├── .gitignore ├── LICENSE ├── README.md ├── doc └── teaser.png ├── evaluate.py ├── models ├── pointnet_cls.py ├── pointnet_cls_basic.py ├── pointnet_seg.py └── transform_nets.py ├── part_seg ├── download_data.sh ├── pointnet_part_seg.py ├── test.py ├── testing_ply_file_list.txt └── train.py ├── provider.py ├── sem_seg ├── README.md ├── batch_inference.py ├── collect_indoor3d_data.py ├── download_data.sh ├── eval_iou_accuracy.py ├── gen_indoor3d_h5.py ├── indoor3d_util.py ├── meta │ ├── all_data_label.txt │ ├── anno_paths.txt │ ├── area6_data_label.txt │ └── class_names.txt ├── model.py └── train.py ├── train.py ├── train_pytorch.py └── utils ├── .DS_Store ├── data_prep_util.py ├── data_utils.py ├── eulerangles.py ├── model.py ├── pc_util.py ├── plyfile.py ├── tf_util.py ├── util_funcs.py └── util_layers.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Hanxiao Deng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointCNN.PyTorch 2 | This is a PyTorch implementation of [PointCNN](https://github.com/yangyanli/PointCNN). It is as efficent as the origin Tensorflow implemetation and achieves same accuracy on both classification and segmentaion jobs. See the following references for more information: 3 | ``` 4 | "PointCNN" 5 | Yangyan Li, Rui Bu, Mingchao Sun, Baoquan Chen 6 | arXiv preprint arXiv:1801.07791, 2018. 7 | ``` 8 | [https://arxiv.org/abs/1801.07791](https://arxiv.org/abs/1801.07791) 9 | 10 | 11 | # Usage 12 | We've tested code on ModelNet40 only. 13 | 14 | ```python 15 | python train_pytorch.py 16 | ``` 17 | 18 | # License 19 | Our code is released under MIT License (see LICENSE file for details). 20 | -------------------------------------------------------------------------------- /doc/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hxdengBerkeley/PointCNN.Pytorch/6ec6c291cf97923a84fb6ed8c82e98bf01e7e96d/doc/teaser.png -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import argparse 4 | import socket 5 | import importlib 6 | import time 7 | import os 8 | import scipy.misc 9 | import sys 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(BASE_DIR) 12 | sys.path.append(os.path.join(BASE_DIR, 'models')) 13 | sys.path.append(os.path.join(BASE_DIR, 'utils')) 14 | import provider 15 | import pc_util 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 20 | parser.add_argument('--model', default='pointnet_cls', help='Model name: pointnet_cls or pointnet_cls_basic [default: pointnet_cls]') 21 | parser.add_argument('--batch_size', type=int, default=4, help='Batch Size during training [default: 1]') 22 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]') 23 | parser.add_argument('--model_path', default='log/model.ckpt', help='model checkpoint file path [default: log/model.ckpt]') 24 | parser.add_argument('--dump_dir', default='dump', help='dump folder path [dump]') 25 | parser.add_argument('--visu', action='store_true', help='Whether to dump image for error case [default: False]') 26 | FLAGS = parser.parse_args() 27 | 28 | 29 | BATCH_SIZE = FLAGS.batch_size 30 | NUM_POINT = FLAGS.num_point 31 | MODEL_PATH = FLAGS.model_path 32 | GPU_INDEX = FLAGS.gpu 33 | MODEL = importlib.import_module(FLAGS.model) # import network module 34 | DUMP_DIR = FLAGS.dump_dir 35 | if not os.path.exists(DUMP_DIR): os.mkdir(DUMP_DIR) 36 | LOG_FOUT = open(os.path.join(DUMP_DIR, 'log_evaluate.txt'), 'w') 37 | LOG_FOUT.write(str(FLAGS)+'\n') 38 | 39 | NUM_CLASSES = 40 40 | SHAPE_NAMES = [line.rstrip() for line in \ 41 | open(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/shape_names.txt'))] 42 | 43 | HOSTNAME = socket.gethostname() 44 | 45 | # ModelNet40 official train/test split 46 | TRAIN_FILES = provider.getDataFiles( \ 47 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt')) 48 | TEST_FILES = provider.getDataFiles(\ 49 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt')) 50 | 51 | def log_string(out_str): 52 | LOG_FOUT.write(out_str+'\n') 53 | LOG_FOUT.flush() 54 | print(out_str) 55 | 56 | def evaluate(num_votes): 57 | is_training = False 58 | 59 | with tf.device('/gpu:'+str(GPU_INDEX)): 60 | pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT) 61 | is_training_pl = tf.placeholder(tf.bool, shape=()) 62 | 63 | # simple model 64 | pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl) 65 | loss = MODEL.get_loss(pred, labels_pl, end_points) 66 | 67 | # Add ops to save and restore all the variables. 68 | saver = tf.train.Saver() 69 | 70 | # Create a session 71 | config = tf.ConfigProto() 72 | config.gpu_options.allow_growth = True 73 | config.allow_soft_placement = True 74 | config.log_device_placement = True 75 | sess = tf.Session(config=config) 76 | 77 | # Restore variables from disk. 78 | saver.restore(sess, MODEL_PATH) 79 | log_string("Model restored.") 80 | 81 | ops = {'pointclouds_pl': pointclouds_pl, 82 | 'labels_pl': labels_pl, 83 | 'is_training_pl': is_training_pl, 84 | 'pred': pred, 85 | 'loss': loss} 86 | 87 | eval_one_epoch(sess, ops, num_votes) 88 | 89 | 90 | def eval_one_epoch(sess, ops, num_votes=1, topk=1): 91 | error_cnt = 0 92 | is_training = False 93 | total_correct = 0 94 | total_seen = 0 95 | loss_sum = 0 96 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 97 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 98 | fout = open(os.path.join(DUMP_DIR, 'pred_label.txt'), 'w') 99 | for fn in range(len(TEST_FILES)): 100 | log_string('----'+str(fn)+'----') 101 | current_data, current_label = provider.loadDataFile(TEST_FILES[fn]) 102 | current_data = current_data[:,0:NUM_POINT,:] 103 | current_label = np.squeeze(current_label) 104 | print(current_data.shape) 105 | 106 | file_size = current_data.shape[0] 107 | num_batches = file_size // BATCH_SIZE 108 | print(file_size) 109 | 110 | for batch_idx in range(num_batches): 111 | start_idx = batch_idx * BATCH_SIZE 112 | end_idx = (batch_idx+1) * BATCH_SIZE 113 | cur_batch_size = end_idx - start_idx 114 | 115 | # Aggregating BEG 116 | batch_loss_sum = 0 # sum of losses for the batch 117 | batch_pred_sum = np.zeros((cur_batch_size, NUM_CLASSES)) # score for classes 118 | batch_pred_classes = np.zeros((cur_batch_size, NUM_CLASSES)) # 0/1 for classes 119 | for vote_idx in range(num_votes): 120 | rotated_data = provider.rotate_point_cloud_by_angle(current_data[start_idx:end_idx, :, :], 121 | vote_idx/float(num_votes) * np.pi * 2) 122 | feed_dict = {ops['pointclouds_pl']: rotated_data, 123 | ops['labels_pl']: current_label[start_idx:end_idx], 124 | ops['is_training_pl']: is_training} 125 | loss_val, pred_val = sess.run([ops['loss'], ops['pred']], 126 | feed_dict=feed_dict) 127 | batch_pred_sum += pred_val 128 | batch_pred_val = np.argmax(pred_val, 1) 129 | for el_idx in range(cur_batch_size): 130 | batch_pred_classes[el_idx, batch_pred_val[el_idx]] += 1 131 | batch_loss_sum += (loss_val * cur_batch_size / float(num_votes)) 132 | # pred_val_topk = np.argsort(batch_pred_sum, axis=-1)[:,-1*np.array(range(topk))-1] 133 | # pred_val = np.argmax(batch_pred_classes, 1) 134 | pred_val = np.argmax(batch_pred_sum, 1) 135 | # Aggregating END 136 | 137 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) 138 | # correct = np.sum(pred_val_topk[:,0:topk] == label_val) 139 | total_correct += correct 140 | total_seen += cur_batch_size 141 | loss_sum += batch_loss_sum 142 | 143 | for i in range(start_idx, end_idx): 144 | l = current_label[i] 145 | total_seen_class[l] += 1 146 | total_correct_class[l] += (pred_val[i-start_idx] == l) 147 | fout.write('%d, %d\n' % (pred_val[i-start_idx], l)) 148 | 149 | if pred_val[i-start_idx] != l and FLAGS.visu: # ERROR CASE, DUMP! 150 | img_filename = '%d_label_%s_pred_%s.jpg' % (error_cnt, SHAPE_NAMES[l], 151 | SHAPE_NAMES[pred_val[i-start_idx]]) 152 | img_filename = os.path.join(DUMP_DIR, img_filename) 153 | output_img = pc_util.point_cloud_three_views(np.squeeze(current_data[i, :, :])) 154 | scipy.misc.imsave(img_filename, output_img) 155 | error_cnt += 1 156 | 157 | log_string('eval mean loss: %f' % (loss_sum / float(total_seen))) 158 | log_string('eval accuracy: %f' % (total_correct / float(total_seen))) 159 | log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)))) 160 | 161 | class_accuracies = np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float) 162 | for i, name in enumerate(SHAPE_NAMES): 163 | log_string('%10s:\t%0.3f' % (name, class_accuracies[i])) 164 | 165 | 166 | 167 | if __name__=='__main__': 168 | with tf.Graph().as_default(): 169 | evaluate(num_votes=1) 170 | LOG_FOUT.close() 171 | -------------------------------------------------------------------------------- /models/pointnet_cls.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | import sys 5 | import os 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | sys.path.append(BASE_DIR) 8 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 9 | import tf_util 10 | from transform_nets import input_transform_net, feature_transform_net 11 | 12 | def placeholder_inputs(batch_size, num_point): 13 | pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) 14 | labels_pl = tf.placeholder(tf.int32, shape=(batch_size)) 15 | return pointclouds_pl, labels_pl 16 | 17 | 18 | def get_model(point_cloud, is_training, bn_decay=None): 19 | """ Classification PointNet, input is BxNx3, output Bx40 """ 20 | batch_size = point_cloud.get_shape()[0].value 21 | num_point = point_cloud.get_shape()[1].value 22 | end_points = {} 23 | 24 | with tf.variable_scope('transform_net1') as sc: 25 | transform = input_transform_net(point_cloud, is_training, bn_decay, K=3) 26 | point_cloud_transformed = tf.matmul(point_cloud, transform) 27 | input_image = tf.expand_dims(point_cloud_transformed, -1) 28 | 29 | net = tf_util.conv2d(input_image, 64, [1,3], 30 | padding='VALID', stride=[1,1], 31 | bn=True, is_training=is_training, 32 | scope='conv1', bn_decay=bn_decay) 33 | net = tf_util.conv2d(net, 64, [1,1], 34 | padding='VALID', stride=[1,1], 35 | bn=True, is_training=is_training, 36 | scope='conv2', bn_decay=bn_decay) 37 | 38 | with tf.variable_scope('transform_net2') as sc: 39 | transform = feature_transform_net(net, is_training, bn_decay, K=64) 40 | end_points['transform'] = transform 41 | net_transformed = tf.matmul(tf.squeeze(net, axis=[2]), transform) 42 | net_transformed = tf.expand_dims(net_transformed, [2]) 43 | 44 | net = tf_util.conv2d(net_transformed, 64, [1,1], 45 | padding='VALID', stride=[1,1], 46 | bn=True, is_training=is_training, 47 | scope='conv3', bn_decay=bn_decay) 48 | net = tf_util.conv2d(net, 128, [1,1], 49 | padding='VALID', stride=[1,1], 50 | bn=True, is_training=is_training, 51 | scope='conv4', bn_decay=bn_decay) 52 | net = tf_util.conv2d(net, 1024, [1,1], 53 | padding='VALID', stride=[1,1], 54 | bn=True, is_training=is_training, 55 | scope='conv5', bn_decay=bn_decay) 56 | 57 | # Symmetric function: max pooling 58 | net = tf_util.max_pool2d(net, [num_point,1], 59 | padding='VALID', scope='maxpool') 60 | 61 | net = tf.reshape(net, [batch_size, -1]) 62 | net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, 63 | scope='fc1', bn_decay=bn_decay) 64 | net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training, 65 | scope='dp1') 66 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, 67 | scope='fc2', bn_decay=bn_decay) 68 | net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training, 69 | scope='dp2') 70 | net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3') 71 | 72 | return net, end_points 73 | 74 | 75 | def get_loss(pred, label, end_points, reg_weight=0.001): 76 | """ pred: B*NUM_CLASSES, 77 | label: B, """ 78 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=label) 79 | classify_loss = tf.reduce_mean(loss) 80 | tf.summary.scalar('classify loss', classify_loss) 81 | 82 | # Enforce the transformation as orthogonal matrix 83 | transform = end_points['transform'] # BxKxK 84 | K = transform.get_shape()[1].value 85 | mat_diff = tf.matmul(transform, tf.transpose(transform, perm=[0,2,1])) 86 | mat_diff -= tf.constant(np.eye(K), dtype=tf.float32) 87 | mat_diff_loss = tf.nn.l2_loss(mat_diff) 88 | tf.summary.scalar('mat loss', mat_diff_loss) 89 | 90 | return classify_loss + mat_diff_loss * reg_weight 91 | 92 | 93 | if __name__=='__main__': 94 | with tf.Graph().as_default(): 95 | inputs = tf.zeros((32,1024,3)) 96 | outputs = get_model(inputs, tf.constant(True)) 97 | print(outputs) 98 | -------------------------------------------------------------------------------- /models/pointnet_cls_basic.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | import sys 5 | import os 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | sys.path.append(BASE_DIR) 8 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 9 | import tf_util 10 | 11 | def placeholder_inputs(batch_size, num_point): 12 | pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) 13 | labels_pl = tf.placeholder(tf.int32, shape=(batch_size)) 14 | return pointclouds_pl, labels_pl 15 | 16 | 17 | def get_model(point_cloud, is_training, bn_decay=None): 18 | """ Classification PointNet, input is BxNx3, output Bx40 """ 19 | batch_size = point_cloud.get_shape()[0].value 20 | num_point = point_cloud.get_shape()[1].value 21 | end_points = {} 22 | input_image = tf.expand_dims(point_cloud, -1) 23 | 24 | # Point functions (MLP implemented as conv2d) 25 | net = tf_util.conv2d(input_image, 64, [1,3], 26 | padding='VALID', stride=[1,1], 27 | bn=True, is_training=is_training, 28 | scope='conv1', bn_decay=bn_decay) 29 | net = tf_util.conv2d(net, 64, [1,1], 30 | padding='VALID', stride=[1,1], 31 | bn=True, is_training=is_training, 32 | scope='conv2', bn_decay=bn_decay) 33 | 34 | net = tf_util.conv2d(net, 64, [1,1], 35 | padding='VALID', stride=[1,1], 36 | bn=True, is_training=is_training, 37 | scope='conv3', bn_decay=bn_decay) 38 | net = tf_util.conv2d(net, 128, [1,1], 39 | padding='VALID', stride=[1,1], 40 | bn=True, is_training=is_training, 41 | scope='conv4', bn_decay=bn_decay) 42 | net = tf_util.conv2d(net, 1024, [1,1], 43 | padding='VALID', stride=[1,1], 44 | bn=True, is_training=is_training, 45 | scope='conv5', bn_decay=bn_decay) 46 | 47 | # Symmetric function: max pooling 48 | net = tf_util.max_pool2d(net, [num_point,1], 49 | padding='VALID', scope='maxpool') 50 | 51 | # MLP on global point cloud vector 52 | net = tf.reshape(net, [batch_size, -1]) 53 | net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, 54 | scope='fc1', bn_decay=bn_decay) 55 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, 56 | scope='fc2', bn_decay=bn_decay) 57 | net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training, 58 | scope='dp1') 59 | net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3') 60 | 61 | return net, end_points 62 | 63 | 64 | def get_loss(pred, label, end_points): 65 | """ pred: B*NUM_CLASSES, 66 | label: B, """ 67 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=label) 68 | classify_loss = tf.reduce_mean(loss) 69 | tf.summary.scalar('classify loss', classify_loss) 70 | return classify_loss 71 | 72 | 73 | if __name__=='__main__': 74 | with tf.Graph().as_default(): 75 | inputs = tf.zeros((32,1024,3)) 76 | outputs = get_model(inputs, tf.constant(True)) 77 | print(outputs) 78 | -------------------------------------------------------------------------------- /models/pointnet_seg.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | import sys 5 | import os 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | sys.path.append(BASE_DIR) 8 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 9 | import tf_util 10 | from transform_nets import input_transform_net, feature_transform_net 11 | 12 | def placeholder_inputs(batch_size, num_point): 13 | pointclouds_pl = tf.placeholder(tf.float32, 14 | shape=(batch_size, num_point, 3)) 15 | labels_pl = tf.placeholder(tf.int32, 16 | shape=(batch_size, num_point)) 17 | return pointclouds_pl, labels_pl 18 | 19 | 20 | def get_model(point_cloud, is_training, bn_decay=None): 21 | """ Classification PointNet, input is BxNx3, output BxNx50 """ 22 | batch_size = point_cloud.get_shape()[0].value 23 | num_point = point_cloud.get_shape()[1].value 24 | end_points = {} 25 | 26 | with tf.variable_scope('transform_net1') as sc: 27 | transform = input_transform_net(point_cloud, is_training, bn_decay, K=3) 28 | point_cloud_transformed = tf.matmul(point_cloud, transform) 29 | input_image = tf.expand_dims(point_cloud_transformed, -1) 30 | 31 | net = tf_util.conv2d(input_image, 64, [1,3], 32 | padding='VALID', stride=[1,1], 33 | bn=True, is_training=is_training, 34 | scope='conv1', bn_decay=bn_decay) 35 | net = tf_util.conv2d(net, 64, [1,1], 36 | padding='VALID', stride=[1,1], 37 | bn=True, is_training=is_training, 38 | scope='conv2', bn_decay=bn_decay) 39 | 40 | with tf.variable_scope('transform_net2') as sc: 41 | transform = feature_transform_net(net, is_training, bn_decay, K=64) 42 | end_points['transform'] = transform 43 | net_transformed = tf.matmul(tf.squeeze(net, axis=[2]), transform) 44 | point_feat = tf.expand_dims(net_transformed, [2]) 45 | print(point_feat) 46 | 47 | net = tf_util.conv2d(point_feat, 64, [1,1], 48 | padding='VALID', stride=[1,1], 49 | bn=True, is_training=is_training, 50 | scope='conv3', bn_decay=bn_decay) 51 | net = tf_util.conv2d(net, 128, [1,1], 52 | padding='VALID', stride=[1,1], 53 | bn=True, is_training=is_training, 54 | scope='conv4', bn_decay=bn_decay) 55 | net = tf_util.conv2d(net, 1024, [1,1], 56 | padding='VALID', stride=[1,1], 57 | bn=True, is_training=is_training, 58 | scope='conv5', bn_decay=bn_decay) 59 | global_feat = tf_util.max_pool2d(net, [num_point,1], 60 | padding='VALID', scope='maxpool') 61 | print(global_feat) 62 | 63 | global_feat_expand = tf.tile(global_feat, [1, num_point, 1, 1]) 64 | concat_feat = tf.concat(3, [point_feat, global_feat_expand]) 65 | print(concat_feat) 66 | 67 | net = tf_util.conv2d(concat_feat, 512, [1,1], 68 | padding='VALID', stride=[1,1], 69 | bn=True, is_training=is_training, 70 | scope='conv6', bn_decay=bn_decay) 71 | net = tf_util.conv2d(net, 256, [1,1], 72 | padding='VALID', stride=[1,1], 73 | bn=True, is_training=is_training, 74 | scope='conv7', bn_decay=bn_decay) 75 | net = tf_util.conv2d(net, 128, [1,1], 76 | padding='VALID', stride=[1,1], 77 | bn=True, is_training=is_training, 78 | scope='conv8', bn_decay=bn_decay) 79 | net = tf_util.conv2d(net, 128, [1,1], 80 | padding='VALID', stride=[1,1], 81 | bn=True, is_training=is_training, 82 | scope='conv9', bn_decay=bn_decay) 83 | 84 | net = tf_util.conv2d(net, 50, [1,1], 85 | padding='VALID', stride=[1,1], activation_fn=None, 86 | scope='conv10') 87 | net = tf.squeeze(net, [2]) # BxNxC 88 | 89 | return net, end_points 90 | 91 | 92 | def get_loss(pred, label, end_points, reg_weight=0.001): 93 | """ pred: BxNxC, 94 | label: BxN, """ 95 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=label) 96 | classify_loss = tf.reduce_mean(loss) 97 | tf.scalar_summary('classify loss', classify_loss) 98 | 99 | # Enforce the transformation as orthogonal matrix 100 | transform = end_points['transform'] # BxKxK 101 | K = transform.get_shape()[1].value 102 | mat_diff = tf.matmul(transform, tf.transpose(transform, perm=[0,2,1])) 103 | mat_diff -= tf.constant(np.eye(K), dtype=tf.float32) 104 | mat_diff_loss = tf.nn.l2_loss(mat_diff) 105 | tf.scalar_summary('mat_loss', mat_diff_loss) 106 | 107 | return classify_loss + mat_diff_loss * reg_weight 108 | 109 | 110 | if __name__=='__main__': 111 | with tf.Graph().as_default(): 112 | inputs = tf.zeros((32,1024,3)) 113 | outputs = get_model(inputs, tf.constant(True)) 114 | print(outputs) 115 | -------------------------------------------------------------------------------- /models/transform_nets.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys 4 | import os 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(BASE_DIR) 7 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 8 | import tf_util 9 | 10 | def input_transform_net(point_cloud, is_training, bn_decay=None, K=3): 11 | """ Input (XYZ) Transform Net, input is BxNx3 gray image 12 | Return: 13 | Transformation matrix of size 3xK """ 14 | batch_size = point_cloud.get_shape()[0].value 15 | num_point = point_cloud.get_shape()[1].value 16 | 17 | input_image = tf.expand_dims(point_cloud, -1) 18 | net = tf_util.conv2d(input_image, 64, [1,3], 19 | padding='VALID', stride=[1,1], 20 | bn=True, is_training=is_training, 21 | scope='tconv1', bn_decay=bn_decay) 22 | net = tf_util.conv2d(net, 128, [1,1], 23 | padding='VALID', stride=[1,1], 24 | bn=True, is_training=is_training, 25 | scope='tconv2', bn_decay=bn_decay) 26 | net = tf_util.conv2d(net, 1024, [1,1], 27 | padding='VALID', stride=[1,1], 28 | bn=True, is_training=is_training, 29 | scope='tconv3', bn_decay=bn_decay) 30 | net = tf_util.max_pool2d(net, [num_point,1], 31 | padding='VALID', scope='tmaxpool') 32 | 33 | net = tf.reshape(net, [batch_size, -1]) 34 | net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, 35 | scope='tfc1', bn_decay=bn_decay) 36 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, 37 | scope='tfc2', bn_decay=bn_decay) 38 | 39 | with tf.variable_scope('transform_XYZ') as sc: 40 | assert(K==3) 41 | weights = tf.get_variable('weights', [256, 3*K], 42 | initializer=tf.constant_initializer(0.0), 43 | dtype=tf.float32) 44 | biases = tf.get_variable('biases', [3*K], 45 | initializer=tf.constant_initializer(0.0), 46 | dtype=tf.float32) 47 | biases += tf.constant([1,0,0,0,1,0,0,0,1], dtype=tf.float32) 48 | transform = tf.matmul(net, weights) 49 | transform = tf.nn.bias_add(transform, biases) 50 | 51 | transform = tf.reshape(transform, [batch_size, 3, K]) 52 | return transform 53 | 54 | 55 | def feature_transform_net(inputs, is_training, bn_decay=None, K=64): 56 | """ Feature Transform Net, input is BxNx1xK 57 | Return: 58 | Transformation matrix of size KxK """ 59 | batch_size = inputs.get_shape()[0].value 60 | num_point = inputs.get_shape()[1].value 61 | 62 | net = tf_util.conv2d(inputs, 64, [1,1], 63 | padding='VALID', stride=[1,1], 64 | bn=True, is_training=is_training, 65 | scope='tconv1', bn_decay=bn_decay) 66 | net = tf_util.conv2d(net, 128, [1,1], 67 | padding='VALID', stride=[1,1], 68 | bn=True, is_training=is_training, 69 | scope='tconv2', bn_decay=bn_decay) 70 | net = tf_util.conv2d(net, 1024, [1,1], 71 | padding='VALID', stride=[1,1], 72 | bn=True, is_training=is_training, 73 | scope='tconv3', bn_decay=bn_decay) 74 | net = tf_util.max_pool2d(net, [num_point,1], 75 | padding='VALID', scope='tmaxpool') 76 | 77 | net = tf.reshape(net, [batch_size, -1]) 78 | net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, 79 | scope='tfc1', bn_decay=bn_decay) 80 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, 81 | scope='tfc2', bn_decay=bn_decay) 82 | 83 | with tf.variable_scope('transform_feat') as sc: 84 | weights = tf.get_variable('weights', [256, K*K], 85 | initializer=tf.constant_initializer(0.0), 86 | dtype=tf.float32) 87 | biases = tf.get_variable('biases', [K*K], 88 | initializer=tf.constant_initializer(0.0), 89 | dtype=tf.float32) 90 | biases += tf.constant(np.eye(K).flatten(), dtype=tf.float32) 91 | transform = tf.matmul(net, weights) 92 | transform = tf.nn.bias_add(transform, biases) 93 | 94 | transform = tf.reshape(transform, [batch_size, K, K]) 95 | return transform 96 | -------------------------------------------------------------------------------- /part_seg/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download original ShapeNetPart dataset (around 1GB) 4 | wget https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_v0.zip 5 | unzip shapenetcore_partanno_v0.zip 6 | rm shapenetcore_partanno_v0.zip 7 | 8 | # Download HDF5 for ShapeNet Part segmentation (around 346MB) 9 | wget https://shapenet.cs.stanford.edu/media/shapenet_part_seg_hdf5_data.zip 10 | unzip shapenet_part_seg_hdf5_data.zip 11 | rm shapenet_part_seg_hdf5_data.zip 12 | 13 | -------------------------------------------------------------------------------- /part_seg/pointnet_part_seg.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | import os 5 | import sys 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | sys.path.append(os.path.dirname(BASE_DIR)) 8 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 9 | import tf_util 10 | 11 | 12 | def get_transform_K(inputs, is_training, bn_decay=None, K = 3): 13 | """ Transform Net, input is BxNx1xK gray image 14 | Return: 15 | Transformation matrix of size KxK """ 16 | batch_size = inputs.get_shape()[0].value 17 | num_point = inputs.get_shape()[1].value 18 | 19 | net = tf_util.conv2d(inputs, 256, [1,1], padding='VALID', stride=[1,1], 20 | bn=True, is_training=is_training, scope='tconv1', bn_decay=bn_decay) 21 | net = tf_util.conv2d(net, 1024, [1,1], padding='VALID', stride=[1,1], 22 | bn=True, is_training=is_training, scope='tconv2', bn_decay=bn_decay) 23 | net = tf_util.max_pool2d(net, [num_point,1], padding='VALID', scope='tmaxpool') 24 | 25 | net = tf.reshape(net, [batch_size, -1]) 26 | net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, scope='tfc1', bn_decay=bn_decay) 27 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='tfc2', bn_decay=bn_decay) 28 | 29 | with tf.variable_scope('transform_feat') as sc: 30 | weights = tf.get_variable('weights', [256, K*K], initializer=tf.constant_initializer(0.0), dtype=tf.float32) 31 | biases = tf.get_variable('biases', [K*K], initializer=tf.constant_initializer(0.0), dtype=tf.float32) + tf.constant(np.eye(K).flatten(), dtype=tf.float32) 32 | transform = tf.matmul(net, weights) 33 | transform = tf.nn.bias_add(transform, biases) 34 | 35 | #transform = tf_util.fully_connected(net, 3*K, activation_fn=None, scope='tfc3') 36 | transform = tf.reshape(transform, [batch_size, K, K]) 37 | return transform 38 | 39 | 40 | 41 | 42 | 43 | def get_transform(point_cloud, is_training, bn_decay=None, K = 3): 44 | """ Transform Net, input is BxNx3 gray image 45 | Return: 46 | Transformation matrix of size 3xK """ 47 | batch_size = point_cloud.get_shape()[0].value 48 | num_point = point_cloud.get_shape()[1].value 49 | 50 | input_image = tf.expand_dims(point_cloud, -1) 51 | net = tf_util.conv2d(input_image, 64, [1,3], padding='VALID', stride=[1,1], 52 | bn=True, is_training=is_training, scope='tconv1', bn_decay=bn_decay) 53 | net = tf_util.conv2d(net, 128, [1,1], padding='VALID', stride=[1,1], 54 | bn=True, is_training=is_training, scope='tconv3', bn_decay=bn_decay) 55 | net = tf_util.conv2d(net, 1024, [1,1], padding='VALID', stride=[1,1], 56 | bn=True, is_training=is_training, scope='tconv4', bn_decay=bn_decay) 57 | net = tf_util.max_pool2d(net, [num_point,1], padding='VALID', scope='tmaxpool') 58 | 59 | net = tf.reshape(net, [batch_size, -1]) 60 | net = tf_util.fully_connected(net, 128, bn=True, is_training=is_training, scope='tfc1', bn_decay=bn_decay) 61 | net = tf_util.fully_connected(net, 128, bn=True, is_training=is_training, scope='tfc2', bn_decay=bn_decay) 62 | 63 | with tf.variable_scope('transform_XYZ') as sc: 64 | assert(K==3) 65 | weights = tf.get_variable('weights', [128, 3*K], initializer=tf.constant_initializer(0.0), dtype=tf.float32) 66 | biases = tf.get_variable('biases', [3*K], initializer=tf.constant_initializer(0.0), dtype=tf.float32) + tf.constant([1,0,0,0,1,0,0,0,1], dtype=tf.float32) 67 | transform = tf.matmul(net, weights) 68 | transform = tf.nn.bias_add(transform, biases) 69 | 70 | #transform = tf_util.fully_connected(net, 3*K, activation_fn=None, scope='tfc3') 71 | transform = tf.reshape(transform, [batch_size, 3, K]) 72 | return transform 73 | 74 | 75 | def get_model(point_cloud, input_label, is_training, cat_num, part_num, \ 76 | batch_size, num_point, weight_decay, bn_decay=None): 77 | """ ConvNet baseline, input is BxNx3 gray image """ 78 | end_points = {} 79 | 80 | with tf.variable_scope('transform_net1') as sc: 81 | K = 3 82 | transform = get_transform(point_cloud, is_training, bn_decay, K = 3) 83 | point_cloud_transformed = tf.matmul(point_cloud, transform) 84 | 85 | input_image = tf.expand_dims(point_cloud_transformed, -1) 86 | out1 = tf_util.conv2d(input_image, 64, [1,K], padding='VALID', stride=[1,1], 87 | bn=True, is_training=is_training, scope='conv1', bn_decay=bn_decay) 88 | out2 = tf_util.conv2d(out1, 128, [1,1], padding='VALID', stride=[1,1], 89 | bn=True, is_training=is_training, scope='conv2', bn_decay=bn_decay) 90 | out3 = tf_util.conv2d(out2, 128, [1,1], padding='VALID', stride=[1,1], 91 | bn=True, is_training=is_training, scope='conv3', bn_decay=bn_decay) 92 | 93 | 94 | with tf.variable_scope('transform_net2') as sc: 95 | K = 128 96 | transform = get_transform_K(out3, is_training, bn_decay, K) 97 | 98 | end_points['transform'] = transform 99 | 100 | squeezed_out3 = tf.reshape(out3, [batch_size, num_point, 128]) 101 | net_transformed = tf.matmul(squeezed_out3, transform) 102 | net_transformed = tf.expand_dims(net_transformed, [2]) 103 | 104 | out4 = tf_util.conv2d(net_transformed, 512, [1,1], padding='VALID', stride=[1,1], 105 | bn=True, is_training=is_training, scope='conv4', bn_decay=bn_decay) 106 | out5 = tf_util.conv2d(out4, 2048, [1,1], padding='VALID', stride=[1,1], 107 | bn=True, is_training=is_training, scope='conv5', bn_decay=bn_decay) 108 | out_max = tf_util.max_pool2d(out5, [num_point,1], padding='VALID', scope='maxpool') 109 | 110 | # classification network 111 | net = tf.reshape(out_max, [batch_size, -1]) 112 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='cla/fc1', bn_decay=bn_decay) 113 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, scope='cla/fc2', bn_decay=bn_decay) 114 | net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training, scope='cla/dp1') 115 | net = tf_util.fully_connected(net, cat_num, activation_fn=None, scope='cla/fc3') 116 | 117 | # segmentation network 118 | one_hot_label_expand = tf.reshape(input_label, [batch_size, 1, 1, cat_num]) 119 | out_max = tf.concat(axis=3, values=[out_max, one_hot_label_expand]) 120 | 121 | expand = tf.tile(out_max, [1, num_point, 1, 1]) 122 | concat = tf.concat(axis=3, values=[expand, out1, out2, out3, out4, out5]) 123 | 124 | net2 = tf_util.conv2d(concat, 256, [1,1], padding='VALID', stride=[1,1], bn_decay=bn_decay, 125 | bn=True, is_training=is_training, scope='seg/conv1', weight_decay=weight_decay) 126 | net2 = tf_util.dropout(net2, keep_prob=0.8, is_training=is_training, scope='seg/dp1') 127 | net2 = tf_util.conv2d(net2, 256, [1,1], padding='VALID', stride=[1,1], bn_decay=bn_decay, 128 | bn=True, is_training=is_training, scope='seg/conv2', weight_decay=weight_decay) 129 | net2 = tf_util.dropout(net2, keep_prob=0.8, is_training=is_training, scope='seg/dp2') 130 | net2 = tf_util.conv2d(net2, 128, [1,1], padding='VALID', stride=[1,1], bn_decay=bn_decay, 131 | bn=True, is_training=is_training, scope='seg/conv3', weight_decay=weight_decay) 132 | net2 = tf_util.conv2d(net2, part_num, [1,1], padding='VALID', stride=[1,1], activation_fn=None, 133 | bn=False, scope='seg/conv4', weight_decay=weight_decay) 134 | 135 | net2 = tf.reshape(net2, [batch_size, num_point, part_num]) 136 | 137 | return net, net2, end_points 138 | 139 | def get_loss(l_pred, seg_pred, label, seg, weight, end_points): 140 | per_instance_label_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=l_pred, labels=label) 141 | label_loss = tf.reduce_mean(per_instance_label_loss) 142 | 143 | # size of seg_pred is batch_size x point_num x part_cat_num 144 | # size of seg is batch_size x point_num 145 | per_instance_seg_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=seg_pred, labels=seg), axis=1) 146 | seg_loss = tf.reduce_mean(per_instance_seg_loss) 147 | 148 | per_instance_seg_pred_res = tf.argmax(seg_pred, 2) 149 | 150 | # Enforce the transformation as orthogonal matrix 151 | transform = end_points['transform'] # BxKxK 152 | K = transform.get_shape()[1].value 153 | mat_diff = tf.matmul(transform, tf.transpose(transform, perm=[0,2,1])) - tf.constant(np.eye(K), dtype=tf.float32) 154 | mat_diff_loss = tf.nn.l2_loss(mat_diff) 155 | 156 | 157 | total_loss = weight * seg_loss + (1 - weight) * label_loss + mat_diff_loss * 1e-3 158 | 159 | return total_loss, label_loss, per_instance_label_loss, seg_loss, per_instance_seg_loss, per_instance_seg_pred_res 160 | 161 | -------------------------------------------------------------------------------- /part_seg/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | import json 4 | import numpy as np 5 | import os 6 | import sys 7 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | sys.path.append(BASE_DIR) 9 | sys.path.append(os.path.dirname(BASE_DIR)) 10 | import provider 11 | import pointnet_part_seg as model 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--model_path', default='train_results/trained_models/epoch_190.ckpt', help='Model checkpoint path') 15 | FLAGS = parser.parse_args() 16 | 17 | 18 | # DEFAULT SETTINGS 19 | pretrained_model_path = FLAGS.model_path # os.path.join(BASE_DIR, './pretrained_model/model.ckpt') 20 | hdf5_data_dir = os.path.join(BASE_DIR, './hdf5_data') 21 | ply_data_dir = os.path.join(BASE_DIR, './PartAnnotation') 22 | gpu_to_use = 0 23 | output_dir = os.path.join(BASE_DIR, './test_results') 24 | output_verbose = True # If true, output all color-coded part segmentation obj files 25 | 26 | # MAIN SCRIPT 27 | point_num = 3000 # the max number of points in the all testing data shapes 28 | batch_size = 1 29 | 30 | test_file_list = os.path.join(BASE_DIR, 'testing_ply_file_list.txt') 31 | 32 | oid2cpid = json.load(open(os.path.join(hdf5_data_dir, 'overallid_to_catid_partid.json'), 'r')) 33 | 34 | object2setofoid = {} 35 | for idx in range(len(oid2cpid)): 36 | objid, pid = oid2cpid[idx] 37 | if not objid in object2setofoid.keys(): 38 | object2setofoid[objid] = [] 39 | object2setofoid[objid].append(idx) 40 | 41 | all_obj_cat_file = os.path.join(hdf5_data_dir, 'all_object_categories.txt') 42 | fin = open(all_obj_cat_file, 'r') 43 | lines = [line.rstrip() for line in fin.readlines()] 44 | objcats = [line.split()[1] for line in lines] 45 | objnames = [line.split()[0] for line in lines] 46 | on2oid = {objcats[i]:i for i in range(len(objcats))} 47 | fin.close() 48 | 49 | color_map_file = os.path.join(hdf5_data_dir, 'part_color_mapping.json') 50 | color_map = json.load(open(color_map_file, 'r')) 51 | 52 | NUM_OBJ_CATS = 16 53 | NUM_PART_CATS = 50 54 | 55 | cpid2oid = json.load(open(os.path.join(hdf5_data_dir, 'catid_partid_to_overallid.json'), 'r')) 56 | 57 | def printout(flog, data): 58 | print(data) 59 | flog.write(data + '\n') 60 | 61 | def output_color_point_cloud(data, seg, out_file): 62 | with open(out_file, 'w') as f: 63 | l = len(seg) 64 | for i in range(l): 65 | color = color_map[seg[i]] 66 | f.write('v %f %f %f %f %f %f\n' % (data[i][0], data[i][1], data[i][2], color[0], color[1], color[2])) 67 | 68 | def output_color_point_cloud_red_blue(data, seg, out_file): 69 | with open(out_file, 'w') as f: 70 | l = len(seg) 71 | for i in range(l): 72 | if seg[i] == 1: 73 | color = [0, 0, 1] 74 | elif seg[i] == 0: 75 | color = [1, 0, 0] 76 | else: 77 | color = [0, 0, 0] 78 | 79 | f.write('v %f %f %f %f %f %f\n' % (data[i][0], data[i][1], data[i][2], color[0], color[1], color[2])) 80 | 81 | 82 | def pc_normalize(pc): 83 | l = pc.shape[0] 84 | centroid = np.mean(pc, axis=0) 85 | pc = pc - centroid 86 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 87 | pc = pc / m 88 | return pc 89 | 90 | def placeholder_inputs(): 91 | pointclouds_ph = tf.placeholder(tf.float32, shape=(batch_size, point_num, 3)) 92 | input_label_ph = tf.placeholder(tf.float32, shape=(batch_size, NUM_OBJ_CATS)) 93 | return pointclouds_ph, input_label_ph 94 | 95 | def output_color_point_cloud(data, seg, out_file): 96 | with open(out_file, 'w') as f: 97 | l = len(seg) 98 | for i in range(l): 99 | color = color_map[seg[i]] 100 | f.write('v %f %f %f %f %f %f\n' % (data[i][0], data[i][1], data[i][2], color[0], color[1], color[2])) 101 | 102 | def load_pts_seg_files(pts_file, seg_file, catid): 103 | with open(pts_file, 'r') as f: 104 | pts_str = [item.rstrip() for item in f.readlines()] 105 | pts = np.array([np.float32(s.split()) for s in pts_str], dtype=np.float32) 106 | with open(seg_file, 'r') as f: 107 | part_ids = np.array([int(item.rstrip()) for item in f.readlines()], dtype=np.uint8) 108 | seg = np.array([cpid2oid[catid+'_'+str(x)] for x in part_ids]) 109 | return pts, seg 110 | 111 | def pc_augment_to_point_num(pts, pn): 112 | assert(pts.shape[0] <= pn) 113 | cur_len = pts.shape[0] 114 | res = np.array(pts) 115 | while cur_len < pn: 116 | res = np.concatenate((res, pts)) 117 | cur_len += pts.shape[0] 118 | return res[:pn, :] 119 | 120 | def convert_label_to_one_hot(labels): 121 | label_one_hot = np.zeros((labels.shape[0], NUM_OBJ_CATS)) 122 | for idx in range(labels.shape[0]): 123 | label_one_hot[idx, labels[idx]] = 1 124 | return label_one_hot 125 | 126 | def predict(): 127 | is_training = False 128 | 129 | with tf.device('/gpu:'+str(gpu_to_use)): 130 | pointclouds_ph, input_label_ph = placeholder_inputs() 131 | is_training_ph = tf.placeholder(tf.bool, shape=()) 132 | 133 | # simple model 134 | pred, seg_pred, end_points = model.get_model(pointclouds_ph, input_label_ph, \ 135 | cat_num=NUM_OBJ_CATS, part_num=NUM_PART_CATS, is_training=is_training_ph, \ 136 | batch_size=batch_size, num_point=point_num, weight_decay=0.0, bn_decay=None) 137 | 138 | # Add ops to save and restore all the variables. 139 | saver = tf.train.Saver() 140 | 141 | # Later, launch the model, use the saver to restore variables from disk, and 142 | # do some work with the model. 143 | 144 | config = tf.ConfigProto() 145 | config.gpu_options.allow_growth = True 146 | config.allow_soft_placement = True 147 | 148 | with tf.Session(config=config) as sess: 149 | if not os.path.exists(output_dir): 150 | os.mkdir(output_dir) 151 | 152 | flog = open(os.path.join(output_dir, 'log.txt'), 'w') 153 | 154 | # Restore variables from disk. 155 | printout(flog, 'Loading model %s' % pretrained_model_path) 156 | saver.restore(sess, pretrained_model_path) 157 | printout(flog, 'Model restored.') 158 | 159 | # Note: the evaluation for the model with BN has to have some statistics 160 | # Using some test datas as the statistics 161 | batch_data = np.zeros([batch_size, point_num, 3]).astype(np.float32) 162 | 163 | total_acc = 0.0 164 | total_seen = 0 165 | total_acc_iou = 0.0 166 | 167 | total_per_cat_acc = np.zeros((NUM_OBJ_CATS)).astype(np.float32) 168 | total_per_cat_iou = np.zeros((NUM_OBJ_CATS)).astype(np.float32) 169 | total_per_cat_seen = np.zeros((NUM_OBJ_CATS)).astype(np.int32) 170 | 171 | ffiles = open(test_file_list, 'r') 172 | lines = [line.rstrip() for line in ffiles.readlines()] 173 | pts_files = [line.split()[0] for line in lines] 174 | seg_files = [line.split()[1] for line in lines] 175 | labels = [line.split()[2] for line in lines] 176 | ffiles.close() 177 | 178 | len_pts_files = len(pts_files) 179 | for shape_idx in range(len_pts_files): 180 | if shape_idx % 100 == 0: 181 | printout(flog, '%d/%d ...' % (shape_idx, len_pts_files)) 182 | 183 | cur_gt_label = on2oid[labels[shape_idx]] 184 | 185 | cur_label_one_hot = np.zeros((1, NUM_OBJ_CATS), dtype=np.float32) 186 | cur_label_one_hot[0, cur_gt_label] = 1 187 | 188 | pts_file_to_load = os.path.join(ply_data_dir, pts_files[shape_idx]) 189 | seg_file_to_load = os.path.join(ply_data_dir, seg_files[shape_idx]) 190 | 191 | pts, seg = load_pts_seg_files(pts_file_to_load, seg_file_to_load, objcats[cur_gt_label]) 192 | ori_point_num = len(seg) 193 | 194 | batch_data[0, ...] = pc_augment_to_point_num(pc_normalize(pts), point_num) 195 | 196 | label_pred_val, seg_pred_res = sess.run([pred, seg_pred], feed_dict={ 197 | pointclouds_ph: batch_data, 198 | input_label_ph: cur_label_one_hot, 199 | is_training_ph: is_training, 200 | }) 201 | 202 | label_pred_val = np.argmax(label_pred_val[0, :]) 203 | 204 | seg_pred_res = seg_pred_res[0, ...] 205 | 206 | iou_oids = object2setofoid[objcats[cur_gt_label]] 207 | non_cat_labels = list(set(np.arange(NUM_PART_CATS)).difference(set(iou_oids))) 208 | 209 | mini = np.min(seg_pred_res) 210 | seg_pred_res[:, non_cat_labels] = mini - 1000 211 | 212 | seg_pred_val = np.argmax(seg_pred_res, axis=1)[:ori_point_num] 213 | 214 | seg_acc = np.mean(seg_pred_val == seg) 215 | 216 | total_acc += seg_acc 217 | total_seen += 1 218 | 219 | total_per_cat_seen[cur_gt_label] += 1 220 | total_per_cat_acc[cur_gt_label] += seg_acc 221 | 222 | mask = np.int32(seg_pred_val == seg) 223 | 224 | total_iou = 0.0 225 | iou_log = '' 226 | for oid in iou_oids: 227 | n_pred = np.sum(seg_pred_val == oid) 228 | n_gt = np.sum(seg == oid) 229 | n_intersect = np.sum(np.int32(seg == oid) * mask) 230 | n_union = n_pred + n_gt - n_intersect 231 | iou_log += '_' + str(n_pred)+'_'+str(n_gt)+'_'+str(n_intersect)+'_'+str(n_union)+'_' 232 | if n_union == 0: 233 | total_iou += 1 234 | iou_log += '_1\n' 235 | else: 236 | total_iou += n_intersect * 1.0 / n_union 237 | iou_log += '_'+str(n_intersect * 1.0 / n_union)+'\n' 238 | 239 | avg_iou = total_iou / len(iou_oids) 240 | total_acc_iou += avg_iou 241 | total_per_cat_iou[cur_gt_label] += avg_iou 242 | 243 | if output_verbose: 244 | output_color_point_cloud(pts, seg, os.path.join(output_dir, str(shape_idx)+'_gt.obj')) 245 | output_color_point_cloud(pts, seg_pred_val, os.path.join(output_dir, str(shape_idx)+'_pred.obj')) 246 | output_color_point_cloud_red_blue(pts, np.int32(seg == seg_pred_val), 247 | os.path.join(output_dir, str(shape_idx)+'_diff.obj')) 248 | 249 | with open(os.path.join(output_dir, str(shape_idx)+'.log'), 'w') as fout: 250 | fout.write('Total Point: %d\n\n' % ori_point_num) 251 | fout.write('Ground Truth: %s\n' % objnames[cur_gt_label]) 252 | fout.write('Predict: %s\n\n' % objnames[label_pred_val]) 253 | fout.write('Accuracy: %f\n' % seg_acc) 254 | fout.write('IoU: %f\n\n' % avg_iou) 255 | fout.write('IoU details: %s\n' % iou_log) 256 | 257 | printout(flog, 'Accuracy: %f' % (total_acc / total_seen)) 258 | printout(flog, 'IoU: %f' % (total_acc_iou / total_seen)) 259 | 260 | for cat_idx in range(NUM_OBJ_CATS): 261 | printout(flog, '\t ' + objcats[cat_idx] + ' Total Number: ' + str(total_per_cat_seen[cat_idx])) 262 | if total_per_cat_seen[cat_idx] > 0: 263 | printout(flog, '\t ' + objcats[cat_idx] + ' Accuracy: ' + \ 264 | str(total_per_cat_acc[cat_idx] / total_per_cat_seen[cat_idx])) 265 | printout(flog, '\t ' + objcats[cat_idx] + ' IoU: '+ \ 266 | str(total_per_cat_iou[cat_idx] / total_per_cat_seen[cat_idx])) 267 | 268 | 269 | with tf.Graph().as_default(): 270 | predict() 271 | -------------------------------------------------------------------------------- /part_seg/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import tensorflow as tf 4 | import numpy as np 5 | from datetime import datetime 6 | import json 7 | import os 8 | import sys 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(BASE_DIR) 11 | sys.path.append(os.path.dirname(BASE_DIR)) 12 | import provider 13 | import pointnet_part_seg as model 14 | 15 | # DEFAULT SETTINGS 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--gpu', type=int, default=1, help='GPU to use [default: GPU 0]') 18 | parser.add_argument('--batch', type=int, default=32, help='Batch Size during training [default: 32]') 19 | parser.add_argument('--epoch', type=int, default=200, help='Epoch to run [default: 50]') 20 | parser.add_argument('--point_num', type=int, default=2048, help='Point Number [256/512/1024/2048]') 21 | parser.add_argument('--output_dir', type=str, default='train_results', help='Directory that stores all training logs and trained models') 22 | parser.add_argument('--wd', type=float, default=0, help='Weight Decay [Default: 0.0]') 23 | FLAGS = parser.parse_args() 24 | 25 | hdf5_data_dir = os.path.join(BASE_DIR, './hdf5_data') 26 | 27 | # MAIN SCRIPT 28 | point_num = FLAGS.point_num 29 | batch_size = FLAGS.batch 30 | output_dir = FLAGS.output_dir 31 | 32 | if not os.path.exists(output_dir): 33 | os.mkdir(output_dir) 34 | 35 | color_map_file = os.path.join(hdf5_data_dir, 'part_color_mapping.json') 36 | color_map = json.load(open(color_map_file, 'r')) 37 | 38 | all_obj_cats_file = os.path.join(hdf5_data_dir, 'all_object_categories.txt') 39 | fin = open(all_obj_cats_file, 'r') 40 | lines = [line.rstrip() for line in fin.readlines()] 41 | all_obj_cats = [(line.split()[0], line.split()[1]) for line in lines] 42 | fin.close() 43 | 44 | all_cats = json.load(open(os.path.join(hdf5_data_dir, 'overallid_to_catid_partid.json'), 'r')) 45 | NUM_CATEGORIES = 16 46 | NUM_PART_CATS = len(all_cats) 47 | 48 | print('#### Batch Size: {0}'.format(batch_size)) 49 | print('#### Point Number: {0}'.format(point_num)) 50 | print('#### Training using GPU: {0}'.format(FLAGS.gpu)) 51 | 52 | DECAY_STEP = 16881 * 20 53 | DECAY_RATE = 0.5 54 | 55 | LEARNING_RATE_CLIP = 1e-5 56 | 57 | BN_INIT_DECAY = 0.5 58 | BN_DECAY_DECAY_RATE = 0.5 59 | BN_DECAY_DECAY_STEP = float(DECAY_STEP * 2) 60 | BN_DECAY_CLIP = 0.99 61 | 62 | BASE_LEARNING_RATE = 0.001 63 | MOMENTUM = 0.9 64 | TRAINING_EPOCHES = FLAGS.epoch 65 | print('### Training epoch: {0}'.format(TRAINING_EPOCHES)) 66 | 67 | TRAINING_FILE_LIST = os.path.join(hdf5_data_dir, 'train_hdf5_file_list.txt') 68 | TESTING_FILE_LIST = os.path.join(hdf5_data_dir, 'val_hdf5_file_list.txt') 69 | 70 | MODEL_STORAGE_PATH = os.path.join(output_dir, 'trained_models') 71 | if not os.path.exists(MODEL_STORAGE_PATH): 72 | os.mkdir(MODEL_STORAGE_PATH) 73 | 74 | LOG_STORAGE_PATH = os.path.join(output_dir, 'logs') 75 | if not os.path.exists(LOG_STORAGE_PATH): 76 | os.mkdir(LOG_STORAGE_PATH) 77 | 78 | SUMMARIES_FOLDER = os.path.join(output_dir, 'summaries') 79 | if not os.path.exists(SUMMARIES_FOLDER): 80 | os.mkdir(SUMMARIES_FOLDER) 81 | 82 | def printout(flog, data): 83 | print(data) 84 | flog.write(data + '\n') 85 | 86 | def placeholder_inputs(): 87 | pointclouds_ph = tf.placeholder(tf.float32, shape=(batch_size, point_num, 3)) 88 | input_label_ph = tf.placeholder(tf.float32, shape=(batch_size, NUM_CATEGORIES)) 89 | labels_ph = tf.placeholder(tf.int32, shape=(batch_size)) 90 | seg_ph = tf.placeholder(tf.int32, shape=(batch_size, point_num)) 91 | return pointclouds_ph, input_label_ph, labels_ph, seg_ph 92 | 93 | def convert_label_to_one_hot(labels): 94 | label_one_hot = np.zeros((labels.shape[0], NUM_CATEGORIES)) 95 | for idx in range(labels.shape[0]): 96 | label_one_hot[idx, labels[idx]] = 1 97 | return label_one_hot 98 | 99 | def train(): 100 | with tf.Graph().as_default(): 101 | with tf.device('/gpu:'+str(FLAGS.gpu)): 102 | pointclouds_ph, input_label_ph, labels_ph, seg_ph = placeholder_inputs() 103 | is_training_ph = tf.placeholder(tf.bool, shape=()) 104 | 105 | batch = tf.Variable(0, trainable=False) 106 | learning_rate = tf.train.exponential_decay( 107 | BASE_LEARNING_RATE, # base learning rate 108 | batch * batch_size, # global_var indicating the number of steps 109 | DECAY_STEP, # step size 110 | DECAY_RATE, # decay rate 111 | staircase=True # Stair-case or continuous decreasing 112 | ) 113 | learning_rate = tf.maximum(learning_rate, LEARNING_RATE_CLIP) 114 | 115 | bn_momentum = tf.train.exponential_decay( 116 | BN_INIT_DECAY, 117 | batch*batch_size, 118 | BN_DECAY_DECAY_STEP, 119 | BN_DECAY_DECAY_RATE, 120 | staircase=True) 121 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 122 | 123 | lr_op = tf.summary.scalar('learning_rate', learning_rate) 124 | batch_op = tf.summary.scalar('batch_number', batch) 125 | bn_decay_op = tf.summary.scalar('bn_decay', bn_decay) 126 | 127 | labels_pred, seg_pred, end_points = model.get_model(pointclouds_ph, input_label_ph, \ 128 | is_training=is_training_ph, bn_decay=bn_decay, cat_num=NUM_CATEGORIES, \ 129 | part_num=NUM_PART_CATS, batch_size=batch_size, num_point=point_num, weight_decay=FLAGS.wd) 130 | 131 | # model.py defines both classification net and segmentation net, which share the common global feature extractor network. 132 | # In model.get_loss, we define the total loss to be weighted sum of the classification and segmentation losses. 133 | # Here, we only train for segmentation network. Thus, we set weight to be 1.0. 134 | loss, label_loss, per_instance_label_loss, seg_loss, per_instance_seg_loss, per_instance_seg_pred_res \ 135 | = model.get_loss(labels_pred, seg_pred, labels_ph, seg_ph, 1.0, end_points) 136 | 137 | total_training_loss_ph = tf.placeholder(tf.float32, shape=()) 138 | total_testing_loss_ph = tf.placeholder(tf.float32, shape=()) 139 | 140 | label_training_loss_ph = tf.placeholder(tf.float32, shape=()) 141 | label_testing_loss_ph = tf.placeholder(tf.float32, shape=()) 142 | 143 | seg_training_loss_ph = tf.placeholder(tf.float32, shape=()) 144 | seg_testing_loss_ph = tf.placeholder(tf.float32, shape=()) 145 | 146 | label_training_acc_ph = tf.placeholder(tf.float32, shape=()) 147 | label_testing_acc_ph = tf.placeholder(tf.float32, shape=()) 148 | label_testing_acc_avg_cat_ph = tf.placeholder(tf.float32, shape=()) 149 | 150 | seg_training_acc_ph = tf.placeholder(tf.float32, shape=()) 151 | seg_testing_acc_ph = tf.placeholder(tf.float32, shape=()) 152 | seg_testing_acc_avg_cat_ph = tf.placeholder(tf.float32, shape=()) 153 | 154 | total_train_loss_sum_op = tf.summary.scalar('total_training_loss', total_training_loss_ph) 155 | total_test_loss_sum_op = tf.summary.scalar('total_testing_loss', total_testing_loss_ph) 156 | 157 | label_train_loss_sum_op = tf.summary.scalar('label_training_loss', label_training_loss_ph) 158 | label_test_loss_sum_op = tf.summary.scalar('label_testing_loss', label_testing_loss_ph) 159 | 160 | seg_train_loss_sum_op = tf.summary.scalar('seg_training_loss', seg_training_loss_ph) 161 | seg_test_loss_sum_op = tf.summary.scalar('seg_testing_loss', seg_testing_loss_ph) 162 | 163 | label_train_acc_sum_op = tf.summary.scalar('label_training_acc', label_training_acc_ph) 164 | label_test_acc_sum_op = tf.summary.scalar('label_testing_acc', label_testing_acc_ph) 165 | label_test_acc_avg_cat_op = tf.summary.scalar('label_testing_acc_avg_cat', label_testing_acc_avg_cat_ph) 166 | 167 | seg_train_acc_sum_op = tf.summary.scalar('seg_training_acc', seg_training_acc_ph) 168 | seg_test_acc_sum_op = tf.summary.scalar('seg_testing_acc', seg_testing_acc_ph) 169 | seg_test_acc_avg_cat_op = tf.summary.scalar('seg_testing_acc_avg_cat', seg_testing_acc_avg_cat_ph) 170 | 171 | train_variables = tf.trainable_variables() 172 | 173 | trainer = tf.train.AdamOptimizer(learning_rate) 174 | train_op = trainer.minimize(loss, var_list=train_variables, global_step=batch) 175 | 176 | saver = tf.train.Saver() 177 | 178 | config = tf.ConfigProto() 179 | config.gpu_options.allow_growth = True 180 | config.allow_soft_placement = True 181 | sess = tf.Session(config=config) 182 | 183 | init = tf.global_variables_initializer() 184 | sess.run(init) 185 | 186 | train_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/train', sess.graph) 187 | test_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/test') 188 | 189 | train_file_list = provider.getDataFiles(TRAINING_FILE_LIST) 190 | num_train_file = len(train_file_list) 191 | test_file_list = provider.getDataFiles(TESTING_FILE_LIST) 192 | num_test_file = len(test_file_list) 193 | 194 | fcmd = open(os.path.join(LOG_STORAGE_PATH, 'cmd.txt'), 'w') 195 | fcmd.write(str(FLAGS)) 196 | fcmd.close() 197 | 198 | # write logs to the disk 199 | flog = open(os.path.join(LOG_STORAGE_PATH, 'log.txt'), 'w') 200 | 201 | def train_one_epoch(train_file_idx, epoch_num): 202 | is_training = True 203 | 204 | for i in range(num_train_file): 205 | cur_train_filename = os.path.join(hdf5_data_dir, train_file_list[train_file_idx[i]]) 206 | printout(flog, 'Loading train file ' + cur_train_filename) 207 | 208 | cur_data, cur_labels, cur_seg = provider.loadDataFile_with_seg(cur_train_filename) 209 | cur_data, cur_labels, order = provider.shuffle_data(cur_data, np.squeeze(cur_labels)) 210 | cur_seg = cur_seg[order, ...] 211 | 212 | cur_labels_one_hot = convert_label_to_one_hot(cur_labels) 213 | 214 | num_data = len(cur_labels) 215 | num_batch = num_data // batch_size 216 | 217 | total_loss = 0.0 218 | total_label_loss = 0.0 219 | total_seg_loss = 0.0 220 | total_label_acc = 0.0 221 | total_seg_acc = 0.0 222 | 223 | for j in range(num_batch): 224 | begidx = j * batch_size 225 | endidx = (j + 1) * batch_size 226 | 227 | feed_dict = { 228 | pointclouds_ph: cur_data[begidx: endidx, ...], 229 | labels_ph: cur_labels[begidx: endidx, ...], 230 | input_label_ph: cur_labels_one_hot[begidx: endidx, ...], 231 | seg_ph: cur_seg[begidx: endidx, ...], 232 | is_training_ph: is_training, 233 | } 234 | 235 | _, loss_val, label_loss_val, seg_loss_val, per_instance_label_loss_val, \ 236 | per_instance_seg_loss_val, label_pred_val, seg_pred_val, pred_seg_res \ 237 | = sess.run([train_op, loss, label_loss, seg_loss, per_instance_label_loss, \ 238 | per_instance_seg_loss, labels_pred, seg_pred, per_instance_seg_pred_res], \ 239 | feed_dict=feed_dict) 240 | 241 | per_instance_part_acc = np.mean(pred_seg_res == cur_seg[begidx: endidx, ...], axis=1) 242 | average_part_acc = np.mean(per_instance_part_acc) 243 | 244 | total_loss += loss_val 245 | total_label_loss += label_loss_val 246 | total_seg_loss += seg_loss_val 247 | 248 | per_instance_label_pred = np.argmax(label_pred_val, axis=1) 249 | total_label_acc += np.mean(np.float32(per_instance_label_pred == cur_labels[begidx: endidx, ...])) 250 | total_seg_acc += average_part_acc 251 | 252 | total_loss = total_loss * 1.0 / num_batch 253 | total_label_loss = total_label_loss * 1.0 / num_batch 254 | total_seg_loss = total_seg_loss * 1.0 / num_batch 255 | total_label_acc = total_label_acc * 1.0 / num_batch 256 | total_seg_acc = total_seg_acc * 1.0 / num_batch 257 | 258 | lr_sum, bn_decay_sum, batch_sum, train_loss_sum, train_label_acc_sum, \ 259 | train_label_loss_sum, train_seg_loss_sum, train_seg_acc_sum = sess.run(\ 260 | [lr_op, bn_decay_op, batch_op, total_train_loss_sum_op, label_train_acc_sum_op, \ 261 | label_train_loss_sum_op, seg_train_loss_sum_op, seg_train_acc_sum_op], \ 262 | feed_dict={total_training_loss_ph: total_loss, label_training_loss_ph: total_label_loss, \ 263 | seg_training_loss_ph: total_seg_loss, label_training_acc_ph: total_label_acc, \ 264 | seg_training_acc_ph: total_seg_acc}) 265 | 266 | train_writer.add_summary(train_loss_sum, i + epoch_num * num_train_file) 267 | train_writer.add_summary(train_label_loss_sum, i + epoch_num * num_train_file) 268 | train_writer.add_summary(train_seg_loss_sum, i + epoch_num * num_train_file) 269 | train_writer.add_summary(lr_sum, i + epoch_num * num_train_file) 270 | train_writer.add_summary(bn_decay_sum, i + epoch_num * num_train_file) 271 | train_writer.add_summary(train_label_acc_sum, i + epoch_num * num_train_file) 272 | train_writer.add_summary(train_seg_acc_sum, i + epoch_num * num_train_file) 273 | train_writer.add_summary(batch_sum, i + epoch_num * num_train_file) 274 | 275 | printout(flog, '\tTraining Total Mean_loss: %f' % total_loss) 276 | printout(flog, '\t\tTraining Label Mean_loss: %f' % total_label_loss) 277 | printout(flog, '\t\tTraining Label Accuracy: %f' % total_label_acc) 278 | printout(flog, '\t\tTraining Seg Mean_loss: %f' % total_seg_loss) 279 | printout(flog, '\t\tTraining Seg Accuracy: %f' % total_seg_acc) 280 | 281 | def eval_one_epoch(epoch_num): 282 | is_training = False 283 | 284 | total_loss = 0.0 285 | total_label_loss = 0.0 286 | total_seg_loss = 0.0 287 | total_label_acc = 0.0 288 | total_seg_acc = 0.0 289 | total_seen = 0 290 | 291 | total_label_acc_per_cat = np.zeros((NUM_CATEGORIES)).astype(np.float32) 292 | total_seg_acc_per_cat = np.zeros((NUM_CATEGORIES)).astype(np.float32) 293 | total_seen_per_cat = np.zeros((NUM_CATEGORIES)).astype(np.int32) 294 | 295 | for i in range(num_test_file): 296 | cur_test_filename = os.path.join(hdf5_data_dir, test_file_list[i]) 297 | printout(flog, 'Loading test file ' + cur_test_filename) 298 | 299 | cur_data, cur_labels, cur_seg = provider.loadDataFile_with_seg(cur_test_filename) 300 | cur_labels = np.squeeze(cur_labels) 301 | 302 | cur_labels_one_hot = convert_label_to_one_hot(cur_labels) 303 | 304 | num_data = len(cur_labels) 305 | num_batch = num_data // batch_size 306 | 307 | for j in range(num_batch): 308 | begidx = j * batch_size 309 | endidx = (j + 1) * batch_size 310 | feed_dict = { 311 | pointclouds_ph: cur_data[begidx: endidx, ...], 312 | labels_ph: cur_labels[begidx: endidx, ...], 313 | input_label_ph: cur_labels_one_hot[begidx: endidx, ...], 314 | seg_ph: cur_seg[begidx: endidx, ...], 315 | is_training_ph: is_training, 316 | } 317 | 318 | loss_val, label_loss_val, seg_loss_val, per_instance_label_loss_val, \ 319 | per_instance_seg_loss_val, label_pred_val, seg_pred_val, pred_seg_res \ 320 | = sess.run([loss, label_loss, seg_loss, per_instance_label_loss, \ 321 | per_instance_seg_loss, labels_pred, seg_pred, per_instance_seg_pred_res], \ 322 | feed_dict=feed_dict) 323 | 324 | per_instance_part_acc = np.mean(pred_seg_res == cur_seg[begidx: endidx, ...], axis=1) 325 | average_part_acc = np.mean(per_instance_part_acc) 326 | 327 | total_seen += 1 328 | total_loss += loss_val 329 | total_label_loss += label_loss_val 330 | total_seg_loss += seg_loss_val 331 | 332 | per_instance_label_pred = np.argmax(label_pred_val, axis=1) 333 | total_label_acc += np.mean(np.float32(per_instance_label_pred == cur_labels[begidx: endidx, ...])) 334 | total_seg_acc += average_part_acc 335 | 336 | for shape_idx in range(begidx, endidx): 337 | total_seen_per_cat[cur_labels[shape_idx]] += 1 338 | total_label_acc_per_cat[cur_labels[shape_idx]] += np.int32(per_instance_label_pred[shape_idx-begidx] == cur_labels[shape_idx]) 339 | total_seg_acc_per_cat[cur_labels[shape_idx]] += per_instance_part_acc[shape_idx - begidx] 340 | 341 | total_loss = total_loss * 1.0 / total_seen 342 | total_label_loss = total_label_loss * 1.0 / total_seen 343 | total_seg_loss = total_seg_loss * 1.0 / total_seen 344 | total_label_acc = total_label_acc * 1.0 / total_seen 345 | total_seg_acc = total_seg_acc * 1.0 / total_seen 346 | 347 | test_loss_sum, test_label_acc_sum, test_label_loss_sum, test_seg_loss_sum, test_seg_acc_sum = sess.run(\ 348 | [total_test_loss_sum_op, label_test_acc_sum_op, label_test_loss_sum_op, seg_test_loss_sum_op, seg_test_acc_sum_op], \ 349 | feed_dict={total_testing_loss_ph: total_loss, label_testing_loss_ph: total_label_loss, \ 350 | seg_testing_loss_ph: total_seg_loss, label_testing_acc_ph: total_label_acc, seg_testing_acc_ph: total_seg_acc}) 351 | 352 | test_writer.add_summary(test_loss_sum, (epoch_num+1) * num_train_file-1) 353 | test_writer.add_summary(test_label_loss_sum, (epoch_num+1) * num_train_file-1) 354 | test_writer.add_summary(test_seg_loss_sum, (epoch_num+1) * num_train_file-1) 355 | test_writer.add_summary(test_label_acc_sum, (epoch_num+1) * num_train_file-1) 356 | test_writer.add_summary(test_seg_acc_sum, (epoch_num+1) * num_train_file-1) 357 | 358 | printout(flog, '\tTesting Total Mean_loss: %f' % total_loss) 359 | printout(flog, '\t\tTesting Label Mean_loss: %f' % total_label_loss) 360 | printout(flog, '\t\tTesting Label Accuracy: %f' % total_label_acc) 361 | printout(flog, '\t\tTesting Seg Mean_loss: %f' % total_seg_loss) 362 | printout(flog, '\t\tTesting Seg Accuracy: %f' % total_seg_acc) 363 | 364 | for cat_idx in range(NUM_CATEGORIES): 365 | if total_seen_per_cat[cat_idx] > 0: 366 | printout(flog, '\n\t\tCategory %s Object Number: %d' % (all_obj_cats[cat_idx][0], total_seen_per_cat[cat_idx])) 367 | printout(flog, '\t\tCategory %s Label Accuracy: %f' % (all_obj_cats[cat_idx][0], total_label_acc_per_cat[cat_idx]/total_seen_per_cat[cat_idx])) 368 | printout(flog, '\t\tCategory %s Seg Accuracy: %f' % (all_obj_cats[cat_idx][0], total_seg_acc_per_cat[cat_idx]/total_seen_per_cat[cat_idx])) 369 | 370 | if not os.path.exists(MODEL_STORAGE_PATH): 371 | os.mkdir(MODEL_STORAGE_PATH) 372 | 373 | for epoch in range(TRAINING_EPOCHES): 374 | printout(flog, '\n<<< Testing on the test dataset ...') 375 | eval_one_epoch(epoch) 376 | 377 | printout(flog, '\n>>> Training for the epoch %d/%d ...' % (epoch, TRAINING_EPOCHES)) 378 | 379 | train_file_idx = np.arange(0, len(train_file_list)) 380 | np.random.shuffle(train_file_idx) 381 | 382 | train_one_epoch(train_file_idx, epoch) 383 | 384 | if (epoch+1) % 10 == 0: 385 | cp_filename = saver.save(sess, os.path.join(MODEL_STORAGE_PATH, 'epoch_' + str(epoch+1)+'.ckpt')) 386 | printout(flog, 'Successfully store the checkpoint model into ' + cp_filename) 387 | 388 | flog.flush() 389 | 390 | flog.close() 391 | 392 | if __name__=='__main__': 393 | train() 394 | -------------------------------------------------------------------------------- /provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import h5py 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(BASE_DIR) 7 | 8 | # Download dataset for point cloud classification 9 | DATA_DIR = os.path.join(BASE_DIR, 'data') 10 | if not os.path.exists(DATA_DIR): 11 | os.mkdir(DATA_DIR) 12 | if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): 13 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 14 | zipfile = os.path.basename(www) 15 | os.system('wget %s; unzip %s' % (www, zipfile)) 16 | os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) 17 | os.system('rm %s' % (zipfile)) 18 | 19 | 20 | def shuffle_data(data, labels): 21 | """ Shuffle data and labels. 22 | Input: 23 | data: B,N,... numpy array 24 | label: B,... numpy array 25 | Return: 26 | shuffled data, label and shuffle indices 27 | """ 28 | idx = np.arange(len(labels)) 29 | np.random.shuffle(idx) 30 | return data[idx, ...], labels[idx], idx 31 | 32 | 33 | def rotate_point_cloud(batch_data): 34 | """ Randomly rotate the point clouds to augument the dataset 35 | rotation is per shape based along up direction 36 | Input: 37 | BxNx3 array, original batch of point clouds 38 | Return: 39 | BxNx3 array, rotated batch of point clouds 40 | """ 41 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 42 | for k in range(batch_data.shape[0]): 43 | rotation_angle = np.random.uniform() * 2 * np.pi 44 | cosval = np.cos(rotation_angle) 45 | sinval = np.sin(rotation_angle) 46 | rotation_matrix = np.array([[cosval, 0, sinval], 47 | [0, 1, 0], 48 | [-sinval, 0, cosval]]) 49 | shape_pc = batch_data[k, ...] 50 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 51 | return rotated_data 52 | 53 | 54 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 55 | """ Rotate the point cloud along up direction with certain angle. 56 | Input: 57 | BxNx3 array, original batch of point clouds 58 | Return: 59 | BxNx3 array, rotated batch of point clouds 60 | """ 61 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 62 | for k in range(batch_data.shape[0]): 63 | #rotation_angle = np.random.uniform() * 2 * np.pi 64 | cosval = np.cos(rotation_angle) 65 | sinval = np.sin(rotation_angle) 66 | rotation_matrix = np.array([[cosval, 0, sinval], 67 | [0, 1, 0], 68 | [-sinval, 0, cosval]]) 69 | shape_pc = batch_data[k, ...] 70 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 71 | return rotated_data 72 | 73 | 74 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 75 | """ Randomly jitter points. jittering is per point. 76 | Input: 77 | BxNx3 array, original batch of point clouds 78 | Return: 79 | BxNx3 array, jittered batch of point clouds 80 | """ 81 | B, N, C = batch_data.shape 82 | assert(clip > 0) 83 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 84 | jittered_data += batch_data 85 | return jittered_data 86 | 87 | def getDataFiles(list_filename): 88 | return [line.rstrip() for line in open(list_filename)] 89 | 90 | def load_h5(h5_filename): 91 | f = h5py.File(h5_filename) 92 | data = f['data'][:] 93 | label = f['label'][:] 94 | return (data, label) 95 | 96 | def loadDataFile(filename): 97 | return load_h5(filename) 98 | 99 | def load_h5_data_label_seg(h5_filename): 100 | f = h5py.File(h5_filename) 101 | data = f['data'][:] 102 | label = f['label'][:] 103 | seg = f['pid'][:] 104 | return (data, label, seg) 105 | 106 | 107 | def loadDataFile_with_seg(filename): 108 | return load_h5_data_label_seg(filename) 109 | -------------------------------------------------------------------------------- /sem_seg/README.md: -------------------------------------------------------------------------------- 1 | ## Semantic Segmentation of Indoor Scenes 2 | 3 | ### Dataset 4 | 5 | Donwload prepared HDF5 data for training: 6 | 7 | sh download_data.sh 8 | 9 | (optional) Download 3D indoor parsing dataset (S3DIS Dataset) for testing and visualization. Version 1.2 of the dataset is used in this work. 10 | 11 | 12 | To prepare your own HDF5 data, you need to firstly download 3D indoor parsing dataset and then use `python collect_indoor3d_data.py` for data re-organization and `python gen_indoor3d_h5.py` to generate HDF5 files. 13 | 14 | ### Training 15 | 16 | Once you have downloaded prepared HDF5 files or prepared them by yourself, to start training: 17 | 18 | python train.py --log_dir log6 --test_area 6 19 | 20 | In default a simple model based on vanilla PointNet is used for training. Area 6 is used for test set. 21 | 22 | ### Testing 23 | 24 | Testing requires download of 3D indoor parsing data and preprocessing with `collect_indoor3d_data.py` 25 | 26 | After training, use `batch_inference.py` command to segment rooms in test set. In our work we use 6-fold training that trains 6 models. For model1 , area2-6 are used as train set, area1 is used as test set. For model2, area1,3-6 are used as train set and area2 is used as test set... Note that S3DIS dataset paper uses a different 3-fold training, which was not publicly announced at the time of our work. 27 | 28 | For example, to test model6, use command: 29 | 30 | python batch_inference.py --model_path log6/model.ckpt --dump_dir log6/dump --output_filelist log6/output_filelist.txt --room_data_filelist meta/area6_data_label.txt --visu 31 | 32 | Some OBJ files will be created for prediciton visualization in `log6/dump`. 33 | 34 | To evaluate overall segmentation accuracy, we evaluate 6 models on their corresponding test areas and use `eval_iou_accuracy.py` to produce point classification accuracy and IoU as reported in the paper. 35 | 36 | 37 | -------------------------------------------------------------------------------- /sem_seg/batch_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | ROOT_DIR = os.path.dirname(BASE_DIR) 6 | sys.path.append(BASE_DIR) 7 | from model import * 8 | import indoor3d_util 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 12 | parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during training [default: 1]') 13 | parser.add_argument('--num_point', type=int, default=4096, help='Point number [default: 4096]') 14 | parser.add_argument('--model_path', required=True, help='model checkpoint file path') 15 | parser.add_argument('--dump_dir', required=True, help='dump folder path') 16 | parser.add_argument('--output_filelist', required=True, help='TXT filename, filelist, each line is an output for a room') 17 | parser.add_argument('--room_data_filelist', required=True, help='TXT filename, filelist, each line is a test room data label file.') 18 | parser.add_argument('--no_clutter', action='store_true', help='If true, donot count the clutter class') 19 | parser.add_argument('--visu', action='store_true', help='Whether to output OBJ file for prediction visualization.') 20 | FLAGS = parser.parse_args() 21 | 22 | BATCH_SIZE = FLAGS.batch_size 23 | NUM_POINT = FLAGS.num_point 24 | MODEL_PATH = FLAGS.model_path 25 | GPU_INDEX = FLAGS.gpu 26 | DUMP_DIR = FLAGS.dump_dir 27 | if not os.path.exists(DUMP_DIR): os.mkdir(DUMP_DIR) 28 | LOG_FOUT = open(os.path.join(DUMP_DIR, 'log_evaluate.txt'), 'w') 29 | LOG_FOUT.write(str(FLAGS)+'\n') 30 | ROOM_PATH_LIST = [os.path.join(ROOT_DIR,line.rstrip()) for line in open(FLAGS.room_data_filelist)] 31 | 32 | NUM_CLASSES = 13 33 | 34 | def log_string(out_str): 35 | LOG_FOUT.write(out_str+'\n') 36 | LOG_FOUT.flush() 37 | print(out_str) 38 | 39 | def evaluate(): 40 | is_training = False 41 | 42 | with tf.device('/gpu:'+str(GPU_INDEX)): 43 | pointclouds_pl, labels_pl = placeholder_inputs(BATCH_SIZE, NUM_POINT) 44 | is_training_pl = tf.placeholder(tf.bool, shape=()) 45 | 46 | # simple model 47 | pred = get_model(pointclouds_pl, is_training_pl) 48 | loss = get_loss(pred, labels_pl) 49 | pred_softmax = tf.nn.softmax(pred) 50 | 51 | # Add ops to save and restore all the variables. 52 | saver = tf.train.Saver() 53 | 54 | # Create a session 55 | config = tf.ConfigProto() 56 | config.gpu_options.allow_growth = True 57 | config.allow_soft_placement = True 58 | config.log_device_placement = True 59 | sess = tf.Session(config=config) 60 | 61 | # Restore variables from disk. 62 | saver.restore(sess, MODEL_PATH) 63 | log_string("Model restored.") 64 | 65 | ops = {'pointclouds_pl': pointclouds_pl, 66 | 'labels_pl': labels_pl, 67 | 'is_training_pl': is_training_pl, 68 | 'pred': pred, 69 | 'pred_softmax': pred_softmax, 70 | 'loss': loss} 71 | 72 | total_correct = 0 73 | total_seen = 0 74 | fout_out_filelist = open(FLAGS.output_filelist, 'w') 75 | for room_path in ROOM_PATH_LIST: 76 | out_data_label_filename = os.path.basename(room_path)[:-4] + '_pred.txt' 77 | out_data_label_filename = os.path.join(DUMP_DIR, out_data_label_filename) 78 | out_gt_label_filename = os.path.basename(room_path)[:-4] + '_gt.txt' 79 | out_gt_label_filename = os.path.join(DUMP_DIR, out_gt_label_filename) 80 | print(room_path, out_data_label_filename) 81 | a, b = eval_one_epoch(sess, ops, room_path, out_data_label_filename, out_gt_label_filename) 82 | total_correct += a 83 | total_seen += b 84 | fout_out_filelist.write(out_data_label_filename+'\n') 85 | fout_out_filelist.close() 86 | log_string('all room eval accuracy: %f'% (total_correct / float(total_seen))) 87 | 88 | def eval_one_epoch(sess, ops, room_path, out_data_label_filename, out_gt_label_filename): 89 | error_cnt = 0 90 | is_training = False 91 | total_correct = 0 92 | total_seen = 0 93 | loss_sum = 0 94 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 95 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 96 | if FLAGS.visu: 97 | fout = open(os.path.join(DUMP_DIR, os.path.basename(room_path)[:-4]+'_pred.obj'), 'w') 98 | fout_gt = open(os.path.join(DUMP_DIR, os.path.basename(room_path)[:-4]+'_gt.obj'), 'w') 99 | fout_data_label = open(out_data_label_filename, 'w') 100 | fout_gt_label = open(out_gt_label_filename, 'w') 101 | 102 | current_data, current_label = indoor3d_util.room2blocks_wrapper_normalized(room_path, NUM_POINT) 103 | current_data = current_data[:,0:NUM_POINT,:] 104 | current_label = np.squeeze(current_label) 105 | # Get room dimension.. 106 | data_label = np.load(room_path) 107 | data = data_label[:,0:6] 108 | max_room_x = max(data[:,0]) 109 | max_room_y = max(data[:,1]) 110 | max_room_z = max(data[:,2]) 111 | 112 | file_size = current_data.shape[0] 113 | num_batches = file_size // BATCH_SIZE 114 | print(file_size) 115 | 116 | 117 | for batch_idx in range(num_batches): 118 | start_idx = batch_idx * BATCH_SIZE 119 | end_idx = (batch_idx+1) * BATCH_SIZE 120 | cur_batch_size = end_idx - start_idx 121 | 122 | feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :], 123 | ops['labels_pl']: current_label[start_idx:end_idx], 124 | ops['is_training_pl']: is_training} 125 | loss_val, pred_val = sess.run([ops['loss'], ops['pred_softmax']], 126 | feed_dict=feed_dict) 127 | 128 | if FLAGS.no_clutter: 129 | pred_label = np.argmax(pred_val[:,:,0:12], 2) # BxN 130 | else: 131 | pred_label = np.argmax(pred_val, 2) # BxN 132 | # Save prediction labels to OBJ file 133 | for b in range(BATCH_SIZE): 134 | pts = current_data[start_idx+b, :, :] 135 | l = current_label[start_idx+b,:] 136 | pts[:,6] *= max_room_x 137 | pts[:,7] *= max_room_y 138 | pts[:,8] *= max_room_z 139 | pts[:,3:6] *= 255.0 140 | pred = pred_label[b, :] 141 | for i in range(NUM_POINT): 142 | color = indoor3d_util.g_label2color[pred[i]] 143 | color_gt = indoor3d_util.g_label2color[current_label[start_idx+b, i]] 144 | if FLAGS.visu: 145 | fout.write('v %f %f %f %d %d %d\n' % (pts[i,6], pts[i,7], pts[i,8], color[0], color[1], color[2])) 146 | fout_gt.write('v %f %f %f %d %d %d\n' % (pts[i,6], pts[i,7], pts[i,8], color_gt[0], color_gt[1], color_gt[2])) 147 | fout_data_label.write('%f %f %f %d %d %d %f %d\n' % (pts[i,6], pts[i,7], pts[i,8], pts[i,3], pts[i,4], pts[i,5], pred_val[b,i,pred[i]], pred[i])) 148 | fout_gt_label.write('%d\n' % (l[i])) 149 | correct = np.sum(pred_label == current_label[start_idx:end_idx,:]) 150 | total_correct += correct 151 | total_seen += (cur_batch_size*NUM_POINT) 152 | loss_sum += (loss_val*BATCH_SIZE) 153 | for i in range(start_idx, end_idx): 154 | for j in range(NUM_POINT): 155 | l = current_label[i, j] 156 | total_seen_class[l] += 1 157 | total_correct_class[l] += (pred_label[i-start_idx, j] == l) 158 | 159 | log_string('eval mean loss: %f' % (loss_sum / float(total_seen/NUM_POINT))) 160 | log_string('eval accuracy: %f'% (total_correct / float(total_seen))) 161 | fout_data_label.close() 162 | fout_gt_label.close() 163 | if FLAGS.visu: 164 | fout.close() 165 | fout_gt.close() 166 | return total_correct, total_seen 167 | 168 | 169 | if __name__=='__main__': 170 | with tf.Graph().as_default(): 171 | evaluate() 172 | LOG_FOUT.close() 173 | -------------------------------------------------------------------------------- /sem_seg/collect_indoor3d_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | ROOT_DIR = os.path.dirname(BASE_DIR) 5 | sys.path.append(BASE_DIR) 6 | import indoor3d_util 7 | 8 | anno_paths = [line.rstrip() for line in open(os.path.join(BASE_DIR, 'meta/anno_paths.txt'))] 9 | anno_paths = [os.path.join(indoor3d_util.DATA_PATH, p) for p in anno_paths] 10 | 11 | output_folder = os.path.join(ROOT_DIR, 'data/stanford_indoor3d') 12 | if not os.path.exists(output_folder): 13 | os.mkdir(output_folder) 14 | 15 | # Note: there is an extra character in the v1.2 data in Area_5/hallway_6. It's fixed manually. 16 | for anno_path in anno_paths: 17 | print(anno_path) 18 | try: 19 | elements = anno_path.split('/') 20 | out_filename = elements[-3]+'_'+elements[-2]+'.npy' # Area_1_hallway_1.npy 21 | indoor3d_util.collect_point_label(anno_path, os.path.join(output_folder, out_filename), 'numpy') 22 | except: 23 | print(anno_path, 'ERROR!!') 24 | -------------------------------------------------------------------------------- /sem_seg/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download HDF5 for indoor 3d semantic segmentation (around 1.6GB) 4 | wget https://shapenet.cs.stanford.edu/media/indoor3d_sem_seg_hdf5_data.zip 5 | unzip indoor3d_sem_seg_hdf5_data.zip 6 | rm indoor3d_sem_seg_hdf5_data.zip 7 | 8 | -------------------------------------------------------------------------------- /sem_seg/eval_iou_accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | pred_data_label_filenames = [line.rstrip() for line in open('all_pred_data_label_filelist.txt')] 4 | gt_label_filenames = [f.rstrip('_pred\.txt') + '_gt.txt' for f in pred_data_label_filenames] 5 | num_room = len(gt_label_filenames) 6 | 7 | 8 | gt_classes = [0 for _ in range(13)] 9 | positive_classes = [0 for _ in range(13)] 10 | true_positive_classes = [0 for _ in range(13)] 11 | for i in range(num_room): 12 | print(i) 13 | data_label = np.loadtxt(pred_data_label_filenames[i]) 14 | pred_label = data_label[:,-1] 15 | gt_label = np.loadtxt(gt_label_filenames[i]) 16 | print(gt_label.shape) 17 | for j in xrange(gt_label.shape[0]): 18 | gt_l = int(gt_label[j]) 19 | pred_l = int(pred_label[j]) 20 | gt_classes[gt_l] += 1 21 | positive_classes[pred_l] += 1 22 | true_positive_classes[gt_l] += int(gt_l==pred_l) 23 | 24 | 25 | print(gt_classes) 26 | print(positive_classes) 27 | print(true_positive_classes) 28 | 29 | 30 | print('Overall accuracy: {0}'.format(sum(true_positive_classes)/float(sum(positive_classes)))) 31 | 32 | print 'IoU:' 33 | iou_list = [] 34 | for i in range(13): 35 | iou = true_positive_classes[i]/float(gt_classes[i]+positive_classes[i]-true_positive_classes[i]) 36 | print(iou) 37 | iou_list.append(iou) 38 | 39 | print(sum(iou_list)/13.0) 40 | -------------------------------------------------------------------------------- /sem_seg/gen_indoor3d_h5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | ROOT_DIR = os.path.dirname(BASE_DIR) 6 | sys.path.append(BASE_DIR) 7 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 8 | import data_prep_util 9 | import indoor3d_util 10 | 11 | # Constants 12 | data_dir = os.path.join(ROOT_DIR, 'data') 13 | indoor3d_data_dir = os.path.join(data_dir, 'stanford_indoor3d') 14 | NUM_POINT = 4096 15 | H5_BATCH_SIZE = 1000 16 | data_dim = [NUM_POINT, 9] 17 | label_dim = [NUM_POINT] 18 | data_dtype = 'float32' 19 | label_dtype = 'uint8' 20 | 21 | # Set paths 22 | filelist = os.path.join(BASE_DIR, 'meta/all_data_label.txt') 23 | data_label_files = [os.path.join(indoor3d_data_dir, line.rstrip()) for line in open(filelist)] 24 | output_dir = os.path.join(data_dir, 'indoor3d_sem_seg_hdf5_data') 25 | if not os.path.exists(output_dir): 26 | os.mkdir(output_dir) 27 | output_filename_prefix = os.path.join(output_dir, 'ply_data_all') 28 | output_room_filelist = os.path.join(output_dir, 'room_filelist.txt') 29 | fout_room = open(output_room_filelist, 'w') 30 | 31 | # -------------------------------------- 32 | # ----- BATCH WRITE TO HDF5 ----- 33 | # -------------------------------------- 34 | batch_data_dim = [H5_BATCH_SIZE] + data_dim 35 | batch_label_dim = [H5_BATCH_SIZE] + label_dim 36 | h5_batch_data = np.zeros(batch_data_dim, dtype = np.float32) 37 | h5_batch_label = np.zeros(batch_label_dim, dtype = np.uint8) 38 | buffer_size = 0 # state: record how many samples are currently in buffer 39 | h5_index = 0 # state: the next h5 file to save 40 | 41 | def insert_batch(data, label, last_batch=False): 42 | global h5_batch_data, h5_batch_label 43 | global buffer_size, h5_index 44 | data_size = data.shape[0] 45 | # If there is enough space, just insert 46 | if buffer_size + data_size <= h5_batch_data.shape[0]: 47 | h5_batch_data[buffer_size:buffer_size+data_size, ...] = data 48 | h5_batch_label[buffer_size:buffer_size+data_size] = label 49 | buffer_size += data_size 50 | else: # not enough space 51 | capacity = h5_batch_data.shape[0] - buffer_size 52 | assert(capacity>=0) 53 | if capacity > 0: 54 | h5_batch_data[buffer_size:buffer_size+capacity, ...] = data[0:capacity, ...] 55 | h5_batch_label[buffer_size:buffer_size+capacity, ...] = label[0:capacity, ...] 56 | # Save batch data and label to h5 file, reset buffer_size 57 | h5_filename = output_filename_prefix + '_' + str(h5_index) + '.h5' 58 | data_prep_util.save_h5(h5_filename, h5_batch_data, h5_batch_label, data_dtype, label_dtype) 59 | print('Stored {0} with size {1}'.format(h5_filename, h5_batch_data.shape[0])) 60 | h5_index += 1 61 | buffer_size = 0 62 | # recursive call 63 | insert_batch(data[capacity:, ...], label[capacity:, ...], last_batch) 64 | if last_batch and buffer_size > 0: 65 | h5_filename = output_filename_prefix + '_' + str(h5_index) + '.h5' 66 | data_prep_util.save_h5(h5_filename, h5_batch_data[0:buffer_size, ...], h5_batch_label[0:buffer_size, ...], data_dtype, label_dtype) 67 | print('Stored {0} with size {1}'.format(h5_filename, buffer_size)) 68 | h5_index += 1 69 | buffer_size = 0 70 | return 71 | 72 | 73 | sample_cnt = 0 74 | for i, data_label_filename in enumerate(data_label_files): 75 | print(data_label_filename) 76 | data, label = indoor3d_util.room2blocks_wrapper_normalized(data_label_filename, NUM_POINT, block_size=1.0, stride=0.5, 77 | random_sample=False, sample_num=None) 78 | print('{0}, {1}'.format(data.shape, label.shape)) 79 | for _ in range(data.shape[0]): 80 | fout_room.write(os.path.basename(data_label_filename)[0:-4]+'\n') 81 | 82 | sample_cnt += data.shape[0] 83 | insert_batch(data, label, i == len(data_label_files)-1) 84 | 85 | fout_room.close() 86 | print("Total samples: {0}".format(sample_cnt)) 87 | -------------------------------------------------------------------------------- /sem_seg/meta/all_data_label.txt: -------------------------------------------------------------------------------- 1 | Area_1_conferenceRoom_1.npy 2 | Area_1_conferenceRoom_2.npy 3 | Area_1_copyRoom_1.npy 4 | Area_1_hallway_1.npy 5 | Area_1_hallway_2.npy 6 | Area_1_hallway_3.npy 7 | Area_1_hallway_4.npy 8 | Area_1_hallway_5.npy 9 | Area_1_hallway_6.npy 10 | Area_1_hallway_7.npy 11 | Area_1_hallway_8.npy 12 | Area_1_office_10.npy 13 | Area_1_office_11.npy 14 | Area_1_office_12.npy 15 | Area_1_office_13.npy 16 | Area_1_office_14.npy 17 | Area_1_office_15.npy 18 | Area_1_office_16.npy 19 | Area_1_office_17.npy 20 | Area_1_office_18.npy 21 | Area_1_office_19.npy 22 | Area_1_office_1.npy 23 | Area_1_office_20.npy 24 | Area_1_office_21.npy 25 | Area_1_office_22.npy 26 | Area_1_office_23.npy 27 | Area_1_office_24.npy 28 | Area_1_office_25.npy 29 | Area_1_office_26.npy 30 | Area_1_office_27.npy 31 | Area_1_office_28.npy 32 | Area_1_office_29.npy 33 | Area_1_office_2.npy 34 | Area_1_office_30.npy 35 | Area_1_office_31.npy 36 | Area_1_office_3.npy 37 | Area_1_office_4.npy 38 | Area_1_office_5.npy 39 | Area_1_office_6.npy 40 | Area_1_office_7.npy 41 | Area_1_office_8.npy 42 | Area_1_office_9.npy 43 | Area_1_pantry_1.npy 44 | Area_1_WC_1.npy 45 | Area_2_auditorium_1.npy 46 | Area_2_auditorium_2.npy 47 | Area_2_conferenceRoom_1.npy 48 | Area_2_hallway_10.npy 49 | Area_2_hallway_11.npy 50 | Area_2_hallway_12.npy 51 | Area_2_hallway_1.npy 52 | Area_2_hallway_2.npy 53 | Area_2_hallway_3.npy 54 | Area_2_hallway_4.npy 55 | Area_2_hallway_5.npy 56 | Area_2_hallway_6.npy 57 | Area_2_hallway_7.npy 58 | Area_2_hallway_8.npy 59 | Area_2_hallway_9.npy 60 | Area_2_office_10.npy 61 | Area_2_office_11.npy 62 | Area_2_office_12.npy 63 | Area_2_office_13.npy 64 | Area_2_office_14.npy 65 | Area_2_office_1.npy 66 | Area_2_office_2.npy 67 | Area_2_office_3.npy 68 | Area_2_office_4.npy 69 | Area_2_office_5.npy 70 | Area_2_office_6.npy 71 | Area_2_office_7.npy 72 | Area_2_office_8.npy 73 | Area_2_office_9.npy 74 | Area_2_storage_1.npy 75 | Area_2_storage_2.npy 76 | Area_2_storage_3.npy 77 | Area_2_storage_4.npy 78 | Area_2_storage_5.npy 79 | Area_2_storage_6.npy 80 | Area_2_storage_7.npy 81 | Area_2_storage_8.npy 82 | Area_2_storage_9.npy 83 | Area_2_WC_1.npy 84 | Area_2_WC_2.npy 85 | Area_3_conferenceRoom_1.npy 86 | Area_3_hallway_1.npy 87 | Area_3_hallway_2.npy 88 | Area_3_hallway_3.npy 89 | Area_3_hallway_4.npy 90 | Area_3_hallway_5.npy 91 | Area_3_hallway_6.npy 92 | Area_3_lounge_1.npy 93 | Area_3_lounge_2.npy 94 | Area_3_office_10.npy 95 | Area_3_office_1.npy 96 | Area_3_office_2.npy 97 | Area_3_office_3.npy 98 | Area_3_office_4.npy 99 | Area_3_office_5.npy 100 | Area_3_office_6.npy 101 | Area_3_office_7.npy 102 | Area_3_office_8.npy 103 | Area_3_office_9.npy 104 | Area_3_storage_1.npy 105 | Area_3_storage_2.npy 106 | Area_3_WC_1.npy 107 | Area_3_WC_2.npy 108 | Area_4_conferenceRoom_1.npy 109 | Area_4_conferenceRoom_2.npy 110 | Area_4_conferenceRoom_3.npy 111 | Area_4_hallway_10.npy 112 | Area_4_hallway_11.npy 113 | Area_4_hallway_12.npy 114 | Area_4_hallway_13.npy 115 | Area_4_hallway_14.npy 116 | Area_4_hallway_1.npy 117 | Area_4_hallway_2.npy 118 | Area_4_hallway_3.npy 119 | Area_4_hallway_4.npy 120 | Area_4_hallway_5.npy 121 | Area_4_hallway_6.npy 122 | Area_4_hallway_7.npy 123 | Area_4_hallway_8.npy 124 | Area_4_hallway_9.npy 125 | Area_4_lobby_1.npy 126 | Area_4_lobby_2.npy 127 | Area_4_office_10.npy 128 | Area_4_office_11.npy 129 | Area_4_office_12.npy 130 | Area_4_office_13.npy 131 | Area_4_office_14.npy 132 | Area_4_office_15.npy 133 | Area_4_office_16.npy 134 | Area_4_office_17.npy 135 | Area_4_office_18.npy 136 | Area_4_office_19.npy 137 | Area_4_office_1.npy 138 | Area_4_office_20.npy 139 | Area_4_office_21.npy 140 | Area_4_office_22.npy 141 | Area_4_office_2.npy 142 | Area_4_office_3.npy 143 | Area_4_office_4.npy 144 | Area_4_office_5.npy 145 | Area_4_office_6.npy 146 | Area_4_office_7.npy 147 | Area_4_office_8.npy 148 | Area_4_office_9.npy 149 | Area_4_storage_1.npy 150 | Area_4_storage_2.npy 151 | Area_4_storage_3.npy 152 | Area_4_storage_4.npy 153 | Area_4_WC_1.npy 154 | Area_4_WC_2.npy 155 | Area_4_WC_3.npy 156 | Area_4_WC_4.npy 157 | Area_5_conferenceRoom_1.npy 158 | Area_5_conferenceRoom_2.npy 159 | Area_5_conferenceRoom_3.npy 160 | Area_5_hallway_10.npy 161 | Area_5_hallway_11.npy 162 | Area_5_hallway_12.npy 163 | Area_5_hallway_13.npy 164 | Area_5_hallway_14.npy 165 | Area_5_hallway_15.npy 166 | Area_5_hallway_1.npy 167 | Area_5_hallway_2.npy 168 | Area_5_hallway_3.npy 169 | Area_5_hallway_4.npy 170 | Area_5_hallway_5.npy 171 | Area_5_hallway_6.npy 172 | Area_5_hallway_7.npy 173 | Area_5_hallway_8.npy 174 | Area_5_hallway_9.npy 175 | Area_5_lobby_1.npy 176 | Area_5_office_10.npy 177 | Area_5_office_11.npy 178 | Area_5_office_12.npy 179 | Area_5_office_13.npy 180 | Area_5_office_14.npy 181 | Area_5_office_15.npy 182 | Area_5_office_16.npy 183 | Area_5_office_17.npy 184 | Area_5_office_18.npy 185 | Area_5_office_19.npy 186 | Area_5_office_1.npy 187 | Area_5_office_20.npy 188 | Area_5_office_21.npy 189 | Area_5_office_22.npy 190 | Area_5_office_23.npy 191 | Area_5_office_24.npy 192 | Area_5_office_25.npy 193 | Area_5_office_26.npy 194 | Area_5_office_27.npy 195 | Area_5_office_28.npy 196 | Area_5_office_29.npy 197 | Area_5_office_2.npy 198 | Area_5_office_30.npy 199 | Area_5_office_31.npy 200 | Area_5_office_32.npy 201 | Area_5_office_33.npy 202 | Area_5_office_34.npy 203 | Area_5_office_35.npy 204 | Area_5_office_36.npy 205 | Area_5_office_37.npy 206 | Area_5_office_38.npy 207 | Area_5_office_39.npy 208 | Area_5_office_3.npy 209 | Area_5_office_40.npy 210 | Area_5_office_41.npy 211 | Area_5_office_42.npy 212 | Area_5_office_4.npy 213 | Area_5_office_5.npy 214 | Area_5_office_6.npy 215 | Area_5_office_7.npy 216 | Area_5_office_8.npy 217 | Area_5_office_9.npy 218 | Area_5_pantry_1.npy 219 | Area_5_storage_1.npy 220 | Area_5_storage_2.npy 221 | Area_5_storage_3.npy 222 | Area_5_storage_4.npy 223 | Area_5_WC_1.npy 224 | Area_5_WC_2.npy 225 | Area_6_conferenceRoom_1.npy 226 | Area_6_copyRoom_1.npy 227 | Area_6_hallway_1.npy 228 | Area_6_hallway_2.npy 229 | Area_6_hallway_3.npy 230 | Area_6_hallway_4.npy 231 | Area_6_hallway_5.npy 232 | Area_6_hallway_6.npy 233 | Area_6_lounge_1.npy 234 | Area_6_office_10.npy 235 | Area_6_office_11.npy 236 | Area_6_office_12.npy 237 | Area_6_office_13.npy 238 | Area_6_office_14.npy 239 | Area_6_office_15.npy 240 | Area_6_office_16.npy 241 | Area_6_office_17.npy 242 | Area_6_office_18.npy 243 | Area_6_office_19.npy 244 | Area_6_office_1.npy 245 | Area_6_office_20.npy 246 | Area_6_office_21.npy 247 | Area_6_office_22.npy 248 | Area_6_office_23.npy 249 | Area_6_office_24.npy 250 | Area_6_office_25.npy 251 | Area_6_office_26.npy 252 | Area_6_office_27.npy 253 | Area_6_office_28.npy 254 | Area_6_office_29.npy 255 | Area_6_office_2.npy 256 | Area_6_office_30.npy 257 | Area_6_office_31.npy 258 | Area_6_office_32.npy 259 | Area_6_office_33.npy 260 | Area_6_office_34.npy 261 | Area_6_office_35.npy 262 | Area_6_office_36.npy 263 | Area_6_office_37.npy 264 | Area_6_office_3.npy 265 | Area_6_office_4.npy 266 | Area_6_office_5.npy 267 | Area_6_office_6.npy 268 | Area_6_office_7.npy 269 | Area_6_office_8.npy 270 | Area_6_office_9.npy 271 | Area_6_openspace_1.npy 272 | Area_6_pantry_1.npy 273 | -------------------------------------------------------------------------------- /sem_seg/meta/anno_paths.txt: -------------------------------------------------------------------------------- 1 | Area_1/conferenceRoom_1/Annotations 2 | Area_1/conferenceRoom_2/Annotations 3 | Area_1/copyRoom_1/Annotations 4 | Area_1/hallway_1/Annotations 5 | Area_1/hallway_2/Annotations 6 | Area_1/hallway_3/Annotations 7 | Area_1/hallway_4/Annotations 8 | Area_1/hallway_5/Annotations 9 | Area_1/hallway_6/Annotations 10 | Area_1/hallway_7/Annotations 11 | Area_1/hallway_8/Annotations 12 | Area_1/office_10/Annotations 13 | Area_1/office_11/Annotations 14 | Area_1/office_12/Annotations 15 | Area_1/office_13/Annotations 16 | Area_1/office_14/Annotations 17 | Area_1/office_15/Annotations 18 | Area_1/office_16/Annotations 19 | Area_1/office_17/Annotations 20 | Area_1/office_18/Annotations 21 | Area_1/office_19/Annotations 22 | Area_1/office_1/Annotations 23 | Area_1/office_20/Annotations 24 | Area_1/office_21/Annotations 25 | Area_1/office_22/Annotations 26 | Area_1/office_23/Annotations 27 | Area_1/office_24/Annotations 28 | Area_1/office_25/Annotations 29 | Area_1/office_26/Annotations 30 | Area_1/office_27/Annotations 31 | Area_1/office_28/Annotations 32 | Area_1/office_29/Annotations 33 | Area_1/office_2/Annotations 34 | Area_1/office_30/Annotations 35 | Area_1/office_31/Annotations 36 | Area_1/office_3/Annotations 37 | Area_1/office_4/Annotations 38 | Area_1/office_5/Annotations 39 | Area_1/office_6/Annotations 40 | Area_1/office_7/Annotations 41 | Area_1/office_8/Annotations 42 | Area_1/office_9/Annotations 43 | Area_1/pantry_1/Annotations 44 | Area_1/WC_1/Annotations 45 | Area_2/auditorium_1/Annotations 46 | Area_2/auditorium_2/Annotations 47 | Area_2/conferenceRoom_1/Annotations 48 | Area_2/hallway_10/Annotations 49 | Area_2/hallway_11/Annotations 50 | Area_2/hallway_12/Annotations 51 | Area_2/hallway_1/Annotations 52 | Area_2/hallway_2/Annotations 53 | Area_2/hallway_3/Annotations 54 | Area_2/hallway_4/Annotations 55 | Area_2/hallway_5/Annotations 56 | Area_2/hallway_6/Annotations 57 | Area_2/hallway_7/Annotations 58 | Area_2/hallway_8/Annotations 59 | Area_2/hallway_9/Annotations 60 | Area_2/office_10/Annotations 61 | Area_2/office_11/Annotations 62 | Area_2/office_12/Annotations 63 | Area_2/office_13/Annotations 64 | Area_2/office_14/Annotations 65 | Area_2/office_1/Annotations 66 | Area_2/office_2/Annotations 67 | Area_2/office_3/Annotations 68 | Area_2/office_4/Annotations 69 | Area_2/office_5/Annotations 70 | Area_2/office_6/Annotations 71 | Area_2/office_7/Annotations 72 | Area_2/office_8/Annotations 73 | Area_2/office_9/Annotations 74 | Area_2/storage_1/Annotations 75 | Area_2/storage_2/Annotations 76 | Area_2/storage_3/Annotations 77 | Area_2/storage_4/Annotations 78 | Area_2/storage_5/Annotations 79 | Area_2/storage_6/Annotations 80 | Area_2/storage_7/Annotations 81 | Area_2/storage_8/Annotations 82 | Area_2/storage_9/Annotations 83 | Area_2/WC_1/Annotations 84 | Area_2/WC_2/Annotations 85 | Area_3/conferenceRoom_1/Annotations 86 | Area_3/hallway_1/Annotations 87 | Area_3/hallway_2/Annotations 88 | Area_3/hallway_3/Annotations 89 | Area_3/hallway_4/Annotations 90 | Area_3/hallway_5/Annotations 91 | Area_3/hallway_6/Annotations 92 | Area_3/lounge_1/Annotations 93 | Area_3/lounge_2/Annotations 94 | Area_3/office_10/Annotations 95 | Area_3/office_1/Annotations 96 | Area_3/office_2/Annotations 97 | Area_3/office_3/Annotations 98 | Area_3/office_4/Annotations 99 | Area_3/office_5/Annotations 100 | Area_3/office_6/Annotations 101 | Area_3/office_7/Annotations 102 | Area_3/office_8/Annotations 103 | Area_3/office_9/Annotations 104 | Area_3/storage_1/Annotations 105 | Area_3/storage_2/Annotations 106 | Area_3/WC_1/Annotations 107 | Area_3/WC_2/Annotations 108 | Area_4/conferenceRoom_1/Annotations 109 | Area_4/conferenceRoom_2/Annotations 110 | Area_4/conferenceRoom_3/Annotations 111 | Area_4/hallway_10/Annotations 112 | Area_4/hallway_11/Annotations 113 | Area_4/hallway_12/Annotations 114 | Area_4/hallway_13/Annotations 115 | Area_4/hallway_14/Annotations 116 | Area_4/hallway_1/Annotations 117 | Area_4/hallway_2/Annotations 118 | Area_4/hallway_3/Annotations 119 | Area_4/hallway_4/Annotations 120 | Area_4/hallway_5/Annotations 121 | Area_4/hallway_6/Annotations 122 | Area_4/hallway_7/Annotations 123 | Area_4/hallway_8/Annotations 124 | Area_4/hallway_9/Annotations 125 | Area_4/lobby_1/Annotations 126 | Area_4/lobby_2/Annotations 127 | Area_4/office_10/Annotations 128 | Area_4/office_11/Annotations 129 | Area_4/office_12/Annotations 130 | Area_4/office_13/Annotations 131 | Area_4/office_14/Annotations 132 | Area_4/office_15/Annotations 133 | Area_4/office_16/Annotations 134 | Area_4/office_17/Annotations 135 | Area_4/office_18/Annotations 136 | Area_4/office_19/Annotations 137 | Area_4/office_1/Annotations 138 | Area_4/office_20/Annotations 139 | Area_4/office_21/Annotations 140 | Area_4/office_22/Annotations 141 | Area_4/office_2/Annotations 142 | Area_4/office_3/Annotations 143 | Area_4/office_4/Annotations 144 | Area_4/office_5/Annotations 145 | Area_4/office_6/Annotations 146 | Area_4/office_7/Annotations 147 | Area_4/office_8/Annotations 148 | Area_4/office_9/Annotations 149 | Area_4/storage_1/Annotations 150 | Area_4/storage_2/Annotations 151 | Area_4/storage_3/Annotations 152 | Area_4/storage_4/Annotations 153 | Area_4/WC_1/Annotations 154 | Area_4/WC_2/Annotations 155 | Area_4/WC_3/Annotations 156 | Area_4/WC_4/Annotations 157 | Area_5/conferenceRoom_1/Annotations 158 | Area_5/conferenceRoom_2/Annotations 159 | Area_5/conferenceRoom_3/Annotations 160 | Area_5/hallway_10/Annotations 161 | Area_5/hallway_11/Annotations 162 | Area_5/hallway_12/Annotations 163 | Area_5/hallway_13/Annotations 164 | Area_5/hallway_14/Annotations 165 | Area_5/hallway_15/Annotations 166 | Area_5/hallway_1/Annotations 167 | Area_5/hallway_2/Annotations 168 | Area_5/hallway_3/Annotations 169 | Area_5/hallway_4/Annotations 170 | Area_5/hallway_5/Annotations 171 | Area_5/hallway_6/Annotations 172 | Area_5/hallway_7/Annotations 173 | Area_5/hallway_8/Annotations 174 | Area_5/hallway_9/Annotations 175 | Area_5/lobby_1/Annotations 176 | Area_5/office_10/Annotations 177 | Area_5/office_11/Annotations 178 | Area_5/office_12/Annotations 179 | Area_5/office_13/Annotations 180 | Area_5/office_14/Annotations 181 | Area_5/office_15/Annotations 182 | Area_5/office_16/Annotations 183 | Area_5/office_17/Annotations 184 | Area_5/office_18/Annotations 185 | Area_5/office_19/Annotations 186 | Area_5/office_1/Annotations 187 | Area_5/office_20/Annotations 188 | Area_5/office_21/Annotations 189 | Area_5/office_22/Annotations 190 | Area_5/office_23/Annotations 191 | Area_5/office_24/Annotations 192 | Area_5/office_25/Annotations 193 | Area_5/office_26/Annotations 194 | Area_5/office_27/Annotations 195 | Area_5/office_28/Annotations 196 | Area_5/office_29/Annotations 197 | Area_5/office_2/Annotations 198 | Area_5/office_30/Annotations 199 | Area_5/office_31/Annotations 200 | Area_5/office_32/Annotations 201 | Area_5/office_33/Annotations 202 | Area_5/office_34/Annotations 203 | Area_5/office_35/Annotations 204 | Area_5/office_36/Annotations 205 | Area_5/office_37/Annotations 206 | Area_5/office_38/Annotations 207 | Area_5/office_39/Annotations 208 | Area_5/office_3/Annotations 209 | Area_5/office_40/Annotations 210 | Area_5/office_41/Annotations 211 | Area_5/office_42/Annotations 212 | Area_5/office_4/Annotations 213 | Area_5/office_5/Annotations 214 | Area_5/office_6/Annotations 215 | Area_5/office_7/Annotations 216 | Area_5/office_8/Annotations 217 | Area_5/office_9/Annotations 218 | Area_5/pantry_1/Annotations 219 | Area_5/storage_1/Annotations 220 | Area_5/storage_2/Annotations 221 | Area_5/storage_3/Annotations 222 | Area_5/storage_4/Annotations 223 | Area_5/WC_1/Annotations 224 | Area_5/WC_2/Annotations 225 | Area_6/conferenceRoom_1/Annotations 226 | Area_6/copyRoom_1/Annotations 227 | Area_6/hallway_1/Annotations 228 | Area_6/hallway_2/Annotations 229 | Area_6/hallway_3/Annotations 230 | Area_6/hallway_4/Annotations 231 | Area_6/hallway_5/Annotations 232 | Area_6/hallway_6/Annotations 233 | Area_6/lounge_1/Annotations 234 | Area_6/office_10/Annotations 235 | Area_6/office_11/Annotations 236 | Area_6/office_12/Annotations 237 | Area_6/office_13/Annotations 238 | Area_6/office_14/Annotations 239 | Area_6/office_15/Annotations 240 | Area_6/office_16/Annotations 241 | Area_6/office_17/Annotations 242 | Area_6/office_18/Annotations 243 | Area_6/office_19/Annotations 244 | Area_6/office_1/Annotations 245 | Area_6/office_20/Annotations 246 | Area_6/office_21/Annotations 247 | Area_6/office_22/Annotations 248 | Area_6/office_23/Annotations 249 | Area_6/office_24/Annotations 250 | Area_6/office_25/Annotations 251 | Area_6/office_26/Annotations 252 | Area_6/office_27/Annotations 253 | Area_6/office_28/Annotations 254 | Area_6/office_29/Annotations 255 | Area_6/office_2/Annotations 256 | Area_6/office_30/Annotations 257 | Area_6/office_31/Annotations 258 | Area_6/office_32/Annotations 259 | Area_6/office_33/Annotations 260 | Area_6/office_34/Annotations 261 | Area_6/office_35/Annotations 262 | Area_6/office_36/Annotations 263 | Area_6/office_37/Annotations 264 | Area_6/office_3/Annotations 265 | Area_6/office_4/Annotations 266 | Area_6/office_5/Annotations 267 | Area_6/office_6/Annotations 268 | Area_6/office_7/Annotations 269 | Area_6/office_8/Annotations 270 | Area_6/office_9/Annotations 271 | Area_6/openspace_1/Annotations 272 | Area_6/pantry_1/Annotations 273 | -------------------------------------------------------------------------------- /sem_seg/meta/area6_data_label.txt: -------------------------------------------------------------------------------- 1 | data/stanford_indoor3d/Area_6_conferenceRoom_1.npy 2 | data/stanford_indoor3d/Area_6_copyRoom_1.npy 3 | data/stanford_indoor3d/Area_6_hallway_1.npy 4 | data/stanford_indoor3d/Area_6_hallway_2.npy 5 | data/stanford_indoor3d/Area_6_hallway_3.npy 6 | data/stanford_indoor3d/Area_6_hallway_4.npy 7 | data/stanford_indoor3d/Area_6_hallway_5.npy 8 | data/stanford_indoor3d/Area_6_hallway_6.npy 9 | data/stanford_indoor3d/Area_6_lounge_1.npy 10 | data/stanford_indoor3d/Area_6_office_10.npy 11 | data/stanford_indoor3d/Area_6_office_11.npy 12 | data/stanford_indoor3d/Area_6_office_12.npy 13 | data/stanford_indoor3d/Area_6_office_13.npy 14 | data/stanford_indoor3d/Area_6_office_14.npy 15 | data/stanford_indoor3d/Area_6_office_15.npy 16 | data/stanford_indoor3d/Area_6_office_16.npy 17 | data/stanford_indoor3d/Area_6_office_17.npy 18 | data/stanford_indoor3d/Area_6_office_18.npy 19 | data/stanford_indoor3d/Area_6_office_19.npy 20 | data/stanford_indoor3d/Area_6_office_1.npy 21 | data/stanford_indoor3d/Area_6_office_20.npy 22 | data/stanford_indoor3d/Area_6_office_21.npy 23 | data/stanford_indoor3d/Area_6_office_22.npy 24 | data/stanford_indoor3d/Area_6_office_23.npy 25 | data/stanford_indoor3d/Area_6_office_24.npy 26 | data/stanford_indoor3d/Area_6_office_25.npy 27 | data/stanford_indoor3d/Area_6_office_26.npy 28 | data/stanford_indoor3d/Area_6_office_27.npy 29 | data/stanford_indoor3d/Area_6_office_28.npy 30 | data/stanford_indoor3d/Area_6_office_29.npy 31 | data/stanford_indoor3d/Area_6_office_2.npy 32 | data/stanford_indoor3d/Area_6_office_30.npy 33 | data/stanford_indoor3d/Area_6_office_31.npy 34 | data/stanford_indoor3d/Area_6_office_32.npy 35 | data/stanford_indoor3d/Area_6_office_33.npy 36 | data/stanford_indoor3d/Area_6_office_34.npy 37 | data/stanford_indoor3d/Area_6_office_35.npy 38 | data/stanford_indoor3d/Area_6_office_36.npy 39 | data/stanford_indoor3d/Area_6_office_37.npy 40 | data/stanford_indoor3d/Area_6_office_3.npy 41 | data/stanford_indoor3d/Area_6_office_4.npy 42 | data/stanford_indoor3d/Area_6_office_5.npy 43 | data/stanford_indoor3d/Area_6_office_6.npy 44 | data/stanford_indoor3d/Area_6_office_7.npy 45 | data/stanford_indoor3d/Area_6_office_8.npy 46 | data/stanford_indoor3d/Area_6_office_9.npy 47 | data/stanford_indoor3d/Area_6_openspace_1.npy 48 | data/stanford_indoor3d/Area_6_pantry_1.npy 49 | -------------------------------------------------------------------------------- /sem_seg/meta/class_names.txt: -------------------------------------------------------------------------------- 1 | ceiling 2 | floor 3 | wall 4 | beam 5 | column 6 | window 7 | door 8 | table 9 | chair 10 | sofa 11 | bookcase 12 | board 13 | clutter 14 | -------------------------------------------------------------------------------- /sem_seg/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import math 3 | import time 4 | import numpy as np 5 | import os 6 | import sys 7 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | ROOT_DIR = os.path.dirname(BASE_DIR) 9 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 10 | import tf_util 11 | 12 | def placeholder_inputs(batch_size, num_point): 13 | pointclouds_pl = tf.placeholder(tf.float32, 14 | shape=(batch_size, num_point, 9)) 15 | labels_pl = tf.placeholder(tf.int32, 16 | shape=(batch_size, num_point)) 17 | return pointclouds_pl, labels_pl 18 | 19 | def get_model(point_cloud, is_training, bn_decay=None): 20 | """ ConvNet baseline, input is BxNx3 gray image """ 21 | batch_size = point_cloud.get_shape()[0].value 22 | num_point = point_cloud.get_shape()[1].value 23 | 24 | input_image = tf.expand_dims(point_cloud, -1) 25 | # CONV 26 | net = tf_util.conv2d(input_image, 64, [1,9], padding='VALID', stride=[1,1], 27 | bn=True, is_training=is_training, scope='conv1', bn_decay=bn_decay) 28 | net = tf_util.conv2d(net, 64, [1,1], padding='VALID', stride=[1,1], 29 | bn=True, is_training=is_training, scope='conv2', bn_decay=bn_decay) 30 | net = tf_util.conv2d(net, 64, [1,1], padding='VALID', stride=[1,1], 31 | bn=True, is_training=is_training, scope='conv3', bn_decay=bn_decay) 32 | net = tf_util.conv2d(net, 128, [1,1], padding='VALID', stride=[1,1], 33 | bn=True, is_training=is_training, scope='conv4', bn_decay=bn_decay) 34 | points_feat1 = tf_util.conv2d(net, 1024, [1,1], padding='VALID', stride=[1,1], 35 | bn=True, is_training=is_training, scope='conv5', bn_decay=bn_decay) 36 | # MAX 37 | pc_feat1 = tf_util.max_pool2d(points_feat1, [num_point,1], padding='VALID', scope='maxpool1') 38 | # FC 39 | pc_feat1 = tf.reshape(pc_feat1, [batch_size, -1]) 40 | pc_feat1 = tf_util.fully_connected(pc_feat1, 256, bn=True, is_training=is_training, scope='fc1', bn_decay=bn_decay) 41 | pc_feat1 = tf_util.fully_connected(pc_feat1, 128, bn=True, is_training=is_training, scope='fc2', bn_decay=bn_decay) 42 | print(pc_feat1) 43 | 44 | # CONCAT 45 | pc_feat1_expand = tf.tile(tf.reshape(pc_feat1, [batch_size, 1, 1, -1]), [1, num_point, 1, 1]) 46 | points_feat1_concat = tf.concat(axis=3, values=[points_feat1, pc_feat1_expand]) 47 | 48 | # CONV 49 | net = tf_util.conv2d(points_feat1_concat, 512, [1,1], padding='VALID', stride=[1,1], 50 | bn=True, is_training=is_training, scope='conv6') 51 | net = tf_util.conv2d(net, 256, [1,1], padding='VALID', stride=[1,1], 52 | bn=True, is_training=is_training, scope='conv7') 53 | net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training, scope='dp1') 54 | net = tf_util.conv2d(net, 13, [1,1], padding='VALID', stride=[1,1], 55 | activation_fn=None, scope='conv8') 56 | net = tf.squeeze(net, [2]) 57 | 58 | return net 59 | 60 | def get_loss(pred, label): 61 | """ pred: B,N,13 62 | label: B,N """ 63 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=label) 64 | return tf.reduce_mean(loss) 65 | 66 | if __name__ == "__main__": 67 | with tf.Graph().as_default(): 68 | a = tf.placeholder(tf.float32, shape=(32,4096,9)) 69 | net = get_model(a, tf.constant(True)) 70 | with tf.Session() as sess: 71 | init = tf.global_variables_initializer() 72 | sess.run(init) 73 | start = time.time() 74 | for i in range(100): 75 | print(i) 76 | sess.run(net, feed_dict={a:np.random.rand(32,4096,9)}) 77 | print(time.time() - start) 78 | -------------------------------------------------------------------------------- /sem_seg/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import h5py 4 | import numpy as np 5 | import tensorflow as tf 6 | import socket 7 | 8 | import os 9 | import sys 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | ROOT_DIR = os.path.dirname(BASE_DIR) 12 | sys.path.append(BASE_DIR) 13 | sys.path.append(ROOT_DIR) 14 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 15 | import provider 16 | import tf_util 17 | from model import * 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 21 | parser.add_argument('--log_dir', default='log', help='Log dir [default: log]') 22 | parser.add_argument('--num_point', type=int, default=4096, help='Point number [default: 4096]') 23 | parser.add_argument('--max_epoch', type=int, default=50, help='Epoch to run [default: 50]') 24 | parser.add_argument('--batch_size', type=int, default=24, help='Batch Size during training [default: 24]') 25 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') 26 | parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]') 27 | parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]') 28 | parser.add_argument('--decay_step', type=int, default=300000, help='Decay step for lr decay [default: 300000]') 29 | parser.add_argument('--decay_rate', type=float, default=0.5, help='Decay rate for lr decay [default: 0.5]') 30 | parser.add_argument('--test_area', type=int, default=6, help='Which area to use for test, option: 1-6 [default: 6]') 31 | FLAGS = parser.parse_args() 32 | 33 | 34 | BATCH_SIZE = FLAGS.batch_size 35 | NUM_POINT = FLAGS.num_point 36 | MAX_EPOCH = FLAGS.max_epoch 37 | NUM_POINT = FLAGS.num_point 38 | BASE_LEARNING_RATE = FLAGS.learning_rate 39 | GPU_INDEX = FLAGS.gpu 40 | MOMENTUM = FLAGS.momentum 41 | OPTIMIZER = FLAGS.optimizer 42 | DECAY_STEP = FLAGS.decay_step 43 | DECAY_RATE = FLAGS.decay_rate 44 | 45 | LOG_DIR = FLAGS.log_dir 46 | if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR) 47 | os.system('cp model.py %s' % (LOG_DIR)) # bkp of model def 48 | os.system('cp train.py %s' % (LOG_DIR)) # bkp of train procedure 49 | LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w') 50 | LOG_FOUT.write(str(FLAGS)+'\n') 51 | 52 | MAX_NUM_POINT = 4096 53 | NUM_CLASSES = 13 54 | 55 | BN_INIT_DECAY = 0.5 56 | BN_DECAY_DECAY_RATE = 0.5 57 | #BN_DECAY_DECAY_STEP = float(DECAY_STEP * 2) 58 | BN_DECAY_DECAY_STEP = float(DECAY_STEP) 59 | BN_DECAY_CLIP = 0.99 60 | 61 | HOSTNAME = socket.gethostname() 62 | 63 | ALL_FILES = provider.getDataFiles('indoor3d_sem_seg_hdf5_data/all_files.txt') 64 | room_filelist = [line.rstrip() for line in open('indoor3d_sem_seg_hdf5_data/room_filelist.txt')] 65 | 66 | # Load ALL data 67 | data_batch_list = [] 68 | label_batch_list = [] 69 | for h5_filename in ALL_FILES: 70 | data_batch, label_batch = provider.loadDataFile(h5_filename) 71 | data_batch_list.append(data_batch) 72 | label_batch_list.append(label_batch) 73 | data_batches = np.concatenate(data_batch_list, 0) 74 | label_batches = np.concatenate(label_batch_list, 0) 75 | print(data_batches.shape) 76 | print(label_batches.shape) 77 | 78 | test_area = 'Area_'+str(FLAGS.test_area) 79 | train_idxs = [] 80 | test_idxs = [] 81 | for i,room_name in enumerate(room_filelist): 82 | if test_area in room_name: 83 | test_idxs.append(i) 84 | else: 85 | train_idxs.append(i) 86 | 87 | train_data = data_batches[train_idxs,...] 88 | train_label = label_batches[train_idxs] 89 | test_data = data_batches[test_idxs,...] 90 | test_label = label_batches[test_idxs] 91 | print(train_data.shape, train_label.shape) 92 | print(test_data.shape, test_label.shape) 93 | 94 | 95 | 96 | 97 | def log_string(out_str): 98 | LOG_FOUT.write(out_str+'\n') 99 | LOG_FOUT.flush() 100 | print(out_str) 101 | 102 | 103 | def get_learning_rate(batch): 104 | learning_rate = tf.train.exponential_decay( 105 | BASE_LEARNING_RATE, # Base learning rate. 106 | batch * BATCH_SIZE, # Current index into the dataset. 107 | DECAY_STEP, # Decay step. 108 | DECAY_RATE, # Decay rate. 109 | staircase=True) 110 | learning_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE!! 111 | return learning_rate 112 | 113 | def get_bn_decay(batch): 114 | bn_momentum = tf.train.exponential_decay( 115 | BN_INIT_DECAY, 116 | batch*BATCH_SIZE, 117 | BN_DECAY_DECAY_STEP, 118 | BN_DECAY_DECAY_RATE, 119 | staircase=True) 120 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 121 | return bn_decay 122 | 123 | def train(): 124 | with tf.Graph().as_default(): 125 | with tf.device('/gpu:'+str(GPU_INDEX)): 126 | pointclouds_pl, labels_pl = placeholder_inputs(BATCH_SIZE, NUM_POINT) 127 | is_training_pl = tf.placeholder(tf.bool, shape=()) 128 | 129 | # Note the global_step=batch parameter to minimize. 130 | # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains. 131 | batch = tf.Variable(0) 132 | bn_decay = get_bn_decay(batch) 133 | tf.summary.scalar('bn_decay', bn_decay) 134 | 135 | # Get model and loss 136 | pred = get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay) 137 | loss = get_loss(pred, labels_pl) 138 | tf.summary.scalar('loss', loss) 139 | 140 | correct = tf.equal(tf.argmax(pred, 2), tf.to_int64(labels_pl)) 141 | accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE*NUM_POINT) 142 | tf.summary.scalar('accuracy', accuracy) 143 | 144 | # Get training operator 145 | learning_rate = get_learning_rate(batch) 146 | tf.summary.scalar('learning_rate', learning_rate) 147 | if OPTIMIZER == 'momentum': 148 | optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM) 149 | elif OPTIMIZER == 'adam': 150 | optimizer = tf.train.AdamOptimizer(learning_rate) 151 | train_op = optimizer.minimize(loss, global_step=batch) 152 | 153 | # Add ops to save and restore all the variables. 154 | saver = tf.train.Saver() 155 | 156 | # Create a session 157 | config = tf.ConfigProto() 158 | config.gpu_options.allow_growth = True 159 | config.allow_soft_placement = True 160 | config.log_device_placement = True 161 | sess = tf.Session(config=config) 162 | 163 | # Add summary writers 164 | merged = tf.summary.merge_all() 165 | train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), 166 | sess.graph) 167 | test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test')) 168 | 169 | # Init variables 170 | init = tf.global_variables_initializer() 171 | sess.run(init, {is_training_pl:True}) 172 | 173 | ops = {'pointclouds_pl': pointclouds_pl, 174 | 'labels_pl': labels_pl, 175 | 'is_training_pl': is_training_pl, 176 | 'pred': pred, 177 | 'loss': loss, 178 | 'train_op': train_op, 179 | 'merged': merged, 180 | 'step': batch} 181 | 182 | for epoch in range(MAX_EPOCH): 183 | log_string('**** EPOCH %03d ****' % (epoch)) 184 | sys.stdout.flush() 185 | 186 | train_one_epoch(sess, ops, train_writer) 187 | eval_one_epoch(sess, ops, test_writer) 188 | 189 | # Save the variables to disk. 190 | if epoch % 10 == 0: 191 | save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt")) 192 | log_string("Model saved in file: %s" % save_path) 193 | 194 | 195 | 196 | def train_one_epoch(sess, ops, train_writer): 197 | """ ops: dict mapping from string to tf ops """ 198 | is_training = True 199 | 200 | log_string('----') 201 | current_data, current_label, _ = provider.shuffle_data(train_data[:,0:NUM_POINT,:], train_label) 202 | 203 | file_size = current_data.shape[0] 204 | num_batches = file_size // BATCH_SIZE 205 | 206 | total_correct = 0 207 | total_seen = 0 208 | loss_sum = 0 209 | 210 | for batch_idx in range(num_batches): 211 | if batch_idx % 100 == 0: 212 | print('Current batch/total batch num: %d/%d'%(batch_idx,num_batches)) 213 | start_idx = batch_idx * BATCH_SIZE 214 | end_idx = (batch_idx+1) * BATCH_SIZE 215 | 216 | feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :], 217 | ops['labels_pl']: current_label[start_idx:end_idx], 218 | ops['is_training_pl']: is_training,} 219 | summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']], 220 | feed_dict=feed_dict) 221 | train_writer.add_summary(summary, step) 222 | pred_val = np.argmax(pred_val, 2) 223 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) 224 | total_correct += correct 225 | total_seen += (BATCH_SIZE*NUM_POINT) 226 | loss_sum += loss_val 227 | 228 | log_string('mean loss: %f' % (loss_sum / float(num_batches))) 229 | log_string('accuracy: %f' % (total_correct / float(total_seen))) 230 | 231 | 232 | def eval_one_epoch(sess, ops, test_writer): 233 | """ ops: dict mapping from string to tf ops """ 234 | is_training = False 235 | total_correct = 0 236 | total_seen = 0 237 | loss_sum = 0 238 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 239 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 240 | 241 | log_string('----') 242 | current_data = test_data[:,0:NUM_POINT,:] 243 | current_label = np.squeeze(test_label) 244 | 245 | file_size = current_data.shape[0] 246 | num_batches = file_size // BATCH_SIZE 247 | 248 | for batch_idx in range(num_batches): 249 | start_idx = batch_idx * BATCH_SIZE 250 | end_idx = (batch_idx+1) * BATCH_SIZE 251 | 252 | feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :], 253 | ops['labels_pl']: current_label[start_idx:end_idx], 254 | ops['is_training_pl']: is_training} 255 | summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['loss'], ops['pred']], 256 | feed_dict=feed_dict) 257 | test_writer.add_summary(summary, step) 258 | pred_val = np.argmax(pred_val, 2) 259 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) 260 | total_correct += correct 261 | total_seen += (BATCH_SIZE*NUM_POINT) 262 | loss_sum += (loss_val*BATCH_SIZE) 263 | for i in range(start_idx, end_idx): 264 | for j in range(NUM_POINT): 265 | l = current_label[i, j] 266 | total_seen_class[l] += 1 267 | total_correct_class[l] += (pred_val[i-start_idx, j] == l) 268 | 269 | log_string('eval mean loss: %f' % (loss_sum / float(total_seen/NUM_POINT))) 270 | log_string('eval accuracy: %f'% (total_correct / float(total_seen))) 271 | log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)))) 272 | 273 | 274 | 275 | if __name__ == "__main__": 276 | train() 277 | LOG_FOUT.close() 278 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import h5py 4 | import numpy as np 5 | import tensorflow as tf 6 | import socket 7 | import importlib 8 | import os 9 | import sys 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(BASE_DIR) 12 | sys.path.append(os.path.join(BASE_DIR, 'models')) 13 | sys.path.append(os.path.join(BASE_DIR, 'utils')) 14 | import provider 15 | import tf_util 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 19 | parser.add_argument('--model', default='pointnet_cls', help='Model name: pointnet_cls or pointnet_cls_basic [default: pointnet_cls]') 20 | parser.add_argument('--log_dir', default='log', help='Log dir [default: log]') 21 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]') 22 | parser.add_argument('--max_epoch', type=int, default=250, help='Epoch to run [default: 250]') 23 | parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 32]') 24 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') 25 | parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]') 26 | parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]') 27 | parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]') 28 | parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.8]') 29 | FLAGS = parser.parse_args() 30 | 31 | 32 | BATCH_SIZE = FLAGS.batch_size 33 | NUM_POINT = FLAGS.num_point 34 | MAX_EPOCH = FLAGS.max_epoch 35 | BASE_LEARNING_RATE = FLAGS.learning_rate 36 | GPU_INDEX = FLAGS.gpu 37 | MOMENTUM = FLAGS.momentum 38 | OPTIMIZER = FLAGS.optimizer 39 | DECAY_STEP = FLAGS.decay_step 40 | DECAY_RATE = FLAGS.decay_rate 41 | 42 | MODEL = importlib.import_module(FLAGS.model) # import network module 43 | MODEL_FILE = os.path.join(BASE_DIR, 'models', FLAGS.model+'.py') 44 | LOG_DIR = FLAGS.log_dir 45 | if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR) 46 | os.system('cp %s %s' % (MODEL_FILE, LOG_DIR)) # bkp of model def 47 | os.system('cp train.py %s' % (LOG_DIR)) # bkp of train procedure 48 | LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w') 49 | LOG_FOUT.write(str(FLAGS)+'\n') 50 | 51 | MAX_NUM_POINT = 2048 52 | NUM_CLASSES = 40 53 | 54 | BN_INIT_DECAY = 0.5 55 | BN_DECAY_DECAY_RATE = 0.5 56 | BN_DECAY_DECAY_STEP = float(DECAY_STEP) 57 | BN_DECAY_CLIP = 0.99 58 | 59 | HOSTNAME = socket.gethostname() 60 | 61 | # ModelNet40 official train/test split 62 | TRAIN_FILES = provider.getDataFiles( \ 63 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt')) 64 | TEST_FILES = provider.getDataFiles(\ 65 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt')) 66 | 67 | def log_string(out_str): 68 | LOG_FOUT.write(out_str+'\n') 69 | LOG_FOUT.flush() 70 | print(out_str) 71 | 72 | def get_learning_rate(batch): 73 | learning_rate = tf.train.exponential_decay( 74 | BASE_LEARNING_RATE, # Base learning rate. 75 | batch * BATCH_SIZE, # Current index into the dataset. 76 | DECAY_STEP, # Decay step. 77 | DECAY_RATE, # Decay rate. 78 | staircase=True) 79 | learning_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE! 80 | return learning_rate 81 | 82 | def get_bn_decay(batch): 83 | bn_momentum = tf.train.exponential_decay( 84 | BN_INIT_DECAY, 85 | batch*BATCH_SIZE, 86 | BN_DECAY_DECAY_STEP, 87 | BN_DECAY_DECAY_RATE, 88 | staircase=True) 89 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 90 | return bn_decay 91 | 92 | def train(): 93 | with tf.Graph().as_default(): 94 | with tf.device('/cpu:'+str(GPU_INDEX)): 95 | pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT) 96 | is_training_pl = tf.placeholder(tf.bool, shape=()) 97 | print(is_training_pl) 98 | 99 | # Note the global_step=batch parameter to minimize. 100 | # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains. 101 | batch = tf.Variable(0) 102 | bn_decay = get_bn_decay(batch) 103 | tf.summary.scalar('bn_decay', bn_decay) 104 | 105 | # Get model and loss 106 | pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay) 107 | loss = MODEL.get_loss(pred, labels_pl, end_points) 108 | tf.summary.scalar('loss', loss) 109 | 110 | correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl)) 111 | accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE) 112 | tf.summary.scalar('accuracy', accuracy) 113 | 114 | # Get training operator 115 | learning_rate = get_learning_rate(batch) 116 | tf.summary.scalar('learning_rate', learning_rate) 117 | if OPTIMIZER == 'momentum': 118 | optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM) 119 | elif OPTIMIZER == 'adam': 120 | optimizer = tf.train.AdamOptimizer(learning_rate) 121 | train_op = optimizer.minimize(loss, global_step=batch) 122 | 123 | # Add ops to save and restore all the variables. 124 | saver = tf.train.Saver() 125 | 126 | # Create a session 127 | config = tf.ConfigProto() 128 | config.gpu_options.allow_growth = True 129 | config.allow_soft_placement = True 130 | config.log_device_placement = False 131 | sess = tf.Session(config=config) 132 | 133 | # Add summary writers 134 | #merged = tf.merge_all_summaries() 135 | merged = tf.summary.merge_all() 136 | train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), 137 | sess.graph) 138 | test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test')) 139 | 140 | # Init variables 141 | init = tf.global_variables_initializer() 142 | # To fix the bug introduced in TF 0.12.1 as in 143 | # http://stackoverflow.com/questions/41543774/invalidargumenterror-for-tensor-bool-tensorflow-0-12-1 144 | #sess.run(init) 145 | sess.run(init, {is_training_pl: True}) 146 | 147 | ops = {'pointclouds_pl': pointclouds_pl, 148 | 'labels_pl': labels_pl, 149 | 'is_training_pl': is_training_pl, 150 | 'pred': pred, 151 | 'loss': loss, 152 | 'train_op': train_op, 153 | 'merged': merged, 154 | 'step': batch} 155 | 156 | for epoch in range(MAX_EPOCH): 157 | log_string('**** EPOCH %03d ****' % (epoch)) 158 | sys.stdout.flush() 159 | 160 | train_one_epoch(sess, ops, train_writer) 161 | eval_one_epoch(sess, ops, test_writer) 162 | 163 | # Save the variables to disk. 164 | if epoch % 10 == 0: 165 | save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt")) 166 | log_string("Model saved in file: %s" % save_path) 167 | 168 | def train_one_epoch(sess, ops, train_writer): 169 | """ ops: dict mapping from string to tf ops """ 170 | is_training = True 171 | 172 | # Shuffle train files 173 | train_file_idxs = np.arange(0, len(TRAIN_FILES)) 174 | np.random.shuffle(train_file_idxs) 175 | 176 | for fn in range(len(TRAIN_FILES)): 177 | log_string('----' + str(fn) + '-----') 178 | current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]]) 179 | current_data = current_data[:,0:NUM_POINT,:] 180 | current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label)) 181 | current_label = np.squeeze(current_label) 182 | 183 | file_size = current_data.shape[0] 184 | num_batches = file_size // BATCH_SIZE 185 | 186 | total_correct = 0 187 | total_seen = 0 188 | loss_sum = 0 189 | 190 | for batch_idx in range(num_batches): 191 | start_idx = batch_idx * BATCH_SIZE 192 | end_idx = (batch_idx+1) * BATCH_SIZE 193 | 194 | # Augment batched point clouds by rotation and jittering 195 | rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :]) 196 | jittered_data = provider.jitter_point_cloud(rotated_data) 197 | feed_dict = {ops['pointclouds_pl']: jittered_data, 198 | ops['labels_pl']: current_label[start_idx:end_idx], 199 | ops['is_training_pl']: is_training,} 200 | summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], 201 | ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict) 202 | train_writer.add_summary(summary, step) 203 | pred_val = np.argmax(pred_val, 1) 204 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) 205 | total_correct += correct 206 | total_seen += BATCH_SIZE 207 | loss_sum += loss_val 208 | 209 | log_string('mean loss: %f' % (loss_sum / float(num_batches))) 210 | log_string('accuracy: %f' % (total_correct / float(total_seen))) 211 | 212 | def eval_one_epoch(sess, ops, test_writer): 213 | """ ops: dict mapping from string to tf ops """ 214 | is_training = False 215 | total_correct = 0 216 | total_seen = 0 217 | loss_sum = 0 218 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 219 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 220 | 221 | for fn in range(len(TEST_FILES)): 222 | log_string('----' + str(fn) + '-----') 223 | current_data, current_label = provider.loadDataFile(TEST_FILES[fn]) 224 | current_data = current_data[:,0:NUM_POINT,:] 225 | current_label = np.squeeze(current_label) 226 | 227 | file_size = current_data.shape[0] 228 | num_batches = file_size // BATCH_SIZE 229 | 230 | for batch_idx in range(num_batches): 231 | start_idx = batch_idx * BATCH_SIZE 232 | end_idx = (batch_idx+1) * BATCH_SIZE 233 | 234 | feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :], 235 | ops['labels_pl']: current_label[start_idx:end_idx], 236 | ops['is_training_pl']: is_training} 237 | summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'], 238 | ops['loss'], ops['pred']], feed_dict=feed_dict) 239 | pred_val = np.argmax(pred_val, 1) 240 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) 241 | total_correct += correct 242 | total_seen += BATCH_SIZE 243 | loss_sum += (loss_val*BATCH_SIZE) 244 | for i in range(start_idx, end_idx): 245 | l = current_label[i] 246 | total_seen_class[l] += 1 247 | total_correct_class[l] += (pred_val[i-start_idx] == l) 248 | 249 | log_string('eval mean loss: %f' % (loss_sum / float(total_seen))) 250 | log_string('eval accuracy: %f'% (total_correct / float(total_seen))) 251 | log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)))) 252 | 253 | if __name__ == "__main__": 254 | train() 255 | LOG_FOUT.close() 256 | -------------------------------------------------------------------------------- /train_pytorch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import h5py 4 | import numpy as np 5 | import socket 6 | import importlib 7 | import matplotlib.pyplot as plt 8 | import os 9 | import sys 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(BASE_DIR) 12 | sys.path.append(os.path.join(BASE_DIR, 'models')) 13 | sys.path.append(os.path.join(BASE_DIR, 'utils')) 14 | 15 | import provider 16 | import math 17 | import random 18 | import data_utils 19 | import time 20 | 21 | import torch 22 | from torch import nn 23 | from torch.autograd import Variable 24 | from torch.utils.data import Dataset, DataLoader 25 | 26 | 27 | from utils.model import RandPointCNN 28 | from utils.util_funcs import knn_indices_func_gpu, knn_indices_func_cpu 29 | from utils.util_layers import Dense 30 | 31 | 32 | random.seed(0) 33 | dtype = torch.cuda.FloatTensor 34 | 35 | 36 | 37 | # Load Hyperparameters 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 40 | parser.add_argument('--model', default='pointnet_cls', 41 | help='Model name: pointnet_cls or pointnet_cls_basic [default: pointnet_cls]') 42 | parser.add_argument('--log_dir', default='log', help='Log dir [default: log]') 43 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]') 44 | parser.add_argument('--max_epoch', type=int, default=2, help='Epoch to run [default: 250]') 45 | parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 32]') 46 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') 47 | parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]') 48 | parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]') 49 | parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]') 50 | parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.8]') 51 | FLAGS = parser.parse_args() 52 | 53 | NUM_POINT = FLAGS.num_point 54 | LEARNING_RATE = FLAGS.learning_rate 55 | GPU_INDEX = FLAGS.gpu 56 | MOMENTUM = FLAGS.momentum 57 | 58 | MAX_NUM_POINT = 2048 59 | 60 | DECAY_STEP = FLAGS.decay_step 61 | DECAY_RATE = FLAGS.decay_rate 62 | BN_INIT_DECAY = 0.5 63 | BN_DECAY_DECAY_RATE = 0.5 64 | BN_DECAY_DECAY_STEP = float(DECAY_STEP) 65 | BN_DECAY_CLIP = 0.99 66 | 67 | 68 | LEARNING_RATE_MIN = 0.00001 69 | 70 | NUM_CLASS = 40 71 | BATCH_SIZE = FLAGS.batch_size #32 72 | NUM_EPOCHS = FLAGS.max_epoch 73 | jitter = 0.01 74 | jitter_val = 0.01 75 | 76 | rotation_range = [0, math.pi / 18, 0, 'g'] 77 | rotation_rage_val = [0, 0, 0, 'u'] 78 | order = 'rxyz' 79 | 80 | scaling_range = [0.05, 0.05, 0.05, 'g'] 81 | scaling_range_val = [0, 0, 0, 'u'] 82 | 83 | 84 | class modelnet40_dataset(Dataset): 85 | 86 | def __init__(self, data, labels): 87 | self.data = data 88 | self.labels = labels 89 | 90 | def __len__(self): 91 | return len(self.data) 92 | 93 | def __getitem__(self, i): 94 | return self.data[i], self.labels[i] 95 | 96 | 97 | # C_in, C_out, D, N_neighbors, dilution, N_rep, r_indices_func, C_lifted = None, mlp_width = 2 98 | # (a, b, c, d, e) == (C_in, C_out, N_neighbors, dilution, N_rep) 99 | # Abbreviated PointCNN constructor. 100 | AbbPointCNN = lambda a, b, c, d, e: RandPointCNN(a, b, 3, c, d, e, knn_indices_func_gpu) 101 | 102 | 103 | class Classifier(nn.Module): 104 | 105 | def __init__(self): 106 | super(Classifier, self).__init__() 107 | 108 | self.pcnn1 = AbbPointCNN(3, 32, 8, 1, -1) 109 | self.pcnn2 = nn.Sequential( 110 | AbbPointCNN(32, 64, 8, 2, -1), 111 | AbbPointCNN(64, 96, 8, 4, -1), 112 | AbbPointCNN(96, 128, 12, 4, 120), 113 | AbbPointCNN(128, 160, 12, 6, 120) 114 | ) 115 | 116 | self.fcn = nn.Sequential( 117 | Dense(160, 128), 118 | Dense(128, 64, drop_rate=0.5), 119 | Dense(64, NUM_CLASS, with_bn=False, activation=None) 120 | ) 121 | 122 | def forward(self, x): 123 | x = self.pcnn1(x) 124 | if False: 125 | print("Making graph...") 126 | k = make_dot(x[1]) 127 | 128 | print("Viewing...") 129 | k.view() 130 | print("DONE") 131 | 132 | assert False 133 | x = self.pcnn2(x)[1] # grab features 134 | 135 | logits = self.fcn(x) 136 | logits_mean = torch.mean(logits, dim=1) 137 | return logits_mean 138 | 139 | 140 | print("------Building model-------") 141 | model = Classifier().cuda() 142 | print("------Successfully Built model-------") 143 | 144 | 145 | optimizer = torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9) 146 | loss_fn = nn.CrossEntropyLoss() 147 | 148 | global_step = 1 149 | 150 | #model_save_dir = os.path.join(CURRENT_DIR, "models", "mnist2") 151 | #os.makedirs(model_save_dir, exist_ok = True) 152 | 153 | TRAIN_FILES = provider.getDataFiles(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt')) 154 | TEST_FILES = provider.getDataFiles(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt')) 155 | 156 | losses = [] 157 | accuracies = [] 158 | 159 | ''' 160 | if False: 161 | latest_model = sorted(os.listdir(model_save_dir))[-1] 162 | model.load_state_dict(torch.load(os.path.join(model_save_dir, latest_model))) 163 | ''' 164 | 165 | for epoch in range(1, NUM_EPOCHS+1): 166 | train_file_idxs = np.arange(0, len(TRAIN_FILES)) 167 | np.random.shuffle(train_file_idxs) 168 | 169 | for fn in range(len(TRAIN_FILES)): 170 | #log_string('----' + str(fn) + '-----') 171 | current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]]) 172 | current_data = current_data[:, 0:NUM_POINT, :] 173 | 174 | current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label)) 175 | current_label = np.squeeze(current_label) 176 | 177 | file_size = current_data.shape[0] 178 | num_batches = file_size // BATCH_SIZE 179 | 180 | total_correct = 0 181 | total_seen = 0 182 | loss_sum = 0 183 | 184 | if epoch > 1: 185 | LEARNING_RATE *= decay_rate ** (global_step // decay_steps) 186 | if LEARNING_RATE > LEARNING_RATE_MIN: 187 | print("NEW LEARNING RATE:", LEARNING_RATE) 188 | optimizer = torch.optim.SGD(model.parameters(), lr = LEARNING_RATE, momentum = 0.9) 189 | 190 | for batch_idx in range(num_batches): 191 | start_idx = batch_idx * BATCH_SIZE 192 | end_idx = (batch_idx + 1) * BATCH_SIZE 193 | 194 | # Lable 195 | label = current_label[start_idx:end_idx] 196 | label = torch.from_numpy(label).long() 197 | label = Variable(label, requires_grad=False).cuda() 198 | # Augment batched point clouds by rotation and jittering 199 | rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :]) 200 | jittered_data = provider.jitter_point_cloud(rotated_data) # P_Sampled 201 | P_sampled = jittered_data 202 | F_sampled = np.zeros((BATCH_SIZE, NUM_POINT, 0)) 203 | optimizer.zero_grad() 204 | 205 | t0 = time.time() 206 | P_sampled = torch.from_numpy(P_sampled).float() 207 | P_sampled = Variable(P_sampled, requires_grad=False).cuda() 208 | 209 | #F_sampled = torch.from_numpy(F_sampled) 210 | 211 | out = model((P_sampled, P_sampled)) 212 | loss = loss_fn(out, label) 213 | loss.backward() 214 | optimizer.step() 215 | print("epoch: "+str(epoch) + " loss: "+str(loss.data[0])) 216 | if global_step % 25 == 0: 217 | loss_v = loss.data[0] 218 | print("Loss:", loss_v) 219 | else: 220 | loss_v = 0 221 | global_step += 1 -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hxdengBerkeley/PointCNN.Pytorch/6ec6c291cf97923a84fb6ed8c82e98bf01e7e96d/utils/.DS_Store -------------------------------------------------------------------------------- /utils/data_prep_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(BASE_DIR) 5 | from plyfile import (PlyData, PlyElement, make2d, PlyParseError, PlyProperty) 6 | import numpy as np 7 | import h5py 8 | 9 | SAMPLING_BIN = os.path.join(BASE_DIR, 'third_party/mesh_sampling/build/pcsample') 10 | 11 | SAMPLING_POINT_NUM = 2048 12 | SAMPLING_LEAF_SIZE = 0.005 13 | 14 | MODELNET40_PATH = '../datasets/modelnet40' 15 | 16 | 17 | def export_ply(pc, filename): 18 | vertex = np.zeros(pc.shape[0], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 19 | for i in range(pc.shape[0]): 20 | vertex[i] = (pc[i][0], pc[i][1], pc[i][2]) 21 | ply_out = PlyData([PlyElement.describe(vertex, 'vertex', comments=['vertices'])]) 22 | ply_out.write(filename) 23 | 24 | 25 | # Sample points on the obj shape 26 | def get_sampling_command(obj_filename, ply_filename): 27 | cmd = SAMPLING_BIN + ' ' + obj_filename 28 | cmd += ' ' + ply_filename 29 | cmd += ' -n_samples %d ' % SAMPLING_POINT_NUM 30 | cmd += ' -leaf_size %f ' % SAMPLING_LEAF_SIZE 31 | return cmd 32 | 33 | # -------------------------------------------------------------- 34 | # Following are the helper functions to load MODELNET40 shapes 35 | # -------------------------------------------------------------- 36 | 37 | # Read in the list of categories in MODELNET40 38 | def get_category_names(): 39 | shape_names_file = os.path.join(MODELNET40_PATH, 'shape_names.txt') 40 | shape_names = [line.rstrip() for line in open(shape_names_file)] 41 | return shape_names 42 | 43 | # Return all the filepaths for the shapes in MODELNET40 44 | def get_obj_filenames(): 45 | obj_filelist_file = os.path.join(MODELNET40_PATH, 'filelist.txt') 46 | obj_filenames = [os.path.join(MODELNET40_PATH, line.rstrip()) for line in open(obj_filelist_file)] 47 | print('Got %d obj files in modelnet40.' % len(obj_filenames)) 48 | return obj_filenames 49 | 50 | # Helper function to create the father folder and all subdir folders if not exist 51 | def batch_mkdir(output_folder, subdir_list): 52 | if not os.path.exists(output_folder): 53 | os.mkdir(output_folder) 54 | for subdir in subdir_list: 55 | if not os.path.exists(os.path.join(output_folder, subdir)): 56 | os.mkdir(os.path.join(output_folder, subdir)) 57 | 58 | # ---------------------------------------------------------------- 59 | # Following are the helper functions to load save/load HDF5 files 60 | # ---------------------------------------------------------------- 61 | 62 | # Write numpy array data and label to h5_filename 63 | def save_h5_data_label_normal(h5_filename, data, label, normal, 64 | data_dtype='float32', label_dtype='uint8', noral_dtype='float32'): 65 | h5_fout = h5py.File(h5_filename) 66 | h5_fout.create_dataset( 67 | 'data', data=data, 68 | compression='gzip', compression_opts=4, 69 | dtype=data_dtype) 70 | h5_fout.create_dataset( 71 | 'normal', data=normal, 72 | compression='gzip', compression_opts=4, 73 | dtype=normal_dtype) 74 | h5_fout.create_dataset( 75 | 'label', data=label, 76 | compression='gzip', compression_opts=1, 77 | dtype=label_dtype) 78 | h5_fout.close() 79 | 80 | 81 | # Write numpy array data and label to h5_filename 82 | def save_h5(h5_filename, data, label, data_dtype='uint8', label_dtype='uint8'): 83 | h5_fout = h5py.File(h5_filename) 84 | h5_fout.create_dataset( 85 | 'data', data=data, 86 | compression='gzip', compression_opts=4, 87 | dtype=data_dtype) 88 | h5_fout.create_dataset( 89 | 'label', data=label, 90 | compression='gzip', compression_opts=1, 91 | dtype=label_dtype) 92 | h5_fout.close() 93 | 94 | # Read numpy array data and label from h5_filename 95 | def load_h5_data_label_normal(h5_filename): 96 | f = h5py.File(h5_filename) 97 | data = f['data'][:] 98 | label = f['label'][:] 99 | normal = f['normal'][:] 100 | return (data, label, normal) 101 | 102 | # Read numpy array data and label from h5_filename 103 | def load_h5_data_label_seg(h5_filename): 104 | f = h5py.File(h5_filename) 105 | data = f['data'][:] 106 | label = f['label'][:] 107 | seg = f['pid'][:] 108 | return (data, label, seg) 109 | 110 | # Read numpy array data and label from h5_filename 111 | def load_h5(h5_filename): 112 | f = h5py.File(h5_filename) 113 | data = f['data'][:] 114 | label = f['label'][:] 115 | return (data, label) 116 | 117 | # ---------------------------------------------------------------- 118 | # Following are the helper functions to load save/load PLY files 119 | # ---------------------------------------------------------------- 120 | 121 | # Load PLY file 122 | def load_ply_data(filename, point_num): 123 | plydata = PlyData.read(filename) 124 | pc = plydata['vertex'].data[:point_num] 125 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 126 | return pc_array 127 | 128 | # Load PLY file 129 | def load_ply_normal(filename, point_num): 130 | plydata = PlyData.read(filename) 131 | pc = plydata['normal'].data[:point_num] 132 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 133 | return pc_array 134 | 135 | # Make up rows for Nxk array 136 | # Input Pad is 'edge' or 'constant' 137 | def pad_arr_rows(arr, row, pad='edge'): 138 | assert(len(arr.shape) == 2) 139 | assert(arr.shape[0] <= row) 140 | assert(pad == 'edge' or pad == 'constant') 141 | if arr.shape[0] == row: 142 | return arr 143 | if pad == 'edge': 144 | return np.lib.pad(arr, ((0, row-arr.shape[0]), (0, 0)), 'edge') 145 | if pad == 'constant': 146 | return np.lib.pad(arr, ((0, row-arr.shape[0]), (0, 0)), 'constant', (0, 0)) 147 | 148 | 149 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import h5py 7 | import plyfile 8 | import numpy as np 9 | from matplotlib import cm 10 | import scipy.spatial.distance as distance 11 | 12 | 13 | def save_ply(points, filename, colors=None, normals=None): 14 | vertex = np.array([tuple(p) for p in points], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 15 | n = len(vertex) 16 | desc = vertex.dtype.descr 17 | 18 | if normals is not None: 19 | vertex_normal = np.array([tuple(n) for n in normals], dtype=[('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4')]) 20 | assert len(vertex_normal) == n 21 | desc = desc + vertex_normal.dtype.descr 22 | 23 | if colors is not None: 24 | vertex_color = np.array([tuple(c * 255) for c in colors], 25 | dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 26 | assert len(vertex_color) == n 27 | desc = desc + vertex_color.dtype.descr 28 | 29 | vertex_all = np.empty(n, dtype=desc) 30 | 31 | for prop in vertex.dtype.names: 32 | vertex_all[prop] = vertex[prop] 33 | 34 | if normals is not None: 35 | for prop in vertex_normal.dtype.names: 36 | vertex_all[prop] = vertex_normal[prop] 37 | 38 | if colors is not None: 39 | for prop in vertex_color.dtype.names: 40 | vertex_all[prop] = vertex_color[prop] 41 | 42 | ply = plyfile.PlyData([plyfile.PlyElement.describe(vertex_all, 'vertex')], text=False) 43 | if not os.path.exists(os.path.dirname(filename)): 44 | os.makedirs(os.path.dirname(filename)) 45 | ply.write(filename) 46 | 47 | 48 | def save_ply_property(points, property, property_max, filename, cmap_name='Set1'): 49 | point_num = points.shape[0] 50 | colors = np.full(points.shape, 0.5) 51 | cmap = cm.get_cmap(cmap_name) 52 | for point_idx in range(point_num): 53 | colors[point_idx] = cmap(property[point_idx] / property_max)[:3] 54 | save_ply(points, filename, colors) 55 | 56 | 57 | def save_ply_batch(points_batch, file_path, points_num=None): 58 | batch_size = points_batch.shape[0] 59 | if type(file_path) != list: 60 | basename = os.path.splitext(file_path)[0] 61 | ext = '.ply' 62 | for batch_idx in range(batch_size): 63 | point_num = points_batch.shape[1] if points_num is None else points_num[batch_idx] 64 | if type(file_path) == list: 65 | save_ply(points_batch[batch_idx][:point_num], file_path[batch_idx]) 66 | else: 67 | save_ply(points_batch[batch_idx][:point_num], '%s_%04d%s' % (basename, batch_idx, ext)) 68 | 69 | 70 | def save_ply_property_batch(points_batch, property_batch, file_path, points_num=None, property_max=None, 71 | cmap_name='Set1'): 72 | batch_size = points_batch.shape[0] 73 | if type(file_path) != list: 74 | basename = os.path.splitext(file_path)[0] 75 | ext = '.ply' 76 | property_max = np.max(property_batch) if property_max is None else property_max 77 | for batch_idx in range(batch_size): 78 | point_num = points_batch.shape[1] if points_num is None else points_num[batch_idx] 79 | if type(file_path) == list: 80 | save_ply_property(points_batch[batch_idx][:point_num], property_batch[batch_idx][:point_num], 81 | property_max, file_path[batch_idx], cmap_name) 82 | else: 83 | save_ply_property(points_batch[batch_idx][:point_num], property_batch[batch_idx][:point_num], 84 | property_max, '%s_%04d%s' % (basename, batch_idx, ext), cmap_name) 85 | 86 | 87 | def save_ply_point_with_normal(data_sample, folder): 88 | for idx, sample in enumerate(data_sample): 89 | filename_pts = os.path.join(folder, '{:08d}.ply'.format(idx)) 90 | save_ply(sample[..., :3], filename_pts, normals=sample[..., 3:]) 91 | 92 | 93 | def grouped_shuffle(inputs): 94 | for idx in range(len(inputs) - 1): 95 | assert (len(inputs[idx]) == len(inputs[idx + 1])) 96 | 97 | shuffle_indices = np.arange(inputs[0].shape[0]) 98 | np.random.shuffle(shuffle_indices) 99 | outputs = [] 100 | for idx in range(len(inputs)): 101 | outputs.append(inputs[idx][shuffle_indices, ...]) 102 | return outputs 103 | 104 | 105 | def load_cls(filelist): 106 | points = [] 107 | labels = [] 108 | 109 | folder = os.path.dirname(filelist) 110 | for line in open(filelist): 111 | filename = os.path.basename(line.rstrip()) 112 | data = h5py.File(os.path.join(folder, filename)) 113 | if 'normal' in data: 114 | points.append(np.concatenate([data['data'][...], data['data'][...]], axis=-1).astype(np.float32)) 115 | else: 116 | points.append(data['data'][...].astype(np.float32)) 117 | labels.append(np.squeeze(data['label'][:]).astype(np.int32)) 118 | return (np.concatenate(points, axis=0), 119 | np.concatenate(labels, axis=0)) 120 | 121 | 122 | def load_cls_train_val(filelist, filelist_val): 123 | data_train, label_train = grouped_shuffle(load_cls(filelist)) 124 | data_val, label_val = load_cls(filelist_val) 125 | return data_train, label_train, data_val, label_val 126 | 127 | 128 | def load_seg(filelist): 129 | points = [] 130 | labels = [] 131 | point_nums = [] 132 | labels_seg = [] 133 | 134 | folder = os.path.dirname(filelist) 135 | for line in open(filelist): 136 | filename = os.path.basename(line.rstrip()) 137 | data = h5py.File(os.path.join(folder, filename)) 138 | points.append(data['data'][...].astype(np.float32)) 139 | labels.append(data['label'][...].astype(np.int32)) 140 | point_nums.append(data['data_num'][...].astype(np.int32)) 141 | labels_seg.append(data['label_seg'][...].astype(np.int32)) 142 | return (np.concatenate(points, axis=0), 143 | np.concatenate(labels, axis=0), 144 | np.concatenate(point_nums, axis=0), 145 | np.concatenate(labels_seg, axis=0)) 146 | -------------------------------------------------------------------------------- /utils/eulerangles.py: -------------------------------------------------------------------------------- 1 | # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 2 | # vi: set ft=python sts=4 ts=4 sw=4 et: 3 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 4 | # 5 | # See COPYING file distributed along with the NiBabel package for the 6 | # copyright and license terms. 7 | # 8 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 9 | ''' Module implementing Euler angle rotations and their conversions 10 | 11 | See: 12 | 13 | * http://en.wikipedia.org/wiki/Rotation_matrix 14 | * http://en.wikipedia.org/wiki/Euler_angles 15 | * http://mathworld.wolfram.com/EulerAngles.html 16 | 17 | See also: *Representing Attitude with Euler Angles and Quaternions: A 18 | Reference* (2006) by James Diebel. A cached PDF link last found here: 19 | 20 | http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.110.5134 21 | 22 | Euler's rotation theorem tells us that any rotation in 3D can be 23 | described by 3 angles. Let's call the 3 angles the *Euler angle vector* 24 | and call the angles in the vector :math:`alpha`, :math:`beta` and 25 | :math:`gamma`. The vector is [ :math:`alpha`, 26 | :math:`beta`. :math:`gamma` ] and, in this description, the order of the 27 | parameters specifies the order in which the rotations occur (so the 28 | rotation corresponding to :math:`alpha` is applied first). 29 | 30 | In order to specify the meaning of an *Euler angle vector* we need to 31 | specify the axes around which each of the rotations corresponding to 32 | :math:`alpha`, :math:`beta` and :math:`gamma` will occur. 33 | 34 | There are therefore three axes for the rotations :math:`alpha`, 35 | :math:`beta` and :math:`gamma`; let's call them :math:`i` :math:`j`, 36 | :math:`k`. 37 | 38 | Let us express the rotation :math:`alpha` around axis `i` as a 3 by 3 39 | rotation matrix `A`. Similarly :math:`beta` around `j` becomes 3 x 3 40 | matrix `B` and :math:`gamma` around `k` becomes matrix `G`. Then the 41 | whole rotation expressed by the Euler angle vector [ :math:`alpha`, 42 | :math:`beta`. :math:`gamma` ], `R` is given by:: 43 | 44 | R = np.dot(G, np.dot(B, A)) 45 | 46 | See http://mathworld.wolfram.com/EulerAngles.html 47 | 48 | The order :math:`G B A` expresses the fact that the rotations are 49 | performed in the order of the vector (:math:`alpha` around axis `i` = 50 | `A` first). 51 | 52 | To convert a given Euler angle vector to a meaningful rotation, and a 53 | rotation matrix, we need to define: 54 | 55 | * the axes `i`, `j`, `k` 56 | * whether a rotation matrix should be applied on the left of a vector to 57 | be transformed (vectors are column vectors) or on the right (vectors 58 | are row vectors). 59 | * whether the rotations move the axes as they are applied (intrinsic 60 | rotations) - compared the situation where the axes stay fixed and the 61 | vectors move within the axis frame (extrinsic) 62 | * the handedness of the coordinate system 63 | 64 | See: http://en.wikipedia.org/wiki/Rotation_matrix#Ambiguities 65 | 66 | We are using the following conventions: 67 | 68 | * axes `i`, `j`, `k` are the `z`, `y`, and `x` axes respectively. Thus 69 | an Euler angle vector [ :math:`alpha`, :math:`beta`. :math:`gamma` ] 70 | in our convention implies a :math:`alpha` radian rotation around the 71 | `z` axis, followed by a :math:`beta` rotation around the `y` axis, 72 | followed by a :math:`gamma` rotation around the `x` axis. 73 | * the rotation matrix applies on the left, to column vectors on the 74 | right, so if `R` is the rotation matrix, and `v` is a 3 x N matrix 75 | with N column vectors, the transformed vector set `vdash` is given by 76 | ``vdash = np.dot(R, v)``. 77 | * extrinsic rotations - the axes are fixed, and do not move with the 78 | rotations. 79 | * a right-handed coordinate system 80 | 81 | The convention of rotation around ``z``, followed by rotation around 82 | ``y``, followed by rotation around ``x``, is known (confusingly) as 83 | "xyz", pitch-roll-yaw, Cardan angles, or Tait-Bryan angles. 84 | ''' 85 | 86 | import math 87 | 88 | import sys 89 | if sys.version_info >= (3,0): 90 | from functools import reduce 91 | 92 | import numpy as np 93 | 94 | 95 | _FLOAT_EPS_4 = np.finfo(float).eps * 4.0 96 | 97 | 98 | def euler2mat(z=0, y=0, x=0): 99 | ''' Return matrix for rotations around z, y and x axes 100 | 101 | Uses the z, then y, then x convention above 102 | 103 | Parameters 104 | ---------- 105 | z : scalar 106 | Rotation angle in radians around z-axis (performed first) 107 | y : scalar 108 | Rotation angle in radians around y-axis 109 | x : scalar 110 | Rotation angle in radians around x-axis (performed last) 111 | 112 | Returns 113 | ------- 114 | M : array shape (3,3) 115 | Rotation matrix giving same rotation as for given angles 116 | 117 | Examples 118 | -------- 119 | >>> zrot = 1.3 # radians 120 | >>> yrot = -0.1 121 | >>> xrot = 0.2 122 | >>> M = euler2mat(zrot, yrot, xrot) 123 | >>> M.shape == (3, 3) 124 | True 125 | 126 | The output rotation matrix is equal to the composition of the 127 | individual rotations 128 | 129 | >>> M1 = euler2mat(zrot) 130 | >>> M2 = euler2mat(0, yrot) 131 | >>> M3 = euler2mat(0, 0, xrot) 132 | >>> composed_M = np.dot(M3, np.dot(M2, M1)) 133 | >>> np.allclose(M, composed_M) 134 | True 135 | 136 | You can specify rotations by named arguments 137 | 138 | >>> np.all(M3 == euler2mat(x=xrot)) 139 | True 140 | 141 | When applying M to a vector, the vector should column vector to the 142 | right of M. If the right hand side is a 2D array rather than a 143 | vector, then each column of the 2D array represents a vector. 144 | 145 | >>> vec = np.array([1, 0, 0]).reshape((3,1)) 146 | >>> v2 = np.dot(M, vec) 147 | >>> vecs = np.array([[1, 0, 0],[0, 1, 0]]).T # giving 3x2 array 148 | >>> vecs2 = np.dot(M, vecs) 149 | 150 | Rotations are counter-clockwise. 151 | 152 | >>> zred = np.dot(euler2mat(z=np.pi/2), np.eye(3)) 153 | >>> np.allclose(zred, [[0, -1, 0],[1, 0, 0], [0, 0, 1]]) 154 | True 155 | >>> yred = np.dot(euler2mat(y=np.pi/2), np.eye(3)) 156 | >>> np.allclose(yred, [[0, 0, 1],[0, 1, 0], [-1, 0, 0]]) 157 | True 158 | >>> xred = np.dot(euler2mat(x=np.pi/2), np.eye(3)) 159 | >>> np.allclose(xred, [[1, 0, 0],[0, 0, -1], [0, 1, 0]]) 160 | True 161 | 162 | Notes 163 | ----- 164 | The direction of rotation is given by the right-hand rule (orient 165 | the thumb of the right hand along the axis around which the rotation 166 | occurs, with the end of the thumb at the positive end of the axis; 167 | curl your fingers; the direction your fingers curl is the direction 168 | of rotation). Therefore, the rotations are counterclockwise if 169 | looking along the axis of rotation from positive to negative. 170 | ''' 171 | Ms = [] 172 | if z: 173 | cosz = math.cos(z) 174 | sinz = math.sin(z) 175 | Ms.append(np.array( 176 | [[cosz, -sinz, 0], 177 | [sinz, cosz, 0], 178 | [0, 0, 1]])) 179 | if y: 180 | cosy = math.cos(y) 181 | siny = math.sin(y) 182 | Ms.append(np.array( 183 | [[cosy, 0, siny], 184 | [0, 1, 0], 185 | [-siny, 0, cosy]])) 186 | if x: 187 | cosx = math.cos(x) 188 | sinx = math.sin(x) 189 | Ms.append(np.array( 190 | [[1, 0, 0], 191 | [0, cosx, -sinx], 192 | [0, sinx, cosx]])) 193 | if Ms: 194 | return reduce(np.dot, Ms[::-1]) 195 | return np.eye(3) 196 | 197 | 198 | def mat2euler(M, cy_thresh=None): 199 | ''' Discover Euler angle vector from 3x3 matrix 200 | 201 | Uses the conventions above. 202 | 203 | Parameters 204 | ---------- 205 | M : array-like, shape (3,3) 206 | cy_thresh : None or scalar, optional 207 | threshold below which to give up on straightforward arctan for 208 | estimating x rotation. If None (default), estimate from 209 | precision of input. 210 | 211 | Returns 212 | ------- 213 | z : scalar 214 | y : scalar 215 | x : scalar 216 | Rotations in radians around z, y, x axes, respectively 217 | 218 | Notes 219 | ----- 220 | If there was no numerical error, the routine could be derived using 221 | Sympy expression for z then y then x rotation matrix, which is:: 222 | 223 | [ cos(y)*cos(z), -cos(y)*sin(z), sin(y)], 224 | [cos(x)*sin(z) + cos(z)*sin(x)*sin(y), cos(x)*cos(z) - sin(x)*sin(y)*sin(z), -cos(y)*sin(x)], 225 | [sin(x)*sin(z) - cos(x)*cos(z)*sin(y), cos(z)*sin(x) + cos(x)*sin(y)*sin(z), cos(x)*cos(y)] 226 | 227 | with the obvious derivations for z, y, and x 228 | 229 | z = atan2(-r12, r11) 230 | y = asin(r13) 231 | x = atan2(-r23, r33) 232 | 233 | Problems arise when cos(y) is close to zero, because both of:: 234 | 235 | z = atan2(cos(y)*sin(z), cos(y)*cos(z)) 236 | x = atan2(cos(y)*sin(x), cos(x)*cos(y)) 237 | 238 | will be close to atan2(0, 0), and highly unstable. 239 | 240 | The ``cy`` fix for numerical instability below is from: *Graphics 241 | Gems IV*, Paul Heckbert (editor), Academic Press, 1994, ISBN: 242 | 0123361559. Specifically it comes from EulerAngles.c by Ken 243 | Shoemake, and deals with the case where cos(y) is close to zero: 244 | 245 | See: http://www.graphicsgems.org/ 246 | 247 | The code appears to be licensed (from the website) as "can be used 248 | without restrictions". 249 | ''' 250 | M = np.asarray(M) 251 | if cy_thresh is None: 252 | try: 253 | cy_thresh = np.finfo(M.dtype).eps * 4 254 | except ValueError: 255 | cy_thresh = _FLOAT_EPS_4 256 | r11, r12, r13, r21, r22, r23, r31, r32, r33 = M.flat 257 | # cy: sqrt((cos(y)*cos(z))**2 + (cos(x)*cos(y))**2) 258 | cy = math.sqrt(r33*r33 + r23*r23) 259 | if cy > cy_thresh: # cos(y) not close to zero, standard form 260 | z = math.atan2(-r12, r11) # atan2(cos(y)*sin(z), cos(y)*cos(z)) 261 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 262 | x = math.atan2(-r23, r33) # atan2(cos(y)*sin(x), cos(x)*cos(y)) 263 | else: # cos(y) (close to) zero, so x -> 0.0 (see above) 264 | # so r21 -> sin(z), r22 -> cos(z) and 265 | z = math.atan2(r21, r22) 266 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 267 | x = 0.0 268 | return z, y, x 269 | 270 | 271 | def euler2quat(z=0, y=0, x=0): 272 | ''' Return quaternion corresponding to these Euler angles 273 | 274 | Uses the z, then y, then x convention above 275 | 276 | Parameters 277 | ---------- 278 | z : scalar 279 | Rotation angle in radians around z-axis (performed first) 280 | y : scalar 281 | Rotation angle in radians around y-axis 282 | x : scalar 283 | Rotation angle in radians around x-axis (performed last) 284 | 285 | Returns 286 | ------- 287 | quat : array shape (4,) 288 | Quaternion in w, x, y z (real, then vector) format 289 | 290 | Notes 291 | ----- 292 | We can derive this formula in Sympy using: 293 | 294 | 1. Formula giving quaternion corresponding to rotation of theta radians 295 | about arbitrary axis: 296 | http://mathworld.wolfram.com/EulerParameters.html 297 | 2. Generated formulae from 1.) for quaternions corresponding to 298 | theta radians rotations about ``x, y, z`` axes 299 | 3. Apply quaternion multiplication formula - 300 | http://en.wikipedia.org/wiki/Quaternions#Hamilton_product - to 301 | formulae from 2.) to give formula for combined rotations. 302 | ''' 303 | z = z/2.0 304 | y = y/2.0 305 | x = x/2.0 306 | cz = math.cos(z) 307 | sz = math.sin(z) 308 | cy = math.cos(y) 309 | sy = math.sin(y) 310 | cx = math.cos(x) 311 | sx = math.sin(x) 312 | return np.array([ 313 | cx*cy*cz - sx*sy*sz, 314 | cx*sy*sz + cy*cz*sx, 315 | cx*cz*sy - sx*cy*sz, 316 | cx*cy*sz + sx*cz*sy]) 317 | 318 | 319 | def quat2euler(q): 320 | ''' Return Euler angles corresponding to quaternion `q` 321 | 322 | Parameters 323 | ---------- 324 | q : 4 element sequence 325 | w, x, y, z of quaternion 326 | 327 | Returns 328 | ------- 329 | z : scalar 330 | Rotation angle in radians around z-axis (performed first) 331 | y : scalar 332 | Rotation angle in radians around y-axis 333 | x : scalar 334 | Rotation angle in radians around x-axis (performed last) 335 | 336 | Notes 337 | ----- 338 | It's possible to reduce the amount of calculation a little, by 339 | combining parts of the ``quat2mat`` and ``mat2euler`` functions, but 340 | the reduction in computation is small, and the code repetition is 341 | large. 342 | ''' 343 | # delayed import to avoid cyclic dependencies 344 | import nibabel.quaternions as nq 345 | return mat2euler(nq.quat2mat(q)) 346 | 347 | 348 | def euler2angle_axis(z=0, y=0, x=0): 349 | ''' Return angle, axis corresponding to these Euler angles 350 | 351 | Uses the z, then y, then x convention above 352 | 353 | Parameters 354 | ---------- 355 | z : scalar 356 | Rotation angle in radians around z-axis (performed first) 357 | y : scalar 358 | Rotation angle in radians around y-axis 359 | x : scalar 360 | Rotation angle in radians around x-axis (performed last) 361 | 362 | Returns 363 | ------- 364 | theta : scalar 365 | angle of rotation 366 | vector : array shape (3,) 367 | axis around which rotation occurs 368 | 369 | Examples 370 | -------- 371 | >>> theta, vec = euler2angle_axis(0, 1.5, 0) 372 | >>> print(theta) 373 | 1.5 374 | >>> np.allclose(vec, [0, 1, 0]) 375 | True 376 | ''' 377 | # delayed import to avoid cyclic dependencies 378 | import nibabel.quaternions as nq 379 | return nq.quat2angle_axis(euler2quat(z, y, x)) 380 | 381 | 382 | def angle_axis2euler(theta, vector, is_normalized=False): 383 | ''' Convert angle, axis pair to Euler angles 384 | 385 | Parameters 386 | ---------- 387 | theta : scalar 388 | angle of rotation 389 | vector : 3 element sequence 390 | vector specifying axis for rotation. 391 | is_normalized : bool, optional 392 | True if vector is already normalized (has norm of 1). Default 393 | False 394 | 395 | Returns 396 | ------- 397 | z : scalar 398 | y : scalar 399 | x : scalar 400 | Rotations in radians around z, y, x axes, respectively 401 | 402 | Examples 403 | -------- 404 | >>> z, y, x = angle_axis2euler(0, [1, 0, 0]) 405 | >>> np.allclose((z, y, x), 0) 406 | True 407 | 408 | Notes 409 | ----- 410 | It's possible to reduce the amount of calculation a little, by 411 | combining parts of the ``angle_axis2mat`` and ``mat2euler`` 412 | functions, but the reduction in computation is small, and the code 413 | repetition is large. 414 | ''' 415 | # delayed import to avoid cyclic dependencies 416 | import nibabel.quaternions as nq 417 | M = nq.angle_axis2mat(theta, vector, is_normalized) 418 | return mat2euler(M) 419 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Austin J. Garrett 3 | 4 | PyTorch implementation of the PointCNN paper, as specified in: 5 | https://arxiv.org/pdf/1801.07791.pdf 6 | Original paper by: Yangyan Li, Rui Bu, Mingchao Sun, Baoquan Chen 7 | """ 8 | 9 | # External Modules 10 | import torch 11 | import torch.nn as nn 12 | from torch import FloatTensor 13 | import numpy as np 14 | from typing import Tuple, Callable, Optional 15 | 16 | # Internal Modules 17 | from util_funcs import UFloatTensor, ULongTensor 18 | from util_layers import Conv, SepConv, Dense, EndChannels 19 | 20 | class XConv(nn.Module): 21 | """ Convolution over a single point and its neighbors. """ 22 | 23 | def __init__(self, C_in : int, C_out : int, dims : int, K : int, 24 | P : int, C_mid : int, depth_multiplier : int) -> None: 25 | """ 26 | :param C_in: Input dimension of the points' features. 27 | :param C_out: Output dimension of the representative point features. 28 | :param dims: Spatial dimensionality of points. 29 | :param K: Number of neighbors to convolve over. 30 | :param P: Number of representative points. 31 | :param C_mid: Dimensionality of lifted point features. 32 | :param depth_multiplier: Depth multiplier for internal depthwise separable convolution. 33 | """ 34 | super(XConv, self).__init__() 35 | 36 | if __debug__: 37 | # Only needed for assertions. 38 | self.C_in = C_in 39 | self.C_mid = C_mid 40 | self.dims = dims 41 | self.K = K 42 | 43 | self.P = P 44 | 45 | # Additional processing layers 46 | # self.pts_layernorm = LayerNorm(2, momentum = 0.9) 47 | 48 | # Main dense linear layers 49 | self.dense1 = Dense(dims, C_mid) 50 | self.dense2 = Dense(C_mid, C_mid) 51 | 52 | # Layers to generate X 53 | self.x_trans = nn.Sequential( 54 | EndChannels(Conv( 55 | in_channels = dims, 56 | out_channels = K*K, 57 | kernel_size = (1, K), 58 | with_bn = False 59 | )), 60 | Dense(K*K, K*K, with_bn = False), 61 | Dense(K*K, K*K, with_bn = False, activation = None) 62 | ) 63 | 64 | self.end_conv = EndChannels(SepConv( 65 | in_channels = C_mid + C_in, 66 | out_channels = C_out, 67 | kernel_size = (1, K), 68 | depth_multiplier = depth_multiplier 69 | )).cuda() 70 | 71 | def forward(self, x : Tuple[UFloatTensor, # (N, P, dims) 72 | UFloatTensor, # (N, P, K, dims) 73 | Optional[UFloatTensor]] # (N, P, K, C_in) 74 | ) -> UFloatTensor: # (N, K, C_out) 75 | """ 76 | Applies XConv to the input data. 77 | :param x: (rep_pt, pts, fts) where 78 | - rep_pt: Representative point. 79 | - pts: Regional point cloud such that fts[:,p_idx,:] is the feature 80 | associated with pts[:,p_idx,:]. 81 | - fts: Regional features such that pts[:,p_idx,:] is the feature 82 | associated with fts[:,p_idx,:]. 83 | :return: Features aggregated into point rep_pt. 84 | """ 85 | rep_pt, pts, fts = x 86 | 87 | if fts is not None: 88 | assert(rep_pt.size()[0] == pts.size()[0] == fts.size()[0]) # Check N is equal. 89 | assert(rep_pt.size()[1] == pts.size()[1] == fts.size()[1]) # Check P is equal. 90 | assert(pts.size()[2] == fts.size()[2] == self.K) # Check K is equal. 91 | assert(fts.size()[3] == self.C_in) # Check C_in is equal. 92 | else: 93 | assert(rep_pt.size()[0] == pts.size()[0]) # Check N is equal. 94 | assert(rep_pt.size()[1] == pts.size()[1]) # Check P is equal. 95 | assert(pts.size()[2] == self.K) # Check K is equal. 96 | assert(rep_pt.size()[2] == pts.size()[3] == self.dims) # Check dims is equal. 97 | 98 | N = len(pts) 99 | P = rep_pt.size()[1] # (N, P, K, dims) 100 | p_center = torch.unsqueeze(rep_pt, dim = 2) # (N, P, 1, dims) 101 | 102 | # Move pts to local coordinate system of rep_pt. 103 | pts_local = pts - p_center # (N, P, K, dims) 104 | # pts_local = self.pts_layernorm(pts - p_center) 105 | 106 | # Individually lift each point into C_mid space. 107 | fts_lifted0 = self.dense1(pts_local) 108 | fts_lifted = self.dense2(fts_lifted0) # (N, P, K, C_mid) 109 | 110 | if fts is None: 111 | fts_cat = fts_lifted 112 | else: 113 | fts_cat = torch.cat((fts_lifted, fts), -1) # (N, P, K, C_mid + C_in) 114 | 115 | # Learn the (N, K, K) X-transformation matrix. 116 | X_shape = (N, P, self.K, self.K) 117 | X = self.x_trans(pts_local) 118 | X = X.view(*X_shape) 119 | 120 | # Weight and permute fts_cat with the learned X. 121 | fts_X = torch.matmul(X, fts_cat) 122 | fts_p = self.end_conv(fts_X).squeeze(dim = 2) 123 | return fts_p 124 | 125 | class PointCNN(nn.Module): 126 | """ Pointwise convolutional model. """ 127 | 128 | def __init__(self, C_in : int, C_out : int, dims : int, K : int, D : int, P : int, 129 | r_indices_func : Callable[[UFloatTensor, # (N, P, dims) 130 | UFloatTensor, # (N, x, dims) 131 | int, int], 132 | ULongTensor] # (N, P, K) 133 | ) -> None: 134 | """ 135 | :param C_in: Input dimension of the points' features. 136 | :param C_out: Output dimension of the representative point features. 137 | :param dims: Spatial dimensionality of points. 138 | :param K: Number of neighbors to convolve over. 139 | :param D: "Spread" of neighboring points. 140 | :param P: Number of representative points. 141 | :param r_indices_func: Selector function of the type, 142 | INPUTS 143 | rep_pts : Representative points. 144 | pts : Point cloud. 145 | K : Number of points for each region. 146 | D : "Spread" of neighboring points. 147 | 148 | OUTPUT 149 | pts_idx : Array of indices into pts such that pts[pts_idx] is the set 150 | of points in the "region" around rep_pt. 151 | """ 152 | super(PointCNN, self).__init__() 153 | 154 | C_mid = C_out // 2 if C_in == 0 else C_out // 4 155 | 156 | if C_in == 0: 157 | depth_multiplier = 1 158 | else: 159 | depth_multiplier = min(int(np.ceil(C_out / C_in)), 4) 160 | 161 | self.r_indices_func = lambda rep_pts, pts: r_indices_func(rep_pts, pts, K, D) 162 | self.dense = Dense(C_in, C_out // 2) if C_in != 0 else None 163 | self.x_conv = XConv(C_out // 2 if C_in != 0 else C_in, C_out, dims, K, P, C_mid, depth_multiplier) 164 | self.D = D 165 | 166 | def select_region(self, pts : UFloatTensor, # (N, x, dims) 167 | pts_idx : ULongTensor # (N, P, K) 168 | ) -> UFloatTensor: # (P, K, dims) 169 | """ 170 | Selects neighborhood points based on output of r_indices_func. 171 | :param pts: Point cloud to select regional points from. 172 | :param pts_idx: Indices of points in region to be selected. 173 | :return: Local neighborhoods around each representative point. 174 | """ 175 | regions = torch.stack([ 176 | pts[n][idx,:] for n, idx in enumerate(torch.unbind(pts_idx, dim = 0)) 177 | ], dim = 0) 178 | return regions 179 | 180 | def forward(self, x : Tuple[FloatTensor, # (N, P, dims) 181 | FloatTensor, # (N, x, dims) 182 | FloatTensor] # (N, x, C_in) 183 | ) -> FloatTensor: # (N, P, C_out) 184 | """ 185 | Given a set of representative points, a point cloud, and its 186 | corresponding features, return a new set of representative points with 187 | features projected from the point cloud. 188 | :param x: (rep_pts, pts, fts) where 189 | - rep_pts: Representative points. 190 | - pts: Regional point cloud such that fts[:,p_idx,:] is the 191 | feature associated with pts[:,p_idx,:]. 192 | - fts: Regional features such that pts[:,p_idx,:] is the feature 193 | associated with fts[:,p_idx,:]. 194 | :return: Features aggregated to rep_pts. 195 | """ 196 | rep_pts, pts, fts = x 197 | fts = self.dense(fts) if fts is not None else fts 198 | 199 | # This step takes ~97% of the time. Prime target for optimization: KNN on GPU. 200 | pts_idx = self.r_indices_func(rep_pts.cpu(), pts.cpu()).cuda() 201 | # -------------------------------------------------------------------------- # 202 | 203 | pts_regional = self.select_region(pts, pts_idx) 204 | fts_regional = self.select_region(fts, pts_idx) if fts is not None else fts 205 | fts_p = self.x_conv((rep_pts, pts_regional, fts_regional)) 206 | 207 | return fts_p 208 | 209 | class RandPointCNN(nn.Module): 210 | """ PointCNN with randomly subsampled representative points. """ 211 | 212 | def __init__(self, C_in : int, C_out : int, dims : int, K : int, D : int, P : int, 213 | r_indices_func : Callable[[UFloatTensor, # (N, P, dims) 214 | UFloatTensor, # (N, x, dims) 215 | int, int], 216 | ULongTensor] # (N, P, K) 217 | ) -> None: 218 | """ See documentation for PointCNN. """ 219 | super(RandPointCNN, self).__init__() 220 | self.pointcnn = PointCNN(C_in, C_out, dims, K, D, P, r_indices_func) 221 | self.P = P 222 | 223 | def forward(self, x : Tuple[UFloatTensor, # (N, x, dims) 224 | UFloatTensor] # (N, x, dims) 225 | ) -> Tuple[UFloatTensor, # (N, P, dims) 226 | UFloatTensor]: # (N, P, C_out) 227 | """ 228 | Given a point cloud, and its corresponding features, return a new set 229 | of randomly-sampled representative points with features projected from 230 | the point cloud. 231 | :param x: (pts, fts) where 232 | - pts: Regional point cloud such that fts[:,p_idx,:] is the 233 | feature associated with pts[:,p_idx,:]. 234 | - fts: Regional features such that pts[:,p_idx,:] is the feature 235 | associated with fts[:,p_idx,:]. 236 | :return: Randomly subsampled points and their features. 237 | """ 238 | pts, fts = x 239 | if 0 < self.P < pts.size()[1]: 240 | # Select random set of indices of subsampled points. 241 | idx = np.random.choice(pts.size()[1], self.P, replace = False).tolist() 242 | rep_pts = pts[:,idx,:] 243 | else: 244 | # All input points are representative points. 245 | rep_pts = pts 246 | rep_pts_fts = self.pointcnn((rep_pts, pts, fts)) 247 | return rep_pts, rep_pts_fts 248 | -------------------------------------------------------------------------------- /utils/pc_util.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for processing point clouds. 2 | 3 | Author: Charles R. Qi, Hao Su 4 | Date: November 2016 5 | """ 6 | 7 | import os 8 | import sys 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(BASE_DIR) 11 | 12 | # Draw point cloud 13 | from eulerangles import euler2mat 14 | 15 | # Point cloud IO 16 | import numpy as np 17 | from plyfile import PlyData, PlyElement 18 | 19 | 20 | # ---------------------------------------- 21 | # Point Cloud/Volume Conversions 22 | # ---------------------------------------- 23 | 24 | def point_cloud_to_volume_batch(point_clouds, vsize=12, radius=1.0, flatten=True): 25 | """ Input is BxNx3 batch of point cloud 26 | Output is Bx(vsize^3) 27 | """ 28 | vol_list = [] 29 | for b in range(point_clouds.shape[0]): 30 | vol = point_cloud_to_volume(np.squeeze(point_clouds[b,:,:]), vsize, radius) 31 | if flatten: 32 | vol_list.append(vol.flatten()) 33 | else: 34 | vol_list.append(np.expand_dims(np.expand_dims(vol, -1), 0)) 35 | if flatten: 36 | return np.vstack(vol_list) 37 | else: 38 | return np.concatenate(vol_list, 0) 39 | 40 | 41 | def point_cloud_to_volume(points, vsize, radius=1.0): 42 | """ input is Nx3 points. 43 | output is vsize*vsize*vsize 44 | assumes points are in range [-radius, radius] 45 | """ 46 | vol = np.zeros((vsize,vsize,vsize)) 47 | voxel = 2*radius/float(vsize) 48 | locations = (points + radius)/voxel 49 | locations = locations.astype(int) 50 | vol[locations[:,0],locations[:,1],locations[:,2]] = 1.0 51 | return vol 52 | 53 | #a = np.zeros((16,1024,3)) 54 | #print point_cloud_to_volume_batch(a, 12, 1.0, False).shape 55 | 56 | def volume_to_point_cloud(vol): 57 | """ vol is occupancy grid (value = 0 or 1) of size vsize*vsize*vsize 58 | return Nx3 numpy array. 59 | """ 60 | vsize = vol.shape[0] 61 | assert(vol.shape[1] == vsize and vol.shape[1] == vsize) 62 | points = [] 63 | for a in range(vsize): 64 | for b in range(vsize): 65 | for c in range(vsize): 66 | if vol[a,b,c] == 1: 67 | points.append(np.array([a,b,c])) 68 | if len(points) == 0: 69 | return np.zeros((0,3)) 70 | points = np.vstack(points) 71 | return points 72 | 73 | # ---------------------------------------- 74 | # Point cloud IO 75 | # ---------------------------------------- 76 | 77 | def read_ply(filename): 78 | """ read XYZ point cloud from filename PLY file """ 79 | plydata = PlyData.read(filename) 80 | pc = plydata['vertex'].data 81 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 82 | return pc_array 83 | 84 | 85 | def write_ply(points, filename, text=True): 86 | """ input: Nx3, write points to filename as PLY format. """ 87 | points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] 88 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) 89 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 90 | PlyData([el], text=text).write(filename) 91 | 92 | 93 | # ---------------------------------------- 94 | # Simple Point cloud and Volume Renderers 95 | # ---------------------------------------- 96 | 97 | def draw_point_cloud(input_points, canvasSize=500, space=200, diameter=25, 98 | xrot=0, yrot=0, zrot=0, switch_xyz=[0,1,2], normalize=True): 99 | """ Render point cloud to image with alpha channel. 100 | Input: 101 | points: Nx3 numpy array (+y is up direction) 102 | Output: 103 | gray image as numpy array of size canvasSizexcanvasSize 104 | """ 105 | image = np.zeros((canvasSize, canvasSize)) 106 | if input_points is None or input_points.shape[0] == 0: 107 | return image 108 | 109 | points = input_points[:, switch_xyz] 110 | M = euler2mat(zrot, yrot, xrot) 111 | points = (np.dot(M, points.transpose())).transpose() 112 | 113 | # Normalize the point cloud 114 | # We normalize scale to fit points in a unit sphere 115 | if normalize: 116 | centroid = np.mean(points, axis=0) 117 | points -= centroid 118 | furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1))) 119 | points /= furthest_distance 120 | 121 | # Pre-compute the Gaussian disk 122 | radius = (diameter-1)/2.0 123 | disk = np.zeros((diameter, diameter)) 124 | for i in range(diameter): 125 | for j in range(diameter): 126 | if (i - radius) * (i-radius) + (j-radius) * (j-radius) <= radius * radius: 127 | disk[i, j] = np.exp((-(i-radius)**2 - (j-radius)**2)/(radius**2)) 128 | mask = np.argwhere(disk > 0) 129 | dx = mask[:, 0] 130 | dy = mask[:, 1] 131 | dv = disk[disk > 0] 132 | 133 | # Order points by z-buffer 134 | zorder = np.argsort(points[:, 2]) 135 | points = points[zorder, :] 136 | points[:, 2] = (points[:, 2] - np.min(points[:, 2])) / (np.max(points[:, 2] - np.min(points[:, 2]))) 137 | max_depth = np.max(points[:, 2]) 138 | 139 | for i in range(points.shape[0]): 140 | j = points.shape[0] - i - 1 141 | x = points[j, 0] 142 | y = points[j, 1] 143 | xc = canvasSize/2 + (x*space) 144 | yc = canvasSize/2 + (y*space) 145 | xc = int(np.round(xc)) 146 | yc = int(np.round(yc)) 147 | 148 | px = dx + xc 149 | py = dy + yc 150 | 151 | image[px, py] = image[px, py] * 0.7 + dv * (max_depth - points[j, 2]) * 0.3 152 | 153 | image = image / np.max(image) 154 | return image 155 | 156 | def point_cloud_three_views(points): 157 | """ input points Nx3 numpy array (+y is up direction). 158 | return an numpy array gray image of size 500x1500. """ 159 | # +y is up direction 160 | # xrot is azimuth 161 | # yrot is in-plane 162 | # zrot is elevation 163 | img1 = draw_point_cloud(points, zrot=110/180.0*np.pi, xrot=45/180.0*np.pi, yrot=0/180.0*np.pi) 164 | img2 = draw_point_cloud(points, zrot=70/180.0*np.pi, xrot=135/180.0*np.pi, yrot=0/180.0*np.pi) 165 | img3 = draw_point_cloud(points, zrot=180.0/180.0*np.pi, xrot=90/180.0*np.pi, yrot=0/180.0*np.pi) 166 | image_large = np.concatenate([img1, img2, img3], 1) 167 | return image_large 168 | 169 | 170 | from PIL import Image 171 | def point_cloud_three_views_demo(): 172 | """ Demo for draw_point_cloud function """ 173 | points = read_ply('../third_party/mesh_sampling/piano.ply') 174 | im_array = point_cloud_three_views(points) 175 | img = Image.fromarray(np.uint8(im_array*255.0)) 176 | img.save('piano.jpg') 177 | 178 | if __name__=="__main__": 179 | point_cloud_three_views_demo() 180 | 181 | 182 | import matplotlib.pyplot as plt 183 | def pyplot_draw_point_cloud(points, output_filename): 184 | """ points is a Nx3 numpy array """ 185 | fig = plt.figure() 186 | ax = fig.add_subplot(111, projection='3d') 187 | ax.scatter(points[:,0], points[:,1], points[:,2]) 188 | ax.set_xlabel('x') 189 | ax.set_ylabel('y') 190 | ax.set_zlabel('z') 191 | #savefig(output_filename) 192 | 193 | def pyplot_draw_volume(vol, output_filename): 194 | """ vol is of size vsize*vsize*vsize 195 | output an image to output_filename 196 | """ 197 | points = volume_to_point_cloud(vol) 198 | pyplot_draw_point_cloud(points, output_filename) 199 | -------------------------------------------------------------------------------- /utils/tf_util.py: -------------------------------------------------------------------------------- 1 | """ Wrapper functions for TensorFlow layers. 2 | 3 | Author: Charles R. Qi 4 | Date: November 2016 5 | """ 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | def _variable_on_cpu(name, shape, initializer, use_fp16=False): 11 | """Helper to create a Variable stored on CPU memory. 12 | Args: 13 | name: name of the variable 14 | shape: list of ints 15 | initializer: initializer for Variable 16 | Returns: 17 | Variable Tensor 18 | """ 19 | with tf.device('/cpu:0'): 20 | dtype = tf.float16 if use_fp16 else tf.float32 21 | var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype) 22 | return var 23 | 24 | def _variable_with_weight_decay(name, shape, stddev, wd, use_xavier=True): 25 | """Helper to create an initialized Variable with weight decay. 26 | 27 | Note that the Variable is initialized with a truncated normal distribution. 28 | A weight decay is added only if one is specified. 29 | 30 | Args: 31 | name: name of the variable 32 | shape: list of ints 33 | stddev: standard deviation of a truncated Gaussian 34 | wd: add L2Loss weight decay multiplied by this float. If None, weight 35 | decay is not added for this Variable. 36 | use_xavier: bool, whether to use xavier initializer 37 | 38 | Returns: 39 | Variable Tensor 40 | """ 41 | if use_xavier: 42 | initializer = tf.contrib.layers.xavier_initializer() 43 | else: 44 | initializer = tf.truncated_normal_initializer(stddev=stddev) 45 | var = _variable_on_cpu(name, shape, initializer) 46 | if wd is not None: 47 | weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') 48 | tf.add_to_collection('losses', weight_decay) 49 | return var 50 | 51 | 52 | def conv1d(inputs, 53 | num_output_channels, 54 | kernel_size, 55 | scope, 56 | stride=1, 57 | padding='SAME', 58 | use_xavier=True, 59 | stddev=1e-3, 60 | weight_decay=0.0, 61 | activation_fn=tf.nn.relu, 62 | bn=False, 63 | bn_decay=None, 64 | is_training=None): 65 | """ 1D convolution with non-linear operation. 66 | 67 | Args: 68 | inputs: 3-D tensor variable BxLxC 69 | num_output_channels: int 70 | kernel_size: int 71 | scope: string 72 | stride: int 73 | padding: 'SAME' or 'VALID' 74 | use_xavier: bool, use xavier_initializer if true 75 | stddev: float, stddev for truncated_normal init 76 | weight_decay: float 77 | activation_fn: function 78 | bn: bool, whether to use batch norm 79 | bn_decay: float or float tensor variable in [0,1] 80 | is_training: bool Tensor variable 81 | 82 | Returns: 83 | Variable tensor 84 | """ 85 | with tf.variable_scope(scope) as sc: 86 | num_in_channels = inputs.get_shape()[-1].value 87 | kernel_shape = [kernel_size, 88 | num_in_channels, num_output_channels] 89 | kernel = _variable_with_weight_decay('weights', 90 | shape=kernel_shape, 91 | use_xavier=use_xavier, 92 | stddev=stddev, 93 | wd=weight_decay) 94 | outputs = tf.nn.conv1d(inputs, kernel, 95 | stride=stride, 96 | padding=padding) 97 | biases = _variable_on_cpu('biases', [num_output_channels], 98 | tf.constant_initializer(0.0)) 99 | outputs = tf.nn.bias_add(outputs, biases) 100 | 101 | if bn: 102 | outputs = batch_norm_for_conv1d(outputs, is_training, 103 | bn_decay=bn_decay, scope='bn') 104 | 105 | if activation_fn is not None: 106 | outputs = activation_fn(outputs) 107 | return outputs 108 | 109 | 110 | 111 | 112 | def conv2d(inputs, 113 | num_output_channels, 114 | kernel_size, 115 | scope, 116 | stride=[1, 1], 117 | padding='SAME', 118 | use_xavier=True, 119 | stddev=1e-3, 120 | weight_decay=0.0, 121 | activation_fn=tf.nn.relu, 122 | bn=False, 123 | bn_decay=None, 124 | is_training=None): 125 | """ 2D convolution with non-linear operation. 126 | 127 | Args: 128 | inputs: 4-D tensor variable BxHxWxC 129 | num_output_channels: int 130 | kernel_size: a list of 2 ints 131 | scope: string 132 | stride: a list of 2 ints 133 | padding: 'SAME' or 'VALID' 134 | use_xavier: bool, use xavier_initializer if true 135 | stddev: float, stddev for truncated_normal init 136 | weight_decay: float 137 | activation_fn: function 138 | bn: bool, whether to use batch norm 139 | bn_decay: float or float tensor variable in [0,1] 140 | is_training: bool Tensor variable 141 | 142 | Returns: 143 | Variable tensor 144 | """ 145 | with tf.variable_scope(scope) as sc: 146 | kernel_h, kernel_w = kernel_size 147 | num_in_channels = inputs.get_shape()[-1].value 148 | kernel_shape = [kernel_h, kernel_w, 149 | num_in_channels, num_output_channels] 150 | kernel = _variable_with_weight_decay('weights', 151 | shape=kernel_shape, 152 | use_xavier=use_xavier, 153 | stddev=stddev, 154 | wd=weight_decay) 155 | stride_h, stride_w = stride 156 | outputs = tf.nn.conv2d(inputs, kernel, 157 | [1, stride_h, stride_w, 1], 158 | padding=padding) 159 | biases = _variable_on_cpu('biases', [num_output_channels], 160 | tf.constant_initializer(0.0)) 161 | outputs = tf.nn.bias_add(outputs, biases) 162 | 163 | if bn: 164 | outputs = batch_norm_for_conv2d(outputs, is_training, 165 | bn_decay=bn_decay, scope='bn') 166 | 167 | if activation_fn is not None: 168 | outputs = activation_fn(outputs) 169 | return outputs 170 | 171 | 172 | def conv2d_transpose(inputs, 173 | num_output_channels, 174 | kernel_size, 175 | scope, 176 | stride=[1, 1], 177 | padding='SAME', 178 | use_xavier=True, 179 | stddev=1e-3, 180 | weight_decay=0.0, 181 | activation_fn=tf.nn.relu, 182 | bn=False, 183 | bn_decay=None, 184 | is_training=None): 185 | """ 2D convolution transpose with non-linear operation. 186 | 187 | Args: 188 | inputs: 4-D tensor variable BxHxWxC 189 | num_output_channels: int 190 | kernel_size: a list of 2 ints 191 | scope: string 192 | stride: a list of 2 ints 193 | padding: 'SAME' or 'VALID' 194 | use_xavier: bool, use xavier_initializer if true 195 | stddev: float, stddev for truncated_normal init 196 | weight_decay: float 197 | activation_fn: function 198 | bn: bool, whether to use batch norm 199 | bn_decay: float or float tensor variable in [0,1] 200 | is_training: bool Tensor variable 201 | 202 | Returns: 203 | Variable tensor 204 | 205 | Note: conv2d(conv2d_transpose(a, num_out, ksize, stride), a.shape[-1], ksize, stride) == a 206 | """ 207 | with tf.variable_scope(scope) as sc: 208 | kernel_h, kernel_w = kernel_size 209 | num_in_channels = inputs.get_shape()[-1].value 210 | kernel_shape = [kernel_h, kernel_w, 211 | num_output_channels, num_in_channels] # reversed to conv2d 212 | kernel = _variable_with_weight_decay('weights', 213 | shape=kernel_shape, 214 | use_xavier=use_xavier, 215 | stddev=stddev, 216 | wd=weight_decay) 217 | stride_h, stride_w = stride 218 | 219 | # from slim.convolution2d_transpose 220 | def get_deconv_dim(dim_size, stride_size, kernel_size, padding): 221 | dim_size *= stride_size 222 | 223 | if padding == 'VALID' and dim_size is not None: 224 | dim_size += max(kernel_size - stride_size, 0) 225 | return dim_size 226 | 227 | # caculate output shape 228 | batch_size = inputs.get_shape()[0].value 229 | height = inputs.get_shape()[1].value 230 | width = inputs.get_shape()[2].value 231 | out_height = get_deconv_dim(height, stride_h, kernel_h, padding) 232 | out_width = get_deconv_dim(width, stride_w, kernel_w, padding) 233 | output_shape = [batch_size, out_height, out_width, num_output_channels] 234 | 235 | outputs = tf.nn.conv2d_transpose(inputs, kernel, output_shape, 236 | [1, stride_h, stride_w, 1], 237 | padding=padding) 238 | biases = _variable_on_cpu('biases', [num_output_channels], 239 | tf.constant_initializer(0.0)) 240 | outputs = tf.nn.bias_add(outputs, biases) 241 | 242 | if bn: 243 | outputs = batch_norm_for_conv2d(outputs, is_training, 244 | bn_decay=bn_decay, scope='bn') 245 | 246 | if activation_fn is not None: 247 | outputs = activation_fn(outputs) 248 | return outputs 249 | 250 | 251 | 252 | def conv3d(inputs, 253 | num_output_channels, 254 | kernel_size, 255 | scope, 256 | stride=[1, 1, 1], 257 | padding='SAME', 258 | use_xavier=True, 259 | stddev=1e-3, 260 | weight_decay=0.0, 261 | activation_fn=tf.nn.relu, 262 | bn=False, 263 | bn_decay=None, 264 | is_training=None): 265 | """ 3D convolution with non-linear operation. 266 | 267 | Args: 268 | inputs: 5-D tensor variable BxDxHxWxC 269 | num_output_channels: int 270 | kernel_size: a list of 3 ints 271 | scope: string 272 | stride: a list of 3 ints 273 | padding: 'SAME' or 'VALID' 274 | use_xavier: bool, use xavier_initializer if true 275 | stddev: float, stddev for truncated_normal init 276 | weight_decay: float 277 | activation_fn: function 278 | bn: bool, whether to use batch norm 279 | bn_decay: float or float tensor variable in [0,1] 280 | is_training: bool Tensor variable 281 | 282 | Returns: 283 | Variable tensor 284 | """ 285 | with tf.variable_scope(scope) as sc: 286 | kernel_d, kernel_h, kernel_w = kernel_size 287 | num_in_channels = inputs.get_shape()[-1].value 288 | kernel_shape = [kernel_d, kernel_h, kernel_w, 289 | num_in_channels, num_output_channels] 290 | kernel = _variable_with_weight_decay('weights', 291 | shape=kernel_shape, 292 | use_xavier=use_xavier, 293 | stddev=stddev, 294 | wd=weight_decay) 295 | stride_d, stride_h, stride_w = stride 296 | outputs = tf.nn.conv3d(inputs, kernel, 297 | [1, stride_d, stride_h, stride_w, 1], 298 | padding=padding) 299 | biases = _variable_on_cpu('biases', [num_output_channels], 300 | tf.constant_initializer(0.0)) 301 | outputs = tf.nn.bias_add(outputs, biases) 302 | 303 | if bn: 304 | outputs = batch_norm_for_conv3d(outputs, is_training, 305 | bn_decay=bn_decay, scope='bn') 306 | 307 | if activation_fn is not None: 308 | outputs = activation_fn(outputs) 309 | return outputs 310 | 311 | def fully_connected(inputs, 312 | num_outputs, 313 | scope, 314 | use_xavier=True, 315 | stddev=1e-3, 316 | weight_decay=0.0, 317 | activation_fn=tf.nn.relu, 318 | bn=False, 319 | bn_decay=None, 320 | is_training=None): 321 | """ Fully connected layer with non-linear operation. 322 | 323 | Args: 324 | inputs: 2-D tensor BxN 325 | num_outputs: int 326 | 327 | Returns: 328 | Variable tensor of size B x num_outputs. 329 | """ 330 | with tf.variable_scope(scope) as sc: 331 | num_input_units = inputs.get_shape()[-1].value 332 | weights = _variable_with_weight_decay('weights', 333 | shape=[num_input_units, num_outputs], 334 | use_xavier=use_xavier, 335 | stddev=stddev, 336 | wd=weight_decay) 337 | outputs = tf.matmul(inputs, weights) 338 | biases = _variable_on_cpu('biases', [num_outputs], 339 | tf.constant_initializer(0.0)) 340 | outputs = tf.nn.bias_add(outputs, biases) 341 | 342 | if bn: 343 | outputs = batch_norm_for_fc(outputs, is_training, bn_decay, 'bn') 344 | 345 | if activation_fn is not None: 346 | outputs = activation_fn(outputs) 347 | return outputs 348 | 349 | 350 | def max_pool2d(inputs, 351 | kernel_size, 352 | scope, 353 | stride=[2, 2], 354 | padding='VALID'): 355 | """ 2D max pooling. 356 | 357 | Args: 358 | inputs: 4-D tensor BxHxWxC 359 | kernel_size: a list of 2 ints 360 | stride: a list of 2 ints 361 | 362 | Returns: 363 | Variable tensor 364 | """ 365 | with tf.variable_scope(scope) as sc: 366 | kernel_h, kernel_w = kernel_size 367 | stride_h, stride_w = stride 368 | outputs = tf.nn.max_pool(inputs, 369 | ksize=[1, kernel_h, kernel_w, 1], 370 | strides=[1, stride_h, stride_w, 1], 371 | padding=padding, 372 | name=sc.name) 373 | return outputs 374 | 375 | def avg_pool2d(inputs, 376 | kernel_size, 377 | scope, 378 | stride=[2, 2], 379 | padding='VALID'): 380 | """ 2D avg pooling. 381 | 382 | Args: 383 | inputs: 4-D tensor BxHxWxC 384 | kernel_size: a list of 2 ints 385 | stride: a list of 2 ints 386 | 387 | Returns: 388 | Variable tensor 389 | """ 390 | with tf.variable_scope(scope) as sc: 391 | kernel_h, kernel_w = kernel_size 392 | stride_h, stride_w = stride 393 | outputs = tf.nn.avg_pool(inputs, 394 | ksize=[1, kernel_h, kernel_w, 1], 395 | strides=[1, stride_h, stride_w, 1], 396 | padding=padding, 397 | name=sc.name) 398 | return outputs 399 | 400 | 401 | def max_pool3d(inputs, 402 | kernel_size, 403 | scope, 404 | stride=[2, 2, 2], 405 | padding='VALID'): 406 | """ 3D max pooling. 407 | 408 | Args: 409 | inputs: 5-D tensor BxDxHxWxC 410 | kernel_size: a list of 3 ints 411 | stride: a list of 3 ints 412 | 413 | Returns: 414 | Variable tensor 415 | """ 416 | with tf.variable_scope(scope) as sc: 417 | kernel_d, kernel_h, kernel_w = kernel_size 418 | stride_d, stride_h, stride_w = stride 419 | outputs = tf.nn.max_pool3d(inputs, 420 | ksize=[1, kernel_d, kernel_h, kernel_w, 1], 421 | strides=[1, stride_d, stride_h, stride_w, 1], 422 | padding=padding, 423 | name=sc.name) 424 | return outputs 425 | 426 | def avg_pool3d(inputs, 427 | kernel_size, 428 | scope, 429 | stride=[2, 2, 2], 430 | padding='VALID'): 431 | """ 3D avg pooling. 432 | 433 | Args: 434 | inputs: 5-D tensor BxDxHxWxC 435 | kernel_size: a list of 3 ints 436 | stride: a list of 3 ints 437 | 438 | Returns: 439 | Variable tensor 440 | """ 441 | with tf.variable_scope(scope) as sc: 442 | kernel_d, kernel_h, kernel_w = kernel_size 443 | stride_d, stride_h, stride_w = stride 444 | outputs = tf.nn.avg_pool3d(inputs, 445 | ksize=[1, kernel_d, kernel_h, kernel_w, 1], 446 | strides=[1, stride_d, stride_h, stride_w, 1], 447 | padding=padding, 448 | name=sc.name) 449 | return outputs 450 | 451 | 452 | 453 | 454 | 455 | def batch_norm_template(inputs, is_training, scope, moments_dims, bn_decay): 456 | """ Batch normalization on convolutional maps and beyond... 457 | Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow 458 | 459 | Args: 460 | inputs: Tensor, k-D input ... x C could be BC or BHWC or BDHWC 461 | is_training: boolean tf.Varialbe, true indicates training phase 462 | scope: string, variable scope 463 | moments_dims: a list of ints, indicating dimensions for moments calculation 464 | bn_decay: float or float tensor variable, controling moving average weight 465 | Return: 466 | normed: batch-normalized maps 467 | """ 468 | with tf.variable_scope(scope) as sc: 469 | num_channels = inputs.get_shape()[-1].value 470 | beta = tf.Variable(tf.constant(0.0, shape=[num_channels]), 471 | name='beta', trainable=True) 472 | gamma = tf.Variable(tf.constant(1.0, shape=[num_channels]), 473 | name='gamma', trainable=True) 474 | batch_mean, batch_var = tf.nn.moments(inputs, moments_dims, name='moments') 475 | decay = bn_decay if bn_decay is not None else 0.9 476 | ema = tf.train.ExponentialMovingAverage(decay=decay) 477 | # Operator that maintains moving averages of variables. 478 | ema_apply_op = tf.cond(is_training, 479 | lambda: ema.apply([batch_mean, batch_var]), 480 | lambda: tf.no_op()) 481 | 482 | # Update moving average and return current batch's avg and var. 483 | def mean_var_with_update(): 484 | with tf.control_dependencies([ema_apply_op]): 485 | return tf.identity(batch_mean), tf.identity(batch_var) 486 | 487 | # ema.average returns the Variable holding the average of var. 488 | mean, var = tf.cond(is_training, 489 | mean_var_with_update, 490 | lambda: (ema.average(batch_mean), ema.average(batch_var))) 491 | normed = tf.nn.batch_normalization(inputs, mean, var, beta, gamma, 1e-3) 492 | return normed 493 | 494 | 495 | def batch_norm_for_fc(inputs, is_training, bn_decay, scope): 496 | """ Batch normalization on FC data. 497 | 498 | Args: 499 | inputs: Tensor, 2D BxC input 500 | is_training: boolean tf.Varialbe, true indicates training phase 501 | bn_decay: float or float tensor variable, controling moving average weight 502 | scope: string, variable scope 503 | Return: 504 | normed: batch-normalized maps 505 | """ 506 | return batch_norm_template(inputs, is_training, scope, [0,], bn_decay) 507 | 508 | 509 | def batch_norm_for_conv1d(inputs, is_training, bn_decay, scope): 510 | """ Batch normalization on 1D convolutional maps. 511 | 512 | Args: 513 | inputs: Tensor, 3D BLC input maps 514 | is_training: boolean tf.Varialbe, true indicates training phase 515 | bn_decay: float or float tensor variable, controling moving average weight 516 | scope: string, variable scope 517 | Return: 518 | normed: batch-normalized maps 519 | """ 520 | return batch_norm_template(inputs, is_training, scope, [0,1], bn_decay) 521 | 522 | 523 | 524 | 525 | def batch_norm_for_conv2d(inputs, is_training, bn_decay, scope): 526 | """ Batch normalization on 2D convolutional maps. 527 | 528 | Args: 529 | inputs: Tensor, 4D BHWC input maps 530 | is_training: boolean tf.Varialbe, true indicates training phase 531 | bn_decay: float or float tensor variable, controling moving average weight 532 | scope: string, variable scope 533 | Return: 534 | normed: batch-normalized maps 535 | """ 536 | return batch_norm_template(inputs, is_training, scope, [0,1,2], bn_decay) 537 | 538 | 539 | 540 | def batch_norm_for_conv3d(inputs, is_training, bn_decay, scope): 541 | """ Batch normalization on 3D convolutional maps. 542 | 543 | Args: 544 | inputs: Tensor, 5D BDHWC input maps 545 | is_training: boolean tf.Varialbe, true indicates training phase 546 | bn_decay: float or float tensor variable, controling moving average weight 547 | scope: string, variable scope 548 | Return: 549 | normed: batch-normalized maps 550 | """ 551 | return batch_norm_template(inputs, is_training, scope, [0,1,2,3], bn_decay) 552 | 553 | 554 | def dropout(inputs, 555 | is_training, 556 | scope, 557 | keep_prob=0.5, 558 | noise_shape=None): 559 | """ Dropout layer. 560 | 561 | Args: 562 | inputs: tensor 563 | is_training: boolean tf.Variable 564 | scope: string 565 | keep_prob: float in [0,1] 566 | noise_shape: list of ints 567 | 568 | Returns: 569 | tensor variable 570 | """ 571 | with tf.variable_scope(scope) as sc: 572 | outputs = tf.cond(is_training, 573 | lambda: tf.nn.dropout(inputs, keep_prob, noise_shape), 574 | lambda: inputs) 575 | return outputs 576 | -------------------------------------------------------------------------------- /utils/util_funcs.py: -------------------------------------------------------------------------------- 1 | # External Modules 2 | import torch 3 | from torch import cuda, FloatTensor, LongTensor 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from sklearn.neighbors import NearestNeighbors 7 | from typing import Union 8 | 9 | # Types to allow for both CPU and GPU models. 10 | UFloatTensor = Union[FloatTensor, cuda.FloatTensor] 11 | ULongTensor = Union[LongTensor, cuda.LongTensor] 12 | 13 | def knn_indices_func_cpu(rep_pts : FloatTensor, # (N, pts, dim) 14 | pts : FloatTensor, # (N, x, dim) 15 | K : int, D : int 16 | ) -> LongTensor: # (N, pts, K) 17 | """ 18 | CPU-based Indexing function based on K-Nearest Neighbors search. 19 | :param rep_pts: Representative points. 20 | :param pts: Point cloud to get indices from. 21 | :param K: Number of nearest neighbors to collect. 22 | :param D: "Spread" of neighboring points. 23 | :return: Array of indices, P_idx, into pts such that pts[n][P_idx[n],:] 24 | is the set k-nearest neighbors for the representative points in pts[n]. 25 | """ 26 | rep_pts = rep_pts.data.numpy() 27 | pts = pts.data.numpy() 28 | region_idx = [] 29 | 30 | for n, p in enumerate(rep_pts): 31 | P_particular = pts[n] 32 | nbrs = NearestNeighbors(D*K + 1, algorithm = "ball_tree").fit(P_particular) 33 | indices = nbrs.kneighbors(p)[1] 34 | region_idx.append(indices[:,1::D]) 35 | 36 | region_idx = torch.from_numpy(np.stack(region_idx, axis = 0)) 37 | return region_idx 38 | 39 | def knn_indices_func_gpu(rep_pts : cuda.FloatTensor, # (N, pts, dim) 40 | pts : cuda.FloatTensor, # (N, x, dim) 41 | k : int, d : int 42 | ) -> cuda.LongTensor: # (N, pts, K) 43 | """ 44 | GPU-based Indexing function based on K-Nearest Neighbors search. 45 | Very memory intensive, and thus unoptimal for large numbers of points. 46 | :param rep_pts: Representative points. 47 | :param pts: Point cloud to get indices from. 48 | :param K: Number of nearest neighbors to collect. 49 | :param D: "Spread" of neighboring points. 50 | :return: Array of indices, P_idx, into pts such that pts[n][P_idx[n],:] 51 | is the set k-nearest neighbors for the representative points in pts[n]. 52 | """ 53 | region_idx = [] 54 | 55 | for n, qry in enumerate(rep_pts): 56 | ref = pts[n] 57 | n, d = ref.size() 58 | m, d = qry.size() 59 | mref = ref.expand(m, n, d) 60 | mqry = qry.expand(n, m, d).transpose(0, 1) 61 | dist2 = torch.sum((mqry - mref)**2, 2).squeeze() 62 | _, inds = torch.topk(dist2, k*d + 1, dim = 1, largest = False) 63 | region_idx.append(inds[:,1::d]) 64 | 65 | region_idx = torch.stack(region_idx, dim = 0) 66 | return region_idx 67 | -------------------------------------------------------------------------------- /utils/util_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Callable, Union, Tuple 3 | 4 | from util_funcs import UFloatTensor 5 | 6 | def EndChannels(f, make_contiguous = False): 7 | """ Class decorator to apply 2D convolution along end channels. """ 8 | 9 | class WrappedLayer(nn.Module): 10 | 11 | def __init__(self): 12 | super(WrappedLayer, self).__init__() 13 | self.f = f 14 | 15 | def forward(self, x): 16 | x = x.permute(0,3,1,2) 17 | x = self.f(x) 18 | x = x.permute(0,2,3,1) 19 | return x 20 | 21 | return WrappedLayer() 22 | 23 | class Dense(nn.Module): 24 | """ 25 | Single layer perceptron with optional activation, batch normalization, and dropout. 26 | """ 27 | 28 | def __init__(self, in_features : int, out_features : int, 29 | drop_rate : int = 0, with_bn : bool = True, 30 | activation : Callable[[UFloatTensor], UFloatTensor] = nn.ReLU() 31 | ) -> None: 32 | """ 33 | :param in_features: Length of input featuers (last dimension). 34 | :param out_features: Length of output features (last dimension). 35 | :param drop_rate: Drop rate to be applied after activation. 36 | :param with_bn: Whether or not to apply batch normalization. 37 | :param activation: Activation function. 38 | """ 39 | super(Dense, self).__init__() 40 | 41 | self.linear = nn.Linear(in_features, out_features) 42 | self.activation = activation 43 | # self.bn = LayerNorm(out_channels) if with_bn else None 44 | self.drop = nn.Dropout(drop_rate) if drop_rate > 0 else None 45 | 46 | def forward(self, x : UFloatTensor) -> UFloatTensor: 47 | """ 48 | :param x: Any input tensor that can be input into nn.Linear. 49 | :return: Tensor with linear layer and optional activation, batchnorm, 50 | and dropout applied. 51 | """ 52 | x = self.linear(x) 53 | if self.activation: 54 | x = self.activation(x) 55 | # if self.bn: 56 | # x = self.bn(x) 57 | if self.drop: 58 | x = self.drop(x) 59 | return x 60 | 61 | class Conv(nn.Module): 62 | """ 63 | 2D convolutional layer with optional activation and batch normalization. 64 | """ 65 | 66 | def __init__(self, in_channels : int, out_channels : int, 67 | kernel_size : Union[int, Tuple[int, int]], with_bn : bool = True, 68 | activation : Callable[[UFloatTensor], UFloatTensor] = nn.ReLU() 69 | ) -> None: 70 | """ 71 | :param in_channels: Length of input featuers (first dimension). 72 | :param out_channels: Length of output features (first dimension). 73 | :param kernel_size: Size of convolutional kernel. 74 | :param with_bn: Whether or not to apply batch normalization. 75 | :param activation: Activation function. 76 | """ 77 | super(Conv, self).__init__() 78 | 79 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias = not with_bn) 80 | self.activation = activation 81 | self.bn = nn.BatchNorm2d(out_channels, momentum = 0.9) if with_bn else None 82 | 83 | def forward(self, x : UFloatTensor) -> UFloatTensor: 84 | """ 85 | :param x: Any input tensor that can be input into nn.Conv2d. 86 | :return: Tensor with convolutional layer and optional activation and batchnorm applied. 87 | """ 88 | x = self.conv(x) 89 | if self.activation: 90 | x = self.activation(x) 91 | if self.bn: 92 | x = self.bn(x) 93 | return x 94 | 95 | class SepConv(nn.Module): 96 | """ Depthwise separable convolution with optional activation and batch normalization""" 97 | 98 | def __init__(self, in_channels : int, out_channels : int, 99 | kernel_size : Union[int, Tuple[int, int]], 100 | depth_multiplier : int = 1, with_bn : bool = True, 101 | activation : Callable[[UFloatTensor], UFloatTensor] = nn.ReLU() 102 | ) -> None: 103 | """ 104 | :param in_channels: Length of input featuers (first dimension). 105 | :param out_channels: Length of output features (first dimension). 106 | :param kernel_size: Size of convolutional kernel. 107 | :depth_multiplier: Depth multiplier for middle part of separable convolution. 108 | :param with_bn: Whether or not to apply batch normalization. 109 | :param activation: Activation function. 110 | """ 111 | super(SepConv, self).__init__() 112 | 113 | self.conv = nn.Sequential( 114 | nn.Conv2d(in_channels, in_channels * depth_multiplier, kernel_size, groups = in_channels), 115 | nn.Conv2d(in_channels * depth_multiplier, out_channels, 1, bias = not with_bn) 116 | ) 117 | 118 | self.activation = activation 119 | self.bn = nn.BatchNorm2d(out_channels, momentum = 0.9) if with_bn else None 120 | 121 | def forward(self, x : UFloatTensor) -> UFloatTensor: 122 | """ 123 | :param x: Any input tensor that can be input into nn.Conv2d. 124 | :return: Tensor with depthwise separable convolutional layer and 125 | optional activation and batchnorm applied. 126 | """ 127 | x = self.conv(x) 128 | if self.activation: 129 | x = self.activation(x) 130 | if self.bn: 131 | x = self.bn(x) 132 | return x 133 | 134 | class LayerNorm(nn.Module): 135 | """ 136 | Batch Normalization over ONLY the mini-batch layer (suitable for nn.Linear layers). 137 | """ 138 | 139 | def __init__(self, N : int, dim : int, *args, **kwargs) -> None: 140 | """ 141 | :param N: Batch size. 142 | :param D: Dimensions. 143 | """ 144 | super(LayerNorm, self).__init__() 145 | if dim == 1: 146 | self.bn = nn.BatchNorm1d(N, *args, **kwargs) 147 | elif dim == 2: 148 | self.bn = nn.BatchNorm2d(N, *args, **kwargs) 149 | elif dim == 3: 150 | self.bn = nn.BatchNorm3d(N, *args, **kwargs) 151 | else: 152 | raise ValueError("Dimensionality %i not supported" % dim) 153 | 154 | self.forward = lambda x: self.bn(x.unsqueeze(0)).squeeze(0) 155 | --------------------------------------------------------------------------------