├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── config
└── hais_run1_scannet.yaml
├── data
└── scannetv2_inst.py
├── dataset
└── scannetv2
│ ├── prepare_data_inst.py
│ ├── prepare_data_inst_gttxt.py
│ └── scannet_util.py
├── docs
├── STPLS3D_leaderboard.png
├── framework.png
├── scannet_leaderboard.png
├── scene0249_00_output_2.gif
└── scene0430_00_output_2.gif
├── lib
└── hais_ops
│ ├── functions
│ └── hais_ops.py
│ ├── setup.py
│ └── src
│ ├── bfs_cluster
│ ├── bfs_cluster.cpp
│ ├── bfs_cluster.cu
│ └── bfs_cluster.h
│ ├── cal_iou_and_masklabel
│ ├── cal_iou_and_masklabel.cpp
│ ├── cal_iou_and_masklabel.cu
│ └── cal_iou_and_masklabel.h
│ ├── cuda.cu
│ ├── cuda_utils.h
│ ├── datatype
│ ├── datatype.cpp
│ └── datatype.h
│ ├── get_iou
│ ├── get_iou.cpp
│ ├── get_iou.cu
│ └── get_iou.h
│ ├── hais_ops.cpp
│ ├── hais_ops.h
│ ├── hais_ops_api.cpp
│ ├── hierarchical_aggregation
│ ├── hierarchical_aggregation.cpp
│ ├── hierarchical_aggregation.cu
│ └── hierarchical_aggregation.h
│ ├── roipool
│ ├── roipool.cpp
│ ├── roipool.cu
│ └── roipool.h
│ ├── sec_mean
│ ├── sec_mean.cpp
│ ├── sec_mean.cu
│ └── sec_mean.h
│ └── voxelize
│ ├── voxelize.cpp
│ ├── voxelize.cu
│ └── voxelize.h
├── model
└── hais
│ └── hais.py
├── requirements.txt
├── test.py
├── train.py
├── util
├── config.py
├── eval.py
├── log.py
├── utils.py
└── utils_3d.py
└── visualize_open3d.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ## General
2 |
3 |
4 | # Compiled Object files
5 | *.slo
6 | *.lo
7 | *.o
8 | *.cuo
9 |
10 | # Compiled Dynamic libraries
11 | *.so
12 | *.dylib
13 |
14 | # Compiled Static libraries
15 | *.lai
16 | *.la
17 | *.a
18 |
19 | # Compiled protocol buffers
20 | *.pb.h
21 | *.pb.cc
22 | *_pb2.py
23 |
24 | # Compiled python
25 | *.pyc
26 |
27 | # Compiled MATLAB
28 | *.mex*
29 |
30 | # IPython notebook checkpoints
31 | .ipynb_checkpoints
32 |
33 | # Editor temporaries
34 | *.swp
35 | *~
36 |
37 | # Sublime Text settings
38 | *.sublime-workspace
39 | *.sublime-project
40 |
41 | # Eclipse Project settings
42 | *.*project
43 | .settings
44 |
45 | # QtCreator files
46 | *.user
47 |
48 | # PyCharm files
49 | .idea
50 |
51 | # Visual Studio Code files
52 | .vscode
53 |
54 | # OSX dir files
55 | .DS_Store
56 |
57 | # personal
58 | __pycache__/
59 | exp/
60 | *.egg-info/
61 | build/
62 | dist/
63 |
64 | *.tsv
65 | *.npy
66 | *.zip
67 | dataset/scannetv2/train
68 | dataset/scannetv2/val
69 | dataset/scannetv2/test
70 | dataset/scannetv2/val_gt
71 | dataset/scannetv2/scannetv2-labels.combined.tsv
72 |
73 |
74 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "lib/spconv"]
2 | path = lib/spconv
3 | url = https://github.com/llijiang/spconv
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Hust Visual Learning Team
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
HAIS
4 |
Hierarchical Aggregation for 3D Instance Segmentation [ICCV 2021]
5 |
6 |
7 |
8 | by
9 |
Shaoyu Chen, Jiemin Fang, Qian Zhang, Wenyu Liu, Xinggang Wang
†. (
†: corresponding author)
10 |
11 |
12 |
13 |
14 |
15 |
16 |

17 |
18 |
19 |
20 |
21 |
22 |
23 | [](https://paperswithcode.com/sota/3d-instance-segmentation-on-scannetv2?p=hierarchical-aggregation-for-3d-instance) [](https://paperswithcode.com/sota/3d-instance-segmentation-on-s3dis?p=hierarchical-aggregation-for-3d-instance)
24 |
25 |
26 |
27 | ## Update
28 | #### 2022.4.28:
29 | * HAIS serves as a baseline of [STPLS3D](https://www.stpls3d.com/) dataset. Code of HAIS on [STPLS3D](https://www.stpls3d.com/) is available on [this Github repo] .
30 | [STPLS3D](https://www.stpls3d.com/) is a large-scale photogrammetry 3D point cloud dataset, composed of high-quality, rich-annotated point clouds from real-world and synthetic environments.
31 |
34 |
35 |
36 | 
37 |
38 |
39 |
40 | #### 2021.9.30:
41 | * Code is released.
42 | * With better CUDA optimization, HAIS now only takes 339 ms on TITAN X, much better than the latency reported in the paper (410 ms on TITAN X).
43 |
44 |
45 |
46 | ## Introduction
47 | * HAIS is an efficient and concise bottom-up framework (NMS-free and single-forward) for point cloud instance segmentation. It adopts the hierarchical aggregation (point aggregation and set aggregation) to generate instances and the intra-instance prediction for outlier filtering and mask quality scoring.
48 |
49 | 
50 |
51 | * **High performance**. HAIS [ranks 1st](http://kaldir.vc.in.tum.de/scannet_benchmark/semantic_instance_3d) on the [ScanNet benchmark](http://kaldir.vc.in.tum.de/scannet_benchmark/semantic_instance_3d) (Aug. 8th, 2021).
52 |
53 | 
54 |
55 | * **High speed**. Thanks to the NMS-free and single-forward inference design, HAIS achieves the best inference speed among all existing methods. HAIS only takes **206 ms** on RTX 3090 and **339 ms** on TITAN X.
56 |
57 | | Method | Per-frame latency on TITAN X|
58 | | :-: | :-: |
59 | |ASIS|181913 ms|
60 | |SGPN|158439 ms|
61 | |3D-SIS|124490 ms|
62 | |GSPN|12702 ms|
63 | |3D-BoNet|9202 ms|
64 | |GICN|8615 ms|
65 | |OccuSeg|1904 ms|
66 | |PointGroup|452 ms|
67 | |**HAIS**|**339 ms**|
68 |
69 |
70 |
71 | ## Installation
72 |
73 | 1\) Environment
74 |
75 | * Python 3.x
76 | * Pytorch 1.1 or higher
77 | * CUDA 9.2 or higher
78 | * gcc-5.4 or higher
79 |
80 | Create a conda virtual environment and activate it.
81 | ```
82 | conda create -n hais python=3.7
83 | conda activate hais
84 | ```
85 |
86 |
87 | 2\) Clone the repository.
88 | ```
89 | git clone https://github.com/hustvl/HAIS.git --recursive
90 | ```
91 |
92 |
93 | 3\) Install the requirements.
94 | ```
95 | cd HAIS
96 | pip install -r requirements.txt
97 | conda install -c bioconda google-sparsehash
98 | ```
99 |
100 | 4\) Install spconv
101 |
102 | * Verify the version of spconv.
103 |
104 | spconv 1.0, compatible with CUDA < 11 and pytorch < 1.5, is already recursively cloned in `HAIS/lib/spconv` in step 2) by default.
105 |
106 | For higher version CUDA and pytorch, spconv 1.2 is suggested. Replace `HAIS/lib/spconv` with this fork of spconv.
107 |
108 | ```
109 | git clone https://github.com/outsidercsy/spconv.git --recursive
110 | ```
111 |
112 | Note: In the provided spconv 1.0 and 1.2, spconv\spconv\functional.py is modified to make grad_output contiguous. Make sure you use the modified spconv but not the original one. Or there would be some bugs of optimization.
113 |
114 |
115 | * Install the dependent libraries.
116 | ```
117 | conda install libboost
118 | conda install -c daleydeng gcc-5 # (optional, install gcc-5.4 in conda env)
119 | ```
120 |
121 | * Compile the spconv library.
122 | ```
123 | cd HAIS/lib/spconv
124 | python setup.py bdist_wheel
125 | ```
126 |
127 | * Intall the generated .whl file.
128 | ```
129 | cd HAIS/lib/spconv/dist
130 | pip install {wheel_file_name}.whl
131 | ```
132 |
133 |
134 | 5\) Compile the external C++ and CUDA ops.
135 | ```
136 | cd HAIS/lib/hais_ops
137 | export CPLUS_INCLUDE_PATH={conda_env_path}/hais/include:$CPLUS_INCLUDE_PATH
138 | python setup.py build_ext develop
139 | ```
140 | {conda_env_path} is the location of the created conda environment, e.g., `/anaconda3/envs`.
141 |
142 |
143 |
144 | ## Data Preparation
145 |
146 | 1\) Download the [ScanNet](http://www.scan-net.org/) v2 dataset.
147 |
148 | 2\) Put the data in the corresponding folders.
149 | * Copy the files `[scene_id]_vh_clean_2.ply`, `[scene_id]_vh_clean_2.labels.ply`, `[scene_id]_vh_clean_2.0.010000.segs.json` and `[scene_id].aggregation.json` into the `dataset/scannetv2/train` and `dataset/scannetv2/val` folders according to the ScanNet v2 train/val [split](https://github.com/ScanNet/ScanNet/tree/master/Tasks/Benchmark).
150 |
151 | * Copy the files `[scene_id]_vh_clean_2.ply` into the `dataset/scannetv2/test` folder according to the ScanNet v2 test [split](https://github.com/ScanNet/ScanNet/tree/master/Tasks/Benchmark).
152 |
153 | * Put the file `scannetv2-labels.combined.tsv` in the `dataset/scannetv2` folder.
154 |
155 | The dataset files are organized as follows.
156 | ```
157 | HAIS
158 | ├── dataset
159 | │ ├── scannetv2
160 | │ │ ├── train
161 | │ │ │ ├── [scene_id]_vh_clean_2.ply & [scene_id]_vh_clean_2.labels.ply & [scene_id]_vh_clean_2.0.010000.segs.json & [scene_id].aggregation.json
162 | │ │ ├── val
163 | │ │ │ ├── [scene_id]_vh_clean_2.ply & [scene_id]_vh_clean_2.labels.ply & [scene_id]_vh_clean_2.0.010000.segs.json & [scene_id].aggregation.json
164 | │ │ ├── test
165 | │ │ │ ├── [scene_id]_vh_clean_2.ply
166 | │ │ ├── scannetv2-labels.combined.tsv
167 | ```
168 |
169 | 3\) Generate input files `[scene_id]_inst_nostuff.pth` for instance segmentation.
170 | ```
171 | cd HAIS/dataset/scannetv2
172 | python prepare_data_inst.py --data_split train
173 | python prepare_data_inst.py --data_split val
174 | python prepare_data_inst.py --data_split test
175 | ```
176 |
177 | ## Training
178 | ```
179 | CUDA_VISIBLE_DEVICES=0 python train.py --config config/hais_run1_scannet.yaml
180 | ```
181 |
182 |
183 | ## Inference
184 |
185 | 1\) To evaluate on validation set,
186 |
187 | * prepare the `.txt` instance ground-truth files as the following.
188 | ```
189 | cd dataset/scannetv2
190 | python prepare_data_inst_gttxt.py
191 | ```
192 |
193 | * set `split` and `eval` in the config file as `val` and `True`.
194 |
195 | * Run the inference and evaluation code.
196 | ```
197 | CUDA_VISIBLE_DEVICES=0 python test.py --config config/hais_run1_scannet.yaml --pretrain $PATH_TO_PRETRAIN_MODEL$
198 | ```
199 |
200 |
201 | Pretrained model: [Google Drive](https://drive.google.com/file/d/1XGNswNrbjm33SwpemYxVEoK4o46EOazd/view?usp=sharing) / [Baidu Cloud](https://pan.baidu.com/s/12dx-39jBOyU9QzGlpgJ8OQ) [code: sh4t].
202 | mAP/mAP50/mAP25 is 44.1/64.4/75.7.
203 |
204 |
205 |
206 | 2\) To evaluate on test set,
207 |
208 | * Set (`split`, `eval`, `save_instance`) as (`test`, `False`, `True`).
209 | * Run the inference code. Prediction results are saved in `HAIS/exp` by default.
210 | ```
211 | CUDA_VISIBLE_DEVICES=0 python test.py --config config/hais_run1_scannet.yaml --pretrain $PATH_TO_PRETRAIN_MODEL$
212 | ```
213 |
214 | * Transform the prediction results into the [submission format](http://kaldir.vc.in.tum.de/scannet_benchmark/documentation).
215 | * Submit the results to the [official evaluation server](http://kaldir.vc.in.tum.de/scannet_benchmark/submissions).
216 |
217 |
218 |
219 |
220 | ## Visualization
221 | We provide visualization tools based on Open3D (tested on Open3D 0.8.0).
222 | ```
223 | pip install open3D==0.8.0
224 | python visualize_open3d.py --data_path {} --prediction_path {} --data_split {} --room_name {} --task {}
225 | ```
226 | Please refer to `visualize_open3d.py` for more details.
227 |
228 | Demo:
229 |
230 |
231 |

232 |
233 |
234 |
235 |
236 | ## Acknowledgement
237 | The code is based on [PointGroup](https://github.com/dvlab-research/PointGroup) and [spconv](https://github.com/traveller59/spconv). And thank [STPLS3D](https://www.stpls3d.com/) for extending HAIS.
238 |
239 |
240 | ## Contact
241 | If you have any questions or suggestions about this repo, please feel free to contact me (shaoyuchen@hust.edu.cn).
242 |
243 |
244 | ## Citing HAIS
245 | If you find HAIS is useful in your research or applications, please consider giving us a star 🌟 and citing HAIS by the following BibTeX entry.
246 |
247 | ```BibTeX
248 | @InProceedings{Chen_HAIS_2021_ICCV,
249 | author = {Chen, Shaoyu and Fang, Jiemin and Zhang, Qian and Liu, Wenyu and Wang, Xinggang},
250 | title = {Hierarchical Aggregation for 3D Instance Segmentation},
251 | booktitle = {ICCV},
252 | year = {2021},
253 | }
254 | ```
255 |
--------------------------------------------------------------------------------
/config/hais_run1_scannet.yaml:
--------------------------------------------------------------------------------
1 | GENERAL:
2 | task: train # train, test
3 | manual_seed: 123
4 | model_dir: model/hais/hais.py
5 | dataset_dir: data/scannetv2_inst.py
6 |
7 | DATA:
8 | data_root: dataset
9 | dataset: scannetv2
10 | filename_suffix: _inst_nostuff.pth
11 |
12 | classes: 20
13 | ignore_label: -100
14 |
15 | input_channel: 3
16 | scale: 50 # voxel_size = 1 / scale, scale 50 -> voxel_size 0.02m
17 | batch_size: 4
18 | full_scale: [128, 512]
19 | max_npoint: 250000
20 | mode: 4 # 4=mean
21 |
22 | STRUCTURE:
23 | model_name: hais
24 | width: 32
25 | block_residual: True
26 | block_reps: 2
27 | use_coords: True
28 |
29 | TRAIN:
30 | epochs: 500
31 | train_workers: 8 # data loader workers
32 | optim: Adam # Adam or SGD
33 | lr: 0.001
34 | step_epoch: 200
35 | multiplier: 0.5
36 | momentum: 0.9
37 | weight_decay: 0.0001
38 | save_freq: 16 # also eval_freq
39 | loss_weight: [1.0, 1.0, 1.0, 1.0] # semantic_loss, offset_norm_loss, score_loss, mask_loss
40 | fg_thresh: 1.
41 | bg_thresh: 0.
42 | score_scale: 50 # the minimal voxel size is 2cm
43 | score_fullscale: 20
44 | score_mode: 4 # mean
45 | pretrain_path:
46 | pretrain_module: []
47 | fix_module: []
48 |
49 |
50 | point_aggr_radius: 0.03
51 | cluster_shift_meanActive: 300
52 | prepare_epochs: 100
53 |
54 | cal_iou_based_on_mask: True
55 | cal_iou_based_on_mask_start_epoch: 200
56 |
57 | use_mask_filter_score_feature: True
58 | use_mask_filter_score_feature_start_epoch: 200
59 | mask_filter_score_feature_thre: 0.5
60 |
61 | using_set_aggr_in_training: False
62 | using_set_aggr_in_testing: True
63 |
64 | max_proposal_num: 200
65 |
66 | TEST:
67 | split: val
68 | test_epoch: 500
69 | test_workers: 16
70 | test_seed: 567
71 |
72 | using_NMS: False
73 | TEST_NMS_THRESH: 0.3
74 | TEST_SCORE_THRESH: 0.09
75 | TEST_NPOINT_THRESH: 100
76 |
77 | eval: True
78 | save_semantic: False
79 | save_pt_offsets: False
80 | save_instance: False
81 |
82 | test_mask_score_thre: -0.5 # bias fg << bg
83 |
84 |
85 |
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/data/scannetv2_inst.py:
--------------------------------------------------------------------------------
1 | import os, sys, glob, math, numpy as np
2 | import scipy.ndimage
3 | import scipy.interpolate
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 | sys.path.append('../')
8 |
9 | from util.config import cfg
10 | from util.log import logger
11 | from lib.hais_ops.functions import hais_ops
12 |
13 | import torch.distributed as dist
14 |
15 | class Dataset:
16 | def __init__(self, test=False):
17 | self.data_root = cfg.data_root
18 | self.dataset = cfg.dataset
19 | self.filename_suffix = cfg.filename_suffix
20 |
21 | self.batch_size = cfg.batch_size
22 | self.train_workers = cfg.train_workers
23 | self.val_workers = cfg.train_workers
24 |
25 | self.full_scale = cfg.full_scale
26 | self.scale = cfg.scale
27 | self.max_npoint = cfg.max_npoint
28 | self.mode = cfg.mode
29 |
30 | self.train_split = getattr(cfg, 'train_split', 'train')
31 |
32 | if test:
33 | self.test_split = cfg.split # val or test
34 | self.test_workers = cfg.test_workers
35 | cfg.batch_size = 1
36 |
37 |
38 | def trainLoader(self):
39 | if self.train_split == 'trainval':
40 | train_file_names = sorted(glob.glob(os.path.join(self.data_root, self.dataset, 'train', '*' + self.filename_suffix))
41 | + glob.glob(os.path.join(self.data_root, self.dataset, 'val', '*' + self.filename_suffix))
42 | )
43 | elif self.train_split == 'train':
44 | train_file_names = sorted(glob.glob(os.path.join(self.data_root, self.dataset, 'train', '*' + self.filename_suffix)))
45 | else:
46 | raise Exception
47 |
48 | self.train_files = [torch.load(i) for i in train_file_names]
49 |
50 | logger.info('Training samples: {}'.format(len(self.train_files)))
51 |
52 | train_set = list(range(len(self.train_files)))
53 | self.train_data_loader = DataLoader(train_set, batch_size=self.batch_size, collate_fn=self.trainMerge, num_workers=self.train_workers,
54 | shuffle=True, sampler=None, drop_last=True, pin_memory=True)
55 |
56 |
57 | def dist_trainLoader(self):
58 | train_file_names = sorted(glob.glob(os.path.join(self.data_root, self.dataset, 'train', '*' + self.filename_suffix)))
59 | self.train_files = [torch.load(i) for i in train_file_names]
60 |
61 | logger.info('Training samples: {}'.format(len(self.train_files)))
62 |
63 | train_set = list(range(len(self.train_files)))
64 | # self.train_data_loader = DataLoader(train_set, batch_size=self.batch_size, collate_fn=self.trainMerge, num_workers=self.train_workers,
65 | # shuffle=True, sampler=None, drop_last=True, pin_memory=True)
66 |
67 | # world_size = dist.get_world_size()
68 | # rank = dist.get_rank()
69 | # self.data_sampler = torch.utils.data.distributed.DistributedSampler(train_set, num_replicas=world_size, rank=rank)
70 | self.data_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
71 |
72 | self.train_data_loader = DataLoader(train_set, batch_size=self.batch_size,
73 | collate_fn=self.trainMerge,
74 | num_workers=self.train_workers,
75 | shuffle=False, sampler=self.data_sampler,
76 | drop_last=False, pin_memory=True)
77 |
78 |
79 |
80 | def valLoader(self):
81 | val_file_names = sorted(glob.glob(os.path.join(self.data_root, self.dataset, 'val', '*' + self.filename_suffix)))
82 | self.val_files = [torch.load(i) for i in val_file_names]
83 |
84 | logger.info('Validation samples: {}'.format(len(self.val_files)))
85 |
86 | val_set = list(range(len(self.val_files)))
87 | self.val_data_loader = DataLoader(val_set, batch_size=self.batch_size, collate_fn=self.valMerge, num_workers=self.val_workers,
88 | shuffle=False, drop_last=False, pin_memory=True)
89 |
90 |
91 | def testLoader(self):
92 | self.test_file_names = sorted(glob.glob(os.path.join(self.data_root, self.dataset, self.test_split, '*' + self.filename_suffix)))
93 | self.test_files = [torch.load(i) for i in self.test_file_names]
94 |
95 | logger.info('Testing samples ({}): {}'.format(self.test_split, len(self.test_files)))
96 |
97 | test_set = list(np.arange(len(self.test_files)))
98 | self.test_data_loader = DataLoader(test_set, batch_size=1, collate_fn=self.testMerge, num_workers=self.test_workers,
99 | shuffle=False, drop_last=False, pin_memory=True)
100 |
101 | # Elastic distortion
102 | def elastic(self, x, gran, mag):
103 | blur0 = np.ones((3, 1, 1)).astype('float32') / 3
104 | blur1 = np.ones((1, 3, 1)).astype('float32') / 3
105 | blur2 = np.ones((1, 1, 3)).astype('float32') / 3
106 |
107 | bb = np.abs(x).max(0).astype(np.int32)//gran + 3
108 | noise = [np.random.randn(bb[0], bb[1], bb[2]).astype('float32') for _ in range(3)]
109 | noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
110 | noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
111 | noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
112 | noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
113 | noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
114 | noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
115 | ax = [np.linspace(-(b-1)*gran, (b-1)*gran, b) for b in bb]
116 | interp = [scipy.interpolate.RegularGridInterpolator(ax, n, bounds_error=0, fill_value=0) for n in noise]
117 | def g(x_):
118 | return np.hstack([i(x_)[:,None] for i in interp])
119 | return x + g(x) * mag
120 |
121 |
122 | def getInstanceInfo(self, xyz, instance_label):
123 | '''
124 | :param xyz: (n, 3)
125 | :param instance_label: (n), int, (0~nInst-1, -100)
126 | :return: instance_num, dict
127 | '''
128 | instance_info = np.ones((xyz.shape[0], 9), dtype=np.float32) * -100.0 # (n, 9), float, (cx, cy, cz, minx, miny, minz, maxx, maxy, maxz)
129 | instance_pointnum = [] # (nInst), int
130 | instance_num = int(instance_label.max()) + 1
131 | for i_ in range(instance_num):
132 | inst_idx_i = np.where(instance_label == i_)
133 |
134 | # instance_info
135 | xyz_i = xyz[inst_idx_i]
136 | min_xyz_i = xyz_i.min(0)
137 | max_xyz_i = xyz_i.max(0)
138 | mean_xyz_i = xyz_i.mean(0)
139 | instance_info_i = instance_info[inst_idx_i]
140 | instance_info_i[:, 0:3] = mean_xyz_i
141 | instance_info_i[:, 3:6] = min_xyz_i
142 | instance_info_i[:, 6:9] = max_xyz_i
143 | instance_info[inst_idx_i] = instance_info_i
144 |
145 | # instance_pointnum
146 | instance_pointnum.append(inst_idx_i[0].size)
147 |
148 | return instance_num, {"instance_info": instance_info, "instance_pointnum": instance_pointnum}
149 |
150 |
151 | def dataAugment(self, xyz, jitter=False, flip=False, rot=False):
152 | m = np.eye(3)
153 | if jitter:
154 | m += np.random.randn(3, 3) * 0.1
155 | if flip:
156 | m[0][0] *= np.random.randint(0, 2) * 2 - 1 # flip x randomly
157 | if rot:
158 | theta = np.random.rand() * 2 * math.pi
159 | m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0], [-math.sin(theta), math.cos(theta), 0], [0, 0, 1]]) # rotation
160 | return np.matmul(xyz, m)
161 |
162 |
163 | def crop(self, xyz):
164 | '''
165 | :param xyz: (n, 3) >= 0
166 | '''
167 | xyz_offset = xyz.copy()
168 | valid_idxs = (xyz_offset.min(1) >= 0)
169 | assert valid_idxs.sum() == xyz.shape[0]
170 |
171 | full_scale = np.array([self.full_scale[1]] * 3)
172 | room_range = xyz.max(0) - xyz.min(0)
173 | while (valid_idxs.sum() > self.max_npoint):
174 | offset = np.clip(full_scale - room_range + 0.001, None, 0) * np.random.rand(3)
175 | xyz_offset = xyz + offset
176 | valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < full_scale).sum(1) == 3)
177 | full_scale[:2] -= 32
178 |
179 | return xyz_offset, valid_idxs
180 |
181 |
182 | def getCroppedInstLabel(self, instance_label, valid_idxs):
183 | instance_label = instance_label[valid_idxs]
184 | j = 0
185 | while (j < instance_label.max()):
186 | if (len(np.where(instance_label == j)[0]) == 0):
187 | instance_label[instance_label == instance_label.max()] = j
188 | j += 1
189 | return instance_label
190 |
191 |
192 | def trainMerge(self, id):
193 | locs = []
194 | locs_float = []
195 | feats = []
196 | labels = []
197 | instance_labels = []
198 |
199 | instance_infos = [] # (N, 9)
200 | instance_pointnum = [] # (total_nInst), int
201 |
202 | batch_offsets = [0]
203 |
204 | total_inst_num = 0
205 | for i, idx in enumerate(id):
206 | xyz_origin, rgb, label, instance_label = self.train_files[idx]
207 |
208 |
209 | # jitter / flip x / rotation
210 | xyz_middle = self.dataAugment(xyz_origin, True, True, True)
211 |
212 | # scale
213 | xyz = xyz_middle * self.scale
214 |
215 | # elastic
216 | xyz = self.elastic(xyz, 6 * self.scale // 50, 40 * self.scale / 50)
217 | xyz = self.elastic(xyz, 20 * self.scale // 50, 160 * self.scale / 50)
218 |
219 | # offset
220 | xyz -= xyz.min(0)
221 |
222 | # crop
223 | xyz, valid_idxs = self.crop(xyz)
224 |
225 | xyz_middle = xyz_middle[valid_idxs]
226 | xyz = xyz[valid_idxs]
227 | rgb = rgb[valid_idxs]
228 | label = label[valid_idxs]
229 | instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
230 |
231 | # get instance information
232 | inst_num, inst_infos = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32))
233 | inst_info = inst_infos["instance_info"] # (n, 9), (cx, cy, cz, minx, miny, minz, maxx, maxy, maxz)
234 | inst_pointnum = inst_infos["instance_pointnum"] # (nInst), list
235 |
236 | instance_label[np.where(instance_label != -100)] += total_inst_num
237 | total_inst_num += inst_num
238 |
239 | # merge the scene to the batch
240 | batch_offsets.append(batch_offsets[-1] + xyz.shape[0])
241 |
242 | locs.append(torch.cat([torch.LongTensor(xyz.shape[0], 1).fill_(i), torch.from_numpy(xyz).long()], 1))
243 | locs_float.append(torch.from_numpy(xyz_middle))
244 | feats.append(torch.from_numpy(rgb) + torch.randn(3) * 0.1)
245 | labels.append(torch.from_numpy(label))
246 | instance_labels.append(torch.from_numpy(instance_label))
247 |
248 | instance_infos.append(torch.from_numpy(inst_info))
249 | instance_pointnum.extend(inst_pointnum)
250 |
251 | # merge all the scenes in the batchd
252 | batch_offsets = torch.tensor(batch_offsets, dtype=torch.int) # int (B+1)
253 |
254 | locs = torch.cat(locs, 0) # long (N, 1 + 3), the batch item idx is put in locs[:, 0]
255 | locs_float = torch.cat(locs_float, 0).to(torch.float32) # float (N, 3)
256 | feats = torch.cat(feats, 0) # float (N, C)
257 | labels = torch.cat(labels, 0).long() # long (N)
258 | instance_labels = torch.cat(instance_labels, 0).long() # long (N)
259 |
260 | instance_infos = torch.cat(instance_infos, 0).to(torch.float32) # float (N, 9) (meanxyz, minxyz, maxxyz)
261 | instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst)
262 |
263 | spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.full_scale[0], None) # long (3)
264 |
265 | # voxelize
266 | voxel_locs, p2v_map, v2p_map = hais_ops.voxelization_idx(locs, self.batch_size, self.mode)
267 |
268 | return {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
269 | 'locs_float': locs_float, 'feats': feats, 'labels': labels, 'instance_labels': instance_labels,
270 | 'instance_info': instance_infos, 'instance_pointnum': instance_pointnum,
271 | 'id': id, 'offsets': batch_offsets, 'spatial_shape': spatial_shape}
272 |
273 |
274 | def valMerge(self, id):
275 | locs = []
276 | locs_float = []
277 | feats = []
278 | labels = []
279 | instance_labels = []
280 |
281 | instance_infos = [] # (N, 9)
282 | instance_pointnum = [] # (total_nInst), int
283 |
284 | batch_offsets = [0]
285 |
286 | total_inst_num = 0
287 | for i, idx in enumerate(id):
288 | xyz_origin, rgb, label, instance_label = self.val_files[idx]
289 |
290 | # flip x / rotation
291 | xyz_middle = self.dataAugment(xyz_origin, False, True, True)
292 |
293 | # scale
294 | xyz = xyz_middle * self.scale
295 |
296 | # offset
297 | xyz -= xyz.min(0)
298 |
299 | # crop
300 | xyz, valid_idxs = self.crop(xyz)
301 |
302 | xyz_middle = xyz_middle[valid_idxs]
303 | xyz = xyz[valid_idxs]
304 | rgb = rgb[valid_idxs]
305 | label = label[valid_idxs]
306 | instance_label = self.getCroppedInstLabel(instance_label, valid_idxs)
307 |
308 | # get instance information
309 | inst_num, inst_infos = self.getInstanceInfo(xyz_middle, instance_label.astype(np.int32))
310 | inst_info = inst_infos["instance_info"] # (n, 9), (cx, cy, cz, minx, miny, minz, maxx, maxy, maxz)
311 | inst_pointnum = inst_infos["instance_pointnum"] # (nInst), list
312 |
313 | instance_label[np.where(instance_label != -100)] += total_inst_num
314 | total_inst_num += inst_num
315 |
316 | # merge the scene to the batch
317 | batch_offsets.append(batch_offsets[-1] + xyz.shape[0])
318 |
319 | locs.append(torch.cat([torch.LongTensor(xyz.shape[0], 1).fill_(i), torch.from_numpy(xyz).long()], 1))
320 | locs_float.append(torch.from_numpy(xyz_middle))
321 | feats.append(torch.from_numpy(rgb))
322 | labels.append(torch.from_numpy(label))
323 | instance_labels.append(torch.from_numpy(instance_label))
324 |
325 | instance_infos.append(torch.from_numpy(inst_info))
326 | instance_pointnum.extend(inst_pointnum)
327 |
328 | # merge all the scenes in the batch
329 | batch_offsets = torch.tensor(batch_offsets, dtype=torch.int) # int (B+1)
330 |
331 | locs = torch.cat(locs, 0) # long (N, 1 + 3), the batch item idx is put in locs[:, 0]
332 | locs_float = torch.cat(locs_float, 0).to(torch.float32) # float (N, 3)
333 | feats = torch.cat(feats, 0) # float (N, C)
334 | labels = torch.cat(labels, 0).long() # long (N)
335 | instance_labels = torch.cat(instance_labels, 0).long() # long (N)
336 |
337 | instance_infos = torch.cat(instance_infos, 0).to(torch.float32) # float (N, 9) (meanxyz, minxyz, maxxyz)
338 | instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int (total_nInst)
339 |
340 | spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.full_scale[0], None) # long (3)
341 |
342 | # voxelize
343 | voxel_locs, p2v_map, v2p_map = hais_ops.voxelization_idx(locs, self.batch_size, self.mode)
344 |
345 | return {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
346 | 'locs_float': locs_float, 'feats': feats, 'labels': labels, 'instance_labels': instance_labels,
347 | 'instance_info': instance_infos, 'instance_pointnum': instance_pointnum,
348 | 'id': id, 'offsets': batch_offsets, 'spatial_shape': spatial_shape}
349 |
350 |
351 | def testMerge(self, id):
352 | locs = []
353 | locs_float = []
354 | feats = []
355 |
356 | labels = []#
357 |
358 | batch_offsets = [0]
359 | for i, idx in enumerate(id):
360 |
361 | if self.test_split == 'val':
362 | xyz_origin, rgb, label, instance_label = self.test_files[idx]
363 | elif self.test_split == 'test':
364 | xyz_origin, rgb = self.test_files[idx]
365 | else:
366 | print("Wrong test split: {}!".format(self.test_split))
367 | exit(0)
368 |
369 | # flip x / rotation
370 | xyz_middle = self.dataAugment(xyz_origin, False, True, True)
371 |
372 | # scale
373 | xyz = xyz_middle * self.scale
374 |
375 | # offset
376 | xyz -= xyz.min(0)
377 |
378 | # merge the scene to the batch
379 | batch_offsets.append(batch_offsets[-1] + xyz.shape[0])
380 |
381 | locs.append(torch.cat([torch.LongTensor(xyz.shape[0], 1).fill_(i), torch.from_numpy(xyz).long()], 1))
382 | locs_float.append(torch.from_numpy(xyz_middle))
383 | feats.append(torch.from_numpy(rgb))
384 |
385 | if self.test_split == 'val':
386 | labels.append(torch.from_numpy(label))
387 |
388 | if self.test_split == 'val':
389 | labels = torch.cat(labels, 0).long() # long (N)
390 |
391 | # merge all the scenes in the batch
392 | batch_offsets = torch.tensor(batch_offsets, dtype=torch.int) # int (B+1)
393 |
394 | locs = torch.cat(locs, 0) # long (N, 1 + 3), the batch item idx is put in locs[:, 0]
395 | locs_float = torch.cat(locs_float, 0).to(torch.float32) # float (N, 3)
396 | feats = torch.cat(feats, 0) # float (N, C)
397 |
398 | spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.full_scale[0], None) # long (3)
399 |
400 | # voxelize
401 | voxel_locs, p2v_map, v2p_map = hais_ops.voxelization_idx(locs, self.batch_size, self.mode)
402 |
403 | if self.test_split == 'val':
404 | return {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
405 | 'locs_float': locs_float, 'feats': feats,
406 | 'id': id, 'offsets': batch_offsets, 'spatial_shape': spatial_shape,
407 | 'labels': labels}
408 |
409 | elif self.test_split == 'test':
410 | return {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
411 | 'locs_float': locs_float, 'feats': feats,
412 | 'id': id, 'offsets': batch_offsets, 'spatial_shape': spatial_shape}
413 | else:
414 | assert Exception
415 |
416 |
--------------------------------------------------------------------------------
/dataset/scannetv2/prepare_data_inst.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified from SparseConvNet data preparation: https://github.com/facebookresearch/SparseConvNet/blob/master/examples/ScanNet/prepare_data.py
3 | '''
4 |
5 | import glob, plyfile, numpy as np, multiprocessing as mp, torch, json, argparse
6 |
7 | import scannet_util
8 |
9 | # Map relevant classes to {0,1,...,19}, and ignored classes to -100
10 | remapper = np.ones(150) * (-100)
11 | for i, x in enumerate([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]):
12 | remapper[x] = i
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--data_split', help='data split (train / val / test)', default='train')
16 | opt = parser.parse_args()
17 |
18 | split = opt.data_split
19 | print('data split: {}'.format(split))
20 | files = sorted(glob.glob(split + '/*_vh_clean_2.ply'))
21 | if opt.data_split != 'test':
22 | files2 = sorted(glob.glob(split + '/*_vh_clean_2.labels.ply'))
23 | files3 = sorted(glob.glob(split + '/*_vh_clean_2.0.010000.segs.json'))
24 | files4 = sorted(glob.glob(split + '/*[0-9].aggregation.json'))
25 | assert len(files) == len(files2)
26 | assert len(files) == len(files3)
27 | assert len(files) == len(files4), "{} {}".format(len(files), len(files4))
28 |
29 | def f_test(fn):
30 | print(fn)
31 |
32 | f = plyfile.PlyData().read(fn)
33 | points = np.array([list(x) for x in f.elements[0]])
34 | coords = np.ascontiguousarray(points[:, :3] - points[:, :3].mean(0))
35 | colors = np.ascontiguousarray(points[:, 3:6]) / 127.5 - 1
36 |
37 | torch.save((coords, colors), fn[:-15] + '_inst_nostuff.pth')
38 | print('Saving to ' + fn[:-15] + '_inst_nostuff.pth')
39 |
40 |
41 | def f(fn):
42 | fn2 = fn[:-3] + 'labels.ply'
43 | fn3 = fn[:-15] + '_vh_clean_2.0.010000.segs.json'
44 | fn4 = fn[:-15] + '.aggregation.json'
45 | print(fn)
46 |
47 | f = plyfile.PlyData().read(fn)
48 | points = np.array([list(x) for x in f.elements[0]])
49 | coords = np.ascontiguousarray(points[:, :3] - points[:, :3].mean(0))
50 | colors = np.ascontiguousarray(points[:, 3:6]) / 127.5 - 1
51 |
52 | f2 = plyfile.PlyData().read(fn2)
53 | sem_labels = remapper[np.array(f2.elements[0]['label'])]
54 |
55 | with open(fn3) as jsondata:
56 | d = json.load(jsondata)
57 | seg = d['segIndices']
58 | segid_to_pointid = {}
59 | for i in range(len(seg)):
60 | if seg[i] not in segid_to_pointid:
61 | segid_to_pointid[seg[i]] = []
62 | segid_to_pointid[seg[i]].append(i)
63 |
64 | instance_segids = []
65 | labels = []
66 | with open(fn4) as jsondata:
67 | d = json.load(jsondata)
68 | for x in d['segGroups']:
69 | if scannet_util.g_raw2scannetv2[x['label']] != 'wall' and scannet_util.g_raw2scannetv2[x['label']] != 'floor':
70 | instance_segids.append(x['segments'])
71 | labels.append(x['label'])
72 | assert(x['label'] in scannet_util.g_raw2scannetv2.keys())
73 | if(fn == 'val/scene0217_00_vh_clean_2.ply' and instance_segids[0] == instance_segids[int(len(instance_segids) / 2)]):
74 | instance_segids = instance_segids[: int(len(instance_segids) / 2)]
75 | check = []
76 | for i in range(len(instance_segids)): check += instance_segids[i]
77 | assert len(np.unique(check)) == len(check)
78 |
79 | instance_labels = np.ones(sem_labels.shape[0]) * -100
80 | for i in range(len(instance_segids)):
81 | segids = instance_segids[i]
82 | pointids = []
83 | for segid in segids:
84 | pointids += segid_to_pointid[segid]
85 | instance_labels[pointids] = i
86 | assert(len(np.unique(sem_labels[pointids])) == 1)
87 |
88 | torch.save((coords, colors, sem_labels, instance_labels), fn[:-15]+'_inst_nostuff.pth')
89 | print('Saving to ' + fn[:-15]+'_inst_nostuff.pth')
90 |
91 | # for fn in files:
92 | # f(fn)
93 |
94 | p = mp.Pool(processes=mp.cpu_count())
95 | if opt.data_split == 'test':
96 | p.map(f_test, files)
97 | else:
98 | p.map(f, files)
99 | p.close()
100 | p.join()
--------------------------------------------------------------------------------
/dataset/scannetv2/prepare_data_inst_gttxt.py:
--------------------------------------------------------------------------------
1 | '''
2 | Generate instance groundtruth .txt files (for evaluation)
3 | '''
4 |
5 | import numpy as np
6 | import glob
7 | import torch
8 | import os
9 |
10 | semantic_label_idxs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]
11 | semantic_label_names = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture']
12 |
13 |
14 | if __name__ == '__main__':
15 | split = 'val'
16 | files = sorted(glob.glob('{}/scene*_inst_nostuff.pth'.format(split)))
17 | rooms = [torch.load(i) for i in files]
18 |
19 | if not os.path.exists(split + '_gt'):
20 | os.mkdir(split + '_gt')
21 |
22 | for i in range(len(rooms)):
23 | xyz, rgb, label, instance_label = rooms[i] # label 0~19 -100; instance_label 0~instance_num-1 -100
24 | scene_name = files[i].split('/')[-1][:12]
25 | print('{}/{} {}'.format(i + 1, len(rooms), scene_name))
26 |
27 | instance_label_new = np.zeros(instance_label.shape, dtype=np.int32) # 0 for unannotated, xx00y: x for semantic_label, y for inst_id (1~instance_num)
28 |
29 | instance_num = int(instance_label.max()) + 1
30 | for inst_id in range(instance_num):
31 | instance_mask = np.where(instance_label == inst_id)[0]
32 | sem_id = int(label[instance_mask[0]])
33 | if(sem_id == -100): sem_id = 0
34 | semantic_label = semantic_label_idxs[sem_id]
35 | instance_label_new[instance_mask] = semantic_label * 1000 + inst_id + 1
36 |
37 | np.savetxt(os.path.join(split + '_gt', scene_name + '.txt'), instance_label_new, fmt='%d')
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/dataset/scannetv2/scannet_util.py:
--------------------------------------------------------------------------------
1 | g_label_names = ['unannotated', 'wall', 'floor', 'chair', 'table', 'desk', 'bed', 'bookshelf', 'sofa', 'sink', 'bathtub', 'toilet', 'curtain', 'counter', 'door', 'window', 'shower curtain', 'refridgerator', 'picture', 'cabinet', 'otherfurniture']
2 |
3 | def get_raw2scannetv2_label_map():
4 | lines = [line.rstrip() for line in open('scannetv2-labels.combined.tsv')]
5 | lines_0 = lines[0].split('\t')
6 | print(lines_0)
7 | print(len(lines))
8 | lines = lines[1:]
9 | raw2scannet = {}
10 | for i in range(len(lines)):
11 | label_classes_set = set(g_label_names)
12 | elements = lines[i].split('\t')
13 | raw_name = elements[1]
14 | if (elements[1] != elements[2]):
15 | print('{}: {} {}'.format(i, elements[1], elements[2]))
16 | nyu40_name = elements[7]
17 | if nyu40_name not in label_classes_set:
18 | raw2scannet[raw_name] = 'unannotated'
19 | else:
20 | raw2scannet[raw_name] = nyu40_name
21 | return raw2scannet
22 |
23 | g_raw2scannetv2 = get_raw2scannetv2_label_map()
--------------------------------------------------------------------------------
/docs/STPLS3D_leaderboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/HAIS/a82f629c5a7b42109fd32fc76f053a941dfbf77a/docs/STPLS3D_leaderboard.png
--------------------------------------------------------------------------------
/docs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/HAIS/a82f629c5a7b42109fd32fc76f053a941dfbf77a/docs/framework.png
--------------------------------------------------------------------------------
/docs/scannet_leaderboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/HAIS/a82f629c5a7b42109fd32fc76f053a941dfbf77a/docs/scannet_leaderboard.png
--------------------------------------------------------------------------------
/docs/scene0249_00_output_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/HAIS/a82f629c5a7b42109fd32fc76f053a941dfbf77a/docs/scene0249_00_output_2.gif
--------------------------------------------------------------------------------
/docs/scene0430_00_output_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/HAIS/a82f629c5a7b42109fd32fc76f053a941dfbf77a/docs/scene0430_00_output_2.gif
--------------------------------------------------------------------------------
/lib/hais_ops/functions/hais_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 |
4 | import HAIS_OP
5 |
6 | class HierarchicalAggregation(Function):
7 | @staticmethod
8 | def forward(ctx, semantic_label, coord_shift, ball_query_idxs, start_len, batch_idxs, training_mode, using_set_aggr):
9 | '''
10 | :param ctx:
11 | :param semantic_label: (N_fg), int
12 | :param coord_shift: (N_fg, 3), float
13 | :param ball_query_idxs: (nActive), int
14 | :param start_len: (N_fg, 2), int
15 | :param batch_idxs: (N_fg), int
16 |
17 | :return: cluster_idxs: int (sumNPoint, 2), [:, 0] for cluster_id, [:, 1] for corresponding point idxs in N
18 | :return: cluster_offsets: int (nCluster + 1)
19 | '''
20 | N = start_len.size(0)
21 |
22 | assert semantic_label.is_contiguous()
23 | assert coord_shift.is_contiguous()
24 | assert ball_query_idxs.is_contiguous()
25 | assert start_len.is_contiguous()
26 |
27 | fragment_idxs = semantic_label.new()
28 | fragment_offsets = semantic_label.new()
29 | fragment_centers = coord_shift.new() # float
30 |
31 | cluster_idxs_kept = semantic_label.new()
32 | cluster_offsets_kept = semantic_label.new()
33 | cluster_centers_kept = coord_shift.new() # float
34 |
35 | primary_idxs = semantic_label.new()
36 | primary_offsets = semantic_label.new()
37 | primary_centers = coord_shift.new() # float
38 |
39 | primary_idxs_post = semantic_label.new()
40 | primary_offsets_post = semantic_label.new()
41 |
42 | training_mode_ = 1 if training_mode == 'train' else 0
43 | using_set_aggr_ = int(using_set_aggr)
44 |
45 | HAIS_OP.hierarchical_aggregation(semantic_label, coord_shift, batch_idxs, ball_query_idxs, start_len,
46 | fragment_idxs, fragment_offsets, fragment_centers,
47 | cluster_idxs_kept, cluster_offsets_kept, cluster_centers_kept,
48 | primary_idxs, primary_offsets, primary_centers,
49 | primary_idxs_post, primary_offsets_post,
50 | N, training_mode_, using_set_aggr_)
51 |
52 | if using_set_aggr_ == 0: # not set aggr
53 | pass
54 | else:
55 | # cut off tails
56 | primary_idxs_post = primary_idxs_post[:primary_offsets_post[-1]]
57 | primary_idxs = primary_idxs_post
58 | primary_offsets = primary_offsets_post
59 |
60 | cluster_idxs = cluster_idxs_kept
61 | cluster_offsets = cluster_offsets_kept
62 |
63 | if primary_idxs.shape[0] != 0:
64 | #add primary
65 | primary_idxs[:, 0] += (cluster_offsets.size(0) - 1)
66 | primary_offsets += cluster_offsets[-1]
67 | cluster_idxs = torch.cat((cluster_idxs, primary_idxs), dim=0).cpu()
68 | cluster_offsets = torch.cat((cluster_offsets, primary_offsets[1:])).cpu()
69 |
70 | return cluster_idxs, cluster_offsets
71 |
72 |
73 | @staticmethod
74 | def backward(ctx, a=None):
75 | return None
76 |
77 | hierarchical_aggregation = HierarchicalAggregation.apply
78 |
79 |
80 | class CalIoUAndMasklabel(Function):
81 | @staticmethod
82 | def forward(ctx, proposals_idx, proposals_offset, instance_labels, instance_pointnum, mask_scores_sigmoid, mode):
83 | '''
84 | :param ctx:
85 | :param proposals_idx: (sumNPoint), int
86 | :param proposals_offset: (nProposal + 1), int
87 | :param instance_labels: (N), long, 0~total_nInst-1, -100
88 | :param instance_pointnum: (total_nInst), int
89 | :param mask_scores_sigmoid: (sumNPoint), float
90 | :param mode: int, mode = 1 if cal IoU based on mask else mode = 0
91 |
92 | :return: proposals_iou: (nProposal, total_nInst), float
93 | :return mask_label:
94 | '''
95 |
96 | nInstance = instance_pointnum.size(0)
97 | nProposal = proposals_offset.size(0) - 1
98 | proposals_iou = torch.cuda.FloatTensor(nProposal, nInstance).zero_()
99 | mask_label = torch.cuda.FloatTensor(mask_scores_sigmoid.shape).zero_() - 1.
100 |
101 | assert proposals_idx.is_contiguous() and proposals_idx.is_cuda
102 | assert proposals_offset.is_contiguous() and proposals_offset.is_cuda
103 | assert instance_labels.is_contiguous() and instance_labels.is_cuda
104 | assert instance_pointnum.is_contiguous() and instance_pointnum.is_cuda
105 | assert mask_scores_sigmoid.is_contiguous() and mask_scores_sigmoid.is_cuda
106 |
107 | HAIS_OP.cal_iou_and_masklabel(proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou, nInstance, nProposal, mask_scores_sigmoid, mask_label, mode)
108 |
109 | return proposals_iou, mask_label
110 |
111 | @staticmethod
112 | def backward(ctx, a=None):
113 | return None, None, None, None
114 |
115 | cal_iou_and_masklabel = CalIoUAndMasklabel.apply
116 |
117 |
118 | class Voxelization_Idx(Function):
119 | @staticmethod
120 | def forward(ctx, coords, batchsize, mode=4):
121 | '''
122 | :param ctx:
123 | :param coords: long (N, dimension + 1) or (N, dimension) dimension = 3
124 | :param batchsize
125 | :param mode: int 4=mean
126 | :param dimension: int
127 | :return: output_coords: long (M, dimension + 1) (M <= N)
128 | :return: output_map: int M * (maxActive + 1)
129 | :return: input_map: int N
130 | '''
131 | assert coords.is_contiguous()
132 | N = coords.size(0)
133 | output_coords = coords.new()
134 |
135 | input_map = torch.IntTensor(N).zero_()
136 | output_map = input_map.new()
137 |
138 | HAIS_OP.voxelize_idx(coords, output_coords, input_map, output_map, batchsize, mode)
139 | return output_coords, input_map, output_map
140 |
141 | @staticmethod
142 | def backward(ctx, a=None, b=None, c=None):
143 | return None
144 |
145 | voxelization_idx = Voxelization_Idx.apply
146 |
147 |
148 | class Voxelization(Function):
149 | @staticmethod
150 | def forward(ctx, feats, map_rule, mode=4):
151 | '''
152 | :param ctx:
153 | :param map_rule: cuda int M * (maxActive + 1)
154 | :param feats: cuda float N * C
155 | :return: output_feats: cuda float M * C
156 | '''
157 | assert map_rule.is_contiguous()
158 | assert feats.is_contiguous()
159 | N, C = feats.size()
160 | M = map_rule.size(0)
161 | maxActive = map_rule.size(1) - 1
162 |
163 | output_feats = torch.cuda.FloatTensor(M, C).zero_()
164 |
165 | ctx.for_backwards = (map_rule, mode, maxActive, N)
166 |
167 | HAIS_OP.voxelize_fp(feats, output_feats, map_rule, mode, M, maxActive, C)
168 | return output_feats
169 |
170 |
171 | @staticmethod
172 | def backward(ctx, d_output_feats):
173 | map_rule, mode, maxActive, N = ctx.for_backwards
174 | M, C = d_output_feats.size()
175 |
176 | d_feats = torch.cuda.FloatTensor(N, C).zero_()
177 |
178 | HAIS_OP.voxelize_bp(d_output_feats.contiguous(), d_feats, map_rule, mode, M, maxActive, C)
179 | return d_feats, None, None
180 |
181 | voxelization = Voxelization.apply
182 |
183 |
184 | class PointRecover(Function):
185 | @staticmethod
186 | def forward(ctx, feats, map_rule, nPoint):
187 | '''
188 | :param ctx:
189 | :param feats: cuda float M * C
190 | :param map_rule: cuda int M * (maxActive + 1)
191 | :param nPoint: int
192 | :return: output_feats: cuda float N * C
193 | '''
194 | assert map_rule.is_contiguous()
195 | assert feats.is_contiguous()
196 | M, C = feats.size()
197 | maxActive = map_rule.size(1) - 1
198 |
199 | output_feats = torch.cuda.FloatTensor(nPoint, C).zero_()
200 |
201 | ctx.for_backwards = (map_rule, maxActive, M)
202 |
203 | HAIS_OP.point_recover_fp(feats, output_feats, map_rule, M, maxActive, C)
204 |
205 | return output_feats
206 |
207 | @staticmethod
208 | def backward(ctx, d_output_feats):
209 | map_rule, maxActive, M = ctx.for_backwards
210 | N, C = d_output_feats.size()
211 |
212 | d_feats = torch.cuda.FloatTensor(M, C).zero_()
213 |
214 | HAIS_OP.point_recover_bp(d_output_feats.contiguous(), d_feats, map_rule, M, maxActive, C)
215 |
216 | return d_feats, None, None
217 |
218 | point_recover = PointRecover.apply
219 |
220 |
221 | class BallQueryBatchP(Function):
222 | @staticmethod
223 | def forward(ctx, coords, batch_idxs, batch_offsets, radius, meanActive):
224 | '''
225 | :param ctx:
226 | :param coords: (n, 3) float
227 | :param batch_idxs: (n) int
228 | :param batch_offsets: (B+1) int
229 | :param radius: float
230 | :param meanActive: int
231 | :return: idx (nActive), int
232 | :return: start_len (n, 2), int
233 | '''
234 |
235 | n = coords.size(0)
236 |
237 | assert coords.is_contiguous() and coords.is_cuda
238 | assert batch_idxs.is_contiguous() and batch_idxs.is_cuda
239 | assert batch_offsets.is_contiguous() and batch_offsets.is_cuda
240 |
241 | while True:
242 | idx = torch.cuda.IntTensor(n * meanActive).zero_()
243 | start_len = torch.cuda.IntTensor(n, 2).zero_()
244 | nActive = HAIS_OP.ballquery_batch_p(coords, batch_idxs, batch_offsets, idx, start_len, n, meanActive, radius)
245 | if nActive <= n * meanActive:
246 | break
247 | meanActive = int(nActive // n + 1)
248 | idx = idx[:nActive]
249 |
250 | return idx, start_len
251 |
252 | @staticmethod
253 | def backward(ctx, a=None, b=None):
254 | return None, None, None
255 |
256 | ballquery_batch_p = BallQueryBatchP.apply
257 |
258 |
259 | class BFSCluster(Function):
260 | @staticmethod
261 | def forward(ctx, semantic_label, ball_query_idxs, start_len, threshold):
262 | '''
263 | :param ctx:
264 | :param semantic_label: (N), int
265 | :param ball_query_idxs: (nActive), int
266 | :param start_len: (N, 2), int
267 | :return: cluster_idxs: int (sumNPoint, 2), dim 0 for cluster_id, dim 1 for corresponding point idxs in N
268 | :return: cluster_offsets: int (nCluster + 1)
269 | '''
270 |
271 | N = start_len.size(0)
272 |
273 | assert semantic_label.is_contiguous()
274 | assert ball_query_idxs.is_contiguous()
275 | assert start_len.is_contiguous()
276 |
277 | cluster_idxs = semantic_label.new()
278 | cluster_offsets = semantic_label.new()
279 |
280 | HAIS_OP.bfs_cluster(semantic_label, ball_query_idxs, start_len, cluster_idxs, cluster_offsets, N, threshold)
281 |
282 | return cluster_idxs, cluster_offsets
283 |
284 | @staticmethod
285 | def backward(ctx, a=None):
286 | return None
287 |
288 | bfs_cluster = BFSCluster.apply
289 |
290 |
291 | class RoiPool(Function):
292 | @staticmethod
293 | def forward(ctx, feats, proposals_offset):
294 | '''
295 | :param ctx:
296 | :param feats: (sumNPoint, C) float
297 | :param proposals_offset: (nProposal + 1) int
298 | :return: output_feats (nProposal, C) float
299 | '''
300 | nProposal = proposals_offset.size(0) - 1
301 | sumNPoint, C = feats.size()
302 |
303 | assert feats.is_contiguous()
304 | assert proposals_offset.is_contiguous()
305 |
306 | output_feats = torch.cuda.FloatTensor(nProposal, C).zero_()
307 | output_maxidx = torch.cuda.IntTensor(nProposal, C).zero_()
308 |
309 | HAIS_OP.roipool_fp(feats, proposals_offset, output_feats, output_maxidx, nProposal, C)
310 |
311 | ctx.for_backwards = (output_maxidx, proposals_offset, sumNPoint)
312 |
313 | return output_feats
314 |
315 | @staticmethod
316 | def backward(ctx, d_output_feats):
317 | nProposal, C = d_output_feats.size()
318 |
319 | output_maxidx, proposals_offset, sumNPoint = ctx.for_backwards
320 |
321 | d_feats = torch.cuda.FloatTensor(sumNPoint, C).zero_()
322 |
323 | HAIS_OP.roipool_bp(d_feats, proposals_offset, output_maxidx, d_output_feats.contiguous(), nProposal, C)
324 |
325 | return d_feats, None
326 |
327 | roipool = RoiPool.apply
328 |
329 |
330 | class GetIoU(Function):
331 | @staticmethod
332 | def forward(ctx, proposals_idx, proposals_offset, instance_labels, instance_pointnum):
333 | '''
334 | :param ctx:
335 | :param proposals_idx: (sumNPoint), int
336 | :param proposals_offset: (nProposal + 1), int
337 | :param instance_labels: (N), long, 0~total_nInst-1, -100
338 | :param instance_pointnum: (total_nInst), int
339 | :return: proposals_iou: (nProposal, total_nInst), float
340 | '''
341 | nInstance = instance_pointnum.size(0)
342 | nProposal = proposals_offset.size(0) - 1
343 |
344 | assert proposals_idx.is_contiguous() and proposals_idx.is_cuda
345 | assert proposals_offset.is_contiguous() and proposals_offset.is_cuda
346 | assert instance_labels.is_contiguous() and instance_labels.is_cuda
347 | assert instance_pointnum.is_contiguous() and instance_pointnum.is_cuda
348 |
349 | proposals_iou = torch.cuda.FloatTensor(nProposal, nInstance).zero_()
350 |
351 | HAIS_OP.get_iou(proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou, nInstance, nProposal)
352 |
353 | return proposals_iou
354 |
355 | @staticmethod
356 | def backward(ctx, a=None):
357 | return None, None, None, None
358 |
359 | get_iou = GetIoU.apply
360 |
361 |
362 | class SecMean(Function):
363 | @staticmethod
364 | def forward(ctx, inp, offsets):
365 | '''
366 | :param ctx:
367 | :param inp: (N, C) float
368 | :param offsets: (nProposal + 1) int
369 | :return: out (nProposal, C) float
370 | '''
371 | nProposal = offsets.size(0) - 1
372 | C = inp.size(1)
373 |
374 | assert inp.is_contiguous()
375 | assert offsets.is_contiguous()
376 |
377 | out = torch.cuda.FloatTensor(nProposal, C).zero_()
378 |
379 | HAIS_OP.sec_mean(inp, offsets, out, nProposal, C)
380 |
381 | return out
382 |
383 | @staticmethod
384 | def backward(ctx, a=None):
385 | return None, None
386 |
387 | sec_mean = SecMean.apply
388 |
389 |
390 | class SecMin(Function):
391 | @staticmethod
392 | def forward(ctx, inp, offsets):
393 | '''
394 | :param ctx:
395 | :param inp: (N, C) float
396 | :param offsets: (nProposal + 1) int
397 | :return: out (nProposal, C) float
398 | '''
399 | nProposal = offsets.size(0) - 1
400 | C = inp.size(1)
401 |
402 | assert inp.is_contiguous()
403 | assert offsets.is_contiguous()
404 |
405 | out = torch.cuda.FloatTensor(nProposal, C).zero_()
406 |
407 | HAIS_OP.sec_min(inp, offsets, out, nProposal, C)
408 |
409 | return out
410 |
411 | @staticmethod
412 | def backward(ctx, a=None):
413 | return None, None
414 |
415 | sec_min = SecMin.apply
416 |
417 |
418 | class SecMax(Function):
419 | @staticmethod
420 | def forward(ctx, inp, offsets):
421 | '''
422 | :param ctx:
423 | :param inp: (N, C) float
424 | :param offsets: (nProposal + 1) int
425 | :return: out (nProposal, C) float
426 | '''
427 | nProposal = offsets.size(0) - 1
428 | C = inp.size(1)
429 |
430 | assert inp.is_contiguous()
431 | assert offsets.is_contiguous()
432 |
433 | out = torch.cuda.FloatTensor(nProposal, C).zero_()
434 |
435 | HAIS_OP.sec_max(inp, offsets, out, nProposal, C)
436 |
437 | return out
438 |
439 | @staticmethod
440 | def backward(ctx, a=None):
441 | return None, None
442 |
443 | sec_max = SecMax.apply
--------------------------------------------------------------------------------
/lib/hais_ops/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name='HAIS_OP',
6 | ext_modules=[
7 | CUDAExtension('HAIS_OP', [
8 | 'src/hais_ops_api.cpp',
9 |
10 | 'src/hais_ops.cpp',
11 | 'src/cuda.cu'
12 | ], extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']})
13 | ],
14 | cmdclass={'build_ext': BuildExtension}
15 | )
--------------------------------------------------------------------------------
/lib/hais_ops/src/bfs_cluster/bfs_cluster.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Ball Query with BatchIdx & Clustering Algorithm
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #include "bfs_cluster.h"
8 |
9 | /* ================================== ballquery_batch_p ================================== */
10 | // input xyz: (n, 3) float
11 | // input batch_idxs: (n) int
12 | // input batch_offsets: (B+1) int, batch_offsets[-1]
13 | // output idx: (n * meanActive) dim 0 for number of points in the ball, idx in n
14 | // output start_len: (n, 2), int
15 | int ballquery_batch_p(at::Tensor xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor batch_offsets_tensor, at::Tensor idx_tensor, at::Tensor start_len_tensor, int n, int meanActive, float radius){
16 | const float *xyz = xyz_tensor.data();
17 | const int *batch_idxs = batch_idxs_tensor.data();
18 | const int *batch_offsets = batch_offsets_tensor.data();
19 | int *idx = idx_tensor.data();
20 | int *start_len = start_len_tensor.data();
21 |
22 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
23 | int cumsum = ballquery_batch_p_cuda(n, meanActive, radius, xyz, batch_idxs, batch_offsets, idx, start_len, stream);
24 | return cumsum;
25 | }
26 |
27 | /* ================================== bfs_cluster ================================== */
28 | ConnectedComponent find_cc(Int idx, int *semantic_label, Int *ball_query_idxs, int *start_len, int *visited){
29 | ConnectedComponent cc;
30 | cc.addPoint(idx);
31 | visited[idx] = 1;
32 |
33 | std::queue Q;
34 | assert(Q.empty());
35 | Q.push(idx);
36 |
37 | while(!Q.empty()){
38 | Int cur = Q.front(); Q.pop();
39 | int start = start_len[cur * 2];
40 | int len = start_len[cur * 2 + 1];
41 | int label_cur = semantic_label[cur];
42 | for(Int i = start; i < start + len; i++){
43 | Int idx_i = ball_query_idxs[i];
44 | if(semantic_label[idx_i] != label_cur) continue;
45 | if(visited[idx_i] == 1) continue;
46 |
47 | cc.addPoint(idx_i);
48 | visited[idx_i] = 1;
49 |
50 | Q.push(idx_i);
51 | }
52 | }
53 | return cc;
54 | }
55 |
56 | //input: semantic_label, int, N
57 | //input: ball_query_idxs, Int, (nActive)
58 | //input: start_len, int, (N, 2)
59 | //output: clusters, CCs
60 | int get_clusters(int *semantic_label, Int *ball_query_idxs, int *start_len, const Int nPoint, int threshold, ConnectedComponents &clusters){
61 | int visited[nPoint] = {0};
62 |
63 | int sumNPoint = 0;
64 | for(Int i = 0; i < nPoint; i++){
65 | if(visited[i] == 0){
66 | ConnectedComponent CC = find_cc(i, semantic_label, ball_query_idxs, start_len, visited);
67 | if((int)CC.pt_idxs.size() >= threshold){
68 | clusters.push_back(CC);
69 | sumNPoint += (int)CC.pt_idxs.size();
70 | }
71 | }
72 | }
73 |
74 | return sumNPoint;
75 | }
76 |
77 | void fill_cluster_idxs_(ConnectedComponents &CCs, int *cluster_idxs, int *cluster_offsets){
78 | for(int i = 0; i < (int)CCs.size(); i++){
79 | cluster_offsets[i + 1] = cluster_offsets[i] + (int)CCs[i].pt_idxs.size();
80 | for(int j = 0; j < (int)CCs[i].pt_idxs.size(); j++){
81 | int idx = CCs[i].pt_idxs[j];
82 | cluster_idxs[(cluster_offsets[i] + j) * 2 + 0] = i;
83 | cluster_idxs[(cluster_offsets[i] + j) * 2 + 1] = idx;
84 | }
85 | }
86 | }
87 |
88 | //input: semantic_label, int, N
89 | //input: ball_query_idxs, int, (nActive)
90 | //input: start_len, int, (N, 2)
91 | //output: cluster_idxs, int (sumNPoint, 2), dim 0 for cluster_id, dim 1 for corresponding point idxs in N
92 | //output: cluster_offsets, int (nCluster + 1)
93 | void bfs_cluster(at::Tensor semantic_label_tensor, at::Tensor ball_query_idxs_tensor, at::Tensor start_len_tensor,
94 | at::Tensor cluster_idxs_tensor, at::Tensor cluster_offsets_tensor, const int N, int threshold){
95 | int *semantic_label = semantic_label_tensor.data();
96 | Int *ball_query_idxs = ball_query_idxs_tensor.data();
97 | int *start_len = start_len_tensor.data();
98 |
99 | ConnectedComponents CCs;
100 | int sumNPoint = get_clusters(semantic_label, ball_query_idxs, start_len, N, threshold, CCs);
101 |
102 | int nCluster = (int)CCs.size();
103 | cluster_idxs_tensor.resize_({sumNPoint, 2});
104 | cluster_offsets_tensor.resize_({nCluster + 1});
105 | cluster_idxs_tensor.zero_();
106 | cluster_offsets_tensor.zero_();
107 |
108 | int *cluster_idxs = cluster_idxs_tensor.data();
109 | int *cluster_offsets = cluster_offsets_tensor.data();
110 |
111 | fill_cluster_idxs_(CCs, cluster_idxs, cluster_offsets);
112 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/bfs_cluster/bfs_cluster.cu:
--------------------------------------------------------------------------------
1 | /*
2 | Ball Query with BatchIdx
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 | #include "bfs_cluster.h"
7 | #include "../cuda_utils.h"
8 |
9 | #include
10 | #include
11 | #include
12 |
13 |
14 | /* ================================== ballquery_batch_p ================================== */
15 | __global__ void ballquery_batch_p_cuda_(int n, int meanActive, float radius, const float *xyz, const int *batch_idxs, const int *batch_offsets, int *idx, int *start_len, int *cumsum) {
16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
17 | if (pt_idx >= n) return;
18 |
19 | start_len += (pt_idx * 2);
20 | int idx_temp[1000];
21 |
22 | float radius2 = radius * radius;
23 | float o_x = xyz[pt_idx * 3 + 0];
24 | float o_y = xyz[pt_idx * 3 + 1];
25 | float o_z = xyz[pt_idx * 3 + 2];
26 |
27 | int batch_idx = batch_idxs[pt_idx];
28 | int start = batch_offsets[batch_idx];
29 | int end = batch_offsets[batch_idx + 1];
30 |
31 | int cnt = 0;
32 | for(int k = start; k < end; k++){
33 | float x = xyz[k * 3 + 0];
34 | float y = xyz[k * 3 + 1];
35 | float z = xyz[k * 3 + 2];
36 | float d2 = (o_x - x) * (o_x - x) + (o_y - y) * (o_y - y) + (o_z - z) * (o_z - z);
37 | if(d2 < radius2){
38 | if(cnt < 1000){
39 | idx_temp[cnt] = k;
40 | }
41 | else{
42 | break;
43 | }
44 | ++cnt;
45 | }
46 | }
47 |
48 | start_len[0] = atomicAdd(cumsum, cnt);
49 | start_len[1] = cnt;
50 |
51 | int thre = n * meanActive;
52 | if(start_len[0] >= thre) return;
53 |
54 | idx += start_len[0];
55 | if(start_len[0] + cnt >= thre) cnt = thre - start_len[0];
56 |
57 | for(int k = 0; k < cnt; k++){
58 | idx[k] = idx_temp[k];
59 | }
60 | }
61 |
62 |
63 | int ballquery_batch_p_cuda(int n, int meanActive, float radius, const float *xyz, const int *batch_idxs, const int *batch_offsets, int *idx, int *start_len, cudaStream_t stream) {
64 | // param xyz: (n, 3)
65 | // param batch_idxs: (n)
66 | // param batch_offsets: (B + 1)
67 | // output idx: (n * meanActive) dim 0 for number of points in the ball, idx in n
68 | // output start_len: (n, 2), int
69 |
70 | cudaError_t err;
71 |
72 | dim3 blocks(DIVUP(n, MAX_THREADS_PER_BLOCK));
73 | dim3 threads(MAX_THREADS_PER_BLOCK);
74 |
75 | int cumsum = 0;
76 | int* p_cumsum;
77 | cudaMalloc((void**)&p_cumsum, sizeof(int));
78 | cudaMemcpy(p_cumsum, &cumsum, sizeof(int), cudaMemcpyHostToDevice);
79 |
80 | ballquery_batch_p_cuda_<<>>(n, meanActive, radius, xyz, batch_idxs, batch_offsets, idx, start_len, p_cumsum);
81 |
82 | err = cudaGetLastError();
83 | if (cudaSuccess != err) {
84 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
85 | exit(-1);
86 | }
87 |
88 | cudaMemcpy(&cumsum, p_cumsum, sizeof(int), cudaMemcpyDeviceToHost);
89 | return cumsum;
90 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/bfs_cluster/bfs_cluster.h:
--------------------------------------------------------------------------------
1 | /*
2 | Ball Query with BatchIdx & Clustering Algorithm
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #ifndef BFS_CLUSTER_H
8 | #define BFS_CLUSTER_H
9 | #include
10 | #include
11 | #include
12 |
13 | #include "../datatype/datatype.h"
14 |
15 | int ballquery_batch_p(at::Tensor xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor batch_offsets_tensor, at::Tensor idx_tensor, at::Tensor start_len_tensor, int n, int meanActive, float radius);
16 | int ballquery_batch_p_cuda(int n, int meanActive, float radius, const float *xyz, const int *batch_idxs, const int *batch_offsets, int *idx, int *start_len, cudaStream_t stream);
17 |
18 | void bfs_cluster(at::Tensor semantic_label_tensor, at::Tensor ball_query_idxs_tensor, at::Tensor start_len_tensor, at::Tensor cluster_idxs_tensor, at::Tensor cluster_offsets_tensor, const int N, int threshold);
19 |
20 | #endif //BFS_CLUSTER_H
--------------------------------------------------------------------------------
/lib/hais_ops/src/cal_iou_and_masklabel/cal_iou_and_masklabel.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Get the IoU between predictions and gt masks
3 | */
4 |
5 | #include "cal_iou_and_masklabel.h"
6 |
7 | void cal_iou_and_masklabel(at::Tensor proposals_idx_tensor, at::Tensor proposals_offset_tensor, \
8 | at::Tensor instance_labels_tensor, at::Tensor instance_pointnum_tensor, \
9 | at::Tensor proposals_iou_tensor, int nInstance, int nProposal,
10 | at::Tensor mask_scores_sigmoid_tensor, at::Tensor mask_labels_tensor,
11 | int mode){
12 | int *proposals_idx = proposals_idx_tensor.data();
13 | int *proposals_offset = proposals_offset_tensor.data();
14 | long *instance_labels = instance_labels_tensor.data();
15 | int *instance_pointnum = instance_pointnum_tensor.data();
16 |
17 | float *proposals_iou = proposals_iou_tensor.data();
18 |
19 | float *mask_scores_sigmoid = mask_scores_sigmoid_tensor.data();
20 | float *mask_label = mask_labels_tensor.data();
21 |
22 |
23 |
24 |
25 | //input: nInstance (1,), int
26 | //input: nProposal (1,), int
27 | //input: proposals_idx (sumNPoint), int
28 | //input: proposals_offset (nProposal + 1), int
29 | //input: instance_labels (N), long, 0~total_nInst-1, -100
30 | //input: instance_pointnum (total_nInst), int
31 | //input: mask_scores_sigmoid (sumNPoint, 1), float
32 | //output: proposals_iou (nProposal, total_nInst), float
33 | //output: mask_label (sumNPoint, 1), float
34 | cal_iou_and_masklabel_cuda(nInstance, nProposal, proposals_idx, proposals_offset, instance_labels,
35 | instance_pointnum, proposals_iou, mask_scores_sigmoid, mask_label, mode);
36 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/cal_iou_and_masklabel/cal_iou_and_masklabel.cu:
--------------------------------------------------------------------------------
1 | /*
2 | Calculate the IoU between predictions and GTs and generate mask labels
3 | */
4 |
5 | #include
6 | #include
7 | #include "cal_iou_and_masklabel.h"
8 |
9 | #define MAX_BLOCKS_PER_GRID 32768
10 | #define MAX_THREADS_PER_BLOCK 512
11 |
12 |
13 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
14 |
15 | __global__ void get_iou_mask_cuda_(int nInstance, int nProposal, int *proposals_idx, int *proposals_offset, long *instance_labels, int *instance_pointnum, float *proposals_iou, float *mask_scores_sigmoid, float *mask_label, int mode){
16 |
17 | if (mode == 0){ // cal iou based on clustering result
18 | for(int proposal_id = blockIdx.x; proposal_id < nProposal; proposal_id += gridDim.x){
19 | int start = proposals_offset[proposal_id];
20 | int end = proposals_offset[proposal_id + 1];
21 | int proposal_total = end - start;
22 | for(int instance_id = threadIdx.x; instance_id < nInstance; instance_id += blockDim.x){
23 | int instance_total = instance_pointnum[instance_id];
24 | int intersection = 0;
25 | for(int i = start; i < end; i++){
26 | int idx = proposals_idx[i];
27 | if((int)instance_labels[idx] == instance_id){
28 | intersection += 1;
29 | }
30 | }
31 | proposals_iou[proposal_id * nInstance + instance_id] = (float)intersection / ((float)(proposal_total + instance_total - intersection) + 1e-5);
32 | }
33 | }
34 | }
35 | else if(mode == 1){ // cal iou based on mask result
36 | for(int proposal_id = blockIdx.x; proposal_id < nProposal; proposal_id += gridDim.x){
37 | int start = proposals_offset[proposal_id];
38 | int end = proposals_offset[proposal_id + 1];
39 | int proposal_total = 0;
40 |
41 | for(int i = start; i < end; i++)
42 | if(mask_scores_sigmoid[i] > 0.5)
43 | proposal_total += 1;
44 |
45 | for(int instance_id = threadIdx.x; instance_id < nInstance; instance_id += blockDim.x){
46 | int instance_total = instance_pointnum[instance_id];
47 | int intersection = 0;
48 | for(int i = start; i < end; i++){
49 | int idx = proposals_idx[i];
50 | if(mask_scores_sigmoid[i] > 0.5){
51 | if((int)instance_labels[idx] == instance_id)
52 | intersection += 1;
53 | }
54 | }
55 | proposals_iou[proposal_id * nInstance + instance_id] = (float)intersection / ((float)(proposal_total + instance_total - intersection) + 1e-5);
56 | }
57 | }
58 | }
59 | }
60 |
61 |
62 | __global__ void get_mask_label_cuda_(int nInstance, int nProposal, int *proposals_idx, int *proposals_offset, long *instance_labels, int *instance_pointnum, float *proposals_iou, float *mask_scores_sigmoid, float *mask_label){
63 | for(int proposal_id = blockIdx.x; proposal_id < nProposal; proposal_id += gridDim.x){
64 | int start = proposals_offset[proposal_id];
65 | int end = proposals_offset[proposal_id + 1];
66 | // int proposal_total = end - start;
67 |
68 | //find the instance with max iou
69 | float max_iou = 0.;
70 | int max_ind = 0;
71 | for(int instance_id = 0; instance_id < nInstance; instance_id++){
72 | if (proposals_iou[proposal_id * nInstance + instance_id] > max_iou) {
73 | max_iou = proposals_iou[proposal_id * nInstance + instance_id];
74 | max_ind = instance_id;
75 | }
76 | }
77 | //mask_label initilized with -1 (-1 means ignored)
78 | if (max_iou > 0.5) {
79 | for(int i = start; i < end; i++){
80 | int idx = proposals_idx[i];
81 | if((int)instance_labels[idx] == max_ind){
82 | mask_label[i] = 1.;
83 | }
84 | else {
85 | mask_label[i] = 0.;
86 | }
87 | }
88 | }
89 | }
90 | }
91 |
92 |
93 | //input: nInstance (1,), int
94 | //input: nProposal (1,), int
95 | //input: proposals_idx (sumNPoint), int
96 | //input: proposals_offset (nProposal + 1), int
97 | //input: instance_labels (N), long, 0~total_nInst-1, -100
98 | //input: instance_pointnum (total_nInst), int
99 | //input: mask_scores_sigmoid (sumNPoint, 1), float
100 | //output: proposals_iou (nProposal, total_nInst), float
101 | //output: mask_label (sumNPoint, 1), float
102 | void cal_iou_and_masklabel_cuda(int nInstance, int nProposal, int *proposals_idx, int *proposals_offset, long *instance_labels, int *instance_pointnum, float *proposals_iou, float *mask_scores_sigmoid, float *mask_label, int mode){
103 | get_iou_mask_cuda_<<>>(nInstance, nProposal, proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou, mask_scores_sigmoid, mask_label, mode);
104 | cudaDeviceSynchronize();
105 | get_mask_label_cuda_<<>>(nInstance, nProposal, proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou, mask_scores_sigmoid, mask_label);
106 |
107 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/cal_iou_and_masklabel/cal_iou_and_masklabel.h:
--------------------------------------------------------------------------------
1 | /*
2 | Get the IoU between predictions and gt masks
3 | */
4 |
5 | #ifndef CAL_IOU_AND_MASKLABEL_H
6 | #define CAL_IOU_AND_MASKLABEL_H
7 | #include
8 | #include
9 |
10 | #include "../datatype/datatype.h"
11 |
12 | //
13 | void cal_iou_and_masklabel_cuda(int nInstance, int nProposal, int *proposals_idx, int *proposals_offset, \
14 | long *instance_labels, int *instance_pointnum, float *proposals_iou, float *mask_scores_sigmoid, float *mask_label, int mode);
15 |
16 | void cal_iou_and_masklabel(at::Tensor proposals_idx_tensor, at::Tensor proposals_offset_tensor, \
17 | at::Tensor instance_labels_tensor, at::Tensor instance_pointnum_tensor, \
18 | at::Tensor proposals_iou_tensor, int nInstance, int nProposal, at::Tensor mask_scores_sigmoid_tensor, at::Tensor mask_labels_tensor, int mode);
19 |
20 | #endif //CAL_IOU_AND_MASKLABEL_H
--------------------------------------------------------------------------------
/lib/hais_ops/src/cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include "datatype/datatype.h"
3 |
4 | #include "hierarchical_aggregation/hierarchical_aggregation.cu"
5 | #include "cal_iou_and_masklabel/cal_iou_and_masklabel.cu"
6 |
7 | #include "voxelize/voxelize.cu"
8 | #include "bfs_cluster/bfs_cluster.cu"
9 | #include "roipool/roipool.cu"
10 | #include "get_iou/get_iou.cu"
11 | #include "sec_mean/sec_mean.cu"
12 |
13 | template void voxelize_fp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, float *feats, float *output_feats, Int *rules, bool average);
14 |
15 | template void voxelize_bp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, float *d_output_feats, float *d_feats, Int *rules, bool average);
--------------------------------------------------------------------------------
/lib/hais_ops/src/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 |
6 | #define TOTAL_THREADS 1024
7 |
8 | #define MAX_THREADS_PER_BLOCK 512
9 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
10 |
11 | inline int opt_n_threads(int work_size) {
12 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
13 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
14 | }
15 |
16 | inline dim3 opt_block_config(int x, int y) {
17 | const int x_threads = opt_n_threads(x);
18 | const int y_threads = max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
19 | dim3 block_config(x_threads, y_threads, 1);
20 | return block_config;
21 | }
22 |
23 | #endif
--------------------------------------------------------------------------------
/lib/hais_ops/src/datatype/datatype.cpp:
--------------------------------------------------------------------------------
1 | #include "datatype.h"
2 |
3 | template SparseGrid::SparseGrid() : ctr(0) {
4 | // Sparsehash needs a key to be set aside and never used
5 | Point empty_key;
6 | for(Int i = 0; i < dimension; i++){
7 | empty_key[i] = std::numeric_limits::min();
8 | }
9 | mp.set_empty_key(empty_key);
10 | }
11 |
12 | ConnectedComponent::ConnectedComponent(){}
13 |
14 | void ConnectedComponent::addPoint(Int pt_idx){
15 | pt_idxs.push_back(pt_idx);
16 | }
17 |
--------------------------------------------------------------------------------
/lib/hais_ops/src/datatype/datatype.h:
--------------------------------------------------------------------------------
1 | #ifndef DATATYPE_H
2 | #define DATATYPE_H
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 | using Int = int32_t;
10 |
11 | template using Point = std::array;
12 |
13 | template struct IntArrayHash{
14 | std::size_t operator()(Point const &p) const{
15 | Int hash = 16777619;
16 | for(auto x : p){
17 | hash *= 2166136261;
18 | hash ^= x;
19 | }
20 | return hash;
21 | }
22 | };
23 |
24 | template using SparseGridMap = google::dense_hash_map, Int, IntArrayHash, std::equal_to>>; //
25 |
26 | template class SparseGrid{
27 | public:
28 | Int ctr;
29 | SparseGridMap mp;
30 | SparseGrid();
31 | };
32 |
33 | template using SparseGrids = std::vector>;
34 |
35 | using RuleBook = std::vector>;
36 |
37 |
38 | class ConnectedComponent{
39 | public:
40 | std::vector pt_idxs;
41 | float accum_x = 0.;
42 | float accum_y = 0.;
43 | float accum_z = 0.;
44 | int cls_label = -100;
45 | int batch_idx = -1;
46 | // int npoint = 0;
47 |
48 | ConnectedComponent();
49 | void addPoint(Int pt_idx);
50 | };
51 |
52 | using ConnectedComponents = std::vector;
53 |
54 | #endif //DATATYPE_H
--------------------------------------------------------------------------------
/lib/hais_ops/src/get_iou/get_iou.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Get the IoU between predictions and gt masks
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #include "get_iou.h"
8 |
9 | void get_iou(at::Tensor proposals_idx_tensor, at::Tensor proposals_offset_tensor, at::Tensor instance_labels_tensor, at::Tensor instance_pointnum_tensor, at::Tensor proposals_iou_tensor, int nInstance, int nProposal){
10 | int *proposals_idx = proposals_idx_tensor.data();
11 | int *proposals_offset = proposals_offset_tensor.data();
12 | long *instance_labels = instance_labels_tensor.data();
13 | int *instance_pointnum = instance_pointnum_tensor.data();
14 |
15 | float *proposals_iou = proposals_iou_tensor.data();
16 |
17 | get_iou_cuda(nInstance, nProposal, proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou);
18 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/get_iou/get_iou.cu:
--------------------------------------------------------------------------------
1 | /*
2 | Get the IoU between predictions and gt masks
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #include
8 | #include
9 | #include "get_iou.h"
10 |
11 |
12 | __global__ void get_iou_cuda_(int nInstance, int nProposal, int *proposals_idx, int *proposals_offset, long *instance_labels, int *instance_pointnum, float *proposals_iou){
13 | for(int proposal_id = blockIdx.x; proposal_id < nProposal; proposal_id += gridDim.x){
14 | int start = proposals_offset[proposal_id];
15 | int end = proposals_offset[proposal_id + 1];
16 | int proposal_total = end - start;
17 | for(int instance_id = threadIdx.x; instance_id < nInstance; instance_id += blockDim.x){
18 | int instance_total = instance_pointnum[instance_id];
19 | int intersection = 0;
20 | for(int i = start; i < end; i++){
21 | int idx = proposals_idx[i];
22 | if((int)instance_labels[idx] == instance_id){
23 | intersection += 1;
24 | }
25 | }
26 | proposals_iou[proposal_id * nInstance + instance_id] = (float)intersection / ((float)(proposal_total + instance_total - intersection) + 1e-5);
27 | }
28 | }
29 | }
30 |
31 | //input: proposals_idx (sumNPoint), int
32 | //input: proposals_offset (nProposal + 1), int
33 | //input: instance_labels (N), long, 0~total_nInst-1, -100
34 | //input: instance_pointnum (total_nInst), int
35 | //output: proposals_iou (nProposal, total_nInst), float
36 | void get_iou_cuda(int nInstance, int nProposal, int *proposals_idx, int *proposals_offset, long *instance_labels, int *instance_pointnum, float *proposals_iou){
37 | get_iou_cuda_<<>>(nInstance, nProposal, proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou);
38 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/get_iou/get_iou.h:
--------------------------------------------------------------------------------
1 | /*
2 | Get the IoU between predictions and gt masks
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #ifndef GET_IOU_H
8 | #define GET_IOU_H
9 | #include
10 | #include
11 |
12 | #include "../datatype/datatype.h"
13 |
14 | //
15 | void get_iou_cuda(int nInstance, int nProposal, int *proposals_idx, int *proposals_offset, long *instance_labels, int *instance_pointnum, float *proposals_iou);
16 | void get_iou(at::Tensor proposals_idx_tensor, at::Tensor proposals_offset_tensor, at::Tensor instance_labels_tensor, at::Tensor instance_pointnum_tensor, at::Tensor proposals_iou_tensor, int nInstance, int nProposal);
17 |
18 | #endif //GET_IOU_H
--------------------------------------------------------------------------------
/lib/hais_ops/src/hais_ops.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "datatype/datatype.cpp"
6 |
7 | #include "hierarchical_aggregation/hierarchical_aggregation.cpp"
8 | #include "cal_iou_and_masklabel/cal_iou_and_masklabel.cpp"
9 |
10 | #include "voxelize/voxelize.cpp"
11 | #include "bfs_cluster/bfs_cluster.cpp"
12 | #include "roipool/roipool.cpp"
13 | #include "get_iou/get_iou.cpp"
14 | #include "sec_mean/sec_mean.cpp"
15 |
16 | void voxelize_idx_3d(/* long N*4 */ at::Tensor coords, /* long M*4 */ at::Tensor output_coords,
17 | /* Int N */ at::Tensor input_map, /* Int M*(maxActive+1) */ at::Tensor output_map, Int batchSize, Int mode){
18 | voxelize_idx<3>(coords, output_coords, input_map, output_map, batchSize, mode);
19 | }
20 |
21 | void voxelize_fp_feat(/* cuda float N*C */ at::Tensor feats, // N * 3 -> M * 3 (N >= M)
22 | /* cuda float M*C */ at::Tensor output_feats,
23 | /* cuda Int M*(maxActive+1) */ at::Tensor output_map, Int mode, Int nActive, Int maxActive, Int nPlane){
24 | voxelize_fp(feats, output_feats, output_map, mode, nActive, maxActive, nPlane);
25 | }
26 |
27 |
28 | void voxelize_bp_feat(/* cuda float M*C */ at::Tensor d_output_feats, /* cuda float N*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor output_map,
29 | Int mode, Int nActive, Int maxActive, Int nPlane){
30 | voxelize_bp(d_output_feats, d_feats, output_map, mode, nActive, maxActive, nPlane);
31 | }
32 |
33 | void point_recover_fp_feat(/* cuda float M*C */ at::Tensor feats, /* cuda float N*C */ at::Tensor output_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map,
34 | Int nActive, Int maxActive, Int nPlane){
35 | point_recover_fp(feats, output_feats, idx_map, nActive, maxActive, nPlane);
36 | }
37 |
38 | void point_recover_bp_feat(/* cuda float N*C */ at::Tensor d_output_feats, /* cuda float M*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map,
39 | Int nActive, Int maxActive, Int nPlane){
40 | point_recover_bp(d_output_feats, d_feats, idx_map, nActive, maxActive, nPlane);
41 | }
42 |
--------------------------------------------------------------------------------
/lib/hais_ops/src/hais_ops.h:
--------------------------------------------------------------------------------
1 | #ifndef HAIS_H
2 | #define HAIS_H
3 | #include "datatype/datatype.h"
4 |
5 | #include "hierarchical_aggregation/hierarchical_aggregation.h"
6 | #include "cal_iou_and_masklabel/cal_iou_and_masklabel.h"
7 |
8 | #include "bfs_cluster/bfs_cluster.h"
9 | #include "roipool/roipool.h"
10 | #include "get_iou/get_iou.h"
11 | #include "sec_mean/sec_mean.h"
12 |
13 | void voxelize_idx_3d(/* long N*4 */ at::Tensor coords, /* long M*4 */ at::Tensor output_coords,
14 | /* Int N */ at::Tensor input_map, /* Int M*(maxActive+1) */ at::Tensor output_map, Int batchSize, Int mode);
15 |
16 | void voxelize_fp_feat(/* cuda float N*C */ at::Tensor feats, // N * 3 -> M * 3 (N >= M)
17 | /* cuda float M*C */ at::Tensor output_feats,
18 | /* cuda Int M*(maxActive+1) */ at::Tensor output_map, Int mode, Int nActive, Int maxActive, Int nPlane);
19 |
20 | void voxelize_bp_feat(/* cuda float M*C */ at::Tensor d_output_feats, /* cuda float N*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor output_map,
21 | Int mode, Int nActive, Int maxActive, Int nPlane);
22 |
23 | void point_recover_fp_feat(/* cuda float M*C */ at::Tensor feats, /* cuda float N*C */ at::Tensor output_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map,
24 | Int nActive, Int maxActive, Int nPlane);
25 |
26 | void point_recover_bp_feat(/* cuda float N*C */ at::Tensor d_output_feats, /* cuda float M*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map,
27 | Int nActive, Int maxActive, Int nPlane);
28 |
29 |
30 | #endif // HAIS_H
--------------------------------------------------------------------------------
/lib/hais_ops/src/hais_ops_api.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "hais_ops.h"
5 |
6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
7 |
8 | m.def("hierarchical_aggregation", &hierarchical_aggregation, "hierarchical_aggregation");
9 | m.def("cal_iou_and_masklabel", &cal_iou_and_masklabel, "cal_iou_and_masklabel");
10 |
11 |
12 | m.def("voxelize_idx", &voxelize_idx_3d, "voxelize_idx");
13 | m.def("voxelize_fp", &voxelize_fp_feat, "voxelize_fp");
14 | m.def("voxelize_bp", &voxelize_bp_feat, "voxelize_bp");
15 | m.def("point_recover_fp", &point_recover_fp_feat, "point_recover_fp");
16 | m.def("point_recover_bp", &point_recover_bp_feat, "point_recover_bp");
17 |
18 | m.def("ballquery_batch_p", &ballquery_batch_p, "ballquery_batch_p");
19 | m.def("bfs_cluster", &bfs_cluster, "bfs_cluster");
20 |
21 | m.def("roipool_fp", &roipool_fp, "roipool_fp");
22 | m.def("roipool_bp", &roipool_bp, "roipool_bp");
23 |
24 | m.def("get_iou", &get_iou, "get_iou");
25 |
26 | m.def("sec_mean", &sec_mean, "sec_mean");
27 | m.def("sec_min", &sec_min, "sec_min");
28 | m.def("sec_max", &sec_max, "sec_max");
29 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/hierarchical_aggregation/hierarchical_aggregation.cpp:
--------------------------------------------------------------------------------
1 | #include "hierarchical_aggregation.h"
2 | #include "time.h"
3 |
4 | /* ================================== hierarchical_aggregation ================================== */
5 |
6 | // instance point num for each class, statistical data from the training set
7 | float class_numpoint_mean_dict[20] = {-1., -1., 3917., 12056., 2303., \
8 | 8331., 3948., 3166., 5629., 11719., \
9 | 1003., 3317., 4912., 10221., 3889., \
10 | 4136., 2120., 945., 3967., 2589.};
11 |
12 | ConnectedComponent find_cc(int idx, int *semantic_label, float *coord_shift, int *batch_idxs,
13 | int *ball_query_idxs, int *start_len, int *visited){
14 | ConnectedComponent cc;
15 | cc.addPoint(idx);
16 | cc.accum_x += coord_shift[idx * 3 + 0];
17 | cc.accum_y += coord_shift[idx * 3 + 1];
18 | cc.accum_z += coord_shift[idx * 3 + 2];
19 | cc.cls_label = semantic_label[idx]; // currently cc's label is the label of the start point, convert to float
20 | cc.batch_idx = batch_idxs[idx]; // record batch info
21 | visited[idx] = 1;
22 | std::queue Q;
23 | assert(Q.empty());
24 | Q.push(idx);
25 | while(!Q.empty()){
26 | int cur = Q.front(); Q.pop();
27 | int start = start_len[cur * 2];
28 | int len = start_len[cur * 2 + 1];
29 | int label_cur = semantic_label[cur];
30 | for(int i = start; i < start + len; i++){
31 | int idx_i = ball_query_idxs[i];
32 | if (semantic_label[idx_i] != label_cur) continue;
33 | if (visited[idx_i] == 1) continue;
34 | cc.addPoint(idx_i);
35 | cc.accum_x += coord_shift[idx_i * 3 + 0];
36 | cc.accum_y += coord_shift[idx_i * 3 + 1];
37 | cc.accum_z += coord_shift[idx_i * 3 + 2];
38 | visited[idx_i] = 1;
39 | Q.push(idx_i);
40 | }
41 | }
42 | return cc;
43 | }
44 |
45 | // split clusters into fragment and primary based on point num
46 | void split_clusters(int *semantic_label, float *coord_shift, int *batch_idxs,
47 | int *ball_query_idxs, int *start_len, const int nPoint,
48 | ConnectedComponents &CCs_fragment, ConnectedComponents &CCs_kept, ConnectedComponents &CCs_primary,
49 | int *sumNPoint_fragment, int *sumNPoint_kept, int *sumNPoint_primary){
50 | int visited[nPoint] = {0};
51 | int _class_idx;
52 | float _class_numpoint_mean, low_thre, high_thre;
53 |
54 | for(int i = 0; i < nPoint; i++){
55 | if (visited[i] == 0){
56 | ConnectedComponent CC = find_cc(i, semantic_label, coord_shift, batch_idxs,
57 | ball_query_idxs, start_len, visited);
58 | _class_idx = CC.cls_label;
59 | _class_numpoint_mean = class_numpoint_mean_dict[_class_idx];
60 |
61 | low_thre = 0.05 * _class_numpoint_mean;
62 | high_thre = 0.3 * _class_numpoint_mean;
63 |
64 | if ((int)CC.pt_idxs.size() < high_thre){
65 | CCs_fragment.push_back(CC);
66 | *sumNPoint_fragment += (int)CC.pt_idxs.size();
67 |
68 | // keep fragments which are large enough to be independent instances
69 | if ((int)CC.pt_idxs.size() >= low_thre && (int)CC.pt_idxs.size() < high_thre){
70 | CCs_kept.push_back(CC);
71 | *sumNPoint_kept += (int)CC.pt_idxs.size();
72 | }
73 | }
74 | else {
75 | CCs_primary.push_back(CC);
76 | *sumNPoint_primary += (int)CC.pt_idxs.size();
77 | }
78 | }
79 | }
80 | return;
81 | }
82 |
83 | // convert from ConnectedComponents to (idxs, offsets) representation
84 | void fill_cluster_idxs_(ConnectedComponents &CCs, int *cluster_idxs, int *cluster_offsets, float *cluster_centers){
85 | for(int i = 0; i < (int)CCs.size(); i++){
86 | cluster_offsets[i + 1] = cluster_offsets[i] + (int)CCs[i].pt_idxs.size();
87 |
88 | cluster_centers[i * 5 + 0] = CCs[i].accum_x / (float)CCs[i].pt_idxs.size();
89 | cluster_centers[i * 5 + 1] = CCs[i].accum_y / (float)CCs[i].pt_idxs.size();
90 | cluster_centers[i * 5 + 2] = CCs[i].accum_z / (float)CCs[i].pt_idxs.size();
91 | cluster_centers[i * 5 + 3] = (float)CCs[i].cls_label;
92 | cluster_centers[i * 5 + 4] = (float)CCs[i].batch_idx;
93 |
94 | for(int j = 0; j < (int)CCs[i].pt_idxs.size(); j++){
95 | int idx = CCs[i].pt_idxs[j];
96 | cluster_idxs[(cluster_offsets[i] + j) * 2 + 0] = i;
97 | cluster_idxs[(cluster_offsets[i] + j) * 2 + 1] = idx;
98 | }
99 | }
100 | }
101 |
102 | //input: semantic_label, int, (N)
103 | //input: coord_shift, float, (N, 3)
104 | //input: batch_idxs, int, (N)
105 | //input: ball_query_idxs, int, (nActive)
106 | //input: start_len, int, (N, 2)
107 | //(fragment_idxs, fragment_offsets, fragment_centers) for fragment clusters
108 | //(cluster_idxs_kept_tensor, cluster_offsets_kept_tensor, cluster_centers_kept_tensor) for keeping some fragments
109 | //(primary_idxs_tensor, primary_offsets, primary_centers) for primary clusters
110 | //(primary_idxs_post_tensor, primary_offsets_post_tensor) for aggregated clusters
111 | void hierarchical_aggregation(at::Tensor semantic_label_tensor, at::Tensor coord_shift_tensor, at::Tensor batch_idxs_tensor,
112 | at::Tensor ball_query_idxs_tensor, at::Tensor start_len_tensor,
113 | at::Tensor fragment_idxs_tensor, at::Tensor fragment_offsets_tensor, at::Tensor fragment_centers_tensor,
114 | at::Tensor cluster_idxs_kept_tensor, at::Tensor cluster_offsets_kept_tensor, at::Tensor cluster_centers_kept_tensor,
115 | at::Tensor primary_idxs_tensor, at::Tensor primary_offsets_tensor, at::Tensor primary_centers_tensor,
116 | at::Tensor primary_idxs_post_tensor, at::Tensor primary_offsets_post_tensor,
117 | const int N, const int training_mode_, const int using_set_aggr_){
118 | int *semantic_label = semantic_label_tensor.data();
119 | float *coord_shift = coord_shift_tensor.data();
120 | int *batch_idxs = batch_idxs_tensor.data();
121 | int *ball_query_idxs = ball_query_idxs_tensor.data();
122 | int *start_len = start_len_tensor.data();
123 |
124 | ConnectedComponents CCs_fragment;
125 | ConnectedComponents CCs_kept;
126 | ConnectedComponents CCs_primary;
127 |
128 | int sumNPoint_fragment = 0, sumNPoint_kept = 0, sumNPoint_primary = 0;
129 | split_clusters(semantic_label, coord_shift, batch_idxs, ball_query_idxs, start_len, N,
130 | CCs_fragment, CCs_kept, CCs_primary,
131 | & sumNPoint_fragment, & sumNPoint_kept, & sumNPoint_primary);
132 |
133 | cluster_idxs_kept_tensor.resize_({sumNPoint_kept, 2});
134 | cluster_offsets_kept_tensor.resize_({(int)CCs_kept.size() + 1});
135 | cluster_centers_kept_tensor.resize_({(int)CCs_kept.size(), 5});
136 | cluster_idxs_kept_tensor.zero_();
137 | cluster_offsets_kept_tensor.zero_();
138 | cluster_centers_kept_tensor.zero_();
139 | int *cluster_idxs_kept = cluster_idxs_kept_tensor.data();
140 | int *cluster_offsets_kept = cluster_offsets_kept_tensor.data();
141 | float *cluster_centers_kept = cluster_centers_kept_tensor.data();
142 | fill_cluster_idxs_(CCs_kept, cluster_idxs_kept, cluster_offsets_kept, cluster_centers_kept);
143 |
144 | primary_idxs_tensor.resize_({sumNPoint_primary, 2});
145 | primary_offsets_tensor.resize_({(int)CCs_primary.size() + 1});
146 | primary_centers_tensor.resize_({(int)CCs_primary.size(), 5});
147 | primary_idxs_tensor.zero_();
148 | primary_offsets_tensor.zero_();
149 | primary_centers_tensor.zero_();
150 | int *primary_idxs = primary_idxs_tensor.data();
151 | int *primary_offsets = primary_offsets_tensor.data();
152 | float *primary_centers = primary_centers_tensor.data();
153 | fill_cluster_idxs_(CCs_primary, primary_idxs, primary_offsets, primary_centers);
154 |
155 | if (using_set_aggr_ == 0) { // only point aggr
156 | return;
157 | }
158 |
159 | fragment_idxs_tensor.resize_({sumNPoint_fragment, 2});
160 | fragment_offsets_tensor.resize_({(int)CCs_fragment.size() + 1});
161 | fragment_centers_tensor.resize_({(int)CCs_fragment.size(), 5}); //[:, -2] for cls_label, [:, -1] for batch_idx
162 | fragment_idxs_tensor.zero_();
163 | fragment_offsets_tensor.zero_();
164 | fragment_centers_tensor.zero_();
165 | int *fragment_idxs = fragment_idxs_tensor.data();
166 | int *fragment_offsets = fragment_offsets_tensor.data();
167 | float *fragment_centers = fragment_centers_tensor.data();
168 | fill_cluster_idxs_(CCs_fragment, fragment_idxs, fragment_offsets, fragment_centers);
169 |
170 |
171 | // prerare tensor for storing post-primary
172 | primary_idxs_post_tensor.resize_({sumNPoint_fragment + sumNPoint_primary, 2}); //never overflow, but need to cut off tails
173 | primary_offsets_post_tensor.resize_({(int)CCs_primary.size() + 1});
174 | primary_idxs_post_tensor.zero_();
175 | primary_offsets_post_tensor.zero_();
176 | int *primary_idxs_post = primary_idxs_post_tensor.data();
177 | int *primary_offsets_post = primary_offsets_post_tensor.data();
178 |
179 | // set aggr
180 | hierarchical_aggregation_cuda(sumNPoint_fragment, (int)CCs_fragment.size(), fragment_idxs, fragment_offsets, fragment_centers,
181 | sumNPoint_primary, (int)CCs_primary.size(), primary_idxs, primary_offsets, primary_centers,
182 | primary_idxs_post, primary_offsets_post);
183 |
184 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/hierarchical_aggregation/hierarchical_aggregation.cu:
--------------------------------------------------------------------------------
1 | #include "hierarchical_aggregation.h"
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | #define MAX_PRIMARY_NUM 1024
9 | #define MAX_PER_PRIMARY_ABSORB_FRAGMENT_NUM 1024
10 | #define INFINITY_DIS_SQUARE 10000
11 | #define MAX_PER_PRIMARY_ABSORB_POINT_NUM 8192
12 | #define MAX_THREADS_PER_BLOCK 512
13 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
14 |
15 |
16 |
17 | // input: cuda_fragment_centers (fragment_num * 5,), 5 for (x, y, z, cls_label, batch_idx)
18 | // input: cuda_primary_centers (primary_num * 5,), 5 for (x, y, z, cls_label, batch_idx)
19 | // input: ...
20 | // output: cuda_primary_absorb_fragment_idx
21 | // output: cuda_primary_absorb_fragment_cnt
22 | __global__ void fragment_find_primary_(int primary_num, int *cuda_primary_offsets, float *cuda_primary_centers,
23 | int fragment_num, int *cuda_fragment_offsets, float *cuda_fragment_centers,
24 | int *cuda_primary_absorb_fragment_idx, int *cuda_primary_absorb_fragment_cnt){
25 |
26 | int fragment_idx = blockIdx.x * blockDim.x + threadIdx.x;
27 | if (fragment_idx >= fragment_num) return;
28 |
29 | // find the nearest primary for each fragment
30 | float nearest_dis_square = INFINITY_DIS_SQUARE;
31 | int nearest_idx = -1; // primary_idx
32 | for( int i = 0; i < primary_num; i++){
33 | if (abs(cuda_primary_centers[i * 5 + 3] - cuda_fragment_centers[fragment_idx * 5 + 3]) > 0.1){ //judge same cls_label or not
34 | continue;
35 | }
36 | if (abs(cuda_primary_centers[i * 5 + 4] - cuda_fragment_centers[fragment_idx * 5 + 4]) > 0.1){ //judge same batch_idx or not
37 | continue;
38 | }
39 | float temp_dis_square = pow((cuda_primary_centers[i * 5 + 0] - cuda_fragment_centers[fragment_idx * 5 + 0]), 2)
40 | + pow((cuda_primary_centers[i * 5 + 1] - cuda_fragment_centers[fragment_idx * 5 + 1]), 2)
41 | + pow((cuda_primary_centers[i * 5 + 2] - cuda_fragment_centers[fragment_idx * 5 + 2]), 2);
42 | if (temp_dis_square < nearest_dis_square){
43 | nearest_dis_square = temp_dis_square;
44 | nearest_idx = i;
45 | }
46 | }
47 | if (nearest_idx == -1) return; // fragment not belong to any primary
48 |
49 | // r_size
50 | int primary_point_num = cuda_primary_offsets[nearest_idx + 1] - cuda_primary_offsets[nearest_idx];
51 | float r_size = 0.01 * sqrt(float(primary_point_num));
52 |
53 | // r_cls
54 | // instance radius for each class, statistical data from the training set
55 | float class_radius_mean[20] = {-1., -1., 0.7047687683952325, 1.1732690381942337, 0.39644035821116036, \
56 | 1.011516629020215, 0.7260155292902369, 0.8674973999335017, 0.8374931435447094, 1.0454153869133096, \
57 | 0.32879464797430913, 1.1954566226966346, 0.8628817944400078, 1.0416287916782507, 0.6602697958671507, \
58 | 0.8541363897836871, 0.38055290598206537, 0.3011878752684007, 0.7420871812436316, 0.4474268644407741};
59 | int _class_idx = (int)cuda_fragment_centers[fragment_idx * 5 + 3];
60 | float r_cls = class_radius_mean[_class_idx] * 1.;
61 |
62 | // r_set
63 | float r_set = max(r_size, r_cls);
64 |
65 | // judge
66 | if ( nearest_dis_square < r_set * r_set ){
67 | int _offect = atomicAdd(cuda_primary_absorb_fragment_cnt + nearest_idx, 1);
68 | if (_offect < MAX_PER_PRIMARY_ABSORB_FRAGMENT_NUM)
69 | cuda_primary_absorb_fragment_idx[nearest_idx * MAX_PER_PRIMARY_ABSORB_FRAGMENT_NUM + _offect] = fragment_idx;
70 | else {
71 | ;
72 | }
73 | }
74 | }
75 |
76 | // input: ...
77 | // output: cuda_concat_idxs
78 | // output: cuda_concat_point_num,
79 | __global__ void concat_fragments_(
80 | int *cuda_fragment_idxs, int *cuda_fragment_offsets,
81 | int *cuda_primary_idxs, int *cuda_primary_offsets,
82 | int *cuda_primary_absorb_fragment_idx, int *cuda_primary_absorb_fragment_cnt,
83 | int *cuda_concat_idxs, int *cuda_concat_point_num,
84 | int primary_num){
85 |
86 | int primary_idx = blockIdx.x;
87 | if (primary_idx >= primary_num) return;
88 |
89 | int _accu_offset = 0; // unit is point
90 | for (int i=0; i>>(
146 | primary_num, cuda_primary_offsets, cuda_primary_centers,
147 | fragment_num, cuda_fragment_offsets, cuda_fragment_centers,
148 | cuda_primary_absorb_fragment_idx, cuda_primary_absorb_fragment_cnt);
149 | cudaDeviceSynchronize();
150 |
151 | // concatenate fragments belonging to the same primary
152 | int *cuda_concat_idxs;
153 | int *cuda_concat_point_num;
154 | cudaMalloc((void**)&cuda_concat_idxs, primary_num * MAX_PER_PRIMARY_ABSORB_POINT_NUM * 2 * sizeof(int) + sizeof(int));
155 | cudaMalloc((void**)&cuda_concat_point_num, primary_num * sizeof(int) + sizeof(int));
156 | assert(primary_num <= MAX_PRIMARY_NUM);
157 | concat_fragments_<<>>(
158 | cuda_fragment_idxs, cuda_fragment_offsets,
159 | cuda_primary_idxs, cuda_primary_offsets,
160 | cuda_primary_absorb_fragment_idx, cuda_primary_absorb_fragment_cnt,
161 | cuda_concat_idxs, cuda_concat_point_num,
162 | primary_num);
163 | cudaDeviceSynchronize();
164 |
165 | // merge primary instances and fragments
166 | int *concat_point_num = new int [primary_num + 1]; // allocate on host
167 | cudaMemcpy(concat_point_num, cuda_concat_point_num, primary_num * sizeof(int), cudaMemcpyDeviceToHost);
168 | int _accu_offset = 0;
169 | for (int i=0; i < primary_num; i++){
170 | // add primary instances
171 | cudaMemcpy(primary_idxs_post + _accu_offset * 2,
172 | cuda_primary_idxs + primary_offsets[i] * 2,
173 | (primary_offsets[i + 1] - primary_offsets[i]) * 2 * sizeof(int),
174 | cudaMemcpyDeviceToHost);
175 | _accu_offset += (primary_offsets[i + 1] - primary_offsets[i]);
176 |
177 | // add absorbed fragments
178 | cudaMemcpy(primary_idxs_post + _accu_offset * 2,
179 | cuda_concat_idxs + i * MAX_PER_PRIMARY_ABSORB_POINT_NUM * 2,
180 | concat_point_num[i] * 2 * sizeof(int),
181 | cudaMemcpyDeviceToHost);
182 | _accu_offset += concat_point_num[i];
183 |
184 | // writing offsets
185 | primary_offsets_post[i + 1] = _accu_offset;
186 | }
187 | cudaDeviceSynchronize();
188 |
189 | cudaError_t err;
190 | err = cudaGetLastError();
191 | if (cudaSuccess != err) {
192 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
193 | exit(-1);
194 | }
195 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/hierarchical_aggregation/hierarchical_aggregation.h:
--------------------------------------------------------------------------------
1 | /*
2 | Hierarchichal Aggregation Algorithm
3 | */
4 |
5 | #ifndef HIERARCHICAL_AGGREGATION_H
6 | #define HIERARCHICAL_AGGREGATION_H
7 | #include
8 | #include
9 | #include
10 |
11 | #include "../datatype/datatype.h"
12 |
13 |
14 | void hierarchical_aggregation(at::Tensor semantic_label_tensor, at::Tensor coord_shift_tensor, at::Tensor batch_idxs_tensor,
15 | at::Tensor ball_query_idxs_tensor, at::Tensor start_len_tensor,
16 | at::Tensor fragment_idxs_tensor, at::Tensor fragment_offsets_tensor, at::Tensor fragment_centers_tensor,
17 | at::Tensor cluster_idxs_kept_tensor, at::Tensor cluster_offsets_kept_tensor, at::Tensor cluster_centers_kept_tensor,
18 | at::Tensor primary_idxs_tensor, at::Tensor primary_offsets_tensor, at::Tensor primary_centers_tensor,
19 | at::Tensor primary_idxs_post_tensor, at::Tensor primary_offsets_post_tensor,
20 | const int N, const int training_mode_, const int using_set_aggr_);
21 |
22 |
23 | void hierarchical_aggregation_cuda(
24 | int fragment_total_point_num, int fragment_num, int *fragment_idxs, int *fragment_offsets, float *fragment_centers,
25 | int primary_total_point_num, int primary_num, int *primary_idxs, int *primary_offsets, float *primary_centers,
26 | int *primary_idxs_post, int *primary_offsets_post
27 | );
28 | #endif //HIERARCHICAL_AGGREGATION_H
29 |
30 |
--------------------------------------------------------------------------------
/lib/hais_ops/src/roipool/roipool.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | ROI Max Pool
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #include "roipool.h"
8 |
9 | void roipool_fp(at::Tensor feats_tensor, at::Tensor proposals_offset_tensor, at::Tensor output_feats_tensor, at::Tensor output_maxidx_tensor, int nProposal, int C){
10 | float *feats = feats_tensor.data();
11 | int *proposals_offset = proposals_offset_tensor.data();
12 | float *output_feats = output_feats_tensor.data();
13 | int *output_maxidx = output_maxidx_tensor.data();
14 |
15 | roipool_fp_cuda(nProposal, C, feats, proposals_offset, output_feats, output_maxidx);
16 | }
17 |
18 |
19 | void roipool_bp(at::Tensor d_feats_tensor, at::Tensor proposals_offset_tensor, at::Tensor output_maxidx_tensor, at::Tensor d_output_feats_tensor, int nProposal, int C){
20 | float *d_feats = d_feats_tensor.data();
21 | int *proposals_offset = proposals_offset_tensor.data();
22 | int *output_maxidx = output_maxidx_tensor.data();
23 | float *d_output_feats = d_output_feats_tensor.data();
24 |
25 | roipool_bp_cuda(nProposal, C, d_feats, proposals_offset, output_maxidx, d_output_feats);
26 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/roipool/roipool.cu:
--------------------------------------------------------------------------------
1 | /*
2 | ROI Max Pool
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #include
8 | #include
9 | #include "roipool.h"
10 |
11 | // fp
12 | __global__ void roipool_fp_cuda_(int nProposal, int C, float *feats, int *proposals_offset, float *output_feats, int *output_maxidx){
13 | for(int pp_id = blockIdx.x; pp_id < nProposal; pp_id += gridDim.x){
14 | int start = proposals_offset[pp_id];
15 | int end = proposals_offset[pp_id + 1];
16 |
17 | for(int plane = threadIdx.x; plane < C; plane += blockDim.x){
18 | int argmax_idx = -1;
19 | float max_val = -1e50;
20 |
21 | for(int i = start; i < end; i++){
22 | if(feats[i * C + plane] > max_val){
23 | argmax_idx = i;
24 | max_val = feats[i * C + plane];
25 | }
26 | }
27 | output_maxidx[pp_id * C + plane] = argmax_idx;
28 | output_feats[pp_id * C + plane] = max_val;
29 | }
30 | }
31 | }
32 |
33 | //input: feats (sumNPoint, C) float
34 | //input: proposals_offset (nProposal + 1) int
35 | //output: output_feats (nProposal, C) float
36 | //output: output_maxidx (nProposal, C) int
37 | void roipool_fp_cuda(int nProposal, int C, float *feats, int *proposals_offset, float *output_feats, int *output_maxidx){
38 | roipool_fp_cuda_<<>>(nProposal, C, feats, proposals_offset, output_feats, output_maxidx);
39 | }
40 |
41 | // bp
42 | __global__ void roipool_bp_cuda_(int nProposal, int C, float *d_feats, int *proposals_offset, int *output_maxidx, float *d_output_feats){
43 | for(int pp_id = blockIdx.x; pp_id < nProposal; pp_id += gridDim.x){
44 | for(int plane = threadIdx.x; plane < C; plane += blockDim.x){
45 | int argmax_idx = output_maxidx[pp_id * C + plane];
46 | atomicAdd(&d_feats[argmax_idx * C + plane], d_output_feats[pp_id * C + plane]);
47 | }
48 | }
49 | }
50 |
51 | //input: d_output_feats (nProposal, C) float
52 | //input: output_maxidx (nProposal, C) int
53 | //input: proposals_offset (nProposal + 1) int
54 | //output: d_feats (sumNPoint, C) float
55 | void roipool_bp_cuda(int nProposal, int C, float *d_feats, int *proposals_offset, int *output_maxidx, float *d_output_feats){
56 | roipool_bp_cuda_<<>>(nProposal, C, d_feats, proposals_offset, output_maxidx, d_output_feats);
57 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/roipool/roipool.h:
--------------------------------------------------------------------------------
1 | /*
2 | ROI Max Pool
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #ifndef ROIPOOL_H
8 | #define ROIPOOL_H
9 | #include
10 | #include
11 |
12 | #include "../datatype/datatype.h"
13 |
14 | //
15 | void roipool_fp(at::Tensor feats_tensor, at::Tensor proposals_offset_tensor, at::Tensor output_feats_tensor, at::Tensor output_maxidx_tensor, int nProposal, int C);
16 |
17 | void roipool_fp_cuda(int nProposal, int C, float *feats, int *proposals_offset, float *output_feats, int *output_maxidx);
18 |
19 |
20 | //
21 | void roipool_bp(at::Tensor d_feats_tensor, at::Tensor proposals_offset_tensor, at::Tensor output_maxidx_tensor, at::Tensor d_output_feats_tensor, int nProposal, int C);
22 |
23 | void roipool_bp_cuda(int nProposal, int C, float *d_feats, int *proposals_offset, int *output_maxidx, float *d_output_feats);
24 |
25 | #endif //ROIPOOL_H
26 |
--------------------------------------------------------------------------------
/lib/hais_ops/src/sec_mean/sec_mean.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Segment Operations (mean, max, min)
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #include "sec_mean.h"
8 |
9 | void sec_mean(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C){
10 | int *offsets = offsets_tensor.data();
11 | float *inp = inp_tensor.data();
12 | float *out = out_tensor.data();
13 |
14 | sec_mean_cuda(nProposal, C, inp, offsets, out);
15 | }
16 |
17 | void sec_min(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C){
18 | int *offsets = offsets_tensor.data();
19 | float *inp = inp_tensor.data();
20 | float *out = out_tensor.data();
21 |
22 | sec_min_cuda(nProposal, C, inp, offsets, out);
23 | }
24 |
25 | void sec_max(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C){
26 | int *offsets = offsets_tensor.data();
27 | float *inp = inp_tensor.data();
28 | float *out = out_tensor.data();
29 |
30 | sec_max_cuda(nProposal, C, inp, offsets, out);
31 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/sec_mean/sec_mean.cu:
--------------------------------------------------------------------------------
1 | /*
2 | Segment Operations (mean, max, min) (no bp)
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #include
8 | #include
9 | #include "sec_mean.h"
10 |
11 | /* ================================== sec_mean ================================== */
12 | __global__ void sec_mean_cuda_(int nProposal, int C, float *inp, int *offsets, float *out){
13 | for(int p_id = blockIdx.x; p_id < nProposal; p_id += gridDim.x){
14 | int start = offsets[p_id];
15 | int end = offsets[p_id + 1];
16 |
17 | float count = (float)(end - start);
18 |
19 | for(int plane = threadIdx.x; plane < C; plane += blockDim.x){
20 | float mean = 0;
21 | for(int i = start; i < end; i++){
22 | mean += (inp[i * C + plane] / count);
23 | }
24 | out[p_id * C + plane] = mean;
25 | }
26 | }
27 | }
28 |
29 | //input: inp (N, C) float
30 | //input: offsets (nProposal + 1) int
31 | //output: out (nProposal, C) float
32 | void sec_mean_cuda(int nProposal, int C, float *inp, int *offsets, float *out){
33 | sec_mean_cuda_<<>>(nProposal, C, inp, offsets, out);
34 | }
35 |
36 |
37 | /* ================================== sec_min ================================== */
38 | __global__ void sec_min_cuda_(int nProposal, int C, float *inp, int *offsets, float *out){
39 | for(int p_id = blockIdx.x; p_id < nProposal; p_id += gridDim.x){
40 | int start = offsets[p_id];
41 | int end = offsets[p_id + 1];
42 |
43 | for(int plane = threadIdx.x; plane < C; plane += blockDim.x){
44 | float min_val = 1e50;
45 | for(int i = start; i < end; i++){
46 | if(inp[i * C + plane] < min_val){
47 | min_val = inp[i * C + plane];
48 | }
49 | }
50 | out[p_id * C + plane] = min_val;
51 | }
52 | }
53 | }
54 |
55 | //input: inp (N, C) float
56 | //input: offsets (nProposal + 1) int
57 | //output: out (nProposal, C) float
58 | void sec_min_cuda(int nProposal, int C, float *inp, int *offsets, float *out){
59 | sec_min_cuda_<<>>(nProposal, C, inp, offsets, out);
60 | }
61 |
62 |
63 | /* ================================== sec_max ================================== */
64 | __global__ void sec_max_cuda_(int nProposal, int C, float *inp, int *offsets, float *out){
65 | for(int p_id = blockIdx.x; p_id < nProposal; p_id += gridDim.x){
66 | int start = offsets[p_id];
67 | int end = offsets[p_id + 1];
68 |
69 | for(int plane = threadIdx.x; plane < C; plane += blockDim.x){
70 | float max_val = -1e50;
71 | for(int i = start; i < end; i++){
72 | if(inp[i * C + plane] > max_val){
73 | max_val = inp[i * C + plane];
74 | }
75 | }
76 | out[p_id * C + plane] = max_val;
77 | }
78 | }
79 | }
80 |
81 | //input: inp (N, C) float
82 | //input: offsets (nProposal + 1) int
83 | //output: out (nProposal, C) float
84 | void sec_max_cuda(int nProposal, int C, float *inp, int *offsets, float *out){
85 | sec_max_cuda_<<>>(nProposal, C, inp, offsets, out);
86 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/sec_mean/sec_mean.h:
--------------------------------------------------------------------------------
1 | /*
2 | Segment Operations (mean, max, min)
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #ifndef SEC_MEAN_H
8 | #define SEC_MEAN_H
9 | #include
10 | #include
11 |
12 | #include "../datatype/datatype.h"
13 |
14 | void sec_mean(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C);
15 | void sec_mean_cuda(int nProposal, int C, float *inp, int *offsets, float *out);
16 |
17 | void sec_min(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C);
18 | void sec_min_cuda(int nProposal, int C, float *inp, int *offsets, float *out);
19 |
20 | void sec_max(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C);
21 | void sec_max_cuda(int nProposal, int C, float *inp, int *offsets, float *out);
22 |
23 |
24 | #endif //SEC_MEAN_H
25 |
--------------------------------------------------------------------------------
/lib/hais_ops/src/voxelize/voxelize.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Points to Voxels & Voxels to Points (Modified from SparseConv)
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #include "voxelize.h"
8 |
9 | /* ================================== voxelize_idx ================================== */
10 | template
11 | void voxelize_idx(/* long N*4 */ at::Tensor coords, /* long M*4 */ at::Tensor output_coords,
12 | /* Int N */ at::Tensor input_map, /* Int M*(maxActive+1) */ at::Tensor output_map, Int batchSize, Int mode){
13 | assert(coords.ndimension() == 2);
14 | assert(coords.size(1) >= dimension and coords.size(1) <= dimension + 1);
15 |
16 | RuleBook voxelizeRuleBook; // rule[1]: M voxels -> N points output_map
17 | SparseGrids inputSGs; // voxel_coords -> voxel_idx in M voxels input_map: N points -> M voxels
18 | Int nActive = 0;
19 |
20 | Int maxActive = voxelize_inputmap(inputSGs, input_map.data(), voxelizeRuleBook, nActive, coords.data(), coords.size(0), coords.size(1), batchSize, mode);
21 |
22 | output_map.resize_({nActive, maxActive + 1});
23 | output_map.zero_();
24 |
25 | output_coords.resize_({nActive, coords.size(1)});
26 | output_coords.zero_();
27 |
28 | Int *oM = output_map.data();
29 | long *oC = output_coords.data();
30 | voxelize_outputmap(coords.data(), oC, oM, &voxelizeRuleBook[1][0], nActive, maxActive);
31 | }
32 |
33 |
34 | template
35 | void voxelize_outputmap(long *coords, long *output_coords, Int *output_map, Int *rule, Int nOutputRows, Int maxActive){
36 | for(Int i = 0; i < nOutputRows; i++){
37 | for(Int j = 0; j <= maxActive; j++)
38 | output_map[j] = rule[j];
39 | Int inputIdx = rule[1];
40 | rule += (1 + maxActive);
41 | output_map += (1 + maxActive);
42 |
43 | long *coord = coords + inputIdx * (dimension + 1);
44 | long *output_coord = output_coords + i * (dimension + 1);
45 | for(Int j = 0; j <= dimension; j++){
46 | output_coord[j] = coord[j];
47 | }
48 | }
49 | }
50 |
51 | //mode 0=guaranteed unique 1=last item(overwrite) 2=first item(keep) 3=sum, 4=mean
52 | //input: coords
53 | //output: SGs: one map for each batch: map from voxel_coord to voxel_idx(in M voxels)
54 | //output: input_map: N, N points -> M voxels
55 | //output: rules
56 | //output: nActive
57 | //output: maxActive
58 | template
59 | Int voxelize_inputmap(SparseGrids &SGs, Int *input_map, RuleBook &rules, Int &nActive, long *coords, Int nInputRows, Int nInputColumns, Int batchSize, Int mode){
60 | assert(nActive == 0);
61 | assert(rules.size() == 0);
62 | assert(SGs.size() == 0);
63 |
64 | SGs.resize(batchSize);
65 | Point p;
66 |
67 | std::vector> outputRows;
68 | if(nInputColumns == dimension){
69 | SGs.resize(1);
70 | auto &sg = SGs[0];
71 | for(Int i = 0; i < nInputRows; i++){
72 | for(Int j = 0; j < dimension; j++)
73 | p[j] = coords[j];
74 | coords += dimension;
75 | auto iter = sg.mp.find(p);
76 | if (iter == sg.mp.end()){
77 | sg.mp[p] = nActive++;
78 | outputRows.resize(nActive);
79 | }
80 | outputRows[sg.mp[p]].push_back(i);
81 |
82 | input_map[i] = sg.mp[p];
83 | }
84 | }
85 | else{ // nInputColumns == dimension + 1 (1 in index 0 for batchidx)
86 | Int batchIdx;
87 | for(Int i = 0; i < nInputRows; i++){
88 | batchIdx = coords[0];
89 | for(Int j = 0; j < dimension; j++)
90 | p[j] = coords[j + 1];
91 | coords += (dimension + 1);
92 | if(batchIdx + 1 >= (Int)SGs.size()){
93 | SGs.resize(batchIdx + 1);
94 | }
95 | auto &sg = SGs[batchIdx];
96 | auto iter = sg.mp.find(p);
97 | if(iter == sg.mp.end()){
98 | sg.mp[p] = nActive++;
99 | outputRows.resize(nActive);
100 | }
101 | outputRows[sg.mp[p]].push_back(i);
102 |
103 | input_map[i] = sg.mp[p];
104 | }
105 | }
106 |
107 | // Rulebook Format
108 | // rules[0][0] == mode
109 | // rules[0][1] == maxActive per spatial location (==1 for modes 0,1,2)
110 | // rules[0][2] == nInputRows
111 | // rules[0][3] == nOutputRows
112 | // rules[1] nOutputRows x (1+maxActive)
113 | rules.resize(2);
114 | rules[0].push_back(mode);
115 | rules[0].push_back(1);
116 | rules[0].push_back(nInputRows);
117 | rules[0].push_back(outputRows.size());
118 | auto &rule = rules[1];
119 | if(mode == 0){
120 | assert(nInputRows == (Int)outputRows.size());
121 | for(Int i = 0; i < nActive; i++){
122 | rule.push_back(1);
123 | assert((Int)outputRows[i].size() == 1);
124 | rule.push_back(outputRows[i][0]);
125 | }
126 | }
127 | if(mode == 1){
128 | for(Int i = 0; i < nActive; i++){
129 | rule.push_back(1);
130 | rule.push_back(outputRows[i].front());
131 | }
132 | }
133 | if(mode == 2){
134 | for(Int i = 0; i < nActive; i++){
135 | rule.push_back(1);
136 | rule.push_back(outputRows[i].back());
137 | }
138 | }
139 | Int maxActive = 1;
140 | if(mode == 3 or mode == 4){
141 | for(auto &row: outputRows)
142 | maxActive = std::max(maxActive, (Int)row.size());
143 | rules[0][1] = maxActive;
144 | for(auto &row: outputRows){
145 | rule.push_back(row.size());
146 | for(auto &r: row)
147 | rule.push_back(r);
148 | rule.resize((rule.size() + maxActive) / (maxActive + 1) * (maxActive + 1));
149 | }
150 | }
151 | return maxActive;
152 | }
153 |
154 |
155 | /* ================================== voxelize ================================== */
156 | template
157 | void voxelize_fp(/* cuda float N*C */ at::Tensor feats, // N * 3 -> M * 3 (N >= M)
158 | /* cuda float M*C */ at::Tensor output_feats,
159 | /* cuda Int M*(maxActive+1) */ at::Tensor output_map, Int mode, Int nActive, Int maxActive, Int nPlane){
160 |
161 | auto iF = feats.data();
162 | auto oF = output_feats.data();
163 |
164 | Int *rules = output_map.data();
165 |
166 | voxelize_fp_cuda(nActive, maxActive, nPlane, iF, oF, rules, mode==4);
167 | }
168 |
169 | template
170 | void voxelize_bp(/* cuda float M*C */ at::Tensor d_output_feats, /* cuda float N*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor output_map,
171 | Int mode, Int nActive, Int maxActive, Int nPlane){
172 | auto d_oF = d_output_feats.data();
173 | auto d_iF = d_feats.data();
174 |
175 | Int *rules = output_map.data();
176 |
177 | voxelize_bp_cuda(nActive, maxActive, nPlane, d_oF, d_iF, rules, mode==4);
178 | }
179 |
180 | /* ================================== point_recover ================================== */
181 | template
182 | void point_recover_fp(/* cuda float M*C */ at::Tensor feats, /* cuda float N*C */ at::Tensor output_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map,
183 | Int nActive, Int maxActive, Int nPlane){
184 | auto iF = feats.data();
185 | auto oF = output_feats.data();
186 |
187 | Int *rules = idx_map.data();
188 |
189 | voxelize_bp_cuda(nActive, maxActive, nPlane, iF, oF, rules, false);
190 | }
191 |
192 |
193 | template
194 | void point_recover_bp(/* cuda float N*C */ at::Tensor d_output_feats, /* cuda float M*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map,
195 | Int nActive, Int maxActive, Int nPlane){
196 | auto d_oF = d_output_feats.data();
197 | auto d_iF = d_feats.data();
198 |
199 | Int *rules = idx_map.data();
200 |
201 | voxelize_fp_cuda(nActive, maxActive, nPlane, d_oF, d_iF, rules, false);
202 | }
--------------------------------------------------------------------------------
/lib/hais_ops/src/voxelize/voxelize.cu:
--------------------------------------------------------------------------------
1 | /*
2 | Points to Voxels & Voxels to Points (Modified from SparseConv)
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #include "voxelize.h"
8 |
9 | template
10 | __global__ void voxelize_fp_cuda_(Int nOutputRows, Int maxActive, Int nPlanes, T *feats, T *output_feats, Int *rules, bool average){
11 | for(int row = blockIdx.x; row < nOutputRows; row += gridDim.x){
12 | T *out = output_feats + row * nPlanes;
13 | Int *r = rules + row * (maxActive + 1);
14 | Int nActive = r[0];
15 | T multiplier = (average and nActive > 0) ? (T) 1 / nActive : (T) 1;
16 | for(int i = 1; i <= nActive; i++){
17 | T *inp = feats + r[i] * nPlanes;
18 | for(int plane = threadIdx.x; plane < nPlanes; plane += blockDim.x){
19 | atomicAdd(&out[plane], multiplier * inp[plane]);
20 | }
21 | }
22 | }
23 | }
24 |
25 | // input: feats N * C
26 | // input: rules M * (1 + maxActive)
27 | // output: output_feats M * C
28 | template
29 | void voxelize_fp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, T *feats, T *output_feats, Int *rules, bool average){
30 | voxelize_fp_cuda_<<>>(nOutputRows, maxActive, nPlanes, feats, output_feats, rules, average);
31 | }
32 |
33 |
34 | template
35 | __global__ void voxelize_bp_cuda_(Int nOutputRows, Int maxActive, Int nPlanes, T *d_output_feats, T *d_feats, Int *rules, bool average){
36 | for(int row = blockIdx.x; row < nOutputRows; row += gridDim.x){
37 | T *out = d_output_feats + row * nPlanes;
38 | Int *r = rules + row * (maxActive + 1);
39 | Int nActive = r[0];
40 | T multiplier = (average and nActive > 0) ? (T) 1 / nActive : (T) 1;
41 | for(int i = 1; i <= nActive; i++){
42 | T *inp = d_feats + r[i] * nPlanes;
43 | for(int plane = threadIdx.x; plane < nPlanes; plane += blockDim.x){
44 | atomicAdd(&inp[plane], multiplier * out[plane]);
45 | }
46 | }
47 | }
48 | }
49 |
50 | template
51 | void voxelize_bp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, T *d_output_feats, T *d_feats, Int *rules, bool average){
52 | voxelize_bp_cuda_<<>>(nOutputRows, maxActive, nPlanes, d_output_feats, d_feats, rules, average);
53 | }
54 |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/lib/hais_ops/src/voxelize/voxelize.h:
--------------------------------------------------------------------------------
1 | /*
2 | Points to Voxels & Voxels to Points (Modified from SparseConv)
3 | Written by Li Jiang
4 | All Rights Reserved 2020.
5 | */
6 |
7 | #ifndef VOXELIZE_H
8 | #define VOXELIZE_H
9 | #include
10 | #include
11 |
12 | #include "../datatype/datatype.h"
13 |
14 | /* ================================== voxelize_idx ================================== */
15 | template
16 | void voxelize_idx(/* long N*4 */ at::Tensor coords, /* long M*4 */ at::Tensor output_coords,
17 | /* Int N */ at::Tensor input_map, /* Int M*(maxActive+1) */ at::Tensor output_map, Int batchSize, Int mode);
18 |
19 | template
20 | void voxelize_outputmap(long *coords, long *output_coords, Int *output_map, Int *rule, Int nOutputRows, Int maxActive);
21 |
22 | template
23 | Int voxelize_inputmap(SparseGrids &SGs, Int *input_map, RuleBook &rules, Int &nActive, long *coords, Int nInputRows, Int nInputColumns, Int batchSize, Int mode);
24 |
25 | /* ================================== voxelize ================================== */
26 | template
27 | void voxelize_fp(/* cuda float N*C */ at::Tensor feats, // N * 3 -> M * 3 (N >= M)
28 | /* cuda float M*C */ at::Tensor output_feats,
29 | /* cuda Int M*(maxActive+1) */ at::Tensor output_map, Int mode, Int nActive, Int maxActive, Int nPlane);
30 |
31 | template
32 | void voxelize_fp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, T *feats, T *output_feats, Int *rules, bool average);
33 |
34 |
35 | //
36 | template
37 | void voxelize_bp(/* cuda float M*C */ at::Tensor d_output_feats, /* cuda float N*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor output_map,
38 | Int mode, Int nActive, Int maxActive, Int nPlane);
39 |
40 | template
41 | void voxelize_bp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, T *d_output_feats, T *d_feats, Int *rules, bool average);
42 |
43 |
44 | /* ================================== point_recover ================================== */
45 | template
46 | void point_recover_fp(/* cuda float M*C */ at::Tensor feats, /* cuda float N*C */ at::Tensor output_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map,
47 | Int nActive, Int maxActive, Int nPlane);
48 |
49 | //
50 | template
51 | void point_recover_bp(/* cuda float N*C */ at::Tensor d_output_feats, /* cuda float M*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map,
52 | Int nActive, Int maxActive, Int nPlane);
53 |
54 |
55 | #endif //VOXELIZE_H
56 |
--------------------------------------------------------------------------------
/model/hais/hais.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import spconv
4 | from spconv.modules import SparseModule
5 | import functools
6 | from collections import OrderedDict
7 | import sys
8 | sys.path.append('../../')
9 |
10 | from lib.hais_ops.functions import hais_ops
11 | from util import utils
12 |
13 |
14 | class ResidualBlock(SparseModule):
15 | def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
16 | super().__init__()
17 |
18 | if in_channels == out_channels:
19 | self.i_branch = spconv.SparseSequential(
20 | nn.Identity()
21 | )
22 | else:
23 | self.i_branch = spconv.SparseSequential(
24 | spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False)
25 | )
26 |
27 | self.conv_branch = spconv.SparseSequential(
28 | norm_fn(in_channels),
29 | nn.ReLU(),
30 | spconv.SubMConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key),
31 | norm_fn(out_channels),
32 | nn.ReLU(),
33 | spconv.SubMConv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key)
34 | )
35 |
36 | def forward(self, input):
37 | identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape, input.batch_size)
38 | output = self.conv_branch(input)
39 | output.features += self.i_branch(identity).features
40 |
41 | return output
42 |
43 |
44 | class VGGBlock(SparseModule):
45 | def __init__(self, in_channels, out_channels, norm_fn, indice_key=None):
46 | super().__init__()
47 |
48 | self.conv_layers = spconv.SparseSequential(
49 | norm_fn(in_channels),
50 | nn.ReLU(),
51 | spconv.SubMConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key)
52 | )
53 |
54 | def forward(self, input):
55 | return self.conv_layers(input)
56 |
57 |
58 | class UBlock(nn.Module):
59 | def __init__(self, nPlanes, norm_fn, block_reps, block, indice_key_id=1):
60 |
61 | super().__init__()
62 |
63 | self.nPlanes = nPlanes
64 |
65 | blocks = {'block{}'.format(i): block(nPlanes[0], nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id)) for i in range(block_reps)}
66 | blocks = OrderedDict(blocks)
67 | self.blocks = spconv.SparseSequential(blocks)
68 |
69 | if len(nPlanes) > 1:
70 | self.conv = spconv.SparseSequential(
71 | norm_fn(nPlanes[0]),
72 | nn.ReLU(),
73 | spconv.SparseConv3d(nPlanes[0], nPlanes[1], kernel_size=2, stride=2, bias=False, indice_key='spconv{}'.format(indice_key_id))
74 | )
75 |
76 | self.u = UBlock(nPlanes[1:], norm_fn, block_reps, block, indice_key_id=indice_key_id+1)
77 |
78 | self.deconv = spconv.SparseSequential(
79 | norm_fn(nPlanes[1]),
80 | nn.ReLU(),
81 | spconv.SparseInverseConv3d(nPlanes[1], nPlanes[0], kernel_size=2, bias=False, indice_key='spconv{}'.format(indice_key_id))
82 | )
83 |
84 | blocks_tail = {}
85 | for i in range(block_reps):
86 | blocks_tail['block{}'.format(i)] = block(nPlanes[0] * (2 - i), nPlanes[0], norm_fn, indice_key='subm{}'.format(indice_key_id))
87 | blocks_tail = OrderedDict(blocks_tail)
88 | self.blocks_tail = spconv.SparseSequential(blocks_tail)
89 |
90 | def forward(self, input):
91 |
92 | output = self.blocks(input)
93 | identity = spconv.SparseConvTensor(output.features, output.indices, output.spatial_shape, output.batch_size)
94 | if len(self.nPlanes) > 1:
95 | output_decoder = self.conv(output)
96 | output_decoder = self.u(output_decoder)
97 | output_decoder = self.deconv(output_decoder)
98 | output.features = torch.cat((identity.features, output_decoder.features), dim=1)
99 | output = self.blocks_tail(output)
100 | return output
101 |
102 | class HAIS(nn.Module):
103 | def __init__(self, cfg):
104 | super().__init__()
105 |
106 | input_c = cfg.input_channel
107 | width = cfg.width
108 | classes = cfg.classes
109 | block_reps = cfg.block_reps
110 | block_residual = cfg.block_residual
111 |
112 | self.point_aggr_radius = cfg.point_aggr_radius
113 | self.cluster_shift_meanActive = cfg.cluster_shift_meanActive
114 |
115 | self.score_scale = cfg.score_scale
116 | self.score_fullscale = cfg.score_fullscale
117 | self.score_mode = cfg.score_mode
118 |
119 | self.prepare_epochs = cfg.prepare_epochs
120 | self.pretrain_path = cfg.pretrain_path
121 | self.pretrain_module = cfg.pretrain_module
122 | self.fix_module = cfg.fix_module
123 |
124 |
125 | norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1)
126 |
127 | if block_residual:
128 | block = ResidualBlock
129 | else:
130 | block = VGGBlock
131 |
132 | if cfg.use_coords:
133 | input_c += 3
134 |
135 | self.cfg = cfg
136 |
137 | # backbone
138 | self.input_conv = spconv.SparseSequential(
139 | spconv.SubMConv3d(input_c, width, kernel_size=3, padding=1, bias=False, indice_key='subm1')
140 | )
141 | self.unet = UBlock([width, 2*width, 3*width, 4*width, 5*width, 6*width, 7*width], norm_fn, block_reps, block, indice_key_id=1)
142 | self.output_layer = spconv.SparseSequential(
143 | norm_fn(width),
144 | nn.ReLU()
145 | )
146 |
147 | # semantic segmentation branch
148 | self.semantic_linear = nn.Sequential(
149 | nn.Linear(width, width, bias=True),
150 | norm_fn(width),
151 | nn.ReLU(),
152 | nn.Linear(width, classes)
153 | )
154 |
155 | # center shift vector branch
156 | self.offset_linear = nn.Sequential(
157 | nn.Linear(width, width, bias=True),
158 | norm_fn(width),
159 | nn.ReLU(),
160 | nn.Linear(width, 3, bias=True)
161 | )
162 |
163 | # intra-instance network
164 | self.intra_ins_unet = UBlock([width, 2*width], norm_fn, 2, block, indice_key_id=11)
165 | self.intra_ins_outputlayer = spconv.SparseSequential(
166 | norm_fn(width),
167 | nn.ReLU()
168 | )
169 |
170 | # proposal score
171 | self.score_linear = nn.Linear(width, 1)
172 |
173 | # proposal mask
174 | self.mask_linear = nn.Sequential(
175 | nn.Linear(width, width),
176 | nn.ReLU(),
177 | nn.Linear(width, 1))
178 |
179 | self.apply(self.set_bn_init)
180 |
181 |
182 | # fix module
183 | module_map = {'input_conv': self.input_conv, 'unet': self.unet, 'output_layer': self.output_layer,
184 | 'semantic_linear': self.semantic_linear, 'offset_linear': self.offset_linear,
185 | 'intra_ins_unet': self.intra_ins_unet, 'intra_ins_outputlayer': self.intra_ins_outputlayer,
186 | 'score_linear': self.score_linear, 'mask_linear': self.mask_linear}
187 | for m in self.fix_module:
188 | mod = module_map[m]
189 | for param in mod.parameters():
190 | param.requires_grad = False
191 |
192 | # load pretrain weights
193 | if self.pretrain_path is not None:
194 | pretrain_dict = torch.load(self.pretrain_path)
195 | for m in self.pretrain_module:
196 | print("Load pretrained " + m + ": %d/%d" % utils.load_model_param(module_map[m], pretrain_dict, prefix=m))
197 |
198 |
199 | @staticmethod
200 | def set_bn_init(m):
201 | classname = m.__class__.__name__
202 | if classname.find('BatchNorm') != -1:
203 | m.weight.data.fill_(1.0)
204 | m.bias.data.fill_(0.0)
205 |
206 |
207 | def clusters_voxelization(self, clusters_idx, clusters_offset, feats, coords, fullscale, scale, mode):
208 | '''
209 | :param clusters_idx: (SumNPoint, 2), int, [:, 0] for cluster_id, [:, 1] for corresponding point idxs in N, cpu
210 | :param clusters_offset: (nCluster + 1), int, cpu
211 | :param feats: (N, C), float, cuda
212 | :param coords: (N, 3), float, cuda
213 | :return:
214 | '''
215 | c_idxs = clusters_idx[:, 1].cuda()
216 | clusters_feats = feats[c_idxs.long()]
217 | clusters_coords = coords[c_idxs.long()]
218 |
219 | clusters_coords_mean = hais_ops.sec_mean(clusters_coords, clusters_offset.cuda()) # (nCluster, 3), float
220 | clusters_coords_mean = torch.index_select(clusters_coords_mean, 0, clusters_idx[:, 0].cuda().long()) # (sumNPoint, 3), float
221 | clusters_coords -= clusters_coords_mean
222 |
223 | clusters_coords_min = hais_ops.sec_min(clusters_coords, clusters_offset.cuda()) # (nCluster, 3), float
224 | clusters_coords_max = hais_ops.sec_max(clusters_coords, clusters_offset.cuda()) # (nCluster, 3), float
225 |
226 | clusters_scale = 1 / ((clusters_coords_max - clusters_coords_min) / fullscale).max(1)[0] - 0.01 # (nCluster), float
227 | clusters_scale = torch.clamp(clusters_scale, min=None, max=scale)
228 |
229 | min_xyz = clusters_coords_min * clusters_scale.unsqueeze(-1) # (nCluster, 3), float
230 | max_xyz = clusters_coords_max * clusters_scale.unsqueeze(-1)
231 |
232 | clusters_scale = torch.index_select(clusters_scale, 0, clusters_idx[:, 0].cuda().long())
233 |
234 | clusters_coords = clusters_coords * clusters_scale.unsqueeze(-1)
235 |
236 | range = max_xyz - min_xyz
237 | offset = - min_xyz + torch.clamp(fullscale - range - 0.001, min=0) * torch.rand(3).cuda() + torch.clamp(fullscale - range + 0.001, max=0) * torch.rand(3).cuda()
238 | offset = torch.index_select(offset, 0, clusters_idx[:, 0].cuda().long())
239 | clusters_coords += offset
240 | assert clusters_coords.shape.numel() == ((clusters_coords >= 0) * (clusters_coords < fullscale)).sum()
241 |
242 | clusters_coords = clusters_coords.long()
243 | clusters_coords = torch.cat([clusters_idx[:, 0].view(-1, 1).long(), clusters_coords.cpu()], 1) # (sumNPoint, 1 + 3)
244 |
245 | out_coords, inp_map, out_map = hais_ops.voxelization_idx(clusters_coords, int(clusters_idx[-1, 0]) + 1, mode)
246 | # output_coords: M * (1 + 3) long
247 | # input_map: sumNPoint int
248 | # output_map: M * (maxActive + 1) int
249 |
250 | out_feats = hais_ops.voxelization(clusters_feats, out_map.cuda(), mode) # (M, C), float, cuda
251 |
252 | spatial_shape = [fullscale] * 3
253 | voxelization_feats = spconv.SparseConvTensor(out_feats, out_coords.int().cuda(), spatial_shape, int(clusters_idx[-1, 0]) + 1)
254 |
255 | return voxelization_feats, inp_map
256 |
257 | def forward(self, input, input_map, coords, batch_idxs, batch_offsets, epoch, training_mode):
258 | '''
259 | :param input_map: (N), int, cuda
260 | :param coords: (N, 3), float, cuda
261 | :param batch_idxs: (N), int, cuda
262 | :param batch_offsets: (B + 1), int, cuda
263 | '''
264 | ret = {}
265 | output = self.input_conv(input)
266 | output = self.unet(output)
267 | output = self.output_layer(output)
268 | output_feats = output.features[input_map.long()]
269 |
270 | # semantic segmentation
271 | semantic_scores = self.semantic_linear(output_feats) # (N, nClass), float
272 |
273 | semantic_preds = semantic_scores.max(1)[1] # (N), long
274 |
275 | ret['semantic_scores'] = semantic_scores
276 |
277 | # center shift vector
278 | pt_offsets = self.offset_linear(output_feats) # (N, 3), float32
279 | ret['pt_offsets'] = pt_offsets
280 |
281 | if(epoch > self.prepare_epochs):
282 |
283 | if self.cfg.dataset == 'scannetv2':
284 | object_idxs = torch.nonzero(semantic_preds > 1).view(-1) # floor idx 0, wall idx 1
285 | else:
286 | raise Exception
287 |
288 | # fliter out floor and wall
289 | batch_idxs_ = batch_idxs[object_idxs]
290 | batch_offsets_ = utils.get_batch_offsets(batch_idxs_, input.batch_size)
291 | coords_ = coords[object_idxs]
292 | pt_offsets_ = pt_offsets[object_idxs] # (N_fg, 3), float32
293 |
294 | semantic_preds_cpu = semantic_preds[object_idxs].int().cpu()
295 |
296 | idx, start_len = hais_ops.ballquery_batch_p(coords_ + pt_offsets_, \
297 | batch_idxs_, batch_offsets_, self.point_aggr_radius, self.cluster_shift_meanActive)
298 |
299 | using_set_aggr_in_training = getattr(self.cfg, 'using_set_aggr_in_training', True)
300 | using_set_aggr_in_testing = getattr(self.cfg, 'using_set_aggr_in_testing', True)
301 | using_set_aggr = using_set_aggr_in_training if training_mode == 'train' else using_set_aggr_in_testing
302 |
303 | proposals_idx, proposals_offset = hais_ops.hierarchical_aggregation(
304 | semantic_preds_cpu, (coords_ + pt_offsets_).cpu(), idx.cpu(), start_len.cpu(),
305 | batch_idxs_.cpu(), training_mode, using_set_aggr)
306 |
307 | proposals_idx[:, 1] = object_idxs[proposals_idx[:, 1].long()].int()
308 |
309 |
310 | # restrict the num of training proposals, avoid OOM
311 | max_proposal_num = getattr(self.cfg, 'max_proposal_num', 200)
312 | if training_mode == 'train' and proposals_offset.shape[0] > max_proposal_num:
313 | proposals_offset = proposals_offset[:max_proposal_num + 1]
314 | proposals_idx = proposals_idx[: proposals_offset[-1]]
315 | assert proposals_idx.shape[0] == proposals_offset[-1]
316 | print('selected proposal num', proposals_offset.shape[0] - 1)
317 |
318 | # proposals voxelization again
319 | input_feats, inp_map = self.clusters_voxelization(proposals_idx, proposals_offset, output_feats, coords, self.score_fullscale, self.score_scale, self.score_mode)
320 |
321 | # predict instance scores
322 | score = self.intra_ins_unet(input_feats)
323 | score = self.intra_ins_outputlayer(score)
324 | score_feats = score.features[inp_map.long()] # (sumNPoint, C)
325 |
326 | # predict mask scores
327 | # first linear than voxel to point, more efficient (because voxel num < point num)
328 | mask_scores = self.mask_linear(score.features)
329 | mask_scores = mask_scores[inp_map.long()]
330 |
331 | # predict instance scores
332 | if getattr(self.cfg, 'use_mask_filter_score_feature', False) and \
333 | epoch > self.cfg.use_mask_filter_score_feature_start_epoch:
334 | mask_index_select = torch.ones_like(mask_scores)
335 | mask_index_select[torch.sigmoid(mask_scores) < self.cfg.mask_filter_score_feature_thre] = 0.
336 | score_feats = score_feats * mask_index_select
337 | score_feats = hais_ops.roipool(score_feats, proposals_offset.cuda()) # (nProposal, C)
338 | scores = self.score_linear(score_feats) # (nProposal, 1)
339 |
340 | ret['proposal_scores'] = (scores, proposals_idx, proposals_offset, mask_scores)
341 |
342 | return ret
343 |
344 |
345 | def model_fn_decorator(test=False):
346 | # config
347 | from util.config import cfg
348 |
349 |
350 | semantic_criterion = nn.CrossEntropyLoss(ignore_index=cfg.ignore_label).cuda()
351 | score_criterion = nn.BCELoss(reduction='none').cuda()
352 |
353 | def test_model_fn(batch, model, epoch):
354 | coords = batch['locs'].cuda() # (N, 1 + 3), long, cuda, dimension 0 for batch_idx
355 | voxel_coords = batch['voxel_locs'].cuda() # (M, 1 + 3), long, cuda
356 | p2v_map = batch['p2v_map'].cuda() # (N), int, cuda
357 | v2p_map = batch['v2p_map'].cuda() # (M, 1 + maxActive), int, cuda
358 |
359 | coords_float = batch['locs_float'].cuda() # (N, 3), float32, cuda
360 | feats = batch['feats'].cuda() # (N, C), float32, cuda
361 | batch_offsets = batch['offsets'].cuda() # (B + 1), int, cuda
362 | spatial_shape = batch['spatial_shape']
363 |
364 | if cfg.use_coords:
365 | feats = torch.cat((feats, coords_float), 1)
366 |
367 | voxel_feats = hais_ops.voxelization(feats, v2p_map, cfg.mode) # (M, C), float, cuda
368 |
369 | input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, cfg.batch_size)
370 |
371 | ret = model(input_, p2v_map, coords_float, coords[:, 0].int(), batch_offsets, epoch, 'test')
372 | semantic_scores = ret['semantic_scores'] # (N, nClass) float32, cuda
373 | pt_offsets = ret['pt_offsets'] # (N, 3), float32, cuda
374 |
375 | if (epoch > cfg.prepare_epochs):
376 | scores, proposals_idx, proposals_offset, mask_scores = ret['proposal_scores']
377 |
378 | # preds
379 | with torch.no_grad():
380 | preds = {}
381 | preds['semantic'] = semantic_scores
382 | preds['pt_offsets'] = pt_offsets
383 | if (epoch > cfg.prepare_epochs):
384 | preds['score'] = scores
385 | preds['proposals'] = (proposals_idx, proposals_offset, mask_scores)
386 |
387 | return preds
388 |
389 | def model_fn(batch, model, epoch):
390 | # batch {'locs': locs, 'voxel_locs': voxel_locs, 'p2v_map': p2v_map, 'v2p_map': v2p_map,
391 | # 'locs_float': locs_float, 'feats': feats, 'labels': labels, 'instance_labels': instance_labels,
392 | # 'instance_info': instance_infos, 'instance_pointnum': instance_pointnum,
393 | # 'id': tbl, 'offsets': batch_offsets, 'spatial_shape': spatial_shape}
394 | coords = batch['locs'].cuda() # (N, 1 + 3), long, cuda, dimension 0 for batch_idx
395 | voxel_coords = batch['voxel_locs'].cuda() # (M, 1 + 3), long, cuda
396 | p2v_map = batch['p2v_map'].cuda() # (N), int, cuda
397 | v2p_map = batch['v2p_map'].cuda() # (M, 1 + maxActive), int, cuda
398 |
399 | coords_float = batch['locs_float'].cuda() # (N, 3), float32, cuda
400 | feats = batch['feats'].cuda() # (N, C), float32, cuda
401 | labels = batch['labels'].cuda() # (N), long, cuda
402 | instance_labels = batch['instance_labels'].cuda() # (N), long, cuda, 0~total_nInst, -100
403 |
404 | instance_info = batch['instance_info'].cuda() # (N, 9), float32, cuda, (meanxyz, minxyz, maxxyz)
405 | instance_pointnum = batch['instance_pointnum'].cuda() # (total_nInst), int, cuda
406 | batch_offsets = batch['offsets'].cuda() # (B + 1), int, cuda
407 | spatial_shape = batch['spatial_shape']
408 |
409 | if cfg.use_coords:
410 | feats = torch.cat((feats, coords_float), 1)
411 |
412 | voxel_feats = hais_ops.voxelization(feats, v2p_map, cfg.mode) # (M, C), float, cuda
413 |
414 | input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, cfg.batch_size)
415 |
416 | ret = model(input_, p2v_map, coords_float, coords[:, 0].int(), batch_offsets, epoch, 'train')
417 | semantic_scores = ret['semantic_scores'] # (N, nClass) float32, cuda
418 | pt_offsets = ret['pt_offsets'] # (N, 3), float32, cuda
419 |
420 | if(epoch > cfg.prepare_epochs):
421 | scores, proposals_idx, proposals_offset, mask_scores = ret['proposal_scores']
422 | # scores: (nProposal, 1) float, cuda
423 | # proposals_idx: (sumNPoint, 2), int, cpu, [:, 0] for cluster_id, [:, 1] for corresponding point idxs in N
424 | # proposals_offset: (nProposal + 1), int, cpu
425 | # mask_scores: (sumNPoint, 1), float, cuda
426 |
427 | loss_inp = {}
428 |
429 | loss_inp['semantic_scores'] = (semantic_scores, labels)
430 | loss_inp['pt_offsets'] = (pt_offsets, coords_float, instance_info, instance_labels)
431 |
432 | if(epoch > cfg.prepare_epochs):
433 | loss_inp['proposal_scores'] = (scores, proposals_idx, proposals_offset, instance_pointnum, mask_scores)
434 |
435 | loss, loss_out = loss_fn(loss_inp, epoch)
436 |
437 | # accuracy / visual_dict / meter_dict
438 | with torch.no_grad():
439 | preds = {}
440 | preds['semantic'] = semantic_scores
441 | preds['pt_offsets'] = pt_offsets
442 | if(epoch > cfg.prepare_epochs):
443 | preds['score'] = scores
444 | preds['proposals'] = (proposals_idx, proposals_offset)
445 |
446 | visual_dict = {}
447 | visual_dict['loss'] = loss
448 | for k, v in loss_out.items():
449 | visual_dict[k] = v[0]
450 |
451 | meter_dict = {}
452 | meter_dict['loss'] = (loss.item(), coords.shape[0])
453 | for k, v in loss_out.items():
454 | meter_dict[k] = (float(v[0]), v[1])
455 |
456 | return loss, preds, visual_dict, meter_dict
457 |
458 |
459 | def loss_fn(loss_inp, epoch):
460 |
461 | loss_out = {}
462 |
463 | '''semantic loss'''
464 | semantic_scores, semantic_labels = loss_inp['semantic_scores']
465 | # semantic_scores: (N, nClass), float32, cuda
466 | # semantic_labels: (N), long, cuda
467 |
468 | semantic_loss = semantic_criterion(semantic_scores, semantic_labels)
469 |
470 | loss_out['semantic_loss'] = (semantic_loss, semantic_scores.shape[0])
471 |
472 | '''offset loss'''
473 | pt_offsets, coords, instance_info, instance_labels = loss_inp['pt_offsets']
474 | # pt_offsets: (N, 3), float, cuda
475 | # coords: (N, 3), float32
476 | # instance_info: (N, 9), float32 tensor (meanxyz, minxyz, maxxyz)
477 | # instance_labels: (N), long
478 |
479 |
480 | gt_offsets = instance_info[:, 0:3] - coords # (N, 3)
481 | pt_diff = pt_offsets - gt_offsets # (N, 3)
482 | pt_dist = torch.sum(torch.abs(pt_diff), dim=-1) # (N)
483 |
484 | valid = (instance_labels != cfg.ignore_label).float()
485 |
486 | offset_norm_loss = torch.sum(pt_dist * valid) / (torch.sum(valid) + 1e-6)
487 | loss_out['offset_norm_loss'] = (offset_norm_loss, valid.sum())
488 |
489 | if (epoch > cfg.prepare_epochs):
490 | '''score and mask loss'''
491 |
492 | scores, proposals_idx, proposals_offset, instance_pointnum, mask_scores = loss_inp['proposal_scores']
493 | # scores: (nProposal, 1), float32
494 | # proposals_idx: (sumNPoint, 2), int, cpu, [:, 0] for cluster_id, [:, 1] for corresponding point idxs in N
495 | # proposals_offset: (nProposal + 1), int, cpu
496 | # instance_pointnum: (total_nInst), int
497 |
498 | # get iou and calculate mask label and mask loss
499 | mask_scores_sigmoid = torch.sigmoid(mask_scores)
500 |
501 | if getattr(cfg, 'cal_iou_based_on_mask', False) \
502 | and (epoch > cfg.cal_iou_based_on_mask_start_epoch):
503 | ious, mask_label = hais_ops.cal_iou_and_masklabel(proposals_idx[:, 1].cuda(), \
504 | proposals_offset.cuda(), instance_labels, instance_pointnum, mask_scores_sigmoid.detach(), 1)
505 | else:
506 | ious, mask_label = hais_ops.cal_iou_and_masklabel(proposals_idx[:, 1].cuda(), \
507 | proposals_offset.cuda(), instance_labels, instance_pointnum, mask_scores_sigmoid.detach(), 0)
508 | # ious: (nProposal, nInstance)
509 | # mask_label: (sumNPoint, 1)
510 |
511 | mask_label_weight = (mask_label != -1).float()
512 | mask_label[mask_label==-1.] = 0.5 # any value is ok
513 | mask_loss = torch.nn.functional.binary_cross_entropy(mask_scores_sigmoid, mask_label, weight=mask_label_weight, reduction='none')
514 | mask_loss = mask_loss.mean()
515 | loss_out['mask_loss'] = (mask_loss, mask_label_weight.sum())
516 | gt_ious, _ = ious.max(1) # gt_ious: (nProposal) float, long
517 |
518 |
519 | gt_scores = get_segmented_scores(gt_ious, cfg.fg_thresh, cfg.bg_thresh)
520 |
521 | score_loss = score_criterion(torch.sigmoid(scores.view(-1)), gt_scores)
522 | score_loss = score_loss.mean()
523 |
524 | loss_out['score_loss'] = (score_loss, gt_ious.shape[0])
525 |
526 | '''total loss'''
527 | loss = cfg.loss_weight[0] * semantic_loss + cfg.loss_weight[1] * offset_norm_loss
528 | if(epoch > cfg.prepare_epochs):
529 | loss += (cfg.loss_weight[2] * score_loss)
530 | loss += (cfg.loss_weight[3] * mask_loss)
531 |
532 | return loss, loss_out
533 |
534 |
535 | def get_segmented_scores(scores, fg_thresh=1.0, bg_thresh=0.0):
536 | '''
537 | :param scores: (N), float, 0~1
538 | :return: segmented_scores: (N), float 0~1, >fg_thresh: 1, fg_thresh
541 | bg_mask = scores < bg_thresh
542 | interval_mask = (fg_mask == 0) & (bg_mask == 0)
543 |
544 | segmented_scores = (fg_mask > 0).float()
545 | k = 1 / (fg_thresh - bg_thresh + 1e-5)
546 | b = bg_thresh / (bg_thresh - fg_thresh + 1e-5)
547 | segmented_scores[interval_mask] = scores[interval_mask] * k + b
548 |
549 | return segmented_scores
550 |
551 | if test:
552 | fn = test_model_fn
553 | else:
554 | fn = model_fn
555 |
556 | return fn
557 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.1
2 | cmake>=3.13.2
3 | plyfile
4 | tensorboardX
5 | pyyaml
6 | scipy
7 | six
8 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import numpy as np
4 | import random
5 | import os
6 |
7 | from util.config import cfg
8 | cfg.task = 'test'
9 | from util.log import logger
10 | import util.utils as utils
11 | import util.eval as eval
12 |
13 | def init():
14 | global result_dir
15 | result_dir = os.path.join(cfg.exp_path, 'result', cfg.split)
16 | backup_dir = os.path.join(result_dir, 'backup_files')
17 | os.makedirs(backup_dir, exist_ok=True)
18 | os.makedirs(os.path.join(result_dir, 'predicted_masks'), exist_ok=True)
19 | os.system('cp test.py {}'.format(backup_dir))
20 | os.system('cp {} {}'.format(cfg.model_dir, backup_dir))
21 | os.system('cp {} {}'.format(cfg.dataset_dir, backup_dir))
22 | os.system('cp {} {}'.format(cfg.config, backup_dir))
23 |
24 | global semantic_label_idx
25 | semantic_label_idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]
26 |
27 | logger.info(cfg)
28 |
29 | random.seed(cfg.test_seed)
30 | np.random.seed(cfg.test_seed)
31 | torch.manual_seed(cfg.test_seed)
32 | torch.cuda.manual_seed_all(cfg.test_seed)
33 |
34 |
35 | def test(model, model_fn, data_name, epoch):
36 | logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
37 |
38 | if cfg.dataset == 'scannetv2':
39 | if data_name == 'scannet':
40 | from data.scannetv2_inst import Dataset
41 | dataset = Dataset(test=True)
42 | dataset.testLoader()
43 | else:
44 | print("Error: no data loader - " + data_name)
45 | exit(0)
46 | dataloader = dataset.test_data_loader
47 |
48 | with torch.no_grad():
49 | model = model.eval()
50 |
51 | total_end1 = 0.
52 | matches = {}
53 | for i, batch in enumerate(dataloader):
54 |
55 | # inference
56 | start1 = time.time()
57 | preds = model_fn(batch, model, epoch)
58 | end1 = time.time() - start1
59 |
60 | # decode results for evaluation
61 | N = batch['feats'].shape[0]
62 | test_scene_name = dataset.test_file_names[int(batch['id'][0])].split('/')[-1][:12]
63 | semantic_scores = preds['semantic'] # (N, nClass=20) float32, cuda
64 | semantic_pred = semantic_scores.max(1)[1] # (N) long, cuda
65 | pt_offsets = preds['pt_offsets'] # (N, 3), float32, cuda
66 | if (epoch > cfg.prepare_epochs):
67 | scores = preds['score'] # (nProposal, 1) float, cuda
68 | scores_pred = torch.sigmoid(scores.view(-1))
69 |
70 | proposals_idx, proposals_offset, mask_scores = preds['proposals']
71 | # proposals_idx: (sumNPoint, 2), int, cpu, [:, 0] for cluster_id, [:, 1] for corresponding point idxs in N
72 | # proposals_offset: (nProposal + 1), int, cpu
73 | proposals_pred = torch.zeros((proposals_offset.shape[0] - 1, N), dtype=torch.int, device=scores_pred.device)
74 | # (nProposal, N), int, cuda
75 |
76 | # outlier filtering
77 | test_mask_score_thre = getattr(cfg, 'test_mask_score_thre', -0.5)
78 | _mask = mask_scores.squeeze(1) > test_mask_score_thre
79 | proposals_pred[proposals_idx[_mask][:, 0].long(), proposals_idx[_mask][:, 1].long()] = 1
80 |
81 | semantic_id = torch.tensor(semantic_label_idx, device=scores_pred.device) \
82 | [semantic_pred[proposals_idx[:, 1][proposals_offset[:-1].long()].long()]] # (nProposal), long
83 | # semantic_id_idx = semantic_pred[proposals_idx[:, 1][proposals_offset[:-1].long()].long()]
84 |
85 | # score threshold
86 | score_mask = (scores_pred > cfg.TEST_SCORE_THRESH)
87 | scores_pred = scores_pred[score_mask]
88 | proposals_pred = proposals_pred[score_mask]
89 | semantic_id = semantic_id[score_mask]
90 | # semantic_id_idx = semantic_id_idx[score_mask]
91 |
92 | # npoint threshold
93 | proposals_pointnum = proposals_pred.sum(1)
94 | npoint_mask = (proposals_pointnum >= cfg.TEST_NPOINT_THRESH)
95 | scores_pred = scores_pred[npoint_mask]
96 | proposals_pred = proposals_pred[npoint_mask]
97 | semantic_id = semantic_id[npoint_mask]
98 |
99 |
100 | # nms (no need)
101 | if getattr(cfg, 'using_NMS', False):
102 | if semantic_id.shape[0] == 0:
103 | pick_idxs = np.empty(0)
104 | else:
105 | proposals_pred_f = proposals_pred.float() # (nProposal, N), float, cuda
106 | intersection = torch.mm(proposals_pred_f, proposals_pred_f.t()) # (nProposal, nProposal), float, cuda
107 | proposals_pointnum = proposals_pred_f.sum(1) # (nProposal), float, cuda
108 | proposals_pn_h = proposals_pointnum.unsqueeze(-1).repeat(1, proposals_pointnum.shape[0])
109 | proposals_pn_v = proposals_pointnum.unsqueeze(0).repeat(proposals_pointnum.shape[0], 1)
110 | cross_ious = intersection / (proposals_pn_h + proposals_pn_v - intersection)
111 | pick_idxs = non_max_suppression(cross_ious.cpu().numpy(), scores_pred.cpu().numpy(), cfg.TEST_NMS_THRESH)
112 | # int, (nCluster, N)
113 | clusters = proposals_pred[pick_idxs]
114 | cluster_scores = scores_pred[pick_idxs]
115 | cluster_semantic_id = semantic_id[pick_idxs]
116 | else:
117 | clusters = proposals_pred
118 | cluster_scores = scores_pred
119 | cluster_semantic_id = semantic_id
120 |
121 | nclusters = clusters.shape[0]
122 |
123 |
124 | # prepare for evaluation
125 | if cfg.eval:
126 | pred_info = {}
127 | pred_info['conf'] = cluster_scores.cpu().numpy()
128 | pred_info['label_id'] = cluster_semantic_id.cpu().numpy()
129 | pred_info['mask'] = clusters.cpu().numpy()
130 | gt_file = os.path.join(cfg.data_root, cfg.dataset, cfg.split + '_gt', test_scene_name + '.txt')
131 | gt2pred, pred2gt = eval.assign_instances_for_scan(test_scene_name, pred_info, gt_file)
132 |
133 | matches[test_scene_name] = {}
134 | matches[test_scene_name]['gt'] = gt2pred
135 | matches[test_scene_name]['pred'] = pred2gt
136 |
137 | if cfg.split == 'val':
138 | matches[test_scene_name]['seg_gt'] = batch['labels']
139 | matches[test_scene_name]['seg_pred'] = semantic_pred
140 |
141 |
142 | # save files
143 | if cfg.save_semantic:
144 | os.makedirs(os.path.join(result_dir, 'semantic'), exist_ok=True)
145 | semantic_np = semantic_pred.cpu().numpy()
146 | np.save(os.path.join(result_dir, 'semantic', test_scene_name + '.npy'), semantic_np)
147 |
148 | if cfg.save_pt_offsets:
149 | os.makedirs(os.path.join(result_dir, 'coords_offsets'), exist_ok=True)
150 | pt_offsets_np = pt_offsets.cpu().numpy()
151 | coords_np = batch['locs_float'].numpy()
152 | coords_offsets = np.concatenate((coords_np, pt_offsets_np), 1) # (N, 6)
153 | np.save(os.path.join(result_dir, 'coords_offsets', test_scene_name + '.npy'), coords_offsets)
154 |
155 | if(epoch > cfg.prepare_epochs and cfg.save_instance):
156 | f = open(os.path.join(result_dir, test_scene_name + '.txt'), 'w')
157 | for proposal_id in range(nclusters):
158 | clusters_i = clusters[proposal_id].cpu().numpy() # (N)
159 | semantic_label = np.argmax(np.bincount(semantic_pred[np.where(clusters_i == 1)[0]].cpu()))
160 | score = cluster_scores[proposal_id]
161 | f.write('predicted_masks/{}_{:03d}.txt {} {:.4f}'.format( \
162 | test_scene_name, proposal_id, semantic_label_idx[semantic_label], score))
163 | if proposal_id < nclusters - 1:
164 | f.write('\n')
165 | np.savetxt(os.path.join(result_dir, 'predicted_masks', test_scene_name + '_%03d.txt' % (proposal_id)), clusters_i, fmt='%d')
166 | f.close()
167 |
168 |
169 | logger.info("instance iter: {}/{} point_num: {} ncluster: {} inference time: {:.2f}s".format( \
170 | batch['id'][0] + 1, len(dataset.test_files), N, nclusters, end1))
171 | total_end1 += end1
172 |
173 | # evaluation
174 | if cfg.eval:
175 | ap_scores = eval.evaluate_matches(matches)
176 | avgs = eval.compute_averages(ap_scores)
177 | eval.print_results(avgs)
178 |
179 | logger.info("whole set inference time: {:.2f}s, latency per frame: {:.2f}ms".format(total_end1, total_end1 / len(dataloader) * 1000))
180 |
181 | # evaluate semantic segmantation accuracy and mIoU
182 | if cfg.split == 'val':
183 | seg_accuracy = evaluate_semantic_segmantation_accuracy(matches)
184 | logger.info("semantic_segmantation_accuracy: {:.4f}".format(seg_accuracy))
185 | miou = evaluate_semantic_segmantation_miou(matches)
186 | logger.info("semantic_segmantation_mIoU: {:.4f}".format(miou))
187 |
188 | def evaluate_semantic_segmantation_accuracy(matches):
189 | seg_gt_list = []
190 | seg_pred_list = []
191 | for k, v in matches.items():
192 | seg_gt_list.append(v['seg_gt'])
193 | seg_pred_list.append(v['seg_pred'])
194 | seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda()
195 | seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda()
196 | assert seg_gt_all.shape == seg_pred_all.shape
197 | correct = (seg_gt_all[seg_gt_all != -100] == seg_pred_all[seg_gt_all != -100]).sum()
198 | whole = (seg_gt_all != -100).sum()
199 | seg_accuracy = correct.float() / whole.float()
200 | return seg_accuracy
201 |
202 | def evaluate_semantic_segmantation_miou(matches):
203 | seg_gt_list = []
204 | seg_pred_list = []
205 | for k, v in matches.items():
206 | seg_gt_list.append(v['seg_gt'])
207 | seg_pred_list.append(v['seg_pred'])
208 | seg_gt_all = torch.cat(seg_gt_list, dim=0).cuda()
209 | seg_pred_all = torch.cat(seg_pred_list, dim=0).cuda()
210 | assert seg_gt_all.shape == seg_pred_all.shape
211 | iou_list = []
212 | for _index in seg_gt_all.unique():
213 | if _index != -100:
214 | intersection = ((seg_gt_all == _index) & (seg_pred_all == _index)).sum()
215 | union = ((seg_gt_all == _index) | (seg_pred_all == _index)).sum()
216 | iou = intersection.float() / union
217 | iou_list.append(iou)
218 | iou_tensor = torch.tensor(iou_list)
219 | miou = iou_tensor.mean()
220 | return miou
221 |
222 | def non_max_suppression(ious, scores, threshold):
223 | ixs = scores.argsort()[::-1]
224 | pick = []
225 | while len(ixs) > 0:
226 | i = ixs[0]
227 | pick.append(i)
228 | iou = ious[i, ixs[1:]]
229 | remove_ixs = np.where(iou > threshold)[0] + 1
230 | ixs = np.delete(ixs, remove_ixs)
231 | ixs = np.delete(ixs, 0)
232 | return np.array(pick, dtype=np.int32)
233 |
234 |
235 | if __name__ == '__main__':
236 | init()
237 |
238 | exp_name = cfg.config.split('/')[-1][:-5]
239 | model_name = exp_name.split('_')[0]
240 | data_name = exp_name.split('_')[-1]
241 |
242 | logger.info('=> creating model ...')
243 | logger.info('Classes: {}'.format(cfg.classes))
244 |
245 | if model_name == 'hais':
246 | from model.hais.hais import HAIS as Network
247 | from model.hais.hais import model_fn_decorator
248 |
249 | else:
250 | print("Error: no model version " + model_name)
251 | exit(0)
252 | model = Network(cfg)
253 |
254 | use_cuda = torch.cuda.is_available()
255 | logger.info('cuda available: {}'.format(use_cuda))
256 | assert use_cuda
257 | model = model.cuda()
258 |
259 | logger.info('#classifier parameters (model): {}'.format(sum([x.nelement() for x in model.parameters()])))
260 | model_fn = model_fn_decorator(test=True)
261 |
262 | # load model
263 | utils.checkpoint_restore(cfg, model, None, cfg.exp_path, cfg.config.split('/')[-1][:-5],
264 | use_cuda, cfg.test_epoch, dist=False, f=cfg.pretrain)
265 | # resume from the latest epoch, or specify the epoch to restore
266 |
267 | # evaluate
268 | test(model, model_fn, data_name, cfg.test_epoch)
269 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import time, sys, os, random
4 | from tensorboardX import SummaryWriter
5 | import numpy as np
6 |
7 | from util.config import cfg
8 |
9 | import torch.distributed as dist
10 |
11 |
12 | def init():
13 | # copy important files to backup
14 | backup_dir = os.path.join(cfg.exp_path, 'backup_files')
15 | os.makedirs(backup_dir, exist_ok=True)
16 | os.system('cp train.py {}'.format(backup_dir))
17 | os.system('cp {} {}'.format(cfg.model_dir, backup_dir))
18 | os.system('cp {} {}'.format(cfg.dataset_dir, backup_dir))
19 | os.system('cp {} {}'.format(cfg.config, backup_dir))
20 |
21 | # log the config
22 | logger.info(cfg)
23 |
24 | # summary writer
25 | global writer
26 | writer = SummaryWriter(cfg.exp_path)
27 |
28 | # random seed
29 | random.seed(cfg.manual_seed)
30 | np.random.seed(cfg.manual_seed)
31 | torch.manual_seed(cfg.manual_seed)
32 | torch.cuda.manual_seed_all(cfg.manual_seed)
33 |
34 | # epoch counts from 1 to N
35 | def train_epoch(train_loader, model, model_fn, optimizer, epoch):
36 | iter_time = utils.AverageMeter()
37 | data_time = utils.AverageMeter()
38 | am_dict = {}
39 |
40 | model.train()
41 | start_epoch = time.time()
42 | end = time.time()
43 |
44 | if train_loader.sampler is not None and cfg.dist == True:
45 | train_loader.sampler.set_epoch(epoch)
46 |
47 | for i, batch in enumerate(train_loader):
48 |
49 | if batch['locs'].shape[0] < 20000:
50 | logger.info("point num < 20000, continue")
51 | continue
52 |
53 | data_time.update(time.time() - end)
54 | torch.cuda.empty_cache()
55 |
56 | # adjust learning rate
57 | utils.cosine_lr_after_step(optimizer, cfg.lr, epoch - 1, cfg.step_epoch, cfg.epochs)
58 |
59 |
60 | # prepare input and forward
61 | loss, _, visual_dict, meter_dict = model_fn(batch, model, epoch)
62 |
63 | # meter_dict
64 | for k, v in meter_dict.items():
65 | if k not in am_dict.keys():
66 | am_dict[k] = utils.AverageMeter()
67 | am_dict[k].update(v[0], v[1])
68 |
69 | # backward
70 | optimizer.zero_grad()
71 | loss.backward()
72 | optimizer.step()
73 |
74 | # time and print
75 | current_iter = (epoch - 1) * len(train_loader) + i + 1
76 | max_iter = cfg.epochs * len(train_loader)
77 | remain_iter = max_iter - current_iter
78 |
79 | iter_time.update(time.time() - end)
80 | end = time.time()
81 |
82 | remain_time = remain_iter * iter_time.avg
83 | t_m, t_s = divmod(remain_time, 60)
84 | t_h, t_m = divmod(t_m, 60)
85 | remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))
86 |
87 | if cfg.local_rank == 0 and i % 10 == 0:
88 | sys.stdout.write(
89 | "epoch: {}/{} iter: {}/{} loss: {:.4f}({:.4f}) data_time: {:.2f}({:.2f}) iter_time: {:.2f}({:.2f}) remain_time: {remain_time}\n".format
90 | (epoch, cfg.epochs, i + 1, len(train_loader), am_dict['loss'].val, am_dict['loss'].avg,
91 | data_time.val, data_time.avg, iter_time.val, iter_time.avg, remain_time=remain_time))
92 |
93 | if (i == len(train_loader) - 1): print()
94 |
95 |
96 | logger.info("epoch: {}/{}, train loss: {:.4f}, time: {}s".format(epoch, cfg.epochs, am_dict['loss'].avg, time.time() - start_epoch))
97 |
98 | if cfg.local_rank == 0:
99 | utils.checkpoint_save(model, optimizer, cfg.exp_path, cfg.config.split('/')[-1][:-5], epoch, cfg.save_freq, use_cuda)
100 |
101 | for k in am_dict.keys():
102 | if k in visual_dict.keys():
103 | writer.add_scalar(k+'_train', am_dict[k].avg, epoch)
104 |
105 |
106 | def eval_epoch(val_loader, model, model_fn, epoch):
107 | logger.info('>>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>')
108 | am_dict = {}
109 |
110 | with torch.no_grad():
111 | model.eval()
112 | start_epoch = time.time()
113 | for i, batch in enumerate(val_loader):
114 |
115 | # prepare input and forward
116 | loss, preds, visual_dict, meter_dict = model_fn(batch, model, epoch)
117 |
118 | for k, v in meter_dict.items():
119 | if k not in am_dict.keys():
120 | am_dict[k] = utils.AverageMeter()
121 | am_dict[k].update(v[0], v[1])
122 | sys.stdout.write("\riter: {}/{} loss: {:.4f}({:.4f})".format(i + 1, len(val_loader), am_dict['loss'].val, am_dict['loss'].avg))
123 | if (i == len(val_loader) - 1): print()
124 |
125 | logger.info("epoch: {}/{}, val loss: {:.4f}, time: {}s".format(epoch, cfg.epochs, am_dict['loss'].avg, time.time() - start_epoch))
126 |
127 | for k in am_dict.keys():
128 | if k in visual_dict.keys():
129 | writer.add_scalar(k + '_eval', am_dict[k].avg, epoch)
130 |
131 |
132 | if __name__ == '__main__':
133 | if cfg.dist == True:
134 | raise NotImplementedError
135 | # num_gpus = torch.cuda.device_count()
136 | # dist.init_process_group(backend='nccl', rank=cfg.local_rank,
137 | # world_size=num_gpus)
138 | # torch.cuda.set_device(cfg.local_rank)
139 |
140 | from util.log import logger
141 | import util.utils as utils
142 |
143 | init()
144 |
145 | exp_name = cfg.config.split('/')[-1][:-5]
146 | model_name = exp_name.split('_')[0]
147 | data_name = exp_name.split('_')[-1]
148 |
149 | # model
150 | logger.info('=> creating model ...')
151 | if model_name == 'hais':
152 | from model.hais.hais import HAIS as Network
153 | from model.hais.hais import model_fn_decorator
154 | else:
155 | print("Error: no model - " + model_name)
156 | exit(0)
157 |
158 | model = Network(cfg)
159 |
160 | use_cuda = torch.cuda.is_available()
161 | logger.info('cuda available: {}'.format(use_cuda))
162 | assert use_cuda
163 | model = model.cuda()
164 |
165 | logger.info('#classifier parameters: {}'.format(sum([x.nelement() for x in model.parameters()])))
166 |
167 | # optimizer
168 | if cfg.optim == 'Adam':
169 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr)
170 | elif cfg.optim == 'SGD':
171 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),
172 | lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
173 |
174 |
175 | model_fn = model_fn_decorator()
176 |
177 | # dataset
178 | if cfg.dataset == 'scannetv2':
179 | if data_name == 'scannet':
180 | import data.scannetv2_inst
181 | dataset = data.scannetv2_inst.Dataset()
182 | if cfg.dist:
183 | dataset.dist_trainLoader()
184 | else:
185 | dataset.trainLoader()
186 | dataset.valLoader()
187 | else:
188 | print("Error: no data loader - " + data_name)
189 | exit(0)
190 | else:
191 | raise NotImplementedError("Not yet supported")
192 |
193 |
194 | # resume from the latest epoch, or specify the epoch to restore
195 | start_epoch = utils.checkpoint_restore(cfg, model, optimizer, cfg.exp_path,
196 | cfg.config.split('/')[-1][:-5], use_cuda)
197 |
198 |
199 | if cfg.dist:
200 | # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
201 | model = torch.nn.parallel.DistributedDataParallel(
202 | model.cuda(cfg.local_rank),
203 | device_ids=[cfg.local_rank],
204 | output_device=cfg.local_rank,
205 | find_unused_parameters=True)
206 |
207 |
208 | # train and val
209 | for epoch in range(start_epoch, cfg.epochs + 1):
210 | train_epoch(dataset.train_data_loader, model, model_fn, optimizer, epoch)
211 |
212 | if utils.is_multiple(epoch, cfg.save_freq) or utils.is_power2(epoch):
213 | eval_epoch(dataset.val_data_loader, model, model_fn, epoch)
214 |
--------------------------------------------------------------------------------
/util/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import yaml
3 | import os
4 |
5 | def get_parser():
6 | parser = argparse.ArgumentParser(description='Point Cloud Segmentation')
7 | parser.add_argument('--config', type=str, help='path to config file')
8 |
9 | # pretrain
10 | parser.add_argument('--pretrain', type=str, help='path to pretrain model')
11 |
12 | parser.add_argument('--save_dir', type=str, default='exp', help='path to save model')
13 |
14 | parser.add_argument('--dist', action='store_true', default=False, help='dist train')
15 |
16 | parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
17 |
18 |
19 | args_cfg = parser.parse_args()
20 | assert args_cfg.config is not None
21 | with open(args_cfg.config, 'r') as f:
22 | config = yaml.load(f)
23 | for key in config:
24 | for k, v in config[key].items():
25 | setattr(args_cfg, k, v)
26 |
27 | return args_cfg
28 |
29 |
30 | cfg = get_parser()
31 | setattr(cfg, 'exp_path', os.path.join(cfg.save_dir, cfg.dataset, cfg.model_name, cfg.config.split('/')[-1][:-5]))
32 |
--------------------------------------------------------------------------------
/util/eval.py:
--------------------------------------------------------------------------------
1 | # Modified from ScanNet evaluation script: https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/3d_evaluation/evaluate_semantic_instance.py
2 |
3 | import os, sys, numpy as np
4 | import util.utils_3d as util_3d
5 | import util.utils as util
6 |
7 | # ---------- Label info ---------- #
8 | CLASS_LABELS = ['cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture']
9 | VALID_CLASS_IDS = np.array([3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
10 | ID_TO_LABEL = {}
11 | LABEL_TO_ID = {}
12 | for i in range(len(VALID_CLASS_IDS)):
13 | LABEL_TO_ID[CLASS_LABELS[i]] = VALID_CLASS_IDS[i]
14 | ID_TO_LABEL[VALID_CLASS_IDS[i]] = CLASS_LABELS[i]
15 | # ---------- Evaluation params ---------- #
16 | # overlaps for evaluation
17 | OVERLAPS = np.append(np.arange(0.5,0.95,0.05), 0.25)
18 | # minimum region size for evaluation [verts]
19 | MIN_REGION_SIZES = np.array( [ 100 ] )
20 | # distance thresholds [m]
21 | DISTANCE_THRESHES = np.array( [ float('inf') ] )
22 | # distance confidences
23 | DISTANCE_CONFS = np.array( [ -float('inf') ] )
24 |
25 |
26 | def evaluate_matches(matches):
27 | overlaps = OVERLAPS
28 | min_region_sizes = [MIN_REGION_SIZES[0]]
29 | dist_threshes = [DISTANCE_THRESHES[0]]
30 | dist_confs = [DISTANCE_CONFS[0]]
31 |
32 | # results: class x overlap
33 | ap = np.zeros((len(dist_threshes), len(CLASS_LABELS), len(overlaps)), np.float)
34 | for di, (min_region_size, distance_thresh, distance_conf) in enumerate(zip(min_region_sizes, dist_threshes, dist_confs)):
35 | for oi, overlap_th in enumerate(overlaps):
36 | pred_visited = {}
37 | for m in matches:
38 | for p in matches[m]['pred']:
39 | for label_name in CLASS_LABELS:
40 | for p in matches[m]['pred'][label_name]:
41 | if 'filename' in p:
42 | pred_visited[p['filename']] = False
43 | for li, label_name in enumerate(CLASS_LABELS):
44 | y_true = np.empty(0)
45 | y_score = np.empty(0)
46 | hard_false_negatives = 0
47 | has_gt = False
48 | has_pred = False
49 | for m in matches:
50 | pred_instances = matches[m]['pred'][label_name]
51 | gt_instances = matches[m]['gt'][label_name]
52 | # filter groups in ground truth
53 | gt_instances = [gt for gt in gt_instances if
54 | gt['instance_id'] >= 1000 and gt['vert_count'] >= min_region_size and gt['med_dist'] <= distance_thresh and gt['dist_conf'] >= distance_conf]
55 | if gt_instances:
56 | has_gt = True
57 | if pred_instances:
58 | has_pred = True
59 |
60 | cur_true = np.ones(len(gt_instances))
61 | cur_score = np.ones(len(gt_instances)) * (-float("inf"))
62 | cur_match = np.zeros(len(gt_instances), dtype=np.bool)
63 | # collect matches
64 | for (gti, gt) in enumerate(gt_instances):
65 | found_match = False
66 | num_pred = len(gt['matched_pred'])
67 | for pred in gt['matched_pred']:
68 | # greedy assignments
69 | if pred_visited[pred['filename']]:
70 | continue
71 | overlap = float(pred['intersection']) / (
72 | gt['vert_count'] + pred['vert_count'] - pred['intersection'])
73 | if overlap > overlap_th:
74 | confidence = pred['confidence']
75 | # if already have a prediction for this gt,
76 | # the prediction with the lower score is automatically a false positive
77 | if cur_match[gti]:
78 | max_score = max(cur_score[gti], confidence)
79 | min_score = min(cur_score[gti], confidence)
80 | cur_score[gti] = max_score
81 | # append false positive
82 | cur_true = np.append(cur_true, 0)
83 | cur_score = np.append(cur_score, min_score)
84 | cur_match = np.append(cur_match, True)
85 | # otherwise set score
86 | else:
87 | found_match = True
88 | cur_match[gti] = True
89 | cur_score[gti] = confidence
90 | pred_visited[pred['filename']] = True
91 | if not found_match:
92 | hard_false_negatives += 1
93 | # remove non-matched ground truth instances
94 | cur_true = cur_true[cur_match == True]
95 | cur_score = cur_score[cur_match == True]
96 |
97 | # collect non-matched predictions as false positive
98 | for pred in pred_instances:
99 | found_gt = False
100 | for gt in pred['matched_gt']:
101 | overlap = float(gt['intersection']) / (
102 | gt['vert_count'] + pred['vert_count'] - gt['intersection'])
103 | if overlap > overlap_th:
104 | found_gt = True
105 | break
106 | if not found_gt:
107 | num_ignore = pred['void_intersection']
108 | for gt in pred['matched_gt']:
109 | # group?
110 | if gt['instance_id'] < 1000:
111 | num_ignore += gt['intersection']
112 | # small ground truth instances
113 | if gt['vert_count'] < min_region_size or gt['med_dist'] > distance_thresh or gt['dist_conf'] < distance_conf:
114 | num_ignore += gt['intersection']
115 | proportion_ignore = float(num_ignore) / pred['vert_count']
116 | # if not ignored append false positive
117 | if proportion_ignore <= overlap_th:
118 | cur_true = np.append(cur_true, 0)
119 | confidence = pred["confidence"]
120 | cur_score = np.append(cur_score, confidence)
121 |
122 | # append to overall results
123 | y_true = np.append(y_true, cur_true)
124 | y_score = np.append(y_score, cur_score)
125 |
126 | # compute average precision
127 | if has_gt and has_pred:
128 | # compute precision recall curve first
129 |
130 | # sorting and cumsum
131 | score_arg_sort = np.argsort(y_score)
132 | y_score_sorted = y_score[score_arg_sort]
133 | y_true_sorted = y_true[score_arg_sort]
134 | y_true_sorted_cumsum = np.cumsum(y_true_sorted)
135 |
136 | # unique thresholds
137 | (thresholds, unique_indices) = np.unique(y_score_sorted, return_index=True)
138 | num_prec_recall = len(unique_indices) + 1
139 |
140 | # prepare precision recall
141 | num_examples = len(y_score_sorted)
142 | if(len(y_true_sorted_cumsum) == 0):
143 | num_true_examples = 0
144 | else:
145 | num_true_examples = y_true_sorted_cumsum[-1]
146 | precision = np.zeros(num_prec_recall)
147 | recall = np.zeros(num_prec_recall)
148 |
149 | # deal with the first point
150 | y_true_sorted_cumsum = np.append(y_true_sorted_cumsum, 0)
151 | # deal with remaining
152 | for idx_res, idx_scores in enumerate(unique_indices):
153 | cumsum = y_true_sorted_cumsum[idx_scores - 1]
154 | tp = num_true_examples - cumsum
155 | fp = num_examples - idx_scores - tp
156 | fn = cumsum + hard_false_negatives
157 | p = float(tp) / (tp + fp)
158 | r = float(tp) / (tp + fn)
159 | precision[idx_res] = p
160 | recall[idx_res] = r
161 |
162 | # first point in curve is artificial
163 | precision[-1] = 1.
164 | recall[-1] = 0.
165 |
166 | # compute average of precision-recall curve
167 | recall_for_conv = np.copy(recall)
168 | recall_for_conv = np.append(recall_for_conv[0], recall_for_conv)
169 | recall_for_conv = np.append(recall_for_conv, 0.)
170 |
171 | stepWidths = np.convolve(recall_for_conv, [-0.5, 0, 0.5], 'valid')
172 | # integrate is now simply a dot product
173 | ap_current = np.dot(precision, stepWidths)
174 |
175 | elif has_gt:
176 | ap_current = 0.0
177 | else:
178 | ap_current = float('nan')
179 | ap[di, li, oi] = ap_current
180 | return ap
181 |
182 |
183 | def compute_averages(aps):
184 | d_inf = 0
185 | o50 = np.where(np.isclose(OVERLAPS,0.5))
186 | o25 = np.where(np.isclose(OVERLAPS,0.25))
187 | oAllBut25 = np.where(np.logical_not(np.isclose(OVERLAPS,0.25)))
188 | avg_dict = {}
189 | #avg_dict['all_ap'] = np.nanmean(aps[ d_inf,:,: ])
190 | avg_dict['all_ap'] = np.nanmean(aps[ d_inf,:,oAllBut25])
191 | avg_dict['all_ap_50%'] = np.nanmean(aps[ d_inf,:,o50])
192 | avg_dict['all_ap_25%'] = np.nanmean(aps[ d_inf,:,o25])
193 | avg_dict["classes"] = {}
194 | for (li,label_name) in enumerate(CLASS_LABELS):
195 | avg_dict["classes"][label_name] = {}
196 | #avg_dict["classes"][label_name]["ap"] = np.average(aps[ d_inf,li, :])
197 | avg_dict["classes"][label_name]["ap"] = np.average(aps[ d_inf,li,oAllBut25])
198 | avg_dict["classes"][label_name]["ap50%"] = np.average(aps[ d_inf,li,o50])
199 | avg_dict["classes"][label_name]["ap25%"] = np.average(aps[ d_inf,li,o25])
200 | return avg_dict
201 |
202 |
203 | def assign_instances_for_scan(scene_name, pred_info, gt_file):
204 |
205 | try:
206 | gt_ids = util_3d.load_ids(gt_file)
207 | except Exception as e:
208 | util.print_error('unable to load ' + gt_file + ': ' + str(e))
209 |
210 | # get gt instances
211 | gt_instances = util_3d.get_instances(gt_ids, VALID_CLASS_IDS, CLASS_LABELS, ID_TO_LABEL)
212 |
213 |
214 | # gt instance statistics
215 | # for key, item in gt_instances.items():
216 | # print('key', key)
217 | # for _ins in item:
218 | # print(_ins['vert_count'])
219 |
220 | # associate
221 | gt2pred = gt_instances.copy()
222 | for label in gt2pred:
223 | for gt in gt2pred[label]:
224 | gt['matched_pred'] = []
225 | pred2gt = {}
226 | for label in CLASS_LABELS:
227 | pred2gt[label] = []
228 | num_pred_instances = 0
229 | # mask of void labels in the groundtruth
230 | bool_void = np.logical_not(np.in1d(gt_ids//1000, VALID_CLASS_IDS))
231 | # go thru all prediction masks
232 | nMask = pred_info['label_id'].shape[0]
233 | for i in range(nMask):
234 | label_id = int(pred_info['label_id'][i])
235 | conf = pred_info['conf'][i]
236 | if not label_id in ID_TO_LABEL:
237 | continue
238 | label_name = ID_TO_LABEL[label_id]
239 | # read the mask
240 | pred_mask = pred_info['mask'][i] # (N), long
241 | if len(pred_mask) != len(gt_ids):
242 | util.print_error('wrong number of lines in mask#%d: ' % (i) + '(%d) vs #mesh vertices (%d)' % (len(pred_mask), len(gt_ids)))
243 | # convert to binary
244 | pred_mask = np.not_equal(pred_mask, 0)
245 | num = np.count_nonzero(pred_mask)
246 | if num < MIN_REGION_SIZES[0]:
247 | continue # skip if empty
248 |
249 | pred_instance = {}
250 | pred_instance['filename'] = '{}_{:03d}'.format(scene_name, num_pred_instances)
251 | pred_instance['pred_id'] = num_pred_instances
252 | pred_instance['label_id'] = label_id
253 | pred_instance['vert_count'] = num
254 | pred_instance['confidence'] = conf
255 | pred_instance['void_intersection'] = np.count_nonzero(np.logical_and(bool_void, pred_mask))
256 |
257 | # matched gt instances
258 | matched_gt = []
259 | # go thru all gt instances with matching label
260 | for (gt_num, gt_inst) in enumerate(gt2pred[label_name]):
261 | intersection = np.count_nonzero(np.logical_and(gt_ids == gt_inst['instance_id'], pred_mask))
262 | if intersection > 0:
263 | gt_copy = gt_inst.copy()
264 | pred_copy = pred_instance.copy()
265 | gt_copy['intersection'] = intersection
266 | pred_copy['intersection'] = intersection
267 | matched_gt.append(gt_copy)
268 | gt2pred[label_name][gt_num]['matched_pred'].append(pred_copy)
269 |
270 | pred_instance['matched_gt'] = matched_gt
271 | num_pred_instances += 1
272 | pred2gt[label_name].append(pred_instance)
273 |
274 | return gt2pred, pred2gt
275 |
276 |
277 | def print_results(avgs):
278 | from util.log import logger
279 | sep = ""
280 | col1 = ":"
281 | lineLen = 64
282 |
283 | logger.info("")
284 | logger.info("#" * lineLen)
285 | line = ""
286 | line += "{:<15}".format("what" ) + sep + col1
287 | line += "{:>15}".format("AP" ) + sep
288 | line += "{:>15}".format("AP_50%" ) + sep
289 | line += "{:>15}".format("AP_25%" ) + sep
290 | logger.info(line)
291 | logger.info("#" * lineLen)
292 |
293 | for (li,label_name) in enumerate(CLASS_LABELS):
294 | ap_avg = avgs["classes"][label_name]["ap"]
295 | ap_50o = avgs["classes"][label_name]["ap50%"]
296 | ap_25o = avgs["classes"][label_name]["ap25%"]
297 | line = "{:<15}".format(label_name) + sep + col1
298 | line += sep + "{:>15.3f}".format(ap_avg ) + sep
299 | line += sep + "{:>15.3f}".format(ap_50o ) + sep
300 | line += sep + "{:>15.3f}".format(ap_25o ) + sep
301 | logger.info(line)
302 |
303 | all_ap_avg = avgs["all_ap"]
304 | all_ap_50o = avgs["all_ap_50%"]
305 | all_ap_25o = avgs["all_ap_25%"]
306 |
307 | logger.info("-"*lineLen)
308 | line = "{:<15}".format("average") + sep + col1
309 | line += "{:>15.3f}".format(all_ap_avg) + sep
310 | line += "{:>15.3f}".format(all_ap_50o) + sep
311 | line += "{:>15.3f}".format(all_ap_25o) + sep
312 | logger.info(line)
313 | logger.info("")
--------------------------------------------------------------------------------
/util/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import time
5 |
6 | sys.path.append('../')
7 |
8 | from util.config import cfg
9 |
10 |
11 | def create_logger(log_file):
12 | logger = logging.getLogger(__name__)
13 | logger.setLevel(logging.DEBUG if cfg.local_rank == 0 else 'ERROR')
14 |
15 | handler = logging.StreamHandler()
16 | log_format = '[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s'
17 | handler.setFormatter(logging.Formatter(log_format))
18 | logger.addHandler(handler)
19 |
20 | logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file) # filename: build a FileHandler
21 | return logger
22 |
23 |
24 | if cfg.task == 'train':
25 | log_file = os.path.join(
26 | cfg.exp_path,
27 | 'train-{}.log'.format(time.strftime("%Y%m%d_%H%M%S", time.localtime()))
28 | )
29 | elif cfg.task == 'test':
30 | log_file = os.path.join(
31 | cfg.exp_path, 'result', 'epoch{}_nmst{}_scoret{}_npointt{}'.format(cfg.test_epoch, cfg.TEST_NMS_THRESH, cfg.TEST_SCORE_THRESH, cfg.TEST_NPOINT_THRESH),
32 | cfg.split, 'test-{}.log'.format(time.strftime("%Y%m%d_%H%M%S", time.localtime()))
33 | )
34 | if not os.path.exists(os.path.dirname(log_file)):
35 | os.makedirs(os.path.dirname(log_file), exist_ok=True)
36 |
37 |
38 | logger = create_logger(log_file)
39 | logger.info('************************ Start Logging ************************')
--------------------------------------------------------------------------------
/util/utils.py:
--------------------------------------------------------------------------------
1 | import torch, glob, os, numpy as np
2 | import sys
3 | sys.path.append('../')
4 | from math import cos, pi
5 | from util.log import logger
6 |
7 | class AverageMeter(object):
8 | """Computes and stores the average and current value"""
9 | def __init__(self):
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
24 |
25 | # Epoch counts from 0 to N-1
26 | def cosine_lr_after_step(optimizer, base_lr, epoch, step_epoch, total_epochs, clip=1e-6):
27 | if epoch < step_epoch:
28 | lr = base_lr
29 | else:
30 | lr = clip + 0.5 * (base_lr - clip) * \
31 | (1 + cos(pi * ( (epoch - step_epoch) / (total_epochs - step_epoch))))
32 |
33 | for param_group in optimizer.param_groups:
34 | param_group['lr'] = lr
35 |
36 |
37 |
38 | def intersectionAndUnion(output, target, K, ignore_index=255):
39 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
40 | assert (output.ndim in [1, 2, 3])
41 | assert output.shape == target.shape
42 | output = output.reshape(output.size).copy()
43 | target = target.reshape(target.size)
44 | output[np.where(target == ignore_index)[0]] = ignore_index
45 | intersection = output[np.where(output == target)[0]]
46 | area_intersection, _ = np.histogram(intersection, bins=np.arange(K+1)) # area_intersection: K, indicates the number of members in each class in intersection
47 | area_output, _ = np.histogram(output, bins=np.arange(K+1))
48 | area_target, _ = np.histogram(target, bins=np.arange(K+1))
49 | area_union = area_output + area_target - area_intersection
50 | return area_intersection, area_union, area_target
51 |
52 |
53 | def checkpoint_restore(cfg, model, optimizer, exp_path, exp_name, use_cuda=True, epoch=0, dist=False, f=''):
54 | if use_cuda:
55 | model.cpu()
56 | if not f:
57 | if epoch > 0:
58 | f = os.path.join(exp_path, exp_name + '-%09d'%epoch + '.pth')
59 | assert os.path.isfile(f)
60 | else:
61 | f = sorted(glob.glob(os.path.join(exp_path, exp_name + '-*.pth')))
62 | if len(f) > 0:
63 | f = f[-1]
64 | epoch = int(f[len(exp_path) + len(exp_name) + 2 : -4])
65 |
66 | if len(f) > 0:
67 | logger.info('Restore from ' + f)
68 | checkpoint = torch.load(f)
69 |
70 |
71 | if 'net' in checkpoint.keys() and 'optimizer' in checkpoint.keys():
72 | net_checkpoint = checkpoint['net']
73 | optimizer_checkpoint = checkpoint['optimizer']
74 |
75 | #load net
76 | for k, v in net_checkpoint.items():
77 | if 'module.' in k:
78 | net_checkpoint = {k[len('module.'):]: v for k, v in net_checkpoint.items()}
79 | break
80 | if dist:
81 | model.module.load_state_dict(net_checkpoint)
82 | else:
83 | model.load_state_dict(net_checkpoint)
84 |
85 | # load optimizer
86 | load_optimizer = getattr(cfg, 'load_optimizer', True)
87 | if optimizer is not None and load_optimizer == True:
88 | optimizer.load_state_dict(optimizer_checkpoint)
89 | for k in optimizer.state.keys():
90 | optimizer.state[k]['exp_avg'] = optimizer.state[k]['exp_avg'].cuda()
91 | optimizer.state[k]['exp_avg_sq'] = optimizer.state[k]['exp_avg_sq'].cuda()
92 |
93 | else: # deprecated without optimizer
94 | for k, v in checkpoint.items():
95 | if 'module.' in k:
96 | checkpoint = {k[len('module.'):]: v for k, v in checkpoint.items()}
97 | break
98 | if dist:
99 | model.module.load_state_dict(checkpoint)
100 | else:
101 | model.load_state_dict(checkpoint)
102 |
103 | if use_cuda:
104 | model.cuda()
105 |
106 | return epoch + 1
107 |
108 |
109 | def is_power2(num):
110 | return num != 0 and ((num & (num - 1)) == 0)
111 |
112 |
113 | def is_multiple(num, multiple):
114 | return num != 0 and num % multiple == 0
115 |
116 |
117 | def checkpoint_save(model, optimizer, exp_path, exp_name, epoch, save_freq=16, use_cuda=True, ):
118 | f = os.path.join(exp_path, exp_name + '-%09d'%epoch + '.pth')
119 | logger.info('Saving ' + f)
120 | model.cpu()
121 |
122 | checkpoint = {'net': model.state_dict(), 'optimizer': optimizer.state_dict()}
123 | torch.save(checkpoint, f)
124 |
125 | if use_cuda:
126 | model.cuda()
127 |
128 | # remove previous checkpoints unless they are a power of 2 or a multiple of 16 to save disk space
129 | epoch = epoch - 1
130 | f = os.path.join(exp_path, exp_name + '-%09d'%epoch + '.pth')
131 | if os.path.isfile(f):
132 | if not is_multiple(epoch, save_freq) and not is_power2(epoch):
133 | os.remove(f)
134 |
135 |
136 | def load_model_param(model, pretrained_dict, prefix=""):
137 | # suppose every param in model should exist in pretrain_dict, but may differ in the prefix of the name
138 | # For example: model_dict: "0.conv.weight" pretrain_dict: "FC_layer.0.conv.weight"
139 | model_dict = model.state_dict()
140 | len_prefix = 0 if len(prefix) == 0 else len(prefix) + 1
141 | pretrained_dict_filter = {k[len_prefix:]: v for k, v in pretrained_dict.items() if k[len_prefix:] in model_dict and prefix in k}
142 | assert len(pretrained_dict_filter) > 0
143 | model_dict.update(pretrained_dict_filter)
144 | model.load_state_dict(model_dict)
145 | return len(pretrained_dict_filter), len(model_dict)
146 |
147 |
148 | def write_obj(points, colors, out_filename):
149 | N = points.shape[0]
150 | fout = open(out_filename, 'w')
151 | for i in range(N):
152 | c = colors[i]
153 | fout.write('v %f %f %f %d %d %d\n' % (points[i,0],points[i,1],points[i,2],c[0],c[1],c[2]))
154 | fout.close()
155 |
156 |
157 | def get_batch_offsets(batch_idxs, bs):
158 | '''
159 | :param batch_idxs: (N), int
160 | :param bs: int
161 | :return: batch_offsets: (bs + 1)
162 | '''
163 | batch_offsets = torch.zeros(bs + 1).int().cuda()
164 | for i in range(bs):
165 | batch_offsets[i + 1] = batch_offsets[i] + (batch_idxs == i).sum()
166 | assert batch_offsets[-1] == batch_idxs.shape[0]
167 | return batch_offsets
168 |
169 |
170 | def print_error(message, user_fault=False):
171 | sys.stderr.write('ERROR: ' + str(message) + '\n')
172 | if user_fault:
173 | sys.exit(2)
174 | sys.exit(-1)
175 |
176 |
177 |
178 |
--------------------------------------------------------------------------------
/util/utils_3d.py:
--------------------------------------------------------------------------------
1 | # ScanNet util_3d: https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/util_3d.py
2 |
3 | import json, numpy as np
4 |
5 | def load_ids(filename):
6 | ids = open(filename).read().splitlines()
7 | ids = np.array(ids, dtype=np.int64)
8 | return ids
9 |
10 |
11 | # ------------ Instance Utils ------------ #
12 |
13 | class Instance(object):
14 | instance_id = 0
15 | label_id = 0
16 | vert_count = 0
17 | med_dist = -1
18 | dist_conf = 0.0
19 |
20 | def __init__(self, mesh_vert_instances, instance_id):
21 | if (instance_id == -1):
22 | return
23 | self.instance_id = int(instance_id)
24 | self.label_id = int(self.get_label_id(instance_id))
25 | self.vert_count = int(self.get_instance_verts(mesh_vert_instances, instance_id))
26 |
27 | def get_label_id(self, instance_id):
28 | return int(instance_id // 1000)
29 |
30 | def get_instance_verts(self, mesh_vert_instances, instance_id):
31 | return (mesh_vert_instances == instance_id).sum()
32 |
33 | def to_json(self):
34 | return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4)
35 |
36 | def to_dict(self):
37 | dict = {}
38 | dict["instance_id"] = self.instance_id
39 | dict["label_id"] = self.label_id
40 | dict["vert_count"] = self.vert_count
41 | dict["med_dist"] = self.med_dist
42 | dict["dist_conf"] = self.dist_conf
43 | return dict
44 |
45 | def from_json(self, data):
46 | self.instance_id = int(data["instance_id"])
47 | self.label_id = int(data["label_id"])
48 | self.vert_count = int(data["vert_count"])
49 | if ("med_dist" in data):
50 | self.med_dist = float(data["med_dist"])
51 | self.dist_conf = float(data["dist_conf"])
52 |
53 | def __str__(self):
54 | return "("+str(self.instance_id)+")"
55 |
56 |
57 | def get_instances(ids, class_ids, class_labels, id2label):
58 | instances = {}
59 | for label in class_labels:
60 | instances[label] = []
61 | instance_ids = np.unique(ids)
62 |
63 | for id in instance_ids:
64 | if id == 0:
65 | continue
66 | inst = Instance(ids, id)
67 | if inst.label_id in class_ids:
68 | instances[id2label[inst.label_id]].append(inst.to_dict())
69 | return instances
70 |
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/visualize_open3d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os, glob, argparse
3 | import torch
4 | from operator import itemgetter
5 | import cv2
6 | import open3d as o3d
7 | import glob
8 |
9 | COLOR_DETECTRON2 = np.array(
10 | [
11 | 0.000, 0.447, 0.741,
12 | 0.850, 0.325, 0.098,
13 | 0.929, 0.694, 0.125,
14 | 0.494, 0.184, 0.556,
15 | 0.466, 0.674, 0.188,
16 | 0.301, 0.745, 0.933,
17 | 0.635, 0.078, 0.184,
18 | # 0.300, 0.300, 0.300,
19 | 0.600, 0.600, 0.600,
20 | 1.000, 0.000, 0.000,
21 | 1.000, 0.500, 0.000,
22 | 0.749, 0.749, 0.000,
23 | 0.000, 1.000, 0.000,
24 | 0.000, 0.000, 1.000,
25 | 0.667, 0.000, 1.000,
26 | 0.333, 0.333, 0.000,
27 | 0.333, 0.667, 0.000,
28 | 0.333, 1.000, 0.000,
29 | 0.667, 0.333, 0.000,
30 | 0.667, 0.667, 0.000,
31 | 0.667, 1.000, 0.000,
32 | 1.000, 0.333, 0.000,
33 | 1.000, 0.667, 0.000,
34 | 1.000, 1.000, 0.000,
35 | 0.000, 0.333, 0.500,
36 | 0.000, 0.667, 0.500,
37 | 0.000, 1.000, 0.500,
38 | 0.333, 0.000, 0.500,
39 | 0.333, 0.333, 0.500,
40 | 0.333, 0.667, 0.500,
41 | 0.333, 1.000, 0.500,
42 | 0.667, 0.000, 0.500,
43 | 0.667, 0.333, 0.500,
44 | 0.667, 0.667, 0.500,
45 | 0.667, 1.000, 0.500,
46 | 1.000, 0.000, 0.500,
47 | 1.000, 0.333, 0.500,
48 | 1.000, 0.667, 0.500,
49 | 1.000, 1.000, 0.500,
50 | 0.000, 0.333, 1.000,
51 | 0.000, 0.667, 1.000,
52 | 0.000, 1.000, 1.000,
53 | 0.333, 0.000, 1.000,
54 | 0.333, 0.333, 1.000,
55 | 0.333, 0.667, 1.000,
56 | 0.333, 1.000, 1.000,
57 | 0.667, 0.000, 1.000,
58 | 0.667, 0.333, 1.000,
59 | 0.667, 0.667, 1.000,
60 | 0.667, 1.000, 1.000,
61 | 1.000, 0.000, 1.000,
62 | 1.000, 0.333, 1.000,
63 | 1.000, 0.667, 1.000,
64 | # 0.333, 0.000, 0.000,
65 | 0.500, 0.000, 0.000,
66 | 0.667, 0.000, 0.000,
67 | 0.833, 0.000, 0.000,
68 | 1.000, 0.000, 0.000,
69 | 0.000, 0.167, 0.000,
70 | # 0.000, 0.333, 0.000,
71 | 0.000, 0.500, 0.000,
72 | 0.000, 0.667, 0.000,
73 | 0.000, 0.833, 0.000,
74 | 0.000, 1.000, 0.000,
75 | 0.000, 0.000, 0.167,
76 | # 0.000, 0.000, 0.333,
77 | 0.000, 0.000, 0.500,
78 | 0.000, 0.000, 0.667,
79 | 0.000, 0.000, 0.833,
80 | 0.000, 0.000, 1.000,
81 | # 0.000, 0.000, 0.000,
82 | 0.143, 0.143, 0.143,
83 | 0.857, 0.857, 0.857,
84 | # 1.000, 1.000, 1.000
85 | ]).astype(np.float32).reshape(-1, 3) * 255
86 |
87 | SEMANTIC_IDXS = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39])
88 | SEMANTIC_NAMES = np.array(['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 'counter',
89 | 'desk', 'curtain', 'refridgerator', 'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture'])
90 | CLASS_COLOR = {
91 | 'unannotated': [0, 0, 0],
92 | 'floor': [143, 223, 142],
93 | 'wall': [171, 198, 230],
94 | 'cabinet': [0, 120, 177],
95 | 'bed': [255, 188, 126],
96 | 'chair': [189, 189, 57],
97 | 'sofa': [144, 86, 76],
98 | 'table': [255, 152, 153],
99 | 'door': [222, 40, 47],
100 | 'window': [197, 176, 212],
101 | 'bookshelf': [150, 103, 185],
102 | 'picture': [200, 156, 149],
103 | 'counter': [0, 190, 206],
104 | 'desk': [252, 183, 210],
105 | 'curtain': [219, 219, 146],
106 | 'refridgerator': [255, 127, 43],
107 | 'bathtub': [234, 119, 192],
108 | 'shower curtain': [150, 218, 228],
109 | 'toilet': [0, 160, 55],
110 | 'sink': [110, 128, 143],
111 | 'otherfurniture': [80, 83, 160]
112 | }
113 | SEMANTIC_IDX2NAME = {1: 'wall', 2: 'floor', 3: 'cabinet', 4: 'bed', 5: 'chair', 6: 'sofa', 7: 'table', 8: 'door', 9: 'window', 10: 'bookshelf', 11: 'picture',
114 | 12: 'counter', 14: 'desk', 16: 'curtain', 24: 'refridgerator', 28: 'shower curtain', 33: 'toilet', 34: 'sink', 36: 'bathtub', 39: 'otherfurniture'}
115 |
116 |
117 | def get_coords_color(opt):
118 | input_file = os.path.join(opt.data_path, opt.data_split, opt.room_name + '_inst_nostuff.pth')
119 | assert os.path.isfile(input_file), 'File not exist - {}.'.format(input_file)
120 | if opt.data_split == 'test':
121 | xyz, rgb = torch.load(input_file)
122 | else:
123 | xyz, rgb, label, inst_label = torch.load(input_file)
124 |
125 | rgb = (rgb + 1) * 127.5
126 |
127 | if (opt.task == 'semantic_gt'):
128 | assert opt.data_split != 'test'
129 | label = label.astype(np.int)
130 | label_rgb = np.zeros(rgb.shape)
131 | label_rgb[label >= 0] = np.array(itemgetter(*SEMANTIC_NAMES[label[label >= 0]])(CLASS_COLOR))
132 | rgb = label_rgb
133 |
134 | elif (opt.task == 'semantic_pred'):
135 | assert opt.data_split != 'train'
136 | semantic_file = os.path.join(opt.prediction_path, opt.data_split, 'semantic', opt.room_name + '.npy')
137 | assert os.path.isfile(semantic_file), 'No semantic result - {}.'.format(semantic_file)
138 | label_pred = np.load(semantic_file).astype(np.int) # 0~19
139 | label_pred_rgb = np.array(itemgetter(*SEMANTIC_NAMES[label_pred])(CLASS_COLOR))
140 | rgb = label_pred_rgb
141 |
142 | elif (opt.task == 'offset_semantic_pred'):
143 | assert opt.data_split != 'train'
144 | semantic_file = os.path.join(opt.prediction_path, opt.data_split, 'semantic', opt.room_name + '.npy')
145 | assert os.path.isfile(semantic_file), 'No semantic result - {}.'.format(semantic_file)
146 | label_pred = np.load(semantic_file).astype(np.int) # 0~19
147 | label_pred_rgb = np.array(itemgetter(*SEMANTIC_NAMES[label_pred])(CLASS_COLOR))
148 | rgb = label_pred_rgb
149 |
150 | offset_file = os.path.join(opt.prediction_path, opt.data_split, 'coords_offsets', opt.room_name + '.npy')
151 | assert os.path.isfile(offset_file), 'No offset result - {}.'.format(offset_file)
152 | offset_coords = np.load(offset_file)
153 | xyz = offset_coords[:, :3] + offset_coords[:, 3:]
154 |
155 | # same color order according to instance pointnum
156 | elif (opt.task == 'instance_gt'):
157 | assert opt.data_split != 'test'
158 | inst_label = inst_label.astype(np.int)
159 | print("Instance number: {}".format(inst_label.max() + 1))
160 | inst_label_rgb = np.zeros(rgb.shape)
161 | object_idx = (inst_label >= 0)
162 | ins_num = inst_label.max() + 1
163 | ins_pointnum = np.zeros(ins_num)
164 | for _ins_id in range(ins_num):
165 | ins_pointnum[_ins_id] = (inst_label == _ins_id).sum()
166 | sort_idx = np.argsort(ins_pointnum)[::-1]
167 | for _sort_id in range(ins_num):
168 | inst_label_rgb[inst_label == sort_idx[_sort_id] ] = COLOR_DETECTRON2[_sort_id % len(COLOR_DETECTRON2)]
169 | rgb = inst_label_rgb
170 |
171 | # same color order according to instance pointnum
172 | elif (opt.task == 'instance_pred'):
173 | assert opt.data_split != 'train'
174 | instance_file = os.path.join(opt.prediction_path, opt.data_split, opt.room_name + '.txt')
175 | assert os.path.isfile(instance_file), 'No instance result - {}.'.format(instance_file)
176 | f = open(instance_file, 'r')
177 | masks = f.readlines()
178 | masks = [mask.rstrip().split() for mask in masks]
179 | inst_label_pred_rgb = np.zeros(rgb.shape) # np.ones(rgb.shape) * 255 #
180 |
181 | ins_num = len(masks)
182 | ins_pointnum = np.zeros(ins_num)
183 | inst_label = -100 * np.ones(rgb.shape[0]).astype(np.int)
184 |
185 | for i in range(len(masks) - 1, -1, -1):
186 | mask_path = os.path.join(opt.prediction_path, opt.data_split, masks[i][0])
187 | assert os.path.isfile(mask_path), mask_path
188 | if (float(masks[i][2]) < 0.09):
189 | continue
190 | mask = np.loadtxt(mask_path).astype(np.int)
191 | print('{} {}: {} pointnum: {}'.format(i, masks[i], SEMANTIC_IDX2NAME[int(masks[i][1])], mask.sum()))
192 | ins_pointnum[i] = mask.sum()
193 | inst_label[mask == 1] = i
194 | sort_idx = np.argsort(ins_pointnum)[::-1]
195 | for _sort_id in range(ins_num):
196 | inst_label_pred_rgb[inst_label == sort_idx[_sort_id] ] = COLOR_DETECTRON2[_sort_id % len(COLOR_DETECTRON2)]
197 | rgb = inst_label_pred_rgb
198 |
199 |
200 | if opt.data_split != 'test':
201 | sem_valid = (label != -100)
202 | xyz = xyz[sem_valid]
203 | rgb = rgb[sem_valid]
204 |
205 | return xyz, rgb
206 |
207 |
208 | if __name__ == '__main__':
209 | parser = argparse.ArgumentParser()
210 | parser.add_argument('--data_path', help='path to the dataset files')
211 | parser.add_argument('--prediction_path', help='path to the prediction results')
212 | parser.add_argument('--data_split', help='train / val / test', default='val')
213 | parser.add_argument('--room_name', help='room_name', default='scene0146_01')
214 | parser.add_argument('--task', help='input / semantic_gt / semantic_pred / offset_semantic_pred / instance_gt / instance_pred', default='input')
215 | opt = parser.parse_args()
216 |
217 |
218 |
219 | xyz, rgb = get_coords_color(opt)
220 | points = xyz[:, :3]
221 | colors = rgb / 255
222 |
223 | pc = o3d.geometry.PointCloud()
224 | pc.points = o3d.utility.Vector3dVector(points)
225 | pc.colors = o3d.utility.Vector3dVector(colors)
226 |
227 | vis = o3d.visualization.Visualizer()
228 | vis.create_window()
229 | vis.add_geometry(pc)
230 | vis.get_render_option().point_size = 1.5
231 | vis.run()
232 | vis.destroy_window()
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
--------------------------------------------------------------------------------