├── lib ├── __init__.py ├── loss.py ├── utils.py ├── scene_util.py ├── config.py ├── projection.py └── dataset.py ├── requirements.txt ├── pointnet2 ├── _version.py ├── _ext_src │ ├── include │ │ ├── ball_query.h │ │ ├── group_points.h │ │ ├── sampling.h │ │ ├── interpolate.h │ │ ├── utils.h │ │ └── cuda_utils.h │ └── src │ │ ├── bindings.cpp │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── sampling.cpp │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ └── sampling_gpu.cu ├── src │ ├── cuda_utils.h │ ├── ball_query_gpu.h │ ├── group_points_gpu.h │ ├── ball_query.cpp │ ├── sampling_gpu.h │ ├── pointnet2_api.cpp │ ├── interpolate_gpu.h │ ├── group_points.cpp │ ├── sampling.cpp │ ├── interpolate.cpp │ ├── ball_query_gpu.cu │ ├── group_points_gpu.cu │ ├── interpolate_gpu.cu │ └── sampling_gpu.cu ├── pointnet2_test.py ├── setup.py ├── pointnet2_semseg.py ├── pytorch_utils.py └── pointnet2_utils.py ├── img ├── snapshot.png └── snapshot_pred.png ├── data ├── scannetv2_enet.pth ├── scannetv2_test.txt └── scannetv2_val.txt ├── .gitignore ├── slurm ├── prep.job ├── compute_multiview_features.job ├── compute_multiview_projection.job ├── project_multiview_features.job ├── train.job ├── eval.job └── visualize.job ├── preprocessing ├── scannet_util.py ├── visualize_prep_scene.py └── collect_scannet_scenes.py ├── LICENSE ├── scripts ├── compute_multiview_features.py ├── compute_multiview_projection.py ├── visualize.py ├── train.py ├── project_multiview_features.py └── eval.py └── README.md /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | prefetch_generator -------------------------------------------------------------------------------- /pointnet2/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.0.0" 2 | -------------------------------------------------------------------------------- /img/snapshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daveredrum/Pointnet2.ScanNet/HEAD/img/snapshot.png -------------------------------------------------------------------------------- /img/snapshot_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daveredrum/Pointnet2.ScanNet/HEAD/img/snapshot_pred.png -------------------------------------------------------------------------------- /data/scannetv2_enet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daveredrum/Pointnet2.ScanNet/HEAD/data/scannetv2_enet.pth -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # cache 2 | __pycache__ 3 | lib/__pycache__ 4 | pointnet2/__pycache__ 5 | preprocessing/__pycache__ 6 | 7 | # pointnet2 compiled 8 | pointnet2/build 9 | pointnet2/dist 10 | pointnet2/pointnet2.egg-info 11 | 12 | # prep data 13 | preprocessing/scannet_scenes 14 | preprocessing/label_point_clouds/ 15 | 16 | # outputs 17 | outputs/ 18 | logs/ -------------------------------------------------------------------------------- /pointnet2/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define TOTAL_THREADS 1024 7 | #define THREADS_PER_BLOCK 256 8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 9 | 10 | inline int opt_n_threads(int work_size) { 11 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 12 | 13 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | #endif 16 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /pointnet2/src/ball_query_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_GPU_H 2 | #define _BALL_QUERY_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 10 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); 11 | 12 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, 13 | const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /lib/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class WeightedCrossEntropyLoss(nn.Module): 6 | def __init__(self, ignore_index=-100): 7 | super(WeightedCrossEntropyLoss, self).__init__() 8 | self.ignore_index = ignore_index 9 | 10 | def forward(self, inputs, targets, weights=None): 11 | assert inputs.size(0) == targets.size(0) == weights.size(0) 12 | 13 | loss = F.cross_entropy(input=inputs, target=targets, reduction="none", ignore_index=self.ignore_index) 14 | if weights is not None: 15 | loss = torch.mean(loss * weights) 16 | else: 17 | loss = torch.mean(loss) 18 | 19 | return loss -------------------------------------------------------------------------------- /slurm/prep.job: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p debug 4 | #SBATCH -q normal 5 | #SBATCH --job-name=prep # Job name 6 | #SBATCH --mail-type=BEGIN,END,FAIL # Mail events (NONE, BEGIN, END, FAIL, ALL) 7 | #SBATCH --mail-user=zhenyu.chen@tum.de # Where to send mail 8 | #SBATCH --mem=60gb # Job memory request 9 | #SBATCH --cpus-per-gpu=8 # Job CPUs request 10 | 11 | # #SBATCH --time=48:00:00 # Time limit hrs:min:sec 12 | #SBATCH --output=/rhome/dchen/Pointnet2.ScanNet/logs/%j.log # Standard output and error log 13 | 14 | # Default output information 15 | date;hostname;pwd 16 | 17 | # scripts 18 | python preprocessing/collect_scannet_scenes.py -------------------------------------------------------------------------------- /slurm/compute_multiview_features.job: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p debug 4 | #SBATCH -q normal 5 | #SBATCH --job-name=prep # Job name 6 | #SBATCH --mail-type=BEGIN,END,FAIL # Mail events (NONE, BEGIN, END, FAIL, ALL) 7 | #SBATCH --mail-user=zhenyu.chen@tum.de # Where to send mail 8 | #SBATCH --mem=60gb # Job memory request 9 | #SBATCH --cpus-per-gpu=8 # Job CPUs request 10 | #SBATCH --gpus=rtx_2080:1 # Job GPUs request 11 | 12 | # #SBATCH --time=48:00:00 # Time limit hrs:min:sec 13 | #SBATCH --output=/rhome/dchen/Pointnet2.ScanNet/logs/%j.log # Standard output and error log 14 | 15 | # Default output information 16 | date;hostname;pwd 17 | 18 | # scripts 19 | python scripts/compute_multiview_features.py -------------------------------------------------------------------------------- /slurm/compute_multiview_projection.job: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p debug 4 | #SBATCH -q normal 5 | #SBATCH --job-name=prep # Job name 6 | #SBATCH --mail-type=BEGIN,END,FAIL # Mail events (NONE, BEGIN, END, FAIL, ALL) 7 | #SBATCH --mail-user=zhenyu.chen@tum.de # Where to send mail 8 | #SBATCH --mem=60gb # Job memory request 9 | #SBATCH --cpus-per-gpu=8 # Job CPUs request 10 | #SBATCH --gpus=rtx_2080:1 # Job GPUs request 11 | 12 | # #SBATCH --time=48:00:00 # Time limit hrs:min:sec 13 | #SBATCH --output=/rhome/dchen/Pointnet2.ScanNet/logs/%j.log # Standard output and error log 14 | 15 | # Default output information 16 | date;hostname;pwd 17 | 18 | # scripts 19 | python scripts/compute_multiview_projection.py -------------------------------------------------------------------------------- /slurm/project_multiview_features.job: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p debug 4 | #SBATCH -q normal 5 | #SBATCH --job-name=prep # Job name 6 | #SBATCH --mail-type=BEGIN,END,FAIL # Mail events (NONE, BEGIN, END, FAIL, ALL) 7 | #SBATCH --mail-user=zhenyu.chen@tum.de # Where to send mail 8 | #SBATCH --mem=60gb # Job memory request 9 | #SBATCH --cpus-per-gpu=8 # Job CPUs request 10 | #SBATCH --gpus=rtx_2080:1 # Job GPUs request 11 | 12 | # #SBATCH --time=48:00:00 # Time limit hrs:min:sec 13 | #SBATCH --output=/rhome/dchen/Pointnet2.ScanNet/logs/%j.log # Standard output and error log 14 | 15 | # Default output information 16 | date;hostname;pwd 17 | 18 | # scripts 19 | python scripts/project_multiview_features.py -------------------------------------------------------------------------------- /slurm/train.job: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p debug 4 | #SBATCH -q normal 5 | #SBATCH --job-name=train # Job name 6 | #SBATCH --mail-type=BEGIN,END,FAIL # Mail events (NONE, BEGIN, END, FAIL, ALL) 7 | #SBATCH --mail-user=zhenyu.chen@tum.de # Where to send mail 8 | #SBATCH --mem=100gb # Job memory request 9 | #SBATCH --cpus-per-gpu=8 # Job CPUs request 10 | #SBATCH --gpus=rtx_3090:1 11 | 12 | # #SBATCH --time=48:00:00 # Time limit hrs:min:sec 13 | #SBATCH --output=/rhome/dchen/Pointnet2.ScanNet/logs/%j.log # Standard output and error log 14 | 15 | # Default output information 16 | date;hostname;pwd 17 | 18 | # scripts 19 | python scripts/train.py --use_multiview --use_normal --tag ssg 20 | # python scripts/train.py --use_multiview --use_normal --use_msg --tag msg 21 | -------------------------------------------------------------------------------- /slurm/eval.job: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p debug 4 | #SBATCH -q normal 5 | #SBATCH --job-name=eval # Job name 6 | #SBATCH --mail-type=BEGIN,END,FAIL # Mail events (NONE, BEGIN, END, FAIL, ALL) 7 | #SBATCH --mail-user=zhenyu.chen@tum.de # Where to send mail 8 | #SBATCH --mem=60gb # Job memory request 9 | #SBATCH --cpus-per-gpu=8 # Job CPUs request 10 | #SBATCH --gpus=rtx_2080:1 11 | 12 | # #SBATCH --time=48:00:00 # Time limit hrs:min:sec 13 | #SBATCH --output=/rhome/dchen/Pointnet2.ScanNet/logs/%j.log # Standard output and error log 14 | 15 | # Default output information 16 | date;hostname;pwd 17 | 18 | # scripts 19 | # python scripts/eval.py --folder 2021-07-30_21-39-11_SSG --use_multiview --use_normal 20 | python scripts/eval.py --folder 2021-07-30_21-39-10_MSG --use_multiview --use_normal --use_msg -------------------------------------------------------------------------------- /preprocessing/scannet_util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | from lib.config import CONF 4 | 5 | g_label_names = CONF.NYUCLASSES 6 | 7 | def get_raw2scannet_label_map(): 8 | # lines = [line.rstrip() for line in open('scannet-labels.combined.tsv')] 9 | lines = [line.rstrip() for line in open('preprocessing/scannetv2-labels.combined.tsv')] 10 | lines = lines[1:] 11 | raw2scannet = {} 12 | for i in range(len(lines)): 13 | label_classes_set = set(g_label_names) 14 | elements = lines[i].split('\t') 15 | # raw_name = elements[0] 16 | # nyu40_name = elements[6] 17 | raw_name = elements[1] 18 | nyu40_name = elements[7] 19 | if nyu40_name not in label_classes_set: 20 | raw2scannet[raw_name] = 'otherprop' 21 | else: 22 | raw2scannet[raw_name] = nyu40_name 23 | return raw2scannet 24 | 25 | 26 | g_raw2scannet = get_raw2scannet_label_map() 27 | -------------------------------------------------------------------------------- /pointnet2/src/group_points_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _GROUP_POINTS_GPU_H 2 | #define _GROUP_POINTS_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 12 | 13 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 14 | const float *points, const int *idx, float *out, cudaStream_t stream); 15 | 16 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /slurm/visualize.job: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p debug 4 | #SBATCH -q normal 5 | #SBATCH --job-name=prep # Job name 6 | #SBATCH --mail-type=BEGIN,END,FAIL # Mail events (NONE, BEGIN, END, FAIL, ALL) 7 | #SBATCH --mail-user=zhenyu.chen@tum.de # Where to send mail 8 | #SBATCH --mem=60gb # Job memory request 9 | #SBATCH --cpus-per-gpu=8 # Job CPUs request 10 | #SBATCH --gpus=rtx_2080:1 11 | 12 | # #SBATCH --time=48:00:00 # Time limit hrs:min:sec 13 | #SBATCH --output=/rhome/dchen/Pointnet2.ScanNet/logs/%j.log # Standard output and error log 14 | 15 | # Default output information 16 | date;hostname;pwd 17 | 18 | # scripts 19 | # python preprocessing/visualize_prep_scene.py --scene_id scene0000_00 20 | python scripts/visualize.py --use_color --use_normal --use_msg --folder 2021-07-29_11-24-45_MSG --scene_id scene0427_00 -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | def get_eta(start, end, extra, num_left): 2 | exe_s = end - start 3 | eta_s = (exe_s + extra) * num_left 4 | eta = {'h': 0, 'm': 0, 's': 0} 5 | if eta_s < 60: 6 | eta['s'] = int(eta_s) 7 | elif eta_s >= 60 and eta_s < 3600: 8 | eta['m'] = int(eta_s / 60) 9 | eta['s'] = int(eta_s % 60) 10 | else: 11 | eta['h'] = int(eta_s / (60 * 60)) 12 | eta['m'] = int(eta_s % (60 * 60) / 60) 13 | eta['s'] = int(eta_s % (60 * 60) % 60) 14 | 15 | return eta 16 | 17 | def decode_eta(eta_sec): 18 | eta = {'h': 0, 'm': 0, 's': 0} 19 | if eta_sec < 60: 20 | eta['s'] = int(eta_sec) 21 | elif eta_sec >= 60 and eta_sec < 3600: 22 | eta['m'] = int(eta_sec / 60) 23 | eta['s'] = int(eta_sec % 60) 24 | else: 25 | eta['h'] = int(eta_sec / (60 * 60)) 26 | eta['m'] = int(eta_sec % (60 * 60) / 60) 27 | eta['s'] = int(eta_sec % (60 * 60) % 60) 28 | 29 | return eta -------------------------------------------------------------------------------- /pointnet2/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "ball_query_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 11 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 13 | 14 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 15 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { 16 | CHECK_INPUT(new_xyz_tensor); 17 | CHECK_INPUT(xyz_tensor); 18 | const float *new_xyz = new_xyz_tensor.data(); 19 | const float *xyz = xyz_tensor.data(); 20 | int *idx = idx_tensor.data(); 21 | 22 | cudaStream_t stream = THCState_getCurrentStream(state); 23 | ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); 24 | return 1; 25 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Dave Z. Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /pointnet2/pointnet2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Testing customized ops. ''' 7 | 8 | import torch 9 | from torch.autograd import gradcheck 10 | import numpy as np 11 | 12 | import os 13 | import sys 14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append(BASE_DIR) 16 | import pointnet2_utils 17 | 18 | def test_interpolation_grad(): 19 | batch_size = 1 20 | feat_dim = 2 21 | m = 4 22 | feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda() 23 | 24 | def interpolate_func(inputs): 25 | idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda() 26 | weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda() 27 | interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight) 28 | return interpolated_feats 29 | 30 | assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1)) 31 | 32 | if __name__=='__main__': 33 | test_interpolation_grad() 34 | -------------------------------------------------------------------------------- /pointnet2/src/sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_GPU_H 2 | #define _SAMPLING_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 10 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 11 | 12 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 13 | const float *points, const int *idx, float *out, cudaStream_t stream); 14 | 15 | 16 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | 23 | int furthest_point_sampling_wrapper(int b, int n, int m, 24 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 25 | 26 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 27 | const float *dataset, float *temp, int *idxs, cudaStream_t stream); 28 | 29 | #endif 30 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | 8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 9 | const int nsample) { 10 | CHECK_CONTIGUOUS(new_xyz); 11 | CHECK_CONTIGUOUS(xyz); 12 | CHECK_IS_FLOAT(new_xyz); 13 | CHECK_IS_FLOAT(xyz); 14 | 15 | if (new_xyz.is_cuda()) { 16 | CHECK_CUDA(xyz); 17 | } 18 | 19 | at::Tensor idx = 20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 22 | 23 | if (new_xyz.is_cuda()) { 24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 25 | radius, nsample, new_xyz.data_ptr(), 26 | xyz.data_ptr(), idx.data_ptr()); 27 | } else { 28 | AT_ASSERT(false, "CPU not supported"); 29 | } 30 | 31 | return idx; 32 | } 33 | -------------------------------------------------------------------------------- /pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | from setuptools import find_packages, setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | _this_dir = osp.dirname(osp.abspath(__file__)) 9 | _ext_src_root = "_ext_src" 10 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 11 | "{}/src/*.cu".format(_ext_src_root) 12 | ) 13 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 14 | 15 | requirements = ["torch>=1.4"] 16 | 17 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 18 | 19 | exec(open("_version.py").read()) 20 | 21 | setup( 22 | name='pointnet2', 23 | version=__version__, 24 | packages=find_packages(), 25 | install_requires=requirements, 26 | ext_modules=[ 27 | CUDAExtension( 28 | name='pointnet2._ext', 29 | sources=_ext_sources, 30 | extra_compile_args={ 31 | "cxx": ["-O3"], 32 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"], 33 | }, 34 | include_dirs=[osp.join(_this_dir, _ext_src_root, "include")], 35 | ) 36 | ], 37 | cmdclass={"build_ext": BuildExtension}, 38 | include_package_data=True, 39 | ) -------------------------------------------------------------------------------- /pointnet2/src/pointnet2_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ball_query_gpu.h" 5 | #include "group_points_gpu.h" 6 | #include "sampling_gpu.h" 7 | #include "interpolate_gpu.h" 8 | 9 | 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 11 | m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast"); 12 | 13 | m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast"); 14 | m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast"); 15 | 16 | m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); 17 | m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); 18 | 19 | m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); 20 | 21 | m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); 22 | m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); 23 | m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast"); 24 | } 25 | -------------------------------------------------------------------------------- /pointnet2/src/interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATE_GPU_H 2 | #define _INTERPOLATE_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 11 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); 12 | 13 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 14 | const float *known, float *dist2, int *idx, cudaStream_t stream); 15 | 16 | 17 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, 18 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); 19 | 20 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 21 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); 22 | 23 | 24 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, 25 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); 26 | 27 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 28 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream); 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /pointnet2/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "group_points_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | 11 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 12 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 13 | 14 | float *grad_points = grad_points_tensor.data(); 15 | const int *idx = idx_tensor.data(); 16 | const float *grad_out = grad_out_tensor.data(); 17 | 18 | cudaStream_t stream = THCState_getCurrentStream(state); 19 | 20 | group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); 21 | return 1; 22 | } 23 | 24 | 25 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 26 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { 27 | 28 | const float *points = points_tensor.data(); 29 | const int *idx = idx_tensor.data(); 30 | float *out = out_tensor.data(); 31 | 32 | cudaStream_t stream = THCState_getCurrentStream(state); 33 | 34 | group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream); 35 | return 1; 36 | } -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /preprocessing/visualize_prep_scene.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | from plyfile import PlyElement, PlyData 5 | 6 | import sys 7 | sys.path.append(".") 8 | from lib.config import CONF 9 | 10 | def visualize(args): 11 | print("visualizing...") 12 | scene = np.load(CONF.SCANNETV2_FILE.format(args.scene_id)) 13 | 14 | vertex = [] 15 | for i in range(scene.shape[0]): 16 | vertex.append( 17 | ( 18 | scene[i][0], 19 | scene[i][1], 20 | scene[i][2], 21 | CONF.PALETTE[int(scene[i][-1])][0], 22 | CONF.PALETTE[int(scene[i][-1])][1], 23 | CONF.PALETTE[int(scene[i][-1])][2] 24 | ) 25 | ) 26 | 27 | vertex = np.array( 28 | vertex, 29 | dtype=[ 30 | ("x", np.dtype("float32")), 31 | ("y", np.dtype("float32")), 32 | ("z", np.dtype("float32")), 33 | ("red", np.dtype("uint8")), 34 | ("green", np.dtype("uint8")), 35 | ("blue", np.dtype("uint8")) 36 | ] 37 | ) 38 | 39 | output_pc = PlyElement.describe(vertex, "vertex") 40 | output_pc = PlyData([output_pc]) 41 | os.makedirs(CONF.SCAN_LABELS, exist_ok=True) 42 | output_pc.write(CONF.SCANNETV2_LABEL.format(args.scene_id)) 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--scene_id", type=str, required=True) 48 | args = parser.parse_args() 49 | 50 | visualize(args) 51 | print("done!") -------------------------------------------------------------------------------- /pointnet2/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "sampling_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | 11 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 12 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ 13 | const float *points = points_tensor.data(); 14 | const int *idx = idx_tensor.data(); 15 | float *out = out_tensor.data(); 16 | 17 | cudaStream_t stream = THCState_getCurrentStream(state); 18 | gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); 19 | return 1; 20 | } 21 | 22 | 23 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 24 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 25 | 26 | const float *grad_out = grad_out_tensor.data(); 27 | const int *idx = idx_tensor.data(); 28 | float *grad_points = grad_points_tensor.data(); 29 | 30 | cudaStream_t stream = THCState_getCurrentStream(state); 31 | gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); 32 | return 1; 33 | } 34 | 35 | 36 | int furthest_point_sampling_wrapper(int b, int n, int m, 37 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 38 | 39 | const float *points = points_tensor.data(); 40 | float *temp = temp_tensor.data(); 41 | int *idx = idx_tensor.data(); 42 | 43 | cudaStream_t stream = THCState_getCurrentStream(state); 44 | furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); 45 | return 1; 46 | } 47 | -------------------------------------------------------------------------------- /data/scannetv2_test.txt: -------------------------------------------------------------------------------- 1 | scene0707_00 2 | scene0708_00 3 | scene0709_00 4 | scene0710_00 5 | scene0711_00 6 | scene0712_00 7 | scene0713_00 8 | scene0714_00 9 | scene0715_00 10 | scene0716_00 11 | scene0717_00 12 | scene0718_00 13 | scene0719_00 14 | scene0720_00 15 | scene0721_00 16 | scene0722_00 17 | scene0723_00 18 | scene0724_00 19 | scene0725_00 20 | scene0726_00 21 | scene0727_00 22 | scene0728_00 23 | scene0729_00 24 | scene0730_00 25 | scene0731_00 26 | scene0732_00 27 | scene0733_00 28 | scene0734_00 29 | scene0735_00 30 | scene0736_00 31 | scene0737_00 32 | scene0738_00 33 | scene0739_00 34 | scene0740_00 35 | scene0741_00 36 | scene0742_00 37 | scene0743_00 38 | scene0744_00 39 | scene0745_00 40 | scene0746_00 41 | scene0747_00 42 | scene0748_00 43 | scene0749_00 44 | scene0750_00 45 | scene0751_00 46 | scene0752_00 47 | scene0753_00 48 | scene0754_00 49 | scene0755_00 50 | scene0756_00 51 | scene0757_00 52 | scene0758_00 53 | scene0759_00 54 | scene0760_00 55 | scene0761_00 56 | scene0762_00 57 | scene0763_00 58 | scene0764_00 59 | scene0765_00 60 | scene0766_00 61 | scene0767_00 62 | scene0768_00 63 | scene0769_00 64 | scene0770_00 65 | scene0771_00 66 | scene0772_00 67 | scene0773_00 68 | scene0774_00 69 | scene0775_00 70 | scene0776_00 71 | scene0777_00 72 | scene0778_00 73 | scene0779_00 74 | scene0780_00 75 | scene0781_00 76 | scene0782_00 77 | scene0783_00 78 | scene0784_00 79 | scene0785_00 80 | scene0786_00 81 | scene0787_00 82 | scene0788_00 83 | scene0789_00 84 | scene0790_00 85 | scene0791_00 86 | scene0792_00 87 | scene0793_00 88 | scene0794_00 89 | scene0795_00 90 | scene0796_00 91 | scene0797_00 92 | scene0798_00 93 | scene0799_00 94 | scene0800_00 95 | scene0801_00 96 | scene0802_00 97 | scene0803_00 98 | scene0804_00 99 | scene0805_00 100 | scene0806_00 101 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), 29 | points.data_ptr(), idx.data_ptr(), 30 | output.data_ptr()); 31 | } else { 32 | AT_ASSERT(false, "CPU not supported"); 33 | } 34 | 35 | return output; 36 | } 37 | 38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 39 | CHECK_CONTIGUOUS(grad_out); 40 | CHECK_CONTIGUOUS(idx); 41 | CHECK_IS_FLOAT(grad_out); 42 | CHECK_IS_INT(idx); 43 | 44 | if (grad_out.is_cuda()) { 45 | CHECK_CUDA(idx); 46 | } 47 | 48 | at::Tensor output = 49 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 50 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 51 | 52 | if (grad_out.is_cuda()) { 53 | group_points_grad_kernel_wrapper( 54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 55 | grad_out.data_ptr(), idx.data_ptr(), 56 | output.data_ptr()); 57 | } else { 58 | AT_ASSERT(false, "CPU not supported"); 59 | } 60 | 61 | return output; 62 | } 63 | -------------------------------------------------------------------------------- /pointnet2/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "interpolate_gpu.h" 10 | 11 | extern THCState *state; 12 | 13 | 14 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 15 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { 16 | const float *unknown = unknown_tensor.data(); 17 | const float *known = known_tensor.data(); 18 | float *dist2 = dist2_tensor.data(); 19 | int *idx = idx_tensor.data(); 20 | 21 | cudaStream_t stream = THCState_getCurrentStream(state); 22 | three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream); 23 | } 24 | 25 | 26 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, 27 | at::Tensor points_tensor, 28 | at::Tensor idx_tensor, 29 | at::Tensor weight_tensor, 30 | at::Tensor out_tensor) { 31 | 32 | const float *points = points_tensor.data(); 33 | const float *weight = weight_tensor.data(); 34 | float *out = out_tensor.data(); 35 | const int *idx = idx_tensor.data(); 36 | 37 | cudaStream_t stream = THCState_getCurrentStream(state); 38 | three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream); 39 | } 40 | 41 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, 42 | at::Tensor grad_out_tensor, 43 | at::Tensor idx_tensor, 44 | at::Tensor weight_tensor, 45 | at::Tensor grad_points_tensor) { 46 | 47 | const float *grad_out = grad_out_tensor.data(); 48 | const float *weight = weight_tensor.data(); 49 | float *grad_points = grad_points_tensor.data(); 50 | const int *idx = idx_tensor.data(); 51 | 52 | cudaStream_t stream = THCState_getCurrentStream(state); 53 | three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream); 54 | } -------------------------------------------------------------------------------- /pointnet2/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "ball_query_gpu.h" 6 | #include "cuda_utils.h" 7 | 8 | 9 | __global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, 10 | const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { 11 | // new_xyz: (B, M, 3) 12 | // xyz: (B, N, 3) 13 | // output: 14 | // idx: (B, M, nsample) 15 | int bs_idx = blockIdx.y; 16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 17 | if (bs_idx >= b || pt_idx >= m) return; 18 | 19 | new_xyz += bs_idx * m * 3 + pt_idx * 3; 20 | xyz += bs_idx * n * 3; 21 | idx += bs_idx * m * nsample + pt_idx * nsample; 22 | 23 | float radius2 = radius * radius; 24 | float new_x = new_xyz[0]; 25 | float new_y = new_xyz[1]; 26 | float new_z = new_xyz[2]; 27 | 28 | int cnt = 0; 29 | for (int k = 0; k < n; ++k) { 30 | float x = xyz[k * 3 + 0]; 31 | float y = xyz[k * 3 + 1]; 32 | float z = xyz[k * 3 + 2]; 33 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 34 | if (d2 < radius2){ 35 | if (cnt == 0){ 36 | for (int l = 0; l < nsample; ++l) { 37 | idx[l] = k; 38 | } 39 | } 40 | idx[cnt] = k; 41 | ++cnt; 42 | if (cnt >= nsample) break; 43 | } 44 | } 45 | } 46 | 47 | 48 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ 49 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { 50 | // new_xyz: (B, M, 3) 51 | // xyz: (B, N, 3) 52 | // output: 53 | // idx: (B, M, nsample) 54 | 55 | cudaError_t err; 56 | 57 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 58 | dim3 threads(THREADS_PER_BLOCK); 59 | 60 | ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); 61 | // cudaDeviceSynchronize(); // for using printf in kernel function 62 | err = cudaGetLastError(); 63 | if (cudaSuccess != err) { 64 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 65 | exit(-1); 66 | } 67 | } -------------------------------------------------------------------------------- /lib/scene_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from sklearn.neighbors import NearestNeighbors 5 | from numpy import linalg as la 6 | import scipy.io as sio 7 | 8 | def cart2sph(xyz): 9 | xy = xyz[:,0]**2+xyz[:,1]**2 10 | aer = np.zeros(xyz.shape) 11 | aer[:,2] = np.sqrt(xy+xyz[:,2]**2) 12 | aer[:,1] = np.arctan2(xyz[:,2],np.sqrt(xy)) 13 | aer[:,0] = np.arctan2(xyz[:,1],xyz[:,0]) 14 | 15 | return aer 16 | 17 | # generate virtual scan of a scene by subsampling the point cloud 18 | def virtual_scan(xyz, mode=-1): 19 | camloc = np.mean(xyz,axis=0) 20 | camloc[2] = 1.5 # human height 21 | if mode==-1: 22 | view_dr = np.array([2*np.pi*np.random.random(), np.pi/10*(np.random.random()-0.75)]) 23 | camloc[:2] -= (0.8+0.7*np.random.random())*np.array([np.cos(view_dr[0]),np.sin(view_dr[0])]) 24 | else: 25 | view_dr = np.array([np.pi/4*mode, 0]) 26 | camloc[:2] -= np.array([np.cos(view_dr[0]),np.sin(view_dr[0])]) 27 | 28 | ct_ray_dr = np.array([np.cos(view_dr[1])*np.cos(view_dr[0]), np.cos(view_dr[1])*np.sin(view_dr[0]), np.sin(view_dr[1])]) 29 | hr_dr = np.cross(ct_ray_dr, np.array([0,0,1])) 30 | hr_dr /= la.norm(hr_dr) 31 | vt_dr = np.cross(hr_dr, ct_ray_dr) 32 | vt_dr /= la.norm(vt_dr) 33 | xx = np.linspace(-0.6,0.6,200) #200 34 | yy = np.linspace(-0.45,0.45,150) #150 35 | xx, yy = np.meshgrid(xx,yy) 36 | xx = xx.reshape(-1,1) 37 | yy = yy.reshape(-1,1) 38 | rays = xx*hr_dr.reshape(1,-1)+yy*vt_dr.reshape(1,-1)+ct_ray_dr.reshape(1,-1) 39 | rays_aer = cart2sph(rays) 40 | local_xyz = xyz-camloc.reshape(1,-1) 41 | local_aer = cart2sph(local_xyz) 42 | nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(rays_aer[:,:2]) 43 | mindd, minidx = nbrs.kneighbors(local_aer[:,:2]) 44 | mindd = mindd.reshape(-1) 45 | minidx = minidx.reshape(-1) 46 | 47 | sub_idx = mindd<0.01 48 | if sum(sub_idx)<100: 49 | return np.ones(0) 50 | 51 | sub_r = local_aer[sub_idx,2] 52 | sub_minidx = minidx[sub_idx] 53 | min_r = float('inf')*np.ones(np.max(sub_minidx)+1) 54 | for i in range(len(sub_r)): 55 | if sub_r[i]min_r[sub_minidx[i]]: 61 | sub_smpidx[i] = 0 62 | 63 | smpidx = np.where(sub_idx)[0] 64 | smpidx = smpidx[sub_smpidx==1] 65 | 66 | return smpidx 67 | 68 | if __name__=='__main__': 69 | pc = np.load('scannet_dataset/scannet_scenes/scene0015_00.npy') 70 | print(pc.shape) 71 | xyz = pc[:,:3] 72 | seg = pc[:,7] 73 | smpidx = virtual_scan(xyz,mode=2) 74 | xyz = xyz[smpidx,:] 75 | seg = seg[smpidx] 76 | sio.savemat('tmp.mat',{'pc':xyz,'seg':seg}) -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data_ptr(), 32 | idx.data_ptr(), output.data_ptr()); 33 | } else { 34 | AT_ASSERT(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data_ptr(), 58 | idx.data_ptr(), 59 | output.data_ptr()); 60 | } else { 61 | AT_ASSERT(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 67 | CHECK_CONTIGUOUS(points); 68 | CHECK_IS_FLOAT(points); 69 | 70 | at::Tensor output = 71 | torch::zeros({points.size(0), nsamples}, 72 | at::device(points.device()).dtype(at::ScalarType::Int)); 73 | 74 | at::Tensor tmp = 75 | torch::full({points.size(0), points.size(1)}, 1e10, 76 | at::device(points.device()).dtype(at::ScalarType::Float)); 77 | 78 | if (points.is_cuda()) { 79 | furthest_point_sampling_kernel_wrapper( 80 | points.size(0), points.size(1), nsamples, points.data_ptr(), 81 | tmp.data_ptr(), output.data_ptr()); 82 | } else { 83 | AT_ASSERT(false, "CPU not supported"); 84 | } 85 | 86 | return output; 87 | } 88 | -------------------------------------------------------------------------------- /lib/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from easydict import EasyDict 3 | 4 | CONF = EasyDict() 5 | 6 | # BASE PATH 7 | CONF.ROOT = "/rhome/dchen/Pointnet2.ScanNet" # TODO change this 8 | CONF.SCANNET_DIR = "/canis/Datasets/ScanNet/public/v2/scans" # TODO change this 9 | 10 | # Uncomment the followings if you're NOT on slurm 11 | # CONF.SCANNET_FRAMES_ROOT = os.path.join(CONF.ROOT, "frames_square") 12 | # CONF.PROJECTION = os.path.join(CONF.ROOT, "multiview_projection_pointnet") 13 | # CONF.ENET_FEATURES_ROOT = os.path.join(CONF.ROOT, "enet_features") 14 | 15 | # Uncomment the followings if you're on slurm 16 | CONF.CLUSTER = "/cluster/balrog/dchen/Pointnet2.ScanNet" 17 | CONF.SCANNET_FRAMES_ROOT = os.path.join(CONF.CLUSTER, "frames_square") 18 | CONF.PROJECTION = os.path.join(CONF.CLUSTER, "multiview_projection_pointnet") 19 | CONF.ENET_FEATURES_ROOT = os.path.join(CONF.CLUSTER, "enet_features") 20 | 21 | CONF.ENET_FEATURES_SUBROOT = os.path.join(CONF.ENET_FEATURES_ROOT, "{}") # scene_id 22 | CONF.ENET_FEATURES_PATH = os.path.join(CONF.ENET_FEATURES_SUBROOT, "{}.npy") # frame_id 23 | CONF.SCANNET_FRAMES = os.path.join(CONF.SCANNET_FRAMES_ROOT, "{}/{}") # scene_id, mode 24 | CONF.SCENE_NAMES = sorted(os.listdir(CONF.SCANNET_DIR)) 25 | 26 | CONF.PREP = os.path.join(CONF.ROOT, "preprocessing") 27 | CONF.PREP_SCANS = os.path.join(CONF.PREP, "scannet_scenes") 28 | CONF.SCAN_LABELS = os.path.join(CONF.PREP, "label_point_clouds") 29 | CONF.OUTPUT_ROOT = os.path.join(CONF.ROOT, "outputs") 30 | CONF.ENET_WEIGHTS = os.path.join(CONF.ROOT, "data/scannetv2_enet.pth") 31 | CONF.MULTIVIEW = os.path.join(CONF.PREP_SCANS, "enet_feats.hdf5") 32 | 33 | CONF.SCANNETV2_TRAIN = os.path.join(CONF.ROOT, "data/scannetv2_train.txt") 34 | CONF.SCANNETV2_VAL = os.path.join(CONF.ROOT, "data/scannetv2_val.txt") 35 | CONF.SCANNETV2_TEST = os.path.join(CONF.ROOT, "data/scannetv2_test.txt") 36 | CONF.SCANNETV2_LIST = os.path.join(CONF.ROOT, "data/scannetv2.txt") 37 | CONF.SCANNETV2_FILE = os.path.join(CONF.PREP_SCANS, "{}.npy") # scene_id 38 | CONF.SCANNETV2_LABEL = os.path.join(CONF.SCAN_LABELS, "{}.ply") # scene_id 39 | 40 | CONF.NYUCLASSES = [ 41 | 'floor', 42 | 'wall', 43 | 'cabinet', 44 | 'bed', 45 | 'chair', 46 | 'sofa', 47 | 'table', 48 | 'door', 49 | 'window', 50 | 'bookshelf', 51 | 'picture', 52 | 'counter', 53 | 'desk', 54 | 'curtain', 55 | 'refrigerator', 56 | 'bathtub', 57 | 'shower curtain', 58 | 'toilet', 59 | 'sink', 60 | 'otherprop' 61 | ] 62 | CONF.NUM_CLASSES = len(CONF.NYUCLASSES) 63 | CONF.PALETTE = [ 64 | (152, 223, 138), # floor 65 | (174, 199, 232), # wall 66 | (31, 119, 180), # cabinet 67 | (255, 187, 120), # bed 68 | (188, 189, 34), # chair 69 | (140, 86, 75), # sofa 70 | (255, 152, 150), # table 71 | (214, 39, 40), # door 72 | (197, 176, 213), # window 73 | (148, 103, 189), # bookshelf 74 | (196, 156, 148), # picture 75 | (23, 190, 207), # counter 76 | (247, 182, 210), # desk 77 | (219, 219, 141), # curtain 78 | (255, 127, 14), # refrigerator 79 | (227, 119, 194), # bathtub 80 | (158, 218, 229), # shower curtain 81 | (44, 160, 44), # toilet 82 | (112, 128, 144), # sink 83 | (82, 84, 163), # otherfurn 84 | ] -------------------------------------------------------------------------------- /pointnet2/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "group_points_gpu.h" 6 | 7 | 8 | __global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, 9 | const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { 10 | // grad_out: (B, C, npoints, nsample) 11 | // idx: (B, npoints, nsample) 12 | // output: 13 | // grad_points: (B, C, N) 14 | int bs_idx = blockIdx.z; 15 | int c_idx = blockIdx.y; 16 | int index = blockIdx.x * blockDim.x + threadIdx.x; 17 | int pt_idx = index / nsample; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 19 | 20 | int sample_idx = index % nsample; 21 | grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 22 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 23 | 24 | atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); 25 | } 26 | 27 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 28 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 29 | // grad_out: (B, C, npoints, nsample) 30 | // idx: (B, npoints, nsample) 31 | // output: 32 | // grad_points: (B, C, N) 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | 47 | __global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, 48 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 49 | // points: (B, C, N) 50 | // idx: (B, npoints, nsample) 51 | // output: 52 | // out: (B, C, npoints, nsample) 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int index = blockIdx.x * blockDim.x + threadIdx.x; 56 | int pt_idx = index / nsample; 57 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 58 | 59 | int sample_idx = index % nsample; 60 | 61 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 62 | int in_idx = bs_idx * c * n + c_idx * n + idx[0]; 63 | int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 64 | 65 | out[out_idx] = points[in_idx]; 66 | } 67 | 68 | 69 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 70 | const float *points, const int *idx, float *out, cudaStream_t stream) { 71 | // points: (B, C, N) 72 | // idx: (B, npoints, nsample) 73 | // output: 74 | // out: (B, C, npoints, nsample) 75 | cudaError_t err; 76 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 77 | dim3 threads(THREADS_PER_BLOCK); 78 | 79 | group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out); 80 | // cudaDeviceSynchronize(); // for using printf in kernel function 81 | err = cudaGetLastError(); 82 | if (cudaSuccess != err) { 83 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 84 | exit(-1); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 7 | const float *points, const int *idx, 8 | const float *weight, float *out); 9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 10 | const float *grad_out, 11 | const int *idx, const float *weight, 12 | float *grad_points); 13 | 14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 15 | CHECK_CONTIGUOUS(unknowns); 16 | CHECK_CONTIGUOUS(knows); 17 | CHECK_IS_FLOAT(unknowns); 18 | CHECK_IS_FLOAT(knows); 19 | 20 | if (unknowns.is_cuda()) { 21 | CHECK_CUDA(knows); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 26 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 27 | at::Tensor dist2 = 28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 29 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (unknowns.is_cuda()) { 32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 33 | unknowns.data_ptr(), knows.data_ptr(), 34 | dist2.data_ptr(), idx.data_ptr()); 35 | } else { 36 | AT_ASSERT(false, "CPU not supported"); 37 | } 38 | 39 | return {dist2, idx}; 40 | } 41 | 42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 43 | at::Tensor weight) { 44 | CHECK_CONTIGUOUS(points); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_CONTIGUOUS(weight); 47 | CHECK_IS_FLOAT(points); 48 | CHECK_IS_INT(idx); 49 | CHECK_IS_FLOAT(weight); 50 | 51 | if (points.is_cuda()) { 52 | CHECK_CUDA(idx); 53 | CHECK_CUDA(weight); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 58 | at::device(points.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (points.is_cuda()) { 61 | three_interpolate_kernel_wrapper( 62 | points.size(0), points.size(1), points.size(2), idx.size(1), 63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(), 64 | output.data_ptr()); 65 | } else { 66 | AT_ASSERT(false, "CPU not supported"); 67 | } 68 | 69 | return output; 70 | } 71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 72 | at::Tensor weight, const int m) { 73 | CHECK_CONTIGUOUS(grad_out); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(grad_out); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (grad_out.is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 87 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (grad_out.is_cuda()) { 90 | three_interpolate_grad_kernel_wrapper( 91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 92 | grad_out.data_ptr(), idx.data_ptr(), 93 | weight.data_ptr(), output.data_ptr()); 94 | } else { 95 | AT_ASSERT(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | -------------------------------------------------------------------------------- /preprocessing/collect_scannet_scenes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | import numpy as np 6 | 7 | sys.path.append(".") 8 | from scannet_util import g_label_names, g_raw2scannet 9 | from lib.pc_util import read_ply_xyzrgbnormal 10 | from lib.utils import get_eta 11 | from lib.config import CONF 12 | 13 | CLASS_NAMES = g_label_names 14 | RAW2SCANNET = g_raw2scannet 15 | NUM_MAX_PTS = 100000 16 | 17 | def collect_one_scene_data_label(scene_name, out_filename): 18 | # Over-segmented segments: maps from segment to vertex/point IDs 19 | data_folder = os.path.join(CONF.SCANNET_DIR, scene_name) 20 | mesh_seg_filename = os.path.join(data_folder, '%s_vh_clean_2.0.010000.segs.json'%(scene_name)) 21 | #print mesh_seg_filename 22 | with open(mesh_seg_filename) as jsondata: 23 | d = json.load(jsondata) 24 | seg = d['segIndices'] 25 | #print len(seg) 26 | segid_to_pointid = {} 27 | for i in range(len(seg)): 28 | if seg[i] not in segid_to_pointid: 29 | segid_to_pointid[seg[i]] = [] 30 | segid_to_pointid[seg[i]].append(i) 31 | 32 | # Raw points in XYZRGBA 33 | ply_filename = os.path.join(data_folder, '%s_vh_clean_2.ply' % (scene_name)) 34 | points = read_ply_xyzrgbnormal(ply_filename) 35 | 36 | # Instances over-segmented segment IDs: annotation on segments 37 | instance_segids = [] 38 | labels = [] 39 | annotation_filename = os.path.join(data_folder, '%s.aggregation.json'%(scene_name)) # low-res mesh 40 | # annotation_filename = os.path.join(data_folder, '%s_vh_clean.aggregation.json'%(scene_name)) # high-res mesh 41 | #print annotation_filename 42 | with open(annotation_filename) as jsondata: 43 | d = json.load(jsondata) 44 | for x in d['segGroups']: 45 | instance_segids.append(x['segments']) 46 | labels.append(x['label']) 47 | 48 | #print len(instance_segids) 49 | #print labels 50 | 51 | # Each instance's points 52 | instance_points_list = [] 53 | instance_labels_list = [] 54 | semantic_labels_list = [] 55 | for i in range(len(instance_segids)): 56 | segids = instance_segids[i] 57 | pointids = [] 58 | for segid in segids: 59 | pointids += segid_to_pointid[segid] 60 | instance_points = points[np.array(pointids),:] 61 | instance_points_list.append(instance_points) 62 | instance_labels_list.append(np.ones((instance_points.shape[0], 1))*i) 63 | label = RAW2SCANNET[labels[i]] 64 | label = CLASS_NAMES.index(label) 65 | semantic_labels_list.append(np.ones((instance_points.shape[0], 1))*label) 66 | 67 | # Refactor data format 68 | scene_points = np.concatenate(instance_points_list, 0) 69 | scene_points = scene_points[:,0:9] # XYZ+RGB+NORMAL 70 | instance_labels = np.concatenate(instance_labels_list, 0) 71 | semantic_labels = np.concatenate(semantic_labels_list, 0) 72 | data = np.concatenate((scene_points, instance_labels, semantic_labels), 1) 73 | 74 | if data.shape[0] > NUM_MAX_PTS: 75 | choices = np.random.choice(data.shape[0], NUM_MAX_PTS, replace=False) 76 | data = data[choices] 77 | 78 | print("shape of subsampled scene data: {}".format(data.shape)) 79 | np.save(out_filename, data) 80 | 81 | if __name__=='__main__': 82 | os.makedirs(CONF.PREP_SCANS, exist_ok=True) 83 | 84 | for i, scene_name in enumerate(CONF.SCENE_NAMES): 85 | try: 86 | start = time.time() 87 | out_filename = scene_name+'.npy' # scene0000_00.npy 88 | collect_one_scene_data_label(scene_name, os.path.join(CONF.PREP_SCANS, out_filename)) 89 | 90 | # report 91 | num_left = len(CONF.SCENE_NAMES) - i - 1 92 | eta = get_eta(start, time.time(), 0, num_left) 93 | print("preprocessed {}, {} left, ETA: {}h {}m {}s".format( 94 | scene_name, 95 | num_left, 96 | eta["h"], 97 | eta["m"], 98 | eta["s"] 99 | )) 100 | 101 | except Exception as e: 102 | print(scene_name+'ERROR!!') 103 | 104 | print("done!") -------------------------------------------------------------------------------- /scripts/compute_multiview_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import torch 5 | import argparse 6 | import numpy as np 7 | import torch.nn as nn 8 | import torchvision.transforms as transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | from imageio import imread 11 | from PIL import Image 12 | from tqdm import tqdm 13 | 14 | sys.path.append(os.path.join(os.getcwd())) # HACK add the root folder 15 | from lib.enet import create_enet_for_3d 16 | from lib.config import CONF 17 | 18 | # scannet data 19 | # NOTE: read only! 20 | SCANNET_FRAME_ROOT = CONF.SCANNET_FRAMES 21 | SCANNET_FRAME_PATH = os.path.join(SCANNET_FRAME_ROOT, "{}") # name of the file 22 | SCANNET_LIST = CONF.SCANNETV2_LIST 23 | 24 | ENET_PATH = CONF.ENET_WEIGHTS 25 | ENET_FEATURE_ROOT = CONF.ENET_FEATURES_SUBROOT 26 | ENET_FEATURE_PATH = CONF.ENET_FEATURES_PATH 27 | 28 | class EnetDataset(Dataset): 29 | def __init__(self): 30 | self._init_resources() 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | def __getitem__(self, idx): 36 | scene_id, frame_id = self.data[idx] 37 | image = self._load_image(SCANNET_FRAME_PATH.format(scene_id, "color", "{}.jpg".format(frame_id)), [328, 256]) 38 | 39 | return scene_id, frame_id, image 40 | 41 | def _init_resources(self): 42 | self._get_scene_list() 43 | self.data = [] 44 | for scene_id in self.scene_list: 45 | frame_list = sorted(os.listdir(SCANNET_FRAME_ROOT.format(scene_id, "color")), key=lambda x:int(x.split(".")[0])) 46 | for frame_file in frame_list: 47 | self.data.append( 48 | ( 49 | scene_id, 50 | int(frame_file.split(".")[0]) 51 | ) 52 | ) 53 | 54 | def _get_scene_list(self): 55 | with open(SCANNET_LIST, 'r') as f: 56 | self.scene_list = sorted(list(set(f.read().splitlines()))) 57 | 58 | def _resize_crop_image(self, image, new_image_dims): 59 | image_dims = [image.shape[1], image.shape[0]] 60 | if image_dims != new_image_dims: 61 | resize_width = int(math.floor(new_image_dims[1] * float(image_dims[0]) / float(image_dims[1]))) 62 | image = transforms.Resize([new_image_dims[1], resize_width], interpolation=Image.NEAREST)(Image.fromarray(image)) 63 | image = transforms.CenterCrop([new_image_dims[1], new_image_dims[0]])(image) 64 | 65 | return np.array(image) 66 | 67 | def _load_image(self, file, image_dims): 68 | image = imread(file) 69 | # preprocess 70 | image = self._resize_crop_image(image, image_dims) 71 | if len(image.shape) == 3: # color image 72 | image = np.transpose(image, [2, 0, 1]) # move feature to front 73 | image = transforms.Normalize(mean=[0.496342, 0.466664, 0.440796], std=[0.277856, 0.28623, 0.291129])(torch.Tensor(image.astype(np.float32) / 255.0)) 74 | elif len(image.shape) == 2: # label image 75 | image = np.expand_dims(image, 0) 76 | else: 77 | raise ValueError 78 | 79 | return image 80 | 81 | def collate_fn(self, data): 82 | scene_ids, frame_ids, images = zip(*data) 83 | scene_ids = list(scene_ids) 84 | frame_ids = list(frame_ids) 85 | images = torch.stack(images, 0).cuda() 86 | 87 | return scene_ids, frame_ids, images 88 | 89 | def create_enet(): 90 | enet_fixed, enet_trainable, _ = create_enet_for_3d(41, ENET_PATH, 21) 91 | enet = nn.Sequential( 92 | enet_fixed, 93 | enet_trainable 94 | ).cuda() 95 | enet.eval() 96 | for param in enet.parameters(): 97 | param.requires_grad = False 98 | 99 | return enet 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('--gpu', type=str, help='gpu', default='0') 104 | args = parser.parse_args() 105 | 106 | # setting 107 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 108 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 109 | 110 | # init 111 | dataset = EnetDataset() 112 | dataloader = DataLoader(dataset, batch_size=256, shuffle=False, collate_fn=dataset.collate_fn) 113 | enet = create_enet() 114 | 115 | # feed 116 | print("extracting multiview features from ENet...") 117 | for scene_ids, frame_ids, images in tqdm(dataloader): 118 | features = enet(images) 119 | batch_size = images.shape[0] 120 | for batch_id in range(batch_size): 121 | os.makedirs(ENET_FEATURE_ROOT.format(scene_ids[batch_id]), exist_ok=True) 122 | np.save(ENET_FEATURE_PATH.format(scene_ids[batch_id], frame_ids[batch_id]), features[batch_id].cpu().numpy()) 123 | 124 | print("done!") 125 | 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pointnet2.ScanNet 2 | PointNet++ Semantic Segmentation on ScanNet in PyTorch with CUDA acceleration based on the original [PointNet++ repo](https://github.com/charlesq34/pointnet2) and [the PyTorch implementation with CUDA](https://github.com/sshaoshuai/Pointnet2.PyTorch) 3 | 4 | ## Performance 5 | The semantic segmentation results in percentage on the ScanNet train/val split in `data/`. 6 | 7 | | use XYZ | use color | use normal | use multiview | use MSG | mIoU | weights | 8 | |---------|-----------|------------|---------------|---------|------|---------| 9 | |:heavy_check_mark: |:heavy_check_mark: |:heavy_check_mark: |- |- |50.48 |[download](https://drive.google.com/file/d/16rsLQwonnf0vvAi4QFaUg6xCxD2pJqEP/view?usp=sharing) | 10 | |:heavy_check_mark: |:heavy_check_mark: |:heavy_check_mark: |- |:heavy_check_mark: |52.50 |[download](https://drive.google.com/file/d/1iMmuZgh8VeYO02tdOSgSKVyXDcvXPior/view?usp=sharing) | 11 | |:heavy_check_mark: |- |:heavy_check_mark: |:heavy_check_mark: |- |65.75 |[download](https://drive.google.com/file/d/1vK9VwIMu__TKOQIlwoN8XZw70FPM5loI/view?usp=sharing) | 12 | |:heavy_check_mark: |- |:heavy_check_mark: |:heavy_check_mark: |:heavy_check_mark: |67.60 |[download](https://drive.google.com/file/d/1twJmV1QuAZ2GHfp8Ae7HyJkKvWbWPK5l/view?usp=sharing) | 13 | 14 | If you want to play around with the pre-trained model, please download the zip file and unzip it under `outputs/`. 15 | 16 | 17 | ## Installation 18 | ### Requirements 19 | * Linux (tested on Ubuntu 14.04/16.04) 20 | * Python 3.6+ 21 | * PyTorch 1.8 22 | * TensorBoardX 23 | 24 | Please run `conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch` to install PyTorch 1.8 and run `pip install -r requirements.txt` to install other required packages. 25 | 26 | ### Install CUDA accelerated PointNet++ library 27 | Install this library by running the following command: 28 | 29 | ```shell 30 | cd pointnet2 31 | python setup.py install 32 | ``` 33 | 34 | ### Configure 35 | Change the path configurations for the ScanNet data in `lib/config.py` 36 | 37 | ### Prepare multiview features (optional) 38 | 1. Download the ScanNet frames [here](http://kaldir.vc.in.tum.de/3dsis/scannet_train_images.zip) (~13GB) and unzip it under the project directory. 39 | 40 | 2. Extract the multiview features from ENet: 41 | ```shell 42 | python scripts/compute_multiview_features.py 43 | ``` 44 | 45 | 3. Generate the projection mapping between image and point cloud 46 | ```shell 47 | python scripts/compute_multiview_projection.py 48 | ``` 49 | 50 | 4. Project the multiview features from image space to point cloud 51 | ```shell 52 | python scripts/project_multiview_features.py 53 | ``` 54 | 55 | > Note you might need ~100GB RAM to train the model with multiview features 56 | 57 | ## Usage 58 | ### Preprocess ScanNet scenes 59 | Parse the ScanNet data into `*.npy` files and save them in `preprocessing/scannet_scenes/` 60 | ```shell 61 | python preprocessing/collect_scannet_scenes.py 62 | ``` 63 | ### Sanity check 64 | Don't forget to visualize the preprocessed scenes to check the consistency 65 | ```shell 66 | python preprocessing/visualize_prep_scene.py --scene_id 67 | ``` 68 | The visualized `.ply` is stored in `preprocessing/label_point_clouds/` - Drag that file into MeshLab and you'll see something like this: 69 | 70 | 71 | 72 | ### train 73 | Train the PointNet++ semantic segmentation model on ScanNet scenes with raw RGB values and point normals (for more training options, see `python scripts/train.py -h`) 74 | ```shell 75 | python scripts/train.py --use_color --use_normal --use_msg 76 | ``` 77 | The trained models and logs will be saved in `outputs//` 78 | 79 | ### eval 80 | Evaluate the trained models and report the segmentation performance in point accuracy, voxel accuracy and calibrated voxel accuracy 81 | ```shell 82 | python scripts/eval.py --folder 83 | ``` 84 | 85 | > Note that all model options must match the ones used for training. 86 | 87 | ### vis 88 | Visualize the semantic segmentation results on points in a given scene 89 | ```shell 90 | python scripts/visualize.py --folder --scene_id 91 | ``` 92 | 93 | > Note that all model options must match the ones used for training. 94 | 95 | The generated `.ply` is stored in `outputs//preds` - Drag that file into MeshLab and you'll see something like the one below. See the class palette [here](http://kaldir.vc.in.tum.de/scannet_benchmark/img/legend.jpg) 96 | 97 | 98 | 99 | ## Changelog 100 | 101 | * __07/29/2021__ Upgrade to PyTorch 1.8 & fix existing issues 102 | * __03/29/2020__ Release the code 103 | 104 | ## TODOs 105 | 106 | - [x] Release all pretrained models 107 | - [x] Upgrade to PyTorch 1.8 108 | - [x] Fix issues with loading pre-trained models 109 | 110 | ## Acknowledgement 111 | * [charlesq34/pointnet2](https://github.com/charlesq34/pointnet2): Paper author and official code repo. 112 | * [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch): Initial work of PyTorch implementation of PointNet++ with CUDA acceleration. 113 | -------------------------------------------------------------------------------- /data/scannetv2_val.txt: -------------------------------------------------------------------------------- 1 | scene0568_00 2 | scene0568_01 3 | scene0568_02 4 | scene0304_00 5 | scene0488_00 6 | scene0488_01 7 | scene0412_00 8 | scene0412_01 9 | scene0217_00 10 | scene0019_00 11 | scene0019_01 12 | scene0414_00 13 | scene0575_00 14 | scene0575_01 15 | scene0575_02 16 | scene0426_00 17 | scene0426_01 18 | scene0426_02 19 | scene0426_03 20 | scene0549_00 21 | scene0549_01 22 | scene0578_00 23 | scene0578_01 24 | scene0578_02 25 | scene0665_00 26 | scene0665_01 27 | scene0050_00 28 | scene0050_01 29 | scene0050_02 30 | scene0257_00 31 | scene0025_00 32 | scene0025_01 33 | scene0025_02 34 | scene0583_00 35 | scene0583_01 36 | scene0583_02 37 | scene0701_00 38 | scene0701_01 39 | scene0701_02 40 | scene0580_00 41 | scene0580_01 42 | scene0565_00 43 | scene0169_00 44 | scene0169_01 45 | scene0655_00 46 | scene0655_01 47 | scene0655_02 48 | scene0063_00 49 | scene0221_00 50 | scene0221_01 51 | scene0591_00 52 | scene0591_01 53 | scene0591_02 54 | scene0678_00 55 | scene0678_01 56 | scene0678_02 57 | scene0462_00 58 | scene0427_00 59 | scene0595_00 60 | scene0193_00 61 | scene0193_01 62 | scene0164_00 63 | scene0164_01 64 | scene0164_02 65 | scene0164_03 66 | scene0598_00 67 | scene0598_01 68 | scene0598_02 69 | scene0599_00 70 | scene0599_01 71 | scene0599_02 72 | scene0328_00 73 | scene0300_00 74 | scene0300_01 75 | scene0354_00 76 | scene0458_00 77 | scene0458_01 78 | scene0423_00 79 | scene0423_01 80 | scene0423_02 81 | scene0307_00 82 | scene0307_01 83 | scene0307_02 84 | scene0606_00 85 | scene0606_01 86 | scene0606_02 87 | scene0432_00 88 | scene0432_01 89 | scene0608_00 90 | scene0608_01 91 | scene0608_02 92 | scene0651_00 93 | scene0651_01 94 | scene0651_02 95 | scene0430_00 96 | scene0430_01 97 | scene0689_00 98 | scene0357_00 99 | scene0357_01 100 | scene0574_00 101 | scene0574_01 102 | scene0574_02 103 | scene0329_00 104 | scene0329_01 105 | scene0329_02 106 | scene0153_00 107 | scene0153_01 108 | scene0616_00 109 | scene0616_01 110 | scene0671_00 111 | scene0671_01 112 | scene0618_00 113 | scene0382_00 114 | scene0382_01 115 | scene0490_00 116 | scene0621_00 117 | scene0607_00 118 | scene0607_01 119 | scene0149_00 120 | scene0695_00 121 | scene0695_01 122 | scene0695_02 123 | scene0695_03 124 | scene0389_00 125 | scene0377_00 126 | scene0377_01 127 | scene0377_02 128 | scene0342_00 129 | scene0139_00 130 | scene0629_00 131 | scene0629_01 132 | scene0629_02 133 | scene0496_00 134 | scene0633_00 135 | scene0633_01 136 | scene0518_00 137 | scene0652_00 138 | scene0406_00 139 | scene0406_01 140 | scene0406_02 141 | scene0144_00 142 | scene0144_01 143 | scene0494_00 144 | scene0278_00 145 | scene0278_01 146 | scene0316_00 147 | scene0609_00 148 | scene0609_01 149 | scene0609_02 150 | scene0609_03 151 | scene0084_00 152 | scene0084_01 153 | scene0084_02 154 | scene0696_00 155 | scene0696_01 156 | scene0696_02 157 | scene0351_00 158 | scene0351_01 159 | scene0643_00 160 | scene0644_00 161 | scene0645_00 162 | scene0645_01 163 | scene0645_02 164 | scene0081_00 165 | scene0081_01 166 | scene0081_02 167 | scene0647_00 168 | scene0647_01 169 | scene0535_00 170 | scene0353_00 171 | scene0353_01 172 | scene0353_02 173 | scene0559_00 174 | scene0559_01 175 | scene0559_02 176 | scene0593_00 177 | scene0593_01 178 | scene0246_00 179 | scene0653_00 180 | scene0653_01 181 | scene0064_00 182 | scene0064_01 183 | scene0356_00 184 | scene0356_01 185 | scene0356_02 186 | scene0030_00 187 | scene0030_01 188 | scene0030_02 189 | scene0222_00 190 | scene0222_01 191 | scene0338_00 192 | scene0338_01 193 | scene0338_02 194 | scene0378_00 195 | scene0378_01 196 | scene0378_02 197 | scene0660_00 198 | scene0553_00 199 | scene0553_01 200 | scene0553_02 201 | scene0527_00 202 | scene0663_00 203 | scene0663_01 204 | scene0663_02 205 | scene0664_00 206 | scene0664_01 207 | scene0664_02 208 | scene0334_00 209 | scene0334_01 210 | scene0334_02 211 | scene0046_00 212 | scene0046_01 213 | scene0046_02 214 | scene0203_00 215 | scene0203_01 216 | scene0203_02 217 | scene0088_00 218 | scene0088_01 219 | scene0088_02 220 | scene0088_03 221 | scene0086_00 222 | scene0086_01 223 | scene0086_02 224 | scene0670_00 225 | scene0670_01 226 | scene0256_00 227 | scene0256_01 228 | scene0256_02 229 | scene0249_00 230 | scene0441_00 231 | scene0658_00 232 | scene0704_00 233 | scene0704_01 234 | scene0187_00 235 | scene0187_01 236 | scene0131_00 237 | scene0131_01 238 | scene0131_02 239 | scene0207_00 240 | scene0207_01 241 | scene0207_02 242 | scene0461_00 243 | scene0011_00 244 | scene0011_01 245 | scene0343_00 246 | scene0251_00 247 | scene0077_00 248 | scene0077_01 249 | scene0684_00 250 | scene0684_01 251 | scene0550_00 252 | scene0686_00 253 | scene0686_01 254 | scene0686_02 255 | scene0208_00 256 | scene0500_00 257 | scene0500_01 258 | scene0552_00 259 | scene0552_01 260 | scene0648_00 261 | scene0648_01 262 | scene0435_00 263 | scene0435_01 264 | scene0435_02 265 | scene0435_03 266 | scene0690_00 267 | scene0690_01 268 | scene0693_00 269 | scene0693_01 270 | scene0693_02 271 | scene0700_00 272 | scene0700_01 273 | scene0700_02 274 | scene0699_00 275 | scene0231_00 276 | scene0231_01 277 | scene0231_02 278 | scene0697_00 279 | scene0697_01 280 | scene0697_02 281 | scene0697_03 282 | scene0474_00 283 | scene0474_01 284 | scene0474_02 285 | scene0474_03 286 | scene0474_04 287 | scene0474_05 288 | scene0355_00 289 | scene0355_01 290 | scene0146_00 291 | scene0146_01 292 | scene0146_02 293 | scene0196_00 294 | scene0702_00 295 | scene0702_01 296 | scene0702_02 297 | scene0314_00 298 | scene0277_00 299 | scene0277_01 300 | scene0277_02 301 | scene0095_00 302 | scene0095_01 303 | scene0015_00 304 | scene0100_00 305 | scene0100_01 306 | scene0100_02 307 | scene0558_00 308 | scene0558_01 309 | scene0558_02 310 | scene0685_00 311 | scene0685_01 312 | scene0685_02 313 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 71 | // output: out(b, c, n) 72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 73 | const float *__restrict__ points, 74 | const int *__restrict__ idx, 75 | const float *__restrict__ weight, 76 | float *__restrict__ out) { 77 | int batch_index = blockIdx.x; 78 | points += batch_index * m * c; 79 | 80 | idx += batch_index * n * 3; 81 | weight += batch_index * n * 3; 82 | 83 | out += batch_index * n * c; 84 | 85 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 86 | const int stride = blockDim.y * blockDim.x; 87 | for (int i = index; i < c * n; i += stride) { 88 | const int l = i / n; 89 | const int j = i % n; 90 | float w1 = weight[j * 3 + 0]; 91 | float w2 = weight[j * 3 + 1]; 92 | float w3 = weight[j * 3 + 2]; 93 | 94 | int i1 = idx[j * 3 + 0]; 95 | int i2 = idx[j * 3 + 1]; 96 | int i3 = idx[j * 3 + 2]; 97 | 98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 99 | points[l * m + i3] * w3; 100 | } 101 | } 102 | 103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 104 | const float *points, const int *idx, 105 | const float *weight, float *out) { 106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 107 | three_interpolate_kernel<<>>( 108 | b, c, m, n, points, idx, weight, out); 109 | 110 | CUDA_CHECK_ERRORS(); 111 | } 112 | 113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 114 | // output: grad_points(b, c, m) 115 | 116 | __global__ void three_interpolate_grad_kernel( 117 | int b, int c, int n, int m, const float *__restrict__ grad_out, 118 | const int *__restrict__ idx, const float *__restrict__ weight, 119 | float *__restrict__ grad_points) { 120 | int batch_index = blockIdx.x; 121 | grad_out += batch_index * n * c; 122 | idx += batch_index * n * 3; 123 | weight += batch_index * n * 3; 124 | grad_points += batch_index * m * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 142 | } 143 | } 144 | 145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 146 | const float *grad_out, 147 | const int *idx, const float *weight, 148 | float *grad_points) { 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | three_interpolate_grad_kernel<<>>( 151 | b, c, n, m, grad_out, idx, weight, grad_points); 152 | 153 | CUDA_CHECK_ERRORS(); 154 | } 155 | -------------------------------------------------------------------------------- /pointnet2/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | #include "interpolate_gpu.h" 7 | 8 | 9 | __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, 10 | const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { 11 | // unknown: (B, N, 3) 12 | // known: (B, M, 3) 13 | // output: 14 | // dist2: (B, N, 3) 15 | // idx: (B, N, 3) 16 | 17 | int bs_idx = blockIdx.y; 18 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 19 | if (bs_idx >= b || pt_idx >= n) return; 20 | 21 | unknown += bs_idx * n * 3 + pt_idx * 3; 22 | known += bs_idx * m * 3; 23 | dist2 += bs_idx * n * 3 + pt_idx * 3; 24 | idx += bs_idx * n * 3 + pt_idx * 3; 25 | 26 | float ux = unknown[0]; 27 | float uy = unknown[1]; 28 | float uz = unknown[2]; 29 | 30 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 31 | int besti1 = 0, besti2 = 0, besti3 = 0; 32 | for (int k = 0; k < m; ++k) { 33 | float x = known[k * 3 + 0]; 34 | float y = known[k * 3 + 1]; 35 | float z = known[k * 3 + 2]; 36 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 37 | if (d < best1) { 38 | best3 = best2; besti3 = besti2; 39 | best2 = best1; besti2 = besti1; 40 | best1 = d; besti1 = k; 41 | } 42 | else if (d < best2) { 43 | best3 = best2; besti3 = besti2; 44 | best2 = d; besti2 = k; 45 | } 46 | else if (d < best3) { 47 | best3 = d; besti3 = k; 48 | } 49 | } 50 | dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; 51 | idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; 52 | } 53 | 54 | 55 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 56 | const float *known, float *dist2, int *idx, cudaStream_t stream) { 57 | // unknown: (B, N, 3) 58 | // known: (B, M, 3) 59 | // output: 60 | // dist2: (B, N, 3) 61 | // idx: (B, N, 3) 62 | 63 | cudaError_t err; 64 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 65 | dim3 threads(THREADS_PER_BLOCK); 66 | 67 | three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx); 68 | 69 | err = cudaGetLastError(); 70 | if (cudaSuccess != err) { 71 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 72 | exit(-1); 73 | } 74 | } 75 | 76 | 77 | __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, 78 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { 79 | // points: (B, C, M) 80 | // idx: (B, N, 3) 81 | // weight: (B, N, 3) 82 | // output: 83 | // out: (B, C, N) 84 | 85 | int bs_idx = blockIdx.z; 86 | int c_idx = blockIdx.y; 87 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 88 | 89 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 90 | 91 | weight += bs_idx * n * 3 + pt_idx * 3; 92 | points += bs_idx * c * m + c_idx * m; 93 | idx += bs_idx * n * 3 + pt_idx * 3; 94 | out += bs_idx * c * n + c_idx * n; 95 | 96 | out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; 97 | } 98 | 99 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 100 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { 101 | // points: (B, C, M) 102 | // idx: (B, N, 3) 103 | // weight: (B, N, 3) 104 | // output: 105 | // out: (B, C, N) 106 | 107 | cudaError_t err; 108 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 109 | dim3 threads(THREADS_PER_BLOCK); 110 | three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out); 111 | 112 | err = cudaGetLastError(); 113 | if (cudaSuccess != err) { 114 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 115 | exit(-1); 116 | } 117 | } 118 | 119 | 120 | __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 121 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { 122 | // grad_out: (B, C, N) 123 | // weight: (B, N, 3) 124 | // output: 125 | // grad_points: (B, C, M) 126 | 127 | int bs_idx = blockIdx.z; 128 | int c_idx = blockIdx.y; 129 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 130 | 131 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 132 | 133 | grad_out += bs_idx * c * n + c_idx * n + pt_idx; 134 | weight += bs_idx * n * 3 + pt_idx * 3; 135 | grad_points += bs_idx * c * m + c_idx * m; 136 | idx += bs_idx * n * 3 + pt_idx * 3; 137 | 138 | 139 | atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); 140 | atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); 141 | atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); 142 | } 143 | 144 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 145 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { 146 | // grad_out: (B, C, N) 147 | // weight: (B, N, 3) 148 | // output: 149 | // grad_points: (B, C, M) 150 | 151 | cudaError_t err; 152 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 153 | dim3 threads(THREADS_PER_BLOCK); 154 | three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points); 155 | 156 | err = cudaGetLastError(); 157 | if (cudaSuccess != err) { 158 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 159 | exit(-1); 160 | } 161 | } -------------------------------------------------------------------------------- /scripts/compute_multiview_projection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import numpy as np 5 | import argparse 6 | import torch 7 | import torchvision.transforms as transforms 8 | from imageio import imread 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | sys.path.append(os.path.join(os.getcwd())) # HACK add the root folder 13 | from lib.config import CONF 14 | from lib.projection import ProjectionHelper 15 | 16 | # data path 17 | SCANNET_LIST = CONF.SCANNETV2_LIST 18 | SCANNET_DATA = CONF.PREP_SCANS 19 | PROJECTION_ROOT = CONF.PROJECTION 20 | PROJECTION_PATH = os.path.join(PROJECTION_ROOT, "{}_{}.npy") # scene_id, mode 21 | 22 | # scannet data 23 | # NOTE: read only! 24 | SCANNET_FRAME_ROOT = CONF.SCANNET_FRAMES 25 | SCANNET_FRAME_PATH = os.path.join(SCANNET_FRAME_ROOT, "{}") # name of the file 26 | 27 | # projection 28 | INTRINSICS = [[37.01983, 0, 20, 0],[0, 38.52470, 15.5, 0],[0, 0, 1, 0],[0, 0, 0, 1]] 29 | PROJECTOR = ProjectionHelper(INTRINSICS, 0.1, 4.0, [41, 32], 0.05) 30 | 31 | def get_scene_list(): 32 | with open(SCANNET_LIST, 'r') as f: 33 | scene_list = sorted(list(set(f.read().splitlines()))) 34 | 35 | return scene_list 36 | 37 | def load_scene(scene_list): 38 | scene_data = {} 39 | for scene_id in scene_list: 40 | scene_data[scene_id] = np.load(os.path.join(SCANNET_DATA, scene_id)+".npy")[:, :3] 41 | 42 | return scene_data 43 | 44 | def resize_crop_image(image, new_image_dims): 45 | image_dims = [image.shape[1], image.shape[0]] 46 | if image_dims != new_image_dims: 47 | resize_width = int(math.floor(new_image_dims[1] * float(image_dims[0]) / float(image_dims[1]))) 48 | image = transforms.Resize([new_image_dims[1], resize_width], interpolation=Image.NEAREST)(Image.fromarray(image)) 49 | image = transforms.CenterCrop([new_image_dims[1], new_image_dims[0]])(image) 50 | 51 | return np.array(image) 52 | 53 | def load_pose(filename): 54 | lines = open(filename).read().splitlines() 55 | assert len(lines) == 4 56 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 57 | 58 | return np.asarray(lines).astype(np.float32) 59 | 60 | def load_image(file, image_dims): 61 | image = imread(file) 62 | # preprocess 63 | image = resize_crop_image(image, image_dims) 64 | if len(image.shape) == 3: # color image 65 | image = np.transpose(image, [2, 0, 1]) # move feature to front 66 | image = transforms.Normalize(mean=[0.496342, 0.466664, 0.440796], std=[0.277856, 0.28623, 0.291129])(torch.Tensor(image.astype(np.float32) / 255.0)) 67 | elif len(image.shape) == 2: # label image 68 | image = np.expand_dims(image, 0) 69 | else: 70 | raise ValueError 71 | 72 | return image 73 | 74 | def load_depth(file, image_dims): 75 | depth_image = imread(file) 76 | # preprocess 77 | depth_image = resize_crop_image(depth_image, image_dims) 78 | depth_image = depth_image.astype(np.float32) / 1000.0 79 | 80 | return depth_image 81 | 82 | def to_tensor(arr): 83 | return torch.Tensor(arr).cuda() 84 | 85 | def compute_projection(points, depth, camera_to_world): 86 | """ 87 | :param points: tensor containing all points of the point cloud (num_points, 3) 88 | :param depth: depth map (size: proj_image) 89 | :param camera_to_world: camera pose (4, 4) 90 | 91 | :return indices_3d (array with point indices that correspond to a pixel), 92 | :return indices_2d (array with pixel indices that correspond to a point) 93 | 94 | note: 95 | the first digit of indices represents the number of relevant points 96 | the rest digits are for the projection mapping 97 | """ 98 | num_points = points.shape[0] 99 | num_frames = depth.shape[0] 100 | indices_3ds = torch.zeros(num_frames, num_points + 1).long().cuda() 101 | indices_2ds = torch.zeros(num_frames, num_points + 1).long().cuda() 102 | 103 | for i in range(num_frames): 104 | indices = PROJECTOR.compute_projection(to_tensor(points), to_tensor(depth[i]), to_tensor(camera_to_world[i])) 105 | if indices: 106 | indices_3ds[i] = indices[0].long() 107 | indices_2ds[i] = indices[1].long() 108 | 109 | return indices_3ds, indices_2ds 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument('--gpu', type=str, help='gpu', default='0') 114 | args = parser.parse_args() 115 | 116 | # setting 117 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 118 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 119 | 120 | # make dir 121 | os.makedirs(PROJECTION_ROOT, exist_ok=True) 122 | scene_list = get_scene_list() 123 | scene_data = load_scene(scene_list) 124 | 125 | print("computing multiview projections...") 126 | for scene_id in tqdm(scene_list): 127 | # if os.path.exists(PROJECTION_PATH.format(scene_id, "3d")) and os.path.exists(PROJECTION_PATH.format(scene_id, "2d")): 128 | # print("skipping {}...".format(scene_id)) 129 | # continue 130 | 131 | point_cloud = scene_data[scene_id] 132 | frame_list = list(map(lambda x: x.split(".")[0], os.listdir(SCANNET_FRAME_ROOT.format(scene_id, "color")))) 133 | 134 | # load frames 135 | scene_images = np.zeros((len(frame_list), 3, 256, 328)) 136 | scene_depths = np.zeros((len(frame_list), 32, 41)) 137 | scene_poses = np.zeros((len(frame_list), 4, 4)) 138 | for i, frame_id in enumerate(frame_list): 139 | scene_images[i] = load_image(SCANNET_FRAME_PATH.format(scene_id, "color", "{}.jpg".format(frame_id)), [328, 256]) 140 | scene_depths[i] = load_depth(SCANNET_FRAME_PATH.format(scene_id, "depth", "{}.png".format(frame_id)), [41, 32]) 141 | scene_poses[i] = load_pose(SCANNET_FRAME_PATH.format(scene_id, "pose", "{}.txt".format(frame_id))) 142 | 143 | projection_3d, projection_2d = compute_projection(point_cloud[:, :3], scene_depths, scene_poses) 144 | np.save(PROJECTION_PATH.format(scene_id, "3d"), projection_3d.cpu().numpy()) 145 | np.save(PROJECTION_PATH.format(scene_id, "2d"), projection_2d.cpu().numpy()) 146 | 147 | print("done!") 148 | 149 | -------------------------------------------------------------------------------- /scripts/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import importlib 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from plyfile import PlyElement, PlyData 8 | 9 | # for PointNet2.PyTorch module 10 | import sys 11 | sys.path.append(".") 12 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../pointnet2/')) 13 | from lib.config import CONF 14 | from lib.dataset import ScannetDatasetWholeScene, collate_wholescene 15 | 16 | def forward(args, model, coords, feats): 17 | pred = [] 18 | coord_chunk, feat_chunk = torch.split(coords.squeeze(0), args.batch_size, 0), torch.split(feats.squeeze(0), args.batch_size, 0) 19 | assert len(coord_chunk) == len(feat_chunk) 20 | for coord, feat in zip(coord_chunk, feat_chunk): 21 | output = model(torch.cat([coord, feat], dim=2)) 22 | pred.append(output) 23 | 24 | pred = torch.cat(pred, dim=0) # (CK, N, C) 25 | outputs = pred.max(2)[1] 26 | 27 | return outputs 28 | 29 | def filter_points(coords, preds): 30 | assert coords.shape[0] == preds.shape[0] 31 | 32 | _, coord_ids = np.unique(coords, axis=0, return_index=True) 33 | coord_filtered, pred_filtered = coords[coord_ids], preds[coord_ids] 34 | # coord_filtered, pred_filtered = coords, preds 35 | filtered = [] 36 | for point_idx in range(coord_filtered.shape[0]): 37 | filtered.append( 38 | [ 39 | coord_filtered[point_idx][0], 40 | coord_filtered[point_idx][1], 41 | coord_filtered[point_idx][2], 42 | CONF.PALETTE[pred_filtered[point_idx]][0], 43 | CONF.PALETTE[pred_filtered[point_idx]][1], 44 | CONF.PALETTE[pred_filtered[point_idx]][2] 45 | ] 46 | ) 47 | 48 | return np.array(filtered) 49 | 50 | 51 | def predict_label(args, model, dataloader): 52 | output_coords, output_preds = [], [] 53 | print("predicting labels...") 54 | for data in dataloader: 55 | # unpack 56 | coords, feats, targets, weights, _ = data 57 | coords, feats, targets, weights = coords.cuda(), feats.cuda(), targets.cuda(), weights.cuda() 58 | 59 | # feed 60 | preds = forward(args, model, coords, feats) 61 | 62 | # dump 63 | coords = coords.squeeze(0).view(-1, 3).cpu().numpy() 64 | preds = preds.view(-1).cpu().numpy() 65 | output_coords.append(coords) 66 | output_preds.append(preds) 67 | 68 | print("filtering points...") 69 | output_coords = np.concatenate(output_coords, axis=0) 70 | output_preds = np.concatenate(output_preds, axis=0) 71 | filtered = filter_points(output_coords, output_preds) 72 | 73 | return filtered 74 | 75 | def visualize(args, preds): 76 | vertex = [] 77 | for i in range(preds.shape[0]): 78 | vertex.append( 79 | ( 80 | preds[i][0], 81 | preds[i][1], 82 | preds[i][2], 83 | preds[i][3], 84 | preds[i][4], 85 | preds[i][5], 86 | ) 87 | ) 88 | 89 | vertex = np.array( 90 | vertex, 91 | dtype=[ 92 | ("x", np.dtype("float32")), 93 | ("y", np.dtype("float32")), 94 | ("z", np.dtype("float32")), 95 | ("red", np.dtype("uint8")), 96 | ("green", np.dtype("uint8")), 97 | ("blue", np.dtype("uint8")) 98 | ] 99 | ) 100 | 101 | output_pc = PlyElement.describe(vertex, "vertex") 102 | output_pc = PlyData([output_pc]) 103 | output_root = os.path.join(CONF.OUTPUT_ROOT, args.folder, "preds") 104 | os.makedirs(output_root, exist_ok=True) 105 | output_pc.write(os.path.join(output_root, "{}.ply".format(args.scene_id))) 106 | 107 | 108 | def get_scene_list(args): 109 | scene_list = [] 110 | if args.scene_id: 111 | scene_list.append(args.scene_id) 112 | else: 113 | raise ValueError("Select a scene to visualize") 114 | 115 | return scene_list 116 | 117 | def evaluate(args): 118 | # prepare data 119 | print("preparing data...") 120 | scene_list = get_scene_list(args) 121 | dataset = ScannetDatasetWholeScene(scene_list, use_color=args.use_color, use_normal=args.use_normal, use_multiview=args.use_multiview) 122 | dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_wholescene) 123 | 124 | # load model 125 | print("loading model...") 126 | model_path = os.path.join(CONF.OUTPUT_ROOT, args.folder, "model.pth") 127 | Pointnet = importlib.import_module("pointnet2_semseg") 128 | input_channels = int(args.use_color) * 3 + int(args.use_normal) * 3 + int(args.use_multiview) * 128 129 | model = Pointnet.get_model(num_classes=CONF.NUM_CLASSES, is_msg=args.use_msg, input_channels=input_channels, use_xyz=not args.no_xyz, bn=not args.no_bn).cuda() 130 | model.load_state_dict(torch.load(model_path)) 131 | model.eval() 132 | 133 | # predict 134 | print("predicting...") 135 | preds = predict_label(args, model, dataloader) 136 | 137 | # visualize 138 | print("visualizing...") 139 | visualize(args, preds) 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument('--folder', type=str, help='output folder containing the best model from training', required=True) 145 | parser.add_argument('--batch_size', type=int, help='size of the batch/chunk', default=1) 146 | parser.add_argument('--gpu', type=str, help='gpu', default='0') 147 | parser.add_argument("--scene_id", type=str, default=None) 148 | parser.add_argument('--no_bn', action="store_true", help="do not apply batch normalization in pointnet++") 149 | parser.add_argument('--no_xyz', action="store_true", help="do not apply coordinates as features in pointnet++") 150 | parser.add_argument("--use_msg", action="store_true", help="apply multiscale grouping or not") 151 | parser.add_argument("--use_color", action="store_true", help="use color values or not") 152 | parser.add_argument("--use_normal", action="store_true", help="use normals or not") 153 | parser.add_argument("--use_multiview", action="store_true", help="use multiview image features or not") 154 | args = parser.parse_args() 155 | 156 | # setting 157 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 158 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 159 | 160 | evaluate(args) 161 | print("done!") -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import argparse 5 | import importlib 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import numpy as np 10 | from datetime import datetime 11 | from torch.utils.data import DataLoader 12 | 13 | sys.path.append(os.path.join(os.getcwd())) # HACK add the root folder 14 | from lib.solver import Solver 15 | from lib.dataset import ScannetDataset, ScannetDatasetWholeScene, collate_random, collate_wholescene 16 | from lib.loss import WeightedCrossEntropyLoss 17 | from lib.config import CONF 18 | 19 | 20 | def get_dataloader(args, scene_list, phase): 21 | if args.use_wholescene: 22 | dataset = ScannetDatasetWholeScene(scene_list, is_weighting=not args.no_weighting, use_color=args.use_color, use_normal=args.use_normal, use_multiview=args.use_multiview) 23 | dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_wholescene, num_workers=args.num_workers, pin_memory=True) 24 | else: 25 | dataset = ScannetDataset(phase, scene_list, is_weighting=not args.no_weighting, use_color=args.use_color, use_normal=args.use_normal, use_multiview=args.use_multiview) 26 | dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_random, num_workers=args.num_workers, pin_memory=True) 27 | 28 | return dataset, dataloader 29 | 30 | def get_num_params(model): 31 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 32 | num_params = int(sum([np.prod(p.size()) for p in model_parameters])) 33 | 34 | return num_params 35 | 36 | def get_solver(args, dataset, dataloader, stamp, weight): 37 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../pointnet2/')) 38 | Pointnet = importlib.import_module("pointnet2_semseg") 39 | input_channels = int(args.use_color) * 3 + int(args.use_normal) * 3 + int(args.use_multiview) * 128 40 | model = Pointnet.get_model(num_classes=CONF.NUM_CLASSES, is_msg=args.use_msg, input_channels=input_channels, use_xyz=not args.no_xyz, bn=not args.no_bn).cuda() 41 | 42 | num_params = get_num_params(model) 43 | criterion = WeightedCrossEntropyLoss() 44 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) 45 | solver = Solver(model, dataset, dataloader, criterion, optimizer, args.batch_size, stamp, args.use_wholescene, args.ds, args.df) 46 | 47 | return solver, num_params 48 | 49 | def get_scene_list(path): 50 | scene_list = [] 51 | with open(path) as f: 52 | for scene_id in f.readlines(): 53 | scene_list.append(scene_id.strip()) 54 | 55 | return scene_list 56 | 57 | def save_info(args, root, train_examples, val_examples, num_params): 58 | info = {} 59 | for key, value in vars(args).items(): 60 | info[key] = value 61 | 62 | info["num_train"] = train_examples 63 | info["num_val"] = val_examples 64 | info["num_params"] = num_params 65 | 66 | with open(os.path.join(root, "info.json"), "w") as f: 67 | json.dump(info, f, indent=4) 68 | 69 | def train(args): 70 | # init training dataset 71 | print("preparing data...") 72 | if args.debug: 73 | train_scene_list = ["scene0000_00"] 74 | val_scene_list = ["scene0000_00"] 75 | else: 76 | train_scene_list = get_scene_list(CONF.SCANNETV2_TRAIN) 77 | val_scene_list = get_scene_list(CONF.SCANNETV2_VAL) 78 | 79 | # dataloader 80 | train_dataset, train_dataloader = get_dataloader(args, train_scene_list, "train") 81 | val_dataset, val_dataloader = get_dataloader(args, val_scene_list, "val") 82 | dataset = { 83 | "train": train_dataset, 84 | "val": val_dataset 85 | } 86 | dataloader = { 87 | "train": train_dataloader, 88 | "val": val_dataloader 89 | } 90 | weight = train_dataset.labelweights 91 | train_examples = len(train_dataset) 92 | val_examples = len(val_dataset) 93 | 94 | print("initializing...") 95 | stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 96 | if args.tag: stamp += "_"+args.tag.upper() 97 | root = os.path.join(CONF.OUTPUT_ROOT, stamp) 98 | os.makedirs(root, exist_ok=True) 99 | solver, num_params = get_solver(args, dataset, dataloader, stamp, weight) 100 | 101 | print("\n[info]") 102 | print("Train examples: {}".format(train_examples)) 103 | print("Evaluation examples: {}".format(val_examples)) 104 | print("Start training...\n") 105 | save_info(args, root, train_examples, val_examples, num_params) 106 | solver(args.epoch, args.verbose) 107 | 108 | if __name__ == '__main__': 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument('--tag', type=str, help="tag for the training, e.g. cuda_wl", default="cuda_wl") 111 | parser.add_argument('--gpu', type=str, help='gpu', default='0') 112 | parser.add_argument('--batch_size', type=int, help='batch size', default=32) 113 | parser.add_argument('--epoch', type=int, help='number of epochs', default=500) 114 | parser.add_argument('--verbose', type=int, help='iterations of showing verbose', default=10) 115 | parser.add_argument('--num_workers', type=int, help='number of workers in dataloader', default=0) 116 | parser.add_argument('--lr', type=float, help='learning rate', default=1e-3) 117 | parser.add_argument('--wd', type=float, help='weight decay', default=0) 118 | parser.add_argument('--ds', type=int, help='decay step', default=100) 119 | parser.add_argument('--df', type=float, help='decay factor', default=0.7) 120 | parser.add_argument("--debug", action="store_true") 121 | parser.add_argument("--no_weighting", action="store_true", help="weight the classes") 122 | parser.add_argument('--no_bn', action="store_true", help="do not apply batch normalization in pointnet++") 123 | parser.add_argument('--no_xyz', action="store_true", help="do not apply coordinates as features in pointnet++") 124 | parser.add_argument("--use_wholescene", action="store_true", help="on the whole scene or on a random chunk") 125 | parser.add_argument("--use_msg", action="store_true", help="apply multiscale grouping or not") 126 | parser.add_argument("--use_color", action="store_true", help="use color values or not") 127 | parser.add_argument("--use_normal", action="store_true", help="use normals or not") 128 | parser.add_argument("--use_multiview", action="store_true", help="use multiview image features or not") 129 | args = parser.parse_args() 130 | 131 | # setting 132 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 133 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 134 | 135 | train(args) -------------------------------------------------------------------------------- /scripts/project_multiview_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import h5py 4 | import torch 5 | import torch.nn as nn 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | from plyfile import PlyData, PlyElement 10 | import math 11 | from imageio import imread 12 | from PIL import Image 13 | import torchvision.transforms as transforms 14 | 15 | sys.path.append(os.path.join(os.getcwd())) # HACK add the root folder 16 | from lib.config import CONF 17 | from lib.projection import ProjectionHelper 18 | 19 | SCANNET_LIST = CONF.SCANNETV2_LIST 20 | SCANNET_DATA = CONF.PREP_SCANS 21 | SCANNET_FRAME_ROOT = CONF.SCANNET_FRAMES 22 | SCANNET_FRAME_PATH = os.path.join(SCANNET_FRAME_ROOT, "{}") # name of the file 23 | 24 | ENET_FEATURE_PATH = CONF.ENET_FEATURES_PATH 25 | ENET_FEATURE_DATABASE = CONF.MULTIVIEW 26 | 27 | # projection 28 | INTRINSICS = [[37.01983, 0, 20, 0],[0, 38.52470, 15.5, 0],[0, 0, 1, 0],[0, 0, 0, 1]] 29 | PROJECTOR = ProjectionHelper(INTRINSICS, 0.1, 4.0, [41, 32], 0.05) 30 | 31 | def get_scene_list(): 32 | with open(SCANNET_LIST, 'r') as f: 33 | return sorted(list(set(f.read().splitlines()))) 34 | 35 | def to_tensor(arr): 36 | return torch.Tensor(arr).cuda() 37 | 38 | def resize_crop_image(image, new_image_dims): 39 | image_dims = [image.shape[1], image.shape[0]] 40 | if image_dims == new_image_dims: 41 | return image 42 | resize_width = int(math.floor(new_image_dims[1] * float(image_dims[0]) / float(image_dims[1]))) 43 | image = transforms.Resize([new_image_dims[1], resize_width], interpolation=Image.NEAREST)(Image.fromarray(image)) 44 | image = transforms.CenterCrop([new_image_dims[1], new_image_dims[0]])(image) 45 | image = np.array(image) 46 | 47 | return image 48 | 49 | def load_image(file, image_dims): 50 | image = imread(file) 51 | # preprocess 52 | image = resize_crop_image(image, image_dims) 53 | if len(image.shape) == 3: # color image 54 | image = np.transpose(image, [2, 0, 1]) # move feature to front 55 | image = transforms.Normalize(mean=[0.496342, 0.466664, 0.440796], std=[0.277856, 0.28623, 0.291129])(torch.Tensor(image.astype(np.float32) / 255.0)) 56 | elif len(image.shape) == 2: # label image 57 | # image = np.expand_dims(image, 0) 58 | pass 59 | else: 60 | raise 61 | 62 | return image 63 | 64 | def load_pose(filename): 65 | lines = open(filename).read().splitlines() 66 | assert len(lines) == 4 67 | lines = [[x[0],x[1],x[2],x[3]] for x in (x.split(" ") for x in lines)] 68 | 69 | return np.asarray(lines).astype(np.float32) 70 | 71 | def load_depth(file, image_dims): 72 | depth_image = imread(file) 73 | # preprocess 74 | depth_image = resize_crop_image(depth_image, image_dims) 75 | depth_image = depth_image.astype(np.float32) / 1000.0 76 | 77 | return depth_image 78 | 79 | def get_scene_data(scene_list): 80 | scene_data = {} 81 | for scene_id in scene_list: 82 | scene_data[scene_id] = np.load(os.path.join(SCANNET_DATA, scene_id)+".npy")[:, :3] 83 | 84 | return scene_data 85 | 86 | def compute_projection(points, depth, camera_to_world): 87 | """ 88 | :param points: tensor containing all points of the point cloud (num_points, 3) 89 | :param depth: depth map (size: proj_image) 90 | :param camera_to_world: camera pose (4, 4) 91 | 92 | :return indices_3d (array with point indices that correspond to a pixel), 93 | :return indices_2d (array with pixel indices that correspond to a point) 94 | 95 | note: 96 | the first digit of indices represents the number of relevant points 97 | the rest digits are for the projection mapping 98 | """ 99 | num_points = points.shape[0] 100 | num_frames = depth.shape[0] 101 | indices_3ds = torch.zeros(num_frames, num_points + 1).long().cuda() 102 | indices_2ds = torch.zeros(num_frames, num_points + 1).long().cuda() 103 | 104 | for i in range(num_frames): 105 | indices = PROJECTOR.compute_projection(to_tensor(points), to_tensor(depth[i]), to_tensor(camera_to_world[i])) 106 | if indices: 107 | indices_3ds[i] = indices[0].long() 108 | indices_2ds[i] = indices[1].long() 109 | 110 | return indices_3ds, indices_2ds 111 | 112 | if __name__ == "__main__": 113 | scene_list = get_scene_list() 114 | scene_data = get_scene_data(scene_list) 115 | with h5py.File(ENET_FEATURE_DATABASE, "w", libver="latest") as database: 116 | print("projecting multiview features to point cloud...") 117 | for scene_id in tqdm(scene_list): 118 | scene = scene_data[scene_id] 119 | # load frames 120 | frame_list = list(map(lambda x: x.split(".")[0], os.listdir(SCANNET_FRAME_ROOT.format(scene_id, "color")))) 121 | scene_images = np.zeros((len(frame_list), 3, 256, 328)) 122 | scene_depths = np.zeros((len(frame_list), 32, 41)) 123 | scene_poses = np.zeros((len(frame_list), 4, 4)) 124 | for i, frame_id in enumerate(frame_list): 125 | scene_images[i] = load_image(SCANNET_FRAME_PATH.format(scene_id, "color", "{}.jpg".format(frame_id)), [328, 256]) 126 | scene_depths[i] = load_depth(SCANNET_FRAME_PATH.format(scene_id, "depth", "{}.png".format(frame_id)), [41, 32]) 127 | scene_poses[i] = load_pose(SCANNET_FRAME_PATH.format(scene_id, "pose", "{}.txt".format(frame_id))) 128 | 129 | # compute projections for each chunk 130 | projection_3d, projection_2d = compute_projection(scene, scene_depths, scene_poses) 131 | _, inds = torch.sort(projection_3d[:, 0], descending=True) 132 | projection_3d, projection_2d = projection_3d[inds], projection_2d[inds] 133 | 134 | # compute valid projections 135 | projections = [] 136 | for i in range(projection_3d.shape[0]): 137 | num_valid = projection_3d[i, 0] 138 | if num_valid == 0: 139 | continue 140 | 141 | projections.append((frame_list[inds[i].long().item()], projection_3d[i], projection_2d[i])) 142 | 143 | # project 144 | point_features = to_tensor(scene).new(scene.shape[0], 128).fill_(0) 145 | for i, projection in enumerate(projections): 146 | frame_id = projection[0] 147 | projection_3d = projection[1] 148 | projection_2d = projection[2] 149 | feat = to_tensor(np.load(ENET_FEATURE_PATH.format(scene_id, frame_id))) 150 | proj_feat = PROJECTOR.project(feat, projection_3d, projection_2d, scene.shape[0]).transpose(1, 0) 151 | if i == 0: 152 | point_features = proj_feat 153 | else: 154 | mask = ((point_features == 0).sum(1) == 128).nonzero().squeeze(1) 155 | point_features[mask] = proj_feat[mask] 156 | 157 | # save 158 | database.create_dataset(scene_id, data=point_features.cpu().numpy()) 159 | 160 | print("done!") 161 | 162 | 163 | -------------------------------------------------------------------------------- /pointnet2/pointnet2_semseg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pointnet2_modules import PointnetFPModule, PointnetSAModule, PointnetSAModuleMSG 4 | import pytorch_utils as pt_utils 5 | 6 | 7 | def get_model(num_classes, is_msg=True, input_channels=6, use_xyz=True, bn=True): 8 | if is_msg: 9 | model = Pointnet2MSG( 10 | num_classes=num_classes, 11 | input_channels=input_channels, 12 | use_xyz=use_xyz, 13 | bn=bn 14 | ) 15 | else: 16 | model = Pointnet2SSG( 17 | num_classes=num_classes, 18 | input_channels=input_channels, 19 | use_xyz=use_xyz, 20 | bn=bn 21 | ) 22 | 23 | return model 24 | 25 | class Pointnet2MSG(nn.Module): 26 | def __init__(self, num_classes, input_channels=3, use_xyz=True, bn=True): 27 | super().__init__() 28 | 29 | NPOINTS = [1024, 256, 64, 16] 30 | RADIUS = [[0.05, 0.1], [0.1, 0.2], [0.2, 0.4], [0.4, 0.8]] 31 | NSAMPLE = [[16, 32], [16, 32], [16, 32], [16, 32]] 32 | MLPS = [[[16, 16, 32], [32, 32, 64]], [[64, 64, 128], [64, 96, 128]], 33 | [[128, 196, 256], [128, 196, 256]], [[256, 256, 512], [256, 384, 512]]] 34 | FP_MLPS = [[128, 128], [256, 256], [512, 512], [512, 512]] 35 | CLS_FC = [128] 36 | DP_RATIO = 0.5 37 | 38 | self.SA_modules = nn.ModuleList() 39 | channel_in = input_channels 40 | 41 | skip_channel_list = [input_channels] 42 | for k in range(NPOINTS.__len__()): 43 | mlps = MLPS[k].copy() 44 | channel_out = 0 45 | for idx in range(mlps.__len__()): 46 | mlps[idx] = [channel_in] + mlps[idx] 47 | channel_out += mlps[idx][-1] 48 | 49 | self.SA_modules.append( 50 | PointnetSAModuleMSG( 51 | npoint=NPOINTS[k], 52 | radii=RADIUS[k], 53 | nsamples=NSAMPLE[k], 54 | mlps=mlps, 55 | use_xyz=use_xyz, 56 | bn=bn 57 | ) 58 | ) 59 | skip_channel_list.append(channel_out) 60 | channel_in = channel_out 61 | 62 | self.FP_modules = nn.ModuleList() 63 | 64 | for k in range(FP_MLPS.__len__()): 65 | pre_channel = FP_MLPS[k + 1][-1] if k + 1 < len(FP_MLPS) else channel_out 66 | self.FP_modules.append( 67 | PointnetFPModule( 68 | mlp=[pre_channel + skip_channel_list[k]] + FP_MLPS[k], 69 | bn=bn 70 | ) 71 | ) 72 | 73 | cls_layers = [] 74 | pre_channel = FP_MLPS[0][-1] 75 | for k in range(0, CLS_FC.__len__()): 76 | cls_layers.append(pt_utils.Conv1d(pre_channel, CLS_FC[k], bn=bn)) 77 | pre_channel = CLS_FC[k] 78 | cls_layers.append(pt_utils.Conv1d(pre_channel, num_classes, activation=None, bn=bn)) 79 | cls_layers.insert(1, nn.Dropout(DP_RATIO)) 80 | self.cls_layer = nn.Sequential(*cls_layers) 81 | 82 | def _break_up_pc(self, pc): 83 | xyz = pc[..., 0:3].contiguous() 84 | features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None 85 | 86 | return xyz, features 87 | 88 | def forward(self, pointcloud: torch.cuda.FloatTensor): 89 | xyz, features = self._break_up_pc(pointcloud) 90 | 91 | l_xyz, l_features = [xyz], [features] 92 | for i in range(len(self.SA_modules)): 93 | li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) 94 | l_xyz.append(li_xyz) 95 | l_features.append(li_features) 96 | 97 | for i in range(-1, -(len(self.FP_modules) + 1), -1): 98 | l_features[i - 1] = self.FP_modules[i]( 99 | l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] 100 | ) 101 | 102 | pred_cls = self.cls_layer(l_features[0]).transpose(1, 2).contiguous() # (B, N, 1) 103 | return pred_cls 104 | 105 | class Pointnet2SSG(nn.Module): 106 | def __init__(self, num_classes, input_channels=3, use_xyz=True, bn=True): 107 | super().__init__() 108 | 109 | NPOINTS = [1024, 256, 64, 16] 110 | RADIUS = [0.1, 0.2, 0.4, 0.8] 111 | NSAMPLE = [32, 32, 32, 32] 112 | MLPS = [[32, 32, 64], [64, 64, 128], 113 | [128, 128, 256], [256, 256, 512]] 114 | FP_MLPS = [[128, 128], [256, 128], [256, 256], [256, 256]] 115 | CLS_FC = [128] 116 | DP_RATIO = 0.5 117 | 118 | self.SA_modules = nn.ModuleList() 119 | channel_in = input_channels 120 | 121 | skip_channel_list = [input_channels] 122 | for k in range(NPOINTS.__len__()): 123 | mlps = MLPS[k].copy() 124 | channel_out = 0 125 | mlps = [channel_in] + mlps 126 | channel_out += mlps[-1] 127 | 128 | self.SA_modules.append( 129 | PointnetSAModule( 130 | npoint=NPOINTS[k], 131 | radius=RADIUS[k], 132 | nsample=NSAMPLE[k], 133 | mlp=mlps, 134 | use_xyz=use_xyz, 135 | bn=bn 136 | ) 137 | ) 138 | skip_channel_list.append(channel_out) 139 | channel_in = channel_out 140 | 141 | self.FP_modules = nn.ModuleList() 142 | 143 | for k in range(FP_MLPS.__len__()): 144 | pre_channel = FP_MLPS[k + 1][-1] if k + 1 < len(FP_MLPS) else channel_out 145 | self.FP_modules.append( 146 | PointnetFPModule( 147 | mlp=[pre_channel + skip_channel_list[k]] + FP_MLPS[k], 148 | bn=bn 149 | ) 150 | ) 151 | 152 | cls_layers = [] 153 | pre_channel = FP_MLPS[0][-1] 154 | for k in range(0, CLS_FC.__len__()): 155 | cls_layers.append(pt_utils.Conv1d(pre_channel, CLS_FC[k], bn=bn)) 156 | pre_channel = CLS_FC[k] 157 | cls_layers.append(pt_utils.Conv1d(pre_channel, num_classes, activation=None, bn=bn)) 158 | cls_layers.insert(1, nn.Dropout(DP_RATIO)) 159 | self.cls_layer = nn.Sequential(*cls_layers) 160 | 161 | def _break_up_pc(self, pc): 162 | xyz = pc[..., 0:3].contiguous() 163 | features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None 164 | 165 | return xyz, features 166 | 167 | def forward(self, pointcloud: torch.cuda.FloatTensor): 168 | xyz, features = self._break_up_pc(pointcloud) 169 | 170 | l_xyz, l_features = [xyz], [features] 171 | for i in range(len(self.SA_modules)): 172 | li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) 173 | l_xyz.append(li_xyz) 174 | l_features.append(li_features) 175 | 176 | for i in range(-1, -(len(self.FP_modules) + 1), -1): 177 | l_features[i - 1] = self.FP_modules[i]( 178 | l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] 179 | ) 180 | 181 | pred_cls = self.cls_layer(l_features[0]).transpose(1, 2).contiguous() # (B, N, 1) 182 | return pred_cls -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, m) 7 | // output: out(b, c, m) 8 | __global__ void gather_points_kernel(int b, int c, int n, int m, 9 | const float *__restrict__ points, 10 | const int *__restrict__ idx, 11 | float *__restrict__ out) { 12 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 13 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 14 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 15 | int a = idx[i * m + j]; 16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 17 | } 18 | } 19 | } 20 | } 21 | 22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 23 | const float *points, const int *idx, 24 | float *out) { 25 | gather_points_kernel<<>>(b, c, n, npoints, 27 | points, idx, out); 28 | 29 | CUDA_CHECK_ERRORS(); 30 | } 31 | 32 | // input: grad_out(b, c, m) idx(b, m) 33 | // output: grad_points(b, c, n) 34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 35 | const float *__restrict__ grad_out, 36 | const int *__restrict__ idx, 37 | float *__restrict__ grad_points) { 38 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 39 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 40 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 41 | int a = idx[i * m + j]; 42 | atomicAdd(grad_points + (i * c + l) * n + a, 43 | grad_out[(i * c + l) * m + j]); 44 | } 45 | } 46 | } 47 | } 48 | 49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 50 | const float *grad_out, const int *idx, 51 | float *grad_points) { 52 | gather_points_grad_kernel<<>>( 54 | b, c, n, npoints, grad_out, idx, grad_points); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | 59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 60 | int idx1, int idx2) { 61 | const float v1 = dists[idx1], v2 = dists[idx2]; 62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 63 | dists[idx1] = max(v1, v2); 64 | dists_i[idx1] = v2 > v1 ? i2 : i1; 65 | } 66 | 67 | // Input dataset: (b, n, 3), tmp: (b, n) 68 | // Ouput idxs (b, m) 69 | template 70 | __global__ void furthest_point_sampling_kernel( 71 | int b, int n, int m, const float *__restrict__ dataset, 72 | float *__restrict__ temp, int *__restrict__ idxs) { 73 | if (m <= 0) return; 74 | __shared__ float dists[block_size]; 75 | __shared__ int dists_i[block_size]; 76 | 77 | int batch_index = blockIdx.x; 78 | dataset += batch_index * n * 3; 79 | temp += batch_index * n; 80 | idxs += batch_index * m; 81 | 82 | int tid = threadIdx.x; 83 | const int stride = block_size; 84 | 85 | int old = 0; 86 | if (threadIdx.x == 0) idxs[0] = old; 87 | 88 | __syncthreads(); 89 | for (int j = 1; j < m; j++) { 90 | int besti = 0; 91 | float best = -1; 92 | float x1 = dataset[old * 3 + 0]; 93 | float y1 = dataset[old * 3 + 1]; 94 | float z1 = dataset[old * 3 + 2]; 95 | for (int k = tid; k < n; k += stride) { 96 | float x2, y2, z2; 97 | x2 = dataset[k * 3 + 0]; 98 | y2 = dataset[k * 3 + 1]; 99 | z2 = dataset[k * 3 + 2]; 100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 101 | if (mag <= 1e-3) continue; 102 | 103 | float d = 104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 105 | 106 | float d2 = min(d, temp[k]); 107 | temp[k] = d2; 108 | besti = d2 > best ? k : besti; 109 | best = d2 > best ? d2 : best; 110 | } 111 | dists[tid] = best; 112 | dists_i[tid] = besti; 113 | __syncthreads(); 114 | 115 | if (block_size >= 512) { 116 | if (tid < 256) { 117 | __update(dists, dists_i, tid, tid + 256); 118 | } 119 | __syncthreads(); 120 | } 121 | if (block_size >= 256) { 122 | if (tid < 128) { 123 | __update(dists, dists_i, tid, tid + 128); 124 | } 125 | __syncthreads(); 126 | } 127 | if (block_size >= 128) { 128 | if (tid < 64) { 129 | __update(dists, dists_i, tid, tid + 64); 130 | } 131 | __syncthreads(); 132 | } 133 | if (block_size >= 64) { 134 | if (tid < 32) { 135 | __update(dists, dists_i, tid, tid + 32); 136 | } 137 | __syncthreads(); 138 | } 139 | if (block_size >= 32) { 140 | if (tid < 16) { 141 | __update(dists, dists_i, tid, tid + 16); 142 | } 143 | __syncthreads(); 144 | } 145 | if (block_size >= 16) { 146 | if (tid < 8) { 147 | __update(dists, dists_i, tid, tid + 8); 148 | } 149 | __syncthreads(); 150 | } 151 | if (block_size >= 8) { 152 | if (tid < 4) { 153 | __update(dists, dists_i, tid, tid + 4); 154 | } 155 | __syncthreads(); 156 | } 157 | if (block_size >= 4) { 158 | if (tid < 2) { 159 | __update(dists, dists_i, tid, tid + 2); 160 | } 161 | __syncthreads(); 162 | } 163 | if (block_size >= 2) { 164 | if (tid < 1) { 165 | __update(dists, dists_i, tid, tid + 1); 166 | } 167 | __syncthreads(); 168 | } 169 | 170 | old = dists_i[0]; 171 | if (tid == 0) idxs[j] = old; 172 | } 173 | } 174 | 175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 176 | const float *dataset, float *temp, 177 | int *idxs) { 178 | unsigned int n_threads = opt_n_threads(n); 179 | 180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 181 | 182 | switch (n_threads) { 183 | case 512: 184 | furthest_point_sampling_kernel<512> 185 | <<>>(b, n, m, dataset, temp, idxs); 186 | break; 187 | case 256: 188 | furthest_point_sampling_kernel<256> 189 | <<>>(b, n, m, dataset, temp, idxs); 190 | break; 191 | case 128: 192 | furthest_point_sampling_kernel<128> 193 | <<>>(b, n, m, dataset, temp, idxs); 194 | break; 195 | case 64: 196 | furthest_point_sampling_kernel<64> 197 | <<>>(b, n, m, dataset, temp, idxs); 198 | break; 199 | case 32: 200 | furthest_point_sampling_kernel<32> 201 | <<>>(b, n, m, dataset, temp, idxs); 202 | break; 203 | case 16: 204 | furthest_point_sampling_kernel<16> 205 | <<>>(b, n, m, dataset, temp, idxs); 206 | break; 207 | case 8: 208 | furthest_point_sampling_kernel<8> 209 | <<>>(b, n, m, dataset, temp, idxs); 210 | break; 211 | case 4: 212 | furthest_point_sampling_kernel<4> 213 | <<>>(b, n, m, dataset, temp, idxs); 214 | break; 215 | case 2: 216 | furthest_point_sampling_kernel<2> 217 | <<>>(b, n, m, dataset, temp, idxs); 218 | break; 219 | case 1: 220 | furthest_point_sampling_kernel<1> 221 | <<>>(b, n, m, dataset, temp, idxs); 222 | break; 223 | default: 224 | furthest_point_sampling_kernel<512> 225 | <<>>(b, n, m, dataset, temp, idxs); 226 | } 227 | 228 | CUDA_CHECK_ERRORS(); 229 | } 230 | -------------------------------------------------------------------------------- /pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | import torch 8 | import torch.nn as nn 9 | from typing import List, Tuple 10 | 11 | class SharedMLP(nn.Sequential): 12 | 13 | def __init__( 14 | self, 15 | args: List[int], 16 | *, 17 | bn: bool = False, 18 | activation=nn.ReLU(inplace=True), 19 | preact: bool = False, 20 | first: bool = False, 21 | name: str = "" 22 | ): 23 | super().__init__() 24 | 25 | for i in range(len(args) - 1): 26 | self.add_module( 27 | name + 'layer{}'.format(i), 28 | Conv2d( 29 | args[i], 30 | args[i + 1], 31 | bn=(not first or not preact or (i != 0)) and bn, 32 | activation=activation 33 | if (not first or not preact or (i != 0)) else None, 34 | preact=preact 35 | ) 36 | ) 37 | 38 | 39 | class _BNBase(nn.Sequential): 40 | 41 | def __init__(self, in_size, batch_norm=None, name=""): 42 | super().__init__() 43 | self.add_module(name + "bn", batch_norm(in_size)) 44 | 45 | nn.init.constant_(self[0].weight, 1.0) 46 | nn.init.constant_(self[0].bias, 0) 47 | 48 | 49 | class BatchNorm1d(_BNBase): 50 | 51 | def __init__(self, in_size: int, *, name: str = ""): 52 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 53 | 54 | 55 | class BatchNorm2d(_BNBase): 56 | 57 | def __init__(self, in_size: int, name: str = ""): 58 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 59 | 60 | 61 | class BatchNorm3d(_BNBase): 62 | 63 | def __init__(self, in_size: int, name: str = ""): 64 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 65 | 66 | 67 | class _ConvBase(nn.Sequential): 68 | 69 | def __init__( 70 | self, 71 | in_size, 72 | out_size, 73 | kernel_size, 74 | stride, 75 | padding, 76 | activation, 77 | bn, 78 | init, 79 | conv=None, 80 | batch_norm=None, 81 | bias=True, 82 | preact=False, 83 | name="" 84 | ): 85 | super().__init__() 86 | 87 | bias = bias and (not bn) 88 | conv_unit = conv( 89 | in_size, 90 | out_size, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | padding=padding, 94 | bias=bias 95 | ) 96 | init(conv_unit.weight) 97 | if bias: 98 | nn.init.constant_(conv_unit.bias, 0) 99 | 100 | if bn: 101 | if not preact: 102 | bn_unit = batch_norm(out_size) 103 | else: 104 | bn_unit = batch_norm(in_size) 105 | 106 | if preact: 107 | if bn: 108 | self.add_module(name + 'bn', bn_unit) 109 | 110 | if activation is not None: 111 | self.add_module(name + 'activation', activation) 112 | 113 | self.add_module(name + 'conv', conv_unit) 114 | 115 | if not preact: 116 | if bn: 117 | self.add_module(name + 'bn', bn_unit) 118 | 119 | if activation is not None: 120 | self.add_module(name + 'activation', activation) 121 | 122 | 123 | class Conv1d(_ConvBase): 124 | 125 | def __init__( 126 | self, 127 | in_size: int, 128 | out_size: int, 129 | *, 130 | kernel_size: int = 1, 131 | stride: int = 1, 132 | padding: int = 0, 133 | activation=nn.ReLU(inplace=True), 134 | bn: bool = False, 135 | init=nn.init.kaiming_normal_, 136 | bias: bool = True, 137 | preact: bool = False, 138 | name: str = "" 139 | ): 140 | super().__init__( 141 | in_size, 142 | out_size, 143 | kernel_size, 144 | stride, 145 | padding, 146 | activation, 147 | bn, 148 | init, 149 | conv=nn.Conv1d, 150 | batch_norm=BatchNorm1d, 151 | bias=bias, 152 | preact=preact, 153 | name=name 154 | ) 155 | 156 | 157 | class Conv2d(_ConvBase): 158 | 159 | def __init__( 160 | self, 161 | in_size: int, 162 | out_size: int, 163 | *, 164 | kernel_size: Tuple[int, int] = (1, 1), 165 | stride: Tuple[int, int] = (1, 1), 166 | padding: Tuple[int, int] = (0, 0), 167 | activation=nn.ReLU(inplace=True), 168 | bn: bool = False, 169 | init=nn.init.kaiming_normal_, 170 | bias: bool = True, 171 | preact: bool = False, 172 | name: str = "" 173 | ): 174 | super().__init__( 175 | in_size, 176 | out_size, 177 | kernel_size, 178 | stride, 179 | padding, 180 | activation, 181 | bn, 182 | init, 183 | conv=nn.Conv2d, 184 | batch_norm=BatchNorm2d, 185 | bias=bias, 186 | preact=preact, 187 | name=name 188 | ) 189 | 190 | 191 | class Conv3d(_ConvBase): 192 | 193 | def __init__( 194 | self, 195 | in_size: int, 196 | out_size: int, 197 | *, 198 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 199 | stride: Tuple[int, int, int] = (1, 1, 1), 200 | padding: Tuple[int, int, int] = (0, 0, 0), 201 | activation=nn.ReLU(inplace=True), 202 | bn: bool = False, 203 | init=nn.init.kaiming_normal_, 204 | bias: bool = True, 205 | preact: bool = False, 206 | name: str = "" 207 | ): 208 | super().__init__( 209 | in_size, 210 | out_size, 211 | kernel_size, 212 | stride, 213 | padding, 214 | activation, 215 | bn, 216 | init, 217 | conv=nn.Conv3d, 218 | batch_norm=BatchNorm3d, 219 | bias=bias, 220 | preact=preact, 221 | name=name 222 | ) 223 | 224 | 225 | class FC(nn.Sequential): 226 | 227 | def __init__( 228 | self, 229 | in_size: int, 230 | out_size: int, 231 | *, 232 | activation=nn.ReLU(inplace=True), 233 | bn: bool = False, 234 | init=None, 235 | preact: bool = False, 236 | name: str = "" 237 | ): 238 | super().__init__() 239 | 240 | fc = nn.Linear(in_size, out_size, bias=not bn) 241 | if init is not None: 242 | init(fc.weight) 243 | if not bn: 244 | nn.init.constant_(fc.bias, 0) 245 | 246 | if preact: 247 | if bn: 248 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 249 | 250 | if activation is not None: 251 | self.add_module(name + 'activation', activation) 252 | 253 | self.add_module(name + 'fc', fc) 254 | 255 | if not preact: 256 | if bn: 257 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 258 | 259 | if activation is not None: 260 | self.add_module(name + 'activation', activation) 261 | 262 | def set_bn_momentum_default(bn_momentum): 263 | 264 | def fn(m): 265 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 266 | m.momentum = bn_momentum 267 | 268 | return fn 269 | 270 | 271 | class BNMomentumScheduler(object): 272 | 273 | def __init__( 274 | self, model, bn_lambda, last_epoch=-1, 275 | setter=set_bn_momentum_default 276 | ): 277 | if not isinstance(model, nn.Module): 278 | raise RuntimeError( 279 | "Class '{}' is not a PyTorch nn Module".format( 280 | type(model).__name__ 281 | ) 282 | ) 283 | 284 | self.model = model 285 | self.setter = setter 286 | self.lmbd = bn_lambda 287 | 288 | self.step(last_epoch + 1) 289 | self.last_epoch = last_epoch 290 | 291 | def step(self, epoch=None): 292 | if epoch is None: 293 | epoch = self.last_epoch + 1 294 | 295 | self.last_epoch = epoch 296 | self.model.apply(self.setter(self.lmbd(epoch))) 297 | 298 | 299 | -------------------------------------------------------------------------------- /pointnet2/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "sampling_gpu.h" 6 | 7 | 8 | __global__ void gather_points_kernel_fast(int b, int c, int n, int m, 9 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 10 | // points: (B, C, N) 11 | // idx: (B, M) 12 | // output: 13 | // out: (B, C, M) 14 | 15 | int bs_idx = blockIdx.z; 16 | int c_idx = blockIdx.y; 17 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 19 | 20 | out += bs_idx * c * m + c_idx * m + pt_idx; 21 | idx += bs_idx * m + pt_idx; 22 | points += bs_idx * c * n + c_idx * n; 23 | out[0] = points[idx[0]]; 24 | } 25 | 26 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 27 | const float *points, const int *idx, float *out, cudaStream_t stream) { 28 | // points: (B, C, N) 29 | // idx: (B, npoints) 30 | // output: 31 | // out: (B, C, npoints) 32 | 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 47 | const int *__restrict__ idx, float *__restrict__ grad_points) { 48 | // grad_out: (B, C, M) 49 | // idx: (B, M) 50 | // output: 51 | // grad_points: (B, C, N) 52 | 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 56 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 57 | 58 | grad_out += bs_idx * c * m + c_idx * m + pt_idx; 59 | idx += bs_idx * m + pt_idx; 60 | grad_points += bs_idx * c * n + c_idx * n; 61 | 62 | atomicAdd(grad_points + idx[0], grad_out[0]); 63 | } 64 | 65 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 66 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 67 | // grad_out: (B, C, npoints) 68 | // idx: (B, npoints) 69 | // output: 70 | // grad_points: (B, C, N) 71 | 72 | cudaError_t err; 73 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 74 | dim3 threads(THREADS_PER_BLOCK); 75 | 76 | gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points); 77 | 78 | err = cudaGetLastError(); 79 | if (cudaSuccess != err) { 80 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 81 | exit(-1); 82 | } 83 | } 84 | 85 | 86 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){ 87 | const float v1 = dists[idx1], v2 = dists[idx2]; 88 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 89 | dists[idx1] = max(v1, v2); 90 | dists_i[idx1] = v2 > v1 ? i2 : i1; 91 | } 92 | 93 | template 94 | __global__ void furthest_point_sampling_kernel(int b, int n, int m, 95 | const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { 96 | // dataset: (B, N, 3) 97 | // tmp: (B, N) 98 | // output: 99 | // idx: (B, M) 100 | 101 | if (m <= 0) return; 102 | __shared__ float dists[block_size]; 103 | __shared__ int dists_i[block_size]; 104 | 105 | int batch_index = blockIdx.x; 106 | dataset += batch_index * n * 3; 107 | temp += batch_index * n; 108 | idxs += batch_index * m; 109 | 110 | int tid = threadIdx.x; 111 | const int stride = block_size; 112 | 113 | int old = 0; 114 | if (threadIdx.x == 0) 115 | idxs[0] = old; 116 | 117 | __syncthreads(); 118 | for (int j = 1; j < m; j++) { 119 | int besti = 0; 120 | float best = -1; 121 | float x1 = dataset[old * 3 + 0]; 122 | float y1 = dataset[old * 3 + 1]; 123 | float z1 = dataset[old * 3 + 2]; 124 | for (int k = tid; k < n; k += stride) { 125 | float x2, y2, z2; 126 | x2 = dataset[k * 3 + 0]; 127 | y2 = dataset[k * 3 + 1]; 128 | z2 = dataset[k * 3 + 2]; 129 | // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 130 | // if (mag <= 1e-3) 131 | // continue; 132 | 133 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 134 | float d2 = min(d, temp[k]); 135 | temp[k] = d2; 136 | besti = d2 > best ? k : besti; 137 | best = d2 > best ? d2 : best; 138 | } 139 | dists[tid] = best; 140 | dists_i[tid] = besti; 141 | __syncthreads(); 142 | 143 | if (block_size >= 1024) { 144 | if (tid < 512) { 145 | __update(dists, dists_i, tid, tid + 512); 146 | } 147 | __syncthreads(); 148 | } 149 | 150 | if (block_size >= 512) { 151 | if (tid < 256) { 152 | __update(dists, dists_i, tid, tid + 256); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 256) { 157 | if (tid < 128) { 158 | __update(dists, dists_i, tid, tid + 128); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 128) { 163 | if (tid < 64) { 164 | __update(dists, dists_i, tid, tid + 64); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 64) { 169 | if (tid < 32) { 170 | __update(dists, dists_i, tid, tid + 32); 171 | } 172 | __syncthreads(); 173 | } 174 | if (block_size >= 32) { 175 | if (tid < 16) { 176 | __update(dists, dists_i, tid, tid + 16); 177 | } 178 | __syncthreads(); 179 | } 180 | if (block_size >= 16) { 181 | if (tid < 8) { 182 | __update(dists, dists_i, tid, tid + 8); 183 | } 184 | __syncthreads(); 185 | } 186 | if (block_size >= 8) { 187 | if (tid < 4) { 188 | __update(dists, dists_i, tid, tid + 4); 189 | } 190 | __syncthreads(); 191 | } 192 | if (block_size >= 4) { 193 | if (tid < 2) { 194 | __update(dists, dists_i, tid, tid + 2); 195 | } 196 | __syncthreads(); 197 | } 198 | if (block_size >= 2) { 199 | if (tid < 1) { 200 | __update(dists, dists_i, tid, tid + 1); 201 | } 202 | __syncthreads(); 203 | } 204 | 205 | old = dists_i[0]; 206 | if (tid == 0) 207 | idxs[j] = old; 208 | } 209 | } 210 | 211 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 212 | const float *dataset, float *temp, int *idxs, cudaStream_t stream) { 213 | // dataset: (B, N, 3) 214 | // tmp: (B, N) 215 | // output: 216 | // idx: (B, M) 217 | 218 | cudaError_t err; 219 | unsigned int n_threads = opt_n_threads(n); 220 | 221 | switch (n_threads) { 222 | case 1024: 223 | furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break; 224 | case 512: 225 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break; 226 | case 256: 227 | furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break; 228 | case 128: 229 | furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break; 230 | case 64: 231 | furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break; 232 | case 32: 233 | furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break; 234 | case 16: 235 | furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break; 236 | case 8: 237 | furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break; 238 | case 4: 239 | furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break; 240 | case 2: 241 | furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break; 242 | case 1: 243 | furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break; 244 | default: 245 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); 246 | } 247 | 248 | err = cudaGetLastError(); 249 | if (cudaSuccess != err) { 250 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 251 | exit(-1); 252 | } 253 | } 254 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import importlib 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | 9 | # for PointNet2.PyTorch module 10 | import sys 11 | sys.path.append(os.path.join(os.getcwd())) # HACK add the root folder 12 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../pointnet2/')) 13 | from lib.config import CONF 14 | from lib.dataset import ScannetDatasetWholeScene, collate_wholescene 15 | from lib.pc_util import point_cloud_label_to_surface_voxel_label_fast 16 | 17 | def get_scene_list(path): 18 | scene_list = [] 19 | with open(path) as f: 20 | for scene_id in f.readlines(): 21 | scene_list.append(scene_id.strip()) 22 | 23 | scene_list = sorted(scene_list, key=lambda x: int(x.split("_")[0][5:])) 24 | 25 | return scene_list 26 | 27 | def forward(args, model, coords, feats): 28 | pred = [] 29 | coord_chunk, feat_chunk = torch.split(coords.squeeze(0), args.batch_size, 0), torch.split(feats.squeeze(0), args.batch_size, 0) 30 | assert len(coord_chunk) == len(feat_chunk) 31 | for coord, feat in zip(coord_chunk, feat_chunk): 32 | output = model(torch.cat([coord, feat], dim=2)) 33 | pred.append(output) 34 | 35 | pred = torch.cat(pred, dim=0).unsqueeze(0) # (1, CK, N, C) 36 | outputs = pred.max(3)[1] 37 | 38 | return outputs 39 | 40 | def filter_points(coords, preds, targets, weights): 41 | assert coords.shape[0] == preds.shape[0] == targets.shape[0] == weights.shape[0] 42 | coord_hash = [hash(str(coords[point_idx][0]) + str(coords[point_idx][1]) + str(coords[point_idx][2])) for point_idx in range(coords.shape[0])] 43 | _, coord_ids = np.unique(np.array(coord_hash), return_index=True) 44 | coord_filtered, pred_filtered, target_filtered, weight_filtered = coords[coord_ids], preds[coord_ids], targets[coord_ids], weights[coord_ids] 45 | 46 | return coord_filtered, pred_filtered, target_filtered, weight_filtered 47 | 48 | def compute_acc(coords, preds, targets, weights): 49 | coords, preds, targets, weights = filter_points(coords, preds, targets, weights) 50 | seen_classes = np.unique(targets) 51 | mask = np.zeros(CONF.NUM_CLASSES) 52 | mask[seen_classes] = 1 53 | 54 | total_correct = 0 55 | total_seen = 0 56 | total_seen_class = [0 for _ in range(CONF.NUM_CLASSES)] 57 | total_correct_class = [0 for _ in range(CONF.NUM_CLASSES)] 58 | 59 | total_correct_vox = 0 60 | total_seen_vox = 0 61 | total_seen_class_vox = [0 for _ in range(CONF.NUM_CLASSES)] 62 | total_correct_class_vox = [0 for _ in range(CONF.NUM_CLASSES)] 63 | 64 | labelweights = np.zeros(CONF.NUM_CLASSES) 65 | labelweights_vox = np.zeros(CONF.NUM_CLASSES) 66 | 67 | correct = np.sum(preds == targets) # evaluate only on 20 categories but not unknown 68 | total_correct += correct 69 | total_seen += targets.shape[0] 70 | tmp,_ = np.histogram(targets,range(CONF.NUM_CLASSES+1)) 71 | labelweights += tmp 72 | for l in seen_classes: 73 | total_seen_class[l] += np.sum(targets==l) 74 | total_correct_class[l] += np.sum((preds==l) & (targets==l)) 75 | 76 | _, uvlabel, _ = point_cloud_label_to_surface_voxel_label_fast(coords, np.concatenate((np.expand_dims(targets,1),np.expand_dims(preds,1)),axis=1), res=0.02) 77 | total_correct_vox += np.sum(uvlabel[:,0]==uvlabel[:,1]) 78 | total_seen_vox += uvlabel[:,0].shape[0] 79 | tmp,_ = np.histogram(uvlabel[:,0],range(CONF.NUM_CLASSES+1)) 80 | labelweights_vox += tmp 81 | for l in seen_classes: 82 | total_seen_class_vox[l] += np.sum(uvlabel[:,0]==l) 83 | total_correct_class_vox[l] += np.sum((uvlabel[:,0]==l) & (uvlabel[:,1]==l)) 84 | 85 | pointacc = total_correct / float(total_seen) 86 | voxacc = total_correct_vox / float(total_seen_vox) 87 | 88 | labelweights = labelweights.astype(np.float32)/np.sum(labelweights.astype(np.float32)) 89 | labelweights_vox = labelweights_vox.astype(np.float32)/np.sum(labelweights_vox.astype(np.float32)) 90 | caliweights = labelweights_vox 91 | voxcaliacc = np.average(np.array(total_correct_class_vox)/(np.array(total_seen_class_vox,dtype=np.float)+1e-8),weights=caliweights) 92 | 93 | pointacc_per_class = np.zeros(CONF.NUM_CLASSES) 94 | voxacc_per_class = np.zeros(CONF.NUM_CLASSES) 95 | for l in seen_classes: 96 | pointacc_per_class[l] = total_correct_class[l]/(total_seen_class[l] + 1e-8) 97 | voxacc_per_class[l] = total_correct_class_vox[l]/(total_seen_class_vox[l] + 1e-8) 98 | 99 | return pointacc, pointacc_per_class, voxacc, voxacc_per_class, voxcaliacc, mask 100 | 101 | def compute_miou(coords, preds, targets, weights): 102 | coords, preds, targets, weights = filter_points(coords, preds, targets, weights) 103 | seen_classes = np.unique(targets) 104 | mask = np.zeros(CONF.NUM_CLASSES) 105 | mask[seen_classes] = 1 106 | 107 | pointmiou = np.zeros(CONF.NUM_CLASSES) 108 | voxmiou = np.zeros(CONF.NUM_CLASSES) 109 | 110 | uvidx, uvlabel, _ = point_cloud_label_to_surface_voxel_label_fast(coords, np.concatenate((np.expand_dims(targets,1),np.expand_dims(preds,1)),axis=1), res=0.02) 111 | for l in seen_classes: 112 | target_label = np.arange(targets.shape[0])[targets==l] 113 | pred_label = np.arange(preds.shape[0])[preds==l] 114 | num_intersection_label = np.intersect1d(pred_label, target_label).shape[0] 115 | num_union_label = np.union1d(pred_label, target_label).shape[0] 116 | pointmiou[l] = num_intersection_label / (num_union_label + 1e-8) 117 | 118 | target_label_vox = uvidx[(uvlabel[:, 0] == l)] 119 | pred_label_vox = uvidx[(uvlabel[:, 1] == l)] 120 | num_intersection_label_vox = np.intersect1d(pred_label_vox, target_label_vox).shape[0] 121 | num_union_label_vox = np.union1d(pred_label_vox, target_label_vox).shape[0] 122 | voxmiou[l] = num_intersection_label_vox / (num_union_label_vox + 1e-8) 123 | 124 | return pointmiou, voxmiou, mask 125 | 126 | def eval_one_batch(args, model, data): 127 | # unpack 128 | coords, feats, targets, weights, _ = data 129 | coords, feats, targets, weights = coords.cuda(), feats.cuda(), targets.cuda(), weights.cuda() 130 | 131 | # feed 132 | preds = forward(args, model, coords, feats) 133 | 134 | # eval 135 | coords = coords.squeeze(0).view(-1, 3).cpu().numpy() # (CK*N, C) 136 | preds = preds.squeeze(0).view(-1).cpu().numpy() # (CK*N, C) 137 | targets = targets.squeeze(0).view(-1).cpu().numpy() # (CK*N, C) 138 | weights = weights.squeeze(0).view(-1).cpu().numpy() # (CK*N, C) 139 | pointacc, pointacc_per_class, voxacc, voxacc_per_class, voxcaliacc, acc_mask = compute_acc(coords, preds, targets, weights) 140 | pointmiou, voxmiou, miou_mask = compute_miou(coords, preds, targets, weights) 141 | assert acc_mask.all() == miou_mask.all() 142 | 143 | return pointacc, pointacc_per_class, voxacc, voxacc_per_class, voxcaliacc, pointmiou, voxmiou, acc_mask 144 | 145 | 146 | def eval_wholescene(args, model, dataloader): 147 | # init 148 | pointacc_list = [] 149 | pointacc_per_class_array = np.zeros((len(dataloader), CONF.NUM_CLASSES)) 150 | voxacc_list = [] 151 | voxacc_per_class_array = np.zeros((len(dataloader), CONF.NUM_CLASSES)) 152 | voxcaliacc_list = [] 153 | pointmiou_per_class_array = np.zeros((len(dataloader), CONF.NUM_CLASSES)) 154 | voxmiou_per_class_array = np.zeros((len(dataloader), CONF.NUM_CLASSES)) 155 | masks = np.zeros((len(dataloader), CONF.NUM_CLASSES)) 156 | 157 | # iter 158 | for load_idx, data in enumerate(tqdm(dataloader)): 159 | # feed 160 | pointacc, pointacc_per_class, voxacc, voxacc_per_class, voxcaliacc, pointmiou, voxmiou, mask = eval_one_batch(args, model, data) 161 | 162 | # dump 163 | pointacc_list.append(pointacc) 164 | pointacc_per_class_array[load_idx] = pointacc_per_class 165 | voxacc_list.append(voxacc) 166 | voxacc_per_class_array[load_idx] = voxacc_per_class 167 | voxcaliacc_list.append(voxcaliacc) 168 | pointmiou_per_class_array[load_idx] = pointmiou 169 | voxmiou_per_class_array[load_idx] = voxmiou 170 | masks[load_idx] = mask 171 | 172 | return pointacc_list, pointacc_per_class_array, voxacc_list, voxacc_per_class_array, voxcaliacc_list, pointmiou_per_class_array, voxmiou_per_class_array, masks 173 | 174 | def evaluate(args): 175 | # prepare data 176 | print("preparing data...") 177 | scene_list = get_scene_list("data/scannetv2_val.txt") 178 | dataset = ScannetDatasetWholeScene(scene_list, use_color=args.use_color, use_normal=args.use_normal, use_multiview=args.use_multiview) 179 | dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_wholescene) 180 | 181 | # load model 182 | print("loading model...") 183 | model_path = os.path.join(CONF.OUTPUT_ROOT, args.folder, "model.pth") 184 | Pointnet = importlib.import_module("pointnet2_semseg") 185 | input_channels = int(args.use_color) * 3 + int(args.use_normal) * 3 + int(args.use_multiview) * 128 186 | model = Pointnet.get_model(num_classes=CONF.NUM_CLASSES, is_msg=args.use_msg, input_channels=input_channels, use_xyz=not args.no_xyz, bn=not args.no_bn).cuda() 187 | model.load_state_dict(torch.load(model_path)) 188 | model.eval() 189 | 190 | # eval 191 | print("evaluating...") 192 | pointacc_list, pointacc_per_class_array, voxacc_list, voxacc_per_class_array, voxcaliacc_list, pointmiou_per_class_array, voxmiou_per_class_array, masks = eval_wholescene(args, model, dataloader) 193 | 194 | avg_pointacc = np.mean(pointacc_list) 195 | avg_pointacc_per_class = np.sum(pointacc_per_class_array * masks, axis=0)/np.sum(masks, axis=0) 196 | 197 | avg_voxacc = np.mean(voxacc_list) 198 | avg_voxacc_per_class = np.sum(voxacc_per_class_array * masks, axis=0)/np.sum(masks, axis=0) 199 | 200 | avg_voxcaliacc = np.mean(voxcaliacc_list) 201 | 202 | avg_pointmiou_per_class = np.sum(pointmiou_per_class_array * masks, axis=0)/np.sum(masks, axis=0) 203 | avg_pointmiou = np.mean(avg_pointmiou_per_class) 204 | 205 | avg_voxmiou_per_class = np.sum(voxmiou_per_class_array * masks, axis=0)/np.sum(masks, axis=0) 206 | avg_voxmiou = np.mean(avg_voxmiou_per_class) 207 | 208 | # report 209 | print() 210 | print("Point accuracy: {}".format(avg_pointacc)) 211 | print("Point accuracy per class: {}".format(np.mean(avg_pointacc_per_class))) 212 | print("Voxel accuracy: {}".format(avg_voxacc)) 213 | print("Voxel accuracy per class: {}".format(np.mean(avg_voxacc_per_class))) 214 | print("Calibrated voxel accuracy: {}".format(avg_voxcaliacc)) 215 | print("Point miou: {}".format(avg_pointmiou)) 216 | print("Voxel miou: {}".format(avg_voxmiou)) 217 | print() 218 | 219 | print("Point acc/voxel acc/point miou/voxel miou per class:") 220 | for l in range(CONF.NUM_CLASSES): 221 | print("Class {}: {}/{}/{}/{}".format(CONF.NYUCLASSES[l], avg_pointacc_per_class[l], avg_voxacc_per_class[l], avg_pointmiou_per_class[l], avg_voxmiou_per_class[l])) 222 | 223 | 224 | if __name__ == "__main__": 225 | parser = argparse.ArgumentParser() 226 | parser.add_argument('--folder', type=str, help='output folder containing the best model from training', required=True) 227 | parser.add_argument('--batch_size', type=int, help='size of the batch/chunk', default=32) 228 | parser.add_argument('--gpu', type=str, help='gpu', default='0') 229 | parser.add_argument('--no_bn', action="store_true", help="do not apply batch normalization in pointnet++") 230 | parser.add_argument('--no_xyz', action="store_true", help="do not apply coordinates as features in pointnet++") 231 | parser.add_argument("--use_msg", action="store_true", help="apply multiscale grouping or not") 232 | parser.add_argument("--use_color", action="store_true", help="use color values or not") 233 | parser.add_argument("--use_normal", action="store_true", help="use normals or not") 234 | parser.add_argument("--use_multiview", action="store_true", help="use multiview image features or not") 235 | args = parser.parse_args() 236 | 237 | # setting 238 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 239 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 240 | 241 | evaluate(args) -------------------------------------------------------------------------------- /pointnet2/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | from __future__ import ( 8 | division, 9 | absolute_import, 10 | with_statement, 11 | print_function, 12 | unicode_literals, 13 | ) 14 | import torch 15 | from torch.autograd import Function 16 | import torch.nn as nn 17 | import pytorch_utils as pt_utils 18 | import sys 19 | 20 | try: 21 | import builtins 22 | except: 23 | import __builtin__ as builtins 24 | 25 | try: 26 | import pointnet2._ext as _ext 27 | except ImportError: 28 | if not getattr(builtins, "__POINTNET2_SETUP__", False): 29 | raise ImportError( 30 | "Could not import _ext module.\n" 31 | "Please see the setup instructions in the README: " 32 | "https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst" 33 | ) 34 | 35 | if False: 36 | # Workaround for type hints without depending on the `typing` module 37 | from typing import * 38 | 39 | 40 | class RandomDropout(nn.Module): 41 | def __init__(self, p=0.5, inplace=False): 42 | super(RandomDropout, self).__init__() 43 | self.p = p 44 | self.inplace = inplace 45 | 46 | def forward(self, X): 47 | theta = torch.Tensor(1).uniform_(0, self.p)[0] 48 | return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace) 49 | 50 | 51 | class FurthestPointSampling(Function): 52 | @staticmethod 53 | def forward(ctx, xyz, npoint): 54 | # type: (Any, torch.Tensor, int) -> torch.Tensor 55 | r""" 56 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 57 | minimum distance 58 | 59 | Parameters 60 | ---------- 61 | xyz : torch.Tensor 62 | (B, N, 3) tensor where N > npoint 63 | npoint : int32 64 | number of features in the sampled set 65 | 66 | Returns 67 | ------- 68 | torch.Tensor 69 | (B, npoint) tensor containing the set 70 | """ 71 | fps_inds = _ext.furthest_point_sampling(xyz, npoint) 72 | ctx.mark_non_differentiable(fps_inds) 73 | return fps_inds 74 | 75 | @staticmethod 76 | def backward(xyz, a=None): 77 | return None, None 78 | 79 | 80 | furthest_point_sample = FurthestPointSampling.apply 81 | 82 | 83 | class GatherOperation(Function): 84 | @staticmethod 85 | def forward(ctx, features, idx): 86 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 87 | r""" 88 | 89 | Parameters 90 | ---------- 91 | features : torch.Tensor 92 | (B, C, N) tensor 93 | 94 | idx : torch.Tensor 95 | (B, npoint) tensor of the features to gather 96 | 97 | Returns 98 | ------- 99 | torch.Tensor 100 | (B, C, npoint) tensor 101 | """ 102 | 103 | _, C, N = features.size() 104 | 105 | ctx.for_backwards = (idx, C, N) 106 | 107 | return _ext.gather_points(features, idx) 108 | 109 | @staticmethod 110 | def backward(ctx, grad_out): 111 | idx, C, N = ctx.for_backwards 112 | 113 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 114 | return grad_features, None 115 | 116 | 117 | gather_operation = GatherOperation.apply 118 | 119 | 120 | class ThreeNN(Function): 121 | @staticmethod 122 | def forward(ctx, unknown, known): 123 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 124 | r""" 125 | Find the three nearest neighbors of unknown in known 126 | Parameters 127 | ---------- 128 | unknown : torch.Tensor 129 | (B, n, 3) tensor of known features 130 | known : torch.Tensor 131 | (B, m, 3) tensor of unknown features 132 | 133 | Returns 134 | ------- 135 | dist : torch.Tensor 136 | (B, n, 3) l2 distance to the three nearest neighbors 137 | idx : torch.Tensor 138 | (B, n, 3) index of 3 nearest neighbors 139 | """ 140 | dist2, idx = _ext.three_nn(unknown, known) 141 | 142 | return torch.sqrt(dist2), idx 143 | 144 | @staticmethod 145 | def backward(ctx, a=None, b=None): 146 | return None, None 147 | 148 | 149 | three_nn = ThreeNN.apply 150 | 151 | 152 | class ThreeInterpolate(Function): 153 | @staticmethod 154 | def forward(ctx, features, idx, weight): 155 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 156 | r""" 157 | Performs weight linear interpolation on 3 features 158 | Parameters 159 | ---------- 160 | features : torch.Tensor 161 | (B, c, m) Features descriptors to be interpolated from 162 | idx : torch.Tensor 163 | (B, n, 3) three nearest neighbors of the target features in features 164 | weight : torch.Tensor 165 | (B, n, 3) weights 166 | 167 | Returns 168 | ------- 169 | torch.Tensor 170 | (B, c, n) tensor of the interpolated features 171 | """ 172 | B, c, m = features.size() 173 | n = idx.size(1) 174 | 175 | ctx.three_interpolate_for_backward = (idx, weight, m) 176 | 177 | return _ext.three_interpolate(features, idx, weight) 178 | 179 | @staticmethod 180 | def backward(ctx, grad_out): 181 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 182 | r""" 183 | Parameters 184 | ---------- 185 | grad_out : torch.Tensor 186 | (B, c, n) tensor with gradients of ouputs 187 | 188 | Returns 189 | ------- 190 | grad_features : torch.Tensor 191 | (B, c, m) tensor with gradients of features 192 | 193 | None 194 | 195 | None 196 | """ 197 | idx, weight, m = ctx.three_interpolate_for_backward 198 | 199 | grad_features = _ext.three_interpolate_grad( 200 | grad_out.contiguous(), idx, weight, m 201 | ) 202 | 203 | return grad_features, None, None 204 | 205 | 206 | three_interpolate = ThreeInterpolate.apply 207 | 208 | 209 | class GroupingOperation(Function): 210 | @staticmethod 211 | def forward(ctx, features, idx): 212 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 213 | r""" 214 | 215 | Parameters 216 | ---------- 217 | features : torch.Tensor 218 | (B, C, N) tensor of features to group 219 | idx : torch.Tensor 220 | (B, npoint, nsample) tensor containing the indicies of features to group with 221 | 222 | Returns 223 | ------- 224 | torch.Tensor 225 | (B, C, npoint, nsample) tensor 226 | """ 227 | B, nfeatures, nsample = idx.size() 228 | _, C, N = features.size() 229 | 230 | ctx.for_backwards = (idx, N) 231 | 232 | return _ext.group_points(features, idx) 233 | 234 | @staticmethod 235 | def backward(ctx, grad_out): 236 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 237 | r""" 238 | 239 | Parameters 240 | ---------- 241 | grad_out : torch.Tensor 242 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 243 | 244 | Returns 245 | ------- 246 | torch.Tensor 247 | (B, C, N) gradient of the features 248 | None 249 | """ 250 | idx, N = ctx.for_backwards 251 | 252 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 253 | 254 | return grad_features, None 255 | 256 | 257 | grouping_operation = GroupingOperation.apply 258 | 259 | 260 | class BallQuery(Function): 261 | @staticmethod 262 | def forward(ctx, radius, nsample, xyz, new_xyz): 263 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 264 | r""" 265 | 266 | Parameters 267 | ---------- 268 | radius : float 269 | radius of the balls 270 | nsample : int 271 | maximum number of features in the balls 272 | xyz : torch.Tensor 273 | (B, N, 3) xyz coordinates of the features 274 | new_xyz : torch.Tensor 275 | (B, npoint, 3) centers of the ball query 276 | 277 | Returns 278 | ------- 279 | torch.Tensor 280 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 281 | """ 282 | inds = _ext.ball_query(new_xyz, xyz, radius, nsample) 283 | ctx.mark_non_differentiable(inds) 284 | return inds 285 | 286 | @staticmethod 287 | def backward(ctx, a=None): 288 | return None, None, None, None 289 | 290 | 291 | ball_query = BallQuery.apply 292 | 293 | 294 | class QueryAndGroup(nn.Module): 295 | r""" 296 | Groups with a ball query of radius 297 | 298 | Parameters 299 | --------- 300 | radius : float32 301 | Radius of ball 302 | nsample : int32 303 | Maximum number of features to gather in the ball 304 | """ 305 | 306 | def __init__(self, radius, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, sample_uniformly=False, ret_unique_cnt=False): 307 | # type: (QueryAndGroup, float, int, bool) -> None 308 | super(QueryAndGroup, self).__init__() 309 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 310 | self.ret_grouped_xyz = ret_grouped_xyz 311 | self.normalize_xyz = normalize_xyz 312 | self.sample_uniformly = sample_uniformly 313 | self.ret_unique_cnt = ret_unique_cnt 314 | if self.ret_unique_cnt: 315 | assert(self.sample_uniformly) 316 | 317 | def forward(self, xyz, new_xyz, features=None): 318 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 319 | r""" 320 | Parameters 321 | ---------- 322 | xyz : torch.Tensor 323 | xyz coordinates of the features (B, N, 3) 324 | new_xyz : torch.Tensor 325 | centriods (B, npoint, 3) 326 | features : torch.Tensor 327 | Descriptors of the features (B, C, N) 328 | 329 | Returns 330 | ------- 331 | new_features : torch.Tensor 332 | (B, 3 + C, npoint, nsample) tensor 333 | """ 334 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 335 | 336 | if self.sample_uniformly: 337 | unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) 338 | for i_batch in range(idx.shape[0]): 339 | for i_region in range(idx.shape[1]): 340 | unique_ind = torch.unique(idx[i_batch, i_region, :]) 341 | num_unique = unique_ind.shape[0] 342 | unique_cnt[i_batch, i_region] = num_unique 343 | sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long) 344 | all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) 345 | idx[i_batch, i_region, :] = all_ind 346 | 347 | 348 | xyz_trans = xyz.transpose(1, 2).contiguous() 349 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 350 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 351 | if self.normalize_xyz: 352 | grouped_xyz /= self.radius 353 | 354 | if features is not None: 355 | grouped_features = grouping_operation(features, idx) 356 | if self.use_xyz: 357 | new_features = torch.cat( 358 | [grouped_xyz, grouped_features], dim=1 359 | ) # (B, C + 3, npoint, nsample) 360 | else: 361 | new_features = grouped_features 362 | else: 363 | assert ( 364 | self.use_xyz 365 | ), "Cannot have not features and not use xyz as a feature!" 366 | new_features = grouped_xyz 367 | 368 | ret = [new_features] 369 | if self.ret_grouped_xyz: 370 | ret.append(grouped_xyz) 371 | if self.ret_unique_cnt: 372 | ret.append(unique_cnt) 373 | if len(ret) == 1: 374 | return ret[0] 375 | else: 376 | return tuple(ret) 377 | 378 | 379 | class GroupAll(nn.Module): 380 | r""" 381 | Groups all features 382 | 383 | Parameters 384 | --------- 385 | """ 386 | 387 | def __init__(self, use_xyz=True, ret_grouped_xyz=False): 388 | # type: (GroupAll, bool) -> None 389 | super(GroupAll, self).__init__() 390 | self.use_xyz = use_xyz 391 | 392 | def forward(self, xyz, new_xyz, features=None): 393 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 394 | r""" 395 | Parameters 396 | ---------- 397 | xyz : torch.Tensor 398 | xyz coordinates of the features (B, N, 3) 399 | new_xyz : torch.Tensor 400 | Ignored 401 | features : torch.Tensor 402 | Descriptors of the features (B, C, N) 403 | 404 | Returns 405 | ------- 406 | new_features : torch.Tensor 407 | (B, C + 3, 1, N) tensor 408 | """ 409 | 410 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 411 | if features is not None: 412 | grouped_features = features.unsqueeze(2) 413 | if self.use_xyz: 414 | new_features = torch.cat( 415 | [grouped_xyz, grouped_features], dim=1 416 | ) # (B, 3 + C, 1, N) 417 | else: 418 | new_features = grouped_features 419 | else: 420 | new_features = grouped_xyz 421 | 422 | if self.ret_grouped_xyz: 423 | return new_features, grouped_xyz 424 | else: 425 | return new_features 426 | -------------------------------------------------------------------------------- /lib/projection.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.autograd import Function 4 | 5 | 6 | class ProjectionHelper(): 7 | def __init__(self, intrinsic, depth_min, depth_max, image_dims, accuracy, cuda=True): 8 | self.intrinsic = intrinsic 9 | self.depth_min = depth_min 10 | self.depth_max = depth_max 11 | self.image_dims = image_dims 12 | self.accuracy = accuracy 13 | self.cuda = cuda 14 | 15 | # precompute 16 | self._compute_corner_points() 17 | 18 | def depth_to_skeleton(self, ux, uy, depth): 19 | # 2D to 3D coordinates with depth (used in compute_frustum_bounds) 20 | x = (ux - self.intrinsic[0][2]) / self.intrinsic[0][0] 21 | y = (uy - self.intrinsic[1][2]) / self.intrinsic[1][1] 22 | return torch.Tensor([depth*x, depth*y, depth]) 23 | 24 | def skeleton_to_depth(self, p): 25 | x = (p[0] * self.intrinsic[0][0]) / p[2] + self.intrinsic[0][2] 26 | y = (p[1] * self.intrinsic[1][1]) / p[2] + self.intrinsic[1][2] 27 | return torch.Tensor([x, y, p[2]]) 28 | 29 | def _compute_corner_points(self): 30 | if self.cuda: 31 | corner_points = torch.ones(8, 4).cuda() 32 | else: 33 | corner_points = torch.ones(8, 4) 34 | 35 | # image to camera 36 | # depth min 37 | corner_points[0][:3] = self.depth_to_skeleton(0, 0, self.depth_min) 38 | corner_points[1][:3] = self.depth_to_skeleton(self.image_dims[0] - 1, 0, self.depth_min) 39 | corner_points[2][:3] = self.depth_to_skeleton(self.image_dims[0] - 1, self.image_dims[1] - 1, self.depth_min) 40 | corner_points[3][:3] = self.depth_to_skeleton(0, self.image_dims[1] - 1, self.depth_min) 41 | # depth max 42 | corner_points[4][:3] = self.depth_to_skeleton(0, 0, self.depth_max) 43 | corner_points[5][:3] = self.depth_to_skeleton(self.image_dims[0] - 1, 0, self.depth_max) 44 | corner_points[6][:3] = self.depth_to_skeleton(self.image_dims[0] - 1, self.image_dims[1] - 1, self.depth_max) 45 | corner_points[7][:3] = self.depth_to_skeleton(0, self.image_dims[1] - 1, self.depth_max) 46 | 47 | self.corner_points = corner_points 48 | 49 | def compute_frustum_corners(self, camera_to_world): 50 | """ 51 | Computes the coordinates of the viewing frustum corresponding to one image and given camera parameters 52 | 53 | :param camera_to_world: torch tensor of shape (4, 4) 54 | :return: corner_coords: torch tensor of shape (8, 4) 55 | """ 56 | # input: camera pose (torch.Size([4, 4])) 57 | # output: coordinates of the corner points of the viewing frustum of the camera 58 | 59 | # corner_points = camera_to_world.new(8, 4, 1).fill_(1) 60 | 61 | # # image to camera 62 | # # depth min 63 | # corner_points[0][:3] = self.depth_to_skeleton(0, 0, self.depth_min).unsqueeze(1) 64 | # corner_points[1][:3] = self.depth_to_skeleton(self.image_dims[0] - 1, 0, self.depth_min).unsqueeze(1) 65 | # corner_points[2][:3] = self.depth_to_skeleton(self.image_dims[0] - 1, self.image_dims[1] - 1, self.depth_min).unsqueeze(1) 66 | # corner_points[3][:3] = self.depth_to_skeleton(0, self.image_dims[1] - 1, self.depth_min).unsqueeze(1) 67 | # # depth max 68 | # corner_points[4][:3] = self.depth_to_skeleton(0, 0, self.depth_max).unsqueeze(1) 69 | # corner_points[5][:3] = self.depth_to_skeleton(self.image_dims[0] - 1, 0, self.depth_max).unsqueeze(1) 70 | # corner_points[6][:3] = self.depth_to_skeleton(self.image_dims[0] - 1, self.image_dims[1] - 1, self.depth_max).unsqueeze(1) 71 | # corner_points[7][:3] = self.depth_to_skeleton(0, self.image_dims[1] - 1, self.depth_max).unsqueeze(1) 72 | 73 | 74 | # camera to world 75 | corner_coords = torch.bmm(camera_to_world.repeat(8, 1, 1), self.corner_points.unsqueeze(2)) 76 | 77 | return corner_coords 78 | 79 | def compute_frustum_normals(self, corner_coords): 80 | """ 81 | Computes the normal vectors (pointing inwards) to the 6 planes that bound the viewing frustum 82 | 83 | :param corner_coords: torch tensor of shape (8, 4), coordinates of the corner points of the viewing frustum 84 | :return: normals: torch tensor of shape (6, 3) 85 | """ 86 | 87 | normals = corner_coords.new(6, 3) 88 | 89 | # compute plane normals 90 | # front plane 91 | plane_vec1 = corner_coords[3][:3] - corner_coords[0][:3] 92 | plane_vec2 = corner_coords[1][:3] - corner_coords[0][:3] 93 | normals[0] = torch.cross(plane_vec1.view(-1), plane_vec2.view(-1)) 94 | 95 | # right side plane 96 | plane_vec1 = corner_coords[2][:3] - corner_coords[1][:3] 97 | plane_vec2 = corner_coords[5][:3] - corner_coords[1][:3] 98 | normals[1] = torch.cross(plane_vec1.view(-1), plane_vec2.view(-1)) 99 | 100 | # roof plane 101 | plane_vec1 = corner_coords[3][:3] - corner_coords[2][:3] 102 | plane_vec2 = corner_coords[6][:3] - corner_coords[2][:3] 103 | normals[2] = torch.cross(plane_vec1.view(-1), plane_vec2.view(-1)) 104 | 105 | # left side plane 106 | plane_vec1 = corner_coords[0][:3] - corner_coords[3][:3] 107 | plane_vec2 = corner_coords[7][:3] - corner_coords[3][:3] 108 | normals[3] = torch.cross(plane_vec1.view(-1), plane_vec2.view(-1)) 109 | 110 | # bottom plane 111 | plane_vec1 = corner_coords[1][:3] - corner_coords[0][:3] 112 | plane_vec2 = corner_coords[4][:3] - corner_coords[0][:3] 113 | normals[4] = torch.cross(plane_vec1.view(-1), plane_vec2.view(-1)) 114 | 115 | # back plane 116 | plane_vec1 = corner_coords[6][:3] - corner_coords[5][:3] 117 | plane_vec2 = corner_coords[4][:3] - corner_coords[5][:3] 118 | normals[5] = torch.cross(plane_vec1.view(-1), plane_vec2.view(-1)) 119 | 120 | return normals 121 | 122 | def points_in_frustum(self, corner_coords, normals, new_pts, return_mask=False): 123 | """ 124 | Checks whether new_pts ly in the frustum defined by the coordinates of the corners coner_coords 125 | 126 | :param corner_coords: torch tensor of shape (8, 4), coordinates of the corners of the viewing frustum 127 | :param normals: torch tensor of shape (6, 3), normal vectors of the 6 planes of the viewing frustum 128 | :param new_pts: (num_points, 3) 129 | :param return_mask: if False, returns number of new_points in frustum 130 | :return: if return_mask=True, returns Boolean mask determining whether point is in frustum or not 131 | """ 132 | 133 | # create vectors from point set to the planes 134 | point_to_plane1 = (new_pts.cuda() - corner_coords[2][:3].view(-1)) 135 | point_to_plane2 = (new_pts.cuda() - corner_coords[4][:3].view(-1)) 136 | 137 | # check if the scalar product with the normals is positive 138 | masks = list() 139 | # for each normal, create a mask for points that lie on the correct side of the plane 140 | for k, normal in enumerate(normals): 141 | if k < 3: 142 | masks.append(torch.round(torch.mm(point_to_plane1, normal.unsqueeze(1)) * 100) / 100 < 0) 143 | else: 144 | masks.append(torch.round(torch.mm(point_to_plane2, normal.unsqueeze(1)) * 100) / 100 < 0) 145 | mask = torch.ones(point_to_plane1.shape[0]) > 0 146 | mask = mask.cuda() 147 | 148 | # create a combined mask, which keeps only the points that lie on the correct side of each plane 149 | for addMask in masks: 150 | mask = mask * addMask.squeeze() 151 | 152 | if return_mask: 153 | return mask 154 | else: 155 | return torch.sum(mask) 156 | 157 | def points_in_frustum_cpu(self, corner_coords, normals, new_pts, return_mask=False): 158 | """ 159 | Checks whether new_pts ly in the frustum defined by the coordinates of the corners coner_coords 160 | 161 | :param corner_coords: torch tensor of shape (8, 4), coordinates of the corners of the viewing frustum 162 | :param normals: torch tensor of shape (6, 3), normal vectors of the 6 planes of the viewing frustum 163 | :param new_pts: (num_points, 3) 164 | :param return_mask: if False, returns number of new_points in frustum 165 | :return: if return_mask=True, returns Boolean mask determining whether point is in frustum or not 166 | """ 167 | 168 | # create vectors from point set to the planes 169 | point_to_plane1 = (new_pts - corner_coords[2][:3].view(-1)) 170 | point_to_plane2 = (new_pts - corner_coords[4][:3].view(-1)) 171 | 172 | # check if the scalar product with the normals is positive 173 | masks = list() 174 | # for each normal, create a mask for points that lie on the correct side of the plane 175 | for k, normal in enumerate(normals): 176 | if k < 3: 177 | masks.append(torch.round(torch.mm(point_to_plane1, normal.unsqueeze(1)) * 100) / 100 < 0) 178 | else: 179 | masks.append(torch.round(torch.mm(point_to_plane2, normal.unsqueeze(1)) * 100) / 100 < 0) 180 | mask = torch.ones(point_to_plane1.shape[0]) > 0 181 | 182 | # create a combined mask, which keeps only the points that lie on the correct side of each plane 183 | for addMask in masks: 184 | mask = mask * addMask.squeeze() 185 | 186 | if return_mask: 187 | return mask 188 | else: 189 | return torch.sum(mask) 190 | 191 | def compute_projection(self, points, depth, camera_to_world): 192 | """ 193 | Computes correspondances of points to pixels 194 | 195 | :param points: tensor containing all points of the point cloud (num_points, 3) 196 | :param depth: depth map (size: proj_image) 197 | :param camera_to_world: camera pose (4, 4) 198 | :param num_points: number of points in one sample point cloud (4096) 199 | :return: indices_3d (array with point indices that correspond to a pixel), 200 | indices_2d (array with pixel indices that correspond to a point) 201 | """ 202 | 203 | num_points = points.shape[0] 204 | world_to_camera = torch.inverse(camera_to_world) 205 | 206 | # create 1-dim array with all indices and array with 4-dim coordinates x, y, z, 1 of points 207 | ind_points = torch.arange(0, num_points, out=torch.LongTensor()).cuda() 208 | coords = camera_to_world.new(4, num_points) 209 | coords[:3, :] = torch.t(points) 210 | coords[3, :].fill_(1) 211 | 212 | # compute viewing frustum 213 | corner_coords = self.compute_frustum_corners(camera_to_world) 214 | normals = self.compute_frustum_normals(corner_coords) 215 | 216 | # check if points are in viewing frustum and only keep according indices 217 | mask_frustum_bounds = self.points_in_frustum(corner_coords, normals, points, return_mask=True).cuda() 218 | 219 | if not mask_frustum_bounds.any(): 220 | return None 221 | ind_points = ind_points[mask_frustum_bounds] 222 | coords = coords[:, ind_points] 223 | 224 | # project world (coords) to camera 225 | camera = torch.mm(world_to_camera, coords) 226 | 227 | # project camera to image 228 | camera[0] = (camera[0] * self.intrinsic[0][0]) / camera[2] + self.intrinsic[0][2] 229 | camera[1] = (camera[1] * self.intrinsic[1][1]) / camera[2] + self.intrinsic[1][2] 230 | image = torch.round(camera).long() 231 | 232 | # keep points that are projected onto the image into the correct pixel range 233 | valid_ind_mask = torch.ge(image[0], 0) * torch.ge(image[1], 0) * torch.lt(image[0], self.image_dims[0]) * torch.lt(image[1], self.image_dims[1]) 234 | if not valid_ind_mask.any(): 235 | return None 236 | valid_image_ind_x = image[0][valid_ind_mask] 237 | valid_image_ind_y = image[1][valid_ind_mask] 238 | valid_image_ind = valid_image_ind_y * self.image_dims[0] + valid_image_ind_x 239 | 240 | # keep only points that are in the correct depth ranges (self.depth_min - self.depth_max) 241 | depth_vals = torch.index_select(depth.view(-1), 0, valid_image_ind.cuda()) 242 | depth_mask = depth_vals.ge(self.depth_min) * depth_vals.le(self.depth_max) * torch.abs(depth_vals - camera[2][valid_ind_mask]).le(self.accuracy) 243 | if not depth_mask.any(): 244 | return None 245 | 246 | # create two vectors for all considered points that establish 3d to 2d correspondence 247 | ind_update = ind_points[valid_ind_mask] 248 | ind_update = ind_update[depth_mask] 249 | indices_3d = ind_update.new(num_points + 1).fill_(0) # needs to be same size for all in batch... (first element has size) 250 | indices_2d = ind_update.new(num_points + 1).fill_(0) # needs to be same size for all in batch... (first element has size) 251 | indices_3d[0] = ind_update.shape[0] # first entry: number of relevant entries (of points) 252 | indices_2d[0] = ind_update.shape[0] 253 | indices_3d[1:1 + indices_3d[0]] = ind_update # indices of points 254 | indices_2d[1:1 + indices_2d[0]] = torch.index_select(valid_image_ind, 0, torch.nonzero(depth_mask)[:, 0]) # indices of corresponding pixels 255 | 256 | return indices_3d, indices_2d 257 | 258 | @torch.no_grad() 259 | def project(self, label, lin_indices_3d, lin_indices_2d, num_points): 260 | """ 261 | forward pass of backprojection for 2d features onto 3d points 262 | 263 | :param label: image features (shape: (num_input_channels, proj_image_dims[0], proj_image_dims[1])) 264 | :param lin_indices_3d: point indices from projection (shape: (num_input_channels, num_points_sample)) 265 | :param lin_indices_2d: pixel indices from projection (shape: (num_input_channels, num_points_sample)) 266 | :param num_points: number of points in one sample 267 | :return: array of points in sample with projected features (shape: (num_input_channels, num_points)) 268 | """ 269 | 270 | num_label_ft = 1 if len(label.shape) == 2 else label.shape[0] # = num_input_channels 271 | 272 | output = label.new(num_label_ft, num_points).fill_(0) 273 | num_ind = lin_indices_3d[0] 274 | if num_ind > 0: 275 | # selects values from image_features at indices given by lin_indices_2d 276 | vals = torch.index_select(label.view(num_label_ft, -1), 1, lin_indices_2d[1:1+num_ind]) 277 | output.view(num_label_ft, -1)[:, lin_indices_3d[1:1+num_ind]] = vals 278 | 279 | return output 280 | 281 | 282 | # Inherit from Function 283 | class Projection(Function): 284 | 285 | @staticmethod 286 | def forward(ctx, label, lin_indices_3d, lin_indices_2d, num_points): 287 | """ 288 | forward pass of backprojection for 2d features onto 3d points 289 | 290 | :param label: image features (shape: (num_input_channels, proj_image_dims[0], proj_image_dims[1])) 291 | :param lin_indices_3d: point indices from projection (shape: (num_input_channels, num_points_sample)) 292 | :param lin_indices_2d: pixel indices from projection (shape: (num_input_channels, num_points_sample)) 293 | :param num_points: number of points in one sample 294 | :return: array of points in sample with projected features (shape: (num_input_channels, num_points)) 295 | """ 296 | # ctx.save_for_backward(lin_indices_3d, lin_indices_2d) 297 | num_label_ft = 1 if len(label.shape) == 2 else label.shape[0] # = num_input_channels 298 | 299 | output = label.new(num_label_ft, num_points).fill_(0) 300 | num_ind = lin_indices_3d[0] 301 | if num_ind > 0: 302 | # selects values from image_features at indices given by lin_indices_2d 303 | vals = torch.index_select(label.view(num_label_ft, -1), 1, lin_indices_2d[1:1+num_ind]) 304 | output.view(num_label_ft, -1)[:, lin_indices_3d[1:1+num_ind]] = vals 305 | return output 306 | 307 | @staticmethod 308 | def backward(ctx, grad_output): 309 | grad_label = grad_output.clone() 310 | num_ft = grad_output.shape[0] 311 | grad_label.resize_(num_ft, 32, 41) 312 | lin_indices_3d, lin_indices_2d = ctx.saved_variables 313 | num_ind = lin_indices_3d.data[0] 314 | vals = torch.index_select(grad_output.data.contiguous().view(num_ft, -1), 1, lin_indices_3d.data[1:1+num_ind]) 315 | grad_label.data.view(num_ft, -1)[:, lin_indices_2d.data[1:1+num_ind]] = vals 316 | 317 | return grad_label, None, None, None 318 | -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import h5py 5 | import torch 6 | import numpy as np 7 | import multiprocessing as mp 8 | from tqdm import tqdm 9 | from prefetch_generator import background 10 | 11 | sys.path.append(".") 12 | from lib.config import CONF 13 | 14 | class ScannetDataset(): 15 | def __init__(self, phase, scene_list, num_classes=21, npoints=8192, is_weighting=True, use_multiview=False, use_color=False, use_normal=False): 16 | self.phase = phase 17 | assert phase in ["train", "val", "test"] 18 | self.scene_list = scene_list 19 | self.num_classes = num_classes 20 | self.npoints = npoints 21 | self.is_weighting = is_weighting 22 | self.use_multiview = use_multiview 23 | self.use_color = use_color 24 | self.use_normal = use_normal 25 | self.chunk_data = {} # init in generate_chunks() 26 | 27 | self._prepare_weights() 28 | 29 | def _prepare_weights(self): 30 | self.scene_data = {} 31 | self.multiview_data = {} 32 | scene_points_list = [] 33 | semantic_labels_list = [] 34 | if self.use_multiview: 35 | multiview_database = h5py.File(CONF.MULTIVIEW, "r", libver="latest") 36 | for scene_id in tqdm(self.scene_list): 37 | scene_data = np.load(CONF.SCANNETV2_FILE.format(scene_id)) 38 | label = scene_data[:, 10] 39 | 40 | # append 41 | scene_points_list.append(scene_data) 42 | semantic_labels_list.append(label) 43 | self.scene_data[scene_id] = scene_data 44 | 45 | if self.use_multiview: 46 | feature = multiview_database.get(scene_id)[()] 47 | self.multiview_data[scene_id] = feature 48 | 49 | if self.is_weighting: 50 | labelweights = np.zeros(self.num_classes) 51 | for seg in semantic_labels_list: 52 | tmp,_ = np.histogram(seg,range(self.num_classes + 1)) 53 | labelweights += tmp 54 | labelweights = labelweights.astype(np.float32) 55 | labelweights = labelweights/np.sum(labelweights) 56 | self.labelweights = 1/np.log(1.2+labelweights) 57 | else: 58 | self.labelweights = np.ones(self.num_classes) 59 | 60 | @background() 61 | def __getitem__(self, index): 62 | start = time.time() 63 | 64 | # load chunks 65 | scene_id = self.scene_list[index] 66 | scene_data = self.chunk_data[scene_id] 67 | # unpack 68 | point_set = scene_data[:, :3] # include xyz by default 69 | rgb = scene_data[:, 3:6] / 255. # normalize the rgb values to [0, 1] 70 | normal = scene_data[:, 6:9] 71 | label = scene_data[:, 10].astype(np.int32) 72 | if self.use_multiview: 73 | feature = scene_data[:, 11:] 74 | point_set = np.concatenate([point_set, feature], axis=1) 75 | 76 | if self.use_color: 77 | point_set = np.concatenate([point_set, rgb], axis=1) 78 | 79 | if self.use_normal: 80 | point_set = np.concatenate([point_set, normal], axis=1) 81 | 82 | if self.phase == "train": 83 | point_set = self._augment(point_set) 84 | 85 | # prepare mask 86 | curmin = np.min(point_set, axis=0)[:3] 87 | curmax = np.max(point_set, axis=0)[:3] 88 | mask = np.sum((point_set[:, :3] >= (curmin - 0.01)) * (point_set[:, :3] <= (curmax + 0.01)), axis=1) == 3 89 | sample_weight = self.labelweights[label] 90 | sample_weight *= mask 91 | 92 | fetch_time = time.time() - start 93 | 94 | return point_set, label, sample_weight, fetch_time 95 | 96 | def __len__(self): 97 | return len(self.scene_list) 98 | 99 | def _augment(self, point_set): 100 | # translate the chunk center to the origin 101 | center = np.mean(point_set[:, :3], axis=0) 102 | coords = point_set[:, :3] - center 103 | 104 | p = np.random.choice(np.arange(0.01, 1.01, 0.01), size=1)[0] 105 | if p < 1 / 8: 106 | # random translation 107 | coords = self._translate(coords) 108 | elif p >= 1 / 8 and p < 2 / 8: 109 | # random rotation 110 | coords = self._rotate(coords) 111 | elif p >= 2 / 8 and p < 3 / 8: 112 | # random scaling 113 | coords = self._scale(coords) 114 | elif p >= 3 / 8 and p < 4 / 8: 115 | # random translation 116 | coords = self._translate(coords) 117 | # random rotation 118 | coords = self._rotate(coords) 119 | elif p >= 4 / 8 and p < 5 / 8: 120 | # random translation 121 | coords = self._translate(coords) 122 | # random scaling 123 | coords = self._scale(coords) 124 | elif p >= 5 / 8 and p < 6 / 8: 125 | # random rotation 126 | coords = self._rotate(coords) 127 | # random scaling 128 | coords = self._scale(coords) 129 | elif p >= 6 / 8 and p < 7 / 8: 130 | # random translation 131 | coords = self._translate(coords) 132 | # random rotation 133 | coords = self._rotate(coords) 134 | # random scaling 135 | coords = self._scale(coords) 136 | else: 137 | # no augmentation 138 | pass 139 | 140 | # translate the chunk center back to the original center 141 | coords += center 142 | point_set[:, :3] = coords 143 | 144 | return point_set 145 | 146 | def _translate(self, point_set): 147 | # translation factors 148 | x_factor = np.random.choice(np.arange(-0.5, 0.501, 0.001), size=1)[0] 149 | y_factor = np.random.choice(np.arange(-0.5, 0.501, 0.001), size=1)[0] 150 | z_factor = np.random.choice(np.arange(-0.5, 0.501, 0.001), size=1)[0] 151 | 152 | coords = point_set[:, :3] 153 | coords += [x_factor, y_factor, z_factor] 154 | point_set[:, :3] = coords 155 | 156 | return point_set 157 | 158 | def _rotate(self, point_set): 159 | coords = point_set[:, :3] 160 | 161 | # x rotation matrix 162 | theta = np.random.choice(np.arange(-5, 5.001, 0.001), size=1)[0] * 3.14 / 180 # in radians 163 | Rx = np.array( 164 | [[1, 0, 0], 165 | [0, np.cos(theta), -np.sin(theta)], 166 | [0, np.sin(theta), np.cos(theta)]] 167 | ) 168 | 169 | # y rotation matrix 170 | theta = np.random.choice(np.arange(-5, 5.001, 0.001), size=1)[0] * 3.14 / 180 # in radians 171 | Ry = np.array( 172 | [[np.cos(theta), 0, np.sin(theta)], 173 | [0, 1, 0], 174 | [-np.sin(theta), 0, np.cos(theta)]] 175 | ) 176 | 177 | # z rotation matrix 178 | theta = np.random.choice(np.arange(-5, 5.001, 0.001), size=1)[0] * 3.14 / 180 # in radians 179 | Rz = np.array( 180 | [[np.cos(theta), -np.sin(theta), 0], 181 | [np.sin(theta), np.cos(theta), 0], 182 | [0, 0, 1]] 183 | ) 184 | 185 | # rotate 186 | R = np.matmul(np.matmul(Rz, Ry), Rx) 187 | coords = np.matmul(R, coords.T).T 188 | 189 | # dump 190 | point_set[:, :3] = coords 191 | 192 | return point_set 193 | 194 | def _scale(self, point_set): 195 | # scaling factors 196 | factor = np.random.choice(np.arange(0.95, 1.051, 0.001), size=1)[0] 197 | 198 | coords = point_set[:, :3] 199 | coords *= [factor, factor, factor] 200 | point_set[:, :3] = coords 201 | 202 | return point_set 203 | 204 | def generate_chunks(self): 205 | """ 206 | note: must be called before training 207 | """ 208 | 209 | print("generate new chunks for {}...".format(self.phase)) 210 | for scene_id in tqdm(self.scene_list): 211 | scene = self.scene_data[scene_id] 212 | semantic = scene[:, 10].astype(np.int32) 213 | if self.use_multiview: 214 | feature = self.multiview_data[scene_id] 215 | 216 | coordmax = np.max(scene, axis=0)[:3] 217 | coordmin = np.min(scene, axis=0)[:3] 218 | 219 | for _ in range(5): 220 | curcenter = scene[np.random.choice(len(semantic), 1)[0],:3] 221 | curmin = curcenter-[0.75,0.75,1.5] 222 | curmax = curcenter+[0.75,0.75,1.5] 223 | curmin[2] = coordmin[2] 224 | curmax[2] = coordmax[2] 225 | curchoice = np.sum((scene[:, :3]>=(curmin-0.2))*(scene[:, :3]<=(curmax+0.2)),axis=1)==3 226 | cur_point_set = scene[curchoice] 227 | cur_semantic_seg = semantic[curchoice] 228 | if self.use_multiview: 229 | cur_feature = feature[curchoice] 230 | 231 | if len(cur_semantic_seg)==0: 232 | continue 233 | 234 | mask = np.sum((cur_point_set[:, :3]>=(curmin-0.01))*(cur_point_set[:, :3]<=(curmax+0.01)),axis=1)==3 235 | vidx = np.ceil((cur_point_set[mask,:3]-curmin)/(curmax-curmin)*[31.0,31.0,62.0]) 236 | vidx = np.unique(vidx[:,0]*31.0*62.0+vidx[:,1]*62.0+vidx[:,2]) 237 | isvalid = np.sum(cur_semantic_seg>0)/len(cur_semantic_seg)>=0.7 and len(vidx)/31.0/31.0/62.0>=0.02 238 | 239 | if isvalid: 240 | break 241 | 242 | # store chunk 243 | if self.use_multiview: 244 | chunk = np.concatenate([cur_point_set, cur_feature], axis=1) 245 | else: 246 | chunk = cur_point_set 247 | 248 | choices = np.random.choice(chunk.shape[0], self.npoints, replace=True) 249 | chunk = chunk[choices] 250 | self.chunk_data[scene_id] = chunk 251 | 252 | print("done!\n") 253 | 254 | class ScannetDatasetWholeScene(): 255 | def __init__(self, scene_list, npoints=8192, is_weighting=True, use_color=False, use_normal=False, use_multiview=False): 256 | self.scene_list = scene_list 257 | self.npoints = npoints 258 | self.is_weighting = is_weighting 259 | self.use_color = use_color 260 | self.use_normal = use_normal 261 | self.use_multiview = use_multiview 262 | 263 | self._load_scene_file() 264 | 265 | def _load_scene_file(self): 266 | self.scene_points_list = [] 267 | self.semantic_labels_list = [] 268 | if self.use_multiview: 269 | multiview_database = h5py.File(CONF.MULTIVIEW, "r", libver="latest") 270 | self.multiview_data = [] 271 | 272 | for scene_id in tqdm(self.scene_list): 273 | scene_data = np.load(CONF.SCANNETV2_FILE.format(scene_id)) 274 | label = scene_data[:, 10].astype(np.int32) 275 | self.scene_points_list.append(scene_data) 276 | self.semantic_labels_list.append(label) 277 | 278 | if self.use_multiview: 279 | feature = multiview_database.get(scene_id)[()] 280 | self.multiview_data.append(feature) 281 | 282 | if self.is_weighting: 283 | labelweights = np.zeros(CONF.NUM_CLASSES) 284 | for seg in self.semantic_labels_list: 285 | tmp,_ = np.histogram(seg,range(CONF.NUM_CLASSES + 1)) 286 | labelweights += tmp 287 | labelweights = labelweights.astype(np.float32) 288 | labelweights = labelweights/np.sum(labelweights) 289 | self.labelweights = 1/np.log(1.2+labelweights) 290 | else: 291 | self.labelweights = np.ones(CONF.NUM_CLASSES) 292 | 293 | @background() 294 | def __getitem__(self, index): 295 | start = time.time() 296 | scene_data = self.scene_points_list[index] 297 | 298 | # unpack 299 | point_set_ini = scene_data[:, :3] # include xyz by default 300 | color = scene_data[:, 3:6] / 255. # normalize the rgb values to [0, 1] 301 | normal = scene_data[:, 6:9] 302 | 303 | if self.use_color: 304 | point_set_ini = np.concatenate([point_set_ini, color], axis=1) 305 | 306 | if self.use_normal: 307 | point_set_ini = np.concatenate([point_set_ini, normal], axis=1) 308 | 309 | if self.use_multiview: 310 | multiview_features = self.multiview_data[index] 311 | point_set_ini = np.concatenate([point_set_ini, multiview_features], axis=1) 312 | 313 | semantic_seg_ini = self.semantic_labels_list[index].astype(np.int32) 314 | coordmax = point_set_ini[:, :3].max(axis=0) 315 | coordmin = point_set_ini[:, :3].min(axis=0) 316 | xlength = 1.5 317 | ylength = 1.5 318 | nsubvolume_x = np.ceil((coordmax[0]-coordmin[0])/xlength).astype(np.int32) 319 | nsubvolume_y = np.ceil((coordmax[1]-coordmin[1])/ylength).astype(np.int32) 320 | point_sets = list() 321 | semantic_segs = list() 322 | sample_weights = list() 323 | 324 | for i in range(nsubvolume_x): 325 | for j in range(nsubvolume_y): 326 | curmin = coordmin+[i*xlength, j*ylength, 0] 327 | curmax = coordmin+[(i+1)*xlength, (j+1)*ylength, coordmax[2]-coordmin[2]] 328 | mask = np.sum((point_set_ini[:, :3]>=(curmin-0.01))*(point_set_ini[:, :3]<=(curmax+0.01)), axis=1)==3 329 | cur_point_set = point_set_ini[mask,:] 330 | cur_semantic_seg = semantic_seg_ini[mask] 331 | if len(cur_semantic_seg) == 0: 332 | continue 333 | 334 | choice = np.random.choice(len(cur_semantic_seg), self.npoints, replace=True) 335 | point_set = cur_point_set[choice,:] # Nx3 336 | semantic_seg = cur_semantic_seg[choice] # N 337 | mask = mask[choice] 338 | # if sum(mask)/float(len(mask))<0.01: 339 | # continue 340 | 341 | sample_weight = self.labelweights[semantic_seg] 342 | sample_weight *= mask # N 343 | point_sets.append(np.expand_dims(point_set,0)) # 1xNx3 344 | semantic_segs.append(np.expand_dims(semantic_seg,0)) # 1xN 345 | sample_weights.append(np.expand_dims(sample_weight,0)) # 1xN 346 | 347 | point_sets = np.concatenate(tuple(point_sets),axis=0) 348 | semantic_segs = np.concatenate(tuple(semantic_segs),axis=0) 349 | sample_weights = np.concatenate(tuple(sample_weights),axis=0) 350 | 351 | fetch_time = time.time() - start 352 | 353 | return point_sets, semantic_segs, sample_weights, fetch_time 354 | 355 | def __len__(self): 356 | return len(self.scene_points_list) 357 | 358 | def collate_random(data): 359 | ''' 360 | for ScannetDataset: collate_fn=collate_random 361 | 362 | return: 363 | coords # torch.FloatTensor(B, N, 3) 364 | feats # torch.FloatTensor(B, N, 3) 365 | semantic_segs # torch.FloatTensor(B, N) 366 | sample_weights # torch.FloatTensor(B, N) 367 | fetch_time # float 368 | ''' 369 | 370 | # load data 371 | ( 372 | point_set, 373 | semantic_seg, 374 | sample_weight, 375 | fetch_time 376 | ) = zip(*data) 377 | 378 | # convert to tensor 379 | point_set = torch.FloatTensor(point_set) 380 | semantic_seg = torch.LongTensor(semantic_seg) 381 | sample_weight = torch.FloatTensor(sample_weight) 382 | 383 | # split points to coords and feats 384 | coords = point_set[:, :, :3] 385 | feats = point_set[:, :, 3:] 386 | 387 | # pack 388 | batch = ( 389 | coords, # (B, N, 3) 390 | feats, # (B, N, 3) 391 | semantic_seg, # (B, N) 392 | sample_weight, # (B, N) 393 | sum(fetch_time) # float 394 | ) 395 | 396 | return batch 397 | 398 | def collate_wholescene(data): 399 | ''' 400 | for ScannetDataset: collate_fn=collate_random 401 | 402 | return: 403 | coords # torch.FloatTensor(B, C, N, 3) 404 | feats # torch.FloatTensor(B, C, N, 3) 405 | semantic_segs # torch.FloatTensor(B, C, N) 406 | sample_weights # torch.FloatTensor(B, C, N) 407 | fetch_time # float 408 | ''' 409 | 410 | # load data 411 | ( 412 | point_sets, 413 | semantic_segs, 414 | sample_weights, 415 | fetch_time 416 | ) = zip(*data) 417 | 418 | # convert to tensor 419 | point_sets = torch.FloatTensor(point_sets) 420 | semantic_segs = torch.LongTensor(semantic_segs) 421 | sample_weights = torch.FloatTensor(sample_weights) 422 | 423 | # split points to coords and feats 424 | coords = point_sets[:, :, :, :3] 425 | feats = point_sets[:, :, :, 3:] 426 | 427 | # pack 428 | batch = ( 429 | coords, # (B, N, 3) 430 | feats, # (B, N, 3) 431 | semantic_segs, # (B, N) 432 | sample_weights, # (B, N) 433 | sum(fetch_time) # float 434 | ) 435 | 436 | return batch --------------------------------------------------------------------------------