├── pointnet2 ├── pointnet2.egg-info │ ├── dependency_links.txt │ ├── top_level.txt │ ├── PKG-INFO │ └── SOURCES.txt ├── _ext_src │ ├── include │ │ ├── ball_query.h │ │ ├── group_points.h │ │ ├── sampling.h │ │ ├── interpolate.h │ │ ├── utils.h │ │ └── cuda_utils.h │ └── src │ │ ├── bindings.cpp │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── sampling.cpp │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ └── sampling_gpu.cu ├── setup.py ├── pytorch_utils.py └── pointnet2_utils.py ├── auction_match ├── __init__.py ├── auction_match_gpu.cpp ├── auction_match.py └── auction_match_gpu.cu ├── chamfer_distance ├── __init__.py ├── chamfer_distance.py ├── chamfer_distance.cu └── chamfer_distance.cpp ├── notebook ├── test_img.jpg ├── .ipynb_checkpoints │ ├── test_a_epoch-checkpoint.ipynb │ ├── test_loss-checkpoint.ipynb │ └── Untitled-checkpoint.ipynb └── Untitled.ipynb ├── .idea ├── inspectionProfiles │ └── profiles_settings.xml ├── vcs.xml ├── modules.xml ├── remote-mappings.xml ├── misc.xml ├── PUGAN_pytorch.iml ├── webServers.xml ├── deployment.xml ├── sshConfigs.xml └── workspace.xml ├── utils ├── xyz_util.py ├── visualize_utils.py ├── Logger.py ├── data_util.py ├── pc_util.py └── ply_utils.py ├── option ├── test_option.py └── train_option.py ├── PUNet-Evaluation ├── CMakeLists.txt ├── nicolo_density.xyz └── evaluation.cpp ├── README.md ├── test ├── visualize_all.py ├── test.py └── eval.py ├── requirements.txt ├── loss └── loss.py ├── data ├── data_loader.py └── test_list.txt └── train ├── train_recon.py ├── train.py └── train_emd_only.py /pointnet2/pointnet2.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pointnet2/pointnet2.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | pointnet2 2 | -------------------------------------------------------------------------------- /auction_match/__init__.py: -------------------------------------------------------------------------------- 1 | from .auction_match import auction_match 2 | -------------------------------------------------------------------------------- /chamfer_distance/__init__.py: -------------------------------------------------------------------------------- 1 | from .chamfer_distance import chamfer_distance 2 | -------------------------------------------------------------------------------- /notebook/test_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaolinLiu97/PUGAN-pytorch/HEAD/notebook/test_img.jpg -------------------------------------------------------------------------------- /notebook/.ipynb_checkpoints/test_a_epoch-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 4 6 | } 7 | -------------------------------------------------------------------------------- /notebook/.ipynb_checkpoints/test_loss-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 4 6 | } 7 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /pointnet2/pointnet2.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: pointnet2 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /utils/xyz_util.py: -------------------------------------------------------------------------------- 1 | 2 | def save_xyz_file(numpy_array, xyz_dir): 3 | num_points = numpy_array.shape[0] 4 | with open(xyz_dir, 'w') as f: 5 | for i in range(num_points): 6 | line = "%f %f %f\n" % (numpy_array[i, 0], numpy_array[i, 1], numpy_array[i, 2]) 7 | f.write(line) 8 | return 9 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 10 | const int nsample); 11 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | -------------------------------------------------------------------------------- /.idea/PUGAN_pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /utils/visualize_utils.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | 3 | def visualize_point_cloud(xyz): 4 | ''' 5 | Args: 6 | xyz is of shape N,3 7 | ''' 8 | pcd=o3d.geometry.PointCloud() 9 | pcd.points=o3d.utility.Vector3dVector(xyz) 10 | vis=o3d.visualization.Visualizer() 11 | 12 | vis.create_window() 13 | vis.add_geometry(pcd) 14 | img=vis.capture_screen_float_buffer(True) 15 | 16 | return img 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /pointnet2/pointnet2.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | _ext_src/src/ball_query.cpp 3 | _ext_src/src/ball_query_gpu.cu 4 | _ext_src/src/bindings.cpp 5 | _ext_src/src/group_points.cpp 6 | _ext_src/src/group_points_gpu.cu 7 | _ext_src/src/interpolate.cpp 8 | _ext_src/src/interpolate_gpu.cu 9 | _ext_src/src/sampling.cpp 10 | _ext_src/src/sampling_gpu.cu 11 | pointnet2.egg-info/PKG-INFO 12 | pointnet2.egg-info/SOURCES.txt 13 | pointnet2.egg-info/dependency_links.txt 14 | pointnet2.egg-info/top_level.txt -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 11 | 12 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 12 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 13 | at::Tensor weight); 14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 15 | at::Tensor weight, const int m); 16 | -------------------------------------------------------------------------------- /option/test_option.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def get_train_options(): 4 | opt = {} 5 | 6 | opt['project_dir'] = "/data2/haolin/PUGAN_pytorch" 7 | opt['model_save_dir'] = opt['project_dir'] + '/checkpoints' 8 | opt["test_save_dir"]=opt['project_dir'] + '/test_results' 9 | opt['test_log_dir']=opt['project_dir'] + '/log_results' 10 | opt['dataset_dir'] = os.path.join(opt["project_dir"],"Patches_noHole_and_collected.h5") 11 | opt['isTrain']=False 12 | opt['batch_size'] = 1 13 | opt["patch_num_point"]=1024 14 | opt['lr_D']=1e-4 15 | opt['lr_G']=1e-3 16 | opt['emd_w']=100.0 17 | opt['uniform_w']=10.0 18 | opt['gan_w']=0.5 19 | opt['repulsion_w']=5.0 20 | opt['use_gan']=False 21 | return opt 22 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 22 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "group_points.h" 8 | #include "interpolate.h" 9 | #include "sampling.h" 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("gather_points", &gather_points); 13 | m.def("gather_points_grad", &gather_points_grad); 14 | m.def("furthest_point_sampling", &furthest_point_sampling); 15 | 16 | m.def("three_nn", &three_nn); 17 | m.def("three_interpolate", &three_interpolate); 18 | m.def("three_interpolate_grad", &three_interpolate_grad); 19 | 20 | m.def("ball_query", &ball_query); 21 | 22 | m.def("group_points", &group_points); 23 | m.def("group_points_grad", &group_points_grad); 24 | } 25 | -------------------------------------------------------------------------------- /auction_match/auction_match_gpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void AuctionMatchLauncher(int b,int n,const float * xyz1,const float * xyz2,int * matchl,int * matchr,float * cost); 4 | 5 | int auction_match_wrapper_fast(int b, int n, 6 | at::Tensor xyz1_tensor, at::Tensor xyz2_tensor, at::Tensor matchl_tensor, 7 | at::Tensor matchr_tensor, at::Tensor cost_tensor) { 8 | 9 | const float *xyz1 = xyz1_tensor.data(); 10 | const float *xyz2 = xyz2_tensor.data(); 11 | int *matchl = matchl_tensor.data(); 12 | int *matchr = matchr_tensor.data(); 13 | float *cost = cost_tensor.data(); 14 | 15 | AuctionMatchLauncher(b, n, xyz1, xyz2, matchl, matchr, cost); 16 | return 1; 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("auction_match_cuda", &auction_match_wrapper_fast, "auction_match_wrapper_fast forward"); 21 | } -------------------------------------------------------------------------------- /PUNet-Evaluation/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Created by the script cgal_create_cmake_script 2 | # This is the CMake script for compiling a CGAL application. 3 | 4 | 5 | project( Distance_2_Tests ) 6 | cmake_minimum_required(VERSION 2.8.10) 7 | set (CMAKE_CXX_STANDARD 11) 8 | 9 | find_package(OpenMP) 10 | if (OPENMP_FOUND) 11 | set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 12 | set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 13 | endif() 14 | 15 | 16 | find_package(CGAL QUIET) 17 | if ( CGAL_FOUND ) 18 | include( ${CGAL_USE_FILE} ) 19 | include( CGAL_CreateSingleSourceCGALProgram ) 20 | include_directories (BEFORE "../../include") 21 | # create a target per cppfile 22 | file(GLOB cppfiles RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 23 | foreach(cppfile ${cppfiles}) 24 | create_single_source_cgal_program( "${cppfile}" ) 25 | endforeach() 26 | 27 | else() 28 | message(STATUS "This program requires the CGAL library, and will not be compiled.") 29 | endif() 30 | 31 | -------------------------------------------------------------------------------- /option/train_option.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def get_train_options(): 4 | opt = {} 5 | 6 | opt['project_dir'] = "/mnt/beegfs/haolin/PUGAN-pytorch" 7 | opt['model_save_dir'] = opt['project_dir'] + '/checkpoints' 8 | opt["test_save_dir"]=opt['project_dir'] + '/test_results' 9 | opt['test_log_dir']=opt['project_dir'] + '/log_results' 10 | opt['dataset_dir'] = os.path.join(opt["project_dir"],"Patches_noHole_and_collected.h5") 11 | opt['test_split']= os.path.join(opt['project_dir'],'data','test_list.txt') 12 | opt['train_split']=os.path.join(opt['project_dir'],'data','train_list.txt') 13 | opt['isTrain']=True 14 | opt['batch_size'] = 2 15 | opt['nepoch'] = 100 16 | opt['model_save_interval'] = 10 17 | opt['model_vis_interval']=200 18 | opt["up_ratio"]=4 19 | opt["patch_num_point"]=1024 20 | opt['lr_D']=1e-4 21 | opt['lr_G']=1e-3 22 | opt['emd_w']=100.0 23 | opt['uniform_w']=10.0 24 | opt['gan_w']=0.5 25 | opt['repulsion_w']=5.0 26 | opt['use_gan']=False 27 | return opt 28 | -------------------------------------------------------------------------------- /pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | 10 | _ext_src_root = "_ext_src" 11 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 12 | "{}/src/*.cu".format(_ext_src_root) 13 | ) 14 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 15 | 16 | setup( 17 | name='pointnet2', 18 | ext_modules=[ 19 | CUDAExtension( 20 | name='pointnet2._ext', 21 | sources=_ext_sources, 22 | extra_compile_args={ 23 | "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 24 | "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 25 | }, 26 | ) 27 | ], 28 | cmdclass={ 29 | 'build_ext': BuildExtension 30 | } 31 | ) 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PUGAN-pytorch 2 | Pytorch unofficial implementation of PUGAN (a Point Cloud Upsampling Adversarial Network, ICCV, 2019) 3 | 4 | #### Install some packages 5 | simply by 6 | ``` 7 | pip install -r requirement.txt 8 | ``` 9 | #### Install Pointnet2 module 10 | ``` 11 | cd pointnet2 12 | python setup.py install 13 | ``` 14 | #### Install KNN_cuda 15 | ``` 16 | pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl 17 | ``` 18 | #### dataset 19 | We use the PU-Net dataset for training, you can refer to https://github.com/yulequan/PU-Net to download the .h5 dataset file, which can be directly used in this project. 20 | #### modify some setting in the option/train_option.py 21 | change opt['project_dir'] to where this project is located, and change opt['dataset_dir'] to where you store the dataset. 22 |
23 | also change params['train_split'] and params['test_split'] to where you save the train/test split txt files. 24 | #### training 25 | ``` 26 | cd train 27 | python train.py --exp_name=the_project_name --gpu=gpu_number --use_gan --batch_size=12 28 | ``` 29 | 30 | -------------------------------------------------------------------------------- /test/visualize_all.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | import glob 4 | import numpy as np 5 | import os, sys 6 | sys.path.append("../") 7 | from utils.pc_util import draw_point_cloud 8 | 9 | parser = argparse.ArgumentParser(description="Arg parser") 10 | parser.add_argument('--data_dir', type=str, required=True) 11 | parser.add_argument('--exp_name',type=str,required=True) 12 | 13 | args = parser.parse_args() 14 | 15 | if __name__=="__main__": 16 | file_dir=glob.glob(os.path.join(args.data_dir,"*.xyz"))##visualize all xyz file 17 | pcd_list=[] 18 | for file in file_dir: 19 | if file.split('/')[-1].find("_")<0: 20 | pcd_list.append(file) 21 | image_save_dir=os.path.join("../vis_result",args.exp_name) 22 | if os.path.exists(image_save_dir)==False: 23 | os.makedirs(image_save_dir) 24 | 25 | for file in pcd_list: 26 | file_name=file.split("/")[-1].split('.')[0] 27 | pcd=np.loadtxt(file) 28 | img = draw_point_cloud(pcd, zrot=90 / 180.0 * np.pi, xrot=90 / 180.0 * np.pi, yrot=0 / 180.0 * np.pi, 29 | diameter=4) 30 | img=(img*255).astype(np.uint8) 31 | image_save_path=os.path.join(image_save_dir,file_name+".png") 32 | cv2.imwrite(image_save_path,img) 33 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | AT_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | AT_CHECK(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | AT_CHECK(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "utils.h" 8 | 9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 10 | int nsample, const float *new_xyz, 11 | const float *xyz, int *idx); 12 | 13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 14 | const int nsample) { 15 | CHECK_CONTIGUOUS(new_xyz); 16 | CHECK_CONTIGUOUS(xyz); 17 | CHECK_IS_FLOAT(new_xyz); 18 | CHECK_IS_FLOAT(xyz); 19 | 20 | if (new_xyz.type().is_cuda()) { 21 | CHECK_CUDA(xyz); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, nsample, new_xyz.data(), 31 | xyz.data(), idx.data()); 32 | } else { 33 | AT_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /.idea/sshConfigs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /auction_match/auction_match.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.cpp_extension import load 3 | import os 4 | 5 | script_dir = os.path.dirname(__file__) 6 | sources = [ 7 | os.path.join(script_dir, "auction_match_gpu.cpp"), 8 | os.path.join(script_dir, "auction_match_gpu.cu"), 9 | ] 10 | 11 | am = load(name="am", sources=sources) 12 | 13 | class AuctionMatch(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, xyz1: torch.Tensor, xyz2: int) -> torch.Tensor: 16 | """ 17 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 18 | minimum distance 19 | :param ctx: 20 | :param xyz1: (B, N, 3) 21 | :param xyz2: (B, N, 3) 22 | :return: 23 | match_left: (B, N) tensor containing the set 24 | match_right: (B, N) tensor containing the set 25 | """ 26 | assert xyz1.is_contiguous() and xyz2.is_contiguous() 27 | assert xyz1.shape[1] <= 4096 28 | 29 | B, N, _ = xyz1.size() 30 | match_left = torch.cuda.IntTensor(B, N) 31 | match_right = torch.cuda.IntTensor(B, N) 32 | temp = torch.cuda.FloatTensor(B, N, N).fill_(0) 33 | 34 | am.auction_match_cuda(B, N, xyz1, xyz2, match_left, match_right, temp) 35 | return match_left, match_right 36 | 37 | @staticmethod 38 | def backward(ml, mr, a=None): 39 | return None, None 40 | 41 | auction_match = AuctionMatch.apply 42 | 43 | if __name__ == '__main__': 44 | import numpy as np 45 | p1 = torch.from_numpy(np.array([[[1,0,0],[2,0,0],[3,0,0],[4,0,0]]], dtype=np.float32)).cuda() 46 | p2 = torch.from_numpy(np.array([[[-10,0,0], [1,0, 0], [2,0, 0], [3,0,0]]], dtype=np.float32)).cuda() 47 | ml, mr = auction_match(p2, p1) 48 | print(ml, mr) -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | sys.path.append("../") 4 | 5 | parser = argparse.ArgumentParser(description="Arg parser") 6 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use') 7 | parser.add_argument('--resume', type=str, required=True) 8 | parser.add_argument('--exp_name',type=str,required=True) 9 | 10 | args = parser.parse_args() 11 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.utils.data import DataLoader 16 | from utils.xyz_util import save_xyz_file 17 | 18 | from network.networks import Generator 19 | from data.data_loader import PUNET_Dataset_Whole 20 | 21 | if __name__ == '__main__': 22 | model = Generator() 23 | 24 | checkpoint = torch.load(args.resume) 25 | model.load_state_dict(checkpoint) 26 | model.eval().cuda() 27 | 28 | eval_dst = PUNET_Dataset_Whole(data_dir='../MC_5k') 29 | eval_loader = DataLoader(eval_dst, batch_size=1, 30 | shuffle=False, pin_memory=True, num_workers=0) 31 | 32 | names = eval_dst.names 33 | exp_name=args.exp_name 34 | save_dir=os.path.join('../outputs',exp_name) 35 | if os.path.exists(save_dir)==False: 36 | os.makedirs(save_dir) 37 | for itr, batch in enumerate(eval_loader): 38 | name = names[itr] 39 | points = batch[:,:,0:3].permute(0,2,1).float().cuda() 40 | preds = model(points) 41 | #radius=radius.float().cuda() 42 | #centroid=centroid.float().cuda() 43 | #print(preds.shape,radius.shape,centroid.shape) 44 | #preds=preds*radius+centroid.unsqueeze(2).repeat(1,1,4096) 45 | 46 | preds = preds.permute(0,2,1).data.cpu().numpy()[0] 47 | points = points.permute(0,2,1).data.cpu().numpy() 48 | save_file='../outputs/{}/{}.xyz'.format(exp_name,name) 49 | #print(preds.shape) 50 | save_xyz_file(preds,save_file) 51 | 52 | -------------------------------------------------------------------------------- /chamfer_distance/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.cpp_extension import load 3 | import os 4 | 5 | script_dir = os.path.dirname(__file__) 6 | sources = [ 7 | os.path.join(script_dir, "chamfer_distance.cpp"), 8 | os.path.join(script_dir, "chamfer_distance.cu"), 9 | ] 10 | 11 | cd = load(name="cd", sources=sources) 12 | 13 | 14 | class ChamferDistanceFunction(torch.autograd.Function): 15 | @staticmethod 16 | def forward(ctx, xyz1, xyz2): 17 | batchsize, n, _ = xyz1.size() 18 | _, m, _ = xyz2.size() 19 | xyz1 = xyz1.contiguous() 20 | xyz2 = xyz2.contiguous() 21 | dist1 = torch.zeros(batchsize, n) 22 | dist2 = torch.zeros(batchsize, m) 23 | 24 | idx1 = torch.zeros(batchsize, n, dtype=torch.int) 25 | idx2 = torch.zeros(batchsize, m, dtype=torch.int) 26 | 27 | if not xyz1.is_cuda: 28 | cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 29 | else: 30 | dist1 = dist1.cuda() 31 | dist2 = dist2.cuda() 32 | idx1 = idx1.cuda() 33 | idx2 = idx2.cuda() 34 | cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) 35 | 36 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 37 | 38 | return dist1, dist2 39 | 40 | @staticmethod 41 | def backward(ctx, graddist1, graddist2): 42 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 43 | 44 | graddist1 = graddist1.contiguous() 45 | graddist2 = graddist2.contiguous() 46 | 47 | gradxyz1 = torch.zeros(xyz1.size()) 48 | gradxyz2 = torch.zeros(xyz2.size()) 49 | 50 | if not graddist1.is_cuda: 51 | cd.backward( 52 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 53 | ) 54 | else: 55 | gradxyz1 = gradxyz1.cuda() 56 | gradxyz2 = gradxyz2.cuda() 57 | cd.backward_cuda( 58 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 59 | ) 60 | 61 | return gradxyz1, gradxyz2 62 | 63 | chamfer_distance = ChamferDistanceFunction.apply 64 | -------------------------------------------------------------------------------- /notebook/.ipynb_checkpoints/Untitled-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os,sys\n", 10 | "sys.path.append('../')\n", 11 | "import Common.pc_util as pc_util\n", 12 | "import numpy as np" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 4, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "data_dir='/data2/haolin/PUGAN_pytorch/outputs/full_bs12_0508/camel.xyz'\n", 22 | "data=np.loadtxt(data_dir)\n", 23 | "img=pc_util.point_cloud_three_views(data,diameter=3)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 5, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "name": "stdout", 33 | "output_type": "stream", 34 | "text": [ 35 | "(1000, 3000)\n" 36 | ] 37 | }, 38 | { 39 | "data": { 40 | "text/plain": [ 41 | "True" 42 | ] 43 | }, 44 | "execution_count": 5, 45 | "metadata": {}, 46 | "output_type": "execute_result" 47 | } 48 | ], 49 | "source": [ 50 | "print(img.shape)\n", 51 | "import cv2\n", 52 | "cv2.imwrite('./test_img.jpg',(img*255).astype(np.uint8))" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [] 68 | } 69 | ], 70 | "metadata": { 71 | "kernelspec": { 72 | "display_name": "Python 3", 73 | "language": "python", 74 | "name": "python3" 75 | }, 76 | "language_info": { 77 | "codemirror_mode": { 78 | "name": "ipython", 79 | "version": 3 80 | }, 81 | "file_extension": ".py", 82 | "mimetype": "text/x-python", 83 | "name": "python", 84 | "nbconvert_exporter": "python", 85 | "pygments_lexer": "ipython3", 86 | "version": "3.6.10" 87 | } 88 | }, 89 | "nbformat": 4, 90 | "nbformat_minor": 4 91 | } 92 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 13 | // output: idx(b, m, nsample) 14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 15 | int nsample, 16 | const float *__restrict__ new_xyz, 17 | const float *__restrict__ xyz, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | xyz += batch_index * n * 3; 21 | new_xyz += batch_index * m * 3; 22 | idx += m * nsample * batch_index; 23 | 24 | int index = threadIdx.x; 25 | int stride = blockDim.x; 26 | 27 | float radius2 = radius * radius; 28 | for (int j = index; j < m; j += stride) { 29 | float new_x = new_xyz[j * 3 + 0]; 30 | float new_y = new_xyz[j * 3 + 1]; 31 | float new_z = new_xyz[j * 3 + 2]; 32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 33 | float x = xyz[k * 3 + 0]; 34 | float y = xyz[k * 3 + 1]; 35 | float z = xyz[k * 3 + 2]; 36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 37 | (new_z - z) * (new_z - z); 38 | if (d2 < radius2) { 39 | if (cnt == 0) { 40 | for (int l = 0; l < nsample; ++l) { 41 | idx[j * nsample + l] = k; 42 | } 43 | } 44 | idx[j * nsample + cnt] = k; 45 | ++cnt; 46 | } 47 | } 48 | } 49 | } 50 | 51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 52 | int nsample, const float *new_xyz, 53 | const float *xyz, int *idx) { 54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 55 | query_ball_point_kernel<<>>( 56 | b, n, m, radius, nsample, new_xyz, xyz, idx); 57 | 58 | CUDA_CHECK_ERRORS(); 59 | } 60 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "group_points.h" 7 | #include "utils.h" 8 | 9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 10 | const float *points, const int *idx, 11 | float *out); 12 | 13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 14 | int nsample, const float *grad_out, 15 | const int *idx, float *grad_points); 16 | 17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 18 | CHECK_CONTIGUOUS(points); 19 | CHECK_CONTIGUOUS(idx); 20 | CHECK_IS_FLOAT(points); 21 | CHECK_IS_INT(idx); 22 | 23 | if (points.type().is_cuda()) { 24 | CHECK_CUDA(idx); 25 | } 26 | 27 | at::Tensor output = 28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 29 | at::device(points.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (points.type().is_cuda()) { 32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 33 | idx.size(1), idx.size(2), points.data(), 34 | idx.data(), output.data()); 35 | } else { 36 | AT_CHECK(false, "CPU not supported"); 37 | } 38 | 39 | return output; 40 | } 41 | 42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 43 | CHECK_CONTIGUOUS(grad_out); 44 | CHECK_CONTIGUOUS(idx); 45 | CHECK_IS_FLOAT(grad_out); 46 | CHECK_IS_INT(idx); 47 | 48 | if (grad_out.type().is_cuda()) { 49 | CHECK_CUDA(idx); 50 | } 51 | 52 | at::Tensor output = 53 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 54 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 55 | 56 | if (grad_out.type().is_cuda()) { 57 | group_points_grad_kernel_wrapper( 58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 59 | grad_out.data(), idx.data(), output.data()); 60 | } else { 61 | AT_CHECK(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | -------------------------------------------------------------------------------- /PUNet-Evaluation/nicolo_density.xyz: -------------------------------------------------------------------------------- 1 | 1.15451 1.1966 1.15852 1.13948 1.18338 1.18658 1.16894 2 | 1.64759 1.55739 1.55538 1.49124 1.4696 1.43312 1.41428 3 | 1.2868 1.377 1.37098 1.32288 1.2868 1.28279 1.24912 4 | 0.926015 0.871897 0.861875 0.907976 0.901963 0.90998 0.902764 5 | 0.998172 0.95608 0.986146 0.989152 1.0126 0.992159 1.00699 6 | 0.673465 0.685492 0.693509 0.682485 0.668655 0.641396 0.617343 7 | 0.950067 1.04628 1.03826 1.06131 1.02703 1.02222 1.03265 8 | 1.13046 1.02222 0.934032 0.910982 0.897152 0.915993 0.934834 9 | 0.56523 0.637387 0.5973 0.583269 0.598903 0.595295 0.610929 10 | 1.08236 1.14249 1.09438 1.11843 1.11843 1.12244 1.11924 11 | 1.73177 1.58745 1.43512 1.37399 1.29161 1.26475 1.19941 12 | 0.938041 0.865884 0.845841 0.862878 0.851453 0.83181 0.848246 13 | 1.31085 1.34092 1.39904 1.39203 1.39985 1.41708 1.42711 14 | 0.769675 0.835819 0.821788 0.826799 0.815374 0.821788 0.838625 15 | 1.41909 1.31085 1.379 1.37098 1.3782 1.34894 1.29402 16 | 1.02222 1.06432 1.07434 1.11543 1.14489 1.15652 1.21705 17 | 1.09438 1.16654 1.13447 1.08837 1.06071 1.07835 1.07594 18 | 0.817779 0.577256 0.549195 0.580263 0.589282 0.601308 0.606119 19 | 1.38301 1.24471 1.22667 1.18157 1.17375 1.18859 1.193 20 | 0.456994 0.469021 0.44096 0.453988 0.461805 0.473029 0.468219 21 | 0.625361 0.583269 0.573247 0.562223 0.550799 0.581265 0.598101 22 | 0.926015 0.938041 0.893945 0.898956 0.875505 0.867889 0.877109 23 | 0.709544 0.649413 0.661439 0.655426 0.637387 0.649413 0.646206 24 | 1.17856 1.16654 1.07434 1.04928 0.995767 0.982137 0.976525 25 | 1.5153 1.29281 1.32288 1.34092 1.29161 1.25072 1.17856 26 | 1.5153 1.53334 1.43512 1.44013 1.44314 1.40706 1.36457 27 | 1.0583 0.986146 0.990155 0.995165 0.983741 0.978128 0.994163 28 | 0.805753 0.757649 0.761657 0.787714 0.762459 0.781701 0.772882 29 | 0.87791 0.847845 0.897954 0.88693 0.885126 0.891941 0.901161 30 | 1.2387 1.16053 1.10641 1.10941 1.1016 1.1044 1.13848 31 | 0.745622 0.781701 0.845841 0.790721 0.793727 0.819784 0.877109 32 | 0.986146 1.08837 1.03024 1.00118 0.998172 1.00619 1.0102 33 | 0.986146 0.89595 0.809762 0.790721 0.743217 0.725579 0.703932 34 | 0.853858 0.745622 0.725579 0.733596 0.707139 0.6895 0.694311 35 | 1.07033 0.992159 1.01822 0.998172 1.03666 1.04628 1.0551 36 | 1.27477 1.26876 1.19861 1.15451 1.16894 1.17055 1.16894 37 | 1.10641 1.18458 1.16253 1.21765 1.1954 1.1946 1.16413 38 | 0.938041 0.932028 0.922006 0.941048 0.945257 0.946059 0.942852 39 | 1.34693 1.16654 1.04628 1.06732 1.00779 0.986146 0.963697 40 | 0.637387 0.613335 0.625361 0.616341 0.589282 0.589282 0.569239 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | astor==0.8.1 3 | attrs==19.3.0 4 | backcall==0.1.0 5 | bleach==3.1.3 6 | cachetools==4.0.0 7 | certifi==2019.11.28 8 | chardet==3.0.4 9 | colored-traceback==0.3.0 10 | cycler==0.10.0 11 | decorator==4.4.1 12 | defusedxml==0.6.0 13 | docutils==0.16 14 | entrypoints==0.3 15 | enum34==1.1.10 16 | future==0.18.2 17 | gast==0.3.3 18 | google-auth==1.12.0 19 | google-auth-oauthlib==0.4.1 20 | grpcio==1.27.2 21 | h5py==2.10.0 22 | hydra-core==0.11.3 23 | idna==2.8 24 | imageio==2.6.1 25 | imageio-ffmpeg==0.3.0 26 | importlib-metadata==1.5.0 27 | ipdb==0.13.2 28 | ipykernel==5.2.0 29 | ipython==7.13.0 30 | ipython-genutils==0.2.0 31 | ipywidgets==7.5.1 32 | jedi==0.16.0 33 | Jinja2==2.11.1 34 | joblib==0.14.1 35 | jsonpatch==1.25 36 | jsonpointer==2.0 37 | jsonschema==3.2.0 38 | jupyter-client==6.1.0 39 | jupyter-core==4.6.3 40 | Keras-Applications==1.0.8 41 | Keras-Preprocessing==1.1.0 42 | kiwisolver==1.1.0 43 | lmdb==0.98 44 | Markdown==3.1.1 45 | MarkupSafe==1.1.1 46 | matplotlib==3.1.3 47 | mistune==0.8.4 48 | mock==3.0.5 49 | moviepy==1.0.1 50 | msgpack==1.0.0 51 | msgpack-numpy==0.4.4.3 52 | nbconvert==5.6.1 53 | nbformat==5.0.4 54 | networkx==2.4 55 | notebook==6.0.3 56 | numpy==1.18.1 57 | oauthlib==3.1.0 58 | omegaconf==1.4.1 59 | open3d==0.9.0.0 60 | opencv-python==4.2.0.32 61 | pandas==1.0.3 62 | pandocfilters==1.4.2 63 | parso==0.6.2 64 | pexpect==4.8.0 65 | pickleshare==0.7.5 66 | Pillow==6.2.2 67 | plyfile==0.7.2 68 | pprint==0.1 69 | proglog==0.1.9 70 | prometheus-client==0.7.1 71 | prompt-toolkit==3.0.3 72 | protobuf==3.11.3 73 | ptyprocess==0.6.0 74 | pyasn1==0.4.8 75 | pyasn1-modules==0.2.8 76 | Pygments==2.5.2 77 | pyparsing==2.4.6 78 | pyrsistent==0.16.0 79 | python-dateutil==2.8.1 80 | pytorch-lightning==0.7.1 81 | pytz==2019.3 82 | PyWavelets==1.1.1 83 | PyYAML==5.3.1 84 | pyzmq==19.0.0 85 | requests==2.22.0 86 | requests-oauthlib==1.3.0 87 | rsa==4.0 88 | scikit-image==0.16.2 89 | scikit-learn==0.22.2.post1 90 | scikit-video==1.1.11 91 | scipy==1.4.1 92 | Send2Trash==1.5.0 93 | six==1.14.0 94 | statistics==1.0.3.5 95 | tensorboard==1.13.1 96 | tensorboard-plugin-wit==1.6.0.post2 97 | tensorboardX==2.0 98 | tensorflow==1.13.1 99 | tensorflow-estimator==1.13.0 100 | termcolor==1.1.0 101 | terminado==0.8.3 102 | testpath==0.4.4 103 | torch==1.2.0 104 | torchfile==0.1.0 105 | torchvision==0.4.0 106 | tornado==6.0.4 107 | tqdm==4.42.1 108 | traitlets==4.3.3 109 | urllib3==1.25.8 110 | visdom==0.1.8.9 111 | wcwidth==0.1.8 112 | webencodings==0.5.1 113 | websocket-client==0.57.0 114 | Werkzeug==0.16.1 115 | widgetsnbextension==3.5.1 116 | zipp==3.1.0 117 | -------------------------------------------------------------------------------- /utils/Logger.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | try: 4 | from StringIO import StringIO # Python 2.7 5 | except ImportError: 6 | from io import BytesIO # Python 3.x 7 | 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | class Logger(object): 12 | def __init__(self, log_dir): 13 | """Create a summary writer logging to log_dir.""" 14 | self.writer = tf.summary.FileWriter(log_dir) 15 | 16 | def scalar_summary(self, tag, value, step): 17 | """Log a scalar variable.""" 18 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 19 | self.writer.add_summary(summary, step) 20 | 21 | def image_summary(self, tag, images, step): 22 | """Log a list of images.""" 23 | 24 | img_summaries = [] 25 | for i, img in enumerate(images): 26 | # Write the image to a string 27 | try: 28 | s = StringIO() 29 | except: 30 | s = BytesIO() 31 | # scipy.misc.toimage(img).save(s, format="png") 32 | plt.imsave(s, img, format='png') 33 | 34 | # Create an Image object 35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 36 | height=img.shape[0], 37 | width=img.shape[1]) 38 | # Create a Summary value 39 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 40 | 41 | # Create and write Summary 42 | summary = tf.Summary(value=img_summaries) 43 | self.writer.add_summary(summary, step) 44 | 45 | def histo_summary(self, tag, values, step, bins=1000): 46 | """Log a histogram of the tensor of values.""" 47 | 48 | # Create a histogram using numpy 49 | counts, bin_edges = np.histogram(values, bins=bins) 50 | 51 | # Fill the fields of the histogram proto 52 | hist = tf.HistogramProto() 53 | hist.min = float(np.min(values)) 54 | hist.max = float(np.max(values)) 55 | hist.num = int(np.prod(values.shape)) 56 | hist.sum = float(np.sum(values)) 57 | hist.sum_squares = float(np.sum(values ** 2)) 58 | 59 | # Drop the start of the first bin 60 | bin_edges = bin_edges[1:] 61 | 62 | # Add bin edges and counts 63 | for edge in bin_edges: 64 | hist.bucket_limit.append(edge) 65 | for c in counts: 66 | hist.bucket.append(c) 67 | 68 | # Create and write Summary 69 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 70 | self.writer.add_summary(summary, step) 71 | self.writer.flush() -------------------------------------------------------------------------------- /test/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | sys.path.append("../") 4 | 5 | parser = argparse.ArgumentParser(description="Arg parser") 6 | parser.add_argument('--gpu', type=int, default=0, help='GPU to use') 7 | parser.add_argument("--model", type=str, default='punet') 8 | parser.add_argument("--batch_size", type=int, default=8) 9 | parser.add_argument("--workers", type=int, default=4) 10 | parser.add_argument('--up_ratio', type=int, default=4, help='Upsampling Ratio [default: 4]') 11 | parser.add_argument("--use_bn", action='store_true', default=False) 12 | parser.add_argument("--use_res", action='store_true', default=False) 13 | parser.add_argument('--resume', type=str, required=True) 14 | 15 | args = parser.parse_args() 16 | print(args) 17 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 18 | 19 | import torch 20 | import torch.nn as nn 21 | from torch.utils.data import DataLoader 22 | import numpy as np 23 | 24 | from data.data_loader import PUNET_Dataset 25 | from chamfer_distance import chamfer_distance 26 | from auction_match import auction_match 27 | import pointnet2.utils.pointnet2_utils as pn2_utils 28 | import importlib 29 | from network.networks import Generator 30 | from option.train_option import get_train_options 31 | 32 | 33 | def get_emd_loss(pred, gt, pcd_radius): 34 | idx, _ = auction_match(pred, gt) 35 | matched_out = pn2_utils.gather_operation(gt.transpose(1, 2).contiguous(), idx) 36 | matched_out = matched_out.transpose(1, 2).contiguous() 37 | dist2 = (pred - matched_out) ** 2 38 | dist2 = dist2.view(dist2.shape[0], -1) # <-- ??? 39 | dist2 = torch.mean(dist2, dim=1, keepdims=True) # B, 40 | dist2 /= pcd_radius 41 | return torch.mean(dist2) 42 | 43 | 44 | def get_cd_loss(pred, gt, pcd_radius): 45 | cost_for, cost_bac = chamfer_distance(gt, pred) 46 | cost = 0.5 * cost_for + 0.5 * cost_bac 47 | cost /= pcd_radius 48 | cost = torch.mean(cost) 49 | return cost 50 | 51 | 52 | if __name__ == '__main__': 53 | param=get_train_options() 54 | model = Generator() 55 | 56 | checkpoint = torch.load(args.resume) 57 | model.load_state_dict(checkpoint) 58 | model.eval().cuda() 59 | 60 | eval_dst = PUNET_Dataset(h5_file_path='../Patches_noHole_and_collected.h5', split_dir=param['test_split'], isTrain=False) 61 | eval_loader = DataLoader(eval_dst, batch_size=args.batch_size, 62 | shuffle=False, pin_memory=True, num_workers=args.workers) 63 | 64 | emd_list = [] 65 | cd_list = [] 66 | with torch.no_grad(): 67 | for itr, batch in enumerate(eval_loader): 68 | points, gt, radius = batch 69 | points = points[..., :3].permute(0,2,1).float().cuda().contiguous() 70 | gt = gt[..., :3].float().cuda().contiguous() 71 | radius = radius.float().cuda() 72 | preds = model(points) # points.shape[1]) 73 | preds=preds.permute(0,2,1).contiguous() 74 | 75 | emd = get_emd_loss(preds, gt, radius) 76 | cd = get_cd_loss(preds, gt, radius) 77 | print(' -- iter {}, emd {}, cd {}.'.format(itr, emd, cd)) 78 | emd_list.append(emd.item()) 79 | cd_list.append(cd.item()) 80 | 81 | print('mean emd: {}'.format(np.mean(emd_list))) 82 | print('mean cd: {}'.format(np.mean(cd_list))) -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, npoints, nsample) 12 | // output: out(b, c, npoints, nsample) 13 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 14 | int nsample, 15 | const float *__restrict__ points, 16 | const int *__restrict__ idx, 17 | float *__restrict__ out) { 18 | int batch_index = blockIdx.x; 19 | points += batch_index * n * c; 20 | idx += batch_index * npoints * nsample; 21 | out += batch_index * npoints * nsample * c; 22 | 23 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 24 | const int stride = blockDim.y * blockDim.x; 25 | for (int i = index; i < c * npoints; i += stride) { 26 | const int l = i / npoints; 27 | const int j = i % npoints; 28 | for (int k = 0; k < nsample; ++k) { 29 | int ii = idx[j * nsample + k]; 30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 31 | } 32 | } 33 | } 34 | 35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 36 | const float *points, const int *idx, 37 | float *out) { 38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 39 | 40 | group_points_kernel<<>>( 41 | b, c, n, npoints, nsample, points, idx, out); 42 | 43 | CUDA_CHECK_ERRORS(); 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points) { 74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 75 | 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | CUDA_CHECK_ERRORS(); 80 | } 81 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "sampling.h" 7 | #include "utils.h" 8 | 9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 10 | const float *points, const int *idx, 11 | float *out); 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs); 19 | 20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 21 | CHECK_CONTIGUOUS(points); 22 | CHECK_CONTIGUOUS(idx); 23 | CHECK_IS_FLOAT(points); 24 | CHECK_IS_INT(idx); 25 | 26 | if (points.type().is_cuda()) { 27 | CHECK_CUDA(idx); 28 | } 29 | 30 | at::Tensor output = 31 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 32 | at::device(points.device()).dtype(at::ScalarType::Float)); 33 | 34 | if (points.type().is_cuda()) { 35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 36 | idx.size(1), points.data(), 37 | idx.data(), output.data()); 38 | } else { 39 | AT_CHECK(false, "CPU not supported"); 40 | } 41 | 42 | return output; 43 | } 44 | 45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 46 | const int n) { 47 | CHECK_CONTIGUOUS(grad_out); 48 | CHECK_CONTIGUOUS(idx); 49 | CHECK_IS_FLOAT(grad_out); 50 | CHECK_IS_INT(idx); 51 | 52 | if (grad_out.type().is_cuda()) { 53 | CHECK_CUDA(idx); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 58 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (grad_out.type().is_cuda()) { 61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 62 | idx.size(1), grad_out.data(), 63 | idx.data(), output.data()); 64 | } else { 65 | AT_CHECK(false, "CPU not supported"); 66 | } 67 | 68 | return output; 69 | } 70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 71 | CHECK_CONTIGUOUS(points); 72 | CHECK_IS_FLOAT(points); 73 | 74 | at::Tensor output = 75 | torch::zeros({points.size(0), nsamples}, 76 | at::device(points.device()).dtype(at::ScalarType::Int)); 77 | 78 | at::Tensor tmp = 79 | torch::full({points.size(0), points.size(1)}, 1e10, 80 | at::device(points.device()).dtype(at::ScalarType::Float)); 81 | 82 | if (points.type().is_cuda()) { 83 | furthest_point_sampling_kernel_wrapper( 84 | points.size(0), points.size(1), nsamples, points.data(), 85 | tmp.data(), output.data()); 86 | } else { 87 | AT_CHECK(false, "CPU not supported"); 88 | } 89 | 90 | return output; 91 | } 92 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os,sys 4 | sys.path.append('../') 5 | from auction_match import auction_match 6 | import pointnet2.pointnet2_utils as pn2_utils 7 | import math 8 | from knn_cuda import KNN 9 | 10 | class Loss(nn.Module): 11 | def __init__(self,radius=1.0): 12 | super(Loss,self).__init__() 13 | self.radius=radius 14 | self.knn_uniform=KNN(k=2,transpose_mode=True) 15 | self.knn_repulsion=KNN(k=20,transpose_mode=True) 16 | def get_emd_loss(self,pred,gt,radius=1.0): 17 | ''' 18 | pred and gt is B N 3 19 | ''' 20 | idx, _ = auction_match(pred.contiguous(), gt.contiguous()) 21 | #gather operation has to be B 3 N 22 | #print(gt.transpose(1,2).shape) 23 | matched_out = pn2_utils.gather_operation(gt.transpose(1, 2).contiguous(), idx) 24 | matched_out = matched_out.transpose(1, 2).contiguous() 25 | dist2 = (pred - matched_out) ** 2 26 | dist2 = dist2.view(dist2.shape[0], -1) # <-- ??? 27 | dist2 = torch.mean(dist2, dim=1, keepdims=True) # B, 28 | dist2 /= radius 29 | return torch.mean(dist2) 30 | def get_uniform_loss(self,pcd,percentage=[0.004,0.006,0.008,0.010,0.012],radius=1.0): 31 | B,N,C=pcd.shape[0],pcd.shape[1],pcd.shape[2] 32 | npoint=int(N*0.05) 33 | loss=0 34 | further_point_idx = pn2_utils.furthest_point_sample(pcd.contiguous(), npoint) 35 | new_xyz = pn2_utils.gather_operation(pcd.permute(0, 2, 1).contiguous(), further_point_idx) # B,C,N 36 | for p in percentage: 37 | nsample=int(N*p) 38 | r=math.sqrt(p*radius) 39 | disk_area=math.pi*(radius**2)/N 40 | 41 | idx=pn2_utils.ball_query(r,nsample,pcd.contiguous(),new_xyz.permute(0,2,1).contiguous()) #b N nsample 42 | 43 | expect_len=math.sqrt(disk_area) 44 | 45 | grouped_pcd=pn2_utils.grouping_operation(pcd.permute(0,2,1).contiguous(),idx)#B C N nsample 46 | grouped_pcd=grouped_pcd.permute(0,2,3,1) #B N nsample C 47 | 48 | grouped_pcd=torch.cat(torch.unbind(grouped_pcd,dim=1),dim=0)#B*N nsample C 49 | 50 | dist,_=self.knn_uniform(grouped_pcd,grouped_pcd) 51 | #print(dist.shape) 52 | uniform_dist=dist[:,:,1:] #B*N nsample 1 53 | uniform_dist=torch.abs(uniform_dist+1e-8) 54 | uniform_dist=torch.mean(uniform_dist,dim=1) 55 | uniform_dist=(uniform_dist-expect_len)**2/(expect_len+1e-8) 56 | mean_loss=torch.mean(uniform_dist) 57 | mean_loss=mean_loss*math.pow(p*100,2) 58 | loss+=mean_loss 59 | return loss/len(percentage) 60 | def get_repulsion_loss(self,pcd,h=0.0005): 61 | dist,idx=self.knn_repulsion(pcd,pcd)#B N k 62 | 63 | dist=dist[:,:,1:5]**2 #top 4 cloest neighbors 64 | 65 | loss=torch.clamp(-dist+h,min=0) 66 | loss=torch.mean(loss) 67 | #print(loss) 68 | return loss 69 | def get_discriminator_loss(self,pred_fake,pred_real): 70 | real_loss=torch.mean((pred_real-1)**2) 71 | fake_loss=torch.mean(pred_fake**2) 72 | loss=real_loss+fake_loss 73 | return loss 74 | def get_generator_loss(self,pred_fake): 75 | fake_loss=torch.mean((pred_fake-1)**2) 76 | return fake_loss 77 | def get_discriminator_loss_single(self,pred,label=True): 78 | if label==True: 79 | loss=torch.mean((pred-1)**2) 80 | return loss 81 | else: 82 | loss=torch.mean((pred)**2) 83 | return loss 84 | if __name__=="__main__": 85 | loss=Loss().cuda() 86 | point_cloud=torch.rand(4,4096,3).cuda() 87 | uniform_loss=loss.get_uniform_loss(point_cloud) 88 | repulsion_loss=loss.get_repulsion_loss(point_cloud) 89 | 90 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "interpolate.h" 7 | #include "utils.h" 8 | 9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 10 | const float *known, float *dist2, int *idx); 11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 12 | const float *points, const int *idx, 13 | const float *weight, float *out); 14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 15 | const float *grad_out, 16 | const int *idx, const float *weight, 17 | float *grad_points); 18 | 19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 20 | CHECK_CONTIGUOUS(unknowns); 21 | CHECK_CONTIGUOUS(knows); 22 | CHECK_IS_FLOAT(unknowns); 23 | CHECK_IS_FLOAT(knows); 24 | 25 | if (unknowns.type().is_cuda()) { 26 | CHECK_CUDA(knows); 27 | } 28 | 29 | at::Tensor idx = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 32 | at::Tensor dist2 = 33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 34 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 35 | 36 | if (unknowns.type().is_cuda()) { 37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 38 | unknowns.data(), knows.data(), 39 | dist2.data(), idx.data()); 40 | } else { 41 | AT_CHECK(false, "CPU not supported"); 42 | } 43 | 44 | return {dist2, idx}; 45 | } 46 | 47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 48 | at::Tensor weight) { 49 | CHECK_CONTIGUOUS(points); 50 | CHECK_CONTIGUOUS(idx); 51 | CHECK_CONTIGUOUS(weight); 52 | CHECK_IS_FLOAT(points); 53 | CHECK_IS_INT(idx); 54 | CHECK_IS_FLOAT(weight); 55 | 56 | if (points.type().is_cuda()) { 57 | CHECK_CUDA(idx); 58 | CHECK_CUDA(weight); 59 | } 60 | 61 | at::Tensor output = 62 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 63 | at::device(points.device()).dtype(at::ScalarType::Float)); 64 | 65 | if (points.type().is_cuda()) { 66 | three_interpolate_kernel_wrapper( 67 | points.size(0), points.size(1), points.size(2), idx.size(1), 68 | points.data(), idx.data(), weight.data(), 69 | output.data()); 70 | } else { 71 | AT_CHECK(false, "CPU not supported"); 72 | } 73 | 74 | return output; 75 | } 76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 77 | at::Tensor weight, const int m) { 78 | CHECK_CONTIGUOUS(grad_out); 79 | CHECK_CONTIGUOUS(idx); 80 | CHECK_CONTIGUOUS(weight); 81 | CHECK_IS_FLOAT(grad_out); 82 | CHECK_IS_INT(idx); 83 | CHECK_IS_FLOAT(weight); 84 | 85 | if (grad_out.type().is_cuda()) { 86 | CHECK_CUDA(idx); 87 | CHECK_CUDA(weight); 88 | } 89 | 90 | at::Tensor output = 91 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 92 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 93 | 94 | if (grad_out.type().is_cuda()) { 95 | three_interpolate_kernel_wrapper( 96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 97 | grad_out.data(), idx.data(), weight.data(), 98 | output.data()); 99 | } else { 100 | AT_CHECK(false, "CPU not supported"); 101 | } 102 | 103 | return output; 104 | } 105 | -------------------------------------------------------------------------------- /utils/data_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | ''' 3 | from knn_cuda import KNN 4 | 5 | 6 | def knn_point(group_size, point_cloud, query_cloud): 7 | knn_obj = KNN(k=group_size, transpose_mode=False) 8 | dist, idx = knn_obj(point_cloud, query_cloud) 9 | return dist, idx 10 | ''' 11 | 12 | 13 | def nonuniform_sampling(num, sample_num): 14 | sample = set() 15 | loc = np.random.rand() * 0.8 + 0.1 16 | while len(sample) < sample_num: 17 | a = int(np.random.normal(loc=loc, scale=0.3) * num) 18 | if a < 0 or a >= num: 19 | continue 20 | sample.add(a) 21 | return list(sample) 22 | 23 | 24 | def rotate_point_cloud_and_gt(input_data, gt_data=None): 25 | """ Randomly rotate the point clouds to augument the dataset 26 | rotation is per shape based along up direction 27 | Input: 28 | Nx3 array, original point cloud 29 | Return: 30 | Nx3 array, rotated point cloud 31 | """ 32 | angles = np.random.uniform(size=(3)) * 2 * np.pi 33 | Rx = np.array([[1, 0, 0], 34 | [0, np.cos(angles[0]), -np.sin(angles[0])], 35 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 36 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 37 | [0, 1, 0], 38 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 39 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 40 | [np.sin(angles[2]), np.cos(angles[2]), 0], 41 | [0, 0, 1]]) 42 | rotation_matrix = np.dot(Rz, np.dot(Ry, Rx)) 43 | 44 | input_data[:, :3] = np.dot(input_data[:, :3], rotation_matrix) 45 | if input_data.shape[1] > 3: 46 | input_data[:, 3:] = np.dot(input_data[:, 3:], rotation_matrix) 47 | 48 | if gt_data is not None: 49 | gt_data[:, :3] = np.dot(gt_data[:, :3], rotation_matrix) 50 | if gt_data.shape[1] > 3: 51 | gt_data[:, 3:] = np.dot(gt_data[:, 3:], rotation_matrix) 52 | 53 | return input_data, gt_data 54 | 55 | 56 | def random_scale_point_cloud_and_gt(input_data, gt_data=None, scale_low=0.5, scale_high=2): 57 | """ Randomly scale the point cloud. Scale is per point cloud. 58 | Input: 59 | Nx3 array, original point cloud 60 | Return: 61 | Nx3 array, scaled point cloud 62 | """ 63 | scale = np.random.uniform(scale_low, scale_high) 64 | input_data[:, :3] *= scale 65 | if gt_data is not None: 66 | gt_data[:, :3] *= scale 67 | 68 | return input_data, gt_data, scale 69 | 70 | 71 | def shift_point_cloud_and_gt(input_data, gt_data=None, shift_range=0.3): 72 | """ Randomly shift point cloud. Shift is per point cloud. 73 | Input: 74 | Nx3 array, original point cloud 75 | Return: 76 | Nx3 array, shifted point cloud 77 | """ 78 | shifts = np.random.uniform(-shift_range, shift_range, 3) 79 | input_data[:, :3] += shifts 80 | if gt_data is not None: 81 | gt_data[:, :3] += shifts 82 | return input_data, gt_data 83 | 84 | 85 | def jitter_perturbation_point_cloud(input_data, sigma=0.005, clip=0.02): 86 | """ Randomly jitter points. jittering is per point. 87 | Input: 88 | Nx3 array, original point cloud 89 | Return: 90 | Nx3 array, jittered point cloud 91 | """ 92 | assert (clip > 0) 93 | jitter = np.clip(sigma * np.random.randn(*input_data.shape), -1 * clip, clip) 94 | jitter[:, 3:] = 0 95 | input_data += jitter 96 | return input_data 97 | 98 | 99 | def rotate_perturbation_point_cloud(input_data, angle_sigma=0.03, angle_clip=0.09): 100 | """ Randomly perturb the point clouds by small rotations 101 | Input: 102 | Nx3 array, original point cloud 103 | Return: 104 | Nx3 array, rotated point cloud 105 | """ 106 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip) 107 | Rx = np.array([[1, 0, 0], 108 | [0, np.cos(angles[0]), -np.sin(angles[0])], 109 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 110 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 111 | [0, 1, 0], 112 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 113 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 114 | [np.sin(angles[2]), np.cos(angles[2]), 0], 115 | [0, 0, 1]]) 116 | R = np.dot(Rz, np.dot(Ry, Rx)) 117 | input_data[:, :3] = np.dot(input_data[:, :3], R) 118 | if input_data.shape[1] > 3: 119 | input_data[:, 3:] = np.dot(input_data[:, 3:], R) 120 | return input_data -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch.utils.data as data 3 | import os, sys 4 | sys.path.append("../") 5 | import numpy as np 6 | import utils.data_util as utils 7 | from torchvision import transforms 8 | 9 | class PUNET_Dataset_Whole(data.Dataset): 10 | def __init__(self, data_dir='../MC_5k',n_input=1024): 11 | super().__init__() 12 | self.raw_input_points=5000 13 | self.n_input=1024 14 | 15 | file_list = os.listdir(data_dir) 16 | self.names = [x.split('.')[0] for x in file_list] 17 | self.sample_path = [os.path.join(data_dir, x) for x in file_list] 18 | 19 | def __len__(self): 20 | return len(self.names) 21 | 22 | def __getitem__(self, index): 23 | random_index=np.random.choice(np.linspace(0,self.raw_input_points,self.raw_input_points,endpoint=False),self.n_input).astype(np.int) 24 | points = np.loadtxt(self.sample_path[index]) 25 | 26 | #centroid=np.mean(points[:,0:3],axis=0) 27 | #dist=np.linalg.norm(points[:,0:3]-centroid,axis=1) 28 | #furthest_dist=np.max(dist) 29 | 30 | #reduced_point=points[random_index][:,0:3] 31 | 32 | #normalized_points=(reduced_point-centroid)/furthest_dist 33 | 34 | return points#normalized_points,furthest_dist,centroid 35 | 36 | class PUNET_Dataset(data.Dataset): 37 | def __init__(self, h5_file_path='../Patches_noHole_and_collected.h5',split_dir='./train_list.txt', 38 | skip_rate=1, npoint=1024, use_random=True, use_norm=True,isTrain=True): 39 | super().__init__() 40 | 41 | self.isTrain=isTrain 42 | 43 | self.npoint = npoint 44 | self.use_random = use_random 45 | self.use_norm = use_norm 46 | 47 | h5_file = h5py.File(h5_file_path) 48 | self.gt = h5_file['poisson_4096'][:] # [:] h5_obj => nparray 49 | self.input = h5_file['poisson_4096'][:] if use_random \ 50 | else h5_file['montecarlo_1024'][:] 51 | assert len(self.input) == len(self.gt), 'invalid data' 52 | self.data_npoint = self.input.shape[1] 53 | 54 | centroid = np.mean(self.gt[..., :3], axis=1, keepdims=True) 55 | furthest_distance = np.amax(np.sqrt(np.sum((self.gt[..., :3] - centroid) ** 2, axis=-1)), axis=1, keepdims=True) 56 | self.radius = furthest_distance[:, 0] # not very sure? 57 | 58 | if use_norm: 59 | self.radius = np.ones(shape=(len(self.input))) 60 | self.gt[..., :3] -= centroid 61 | self.gt[..., :3] /= np.expand_dims(furthest_distance, axis=-1) 62 | self.input[..., :3] -= centroid 63 | self.input[..., :3] /= np.expand_dims(furthest_distance, axis=-1) 64 | 65 | self.split_dir = split_dir 66 | self.__load_split_file() 67 | 68 | def __load_split_file(self): 69 | index=np.loadtxt(self.split_dir) 70 | index=index.astype(np.int) 71 | print(index) 72 | self.input=self.input[index,:] 73 | self.gt=self.gt[index,:] 74 | self.radius=self.radius[index] 75 | 76 | def __len__(self): 77 | return self.input.shape[0] 78 | 79 | def __getitem__(self, index): 80 | input_data = self.input[index] 81 | gt_data = self.gt[index] 82 | radius_data = np.array([self.radius[index]]) 83 | 84 | sample_idx = utils.nonuniform_sampling(self.data_npoint, sample_num=self.npoint) 85 | input_data = input_data[sample_idx, :] 86 | 87 | if not self.isTrain: 88 | return input_data, gt_data, radius_data 89 | 90 | if self.use_norm: 91 | # for data aug 92 | input_data, gt_data = utils.rotate_point_cloud_and_gt(input_data, gt_data) 93 | input_data, gt_data, scale = utils.random_scale_point_cloud_and_gt(input_data, gt_data, 94 | scale_low=0.9, scale_high=1.1) 95 | input_data, gt_data = utils.shift_point_cloud_and_gt(input_data, gt_data, shift_range=0.1) 96 | radius_data = radius_data * scale 97 | 98 | # for input aug 99 | #if np.random.rand() > 0.5: 100 | # input_data = utils.jitter_perturbation_point_cloud(input_data, sigma=0.025, clip=0.05) 101 | #if np.random.rand() > 0.5: 102 | # input_data = utils.rotate_perturbation_point_cloud(input_data, angle_sigma=0.03, angle_clip=0.09) 103 | else: 104 | raise NotImplementedError 105 | 106 | return input_data, gt_data, radius_data 107 | 108 | if __name__=="__main__": 109 | dataset=PUNET_Dataset() 110 | #(input_data,gt_data,radius_data)=dataset.__getitem__(0) 111 | #print(input_data.shape,gt_data.shape,radius_data.shape) 112 | #dataset=PUNET_Dataset_Whole(data_dir="../MC_5k",n_input=1024) 113 | #points=dataset.__getitem__(0) 114 | #print(points.shape) -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: unknown(b, n, 3) known(b, m, 3) 13 | // output: dist2(b, n, 3), idx(b, n, 3) 14 | __global__ void three_nn_kernel(int b, int n, int m, 15 | const float *__restrict__ unknown, 16 | const float *__restrict__ known, 17 | float *__restrict__ dist2, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | unknown += batch_index * n * 3; 21 | known += batch_index * m * 3; 22 | dist2 += batch_index * n * 3; 23 | idx += batch_index * n * 3; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | for (int j = index; j < n; j += stride) { 28 | float ux = unknown[j * 3 + 0]; 29 | float uy = unknown[j * 3 + 1]; 30 | float uz = unknown[j * 3 + 2]; 31 | 32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 33 | int besti1 = 0, besti2 = 0, besti3 = 0; 34 | for (int k = 0; k < m; ++k) { 35 | float x = known[k * 3 + 0]; 36 | float y = known[k * 3 + 1]; 37 | float z = known[k * 3 + 2]; 38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 39 | if (d < best1) { 40 | best3 = best2; 41 | besti3 = besti2; 42 | best2 = best1; 43 | besti2 = besti1; 44 | best1 = d; 45 | besti1 = k; 46 | } else if (d < best2) { 47 | best3 = best2; 48 | besti3 = besti2; 49 | best2 = d; 50 | besti2 = k; 51 | } else if (d < best3) { 52 | best3 = d; 53 | besti3 = k; 54 | } 55 | } 56 | dist2[j * 3 + 0] = best1; 57 | dist2[j * 3 + 1] = best2; 58 | dist2[j * 3 + 2] = best3; 59 | 60 | idx[j * 3 + 0] = besti1; 61 | idx[j * 3 + 1] = besti2; 62 | idx[j * 3 + 2] = besti3; 63 | } 64 | } 65 | 66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 67 | const float *known, float *dist2, int *idx) { 68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 69 | three_nn_kernel<<>>(b, n, m, unknown, known, 70 | dist2, idx); 71 | 72 | CUDA_CHECK_ERRORS(); 73 | } 74 | 75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 76 | // output: out(b, c, n) 77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 78 | const float *__restrict__ points, 79 | const int *__restrict__ idx, 80 | const float *__restrict__ weight, 81 | float *__restrict__ out) { 82 | int batch_index = blockIdx.x; 83 | points += batch_index * m * c; 84 | 85 | idx += batch_index * n * 3; 86 | weight += batch_index * n * 3; 87 | 88 | out += batch_index * n * c; 89 | 90 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 91 | const int stride = blockDim.y * blockDim.x; 92 | for (int i = index; i < c * n; i += stride) { 93 | const int l = i / n; 94 | const int j = i % n; 95 | float w1 = weight[j * 3 + 0]; 96 | float w2 = weight[j * 3 + 1]; 97 | float w3 = weight[j * 3 + 2]; 98 | 99 | int i1 = idx[j * 3 + 0]; 100 | int i2 = idx[j * 3 + 1]; 101 | int i3 = idx[j * 3 + 2]; 102 | 103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 104 | points[l * m + i3] * w3; 105 | } 106 | } 107 | 108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 109 | const float *points, const int *idx, 110 | const float *weight, float *out) { 111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 112 | three_interpolate_kernel<<>>( 113 | b, c, m, n, points, idx, weight, out); 114 | 115 | CUDA_CHECK_ERRORS(); 116 | } 117 | 118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 119 | // output: grad_points(b, c, m) 120 | 121 | __global__ void three_interpolate_grad_kernel( 122 | int b, int c, int n, int m, const float *__restrict__ grad_out, 123 | const int *__restrict__ idx, const float *__restrict__ weight, 124 | float *__restrict__ grad_points) { 125 | int batch_index = blockIdx.x; 126 | grad_out += batch_index * n * c; 127 | idx += batch_index * n * 3; 128 | weight += batch_index * n * 3; 129 | grad_points += batch_index * m * c; 130 | 131 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 132 | const int stride = blockDim.y * blockDim.x; 133 | for (int i = index; i < c * n; i += stride) { 134 | const int l = i / n; 135 | const int j = i % n; 136 | float w1 = weight[j * 3 + 0]; 137 | float w2 = weight[j * 3 + 1]; 138 | float w3 = weight[j * 3 + 2]; 139 | 140 | int i1 = idx[j * 3 + 0]; 141 | int i2 = idx[j * 3 + 1]; 142 | int i3 = idx[j * 3 + 2]; 143 | 144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 147 | } 148 | } 149 | 150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 151 | const float *grad_out, 152 | const int *idx, const float *weight, 153 | float *grad_points) { 154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 155 | three_interpolate_grad_kernel<<>>( 156 | b, c, n, m, grad_out, idx, weight, grad_points); 157 | 158 | CUDA_CHECK_ERRORS(); 159 | } 160 | -------------------------------------------------------------------------------- /chamfer_distance/chamfer_distance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | __global__ 7 | void ChamferDistanceKernel( 8 | int b, 9 | int n, 10 | const float* xyz, 11 | int m, 12 | const float* xyz2, 13 | float* result, 14 | int* result_i) 15 | { 16 | const int batch=512; 17 | __shared__ float buf[batch*3]; 18 | for (int i=blockIdx.x;ibest){ 130 | result[(i*n+j)]=best; 131 | result_i[(i*n+j)]=best_i; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | void ChamferDistanceKernelLauncher( 140 | const int b, const int n, 141 | const float* xyz, 142 | const int m, 143 | const float* xyz2, 144 | float* result, 145 | int* result_i, 146 | float* result2, 147 | int* result2_i) 148 | { 149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 151 | 152 | cudaError_t err = cudaGetLastError(); 153 | if (err != cudaSuccess) 154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 155 | } 156 | 157 | 158 | __global__ 159 | void ChamferDistanceGradKernel( 160 | int b, int n, 161 | const float* xyz1, 162 | int m, 163 | const float* xyz2, 164 | const float* grad_dist1, 165 | const int* idx1, 166 | float* grad_xyz1, 167 | float* grad_xyz2) 168 | { 169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 205 | 206 | cudaError_t err = cudaGetLastError(); 207 | if (err != cudaSuccess) 208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 209 | } 210 | -------------------------------------------------------------------------------- /utils/pc_util.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | sys.path.append('../') 3 | import numpy as np 4 | import math 5 | from functools import reduce 6 | 7 | def euler2mat(z=0, y=0, x=0): 8 | ''' Return matrix for rotations around z, y and x axes 9 | 10 | Uses the z, then y, then x convention above 11 | 12 | Parameters 13 | ---------- 14 | z : scalar 15 | Rotation angle in radians around z-axis (performed first) 16 | y : scalar 17 | Rotation angle in radians around y-axis 18 | x : scalar 19 | Rotation angle in radians around x-axis (performed last) 20 | 21 | Returns 22 | ------- 23 | M : array shape (3,3) 24 | Rotation matrix giving same rotation as for given angles 25 | 26 | Examples 27 | -------- 28 | >>> zrot = 1.3 # radians 29 | >>> yrot = -0.1 30 | >>> xrot = 0.2 31 | >>> M = euler2mat(zrot, yrot, xrot) 32 | >>> M.shape == (3, 3) 33 | True 34 | 35 | The output rotation matrix is equal to the composition of the 36 | individual rotations 37 | 38 | >>> M1 = euler2mat(zrot) 39 | >>> M2 = euler2mat(0, yrot) 40 | >>> M3 = euler2mat(0, 0, xrot) 41 | >>> composed_M = np.dot(M3, np.dot(M2, M1)) 42 | >>> np.allclose(M, composed_M) 43 | True 44 | 45 | You can specify rotations by named arguments 46 | 47 | >>> np.all(M3 == euler2mat(x=xrot)) 48 | True 49 | 50 | When applying M to a vector, the vector should column vector to the 51 | right of M. If the right hand side is a 2D array rather than a 52 | vector, then each column of the 2D array represents a vector. 53 | 54 | >>> vec = np.array([1, 0, 0]).reshape((3,1)) 55 | >>> v2 = np.dot(M, vec) 56 | >>> vecs = np.array([[1, 0, 0],[0, 1, 0]]).T # giving 3x2 array 57 | >>> vecs2 = np.dot(M, vecs) 58 | 59 | Rotations are counter-clockwise. 60 | 61 | >>> zred = np.dot(euler2mat(z=np.pi/2), np.eye(3)) 62 | >>> np.allclose(zred, [[0, -1, 0],[1, 0, 0], [0, 0, 1]]) 63 | True 64 | >>> yred = np.dot(euler2mat(y=np.pi/2), np.eye(3)) 65 | >>> np.allclose(yred, [[0, 0, 1],[0, 1, 0], [-1, 0, 0]]) 66 | True 67 | >>> xred = np.dot(euler2mat(x=np.pi/2), np.eye(3)) 68 | >>> np.allclose(xred, [[1, 0, 0],[0, 0, -1], [0, 1, 0]]) 69 | True 70 | 71 | Notes 72 | ----- 73 | The direction of rotation is given by the right-hand rule (orient 74 | the thumb of the right hand along the axis around which the rotation 75 | occurs, with the end of the thumb at the positive end of the axis; 76 | curl your fingers; the direction your fingers curl is the direction 77 | of rotation). Therefore, the rotations are counterclockwise if 78 | looking along the axis of rotation from positive to negative. 79 | ''' 80 | Ms = [] 81 | if z: 82 | cosz = math.cos(z) 83 | sinz = math.sin(z) 84 | Ms.append(np.array( 85 | [[cosz, -sinz, 0], 86 | [sinz, cosz, 0], 87 | [0, 0, 1]])) 88 | if y: 89 | cosy = math.cos(y) 90 | siny = math.sin(y) 91 | Ms.append(np.array( 92 | [[cosy, 0, siny], 93 | [0, 1, 0], 94 | [-siny, 0, cosy]])) 95 | if x: 96 | cosx = math.cos(x) 97 | sinx = math.sin(x) 98 | Ms.append(np.array( 99 | [[1, 0, 0], 100 | [0, cosx, -sinx], 101 | [0, sinx, cosx]])) 102 | if Ms: 103 | return reduce(np.dot, Ms[::-1]) 104 | return np.eye(3) 105 | 106 | def draw_point_cloud(input_points, canvasSize=1000, space=480, diameter=10, 107 | xrot=0, yrot=0, zrot=0, switch_xyz=[0,1,2], normalize=True): 108 | """ Render point cloud to image with alpha channel. 109 | Input: 110 | points: Nx3 numpy array (+y is up direction) 111 | Output: 112 | gray image as numpy array of size canvasSizexcanvasSize 113 | """ 114 | canvasSizeX = canvasSize 115 | canvasSizeY = canvasSize 116 | 117 | image = np.zeros((canvasSizeX, canvasSizeY)) 118 | if input_points is None or input_points.shape[0] == 0: 119 | return image 120 | 121 | points = input_points[:, switch_xyz] 122 | M = euler2mat(zrot, yrot, xrot) 123 | points = (np.dot(M, points.transpose())).transpose() 124 | 125 | # Normalize the point cloud 126 | # We normalize scale to fit points in a unit sphere 127 | if normalize: 128 | centroid = np.mean(points, axis=0) 129 | points -= centroid 130 | furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1))) 131 | points /= furthest_distance 132 | 133 | # Pre-compute the Gaussian disk 134 | radius = (diameter-1)/2.0 135 | disk = np.zeros((diameter, diameter)) 136 | for i in range(diameter): 137 | for j in range(diameter): 138 | if (i - radius) * (i-radius) + (j-radius) * (j-radius) <= radius * radius: 139 | disk[i, j] = np.exp((-(i-radius)**2 - (j-radius)**2)/(radius**2)) 140 | mask = np.argwhere(disk > 0) 141 | dx = mask[:, 0] 142 | dy = mask[:, 1] 143 | dv = disk[disk > 0] 144 | 145 | # Order points by z-buffer 146 | zorder = np.argsort(points[:, 2]) 147 | points = points[zorder, :] 148 | points[:, 2] = (points[:, 2] - np.min(points[:, 2])) / (np.max(points[:, 2] - np.min(points[:, 2]))) 149 | max_depth = np.max(points[:, 2]) 150 | 151 | for i in range(points.shape[0]): 152 | j = points.shape[0] - i - 1 153 | x = points[j, 0] 154 | y = points[j, 1] 155 | xc = canvasSizeX/2 + (x*space) 156 | yc = canvasSizeY/2 + (y*space) 157 | xc = int(np.round(xc)) 158 | yc = int(np.round(yc)) 159 | 160 | px = dx + xc 161 | py = dy + yc 162 | #image[px, py] = image[px, py] * 0.7 + dv * (max_depth - points[j, 2]) * 0.3 163 | image[px, py] = image[px, py] * 0.7 + dv * 0.3 164 | 165 | val = np.max(image)+1e-8 166 | val = np.percentile(image,99.9) 167 | image = image / val 168 | mask = image==0 169 | 170 | image[image>1.0]=1.0 171 | image = 1.0-image 172 | #image = np.expand_dims(image, axis=-1) 173 | #image = np.concatenate((image*0.3+0.7,np.ones_like(image), np.ones_like(image)), axis=2) 174 | #image = colors.hsv_to_rgb(image) 175 | image[mask]=1.0 176 | 177 | 178 | return image 179 | 180 | if __name__=="__main__": 181 | import cv2 182 | data_dir = '/data2/haolin/PUGAN_pytorch/outputs/full_bs12_0508/camel.xyz' 183 | data = np.loadtxt(data_dir) 184 | img = draw_point_cloud(data, zrot=90 / 180.0 * np.pi, xrot=90 / 180.0 * np.pi, yrot=0 / 180.0 * np.pi, diameter=4) 185 | 186 | print(img.shape) 187 | cv2.imwrite('./test_img.jpg', (img * 255).astype(np.uint8)) -------------------------------------------------------------------------------- /train/train_recon.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | sys.path.append('../') 5 | import torch 6 | from network.networks import Generator, Discriminator,Generator_recon 7 | from data.data_loader import PUNET_Dataset 8 | import argparse 9 | import time 10 | from option.train_option import get_train_options 11 | from utils.Logger import Logger 12 | from torch.utils import data 13 | from torch.optim import Adam 14 | from torch.optim.lr_scheduler import MultiStepLR 15 | from loss.loss import Loss 16 | import datetime 17 | import torch.nn as nn 18 | from utils.visualize_utils import visualize_point_cloud 19 | import numpy as np 20 | 21 | 22 | def xavier_init(m): 23 | classname = m.__class__.__name__ 24 | # print(classname) 25 | if classname.find('Conv') != -1: 26 | nn.init.xavier_normal(m.weight) 27 | elif classname.find('Linear') != -1: 28 | nn.init.xavier_normal(m.weight) 29 | elif classname.find('BatchNorm') != -1: 30 | nn.init.constant_(m.weight, 1) 31 | nn.init.constant_(m.bias, 0) 32 | 33 | 34 | def train(args): 35 | start_t = time.time() 36 | params = get_train_options() 37 | params["exp_name"] = args.exp_name 38 | params["patch_num_point"] = 1024 39 | params["batch_size"] = args.batch_size 40 | params['use_gan'] = args.use_gan 41 | 42 | if args.debug: 43 | params["nepoch"] = 2 44 | params["model_save_interval"] = 3 45 | params['model_vis_interval'] = 3 46 | 47 | log_dir = os.path.join(params["model_save_dir"], args.exp_name) 48 | if os.path.exists(log_dir) == False: 49 | os.makedirs(log_dir) 50 | tb_logger = Logger(log_dir) 51 | 52 | trainloader = PUNET_Dataset(h5_file_path=params["dataset_dir"]) 53 | # print(params["dataset_dir"]) 54 | num_workers = 4 55 | train_data_loader = data.DataLoader(dataset=trainloader, batch_size=params["batch_size"], shuffle=True, 56 | num_workers=num_workers, pin_memory=True, drop_last=True) 57 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 58 | 59 | G_model = Generator_recon(params) 60 | G_model.apply(xavier_init) 61 | G_model = torch.nn.DataParallel(G_model).to(device) 62 | D_model = torch.nn.DataParallel(Discriminator(params, in_channels=3)).to(device) 63 | 64 | G_model.train() 65 | D_model.train() 66 | 67 | optimizer_D = Adam(D_model.parameters(), lr=params["lr_D"], betas=(0.9, 0.999)) 68 | optimizer_G = Adam(G_model.parameters(), lr=params["lr_G"], betas=(0.9, 0.999)) 69 | 70 | D_scheduler = MultiStepLR(optimizer_D, [50, 80], gamma=0.2) 71 | G_scheduler = MultiStepLR(optimizer_G, [50, 80], gamma=0.2) 72 | 73 | Loss_fn = Loss() 74 | 75 | print("preparation time is %fs" % (time.time() - start_t)) 76 | iter = 0 77 | for e in range(params["nepoch"]): 78 | D_scheduler.step() 79 | G_scheduler.step() 80 | for batch_id, (input_data, gt_data, radius_data) in enumerate(train_data_loader): 81 | optimizer_G.zero_grad() 82 | optimizer_D.zero_grad() 83 | 84 | input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda() 85 | gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda() 86 | 87 | start_t_batch = time.time() 88 | output_point_cloud = G_model(input_data) 89 | 90 | emd_loss = Loss_fn.get_emd_loss(output_point_cloud.permute(0, 2, 1), input_data.permute(0, 2, 1)) 91 | 92 | total_G_loss=emd_loss 93 | total_G_loss.backward() 94 | optimizer_G.step() 95 | 96 | current_lr_D = optimizer_D.state_dict()['param_groups'][0]['lr'] 97 | current_lr_G = optimizer_G.state_dict()['param_groups'][0]['lr'] 98 | 99 | tb_logger.scalar_summary('emd_loss', emd_loss.item(), iter) 100 | tb_logger.scalar_summary('lr_D', current_lr_D, iter) 101 | tb_logger.scalar_summary('lr_G', current_lr_G, iter) 102 | 103 | msg = "{:0>8},{}:{}, [{}/{}], {}: {},{}:{}".format( 104 | str(datetime.timedelta(seconds=round(time.time() - start_t))), 105 | "epoch", 106 | e, 107 | batch_id + 1, 108 | len(train_data_loader), 109 | "total_G_loss", 110 | total_G_loss.item(), 111 | "iter time", 112 | (time.time() - start_t_batch) 113 | ) 114 | print(msg) 115 | 116 | if iter % params['model_save_interval'] == 0 and iter > 0: 117 | model_save_dir = os.path.join(params['model_save_dir'], params['exp_name']) 118 | if os.path.exists(model_save_dir) == False: 119 | os.makedirs(model_save_dir) 120 | D_ckpt_model_filename = "D_iter_%d.pth" % (iter) 121 | G_ckpt_model_filename = "G_iter_%d.pth" % (iter) 122 | D_model_save_path = os.path.join(model_save_dir, D_ckpt_model_filename) 123 | G_model_save_path = os.path.join(model_save_dir, G_ckpt_model_filename) 124 | torch.save(D_model.module.state_dict(), D_model_save_path) 125 | torch.save(G_model.module.state_dict(), G_model_save_path) 126 | 127 | if iter % params['model_vis_interval'] == 0 and iter > 0: 128 | np_pcd = output_point_cloud.permute(0, 2, 1)[0].detach().cpu().numpy() 129 | # print(np_pcd.shape) 130 | img = (np.array(visualize_point_cloud(np_pcd)) * 255).astype(np.uint8) 131 | tb_logger.image_summary("images", img[np.newaxis, :], iter) 132 | 133 | gt_pcd = gt_data.permute(0, 2, 1)[0].detach().cpu().numpy() 134 | # print(gt_pcd.shape) 135 | gt_img = (np.array(visualize_point_cloud(gt_pcd)) * 255).astype(np.uint8) 136 | tb_logger.image_summary("gt", gt_img[np.newaxis, :], iter) 137 | 138 | input_pcd = input_data.permute(0, 2, 1)[0].detach().cpu().numpy() 139 | input_img = (np.array(visualize_point_cloud(input_pcd)) * 255).astype(np.uint8) 140 | tb_logger.image_summary("input", input_img[np.newaxis, :], iter) 141 | iter += 1 142 | 143 | 144 | if __name__ == "__main__": 145 | import colored_traceback 146 | 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument('--exp_name', '-e', type=str, required=True, help='experiment name') 149 | parser.add_argument('--debug', action='store_true', help='specify debug mode') 150 | parser.add_argument('--use_gan', action='store_true') 151 | parser.add_argument('--batch_size', type=int, default=16) 152 | 153 | args = parser.parse_args() 154 | train(args) -------------------------------------------------------------------------------- /chamfer_distance/chamfer_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | int ChamferDistanceKernelLauncher( 5 | const int b, const int n, 6 | const float* xyz, 7 | const int m, 8 | const float* xyz2, 9 | float* result, 10 | int* result_i, 11 | float* result2, 12 | int* result2_i); 13 | 14 | int ChamferDistanceGradKernelLauncher( 15 | const int b, const int n, 16 | const float* xyz1, 17 | const int m, 18 | const float* xyz2, 19 | const float* grad_dist1, 20 | const int* idx1, 21 | const float* grad_dist2, 22 | const int* idx2, 23 | float* grad_xyz1, 24 | float* grad_xyz2); 25 | 26 | 27 | void chamfer_distance_forward_cuda( 28 | const at::Tensor xyz1, 29 | const at::Tensor xyz2, 30 | const at::Tensor dist1, 31 | const at::Tensor dist2, 32 | const at::Tensor idx1, 33 | const at::Tensor idx2) 34 | { 35 | ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 36 | xyz2.size(1), xyz2.data(), 37 | dist1.data(), idx1.data(), 38 | dist2.data(), idx2.data()); 39 | } 40 | 41 | void chamfer_distance_backward_cuda( 42 | const at::Tensor xyz1, 43 | const at::Tensor xyz2, 44 | at::Tensor gradxyz1, 45 | at::Tensor gradxyz2, 46 | at::Tensor graddist1, 47 | at::Tensor graddist2, 48 | at::Tensor idx1, 49 | at::Tensor idx2) 50 | { 51 | ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 52 | xyz2.size(1), xyz2.data(), 53 | graddist1.data(), idx1.data(), 54 | graddist2.data(), idx2.data(), 55 | gradxyz1.data(), gradxyz2.data()); 56 | } 57 | 58 | 59 | void nnsearch( 60 | const int b, const int n, const int m, 61 | const float* xyz1, 62 | const float* xyz2, 63 | float* dist, 64 | int* idx) 65 | { 66 | for (int i = 0; i < b; i++) { 67 | for (int j = 0; j < n; j++) { 68 | const float x1 = xyz1[(i*n+j)*3+0]; 69 | const float y1 = xyz1[(i*n+j)*3+1]; 70 | const float z1 = xyz1[(i*n+j)*3+2]; 71 | double best = 0; 72 | int besti = 0; 73 | for (int k = 0; k < m; k++) { 74 | const float x2 = xyz2[(i*m+k)*3+0] - x1; 75 | const float y2 = xyz2[(i*m+k)*3+1] - y1; 76 | const float z2 = xyz2[(i*m+k)*3+2] - z1; 77 | const double d=x2*x2+y2*y2+z2*z2; 78 | if (k==0 || d < best){ 79 | best = d; 80 | besti = k; 81 | } 82 | } 83 | dist[i*n+j] = best; 84 | idx[i*n+j] = besti; 85 | } 86 | } 87 | } 88 | 89 | 90 | void chamfer_distance_forward( 91 | const at::Tensor xyz1, 92 | const at::Tensor xyz2, 93 | const at::Tensor dist1, 94 | const at::Tensor dist2, 95 | const at::Tensor idx1, 96 | const at::Tensor idx2) 97 | { 98 | const int batchsize = xyz1.size(0); 99 | const int n = xyz1.size(1); 100 | const int m = xyz2.size(1); 101 | 102 | const float* xyz1_data = xyz1.data(); 103 | const float* xyz2_data = xyz2.data(); 104 | float* dist1_data = dist1.data(); 105 | float* dist2_data = dist2.data(); 106 | int* idx1_data = idx1.data(); 107 | int* idx2_data = idx2.data(); 108 | 109 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); 110 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); 111 | } 112 | 113 | 114 | void chamfer_distance_backward( 115 | const at::Tensor xyz1, 116 | const at::Tensor xyz2, 117 | at::Tensor gradxyz1, 118 | at::Tensor gradxyz2, 119 | at::Tensor graddist1, 120 | at::Tensor graddist2, 121 | at::Tensor idx1, 122 | at::Tensor idx2) 123 | { 124 | const int b = xyz1.size(0); 125 | const int n = xyz1.size(1); 126 | const int m = xyz2.size(1); 127 | 128 | const float* xyz1_data = xyz1.data(); 129 | const float* xyz2_data = xyz2.data(); 130 | float* gradxyz1_data = gradxyz1.data(); 131 | float* gradxyz2_data = gradxyz2.data(); 132 | float* graddist1_data = graddist1.data(); 133 | float* graddist2_data = graddist2.data(); 134 | const int* idx1_data = idx1.data(); 135 | const int* idx2_data = idx2.data(); 136 | 137 | for (int i = 0; i < b*n*3; i++) 138 | gradxyz1_data[i] = 0; 139 | for (int i = 0; i < b*m*3; i++) 140 | gradxyz2_data[i] = 0; 141 | for (int i = 0;i < b; i++) { 142 | for (int j = 0; j < n; j++) { 143 | const float x1 = xyz1_data[(i*n+j)*3+0]; 144 | const float y1 = xyz1_data[(i*n+j)*3+1]; 145 | const float z1 = xyz1_data[(i*n+j)*3+2]; 146 | const int j2 = idx1_data[i*n+j]; 147 | 148 | const float x2 = xyz2_data[(i*m+j2)*3+0]; 149 | const float y2 = xyz2_data[(i*m+j2)*3+1]; 150 | const float z2 = xyz2_data[(i*m+j2)*3+2]; 151 | const float g = graddist1_data[i*n+j]*2; 152 | 153 | gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); 154 | gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); 155 | gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); 156 | gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); 157 | gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); 158 | gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); 159 | } 160 | for (int j = 0; j < m; j++) { 161 | const float x1 = xyz2_data[(i*m+j)*3+0]; 162 | const float y1 = xyz2_data[(i*m+j)*3+1]; 163 | const float z1 = xyz2_data[(i*m+j)*3+2]; 164 | const int j2 = idx2_data[i*m+j]; 165 | const float x2 = xyz1_data[(i*n+j2)*3+0]; 166 | const float y2 = xyz1_data[(i*n+j2)*3+1]; 167 | const float z2 = xyz1_data[(i*n+j2)*3+2]; 168 | const float g = graddist2_data[i*m+j]*2; 169 | gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); 170 | gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); 171 | gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); 172 | gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); 173 | gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); 174 | gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); 175 | } 176 | } 177 | } 178 | 179 | 180 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 181 | m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); 182 | m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); 183 | m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); 184 | m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); 185 | } 186 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import argparse 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--exp_name', '-e', type=str, required=True, help='experiment name') 5 | parser.add_argument('--debug', action='store_true', help='specify debug mode') 6 | parser.add_argument('--use_gan',action='store_true') 7 | parser.add_argument('--batch_size',type=int,default=16) 8 | parser.add_argument('--gpu',type=str,default='0') 9 | 10 | args = parser.parse_args() 11 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 12 | sys.path.append('../') 13 | import torch 14 | from network.networks import Generator,Discriminator 15 | from data.data_loader import PUNET_Dataset 16 | import time 17 | from option.train_option import get_train_options 18 | from utils.Logger import Logger 19 | from torch.utils import data 20 | from torch.optim import Adam 21 | from torch.optim.lr_scheduler import MultiStepLR 22 | from loss.loss import Loss 23 | import datetime 24 | import torch.nn as nn 25 | 26 | def xavier_init(m): 27 | classname = m.__class__.__name__ 28 | #print(classname) 29 | if classname.find('Conv') != -1: 30 | nn.init.xavier_normal(m.weight) 31 | elif classname.find('Linear')!=-1: 32 | nn.init.xavier_normal(m.weight) 33 | elif classname.find('BatchNorm') != -1: 34 | nn.init.constant_(m.weight, 1) 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def train(args): 38 | start_t=time.time() 39 | params=get_train_options() 40 | params["exp_name"]=args.exp_name 41 | params["patch_num_point"]=1024 42 | params["batch_size"]=args.batch_size 43 | params['use_gan']=args.use_gan 44 | 45 | if args.debug: 46 | params["nepoch"]=2 47 | params["model_save_interval"]=3 48 | params['model_vis_interval']=3 49 | 50 | log_dir=os.path.join(params["model_save_dir"],args.exp_name) 51 | if os.path.exists(log_dir)==False: 52 | os.makedirs(log_dir) 53 | tb_logger=Logger(log_dir) 54 | 55 | trainloader=PUNET_Dataset(h5_file_path=params["dataset_dir"],split_dir=params['train_split']) 56 | #print(params["dataset_dir"]) 57 | num_workers=4 58 | train_data_loader=data.DataLoader(dataset=trainloader,batch_size=params["batch_size"],shuffle=True, 59 | num_workers=num_workers,pin_memory=True,drop_last=True) 60 | device=torch.device('cuda'if torch.cuda.is_available() else 'cpu') 61 | 62 | G_model=Generator(params) 63 | G_model.apply(xavier_init) 64 | G_model=torch.nn.DataParallel(G_model).to(device) 65 | D_model=Discriminator(params,in_channels=3) 66 | D_model.apply(xavier_init) 67 | D_model=torch.nn.DataParallel(D_model).to(device) 68 | 69 | G_model.train() 70 | D_model.train() 71 | 72 | optimizer_D=Adam(D_model.parameters(),lr=params["lr_D"],betas=(0.9,0.999)) 73 | optimizer_G=Adam(G_model.parameters(),lr=params["lr_G"],betas=(0.9,0.999)) 74 | 75 | D_scheduler = MultiStepLR(optimizer_D,[50,80],gamma=0.2) 76 | G_scheduler = MultiStepLR(optimizer_G,[50,80],gamma=0.2) 77 | 78 | Loss_fn=Loss() 79 | 80 | print("preparation time is %fs" % (time.time() - start_t)) 81 | iter=0 82 | for e in range(params["nepoch"]): 83 | D_scheduler.step() 84 | G_scheduler.step() 85 | for batch_id,(input_data, gt_data, radius_data) in enumerate(train_data_loader): 86 | optimizer_G.zero_grad() 87 | optimizer_D.zero_grad() 88 | 89 | input_data=input_data[:,:,0:3].permute(0,2,1).float().cuda() 90 | gt_data=gt_data[:,:,0:3].permute(0,2,1).float().cuda() 91 | 92 | start_t_batch=time.time() 93 | output_point_cloud=G_model(input_data) 94 | 95 | repulsion_loss = Loss_fn.get_repulsion_loss(output_point_cloud.permute(0, 2, 1)) 96 | uniform_loss = Loss_fn.get_uniform_loss(output_point_cloud.permute(0, 2, 1)) 97 | #print(output_point_cloud.shape,gt_data.shape) 98 | emd_loss = Loss_fn.get_emd_loss(output_point_cloud.permute(0, 2, 1), gt_data.permute(0, 2, 1)) 99 | 100 | if params['use_gan']==True: 101 | fake_pred = D_model(output_point_cloud.detach()) 102 | d_loss_fake = Loss_fn.get_discriminator_loss_single(fake_pred,label=False) 103 | d_loss_fake.backward() 104 | optimizer_D.step() 105 | 106 | real_pred = D_model(gt_data.detach()) 107 | d_loss_real = Loss_fn.get_discriminator_loss_single(real_pred, label=True) 108 | d_loss_real.backward() 109 | optimizer_D.step() 110 | 111 | d_loss=d_loss_real+d_loss_fake 112 | 113 | fake_pred=D_model(output_point_cloud) 114 | g_loss=Loss_fn.get_generator_loss(fake_pred) 115 | 116 | #print(repulsion_loss,uniform_loss,emd_loss) 117 | total_G_loss=params['uniform_w']*uniform_loss+params['emd_w']*emd_loss+ \ 118 | repulsion_loss*params['repulsion_w']+ g_loss*params['gan_w'] 119 | else: 120 | #total_G_loss = params['uniform_w'] * uniform_loss + params['emd_w'] * emd_loss + \ 121 | # repulsion_loss * params['repulsion_w'] 122 | total_G_loss=params['emd_w'] * emd_loss + \ 123 | repulsion_loss * params['repulsion_w'] 124 | 125 | #total_G_loss=emd_loss 126 | total_G_loss.backward() 127 | optimizer_G.step() 128 | 129 | current_lr_D=optimizer_D.state_dict()['param_groups'][0]['lr'] 130 | current_lr_G=optimizer_G.state_dict()['param_groups'][0]['lr'] 131 | 132 | tb_logger.scalar_summary('repulsion_loss', repulsion_loss.item(), iter) 133 | tb_logger.scalar_summary('uniform_loss', uniform_loss.item(), iter) 134 | tb_logger.scalar_summary('emd_loss', emd_loss.item(), iter) 135 | if params['use_gan']==True: 136 | tb_logger.scalar_summary('d_loss', d_loss.item(), iter) 137 | tb_logger.scalar_summary('g_loss', g_loss.item(), iter) 138 | tb_logger.scalar_summary('lr_D', current_lr_D, iter) 139 | tb_logger.scalar_summary('lr_G', current_lr_G, iter) 140 | 141 | msg="{:0>8},{}:{}, [{}/{}], {}: {},{}:{}".format( 142 | str(datetime.timedelta(seconds=round(time.time() - start_t))), 143 | "epoch", 144 | e, 145 | batch_id + 1, 146 | len(train_data_loader), 147 | "total_G_loss", 148 | total_G_loss.item(), 149 | "iter time", 150 | (time.time() - start_t_batch) 151 | ) 152 | print(msg) 153 | 154 | iter+=1 155 | if (e+1) % params['model_save_interval'] == 0 and e > 0: 156 | model_save_dir = os.path.join(params['model_save_dir'], params['exp_name']) 157 | if os.path.exists(model_save_dir) == False: 158 | os.makedirs(model_save_dir) 159 | D_ckpt_model_filename = "D_iter_%d.pth" % (e) 160 | G_ckpt_model_filename = "G_iter_%d.pth" % (e) 161 | D_model_save_path = os.path.join(model_save_dir, D_ckpt_model_filename) 162 | G_model_save_path = os.path.join(model_save_dir, G_ckpt_model_filename) 163 | torch.save(D_model.module.state_dict(), D_model_save_path) 164 | torch.save(G_model.module.state_dict(), G_model_save_path) 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | if __name__=="__main__": 175 | import colored_traceback 176 | train(args) -------------------------------------------------------------------------------- /train/train_emd_only.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 4 | sys.path.append('../') 5 | import torch 6 | from network.networks import Generator, Discriminator 7 | from data.data_loader import PUNET_Dataset 8 | import argparse 9 | import time 10 | from option.train_option import get_train_options 11 | from utils.Logger import Logger 12 | from torch.utils import data 13 | from torch.optim import Adam 14 | from torch.optim.lr_scheduler import MultiStepLR 15 | from loss.loss import Loss 16 | import datetime 17 | import torch.nn as nn 18 | #from utils.visualize_utils import visualize_point_cloud 19 | import numpy as np 20 | 21 | 22 | def xavier_init(m): 23 | classname = m.__class__.__name__ 24 | # print(classname) 25 | if classname.find('Conv') != -1: 26 | nn.init.xavier_normal(m.weight) 27 | elif classname.find('Linear') != -1: 28 | nn.init.xavier_normal(m.weight) 29 | elif classname.find('BatchNorm') != -1: 30 | nn.init.constant_(m.weight, 1) 31 | nn.init.constant_(m.bias, 0) 32 | 33 | 34 | def train(args): 35 | start_t = time.time() 36 | params = get_train_options() 37 | params["exp_name"] = args.exp_name 38 | params["patch_num_point"] = 1024 39 | params["batch_size"] = args.batch_size 40 | params['use_gan'] = args.use_gan 41 | 42 | if args.debug: 43 | params["nepoch"] = 2 44 | params["model_save_interval"] = 3 45 | params['model_vis_interval'] = 3 46 | 47 | log_dir = os.path.join(params["model_save_dir"], args.exp_name) 48 | if os.path.exists(log_dir) == False: 49 | os.makedirs(log_dir) 50 | tb_logger = Logger(log_dir) 51 | 52 | trainloader = PUNET_Dataset(h5_file_path=params["dataset_dir"],split_dir=params['train_split']) 53 | # print(params["dataset_dir"]) 54 | num_workers = 4 55 | train_data_loader = data.DataLoader(dataset=trainloader, batch_size=params["batch_size"], shuffle=True, 56 | num_workers=num_workers, pin_memory=True, drop_last=True) 57 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 58 | 59 | G_model = Generator(params) 60 | G_model.apply(xavier_init) 61 | G_model = torch.nn.DataParallel(G_model).to(device) 62 | D_model = torch.nn.DataParallel(Discriminator(params, in_channels=3)).to(device) 63 | 64 | G_model.train() 65 | D_model.train() 66 | 67 | optimizer_D = Adam(D_model.parameters(), lr=params["lr_D"], betas=(0.9, 0.999)) 68 | optimizer_G = Adam(G_model.parameters(), lr=params["lr_G"], betas=(0.9, 0.999)) 69 | 70 | D_scheduler = MultiStepLR(optimizer_D, [50, 80], gamma=0.2) 71 | G_scheduler = MultiStepLR(optimizer_G, [50, 80], gamma=0.2) 72 | 73 | Loss_fn = Loss() 74 | 75 | print("preparation time is %fs" % (time.time() - start_t)) 76 | iter = 0 77 | for e in range(params["nepoch"]): 78 | D_scheduler.step() 79 | G_scheduler.step() 80 | for batch_id, (input_data, gt_data, radius_data) in enumerate(train_data_loader): 81 | optimizer_G.zero_grad() 82 | optimizer_D.zero_grad() 83 | 84 | input_data = input_data[:, :, 0:3].permute(0, 2, 1).float().cuda() 85 | gt_data = gt_data[:, :, 0:3].permute(0, 2, 1).float().cuda() 86 | 87 | start_t_batch = time.time() 88 | output_point_cloud = G_model(input_data) 89 | 90 | emd_loss = Loss_fn.get_emd_loss(output_point_cloud.permute(0, 2, 1), gt_data.permute(0, 2, 1)) 91 | 92 | if params['use_gan']==True: 93 | fake_pred = D_model(output_point_cloud.detach()) 94 | d_loss_fake = Loss_fn.get_discriminator_loss_single(fake_pred,label=False) 95 | d_loss_fake.backward() 96 | optimizer_D.step() 97 | 98 | real_pred = D_model(gt_data.detach()) 99 | d_loss_real = Loss_fn.get_discriminator_loss_single(real_pred, label=True) 100 | d_loss_real.backward() 101 | optimizer_D.step() 102 | 103 | d_loss=d_loss_real+d_loss_fake 104 | 105 | fake_pred=D_model(output_point_cloud) 106 | g_loss=Loss_fn.get_generator_loss(fake_pred) 107 | 108 | total_G_loss=params['emd_w']*emd_loss + g_loss*params['gan_w'] 109 | else: 110 | total_G_loss=params['emd_w']*emd_loss 111 | 112 | total_G_loss.backward() 113 | optimizer_G.step() 114 | 115 | current_lr_D = optimizer_D.state_dict()['param_groups'][0]['lr'] 116 | current_lr_G = optimizer_G.state_dict()['param_groups'][0]['lr'] 117 | 118 | tb_logger.scalar_summary('emd_loss', emd_loss.item(), iter) 119 | tb_logger.scalar_summary('lr_D', current_lr_D, iter) 120 | tb_logger.scalar_summary('lr_G', current_lr_G, iter) 121 | if params['use_gan']==True: 122 | tb_logger.scalar_summary('d_loss', d_loss.item(), iter) 123 | tb_logger.scalar_summary('g_loss', g_loss.item(), iter) 124 | 125 | msg = "{:0>8},{}:{}, [{}/{}], {}: {},{}:{}".format( 126 | str(datetime.timedelta(seconds=round(time.time() - start_t))), 127 | "epoch", 128 | e, 129 | batch_id + 1, 130 | len(train_data_loader), 131 | "total_G_loss", 132 | total_G_loss.item(), 133 | "iter time", 134 | (time.time() - start_t_batch) 135 | ) 136 | print(msg) 137 | ''' 138 | if iter % params['model_vis_interval'] == 0 and iter > 0: 139 | np_pcd = output_point_cloud.permute(0, 2, 1)[0].detach().cpu().numpy() 140 | # print(np_pcd.shape) 141 | img = (np.array(visualize_point_cloud(np_pcd)) * 255).astype(np.uint8) 142 | tb_logger.image_summary("images", img[np.newaxis, :], iter) 143 | 144 | gt_pcd = gt_data.permute(0, 2, 1)[0].detach().cpu().numpy() 145 | # print(gt_pcd.shape) 146 | gt_img = (np.array(visualize_point_cloud(gt_pcd)) * 255).astype(np.uint8) 147 | tb_logger.image_summary("gt", gt_img[np.newaxis, :], iter) 148 | 149 | input_pcd = input_data.permute(0, 2, 1)[0].detach().cpu().numpy() 150 | input_img = (np.array(visualize_point_cloud(input_pcd)) * 255).astype(np.uint8) 151 | tb_logger.image_summary("input", input_img[np.newaxis, :], iter) 152 | ''' 153 | iter += 1 154 | if (e+1) % params['model_save_interval'] == 0 and e > 0: 155 | model_save_dir = os.path.join(params['model_save_dir'], params['exp_name']) 156 | if os.path.exists(model_save_dir) == False: 157 | os.makedirs(model_save_dir) 158 | D_ckpt_model_filename = "D_iter_%d.pth" % (e) 159 | G_ckpt_model_filename = "G_iter_%d.pth" % (e) 160 | D_model_save_path = os.path.join(model_save_dir, D_ckpt_model_filename) 161 | G_model_save_path = os.path.join(model_save_dir, G_ckpt_model_filename) 162 | torch.save(D_model.module.state_dict(), D_model_save_path) 163 | torch.save(G_model.module.state_dict(), G_model_save_path) 164 | 165 | 166 | if __name__ == "__main__": 167 | import colored_traceback 168 | 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument('--exp_name', '-e', type=str, required=True, help='experiment name') 171 | parser.add_argument('--debug', action='store_true', help='specify debug mode') 172 | parser.add_argument('--use_gan', action='store_true') 173 | parser.add_argument('--batch_size', type=int, default=16) 174 | 175 | args = parser.parse_args() 176 | train(args) -------------------------------------------------------------------------------- /data/test_list.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 3 3 | 11 4 | 12 5 | 16 6 | 21 7 | 30 8 | 35 9 | 52 10 | 54 11 | 73 12 | 87 13 | 88 14 | 89 15 | 92 16 | 97 17 | 98 18 | 103 19 | 107 20 | 109 21 | 111 22 | 114 23 | 115 24 | 117 25 | 123 26 | 134 27 | 135 28 | 137 29 | 148 30 | 149 31 | 155 32 | 175 33 | 176 34 | 177 35 | 178 36 | 179 37 | 183 38 | 184 39 | 189 40 | 195 41 | 198 42 | 218 43 | 220 44 | 236 45 | 237 46 | 240 47 | 244 48 | 245 49 | 246 50 | 248 51 | 261 52 | 263 53 | 264 54 | 266 55 | 269 56 | 272 57 | 277 58 | 280 59 | 281 60 | 286 61 | 289 62 | 291 63 | 295 64 | 298 65 | 299 66 | 302 67 | 303 68 | 310 69 | 312 70 | 316 71 | 325 72 | 333 73 | 334 74 | 338 75 | 353 76 | 355 77 | 358 78 | 359 79 | 371 80 | 373 81 | 374 82 | 378 83 | 380 84 | 384 85 | 393 86 | 401 87 | 406 88 | 424 89 | 431 90 | 434 91 | 437 92 | 442 93 | 445 94 | 456 95 | 459 96 | 470 97 | 473 98 | 474 99 | 478 100 | 479 101 | 488 102 | 495 103 | 509 104 | 513 105 | 516 106 | 518 107 | 520 108 | 522 109 | 524 110 | 530 111 | 535 112 | 544 113 | 548 114 | 556 115 | 559 116 | 566 117 | 567 118 | 571 119 | 574 120 | 585 121 | 611 122 | 615 123 | 617 124 | 629 125 | 641 126 | 643 127 | 650 128 | 652 129 | 653 130 | 655 131 | 658 132 | 660 133 | 670 134 | 680 135 | 683 136 | 689 137 | 701 138 | 709 139 | 716 140 | 723 141 | 727 142 | 728 143 | 729 144 | 740 145 | 741 146 | 745 147 | 746 148 | 749 149 | 759 150 | 762 151 | 766 152 | 770 153 | 775 154 | 777 155 | 786 156 | 792 157 | 793 158 | 798 159 | 801 160 | 803 161 | 804 162 | 811 163 | 812 164 | 813 165 | 819 166 | 833 167 | 837 168 | 846 169 | 850 170 | 851 171 | 852 172 | 861 173 | 865 174 | 876 175 | 890 176 | 897 177 | 898 178 | 902 179 | 910 180 | 912 181 | 914 182 | 920 183 | 927 184 | 935 185 | 940 186 | 944 187 | 945 188 | 958 189 | 964 190 | 965 191 | 967 192 | 969 193 | 971 194 | 972 195 | 978 196 | 979 197 | 982 198 | 989 199 | 998 200 | 999 201 | 1000 202 | 1003 203 | 1004 204 | 1007 205 | 1009 206 | 1013 207 | 1019 208 | 1021 209 | 1024 210 | 1025 211 | 1033 212 | 1039 213 | 1053 214 | 1057 215 | 1064 216 | 1066 217 | 1067 218 | 1068 219 | 1075 220 | 1077 221 | 1079 222 | 1085 223 | 1090 224 | 1093 225 | 1098 226 | 1103 227 | 1106 228 | 1108 229 | 1122 230 | 1123 231 | 1125 232 | 1126 233 | 1134 234 | 1145 235 | 1149 236 | 1167 237 | 1176 238 | 1179 239 | 1180 240 | 1185 241 | 1206 242 | 1213 243 | 1214 244 | 1217 245 | 1221 246 | 1240 247 | 1243 248 | 1246 249 | 1253 250 | 1255 251 | 1256 252 | 1259 253 | 1261 254 | 1267 255 | 1274 256 | 1277 257 | 1278 258 | 1279 259 | 1281 260 | 1287 261 | 1289 262 | 1291 263 | 1296 264 | 1311 265 | 1313 266 | 1314 267 | 1320 268 | 1325 269 | 1333 270 | 1336 271 | 1337 272 | 1338 273 | 1342 274 | 1347 275 | 1348 276 | 1349 277 | 1350 278 | 1354 279 | 1358 280 | 1359 281 | 1387 282 | 1389 283 | 1393 284 | 1394 285 | 1405 286 | 1406 287 | 1417 288 | 1422 289 | 1425 290 | 1429 291 | 1439 292 | 1440 293 | 1447 294 | 1454 295 | 1455 296 | 1457 297 | 1467 298 | 1473 299 | 1478 300 | 1499 301 | 1505 302 | 1506 303 | 1513 304 | 1518 305 | 1520 306 | 1523 307 | 1525 308 | 1526 309 | 1536 310 | 1538 311 | 1543 312 | 1557 313 | 1558 314 | 1565 315 | 1567 316 | 1578 317 | 1580 318 | 1582 319 | 1588 320 | 1589 321 | 1598 322 | 1601 323 | 1612 324 | 1614 325 | 1617 326 | 1619 327 | 1625 328 | 1633 329 | 1636 330 | 1637 331 | 1641 332 | 1643 333 | 1653 334 | 1667 335 | 1668 336 | 1670 337 | 1671 338 | 1679 339 | 1681 340 | 1682 341 | 1686 342 | 1688 343 | 1690 344 | 1692 345 | 1697 346 | 1716 347 | 1721 348 | 1724 349 | 1733 350 | 1735 351 | 1744 352 | 1747 353 | 1751 354 | 1773 355 | 1783 356 | 1785 357 | 1794 358 | 1801 359 | 1802 360 | 1808 361 | 1815 362 | 1819 363 | 1820 364 | 1843 365 | 1844 366 | 1854 367 | 1857 368 | 1858 369 | 1861 370 | 1866 371 | 1871 372 | 1873 373 | 1876 374 | 1878 375 | 1887 376 | 1901 377 | 1902 378 | 1905 379 | 1911 380 | 1921 381 | 1927 382 | 1930 383 | 1963 384 | 1969 385 | 1970 386 | 1971 387 | 1972 388 | 1982 389 | 1988 390 | 1991 391 | 2002 392 | 2008 393 | 2009 394 | 2016 395 | 2020 396 | 2027 397 | 2033 398 | 2035 399 | 2045 400 | 2049 401 | 2053 402 | 2064 403 | 2066 404 | 2074 405 | 2075 406 | 2080 407 | 2085 408 | 2101 409 | 2102 410 | 2104 411 | 2119 412 | 2125 413 | 2127 414 | 2129 415 | 2134 416 | 2136 417 | 2148 418 | 2156 419 | 2159 420 | 2163 421 | 2167 422 | 2169 423 | 2171 424 | 2184 425 | 2189 426 | 2190 427 | 2204 428 | 2205 429 | 2207 430 | 2213 431 | 2225 432 | 2229 433 | 2232 434 | 2238 435 | 2245 436 | 2247 437 | 2252 438 | 2255 439 | 2257 440 | 2259 441 | 2260 442 | 2263 443 | 2264 444 | 2267 445 | 2275 446 | 2288 447 | 2304 448 | 2306 449 | 2311 450 | 2315 451 | 2317 452 | 2319 453 | 2342 454 | 2350 455 | 2357 456 | 2359 457 | 2361 458 | 2367 459 | 2371 460 | 2376 461 | 2391 462 | 2393 463 | 2399 464 | 2400 465 | 2402 466 | 2414 467 | 2425 468 | 2429 469 | 2438 470 | 2441 471 | 2444 472 | 2445 473 | 2446 474 | 2465 475 | 2466 476 | 2467 477 | 2470 478 | 2471 479 | 2472 480 | 2473 481 | 2478 482 | 2483 483 | 2484 484 | 2490 485 | 2491 486 | 2498 487 | 2500 488 | 2502 489 | 2509 490 | 2512 491 | 2514 492 | 2515 493 | 2516 494 | 2518 495 | 2523 496 | 2528 497 | 2530 498 | 2548 499 | 2556 500 | 2559 501 | 2566 502 | 2567 503 | 2571 504 | 2578 505 | 2579 506 | 2588 507 | 2601 508 | 2604 509 | 2616 510 | 2617 511 | 2621 512 | 2622 513 | 2628 514 | 2629 515 | 2632 516 | 2638 517 | 2640 518 | 2641 519 | 2644 520 | 2649 521 | 2650 522 | 2652 523 | 2653 524 | 2656 525 | 2677 526 | 2681 527 | 2682 528 | 2683 529 | 2684 530 | 2694 531 | 2695 532 | 2696 533 | 2698 534 | 2705 535 | 2716 536 | 2726 537 | 2727 538 | 2728 539 | 2731 540 | 2732 541 | 2735 542 | 2740 543 | 2744 544 | 2749 545 | 2757 546 | 2760 547 | 2766 548 | 2768 549 | 2777 550 | 2780 551 | 2784 552 | 2797 553 | 2798 554 | 2799 555 | 2808 556 | 2810 557 | 2817 558 | 2822 559 | 2824 560 | 2832 561 | 2833 562 | 2834 563 | 2843 564 | 2844 565 | 2861 566 | 2866 567 | 2867 568 | 2885 569 | 2887 570 | 2898 571 | 2915 572 | 2916 573 | 2917 574 | 2919 575 | 2921 576 | 2922 577 | 2933 578 | 2936 579 | 2940 580 | 2948 581 | 2958 582 | 2960 583 | 2961 584 | 2962 585 | 2972 586 | 2981 587 | 2992 588 | 2999 589 | 3000 590 | 3003 591 | 3007 592 | 3012 593 | 3017 594 | 3019 595 | 3022 596 | 3023 597 | 3025 598 | 3027 599 | 3035 600 | 3041 601 | 3046 602 | 3049 603 | 3078 604 | 3081 605 | 3082 606 | 3092 607 | 3097 608 | 3100 609 | 3103 610 | 3109 611 | 3113 612 | 3114 613 | 3120 614 | 3122 615 | 3124 616 | 3126 617 | 3128 618 | 3132 619 | 3135 620 | 3136 621 | 3143 622 | 3157 623 | 3159 624 | 3160 625 | 3162 626 | 3165 627 | 3170 628 | 3171 629 | 3172 630 | 3176 631 | 3186 632 | 3190 633 | 3191 634 | 3197 635 | 3205 636 | 3208 637 | 3209 638 | 3218 639 | 3233 640 | 3236 641 | 3240 642 | 3248 643 | 3252 644 | 3256 645 | 3258 646 | 3261 647 | 3272 648 | 3274 649 | 3277 650 | 3278 651 | 3282 652 | 3284 653 | 3289 654 | 3290 655 | 3292 656 | 3293 657 | 3295 658 | 3296 659 | 3299 660 | 3313 661 | 3315 662 | 3316 663 | 3336 664 | 3343 665 | 3348 666 | 3355 667 | 3357 668 | 3359 669 | 3367 670 | 3368 671 | 3371 672 | 3380 673 | 3392 674 | 3402 675 | 3409 676 | 3412 677 | 3417 678 | 3418 679 | 3422 680 | 3426 681 | 3437 682 | 3438 683 | 3448 684 | 3455 685 | 3458 686 | 3459 687 | 3462 688 | 3472 689 | 3479 690 | 3489 691 | 3493 692 | 3495 693 | 3509 694 | 3512 695 | 3515 696 | 3526 697 | 3534 698 | 3540 699 | 3543 700 | 3553 701 | 3560 702 | 3562 703 | 3563 704 | 3564 705 | 3570 706 | 3577 707 | 3582 708 | 3590 709 | 3598 710 | 3600 711 | 3602 712 | 3606 713 | 3616 714 | 3618 715 | 3622 716 | 3624 717 | 3626 718 | 3634 719 | 3638 720 | 3642 721 | 3646 722 | 3647 723 | 3648 724 | 3656 725 | 3658 726 | 3662 727 | 3680 728 | 3684 729 | 3685 730 | 3691 731 | 3693 732 | 3695 733 | 3696 734 | 3697 735 | 3705 736 | 3707 737 | 3708 738 | 3709 739 | 3710 740 | 3715 741 | 3718 742 | 3719 743 | 3727 744 | 3729 745 | 3732 746 | 3733 747 | 3740 748 | 3745 749 | 3749 750 | 3751 751 | 3759 752 | 3768 753 | 3770 754 | 3773 755 | 3774 756 | 3777 757 | 3785 758 | 3790 759 | 3794 760 | 3800 761 | 3815 762 | 3822 763 | 3823 764 | 3826 765 | 3827 766 | 3828 767 | 3832 768 | 3838 769 | 3839 770 | 3840 771 | 3842 772 | 3849 773 | 3854 774 | 3856 775 | 3865 776 | 3867 777 | 3870 778 | 3872 779 | 3876 780 | 3878 781 | 3885 782 | 3890 783 | 3891 784 | 3894 785 | 3897 786 | 3902 787 | 3907 788 | 3910 789 | 3912 790 | 3916 791 | 3918 792 | 3921 793 | 3925 794 | 3931 795 | 3935 796 | 3943 797 | 3949 798 | 3968 799 | 3990 800 | 3992 801 | -------------------------------------------------------------------------------- /pointnet2/_ext_src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, m) 12 | // output: out(b, c, m) 13 | __global__ void gather_points_kernel(int b, int c, int n, int m, 14 | const float *__restrict__ points, 15 | const int *__restrict__ idx, 16 | float *__restrict__ out) { 17 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 18 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 19 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 20 | int a = idx[i * m + j]; 21 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 22 | } 23 | } 24 | } 25 | } 26 | 27 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 28 | const float *points, const int *idx, 29 | float *out) { 30 | gather_points_kernel<<>>(b, c, n, npoints, 32 | points, idx, out); 33 | 34 | CUDA_CHECK_ERRORS(); 35 | } 36 | 37 | // input: grad_out(b, c, m) idx(b, m) 38 | // output: grad_points(b, c, n) 39 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 40 | const float *__restrict__ grad_out, 41 | const int *__restrict__ idx, 42 | float *__restrict__ grad_points) { 43 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 44 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 45 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 46 | int a = idx[i * m + j]; 47 | atomicAdd(grad_points + (i * c + l) * n + a, 48 | grad_out[(i * c + l) * m + j]); 49 | } 50 | } 51 | } 52 | } 53 | 54 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 55 | const float *grad_out, const int *idx, 56 | float *grad_points) { 57 | gather_points_grad_kernel<<>>( 59 | b, c, n, npoints, grad_out, idx, grad_points); 60 | 61 | CUDA_CHECK_ERRORS(); 62 | } 63 | 64 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 65 | int idx1, int idx2) { 66 | const float v1 = dists[idx1], v2 = dists[idx2]; 67 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 68 | dists[idx1] = max(v1, v2); 69 | dists_i[idx1] = v2 > v1 ? i2 : i1; 70 | } 71 | 72 | // Input dataset: (b, n, 3), tmp: (b, n) 73 | // Ouput idxs (b, m) 74 | template 75 | __global__ void furthest_point_sampling_kernel( 76 | int b, int n, int m, const float *__restrict__ dataset, 77 | float *__restrict__ temp, int *__restrict__ idxs) { 78 | if (m <= 0) return; 79 | __shared__ float dists[block_size]; 80 | __shared__ int dists_i[block_size]; 81 | 82 | int batch_index = blockIdx.x; 83 | dataset += batch_index * n * 3; 84 | temp += batch_index * n; 85 | idxs += batch_index * m; 86 | 87 | int tid = threadIdx.x; 88 | const int stride = block_size; 89 | 90 | int old = 0; 91 | if (threadIdx.x == 0) idxs[0] = old; 92 | 93 | __syncthreads(); 94 | for (int j = 1; j < m; j++) { 95 | int besti = 0; 96 | float best = -1; 97 | float x1 = dataset[old * 3 + 0]; 98 | float y1 = dataset[old * 3 + 1]; 99 | float z1 = dataset[old * 3 + 2]; 100 | for (int k = tid; k < n; k += stride) { 101 | float x2, y2, z2; 102 | x2 = dataset[k * 3 + 0]; 103 | y2 = dataset[k * 3 + 1]; 104 | z2 = dataset[k * 3 + 2]; 105 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 106 | if (mag <= 1e-3) continue; 107 | 108 | float d = 109 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 110 | 111 | float d2 = min(d, temp[k]); 112 | temp[k] = d2; 113 | besti = d2 > best ? k : besti; 114 | best = d2 > best ? d2 : best; 115 | } 116 | dists[tid] = best; 117 | dists_i[tid] = besti; 118 | __syncthreads(); 119 | 120 | if (block_size >= 512) { 121 | if (tid < 256) { 122 | __update(dists, dists_i, tid, tid + 256); 123 | } 124 | __syncthreads(); 125 | } 126 | if (block_size >= 256) { 127 | if (tid < 128) { 128 | __update(dists, dists_i, tid, tid + 128); 129 | } 130 | __syncthreads(); 131 | } 132 | if (block_size >= 128) { 133 | if (tid < 64) { 134 | __update(dists, dists_i, tid, tid + 64); 135 | } 136 | __syncthreads(); 137 | } 138 | if (block_size >= 64) { 139 | if (tid < 32) { 140 | __update(dists, dists_i, tid, tid + 32); 141 | } 142 | __syncthreads(); 143 | } 144 | if (block_size >= 32) { 145 | if (tid < 16) { 146 | __update(dists, dists_i, tid, tid + 16); 147 | } 148 | __syncthreads(); 149 | } 150 | if (block_size >= 16) { 151 | if (tid < 8) { 152 | __update(dists, dists_i, tid, tid + 8); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 8) { 157 | if (tid < 4) { 158 | __update(dists, dists_i, tid, tid + 4); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 4) { 163 | if (tid < 2) { 164 | __update(dists, dists_i, tid, tid + 2); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 2) { 169 | if (tid < 1) { 170 | __update(dists, dists_i, tid, tid + 1); 171 | } 172 | __syncthreads(); 173 | } 174 | 175 | old = dists_i[0]; 176 | if (tid == 0) idxs[j] = old; 177 | } 178 | } 179 | 180 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 181 | const float *dataset, float *temp, 182 | int *idxs) { 183 | unsigned int n_threads = opt_n_threads(n); 184 | 185 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 186 | 187 | switch (n_threads) { 188 | case 512: 189 | furthest_point_sampling_kernel<512> 190 | <<>>(b, n, m, dataset, temp, idxs); 191 | break; 192 | case 256: 193 | furthest_point_sampling_kernel<256> 194 | <<>>(b, n, m, dataset, temp, idxs); 195 | break; 196 | case 128: 197 | furthest_point_sampling_kernel<128> 198 | <<>>(b, n, m, dataset, temp, idxs); 199 | break; 200 | case 64: 201 | furthest_point_sampling_kernel<64> 202 | <<>>(b, n, m, dataset, temp, idxs); 203 | break; 204 | case 32: 205 | furthest_point_sampling_kernel<32> 206 | <<>>(b, n, m, dataset, temp, idxs); 207 | break; 208 | case 16: 209 | furthest_point_sampling_kernel<16> 210 | <<>>(b, n, m, dataset, temp, idxs); 211 | break; 212 | case 8: 213 | furthest_point_sampling_kernel<8> 214 | <<>>(b, n, m, dataset, temp, idxs); 215 | break; 216 | case 4: 217 | furthest_point_sampling_kernel<4> 218 | <<>>(b, n, m, dataset, temp, idxs); 219 | break; 220 | case 2: 221 | furthest_point_sampling_kernel<2> 222 | <<>>(b, n, m, dataset, temp, idxs); 223 | break; 224 | case 1: 225 | furthest_point_sampling_kernel<1> 226 | <<>>(b, n, m, dataset, temp, idxs); 227 | break; 228 | default: 229 | furthest_point_sampling_kernel<512> 230 | <<>>(b, n, m, dataset, temp, idxs); 231 | } 232 | 233 | CUDA_CHECK_ERRORS(); 234 | } 235 | -------------------------------------------------------------------------------- /pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | import torch 8 | import torch.nn as nn 9 | from typing import List, Tuple 10 | 11 | class SharedMLP(nn.Sequential): 12 | 13 | def __init__( 14 | self, 15 | args: List[int], 16 | *, 17 | bn: bool = False, 18 | activation=nn.ReLU(inplace=True), 19 | preact: bool = False, 20 | first: bool = False, 21 | name: str = "" 22 | ): 23 | super().__init__() 24 | 25 | for i in range(len(args) - 1): 26 | self.add_module( 27 | name + 'layer{}'.format(i), 28 | Conv2d( 29 | args[i], 30 | args[i + 1], 31 | bn=(not first or not preact or (i != 0)) and bn, 32 | activation=activation 33 | if (not first or not preact or (i != 0)) else None, 34 | preact=preact 35 | ) 36 | ) 37 | 38 | 39 | class _BNBase(nn.Sequential): 40 | 41 | def __init__(self, in_size, batch_norm=None, name=""): 42 | super().__init__() 43 | self.add_module(name + "bn", batch_norm(in_size)) 44 | 45 | nn.init.constant_(self[0].weight, 1.0) 46 | nn.init.constant_(self[0].bias, 0) 47 | 48 | 49 | class BatchNorm1d(_BNBase): 50 | 51 | def __init__(self, in_size: int, *, name: str = ""): 52 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 53 | 54 | 55 | class BatchNorm2d(_BNBase): 56 | 57 | def __init__(self, in_size: int, name: str = ""): 58 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 59 | 60 | 61 | class BatchNorm3d(_BNBase): 62 | 63 | def __init__(self, in_size: int, name: str = ""): 64 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 65 | 66 | 67 | class _ConvBase(nn.Sequential): 68 | 69 | def __init__( 70 | self, 71 | in_size, 72 | out_size, 73 | kernel_size, 74 | stride, 75 | padding, 76 | activation, 77 | bn, 78 | init, 79 | conv=None, 80 | batch_norm=None, 81 | bias=True, 82 | preact=False, 83 | name="" 84 | ): 85 | super().__init__() 86 | 87 | bias = bias and (not bn) 88 | conv_unit = conv( 89 | in_size, 90 | out_size, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | padding=padding, 94 | bias=bias 95 | ) 96 | init(conv_unit.weight) 97 | if bias: 98 | nn.init.constant_(conv_unit.bias, 0) 99 | 100 | if bn: 101 | if not preact: 102 | bn_unit = batch_norm(out_size) 103 | else: 104 | bn_unit = batch_norm(in_size) 105 | 106 | if preact: 107 | if bn: 108 | self.add_module(name + 'bn', bn_unit) 109 | 110 | if activation is not None: 111 | self.add_module(name + 'activation', activation) 112 | 113 | self.add_module(name + 'conv', conv_unit) 114 | 115 | if not preact: 116 | if bn: 117 | self.add_module(name + 'bn', bn_unit) 118 | 119 | if activation is not None: 120 | self.add_module(name + 'activation', activation) 121 | 122 | 123 | class Conv1d(_ConvBase): 124 | 125 | def __init__( 126 | self, 127 | in_size: int, 128 | out_size: int, 129 | *, 130 | kernel_size: int = 1, 131 | stride: int = 1, 132 | padding: int = 0, 133 | activation=nn.ReLU(inplace=True), 134 | bn: bool = False, 135 | init=nn.init.kaiming_normal_, 136 | bias: bool = True, 137 | preact: bool = False, 138 | name: str = "" 139 | ): 140 | super().__init__( 141 | in_size, 142 | out_size, 143 | kernel_size, 144 | stride, 145 | padding, 146 | activation, 147 | bn, 148 | init, 149 | conv=nn.Conv1d, 150 | batch_norm=BatchNorm1d, 151 | bias=bias, 152 | preact=preact, 153 | name=name 154 | ) 155 | 156 | 157 | class Conv2d(_ConvBase): 158 | 159 | def __init__( 160 | self, 161 | in_size: int, 162 | out_size: int, 163 | *, 164 | kernel_size: Tuple[int, int] = (1, 1), 165 | stride: Tuple[int, int] = (1, 1), 166 | padding: Tuple[int, int] = (0, 0), 167 | activation=nn.ReLU(inplace=True), 168 | bn: bool = False, 169 | init=nn.init.kaiming_normal_, 170 | bias: bool = True, 171 | preact: bool = False, 172 | name: str = "" 173 | ): 174 | super().__init__( 175 | in_size, 176 | out_size, 177 | kernel_size, 178 | stride, 179 | padding, 180 | activation, 181 | bn, 182 | init, 183 | conv=nn.Conv2d, 184 | batch_norm=BatchNorm2d, 185 | bias=bias, 186 | preact=preact, 187 | name=name 188 | ) 189 | 190 | 191 | class Conv3d(_ConvBase): 192 | 193 | def __init__( 194 | self, 195 | in_size: int, 196 | out_size: int, 197 | *, 198 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 199 | stride: Tuple[int, int, int] = (1, 1, 1), 200 | padding: Tuple[int, int, int] = (0, 0, 0), 201 | activation=nn.ReLU(inplace=True), 202 | bn: bool = False, 203 | init=nn.init.kaiming_normal_, 204 | bias: bool = True, 205 | preact: bool = False, 206 | name: str = "" 207 | ): 208 | super().__init__( 209 | in_size, 210 | out_size, 211 | kernel_size, 212 | stride, 213 | padding, 214 | activation, 215 | bn, 216 | init, 217 | conv=nn.Conv3d, 218 | batch_norm=BatchNorm3d, 219 | bias=bias, 220 | preact=preact, 221 | name=name 222 | ) 223 | 224 | 225 | class FC(nn.Sequential): 226 | 227 | def __init__( 228 | self, 229 | in_size: int, 230 | out_size: int, 231 | *, 232 | activation=nn.ReLU(inplace=True), 233 | bn: bool = False, 234 | init=None, 235 | preact: bool = False, 236 | name: str = "" 237 | ): 238 | super().__init__() 239 | 240 | fc = nn.Linear(in_size, out_size, bias=not bn) 241 | if init is not None: 242 | init(fc.weight) 243 | if not bn: 244 | nn.init.constant_(fc.bias, 0) 245 | 246 | if preact: 247 | if bn: 248 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 249 | 250 | if activation is not None: 251 | self.add_module(name + 'activation', activation) 252 | 253 | self.add_module(name + 'fc', fc) 254 | 255 | if not preact: 256 | if bn: 257 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 258 | 259 | if activation is not None: 260 | self.add_module(name + 'activation', activation) 261 | 262 | def set_bn_momentum_default(bn_momentum): 263 | 264 | def fn(m): 265 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 266 | m.momentum = bn_momentum 267 | 268 | return fn 269 | 270 | 271 | class BNMomentumScheduler(object): 272 | 273 | def __init__( 274 | self, model, bn_lambda, last_epoch=-1, 275 | setter=set_bn_momentum_default 276 | ): 277 | if not isinstance(model, nn.Module): 278 | raise RuntimeError( 279 | "Class '{}' is not a PyTorch nn Module".format( 280 | type(model).__name__ 281 | ) 282 | ) 283 | 284 | self.model = model 285 | self.setter = setter 286 | self.lmbd = bn_lambda 287 | 288 | self.step(last_epoch + 1) 289 | self.last_epoch = last_epoch 290 | 291 | def step(self, epoch=None): 292 | if epoch is None: 293 | epoch = self.last_epoch + 1 294 | 295 | self.last_epoch = epoch 296 | self.model.apply(self.setter(self.lmbd(epoch))) 297 | 298 | 299 | -------------------------------------------------------------------------------- /auction_match/auction_match_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | __global__ void AuctionMatchKernel(int b,int n,const float * __restrict__ xyz1,const float * __restrict__ xyz2,int * matchl,int * matchr,float * cost){ 4 | //this kernel handles up to 4096 points 5 | const int NMax=4096; 6 | __shared__ short Queue[NMax]; 7 | __shared__ short matchrbuf[NMax]; 8 | __shared__ float pricer[NMax]; 9 | __shared__ float bests[32][3]; 10 | __shared__ int qhead,qlen; 11 | const int BufLen=2048; 12 | __shared__ float buf[BufLen]; 13 | for (int bno=blockIdx.x;bno1; 92 | } 93 | int vj,vj2,vj3,vj4; 94 | if (value1=blockDim.x*4){ 150 | for (int j=threadIdx.x;j=blockDim.x*2){ 188 | for (int j=threadIdx.x;j0;i>>=1){ 221 | float b1=__shfl_down(best,i,32); 222 | float b2=__shfl_down(best2,i,32); 223 | int bj=__shfl_down(bestj,i,32); 224 | if (best>5][0]=best; 234 | bests[threadIdx.x>>5][1]=best2; 235 | *(int*)&bests[threadIdx.x>>5][2]=bestj; 236 | } 237 | __syncthreads(); 238 | int nn=blockDim.x>>5; 239 | if (threadIdx.x>1;i>0;i>>=1){ 244 | float b1=__shfl_down(best,i,32); 245 | float b2=__shfl_down(best2,i,32); 246 | int bj=__shfl_down(bestj,i,32); 247 | if (best=n) 261 | qhead-=n; 262 | int old=matchrbuf[bestj]; 263 | pricer[bestj]+=delta; 264 | cnt++; 265 | if (old!=-1){ 266 | int ql=qlen; 267 | int tail=qhead+ql; 268 | qlen=ql+1; 269 | if (tail>=n) 270 | tail-=n; 271 | Queue[tail]=old; 272 | } 273 | if (cnt==(40*n)){ 274 | if (tolerance==1.0) 275 | qlen=0; 276 | tolerance=fminf(1.0,tolerance*100); 277 | cnt=0; 278 | } 279 | } 280 | __syncthreads(); 281 | if (threadIdx.x==0){ 282 | matchrbuf[bestj]=i; 283 | } 284 | } 285 | __syncthreads(); 286 | for (int j=threadIdx.x;j>>(b,n,xyz1,xyz2,matchl,matchr,cost); 295 | } 296 | 297 | -------------------------------------------------------------------------------- /notebook/Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os,sys\n", 10 | "sys.path.append('../')\n", 11 | "import Common.pc_util as pc_util\n", 12 | "import numpy as np\n", 13 | "import math\n", 14 | "from functools import reduce\n", 15 | "\n", 16 | "def euler2mat(z=0, y=0, x=0):\n", 17 | " ''' Return matrix for rotations around z, y and x axes\n", 18 | "\n", 19 | " Uses the z, then y, then x convention above\n", 20 | "\n", 21 | " Parameters\n", 22 | " ----------\n", 23 | " z : scalar\n", 24 | " Rotation angle in radians around z-axis (performed first)\n", 25 | " y : scalar\n", 26 | " Rotation angle in radians around y-axis\n", 27 | " x : scalar\n", 28 | " Rotation angle in radians around x-axis (performed last)\n", 29 | "\n", 30 | " Returns\n", 31 | " -------\n", 32 | " M : array shape (3,3)\n", 33 | " Rotation matrix giving same rotation as for given angles\n", 34 | "\n", 35 | " Examples\n", 36 | " --------\n", 37 | " >>> zrot = 1.3 # radians\n", 38 | " >>> yrot = -0.1\n", 39 | " >>> xrot = 0.2\n", 40 | " >>> M = euler2mat(zrot, yrot, xrot)\n", 41 | " >>> M.shape == (3, 3)\n", 42 | " True\n", 43 | "\n", 44 | " The output rotation matrix is equal to the composition of the\n", 45 | " individual rotations\n", 46 | "\n", 47 | " >>> M1 = euler2mat(zrot)\n", 48 | " >>> M2 = euler2mat(0, yrot)\n", 49 | " >>> M3 = euler2mat(0, 0, xrot)\n", 50 | " >>> composed_M = np.dot(M3, np.dot(M2, M1))\n", 51 | " >>> np.allclose(M, composed_M)\n", 52 | " True\n", 53 | "\n", 54 | " You can specify rotations by named arguments\n", 55 | "\n", 56 | " >>> np.all(M3 == euler2mat(x=xrot))\n", 57 | " True\n", 58 | "\n", 59 | " When applying M to a vector, the vector should column vector to the\n", 60 | " right of M. If the right hand side is a 2D array rather than a\n", 61 | " vector, then each column of the 2D array represents a vector.\n", 62 | "\n", 63 | " >>> vec = np.array([1, 0, 0]).reshape((3,1))\n", 64 | " >>> v2 = np.dot(M, vec)\n", 65 | " >>> vecs = np.array([[1, 0, 0],[0, 1, 0]]).T # giving 3x2 array\n", 66 | " >>> vecs2 = np.dot(M, vecs)\n", 67 | "\n", 68 | " Rotations are counter-clockwise.\n", 69 | "\n", 70 | " >>> zred = np.dot(euler2mat(z=np.pi/2), np.eye(3))\n", 71 | " >>> np.allclose(zred, [[0, -1, 0],[1, 0, 0], [0, 0, 1]])\n", 72 | " True\n", 73 | " >>> yred = np.dot(euler2mat(y=np.pi/2), np.eye(3))\n", 74 | " >>> np.allclose(yred, [[0, 0, 1],[0, 1, 0], [-1, 0, 0]])\n", 75 | " True\n", 76 | " >>> xred = np.dot(euler2mat(x=np.pi/2), np.eye(3))\n", 77 | " >>> np.allclose(xred, [[1, 0, 0],[0, 0, -1], [0, 1, 0]])\n", 78 | " True\n", 79 | "\n", 80 | " Notes\n", 81 | " -----\n", 82 | " The direction of rotation is given by the right-hand rule (orient\n", 83 | " the thumb of the right hand along the axis around which the rotation\n", 84 | " occurs, with the end of the thumb at the positive end of the axis;\n", 85 | " curl your fingers; the direction your fingers curl is the direction\n", 86 | " of rotation). Therefore, the rotations are counterclockwise if\n", 87 | " looking along the axis of rotation from positive to negative.\n", 88 | " '''\n", 89 | " Ms = []\n", 90 | " if z:\n", 91 | " cosz = math.cos(z)\n", 92 | " sinz = math.sin(z)\n", 93 | " Ms.append(np.array(\n", 94 | " [[cosz, -sinz, 0],\n", 95 | " [sinz, cosz, 0],\n", 96 | " [0, 0, 1]]))\n", 97 | " if y:\n", 98 | " cosy = math.cos(y)\n", 99 | " siny = math.sin(y)\n", 100 | " Ms.append(np.array(\n", 101 | " [[cosy, 0, siny],\n", 102 | " [0, 1, 0],\n", 103 | " [-siny, 0, cosy]]))\n", 104 | " if x:\n", 105 | " cosx = math.cos(x)\n", 106 | " sinx = math.sin(x)\n", 107 | " Ms.append(np.array(\n", 108 | " [[1, 0, 0],\n", 109 | " [0, cosx, -sinx],\n", 110 | " [0, sinx, cosx]]))\n", 111 | " if Ms:\n", 112 | " return reduce(np.dot, Ms[::-1])\n", 113 | " return np.eye(3)\n", 114 | "\n", 115 | "def draw_point_cloud(input_points, canvasSize=1000, space=480, diameter=10,\n", 116 | " xrot=0, yrot=0, zrot=0, switch_xyz=[0,1,2], normalize=True):\n", 117 | " \"\"\" Render point cloud to image with alpha channel.\n", 118 | " Input:\n", 119 | " points: Nx3 numpy array (+y is up direction)\n", 120 | " Output:\n", 121 | " gray image as numpy array of size canvasSizexcanvasSize\n", 122 | " \"\"\"\n", 123 | " canvasSizeX = canvasSize\n", 124 | " canvasSizeY = canvasSize\n", 125 | "\n", 126 | " image = np.zeros((canvasSizeX, canvasSizeY))\n", 127 | " if input_points is None or input_points.shape[0] == 0:\n", 128 | " return image\n", 129 | "\n", 130 | " points = input_points[:, switch_xyz]\n", 131 | " M = euler2mat(zrot, yrot, xrot)\n", 132 | " points = (np.dot(M, points.transpose())).transpose()\n", 133 | "\n", 134 | " # Normalize the point cloud\n", 135 | " # We normalize scale to fit points in a unit sphere\n", 136 | " if normalize:\n", 137 | " centroid = np.mean(points, axis=0)\n", 138 | " points -= centroid\n", 139 | " furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1)))\n", 140 | " points /= furthest_distance\n", 141 | "\n", 142 | " # Pre-compute the Gaussian disk\n", 143 | " radius = (diameter-1)/2.0\n", 144 | " disk = np.zeros((diameter, diameter))\n", 145 | " for i in range(diameter):\n", 146 | " for j in range(diameter):\n", 147 | " if (i - radius) * (i-radius) + (j-radius) * (j-radius) <= radius * radius:\n", 148 | " disk[i, j] = np.exp((-(i-radius)**2 - (j-radius)**2)/(radius**2))\n", 149 | " mask = np.argwhere(disk > 0)\n", 150 | " dx = mask[:, 0]\n", 151 | " dy = mask[:, 1]\n", 152 | " dv = disk[disk > 0]\n", 153 | "\n", 154 | " # Order points by z-buffer\n", 155 | " zorder = np.argsort(points[:, 2])\n", 156 | " points = points[zorder, :]\n", 157 | " points[:, 2] = (points[:, 2] - np.min(points[:, 2])) / (np.max(points[:, 2] - np.min(points[:, 2])))\n", 158 | " max_depth = np.max(points[:, 2])\n", 159 | "\n", 160 | " for i in range(points.shape[0]):\n", 161 | " j = points.shape[0] - i - 1\n", 162 | " x = points[j, 0]\n", 163 | " y = points[j, 1]\n", 164 | " xc = canvasSizeX/2 + (x*space)\n", 165 | " yc = canvasSizeY/2 + (y*space)\n", 166 | " xc = int(np.round(xc))\n", 167 | " yc = int(np.round(yc))\n", 168 | "\n", 169 | " px = dx + xc\n", 170 | " py = dy + yc\n", 171 | " #image[px, py] = image[px, py] * 0.7 + dv * (max_depth - points[j, 2]) * 0.3\n", 172 | " image[px, py] = image[px, py] * 0.7 + dv * 0.3\n", 173 | "\n", 174 | " val = np.max(image)+1e-8\n", 175 | " val = np.percentile(image,99.9)\n", 176 | " image = image / val\n", 177 | " mask = image==0\n", 178 | "\n", 179 | " image[image>1.0]=1.0\n", 180 | " image = 1.0-image\n", 181 | " #image = np.expand_dims(image, axis=-1)\n", 182 | " #image = np.concatenate((image*0.3+0.7,np.ones_like(image), np.ones_like(image)), axis=2)\n", 183 | " #image = colors.hsv_to_rgb(image)\n", 184 | " image[mask]=1.0\n", 185 | "\n", 186 | "\n", 187 | " return image" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 11, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "data_dir='/data2/haolin/PUGAN_pytorch/outputs/full_bs12_0508/camel.xyz'\n", 197 | "data=np.loadtxt(data_dir)\n", 198 | "img=draw_point_cloud(data, zrot=90 / 180.0 * np.pi, xrot=90 / 180.0 * np.pi, yrot=0 / 180.0 * np.pi,diameter=4)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 12, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "(1000, 1000)\n" 211 | ] 212 | }, 213 | { 214 | "data": { 215 | "text/plain": [ 216 | "True" 217 | ] 218 | }, 219 | "execution_count": 12, 220 | "metadata": {}, 221 | "output_type": "execute_result" 222 | } 223 | ], 224 | "source": [ 225 | "print(img.shape)\n", 226 | "import cv2\n", 227 | "cv2.imwrite('./test_img.jpg',(img*255).astype(np.uint8))" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [] 243 | } 244 | ], 245 | "metadata": { 246 | "kernelspec": { 247 | "display_name": "Python 3", 248 | "language": "python", 249 | "name": "python3" 250 | }, 251 | "language_info": { 252 | "codemirror_mode": { 253 | "name": "ipython", 254 | "version": 3 255 | }, 256 | "file_extension": ".py", 257 | "mimetype": "text/x-python", 258 | "name": "python", 259 | "nbconvert_exporter": "python", 260 | "pygments_lexer": "ipython3", 261 | "version": "3.6.10" 262 | } 263 | }, 264 | "nbformat": 4, 265 | "nbformat_minor": 4 266 | } 267 | -------------------------------------------------------------------------------- /utils/ply_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # 0===============================0 4 | # | PLY files reader/writer | 5 | # 0===============================0 6 | # 7 | # 8 | # ---------------------------------------------------------------------------------------------------------------------- 9 | # 10 | # function to read/write .ply files 11 | # 12 | # ---------------------------------------------------------------------------------------------------------------------- 13 | # 14 | # Hugues THOMAS - 10/02/2017 15 | # 16 | 17 | 18 | # ---------------------------------------------------------------------------------------------------------------------- 19 | # 20 | # Imports and global variables 21 | # \**********************************/ 22 | # 23 | 24 | 25 | # Basic libs 26 | import numpy as np 27 | import sys 28 | 29 | 30 | def save_ply(save_path, points, faces=None): 31 | write_ply(save_path, points, ['x', 'y', 'z'], faces) 32 | 33 | 34 | # Define PLY types 35 | ply_dtypes = dict([ 36 | (b'int8', 'i1'), 37 | (b'char', 'i1'), 38 | (b'uint8', 'u1'), 39 | (b'uchar', 'u1'), 40 | (b'int16', 'i2'), 41 | (b'short', 'i2'), 42 | (b'uint16', 'u2'), 43 | (b'ushort', 'u2'), 44 | (b'int32', 'i4'), 45 | (b'int', 'i4'), 46 | (b'uint32', 'u4'), 47 | (b'uint', 'u4'), 48 | (b'float32', 'f4'), 49 | (b'float', 'f4'), 50 | (b'float64', 'f8'), 51 | (b'double', 'f8') 52 | ]) 53 | 54 | # Numpy reader format 55 | valid_formats = {'ascii': '', 'binary_big_endian': '>', 56 | 'binary_little_endian': '<'} 57 | 58 | 59 | # ---------------------------------------------------------------------------------------------------------------------- 60 | # 61 | # Functions 62 | # \***************/ 63 | # 64 | 65 | 66 | def parse_header(plyfile, ext): 67 | # Variables 68 | line = [] 69 | properties = [] 70 | num_points = None 71 | 72 | while b'end_header' not in line and line != b'': 73 | line = plyfile.readline() 74 | 75 | if b'element' in line: 76 | line = line.split() 77 | num_points = int(line[2]) 78 | 79 | elif b'property' in line: 80 | line = line.split() 81 | properties.append((line[2].decode(), ext + ply_dtypes[line[1]])) 82 | 83 | return num_points, properties 84 | 85 | 86 | def parse_mesh_header(plyfile, ext): 87 | # Variables 88 | line = [] 89 | vertex_properties = [] 90 | num_points = None 91 | num_faces = None 92 | current_element = None 93 | 94 | while b'end_header' not in line and line != b'': 95 | line = plyfile.readline() 96 | 97 | # Find point element 98 | if b'element vertex' in line: 99 | current_element = 'vertex' 100 | line = line.split() 101 | num_points = int(line[2]) 102 | 103 | elif b'element face' in line: 104 | current_element = 'face' 105 | line = line.split() 106 | num_faces = int(line[2]) 107 | 108 | elif b'property' in line: 109 | if current_element == 'vertex': 110 | line = line.split() 111 | vertex_properties.append((line[2].decode(), ext + ply_dtypes[line[1]])) 112 | elif current_element == 'vertex': 113 | if not line.startswith('property list uchar int'): 114 | raise ValueError('Unsupported faces property : ' + line) 115 | 116 | return num_points, num_faces, vertex_properties 117 | 118 | 119 | def read_ply(filename, triangular_mesh=False): 120 | """ 121 | Read ".ply" files 122 | 123 | Parameters 124 | ---------- 125 | filename : string 126 | the name of the file to read. 127 | 128 | Returns 129 | ------- 130 | result : array 131 | data stored in the file 132 | 133 | Examples 134 | -------- 135 | Store data in file 136 | 137 | >>> points = np.random.rand(5, 3) 138 | >>> values = np.random.randint(2, size=10) 139 | >>> write_ply('example.ply', [points, values], ['x', 'y', 'z', 'values']) 140 | 141 | Read the file 142 | 143 | >>> data = read_ply('example.ply') 144 | >>> values = data['values'] 145 | array([0, 0, 1, 1, 0]) 146 | 147 | >>> points = np.vstack((data['x'], data['y'], data['z'])).T 148 | array([[ 0.466 0.595 0.324] 149 | [ 0.538 0.407 0.654] 150 | [ 0.850 0.018 0.988] 151 | [ 0.395 0.394 0.363] 152 | [ 0.873 0.996 0.092]]) 153 | 154 | """ 155 | 156 | with open(filename, 'rb') as plyfile: 157 | 158 | # Check if the file start with ply 159 | if b'ply' not in plyfile.readline(): 160 | raise ValueError('The file does not start whith the word ply') 161 | 162 | # get binary_little/big or ascii 163 | fmt = plyfile.readline().split()[1].decode() 164 | if fmt == "ascii": 165 | raise ValueError('The file is not binary') 166 | 167 | # get extension for building the numpy dtypes 168 | ext = valid_formats[fmt] 169 | 170 | # PointCloud reader vs mesh reader 171 | if triangular_mesh: 172 | 173 | # Parse header 174 | num_points, num_faces, properties = parse_mesh_header(plyfile, ext) 175 | 176 | # Get point data 177 | vertex_data = np.fromfile(plyfile, dtype=properties, count=num_points) 178 | 179 | # Get face data 180 | face_properties = [('k', ext + 'u1'), 181 | ('v1', ext + 'i4'), 182 | ('v2', ext + 'i4'), 183 | ('v3', ext + 'i4')] 184 | faces_data = np.fromfile(plyfile, dtype=face_properties, count=num_faces) 185 | 186 | # Return vertex data and concatenated faces 187 | faces = np.vstack((faces_data['v1'], faces_data['v2'], faces_data['v3'])).T 188 | data = [vertex_data, faces] 189 | 190 | else: 191 | 192 | # Parse header 193 | num_points, properties = parse_header(plyfile, ext) 194 | 195 | # Get data 196 | data = np.fromfile(plyfile, dtype=properties, count=num_points) 197 | 198 | return data 199 | 200 | 201 | def header_properties(field_list, field_names): 202 | # List of lines to write 203 | lines = [] 204 | 205 | # First line describing element vertex 206 | lines.append('element vertex %d' % field_list[0].shape[0]) 207 | 208 | # Properties lines 209 | i = 0 210 | for fields in field_list: 211 | for field in fields.T: 212 | lines.append('property %s %s' % (field.dtype.name, field_names[i])) 213 | i += 1 214 | 215 | return lines 216 | 217 | 218 | def write_ply(filename, field_list, field_names, triangular_faces=None): 219 | """ 220 | Write ".ply" files 221 | 222 | Parameters 223 | ---------- 224 | filename : string 225 | the name of the file to which the data is saved. A '.ply' extension will be appended to the 226 | file name if it does no already have one. 227 | 228 | field_list : list, tuple, numpy array 229 | the fields to be saved in the ply file. Either a numpy array, a list of numpy arrays or a 230 | tuple of numpy arrays. Each 1D numpy array and each column of 2D numpy arrays are considered 231 | as one field. 232 | 233 | field_names : list 234 | the name of each fields as a list of strings. Has to be the same length as the number of 235 | fields. 236 | 237 | Examples 238 | -------- 239 | >>> points = np.random.rand(10, 3) 240 | >>> write_ply('example1.ply', points, ['x', 'y', 'z']) 241 | 242 | >>> values = np.random.randint(2, size=10) 243 | >>> write_ply('example2.ply', [points, values], ['x', 'y', 'z', 'values']) 244 | 245 | >>> colors = np.random.randint(255, size=(10,3), dtype=np.uint8) 246 | >>> field_names = ['x', 'y', 'z', 'red', 'green', 'blue', values'] 247 | >>> write_ply('example3.ply', [points, colors, values], field_names) 248 | 249 | """ 250 | 251 | # Format list input to the right form 252 | field_list = list(field_list) if (type(field_list) == list or type(field_list) == tuple) else list((field_list,)) 253 | for i, field in enumerate(field_list): 254 | if field.ndim < 2: 255 | field_list[i] = field.reshape(-1, 1) 256 | if field.ndim > 2: 257 | print('fields have more than 2 dimensions') 258 | return False 259 | 260 | # check all fields have the same number of data 261 | n_points = [field.shape[0] for field in field_list] 262 | if not np.all(np.equal(n_points, n_points[0])): 263 | print('wrong field dimensions') 264 | return False 265 | 266 | # Check if field_names and field_list have same nb of column 267 | n_fields = np.sum([field.shape[1] for field in field_list]) 268 | if (n_fields != len(field_names)): 269 | print('wrong number of field names') 270 | return False 271 | 272 | # Add extension if not there 273 | if not filename.endswith('.ply'): 274 | filename += '.ply' 275 | 276 | # open in text mode to write the header 277 | with open(filename, 'w') as plyfile: 278 | 279 | # First magical word 280 | header = ['ply'] 281 | 282 | # Encoding format 283 | header.append('format binary_' + sys.byteorder + '_endian 1.0') 284 | 285 | # Points properties description 286 | header.extend(header_properties(field_list, field_names)) 287 | 288 | # Add faces if needded 289 | if triangular_faces is not None: 290 | header.append('element face {:d}'.format(triangular_faces.shape[0])) 291 | header.append('property list uchar int vertex_indices') 292 | 293 | # End of header 294 | header.append('end_header') 295 | 296 | # Write all lines 297 | for line in header: 298 | plyfile.write("%s\n" % line) 299 | 300 | # open in binary/append to use tofile 301 | with open(filename, 'ab') as plyfile: 302 | 303 | # Create a structured array 304 | i = 0 305 | type_list = [] 306 | for fields in field_list: 307 | for field in fields.T: 308 | type_list += [(field_names[i], field.dtype.str)] 309 | i += 1 310 | data = np.empty(field_list[0].shape[0], dtype=type_list) 311 | i = 0 312 | for fields in field_list: 313 | for field in fields.T: 314 | data[field_names[i]] = field 315 | i += 1 316 | 317 | data.tofile(plyfile) 318 | 319 | if triangular_faces is not None: 320 | triangular_faces = triangular_faces.astype(np.int32) 321 | type_list = [('k', 'uint8')] + [(str(ind), 'int32') for ind in range(3)] 322 | data = np.empty(triangular_faces.shape[0], dtype=type_list) 323 | data['k'] = np.full((triangular_faces.shape[0],), 3, dtype=np.uint8) 324 | data['0'] = triangular_faces[:, 0] 325 | data['1'] = triangular_faces[:, 1] 326 | data['2'] = triangular_faces[:, 2] 327 | data.tofile(plyfile) 328 | 329 | return True 330 | 331 | 332 | def describe_element(name, df): 333 | """ Takes the columns of the dataframe and builds a ply-like description 334 | 335 | Parameters 336 | ---------- 337 | name: str 338 | df: pandas DataFrame 339 | 340 | Returns 341 | ------- 342 | element: list[str] 343 | """ 344 | property_formats = {'f': 'float', 'u': 'uchar', 'i': 'int'} 345 | element = ['element ' + name + ' ' + str(len(df))] 346 | 347 | if name == 'face': 348 | element.append("property list uchar int points_indices") 349 | 350 | else: 351 | for i in range(len(df.columns)): 352 | # get first letter of dtype to infer format 353 | f = property_formats[str(df.dtypes[i])[0]] 354 | element.append('property ' + f + ' ' + df.columns.values[i]) 355 | 356 | return element 357 | -------------------------------------------------------------------------------- /PUNet-Evaluation/evaluation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include /* sqrt */ 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | #include 26 | #include 27 | 28 | #include 29 | #include 30 | #include 31 | 32 | //we use multi-thread to accelerate the calculation 33 | //define the thread number here 34 | #define THREAD 4 35 | 36 | typedef CGAL::Exact_predicates_inexact_constructions_kernel Kernel; 37 | typedef CGAL::Surface_mesh Triangle_mesh; 38 | typedef CGAL::Surface_mesh_shortest_path_traits Traits; 39 | typedef CGAL::Surface_mesh_shortest_path Surface_mesh_shortest_path; 40 | typedef Surface_mesh_shortest_path::Face_location Face_location; 41 | typedef boost::graph_traits Graph_traits; 42 | typedef Graph_traits::vertex_iterator vertex_iterator; 43 | typedef Graph_traits::face_iterator face_iterator; 44 | typedef Graph_traits::face_descriptor face_descriptor; 45 | typedef CGAL::AABB_face_graph_triangle_primitive AABB_face_graph_primitive; 46 | typedef CGAL::AABB_traits AABB_face_graph_traits; 47 | typedef CGAL::AABB_tree Tree; 48 | typedef Traits::Barycentric_coordinate Barycentric_coordinate; 49 | typedef Traits::FT FT; 50 | typedef Traits::Point_3 Point_3; 51 | typedef Traits::Vector_3 Vector_3; 52 | 53 | 54 | void calculate_mean_var(std::vector v){ 55 | double sum = std::accumulate(std::begin(v), std::end(v), 0.0); 56 | double mean = sum / v.size(); 57 | double accum = 0.0; 58 | std::for_each (std::begin(v), std::end(v), [&](const double d) { 59 | accum += (d - mean) * (d - mean); 60 | }); 61 | double stdev = sqrt(accum / (v.size()-1)); 62 | auto max = std::max_element(std::begin(v), std::end(v)); 63 | auto min = std::min_element(std::begin(v), std::end(v)); 64 | std::cout<<"Mean: "< *pred_face_locations = (std::vector *)(((void**)args)[1]); 72 | std::vector *sample_face_locations = (std::vector *)(((void**)args)[2]); 73 | std::vector *sample_points = (std::vector *)(((void**)args)[3]); 74 | std::vector *pred_map_points = (std::vector *)(((void**)args)[4]); 75 | std::vector > *density = (std::vector > *)(((void**)args)[5]); 76 | std::vector *radius = (std::vector *)(((void**)args)[6]); 77 | //[lower,upper) 78 | int lower = *(int*)(((void**)args)[7]); 79 | int upper = *(int*)(((void**)args)[8]); 80 | std::cout<< "In this function, handle "< radius_cnt; 85 | 86 | for (int sample_iter =lower;sample_iter((*radius).size(),0); 90 | for (unsigned int pred_iter=0;pred_itersize();pred_iter++){ 91 | dist1 = CGAL::squared_distance((*sample_points)[sample_iter],(*pred_map_points)[pred_iter]); 92 | if (CGAL::sqrt(dist1)>(*radius).back()){ 93 | continue; 94 | } 95 | dist2 = shortest_paths.shortest_distance_to_source_points((*pred_face_locations)[pred_iter].first,(*pred_face_locations)[pred_iter].second).first; 96 | for (unsigned int i=0;i<(*radius).size();i++){ 97 | if (dist2 <= (*radius)[i]){ 98 | radius_cnt[i] +=1; 99 | } 100 | } 101 | } 102 | if (sample_iter%20==0){ 103 | std::cout << "ID "<& areas, float number){ 111 | for (unsigned int i=0;i=areas[i] && number < areas[i+1]){ 113 | return i; 114 | } 115 | } 116 | return 0; 117 | } 118 | 119 | 120 | int main(int argc, char* argv[]){ 121 | // If not given the sample position, we will randomly sample THREAD*10 disks 122 | // THREAD is the number of threads 123 | if (argc!=3){ 124 | std::cout << "Usage: ./evaluation mesh_path prediction_path [sampling_seed]\n"; 125 | return -1; 126 | } 127 | 128 | // read input tmesh 129 | Triangle_mesh tmesh; 130 | std::cout << "Read "<> tmesh; 133 | input.close(); 134 | face_iterator fit, fit_end; 135 | boost::tie(fit, fit_end) = faces(tmesh); 136 | std::vector face_vector(fit, fit_end); //face_vector of the tmesh 137 | const int face_num = face_vector.size(); 138 | std::cout <<"This mesh has "<< face_num << " faces"< face_areas(face_num+1,0.0); 146 | for (unsigned int i=0;i("f:normals", CGAL::NULL_VECTOR).first; 153 | //CGAL::Polygon_mesh_processing::compute_face_normals(tmesh,fnormals, 154 | // CGAL::Polygon_mesh_processing::parameters::vertex_point_map(tmesh.points()).geom_traits(Kernel())); 155 | 156 | //read the prediction points 157 | std::vector pred_points; 158 | //std::vector normals; 159 | std::ifstream stream(argv[2]); 160 | Point_3 p; 161 | Vector_3 v; 162 | while(stream >> p){ 163 | pred_points.push_back(p); 164 | //normals.push_back(v); 165 | } 166 | const int pred_cnt = pred_points.size(); 167 | std::cout << pred_cnt << " prediction points" << std::endl; 168 | 169 | // For each predicted point, find the coresponded nearest point on the surface. 170 | Surface_mesh_shortest_path shortest_paths(tmesh); 171 | Tree tree; 172 | shortest_paths.build_aabb_tree(tree); 173 | std::vector pred_face_locations(pred_cnt); 174 | std::vector pred_map_points(pred_cnt); 175 | std::vector nearest_distance(pred_cnt); 176 | std::vector gt_normals(pred_cnt); 177 | 178 | // find the basic file name of the mesh 179 | std::string pre = argv[2]; 180 | std::string token1; 181 | if (pre.find('/')== std::string::npos){ 182 | token1 = pre; 183 | } 184 | else{ 185 | token1 = pre.substr(pre.rfind("/")+1); 186 | } 187 | std::string token2 = pre.substr(0,pre.rfind(".")); 188 | const char* prefix = token2.c_str(); 189 | char filename[2048]; 190 | sprintf(filename, "%s_point2mesh_distance.xyz",prefix); 191 | std::ofstream distace_output(filename); 192 | 193 | // calculate the point2surface distance for each predicted point 194 | for (int i=0;i(pred_points[i],tree); 197 | pred_face_locations[i] = location; 198 | // convert the face location to xyz coordinate 199 | pred_map_points[i] = shortest_paths.point(location.first,location.second); 200 | //calculate the distance 201 | nearest_distance[i] = CGAL::sqrt(CGAL::squared_distance(pred_points[i],pred_map_points[i])); 202 | distace_output << pred_points[i][0]<<" "< sample_face_locations; 211 | if (argc>3){ //read sampling seeds from file 212 | std::ifstream sample_input(argv[3]); 213 | int id; double x1,x2,x3; 214 | while(sample_input >> id >> x1 >> x2>> x3){ 215 | sample_face_locations.push_back(Face_location(face_vector[id],{{x1,x2,x3}})); 216 | } 217 | } 218 | else{ // randomly pick the seeds on the surface of the mesh 219 | int id; double x1,x2,x3,total; 220 | CGAL::Random rand; 221 | sprintf(filename, "%s_sampling_seed.txt",prefix); 222 | std::ofstream sample_output(filename); 223 | for (int i=0;i sample_points(sample_cnt); 236 | for (unsigned int i=0;i precentage={0.002,0.004,0.006,0.008,0.010,0.012,0.015}; 244 | std::vector radius(precentage.size()); 245 | for (unsigned int i=0;i > density(sample_cnt,std::vector(radius.size())); 251 | auto start = std::chrono::system_clock::now(); 252 | pthread_t tid[THREAD]; 253 | int inds[THREAD+1]; 254 | int interval = ceil(sample_cnt*1.0/THREAD); 255 | for (int i=0;i elapsed_seconds = end-start; 293 | std::time_t end_time = std::chrono::system_clock::to_time_t(end); 294 | std::cout << "finished computation at " << std::ctime(&end_time) 295 | << "elapsed time: " << elapsed_seconds.count() << "s\n"; 296 | return 0; 297 | } 298 | 299 | -------------------------------------------------------------------------------- /pointnet2/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | from __future__ import ( 8 | division, 9 | absolute_import, 10 | with_statement, 11 | print_function, 12 | unicode_literals, 13 | ) 14 | import torch 15 | from torch.autograd import Function 16 | import torch.nn as nn 17 | import pointnet2.pytorch_utils as pt_utils 18 | import sys 19 | 20 | try: 21 | import builtins 22 | except: 23 | import __builtin__ as builtins 24 | 25 | try: 26 | import pointnet2._ext as _ext 27 | except ImportError: 28 | if not getattr(builtins, "__POINTNET2_SETUP__", False): 29 | raise ImportError( 30 | "Could not import _ext module.\n" 31 | "Please see the setup instructions in the README: " 32 | "https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst" 33 | ) 34 | 35 | if False: 36 | # Workaround for type hints without depending on the `typing` module 37 | from typing import * 38 | 39 | 40 | class RandomDropout(nn.Module): 41 | def __init__(self, p=0.5, inplace=False): 42 | super(RandomDropout, self).__init__() 43 | self.p = p 44 | self.inplace = inplace 45 | 46 | def forward(self, X): 47 | theta = torch.Tensor(1).uniform_(0, self.p)[0] 48 | return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace) 49 | 50 | 51 | class FurthestPointSampling(Function): 52 | @staticmethod 53 | def forward(ctx, xyz, npoint): 54 | # type: (Any, torch.Tensor, int) -> torch.Tensor 55 | r""" 56 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 57 | minimum distance 58 | 59 | Parameters 60 | ---------- 61 | xyz : torch.Tensor 62 | (B, N, 3) tensor where N > npoint 63 | npoint : int32 64 | number of features in the sampled set 65 | 66 | Returns 67 | ------- 68 | torch.Tensor 69 | (B, npoint) tensor containing the set 70 | """ 71 | return _ext.furthest_point_sampling(xyz, npoint) 72 | 73 | @staticmethod 74 | def backward(xyz, a=None): 75 | return None, None 76 | 77 | 78 | furthest_point_sample = FurthestPointSampling.apply 79 | 80 | 81 | class GatherOperation(Function): 82 | @staticmethod 83 | def forward(ctx, features, idx): 84 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 85 | r""" 86 | 87 | Parameters 88 | ---------- 89 | features : torch.Tensor 90 | (B, C, N) tensor 91 | 92 | idx : torch.Tensor 93 | (B, npoint) tensor of the features to gather 94 | 95 | Returns 96 | ------- 97 | torch.Tensor 98 | (B, C, npoint) tensor 99 | """ 100 | 101 | _, C, N = features.size() 102 | 103 | ctx.for_backwards = (idx, C, N) 104 | 105 | return _ext.gather_points(features, idx) 106 | 107 | @staticmethod 108 | def backward(ctx, grad_out): 109 | idx, C, N = ctx.for_backwards 110 | 111 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 112 | return grad_features, None 113 | 114 | 115 | gather_operation = GatherOperation.apply 116 | 117 | 118 | class ThreeNN(Function): 119 | @staticmethod 120 | def forward(ctx, unknown, known): 121 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 122 | r""" 123 | Find the three nearest neighbors of unknown in known 124 | Parameters 125 | ---------- 126 | unknown : torch.Tensor 127 | (B, n, 3) tensor of known features 128 | known : torch.Tensor 129 | (B, m, 3) tensor of unknown features 130 | 131 | Returns 132 | ------- 133 | dist : torch.Tensor 134 | (B, n, 3) l2 distance to the three nearest neighbors 135 | idx : torch.Tensor 136 | (B, n, 3) index of 3 nearest neighbors 137 | """ 138 | dist2, idx = _ext.three_nn(unknown, known) 139 | 140 | return torch.sqrt(dist2), idx 141 | 142 | @staticmethod 143 | def backward(ctx, a=None, b=None): 144 | return None, None 145 | 146 | 147 | three_nn = ThreeNN.apply 148 | 149 | 150 | class ThreeInterpolate(Function): 151 | @staticmethod 152 | def forward(ctx, features, idx, weight): 153 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 154 | r""" 155 | Performs weight linear interpolation on 3 features 156 | Parameters 157 | ---------- 158 | features : torch.Tensor 159 | (B, c, m) Features descriptors to be interpolated from 160 | idx : torch.Tensor 161 | (B, n, 3) three nearest neighbors of the target features in features 162 | weight : torch.Tensor 163 | (B, n, 3) weights 164 | 165 | Returns 166 | ------- 167 | torch.Tensor 168 | (B, c, n) tensor of the interpolated features 169 | """ 170 | B, c, m = features.size() 171 | n = idx.size(1) 172 | 173 | ctx.three_interpolate_for_backward = (idx, weight, m) 174 | 175 | return _ext.three_interpolate(features, idx, weight) 176 | 177 | @staticmethod 178 | def backward(ctx, grad_out): 179 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 180 | r""" 181 | Parameters 182 | ---------- 183 | grad_out : torch.Tensor 184 | (B, c, n) tensor with gradients of ouputs 185 | 186 | Returns 187 | ------- 188 | grad_features : torch.Tensor 189 | (B, c, m) tensor with gradients of features 190 | 191 | None 192 | 193 | None 194 | """ 195 | idx, weight, m = ctx.three_interpolate_for_backward 196 | 197 | grad_features = _ext.three_interpolate_grad( 198 | grad_out.contiguous(), idx, weight, m 199 | ) 200 | 201 | return grad_features, None, None 202 | 203 | 204 | three_interpolate = ThreeInterpolate.apply 205 | 206 | 207 | class GroupingOperation(Function): 208 | @staticmethod 209 | def forward(ctx, features, idx): 210 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 211 | r""" 212 | 213 | Parameters 214 | ---------- 215 | features : torch.Tensor 216 | (B, C, N) tensor of features to group 217 | idx : torch.Tensor 218 | (B, npoint, nsample) tensor containing the indicies of features to group with 219 | 220 | Returns 221 | ------- 222 | torch.Tensor 223 | (B, C, npoint, nsample) tensor 224 | """ 225 | B, nfeatures, nsample = idx.size() 226 | _, C, N = features.size() 227 | 228 | ctx.for_backwards = (idx, N) 229 | 230 | return _ext.group_points(features, idx) 231 | 232 | @staticmethod 233 | def backward(ctx, grad_out): 234 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 235 | r""" 236 | 237 | Parameters 238 | ---------- 239 | grad_out : torch.Tensor 240 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 241 | 242 | Returns 243 | ------- 244 | torch.Tensor 245 | (B, C, N) gradient of the features 246 | None 247 | """ 248 | idx, N = ctx.for_backwards 249 | 250 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 251 | 252 | return grad_features, None 253 | 254 | 255 | grouping_operation = GroupingOperation.apply 256 | 257 | 258 | class BallQuery(Function): 259 | @staticmethod 260 | def forward(ctx, radius, nsample, xyz, new_xyz): 261 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 262 | r""" 263 | 264 | Parameters 265 | ---------- 266 | radius : float 267 | radius of the balls 268 | nsample : int 269 | maximum number of features in the balls 270 | xyz : torch.Tensor 271 | (B, N, 3) xyz coordinates of the features 272 | new_xyz : torch.Tensor 273 | (B, npoint, 3) centers of the ball query 274 | 275 | Returns 276 | ------- 277 | torch.Tensor 278 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 279 | """ 280 | return _ext.ball_query(new_xyz, xyz, radius, nsample) 281 | 282 | @staticmethod 283 | def backward(ctx, a=None): 284 | return None, None, None, None 285 | 286 | 287 | ball_query = BallQuery.apply 288 | 289 | 290 | class QueryAndGroup(nn.Module): 291 | r""" 292 | Groups with a ball query of radius 293 | 294 | Parameters 295 | --------- 296 | radius : float32 297 | Radius of ball 298 | nsample : int32 299 | Maximum number of features to gather in the ball 300 | """ 301 | 302 | def __init__(self, radius, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, sample_uniformly=False, ret_unique_cnt=False): 303 | # type: (QueryAndGroup, float, int, bool) -> None 304 | super(QueryAndGroup, self).__init__() 305 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 306 | self.ret_grouped_xyz = ret_grouped_xyz 307 | self.normalize_xyz = normalize_xyz 308 | self.sample_uniformly = sample_uniformly 309 | self.ret_unique_cnt = ret_unique_cnt 310 | if self.ret_unique_cnt: 311 | assert(self.sample_uniformly) 312 | 313 | def forward(self, xyz, new_xyz, features=None): 314 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 315 | r""" 316 | Parameters 317 | ---------- 318 | xyz : torch.Tensor 319 | xyz coordinates of the features (B, N, 3) 320 | new_xyz : torch.Tensor 321 | centriods (B, npoint, 3) 322 | features : torch.Tensor 323 | Descriptors of the features (B, C, N) 324 | 325 | Returns 326 | ------- 327 | new_features : torch.Tensor 328 | (B, 3 + C, npoint, nsample) tensor 329 | """ 330 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 331 | 332 | if self.sample_uniformly: 333 | unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) 334 | for i_batch in range(idx.shape[0]): 335 | for i_region in range(idx.shape[1]): 336 | unique_ind = torch.unique(idx[i_batch, i_region, :]) 337 | num_unique = unique_ind.shape[0] 338 | unique_cnt[i_batch, i_region] = num_unique 339 | sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long) 340 | all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) 341 | idx[i_batch, i_region, :] = all_ind 342 | 343 | 344 | xyz_trans = xyz.transpose(1, 2).contiguous() 345 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 346 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 347 | if self.normalize_xyz: 348 | grouped_xyz /= self.radius 349 | #print(features.shape,idx.shape) 350 | if features is not None: 351 | grouped_features = grouping_operation(features, idx) 352 | if self.use_xyz: 353 | new_features = torch.cat( 354 | [grouped_xyz, grouped_features], dim=1 355 | ) # (B, C + 3, npoint, nsample) 356 | else: 357 | new_features = grouped_features 358 | else: 359 | assert ( 360 | self.use_xyz 361 | ), "Cannot have not features and not use xyz as a feature!" 362 | new_features = grouped_xyz 363 | 364 | ret = [new_features] 365 | if self.ret_grouped_xyz: 366 | ret.append(grouped_xyz) 367 | if self.ret_unique_cnt: 368 | ret.append(unique_cnt) 369 | if len(ret) == 1: 370 | return ret[0] 371 | else: 372 | return tuple(ret) 373 | 374 | 375 | class GroupAll(nn.Module): 376 | r""" 377 | Groups all features 378 | 379 | Parameters 380 | --------- 381 | """ 382 | 383 | def __init__(self, use_xyz=True, ret_grouped_xyz=False): 384 | # type: (GroupAll, bool) -> None 385 | super(GroupAll, self).__init__() 386 | self.use_xyz = use_xyz 387 | 388 | def forward(self, xyz, new_xyz, features=None): 389 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 390 | r""" 391 | Parameters 392 | ---------- 393 | xyz : torch.Tensor 394 | xyz coordinates of the features (B, N, 3) 395 | new_xyz : torch.Tensor 396 | Ignored 397 | features : torch.Tensor 398 | Descriptors of the features (B, C, N) 399 | 400 | Returns 401 | ------- 402 | new_features : torch.Tensor 403 | (B, C + 3, 1, N) tensor 404 | """ 405 | 406 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 407 | if features is not None: 408 | grouped_features = features.unsqueeze(2) 409 | if self.use_xyz: 410 | new_features = torch.cat( 411 | [grouped_xyz, grouped_features], dim=1 412 | ) # (B, 3 + C, 1, N) 413 | else: 414 | new_features = grouped_features 415 | else: 416 | new_features = grouped_xyz 417 | 418 | if self.ret_grouped_xyz: 419 | return new_features, grouped_xyz 420 | else: 421 | return new_features 422 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 19 | 20 | 25 | 26 | 27 | 29 | 30 | 31 | 32 | 33 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 65 | 66 | 67 | 87 | 88 | 89 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 |