├── LICENSE ├── ModelNet ├── ModelNet.py ├── ModelNetDataSet.py ├── ModelNetEval.py ├── ModelNetNormals.py └── ModelNetNormalsEval.py ├── README.md ├── ScanNet ├── ScanNet.py ├── ScanNetDataSet.py ├── ScanNetEval.py ├── genScanNetData.py └── ply_reader.py ├── ShapeNet ├── ShapeNet.py ├── ShapeNetDataSet.py └── ShapeNetEval.py ├── models ├── MCClass.py ├── MCClassH.py ├── MCClassS.py ├── MCNorm.py ├── MCNormS.py ├── MCSeg.py └── MCSegScanNet.py ├── teaser └── Teaser.png ├── tf_ops ├── MCConvModuleSrc ├── aabb_gpu.cc ├── aabb_gpu.cu ├── compute_pdf.cc ├── compute_pdf.cu ├── cuda_kernel_utils.h ├── find_neighbors.cc ├── find_neighbors.cu ├── genCompileScript.py ├── poisson_sampling.cc ├── poisson_sampling.cu ├── sort_gpu.cc ├── sort_gpu.cu ├── spatial_conv.cc └── spatial_conv.cu └── utils ├── DataSet.py ├── GenerateSphereMeshes.py ├── MCConvBuilder.py ├── MCNetworkUtils.py └── PyUtils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MCCNN: Monte Carlo Convolution for Learning on Non-Uniformly Sampled Point Clouds 2 | 3 | Copyright (c) 2018, Visual Computing group form Ulm University, Germany 4 | 5 | The MIT License (MIT) 6 | 7 | Copyright (c) 2018 Pedro Hermosilla 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | -------------------------------------------------------------------------------- /ModelNet/ModelNetDataSet.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file ModelNetDataSet.py 3 | 4 | \brief ModelNet dataset class. 5 | 6 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 7 | Germany. See the LICENSE file at the top-level directory of 8 | this distribution. 9 | 10 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 11 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 12 | 13 | import sys 14 | import os 15 | import math 16 | import time 17 | import numpy as np 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | ROOT_DIR = os.path.dirname(BASE_DIR) 20 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 21 | from DataSet import DataSet 22 | 23 | class ModelNetDataSet(DataSet): 24 | """ModelNet dataset. 25 | 26 | Attributes: 27 | useNormalsAsLabels_ (bool): Boolean that indicates if the normals will be used as the destination 28 | labels per each point. 29 | useNormalsAsFeatures_ (bool): Boolean that indicates if the normals will be used as the input features. 30 | maxStoredNumPoints_ (int): Maximum number of points stored per model. 31 | catNames_ (string array): Name of the categories in the dataset. 32 | """ 33 | 34 | def __init__(self, train, numPoints, ptDropOut, maxStoredNumPoints, batchSize, 35 | allowedSamplings=[0], augment=False, useNormalsAsLabels=False, 36 | useNormalsAsFeatures=False, folder="data", seed=None): 37 | """Constructor. 38 | 39 | Args: 40 | train (bool): Boolean that indicates if this is the train or test dataset. 41 | numPoints (int): Number of points that will be sampled from each model. If 0, all the 42 | points of each model are used. 43 | ptDropOut (float): Probability to keep a point during uniform sampling when all the points 44 | or only the first n number of points are selected. 45 | maxStoredNumPoints (int): Maximum number of points stored per model. 46 | batchSize (int): Size of the batch used. 47 | allowedSamplings (array of ints): Each element of the array determines an allowed sampling protocol 48 | that will be used to sample the different models. The implemented sampling protocols are: 49 | - 0: Uniform sampling 50 | - 1: Split sampling 51 | - 2: Gradient sampling 52 | - 3: Lambert sampling 53 | - 4: Occlusion sampling 54 | augment (bool): Boolean that indicates if data augmentation will be used in the models. 55 | useNormalsAsLabels (bool): Boolean that indicates if the normals will be used as the destination 56 | labels per each point. 57 | useNormalsAsFeatures (bool): Boolean that indicates if the normals will be used as the input features. 58 | folder (int): Folder in which the data is stored. 59 | seed (int): Seed used to initialize the random number generator. If None is provided instead, the current 60 | time on the machine will be used to initialize the number generator. 61 | """ 62 | 63 | # Store the parameters of the class. 64 | self.useNormalsAsLabels_ = useNormalsAsLabels 65 | self.useNormalsAsFeatures_ = useNormalsAsFeatures 66 | self.maxStoredNumPoints_ = maxStoredNumPoints 67 | 68 | # Create the list of labels that need to be augmented. 69 | augmentedLabels = [] 70 | if useNormalsAsLabels: 71 | augmentedLabels = [0] 72 | 73 | # Create the list of features that need to be augmented. 74 | augmentedFeatures = [] 75 | if useNormalsAsFeatures: 76 | augmentedFeatures = [0] 77 | 78 | # Call the constructor of the parent class. 79 | super(ModelNetDataSet,self).__init__(numPoints, ptDropOut, useNormalsAsFeatures, 80 | useNormalsAsLabels, True, not(useNormalsAsLabels), False, batchSize, 81 | allowedSamplings, 100000000, 0, augment, 1, True, True, augmentedFeatures, 82 | augmentedLabels, seed) 83 | 84 | # Get the category names. 85 | self.catNames_ =[] 86 | with open(folder+"/modelnet40_shape_names.txt", 'r') as nameFile: 87 | for line in nameFile: 88 | self.catNames_.append(line.replace("\n","")) 89 | 90 | # List of files 91 | fileList = folder+"/modelnet40_test.txt" 92 | if train: 93 | fileList = folder+"/modelnet40_train.txt" 94 | with open(fileList, 'r') as nameFile: 95 | for line in nameFile: 96 | catId = -1 97 | for i in range(len(self.catNames_)): 98 | if self.catNames_[i] in line: 99 | catId = i 100 | break 101 | if catId >= 0: 102 | self.fileList_.append(folder+"/"+self.catNames_[catId]+"/"+line.replace("\n","")+".txt") 103 | self.categories_.append(catId) 104 | self.numPts_.append(self.maxStoredNumPoints_) 105 | 106 | 107 | def get_categories(self): 108 | """Method to get the list of categories. 109 | 110 | Returns: 111 | pts (n np.array string): List of categories. 112 | """ 113 | return self.catNames_ 114 | 115 | 116 | def _load_model_from_disk_(self, modelPath): 117 | """Abstract method that should be implemented by child class which loads a model 118 | from disk. 119 | 120 | Args: 121 | modelPath (string): Path to the model that needs to be loaded. 122 | 123 | Returns: 124 | pts (nx3 np.array): List of points. 125 | normals (nx3 np.array): List of normals. If the dataset does not contain 126 | normals, None should be returned. 127 | features (nxm np.array): List of features. If the dataset does not contain 128 | features, None should be returned. 129 | labels (nxl np.array): List of labels. If the dataset does not contain 130 | labels, None should be returned. 131 | """ 132 | 133 | fileDataArray = [] 134 | with open(modelPath, 'r') as modelFile: 135 | it = 0 136 | for line in modelFile: 137 | if it < self.maxStoredNumPoints_: 138 | line = line.replace("\n", "") 139 | currPoint = line.split(',') 140 | fileDataArray.append([float(currPoint[0]), float(currPoint[1]), 141 | float(currPoint[2]), float(currPoint[3]), float(currPoint[4]), 142 | float(currPoint[5])]) 143 | it+=1 144 | else: 145 | break 146 | fileData = np.array(fileDataArray) 147 | 148 | pts = fileData[:,0:3] 149 | normals = fileData[:,3:6] 150 | features = None 151 | if self.useNormalsAsFeatures_: 152 | features = normals 153 | labels = None 154 | if self.useNormalsAsLabels_: 155 | labels = normals 156 | 157 | return pts, normals, features, labels -------------------------------------------------------------------------------- /ModelNet/ModelNetEval.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file ModelNetEval.py 3 | 4 | \brief Code to evaluate a classification network on the ModelNet40 5 | dataset. 6 | 7 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | Germany. See the LICENSE file at the top-level directory of 9 | this distribution. 10 | 11 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 13 | 14 | import sys 15 | import math 16 | import time 17 | import argparse 18 | import copy 19 | import random 20 | import importlib 21 | import os 22 | from os import listdir 23 | from os.path import isfile, isdir, join 24 | import numpy as np 25 | import tensorflow as tf 26 | from tensorflow.python import debug as tf_debug 27 | 28 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 29 | ROOT_DIR = os.path.dirname(BASE_DIR) 30 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 31 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 32 | 33 | from PyUtils import visualize_progress 34 | from ModelNetDataSet import ModelNetDataSet 35 | 36 | current_milli_time = lambda: time.time() * 1000.0 37 | 38 | if __name__ == '__main__': 39 | 40 | parser = argparse.ArgumentParser(description='Evaluation of classification networks (ModelNet40)') 41 | parser.add_argument('--inTrainedModel', default='log/model.ckpt', help='Input trained model (default: log/model.ckpt)') 42 | parser.add_argument('--model', default='MCClass', help='model (default: MCClass)') 43 | parser.add_argument('--grow', default=32, type=int, help='Grow rate (default: 32)') 44 | parser.add_argument('--nPoints', default=1024, type=int, help='Number of points (default: 1024)') 45 | parser.add_argument('--nExec', default=1, type=int, help='Number of executions per model (default: 1)') 46 | parser.add_argument('--gpu', default='0', help='GPU (default: 0)') 47 | parser.add_argument('--gpuMem', default=0.5, type=float, help='GPU memory used (default: 0.5)') 48 | args = parser.parse_args() 49 | 50 | print("Trained model: "+args.inTrainedModel) 51 | print("Model: "+args.model) 52 | print("Grow: "+str(args.grow)) 53 | print("nPoints: "+str(args.nPoints)) 54 | print("nExec: "+str(args.nExec)) 55 | 56 | #Load the model 57 | model = importlib.import_module(args.model) 58 | 59 | #Get train and test files 60 | mTestDataSet = ModelNetDataSet(False, args.nPoints, 1.0, 5000, 61 | args.nExec, [0], False) 62 | numTestModels = mTestDataSet.get_num_models() 63 | categories = mTestDataSet.get_categories() 64 | print(categories) 65 | print("Test models: " + str(numTestModels)) 66 | 67 | #Create variable and place holders 68 | inPts = tf.placeholder(tf.float32, [None, 3]) 69 | inBatchIds = tf.placeholder(tf.int32, [None, 1]) 70 | inFeatures = tf.placeholder(tf.float32, [None, 1]) 71 | inLabels = tf.placeholder(tf.int32, [None]) 72 | isTraining = tf.placeholder(tf.bool, shape=()) 73 | keepProbConv = tf.placeholder(tf.float32) 74 | keepProbFull = tf.placeholder(tf.float32) 75 | 76 | #Create the network 77 | logits = model.create_network(inPts, inBatchIds, inFeatures, 1, args.nExec, args.grow, 78 | len(categories), isTraining, keepProbFull, keepProbConv, False, False) 79 | 80 | #Create init variables 81 | init = tf.global_variables_initializer() 82 | initLocal = tf.local_variables_initializer() 83 | 84 | #create the saver 85 | saver = tf.train.Saver() 86 | 87 | #Create session 88 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpuMem, visible_device_list=args.gpu) 89 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 90 | 91 | #Init variables 92 | sess.run(init) 93 | sess.run(initLocal) 94 | 95 | #Restore the model 96 | saver.restore(sess, args.inTrainedModel) 97 | 98 | #Test the dataset. 99 | titleTexts = [ 100 | "Uniform sampling", 101 | "Non-uniform split", 102 | "Non-uniform gradient", 103 | "Non-uniform lambert", 104 | "Non-uniform occlusion"] 105 | for samp in range(5): 106 | #Print the type of sampling used. 107 | print(titleTexts[samp]) 108 | #Update the dataset. 109 | allowedSamplingsTest = [samp] 110 | mTestDataSet.set_allowed_samplings(allowedSamplingsTest) 111 | mTestDataSet.start_iteration() 112 | #Create the auxiliar variables. 113 | i = 0 114 | accumTime = 0.0 115 | totalAccuracy = 0.0 116 | #Iterate over the models. 117 | while mTestDataSet.has_more_batches(): 118 | #Get the batch dataset. 119 | _, points, batchIds, features, _, labels, _ = mTestDataSet.get_next_batch(True) 120 | 121 | #Compute the predicted logits. 122 | startTimeMeasure = current_milli_time() 123 | logitsRes = sess.run(logits, 124 | {inPts: points, inBatchIds: batchIds, inFeatures: features, inLabels: labels, 125 | isTraining: False, keepProbConv: 1.0, keepProbFull: 1.0}) 126 | endTimeMeasure = current_milli_time() 127 | accumTime = accumTime + (endTimeMeasure - startTimeMeasure) 128 | 129 | #Print the progress. 130 | if i%100 == 0: 131 | visualize_progress(i, numTestModels) 132 | 133 | #Compute the predicted class. 134 | predCat = np.argmax(np.sum(logitsRes, axis=0)) 135 | if predCat == labels[0]: 136 | totalAccuracy = totalAccuracy + 1.0 137 | 138 | i += 1 139 | 140 | #Print the results. 141 | print("Time: %.8f" % (accumTime/(float(numTestModels)))) 142 | totalAccuracy = totalAccuracy/float(numTestModels) 143 | print("Test accuracy: %.4f" % (totalAccuracy*100.0)) 144 | -------------------------------------------------------------------------------- /ModelNet/ModelNetNormals.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file ModelNetNormals.py 3 | 4 | \brief Code to train a normal estimation network on the ModelNet40 5 | dataset. 6 | 7 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | Germany. See the LICENSE file at the top-level directory of 9 | this distribution. 10 | 11 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 13 | 14 | import sys 15 | import math 16 | import time 17 | import argparse 18 | import importlib 19 | import os 20 | import numpy as np 21 | import tensorflow as tf 22 | from tensorflow.python import debug as tf_debug 23 | 24 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 25 | ROOT_DIR = os.path.dirname(BASE_DIR) 26 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 27 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 28 | 29 | from PyUtils import visualize_progress 30 | from ModelNetDataSet import ModelNetDataSet 31 | 32 | current_milli_time = lambda: time.time() * 1000.0 33 | 34 | def create_angle(convResult, normals): 35 | normalized_conv = tf.nn.l2_normalize(convResult, axis=1) 36 | normalized_normals = tf.nn.l2_normalize(normals, axis=1) 37 | error = tf.multiply(normalized_conv, normalized_normals) 38 | error = tf.reduce_sum(error, 1) 39 | return tf.acos(tf.reduce_mean(error)) 40 | 41 | def create_loss(convResult, normals): 42 | normalized_normals = tf.nn.l2_normalize(normals, axis=1) 43 | normalized_conv = tf.nn.l2_normalize(convResult, axis=1) 44 | return tf.losses.cosine_distance(normalized_normals, normalized_conv, axis = 1) 45 | 46 | def create_trainning(lossGraph, learningRate, maxLearningRate, learningDecayFactor, learningRateDecay, global_step): 47 | learningRateExp = tf.train.exponential_decay(learningRate, global_step, learningRateDecay, learningDecayFactor, staircase=True) 48 | learningRateExp = tf.maximum(learningRateExp, maxLearningRate) 49 | optimizer = tf.train.AdamOptimizer(learning_rate =learningRateExp) 50 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 51 | with tf.control_dependencies(update_ops): 52 | train_op = optimizer.minimize(lossGraph, global_step=global_step) 53 | return train_op, learningRateExp 54 | 55 | 56 | if __name__ == '__main__': 57 | 58 | parser = argparse.ArgumentParser(description='Script to train MCCNN for normal estimation of point clouds (ModelNet40)') 59 | parser.add_argument('--logFolder', default='log', help='Folder of the output models (default: log)') 60 | parser.add_argument('--model', default='MCNorm', help='model (default: MCNorm)') 61 | parser.add_argument('--grow', default=32, type=int, help='Grow rate (default: 32)') 62 | parser.add_argument('--batchSize', default=32, type=int, help='Batch size (default: 32)') 63 | parser.add_argument('--maxEpoch', default=201, type=int, help='Max Epoch (default: 201)') 64 | parser.add_argument('--initLearningRate', default=0.005, type=float, help='Init learning rate (default: 0.005)') 65 | parser.add_argument('--learningDeacyFactor', default=0.5, type=float, help='Learning deacy factor (default: 0.5)') 66 | parser.add_argument('--learningDecayRate', default=20, type=int, help='Learning decay rate (default: 20 Epochs)') 67 | parser.add_argument('--maxLearningRate', default=0.00001, type=float, help='Maximum Learning rate (default: 0.00001)') 68 | parser.add_argument('--nPoints', default=1024, type=int, help='Number of points (default: 1024)') 69 | parser.add_argument('--augment', action='store_true', help='Augment data (default: False)') 70 | parser.add_argument('--nonunif', action='store_true', help='Train on non-uniform (default: False)') 71 | parser.add_argument('--gpu', default='0', help='GPU (default: 0)') 72 | parser.add_argument('--gpuMem', default=0.5, type=float, help='GPU memory used (default: 0.5)') 73 | args = parser.parse_args() 74 | 75 | #Create log folder. 76 | if not os.path.exists(args.logFolder): os.mkdir(args.logFolder) 77 | os.system('cp ../models/%s.py %s' % (args.model, args.logFolder)) 78 | os.system('cp ModelNetNormals.py %s' % (args.logFolder)) 79 | logFile = args.logFolder+"/log.txt" 80 | 81 | #Write execution info. 82 | with open(logFile, "a") as myFile: 83 | myFile.write("Model: "+args.model+"\n") 84 | myFile.write("Grow: "+str(args.grow)+"\n") 85 | myFile.write("BatchSize: "+str(args.batchSize)+"\n") 86 | myFile.write("MaxEpoch: "+str(args.maxEpoch)+"\n") 87 | myFile.write("InitLearningRate: "+str(args.initLearningRate)+"\n") 88 | myFile.write("LearningDeacyFactor: "+str(args.learningDeacyFactor)+"\n") 89 | myFile.write("LearningDecayRate: "+str(args.learningDecayRate)+"\n") 90 | myFile.write("MaxLearningRate: "+str(args.maxLearningRate)+"\n") 91 | myFile.write("nPoints: "+str(args.nPoints)+"\n") 92 | myFile.write("Augment: "+str(args.augment)+"\n") 93 | myFile.write("Nonunif: "+str(args.nonunif)+"\n") 94 | 95 | print("Model: "+args.model) 96 | print("Grow: "+str(args.grow)) 97 | print("BatchSize: "+str(args.batchSize)) 98 | print("MaxEpoch: "+str(args.maxEpoch)) 99 | print("InitLearningRate: "+str(args.initLearningRate)) 100 | print("LearningDeacyFactor: "+str(args.learningDeacyFactor)) 101 | print("LearningDecayRate: "+str(args.learningDecayRate)) 102 | print("MaxLearningRate: "+str(args.maxLearningRate)) 103 | print("nPoints: "+str(args.nPoints)) 104 | print("Augment: "+str(args.augment)) 105 | print("Nonunif: "+str(args.nonunif)) 106 | 107 | #Load the model 108 | model = importlib.import_module(args.model) 109 | 110 | #Get train and test datasets 111 | allowedSamplingsTrain=[] 112 | allowedSamplingsTest=[] 113 | maxStoredPoints = args.nPoints 114 | if args.nonunif: 115 | maxStoredPoints = 5000 116 | allowedSamplingsTrain = [1, 2, 3, 4] 117 | allowedSamplingsTest = [0, 1, 2, 3, 4] 118 | else: 119 | allowedSamplingsTrain = [0] 120 | allowedSamplingsTest = [0] 121 | 122 | mTrainDataSet = ModelNetDataSet(True, args.nPoints, 1.0, maxStoredPoints, 123 | args.batchSize, allowedSamplingsTrain, args.augment, True) 124 | mTestDataSet = ModelNetDataSet(False, args.nPoints, 1.0, maxStoredPoints, 125 | 1, allowedSamplingsTest, False, True) 126 | 127 | numTrainModels = mTrainDataSet.get_num_models() 128 | numBatchesXEpoch = numTrainModels/args.batchSize 129 | if numTrainModels%args.batchSize != 0: 130 | numBatchesXEpoch = numBatchesXEpoch + 1 131 | numTestModels = mTestDataSet.get_num_models() 132 | print("Train models: " + str(numTrainModels)) 133 | print("Test models: " + str(numTestModels)) 134 | 135 | #Create variable and place holders 136 | global_step = tf.Variable(0, name='global_step', trainable=False) 137 | inPts = tf.placeholder(tf.float32, [None, 3]) 138 | inBatchIds = tf.placeholder(tf.int32, [None, 1]) 139 | inFeatures = tf.placeholder(tf.float32, [None, 1]) 140 | inNormals = tf.placeholder(tf.float32, [None, 3]) 141 | isTraining = tf.placeholder(tf.bool) 142 | accuracyAngle = tf.placeholder(tf.float32) 143 | accuracyTestAngle = tf.placeholder(tf.float32) 144 | 145 | #Create the network 146 | predNormals = model.create_network(inPts, inBatchIds, inFeatures, 1, args.batchSize, args.grow, isTraining) 147 | 148 | #Create loss 149 | loss = create_loss(predNormals, inNormals) 150 | 151 | #Create angle 152 | angle = create_angle(predNormals, inNormals) 153 | 154 | #Create training 155 | trainning, learningRateExp = create_trainning(loss, 156 | args.initLearningRate, args.maxLearningRate, args.learningDeacyFactor, 157 | args.learningDecayRate*numBatchesXEpoch, global_step) 158 | learningRateSumm = tf.summary.scalar('learninRate', learningRateExp) 159 | 160 | #Create sumaries 161 | lossSummary = tf.summary.scalar('loss', loss) 162 | trainingSummary = tf.summary.merge([lossSummary, learningRateSumm]) 163 | metricsSummary = tf.summary.scalar('accuracy', accuracyAngle) 164 | metricsTestSummary = tf.summary.scalar('Tes_Accuracy', accuracyTestAngle) 165 | 166 | #Create init variables 167 | init = tf.global_variables_initializer() 168 | initLocal = tf.local_variables_initializer() 169 | 170 | #create the saver 171 | saver = tf.train.Saver() 172 | 173 | #Create session 174 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpuMem, visible_device_list=args.gpu) 175 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 176 | 177 | #Create the summary writer 178 | summary_writer = tf.summary.FileWriter(args.logFolder, sess.graph) 179 | summary_writer.add_graph(sess.graph) 180 | 181 | #Init variables 182 | sess.run(init) 183 | sess.run(initLocal) 184 | step = 0 185 | epochStep = 0 186 | np.random.seed(int(time.time())) 187 | 188 | #Train 189 | for epoch in range(args.maxEpoch): 190 | 191 | startEpochTime = current_milli_time() 192 | startTrainTime = current_milli_time() 193 | 194 | epochStep = 0 195 | lossInfoCounter = 0 196 | lossAccumValue = 0.0 197 | angleAccumValue = 0.0 198 | 199 | #Iterate over all the train files 200 | mTrainDataSet.start_iteration() 201 | while mTrainDataSet.has_more_batches(): 202 | 203 | _, points, batchIds, features, normals, _, _ = mTrainDataSet.get_next_batch() 204 | 205 | _, lossRes, angleRes, trainingSummRes = \ 206 | sess.run([trainning, loss, angle, trainingSummary], 207 | {inPts: points, inBatchIds: batchIds, inFeatures: features, inNormals: normals, isTraining: True}) 208 | 209 | summary_writer.add_summary(trainingSummRes, step) 210 | 211 | angleAccumValue += angleRes 212 | lossAccumValue += lossRes 213 | lossInfoCounter += 1 214 | 215 | if lossInfoCounter == 10: 216 | currAngle = math.degrees(angleAccumValue/10.0) 217 | endTrainTime = current_milli_time() 218 | metricsSummRes = sess.run(metricsSummary, {accuracyAngle: currAngle}) 219 | summary_writer.add_summary(metricsSummRes, step) 220 | 221 | visualize_progress(epochStep, numBatchesXEpoch, "Loss: %.6f | Angle: %.4f | Time: %.4f" % ( 222 | lossAccumValue/10.0, currAngle, (endTrainTime-startTrainTime)/1000.0)) 223 | 224 | with open(logFile, "a") as myfile: 225 | myfile.write("Step: %6d (%4d) | Loss: %.6f | Angle: %.4f\n" % (step, epochStep, lossAccumValue/10.0, currAngle)) 226 | 227 | lossInfoCounter = 0 228 | lossAccumValue = 0.0 229 | angleAccumValue = 0.0 230 | startTrainTime = current_milli_time() 231 | 232 | step += 1 233 | epochStep += 1 234 | 235 | endEpochTime = current_milli_time() 236 | print("Epoch %3d train time: %.4f" %(epoch, (endEpochTime-startEpochTime)/1000.0)) 237 | with open(logFile, "a") as myfile: 238 | myfile.write("Epoch %3d train time: %.4f \n" %(epoch, (endEpochTime-startEpochTime)/1000.0)) 239 | 240 | if epoch%10==0: 241 | saver.save(sess, args.logFolder+"/model.ckpt") 242 | 243 | #Test data 244 | accumTestLoss = 0.0 245 | accumAngleTest = 0.0 246 | it = 0 247 | mTestDataSet.start_iteration() 248 | while mTestDataSet.has_more_batches(): 249 | _, points, batchIds, features, normals, _, _ = mTestDataSet.get_next_batch() 250 | 251 | lossRes, angleRes = sess.run([loss, angle], 252 | {inPts: points, inBatchIds: batchIds, inFeatures: features, inNormals: normals, isTraining: False}) 253 | 254 | accumTestLoss +=lossRes 255 | accumAngleTest += angleRes 256 | 257 | if it%100 == 0: 258 | visualize_progress(it, numTestModels) 259 | 260 | it += 1 261 | 262 | accumTestLoss = accumTestLoss/float(numTestModels) 263 | currTestAngle = math.degrees(accumAngleTest/float(numTestModels)) 264 | metricsTestSummRes = sess.run(metricsTestSummary, {accuracyTestAngle: currTestAngle}) 265 | summary_writer.add_summary(metricsTestSummRes, step) 266 | 267 | print("Loss: %.6f | Test accuracy: %.4f" % (accumTestLoss, currTestAngle)) 268 | with open(logFile, "a") as myfile: 269 | myfile.write("Loss: %.6f | Test accuracy: %.4f \n" % (accumTestLoss, currTestAngle)) 270 | -------------------------------------------------------------------------------- /ModelNet/ModelNetNormalsEval.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file ModelNetNormalsEval.py 3 | 4 | \brief Code to evaluate a normal estimation network on the ModelNet40 5 | dataset. 6 | 7 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | Germany. See the LICENSE file at the top-level directory of 9 | this distribution. 10 | 11 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 13 | 14 | 15 | import sys 16 | import math 17 | import time 18 | import argparse 19 | import importlib 20 | import os 21 | import numpy as np 22 | import tensorflow as tf 23 | from tensorflow.python import debug as tf_debug 24 | 25 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 26 | ROOT_DIR = os.path.dirname(BASE_DIR) 27 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 28 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 29 | 30 | from PyUtils import visualize_progress 31 | from ModelNetDataSet import ModelNetDataSet 32 | 33 | current_milli_time = lambda: time.time() * 1000.0 34 | 35 | def create_angle(convResult, normals): 36 | normalized_conv = tf.nn.l2_normalize(convResult, axis=1) 37 | normalized_normals = tf.nn.l2_normalize(normals, axis=1) 38 | error = tf.multiply(normalized_conv, normalized_normals) 39 | error = tf.reduce_sum(error, 1) 40 | return tf.acos(tf.reduce_mean(error)) 41 | 42 | def create_loss(convResult, normals): 43 | normalized_normals = tf.nn.l2_normalize(normals, axis=1) 44 | normalized_conv = tf.nn.l2_normalize(convResult, axis=1) 45 | return tf.losses.cosine_distance(normalized_normals, normalized_conv, axis = 1) 46 | 47 | 48 | if __name__ == '__main__': 49 | 50 | parser = argparse.ArgumentParser(description='Evaluation of normal estimation networks (ModelNet40)') 51 | parser.add_argument('--inTrainedModel', default='log/model.ckpt', help='Input trained model (default: log/model.ckpt)') 52 | parser.add_argument('--model', default='MCNorm', help='model (default: MCNorm)') 53 | parser.add_argument('--grow', default=32, type=int, help='Grow rate (default: 32)') 54 | parser.add_argument('--nPoints', default=1024, type=int, help='Number of points (default: 1024)') 55 | parser.add_argument('--nExec', default=1, type=int, help='Number of executions per model (default: 1)') 56 | parser.add_argument('--gpu', default='0', help='GPU (default: 0)') 57 | parser.add_argument('--gpuMem', default=0.5, type=float, help='GPU memory used (default: 0.5)') 58 | args = parser.parse_args() 59 | 60 | print("Trained model: "+args.inTrainedModel) 61 | print("Model: "+args.model) 62 | print("Grow: "+str(args.grow)) 63 | print("nPoints: "+str(args.nPoints)) 64 | print("nExec: "+str(args.nExec)) 65 | 66 | #Load the model 67 | model = importlib.import_module(args.model) 68 | 69 | #Get train and test files 70 | mTestDataSet = ModelNetDataSet(False, args.nPoints, 1.0, 5000, 71 | args.nExec, [0], False, True) 72 | numTestModels = mTestDataSet.get_num_models() 73 | print("Test models: " + str(numTestModels)) 74 | 75 | #Create variable and place holders 76 | inPts = tf.placeholder(tf.float32, [None, 3]) 77 | inBatchIds = tf.placeholder(tf.int32, [None, 1]) 78 | inFeatures = tf.placeholder(tf.float32, [None, 1]) 79 | inNormals = tf.placeholder(tf.float32, [None, 3]) 80 | isTraining = tf.placeholder(tf.bool) 81 | 82 | #Create the network 83 | predNormals = model.create_network(inPts, inBatchIds, inFeatures, 1, args.nExec, args.grow, isTraining) 84 | 85 | #Create loss 86 | loss = create_loss(predNormals, inNormals) 87 | 88 | #Create angle 89 | angle = create_angle(predNormals, inNormals) 90 | 91 | #Create init variables 92 | init = tf.global_variables_initializer() 93 | initLocal = tf.local_variables_initializer() 94 | 95 | #create the saver 96 | saver = tf.train.Saver() 97 | 98 | #Create session 99 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpuMem, visible_device_list=args.gpu) 100 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 101 | 102 | #Init variables 103 | sess.run(init) 104 | sess.run(initLocal) 105 | 106 | #Restore the model 107 | saver.restore(sess, args.inTrainedModel) 108 | 109 | #Test the dataset. 110 | titleTexts = [ 111 | "Uniform sampling", 112 | "Non-uniform split", 113 | "Non-uniform gradient", 114 | "Non-uniform lambert", 115 | "Non-uniform occlusion"] 116 | for samp in range(5): 117 | #Print the type of sampling used. 118 | print(titleTexts[samp]) 119 | #Update the dataset. 120 | allowedSamplingsTest = [samp] 121 | mTestDataSet.set_allowed_samplings(allowedSamplingsTest) 122 | mTestDataSet.start_iteration() 123 | #Create the auxiliar variables. 124 | i = 0 125 | accumTime = 0.0 126 | totalLoss = 0.0 127 | totalAngle = 0.0 128 | #Iterate over the models. 129 | while mTestDataSet.has_more_batches(): 130 | #Get the batch dataset. 131 | _, points, batchIds, features, normals, _, _ = mTestDataSet.get_next_batch(True) 132 | 133 | #Compute the loss. 134 | startTimeMeasure = current_milli_time() 135 | lossRes, angleRes = sess.run([loss, angle], 136 | {inPts: points, inBatchIds: batchIds, inFeatures: features, inNormals: normals, isTraining: False}) 137 | endTimeMeasure = current_milli_time() 138 | accumTime = accumTime + (endTimeMeasure - startTimeMeasure) 139 | totalLoss += lossRes 140 | totalAngle += angleRes 141 | 142 | #Print the progress. 143 | if i%100 == 0: 144 | visualize_progress(i, numTestModels) 145 | 146 | i += 1 147 | 148 | #Print the results. 149 | print("Time: %.8f" % (accumTime/(float(numTestModels)*float(args.nExec)))) 150 | print("Test loss: %.4f | Test angle: %.4f" % (totalLoss/float(numTestModels*args.nExec), 151 | math.degrees((totalAngle/float(numTestModels*args.nExec))))) 152 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### MCCNN: *Monte Carlo Convolution for Learning on Non-Uniformly Sampled Point Clouds* 2 | Created by Pedro Hermosilla, Tobias Ritschel, Pere-Pau Vazquez, Alvar Vinacua, Timo Ropinski. 3 | 4 | ![teaser](https://github.com/viscom-ulm/MCCNN/blob/master/teaser/Teaser.png) 5 | 6 | ### Citation 7 | If you find this code useful please consider citing us: 8 | 9 | @article{hermosilla2018mccnn, 10 | title={Monte Carlo Convolution for Learning on Non-Uniformly Sampled Point Clouds}, 11 | author={Hermosilla, P. and Ritschel, T. and Vazquez, P-P and Vinacua, A. and Ropinski, T.}, 12 | journal={ACM Transactions on Graphics (Proceedings of SIGGRAPH Asia 2018)}, 13 | volume={37}, 14 | number={6} 15 | year={2018}, 16 | doi={10.1145/3272127.3275110} 17 | } 18 | 19 | ### Introduction 20 | Deep learning systems extensively use convolution operations to process input data. Though convolution is clearly defined for structured data such as 2D images or 3D volumes, this is not true for other data types such as sparse point clouds. Previous techniques have developed approximations to convolutions for restricted conditions. Unfortunately, their applicability is limited and cannot be used for general point clouds. We propose an efficient and effective method to learn convolutions for non-uniformly sampled point clouds, as they are obtained with modern acquisition techniques. Learning is enabled by four key novelties: first, representing the convolution kernel itself as a multilayer perceptron; second, phrasing convolution as a Monte Carlo integration problem, third, using this notion to combine information from multiple samplings at different levels; and fourth using Poisson disk sampling as a scalable means of hierarchical point cloud learning. The key idea across all these contributions is to guarantee adequate consideration of the underlying non-uniform sample distribution function from a Monte Carlo perspective. To make the proposed concepts applicable to real-world tasks, we furthermore propose an efficient implementation which significantly reduces the GPU memory required during the training process. By employing our method in hierarchical network architectures we can outperform most of the state-of-the-art networks on established point cloud segmentation, classification and normal estimation benchmarks. Furthermore, in contrast to most existing approaches, we also demonstrate the robustness of our method with respect to sampling variations, even when training with uniformly sampled data only. 21 | 22 | In this repository, we release the code of our tensor operations and network architectures for classification, segmentation, and normal estimation tasks, which realize the ideas presented in our paper. For further details of the techniques implemented here, you can refer to the paper or visit our project page. 23 | 24 | ### Installation 25 | First, install TensorFlow. The code presented here was developed using TensorFlow v1.5 GPU version, Python 2.7, and Ubuntu 16.04 TLS. However, it should also work with TensorFlow v1.8 GPU version and Python 3. All the operation were implemented on the GPU, no CPU implementation is provided. Therefore, a workstation with a state-of-the-art GPU is required. 26 | 27 | #### Compiling tensorflow operations 28 | In order to train the networks provided in this repository, first, we have to compile the new tensor operations which implement the Monte Carlo convolutions. These operations are located on the folder `tf_ops`. To compile them we should execute the following commands: 29 | 30 | cd tf_ops 31 | python genCompileScript.py --cudaFolder *path_to_cuda* 32 | sh compile.sh 33 | 34 | The python script `genCompileScript.py` also provides two more options: `MLPSize`, which defines the size of the MLPs used to approximate the convolution kernels; and `debugInfo`, to print debug information during the execution of the layers. 35 | 36 | For newer versions of tensorflow the generated shared library will produce a link error. On these cases, we should substitute `-D_GLIBCXX_USE_CXX11_ABI=0` by `-D_GLIBCXX_USE_CXX11_ABI=1` in our compile.sh script. 37 | 38 | ### Tasks 39 | The network architectures used for the different tasks can be found on the folder `models`, whilst the scripts to train and evaluate such networks can be found on the folders of the different datasets. See our paper for more details on the network architectures. All scripts can be executed with the argument `--help` to visualize the different options. Some arguments used in our scripts are: 40 | 41 | * **grow:** Determines the growth rate of the number of features in the networks. All layers of the networks produce a number of features which is multiple of this number. 42 | * **useDropOut:** If this argument is provided, drop out layers are used in the final MLPs in classification and segmentation networks. 43 | * **dropOutKeepProb:** If this useDropOut is provided, this argument determines the probability to keep the value of a neuron in the MLPs. 44 | * **useDropOutConv:** If this argument is provided, drop out layers are used before each convolution layer in classification and segmentation networks. 45 | * **dropOutKeepProbConv:** If useDropOutConv is provided, this argument determines the probability to keep the value of a feature before each convolution layer. 46 | * **weightDecay:** Scale used in the L2 regularization. If 0.0 is provided, no L2 regularization is performed. 47 | * **ptDropOut:** Probability of selecting a point during loading of the models in the training phase. 48 | 49 | #### Classification 50 | We provide 3 different networks for classification tasks (MCClassS, MCClass, and MCClassH) which have been tested on the ModelNet40 dataset. We used the resampled ModelNet40 dataset provided in PointNet++, which contains XYZ position and normals for 10k points per model. Once downloaded, uncompress the zip file inside the folder ModelNet and rename the folder to `data`. Then, you can train and evaluate the different networks. 51 | 52 | **MCClassS** This network is composed of only 3 pooling Monte Carlo convolutions. This is the default model used in the classification script and it can be trained and evaluated using the following commands: 53 | 54 | python ModelNet.py --grow 128 --useDropOut 55 | python ModelNetEval.py --grow 128 56 | 57 | **MCClass** This network is composed of a set of Monte Carlo convolutions on the different levels of the point hierarchy. It can be trained and evaluated using the following commands: 58 | 59 | python ModelNet.py --model MCClass --useDropOut --useDropOutConv --weightDecay 0.00001 60 | python ModelNetEval.py --model MCClass 61 | 62 | **MCClassH** This network is composed of two different paths which process the different levels of the point hierarchy independently. Whilst one path works directly on the initial point set, the second path works on a lower level on the point hierarchy. It is trained by activating and deactivating these paths in order to be more robust to non-uniformly sampled point clouds. The results obtained by using this network were the ones reported in our paper. It can be trained and evaluated using the following commands: 63 | 64 | python ModelNet.py --model MCClassH --useDropOut --useDropOutConv 65 | python ModelNetEval.py --model MCClassH 66 | 67 | #### Segmentation 68 | We provide a network for segmentation tasks (MCSeg) which have been tested on the ShapeNet dataset. We used the resampled ShapeNet dataset provided in PointNet++, which contains XYZ position, normals and part label per each point. Once downloaded, uncompress the zip file inside the folder ShapeNet and rename the folder to `shape_data`. Then, you can train and evaluate the networks. 69 | 70 | **MCSeg** This network has an encoder-decoder architecture, in which the decoder part upsamples features from different levels at the same time. It can be trained and evaluated using the following commands: 71 | 72 | python ShapeNet.py --useDropOut 73 | python ShapeNetEval.py 74 | 75 | The results reported in our paper were obtained by training the network with the following command. However, this configuration requires more GPU memory. 76 | 77 | python ShapeNet.py --grow 128 --useDropOut --useDropOutConv 78 | 79 | #### Normal Estimation 80 | We provide 2 networks for normal estimation tasks (MCNorm, and MCNormS) which have been tested on the ModelNet40 dataset. We used the same resampled dataset as the one used in the classification task. Follow the instructions provided in the classification section to download the data. Then, you can train and evaluate the different networks. 81 | 82 | **MCNorm:** Network with an encoder-decoder architecture which outputs 3 floats per point. It can be trained and evaluated using the following commands: 83 | 84 | python ModelNetNormals.py 85 | python ModelNetNormalsEval.py 86 | 87 | **MCNormS:** Small network composed of only two consecutive Monte Carlo convolutions which output 3 floats per point. This network was designed to evaluate the performance of our convolutions without considering the point hierarchy used. It can be trained and evaluated using the following commands: 88 | 89 | python ModelNetNormals.py --model MCNormS 90 | python ModelNetNormalsEval.py --model MCNormS 91 | 92 | #### Real world dataset 93 | We provide a network architecture used to perform semantic segmentation on real scans. We tested out network on the ScanNet dataset (version 1). In order to download the dataset, please contact the authors who will provide a script for the download. Download the decimated mesh and the task file as well. Then we will process the data by using the following commands: 94 | 95 | python genScanNetData.py --inFolder data --outFolder data_mccnn 96 | 97 | Then, we will be able to train and evaluate our network with the commands: 98 | 99 | python ScanNet.py 100 | python ScanNetEval.py --inTrainedModel log/model.ckpt-# (number of the best performing epoch) 101 | 102 | In order to reproduce the results obtained in our paper, execute the following command (it requires more GPU memory): 103 | 104 | python ScanNet.py --grow 64 105 | 106 | #### Custom Architectures 107 | We provide a builder module on folder `utils` which allows the creation of spatial convolutions by using a set of simple interfaces, hiding the tedious generation of auxiliary data structures. In order to create our first spatial convolution, we first have to create our point hierarchy with the command: 108 | 109 | mPointHierarchy = PointHierarchy( 110 | inPoints=points, 111 | inFeatures=features, 112 | inBatchIds=batchIds, 113 | radiusList=[0.1, 0.4], 114 | hierarchyName="MyPointHierarchy", 115 | batchSize=batchSize, 116 | relativeRadius=True) 117 | 118 | Then, we can create our convolution builder with the command: 119 | 120 | mConvBuilder = ConvolutionBuilder( 121 | multiFeatureConvs=False, 122 | KDEWindow=0.25, 123 | relativeRadius=True, 124 | usePDF=True) 125 | 126 | Lastly, we can create a spatial convolution on our first level of the hierarchy with: 127 | 128 | convFeatures1 = mConvBuilder.create_convolution( 129 | convName="Conv_1", 130 | inPointHierarchy=mPointHierarchy, 131 | inPointLevel=0, 132 | inFeatures=features, 133 | inNumFeatures=1, 134 | outNumFeatures=16, 135 | convRadius= 0.1, 136 | multiFeatureConv=True) 137 | 138 | And transfer these features to the next level by using another spatial convolution: 139 | 140 | poolFeatures1 = mConvBuilder.create_convolution( 141 | convName="Pool_1", 142 | inPointHierarchy=mPointHierarchy, 143 | inPointLevel=0, 144 | outPointLevel=1, 145 | inFeatures=convFeatures1, 146 | inNumFeatures=16, 147 | convRadius=0.2, 148 | KDEWindow= 0.2) 149 | 150 | We recommend taking a look first to the network `MCNormS.py` since it is only composed of two convolutions on a point hierarchy of only one level. Then, for a better understanding of how to create a deeper architecture, one can refer to the network `MCClass.py`. 151 | 152 | #### Custom Datasets 153 | We also provide an interface to create your own dataset on the folder `utils`, `DataSet.py`. This class implements the different sampling protocols described in our paper and other features such us model cache. In order to take advantage of these features, your dataset class should inherit from this one and define the files containing your models and implement the loader function. The files `ModelNetDataSet.py` and `ShapeNetDataSet.py` provide examples of datasets used for classification and segmentation tasks. `ScanNetDataSet.py` file implements a dataset composed of models with variable sizes and variable bounding boxes. 154 | 155 | If you have some doubts about the usage of some operations or scripts do not hesitate to contact us. 156 | 157 | #### Updates 158 | 159 | * 25/09/2018: Added scripts for training on ScanNet and builders to create point hierarchies and convolutions in a more intuitive way. 160 | 161 | ### License 162 | Our code is released under MIT License (see LICENSE file for details). -------------------------------------------------------------------------------- /ScanNet/ScanNetDataSet.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file ShapeNetDataSet.py 3 | 4 | \brief ScanNet dataset class. 5 | 6 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 7 | Germany. See the LICENSE file at the top-level directory of 8 | this distribution. 9 | 10 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 11 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 12 | 13 | import sys 14 | import os 15 | import math 16 | import time 17 | import json 18 | import numpy as np 19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 20 | ROOT_DIR = os.path.dirname(BASE_DIR) 21 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 22 | from DataSet import DataSet 23 | 24 | class ScanNetDataSet(DataSet): 25 | """ScanNet dataset. 26 | 27 | Attributes: 28 | useColorsAsFeatures_ (bool): Boolean that indicates if the colors will be used as the input features. 29 | dataFolder_ (string): Path of the folder with the data. 30 | weights_ (array of floats): List of weights for each label or category. 31 | """ 32 | 33 | def __init__(self, dataset, batchSize, ptDropOut, maxNumPtsxBatch=600000, 34 | augment=False, useColorsAsFeatures=False, dataFolder="data_mccnn", seed=None): 35 | """Constructor. 36 | 37 | Args: 38 | dataset (int): Index of the dataset that will be used. 0 - train, 1 - val, 2 - test 39 | batchSize (int): Size of the batch used. 40 | ptDropOut (float): Probability to keep a point during uniform sampling when all the points 41 | or only the first n number of points are selected. 42 | augment (bool): Boolean that indicates if data augmentation will be used in the models. 43 | useColorsAsFeatures (bool): Boolean that indicates if the colors will be used as the input features. 44 | dataFolder (string): Path of the folder with the data. 45 | seed (int): Seed used to initialize the random number generator. If None is provided instead, the current 46 | time on the machine will be used to initialize the number generator. 47 | """ 48 | 49 | # Check if all the models will fit in the batch. 50 | if (maxNumPtsxBatch < 600000) and (maxNumPtsxBatch > 0): 51 | raise RuntimeError('The number of points per batch should be big enough to store'\ 52 | ' all the models (greater than 600000).') 53 | 54 | # Store the parameters of the class. 55 | self.useColorsAsFeatures_ = useColorsAsFeatures 56 | self.dataFolder_ = dataFolder 57 | 58 | # Call the constructor of the parent class. 59 | super(ScanNetDataSet,self).__init__(0, ptDropOut, 60 | useColorsAsFeatures, True, False, 61 | False, False, batchSize, [0], 0, maxNumPtsxBatch, 62 | augment, 2, False, False, [], [], seed) 63 | 64 | # Load the room list. 65 | rooms = np.loadtxt(dataFolder+"/rooms.txt", dtype='str') 66 | # Load the number of points per rooms. 67 | roomsNummPoints = np.loadtxt(dataFolder+"/num_points.txt")*ptDropOut 68 | # Room per dataset. 69 | trainRooms = set(np.loadtxt(dataFolder+"/scannet_train.txt", dtype='str')) 70 | valRooms = set(np.loadtxt(dataFolder+"/scannet_val.txt", dtype='str')) 71 | testRooms = set(np.loadtxt(dataFolder+"/scannet_test.txt", dtype='str')) 72 | # Determine the indexs of the rooms for each dataset. 73 | roomIndexs = np.array([]) 74 | if dataset == 0: 75 | roomIndexs = np.array([i for i in range(len(rooms)) if rooms[i] in trainRooms]) 76 | elif dataset == 1: 77 | roomIndexs = np.array([i for i in range(len(rooms)) if rooms[i] in valRooms]) 78 | else: 79 | roomIndexs = np.array([i for i in range(len(rooms)) if rooms[i] in testRooms]) 80 | 81 | # Store the file rooms of the dataset with the number of points. 82 | self.fileList_ = rooms[roomIndexs]#["scene0497_00"] 83 | self.numPts_ = roomsNummPoints[roomIndexs]#[590000] 84 | 85 | # Load the labels identifiers and the weights. 86 | self.semLabels_ = np.loadtxt(dataFolder+"/labels.txt", dtype='str', delimiter=':') 87 | weights = np.loadtxt(dataFolder+"/weights.txt") 88 | for i in range(len(self.semLabels_)): 89 | weights[0][i] = 1.0/np.log(1.2 + weights[0][i]) 90 | weights[1][i] = 1.0/np.log(1.2 + weights[1][i]) 91 | self.weights_ = weights[0] 92 | self.weights_[0] = 0.0 93 | 94 | 95 | def get_labels(self): 96 | """Method to get the list of labels. 97 | 98 | Returns: 99 | pts (n np.array string): List of labels. 100 | """ 101 | return self.semLabels_ 102 | 103 | 104 | def get_weights(self, labels): 105 | """Method to get the weights associated to the labels. 106 | 107 | Args: 108 | catlabs (nxm np.array): Labels for which we want the weights. 109 | Returns: 110 | weights (nxm): Weights associated with the input labels. 111 | """ 112 | 113 | if len(self.weights_) == 0: 114 | raise RuntimeError('No weights associated to the labels.') 115 | 116 | outWeights = np.array([[self.weights_[currLab[0]]] for currLab in labels]) 117 | return outWeights 118 | 119 | 120 | def get_accuracy_masks(self, labels): 121 | """Method to get the list of mask for each label to compute the accuracy. 122 | 123 | Args: 124 | labels (np.array): Labels for which we want the weights. 125 | Returns: 126 | masks (np.array): List of mask for each label to compute 127 | the accuracy. 128 | """ 129 | outMasks = np.array([[1.0] if lab[0] != 0 else [0.0] for lab in labels]) 130 | return outMasks 131 | 132 | 133 | def _load_model_from_disk_(self, modelPath): 134 | """Abstract method that should be implemented by child class which loads a model 135 | from disk. 136 | 137 | Args: 138 | modelPath (string): Path to the model that needs to be loaded. 139 | 140 | Returns: 141 | pts (nx3 np.array): List of points. 142 | normals (nx3 np.array): List of normals. If the dataset does not contain 143 | normals, None should be returned. 144 | features (nxm np.array): List of features. If the dataset does not contain 145 | features, None should be returned. 146 | labels (nxl np.array): List of labels. If the dataset does not contain 147 | labels, None should be returned. 148 | """ 149 | 150 | pts = np.load(self.dataFolder_+"/"+modelPath+"_pos.npy") 151 | labels = np.load(self.dataFolder_+"/"+modelPath+"_labels.npy") 152 | normals = None 153 | features = None 154 | if self.useColorsAsFeatures_: 155 | features = np.load(self.dataFolder_+"/"+modelPath+"_colors.npy") 156 | centroid = np.mean(pts, axis= 0) 157 | pts = pts - centroid 158 | 159 | return pts, normals, features, labels.reshape((-1,1)) -------------------------------------------------------------------------------- /ScanNet/ScanNetEval.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file ShapeNet.py 3 | 4 | \brief Code to evaluate a segmentation network on the ScanNet dataset. 5 | 6 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 7 | Germany. See the LICENSE file at the top-level directory of 8 | this distribution. 9 | 10 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 11 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 12 | 13 | import sys 14 | import math 15 | import time 16 | import argparse 17 | import importlib 18 | import os 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 23 | ROOT_DIR = os.path.dirname(BASE_DIR) 24 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 25 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 26 | 27 | from PyUtils import visualize_progress, save_model 28 | from ScanNetDataSet import ScanNetDataSet 29 | 30 | 31 | current_milli_time = lambda: time.time() * 1000.0 32 | 33 | 34 | def create_loss(logits, labels, labelWeights, weigthDecay): 35 | labels = tf.to_int64(labels) 36 | cross_entropy = tf.losses.sparse_softmax_cross_entropy( 37 | labels=tf.reshape(labels, [-1]), logits=logits, 38 | weights=tf.reshape(labelWeights, [-1]), scope='xentropy') 39 | xentropyloss = tf.reduce_mean(cross_entropy, name='xentropy_mean') 40 | regularizer = tf.contrib.layers.l2_regularizer(scale=weigthDecay) 41 | regVariables = tf.get_collection('weight_decay_loss') 42 | regTerm = tf.contrib.layers.apply_regularization(regularizer, regVariables) 43 | return xentropyloss, regTerm 44 | 45 | 46 | def create_accuracy(logits, labels, inAccWeights, scope): 47 | _, logitsIndexs = tf.nn.top_k(logits) 48 | with tf.variable_scope(scope): 49 | return tf.metrics.accuracy(labels, logitsIndexs, weights=inAccWeights) 50 | 51 | 52 | if __name__ == '__main__': 53 | 54 | parser = argparse.ArgumentParser(description='Script to train MCCNN for segmentation tasks (S3DIS)') 55 | parser.add_argument('--inTrainedModel', default='log/model.ckpt', help='Input trained model (default: log/model.ckpt)') 56 | parser.add_argument('--model', default='MCSegScanNet', help='model (default: MCSegScanNet)') 57 | parser.add_argument('--grow', default=32, type=int, help='Grow rate (default: 32)') 58 | parser.add_argument('--nExec', default=1, type=int, help='Number of executions per model (default: 1)') 59 | parser.add_argument('--gpu', default='0', help='GPU (default: 0)') 60 | parser.add_argument('--gpuMem', default=1.0, type=float, help='GPU memory used (default: 1.0)') 61 | parser.add_argument('--useColor', action='store_true', help='Augment data (default: False)') 62 | parser.add_argument('--saveModels', action='store_true', help='Save models (default: False)') 63 | args = parser.parse_args() 64 | 65 | print("Trained Model: "+args.inTrainedModel) 66 | print("Model: "+args.model) 67 | print("Grow: "+str(args.grow)) 68 | print("Num executions: "+str(args.nExec)) 69 | print("Use color: "+str(args.useColor)) 70 | 71 | objColors = [ [0,0,0], # Unannotated 72 | [174,198,232], # Wall 73 | [151,223,137], # Floor 74 | [187,188,34], # Chair 75 | [254,151,150], # Table 76 | [247,183,210], # Desk 77 | [255,188,120], # Bed 78 | [148,103,188], # Bookshelf 79 | [140,86,74], # Sofa 80 | [112,128,144], # Sink 81 | [226,118,193], # Bathtub 82 | [42,159,44], # Toilet 83 | [218,219,141], # Curtain 84 | [23,190,208], # Counter 85 | [213,39,40], # Door 86 | [196,176,213], # Window 87 | [158,218,229], # Shower curtain 88 | [254,127,14], # Refrigerator 89 | [196,156,148], # Picture 90 | [31,120,180], # Cabinet 91 | [82,83,163]] # Other furniture 92 | 93 | #Save models, create folder 94 | if args.saveModels: 95 | if not os.path.exists("savedModels"): os.mkdir("savedModels") 96 | 97 | #Load the model 98 | model = importlib.import_module(args.model) 99 | 100 | #Get train and test files 101 | mTestDataSet = ScanNetDataSet(2, 1, 1.0, 0, False, args.useColor) 102 | semLabels = mTestDataSet.get_labels() 103 | print(semLabels) 104 | numTestRooms = mTestDataSet.get_num_models() 105 | print("Num test rooms: " + str(numTestRooms)) 106 | 107 | #Create variable and place holders 108 | inPts = tf.placeholder(tf.float32, [None, 3]) 109 | inBatchIds = tf.placeholder(tf.int32, [None, 1]) 110 | if args.useColor: 111 | inFeatures = tf.placeholder(tf.float32, [None, 4]) 112 | else: 113 | inFeatures = tf.placeholder(tf.float32, [None, 1]) 114 | inLabels = tf.placeholder(tf.int32, [None, 1]) 115 | inWeights = tf.placeholder(tf.float32, [None, 1]) 116 | inAccWeights = tf.placeholder(tf.float32, [None, 1]) 117 | isTraining = tf.placeholder(tf.bool) 118 | keepProbConv = tf.placeholder(tf.float32) 119 | keepProbFull = tf.placeholder(tf.float32) 120 | 121 | #Create the network 122 | numInputs = 1 123 | if args.useColor: 124 | numInputs = 4 125 | logits = model.create_network(inPts, inBatchIds, inFeatures, numInputs, len(semLabels), 1, 126 | args.grow, isTraining, keepProbConv, keepProbFull, False, False) 127 | 128 | #Create predict labels 129 | predictedLabels = tf.argmax(logits, 1) 130 | 131 | #Create loss 132 | xentropyLoss, regularizationLoss = create_loss(logits, inLabels, inWeights, 0.0) 133 | loss = xentropyLoss + regularizationLoss 134 | 135 | #Create accuracy metric 136 | accuracyVal, accuracyAccumOp = create_accuracy(logits, inLabels, inAccWeights, 'metrics') 137 | metricsVars = tf.contrib.framework.get_variables('metrics', collection=tf.GraphKeys.LOCAL_VARIABLES) 138 | resetMetrics = tf.variables_initializer(metricsVars) 139 | 140 | #Create init variables 141 | init = tf.global_variables_initializer() 142 | initLocal = tf.local_variables_initializer() 143 | 144 | #Create session 145 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpuMem, visible_device_list=args.gpu) 146 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 147 | 148 | #create the saver 149 | saver = tf.train.Saver() 150 | 151 | #Init variables 152 | sess.run(init) 153 | sess.run(initLocal) 154 | np.random.seed(int(time.time())) 155 | 156 | #Restore the model 157 | saver.restore(sess, args.inTrainedModel) 158 | 159 | #Test data 160 | print("############################## Evaluation" ) 161 | it = 0 162 | accumLoss = 0.0 163 | accumTestLoss = 0.0 164 | sess.run(resetMetrics) 165 | accumIntersection = [0.0 for i in range(len(semLabels))] 166 | accumUnion = [0.0 for i in range(len(semLabels))] 167 | accumGt = [0.0 for i in range(len(semLabels))] 168 | accumVox = [0.0 for i in range(len(semLabels))] 169 | accumVoxGt = [0.0 for i in range(len(semLabels))] 170 | mTestDataSet.start_iteration() 171 | while mTestDataSet.has_more_batches(): 172 | 173 | _, points, batchIds, features, labels, _, sceneName = mTestDataSet.get_next_batch() 174 | currAccWeights = mTestDataSet.get_accuracy_masks(labels) 175 | 176 | for iterExec in range(args.nExec): 177 | 178 | lossRes, predictedLabelsRes = sess.run([loss, predictedLabels], 179 | {inPts: points, inBatchIds: batchIds, inFeatures: features, inWeights: currAccWeights, 180 | inAccWeights: currAccWeights, inLabels: labels, isTraining: False, keepProbConv: 1.0, keepProbFull: 1.0}) 181 | 182 | 183 | #Save models 184 | if args.saveModels: 185 | save_model("savedModels/"+sceneName[0]+"_gt", points, labels, objColors) 186 | save_model("savedModels/"+sceneName[0]+"_pred", points, predictedLabelsRes.reshape((-1, 1)), 187 | objColors) 188 | 189 | labels = labels.reshape((-1)) 190 | 191 | #Compute IoU 192 | for k in range(len(predictedLabelsRes)): 193 | if labels[k] != 0: 194 | if labels[k] == predictedLabelsRes[k]: 195 | accumIntersection[predictedLabelsRes[k]] = accumIntersection[predictedLabelsRes[k]] + 1.0 196 | accumUnion[predictedLabelsRes[k]] = accumUnion[predictedLabelsRes[k]] + 1.0 197 | else: 198 | accumUnion[labels[k]] = accumUnion[labels[k]] + 1.0 199 | accumUnion[predictedLabelsRes[k]] = accumUnion[predictedLabelsRes[k]] + 1.0 200 | accumGt[labels[k]] = accumGt[labels[k]] + 1.0 201 | 202 | accumLoss += lossRes 203 | 204 | accumTestLoss += lossRes 205 | 206 | #Compute Voxel accuracy 207 | resolution = 0.02 208 | coordMax = np.amax(points, axis=0) 209 | coordMin = np.amin(points, axis=0) 210 | nVoxels = np.ceil((coordMax-coordMin)/resolution) 211 | vidx = np.ceil((points-coordMin)/resolution) 212 | vidx = vidx[:,0]+vidx[:,1]*nVoxels[0]+vidx[:,2]*nVoxels[0]*nVoxels[1] 213 | uvidx = np.unique(vidx) 214 | voxelLabelCount = [np.bincount(labels[vidx==uv].astype(np.int32), minlength=len(semLabels)) for uv in uvidx] 215 | voxelPredLabelCount = [np.bincount(predictedLabelsRes[vidx==uv].astype(np.int32), minlength=len(semLabels)) for uv in uvidx] 216 | uvlabel = np.argmax(voxelLabelCount, axis = 1) 217 | uvpredlabel = np.argmax(voxelPredLabelCount, axis = 1) 218 | validVoxels = [1 if float(voxelLabelCount[k][0])/float(np.sum(voxelLabelCount[k])) < 0.3 and uvlabel[k] > 0 else 0 for k in range(len(uvidx))] 219 | 220 | for k in range(len(uvlabel)): 221 | if validVoxels[k] == 1: 222 | if uvlabel[k] == uvpredlabel[k]: 223 | accumVox[uvlabel[k]] = accumVox[uvlabel[k]] + 1.0 224 | accumVoxGt[uvlabel[k]] = accumVoxGt[uvlabel[k]] + 1.0 225 | 226 | visualize_progress(it, numTestRooms, ("Loss: %.6f "+sceneName[0]) % (accumLoss/(args.nExec))) 227 | accumLoss = 0.0 228 | it += 1 229 | 230 | #Compute mean IoU 231 | print("############################## Category IoU / Acc / VoxAcc") 232 | meanIoUxCat = 0.0 233 | totalAccuracy = 0.0 234 | totalVoxAccuracy = 0.0 235 | totalIntersection = 0.0 236 | totalGt = 0.0 237 | for i in range(1, len(semLabels)): 238 | 239 | currMean = 0.0 240 | if accumUnion[i] <= 0.0: 241 | currMean = 1.0 242 | else: 243 | currMean = accumIntersection[i] / accumUnion[i] 244 | 245 | currAccuracy = 0.0 246 | if accumGt[i] <= 0.0: 247 | currAccuracy = 1.0 248 | else: 249 | currAccuracy = accumIntersection[i] / accumGt[i] 250 | 251 | currVoxAccuracy = 0.0 252 | if accumVoxGt[i] <= 0.0: 253 | currVoxAccuracy = 1.0 254 | else: 255 | currVoxAccuracy = accumVox[i] / accumVoxGt[i] 256 | 257 | totalIntersection = totalIntersection + accumIntersection[i] 258 | totalGt = totalGt + accumGt[i] 259 | 260 | print("Mean category "+semLabels[i]+": %.4f | %.4f | %.4f" % (currMean*100.0, currAccuracy*100.0, currVoxAccuracy*100.0)) 261 | 262 | meanIoUxCat = meanIoUxCat + currMean 263 | totalAccuracy = totalAccuracy + currAccuracy 264 | totalVoxAccuracy = totalVoxAccuracy + currVoxAccuracy 265 | 266 | meanIoUxCat = meanIoUxCat / float(len(semLabels)-1) 267 | totalAccuracy = totalAccuracy / float(len(semLabels)-1) 268 | totalVoxAccuracy = totalVoxAccuracy / float(len(semLabels)-1) 269 | accumTestLoss = accumTestLoss/float(numTestRooms*args.nExec) 270 | noMeantotalAccuracy = totalIntersection / totalGt 271 | 272 | #Print results 273 | print("############################## Global Accuracy and IoU") 274 | print("Loss: %.6f" % (accumTestLoss)) 275 | print("Test total accuracy: %.4f" % (noMeantotalAccuracy*100.0)) 276 | print("Test accuracy: %.4f" % (totalAccuracy*100.0)) 277 | print("Test voxel accuracy: %.4f" % (totalVoxAccuracy*100.0)) 278 | print("Test IoU %.4f" % (meanIoUxCat*100.0)) 279 | -------------------------------------------------------------------------------- /ScanNet/genScanNetData.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file genScanNetData.py 3 | 4 | \brief Code to process the scannet dataset. 5 | 6 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 7 | Germany. See the LICENSE file at the top-level directory of 8 | this distribution. 9 | 10 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 11 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 12 | 13 | import sys 14 | import argparse 15 | import os 16 | import json 17 | import gc 18 | from os import listdir 19 | from os.path import isdir, join 20 | import numpy as np 21 | from ply_reader import read_points_binary_ply 22 | 23 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 24 | ROOT_DIR = os.path.dirname(BASE_DIR) 25 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 26 | 27 | from PyUtils import visualize_progress 28 | 29 | def create_raw_to_scannet_label_map(folder, scannetLabels, version): 30 | if version == 1: 31 | indexs = [0,6] 32 | lines = [line.rstrip() for line in open(folder+"/scannet-labels.combined.tsv")] 33 | elif version == 2: 34 | indexs = [1,7] 35 | lines = [line.rstrip() for line in open(folder+"/scannetv2-labels.combined.tsv")] 36 | else: 37 | raise RuntimeError('Unrecognized ScanNet version') 38 | lines = lines[1:] 39 | rawToScanNet = {} 40 | for i in range(len(lines)): 41 | scannetLabelsSet = set(scannetLabels) 42 | elements = lines[i].split('\t') 43 | if elements[indexs[1]] not in scannetLabelsSet: 44 | rawToScanNet[elements[indexs[0]]] = 'unannotated' 45 | else: 46 | rawToScanNet[elements[indexs[0]]] = elements[indexs[1]] 47 | return rawToScanNet 48 | 49 | 50 | def get_scanned_rooms(folder): 51 | scannedRooms = [f for f in listdir(folder+"/scans") if isdir(folder+"/scans/"+f)] 52 | return scannedRooms 53 | 54 | 55 | def load_room(roomName, folder): 56 | plydata = read_points_binary_ply(folder+"/scans/"+roomName+"/"+roomName+"_vh_clean_2.ply") 57 | positions = [[currPt[0], currPt[1], currPt[2]] for currPt in plydata] 58 | if len(plydata[0]) > 7: 59 | normals = [[currPt[3], currPt[4], currPt[5]] for currPt in plydata] 60 | colors = [[currPt[6], currPt[7], currPt[8], currPt[9]] for currPt in plydata] 61 | else: 62 | normals = [] 63 | colors = [[currPt[3], currPt[4], currPt[5], currPt[6]] for currPt in plydata] 64 | 65 | return positions, normals, colors 66 | 67 | 68 | def load_segmentation(roomName, folder): 69 | with open(folder+"/scans/"+roomName+"/"+roomName+"_vh_clean_2.0.010000.segs.json") as jsondata: 70 | d = json.load(jsondata) 71 | seg = d['segIndices'] 72 | objToSegMap = {} 73 | for i in range(len(seg)): 74 | if seg[i] not in objToSegMap: 75 | objToSegMap[seg[i]] = [] 76 | objToSegMap[seg[i]].append(i) 77 | return objToSegMap 78 | 79 | 80 | def load_labels(roomName, folder, pts, segMap, rawToScanNetMap, scannetLabels, weights): 81 | labels = [0 for i in range(len(pts))] 82 | with open(folder+"/scans/"+roomName+"/"+roomName+".aggregation.json") as jsondata: 83 | d = json.load(jsondata) 84 | for x in d['segGroups']: 85 | label = 'unannotated' 86 | if x['label'] in rawToScanNetMap: 87 | label = rawToScanNetMap[x['label']] 88 | labelId = scannetLabels.index(label) 89 | for segment in x['segments']: 90 | for ptId in segMap[segment]: 91 | labels[ptId] = labelId 92 | weights[labelId] += 1.0 93 | return labels 94 | 95 | 96 | def compute_aabb(pts): 97 | maxPt = np.array([-10000.0, -10000.0, -10000.0]) 98 | minPt = np.array([10000.0, 10000.0, 10000.0]) 99 | for currPt in pts: 100 | maxPt = np.maximum(currPt, maxPt) 101 | minPt = np.minimum(currPt, minPt) 102 | return minPt, maxPt 103 | 104 | 105 | def process_room(folder, outFolder, room, scannetLabels, rawToScanNetMap, weights, aabbSizesVec, numPointsVec): 106 | pos, normals, colors = load_room(room, folder) 107 | segMap = load_segmentation(room, folder) 108 | labels = load_labels(room, folder, pos, segMap, rawToScanNetMap, scannetLabels, weights) 109 | minPt, maxPt = compute_aabb(pos) 110 | 111 | np.save(outFolder+"/"+room+"_pos.npy", pos) 112 | if len(normals) > 0: 113 | np.save(outFolder+"/"+room+"_normals.npy", normals) 114 | np.save(outFolder+"/"+room+"_colors.npy", colors) 115 | np.save(outFolder+"/"+room+"_labels.npy", labels) 116 | np.savetxt(outFolder+"/"+room+"_aabb.txt", [minPt, maxPt]) 117 | 118 | numPointsVec.append(len(pos)) 119 | aabbSizesVec.append(np.amax(maxPt - minPt)) 120 | 121 | return len(pos) 122 | 123 | 124 | if __name__ == '__main__': 125 | 126 | parser = argparse.ArgumentParser(description='Script to train MCCNN for segmentation tasks (ShapeNet)') 127 | parser.add_argument('--inFolder', default='data', help='Folder of the input ScanNet data (default: data)') 128 | parser.add_argument('--outFolder', default='data_mccnn', help='Folder of the output ScanNet data (default: data_mccnn)') 129 | parser.add_argument('--version', default=1, type=int, help='ScanNet version (default: 1)') 130 | args = parser.parse_args() 131 | 132 | scannetLabels = ['unannotated', 'wall', 'floor', 'chair', 'table', 'desk', 'bed', 'bookshelf', 'sofa', 133 | 'sink', 'bathtub', 'toilet', 'curtain', 'counter', 'door', 'window', 'shower curtain', 'refridgerator', 134 | 'picture', 'cabinet', 'otherfurniture'] 135 | weights = [0.0 for i in range(len(scannetLabels))] 136 | print(scannetLabels) 137 | 138 | outFolder = args.outFolder+"_v"+str(args.version) 139 | if not os.path.exists(outFolder): os.mkdir(outFolder) 140 | 141 | np.savetxt(outFolder+"/labels.txt", scannetLabels, fmt='%s') 142 | 143 | rawToScanNetMap = create_raw_to_scannet_label_map(args.inFolder, scannetLabels, args.version) 144 | print rawToScanNetMap 145 | 146 | scannedRooms = get_scanned_rooms(args.inFolder) 147 | 148 | np.savetxt(outFolder+"/rooms.txt", scannedRooms, fmt='%s') 149 | 150 | numPointsVec = [] 151 | aabbSizeVec = [] 152 | iter = 0 153 | for room in scannedRooms: 154 | visualize_progress(iter, len(scannedRooms), room) 155 | 156 | numPoints = 0 157 | try: 158 | numPoints = process_room(args.inFolder, outFolder, room, scannetLabels, rawToScanNetMap, weights, aabbSizeVec, numPointsVec) 159 | except Exception, e: 160 | print(room+'ERROR!!') 161 | print(str(e)) 162 | 163 | gc.collect() 164 | 165 | visualize_progress(iter, len(scannedRooms), room + " | "+str(numPoints)) 166 | 167 | iter += 1 168 | 169 | sumWeights = 0.0 170 | for weight in weights[1:]: 171 | sumWeights += weight 172 | 173 | weights1 = [w for w in weights] 174 | weights2 = [w for w in weights] 175 | for i in range(len(weights)): 176 | weights1[i] = weights1[i]/sumWeights 177 | weights2[i] = weights2[i]/(sumWeights+weights[0]) 178 | 179 | print(weights1) 180 | print(weights2) 181 | 182 | np.savetxt(outFolder+"/weights.txt", [weights1, weights2]) 183 | np.savetxt(outFolder+"/num_points.txt", numPointsVec) 184 | np.savetxt(outFolder+"/aabb_sizes.txt", aabbSizeVec) -------------------------------------------------------------------------------- /ScanNet/ply_reader.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file genScanNetData.py 3 | 4 | \brief Ply binary reader. This code is a modification from the file ply.py 5 | from the project pyntcloud of David de la Iglesia Castro 6 | (https://github.com/daavoo/pyntcloud) 7 | 8 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 9 | Germany. See the LICENSE file at the top-level directory of 10 | this distribution. 11 | 12 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 13 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 14 | 15 | import sys 16 | import numpy as np 17 | from collections import defaultdict 18 | 19 | sys_byteorder = ('>', '<')[sys.byteorder == 'little'] 20 | 21 | ply_dtypes = dict([ 22 | (b'int8', 'i1'), 23 | (b'char', 'i1'), 24 | (b'uint8', 'u1'), 25 | (b'uchar', 'b1'), 26 | (b'uchar', 'u1'), 27 | (b'int16', 'i2'), 28 | (b'short', 'i2'), 29 | (b'uint16', 'u2'), 30 | (b'ushort', 'u2'), 31 | (b'int32', 'i4'), 32 | (b'int', 'i4'), 33 | (b'uint32', 'u4'), 34 | (b'uint', 'u4'), 35 | (b'float32', 'f4'), 36 | (b'float', 'f4'), 37 | (b'float64', 'f8'), 38 | (b'double', 'f8') 39 | ]) 40 | 41 | valid_formats = {'ascii': '', 'binary_big_endian': '>', 42 | 'binary_little_endian': '<'} 43 | 44 | 45 | def read_points_binary_ply(filename): 46 | with open(filename, 'rb') as ply: 47 | 48 | if b'ply' not in ply.readline(): 49 | raise ValueError('The file does not start whith the word ply') 50 | 51 | fmt = ply.readline().split()[1].decode() 52 | if fmt == "ascii": 53 | raise ValueError('The file format is ascii not binary') 54 | 55 | ext = valid_formats[fmt] 56 | 57 | line = [] 58 | dtypes = defaultdict(list) 59 | count = 2 60 | points_size = None 61 | mesh_size = None 62 | while b'end_header' not in line and line != b'': 63 | line = ply.readline() 64 | 65 | if b'element' in line: 66 | line = line.split() 67 | name = line[1].decode() 68 | size = int(line[2]) 69 | if name == "vertex": 70 | points_size = size 71 | elif name == "face": 72 | mesh_size = size 73 | 74 | elif b'property' in line: 75 | line = line.split() 76 | if b'list' not in line: 77 | dtypes[name].append((line[2], ext + ply_dtypes[line[1]])) 78 | count += 1 79 | 80 | end_header = ply.tell() 81 | 82 | with open(filename, 'rb') as ply: 83 | ply.seek(end_header) 84 | points_np = np.fromfile(ply, dtype=np.dtype(dtypes["vertex"]), count=points_size) 85 | if ext != sys_byteorder: 86 | points_np = points_np.byteswap().newbyteorder() 87 | 88 | return points_np -------------------------------------------------------------------------------- /ShapeNet/ShapeNetDataSet.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file ShapeNetDataSet.py 3 | 4 | \brief ShapeNet dataset class. 5 | 6 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 7 | Germany. See the LICENSE file at the top-level directory of 8 | this distribution. 9 | 10 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 11 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 12 | 13 | import sys 14 | import os 15 | import math 16 | import time 17 | import json 18 | import numpy as np 19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 20 | ROOT_DIR = os.path.dirname(BASE_DIR) 21 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 22 | from DataSet import DataSet 23 | 24 | class ShapeNetDataSet(DataSet): 25 | """ShapeNet dataset. 26 | 27 | Attributes: 28 | useNormalsAsFeatures_ (bool): Boolean that indicates if the normals will be used as the input features. 29 | cat_ (nx2 array): List of tuples (category name, category folder) of the categories in the dataset. 30 | segClasses_ (dictionary of arrays): Each entry of the dictionary has a key equal to the name of the 31 | category and a list of part identifiers. 32 | """ 33 | 34 | def __init__(self, train, batchSize, ptDropOut, allowedSamplings=[0], augment=False, 35 | useNormalsAsFeatures=False, seed=None): 36 | """Constructor. 37 | 38 | Args: 39 | train (bool): Boolean that indicates if this is the train or test dataset. 40 | batchSize (int): Size of the batch used. 41 | ptDropOut (float): Probability to keep a point during uniform sampling when all the points 42 | or only the first n number of points are selected. 43 | allowedSamplings (array of ints): Each element of the array determines an allowed sampling protocol 44 | that will be used to sample the different models. The implemented sampling protocols are: 45 | - 0: Uniform sampling 46 | - 1: Split sampling 47 | - 2: Gradient sampling 48 | - 3: Lambert sampling 49 | - 4: Occlusion sampling 50 | augment (bool): Boolean that indicates if data augmentation will be used in the models. 51 | useNormalsAsFeatures (bool): Boolean that indicates if the normals will be used as the input features. 52 | seed (int): Seed used to initialize the random number generator. If None is provided instead, the current 53 | time on the machine will be used to initialize the number generator. 54 | """ 55 | 56 | # Store the parameters of the class. 57 | self.useNormalsAsFeatures_ = useNormalsAsFeatures 58 | 59 | # Create the list of features that need to be augmented. 60 | augmentedFeatures = [] 61 | if useNormalsAsFeatures: 62 | augmentedFeatures = [0] 63 | 64 | # Call the constructor of the parent class. 65 | super(ShapeNetDataSet,self).__init__(0, ptDropOut, useNormalsAsFeatures, True, True, 66 | True, True, batchSize, allowedSamplings, 100000000, 0, 67 | augment, 1, True, False, augmentedFeatures, [], seed) 68 | 69 | # Get the categories and their associated part ids.. 70 | self.catNames_ = [] 71 | with open("./shape_data/synsetoffset2category.txt", 'r') as nameFile: 72 | for line in nameFile: 73 | strings = line.replace("\n", "").split("\t") 74 | self.catNames_.append((strings[0], strings[1])) 75 | 76 | self.segClasses_ = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 77 | 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 78 | 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 79 | 80 | # List of files. 81 | if train: 82 | with open("./shape_data/train_test_split/shuffled_train_file_list.json", 'r') as f: 83 | self.fileList_ = list([d for d in json.load(f)]) 84 | with open("./shape_data/train_test_split/shuffled_val_file_list.json", 'r') as f: 85 | self.fileList_ = self.fileList_ + list([d for d in json.load(f)]) 86 | else: 87 | with open("./shape_data/train_test_split/shuffled_test_file_list.json", 'r') as f: 88 | self.fileList_ = list([d for d in json.load(f)]) 89 | 90 | # Check the categories per model. 91 | for currModel in self.fileList_: 92 | catId = 0 93 | for currCat in range(len(self.catNames_)): 94 | if self.catNames_[currCat][1] in currModel: 95 | catId = currCat 96 | self.categories_.append(catId) 97 | 98 | # Since we do not know the size of the models in advance we initialize them to 0 and the first that will be loaded 99 | # this values will be update automatically. 100 | self.numPts_ = [0 for i in range(len(self.fileList_))] 101 | 102 | 103 | def get_categories(self): 104 | """Method to get the list of categories. 105 | 106 | Returns: 107 | pts (nx2 np.array string): List of tuples with the category name and the folder name. 108 | """ 109 | return self.catNames_ 110 | 111 | 112 | def get_categories_seg_parts(self): 113 | """Method to get the list of parts per category. 114 | 115 | Returns: 116 | pts (dict of array): Each entry of the dictionary has a key equal to the name of the 117 | category and a list of part identifiers. 118 | """ 119 | return self.segClasses_ 120 | 121 | 122 | def _load_model_from_disk_(self, modelPath): 123 | """Abstract method that should be implemented by child class which loads a model 124 | from disk. 125 | 126 | Args: 127 | modelPath (string): Path to the model that needs to be loaded. 128 | 129 | Returns: 130 | pts (nx3 np.array): List of points. 131 | normals (nx3 np.array): List of normals. If the dataset does not contain 132 | normals, None should be returned. 133 | features (nxm np.array): List of features. If the dataset does not contain 134 | features, None should be returned. 135 | labels (nxl np.array): List of labels. If the dataset does not contain 136 | labels, None should be returned. 137 | """ 138 | 139 | fileDataArray = [] 140 | with open(modelPath+".txt", 'r') as modelFile: 141 | for line in modelFile: 142 | line = line.replace("\n", "") 143 | currPoint = line.split() 144 | fileDataArray.append([float(currPoint[0]), float(currPoint[1]), 145 | float(currPoint[2]), float(currPoint[3]), float(currPoint[4]), 146 | float(currPoint[5]), float(currPoint[6])]) 147 | fileData = np.array(fileDataArray) 148 | 149 | pts = fileData[:,0:3] 150 | normals = fileData[:,3:6] 151 | features = None 152 | if self.useNormalsAsFeatures_: 153 | features = normals 154 | labels = fileData[:,6:7] 155 | 156 | return pts, normals, features, labels -------------------------------------------------------------------------------- /ShapeNet/ShapeNetEval.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file ShapeNet.py 3 | 4 | \brief Code to evaluate a segmentation network on the ShapeNet dataset. 5 | 6 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 7 | Germany. See the LICENSE file at the top-level directory of 8 | this distribution. 9 | 10 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 11 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 12 | 13 | import sys 14 | import math 15 | import time 16 | import argparse 17 | import importlib 18 | import os 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 23 | ROOT_DIR = os.path.dirname(BASE_DIR) 24 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 25 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 26 | 27 | from PyUtils import visualize_progress, save_model 28 | from ShapeNetDataSet import ShapeNetDataSet 29 | 30 | current_milli_time = lambda: time.time() * 1000.0 31 | 32 | 33 | def create_accuracy(logits, labels, scope): 34 | _, logitsIndexs = tf.nn.top_k(logits) 35 | with tf.variable_scope(scope): 36 | return tf.metrics.accuracy(labels, logitsIndexs) 37 | 38 | 39 | if __name__ == '__main__': 40 | 41 | parser = argparse.ArgumentParser(description='Evaluation of segmentation networks (ShapeNet)') 42 | parser.add_argument('--grow', default=32, type=int, help='Grow rate (default: 32)') 43 | parser.add_argument('--inTrainedModel', default='log/model.ckpt', help='Input trained model (default: log/model.ckpt)') 44 | parser.add_argument('--model', default='MCSeg', help='model (default: MCSeg)') 45 | parser.add_argument('--gpu', default='0', help='GPU (default: 0)') 46 | parser.add_argument('--gpuMem', default=0.5, type=float, help='GPU memory used (default: 0.5)') 47 | parser.add_argument('--nExec', default=1, type=int, help='Number of executions per model (default: 1)') 48 | parser.add_argument('--saveModels', action='store_true', help='Save models (default: False)') 49 | args = parser.parse_args() 50 | 51 | 52 | print("Trained model: "+args.inTrainedModel) 53 | print("Model: "+args.model) 54 | print("Grow: "+str(args.grow)) 55 | print("nExec: "+str(args.nExec)) 56 | 57 | #Colors asigned to each part (used to save the model as a file). 58 | colors = [ [228,26,28], 59 | [55,126,184], 60 | [77,175,74], 61 | [152,78,163], 62 | [255,127,0], 63 | [255,255,51]] 64 | 65 | #Load the model 66 | model = importlib.import_module(args.model) 67 | 68 | #Get train and test files 69 | mTestDataSet = ShapeNetDataSet(False, args.nExec, 1.0, [0], False) 70 | cat = mTestDataSet.get_categories() 71 | segClasses = mTestDataSet.get_categories_seg_parts() 72 | print(segClasses) 73 | numTestModels = mTestDataSet.get_num_models() 74 | print("Test models: " + str(numTestModels)) 75 | 76 | #Save models, create folder 77 | if args.saveModels: 78 | if not os.path.exists("savedModels"): os.mkdir("savedModels") 79 | 80 | #Create variable and place holders 81 | global_step = tf.Variable(0, name='global_step', trainable=False) 82 | inPts = tf.placeholder(tf.float32, [None, 3]) 83 | inBatchIds = tf.placeholder(tf.int32, [None, 1]) 84 | inFeatures = tf.placeholder(tf.float32, [None, 1]) 85 | inCatLabels = tf.placeholder(tf.int32, [None, 1]) 86 | inLabels = tf.placeholder(tf.int32, [None, 1]) 87 | isTraining = tf.placeholder(tf.bool) 88 | keepProbConv = tf.placeholder(tf.float32) 89 | keepProbFull = tf.placeholder(tf.float32) 90 | 91 | #Create the network 92 | logits = model.create_network(inPts, inBatchIds, inFeatures,inCatLabels, 1, len(cat), 50, 93 | args.nExec, args.grow, isTraining, keepProbConv, keepProbFull, False, False) 94 | 95 | #Create predict labels 96 | predictedLabels = tf.argmax(logits, 1) 97 | 98 | #Create accuracy metric 99 | accuracyVal, accuracyAccumOp = create_accuracy(logits, inLabels, 'metrics') 100 | metricsVars = tf.contrib.framework.get_variables('metrics', collection=tf.GraphKeys.LOCAL_VARIABLES) 101 | resetMetrics = tf.variables_initializer(metricsVars) 102 | 103 | #Create init variables 104 | init = tf.global_variables_initializer() 105 | initLocal = tf.local_variables_initializer() 106 | 107 | #create the saver 108 | saver = tf.train.Saver() 109 | 110 | #Create session 111 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpuMem, visible_device_list=args.gpu) 112 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 113 | 114 | #Init variables 115 | sess.run(init) 116 | sess.run(initLocal) 117 | 118 | #Restore the model 119 | saver.restore(sess, args.inTrainedModel) 120 | 121 | #Test the dataset. 122 | titleTexts = [ 123 | "Uniform sampling", 124 | "Non-uniform split", 125 | "Non-uniform gradient", 126 | "Non-uniform lambert", 127 | "Non-uniform occlusion"] 128 | for samp in range(5): 129 | #Print the type of sampling used. 130 | print(titleTexts[samp]) 131 | #Update the dataset. 132 | allowedSamplingsTest = [samp] 133 | mTestDataSet.set_allowed_samplings(allowedSamplingsTest) 134 | mTestDataSet.start_iteration() 135 | #Create the auxiliar variables. 136 | it = 0 137 | accumTime = 0.0 138 | step = 0 139 | epochStep = 0 140 | maxIoU = 0.0 141 | IoUxCat = [[] for i in range(len(cat))] 142 | #Iterate over the models. 143 | while mTestDataSet.has_more_batches(): 144 | #Get the batch dataset. 145 | _, points, batchIds, features, labels, catLabels, modelsPath = mTestDataSet.get_next_batch(True) 146 | 147 | #Compute the predicted logits. 148 | startTimeMeasure = current_milli_time() 149 | predictedLabelsRes, _ = sess.run([predictedLabels, accuracyAccumOp], 150 | {inPts: points, inBatchIds: batchIds, inFeatures: features, inCatLabels: catLabels, 151 | inLabels: labels, isTraining: False, keepProbConv: 1.0, keepProbFull: 1.0}) 152 | endTimeMeasure = current_milli_time() 153 | accumTime = accumTime + (endTimeMeasure - startTimeMeasure) 154 | 155 | #Save models 156 | if args.saveModels: 157 | save_model("savedModels/"+modelsPath[0].replace("/", "-")+"_sampling_"+ 158 | str(samp)+"_gt", points, labels, colors, 6) 159 | save_model("savedModels/"+modelsPath[0].replace("/", "-")+"_sampling_"+ 160 | str(samp)+"_pred", points, predictedLabelsRes.reshape((-1,1)), 161 | colors, 6) 162 | 163 | #Compute IoU 164 | numParts = len(segClasses[cat[catLabels[0][0]][0]]) 165 | accumIoU = 0.0 166 | for j in range(numParts): 167 | intersection = 0.0 168 | union = 0.0 169 | currLabel = segClasses[cat[catLabels[0][0]][0]][j] 170 | for k in range(len(labels)): 171 | if labels[k] == predictedLabelsRes[k] and labels[k] == currLabel: 172 | intersection = intersection + 1.0 173 | if labels[k] == currLabel or predictedLabelsRes[k] == currLabel: 174 | union = union + 1.0 175 | if union > 0.0: 176 | accumIoU = accumIoU + intersection/union 177 | else: 178 | accumIoU = accumIoU + 1.0 179 | accumIoU = accumIoU/float(numParts) 180 | IoUxCat[catLabels[0][0]].append(accumIoU) 181 | 182 | if it % 100 == 0: 183 | visualize_progress(it, numTestModels) 184 | 185 | it += 1 186 | 187 | #Compute mean IoU 188 | meanIoUxCat = 0.0 189 | for i in range(len(IoUxCat)): 190 | currMean = 0.0 191 | for currVal in IoUxCat[i]: 192 | currMean = currMean + currVal 193 | currMean = currMean / float(len(IoUxCat[i])) 194 | print("Mean IoU category "+cat[i][0]+": "+str(currMean)) 195 | meanIoUxCat = meanIoUxCat + currMean*float(len(IoUxCat[i])) 196 | meanIoUxCat = meanIoUxCat / float(numTestModels) 197 | 198 | totalAccuracy = sess.run(accuracyVal) 199 | 200 | print("Time: %.8f" % (accumTime/(float(numTestModels)))) 201 | print("Test accuracy: %.4f | Test IoU %.4f [ %.4f ]" % (totalAccuracy*100.0, meanIoUxCat*100.0, maxIoU*100.0)) -------------------------------------------------------------------------------- /models/MCClass.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file MCClass.py 3 | 4 | \brief Definition of the network architecture MCClass for classification 5 | tasks. 6 | 7 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | Germany. See the LICENSE file at the top-level directory of 9 | this distribution. 10 | 11 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 13 | 14 | import sys 15 | import os 16 | import math 17 | import tensorflow as tf 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | ROOT_DIR = os.path.dirname(BASE_DIR) 20 | sys.path.append(os.path.join(ROOT_DIR, 'tf_ops')) 21 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 22 | from MCConvBuilder import PointHierarchy, ConvolutionBuilder 23 | from MCNetworkUtils import MLP_2_hidden, batch_norm_RELU_drop_out, conv_1x1 24 | 25 | def create_network(points, batchIds, features, numInputFeatures, batchSize, k, numOutCat, isTraining, 26 | keepProbConv, keepProbFull, useConvDropOut = False, useDropOutFull = True): 27 | 28 | ############################################ Compute point hierarchy 29 | mPointHierarchy = PointHierarchy(points, features, batchIds, [0.1, 0.4, math.sqrt(3.0)+0.1], "MCClass_PH", batchSize) 30 | 31 | 32 | ############################################ Convolutions 33 | mConvBuilder = ConvolutionBuilder(KDEWindow=0.25) 34 | 35 | # First Convolution 36 | convFeatures1 = mConvBuilder.create_convolution( 37 | convName="Conv_1", 38 | inPointHierarchy=mPointHierarchy, 39 | inPointLevel=0, 40 | inFeatures=features, 41 | inNumFeatures=numInputFeatures, 42 | outNumFeatures=k, 43 | convRadius= 0.1, 44 | multiFeatureConv=True) 45 | 46 | # First Pooling 47 | convFeatures1 = batch_norm_RELU_drop_out("Reduce_Pool_1_In_BN", convFeatures1, isTraining, useConvDropOut, keepProbConv) 48 | convFeatures1 = conv_1x1("Reduce_Pool_1", convFeatures1, k, k*2) 49 | convFeatures1 = batch_norm_RELU_drop_out("Reduce_Pool_1_Out_BN", convFeatures1, isTraining, useConvDropOut, keepProbConv) 50 | poolFeatures1 = mConvBuilder.create_convolution( 51 | convName="Pool_1", 52 | inPointHierarchy=mPointHierarchy, 53 | inPointLevel=0, 54 | outPointLevel=1, 55 | inFeatures=convFeatures1, 56 | inNumFeatures=k*2, 57 | convRadius=0.2, 58 | KDEWindow= 0.2) 59 | 60 | 61 | # Second Convolution 62 | bnPoolFeatures1 = batch_norm_RELU_drop_out("Conv_2_In_BN", poolFeatures1, isTraining, useConvDropOut, keepProbConv) 63 | convFeatures2 = mConvBuilder.create_convolution( 64 | convName="Conv_2", 65 | inPointHierarchy=mPointHierarchy, 66 | inPointLevel=1, 67 | inFeatures=bnPoolFeatures1, 68 | inNumFeatures=k*2, 69 | convRadius=0.4) 70 | convFeatures2 = tf.concat([poolFeatures1, convFeatures2], 1) 71 | 72 | # Second Pooling 73 | convFeatures2 = batch_norm_RELU_drop_out("Reduce_Pool_2_In_BN", convFeatures2, isTraining, useConvDropOut, keepProbConv) 74 | convFeatures2 = conv_1x1("Reduce_Pool_2", convFeatures2, k*4, k*8) 75 | convFeatures2 = batch_norm_RELU_drop_out("Reduce_Pool_2_Out_BN", convFeatures2, isTraining, useConvDropOut, keepProbConv) 76 | poolFeatures2 = mConvBuilder.create_convolution( 77 | convName="Pool_2", 78 | inPointHierarchy=mPointHierarchy, 79 | inPointLevel=1, 80 | outPointLevel=2, 81 | inFeatures=convFeatures2, 82 | inNumFeatures=k*8, 83 | convRadius=0.8, 84 | KDEWindow= 0.2) 85 | 86 | 87 | # Third Convolution 88 | bnPoolFeatures2 = batch_norm_RELU_drop_out("Conv_3_In_BN", poolFeatures2, isTraining, useConvDropOut, keepProbConv) 89 | convFeatures3 = mConvBuilder.create_convolution( 90 | convName="Conv_3", 91 | inPointHierarchy=mPointHierarchy, 92 | inPointLevel=2, 93 | inFeatures=bnPoolFeatures2, 94 | inNumFeatures=k*8, 95 | convRadius=1.1) 96 | convFeatures3 = tf.concat([poolFeatures2, convFeatures3], 1) 97 | 98 | # Third Pooling 99 | convFeatures3 = batch_norm_RELU_drop_out("Reduce_Pool_3_In_BN", convFeatures3, isTraining, useConvDropOut, keepProbConv) 100 | convFeatures3 = conv_1x1("Reduce_Pool_3", convFeatures3, k*16, k*32) 101 | convFeatures3 = batch_norm_RELU_drop_out("Reduce_Pool_3_Out_BN", convFeatures3, isTraining, useConvDropOut, keepProbConv) 102 | poolFeatures3 = mConvBuilder.create_convolution( 103 | convName="Pool_3", 104 | inPointHierarchy=mPointHierarchy, 105 | inPointLevel=2, 106 | outPointLevel=3, 107 | inFeatures=convFeatures3, 108 | inNumFeatures=k*32, 109 | convRadius=math.sqrt(3.0)+0.1, 110 | KDEWindow= 0.2) 111 | 112 | #Fully connected MLP - Global features. 113 | finalInput = batch_norm_RELU_drop_out("BNRELUDROP_final", poolFeatures3, isTraining, useConvDropOut, keepProbConv) 114 | finalLogits = MLP_2_hidden(finalInput, k*32, k*16, k*8, numOutCat, "Final_Logits", keepProbFull, isTraining, useDropOutFull) 115 | 116 | return finalLogits 117 | -------------------------------------------------------------------------------- /models/MCClassH.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file MCClassH.py 3 | 4 | \brief Definition of the network architecture MCClassH for classification 5 | tasks, in which the class probabilities are computed by two 6 | separated paths. 7 | 8 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 9 | Germany. See the LICENSE file at the top-level directory of 10 | this distribution. 11 | 12 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 13 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 14 | 15 | import sys 16 | import os 17 | import math 18 | import tensorflow as tf 19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 20 | ROOT_DIR = os.path.dirname(BASE_DIR) 21 | sys.path.append(os.path.join(ROOT_DIR, 'tf_ops')) 22 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 23 | from MCConvBuilder import PointHierarchy, ConvolutionBuilder 24 | from MCNetworkUtils import MLP_2_hidden, batch_norm_RELU_drop_out, conv_1x1 25 | 26 | def create_network(points, batchIds, features, numInputFeatures, batchSize, k, numOutCat, isTraining, 27 | keepProbConv, keepProbFull, useConvDropOut = False, useDropOutFull = True): 28 | 29 | ############################################ Compute point hierarchy 30 | mPointHierarchy = PointHierarchy(points, features, batchIds, [0.1, 0.4, math.sqrt(3.0)+0.1], "MCClassH_PH", batchSize) 31 | 32 | ############################################ Convolutions 33 | mConvBuilder = ConvolutionBuilder(KDEWindow=0.25) 34 | 35 | ############################################ LOGITS 1 36 | 37 | ############################################ First level 38 | 39 | # Convolution 40 | convFeatures1 = mConvBuilder.create_convolution( 41 | convName="Conv_1", 42 | inPointHierarchy=mPointHierarchy, 43 | inPointLevel=0, 44 | inFeatures=features, 45 | inNumFeatures=numInputFeatures, 46 | outNumFeatures=k, 47 | convRadius= 0.1, 48 | multiFeatureConv=True) 49 | 50 | # Pooling 51 | convFeatures1 = batch_norm_RELU_drop_out("Reduce_Pool_1_In_BN", convFeatures1, isTraining, useConvDropOut, keepProbConv) 52 | convFeatures1 = conv_1x1("Reduce_Pool_1", convFeatures1, k, k*2) 53 | convFeatures1 = batch_norm_RELU_drop_out("Reduce_Pool_1_Out_BN", convFeatures1, isTraining, useConvDropOut, keepProbConv) 54 | poolFeatures1 = mConvBuilder.create_convolution( 55 | convName="Pool_1", 56 | inPointHierarchy=mPointHierarchy, 57 | inPointLevel=0, 58 | outPointLevel=1, 59 | inFeatures=convFeatures1, 60 | inNumFeatures=k*2, 61 | convRadius=0.2, 62 | KDEWindow= 0.2) 63 | 64 | ############################################ Second level convolutions 65 | 66 | #### Convolution 67 | bnPoolFeatures1 = batch_norm_RELU_drop_out("Reduce_Conv_2_In_BN", poolFeatures1, isTraining, useConvDropOut, keepProbConv) 68 | bnPoolFeatures1 = conv_1x1("Reduce_Conv_2", bnPoolFeatures1, k*2, k*2) 69 | bnPoolFeatures1 = batch_norm_RELU_drop_out("Reduce_Conv_2_Out_BN", bnPoolFeatures1, isTraining, useConvDropOut, keepProbConv) 70 | convFeatures2 = mConvBuilder.create_convolution( 71 | convName="Conv_2", 72 | inPointHierarchy=mPointHierarchy, 73 | inPointLevel=1, 74 | inFeatures=bnPoolFeatures1, 75 | inNumFeatures=k*2, 76 | convRadius=0.4) 77 | convFeatures2 = tf.concat([poolFeatures1, convFeatures2], 1) 78 | 79 | # Pooling 80 | convFeatures2 = batch_norm_RELU_drop_out("Reduce_Pool_2_In_BN", convFeatures2, isTraining, useConvDropOut, keepProbConv) 81 | convFeatures2 = conv_1x1("Reduce_Pool_2", convFeatures2, k*4, k*8) 82 | convFeatures2 = batch_norm_RELU_drop_out("Reduce_Pool_2_Out_BN", convFeatures2, isTraining, useConvDropOut, keepProbConv) 83 | poolFeatures2 = mConvBuilder.create_convolution( 84 | convName="Pool_2", 85 | inPointHierarchy=mPointHierarchy, 86 | inPointLevel=1, 87 | outPointLevel=2, 88 | inFeatures=convFeatures2, 89 | inNumFeatures=k*8, 90 | convRadius=0.8, 91 | KDEWindow= 0.2) 92 | 93 | 94 | ############################################ Third level convolutions 95 | 96 | # Convolution 97 | bnPoolFeatures2 = batch_norm_RELU_drop_out("Reduce_Conv_3_In_BN", poolFeatures2, isTraining, useConvDropOut, keepProbConv) 98 | bnPoolFeatures2 = conv_1x1("Reduce_Conv_3", bnPoolFeatures2, k*8, k*8) 99 | bnPoolFeatures2 = batch_norm_RELU_drop_out("Reduce_Conv_3_Out_BN", bnPoolFeatures2, isTraining, useConvDropOut, keepProbConv) 100 | convFeatures3 = mConvBuilder.create_convolution( 101 | convName="Conv_3", 102 | inPointHierarchy=mPointHierarchy, 103 | inPointLevel=2, 104 | inFeatures=bnPoolFeatures2, 105 | inNumFeatures=k*8, 106 | convRadius=1.2) 107 | convFeatures3 = tf.concat([poolFeatures2, convFeatures3], 1) 108 | 109 | # Pooling 110 | convFeatures3 = batch_norm_RELU_drop_out("Reduce_Pool_3_In_BN", convFeatures3, isTraining, useConvDropOut, keepProbConv) 111 | convFeatures3 = conv_1x1("Reduce_Pool_3", convFeatures3, k*16, k*32) 112 | convFeatures3 = batch_norm_RELU_drop_out("Reduce_Pool_3_Out_BN", convFeatures3, isTraining, useConvDropOut, keepProbConv) 113 | poolFeatures3 = mConvBuilder.create_convolution( 114 | convName="Pool_3", 115 | inPointHierarchy=mPointHierarchy, 116 | inPointLevel=2, 117 | outPointLevel=3, 118 | inFeatures=convFeatures3, 119 | inNumFeatures=k*32, 120 | convRadius=math.sqrt(3.0)+0.1, 121 | KDEWindow= 0.2) 122 | 123 | 124 | #Fully connected MLP - Global features. 125 | finalInput = batch_norm_RELU_drop_out("BNRELUDROP_final", poolFeatures3, isTraining, useConvDropOut, keepProbConv) 126 | finalLogits1 = MLP_2_hidden(finalInput, k*32, k*16, k*8, numOutCat, "Final_Logits", keepProbFull, isTraining, useDropOutFull) 127 | 128 | 129 | ############################################ LOGITS 2 130 | 131 | ############################################ Second level convolutions 132 | 133 | #### Convolution 134 | convFeatures22 = mConvBuilder.create_convolution( 135 | convName="Conv_2_2", 136 | inPointHierarchy=mPointHierarchy, 137 | inPointLevel=1, 138 | inFeatures=mPointHierarchy.features_[1], 139 | inNumFeatures=numInputFeatures, 140 | outNumFeatures=k*2, 141 | convRadius= 0.4, 142 | multiFeatureConv=True) 143 | 144 | # Pooling 145 | convFeatures22 = batch_norm_RELU_drop_out("Reduce_Pool_2_2_In_BN", convFeatures22, isTraining, useConvDropOut, keepProbConv) 146 | convFeatures22 = conv_1x1("Reduce_Pool_2_2", convFeatures22, k*2, k*8) 147 | convFeatures22 = batch_norm_RELU_drop_out("Reduce_Pool_2_2_Out_BN", convFeatures22, isTraining, useConvDropOut, keepProbConv) 148 | poolFeatures22 = mConvBuilder.create_convolution( 149 | convName="Pool_2_2", 150 | inPointHierarchy=mPointHierarchy, 151 | inPointLevel=1, 152 | outPointLevel=2, 153 | inFeatures=convFeatures22, 154 | inNumFeatures=k*8, 155 | convRadius=0.8, 156 | KDEWindow= 0.2) 157 | 158 | 159 | ############################################ Third level convolutions 160 | 161 | # Convolution 162 | bnPoolFeatures22 = batch_norm_RELU_drop_out("Reduce_Conv_3_2_In_BN", poolFeatures22, isTraining, useConvDropOut, keepProbConv) 163 | bnPoolFeatures22 = conv_1x1("Reduce_Conv_3_2", bnPoolFeatures22, k*8, k*8) 164 | bnPoolFeatures22 = batch_norm_RELU_drop_out("Reduce_Conv_3_2_Out_BN", bnPoolFeatures22, isTraining, useConvDropOut, keepProbConv) 165 | convFeatures32 = mConvBuilder.create_convolution( 166 | convName="Conv_3_2", 167 | inPointHierarchy=mPointHierarchy, 168 | inPointLevel=2, 169 | inFeatures=bnPoolFeatures22, 170 | inNumFeatures=k*8, 171 | convRadius=1.2) 172 | convFeatures32 = tf.concat([poolFeatures22, convFeatures32], 1) 173 | 174 | # Pooling 175 | convFeatures32 = batch_norm_RELU_drop_out("Reduce_Pool_3_2_In_BN", convFeatures32, isTraining, useConvDropOut, keepProbConv) 176 | convFeatures32 = conv_1x1("Reduce_Pool_3_2", convFeatures32, k*16, k*32) 177 | convFeatures32 = batch_norm_RELU_drop_out("Reduce_Pool_3_2_Out_BN", convFeatures32, isTraining, useConvDropOut, keepProbConv) 178 | poolFeatures32 = mConvBuilder.create_convolution( 179 | convName="Pool_3_2", 180 | inPointHierarchy=mPointHierarchy, 181 | inPointLevel=2, 182 | outPointLevel=3, 183 | inFeatures=convFeatures32, 184 | inNumFeatures=k*32, 185 | convRadius=math.sqrt(3.0)+0.1, 186 | KDEWindow= 0.2) 187 | 188 | #Fully connected MLP - Global features. 189 | finalInput2 = batch_norm_RELU_drop_out("2BNRELUDROP_final", poolFeatures32, isTraining, useConvDropOut, keepProbConv) 190 | finalLogits2 = MLP_2_hidden(finalInput2, k*32, k*16, k*8, numOutCat, "2Final_Logits", keepProbFull, isTraining, useDropOutFull) 191 | 192 | ############################################ PATH DROPOUT 193 | counter = tf.constant(0.0, dtype=tf.float32) 194 | 195 | probability = tf.random_uniform([1]) 196 | 197 | mask1 = tf.less_equal(probability[0], tf.constant(0.66)) 198 | mask1 = tf.maximum(tf.cast(mask1, tf.float32), tf.cast(tf.logical_not(isTraining), tf.float32)) 199 | counter = tf.add(counter, mask1) 200 | finalLogits1 = tf.scalar_mul(mask1, finalLogits1) 201 | 202 | mask2 = tf.greater_equal(probability[0], tf.constant(0.33)) 203 | mask2 = tf.maximum(tf.cast(mask2, tf.float32), tf.cast(tf.logical_not(isTraining), tf.float32)) 204 | counter = tf.add(counter, mask2) 205 | finalLogits2 = tf.scalar_mul(mask2, finalLogits2) 206 | 207 | counter = tf.multiply(tf.constant((2.0), dtype=tf.float32), tf.reciprocal(counter)) 208 | 209 | return tf.scalar_mul(counter, tf.add(finalLogits1, finalLogits2)) 210 | -------------------------------------------------------------------------------- /models/MCClassS.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file MCClassS.py 3 | 4 | \brief Definition of the network architecture MCClassS for classification 5 | tasks. 6 | 7 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | Germany. See the LICENSE file at the top-level directory of 9 | this distribution. 10 | 11 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 13 | 14 | import sys 15 | import os 16 | import math 17 | import tensorflow as tf 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | ROOT_DIR = os.path.dirname(BASE_DIR) 20 | sys.path.append(os.path.join(ROOT_DIR, 'tf_ops')) 21 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 22 | from MCConvBuilder import PointHierarchy, ConvolutionBuilder 23 | from MCNetworkUtils import MLP_2_hidden, batch_norm_RELU_drop_out, conv_1x1 24 | 25 | def create_network(points, batchIds, features, numInputFeatures, batchSize, k, numOutCat, isTraining, 26 | keepProbConv, keepProbFull, useConvDropOut = False, useDropOutFull = True): 27 | 28 | ############################################ Compute point hierarchy 29 | mPointHierarchy = PointHierarchy(points, features, batchIds, [0.1, 0.4, math.sqrt(3.0)+0.1], "MCClassS_PH", batchSize) 30 | 31 | 32 | ############################################ Convolutions 33 | mConvBuilder = ConvolutionBuilder(KDEWindow=0.2) 34 | 35 | #### Convolution 1 36 | convFeatures1 = mConvBuilder.create_convolution( 37 | convName = "Conv_1", 38 | inPointHierarchy = mPointHierarchy, 39 | inPointLevel=0, 40 | outPointLevel=1, 41 | inFeatures=features, 42 | inNumFeatures=numInputFeatures, 43 | outNumFeatures=k, 44 | convRadius= 0.2, 45 | multiFeatureConv=True) 46 | 47 | #### Convolution 2 48 | convFeatures1 = batch_norm_RELU_drop_out("Reduce_1_In_BN", convFeatures1, isTraining, useConvDropOut, keepProbConv) 49 | convFeatures1 = conv_1x1("Reduce_1", convFeatures1, k, k*2) 50 | convFeatures1 = batch_norm_RELU_drop_out("Reduce_1_Out_BN", convFeatures1, isTraining, useConvDropOut, keepProbConv) 51 | convFeatures2 = mConvBuilder.create_convolution( 52 | convName="Conv_2", 53 | inPointHierarchy=mPointHierarchy, 54 | inPointLevel=1, 55 | outPointLevel=2, 56 | inFeatures=convFeatures1, 57 | inNumFeatures=k*2, 58 | convRadius=0.8) 59 | 60 | #### Convolution 3 61 | convFeatures2 = batch_norm_RELU_drop_out("Reduce_2_In_BN", convFeatures2, isTraining, useConvDropOut, keepProbConv) 62 | convFeatures2 = conv_1x1("Reduce_2", convFeatures2, k*2, k*4) 63 | convFeatures2 = batch_norm_RELU_drop_out("Reduce_2_Out_BN", convFeatures2, isTraining, useConvDropOut, keepProbConv) 64 | convFeatures3 = mConvBuilder.create_convolution( 65 | convName="Conv_3", 66 | inPointHierarchy=mPointHierarchy, 67 | inPointLevel=2, 68 | outPointLevel=3, 69 | inFeatures=convFeatures2, 70 | inNumFeatures=k*4, 71 | convRadius=math.sqrt(3.0)+0.1) 72 | 73 | #Fully connected MLP - Global features. 74 | finalInput = batch_norm_RELU_drop_out("BNRELUDROP_final", convFeatures3, isTraining, useConvDropOut, keepProbConv) 75 | finalLogits = MLP_2_hidden(finalInput, k*4, k*2, k, numOutCat, "Final_Logits", keepProbFull, isTraining, useDropOutFull) 76 | 77 | return finalLogits 78 | -------------------------------------------------------------------------------- /models/MCNorm.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file MCNorm.py 3 | 4 | \brief Definition of the network architecture MCNorm for 5 | normal estimation tasks. 6 | 7 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | Germany. See the LICENSE file at the top-level directory of 9 | this distribution. 10 | 11 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 13 | 14 | import math 15 | import sys 16 | import os 17 | import tensorflow as tf 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | ROOT_DIR = os.path.dirname(BASE_DIR) 20 | sys.path.append(os.path.join(ROOT_DIR, 'tf_ops')) 21 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 22 | from MCConvBuilder import PointHierarchy, ConvolutionBuilder 23 | from MCNetworkUtils import batch_norm_RELU_drop_out, conv_1x1 24 | 25 | def create_network(points, batchIds, features, numInputFeatures, batchSize, k, isTraining, multiConv = True, useMC = True): 26 | 27 | ############################################ Compute point hierarchy 28 | mPointHierarchy = PointHierarchy(points, features, batchIds, [0.1, 0.4], "MCNorm_PH", batchSize) 29 | 30 | ############################################ Convolutions 31 | mConvBuilder = ConvolutionBuilder(KDEWindow=0.25) 32 | 33 | ############################################ Encoder 34 | 35 | # First Convolution 36 | convFeatures1 = mConvBuilder.create_convolution( 37 | convName="Conv_1", 38 | inPointHierarchy=mPointHierarchy, 39 | inPointLevel=0, 40 | inFeatures=features, 41 | inNumFeatures=numInputFeatures, 42 | outNumFeatures=k, 43 | convRadius=0.1, 44 | multiFeatureConv=True) 45 | 46 | # First Pooling 47 | bnConvFeatures1 = batch_norm_RELU_drop_out("Reduce_Pool_1_In_BN", convFeatures1, isTraining, False, False) 48 | bnConvFeatures1 = conv_1x1("Reduce_Pool_1", bnConvFeatures1, k, k*2) 49 | bnConvFeatures1 = batch_norm_RELU_drop_out("Reduce_Pool_1_Out_BN", bnConvFeatures1, isTraining, False, False) 50 | poolFeatures1 = mConvBuilder.create_convolution( 51 | convName="Pool_1", 52 | inPointHierarchy=mPointHierarchy, 53 | inPointLevel=0, 54 | outPointLevel=1, 55 | inFeatures=bnConvFeatures1, 56 | inNumFeatures=k*2, 57 | convRadius=0.2, 58 | KDEWindow= 0.2) 59 | 60 | # Second Convolution 61 | bnPoolFeatures1 = batch_norm_RELU_drop_out("Conv_2_In_BN", poolFeatures1, isTraining, False, False) 62 | convFeatures2 = mConvBuilder.create_convolution( 63 | convName="Conv_2", 64 | inPointHierarchy=mPointHierarchy, 65 | inPointLevel=1, 66 | inFeatures=bnPoolFeatures1, 67 | inNumFeatures=k*2, 68 | convRadius=0.4) 69 | convFeatures2 = tf.concat([poolFeatures1, convFeatures2], 1) 70 | 71 | # Second Pooling 72 | bnConvFeatures2 = batch_norm_RELU_drop_out("Reduce_Pool_2_In_BN", convFeatures2, isTraining, False, False) 73 | bnConvFeatures2 = conv_1x1("Reduce_Pool_2", bnConvFeatures2, k*4, k*4) 74 | bnConvFeatures2 = batch_norm_RELU_drop_out("Reduce_Pool_2_Out_BN", bnConvFeatures2, isTraining, False, False) 75 | poolFeatures2 = mConvBuilder.create_convolution( 76 | convName="Pool_2", 77 | inPointHierarchy=mPointHierarchy, 78 | inPointLevel=1, 79 | outPointLevel=2, 80 | inFeatures=bnConvFeatures2, 81 | inNumFeatures=k*4, 82 | convRadius=0.8, 83 | KDEWindow= 0.2) 84 | 85 | # Third Convolution 86 | bnPoolFeatures2 = batch_norm_RELU_drop_out("Conv_3_In_BN", poolFeatures2, isTraining, False, False) 87 | convFeatures3 = mConvBuilder.create_convolution( 88 | convName="Conv_3", 89 | inPointHierarchy=mPointHierarchy, 90 | inPointLevel=2, 91 | inFeatures=bnPoolFeatures2, 92 | inNumFeatures=k*4, 93 | convRadius=math.sqrt(3)) 94 | convFeatures3 = tf.concat([poolFeatures2, convFeatures3], 1) 95 | 96 | 97 | ##################################################### Multi-hierarchy sampling 98 | 99 | 100 | # Second upsampling 101 | bnFeatures3 = batch_norm_RELU_drop_out("Up_2_3_BN", convFeatures3, isTraining, False, False) 102 | upFeatures2_3 = mConvBuilder.create_convolution( 103 | convName="Up_2_3", 104 | inPointHierarchy=mPointHierarchy, 105 | inPointLevel=2, 106 | outPointLevel=1, 107 | inFeatures=bnFeatures3, 108 | inNumFeatures=k*8, 109 | convRadius=0.8) 110 | deConvFeatures2 = tf.concat([upFeatures2_3, convFeatures2], 1) 111 | deConvFeatures2 = batch_norm_RELU_drop_out("DeConv_2_Reduce_In_BN", deConvFeatures2, isTraining, False, False) 112 | deConvFeatures2 = conv_1x1("DeConv_2_Reduce", deConvFeatures2, k*12, k*4) 113 | deConvFeatures2 = batch_norm_RELU_drop_out("DeConv_2_Reduce_Out_BN", deConvFeatures2, isTraining, False, False) 114 | deConvFeatures2 = mConvBuilder.create_convolution( 115 | convName="DeConv_2", 116 | inPointHierarchy=mPointHierarchy, 117 | inPointLevel=1, 118 | inFeatures=deConvFeatures2, 119 | inNumFeatures=k*4, 120 | convRadius=0.4) 121 | 122 | # First upsampling 123 | bnDeConvFeatures2 = batch_norm_RELU_drop_out("Up_1_2_BN", deConvFeatures2, isTraining, False, False) 124 | upFeatures1_2 = mConvBuilder.create_convolution( 125 | convName="Up_1_2", 126 | inPointHierarchy=mPointHierarchy, 127 | inPointLevel=1, 128 | outPointLevel=0, 129 | inFeatures=bnDeConvFeatures2, 130 | inNumFeatures=k*4, 131 | convRadius=0.2) 132 | deConvFeatures1 = tf.concat([upFeatures1_2, convFeatures1], 1) 133 | deConvFeatures1 = batch_norm_RELU_drop_out("DeConv_1_Reduce_In_BN", deConvFeatures1, isTraining, False, False) 134 | deConvFeatures1 = conv_1x1("DeConv_1_Reduce", deConvFeatures1, k*5, k*2) 135 | deConvFeatures1 = batch_norm_RELU_drop_out("DeConv_1_Reduce_Out_BN", deConvFeatures1, isTraining, False, False) 136 | normals = mConvBuilder.create_convolution( 137 | convName="DeConv_1", 138 | inPointHierarchy=mPointHierarchy, 139 | inPointLevel=0, 140 | inFeatures=deConvFeatures1, 141 | inNumFeatures=k*2, 142 | outNumFeatures=3, 143 | convRadius=0.1, 144 | multiFeatureConv=True) 145 | 146 | return normals 147 | -------------------------------------------------------------------------------- /models/MCNormS.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file MCNormS.py 3 | 4 | \brief Definition of the network architecture MCNormS for normal 5 | estimation tasks. 6 | 7 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | Germany. See the LICENSE file at the top-level directory of 9 | this distribution. 10 | 11 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 13 | 14 | import math 15 | import sys 16 | import os 17 | import tensorflow as tf 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | ROOT_DIR = os.path.dirname(BASE_DIR) 20 | sys.path.append(os.path.join(ROOT_DIR, 'tf_ops')) 21 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 22 | from MCConvBuilder import PointHierarchy, ConvolutionBuilder 23 | from MCNetworkUtils import batch_norm_RELU_drop_out 24 | 25 | def create_network(points, batchIds, features, numInputFeatures, batchSize, k, isTraining, multiConv = True, useMC = True): 26 | 27 | ############################################ Compute point hierarchy 28 | mPointHierarchy = PointHierarchy(points, features, batchIds, [], "MCNormS_PH", batchSize) 29 | 30 | ############################################ Convolutions 31 | mConvBuilder = ConvolutionBuilder(KDEWindow=0.2) 32 | 33 | # Convolution 1 34 | convFeatures1 = mConvBuilder.create_convolution( 35 | convName="Conv_1", 36 | inPointHierarchy=mPointHierarchy, 37 | inPointLevel=0, 38 | inFeatures=features, 39 | inNumFeatures=numInputFeatures, 40 | outNumFeatures=k, 41 | convRadius=0.15, 42 | multiFeatureConv=True) 43 | 44 | #BatchNorm and RELU 45 | convFeatures1 = batch_norm_RELU_drop_out("BN_RELU", convFeatures1, isTraining, False, False) 46 | 47 | # Convolution 2 48 | normals = mConvBuilder.create_convolution( 49 | convName="Conv_2", 50 | inPointHierarchy=mPointHierarchy, 51 | inPointLevel=0, 52 | inFeatures=convFeatures1, 53 | inNumFeatures=k, 54 | outNumFeatures=3, 55 | convRadius=0.15, 56 | multiFeatureConv=True) 57 | 58 | return normals 59 | -------------------------------------------------------------------------------- /models/MCSeg.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file MCSeg.py 3 | 4 | \brief Definition of the network architecture MCSeg for 5 | segmentation tasks. 6 | 7 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | Germany. See the LICENSE file at the top-level directory of 9 | this distribution. 10 | 11 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 13 | 14 | import math 15 | import sys 16 | import os 17 | import tensorflow as tf 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | ROOT_DIR = os.path.dirname(BASE_DIR) 20 | sys.path.append(os.path.join(ROOT_DIR, 'tf_ops')) 21 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 22 | from MCConvBuilder import PointHierarchy, ConvolutionBuilder 23 | from MCNetworkUtils import MLP_2_hidden, batch_norm_RELU_drop_out, conv_1x1 24 | 25 | def create_network(points, batchIds, features, catLabels, numInputFeatures, numCats, numParts, batchSize, k, isTraining, 26 | keepProbConv, keepProbFull, useConvDropOut = False, useDropOutFull = True): 27 | 28 | ############################################ Compute point hierarchy 29 | mPointHierarchy = PointHierarchy(points, features, batchIds, [0.025, 0.1, 0.4], "MCSeg_PH", batchSize) 30 | 31 | ############################################ Convolutions 32 | mConvBuilder = ConvolutionBuilder(KDEWindow=0.25) 33 | 34 | ############################################ Encoder 35 | 36 | # First Convolution 37 | convFeatures1 = mConvBuilder.create_convolution( 38 | convName="Conv_1", 39 | inPointHierarchy=mPointHierarchy, 40 | inPointLevel=0, 41 | inFeatures=features, 42 | inNumFeatures=numInputFeatures, 43 | outNumFeatures=k, 44 | convRadius=0.03, 45 | multiFeatureConv=True) 46 | 47 | # First Pooling 48 | bnConvFeatures1 = batch_norm_RELU_drop_out("Reduce_Pool_1_In_BN", convFeatures1, isTraining, useConvDropOut, keepProbConv) 49 | bnConvFeatures1 = conv_1x1("Reduce_Pool_1", bnConvFeatures1, k, k*2) 50 | bnConvFeatures1 = batch_norm_RELU_drop_out("Reduce_Pool_1_Out_BN", bnConvFeatures1, isTraining, useConvDropOut, keepProbConv) 51 | poolFeatures1 = mConvBuilder.create_convolution( 52 | convName="Pool_1", 53 | inPointHierarchy=mPointHierarchy, 54 | inPointLevel=0, 55 | outPointLevel=1, 56 | inFeatures=bnConvFeatures1, 57 | inNumFeatures=k*2, 58 | convRadius=0.05, 59 | KDEWindow= 0.2) 60 | 61 | # Second Convolution 62 | bnPoolFeatures1 = batch_norm_RELU_drop_out("Conv_2_In_BN", poolFeatures1, isTraining, useConvDropOut, keepProbConv) 63 | convFeatures2 = mConvBuilder.create_convolution( 64 | convName="Conv_2", 65 | inPointHierarchy=mPointHierarchy, 66 | inPointLevel=1, 67 | inFeatures=bnPoolFeatures1, 68 | inNumFeatures=k*2, 69 | convRadius=0.1) 70 | convFeatures2 = tf.concat([poolFeatures1, convFeatures2], 1) 71 | 72 | # Second Pooling 73 | bnConvFeatures2 = batch_norm_RELU_drop_out("Reduce_Pool_2_In_BN", convFeatures2, isTraining, useConvDropOut, keepProbConv) 74 | bnConvFeatures2 = conv_1x1("Reduce_Pool_2", bnConvFeatures2, k*4, k*4) 75 | bnConvFeatures2 = batch_norm_RELU_drop_out("Reduce_Pool_2_Out_BN", bnConvFeatures2, isTraining, useConvDropOut, keepProbConv) 76 | poolFeatures2 = mConvBuilder.create_convolution( 77 | convName="Pool_2", 78 | inPointHierarchy=mPointHierarchy, 79 | inPointLevel=1, 80 | outPointLevel=2, 81 | inFeatures=bnConvFeatures2, 82 | inNumFeatures=k*4, 83 | convRadius=0.2, 84 | KDEWindow= 0.2) 85 | 86 | # Third Convolution 87 | bnPoolFeatures2 = batch_norm_RELU_drop_out("Conv_3_In_BN", poolFeatures2, isTraining, useConvDropOut, keepProbConv) 88 | convFeatures3 = mConvBuilder.create_convolution( 89 | convName="Conv_3", 90 | inPointHierarchy=mPointHierarchy, 91 | inPointLevel=2, 92 | inFeatures=bnPoolFeatures2, 93 | inNumFeatures=k*4, 94 | convRadius=0.4) 95 | convFeatures3 = tf.concat([poolFeatures2, convFeatures3], 1) 96 | 97 | # Third Pooling 98 | bnConvFeatures3 = batch_norm_RELU_drop_out("Reduce_Pool_3_In_BN", convFeatures3, isTraining, useConvDropOut, keepProbConv) 99 | bnConvFeatures3 = conv_1x1("Reduce_Pool_3", bnConvFeatures3, k*8, k*8) 100 | bnConvFeatures3 = batch_norm_RELU_drop_out("Reduce_Pool_3_Out_BN", bnConvFeatures3, isTraining, useConvDropOut, keepProbConv) 101 | poolFeatures3 = mConvBuilder.create_convolution( 102 | convName="Pool_3", 103 | inPointHierarchy=mPointHierarchy, 104 | inPointLevel=2, 105 | outPointLevel=3, 106 | inFeatures=bnConvFeatures3, 107 | inNumFeatures=k*8, 108 | convRadius=0.8, 109 | KDEWindow= 0.2) 110 | 111 | # Fourth Convolution 112 | bnPoolFeatures3 = batch_norm_RELU_drop_out("Conv_4_In_BN", poolFeatures3, isTraining, useConvDropOut, keepProbConv) 113 | convFeatures4 = mConvBuilder.create_convolution( 114 | convName="Conv_4", 115 | inPointHierarchy=mPointHierarchy, 116 | inPointLevel=3, 117 | inFeatures=bnPoolFeatures3, 118 | inNumFeatures=k*8, 119 | convRadius=math.sqrt(3.0)+0.1) 120 | convFeatures4 = tf.concat([poolFeatures3, convFeatures4], 1) 121 | 122 | 123 | ############################################ Decoder 124 | 125 | # Third upsampling 126 | bnConvFeatures4 = batch_norm_RELU_drop_out("Up_3_4_BN", convFeatures4, isTraining, useConvDropOut, keepProbConv) 127 | upFeatures3_4 = mConvBuilder.create_convolution( 128 | convName="Up_3_4", 129 | inPointHierarchy=mPointHierarchy, 130 | inPointLevel=3, 131 | outPointLevel=2, 132 | inFeatures=bnConvFeatures4, 133 | inNumFeatures=k*16, 134 | convRadius=math.sqrt(3.0)+0.1) 135 | deConvFeatures3 = tf.concat([upFeatures3_4, convFeatures3], 1) 136 | deConvFeatures3 = batch_norm_RELU_drop_out("DeConv_3_Reduce_In_BN", deConvFeatures3, isTraining, useConvDropOut, keepProbConv) 137 | deConvFeatures3 = conv_1x1("DeConv_3_Reduce", deConvFeatures3, k*24, k*8) 138 | deConvFeatures3 = batch_norm_RELU_drop_out("DeConv_3_Reduce_Out_BN", deConvFeatures3, isTraining, useConvDropOut, keepProbConv) 139 | deConvFeatures3 = mConvBuilder.create_convolution( 140 | convName="DeConv_3", 141 | inPointHierarchy=mPointHierarchy, 142 | inPointLevel=2, 143 | inFeatures=deConvFeatures3, 144 | inNumFeatures=k*8, 145 | convRadius=0.4) 146 | 147 | # Second upsampling 148 | bnDeConvFeatures3 = batch_norm_RELU_drop_out("Up_2_3_BN", deConvFeatures3, isTraining, useConvDropOut, keepProbConv) 149 | upFeatures2_3 = mConvBuilder.create_convolution( 150 | convName="Up_2_3", 151 | inPointHierarchy=mPointHierarchy, 152 | inPointLevel=2, 153 | outPointLevel=1, 154 | inFeatures=bnDeConvFeatures3, 155 | inNumFeatures=k*8, 156 | convRadius=0.2) 157 | deConvFeatures2 = tf.concat([upFeatures2_3, convFeatures2], 1) 158 | deConvFeatures2 = batch_norm_RELU_drop_out("DeConv_2_Reduce_In_BN", deConvFeatures2, isTraining, useConvDropOut, keepProbConv) 159 | deConvFeatures2 = conv_1x1("DeConv_2_Reduce", deConvFeatures2, k*12, k*4) 160 | deConvFeatures2 = batch_norm_RELU_drop_out("DeConv_2_Reduce_Out_BN", deConvFeatures2, isTraining, useConvDropOut, keepProbConv) 161 | deConvFeatures2 = mConvBuilder.create_convolution( 162 | convName="DeConv_2", 163 | inPointHierarchy=mPointHierarchy, 164 | inPointLevel=1, 165 | inFeatures=deConvFeatures2, 166 | inNumFeatures=k*4, 167 | convRadius=0.1) 168 | 169 | # First multiple upsamplings 170 | bnDeConvFeatures2 = batch_norm_RELU_drop_out("Up_1_2_BN", deConvFeatures2, isTraining, useConvDropOut, keepProbConv) 171 | upFeatures1_2 = mConvBuilder.create_convolution( 172 | convName="Up_1_2", 173 | inPointHierarchy=mPointHierarchy, 174 | inPointLevel=1, 175 | outPointLevel=0, 176 | inFeatures=bnDeConvFeatures2, 177 | inNumFeatures=k*4, 178 | convRadius=0.05) 179 | bnDeConvFeatures3 = batch_norm_RELU_drop_out("Up_1_3_BN", deConvFeatures3, isTraining, useConvDropOut, keepProbConv) 180 | upFeatures1_3 = mConvBuilder.create_convolution( 181 | convName="Up_1_3", 182 | inPointHierarchy=mPointHierarchy, 183 | inPointLevel=2, 184 | outPointLevel=0, 185 | inFeatures=bnDeConvFeatures3, 186 | inNumFeatures=k*8, 187 | convRadius=0.2) 188 | deConvFeatures1 = tf.concat([upFeatures1_2, upFeatures1_3, convFeatures1], 1) 189 | deConvFeatures1 = batch_norm_RELU_drop_out("DeConv_1_Reduce_In_BN", deConvFeatures1, isTraining, useConvDropOut, keepProbConv) 190 | deConvFeatures1 = conv_1x1("DeConv_1_Reduce", deConvFeatures1, k*13, k*4) 191 | deConvFeatures1 = batch_norm_RELU_drop_out("DeConv_1_Reduce_Out_BN", deConvFeatures1, isTraining, useConvDropOut, keepProbConv) 192 | deConvFeatures1 = mConvBuilder.create_convolution( 193 | convName="DeConv_1", 194 | inPointHierarchy=mPointHierarchy, 195 | inPointLevel=0, 196 | inFeatures=deConvFeatures1, 197 | inNumFeatures=k*4, 198 | convRadius=0.03) 199 | 200 | 201 | # Fully connected MLP - Global features. 202 | finalInput = batch_norm_RELU_drop_out("BNRELUDROP_hier_final", deConvFeatures1, isTraining, useConvDropOut, keepProbConv) 203 | #Convert cat labels 204 | catLabelOneHot = tf.one_hot(catLabels, numCats, on_value=1.0, off_value=0.0) 205 | catLabelOneHot = tf.reshape(catLabelOneHot, [-1, numCats]) 206 | finalInput = tf.concat([catLabelOneHot, finalInput], 1) 207 | finalLogits = MLP_2_hidden(finalInput, k*4 + numCats, k*4, k*2, numParts, "Final_Logits", keepProbFull, isTraining, useDropOutFull, useInitBN = False) 208 | 209 | return finalLogits 210 | -------------------------------------------------------------------------------- /models/MCSegScanNet.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file MCSeg.py 3 | 4 | \brief Definition of the network architecture MCSegScanNet for 5 | segmentation tasks on the ScanNet dataset. 6 | 7 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | Germany. See the LICENSE file at the top-level directory of 9 | this distribution. 10 | 11 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 13 | 14 | 15 | import math 16 | import sys 17 | import os 18 | import tensorflow as tf 19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 20 | ROOT_DIR = os.path.dirname(BASE_DIR) 21 | sys.path.append(os.path.join(ROOT_DIR, 'tf_ops')) 22 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 23 | from MCConvBuilder import PointHierarchy, ConvolutionBuilder 24 | from MCNetworkUtils import MLP_2_hidden, batch_norm_RELU_drop_out, conv_1x1 25 | 26 | def create_network(points, batchIds, features, numInputFeatures, numSem, batchSize, k, isTraining, 27 | keepProbConv, keepProbFull, useConvDropOut = False, useDropOutFull = True): 28 | 29 | ############################################ Compute point hierarchy 30 | mPointHierarchy = PointHierarchy(points, features, batchIds, [0.1, 0.2, 0.4, 0.8], "MCSegScanNet_PH", batchSize, False) 31 | 32 | ############################################ Convolutions 33 | mConvBuilder = ConvolutionBuilder(KDEWindow=0.25, relativeRadius=False) 34 | 35 | ############################################ Encoder 36 | 37 | # Init pooling 38 | poolFeatures0 = mConvBuilder.create_convolution( 39 | convName="Pool_0", 40 | inPointHierarchy=mPointHierarchy, 41 | inPointLevel=0, 42 | outPointLevel=1, 43 | inFeatures=features, 44 | inNumFeatures=numInputFeatures, 45 | outNumFeatures=k, 46 | convRadius=0.1, 47 | KDEWindow= 0.2, 48 | multiFeatureConv=True) 49 | 50 | # First Convolution 51 | bnPoolFeatures0 = batch_norm_RELU_drop_out("Conv_1_In_BN", poolFeatures0, isTraining, useConvDropOut, keepProbConv) 52 | convFeatures1 = mConvBuilder.create_convolution( 53 | convName="Conv_1", 54 | inPointHierarchy=mPointHierarchy, 55 | inPointLevel=1, 56 | inFeatures=bnPoolFeatures0, 57 | inNumFeatures=k, 58 | convRadius=0.4) 59 | convFeatures1 = tf.concat([poolFeatures0, convFeatures1], 1) 60 | 61 | # First Pooling 62 | bnConvFeatures1 = batch_norm_RELU_drop_out("Reduce_Pool_1_In_BN", convFeatures1, isTraining, useConvDropOut, keepProbConv) 63 | bnConvFeatures1 = conv_1x1("Reduce_Pool_1", bnConvFeatures1, k*2, k*2) 64 | bnConvFeatures1 = batch_norm_RELU_drop_out("Reduce_Pool_1_Out_BN", bnConvFeatures1, isTraining, useConvDropOut, keepProbConv) 65 | poolFeatures1 = mConvBuilder.create_convolution( 66 | convName="Pool_1", 67 | inPointHierarchy=mPointHierarchy, 68 | inPointLevel=1, 69 | outPointLevel=2, 70 | inFeatures=bnConvFeatures1, 71 | inNumFeatures=k*2, 72 | convRadius=0.4, 73 | KDEWindow= 0.2) 74 | 75 | # Second Convolution 76 | bnPoolFeatures1 = batch_norm_RELU_drop_out("Conv_2_In_BN", poolFeatures1, isTraining, useConvDropOut, keepProbConv) 77 | convFeatures2 = mConvBuilder.create_convolution( 78 | convName="Conv_2", 79 | inPointHierarchy=mPointHierarchy, 80 | inPointLevel=2, 81 | inFeatures=bnPoolFeatures1, 82 | inNumFeatures=k*2, 83 | convRadius=0.8) 84 | convFeatures2 = tf.concat([poolFeatures1, convFeatures2], 1) 85 | 86 | # Second Pooling 87 | bnConvFeatures2 = batch_norm_RELU_drop_out("Reduce_Pool_2_In_BN", convFeatures2, isTraining, useConvDropOut, keepProbConv) 88 | bnConvFeatures2 = conv_1x1("Reduce_Pool_2", bnConvFeatures2, k*4, k*4) 89 | bnConvFeatures2 = batch_norm_RELU_drop_out("Reduce_Pool_2_Out_BN", bnConvFeatures2, isTraining, useConvDropOut, keepProbConv) 90 | poolFeatures2 = mConvBuilder.create_convolution( 91 | convName="Pool_2", 92 | inPointHierarchy=mPointHierarchy, 93 | inPointLevel=2, 94 | outPointLevel=3, 95 | inFeatures=bnConvFeatures2, 96 | inNumFeatures=k*4, 97 | convRadius=0.8, 98 | KDEWindow= 0.2) 99 | 100 | # Third Convolution 101 | bnPoolFeatures2 = batch_norm_RELU_drop_out("Conv_3_In_BN", poolFeatures2, isTraining, useConvDropOut, keepProbConv) 102 | convFeatures3 = mConvBuilder.create_convolution( 103 | convName="Conv_3", 104 | inPointHierarchy=mPointHierarchy, 105 | inPointLevel=3, 106 | inFeatures=bnPoolFeatures2, 107 | inNumFeatures=k*4, 108 | convRadius=1.6) 109 | convFeatures3 = tf.concat([poolFeatures2, convFeatures3], 1) 110 | 111 | # Third Pooling 112 | bnConvFeatures3 = batch_norm_RELU_drop_out("Reduce_Pool_3_In_BN", convFeatures3, isTraining, useConvDropOut, keepProbConv) 113 | bnConvFeatures3 = conv_1x1("Reduce_Pool_3", bnConvFeatures3, k*8, k*8) 114 | bnConvFeatures3 = batch_norm_RELU_drop_out("Reduce_Pool_3_Out_BN", bnConvFeatures3, isTraining, useConvDropOut, keepProbConv) 115 | poolFeatures3 = mConvBuilder.create_convolution( 116 | convName="Pool_3", 117 | inPointHierarchy=mPointHierarchy, 118 | inPointLevel=3, 119 | outPointLevel=4, 120 | inFeatures=bnConvFeatures3, 121 | inNumFeatures=k*8, 122 | convRadius=1.6, 123 | KDEWindow= 0.2) 124 | 125 | # Fourth Convolution 126 | bnPoolFeatures3 = batch_norm_RELU_drop_out("Conv_4_In_BN", poolFeatures3, isTraining, useConvDropOut, keepProbConv) 127 | convFeatures4 = mConvBuilder.create_convolution( 128 | convName="Conv_4", 129 | inPointHierarchy=mPointHierarchy, 130 | inPointLevel=4, 131 | inFeatures=bnPoolFeatures3, 132 | inNumFeatures=k*8, 133 | convRadius=5.0) 134 | convFeatures4 = tf.concat([poolFeatures3, convFeatures4], 1) 135 | 136 | 137 | ############################################ Decoder 138 | 139 | # Third upsampling 140 | bnConvFeatures4 = batch_norm_RELU_drop_out("Up3_4_Reduce_In_BN", convFeatures4, isTraining, useConvDropOut, keepProbConv) 141 | bnConvFeatures4 = conv_1x1("Up3_4_Reduce", bnConvFeatures4, k*16, k*8) 142 | bnConvFeatures4 = batch_norm_RELU_drop_out("Up3_4_Reduce_Out_BN", bnConvFeatures4, isTraining, useConvDropOut, keepProbConv) 143 | upFeatures3_4 = mConvBuilder.create_convolution( 144 | convName="Up_3_4", 145 | inPointHierarchy=mPointHierarchy, 146 | inPointLevel=4, 147 | outPointLevel=3, 148 | inFeatures=bnConvFeatures4, 149 | inNumFeatures=k*8, 150 | convRadius=1.6) 151 | upFeatures3_4 = tf.concat([upFeatures3_4, convFeatures3], 1) 152 | deConvFeatures3 = batch_norm_RELU_drop_out("DeConv_3_Reduce_In_BN", upFeatures3_4, isTraining, useConvDropOut, keepProbConv) 153 | deConvFeatures3 = conv_1x1("DeConv_3_Reduce", deConvFeatures3, k*16, k*8) 154 | deConvFeatures3 = batch_norm_RELU_drop_out("DeConv_3_Reduce_Out_BN", deConvFeatures3, isTraining, useConvDropOut, keepProbConv) 155 | deConvFeatures3 = mConvBuilder.create_convolution( 156 | convName="DeConv_3", 157 | inPointHierarchy=mPointHierarchy, 158 | inPointLevel=3, 159 | inFeatures=deConvFeatures3, 160 | inNumFeatures=k*8, 161 | convRadius=1.6) 162 | 163 | 164 | # Second upsampling 165 | bnDeConvFeatures3 = batch_norm_RELU_drop_out("Up2_3_Reduce_In_BN", deConvFeatures3, isTraining, useConvDropOut, keepProbConv) 166 | bnDeConvFeatures3 = conv_1x1("Up2_3_Reduce", bnDeConvFeatures3, k*8, k*4) 167 | bnDeConvFeatures3 = batch_norm_RELU_drop_out("Up2_3_Reduce_Out_BN", bnDeConvFeatures3, isTraining, useConvDropOut, keepProbConv) 168 | upFeatures2_3 = mConvBuilder.create_convolution( 169 | convName="Up_2_3", 170 | inPointHierarchy=mPointHierarchy, 171 | inPointLevel=3, 172 | outPointLevel=2, 173 | inFeatures=bnDeConvFeatures3, 174 | inNumFeatures=k*4, 175 | convRadius=0.8) 176 | upFeatures2_3 = tf.concat([upFeatures2_3, convFeatures2], 1) 177 | deConvFeatures2 = batch_norm_RELU_drop_out("DeConv_2_Reduce_In_BN", upFeatures2_3, isTraining, useConvDropOut, keepProbConv) 178 | deConvFeatures2 = conv_1x1("DeConv_2_Reduce", deConvFeatures2, k*8, k*4) 179 | deConvFeatures2 = batch_norm_RELU_drop_out("DeConv_2_Reduce_Out_BN", deConvFeatures2, isTraining, useConvDropOut, keepProbConv) 180 | deConvFeatures2 = mConvBuilder.create_convolution( 181 | convName="DeConv_2", 182 | inPointHierarchy=mPointHierarchy, 183 | inPointLevel=2, 184 | inFeatures=deConvFeatures2, 185 | inNumFeatures=k*4, 186 | convRadius=0.8) 187 | 188 | 189 | # First multiple upsamplings 190 | bnDeConvFeatures2 = batch_norm_RELU_drop_out("Up1_2_Reduce_In_BN", deConvFeatures2, isTraining, useConvDropOut, keepProbConv) 191 | bnDeConvFeatures2 = conv_1x1("Up1_2_Reduce", bnDeConvFeatures2, k*4, k*2) 192 | bnDeConvFeatures2 = batch_norm_RELU_drop_out("Up1_2_Reduce_Out_BN", bnDeConvFeatures2, isTraining, useConvDropOut, keepProbConv) 193 | upFeatures1_2 = mConvBuilder.create_convolution( 194 | convName="Up_1_2", 195 | inPointHierarchy=mPointHierarchy, 196 | inPointLevel=2, 197 | outPointLevel=1, 198 | inFeatures=bnDeConvFeatures2, 199 | inNumFeatures=k*2, 200 | convRadius=0.4) 201 | bnDeConvFeatures3 = batch_norm_RELU_drop_out("Up1_3_Reduce_In_BN", deConvFeatures3, isTraining, useConvDropOut, keepProbConv) 202 | bnDeConvFeatures3 = conv_1x1("Up1_3_Reduce", bnDeConvFeatures3, k*8, k*2) 203 | bnDeConvFeatures3 = batch_norm_RELU_drop_out("Up1_3_Reduce_Out_BN", bnDeConvFeatures3, isTraining, useConvDropOut, keepProbConv) 204 | upFeatures1_3 = mConvBuilder.create_convolution( 205 | convName="Up_1_3", 206 | inPointHierarchy=mPointHierarchy, 207 | inPointLevel=3, 208 | outPointLevel=1, 209 | inFeatures=bnDeConvFeatures3, 210 | inNumFeatures=k*2, 211 | convRadius=0.8) 212 | bnDeConvFeatures4 = batch_norm_RELU_drop_out("Up1_4_Reduce_In_BN", convFeatures4, isTraining, useConvDropOut, keepProbConv) 213 | bnDeConvFeatures4 = conv_1x1("Up1_4_Reduce", bnDeConvFeatures4, k*16, k*2) 214 | bnDeConvFeatures4 = batch_norm_RELU_drop_out("Up1_4_Reduce_Out_BN", bnDeConvFeatures4, isTraining, useConvDropOut, keepProbConv) 215 | upFeatures1_4 = mConvBuilder.create_convolution( 216 | convName="Up_1_4", 217 | inPointHierarchy=mPointHierarchy, 218 | inPointLevel=4, 219 | outPointLevel=1, 220 | inFeatures=bnDeConvFeatures4, 221 | inNumFeatures=k*2, 222 | convRadius=1.6) 223 | upFeatures1 = tf.concat([upFeatures1_4, upFeatures1_3, upFeatures1_2, convFeatures1], 1) 224 | deConvFeatures1 = batch_norm_RELU_drop_out("DeConv_1_Reduce_In_BN", upFeatures1, isTraining, useConvDropOut, keepProbConv) 225 | deConvFeatures1 = conv_1x1("DeConv_1_Reduce", deConvFeatures1, k*8, k*4) 226 | deConvFeatures1 = batch_norm_RELU_drop_out("DeConv_1_Reduce_Out_BN", deConvFeatures1, isTraining, useConvDropOut, keepProbConv) 227 | deConvFeatures1 = mConvBuilder.create_convolution( 228 | convName="DeConv_1", 229 | inPointHierarchy=mPointHierarchy, 230 | inPointLevel=1, 231 | inFeatures=deConvFeatures1, 232 | inNumFeatures=k*4, 233 | convRadius=0.4) 234 | deConvFeatures1 = tf.concat([upFeatures1_4, upFeatures1_3, upFeatures1_2, convFeatures1, deConvFeatures1], 1) 235 | 236 | 237 | # Final upsampling 238 | upFeaturesFinal = batch_norm_RELU_drop_out("Up_Final_Reduce_In_BN", deConvFeatures1, isTraining, useConvDropOut, keepProbConv) 239 | upFeaturesFinal = conv_1x1("Up_Final_Reduce", upFeaturesFinal, k*12, k*4) 240 | upFeaturesFinal = batch_norm_RELU_drop_out("Up_Final_Reduce_Out_BN", upFeaturesFinal, isTraining, useConvDropOut, keepProbConv) 241 | finalFeatures = mConvBuilder.create_convolution( 242 | convName="Up_0_1", 243 | inPointHierarchy=mPointHierarchy, 244 | inPointLevel=1, 245 | outPointLevel=0, 246 | inFeatures=upFeaturesFinal, 247 | inNumFeatures=k*4, 248 | convRadius=0.1) 249 | 250 | 251 | # Fully connected MLP - Global features. 252 | finalInput = batch_norm_RELU_drop_out("BNRELUDROP_hier_final", finalFeatures, isTraining, useConvDropOut, keepProbConv) 253 | finalLogits = MLP_2_hidden(finalInput, k*4, k*4, k*2, numSem, "Final_Logits", keepProbFull, isTraining, useDropOutFull, useInitBN = False) 254 | 255 | return finalLogits 256 | -------------------------------------------------------------------------------- /teaser/Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/viscom-ulm/MCCNN/41c2db148d691686d63682fd2be9c38126b275f1/teaser/Teaser.png -------------------------------------------------------------------------------- /tf_ops/MCConvModuleSrc: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file MCConvModule.py 3 | 4 | \brief Python definition of the tensor operations. 5 | 6 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 7 | Germany. See the LICENSE file at the top-level directory of 8 | this distribution. 9 | 10 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 11 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 12 | 13 | import tensorflow as tf 14 | import sys 15 | import os 16 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.append(BASE_DIR) 18 | MCConv_module=tf.load_op_library(os.path.join(BASE_DIR, 'MCConv.so')) 19 | 20 | def compute_aabb(inPts, inBatchIds, batchSize, scaleInv = True): 21 | return MCConv_module.compute_aabb(inPts, inBatchIds, batchSize, scaleInv) 22 | tf.NoGradient('ComputeAabb') 23 | 24 | def sort_points_step1(inPts, inBatchIds, aabbMin, aabbMax, batchSize, cellSize, scaleInv): 25 | return MCConv_module.sort_points_step1(inPts, inBatchIds, aabbMin, aabbMax, batchSize, cellSize, scaleInv) 26 | tf.NoGradient('SortPointsStep1') 27 | 28 | def sort_points_step2(inPts, inBatchIds, inFeatures, keys, indexs, aabbMin, aabbMax, batchSize, cellSize, scaleInv): 29 | return MCConv_module.sort_points_step2(inPts, inBatchIds, inFeatures, keys, indexs, aabbMin, aabbMax, batchSize, cellSize, scaleInv) 30 | @tf.RegisterGradient("SortPointsStep2") 31 | def _sort_points_step2_grad(op, *grads): 32 | inPtsGrad, inFeatureGrad = MCConv_module.sort_points_step2_grad(op.inputs[4], grads[0], grads[2]) 33 | return [inPtsGrad, None, inFeatureGrad, None, None, None, None] 34 | 35 | def sort_features(inFeatures, indexs): 36 | return MCConv_module.sort_features_back_grad(indexs, inFeatures) 37 | @tf.RegisterGradient("SortFeaturesBackGrad") 38 | def _sort_features_grad(op, grads): 39 | return [None, MCConv_module.sort_features_back(grads, op.inputs[0])] 40 | 41 | def sort_features_back(inFeatures, indexs): 42 | return MCConv_module.sort_features_back(inFeatures, indexs) 43 | @tf.RegisterGradient("SortFeaturesBack") 44 | def _sort_features_back_grad(op, grads): 45 | return [MCConv_module.sort_features_back_grad(op.inputs[1], grads), None] 46 | 47 | def transform_indexs(inIndexs, inNewPositions): 48 | return MCConv_module.transform_indexs(inIndexs, inNewPositions) 49 | tf.NoGradient('TransformIndexs') 50 | 51 | def find_neighbors(inPts, inBatchIds, inPts2, cellIndexs, aabbMin, aabbMax, radius, batchSize, scaleInv): 52 | return MCConv_module.find_neighbors(inPts, inBatchIds, inPts2, cellIndexs, aabbMin, aabbMax, radius, batchSize, scaleInv) 53 | tf.NoGradient('FindNeighbors') 54 | 55 | def compute_pdf(inPts, inBatchIds, aabbMin, aabbMax, startIndexs, neighbors, window, radius, batchSize, scaleInv): 56 | return MCConv_module.compute_pdf(inPts, inBatchIds, startIndexs, neighbors, aabbMin, aabbMax, window, batchSize, radius, scaleInv) 57 | tf.NoGradient('ComputePDF') 58 | 59 | def poisson_sampling(inPts, inBatchIds, cellIndexs, aabbMin, aabbMax, radius, batchSize, scaleInv): 60 | return MCConv_module.poisson_sampling(inPts, inBatchIds, cellIndexs, aabbMin, aabbMax, radius, batchSize, scaleInv) 61 | tf.NoGradient('PoissonSampling') 62 | 63 | def get_sampled_features(inSampledIndexs, pInFeatures): 64 | return MCConv_module.get_sampled_features(inSampledIndexs, pInFeatures) 65 | @tf.RegisterGradient("GetSampledFeatures") 66 | def _get_sampled_features_grad(op, *grads): 67 | featureGrads = MCConv_module.get_sampled_features_grad(op.inputs[0], op.inputs[1], grads[0]) 68 | return [None, featureGrads] 69 | 70 | def spatial_conv(inPts, inFeatures, inBatchIds, inPDFs, inSamplePts, neighStartIndexs, packedNeighs, aabbMin, aabbMax, 71 | weights1, weights2, weightsOut, biases1, biases2, biasesOut, numOutFeatures, combin, batchSize, radius, scaleInv, avg): 72 | return MCConv_module.spatial_conv(inPts, inFeatures, inBatchIds, inPDFs, inSamplePts, neighStartIndexs, packedNeighs, aabbMin, aabbMax, 73 | weights1, biases1, weights2, biases2, weightsOut, biasesOut, numOutFeatures, combin, batchSize, radius, scaleInv, avg) 74 | @tf.RegisterGradient("SpatialConv") 75 | def _spatial_conv_grad(op, *grads): 76 | featureGrads, weights1Grads, biases1Grads, weights2Grads, biases2Grads, weightsOutGrads, biasesOutGrads = \ 77 | MCConv_module.spatial_conv_grad(op.inputs[0], op.inputs[1], op.inputs[2], op.inputs[3], op.inputs[4], op.inputs[5], 78 | op.inputs[6], op.inputs[7], op.inputs[8], op.inputs[9], op.inputs[10], op.inputs[11], op.inputs[12], op.inputs[13], 79 | op.inputs[14], grads[0], op.get_attr("num_out_features"), op.get_attr("combin"), op.get_attr("batch_size"), 80 | op.get_attr("radius"), op.get_attr("scale_inv"), op.get_attr("avg")) 81 | return [None, featureGrads, None, None, None, None, None, None, None, weights1Grads, biases1Grads, weights2Grads, biases2Grads, weightsOutGrads, biasesOutGrads] -------------------------------------------------------------------------------- /tf_ops/aabb_gpu.cc: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////////////////////////////////////////// 2 | /// \file aabb_gpu.cc 3 | /// 4 | /// \brief C++ operation definition to compute the axis aligned bounding box 5 | /// of a batch of point clouds. 6 | /// 7 | /// \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | /// Germany. See the LICENSE file at the top-level directory of 9 | /// this distribution. 10 | /// 11 | /// \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ///////////////////////////////////////////////////////////////////////////// 13 | 14 | #include "tensorflow/core/framework/op.h" 15 | #include "tensorflow/core/framework/op_kernel.h" 16 | #include "tensorflow/core/framework/shape_inference.h" 17 | #include "tensorflow/core/framework/common_shape_fns.h" 18 | #include 19 | 20 | using namespace tensorflow; 21 | 22 | REGISTER_OP("ComputeAabb") 23 | .Attr("batch_size: int") 24 | .Attr("scale_inv: bool") 25 | .Input("points : float32") 26 | .Input("batch_ids: int32") 27 | .Output("aabb_min : float32") 28 | .Output("aabb_max : float32") 29 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 30 | int batch_size; 31 | TF_RETURN_IF_ERROR(c->GetAttr("batch_size", &batch_size)); 32 | shape_inference::ShapeHandle aabbDims = c->MakeShape({batch_size, 3}); 33 | c->set_output(0, aabbDims); 34 | c->set_output(1, aabbDims); 35 | return Status::OK(); 36 | }); 37 | 38 | void computeAABB( 39 | const bool pScaleInv, const int pNumPoints, const int pBatchSize, 40 | const float* pPoints, const int* pBatchIds, float* pAABBMin, float* pAABBMax); 41 | 42 | class ComputeAABBOp : public OpKernel { 43 | public: 44 | explicit ComputeAABBOp(OpKernelConstruction* context) : OpKernel(context) { 45 | OP_REQUIRES_OK(context, context->GetAttr("batch_size", &batchSize_)); 46 | OP_REQUIRES(context, batchSize_ > 0, errors::InvalidArgument("SpatialConvolutionGradOp expects a positive batch size")); 47 | 48 | OP_REQUIRES_OK(context, context->GetAttr("scale_inv", &scaleInv_)); 49 | } 50 | 51 | void Compute(OpKernelContext* context) override { 52 | const Tensor& inPointsTensor=context->input(0); 53 | OP_REQUIRES(context, inPointsTensor.dims() == 2, errors::InvalidArgument 54 | ("ComputeAabb expects as input the following dimensions (numPoints, pointComponents)")); 55 | OP_REQUIRES(context, inPointsTensor.shape().dim_size(1) >= 3, errors::InvalidArgument 56 | ("ComputeAabb expects points with at least three components")); 57 | int numPoints = inPointsTensor.shape().dim_size(0); 58 | int pointSize = inPointsTensor.shape().dim_size(1); 59 | auto inPointsFlat = inPointsTensor.flat(); 60 | const float* inPointsPtr = &(inPointsFlat(0)); 61 | 62 | const Tensor& inBatchTensor=context->input(1); 63 | OP_REQUIRES(context, inBatchTensor.dims() == 2 && 64 | inBatchTensor.shape().dim_size(0) == inPointsTensor.shape().dim_size(0) && 65 | inBatchTensor.shape().dim_size(1) == 1, errors::InvalidArgument 66 | ("ComputeAabb expects as batch ids input the following dimensions (numPoints)")); 67 | auto inBatchFlat = inBatchTensor.flat(); 68 | const int* inBatchPtr = &(inBatchFlat(0)); 69 | 70 | Tensor* outAABBMin = NULL; 71 | Tensor* outAABBMax = NULL; 72 | OP_REQUIRES_OK(context,context->allocate_output(0, TensorShape{batchSize_, 3}, &outAABBMin)); 73 | OP_REQUIRES_OK(context,context->allocate_output(1, TensorShape{batchSize_, 3}, &outAABBMax)); 74 | auto outAABBMinFlat = outAABBMin->flat(); 75 | auto outAABBMaxFlat = outAABBMax->flat(); 76 | float* outAABBMinPtr = &(outAABBMinFlat(0)); 77 | float* outAABBMaxPtr = &(outAABBMaxFlat(0)); 78 | 79 | computeAABB(scaleInv_, numPoints, batchSize_, inPointsPtr, inBatchPtr, outAABBMinPtr, outAABBMaxPtr); 80 | } 81 | private: 82 | int batchSize_; 83 | bool scaleInv_; 84 | }; 85 | 86 | REGISTER_KERNEL_BUILDER(Name("ComputeAabb").Device(DEVICE_GPU), ComputeAABBOp); -------------------------------------------------------------------------------- /tf_ops/aabb_gpu.cu: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////////////////////////////////////////// 2 | /// \file aabb_gpu.cu 3 | /// 4 | /// \brief Cuda implementation of the operations to compute the axis aligned 5 | /// bounding box of a batch of point clouds. 6 | /// 7 | /// \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | /// Germany. See the LICENSE file at the top-level directory of 9 | /// this distribution. 10 | /// 11 | /// \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ///////////////////////////////////////////////////////////////////////////// 13 | 14 | #include 15 | #include 16 | 17 | #include "cuda_kernel_utils.h" 18 | 19 | #define POINT_BLOCK_SIZE 256 20 | 21 | ////////////////////////////////////////////////////////////////////////////////// GPU 22 | 23 | __device__ static float atomicMin(float* address, float val) 24 | { 25 | int* address_as_i = (int*) address; 26 | int old = *address_as_i, assumed; 27 | do { 28 | assumed = old; 29 | old = ::atomicCAS(address_as_i, assumed, 30 | __float_as_int(::fminf(val, __int_as_float(assumed)))); 31 | } while (assumed != old); 32 | return __int_as_float(old); 33 | } 34 | 35 | __device__ static float atomicMax(float* address, float val) 36 | { 37 | int* address_as_i = (int*) address; 38 | int old = *address_as_i, assumed; 39 | do { 40 | assumed = old; 41 | old = ::atomicCAS(address_as_i, assumed, 42 | __float_as_int(::fmaxf(val, __int_as_float(assumed)))); 43 | } while (assumed != old); 44 | return __int_as_float(old); 45 | } 46 | 47 | /** 48 | * Method to compute the bounding box of a point cloud. 49 | * @param pScaleInv Scale invariance. 50 | * @param pNumPoints Number of points. 51 | * @param pBatchSize Size of the batch. 52 | * @param pPoints List of points. 53 | * @param pBatchIds List of identifiers of the batch. 54 | * @param pAABBMin Output parameter with the minimum point of the bounding box. 55 | * @param pAABBMax Output parameter with the maximum point of the bounding box. 56 | */ 57 | __global__ void comp_AABB( 58 | const bool pScaleInv, 59 | const int pNumPoints, 60 | const int pBatchSize, 61 | const float* __restrict__ pPoints, 62 | const int* __restrict__ pBatchIds, 63 | float* __restrict__ pAABBMin, 64 | float* __restrict__ pAABBMax) 65 | { 66 | extern __shared__ float tmpSharedMemPtr[]; 67 | 68 | int currentIndex = threadIdx.x + blockIdx.x * blockDim.x; 69 | if(currentIndex < pNumPoints){ 70 | 71 | if(threadIdx.x < pBatchSize){ 72 | tmpSharedMemPtr[threadIdx.x*3] = FLT_MAX; 73 | tmpSharedMemPtr[threadIdx.x*3 + 1] = FLT_MAX; 74 | tmpSharedMemPtr[threadIdx.x*3 + 2] = FLT_MAX; 75 | tmpSharedMemPtr[pBatchSize*3 + threadIdx.x*3] = -FLT_MAX; 76 | tmpSharedMemPtr[pBatchSize*3 + threadIdx.x*3 + 1] = -FLT_MAX; 77 | tmpSharedMemPtr[pBatchSize*3 + threadIdx.x*3 + 2] = -FLT_MAX; 78 | } 79 | 80 | __syncthreads(); 81 | 82 | int batchId = pBatchIds[currentIndex]; 83 | float* aabbMin = &tmpSharedMemPtr[batchId*3]; 84 | float* aabbMax = &tmpSharedMemPtr[pBatchSize*3 + batchId*3]; 85 | 86 | int pointIndex = currentIndex * 3; 87 | atomicMin(&aabbMin[0], pPoints[pointIndex]); 88 | atomicMin(&aabbMin[1], pPoints[pointIndex+1]); 89 | atomicMin(&aabbMin[2], pPoints[pointIndex+2]); 90 | atomicMax(&aabbMax[0], pPoints[pointIndex]); 91 | atomicMax(&aabbMax[1], pPoints[pointIndex+1]); 92 | atomicMax(&aabbMax[2], pPoints[pointIndex+2]); 93 | 94 | __syncthreads(); 95 | 96 | if(threadIdx.x < pBatchSize){ 97 | if(pScaleInv){ 98 | atomicMin(&pAABBMin[threadIdx.x*3], tmpSharedMemPtr[threadIdx.x*3]); 99 | atomicMin(&pAABBMin[threadIdx.x*3 + 1], tmpSharedMemPtr[threadIdx.x*3 + 1]); 100 | atomicMin(&pAABBMin[threadIdx.x*3 + 2], tmpSharedMemPtr[threadIdx.x*3 + 2]); 101 | atomicMax(&pAABBMax[threadIdx.x*3], tmpSharedMemPtr[pBatchSize*3 + threadIdx.x*3]); 102 | atomicMax(&pAABBMax[threadIdx.x*3 + 1], tmpSharedMemPtr[pBatchSize*3 + threadIdx.x*3 + 1]); 103 | atomicMax(&pAABBMax[threadIdx.x*3 + 2], tmpSharedMemPtr[pBatchSize*3 + threadIdx.x*3 + 2]); 104 | }else{ 105 | for(int i = 0; i < pBatchSize; ++i) 106 | { 107 | atomicMin(&pAABBMin[i*3], tmpSharedMemPtr[threadIdx.x*3]); 108 | atomicMin(&pAABBMin[i*3 + 1], tmpSharedMemPtr[threadIdx.x*3 + 1]); 109 | atomicMin(&pAABBMin[i*3 + 2], tmpSharedMemPtr[threadIdx.x*3 + 2]); 110 | atomicMax(&pAABBMax[i*3], tmpSharedMemPtr[pBatchSize*3 + threadIdx.x*3]); 111 | atomicMax(&pAABBMax[i*3 + 1], tmpSharedMemPtr[pBatchSize*3 + threadIdx.x*3 + 1]); 112 | atomicMax(&pAABBMax[i*3 + 2], tmpSharedMemPtr[pBatchSize*3 + threadIdx.x*3 + 2]); 113 | } 114 | } 115 | } 116 | } 117 | } 118 | 119 | ////////////////////////////////////////////////////////////////////////////////// CPU 120 | 121 | void computeAABB( 122 | const bool pScaleInv, 123 | const int pNumPoints, 124 | const int pBatchSize, 125 | const float* pPoints, 126 | const int* pBatchIds, 127 | float* pAABBMin, 128 | float* pAABBMax) 129 | { 130 | float maxFlt[pBatchSize*3]; 131 | float minFlt[pBatchSize*3]; 132 | for(int i = 0; i < pBatchSize*3; ++i){ 133 | maxFlt[i] = FLT_MAX; 134 | minFlt[i] = -FLT_MAX; 135 | } 136 | gpuErrchk(cudaMemcpy(pAABBMin, &maxFlt[0], pBatchSize*3*sizeof(float), cudaMemcpyHostToDevice)); 137 | gpuErrchk(cudaMemcpy(pAABBMax, &minFlt[0], pBatchSize*3*sizeof(float), cudaMemcpyHostToDevice)); 138 | int numBlocksPoints = pNumPoints/POINT_BLOCK_SIZE; 139 | numBlocksPoints += (pNumPoints%POINT_BLOCK_SIZE != 0)?1:0; 140 | comp_AABB<<>>(pScaleInv, pNumPoints, pBatchSize, pPoints, pBatchIds, pAABBMin, pAABBMax); 141 | } -------------------------------------------------------------------------------- /tf_ops/compute_pdf.cc: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////////////////////////////////////////// 2 | /// \file compute_pdf.cc 3 | /// 4 | /// \brief C++ operation definition to approximate the probability 5 | /// distribution function at each sample in the different receptive 6 | /// fields. 7 | /// 8 | /// \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 9 | /// Germany. See the LICENSE file at the top-level directory of 10 | /// this distribution. 11 | /// 12 | /// \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 13 | ///////////////////////////////////////////////////////////////////////////// 14 | 15 | #include "tensorflow/core/framework/op.h" 16 | #include "tensorflow/core/framework/op_kernel.h" 17 | #include "tensorflow/core/framework/shape_inference.h" 18 | #include "tensorflow/core/framework/common_shape_fns.h" 19 | #include 20 | 21 | #include "cuda_kernel_utils.h" 22 | 23 | using namespace tensorflow; 24 | 25 | REGISTER_OP("ComputePDF") 26 | .Attr("window: float") 27 | .Attr("batch_size: int") 28 | .Attr("radius: float") 29 | .Attr("scale_inv: bool") 30 | .Input("points: float32") 31 | .Input("batch_ids: int32") 32 | .Input("start_indexs: int32") 33 | .Input("neigbors: int32") 34 | .Input("aabb_min: float32") 35 | .Input("aabb_max: float32") 36 | .Output("pdfs: float32") 37 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 38 | shape_inference::ShapeHandle outputDims = c->MakeShape({c->Dim(c->input(3), 0), 1}); 39 | c->set_output(0, outputDims); 40 | return Status::OK(); 41 | }); 42 | 43 | void computeDPFsCPU( 44 | const bool pScaleInv, 45 | const float pWindow, 46 | const int numSamples, 47 | const int pNumNeighbors, 48 | const float pRadius, 49 | const float* pInPts, 50 | const int* pInBatchIds, 51 | const float* pAABBMin, 52 | const float* pAABBMax, 53 | const int* pStartIndexs, 54 | const int* pPackedIndexs, 55 | float* pPDFs); 56 | 57 | class ComputePDFOp : public OpKernel { 58 | public: 59 | explicit ComputePDFOp(OpKernelConstruction* context) : OpKernel(context) 60 | { 61 | OP_REQUIRES_OK(context, context->GetAttr("window", &window_)); 62 | OP_REQUIRES(context, window_ > 0.0, errors::InvalidArgument("ComputePDFOp expects a positive window")); 63 | 64 | OP_REQUIRES_OK(context, context->GetAttr("radius", &radius_)); 65 | OP_REQUIRES(context, radius_ > 0.0, errors::InvalidArgument("ComputePDFOp expects a positive radius")); 66 | 67 | OP_REQUIRES_OK(context, context->GetAttr("batch_size", &batchSize_)); 68 | OP_REQUIRES(context, batchSize_ > 0, errors::InvalidArgument("ComputePDFOp expects a positive batch size")); 69 | 70 | OP_REQUIRES_OK(context, context->GetAttr("scale_inv", &scaleInv_)); 71 | } 72 | 73 | void Compute(OpKernelContext* context) override { 74 | //Process input points. 75 | const Tensor& inPointsTensor = context->input(0); 76 | OP_REQUIRES(context, inPointsTensor.dims() == 2, errors::InvalidArgument 77 | ("ComputePDFOp expects points with the following dimensions (batchSize, pointComponents)")); 78 | OP_REQUIRES(context, inPointsTensor.shape().dim_size(1) == 3, errors::InvalidArgument 79 | ("ComputePDFOp expects points with three components")); 80 | int numPoints = inPointsTensor.shape().dim_size(0); 81 | auto inPointsFlat = inPointsTensor.flat(); 82 | const float* inPointsPtr = &(inPointsFlat(0)); 83 | 84 | const Tensor& inBatchTensor=context->input(1); 85 | OP_REQUIRES(context, inBatchTensor.dims() == 2 && 86 | inBatchTensor.shape().dim_size(0) == inPointsTensor.shape().dim_size(0) && 87 | inBatchTensor.shape().dim_size(1) == 1, errors::InvalidArgument 88 | ("FindNeighborsOp expects as batch ids input the following dimensions (numPoints)")); 89 | auto inBatchFlat = inBatchTensor.flat(); 90 | const int* inBatchPtr = &(inBatchFlat(0)); 91 | 92 | //Process start indexs. 93 | const Tensor& startIndexTensor = context->input(2); 94 | OP_REQUIRES(context, startIndexTensor.dims() == 2 && 95 | startIndexTensor.shape().dim_size(1) == 1, errors::InvalidArgument 96 | ("ComputePDFOp expects a four dimension tensor for the cell indices")); 97 | int numSamples = startIndexTensor.shape().dim_size(0); 98 | auto startIndexTensorFlat = startIndexTensor.flat(); 99 | const int* startIndexTensorPtr = &(startIndexTensorFlat(0)); 100 | 101 | //Process packed neighbors. 102 | const Tensor& packedNeighTensor = context->input(3); 103 | OP_REQUIRES(context, packedNeighTensor.dims() == 2 && 104 | packedNeighTensor.shape().dim_size(1) == 2, errors::InvalidArgument 105 | ("ComputePDFOp expects a four dimension tensor for the cell indices")); 106 | int numNeighs = packedNeighTensor.shape().dim_size(0); 107 | auto packedNeighTensorFlat = packedNeighTensor.flat(); 108 | const int* packedNeighTensorPtr = &(packedNeighTensorFlat(0)); 109 | 110 | //Process input bounding box. 111 | const Tensor& inAABBMinTensor = context->input(4); 112 | OP_REQUIRES(context, inAABBMinTensor.dims() == 2 113 | && inAABBMinTensor.shape().dim_size(0) == batchSize_ && inAABBMinTensor.shape().dim_size(1) == 3, errors::InvalidArgument 114 | ("FindNeighborsOp expects a minimum point of the bounding box with 3 components")); 115 | auto inAABBMinFlat = inAABBMinTensor.flat(); 116 | const float* inAABBMinPtr = &(inAABBMinFlat(0)); 117 | 118 | const Tensor& inAABBMaxTensor = context->input(5); 119 | OP_REQUIRES(context, inAABBMaxTensor.dims() == 2 120 | && inAABBMaxTensor.shape().dim_size(0) == batchSize_ && inAABBMaxTensor.shape().dim_size(1) == 3, errors::InvalidArgument 121 | ("FindNeighborsOp expects a maximum point of the bounding box with 3 components")); 122 | auto inAABBMaxFlat = inAABBMaxTensor.flat(); 123 | const float* inAABBMaxPtr = &(inAABBMaxFlat(0)); 124 | 125 | //Create the output tensors. 126 | Tensor* pdfs = nullptr; 127 | OP_REQUIRES_OK(context,context->allocate_output(0, TensorShape{numNeighs, 1}, &pdfs)); 128 | auto pdfsFlat = pdfs->flat(); 129 | float* pdfsPtr = &(pdfsFlat(0)); 130 | 131 | //Compute the pdfs 132 | computeDPFsCPU(scaleInv_, window_, numSamples, numNeighs, radius_, inPointsPtr, inBatchPtr, inAABBMinPtr, inAABBMaxPtr, 133 | startIndexTensorPtr, packedNeighTensorPtr, pdfsPtr); 134 | } 135 | 136 | private: 137 | 138 | float window_; 139 | float radius_; 140 | int batchSize_; 141 | bool scaleInv_; 142 | }; 143 | 144 | REGISTER_KERNEL_BUILDER(Name("ComputePDF").Device(DEVICE_GPU), ComputePDFOp); -------------------------------------------------------------------------------- /tf_ops/compute_pdf.cu: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////////////////////////////////////////// 2 | /// \file compute_pdf.cu 3 | /// 4 | /// \brief Cuda implementation of the operation to approximate the 5 | /// probability distribution function at each sample in the different 6 | /// receptive fields. 7 | /// 8 | /// \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 9 | /// Germany. See the LICENSE file at the top-level directory of 10 | /// this distribution. 11 | /// 12 | /// \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 13 | ///////////////////////////////////////////////////////////////////////////// 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | #include "cuda_kernel_utils.h" 20 | 21 | #define NEIGHBOR_BLOCK_PDF_SIZE 256 22 | 23 | ////////////////////////////////////////////////////////////////////////////////// GPU 24 | 25 | /** 26 | * Method to compute the pdfs of each neighboring point. 27 | * @param pWindow Window used to compute the pdfs. 28 | * @param pNumPoints Number of points. 29 | * @param pNumNeighbors Number of neighboring points. 30 | * @param pRadius Radius of the convolution. 31 | * @param pAABBMin Minimum point of the grid (3 componenets). 32 | * @param pAABBMax Maximum point of the grid (3 componenets). 33 | * @param pPoints List of points. 34 | * @param pBatchIds List of batch ids. 35 | * @param pPoints2 List of neighboring points. 36 | * @param pStartIndexs List of the starting indices in the neighboring list. 37 | * @param pNeigbors List neighbors of each point. 38 | * @param pOutPDFs Output parameter with the pdfs. 39 | */ 40 | __global__ void computePDFs( 41 | const bool pScaleInv, 42 | const float pWindow, 43 | const int numSamples, 44 | const int pNumNeighbors, 45 | const float pRadius, 46 | const float* __restrict__ pAABBMin, 47 | const float* __restrict__ pAABBMax, 48 | const float* __restrict__ pPoints, 49 | const int* __restrict__ pBatchIds, 50 | const int* __restrict__ pStartIndexs, 51 | const int* __restrict__ pNeigbors, 52 | float* __restrict__ pOutPDFs) 53 | { 54 | int currentNeighborIndex = threadIdx.x + blockDim.x*(blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y); 55 | if(currentNeighborIndex < pNumNeighbors){ 56 | 57 | int neighborIndex = currentNeighborIndex * 2; 58 | int currentPoint = pNeigbors[neighborIndex]; 59 | float currPointCoords[3] = {pPoints[currentPoint*3], pPoints[currentPoint*3+1], pPoints[currentPoint*3+2]}; 60 | int currBatchId = pBatchIds[currentPoint]; 61 | 62 | float maxAabbSize = max(max( 63 | pAABBMax[currBatchId*3] - pAABBMin[currBatchId*3], 64 | pAABBMax[currBatchId*3+1] - pAABBMin[currBatchId*3+1]), 65 | pAABBMax[currBatchId*3+2] - pAABBMin[currBatchId*3+2]); 66 | float scaledRadius = (pScaleInv)?pRadius*maxAabbSize:pRadius; 67 | 68 | int centralPoint = pNeigbors[neighborIndex+1]; 69 | int initIter = pStartIndexs[centralPoint]; 70 | int endIter = (centralPoint < numSamples-1)?pStartIndexs[centralPoint+1]:pNumNeighbors; 71 | 72 | const float h = pWindow; 73 | const float invH = 1/h; 74 | const float invRadH = 1.0/(scaledRadius*h); 75 | float currPdf = 0.0; 76 | int iter = initIter; 77 | while(iter < endIter) 78 | { 79 | int iterPoint = pNeigbors[iter*2]*3; 80 | float iterPointCoords[3] = {pPoints[iterPoint], pPoints[iterPoint+1], pPoints[iterPoint+2]}; 81 | float diff [3] = { 82 | (iterPointCoords[0] - currPointCoords[0])*invRadH, 83 | (iterPointCoords[1] - currPointCoords[1])*invRadH, 84 | (iterPointCoords[2] - currPointCoords[2])*invRadH}; 85 | float gaussVal = invH*((0.39894228)*exp((-0.5)*diff[0]*diff[0])); 86 | gaussVal = gaussVal*invH*((0.39894228)*exp((-0.5)*diff[1]*diff[1])); 87 | gaussVal = gaussVal*invH*((0.39894228)*exp((-0.5)*diff[2]*diff[2])); 88 | currPdf += gaussVal; 89 | iter++; 90 | } 91 | 92 | pOutPDFs[currentNeighborIndex] = (currPdf)/((float)endIter-initIter); 93 | } 94 | } 95 | 96 | ////////////////////////////////////////////////////////////////////////////////// CPU 97 | 98 | void computeDPFsCPU( 99 | const bool scaleInv, 100 | const float pWindow, 101 | const int numSamples, 102 | const int pNumNeighbors, 103 | const float pRadius, 104 | const float* pInPts, 105 | const int* pInBatchIds, 106 | const float* pAABBMin, 107 | const float* pAABBMax, 108 | const int* pStartIndexs, 109 | const int* pPackedIndexs, 110 | float* pPDFs) 111 | { 112 | 113 | //Compute the PDF. 114 | dim3 gridDimension = computeBlockGrid(pNumNeighbors, NEIGHBOR_BLOCK_PDF_SIZE); 115 | 116 | computePDFs<<>>(scaleInv, pWindow, numSamples, pNumNeighbors, 117 | pRadius, pAABBMin, pAABBMax, pInPts, pInBatchIds, pStartIndexs, pPackedIndexs, pPDFs); 118 | 119 | gpuErrchk(cudaPeekAtLastError()); 120 | } -------------------------------------------------------------------------------- /tf_ops/cuda_kernel_utils.h: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////////////////////////////////////////// 2 | /// \file cuda_kernel_utils.h 3 | /// 4 | /// \brief Utilities for the cuda implementations of the tensor operations. 5 | /// 6 | /// \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 7 | /// Germany. See the LICENSE file at the top-level directory of 8 | /// this distribution. 9 | /// 10 | /// \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 11 | ///////////////////////////////////////////////////////////////////////////// 12 | 13 | #ifndef CUDA_KERNEL_UTILS_H_ 14 | #define CUDA_KERNEL_UTILS_H_ 15 | 16 | #define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); } 17 | inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) 18 | { 19 | if (code != cudaSuccess) 20 | { 21 | fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); 22 | if (abort) exit(code); 23 | } 24 | } 25 | 26 | inline dim3 computeBlockGrid(const unsigned long long int pNumElements, const int pNumThreads) 27 | { 28 | dim3 finalDimension(pNumElements/pNumThreads, 1, 1); 29 | finalDimension.x += (pNumElements%pNumThreads!= 0)?1:0; 30 | while(finalDimension.x >= 65536){ 31 | finalDimension.y *= 2; 32 | int auxDim = finalDimension.x/2; 33 | auxDim += (finalDimension.x%2!=0)?1:0; 34 | finalDimension.x = auxDim; 35 | } 36 | 37 | while(finalDimension.y >= 65536){ 38 | finalDimension.z *= 2; 39 | int auxDim = finalDimension.y/2; 40 | auxDim += (finalDimension.y%2!=0)?1:0; 41 | finalDimension.y = auxDim; 42 | } 43 | 44 | return finalDimension; 45 | } 46 | 47 | #endif 48 | -------------------------------------------------------------------------------- /tf_ops/find_neighbors.cc: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////////////////////////////////////////// 2 | /// \file find_neighbors.cc 3 | /// 4 | /// \brief C++ operation definition to find the neighboring points within a 5 | /// certain radius. 6 | /// 7 | /// \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | /// Germany. See the LICENSE file at the top-level directory of 9 | /// this distribution. 10 | /// 11 | /// \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ///////////////////////////////////////////////////////////////////////////// 13 | 14 | 15 | #include "tensorflow/core/framework/op.h" 16 | #include "tensorflow/core/framework/op_kernel.h" 17 | #include "tensorflow/core/framework/shape_inference.h" 18 | #include "tensorflow/core/framework/common_shape_fns.h" 19 | #include 20 | 21 | #include "cuda_kernel_utils.h" 22 | 23 | using namespace tensorflow; 24 | 25 | REGISTER_OP("FindNeighbors") 26 | .Attr("radius: float") 27 | .Attr("batch_size: int") 28 | .Attr("scale_inv: bool") 29 | .Input("points: float32") 30 | .Input("batch_ids: int32") 31 | .Input("points2: float32") 32 | .Input("cell_indexs: int32") 33 | .Input("aabb_min: float32") 34 | .Input("aabb_max: float32") 35 | .Output("start_indexs: int32") 36 | .Output("neigh_indexs: int32") 37 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 38 | shape_inference::ShapeHandle outputDims = c->MakeShape({c->Dim(c->input(0), 0), 1}); 39 | shape_inference::ShapeHandle outputDims2 = c->MakeShape({-1, 2}); 40 | c->set_output(0, outputDims); 41 | c->set_output(1, outputDims2); 42 | return Status::OK(); 43 | }); 44 | 45 | unsigned int countNeighborsCPU( 46 | const bool pScaleInv, 47 | const int pNumPoints, 48 | const int pNumCells, 49 | const float pRadius, 50 | const float* pInPts, 51 | const int* pInBatchIds, 52 | const float* pInPts2, 53 | const int* pCellIndexs, 54 | const float* pAABBMin, 55 | const float* pAABBMax, 56 | int* pStartIndex); 57 | 58 | void packNeighborsCPU( 59 | const bool pScaleInv, 60 | const int pNumPoints, 61 | const int pNumNeighbors, 62 | const int pNumCells, 63 | const float pRadius, 64 | const float* pInPts, 65 | const int* pInBatchIds, 66 | const float* pInPts2, 67 | const int* pCellIndexs, 68 | const float* pAABBMin, 69 | const float* pAABBMax, 70 | int* pAuxBuffOffsets, 71 | int* pAuxBuffOffsets2, 72 | int* pStartIndexs, 73 | int* pPackedIndexs); 74 | 75 | void computeAuxiliarBuffersSize( 76 | const int pNumPoints, 77 | int* PBufferSize1, 78 | int* PBufferSize2); 79 | 80 | class FindNeighborsOp : public OpKernel { 81 | public: 82 | explicit FindNeighborsOp(OpKernelConstruction* context) : OpKernel(context) 83 | { 84 | OP_REQUIRES_OK(context, context->GetAttr("radius", &radius_)); 85 | OP_REQUIRES(context, radius_ > 0.0, errors::InvalidArgument("FindNeighborsOp expects a positive radius")); 86 | 87 | OP_REQUIRES_OK(context, context->GetAttr("batch_size", &batchSize_)); 88 | OP_REQUIRES(context, batchSize_ > 0, errors::InvalidArgument("FindNeighborsOp expects a positive batch size")); 89 | 90 | OP_REQUIRES_OK(context, context->GetAttr("scale_inv", &scaleInv_)); 91 | } 92 | 93 | void Compute(OpKernelContext* context) override { 94 | //Process input points. 95 | const Tensor& inPointsTensor = context->input(0); 96 | OP_REQUIRES(context, inPointsTensor.dims() == 2, errors::InvalidArgument 97 | ("FindNeighborsOp expects points with the following dimensions (batchSize, pointComponents)")); 98 | OP_REQUIRES(context, inPointsTensor.shape().dim_size(1) == 3, errors::InvalidArgument 99 | ("FindNeighborsOp expects points with three components")); 100 | int numPoints = inPointsTensor.shape().dim_size(0); 101 | auto inPointsFlat = inPointsTensor.flat(); 102 | const float* inPointsPtr = &(inPointsFlat(0)); 103 | 104 | const Tensor& inBatchTensor=context->input(1); 105 | OP_REQUIRES(context, inBatchTensor.dims() == 2 && 106 | inBatchTensor.shape().dim_size(0) == inPointsTensor.shape().dim_size(0) && 107 | inBatchTensor.shape().dim_size(1) == 1, errors::InvalidArgument 108 | ("FindNeighborsOp expects as batch ids input the following dimensions (numPoints)")); 109 | auto inBatchFlat = inBatchTensor.flat(); 110 | const int* inBatchPtr = &(inBatchFlat(0)); 111 | 112 | //Process input points. 113 | const Tensor& inPointsTensor2 = context->input(2); 114 | OP_REQUIRES(context, inPointsTensor2.dims() == 2, errors::InvalidArgument 115 | ("FindNeighborsOp expects points with the following dimensions (batchSize, pointComponents)")); 116 | OP_REQUIRES(context, inPointsTensor2.shape().dim_size(1) == 3, errors::InvalidArgument 117 | ("FindNeighborsOp expects points with three components")); 118 | int numPoints2 = inPointsTensor2.shape().dim_size(0); 119 | auto inPointsFlat2 = inPointsTensor2.flat(); 120 | const float* inPointsPtr2 = &(inPointsFlat2(0)); 121 | 122 | //Process input cell ids. 123 | const Tensor& inCellIdsTensor = context->input(3); 124 | OP_REQUIRES(context, inCellIdsTensor.dims() == 5 && 125 | inCellIdsTensor.shape().dim_size(0) == batchSize_, errors::InvalidArgument 126 | ("FindNeighborsOp expects a four dimension tensor for the cell indices")); 127 | int numCells = inCellIdsTensor.shape().dim_size(1); 128 | auto inCellIdsFlat = inCellIdsTensor.flat(); 129 | const int* inCellIdsPtr = &(inCellIdsFlat(0)); 130 | 131 | //Process input bounding box. 132 | const Tensor& inAABBMinTensor = context->input(4); 133 | OP_REQUIRES(context, inAABBMinTensor.dims() == 2 134 | && inAABBMinTensor.shape().dim_size(0) == batchSize_ && inAABBMinTensor.shape().dim_size(1) == 3, errors::InvalidArgument 135 | ("FindNeighborsOp expects a minimum point of the bounding box with 3 components")); 136 | auto inAABBMinFlat = inAABBMinTensor.flat(); 137 | const float* inAABBMinPtr = &(inAABBMinFlat(0)); 138 | 139 | const Tensor& inAABBMaxTensor = context->input(5); 140 | OP_REQUIRES(context, inAABBMaxTensor.dims() == 2 141 | && inAABBMaxTensor.shape().dim_size(0) == batchSize_ && inAABBMaxTensor.shape().dim_size(1) == 3, errors::InvalidArgument 142 | ("FindNeighborsOp expects a maximum point of the bounding box with 3 components")); 143 | auto inAABBMaxFlat = inAABBMaxTensor.flat(); 144 | const float* inAABBMaxPtr = &(inAABBMaxFlat(0)); 145 | 146 | //Create the output tensors. 147 | Tensor* startIndexs = nullptr; 148 | OP_REQUIRES_OK(context,context->allocate_output(0, TensorShape{inPointsTensor.shape().dim_size(0), 1}, &startIndexs)); 149 | auto startIndexsFlat = startIndexs->flat(); 150 | int* startIndexsPtr = &(startIndexsFlat(0)); 151 | 152 | //Determine the number of neighbors. 153 | unsigned int numNeighs = countNeighborsCPU(scaleInv_, numPoints, numCells, radius_, 154 | inPointsPtr, inBatchPtr, inPointsPtr2, inCellIdsPtr, inAABBMinPtr, inAABBMaxPtr, startIndexsPtr); 155 | 156 | //Create the second output 157 | Tensor* neighIndexs = nullptr; 158 | OP_REQUIRES_OK(context,context->allocate_output(1, TensorShape{numNeighs, 2}, &neighIndexs)); 159 | auto neighIndexsFlat = neighIndexs->flat(); 160 | int* neighIndexsPtr = &(neighIndexsFlat(0)); 161 | 162 | //Create the temporal tensors. 163 | int tmpBuff1Size, tmpBuff2Size; 164 | computeAuxiliarBuffersSize(numPoints, &tmpBuff1Size, &tmpBuff2Size); 165 | Tensor tmpBuff1; 166 | OP_REQUIRES_OK(context,context->allocate_temp(DataTypeToEnum::value,TensorShape{tmpBuff1Size}, &tmpBuff1)); 167 | auto tmpBuff1Flat = tmpBuff1.flat(); 168 | int* tmpBuff1Ptr = &(tmpBuff1Flat(0)); 169 | Tensor tmpBuff2; 170 | OP_REQUIRES_OK(context,context->allocate_temp(DataTypeToEnum::value,TensorShape{tmpBuff2Size}, &tmpBuff2)); 171 | auto tmpBuff2Flat = tmpBuff2.flat(); 172 | int* tmpBuff2Ptr = &(tmpBuff2Flat(0)); 173 | 174 | //Pack neighbors 175 | packNeighborsCPU(scaleInv_, numPoints, numNeighs, numCells, radius_, 176 | inPointsPtr, inBatchPtr, inPointsPtr2, inCellIdsPtr, inAABBMinPtr, inAABBMaxPtr, 177 | tmpBuff1Ptr, tmpBuff2Ptr, startIndexsPtr, neighIndexsPtr); 178 | } 179 | 180 | private: 181 | 182 | float radius_; 183 | int batchSize_; 184 | bool scaleInv_; 185 | }; 186 | 187 | REGISTER_KERNEL_BUILDER(Name("FindNeighbors").Device(DEVICE_GPU), FindNeighborsOp); -------------------------------------------------------------------------------- /tf_ops/genCompileScript.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file genCompileScript.py 3 | 4 | \brief Python script to generate the compile script for unix systems. 5 | 6 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 7 | Germany. See the LICENSE file at the top-level directory of 8 | this distribution. 9 | 10 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 11 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 12 | 13 | import argparse 14 | import tensorflow as tf 15 | 16 | if __name__ == '__main__': 17 | 18 | parser = argparse.ArgumentParser(description='Generate the compile script for the MCCNN operations.') 19 | parser.add_argument('--cudaFolder', required=True, help='Path to the CUDA folder') 20 | parser.add_argument('--MLPSize', default=8, type=int, help='Size of the MLPs (default 8)') 21 | parser.add_argument('--debugInfo', action='store_true', help='Print debug information during execution (default: False)') 22 | args = parser.parse_args() 23 | 24 | debugString = " -DPRINT_CONV_INFO" if args.debugInfo else "" 25 | 26 | with open("compile.sh", "w") as myCompileScript: 27 | myCompileScript.write(args.cudaFolder+"/bin/nvcc -DBLOCK_MLP_SIZE="+str(args.MLPSize)+debugString+" -std=c++11 aabb_gpu.cu -o aabb_gpu.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC\n") 28 | myCompileScript.write(args.cudaFolder+"/bin/nvcc -DBLOCK_MLP_SIZE="+str(args.MLPSize)+debugString+" -std=c++11 sort_gpu.cu -o sort_gpu.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC\n") 29 | myCompileScript.write(args.cudaFolder+"/bin/nvcc -DBLOCK_MLP_SIZE="+str(args.MLPSize)+debugString+" -std=c++11 find_neighbors.cu -o find_neighbors.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC\n") 30 | myCompileScript.write(args.cudaFolder+"/bin/nvcc -DBLOCK_MLP_SIZE="+str(args.MLPSize)+debugString+" -std=c++11 compute_pdf.cu -o compute_pdf.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC\n") 31 | myCompileScript.write(args.cudaFolder+"/bin/nvcc -DBLOCK_MLP_SIZE="+str(args.MLPSize)+debugString+" -std=c++11 poisson_sampling.cu -o poisson_sampling.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC\n") 32 | myCompileScript.write(args.cudaFolder+"/bin/nvcc -DBLOCK_MLP_SIZE="+str(args.MLPSize)+debugString+" -std=c++11 spatial_conv.cu -o spatial_conv.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC\n") 33 | tensorflowInclude = tf.sysconfig.get_include() 34 | tensorflowLib = tf.sysconfig.get_lib() 35 | myCompileScript.write("g++ -std=c++11 -DBLOCK_MLP_SIZE="+str(args.MLPSize)+debugString+" spatial_conv.cc poisson_sampling.cc compute_pdf.cc "\ 36 | "find_neighbors.cc sort_gpu.cc aabb_gpu.cc spatial_conv.cu.o poisson_sampling.cu.o compute_pdf.cu.o "\ 37 | "find_neighbors.cu.o sort_gpu.cu.o aabb_gpu.cu.o -o MCConv.so -shared -fPIC -I"+tensorflowInclude+" -I"+tensorflowInclude+"/external/nsync/public "\ 38 | "-I"+args.cudaFolder+"/include -lcudart -L "+args.cudaFolder+"/lib64/ -L"+tensorflowLib+" -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0\n") 39 | 40 | with open("MCConvModuleSrc", "r") as mySrcPyScript: 41 | with open("MCConvModule.py", "w") as myDestPyScript: 42 | for line in mySrcPyScript: 43 | myDestPyScript.write(line) 44 | myDestPyScript.write("\n") 45 | myDestPyScript.write("\n") 46 | myDestPyScript.write("def get_block_size():\n") 47 | myDestPyScript.write(" return "+str(args.MLPSize)+"\n") 48 | myDestPyScript.write("\n") 49 | 50 | -------------------------------------------------------------------------------- /tf_ops/poisson_sampling.cu: -------------------------------------------------------------------------------- 1 | ///////////////////////////////////////////////////////////////////////////// 2 | /// \file poisson_sampling.cu 3 | /// 4 | /// \brief Cuda implementation of the operations to perform a poisson disk 5 | /// sampling on a batch of point clouds (O(n)), to obtain the 6 | /// associated features to the selected points, and to propagate the 7 | /// feature gradients. 8 | /// 9 | /// \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 10 | /// Germany. See the LICENSE file at the top-level directory of 11 | /// this distribution. 12 | /// 13 | /// \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 14 | ///////////////////////////////////////////////////////////////////////////// 15 | 16 | #include 17 | #include 18 | 19 | #include "cuda_kernel_utils.h" 20 | 21 | #define BLOCK_SIZE 4 22 | #define PT_BLOCK_SIZE 128 23 | 24 | ////////////////////////////////////////////////////////////////////////////////// GPU 25 | 26 | __constant__ int cellOffsetsPool[27][3]; 27 | 28 | /** 29 | * Method to select a set of points from a point cloud in which all of them are at 30 | * distance [pRadius*0.5, pRadius]. 31 | * @param scaleInv Scale invariant. 32 | * @param pCurrBatch Current batch processed. 33 | * @param pCurrentCell Integer with the current cell of the block. 34 | * @param pNumPoints Number of points. 35 | * @param pBatchSize Size of the batch. 36 | * @param pNumCells Number of cells of the grid. 37 | * @param pRadius Radius of the possion disk. 38 | * @param pAABBMinPoint Minimum point of the grid (3 componenets). 39 | * @param pAABBMaxPoint Maximum point of the grid (3 componenets). 40 | * @param pPoints List of points. 41 | * @param pBatchIds List of the batch identifies. 42 | * @param pPDFs List of pdfs of each point. 43 | * @param pCellIndexs Indexs of the grid cells. 44 | * @param pAuxBoleanBuffer Input/Output parameter with the list of booleans indicating 45 | * if a point was selected. 46 | * @param pOutSampledPoints Output parameter with the list of sampled points. 47 | * @param pOutSampleBatchIds Output parameter with the list of sampled batch ids. 48 | * @param pOutSampleIndexs Output parameter with the list of indexs of the sampled points. 49 | * @param pOutNumSelectedPoints Output parameter with the number of selected points. 50 | */ 51 | __global__ void selectSamples( 52 | const bool scaleInv, 53 | const int pCurrBatch, 54 | const int pCurrentCell, 55 | const int pNumPoints, 56 | const int pBatchSize, 57 | const int pNumCells, 58 | const float pRadius, 59 | const float* __restrict__ pAABBMinPoint, 60 | const float* __restrict__ pAABBMaxPoint, 61 | const float* __restrict__ pPoints, 62 | const int* __restrict__ pBatchIds, 63 | const int* __restrict__ pCellIndexs, 64 | bool* __restrict__ pAuxBooleanBuffer, 65 | float* __restrict__ pOutSampledPoints, 66 | int* __restrict__ pOutSampleBatchIds, 67 | int* __restrict__ pOutSampleIndexs, 68 | int* __restrict__ pOutNumSelectedPoints) 69 | { 70 | int xCell = (threadIdx.x + blockIdx.x * blockDim.x)*3 + 1 + cellOffsetsPool[pCurrentCell][0]; 71 | int yCell = (threadIdx.y + blockIdx.y * blockDim.y)*3 + 1 + cellOffsetsPool[pCurrentCell][1]; 72 | int zCell = (threadIdx.z + blockIdx.z * blockDim.z)*3 + 1 + cellOffsetsPool[pCurrentCell][2]; 73 | 74 | if(xCell < pNumCells && yCell < pNumCells & zCell < pNumCells){ 75 | 76 | float maxAabbSize = max(max( 77 | pAABBMaxPoint[pCurrBatch*3] - pAABBMinPoint[pCurrBatch*3], 78 | pAABBMaxPoint[pCurrBatch*3 + 1] - pAABBMinPoint[pCurrBatch*3 + 1]), 79 | pAABBMaxPoint[pCurrBatch*3 + 2] - pAABBMinPoint[pCurrBatch*3 + 2]); 80 | float radius = (scaleInv)?pRadius*maxAabbSize:pRadius; 81 | 82 | int cellIndex = pCurrBatch*pNumCells*pNumCells*pNumCells + xCell*pNumCells*pNumCells + yCell*pNumCells + zCell; 83 | int initPoint = pCellIndexs[cellIndex*2]; 84 | int endPoint = pCellIndexs[cellIndex*2 +1]; 85 | for(int i = initPoint; i < endPoint; ++i) 86 | { 87 | float centralCoords[3] = {pPoints[i*3], pPoints[i*3+1], pPoints[i*3+2]}; 88 | bool collision = false; 89 | 90 | for(int neighIter = 0; (neighIter < 27) && !collision; ++neighIter) 91 | { 92 | int currCellIndex[3] = {xCell+cellOffsetsPool[neighIter][0], yCell+cellOffsetsPool[neighIter][1], zCell+cellOffsetsPool[neighIter][2]}; 93 | if(currCellIndex[0] >= 0 && currCellIndex[0] < pNumCells && 94 | currCellIndex[1] >= 0 && currCellIndex[1] < pNumCells && 95 | currCellIndex[2] >= 0 && currCellIndex[2] < pNumCells) 96 | { 97 | int cellIndexFlat = pCurrBatch*pNumCells*pNumCells*pNumCells + currCellIndex[0]*pNumCells*pNumCells + currCellIndex[1]*pNumCells + currCellIndex[2]; 98 | int initNeighIndex = pCellIndexs[cellIndexFlat*2]; 99 | int endNeighIndex = pCellIndexs[cellIndexFlat*2 + 1]; 100 | for(int j = initNeighIndex; (j < endNeighIndex) && !collision; ++j) 101 | { 102 | int currPointIndex = j * 3; 103 | float currentCoords[3] = {pPoints[currPointIndex], pPoints[currPointIndex+1], pPoints[currPointIndex+2]}; 104 | float diffVector[3] = {currentCoords[0] - centralCoords[0], currentCoords[1] - centralCoords[1], currentCoords[2] - centralCoords[2]}; 105 | float pointDist = sqrt(diffVector[0]*diffVector[0] + diffVector[1]*diffVector[1] + diffVector[2]*diffVector[2]); 106 | if(pointDist < radius && pAuxBooleanBuffer[j]){ 107 | collision = true; 108 | } 109 | } 110 | } 111 | } 112 | 113 | if(!collision){ 114 | pAuxBooleanBuffer[i] = true; 115 | int finalPointIndex = atomicAdd(&pOutNumSelectedPoints[0], 1); 116 | pOutSampledPoints[finalPointIndex*3] = centralCoords[0]; 117 | pOutSampledPoints[finalPointIndex*3+1] = centralCoords[1]; 118 | pOutSampledPoints[finalPointIndex*3+2] = centralCoords[2]; 119 | pOutSampleBatchIds[finalPointIndex] = pCurrBatch; 120 | pOutSampleIndexs[finalPointIndex] = i; 121 | } 122 | } 123 | } 124 | } 125 | 126 | 127 | /** 128 | * Method to get the features of the sampled points. 129 | * @param pNumSamples Number of samples. 130 | * @param pNumFeatures Number of features. 131 | * @param pSampledIndexs List of indexs of the sampled points. 132 | * @param pFeatures List of input features. 133 | * @param pOutSampledFeatures List of output sampled features. 134 | */ 135 | __global__ void selectFeatureSamples( 136 | const int pNumSamples, 137 | const int pNumFeatures, 138 | const int* __restrict__ pSampledIndexs, 139 | const float* __restrict__ pFeatures, 140 | float* __restrict__ pOutSampledFeatures) 141 | { 142 | int currentIndex = threadIdx.x + blockIdx.x * blockDim.x; 143 | int sampleIndex = currentIndex/pNumFeatures; 144 | int featureIndex = currentIndex%pNumFeatures; 145 | if(sampleIndex < pNumSamples){ 146 | pOutSampledFeatures[currentIndex] = pFeatures[pSampledIndexs[sampleIndex]*pNumFeatures + featureIndex]; 147 | } 148 | } 149 | 150 | /** 151 | * Method to get the gradients of the features of the sampled points. 152 | * @param pNumSamples Number of samples. 153 | * @param pNumFeatures Number of features. 154 | * @param pSampledIndexs List of indexs of the sampled points. 155 | * @param pFeaturesGrads List of gradients of output features. 156 | * @param pOutSampledFeaturesGrads List of output gradients of input features. 157 | */ 158 | __global__ void selectFeatureSamplesGrad( 159 | const int pNumSamples, 160 | const int pNumFeatures, 161 | const int* __restrict__ pSampledIndexs, 162 | const float* __restrict__ pFeaturesGrads, 163 | float* __restrict__ pOutSampledFeaturesGrads) 164 | { 165 | int currentIndex = threadIdx.x + blockIdx.x * blockDim.x; 166 | int sampleIndex = currentIndex/pNumFeatures; 167 | int featureIndex = currentIndex%pNumFeatures; 168 | if(sampleIndex < pNumSamples){ 169 | pOutSampledFeaturesGrads[pSampledIndexs[sampleIndex]*pNumFeatures + featureIndex] = pFeaturesGrads[currentIndex]; 170 | } 171 | } 172 | 173 | ////////////////////////////////////////////////////////////////////////////////// CPU 174 | 175 | int samplePointCloud( 176 | const bool scaleInv, 177 | const float pRadius, 178 | const int pNumPoints, 179 | const int pBatchSize, 180 | const int pNumCells, 181 | const float* pAABBMin, 182 | const float* pAABBMax, 183 | const float* pPoints, 184 | const int* pBatchIds, 185 | const int* pCellIndexs, 186 | float* pSelectedPts, 187 | int* pSelectedBatchIds, 188 | int* pSelectedIndexs, 189 | bool* pAuxBoolBuffer) 190 | { 191 | //Init device symbols. 192 | int cellOffsetsPoolCPU[27][3] = { 193 | {1, 1, -1}, {0, -1, 1}, {0, 1, 1}, {0, 1, 0}, {0, 0, 1}, {0, -1, 0}, {-1, 1, -1}, 194 | {0, -1, -1}, {1, 0, 0}, {1, -1, 1}, {1, 0, 1}, {-1, 1, 1}, {-1, 0, 0}, {1, -1, -1}, 195 | {0, 1, -1}, {-1, -1, 0}, {-1, 1, 0}, {0, 0, 0}, {0, 0, -1}, {1, 1, 0}, {1, 0, -1}, 196 | {1, -1, 0}, {-1, 0, 1}, {1, 1, 1}, {-1, 0, -1}, {-1, -1, -1}, {-1, -1, 1}}; 197 | cudaMemcpyToSymbol(cellOffsetsPool, cellOffsetsPoolCPU, 27*3*sizeof(int)); 198 | int numSelectedPointsCPU = 0; 199 | 200 | gpuErrchk(cudaMemset(pAuxBoolBuffer, 0, sizeof(bool)*pNumPoints)); 201 | 202 | int* numSelectedPoints; 203 | gpuErrchk(cudaMalloc(&numSelectedPoints, sizeof(int))); 204 | gpuErrchk(cudaMemset(numSelectedPoints, 0, sizeof(int))); 205 | 206 | int numPhaseGroups = pNumCells/3; 207 | numPhaseGroups += (pNumCells%3!=0)?1:0; 208 | int numBlocks = numPhaseGroups/BLOCK_SIZE; 209 | numBlocks += (numPhaseGroups%BLOCK_SIZE!=0)?1:0; 210 | for(int b = 0; b < pBatchSize; ++b){ 211 | for(int i = 0; i < 27; ++i){ 212 | selectSamples<<>> 213 | (scaleInv, b, i, pNumPoints, pBatchSize, pNumCells, pRadius, pAABBMin, 214 | pAABBMax, pPoints, pBatchIds, pCellIndexs, pAuxBoolBuffer, pSelectedPts, 215 | pSelectedBatchIds, pSelectedIndexs, numSelectedPoints); 216 | 217 | gpuErrchk(cudaPeekAtLastError()); 218 | } 219 | } 220 | 221 | //Copy from GPU the number of selected samples. 222 | gpuErrchk(cudaMemcpy(&numSelectedPointsCPU, numSelectedPoints, sizeof(int), cudaMemcpyDeviceToHost)); 223 | gpuErrchk(cudaFree(numSelectedPoints)); 224 | 225 | #ifdef PRINT_CONV_INFO 226 | printf("Num Cells: %d | Input points: %d | Result pooling: %d\n", pNumCells, pNumPoints, numSelectedPointsCPU); 227 | #endif 228 | 229 | return numSelectedPointsCPU; 230 | } 231 | 232 | void copyPoints( 233 | float* pSelectedPts, 234 | int* pSelectedBatchIds, 235 | int* pSelectedIndexs, 236 | const int pNumPts, 237 | float* pDestPts, 238 | int* pDestBatchIds, 239 | int* pDestIndexs) 240 | { 241 | gpuErrchk(cudaMemcpy(pDestPts, pSelectedPts, sizeof(float)*3*pNumPts, cudaMemcpyDeviceToDevice)); 242 | gpuErrchk(cudaMemcpy(pDestBatchIds, pSelectedBatchIds, sizeof(int)*pNumPts, cudaMemcpyDeviceToDevice)); 243 | gpuErrchk(cudaMemcpy(pDestIndexs, pSelectedIndexs, sizeof(int)*pNumPts, cudaMemcpyDeviceToDevice)); 244 | } 245 | 246 | void getFeaturesSampledPoints( 247 | int pNumPoints, 248 | int pNumFeatures, 249 | int pNumSampledPoints, 250 | const int* pInPointsIndexs, 251 | const float* pInFeature, 252 | float* pOutSelFeatures) 253 | { 254 | int numBlocksPoints = pNumSampledPoints/PT_BLOCK_SIZE; 255 | numBlocksPoints += (pNumSampledPoints%PT_BLOCK_SIZE != 0)?1:0; 256 | selectFeatureSamples<<>>(pNumSampledPoints, pNumFeatures, pInPointsIndexs, pInFeature, pOutSelFeatures); 257 | gpuErrchk(cudaPeekAtLastError()); 258 | } 259 | 260 | void getFeaturesSampledPointsGradients( 261 | int pNumPoints, 262 | int pNumFeatures, 263 | int pNumSampledPoints, 264 | const int* pInPointsIndexs, 265 | const float* pInOutFeatureGrad, 266 | float* pOutInFeaturesGradients) 267 | { 268 | gpuErrchk(cudaMemset(pOutInFeaturesGradients, 0, sizeof(int)*pNumFeatures*pNumPoints)); 269 | 270 | int numBlocksPoints = pNumSampledPoints/PT_BLOCK_SIZE; 271 | numBlocksPoints += (pNumSampledPoints%PT_BLOCK_SIZE != 0)?1:0; 272 | selectFeatureSamplesGrad<<>>(pNumSampledPoints, pNumFeatures, pInPointsIndexs, pInOutFeatureGrad, pOutInFeaturesGradients); 273 | gpuErrchk(cudaPeekAtLastError()); 274 | } 275 | -------------------------------------------------------------------------------- /utils/GenerateSphereMeshes.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file GenerateShpereMeshes.py 3 | 4 | \brief Script to create a ply scene from a point cloud with an sphere 5 | for each point. 6 | 7 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 8 | Germany. See the LICENSE file at the top-level directory of 9 | this distribution. 10 | 11 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 12 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 13 | 14 | import sys 15 | import argparse 16 | import os 17 | from os import listdir 18 | from os.path import isfile, join 19 | import numpy as np 20 | 21 | from PyUtils import visualize_progress 22 | 23 | def icosahedron(): 24 | PHI = (1.0 + np.sqrt(5.0)) / 2.0 25 | sphereLength = np.sqrt(PHI*PHI + 1.0) 26 | dist1 = PHI/sphereLength 27 | dist2 = 1.0/sphereLength 28 | 29 | verts = [ 30 | [-dist2, dist1, 0], [ dist2, dist1, 0], [-dist2, -dist1, 0], [ dist2, -dist1, 0], 31 | [0, -dist2, dist1], [0, dist2, dist1], [0, -dist2, -dist1], [0, dist2, -dist1], 32 | [ dist1, 0, -dist2], [ dist1, 0, dist2], [-dist1, 0, -dist2], [-dist1, 0, dist2] 33 | ] 34 | 35 | faces = [ 36 | [0, 11, 5], [0, 5, 1], [0, 1, 7], [0, 7, 10], [0, 10, 11], 37 | [1, 5, 9], [5, 11, 4], [11, 10, 2], [10, 7, 6], [7, 1, 8], 38 | [3, 9, 4], [3, 4, 2], [3, 2, 6], [3, 6, 8], [3, 8, 9], 39 | [4, 9, 5], [2, 4, 11], [6, 2, 10], [8, 6, 7], [9, 8, 1] 40 | ] 41 | 42 | return verts, faces 43 | 44 | def createEdgeIndex(index1, index2, totalVerts): 45 | if index1 > index2: 46 | auxVal = index1 47 | index1 = index2 48 | index2 = auxVal 49 | index1 *= totalVerts 50 | outIndex = index1 + index2 51 | return outIndex 52 | 53 | def subdivide(verts, faces): 54 | triangles = len(faces) 55 | edgeMap = dict([]) 56 | currLength = len(verts) 57 | for faceIndex in xrange(triangles): 58 | face = faces[faceIndex] 59 | v0 = verts[face[0]] 60 | v1 = verts[face[1]] 61 | v2 = verts[face[2]] 62 | 63 | v3EdgeIndex = createEdgeIndex(face[0], face[1], currLength) 64 | v3Index = -1 65 | if v3EdgeIndex in edgeMap: 66 | v3Index = edgeMap[v3EdgeIndex] 67 | else: 68 | newVert = np.array([(v0[0]+v1[0])*0.5, (v0[1]+v1[1])*0.5, (v0[2]+v1[2])*0.5]) 69 | length = np.linalg.norm(newVert) 70 | verts.append([newVert[0]/length, newVert[1]/length, newVert[2]/length]) 71 | edgeMap[v3EdgeIndex] = len(verts) - 1 72 | v3Index = len(verts) - 1 73 | 74 | v4EdgeIndex = createEdgeIndex(face[1], face[2], currLength) 75 | v4Index = -1 76 | if v4EdgeIndex in edgeMap: 77 | v4Index = edgeMap[v4EdgeIndex] 78 | else: 79 | newVert = np.array([(v1[0]+v2[0])*0.5, (v1[1]+v2[1])*0.5, (v1[2]+v2[2])*0.5]) 80 | length = np.linalg.norm(newVert) 81 | verts.append([newVert[0]/length, newVert[1]/length, newVert[2]/length]) 82 | edgeMap[v4EdgeIndex] = len(verts) - 1 83 | v4Index = len(verts) - 1 84 | 85 | v5EdgeIndex = createEdgeIndex(face[0], face[2], currLength) 86 | v5Index = -1 87 | if v5EdgeIndex in edgeMap: 88 | v5Index = edgeMap[v5EdgeIndex] 89 | else: 90 | newVert = np.array([(v0[0]+v2[0])*0.5, (v0[1]+v2[1])*0.5, (v0[2]+v2[2])*0.5]) 91 | length = np.linalg.norm(newVert) 92 | verts.append([newVert[0]/length, newVert[1]/length, newVert[2]/length]) 93 | edgeMap[v5EdgeIndex] = len(verts) - 1 94 | v5Index = len(verts) - 1 95 | 96 | faces.append([v3Index, v4Index, v5Index]) 97 | faces.append([face[0], v3Index, v5Index]) 98 | faces.append([v3Index, face[1], v4Index]) 99 | faces[faceIndex] = [v5Index, v4Index, face[2]] 100 | 101 | return verts, faces 102 | 103 | def load_model(modelsPath): 104 | points = [] 105 | colors = [] 106 | with open(modelsPath, 'r') as modelFile: 107 | for line in modelFile: 108 | line = line.replace("\n", "") 109 | currPoint = line.split(',') 110 | points.append([float(currPoint[0]), float(currPoint[1]), float(currPoint[2])]) 111 | colors.append([int(currPoint[3]), int(currPoint[4]), int(currPoint[5])]) 112 | return points, colors 113 | 114 | def save_model_ply(modelName, points, colors, sphPts, sphFaces, sphScale): 115 | coordMax = np.amax(points, axis=0) 116 | coordMin = np.amin(points, axis=0) 117 | aabbSize = (1.0/np.amax(coordMax - coordMin))*sphScale 118 | 119 | newModelName = modelName[:-4]+"_spheres.ply" 120 | with open(newModelName, 'w') as myFile: 121 | myFile.write("ply\n") 122 | myFile.write("format ascii 1.0\n") 123 | myFile.write("element vertex "+ str(len(sphPts)*len(points))+"\n") 124 | myFile.write("property float x\n") 125 | myFile.write("property float y\n") 126 | myFile.write("property float z\n") 127 | myFile.write("property uchar red\n") 128 | myFile.write("property uchar green\n") 129 | myFile.write("property uchar blue\n") 130 | myFile.write("element face "+ str(len(sphFaces)*len(points))+"\n") 131 | myFile.write("property list uchar int vertex_index\n") 132 | myFile.write("end_header\n") 133 | 134 | for point, color in zip(points, colors): 135 | for currSphPt in sphPts: 136 | currPtFlt = [aabbSize*currSphPt[0]+point[0], aabbSize*currSphPt[1]+point[1], aabbSize*currSphPt[2]+point[2]] 137 | myFile.write(str(currPtFlt[0])+" "+str(currPtFlt[1])+" "+str(currPtFlt[2])+" "+str(color[0])+" "+ str(color[1])+ " "+str(color[2])+"\n") 138 | 139 | offset = 0 140 | for i in range(len(points)): 141 | for currSphFace in sphFaces: 142 | myFile.write("3 "+str(currSphFace[0]+offset)+" "+str(currSphFace[1]+offset)+" "+str(currSphFace[2]+offset)+"\n") 143 | offset += len(sphPts) 144 | 145 | myFile.close() 146 | 147 | 148 | if __name__ == '__main__': 149 | 150 | parser = argparse.ArgumentParser(description='Script to generate a 3D model with a sphere for each point.') 151 | parser.add_argument('--inFolder', default='SphereModels', help='Folder of the input/output models (default: SphereModels)') 152 | parser.add_argument('--sphSub', default=2, type=int, help='Number of subdivisions applied to the sphere models (default: 2)') 153 | parser.add_argument('--sphScaling', default=0.005, type=float, help='Scaling applied to the sphere models (default: 0.005)') 154 | args = parser.parse_args() 155 | 156 | files = [f for f in listdir(args.inFolder+"/") if isfile(join(args.inFolder+"/", f))] 157 | 158 | sphPts, sphFaces = icosahedron() 159 | for i in range(args.sphSub): 160 | sphPts, sphFaces = subdivide(sphPts, sphFaces) 161 | 162 | iter = 0 163 | for currFile in files: 164 | points, colors = load_model(args.inFolder+"/"+currFile) 165 | save_model_ply(args.inFolder+"/"+currFile, points, colors, sphPts, sphFaces, args.sphScaling) 166 | visualize_progress(iter, len(files)) 167 | iter += 1 168 | -------------------------------------------------------------------------------- /utils/MCNetworkUtils.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file MCNetworkUtils.py 3 | 4 | \brief Helper functions to build neural networks. 5 | 6 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 7 | Germany. See the LICENSE file at the top-level directory of 8 | this distribution. 9 | 10 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 11 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 12 | 13 | import tensorflow as tf 14 | import os 15 | import sys 16 | import math 17 | 18 | ############################################################################# Network Utils 19 | 20 | def MLP_2_hidden(features, numInputFeatures, hidden1_units, hidden2_units, numOutFeatures, 21 | layerName, keepProb, isTraining, useDropOut = False, useInitBN = True): 22 | """Method to create the graph of a MLP of two hidden layers. 23 | 24 | Args: 25 | features (nxm tensor): Input features. 26 | numInputFeatures (int): Number of input features. 27 | hidden1_units (int): Number of units in the first hidden layer. 28 | hidden2_units (int): Number of units in the second hidden layer. 29 | numOutFeatures (int): Number of output features. 30 | layerName (string): Name of the MLP. 31 | keepProb (tensor): Tensor with the probability to maintain a input in the MLP. 32 | isTraining (tensor): Tensor with a boolean that indicates if the MLP is executed 33 | in a training mode or not. 34 | useDropOut (bool): Boolean that indicates if dropout should be used in the MLP. 35 | useInitBN (bool): Boolean that indicates if an initial batch normalization should be used. 36 | """ 37 | 38 | initializer = tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_AVG', uniform=True) 39 | initializerBiases = tf.zeros_initializer() 40 | 41 | if useInitBN: 42 | features = tf.layers.batch_normalization(inputs = features, training = isTraining, name = layerName+"_BN_Init") 43 | 44 | # Hidden 1 45 | weights = tf.get_variable(layerName+'_weights1', [numInputFeatures, hidden1_units], initializer=initializer) 46 | tf.add_to_collection('weight_decay_loss', weights) 47 | biases = tf.get_variable(layerName+'_biases1', [hidden1_units], initializer=initializerBiases) 48 | mul1 = tf.matmul(features, weights) + biases 49 | mul1 = tf.layers.batch_normalization(inputs = mul1, training = isTraining, name = layerName+"_BN_h1") 50 | hidden1 = tf.nn.relu(mul1) 51 | 52 | # Hidden 2 53 | if useDropOut: 54 | hidden1 = tf.nn.dropout(hidden1, keepProb) 55 | weights = tf.get_variable(layerName+'_weights2', [hidden1_units, hidden2_units]) 56 | tf.add_to_collection('weight_decay_loss', weights) 57 | biases = tf.get_variable(layerName+'_biases2', [hidden2_units], initializer=initializerBiases) 58 | mul2 = tf.matmul(hidden1, weights) + biases 59 | mul2 = tf.layers.batch_normalization(inputs = mul2, training = isTraining, name = layerName+"_BN_h2") 60 | hidden2 = tf.nn.relu(mul2) 61 | 62 | # Linear 63 | if useDropOut: 64 | hidden2 = tf.nn.dropout(hidden2, keepProb) 65 | weights = tf.get_variable(layerName+'_weights3', [hidden2_units, numOutFeatures], initializer=initializer) 66 | tf.add_to_collection('weight_decay_loss', weights) 67 | biases = tf.get_variable(layerName+'_biases3', [numOutFeatures], initializer=initializerBiases) 68 | logits = tf.matmul(hidden2, weights) + biases 69 | return logits 70 | 71 | 72 | def MLP_1_hidden(features, numInputFeatures, hidden_units, numOutFeatures, layerName, 73 | keepProb, isTraining, useDropOut = False): 74 | """Method to create the graph of a MLP of one hidden layers. 75 | 76 | Args: 77 | features (nxm tensor): Input features. 78 | numInputFeatures (int): Number of input features. 79 | hidden_units (int): Number of units in the hidden layer. 80 | numOutFeatures (int): Number of output features. 81 | layerName (string): Name of the MLP. 82 | keepProb (tensor): Tensor with the probability to maintain a input in the MLP. 83 | isTraining (tensor): Tensor with a boolean that indicates if the MLP is executed 84 | in a training mode or not. 85 | useDropOut (bool): Boolean that indicates if dropout should be used in the MLP. 86 | """ 87 | 88 | initializer = tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_AVG', uniform=True) 89 | initializerBiases = tf.zeros_initializer() 90 | 91 | # Hidden 1 92 | weights = tf.get_variable(layerName+'_weights1', [numInputFeatures, hidden_units], initializer=initializer) 93 | tf.add_to_collection('weight_decay_loss', weights) 94 | biases = tf.get_variable(layerName+'_biases1', [hidden_units], initializer=initializerBiases) 95 | mul = tf.matmul(features, weights) + biases 96 | mul = tf.layers.batch_normalization(inputs = mul, training = isTraining, name = layerName+"_BN_h") 97 | hidden = tf.nn.relu(mul) 98 | 99 | # Linear 100 | if useDropOut: 101 | hidden = tf.nn.dropout(hidden, keepProb) 102 | weights = tf.get_variable(layerName+'_weights2', [hidden_units, numOutFeatures], initializer=initializer) 103 | tf.add_to_collection('weight_decay_loss', weights) 104 | biases = tf.get_variable(layerName+'_biases2', [numOutFeatures], initializer=initializerBiases) 105 | linear = tf.matmul(hidden, weights) + biases 106 | return linear 107 | 108 | 109 | def conv_1x1(layerName, inputs, numInputs, numOutFeatures): 110 | """Method to create a fully connected layer to compute a new set of features 111 | by combining the input features. 112 | 113 | Args: 114 | layerName (string): Name of the layer. 115 | inputs (nxm tensor): Input features. 116 | numInputs (int): Number of input features. 117 | numOutFeatures (int): Number of output features. 118 | """ 119 | 120 | initializer = tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_AVG', uniform=True) 121 | initializerBiases = tf.zeros_initializer() 122 | weights = tf.get_variable(layerName+'_weights', [numInputs, numOutFeatures], initializer=initializer) 123 | tf.add_to_collection('weight_decay_loss', weights) 124 | biases = tf.get_variable(layerName+'_biases', [numOutFeatures], initializer=initializerBiases) 125 | reducedOutput = tf.matmul(inputs, weights) + biases 126 | return reducedOutput 127 | 128 | 129 | def batch_norm_RELU_drop_out(layerName, inFeatures, isTraining, usedDropOut, keepProb): 130 | """Method to create a combination of layers: Batch norm + RELU + Drop out. 131 | 132 | Args: 133 | layerName (string): Name of the layer. 134 | inFeatures (nxm tensor): Input features. 135 | isTraining (tensor): Tensor with a boolean that indicates if the MLP is executed 136 | in a training mode or not. 137 | useDropOut (bool): Boolean that indicates if dropout should be used in the MLP. 138 | keepProb (tensor): Tensor with the probability to maintain a input in the MLP. 139 | """ 140 | inFeatures = tf.layers.batch_normalization(inputs = inFeatures, training = isTraining, name = layerName+"_BN") 141 | inFeatures = tf.nn.relu(inFeatures) 142 | if usedDropOut: 143 | inFeatures = tf.nn.dropout(inFeatures, keepProb) 144 | return inFeatures 145 | -------------------------------------------------------------------------------- /utils/PyUtils.py: -------------------------------------------------------------------------------- 1 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 2 | \file PyUtils.py 3 | 4 | \brief File with sevaral utils for python applications. 5 | 6 | \copyright Copyright (c) 2018 Visual Computing group of Ulm University, 7 | Germany. See the LICENSE file at the top-level directory of 8 | this distribution. 9 | 10 | \author pedro hermosilla (pedro-1.hermosilla-casajus@uni-ulm.de) 11 | ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' 12 | 13 | import sys 14 | 15 | def visualize_progress(val, maxVal, description="", barWidth=20): 16 | """Method to visualize the progress of a process in the console. 17 | 18 | Args: 19 | val (int): Current step in the process. 20 | maxVal (int): Maximum numbef of step of the process. 21 | description (string): String to be displayed at the current step. 22 | barWidth (int): Size of the progress bar displayed. 23 | """ 24 | 25 | progress = int((val*barWidth) / maxVal) 26 | progressBar = ['='] * (progress) + ['>'] + ['.'] * (barWidth - (progress+1)) 27 | progressBar = ''.join(progressBar) 28 | initBar = "%5d/%5d" % (val + 1, maxVal) 29 | print(initBar + ' [' + progressBar + '] ' + description) 30 | sys.stdout.flush() 31 | 32 | def save_model(modelName, points, labels = None, colors = None, modLabel = 0): 33 | """Method to save a model into a txt file. 34 | 35 | Args: 36 | modelName (string): Path of the model to be saved. 37 | points (nx3 np.array): List of points of the model. 38 | labels (nxm np.array): List of point labels. 39 | colors (nx3 array): Color associated to each label. If None is provided, 40 | the method will save the labels instead. 41 | modLabel (int): Integer value that will be used to apply the mod operation 42 | to each label. 43 | """ 44 | 45 | with open(modelName+".txt", 'w') as myFile: 46 | for it, point in enumerate(points): 47 | 48 | myFile.write(str(point[0])+",") 49 | myFile.write(str(point[1])+",") 50 | myFile.write(str(point[2])) 51 | 52 | if not(labels is None): 53 | if not(colors is None): 54 | currLabel = int(labels[it][0]) 55 | if modLabel > 0: 56 | currLabel = currLabel%modLabel 57 | currColor = colors[currLabel] 58 | myFile.write(","+str(currColor[0])+",") 59 | myFile.write(str(currColor[1])+",") 60 | myFile.write(str(currColor[2])) 61 | else: 62 | currLabels = labels[it] 63 | for label in currLabels: 64 | myFile.write(","+str(label)) 65 | myFile.write("\n") 66 | 67 | myFile.close() --------------------------------------------------------------------------------