├── .gitignore ├── LICENSE ├── README.md ├── pytorch ├── README.md ├── data.py ├── main.py ├── model.py ├── pretrained │ └── model.1024.t7 └── util.py └── tensorflow ├── README.md ├── evaluate.py ├── misc └── demo_teaser.png ├── models ├── dgcnn.py └── transform_nets.py ├── part_seg ├── README.md ├── download_data.sh ├── part_seg_model.py ├── test.py ├── testing_ply_file_list.txt ├── train_multi_gpu.py └── train_results │ └── trained_models │ ├── checkpoint │ ├── epoch_175.ckpt.data-00000-of-00001 │ ├── epoch_175.ckpt.index │ └── epoch_175.ckpt.meta ├── provider.py ├── sem_seg ├── README.md ├── batch_inference.py ├── collect_indoor3d_data.py ├── download_data.sh ├── eval_iou_accuracy.py ├── indoor3d_util.py ├── meta │ ├── all_data_label.txt │ ├── anno_paths.txt │ ├── area1_data_label.txt │ ├── area2_data_label.txt │ ├── area3_data_label.txt │ ├── area4_data_label.txt │ ├── area5_data_label.txt │ ├── area6_data_label.txt │ └── class_names.txt ├── model.py ├── test_job.sh ├── train.py └── train_job.sh ├── train.py └── utils ├── data_prep_util.py ├── eulerangles.py ├── pc_util.py ├── plyfile.py └── tf_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | log/ 3 | *.pyc 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Massachusetts Institute of Technology and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Graph CNN for Learning on Point Clouds 2 | We propose a new neural network module dubbed EdgeConv suitable for CNN-based high-level tasks on point clouds including classification and segmentation. EdgeConv is differentiable and can be plugged into existing architectures. 3 | 4 | [[Project]](https://liuziwei7.github.io/projects/DGCNN) [[Paper]](https://arxiv.org/abs/1801.07829) [[Press]](http://news.mit.edu/2019/deep-learning-point-clouds-1021) 5 | 6 | ## Overview 7 | `DGCNN` is the author's re-implementation of Dynamic Graph CNN, which achieves state-of-the-art performance on point-cloud-related high-level tasks including category classification, semantic segmentation and part segmentation. 8 | 9 | 10 | 11 | Further information please contact [Yue Wang](https://www.csail.mit.edu/person/yue-wang) and [Yongbin Sun](https://autoid.mit.edu/people-2). 12 | 13 | ## Author's Implementations 14 | 15 | The classification experiments in our paper are done with the pytorch implementation. 16 | 17 | * [tensorflow-dgcnn](./tensorflow) 18 | * [pytorch-dgcnn](./pytorch) 19 | 20 | ## Other Implementations 21 | * [pytorch-geometric](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.EdgeConv) 22 | * [pytorch-dgcnn](https://github.com/AnTao97/dgcnn.pytorch) (This implementation on S3DIS achieves significant better results than our tensorflow implementation) 23 | 24 | ## Generalization under Corruptions 25 | 26 | The performance is evaluated on [ModelNet-C](https://github.com/jiawei-ren/ModelNet-C) with mCE (lower is better) and clean OA (higher is better). 27 | 28 | | Method | Reference | Standalone | mCE | Clean OA | 29 | | --------------- | ---------------------------------------------------------- | :--------: | :---: | :------: | 30 | | PointNet | [Qi et al.](https://arxiv.org/abs/1612.00593) | Yes | 1.422 | 0.907 | 31 | | DGCNN | [Wang et al.](https://arxiv.org/abs/1801.07829) | Yes | 1.000 | 0.926 | 32 | 33 | 34 | ## Real-World Applications 35 | * DGCNN has been successfully applied to [ParticalNet in Large Hadron Collider (LHC)](https://arxiv.org/abs/1902.08570). 36 | 37 | 38 | ## Citation 39 | Please cite this paper if you want to use it in your work, 40 | 41 | @article{dgcnn, 42 | title={Dynamic Graph CNN for Learning on Point Clouds}, 43 | author={Wang, Yue and Sun, Yongbin and Liu, Ziwei and Sarma, Sanjay E. and Bronstein, Michael M. and Solomon, Justin M.}, 44 | journal={ACM Transactions on Graphics (TOG)}, 45 | year={2019} 46 | } 47 | 48 | ## License 49 | MIT License 50 | 51 | ## Acknowledgement 52 | The structure of this codebase is borrowed from [PointNet](https://github.com/charlesq34/pointnet). 53 | -------------------------------------------------------------------------------- /pytorch/README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Graph CNN for Learning on Point Clouds (PyTorch) 2 | 3 | ## Point Cloud Classification 4 | * Run the training script: 5 | 6 | 7 | ``` 1024 points 8 | python main.py --exp_name=dgcnn_1024 --model=dgcnn --num_points=1024 --k=20 --use_sgd=True 9 | ``` 10 | 11 | ``` 2048 points 12 | python main.py --exp_name=dgcnn_2048 --model=dgcnn --num_points=2048 --k=40 --use_sgd=True 13 | ``` 14 | 15 | * Run the evaluation script after training finished: 16 | 17 | ``` 1024 points 18 | python main.py --exp_name=dgcnn_1024_eval --model=dgcnn --num_points=1024 --k=20 --use_sgd=True --eval=True --model_path=checkpoints/dgcnn_1024/models/model.t7 19 | ``` 20 | 21 | ``` 2048 points 22 | python main.py --exp_name=dgcnn_2048_eval --model=dgcnn --num_points=2048 --k=40 --use_sgd=True --eval=True --model_path=checkpoints/dgcnn_2048/models/model.t7 23 | ``` 24 | 25 | * Run the evaluation script with pretrained models: 26 | 27 | ``` 1024 points 28 | python main.py --exp_name=dgcnn_1024_eval --model=dgcnn --num_points=1024 --k=20 --use_sgd=True --eval=True --model_path=pretrained/model.1024.t7 29 | ``` 30 | 31 | ``` 2048 points 32 | python main.py --exp_name=dgcnn_2048_eval --model=dgcnn --num_points=2048 --k=40 --use_sgd=True --eval=True --model_path=pretrained/model.2048.t7 33 | ``` 34 | -------------------------------------------------------------------------------- /pytorch/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Author: Yue Wang 5 | @Contact: yuewangx@mit.edu 6 | @File: data.py 7 | @Time: 2018/10/13 6:21 PM 8 | """ 9 | 10 | 11 | import os 12 | import sys 13 | import glob 14 | import h5py 15 | import numpy as np 16 | from torch.utils.data import Dataset 17 | 18 | 19 | def download(): 20 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 21 | DATA_DIR = os.path.join(BASE_DIR, 'data') 22 | if not os.path.exists(DATA_DIR): 23 | os.mkdir(DATA_DIR) 24 | if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): 25 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 26 | zipfile = os.path.basename(www) 27 | os.system('wget %s; unzip %s' % (www, zipfile)) 28 | os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) 29 | os.system('rm %s' % (zipfile)) 30 | 31 | 32 | def load_data(partition): 33 | download() 34 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 35 | DATA_DIR = os.path.join(BASE_DIR, 'data') 36 | all_data = [] 37 | all_label = [] 38 | for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)): 39 | f = h5py.File(h5_name) 40 | data = f['data'][:].astype('float32') 41 | label = f['label'][:].astype('int64') 42 | f.close() 43 | all_data.append(data) 44 | all_label.append(label) 45 | all_data = np.concatenate(all_data, axis=0) 46 | all_label = np.concatenate(all_label, axis=0) 47 | return all_data, all_label 48 | 49 | 50 | def translate_pointcloud(pointcloud): 51 | xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) 52 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 53 | 54 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 55 | return translated_pointcloud 56 | 57 | 58 | def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): 59 | N, C = pointcloud.shape 60 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) 61 | return pointcloud 62 | 63 | 64 | class ModelNet40(Dataset): 65 | def __init__(self, num_points, partition='train'): 66 | self.data, self.label = load_data(partition) 67 | self.num_points = num_points 68 | self.partition = partition 69 | 70 | def __getitem__(self, item): 71 | pointcloud = self.data[item][:self.num_points] 72 | label = self.label[item] 73 | if self.partition == 'train': 74 | pointcloud = translate_pointcloud(pointcloud) 75 | np.random.shuffle(pointcloud) 76 | return pointcloud, label 77 | 78 | def __len__(self): 79 | return self.data.shape[0] 80 | 81 | 82 | if __name__ == '__main__': 83 | train = ModelNet40(1024) 84 | test = ModelNet40(1024, 'test') 85 | for data, label in train: 86 | print(data.shape) 87 | print(label.shape) 88 | -------------------------------------------------------------------------------- /pytorch/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Author: Yue Wang 5 | @Contact: yuewangx@mit.edu 6 | @File: main.py 7 | @Time: 2018/10/13 10:39 PM 8 | """ 9 | 10 | 11 | from __future__ import print_function 12 | import os 13 | import argparse 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | from torch.optim.lr_scheduler import CosineAnnealingLR 19 | from data import ModelNet40 20 | from model import PointNet, DGCNN 21 | import numpy as np 22 | from torch.utils.data import DataLoader 23 | from util import cal_loss, IOStream 24 | import sklearn.metrics as metrics 25 | 26 | 27 | def _init_(): 28 | if not os.path.exists('checkpoints'): 29 | os.makedirs('checkpoints') 30 | if not os.path.exists('checkpoints/'+args.exp_name): 31 | os.makedirs('checkpoints/'+args.exp_name) 32 | if not os.path.exists('checkpoints/'+args.exp_name+'/'+'models'): 33 | os.makedirs('checkpoints/'+args.exp_name+'/'+'models') 34 | os.system('cp main.py checkpoints'+'/'+args.exp_name+'/'+'main.py.backup') 35 | os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup') 36 | os.system('cp util.py checkpoints' + '/' + args.exp_name + '/' + 'util.py.backup') 37 | os.system('cp data.py checkpoints' + '/' + args.exp_name + '/' + 'data.py.backup') 38 | 39 | def train(args, io): 40 | train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=8, 41 | batch_size=args.batch_size, shuffle=True, drop_last=True) 42 | test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=8, 43 | batch_size=args.test_batch_size, shuffle=True, drop_last=False) 44 | 45 | device = torch.device("cuda" if args.cuda else "cpu") 46 | 47 | #Try to load models 48 | if args.model == 'pointnet': 49 | model = PointNet(args).to(device) 50 | elif args.model == 'dgcnn': 51 | model = DGCNN(args).to(device) 52 | else: 53 | raise Exception("Not implemented") 54 | print(str(model)) 55 | 56 | model = nn.DataParallel(model) 57 | print("Let's use", torch.cuda.device_count(), "GPUs!") 58 | 59 | if args.use_sgd: 60 | print("Use SGD") 61 | opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4) 62 | else: 63 | print("Use Adam") 64 | opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) 65 | 66 | scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr) 67 | 68 | criterion = cal_loss 69 | 70 | best_test_acc = 0 71 | for epoch in range(args.epochs): 72 | scheduler.step() 73 | #################### 74 | # Train 75 | #################### 76 | train_loss = 0.0 77 | count = 0.0 78 | model.train() 79 | train_pred = [] 80 | train_true = [] 81 | for data, label in train_loader: 82 | data, label = data.to(device), label.to(device).squeeze() 83 | data = data.permute(0, 2, 1) 84 | batch_size = data.size()[0] 85 | opt.zero_grad() 86 | logits = model(data) 87 | loss = criterion(logits, label) 88 | loss.backward() 89 | opt.step() 90 | preds = logits.max(dim=1)[1] 91 | count += batch_size 92 | train_loss += loss.item() * batch_size 93 | train_true.append(label.cpu().numpy()) 94 | train_pred.append(preds.detach().cpu().numpy()) 95 | train_true = np.concatenate(train_true) 96 | train_pred = np.concatenate(train_pred) 97 | outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f' % (epoch, 98 | train_loss*1.0/count, 99 | metrics.accuracy_score( 100 | train_true, train_pred), 101 | metrics.balanced_accuracy_score( 102 | train_true, train_pred)) 103 | io.cprint(outstr) 104 | 105 | #################### 106 | # Test 107 | #################### 108 | test_loss = 0.0 109 | count = 0.0 110 | model.eval() 111 | test_pred = [] 112 | test_true = [] 113 | for data, label in test_loader: 114 | data, label = data.to(device), label.to(device).squeeze() 115 | data = data.permute(0, 2, 1) 116 | batch_size = data.size()[0] 117 | logits = model(data) 118 | loss = criterion(logits, label) 119 | preds = logits.max(dim=1)[1] 120 | count += batch_size 121 | test_loss += loss.item() * batch_size 122 | test_true.append(label.cpu().numpy()) 123 | test_pred.append(preds.detach().cpu().numpy()) 124 | test_true = np.concatenate(test_true) 125 | test_pred = np.concatenate(test_pred) 126 | test_acc = metrics.accuracy_score(test_true, test_pred) 127 | avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred) 128 | outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f' % (epoch, 129 | test_loss*1.0/count, 130 | test_acc, 131 | avg_per_class_acc) 132 | io.cprint(outstr) 133 | if test_acc >= best_test_acc: 134 | best_test_acc = test_acc 135 | torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % args.exp_name) 136 | 137 | 138 | def test(args, io): 139 | test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), 140 | batch_size=args.test_batch_size, shuffle=True, drop_last=False) 141 | 142 | device = torch.device("cuda" if args.cuda else "cpu") 143 | 144 | #Try to load models 145 | model = DGCNN(args).to(device) 146 | model = nn.DataParallel(model) 147 | model.load_state_dict(torch.load(args.model_path)) 148 | model = model.eval() 149 | test_acc = 0.0 150 | count = 0.0 151 | test_true = [] 152 | test_pred = [] 153 | for data, label in test_loader: 154 | 155 | data, label = data.to(device), label.to(device).squeeze() 156 | data = data.permute(0, 2, 1) 157 | batch_size = data.size()[0] 158 | logits = model(data) 159 | preds = logits.max(dim=1)[1] 160 | test_true.append(label.cpu().numpy()) 161 | test_pred.append(preds.detach().cpu().numpy()) 162 | test_true = np.concatenate(test_true) 163 | test_pred = np.concatenate(test_pred) 164 | test_acc = metrics.accuracy_score(test_true, test_pred) 165 | avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred) 166 | outstr = 'Test :: test acc: %.6f, test avg acc: %.6f'%(test_acc, avg_per_class_acc) 167 | io.cprint(outstr) 168 | 169 | 170 | if __name__ == "__main__": 171 | # Training settings 172 | parser = argparse.ArgumentParser(description='Point Cloud Recognition') 173 | parser.add_argument('--exp_name', type=str, default='exp', metavar='N', 174 | help='Name of the experiment') 175 | parser.add_argument('--model', type=str, default='dgcnn', metavar='N', 176 | choices=['pointnet', 'dgcnn'], 177 | help='Model to use, [pointnet, dgcnn]') 178 | parser.add_argument('--dataset', type=str, default='modelnet40', metavar='N', 179 | choices=['modelnet40']) 180 | parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size', 181 | help='Size of batch)') 182 | parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size', 183 | help='Size of batch)') 184 | parser.add_argument('--epochs', type=int, default=250, metavar='N', 185 | help='number of episode to train ') 186 | parser.add_argument('--use_sgd', type=bool, default=True, 187 | help='Use SGD') 188 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 189 | help='learning rate (default: 0.001, 0.1 if using sgd)') 190 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 191 | help='SGD momentum (default: 0.9)') 192 | parser.add_argument('--no_cuda', type=bool, default=False, 193 | help='enables CUDA training') 194 | parser.add_argument('--seed', type=int, default=1, metavar='S', 195 | help='random seed (default: 1)') 196 | parser.add_argument('--eval', type=bool, default=False, 197 | help='evaluate the model') 198 | parser.add_argument('--num_points', type=int, default=1024, 199 | help='num of points to use') 200 | parser.add_argument('--dropout', type=float, default=0.5, 201 | help='dropout rate') 202 | parser.add_argument('--emb_dims', type=int, default=1024, metavar='N', 203 | help='Dimension of embeddings') 204 | parser.add_argument('--k', type=int, default=20, metavar='N', 205 | help='Num of nearest neighbors to use') 206 | parser.add_argument('--model_path', type=str, default='', metavar='N', 207 | help='Pretrained model path') 208 | args = parser.parse_args() 209 | 210 | _init_() 211 | 212 | io = IOStream('checkpoints/' + args.exp_name + '/run.log') 213 | io.cprint(str(args)) 214 | 215 | args.cuda = not args.no_cuda and torch.cuda.is_available() 216 | torch.manual_seed(args.seed) 217 | if args.cuda: 218 | io.cprint( 219 | 'Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices') 220 | torch.cuda.manual_seed(args.seed) 221 | else: 222 | io.cprint('Using CPU') 223 | 224 | if not args.eval: 225 | train(args, io) 226 | else: 227 | test(args, io) 228 | -------------------------------------------------------------------------------- /pytorch/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Author: Yue Wang 5 | @Contact: yuewangx@mit.edu 6 | @File: model.py 7 | @Time: 2018/10/13 6:35 PM 8 | """ 9 | 10 | 11 | import os 12 | import sys 13 | import copy 14 | import math 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | 21 | def knn(x, k): 22 | inner = -2*torch.matmul(x.transpose(2, 1), x) 23 | xx = torch.sum(x**2, dim=1, keepdim=True) 24 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 25 | 26 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 27 | return idx 28 | 29 | 30 | def get_graph_feature(x, k=20, idx=None): 31 | batch_size = x.size(0) 32 | num_points = x.size(2) 33 | x = x.view(batch_size, -1, num_points) 34 | if idx is None: 35 | idx = knn(x, k=k) # (batch_size, num_points, k) 36 | device = torch.device('cuda') 37 | 38 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points 39 | 40 | idx = idx + idx_base 41 | 42 | idx = idx.view(-1) 43 | 44 | _, num_dims, _ = x.size() 45 | 46 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 47 | feature = x.view(batch_size*num_points, -1)[idx, :] 48 | feature = feature.view(batch_size, num_points, k, num_dims) 49 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 50 | 51 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() 52 | 53 | return feature 54 | 55 | 56 | class PointNet(nn.Module): 57 | def __init__(self, args, output_channels=40): 58 | super(PointNet, self).__init__() 59 | self.args = args 60 | self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False) 61 | self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) 62 | self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False) 63 | self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False) 64 | self.conv5 = nn.Conv1d(128, args.emb_dims, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm1d(64) 66 | self.bn2 = nn.BatchNorm1d(64) 67 | self.bn3 = nn.BatchNorm1d(64) 68 | self.bn4 = nn.BatchNorm1d(128) 69 | self.bn5 = nn.BatchNorm1d(args.emb_dims) 70 | self.linear1 = nn.Linear(args.emb_dims, 512, bias=False) 71 | self.bn6 = nn.BatchNorm1d(512) 72 | self.dp1 = nn.Dropout() 73 | self.linear2 = nn.Linear(512, output_channels) 74 | 75 | def forward(self, x): 76 | x = F.relu(self.bn1(self.conv1(x))) 77 | x = F.relu(self.bn2(self.conv2(x))) 78 | x = F.relu(self.bn3(self.conv3(x))) 79 | x = F.relu(self.bn4(self.conv4(x))) 80 | x = F.relu(self.bn5(self.conv5(x))) 81 | x = F.adaptive_max_pool1d(x, 1).squeeze() 82 | x = F.relu(self.bn6(self.linear1(x))) 83 | x = self.dp1(x) 84 | x = self.linear2(x) 85 | return x 86 | 87 | 88 | class DGCNN(nn.Module): 89 | def __init__(self, args, output_channels=40): 90 | super(DGCNN, self).__init__() 91 | self.args = args 92 | self.k = args.k 93 | 94 | self.bn1 = nn.BatchNorm2d(64) 95 | self.bn2 = nn.BatchNorm2d(64) 96 | self.bn3 = nn.BatchNorm2d(128) 97 | self.bn4 = nn.BatchNorm2d(256) 98 | self.bn5 = nn.BatchNorm1d(args.emb_dims) 99 | 100 | self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), 101 | self.bn1, 102 | nn.LeakyReLU(negative_slope=0.2)) 103 | self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False), 104 | self.bn2, 105 | nn.LeakyReLU(negative_slope=0.2)) 106 | self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False), 107 | self.bn3, 108 | nn.LeakyReLU(negative_slope=0.2)) 109 | self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False), 110 | self.bn4, 111 | nn.LeakyReLU(negative_slope=0.2)) 112 | self.conv5 = nn.Sequential(nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False), 113 | self.bn5, 114 | nn.LeakyReLU(negative_slope=0.2)) 115 | self.linear1 = nn.Linear(args.emb_dims*2, 512, bias=False) 116 | self.bn6 = nn.BatchNorm1d(512) 117 | self.dp1 = nn.Dropout(p=args.dropout) 118 | self.linear2 = nn.Linear(512, 256) 119 | self.bn7 = nn.BatchNorm1d(256) 120 | self.dp2 = nn.Dropout(p=args.dropout) 121 | self.linear3 = nn.Linear(256, output_channels) 122 | 123 | def forward(self, x): 124 | batch_size = x.size(0) 125 | x = get_graph_feature(x, k=self.k) 126 | x = self.conv1(x) 127 | x1 = x.max(dim=-1, keepdim=False)[0] 128 | 129 | x = get_graph_feature(x1, k=self.k) 130 | x = self.conv2(x) 131 | x2 = x.max(dim=-1, keepdim=False)[0] 132 | 133 | x = get_graph_feature(x2, k=self.k) 134 | x = self.conv3(x) 135 | x3 = x.max(dim=-1, keepdim=False)[0] 136 | 137 | x = get_graph_feature(x3, k=self.k) 138 | x = self.conv4(x) 139 | x4 = x.max(dim=-1, keepdim=False)[0] 140 | 141 | x = torch.cat((x1, x2, x3, x4), dim=1) 142 | 143 | x = self.conv5(x) 144 | x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 145 | x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) 146 | x = torch.cat((x1, x2), 1) 147 | 148 | x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) 149 | x = self.dp1(x) 150 | x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) 151 | x = self.dp2(x) 152 | x = self.linear3(x) 153 | return x 154 | -------------------------------------------------------------------------------- /pytorch/pretrained/model.1024.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangYueFt/dgcnn/f765b469a67730658ba554e97dc11723a7bab628/pytorch/pretrained/model.1024.t7 -------------------------------------------------------------------------------- /pytorch/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Author: Yue Wang 5 | @Contact: yuewangx@mit.edu 6 | @File: util 7 | @Time: 4/5/19 3:47 PM 8 | """ 9 | 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | 16 | def cal_loss(pred, gold, smoothing=True): 17 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 18 | 19 | gold = gold.contiguous().view(-1) 20 | 21 | if smoothing: 22 | eps = 0.2 23 | n_class = pred.size(1) 24 | 25 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 26 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 27 | log_prb = F.log_softmax(pred, dim=1) 28 | 29 | loss = -(one_hot * log_prb).sum(dim=1).mean() 30 | else: 31 | loss = F.cross_entropy(pred, gold, reduction='mean') 32 | 33 | return loss 34 | 35 | 36 | class IOStream(): 37 | def __init__(self, path): 38 | self.f = open(path, 'a') 39 | 40 | def cprint(self, text): 41 | print(text) 42 | self.f.write(text+'\n') 43 | self.f.flush() 44 | 45 | def close(self): 46 | self.f.close() 47 | -------------------------------------------------------------------------------- /tensorflow/README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Graph CNN for Learning on Point Clouds (TensorFlow) 2 | 3 | ## Point Cloud Classification 4 | * Run the training script: 5 | 6 | ``` bash 7 | python train.py 8 | ``` 9 | 10 | * Run the evaluation script after training finished: 11 | 12 | ``` bash 13 | python evaluate.py 14 | 15 | ``` 16 | -------------------------------------------------------------------------------- /tensorflow/evaluate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import argparse 4 | import socket 5 | import importlib 6 | import time 7 | import os 8 | import scipy.misc 9 | import sys 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(BASE_DIR) 12 | sys.path.append(os.path.join(BASE_DIR, 'models')) 13 | sys.path.append(os.path.join(BASE_DIR, 'utils')) 14 | import provider 15 | import pc_util 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 20 | parser.add_argument('--model', default='dgcnn', help='Model name: dgcnn [default: dgcnn]') 21 | parser.add_argument('--batch_size', type=int, default=4, help='Batch Size during training [default: 1]') 22 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]') 23 | parser.add_argument('--model_path', default='log/model.ckpt', help='model checkpoint file path [default: log/model.ckpt]') 24 | parser.add_argument('--dump_dir', default='dump', help='dump folder path [dump]') 25 | parser.add_argument('--visu', action='store_true', help='Whether to dump image for error case [default: False]') 26 | FLAGS = parser.parse_args() 27 | 28 | 29 | BATCH_SIZE = FLAGS.batch_size 30 | NUM_POINT = FLAGS.num_point 31 | MODEL_PATH = FLAGS.model_path 32 | GPU_INDEX = FLAGS.gpu 33 | MODEL = importlib.import_module(FLAGS.model) # import network module 34 | DUMP_DIR = FLAGS.dump_dir 35 | if not os.path.exists(DUMP_DIR): os.mkdir(DUMP_DIR) 36 | LOG_FOUT = open(os.path.join(DUMP_DIR, 'log_evaluate.txt'), 'w') 37 | LOG_FOUT.write(str(FLAGS)+'\n') 38 | 39 | NUM_CLASSES = 40 40 | SHAPE_NAMES = [line.rstrip() for line in \ 41 | open(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/shape_names.txt'))] 42 | 43 | HOSTNAME = socket.gethostname() 44 | 45 | # ModelNet40 official train/test split 46 | TRAIN_FILES = provider.getDataFiles( \ 47 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt')) 48 | TEST_FILES = provider.getDataFiles(\ 49 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt')) 50 | 51 | def log_string(out_str): 52 | LOG_FOUT.write(out_str+'\n') 53 | LOG_FOUT.flush() 54 | print(out_str) 55 | 56 | def evaluate(num_votes): 57 | is_training = False 58 | 59 | with tf.device('/gpu:'+str(GPU_INDEX)): 60 | pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT) 61 | is_training_pl = tf.placeholder(tf.bool, shape=()) 62 | 63 | # simple model 64 | pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl) 65 | loss = MODEL.get_loss(pred, labels_pl, end_points) 66 | 67 | # Add ops to save and restore all the variables. 68 | saver = tf.train.Saver() 69 | 70 | # Create a session 71 | config = tf.ConfigProto() 72 | config.gpu_options.allow_growth = True 73 | config.allow_soft_placement = True 74 | config.log_device_placement = True 75 | sess = tf.Session(config=config) 76 | 77 | # Restore variables from disk. 78 | saver.restore(sess, MODEL_PATH) 79 | log_string("Model restored.") 80 | 81 | ops = {'pointclouds_pl': pointclouds_pl, 82 | 'labels_pl': labels_pl, 83 | 'is_training_pl': is_training_pl, 84 | 'pred': pred, 85 | 'loss': loss} 86 | 87 | eval_one_epoch(sess, ops, num_votes) 88 | 89 | 90 | def eval_one_epoch(sess, ops, num_votes=1, topk=1): 91 | error_cnt = 0 92 | is_training = False 93 | total_correct = 0 94 | total_seen = 0 95 | loss_sum = 0 96 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 97 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 98 | fout = open(os.path.join(DUMP_DIR, 'pred_label.txt'), 'w') 99 | for fn in range(len(TEST_FILES)): 100 | log_string('----'+str(fn)+'----') 101 | current_data, current_label = provider.loadDataFile(TEST_FILES[fn]) 102 | current_data = current_data[:,0:NUM_POINT,:] 103 | current_label = np.squeeze(current_label) 104 | print(current_data.shape) 105 | 106 | file_size = current_data.shape[0] 107 | num_batches = file_size // BATCH_SIZE 108 | print(file_size) 109 | 110 | for batch_idx in range(num_batches): 111 | start_idx = batch_idx * BATCH_SIZE 112 | end_idx = (batch_idx+1) * BATCH_SIZE 113 | cur_batch_size = end_idx - start_idx 114 | 115 | # Aggregating BEG 116 | batch_loss_sum = 0 # sum of losses for the batch 117 | batch_pred_sum = np.zeros((cur_batch_size, NUM_CLASSES)) # score for classes 118 | batch_pred_classes = np.zeros((cur_batch_size, NUM_CLASSES)) # 0/1 for classes 119 | for vote_idx in range(num_votes): 120 | rotated_data = provider.rotate_point_cloud_by_angle(current_data[start_idx:end_idx, :, :], 121 | vote_idx/float(num_votes) * np.pi * 2) 122 | feed_dict = {ops['pointclouds_pl']: rotated_data, 123 | ops['labels_pl']: current_label[start_idx:end_idx], 124 | ops['is_training_pl']: is_training} 125 | loss_val, pred_val = sess.run([ops['loss'], ops['pred']], 126 | feed_dict=feed_dict) 127 | batch_pred_sum += pred_val 128 | batch_pred_val = np.argmax(pred_val, 1) 129 | for el_idx in range(cur_batch_size): 130 | batch_pred_classes[el_idx, batch_pred_val[el_idx]] += 1 131 | batch_loss_sum += (loss_val * cur_batch_size / float(num_votes)) 132 | # pred_val_topk = np.argsort(batch_pred_sum, axis=-1)[:,-1*np.array(range(topk))-1] 133 | # pred_val = np.argmax(batch_pred_classes, 1) 134 | pred_val = np.argmax(batch_pred_sum, 1) 135 | # Aggregating END 136 | 137 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) 138 | # correct = np.sum(pred_val_topk[:,0:topk] == label_val) 139 | total_correct += correct 140 | total_seen += cur_batch_size 141 | loss_sum += batch_loss_sum 142 | 143 | for i in range(start_idx, end_idx): 144 | l = current_label[i] 145 | total_seen_class[l] += 1 146 | total_correct_class[l] += (pred_val[i-start_idx] == l) 147 | fout.write('%d, %d\n' % (pred_val[i-start_idx], l)) 148 | 149 | if pred_val[i-start_idx] != l and FLAGS.visu: # ERROR CASE, DUMP! 150 | img_filename = '%d_label_%s_pred_%s.jpg' % (error_cnt, SHAPE_NAMES[l], 151 | SHAPE_NAMES[pred_val[i-start_idx]]) 152 | img_filename = os.path.join(DUMP_DIR, img_filename) 153 | output_img = pc_util.point_cloud_three_views(np.squeeze(current_data[i, :, :])) 154 | scipy.misc.imsave(img_filename, output_img) 155 | error_cnt += 1 156 | 157 | log_string('eval mean loss: %f' % (loss_sum / float(total_seen))) 158 | log_string('eval accuracy: %f' % (total_correct / float(total_seen))) 159 | log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)))) 160 | 161 | class_accuracies = np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float) 162 | for i, name in enumerate(SHAPE_NAMES): 163 | log_string('%10s:\t%0.3f' % (name, class_accuracies[i])) 164 | 165 | 166 | 167 | if __name__=='__main__': 168 | with tf.Graph().as_default(): 169 | evaluate(num_votes=12) 170 | LOG_FOUT.close() 171 | -------------------------------------------------------------------------------- /tensorflow/misc/demo_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangYueFt/dgcnn/f765b469a67730658ba554e97dc11723a7bab628/tensorflow/misc/demo_teaser.png -------------------------------------------------------------------------------- /tensorflow/models/dgcnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | import sys 5 | import os 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | sys.path.append(BASE_DIR) 8 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 9 | sys.path.append(os.path.join(BASE_DIR, '../../utils')) 10 | import tf_util 11 | from transform_nets import input_transform_net 12 | 13 | 14 | def placeholder_inputs(batch_size, num_point): 15 | pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) 16 | labels_pl = tf.placeholder(tf.int32, shape=(batch_size)) 17 | return pointclouds_pl, labels_pl 18 | 19 | 20 | def get_model(point_cloud, is_training, bn_decay=None): 21 | """ Classification PointNet, input is BxNx3, output Bx40 """ 22 | batch_size = point_cloud.get_shape()[0].value 23 | num_point = point_cloud.get_shape()[1].value 24 | end_points = {} 25 | k = 20 26 | 27 | adj_matrix = tf_util.pairwise_distance(point_cloud) 28 | nn_idx = tf_util.knn(adj_matrix, k=k) 29 | edge_feature = tf_util.get_edge_feature(point_cloud, nn_idx=nn_idx, k=k) 30 | 31 | with tf.variable_scope('transform_net1') as sc: 32 | transform = input_transform_net(edge_feature, is_training, bn_decay, K=3) 33 | 34 | point_cloud_transformed = tf.matmul(point_cloud, transform) 35 | adj_matrix = tf_util.pairwise_distance(point_cloud_transformed) 36 | nn_idx = tf_util.knn(adj_matrix, k=k) 37 | edge_feature = tf_util.get_edge_feature(point_cloud_transformed, nn_idx=nn_idx, k=k) 38 | 39 | net = tf_util.conv2d(edge_feature, 64, [1,1], 40 | padding='VALID', stride=[1,1], 41 | bn=True, is_training=is_training, 42 | scope='dgcnn1', bn_decay=bn_decay) 43 | net = tf.reduce_max(net, axis=-2, keep_dims=True) 44 | net1 = net 45 | 46 | adj_matrix = tf_util.pairwise_distance(net) 47 | nn_idx = tf_util.knn(adj_matrix, k=k) 48 | edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k) 49 | 50 | net = tf_util.conv2d(edge_feature, 64, [1,1], 51 | padding='VALID', stride=[1,1], 52 | bn=True, is_training=is_training, 53 | scope='dgcnn2', bn_decay=bn_decay) 54 | net = tf.reduce_max(net, axis=-2, keep_dims=True) 55 | net2 = net 56 | 57 | adj_matrix = tf_util.pairwise_distance(net) 58 | nn_idx = tf_util.knn(adj_matrix, k=k) 59 | edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k) 60 | 61 | net = tf_util.conv2d(edge_feature, 64, [1,1], 62 | padding='VALID', stride=[1,1], 63 | bn=True, is_training=is_training, 64 | scope='dgcnn3', bn_decay=bn_decay) 65 | net = tf.reduce_max(net, axis=-2, keep_dims=True) 66 | net3 = net 67 | 68 | adj_matrix = tf_util.pairwise_distance(net) 69 | nn_idx = tf_util.knn(adj_matrix, k=k) 70 | edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k) 71 | 72 | net = tf_util.conv2d(edge_feature, 128, [1,1], 73 | padding='VALID', stride=[1,1], 74 | bn=True, is_training=is_training, 75 | scope='dgcnn4', bn_decay=bn_decay) 76 | net = tf.reduce_max(net, axis=-2, keep_dims=True) 77 | net4 = net 78 | 79 | net = tf_util.conv2d(tf.concat([net1, net2, net3, net4], axis=-1), 1024, [1, 1], 80 | padding='VALID', stride=[1,1], 81 | bn=True, is_training=is_training, 82 | scope='agg', bn_decay=bn_decay) 83 | 84 | net = tf.reduce_max(net, axis=1, keep_dims=True) 85 | 86 | # MLP on global point cloud vector 87 | net = tf.reshape(net, [batch_size, -1]) 88 | net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, 89 | scope='fc1', bn_decay=bn_decay) 90 | net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, 91 | scope='dp1') 92 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, 93 | scope='fc2', bn_decay=bn_decay) 94 | net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, 95 | scope='dp2') 96 | net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3') 97 | 98 | return net, end_points 99 | 100 | 101 | def get_loss(pred, label, end_points): 102 | """ pred: B*NUM_CLASSES, 103 | label: B, """ 104 | labels = tf.one_hot(indices=label, depth=40) 105 | loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=pred, label_smoothing=0.2) 106 | classify_loss = tf.reduce_mean(loss) 107 | return classify_loss 108 | 109 | 110 | if __name__=='__main__': 111 | batch_size = 2 112 | num_pt = 124 113 | pos_dim = 3 114 | 115 | input_feed = np.random.rand(batch_size, num_pt, pos_dim) 116 | label_feed = np.random.rand(batch_size) 117 | label_feed[label_feed>=0.5] = 1 118 | label_feed[label_feed<0.5] = 0 119 | label_feed = label_feed.astype(np.int32) 120 | 121 | # # np.save('./debug/input_feed.npy', input_feed) 122 | # input_feed = np.load('./debug/input_feed.npy') 123 | # print input_feed 124 | 125 | with tf.Graph().as_default(): 126 | input_pl, label_pl = placeholder_inputs(batch_size, num_pt) 127 | pos, ftr = get_model(input_pl, tf.constant(True)) 128 | # loss = get_loss(logits, label_pl, None) 129 | 130 | with tf.Session() as sess: 131 | sess.run(tf.global_variables_initializer()) 132 | feed_dict = {input_pl: input_feed, label_pl: label_feed} 133 | res1, res2 = sess.run([pos, ftr], feed_dict=feed_dict) 134 | print res1.shape 135 | print res1 136 | 137 | print res2.shape 138 | print res2 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /tensorflow/models/transform_nets.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys 4 | import os 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(BASE_DIR) 7 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 8 | import tf_util 9 | 10 | def input_transform_net(edge_feature, is_training, bn_decay=None, K=3, is_dist=False): 11 | """ Input (XYZ) Transform Net, input is BxNx3 gray image 12 | Return: 13 | Transformation matrix of size 3xK """ 14 | batch_size = edge_feature.get_shape()[0].value 15 | num_point = edge_feature.get_shape()[1].value 16 | 17 | # input_image = tf.expand_dims(point_cloud, -1) 18 | net = tf_util.conv2d(edge_feature, 64, [1,1], 19 | padding='VALID', stride=[1,1], 20 | bn=True, is_training=is_training, 21 | scope='tconv1', bn_decay=bn_decay, is_dist=is_dist) 22 | net = tf_util.conv2d(net, 128, [1,1], 23 | padding='VALID', stride=[1,1], 24 | bn=True, is_training=is_training, 25 | scope='tconv2', bn_decay=bn_decay, is_dist=is_dist) 26 | 27 | net = tf.reduce_max(net, axis=-2, keep_dims=True) 28 | 29 | net = tf_util.conv2d(net, 1024, [1,1], 30 | padding='VALID', stride=[1,1], 31 | bn=True, is_training=is_training, 32 | scope='tconv3', bn_decay=bn_decay, is_dist=is_dist) 33 | net = tf_util.max_pool2d(net, [num_point,1], 34 | padding='VALID', scope='tmaxpool') 35 | 36 | net = tf.reshape(net, [batch_size, -1]) 37 | net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, 38 | scope='tfc1', bn_decay=bn_decay,is_dist=is_dist) 39 | net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, 40 | scope='tfc2', bn_decay=bn_decay,is_dist=is_dist) 41 | 42 | with tf.variable_scope('transform_XYZ') as sc: 43 | # assert(K==3) 44 | with tf.device('/cpu:0'): 45 | weights = tf.get_variable('weights', [256, K*K], 46 | initializer=tf.constant_initializer(0.0), 47 | dtype=tf.float32) 48 | biases = tf.get_variable('biases', [K*K], 49 | initializer=tf.constant_initializer(0.0), 50 | dtype=tf.float32) 51 | biases += tf.constant(np.eye(K).flatten(), dtype=tf.float32) 52 | transform = tf.matmul(net, weights) 53 | transform = tf.nn.bias_add(transform, biases) 54 | 55 | transform = tf.reshape(transform, [batch_size, K, K]) 56 | return transform -------------------------------------------------------------------------------- /tensorflow/part_seg/README.md: -------------------------------------------------------------------------------- 1 | ## Part segmentation 2 | 3 | ### Dataset 4 | 5 | Load the data for part segmentation. 6 | 7 | ``` 8 | sh +x download_data.sh 9 | ``` 10 | 11 | ### Train 12 | 13 | Train the model on 2 GPUs, each with 12 GB memeory. 14 | 15 | ``` 16 | python train_multi_gpu.py 17 | ``` 18 | 19 | Model parameters are saved every 5 epochs in "train_results/trained_models/". 20 | 21 | ### Evaluation 22 | 23 | To evaluate the model saved after epoch n, 24 | 25 | ``` 26 | python test.py --model_path train_results/trained_models/epoch_n.ckpt 27 | ``` 28 | 29 | For example, if we want to test the model saved after 175 epochs (provided), 30 | 31 | ``` 32 | python test.py --model_path train_results/trained_models/epoch_175.ckpt 33 | ``` 34 | -------------------------------------------------------------------------------- /tensorflow/part_seg/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download original ShapeNetPart dataset (around 1GB) ['PartAnnotation'] 4 | wget https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_v0.zip 5 | unzip shapenetcore_partanno_v0.zip 6 | rm shapenetcore_partanno_v0.zip 7 | 8 | # Download HDF5 for ShapeNet Part segmentation (around 346MB) ['hdf5_data'] 9 | wget https://shapenet.cs.stanford.edu/media/shapenet_part_seg_hdf5_data.zip 10 | unzip shapenet_part_seg_hdf5_data.zip 11 | rm shapenet_part_seg_hdf5_data.zip 12 | -------------------------------------------------------------------------------- /tensorflow/part_seg/part_seg_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | import os 5 | import sys 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | sys.path.append(os.path.dirname(BASE_DIR)) 8 | sys.path.append(os.path.join(BASE_DIR, '../utils')) 9 | sys.path.append(os.path.join(BASE_DIR, '../models')) 10 | sys.path.append(os.path.join(BASE_DIR, '../')) 11 | import tf_util 12 | from transform_nets import input_transform_net 13 | 14 | def get_model(point_cloud, input_label, is_training, cat_num, part_num, \ 15 | batch_size, num_point, weight_decay, bn_decay=None): 16 | 17 | batch_size = point_cloud.get_shape()[0].value 18 | num_point = point_cloud.get_shape()[1].value 19 | input_image = tf.expand_dims(point_cloud, -1) 20 | 21 | k = 20 22 | 23 | adj = tf_util.pairwise_distance(point_cloud) 24 | nn_idx = tf_util.knn(adj, k=k) 25 | edge_feature = tf_util.get_edge_feature(input_image, nn_idx=nn_idx, k=k) 26 | 27 | with tf.variable_scope('transform_net1') as sc: 28 | transform = input_transform_net(edge_feature, is_training, bn_decay, K=3, is_dist=True) 29 | point_cloud_transformed = tf.matmul(point_cloud, transform) 30 | 31 | input_image = tf.expand_dims(point_cloud_transformed, -1) 32 | adj = tf_util.pairwise_distance(point_cloud_transformed) 33 | nn_idx = tf_util.knn(adj, k=k) 34 | edge_feature = tf_util.get_edge_feature(input_image, nn_idx=nn_idx, k=k) 35 | 36 | out1 = tf_util.conv2d(edge_feature, 64, [1,1], 37 | padding='VALID', stride=[1,1], 38 | bn=True, is_training=is_training, weight_decay=weight_decay, 39 | scope='adj_conv1', bn_decay=bn_decay, is_dist=True) 40 | 41 | out2 = tf_util.conv2d(out1, 64, [1,1], 42 | padding='VALID', stride=[1,1], 43 | bn=True, is_training=is_training, weight_decay=weight_decay, 44 | scope='adj_conv2', bn_decay=bn_decay, is_dist=True) 45 | 46 | net_1 = tf.reduce_max(out2, axis=-2, keep_dims=True) 47 | 48 | 49 | 50 | adj = tf_util.pairwise_distance(net_1) 51 | nn_idx = tf_util.knn(adj, k=k) 52 | edge_feature = tf_util.get_edge_feature(net_1, nn_idx=nn_idx, k=k) 53 | 54 | out3 = tf_util.conv2d(edge_feature, 64, [1,1], 55 | padding='VALID', stride=[1,1], 56 | bn=True, is_training=is_training, weight_decay=weight_decay, 57 | scope='adj_conv3', bn_decay=bn_decay, is_dist=True) 58 | 59 | out4 = tf_util.conv2d(out3, 64, [1,1], 60 | padding='VALID', stride=[1,1], 61 | bn=True, is_training=is_training, weight_decay=weight_decay, 62 | scope='adj_conv4', bn_decay=bn_decay, is_dist=True) 63 | 64 | net_2 = tf.reduce_max(out4, axis=-2, keep_dims=True) 65 | 66 | 67 | 68 | adj = tf_util.pairwise_distance(net_2) 69 | nn_idx = tf_util.knn(adj, k=k) 70 | edge_feature = tf_util.get_edge_feature(net_2, nn_idx=nn_idx, k=k) 71 | 72 | out5 = tf_util.conv2d(edge_feature, 64, [1,1], 73 | padding='VALID', stride=[1,1], 74 | bn=True, is_training=is_training, weight_decay=weight_decay, 75 | scope='adj_conv5', bn_decay=bn_decay, is_dist=True) 76 | 77 | # out6 = tf_util.conv2d(out5, 64, [1,1], 78 | # padding='VALID', stride=[1,1], 79 | # bn=True, is_training=is_training, weight_decay=weight_decay, 80 | # scope='adj_conv6', bn_decay=bn_decay, is_dist=True) 81 | 82 | net_3 = tf.reduce_max(out5, axis=-2, keep_dims=True) 83 | 84 | 85 | 86 | out7 = tf_util.conv2d(tf.concat([net_1, net_2, net_3], axis=-1), 1024, [1, 1], 87 | padding='VALID', stride=[1,1], 88 | bn=True, is_training=is_training, 89 | scope='adj_conv7', bn_decay=bn_decay, is_dist=True) 90 | 91 | out_max = tf_util.max_pool2d(out7, [num_point, 1], padding='VALID', scope='maxpool') 92 | 93 | 94 | one_hot_label_expand = tf.reshape(input_label, [batch_size, 1, 1, cat_num]) 95 | one_hot_label_expand = tf_util.conv2d(one_hot_label_expand, 64, [1, 1], 96 | padding='VALID', stride=[1,1], 97 | bn=True, is_training=is_training, 98 | scope='one_hot_label_expand', bn_decay=bn_decay, is_dist=True) 99 | out_max = tf.concat(axis=3, values=[out_max, one_hot_label_expand]) 100 | expand = tf.tile(out_max, [1, num_point, 1, 1]) 101 | 102 | concat = tf.concat(axis=3, values=[expand, 103 | net_1, 104 | net_2, 105 | net_3]) 106 | 107 | net2 = tf_util.conv2d(concat, 256, [1,1], padding='VALID', stride=[1,1], bn_decay=bn_decay, 108 | bn=True, is_training=is_training, scope='seg/conv1', weight_decay=weight_decay, is_dist=True) 109 | net2 = tf_util.dropout(net2, keep_prob=0.6, is_training=is_training, scope='seg/dp1') 110 | net2 = tf_util.conv2d(net2, 256, [1,1], padding='VALID', stride=[1,1], bn_decay=bn_decay, 111 | bn=True, is_training=is_training, scope='seg/conv2', weight_decay=weight_decay, is_dist=True) 112 | net2 = tf_util.dropout(net2, keep_prob=0.6, is_training=is_training, scope='seg/dp2') 113 | net2 = tf_util.conv2d(net2, 128, [1,1], padding='VALID', stride=[1,1], bn_decay=bn_decay, 114 | bn=True, is_training=is_training, scope='seg/conv3', weight_decay=weight_decay, is_dist=True) 115 | net2 = tf_util.conv2d(net2, part_num, [1,1], padding='VALID', stride=[1,1], activation_fn=None, 116 | bn=False, scope='seg/conv4', weight_decay=weight_decay, is_dist=True) 117 | 118 | net2 = tf.reshape(net2, [batch_size, num_point, part_num]) 119 | 120 | return net2 121 | 122 | 123 | def get_loss(seg_pred, seg): 124 | per_instance_seg_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=seg_pred, labels=seg), axis=1) 125 | seg_loss = tf.reduce_mean(per_instance_seg_loss) 126 | per_instance_seg_pred_res = tf.argmax(seg_pred, 2) 127 | 128 | return seg_loss, per_instance_seg_loss, per_instance_seg_pred_res 129 | 130 | -------------------------------------------------------------------------------- /tensorflow/part_seg/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | import json 4 | import numpy as np 5 | import os 6 | import sys 7 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | sys.path.append(BASE_DIR) 9 | sys.path.append(os.path.dirname(BASE_DIR)) 10 | import provider 11 | import part_seg_model as model 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--model_path', default='train_results/trained_models/epoch_160.ckpt', help='Model checkpoint path') 15 | FLAGS = parser.parse_args() 16 | 17 | # DEFAULT SETTINGS 18 | pretrained_model_path = FLAGS.model_path 19 | hdf5_data_dir = os.path.join(BASE_DIR, './hdf5_data') 20 | ply_data_dir = os.path.join(BASE_DIR, './PartAnnotation') 21 | gpu_to_use = 0 22 | output_dir = os.path.join(BASE_DIR, './test_results') 23 | output_verbose = False 24 | 25 | # MAIN SCRIPT 26 | point_num = 3000 27 | batch_size = 1 28 | 29 | test_file_list = os.path.join(BASE_DIR, 'testing_ply_file_list.txt') 30 | 31 | oid2cpid = json.load(open(os.path.join(hdf5_data_dir, 'overallid_to_catid_partid.json'), 'r')) 32 | 33 | object2setofoid = {} 34 | for idx in range(len(oid2cpid)): 35 | objid, pid = oid2cpid[idx] 36 | if not objid in object2setofoid.keys(): 37 | object2setofoid[objid] = [] 38 | object2setofoid[objid].append(idx) 39 | 40 | all_obj_cat_file = os.path.join(hdf5_data_dir, 'all_object_categories.txt') 41 | fin = open(all_obj_cat_file, 'r') 42 | lines = [line.rstrip() for line in fin.readlines()] 43 | objcats = [line.split()[1] for line in lines] 44 | objnames = [line.split()[0] for line in lines] 45 | on2oid = {objcats[i]:i for i in range(len(objcats))} 46 | fin.close() 47 | 48 | color_map_file = os.path.join(hdf5_data_dir, 'part_color_mapping.json') 49 | color_map = json.load(open(color_map_file, 'r')) 50 | 51 | NUM_OBJ_CATS = 16 52 | NUM_PART_CATS = 50 53 | 54 | cpid2oid = json.load(open(os.path.join(hdf5_data_dir, 'catid_partid_to_overallid.json'), 'r')) 55 | 56 | def printout(flog, data): 57 | print(data) 58 | flog.write(data + '\n') 59 | 60 | def output_color_point_cloud(data, seg, out_file): 61 | with open(out_file, 'w') as f: 62 | l = len(seg) 63 | for i in range(l): 64 | color = color_map[seg[i]] 65 | f.write('v %f %f %f %f %f %f\n' % (data[i][0], data[i][1], data[i][2], color[0], color[1], color[2])) 66 | 67 | def output_color_point_cloud_red_blue(data, seg, out_file): 68 | with open(out_file, 'w') as f: 69 | l = len(seg) 70 | for i in range(l): 71 | if seg[i] == 1: 72 | color = [0, 0, 1] 73 | elif seg[i] == 0: 74 | color = [1, 0, 0] 75 | else: 76 | color = [0, 0, 0] 77 | 78 | f.write('v %f %f %f %f %f %f\n' % (data[i][0], data[i][1], data[i][2], color[0], color[1], color[2])) 79 | 80 | 81 | def pc_normalize(pc): 82 | l = pc.shape[0] 83 | centroid = np.mean(pc, axis=0) 84 | pc = pc - centroid 85 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 86 | pc = pc / m 87 | return pc 88 | 89 | def placeholder_inputs(): 90 | pointclouds_ph = tf.placeholder(tf.float32, shape=(batch_size, point_num, 3)) 91 | input_label_ph = tf.placeholder(tf.float32, shape=(batch_size, NUM_OBJ_CATS)) 92 | return pointclouds_ph, input_label_ph 93 | 94 | def output_color_point_cloud(data, seg, out_file): 95 | with open(out_file, 'w') as f: 96 | l = len(seg) 97 | for i in range(l): 98 | color = color_map[seg[i]] 99 | f.write('v %f %f %f %f %f %f\n' % (data[i][0], data[i][1], data[i][2], color[0], color[1], color[2])) 100 | 101 | def load_pts_seg_files(pts_file, seg_file, catid): 102 | with open(pts_file, 'r') as f: 103 | pts_str = [item.rstrip() for item in f.readlines()] 104 | pts = np.array([np.float32(s.split()) for s in pts_str], dtype=np.float32) 105 | with open(seg_file, 'r') as f: 106 | part_ids = np.array([int(item.rstrip()) for item in f.readlines()], dtype=np.uint8) 107 | seg = np.array([cpid2oid[catid+'_'+str(x)] for x in part_ids]) 108 | return pts, seg 109 | 110 | def pc_augment_to_point_num(pts, pn): 111 | assert(pts.shape[0] <= pn) 112 | cur_len = pts.shape[0] 113 | res = np.array(pts) 114 | while cur_len < pn: 115 | res = np.concatenate((res, pts)) 116 | cur_len += pts.shape[0] 117 | return res[:pn, :] 118 | 119 | def convert_label_to_one_hot(labels): 120 | label_one_hot = np.zeros((labels.shape[0], NUM_OBJ_CATS)) 121 | for idx in range(labels.shape[0]): 122 | label_one_hot[idx, labels[idx]] = 1 123 | return label_one_hot 124 | 125 | def predict(): 126 | is_training = False 127 | 128 | with tf.device('/gpu:'+str(gpu_to_use)): 129 | pointclouds_ph, input_label_ph = placeholder_inputs() 130 | is_training_ph = tf.placeholder(tf.bool, shape=()) 131 | 132 | seg_pred = model.get_model(pointclouds_ph, input_label_ph, \ 133 | cat_num=NUM_OBJ_CATS, part_num=NUM_PART_CATS, is_training=is_training_ph, \ 134 | batch_size=batch_size, num_point=point_num, weight_decay=0.0, bn_decay=None) 135 | 136 | saver = tf.train.Saver() 137 | 138 | config = tf.ConfigProto() 139 | config.gpu_options.allow_growth = True 140 | config.allow_soft_placement = True 141 | 142 | with tf.Session(config=config) as sess: 143 | if not os.path.exists(output_dir): 144 | os.mkdir(output_dir) 145 | 146 | flog = open(os.path.join(output_dir, 'log.txt'), 'a') 147 | 148 | printout(flog, 'Loading model %s' % pretrained_model_path) 149 | saver.restore(sess, pretrained_model_path) 150 | printout(flog, 'Model restored.') 151 | 152 | batch_data = np.zeros([batch_size, point_num, 3]).astype(np.float32) 153 | 154 | total_acc = 0.0 155 | total_seen = 0 156 | total_acc_iou = 0.0 157 | 158 | total_per_cat_acc = np.zeros((NUM_OBJ_CATS)).astype(np.float32) 159 | total_per_cat_iou = np.zeros((NUM_OBJ_CATS)).astype(np.float32) 160 | total_per_cat_seen = np.zeros((NUM_OBJ_CATS)).astype(np.int32) 161 | 162 | ffiles = open(test_file_list, 'r') 163 | lines = [line.rstrip() for line in ffiles.readlines()] 164 | pts_files = [line.split()[0] for line in lines] 165 | seg_files = [line.split()[1] for line in lines] 166 | labels = [line.split()[2] for line in lines] 167 | ffiles.close() 168 | 169 | len_pts_files = len(pts_files) 170 | for shape_idx in range(len_pts_files): 171 | if shape_idx % 100 == 0: 172 | printout(flog, '%d/%d ...' % (shape_idx, len_pts_files)) 173 | 174 | cur_gt_label = on2oid[labels[shape_idx]] # 0/1/.../15 175 | 176 | cur_label_one_hot = np.zeros((1, NUM_OBJ_CATS), dtype=np.float32) 177 | cur_label_one_hot[0, cur_gt_label] = 1 178 | 179 | pts_file_to_load = os.path.join(ply_data_dir, pts_files[shape_idx]) 180 | seg_file_to_load = os.path.join(ply_data_dir, seg_files[shape_idx]) 181 | 182 | pts, seg = load_pts_seg_files(pts_file_to_load, seg_file_to_load, objcats[cur_gt_label]) 183 | ori_point_num = len(seg) 184 | 185 | batch_data[0, ...] = pc_augment_to_point_num(pc_normalize(pts), point_num) 186 | 187 | seg_pred_res = sess.run(seg_pred, feed_dict={ 188 | pointclouds_ph: batch_data, 189 | input_label_ph: cur_label_one_hot, 190 | is_training_ph: is_training}) 191 | 192 | seg_pred_res = seg_pred_res[0, ...] 193 | 194 | iou_oids = object2setofoid[objcats[cur_gt_label]] 195 | non_cat_labels = list(set(np.arange(NUM_PART_CATS)).difference(set(iou_oids))) 196 | 197 | mini = np.min(seg_pred_res) 198 | seg_pred_res[:, non_cat_labels] = mini - 1000 199 | 200 | seg_pred_val = np.argmax(seg_pred_res, axis=1)[:ori_point_num] 201 | 202 | seg_acc = np.mean(seg_pred_val == seg) 203 | 204 | total_acc += seg_acc 205 | total_seen += 1 206 | 207 | total_per_cat_seen[cur_gt_label] += 1 208 | total_per_cat_acc[cur_gt_label] += seg_acc 209 | 210 | mask = np.int32(seg_pred_val == seg) 211 | 212 | total_iou = 0.0 213 | iou_log = '' 214 | for oid in iou_oids: 215 | n_pred = np.sum(seg_pred_val == oid) 216 | n_gt = np.sum(seg == oid) 217 | n_intersect = np.sum(np.int32(seg == oid) * mask) 218 | n_union = n_pred + n_gt - n_intersect 219 | iou_log += '_' + str(n_pred)+'_'+str(n_gt)+'_'+str(n_intersect)+'_'+str(n_union)+'_' 220 | if n_union == 0: 221 | total_iou += 1 222 | iou_log += '_1\n' 223 | else: 224 | total_iou += n_intersect * 1.0 / n_union 225 | iou_log += '_'+str(n_intersect * 1.0 / n_union)+'\n' 226 | 227 | avg_iou = total_iou / len(iou_oids) 228 | total_acc_iou += avg_iou 229 | total_per_cat_iou[cur_gt_label] += avg_iou 230 | 231 | if output_verbose: 232 | output_color_point_cloud(pts, seg, os.path.join(output_dir, str(shape_idx)+'_gt.obj')) 233 | output_color_point_cloud(pts, seg_pred_val, os.path.join(output_dir, str(shape_idx)+'_pred.obj')) 234 | output_color_point_cloud_red_blue(pts, np.int32(seg == seg_pred_val), 235 | os.path.join(output_dir, str(shape_idx)+'_diff.obj')) 236 | 237 | with open(os.path.join(output_dir, str(shape_idx)+'.log'), 'w') as fout: 238 | fout.write('Total Point: %d\n\n' % ori_point_num) 239 | fout.write('Ground Truth: %s\n' % objnames[cur_gt_label]) 240 | fout.write('Accuracy: %f\n' % seg_acc) 241 | fout.write('IoU: %f\n\n' % avg_iou) 242 | fout.write('IoU details: %s\n' % iou_log) 243 | 244 | printout(flog, 'Accuracy: %f' % (total_acc / total_seen)) 245 | printout(flog, 'IoU: %f' % (total_acc_iou / total_seen)) 246 | 247 | for cat_idx in range(NUM_OBJ_CATS): 248 | printout(flog, '\t ' + objcats[cat_idx] + ' Total Number: ' + str(total_per_cat_seen[cat_idx])) 249 | if total_per_cat_seen[cat_idx] > 0: 250 | printout(flog, '\t ' + objcats[cat_idx] + ' Accuracy: ' + \ 251 | str(total_per_cat_acc[cat_idx] / total_per_cat_seen[cat_idx])) 252 | printout(flog, '\t ' + objcats[cat_idx] + ' IoU: '+ \ 253 | str(total_per_cat_iou[cat_idx] / total_per_cat_seen[cat_idx])) 254 | 255 | with tf.Graph().as_default(): 256 | predict() 257 | -------------------------------------------------------------------------------- /tensorflow/part_seg/train_multi_gpu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import tensorflow as tf 4 | import numpy as np 5 | from datetime import datetime 6 | import json 7 | import os 8 | import sys 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(BASE_DIR) 11 | sys.path.append(os.path.dirname(BASE_DIR)) 12 | import provider 13 | import part_seg_model as model 14 | 15 | TOWER_NAME = 'tower' 16 | 17 | # DEFAULT SETTINGS 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--num_gpu', type=int, default=2, help='The number of GPUs to use [default: 2]') 20 | parser.add_argument('--batch', type=int, default=16, help='Batch Size per GPU during training [default: 32]') 21 | parser.add_argument('--epoch', type=int, default=201, help='Epoch to run [default: 50]') 22 | parser.add_argument('--point_num', type=int, default=2048, help='Point Number [256/512/1024/2048]') 23 | parser.add_argument('--output_dir', type=str, default='train_results', help='Directory that stores all training logs and trained models') 24 | parser.add_argument('--wd', type=float, default=0, help='Weight Decay [Default: 0.0]') 25 | FLAGS = parser.parse_args() 26 | 27 | hdf5_data_dir = os.path.join(BASE_DIR, './hdf5_data') 28 | 29 | # MAIN SCRIPT 30 | point_num = FLAGS.point_num 31 | batch_size = FLAGS.batch 32 | output_dir = FLAGS.output_dir 33 | 34 | if not os.path.exists(output_dir): 35 | os.mkdir(output_dir) 36 | 37 | # color_map_file = os.path.join(hdf5_data_dir, 'part_color_mapping.json') 38 | # color_map = json.load(open(color_map_file, 'r')) 39 | 40 | all_obj_cats_file = os.path.join(hdf5_data_dir, 'all_object_categories.txt') 41 | fin = open(all_obj_cats_file, 'r') 42 | lines = [line.rstrip() for line in fin.readlines()] 43 | all_obj_cats = [(line.split()[0], line.split()[1]) for line in lines] 44 | fin.close() 45 | 46 | all_cats = json.load(open(os.path.join(hdf5_data_dir, 'overallid_to_catid_partid.json'), 'r')) 47 | NUM_CATEGORIES = 16 48 | NUM_PART_CATS = len(all_cats) 49 | 50 | print('#### Batch Size Per GPU: {0}'.format(batch_size)) 51 | print('#### Point Number: {0}'.format(point_num)) 52 | print('#### Using GPUs: {0}'.format(FLAGS.num_gpu)) 53 | 54 | DECAY_STEP = 16881 * 20 55 | DECAY_RATE = 0.5 56 | 57 | LEARNING_RATE_CLIP = 1e-5 58 | 59 | BN_INIT_DECAY = 0.5 60 | BN_DECAY_DECAY_RATE = 0.5 61 | BN_DECAY_DECAY_STEP = float(DECAY_STEP * 2) 62 | BN_DECAY_CLIP = 0.99 63 | 64 | BASE_LEARNING_RATE = 0.003 65 | MOMENTUM = 0.9 66 | TRAINING_EPOCHES = FLAGS.epoch 67 | print('### Training epoch: {0}'.format(TRAINING_EPOCHES)) 68 | 69 | TRAINING_FILE_LIST = os.path.join(hdf5_data_dir, 'train_hdf5_file_list.txt') 70 | TESTING_FILE_LIST = os.path.join(hdf5_data_dir, 'val_hdf5_file_list.txt') 71 | 72 | MODEL_STORAGE_PATH = os.path.join(output_dir, 'trained_models') 73 | if not os.path.exists(MODEL_STORAGE_PATH): 74 | os.mkdir(MODEL_STORAGE_PATH) 75 | 76 | LOG_STORAGE_PATH = os.path.join(output_dir, 'logs') 77 | if not os.path.exists(LOG_STORAGE_PATH): 78 | os.mkdir(LOG_STORAGE_PATH) 79 | 80 | SUMMARIES_FOLDER = os.path.join(output_dir, 'summaries') 81 | if not os.path.exists(SUMMARIES_FOLDER): 82 | os.mkdir(SUMMARIES_FOLDER) 83 | 84 | def printout(flog, data): 85 | print(data) 86 | flog.write(data + '\n') 87 | 88 | def convert_label_to_one_hot(labels): 89 | label_one_hot = np.zeros((labels.shape[0], NUM_CATEGORIES)) 90 | for idx in range(labels.shape[0]): 91 | label_one_hot[idx, labels[idx]] = 1 92 | return label_one_hot 93 | 94 | def average_gradients(tower_grads): 95 | """Calculate average gradient for each shared variable across all towers. 96 | 97 | Note that this function provides a synchronization point across all towers. 98 | 99 | Args: 100 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 101 | is over individual gradients. The inner list is over the gradient 102 | calculation for each tower. 103 | Returns: 104 | List of pairs of (gradient, variable) where the gradient has been 105 | averaged across all towers. 106 | """ 107 | average_grads = [] 108 | for grad_and_vars in zip(*tower_grads): 109 | # Note that each grad_and_vars looks like the following: 110 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 111 | grads = [] 112 | for g, _ in grad_and_vars: 113 | if g is None: 114 | continue 115 | expanded_g = tf.expand_dims(g, 0) 116 | grads.append(expanded_g) 117 | 118 | # Average over the 'tower' dimension. 119 | grad = tf.concat(grads, 0) 120 | grad = tf.reduce_mean(grad, 0) 121 | 122 | # Keep in mind that the Variables are redundant because they are shared 123 | # across towers. So .. we will just return the first tower's pointer to 124 | # the Variable. 125 | v = grad_and_vars[0][1] 126 | grad_and_var = (grad, v) 127 | average_grads.append(grad_and_var) 128 | return average_grads 129 | 130 | 131 | def train(): 132 | with tf.Graph().as_default(), tf.device('/cpu:0'): 133 | 134 | batch = tf.Variable(0, trainable=False) 135 | 136 | learning_rate = tf.train.exponential_decay( 137 | BASE_LEARNING_RATE, # base learning rate 138 | batch * batch_size, # global_var indicating the number of steps 139 | DECAY_STEP, # step size 140 | DECAY_RATE, # decay rate 141 | staircase=True # Stair-case or continuous decreasing 142 | ) 143 | learning_rate = tf.maximum(learning_rate, LEARNING_RATE_CLIP) 144 | 145 | bn_momentum = tf.train.exponential_decay( 146 | BN_INIT_DECAY, 147 | batch*batch_size, 148 | BN_DECAY_DECAY_STEP, 149 | BN_DECAY_DECAY_RATE, 150 | staircase=True) 151 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 152 | 153 | lr_op = tf.summary.scalar('learning_rate', learning_rate) 154 | batch_op = tf.summary.scalar('batch_number', batch) 155 | bn_decay_op = tf.summary.scalar('bn_decay', bn_decay) 156 | 157 | trainer = tf.train.AdamOptimizer(learning_rate) 158 | 159 | # store tensors for different gpus 160 | tower_grads = [] 161 | pointclouds_phs = [] 162 | input_label_phs = [] 163 | seg_phs =[] 164 | is_training_phs =[] 165 | 166 | with tf.variable_scope(tf.get_variable_scope()): 167 | for i in xrange(FLAGS.num_gpu): 168 | with tf.device('/gpu:%d' % i): 169 | with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope: 170 | pointclouds_phs.append(tf.placeholder(tf.float32, shape=(batch_size, point_num, 3))) # for points 171 | input_label_phs.append(tf.placeholder(tf.float32, shape=(batch_size, NUM_CATEGORIES))) # for one-hot category label 172 | seg_phs.append(tf.placeholder(tf.int32, shape=(batch_size, point_num))) # for part labels 173 | is_training_phs.append(tf.placeholder(tf.bool, shape=())) 174 | 175 | seg_pred = model.get_model(pointclouds_phs[-1], input_label_phs[-1], \ 176 | is_training=is_training_phs[-1], bn_decay=bn_decay, cat_num=NUM_CATEGORIES, \ 177 | part_num=NUM_PART_CATS, batch_size=batch_size, num_point=point_num, weight_decay=FLAGS.wd) 178 | 179 | 180 | loss, per_instance_seg_loss, per_instance_seg_pred_res \ 181 | = model.get_loss(seg_pred, seg_phs[-1]) 182 | 183 | total_training_loss_ph = tf.placeholder(tf.float32, shape=()) 184 | total_testing_loss_ph = tf.placeholder(tf.float32, shape=()) 185 | 186 | seg_training_acc_ph = tf.placeholder(tf.float32, shape=()) 187 | seg_testing_acc_ph = tf.placeholder(tf.float32, shape=()) 188 | seg_testing_acc_avg_cat_ph = tf.placeholder(tf.float32, shape=()) 189 | 190 | total_train_loss_sum_op = tf.summary.scalar('total_training_loss', total_training_loss_ph) 191 | total_test_loss_sum_op = tf.summary.scalar('total_testing_loss', total_testing_loss_ph) 192 | 193 | 194 | seg_train_acc_sum_op = tf.summary.scalar('seg_training_acc', seg_training_acc_ph) 195 | seg_test_acc_sum_op = tf.summary.scalar('seg_testing_acc', seg_testing_acc_ph) 196 | seg_test_acc_avg_cat_op = tf.summary.scalar('seg_testing_acc_avg_cat', seg_testing_acc_avg_cat_ph) 197 | 198 | tf.get_variable_scope().reuse_variables() 199 | 200 | grads = trainer.compute_gradients(loss) 201 | 202 | tower_grads.append(grads) 203 | 204 | grads = average_gradients(tower_grads) 205 | 206 | train_op = trainer.apply_gradients(grads, global_step=batch) 207 | 208 | saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=20) 209 | 210 | config = tf.ConfigProto() 211 | config.gpu_options.allow_growth = True 212 | config.allow_soft_placement = True 213 | sess = tf.Session(config=config) 214 | 215 | init = tf.group(tf.global_variables_initializer(), 216 | tf.local_variables_initializer()) 217 | sess.run(init) 218 | 219 | train_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/train', sess.graph) 220 | test_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/test') 221 | 222 | train_file_list = provider.getDataFiles(TRAINING_FILE_LIST) 223 | num_train_file = len(train_file_list) 224 | test_file_list = provider.getDataFiles(TESTING_FILE_LIST) 225 | num_test_file = len(test_file_list) 226 | 227 | fcmd = open(os.path.join(LOG_STORAGE_PATH, 'cmd.txt'), 'w') 228 | fcmd.write(str(FLAGS)) 229 | fcmd.close() 230 | 231 | # write logs to the disk 232 | flog = open(os.path.join(LOG_STORAGE_PATH, 'log.txt'), 'w') 233 | 234 | def train_one_epoch(train_file_idx, epoch_num): 235 | is_training = True 236 | 237 | for i in range(num_train_file): 238 | cur_train_filename = os.path.join(hdf5_data_dir, train_file_list[train_file_idx[i]]) 239 | printout(flog, 'Loading train file ' + cur_train_filename) 240 | 241 | cur_data, cur_labels, cur_seg = provider.load_h5_data_label_seg(cur_train_filename) 242 | cur_data, cur_labels, order = provider.shuffle_data(cur_data, np.squeeze(cur_labels)) 243 | cur_seg = cur_seg[order, ...] 244 | 245 | cur_labels_one_hot = convert_label_to_one_hot(cur_labels) 246 | 247 | num_data = len(cur_labels) 248 | num_batch = num_data // (FLAGS.num_gpu * batch_size) # For all working gpus 249 | 250 | total_loss = 0.0 251 | total_seg_acc = 0.0 252 | 253 | for j in range(num_batch): 254 | begidx_0 = j * batch_size 255 | endidx_0 = (j + 1) * batch_size 256 | begidx_1 = (j + 1) * batch_size 257 | endidx_1 = (j + 2) * batch_size 258 | 259 | feed_dict = { 260 | # For the first gpu 261 | pointclouds_phs[0]: cur_data[begidx_0: endidx_0, ...], 262 | input_label_phs[0]: cur_labels_one_hot[begidx_0: endidx_0, ...], 263 | seg_phs[0]: cur_seg[begidx_0: endidx_0, ...], 264 | is_training_phs[0]: is_training, 265 | # For the second gpu 266 | pointclouds_phs[1]: cur_data[begidx_1: endidx_1, ...], 267 | input_label_phs[1]: cur_labels_one_hot[begidx_1: endidx_1, ...], 268 | seg_phs[1]: cur_seg[begidx_1: endidx_1, ...], 269 | is_training_phs[1]: is_training, 270 | } 271 | 272 | 273 | # train_op is for both gpus, and the others are for gpu_1 274 | _, loss_val, per_instance_seg_loss_val, seg_pred_val, pred_seg_res \ 275 | = sess.run([train_op, loss, per_instance_seg_loss, seg_pred, per_instance_seg_pred_res], \ 276 | feed_dict=feed_dict) 277 | 278 | per_instance_part_acc = np.mean(pred_seg_res == cur_seg[begidx_1: endidx_1, ...], axis=1) 279 | average_part_acc = np.mean(per_instance_part_acc) 280 | 281 | total_loss += loss_val 282 | total_seg_acc += average_part_acc 283 | 284 | total_loss = total_loss * 1.0 / num_batch 285 | total_seg_acc = total_seg_acc * 1.0 / num_batch 286 | 287 | lr_sum, bn_decay_sum, batch_sum, train_loss_sum, train_seg_acc_sum = sess.run(\ 288 | [lr_op, bn_decay_op, batch_op, total_train_loss_sum_op, seg_train_acc_sum_op], \ 289 | feed_dict={total_training_loss_ph: total_loss, seg_training_acc_ph: total_seg_acc}) 290 | 291 | train_writer.add_summary(train_loss_sum, i + epoch_num * num_train_file) 292 | train_writer.add_summary(lr_sum, i + epoch_num * num_train_file) 293 | train_writer.add_summary(bn_decay_sum, i + epoch_num * num_train_file) 294 | train_writer.add_summary(train_seg_acc_sum, i + epoch_num * num_train_file) 295 | train_writer.add_summary(batch_sum, i + epoch_num * num_train_file) 296 | 297 | printout(flog, '\tTraining Total Mean_loss: %f' % total_loss) 298 | printout(flog, '\t\tTraining Seg Accuracy: %f' % total_seg_acc) 299 | 300 | def eval_one_epoch(epoch_num): 301 | is_training = False 302 | 303 | total_loss = 0.0 304 | total_seg_acc = 0.0 305 | total_seen = 0 306 | 307 | total_seg_acc_per_cat = np.zeros((NUM_CATEGORIES)).astype(np.float32) 308 | total_seen_per_cat = np.zeros((NUM_CATEGORIES)).astype(np.int32) 309 | 310 | for i in range(num_test_file): 311 | cur_test_filename = os.path.join(hdf5_data_dir, test_file_list[i]) 312 | printout(flog, 'Loading test file ' + cur_test_filename) 313 | 314 | cur_data, cur_labels, cur_seg = provider.load_h5_data_label_seg(cur_test_filename) 315 | cur_labels = np.squeeze(cur_labels) 316 | 317 | cur_labels_one_hot = convert_label_to_one_hot(cur_labels) 318 | 319 | num_data = len(cur_labels) 320 | num_batch = num_data // batch_size 321 | 322 | # Run on gpu_1, since the tensors used for evaluation are defined on gpu_1 323 | for j in range(num_batch): 324 | begidx = j * batch_size 325 | endidx = (j + 1) * batch_size 326 | feed_dict = { 327 | pointclouds_phs[1]: cur_data[begidx: endidx, ...], 328 | input_label_phs[1]: cur_labels_one_hot[begidx: endidx, ...], 329 | seg_phs[1]: cur_seg[begidx: endidx, ...], 330 | is_training_phs[1]: is_training} 331 | 332 | loss_val, per_instance_seg_loss_val, seg_pred_val, pred_seg_res \ 333 | = sess.run([loss, per_instance_seg_loss, seg_pred, per_instance_seg_pred_res], \ 334 | feed_dict=feed_dict) 335 | 336 | per_instance_part_acc = np.mean(pred_seg_res == cur_seg[begidx: endidx, ...], axis=1) 337 | average_part_acc = np.mean(per_instance_part_acc) 338 | 339 | total_seen += 1 340 | total_loss += loss_val 341 | 342 | total_seg_acc += average_part_acc 343 | 344 | for shape_idx in range(begidx, endidx): 345 | total_seen_per_cat[cur_labels[shape_idx]] += 1 346 | total_seg_acc_per_cat[cur_labels[shape_idx]] += per_instance_part_acc[shape_idx - begidx] 347 | 348 | total_loss = total_loss * 1.0 / total_seen 349 | total_seg_acc = total_seg_acc * 1.0 / total_seen 350 | 351 | test_loss_sum, test_seg_acc_sum = sess.run(\ 352 | [total_test_loss_sum_op, seg_test_acc_sum_op], \ 353 | feed_dict={total_testing_loss_ph: total_loss, \ 354 | seg_testing_acc_ph: total_seg_acc}) 355 | 356 | test_writer.add_summary(test_loss_sum, (epoch_num+1) * num_train_file-1) 357 | test_writer.add_summary(test_seg_acc_sum, (epoch_num+1) * num_train_file-1) 358 | 359 | printout(flog, '\tTesting Total Mean_loss: %f' % total_loss) 360 | printout(flog, '\t\tTesting Seg Accuracy: %f' % total_seg_acc) 361 | 362 | for cat_idx in range(NUM_CATEGORIES): 363 | if total_seen_per_cat[cat_idx] > 0: 364 | printout(flog, '\n\t\tCategory %s Object Number: %d' % (all_obj_cats[cat_idx][0], total_seen_per_cat[cat_idx])) 365 | printout(flog, '\t\tCategory %s Seg Accuracy: %f' % (all_obj_cats[cat_idx][0], total_seg_acc_per_cat[cat_idx]/total_seen_per_cat[cat_idx])) 366 | 367 | if not os.path.exists(MODEL_STORAGE_PATH): 368 | os.mkdir(MODEL_STORAGE_PATH) 369 | 370 | for epoch in range(TRAINING_EPOCHES): 371 | printout(flog, '\n<<< Testing on the test dataset ...') 372 | eval_one_epoch(epoch) 373 | 374 | printout(flog, '\n>>> Training for the epoch %d/%d ...' % (epoch, TRAINING_EPOCHES)) 375 | 376 | train_file_idx = np.arange(0, len(train_file_list)) 377 | np.random.shuffle(train_file_idx) 378 | 379 | train_one_epoch(train_file_idx, epoch) 380 | 381 | if epoch % 5 == 0: 382 | cp_filename = saver.save(sess, os.path.join(MODEL_STORAGE_PATH, 'epoch_' + str(epoch)+'.ckpt')) 383 | printout(flog, 'Successfully store the checkpoint model into ' + cp_filename) 384 | 385 | flog.flush() 386 | 387 | flog.close() 388 | 389 | if __name__=='__main__': 390 | train() 391 | -------------------------------------------------------------------------------- /tensorflow/part_seg/train_results/trained_models/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "epoch_200.ckpt" 2 | all_model_checkpoint_paths: "epoch_0.ckpt" 3 | all_model_checkpoint_paths: "epoch_5.ckpt" 4 | all_model_checkpoint_paths: "epoch_10.ckpt" 5 | all_model_checkpoint_paths: "epoch_15.ckpt" 6 | all_model_checkpoint_paths: "epoch_20.ckpt" 7 | all_model_checkpoint_paths: "epoch_25.ckpt" 8 | all_model_checkpoint_paths: "epoch_30.ckpt" 9 | all_model_checkpoint_paths: "epoch_35.ckpt" 10 | all_model_checkpoint_paths: "epoch_40.ckpt" 11 | all_model_checkpoint_paths: "epoch_45.ckpt" 12 | all_model_checkpoint_paths: "epoch_50.ckpt" 13 | all_model_checkpoint_paths: "epoch_55.ckpt" 14 | all_model_checkpoint_paths: "epoch_60.ckpt" 15 | all_model_checkpoint_paths: "epoch_65.ckpt" 16 | all_model_checkpoint_paths: "epoch_70.ckpt" 17 | all_model_checkpoint_paths: "epoch_75.ckpt" 18 | all_model_checkpoint_paths: "epoch_80.ckpt" 19 | all_model_checkpoint_paths: "epoch_85.ckpt" 20 | all_model_checkpoint_paths: "epoch_90.ckpt" 21 | all_model_checkpoint_paths: "epoch_95.ckpt" 22 | all_model_checkpoint_paths: "epoch_100.ckpt" 23 | all_model_checkpoint_paths: "epoch_105.ckpt" 24 | all_model_checkpoint_paths: "epoch_110.ckpt" 25 | all_model_checkpoint_paths: "epoch_115.ckpt" 26 | all_model_checkpoint_paths: "epoch_120.ckpt" 27 | all_model_checkpoint_paths: "epoch_125.ckpt" 28 | all_model_checkpoint_paths: "epoch_130.ckpt" 29 | all_model_checkpoint_paths: "epoch_135.ckpt" 30 | all_model_checkpoint_paths: "epoch_140.ckpt" 31 | all_model_checkpoint_paths: "epoch_145.ckpt" 32 | all_model_checkpoint_paths: "epoch_150.ckpt" 33 | all_model_checkpoint_paths: "epoch_155.ckpt" 34 | all_model_checkpoint_paths: "epoch_160.ckpt" 35 | all_model_checkpoint_paths: "epoch_165.ckpt" 36 | all_model_checkpoint_paths: "epoch_170.ckpt" 37 | all_model_checkpoint_paths: "epoch_175.ckpt" 38 | all_model_checkpoint_paths: "epoch_180.ckpt" 39 | all_model_checkpoint_paths: "epoch_185.ckpt" 40 | all_model_checkpoint_paths: "epoch_190.ckpt" 41 | all_model_checkpoint_paths: "epoch_195.ckpt" 42 | all_model_checkpoint_paths: "epoch_200.ckpt" 43 | -------------------------------------------------------------------------------- /tensorflow/part_seg/train_results/trained_models/epoch_175.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangYueFt/dgcnn/f765b469a67730658ba554e97dc11723a7bab628/tensorflow/part_seg/train_results/trained_models/epoch_175.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /tensorflow/part_seg/train_results/trained_models/epoch_175.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangYueFt/dgcnn/f765b469a67730658ba554e97dc11723a7bab628/tensorflow/part_seg/train_results/trained_models/epoch_175.ckpt.index -------------------------------------------------------------------------------- /tensorflow/part_seg/train_results/trained_models/epoch_175.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangYueFt/dgcnn/f765b469a67730658ba554e97dc11723a7bab628/tensorflow/part_seg/train_results/trained_models/epoch_175.ckpt.meta -------------------------------------------------------------------------------- /tensorflow/provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import h5py 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(BASE_DIR) 7 | 8 | # Download dataset for point cloud classification 9 | DATA_DIR = os.path.join(BASE_DIR, 'data') 10 | if not os.path.exists(DATA_DIR): 11 | os.mkdir(DATA_DIR) 12 | if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): 13 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 14 | zipfile = os.path.basename(www) 15 | os.system('wget %s; unzip %s' % (www, zipfile)) 16 | os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) 17 | os.system('rm %s' % (zipfile)) 18 | 19 | 20 | def shuffle_data(data, labels): 21 | """ Shuffle data and labels. 22 | Input: 23 | data: B,N,... numpy array 24 | label: B,... numpy array 25 | Return: 26 | shuffled data, label and shuffle indices 27 | """ 28 | idx = np.arange(len(labels)) 29 | np.random.shuffle(idx) 30 | return data[idx, ...], labels[idx], idx 31 | 32 | 33 | def rotate_point_cloud(batch_data): 34 | """ Randomly rotate the point clouds to augument the dataset 35 | rotation is per shape based along up direction 36 | Input: 37 | BxNx3 array, original batch of point clouds 38 | Return: 39 | BxNx3 array, rotated batch of point clouds 40 | """ 41 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 42 | for k in xrange(batch_data.shape[0]): 43 | rotation_angle = np.random.uniform() * 2 * np.pi 44 | cosval = np.cos(rotation_angle) 45 | sinval = np.sin(rotation_angle) 46 | rotation_matrix = np.array([[cosval, 0, sinval], 47 | [0, 1, 0], 48 | [-sinval, 0, cosval]]) 49 | shape_pc = batch_data[k, ...] 50 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 51 | return rotated_data 52 | 53 | 54 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 55 | """ Rotate the point cloud along up direction with certain angle. 56 | Input: 57 | BxNx3 array, original batch of point clouds 58 | Return: 59 | BxNx3 array, rotated batch of point clouds 60 | """ 61 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 62 | for k in xrange(batch_data.shape[0]): 63 | #rotation_angle = np.random.uniform() * 2 * np.pi 64 | cosval = np.cos(rotation_angle) 65 | sinval = np.sin(rotation_angle) 66 | rotation_matrix = np.array([[cosval, 0, sinval], 67 | [0, 1, 0], 68 | [-sinval, 0, cosval]]) 69 | shape_pc = batch_data[k, ...] 70 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 71 | return rotated_data 72 | 73 | 74 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 75 | """ Randomly perturb the point clouds by small rotations 76 | Input: 77 | BxNx3 array, original batch of point clouds 78 | Return: 79 | BxNx3 array, rotated batch of point clouds 80 | """ 81 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 82 | for k in xrange(batch_data.shape[0]): 83 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 84 | Rx = np.array([[1,0,0], 85 | [0,np.cos(angles[0]),-np.sin(angles[0])], 86 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 87 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 88 | [0,1,0], 89 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 90 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 91 | [np.sin(angles[2]),np.cos(angles[2]),0], 92 | [0,0,1]]) 93 | R = np.dot(Rz, np.dot(Ry,Rx)) 94 | shape_pc = batch_data[k, ...] 95 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 96 | return rotated_data 97 | 98 | 99 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 100 | """ Randomly jitter points. jittering is per point. 101 | Input: 102 | BxNx3 array, original batch of point clouds 103 | Return: 104 | BxNx3 array, jittered batch of point clouds 105 | """ 106 | B, N, C = batch_data.shape 107 | assert(clip > 0) 108 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 109 | jittered_data += batch_data 110 | return jittered_data 111 | 112 | def shift_point_cloud(batch_data, shift_range=0.1): 113 | """ Randomly shift point cloud. Shift is per point cloud. 114 | Input: 115 | BxNx3 array, original batch of point clouds 116 | Return: 117 | BxNx3 array, shifted batch of point clouds 118 | """ 119 | B, N, C = batch_data.shape 120 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 121 | for batch_index in range(B): 122 | batch_data[batch_index,:,:] += shifts[batch_index,:] 123 | return batch_data 124 | 125 | 126 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): 127 | """ Randomly scale the point cloud. Scale is per point cloud. 128 | Input: 129 | BxNx3 array, original batch of point clouds 130 | Return: 131 | BxNx3 array, scaled batch of point clouds 132 | """ 133 | B, N, C = batch_data.shape 134 | scales = np.random.uniform(scale_low, scale_high, B) 135 | for batch_index in range(B): 136 | batch_data[batch_index,:,:] *= scales[batch_index] 137 | return batch_data 138 | 139 | def getDataFiles(list_filename): 140 | return [line.rstrip() for line in open(list_filename)] 141 | 142 | def load_h5(h5_filename): 143 | f = h5py.File(h5_filename) 144 | data = f['data'][:] 145 | label = f['label'][:] 146 | return (data, label) 147 | 148 | def loadDataFile(filename): 149 | return load_h5(filename) 150 | 151 | 152 | def load_h5_data_label_seg(h5_filename): 153 | f = h5py.File(h5_filename) 154 | data = f['data'][:] # (2048, 2048, 3) 155 | label = f['label'][:] # (2048, 1) 156 | seg = f['pid'][:] # (2048, 2048) 157 | return (data, label, seg) -------------------------------------------------------------------------------- /tensorflow/sem_seg/README.md: -------------------------------------------------------------------------------- 1 | ## Semantic segmentation of indoor scenes 2 | 3 | ### Dataset 4 | 5 | 1. Donwload prepared HDF5 data for training: 6 | ``` 7 | sh +x download_data.sh 8 | ``` 9 | 2. Download 3D indoor parsing dataset (S3DIS Dataset) for testing and visualization. "Stanford3dDataset_v1.2_Aligned_Version.zip" of the dataset is used. Unzip the downloaded file into "dgcnn/data/", and then run 10 | ``` 11 | python collect_indoor3d_data.py 12 | ``` 13 | to generate "dgcnn/data/stanford_indoor3d" 14 | 15 | ### Train 16 | 17 | We use 6-fold training, such that 6 models are trained leaving 1 of 6 areas as the testing area for each model. We keep using 2 GPUs for distributed training. To train 6 models sequentially, run 18 | ``` 19 | sh +x train_job.sh 20 | ``` 21 | 22 | ### Evaluation 23 | 24 | 1. To generate predicted results for all 6 areas, run 25 | ``` 26 | sh +x test_job.sh 27 | ``` 28 | The model parameters are saved every 10 epochs, the saved model used to generate predited results can be changed by setting "--model_path" in "test_job.sh". For example, if you want to use the model saved after 70 epochs, you can set "--model_path" to "log*n*/epoch_70.ckpt" for *n* = 1, 2, ..., 6. To visualize the results, you can add "--visu" flag in the end of each line in "test_job.sh". 29 | 30 | 2. To obtain overall quantitative evaluation results, run 31 | ``` 32 | python eval_iou_accuracy.py 33 | ``` 34 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/batch_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | ROOT_DIR = os.path.dirname(BASE_DIR) 6 | sys.path.append(BASE_DIR) 7 | from model import * 8 | import indoor3d_util 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 12 | parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during training [default: 1]') 13 | parser.add_argument('--num_point', type=int, default=4096, help='Point number [default: 4096]') 14 | parser.add_argument('--model_path', required=True, help='model checkpoint file path') 15 | parser.add_argument('--dump_dir', required=True, help='dump folder path') 16 | parser.add_argument('--output_filelist', required=True, help='TXT filename, filelist, each line is an output for a room') 17 | parser.add_argument('--room_data_filelist', required=True, help='TXT filename, filelist, each line is a test room data label file.') 18 | parser.add_argument('--no_clutter', action='store_true', help='If true, donot count the clutter class') 19 | parser.add_argument('--visu', action='store_true', help='Whether to output OBJ file for prediction visualization.') 20 | FLAGS = parser.parse_args() 21 | 22 | BATCH_SIZE = FLAGS.batch_size 23 | NUM_POINT = FLAGS.num_point 24 | MODEL_PATH = FLAGS.model_path 25 | GPU_INDEX = FLAGS.gpu 26 | DUMP_DIR = FLAGS.dump_dir 27 | if not os.path.exists(DUMP_DIR): os.mkdir(DUMP_DIR) 28 | LOG_FOUT = open(os.path.join(DUMP_DIR, 'log_evaluate.txt'), 'w') 29 | LOG_FOUT.write(str(FLAGS)+'\n') 30 | ROOM_PATH_LIST = [os.path.join(ROOT_DIR,line.rstrip()) for line in open(FLAGS.room_data_filelist)] 31 | 32 | NUM_CLASSES = 13 33 | 34 | def log_string(out_str): 35 | LOG_FOUT.write(out_str+'\n') 36 | LOG_FOUT.flush() 37 | print(out_str) 38 | 39 | def evaluate(): 40 | is_training = False 41 | 42 | with tf.device('/gpu:'+str(GPU_INDEX)): 43 | pointclouds_pl, labels_pl = placeholder_inputs(BATCH_SIZE, NUM_POINT) 44 | is_training_pl = tf.placeholder(tf.bool, shape=()) 45 | 46 | pred = get_model(pointclouds_pl, is_training_pl) 47 | loss = get_loss(pred, labels_pl) 48 | pred_softmax = tf.nn.softmax(pred) 49 | 50 | saver = tf.train.Saver() 51 | 52 | config = tf.ConfigProto() 53 | config.gpu_options.allow_growth = True 54 | config.allow_soft_placement = True 55 | sess = tf.Session(config=config) 56 | 57 | saver.restore(sess, MODEL_PATH) 58 | log_string("Model restored.") 59 | 60 | ops = {'pointclouds_pl': pointclouds_pl, 61 | 'labels_pl': labels_pl, 62 | 'is_training_pl': is_training_pl, 63 | 'pred': pred, 64 | 'pred_softmax': pred_softmax, 65 | 'loss': loss} 66 | 67 | total_correct = 0 68 | total_seen = 0 69 | fout_out_filelist = open(FLAGS.output_filelist, 'w') 70 | for room_path in ROOM_PATH_LIST: 71 | out_data_label_filename = os.path.basename(room_path)[:-4] + '_pred.txt' 72 | out_data_label_filename = os.path.join(DUMP_DIR, out_data_label_filename) 73 | out_gt_label_filename = os.path.basename(room_path)[:-4] + '_gt.txt' 74 | out_gt_label_filename = os.path.join(DUMP_DIR, out_gt_label_filename) 75 | 76 | print(room_path, out_data_label_filename) 77 | # Evaluate room one by one. 78 | a, b = eval_one_epoch(sess, ops, room_path, out_data_label_filename, out_gt_label_filename) 79 | total_correct += a 80 | total_seen += b 81 | fout_out_filelist.write(out_data_label_filename+'\n') 82 | fout_out_filelist.close() 83 | log_string('all room eval accuracy: %f'% (total_correct / float(total_seen))) 84 | 85 | def eval_one_epoch(sess, ops, room_path, out_data_label_filename, out_gt_label_filename): 86 | error_cnt = 0 87 | is_training = False 88 | total_correct = 0 89 | total_seen = 0 90 | loss_sum = 0 91 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 92 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 93 | 94 | if FLAGS.visu: 95 | fout = open(os.path.join(DUMP_DIR, os.path.basename(room_path)[:-4]+'_pred.obj'), 'w') 96 | fout_gt = open(os.path.join(DUMP_DIR, os.path.basename(room_path)[:-4]+'_gt.obj'), 'w') 97 | fout_real_color = open(os.path.join(DUMP_DIR, os.path.basename(room_path)[:-4]+'_real_color.obj'), 'w') 98 | fout_data_label = open(out_data_label_filename, 'w') 99 | fout_gt_label = open(out_gt_label_filename, 'w') 100 | 101 | current_data, current_label = indoor3d_util.room2blocks_wrapper_normalized(room_path, NUM_POINT) 102 | current_data = current_data[:,0:NUM_POINT,:] 103 | current_label = np.squeeze(current_label) 104 | # Get room dimension.. 105 | data_label = np.load(room_path) 106 | data = data_label[:,0:6] 107 | max_room_x = max(data[:,0]) 108 | max_room_y = max(data[:,1]) 109 | max_room_z = max(data[:,2]) 110 | 111 | file_size = current_data.shape[0] 112 | num_batches = file_size // BATCH_SIZE 113 | print(file_size) 114 | 115 | 116 | for batch_idx in range(num_batches): 117 | start_idx = batch_idx * BATCH_SIZE 118 | end_idx = (batch_idx+1) * BATCH_SIZE 119 | cur_batch_size = end_idx - start_idx 120 | 121 | feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :], 122 | ops['labels_pl']: current_label[start_idx:end_idx], 123 | ops['is_training_pl']: is_training} 124 | loss_val, pred_val = sess.run([ops['loss'], ops['pred_softmax']], 125 | feed_dict=feed_dict) 126 | 127 | if FLAGS.no_clutter: 128 | pred_label = np.argmax(pred_val[:,:,0:12], 2) # BxN 129 | else: 130 | pred_label = np.argmax(pred_val, 2) # BxN 131 | 132 | # Save prediction labels to OBJ file 133 | for b in range(BATCH_SIZE): 134 | pts = current_data[start_idx+b, :, :] 135 | l = current_label[start_idx+b,:] 136 | pts[:,6] *= max_room_x 137 | pts[:,7] *= max_room_y 138 | pts[:,8] *= max_room_z 139 | pts[:,3:6] *= 255.0 140 | pred = pred_label[b, :] 141 | for i in range(NUM_POINT): 142 | color = indoor3d_util.g_label2color[pred[i]] 143 | color_gt = indoor3d_util.g_label2color[current_label[start_idx+b, i]] 144 | if FLAGS.visu: 145 | fout.write('v %f %f %f %d %d %d\n' % (pts[i,6], pts[i,7], pts[i,8], color[0], color[1], color[2])) 146 | fout_gt.write('v %f %f %f %d %d %d\n' % (pts[i,6], pts[i,7], pts[i,8], color_gt[0], color_gt[1], color_gt[2])) 147 | fout_data_label.write('%f %f %f %d %d %d %f %d\n' % (pts[i,6], pts[i,7], pts[i,8], pts[i,3], pts[i,4], pts[i,5], pred_val[b,i,pred[i]], pred[i])) 148 | fout_gt_label.write('%d\n' % (l[i])) 149 | 150 | correct = np.sum(pred_label == current_label[start_idx:end_idx,:]) 151 | total_correct += correct 152 | total_seen += (cur_batch_size*NUM_POINT) 153 | loss_sum += (loss_val*BATCH_SIZE) 154 | for i in range(start_idx, end_idx): 155 | for j in range(NUM_POINT): 156 | l = current_label[i, j] 157 | total_seen_class[l] += 1 158 | total_correct_class[l] += (pred_label[i-start_idx, j] == l) 159 | 160 | log_string('eval mean loss: %f' % (loss_sum / float(total_seen/NUM_POINT))) 161 | log_string('eval accuracy: %f'% (total_correct / float(total_seen))) 162 | fout_data_label.close() 163 | fout_gt_label.close() 164 | if FLAGS.visu: 165 | fout.close() 166 | fout_gt.close() 167 | return total_correct, total_seen 168 | 169 | 170 | if __name__=='__main__': 171 | with tf.Graph().as_default(): 172 | evaluate() 173 | LOG_FOUT.close() 174 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/collect_indoor3d_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | ROOT_DIR = os.path.dirname(BASE_DIR) 5 | sys.path.append(BASE_DIR) 6 | import indoor3d_util 7 | 8 | anno_paths = [line.rstrip() for line in open(os.path.join(BASE_DIR, 'meta/anno_paths.txt'))] 9 | anno_paths = [os.path.join(indoor3d_util.DATA_PATH, p) for p in anno_paths] 10 | 11 | output_folder = os.path.join(ROOT_DIR, 'data/stanford_indoor3d') 12 | if not os.path.exists(output_folder): 13 | os.mkdir(output_folder) 14 | 15 | # Note: there is an extra character in the v1.2 data in Area_5/hallway_6. It's fixed manually. 16 | for anno_path in anno_paths: 17 | print(anno_path) 18 | try: 19 | elements = anno_path.split('/') 20 | out_filename = elements[-3]+'_'+elements[-2]+'.npy' 21 | indoor3d_util.collect_point_label(anno_path, os.path.join(output_folder, out_filename), 'numpy') 22 | except: 23 | print(anno_path, 'ERROR!!') 24 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download HDF5 for indoor 3d semantic segmentation (around 1.6GB) -> 'indoor3d_sem_seg_hdf5_data' 4 | wget https://shapenet.cs.stanford.edu/media/indoor3d_sem_seg_hdf5_data.zip 5 | unzip indoor3d_sem_seg_hdf5_data.zip 6 | rm indoor3d_sem_seg_hdf5_data.zip -------------------------------------------------------------------------------- /tensorflow/sem_seg/eval_iou_accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | pred_data_label_filenames = [] 4 | for i in range(1,7): 5 | file_name = 'log{}/output_filelist.txt'.format(i) 6 | pred_data_label_filenames += [line.rstrip() for line in open(file_name)] 7 | 8 | gt_label_filenames = [f.rstrip('_pred\.txt') + '_gt.txt' for f in pred_data_label_filenames] 9 | 10 | num_room = len(gt_label_filenames) 11 | 12 | gt_classes = [0 for _ in range(13)] 13 | positive_classes = [0 for _ in range(13)] 14 | true_positive_classes = [0 for _ in range(13)] 15 | 16 | for i in range(num_room): 17 | print(i) 18 | data_label = np.loadtxt(pred_data_label_filenames[i]) 19 | pred_label = data_label[:,-1] 20 | gt_label = np.loadtxt(gt_label_filenames[i]) 21 | print(gt_label.shape) 22 | for j in xrange(gt_label.shape[0]): 23 | gt_l = int(gt_label[j]) 24 | pred_l = int(pred_label[j]) 25 | gt_classes[gt_l] += 1 26 | positive_classes[pred_l] += 1 27 | true_positive_classes[gt_l] += int(gt_l==pred_l) 28 | 29 | 30 | print(gt_classes) 31 | print(positive_classes) 32 | print(true_positive_classes) 33 | 34 | print('Overall accuracy: {0}'.format(sum(true_positive_classes)/float(sum(positive_classes)))) 35 | 36 | print 'IoU:' 37 | iou_list = [] 38 | for i in range(13): 39 | iou = true_positive_classes[i]/float(gt_classes[i]+positive_classes[i]-true_positive_classes[i]) 40 | print(iou) 41 | iou_list.append(iou) 42 | 43 | print 'avg IoU:' 44 | print(sum(iou_list)/13.0) 45 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/indoor3d_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import os 4 | import sys 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | ROOT_DIR = os.path.dirname(BASE_DIR) 7 | sys.path.append(BASE_DIR) 8 | 9 | # ----------------------------------------------------------------------------- 10 | # CONSTANTS 11 | # ----------------------------------------------------------------------------- 12 | 13 | DATA_PATH = os.path.join(ROOT_DIR, 'data', 'Stanford3dDataset_v1.2_Aligned_Version') 14 | g_classes = [x.rstrip() for x in open(os.path.join(BASE_DIR, 'meta/class_names.txt'))] 15 | g_class2label = {cls: i for i,cls in enumerate(g_classes)} 16 | g_class2color = {'ceiling': [0,255,0], 17 | 'floor': [0,0,255], 18 | 'wall': [0,255,255], 19 | 'beam': [255,255,0], 20 | 'column': [255,0,255], 21 | 'window': [100,100,255], 22 | 'door': [200,200,100], 23 | 'table': [170,120,200], 24 | 'chair': [255,0,0], 25 | 'sofa': [200,100,100], 26 | 'bookcase': [10,200,100], 27 | 'board': [200,200,200], 28 | 'clutter': [50,50,50]} 29 | g_easy_view_labels = [7,8,9,10,11,1] 30 | g_label2color = {g_classes.index(cls): g_class2color[cls] for cls in g_classes} 31 | 32 | 33 | # ----------------------------------------------------------------------------- 34 | # CONVERT ORIGINAL DATA TO OUR DATA_LABEL FILES 35 | # ----------------------------------------------------------------------------- 36 | 37 | def collect_point_label(anno_path, out_filename, file_format='txt'): 38 | """ Convert original dataset files to data_label file (each line is XYZRGBL). 39 | We aggregated all the points from each instance in the room. 40 | 41 | Args: 42 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/ 43 | out_filename: path to save collected points and labels (each line is XYZRGBL) 44 | file_format: txt or numpy, determines what file format to save. 45 | Returns: 46 | None 47 | Note: 48 | the points are shifted before save, the most negative point is now at origin. 49 | """ 50 | points_list = [] 51 | 52 | for f in glob.glob(os.path.join(anno_path, '*.txt')): 53 | cls = os.path.basename(f).split('_')[0] 54 | if cls not in g_classes: # note: in some room there is 'staris' class.. 55 | cls = 'clutter' 56 | points = np.loadtxt(f) 57 | labels = np.ones((points.shape[0],1)) * g_class2label[cls] 58 | points_list.append(np.concatenate([points, labels], 1)) # Nx7 59 | 60 | 61 | data_label = np.concatenate(points_list, 0) 62 | xyz_min = np.amin(data_label, axis=0)[0:3] 63 | data_label[:, 0:3] -= xyz_min 64 | 65 | if file_format=='txt': 66 | fout = open(out_filename, 'w') 67 | for i in range(data_label.shape[0]): 68 | fout.write('%f %f %f %d %d %d %d\n' % \ 69 | (data_label[i,0], data_label[i,1], data_label[i,2], 70 | data_label[i,3], data_label[i,4], data_label[i,5], 71 | data_label[i,6])) 72 | fout.close() 73 | elif file_format=='numpy': 74 | np.save(out_filename, data_label) 75 | else: 76 | print('ERROR!! Unknown file format: %s, please use txt or numpy.' % \ 77 | (file_format)) 78 | exit() 79 | 80 | def point_label_to_obj(input_filename, out_filename, label_color=True, easy_view=False, no_wall=False): 81 | """ For visualization of a room from data_label file, 82 | input_filename: each line is X Y Z R G B L 83 | out_filename: OBJ filename, 84 | visualize input file by coloring point with label color 85 | easy_view: only visualize furnitures and floor 86 | """ 87 | data_label = np.loadtxt(input_filename) 88 | data = data_label[:, 0:6] 89 | label = data_label[:, -1].astype(int) 90 | fout = open(out_filename, 'w') 91 | for i in range(data.shape[0]): 92 | color = g_label2color[label[i]] 93 | if easy_view and (label[i] not in g_easy_view_labels): 94 | continue 95 | if no_wall and ((label[i] == 2) or (label[i]==0)): 96 | continue 97 | if label_color: 98 | fout.write('v %f %f %f %d %d %d\n' % \ 99 | (data[i,0], data[i,1], data[i,2], color[0], color[1], color[2])) 100 | else: 101 | fout.write('v %f %f %f %d %d %d\n' % \ 102 | (data[i,0], data[i,1], data[i,2], data[i,3], data[i,4], data[i,5])) 103 | fout.close() 104 | 105 | 106 | 107 | # ----------------------------------------------------------------------------- 108 | # PREPARE BLOCK DATA FOR DEEPNETS TRAINING/TESTING 109 | # ----------------------------------------------------------------------------- 110 | 111 | def sample_data(data, num_sample): 112 | """ data is in N x ... 113 | we want to keep num_samplexC of them. 114 | if N > num_sample, we will randomly keep num_sample of them. 115 | if N < num_sample, we will randomly duplicate samples. 116 | """ 117 | N = data.shape[0] 118 | if (N == num_sample): 119 | return data, range(N) 120 | elif (N > num_sample): 121 | sample = np.random.choice(N, num_sample) 122 | return data[sample, ...], sample 123 | else: 124 | sample = np.random.choice(N, num_sample-N) 125 | dup_data = data[sample, ...] 126 | return np.concatenate([data, dup_data], 0), range(N)+list(sample) 127 | 128 | def sample_data_label(data, label, num_sample): 129 | new_data, sample_indices = sample_data(data, num_sample) 130 | new_label = label[sample_indices] 131 | return new_data, new_label 132 | 133 | def room2blocks(data, label, num_point, block_size=1.0, stride=1.0, 134 | random_sample=False, sample_num=None, sample_aug=1): 135 | """ Prepare block training data. 136 | Args: 137 | data: N x 6 numpy array, 012 are XYZ in meters, 345 are RGB in [0,1] 138 | assumes the data is shifted (min point is origin) and aligned 139 | (aligned with XYZ axis) 140 | label: N size uint8 numpy array from 0-12 141 | num_point: int, how many points to sample in each block 142 | block_size: float, physical size of the block in meters 143 | stride: float, stride for block sweeping 144 | random_sample: bool, if True, we will randomly sample blocks in the room 145 | sample_num: int, if random sample, how many blocks to sample 146 | [default: room area] 147 | sample_aug: if random sample, how much aug 148 | Returns: 149 | block_datas: K x num_point x 6 np array of XYZRGB, RGB is in [0,1] 150 | block_labels: K x num_point x 1 np array of uint8 labels 151 | 152 | TODO: for this version, blocking is in fixed, non-overlapping pattern. 153 | """ 154 | assert(stride<=block_size) 155 | 156 | limit = np.amax(data, 0)[0:3] 157 | 158 | # Get the corner location for our sampling blocks 159 | xbeg_list = [] 160 | ybeg_list = [] 161 | if not random_sample: 162 | num_block_x = int(np.ceil((limit[0] - block_size) / stride)) + 1 163 | num_block_y = int(np.ceil((limit[1] - block_size) / stride)) + 1 164 | for i in range(num_block_x): 165 | for j in range(num_block_y): 166 | xbeg_list.append(i*stride) 167 | ybeg_list.append(j*stride) 168 | else: 169 | num_block_x = int(np.ceil(limit[0] / block_size)) 170 | num_block_y = int(np.ceil(limit[1] / block_size)) 171 | if sample_num is None: 172 | sample_num = num_block_x * num_block_y * sample_aug 173 | for _ in range(sample_num): 174 | xbeg = np.random.uniform(-block_size, limit[0]) 175 | ybeg = np.random.uniform(-block_size, limit[1]) 176 | xbeg_list.append(xbeg) 177 | ybeg_list.append(ybeg) 178 | 179 | # Collect blocks 180 | block_data_list = [] 181 | block_label_list = [] 182 | idx = 0 183 | for idx in range(len(xbeg_list)): 184 | xbeg = xbeg_list[idx] 185 | ybeg = ybeg_list[idx] 186 | xcond = (data[:,0]<=xbeg+block_size) & (data[:,0]>=xbeg) 187 | ycond = (data[:,1]<=ybeg+block_size) & (data[:,1]>=ybeg) 188 | cond = xcond & ycond 189 | if np.sum(cond) < 100: # discard block if there are less than 100 pts. 190 | continue 191 | 192 | block_data = data[cond, :] 193 | block_label = label[cond] 194 | 195 | # randomly subsample data 196 | block_data_sampled, block_label_sampled = \ 197 | sample_data_label(block_data, block_label, num_point) 198 | block_data_list.append(np.expand_dims(block_data_sampled, 0)) 199 | block_label_list.append(np.expand_dims(block_label_sampled, 0)) 200 | 201 | return np.concatenate(block_data_list, 0), \ 202 | np.concatenate(block_label_list, 0) 203 | 204 | 205 | def room2blocks_plus(data_label, num_point, block_size, stride, 206 | random_sample, sample_num, sample_aug): 207 | """ room2block with input filename and RGB preprocessing. 208 | """ 209 | data = data_label[:,0:6] 210 | data[:,3:6] /= 255.0 211 | label = data_label[:,-1].astype(np.uint8) 212 | 213 | return room2blocks(data, label, num_point, block_size, stride, 214 | random_sample, sample_num, sample_aug) 215 | 216 | def room2blocks_wrapper(data_label_filename, num_point, block_size=1.0, stride=1.0, 217 | random_sample=False, sample_num=None, sample_aug=1): 218 | if data_label_filename[-3:] == 'txt': 219 | data_label = np.loadtxt(data_label_filename) 220 | elif data_label_filename[-3:] == 'npy': 221 | data_label = np.load(data_label_filename) 222 | else: 223 | print('Unknown file type! exiting.') 224 | exit() 225 | return room2blocks_plus(data_label, num_point, block_size, stride, 226 | random_sample, sample_num, sample_aug) 227 | 228 | def room2blocks_plus_normalized(data_label, num_point, block_size, stride, 229 | random_sample, sample_num, sample_aug): 230 | """ room2block, with input filename and RGB preprocessing. 231 | for each block centralize XYZ, add normalized XYZ as 678 channels 232 | """ 233 | data = data_label[:,0:6] 234 | data[:,3:6] /= 255.0 235 | label = data_label[:,-1].astype(np.uint8) 236 | max_room_x = max(data[:,0]) 237 | max_room_y = max(data[:,1]) 238 | max_room_z = max(data[:,2]) 239 | 240 | data_batch, label_batch = room2blocks(data, label, num_point, block_size, stride, 241 | random_sample, sample_num, sample_aug) 242 | new_data_batch = np.zeros((data_batch.shape[0], num_point, 9)) 243 | for b in range(data_batch.shape[0]): 244 | new_data_batch[b, :, 6] = data_batch[b, :, 0]/max_room_x 245 | new_data_batch[b, :, 7] = data_batch[b, :, 1]/max_room_y 246 | new_data_batch[b, :, 8] = data_batch[b, :, 2]/max_room_z 247 | minx = min(data_batch[b, :, 0]) 248 | miny = min(data_batch[b, :, 1]) 249 | data_batch[b, :, 0] -= (minx+block_size/2) 250 | data_batch[b, :, 1] -= (miny+block_size/2) 251 | new_data_batch[:, :, 0:6] = data_batch 252 | return new_data_batch, label_batch 253 | 254 | 255 | def room2blocks_wrapper_normalized(data_label_filename, num_point, block_size=1.0, stride=1.0, 256 | random_sample=False, sample_num=None, sample_aug=1): 257 | if data_label_filename[-3:] == 'txt': 258 | data_label = np.loadtxt(data_label_filename) 259 | elif data_label_filename[-3:] == 'npy': 260 | data_label = np.load(data_label_filename) 261 | else: 262 | print('Unknown file type! exiting.') 263 | exit() 264 | return room2blocks_plus_normalized(data_label, num_point, block_size, stride, 265 | random_sample, sample_num, sample_aug) 266 | 267 | def room2samples(data, label, sample_num_point): 268 | """ Prepare whole room samples. 269 | 270 | Args: 271 | data: N x 6 numpy array, 012 are XYZ in meters, 345 are RGB in [0,1] 272 | assumes the data is shifted (min point is origin) and 273 | aligned (aligned with XYZ axis) 274 | label: N size uint8 numpy array from 0-12 275 | sample_num_point: int, how many points to sample in each sample 276 | Returns: 277 | sample_datas: K x sample_num_point x 9 278 | numpy array of XYZRGBX'Y'Z', RGB is in [0,1] 279 | sample_labels: K x sample_num_point x 1 np array of uint8 labels 280 | """ 281 | N = data.shape[0] 282 | order = np.arange(N) 283 | np.random.shuffle(order) 284 | data = data[order, :] 285 | label = label[order] 286 | 287 | batch_num = int(np.ceil(N / float(sample_num_point))) 288 | sample_datas = np.zeros((batch_num, sample_num_point, 6)) 289 | sample_labels = np.zeros((batch_num, sample_num_point, 1)) 290 | 291 | for i in range(batch_num): 292 | beg_idx = i*sample_num_point 293 | end_idx = min((i+1)*sample_num_point, N) 294 | num = end_idx - beg_idx 295 | sample_datas[i,0:num,:] = data[beg_idx:end_idx, :] 296 | sample_labels[i,0:num,0] = label[beg_idx:end_idx] 297 | if num < sample_num_point: 298 | makeup_indices = np.random.choice(N, sample_num_point - num) 299 | sample_datas[i,num:,:] = data[makeup_indices, :] 300 | sample_labels[i,num:,0] = label[makeup_indices] 301 | return sample_datas, sample_labels 302 | 303 | def room2samples_plus_normalized(data_label, num_point): 304 | """ room2sample, with input filename and RGB preprocessing. 305 | for each block centralize XYZ, add normalized XYZ as 678 channels 306 | """ 307 | data = data_label[:,0:6] 308 | data[:,3:6] /= 255.0 309 | label = data_label[:,-1].astype(np.uint8) 310 | max_room_x = max(data[:,0]) 311 | max_room_y = max(data[:,1]) 312 | max_room_z = max(data[:,2]) 313 | #print(max_room_x, max_room_y, max_room_z) 314 | 315 | data_batch, label_batch = room2samples(data, label, num_point) 316 | new_data_batch = np.zeros((data_batch.shape[0], num_point, 9)) 317 | for b in range(data_batch.shape[0]): 318 | new_data_batch[b, :, 6] = data_batch[b, :, 0]/max_room_x 319 | new_data_batch[b, :, 7] = data_batch[b, :, 1]/max_room_y 320 | new_data_batch[b, :, 8] = data_batch[b, :, 2]/max_room_z 321 | #minx = min(data_batch[b, :, 0]) 322 | #miny = min(data_batch[b, :, 1]) 323 | #data_batch[b, :, 0] -= (minx+block_size/2) 324 | #data_batch[b, :, 1] -= (miny+block_size/2) 325 | new_data_batch[:, :, 0:6] = data_batch 326 | return new_data_batch, label_batch 327 | 328 | 329 | def room2samples_wrapper_normalized(data_label_filename, num_point): 330 | if data_label_filename[-3:] == 'txt': 331 | data_label = np.loadtxt(data_label_filename) 332 | elif data_label_filename[-3:] == 'npy': 333 | data_label = np.load(data_label_filename) 334 | else: 335 | print('Unknown file type! exiting.') 336 | exit() 337 | return room2samples_plus_normalized(data_label, num_point) 338 | 339 | 340 | # ----------------------------------------------------------------------------- 341 | # EXTRACT INSTANCE BBOX FROM ORIGINAL DATA (for detection evaluation) 342 | # ----------------------------------------------------------------------------- 343 | 344 | def collect_bounding_box(anno_path, out_filename): 345 | """ Compute bounding boxes from each instance in original dataset files on 346 | one room. **We assume the bbox is aligned with XYZ coordinate.** 347 | 348 | Args: 349 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/ 350 | out_filename: path to save instance bounding boxes for that room. 351 | each line is x1 y1 z1 x2 y2 z2 label, 352 | where (x1,y1,z1) is the point on the diagonal closer to origin 353 | Returns: 354 | None 355 | Note: 356 | room points are shifted, the most negative point is now at origin. 357 | """ 358 | bbox_label_list = [] 359 | 360 | for f in glob.glob(os.path.join(anno_path, '*.txt')): 361 | cls = os.path.basename(f).split('_')[0] 362 | if cls not in g_classes: # note: in some room there is 'staris' class.. 363 | cls = 'clutter' 364 | points = np.loadtxt(f) 365 | label = g_class2label[cls] 366 | # Compute tightest axis aligned bounding box 367 | xyz_min = np.amin(points[:, 0:3], axis=0) 368 | xyz_max = np.amax(points[:, 0:3], axis=0) 369 | ins_bbox_label = np.expand_dims( 370 | np.concatenate([xyz_min, xyz_max, np.array([label])], 0), 0) 371 | bbox_label_list.append(ins_bbox_label) 372 | 373 | bbox_label = np.concatenate(bbox_label_list, 0) 374 | room_xyz_min = np.amin(bbox_label[:, 0:3], axis=0) 375 | bbox_label[:, 0:3] -= room_xyz_min 376 | bbox_label[:, 3:6] -= room_xyz_min 377 | 378 | fout = open(out_filename, 'w') 379 | for i in range(bbox_label.shape[0]): 380 | fout.write('%f %f %f %f %f %f %d\n' % \ 381 | (bbox_label[i,0], bbox_label[i,1], bbox_label[i,2], 382 | bbox_label[i,3], bbox_label[i,4], bbox_label[i,5], 383 | bbox_label[i,6])) 384 | fout.close() 385 | 386 | def bbox_label_to_obj(input_filename, out_filename_prefix, easy_view=False): 387 | """ Visualization of bounding boxes. 388 | 389 | Args: 390 | input_filename: each line is x1 y1 z1 x2 y2 z2 label 391 | out_filename_prefix: OBJ filename prefix, 392 | visualize object by g_label2color 393 | easy_view: if True, only visualize furniture and floor 394 | Returns: 395 | output a list of OBJ file and MTL files with the same prefix 396 | """ 397 | bbox_label = np.loadtxt(input_filename) 398 | bbox = bbox_label[:, 0:6] 399 | label = bbox_label[:, -1].astype(int) 400 | v_cnt = 0 # count vertex 401 | ins_cnt = 0 # count instance 402 | for i in range(bbox.shape[0]): 403 | if easy_view and (label[i] not in g_easy_view_labels): 404 | continue 405 | obj_filename = out_filename_prefix+'_'+g_classes[label[i]]+'_'+str(ins_cnt)+'.obj' 406 | mtl_filename = out_filename_prefix+'_'+g_classes[label[i]]+'_'+str(ins_cnt)+'.mtl' 407 | fout_obj = open(obj_filename, 'w') 408 | fout_mtl = open(mtl_filename, 'w') 409 | fout_obj.write('mtllib %s\n' % (os.path.basename(mtl_filename))) 410 | 411 | length = bbox[i, 3:6] - bbox[i, 0:3] 412 | a = length[0] 413 | b = length[1] 414 | c = length[2] 415 | x = bbox[i, 0] 416 | y = bbox[i, 1] 417 | z = bbox[i, 2] 418 | color = np.array(g_label2color[label[i]], dtype=float) / 255.0 419 | 420 | material = 'material%d' % (ins_cnt) 421 | fout_obj.write('usemtl %s\n' % (material)) 422 | fout_obj.write('v %f %f %f\n' % (x,y,z+c)) 423 | fout_obj.write('v %f %f %f\n' % (x,y+b,z+c)) 424 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z+c)) 425 | fout_obj.write('v %f %f %f\n' % (x+a,y,z+c)) 426 | fout_obj.write('v %f %f %f\n' % (x,y,z)) 427 | fout_obj.write('v %f %f %f\n' % (x,y+b,z)) 428 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z)) 429 | fout_obj.write('v %f %f %f\n' % (x+a,y,z)) 430 | fout_obj.write('g default\n') 431 | v_cnt = 0 # for individual box 432 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 3+v_cnt, 2+v_cnt, 1+v_cnt)) 433 | fout_obj.write('f %d %d %d %d\n' % (1+v_cnt, 2+v_cnt, 6+v_cnt, 5+v_cnt)) 434 | fout_obj.write('f %d %d %d %d\n' % (7+v_cnt, 6+v_cnt, 2+v_cnt, 3+v_cnt)) 435 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 8+v_cnt, 7+v_cnt, 3+v_cnt)) 436 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 8+v_cnt, 4+v_cnt, 1+v_cnt)) 437 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 6+v_cnt, 7+v_cnt, 8+v_cnt)) 438 | fout_obj.write('\n') 439 | 440 | fout_mtl.write('newmtl %s\n' % (material)) 441 | fout_mtl.write('Kd %f %f %f\n' % (color[0], color[1], color[2])) 442 | fout_mtl.write('\n') 443 | fout_obj.close() 444 | fout_mtl.close() 445 | 446 | v_cnt += 8 447 | ins_cnt += 1 448 | 449 | def bbox_label_to_obj_room(input_filename, out_filename_prefix, easy_view=False, permute=None, center=False, exclude_table=False): 450 | """ Visualization of bounding boxes. 451 | 452 | Args: 453 | input_filename: each line is x1 y1 z1 x2 y2 z2 label 454 | out_filename_prefix: OBJ filename prefix, 455 | visualize object by g_label2color 456 | easy_view: if True, only visualize furniture and floor 457 | permute: if not None, permute XYZ for rendering, e.g. [0 2 1] 458 | center: if True, move obj to have zero origin 459 | Returns: 460 | output a list of OBJ file and MTL files with the same prefix 461 | """ 462 | bbox_label = np.loadtxt(input_filename) 463 | bbox = bbox_label[:, 0:6] 464 | if permute is not None: 465 | assert(len(permute)==3) 466 | permute = np.array(permute) 467 | bbox[:,0:3] = bbox[:,permute] 468 | bbox[:,3:6] = bbox[:,permute+3] 469 | if center: 470 | xyz_max = np.amax(bbox[:,3:6], 0) 471 | bbox[:,0:3] -= (xyz_max/2.0) 472 | bbox[:,3:6] -= (xyz_max/2.0) 473 | bbox /= np.max(xyz_max/2.0) 474 | label = bbox_label[:, -1].astype(int) 475 | obj_filename = out_filename_prefix+'.obj' 476 | mtl_filename = out_filename_prefix+'.mtl' 477 | 478 | fout_obj = open(obj_filename, 'w') 479 | fout_mtl = open(mtl_filename, 'w') 480 | fout_obj.write('mtllib %s\n' % (os.path.basename(mtl_filename))) 481 | v_cnt = 0 # count vertex 482 | ins_cnt = 0 # count instance 483 | for i in range(bbox.shape[0]): 484 | if easy_view and (label[i] not in g_easy_view_labels): 485 | continue 486 | if exclude_table and label[i] == g_classes.index('table'): 487 | continue 488 | 489 | length = bbox[i, 3:6] - bbox[i, 0:3] 490 | a = length[0] 491 | b = length[1] 492 | c = length[2] 493 | x = bbox[i, 0] 494 | y = bbox[i, 1] 495 | z = bbox[i, 2] 496 | color = np.array(g_label2color[label[i]], dtype=float) / 255.0 497 | 498 | material = 'material%d' % (ins_cnt) 499 | fout_obj.write('usemtl %s\n' % (material)) 500 | fout_obj.write('v %f %f %f\n' % (x,y,z+c)) 501 | fout_obj.write('v %f %f %f\n' % (x,y+b,z+c)) 502 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z+c)) 503 | fout_obj.write('v %f %f %f\n' % (x+a,y,z+c)) 504 | fout_obj.write('v %f %f %f\n' % (x,y,z)) 505 | fout_obj.write('v %f %f %f\n' % (x,y+b,z)) 506 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z)) 507 | fout_obj.write('v %f %f %f\n' % (x+a,y,z)) 508 | fout_obj.write('g default\n') 509 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 3+v_cnt, 2+v_cnt, 1+v_cnt)) 510 | fout_obj.write('f %d %d %d %d\n' % (1+v_cnt, 2+v_cnt, 6+v_cnt, 5+v_cnt)) 511 | fout_obj.write('f %d %d %d %d\n' % (7+v_cnt, 6+v_cnt, 2+v_cnt, 3+v_cnt)) 512 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 8+v_cnt, 7+v_cnt, 3+v_cnt)) 513 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 8+v_cnt, 4+v_cnt, 1+v_cnt)) 514 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 6+v_cnt, 7+v_cnt, 8+v_cnt)) 515 | fout_obj.write('\n') 516 | 517 | fout_mtl.write('newmtl %s\n' % (material)) 518 | fout_mtl.write('Kd %f %f %f\n' % (color[0], color[1], color[2])) 519 | fout_mtl.write('\n') 520 | 521 | v_cnt += 8 522 | ins_cnt += 1 523 | 524 | fout_obj.close() 525 | fout_mtl.close() 526 | 527 | 528 | def collect_point_bounding_box(anno_path, out_filename, file_format): 529 | """ Compute bounding boxes from each instance in original dataset files on 530 | one room. **We assume the bbox is aligned with XYZ coordinate.** 531 | Save both the point XYZRGB and the bounding box for the point's 532 | parent element. 533 | 534 | Args: 535 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/ 536 | out_filename: path to save instance bounding boxes for each point, 537 | plus the point's XYZRGBL 538 | each line is XYZRGBL offsetX offsetY offsetZ a b c, 539 | where cx = X+offsetX, cy=X+offsetY, cz=Z+offsetZ 540 | where (cx,cy,cz) is center of the box, a,b,c are distances from center 541 | to the surfaces of the box, i.e. x1 = cx-a, x2 = cx+a, y1=cy-b etc. 542 | file_format: output file format, txt or numpy 543 | Returns: 544 | None 545 | 546 | Note: 547 | room points are shifted, the most negative point is now at origin. 548 | """ 549 | point_bbox_list = [] 550 | 551 | for f in glob.glob(os.path.join(anno_path, '*.txt')): 552 | cls = os.path.basename(f).split('_')[0] 553 | if cls not in g_classes: # note: in some room there is 'staris' class.. 554 | cls = 'clutter' 555 | points = np.loadtxt(f) # Nx6 556 | label = g_class2label[cls] # N, 557 | # Compute tightest axis aligned bounding box 558 | xyz_min = np.amin(points[:, 0:3], axis=0) # 3, 559 | xyz_max = np.amax(points[:, 0:3], axis=0) # 3, 560 | xyz_center = (xyz_min + xyz_max) / 2 561 | dimension = (xyz_max - xyz_min) / 2 562 | 563 | xyz_offsets = xyz_center - points[:,0:3] # Nx3 564 | dimensions = np.ones((points.shape[0],3)) * dimension # Nx3 565 | labels = np.ones((points.shape[0],1)) * label # N 566 | point_bbox_list.append(np.concatenate([points, labels, 567 | xyz_offsets, dimensions], 1)) # Nx13 568 | 569 | point_bbox = np.concatenate(point_bbox_list, 0) # KxNx13 570 | room_xyz_min = np.amin(point_bbox[:, 0:3], axis=0) 571 | point_bbox[:, 0:3] -= room_xyz_min 572 | 573 | if file_format == 'txt': 574 | fout = open(out_filename, 'w') 575 | for i in range(point_bbox.shape[0]): 576 | fout.write('%f %f %f %d %d %d %d %f %f %f %f %f %f\n' % \ 577 | (point_bbox[i,0], point_bbox[i,1], point_bbox[i,2], 578 | point_bbox[i,3], point_bbox[i,4], point_bbox[i,5], 579 | point_bbox[i,6], 580 | point_bbox[i,7], point_bbox[i,8], point_bbox[i,9], 581 | point_bbox[i,10], point_bbox[i,11], point_bbox[i,12])) 582 | 583 | fout.close() 584 | elif file_format == 'numpy': 585 | np.save(out_filename, point_bbox) 586 | else: 587 | print('ERROR!! Unknown file format: %s, please use txt or numpy.' % \ 588 | (file_format)) 589 | exit() 590 | 591 | 592 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/meta/all_data_label.txt: -------------------------------------------------------------------------------- 1 | Area_1_conferenceRoom_1.npy 2 | Area_1_conferenceRoom_2.npy 3 | Area_1_copyRoom_1.npy 4 | Area_1_hallway_1.npy 5 | Area_1_hallway_2.npy 6 | Area_1_hallway_3.npy 7 | Area_1_hallway_4.npy 8 | Area_1_hallway_5.npy 9 | Area_1_hallway_6.npy 10 | Area_1_hallway_7.npy 11 | Area_1_hallway_8.npy 12 | Area_1_office_10.npy 13 | Area_1_office_11.npy 14 | Area_1_office_12.npy 15 | Area_1_office_13.npy 16 | Area_1_office_14.npy 17 | Area_1_office_15.npy 18 | Area_1_office_16.npy 19 | Area_1_office_17.npy 20 | Area_1_office_18.npy 21 | Area_1_office_19.npy 22 | Area_1_office_1.npy 23 | Area_1_office_20.npy 24 | Area_1_office_21.npy 25 | Area_1_office_22.npy 26 | Area_1_office_23.npy 27 | Area_1_office_24.npy 28 | Area_1_office_25.npy 29 | Area_1_office_26.npy 30 | Area_1_office_27.npy 31 | Area_1_office_28.npy 32 | Area_1_office_29.npy 33 | Area_1_office_2.npy 34 | Area_1_office_30.npy 35 | Area_1_office_31.npy 36 | Area_1_office_3.npy 37 | Area_1_office_4.npy 38 | Area_1_office_5.npy 39 | Area_1_office_6.npy 40 | Area_1_office_7.npy 41 | Area_1_office_8.npy 42 | Area_1_office_9.npy 43 | Area_1_pantry_1.npy 44 | Area_1_WC_1.npy 45 | Area_2_auditorium_1.npy 46 | Area_2_auditorium_2.npy 47 | Area_2_conferenceRoom_1.npy 48 | Area_2_hallway_10.npy 49 | Area_2_hallway_11.npy 50 | Area_2_hallway_12.npy 51 | Area_2_hallway_1.npy 52 | Area_2_hallway_2.npy 53 | Area_2_hallway_3.npy 54 | Area_2_hallway_4.npy 55 | Area_2_hallway_5.npy 56 | Area_2_hallway_6.npy 57 | Area_2_hallway_7.npy 58 | Area_2_hallway_8.npy 59 | Area_2_hallway_9.npy 60 | Area_2_office_10.npy 61 | Area_2_office_11.npy 62 | Area_2_office_12.npy 63 | Area_2_office_13.npy 64 | Area_2_office_14.npy 65 | Area_2_office_1.npy 66 | Area_2_office_2.npy 67 | Area_2_office_3.npy 68 | Area_2_office_4.npy 69 | Area_2_office_5.npy 70 | Area_2_office_6.npy 71 | Area_2_office_7.npy 72 | Area_2_office_8.npy 73 | Area_2_office_9.npy 74 | Area_2_storage_1.npy 75 | Area_2_storage_2.npy 76 | Area_2_storage_3.npy 77 | Area_2_storage_4.npy 78 | Area_2_storage_5.npy 79 | Area_2_storage_6.npy 80 | Area_2_storage_7.npy 81 | Area_2_storage_8.npy 82 | Area_2_storage_9.npy 83 | Area_2_WC_1.npy 84 | Area_2_WC_2.npy 85 | Area_3_conferenceRoom_1.npy 86 | Area_3_hallway_1.npy 87 | Area_3_hallway_2.npy 88 | Area_3_hallway_3.npy 89 | Area_3_hallway_4.npy 90 | Area_3_hallway_5.npy 91 | Area_3_hallway_6.npy 92 | Area_3_lounge_1.npy 93 | Area_3_lounge_2.npy 94 | Area_3_office_10.npy 95 | Area_3_office_1.npy 96 | Area_3_office_2.npy 97 | Area_3_office_3.npy 98 | Area_3_office_4.npy 99 | Area_3_office_5.npy 100 | Area_3_office_6.npy 101 | Area_3_office_7.npy 102 | Area_3_office_8.npy 103 | Area_3_office_9.npy 104 | Area_3_storage_1.npy 105 | Area_3_storage_2.npy 106 | Area_3_WC_1.npy 107 | Area_3_WC_2.npy 108 | Area_4_conferenceRoom_1.npy 109 | Area_4_conferenceRoom_2.npy 110 | Area_4_conferenceRoom_3.npy 111 | Area_4_hallway_10.npy 112 | Area_4_hallway_11.npy 113 | Area_4_hallway_12.npy 114 | Area_4_hallway_13.npy 115 | Area_4_hallway_14.npy 116 | Area_4_hallway_1.npy 117 | Area_4_hallway_2.npy 118 | Area_4_hallway_3.npy 119 | Area_4_hallway_4.npy 120 | Area_4_hallway_5.npy 121 | Area_4_hallway_6.npy 122 | Area_4_hallway_7.npy 123 | Area_4_hallway_8.npy 124 | Area_4_hallway_9.npy 125 | Area_4_lobby_1.npy 126 | Area_4_lobby_2.npy 127 | Area_4_office_10.npy 128 | Area_4_office_11.npy 129 | Area_4_office_12.npy 130 | Area_4_office_13.npy 131 | Area_4_office_14.npy 132 | Area_4_office_15.npy 133 | Area_4_office_16.npy 134 | Area_4_office_17.npy 135 | Area_4_office_18.npy 136 | Area_4_office_19.npy 137 | Area_4_office_1.npy 138 | Area_4_office_20.npy 139 | Area_4_office_21.npy 140 | Area_4_office_22.npy 141 | Area_4_office_2.npy 142 | Area_4_office_3.npy 143 | Area_4_office_4.npy 144 | Area_4_office_5.npy 145 | Area_4_office_6.npy 146 | Area_4_office_7.npy 147 | Area_4_office_8.npy 148 | Area_4_office_9.npy 149 | Area_4_storage_1.npy 150 | Area_4_storage_2.npy 151 | Area_4_storage_3.npy 152 | Area_4_storage_4.npy 153 | Area_4_WC_1.npy 154 | Area_4_WC_2.npy 155 | Area_4_WC_3.npy 156 | Area_4_WC_4.npy 157 | Area_5_conferenceRoom_1.npy 158 | Area_5_conferenceRoom_2.npy 159 | Area_5_conferenceRoom_3.npy 160 | Area_5_hallway_10.npy 161 | Area_5_hallway_11.npy 162 | Area_5_hallway_12.npy 163 | Area_5_hallway_13.npy 164 | Area_5_hallway_14.npy 165 | Area_5_hallway_15.npy 166 | Area_5_hallway_1.npy 167 | Area_5_hallway_2.npy 168 | Area_5_hallway_3.npy 169 | Area_5_hallway_4.npy 170 | Area_5_hallway_5.npy 171 | Area_5_hallway_6.npy 172 | Area_5_hallway_7.npy 173 | Area_5_hallway_8.npy 174 | Area_5_hallway_9.npy 175 | Area_5_lobby_1.npy 176 | Area_5_office_10.npy 177 | Area_5_office_11.npy 178 | Area_5_office_12.npy 179 | Area_5_office_13.npy 180 | Area_5_office_14.npy 181 | Area_5_office_15.npy 182 | Area_5_office_16.npy 183 | Area_5_office_17.npy 184 | Area_5_office_18.npy 185 | Area_5_office_19.npy 186 | Area_5_office_1.npy 187 | Area_5_office_20.npy 188 | Area_5_office_21.npy 189 | Area_5_office_22.npy 190 | Area_5_office_23.npy 191 | Area_5_office_24.npy 192 | Area_5_office_25.npy 193 | Area_5_office_26.npy 194 | Area_5_office_27.npy 195 | Area_5_office_28.npy 196 | Area_5_office_29.npy 197 | Area_5_office_2.npy 198 | Area_5_office_30.npy 199 | Area_5_office_31.npy 200 | Area_5_office_32.npy 201 | Area_5_office_33.npy 202 | Area_5_office_34.npy 203 | Area_5_office_35.npy 204 | Area_5_office_36.npy 205 | Area_5_office_37.npy 206 | Area_5_office_38.npy 207 | Area_5_office_39.npy 208 | Area_5_office_3.npy 209 | Area_5_office_40.npy 210 | Area_5_office_41.npy 211 | Area_5_office_42.npy 212 | Area_5_office_4.npy 213 | Area_5_office_5.npy 214 | Area_5_office_6.npy 215 | Area_5_office_7.npy 216 | Area_5_office_8.npy 217 | Area_5_office_9.npy 218 | Area_5_pantry_1.npy 219 | Area_5_storage_1.npy 220 | Area_5_storage_2.npy 221 | Area_5_storage_3.npy 222 | Area_5_storage_4.npy 223 | Area_5_WC_1.npy 224 | Area_5_WC_2.npy 225 | Area_6_conferenceRoom_1.npy 226 | Area_6_copyRoom_1.npy 227 | Area_6_hallway_1.npy 228 | Area_6_hallway_2.npy 229 | Area_6_hallway_3.npy 230 | Area_6_hallway_4.npy 231 | Area_6_hallway_5.npy 232 | Area_6_hallway_6.npy 233 | Area_6_lounge_1.npy 234 | Area_6_office_10.npy 235 | Area_6_office_11.npy 236 | Area_6_office_12.npy 237 | Area_6_office_13.npy 238 | Area_6_office_14.npy 239 | Area_6_office_15.npy 240 | Area_6_office_16.npy 241 | Area_6_office_17.npy 242 | Area_6_office_18.npy 243 | Area_6_office_19.npy 244 | Area_6_office_1.npy 245 | Area_6_office_20.npy 246 | Area_6_office_21.npy 247 | Area_6_office_22.npy 248 | Area_6_office_23.npy 249 | Area_6_office_24.npy 250 | Area_6_office_25.npy 251 | Area_6_office_26.npy 252 | Area_6_office_27.npy 253 | Area_6_office_28.npy 254 | Area_6_office_29.npy 255 | Area_6_office_2.npy 256 | Area_6_office_30.npy 257 | Area_6_office_31.npy 258 | Area_6_office_32.npy 259 | Area_6_office_33.npy 260 | Area_6_office_34.npy 261 | Area_6_office_35.npy 262 | Area_6_office_36.npy 263 | Area_6_office_37.npy 264 | Area_6_office_3.npy 265 | Area_6_office_4.npy 266 | Area_6_office_5.npy 267 | Area_6_office_6.npy 268 | Area_6_office_7.npy 269 | Area_6_office_8.npy 270 | Area_6_office_9.npy 271 | Area_6_openspace_1.npy 272 | Area_6_pantry_1.npy 273 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/meta/anno_paths.txt: -------------------------------------------------------------------------------- 1 | Area_1/conferenceRoom_1/Annotations 2 | Area_1/conferenceRoom_2/Annotations 3 | Area_1/copyRoom_1/Annotations 4 | Area_1/hallway_1/Annotations 5 | Area_1/hallway_2/Annotations 6 | Area_1/hallway_3/Annotations 7 | Area_1/hallway_4/Annotations 8 | Area_1/hallway_5/Annotations 9 | Area_1/hallway_6/Annotations 10 | Area_1/hallway_7/Annotations 11 | Area_1/hallway_8/Annotations 12 | Area_1/office_10/Annotations 13 | Area_1/office_11/Annotations 14 | Area_1/office_12/Annotations 15 | Area_1/office_13/Annotations 16 | Area_1/office_14/Annotations 17 | Area_1/office_15/Annotations 18 | Area_1/office_16/Annotations 19 | Area_1/office_17/Annotations 20 | Area_1/office_18/Annotations 21 | Area_1/office_19/Annotations 22 | Area_1/office_1/Annotations 23 | Area_1/office_20/Annotations 24 | Area_1/office_21/Annotations 25 | Area_1/office_22/Annotations 26 | Area_1/office_23/Annotations 27 | Area_1/office_24/Annotations 28 | Area_1/office_25/Annotations 29 | Area_1/office_26/Annotations 30 | Area_1/office_27/Annotations 31 | Area_1/office_28/Annotations 32 | Area_1/office_29/Annotations 33 | Area_1/office_2/Annotations 34 | Area_1/office_30/Annotations 35 | Area_1/office_31/Annotations 36 | Area_1/office_3/Annotations 37 | Area_1/office_4/Annotations 38 | Area_1/office_5/Annotations 39 | Area_1/office_6/Annotations 40 | Area_1/office_7/Annotations 41 | Area_1/office_8/Annotations 42 | Area_1/office_9/Annotations 43 | Area_1/pantry_1/Annotations 44 | Area_1/WC_1/Annotations 45 | Area_2/auditorium_1/Annotations 46 | Area_2/auditorium_2/Annotations 47 | Area_2/conferenceRoom_1/Annotations 48 | Area_2/hallway_10/Annotations 49 | Area_2/hallway_11/Annotations 50 | Area_2/hallway_12/Annotations 51 | Area_2/hallway_1/Annotations 52 | Area_2/hallway_2/Annotations 53 | Area_2/hallway_3/Annotations 54 | Area_2/hallway_4/Annotations 55 | Area_2/hallway_5/Annotations 56 | Area_2/hallway_6/Annotations 57 | Area_2/hallway_7/Annotations 58 | Area_2/hallway_8/Annotations 59 | Area_2/hallway_9/Annotations 60 | Area_2/office_10/Annotations 61 | Area_2/office_11/Annotations 62 | Area_2/office_12/Annotations 63 | Area_2/office_13/Annotations 64 | Area_2/office_14/Annotations 65 | Area_2/office_1/Annotations 66 | Area_2/office_2/Annotations 67 | Area_2/office_3/Annotations 68 | Area_2/office_4/Annotations 69 | Area_2/office_5/Annotations 70 | Area_2/office_6/Annotations 71 | Area_2/office_7/Annotations 72 | Area_2/office_8/Annotations 73 | Area_2/office_9/Annotations 74 | Area_2/storage_1/Annotations 75 | Area_2/storage_2/Annotations 76 | Area_2/storage_3/Annotations 77 | Area_2/storage_4/Annotations 78 | Area_2/storage_5/Annotations 79 | Area_2/storage_6/Annotations 80 | Area_2/storage_7/Annotations 81 | Area_2/storage_8/Annotations 82 | Area_2/storage_9/Annotations 83 | Area_2/WC_1/Annotations 84 | Area_2/WC_2/Annotations 85 | Area_3/conferenceRoom_1/Annotations 86 | Area_3/hallway_1/Annotations 87 | Area_3/hallway_2/Annotations 88 | Area_3/hallway_3/Annotations 89 | Area_3/hallway_4/Annotations 90 | Area_3/hallway_5/Annotations 91 | Area_3/hallway_6/Annotations 92 | Area_3/lounge_1/Annotations 93 | Area_3/lounge_2/Annotations 94 | Area_3/office_10/Annotations 95 | Area_3/office_1/Annotations 96 | Area_3/office_2/Annotations 97 | Area_3/office_3/Annotations 98 | Area_3/office_4/Annotations 99 | Area_3/office_5/Annotations 100 | Area_3/office_6/Annotations 101 | Area_3/office_7/Annotations 102 | Area_3/office_8/Annotations 103 | Area_3/office_9/Annotations 104 | Area_3/storage_1/Annotations 105 | Area_3/storage_2/Annotations 106 | Area_3/WC_1/Annotations 107 | Area_3/WC_2/Annotations 108 | Area_4/conferenceRoom_1/Annotations 109 | Area_4/conferenceRoom_2/Annotations 110 | Area_4/conferenceRoom_3/Annotations 111 | Area_4/hallway_10/Annotations 112 | Area_4/hallway_11/Annotations 113 | Area_4/hallway_12/Annotations 114 | Area_4/hallway_13/Annotations 115 | Area_4/hallway_14/Annotations 116 | Area_4/hallway_1/Annotations 117 | Area_4/hallway_2/Annotations 118 | Area_4/hallway_3/Annotations 119 | Area_4/hallway_4/Annotations 120 | Area_4/hallway_5/Annotations 121 | Area_4/hallway_6/Annotations 122 | Area_4/hallway_7/Annotations 123 | Area_4/hallway_8/Annotations 124 | Area_4/hallway_9/Annotations 125 | Area_4/lobby_1/Annotations 126 | Area_4/lobby_2/Annotations 127 | Area_4/office_10/Annotations 128 | Area_4/office_11/Annotations 129 | Area_4/office_12/Annotations 130 | Area_4/office_13/Annotations 131 | Area_4/office_14/Annotations 132 | Area_4/office_15/Annotations 133 | Area_4/office_16/Annotations 134 | Area_4/office_17/Annotations 135 | Area_4/office_18/Annotations 136 | Area_4/office_19/Annotations 137 | Area_4/office_1/Annotations 138 | Area_4/office_20/Annotations 139 | Area_4/office_21/Annotations 140 | Area_4/office_22/Annotations 141 | Area_4/office_2/Annotations 142 | Area_4/office_3/Annotations 143 | Area_4/office_4/Annotations 144 | Area_4/office_5/Annotations 145 | Area_4/office_6/Annotations 146 | Area_4/office_7/Annotations 147 | Area_4/office_8/Annotations 148 | Area_4/office_9/Annotations 149 | Area_4/storage_1/Annotations 150 | Area_4/storage_2/Annotations 151 | Area_4/storage_3/Annotations 152 | Area_4/storage_4/Annotations 153 | Area_4/WC_1/Annotations 154 | Area_4/WC_2/Annotations 155 | Area_4/WC_3/Annotations 156 | Area_4/WC_4/Annotations 157 | Area_5/conferenceRoom_1/Annotations 158 | Area_5/conferenceRoom_2/Annotations 159 | Area_5/conferenceRoom_3/Annotations 160 | Area_5/hallway_10/Annotations 161 | Area_5/hallway_11/Annotations 162 | Area_5/hallway_12/Annotations 163 | Area_5/hallway_13/Annotations 164 | Area_5/hallway_14/Annotations 165 | Area_5/hallway_15/Annotations 166 | Area_5/hallway_1/Annotations 167 | Area_5/hallway_2/Annotations 168 | Area_5/hallway_3/Annotations 169 | Area_5/hallway_4/Annotations 170 | Area_5/hallway_5/Annotations 171 | Area_5/hallway_6/Annotations 172 | Area_5/hallway_7/Annotations 173 | Area_5/hallway_8/Annotations 174 | Area_5/hallway_9/Annotations 175 | Area_5/lobby_1/Annotations 176 | Area_5/office_10/Annotations 177 | Area_5/office_11/Annotations 178 | Area_5/office_12/Annotations 179 | Area_5/office_13/Annotations 180 | Area_5/office_14/Annotations 181 | Area_5/office_15/Annotations 182 | Area_5/office_16/Annotations 183 | Area_5/office_17/Annotations 184 | Area_5/office_18/Annotations 185 | Area_5/office_19/Annotations 186 | Area_5/office_1/Annotations 187 | Area_5/office_20/Annotations 188 | Area_5/office_21/Annotations 189 | Area_5/office_22/Annotations 190 | Area_5/office_23/Annotations 191 | Area_5/office_24/Annotations 192 | Area_5/office_25/Annotations 193 | Area_5/office_26/Annotations 194 | Area_5/office_27/Annotations 195 | Area_5/office_28/Annotations 196 | Area_5/office_29/Annotations 197 | Area_5/office_2/Annotations 198 | Area_5/office_30/Annotations 199 | Area_5/office_31/Annotations 200 | Area_5/office_32/Annotations 201 | Area_5/office_33/Annotations 202 | Area_5/office_34/Annotations 203 | Area_5/office_35/Annotations 204 | Area_5/office_36/Annotations 205 | Area_5/office_37/Annotations 206 | Area_5/office_38/Annotations 207 | Area_5/office_39/Annotations 208 | Area_5/office_3/Annotations 209 | Area_5/office_40/Annotations 210 | Area_5/office_41/Annotations 211 | Area_5/office_42/Annotations 212 | Area_5/office_4/Annotations 213 | Area_5/office_5/Annotations 214 | Area_5/office_6/Annotations 215 | Area_5/office_7/Annotations 216 | Area_5/office_8/Annotations 217 | Area_5/office_9/Annotations 218 | Area_5/pantry_1/Annotations 219 | Area_5/storage_1/Annotations 220 | Area_5/storage_2/Annotations 221 | Area_5/storage_3/Annotations 222 | Area_5/storage_4/Annotations 223 | Area_5/WC_1/Annotations 224 | Area_5/WC_2/Annotations 225 | Area_6/conferenceRoom_1/Annotations 226 | Area_6/copyRoom_1/Annotations 227 | Area_6/hallway_1/Annotations 228 | Area_6/hallway_2/Annotations 229 | Area_6/hallway_3/Annotations 230 | Area_6/hallway_4/Annotations 231 | Area_6/hallway_5/Annotations 232 | Area_6/hallway_6/Annotations 233 | Area_6/lounge_1/Annotations 234 | Area_6/office_10/Annotations 235 | Area_6/office_11/Annotations 236 | Area_6/office_12/Annotations 237 | Area_6/office_13/Annotations 238 | Area_6/office_14/Annotations 239 | Area_6/office_15/Annotations 240 | Area_6/office_16/Annotations 241 | Area_6/office_17/Annotations 242 | Area_6/office_18/Annotations 243 | Area_6/office_19/Annotations 244 | Area_6/office_1/Annotations 245 | Area_6/office_20/Annotations 246 | Area_6/office_21/Annotations 247 | Area_6/office_22/Annotations 248 | Area_6/office_23/Annotations 249 | Area_6/office_24/Annotations 250 | Area_6/office_25/Annotations 251 | Area_6/office_26/Annotations 252 | Area_6/office_27/Annotations 253 | Area_6/office_28/Annotations 254 | Area_6/office_29/Annotations 255 | Area_6/office_2/Annotations 256 | Area_6/office_30/Annotations 257 | Area_6/office_31/Annotations 258 | Area_6/office_32/Annotations 259 | Area_6/office_33/Annotations 260 | Area_6/office_34/Annotations 261 | Area_6/office_35/Annotations 262 | Area_6/office_36/Annotations 263 | Area_6/office_37/Annotations 264 | Area_6/office_3/Annotations 265 | Area_6/office_4/Annotations 266 | Area_6/office_5/Annotations 267 | Area_6/office_6/Annotations 268 | Area_6/office_7/Annotations 269 | Area_6/office_8/Annotations 270 | Area_6/office_9/Annotations 271 | Area_6/openspace_1/Annotations 272 | Area_6/pantry_1/Annotations 273 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/meta/area1_data_label.txt: -------------------------------------------------------------------------------- 1 | data/stanford_indoor3d/Area_1_conferenceRoom_1.npy 2 | data/stanford_indoor3d/Area_1_conferenceRoom_2.npy 3 | data/stanford_indoor3d/Area_1_copyRoom_1.npy 4 | data/stanford_indoor3d/Area_1_hallway_1.npy 5 | data/stanford_indoor3d/Area_1_hallway_2.npy 6 | data/stanford_indoor3d/Area_1_hallway_3.npy 7 | data/stanford_indoor3d/Area_1_hallway_4.npy 8 | data/stanford_indoor3d/Area_1_hallway_5.npy 9 | data/stanford_indoor3d/Area_1_hallway_6.npy 10 | data/stanford_indoor3d/Area_1_hallway_7.npy 11 | data/stanford_indoor3d/Area_1_hallway_8.npy 12 | data/stanford_indoor3d/Area_1_office_10.npy 13 | data/stanford_indoor3d/Area_1_office_11.npy 14 | data/stanford_indoor3d/Area_1_office_12.npy 15 | data/stanford_indoor3d/Area_1_office_13.npy 16 | data/stanford_indoor3d/Area_1_office_14.npy 17 | data/stanford_indoor3d/Area_1_office_15.npy 18 | data/stanford_indoor3d/Area_1_office_16.npy 19 | data/stanford_indoor3d/Area_1_office_17.npy 20 | data/stanford_indoor3d/Area_1_office_18.npy 21 | data/stanford_indoor3d/Area_1_office_19.npy 22 | data/stanford_indoor3d/Area_1_office_1.npy 23 | data/stanford_indoor3d/Area_1_office_20.npy 24 | data/stanford_indoor3d/Area_1_office_21.npy 25 | data/stanford_indoor3d/Area_1_office_22.npy 26 | data/stanford_indoor3d/Area_1_office_23.npy 27 | data/stanford_indoor3d/Area_1_office_24.npy 28 | data/stanford_indoor3d/Area_1_office_25.npy 29 | data/stanford_indoor3d/Area_1_office_26.npy 30 | data/stanford_indoor3d/Area_1_office_27.npy 31 | data/stanford_indoor3d/Area_1_office_28.npy 32 | data/stanford_indoor3d/Area_1_office_29.npy 33 | data/stanford_indoor3d/Area_1_office_2.npy 34 | data/stanford_indoor3d/Area_1_office_30.npy 35 | data/stanford_indoor3d/Area_1_office_31.npy 36 | data/stanford_indoor3d/Area_1_office_3.npy 37 | data/stanford_indoor3d/Area_1_office_4.npy 38 | data/stanford_indoor3d/Area_1_office_5.npy 39 | data/stanford_indoor3d/Area_1_office_6.npy 40 | data/stanford_indoor3d/Area_1_office_7.npy 41 | data/stanford_indoor3d/Area_1_office_8.npy 42 | data/stanford_indoor3d/Area_1_office_9.npy 43 | data/stanford_indoor3d/Area_1_pantry_1.npy 44 | data/stanford_indoor3d/Area_1_WC_1.npy 45 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/meta/area2_data_label.txt: -------------------------------------------------------------------------------- 1 | data/stanford_indoor3d/Area_2_auditorium_1.npy 2 | data/stanford_indoor3d/Area_2_auditorium_2.npy 3 | data/stanford_indoor3d/Area_2_conferenceRoom_1.npy 4 | data/stanford_indoor3d/Area_2_hallway_10.npy 5 | data/stanford_indoor3d/Area_2_hallway_11.npy 6 | data/stanford_indoor3d/Area_2_hallway_12.npy 7 | data/stanford_indoor3d/Area_2_hallway_1.npy 8 | data/stanford_indoor3d/Area_2_hallway_2.npy 9 | data/stanford_indoor3d/Area_2_hallway_3.npy 10 | data/stanford_indoor3d/Area_2_hallway_4.npy 11 | data/stanford_indoor3d/Area_2_hallway_5.npy 12 | data/stanford_indoor3d/Area_2_hallway_6.npy 13 | data/stanford_indoor3d/Area_2_hallway_7.npy 14 | data/stanford_indoor3d/Area_2_hallway_8.npy 15 | data/stanford_indoor3d/Area_2_hallway_9.npy 16 | data/stanford_indoor3d/Area_2_office_10.npy 17 | data/stanford_indoor3d/Area_2_office_11.npy 18 | data/stanford_indoor3d/Area_2_office_12.npy 19 | data/stanford_indoor3d/Area_2_office_13.npy 20 | data/stanford_indoor3d/Area_2_office_14.npy 21 | data/stanford_indoor3d/Area_2_office_1.npy 22 | data/stanford_indoor3d/Area_2_office_2.npy 23 | data/stanford_indoor3d/Area_2_office_3.npy 24 | data/stanford_indoor3d/Area_2_office_4.npy 25 | data/stanford_indoor3d/Area_2_office_5.npy 26 | data/stanford_indoor3d/Area_2_office_6.npy 27 | data/stanford_indoor3d/Area_2_office_7.npy 28 | data/stanford_indoor3d/Area_2_office_8.npy 29 | data/stanford_indoor3d/Area_2_office_9.npy 30 | data/stanford_indoor3d/Area_2_storage_1.npy 31 | data/stanford_indoor3d/Area_2_storage_2.npy 32 | data/stanford_indoor3d/Area_2_storage_3.npy 33 | data/stanford_indoor3d/Area_2_storage_4.npy 34 | data/stanford_indoor3d/Area_2_storage_5.npy 35 | data/stanford_indoor3d/Area_2_storage_6.npy 36 | data/stanford_indoor3d/Area_2_storage_7.npy 37 | data/stanford_indoor3d/Area_2_storage_8.npy 38 | data/stanford_indoor3d/Area_2_storage_9.npy 39 | data/stanford_indoor3d/Area_2_WC_1.npy 40 | data/stanford_indoor3d/Area_2_WC_2.npy 41 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/meta/area3_data_label.txt: -------------------------------------------------------------------------------- 1 | data/stanford_indoor3d/Area_3_conferenceRoom_1.npy 2 | data/stanford_indoor3d/Area_3_hallway_1.npy 3 | data/stanford_indoor3d/Area_3_hallway_2.npy 4 | data/stanford_indoor3d/Area_3_hallway_3.npy 5 | data/stanford_indoor3d/Area_3_hallway_4.npy 6 | data/stanford_indoor3d/Area_3_hallway_5.npy 7 | data/stanford_indoor3d/Area_3_hallway_6.npy 8 | data/stanford_indoor3d/Area_3_lounge_1.npy 9 | data/stanford_indoor3d/Area_3_lounge_2.npy 10 | data/stanford_indoor3d/Area_3_office_10.npy 11 | data/stanford_indoor3d/Area_3_office_1.npy 12 | data/stanford_indoor3d/Area_3_office_2.npy 13 | data/stanford_indoor3d/Area_3_office_3.npy 14 | data/stanford_indoor3d/Area_3_office_4.npy 15 | data/stanford_indoor3d/Area_3_office_5.npy 16 | data/stanford_indoor3d/Area_3_office_6.npy 17 | data/stanford_indoor3d/Area_3_office_7.npy 18 | data/stanford_indoor3d/Area_3_office_8.npy 19 | data/stanford_indoor3d/Area_3_office_9.npy 20 | data/stanford_indoor3d/Area_3_storage_1.npy 21 | data/stanford_indoor3d/Area_3_storage_2.npy 22 | data/stanford_indoor3d/Area_3_WC_1.npy 23 | data/stanford_indoor3d/Area_3_WC_2.npy 24 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/meta/area4_data_label.txt: -------------------------------------------------------------------------------- 1 | data/stanford_indoor3d/Area_4_conferenceRoom_1.npy 2 | data/stanford_indoor3d/Area_4_conferenceRoom_2.npy 3 | data/stanford_indoor3d/Area_4_conferenceRoom_3.npy 4 | data/stanford_indoor3d/Area_4_hallway_10.npy 5 | data/stanford_indoor3d/Area_4_hallway_11.npy 6 | data/stanford_indoor3d/Area_4_hallway_12.npy 7 | data/stanford_indoor3d/Area_4_hallway_13.npy 8 | data/stanford_indoor3d/Area_4_hallway_14.npy 9 | data/stanford_indoor3d/Area_4_hallway_1.npy 10 | data/stanford_indoor3d/Area_4_hallway_2.npy 11 | data/stanford_indoor3d/Area_4_hallway_3.npy 12 | data/stanford_indoor3d/Area_4_hallway_4.npy 13 | data/stanford_indoor3d/Area_4_hallway_5.npy 14 | data/stanford_indoor3d/Area_4_hallway_6.npy 15 | data/stanford_indoor3d/Area_4_hallway_7.npy 16 | data/stanford_indoor3d/Area_4_hallway_8.npy 17 | data/stanford_indoor3d/Area_4_hallway_9.npy 18 | data/stanford_indoor3d/Area_4_lobby_1.npy 19 | data/stanford_indoor3d/Area_4_lobby_2.npy 20 | data/stanford_indoor3d/Area_4_office_10.npy 21 | data/stanford_indoor3d/Area_4_office_11.npy 22 | data/stanford_indoor3d/Area_4_office_12.npy 23 | data/stanford_indoor3d/Area_4_office_13.npy 24 | data/stanford_indoor3d/Area_4_office_14.npy 25 | data/stanford_indoor3d/Area_4_office_15.npy 26 | data/stanford_indoor3d/Area_4_office_16.npy 27 | data/stanford_indoor3d/Area_4_office_17.npy 28 | data/stanford_indoor3d/Area_4_office_18.npy 29 | data/stanford_indoor3d/Area_4_office_19.npy 30 | data/stanford_indoor3d/Area_4_office_1.npy 31 | data/stanford_indoor3d/Area_4_office_20.npy 32 | data/stanford_indoor3d/Area_4_office_21.npy 33 | data/stanford_indoor3d/Area_4_office_22.npy 34 | data/stanford_indoor3d/Area_4_office_2.npy 35 | data/stanford_indoor3d/Area_4_office_3.npy 36 | data/stanford_indoor3d/Area_4_office_4.npy 37 | data/stanford_indoor3d/Area_4_office_5.npy 38 | data/stanford_indoor3d/Area_4_office_6.npy 39 | data/stanford_indoor3d/Area_4_office_7.npy 40 | data/stanford_indoor3d/Area_4_office_8.npy 41 | data/stanford_indoor3d/Area_4_office_9.npy 42 | data/stanford_indoor3d/Area_4_storage_1.npy 43 | data/stanford_indoor3d/Area_4_storage_2.npy 44 | data/stanford_indoor3d/Area_4_storage_3.npy 45 | data/stanford_indoor3d/Area_4_storage_4.npy 46 | data/stanford_indoor3d/Area_4_WC_1.npy 47 | data/stanford_indoor3d/Area_4_WC_2.npy 48 | data/stanford_indoor3d/Area_4_WC_3.npy 49 | data/stanford_indoor3d/Area_4_WC_4.npy 50 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/meta/area5_data_label.txt: -------------------------------------------------------------------------------- 1 | data/stanford_indoor3d/Area_5_conferenceRoom_1.npy 2 | data/stanford_indoor3d/Area_5_conferenceRoom_2.npy 3 | data/stanford_indoor3d/Area_5_conferenceRoom_3.npy 4 | data/stanford_indoor3d/Area_5_hallway_10.npy 5 | data/stanford_indoor3d/Area_5_hallway_11.npy 6 | data/stanford_indoor3d/Area_5_hallway_12.npy 7 | data/stanford_indoor3d/Area_5_hallway_13.npy 8 | data/stanford_indoor3d/Area_5_hallway_14.npy 9 | data/stanford_indoor3d/Area_5_hallway_15.npy 10 | data/stanford_indoor3d/Area_5_hallway_1.npy 11 | data/stanford_indoor3d/Area_5_hallway_2.npy 12 | data/stanford_indoor3d/Area_5_hallway_3.npy 13 | data/stanford_indoor3d/Area_5_hallway_4.npy 14 | data/stanford_indoor3d/Area_5_hallway_5.npy 15 | data/stanford_indoor3d/Area_5_hallway_6.npy 16 | data/stanford_indoor3d/Area_5_hallway_7.npy 17 | data/stanford_indoor3d/Area_5_hallway_8.npy 18 | data/stanford_indoor3d/Area_5_hallway_9.npy 19 | data/stanford_indoor3d/Area_5_lobby_1.npy 20 | data/stanford_indoor3d/Area_5_office_10.npy 21 | data/stanford_indoor3d/Area_5_office_11.npy 22 | data/stanford_indoor3d/Area_5_office_12.npy 23 | data/stanford_indoor3d/Area_5_office_13.npy 24 | data/stanford_indoor3d/Area_5_office_14.npy 25 | data/stanford_indoor3d/Area_5_office_15.npy 26 | data/stanford_indoor3d/Area_5_office_16.npy 27 | data/stanford_indoor3d/Area_5_office_17.npy 28 | data/stanford_indoor3d/Area_5_office_18.npy 29 | data/stanford_indoor3d/Area_5_office_19.npy 30 | data/stanford_indoor3d/Area_5_office_1.npy 31 | data/stanford_indoor3d/Area_5_office_20.npy 32 | data/stanford_indoor3d/Area_5_office_21.npy 33 | data/stanford_indoor3d/Area_5_office_22.npy 34 | data/stanford_indoor3d/Area_5_office_23.npy 35 | data/stanford_indoor3d/Area_5_office_24.npy 36 | data/stanford_indoor3d/Area_5_office_25.npy 37 | data/stanford_indoor3d/Area_5_office_26.npy 38 | data/stanford_indoor3d/Area_5_office_27.npy 39 | data/stanford_indoor3d/Area_5_office_28.npy 40 | data/stanford_indoor3d/Area_5_office_29.npy 41 | data/stanford_indoor3d/Area_5_office_2.npy 42 | data/stanford_indoor3d/Area_5_office_30.npy 43 | data/stanford_indoor3d/Area_5_office_31.npy 44 | data/stanford_indoor3d/Area_5_office_32.npy 45 | data/stanford_indoor3d/Area_5_office_33.npy 46 | data/stanford_indoor3d/Area_5_office_34.npy 47 | data/stanford_indoor3d/Area_5_office_35.npy 48 | data/stanford_indoor3d/Area_5_office_36.npy 49 | data/stanford_indoor3d/Area_5_office_37.npy 50 | data/stanford_indoor3d/Area_5_office_38.npy 51 | data/stanford_indoor3d/Area_5_office_39.npy 52 | data/stanford_indoor3d/Area_5_office_3.npy 53 | data/stanford_indoor3d/Area_5_office_40.npy 54 | data/stanford_indoor3d/Area_5_office_41.npy 55 | data/stanford_indoor3d/Area_5_office_42.npy 56 | data/stanford_indoor3d/Area_5_office_4.npy 57 | data/stanford_indoor3d/Area_5_office_5.npy 58 | data/stanford_indoor3d/Area_5_office_6.npy 59 | data/stanford_indoor3d/Area_5_office_7.npy 60 | data/stanford_indoor3d/Area_5_office_8.npy 61 | data/stanford_indoor3d/Area_5_office_9.npy 62 | data/stanford_indoor3d/Area_5_pantry_1.npy 63 | data/stanford_indoor3d/Area_5_storage_1.npy 64 | data/stanford_indoor3d/Area_5_storage_2.npy 65 | data/stanford_indoor3d/Area_5_storage_3.npy 66 | data/stanford_indoor3d/Area_5_storage_4.npy 67 | data/stanford_indoor3d/Area_5_WC_1.npy 68 | data/stanford_indoor3d/Area_5_WC_2.npy 69 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/meta/area6_data_label.txt: -------------------------------------------------------------------------------- 1 | data/stanford_indoor3d/Area_6_conferenceRoom_1.npy 2 | data/stanford_indoor3d/Area_6_copyRoom_1.npy 3 | data/stanford_indoor3d/Area_6_hallway_1.npy 4 | data/stanford_indoor3d/Area_6_hallway_2.npy 5 | data/stanford_indoor3d/Area_6_hallway_3.npy 6 | data/stanford_indoor3d/Area_6_hallway_4.npy 7 | data/stanford_indoor3d/Area_6_hallway_5.npy 8 | data/stanford_indoor3d/Area_6_hallway_6.npy 9 | data/stanford_indoor3d/Area_6_lounge_1.npy 10 | data/stanford_indoor3d/Area_6_office_10.npy 11 | data/stanford_indoor3d/Area_6_office_11.npy 12 | data/stanford_indoor3d/Area_6_office_12.npy 13 | data/stanford_indoor3d/Area_6_office_13.npy 14 | data/stanford_indoor3d/Area_6_office_14.npy 15 | data/stanford_indoor3d/Area_6_office_15.npy 16 | data/stanford_indoor3d/Area_6_office_16.npy 17 | data/stanford_indoor3d/Area_6_office_17.npy 18 | data/stanford_indoor3d/Area_6_office_18.npy 19 | data/stanford_indoor3d/Area_6_office_19.npy 20 | data/stanford_indoor3d/Area_6_office_1.npy 21 | data/stanford_indoor3d/Area_6_office_20.npy 22 | data/stanford_indoor3d/Area_6_office_21.npy 23 | data/stanford_indoor3d/Area_6_office_22.npy 24 | data/stanford_indoor3d/Area_6_office_23.npy 25 | data/stanford_indoor3d/Area_6_office_24.npy 26 | data/stanford_indoor3d/Area_6_office_25.npy 27 | data/stanford_indoor3d/Area_6_office_26.npy 28 | data/stanford_indoor3d/Area_6_office_27.npy 29 | data/stanford_indoor3d/Area_6_office_28.npy 30 | data/stanford_indoor3d/Area_6_office_29.npy 31 | data/stanford_indoor3d/Area_6_office_2.npy 32 | data/stanford_indoor3d/Area_6_office_30.npy 33 | data/stanford_indoor3d/Area_6_office_31.npy 34 | data/stanford_indoor3d/Area_6_office_32.npy 35 | data/stanford_indoor3d/Area_6_office_33.npy 36 | data/stanford_indoor3d/Area_6_office_34.npy 37 | data/stanford_indoor3d/Area_6_office_35.npy 38 | data/stanford_indoor3d/Area_6_office_36.npy 39 | data/stanford_indoor3d/Area_6_office_37.npy 40 | data/stanford_indoor3d/Area_6_office_3.npy 41 | data/stanford_indoor3d/Area_6_office_4.npy 42 | data/stanford_indoor3d/Area_6_office_5.npy 43 | data/stanford_indoor3d/Area_6_office_6.npy 44 | data/stanford_indoor3d/Area_6_office_7.npy 45 | data/stanford_indoor3d/Area_6_office_8.npy 46 | data/stanford_indoor3d/Area_6_office_9.npy 47 | data/stanford_indoor3d/Area_6_openspace_1.npy 48 | data/stanford_indoor3d/Area_6_pantry_1.npy 49 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/meta/class_names.txt: -------------------------------------------------------------------------------- 1 | ceiling 2 | floor 3 | wall 4 | beam 5 | column 6 | window 7 | door 8 | table 9 | chair 10 | sofa 11 | bookcase 12 | board 13 | clutter 14 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import math 3 | import time 4 | import numpy as np 5 | import os 6 | import sys 7 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | ROOT_DIR = os.path.dirname(BASE_DIR) 9 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 10 | sys.path.append(os.path.join(BASE_DIR, '../models')) 11 | import tf_util 12 | 13 | def placeholder_inputs(batch_size, num_point): 14 | pointclouds_pl = tf.placeholder(tf.float32, 15 | shape=(batch_size, num_point, 9)) 16 | labels_pl = tf.placeholder(tf.int32, 17 | shape=(batch_size, num_point)) 18 | return pointclouds_pl, labels_pl 19 | 20 | def get_model(point_cloud, is_training, bn_decay=None): 21 | """ ConvNet baseline, input is BxNx9 gray image """ 22 | batch_size = point_cloud.get_shape()[0].value 23 | num_point = point_cloud.get_shape()[1].value 24 | input_image = tf.expand_dims(point_cloud, -1) 25 | 26 | k = 20 27 | 28 | adj = tf_util.pairwise_distance(point_cloud[:, :, 6:]) 29 | nn_idx = tf_util.knn(adj, k=k) # (batch, num_points, k) 30 | edge_feature = tf_util.get_edge_feature(input_image, nn_idx=nn_idx, k=k) 31 | 32 | out1 = tf_util.conv2d(edge_feature, 64, [1,1], 33 | padding='VALID', stride=[1,1], 34 | bn=True, is_training=is_training, weight_decay=weight_decay, 35 | scope='adj_conv1', bn_decay=bn_decay, is_dist=True) 36 | 37 | out2 = tf_util.conv2d(out1, 64, [1,1], 38 | padding='VALID', stride=[1,1], 39 | bn=True, is_training=is_training, weight_decay=weight_decay, 40 | scope='adj_conv2', bn_decay=bn_decay, is_dist=True) 41 | 42 | net_1 = tf.reduce_max(out2, axis=-2, keep_dims=True) 43 | 44 | 45 | 46 | adj = tf_util.pairwise_distance(net_1) 47 | nn_idx = tf_util.knn(adj, k=k) 48 | edge_feature = tf_util.get_edge_feature(net_1, nn_idx=nn_idx, k=k) 49 | 50 | out3 = tf_util.conv2d(edge_feature, 64, [1,1], 51 | padding='VALID', stride=[1,1], 52 | bn=True, is_training=is_training, weight_decay=weight_decay, 53 | scope='adj_conv3', bn_decay=bn_decay, is_dist=True) 54 | 55 | out4 = tf_util.conv2d(out3, 64, [1,1], 56 | padding='VALID', stride=[1,1], 57 | bn=True, is_training=is_training, weight_decay=weight_decay, 58 | scope='adj_conv4', bn_decay=bn_decay, is_dist=True) 59 | 60 | net_2 = tf.reduce_max(out4, axis=-2, keep_dims=True) 61 | 62 | 63 | 64 | adj = tf_util.pairwise_distance(net_2) 65 | nn_idx = tf_util.knn(adj, k=k) 66 | edge_feature = tf_util.get_edge_feature(net_2, nn_idx=nn_idx, k=k) 67 | 68 | out5 = tf_util.conv2d(edge_feature, 64, [1,1], 69 | padding='VALID', stride=[1,1], 70 | bn=True, is_training=is_training, weight_decay=weight_decay, 71 | scope='adj_conv5', bn_decay=bn_decay, is_dist=True) 72 | 73 | # out6 = tf_util.conv2d(out5, 64, [1,1], 74 | # padding='VALID', stride=[1,1], 75 | # bn=True, is_training=is_training, weight_decay=weight_decay, 76 | # scope='adj_conv6', bn_decay=bn_decay, is_dist=True) 77 | 78 | net_3 = tf.reduce_max(out5, axis=-2, keep_dims=True) 79 | 80 | 81 | 82 | out7 = tf_util.conv2d(tf.concat([net_1, net_2, net_3], axis=-1), 1024, [1, 1], 83 | padding='VALID', stride=[1,1], 84 | bn=True, is_training=is_training, 85 | scope='adj_conv7', bn_decay=bn_decay, is_dist=True) 86 | 87 | out_max = tf_util.max_pool2d(out7, [num_point, 1], padding='VALID', scope='maxpool') 88 | 89 | 90 | expand = tf.tile(out_max, [1, num_point, 1, 1]) 91 | 92 | concat = tf.concat(axis=3, values=[expand, 93 | net_1, 94 | net_2, 95 | net_3]) 96 | 97 | # CONV 98 | net = tf_util.conv2d(concat, 512, [1,1], padding='VALID', stride=[1,1], 99 | bn=True, is_training=is_training, scope='seg/conv1', is_dist=True) 100 | net = tf_util.conv2d(net, 256, [1,1], padding='VALID', stride=[1,1], 101 | bn=True, is_training=is_training, scope='seg/conv2', is_dist=True) 102 | net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training, scope='dp1') 103 | net = tf_util.conv2d(net, 13, [1,1], padding='VALID', stride=[1,1], 104 | activation_fn=None, scope='seg/conv3', is_dist=True) 105 | net = tf.squeeze(net, [2]) 106 | 107 | return net 108 | 109 | def get_loss(pred, label): 110 | """ pred: B,N,13; label: B,N """ 111 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=label) 112 | return tf.reduce_mean(loss) 113 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/test_job.sh: -------------------------------------------------------------------------------- 1 | python batch_inference.py --model_path log1/epoch_60.ckpt --dump_dir log1/dump --output_filelist log1/output_filelist.txt --room_data_filelist meta/area1_data_label.txt 2 | python batch_inference.py --model_path log2/epoch_60.ckpt --dump_dir log2/dump --output_filelist log2/output_filelist.txt --room_data_filelist meta/area2_data_label.txt 3 | python batch_inference.py --model_path log3/epoch_60.ckpt --dump_dir log3/dump --output_filelist log3/output_filelist.txt --room_data_filelist meta/area3_data_label.txt 4 | python batch_inference.py --model_path log4/epoch_60.ckpt --dump_dir log4/dump --output_filelist log4/output_filelist.txt --room_data_filelist meta/area4_data_label.txt 5 | python batch_inference.py --model_path log5/epoch_60.ckpt --dump_dir log5/dump --output_filelist log5/output_filelist.txt --room_data_filelist meta/area5_data_label.txt 6 | python batch_inference.py --model_path log6/epoch_60.ckpt --dump_dir log6/dump --output_filelist log6/output_filelist.txt --room_data_filelist meta/area6_data_label.txt -------------------------------------------------------------------------------- /tensorflow/sem_seg/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import h5py 4 | import numpy as np 5 | import tensorflow as tf 6 | import socket 7 | 8 | import os 9 | import sys 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | ROOT_DIR = os.path.dirname(BASE_DIR) 12 | sys.path.append(BASE_DIR) 13 | sys.path.append(ROOT_DIR) 14 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 15 | import provider 16 | import tf_util 17 | from model import * 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--num_gpu', type=int, default=2, help='the number of GPUs to use [default: 2]') 21 | parser.add_argument('--log_dir', default='log', help='Log dir [default: log]') 22 | parser.add_argument('--num_point', type=int, default=4096, help='Point number [default: 4096]') 23 | parser.add_argument('--max_epoch', type=int, default=101, help='Epoch to run [default: 50]') 24 | parser.add_argument('--batch_size', type=int, default=12, help='Batch Size during training for each GPU [default: 24]') 25 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') 26 | parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]') 27 | parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]') 28 | parser.add_argument('--decay_step', type=int, default=300000, help='Decay step for lr decay [default: 300000]') 29 | parser.add_argument('--decay_rate', type=float, default=0.5, help='Decay rate for lr decay [default: 0.5]') 30 | parser.add_argument('--test_area', type=int, default=6, help='Which area to use for test, option: 1-6 [default: 6]') 31 | FLAGS = parser.parse_args() 32 | 33 | TOWER_NAME = 'tower' 34 | 35 | BATCH_SIZE = FLAGS.batch_size 36 | NUM_POINT = FLAGS.num_point 37 | MAX_EPOCH = FLAGS.max_epoch 38 | NUM_POINT = FLAGS.num_point 39 | BASE_LEARNING_RATE = FLAGS.learning_rate 40 | MOMENTUM = FLAGS.momentum 41 | OPTIMIZER = FLAGS.optimizer 42 | DECAY_STEP = FLAGS.decay_step 43 | DECAY_RATE = FLAGS.decay_rate 44 | 45 | LOG_DIR = FLAGS.log_dir 46 | if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR) 47 | os.system('cp model.py %s' % (LOG_DIR)) 48 | os.system('cp train.py %s' % (LOG_DIR)) 49 | LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w') 50 | LOG_FOUT.write(str(FLAGS)+'\n') 51 | 52 | MAX_NUM_POINT = 4096 53 | NUM_CLASSES = 13 54 | 55 | BN_INIT_DECAY = 0.5 56 | BN_DECAY_DECAY_RATE = 0.5 57 | BN_DECAY_DECAY_STEP = float(DECAY_STEP) 58 | BN_DECAY_CLIP = 0.99 59 | 60 | HOSTNAME = socket.gethostname() 61 | 62 | ALL_FILES = provider.getDataFiles('indoor3d_sem_seg_hdf5_data/all_files.txt') 63 | room_filelist = [line.rstrip() for line in open('indoor3d_sem_seg_hdf5_data/room_filelist.txt')] 64 | print len(room_filelist) 65 | 66 | # Load ALL data 67 | data_batch_list = [] 68 | label_batch_list = [] 69 | for h5_filename in ALL_FILES: 70 | data_batch, label_batch = provider.loadDataFile(h5_filename) 71 | data_batch_list.append(data_batch) 72 | label_batch_list.append(label_batch) 73 | data_batches = np.concatenate(data_batch_list, 0) 74 | label_batches = np.concatenate(label_batch_list, 0) 75 | print(data_batches.shape) 76 | print(label_batches.shape) 77 | 78 | test_area = 'Area_'+str(FLAGS.test_area) 79 | train_idxs = [] 80 | test_idxs = [] 81 | for i,room_name in enumerate(room_filelist): 82 | if test_area in room_name: 83 | test_idxs.append(i) 84 | else: 85 | train_idxs.append(i) 86 | 87 | train_data = data_batches[train_idxs,...] 88 | train_label = label_batches[train_idxs] 89 | test_data = data_batches[test_idxs,...] 90 | test_label = label_batches[test_idxs] 91 | print(train_data.shape, train_label.shape) 92 | print(test_data.shape, test_label.shape) 93 | 94 | 95 | def log_string(out_str): 96 | LOG_FOUT.write(out_str+'\n') 97 | LOG_FOUT.flush() 98 | print(out_str) 99 | 100 | 101 | def get_learning_rate(batch): 102 | learning_rate = tf.train.exponential_decay( 103 | BASE_LEARNING_RATE, # Base learning rate. 104 | batch * BATCH_SIZE, # Current index into the dataset. 105 | DECAY_STEP, # Decay step. 106 | DECAY_RATE, # Decay rate. 107 | staircase=True) 108 | learning_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE!! 109 | return learning_rate 110 | 111 | def get_bn_decay(batch): 112 | bn_momentum = tf.train.exponential_decay( 113 | BN_INIT_DECAY, 114 | batch*BATCH_SIZE, 115 | BN_DECAY_DECAY_STEP, 116 | BN_DECAY_DECAY_RATE, 117 | staircase=True) 118 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 119 | return bn_decay 120 | 121 | def average_gradients(tower_grads): 122 | """Calculate average gradient for each shared variable across all towers. 123 | 124 | Note that this function provides a synchronization point across all towers. 125 | 126 | Args: 127 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 128 | is over individual gradients. The inner list is over the gradient 129 | calculation for each tower. 130 | Returns: 131 | List of pairs of (gradient, variable) where the gradient has been 132 | averaged across all towers. 133 | """ 134 | average_grads = [] 135 | for grad_and_vars in zip(*tower_grads): 136 | # Note that each grad_and_vars looks like the following: 137 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 138 | grads = [] 139 | for g, _ in grad_and_vars: 140 | expanded_g = tf.expand_dims(g, 0) 141 | grads.append(expanded_g) 142 | 143 | # Average over the 'tower' dimension. 144 | grad = tf.concat(grads, 0) 145 | grad = tf.reduce_mean(grad, 0) 146 | 147 | # Keep in mind that the Variables are redundant because they are shared 148 | # across towers. So .. we will just return the first tower's pointer to 149 | # the Variable. 150 | v = grad_and_vars[0][1] 151 | grad_and_var = (grad, v) 152 | average_grads.append(grad_and_var) 153 | return average_grads 154 | 155 | def train(): 156 | with tf.Graph().as_default(), tf.device('/cpu:0'): 157 | batch = tf.Variable(0, trainable=False) 158 | 159 | bn_decay = get_bn_decay(batch) 160 | tf.summary.scalar('bn_decay', bn_decay) 161 | 162 | learning_rate = get_learning_rate(batch) 163 | tf.summary.scalar('learning_rate', learning_rate) 164 | 165 | trainer = tf.train.AdamOptimizer(learning_rate) 166 | 167 | tower_grads = [] 168 | pointclouds_phs = [] 169 | labels_phs = [] 170 | is_training_phs =[] 171 | 172 | with tf.variable_scope(tf.get_variable_scope()): 173 | for i in xrange(FLAGS.num_gpu): 174 | with tf.device('/gpu:%d' % i): 175 | with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope: 176 | 177 | pointclouds_pl, labels_pl = placeholder_inputs(BATCH_SIZE, NUM_POINT) 178 | is_training_pl = tf.placeholder(tf.bool, shape=()) 179 | 180 | pointclouds_phs.append(pointclouds_pl) 181 | labels_phs.append(labels_pl) 182 | is_training_phs.append(is_training_pl) 183 | 184 | pred = get_model(pointclouds_phs[-1], is_training_phs[-1], bn_decay=bn_decay) 185 | loss = get_loss(pred, labels_phs[-1]) 186 | tf.summary.scalar('loss', loss) 187 | 188 | correct = tf.equal(tf.argmax(pred, 2), tf.to_int64(labels_phs[-1])) 189 | accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE*NUM_POINT) 190 | tf.summary.scalar('accuracy', accuracy) 191 | 192 | tf.get_variable_scope().reuse_variables() 193 | 194 | grads = trainer.compute_gradients(loss) 195 | 196 | tower_grads.append(grads) 197 | 198 | grads = average_gradients(tower_grads) 199 | 200 | train_op = trainer.apply_gradients(grads, global_step=batch) 201 | 202 | saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10) 203 | 204 | # Create a session 205 | config = tf.ConfigProto() 206 | config.gpu_options.allow_growth = True 207 | config.allow_soft_placement = True 208 | sess = tf.Session(config=config) 209 | 210 | # Add summary writers 211 | merged = tf.summary.merge_all() 212 | train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), 213 | sess.graph) 214 | test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test')) 215 | 216 | # Init variables for two GPUs 217 | init = tf.group(tf.global_variables_initializer(), 218 | tf.local_variables_initializer()) 219 | sess.run(init) 220 | 221 | ops = {'pointclouds_phs': pointclouds_phs, 222 | 'labels_phs': labels_phs, 223 | 'is_training_phs': is_training_phs, 224 | 'pred': pred, 225 | 'loss': loss, 226 | 'train_op': train_op, 227 | 'merged': merged, 228 | 'step': batch} 229 | 230 | for epoch in range(MAX_EPOCH): 231 | log_string('**** EPOCH %03d ****' % (epoch)) 232 | sys.stdout.flush() 233 | 234 | train_one_epoch(sess, ops, train_writer) 235 | 236 | # Save the variables to disk. 237 | if epoch % 10 == 0: 238 | save_path = saver.save(sess, os.path.join(LOG_DIR,'epoch_' + str(epoch)+'.ckpt')) 239 | log_string("Model saved in file: %s" % save_path) 240 | 241 | 242 | 243 | def train_one_epoch(sess, ops, train_writer): 244 | """ ops: dict mapping from string to tf ops """ 245 | is_training = True 246 | 247 | log_string('----') 248 | current_data, current_label, _ = provider.shuffle_data(train_data[:,0:NUM_POINT,:], train_label) 249 | 250 | file_size = current_data.shape[0] 251 | num_batches = file_size // (FLAGS.num_gpu * BATCH_SIZE) 252 | 253 | total_correct = 0 254 | total_seen = 0 255 | loss_sum = 0 256 | 257 | for batch_idx in range(num_batches): 258 | if batch_idx % 100 == 0: 259 | print('Current batch/total batch num: %d/%d'%(batch_idx,num_batches)) 260 | start_idx_0 = batch_idx * BATCH_SIZE 261 | end_idx_0 = (batch_idx+1) * BATCH_SIZE 262 | start_idx_1 = (batch_idx+1) * BATCH_SIZE 263 | end_idx_1 = (batch_idx+2) * BATCH_SIZE 264 | 265 | 266 | feed_dict = {ops['pointclouds_phs'][0]: current_data[start_idx_0:end_idx_0, :, :], 267 | ops['pointclouds_phs'][1]: current_data[start_idx_1:end_idx_1, :, :], 268 | ops['labels_phs'][0]: current_label[start_idx_0:end_idx_0], 269 | ops['labels_phs'][1]: current_label[start_idx_1:end_idx_1], 270 | ops['is_training_phs'][0]: is_training, 271 | ops['is_training_phs'][1]: is_training} 272 | summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']], 273 | feed_dict=feed_dict) 274 | train_writer.add_summary(summary, step) 275 | pred_val = np.argmax(pred_val, 2) 276 | correct = np.sum(pred_val == current_label[start_idx_1:end_idx_1]) 277 | total_correct += correct 278 | total_seen += (BATCH_SIZE*NUM_POINT) 279 | loss_sum += loss_val 280 | 281 | log_string('mean loss: %f' % (loss_sum / float(num_batches))) 282 | log_string('accuracy: %f' % (total_correct / float(total_seen))) 283 | 284 | if __name__ == "__main__": 285 | train() 286 | LOG_FOUT.close() 287 | -------------------------------------------------------------------------------- /tensorflow/sem_seg/train_job.sh: -------------------------------------------------------------------------------- 1 | python train.py --log_dir log1 --test_area 1 2 | python train.py --log_dir log2 --test_area 2 3 | python train.py --log_dir log3 --test_area 3 4 | python train.py --log_dir log4 --test_area 4 5 | python train.py --log_dir log5 --test_area 5 6 | python train.py --log_dir log6 --test_area 6 -------------------------------------------------------------------------------- /tensorflow/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import h5py 4 | import numpy as np 5 | import tensorflow as tf 6 | import socket 7 | import importlib 8 | import os 9 | import sys 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(BASE_DIR) 12 | sys.path.append(os.path.join(BASE_DIR, 'models')) 13 | sys.path.append(os.path.join(BASE_DIR, 'utils')) 14 | import provider 15 | import tf_util 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') 19 | parser.add_argument('--model', default='dgcnn', help='Model name: dgcnn') 20 | parser.add_argument('--log_dir', default='log', help='Log dir [default: log]') 21 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]') 22 | parser.add_argument('--max_epoch', type=int, default=250, help='Epoch to run [default: 250]') 23 | parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 32]') 24 | parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') 25 | parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]') 26 | parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]') 27 | parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]') 28 | parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.8]') 29 | FLAGS = parser.parse_args() 30 | 31 | 32 | BATCH_SIZE = FLAGS.batch_size 33 | NUM_POINT = FLAGS.num_point 34 | MAX_EPOCH = FLAGS.max_epoch 35 | BASE_LEARNING_RATE = FLAGS.learning_rate 36 | GPU_INDEX = FLAGS.gpu 37 | MOMENTUM = FLAGS.momentum 38 | OPTIMIZER = FLAGS.optimizer 39 | DECAY_STEP = FLAGS.decay_step 40 | DECAY_RATE = FLAGS.decay_rate 41 | 42 | MODEL = importlib.import_module(FLAGS.model) # import network module 43 | MODEL_FILE = os.path.join(BASE_DIR, 'models', FLAGS.model+'.py') 44 | LOG_DIR = FLAGS.log_dir 45 | if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR) 46 | os.system('cp %s %s' % (MODEL_FILE, LOG_DIR)) # bkp of model def 47 | os.system('cp train.py %s' % (LOG_DIR)) # bkp of train procedure 48 | LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w') 49 | LOG_FOUT.write(str(FLAGS)+'\n') 50 | 51 | MAX_NUM_POINT = 2048 52 | NUM_CLASSES = 40 53 | 54 | BN_INIT_DECAY = 0.5 55 | BN_DECAY_DECAY_RATE = 0.5 56 | BN_DECAY_DECAY_STEP = float(DECAY_STEP) 57 | BN_DECAY_CLIP = 0.99 58 | 59 | HOSTNAME = socket.gethostname() 60 | 61 | # ModelNet40 official train/test split 62 | TRAIN_FILES = provider.getDataFiles( \ 63 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt')) 64 | TEST_FILES = provider.getDataFiles(\ 65 | os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt')) 66 | 67 | def log_string(out_str): 68 | LOG_FOUT.write(out_str+'\n') 69 | LOG_FOUT.flush() 70 | print(out_str) 71 | 72 | 73 | def get_learning_rate(batch): 74 | learning_rate = tf.train.exponential_decay( 75 | BASE_LEARNING_RATE, # Base learning rate. 76 | batch * BATCH_SIZE, # Current index into the dataset. 77 | DECAY_STEP, # Decay step. 78 | DECAY_RATE, # Decay rate. 79 | staircase=True) 80 | learning_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE! 81 | return learning_rate 82 | 83 | def get_bn_decay(batch): 84 | bn_momentum = tf.train.exponential_decay( 85 | BN_INIT_DECAY, 86 | batch*BATCH_SIZE, 87 | BN_DECAY_DECAY_STEP, 88 | BN_DECAY_DECAY_RATE, 89 | staircase=True) 90 | bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) 91 | return bn_decay 92 | 93 | def train(): 94 | with tf.Graph().as_default(): 95 | with tf.device('/gpu:'+str(GPU_INDEX)): 96 | pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT) 97 | is_training_pl = tf.placeholder(tf.bool, shape=()) 98 | print(is_training_pl) 99 | 100 | # Note the global_step=batch parameter to minimize. 101 | # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains. 102 | batch = tf.Variable(0) 103 | bn_decay = get_bn_decay(batch) 104 | tf.summary.scalar('bn_decay', bn_decay) 105 | 106 | # Get model and loss 107 | pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay) 108 | loss = MODEL.get_loss(pred, labels_pl, end_points) 109 | tf.summary.scalar('loss', loss) 110 | 111 | correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl)) 112 | accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE) 113 | tf.summary.scalar('accuracy', accuracy) 114 | 115 | # Get training operator 116 | learning_rate = get_learning_rate(batch) 117 | tf.summary.scalar('learning_rate', learning_rate) 118 | if OPTIMIZER == 'momentum': 119 | optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM) 120 | elif OPTIMIZER == 'adam': 121 | optimizer = tf.train.AdamOptimizer(learning_rate) 122 | train_op = optimizer.minimize(loss, global_step=batch) 123 | 124 | # Add ops to save and restore all the variables. 125 | saver = tf.train.Saver() 126 | 127 | # Create a session 128 | config = tf.ConfigProto() 129 | config.gpu_options.allow_growth = True 130 | config.allow_soft_placement = True 131 | config.log_device_placement = False 132 | sess = tf.Session(config=config) 133 | 134 | # Add summary writers 135 | #merged = tf.merge_all_summaries() 136 | merged = tf.summary.merge_all() 137 | train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), 138 | sess.graph) 139 | test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test')) 140 | 141 | # Init variables 142 | init = tf.global_variables_initializer() 143 | # To fix the bug introduced in TF 0.12.1 as in 144 | # http://stackoverflow.com/questions/41543774/invalidargumenterror-for-tensor-bool-tensorflow-0-12-1 145 | #sess.run(init) 146 | sess.run(init, {is_training_pl: True}) 147 | 148 | ops = {'pointclouds_pl': pointclouds_pl, 149 | 'labels_pl': labels_pl, 150 | 'is_training_pl': is_training_pl, 151 | 'pred': pred, 152 | 'loss': loss, 153 | 'train_op': train_op, 154 | 'merged': merged, 155 | 'step': batch} 156 | 157 | for epoch in range(MAX_EPOCH): 158 | log_string('**** EPOCH %03d ****' % (epoch)) 159 | sys.stdout.flush() 160 | 161 | train_one_epoch(sess, ops, train_writer) 162 | eval_one_epoch(sess, ops, test_writer) 163 | 164 | # Save the variables to disk. 165 | if epoch % 10 == 0: 166 | save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt")) 167 | log_string("Model saved in file: %s" % save_path) 168 | 169 | 170 | 171 | def train_one_epoch(sess, ops, train_writer): 172 | """ ops: dict mapping from string to tf ops """ 173 | is_training = True 174 | 175 | # Shuffle train files 176 | train_file_idxs = np.arange(0, len(TRAIN_FILES)) 177 | np.random.shuffle(train_file_idxs) 178 | 179 | for fn in range(len(TRAIN_FILES)): 180 | log_string('----' + str(fn) + '-----') 181 | current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]]) 182 | current_data = current_data[:,0:NUM_POINT,:] 183 | current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label)) 184 | current_label = np.squeeze(current_label) 185 | 186 | file_size = current_data.shape[0] 187 | num_batches = file_size // BATCH_SIZE 188 | 189 | total_correct = 0 190 | total_seen = 0 191 | loss_sum = 0 192 | 193 | for batch_idx in range(num_batches): 194 | start_idx = batch_idx * BATCH_SIZE 195 | end_idx = (batch_idx+1) * BATCH_SIZE 196 | 197 | # Augment batched point clouds by rotation and jittering 198 | rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :]) 199 | jittered_data = provider.jitter_point_cloud(rotated_data) 200 | jittered_data = provider.random_scale_point_cloud(jittered_data) 201 | jittered_data = provider.rotate_perturbation_point_cloud(jittered_data) 202 | jittered_data = provider.shift_point_cloud(jittered_data) 203 | 204 | feed_dict = {ops['pointclouds_pl']: jittered_data, 205 | ops['labels_pl']: current_label[start_idx:end_idx], 206 | ops['is_training_pl']: is_training,} 207 | summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], 208 | ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict) 209 | train_writer.add_summary(summary, step) 210 | pred_val = np.argmax(pred_val, 1) 211 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) 212 | total_correct += correct 213 | total_seen += BATCH_SIZE 214 | loss_sum += loss_val 215 | 216 | log_string('mean loss: %f' % (loss_sum / float(num_batches))) 217 | log_string('accuracy: %f' % (total_correct / float(total_seen))) 218 | 219 | 220 | def eval_one_epoch(sess, ops, test_writer): 221 | """ ops: dict mapping from string to tf ops """ 222 | is_training = False 223 | total_correct = 0 224 | total_seen = 0 225 | loss_sum = 0 226 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 227 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 228 | 229 | for fn in range(len(TEST_FILES)): 230 | log_string('----' + str(fn) + '-----') 231 | current_data, current_label = provider.loadDataFile(TEST_FILES[fn]) 232 | current_data = current_data[:,0:NUM_POINT,:] 233 | current_label = np.squeeze(current_label) 234 | 235 | file_size = current_data.shape[0] 236 | num_batches = file_size // BATCH_SIZE 237 | 238 | for batch_idx in range(num_batches): 239 | start_idx = batch_idx * BATCH_SIZE 240 | end_idx = (batch_idx+1) * BATCH_SIZE 241 | 242 | feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :], 243 | ops['labels_pl']: current_label[start_idx:end_idx], 244 | ops['is_training_pl']: is_training} 245 | summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'], 246 | ops['loss'], ops['pred']], feed_dict=feed_dict) 247 | pred_val = np.argmax(pred_val, 1) 248 | correct = np.sum(pred_val == current_label[start_idx:end_idx]) 249 | total_correct += correct 250 | total_seen += BATCH_SIZE 251 | loss_sum += (loss_val*BATCH_SIZE) 252 | for i in range(start_idx, end_idx): 253 | l = current_label[i] 254 | total_seen_class[l] += 1 255 | total_correct_class[l] += (pred_val[i-start_idx] == l) 256 | 257 | log_string('eval mean loss: %f' % (loss_sum / float(total_seen))) 258 | log_string('eval accuracy: %f'% (total_correct / float(total_seen))) 259 | log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)))) 260 | 261 | 262 | 263 | if __name__ == "__main__": 264 | train() 265 | LOG_FOUT.close() 266 | -------------------------------------------------------------------------------- /tensorflow/utils/data_prep_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(BASE_DIR) 5 | from plyfile import (PlyData, PlyElement, make2d, PlyParseError, PlyProperty) 6 | import numpy as np 7 | import h5py 8 | 9 | SAMPLING_BIN = os.path.join(BASE_DIR, 'third_party/mesh_sampling/build/pcsample') 10 | 11 | SAMPLING_POINT_NUM = 2048 12 | SAMPLING_LEAF_SIZE = 0.005 13 | 14 | MODELNET40_PATH = '../datasets/modelnet40' 15 | def export_ply(pc, filename): 16 | vertex = np.zeros(pc.shape[0], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 17 | for i in range(pc.shape[0]): 18 | vertex[i] = (pc[i][0], pc[i][1], pc[i][2]) 19 | ply_out = PlyData([PlyElement.describe(vertex, 'vertex', comments=['vertices'])]) 20 | ply_out.write(filename) 21 | 22 | # Sample points on the obj shape 23 | def get_sampling_command(obj_filename, ply_filename): 24 | cmd = SAMPLING_BIN + ' ' + obj_filename 25 | cmd += ' ' + ply_filename 26 | cmd += ' -n_samples %d ' % SAMPLING_POINT_NUM 27 | cmd += ' -leaf_size %f ' % SAMPLING_LEAF_SIZE 28 | return cmd 29 | 30 | # -------------------------------------------------------------- 31 | # Following are the helper functions to load MODELNET40 shapes 32 | # -------------------------------------------------------------- 33 | 34 | # Read in the list of categories in MODELNET40 35 | def get_category_names(): 36 | shape_names_file = os.path.join(MODELNET40_PATH, 'shape_names.txt') 37 | shape_names = [line.rstrip() for line in open(shape_names_file)] 38 | return shape_names 39 | 40 | # Return all the filepaths for the shapes in MODELNET40 41 | def get_obj_filenames(): 42 | obj_filelist_file = os.path.join(MODELNET40_PATH, 'filelist.txt') 43 | obj_filenames = [os.path.join(MODELNET40_PATH, line.rstrip()) for line in open(obj_filelist_file)] 44 | print('Got %d obj files in modelnet40.' % len(obj_filenames)) 45 | return obj_filenames 46 | 47 | # Helper function to create the father folder and all subdir folders if not exist 48 | def batch_mkdir(output_folder, subdir_list): 49 | if not os.path.exists(output_folder): 50 | os.mkdir(output_folder) 51 | for subdir in subdir_list: 52 | if not os.path.exists(os.path.join(output_folder, subdir)): 53 | os.mkdir(os.path.join(output_folder, subdir)) 54 | 55 | # ---------------------------------------------------------------- 56 | # Following are the helper functions to load save/load HDF5 files 57 | # ---------------------------------------------------------------- 58 | 59 | # Write numpy array data and label to h5_filename 60 | def save_h5_data_label_normal(h5_filename, data, label, normal, 61 | data_dtype='float32', label_dtype='uint8', noral_dtype='float32'): 62 | h5_fout = h5py.File(h5_filename) 63 | h5_fout.create_dataset( 64 | 'data', data=data, 65 | compression='gzip', compression_opts=4, 66 | dtype=data_dtype) 67 | h5_fout.create_dataset( 68 | 'normal', data=normal, 69 | compression='gzip', compression_opts=4, 70 | dtype=normal_dtype) 71 | h5_fout.create_dataset( 72 | 'label', data=label, 73 | compression='gzip', compression_opts=1, 74 | dtype=label_dtype) 75 | h5_fout.close() 76 | 77 | 78 | # Write numpy array data and label to h5_filename 79 | def save_h5(h5_filename, data, label, data_dtype='uint8', label_dtype='uint8'): 80 | h5_fout = h5py.File(h5_filename) 81 | h5_fout.create_dataset( 82 | 'data', data=data, 83 | compression='gzip', compression_opts=4, 84 | dtype=data_dtype) 85 | h5_fout.create_dataset( 86 | 'label', data=label, 87 | compression='gzip', compression_opts=1, 88 | dtype=label_dtype) 89 | h5_fout.close() 90 | 91 | # Read numpy array data and label from h5_filename 92 | def load_h5_data_label_normal(h5_filename): 93 | f = h5py.File(h5_filename) 94 | data = f['data'][:] 95 | label = f['label'][:] 96 | normal = f['normal'][:] 97 | return (data, label, normal) 98 | 99 | # Read numpy array data and label from h5_filename 100 | def load_h5_data_label_seg(h5_filename): 101 | f = h5py.File(h5_filename) 102 | data = f['data'][:] 103 | label = f['label'][:] 104 | seg = f['pid'][:] 105 | return (data, label, seg) 106 | 107 | # Read numpy array data and label from h5_filename 108 | def load_h5(h5_filename): 109 | f = h5py.File(h5_filename) 110 | data = f['data'][:] 111 | label = f['label'][:] 112 | return (data, label) 113 | 114 | # ---------------------------------------------------------------- 115 | # Following are the helper functions to load save/load PLY files 116 | # ---------------------------------------------------------------- 117 | 118 | # Load PLY file 119 | def load_ply_data(filename, point_num): 120 | plydata = PlyData.read(filename) 121 | pc = plydata['vertex'].data[:point_num] 122 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 123 | return pc_array 124 | 125 | # Load PLY file 126 | def load_ply_normal(filename, point_num): 127 | plydata = PlyData.read(filename) 128 | pc = plydata['normal'].data[:point_num] 129 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 130 | return pc_array 131 | 132 | # Make up rows for Nxk array 133 | # Input Pad is 'edge' or 'constant' 134 | def pad_arr_rows(arr, row, pad='edge'): 135 | assert(len(arr.shape) == 2) 136 | assert(arr.shape[0] <= row) 137 | assert(pad == 'edge' or pad == 'constant') 138 | if arr.shape[0] == row: 139 | return arr 140 | if pad == 'edge': 141 | return np.lib.pad(arr, ((0, row-arr.shape[0]), (0, 0)), 'edge') 142 | if pad == 'constant': 143 | return np.lib.pad(arr, ((0, row-arr.shape[0]), (0, 0)), 'constant', (0, 0)) 144 | 145 | 146 | -------------------------------------------------------------------------------- /tensorflow/utils/eulerangles.py: -------------------------------------------------------------------------------- 1 | # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 2 | # vi: set ft=python sts=4 ts=4 sw=4 et: 3 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 4 | # 5 | # See COPYING file distributed along with the NiBabel package for the 6 | # copyright and license terms. 7 | # 8 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 9 | ''' Module implementing Euler angle rotations and their conversions 10 | 11 | See: 12 | 13 | * http://en.wikipedia.org/wiki/Rotation_matrix 14 | * http://en.wikipedia.org/wiki/Euler_angles 15 | * http://mathworld.wolfram.com/EulerAngles.html 16 | 17 | See also: *Representing Attitude with Euler Angles and Quaternions: A 18 | Reference* (2006) by James Diebel. A cached PDF link last found here: 19 | 20 | http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.110.5134 21 | 22 | Euler's rotation theorem tells us that any rotation in 3D can be 23 | described by 3 angles. Let's call the 3 angles the *Euler angle vector* 24 | and call the angles in the vector :math:`alpha`, :math:`beta` and 25 | :math:`gamma`. The vector is [ :math:`alpha`, 26 | :math:`beta`. :math:`gamma` ] and, in this description, the order of the 27 | parameters specifies the order in which the rotations occur (so the 28 | rotation corresponding to :math:`alpha` is applied first). 29 | 30 | In order to specify the meaning of an *Euler angle vector* we need to 31 | specify the axes around which each of the rotations corresponding to 32 | :math:`alpha`, :math:`beta` and :math:`gamma` will occur. 33 | 34 | There are therefore three axes for the rotations :math:`alpha`, 35 | :math:`beta` and :math:`gamma`; let's call them :math:`i` :math:`j`, 36 | :math:`k`. 37 | 38 | Let us express the rotation :math:`alpha` around axis `i` as a 3 by 3 39 | rotation matrix `A`. Similarly :math:`beta` around `j` becomes 3 x 3 40 | matrix `B` and :math:`gamma` around `k` becomes matrix `G`. Then the 41 | whole rotation expressed by the Euler angle vector [ :math:`alpha`, 42 | :math:`beta`. :math:`gamma` ], `R` is given by:: 43 | 44 | R = np.dot(G, np.dot(B, A)) 45 | 46 | See http://mathworld.wolfram.com/EulerAngles.html 47 | 48 | The order :math:`G B A` expresses the fact that the rotations are 49 | performed in the order of the vector (:math:`alpha` around axis `i` = 50 | `A` first). 51 | 52 | To convert a given Euler angle vector to a meaningful rotation, and a 53 | rotation matrix, we need to define: 54 | 55 | * the axes `i`, `j`, `k` 56 | * whether a rotation matrix should be applied on the left of a vector to 57 | be transformed (vectors are column vectors) or on the right (vectors 58 | are row vectors). 59 | * whether the rotations move the axes as they are applied (intrinsic 60 | rotations) - compared the situation where the axes stay fixed and the 61 | vectors move within the axis frame (extrinsic) 62 | * the handedness of the coordinate system 63 | 64 | See: http://en.wikipedia.org/wiki/Rotation_matrix#Ambiguities 65 | 66 | We are using the following conventions: 67 | 68 | * axes `i`, `j`, `k` are the `z`, `y`, and `x` axes respectively. Thus 69 | an Euler angle vector [ :math:`alpha`, :math:`beta`. :math:`gamma` ] 70 | in our convention implies a :math:`alpha` radian rotation around the 71 | `z` axis, followed by a :math:`beta` rotation around the `y` axis, 72 | followed by a :math:`gamma` rotation around the `x` axis. 73 | * the rotation matrix applies on the left, to column vectors on the 74 | right, so if `R` is the rotation matrix, and `v` is a 3 x N matrix 75 | with N column vectors, the transformed vector set `vdash` is given by 76 | ``vdash = np.dot(R, v)``. 77 | * extrinsic rotations - the axes are fixed, and do not move with the 78 | rotations. 79 | * a right-handed coordinate system 80 | 81 | The convention of rotation around ``z``, followed by rotation around 82 | ``y``, followed by rotation around ``x``, is known (confusingly) as 83 | "xyz", pitch-roll-yaw, Cardan angles, or Tait-Bryan angles. 84 | ''' 85 | 86 | import math 87 | 88 | import sys 89 | if sys.version_info >= (3,0): 90 | from functools import reduce 91 | 92 | import numpy as np 93 | 94 | 95 | _FLOAT_EPS_4 = np.finfo(float).eps * 4.0 96 | 97 | 98 | def euler2mat(z=0, y=0, x=0): 99 | ''' Return matrix for rotations around z, y and x axes 100 | 101 | Uses the z, then y, then x convention above 102 | 103 | Parameters 104 | ---------- 105 | z : scalar 106 | Rotation angle in radians around z-axis (performed first) 107 | y : scalar 108 | Rotation angle in radians around y-axis 109 | x : scalar 110 | Rotation angle in radians around x-axis (performed last) 111 | 112 | Returns 113 | ------- 114 | M : array shape (3,3) 115 | Rotation matrix giving same rotation as for given angles 116 | 117 | Examples 118 | -------- 119 | >>> zrot = 1.3 # radians 120 | >>> yrot = -0.1 121 | >>> xrot = 0.2 122 | >>> M = euler2mat(zrot, yrot, xrot) 123 | >>> M.shape == (3, 3) 124 | True 125 | 126 | The output rotation matrix is equal to the composition of the 127 | individual rotations 128 | 129 | >>> M1 = euler2mat(zrot) 130 | >>> M2 = euler2mat(0, yrot) 131 | >>> M3 = euler2mat(0, 0, xrot) 132 | >>> composed_M = np.dot(M3, np.dot(M2, M1)) 133 | >>> np.allclose(M, composed_M) 134 | True 135 | 136 | You can specify rotations by named arguments 137 | 138 | >>> np.all(M3 == euler2mat(x=xrot)) 139 | True 140 | 141 | When applying M to a vector, the vector should column vector to the 142 | right of M. If the right hand side is a 2D array rather than a 143 | vector, then each column of the 2D array represents a vector. 144 | 145 | >>> vec = np.array([1, 0, 0]).reshape((3,1)) 146 | >>> v2 = np.dot(M, vec) 147 | >>> vecs = np.array([[1, 0, 0],[0, 1, 0]]).T # giving 3x2 array 148 | >>> vecs2 = np.dot(M, vecs) 149 | 150 | Rotations are counter-clockwise. 151 | 152 | >>> zred = np.dot(euler2mat(z=np.pi/2), np.eye(3)) 153 | >>> np.allclose(zred, [[0, -1, 0],[1, 0, 0], [0, 0, 1]]) 154 | True 155 | >>> yred = np.dot(euler2mat(y=np.pi/2), np.eye(3)) 156 | >>> np.allclose(yred, [[0, 0, 1],[0, 1, 0], [-1, 0, 0]]) 157 | True 158 | >>> xred = np.dot(euler2mat(x=np.pi/2), np.eye(3)) 159 | >>> np.allclose(xred, [[1, 0, 0],[0, 0, -1], [0, 1, 0]]) 160 | True 161 | 162 | Notes 163 | ----- 164 | The direction of rotation is given by the right-hand rule (orient 165 | the thumb of the right hand along the axis around which the rotation 166 | occurs, with the end of the thumb at the positive end of the axis; 167 | curl your fingers; the direction your fingers curl is the direction 168 | of rotation). Therefore, the rotations are counterclockwise if 169 | looking along the axis of rotation from positive to negative. 170 | ''' 171 | Ms = [] 172 | if z: 173 | cosz = math.cos(z) 174 | sinz = math.sin(z) 175 | Ms.append(np.array( 176 | [[cosz, -sinz, 0], 177 | [sinz, cosz, 0], 178 | [0, 0, 1]])) 179 | if y: 180 | cosy = math.cos(y) 181 | siny = math.sin(y) 182 | Ms.append(np.array( 183 | [[cosy, 0, siny], 184 | [0, 1, 0], 185 | [-siny, 0, cosy]])) 186 | if x: 187 | cosx = math.cos(x) 188 | sinx = math.sin(x) 189 | Ms.append(np.array( 190 | [[1, 0, 0], 191 | [0, cosx, -sinx], 192 | [0, sinx, cosx]])) 193 | if Ms: 194 | return reduce(np.dot, Ms[::-1]) 195 | return np.eye(3) 196 | 197 | 198 | def mat2euler(M, cy_thresh=None): 199 | ''' Discover Euler angle vector from 3x3 matrix 200 | 201 | Uses the conventions above. 202 | 203 | Parameters 204 | ---------- 205 | M : array-like, shape (3,3) 206 | cy_thresh : None or scalar, optional 207 | threshold below which to give up on straightforward arctan for 208 | estimating x rotation. If None (default), estimate from 209 | precision of input. 210 | 211 | Returns 212 | ------- 213 | z : scalar 214 | y : scalar 215 | x : scalar 216 | Rotations in radians around z, y, x axes, respectively 217 | 218 | Notes 219 | ----- 220 | If there was no numerical error, the routine could be derived using 221 | Sympy expression for z then y then x rotation matrix, which is:: 222 | 223 | [ cos(y)*cos(z), -cos(y)*sin(z), sin(y)], 224 | [cos(x)*sin(z) + cos(z)*sin(x)*sin(y), cos(x)*cos(z) - sin(x)*sin(y)*sin(z), -cos(y)*sin(x)], 225 | [sin(x)*sin(z) - cos(x)*cos(z)*sin(y), cos(z)*sin(x) + cos(x)*sin(y)*sin(z), cos(x)*cos(y)] 226 | 227 | with the obvious derivations for z, y, and x 228 | 229 | z = atan2(-r12, r11) 230 | y = asin(r13) 231 | x = atan2(-r23, r33) 232 | 233 | Problems arise when cos(y) is close to zero, because both of:: 234 | 235 | z = atan2(cos(y)*sin(z), cos(y)*cos(z)) 236 | x = atan2(cos(y)*sin(x), cos(x)*cos(y)) 237 | 238 | will be close to atan2(0, 0), and highly unstable. 239 | 240 | The ``cy`` fix for numerical instability below is from: *Graphics 241 | Gems IV*, Paul Heckbert (editor), Academic Press, 1994, ISBN: 242 | 0123361559. Specifically it comes from EulerAngles.c by Ken 243 | Shoemake, and deals with the case where cos(y) is close to zero: 244 | 245 | See: http://www.graphicsgems.org/ 246 | 247 | The code appears to be licensed (from the website) as "can be used 248 | without restrictions". 249 | ''' 250 | M = np.asarray(M) 251 | if cy_thresh is None: 252 | try: 253 | cy_thresh = np.finfo(M.dtype).eps * 4 254 | except ValueError: 255 | cy_thresh = _FLOAT_EPS_4 256 | r11, r12, r13, r21, r22, r23, r31, r32, r33 = M.flat 257 | # cy: sqrt((cos(y)*cos(z))**2 + (cos(x)*cos(y))**2) 258 | cy = math.sqrt(r33*r33 + r23*r23) 259 | if cy > cy_thresh: # cos(y) not close to zero, standard form 260 | z = math.atan2(-r12, r11) # atan2(cos(y)*sin(z), cos(y)*cos(z)) 261 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 262 | x = math.atan2(-r23, r33) # atan2(cos(y)*sin(x), cos(x)*cos(y)) 263 | else: # cos(y) (close to) zero, so x -> 0.0 (see above) 264 | # so r21 -> sin(z), r22 -> cos(z) and 265 | z = math.atan2(r21, r22) 266 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 267 | x = 0.0 268 | return z, y, x 269 | 270 | 271 | def euler2quat(z=0, y=0, x=0): 272 | ''' Return quaternion corresponding to these Euler angles 273 | 274 | Uses the z, then y, then x convention above 275 | 276 | Parameters 277 | ---------- 278 | z : scalar 279 | Rotation angle in radians around z-axis (performed first) 280 | y : scalar 281 | Rotation angle in radians around y-axis 282 | x : scalar 283 | Rotation angle in radians around x-axis (performed last) 284 | 285 | Returns 286 | ------- 287 | quat : array shape (4,) 288 | Quaternion in w, x, y z (real, then vector) format 289 | 290 | Notes 291 | ----- 292 | We can derive this formula in Sympy using: 293 | 294 | 1. Formula giving quaternion corresponding to rotation of theta radians 295 | about arbitrary axis: 296 | http://mathworld.wolfram.com/EulerParameters.html 297 | 2. Generated formulae from 1.) for quaternions corresponding to 298 | theta radians rotations about ``x, y, z`` axes 299 | 3. Apply quaternion multiplication formula - 300 | http://en.wikipedia.org/wiki/Quaternions#Hamilton_product - to 301 | formulae from 2.) to give formula for combined rotations. 302 | ''' 303 | z = z/2.0 304 | y = y/2.0 305 | x = x/2.0 306 | cz = math.cos(z) 307 | sz = math.sin(z) 308 | cy = math.cos(y) 309 | sy = math.sin(y) 310 | cx = math.cos(x) 311 | sx = math.sin(x) 312 | return np.array([ 313 | cx*cy*cz - sx*sy*sz, 314 | cx*sy*sz + cy*cz*sx, 315 | cx*cz*sy - sx*cy*sz, 316 | cx*cy*sz + sx*cz*sy]) 317 | 318 | 319 | def quat2euler(q): 320 | ''' Return Euler angles corresponding to quaternion `q` 321 | 322 | Parameters 323 | ---------- 324 | q : 4 element sequence 325 | w, x, y, z of quaternion 326 | 327 | Returns 328 | ------- 329 | z : scalar 330 | Rotation angle in radians around z-axis (performed first) 331 | y : scalar 332 | Rotation angle in radians around y-axis 333 | x : scalar 334 | Rotation angle in radians around x-axis (performed last) 335 | 336 | Notes 337 | ----- 338 | It's possible to reduce the amount of calculation a little, by 339 | combining parts of the ``quat2mat`` and ``mat2euler`` functions, but 340 | the reduction in computation is small, and the code repetition is 341 | large. 342 | ''' 343 | # delayed import to avoid cyclic dependencies 344 | import nibabel.quaternions as nq 345 | return mat2euler(nq.quat2mat(q)) 346 | 347 | 348 | def euler2angle_axis(z=0, y=0, x=0): 349 | ''' Return angle, axis corresponding to these Euler angles 350 | 351 | Uses the z, then y, then x convention above 352 | 353 | Parameters 354 | ---------- 355 | z : scalar 356 | Rotation angle in radians around z-axis (performed first) 357 | y : scalar 358 | Rotation angle in radians around y-axis 359 | x : scalar 360 | Rotation angle in radians around x-axis (performed last) 361 | 362 | Returns 363 | ------- 364 | theta : scalar 365 | angle of rotation 366 | vector : array shape (3,) 367 | axis around which rotation occurs 368 | 369 | Examples 370 | -------- 371 | >>> theta, vec = euler2angle_axis(0, 1.5, 0) 372 | >>> print(theta) 373 | 1.5 374 | >>> np.allclose(vec, [0, 1, 0]) 375 | True 376 | ''' 377 | # delayed import to avoid cyclic dependencies 378 | import nibabel.quaternions as nq 379 | return nq.quat2angle_axis(euler2quat(z, y, x)) 380 | 381 | 382 | def angle_axis2euler(theta, vector, is_normalized=False): 383 | ''' Convert angle, axis pair to Euler angles 384 | 385 | Parameters 386 | ---------- 387 | theta : scalar 388 | angle of rotation 389 | vector : 3 element sequence 390 | vector specifying axis for rotation. 391 | is_normalized : bool, optional 392 | True if vector is already normalized (has norm of 1). Default 393 | False 394 | 395 | Returns 396 | ------- 397 | z : scalar 398 | y : scalar 399 | x : scalar 400 | Rotations in radians around z, y, x axes, respectively 401 | 402 | Examples 403 | -------- 404 | >>> z, y, x = angle_axis2euler(0, [1, 0, 0]) 405 | >>> np.allclose((z, y, x), 0) 406 | True 407 | 408 | Notes 409 | ----- 410 | It's possible to reduce the amount of calculation a little, by 411 | combining parts of the ``angle_axis2mat`` and ``mat2euler`` 412 | functions, but the reduction in computation is small, and the code 413 | repetition is large. 414 | ''' 415 | # delayed import to avoid cyclic dependencies 416 | import nibabel.quaternions as nq 417 | M = nq.angle_axis2mat(theta, vector, is_normalized) 418 | return mat2euler(M) 419 | -------------------------------------------------------------------------------- /tensorflow/utils/pc_util.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for processing point clouds. 2 | 3 | Author: Charles R. Qi, Hao Su 4 | Date: November 2016 5 | """ 6 | 7 | import os 8 | import sys 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(BASE_DIR) 11 | 12 | # Draw point cloud 13 | from eulerangles import euler2mat 14 | 15 | # Point cloud IO 16 | import numpy as np 17 | from plyfile import PlyData, PlyElement 18 | 19 | 20 | # ---------------------------------------- 21 | # Point Cloud/Volume Conversions 22 | # ---------------------------------------- 23 | 24 | def point_cloud_to_volume_batch(point_clouds, vsize=12, radius=1.0, flatten=True): 25 | """ Input is BxNx3 batch of point cloud 26 | Output is Bx(vsize^3) 27 | """ 28 | vol_list = [] 29 | for b in range(point_clouds.shape[0]): 30 | vol = point_cloud_to_volume(np.squeeze(point_clouds[b,:,:]), vsize, radius) 31 | if flatten: 32 | vol_list.append(vol.flatten()) 33 | else: 34 | vol_list.append(np.expand_dims(np.expand_dims(vol, -1), 0)) 35 | if flatten: 36 | return np.vstack(vol_list) 37 | else: 38 | return np.concatenate(vol_list, 0) 39 | 40 | 41 | def point_cloud_to_volume(points, vsize, radius=1.0): 42 | """ input is Nx3 points. 43 | output is vsize*vsize*vsize 44 | assumes points are in range [-radius, radius] 45 | """ 46 | vol = np.zeros((vsize,vsize,vsize)) 47 | voxel = 2*radius/float(vsize) 48 | locations = (points + radius)/voxel 49 | locations = locations.astype(int) 50 | vol[locations[:,0],locations[:,1],locations[:,2]] = 1.0 51 | return vol 52 | 53 | #a = np.zeros((16,1024,3)) 54 | #print point_cloud_to_volume_batch(a, 12, 1.0, False).shape 55 | 56 | def volume_to_point_cloud(vol): 57 | """ vol is occupancy grid (value = 0 or 1) of size vsize*vsize*vsize 58 | return Nx3 numpy array. 59 | """ 60 | vsize = vol.shape[0] 61 | assert(vol.shape[1] == vsize and vol.shape[1] == vsize) 62 | points = [] 63 | for a in range(vsize): 64 | for b in range(vsize): 65 | for c in range(vsize): 66 | if vol[a,b,c] == 1: 67 | points.append(np.array([a,b,c])) 68 | if len(points) == 0: 69 | return np.zeros((0,3)) 70 | points = np.vstack(points) 71 | return points 72 | 73 | # ---------------------------------------- 74 | # Point cloud IO 75 | # ---------------------------------------- 76 | 77 | def read_ply(filename): 78 | """ read XYZ point cloud from filename PLY file """ 79 | plydata = PlyData.read(filename) 80 | pc = plydata['vertex'].data 81 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 82 | return pc_array 83 | 84 | 85 | def write_ply(points, filename, text=True): 86 | """ input: Nx3, write points to filename as PLY format. """ 87 | points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] 88 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) 89 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 90 | PlyData([el], text=text).write(filename) 91 | 92 | 93 | # ---------------------------------------- 94 | # Simple Point cloud and Volume Renderers 95 | # ---------------------------------------- 96 | 97 | def draw_point_cloud(input_points, canvasSize=500, space=200, diameter=25, 98 | xrot=0, yrot=0, zrot=0, switch_xyz=[0,1,2], normalize=True): 99 | """ Render point cloud to image with alpha channel. 100 | Input: 101 | points: Nx3 numpy array (+y is up direction) 102 | Output: 103 | gray image as numpy array of size canvasSizexcanvasSize 104 | """ 105 | image = np.zeros((canvasSize, canvasSize)) 106 | if input_points is None or input_points.shape[0] == 0: 107 | return image 108 | 109 | points = input_points[:, switch_xyz] 110 | M = euler2mat(zrot, yrot, xrot) 111 | points = (np.dot(M, points.transpose())).transpose() 112 | 113 | # Normalize the point cloud 114 | # We normalize scale to fit points in a unit sphere 115 | if normalize: 116 | centroid = np.mean(points, axis=0) 117 | points -= centroid 118 | furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1))) 119 | points /= furthest_distance 120 | 121 | # Pre-compute the Gaussian disk 122 | radius = (diameter-1)/2.0 123 | disk = np.zeros((diameter, diameter)) 124 | for i in range(diameter): 125 | for j in range(diameter): 126 | if (i - radius) * (i-radius) + (j-radius) * (j-radius) <= radius * radius: 127 | disk[i, j] = np.exp((-(i-radius)**2 - (j-radius)**2)/(radius**2)) 128 | mask = np.argwhere(disk > 0) 129 | dx = mask[:, 0] 130 | dy = mask[:, 1] 131 | dv = disk[disk > 0] 132 | 133 | # Order points by z-buffer 134 | zorder = np.argsort(points[:, 2]) 135 | points = points[zorder, :] 136 | points[:, 2] = (points[:, 2] - np.min(points[:, 2])) / (np.max(points[:, 2] - np.min(points[:, 2]))) 137 | max_depth = np.max(points[:, 2]) 138 | 139 | for i in range(points.shape[0]): 140 | j = points.shape[0] - i - 1 141 | x = points[j, 0] 142 | y = points[j, 1] 143 | xc = canvasSize/2 + (x*space) 144 | yc = canvasSize/2 + (y*space) 145 | xc = int(np.round(xc)) 146 | yc = int(np.round(yc)) 147 | 148 | px = dx + xc 149 | py = dy + yc 150 | 151 | image[px, py] = image[px, py] * 0.7 + dv * (max_depth - points[j, 2]) * 0.3 152 | 153 | image = image / np.max(image) 154 | return image 155 | 156 | def point_cloud_three_views(points): 157 | """ input points Nx3 numpy array (+y is up direction). 158 | return an numpy array gray image of size 500x1500. """ 159 | # +y is up direction 160 | # xrot is azimuth 161 | # yrot is in-plane 162 | # zrot is elevation 163 | img1 = draw_point_cloud(points, zrot=110/180.0*np.pi, xrot=45/180.0*np.pi, yrot=0/180.0*np.pi) 164 | img2 = draw_point_cloud(points, zrot=70/180.0*np.pi, xrot=135/180.0*np.pi, yrot=0/180.0*np.pi) 165 | img3 = draw_point_cloud(points, zrot=180.0/180.0*np.pi, xrot=90/180.0*np.pi, yrot=0/180.0*np.pi) 166 | image_large = np.concatenate([img1, img2, img3], 1) 167 | return image_large 168 | 169 | 170 | from PIL import Image 171 | def point_cloud_three_views_demo(): 172 | """ Demo for draw_point_cloud function """ 173 | points = read_ply('../third_party/mesh_sampling/piano.ply') 174 | im_array = point_cloud_three_views(points) 175 | img = Image.fromarray(np.uint8(im_array*255.0)) 176 | img.save('piano.jpg') 177 | 178 | if __name__=="__main__": 179 | point_cloud_three_views_demo() 180 | 181 | 182 | import matplotlib.pyplot as plt 183 | def pyplot_draw_point_cloud(points, output_filename): 184 | """ points is a Nx3 numpy array """ 185 | fig = plt.figure() 186 | ax = fig.add_subplot(111, projection='3d') 187 | ax.scatter(points[:,0], points[:,1], points[:,2]) 188 | ax.set_xlabel('x') 189 | ax.set_ylabel('y') 190 | ax.set_zlabel('z') 191 | #savefig(output_filename) 192 | 193 | def pyplot_draw_volume(vol, output_filename): 194 | """ vol is of size vsize*vsize*vsize 195 | output an image to output_filename 196 | """ 197 | points = volume_to_point_cloud(vol) 198 | pyplot_draw_point_cloud(points, output_filename) 199 | --------------------------------------------------------------------------------