├── README.md ├── create_pcn_h5.py ├── data ├── __init__.py ├── completion3d │ └── test.list ├── dataset.py └── pcn │ └── test.list ├── main.py ├── main_benchmark.py ├── models ├── __init__.py ├── full_model.py ├── model.py └── model_utils.py ├── overview.png ├── run.sh ├── run_test.sh ├── run_test_benchmark.sh ├── test.py ├── test_benchmark.py ├── test_per_category_l2cd.py ├── train ├── __init__.py └── train.py └── utils ├── ChamferDistancePytorch ├── LICENSE ├── README.md ├── chamfer2D │ ├── chamfer2D.cu │ ├── chamfer_cuda.cpp │ ├── dist_chamfer_2D.py │ └── setup.py ├── chamfer3D │ ├── chamfer3D.cu │ ├── chamfer_cuda.cpp │ ├── dist_chamfer_3D.py │ └── setup.py ├── chamfer5D │ ├── chamfer5D.cu │ ├── chamfer_cuda.cpp │ ├── dist_chamfer_5D.py │ └── setup.py ├── chamfer_python.py ├── fscore.py └── unit_test.py ├── MDS ├── MDS.cpp ├── MDS_cuda.cu ├── MDS_module.py ├── clean.sh ├── run_compile.sh └── setup.py ├── Pointnet2.PyTorch ├── LICENSE ├── README.md ├── pointnet2 │ ├── build │ │ └── temp.linux-x86_64-3.6 │ │ │ ├── build.ninja │ │ │ └── src │ │ │ ├── ball_query_gpu.o │ │ │ ├── group_points_gpu.o │ │ │ ├── interpolate_gpu.o │ │ │ ├── pointnet2_api.o │ │ │ └── sampling_gpu.o │ ├── pointnet2.egg-info │ │ ├── PKG-INFO │ │ ├── SOURCES.txt │ │ ├── dependency_links.txt │ │ └── top_level.txt │ ├── pointnet2_modules.py │ ├── pointnet2_utils.py │ ├── pytorch_utils.py │ ├── setup.py │ └── src │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── ball_query_gpu.h │ │ ├── cuda_utils.h │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── group_points_gpu.h │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ ├── interpolate_gpu.h │ │ ├── pointnet2_api.cpp │ │ ├── sampling.cpp │ │ ├── sampling_gpu.cu │ │ └── sampling_gpu.h └── tools │ ├── _init_path.py │ ├── data │ └── KITTI │ │ └── ImageSets │ │ ├── test.txt │ │ ├── train.txt │ │ ├── trainval.txt │ │ └── val.txt │ ├── dataset.py │ ├── kitti_utils.py │ ├── pointnet2_msg.py │ └── train_and_eval.py ├── __init__.py ├── emd ├── CDEMD.png ├── README.md ├── clean.sh ├── emd.cpp ├── emd_cuda.cu ├── emd_module.py ├── run_compile.sh └── setup.py ├── expansion_penalty ├── clean.sh ├── expansion_penalty.cpp ├── expansion_penalty_cuda.cu ├── expansion_penalty_module.py ├── run_compile.sh └── setup.py ├── generate_excel_results.py ├── utils.py └── vis_pcd.py /README.md: -------------------------------------------------------------------------------- 1 | # Point Cloud Completion via Skeleton-Detail Transformer 2 | Codes for Point Cloud Completion via Skeleton-Detail Transformer. IEEE Transactions on Visualization and Computer Graphics (TVCG), 2022. See [IEEE PDF](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9804851). 3 | 4 | ![overview](overview.png) 5 | In this work, we present a coarse-to-fine completion framework, which makes full use of both neighboring and long-distance region cues for point cloud completion. Our network leverages a Skeleton-Detail Transformer, which contains cross-attention and self-attention layers, to fully explore the correlation from local patterns to global shape and utilize it to enhance the overall skeleton. Also, we propose a selective attention mechanism to save memory usage in the attention process without significantly affecting performance. 6 | 7 | ### 1) Pre-requisites 8 | * Python3 9 | * CUDA 10 | * pytorch 11 | * open3d-python 12 | 13 | This code is built using Pytorch 1.7.1 with CUDA 10.2 and tested on Ubuntu 18.04 with Python 3.6. 14 | 15 | ### 2)Compile 3rd-party libs 16 | The libs are included under `/util`, you need to first compile them where there is also a 'Readme.md' in each subfolder. 17 | 18 | ### 3)Download pre-trained models 19 | Download pre-trained models in `trained_model` folder from [Google Drive](https://drive.google.com/file/d/1OlfBdK0707iGLkdn18VgXTrkSxylzqcj/view?usp=sharing) and put them on `trianed_model` dir. 20 | 21 | ### 4) Testing 22 | For PCN: 23 | 1. Download ShapeNet test data on [Google Drive](https://drive.google.com/drive/folders/1o2Kwi-0127mVZjRJskY9tquJKTD-67Jm?usp=sharing). Put them on `data/pcn` folder. We use the same testing data in [PCN](https://www.cs.cmu.edu/~wyuan1/pcn/) project but we use `h5` format. 24 | 2. Run `sh test.sh`. You should first modify the `model_path` to the folder containing your pre-trained model, and `data_path` to the testing files. 25 | 26 | For Completion3D: 27 | 1. Download the test data on [Google Drive](https://drive.google.com/drive/folders/1o2Kwi-0127mVZjRJskY9tquJKTD-67Jm?usp=sharing) or [Completion3D](https://completion3d.stanford.edu/). Put them on `data/completion3d` folder. 28 | 2. Run `test_benchmark.sh` to generate the 'submission.zip' file for Compleiont3D dataset. 29 | 30 | 31 | ### 5) Traning 32 | For PCN 33 | 1. The training data are from [PCN repository](https://github.com/wentaoyuan/pcn), you can download training (`train.lmdb`, `train.lmdb-lock`) and validation (`valid.lmdb`, `valid.lmdb-lock`) data from `shapenet` directory on the provided training set link in PCN repository. 34 | 2. Run `python create_pcn_h5.py` to generate the training and validation files with `.h5` format. 35 | 3. Run `sh run.sh` for training. 36 | 37 | For Compleiont3D: 38 | You can directly download the tranining files from Compleiont3D benchmark. Run `sh run.sh` and set `dataset` to `Completion3D`. 39 | 40 | ## [Acknowledgement] 41 | Our codes are partly from [ECG](https://github.com/paul007pl/ECG), [VRCNET](https://github.com/paul007pl/VRCNet). We sincerely thank for their contribution. 42 | 43 | -------------------------------------------------------------------------------- /create_pcn_h5.py: -------------------------------------------------------------------------------- 1 | from tensorpack import dataflow 2 | import h5py 3 | import os 4 | 5 | #generate train files 6 | df = dataflow.LMDBSerializer.load('data/train.lmdb', shuffle=False) 7 | print('df size:', df.size()) 8 | ds = dataflow.PrefetchData(df, nr_prefetch=500, nr_proc=1) 9 | size = df.size() 10 | output_base_folder = 'data/pcn/train' 11 | if not os.path.exists(output_base_folder): 12 | os.makedirs(output_base_folder) 13 | f_list = open('data/pcn/train.list', 'w') 14 | i = 0 15 | for id, input, gt in ds.get_data(): 16 | ids = id.split('_') 17 | category_id = ids[0] 18 | model_id = ids[1] 19 | idx = len(ids) - 3 20 | 21 | partial_output_folder = os.path.join(output_base_folder, 'partial', category_id) 22 | gt_output_folder = os.path.join(output_base_folder, 'gt', category_id) 23 | if not os.path.exists(partial_output_folder): 24 | os.makedirs(partial_output_folder) 25 | if not os.path.exists(gt_output_folder): 26 | os.makedirs(gt_output_folder) 27 | 28 | f = h5py.File(os.path.join(partial_output_folder, '%s_%d.h5' % (model_id, idx)), 'w') 29 | f.create_dataset("data", data=input) 30 | f.close() 31 | 32 | f = h5py.File(os.path.join(gt_output_folder, '%s_%d.h5' % (model_id, idx)), 'w') 33 | f.create_dataset("data", data=gt) 34 | f.close() 35 | 36 | f_list.write(os.path.join(category_id, '%s_%d' % (model_id, idx))) 37 | if i != size-1: 38 | f_list.write('\n') 39 | f_list.close() 40 | 41 | #generate valid files 42 | df = dataflow.LMDBSerializer.load('data/valid.lmdb', shuffle=False) 43 | ds = dataflow.PrefetchData(df, nr_prefetch=500, nr_proc=1) 44 | size = df.size() 45 | output_base_folder = 'data/pcn/val' 46 | if not os.path.exists(output_base_folder): 47 | os.makedirs(output_base_folder) 48 | f_list = open('data/pcn/val.list', 'w') 49 | i = 0 50 | for id, input, gt in ds.get_data(): 51 | ids = id.split('_') 52 | category_id = ids[0] 53 | model_id = ids[1] 54 | idx = len(ids) - 3 55 | 56 | partial_output_folder = os.path.join(output_base_folder, 'partial', category_id) 57 | gt_output_folder = os.path.join(output_base_folder, 'gt', category_id) 58 | if not os.path.exists(partial_output_folder): 59 | os.makedirs(partial_output_folder) 60 | if not os.path.exists(gt_output_folder): 61 | os.makedirs(gt_output_folder) 62 | 63 | f = h5py.File(os.path.join(partial_output_folder, '%s.h5' % (model_id)), 'w') 64 | f.create_dataset("data", data=input) 65 | f.close() 66 | 67 | f = h5py.File(os.path.join(gt_output_folder, '%s.h5' % (model_id)), 'w') 68 | f.create_dataset("data", data=gt) 69 | f.close() 70 | 71 | f_list.write(os.path.join(category_id, '%s' % (model_id))) 72 | if i != size-1: 73 | f_list.write('\n') 74 | f_list.close() -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.utils.data as data 4 | import h5py 5 | import os 6 | 7 | def load_h5(path, verbose=False): 8 | if verbose: 9 | print("Loading %s \n" % (path)) 10 | f = h5py.File(path, 'r') 11 | cloud_data = np.array(f['data']) 12 | f.close() 13 | 14 | #return cloud_data.astype(np.float64) 15 | return cloud_data 16 | 17 | def pad_cloudN(P, Nin): 18 | """ Pad or subsample 3D Point cloud to Nin number of points """ 19 | N = P.shape[0] 20 | P = P[:].astype(np.float32) 21 | 22 | rs = np.random.random.__self__ 23 | choice = np.arange(N) 24 | if N > Nin: # need to subsample 25 | ii = rs.choice(N, Nin) 26 | choice = ii 27 | elif N < Nin: # need to pad by duplication 28 | ii = rs.choice(N, Nin - N) 29 | choice = np.concatenate([range(N),ii]) 30 | P = P[choice, :] 31 | 32 | return P 33 | 34 | class Completion3D(data.Dataset): 35 | def __init__(self, datapath, train=True, npoints=2048, use_mean_feature=0, benchmark=False): 36 | # train data only has input(2048) and gt(2048) 37 | self.npoints = npoints 38 | self.train = train 39 | self.use_mean_feature = use_mean_feature 40 | if train: 41 | split = 'train' 42 | elif benchmark: 43 | split = 'test' 44 | else: 45 | split = 'val' 46 | 47 | DATA_PATH = datapath 48 | 49 | self.partial_data_paths = [os.path.join(DATA_PATH, split, 'partial', k.rstrip()+ '.h5') for k in open(DATA_PATH + '/%s.list' % (split)).readlines()] #sorted() 50 | 51 | if benchmark: 52 | self.gt_data_paths = self.partial_data_paths 53 | else: 54 | self.gt_data_paths = [os.path.join(DATA_PATH, split, 'gt', k.rstrip() + '.h5') for k in 55 | open(DATA_PATH + '/%s.list' % (split)).readlines()] #sorted() 56 | #print(self.partial_data_paths, np.array(self.partial_data_paths).shape) 57 | self.len = np.array(self.partial_data_paths).shape[0] 58 | print(self.len) 59 | 60 | def __len__(self): 61 | return self.len 62 | 63 | def __getitem__(self, index): 64 | partial = torch.from_numpy(np.array(load_h5(self.partial_data_paths[index]))).float() 65 | #print('partial.shape', partial.shape) 66 | complete = torch.from_numpy(np.array(load_h5(self.gt_data_paths[index]))).float() 67 | label = self.partial_data_paths[index] 68 | if self.use_mean_feature == 1: 69 | mean_feature_input = torch.from_numpy(np.array(self.mean_feature[label])).float() 70 | return label, partial, complete, mean_feature_input 71 | else: 72 | return label, partial, complete 73 | 74 | class PCN(data.Dataset): 75 | def __init__(self, datapath, train=True, npoints=2048, use_mean_feature=0, test=False): 76 | # train data only has input(2048) and gt(2048) 77 | self.npoints = npoints 78 | self.train = train 79 | self.use_mean_feature = use_mean_feature 80 | if train: 81 | split = 'train' 82 | elif test: 83 | split = 'test' 84 | else: 85 | split = 'val' 86 | 87 | DATA_PATH = datapath 88 | 89 | self.partial_data_paths = [os.path.join(DATA_PATH, split, 'partial', k.rstrip()+ '.h5') for k in open(DATA_PATH + '/%s.list' % (split)).readlines()] #sorted() 90 | 91 | self.gt_data_paths = [os.path.join(DATA_PATH, split, 'gt', k.rstrip() + '.h5') for k in 92 | open(DATA_PATH + '/%s.list' % (split)).readlines()] #sorted() 93 | #print(self.partial_data_paths, np.array(self.partial_data_paths).shape) 94 | self.len = np.array(self.partial_data_paths).shape[0] 95 | print(self.len) 96 | 97 | def __len__(self): 98 | return self.len 99 | 100 | def __getitem__(self, index): 101 | partial = torch.from_numpy(pad_cloudN(np.array(load_h5(self.partial_data_paths[index])), 2048)).float() 102 | #print('partial.shape', partial.shape) 103 | complete = torch.from_numpy(np.array(load_h5(self.gt_data_paths[index]))).float() 104 | label = self.partial_data_paths[index] 105 | if self.use_mean_feature == 1: 106 | mean_feature_input = torch.from_numpy(np.array(self.mean_feature[label])).float() 107 | return label, partial, complete, mean_feature_input 108 | else: 109 | return label, partial, complete 110 | 111 | 112 | class SCAN(data.Dataset): 113 | def __init__(self, datapath, npoints=2048): 114 | # train data only has input(2048) and gt(2048) 115 | self.npoints = npoints 116 | 117 | DATA_PATH = datapath 118 | 119 | self.partial_data_paths = [os.path.join(DATA_PATH, k.rstrip() + '.h5') for k in open(DATA_PATH + '/data_list.txt').readlines()] #sorted() 120 | 121 | self.gt_data_paths = self.partial_data_paths 122 | 123 | self.len = np.array(self.partial_data_paths).shape[0] 124 | print(self.len) 125 | 126 | def __len__(self): 127 | return self.len 128 | 129 | def __getitem__(self, index): 130 | partial = torch.from_numpy(pad_cloudN(np.array(load_h5(self.partial_data_paths[index])), 2048)).float() 131 | #print('partial.shape', partial.shape) 132 | complete = torch.from_numpy(np.array(load_h5(self.gt_data_paths[index]))).float() 133 | label = self.partial_data_paths[index] 134 | 135 | return label, partial, complete 136 | 137 | 138 | class KITTI(data.Dataset): 139 | def __init__(self, datapath, npoints=2048): 140 | # train data only has input(2048) and gt(2048) 141 | self.npoints = npoints 142 | 143 | DATA_PATH = datapath 144 | 145 | self.partial_data_paths = [os.path.join(DATA_PATH, 'cars_h5', k.rstrip() + '.h5') for k in open(DATA_PATH + '/data_list.txt').readlines()] #sorted() 146 | 147 | self.gt_data_paths = self.partial_data_paths 148 | 149 | self.len = np.array(self.partial_data_paths).shape[0] 150 | print(self.len) 151 | 152 | def __len__(self): 153 | return self.len 154 | 155 | def __getitem__(self, index): 156 | partial = torch.from_numpy(pad_cloudN(np.array(load_h5(self.partial_data_paths[index])), 2048)).float() 157 | complete = torch.from_numpy(np.array(load_h5(self.gt_data_paths[index]))).float() 158 | label = self.partial_data_paths[index] 159 | 160 | return label, partial, complete 161 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import argparse 4 | from test import test 5 | from train.train import train 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser(description='Point Cloud Completion') 8 | 9 | # mode and dataset 10 | parser.add_argument('--mode', type=int, default=0, help='0 for train, 1 for test') 11 | parser.add_argument('--model_dir', type=str, default='/mnt/data2/zwx/ECG/log/PCT_CD_train/PCT_4SA_2021-02-26T16:10:21') # for test only 12 | parser.add_argument('--dataset', type=str, default='Completion3D', help='dataset') 13 | parser.add_argument('--datapath', type=str, default='data/completion3d', help='dataset path') 14 | # common args 15 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size') 16 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=12) 17 | parser.add_argument('--nepoch', type=int, default=100, help='number of epochs to train for') 18 | parser.add_argument('--model_name', type=str, default='ECG', help='model to use') 19 | parser.add_argument('--load_model', type=str, default='', help='load model to resume training / start testing') 20 | parser.add_argument('--resume_epoch', type=int, default=0, help='which epoch to resume from') 21 | parser.add_argument('--num_points', type=int, default=2048, help='number of ground truth points') 22 | parser.add_argument('--log_env', type=str, default="ecg_2048", help='subfolder name inside log/__train/') 23 | parser.add_argument('--loss', type=str, default='EMD', help='train loss type; CD or EMD') 24 | parser.add_argument('--manual_seed', type=str, default='', help='manual seed') 25 | parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') # cascade, msn, pcn:0.0001, topnet:0.5e-2 26 | parser.add_argument('--use_mean_feature', type=int, default=0, help='0 if not using, 1 if using') 27 | 28 | args = parser.parse_args() 29 | 30 | #assert args.model_name in list(models_dict.keys()) 31 | assert args.loss == 'EMD' or args.loss == 'CD' 32 | 33 | if args.mode == 0: 34 | print('args.num_points in train', args.num_points) 35 | train(args) 36 | else: 37 | test(args) 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /main_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | #os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" 3 | 4 | import argparse 5 | from test_benchmark import test 6 | from train.train import train 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser(description='Point Cloud Completion') 9 | 10 | # mode and dataset 11 | parser.add_argument('--mode', type=int, default=0, help='0 for train, 1 for test') 12 | parser.add_argument('--model_dir', type=str, default='/mnt/data2/zwx/ECG/log/PCT_CD_train/PCT_4SA_2021-02-26T16:10:21') # for test only 13 | parser.add_argument('--dataset', type=str, default='Completion3D', help='dataset') 14 | parser.add_argument('--datapath', type=str, default='data/completion3d', help='dataset path') 15 | # common args 16 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size') 17 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=12) 18 | parser.add_argument('--nepoch', type=int, default=100, help='number of epochs to train for') 19 | parser.add_argument('--model_name', type=str, default='ECG', help='model to use') 20 | parser.add_argument('--load_model', type=str, default='', help='load model to resume training / start testing') 21 | parser.add_argument('--resume_epoch', type=int, default=0, help='which epoch to resume from') 22 | parser.add_argument('--num_points', type=int, default=2048, help='number of ground truth points') 23 | parser.add_argument('--log_env', type=str, default="ecg_2048", help='subfolder name inside log/__train/') 24 | parser.add_argument('--loss', type=str, default='EMD', help='train loss type; CD or EMD') 25 | parser.add_argument('--manual_seed', type=str, default='', help='manual seed') 26 | parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') # cascade, msn, pcn:0.0001, topnet:0.5e-2 27 | parser.add_argument('--use_mean_feature', type=int, default=0, help='0 if not using, 1 if using') 28 | 29 | args = parser.parse_args() 30 | 31 | #assert args.model_name in list(models_dict.keys()) 32 | assert args.loss == 'EMD' or args.loss == 'CD' 33 | 34 | if args.mode == 0: 35 | print('args.num_points in train', args.num_points) 36 | train(args) 37 | else: 38 | test(args) 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/full_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import sys 4 | import os 5 | proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(os.path.join(proj_dir, "utils/emd")) 7 | # import emd_module as emd 8 | sys.path.append(os.path.join(proj_dir, "utils/ChamferDistancePytorch")) 9 | from chamfer3D import dist_chamfer_3D 10 | chamLoss = dist_chamfer_3D.chamfer_3DDist() 11 | sys.path.append(os.path.join(proj_dir, "utils/Pointnet2.PyTorch/pointnet2")) 12 | import pointnet2_utils as pn2 13 | from models.model_utils import get_uniform_loss,get_repulsion_loss 14 | 15 | 16 | class FullModel(nn.Module): 17 | def __init__(self, model): 18 | super(FullModel, self).__init__() 19 | self.model = model 20 | #self.EMD = emd.emdModule() 21 | 22 | def forward(self, inputs, gt, eps, iters, EMD=True, CD=True): 23 | cur_bs = inputs.size()[0] 24 | output1, output2 = self.model(inputs) 25 | gt = gt[:, :, :3] 26 | 27 | emd1 = emd2 = cd_p1 = cd_p2 = cd_t1 = cd_t2 = origin_cd_p1 = origin_cd_p2 = torch.tensor([0], dtype=torch.float32).cuda() 28 | 29 | # if EMD: 30 | # num_coarse = self.model.num_coarse 31 | # gt_fps = pn2.gather_operation(gt.transpose(1, 2).contiguous(), 32 | # pn2.furthest_point_sample(gt, num_coarse)).transpose(1, 2).contiguous() 33 | # 34 | # dist1, _ = self.EMD(output1, gt_fps, eps, iters) 35 | # emd1 = torch.sqrt(dist1).mean(1) 36 | # 37 | # dist2, _ = self.EMD(output2, gt, eps, iters) 38 | # emd2 = torch.sqrt(dist2).mean(1) 39 | # 40 | # # CD loss 41 | if CD: 42 | dist11, dist12, _, _ = chamLoss(gt, output1) 43 | cd_p1 = (torch.sqrt(dist11).mean(1) + torch.sqrt(dist12).mean(1)) / 2 44 | cd_t1 = (dist11.mean(1) + dist12.mean(1)) 45 | 46 | dist21, dist22, _, _ = chamLoss(gt, output2) 47 | cd_p2 = (torch.sqrt(dist21).mean(1) + torch.sqrt(dist22).mean(1)) / 2 48 | cd_t2 = (dist21.mean(1) + dist22.mean(1)) 49 | 50 | # dist31, dist32, _, _ = chamLoss(inputs, output1) 51 | # dist41, dist42, _, _ = chamLoss(inputs, output2) 52 | # origin_cd_p1 = (torch.sqrt(dist31).mean(1) + torch.sqrt(dist32).mean(1)) / 2 53 | # origin_cd_p2 = (torch.sqrt(dist41).mean(1) + torch.sqrt(dist42).mean(1)) / 2 54 | 55 | 56 | # u1 = get_uniform_loss(output1) 57 | # u2 = get_uniform_loss(output2) 58 | 59 | u1 = u2 = torch.tensor([0], dtype=torch.float32).cuda() 60 | # u2 = get_repulsion_loss(output2) 61 | 62 | return output1, output2, emd1, emd2, cd_p1, cd_p2, cd_t1, cd_t2, u1, u2, origin_cd_p1, origin_cd_p2 63 | -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import os 4 | import sys 5 | proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(os.path.join(proj_dir, "utils/Pointnet2.PyTorch/pointnet2")) 7 | import pointnet2_utils as pn2 8 | 9 | def gen_grid(num_grid_point): 10 | x = torch.linspace(-0.05, 0.05, num_grid_point) 11 | x, y = torch.meshgrid(x, x) 12 | grid = torch.stack([x, y], axis=-1).view(2, num_grid_point ** 2) 13 | return grid 14 | 15 | 16 | def gen_1d_grid(num_grid_point): 17 | x = torch.linspace(-0.05, 0.05, num_grid_point) 18 | grid = x.view(1, num_grid_point) 19 | return grid 20 | 21 | 22 | def gen_grid_up(up_ratio, grid_size=0.2): 23 | sqrted = int(math.sqrt(up_ratio)) + 1 24 | for i in range(1, sqrted + 1).__reversed__(): 25 | if (up_ratio % i) == 0: 26 | num_x = i 27 | num_y = up_ratio // i 28 | break 29 | 30 | grid_x = torch.linspace(-grid_size, grid_size, steps=num_x) 31 | grid_y = torch.linspace(-grid_size, grid_size, steps=num_y) 32 | 33 | x, y = torch.meshgrid(grid_x, grid_y) # x, y shape: (2, 1) 34 | grid = torch.stack([x, y], dim=-1).view(-1, 2).transpose(0, 1).contiguous() 35 | return grid 36 | 37 | 38 | def symmetric_sample(points, num=512): 39 | p1_idx = pn2.furthest_point_sample(points, num) 40 | input_fps = pn2.gather_operation(points.transpose(1, 2).contiguous(), p1_idx).transpose(1, 2).contiguous() 41 | x = torch.unsqueeze(input_fps[:, :, 0], dim=2) 42 | y = torch.unsqueeze(input_fps[:, :, 1], dim=2) 43 | z = torch.unsqueeze(-input_fps[:, :, 2], dim=2) 44 | input_fps_flip = torch.cat([x, y, z], dim=2) 45 | input_fps = torch.cat([input_fps, input_fps_flip], dim=1) 46 | return input_fps 47 | 48 | 49 | def index_points(points, idx): 50 | """ 51 | Input: 52 | points: input points data, [B, N, C] 53 | idx: sample index data, [B, S] 54 | Return: 55 | new_points:, indexed points data, [B, S, C] 56 | """ 57 | device = points.device 58 | B = points.shape[0] 59 | view_shape = list(idx.shape) 60 | view_shape[1:] = [1] * (len(view_shape) - 1) 61 | repeat_shape = list(idx.shape) 62 | repeat_shape[0] = 1 63 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 64 | new_points = points[batch_indices, idx, :] 65 | return new_points 66 | 67 | 68 | def knn_point(pk, point_input, point_output): 69 | m = point_output.size()[1] 70 | n = point_input.size()[1] 71 | 72 | inner = -2 * torch.matmul(point_output, point_input.transpose(2, 1).contiguous()) 73 | xx = torch.sum(point_output ** 2, dim=2, keepdim=True).repeat(1, 1, n) 74 | yy = torch.sum(point_input ** 2, dim=2, keepdim=False).unsqueeze(1).repeat(1, m, 1) 75 | pairwise_distance = -xx - inner - yy 76 | dist, idx = pairwise_distance.topk(k=pk, dim=-1) 77 | 78 | return dist, idx 79 | 80 | 81 | def edge_preserve_sampling(feature_input, point_input, num_samples, k=10): 82 | batch_size = feature_input.size()[0] 83 | feature_size = feature_input.size()[1] 84 | num_points = feature_input.size()[2] 85 | 86 | p_idx = pn2.furthest_point_sample(point_input, num_samples) 87 | point_output = pn2.gather_operation(point_input.transpose(1, 2).contiguous(), p_idx).transpose(1, 88 | 2).contiguous() # B M 3 89 | 90 | pk = int(min(k, num_points)) 91 | _, pn_idx = knn_point(pk, point_input, point_output) 92 | pn_idx = pn_idx.detach().int() # B M pk 93 | # print(pn_idx.size()) 94 | 95 | # neighbor_feature = pn2.grouping_operation(feature_input, pn_idx) 96 | # neighbor_feature = index_points(feature_input.transpose(1,2).contiguous(), pn_idx).permute(0, 3, 1, 2) 97 | neighbor_feature = pn2.gather_operation(feature_input, pn_idx.view(batch_size, num_samples * pk)).view(batch_size, 98 | feature_size, 99 | num_samples, 100 | pk) 101 | neighbor_feature, _ = torch.max(neighbor_feature, 3) 102 | 103 | center_feature = pn2.grouping_operation(feature_input, p_idx.unsqueeze(2)).view(batch_size, -1, num_samples) 104 | 105 | net = torch.cat((center_feature, neighbor_feature), 1) 106 | 107 | return net, p_idx, pn_idx, point_output 108 | 109 | 110 | def three_nn_upsampling(target_points, source_points): 111 | dist, idx = pn2.three_nn(target_points, source_points) 112 | dist = torch.max(dist, torch.ones(1).cuda() * 1e-10) 113 | norm = torch.sum((1.0 / dist), 2, keepdim=True) 114 | norm = norm.repeat(1, 1, 3) 115 | weight = (1.0 / dist) / norm 116 | 117 | return idx, weight 118 | 119 | 120 | def knn(x, k): 121 | inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x) 122 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 123 | pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous() 124 | 125 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 126 | return idx 127 | 128 | 129 | def get_graph_feature(x, k=20, minus_center=True): 130 | idx = knn(x, k=k) 131 | batch_size, num_points, _ = idx.size() 132 | device = torch.device('cuda') 133 | 134 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 135 | 136 | idx = idx + idx_base 137 | 138 | idx = idx.view(-1) 139 | 140 | _, num_dims, _ = x.size() 141 | 142 | x = x.transpose(2, 1).contiguous() 143 | feature = x.view(batch_size * num_points, -1)[idx, :] 144 | feature = feature.view(batch_size, num_points, k, num_dims) 145 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 146 | 147 | if minus_center: 148 | feature = torch.cat((x, feature - x), dim=3).permute(0, 3, 1, 2) 149 | else: 150 | feature = torch.cat((x, feature), dim=3).permute(0, 3, 1, 2) 151 | return feature 152 | 153 | 154 | def get_uniform_loss(pcd, percentages=[0.004,0.006,0.008,0.010,0.012], radius=1.0): 155 | B, N, C = pcd.size() 156 | 157 | npoint = int(N * 0.05) 158 | # loss=[] 159 | loss = 0 160 | for p in percentages: 161 | nsample = int(N*p) 162 | r = math.sqrt(p*radius) 163 | disk_area = math.pi *(radius ** 2) * p/nsample 164 | new_xyz = pn2.gather_operation(pcd.transpose(1,2).contiguous(), pn2.furthest_point_sample(pcd, npoint)).transpose(1,2).contiguous() 165 | idx = pn2.ball_query(r, nsample, pcd, new_xyz) 166 | 167 | expect_len = math.sqrt(disk_area) 168 | 169 | grouped_pcd = pn2.grouping_operation(pcd.transpose(1,2).contiguous(), idx) 170 | grouped_pcd = grouped_pcd.permute(0, 2, 3, 1).contiguous().view(-1, nsample, 3) 171 | 172 | var, _ = knn_point(2, grouped_pcd, grouped_pcd) 173 | uniform_dis = -var[:, :, 1:] 174 | 175 | uniform_dis = torch.sqrt(torch.abs(uniform_dis+1e-8)) 176 | uniform_dis = torch.mean(uniform_dis, dim=-1) 177 | uniform_dis = ((uniform_dis - expect_len)**2 / (expect_len + 1e-8)) 178 | 179 | mean = torch.mean(uniform_dis) 180 | 181 | mean = mean*math.pow(p*100,2) 182 | 183 | loss += mean 184 | return loss/len(percentages) 185 | 186 | def get_repulsion_loss(pred, nn_size=20, radius=0.07, h=0.03, eps=1e-12): 187 | #print('pred shape', pred.shape) 188 | idx = pn2.ball_query(radius, nn_size, pred, pred) 189 | 190 | #_, idx = knn_point(nn_size, pred, pred) 191 | #idx = idx[:, :, 1:].to(torch.int32) # remove first one 192 | idx = idx.contiguous() # B, N, nn 193 | 194 | pred = pred.transpose(1, 2).contiguous() # B, 3, N 195 | grouped_points = pn2.grouping_operation(pred, idx) # (B, 3, N), (B, N, nn) => (B, 3, N, nn) 196 | 197 | grouped_points = grouped_points - pred.unsqueeze(-1) 198 | dist2 = torch.sum(grouped_points ** 2, dim=1) 199 | dist2 = torch.max(dist2, torch.tensor(eps).cuda()) 200 | dist = torch.sqrt(dist2) 201 | weight = torch.exp(- dist2 / h ** 2) 202 | 203 | uniform_loss = torch.mean((radius - dist) * weight) 204 | # uniform_loss = torch.mean(self.radius - dist * weight) # punet 205 | return uniform_loss -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XLechter/SDT/d87587cc70c4d7bb03fe4a795471984b5a2323ac/overview.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py --mode 0 --datapath /mnt/data1/zwx/completion3d/data/pcn --dataset PCN --batch_size 4 --workers 16 --nepoch 300 --model_name Model --num_points 2048 --log_env PCN --lr 1e-4 --loss CD --use_mean_feature 0 4 | -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=7 python main.py --datapath /mnt/data1/zwx/completion3d/data/pcn --model_dir /mnt/data1/zwx/ECG/trained_model --mode 1 --dataset PCN --batch_size 1 --num_points 16384 --model_name Model --log_env PCN --lr 0.0001 --loss CD --use_mean_feature 0 --workers 16 --nepoch 300 -------------------------------------------------------------------------------- /run_test_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main_benchmark.py --datapath /mnt/data2/zwx/completion3d/data/shapenet --model_dir /mnt/data2/zwx/ECG/trained_model --mode 1 --dataset Completion3D --batch_size 1 --num_points 16384 --model_name Model --log_env PCN --lr 0.0001 --loss CD --use_mean_feature 0 --workers 16 --nepoch 300 4 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from utils.utils import * 2 | import torch 3 | import os 4 | import h5py 5 | import sys 6 | import os 7 | proj_dir = os.path.dirname(os.path.abspath(__file__)) 8 | import open3d 9 | from models.model import Model 10 | import subprocess 11 | 12 | sys.path.append(os.path.join(proj_dir, "utils/ChamferDistancePytorch")) 13 | from chamfer3D import dist_chamfer_3D 14 | from fscore import fscore 15 | chamLoss = dist_chamfer_3D.chamfer_3DDist() 16 | 17 | 18 | def calculate_fscore(gt_array, pr_array, th = 0.01): 19 | '''Calculates the F-score between two point clouds with the corresponding threshold value.''' 20 | print('gt_array.shape', gt_array.shape) 21 | gt = open3d.geometry.PointCloud() 22 | gt.points = open3d.utility.Vector3dVector(gt_array) 23 | pr = open3d.geometry.PointCloud() 24 | pr.points = open3d.utility.Vector3dVector(pr_array) 25 | 26 | d1 = gt.compute_point_cloud_distance(pr) 27 | d2 = pr.compute_point_cloud_distance(gt) 28 | 29 | if len(d1) and len(d2): 30 | recall = float(sum(d < th for d in d2)) / float(len(d2)) 31 | precision = float(sum(d < th for d in d1)) / float(len(d1)) 32 | 33 | if recall + precision > 0: 34 | fscore = 2 * recall * precision / (recall + precision) 35 | else: 36 | fscore = 0 37 | else: 38 | fscore = 0 39 | precision = 0 40 | recall = 0 41 | 42 | return fscore, precision, recall 43 | 44 | 45 | def test(args): 46 | model_dir = args.model_dir 47 | log_test = LogString(open(os.path.join(model_dir, 'log_text.txt'), 'w')) 48 | 49 | if args.dataset == 'SCAN': 50 | dataset_test = SCAN(args.datapath, npoints=args.num_points) 51 | elif args.dataset == 'KITTI': 52 | dataset_test = KITTI(args.datapath, npoints=args.num_points) 53 | else: 54 | dataset_test = PCN(args.datapath, train=False, npoints=args.num_points, test=True) 55 | 56 | dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, 57 | shuffle=False, num_workers=int(args.workers)) 58 | dataset_length = len(dataset_test) 59 | 60 | epochs = ['model.pth'] 61 | for epoch in epochs: 62 | load_path = os.path.join(args.model_dir, epoch) 63 | net = eval(args.model_name)(num_coarse=1024, num_fine=args.num_points) 64 | args.load_model = load_path 65 | 66 | load_model(args, net, None, log_test, train=False) 67 | net.cuda() 68 | net.eval() 69 | log_test.log_string("Testing...") 70 | 71 | # pcd_file = h5py.File(os.path.join(args.model_dir, '%s_pcds.h5' % epoch.split('.')[0]), 'w') 72 | # pcd_file.create_dataset('output_pcds', (dataset_length, args.num_points, 3)) 73 | 74 | test_loss_cd_p = AverageValueMeter() 75 | test_loss_cd_t = AverageValueMeter() 76 | test_f1_score = AverageValueMeter() 77 | 78 | with torch.no_grad(): 79 | for i, data in enumerate(dataloader_test): 80 | label, inputs, gt = data 81 | 82 | inputs = inputs.float().cuda() 83 | gt = gt.float().cuda() 84 | inputs = inputs.transpose(2, 1).contiguous() 85 | 86 | coarse, output = net(inputs) 87 | 88 | # save pcd 89 | # pcd_index1 = args.batch_size * i 90 | # pcd_index2 = args.batch_size * (i + 1) 91 | # pcd_file['output_pcds'][pcd_index1:pcd_index2, :, :] = output.cpu().numpy() 92 | 93 | #g_input_pcd[f"{i}"] = inputs.cpu().numpy() 94 | #g_gt_pcd[f"{i}"] = gt.cpu().numpy() 95 | # g_output_pcd[f"{i}"] = output.cpu().numpy() 96 | # g_coarse_pcd[f"{i}"] = coarse.cpu().numpy() 97 | 98 | # EMD 99 | # dist, _ = EMD(output, gt, 0.004, 3000) 100 | # emd = torch.sqrt(dist).mean(1) 101 | 102 | # CD 103 | dist1, dist2, _, _ = chamLoss(gt, output) 104 | cd_p = (torch.sqrt(dist1).mean(1) + torch.sqrt(dist2).mean(1)) / 2 105 | cd_t = dist1.mean(1) + dist2.mean(1) 106 | # emd = cd_t 107 | 108 | # f1 109 | #f1, _, _ = fscore(dist1, dist2) 110 | 111 | f1, _, _ = calculate_fscore(gt.squeeze().cpu().numpy(), output.squeeze().cpu().numpy()) 112 | 113 | f1 = torch.tensor(f1) 114 | 115 | test_loss_cd_p.update(cd_p.mean().item()) 116 | test_loss_cd_t.update(cd_t.mean().item()) 117 | test_f1_score.update(f1.mean().item()) 118 | 119 | if i % 100 == 0: 120 | log_test.log_string('test [%d/%d]' % (i, dataset_length / args.batch_size)) 121 | 122 | log_test.log_string('Overview results:') 123 | log_test.log_string( 124 | 'CD_p: %f, CD_t: %f, F1: %f' % (test_loss_cd_p.avg, test_loss_cd_t.avg, 125 | test_f1_score.avg)) 126 | #pcd_file.close() 127 | log_test.close() 128 | 129 | -------------------------------------------------------------------------------- /test_benchmark.py: -------------------------------------------------------------------------------- 1 | from utils.utils import * 2 | import torch 3 | import os 4 | import h5py 5 | import sys 6 | import os 7 | proj_dir = os.path.dirname(os.path.abspath(__file__)) 8 | from models.model import Model 9 | 10 | import subprocess 11 | 12 | sys.path.append(os.path.join(proj_dir, "utils/ChamferDistancePytorch")) 13 | from chamfer3D import dist_chamfer_3D 14 | from fscore import fscore 15 | chamLoss = dist_chamfer_3D.chamfer_3DDist() 16 | 17 | from matplotlib import pyplot as plt 18 | from mpl_toolkits.mplot3d import Axes3D 19 | 20 | def test(args): 21 | model_dir = args.model_dir 22 | log_test = LogString(open(os.path.join(model_dir, 'log_text.txt'), 'w')) 23 | dataset_test = Completion3D(args.datapath, train=False, npoints=args.num_points, use_mean_feature=args.use_mean_feature, benchmark=True) 24 | dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, 25 | shuffle=False, num_workers=int(args.workers)) 26 | dataset_length = len(dataset_test) 27 | print(dataset_length) 28 | epochs = ['model.pth'] 29 | 30 | odir = 'benchmark/' 31 | 32 | if not os.path.exists(odir): 33 | os.makedirs(odir) 34 | 35 | for epoch in epochs: 36 | load_path = os.path.join(args.model_dir, epoch) 37 | net = eval(args.model_name)(num_coarse=1024, num_fine=args.num_points, benchmark=True) 38 | args.load_model = load_path 39 | load_model(args, net, None, log_test, train=False) 40 | net.cuda() 41 | net.eval() 42 | log_test.log_string("Testing...") 43 | with torch.no_grad(): 44 | for i, data in enumerate(dataloader_test): 45 | label, inputs, _ = data 46 | 47 | inputs = inputs.float().cuda() 48 | inputs = inputs.transpose(2, 1).contiguous() 49 | 50 | _, output = net(inputs) 51 | output_numpy = output.data.cpu().numpy() 52 | #print('output.shape', output.shape) 53 | for idx in range(output_numpy.shape[0]): 54 | fname = label[idx].split('/')[-1] 55 | #print('fname:', idx, fname) 56 | outp = output_numpy[idx:idx + 1, ...].squeeze() 57 | print('outp.shape', outp.shape) 58 | dir = os.path.join(odir, 'all') 59 | if not os.path.exists(dir): 60 | os.makedirs(dir) 61 | ofile = os.path.join(dir, fname) 62 | print("Saving to %s ..." % (ofile)) 63 | # pltname = ofile.replace('h5', 'png').replace('all', 'plot') 64 | # plot_pcd_three_views(pltname, [outp], ['partial']) 65 | with h5py.File(ofile, "w") as f: 66 | f.create_dataset("data", data=outp) 67 | 68 | cur_dir = os.getcwd() 69 | cmd = "cd %s; zip -r submission.zip * ; cd %s" % (odir, cur_dir) 70 | process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 71 | _, _ = process.communicate() 72 | print("Submission file has been saved to %s/submission.zip" % odir) 73 | 74 | log_test.close() 75 | 76 | def plot_pcd_three_views(filename, pcds, titles, suptitle='', sizes=None, cmap='Reds', zdir='y', 77 | xlim=(-0.3, 0.3), ylim=(-0.3, 0.3), zlim=(-0.3, 0.3)): 78 | if sizes is None: 79 | sizes = [0.5 for i in range(len(pcds))] 80 | fig = plt.figure(figsize=(len(pcds) * 3, 9)) 81 | for i in range(3): 82 | elev = 30 83 | azim = -45 + 90 * i 84 | for j, (pcd, size) in enumerate(zip(pcds, sizes)): 85 | color = pcd[:, 0] 86 | ax = fig.add_subplot(3, len(pcds), i * len(pcds) + j + 1, projection='3d') 87 | ax.view_init(elev, azim) 88 | ax.scatter(pcd[:, 0], pcd[:, 1], pcd[:, 2], zdir=zdir, c=color, s=size, cmap=cmap, vmin=-1, vmax=0.5) 89 | ax.set_title(titles[j]) 90 | ax.set_axis_off() 91 | ax.set_xlim(xlim) 92 | ax.set_ylim(ylim) 93 | ax.set_zlim(zlim) 94 | plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9, wspace=0.1, hspace=0.1) 95 | plt.suptitle(suptitle) 96 | fig.savefig(filename) 97 | plt.close(fig) 98 | -------------------------------------------------------------------------------- /test_per_category_l2cd.py: -------------------------------------------------------------------------------- 1 | from utils.utils import * 2 | import torch 3 | import os 4 | import h5py 5 | import sys 6 | import os 7 | proj_dir = os.path.dirname(os.path.abspath(__file__)) 8 | import open3d 9 | from models.model import Model 10 | import subprocess 11 | 12 | sys.path.append(os.path.join(proj_dir, "utils/ChamferDistancePytorch")) 13 | from chamfer3D import dist_chamfer_3D 14 | from fscore import fscore 15 | chamLoss = dist_chamfer_3D.chamfer_3DDist() 16 | 17 | 18 | def calculate_fscore(gt_array, pr_array, th = 0.01): 19 | '''Calculates the F-score between two point clouds with the corresponding threshold value.''' 20 | print('gt_array.shape', gt_array.shape) 21 | gt = open3d.geometry.PointCloud() 22 | gt.points = open3d.utility.Vector3dVector(gt_array) 23 | pr = open3d.geometry.PointCloud() 24 | pr.points = open3d.utility.Vector3dVector(pr_array) 25 | 26 | d1 = gt.compute_point_cloud_distance(pr) 27 | d2 = pr.compute_point_cloud_distance(gt) 28 | 29 | if len(d1) and len(d2): 30 | recall = float(sum(d < th for d in d2)) / float(len(d2)) 31 | precision = float(sum(d < th for d in d1)) / float(len(d1)) 32 | 33 | if recall + precision > 0: 34 | fscore = 2 * recall * precision / (recall + precision) 35 | else: 36 | fscore = 0 37 | else: 38 | fscore = 0 39 | precision = 0 40 | recall = 0 41 | 42 | return fscore, precision, recall 43 | 44 | 45 | def test(args): 46 | model_dir = args.model_dir 47 | log_test = LogString(open(os.path.join(model_dir, 'log_text.txt'), 'w')) 48 | 49 | if args.dataset == 'SCAN': 50 | dataset_test = SCAN(args.datapath, npoints=args.num_points) 51 | elif args.dataset == 'KITTI': 52 | dataset_test = KITTI(args.datapath, npoints=args.num_points) 53 | else: 54 | dataset_test = PCN(args.datapath, train=False, npoints=args.num_points, test=True) 55 | 56 | dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, 57 | shuffle=False, num_workers=int(args.workers)) 58 | dataset_length = len(dataset_test) 59 | 60 | epochs = ['model.pth'] 61 | for epoch in epochs: 62 | load_path = os.path.join(args.model_dir, epoch) 63 | net = eval(args.model_name)(num_coarse=1024, num_fine=args.num_points) 64 | args.load_model = load_path 65 | 66 | load_model(args, net, None, log_test, train=False) 67 | net.cuda() 68 | net.eval() 69 | log_test.log_string("Testing...") 70 | 71 | pcd_file = h5py.File(os.path.join(args.model_dir, '%s_pcds.h5' % epoch.split('.')[0]), 'w') 72 | pcd_file.create_dataset('output_pcds', (dataset_length, args.num_points, 3)) 73 | 74 | test_loss_cd_p = AverageValueMeter() 75 | test_loss_cd_t = AverageValueMeter() 76 | test_f1_score = AverageValueMeter() 77 | 78 | cd_per_cat = {} 79 | 80 | with torch.no_grad(): 81 | for i, data in enumerate(dataloader_test): 82 | label, inputs, gt = data 83 | 84 | synset_id= str(label).split('/')[4] 85 | #print(synset_id) 86 | 87 | inputs = inputs.float().cuda() 88 | gt = gt.float().cuda() 89 | inputs = inputs.transpose(2, 1).contiguous() 90 | 91 | coarse, output = net(inputs) 92 | 93 | # save pcd 94 | # pcd_index1 = args.batch_size * i 95 | # pcd_index2 = args.batch_size * (i + 1) 96 | # pcd_file['output_pcds'][pcd_index1:pcd_index2, :, :] = output.cpu().numpy() 97 | 98 | #g_input_pcd[f"{i}"] = inputs.cpu().numpy() 99 | #g_gt_pcd[f"{i}"] = gt.cpu().numpy() 100 | # g_output_pcd[f"{i}"] = output.cpu().numpy() 101 | # g_coarse_pcd[f"{i}"] = coarse.cpu().numpy() 102 | 103 | # EMD 104 | # dist, _ = EMD(output, gt, 0.004, 3000) 105 | # emd = torch.sqrt(dist).mean(1) 106 | 107 | # CD 108 | dist1, dist2, _, _ = chamLoss(gt, output) 109 | cd_p = (torch.sqrt(dist1).mean(1) + torch.sqrt(dist2).mean(1)) / 2 110 | cd_t = dist1.mean(1) + dist2.mean(1) 111 | emd = cd_t 112 | 113 | if not cd_per_cat.get(synset_id): 114 | cd_per_cat[synset_id] = [] 115 | cd_per_cat[synset_id].append(cd_t.squeeze().cpu().numpy()) 116 | 117 | # f1 118 | #f1, _, _ = fscore(dist1, dist2) 119 | 120 | f1, _, _ = calculate_fscore(gt.squeeze().cpu().numpy(), output.squeeze().cpu().numpy()) 121 | 122 | f1 = torch.tensor(f1) 123 | 124 | test_loss_cd_p.update(cd_p.mean().item()) 125 | test_loss_cd_t.update(cd_t.mean().item()) 126 | test_f1_score.update(f1.mean().item()) 127 | 128 | if i % 100 == 0: 129 | log_test.log_string('test [%d/%d]' % (i, dataset_length / args.batch_size)) 130 | 131 | log_test.log_string('Overview results:') 132 | log_test.log_string( 133 | 'CD_p: %f, CD_t: %f, F1: %f' % (test_loss_cd_p.avg, test_loss_cd_t.avg, 134 | test_f1_score.avg)) 135 | dict_known = {'02691156': 'airplane','02933112': 'cabinet', '02958343': 'car', '03001627': 'chair', '03636649': 'lamp', '04256520': 'sofa', 136 | '04379243' : 'table','04530566': 'vessel'} 137 | for synset_id in dict_known.keys(): 138 | print(dict_known[synset_id], ' %f' % np.mean(cd_per_cat[synset_id])) 139 | break 140 | 141 | pcd_file.close() 142 | log_test.close() 143 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | from models.full_model import FullModel 2 | import torch 3 | import torch.optim as optim 4 | from utils.utils import * 5 | from models.model import Model 6 | 7 | def train(args): 8 | # setup 9 | vis, log_model, log_train, log_path = log_setup(args) 10 | train_curve, val_curves = vis_curve_setup() 11 | best_val_losses, best_val_epochs = best_loss_setup() 12 | train_loss_meter, val_loss_meters = loss_average_meter_setup() 13 | dataset, dataset_test, dataloader, dataloader_test = data_setup(args, log_train) 14 | seed_setup(args, log_train) 15 | 16 | print('args.num_points', args.num_points) 17 | 18 | # model 19 | net = eval(args.model_name)(num_fine=args.num_points).cuda() 20 | net = torch.nn.DataParallel(FullModel(net)) 21 | log_model.log_string(str(net.module.model) + '\n', stdout=False) 22 | 23 | # optim 24 | lrate = args.lr # learning rate 25 | optimizer = optim.Adam(net.module.model.parameters(), lr=lrate) 26 | load_model(args, net, optimizer, log_train) 27 | 28 | for epoch in range(args.resume_epoch, args.nepoch): 29 | train_loss_meter.reset() 30 | net.module.model.train() 31 | 32 | if epoch > 0 and epoch % 20 == 0: 33 | lrate = max(lrate * 0.7, 1e-6) 34 | for param_group in optimizer.param_groups: 35 | param_group['lr'] = lrate 36 | 37 | if epoch < 5: 38 | alpha = 0.01 39 | elif epoch < 10: 40 | alpha = 0.1 41 | elif epoch < 15: 42 | alpha = 0.5 43 | else: 44 | alpha = 1.0 45 | 46 | # if epoch > 0 and epoch % 60 == 0: 47 | # lrate = max(lrate * 0.7, 1e-6) 48 | # for param_group in optimizer.param_groups: 49 | # param_group['lr'] = lrate 50 | # 51 | # if epoch < 5: 52 | # alpha = 0.01 53 | # elif epoch < 15: 54 | # alpha = 0.1 55 | # elif epoch < 30: 56 | # alpha = 0.5 57 | # else: 58 | # alpha = 1.0 59 | 60 | for i, data in enumerate(dataloader, 0): 61 | optimizer.zero_grad() 62 | label, inputs, gt = data 63 | inputs = inputs.float().cuda() 64 | gt = gt.float().cuda() 65 | inputs = inputs.transpose(2, 1).contiguous() 66 | 67 | # if args.loss == 'EMD': 68 | # output1, output2, emd1, emd2, cd1_p, cd2_p, cd1_t, cd2_t, u1, u2 = net(inputs, gt.contiguous(), 0.005, 50, True, False) 69 | # loss_net = emd1.mean() + u1.mean()*0.1 + (emd2.mean() + u2.mean()*0.1) * alpha 70 | # train_loss_meter.update(emd2.mean().item()) 71 | # else: 72 | output1, output2, emd1, emd2, cd1_p, cd2_p, cd1_t, cd2_t, u1, u2, origin_cd_p1, origin_cd_p2 = net(inputs, gt.contiguous(), 0.005, 50, False, True) 73 | loss_net = cd1_p.mean() + u1.mean()*0.1 + (cd2_p.mean() + u2.mean()*0.1) * alpha 74 | 75 | #loss_net = cd1_t.mean() + u1.mean() * 0.1 + (cd2_t.mean() + u2.mean() * 0.1) * alpha 76 | train_loss_meter.update(cd2_p.mean().item()) 77 | loss_net.backward() 78 | optimizer.step() 79 | 80 | if i % 100 == 0: 81 | log_train.log_string(args.log_env + ' train [%d: %d/%d] emd1: %f emd2: %f cd1_p: %f cd2_p: %f cd1_t: %f cd2_t: %f u1: %f u2: %f' % ( 82 | epoch, i, len(dataset) / args.batch_size, emd1.mean().item(), emd2.mean().item(), cd1_p.mean().item(), cd2_p.mean().item(), 83 | cd1_t.mean().item(), cd2_t.mean().item(), u1.mean().item(), u2.mean().item())) 84 | best_val_losses, best_val_epochs, save_cd_p, save_cd_t = val(args, net, epoch, dataloader_test, log_train, 85 | best_val_losses, best_val_epochs, val_loss_meters) 86 | if save_cd_p: 87 | save_model('%s/best_cd_p_network.pth' % log_path, net, optimizer) 88 | log_train.log_string('saving best cd_p net...') 89 | 90 | if save_cd_t: 91 | save_model('%s/best_cd_t_network.pth' % log_path, net, optimizer) 92 | log_train.log_string('saving best cd_t net...') 93 | 94 | train_curve.append(train_loss_meter.avg) 95 | #if epoch % 5 == 0: 96 | #save_model('%s/network_%d.pth' % (log_path, epoch), net, optimizer) 97 | save_model('%s/network.pth' % log_path, net, optimizer) 98 | log_train.log_string("saving net...") 99 | 100 | net.module.model.eval() 101 | # VALIDATION 102 | if epoch % 1 == 0 or epoch == args.nepoch - 1: 103 | best_val_losses, best_val_epochs, _, _ = val(args, net, epoch, dataloader_test, log_train, 104 | best_val_losses, best_val_epochs, val_loss_meters) 105 | if best_val_epochs["best_emd_epoch"] == epoch: 106 | save_model('%s/best_emd_network.pth' % log_path, net, optimizer) 107 | log_train.log_string('saving best emd net...') 108 | if best_val_epochs["best_cd_p_epoch"] == epoch: 109 | save_model('%s/best_cd_p_network.pth' % log_path, net, optimizer) 110 | log_train.log_string('saving best cd_p net...') 111 | if best_val_epochs["best_cd_t_epoch"] == epoch: 112 | save_model('%s/best_cd_t_network.pth' % log_path, net, optimizer) 113 | log_train.log_string('saving best cd_t net...') 114 | 115 | val_curves["val_curve_emd"].append(val_loss_meters["val_loss_emd"].avg) 116 | val_curves["val_curve_cd_p"].append(val_loss_meters["val_loss_cd_p"].avg) 117 | val_curves["val_curve_cd_t"].append(val_loss_meters["val_loss_cd_t"].avg) 118 | vis_curve_plot(vis, args.log_env, train_curve, val_curves) 119 | epoch_log(args, log_model, train_loss_meter, val_loss_meters, best_val_losses, best_val_epochs) 120 | net.module.model.train() 121 | 122 | log_model.close() 123 | log_train.close() 124 | -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 ThibaultGROUEIX 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Chamfer Distance. 2 | 3 | Include a **CUDA** version, and a **PYTHON** version with pytorch standard operations. 4 | NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt thresholds accordingly. 5 | 6 | - [x] F - Score 7 | 8 | 9 | 10 | ### CUDA VERSION 11 | 12 | - [x] JIT compilation 13 | - [x] Supports multi-gpu 14 | - [x] 2D point clouds. 15 | - [x] 3D point clouds. 16 | - [x] 5D point clouds. 17 | - [x] Contiguous() safe. 18 | 19 | 20 | 21 | ### Python Version 22 | 23 | - [x] Supports any dimension 24 | 25 | 26 | 27 | ### Usage 28 | 29 | ```python 30 | import torch, chamfer3D.dist_chamfer_3D, fscore 31 | chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist() 32 | points1 = torch.rand(32, 1000, 3).cuda() 33 | points2 = torch.rand(32, 2000, 3, requires_grad=True).cuda() 34 | dist1, dist2, idx1, idx2 = chamLoss(points1, points2) 35 | f_score, precision, recall = fscore.fscore(dist1, dist2) 36 | ``` 37 | 38 | 39 | 40 | ### Add it to your project as a submodule 41 | 42 | ```shell 43 | git submodule add https://github.com/ThibaultGROUEIX/ChamferDistancePytorch 44 | ``` 45 | 46 | 47 | 48 | ### Benchmark: [forward + backward] pass 49 | - [x] CUDA 10.1, NVIDIA 435, Pytorch 1.4 50 | - [x] p1 : 32 x 2000 x dim 51 | - [x] p2 : 32 x 1000 x dim 52 | 53 | | *Timing (sec * 1000)* | 2D | 3D | 5D | 54 | | ---------- | -------- | ------- | ------- | 55 | | **Cuda Compiled** | **1.2** | 1.4 |1.8 | 56 | | **Cuda JIT** | 1.3 | **1.4** |**1.5** | 57 | | **Python** | 37 | 37 | 37 | 58 | 59 | 60 | | *Memory (MB)* | 2D | 3D | 5D | 61 | | ---------- | -------- | ------- | ------- | 62 | | **Cuda Compiled** | 529 | 529 | 549 | 63 | | **Cuda JIT** | **520** | **529** |**549** | 64 | | **Python** | 2495 | 2495 | 2495 | 65 | 66 | 67 | 68 | ### What is the chamfer distance ? 69 | 70 | [Stanford course](http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf) on 3D deep Learning 71 | 72 | 73 | 74 | ### Aknowledgment 75 | 76 | Original backbone from [Fei Xia](https://github.com/fxia22/pointGAN/blob/master/nndistance/src/nnd_cuda.cu). 77 | 78 | JIT cool trick from [Christian Diller](https://github.com/chrdiller) 79 | 80 | ### Troubleshoot 81 | 82 | - `Undefined symbol: Zxxxxxxxxxxxxxxxxx `: 83 | 84 | --> Fix: Make sure to `import torch` before you `import chamfer`. 85 | --> Use pytorch.version >= 1.1.0 86 | 87 | - [RuntimeError: Ninja is required to load C++ extension](https://github.com/zhanghang1989/PyTorch-Encoding/issues/167) 88 | 89 | ```shell 90 | wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip 91 | sudo unzip ninja-linux.zip -d /usr/local/bin/ 92 | sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 93 | ``` 94 | 95 | 96 | 97 | 98 | 99 | #### TODO: 100 | 101 | * Discuss behaviour of torch.min() and tensor.min() which causes issues in some pytorch versions 102 | -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer2D/chamfer2D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=512; 14 | __shared__ float buf[batch*2]; 15 | for (int i=blockIdx.x;ibest){ 117 | result[(i*n+j)]=best; 118 | result_i[(i*n+j)]=best_i; 119 | } 120 | } 121 | __syncthreads(); 122 | } 123 | } 124 | } 125 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 126 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 127 | 128 | const auto batch_size = xyz1.size(0); 129 | const auto n = xyz1.size(1); //num_points point cloud A 130 | const auto m = xyz2.size(1); //num_points point cloud B 131 | 132 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 133 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 134 | 135 | cudaError_t err = cudaGetLastError(); 136 | if (err != cudaSuccess) { 137 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 138 | //THError("aborting"); 139 | return 0; 140 | } 141 | return 1; 142 | 143 | 144 | } 145 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 146 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 171 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 172 | 173 | cudaError_t err = cudaGetLastError(); 174 | if (err != cudaSuccess) { 175 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 176 | //THError("aborting"); 177 | return 0; 178 | } 179 | return 1; 180 | 181 | } 182 | 183 | -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer2D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer2D/dist_chamfer_2D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | chamfer_found = importlib.find_loader("chamfer_2D") is not None 7 | if not chamfer_found: 8 | ## Cool trick from https://github.com/chrdiller 9 | print("Jitting Chamfer 2D") 10 | 11 | from torch.utils.cpp_extension import load 12 | chamfer_2D = load(name="chamfer_2D", 13 | sources=[ 14 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 15 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]), 16 | ]) 17 | print("Loaded JIT 2D CUDA chamfer distance") 18 | 19 | else: 20 | import chamfer_2D 21 | print("Loaded compiled 2D CUDA chamfer distance") 22 | 23 | # Chamfer's distance module @thibaultgroueix 24 | # GPU tensors only 25 | class chamfer_2DFunction(Function): 26 | @staticmethod 27 | def forward(ctx, xyz1, xyz2): 28 | batchsize, n, _ = xyz1.size() 29 | _, m, _ = xyz2.size() 30 | device = xyz1.device 31 | 32 | dist1 = torch.zeros(batchsize, n) 33 | dist2 = torch.zeros(batchsize, m) 34 | 35 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 36 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 37 | 38 | dist1 = dist1.to(device) 39 | dist2 = dist2.to(device) 40 | idx1 = idx1.to(device) 41 | idx2 = idx2.to(device) 42 | torch.cuda.set_device(device) 43 | 44 | chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 45 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 46 | return dist1, dist2, idx1, idx2 47 | 48 | @staticmethod 49 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 50 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 51 | graddist1 = graddist1.contiguous() 52 | graddist2 = graddist2.contiguous() 53 | device = graddist1.device 54 | 55 | gradxyz1 = torch.zeros(xyz1.size()) 56 | gradxyz2 = torch.zeros(xyz2.size()) 57 | 58 | gradxyz1 = gradxyz1.to(device) 59 | gradxyz2 = gradxyz2.to(device) 60 | chamfer_2D.backward( 61 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 62 | ) 63 | return gradxyz1, gradxyz2 64 | 65 | 66 | class chamfer_2DDist(nn.Module): 67 | def __init__(self): 68 | super(chamfer_2DDist, self).__init__() 69 | 70 | def forward(self, input1, input2): 71 | input1 = input1.contiguous() 72 | input2 = input2.contiguous() 73 | return chamfer_2DFunction.apply(input1, input2) 74 | -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer2D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_2D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_2D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer3D/chamfer3D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=512; 14 | __shared__ float buf[batch*3]; 15 | for (int i=blockIdx.x;ibest){ 127 | result[(i*n+j)]=best; 128 | result_i[(i*n+j)]=best_i; 129 | } 130 | } 131 | __syncthreads(); 132 | } 133 | } 134 | } 135 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 136 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 137 | 138 | const auto batch_size = xyz1.size(0); 139 | const auto n = xyz1.size(1); //num_points point cloud A 140 | const auto m = xyz2.size(1); //num_points point cloud B 141 | 142 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 143 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 144 | 145 | cudaError_t err = cudaGetLastError(); 146 | if (err != cudaSuccess) { 147 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 148 | //THError("aborting"); 149 | return 0; 150 | } 151 | return 1; 152 | 153 | 154 | } 155 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 156 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 185 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 186 | 187 | cudaError_t err = cudaGetLastError(); 188 | if (err != cudaSuccess) { 189 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 190 | //THError("aborting"); 191 | return 0; 192 | } 193 | return 1; 194 | 195 | } 196 | 197 | -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer3D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer3D/dist_chamfer_3D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | chamfer_found = importlib.find_loader("chamfer_3D") is not None 7 | if not chamfer_found: 8 | ## Cool trick from https://github.com/chrdiller 9 | print("Jitting Chamfer 3D") 10 | 11 | from torch.utils.cpp_extension import load 12 | chamfer_3D = load(name="chamfer_3D", 13 | sources=[ 14 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 15 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), 16 | ]) 17 | print("Loaded JIT 3D CUDA chamfer distance") 18 | 19 | else: 20 | import chamfer_3D 21 | print("Loaded compiled 3D CUDA chamfer distance") 22 | 23 | 24 | # Chamfer's distance module @thibaultgroueix 25 | # GPU tensors only 26 | class chamfer_3DFunction(Function): 27 | @staticmethod 28 | def forward(ctx, xyz1, xyz2): 29 | batchsize, n, _ = xyz1.size() 30 | _, m, _ = xyz2.size() 31 | device = xyz1.device 32 | 33 | dist1 = torch.zeros(batchsize, n) 34 | dist2 = torch.zeros(batchsize, m) 35 | 36 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 37 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 38 | 39 | dist1 = dist1.to(device) 40 | dist2 = dist2.to(device) 41 | idx1 = idx1.to(device) 42 | idx2 = idx2.to(device) 43 | torch.cuda.set_device(device) 44 | 45 | chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 46 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 47 | return dist1, dist2, idx1, idx2 48 | 49 | @staticmethod 50 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 51 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 52 | graddist1 = graddist1.contiguous() 53 | graddist2 = graddist2.contiguous() 54 | device = graddist1.device 55 | 56 | gradxyz1 = torch.zeros(xyz1.size()) 57 | gradxyz2 = torch.zeros(xyz2.size()) 58 | 59 | gradxyz1 = gradxyz1.to(device) 60 | gradxyz2 = gradxyz2.to(device) 61 | chamfer_3D.backward( 62 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 63 | ) 64 | return gradxyz1, gradxyz2 65 | 66 | 67 | class chamfer_3DDist(nn.Module): 68 | def __init__(self): 69 | super(chamfer_3DDist, self).__init__() 70 | 71 | def forward(self, input1, input2): 72 | input1 = input1.contiguous() 73 | input2 = input2.contiguous() 74 | return chamfer_3DFunction.apply(input1, input2) 75 | 76 | -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer3D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_3D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_3D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer5D/chamfer5D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=2048; 14 | __shared__ float buf[batch*5]; 15 | for (int i=blockIdx.x;ibest){ 147 | result[(i*n+j)]=best; 148 | result_i[(i*n+j)]=best_i; 149 | } 150 | } 151 | __syncthreads(); 152 | } 153 | } 154 | } 155 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 156 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 157 | 158 | const auto batch_size = xyz1.size(0); 159 | const auto n = xyz1.size(1); //num_points point cloud A 160 | const auto m = xyz2.size(1); //num_points point cloud B 161 | 162 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 163 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 164 | 165 | cudaError_t err = cudaGetLastError(); 166 | if (err != cudaSuccess) { 167 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 168 | //THError("aborting"); 169 | return 0; 170 | } 171 | return 1; 172 | 173 | 174 | } 175 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 176 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 213 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 214 | 215 | cudaError_t err = cudaGetLastError(); 216 | if (err != cudaSuccess) { 217 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 218 | //THError("aborting"); 219 | return 0; 220 | } 221 | return 1; 222 | 223 | } 224 | -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer5D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer5D/dist_chamfer_5D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | 7 | chamfer_found = importlib.find_loader("chamfer_5D") is not None 8 | if not chamfer_found: 9 | ## Cool trick from https://github.com/chrdiller 10 | print("Jitting Chamfer 5D") 11 | 12 | from torch.utils.cpp_extension import load 13 | chamfer_5D = load(name="chamfer_5D", 14 | sources=[ 15 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 16 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]), 17 | ]) 18 | print("Loaded JIT 5D CUDA chamfer distance") 19 | 20 | else: 21 | import chamfer_5D 22 | print("Loaded compiled 5D CUDA chamfer distance") 23 | 24 | 25 | # Chamfer's distance module @thibaultgroueix 26 | # GPU tensors only 27 | class chamfer_5DFunction(Function): 28 | @staticmethod 29 | def forward(ctx, xyz1, xyz2): 30 | batchsize, n, _ = xyz1.size() 31 | _, m, _ = xyz2.size() 32 | device = xyz1.device 33 | 34 | dist1 = torch.zeros(batchsize, n) 35 | dist2 = torch.zeros(batchsize, m) 36 | 37 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 38 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 39 | 40 | dist1 = dist1.to(device) 41 | dist2 = dist2.to(device) 42 | idx1 = idx1.to(device) 43 | idx2 = idx2.to(device) 44 | torch.cuda.set_device(device) 45 | 46 | chamfer_5D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 47 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 48 | return dist1, dist2, idx1, idx2 49 | 50 | @staticmethod 51 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 52 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 53 | graddist1 = graddist1.contiguous() 54 | graddist2 = graddist2.contiguous() 55 | device = graddist1.device 56 | 57 | gradxyz1 = torch.zeros(xyz1.size()) 58 | gradxyz2 = torch.zeros(xyz2.size()) 59 | 60 | gradxyz1 = gradxyz1.to(device) 61 | gradxyz2 = gradxyz2.to(device) 62 | chamfer_5D.backward( 63 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 64 | ) 65 | return gradxyz1, gradxyz2 66 | 67 | 68 | class chamfer_5DDist(nn.Module): 69 | def __init__(self): 70 | super(chamfer_5DDist, self).__init__() 71 | 72 | def forward(self, input1, input2): 73 | input1 = input1.contiguous() 74 | input2 = input2.contiguous() 75 | return chamfer_5DFunction.apply(input1, input2) 76 | -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer5D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_5D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_5D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/chamfer_python.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pairwise_dist(x, y): 5 | xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) 6 | rx = xx.diag().unsqueeze(0).expand_as(xx) 7 | ry = yy.diag().unsqueeze(0).expand_as(yy) 8 | P = rx.t() + ry - 2 * zz 9 | return P 10 | 11 | 12 | def NN_loss(x, y, dim=0): 13 | dist = pairwise_dist(x, y) 14 | values, indices = dist.min(dim=dim) 15 | return values.mean() 16 | 17 | 18 | def distChamfer(a, b): 19 | """ 20 | :param a: Pointclouds Batch x nul_points x dim 21 | :param b: Pointclouds Batch x nul_points x dim 22 | :return: 23 | -closest point on b of points from a 24 | -closest point on a of points from b 25 | -idx of closest point on b of points from a 26 | -idx of closest point on a of points from b 27 | Works for pointcloud of any dimension 28 | """ 29 | x, y = a.double(), b.double() 30 | bs, num_points_x, points_dim = x.size() 31 | bs, num_points_y, points_dim = y.size() 32 | 33 | xx = torch.pow(x, 2).sum(2) 34 | yy = torch.pow(y, 2).sum(2) 35 | zz = torch.bmm(x, y.transpose(2, 1)) 36 | rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx 37 | ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy 38 | P = rx.transpose(2, 1) + ry - 2 * zz 39 | return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int() 40 | 41 | -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/fscore.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def fscore(dist1, dist2, threshold=0.001): 4 | """ 5 | Calculates the F-score between two point clouds with the corresponding threshold value. 6 | :param dist1: Batch, N-Points 7 | :param dist2: Batch, N-Points 8 | :param th: float 9 | :return: fscore, precision, recall 10 | """ 11 | # NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly. 12 | precision_1 = torch.mean((dist1 < threshold).float(), dim=1) 13 | precision_2 = torch.mean((dist2 < threshold).float(), dim=1) 14 | fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2) 15 | fscore[torch.isnan(fscore)] = 0 16 | return fscore, precision_1, precision_2 17 | 18 | -------------------------------------------------------------------------------- /utils/ChamferDistancePytorch/unit_test.py: -------------------------------------------------------------------------------- 1 | import torch, time 2 | import chamfer2D.dist_chamfer_2D 3 | import chamfer3D.dist_chamfer_3D 4 | import chamfer5D.dist_chamfer_5D 5 | import chamfer_python 6 | 7 | cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist() 8 | cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist() 9 | cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist() 10 | 11 | from torch.autograd import Variable 12 | from fscore import fscore 13 | 14 | def test_chamfer(distChamfer, dim): 15 | points1 = torch.rand(4, 100, dim).cuda() 16 | points2 = torch.rand(4, 200, dim, requires_grad=True).cuda() 17 | dist1, dist2, idx1, idx2= distChamfer(points1, points2) 18 | 19 | loss = torch.sum(dist1) 20 | loss.backward() 21 | 22 | mydist1, mydist2, myidx1, myidx2 = chamfer_python.distChamfer(points1, points2) 23 | d1 = (dist1 - mydist1) ** 2 24 | d2 = (dist2 - mydist2) ** 2 25 | assert ( 26 | torch.mean(d1) + torch.mean(d2) < 0.00000001 27 | ), "chamfer cuda and chamfer normal are not giving the same results" 28 | 29 | xd1 = idx1 - myidx1 30 | xd2 = idx2 - myidx2 31 | assert ( 32 | torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0 33 | ), "chamfer cuda and chamfer normal are not giving the same results" 34 | print(f"fscore :", fscore(dist1, dist2)) 35 | print("Unit test passed") 36 | 37 | 38 | def timings(distChamfer, dim): 39 | p1 = torch.rand(32, 2000, dim).cuda() 40 | p2 = torch.rand(32, 1000, dim).cuda() 41 | print("Timings : Start CUDA version") 42 | start = time.time() 43 | num_it = 100 44 | for i in range(num_it): 45 | points1 = Variable(p1, requires_grad=True) 46 | points2 = Variable(p2) 47 | mydist1, mydist2, idx1, idx2 = distChamfer(points1, points2) 48 | loss = torch.sum(mydist1) 49 | loss.backward() 50 | print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") 51 | 52 | 53 | print("Timings : Start Pythonic version") 54 | start = time.time() 55 | for i in range(num_it): 56 | points1 = Variable(p1, requires_grad=True) 57 | points2 = Variable(p2) 58 | mydist1, mydist2, idx1, idx2 = chamfer_python.distChamfer(points1, points2) 59 | loss = torch.sum(mydist1) 60 | loss.backward() 61 | print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") 62 | 63 | 64 | 65 | dims = [2,3,5] 66 | for i,cham in enumerate([cham2D, cham3D, cham5D]): 67 | print(f"testing Chamfer {dims[i]}D") 68 | test_chamfer(cham, dims[i]) 69 | timings(cham, dims[i]) 70 | -------------------------------------------------------------------------------- /utils/MDS/MDS.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void minimum_density_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs, float *mean_mst_length); 14 | 15 | 16 | #define CUDA_CHECK_ERRORS() \ 17 | do { \ 18 | cudaError_t err = cudaGetLastError(); \ 19 | if (cudaSuccess != err) { \ 20 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 21 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 22 | __FILE__); \ 23 | exit(-1); \ 24 | } \ 25 | } while (0) 26 | 27 | #define CHECK_CUDA(x) \ 28 | do { \ 29 | AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 30 | } while (0) 31 | 32 | #define CHECK_CONTIGUOUS(x) \ 33 | do { \ 34 | AT_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 35 | } while (0) 36 | 37 | #define CHECK_IS_INT(x) \ 38 | do { \ 39 | AT_CHECK(x.scalar_type() == at::ScalarType::Int, \ 40 | #x " must be an int tensor"); \ 41 | } while (0) 42 | 43 | #define CHECK_IS_FLOAT(x) \ 44 | do { \ 45 | AT_CHECK(x.scalar_type() == at::ScalarType::Float, \ 46 | #x " must be a float tensor"); \ 47 | } while (0) 48 | 49 | 50 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 51 | CHECK_CONTIGUOUS(points); 52 | CHECK_CONTIGUOUS(idx); 53 | CHECK_IS_FLOAT(points); 54 | CHECK_IS_INT(idx); 55 | 56 | if (points.type().is_cuda()) { 57 | CHECK_CUDA(idx); 58 | } 59 | 60 | at::Tensor output = 61 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 62 | at::device(points.device()).dtype(at::ScalarType::Float)); 63 | 64 | if (points.type().is_cuda()) { 65 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 66 | idx.size(1), points.data(), 67 | idx.data(), output.data()); 68 | } else { 69 | AT_CHECK(false, "CPU not supported"); 70 | } 71 | 72 | return output; 73 | } 74 | 75 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 76 | const int n) { 77 | CHECK_CONTIGUOUS(grad_out); 78 | CHECK_CONTIGUOUS(idx); 79 | CHECK_IS_FLOAT(grad_out); 80 | CHECK_IS_INT(idx); 81 | 82 | if (grad_out.type().is_cuda()) { 83 | CHECK_CUDA(idx); 84 | } 85 | 86 | at::Tensor output = 87 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 88 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 89 | 90 | if (grad_out.type().is_cuda()) { 91 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 92 | idx.size(1), grad_out.data(), 93 | idx.data(), output.data()); 94 | } else { 95 | AT_CHECK(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | at::Tensor minimum_density_sampling(at::Tensor points, const int nsamples, at::Tensor mean_mst_length, at::Tensor output) { 101 | CHECK_CONTIGUOUS(points); 102 | CHECK_IS_FLOAT(points); 103 | 104 | at::Tensor tmp = 105 | torch::zeros({points.size(0), points.size(1)}, 106 | at::device(points.device()).dtype(at::ScalarType::Float)); 107 | 108 | if (points.type().is_cuda()) { 109 | minimum_density_sampling_kernel_wrapper( 110 | points.size(0), points.size(1), nsamples, points.data(), 111 | tmp.data(), output.data(), mean_mst_length.data()); 112 | } else { 113 | AT_CHECK(false, "CPU not supported"); 114 | } 115 | 116 | return output; 117 | } 118 | 119 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 120 | m.def("minimum_density_sampling", &minimum_density_sampling, "minimum_density_sampling (CUDA)"); 121 | m.def("gather_points_grad", &gather_points_grad, "gather_points_grad (CUDA)"); 122 | m.def("gather_points", &gather_points, "gather_points (CUDA)"); 123 | } 124 | -------------------------------------------------------------------------------- /utils/MDS/MDS_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Function 4 | import MDS 5 | 6 | class MinimumDensitySampling(Function): 7 | @staticmethod 8 | def forward(ctx, xyz, npoint, mean_mst_length): 9 | # type: (Any, torch.Tensor, int) -> torch.Tensor 10 | r""" 11 | Uses iterative radius point sampling to select a set of npoint features that have the largest 12 | minimum distance 13 | 14 | Parameters 15 | ---------- 16 | xyz : torch.Tensor 17 | (B, N, 3) tensor where N > npoint 18 | npoint : int32 19 | number of features in the sampled set 20 | mean_mst_length : torch.Tensor 21 | (B) the average edge length from expansion penalty module 22 | 23 | Returns 24 | ------- 25 | torch.Tensor 26 | (B, npoint) tensor containing the set 27 | """ 28 | idx = torch.zeros(xyz.shape[0], npoint, requires_grad= False, device='cuda', dtype=torch.int32).contiguous() 29 | MDS.minimum_density_sampling(xyz, npoint, mean_mst_length, idx) 30 | return idx 31 | 32 | @staticmethod 33 | def backward(grad_idx, a=None): 34 | return None, None, None 35 | 36 | 37 | minimum_density_sample = MinimumDensitySampling.apply 38 | 39 | 40 | class GatherOperation(Function): 41 | @staticmethod 42 | def forward(ctx, features, idx): 43 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 44 | r""" 45 | 46 | Parameters 47 | ---------- 48 | features : torch.Tensor 49 | (B, C, N) tensor 50 | 51 | idx : torch.Tensor 52 | (B, npoint) tensor of the features to gather 53 | 54 | Returns 55 | ------- 56 | torch.Tensor 57 | (B, C, npoint) tensor 58 | """ 59 | 60 | _, C, N = features.size() 61 | 62 | ctx.for_backwards = (idx, C, N) 63 | 64 | return MDS.gather_points(features, idx) 65 | 66 | @staticmethod 67 | def backward(ctx, grad_out): 68 | idx, C, N = ctx.for_backwards 69 | 70 | grad_features = MDS.gather_points_grad(grad_out.contiguous(), idx, N) 71 | return grad_features, None 72 | 73 | 74 | gather_operation = GatherOperation.apply 75 | 76 | -------------------------------------------------------------------------------- /utils/MDS/clean.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -rf __pycache__/ 4 | rm -rf dist/ 5 | rm -rf build/ 6 | rm -rf MDS.egg-info/ 7 | rm -rf /mnt/lustre/chenxinyi1/.conda/envs/pt/lib/python3.7/site-packages/MDS-0.0.0-py3.7-linux-x86_64.egg/ 8 | -------------------------------------------------------------------------------- /utils/MDS/run_compile.sh: -------------------------------------------------------------------------------- 1 | partition=ips_share 2 | job_name=compile 3 | gpus=1 4 | g=$((${gpus}<8?${gpus}:8)) 5 | 6 | 7 | srun -u --partition=${partition} --job-name=${job_name} \ 8 | -n1 --gres=gpu:${gpus} --ntasks-per-node=1 -w 'SH-IDC1-10-198-6-85' \ 9 | python3 MDS_module.py 10 | -------------------------------------------------------------------------------- /utils/MDS/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='MDS', 6 | ext_modules=[ 7 | CUDAExtension('MDS', [ 8 | 'MDS_cuda.cu', 9 | 'MDS.cpp', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Shaoshuai Shi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/README.md: -------------------------------------------------------------------------------- 1 | # Pointnet2.PyTorch 2 | 3 | * PyTorch implementation of [PointNet++](https://arxiv.org/abs/1706.02413) based on [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch). 4 | * Faster than the original codes by re-implementing the CUDA operations. 5 | 6 | ## Installation 7 | ### Requirements 8 | * Linux (tested on Ubuntu 14.04/16.04) 9 | * Python 3.6+ 10 | * PyTorch 1.0 11 | 12 | ### Install 13 | Install this library by running the following command: 14 | 15 | ```shell 16 | cd pointnet2 17 | python setup.py install 18 | cd ../ 19 | ``` 20 | 21 | ## Examples 22 | Here I provide a simple example to use this library in the task of KITTI ourdoor foreground point cloud segmentation, and you could refer to the paper [PointRCNN](https://arxiv.org/abs/1812.04244) for the details of task description and foreground label generation. 23 | 24 | 1. Download the training data from [KITTI 3D object detection](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) website and organize the downloaded files as follows: 25 | ``` 26 | Pointnet2.PyTorch 27 | ├── pointnet2 28 | ├── tools 29 | │ ├──data 30 | │ │ ├── KITTI 31 | │ │ │ ├── ImageSets 32 | │ │ │ ├── object 33 | │ │ │ │ ├──training 34 | │ │ │ │ ├──calib & velodyne & label_2 & image_2 35 | │ │ train_and_eval.py 36 | ``` 37 | 38 | 2. Run the following command to train and evaluate: 39 | ```shell 40 | cd tools 41 | python train_and_eval.py --batch_size 8 --epochs 100 --ckpt_save_interval 2 42 | ``` 43 | 44 | 45 | 46 | ## Project using this repo: 47 | * [PointRCNN](https://github.com/sshaoshuai/PointRCNN): 3D object detector from raw point cloud. 48 | 49 | ## Acknowledgement 50 | * [charlesq34/pointnet2](https://github.com/charlesq34/pointnet2): Paper author and official code repo. 51 | * [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch): Initial work of PyTorch implementation of PointNet++. 52 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda/bin/nvcc 4 | 5 | cflags = -pthread -B /home/user/anaconda3/envs/torch_zwx/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/home/user/anaconda3/envs/torch_zwx/lib/python3.6/site-packages/torch/include -I/home/user/anaconda3/envs/torch_zwx/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -I/home/user/anaconda3/envs/torch_zwx/lib/python3.6/site-packages/torch/include/TH -I/home/user/anaconda3/envs/torch_zwx/lib/python3.6/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/user/anaconda3/envs/torch_zwx/include/python3.6m -c 6 | post_cflags = -g -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=pointnet2_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14 7 | cuda_cflags = -I/home/user/anaconda3/envs/torch_zwx/lib/python3.6/site-packages/torch/include -I/home/user/anaconda3/envs/torch_zwx/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -I/home/user/anaconda3/envs/torch_zwx/lib/python3.6/site-packages/torch/include/TH -I/home/user/anaconda3/envs/torch_zwx/lib/python3.6/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/user/anaconda3/envs/torch_zwx/include/python3.6m -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O2 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=pointnet2_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_70,code=compute_70 -gencode=arch=compute_70,code=sm_70 -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/pointnet2_api.o: compile /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/src/pointnet2_api.cpp 24 | build /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/ball_query.o: compile /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/src/ball_query.cpp 25 | build /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/ball_query_gpu.o: cuda_compile /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/src/ball_query_gpu.cu 26 | build /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/group_points.o: compile /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/src/group_points.cpp 27 | build /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/group_points_gpu.o: cuda_compile /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/src/group_points_gpu.cu 28 | build /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/interpolate.o: compile /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/src/interpolate.cpp 29 | build /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/interpolate_gpu.o: cuda_compile /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/src/interpolate_gpu.cu 30 | build /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/sampling.o: compile /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/src/sampling.cpp 31 | build /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/sampling_gpu.o: cuda_compile /mnt/data1/zwx/ICCV2021_Submission7567_test/utils/Pointnet2.PyTorch/pointnet2/src/sampling_gpu.cu 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/ball_query_gpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XLechter/SDT/d87587cc70c4d7bb03fe4a795471984b5a2323ac/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/ball_query_gpu.o -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/group_points_gpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XLechter/SDT/d87587cc70c4d7bb03fe4a795471984b5a2323ac/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/group_points_gpu.o -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/interpolate_gpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XLechter/SDT/d87587cc70c4d7bb03fe4a795471984b5a2323ac/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/interpolate_gpu.o -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/pointnet2_api.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XLechter/SDT/d87587cc70c4d7bb03fe4a795471984b5a2323ac/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/pointnet2_api.o -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/sampling_gpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XLechter/SDT/d87587cc70c4d7bb03fe4a795471984b5a2323ac/utils/Pointnet2.PyTorch/pointnet2/build/temp.linux-x86_64-3.6/src/sampling_gpu.o -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/pointnet2.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: pointnet2 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/pointnet2.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | pointnet2.egg-info/PKG-INFO 3 | pointnet2.egg-info/SOURCES.txt 4 | pointnet2.egg-info/dependency_links.txt 5 | pointnet2.egg-info/top_level.txt 6 | src/ball_query.cpp 7 | src/ball_query_gpu.cu 8 | src/group_points.cpp 9 | src/group_points_gpu.cu 10 | src/interpolate.cpp 11 | src/interpolate_gpu.cu 12 | src/pointnet2_api.cpp 13 | src/sampling.cpp 14 | src/sampling_gpu.cu -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/pointnet2.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/pointnet2.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | pointnet2_cuda 2 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pointnet2_utils 5 | import pytorch_utils as pt_utils 6 | from typing import List 7 | 8 | 9 | class _PointnetSAModuleBase(nn.Module): 10 | 11 | def __init__(self): 12 | super().__init__() 13 | self.npoint = None 14 | self.groupers = None 15 | self.mlps = None 16 | self.pool_method = 'max_pool' 17 | 18 | def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): 19 | """ 20 | :param xyz: (B, N, 3) tensor of the xyz coordinates of the features 21 | :param features: (B, N, C) tensor of the descriptors of the the features 22 | :param new_xyz: 23 | :return: 24 | new_xyz: (B, npoint, 3) tensor of the new features' xyz 25 | new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors 26 | """ 27 | new_features_list = [] 28 | xyz_flipped = xyz.transpose(1, 2).contiguous() 29 | if new_xyz is None: 30 | new_xyz = pointnet2_utils.gather_operation( 31 | xyz_flipped, 32 | pointnet2_utils.furthest_point_sample(xyz, self.npoint) 33 | ).transpose(1, 2).contiguous() if self.npoint is not None else None 34 | for i in range(len(self.groupers)): 35 | new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) 36 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 37 | if self.pool_method == 'max_pool': 38 | new_features = F.max_pool2d( 39 | new_features, kernel_size=[1, new_features.size(3)] 40 | ) # (B, mlp[-1], npoint, 1) 41 | elif self.pool_method == 'avg_pool': 42 | new_features = F.avg_pool2d( 43 | new_features, kernel_size=[1, new_features.size(3)] 44 | ) # (B, mlp[-1], npoint, 1) 45 | else: 46 | raise NotImplementedError 47 | 48 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 49 | new_features_list.append(new_features) 50 | return new_xyz, torch.cat(new_features_list, dim=1) 51 | 52 | 53 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 54 | """Pointnet set abstraction layer with multiscale grouping""" 55 | 56 | def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, 57 | use_xyz: bool = True, pool_method='max_pool', instance_norm=False, features: torch.Tensor = None): 58 | """ 59 | :param npoint: int 60 | :param radii: list of float, list of radii to group with 61 | :param nsamples: list of int, number of samples in each ball query 62 | :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale 63 | :param bn: whether to use batchnorm 64 | :param use_xyz: 65 | :param pool_method: max_pool / avg_pool 66 | :param instance_norm: whether to use instance_norm 67 | """ 68 | super().__init__() 69 | 70 | assert len(radii) == len(nsamples) == len(mlps) 71 | 72 | self.npoint = npoint 73 | self.groupers = nn.ModuleList() 74 | self.mlps = nn.ModuleList() 75 | for i in range(len(radii)): 76 | radius = radii[i] 77 | nsample = nsamples[i] 78 | self.groupers.append( 79 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 80 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz) 81 | ) 82 | mlp_spec = mlps[i] 83 | 84 | if features is not None: 85 | if use_xyz: 86 | mlp_spec[0]+=3 87 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) 88 | self.pool_method = pool_method 89 | 90 | 91 | class PointnetSAModule(PointnetSAModuleMSG): 92 | """Pointnet set abstraction layer""" 93 | 94 | def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None, 95 | bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): 96 | """ 97 | :param mlp: list of int, spec of the pointnet before the global max_pool 98 | :param npoint: int, number of features 99 | :param radius: float, radius of ball 100 | :param nsample: int, number of samples in the ball query 101 | :param bn: whether to use batchnorm 102 | :param use_xyz: 103 | :param pool_method: max_pool / avg_pool 104 | :param instance_norm: whether to use instance_norm 105 | """ 106 | super().__init__( 107 | mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, 108 | pool_method=pool_method, instance_norm=instance_norm 109 | ) 110 | 111 | 112 | class PointnetFPModule(nn.Module): 113 | r"""Propigates the features of one set to another""" 114 | 115 | def __init__(self, *, mlp: List[int], bn: bool = True): 116 | """ 117 | :param mlp: list of int 118 | :param bn: whether to use batchnorm 119 | """ 120 | super().__init__() 121 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 122 | 123 | def forward( 124 | self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor 125 | ) -> torch.Tensor: 126 | """ 127 | :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features 128 | :param known: (B, m, 3) tensor of the xyz positions of the known features 129 | :param unknow_feats: (B, C1, n) tensor of the features to be propigated to 130 | :param known_feats: (B, C2, m) tensor of features to be propigated 131 | :return: 132 | new_features: (B, mlp[-1], n) tensor of the features of the unknown features 133 | """ 134 | if known is not None: 135 | dist, idx = pointnet2_utils.three_nn(unknown, known) 136 | dist_recip = 1.0 / (dist + 1e-8) 137 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 138 | weight = dist_recip / norm 139 | 140 | interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight) 141 | else: 142 | interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) 143 | 144 | if unknow_feats is not None: 145 | new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n) 146 | else: 147 | new_features = interpolated_feats 148 | 149 | new_features = new_features.unsqueeze(-1) 150 | new_features = self.mlp(new_features) 151 | 152 | return new_features.squeeze(-1) 153 | 154 | 155 | if __name__ == "__main__": 156 | pass 157 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import List, Tuple 3 | 4 | 5 | class SharedMLP(nn.Sequential): 6 | 7 | def __init__( 8 | self, 9 | args: List[int], 10 | *, 11 | bn: bool = False, 12 | activation=nn.ReLU(inplace=True), 13 | preact: bool = False, 14 | first: bool = False, 15 | name: str = "", 16 | instance_norm: bool = False, 17 | ): 18 | super().__init__() 19 | 20 | for i in range(len(args) - 1): 21 | self.add_module( 22 | name + 'layer{}'.format(i), 23 | Conv2d( 24 | args[i], 25 | args[i + 1], 26 | bn=(not first or not preact or (i != 0)) and bn, 27 | activation=activation 28 | if (not first or not preact or (i != 0)) else None, 29 | preact=preact, 30 | instance_norm=instance_norm 31 | ) 32 | ) 33 | 34 | 35 | class _ConvBase(nn.Sequential): 36 | 37 | def __init__( 38 | self, 39 | in_size, 40 | out_size, 41 | kernel_size, 42 | stride, 43 | padding, 44 | activation, 45 | bn, 46 | init, 47 | conv=None, 48 | batch_norm=None, 49 | bias=True, 50 | preact=False, 51 | name="", 52 | instance_norm=False, 53 | instance_norm_func=None 54 | ): 55 | super().__init__() 56 | 57 | bias = bias and (not bn) 58 | conv_unit = conv( 59 | in_size, 60 | out_size, 61 | kernel_size=kernel_size, 62 | stride=stride, 63 | padding=padding, 64 | bias=bias 65 | ) 66 | init(conv_unit.weight) 67 | if bias: 68 | nn.init.constant_(conv_unit.bias, 0) 69 | 70 | if bn: 71 | if not preact: 72 | bn_unit = batch_norm(out_size) 73 | else: 74 | bn_unit = batch_norm(in_size) 75 | if instance_norm: 76 | if not preact: 77 | in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) 78 | else: 79 | in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) 80 | 81 | if preact: 82 | if bn: 83 | self.add_module(name + 'bn', bn_unit) 84 | 85 | if activation is not None: 86 | self.add_module(name + 'activation', activation) 87 | 88 | if not bn and instance_norm: 89 | self.add_module(name + 'in', in_unit) 90 | 91 | self.add_module(name + 'conv', conv_unit) 92 | 93 | if not preact: 94 | if bn: 95 | self.add_module(name + 'bn', bn_unit) 96 | 97 | if activation is not None: 98 | self.add_module(name + 'activation', activation) 99 | 100 | if not bn and instance_norm: 101 | self.add_module(name + 'in', in_unit) 102 | 103 | 104 | class _BNBase(nn.Sequential): 105 | 106 | def __init__(self, in_size, batch_norm=None, name=""): 107 | super().__init__() 108 | self.add_module(name + "bn", batch_norm(in_size)) 109 | 110 | nn.init.constant_(self[0].weight, 1.0) 111 | nn.init.constant_(self[0].bias, 0) 112 | 113 | 114 | class BatchNorm1d(_BNBase): 115 | 116 | def __init__(self, in_size: int, *, name: str = ""): 117 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 118 | 119 | 120 | class BatchNorm2d(_BNBase): 121 | 122 | def __init__(self, in_size: int, name: str = ""): 123 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 124 | 125 | 126 | class Conv1d(_ConvBase): 127 | 128 | def __init__( 129 | self, 130 | in_size: int, 131 | out_size: int, 132 | *, 133 | kernel_size: int = 1, 134 | stride: int = 1, 135 | padding: int = 0, 136 | activation=nn.ReLU(inplace=True), 137 | bn: bool = False, 138 | init=nn.init.kaiming_normal_, 139 | bias: bool = True, 140 | preact: bool = False, 141 | name: str = "", 142 | instance_norm=False 143 | ): 144 | super().__init__( 145 | in_size, 146 | out_size, 147 | kernel_size, 148 | stride, 149 | padding, 150 | activation, 151 | bn, 152 | init, 153 | conv=nn.Conv1d, 154 | batch_norm=BatchNorm1d, 155 | bias=bias, 156 | preact=preact, 157 | name=name, 158 | instance_norm=instance_norm, 159 | instance_norm_func=nn.InstanceNorm1d 160 | ) 161 | 162 | 163 | class Conv2d(_ConvBase): 164 | 165 | def __init__( 166 | self, 167 | in_size: int, 168 | out_size: int, 169 | *, 170 | kernel_size: Tuple[int, int] = (1, 1), 171 | stride: Tuple[int, int] = (1, 1), 172 | padding: Tuple[int, int] = (0, 0), 173 | activation=nn.ReLU(inplace=True), 174 | bn: bool = False, 175 | init=nn.init.kaiming_normal_, 176 | bias: bool = True, 177 | preact: bool = False, 178 | name: str = "", 179 | instance_norm=False 180 | ): 181 | super().__init__( 182 | in_size, 183 | out_size, 184 | kernel_size, 185 | stride, 186 | padding, 187 | activation, 188 | bn, 189 | init, 190 | conv=nn.Conv2d, 191 | batch_norm=BatchNorm2d, 192 | bias=bias, 193 | preact=preact, 194 | name=name, 195 | instance_norm=instance_norm, 196 | instance_norm_func=nn.InstanceNorm2d 197 | ) 198 | 199 | 200 | class FC(nn.Sequential): 201 | 202 | def __init__( 203 | self, 204 | in_size: int, 205 | out_size: int, 206 | *, 207 | activation=nn.ReLU(inplace=True), 208 | bn: bool = False, 209 | init=None, 210 | preact: bool = False, 211 | name: str = "" 212 | ): 213 | super().__init__() 214 | 215 | fc = nn.Linear(in_size, out_size, bias=not bn) 216 | if init is not None: 217 | init(fc.weight) 218 | if not bn: 219 | nn.init.constant(fc.bias, 0) 220 | 221 | if preact: 222 | if bn: 223 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 224 | 225 | if activation is not None: 226 | self.add_module(name + 'activation', activation) 227 | 228 | self.add_module(name + 'fc', fc) 229 | 230 | if not preact: 231 | if bn: 232 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 233 | 234 | if activation is not None: 235 | self.add_module(name + 'activation', activation) 236 | 237 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='pointnet2', 6 | ext_modules=[ 7 | CUDAExtension('pointnet2_cuda', [ 8 | 'src/pointnet2_api.cpp', 9 | 10 | 'src/ball_query.cpp', 11 | 'src/ball_query_gpu.cu', 12 | 'src/group_points.cpp', 13 | 'src/group_points_gpu.cu', 14 | 'src/interpolate.cpp', 15 | 'src/interpolate_gpu.cu', 16 | 'src/sampling.cpp', 17 | 'src/sampling_gpu.cu', 18 | ], 19 | extra_compile_args={'cxx': ['-g'], 20 | 'nvcc': ['-O2']}) 21 | ], 22 | cmdclass={'build_ext': BuildExtension} 23 | ) 24 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "ball_query_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 11 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 13 | 14 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 15 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { 16 | CHECK_INPUT(new_xyz_tensor); 17 | CHECK_INPUT(xyz_tensor); 18 | const float *new_xyz = new_xyz_tensor.data(); 19 | const float *xyz = xyz_tensor.data(); 20 | int *idx = idx_tensor.data(); 21 | 22 | cudaStream_t stream = THCState_getCurrentStream(state); 23 | ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); 24 | return 1; 25 | } -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "ball_query_gpu.h" 6 | #include "cuda_utils.h" 7 | 8 | 9 | __global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, 10 | const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { 11 | // new_xyz: (B, M, 3) 12 | // xyz: (B, N, 3) 13 | // output: 14 | // idx: (B, M, nsample) 15 | int bs_idx = blockIdx.y; 16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 17 | if (bs_idx >= b || pt_idx >= m) return; 18 | 19 | new_xyz += bs_idx * m * 3 + pt_idx * 3; 20 | xyz += bs_idx * n * 3; 21 | idx += bs_idx * m * nsample + pt_idx * nsample; 22 | 23 | float radius2 = radius * radius; 24 | float new_x = new_xyz[0]; 25 | float new_y = new_xyz[1]; 26 | float new_z = new_xyz[2]; 27 | 28 | int cnt = 0; 29 | for (int k = 0; k < n; ++k) { 30 | float x = xyz[k * 3 + 0]; 31 | float y = xyz[k * 3 + 1]; 32 | float z = xyz[k * 3 + 2]; 33 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 34 | if (d2 < radius2){ 35 | if (cnt == 0){ 36 | for (int l = 0; l < nsample; ++l) { 37 | idx[l] = k; 38 | } 39 | } 40 | idx[cnt] = k; 41 | ++cnt; 42 | if (cnt >= nsample) break; 43 | } 44 | } 45 | } 46 | 47 | 48 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ 49 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { 50 | // new_xyz: (B, M, 3) 51 | // xyz: (B, N, 3) 52 | // output: 53 | // idx: (B, M, nsample) 54 | 55 | cudaError_t err; 56 | 57 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 58 | dim3 threads(THREADS_PER_BLOCK); 59 | 60 | ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); 61 | // cudaDeviceSynchronize(); // for using printf in kernel function 62 | err = cudaGetLastError(); 63 | if (cudaSuccess != err) { 64 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 65 | exit(-1); 66 | } 67 | } -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/ball_query_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_GPU_H 2 | #define _BALL_QUERY_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 10 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); 11 | 12 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, 13 | const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define TOTAL_THREADS 1024 7 | #define THREADS_PER_BLOCK 256 8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 9 | 10 | inline int opt_n_threads(int work_size) { 11 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 12 | 13 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | #endif 16 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "group_points_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | 11 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 12 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 13 | 14 | float *grad_points = grad_points_tensor.data(); 15 | const int *idx = idx_tensor.data(); 16 | const float *grad_out = grad_out_tensor.data(); 17 | 18 | cudaStream_t stream = THCState_getCurrentStream(state); 19 | 20 | group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); 21 | return 1; 22 | } 23 | 24 | 25 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 26 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { 27 | 28 | const float *points = points_tensor.data(); 29 | const int *idx = idx_tensor.data(); 30 | float *out = out_tensor.data(); 31 | 32 | cudaStream_t stream = THCState_getCurrentStream(state); 33 | 34 | group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream); 35 | return 1; 36 | } -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "group_points_gpu.h" 6 | 7 | 8 | __global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, 9 | const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { 10 | // grad_out: (B, C, npoints, nsample) 11 | // idx: (B, npoints, nsample) 12 | // output: 13 | // grad_points: (B, C, N) 14 | int bs_idx = blockIdx.z; 15 | int c_idx = blockIdx.y; 16 | int index = blockIdx.x * blockDim.x + threadIdx.x; 17 | int pt_idx = index / nsample; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 19 | 20 | int sample_idx = index % nsample; 21 | grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 22 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 23 | 24 | atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); 25 | } 26 | 27 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 28 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 29 | // grad_out: (B, C, npoints, nsample) 30 | // idx: (B, npoints, nsample) 31 | // output: 32 | // grad_points: (B, C, N) 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | 47 | __global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, 48 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 49 | // points: (B, C, N) 50 | // idx: (B, npoints, nsample) 51 | // output: 52 | // out: (B, C, npoints, nsample) 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int index = blockIdx.x * blockDim.x + threadIdx.x; 56 | int pt_idx = index / nsample; 57 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 58 | 59 | int sample_idx = index % nsample; 60 | 61 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 62 | int in_idx = bs_idx * c * n + c_idx * n + idx[0]; 63 | int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 64 | 65 | out[out_idx] = points[in_idx]; 66 | } 67 | 68 | 69 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 70 | const float *points, const int *idx, float *out, cudaStream_t stream) { 71 | // points: (B, C, N) 72 | // idx: (B, npoints, nsample) 73 | // output: 74 | // out: (B, C, npoints, nsample) 75 | cudaError_t err; 76 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 77 | dim3 threads(THREADS_PER_BLOCK); 78 | 79 | group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out); 80 | // cudaDeviceSynchronize(); // for using printf in kernel function 81 | err = cudaGetLastError(); 82 | if (cudaSuccess != err) { 83 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 84 | exit(-1); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/group_points_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _GROUP_POINTS_GPU_H 2 | #define _GROUP_POINTS_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 12 | 13 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 14 | const float *points, const int *idx, float *out, cudaStream_t stream); 15 | 16 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "interpolate_gpu.h" 10 | 11 | extern THCState *state; 12 | 13 | 14 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 15 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { 16 | const float *unknown = unknown_tensor.data(); 17 | const float *known = known_tensor.data(); 18 | float *dist2 = dist2_tensor.data(); 19 | int *idx = idx_tensor.data(); 20 | 21 | cudaStream_t stream = THCState_getCurrentStream(state); 22 | three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream); 23 | } 24 | 25 | 26 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, 27 | at::Tensor points_tensor, 28 | at::Tensor idx_tensor, 29 | at::Tensor weight_tensor, 30 | at::Tensor out_tensor) { 31 | 32 | const float *points = points_tensor.data(); 33 | const float *weight = weight_tensor.data(); 34 | float *out = out_tensor.data(); 35 | const int *idx = idx_tensor.data(); 36 | 37 | cudaStream_t stream = THCState_getCurrentStream(state); 38 | three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream); 39 | } 40 | 41 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, 42 | at::Tensor grad_out_tensor, 43 | at::Tensor idx_tensor, 44 | at::Tensor weight_tensor, 45 | at::Tensor grad_points_tensor) { 46 | 47 | const float *grad_out = grad_out_tensor.data(); 48 | const float *weight = weight_tensor.data(); 49 | float *grad_points = grad_points_tensor.data(); 50 | const int *idx = idx_tensor.data(); 51 | 52 | cudaStream_t stream = THCState_getCurrentStream(state); 53 | three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream); 54 | } -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | #include "interpolate_gpu.h" 7 | 8 | 9 | __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, 10 | const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { 11 | // unknown: (B, N, 3) 12 | // known: (B, M, 3) 13 | // output: 14 | // dist2: (B, N, 3) 15 | // idx: (B, N, 3) 16 | 17 | int bs_idx = blockIdx.y; 18 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 19 | if (bs_idx >= b || pt_idx >= n) return; 20 | 21 | unknown += bs_idx * n * 3 + pt_idx * 3; 22 | known += bs_idx * m * 3; 23 | dist2 += bs_idx * n * 3 + pt_idx * 3; 24 | idx += bs_idx * n * 3 + pt_idx * 3; 25 | 26 | float ux = unknown[0]; 27 | float uy = unknown[1]; 28 | float uz = unknown[2]; 29 | 30 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 31 | int besti1 = 0, besti2 = 0, besti3 = 0; 32 | for (int k = 0; k < m; ++k) { 33 | float x = known[k * 3 + 0]; 34 | float y = known[k * 3 + 1]; 35 | float z = known[k * 3 + 2]; 36 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 37 | if (d < best1) { 38 | best3 = best2; besti3 = besti2; 39 | best2 = best1; besti2 = besti1; 40 | best1 = d; besti1 = k; 41 | } 42 | else if (d < best2) { 43 | best3 = best2; besti3 = besti2; 44 | best2 = d; besti2 = k; 45 | } 46 | else if (d < best3) { 47 | best3 = d; besti3 = k; 48 | } 49 | } 50 | dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; 51 | idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; 52 | } 53 | 54 | 55 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 56 | const float *known, float *dist2, int *idx, cudaStream_t stream) { 57 | // unknown: (B, N, 3) 58 | // known: (B, M, 3) 59 | // output: 60 | // dist2: (B, N, 3) 61 | // idx: (B, N, 3) 62 | 63 | cudaError_t err; 64 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 65 | dim3 threads(THREADS_PER_BLOCK); 66 | 67 | three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx); 68 | 69 | err = cudaGetLastError(); 70 | if (cudaSuccess != err) { 71 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 72 | exit(-1); 73 | } 74 | } 75 | 76 | 77 | __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, 78 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { 79 | // points: (B, C, M) 80 | // idx: (B, N, 3) 81 | // weight: (B, N, 3) 82 | // output: 83 | // out: (B, C, N) 84 | 85 | int bs_idx = blockIdx.z; 86 | int c_idx = blockIdx.y; 87 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 88 | 89 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 90 | 91 | weight += bs_idx * n * 3 + pt_idx * 3; 92 | points += bs_idx * c * m + c_idx * m; 93 | idx += bs_idx * n * 3 + pt_idx * 3; 94 | out += bs_idx * c * n + c_idx * n; 95 | 96 | out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; 97 | } 98 | 99 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 100 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { 101 | // points: (B, C, M) 102 | // idx: (B, N, 3) 103 | // weight: (B, N, 3) 104 | // output: 105 | // out: (B, C, N) 106 | 107 | cudaError_t err; 108 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 109 | dim3 threads(THREADS_PER_BLOCK); 110 | three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out); 111 | 112 | err = cudaGetLastError(); 113 | if (cudaSuccess != err) { 114 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 115 | exit(-1); 116 | } 117 | } 118 | 119 | 120 | __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 121 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { 122 | // grad_out: (B, C, N) 123 | // weight: (B, N, 3) 124 | // output: 125 | // grad_points: (B, C, M) 126 | 127 | int bs_idx = blockIdx.z; 128 | int c_idx = blockIdx.y; 129 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 130 | 131 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 132 | 133 | grad_out += bs_idx * c * n + c_idx * n + pt_idx; 134 | weight += bs_idx * n * 3 + pt_idx * 3; 135 | grad_points += bs_idx * c * m + c_idx * m; 136 | idx += bs_idx * n * 3 + pt_idx * 3; 137 | 138 | 139 | atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); 140 | atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); 141 | atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); 142 | } 143 | 144 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 145 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { 146 | // grad_out: (B, C, N) 147 | // weight: (B, N, 3) 148 | // output: 149 | // grad_points: (B, C, M) 150 | 151 | cudaError_t err; 152 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 153 | dim3 threads(THREADS_PER_BLOCK); 154 | three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points); 155 | 156 | err = cudaGetLastError(); 157 | if (cudaSuccess != err) { 158 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 159 | exit(-1); 160 | } 161 | } -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATE_GPU_H 2 | #define _INTERPOLATE_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 11 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); 12 | 13 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 14 | const float *known, float *dist2, int *idx, cudaStream_t stream); 15 | 16 | 17 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, 18 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); 19 | 20 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 21 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); 22 | 23 | 24 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, 25 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); 26 | 27 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 28 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream); 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/pointnet2_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ball_query_gpu.h" 5 | #include "group_points_gpu.h" 6 | #include "sampling_gpu.h" 7 | #include "interpolate_gpu.h" 8 | 9 | 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 11 | m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast"); 12 | 13 | m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast"); 14 | m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast"); 15 | 16 | m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); 17 | m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); 18 | 19 | m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); 20 | 21 | m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); 22 | m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); 23 | m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast"); 24 | } 25 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "sampling_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | 11 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 12 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ 13 | const float *points = points_tensor.data(); 14 | const int *idx = idx_tensor.data(); 15 | float *out = out_tensor.data(); 16 | 17 | cudaStream_t stream = THCState_getCurrentStream(state); 18 | gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); 19 | return 1; 20 | } 21 | 22 | 23 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 24 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 25 | 26 | const float *grad_out = grad_out_tensor.data(); 27 | const int *idx = idx_tensor.data(); 28 | float *grad_points = grad_points_tensor.data(); 29 | 30 | cudaStream_t stream = THCState_getCurrentStream(state); 31 | gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); 32 | return 1; 33 | } 34 | 35 | 36 | int furthest_point_sampling_wrapper(int b, int n, int m, 37 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 38 | 39 | const float *points = points_tensor.data(); 40 | float *temp = temp_tensor.data(); 41 | int *idx = idx_tensor.data(); 42 | 43 | cudaStream_t stream = THCState_getCurrentStream(state); 44 | furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); 45 | return 1; 46 | } 47 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "sampling_gpu.h" 6 | 7 | 8 | __global__ void gather_points_kernel_fast(int b, int c, int n, int m, 9 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 10 | // points: (B, C, N) 11 | // idx: (B, M) 12 | // output: 13 | // out: (B, C, M) 14 | 15 | int bs_idx = blockIdx.z; 16 | int c_idx = blockIdx.y; 17 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 19 | 20 | out += bs_idx * c * m + c_idx * m + pt_idx; 21 | idx += bs_idx * m + pt_idx; 22 | points += bs_idx * c * n + c_idx * n; 23 | out[0] = points[idx[0]]; 24 | } 25 | 26 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 27 | const float *points, const int *idx, float *out, cudaStream_t stream) { 28 | // points: (B, C, N) 29 | // idx: (B, npoints) 30 | // output: 31 | // out: (B, C, npoints) 32 | 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 47 | const int *__restrict__ idx, float *__restrict__ grad_points) { 48 | // grad_out: (B, C, M) 49 | // idx: (B, M) 50 | // output: 51 | // grad_points: (B, C, N) 52 | 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 56 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 57 | 58 | grad_out += bs_idx * c * m + c_idx * m + pt_idx; 59 | idx += bs_idx * m + pt_idx; 60 | grad_points += bs_idx * c * n + c_idx * n; 61 | 62 | atomicAdd(grad_points + idx[0], grad_out[0]); 63 | } 64 | 65 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 66 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 67 | // grad_out: (B, C, npoints) 68 | // idx: (B, npoints) 69 | // output: 70 | // grad_points: (B, C, N) 71 | 72 | cudaError_t err; 73 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 74 | dim3 threads(THREADS_PER_BLOCK); 75 | 76 | gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points); 77 | 78 | err = cudaGetLastError(); 79 | if (cudaSuccess != err) { 80 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 81 | exit(-1); 82 | } 83 | } 84 | 85 | 86 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){ 87 | const float v1 = dists[idx1], v2 = dists[idx2]; 88 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 89 | dists[idx1] = max(v1, v2); 90 | dists_i[idx1] = v2 > v1 ? i2 : i1; 91 | } 92 | 93 | template 94 | __global__ void furthest_point_sampling_kernel(int b, int n, int m, 95 | const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { 96 | // dataset: (B, N, 3) 97 | // tmp: (B, N) 98 | // output: 99 | // idx: (B, M) 100 | 101 | if (m <= 0) return; 102 | __shared__ float dists[block_size]; 103 | __shared__ int dists_i[block_size]; 104 | 105 | int batch_index = blockIdx.x; 106 | dataset += batch_index * n * 3; 107 | temp += batch_index * n; 108 | idxs += batch_index * m; 109 | 110 | int tid = threadIdx.x; 111 | const int stride = block_size; 112 | 113 | int old = 0; 114 | if (threadIdx.x == 0) 115 | idxs[0] = old; 116 | 117 | __syncthreads(); 118 | for (int j = 1; j < m; j++) { 119 | int besti = 0; 120 | float best = -1; 121 | float x1 = dataset[old * 3 + 0]; 122 | float y1 = dataset[old * 3 + 1]; 123 | float z1 = dataset[old * 3 + 2]; 124 | for (int k = tid; k < n; k += stride) { 125 | float x2, y2, z2; 126 | x2 = dataset[k * 3 + 0]; 127 | y2 = dataset[k * 3 + 1]; 128 | z2 = dataset[k * 3 + 2]; 129 | // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 130 | // if (mag <= 1e-3) 131 | // continue; 132 | 133 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 134 | float d2 = min(d, temp[k]); 135 | temp[k] = d2; 136 | besti = d2 > best ? k : besti; 137 | best = d2 > best ? d2 : best; 138 | } 139 | dists[tid] = best; 140 | dists_i[tid] = besti; 141 | __syncthreads(); 142 | 143 | if (block_size >= 1024) { 144 | if (tid < 512) { 145 | __update(dists, dists_i, tid, tid + 512); 146 | } 147 | __syncthreads(); 148 | } 149 | 150 | if (block_size >= 512) { 151 | if (tid < 256) { 152 | __update(dists, dists_i, tid, tid + 256); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 256) { 157 | if (tid < 128) { 158 | __update(dists, dists_i, tid, tid + 128); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 128) { 163 | if (tid < 64) { 164 | __update(dists, dists_i, tid, tid + 64); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 64) { 169 | if (tid < 32) { 170 | __update(dists, dists_i, tid, tid + 32); 171 | } 172 | __syncthreads(); 173 | } 174 | if (block_size >= 32) { 175 | if (tid < 16) { 176 | __update(dists, dists_i, tid, tid + 16); 177 | } 178 | __syncthreads(); 179 | } 180 | if (block_size >= 16) { 181 | if (tid < 8) { 182 | __update(dists, dists_i, tid, tid + 8); 183 | } 184 | __syncthreads(); 185 | } 186 | if (block_size >= 8) { 187 | if (tid < 4) { 188 | __update(dists, dists_i, tid, tid + 4); 189 | } 190 | __syncthreads(); 191 | } 192 | if (block_size >= 4) { 193 | if (tid < 2) { 194 | __update(dists, dists_i, tid, tid + 2); 195 | } 196 | __syncthreads(); 197 | } 198 | if (block_size >= 2) { 199 | if (tid < 1) { 200 | __update(dists, dists_i, tid, tid + 1); 201 | } 202 | __syncthreads(); 203 | } 204 | 205 | old = dists_i[0]; 206 | if (tid == 0) 207 | idxs[j] = old; 208 | } 209 | } 210 | 211 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 212 | const float *dataset, float *temp, int *idxs, cudaStream_t stream) { 213 | // dataset: (B, N, 3) 214 | // tmp: (B, N) 215 | // output: 216 | // idx: (B, M) 217 | 218 | cudaError_t err; 219 | unsigned int n_threads = opt_n_threads(n); 220 | 221 | switch (n_threads) { 222 | case 1024: 223 | furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break; 224 | case 512: 225 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break; 226 | case 256: 227 | furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break; 228 | case 128: 229 | furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break; 230 | case 64: 231 | furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break; 232 | case 32: 233 | furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break; 234 | case 16: 235 | furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break; 236 | case 8: 237 | furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break; 238 | case 4: 239 | furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break; 240 | case 2: 241 | furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break; 242 | case 1: 243 | furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break; 244 | default: 245 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); 246 | } 247 | 248 | err = cudaGetLastError(); 249 | if (cudaSuccess != err) { 250 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 251 | exit(-1); 252 | } 253 | } 254 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/pointnet2/src/sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_GPU_H 2 | #define _SAMPLING_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 10 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 11 | 12 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 13 | const float *points, const int *idx, float *out, cudaStream_t stream); 14 | 15 | 16 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | 23 | int furthest_point_sampling_wrapper(int b, int n, int m, 24 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 25 | 26 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 27 | const float *dataset, float *temp, int *idxs, cudaStream_t stream); 28 | 29 | #endif 30 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/tools/_init_path.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')) 3 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/tools/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch.utils.data as torch_data 4 | import kitti_utils 5 | import cv2 6 | from PIL import Image 7 | 8 | 9 | USE_INTENSITY = False 10 | 11 | 12 | class KittiDataset(torch_data.Dataset): 13 | def __init__(self, root_dir, split='train', mode='TRAIN'): 14 | self.split = split 15 | self.mode = mode 16 | self.classes = ['Car'] 17 | is_test = self.split == 'test' 18 | self.imageset_dir = os.path.join(root_dir, 'KITTI', 'object', 'testing' if is_test else 'training') 19 | 20 | split_dir = os.path.join(root_dir, 'KITTI', 'ImageSets', split + '.txt') 21 | self.image_idx_list = [x.strip() for x in open(split_dir).readlines()] 22 | self.sample_id_list = [int(sample_id) for sample_id in self.image_idx_list] 23 | self.num_sample = self.image_idx_list.__len__() 24 | 25 | self.npoints = 16384 26 | 27 | self.image_dir = os.path.join(self.imageset_dir, 'image_2') 28 | self.lidar_dir = os.path.join(self.imageset_dir, 'velodyne') 29 | self.calib_dir = os.path.join(self.imageset_dir, 'calib') 30 | self.label_dir = os.path.join(self.imageset_dir, 'label_2') 31 | self.plane_dir = os.path.join(self.imageset_dir, 'planes') 32 | 33 | def get_image(self, idx): 34 | img_file = os.path.join(self.image_dir, '%06d.png' % idx) 35 | assert os.path.exists(img_file) 36 | return cv2.imread(img_file) # (H, W, 3) BGR mode 37 | 38 | def get_image_shape(self, idx): 39 | img_file = os.path.join(self.image_dir, '%06d.png' % idx) 40 | assert os.path.exists(img_file) 41 | im = Image.open(img_file) 42 | width, height = im.size 43 | return height, width, 3 44 | 45 | def get_lidar(self, idx): 46 | lidar_file = os.path.join(self.lidar_dir, '%06d.bin' % idx) 47 | assert os.path.exists(lidar_file) 48 | return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4) 49 | 50 | def get_calib(self, idx): 51 | calib_file = os.path.join(self.calib_dir, '%06d.txt' % idx) 52 | assert os.path.exists(calib_file) 53 | return kitti_utils.Calibration(calib_file) 54 | 55 | def get_label(self, idx): 56 | label_file = os.path.join(self.label_dir, '%06d.txt' % idx) 57 | assert os.path.exists(label_file) 58 | return kitti_utils.get_objects_from_label(label_file) 59 | 60 | @staticmethod 61 | def get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape): 62 | val_flag_1 = np.logical_and(pts_img[:, 0] >= 0, pts_img[:, 0] < img_shape[1]) 63 | val_flag_2 = np.logical_and(pts_img[:, 1] >= 0, pts_img[:, 1] < img_shape[0]) 64 | val_flag_merge = np.logical_and(val_flag_1, val_flag_2) 65 | pts_valid_flag = np.logical_and(val_flag_merge, pts_rect_depth >= 0) 66 | return pts_valid_flag 67 | 68 | def filtrate_objects(self, obj_list): 69 | type_whitelist = self.classes 70 | if self.mode == 'TRAIN': 71 | type_whitelist = list(self.classes) 72 | if 'Car' in self.classes: 73 | type_whitelist.append('Van') 74 | 75 | valid_obj_list = [] 76 | for obj in obj_list: 77 | if obj.cls_type not in type_whitelist: 78 | continue 79 | 80 | valid_obj_list.append(obj) 81 | return valid_obj_list 82 | 83 | def __len__(self): 84 | return len(self.sample_id_list) 85 | 86 | def __getitem__(self, index): 87 | sample_id = int(self.sample_id_list[index]) 88 | calib = self.get_calib(sample_id) 89 | img_shape = self.get_image_shape(sample_id) 90 | pts_lidar = self.get_lidar(sample_id) 91 | 92 | # get valid point (projected points should be in image) 93 | pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3]) 94 | pts_intensity = pts_lidar[:, 3] 95 | 96 | pts_img, pts_rect_depth = calib.rect_to_img(pts_rect) 97 | pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape) 98 | 99 | pts_rect = pts_rect[pts_valid_flag][:, 0:3] 100 | pts_intensity = pts_intensity[pts_valid_flag] 101 | 102 | if self.npoints < len(pts_rect): 103 | pts_depth = pts_rect[:, 2] 104 | pts_near_flag = pts_depth < 40.0 105 | far_idxs_choice = np.where(pts_near_flag == 0)[0] 106 | near_idxs = np.where(pts_near_flag == 1)[0] 107 | near_idxs_choice = np.random.choice(near_idxs, self.npoints - len(far_idxs_choice), replace=False) 108 | 109 | choice = np.concatenate((near_idxs_choice, far_idxs_choice), axis=0) \ 110 | if len(far_idxs_choice) > 0 else near_idxs_choice 111 | np.random.shuffle(choice) 112 | else: 113 | choice = np.arange(0, len(pts_rect), dtype=np.int32) 114 | if self.npoints > len(pts_rect): 115 | extra_choice = np.random.choice(choice, self.npoints - len(pts_rect), replace=False) 116 | choice = np.concatenate((choice, extra_choice), axis=0) 117 | np.random.shuffle(choice) 118 | 119 | ret_pts_rect = pts_rect[choice, :] 120 | ret_pts_intensity = pts_intensity[choice] - 0.5 # translate intensity to [-0.5, 0.5] 121 | 122 | pts_features = [ret_pts_intensity.reshape(-1, 1)] 123 | ret_pts_features = np.concatenate(pts_features, axis=1) if pts_features.__len__() > 1 else pts_features[0] 124 | 125 | sample_info = {'sample_id': sample_id} 126 | 127 | if self.mode == 'TEST': 128 | if USE_INTENSITY: 129 | pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C) 130 | else: 131 | pts_input = ret_pts_rect 132 | sample_info['pts_input'] = pts_input 133 | sample_info['pts_rect'] = ret_pts_rect 134 | sample_info['pts_features'] = ret_pts_features 135 | return sample_info 136 | 137 | gt_obj_list = self.filtrate_objects(self.get_label(sample_id)) 138 | 139 | gt_boxes3d = kitti_utils.objs_to_boxes3d(gt_obj_list) 140 | 141 | # prepare input 142 | if USE_INTENSITY: 143 | pts_input = np.concatenate((ret_pts_rect, ret_pts_features), axis=1) # (N, C) 144 | else: 145 | pts_input = ret_pts_rect 146 | 147 | # generate training labels 148 | cls_labels = self.generate_training_labels(ret_pts_rect, gt_boxes3d) 149 | sample_info['pts_input'] = pts_input 150 | sample_info['pts_rect'] = ret_pts_rect 151 | sample_info['cls_labels'] = cls_labels 152 | return sample_info 153 | 154 | @staticmethod 155 | def generate_training_labels(pts_rect, gt_boxes3d): 156 | cls_label = np.zeros((pts_rect.shape[0]), dtype=np.int32) 157 | gt_corners = kitti_utils.boxes3d_to_corners3d(gt_boxes3d, rotate=True) 158 | extend_gt_boxes3d = kitti_utils.enlarge_box3d(gt_boxes3d, extra_width=0.2) 159 | extend_gt_corners = kitti_utils.boxes3d_to_corners3d(extend_gt_boxes3d, rotate=True) 160 | for k in range(gt_boxes3d.shape[0]): 161 | box_corners = gt_corners[k] 162 | fg_pt_flag = kitti_utils.in_hull(pts_rect, box_corners) 163 | cls_label[fg_pt_flag] = 1 164 | 165 | # enlarge the bbox3d, ignore nearby points 166 | extend_box_corners = extend_gt_corners[k] 167 | fg_enlarge_flag = kitti_utils.in_hull(pts_rect, extend_box_corners) 168 | ignore_flag = np.logical_xor(fg_pt_flag, fg_enlarge_flag) 169 | cls_label[ignore_flag] = -1 170 | 171 | return cls_label 172 | 173 | def collate_batch(self, batch): 174 | batch_size = batch.__len__() 175 | ans_dict = {} 176 | 177 | for key in batch[0].keys(): 178 | if isinstance(batch[0][key], np.ndarray): 179 | ans_dict[key] = np.concatenate([batch[k][key][np.newaxis, ...] for k in range(batch_size)], axis=0) 180 | 181 | else: 182 | ans_dict[key] = [batch[k][key] for k in range(batch_size)] 183 | if isinstance(batch[0][key], int): 184 | ans_dict[key] = np.array(ans_dict[key], dtype=np.int32) 185 | elif isinstance(batch[0][key], float): 186 | ans_dict[key] = np.array(ans_dict[key], dtype=np.float32) 187 | 188 | return ans_dict 189 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/tools/pointnet2_msg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pointnet2.pointnet2_modules import PointnetFPModule, PointnetSAModuleMSG 4 | import pointnet2.pytorch_utils as pt_utils 5 | 6 | 7 | def get_model(input_channels=0): 8 | return Pointnet2MSG(input_channels=input_channels) 9 | 10 | 11 | NPOINTS = [4096, 1024, 256, 64] 12 | RADIUS = [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]] 13 | NSAMPLE = [[16, 32], [16, 32], [16, 32], [16, 32]] 14 | MLPS = [[[16, 16, 32], [32, 32, 64]], [[64, 64, 128], [64, 96, 128]], 15 | [[128, 196, 256], [128, 196, 256]], [[256, 256, 512], [256, 384, 512]]] 16 | FP_MLPS = [[128, 128], [256, 256], [512, 512], [512, 512]] 17 | CLS_FC = [128] 18 | DP_RATIO = 0.5 19 | 20 | 21 | class Pointnet2MSG(nn.Module): 22 | def __init__(self, input_channels=6): 23 | super().__init__() 24 | 25 | self.SA_modules = nn.ModuleList() 26 | channel_in = input_channels 27 | 28 | skip_channel_list = [input_channels] 29 | for k in range(NPOINTS.__len__()): 30 | mlps = MLPS[k].copy() 31 | channel_out = 0 32 | for idx in range(mlps.__len__()): 33 | mlps[idx] = [channel_in] + mlps[idx] 34 | channel_out += mlps[idx][-1] 35 | 36 | self.SA_modules.append( 37 | PointnetSAModuleMSG( 38 | npoint=NPOINTS[k], 39 | radii=RADIUS[k], 40 | nsamples=NSAMPLE[k], 41 | mlps=mlps, 42 | use_xyz=True, 43 | bn=True 44 | ) 45 | ) 46 | skip_channel_list.append(channel_out) 47 | channel_in = channel_out 48 | 49 | self.FP_modules = nn.ModuleList() 50 | 51 | for k in range(FP_MLPS.__len__()): 52 | pre_channel = FP_MLPS[k + 1][-1] if k + 1 < len(FP_MLPS) else channel_out 53 | self.FP_modules.append( 54 | PointnetFPModule(mlp=[pre_channel + skip_channel_list[k]] + FP_MLPS[k]) 55 | ) 56 | 57 | cls_layers = [] 58 | pre_channel = FP_MLPS[0][-1] 59 | for k in range(0, CLS_FC.__len__()): 60 | cls_layers.append(pt_utils.Conv1d(pre_channel, CLS_FC[k], bn=True)) 61 | pre_channel = CLS_FC[k] 62 | cls_layers.append(pt_utils.Conv1d(pre_channel, 1, activation=None)) 63 | cls_layers.insert(1, nn.Dropout(0.5)) 64 | self.cls_layer = nn.Sequential(*cls_layers) 65 | 66 | def _break_up_pc(self, pc): 67 | xyz = pc[..., 0:3].contiguous() 68 | features = ( 69 | pc[..., 3:].transpose(1, 2).contiguous() 70 | if pc.size(-1) > 3 else None 71 | ) 72 | 73 | return xyz, features 74 | 75 | def forward(self, pointcloud: torch.cuda.FloatTensor): 76 | xyz, features = self._break_up_pc(pointcloud) 77 | 78 | l_xyz, l_features = [xyz], [features] 79 | for i in range(len(self.SA_modules)): 80 | li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) 81 | l_xyz.append(li_xyz) 82 | l_features.append(li_features) 83 | 84 | for i in range(-1, -(len(self.FP_modules) + 1), -1): 85 | l_features[i - 1] = self.FP_modules[i]( 86 | l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i] 87 | ) 88 | 89 | pred_cls = self.cls_layer(l_features[0]).transpose(1, 2).contiguous() # (B, N, 1) 90 | return pred_cls 91 | -------------------------------------------------------------------------------- /utils/Pointnet2.PyTorch/tools/train_and_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.optim.lr_scheduler as lr_sched 7 | from torch.nn.utils import clip_grad_norm_ 8 | from torch.utils.data import DataLoader 9 | import tensorboard_logger as tb_log 10 | from data.dataset import KittiDataset 11 | import argparse 12 | import importlib 13 | 14 | parser = argparse.ArgumentParser(description="Arg parser") 15 | parser.add_argument("--batch_size", type=int, default=8) 16 | parser.add_argument("--epochs", type=int, default=100) 17 | parser.add_argument("--ckpt_save_interval", type=int, default=5) 18 | parser.add_argument('--workers', type=int, default=4) 19 | parser.add_argument("--mode", type=str, default='train') 20 | parser.add_argument("--ckpt", type=str, default='None') 21 | 22 | parser.add_argument("--net", type=str, default='pointnet2_msg') 23 | 24 | parser.add_argument('--lr', type=float, default=0.002) 25 | parser.add_argument('--lr_decay', type=float, default=0.2) 26 | parser.add_argument('--lr_clip', type=float, default=0.000001) 27 | parser.add_argument('--decay_step_list', type=list, default=[50, 70, 80, 90]) 28 | parser.add_argument('--weight_decay', type=float, default=0.001) 29 | 30 | parser.add_argument("--output_dir", type=str, default='output') 31 | parser.add_argument("--extra_tag", type=str, default='default') 32 | 33 | args = parser.parse_args() 34 | 35 | FG_THRESH = 0.3 36 | 37 | 38 | def log_print(info, log_f=None): 39 | print(info) 40 | if log_f is not None: 41 | print(info, file=log_f) 42 | 43 | 44 | class DiceLoss(nn.Module): 45 | def __init__(self, ignore_target=-1): 46 | super().__init__() 47 | self.ignore_target = ignore_target 48 | 49 | def forward(self, input, target): 50 | """ 51 | :param input: (N), logit 52 | :param target: (N), {0, 1} 53 | :return: 54 | """ 55 | input = torch.sigmoid(input.view(-1)) 56 | target = target.float().view(-1) 57 | mask = (target != self.ignore_target).float() 58 | return 1.0 - (torch.min(input, target) * mask).sum() / torch.clamp((torch.max(input, target) * mask).sum(), min=1.0) 59 | 60 | 61 | def train_one_epoch(model, train_loader, optimizer, epoch, lr_scheduler, total_it, tb_log, log_f): 62 | model.train() 63 | log_print('===============TRAIN EPOCH %d================' % epoch, log_f=log_f) 64 | loss_func = DiceLoss(ignore_target=-1) 65 | 66 | for it, batch in enumerate(train_loader): 67 | optimizer.zero_grad() 68 | 69 | pts_input, cls_labels = batch['pts_input'], batch['cls_labels'] 70 | pts_input = torch.from_numpy(pts_input).cuda(non_blocking=True).float() 71 | cls_labels = torch.from_numpy(cls_labels).cuda(non_blocking=True).long().view(-1) 72 | 73 | pred_cls = model(pts_input) 74 | pred_cls = pred_cls.view(-1) 75 | 76 | loss = loss_func(pred_cls, cls_labels) 77 | loss.backward() 78 | clip_grad_norm_(model.parameters(), 1.0) 79 | optimizer.step() 80 | 81 | total_it += 1 82 | 83 | pred_class = (torch.sigmoid(pred_cls) > FG_THRESH) 84 | fg_mask = cls_labels > 0 85 | correct = ((pred_class.long() == cls_labels) & fg_mask).float().sum() 86 | union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct 87 | iou = correct / torch.clamp(union, min=1.0) 88 | 89 | cur_lr = lr_scheduler.get_lr()[0] 90 | tb_log.log_value('learning_rate', cur_lr, epoch) 91 | if tb_log is not None: 92 | tb_log.log_value('train_loss', loss, total_it) 93 | tb_log.log_value('train_fg_iou', iou, total_it) 94 | 95 | log_print('training epoch %d: it=%d/%d, total_it=%d, loss=%.5f, fg_iou=%.3f, lr=%f' % 96 | (epoch, it, len(train_loader), total_it, loss.item(), iou.item(), cur_lr), log_f=log_f) 97 | 98 | return total_it 99 | 100 | 101 | def eval_one_epoch(model, eval_loader, epoch, tb_log=None, log_f=None): 102 | model.train() 103 | log_print('===============EVAL EPOCH %d================' % epoch, log_f=log_f) 104 | 105 | iou_list = [] 106 | for it, batch in enumerate(eval_loader): 107 | pts_input, cls_labels = batch['pts_input'], batch['cls_labels'] 108 | pts_input = torch.from_numpy(pts_input).cuda(non_blocking=True).float() 109 | cls_labels = torch.from_numpy(cls_labels).cuda(non_blocking=True).long().view(-1) 110 | 111 | pred_cls = model(pts_input) 112 | pred_cls = pred_cls.view(-1) 113 | 114 | pred_class = (torch.sigmoid(pred_cls) > FG_THRESH) 115 | fg_mask = cls_labels > 0 116 | correct = ((pred_class.long() == cls_labels) & fg_mask).float().sum() 117 | union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct 118 | iou = correct / torch.clamp(union, min=1.0) 119 | 120 | iou_list.append(iou.item()) 121 | log_print('EVAL: it=%d/%d, iou=%.3f' % (it, len(eval_loader), iou), log_f=log_f) 122 | 123 | iou_list = np.array(iou_list) 124 | avg_iou = iou_list.mean() 125 | if tb_log is not None: 126 | tb_log.log_value('eval_fg_iou', avg_iou, epoch) 127 | 128 | log_print('\nEpoch %d: Average IoU (samples=%d): %.6f' % (epoch, iou_list.__len__(), avg_iou), log_f=log_f) 129 | return avg_iou 130 | 131 | 132 | def save_checkpoint(model, epoch, ckpt_name): 133 | if isinstance(model, torch.nn.DataParallel): 134 | model_state = model.module.state_dict() 135 | else: 136 | model_state = model.state_dict() 137 | 138 | state = {'epoch': epoch, 'model_state': model_state} 139 | ckpt_name = '{}.pth'.format(ckpt_name) 140 | torch.save(state, ckpt_name) 141 | 142 | 143 | def load_checkpoint(model, filename): 144 | if os.path.isfile(filename): 145 | log_print("==> Loading from checkpoint %s" % filename) 146 | checkpoint = torch.load(filename) 147 | epoch = checkpoint['epoch'] 148 | model.load_state_dict(checkpoint['model_state']) 149 | log_print("==> Done") 150 | else: 151 | raise FileNotFoundError 152 | 153 | return epoch 154 | 155 | 156 | def train_and_eval(model, train_loader, eval_loader, tb_log, ckpt_dir, log_f): 157 | model.cuda() 158 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 159 | 160 | def lr_lbmd(cur_epoch): 161 | cur_decay = 1 162 | for decay_step in args.decay_step_list: 163 | if cur_epoch >= decay_step: 164 | cur_decay = cur_decay * args.lr_decay 165 | return max(cur_decay, args.lr_clip / args.lr) 166 | 167 | lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd) 168 | 169 | total_it = 0 170 | for epoch in range(1, args.epochs + 1): 171 | lr_scheduler.step(epoch) 172 | total_it = train_one_epoch(model, train_loader, optimizer, epoch, lr_scheduler, total_it, tb_log, log_f) 173 | 174 | if epoch % args.ckpt_save_interval == 0: 175 | with torch.no_grad(): 176 | avg_iou = eval_one_epoch(model, eval_loader, epoch, tb_log, log_f) 177 | ckpt_name = os.path.join(ckpt_dir, 'checkpoint_epoch_%d' % epoch) 178 | save_checkpoint(model, epoch, ckpt_name) 179 | 180 | 181 | if __name__ == '__main__': 182 | MODEL = importlib.import_module(args.net) # import network module 183 | model = MODEL.get_model(input_channels=0) 184 | 185 | eval_set = KittiDataset(root_dir='data', mode='EVAL', split='val') 186 | eval_loader = DataLoader(eval_set, batch_size=args.batch_size, shuffle=False, pin_memory=True, 187 | num_workers=args.workers, collate_fn=eval_set.collate_batch) 188 | 189 | if args.mode == 'train': 190 | train_set = KittiDataset(root_dir='data', mode='TRAIN', split='train') 191 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=True, 192 | num_workers=args.workers, collate_fn=train_set.collate_batch) 193 | # output dir config 194 | output_dir = os.path.join(args.output_dir, args.extra_tag) 195 | os.makedirs(output_dir, exist_ok=True) 196 | tb_log.configure(os.path.join(output_dir, 'tensorboard')) 197 | ckpt_dir = os.path.join(output_dir, 'ckpt') 198 | os.makedirs(ckpt_dir, exist_ok=True) 199 | 200 | log_file = os.path.join(output_dir, 'log.txt') 201 | log_f = open(log_file, 'w') 202 | 203 | for key, val in vars(args).items(): 204 | log_print("{:16} {}".format(key, val), log_f=log_f) 205 | 206 | # train and eval 207 | train_and_eval(model, train_loader, eval_loader, tb_log, ckpt_dir, log_f) 208 | log_f.close() 209 | elif args.mode == 'eval': 210 | epoch = load_checkpoint(model, args.ckpt) 211 | model.cuda() 212 | with torch.no_grad(): 213 | avg_iou = eval_one_epoch(model, eval_loader, epoch) 214 | else: 215 | raise NotImplementedError 216 | 217 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/emd/CDEMD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XLechter/SDT/d87587cc70c4d7bb03fe4a795471984b5a2323ac/utils/emd/CDEMD.png -------------------------------------------------------------------------------- /utils/emd/README.md: -------------------------------------------------------------------------------- 1 | ## Earth Mover's Distance of point clouds 2 | 3 | ![](/utils/emdls/emd/CDEMD.png) 4 | 5 | Compared to the Chamfer Distance (CD), the Earth Mover's Distance (EMD) is more reliable to distinguish the visual quality of the point clouds. See our [paper](http://cseweb.ucsd.edu/~mil070/projects/AAAI2020/paper.pdf) for more details. 6 | 7 | We provide an EMD implementation for point cloud comparison, which only needs $O(n)$ memory and thus enables dense point clouds (with 10,000 points or over) and large batch size. It is based on an approximated algorithm (auction algorithm) and cannot guarantee a (but near) bijection assignment. It employs a parameter $\epsilon$ to balance the error rate and the speed of convergence. Smaller $\epsilon$ achieves more accurate results, but needs a longer time for convergence. The time complexity is $O(n^2k)$, where $k$ is the number of iterations. We set a $\epsilon = 0.005, k = 50$ during training and a $\epsilon = 0.002, k = 10000$ during testing. 8 | 9 | ### Compile 10 | Run `python3 setup.py install` to compile. 11 | 12 | ### Example 13 | See `emd_module.py/test_emd()` for examples. 14 | 15 | ### Input 16 | 17 | - **xyz1, xyz2**: float tensors with shape `[#batch, #points, 3]`. xyz1 is the predicted point cloud and xyz2 is the ground truth point cloud. Two point clouds should have same size and be normalized to [0, 1]. The number of points should be a multiple of 1024. The batch size should be no greater than 512. Since we only calculate gradients for xyz1, please do not swap xyz1 and xyz2. 18 | - **eps**: a float tensor, the parameter balances the error rate and the speed of convergence. 19 | - **iters**: a int tensor, the number of iterations. 20 | 21 | ### Output 22 | 23 | - **dist**: a float tensor with shape `[#batch, #points]`. sqrt(dist) are the L2 distances between the pairs of points. 24 | - **assignment**: a int tensor with shape `[#batch, #points]`. The index of the matched point in the ground truth point cloud. 25 | -------------------------------------------------------------------------------- /utils/emd/clean.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -rf __pycache__/ 4 | rm -rf dist/ 5 | rm -rf build/ 6 | rm -rf emd.egg-info/ 7 | rm -rf /mnt/lustre/chenxinyi1/.conda/envs/pt/lib/python3.7/site-packages/emd-0.0.0-py3.7-linux-x86_64.egg/ 8 | -------------------------------------------------------------------------------- /utils/emd/emd.cpp: -------------------------------------------------------------------------------- 1 | // EMD approximation module (based on auction algorithm) 2 | // author: Minghua Liu 3 | #include 4 | #include 5 | 6 | int emd_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price, 7 | at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments, 8 | at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters); 9 | 10 | int emd_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx); 11 | 12 | 13 | 14 | int emd_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price, 15 | at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments, 16 | at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters) { 17 | return emd_cuda_forward(xyz1, xyz2, dist, assignment, price, assignment_inv, bid, bid_increments, max_increments, unass_idx, unass_cnt, unass_cnt_sum, cnt_tmp, max_idx, eps, iters); 18 | } 19 | 20 | int emd_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx) { 21 | 22 | return emd_cuda_backward(xyz1, xyz2, gradxyz, graddist, idx); 23 | } 24 | 25 | 26 | 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 29 | m.def("forward", &emd_forward, "emd forward (CUDA)"); 30 | m.def("backward", &emd_backward, "emd backward (CUDA)"); 31 | } -------------------------------------------------------------------------------- /utils/emd/emd_module.py: -------------------------------------------------------------------------------- 1 | # EMD approximation module (based on auction algorithm) 2 | # memory complexity: O(n) 3 | # time complexity: O(n^2 * iter) 4 | # author: Minghua Liu 5 | 6 | # Input: 7 | # xyz1, xyz2: [#batch, #points, 3] 8 | # where xyz1 is the predicted point cloud and xyz2 is the ground truth point cloud 9 | # two point clouds should have same size and be normalized to [0, 1] 10 | # #points should be a multiple of 1024 11 | # #batch should be no greater than 512 12 | # eps is a parameter which balances the error rate and the speed of convergence 13 | # iters is the number of iteration 14 | # we only calculate gradient for xyz1 15 | 16 | # Output: 17 | # dist: [#batch, #points], sqrt(dist) -> L2 distance 18 | # assignment: [#batch, #points], index of the matched point in the ground truth point cloud 19 | # the result is an approximation and the assignment is not guranteed to be a bijection 20 | 21 | import time 22 | import numpy as np 23 | import torch 24 | from torch import nn 25 | from torch.autograd import Function 26 | import emd 27 | 28 | 29 | 30 | 31 | class emdFunction(Function): 32 | @staticmethod 33 | def forward(ctx, xyz1, xyz2, eps, iters): 34 | 35 | batchsize, n, _ = xyz1.size() 36 | _, m, _ = xyz2.size() 37 | 38 | assert(n == m) 39 | assert(xyz1.size()[0] == xyz2.size()[0]) 40 | #assert(n % 1024 == 0) 41 | assert(batchsize <= 512) 42 | 43 | xyz1 = xyz1.contiguous().float().cuda() 44 | xyz2 = xyz2.contiguous().float().cuda() 45 | dist = torch.zeros(batchsize, n, device='cuda').contiguous() 46 | assignment = torch.zeros(batchsize, n, device='cuda', dtype=torch.int32).contiguous() - 1 47 | assignment_inv = torch.zeros(batchsize, m, device='cuda', dtype=torch.int32).contiguous() - 1 48 | price = torch.zeros(batchsize, m, device='cuda').contiguous() 49 | bid = torch.zeros(batchsize, n, device='cuda', dtype=torch.int32).contiguous() 50 | bid_increments = torch.zeros(batchsize, n, device='cuda').contiguous() 51 | max_increments = torch.zeros(batchsize, m, device='cuda').contiguous() 52 | unass_idx = torch.zeros(batchsize * n, device='cuda', dtype=torch.int32).contiguous() 53 | max_idx = torch.zeros(batchsize * m, device='cuda', dtype=torch.int32).contiguous() 54 | unass_cnt = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous() 55 | unass_cnt_sum = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous() 56 | cnt_tmp = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous() 57 | 58 | emd.forward(xyz1, xyz2, dist, assignment, price, assignment_inv, bid, bid_increments, max_increments, unass_idx, unass_cnt, unass_cnt_sum, cnt_tmp, max_idx, eps, iters) 59 | 60 | ctx.save_for_backward(xyz1, xyz2, assignment) 61 | return dist, assignment 62 | 63 | @staticmethod 64 | def backward(ctx, graddist, gradidx): 65 | xyz1, xyz2, assignment = ctx.saved_tensors 66 | graddist = graddist.contiguous() 67 | 68 | gradxyz1 = torch.zeros(xyz1.size(), device='cuda').contiguous() 69 | gradxyz2 = torch.zeros(xyz2.size(), device='cuda').contiguous() 70 | 71 | emd.backward(xyz1, xyz2, gradxyz1, graddist, assignment) 72 | return gradxyz1, gradxyz2, None, None 73 | 74 | class emdModule(nn.Module): 75 | def __init__(self): 76 | super(emdModule, self).__init__() 77 | 78 | def forward(self, input1, input2, eps, iters): 79 | return emdFunction.apply(input1, input2, eps, iters) 80 | 81 | def test_emd(): 82 | x1 = torch.rand(20, 8192, 3).cuda() 83 | x2 = torch.rand(20, 8192, 3).cuda() 84 | emd = emdModule() 85 | start_time = time.perf_counter() 86 | dis, assigment = emd(x1, x2, 0.05, 3000) 87 | print("Input_size: ", x1.shape) 88 | print("Runtime: %lfs" % (time.perf_counter() - start_time)) 89 | print("EMD: %lf" % np.sqrt(dis.cpu()).mean()) 90 | print("|set(assignment)|: %d" % assigment.unique().numel()) 91 | assigment = assigment.cpu().numpy() 92 | assigment = np.expand_dims(assigment, -1) 93 | x2 = np.take_along_axis(x2, assigment, axis = 1) 94 | d = (x1 - x2) * (x1 - x2) 95 | print("Verified EMD: %lf" % np.sqrt(d.cpu().sum(-1)).mean()) 96 | 97 | #test_emd() 98 | 99 | 100 | -------------------------------------------------------------------------------- /utils/emd/run_compile.sh: -------------------------------------------------------------------------------- 1 | partition=ips_share 2 | job_name=compile 3 | gpus=1 4 | g=$((${gpus}<8?${gpus}:8)) 5 | 6 | 7 | srun -u --partition=${partition} --job-name=${job_name} \ 8 | -n1 --gres=gpu:${gpus} --ntasks-per-node=1 -w 'SH-IDC1-10-198-6-85' \ 9 | python3 emd_module.py 10 | -------------------------------------------------------------------------------- /utils/emd/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='emd', 6 | ext_modules=[ 7 | CUDAExtension('emd', [ 8 | 'emd.cpp', 9 | 'emd_cuda.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /utils/expansion_penalty/clean.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -rf __pycache__/ 4 | rm -rf dist/ 5 | rm -rf build/ 6 | rm -rf expansion_penalty.egg-info/ 7 | rm -rf /mnt/lustre/chenxinyi1/.conda/envs/pt/lib/python3.7/site-packages/expansion_penalty-0.0.0-py3.7-linux-x86_64.egg/ 8 | -------------------------------------------------------------------------------- /utils/expansion_penalty/expansion_penalty.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int expansion_penalty_cuda_forward(at::Tensor xyz, int primitive_size, at::Tensor father, at::Tensor dist, double alpha, at::Tensor neighbor, at::Tensor cost, at::Tensor mean_mst_length); 5 | 6 | int expansion_penalty_cuda_backward(at::Tensor xyz, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx); 7 | 8 | int expansion_penalty_forward(at::Tensor xyz, int primitive_size, at::Tensor father, at::Tensor dist, double alpha, at::Tensor neighbor, at::Tensor cost, at::Tensor mean_mst_length) { 9 | return expansion_penalty_cuda_forward(xyz, primitive_size, father, dist, alpha, neighbor, cost, mean_mst_length); 10 | } 11 | 12 | int expansion_penalty_backward(at::Tensor xyz, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx) { 13 | 14 | return expansion_penalty_cuda_backward(xyz, gradxyz, graddist, idx); 15 | } 16 | 17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 18 | m.def("forward", &expansion_penalty_forward, "expansion_penalty forward (CUDA)"); 19 | m.def("backward", &expansion_penalty_backward, "expansion_penalty backward (CUDA)"); 20 | } -------------------------------------------------------------------------------- /utils/expansion_penalty/expansion_penalty_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | __global__ void calc_penalty(int b, int n, int primitive_size, const float * xyz, int * idx, float * dist, float alpha, int * neighbor, float * cost, float * mean_mst_length) { 8 | const int batch = 512; // primitive_size should be less than 512 9 | __shared__ float xyz_buf[batch * 3]; 10 | __shared__ bool vis[batch]; 11 | __shared__ float cur_dis[batch]; 12 | __shared__ int cur_idx[batch]; 13 | __shared__ float min_dis[batch]; 14 | __shared__ int min_idx[batch]; 15 | __shared__ float sum_dis[batch]; 16 | __shared__ int cnt[batch]; 17 | __shared__ int degree[batch]; 18 | 19 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 20 | vis[threadIdx.x] = false; 21 | cur_dis[threadIdx.x] = 1e9; 22 | cnt[threadIdx.x] = 0; 23 | degree[threadIdx.x] = 0; 24 | 25 | for (int j = threadIdx.x; j < primitive_size * 3; j += blockDim.x) { 26 | xyz_buf[j] = xyz[(i * n + blockIdx.y * primitive_size) * 3 + j]; 27 | } 28 | __syncthreads(); 29 | 30 | __shared__ int last; 31 | __shared__ float x_last; 32 | __shared__ float y_last; 33 | __shared__ float z_last; 34 | 35 | if (threadIdx.x == 0) { 36 | vis[0] = true; 37 | sum_dis[0] = 0; 38 | last = 0; 39 | x_last = xyz_buf[0]; 40 | y_last = xyz_buf[1]; 41 | z_last = xyz_buf[2]; 42 | } 43 | __syncthreads(); 44 | 45 | for (int j = 0; j < primitive_size - 1; j++) { 46 | if (vis[threadIdx.x] == false) { 47 | float delta_x = xyz_buf[threadIdx.x * 3 + 0] - x_last; 48 | float delta_y = xyz_buf[threadIdx.x * 3 + 1] - y_last; 49 | float delta_z = xyz_buf[threadIdx.x * 3 + 2] - z_last; 50 | float d = sqrtf(delta_x * delta_x + delta_y * delta_y + delta_z * delta_z); 51 | 52 | if (d < cur_dis[threadIdx.x]) { 53 | cur_dis[threadIdx.x] = d; 54 | cur_idx[threadIdx.x] = last; 55 | } 56 | min_dis[threadIdx.x] = cur_dis[threadIdx.x]; 57 | } 58 | else { 59 | min_dis[threadIdx.x] = 1e9; 60 | } 61 | min_idx[threadIdx.x] = threadIdx.x; 62 | __syncthreads(); 63 | 64 | int stride = 1; 65 | while(stride <= primitive_size / 2) { 66 | int index = (threadIdx.x + 1) * stride * 2 - 1; 67 | if(index < primitive_size && min_dis[index - stride] < min_dis[index]) { 68 | min_dis[index] = min_dis[index - stride]; 69 | min_idx[index] = min_idx[index - stride]; 70 | } 71 | stride = stride * 2; 72 | __syncthreads(); 73 | } 74 | __syncthreads(); 75 | 76 | if (threadIdx.x == primitive_size - 1) { 77 | last = min_idx[threadIdx.x]; 78 | int u = cur_idx[last]; 79 | vis[last] = true; 80 | x_last = xyz_buf[last * 3 + 0]; 81 | y_last = xyz_buf[last * 3 + 1]; 82 | z_last = xyz_buf[last * 3 + 2]; 83 | 84 | cnt[last] += 1; 85 | degree[last] += 1; 86 | neighbor[(i * n + blockIdx.y * primitive_size + last) * 512 + cnt[last]] = u; 87 | cost[(i * n + blockIdx.y * primitive_size + last) * 512 + cnt[last]] = cur_dis[last]; 88 | cnt[u] += 1; 89 | degree[u] += 1; 90 | neighbor[(i * n + blockIdx.y * primitive_size + u) * 512 + cnt[u]] = last; 91 | cost[(i * n + blockIdx.y * primitive_size + u) * 512 + cnt[u]] = cur_dis[last]; 92 | 93 | if (cnt[last] >= 512 || cnt[u] >= 512) { 94 | printf("MST Error: Too many neighbors! %d %d %d %d\n", cnt[last], cnt[u], last, u); 95 | } 96 | 97 | sum_dis[last] = cur_dis[last]; 98 | } 99 | __syncthreads(); 100 | } 101 | 102 | __syncthreads(); 103 | int stride = 1; 104 | while(stride <= primitive_size / 2) { 105 | int index = (threadIdx.x + 1) * stride * 2 - 1; 106 | if (index < primitive_size) 107 | sum_dis[index] += sum_dis[index - stride]; 108 | stride = stride * 2; 109 | __syncthreads(); 110 | } 111 | __syncthreads(); 112 | 113 | __shared__ float mean_dis; 114 | if (threadIdx.x == 0) { 115 | mean_dis = sum_dis[primitive_size - 1] / (primitive_size - 1); 116 | atomicAdd(&mean_mst_length[i], mean_dis); 117 | } 118 | 119 | dist[i * n + blockIdx.y * primitive_size + threadIdx.x] = 0; 120 | idx[i * n + blockIdx.y * primitive_size + threadIdx.x] = -1; 121 | __syncthreads(); 122 | 123 | while (true) { 124 | __shared__ int flag; 125 | flag = 0; 126 | int tmp = cnt[threadIdx.x]; 127 | __syncthreads(); 128 | if (tmp == 1) { 129 | atomicAdd(&flag, 1); 130 | for (int j = 1; j <= degree[threadIdx.x]; j++) { 131 | int u = neighbor[(i * n + blockIdx.y * primitive_size + threadIdx.x) * 512 + j]; 132 | if (cnt[u] > 1 || (cnt[u] == 1 && threadIdx.x > u)) { 133 | float c = cost[(i * n + blockIdx.y * primitive_size + threadIdx.x) * 512 + j]; 134 | atomicAdd(&cnt[threadIdx.x], -1); 135 | atomicAdd(&cnt[u], -1); 136 | if (c > mean_dis * alpha) { 137 | dist[i * n + blockIdx.y * primitive_size + threadIdx.x] = c; 138 | idx[i * n + blockIdx.y * primitive_size + threadIdx.x] = blockIdx.y * primitive_size + u; 139 | } 140 | } 141 | } 142 | } 143 | __syncthreads(); 144 | if (flag == 0) break; 145 | __syncthreads(); 146 | } 147 | __syncthreads(); 148 | } 149 | } 150 | 151 | int expansion_penalty_cuda_forward(at::Tensor xyz, int primitive_size, at::Tensor idx, at::Tensor dist, double alpha, at::Tensor neighbor, at::Tensor cost, at::Tensor mean_mst_length) { 152 | 153 | const auto batch_size = xyz.size(0); 154 | const auto n = xyz.size(1); 155 | 156 | calc_penalty<<>>(batch_size, n, primitive_size, xyz.data(), idx.data(), dist.data(), 157 | alpha, neighbor.data(), cost.data(), mean_mst_length.data()); 158 | 159 | cudaError_t err = cudaGetLastError(); 160 | if (err != cudaSuccess) { 161 | printf("error in nnd Output: %s\n", cudaGetErrorString(err)); 162 | return 0; 163 | } 164 | return 1; 165 | } 166 | 167 | __global__ void calc_grad(int b, int n, const float * xyz, const float * grad_dist, const int * idx, float * grad_xyz) { 168 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 169 | for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y) 170 | if (idx[i * n + j] != -1) { 171 | float x1 = xyz[(i * n + j) * 3 + 0]; 172 | float y1 = xyz[(i * n + j) * 3 + 1]; 173 | float z1 = xyz[(i * n + j) * 3 + 2]; 174 | int j2 = idx[i * n + j]; 175 | float x2 = xyz[(i * n + j2) * 3 + 0]; 176 | float y2 = xyz[(i * n + j2) * 3 + 1]; 177 | float z2 = xyz[(i * n + j2) * 3 + 2]; 178 | float g = grad_dist[i * n + j] * 2; 179 | atomicAdd(&(grad_xyz[(i * n + j) * 3 + 0]), g * (x1 - x2)); 180 | atomicAdd(&(grad_xyz[(i * n + j) * 3 + 1]), g * (y1 - y2)); 181 | atomicAdd(&(grad_xyz[(i * n + j) * 3 + 2]), g * (z1 - z2)); 182 | } 183 | } 184 | } 185 | 186 | int expansion_penalty_cuda_backward(at::Tensor xyz, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx) { 187 | const auto batch_size = xyz.size(0); 188 | const auto n = xyz.size(1); 189 | 190 | calc_grad<<>>(batch_size, n, xyz.data(), graddist.data(), idx.data(), gradxyz.data()); 191 | 192 | cudaError_t err = cudaGetLastError(); 193 | if (err != cudaSuccess) { 194 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 195 | return 0; 196 | } 197 | return 1; 198 | } 199 | -------------------------------------------------------------------------------- /utils/expansion_penalty/expansion_penalty_module.py: -------------------------------------------------------------------------------- 1 | # Expansion penalty module (based on minimum spanning tree) 2 | # author: Minghua Liu 3 | 4 | # Input: 5 | # xyz: [#batch, #point] 6 | # primitive_size: int, the number of points of sampled points for each surface elements, which should be no greater than 512 7 | # in each point cloud, the points from the same surface element should be successive 8 | # alpha: float, > 1, only penalize those edges whose length are greater than (alpha * mean_length) 9 | 10 | #Output: 11 | # dist: [#batch, #point], if the point u is penalized then dist[u] is the distance between u and its neighbor in the MST, otherwise dist[u] is 0 12 | # assignment: [#batch, #point], if the point u is penalized then assignment[u] is its neighbor in the MST, otherwise assignment[u] is -1 13 | # mean_mst_length: [#batch], the average length of the edeges in each point clouds 14 | 15 | 16 | import time 17 | import numpy as np 18 | import torch 19 | from torch import nn 20 | from torch.autograd import Function 21 | import expansion_penalty 22 | 23 | # GPU tensors only 24 | class expansionPenaltyFunction(Function): 25 | @staticmethod 26 | def forward(ctx, xyz, primitive_size, alpha): 27 | assert(primitive_size <= 512) 28 | batchsize, n, _ = xyz.size() 29 | assert(n % primitive_size == 0) 30 | xyz = xyz.contiguous().float().cuda() 31 | dist = torch.zeros(batchsize, n, device='cuda').contiguous() 32 | assignment = torch.zeros(batchsize, n, device='cuda', dtype=torch.int32).contiguous() - 1 33 | neighbor = torch.zeros(batchsize, n * 512, device='cuda', dtype=torch.int32).contiguous() 34 | cost = torch.zeros(batchsize, n * 512, device='cuda').contiguous() 35 | mean_mst_length = torch.zeros(batchsize, device='cuda').contiguous() 36 | expansion_penalty.forward(xyz, primitive_size, assignment, dist, alpha, neighbor, cost, mean_mst_length) 37 | ctx.save_for_backward(xyz, assignment) 38 | return dist, assignment, mean_mst_length / (n / primitive_size) 39 | 40 | @staticmethod 41 | def backward(ctx, grad_dist, grad_idx, grad_mml): 42 | xyz, assignment = ctx.saved_tensors 43 | grad_dist = grad_dist.contiguous() 44 | grad_xyz = torch.zeros(xyz.size(), device='cuda').contiguous() 45 | expansion_penalty.backward(xyz, grad_xyz, grad_dist, assignment) 46 | return grad_xyz, None, None 47 | 48 | class expansionPenaltyModule(nn.Module): 49 | def __init__(self): 50 | super(expansionPenaltyModule, self).__init__() 51 | 52 | def forward(self, input, primitive_size, alpha): 53 | return expansionPenaltyFunction.apply(input, primitive_size, alpha) 54 | 55 | def test_expansion_penalty(): 56 | x = torch.rand(20, 8192, 3).cuda() 57 | print("Input_size: ", x.shape) 58 | expansion = expansionPenaltyModule() 59 | start_time = time.perf_counter() 60 | dis, ass, mean_length = expansion(x, 512, 1.5) 61 | print("Runtime: %lfs" % (time.perf_counter() - start_time)) 62 | 63 | #test_expansion_penalty() 64 | -------------------------------------------------------------------------------- /utils/expansion_penalty/run_compile.sh: -------------------------------------------------------------------------------- 1 | partition=ips_share 2 | job_name=compile 3 | gpus=1 4 | g=$((${gpus}<8?${gpus}:8)) 5 | 6 | 7 | srun -u --partition=${partition} --job-name=${job_name} \ 8 | -n1 --gres=gpu:${gpus} --ntasks-per-node=1 \ 9 | python3 setup.py install 10 | -------------------------------------------------------------------------------- /utils/expansion_penalty/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='expansion_penalty', 6 | ext_modules=[ 7 | CUDAExtension('expansion_penalty', [ 8 | 'expansion_penalty.cpp', 9 | 'expansion_penalty_cuda.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /utils/generate_excel_results.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pandas import ExcelWriter 3 | import os 4 | import numpy as np 5 | import sys 6 | 7 | log_base_path = sys.argv[1] 8 | main_categories = ['airplane', 'cabinet', 'car', 'chair', 'lamp', 'sofa', 'table', 'vessel', 'main_mean'] 9 | novel_categories = ['bed', 'bench', 'bookshelf', 'bus', 'guitar', 'motorbike', 'pistol', 'skateboard', 'novel_mean'] 10 | model_names = ['pcn', 'topnet', 'msn', 'cascade'] 11 | train_modes = ['cd', 'emd'] 12 | loss_cols = ['emd', 'cd_p', 'cd_p_f1', 'cd_t', 'cd_t_f1'] 13 | sheet_names = ['cd_train_main_category', 'cd_train_novel_category', 'cd_train_overview','emd_train_main_category', 14 | 'emd_train_novel_category', 'emd_train_overview', ] 15 | 16 | 17 | def save_xls(list_dfs, xls_path): 18 | assert len(list_dfs) == len(sheet_names) 19 | with ExcelWriter(xls_path, engine='xlsxwriter') as writer: 20 | for n, df in enumerate(list_dfs): 21 | df.to_excel(writer, sheet_names[n]) 22 | if n != 2 and n != 5: 23 | writer.sheets[sheet_names[n]].set_row(2, None, None, {'hidden': True}) 24 | writer.save() 25 | 26 | 27 | def generate_cat_results_row(best_emd_cat, best_cd_p_cat, best_cd_t_cat): 28 | main_cat_r = [] 29 | novel_cat_r = [] 30 | emd = [float(line.split(' ')[5]) for line in best_emd_cat] 31 | cd_p = [float(line.split(' ')[1][:-1]) for line in best_cd_p_cat] 32 | cd_p_f1 = [float(line.split(' ')[-1]) for line in best_cd_p_cat] 33 | cd_t = [float(line.split(' ')[3][:-1]) for line in best_cd_t_cat] 34 | cd_t_f1 = [float(line.split(' ')[-1]) for line in best_cd_t_cat] 35 | for i in range(8): 36 | main_cat_r.extend([emd[i], cd_p[i], cd_p_f1[i], cd_t[i], cd_t_f1[i]]) 37 | novel_cat_r.extend([emd[i+8], cd_p[i+8], cd_p_f1[i+8], cd_t[i+8], cd_t_f1[i+8]]) 38 | main_cat_r.extend([np.mean(emd[:8]), np.mean(cd_p[:8]), np.mean(cd_p_f1[:8]), np.mean(cd_t[:8]), np.mean(cd_t_f1[:8])]) 39 | novel_cat_r.extend([np.mean(emd[8:]), np.mean(cd_p[8:]), np.mean(cd_p_f1[8:]), np.mean(cd_t[8:]), np.mean(cd_t_f1[8:])]) 40 | return main_cat_r, novel_cat_r 41 | 42 | 43 | def generate_overview_row(best_emd_overview, best_cd_p_overview, best_cd_t_overview): 44 | best_emd = float(best_emd_overview.split(' ')[5]) 45 | best_cd_p = float(best_cd_p_overview.split(' ')[1][:-1]) 46 | best_cd_p_f1 = float(best_cd_p_overview.split(' ')[-1]) 47 | best_cd_t = float(best_cd_t_overview.split(' ')[3][:-1]) 48 | best_cd_t_f1 = float(best_cd_t_overview.split(' ')[-1]) 49 | return [best_emd*(10**4), best_cd_p*(10**4), best_cd_p_f1, best_cd_t*(10**4), best_cd_t_f1] 50 | 51 | 52 | sheets = [] 53 | for mode in train_modes: 54 | main_cat_col = pd.MultiIndex.from_product([main_categories, loss_cols]) 55 | main_cat_df = pd.DataFrame(columns=main_cat_col, index=model_names) 56 | novel_cat_col = pd.MultiIndex.from_product([novel_categories, loss_cols]) 57 | novel_cat_df = pd.DataFrame(columns=novel_cat_col, index=model_names) 58 | overview_df = pd.DataFrame(columns=loss_cols, index=model_names) 59 | 60 | for model in model_names: 61 | log_file = os.path.join(log_base_path, model + '_' + mode, 'log_test.txt') 62 | with open(log_file) as f: 63 | content = f.readlines() 64 | main_cat_row, novel_cat_row = generate_cat_results_row(content[16:32], content[50:66], content[84:100]) 65 | overview_row = generate_overview_row(content[33], content[67], content[101]) 66 | main_cat_df.loc[model] = main_cat_row 67 | novel_cat_df.loc[model] = novel_cat_row 68 | overview_df.loc[model] = overview_row 69 | 70 | sheets.append(main_cat_df) 71 | sheets.append(novel_cat_df) 72 | sheets.append(overview_df) 73 | 74 | save_xls(sheets, os.path.join(log_base_path, 'benchmark_results.xlsx')) 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /utils/vis_pcd.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import open3d as o3d 3 | import sys 4 | import numpy as np 5 | from mpl_toolkits.mplot3d import Axes3D 6 | import matplotlib.pyplot as plt 7 | from matplotlib.backends.backend_pdf import PdfPages 8 | from PIL import Image 9 | import io 10 | import os 11 | 12 | 13 | def get_pts(pcd): 14 | points = np.asarray(pcd.points) 15 | X = [] 16 | Y = [] 17 | Z = [] 18 | for pt in range(points.shape[0]): 19 | X.append(points[pt][0]) 20 | Y.append(points[pt][1]) 21 | Z.append(points[pt][2]) 22 | 23 | return np.asarray(X), np.asarray(Y), np.asarray(Z) 24 | 25 | 26 | def set_axes_equal(ax): 27 | '''Make axes of 3D plot have equal scale so that spheres appear as spheres, 28 | cubes as cubes, etc.. This is one possible solution to Matplotlib's 29 | ax.set_aspect('equal') and ax.axis('equal') not working for 3D. 30 | 31 | Input 32 | ax: a matplotlib axis, e.g., as output from plt.gca(). 33 | ''' 34 | 35 | x_limits = ax.get_xlim3d() 36 | y_limits = ax.get_ylim3d() 37 | z_limits = ax.get_zlim3d() 38 | 39 | x_range = abs(x_limits[1] - x_limits[0]) 40 | x_middle = np.mean(x_limits) 41 | y_range = abs(y_limits[1] - y_limits[0]) 42 | y_middle = np.mean(y_limits) 43 | z_range = abs(z_limits[1] - z_limits[0]) 44 | z_middle = np.mean(z_limits) 45 | 46 | # The plot bounding box is a sphere in the sense of the infinity 47 | # norm, hence I call half the max range the plot radius. 48 | plot_radius = 0.5*max([x_range, y_range, z_range]) 49 | 50 | ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius]) 51 | ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius]) 52 | ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius]) 53 | 54 | 55 | def plot_pcd(pcd_list, pdf, object_id, view_id): 56 | num_pcds = len(pcd_list) 57 | fig = plt.figure(figsize=(60, 60)) 58 | ax = fig.add_subplot(111, projection='3d') 59 | ax.set_aspect('equal') 60 | for ind, pcd_original in enumerate(pcd_list): 61 | pcd = o3d.geometry.PointCloud(pcd_original) 62 | translation_matrix = np.asarray( 63 | [[1, 0, 0, ind - int(num_pcds / 2 - 0.5)], 64 | [0, 1, 0, 0.275 * ind - (num_pcds / 2 - 0.5)], 65 | [0, 0, 1, 0], 66 | [0, 0, 0, 1]]) 67 | rotation_matrix = np.asarray( 68 | [[1, 0, 0, 0], 69 | [0, 0, -1, 0], 70 | [0, 1, 0, 0], 71 | [0, 0, 0, 1]]) 72 | transform_matrix = rotation_matrix @ translation_matrix 73 | pcd = pcd.transform(transform_matrix) 74 | X, Y, Z = get_pts(pcd) 75 | t = Z 76 | ax.scatter(X, Y, Z, c=t, cmap='jet', marker='o', s=1) 77 | ax.grid(False) 78 | ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0)) 79 | ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0)) 80 | ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0)) 81 | 82 | set_axes_equal(ax) 83 | if view_id != 0: 84 | plt.title(" Obj: %d Scan: %d" % (object_id, view_id), y=0.62, fontsize=20) 85 | else: 86 | gt_title = " GT" 87 | input_title = " Input" 88 | pcn_title = " pcn" 89 | topnet_title = " TopNet" 90 | msn_title = " MSN" 91 | cascade_title = " Cascaded" 92 | title = gt_title + input_title + pcn_title + topnet_title + msn_title + cascade_title 93 | plt.title(title + "\n Obj: %d Scan: %d" % (object_id, view_id), y=0.62, fontsize=20) 94 | 95 | plt.axis('off') 96 | bbox = fig.bbox_inches.from_bounds(19, 27, 25, 7) 97 | buf = io.BytesIO() 98 | plt.savefig(buf, format='png', bbox_inches=bbox) 99 | buf.seek(0) 100 | im = Image.open(buf) 101 | plt.close() 102 | return im 103 | 104 | 105 | if __name__ == '__main__': 106 | model_type = sys.argv[1] # cd or emd 107 | best_type = sys.argv[2] # cd_t or cd_p or emd 108 | results_prefix = sys.argv[3] 109 | 110 | # outputs by 4 models (pcn, Topnet, MSN, Cascade) 111 | output_file_pcn = h5py.File(os.path.join(results_prefix, 'pcn_%s/best_%s_network_pcds.h5' % (model_type, best_type)), 'r') 112 | pcn = output_file_pcn['output_pcds'][()] 113 | output_file_topnet = h5py.File(os.path.join(results_prefix,'topnet_%s/best_%s_network_pcds.h5' % (model_type, best_type)), 'r') 114 | topnet = output_file_topnet['output_pcds'][()] 115 | output_file_msn = h5py.File(os.path.join(results_prefix,'msn_%s/best_%s_network_pcds.h5' % (model_type, best_type)), 'r') 116 | msn = output_file_msn['output_pcds'][()] 117 | output_file_cascade = h5py.File(os.path.join(results_prefix,'cascade_%s/best_%s_network_pcds.h5' % (model_type, best_type)), 'r') 118 | cascade = output_file_cascade['output_pcds'][()] 119 | # gt 120 | gt_file = h5py.File('/mnt/lustre/chenxinyi1/pl/data_generation/my_dataset1/my_test_gt_data_2048_1.h5', 'r') 121 | novel_gt = gt_file['novel_complete_pcds'][()] 122 | gt = gt_file['complete_pcds'][()] 123 | gt = np.concatenate((gt, novel_gt), axis=0) 124 | # input 125 | input_file = h5py.File('/mnt/lustre/chenxinyi1/pl/data_generation/my_dataset1/my_test_input_data_denoised_1.h5', 'r') 126 | novel_input = input_file['novel_incomplete_pcds'][()] 127 | inputs = input_file['incomplete_pcds'][()] 128 | inputs = np.concatenate((inputs, novel_input), axis=0) 129 | 130 | to_plot = [7, 10, 15, 17, 55, 95, 123, 132, 133, 136, 131 | 183, 191, 192, 243, 249, 253, 254, 261, 266, 269, 132 | 303, 311, 323, 357, 367, 400, 405, 419, 434, 449, 133 | 459, 483, 561, 596, 134 | 601, 605, 612, 614, 630, 638, 646, 652, 668, 673, 135 | 787, 793, 799, 807, 823, 829, 851, 864, 879, 895, 136 | 902, 911, 913, 926, 930, 935, 960, 993, 1007, 1011, 137 | 1055, 1063, 1069, 1072, 1075, 1076, 1082, 1104, 1108, 1130, 138 | 1204, 1227, 1248, 139 | 1258, 1271, 1276, 1277, 1282, 140 | 1329, 1349, 141 | 1364, 1377, 1382, 1388, 142 | 1414, 1416, 1420, 143 | 1452, 1454, 1456, 1478, 1479, 144 | 1500, 1508, 1510, 1514, 1515, 1521, 1529, 1549, 145 | 1565, 1574] 146 | page_list = [] 147 | for ind, i in enumerate(to_plot): 148 | print('%d/%d' % (ind, len(to_plot))) 149 | gt_pcd = o3d.geometry.PointCloud() 150 | gt_pcd.points = o3d.utility.Vector3dVector(gt[i]) 151 | width, height = 2500, 700 152 | concat_im = Image.new('RGB', (width, 26 * height)) 153 | for j in range(0, 26): 154 | input_pcd = o3d.geometry.PointCloud() 155 | input_pcd.points = o3d.utility.Vector3dVector(inputs[i * 26 + j]) 156 | pcn_pcd = o3d.geometry.PointCloud() 157 | pcn_pcd.points = o3d.utility.Vector3dVector(pcn[i * 26 + j]) 158 | topnet_pcd = o3d.geometry.PointCloud() 159 | topnet_pcd.points = o3d.utility.Vector3dVector(topnet[i * 26 + j]) 160 | msn_pcd = o3d.geometry.PointCloud() 161 | msn_pcd.points = o3d.utility.Vector3dVector(msn[i * 26 + j]) 162 | cascade_pcd = o3d.geometry.PointCloud() 163 | cascade_pcd.points = o3d.utility.Vector3dVector(cascade[i * 26 + j]) 164 | pcds = plot_pcd([gt_pcd, input_pcd, pcn_pcd, topnet_pcd, msn_pcd, cascade_pcd], i, j) 165 | concat_im.paste(pcds, (0, j * height)) 166 | page_list.append(concat_im) 167 | first = page_list.pop(0) 168 | first.save(os.path.join(results_prefix, '%s_train_best_%s.pdf' % (model_type, best_type)), 169 | "PDF", resolution=100.0, optimize=True, save_all=True, append_images=page_list) 170 | 171 | 172 | 173 | 174 | 175 | 176 | --------------------------------------------------------------------------------