├── .gitignore ├── INSTALL.md ├── LICENSE ├── PartObjaverse-Tiny ├── PartObjaverse-Tiny.md ├── PartObjaverse-Tiny_semantic.json └── eval │ ├── eval_instance.py │ ├── eval_part.py │ └── eval_semantic.py ├── README.md ├── assets ├── partobjaverse-vis.png └── teaser.png ├── configs ├── _base_ │ └── default_runtime.py └── sampart3d │ └── sampart3d-trainmlp-render16views.py ├── launch ├── eval.py └── train.py ├── libs └── pointops │ ├── __init__.py │ ├── functions │ ├── __init__.py │ ├── aggregation.py │ ├── attention.py │ ├── grouping.py │ ├── interpolation.py │ ├── query.py │ ├── sampling.py │ ├── subtraction.py │ └── utils.py │ ├── setup.py │ └── src │ ├── __init__.py │ ├── aggregation │ ├── aggregation_cuda.cpp │ ├── aggregation_cuda_kernel.cu │ └── aggregation_cuda_kernel.h │ ├── attention │ ├── attention_cuda.cpp │ ├── attention_cuda_kernel.cu │ └── attention_cuda_kernel.h │ ├── ball_query │ ├── ball_query_cuda.cpp │ ├── ball_query_cuda_kernel.cu │ └── ball_query_cuda_kernel.h │ ├── cuda_utils.h │ ├── grouping │ ├── grouping_cuda.cpp │ ├── grouping_cuda_kernel.cu │ └── grouping_cuda_kernel.h │ ├── interpolation │ ├── interpolation_cuda.cpp │ ├── interpolation_cuda_kernel.cu │ └── interpolation_cuda_kernel.h │ ├── knn_query │ ├── knn_query_cuda.cpp │ ├── knn_query_cuda_kernel.cu │ └── knn_query_cuda_kernel.h │ ├── pointops_api.cpp │ ├── random_ball_query │ ├── random_ball_query_cuda.cpp │ ├── random_ball_query_cuda_kernel.cu │ └── random_ball_query_cuda_kernel.h │ ├── sampling │ ├── sampling_cuda.cpp │ ├── sampling_cuda_kernel.cu │ └── sampling_cuda_kernel.h │ └── subtraction │ ├── subtraction_cuda.cpp │ ├── subtraction_cuda_kernel.cu │ └── subtraction_cuda_kernel.h ├── pointcept ├── datasets │ ├── __init__.py │ ├── builder.py │ ├── dataset_render_16views.py │ ├── sampart3d_util.py │ ├── transform.py │ └── utils.py ├── engines │ ├── __init__.py │ ├── defaults.py │ ├── eval.py │ ├── hooks │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── default.py │ │ ├── evaluator.py │ │ └── misc.py │ ├── launch.py │ └── train.py ├── models │ ├── PTv3Object.py │ ├── SAMPart3D.py │ ├── __init__.py │ ├── builder.py │ ├── modules.py │ └── utils │ │ ├── __init__.py │ │ ├── checkpoint.py │ │ ├── misc.py │ │ ├── serialization │ │ ├── __init__.py │ │ ├── default.py │ │ ├── hilbert.py │ │ └── z_order.py │ │ └── structure.py └── utils │ ├── __init__.py │ ├── cache.py │ ├── comm.py │ ├── config.py │ ├── env.py │ ├── events.py │ ├── logger.py │ ├── misc.py │ ├── optimizer.py │ ├── path.py │ ├── registry.py │ ├── scheduler.py │ ├── timer.py │ └── visualization.py ├── requirements.txt ├── scripts ├── eval.sh └── train.sh └── tools ├── blender_render_16views.py └── highlight_parts.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | **/build/ 3 | **/*.egg-info/ 4 | **/dist/ 5 | *.so 6 | exp 7 | weights 8 | data 9 | log 10 | ckpt 11 | outputs/ 12 | .vscode 13 | .idea 14 | */.DS_Store 15 | **/*.out 16 | Dockerfile 17 | debug_pcd -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | ### Installation 2 | 3 | We test our model on a 24G RTX4090 GPU with Python 3.10, CUDA 12.1 and Pytorch 2.1.0. 4 | 5 | 1. Install basic modules: torch and packages in requirements.txt 6 | ```bash 7 | conda create -n sampart3d 8 | conda activate sampart3d 9 | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | 2. Install modules for PTv3-object 14 | ```bash 15 | cd libs/pointops 16 | python setup.py install 17 | cd ../.. 18 | 19 | # spconv (SparseUNet) 20 | # refer https://github.com/traveller59/spconv 21 | pip install spconv-cu120 # choose version match your local cuda 22 | ``` 23 | 24 | Following [README](https://github.com/Dao-AILab/flash-attention) in Flash Attention repo and install Flash Attention for PTv3-object. 25 | 26 | 27 | 3. Install modules for acceleration (necessary in current version of code) 28 | ```bash 29 | pip install ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch 30 | 31 | # using GPU-based HDBSCAN clustering algorithm 32 | # refer https://docs.rapids.ai/install 33 | pip install --extra-index-url=https://pypi.nvidia.com cudf-cu11==24.6.* cuml-cu11==24.6.* 34 | ``` 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Pointcept 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PartObjaverse-Tiny/PartObjaverse-Tiny.md: -------------------------------------------------------------------------------- 1 | ## PartObjaverse-Tiny 2 | 3 | [PartObjaverse-Tiny](https://huggingface.co/datasets/yhyang-myron/PartObjaverse-Tiny) is a 3D part segmentation dataset which provides detailed semantic-level and instance-level part annotations for 200 complex 3D objects from [Objaverse](https://objaverse.allenai.org/). Following [GObjaverse](https://aigc3d.github.io/gobjaverse/), we divide these 200 objects into 8 categories: Human-Shape (29), Animals (23), Daily-Used (25), Buildings&&Outdoor (25), Transportations (38), Plants (18), Food (8) and Electronics (34). 4 | 5 | PartObjaverse-Tiny can be downloaded [here](https://huggingface.co/datasets/yhyang-myron/PartObjaverse-Tiny). 6 | 7 | ![](../assets/partobjaverse-vis.png) 8 | 9 | ### File Description 10 | - **PartObjaverse-Tiny_mesh**: 200 Meshes in glb format, named with uids. 11 | - **PartObjaverse-Tiny_semantic.json**: Label set for each mesh. 12 | - **PartObjaverse-Tiny_semantic_gt**: The ground truth of the semantic segmentation task. Each GT number corresponds in order to each label in the label set. 13 | - **PartObjaverse-Tiny_instance_gt**: The ground truth of the instance segmentation task. Each GT number corresponds to an instance and does not represent semantics. 14 | 15 | ### Usage 16 | ``` 17 | pip install trimesh 18 | ``` 19 | ```bash 20 | mesh = trimesh.load(${MESH_PATH}) 21 | if isinstance(mesh, trimesh.Scene): 22 | mesh = mesh.dump(concatenate=True) 23 | mesh_faces = mesh.faces() 24 | ``` 25 | The face order in GT is consistent with the order of ``mesh_faces``. 26 | 27 | Evaluation codes are in ``SAMPart3D/PartObjaverse-Tiny/eval/``. -------------------------------------------------------------------------------- /PartObjaverse-Tiny/eval/eval_instance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | from os.path import join 4 | from typing import List 5 | 6 | 7 | def compute_ap(tp, fp, gt_npos, n_bins=100): 8 | assert len(tp) == len(fp), 'ERROR: the length of true_pos and false_pos is not the same!' 9 | 10 | tp = np.cumsum(tp) 11 | fp = np.cumsum(fp) 12 | 13 | rec = tp / gt_npos 14 | prec = tp / (fp + tp) 15 | 16 | rec = np.insert(rec, 0, 0.0) 17 | prec = np.insert(prec, 0, 1.0) 18 | 19 | ap = 0. 20 | delta = 1.0 / n_bins 21 | 22 | out_rec = np.arange(0, 1 + delta, delta) 23 | out_prec = np.zeros((n_bins+1), dtype=np.float32) 24 | 25 | for idx, t in enumerate(out_rec): 26 | prec1 = prec[rec >= t] 27 | if len(prec1) == 0: 28 | p = 0. 29 | else: 30 | p = max(prec1) 31 | 32 | out_prec[idx] = p 33 | ap = ap + p / (n_bins + 1) 34 | 35 | return ap 36 | 37 | 38 | def eval_per_shape_mean_ap( 39 | part_name_list: List[str], # The name list of the shape 40 | pred_sem: np.ndarray, # Predicted semantic labels, continuous natural numbers, each number is the index of the part_name_list 41 | pred_ins: np.ndarray, # Predicted instance labels, continuous natural numbers, each number is the index of the instance (without semantic) 42 | gt_sem: np.ndarray, # Ground truth semantic labels 43 | gt_ins: np.ndarray, # Ground truth instance labels 44 | iou_threshold: float = 0.5 45 | ) -> float: 46 | 47 | assert len(pred_sem) == len(pred_sem) == len(gt_sem) == len(gt_ins) 48 | 49 | gt_n_ins = gt_ins.max() + 1 50 | pred_n_ins = pred_ins.max() + 1 51 | n_labels = len(part_name_list) 52 | 53 | true_pos_list = [[] for _ in part_name_list] 54 | false_pos_list = [[] for _ in part_name_list] 55 | gt_npos = np.zeros((n_labels), dtype=np.int32) 56 | 57 | mapping_insid_to_semid_gt = {} 58 | for i in range(gt_n_ins): 59 | sem_id = gt_sem[gt_ins == i][0] 60 | mapping_insid_to_semid_gt[i] = sem_id 61 | mapping_insid_to_semid_pred = {} 62 | for i in range(pred_n_ins): 63 | sem_id = pred_sem[pred_ins == i][0] 64 | mapping_insid_to_semid_pred[i] = sem_id 65 | 66 | # classify all gt masks by part categories 67 | gt_mask_per_cat = [[] for _ in part_name_list] 68 | for i in range(gt_n_ins): 69 | sem_id = mapping_insid_to_semid_gt[i] 70 | gt_mask_per_cat[sem_id].append(i) 71 | gt_npos[sem_id] += 1 72 | 73 | gt_used = np.zeros((gt_n_ins), dtype=np.bool_) 74 | 75 | # enumerate all pred parts 76 | for idx in range(pred_n_ins): 77 | sem_id = mapping_insid_to_semid_pred[idx] 78 | 79 | iou_max = 0.0 80 | cor_gt_id = -1 81 | for j in gt_mask_per_cat[sem_id]: 82 | if not gt_used[j]: 83 | intersect = np.sum((gt_ins == j) & (pred_ins == idx)) 84 | union = np.sum((gt_ins == j) | (pred_ins == idx)) 85 | iou = intersect * 1.0 / union 86 | 87 | if iou > iou_max: 88 | iou_max = iou 89 | cor_gt_id = j 90 | 91 | if iou_max > iou_threshold: 92 | gt_used[cor_gt_id] = True 93 | 94 | # add in a true positive 95 | true_pos_list[sem_id].append(True) 96 | false_pos_list[sem_id].append(False) 97 | else: 98 | # add in a false positive 99 | true_pos_list[sem_id].append(False) 100 | false_pos_list[sem_id].append(True) 101 | 102 | # compute per-part-category AP 103 | aps = np.zeros((n_labels), dtype=np.float32) 104 | ap_valids = np.ones((n_labels), dtype=bool) 105 | for i in range(n_labels): 106 | has_pred = (len(true_pos_list[i]) > 0) 107 | has_gt = (gt_npos[i] > 0) 108 | 109 | if not has_gt: 110 | ap_valids[i] = False 111 | continue 112 | 113 | if has_gt and not has_pred: 114 | continue 115 | 116 | true_pos = np.array(true_pos_list[i], dtype=np.float32) 117 | false_pos = np.array(false_pos_list[i], dtype=np.float32) 118 | 119 | aps[i] = compute_ap(true_pos, false_pos, gt_npos[i]) 120 | 121 | # compute mean AP 122 | mean_ap = np.sum(aps * ap_valids) / np.sum(ap_valids) 123 | 124 | return aps, ap_valids, gt_npos, mean_ap 125 | 126 | 127 | def eval_all_shape_mean_ap(meta_path, pred_sem_path, pred_ins_path, gt_sem_path, gt_ins_path, iou_threshold=0.5): 128 | meta_data = json.load(open(meta_path, 'r')) 129 | total_mean_ap = [] 130 | categories_list = ["Human-Shape", "Animals", "Daily-Used", "Buildings&&Outdoor", "Transportations", "Plants", "Food", "Electronics"] 131 | cate_mAP = {} 132 | for cate in categories_list: 133 | cate_mAP[cate] = [] 134 | 135 | for cate in meta_data.keys(): 136 | for uid in meta_data[cate]: 137 | print(f"Evaluating {uid}") 138 | part_name_list = meta_data[cate][uid] 139 | pred_sem = np.load(join(pred_sem_path, f"{uid}.npy")) 140 | pred_ins = np.load(join(pred_ins_path, f"{uid}.npy")) 141 | gt_sem = np.load(join(gt_sem_path, f"{uid}.npy")) 142 | gt_ins = np.load(join(gt_ins_path, f"{uid}.npy")) 143 | 144 | aps, ap_valids, gt_npos, mean_ap = eval_per_shape_mean_ap(part_name_list, pred_sem, pred_ins, gt_sem, gt_ins, iou_threshold) 145 | total_mean_ap.append(mean_ap) 146 | print(f"Mean AP: {mean_ap}") 147 | with open("eval_ins_results.txt", "a") as f: 148 | f.write(f"{uid}: {mean_ap}\n") 149 | cate_mAP[cate].append(mean_ap) 150 | 151 | for cate in categories_list: 152 | print(f"{cate} Mean AP: {np.mean(cate_mAP[cate])}") 153 | with open("eval_ins_results.txt", "a") as f: 154 | f.write(f"{cate} Mean AP: {np.mean(cate_mAP[cate])}\n") 155 | 156 | total_mean_ap = np.mean(total_mean_ap) 157 | print(f"Total Mean AP: {total_mean_ap}") 158 | with open("eval_ins_results.txt", "a") as f: 159 | f.write(f"Total Mean AP: {total_mean_ap}\n") 160 | 161 | 162 | if __name__ == '__main__': 163 | meta_path = "" 164 | pred_sem_path = "" 165 | pred_ins_path = "" 166 | gt_sem_path = "" 167 | gt_ins_path = "" 168 | eval_all_shape_mean_ap(meta_path, pred_sem_path, pred_ins_path, gt_sem_path, gt_ins_path, iou_threshold=0.5) -------------------------------------------------------------------------------- /PartObjaverse-Tiny/eval/eval_part.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | from os.path import join 4 | from typing import List 5 | 6 | 7 | def compute_iou(pred, gt): 8 | intersection = np.logical_and(pred, gt).sum() 9 | union = np.logical_or(pred, gt).sum() 10 | if union != 0: 11 | return (intersection / union) * 100 12 | else: 13 | return 0 14 | 15 | 16 | def eval_per_shape_part_mean_iou( 17 | pred_ins: np.ndarray, # Predicted instance labels, continuous natural numbers, each number is the index of the instance (without semantic) 18 | gt_ins: np.ndarray, # Ground truth instance labels 19 | ) -> float: 20 | 21 | ious = [] 22 | if gt_ins.max() == -1: 23 | return -1 24 | for gt_id in np.unique(gt_ins): 25 | if gt_id == -1: # Ignore the unassigned group 26 | continue 27 | gt_group = gt_ins == gt_id 28 | best_iou = 0 29 | for pred_id in np.unique(pred_ins): 30 | if pred_id == -1: # Ignore the unassigned group 31 | continue 32 | pred_group = pred_ins == pred_id 33 | iou = compute_iou(pred_group, gt_group) 34 | if iou > best_iou: 35 | best_iou = iou 36 | ious.append(best_iou) 37 | return np.mean(ious) 38 | 39 | 40 | def eval_all_shape_part_mean_iou(meta_path, pred_ins_path, gt_ins_path): 41 | 42 | meta_data = json.load(open(meta_path, 'r')) 43 | total_miou = [] 44 | categories_list = ["Human-Shape", "Animals", "Daily-Used", "Buildings&&Outdoor", "Transportations", "Plants", "Food", "Electronics"] 45 | cate_miou = {} 46 | for cate in categories_list: 47 | cate_miou[cate] = [] 48 | 49 | for cate in meta_data.keys(): 50 | for uid in meta_data[cate]: 51 | print(f"Evaluating {uid}") 52 | pred_ins = np.load(join(pred_ins_path, f"{uid}.npy")) 53 | gt_ins = np.load(join(gt_ins_path, f"{uid}.npy")) 54 | obj_iou = eval_per_shape_part_mean_iou(pred_ins, gt_ins) 55 | 56 | total_miou.append(obj_iou) 57 | print(f"miou: {obj_iou}") 58 | with open("eval_partseg_results.txt", "a") as f: 59 | f.write(f"{uid}: {obj_iou}\n") 60 | cate_miou[cate].append(obj_iou) 61 | 62 | for cate in categories_list: 63 | print(f"{cate} miou: {np.mean(cate_miou[cate])}") 64 | with open("eval_partseg_results.txt", "a") as f: 65 | f.write(f"{cate} miou: {np.mean(cate_miou[cate])}\n") 66 | 67 | total_miou = np.mean(total_miou) 68 | print(f"Total miou: {total_miou}") 69 | with open("eval_partseg_results.txt", "a") as f: 70 | f.write(f"Total miou: {total_miou}\n") 71 | 72 | 73 | if __name__ == '__main__': 74 | meta_path = "" 75 | pred_ins_path = "" 76 | gt_ins_path = "" 77 | eval_all_shape_part_mean_iou(meta_path, pred_ins_path, gt_ins_path) -------------------------------------------------------------------------------- /PartObjaverse-Tiny/eval/eval_semantic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | from os.path import join 4 | from typing import List 5 | 6 | 7 | def compute_iou(pred, gt): 8 | intersection = np.logical_and(pred, gt).sum() 9 | union = np.logical_or(pred, gt).sum() 10 | if union != 0: 11 | return (intersection / union) * 100 12 | else: 13 | return 0 14 | 15 | 16 | def eval_per_shape_mean_iou( 17 | part_name_list: List[str], # The name list of the shape 18 | pred_sem: np.ndarray, # Predicted semantic labels, continuous natural numbers, each number is the index of the part_name_list 19 | gt_sem: np.ndarray, # Ground truth semantic labels 20 | ) -> float: 21 | 22 | obj_iou = [] 23 | for i in range(len(part_name_list)): 24 | if (gt_sem == i).sum() == 0: 25 | continue 26 | obj_iou.append(compute_iou(pred_sem == i, gt_sem == i)) 27 | iou = np.mean(obj_iou) 28 | 29 | return iou 30 | 31 | 32 | def eval_all_shape_mean_iou(meta_path, pred_sem_path, gt_sem_path): 33 | 34 | meta_data = json.load(open(meta_path, 'r')) 35 | total_miou = [] 36 | categories_list = ["Human-Shape", "Animals", "Daily-Used", "Buildings&&Outdoor", "Transportations", "Plants", "Food", "Electronics"] 37 | cate_miou = {} 38 | for cate in categories_list: 39 | cate_miou[cate] = [] 40 | 41 | for cate in meta_data.keys(): 42 | for uid in meta_data[cate]: 43 | print(f"Evaluating {uid}") 44 | part_name_list = meta_data[cate][uid] 45 | pred_sem = np.load(join(pred_sem_path, f"{uid}.npy")) 46 | gt_sem = np.load(join(gt_sem_path, f"{uid}.npy")) 47 | obj_iou = eval_per_shape_mean_iou(part_name_list, pred_sem, gt_sem) 48 | 49 | total_miou.append(obj_iou) 50 | print(f"miou: {obj_iou}") 51 | with open("eval_sem_results.txt", "a") as f: 52 | f.write(f"{uid}: {obj_iou}\n") 53 | cate_miou[cate].append(obj_iou) 54 | 55 | for cate in categories_list: 56 | print(f"{cate} miou: {np.mean(cate_miou[cate])}") 57 | with open("eval_sem_results.txt", "a") as f: 58 | f.write(f"{cate} miou: {np.mean(cate_miou[cate])}\n") 59 | 60 | total_miou = np.mean(total_miou) 61 | print(f"Total miou: {total_miou}") 62 | with open("eval_sem_results.txt", "a") as f: 63 | f.write(f"Total miou: {total_miou}\n") 64 | 65 | 66 | if __name__ == '__main__': 67 | meta_path = "" 68 | pred_sem_path = "" 69 | gt_sem_path = "" 70 | eval_all_shape_mean_iou(meta_path, pred_sem_path, gt_sem_path) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAMPart3D: Segment Any Part in 3D Objects 2 | 3 | ## [Project Page](https://yhyang-myron.github.io/SAMPart3D-website/) | [Paper](https://arxiv.org/abs/2411.07184) | [Dataset: PartObjaverse-Tiny](PartObjaverse-Tiny/PartObjaverse-Tiny.md) 4 | 5 | ![](assets/teaser.png) 6 | 7 | ## 🔧 Setup 8 | 9 | ### Installation 10 | Please refer to [INSTALL.md](INSTALL.md). 11 | 12 | ### Preparation for training 13 | 14 | 1. Download pretrained PTv3-object [https://huggingface.co/yhyang-myron/SAMPart3D/tree/main](https://huggingface.co/yhyang-myron/SAMPart3D/tree/main). 15 | 16 | 2. Data prepocessing 17 | 18 | We use Blender to render multi-view rgb and depth of the 3D glb mesh. First Install Blender: 19 | ```bash 20 | wget https://download.blender.org/release/Blender4.0/blender-4.0.0-linux-x64.tar.xz 21 | tar -xf blender-4.0.0-linux-x64.tar.xz 22 | rm blender-4.0.0-linux-x64.tar.xz 23 | ``` 24 | Then render rgb and depth: 25 | ```bash 26 | cd tools 27 | ${PATH_TO_BLENDER} -b -P blender_render_16views.py ${MESH_PATH} ${TYPES} ${OUTPUT_PATH} 28 | ``` 29 | For example: 30 | ```bash 31 | blender-4.0.0-linux-x64/blender -b -P blender_render_16views.py mesh_root/knight.glb glb data_root/knight 32 | ``` 33 | 34 | ## 🚀 Running SAMPart3D 35 | ### 1. Train 36 | Change the rendering **data_root**, **mesh_root** and **backbone_weight_path** in `configs/sampart3d/sampart3d-trainmlp-render16views.py`. 37 | ```bash 38 | SAMPart3D 39 | |-- ckpt 40 | |-- ptv3-object.pth 41 | |-- mesh_root 42 | |-- knight.glb 43 | |-- data_root 44 | |-- knight 45 | |-- meta.json 46 | |-- render_0000.webp 47 | |-- depth_0000.exr 48 | ... 49 | ``` 50 | 51 | ```bash 52 | export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} 53 | sh scripts/train.sh -g ${NUM_GPU} -d ${DATASET_NAME} -c ${CONFIG_NAME} -n ${EXP_NAME} -o ${OBJECT_UID} 54 | ``` 55 | For example: 56 | ```bash 57 | sh scripts/train.sh -g 1 -d sampart3d -c sampart3d-trainmlp-render16views -n knight -o knight 58 | ``` 59 | 60 | The mesh segmentation results will be saved in `exp/${DATASET_NAME}/${EXP_NAME}/resuls`, and the visualization of point clouds and meshes will be saved in `exp/${DATASET_NAME}/${EXP_NAME}/vis_pcd/`. 61 | 62 | ### 2. Test more scales with pretrained MLPs 63 | After training, the ckpt of the target mesh will be saved in `exp/${DATASET_NAME}/${EXP_NAME}/model/`, if you want to try more scales, you can directly load the weight. And modify the **val_scales_list** in `exp/${DATASET_NAME}/${EXP_NAME}/config.py`. 64 | 65 | ```bash 66 | export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} 67 | sh scripts/eval.sh -g ${NUM_GPU} -d ${DATASET_NAME} -n ${EXP_NAME} -w ${WEIGHT_NAME} 68 | ``` 69 | For example: 70 | ```bash 71 | sh scripts/eval.sh -g 1 -d sampart3d -n knight -w 5000 72 | ``` 73 | 74 | ### 3. Highlight 3D segments on multi-view renderings 75 | Set **render_dir**, **mesh_path**, **results_dir**, **save_dir** in `tools/highlight_parts.py`. 76 | ```bash 77 | python tools/highlight_parts.py 78 | ``` 79 | 80 | ## 📚 Dataset: PartObjaverse-Tiny 81 | Please refer to [PartObjaverse-Tiny.md](PartObjaverse-Tiny/PartObjaverse-Tiny.md). 82 | 83 | ## Acknowledgement 84 | SAMPart3D is inspired by the following repos: [garfield](https://github.com/chungmin99/garfield), [PointTransformerV3](https://github.com/Pointcept/PointTransformerV3), [Pointcept](https://github.com/Pointcept/Pointcept), [FeatUp](https://github.com/mhamilton723/FeatUp), [dinov2](https://github.com/facebookresearch/dinov2), [segment-anything](https://github.com/facebookresearch/segment-anything), [PartSLIP2](https://github.com/zyc00/PartSLIP2). 85 | 86 | Many thanks to the authors for sharing their codes. 87 | 88 | ## Citation 89 | If you find SAMPart3D useful in your project, please cite our work. :) 90 | ``` 91 | @article{yang2024sampart3d, 92 | title={SAMPart3D: Segment Any Part in 3D Objects}, 93 | author={Yang, Yunhan and Huang, Yukun and Guo, Yuan-Chen and Lu, Liangjun and Wu, Xiaoyang and Lam, Edmund Y and Cao, Yan-Pei and Liu, Xihui}, 94 | journal={arXiv preprint arXiv:2411.07184}, 95 | year={2024} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /assets/partobjaverse-vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pointcept/SAMPart3D/6a508d145be5cecd7e80bce6e4248ea7abbd71b7/assets/partobjaverse-vis.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pointcept/SAMPart3D/6a508d145be5cecd7e80bce6e4248ea7abbd71b7/assets/teaser.png -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | weight = None # path to model weight 2 | resume = False # whether to resume training process 3 | evaluate = True # evaluate after each epoch training process 4 | test_only = False # test process 5 | 6 | seed = None # train process will init a random seed and record 7 | save_path = "exp/default" 8 | num_worker = 16 # total worker in all gpu 9 | batch_size = 16 # total batch size in all gpu 10 | batch_size_val = None # auto adapt to bs 1 for each gpu 11 | batch_size_test = None # auto adapt to bs 1 for each gpu 12 | epoch = 100 # total epoch, data loop = epoch // eval_epoch 13 | eval_epoch = 100 # sche total eval & checkpoint epoch 14 | 15 | sync_bn = False 16 | enable_amp = False 17 | empty_cache = False 18 | find_unused_parameters = False 19 | 20 | mix_prob = 0 21 | param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)] 22 | 23 | # hook 24 | hooks = [ 25 | dict(type="CheckpointLoader"), 26 | dict(type="IterationTimer", warmup_iter=2), 27 | dict(type="InformationWriter"), 28 | dict(type="SemSegEvaluator"), 29 | dict(type="CheckpointSaver", save_freq=None), 30 | dict(type="PreciseEvaluator", test_last=False), 31 | ] 32 | 33 | # Trainer 34 | train = dict(type="DefaultTrainer") 35 | 36 | # Tester 37 | test = dict(type="SemSegTester", verbose=True) 38 | -------------------------------------------------------------------------------- /configs/sampart3d/sampart3d-trainmlp-render16views.py: -------------------------------------------------------------------------------- 1 | _base_ = ["../_base_/default_runtime.py"] 2 | 3 | # misc custom setting 4 | batch_size = 1 # dummy value 5 | num_worker = 1 # dummy value 6 | mix_prob = 0 # dummy value 7 | empty_cache = False 8 | enable_amp = True 9 | 10 | # hook 11 | hooks = [ 12 | dict(type="CheckpointLoader"), 13 | dict(type="IterationTimer", warmup_iter=2), 14 | dict(type="InformationWriter"), 15 | dict(type="CheckpointSaver", save_freq=None), 16 | ] 17 | 18 | # model settings 19 | model = dict( 20 | type="SAMPart3D", 21 | backbone_dim=384, 22 | output_dim=384, 23 | pcd_feat_dim=9, 24 | freeze_backbone=True, 25 | max_grouping_scale=2, 26 | use_hierarchy_losses=True, 27 | backbone = dict( 28 | type="PTv3-obj", 29 | in_channels=9, 30 | order=["z", "z-trans", "hilbert", "hilbert-trans"], 31 | stride=(), 32 | enc_depths=(3, 3, 3, 6, 16), 33 | enc_channels=(32, 64, 128, 256, 384), 34 | enc_num_head=(2, 4, 8, 16, 24), 35 | enc_patch_size=(1024, 1024, 1024, 1024, 1024), 36 | mlp_ratio=4, 37 | qkv_bias=True, 38 | qk_scale=None, 39 | attn_drop=0.0, 40 | proj_drop=0.0, 41 | drop_path=0.0, 42 | shuffle_orders=False, 43 | pre_norm=True, 44 | enable_rpe=False, 45 | enable_flash=True, 46 | upcast_attention=False, 47 | upcast_softmax=False, 48 | cls_mode=False) 49 | ) 50 | 51 | # scheduler settings 52 | epoch = 5000 # can be smaller 53 | eval_epoch = 5000 # can be smaller 54 | optimizer = dict(type="AdamW", lr=1e-4, weight_decay=1e-8) 55 | scheduler = dict( 56 | type="OneCycleLR", 57 | max_lr=[1e-4], 58 | pct_start=0.1, 59 | anneal_strategy="cos", 60 | div_factor=10.0, 61 | final_div_factor=10.0, 62 | ) 63 | 64 | # dataset settings 65 | dataset_type = "SAMPart3DDataset16Views" 66 | data_root = "" 67 | mesh_root = "" 68 | backbone_weight_path = "" 69 | 70 | # eval 71 | val_scales_list = [0.0, 0.5, 1.0, 1.5, 2.0] 72 | mesh_voting = True 73 | 74 | data = dict( 75 | train=dict( 76 | type=dataset_type, 77 | split="train", 78 | data_root=data_root, 79 | mesh_root=mesh_root, 80 | sample_num=15000, 81 | pixels_per_image=256, 82 | batch_size=90, 83 | extent_scale=10.0, 84 | transform=[ 85 | dict(type="NormalizeCoord"), 86 | dict(type="CenterShift", apply_z=True), 87 | dict( 88 | type="GridSample", 89 | grid_size=0.01, 90 | keys=("coord", "color", "normal", "origin_coord", "face_index"), 91 | hash_type="fnv", 92 | mode="train", 93 | return_grid_coord=True, 94 | return_inverse=True, 95 | ), 96 | dict(type="CenterShift", apply_z=False), 97 | dict(type="NormalizeColor"), 98 | dict(type="ToTensor"), 99 | dict( 100 | type="Collect", 101 | keys=("coord", "grid_coord", "inverse", "origin_coord", "face_index"), 102 | feat_keys=("coord", "normal", "color"), 103 | ), 104 | ], 105 | ), 106 | ) -------------------------------------------------------------------------------- /launch/eval.py: -------------------------------------------------------------------------------- 1 | from pointcept.engines.defaults import ( 2 | default_argument_parser, 3 | default_config_parser, 4 | default_setup, 5 | ) 6 | from pointcept.engines.eval import TRAINERS 7 | from pointcept.engines.launch import launch 8 | 9 | 10 | def main_worker(cfg): 11 | cfg = default_setup(cfg) 12 | trainer = TRAINERS.build(dict(type=cfg.train.type, cfg=cfg)) 13 | trainer.eval() 14 | 15 | 16 | def main(): 17 | args = default_argument_parser().parse_args() 18 | cfg = default_config_parser(args.config_file, args.options) 19 | 20 | launch( 21 | main_worker, 22 | num_gpus_per_machine=args.num_gpus, 23 | num_machines=args.num_machines, 24 | machine_rank=args.machine_rank, 25 | dist_url=args.dist_url, 26 | cfg=(cfg,), 27 | ) 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /launch/train.py: -------------------------------------------------------------------------------- 1 | from pointcept.engines.defaults import ( 2 | default_argument_parser, 3 | default_config_parser, 4 | default_setup, 5 | ) 6 | from pointcept.engines.train import TRAINERS 7 | from pointcept.engines.launch import launch 8 | 9 | 10 | def main_worker(cfg): 11 | cfg = default_setup(cfg) 12 | trainer = TRAINERS.build(dict(type=cfg.train.type, cfg=cfg)) 13 | trainer.train() 14 | 15 | 16 | def main(): 17 | args = default_argument_parser().parse_args() 18 | cfg = default_config_parser(args.config_file, args.options) 19 | 20 | launch( 21 | main_worker, 22 | num_gpus_per_machine=args.num_gpus, 23 | num_machines=args.num_machines, 24 | machine_rank=args.machine_rank, 25 | dist_url=args.dist_url, 26 | cfg=(cfg,), 27 | ) 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /libs/pointops/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | -------------------------------------------------------------------------------- /libs/pointops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .query import knn_query, ball_query, random_ball_query 2 | from .sampling import farthest_point_sampling 3 | from .grouping import grouping, grouping2 4 | from .interpolation import interpolation, interpolation2 5 | from .subtraction import subtraction 6 | from .aggregation import aggregation 7 | from .attention import attention_relation_step, attention_fusion_step 8 | from .utils import ( 9 | query_and_group, 10 | knn_query_and_group, 11 | ball_query_and_group, 12 | batch2offset, 13 | offset2batch, 14 | ) 15 | -------------------------------------------------------------------------------- /libs/pointops/functions/aggregation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from pointops._C import aggregation_forward_cuda, aggregation_backward_cuda 5 | 6 | 7 | class Aggregation(Function): 8 | @staticmethod 9 | def forward(ctx, input, position, weight, idx): 10 | """ 11 | input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample) 12 | output: (n, c) 13 | """ 14 | assert ( 15 | input.is_contiguous() 16 | and position.is_contiguous() 17 | and weight.is_contiguous() 18 | ) 19 | n, nsample, c = position.shape 20 | w_c = weight.shape[-1] 21 | output = torch.cuda.FloatTensor(n, c).zero_() 22 | aggregation_forward_cuda( 23 | n, nsample, c, w_c, input, position, weight, idx, output 24 | ) 25 | ctx.save_for_backward(input, position, weight, idx) 26 | return output 27 | 28 | @staticmethod 29 | def backward(ctx, grad_output): 30 | """ 31 | input: grad_out: (n, c) 32 | output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c') 33 | """ 34 | input, position, weight, idx = ctx.saved_tensors 35 | n, nsample, c = position.shape 36 | w_c = weight.shape[-1] 37 | grad_input = torch.cuda.FloatTensor(n, c).zero_() 38 | grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_() 39 | grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_() 40 | aggregation_backward_cuda( 41 | n, 42 | nsample, 43 | c, 44 | w_c, 45 | input, 46 | position, 47 | weight, 48 | idx, 49 | grad_output, 50 | grad_input, 51 | grad_position, 52 | grad_weight, 53 | ) 54 | return grad_input, grad_position, grad_weight, None 55 | 56 | 57 | aggregation = Aggregation.apply 58 | -------------------------------------------------------------------------------- /libs/pointops/functions/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from pointops._C import ( 5 | attention_relation_step_forward_cuda, 6 | attention_relation_step_backward_cuda, 7 | attention_fusion_step_forward_cuda, 8 | attention_fusion_step_backward_cuda, 9 | ) 10 | 11 | 12 | class AttentionRelationStep(Function): 13 | @staticmethod 14 | def forward(ctx, query, key, weight, index_target, index_refer): 15 | """ 16 | input - query: (n, g, c), key: (n, g, c), weight: (c) 1_c for scatter attention, 17 | index_target: (m), index_refer: (m) 18 | output - relation: (M, g) 19 | """ 20 | 21 | assert ( 22 | query.is_contiguous() 23 | and key.is_contiguous() 24 | and index_target.is_contiguous() 25 | and index_refer.is_contiguous() 26 | and weight.is_contiguous() 27 | ) 28 | 29 | assert index_target.shape[0] == index_refer.shape[0] 30 | 31 | _, g, c = query.shape 32 | m = index_target.shape[0] 33 | output = torch.cuda.FloatTensor(m, g).zero_() 34 | attention_relation_step_forward_cuda( 35 | m, g, c, query, key, weight, index_target.int(), index_refer.int(), output 36 | ) 37 | ctx.save_for_backward(query, key, weight, index_target, index_refer) 38 | return output 39 | 40 | @staticmethod 41 | def backward(ctx, grad_output): 42 | query, key, weight, index_target, index_refer = ctx.saved_tensors 43 | n, g, c = query.shape 44 | m = index_target.shape[0] 45 | grad_query = torch.cuda.FloatTensor(n, g, c).zero_() 46 | grad_key = torch.cuda.FloatTensor(n, g, c).zero_() 47 | grad_weight = torch.cuda.FloatTensor(c).zero_() 48 | attention_relation_step_backward_cuda( 49 | m, 50 | g, 51 | c, 52 | query, 53 | grad_query, 54 | key, 55 | grad_key, 56 | weight, 57 | grad_weight, 58 | index_target.int(), 59 | index_refer.int(), 60 | grad_output, 61 | ) 62 | return grad_query, grad_key, None, None, None 63 | 64 | 65 | class AttentionFusionStep(Function): 66 | @staticmethod 67 | def forward(ctx, weight, value, index_target, index_refer): 68 | """ 69 | input - weight: (m, g), value: (n, g, c) 70 | index_target: (m), index_value: (m) 71 | output - output: (n, g, c) 72 | """ 73 | 74 | assert ( 75 | weight.is_contiguous() 76 | and value.is_contiguous() 77 | and index_target.is_contiguous() 78 | and index_refer.is_contiguous() 79 | and weight.is_contiguous() 80 | ) 81 | 82 | assert index_target.shape[0] == index_refer.shape[0] 83 | 84 | n, g, c = value.shape 85 | m = index_refer.shape[0] 86 | output = torch.cuda.FloatTensor(n, g, c).zero_() 87 | attention_fusion_step_forward_cuda( 88 | m, g, c, weight, value, index_target.int(), index_refer.int(), output 89 | ) 90 | ctx.save_for_backward(weight, value, index_target, index_refer) 91 | return output 92 | 93 | @staticmethod 94 | def backward(ctx, grad_output): 95 | """ 96 | input: grad_output: (n, g, c) 97 | output: grad_weight: (m, g), grad_value: (n, g, c), none, none 98 | """ 99 | weight, value, index_target, index_refer = ctx.saved_tensors 100 | n, g, c = value.shape 101 | m = index_target.shape[0] 102 | grad_weight = torch.cuda.FloatTensor(m, g).zero_() 103 | grad_value = torch.cuda.FloatTensor(n, g, c).zero_() 104 | attention_fusion_step_backward_cuda( 105 | m, 106 | g, 107 | c, 108 | weight, 109 | grad_weight, 110 | value, 111 | grad_value, 112 | index_target.int(), 113 | index_refer.int(), 114 | grad_output, 115 | ) 116 | return grad_weight, grad_value, None, None 117 | 118 | 119 | attention_relation_step = AttentionRelationStep.apply 120 | attention_fusion_step = AttentionFusionStep.apply 121 | -------------------------------------------------------------------------------- /libs/pointops/functions/grouping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from pointops._C import grouping_forward_cuda, grouping_backward_cuda 5 | 6 | 7 | class Grouping(Function): 8 | @staticmethod 9 | def forward(ctx, input, idx): 10 | """ 11 | input: input: (n, c), idx : (m, nsample) 12 | output: (m, nsample, c) 13 | """ 14 | assert input.is_contiguous() and idx.is_contiguous() 15 | m, nsample, n, c = idx.shape[0], idx.shape[1], input.shape[0], input.shape[1] 16 | output = torch.cuda.FloatTensor(m, nsample, c) 17 | grouping_forward_cuda(m, nsample, c, input, idx, output) 18 | ctx.n = n 19 | ctx.save_for_backward(idx) 20 | return output 21 | 22 | @staticmethod 23 | def backward(ctx, grad_output): 24 | """ 25 | input: grad_out: (m, c, nsample) 26 | output: (n, c), None 27 | """ 28 | n = ctx.n 29 | (idx,) = ctx.saved_tensors 30 | m, nsample, c = grad_output.shape 31 | grad_input = torch.cuda.FloatTensor(n, c).zero_() 32 | grouping_backward_cuda(m, nsample, c, grad_output, idx, grad_input) 33 | return grad_input, None 34 | 35 | 36 | def grouping(idx, feat, xyz, new_xyz=None, with_xyz=False): 37 | if new_xyz is None: 38 | new_xyz = xyz 39 | assert xyz.is_contiguous() and feat.is_contiguous() 40 | m, nsample, c = idx.shape[0], idx.shape[1], feat.shape[1] 41 | xyz = torch.cat([xyz, torch.zeros([1, 3]).to(xyz.device)], dim=0) 42 | feat = torch.cat([feat, torch.zeros([1, c]).to(feat.device)], dim=0) 43 | grouped_feat = feat[idx.view(-1).long(), :].view( 44 | m, nsample, c 45 | ) # (m, num_sample, c) 46 | 47 | if with_xyz: 48 | assert new_xyz.is_contiguous() 49 | mask = torch.sign(idx + 1) 50 | grouped_xyz = xyz[idx.view(-1).long(), :].view( 51 | m, nsample, 3 52 | ) - new_xyz.unsqueeze( 53 | 1 54 | ) # (m, num_sample, 3) 55 | grouped_xyz = torch.einsum( 56 | "n s c, n s -> n s c", grouped_xyz, mask 57 | ) # (m, num_sample, 3) 58 | return torch.cat((grouped_xyz, grouped_feat), -1) 59 | else: 60 | return grouped_feat 61 | 62 | 63 | grouping2 = Grouping.apply 64 | -------------------------------------------------------------------------------- /libs/pointops/functions/interpolation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from pointops._C import interpolation_forward_cuda, interpolation_backward_cuda 5 | from .query import knn_query 6 | 7 | 8 | def interpolation(xyz, new_xyz, feat, offset, new_offset, k=3): 9 | """ 10 | input: coords: (m, 3), new_xyz: (n, 3), color: (m, c), offset: (b), new_offset: (b) 11 | output: (n, c) 12 | """ 13 | assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() 14 | idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) # (n, 3), (n, 3) 15 | dist_recip = 1.0 / (dist + 1e-8) # (n, 3) 16 | norm = torch.sum(dist_recip, dim=1, keepdim=True) 17 | weight = dist_recip / norm # (n, 3) 18 | 19 | new_feat = torch.cuda.FloatTensor(new_xyz.shape[0], feat.shape[1]).zero_() 20 | for i in range(k): 21 | new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) 22 | return new_feat 23 | 24 | 25 | class Interpolation(Function): 26 | @staticmethod 27 | def forward(ctx, xyz, new_xyz, input, offset, new_offset, k=3): 28 | """ 29 | input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) 30 | output: (n, c) 31 | """ 32 | assert xyz.is_contiguous() and new_xyz.is_contiguous() and input.is_contiguous() 33 | idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) # (n, k), (n, k) 34 | dist_recip = 1.0 / (dist + 1e-8) # (n, k) 35 | norm = torch.sum(dist_recip, dim=1, keepdim=True) 36 | weight = dist_recip / norm # (n, k) 37 | 38 | n, c, m = new_xyz.shape[0], input.shape[1], input.shape[0] 39 | output = torch.cuda.FloatTensor(n, c).zero_() 40 | interpolation_forward_cuda(n, c, k, input, idx, weight, output) 41 | ctx.m, ctx.k = m, k 42 | ctx.save_for_backward(idx, weight) 43 | return output 44 | 45 | @staticmethod 46 | def backward(ctx, grad_output): 47 | """ 48 | input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) 49 | output: (n, c) 50 | """ 51 | m, k = ctx.m, ctx.k 52 | idx, weight = ctx.saved_tensors 53 | n, c = grad_output.shape 54 | grad_input = torch.cuda.FloatTensor(m, c).zero_() 55 | interpolation_backward_cuda(n, c, k, grad_output, idx, weight, grad_input) 56 | return None, None, grad_input, None, None, None 57 | 58 | 59 | interpolation2 = Interpolation.apply 60 | -------------------------------------------------------------------------------- /libs/pointops/functions/query.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from pointops._C import knn_query_cuda, random_ball_query_cuda, ball_query_cuda 5 | 6 | 7 | class KNNQuery(Function): 8 | @staticmethod 9 | def forward(ctx, nsample, xyz, offset, new_xyz=None, new_offset=None): 10 | """ 11 | input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) 12 | output: idx: (m, nsample) -1 is placeholder, dist2: (m, nsample) 13 | """ 14 | if new_xyz is None or new_offset is None: 15 | new_xyz = xyz 16 | new_offset = offset 17 | assert xyz.is_contiguous() and new_xyz.is_contiguous() 18 | m = new_xyz.shape[0] 19 | # idx = torch.cuda.IntTensor(m, nsample).zero_() 20 | # dist2 = torch.cuda.FloatTensor(m, nsample).zero_() 21 | idx = torch.zeros((m, nsample), dtype=torch.int32, device='cuda') 22 | dist2 = torch.zeros((m, nsample), dtype=torch.float32, device='cuda') 23 | knn_query_cuda( 24 | m, nsample, xyz, new_xyz, offset.int(), new_offset.int(), idx, dist2 25 | ) 26 | return idx, torch.sqrt(dist2) 27 | 28 | 29 | class RandomBallQuery(Function): 30 | """Random Ball Query. 31 | 32 | Find nearby points in spherical space. 33 | """ 34 | 35 | @staticmethod 36 | def forward( 37 | ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None 38 | ): 39 | """ 40 | input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) 41 | output: idx: (m, nsample), dist2: (m, nsample) 42 | """ 43 | if new_xyz is None or new_offset is None: 44 | new_xyz = xyz 45 | new_offset = offset 46 | assert xyz.is_contiguous() and new_xyz.is_contiguous() 47 | assert min_radius < max_radius 48 | 49 | m = new_xyz.shape[0] 50 | order = [] 51 | for k in range(offset.shape[0]): 52 | s_k, e_k = (0, offset[0]) if k == 0 else (offset[k - 1], offset[k]) 53 | order.append( 54 | torch.randperm(e_k - s_k, dtype=torch.int32, device=offset.device) + s_k 55 | ) 56 | order = torch.cat(order, dim=0) 57 | # idx = torch.cuda.IntTensor(m, nsample).zero_() 58 | # dist2 = torch.cuda.FloatTensor(m, nsample).zero_() 59 | idx = torch.zeros((m, nsample), dtype=torch.int32, device='cuda') 60 | dist2 = torch.zeros((m, nsample), dtype=torch.float32, device='cuda') 61 | random_ball_query_cuda( 62 | m, 63 | nsample, 64 | min_radius, 65 | max_radius, 66 | order, 67 | xyz, 68 | new_xyz, 69 | offset.int(), 70 | new_offset.int(), 71 | idx, 72 | dist2, 73 | ) 74 | return idx, torch.sqrt(dist2) 75 | 76 | 77 | class BallQuery(Function): 78 | """Ball Query. 79 | 80 | Find nearby points in spherical space. 81 | """ 82 | 83 | @staticmethod 84 | def forward( 85 | ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None 86 | ): 87 | """ 88 | input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) 89 | output: idx: (m, nsample), dist2: (m, nsample) 90 | """ 91 | if new_xyz is None or new_offset is None: 92 | new_xyz = xyz 93 | new_offset = offset 94 | assert xyz.is_contiguous() and new_xyz.is_contiguous() 95 | assert min_radius < max_radius 96 | 97 | m = new_xyz.shape[0] 98 | # idx = torch.cuda.IntTensor(m, nsample).zero_() 99 | # dist2 = torch.cuda.FloatTensor(m, nsample).zero_() 100 | idx = torch.zeros((m, nsample), dtype=torch.int32, device='cuda') 101 | dist2 = torch.zeros((m, nsample), dtype=torch.float32, device='cuda') 102 | ball_query_cuda( 103 | m, 104 | nsample, 105 | min_radius, 106 | max_radius, 107 | xyz, 108 | new_xyz, 109 | offset.int(), 110 | new_offset.int(), 111 | idx, 112 | dist2, 113 | ) 114 | return idx, torch.sqrt(dist2) 115 | 116 | 117 | knn_query = KNNQuery.apply 118 | ball_query = BallQuery.apply 119 | random_ball_query = RandomBallQuery.apply 120 | -------------------------------------------------------------------------------- /libs/pointops/functions/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from pointops._C import farthest_point_sampling_cuda 5 | 6 | 7 | class FarthestPointSampling(Function): 8 | @staticmethod 9 | def forward(ctx, xyz, offset, new_offset): 10 | """ 11 | input: coords: (n, 3), offset: (b), new_offset: (b) 12 | output: idx: (m) 13 | """ 14 | assert xyz.is_contiguous() 15 | n, b, n_max = xyz.shape[0], offset.shape[0], offset[0] 16 | for i in range(1, b): 17 | n_max = max(offset[i] - offset[i - 1], n_max) 18 | idx = torch.cuda.IntTensor(new_offset[b - 1].item()).zero_() 19 | tmp = torch.cuda.FloatTensor(n).fill_(1e10) 20 | farthest_point_sampling_cuda( 21 | b, n_max, xyz, offset.int(), new_offset.int(), tmp, idx 22 | ) 23 | del tmp 24 | return idx 25 | 26 | 27 | farthest_point_sampling = FarthestPointSampling.apply 28 | -------------------------------------------------------------------------------- /libs/pointops/functions/subtraction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from pointops._C import subtraction_forward_cuda, subtraction_backward_cuda 5 | 6 | 7 | class Subtraction(Function): 8 | @staticmethod 9 | def forward(ctx, input1, input2, idx): 10 | """ 11 | input: input1: (n, c), input2: (n, c), idx: (n, nsample) 12 | output: (n, nsample, c) 13 | """ 14 | assert input1.is_contiguous() and input2.is_contiguous() 15 | n, c = input1.shape 16 | nsample = idx.shape[-1] 17 | output = torch.cuda.FloatTensor(n, nsample, c).zero_() 18 | subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output) 19 | ctx.save_for_backward(idx) 20 | return output 21 | 22 | @staticmethod 23 | def backward(ctx, grad_output): 24 | """ 25 | input: grad_out: (n, nsample, c) 26 | output: grad_input1: (n, c), grad_input2: (n, c) 27 | """ 28 | (idx,) = ctx.saved_tensors 29 | n, nsample, c = grad_output.shape 30 | grad_input1 = torch.cuda.FloatTensor(n, c).zero_() 31 | grad_input2 = torch.cuda.FloatTensor(n, c).zero_() 32 | subtraction_backward_cuda( 33 | n, nsample, c, idx, grad_output, grad_input1, grad_input2 34 | ) 35 | return grad_input1, grad_input2, None 36 | 37 | 38 | subtraction = Subtraction.apply 39 | -------------------------------------------------------------------------------- /libs/pointops/functions/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pointops import knn_query, ball_query, grouping 3 | 4 | 5 | def knn_query_and_group( 6 | feat, 7 | xyz, 8 | offset=None, 9 | new_xyz=None, 10 | new_offset=None, 11 | idx=None, 12 | nsample=None, 13 | with_xyz=False, 14 | ): 15 | if idx is None: 16 | assert nsample is not None 17 | idx, _ = knn_query(nsample, xyz, offset, new_xyz, new_offset) 18 | return grouping(idx, feat, xyz, new_xyz, with_xyz), idx 19 | 20 | 21 | def ball_query_and_group( 22 | feat, 23 | xyz, 24 | offset=None, 25 | new_xyz=None, 26 | new_offset=None, 27 | idx=None, 28 | max_radio=None, 29 | min_radio=0, 30 | nsample=None, 31 | with_xyz=False, 32 | ): 33 | if idx is None: 34 | assert nsample is not None and offset is not None 35 | assert max_radio is not None and min_radio is not None 36 | idx, _ = ball_query( 37 | nsample, max_radio, min_radio, xyz, offset, new_xyz, new_offset 38 | ) 39 | return grouping(idx, feat, xyz, new_xyz, with_xyz), idx 40 | 41 | 42 | def query_and_group( 43 | nsample, 44 | xyz, 45 | new_xyz, 46 | feat, 47 | idx, 48 | offset, 49 | new_offset, 50 | dilation=0, 51 | with_feat=True, 52 | with_xyz=True, 53 | ): 54 | """ 55 | input: coords: (n, 3), new_xyz: (m, 3), color: (n, c), idx: (m, nsample), offset: (b), new_offset: (b) 56 | output: new_feat: (m, nsample, c+3), grouped_idx: (m, nsample) 57 | """ 58 | assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() 59 | if new_xyz is None: 60 | new_xyz = xyz 61 | 62 | if idx is None: 63 | num_samples_total = 1 + (nsample - 1) * (dilation + 1) 64 | # num points in a batch might < num_samples_total => [n1, n2, ..., nk, ns, ns, ns, ...] 65 | idx_no_dilation, _ = knn_query( 66 | num_samples_total, xyz, offset, new_xyz, new_offset 67 | ) # (m, nsample * (d + 1)) 68 | idx = [] 69 | batch_end = offset.tolist() 70 | batch_start = [0] + batch_end[:-1] 71 | new_batch_end = new_offset.tolist() 72 | new_batch_start = [0] + new_batch_end[:-1] 73 | for i in range(offset.shape[0]): 74 | if batch_end[i] - batch_start[i] < num_samples_total: 75 | soft_dilation = (batch_end[i] - batch_start[i] - 1) / (nsample - 1) - 1 76 | else: 77 | soft_dilation = dilation 78 | idx.append( 79 | idx_no_dilation[ 80 | new_batch_start[i] : new_batch_end[i], 81 | [int((soft_dilation + 1) * i) for i in range(nsample)], 82 | ] 83 | ) 84 | idx = torch.cat(idx, dim=0) 85 | 86 | if not with_feat: 87 | return idx 88 | 89 | n, m, c = xyz.shape[0], new_xyz.shape[0], feat.shape[1] 90 | grouped_xyz = xyz[idx.view(-1).long(), :].view(m, nsample, 3) # (m, nsample, 3) 91 | # grouped_xyz = grouping(coords, idx) # (m, nsample, 3) 92 | grouped_xyz -= new_xyz.unsqueeze(1) # (m, nsample, 3) 93 | grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c) # (m, nsample, c) 94 | # grouped_feat = grouping(color, idx) # (m, nsample, c) 95 | 96 | if with_xyz: 97 | return torch.cat((grouped_xyz, grouped_feat), -1), idx # (m, nsample, 3+c) 98 | else: 99 | return grouped_feat, idx 100 | 101 | 102 | def offset2batch(offset): 103 | return ( 104 | torch.cat( 105 | [ 106 | torch.tensor([i] * (o - offset[i - 1])) 107 | if i > 0 108 | else torch.tensor([i] * o) 109 | for i, o in enumerate(offset) 110 | ], 111 | dim=0, 112 | ) 113 | .long() 114 | .to(offset.device) 115 | ) 116 | 117 | 118 | def batch2offset(batch): 119 | return torch.cumsum(batch.bincount(), dim=0).int() 120 | -------------------------------------------------------------------------------- /libs/pointops/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | from distutils.sysconfig import get_config_vars 5 | 6 | (opt,) = get_config_vars("OPT") 7 | os.environ["OPT"] = " ".join( 8 | flag for flag in opt.split() if flag != "-Wstrict-prototypes" 9 | ) 10 | 11 | src = "src" 12 | sources = [ 13 | os.path.join(root, file) 14 | for root, dirs, files in os.walk(src) 15 | for file in files 16 | if file.endswith(".cpp") or file.endswith(".cu") 17 | ] 18 | 19 | setup( 20 | name="pointops", 21 | version="1.0", 22 | install_requires=["torch", "numpy"], 23 | packages=["pointops"], 24 | package_dir={"pointops": "functions"}, 25 | ext_modules=[ 26 | CUDAExtension( 27 | name="pointops._C", 28 | sources=sources, 29 | extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]}, 30 | ) 31 | ], 32 | cmdclass={"build_ext": BuildExtension}, 33 | ) 34 | -------------------------------------------------------------------------------- /libs/pointops/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pointcept/SAMPart3D/6a508d145be5cecd7e80bce6e4248ea7abbd71b7/libs/pointops/src/__init__.py -------------------------------------------------------------------------------- /libs/pointops/src/aggregation/aggregation_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "aggregation_cuda_kernel.h" 5 | 6 | 7 | void aggregation_forward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) 8 | { 9 | const float *input = input_tensor.data_ptr(); 10 | const float *position = position_tensor.data_ptr(); 11 | const float *weight = weight_tensor.data_ptr(); 12 | const int *idx = idx_tensor.data_ptr(); 13 | float *output = output_tensor.data_ptr(); 14 | aggregation_forward_cuda_launcher(n, nsample, c, w_c, input, position, weight, idx, output); 15 | } 16 | 17 | void aggregation_backward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input_tensor, at::Tensor grad_position_tensor, at::Tensor grad_weight_tensor) 18 | { 19 | const float *input = input_tensor.data_ptr(); 20 | const float *position = position_tensor.data_ptr(); 21 | const float *weight = weight_tensor.data_ptr(); 22 | const int *idx = idx_tensor.data_ptr(); 23 | const float *grad_output = grad_output_tensor.data_ptr(); 24 | float *grad_input = grad_input_tensor.data_ptr(); 25 | float *grad_position = grad_position_tensor.data_ptr(); 26 | float *grad_weight = grad_weight_tensor.data_ptr(); 27 | aggregation_backward_cuda_launcher(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight); 28 | } 29 | -------------------------------------------------------------------------------- /libs/pointops/src/aggregation/aggregation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "aggregation_cuda_kernel.h" 3 | 4 | 5 | __global__ void aggregation_forward_cuda_kernel(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output) { 6 | // input: input: (n, c), position: (n, nsample, c), weight: (n, nsample, w_c), idx: (n, nsample), output: (n, c) 7 | int index = blockIdx.x * blockDim.x + threadIdx.x; 8 | if (index >= n * c) return; 9 | const int c_idx = index % c; 10 | const int n_idx = index / c; 11 | const int w_c_idx = c_idx % w_c; 12 | for (int nsample_idx = 0; nsample_idx < nsample; nsample_idx++) 13 | { 14 | int idx_idx = n_idx * nsample + nsample_idx; 15 | int input_idx = idx[idx_idx] * c + c_idx; 16 | int position_idx = n_idx * nsample * c + nsample_idx * c + c_idx; 17 | int weight_idx = n_idx * nsample * w_c + nsample_idx * w_c + w_c_idx; 18 | output[index] += (input[input_idx] + position[position_idx]) * weight[weight_idx]; 19 | } 20 | } 21 | 22 | __global__ void aggregation_backward_cuda_kernel(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight) { 23 | // input: grad_output: (n, c), output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight: (n, nsample, w_c) 24 | int index = blockIdx.x * blockDim.x + threadIdx.x; 25 | if (index >= n * c) return; 26 | const int c_idx = index % c; 27 | const int n_idx = index / c; 28 | const int w_c_idx = c_idx % w_c; 29 | for (int nsample_idx = 0; nsample_idx < nsample; nsample_idx++) 30 | { 31 | int idx_idx = n_idx * nsample + nsample_idx; 32 | int input_idx = idx[idx_idx] * c + c_idx; 33 | int position_idx = n_idx * nsample * c + nsample_idx * c + c_idx; 34 | int weight_idx = n_idx * nsample * w_c + nsample_idx * w_c + w_c_idx; 35 | atomicAdd(grad_input + input_idx, grad_output[index] * weight[weight_idx]); 36 | grad_position[position_idx] = grad_output[index] * weight[weight_idx]; 37 | atomicAdd(grad_weight + weight_idx, grad_output[index] * (input[input_idx] + position[position_idx])); 38 | } 39 | } 40 | 41 | void aggregation_forward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output) { 42 | // input: input: (n, c), position: (n, nsample, c), weight: (n, nsample, w_c), idx: (n, nsample), output: (n, c) 43 | dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); 44 | dim3 threads(THREADS_PER_BLOCK); 45 | aggregation_forward_cuda_kernel<<>>(n, nsample, c, w_c, input, position, weight, idx, output); 46 | } 47 | 48 | void aggregation_backward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight) { 49 | // input: grad_output: (n, c), output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight: (n, nsample, w_c) 50 | dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); 51 | dim3 threads(THREADS_PER_BLOCK); 52 | aggregation_backward_cuda_kernel<<>>(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight); 53 | } 54 | -------------------------------------------------------------------------------- /libs/pointops/src/aggregation/aggregation_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _AGGREGATION_CUDA_KERNEL 2 | #define _AGGREGATION_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void aggregation_forward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); 8 | void aggregation_backward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input_tensor, at::Tensor grad_position_tensor, at::Tensor grad_weight_tensor); 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | void aggregation_forward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output); 15 | void aggregation_backward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | #endif 21 | -------------------------------------------------------------------------------- /libs/pointops/src/attention/attention_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "attention_cuda_kernel.h" 5 | 6 | 7 | void attention_relation_step_forward_cuda(int m, int g, int c, 8 | at::Tensor query_tensor, at::Tensor key_tensor, at::Tensor weight_tensor, 9 | at::Tensor index_target_tensor, at::Tensor index_refer_tensor, 10 | at::Tensor output_tensor) 11 | { 12 | const float *query = query_tensor.data_ptr(); 13 | const float *key = key_tensor.data_ptr(); 14 | const float *weight = weight_tensor.data_ptr(); 15 | const int *index_target = index_target_tensor.data_ptr(); 16 | const int *index_refer = index_refer_tensor.data_ptr(); 17 | float *output = output_tensor.data_ptr(); 18 | attention_relation_step_forward_cuda_launcher(m, g, c, query, key, weight, index_target, index_refer, output); 19 | } 20 | 21 | void attention_relation_step_backward_cuda(int m, int g, int c, 22 | at::Tensor query_tensor, at::Tensor grad_query_tensor, 23 | at::Tensor key_tensor, at::Tensor grad_key_tensor, 24 | at::Tensor weight_tensor, at::Tensor grad_weight_tensor, 25 | at::Tensor index_target_tensor, at::Tensor index_refer_tensor, 26 | at::Tensor grad_output_tensor) 27 | { 28 | const float *query = query_tensor.data_ptr(); 29 | float *grad_query = grad_query_tensor.data_ptr(); 30 | const float *key = key_tensor.data_ptr(); 31 | float *grad_key = grad_key_tensor.data_ptr(); 32 | const float *weight = weight_tensor.data_ptr(); 33 | float *grad_weight = grad_weight_tensor.data_ptr(); 34 | const int *index_target = index_target_tensor.data_ptr(); 35 | const int *index_refer = index_refer_tensor.data_ptr(); 36 | const float *grad_output = grad_output_tensor.data_ptr(); 37 | attention_relation_step_backward_cuda_launcher(m, g, c, 38 | query, grad_query, 39 | key, grad_key, 40 | weight, grad_weight, 41 | index_target, index_refer, grad_output); 42 | } 43 | 44 | 45 | void attention_fusion_step_forward_cuda(int m, int g, int c, 46 | at::Tensor weight_tensor, at::Tensor value_tensor, 47 | at::Tensor index_target_tensor, at::Tensor index_refer_tensor, 48 | at::Tensor output_tensor) 49 | { 50 | const float *weight = weight_tensor.data_ptr(); 51 | const float *value = value_tensor.data_ptr(); 52 | const int *index_target = index_target_tensor.data_ptr(); 53 | const int *index_refer = index_refer_tensor.data_ptr(); 54 | float *output = output_tensor.data_ptr(); 55 | attention_fusion_step_forward_cuda_launcher(m, g, c, weight, value, index_target, index_refer, output); 56 | } 57 | 58 | 59 | void attention_fusion_step_backward_cuda(int m, int g, int c, 60 | at::Tensor weight_tensor, at::Tensor grad_weight_tensor, 61 | at::Tensor value_tensor, at::Tensor grad_value_tensor, 62 | at::Tensor index_target_tensor, at::Tensor index_refer_tensor, 63 | at::Tensor grad_output_tensor) 64 | { 65 | const float *weight = weight_tensor.data_ptr(); 66 | float *grad_weight = grad_weight_tensor.data_ptr(); 67 | const float *value = value_tensor.data_ptr(); 68 | float *grad_value = grad_value_tensor.data_ptr(); 69 | const int *index_target = index_target_tensor.data_ptr(); 70 | const int *index_refer = index_refer_tensor.data_ptr(); 71 | const float *grad_output = grad_output_tensor.data_ptr(); 72 | attention_fusion_step_backward_cuda_launcher(m, g, c, 73 | weight, grad_weight, 74 | value, grad_value, 75 | index_target, index_refer, grad_output); 76 | } 77 | -------------------------------------------------------------------------------- /libs/pointops/src/attention/attention_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "attention_cuda_kernel.h" 3 | 4 | 5 | /* 6 | Kernels 7 | */ 8 | 9 | __global__ void attention_relation_step_forward_cuda_kernel(int m, int g, int c, 10 | const float *query, const float *key, const float *weight, 11 | const int *index_target, const int *index_refer, 12 | float *output) 13 | { 14 | int r_idx = blockIdx.x * blockDim.x + threadIdx.x; 15 | int g_idx = blockIdx.y; 16 | int c_idx = blockIdx.z; 17 | 18 | if (r_idx >= m || g_idx >= g || c_idx >= c) return; 19 | int q_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; 20 | int k_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; 21 | 22 | float r = query[q_idx] * key[k_idx] * weight[c_idx]; 23 | atomicAdd(output + r_idx * g + g_idx, r); 24 | } 25 | 26 | __global__ void attention_relation_step_backward_cuda_kernel(int m, int g, int c, 27 | const float *query, float *grad_query, 28 | const float *key, float *grad_key, 29 | const float *weight, float *grad_weight, 30 | const int *index_target, const int *index_refer, 31 | const float *grad_output) 32 | { 33 | int r_idx = blockIdx.x * blockDim.x + threadIdx.x; 34 | int g_idx = blockIdx.y; 35 | int c_idx = blockIdx.z; 36 | 37 | if (r_idx >= m || g_idx >= g || c_idx >= c) return; 38 | 39 | int q_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; 40 | int k_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; 41 | int o_idx = r_idx * g + g_idx; 42 | float grad_r = grad_output[o_idx]; 43 | atomicAdd(grad_query + q_idx, grad_r * key[k_idx] * weight[c_idx]); 44 | atomicAdd(grad_key + k_idx, grad_r * query[q_idx] * weight[c_idx]); 45 | atomicAdd(grad_weight + c_idx, grad_r * key[k_idx] * query[q_idx]); 46 | } 47 | 48 | 49 | __global__ void attention_fusion_step_forward_cuda_kernel(int m, int g, int c, 50 | const float *weight, const float *value, 51 | const int *index_target, const int *index_refer, 52 | float *output) 53 | { 54 | int r_idx = blockIdx.x * blockDim.x + threadIdx.x; 55 | int g_idx = blockIdx.y; 56 | int c_idx = blockIdx.z; 57 | 58 | if (r_idx >= m || g_idx >= g || c_idx >= c) return; 59 | 60 | int o_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; 61 | int v_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; 62 | 63 | float f = weight[r_idx * g + g_idx] * value[v_idx]; 64 | atomicAdd(output + o_idx, f); 65 | } 66 | 67 | 68 | __global__ void attention_fusion_step_backward_cuda_kernel(int m, int g, int c, 69 | const float *weight, float *grad_weight, 70 | const float *value, float *grad_value, 71 | const int *index_target, const int *index_refer, 72 | const float *grad_output) 73 | { 74 | int r_idx = blockIdx.x * blockDim.x + threadIdx.x; 75 | int g_idx = blockIdx.y; 76 | int c_idx = blockIdx.z; 77 | 78 | if (r_idx >= m || g_idx >= g || c_idx >= c) return; 79 | 80 | int o_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; 81 | int v_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; 82 | int w_idx = r_idx * g + g_idx; 83 | float grad = grad_output[o_idx]; 84 | atomicAdd(grad_weight + w_idx, grad * value[v_idx]); 85 | atomicAdd(grad_value + v_idx, grad * weight[w_idx]); 86 | } 87 | 88 | /* 89 | Launchers 90 | */ 91 | 92 | 93 | void attention_relation_step_forward_cuda_launcher(int m, int g, int c, 94 | const float *query, const float *key, const float *weight, 95 | const int *index_target, const int *index_refer, 96 | float *output) 97 | { 98 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); 99 | dim3 threads(THREADS_PER_BLOCK); 100 | attention_relation_step_forward_cuda_kernel<<>>(m, g, c, query, key, weight, 101 | index_target, index_refer, output); 102 | } 103 | 104 | void attention_relation_step_backward_cuda_launcher(int m, int g, int c, 105 | const float *query, float *grad_query, 106 | const float *key, float *grad_key, 107 | const float *weight, float *grad_weight, 108 | const int *index_target, const int *index_refer, 109 | const float *grad_output) 110 | { 111 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); 112 | dim3 threads(THREADS_PER_BLOCK); 113 | attention_relation_step_backward_cuda_kernel<<>>(m, g, c, 114 | query, grad_query, 115 | key, grad_key, 116 | weight, grad_weight, 117 | index_target, index_refer, 118 | grad_output); 119 | } 120 | 121 | 122 | void attention_fusion_step_forward_cuda_launcher(int m, int g, int c, 123 | const float *weight, const float *value, 124 | const int *index_target, const int *index_refer, 125 | float *output) 126 | { 127 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); 128 | dim3 threads(THREADS_PER_BLOCK); 129 | attention_fusion_step_forward_cuda_kernel<<>>(m, g, c, weight, value, 130 | index_target, index_refer, output); 131 | } 132 | 133 | 134 | void attention_fusion_step_backward_cuda_launcher(int m, int g, int c, 135 | const float *weight, float *grad_weight, 136 | const float *value, float *grad_value, 137 | const int *index_target, const int *index_refer, 138 | const float *grad_output) 139 | { 140 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); 141 | dim3 threads(THREADS_PER_BLOCK); 142 | attention_fusion_step_backward_cuda_kernel<<>>(m, g, c, 143 | weight, grad_weight, 144 | value, grad_value, 145 | index_target, index_refer, 146 | grad_output); 147 | } 148 | 149 | 150 | -------------------------------------------------------------------------------- /libs/pointops/src/attention/attention_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _ATTENTION_CUDA_KERNEL 2 | #define _ATTENTION_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void attention_relation_step_forward_cuda(int m, int g, int c, 8 | at::Tensor query_tensor, at::Tensor key_tensor, at::Tensor weight_tensor, 9 | at::Tensor index_target_tensor, at::Tensor index_refer_tensor, 10 | at::Tensor output_tensor); 11 | void attention_relation_step_backward_cuda(int m, int g, int c, 12 | at::Tensor query_tensor, at::Tensor grad_query_tensor, 13 | at::Tensor key_tensor, at::Tensor grad_key_tensor, 14 | at::Tensor weight_tensor, at::Tensor grad_weight_tensor, 15 | at::Tensor index_target_tensor, at::Tensor index_refer_tensor, 16 | at::Tensor grad_output_tensor); 17 | void attention_fusion_step_forward_cuda(int m, int g, int c, 18 | at::Tensor weight_tensor, at::Tensor value_tensor, 19 | at::Tensor index_target_tensor, at::Tensor index_refer_tensor, 20 | at::Tensor output_tensor); 21 | void attention_fusion_step_backward_cuda(int m, int g, int c, 22 | at::Tensor weight_tensor, at::Tensor grad_weight_tensor, 23 | at::Tensor value_tensor, at::Tensor grad_value_tensor, 24 | at::Tensor index_target_tensor, at::Tensor index_refer_tensor, 25 | at::Tensor grad_output_tensor); 26 | 27 | #ifdef __cplusplus 28 | extern "C" { 29 | #endif 30 | 31 | void attention_relation_step_forward_cuda_launcher(int m, int g, int c, 32 | const float *query, const float *key, const float *weight, 33 | const int *index_target, const int *index_refer, 34 | float *output); 35 | void attention_relation_step_backward_cuda_launcher(int m, int g, int c, 36 | const float *query, float *grad_query, 37 | const float *key, float *grad_key, 38 | const float *weight, float *grad_weight, 39 | const int *index_target, const int *index_refer, 40 | const float *grad_output); 41 | void attention_fusion_step_forward_cuda_launcher(int m, int g, int c, 42 | const float *weight, const float *value, 43 | const int *index_target, const int *index_refer, 44 | float *output); 45 | void attention_fusion_step_backward_cuda_launcher(int m, int g, int c, 46 | const float *weight, float *grad_weight, 47 | const float *value, float *grad_value, 48 | const int *index_target, const int *index_refer, 49 | const float *grad_output); 50 | 51 | #ifdef __cplusplus 52 | } 53 | #endif 54 | #endif 55 | -------------------------------------------------------------------------------- /libs/pointops/src/ball_query/ball_query_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "ball_query_cuda_kernel.h" 5 | 6 | 7 | void ball_query_cuda(int m, int nsample, 8 | float min_radius, float max_radius, 9 | at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, 10 | at::Tensor offset_tensor, at::Tensor new_offset_tensor, 11 | at::Tensor idx_tensor, at::Tensor dist2_tensor) 12 | { 13 | const float *xyz = xyz_tensor.data_ptr(); 14 | const float *new_xyz = new_xyz_tensor.data_ptr(); 15 | const int *offset = offset_tensor.data_ptr(); 16 | const int *new_offset = new_offset_tensor.data_ptr(); 17 | int *idx = idx_tensor.data_ptr(); 18 | float *dist2 = dist2_tensor.data_ptr(); 19 | ball_query_cuda_launcher(m, nsample, min_radius, max_radius, xyz, new_xyz, offset, new_offset, idx, dist2); 20 | } 21 | -------------------------------------------------------------------------------- /libs/pointops/src/ball_query/ball_query_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "ball_query_cuda_kernel.h" 3 | 4 | 5 | namespace ball_query_utils{ 6 | 7 | template 8 | __device__ void swap(DType *x, DType *y) 9 | { 10 | DType tmp = *x; 11 | *x = *y; 12 | *y = tmp; 13 | } 14 | 15 | __device__ void reheap(float *dist, int *idx, int k) 16 | { 17 | int root = 0; 18 | int child = root * 2 + 1; 19 | while (child < k) 20 | { 21 | if(child + 1 < k && dist[child+1] > dist[child]) 22 | child++; 23 | if(dist[root] > dist[child]) 24 | return; 25 | swap(&dist[root], &dist[child]); 26 | swap(&idx[root], &idx[child]); 27 | root = child; 28 | child = root * 2 + 1; 29 | } 30 | } 31 | 32 | 33 | __device__ void heap_sort(float *dist, int *idx, int k) 34 | { 35 | int i; 36 | for (i = k - 1; i > 0; i--) 37 | { 38 | swap(&dist[0], &dist[i]); 39 | swap(&idx[0], &idx[i]); 40 | reheap(dist, idx, i); 41 | } 42 | } 43 | 44 | __device__ int get_bt_idx(int idx, const int *offset) 45 | { 46 | int i = 0; 47 | while (1) 48 | { 49 | if (idx < offset[i]) 50 | break; 51 | else 52 | i++; 53 | } 54 | return i; 55 | } 56 | } // namespace ball_query_utils 57 | 58 | __global__ void ball_query_cuda_kernel(int m, int nsample, 59 | float min_radius, float max_radius, 60 | const float *__restrict__ xyz, const float *__restrict__ new_xyz, 61 | const int *__restrict__ offset, const int *__restrict__ new_offset, 62 | int *__restrict__ idx, float *__restrict__ dist2) { 63 | // input: xyz (n, 3) new_xyz (m, 3) 64 | // output: idx (m, nsample) dist (m, nsample) 65 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 66 | if (pt_idx >= m) return; 67 | 68 | new_xyz += pt_idx * 3; 69 | idx += pt_idx * nsample; 70 | dist2 += pt_idx * nsample; 71 | 72 | int bt_idx = ball_query_utils::get_bt_idx(pt_idx, new_offset); 73 | int start; 74 | if (bt_idx == 0) 75 | start = 0; 76 | else 77 | start = offset[bt_idx - 1]; 78 | int end = offset[bt_idx]; 79 | 80 | float max_radius2 = max_radius * max_radius; 81 | float min_radius2 = min_radius * min_radius; 82 | float new_x = new_xyz[0]; 83 | float new_y = new_xyz[1]; 84 | float new_z = new_xyz[2]; 85 | 86 | float candi_dist[2048]; 87 | int candi_idx[2048]; 88 | int candi_num = 0; 89 | 90 | for(int i = start; i < end; i++){ 91 | float x = xyz[i * 3 + 0]; 92 | float y = xyz[i * 3 + 1]; 93 | float z = xyz[i * 3 + 2]; 94 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 95 | 96 | if (d2 <= 1e-5 || (d2 >= min_radius2 && d2 < max_radius2)){ 97 | // TODO: Check d2 <= 1e-5 98 | candi_dist[candi_num] = d2; 99 | candi_idx[candi_num] = i; 100 | candi_num += 1; 101 | } 102 | } 103 | ball_query_utils::heap_sort(candi_dist, candi_idx, candi_num); 104 | if(candi_num <= nsample){ 105 | for(int i = 0; i < candi_num; i++){ 106 | idx[i] = candi_idx[i]; 107 | dist2[i] = candi_dist[i]; 108 | } 109 | for(int i = candi_num; i < nsample; i++){ 110 | idx[i] = -1; 111 | dist2[i] = 1e10; 112 | } 113 | } 114 | else{ 115 | float sep = static_cast(candi_num) / nsample; 116 | for(int i = 0; i < nsample; i++) 117 | { 118 | int index = static_cast(sep * i); 119 | idx[i] = candi_idx[index]; 120 | dist2[i] = candi_idx[index]; 121 | } 122 | } 123 | } 124 | 125 | /* Random Sample Mode Ball Query */ 126 | 127 | // __global__ void ball_query_cuda_kernel(int m, int nsample, 128 | // float min_radius, float max_radius, 129 | // const float *__restrict__ xyz, const float *__restrict__ new_xyz, 130 | // const int *__restrict__ offset, const int *__restrict__ new_offset, 131 | // int *__restrict__ idx, float *__restrict__ dist2) { 132 | // // input: xyz (n, 3) new_xyz (m, 3) 133 | // // output: idx (m, nsample) dist (m, nsample) 134 | // int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 135 | // if (pt_idx >= m) return; 136 | // 137 | // new_xyz += pt_idx * 3; 138 | // idx += pt_idx * nsample; 139 | // dist2 += pt_idx * nsample; 140 | // 141 | // int bt_idx = ball_get_bt_idx(pt_idx, new_offset); 142 | // int start; 143 | // if (bt_idx == 0) 144 | // start = 0; 145 | // else 146 | // start = offset[bt_idx - 1]; 147 | // int end = offset[bt_idx]; 148 | // 149 | // float max_radius2 = max_radius * max_radius; 150 | // float min_radius2 = min_radius * min_radius; 151 | // float new_x = new_xyz[0]; 152 | // float new_y = new_xyz[1]; 153 | // float new_z = new_xyz[2]; 154 | // 155 | // int cnt = 0; 156 | // for(int i = start; i < end; i++){ 157 | // float x = xyz[i * 3 + 0]; 158 | // float y = xyz[i * 3 + 1]; 159 | // float z = xyz[i * 3 + 2]; 160 | // float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 161 | // 162 | // if (d2 == 0 || (d2 >= min_radius2 && d2 < max_radius2)) { 163 | // if (cnt == 0) { 164 | // for (int l = 0; l < nsample; ++l) { 165 | // idx[l] = i; 166 | // dist2[l] = d2; 167 | // } 168 | // } 169 | // idx[cnt] = i; 170 | // ++cnt; 171 | // if (cnt >= nsample) break; 172 | // } 173 | // } 174 | // } 175 | 176 | 177 | void ball_query_cuda_launcher(int m, int nsample, 178 | float min_radius, float max_radius, 179 | const float *xyz, const float *new_xyz, 180 | const int *offset, const int *new_offset, 181 | int *idx, float *dist2) { 182 | // input: new_xyz: (m, 3), xyz: (n, 3), idx: (m, nsample) 183 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK)); 184 | dim3 threads(THREADS_PER_BLOCK); 185 | ball_query_cuda_kernel<<>>(m, nsample, 186 | min_radius, max_radius, 187 | xyz, new_xyz, 188 | offset, new_offset, 189 | idx, dist2); 190 | } 191 | -------------------------------------------------------------------------------- /libs/pointops/src/ball_query/ball_query_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_CUDA_KERNEL 2 | #define _BALL_QUERY_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void ball_query_cuda(int m, int nsample, 8 | float min_radius, float max_radius, 9 | at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, 10 | at::Tensor offset_tensor, at::Tensor new_offset_tensor, 11 | at::Tensor idx_tensor, at::Tensor dist2_tensor); 12 | 13 | #ifdef __cplusplus 14 | extern "C" { 15 | #endif 16 | 17 | void ball_query_cuda_launcher(int m, int nsample, 18 | float min_radius, float max_radius, 19 | const float *xyz, const float *new_xyz, 20 | const int *offset, const int *new_offset, 21 | int *idx, float *dist2); 22 | 23 | #ifdef __cplusplus 24 | } 25 | #endif 26 | #endif 27 | -------------------------------------------------------------------------------- /libs/pointops/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | 7 | #define TOTAL_THREADS 1024 8 | #define THREADS_PER_BLOCK 512 9 | #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) 10 | 11 | inline int opt_n_threads(int work_size) { 12 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 13 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | 16 | inline dim3 opt_block_config(int x, int y) { 17 | const int x_threads = opt_n_threads(x); 18 | const int y_threads = std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 19 | dim3 block_config(x_threads, y_threads, 1); 20 | return block_config; 21 | } 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /libs/pointops/src/grouping/grouping_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "grouping_cuda_kernel.h" 5 | 6 | 7 | void grouping_forward_cuda(int m, int nsample, int c, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) 8 | { 9 | const float *input = input_tensor.data_ptr(); 10 | const int *idx = idx_tensor.data_ptr(); 11 | float *output = output_tensor.data_ptr(); 12 | grouping_forward_cuda_launcher(m, nsample, c, input, idx, output); 13 | } 14 | 15 | void grouping_backward_cuda(int m, int nsample, int c, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor grad_input_tensor) 16 | { 17 | const float *grad_output = grad_output_tensor.data_ptr(); 18 | const int *idx = idx_tensor.data_ptr(); 19 | float *grad_input = grad_input_tensor.data_ptr(); 20 | grouping_backward_cuda_launcher(m, nsample, c, grad_output, idx, grad_input); 21 | } 22 | -------------------------------------------------------------------------------- /libs/pointops/src/grouping/grouping_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "grouping_cuda_kernel.h" 3 | 4 | 5 | __global__ void grouping_forward_cuda_kernel(int m, int nsample, int c, const float *__restrict__ input, const int *__restrict__ idx, float *__restrict__ output) { 6 | // input: input: (n, c), idx: (m, nsample), output: (m, nsample, c) 7 | int index = blockIdx.x * blockDim.x + threadIdx.x; 8 | if (index >= m * nsample * c) return; 9 | const int c_idx = index % c; 10 | const int nsample_idx = (index / c) % nsample; 11 | const int m_idx = index / nsample / c; 12 | const int input_idx = idx[m_idx * nsample + nsample_idx] * c + c_idx; 13 | output[index] = input[input_idx]; 14 | } 15 | 16 | __global__ void grouping_backward_cuda_kernel(int m, int nsample, int c, const float *__restrict__ grad_output, const int *__restrict__ idx, float *__restrict__ grad_input) { 17 | // input: grad_output: (m, nsample, c), idx: (m, nsample), output: grad_input: (n, c) 18 | int index = blockIdx.x * blockDim.x + threadIdx.x; 19 | if (index >= m * nsample * c) return; 20 | const int c_idx = index % c; 21 | const int nsample_idx = (index / c) % nsample; 22 | const int m_idx = index / nsample / c; 23 | const int input_idx = idx[m_idx * nsample + nsample_idx] * c + c_idx; 24 | atomicAdd(grad_input + input_idx, grad_output[index]); 25 | } 26 | 27 | void grouping_forward_cuda_launcher(int m, int nsample, int c, const float *input, const int *idx, float *output) { 28 | // input: input: (n, c), idx: (m, nsample), output: (m, nsample, c) 29 | dim3 blocks(DIVUP(m * nsample * c, THREADS_PER_BLOCK)); 30 | dim3 threads(THREADS_PER_BLOCK); 31 | grouping_forward_cuda_kernel<<>>(m, nsample, c, input, idx, output); 32 | } 33 | 34 | void grouping_backward_cuda_launcher(int m, int nsample, int c, const float *grad_output, const int *idx, float *grad_input) 35 | { 36 | // input: grad_output: (m, nsample, c), idx: (m, nsample), output: grad_input: (n, c) 37 | dim3 blocks(DIVUP(m * nsample * c, THREADS_PER_BLOCK)); 38 | dim3 threads(THREADS_PER_BLOCK); 39 | grouping_backward_cuda_kernel<<>>(m, nsample, c, grad_output, idx, grad_input); 40 | } 41 | -------------------------------------------------------------------------------- /libs/pointops/src/grouping/grouping_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _GROUPING_CUDA_KERNEL 2 | #define _GROUPING_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void grouping_forward_cuda(int m, int nsample, int c, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); 8 | void grouping_backward_cuda(int m, int nsample, int c, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor grad_input_tensor); 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | void grouping_forward_cuda_launcher(int m, int nsample, int c, const float *input, const int *idx, float *output); 15 | void grouping_backward_cuda_launcher(int m, int nsample, int c, const float *grad_output, const int *idx, float *grad_input); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | #endif 21 | -------------------------------------------------------------------------------- /libs/pointops/src/interpolation/interpolation_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "interpolation_cuda_kernel.h" 5 | 6 | 7 | void interpolation_forward_cuda(int n, int c, int k, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor output_tensor) 8 | { 9 | const float *input = input_tensor.data_ptr(); 10 | const int *idx = idx_tensor.data_ptr(); 11 | const float *weight = weight_tensor.data_ptr(); 12 | float *output = output_tensor.data_ptr(); 13 | interpolation_forward_cuda_launcher(n, c, k, input, idx, weight, output); 14 | } 15 | 16 | void interpolation_backward_cuda(int n, int c, int k, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_input_tensor) 17 | { 18 | const float *grad_output = grad_output_tensor.data_ptr(); 19 | const int *idx = idx_tensor.data_ptr(); 20 | const float *weight = weight_tensor.data_ptr(); 21 | float *grad_input = grad_input_tensor.data_ptr(); 22 | interpolation_backward_cuda_launcher(n, c, k, grad_output, idx, weight, grad_input); 23 | } 24 | -------------------------------------------------------------------------------- /libs/pointops/src/interpolation/interpolation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "interpolation_cuda_kernel.h" 3 | 4 | 5 | __global__ void interpolation_forward_cuda_kernel(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output) 6 | { 7 | // input: input: (m, c), idx: (n, k), weight: (n, k), output: output (n, c) 8 | int index = blockIdx.x * blockDim.x + threadIdx.x; 9 | if (index >= n * c) return; 10 | int c_idx = index % c; 11 | int n_idx = index / c; 12 | for (int i = 0; i < k; i++) 13 | { 14 | int idx_idx = n_idx * k + i; 15 | int input_idx = idx[idx_idx] * c + c_idx; 16 | output[index] += input[input_idx] * weight[idx_idx]; 17 | } 18 | } 19 | 20 | __global__ void interpolation_backward_cuda_kernel(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input) 21 | { 22 | // input: grad_output: (n, c), idx: (n, k), weight: (n, k), output: grad_input (m, c) 23 | int index = blockIdx.x * blockDim.x + threadIdx.x; 24 | if (index >= n * c) return; 25 | int c_idx = index % c; 26 | int n_idx = index / c; 27 | for (int i = 0; i < k; i++) 28 | { 29 | int idx_idx = n_idx * k + i; 30 | int input_idx = idx[idx_idx] * c + c_idx; 31 | atomicAdd(grad_input + input_idx, grad_output[index] * weight[idx_idx]); 32 | } 33 | } 34 | 35 | void interpolation_forward_cuda_launcher(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output) { 36 | // input: input: (m, c), idx: (n, k), weight: (n, k), output: output (n, c) 37 | dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); 38 | dim3 threads(THREADS_PER_BLOCK); 39 | interpolation_forward_cuda_kernel<<>>(n, c, k, input, idx, weight, output); 40 | } 41 | 42 | void interpolation_backward_cuda_launcher(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input) { 43 | // input: grad_output: (n, c), idx: (n, k), weight: (n, k), output: grad_input (m, c) 44 | dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); 45 | dim3 threads(THREADS_PER_BLOCK); 46 | interpolation_backward_cuda_kernel<<>>(n, c, k, grad_output, idx, weight, grad_input); 47 | } 48 | -------------------------------------------------------------------------------- /libs/pointops/src/interpolation/interpolation_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATION_CUDA_KERNEL 2 | #define _INTERPOLATION_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void interpolation_forward_cuda(int n, int c, int k, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor output_tensor); 8 | void interpolation_backward_cuda(int n, int c, int k, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_input_tensor); 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | void interpolation_forward_cuda_launcher(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output); 15 | void interpolation_backward_cuda_launcher(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | #endif 21 | -------------------------------------------------------------------------------- /libs/pointops/src/knn_query/knn_query_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "knn_query_cuda_kernel.h" 5 | 6 | 7 | void knn_query_cuda(int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor) 8 | { 9 | const float *xyz = xyz_tensor.data_ptr(); 10 | const float *new_xyz = new_xyz_tensor.data_ptr(); 11 | const int *offset = offset_tensor.data_ptr(); 12 | const int *new_offset = new_offset_tensor.data_ptr(); 13 | int *idx = idx_tensor.data_ptr(); 14 | float *dist2 = dist2_tensor.data_ptr(); 15 | knn_query_cuda_launcher(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2); 16 | } 17 | -------------------------------------------------------------------------------- /libs/pointops/src/knn_query/knn_query_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "knn_query_cuda_kernel.h" 3 | 4 | 5 | namespace knn_query_utils{ 6 | 7 | template 8 | __device__ void swap(DType *x, DType *y) 9 | { 10 | DType tmp = *x; 11 | *x = *y; 12 | *y = tmp; 13 | } 14 | 15 | __device__ void reheap(float *dist, int *idx, int k) 16 | { 17 | int root = 0; 18 | int child = root * 2 + 1; 19 | while (child < k) 20 | { 21 | if(child + 1 < k && dist[child+1] > dist[child]) 22 | child++; 23 | if(dist[root] > dist[child]) 24 | return; 25 | swap(&dist[root], &dist[child]); 26 | swap(&idx[root], &idx[child]); 27 | root = child; 28 | child = root * 2 + 1; 29 | } 30 | } 31 | 32 | 33 | __device__ void heap_sort(float *dist, int *idx, int k) 34 | { 35 | int i; 36 | for (i = k - 1; i > 0; i--) 37 | { 38 | swap(&dist[0], &dist[i]); 39 | swap(&idx[0], &idx[i]); 40 | reheap(dist, idx, i); 41 | } 42 | } 43 | 44 | 45 | __device__ int get_bt_idx(int idx, const int *offset) 46 | { 47 | int i = 0; 48 | while (1) 49 | { 50 | if (idx < offset[i]) 51 | break; 52 | else 53 | i++; 54 | } 55 | return i; 56 | } 57 | } // namespace knn_query_utils 58 | 59 | 60 | __global__ void knn_query_cuda_kernel(int m, int nsample, const float *__restrict__ xyz, const float *__restrict__ new_xyz, const int *__restrict__ offset, const int *__restrict__ new_offset, int *__restrict__ idx, float *__restrict__ dist2) { 61 | // input: xyz (n, 3) new_xyz (m, 3) 62 | // output: idx (m, nsample) dist2 (m, nsample) 63 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 64 | if (pt_idx >= m) return; 65 | 66 | new_xyz += pt_idx * 3; 67 | idx += pt_idx * nsample; 68 | dist2 += pt_idx * nsample; 69 | 70 | int bt_idx = knn_query_utils::get_bt_idx(pt_idx, new_offset); 71 | int start; 72 | if (bt_idx == 0) 73 | start = 0; 74 | else 75 | start = offset[bt_idx - 1]; 76 | int end = offset[bt_idx]; 77 | 78 | float new_x = new_xyz[0]; 79 | float new_y = new_xyz[1]; 80 | float new_z = new_xyz[2]; 81 | 82 | float best_dist[128]; 83 | int best_idx[128]; 84 | for(int i = 0; i < nsample; i++){ 85 | best_dist[i] = 1e10; 86 | best_idx[i] = -1; 87 | } 88 | for(int i = start; i < end; i++){ 89 | float x = xyz[i * 3 + 0]; 90 | float y = xyz[i * 3 + 1]; 91 | float z = xyz[i * 3 + 2]; 92 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 93 | if (d2 < best_dist[0]){ 94 | best_dist[0] = d2; 95 | best_idx[0] = i; 96 | knn_query_utils::reheap(best_dist, best_idx, nsample); 97 | } 98 | } 99 | knn_query_utils::heap_sort(best_dist, best_idx, nsample); 100 | for(int i = 0; i < nsample; i++){ 101 | idx[i] = best_idx[i]; 102 | dist2[i] = best_dist[i]; 103 | } 104 | } 105 | 106 | 107 | void knn_query_cuda_launcher(int m, int nsample, const float *xyz, const float *new_xyz, const int *offset, const int *new_offset, int *idx, float *dist2) { 108 | // input: new_xyz: (m, 3), xyz: (n, 3), idx: (m, nsample) 109 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK)); 110 | dim3 threads(THREADS_PER_BLOCK); 111 | knn_query_cuda_kernel<<>>(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2); 112 | } 113 | -------------------------------------------------------------------------------- /libs/pointops/src/knn_query/knn_query_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _KNN_QUERY_CUDA_KERNEL 2 | #define _KNN_QUERY_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void knn_query_cuda(int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor); 8 | 9 | #ifdef __cplusplus 10 | extern "C" { 11 | #endif 12 | 13 | void knn_query_cuda_launcher(int m, int nsample, const float *xyz, const float *new_xyz, const int *offset, const int *new_offset, int *idx, float *dist2); 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif 18 | #endif 19 | -------------------------------------------------------------------------------- /libs/pointops/src/pointops_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "knn_query/knn_query_cuda_kernel.h" 5 | #include "ball_query/ball_query_cuda_kernel.h" 6 | #include "random_ball_query/random_ball_query_cuda_kernel.h" 7 | #include "sampling/sampling_cuda_kernel.h" 8 | #include "grouping/grouping_cuda_kernel.h" 9 | #include "interpolation/interpolation_cuda_kernel.h" 10 | #include "aggregation/aggregation_cuda_kernel.h" 11 | #include "subtraction/subtraction_cuda_kernel.h" 12 | #include "attention/attention_cuda_kernel.h" 13 | 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("knn_query_cuda", &knn_query_cuda, "knn_query_cuda"); 17 | m.def("ball_query_cuda", &ball_query_cuda, "ball_query_cuda"); 18 | m.def("random_ball_query_cuda", &random_ball_query_cuda, "random_ball_query_cuda"); 19 | m.def("farthest_point_sampling_cuda", &farthest_point_sampling_cuda, "farthest_point_sampling_cuda"); 20 | m.def("grouping_forward_cuda", &grouping_forward_cuda, "grouping_forward_cuda"); 21 | m.def("grouping_backward_cuda", &grouping_backward_cuda, "grouping_backward_cuda"); 22 | m.def("interpolation_forward_cuda", &interpolation_forward_cuda, "interpolation_forward_cuda"); 23 | m.def("interpolation_backward_cuda", &interpolation_backward_cuda, "interpolation_backward_cuda"); 24 | m.def("subtraction_forward_cuda", &subtraction_forward_cuda, "subtraction_forward_cuda"); 25 | m.def("subtraction_backward_cuda", &subtraction_backward_cuda, "subtraction_backward_cuda"); 26 | m.def("aggregation_forward_cuda", &aggregation_forward_cuda, "aggregation_forward_cuda"); 27 | m.def("aggregation_backward_cuda", &aggregation_backward_cuda, "aggregation_backward_cuda"); 28 | m.def("attention_relation_step_forward_cuda", &attention_relation_step_forward_cuda, "attention_relation_step_forward_cuda"); 29 | m.def("attention_relation_step_backward_cuda", &attention_relation_step_backward_cuda, "attention_relation_step_backward_cuda"); 30 | m.def("attention_fusion_step_forward_cuda", &attention_fusion_step_forward_cuda, "attention_fusion_step_forward_cuda"); 31 | m.def("attention_fusion_step_backward_cuda", &attention_fusion_step_backward_cuda, "attention_fusion_step_backward_cuda"); 32 | } 33 | -------------------------------------------------------------------------------- /libs/pointops/src/random_ball_query/random_ball_query_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "random_ball_query_cuda_kernel.h" 5 | 6 | 7 | void random_ball_query_cuda(int m, int nsample, 8 | float min_radius, float max_radius, at::Tensor order_tensor, 9 | at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, 10 | at::Tensor offset_tensor, at::Tensor new_offset_tensor, 11 | at::Tensor idx_tensor, at::Tensor dist2_tensor) 12 | { 13 | const int *order = order_tensor.data_ptr(); 14 | const float *xyz = xyz_tensor.data_ptr(); 15 | const float *new_xyz = new_xyz_tensor.data_ptr(); 16 | const int *offset = offset_tensor.data_ptr(); 17 | const int *new_offset = new_offset_tensor.data_ptr(); 18 | int *idx = idx_tensor.data_ptr(); 19 | float *dist2 = dist2_tensor.data_ptr(); 20 | random_ball_query_cuda_launcher(m, nsample, min_radius, max_radius, order, xyz, new_xyz, offset, new_offset, idx, dist2); 21 | } 22 | -------------------------------------------------------------------------------- /libs/pointops/src/random_ball_query/random_ball_query_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "random_ball_query_cuda_kernel.h" 3 | 4 | 5 | namespace random_ball_query_utils{ 6 | 7 | template 8 | __device__ void swap(DType *x, DType *y) 9 | { 10 | DType tmp = *x; 11 | *x = *y; 12 | *y = tmp; 13 | } 14 | 15 | __device__ void reheap(float *dist, int *idx, int k) 16 | { 17 | int root = 0; 18 | int child = root * 2 + 1; 19 | while (child < k) 20 | { 21 | if(child + 1 < k && dist[child+1] > dist[child]) 22 | child++; 23 | if(dist[root] > dist[child]) 24 | return; 25 | swap(&dist[root], &dist[child]); 26 | swap(&idx[root], &idx[child]); 27 | root = child; 28 | child = root * 2 + 1; 29 | } 30 | } 31 | 32 | 33 | __device__ void heap_sort(float *dist, int *idx, int k) 34 | { 35 | int i; 36 | for (i = k - 1; i > 0; i--) 37 | { 38 | swap(&dist[0], &dist[i]); 39 | swap(&idx[0], &idx[i]); 40 | reheap(dist, idx, i); 41 | } 42 | } 43 | 44 | __device__ int get_bt_idx(int idx, const int *offset) 45 | { 46 | int i = 0; 47 | while (1) 48 | { 49 | if (idx < offset[i]) 50 | break; 51 | else 52 | i++; 53 | } 54 | return i; 55 | } 56 | } // namespace ball_query_utils 57 | 58 | __global__ void random_ball_query_cuda_kernel(int m, int nsample, 59 | float min_radius, float max_radius, const int *__restrict__ order, 60 | const float *__restrict__ xyz, const float *__restrict__ new_xyz, 61 | const int *__restrict__ offset, const int *__restrict__ new_offset, 62 | int *__restrict__ idx, float *__restrict__ dist2) { 63 | // input: xyz (n, 3) new_xyz (m, 3) 64 | // output: idx (m, nsample) dist (m, nsample) 65 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 66 | if (pt_idx >= m) return; 67 | 68 | new_xyz += pt_idx * 3; 69 | idx += pt_idx * nsample; 70 | dist2 += pt_idx * nsample; 71 | 72 | int bt_idx = random_ball_query_utils::get_bt_idx(pt_idx, new_offset); 73 | int start; 74 | if (bt_idx == 0) 75 | start = 0; 76 | else 77 | start = offset[bt_idx - 1]; 78 | int end = offset[bt_idx]; 79 | 80 | float max_radius2 = max_radius * max_radius; 81 | float min_radius2 = min_radius * min_radius; 82 | float new_x = new_xyz[0]; 83 | float new_y = new_xyz[1]; 84 | float new_z = new_xyz[2]; 85 | 86 | int cnt = 0; 87 | 88 | for(int i = start; i < end; i++){ 89 | float x = xyz[order[i] * 3 + 0]; 90 | float y = xyz[order[i] * 3 + 1]; 91 | float z = xyz[order[i] * 3 + 2]; 92 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 93 | 94 | if (d2 <= 1e-5 || (d2 >= min_radius2 && d2 < max_radius2)){ 95 | dist2[cnt] = d2; 96 | idx[cnt] = order[i]; 97 | cnt += 1; 98 | if (cnt >= nsample) break; 99 | } 100 | } 101 | 102 | if (cnt < nsample) { 103 | for (int i = cnt; i < nsample; i++){ 104 | idx[i] = -1; 105 | dist2[i] = 1e10; 106 | } 107 | } 108 | } 109 | 110 | void random_ball_query_cuda_launcher(int m, int nsample, 111 | float min_radius, float max_radius, const int *order, 112 | const float *xyz, const float *new_xyz, 113 | const int *offset, const int *new_offset, 114 | int *idx, float *dist2) { 115 | // input: new_xyz: (m, 3), xyz: (n, 3), idx: (m, nsample) 116 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK)); 117 | dim3 threads(THREADS_PER_BLOCK); 118 | random_ball_query_cuda_kernel<<>>(m, nsample, 119 | min_radius, max_radius, order, 120 | xyz, new_xyz, 121 | offset, new_offset, 122 | idx, dist2); 123 | } 124 | -------------------------------------------------------------------------------- /libs/pointops/src/random_ball_query/random_ball_query_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _RANDOM_BALL_QUERY_CUDA_KERNEL 2 | #define _RANDOM_BALL_QUERY_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void random_ball_query_cuda(int m, int nsample, 8 | float min_radius, float max_radius, at::Tensor order_tensor, 9 | at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, 10 | at::Tensor offset_tensor, at::Tensor new_offset_tensor, 11 | at::Tensor idx_tensor, at::Tensor dist2_tensor); 12 | 13 | #ifdef __cplusplus 14 | extern "C" { 15 | #endif 16 | 17 | void random_ball_query_cuda_launcher(int m, int nsample, 18 | float min_radius, float max_radius, const int *order, 19 | const float *xyz, const float *new_xyz, 20 | const int *offset, const int *new_offset, 21 | int *idx, float *dist2); 22 | 23 | #ifdef __cplusplus 24 | } 25 | #endif 26 | #endif 27 | -------------------------------------------------------------------------------- /libs/pointops/src/sampling/sampling_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "sampling_cuda_kernel.h" 5 | 6 | 7 | void farthest_point_sampling_cuda(int b, int n, at::Tensor xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor tmp_tensor, at::Tensor idx_tensor) 8 | { 9 | const float *xyz = xyz_tensor.data_ptr(); 10 | const int *offset = offset_tensor.data_ptr(); 11 | const int *new_offset = new_offset_tensor.data_ptr(); 12 | float *tmp = tmp_tensor.data_ptr(); 13 | int *idx = idx_tensor.data_ptr(); 14 | farthest_point_sampling_cuda_launcher(b, n, xyz, offset, new_offset, tmp, idx); 15 | } 16 | -------------------------------------------------------------------------------- /libs/pointops/src/sampling/sampling_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "sampling_cuda_kernel.h" 3 | 4 | 5 | __device__ void __update(float *dists, int *dists_i, int idx1, int idx2) { 6 | const float v1 = dists[idx1], v2 = dists[idx2]; 7 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 8 | dists[idx1] = max(v1, v2); 9 | dists_i[idx1] = v2 > v1 ? i2 : i1; 10 | } 11 | 12 | // input xyz: (n, 3), tmp: (b, n_max) 13 | // ouput idx (m) 14 | template 15 | __global__ void farthest_point_sampling_cuda_kernel(const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx) 16 | { 17 | __shared__ float dists[block_size]; 18 | __shared__ int dists_i[block_size]; 19 | 20 | int bid = blockIdx.x; 21 | int start_n, end_n, start_m, end_m, old; 22 | if (bid == 0) { 23 | start_n = 0; 24 | end_n = offset[0]; 25 | start_m = 0; 26 | end_m = new_offset[0]; 27 | old = 0; 28 | } 29 | else { 30 | start_n = offset[bid - 1]; 31 | end_n = offset[bid]; 32 | start_m = new_offset[bid - 1]; 33 | end_m = new_offset[bid]; 34 | old = offset[bid - 1]; 35 | } 36 | 37 | const int stride = block_size; 38 | int tid = threadIdx.x; 39 | if (tid == 0) idx[start_m] = start_n; 40 | 41 | __syncthreads(); 42 | for (int j = start_m + 1; j < end_m; j++) 43 | { 44 | int besti = start_n; 45 | float best = -1; 46 | float x1 = xyz[old * 3 + 0]; 47 | float y1 = xyz[old * 3 + 1]; 48 | float z1 = xyz[old * 3 + 2]; 49 | for (int k = start_n + tid; k < end_n; k += stride) 50 | { 51 | float x2 = xyz[k * 3 + 0]; 52 | float y2 = xyz[k * 3 + 1]; 53 | float z2 = xyz[k * 3 + 2]; 54 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 55 | float d2 = min(d, tmp[k]); 56 | tmp[k] = d2; 57 | besti = d2 > best ? k : besti; 58 | best = d2 > best ? d2 : best; 59 | } 60 | dists[tid] = best; 61 | dists_i[tid] = besti; 62 | __syncthreads(); 63 | 64 | if (block_size >= 1024) { 65 | if (tid < 512) { 66 | __update(dists, dists_i, tid, tid + 512); 67 | } 68 | __syncthreads(); 69 | } 70 | if (block_size >= 512) { 71 | if (tid < 256) { 72 | __update(dists, dists_i, tid, tid + 256); 73 | } 74 | __syncthreads(); 75 | } 76 | if (block_size >= 256) { 77 | if (tid < 128) { 78 | __update(dists, dists_i, tid, tid + 128); 79 | } 80 | __syncthreads(); 81 | } 82 | if (block_size >= 128) { 83 | if (tid < 64) { 84 | __update(dists, dists_i, tid, tid + 64); 85 | } 86 | __syncthreads(); 87 | } 88 | if (block_size >= 64) { 89 | if (tid < 32) { 90 | __update(dists, dists_i, tid, tid + 32); 91 | } 92 | __syncthreads(); 93 | } 94 | if (block_size >= 32) { 95 | if (tid < 16) { 96 | __update(dists, dists_i, tid, tid + 16); 97 | } 98 | __syncthreads(); 99 | } 100 | if (block_size >= 16) { 101 | if (tid < 8) { 102 | __update(dists, dists_i, tid, tid + 8); 103 | } 104 | __syncthreads(); 105 | } 106 | if (block_size >= 8) { 107 | if (tid < 4) { 108 | __update(dists, dists_i, tid, tid + 4); 109 | } 110 | __syncthreads(); 111 | } 112 | if (block_size >= 4) { 113 | if (tid < 2) { 114 | __update(dists, dists_i, tid, tid + 2); 115 | } 116 | __syncthreads(); 117 | } 118 | if (block_size >= 2) { 119 | if (tid < 1) { 120 | __update(dists, dists_i, tid, tid + 1); 121 | } 122 | __syncthreads(); 123 | } 124 | 125 | old = dists_i[0]; 126 | if (tid == 0) 127 | idx[j] = old; 128 | } 129 | } 130 | 131 | void farthest_point_sampling_cuda_launcher(int b, int n, const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx) 132 | { 133 | unsigned int n_threads = opt_n_threads(n); 134 | switch (n_threads) { 135 | case 1024: 136 | farthest_point_sampling_cuda_kernel<1024><<>>(xyz, offset, new_offset, tmp, idx); 137 | break; 138 | case 512: 139 | farthest_point_sampling_cuda_kernel<512><<>>(xyz, offset, new_offset, tmp, idx); 140 | break; 141 | case 256: 142 | farthest_point_sampling_cuda_kernel<256><<>>(xyz, offset, new_offset, tmp, idx); 143 | break; 144 | case 128: 145 | farthest_point_sampling_cuda_kernel<128><<>>(xyz, offset, new_offset, tmp, idx); 146 | break; 147 | case 64: 148 | farthest_point_sampling_cuda_kernel<64><<>>(xyz, offset, new_offset, tmp, idx); 149 | break; 150 | case 32: 151 | farthest_point_sampling_cuda_kernel<32><<>>(xyz, offset, new_offset, tmp, idx); 152 | break; 153 | case 16: 154 | farthest_point_sampling_cuda_kernel<16><<>>(xyz, offset, new_offset, tmp, idx); 155 | break; 156 | case 8: 157 | farthest_point_sampling_cuda_kernel<8><<>>(xyz, offset, new_offset, tmp, idx); 158 | break; 159 | case 4: 160 | farthest_point_sampling_cuda_kernel<4><<>>(xyz, offset, new_offset, tmp, idx); 161 | break; 162 | case 2: 163 | farthest_point_sampling_cuda_kernel<2><<>>(xyz, offset, new_offset, tmp, idx); 164 | break; 165 | case 1: 166 | farthest_point_sampling_cuda_kernel<1><<>>(xyz, offset, new_offset, tmp, idx); 167 | break; 168 | default: 169 | farthest_point_sampling_cuda_kernel<512><<>>(xyz, offset, new_offset, tmp, idx); 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /libs/pointops/src/sampling/sampling_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_CUDA_KERNEL 2 | #define _SAMPLING_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void farthest_point_sampling_cuda(int b, int n, at::Tensor xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor tmp_tensor, at::Tensor idx_tensor); 8 | 9 | #ifdef __cplusplus 10 | extern "C" { 11 | #endif 12 | 13 | void farthest_point_sampling_cuda_launcher(int b, int n, const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx); 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif 18 | #endif 19 | -------------------------------------------------------------------------------- /libs/pointops/src/subtraction/subtraction_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "subtraction_cuda_kernel.h" 5 | 6 | 7 | void subtraction_forward_cuda(int n, int nsample, int c, at::Tensor input1_tensor, at::Tensor input2_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) 8 | { 9 | const float *input1 = input1_tensor.data_ptr(); 10 | const float *input2 = input2_tensor.data_ptr(); 11 | const int *idx = idx_tensor.data_ptr(); 12 | float *output = output_tensor.data_ptr(); 13 | subtraction_forward_cuda_launcher(n, nsample, c, input1, input2, idx, output); 14 | } 15 | 16 | void subtraction_backward_cuda(int n, int nsample, int c, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input1_tensor, at::Tensor grad_input2_tensor) 17 | { 18 | const int *idx = idx_tensor.data_ptr(); 19 | const float *grad_output = grad_output_tensor.data_ptr(); 20 | float *grad_input1 = grad_input1_tensor.data_ptr(); 21 | float *grad_input2 = grad_input2_tensor.data_ptr(); 22 | subtraction_backward_cuda_launcher(n, nsample, c, idx, grad_output, grad_input1, grad_input2); 23 | } 24 | -------------------------------------------------------------------------------- /libs/pointops/src/subtraction/subtraction_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "subtraction_cuda_kernel.h" 3 | 4 | 5 | __global__ void subtraction_forward_cuda_kernel(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output) { 6 | // input: input1: (n, c), input2: (n, c), idx: (n, nsample), output: (n, nsample, c) 7 | int index = blockIdx.x * blockDim.x + threadIdx.x; 8 | if (index >= n * nsample * c) return; 9 | const int c_idx = index % c; 10 | const int nsample_idx = (index / c) % nsample; 11 | const int n_idx = index / nsample / c; 12 | const int idx_idx = n_idx * nsample + nsample_idx; 13 | const int input1_idx = n_idx * c + c_idx; 14 | const int input2_idx = idx[idx_idx] * c + c_idx; 15 | output[index] = input1[input1_idx] - input2[input2_idx]; 16 | } 17 | 18 | __global__ void subtraction_backward_cuda_kernel(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2) { 19 | // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) 20 | int index = blockIdx.x * blockDim.x + threadIdx.x; 21 | if (index >= n * nsample * c) return; 22 | const int c_idx = index % c; 23 | const int nsample_idx = (index / c) % nsample; 24 | const int n_idx = index / nsample / c; 25 | const int idx_idx = n_idx * nsample + nsample_idx; 26 | const int input1_idx = n_idx * c + c_idx; 27 | const int input2_idx = idx[idx_idx] * c + c_idx; 28 | atomicAdd(grad_input1 + input1_idx, grad_output[index]); 29 | atomicAdd(grad_input2 + input2_idx, -grad_output[index]); 30 | } 31 | 32 | void subtraction_forward_cuda_launcher(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output) { 33 | // input: input1: (n, c), input2: (n, c), idx: (n, nsample), output: (n, nsample, c) 34 | dim3 blocks(DIVUP(n * nsample * c, THREADS_PER_BLOCK)); 35 | dim3 threads(THREADS_PER_BLOCK); 36 | subtraction_forward_cuda_kernel<<>>(n, nsample, c, input1, input2, idx, output); 37 | } 38 | 39 | void subtraction_backward_cuda_launcher(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2) { 40 | // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) 41 | dim3 blocks(DIVUP(n * nsample * c, THREADS_PER_BLOCK)); 42 | dim3 threads(THREADS_PER_BLOCK); 43 | subtraction_backward_cuda_kernel<<>>(n, nsample, c, idx, grad_output, grad_input1, grad_input2); 44 | } 45 | -------------------------------------------------------------------------------- /libs/pointops/src/subtraction/subtraction_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _SUBTRACTION_CUDA_KERNEL 2 | #define _SUBTRACTION_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void subtraction_forward_cuda(int n, int nsample, int c, at::Tensor input1_tensor, at::Tensor input2_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); 8 | void subtraction_backward_cuda(int n, int nsample, int c, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input1_tensor, at::Tensor grad_input2_tensor); 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | void subtraction_forward_cuda_launcher(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output); 15 | void subtraction_backward_cuda_launcher(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | #endif 21 | -------------------------------------------------------------------------------- /pointcept/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_dataset 2 | from .utils import point_collate_fn, collate_fn 3 | 4 | from .dataset_render_16views import SAMPart3DDataset16Views -------------------------------------------------------------------------------- /pointcept/datasets/builder.py: -------------------------------------------------------------------------------- 1 | from pointcept.utils.registry import Registry 2 | 3 | DATASETS = Registry("datasets") 4 | 5 | 6 | def build_dataset(cfg): 7 | """Build datasets.""" 8 | return DATASETS.build(cfg) 9 | -------------------------------------------------------------------------------- /pointcept/datasets/sampart3d_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | import os 4 | import json 5 | import math 6 | import open3d as o3d 7 | import torch 8 | 9 | 10 | def sample_surface(mesh, count, face_weight=None, sample_color=False, seed=147): 11 | 12 | if face_weight is None: 13 | # len(mesh.faces) float, array of the areas 14 | # of each face of the mesh 15 | face_weight = mesh.area_faces 16 | 17 | # cumulative sum of weights (len(mesh.faces)) 18 | weight_cum = np.cumsum(face_weight) 19 | 20 | # seed the random number generator as requested 21 | random = np.random.default_rng(seed).random 22 | 23 | # last value of cumulative sum is total summed weight/area 24 | face_pick = random(count) * weight_cum[-1] 25 | # get the index of the selected faces 26 | face_index = np.searchsorted(weight_cum, face_pick) 27 | 28 | # pull triangles into the form of an origin + 2 vectors 29 | tri_origins = mesh.vertices[mesh.faces[:, 0]] 30 | tri_vectors = mesh.vertices[mesh.faces[:, 1:]].copy() 31 | tri_vectors -= np.tile(tri_origins, (1, 2)).reshape((-1, 2, 3)) 32 | 33 | # pull the vectors for the faces we are going to sample from 34 | tri_origins = tri_origins[face_index] 35 | tri_vectors = tri_vectors[face_index] 36 | 37 | if sample_color and hasattr(mesh.visual, "uv"): 38 | uv_origins = mesh.visual.uv[mesh.faces[:, 0]] 39 | uv_vectors = mesh.visual.uv[mesh.faces[:, 1:]].copy() 40 | uv_origins_tile = np.tile(uv_origins, (1, 2)).reshape((-1, 2, 2)) 41 | uv_vectors -= uv_origins_tile 42 | uv_origins = uv_origins[face_index] 43 | uv_vectors = uv_vectors[face_index] 44 | 45 | # randomly generate two 0-1 scalar components to multiply edge vectors b 46 | random_lengths = random((len(tri_vectors), 2, 1)) 47 | 48 | # points will be distributed on a quadrilateral if we use 2 0-1 samples 49 | # if the two scalar components sum less than 1.0 the point will be 50 | # inside the triangle, so we find vectors longer than 1.0 and 51 | # transform them to be inside the triangle 52 | random_test = random_lengths.sum(axis=1).reshape(-1) > 1.0 53 | random_lengths[random_test] -= 1.0 54 | random_lengths = np.abs(random_lengths) 55 | 56 | # multiply triangle edge vectors by the random lengths and sum 57 | sample_vector = (tri_vectors * random_lengths).sum(axis=1) 58 | 59 | # finally, offset by the origin to generate 60 | # (n,3) points in space on the triangle 61 | samples = sample_vector + tri_origins 62 | 63 | if sample_color: 64 | if hasattr(mesh.visual, "uv"): 65 | sample_uv_vector = (uv_vectors * random_lengths).sum(axis=1) 66 | uv_samples = sample_uv_vector + uv_origins 67 | try: 68 | texture = mesh.visual.material.baseColorTexture 69 | except: 70 | texture = mesh.visual.material.image 71 | colors = trimesh.visual.color.uv_to_interpolated_color(uv_samples, texture) 72 | else: 73 | colors = mesh.visual.face_colors[face_index] 74 | 75 | return samples, face_index, colors 76 | 77 | return samples, face_index 78 | 79 | 80 | def get_ray_directions(W, H, fx, fy, cx, cy, use_pixel_centers=True): 81 | pixel_center = 0.5 if use_pixel_centers else 0 82 | i, j = np.meshgrid( 83 | np.arange(W, dtype=np.float32) + pixel_center, 84 | np.arange(H, dtype=np.float32) + pixel_center, 85 | indexing="xy", 86 | ) 87 | directions = np.stack( 88 | [(i - cx) / fx, -(j - cy) / fy, -np.ones_like(i)], -1 89 | ) 90 | 91 | return directions 92 | 93 | 94 | def gen_pcd(depth, c2w_opengl, camera_angle_x): 95 | 96 | h, w = depth.shape 97 | 98 | depth_valid = depth < 65500.0 99 | depth = depth[depth_valid] 100 | focal = ( 101 | 0.5 * w / math.tan(0.5 * camera_angle_x) 102 | ) # scaled focal length 103 | ray_directions = get_ray_directions(w, h, focal, focal, w // 2, h // 2) 104 | points_c = ray_directions[depth_valid] * depth[:, None] 105 | points_c_homo = np.concatenate( 106 | [points_c, np.ones_like(points_c[..., :1])], axis=-1 107 | ) 108 | org_points = (points_c_homo @ c2w_opengl.T)[..., :3] 109 | 110 | return org_points 111 | 112 | 113 | def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None): 114 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 115 | coord = np.array(coord) 116 | if color is not None: 117 | color = np.array(color) 118 | pcd = o3d.geometry.PointCloud() 119 | pcd.points = o3d.utility.Vector3dVector(coord) 120 | pcd.colors = o3d.utility.Vector3dVector(np.ones_like(coord) if color is None else color) 121 | o3d.io.write_point_cloud(file_path, pcd) 122 | if logger is not None: 123 | logger.info(f"Save Point Cloud to: {file_path}") 124 | 125 | 126 | def vis_pcd_feat(coord, point_feat, save_path): 127 | class TorchPCA(object): 128 | 129 | def __init__(self, n_components): 130 | self.n_components = n_components 131 | 132 | def fit(self, X): 133 | self.mean_ = X.mean(dim=0) 134 | unbiased = X - self.mean_.unsqueeze(0) 135 | U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4) 136 | self.components_ = V.T 137 | self.singular_values_ = S 138 | return self 139 | 140 | def transform(self, X): 141 | t0 = X - self.mean_.unsqueeze(0) 142 | projected = t0 @ self.components_.T 143 | return projected 144 | 145 | fit_pca = TorchPCA(n_components=3).fit(point_feat) 146 | x_red = fit_pca.transform(point_feat) 147 | if isinstance(x_red, np.ndarray): 148 | x_red = torch.from_numpy(x_red) 149 | x_red -= x_red.min(dim=0, keepdim=True).values 150 | x_red /= x_red.max(dim=0, keepdim=True).values 151 | 152 | save_point_cloud(coord.detach().cpu(), x_red.detach().cpu(), save_path) -------------------------------------------------------------------------------- /pointcept/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections.abc import Mapping, Sequence 3 | import numpy as np 4 | import torch 5 | # import trimesh 6 | from torch.utils.data.dataloader import default_collate 7 | 8 | 9 | def collate_fn_dummy(batch): 10 | return batch 11 | 12 | 13 | def collate_fn(batch): 14 | """ 15 | collate function for point cloud which support dict and list, 16 | 'coord' is necessary to determine 'offset' 17 | """ 18 | if not isinstance(batch, Sequence): 19 | raise TypeError(f"{batch.dtype} is not supported.") 20 | 21 | if isinstance(batch[0], torch.Tensor): 22 | return torch.cat(list(batch)) 23 | elif isinstance(batch[0], str): 24 | # str is also a kind of Sequence, judgement should before Sequence 25 | return list(batch) 26 | elif isinstance(batch[0], Sequence): 27 | for data in batch: 28 | data.append(torch.tensor([data[0].shape[0]])) 29 | batch = [collate_fn(samples) for samples in zip(*batch)] 30 | batch[-1] = torch.cumsum(batch[-1], dim=0).int() 31 | return batch 32 | elif isinstance(batch[0], Mapping): 33 | batch = {key: collate_fn([d[key] for d in batch]) for key in batch[0]} 34 | for key in batch.keys(): 35 | if "offset" in key: 36 | batch[key] = torch.cumsum(batch[key], dim=0) 37 | return batch 38 | else: 39 | return default_collate(batch) 40 | 41 | 42 | def point_collate_fn(batch, mix_prob=0): 43 | assert isinstance( 44 | batch[0], Mapping 45 | ) # currently, only support input_dict, rather than input_list 46 | batch = collate_fn(batch) 47 | if "offset" in batch.keys(): 48 | # Mix3d (https://arxiv.org/pdf/2110.02210.pdf) 49 | if random.random() < mix_prob: 50 | batch["offset"] = torch.cat( 51 | [batch["offset"][1:-1:2], batch["offset"][-1].unsqueeze(0)], dim=0 52 | ) 53 | return batch 54 | 55 | 56 | def gaussian_kernel(dist2: np.array, a: float = 1, c: float = 5): 57 | return a * np.exp(-dist2 / (2 * c**2)) 58 | 59 | -------------------------------------------------------------------------------- /pointcept/engines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pointcept/SAMPart3D/6a508d145be5cecd7e80bce6e4248ea7abbd71b7/pointcept/engines/__init__.py -------------------------------------------------------------------------------- /pointcept/engines/defaults.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import multiprocessing as mp 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | 8 | import pointcept.utils.comm as comm 9 | from pointcept.utils.env import get_random_seed, set_seed 10 | from pointcept.utils.config import Config, DictAction 11 | 12 | 13 | def create_ddp_model(model, *, fp16_compression=False, **kwargs): 14 | """ 15 | Create a DistributedDataParallel model if there are >1 processes. 16 | Args: 17 | model: a torch.nn.Module 18 | fp16_compression: add fp16 compression hooks to the ddp object. 19 | See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook 20 | kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`. 21 | """ 22 | if comm.get_world_size() == 1: 23 | return model 24 | # kwargs['find_unused_parameters'] = True 25 | if "device_ids" not in kwargs: 26 | kwargs["device_ids"] = [comm.get_local_rank()] 27 | if "output_device" not in kwargs: 28 | kwargs["output_device"] = [comm.get_local_rank()] 29 | ddp = DistributedDataParallel(model, **kwargs) 30 | if fp16_compression: 31 | from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks 32 | 33 | ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) 34 | return ddp 35 | 36 | 37 | def worker_init_fn(worker_id, num_workers, rank, seed): 38 | """Worker init func for dataloader. 39 | 40 | The seed of each worker equals to num_worker * rank + worker_id + user_seed 41 | 42 | Args: 43 | worker_id (int): Worker id. 44 | num_workers (int): Number of workers. 45 | rank (int): The rank of current process. 46 | seed (int): The random seed to use. 47 | """ 48 | 49 | worker_seed = num_workers * rank + worker_id + seed 50 | set_seed(worker_seed) 51 | 52 | 53 | def default_argument_parser(epilog=None): 54 | parser = argparse.ArgumentParser( 55 | epilog=epilog 56 | or f""" 57 | Examples: 58 | Run on single machine: 59 | $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml 60 | Change some config options: 61 | $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 62 | Run on multiple machines: 63 | (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags] 64 | (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags] 65 | """, 66 | formatter_class=argparse.RawDescriptionHelpFormatter, 67 | ) 68 | parser.add_argument( 69 | "--config-file", default="", metavar="FILE", help="path to config file" 70 | ) 71 | parser.add_argument( 72 | "--num-gpus", type=int, default=1, help="number of gpus *per machine*" 73 | ) 74 | parser.add_argument( 75 | "--num-machines", type=int, default=1, help="total number of machines" 76 | ) 77 | parser.add_argument( 78 | "--machine-rank", 79 | type=int, 80 | default=0, 81 | help="the rank of this machine (unique per machine)", 82 | ) 83 | # PyTorch still may leave orphan processes in multi-gpu training. 84 | # Therefore we use a deterministic way to obtain port, 85 | # so that users are aware of orphan processes by seeing the port occupied. 86 | # port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 87 | parser.add_argument( 88 | "--dist-url", 89 | # default="tcp://127.0.0.1:{}".format(port), 90 | default="auto", 91 | help="initialization URL for pytorch distributed backend. See " 92 | "https://pytorch.org/docs/stable/distributed.html for details.", 93 | ) 94 | parser.add_argument( 95 | "--options", nargs="+", action=DictAction, help="custom options" 96 | ) 97 | return parser 98 | 99 | 100 | def default_config_parser(file_path, options): 101 | # config name protocol: dataset_name/model_name-exp_name 102 | if os.path.isfile(file_path): 103 | cfg = Config.fromfile(file_path) 104 | else: 105 | sep = file_path.find("-") 106 | cfg = Config.fromfile(os.path.join(file_path[:sep], file_path[sep + 1 :])) 107 | 108 | if options is not None: 109 | cfg.merge_from_dict(options) 110 | 111 | if cfg.seed is None: 112 | cfg.seed = get_random_seed() 113 | 114 | cfg.data.train.loop = cfg.epoch // cfg.eval_epoch 115 | 116 | os.makedirs(os.path.join(cfg.save_path, "model"), exist_ok=True) 117 | if not cfg.resume: 118 | cfg.dump(os.path.join(cfg.save_path, "config.py")) 119 | return cfg 120 | 121 | 122 | def default_setup(cfg): 123 | # scalar by world size 124 | world_size = comm.get_world_size() 125 | cfg.num_worker = cfg.num_worker if cfg.num_worker is not None else mp.cpu_count() 126 | cfg.num_worker_per_gpu = cfg.num_worker // world_size 127 | assert cfg.batch_size % world_size == 0 128 | assert cfg.batch_size_val is None or cfg.batch_size_val % world_size == 0 129 | assert cfg.batch_size_test is None or cfg.batch_size_test % world_size == 0 130 | cfg.batch_size_per_gpu = cfg.batch_size // world_size 131 | cfg.batch_size_val_per_gpu = ( 132 | cfg.batch_size_val // world_size if cfg.batch_size_val is not None else 1 133 | ) 134 | cfg.batch_size_test_per_gpu = ( 135 | cfg.batch_size_test // world_size if cfg.batch_size_test is not None else 1 136 | ) 137 | # update data loop 138 | assert cfg.epoch % cfg.eval_epoch == 0 139 | # settle random seed 140 | rank = comm.get_rank() 141 | seed = None if cfg.seed is None else cfg.seed * cfg.num_worker_per_gpu + rank 142 | set_seed(seed) 143 | return cfg 144 | -------------------------------------------------------------------------------- /pointcept/engines/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .default import HookBase 2 | from .misc import * 3 | from .evaluator import * 4 | # from .partseg import * 5 | 6 | from .builder import build_hooks 7 | -------------------------------------------------------------------------------- /pointcept/engines/hooks/builder.py: -------------------------------------------------------------------------------- 1 | from pointcept.utils.registry import Registry 2 | 3 | 4 | HOOKS = Registry("hooks") 5 | 6 | 7 | def build_hooks(cfg): 8 | hooks = [] 9 | for hook_cfg in cfg: 10 | hooks.append(HOOKS.build(hook_cfg)) 11 | return hooks 12 | -------------------------------------------------------------------------------- /pointcept/engines/hooks/default.py: -------------------------------------------------------------------------------- 1 | class HookBase: 2 | """ 3 | Base class for hooks that can be registered with :class:`TrainerBase`. 4 | """ 5 | 6 | trainer = None # A weak reference to the trainer object. 7 | 8 | def before_train(self): 9 | pass 10 | 11 | def before_epoch(self): 12 | pass 13 | 14 | def before_step(self): 15 | pass 16 | 17 | def after_step(self): 18 | pass 19 | 20 | def after_epoch(self): 21 | pass 22 | 23 | def after_train(self): 24 | pass 25 | -------------------------------------------------------------------------------- /pointcept/engines/launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from datetime import timedelta 4 | import torch 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | 8 | from pointcept.utils import comm 9 | 10 | __all__ = ["DEFAULT_TIMEOUT", "launch"] 11 | 12 | DEFAULT_TIMEOUT = timedelta(minutes=60) 13 | 14 | 15 | def _find_free_port(): 16 | import socket 17 | 18 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 19 | # Binding to port 0 will cause the OS to find an available port for us 20 | sock.bind(("", 0)) 21 | port = sock.getsockname()[1] 22 | sock.close() 23 | # NOTE: there is still a chance the port could be taken by other processes. 24 | return port 25 | 26 | 27 | def launch( 28 | main_func, 29 | num_gpus_per_machine, 30 | num_machines=1, 31 | machine_rank=0, 32 | dist_url=None, 33 | cfg=(), 34 | timeout=DEFAULT_TIMEOUT, 35 | ): 36 | """ 37 | Launch multi-gpu or distributed training. 38 | This function must be called on all machines involved in the training. 39 | It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine. 40 | Args: 41 | main_func: a function that will be called by `main_func(*args)` 42 | num_gpus_per_machine (int): number of GPUs per machine 43 | num_machines (int): the total number of machines 44 | machine_rank (int): the rank of this machine 45 | dist_url (str): url to connect to for distributed jobs, including protocol 46 | e.g. "tcp://127.0.0.1:8686". 47 | Can be set to "auto" to automatically select a free port on localhost 48 | timeout (timedelta): timeout of the distributed workers 49 | args (tuple): arguments passed to main_func 50 | """ 51 | world_size = num_machines * num_gpus_per_machine 52 | if world_size > 1: 53 | if dist_url == "auto": 54 | assert ( 55 | num_machines == 1 56 | ), "dist_url=auto not supported in multi-machine jobs." 57 | port = _find_free_port() 58 | dist_url = f"tcp://127.0.0.1:{port}" 59 | if num_machines > 1 and dist_url.startswith("file://"): 60 | logger = logging.getLogger(__name__) 61 | logger.warning( 62 | "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://" 63 | ) 64 | 65 | mp.spawn( 66 | _distributed_worker, 67 | nprocs=num_gpus_per_machine, 68 | args=( 69 | main_func, 70 | world_size, 71 | num_gpus_per_machine, 72 | machine_rank, 73 | dist_url, 74 | cfg, 75 | timeout, 76 | ), 77 | daemon=False, 78 | ) 79 | else: 80 | main_func(*cfg) 81 | 82 | 83 | def _distributed_worker( 84 | local_rank, 85 | main_func, 86 | world_size, 87 | num_gpus_per_machine, 88 | machine_rank, 89 | dist_url, 90 | cfg, 91 | timeout=DEFAULT_TIMEOUT, 92 | ): 93 | assert ( 94 | torch.cuda.is_available() 95 | ), "cuda is not available. Please check your installation." 96 | global_rank = machine_rank * num_gpus_per_machine + local_rank 97 | try: 98 | dist.init_process_group( 99 | backend="NCCL", 100 | init_method=dist_url, 101 | world_size=world_size, 102 | rank=global_rank, 103 | timeout=timeout, 104 | ) 105 | except Exception as e: 106 | logger = logging.getLogger(__name__) 107 | logger.error("Process group URL: {}".format(dist_url)) 108 | raise e 109 | 110 | # Setup the local process group (which contains ranks within the same machine) 111 | assert comm._LOCAL_PROCESS_GROUP is None 112 | num_machines = world_size // num_gpus_per_machine 113 | for i in range(num_machines): 114 | ranks_on_i = list( 115 | range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) 116 | ) 117 | pg = dist.new_group(ranks_on_i) 118 | if i == machine_rank: 119 | comm._LOCAL_PROCESS_GROUP = pg 120 | 121 | assert num_gpus_per_machine <= torch.cuda.device_count() 122 | torch.cuda.set_device(local_rank) 123 | 124 | # synchronize is needed here to prevent a possible timeout after calling init_process_group 125 | # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 126 | comm.synchronize() 127 | 128 | main_func(*cfg) 129 | -------------------------------------------------------------------------------- /pointcept/models/SAMPart3D.py: -------------------------------------------------------------------------------- 1 | from addict import Dict 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import spconv.pytorch as spconv 6 | 7 | try: 8 | import flash_attn 9 | except ImportError: 10 | flash_attn = None 11 | 12 | from pointcept.models.builder import MODELS, build_model 13 | from pointcept.models.utils.structure import Point 14 | import tinycudann as tcnn 15 | from pointcept.datasets.sampart3d_util import * 16 | 17 | 18 | @MODELS.register_module("SAMPart3D") 19 | class SAMPart3D(nn.Module): 20 | 21 | def __init__(self, 22 | backbone=None, 23 | backbone_dim=None, 24 | output_dim=None, 25 | pcd_feat_dim=None, 26 | use_hierarchy_losses=True, 27 | max_grouping_scale=2, 28 | freeze_backbone=True, 29 | **kwargs): 30 | super().__init__() 31 | 32 | self.use_hierarchy_losses = use_hierarchy_losses 33 | self.max_grouping_scale = max_grouping_scale 34 | self.device = "cuda" 35 | self.quantile_transformer = None 36 | 37 | self.backbone = build_model(backbone) 38 | self.init_feat = None 39 | 40 | self.instance_net = tcnn.Network( 41 | n_input_dims=backbone_dim+1, 42 | n_output_dims=output_dim, 43 | network_config={ 44 | "otype": "CutlassMLP", 45 | "activation": "ReLU", 46 | "output_activation": "None", 47 | "n_neurons": 384, 48 | "n_hidden_layers": 6, 49 | }, 50 | ) 51 | 52 | self.pos_net = tcnn.Network( 53 | n_input_dims=pcd_feat_dim+1, 54 | n_output_dims=output_dim, 55 | network_config={ 56 | "otype": "CutlassMLP", 57 | "activation": "ReLU", 58 | "output_activation": "None", 59 | "n_neurons": 384, 60 | "n_hidden_layers": 4, 61 | }, 62 | ) 63 | 64 | if freeze_backbone: 65 | for name, param in self.named_parameters(): 66 | if 'instance_net' not in name and 'pos_net' not in name: 67 | param.requires_grad = False 68 | 69 | def get_mlp(self, point_feat, scales): 70 | scales = self.quantile_transformer(scales) 71 | # n = point_feat.shape[0] 72 | point_feat = torch.cat((point_feat, scales), dim=-1) 73 | instance_pass = self.instance_net(point_feat) 74 | 75 | epsilon = 1e-5 76 | norms = instance_pass.norm(dim=-1, keepdim=True) 77 | instance_pass = instance_pass / (norms + epsilon) 78 | 79 | return instance_pass 80 | 81 | def pos_emb(self, point_feat, scales): 82 | scales = self.quantile_transformer(scales) 83 | # n = point_feat.shape[0] 84 | point_feat = torch.cat((point_feat, scales), dim=-1) 85 | instance_pass = self.pos_net(point_feat) 86 | 87 | epsilon = 1e-5 88 | norms = instance_pass.norm(dim=-1, keepdim=True) 89 | instance_pass = instance_pass / (norms + epsilon) 90 | 91 | return instance_pass 92 | 93 | def get_loss(self, input_dict, pcd_dict): 94 | if self.init_feat is None: 95 | with torch.no_grad(): 96 | self.backbone.eval() 97 | point = self.backbone(pcd_dict) 98 | point_feat = point.feat 99 | self.init_feat = point_feat 100 | del self.backbone 101 | 102 | point_orgfeat_mapping = pcd_dict["feat"][input_dict["mapping"]] 103 | point_selected_feat = self.init_feat[input_dict["mapping"]] 104 | 105 | loss_dict = {} 106 | margin = 1.0 107 | 108 | #################################################################################### 109 | # Calculate GT labels for the positive and negative pairs 110 | #################################################################################### 111 | input_id1 = input_id2 = input_dict["mask_id"] 112 | 113 | # Expand labels 114 | labels1_expanded = input_id1.unsqueeze(1).expand(-1, input_id1.shape[0]) 115 | labels2_expanded = input_id2.unsqueeze(0).expand(input_id2.shape[0], -1) 116 | 117 | # Mask for positive/negative pairs across the entire matrix 118 | mask_full_positive = labels1_expanded == labels2_expanded 119 | mask_full_negative = ~mask_full_positive 120 | 121 | # Create a block mask to only consider pairs within the same image -- no cross-image pairs 122 | chunk_size = input_dict["nPxImg"] # i.e., the number of rays per image 123 | num_chunks = input_id1.shape[0] // chunk_size # i.e., # of images in the batch 124 | block_mask = torch.kron( 125 | torch.eye(num_chunks, device=self.device, dtype=bool), 126 | torch.ones((chunk_size, chunk_size), device=self.device, dtype=bool), 127 | ) # block-diagonal matrix, to consider only pairs within the same image 128 | 129 | # Only consider upper triangle to avoid double-counting 130 | block_mask = torch.triu(block_mask, diagonal=0) 131 | # Only consider pairs where both points are valid (-1 means not in mask / invalid) 132 | block_mask = block_mask * (labels1_expanded != -1) * (labels2_expanded != -1) 133 | diag_mask = torch.eye(block_mask.shape[0], device=self.device, dtype=bool) 134 | scale = input_dict["scale"].view(-1, 1) 135 | 136 | #################################################################################### 137 | # Grouping supervision 138 | #################################################################################### 139 | total_loss = 0 140 | 141 | # 1. If (A, s_A) and (A', s_A) in same group, then supervise the features to be similar 142 | instance = self.get_mlp(point_selected_feat, scale) 143 | pose_emb = self.pos_emb(point_orgfeat_mapping, scale) 144 | instance = instance + pose_emb 145 | 146 | # instance = instance.float() 147 | mask = torch.where(mask_full_positive * block_mask * (~diag_mask)) 148 | instance_loss_1 = torch.norm( 149 | instance[mask[0]] - instance[mask[1]], p=2, dim=-1 150 | ).nan_to_num(0).mean() 151 | loss_weight_pos = torch.sum(mask_full_positive * block_mask * (~diag_mask)) / torch.sum(block_mask) 152 | total_loss += instance_loss_1 * loss_weight_pos 153 | 154 | # 2. If (A, s_A) and (A', s_A) in same group, then also supervise them to be similar at s > s_A 155 | if self.use_hierarchy_losses: 156 | scale_diff = torch.max( 157 | torch.zeros_like(scale), (self.max_grouping_scale - scale) 158 | ) 159 | larger_scale = scale + scale_diff * torch.rand( 160 | size=(1,), device=scale.device 161 | ) 162 | instance = self.get_mlp(point_selected_feat, larger_scale) 163 | pose_emb = self.pos_emb(point_orgfeat_mapping, larger_scale) 164 | instance = instance + pose_emb 165 | # instance = instance.float() 166 | mask = torch.where(mask_full_positive * block_mask * (~diag_mask)) 167 | instance_loss_2 = torch.norm( 168 | instance[mask[0]] - instance[mask[1]], p=2, dim=-1 169 | ).nan_to_num(0).mean() 170 | total_loss += instance_loss_2 * loss_weight_pos 171 | 172 | # 3. Also supervising A, B to be dissimilar at scales s_A, s_B respectively seems to help. 173 | instance = self.get_mlp(point_selected_feat, scale) 174 | pose_emb = self.pos_emb(point_orgfeat_mapping, scale) 175 | instance = instance + pose_emb 176 | # instance = instance.float() 177 | mask = torch.where(mask_full_negative * block_mask) 178 | instance_loss_3 = ( 179 | F.relu( 180 | margin - torch.norm(instance[mask[0]] - instance[mask[1]], p=2, dim=-1) 181 | ) 182 | ).nan_to_num(0).mean() 183 | loss_weight_neg = torch.sum(mask_full_negative * block_mask) / torch.sum(block_mask) 184 | total_loss += instance_loss_3 * loss_weight_neg 185 | 186 | 187 | loss_dict["instance_loss"] = total_loss 188 | loss_dict["instance_loss_1"] = instance_loss_1 189 | loss_dict["instance_loss_2"] = instance_loss_2 190 | loss_dict["instance_loss_3"] = instance_loss_3 191 | 192 | return loss_dict 193 | 194 | def forward(self, input_dict): 195 | for k, v in input_dict.items(): 196 | if isinstance(v, torch.Tensor): 197 | input_dict[k] = v.cuda() 198 | # print(k, v.shape) 199 | data_dict = input_dict["obj"] 200 | for k, v in data_dict.items(): 201 | if isinstance(v, torch.Tensor): 202 | data_dict[k] = v.cuda() 203 | data_dict["grid_size"] = 0.01 204 | if self.training: 205 | loss_dict = self.get_loss(input_dict, data_dict) 206 | return loss_dict 207 | else: 208 | if self.init_feat is None: 209 | with torch.no_grad(): 210 | self.backbone.eval() 211 | point = self.backbone(data_dict) 212 | point_feat = point.feat 213 | 214 | else: 215 | point_feat = self.init_feat 216 | 217 | scale = input_dict["scale"] 218 | n = data_dict["feat"].shape[0] 219 | scale_column = torch.full((n, 1), scale, device=point_feat.device) 220 | instance_feat = self.get_mlp(point_feat, scale_column) 221 | pose_emb = self.pos_emb(data_dict["feat"], scale_column) 222 | instance_feat = instance_feat + pose_emb 223 | 224 | return instance_feat 225 | 226 | 227 | -------------------------------------------------------------------------------- /pointcept/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_model 2 | 3 | from .SAMPart3D import SAMPart3D 4 | from .PTv3Object import PointTransformerV3Object 5 | -------------------------------------------------------------------------------- /pointcept/models/builder.py: -------------------------------------------------------------------------------- 1 | from pointcept.utils.registry import Registry 2 | 3 | MODELS = Registry("models") 4 | MODULES = Registry("modules") 5 | 6 | 7 | def build_model(cfg): 8 | """Build models.""" 9 | return MODELS.build(cfg) 10 | -------------------------------------------------------------------------------- /pointcept/models/modules.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch.nn as nn 3 | import spconv.pytorch as spconv 4 | from collections import OrderedDict 5 | from pointcept.models.utils.structure import Point 6 | 7 | 8 | class PointModule(nn.Module): 9 | r"""PointModule 10 | placeholder, all module subclass from this will take Point in PointSequential. 11 | """ 12 | 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | 16 | 17 | class PointSequential(PointModule): 18 | r"""A sequential container. 19 | Modules will be added to it in the order they are passed in the constructor. 20 | Alternatively, an ordered dict of modules can also be passed in. 21 | """ 22 | 23 | def __init__(self, *args, **kwargs): 24 | super().__init__() 25 | if len(args) == 1 and isinstance(args[0], OrderedDict): 26 | for key, module in args[0].items(): 27 | self.add_module(key, module) 28 | else: 29 | for idx, module in enumerate(args): 30 | self.add_module(str(idx), module) 31 | for name, module in kwargs.items(): 32 | if sys.version_info < (3, 6): 33 | raise ValueError("kwargs only supported in py36+") 34 | if name in self._modules: 35 | raise ValueError("name exists.") 36 | self.add_module(name, module) 37 | 38 | def __getitem__(self, idx): 39 | if not (-len(self) <= idx < len(self)): 40 | raise IndexError("index {} is out of range".format(idx)) 41 | if idx < 0: 42 | idx += len(self) 43 | it = iter(self._modules.values()) 44 | for i in range(idx): 45 | next(it) 46 | return next(it) 47 | 48 | def __len__(self): 49 | return len(self._modules) 50 | 51 | def add(self, module, name=None): 52 | if name is None: 53 | name = str(len(self._modules)) 54 | if name in self._modules: 55 | raise KeyError("name exists") 56 | self.add_module(name, module) 57 | 58 | def forward(self, input): 59 | for k, module in self._modules.items(): 60 | # Point module 61 | if isinstance(module, PointModule): 62 | input = module(input) 63 | # Spconv module 64 | elif spconv.modules.is_spconv_module(module): 65 | if isinstance(input, Point): 66 | input.sparse_conv_feat = module(input.sparse_conv_feat) 67 | input.feat = input.sparse_conv_feat.features 68 | else: 69 | input = module(input) 70 | # PyTorch module 71 | else: 72 | if isinstance(input, Point): 73 | input.feat = module(input.feat) 74 | if "sparse_conv_feat" in input.keys(): 75 | input.sparse_conv_feat = input.sparse_conv_feat.replace_feature( 76 | input.feat 77 | ) 78 | elif isinstance(input, spconv.SparseConvTensor): 79 | if input.indices.shape[0] != 0: 80 | input = input.replace_feature(module(input.features)) 81 | else: 82 | input = module(input) 83 | return input 84 | -------------------------------------------------------------------------------- /pointcept/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import offset2batch, offset2bincount, batch2offset, off_diagonal 2 | from .checkpoint import checkpoint 3 | from .serialization import encode, decode 4 | from .structure import Point 5 | -------------------------------------------------------------------------------- /pointcept/models/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class CheckpointFunction(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, run_function, length, *args): 7 | ctx.run_function = run_function 8 | ctx.input_tensors = list(args[:length]) 9 | ctx.input_params = list(args[length:]) 10 | 11 | with torch.no_grad(): 12 | output_tensors = ctx.run_function(*ctx.input_tensors) 13 | return output_tensors 14 | 15 | @staticmethod 16 | def backward(ctx, *output_grads): 17 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 18 | with torch.enable_grad(): 19 | # Fixes a bug where the first op in run_function modifies the 20 | # Tensor storage in place, which is not allowed for detach()'d 21 | # Tensors. 22 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 23 | output_tensors = ctx.run_function(*shallow_copies) 24 | input_grads = torch.autograd.grad( 25 | output_tensors, 26 | ctx.input_tensors + ctx.input_params, 27 | output_grads, 28 | allow_unused=True, 29 | ) 30 | del ctx.input_tensors 31 | del ctx.input_params 32 | del output_tensors 33 | return (None, None) + input_grads 34 | 35 | 36 | def checkpoint(func, inputs, params, flag): 37 | """ 38 | Evaluate a function without caching intermediate activations, allowing for 39 | reduced memory at the expense of extra compute in the backward pass. 40 | :param func: the function to evaluate. 41 | :param inputs: the argument sequence to pass to `func`. 42 | :param params: a sequence of parameters `func` depends on but does not 43 | explicitly take as arguments. 44 | :param flag: if False, disable gradient checkpointing. 45 | """ 46 | if flag: 47 | args = tuple(inputs) + tuple(params) 48 | return CheckpointFunction.apply(func, len(inputs), *args) 49 | else: 50 | return func(*inputs) 51 | -------------------------------------------------------------------------------- /pointcept/models/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.inference_mode() 5 | def offset2bincount(offset): 6 | return torch.diff( 7 | offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long) 8 | ) 9 | 10 | 11 | @torch.inference_mode() 12 | def offset2batch(offset): 13 | bincount = offset2bincount(offset) 14 | return torch.arange( 15 | len(bincount), device=offset.device, dtype=torch.long 16 | ).repeat_interleave(bincount) 17 | 18 | 19 | @torch.inference_mode() 20 | def batch2offset(batch): 21 | return torch.cumsum(batch.bincount(), dim=0).long() 22 | 23 | 24 | def off_diagonal(x): 25 | # return a flattened view of the off-diagonal elements of a square matrix 26 | n, m = x.shape 27 | assert n == m 28 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() 29 | -------------------------------------------------------------------------------- /pointcept/models/utils/serialization/__init__.py: -------------------------------------------------------------------------------- 1 | from .default import ( 2 | encode, 3 | decode, 4 | z_order_encode, 5 | z_order_decode, 6 | hilbert_encode, 7 | hilbert_decode, 8 | ) 9 | -------------------------------------------------------------------------------- /pointcept/models/utils/serialization/default.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .z_order import xyz2key as z_order_encode_ 3 | from .z_order import key2xyz as z_order_decode_ 4 | from .hilbert import encode as hilbert_encode_ 5 | from .hilbert import decode as hilbert_decode_ 6 | 7 | 8 | @torch.inference_mode() 9 | def encode(grid_coord, batch=None, depth=16, order="z"): 10 | assert order in {"z", "z-trans", "hilbert", "hilbert-trans"} 11 | if order == "z": 12 | code = z_order_encode(grid_coord, depth=depth) 13 | elif order == "z-trans": 14 | code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) 15 | elif order == "hilbert": 16 | code = hilbert_encode(grid_coord, depth=depth) 17 | elif order == "hilbert-trans": 18 | code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) 19 | else: 20 | raise NotImplementedError 21 | if batch is not None: 22 | batch = batch.long() 23 | code = batch << depth * 3 | code 24 | return code 25 | 26 | 27 | @torch.inference_mode() 28 | def decode(code, depth=16, order="z"): 29 | assert order in {"z", "hilbert"} 30 | batch = code >> depth * 3 31 | code = code & ((1 << depth * 3) - 1) 32 | if order == "z": 33 | grid_coord = z_order_decode(code, depth=depth) 34 | elif order == "hilbert": 35 | grid_coord = hilbert_decode(code, depth=depth) 36 | else: 37 | raise NotImplementedError 38 | return grid_coord, batch 39 | 40 | 41 | def z_order_encode(grid_coord: torch.Tensor, depth: int = 16): 42 | x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long() 43 | # we block the support to batch, maintain batched code in Point class 44 | code = z_order_encode_(x, y, z, b=None, depth=depth) 45 | return code 46 | 47 | 48 | def z_order_decode(code: torch.Tensor, depth): 49 | x, y, z = z_order_decode_(code, depth=depth) 50 | grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3) 51 | return grid_coord 52 | 53 | 54 | def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16): 55 | return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth) 56 | 57 | 58 | def hilbert_decode(code: torch.Tensor, depth: int = 16): 59 | return hilbert_decode_(code, num_dims=3, num_bits=depth) 60 | -------------------------------------------------------------------------------- /pointcept/models/utils/serialization/z_order.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Union 3 | 4 | 5 | class KeyLUT: 6 | def __init__(self): 7 | r256 = torch.arange(256, dtype=torch.int64) 8 | r512 = torch.arange(512, dtype=torch.int64) 9 | zero = torch.zeros(256, dtype=torch.int64) 10 | device = torch.device("cpu") 11 | 12 | self._encode = { 13 | device: ( 14 | self.xyz2key(r256, zero, zero, 8), 15 | self.xyz2key(zero, r256, zero, 8), 16 | self.xyz2key(zero, zero, r256, 8), 17 | ) 18 | } 19 | self._decode = {device: self.key2xyz(r512, 9)} 20 | 21 | def encode_lut(self, device=torch.device("cpu")): 22 | if device not in self._encode: 23 | cpu = torch.device("cpu") 24 | self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) 25 | return self._encode[device] 26 | 27 | def decode_lut(self, device=torch.device("cpu")): 28 | if device not in self._decode: 29 | cpu = torch.device("cpu") 30 | self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) 31 | return self._decode[device] 32 | 33 | def xyz2key(self, x, y, z, depth): 34 | key = torch.zeros_like(x) 35 | for i in range(depth): 36 | mask = 1 << i 37 | key = ( 38 | key 39 | | ((x & mask) << (2 * i + 2)) 40 | | ((y & mask) << (2 * i + 1)) 41 | | ((z & mask) << (2 * i + 0)) 42 | ) 43 | return key 44 | 45 | def key2xyz(self, key, depth): 46 | x = torch.zeros_like(key) 47 | y = torch.zeros_like(key) 48 | z = torch.zeros_like(key) 49 | for i in range(depth): 50 | x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) 51 | y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) 52 | z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) 53 | return x, y, z 54 | 55 | 56 | _key_lut = KeyLUT() 57 | 58 | 59 | def xyz2key( 60 | x: torch.Tensor, 61 | y: torch.Tensor, 62 | z: torch.Tensor, 63 | b: Optional[Union[torch.Tensor, int]] = None, 64 | depth: int = 16, 65 | ): 66 | r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys 67 | based on pre-computed look up tables. The speed of this function is much 68 | faster than the method based on for-loop. 69 | 70 | Args: 71 | x (torch.Tensor): The x coordinate. 72 | y (torch.Tensor): The y coordinate. 73 | z (torch.Tensor): The z coordinate. 74 | b (torch.Tensor or int): The batch index of the coordinates, and should be 75 | smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of 76 | :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. 77 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). 78 | """ 79 | 80 | EX, EY, EZ = _key_lut.encode_lut(x.device) 81 | x, y, z = x.long(), y.long(), z.long() 82 | 83 | mask = 255 if depth > 8 else (1 << depth) - 1 84 | key = EX[x & mask] | EY[y & mask] | EZ[z & mask] 85 | if depth > 8: 86 | mask = (1 << (depth - 8)) - 1 87 | key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] 88 | key = key16 << 24 | key 89 | 90 | if b is not None: 91 | b = b.long() 92 | key = b << 48 | key 93 | 94 | return key 95 | 96 | 97 | def key2xyz(key: torch.Tensor, depth: int = 16): 98 | r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates 99 | and the batch index based on pre-computed look up tables. 100 | 101 | Args: 102 | key (torch.Tensor): The shuffled key. 103 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). 104 | """ 105 | 106 | DX, DY, DZ = _key_lut.decode_lut(key.device) 107 | x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) 108 | 109 | b = key >> 48 110 | key = key & ((1 << 48) - 1) 111 | 112 | n = (depth + 2) // 3 113 | for i in range(n): 114 | k = key >> (i * 9) & 511 115 | x = x | (DX[k] << (i * 3)) 116 | y = y | (DY[k] << (i * 3)) 117 | z = z | (DZ[k] << (i * 3)) 118 | 119 | return x, y, z, b 120 | -------------------------------------------------------------------------------- /pointcept/models/utils/structure.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import spconv.pytorch as spconv 3 | 4 | try: 5 | import ocnn 6 | except ImportError: 7 | ocnn = None 8 | from addict import Dict 9 | 10 | from pointcept.models.utils.serialization import encode, decode 11 | from pointcept.models.utils import offset2batch, batch2offset 12 | import torch_scatter 13 | 14 | class Point(Dict): 15 | """ 16 | Point Structure of Pointcept 17 | 18 | A Point (point cloud) in Pointcept is a dictionary that contains various properties of 19 | a batched point cloud. The property with the following names have a specific definition 20 | as follows: 21 | 22 | - "coord": original coordinate of point cloud; 23 | - "grid_coord": grid coordinate for specific grid size (related to GridSampling); 24 | Point also support the following optional attributes: 25 | - "offset": if not exist, initialized as batch size is 1; 26 | - "batch": if not exist, initialized as batch size is 1; 27 | - "feat": feature of point cloud, default input of model; 28 | - "grid_size": Grid size of point cloud (related to GridSampling); 29 | (related to Serialization) 30 | - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range; 31 | - "serialized_code": a list of serialization codes; 32 | - "serialized_order": a list of serialization order determined by code; 33 | - "serialized_inverse": a list of inverse mapping determined by code; 34 | (related to Sparsify: SpConv) 35 | - "sparse_shape": Sparse shape for Sparse Conv Tensor; 36 | - "sparse_conv_feat": SparseConvTensor init with information provide by Point; 37 | """ 38 | 39 | def __init__(self, *args, **kwargs): 40 | super().__init__(*args, **kwargs) 41 | # If one of "offset" or "batch" do not exist, generate by the existing one 42 | if "batch" not in self.keys() and "offset" in self.keys(): 43 | self["batch"] = offset2batch(self.offset) 44 | elif "offset" not in self.keys() and "batch" in self.keys(): 45 | self["offset"] = batch2offset(self.batch) 46 | 47 | def serialization(self, order="z", depth=None, shuffle_orders=False): 48 | """ 49 | Point Cloud Serialization 50 | 51 | relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] 52 | """ 53 | assert "batch" in self.keys() 54 | # if "grid_coord" not in self.keys(): 55 | # # if you don't want to operate GridSampling in data augmentation, 56 | # # please add the following augmentation into your pipline: 57 | # # dict(type="Copy", keys_dict={"grid_size": 0.01}), 58 | # # (adjust `grid_size` to what your want) 59 | # assert {"grid_size", "coord"}.issubset(self.keys()) 60 | # self["grid_coord"] = torch.div( 61 | # self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" 62 | # ).int() 63 | if "grid_coord" not in self.keys(): 64 | # if you don't want to operate GridSampling in data augmentation, 65 | # please add the following augmentation into your pipline: 66 | # dict(type="Copy", keys_dict={"grid_size": 0.01}), 67 | # (adjust `grid_size` to what your want) 68 | assert {"grid_size", "coord"}.issubset(self.keys()) 69 | idx_ptr = torch.nn.functional.pad(self.offset, (1, 0), value=0) 70 | min_coord = torch_scatter.segment_csr(self.coord, idx_ptr, reduce="min") 71 | self["grid_coord"] = torch.div( 72 | self.coord - min_coord[self.batch], 73 | self.grid_size, 74 | rounding_mode="trunc", 75 | ).int() 76 | 77 | # print(self.grid_coord.max()) 78 | # print(int(self.grid_coord.max()).bit_length()) 79 | 80 | if depth is None: 81 | # Adaptive measure the depth of serialization cube (length = 2 ^ depth) 82 | depth = int(self.grid_coord.max()).bit_length() 83 | self["serialized_depth"] = depth 84 | # Maximum bit length for serialization code is 63 (int64) 85 | assert depth * 3 + len(self.offset).bit_length() <= 63 86 | # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position. 87 | # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3 88 | # cube with a grid size of 0.01 meter. We consider it is enough for the current stage. 89 | # We can unlock the limitation by optimizing the z-order encoding function if necessary. 90 | assert depth <= 16 91 | 92 | # The serialization codes are arranged as following structures: 93 | # [Order1 ([n]), 94 | # Order2 ([n]), 95 | # ... 96 | # OrderN ([n])] (k, n) 97 | code = [ 98 | encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order 99 | ] 100 | code = torch.stack(code) 101 | order = torch.argsort(code) 102 | inverse = torch.zeros_like(order).scatter_( 103 | dim=1, 104 | index=order, 105 | src=torch.arange(0, code.shape[1], device=order.device).repeat( 106 | code.shape[0], 1 107 | ), 108 | ) 109 | 110 | if shuffle_orders: 111 | perm = torch.randperm(code.shape[0]) 112 | code = code[perm] 113 | order = order[perm] 114 | inverse = inverse[perm] 115 | 116 | self["serialized_code"] = code 117 | self["serialized_order"] = order 118 | self["serialized_inverse"] = inverse 119 | 120 | def sparsify(self, pad=96): 121 | """ 122 | Point Cloud Serialization 123 | 124 | Point cloud is sparse, here we use "sparsify" to specifically refer to 125 | preparing "spconv.SparseConvTensor" for SpConv. 126 | 127 | relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] 128 | 129 | pad: padding sparse for sparse shape. 130 | """ 131 | assert {"feat", "batch"}.issubset(self.keys()) 132 | # if "grid_coord" not in self.keys(): 133 | # # if you don't want to operate GridSampling in data augmentation, 134 | # # please add the following augmentation into your pipline: 135 | # # dict(type="Copy", keys_dict={"grid_size": 0.01}), 136 | # # (adjust `grid_size` to what your want) 137 | # assert {"grid_size", "coord"}.issubset(self.keys()) 138 | # self["grid_coord"] = torch.div( 139 | # self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" 140 | # ).int() 141 | if "grid_coord" not in self.keys(): 142 | # if you don't want to operate GridSampling in data augmentation, 143 | # please add the following augmentation into your pipline: 144 | # dict(type="Copy", keys_dict={"grid_size": 0.01}), 145 | # (adjust `grid_size` to what your want) 146 | assert {"grid_size", "coord"}.issubset(self.keys()) 147 | idx_ptr = torch.nn.functional.pad(self.offset, (1, 0), value=0) 148 | min_coord = torch_scatter.segment_csr(self.coord, idx_ptr, reduce="min") 149 | self["grid_coord"] = torch.div( 150 | self.coord - min_coord[self.batch], 151 | self.grid_size, 152 | rounding_mode="trunc", 153 | ).int() 154 | if "sparse_shape" in self.keys(): 155 | sparse_shape = self.sparse_shape 156 | else: 157 | sparse_shape = torch.add( 158 | torch.max(self.grid_coord, dim=0).values, pad 159 | ).tolist() 160 | sparse_conv_feat = spconv.SparseConvTensor( 161 | features=self.feat, 162 | indices=torch.cat( 163 | [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1 164 | ).contiguous(), 165 | spatial_shape=sparse_shape, 166 | batch_size=self.batch[-1].tolist() + 1, 167 | ) 168 | self["sparse_shape"] = sparse_shape 169 | self["sparse_conv_feat"] = sparse_conv_feat 170 | 171 | def octreetization(self, depth=None, full_depth=None): 172 | """ 173 | Point Cloud Octreelization 174 | 175 | Generate octree with OCNN 176 | relay on ["grid_coord", "batch", "feat"] 177 | """ 178 | assert ( 179 | ocnn is not None 180 | ), "Please follow https://github.com/octree-nn/ocnn-pytorch install ocnn." 181 | assert {"grid_coord", "feat", "batch"}.issubset(self.keys()) 182 | # add 1 to make grid space support shift order 183 | if depth is None: 184 | if "depth" in self.keys(): 185 | depth = self.depth 186 | else: 187 | depth = int(self.grid_coord.max() + 1).bit_length() 188 | if full_depth is None: 189 | full_depth = 2 190 | self["depth"] = depth 191 | assert depth <= 16 # maximum in ocnn 192 | 193 | # [0, 2**depth] -> [0, 2] -> [-1, 1] 194 | coord = self.grid_coord / 2 ** (self.depth - 1) - 1.0 195 | point = ocnn.octree.Points( 196 | points=coord, 197 | features=self.feat, 198 | batch_id=self.batch.unsqueeze(-1), 199 | batch_size=self.batch[-1] + 1, 200 | ) 201 | octree = ocnn.octree.Octree( 202 | depth=depth, 203 | full_depth=full_depth, 204 | batch_size=self.batch[-1] + 1, 205 | device=coord.device, 206 | ) 207 | octree.build_octree(point) 208 | octree.construct_all_neigh() 209 | self["octree"] = octree 210 | -------------------------------------------------------------------------------- /pointcept/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pointcept/SAMPart3D/6a508d145be5cecd7e80bce6e4248ea7abbd71b7/pointcept/utils/__init__.py -------------------------------------------------------------------------------- /pointcept/utils/cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Cache Utils 3 | 4 | Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) 5 | Please cite our work if the code is helpful to you. 6 | """ 7 | 8 | import os 9 | # import SharedArray 10 | 11 | try: 12 | from multiprocessing.shared_memory import ShareableList 13 | except ImportError: 14 | import warnings 15 | 16 | warnings.warn("Please update python version >= 3.8 to enable shared_memory") 17 | import numpy as np 18 | 19 | 20 | def shared_array(name, var=None): 21 | if var is not None: 22 | # check exist 23 | if os.path.exists(f"/dev/shm/{name}"): 24 | return SharedArray.attach(f"shm://{name}") 25 | # create shared_array 26 | data = SharedArray.create(f"shm://{name}", var.shape, dtype=var.dtype) 27 | data[...] = var[...] 28 | data.flags.writeable = False 29 | else: 30 | data = SharedArray.attach(f"shm://{name}").copy() 31 | return data 32 | 33 | 34 | def shared_dict(name, var=None): 35 | name = str(name) 36 | assert "." not in name # '.' is used as sep flag 37 | data = {} 38 | if var is not None: 39 | assert isinstance(var, dict) 40 | keys = var.keys() 41 | # current version only cache np.array 42 | keys_valid = [] 43 | for key in keys: 44 | if isinstance(var[key], np.ndarray): 45 | keys_valid.append(key) 46 | keys = keys_valid 47 | 48 | ShareableList(sequence=keys, name=name + ".keys") 49 | for key in keys: 50 | if isinstance(var[key], np.ndarray): 51 | data[key] = shared_array(name=f"{name}.{key}", var=var[key]) 52 | else: 53 | keys = list(ShareableList(name=name + ".keys")) 54 | for key in keys: 55 | data[key] = shared_array(name=f"{name}.{key}") 56 | return data 57 | -------------------------------------------------------------------------------- /pointcept/utils/comm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | """ 3 | This file contains primitives for multi-gpu communication. 4 | This is useful when doing distributed training. 5 | Modified from detectron2(https://github.com/facebookresearch/detectron2) 6 | 7 | Copyright (c) Xiaoyang Wu (xiaoyang.wu@connect.hku.hk). All Rights Reserved. 8 | Please cite our work if you use any part of the code. 9 | """ 10 | 11 | import functools 12 | import numpy as np 13 | import torch 14 | import torch.distributed as dist 15 | 16 | _LOCAL_PROCESS_GROUP = None 17 | """ 18 | A torch process group which only includes processes that on the same machine as the current process. 19 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 20 | """ 21 | 22 | 23 | def get_world_size() -> int: 24 | if not dist.is_available(): 25 | return 1 26 | if not dist.is_initialized(): 27 | return 1 28 | return dist.get_world_size() 29 | 30 | 31 | def get_rank() -> int: 32 | if not dist.is_available(): 33 | return 0 34 | if not dist.is_initialized(): 35 | return 0 36 | return dist.get_rank() 37 | 38 | 39 | def get_local_rank() -> int: 40 | """ 41 | Returns: 42 | The rank of the current process within the local (per-machine) process group. 43 | """ 44 | if not dist.is_available(): 45 | return 0 46 | if not dist.is_initialized(): 47 | return 0 48 | assert ( 49 | _LOCAL_PROCESS_GROUP is not None 50 | ), "Local process group is not created! Please use launch() to spawn processes!" 51 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 52 | 53 | 54 | def get_local_size() -> int: 55 | """ 56 | Returns: 57 | The size of the per-machine process group, 58 | i.e. the number of processes per machine. 59 | """ 60 | if not dist.is_available(): 61 | return 1 62 | if not dist.is_initialized(): 63 | return 1 64 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 65 | 66 | 67 | def is_main_process() -> bool: 68 | return get_rank() == 0 69 | 70 | 71 | def synchronize(): 72 | """ 73 | Helper function to synchronize (barrier) among all processes when 74 | using distributed training 75 | """ 76 | if not dist.is_available(): 77 | return 78 | if not dist.is_initialized(): 79 | return 80 | world_size = dist.get_world_size() 81 | if world_size == 1: 82 | return 83 | if dist.get_backend() == dist.Backend.NCCL: 84 | # This argument is needed to avoid warnings. 85 | # It's valid only for NCCL backend. 86 | dist.barrier(device_ids=[torch.cuda.current_device()]) 87 | else: 88 | dist.barrier() 89 | 90 | 91 | @functools.lru_cache() 92 | def _get_global_gloo_group(): 93 | """ 94 | Return a process group based on gloo backend, containing all the ranks 95 | The result is cached. 96 | """ 97 | if dist.get_backend() == "nccl": 98 | return dist.new_group(backend="gloo") 99 | else: 100 | return dist.group.WORLD 101 | 102 | 103 | def all_gather(data, group=None): 104 | """ 105 | Run all_gather on arbitrary picklable data (not necessarily tensors). 106 | Args: 107 | data: any picklable object 108 | group: a torch process group. By default, will use a group which 109 | contains all ranks on gloo backend. 110 | Returns: 111 | list[data]: list of data gathered from each rank 112 | """ 113 | if get_world_size() == 1: 114 | return [data] 115 | if group is None: 116 | group = ( 117 | _get_global_gloo_group() 118 | ) # use CPU group by default, to reduce GPU RAM usage. 119 | world_size = dist.get_world_size(group) 120 | if world_size == 1: 121 | return [data] 122 | 123 | output = [None for _ in range(world_size)] 124 | dist.all_gather_object(output, data, group=group) 125 | return output 126 | 127 | 128 | def gather(data, dst=0, group=None): 129 | """ 130 | Run gather on arbitrary picklable data (not necessarily tensors). 131 | Args: 132 | data: any picklable object 133 | dst (int): destination rank 134 | group: a torch process group. By default, will use a group which 135 | contains all ranks on gloo backend. 136 | Returns: 137 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 138 | an empty list. 139 | """ 140 | if get_world_size() == 1: 141 | return [data] 142 | if group is None: 143 | group = _get_global_gloo_group() 144 | world_size = dist.get_world_size(group=group) 145 | if world_size == 1: 146 | return [data] 147 | rank = dist.get_rank(group=group) 148 | 149 | if rank == dst: 150 | output = [None for _ in range(world_size)] 151 | dist.gather_object(data, output, dst=dst, group=group) 152 | return output 153 | else: 154 | dist.gather_object(data, None, dst=dst, group=group) 155 | return [] 156 | 157 | 158 | def shared_random_seed(): 159 | """ 160 | Returns: 161 | int: a random number that is the same across all workers. 162 | If workers need a shared RNG, they can use this shared seed to 163 | create one. 164 | All workers must call this function, otherwise it will deadlock. 165 | """ 166 | ints = np.random.randint(2**31) 167 | all_ints = all_gather(ints) 168 | return all_ints[0] 169 | 170 | 171 | def reduce_dict(input_dict, average=True): 172 | """ 173 | Reduce the values in the dictionary from all processes so that process with rank 174 | 0 has the reduced results. 175 | Args: 176 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 177 | average (bool): whether to do average or sum 178 | Returns: 179 | a dict with the same keys as input_dict, after reduction. 180 | """ 181 | world_size = get_world_size() 182 | if world_size < 2: 183 | return input_dict 184 | with torch.no_grad(): 185 | names = [] 186 | values = [] 187 | # sort the keys so that they are consistent across processes 188 | for k in sorted(input_dict.keys()): 189 | names.append(k) 190 | values.append(input_dict[k]) 191 | values = torch.stack(values, dim=0) 192 | dist.reduce(values, dst=0) 193 | if dist.get_rank() == 0 and average: 194 | # only main process gets accumulated, so only divide by 195 | # world_size in this case 196 | values /= world_size 197 | reduced_dict = {k: v for k, v in zip(names, values)} 198 | return reduced_dict 199 | -------------------------------------------------------------------------------- /pointcept/utils/env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Environment Utils 3 | 4 | Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) 5 | Please cite our work if the code is helpful to you. 6 | """ 7 | 8 | import os 9 | import random 10 | import numpy as np 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | 14 | from datetime import datetime 15 | 16 | 17 | def get_random_seed(): 18 | seed = ( 19 | os.getpid() 20 | + int(datetime.now().strftime("%S%f")) 21 | + int.from_bytes(os.urandom(2), "big") 22 | ) 23 | return seed 24 | 25 | 26 | def set_seed(seed=None): 27 | if seed is None: 28 | seed = get_random_seed() 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | cudnn.benchmark = False 35 | cudnn.deterministic = True 36 | os.environ["PYTHONHASHSEED"] = str(seed) 37 | -------------------------------------------------------------------------------- /pointcept/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger Utils 3 | 4 | Modified from mmcv 5 | 6 | Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) 7 | Please cite our work if the code is helpful to you. 8 | """ 9 | 10 | import logging 11 | import torch 12 | import torch.distributed as dist 13 | 14 | from termcolor import colored 15 | 16 | logger_initialized = {} 17 | root_status = 0 18 | 19 | 20 | class _ColorfulFormatter(logging.Formatter): 21 | def __init__(self, *args, **kwargs): 22 | self._root_name = kwargs.pop("root_name") + "." 23 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 24 | 25 | def formatMessage(self, record): 26 | log = super(_ColorfulFormatter, self).formatMessage(record) 27 | if record.levelno == logging.WARNING: 28 | prefix = colored("WARNING", "red", attrs=["blink"]) 29 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 30 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 31 | else: 32 | return log 33 | return prefix + " " + log 34 | 35 | 36 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="a", color=False): 37 | """Initialize and get a logger by name. 38 | 39 | If the logger has not been initialized, this method will initialize the 40 | logger by adding one or two handlers, otherwise the initialized logger will 41 | be directly returned. During initialization, a StreamHandler will always be 42 | added. If `log_file` is specified and the process rank is 0, a FileHandler 43 | will also be added. 44 | 45 | Args: 46 | name (str): Logger name. 47 | log_file (str | None): The log filename. If specified, a FileHandler 48 | will be added to the logger. 49 | log_level (int): The logger level. Note that only the process of 50 | rank 0 is affected, and other processes will set the level to 51 | "Error" thus be silent most of the time. 52 | file_mode (str): The file mode used in opening log file. 53 | Defaults to 'a'. 54 | color (bool): Colorful log output. Defaults to True 55 | 56 | Returns: 57 | logging.Logger: The expected logger. 58 | """ 59 | logger = logging.getLogger(name) 60 | 61 | if name in logger_initialized: 62 | return logger 63 | # handle hierarchical names 64 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 65 | # initialization since it is a child of "a". 66 | for logger_name in logger_initialized: 67 | if name.startswith(logger_name): 68 | return logger 69 | 70 | logger.propagate = False 71 | 72 | stream_handler = logging.StreamHandler() 73 | handlers = [stream_handler] 74 | 75 | if dist.is_available() and dist.is_initialized(): 76 | rank = dist.get_rank() 77 | else: 78 | rank = 0 79 | 80 | # only rank 0 will add a FileHandler 81 | if rank == 0 and log_file is not None: 82 | # Here, the default behaviour of the official logger is 'a'. Thus, we 83 | # provide an interface to change the file mode to the default 84 | # behaviour. 85 | file_handler = logging.FileHandler(log_file, file_mode) 86 | handlers.append(file_handler) 87 | 88 | plain_formatter = logging.Formatter( 89 | "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 90 | ) 91 | if color: 92 | formatter = _ColorfulFormatter( 93 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 94 | datefmt="%m/%d %H:%M:%S", 95 | root_name=name, 96 | ) 97 | else: 98 | formatter = plain_formatter 99 | for handler in handlers: 100 | handler.setFormatter(formatter) 101 | handler.setLevel(log_level) 102 | logger.addHandler(handler) 103 | 104 | if rank == 0: 105 | logger.setLevel(log_level) 106 | else: 107 | logger.setLevel(logging.ERROR) 108 | 109 | logger_initialized[name] = True 110 | 111 | return logger 112 | 113 | 114 | def print_log(msg, logger=None, level=logging.INFO): 115 | """Print a log message. 116 | 117 | Args: 118 | msg (str): The message to be logged. 119 | logger (logging.Logger | str | None): The logger to be used. 120 | Some special loggers are: 121 | - "silent": no message will be printed. 122 | - other str: the logger obtained with `get_root_logger(logger)`. 123 | - None: The `print()` method will be used to print log messages. 124 | level (int): Logging level. Only available when `logger` is a Logger 125 | object or "root". 126 | """ 127 | if logger is None: 128 | print(msg) 129 | elif isinstance(logger, logging.Logger): 130 | logger.log(level, msg) 131 | elif logger == "silent": 132 | pass 133 | elif isinstance(logger, str): 134 | _logger = get_logger(logger) 135 | _logger.log(level, msg) 136 | else: 137 | raise TypeError( 138 | "logger should be either a logging.Logger object, str, " 139 | f'"silent" or None, but got {type(logger)}' 140 | ) 141 | 142 | 143 | def get_root_logger(log_file=None, log_level=logging.INFO, file_mode="a"): 144 | """Get the root logger. 145 | 146 | The logger will be initialized if it has not been initialized. By default a 147 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 148 | also be added. The name of the root logger is the top-level package name. 149 | 150 | Args: 151 | log_file (str | None): The log filename. If specified, a FileHandler 152 | will be added to the root logger. 153 | log_level (int): The root logger level. Note that only the process of 154 | rank 0 is affected, while other processes will set the level to 155 | "Error" and be silent most of the time. 156 | file_mode (str): File Mode of logger. (w or a) 157 | 158 | Returns: 159 | logging.Logger: The root logger. 160 | """ 161 | logger = get_logger( 162 | name="pointcept", log_file=log_file, log_level=log_level, file_mode=file_mode 163 | ) 164 | return logger 165 | 166 | 167 | def _log_api_usage(identifier: str): 168 | """ 169 | Internal function used to log the usage of different detectron2 components 170 | inside facebook's infra. 171 | """ 172 | torch._C._log_api_usage_once("pointcept." + identifier) 173 | -------------------------------------------------------------------------------- /pointcept/utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc 3 | 4 | Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) 5 | Please cite our work if the code is helpful to you. 6 | """ 7 | 8 | import os 9 | import warnings 10 | from collections import abc 11 | import numpy as np 12 | import torch 13 | from importlib import import_module 14 | 15 | 16 | class AverageMeter(object): 17 | """Computes and stores the average and current value""" 18 | 19 | def __init__(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | 38 | def intersection_and_union(output, target, K, ignore_index=-1): 39 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 40 | assert output.ndim in [1, 2, 3] 41 | assert output.shape == target.shape 42 | output = output.reshape(output.size).copy() 43 | target = target.reshape(target.size) 44 | output[np.where(target == ignore_index)[0]] = ignore_index 45 | intersection = output[np.where(output == target)[0]] 46 | area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1)) 47 | area_output, _ = np.histogram(output, bins=np.arange(K + 1)) 48 | area_target, _ = np.histogram(target, bins=np.arange(K + 1)) 49 | area_union = area_output + area_target - area_intersection 50 | return area_intersection, area_union, area_target 51 | 52 | 53 | def intersection_and_union_gpu(output, target, k, ignore_index=-1): 54 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 55 | assert output.dim() in [1, 2, 3] 56 | assert output.shape == target.shape 57 | output = output.view(-1) 58 | target = target.view(-1) 59 | output[target == ignore_index] = ignore_index 60 | intersection = output[output == target] 61 | area_intersection = torch.histc(intersection, bins=k, min=0, max=k - 1) 62 | area_output = torch.histc(output, bins=k, min=0, max=k - 1) 63 | area_target = torch.histc(target, bins=k, min=0, max=k - 1) 64 | area_union = area_output + area_target - area_intersection 65 | return area_intersection, area_union, area_target 66 | 67 | 68 | def make_dirs(dir_name): 69 | if not os.path.exists(dir_name): 70 | os.makedirs(dir_name, exist_ok=True) 71 | 72 | 73 | def find_free_port(): 74 | import socket 75 | 76 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 77 | # Binding to port 0 will cause the OS to find an available port for us 78 | sock.bind(("", 0)) 79 | port = sock.getsockname()[1] 80 | sock.close() 81 | # NOTE: there is still a chance the port could be taken by other processes. 82 | return port 83 | 84 | 85 | def is_seq_of(seq, expected_type, seq_type=None): 86 | """Check whether it is a sequence of some type. 87 | 88 | Args: 89 | seq (Sequence): The sequence to be checked. 90 | expected_type (type): Expected type of sequence items. 91 | seq_type (type, optional): Expected sequence type. 92 | 93 | Returns: 94 | bool: Whether the sequence is valid. 95 | """ 96 | if seq_type is None: 97 | exp_seq_type = abc.Sequence 98 | else: 99 | assert isinstance(seq_type, type) 100 | exp_seq_type = seq_type 101 | if not isinstance(seq, exp_seq_type): 102 | return False 103 | for item in seq: 104 | if not isinstance(item, expected_type): 105 | return False 106 | return True 107 | 108 | 109 | def is_str(x): 110 | """Whether the input is an string instance. 111 | 112 | Note: This method is deprecated since python 2 is no longer supported. 113 | """ 114 | return isinstance(x, str) 115 | 116 | 117 | def import_modules_from_strings(imports, allow_failed_imports=False): 118 | """Import modules from the given list of strings. 119 | 120 | Args: 121 | imports (list | str | None): The given module names to be imported. 122 | allow_failed_imports (bool): If True, the failed imports will return 123 | None. Otherwise, an ImportError is raise. Default: False. 124 | 125 | Returns: 126 | list[module] | module | None: The imported modules. 127 | 128 | Examples: 129 | >>> osp, sys = import_modules_from_strings( 130 | ... ['os.path', 'sys']) 131 | >>> import os.path as osp_ 132 | >>> import sys as sys_ 133 | >>> assert osp == osp_ 134 | >>> assert sys == sys_ 135 | """ 136 | if not imports: 137 | return 138 | single_import = False 139 | if isinstance(imports, str): 140 | single_import = True 141 | imports = [imports] 142 | if not isinstance(imports, list): 143 | raise TypeError(f"custom_imports must be a list but got type {type(imports)}") 144 | imported = [] 145 | for imp in imports: 146 | if not isinstance(imp, str): 147 | raise TypeError(f"{imp} is of type {type(imp)} and cannot be imported.") 148 | try: 149 | imported_tmp = import_module(imp) 150 | except ImportError: 151 | if allow_failed_imports: 152 | warnings.warn(f"{imp} failed to import and is ignored.", UserWarning) 153 | imported_tmp = None 154 | else: 155 | raise ImportError 156 | imported.append(imported_tmp) 157 | if single_import: 158 | imported = imported[0] 159 | return imported 160 | -------------------------------------------------------------------------------- /pointcept/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Optimizer 3 | 4 | Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) 5 | Please cite our work if the code is helpful to you. 6 | """ 7 | 8 | import torch 9 | from pointcept.utils.logger import get_root_logger 10 | from pointcept.utils.registry import Registry 11 | 12 | OPTIMIZERS = Registry("optimizers") 13 | 14 | 15 | OPTIMIZERS.register_module(module=torch.optim.SGD, name="SGD") 16 | OPTIMIZERS.register_module(module=torch.optim.Adam, name="Adam") 17 | OPTIMIZERS.register_module(module=torch.optim.AdamW, name="AdamW") 18 | 19 | 20 | def build_optimizer(cfg, model, param_dicts=None): 21 | if param_dicts is None: 22 | # cfg.params = model.parameters() 23 | cfg.params = [dict(names=[], params=[], lr=cfg.lr)] 24 | for n, p in model.named_parameters(): 25 | if not p.requires_grad: 26 | continue 27 | cfg.params[0]["names"].append(n) 28 | cfg.params[0]["params"].append(p) 29 | else: 30 | cfg.params = [dict(names=[], params=[], lr=cfg.lr)] 31 | for i in range(len(param_dicts)): 32 | param_group = dict(names=[], params=[]) 33 | if "lr" in param_dicts[i].keys(): 34 | param_group["lr"] = param_dicts[i].lr 35 | if "momentum" in param_dicts[i].keys(): 36 | param_group["momentum"] = param_dicts[i].momentum 37 | if "weight_decay" in param_dicts[i].keys(): 38 | param_group["weight_decay"] = param_dicts[i].weight_decay 39 | cfg.params.append(param_group) 40 | 41 | for n, p in model.named_parameters(): 42 | # !!! requires_grad is a must 43 | if not p.requires_grad: 44 | continue 45 | flag = False 46 | for i in range(len(param_dicts)): 47 | if param_dicts[i].keyword in n: 48 | cfg.params[i + 1]["names"].append(n) 49 | cfg.params[i + 1]["params"].append(p) 50 | flag = True 51 | break 52 | if not flag: 53 | cfg.params[0]["names"].append(n) 54 | cfg.params[0]["params"].append(p) 55 | 56 | logger = get_root_logger() 57 | 58 | for i in range(len(cfg.params)): 59 | param_names = cfg.params[i].pop("names") 60 | message = "" 61 | for key in cfg.params[i].keys(): 62 | if key != "params": 63 | message += f" {key}: {cfg.params[i][key]};" 64 | logger.info(f"Params Group {i+1} -{message} Params: {param_names}.") 65 | # print(111) 66 | # exit(0) 67 | return OPTIMIZERS.build(cfg=cfg) 68 | -------------------------------------------------------------------------------- /pointcept/utils/path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import os.path as osp 4 | from pathlib import Path 5 | 6 | from .misc import is_str 7 | 8 | 9 | def is_filepath(x): 10 | return is_str(x) or isinstance(x, Path) 11 | 12 | 13 | def fopen(filepath, *args, **kwargs): 14 | if is_str(filepath): 15 | return open(filepath, *args, **kwargs) 16 | elif isinstance(filepath, Path): 17 | return filepath.open(*args, **kwargs) 18 | raise ValueError("`filepath` should be a string or a Path") 19 | 20 | 21 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): 22 | if not osp.isfile(filename): 23 | raise FileNotFoundError(msg_tmpl.format(filename)) 24 | 25 | 26 | def mkdir_or_exist(dir_name, mode=0o777): 27 | if dir_name == "": 28 | return 29 | dir_name = osp.expanduser(dir_name) 30 | os.makedirs(dir_name, mode=mode, exist_ok=True) 31 | 32 | 33 | def symlink(src, dst, overwrite=True, **kwargs): 34 | if os.path.lexists(dst) and overwrite: 35 | os.remove(dst) 36 | os.symlink(src, dst, **kwargs) 37 | 38 | 39 | def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): 40 | """Scan a directory to find the interested files. 41 | 42 | Args: 43 | dir_path (str | obj:`Path`): Path of the directory. 44 | suffix (str | tuple(str), optional): File suffix that we are 45 | interested in. Default: None. 46 | recursive (bool, optional): If set to True, recursively scan the 47 | directory. Default: False. 48 | case_sensitive (bool, optional) : If set to False, ignore the case of 49 | suffix. Default: True. 50 | 51 | Returns: 52 | A generator for all the interested files with relative paths. 53 | """ 54 | if isinstance(dir_path, (str, Path)): 55 | dir_path = str(dir_path) 56 | else: 57 | raise TypeError('"dir_path" must be a string or Path object') 58 | 59 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 60 | raise TypeError('"suffix" must be a string or tuple of strings') 61 | 62 | if suffix is not None and not case_sensitive: 63 | suffix = ( 64 | suffix.lower() 65 | if isinstance(suffix, str) 66 | else tuple(item.lower() for item in suffix) 67 | ) 68 | 69 | root = dir_path 70 | 71 | def _scandir(dir_path, suffix, recursive, case_sensitive): 72 | for entry in os.scandir(dir_path): 73 | if not entry.name.startswith(".") and entry.is_file(): 74 | rel_path = osp.relpath(entry.path, root) 75 | _rel_path = rel_path if case_sensitive else rel_path.lower() 76 | if suffix is None or _rel_path.endswith(suffix): 77 | yield rel_path 78 | elif recursive and os.path.isdir(entry.path): 79 | # scan recursively if entry.path is a directory 80 | yield from _scandir(entry.path, suffix, recursive, case_sensitive) 81 | 82 | return _scandir(dir_path, suffix, recursive, case_sensitive) 83 | 84 | 85 | def find_vcs_root(path, markers=(".git",)): 86 | """Finds the root directory (including itself) of specified markers. 87 | 88 | Args: 89 | path (str): Path of directory or file. 90 | markers (list[str], optional): List of file or directory names. 91 | 92 | Returns: 93 | The directory contained one of the markers or None if not found. 94 | """ 95 | if osp.isfile(path): 96 | path = osp.dirname(path) 97 | 98 | prev, cur = None, osp.abspath(osp.expanduser(path)) 99 | while cur != prev: 100 | if any(osp.exists(osp.join(cur, marker)) for marker in markers): 101 | return cur 102 | prev, cur = cur, osp.split(cur)[0] 103 | return None 104 | -------------------------------------------------------------------------------- /pointcept/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scheduler 3 | 4 | Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) 5 | Please cite our work if the code is helpful to you. 6 | """ 7 | 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | from .registry import Registry 10 | 11 | SCHEDULERS = Registry("schedulers") 12 | 13 | 14 | @SCHEDULERS.register_module() 15 | class MultiStepLR(lr_scheduler.MultiStepLR): 16 | def __init__( 17 | self, 18 | optimizer, 19 | milestones, 20 | total_steps, 21 | gamma=0.1, 22 | last_epoch=-1, 23 | verbose=False, 24 | ): 25 | super().__init__( 26 | optimizer=optimizer, 27 | milestones=[rate * total_steps for rate in milestones], 28 | gamma=gamma, 29 | last_epoch=last_epoch, 30 | verbose=verbose, 31 | ) 32 | 33 | 34 | @SCHEDULERS.register_module() 35 | class MultiStepWithWarmupLR(lr_scheduler.LambdaLR): 36 | def __init__( 37 | self, 38 | optimizer, 39 | milestones, 40 | total_steps, 41 | gamma=0.1, 42 | warmup_rate=0.05, 43 | warmup_scale=1e-6, 44 | last_epoch=-1, 45 | verbose=False, 46 | ): 47 | milestones = [rate * total_steps for rate in milestones] 48 | 49 | def multi_step_with_warmup(s): 50 | factor = 1.0 51 | for i in range(len(milestones)): 52 | if s < milestones[i]: 53 | break 54 | factor *= gamma 55 | 56 | if s <= warmup_rate * total_steps: 57 | warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * ( 58 | 1 - warmup_scale 59 | ) 60 | else: 61 | warmup_coefficient = 1.0 62 | return warmup_coefficient * factor 63 | 64 | super().__init__( 65 | optimizer=optimizer, 66 | lr_lambda=multi_step_with_warmup, 67 | last_epoch=last_epoch, 68 | verbose=verbose, 69 | ) 70 | 71 | 72 | @SCHEDULERS.register_module() 73 | class PolyLR(lr_scheduler.LambdaLR): 74 | def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False): 75 | super().__init__( 76 | optimizer=optimizer, 77 | lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power, 78 | last_epoch=last_epoch, 79 | verbose=verbose, 80 | ) 81 | 82 | 83 | @SCHEDULERS.register_module() 84 | class ExpLR(lr_scheduler.LambdaLR): 85 | def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False): 86 | super().__init__( 87 | optimizer=optimizer, 88 | lr_lambda=lambda s: gamma ** (s / total_steps), 89 | last_epoch=last_epoch, 90 | verbose=verbose, 91 | ) 92 | 93 | 94 | @SCHEDULERS.register_module() 95 | class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR): 96 | def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False): 97 | super().__init__( 98 | optimizer=optimizer, 99 | T_max=total_steps, 100 | eta_min=eta_min, 101 | last_epoch=last_epoch, 102 | verbose=verbose, 103 | ) 104 | 105 | 106 | @SCHEDULERS.register_module() 107 | class OneCycleLR(lr_scheduler.OneCycleLR): 108 | r""" 109 | torch.optim.lr_scheduler.OneCycleLR, Block total_steps 110 | """ 111 | 112 | def __init__( 113 | self, 114 | optimizer, 115 | max_lr, 116 | total_steps=None, 117 | pct_start=0.3, 118 | anneal_strategy="cos", 119 | cycle_momentum=True, 120 | base_momentum=0.85, 121 | max_momentum=0.95, 122 | div_factor=25.0, 123 | final_div_factor=1e4, 124 | three_phase=False, 125 | last_epoch=-1, 126 | verbose=False, 127 | ): 128 | super().__init__( 129 | optimizer=optimizer, 130 | max_lr=max_lr, 131 | total_steps=total_steps, 132 | pct_start=pct_start, 133 | anneal_strategy=anneal_strategy, 134 | cycle_momentum=cycle_momentum, 135 | base_momentum=base_momentum, 136 | max_momentum=max_momentum, 137 | div_factor=div_factor, 138 | final_div_factor=final_div_factor, 139 | three_phase=three_phase, 140 | last_epoch=last_epoch, 141 | verbose=verbose, 142 | ) 143 | 144 | 145 | def build_scheduler(cfg, optimizer): 146 | cfg.optimizer = optimizer 147 | return SCHEDULERS.build(cfg=cfg) 148 | -------------------------------------------------------------------------------- /pointcept/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # -*- coding: utf-8 -*- 3 | 4 | from time import perf_counter 5 | from typing import Optional 6 | 7 | 8 | class Timer: 9 | """ 10 | A timer which computes the time elapsed since the start/reset of the timer. 11 | """ 12 | 13 | def __init__(self) -> None: 14 | self.reset() 15 | 16 | def reset(self) -> None: 17 | """ 18 | Reset the timer. 19 | """ 20 | self._start = perf_counter() 21 | self._paused: Optional[float] = None 22 | self._total_paused = 0 23 | self._count_start = 1 24 | 25 | def pause(self) -> None: 26 | """ 27 | Pause the timer. 28 | """ 29 | if self._paused is not None: 30 | raise ValueError("Trying to pause a Timer that is already paused!") 31 | self._paused = perf_counter() 32 | 33 | def is_paused(self) -> bool: 34 | """ 35 | Returns: 36 | bool: whether the timer is currently paused 37 | """ 38 | return self._paused is not None 39 | 40 | def resume(self) -> None: 41 | """ 42 | Resume the timer. 43 | """ 44 | if self._paused is None: 45 | raise ValueError("Trying to resume a Timer that is not paused!") 46 | # pyre-fixme[58]: `-` is not supported for operand types `float` and 47 | # `Optional[float]`. 48 | self._total_paused += perf_counter() - self._paused 49 | self._paused = None 50 | self._count_start += 1 51 | 52 | def seconds(self) -> float: 53 | """ 54 | Returns: 55 | (float): the total number of seconds since the start/reset of the 56 | timer, excluding the time when the timer is paused. 57 | """ 58 | if self._paused is not None: 59 | end_time: float = self._paused # type: ignore 60 | else: 61 | end_time = perf_counter() 62 | return end_time - self._start - self._total_paused 63 | 64 | def avg_seconds(self) -> float: 65 | """ 66 | Returns: 67 | (float): the average number of seconds between every start/reset and 68 | pause. 69 | """ 70 | return self.seconds() / self._count_start 71 | -------------------------------------------------------------------------------- /pointcept/utils/visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualization Utils 3 | 4 | Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) 5 | Please cite our work if the code is helpful to you. 6 | """ 7 | 8 | import os 9 | import open3d as o3d 10 | import numpy as np 11 | import torch 12 | 13 | 14 | def to_numpy(x): 15 | if isinstance(x, torch.Tensor): 16 | x = x.clone().detach().cpu().numpy() 17 | assert isinstance(x, np.ndarray) 18 | return x 19 | 20 | 21 | def save_point_cloud(coord, color=None, file_path="pc.ply", logger=None): 22 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 23 | coord = to_numpy(coord) 24 | if color is not None: 25 | color = to_numpy(color) 26 | pcd = o3d.geometry.PointCloud() 27 | pcd.points = o3d.utility.Vector3dVector(coord) 28 | pcd.colors = o3d.utility.Vector3dVector( 29 | np.ones_like(coord) if color is None else color 30 | ) 31 | o3d.io.write_point_cloud(file_path, pcd) 32 | if logger is not None: 33 | logger.info(f"Save Point Cloud to: {file_path}") 34 | 35 | 36 | def save_bounding_boxes( 37 | bboxes_corners, color=(1.0, 0.0, 0.0), file_path="bbox.ply", logger=None 38 | ): 39 | bboxes_corners = to_numpy(bboxes_corners) 40 | # point list 41 | points = bboxes_corners.reshape(-1, 3) 42 | # line list 43 | box_lines = np.array( 44 | [ 45 | [0, 1], 46 | [1, 2], 47 | [2, 3], 48 | [3, 0], 49 | [4, 5], 50 | [5, 6], 51 | [6, 7], 52 | [7, 0], 53 | [0, 4], 54 | [1, 5], 55 | [2, 6], 56 | [3, 7], 57 | ] 58 | ) 59 | lines = [] 60 | for i, _ in enumerate(bboxes_corners): 61 | lines.append(box_lines + i * 8) 62 | lines = np.concatenate(lines) 63 | # color list 64 | color = np.array([color for _ in range(len(lines))]) 65 | # generate line set 66 | line_set = o3d.geometry.LineSet() 67 | line_set.points = o3d.utility.Vector3dVector(points) 68 | line_set.lines = o3d.utility.Vector2iVector(lines) 69 | line_set.colors = o3d.utility.Vector3dVector(color) 70 | o3d.io.write_line_set(file_path, line_set) 71 | 72 | if logger is not None: 73 | logger.info(f"Save Boxes to: {file_path}") 74 | 75 | 76 | def save_lines( 77 | points, lines, color=(1.0, 0.0, 0.0), file_path="lines.ply", logger=None 78 | ): 79 | points = to_numpy(points) 80 | lines = to_numpy(lines) 81 | colors = np.array([color for _ in range(len(lines))]) 82 | line_set = o3d.geometry.LineSet() 83 | line_set.points = o3d.utility.Vector3dVector(points) 84 | line_set.lines = o3d.utility.Vector2iVector(lines) 85 | line_set.colors = o3d.utility.Vector3dVector(colors) 86 | o3d.io.write_line_set(file_path, line_set) 87 | 88 | if logger is not None: 89 | logger.info(f"Save Lines to: {file_path}") 90 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | opencv-python 3 | transformers 4 | einops 5 | scikit-learn 6 | tensorboard 7 | tensorboardx 8 | yapf 9 | addict 10 | scipy 11 | timm 12 | open3d 13 | trimesh 14 | torch-scatter -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cd $(dirname $(dirname "$0")) || exit 4 | PYTHON=python 5 | 6 | TEST_CODE=eval.py 7 | 8 | DATASET=scannet 9 | CONFIG="None" 10 | EXP_NAME=debug 11 | WEIGHT=model_best 12 | GPU=None 13 | 14 | while getopts "p:d:c:n:w:g:" opt; do 15 | case $opt in 16 | p) 17 | PYTHON=$OPTARG 18 | ;; 19 | d) 20 | DATASET=$OPTARG 21 | ;; 22 | c) 23 | CONFIG=$OPTARG 24 | ;; 25 | n) 26 | EXP_NAME=$OPTARG 27 | ;; 28 | w) 29 | WEIGHT=$OPTARG 30 | ;; 31 | g) 32 | GPU=$OPTARG 33 | ;; 34 | \?) 35 | echo "Invalid option: -$OPTARG" 36 | ;; 37 | esac 38 | done 39 | 40 | if [ "${NUM_GPU}" = 'None' ] 41 | then 42 | NUM_GPU=`$PYTHON -c 'import torch; print(torch.cuda.device_count())'` 43 | fi 44 | 45 | echo "Experiment name: $EXP_NAME" 46 | echo "Python interpreter dir: $PYTHON" 47 | echo "Dataset: $DATASET" 48 | echo "GPU Num: $GPU" 49 | 50 | EXP_DIR=exp/${DATASET}/${EXP_NAME} 51 | MODEL_DIR=${EXP_DIR}/model 52 | CODE_DIR=${EXP_DIR}/code 53 | CONFIG_DIR=${EXP_DIR}/config.py 54 | 55 | if [ "${CONFIG}" = "None" ] 56 | then 57 | CONFIG_DIR=${EXP_DIR}/config.py 58 | else 59 | CONFIG_DIR=configs/${DATASET}/${CONFIG}.py 60 | fi 61 | 62 | echo "Loading config in:" $CONFIG_DIR 63 | export PYTHONPATH=./$CODE_DIR 64 | # export PYTHONPATH=./ 65 | echo "Running code in: $CODE_DIR" 66 | 67 | 68 | echo " =========> RUN TASK <=========" 69 | 70 | #$PYTHON -u "$CODE_DIR"/tools/$TEST_CODE \ 71 | $PYTHON -u launch/$TEST_CODE \ 72 | --config-file "$CONFIG_DIR" \ 73 | --num-gpus "$GPU" \ 74 | --options save_path="$EXP_DIR" weight="${MODEL_DIR}"/"${WEIGHT}".pth 75 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cd $(dirname $(dirname "$0")) || exit 4 | ROOT_DIR=$(pwd) 5 | PYTHON=python 6 | 7 | TRAIN_CODE=train.py 8 | 9 | DATASET=scannet 10 | CONFIG="None" 11 | EXP_NAME=debug 12 | WEIGHT="None" 13 | RESUME=false 14 | GPU=None 15 | OID="None" 16 | LABEL="None" 17 | 18 | 19 | while getopts "p:d:c:n:w:g:r:o:l:" opt; do 20 | case $opt in 21 | p) 22 | PYTHON=$OPTARG 23 | ;; 24 | d) 25 | DATASET=$OPTARG 26 | ;; 27 | c) 28 | CONFIG=$OPTARG 29 | ;; 30 | n) 31 | EXP_NAME=$OPTARG 32 | ;; 33 | w) 34 | WEIGHT=$OPTARG 35 | ;; 36 | r) 37 | RESUME=$OPTARG 38 | ;; 39 | g) 40 | GPU=$OPTARG 41 | ;; 42 | o) 43 | OID=$OPTARG 44 | ;; 45 | l) 46 | LABEL=$OPTARG 47 | ;; 48 | \?) 49 | echo "Invalid option: -$OPTARG" 50 | ;; 51 | esac 52 | done 53 | 54 | if [ "${NUM_GPU}" = 'None' ] 55 | then 56 | NUM_GPU=`$PYTHON -c 'import torch; print(torch.cuda.device_count())'` 57 | fi 58 | 59 | echo "Experiment name: $EXP_NAME" 60 | echo "Python interpreter dir: $PYTHON" 61 | echo "Dataset: $DATASET" 62 | echo "Config: $CONFIG" 63 | echo "GPU Num: $GPU" 64 | 65 | EXP_DIR=exp/${DATASET}/${EXP_NAME} 66 | MODEL_DIR=${EXP_DIR}/model 67 | CODE_DIR=${EXP_DIR}/code 68 | CONFIG_DIR=configs/${DATASET}/${CONFIG}.py 69 | 70 | 71 | echo " =========> CREATE EXP DIR <=========" 72 | echo "Experiment dir: $ROOT_DIR/$EXP_DIR" 73 | if ${RESUME} 74 | then 75 | CONFIG_DIR=${EXP_DIR}/config.py 76 | WEIGHT=$MODEL_DIR/model_last.pth 77 | else 78 | mkdir -p "$MODEL_DIR" "$CODE_DIR" 79 | cp -r scripts launch pointcept "$CODE_DIR" 80 | fi 81 | 82 | echo "Loading config in:" $CONFIG_DIR 83 | # export PYTHONPATH=./$CODE_DIR:/usr/local/lib/python3.8/dist-packages 84 | export PYTHONPATH=./$CODE_DIR:/opt/conda/envs/part/lib/python3.10/site-packages 85 | echo "Running code in: $CODE_DIR" 86 | 87 | 88 | echo " =========> RUN TASK <=========" 89 | 90 | if [ "${WEIGHT}" = "None" ] 91 | then 92 | $PYTHON "$CODE_DIR"/launch/$TRAIN_CODE \ 93 | --config-file "$CONFIG_DIR" \ 94 | --num-gpus "$GPU" \ 95 | --options save_path="$EXP_DIR" oid="$OID" label="$LABEL" 96 | else 97 | $PYTHON "$CODE_DIR"/launch/$TRAIN_CODE \ 98 | --config-file "$CONFIG_DIR" \ 99 | --num-gpus "$GPU" \ 100 | --options save_path="$EXP_DIR" resume="$RESUME" weight="$WEIGHT" 101 | fi -------------------------------------------------------------------------------- /tools/highlight_parts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import trimesh 4 | import json 5 | import cv2 6 | import os 7 | from os.path import join 8 | from PIL import Image, ImageDraw 9 | from scipy import ndimage as ndi 10 | import pointops 11 | import sys 12 | sys.path.append(os.path.abspath("..")) 13 | 14 | from pointcept.datasets.sampart3d_util import * 15 | 16 | 17 | def cal_mapping_2d_3d(render_dir, mesh_path): 18 | mesh = trimesh.load(mesh_path) 19 | if isinstance(mesh, trimesh.Scene): 20 | mesh = mesh.dump(concatenate=True) 21 | samples, face_index, colors = sample_surface(mesh, 50000, sample_color=True) 22 | face_index = torch.from_numpy(face_index).int() 23 | face_index = torch.concat([face_index, torch.tensor([-1]).int()]) 24 | 25 | meta_data = json.load(open(join(render_dir, "meta.json"))) 26 | mesh_scale = meta_data["scaling_factor"] 27 | mesh_center_offset = meta_data["mesh_offset"] 28 | 29 | object_org_coord = samples 30 | rotation_matrix = np.array([ 31 | [1, 0, 0], 32 | [0, 0, 1], 33 | [0, -1, 0]]) 34 | object_org_coord = np.dot(object_org_coord, rotation_matrix) 35 | object_org_coord = object_org_coord * mesh_scale + mesh_center_offset 36 | object_org_coord = torch.from_numpy(object_org_coord).to("cuda").contiguous().float() 37 | obj_offset = torch.tensor(object_org_coord.shape[0]).to("cuda") 38 | 39 | mapping_list = [] 40 | camera_angle_x = meta_data['camera_angle_x'] 41 | for i, c2w_opengl in enumerate(meta_data["transforms"]): 42 | c2w_opengl = np.array(c2w_opengl) 43 | rgb_path = join(render_dir, f"render_{i:04d}.webp") 44 | img = np.array(Image.open(rgb_path)) 45 | if img.shape[-1] == 4: 46 | mask_img = img[..., 3] == 0 47 | img[mask_img] = [255, 255, 255, 255] 48 | img = img[..., :3] 49 | img = Image.fromarray(img.astype('uint8')) 50 | 51 | # Calculate mapping 52 | depth_path = join(render_dir, f"depth_{i:04d}.exr") 53 | depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) 54 | depth = depth[..., 0] 55 | depth_valid = torch.tensor(depth < 65500.0) 56 | 57 | org_points = gen_pcd(depth, c2w_opengl, camera_angle_x) 58 | # points = torch.from_numpy(points).to(self.device).contiguous().float() 59 | org_points = torch.from_numpy(org_points) 60 | points_tensor = org_points.to("cuda").contiguous().float() 61 | offset = torch.tensor(points_tensor.shape[0]).to("cuda") 62 | indices, distances = pointops.knn_query(1, object_org_coord, obj_offset, points_tensor, offset) 63 | mapping = torch.zeros((depth.shape[0], depth.shape[1]), dtype=torch.int) - 1 64 | 65 | mask_dis = distances[..., 0] < 0.03 66 | indices[~mask_dis] = -1 67 | mapping[depth_valid] = face_index[indices.cpu().flatten()] 68 | 69 | mapping_list.append(mapping.cpu().numpy()) 70 | return np.stack(mapping_list) 71 | 72 | 73 | def highlight_parts_in_multi_views(render_dir, mesh_path, results_dir, save_dir, img_num=1): 74 | 75 | print(f"Processing {mesh_path}") 76 | obj_mapping = cal_mapping_2d_3d(render_dir, mesh_path) 77 | scale_list = ["0.0", "0.5", "1.0", "1.5", "2.0"] 78 | for scale in scale_list: 79 | ins_pred = np.load(join(results_dir, f"mesh_{scale}.npy")) 80 | # Get the number of images and the number of classes 81 | num_images = obj_mapping.shape[0] 82 | num_classes = np.max(ins_pred) + 1 83 | # Initialize an array to store the pixel count for each class in each image 84 | pixel_count = np.zeros((num_images, num_classes), dtype=np.int32) 85 | # Iterate over each image 86 | for i in range(num_images): 87 | # Get the group numbers for each pixel in the image 88 | valid_areas = obj_mapping[i] != -1 89 | groups = ins_pred[obj_mapping[i][valid_areas]] 90 | # Count the number of pixels for each group 91 | pixel_count[i], _ = np.histogram(groups, bins=np.arange(num_classes + 1) - 0.5) 92 | # Find the top 1 images for each class 93 | top_image_ids = np.argsort(-pixel_count, axis=0)[:img_num] 94 | # top_image_ids = np.stack([top_image_ids[0, :], top_image_ids[2, :], top_image_ids[4, :]]) 95 | 96 | save_path = join(save_dir, scale) 97 | os.makedirs(save_path, exist_ok=True) 98 | for part_id in range(ins_pred.max()+1): 99 | img_id_list = top_image_ids[:, part_id] 100 | for topj, img_id in enumerate(img_id_list): 101 | image = np.array(Image.open(join(render_dir, f"render_{img_id:04d}.webp"))) 102 | if image.shape[-1] == 4: 103 | mask_img = image[..., 3] == 0 104 | image[mask_img] = [255, 255, 255, 255] 105 | image = image[..., :3] 106 | image = Image.fromarray(image) 107 | valid_areas = obj_mapping[img_id] != -1 108 | mask = np.zeros_like(obj_mapping[img_id], dtype=bool) 109 | mask[valid_areas] = (ins_pred[obj_mapping[img_id][valid_areas]] == part_id) 110 | 111 | # Find the edges of the mask 112 | edges = ndi.binary_dilation(mask, iterations=1) ^ mask 113 | # Draw a red circle around the edges 114 | draw = ImageDraw.Draw(image) 115 | for y, x in np.argwhere(edges): 116 | draw.ellipse([x-2, y-2, x+2, y+2], fill='red') 117 | image.save(join(save_path, f"{part_id}-{topj}.png")) 118 | 119 | 120 | if __name__ == '__main__': 121 | render_dir = "" 122 | mesh_path = "" 123 | results_dir = "" 124 | save_dir = "" 125 | highlight_parts_in_multi_views(render_dir, mesh_path, results_dir, save_dir) --------------------------------------------------------------------------------