├── .gitignore ├── README.md ├── config └── s3dis │ └── s3dis_pointtransformer_repro.yaml ├── data └── s3dis │ └── s3dis_names.txt ├── env_setup.sh ├── lib ├── __init__.py └── pointops │ ├── __init__.py │ ├── functions │ ├── __init__.py │ └── pointops.py │ ├── setup.py │ └── src │ ├── __init__.py │ ├── aggregation │ ├── aggregation_cuda.cpp │ ├── aggregation_cuda_kernel.cu │ └── aggregation_cuda_kernel.h │ ├── cuda_utils.h │ ├── grouping │ ├── grouping_cuda.cpp │ ├── grouping_cuda_kernel.cu │ └── grouping_cuda_kernel.h │ ├── interpolation │ ├── interpolation_cuda.cpp │ ├── interpolation_cuda_kernel.cu │ └── interpolation_cuda_kernel.h │ ├── knnquery │ ├── knnquery_cuda.cpp │ ├── knnquery_cuda_kernel.cu │ └── knnquery_cuda_kernel.h │ ├── pointops_api.cpp │ ├── sampling │ ├── sampling_cuda.cpp │ ├── sampling_cuda_kernel.cu │ └── sampling_cuda_kernel.h │ └── subtraction │ ├── subtraction_cuda.cpp │ ├── subtraction_cuda_kernel.cu │ └── subtraction_cuda_kernel.h ├── model ├── __init__.py └── pointtransformer │ └── pointtransformer_seg.py ├── tool ├── test.py ├── test.sh ├── train.py └── train.sh └── util ├── __init__.py ├── common_util.py ├── config.py ├── data_util.py ├── s3dis.py ├── transform.py └── voxelize.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pth* 3 | .autoenv* 4 | runs 5 | build 6 | checkpoints 7 | *.prof 8 | .lvimrc 9 | .vimtags 10 | .ccls 11 | .ccls-cache/ 12 | dist/ 13 | point_transformer.egg-info/ 14 | *.zip 15 | *.so 16 | .tox/ 17 | .mypy_cache 18 | **/*.pyc 19 | point_transformer/data/modelnet40_normal_resampled/ 20 | point_transformer/data/modelnet40_normal_resampled_cache/ 21 | point_transformer/data/modelnet40_ply_hdf5_2048/ 22 | 23 | .ipynb_checkpoints/ 24 | point_transformer/models/test.py 25 | outputs/ 26 | .vscode 27 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Point Transformer 2 | This repository reproduces [Point Transformer](https://arxiv.org/abs/2012.09164). \ 3 | The codebase is provided by the first author of [Point Transformer](https://arxiv.org/abs/2012.09164). 4 | 5 | ## Notes 6 | - For shape classification and part segmentation, please use paconv-codebase branch. After some testing, we will merge it into the master branch. 7 | 8 | --- 9 | ## Dependencies 10 | - Ubuntu: 18.04 or higher 11 | - PyTorch: 1.9.0 12 | - CUDA: 11.1 13 | - Hardware: 4GPUs (TITAN RTX) to reproduce [Point Transformer](https://arxiv.org/abs/2012.09164) 14 | - To create conda environment, command as follows: 15 | 16 | ``` 17 | bash env_setup.sh pt 18 | ``` 19 | 20 | ## Dataset preparation 21 | - Download S3DIS [dataset](https://drive.google.com/uc?export=download&id=1KUxWagmEWnvMhEb4FRwq2Mj0aa3U3xUf) and symlink the paths to them as follows: 22 | 23 | ``` 24 | mkdir -p dataset 25 | ln -s /path_to_s3dis_dataset dataset/s3dis 26 | ``` 27 | 28 | ## Usage 29 | - Shape classification on ModelNet40 30 | - For now, please use paconv-codebase branch. 31 | - Part segmentation on ShapeNetPart 32 | - For now, please use paconv-codebase branch. 33 | - Semantic segmantation on S3DIS Area 5 34 | - Train 35 | 36 | - Specify the gpu used in config and then do training: 37 | 38 | ``` 39 | sh tool/train.sh s3dis pointtransformer_repro 40 | ``` 41 | 42 | - Test 43 | 44 | - Afer training, you can test the checkpoint as follows: 45 | 46 | ``` 47 | CUDA_VISIBLE_DEVICES=0 sh tool/test.sh s3dis pointtransformer_repro 48 | ``` 49 | --- 50 | ## Experimental Results 51 | 52 | - Semanctic Segmentation on S3DIS Area 5 53 | 54 | |Model | mAcc | OA | mIoU | 55 | |-------| ------| ----| -------| 56 | |Paper| 76.5 | 90.8 | 70.4 | 57 | |Hengshuang's code | 76.8 | 90.4 | 70.0 | 58 | --- 59 | ## References 60 | 61 | If you use this code, please cite [Point Transformer](https://arxiv.org/abs/2012.09164): 62 | ``` 63 | @inproceedings{zhao2021point, 64 | title={Point transformer}, 65 | author={Zhao, Hengshuang and Jiang, Li and Jia, Jiaya and Torr, Philip HS and Koltun, Vladlen}, 66 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 67 | pages={16259--16268}, 68 | year={2021} 69 | } 70 | ``` 71 | 72 | ## Acknowledgement 73 | The code is from the first author of [Point Transformer](https://arxiv.org/abs/2012.09164). 74 | We also refer [PAConv repository](https://github.com/CVMI-Lab/PAConv). 75 | -------------------------------------------------------------------------------- /config/s3dis/s3dis_pointtransformer_repro.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | data_name: s3dis 3 | data_root: dataset/s3dis/trainval_fullarea 4 | test_area: 5 5 | classes: 13 6 | fea_dim: 6 7 | voxel_size: 0.04 8 | voxel_max: 80000 9 | loop: 30 10 | 11 | TRAIN: 12 | arch: pointtransformer_seg_repro 13 | use_xyz: True 14 | sync_bn: False 15 | ignore_label: 255 16 | train_gpu: [0, 1, 2, 3] 17 | workers: 16 # data loader workers 18 | batch_size: 16 # batch size for training 19 | batch_size_val: 4 # batch size for validation during training, memory and speed tradeoff 20 | base_lr: 0.5 21 | epochs: 100 22 | start_epoch: 0 23 | step_epoch: 30 24 | multiplier: 0.1 25 | momentum: 0.9 26 | weight_decay: 0.0001 27 | drop_rate: 0.5 28 | manual_seed: 7777 29 | print_freq: 1 30 | save_freq: 1 31 | save_path: 32 | weight: # path to initial weight (default: none) 33 | resume: # path to latest checkpoint (default: none) 34 | evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend 35 | eval_freq: 1 36 | Distributed: 37 | dist_url: tcp://localhost:8888 38 | dist_backend: 'nccl' 39 | multiprocessing_distributed: True 40 | world_size: 1 41 | rank: 0 42 | 43 | TEST: 44 | test_list: dataset/s3dis/list/val5.txt 45 | test_list_full: dataset/s3dis/list/val5_full.txt 46 | split: val # split in [train, val and test] 47 | test_gpu: [0] 48 | test_workers: 4 49 | batch_size_test: 4 50 | model_path: 51 | save_folder: 52 | names_path: data/s3dis/s3dis_names.txt 53 | -------------------------------------------------------------------------------- /data/s3dis/s3dis_names.txt: -------------------------------------------------------------------------------- 1 | ceiling 2 | floor 3 | wall 4 | beam 5 | column 6 | window 7 | door 8 | chair 9 | table 10 | bookcase 11 | sofa 12 | board 13 | clutter 14 | -------------------------------------------------------------------------------- /env_setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | ENVS=$(conda env list | awk '{print $1}') 4 | 5 | if [[ $ENVS = *"$1"* ]]; then 6 | echo "[PT INFO] \"$1\" already exists. Pass the installation" 7 | else 8 | echo "[PT INFO] Creating $1..." 9 | conda create -n $1 python=3.7 -y 10 | conda activate "$1" 11 | echo "[PT INFO] Done !" 12 | 13 | echo "[PT INFO] Dependecies..." 14 | conda install pytorch=1.9.0 torchvision cudatoolkit=11.1 -c pytorch -c nvidia -y 15 | conda install -c anaconda h5py pyyaml -y 16 | conda install -c conda-forge sharedarray tensorboardx -y 17 | echo "[PT INFO] Done !" 18 | 19 | echo "[PT INFO] Installing cuda operations..." 20 | cd lib/pointops 21 | python3 setup.py install 22 | cd ../.. 23 | echo "[PT INFO] Done !" 24 | 25 | NVCC="$(nvcc --version)" 26 | TORCH="$(python -c "import torch; print(torch.__version__)")" 27 | 28 | echo "[PT INFO] Finished the installation!" 29 | echo "[PT INFO] ========== Configurations ==========" 30 | echo "$NVCC" 31 | echo "[PT INFO] PyTorch version: $TORCH" 32 | echo "[PT INFO] ====================================" 33 | 34 | fi; -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/point-transformer/10d43ab5210fc93ffa15886f2a4c6460cc308780/lib/__init__.py -------------------------------------------------------------------------------- /lib/pointops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/point-transformer/10d43ab5210fc93ffa15886f2a4c6460cc308780/lib/pointops/__init__.py -------------------------------------------------------------------------------- /lib/pointops/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/point-transformer/10d43ab5210fc93ffa15886f2a4c6460cc308780/lib/pointops/functions/__init__.py -------------------------------------------------------------------------------- /lib/pointops/functions/pointops.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch.autograd import Function 5 | import torch.nn as nn 6 | 7 | import pointops_cuda 8 | 9 | 10 | class FurthestSampling(Function): 11 | @staticmethod 12 | def forward(ctx, xyz, offset, new_offset): 13 | """ 14 | input: xyz: (n, 3), offset: (b), new_offset: (b) 15 | output: idx: (m) 16 | """ 17 | assert xyz.is_contiguous() 18 | n, b, n_max = xyz.shape[0], offset.shape[0], offset[0] 19 | for i in range(1, b): 20 | n_max = max(offset[i] - offset[i-1], n_max) 21 | idx = torch.cuda.IntTensor(new_offset[b-1].item()).zero_() 22 | tmp = torch.cuda.FloatTensor(n).fill_(1e10) 23 | pointops_cuda.furthestsampling_cuda(b, n_max, xyz, offset, new_offset, tmp, idx) 24 | del tmp 25 | return idx 26 | 27 | furthestsampling = FurthestSampling.apply 28 | 29 | 30 | class KNNQuery(Function): 31 | @staticmethod 32 | def forward(ctx, nsample, xyz, new_xyz, offset, new_offset): 33 | """ 34 | input: xyz: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) 35 | output: idx: (m, nsample), dist2: (m, nsample) 36 | """ 37 | if new_xyz is None: new_xyz = xyz 38 | assert xyz.is_contiguous() and new_xyz.is_contiguous() 39 | m = new_xyz.shape[0] 40 | idx = torch.cuda.IntTensor(m, nsample).zero_() 41 | dist2 = torch.cuda.FloatTensor(m, nsample).zero_() 42 | pointops_cuda.knnquery_cuda(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2) 43 | return idx, torch.sqrt(dist2) 44 | 45 | knnquery = KNNQuery.apply 46 | 47 | 48 | class Grouping(Function): 49 | @staticmethod 50 | def forward(ctx, input, idx): 51 | """ 52 | input: input: (n, c), idx : (m, nsample) 53 | output: (m, nsample, c) 54 | """ 55 | assert input.is_contiguous() and idx.is_contiguous() 56 | m, nsample, n, c = idx.shape[0], idx.shape[1], input.shape[0], input.shape[1] 57 | output = torch.cuda.FloatTensor(m, nsample, c) 58 | pointops_cuda.grouping_forward_cuda(m, nsample, c, input, idx, output) 59 | ctx.n = n 60 | ctx.save_for_backward(idx) 61 | return output 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | """ 66 | input: grad_out: (m, c, nsample) 67 | output: (n, c), None 68 | """ 69 | n = ctx.n 70 | idx, = ctx.saved_tensors 71 | m, nsample, c = grad_output.shape 72 | grad_input = torch.cuda.FloatTensor(n, c).zero_() 73 | pointops_cuda.grouping_backward_cuda(m, nsample, c, grad_output, idx, grad_input) 74 | return grad_input, None 75 | 76 | grouping = Grouping.apply 77 | 78 | 79 | def queryandgroup(nsample, xyz, new_xyz, feat, idx, offset, new_offset, use_xyz=True): 80 | """ 81 | input: xyz: (n, 3), new_xyz: (m, 3), feat: (n, c), idx: (m, nsample), offset: (b), new_offset: (b) 82 | output: new_feat: (m, c+3, nsample), grouped_idx: (m, nsample) 83 | """ 84 | assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() 85 | if new_xyz is None: 86 | new_xyz = xyz 87 | if idx is None: 88 | idx, _ = knnquery(nsample, xyz, new_xyz, offset, new_offset) # (m, nsample) 89 | 90 | n, m, c = xyz.shape[0], new_xyz.shape[0], feat.shape[1] 91 | grouped_xyz = xyz[idx.view(-1).long(), :].view(m, nsample, 3) # (m, nsample, 3) 92 | #grouped_xyz = grouping(xyz, idx) # (m, nsample, 3) 93 | grouped_xyz -= new_xyz.unsqueeze(1) # (m, nsample, 3) 94 | grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c) # (m, nsample, c) 95 | #grouped_feat = grouping(feat, idx) # (m, nsample, c) 96 | 97 | if use_xyz: 98 | return torch.cat((grouped_xyz, grouped_feat), -1) # (m, nsample, 3+c) 99 | else: 100 | return grouped_feat 101 | 102 | 103 | class Subtraction(Function): 104 | @staticmethod 105 | def forward(ctx, input1, input2, idx): 106 | """ 107 | input: input1: (n, c), input2: (n, c), idx: (n, nsample) 108 | output: (n, nsample, c) 109 | """ 110 | assert input1.is_contiguous() and input2.is_contiguous() 111 | n, c = input1.shape; nsample = idx.shape[-1] 112 | output = torch.cuda.FloatTensor(n, nsample, c).zero_() 113 | pointops_cuda.subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output) 114 | ctx.save_for_backward(idx) 115 | return output 116 | 117 | @staticmethod 118 | def backward(ctx, grad_output): 119 | """ 120 | input: grad_out: (n, nsample, c) 121 | output: grad_input1: (n, c), grad_input2: (n, c) 122 | """ 123 | idx, = ctx.saved_tensors 124 | n, nsample, c = grad_output.shape 125 | grad_input1 = torch.cuda.FloatTensor(n, c).zero_() 126 | grad_input2 = torch.cuda.FloatTensor(n, c).zero_() 127 | pointops_cuda.subtraction_backward_cuda(n, nsample, c, idx, grad_output, grad_input1, grad_input2) 128 | return grad_input1, grad_input2, None 129 | 130 | subtraction = Subtraction.apply 131 | 132 | 133 | class Aggregation(Function): 134 | @staticmethod 135 | def forward(ctx, input, position, weight, idx): 136 | """ 137 | input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample) 138 | output: (n, c) 139 | """ 140 | assert input.is_contiguous() and position.is_contiguous() and weight.is_contiguous() 141 | n, nsample, c = position.shape; w_c = weight.shape[-1] 142 | output = torch.cuda.FloatTensor(n, c).zero_() 143 | pointops_cuda.aggregation_forward_cuda(n, nsample, c, w_c, input, position, weight, idx, output) 144 | ctx.save_for_backward(input, position, weight, idx) 145 | return output 146 | 147 | @staticmethod 148 | def backward(ctx, grad_output): 149 | """ 150 | input: grad_out: (n, c) 151 | output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c') 152 | """ 153 | input, position, weight, idx = ctx.saved_tensors 154 | n, nsample, c = position.shape; w_c = weight.shape[-1] 155 | grad_input = torch.cuda.FloatTensor(n, c).zero_() 156 | grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_() 157 | grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_() 158 | pointops_cuda.aggregation_backward_cuda(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight) 159 | return grad_input, grad_position, grad_weight, None 160 | 161 | aggregation = Aggregation.apply 162 | 163 | 164 | def interpolation(xyz, new_xyz, feat, offset, new_offset, k=3): 165 | """ 166 | input: xyz: (m, 3), new_xyz: (n, 3), feat: (m, c), offset: (b), new_offset: (b) 167 | output: (n, c) 168 | """ 169 | assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() 170 | idx, dist = knnquery(k, xyz, new_xyz, offset, new_offset) # (n, 3), (n, 3) 171 | dist_recip = 1.0 / (dist + 1e-8) # (n, 3) 172 | norm = torch.sum(dist_recip, dim=1, keepdim=True) 173 | weight = dist_recip / norm # (n, 3) 174 | 175 | new_feat = torch.cuda.FloatTensor(new_xyz.shape[0], feat.shape[1]).zero_() 176 | for i in range(k): 177 | new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) 178 | return new_feat 179 | 180 | 181 | class Interpolation(Function): 182 | @staticmethod 183 | def forward(ctx, xyz, new_xyz, input, offset, new_offset, k=3): 184 | """ 185 | input: xyz: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) 186 | output: (n, c) 187 | """ 188 | assert xyz.is_contiguous() and new_xyz.is_contiguous() and input.is_contiguous() 189 | idx, dist = knnquery(k, xyz, new_xyz, offset, new_offset) # (n, k), (n, k) 190 | dist_recip = 1.0 / (dist + 1e-8) # (n, k) 191 | norm = torch.sum(dist_recip, dim=1, keepdim=True) 192 | weight = dist_recip / norm # (n, k) 193 | 194 | n, c, m = new_xyz.shape[0], input.shape[1], input.shape[0] 195 | output = torch.cuda.FloatTensor(n, c).zero_() 196 | pointops_cuda.interpolation_forward_cuda(n, c, k, input, idx, weight, output) 197 | ctx.m, ctx.k = m, k 198 | ctx.save_for_backward(idx, weight) 199 | return output 200 | 201 | @staticmethod 202 | def backward(ctx, grad_output): 203 | """ 204 | input: xyz: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) 205 | output: (n, c) 206 | """ 207 | m, k = ctx.m, ctx.k 208 | idx, weight = ctx.saved_tensors 209 | n, c = grad_output.shape 210 | grad_input = torch.cuda.FloatTensor(m, c).zero_() 211 | pointops_cuda.interpolation_backward_cuda(n, c, k, grad_output, idx, weight, grad_input) 212 | return None, None, grad_input, None, None, None 213 | 214 | interpolation2 = Interpolation.apply 215 | -------------------------------------------------------------------------------- /lib/pointops/setup.py: -------------------------------------------------------------------------------- 1 | #python3 setup.py install 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | import os 5 | from distutils.sysconfig import get_config_vars 6 | 7 | (opt,) = get_config_vars('OPT') 8 | os.environ['OPT'] = " ".join( 9 | flag for flag in opt.split() if flag != '-Wstrict-prototypes' 10 | ) 11 | 12 | setup( 13 | name='pointops', 14 | author='Hengshuang Zhao', 15 | ext_modules=[ 16 | CUDAExtension('pointops_cuda', [ 17 | 'src/pointops_api.cpp', 18 | 'src/knnquery/knnquery_cuda.cpp', 19 | 'src/knnquery/knnquery_cuda_kernel.cu', 20 | 'src/sampling/sampling_cuda.cpp', 21 | 'src/sampling/sampling_cuda_kernel.cu', 22 | 'src/grouping/grouping_cuda.cpp', 23 | 'src/grouping/grouping_cuda_kernel.cu', 24 | 'src/interpolation/interpolation_cuda.cpp', 25 | 'src/interpolation/interpolation_cuda_kernel.cu', 26 | 'src/subtraction/subtraction_cuda.cpp', 27 | 'src/subtraction/subtraction_cuda_kernel.cu', 28 | 'src/aggregation/aggregation_cuda.cpp', 29 | 'src/aggregation/aggregation_cuda_kernel.cu', 30 | ], 31 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']} 32 | ) 33 | ], 34 | cmdclass={'build_ext': BuildExtension} 35 | ) 36 | -------------------------------------------------------------------------------- /lib/pointops/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/point-transformer/10d43ab5210fc93ffa15886f2a4c6460cc308780/lib/pointops/src/__init__.py -------------------------------------------------------------------------------- /lib/pointops/src/aggregation/aggregation_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "aggregation_cuda_kernel.h" 6 | 7 | 8 | void aggregation_forward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) 9 | { 10 | const float *input = input_tensor.data_ptr(); 11 | const float *position = position_tensor.data_ptr(); 12 | const float *weight = weight_tensor.data_ptr(); 13 | const int *idx = idx_tensor.data_ptr(); 14 | float *output = output_tensor.data_ptr(); 15 | aggregation_forward_cuda_launcher(n, nsample, c, w_c, input, position, weight, idx, output); 16 | } 17 | 18 | void aggregation_backward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input_tensor, at::Tensor grad_position_tensor, at::Tensor grad_weight_tensor) 19 | { 20 | const float *input = input_tensor.data_ptr(); 21 | const float *position = position_tensor.data_ptr(); 22 | const float *weight = weight_tensor.data_ptr(); 23 | const int *idx = idx_tensor.data_ptr(); 24 | const float *grad_output = grad_output_tensor.data_ptr(); 25 | float *grad_input = grad_input_tensor.data_ptr(); 26 | float *grad_position = grad_position_tensor.data_ptr(); 27 | float *grad_weight = grad_weight_tensor.data_ptr(); 28 | aggregation_backward_cuda_launcher(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight); 29 | } 30 | -------------------------------------------------------------------------------- /lib/pointops/src/aggregation/aggregation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "aggregation_cuda_kernel.h" 3 | 4 | 5 | __global__ void aggregation_forward_cuda_kernel(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output) { 6 | // input: input: (n, c), position: (n, nsample, c), weight: (n, nsample, w_c), idx: (n, nsample), output: (n, c) 7 | int index = blockIdx.x * blockDim.x + threadIdx.x; 8 | if (index >= n * c) return; 9 | const int c_idx = index % c; 10 | const int n_idx = index / c; 11 | const int w_c_idx = c_idx % w_c; 12 | for (int nsample_idx = 0; nsample_idx < nsample; nsample_idx++) 13 | { 14 | int idx_idx = n_idx * nsample + nsample_idx; 15 | int input_idx = idx[idx_idx] * c + c_idx; 16 | int position_idx = n_idx * nsample * c + nsample_idx * c + c_idx; 17 | int weight_idx = n_idx * nsample * w_c + nsample_idx * w_c + w_c_idx; 18 | output[index] += (input[input_idx] + position[position_idx]) * weight[weight_idx]; 19 | } 20 | } 21 | 22 | __global__ void aggregation_backward_cuda_kernel(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight) { 23 | // input: grad_output: (n, c), output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight: (n, nsample, w_c) 24 | int index = blockIdx.x * blockDim.x + threadIdx.x; 25 | if (index >= n * c) return; 26 | const int c_idx = index % c; 27 | const int n_idx = index / c; 28 | const int w_c_idx = c_idx % w_c; 29 | for (int nsample_idx = 0; nsample_idx < nsample; nsample_idx++) 30 | { 31 | int idx_idx = n_idx * nsample + nsample_idx; 32 | int input_idx = idx[idx_idx] * c + c_idx; 33 | int position_idx = n_idx * nsample * c + nsample_idx * c + c_idx; 34 | int weight_idx = n_idx * nsample * w_c + nsample_idx * w_c + w_c_idx; 35 | atomicAdd(grad_input + input_idx, grad_output[index] * weight[weight_idx]); 36 | grad_position[position_idx] = grad_output[index] * weight[weight_idx]; 37 | atomicAdd(grad_weight + weight_idx, grad_output[index] * (input[input_idx] + position[position_idx])); 38 | } 39 | } 40 | 41 | void aggregation_forward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output) { 42 | // input: input: (n, c), position: (n, nsample, c), weight: (n, nsample, w_c), idx: (n, nsample), output: (n, c) 43 | dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); 44 | dim3 threads(THREADS_PER_BLOCK); 45 | aggregation_forward_cuda_kernel<<>>(n, nsample, c, w_c, input, position, weight, idx, output); 46 | } 47 | 48 | void aggregation_backward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight) { 49 | // input: grad_output: (n, c), output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight: (n, nsample, w_c) 50 | dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); 51 | dim3 threads(THREADS_PER_BLOCK); 52 | aggregation_backward_cuda_kernel<<>>(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight); 53 | } 54 | -------------------------------------------------------------------------------- /lib/pointops/src/aggregation/aggregation_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _AGGREGATION_CUDA_KERNEL 2 | #define _AGGREGATION_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void aggregation_forward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); 8 | void aggregation_backward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input_tensor, at::Tensor grad_position_tensor, at::Tensor grad_weight_tensor); 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | void aggregation_forward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output); 15 | void aggregation_backward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | #endif 21 | -------------------------------------------------------------------------------- /lib/pointops/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | 7 | #define TOTAL_THREADS 1024 8 | #define THREADS_PER_BLOCK 256 9 | #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) 10 | 11 | inline int opt_n_threads(int work_size) { 12 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 13 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | 16 | inline dim3 opt_block_config(int x, int y) { 17 | const int x_threads = opt_n_threads(x); 18 | const int y_threads = std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 19 | dim3 block_config(x_threads, y_threads, 1); 20 | return block_config; 21 | } 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /lib/pointops/src/grouping/grouping_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "grouping_cuda_kernel.h" 6 | 7 | 8 | void grouping_forward_cuda(int m, int nsample, int c, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) 9 | { 10 | const float *input = input_tensor.data_ptr(); 11 | const int *idx = idx_tensor.data_ptr(); 12 | float *output = output_tensor.data_ptr(); 13 | grouping_forward_cuda_launcher(m, nsample, c, input, idx, output); 14 | } 15 | 16 | void grouping_backward_cuda(int m, int nsample, int c, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor grad_input_tensor) 17 | { 18 | const float *grad_output = grad_output_tensor.data_ptr(); 19 | const int *idx = idx_tensor.data_ptr(); 20 | float *grad_input = grad_input_tensor.data_ptr(); 21 | grouping_backward_cuda_launcher(m, nsample, c, grad_output, idx, grad_input); 22 | } 23 | -------------------------------------------------------------------------------- /lib/pointops/src/grouping/grouping_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "grouping_cuda_kernel.h" 3 | 4 | 5 | __global__ void grouping_forward_cuda_kernel(int m, int nsample, int c, const float *__restrict__ input, const int *__restrict__ idx, float *__restrict__ output) { 6 | // input: input: (n, c), idx: (m, nsample), output: (m, nsample, c) 7 | int index = blockIdx.x * blockDim.x + threadIdx.x; 8 | if (index >= m * nsample * c) return; 9 | const int c_idx = index % c; 10 | const int nsample_idx = (index / c) % nsample; 11 | const int m_idx = index / nsample / c; 12 | const int input_idx = idx[m_idx * nsample + nsample_idx] * c + c_idx; 13 | output[index] = input[input_idx]; 14 | } 15 | 16 | __global__ void grouping_backward_cuda_kernel(int m, int nsample, int c, const float *__restrict__ grad_output, const int *__restrict__ idx, float *__restrict__ grad_input) { 17 | // input: grad_output: (m, nsample, c), idx: (m, nsample), output: grad_input: (n, c) 18 | int index = blockIdx.x * blockDim.x + threadIdx.x; 19 | if (index >= m * nsample * c) return; 20 | const int c_idx = index % c; 21 | const int nsample_idx = (index / c) % nsample; 22 | const int m_idx = index / nsample / c; 23 | const int input_idx = idx[m_idx * nsample + nsample_idx] * c + c_idx; 24 | atomicAdd(grad_input + input_idx, grad_output[index]); 25 | } 26 | 27 | void grouping_forward_cuda_launcher(int m, int nsample, int c, const float *input, const int *idx, float *output) { 28 | // input: input: (n, c), idx: (m, nsample), output: (m, nsample, c) 29 | dim3 blocks(DIVUP(m * nsample * c, THREADS_PER_BLOCK)); 30 | dim3 threads(THREADS_PER_BLOCK); 31 | grouping_forward_cuda_kernel<<>>(m, nsample, c, input, idx, output); 32 | } 33 | 34 | void grouping_backward_cuda_launcher(int m, int nsample, int c, const float *grad_output, const int *idx, float *grad_input) 35 | { 36 | // input: grad_output: (m, nsample, c), idx: (m, nsample), output: grad_input: (n, c) 37 | dim3 blocks(DIVUP(m * nsample * c, THREADS_PER_BLOCK)); 38 | dim3 threads(THREADS_PER_BLOCK); 39 | grouping_backward_cuda_kernel<<>>(m, nsample, c, grad_output, idx, grad_input); 40 | } 41 | -------------------------------------------------------------------------------- /lib/pointops/src/grouping/grouping_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _GROUPING_CUDA_KERNEL 2 | #define _GROUPING_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void grouping_forward_cuda(int m, int nsample, int c, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); 8 | void grouping_backward_cuda(int m, int nsample, int c, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor grad_input_tensor); 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | void grouping_forward_cuda_launcher(int m, int nsample, int c, const float *input, const int *idx, float *output); 15 | void grouping_backward_cuda_launcher(int m, int nsample, int c, const float *grad_output, const int *idx, float *grad_input); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | #endif 21 | -------------------------------------------------------------------------------- /lib/pointops/src/interpolation/interpolation_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "interpolation_cuda_kernel.h" 6 | 7 | 8 | void interpolation_forward_cuda(int n, int c, int k, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor output_tensor) 9 | { 10 | const float *input = input_tensor.data_ptr(); 11 | const int *idx = idx_tensor.data_ptr(); 12 | const float *weight = weight_tensor.data_ptr(); 13 | float *output = output_tensor.data_ptr(); 14 | interpolation_forward_cuda_launcher(n, c, k, input, idx, weight, output); 15 | } 16 | 17 | void interpolation_backward_cuda(int n, int c, int k, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_input_tensor) 18 | { 19 | const float *grad_output = grad_output_tensor.data_ptr(); 20 | const int *idx = idx_tensor.data_ptr(); 21 | const float *weight = weight_tensor.data_ptr(); 22 | float *grad_input = grad_input_tensor.data_ptr(); 23 | interpolation_backward_cuda_launcher(n, c, k, grad_output, idx, weight, grad_input); 24 | } 25 | -------------------------------------------------------------------------------- /lib/pointops/src/interpolation/interpolation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "interpolation_cuda_kernel.h" 3 | 4 | 5 | __global__ void interpolation_forward_cuda_kernel(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output) 6 | { 7 | // input: input: (m, c), idx: (n, k), weight: (n, k), output: output (n, c) 8 | int index = blockIdx.x * blockDim.x + threadIdx.x; 9 | if (index >= n * c) return; 10 | int c_idx = index % c; 11 | int n_idx = index / c; 12 | for (int i = 0; i < k; i++) 13 | { 14 | int idx_idx = n_idx * k + i; 15 | int input_idx = idx[idx_idx] * c + c_idx; 16 | output[index] += input[input_idx] * weight[idx_idx]; 17 | } 18 | } 19 | 20 | __global__ void interpolation_backward_cuda_kernel(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input) 21 | { 22 | // input: grad_output: (n, c), idx: (n, k), weight: (n, k), output: grad_input (m, c) 23 | int index = blockIdx.x * blockDim.x + threadIdx.x; 24 | if (index >= n * c) return; 25 | int c_idx = index % c; 26 | int n_idx = index / c; 27 | for (int i = 0; i < k; i++) 28 | { 29 | int idx_idx = n_idx * k + i; 30 | int input_idx = idx[idx_idx] * c + c_idx; 31 | atomicAdd(grad_input + input_idx, grad_output[index] * weight[idx_idx]); 32 | } 33 | } 34 | 35 | void interpolation_forward_cuda_launcher(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output) { 36 | // input: input: (m, c), idx: (n, k), weight: (n, k), output: output (n, c) 37 | dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); 38 | dim3 threads(THREADS_PER_BLOCK); 39 | interpolation_forward_cuda_kernel<<>>(n, c, k, input, idx, weight, output); 40 | } 41 | 42 | void interpolation_backward_cuda_launcher(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input) { 43 | // input: grad_output: (n, c), idx: (n, k), weight: (n, k), output: grad_input (m, c) 44 | dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); 45 | dim3 threads(THREADS_PER_BLOCK); 46 | interpolation_backward_cuda_kernel<<>>(n, c, k, grad_output, idx, weight, grad_input); 47 | } 48 | -------------------------------------------------------------------------------- /lib/pointops/src/interpolation/interpolation_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATION_CUDA_KERNEL 2 | #define _INTERPOLATION_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void interpolation_forward_cuda(int n, int c, int k, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor output_tensor); 8 | void interpolation_backward_cuda(int n, int c, int k, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_input_tensor); 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | void interpolation_forward_cuda_launcher(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output); 15 | void interpolation_backward_cuda_launcher(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | #endif 21 | -------------------------------------------------------------------------------- /lib/pointops/src/knnquery/knnquery_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "knnquery_cuda_kernel.h" 6 | 7 | 8 | void knnquery_cuda(int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor) 9 | { 10 | const float *xyz = xyz_tensor.data_ptr(); 11 | const float *new_xyz = new_xyz_tensor.data_ptr(); 12 | const int *offset = offset_tensor.data_ptr(); 13 | const int *new_offset = new_offset_tensor.data_ptr(); 14 | int *idx = idx_tensor.data_ptr(); 15 | float *dist2 = dist2_tensor.data_ptr(); 16 | knnquery_cuda_launcher(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2); 17 | } 18 | -------------------------------------------------------------------------------- /lib/pointops/src/knnquery/knnquery_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "knnquery_cuda_kernel.h" 3 | 4 | 5 | __device__ void swap_float(float *x, float *y) 6 | { 7 | float tmp = *x; 8 | *x = *y; 9 | *y = tmp; 10 | } 11 | 12 | 13 | __device__ void swap_int(int *x, int *y) 14 | { 15 | int tmp = *x; 16 | *x = *y; 17 | *y = tmp; 18 | } 19 | 20 | 21 | __device__ void reheap(float *dist, int *idx, int k) 22 | { 23 | int root = 0; 24 | int child = root * 2 + 1; 25 | while (child < k) 26 | { 27 | if(child + 1 < k && dist[child+1] > dist[child]) 28 | child++; 29 | if(dist[root] > dist[child]) 30 | return; 31 | swap_float(&dist[root], &dist[child]); 32 | swap_int(&idx[root], &idx[child]); 33 | root = child; 34 | child = root * 2 + 1; 35 | } 36 | } 37 | 38 | 39 | __device__ void heap_sort(float *dist, int *idx, int k) 40 | { 41 | int i; 42 | for (i = k - 1; i > 0; i--) 43 | { 44 | swap_float(&dist[0], &dist[i]); 45 | swap_int(&idx[0], &idx[i]); 46 | reheap(dist, idx, i); 47 | } 48 | } 49 | 50 | 51 | __device__ int get_bt_idx(int idx, const int *offset) 52 | { 53 | int i = 0; 54 | while (1) 55 | { 56 | if (idx < offset[i]) 57 | break; 58 | else 59 | i++; 60 | } 61 | return i; 62 | } 63 | 64 | 65 | __global__ void knnquery_cuda_kernel(int m, int nsample, const float *__restrict__ xyz, const float *__restrict__ new_xyz, const int *__restrict__ offset, const int *__restrict__ new_offset, int *__restrict__ idx, float *__restrict__ dist2) { 66 | // input: xyz (n, 3) new_xyz (m, 3) 67 | // output: idx (m, nsample) dist2 (m, nsample) 68 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 69 | if (pt_idx >= m) return; 70 | 71 | new_xyz += pt_idx * 3; 72 | idx += pt_idx * nsample; 73 | dist2 += pt_idx * nsample; 74 | int bt_idx = get_bt_idx(pt_idx, new_offset); 75 | int start; 76 | if (bt_idx == 0) 77 | start = 0; 78 | else 79 | start = offset[bt_idx - 1]; 80 | int end = offset[bt_idx]; 81 | 82 | float new_x = new_xyz[0]; 83 | float new_y = new_xyz[1]; 84 | float new_z = new_xyz[2]; 85 | 86 | float best_dist[100]; 87 | int best_idx[100]; 88 | for(int i = 0; i < nsample; i++){ 89 | best_dist[i] = 1e10; 90 | best_idx[i] = start; 91 | } 92 | for(int i = start; i < end; i++){ 93 | float x = xyz[i * 3 + 0]; 94 | float y = xyz[i * 3 + 1]; 95 | float z = xyz[i * 3 + 2]; 96 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 97 | if (d2 < best_dist[0]){ 98 | best_dist[0] = d2; 99 | best_idx[0] = i; 100 | reheap(best_dist, best_idx, nsample); 101 | } 102 | } 103 | heap_sort(best_dist, best_idx, nsample); 104 | for(int i = 0; i < nsample; i++){ 105 | idx[i] = best_idx[i]; 106 | dist2[i] = best_dist[i]; 107 | } 108 | } 109 | 110 | 111 | void knnquery_cuda_launcher(int m, int nsample, const float *xyz, const float *new_xyz, const int *offset, const int *new_offset, int *idx, float *dist2) { 112 | // input: new_xyz: (m, 3), xyz: (n, 3), idx: (m, nsample) 113 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK)); 114 | dim3 threads(THREADS_PER_BLOCK); 115 | knnquery_cuda_kernel<<>>(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2); 116 | } 117 | -------------------------------------------------------------------------------- /lib/pointops/src/knnquery/knnquery_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _KNNQUERY_CUDA_KERNEL 2 | #define _KNNQUERY_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void knnquery_cuda(int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor); 8 | 9 | #ifdef __cplusplus 10 | extern "C" { 11 | #endif 12 | 13 | void knnquery_cuda_launcher(int m, int nsample, const float *xyz, const float *new_xyz, const int *offset, const int *new_offset, int *idx, float *dist2); 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif 18 | #endif 19 | -------------------------------------------------------------------------------- /lib/pointops/src/pointops_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "knnquery/knnquery_cuda_kernel.h" 5 | #include "sampling/sampling_cuda_kernel.h" 6 | #include "grouping/grouping_cuda_kernel.h" 7 | #include "interpolation/interpolation_cuda_kernel.h" 8 | #include "aggregation/aggregation_cuda_kernel.h" 9 | #include "subtraction/subtraction_cuda_kernel.h" 10 | 11 | 12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 13 | m.def("knnquery_cuda", &knnquery_cuda, "knnquery_cuda"); 14 | m.def("furthestsampling_cuda", &furthestsampling_cuda, "furthestsampling_cuda"); 15 | m.def("grouping_forward_cuda", &grouping_forward_cuda, "grouping_forward_cuda"); 16 | m.def("grouping_backward_cuda", &grouping_backward_cuda, "grouping_backward_cuda"); 17 | m.def("interpolation_forward_cuda", &interpolation_forward_cuda, "interpolation_forward_cuda"); 18 | m.def("interpolation_backward_cuda", &interpolation_backward_cuda, "interpolation_backward_cuda"); 19 | m.def("subtraction_forward_cuda", &subtraction_forward_cuda, "subtraction_forward_cuda"); 20 | m.def("subtraction_backward_cuda", &subtraction_backward_cuda, "subtraction_backward_cuda"); 21 | m.def("aggregation_forward_cuda", &aggregation_forward_cuda, "aggregation_forward_cuda"); 22 | m.def("aggregation_backward_cuda", &aggregation_backward_cuda, "aggregation_backward_cuda"); 23 | } 24 | -------------------------------------------------------------------------------- /lib/pointops/src/sampling/sampling_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "sampling_cuda_kernel.h" 6 | 7 | 8 | void furthestsampling_cuda(int b, int n, at::Tensor xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor tmp_tensor, at::Tensor idx_tensor) 9 | { 10 | const float *xyz = xyz_tensor.data_ptr(); 11 | const int *offset = offset_tensor.data_ptr(); 12 | const int *new_offset = new_offset_tensor.data_ptr(); 13 | float *tmp = tmp_tensor.data_ptr(); 14 | int *idx = idx_tensor.data_ptr(); 15 | furthestsampling_cuda_launcher(b, n, xyz, offset, new_offset, tmp, idx); 16 | } 17 | -------------------------------------------------------------------------------- /lib/pointops/src/sampling/sampling_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "sampling_cuda_kernel.h" 3 | 4 | 5 | __device__ void __update(float *dists, int *dists_i, int idx1, int idx2) { 6 | const float v1 = dists[idx1], v2 = dists[idx2]; 7 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 8 | dists[idx1] = max(v1, v2); 9 | dists_i[idx1] = v2 > v1 ? i2 : i1; 10 | } 11 | 12 | // input xyz: (n, 3), tmp: (b, n_max) 13 | // ouput idx (m) 14 | template 15 | __global__ void furthestsampling_cuda_kernel(const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx) 16 | { 17 | __shared__ float dists[block_size]; 18 | __shared__ int dists_i[block_size]; 19 | 20 | int bid = blockIdx.x; 21 | int start_n, end_n, start_m, end_m, old; 22 | if (bid == 0) { 23 | start_n = 0; 24 | end_n = offset[0]; 25 | start_m = 0; 26 | end_m = new_offset[0]; 27 | old = 0; 28 | } 29 | else { 30 | start_n = offset[bid - 1]; 31 | end_n = offset[bid]; 32 | start_m = new_offset[bid - 1]; 33 | end_m = new_offset[bid]; 34 | old = offset[bid - 1]; 35 | } 36 | 37 | const int stride = block_size; 38 | int tid = threadIdx.x; 39 | if (tid == 0) idx[start_m] = start_n; 40 | 41 | __syncthreads(); 42 | for (int j = start_m + 1; j < end_m; j++) 43 | { 44 | int besti = start_n; 45 | float best = -1; 46 | float x1 = xyz[old * 3 + 0]; 47 | float y1 = xyz[old * 3 + 1]; 48 | float z1 = xyz[old * 3 + 2]; 49 | for (int k = start_n + tid; k < end_n; k += stride) 50 | { 51 | float x2 = xyz[k * 3 + 0]; 52 | float y2 = xyz[k * 3 + 1]; 53 | float z2 = xyz[k * 3 + 2]; 54 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 55 | float d2 = min(d, tmp[k]); 56 | tmp[k] = d2; 57 | besti = d2 > best ? k : besti; 58 | best = d2 > best ? d2 : best; 59 | } 60 | dists[tid] = best; 61 | dists_i[tid] = besti; 62 | __syncthreads(); 63 | 64 | if (block_size >= 1024) { 65 | if (tid < 512) { 66 | __update(dists, dists_i, tid, tid + 512); 67 | } 68 | __syncthreads(); 69 | } 70 | if (block_size >= 512) { 71 | if (tid < 256) { 72 | __update(dists, dists_i, tid, tid + 256); 73 | } 74 | __syncthreads(); 75 | } 76 | if (block_size >= 256) { 77 | if (tid < 128) { 78 | __update(dists, dists_i, tid, tid + 128); 79 | } 80 | __syncthreads(); 81 | } 82 | if (block_size >= 128) { 83 | if (tid < 64) { 84 | __update(dists, dists_i, tid, tid + 64); 85 | } 86 | __syncthreads(); 87 | } 88 | if (block_size >= 64) { 89 | if (tid < 32) { 90 | __update(dists, dists_i, tid, tid + 32); 91 | } 92 | __syncthreads(); 93 | } 94 | if (block_size >= 32) { 95 | if (tid < 16) { 96 | __update(dists, dists_i, tid, tid + 16); 97 | } 98 | __syncthreads(); 99 | } 100 | if (block_size >= 16) { 101 | if (tid < 8) { 102 | __update(dists, dists_i, tid, tid + 8); 103 | } 104 | __syncthreads(); 105 | } 106 | if (block_size >= 8) { 107 | if (tid < 4) { 108 | __update(dists, dists_i, tid, tid + 4); 109 | } 110 | __syncthreads(); 111 | } 112 | if (block_size >= 4) { 113 | if (tid < 2) { 114 | __update(dists, dists_i, tid, tid + 2); 115 | } 116 | __syncthreads(); 117 | } 118 | if (block_size >= 2) { 119 | if (tid < 1) { 120 | __update(dists, dists_i, tid, tid + 1); 121 | } 122 | __syncthreads(); 123 | } 124 | 125 | old = dists_i[0]; 126 | if (tid == 0) 127 | idx[j] = old; 128 | } 129 | } 130 | 131 | void furthestsampling_cuda_launcher(int b, int n, const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx) 132 | { 133 | unsigned int n_threads = opt_n_threads(n); 134 | switch (n_threads) { 135 | case 1024: 136 | furthestsampling_cuda_kernel<1024><<>>(xyz, offset, new_offset, tmp, idx); 137 | break; 138 | case 512: 139 | furthestsampling_cuda_kernel<512><<>>(xyz, offset, new_offset, tmp, idx); 140 | break; 141 | case 256: 142 | furthestsampling_cuda_kernel<256><<>>(xyz, offset, new_offset, tmp, idx); 143 | break; 144 | case 128: 145 | furthestsampling_cuda_kernel<128><<>>(xyz, offset, new_offset, tmp, idx); 146 | break; 147 | case 64: 148 | furthestsampling_cuda_kernel<64><<>>(xyz, offset, new_offset, tmp, idx); 149 | break; 150 | case 32: 151 | furthestsampling_cuda_kernel<32><<>>(xyz, offset, new_offset, tmp, idx); 152 | break; 153 | case 16: 154 | furthestsampling_cuda_kernel<16><<>>(xyz, offset, new_offset, tmp, idx); 155 | break; 156 | case 8: 157 | furthestsampling_cuda_kernel<8><<>>(xyz, offset, new_offset, tmp, idx); 158 | break; 159 | case 4: 160 | furthestsampling_cuda_kernel<4><<>>(xyz, offset, new_offset, tmp, idx); 161 | break; 162 | case 2: 163 | furthestsampling_cuda_kernel<2><<>>(xyz, offset, new_offset, tmp, idx); 164 | break; 165 | case 1: 166 | furthestsampling_cuda_kernel<1><<>>(xyz, offset, new_offset, tmp, idx); 167 | break; 168 | default: 169 | furthestsampling_cuda_kernel<512><<>>(xyz, offset, new_offset, tmp, idx); 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /lib/pointops/src/sampling/sampling_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_CUDA_KERNEL 2 | #define _SAMPLING_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void furthestsampling_cuda(int b, int n, at::Tensor xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor tmp_tensor, at::Tensor idx_tensor); 8 | 9 | #ifdef __cplusplus 10 | extern "C" { 11 | #endif 12 | 13 | void furthestsampling_cuda_launcher(int b, int n, const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx); 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif 18 | #endif 19 | -------------------------------------------------------------------------------- /lib/pointops/src/subtraction/subtraction_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "subtraction_cuda_kernel.h" 6 | 7 | 8 | void subtraction_forward_cuda(int n, int nsample, int c, at::Tensor input1_tensor, at::Tensor input2_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) 9 | { 10 | const float *input1 = input1_tensor.data_ptr(); 11 | const float *input2 = input2_tensor.data_ptr(); 12 | const int *idx = idx_tensor.data_ptr(); 13 | float *output = output_tensor.data_ptr(); 14 | subtraction_forward_cuda_launcher(n, nsample, c, input1, input2, idx, output); 15 | } 16 | 17 | void subtraction_backward_cuda(int n, int nsample, int c, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input1_tensor, at::Tensor grad_input2_tensor) 18 | { 19 | const int *idx = idx_tensor.data_ptr(); 20 | const float *grad_output = grad_output_tensor.data_ptr(); 21 | float *grad_input1 = grad_input1_tensor.data_ptr(); 22 | float *grad_input2 = grad_input2_tensor.data_ptr(); 23 | subtraction_backward_cuda_launcher(n, nsample, c, idx, grad_output, grad_input1, grad_input2); 24 | } 25 | -------------------------------------------------------------------------------- /lib/pointops/src/subtraction/subtraction_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "subtraction_cuda_kernel.h" 3 | 4 | 5 | __global__ void subtraction_forward_cuda_kernel(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output) { 6 | // input: input1: (n, c), input2: (n, c), idx: (n, nsample), output: (n, nsample, c) 7 | int index = blockIdx.x * blockDim.x + threadIdx.x; 8 | if (index >= n * nsample * c) return; 9 | const int c_idx = index % c; 10 | const int nsample_idx = (index / c) % nsample; 11 | const int n_idx = index / nsample / c; 12 | const int idx_idx = n_idx * nsample + nsample_idx; 13 | const int input1_idx = n_idx * c + c_idx; 14 | const int input2_idx = idx[idx_idx] * c + c_idx; 15 | output[index] = input1[input1_idx] - input2[input2_idx]; 16 | } 17 | 18 | __global__ void subtraction_backward_cuda_kernel(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2) { 19 | // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) 20 | int index = blockIdx.x * blockDim.x + threadIdx.x; 21 | if (index >= n * nsample * c) return; 22 | const int c_idx = index % c; 23 | const int nsample_idx = (index / c) % nsample; 24 | const int n_idx = index / nsample / c; 25 | const int idx_idx = n_idx * nsample + nsample_idx; 26 | const int input1_idx = n_idx * c + c_idx; 27 | const int input2_idx = idx[idx_idx] * c + c_idx; 28 | atomicAdd(grad_input1 + input1_idx, grad_output[index]); 29 | atomicAdd(grad_input2 + input2_idx, -grad_output[index]); 30 | } 31 | 32 | void subtraction_forward_cuda_launcher(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output) { 33 | // input: input1: (n, c), input2: (n, c), idx: (n, nsample), output: (n, nsample, c) 34 | dim3 blocks(DIVUP(n * nsample * c, THREADS_PER_BLOCK)); 35 | dim3 threads(THREADS_PER_BLOCK); 36 | subtraction_forward_cuda_kernel<<>>(n, nsample, c, input1, input2, idx, output); 37 | } 38 | 39 | void subtraction_backward_cuda_launcher(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2) { 40 | // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) 41 | dim3 blocks(DIVUP(n * nsample * c, THREADS_PER_BLOCK)); 42 | dim3 threads(THREADS_PER_BLOCK); 43 | subtraction_backward_cuda_kernel<<>>(n, nsample, c, idx, grad_output, grad_input1, grad_input2); 44 | } 45 | -------------------------------------------------------------------------------- /lib/pointops/src/subtraction/subtraction_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _SUBTRACTION_CUDA_KERNEL 2 | #define _SUBTRACTION_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void subtraction_forward_cuda(int n, int nsample, int c, at::Tensor input1_tensor, at::Tensor input2_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); 8 | void subtraction_backward_cuda(int n, int nsample, int c, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input1_tensor, at::Tensor grad_input2_tensor); 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | void subtraction_forward_cuda_launcher(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output); 15 | void subtraction_backward_cuda_launcher(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | #endif 21 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/point-transformer/10d43ab5210fc93ffa15886f2a4c6460cc308780/model/__init__.py -------------------------------------------------------------------------------- /model/pointtransformer/pointtransformer_seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from lib.pointops.functions import pointops 5 | 6 | 7 | class PointTransformerLayer(nn.Module): 8 | def __init__(self, in_planes, out_planes, share_planes=8, nsample=16): 9 | super().__init__() 10 | self.mid_planes = mid_planes = out_planes // 1 11 | self.out_planes = out_planes 12 | self.share_planes = share_planes 13 | self.nsample = nsample 14 | self.linear_q = nn.Linear(in_planes, mid_planes) 15 | self.linear_k = nn.Linear(in_planes, mid_planes) 16 | self.linear_v = nn.Linear(in_planes, out_planes) 17 | self.linear_p = nn.Sequential(nn.Linear(3, 3), nn.BatchNorm1d(3), nn.ReLU(inplace=True), nn.Linear(3, out_planes)) 18 | self.linear_w = nn.Sequential(nn.BatchNorm1d(mid_planes), nn.ReLU(inplace=True), 19 | nn.Linear(mid_planes, mid_planes // share_planes), 20 | nn.BatchNorm1d(mid_planes // share_planes), nn.ReLU(inplace=True), 21 | nn.Linear(out_planes // share_planes, out_planes // share_planes)) 22 | self.softmax = nn.Softmax(dim=1) 23 | 24 | def forward(self, pxo) -> torch.Tensor: 25 | p, x, o = pxo # (n, 3), (n, c), (b) 26 | x_q, x_k, x_v = self.linear_q(x), self.linear_k(x), self.linear_v(x) # (n, c) 27 | x_k = pointops.queryandgroup(self.nsample, p, p, x_k, None, o, o, use_xyz=True) # (n, nsample, 3+c) 28 | x_v = pointops.queryandgroup(self.nsample, p, p, x_v, None, o, o, use_xyz=False) # (n, nsample, c) 29 | p_r, x_k = x_k[:, :, 0:3], x_k[:, :, 3:] 30 | for i, layer in enumerate(self.linear_p): p_r = layer(p_r.transpose(1, 2).contiguous()).transpose(1, 2).contiguous() if i == 1 else layer(p_r) # (n, nsample, c) 31 | w = x_k - x_q.unsqueeze(1) + p_r.view(p_r.shape[0], p_r.shape[1], self.out_planes // self.mid_planes, self.mid_planes).sum(2) # (n, nsample, c) 32 | for i, layer in enumerate(self.linear_w): w = layer(w.transpose(1, 2).contiguous()).transpose(1, 2).contiguous() if i % 3 == 0 else layer(w) 33 | w = self.softmax(w) # (n, nsample, c) 34 | n, nsample, c = x_v.shape; s = self.share_planes 35 | x = ((x_v + p_r).view(n, nsample, s, c // s) * w.unsqueeze(2)).sum(1).view(n, c) 36 | return x 37 | 38 | 39 | class TransitionDown(nn.Module): 40 | def __init__(self, in_planes, out_planes, stride=1, nsample=16): 41 | super().__init__() 42 | self.stride, self.nsample = stride, nsample 43 | if stride != 1: 44 | self.linear = nn.Linear(3+in_planes, out_planes, bias=False) 45 | self.pool = nn.MaxPool1d(nsample) 46 | else: 47 | self.linear = nn.Linear(in_planes, out_planes, bias=False) 48 | self.bn = nn.BatchNorm1d(out_planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | 51 | def forward(self, pxo): 52 | p, x, o = pxo # (n, 3), (n, c), (b) 53 | if self.stride != 1: 54 | n_o, count = [o[0].item() // self.stride], o[0].item() // self.stride 55 | for i in range(1, o.shape[0]): 56 | count += (o[i].item() - o[i-1].item()) // self.stride 57 | n_o.append(count) 58 | n_o = torch.cuda.IntTensor(n_o) 59 | idx = pointops.furthestsampling(p, o, n_o) # (m) 60 | n_p = p[idx.long(), :] # (m, 3) 61 | x = pointops.queryandgroup(self.nsample, p, n_p, x, None, o, n_o, use_xyz=True) # (m, 3+c, nsample) 62 | x = self.relu(self.bn(self.linear(x).transpose(1, 2).contiguous())) # (m, c, nsample) 63 | x = self.pool(x).squeeze(-1) # (m, c) 64 | p, o = n_p, n_o 65 | else: 66 | x = self.relu(self.bn(self.linear(x))) # (n, c) 67 | return [p, x, o] 68 | 69 | 70 | class TransitionUp(nn.Module): 71 | def __init__(self, in_planes, out_planes=None): 72 | super().__init__() 73 | if out_planes is None: 74 | self.linear1 = nn.Sequential(nn.Linear(2*in_planes, in_planes), nn.BatchNorm1d(in_planes), nn.ReLU(inplace=True)) 75 | self.linear2 = nn.Sequential(nn.Linear(in_planes, in_planes), nn.ReLU(inplace=True)) 76 | else: 77 | self.linear1 = nn.Sequential(nn.Linear(out_planes, out_planes), nn.BatchNorm1d(out_planes), nn.ReLU(inplace=True)) 78 | self.linear2 = nn.Sequential(nn.Linear(in_planes, out_planes), nn.BatchNorm1d(out_planes), nn.ReLU(inplace=True)) 79 | 80 | def forward(self, pxo1, pxo2=None): 81 | if pxo2 is None: 82 | _, x, o = pxo1 # (n, 3), (n, c), (b) 83 | x_tmp = [] 84 | for i in range(o.shape[0]): 85 | if i == 0: 86 | s_i, e_i, cnt = 0, o[0], o[0] 87 | else: 88 | s_i, e_i, cnt = o[i-1], o[i], o[i] - o[i-1] 89 | x_b = x[s_i:e_i, :] 90 | x_b = torch.cat((x_b, self.linear2(x_b.sum(0, True) / cnt).repeat(cnt, 1)), 1) 91 | x_tmp.append(x_b) 92 | x = torch.cat(x_tmp, 0) 93 | x = self.linear1(x) 94 | else: 95 | p1, x1, o1 = pxo1; p2, x2, o2 = pxo2 96 | x = self.linear1(x1) + pointops.interpolation(p2, p1, self.linear2(x2), o2, o1) 97 | return x 98 | 99 | 100 | class PointTransformerBlock(nn.Module): 101 | expansion = 1 102 | 103 | def __init__(self, in_planes, planes, share_planes=8, nsample=16): 104 | super(PointTransformerBlock, self).__init__() 105 | self.linear1 = nn.Linear(in_planes, planes, bias=False) 106 | self.bn1 = nn.BatchNorm1d(planes) 107 | self.transformer2 = PointTransformerLayer(planes, planes, share_planes, nsample) 108 | self.bn2 = nn.BatchNorm1d(planes) 109 | self.linear3 = nn.Linear(planes, planes * self.expansion, bias=False) 110 | self.bn3 = nn.BatchNorm1d(planes * self.expansion) 111 | self.relu = nn.ReLU(inplace=True) 112 | 113 | def forward(self, pxo): 114 | p, x, o = pxo # (n, 3), (n, c), (b) 115 | identity = x 116 | x = self.relu(self.bn1(self.linear1(x))) 117 | x = self.relu(self.bn2(self.transformer2([p, x, o]))) 118 | x = self.bn3(self.linear3(x)) 119 | x += identity 120 | x = self.relu(x) 121 | return [p, x, o] 122 | 123 | 124 | class PointTransformerSeg(nn.Module): 125 | def __init__(self, block, blocks, c=6, k=13): 126 | super().__init__() 127 | self.c = c 128 | self.in_planes, planes = c, [32, 64, 128, 256, 512] 129 | fpn_planes, fpnhead_planes, share_planes = 128, 64, 8 130 | stride, nsample = [1, 4, 4, 4, 4], [8, 16, 16, 16, 16] 131 | self.enc1 = self._make_enc(block, planes[0], blocks[0], share_planes, stride=stride[0], nsample=nsample[0]) # N/1 132 | self.enc2 = self._make_enc(block, planes[1], blocks[1], share_planes, stride=stride[1], nsample=nsample[1]) # N/4 133 | self.enc3 = self._make_enc(block, planes[2], blocks[2], share_planes, stride=stride[2], nsample=nsample[2]) # N/16 134 | self.enc4 = self._make_enc(block, planes[3], blocks[3], share_planes, stride=stride[3], nsample=nsample[3]) # N/64 135 | self.enc5 = self._make_enc(block, planes[4], blocks[4], share_planes, stride=stride[4], nsample=nsample[4]) # N/256 136 | self.dec5 = self._make_dec(block, planes[4], 2, share_planes, nsample=nsample[4], is_head=True) # transform p5 137 | self.dec4 = self._make_dec(block, planes[3], 2, share_planes, nsample=nsample[3]) # fusion p5 and p4 138 | self.dec3 = self._make_dec(block, planes[2], 2, share_planes, nsample=nsample[2]) # fusion p4 and p3 139 | self.dec2 = self._make_dec(block, planes[1], 2, share_planes, nsample=nsample[1]) # fusion p3 and p2 140 | self.dec1 = self._make_dec(block, planes[0], 2, share_planes, nsample=nsample[0]) # fusion p2 and p1 141 | self.cls = nn.Sequential(nn.Linear(planes[0], planes[0]), nn.BatchNorm1d(planes[0]), nn.ReLU(inplace=True), nn.Linear(planes[0], k)) 142 | 143 | def _make_enc(self, block, planes, blocks, share_planes=8, stride=1, nsample=16): 144 | layers = [] 145 | layers.append(TransitionDown(self.in_planes, planes * block.expansion, stride, nsample)) 146 | self.in_planes = planes * block.expansion 147 | for _ in range(1, blocks): 148 | layers.append(block(self.in_planes, self.in_planes, share_planes, nsample=nsample)) 149 | return nn.Sequential(*layers) 150 | 151 | def _make_dec(self, block, planes, blocks, share_planes=8, nsample=16, is_head=False): 152 | layers = [] 153 | layers.append(TransitionUp(self.in_planes, None if is_head else planes * block.expansion)) 154 | self.in_planes = planes * block.expansion 155 | for _ in range(1, blocks): 156 | layers.append(block(self.in_planes, self.in_planes, share_planes, nsample=nsample)) 157 | return nn.Sequential(*layers) 158 | 159 | def forward(self, pxo): 160 | p0, x0, o0 = pxo # (n, 3), (n, c), (b) 161 | x0 = p0 if self.c == 3 else torch.cat((p0, x0), 1) 162 | p1, x1, o1 = self.enc1([p0, x0, o0]) 163 | p2, x2, o2 = self.enc2([p1, x1, o1]) 164 | p3, x3, o3 = self.enc3([p2, x2, o2]) 165 | p4, x4, o4 = self.enc4([p3, x3, o3]) 166 | p5, x5, o5 = self.enc5([p4, x4, o4]) 167 | x5 = self.dec5[1:]([p5, self.dec5[0]([p5, x5, o5]), o5])[1] 168 | x4 = self.dec4[1:]([p4, self.dec4[0]([p4, x4, o4], [p5, x5, o5]), o4])[1] 169 | x3 = self.dec3[1:]([p3, self.dec3[0]([p3, x3, o3], [p4, x4, o4]), o3])[1] 170 | x2 = self.dec2[1:]([p2, self.dec2[0]([p2, x2, o2], [p3, x3, o3]), o2])[1] 171 | x1 = self.dec1[1:]([p1, self.dec1[0]([p1, x1, o1], [p2, x2, o2]), o1])[1] 172 | x = self.cls(x1) 173 | return x 174 | 175 | 176 | def pointtransformer_seg_repro(**kwargs): 177 | model = PointTransformerSeg(PointTransformerBlock, [2, 3, 4, 6, 3], **kwargs) 178 | return model 179 | -------------------------------------------------------------------------------- /tool/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import numpy as np 5 | import logging 6 | import pickle 7 | import argparse 8 | import collections 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.optim 14 | import torch.utils.data 15 | 16 | from util import config 17 | from util.common_util import AverageMeter, intersectionAndUnion, check_makedirs 18 | from util.voxelize import voxelize 19 | 20 | random.seed(123) 21 | np.random.seed(123) 22 | 23 | def get_parser(): 24 | parser = argparse.ArgumentParser(description='PyTorch Point Cloud Semantic Segmentation') 25 | parser.add_argument('--config', type=str, default='config/s3dis/s3dis_pointtransformer_repro.yaml', help='config file') 26 | parser.add_argument('opts', help='see config/s3dis/s3dis_pointtransformer_repro.yaml for all options', default=None, nargs=argparse.REMAINDER) 27 | args = parser.parse_args() 28 | assert args.config is not None 29 | cfg = config.load_cfg_from_cfg_file(args.config) 30 | if args.opts is not None: 31 | cfg = config.merge_cfg_from_list(cfg, args.opts) 32 | return cfg 33 | 34 | 35 | def get_logger(): 36 | logger_name = "main-logger" 37 | logger = logging.getLogger(logger_name) 38 | logger.setLevel(logging.INFO) 39 | handler = logging.StreamHandler() 40 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 41 | handler.setFormatter(logging.Formatter(fmt)) 42 | logger.addHandler(handler) 43 | return logger 44 | 45 | 46 | def main(): 47 | global args, logger 48 | args = get_parser() 49 | logger = get_logger() 50 | logger.info(args) 51 | assert args.classes > 1 52 | logger.info("=> creating model ...") 53 | logger.info("Classes: {}".format(args.classes)) 54 | 55 | if args.arch == 'pointtransformer_seg_repro': 56 | from model.pointtransformer.pointtransformer_seg import pointtransformer_seg_repro as Model 57 | else: 58 | raise Exception('architecture not supported yet'.format(args.arch)) 59 | model = Model(c=args.fea_dim, k=args.classes).cuda() 60 | logger.info(model) 61 | criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda() 62 | names = [line.rstrip('\n') for line in open(args.names_path)] 63 | if os.path.isfile(args.model_path): 64 | logger.info("=> loading checkpoint '{}'".format(args.model_path)) 65 | checkpoint = torch.load(args.model_path) 66 | state_dict = checkpoint['state_dict'] 67 | new_state_dict = collections.OrderedDict() 68 | for k, v in state_dict.items(): 69 | name = k[7:] 70 | new_state_dict[name] = v 71 | model.load_state_dict(new_state_dict, strict=True) 72 | logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.model_path, checkpoint['epoch'])) 73 | args.epoch = checkpoint['epoch'] 74 | else: 75 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path)) 76 | test(model, criterion, names) 77 | 78 | 79 | def data_prepare(): 80 | if args.data_name == 's3dis': 81 | data_list = sorted(os.listdir(args.data_root)) 82 | data_list = [item[:-4] for item in data_list if 'Area_{}'.format(args.test_area) in item] 83 | else: 84 | raise Exception('dataset not supported yet'.format(args.data_name)) 85 | print("Totally {} samples in val set.".format(len(data_list))) 86 | return data_list 87 | 88 | 89 | def data_load(data_name): 90 | data_path = os.path.join(args.data_root, data_name + '.npy') 91 | data = np.load(data_path) # xyzrgbl, N*7 92 | coord, feat, label = data[:, :3], data[:, 3:6], data[:, 6] 93 | 94 | idx_data = [] 95 | if args.voxel_size: 96 | coord_min = np.min(coord, 0) 97 | coord -= coord_min 98 | idx_sort, count = voxelize(coord, args.voxel_size, mode=1) 99 | for i in range(count.max()): 100 | idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count 101 | idx_part = idx_sort[idx_select] 102 | idx_data.append(idx_part) 103 | else: 104 | idx_data.append(np.arange(label.shape[0])) 105 | return coord, feat, label, idx_data 106 | 107 | 108 | def input_normalize(coord, feat): 109 | coord_min = np.min(coord, 0) 110 | coord -= coord_min 111 | feat = feat / 255. 112 | return coord, feat 113 | 114 | 115 | def test(model, criterion, names): 116 | logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>') 117 | batch_time = AverageMeter() 118 | intersection_meter = AverageMeter() 119 | union_meter = AverageMeter() 120 | target_meter = AverageMeter() 121 | args.batch_size_test = 10 122 | model.eval() 123 | 124 | check_makedirs(args.save_folder) 125 | pred_save, label_save = [], [] 126 | data_list = data_prepare() 127 | for idx, item in enumerate(data_list): 128 | end = time.time() 129 | pred_save_path = os.path.join(args.save_folder, '{}_{}_pred.npy'.format(item, args.epoch)) 130 | label_save_path = os.path.join(args.save_folder, '{}_{}_label.npy'.format(item, args.epoch)) 131 | if os.path.isfile(pred_save_path) and os.path.isfile(label_save_path): 132 | logger.info('{}/{}: {}, loaded pred and label.'.format(idx + 1, len(data_list), item)) 133 | pred, label = np.load(pred_save_path), np.load(label_save_path) 134 | else: 135 | coord, feat, label, idx_data = data_load(item) 136 | pred = torch.zeros((label.size, args.classes)).cuda() 137 | idx_size = len(idx_data) 138 | idx_list, coord_list, feat_list, offset_list = [], [], [], [] 139 | for i in range(idx_size): 140 | logger.info('{}/{}: {}/{}/{}, {}'.format(idx + 1, len(data_list), i + 1, idx_size, idx_data[0].shape[0], item)) 141 | idx_part = idx_data[i] 142 | coord_part, feat_part = coord[idx_part], feat[idx_part] 143 | if args.voxel_max and coord_part.shape[0] > args.voxel_max: 144 | coord_p, idx_uni, cnt = np.random.rand(coord_part.shape[0]) * 1e-3, np.array([]), 0 145 | while idx_uni.size != idx_part.shape[0]: 146 | init_idx = np.argmin(coord_p) 147 | dist = np.sum(np.power(coord_part - coord_part[init_idx], 2), 1) 148 | idx_crop = np.argsort(dist)[:args.voxel_max] 149 | coord_sub, feat_sub, idx_sub = coord_part[idx_crop], feat_part[idx_crop], idx_part[idx_crop] 150 | dist = dist[idx_crop] 151 | delta = np.square(1 - dist / np.max(dist)) 152 | coord_p[idx_crop] += delta 153 | coord_sub, feat_sub = input_normalize(coord_sub, feat_sub) 154 | idx_list.append(idx_sub), coord_list.append(coord_sub), feat_list.append(feat_sub), offset_list.append(idx_sub.size) 155 | idx_uni = np.unique(np.concatenate((idx_uni, idx_sub))) 156 | else: 157 | coord_part, feat_part = input_normalize(coord_part, feat_part) 158 | idx_list.append(idx_part), coord_list.append(coord_part), feat_list.append(feat_part), offset_list.append(idx_part.size) 159 | batch_num = int(np.ceil(len(idx_list) / args.batch_size_test)) 160 | for i in range(batch_num): 161 | s_i, e_i = i * args.batch_size_test, min((i + 1) * args.batch_size_test, len(idx_list)) 162 | idx_part, coord_part, feat_part, offset_part = idx_list[s_i:e_i], coord_list[s_i:e_i], feat_list[s_i:e_i], offset_list[s_i:e_i] 163 | idx_part = np.concatenate(idx_part) 164 | coord_part = torch.FloatTensor(np.concatenate(coord_part)).cuda(non_blocking=True) 165 | feat_part = torch.FloatTensor(np.concatenate(feat_part)).cuda(non_blocking=True) 166 | offset_part = torch.IntTensor(np.cumsum(offset_part)).cuda(non_blocking=True) 167 | with torch.no_grad(): 168 | pred_part = model([coord_part, feat_part, offset_part]) # (n, k) 169 | torch.cuda.empty_cache() 170 | pred[idx_part, :] += pred_part 171 | logger.info('Test: {}/{}, {}/{}, {}/{}'.format(idx + 1, len(data_list), e_i, len(idx_list), args.voxel_max, idx_part.shape[0])) 172 | loss = criterion(pred, torch.LongTensor(label).cuda(non_blocking=True)) # for reference 173 | pred = pred.max(1)[1].data.cpu().numpy() 174 | 175 | # calculation 1: add per room predictions 176 | intersection, union, target = intersectionAndUnion(pred, label, args.classes, args.ignore_label) 177 | intersection_meter.update(intersection) 178 | union_meter.update(union) 179 | target_meter.update(target) 180 | 181 | accuracy = sum(intersection) / (sum(target) + 1e-10) 182 | batch_time.update(time.time() - end) 183 | logger.info('Test: [{}/{}]-{} ' 184 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 185 | 'Accuracy {accuracy:.4f}.'.format(idx + 1, len(data_list), label.size, batch_time=batch_time, accuracy=accuracy)) 186 | pred_save.append(pred); label_save.append(label) 187 | np.save(pred_save_path, pred); np.save(label_save_path, label) 188 | 189 | with open(os.path.join(args.save_folder, "pred.pickle"), 'wb') as handle: 190 | pickle.dump({'pred': pred_save}, handle, protocol=pickle.HIGHEST_PROTOCOL) 191 | with open(os.path.join(args.save_folder, "label.pickle"), 'wb') as handle: 192 | pickle.dump({'label': label_save}, handle, protocol=pickle.HIGHEST_PROTOCOL) 193 | 194 | # calculation 1 195 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) 196 | accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) 197 | mIoU1 = np.mean(iou_class) 198 | mAcc1 = np.mean(accuracy_class) 199 | allAcc1 = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) 200 | 201 | # calculation 2 202 | intersection, union, target = intersectionAndUnion(np.concatenate(pred_save), np.concatenate(label_save), args.classes, args.ignore_label) 203 | iou_class = intersection / (union + 1e-10) 204 | accuracy_class = intersection / (target + 1e-10) 205 | mIoU = np.mean(iou_class) 206 | mAcc = np.mean(accuracy_class) 207 | allAcc = sum(intersection) / (sum(target) + 1e-10) 208 | logger.info('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc)) 209 | logger.info('Val1 result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU1, mAcc1, allAcc1)) 210 | 211 | for i in range(args.classes): 212 | logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}, name: {}.'.format(i, iou_class[i], accuracy_class[i], names[i])) 213 | logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<') 214 | 215 | 216 | if __name__ == '__main__': 217 | main() 218 | -------------------------------------------------------------------------------- /tool/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export PYTHONPATH=./ 4 | eval "$(conda shell.bash hook)" 5 | PYTHON=python 6 | 7 | TEST_CODE=test.py 8 | 9 | dataset=$1 10 | exp_name=$2 11 | exp_dir=exp/${dataset}/${exp_name} 12 | model_dir=${exp_dir}/model 13 | result_dir=${exp_dir}/result 14 | config=config/${dataset}/${dataset}_${exp_name}.yaml 15 | 16 | mkdir -p ${result_dir}/last 17 | mkdir -p ${result_dir}/best 18 | 19 | now=$(date +"%Y%m%d_%H%M%S") 20 | cp ${config} tool/test.sh tool/${TEST_CODE} ${exp_dir} 21 | 22 | #: ' 23 | $PYTHON -u ${exp_dir}/${TEST_CODE} \ 24 | --config=${config} \ 25 | save_folder ${result_dir}/best \ 26 | model_path ${model_dir}/model_best.pth \ 27 | 2>&1 | tee ${exp_dir}/test_best-$now.log 28 | #' 29 | 30 | #: ' 31 | $PYTHON -u ${exp_dir}/${TEST_CODE} \ 32 | --config=${config} \ 33 | save_folder ${result_dir}/last \ 34 | model_path ${model_dir}/model_last.pth \ 35 | 2>&1 | tee ${exp_dir}/test_last-$now.log 36 | #' 37 | -------------------------------------------------------------------------------- /tool/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import numpy as np 5 | import logging 6 | import argparse 7 | import shutil 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.optim 14 | import torch.utils.data 15 | import torch.multiprocessing as mp 16 | import torch.distributed as dist 17 | import torch.optim.lr_scheduler as lr_scheduler 18 | from tensorboardX import SummaryWriter 19 | 20 | from util import config 21 | from util.s3dis import S3DIS 22 | from util.common_util import AverageMeter, intersectionAndUnionGPU, find_free_port 23 | from util.data_util import collate_fn 24 | from util import transform as t 25 | 26 | 27 | def get_parser(): 28 | parser = argparse.ArgumentParser(description='PyTorch Point Cloud Semantic Segmentation') 29 | parser.add_argument('--config', type=str, default='config/s3dis/s3dis_pointtransformer_repro.yaml', help='config file') 30 | parser.add_argument('opts', help='see config/s3dis/s3dis_pointtransformer_repro.yaml for all options', default=None, nargs=argparse.REMAINDER) 31 | args = parser.parse_args() 32 | assert args.config is not None 33 | cfg = config.load_cfg_from_cfg_file(args.config) 34 | if args.opts is not None: 35 | cfg = config.merge_cfg_from_list(cfg, args.opts) 36 | return cfg 37 | 38 | 39 | def get_logger(): 40 | logger_name = "main-logger" 41 | logger = logging.getLogger(logger_name) 42 | logger.setLevel(logging.INFO) 43 | handler = logging.StreamHandler() 44 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 45 | handler.setFormatter(logging.Formatter(fmt)) 46 | logger.addHandler(handler) 47 | return logger 48 | 49 | 50 | def worker_init_fn(worker_id): 51 | random.seed(args.manual_seed + worker_id) 52 | 53 | 54 | def main_process(): 55 | return not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % args.ngpus_per_node == 0) 56 | 57 | 58 | def main(): 59 | args = get_parser() 60 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.train_gpu) 61 | 62 | if args.manual_seed is not None: 63 | random.seed(args.manual_seed) 64 | np.random.seed(args.manual_seed) 65 | torch.manual_seed(args.manual_seed) 66 | torch.cuda.manual_seed(args.manual_seed) 67 | torch.cuda.manual_seed_all(args.manual_seed) 68 | cudnn.benchmark = False 69 | cudnn.deterministic = True 70 | if args.dist_url == "env://" and args.world_size == -1: 71 | args.world_size = int(os.environ["WORLD_SIZE"]) 72 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 73 | args.ngpus_per_node = len(args.train_gpu) 74 | if len(args.train_gpu) == 1: 75 | args.sync_bn = False 76 | args.distributed = False 77 | args.multiprocessing_distributed = False 78 | 79 | if args.data_name == 's3dis': 80 | S3DIS(split='train', data_root=args.data_root, test_area=args.test_area) 81 | S3DIS(split='val', data_root=args.data_root, test_area=args.test_area) 82 | else: 83 | raise NotImplementedError() 84 | if args.multiprocessing_distributed: 85 | port = find_free_port() 86 | args.dist_url = f"tcp://localhost:{port}" 87 | args.world_size = args.ngpus_per_node * args.world_size 88 | mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args.ngpus_per_node, args)) 89 | else: 90 | main_worker(args.train_gpu, args.ngpus_per_node, args) 91 | 92 | 93 | def main_worker(gpu, ngpus_per_node, argss): 94 | global args, best_iou 95 | args, best_iou = argss, 0 96 | if args.distributed: 97 | if args.dist_url == "env://" and args.rank == -1: 98 | args.rank = int(os.environ["RANK"]) 99 | if args.multiprocessing_distributed: 100 | args.rank = args.rank * ngpus_per_node + gpu 101 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) 102 | 103 | if args.arch == 'pointtransformer_seg_repro': 104 | from model.pointtransformer.pointtransformer_seg import pointtransformer_seg_repro as Model 105 | else: 106 | raise Exception('architecture not supported yet'.format(args.arch)) 107 | model = Model(c=args.fea_dim, k=args.classes) 108 | if args.sync_bn: 109 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 110 | criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda() 111 | 112 | optimizer = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) 113 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(args.epochs*0.6), int(args.epochs*0.8)], gamma=0.1) 114 | 115 | if main_process(): 116 | global logger, writer 117 | logger = get_logger() 118 | writer = SummaryWriter(args.save_path) 119 | logger.info(args) 120 | logger.info("=> creating model ...") 121 | logger.info("Classes: {}".format(args.classes)) 122 | logger.info(model) 123 | if args.distributed: 124 | torch.cuda.set_device(gpu) 125 | args.batch_size = int(args.batch_size / ngpus_per_node) 126 | args.batch_size_val = int(args.batch_size_val / ngpus_per_node) 127 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 128 | model = torch.nn.parallel.DistributedDataParallel( 129 | model.cuda(), 130 | device_ids=[gpu], 131 | find_unused_parameters=True if "transformer" in args.arch else False 132 | ) 133 | 134 | else: 135 | model = torch.nn.DataParallel(model.cuda()) 136 | 137 | if args.weight: 138 | if os.path.isfile(args.weight): 139 | if main_process(): 140 | logger.info("=> loading weight '{}'".format(args.weight)) 141 | checkpoint = torch.load(args.weight) 142 | model.load_state_dict(checkpoint['state_dict']) 143 | if main_process(): 144 | logger.info("=> loaded weight '{}'".format(args.weight)) 145 | else: 146 | logger.info("=> no weight found at '{}'".format(args.weight)) 147 | 148 | if args.resume: 149 | if os.path.isfile(args.resume): 150 | if main_process(): 151 | logger.info("=> loading checkpoint '{}'".format(args.resume)) 152 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda()) 153 | args.start_epoch = checkpoint['epoch'] 154 | model.load_state_dict(checkpoint['state_dict'], strict=True) 155 | optimizer.load_state_dict(checkpoint['optimizer']) 156 | scheduler.load_state_dict(checkpoint['scheduler']) 157 | #best_iou = 40.0 158 | best_iou = checkpoint['best_iou'] 159 | if main_process(): 160 | logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 161 | else: 162 | if main_process(): 163 | logger.info("=> no checkpoint found at '{}'".format(args.resume)) 164 | 165 | train_transform = t.Compose([t.RandomScale([0.9, 1.1]), t.ChromaticAutoContrast(), t.ChromaticTranslation(), t.ChromaticJitter(), t.HueSaturationTranslation()]) 166 | train_data = S3DIS(split='train', data_root=args.data_root, test_area=args.test_area, voxel_size=args.voxel_size, voxel_max=args.voxel_max, transform=train_transform, shuffle_index=True, loop=args.loop) 167 | if main_process(): 168 | logger.info("train_data samples: '{}'".format(len(train_data))) 169 | if args.distributed: 170 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) 171 | else: 172 | train_sampler = None 173 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True, collate_fn=collate_fn) 174 | 175 | val_loader = None 176 | if args.evaluate: 177 | val_transform = None 178 | val_data = S3DIS(split='val', data_root=args.data_root, test_area=args.test_area, voxel_size=args.voxel_size, voxel_max=800000, transform=val_transform) 179 | if args.distributed: 180 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_data) 181 | else: 182 | val_sampler = None 183 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler, collate_fn=collate_fn) 184 | 185 | for epoch in range(args.start_epoch, args.epochs): 186 | if args.distributed: 187 | train_sampler.set_epoch(epoch) 188 | loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, criterion, optimizer, epoch) 189 | scheduler.step() 190 | epoch_log = epoch + 1 191 | if main_process(): 192 | writer.add_scalar('loss_train', loss_train, epoch_log) 193 | writer.add_scalar('mIoU_train', mIoU_train, epoch_log) 194 | writer.add_scalar('mAcc_train', mAcc_train, epoch_log) 195 | writer.add_scalar('allAcc_train', allAcc_train, epoch_log) 196 | 197 | is_best = False 198 | if args.evaluate and (epoch_log % args.eval_freq == 0): 199 | if args.data_name == 'shapenet': 200 | raise NotImplementedError() 201 | else: 202 | loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion) 203 | 204 | if main_process(): 205 | writer.add_scalar('loss_val', loss_val, epoch_log) 206 | writer.add_scalar('mIoU_val', mIoU_val, epoch_log) 207 | writer.add_scalar('mAcc_val', mAcc_val, epoch_log) 208 | writer.add_scalar('allAcc_val', allAcc_val, epoch_log) 209 | is_best = mIoU_val > best_iou 210 | best_iou = max(best_iou, mIoU_val) 211 | 212 | if (epoch_log % args.save_freq == 0) and main_process(): 213 | filename = args.save_path + '/model/model_last.pth' 214 | logger.info('Saving checkpoint to: ' + filename) 215 | torch.save({'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 216 | 'scheduler': scheduler.state_dict(), 'best_iou': best_iou, 'is_best': is_best}, filename) 217 | if is_best: 218 | logger.info('Best validation mIoU updated to: {:.4f}'.format(best_iou)) 219 | shutil.copyfile(filename, args.save_path + '/model/model_best.pth') 220 | 221 | if main_process(): 222 | writer.close() 223 | logger.info('==>Training done!\nBest Iou: %.3f' % (best_iou)) 224 | 225 | 226 | def train(train_loader, model, criterion, optimizer, epoch): 227 | batch_time = AverageMeter() 228 | data_time = AverageMeter() 229 | loss_meter = AverageMeter() 230 | intersection_meter = AverageMeter() 231 | union_meter = AverageMeter() 232 | target_meter = AverageMeter() 233 | 234 | model.train() 235 | end = time.time() 236 | max_iter = args.epochs * len(train_loader) 237 | for i, (coord, feat, target, offset) in enumerate(train_loader): # (n, 3), (n, c), (n), (b) 238 | data_time.update(time.time() - end) 239 | coord, feat, target, offset = coord.cuda(non_blocking=True), feat.cuda(non_blocking=True), target.cuda(non_blocking=True), offset.cuda(non_blocking=True) 240 | output = model([coord, feat, offset]) 241 | if target.shape[-1] == 1: 242 | target = target[:, 0] # for cls 243 | loss = criterion(output, target) 244 | optimizer.zero_grad() 245 | loss.backward() 246 | optimizer.step() 247 | 248 | output = output.max(1)[1] 249 | n = coord.size(0) 250 | if args.multiprocessing_distributed: 251 | loss *= n 252 | count = target.new_tensor([n], dtype=torch.long) 253 | dist.all_reduce(loss), dist.all_reduce(count) 254 | n = count.item() 255 | loss /= n 256 | intersection, union, target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label) 257 | if args.multiprocessing_distributed: 258 | dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target) 259 | intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy() 260 | intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target) 261 | 262 | accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10) 263 | loss_meter.update(loss.item(), n) 264 | batch_time.update(time.time() - end) 265 | end = time.time() 266 | 267 | # calculate remain time 268 | current_iter = epoch * len(train_loader) + i + 1 269 | remain_iter = max_iter - current_iter 270 | remain_time = remain_iter * batch_time.avg 271 | t_m, t_s = divmod(remain_time, 60) 272 | t_h, t_m = divmod(t_m, 60) 273 | remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s)) 274 | 275 | if (i + 1) % args.print_freq == 0 and main_process(): 276 | logger.info('Epoch: [{}/{}][{}/{}] ' 277 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 278 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 279 | 'Remain {remain_time} ' 280 | 'Loss {loss_meter.val:.4f} ' 281 | 'Accuracy {accuracy:.4f}.'.format(epoch+1, args.epochs, i + 1, len(train_loader), 282 | batch_time=batch_time, data_time=data_time, 283 | remain_time=remain_time, 284 | loss_meter=loss_meter, 285 | accuracy=accuracy)) 286 | if main_process(): 287 | writer.add_scalar('loss_train_batch', loss_meter.val, current_iter) 288 | writer.add_scalar('mIoU_train_batch', np.mean(intersection / (union + 1e-10)), current_iter) 289 | writer.add_scalar('mAcc_train_batch', np.mean(intersection / (target + 1e-10)), current_iter) 290 | writer.add_scalar('allAcc_train_batch', accuracy, current_iter) 291 | 292 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) 293 | accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) 294 | mIoU = np.mean(iou_class) 295 | mAcc = np.mean(accuracy_class) 296 | allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) 297 | if main_process(): 298 | logger.info('Train result at epoch [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(epoch+1, args.epochs, mIoU, mAcc, allAcc)) 299 | return loss_meter.avg, mIoU, mAcc, allAcc 300 | 301 | 302 | def validate(val_loader, model, criterion): 303 | if main_process(): 304 | logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>') 305 | batch_time = AverageMeter() 306 | data_time = AverageMeter() 307 | loss_meter = AverageMeter() 308 | intersection_meter = AverageMeter() 309 | union_meter = AverageMeter() 310 | target_meter = AverageMeter() 311 | 312 | model.eval() 313 | end = time.time() 314 | for i, (coord, feat, target, offset) in enumerate(val_loader): 315 | data_time.update(time.time() - end) 316 | coord, feat, target, offset = coord.cuda(non_blocking=True), feat.cuda(non_blocking=True), target.cuda(non_blocking=True), offset.cuda(non_blocking=True) 317 | if target.shape[-1] == 1: 318 | target = target[:, 0] # for cls 319 | with torch.no_grad(): 320 | output = model([coord, feat, offset]) 321 | loss = criterion(output, target) 322 | 323 | output = output.max(1)[1] 324 | n = coord.size(0) 325 | if args.multiprocessing_distributed: 326 | loss *= n 327 | count = target.new_tensor([n], dtype=torch.long) 328 | dist.all_reduce(loss), dist.all_reduce(count) 329 | n = count.item() 330 | loss /= n 331 | 332 | intersection, union, target = intersectionAndUnionGPU(output, target, args.classes, args.ignore_label) 333 | if args.multiprocessing_distributed: 334 | dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target) 335 | intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy() 336 | intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target) 337 | 338 | accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10) 339 | loss_meter.update(loss.item(), n) 340 | batch_time.update(time.time() - end) 341 | end = time.time() 342 | if (i + 1) % args.print_freq == 0 and main_process(): 343 | logger.info('Test: [{}/{}] ' 344 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 345 | 'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 346 | 'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) ' 347 | 'Accuracy {accuracy:.4f}.'.format(i + 1, len(val_loader), 348 | data_time=data_time, 349 | batch_time=batch_time, 350 | loss_meter=loss_meter, 351 | accuracy=accuracy)) 352 | 353 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) 354 | accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) 355 | mIoU = np.mean(iou_class) 356 | mAcc = np.mean(accuracy_class) 357 | allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) 358 | 359 | if main_process(): 360 | logger.info('Val result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc)) 361 | for i in range(args.classes): 362 | logger.info('Class_{} Result: iou/accuracy {:.4f}/{:.4f}.'.format(i, iou_class[i], accuracy_class[i])) 363 | logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<') 364 | 365 | return loss_meter.avg, mIoU, mAcc, allAcc 366 | 367 | 368 | if __name__ == '__main__': 369 | import gc 370 | gc.collect() 371 | main() 372 | -------------------------------------------------------------------------------- /tool/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export PYTHONPATH=./ 4 | eval "$(conda shell.bash hook)" 5 | PYTHON=python 6 | 7 | TRAIN_CODE=train.py 8 | TEST_CODE=test.py 9 | 10 | dataset=$1 11 | exp_name=$2 12 | exp_dir=exp/${dataset}/${exp_name} 13 | model_dir=${exp_dir}/model 14 | result_dir=${exp_dir}/result 15 | config=config/${dataset}/${dataset}_${exp_name}.yaml 16 | 17 | mkdir -p ${model_dir} ${result_dir} 18 | mkdir -p ${result_dir}/last 19 | mkdir -p ${result_dir}/best 20 | cp tool/train.sh tool/${TRAIN_CODE} ${config} tool/test.sh tool/${TEST_CODE} ${exp_dir} 21 | 22 | 23 | now=$(date +"%Y%m%d_%H%M%S") 24 | $PYTHON ${exp_dir}/${TRAIN_CODE} \ 25 | --config=${config} \ 26 | save_path ${exp_dir} \ 27 | 2>&1 | tee ${exp_dir}/train-$now.log 28 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/point-transformer/10d43ab5210fc93ffa15886f2a4c6460cc308780/util/__init__.py -------------------------------------------------------------------------------- /util/common_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | def intersectionAndUnion(output, target, K, ignore_index=255): 26 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 27 | assert (output.ndim in [1, 2, 3]) 28 | assert output.shape == target.shape 29 | output = output.reshape(output.size).copy() 30 | target = target.reshape(target.size) 31 | output[np.where(target == ignore_index)[0]] = ignore_index 32 | intersection = output[np.where(output == target)[0]] 33 | area_intersection, _ = np.histogram(intersection, bins=np.arange(K+1)) 34 | area_output, _ = np.histogram(output, bins=np.arange(K+1)) 35 | area_target, _ = np.histogram(target, bins=np.arange(K+1)) 36 | area_union = area_output + area_target - area_intersection 37 | return area_intersection, area_union, area_target 38 | 39 | 40 | def intersectionAndUnionGPU(output, target, K, ignore_index=255): 41 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 42 | assert (output.dim() in [1, 2, 3]) 43 | assert output.shape == target.shape 44 | output = output.view(-1) 45 | target = target.view(-1) 46 | output[target == ignore_index] = ignore_index 47 | intersection = output[output == target] 48 | area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1) 49 | area_output = torch.histc(output, bins=K, min=0, max=K-1) 50 | area_target = torch.histc(target, bins=K, min=0, max=K-1) 51 | area_union = area_output + area_target - area_intersection 52 | return area_intersection, area_union, area_target 53 | 54 | 55 | def check_makedirs(dir_name): 56 | if not os.path.exists(dir_name): 57 | os.makedirs(dir_name) 58 | 59 | 60 | def find_free_port(): 61 | import socket 62 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 63 | # Binding to port 0 will cause the OS to find an available port for us 64 | sock.bind(("", 0)) 65 | port = sock.getsockname()[1] 66 | sock.close() 67 | # NOTE: there is still a chance the port could be taken by other processes. 68 | return port 69 | -------------------------------------------------------------------------------- /util/config.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # Functions for parsing args 3 | # ----------------------------------------------------------------------------- 4 | import yaml 5 | import os 6 | from ast import literal_eval 7 | import copy 8 | 9 | 10 | class CfgNode(dict): 11 | """ 12 | CfgNode represents an internal node in the configuration tree. It's a simple 13 | dict-like container that allows for attribute-based access to keys. 14 | """ 15 | 16 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 17 | # Recursively convert nested dictionaries in init_dict into CfgNodes 18 | init_dict = {} if init_dict is None else init_dict 19 | key_list = [] if key_list is None else key_list 20 | for k, v in init_dict.items(): 21 | if type(v) is dict: 22 | # Convert dict to CfgNode 23 | init_dict[k] = CfgNode(v, key_list=key_list + [k]) 24 | super(CfgNode, self).__init__(init_dict) 25 | 26 | def __getattr__(self, name): 27 | if name in self: 28 | return self[name] 29 | else: 30 | raise AttributeError(name) 31 | 32 | def __setattr__(self, name, value): 33 | self[name] = value 34 | 35 | def __str__(self): 36 | def _indent(s_, num_spaces): 37 | s = s_.split("\n") 38 | if len(s) == 1: 39 | return s_ 40 | first = s.pop(0) 41 | s = [(num_spaces * " ") + line for line in s] 42 | s = "\n".join(s) 43 | s = first + "\n" + s 44 | return s 45 | 46 | r = "" 47 | s = [] 48 | for k, v in sorted(self.items()): 49 | seperator = "\n" if isinstance(v, CfgNode) else " " 50 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 51 | attr_str = _indent(attr_str, 2) 52 | s.append(attr_str) 53 | r += "\n".join(s) 54 | return r 55 | 56 | def __repr__(self): 57 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) 58 | 59 | 60 | def load_cfg_from_cfg_file(file): 61 | cfg = {} 62 | assert os.path.isfile(file) and file.endswith('.yaml'), \ 63 | '{} is not a yaml file'.format(file) 64 | 65 | with open(file, 'r') as f: 66 | cfg_from_file = yaml.safe_load(f) 67 | 68 | for key in cfg_from_file: 69 | for k, v in cfg_from_file[key].items(): 70 | cfg[k] = v 71 | 72 | cfg = CfgNode(cfg) 73 | return cfg 74 | 75 | 76 | def merge_cfg_from_list(cfg, cfg_list): 77 | new_cfg = copy.deepcopy(cfg) 78 | assert len(cfg_list) % 2 == 0 79 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 80 | subkey = full_key.split('.')[-1] 81 | assert subkey in cfg, 'Non-existent key: {}'.format(full_key) 82 | value = _decode_cfg_value(v) 83 | value = _check_and_coerce_cfg_value_type( 84 | value, cfg[subkey], subkey, full_key 85 | ) 86 | setattr(new_cfg, subkey, value) 87 | 88 | return new_cfg 89 | 90 | 91 | def _decode_cfg_value(v): 92 | """Decodes a raw config value (e.g., from a yaml config files or command 93 | line argument) into a Python object. 94 | """ 95 | # All remaining processing is only applied to strings 96 | if not isinstance(v, str): 97 | return v 98 | # Try to interpret `v` as a: 99 | # string, number, tuple, list, dict, boolean, or None 100 | try: 101 | v = literal_eval(v) 102 | # The following two excepts allow v to pass through when it represents a 103 | # string. 104 | # 105 | # Longer explanation: 106 | # The type of v is always a string (before calling literal_eval), but 107 | # sometimes it *represents* a string and other times a data structure, like 108 | # a list. In the case that v represents a string, what we got back from the 109 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 110 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 111 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 112 | # will raise a SyntaxError. 113 | except ValueError: 114 | pass 115 | except SyntaxError: 116 | pass 117 | return v 118 | 119 | 120 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): 121 | """Checks that `replacement`, which is intended to replace `original` is of 122 | the right type. The type is correct if it matches exactly or is one of a few 123 | cases in which the type can be easily coerced. 124 | """ 125 | original_type = type(original) 126 | replacement_type = type(replacement) 127 | 128 | # The types must match (with some exceptions) 129 | if replacement_type == original_type or original is None: 130 | return replacement 131 | 132 | # Cast replacement from from_type to to_type if the replacement and original 133 | # types match from_type and to_type 134 | def conditional_cast(from_type, to_type): 135 | if replacement_type == from_type and original_type == to_type: 136 | return True, to_type(replacement) 137 | else: 138 | return False, None 139 | 140 | # Conditionally casts 141 | # list <-> tuple 142 | casts = [(tuple, list), (list, tuple)] 143 | # For py2: allow converting from str (bytes) to a unicode string 144 | try: 145 | casts.append((str, unicode)) # noqa: F821 146 | except Exception: 147 | pass 148 | 149 | for (from_type, to_type) in casts: 150 | converted, converted_value = conditional_cast(from_type, to_type) 151 | if converted: 152 | return converted_value 153 | 154 | raise ValueError( 155 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " 156 | "key: {}".format( 157 | original_type, replacement_type, original, replacement, full_key 158 | ) 159 | ) 160 | -------------------------------------------------------------------------------- /util/data_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import SharedArray as SA 4 | 5 | import torch 6 | 7 | from util.voxelize import voxelize 8 | 9 | 10 | def sa_create(name, var): 11 | x = SA.create(name, var.shape, dtype=var.dtype) 12 | x[...] = var[...] 13 | x.flags.writeable = False 14 | return x 15 | 16 | 17 | def collate_fn(batch): 18 | coord, feat, label = list(zip(*batch)) 19 | offset, count = [], 0 20 | for item in coord: 21 | count += item.shape[0] 22 | offset.append(count) 23 | return torch.cat(coord), torch.cat(feat), torch.cat(label), torch.IntTensor(offset) 24 | 25 | 26 | def data_prepare(coord, feat, label, split='train', voxel_size=0.04, voxel_max=None, transform=None, shuffle_index=False): 27 | if transform: 28 | coord, feat, label = transform(coord, feat, label) 29 | if voxel_size: 30 | coord_min = np.min(coord, 0) 31 | coord -= coord_min 32 | uniq_idx = voxelize(coord, voxel_size) 33 | coord, feat, label = coord[uniq_idx], feat[uniq_idx], label[uniq_idx] 34 | if voxel_max and label.shape[0] > voxel_max: 35 | init_idx = np.random.randint(label.shape[0]) if 'train' in split else label.shape[0] // 2 36 | crop_idx = np.argsort(np.sum(np.square(coord - coord[init_idx]), 1))[:voxel_max] 37 | coord, feat, label = coord[crop_idx], feat[crop_idx], label[crop_idx] 38 | if shuffle_index: 39 | shuf_idx = np.arange(coord.shape[0]) 40 | np.random.shuffle(shuf_idx) 41 | coord, feat, label = coord[shuf_idx], feat[shuf_idx], label[shuf_idx] 42 | 43 | coord_min = np.min(coord, 0) 44 | coord -= coord_min 45 | coord = torch.FloatTensor(coord) 46 | feat = torch.FloatTensor(feat) / 255. 47 | label = torch.LongTensor(label) 48 | return coord, feat, label 49 | -------------------------------------------------------------------------------- /util/s3dis.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import SharedArray as SA 5 | from torch.utils.data import Dataset 6 | 7 | from util.data_util import sa_create 8 | from util.data_util import data_prepare 9 | 10 | 11 | class S3DIS(Dataset): 12 | def __init__(self, split='train', data_root='trainval', test_area=5, voxel_size=0.04, voxel_max=None, transform=None, shuffle_index=False, loop=1): 13 | super().__init__() 14 | self.split, self.voxel_size, self.transform, self.voxel_max, self.shuffle_index, self.loop = split, voxel_size, transform, voxel_max, shuffle_index, loop 15 | data_list = sorted(os.listdir(data_root)) 16 | data_list = [item[:-4] for item in data_list if 'Area_' in item] 17 | if split == 'train': 18 | self.data_list = [item for item in data_list if not 'Area_{}'.format(test_area) in item] 19 | else: 20 | self.data_list = [item for item in data_list if 'Area_{}'.format(test_area) in item] 21 | for item in self.data_list: 22 | if not os.path.exists("/dev/shm/{}".format(item)): 23 | data_path = os.path.join(data_root, item + '.npy') 24 | data = np.load(data_path) # xyzrgbl, N*7 25 | sa_create("shm://{}".format(item), data) 26 | self.data_idx = np.arange(len(self.data_list)) 27 | print("Totally {} samples in {} set.".format(len(self.data_idx), split)) 28 | 29 | def __getitem__(self, idx): 30 | data_idx = self.data_idx[idx % len(self.data_idx)] 31 | data = SA.attach("shm://{}".format(self.data_list[data_idx])).copy() 32 | coord, feat, label = data[:, 0:3], data[:, 3:6], data[:, 6] 33 | coord, feat, label = data_prepare(coord, feat, label, self.split, self.voxel_size, self.voxel_max, self.transform, self.shuffle_index) 34 | return coord, feat, label 35 | 36 | def __len__(self): 37 | return len(self.data_idx) * self.loop 38 | -------------------------------------------------------------------------------- /util/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | 6 | class Compose(object): 7 | def __init__(self, transforms): 8 | self.transforms = transforms 9 | 10 | def __call__(self, coord, feat, label): 11 | for t in self.transforms: 12 | coord, feat, label = t(coord, feat, label) 13 | return coord, feat, label 14 | 15 | 16 | class ToTensor(object): 17 | def __call__(self, coord, feat, label): 18 | coord = torch.from_numpy(coord) 19 | if not isinstance(coord, torch.FloatTensor): 20 | coord = coord.float() 21 | feat = torch.from_numpy(feat) 22 | if not isinstance(feat, torch.FloatTensor): 23 | feat = feat.float() 24 | label = torch.from_numpy(label) 25 | if not isinstance(label, torch.LongTensor): 26 | label = label.long() 27 | return coord, feat, label 28 | 29 | 30 | class RandomRotate(object): 31 | def __init__(self, angle=[0, 0, 1]): 32 | self.angle = angle 33 | 34 | def __call__(self, coord, feat, label): 35 | angle_x = np.random.uniform(-self.angle[0], self.angle[0]) * np.pi 36 | angle_y = np.random.uniform(-self.angle[1], self.angle[1]) * np.pi 37 | angle_z = np.random.uniform(-self.angle[2], self.angle[2]) * np.pi 38 | cos_x, sin_x = np.cos(angle_x), np.sin(angle_x) 39 | cos_y, sin_y = np.cos(angle_y), np.sin(angle_y) 40 | cos_z, sin_z = np.cos(angle_z), np.sin(angle_z) 41 | R_x = np.array([[1, 0, 0], [0, cos_x, -sin_x], [0, sin_x, cos_x]]) 42 | R_y = np.array([[cos_y, 0, sin_y], [0, 1, 0], [-sin_y, 0, cos_y]]) 43 | R_z = np.array([[cos_z, -sin_z, 0], [sin_z, cos_z, 0], [0, 0, 1]]) 44 | R = np.dot(R_z, np.dot(R_y, R_x)) 45 | coord = np.dot(coord, np.transpose(R)) 46 | return coord, feat, label 47 | 48 | 49 | class RandomScale(object): 50 | def __init__(self, scale=[0.9, 1.1], anisotropic=False): 51 | self.scale = scale 52 | self.anisotropic = anisotropic 53 | 54 | def __call__(self, coord, feat, label): 55 | scale = np.random.uniform(self.scale[0], self.scale[1], 3 if self.anisotropic else 1) 56 | coord *= scale 57 | return coord, feat, label 58 | 59 | 60 | class RandomShift(object): 61 | def __init__(self, shift=[0.2, 0.2, 0]): 62 | self.shift = shift 63 | 64 | def __call__(self, coord, feat, label): 65 | shift_x = np.random.uniform(-self.shift[0], self.shift[0]) 66 | shift_y = np.random.uniform(-self.shift[1], self.shift[1]) 67 | shift_z = np.random.uniform(-self.shift[2], self.shift[2]) 68 | coord += [shift_x, shift_y, shift_z] 69 | return coord, feat, label 70 | 71 | 72 | class RandomFlip(object): 73 | def __init__(self, p=0.5): 74 | self.p = p 75 | 76 | def __call__(self, coord, feat, label): 77 | if np.random.rand() < self.p: 78 | coord[:, 0] = -coord[:, 0] 79 | if np.random.rand() < self.p: 80 | coord[:, 1] = -coord[:, 1] 81 | return coord, feat, label 82 | 83 | 84 | class RandomJitter(object): 85 | def __init__(self, sigma=0.01, clip=0.05): 86 | self.sigma = sigma 87 | self.clip = clip 88 | 89 | def __call__(self, coord, feat, label): 90 | assert (self.clip > 0) 91 | jitter = np.clip(self.sigma * np.random.randn(coord.shape[0], 3), -1 * self.clip, self.clip) 92 | coord += jitter 93 | return coord, feat, label 94 | 95 | 96 | class ChromaticAutoContrast(object): 97 | def __init__(self, p=0.2, blend_factor=None): 98 | self.p = p 99 | self.blend_factor = blend_factor 100 | 101 | def __call__(self, coord, feat, label): 102 | if np.random.rand() < self.p: 103 | lo = np.min(feat, 0, keepdims=True) 104 | hi = np.max(feat, 0, keepdims=True) 105 | scale = 255 / (hi - lo) 106 | contrast_feat = (feat[:, :3] - lo) * scale 107 | blend_factor = np.random.rand() if self.blend_factor is None else self.blend_factor 108 | feat[:, :3] = (1 - blend_factor) * feat[:, :3] + blend_factor * contrast_feat 109 | return coord, feat, label 110 | 111 | 112 | class ChromaticTranslation(object): 113 | def __init__(self, p=0.95, ratio=0.05): 114 | self.p = p 115 | self.ratio = ratio 116 | 117 | def __call__(self, coord, feat, label): 118 | if np.random.rand() < self.p: 119 | tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.ratio 120 | feat[:, :3] = np.clip(tr + feat[:, :3], 0, 255) 121 | return coord, feat, label 122 | 123 | 124 | class ChromaticJitter(object): 125 | def __init__(self, p=0.95, std=0.005): 126 | self.p = p 127 | self.std = std 128 | 129 | def __call__(self, coord, feat, label): 130 | if np.random.rand() < self.p: 131 | noise = np.random.randn(feat.shape[0], 3) 132 | noise *= self.std * 255 133 | feat[:, :3] = np.clip(noise + feat[:, :3], 0, 255) 134 | return coord, feat, label 135 | 136 | 137 | class HueSaturationTranslation(object): 138 | @staticmethod 139 | def rgb_to_hsv(rgb): 140 | # Translated from source of colorsys.rgb_to_hsv 141 | # r,g,b should be a numpy arrays with values between 0 and 255 142 | # rgb_to_hsv returns an array of floats between 0.0 and 1.0. 143 | rgb = rgb.astype('float') 144 | hsv = np.zeros_like(rgb) 145 | # in case an RGBA array was passed, just copy the A channel 146 | hsv[..., 3:] = rgb[..., 3:] 147 | r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] 148 | maxc = np.max(rgb[..., :3], axis=-1) 149 | minc = np.min(rgb[..., :3], axis=-1) 150 | hsv[..., 2] = maxc 151 | mask = maxc != minc 152 | hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask] 153 | rc = np.zeros_like(r) 154 | gc = np.zeros_like(g) 155 | bc = np.zeros_like(b) 156 | rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask] 157 | gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask] 158 | bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask] 159 | hsv[..., 0] = np.select([r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc) 160 | hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0 161 | return hsv 162 | 163 | @staticmethod 164 | def hsv_to_rgb(hsv): 165 | # Translated from source of colorsys.hsv_to_rgb 166 | # h,s should be a numpy arrays with values between 0.0 and 1.0 167 | # v should be a numpy array with values between 0.0 and 255.0 168 | # hsv_to_rgb returns an array of uints between 0 and 255. 169 | rgb = np.empty_like(hsv) 170 | rgb[..., 3:] = hsv[..., 3:] 171 | h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] 172 | i = (h * 6.0).astype('uint8') 173 | f = (h * 6.0) - i 174 | p = v * (1.0 - s) 175 | q = v * (1.0 - s * f) 176 | t = v * (1.0 - s * (1.0 - f)) 177 | i = i % 6 178 | conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5] 179 | rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v) 180 | rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t) 181 | rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p) 182 | return rgb.astype('uint8') 183 | 184 | def __init__(self, hue_max=0.5, saturation_max=0.2): 185 | self.hue_max = hue_max 186 | self.saturation_max = saturation_max 187 | 188 | def __call__(self, coord, feat, label): 189 | # Assume feat[:, :3] is rgb 190 | hsv = HueSaturationTranslation.rgb_to_hsv(feat[:, :3]) 191 | hue_val = (np.random.rand() - 0.5) * 2 * self.hue_max 192 | sat_ratio = 1 + (np.random.rand() - 0.5) * 2 * self.saturation_max 193 | hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1) 194 | hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1) 195 | feat[:, :3] = np.clip(HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255) 196 | return coord, feat, label 197 | 198 | 199 | class RandomDropColor(object): 200 | def __init__(self, p=0.2): 201 | self.p = p 202 | 203 | def __call__(self, coord, feat, label): 204 | if np.random.rand() < self.p: 205 | feat[:, :3] = 0 206 | # feat[:, :3] = 127.5 207 | return coord, feat, label 208 | -------------------------------------------------------------------------------- /util/voxelize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def fnv_hash_vec(arr): 5 | """ 6 | FNV64-1A 7 | """ 8 | assert arr.ndim == 2 9 | # Floor first for negative coordinates 10 | arr = arr.copy() 11 | arr = arr.astype(np.uint64, copy=False) 12 | hashed_arr = np.uint64(14695981039346656037) * np.ones(arr.shape[0], dtype=np.uint64) 13 | for j in range(arr.shape[1]): 14 | hashed_arr *= np.uint64(1099511628211) 15 | hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j]) 16 | return hashed_arr 17 | 18 | 19 | def ravel_hash_vec(arr): 20 | """ 21 | Ravel the coordinates after subtracting the min coordinates. 22 | """ 23 | assert arr.ndim == 2 24 | arr = arr.copy() 25 | arr -= arr.min(0) 26 | arr = arr.astype(np.uint64, copy=False) 27 | arr_max = arr.max(0).astype(np.uint64) + 1 28 | 29 | keys = np.zeros(arr.shape[0], dtype=np.uint64) 30 | # Fortran style indexing 31 | for j in range(arr.shape[1] - 1): 32 | keys += arr[:, j] 33 | keys *= arr_max[j + 1] 34 | keys += arr[:, -1] 35 | return keys 36 | 37 | 38 | def voxelize(coord, voxel_size=0.05, hash_type='fnv', mode=0): 39 | discrete_coord = np.floor(coord / np.array(voxel_size)) 40 | if hash_type == 'ravel': 41 | key = ravel_hash_vec(discrete_coord) 42 | else: 43 | key = fnv_hash_vec(discrete_coord) 44 | 45 | idx_sort = np.argsort(key) 46 | key_sort = key[idx_sort] 47 | _, count = np.unique(key_sort, return_counts=True) 48 | if mode == 0: # train mode 49 | idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + np.random.randint(0, count.max(), count.size) % count 50 | idx_unique = idx_sort[idx_select] 51 | return idx_unique 52 | else: # val mode 53 | return idx_sort, count 54 | --------------------------------------------------------------------------------