├── .gitignore ├── LICENSE ├── README.md ├── config └── default.yaml ├── data └── scannetv2 │ ├── .gitignore │ ├── README.md │ └── prepare_data.sh ├── docs ├── benchmark.png └── overview.png ├── requirements.txt ├── sstnet ├── __init__.py ├── data │ ├── __init__.py │ ├── scannetv2.py │ └── utils.py ├── lib │ ├── .gitignore │ ├── cluster │ │ ├── __init__.py │ │ ├── _hierarchy.c │ │ ├── _hierarchy.pyx │ │ ├── _hierarchy_distance_update.pxi │ │ ├── _optimal_leaf_ordering.c │ │ ├── _optimal_leaf_ordering.pyx │ │ ├── _structures.pxi │ │ ├── _vq.c │ │ ├── _vq.pyx │ │ ├── hierarchy.py │ │ ├── setup.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ ├── hierarchy_test_data.py │ │ │ ├── test_hierarchy.py │ │ │ └── test_vq.py │ │ └── vq.py │ ├── htree │ │ ├── setup.py │ │ └── src │ │ │ ├── api.cpp │ │ │ ├── tree.cpp │ │ │ └── tree.h │ └── pointgroup_ops │ │ ├── pointgroup_ops │ │ ├── __init__.py │ │ ├── pointgroup_ops.py │ │ └── src │ │ │ ├── bfs_cluster │ │ │ ├── bfs_cluster.cpp │ │ │ ├── bfs_cluster.cu │ │ │ └── bfs_cluster.h │ │ │ ├── cuda.cu │ │ │ ├── cuda_utils.h │ │ │ ├── datatype │ │ │ ├── datatype.cpp │ │ │ └── datatype.h │ │ │ ├── get_iou │ │ │ ├── get_iou.cpp │ │ │ ├── get_iou.cu │ │ │ └── get_iou.h │ │ │ ├── pointgroup_ops.cpp │ │ │ ├── pointgroup_ops.h │ │ │ ├── pointgroup_ops_api.cpp │ │ │ ├── roipool │ │ │ ├── roipool.cpp │ │ │ ├── roipool.cu │ │ │ └── roipool.h │ │ │ ├── sec_mean │ │ │ ├── sec_mean.cpp │ │ │ ├── sec_mean.cu │ │ │ └── sec_mean.h │ │ │ └── voxelize │ │ │ ├── voxelize.cpp │ │ │ ├── voxelize.cu │ │ │ └── voxelize.h │ │ └── setup.py └── model │ ├── __init__.py │ ├── func_helper.py │ ├── losses.py │ └── sstnet.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | log 2 | __pycache__ 3 | pretrain 4 | results 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 lzhhnb 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSTNet 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/instance-segmentation-in-3d-scenes-using/3d-instance-segmentation-on-scannetv2)](https://paperswithcode.com/sota/3d-instance-segmentation-on-scannetv2?p=instance-segmentation-in-3d-scenes-using) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/instance-segmentation-in-3d-scenes-using/3d-instance-segmentation-on-s3dis)](https://paperswithcode.com/sota/3d-instance-segmentation-on-s3dis?p=instance-segmentation-in-3d-scenes-using) 5 | 6 | ![overview](docs/overview.png) 7 | **Instance Segmentation in 3D Scenes using Semantic Superpoint Tree Networks(ICCV2021)** by [Zhihao Liang](https://lzhnb.github.io/), Zhihao Li, Songcen Xu, Mingkui Tan, [Kui Jia*](http://kuijia.site/). (\*) Corresponding author. 8 | [[arxiv]](https://arxiv.org/abs/2108.07478) 9 | [[ICCV2021]](https://openaccess.thecvf.com/content/ICCV2021/papers/Liang_Instance_Segmentation_in_3D_Scenes_Using_Semantic_Superpoint_Tree_Networks_ICCV_2021_paper.pdf) 10 | 11 | 12 | ## Introduction 13 | Instance segmentation in 3D scenes is fundamental in many applications of scene understanding. It is yet challenging due to the compound factors of data irregularity and uncertainty in the numbers of instances. State-of-the-art methods largely rely on a general pipeline that first learns point-wise features discriminative at semantic and instance levels, followed by a separate step of point grouping for proposing object instances. While promising, they have the shortcomings that (1) the second step is not supervised by the main objective of instance segmentation, and (2) their point-wise feature learning and grouping are less effective to deal with data irregularities, possibly resulting in fragmented segmentations. To address these issues, we propose in this work an end-to-end solution of Semantic Superpoint Tree Network (SSTNet) for proposing object instances from scene points. Key in SSTNet is an intermediate, semantic superpoint tree (SST), which is constructed based on the learned semantic features of superpoints, and which will be traversed and split at intermediate tree nodes for proposals of object instances. We also design in SSTNet a refinement module, termed CliqueNet, to prune superpoints that may be wrongly grouped into instance proposals. 14 | 15 | ## Installation 16 | 17 | ### Requirements 18 | * Python 3.8.5 19 | * Pytorch 1.7.1 20 | * torchvision 0.8.2 21 | * CUDA 11.1 22 | 23 | then install the requirements: 24 | ```sh 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ### SparseConv 29 | For the SparseConv, please refer [PointGroup's spconv](https://github.com/llijiang/spconv) to install. 30 | 31 | ### Extension 32 | This project is based on our Gorilla-Lab deep learning toolkit - `gorilla-core` and 3D toolkit `gorilla-3d`. 33 | 34 | For `gorilla-core`, you can install it by running: 35 | ```sh 36 | pip install gorilla-core==0.2.7.6 37 | ``` 38 | or building from source(recommend) 39 | ```sh 40 | git clone https://github.com/Gorilla-Lab-SCUT/gorilla-core 41 | cd gorilla-core 42 | python setup.py install(develop) 43 | ``` 44 | 45 | For `gorilla-3d`, you should install it by building from source: 46 | ```sh 47 | git clone https://github.com/Gorilla-Lab-SCUT/gorilla-3d 48 | cd gorilla-3d 49 | python setup.py develop 50 | ``` 51 | > Tip: for high-version `torch`, the `BuildExtension` may fail by using ninja to build the compile system. If you meet this problem, you can change the `BuildExtension` in `cmdclass={"build_ext": BuildExtension}` as `cmdclass={"build_ext": BuildExtension}.with_options(use_ninja=False)` 52 | 53 | Otherwise, this project also need other extension, we use the `pointgroup_ops` to realize voxelization and use the `segmentator` to generate superpoints for scannet scene. we use the `htree` to construct the **Semantic Superpoint Tree** and the **hierarchical node-inheriting relations** is realized based on the modified `cluster.hierarchy.linkage` function from `scipy`. 54 | 55 | - For `pointgroup_ops`, we modified the package from `PointGroup` to let its function calls get rid of the dependence on absolute paths. You can install it by running: 56 | ```sh 57 | conda install -c bioconda google-sparsehash 58 | cd $PROJECT_ROOT$ 59 | cd sstnet/lib/pointgroup_ops 60 | python setup.py develop 61 | ``` 62 | Then, you can call the function like: 63 | ```python 64 | import pointgroup_ops 65 | pointgroup_ops.voxelization 66 | >>> 67 | ``` 68 | - For `htree`, it can be seen as a supplement to the `treelib` python package, and I abstract the **SST** through both of them. You can install it by running: 69 | ```sh 70 | cd $PROJECT_ROOT$ 71 | cd sstnet/lib/htree 72 | python setup.py install 73 | ``` 74 | > Tip: The interaction between this piece of code and `treelib` is a bit messy. I lack time to organize it, which may cause some difficulties for someone in understanding. I am sorry for this. At the same time, I also welcome people to improve it. 75 | - For `cluster`, it is originally a sub-module in `scipy`, the `SST` construction requires the `cluster.hierarchy.linkage` to be implemented. However, the origin implementation do not consider the sizes of clustering nodes (each superpoint contains different number of points). To this end, we modify this function and let it support the property mentioned above. So, for used, you can install it by running: 76 | ```sh 77 | cd $PROJECT_ROOT$ 78 | cd sstnet/lib/cluster 79 | python setup.py install 80 | ``` 81 | - For `segmentator`, please refer [here](https://github.com/Karbo123/segmentator) to install. (We wrap the [segmentator](https://github.com/ScanNet/ScanNet/tree/master/Segmentator) in ScanNet) 82 | 83 | ## Data Preparation 84 | Please refer to the `README.md` in `data/scannetv2` to realize data preparation. 85 | 86 | ## Training 87 | ``` 88 | CUDA_VISIBLE_DEVICES=0 python train.py --config config/default.yaml 89 | ``` 90 | You can start a tensorboard session by 91 | ``` 92 | tensorboard --logdir=./log --port=6666 93 | ``` 94 | > Tip: For the directory of logging, please refer the implementation of function `gorilla.collect_logger`. 95 | 96 | ## Inference and Evaluation 97 | ``` 98 | CUDA_VISIBLE_DEVICES=0 python test.py --config config/default.yaml --pretrain pretrain.pth --eval 99 | ``` 100 | - `--split` is the evaluation split of dataset. 101 | - `--save` is the action to save instance segmentation results. 102 | - `--eval` is the action to evaluate the segmentation results. 103 | - `--semantic` is the action to evaluate semantic segmentation only (work on the `--eval` mode). 104 | - `--log-file` is to define the logging file to save evaluation result (default please to refer the `gorilla.collect_logger`). 105 | - `--visual` is the action to save visualization of instance segmentation. (It will be mentioned in the next partion.) 106 | 107 | ## Results on ScanNet Benchmark 108 | Rank 1st on the ScanNet benchmark 109 | ![benchmark](docs/benchmark.png) 110 | 111 | ## Pretrained 112 | We provide a pretrained model trained on ScanNet(v2) dataset. 113 | [[Google Drive]](https://drive.google.com/file/d/1UYT5QOjQQYB8QFzZi4cNyncCPNTSVDnu/view?usp=sharing) [[Baidu Cloud]](https://pan.baidu.com/s/19tNxhwO5UkGn7C3E8asMsQ) (提取码:f3az) 114 | Its performance on ScanNet(v2) validation set is 49.4/64.9/74.4 in terms of mAP/mAP50/mAP25. 115 | 116 | ## Acknowledgement 117 | This repo is built upon several repos, e.g., [PointGroup](https://github.com/dvlab-research/PointGroupt), [spconv](https://github.com/traveller59/spconv) and [ScanNet](https://github.com/ScanNet/ScanNet). 118 | 119 | 120 | ## Contact 121 | If you have any questions or suggestions about this repo or paper, please feel free to contact me in issue or email (eezhihaoliang@mail.scut.edu.cn). 122 | 123 | ## TODO 124 | - [ ] Distributed training(not verification) 125 | - [ ] Batch inference 126 | - [x] Multi-processing for getting superpoints 127 | 128 | ## Citation 129 | If you find this work useful in your research, please cite: 130 | ``` 131 | @inproceedings{liang2021instance, 132 | title={Instance Segmentation in 3D Scenes using Semantic Superpoint Tree Networks}, 133 | author={Liang, Zhihao and Li, Zhihao and Xu, Songcen and Tan, Mingkui and Jia, Kui}, 134 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 135 | pages={2783--2792}, 136 | year={2021} 137 | } 138 | ``` 139 | 140 | 141 | -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | task: train # train, test 2 | seed: 123 3 | 4 | data: 5 | ignore_label: -100 6 | mode: 4 # 4=mean 7 | 8 | # train mode 9 | epochs: 512 10 | save_freq: 8 # also eval_freq 11 | 12 | # test mode 13 | test_seed: 567 14 | test_workers: 8 # data loader workers 15 | 16 | TEST_NMS_THRESH: 0.3 17 | TEST_SCORE_THRESH: 0.00 18 | TEST_NPOINT_THRESH: 100 19 | 20 | split: val 21 | test_epoch: 512 22 | 23 | dataloader: 24 | batch_size: 4 25 | num_workers: 8 # data loader workers 26 | 27 | dataset: 28 | type: ScanNetV2Inst 29 | data_root: data/scannetv2 30 | full_scale: [128, 512] 31 | scale: 50 # voxel_size = 1 / scale, scale 50(2cm) 32 | max_npoint: 250000 33 | task: train 34 | with_elastic: False 35 | prefetch_superpoints: False 36 | 37 | model: 38 | type: SSTNet 39 | input_channel: 3 40 | use_coords: True 41 | blocks: 5 42 | block_reps: 2 43 | media: 32 # 16 or 32 44 | classes: 20 45 | score_scale: 50 # the minimal voxel size is 2cm 46 | score_fullscale: 14 47 | score_mode: 4 # mean 48 | detach: True 49 | affinity_weight: [1.0, 1.0] 50 | with_refine: False 51 | fusion_epochs: 128 52 | score_epochs: 160 53 | fix_module: [] 54 | 55 | loss: 56 | type: SSTLoss 57 | ignore_label: -100 58 | fusion_epochs: 128 59 | score_epochs: 160 60 | bg_thresh: 0.25 61 | fg_thresh: 0.75 62 | semantic_dice: True 63 | loss_weight: [1.0, 1.0, 1.0, 1.0, 1.0] # semantic_loss, offset_norm_loss, offset_dir_loss, score_loss 64 | 65 | # optimizer 66 | optimizer: 67 | lr: 0.001 68 | # type: Adam 69 | type: AdamW 70 | weight_decay: 0.0001 71 | # amsgrad: False 72 | 73 | # lr_scheduler 74 | lr_scheduler: 75 | type: PolyLR 76 | # max_iters: 153600 77 | # max_iters: 614912 78 | max_iters: 512 79 | power: 0.9 80 | constant_ending: 0.0 81 | 82 | -------------------------------------------------------------------------------- /data/scannetv2/.gitignore: -------------------------------------------------------------------------------- 1 | scans* 2 | train* 3 | val* 4 | test* 5 | -------------------------------------------------------------------------------- /data/scannetv2/README.md: -------------------------------------------------------------------------------- 1 | # Prepare ScanNet Data 2 | - Download origin [ScanNet](https://github.com/ScanNet/ScanNet) v2 data 3 | ```sh 4 | dataset 5 | └── scannetv2 6 | ├── meta_data(unnecessary, we have moved into our source code) 7 | │ ├── scannetv2_train.txt 8 | │ ├── scannetv2_val.txt 9 | │ ├── scannetv2_test.txt 10 | │ └── scannetv2-labels.combined.tsv 11 | ├── scans 12 | │ ├── ... 13 | │ ├── [scene_id] 14 | | | └── [scene_id]_vh_clean_2.ply & [scene_id]_vh_clean_2.labels.ply & [scene_id]_vh_clean_2.0.010000.segs.json & [scene_id].aggregation.json 15 | | └── ... 16 | └── scans_test 17 | ├── ... 18 | ├── [scene_id] 19 | | └── [scene_id]_vh_clean_2.ply & [scene_id].txt 20 | └── ... 21 | ``` 22 | 23 | - Refer to [PointGroup](https://github.com/Jia-Research-Lab/PointGroup), we've modify the code, and it can generate input files `[scene_id]_inst_nostuff.pth` for instance segmentation directly, you don't need to split the origin data into `train/val/test`, the script refer to `gorilla3d/preprocessing/scannetv2/segmentation`. 24 | - And we package these command. You just running: 25 | ```sh 26 | sh prepare_data.sh 27 | ``` 28 | - After running such command, the structure of directory is as following: 29 | ```sh 30 | dataset 31 | └── scannetv2 32 | ├── meta_data(unnecessary, we have moved into our source code) 33 | │ └── ... 34 | ├── scans 35 | | └── ... 36 | ├── scans_test 37 | | └── ... 38 | | (data preparation generation as following) 39 | ├── train 40 | | ├── [scene_id]_inst_nostuff.pth 41 | | └── ... 42 | ├── test 43 | | ├── [scene_id]_inst_nostuff.pth 44 | | └── ... 45 | ├── val 46 | | ├── [scene_id]_inst_nostuff.pth 47 | | └── ... 48 | └── val_gt 49 | ├── [scene_id].txt 50 | └── ... 51 | ``` 52 | -------------------------------------------------------------------------------- /data/scannetv2/prepare_data.sh: -------------------------------------------------------------------------------- 1 | # preprocess scannet dataset inputs 2 | python -m gorilla3d.preprocessing.scannetv2.segmentation.prepare_data_inst --data-split train 3 | python -m gorilla3d.preprocessing.scannetv2.segmentation.prepare_data_inst --data-split val 4 | python -m gorilla3d.preprocessing.scannetv2.segmentation.prepare_data_inst --data-split test 5 | # prepare validation dataset gt 6 | python -m gorilla3d.preprocessing.scannetv2.segmentation.prepare_data_inst_gttxt --data-split val 7 | -------------------------------------------------------------------------------- /docs/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/SSTNet/53edd4436ae60171a8031fa8709b996277e89835/docs/benchmark.png -------------------------------------------------------------------------------- /docs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/SSTNet/53edd4436ae60171a8031fa8709b996277e89835/docs/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pybind11 2 | scipy 3 | treelib 4 | torch_scatter 5 | open3d 6 | -------------------------------------------------------------------------------- /sstnet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Gorilla-Lab. All rights reserved. 2 | from .data import * 3 | from .model import * 4 | 5 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 6 | -------------------------------------------------------------------------------- /sstnet/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Gorilla-Lab. All rights reserved. 2 | from .scannetv2 import ScanNetV2Inst 3 | -------------------------------------------------------------------------------- /sstnet/data/scannetv2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Gorilla-Lab. All rights reserved. 2 | import os 3 | import time 4 | import math 5 | import glob 6 | import multiprocessing as mp 7 | from typing import Dict, List, Sequence, Tuple, Union 8 | 9 | import gorilla 10 | import open3d as o3d 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import Dataset 14 | 15 | import segmentator 16 | import pointgroup_ops 17 | from .utils import elastic 18 | 19 | 20 | class GetSuperpoint(mp.Process): 21 | def __init__(self, path: str, scene: str, mdict: Dict): 22 | # must call this before anything else 23 | mp.Process.__init__(self) 24 | self.path = path 25 | self.scene = scene 26 | self.mdict = mdict 27 | 28 | def run(self): 29 | mesh_file = os.path.join(os.path.join(self.path, self.scene, self.scene+"_vh_clean_2.ply")) 30 | mesh = o3d.io.read_triangle_mesh(mesh_file) 31 | vertices = torch.from_numpy(np.array(mesh.vertices).astype(np.float32)) 32 | faces = torch.from_numpy(np.array(mesh.triangles).astype(np.int64)) 33 | superpoint = segmentator.segment_mesh(vertices, faces).numpy() 34 | self.mdict.update({self.scene: superpoint}) 35 | 36 | 37 | @gorilla.DATASETS.register_module(force=True) 38 | class ScanNetV2Inst(Dataset): 39 | def __init__(self, 40 | data_root: str, 41 | full_scale: List[int]=[128, 512], 42 | scale: int=50, 43 | max_npoint: int=250000, 44 | task: str="train", 45 | with_elastic: bool=False, 46 | test_mode: bool=False, 47 | prefetch_superpoints: bool=True, 48 | **kwargs): 49 | # initialize dataset parameters 50 | self.logger = gorilla.derive_logger(__name__) 51 | self.data_root = data_root 52 | self.full_scale = full_scale 53 | self.scale = scale 54 | self.max_npoint = max_npoint 55 | self.test_mode = test_mode 56 | self.with_elastic = with_elastic 57 | self.prefetch_superpoints = prefetch_superpoints 58 | self.task = task 59 | self.aug_flag = "train" in self.task 60 | 61 | # load files 62 | self.load_files() 63 | 64 | def load_files(self): 65 | file_names = sorted(glob.glob(os.path.join(self.data_root, self.task, "*.pth"))) 66 | self.files = [torch.load(i) for i in gorilla.track(file_names)] 67 | self.logger.info(f"{self.task} samples: {len(self.files)}") 68 | self.superpoints = {} 69 | 70 | if self.prefetch_superpoints: 71 | self.logger.info("begin prefetch superpoints...") 72 | sub_dir = "scans_test" if "test" in self.task else "scans" 73 | path = os.path.join(self.data_root, sub_dir) 74 | with gorilla.Timer("prefetch superpoints:"): 75 | workers = [] 76 | mdict = mp.Manager().dict() 77 | # multi-processing generate superpoints 78 | for f in self.files: 79 | workers.append(GetSuperpoint(path, f[-1], mdict)) 80 | for worker in workers: 81 | worker.start() 82 | # wait for multi-processing 83 | while len(mdict) != len(self.files): 84 | time.sleep(0.1) 85 | self.superpoints.update(mdict) 86 | 87 | # # single processing (comparison) 88 | # if self.prefetch_superpoints: 89 | # self.logger.info("prefetch superpoints:") 90 | # for f in gorilla.utils.track(self.files): 91 | # self.get_superpoint(f[-1]) 92 | # import ipdb; ipdb.set_trace() 93 | 94 | 95 | def get_superpoint(self, scene: str): 96 | if scene in self.superpoints: 97 | return 98 | sub_dir = "scans_test" if "test" in self.task else "scans" 99 | mesh_file = os.path.join(self.data_root, sub_dir, scene, scene+"_vh_clean_2.ply") 100 | mesh = o3d.io.read_triangle_mesh(mesh_file) 101 | vertices = torch.from_numpy(np.array(mesh.vertices).astype(np.float32)) 102 | faces = torch.from_numpy(np.array(mesh.triangles).astype(np.int64)) 103 | superpoint = segmentator.segment_mesh(vertices, faces).numpy() 104 | self.superpoints[scene] = superpoint 105 | 106 | 107 | def __len__(self): 108 | return len(self.files) 109 | 110 | def __getitem__(self, index: int) -> Tuple: 111 | if "test" in self.task: 112 | xyz_origin, rgb, faces, scene = self.files[index] 113 | # construct fake label for label-lack testset 114 | semantic_label = np.zeros(xyz_origin.shape[0], dtype=np.int32) 115 | instance_label = np.zeros(xyz_origin.shape[0], dtype=np.int32) 116 | else: 117 | xyz_origin, rgb, faces, semantic_label, instance_label, coords_shift, scene = self.files[index] 118 | 119 | if not self.prefetch_superpoints: 120 | self.get_superpoint(scene) 121 | superpoint = self.superpoints[scene] 122 | 123 | ### jitter / flip x / rotation 124 | if self.aug_flag: 125 | xyz_middle = self.data_aug(xyz_origin, True, True, True) 126 | else: 127 | xyz_middle = self.data_aug(xyz_origin, False, False, False) 128 | 129 | ### scale 130 | xyz = xyz_middle * self.scale 131 | 132 | ### elastic 133 | if self.with_elastic: 134 | xyz = elastic(xyz, 6 * self.scale // 50, 40 * self.scale / 50) 135 | xyz = elastic(xyz, 20 * self.scale // 50, 160 * self.scale / 50) 136 | 137 | ### offset 138 | xyz_offset = xyz.min(0) 139 | xyz -= xyz_offset 140 | 141 | ### crop 142 | valid_idxs = np.ones(len(xyz_middle), dtype=np.bool) 143 | if not self.test_mode: 144 | xyz, valid_idxs = self.crop(xyz) 145 | 146 | xyz_middle = xyz_middle[valid_idxs] 147 | xyz = xyz[valid_idxs] 148 | rgb = rgb[valid_idxs] 149 | semantic_label = semantic_label[valid_idxs] 150 | superpoint = np.unique(superpoint[valid_idxs], return_inverse=True)[1] 151 | instance_label = self.get_cropped_inst_label(instance_label, valid_idxs) 152 | 153 | ### get instance information 154 | inst_num, inst_infos = self.get_instance_info(xyz_middle, instance_label.astype(np.int32)) 155 | inst_info = inst_infos["instance_info"] # [n, 9], (cx, cy, cz, minx, miny, minz, maxx, maxy, maxz) 156 | inst_pointnum = inst_infos["instance_pointnum"] # [num_inst], list 157 | 158 | loc = torch.from_numpy(xyz).long() 159 | loc_offset = torch.from_numpy(xyz_offset).long() 160 | loc_float = torch.from_numpy(xyz_middle) 161 | feat = torch.from_numpy(rgb) 162 | if self.aug_flag: 163 | feat += torch.randn(3) * 0.1 164 | semantic_label = torch.from_numpy(semantic_label) 165 | instance_label = torch.from_numpy(instance_label) 166 | superpoint = torch.from_numpy(superpoint) 167 | 168 | inst_info = torch.from_numpy(inst_info) 169 | 170 | return scene, loc, loc_offset, loc_float, feat, semantic_label, instance_label, superpoint, inst_num, inst_info, inst_pointnum 171 | 172 | def data_aug(self, xyz, jitter=False, flip=False, rot=False): 173 | m = np.eye(3) 174 | if jitter: 175 | m += np.random.randn(3, 3) * 0.1 176 | if flip: 177 | m[0][0] *= np.random.randint(0, 2) * 2 - 1 # flip x randomly 178 | if rot: 179 | theta = np.random.rand() * 2 * math.pi 180 | m = np.matmul(m, [[math.cos(theta), math.sin(theta), 0], [-math.sin(theta), math.cos(theta), 0], [0, 0, 1]]) # rotation 181 | return np.matmul(xyz, m) 182 | 183 | def crop(self, xyz: np.ndarray) -> Union[np.ndarray, np.ndarray]: 184 | r""" 185 | crop the point cloud to reduce training complexity 186 | 187 | Args: 188 | xyz (np.ndarray, [N, 3]): input point cloud to be cropped 189 | 190 | Returns: 191 | Union[np.ndarray, np.ndarray]: processed point cloud and boolean valid indices 192 | """ 193 | xyz_offset = xyz.copy() 194 | valid_idxs = (xyz_offset.min(1) >= 0) 195 | assert valid_idxs.sum() == xyz.shape[0] 196 | 197 | full_scale = np.array([self.full_scale[1]] * 3) 198 | room_range = xyz.max(0) - xyz.min(0) 199 | while (valid_idxs.sum() > self.max_npoint): 200 | offset = np.clip(full_scale - room_range + 0.001, None, 0) * np.random.rand(3) 201 | xyz_offset = xyz + offset 202 | valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < full_scale).sum(1) == 3) 203 | full_scale[:2] -= 32 204 | 205 | return xyz_offset, valid_idxs 206 | 207 | def get_instance_info(self, 208 | xyz: np.ndarray, 209 | instance_label: np.ndarray) -> Union[int, Dict]: 210 | r""" 211 | get the informations of instances (amount and coordinates) 212 | 213 | Args: 214 | xyz (np.ndarray, [N, 3]): input point cloud data 215 | instance_label (np.ndarray, [N]): instance ids of point cloud 216 | 217 | Returns: 218 | Union[int, Dict]: the amount of instances andinformations 219 | (coordinates and the number of points) of instances 220 | """ 221 | instance_info = np.ones((xyz.shape[0], 9), dtype=np.float32) * -100.0 # [n, 9], float, (cx, cy, cz, minx, miny, minz, maxx, maxy, maxz) 222 | instance_pointnum = [] # [num_inst], int 223 | instance_num = int(instance_label.max()) + 1 224 | for i_ in range(instance_num): 225 | inst_idx_i = np.where(instance_label == i_) 226 | 227 | ### instance_info 228 | xyz_i = xyz[inst_idx_i] 229 | min_xyz_i = xyz_i.min(0) 230 | max_xyz_i = xyz_i.max(0) 231 | mean_xyz_i = xyz_i.mean(0) 232 | instance_info_i = instance_info[inst_idx_i] 233 | instance_info_i[:, 0:3] = mean_xyz_i 234 | instance_info_i[:, 3:6] = min_xyz_i 235 | instance_info_i[:, 6:9] = max_xyz_i 236 | instance_info[inst_idx_i] = instance_info_i 237 | 238 | ### instance_pointnum 239 | instance_pointnum.append(inst_idx_i[0].size) 240 | 241 | return instance_num, {"instance_info": instance_info, "instance_pointnum": instance_pointnum} 242 | 243 | def get_cropped_inst_label(self, 244 | instance_label: np.ndarray, 245 | valid_idxs: np.ndarray) -> np.ndarray: 246 | r""" 247 | get the instance labels after crop operation and recompact 248 | 249 | Args: 250 | instance_label (np.ndarray, [N]): instance label ids of point cloud 251 | valid_idxs (np.ndarray, [N]): boolean valid indices 252 | 253 | Returns: 254 | np.ndarray: processed instance labels 255 | """ 256 | instance_label = instance_label[valid_idxs] 257 | j = 0 258 | while (j < instance_label.max()): 259 | if (len(np.where(instance_label == j)[0]) == 0): 260 | instance_label[instance_label == instance_label.max()] = j 261 | j += 1 262 | return instance_label 263 | 264 | def collate_fn(self, batch: Sequence[Sequence]) -> Dict: 265 | locs = [] 266 | loc_offset_list = [] 267 | locs_float = [] 268 | feats = [] 269 | semantic_labels = [] 270 | instance_labels = [] 271 | 272 | instance_infos = [] # [N, 9] 273 | instance_pointnum = [] # [total_num_inst], int 274 | 275 | batch_offsets = [0] 276 | scene_list = [] 277 | superpoint_list = [] 278 | superpoint_bias = 0 279 | 280 | total_inst_num = 0 281 | for i, data in enumerate(batch): 282 | scene, loc, loc_offset, loc_float, feat, semantic_label, instance_label, superpoint, inst_num, inst_info, inst_pointnum = data 283 | 284 | scene_list.append(scene) 285 | superpoint += superpoint_bias 286 | superpoint_bias += (superpoint.max() + 1) 287 | 288 | invalid_ids = np.where(instance_label != -100) 289 | instance_label[invalid_ids] += total_inst_num 290 | total_inst_num += inst_num 291 | 292 | ### merge the scene to the batch 293 | batch_offsets.append(batch_offsets[-1] + loc.shape[0]) 294 | 295 | locs.append(torch.cat([torch.LongTensor(loc.shape[0], 1).fill_(i), loc], 1)) 296 | loc_offset_list.append(loc_offset) 297 | locs_float.append(loc_float) 298 | feats.append(feat) 299 | semantic_labels.append(semantic_label) 300 | instance_labels.append(instance_label) 301 | superpoint_list.append(superpoint) 302 | 303 | instance_infos.append(inst_info) 304 | instance_pointnum.extend(inst_pointnum) 305 | 306 | ### merge all the scenes in the batchd 307 | batch_offsets = torch.tensor(batch_offsets, dtype=torch.int) # int [B+1] 308 | 309 | locs = torch.cat(locs, 0) # long [N, 1 + 3], the batch item idx is put in locs[:, 0] 310 | locs_float = torch.cat(locs_float, 0).to(torch.float32) # float [N, 3] 311 | superpoint = torch.cat(superpoint_list, 0).long() # long[N] 312 | feats = torch.cat(feats, 0) # float [N, C] 313 | semantic_labels = torch.cat(semantic_labels, 0).long() # long [N] 314 | instance_labels = torch.cat(instance_labels, 0).long() # long [N] 315 | locs_offset = torch.stack(loc_offset_list) # long [B, 3] 316 | 317 | instance_infos = torch.cat(instance_infos, 0).to(torch.float32) # float [N, 9] (meanxyz, minxyz, maxxyz) 318 | instance_pointnum = torch.tensor(instance_pointnum, dtype=torch.int) # int [total_num_inst] 319 | 320 | spatial_shape = np.clip((locs.max(0)[0][1:] + 1).numpy(), self.full_scale[0], None) # long [3] 321 | 322 | ### voxelize 323 | batch_size = len(batch) 324 | voxel_locs, p2v_map, v2p_map = pointgroup_ops.voxelization_idx(locs, batch_size, 4) 325 | 326 | return {"locs": locs, "locs_offset": locs_offset, "voxel_locs": voxel_locs, 327 | "scene_list": scene_list, "p2v_map": p2v_map, "v2p_map": v2p_map, 328 | "locs_float": locs_float, "feats": feats, 329 | "semantic_labels": semantic_labels, "instance_labels": instance_labels, 330 | "instance_info": instance_infos, "instance_pointnum": instance_pointnum, 331 | "offsets": batch_offsets, "spatial_shape": spatial_shape, "superpoint": superpoint} 332 | 333 | -------------------------------------------------------------------------------- /sstnet/data/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Gorilla-Lab. All rights reserved. 2 | import math 3 | 4 | import numpy as np 5 | import scipy.ndimage as ndimage 6 | import scipy.interpolate as interpolate 7 | import transforms3d.euler as euler 8 | 9 | def elastic(xyz, gran, mag): 10 | """Elastic distortion (from point group) 11 | 12 | Args: 13 | xyz (np.ndarray): input point cloud 14 | gran (float): distortion param 15 | mag (float): distortion scalar 16 | 17 | Returns: 18 | xyz: point cloud with elastic distortion 19 | """ 20 | blur0 = np.ones((3, 1, 1)).astype("float32") / 3 21 | blur1 = np.ones((1, 3, 1)).astype("float32") / 3 22 | blur2 = np.ones((1, 1, 3)).astype("float32") / 3 23 | 24 | bb = np.abs(xyz).max(0).astype(np.int32)//gran + 3 25 | noise = [np.random.randn(bb[0], bb[1], bb[2]).astype("float32") for _ in range(3)] 26 | noise = [ndimage.filters.convolve(n, blur0, mode="constant", cval=0) for n in noise] 27 | noise = [ndimage.filters.convolve(n, blur1, mode="constant", cval=0) for n in noise] 28 | noise = [ndimage.filters.convolve(n, blur2, mode="constant", cval=0) for n in noise] 29 | noise = [ndimage.filters.convolve(n, blur0, mode="constant", cval=0) for n in noise] 30 | noise = [ndimage.filters.convolve(n, blur1, mode="constant", cval=0) for n in noise] 31 | noise = [ndimage.filters.convolve(n, blur2, mode="constant", cval=0) for n in noise] 32 | ax = [np.linspace(-(b-1)*gran, (b-1)*gran, b) for b in bb] 33 | interp = [interpolate.RegularGridInterpolator(ax, n, bounds_error=0, fill_value=0) for n in noise] 34 | def g(xyz_): 35 | return np.hstack([i(xyz_)[:,None] for i in interp]) 36 | return xyz + g(xyz) * mag 37 | 38 | 39 | # modify from PointGroup 40 | def pc_aug(xyz, scale=False, flip=False, rot=False): 41 | if scale: 42 | scale = np.random.uniform(0.8, 1.2) 43 | xyz = xyz * scale 44 | if flip: 45 | # m[0][0] *= np.random.randint(0, 2) * 2 - 1 # flip x randomly 46 | flag = np.random.randint(0, 2) 47 | if flag: 48 | xyz[:, 0] = -xyz[:, 0] 49 | if rot: 50 | theta = np.random.uniform() * np.pi 51 | # theta = np.random.randn() * np.pi 52 | rot_mat = np.eye(3) 53 | c, s = np.cos(theta), np.sin(theta) 54 | rot_mat[0, 0] = c 55 | rot_mat[0, 1] = -s 56 | rot_mat[1, 1] = c 57 | rot_mat[1, 0] = s 58 | xyz = xyz @ rot_mat.T 59 | 60 | return xyz 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /sstnet/lib/.gitignore: -------------------------------------------------------------------------------- 1 | scipy* -------------------------------------------------------------------------------- /sstnet/lib/cluster/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================================= 3 | Clustering package (:mod:`scipy.cluster`) 4 | ========================================= 5 | 6 | .. currentmodule:: scipy.cluster 7 | 8 | :mod:`scipy.cluster.vq` 9 | 10 | Clustering algorithms are useful in information theory, target detection, 11 | communications, compression, and other areas. The `vq` module only 12 | supports vector quantization and the k-means algorithms. 13 | 14 | :mod:`scipy.cluster.hierarchy` 15 | 16 | The `hierarchy` module provides functions for hierarchical and 17 | agglomerative clustering. Its features include generating hierarchical 18 | clusters from distance matrices, 19 | calculating statistics on clusters, cutting linkages 20 | to generate flat clusters, and visualizing clusters with dendrograms. 21 | 22 | """ 23 | __all__ = ['vq', 'hierarchy'] 24 | 25 | from . import vq, hierarchy 26 | 27 | from scipy._lib._testutils import PytestTester 28 | test = PytestTester(__name__) 29 | del PytestTester 30 | -------------------------------------------------------------------------------- /sstnet/lib/cluster/_hierarchy_distance_update.pxi: -------------------------------------------------------------------------------- 1 | """ 2 | A `linkage_distance_update` function calculates the distance from cluster i 3 | to the new cluster xy after merging cluster x and cluster y 4 | 5 | Parameters 6 | ---------- 7 | d_xi : double 8 | Distance from cluster x to cluster i 9 | d_yi : double 10 | Distance from cluster y to cluster i 11 | d_xy : double 12 | Distance from cluster x to cluster y 13 | size_x : int 14 | Size of cluster x 15 | size_y : int 16 | Size of cluster y 17 | size_i : int 18 | Size of cluster i 19 | 20 | Returns 21 | ------- 22 | d_xyi : double 23 | Distance from the new cluster xy to cluster i 24 | """ 25 | ctypedef double (*linkage_distance_update)(double d_xi, double d_yi, 26 | double d_xy, int size_x, 27 | int size_y, int size_i) 28 | 29 | 30 | cdef double _single(double d_xi, double d_yi, double d_xy, 31 | int size_x, int size_y, int size_i): 32 | return min(d_xi, d_yi) 33 | 34 | 35 | cdef double _complete(double d_xi, double d_yi, double d_xy, 36 | int size_x, int size_y, int size_i): 37 | return max(d_xi, d_yi) 38 | 39 | 40 | cdef double _average(double d_xi, double d_yi, double d_xy, 41 | int size_x, int size_y, int size_i): 42 | return (size_x * d_xi + size_y * d_yi) / (size_x + size_y) 43 | 44 | 45 | cdef double _centroid(double d_xi, double d_yi, double d_xy, 46 | int size_x, int size_y, int size_i): 47 | return sqrt((((size_x * d_xi * d_xi) + (size_y * d_yi * d_yi)) - 48 | (size_x * size_y * d_xy * d_xy) / (size_x + size_y)) / 49 | (size_x + size_y)) 50 | 51 | 52 | cdef double _median(double d_xi, double d_yi, double d_xy, 53 | int size_x, int size_y, int size_i): 54 | return sqrt(0.5 * (d_xi * d_xi + d_yi * d_yi) - 0.25 * d_xy * d_xy) 55 | 56 | 57 | cdef double _ward(double d_xi, double d_yi, double d_xy, 58 | int size_x, int size_y, int size_i): 59 | cdef double t = 1.0 / (size_x + size_y + size_i) 60 | return sqrt((size_i + size_x) * t * d_xi * d_xi + 61 | (size_i + size_y) * t * d_yi * d_yi - 62 | size_i * t * d_xy * d_xy) 63 | 64 | 65 | cdef double _weighted(double d_xi, double d_yi, double d_xy, 66 | int size_x, int size_y, int size_i): 67 | return 0.5 * (d_xi + d_yi) 68 | -------------------------------------------------------------------------------- /sstnet/lib/cluster/_optimal_leaf_ordering.pyx: -------------------------------------------------------------------------------- 1 | # cython: profile=False 2 | # cython: linetrace=False 3 | # distutils: define_macros=CYTHON_TRACE_NOGIL=1 4 | 5 | # Code adapted from github.com/adrianveres/Polo, licensed: 6 | # 7 | # The MIT License (MIT) 8 | # Copyright (c) 2016 Adrian Veres 9 | # 10 | # Permission is hereby granted, free of charge, to any person obtaining a copy 11 | # of this software and associated documentation files (the "Software"), to deal 12 | # in the Software without restriction, including without limitation the rights 13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | # copies of the Software, and to permit persons to whom the Software is 15 | # furnished to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be included in 18 | # all copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | # SOFTWARE. 27 | 28 | import numpy as np 29 | cimport numpy as np 30 | cimport cython 31 | from libc.stdlib cimport malloc, free 32 | 33 | from scipy.spatial.distance import squareform, is_valid_y, is_valid_dm 34 | 35 | 36 | @cython.profile(False) 37 | @cython.boundscheck(False) 38 | @cython.wraparound(False) 39 | cdef inline void dual_swap(float* darr, int* iarr, 40 | int i1, int i2): 41 | """ 42 | [Taken from Scikit-learn.] 43 | 44 | swap the values at inex i1 and i2 of both darr and iarr""" 45 | cdef float dtmp = darr[i1] 46 | darr[i1] = darr[i2] 47 | darr[i2] = dtmp 48 | 49 | cdef int itmp = iarr[i1] 50 | iarr[i1] = iarr[i2] 51 | iarr[i2] = itmp 52 | 53 | 54 | @cython.profile(False) 55 | @cython.boundscheck(False) 56 | @cython.wraparound(False) 57 | cdef int _simultaneous_sort(float* dist, int* idx, 58 | int size) except -1: 59 | """ 60 | [Taken from Scikit-learn.] 61 | 62 | 63 | Perform a recursive quicksort on the dist array, simultaneously 64 | performing the same swaps on the idx array. The equivalent in 65 | numpy (though quite a bit slower) is 66 | def simultaneous_sort(dist, idx): 67 | i = np.argsort(dist) 68 | return dist[i], idx[i] 69 | """ 70 | cdef int pivot_idx, i, store_idx 71 | cdef float pivot_val 72 | 73 | # in the small-array case, do things efficiently 74 | if size <= 1: 75 | pass 76 | elif size == 2: 77 | if dist[0] > dist[1]: 78 | dual_swap(dist, idx, 0, 1) 79 | elif size == 3: 80 | if dist[0] > dist[1]: 81 | dual_swap(dist, idx, 0, 1) 82 | if dist[1] > dist[2]: 83 | dual_swap(dist, idx, 1, 2) 84 | if dist[0] > dist[1]: 85 | dual_swap(dist, idx, 0, 1) 86 | else: 87 | # Determine the pivot using the median-of-three rule. 88 | # The smallest of the three is moved to the beginning of the array, 89 | # the middle (the pivot value) is moved to the end, and the largest 90 | # is moved to the pivot index. 91 | pivot_idx = size // 2 92 | if dist[0] > dist[size - 1]: 93 | dual_swap(dist, idx, 0, size - 1) 94 | if dist[size - 1] > dist[pivot_idx]: 95 | dual_swap(dist, idx, size - 1, pivot_idx) 96 | if dist[0] > dist[size - 1]: 97 | dual_swap(dist, idx, 0, size - 1) 98 | pivot_val = dist[size - 1] 99 | 100 | # partition indices about pivot. At the end of this operation, 101 | # pivot_idx will contain the pivot value, everything to the left 102 | # will be smaller, and everything to the right will be larger. 103 | store_idx = 0 104 | for i in range(size - 1): 105 | if dist[i] < pivot_val: 106 | dual_swap(dist, idx, i, store_idx) 107 | store_idx += 1 108 | dual_swap(dist, idx, store_idx, size - 1) 109 | pivot_idx = store_idx 110 | 111 | # recursively sort each side of the pivot 112 | if pivot_idx > 1: 113 | _simultaneous_sort(dist, idx, pivot_idx) 114 | if pivot_idx + 2 < size: 115 | _simultaneous_sort(dist + pivot_idx + 1, 116 | idx + pivot_idx + 1, 117 | size - pivot_idx - 1) 118 | return 0 119 | 120 | 121 | cdef inline void _sort_M_slice(float[:, ::1] M, 122 | float* vals, int* idx, 123 | int dim1_min, int dim1_max, int dim2_val): 124 | """ 125 | Simultaneously sort indices and values of M[{m}, u] using 126 | `_simultaneous_sort` 127 | 128 | This is equivalent to : 129 | m_sort = M[dim1_min:dim1_max, dim2_val].argsort() 130 | m_iter = np.arange(dim1_min, dim1_max)[m_sort] 131 | 132 | but much faster because we don't have to pay the numpy overhead. This 133 | matters a lot for the sorting of M[{k}, w] which is executed many times. 134 | """ 135 | cdef int i 136 | for i in range(0, dim1_max - dim1_min): 137 | vals[i] = M[dim1_min + i, dim2_val] 138 | idx[i] = dim1_min + i 139 | _simultaneous_sort(vals, idx, dim1_max - dim1_min) 140 | 141 | 142 | @cython.boundscheck(False) 143 | @cython.wraparound(False) 144 | cdef int[:] identify_swaps(int[:, ::1] sorted_Z, 145 | double[:, ::1] sorted_D, 146 | int[:, ::1] cluster_ranges): 147 | """ 148 | Implements the Optimal Leaf Ordering algorithm described in 149 | "Fast Optimal leaf ordering for hierarchical clustering" 150 | Ziv Bar-Joseph, David K. Gifford, Tommi S. Jaakkola 151 | Bioinformatics, 2001, :doi:`10.1093/bioinformatics/17.suppl_1.S22` 152 | 153 | `sorted_Z` : Linkage list, with 'height' column removed. 154 | 155 | """ 156 | cdef int n_points = len(sorted_Z) + 1 157 | 158 | cdef: 159 | # (n x n) floats 160 | float[:, ::1] M = np.zeros((n_points, n_points), dtype=np.float32) 161 | # (n x n x 2) booleans 162 | int[:, :, :] swap_status = np.zeros((n_points, n_points, 2), 163 | dtype=np.intc) 164 | int[:] must_swap = np.zeros((len(sorted_Z),), dtype=np.intc) 165 | 166 | int i, v_l, v_r, v_size, 167 | int v_l_min, v_l_max, v_r_min, v_r_max 168 | 169 | int u_clusters[2] 170 | int m_clusters[2] 171 | int w_clusters[2] 172 | int k_clusters[2] 173 | int total_u_clusters, total_w_clusters 174 | 175 | int u, w, m, k 176 | int u_min, u_max, m_min, m_max, w_min, w_max, k_min, k_max 177 | int swap_L, swap_R 178 | 179 | float* m_vals 180 | int* m_idx 181 | float* k_vals 182 | int* k_idx 183 | int mi, ki 184 | 185 | float min_km_dist 186 | float cur_min_M, current_M 187 | int best_m, best_k 188 | 189 | int best_u, best_w 190 | 191 | for i in range(len(sorted_Z)): 192 | # Iterate over the linkage list instead of recursion. 193 | # v_l = sorted_Z[i, 0] 194 | # v_r = sorted_Z[i, 1] 195 | # are indices of the left and right children for node i. 196 | # 197 | # If the v_l or v_r are < n_points, then v_l or v_r are singleton 198 | # clusters. Otherwise, it is the node defined in the (i - n_points) 199 | # 200 | # V 201 | # / \ 202 | # -- -- 203 | # / \ 204 | # V_l V_r 205 | # / \ / \ 206 | # V_l1 V_l2 V_r1 V_r2 207 | # (u) (m) (k) (w) 208 | # 209 | # Briefly, for every node V, the algorithm finds left-most and 210 | # right-most nodes u, w that minimizes U[u, w] the sum of distances of 211 | # every neighboring singleton node in the linear ordering. 212 | # 213 | # This is done recursively, by finding the optimizing the ordering of 214 | # v_l (bounded by nodes u, m) and v_r (bounded by k, w) 215 | # such that 216 | # U[u, w] = U[u, m] + U[k, w] + D[m, k] 217 | # is then minimized. 218 | # 219 | # Part of the optimization is that at every search step, 220 | # if (u) ~ V_l1, then (m) ~ V_l2 (and vice-versa) 221 | # likewise for (w) ~ V_r1, then (w) ~ V_r2. 222 | # 223 | # This means we need to search 4 pairs of (V_li, V_rj) combinations. 224 | # If V_l or V_r are singletons, for example, V_l = u = m. 225 | 226 | v_l = sorted_Z[i, 0] 227 | v_r = sorted_Z[i, 1] 228 | v_size = sorted_Z[i, 2] 229 | 230 | v_l_min = cluster_ranges[v_l, 0]; v_l_max = cluster_ranges[v_l, 1] 231 | v_r_min = cluster_ranges[v_r, 0]; v_r_max = cluster_ranges[v_r, 1] 232 | 233 | if v_l < n_points: 234 | # V_l is a singleton, so U = M = V_L. 235 | total_u_clusters = 1 236 | 237 | # This could be handled more efficiently, but in practice the code 238 | # would get longer for no speed gain. 239 | u_clusters[0] = v_l 240 | m_clusters[0] = v_l 241 | 242 | else: 243 | total_u_clusters = 2 244 | 245 | # First look for U from V_LL and M from V_LR 246 | u_clusters[0] = sorted_Z[v_l - n_points, 0] 247 | m_clusters[0] = sorted_Z[v_l - n_points, 1] 248 | 249 | # Then look for U from V_LR and M from V_LL 250 | u_clusters[1] = sorted_Z[v_l - n_points, 1] 251 | m_clusters[1] = sorted_Z[v_l - n_points, 0] 252 | 253 | if v_r < n_points: 254 | total_w_clusters = 1 255 | # V_r is a singleton, so W = K = V_R. 256 | w_clusters[0] = v_r 257 | k_clusters[0] = v_r 258 | 259 | else: 260 | total_w_clusters = 2 261 | 262 | # First look for W from V_RR and L from V_RL 263 | w_clusters[0] = sorted_Z[v_r - n_points, 1] 264 | w_clusters[1] = sorted_Z[v_r - n_points, 0] 265 | 266 | # Next look for W from V_RL and L from V_RR 267 | k_clusters[0] = sorted_Z[v_r - n_points, 0] 268 | k_clusters[1] = sorted_Z[v_r - n_points, 1] 269 | 270 | for swap_L in range(total_u_clusters): 271 | for swap_R in range(total_w_clusters): 272 | # Get bounds for the clusters from which we'll sample u, m, w, k 273 | # (see note above for details). 274 | # If in the chosen ordering, 275 | # U came from V_ll : Don't swap V_l. 276 | # U came from V_lr : Swap V_l. 277 | # W came from V_rl : Swap V_r. 278 | # W came from V_ll : Don't swap V_r. 279 | 280 | u_min = cluster_ranges[u_clusters[swap_L], 0] 281 | u_max = cluster_ranges[u_clusters[swap_L], 1] 282 | m_min = cluster_ranges[m_clusters[swap_L], 0] 283 | m_max = cluster_ranges[m_clusters[swap_L], 1] 284 | w_min = cluster_ranges[w_clusters[swap_R], 0] 285 | w_max = cluster_ranges[w_clusters[swap_R], 1] 286 | k_min = cluster_ranges[k_clusters[swap_R], 0] 287 | k_max = cluster_ranges[k_clusters[swap_R], 1] 288 | 289 | # Find the minimum of D[m, k] for the appropriate sets {m}, {k}. 290 | # This is C[{m}, {k}] in the paper's notation. 291 | min_km_dist = 1073741824 #2^30 292 | for m in range(m_min, m_max): 293 | for k in range(k_min, k_max): 294 | if sorted_D[m, k] < min_km_dist: 295 | min_km_dist = sorted_D[m, k] 296 | 297 | m_vals = malloc(sizeof(float) * (m_max - m_min)) 298 | m_idx = malloc(sizeof(int) * (m_max - m_min)) 299 | k_vals = malloc(sizeof(float) * (k_max - k_min)) 300 | k_idx = malloc(sizeof(int) * (k_max - k_min)) 301 | if not m_vals or not m_idx or not k_vals or not k_idx: 302 | free(m_vals) 303 | free(m_idx) 304 | free(k_vals) 305 | free(k_idx) 306 | raise MemoryError("failed to allocate memory in identify_swaps().") 307 | 308 | for u in range(u_min, u_max): 309 | # Sort the values of M[{m}, u] 310 | _sort_M_slice(M, m_vals, m_idx, m_min, m_max, u) 311 | 312 | for w in range(w_min, w_max): 313 | # Sort the values of M[{k}, w] 314 | _sort_M_slice(M, k_vals, k_idx, k_min, k_max, w) 315 | 316 | # Set initial value for cur_min_M. 317 | # I used a large number. 318 | cur_min_M = 1073741824.0 #2^30 319 | 320 | for mi in range(0, m_max - m_min): 321 | m = m_idx[mi] 322 | 323 | if (M[u, m] + M[w, k_idx[0]] + min_km_dist 324 | >= cur_min_M): 325 | # Terminate the outer loop early, there will not 326 | # be a better 'k' in the current k list. 327 | break 328 | for ki in range(0, k_max - k_min): 329 | k = k_idx[ki] 330 | 331 | if M[u, m] + M[w, k] + min_km_dist >= cur_min_M: 332 | # Terminate the inner loop early 333 | break 334 | 335 | current_M = M[u, m] + M[w, k] + sorted_D[m, k] 336 | if current_M < cur_min_M: 337 | # We found a better m, k than previously. 338 | cur_min_M = current_M 339 | best_m = m 340 | best_k = k 341 | 342 | # For the chosen (u, w), record the resulting minimal 343 | # M[u, w] = M[u, m] + M[k, w] + D[m, k] 344 | M[u, w] = cur_min_M 345 | M[w, u] = cur_min_M 346 | # whether we need to swap V_l and V_r given the current 347 | # chosen (m, k) (see note above). This saves us from 348 | # storing (m, k) and doing back-tracking later. 349 | swap_status[u, w, 0] = swap_L 350 | swap_status[w, u, 0] = swap_L 351 | swap_status[u, w, 1] = swap_R 352 | swap_status[w, u, 1] = swap_R 353 | 354 | # We are getting a fresh `w` and `u` so need to resort 355 | # M[{k}, w] and M[{m}, u] 356 | free(m_vals) 357 | free(m_idx) 358 | free(k_vals) 359 | free(k_idx) 360 | 361 | # We are now ready to find the best minimal value for M[{u}, {w}] 362 | cur_min_M = 1073741824.0 #2^30 363 | for u in range(v_l_min, v_l_max): 364 | for w in range(v_r_min, v_r_max): 365 | if M[u, w] < cur_min_M: 366 | cur_min_M = M[u, w] 367 | best_u = u 368 | best_w = w 369 | 370 | # If v_l, v_r are not singletons, record whether our choice of (u, w) 371 | # for V requires a swap of its children. 372 | if v_l >= n_points: 373 | must_swap[v_l - n_points] = int(swap_status[best_u, best_w, 0]) 374 | if v_r >= n_points: 375 | must_swap[v_r - n_points] = int(swap_status[best_u, best_w, 1]) 376 | 377 | return must_swap 378 | 379 | 380 | def optimal_leaf_ordering(Z, D): 381 | """ 382 | Compute the optimal leaf order for Z (according to D) and return an 383 | optimally sorted Z. 384 | 385 | We start by sorting and relabelling Z and D according to the current leaf 386 | order in Z. 387 | 388 | This is because when everything is sorted each cluster (including 389 | singletons) can be defined by its range over (0...n_points). 390 | 391 | This is used extensively to loop efficiently over the various arrays in the 392 | algorithm. 393 | 394 | """ 395 | # Import here to avoid import cycles 396 | from scipy.cluster.hierarchy import leaves_list, is_valid_linkage 397 | 398 | is_valid_linkage(Z, throw=True, name='Z') 399 | 400 | if is_valid_y(D): 401 | sorted_D = squareform(D) 402 | elif is_valid_dm(D): 403 | sorted_D = D 404 | else: 405 | raise("Not a valid distance matrix (neither condensed nor square form)") 406 | 407 | n_points = Z.shape[0] + 1 408 | n_clusters = 2*n_points - 1 409 | 410 | # Get the current linear ordering 411 | sorted_leaves = leaves_list(Z) 412 | 413 | # Create map from original order to sorted order. 414 | original_order_to_sorted_order = dict((orig_i, sorted_i) for sorted_i,orig_i 415 | in enumerate(sorted_leaves)) 416 | 417 | 418 | # Re-write linkage map so it refers to sorted positions, rather than input 419 | # positions. Remove the 'height' column so we can cast the whole thing as 420 | # integer and simplify passing to C function above. 421 | sorted_Z = [] 422 | for (v_l, v_r, _, v_size) in Z: 423 | if v_l < n_points: 424 | v_l = original_order_to_sorted_order[int(v_l)] 425 | if v_r < n_points: 426 | v_r = original_order_to_sorted_order[int(v_r)] 427 | 428 | sorted_Z.append([v_l, v_r, v_size]) 429 | sorted_Z = np.array(sorted_Z).astype(np.int32).copy(order='C') 430 | 431 | 432 | # Sort distance matrix D by the leaf order 433 | sorted_D = sorted_D[sorted_leaves, :] 434 | sorted_D = sorted_D[:, sorted_leaves].copy(order='C') 435 | 436 | # Defines the range of each cluster over (0... n_points) as explained above. 437 | cluster_ranges = np.zeros((n_clusters, 2)) 438 | cluster_ranges[np.arange(n_points), 0] = np.arange(n_points) 439 | cluster_ranges[np.arange(n_points), 1] = np.arange(n_points) + 1 440 | for link_i, (v_l, v_r, v_size) in enumerate(sorted_Z): 441 | v = link_i + n_points 442 | cluster_ranges[v, 0] = cluster_ranges[v_l, 0] 443 | cluster_ranges[v, 1] = cluster_ranges[v_r, 1] 444 | cluster_ranges = cluster_ranges.astype(np.int32).copy(order='C') 445 | 446 | # Get Swaps 447 | must_swap = identify_swaps(sorted_Z, sorted_D, cluster_ranges) 448 | 449 | # To 'rotate' around the axis of a node, we need to consider the left-right 450 | # children of every descendant of this target node. 451 | # 452 | # To do so efficiently, we record how many total times a given node must be 453 | # swapped (once if it needs to be swapped itself, once for each parent that 454 | # needs to be swapped) and take modulo 2 to find whether it needs to be 455 | # swapped at all. 456 | is_descendant = np.zeros((n_clusters - n_points, n_clusters - n_points), 457 | dtype=int) 458 | for i, (v_l, v_r, v_size) in enumerate(sorted_Z): 459 | is_descendant[i, i] = 1 460 | if v_l >= n_points: 461 | is_descendant[i, v_l - n_points] = 1 462 | is_descendant[i, :] += is_descendant[v_l - n_points, :] 463 | if v_r >= n_points: 464 | is_descendant[i, v_r - n_points] = 1 465 | is_descendant[i, :] += is_descendant[v_r - n_points, :] 466 | 467 | 468 | # To "rotate" a tree node, we need to 'swap' its left-right children, 469 | # and do the same to all its children. 470 | applied_swap = (np.array(is_descendant).astype(bool) 471 | * np.array(must_swap).reshape(-1, 1)) 472 | final_swap = applied_swap.sum(axis=0) % 2 473 | 474 | # Create a new linkage matrix by applying swaps where needed. 475 | swapped_Z = [] 476 | for i, (in_l, in_r, h, v_size) in enumerate(Z): 477 | if final_swap[i]: 478 | out_l = in_r 479 | out_r = in_l 480 | else: 481 | out_r = in_r 482 | out_l = in_l 483 | swapped_Z.append((out_l, out_r, h, v_size)) 484 | swapped_Z = np.array(swapped_Z) 485 | 486 | return swapped_Z 487 | 488 | 489 | 490 | -------------------------------------------------------------------------------- /sstnet/lib/cluster/_structures.pxi: -------------------------------------------------------------------------------- 1 | # cython: boundscheck=False, wraparound=False, cdivision=True 2 | import numpy as np 3 | 4 | 5 | ctypedef struct Pair: 6 | int key 7 | double value 8 | 9 | 10 | cdef class Heap: 11 | """Binary heap. 12 | 13 | Heap stores values and keys. Values are passed explicitly, whereas keys 14 | are assigned implicitly to natural numbers (from 0 to n - 1). 15 | 16 | The supported operations (all have O(log n) time complexity): 17 | 18 | * Return the current minimum value and the corresponding key. 19 | * Remove the current minimum value. 20 | * Change the value of the given key. Note that the key must be still 21 | in the heap. 22 | 23 | The heap is stored as an array, where children of parent i have indices 24 | 2 * i + 1 and 2 * i + 2. All public methods are based on `sift_down` and 25 | `sift_up` methods, which restore the heap property by moving an element 26 | down or up in the heap. 27 | """ 28 | cdef int[:] index_by_key 29 | cdef int[:] key_by_index 30 | cdef double[:] values 31 | cdef int size 32 | 33 | def __init__(self, double[:] values): 34 | self.size = values.shape[0] 35 | self.index_by_key = np.arange(self.size, dtype=np.intc) 36 | self.key_by_index = np.arange(self.size, dtype=np.intc) 37 | self.values = values.copy() 38 | cdef int i 39 | 40 | # Create the heap in a linear time. The algorithm sequentially sifts 41 | # down items starting from lower levels. 42 | for i in reversed(range(self.size / 2)): 43 | self.sift_down(i) 44 | 45 | cpdef Pair get_min(self): 46 | return Pair(self.key_by_index[0], self.values[0]) 47 | 48 | cpdef void remove_min(self): 49 | self.swap(0, self.size - 1) 50 | self.size -= 1 51 | self.sift_down(0) 52 | 53 | cpdef void change_value(self, int key, double value): 54 | cdef int index = self.index_by_key[key] 55 | cdef double old_value = self.values[index] 56 | self.values[index] = value 57 | if value < old_value: 58 | self.sift_up(index) 59 | else: 60 | self.sift_down(index) 61 | 62 | cdef void sift_up(self, int index): 63 | cdef int parent = Heap.parent(index) 64 | while index > 0 and self.values[parent] > self.values[index]: 65 | self.swap(index, parent) 66 | index = parent 67 | parent = Heap.parent(index) 68 | 69 | cdef void sift_down(self, int index): 70 | cdef int child = Heap.left_child(index) 71 | while child < self.size: 72 | if (child + 1 < self.size and 73 | self.values[child + 1] < self.values[child]): 74 | child += 1 75 | 76 | if self.values[index] > self.values[child]: 77 | self.swap(index, child) 78 | index = child 79 | child = Heap.left_child(index) 80 | else: 81 | break 82 | 83 | @staticmethod 84 | cdef inline int left_child(int parent): 85 | return (parent << 1) + 1 86 | 87 | @staticmethod 88 | cdef inline int parent(int child): 89 | return (child - 1) >> 1 90 | 91 | cdef void swap(self, int i, int j): 92 | self.values[i], self.values[j] = self.values[j], self.values[i] 93 | cdef int key_i = self.key_by_index[i] 94 | cdef int key_j = self.key_by_index[j] 95 | self.key_by_index[i] = key_j 96 | self.key_by_index[j] = key_i 97 | self.index_by_key[key_i] = j 98 | self.index_by_key[key_j] = i 99 | -------------------------------------------------------------------------------- /sstnet/lib/cluster/_vq.pyx: -------------------------------------------------------------------------------- 1 | """ 2 | Cython rewrite of the vector quantization module, originally written 3 | in C at src/vq.c and the wrapper at src/vq_module.c. This should be 4 | easier to maintain than old SWIG output. 5 | 6 | Original C version by Damian Eads. 7 | Translated to Cython by David Warde-Farley, October 2009. 8 | """ 9 | 10 | cimport cython 11 | import numpy as np 12 | cimport numpy as np 13 | from scipy.linalg.cython_blas cimport dgemm, sgemm 14 | 15 | from libc.math cimport sqrt 16 | 17 | ctypedef np.float64_t float64_t 18 | ctypedef np.float32_t float32_t 19 | ctypedef np.int32_t int32_t 20 | 21 | # Use Cython fused types for templating 22 | # Define supported data types as vq_type 23 | ctypedef fused vq_type: 24 | float32_t 25 | float64_t 26 | 27 | # When the number of features is less than this number, 28 | # switch back to the naive algorithm to avoid high overhead. 29 | DEF NFEATURES_CUTOFF=5 30 | 31 | # Initialize the NumPy C API 32 | np.import_array() 33 | 34 | 35 | cdef inline vq_type vec_sqr(int n, vq_type *p): 36 | cdef vq_type result = 0.0 37 | cdef int i 38 | for i in range(n): 39 | result += p[i] * p[i] 40 | return result 41 | 42 | 43 | cdef inline void cal_M(int nobs, int ncodes, int nfeat, vq_type *obs, 44 | vq_type *code_book, vq_type *M): 45 | """ 46 | Calculate M = obs * code_book.T 47 | """ 48 | cdef vq_type alpha = -2.0, beta = 0.0 49 | 50 | # Call BLAS functions with Fortran ABI 51 | # Note that BLAS Fortran ABI uses column-major order 52 | if vq_type is float32_t: 53 | sgemm("T", "N", &ncodes, &nobs, &nfeat, 54 | &alpha, code_book, &nfeat, obs, &nfeat, &beta, M, &ncodes) 55 | else: 56 | dgemm("T", "N", &ncodes, &nobs, &nfeat, 57 | &alpha, code_book, &nfeat, obs, &nfeat, &beta, M, &ncodes) 58 | 59 | 60 | cdef int _vq(vq_type *obs, vq_type *code_book, 61 | int ncodes, int nfeat, int nobs, 62 | int32_t *codes, vq_type *low_dist) except -1: 63 | """ 64 | The underlying function (template) of _vq.vq. 65 | 66 | Parameters 67 | ---------- 68 | obs : vq_type* 69 | The pointer to the observation matrix. 70 | code_book : vq_type* 71 | The pointer to the code book matrix. 72 | ncodes : int 73 | The number of centroids (codes). 74 | nfeat : int 75 | The number of features of each observation. 76 | nobs : int 77 | The number of observations. 78 | codes : int32_t* 79 | The pointer to the new codes array. 80 | low_dist : vq_type* 81 | low_dist[i] is the Euclidean distance from obs[i] to the corresponding 82 | centroid. 83 | """ 84 | # Naive algorithm is preferred when nfeat is small 85 | if nfeat < NFEATURES_CUTOFF: 86 | _vq_small_nf(obs, code_book, ncodes, nfeat, nobs, codes, low_dist) 87 | return 0 88 | 89 | cdef np.npy_intp i, j 90 | cdef vq_type *p_obs 91 | cdef vq_type *p_codes 92 | cdef vq_type dist_sqr 93 | cdef np.ndarray[vq_type, ndim=1] obs_sqr, codes_sqr 94 | cdef np.ndarray[vq_type, ndim=2] M 95 | 96 | if vq_type is float32_t: 97 | obs_sqr = np.ndarray(nobs, np.float32) 98 | codes_sqr = np.ndarray(ncodes, np.float32) 99 | M = np.ndarray((nobs, ncodes), np.float32) 100 | else: 101 | obs_sqr = np.ndarray(nobs, np.float64) 102 | codes_sqr = np.ndarray(ncodes, np.float64) 103 | M = np.ndarray((nobs, ncodes), np.float64) 104 | 105 | p_obs = obs 106 | for i in range(nobs): 107 | # obs_sqr[i] is the inner product of the i-th observation with itself 108 | obs_sqr[i] = vec_sqr(nfeat, p_obs) 109 | p_obs += nfeat 110 | 111 | p_codes = code_book 112 | for i in range(ncodes): 113 | # codes_sqr[i] is the inner product of the i-th code with itself 114 | codes_sqr[i] = vec_sqr(nfeat, p_codes) 115 | p_codes += nfeat 116 | 117 | # M[i][j] is the inner product of the i-th obs and j-th code 118 | # M = obs * codes.T 119 | cal_M(nobs, ncodes, nfeat, obs, code_book, M.data) 120 | 121 | for i in range(nobs): 122 | for j in range(ncodes): 123 | dist_sqr = (M[i, j] + 124 | obs_sqr[i] + codes_sqr[j]) 125 | if dist_sqr < low_dist[i]: 126 | codes[i] = j 127 | low_dist[i] = dist_sqr 128 | 129 | # dist_sqr may be negative due to float point errors 130 | if low_dist[i] > 0: 131 | low_dist[i] = sqrt(low_dist[i]) 132 | else: 133 | low_dist[i] = 0 134 | 135 | return 0 136 | 137 | 138 | cdef void _vq_small_nf(vq_type *obs, vq_type *code_book, 139 | int ncodes, int nfeat, int nobs, 140 | int32_t *codes, vq_type *low_dist): 141 | """ 142 | Vector quantization using naive algorithm. 143 | This is preferred when nfeat is small. 144 | The parameters are the same as those of _vq. 145 | """ 146 | # Temporary variables 147 | cdef vq_type dist_sqr, diff 148 | cdef np.npy_intp i, j, k, obs_offset = 0, code_offset 149 | 150 | # Index and pointer to keep track of the current position in 151 | # both arrays so that we don't have to always do index * nfeat. 152 | cdef vq_type *current_obs 153 | cdef vq_type *current_code 154 | 155 | for i in range(nobs): 156 | code_offset = 0 157 | current_obs = &(obs[obs_offset]) 158 | 159 | for j in range(ncodes): 160 | dist_sqr = 0 161 | current_code = &(code_book[code_offset]) 162 | 163 | # Distance between code_book[j] and obs[i] 164 | for k in range(nfeat): 165 | diff = current_code[k] - current_obs[k] 166 | dist_sqr += diff * diff 167 | code_offset += nfeat 168 | 169 | # Replace the code assignment and record distance if necessary 170 | if dist_sqr < low_dist[i]: 171 | codes[i] = j 172 | low_dist[i] = dist_sqr 173 | 174 | low_dist[i] = sqrt(low_dist[i]) 175 | 176 | # Update the offset of the current observation 177 | obs_offset += nfeat 178 | 179 | 180 | def vq(np.ndarray obs, np.ndarray codes): 181 | """ 182 | Vector quantization ndarray wrapper. Only support float32 and float64. 183 | 184 | Parameters 185 | ---------- 186 | obs : ndarray 187 | The observation matrix. Each row is an observation. 188 | codes : ndarray 189 | The code book matrix. 190 | 191 | Notes 192 | ----- 193 | The observation matrix and code book matrix should have same ndim and 194 | same number of columns (features). Only 1-dimensional and 2-dimensional 195 | arrays are supported. 196 | """ 197 | cdef int nobs, ncodes, nfeat 198 | cdef np.ndarray outcodes, outdists 199 | 200 | # Ensure the arrays are contiguous 201 | obs = np.ascontiguousarray(obs) 202 | codes = np.ascontiguousarray(codes) 203 | 204 | if obs.dtype != codes.dtype: 205 | raise TypeError('observation and code should have same dtype') 206 | if obs.dtype not in (np.float32, np.float64): 207 | raise TypeError('type other than float or double not supported') 208 | if obs.ndim != codes.ndim: 209 | raise ValueError( 210 | 'observation and code should have same number of dimensions') 211 | 212 | if obs.ndim == 1: 213 | nfeat = 1 214 | nobs = obs.shape[0] 215 | ncodes = codes.shape[0] 216 | elif obs.ndim == 2: 217 | nfeat = obs.shape[1] 218 | nobs = obs.shape[0] 219 | ncodes = codes.shape[0] 220 | if nfeat != codes.shape[1]: 221 | raise ValueError('obs and code should have same number of ' 222 | 'features (columns)') 223 | else: 224 | raise ValueError('ndim different than 1 or 2 are not supported') 225 | 226 | # Initialize outdists and outcodes array. 227 | # Outdists should be initialized as INF. 228 | outdists = np.empty((nobs,), dtype=obs.dtype) 229 | outcodes = np.empty((nobs,), dtype=np.int32) 230 | outdists.fill(np.inf) 231 | 232 | if obs.dtype.type is np.float32: 233 | _vq(obs.data, codes.data, 234 | ncodes, nfeat, nobs, outcodes.data, 235 | outdists.data) 236 | elif obs.dtype.type is np.float64: 237 | _vq(obs.data, codes.data, 238 | ncodes, nfeat, nobs, outcodes.data, 239 | outdists.data) 240 | 241 | return outcodes, outdists 242 | 243 | 244 | @cython.cdivision(True) 245 | cdef np.ndarray _update_cluster_means(vq_type *obs, int32_t *labels, 246 | vq_type *cb, int nobs, int nc, int nfeat): 247 | """ 248 | The underlying function (template) of _vq.update_cluster_means. 249 | 250 | Parameters 251 | ---------- 252 | obs : vq_type* 253 | The pointer to the observation matrix. 254 | labels : int32_t* 255 | The pointer to the array of the labels (codes) of the observations. 256 | cb : vq_type* 257 | The pointer to the new code book matrix. 258 | nobs : int 259 | The number of observations. 260 | nc : int 261 | The number of centroids (codes). 262 | nfeat : int 263 | The number of features of each observation. 264 | 265 | Returns 266 | ------- 267 | has_members : ndarray 268 | A boolean array indicating which clusters have members. 269 | """ 270 | cdef np.npy_intp i, j, cluster_size, label 271 | cdef vq_type *obs_p 272 | cdef vq_type *cb_p 273 | cdef np.ndarray[int, ndim=1] obs_count 274 | 275 | # Calculate the sums the numbers of obs in each cluster 276 | obs_count = np.zeros(nc, np.intc) 277 | obs_p = obs 278 | for i in range(nobs): 279 | label = labels[i] 280 | cb_p = cb + nfeat * label 281 | 282 | for j in range(nfeat): 283 | cb_p[j] += obs_p[j] 284 | 285 | # Count the obs in each cluster 286 | obs_count[label] += 1 287 | obs_p += nfeat 288 | 289 | cb_p = cb 290 | for i in range(nc): 291 | cluster_size = obs_count[i] 292 | 293 | if cluster_size > 0: 294 | # Calculate the centroid of each cluster 295 | for j in range(nfeat): 296 | cb_p[j] /= cluster_size 297 | 298 | cb_p += nfeat 299 | 300 | # Return a boolean array indicating which clusters have members 301 | return obs_count > 0 302 | 303 | 304 | def update_cluster_means(np.ndarray obs, np.ndarray labels, int nc): 305 | """ 306 | The update-step of K-means. Calculate the mean of observations in each 307 | cluster. 308 | 309 | Parameters 310 | ---------- 311 | obs : ndarray 312 | The observation matrix. Each row is an observation. Its dtype must be 313 | float32 or float64. 314 | labels : ndarray 315 | The label of each observation. Must be an 1d array. 316 | nc : int 317 | The number of centroids. 318 | 319 | Returns 320 | ------- 321 | cb : ndarray 322 | The new code book. 323 | has_members : ndarray 324 | A boolean array indicating which clusters have members. 325 | 326 | Notes 327 | ----- 328 | The empty clusters will be set to all zeros and the corresponding elements 329 | in `has_members` will be `False`. The upper level function should decide 330 | how to deal with them. 331 | """ 332 | cdef np.ndarray has_members, cb 333 | cdef int nfeat 334 | 335 | # Ensure the arrays are contiguous 336 | obs = np.ascontiguousarray(obs) 337 | labels = np.ascontiguousarray(labels) 338 | 339 | if obs.dtype not in (np.float32, np.float64): 340 | raise TypeError('type other than float or double not supported') 341 | if labels.dtype.type is not np.int32: 342 | labels = labels.astype(np.int32) 343 | if labels.ndim != 1: 344 | raise ValueError('labels must be an 1d array') 345 | 346 | if obs.ndim == 1: 347 | nfeat = 1 348 | cb = np.zeros(nc, dtype=obs.dtype) 349 | elif obs.ndim == 2: 350 | nfeat = obs.shape[1] 351 | cb = np.zeros((nc, nfeat), dtype=obs.dtype) 352 | else: 353 | raise ValueError('ndim different than 1 or 2 are not supported') 354 | 355 | if obs.dtype.type is np.float32: 356 | has_members = _update_cluster_means(obs.data, 357 | labels.data, 358 | cb.data, 359 | obs.shape[0], nc, nfeat) 360 | elif obs.dtype.type is np.float64: 361 | has_members = _update_cluster_means(obs.data, 362 | labels.data, 363 | cb.data, 364 | obs.shape[0], nc, nfeat) 365 | 366 | return cb, has_members 367 | -------------------------------------------------------------------------------- /sstnet/lib/cluster/setup.py: -------------------------------------------------------------------------------- 1 | DEFINE_MACROS = [("SCIPY_PY3K", None)] 2 | 3 | 4 | def configuration(parent_package='', top_path=None): 5 | from numpy.distutils.misc_util import Configuration, get_numpy_include_dirs 6 | config = Configuration('cluster', parent_package, top_path) 7 | 8 | config.add_data_dir('tests') 9 | 10 | config.add_extension('_vq', 11 | sources=[('_vq.c')], 12 | include_dirs=[get_numpy_include_dirs()]) 13 | 14 | config.add_extension('_hierarchy', 15 | sources=[('_hierarchy.c')], 16 | include_dirs=[get_numpy_include_dirs()]) 17 | 18 | config.add_extension('_optimal_leaf_ordering', 19 | sources=[('_optimal_leaf_ordering.c')], 20 | include_dirs=[get_numpy_include_dirs()]) 21 | 22 | return config 23 | 24 | 25 | if __name__ == '__main__': 26 | from numpy.distutils.core import setup 27 | setup(**configuration(top_path='').todict()) 28 | -------------------------------------------------------------------------------- /sstnet/lib/cluster/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/SSTNet/53edd4436ae60171a8031fa8709b996277e89835/sstnet/lib/cluster/tests/__init__.py -------------------------------------------------------------------------------- /sstnet/lib/cluster/tests/hierarchy_test_data.py: -------------------------------------------------------------------------------- 1 | from numpy import array 2 | 3 | 4 | Q_X = array([[5.26563660e-01, 3.14160190e-01, 8.00656370e-02], 5 | [7.50205180e-01, 4.60299830e-01, 8.98696460e-01], 6 | [6.65461230e-01, 6.94011420e-01, 9.10465700e-01], 7 | [9.64047590e-01, 1.43082200e-03, 7.39874220e-01], 8 | [1.08159060e-01, 5.53028790e-01, 6.63804780e-02], 9 | [9.31359130e-01, 8.25424910e-01, 9.52315440e-01], 10 | [6.78086960e-01, 3.41903970e-01, 5.61481950e-01], 11 | [9.82730940e-01, 7.04605210e-01, 8.70978630e-02], 12 | [6.14691610e-01, 4.69989230e-02, 6.02406450e-01], 13 | [5.80161260e-01, 9.17354970e-01, 5.88163850e-01], 14 | [1.38246310e+00, 1.96358160e+00, 1.94437880e+00], 15 | [2.10675860e+00, 1.67148730e+00, 1.34854480e+00], 16 | [1.39880070e+00, 1.66142050e+00, 1.32224550e+00], 17 | [1.71410460e+00, 1.49176380e+00, 1.45432170e+00], 18 | [1.54102340e+00, 1.84374950e+00, 1.64658950e+00], 19 | [2.08512480e+00, 1.84524350e+00, 2.17340850e+00], 20 | [1.30748740e+00, 1.53801650e+00, 2.16007740e+00], 21 | [1.41447700e+00, 1.99329070e+00, 1.99107420e+00], 22 | [1.61943490e+00, 1.47703280e+00, 1.89788160e+00], 23 | [1.59880600e+00, 1.54988980e+00, 1.57563350e+00], 24 | [3.37247380e+00, 2.69635310e+00, 3.39981700e+00], 25 | [3.13705120e+00, 3.36528090e+00, 3.06089070e+00], 26 | [3.29413250e+00, 3.19619500e+00, 2.90700170e+00], 27 | [2.65510510e+00, 3.06785900e+00, 2.97198540e+00], 28 | [3.30941040e+00, 2.59283970e+00, 2.57714110e+00], 29 | [2.59557220e+00, 3.33477370e+00, 3.08793190e+00], 30 | [2.58206180e+00, 3.41615670e+00, 3.26441990e+00], 31 | [2.71127000e+00, 2.77032450e+00, 2.63466500e+00], 32 | [2.79617850e+00, 3.25473720e+00, 3.41801560e+00], 33 | [2.64741750e+00, 2.54538040e+00, 3.25354110e+00]]) 34 | 35 | ytdist = array([662., 877., 255., 412., 996., 295., 468., 268., 400., 754., 36 | 564., 138., 219., 869., 669.]) 37 | 38 | linkage_ytdist_single = array([[2., 5., 138., 2.], 39 | [3., 4., 219., 2.], 40 | [0., 7., 255., 3.], 41 | [1., 8., 268., 4.], 42 | [6., 9., 295., 6.]]) 43 | 44 | linkage_ytdist_complete = array([[2., 5., 138., 2.], 45 | [3., 4., 219., 2.], 46 | [1., 6., 400., 3.], 47 | [0., 7., 412., 3.], 48 | [8., 9., 996., 6.]]) 49 | 50 | linkage_ytdist_average = array([[2., 5., 138., 2.], 51 | [3., 4., 219., 2.], 52 | [0., 7., 333.5, 3.], 53 | [1., 6., 347.5, 3.], 54 | [8., 9., 680.77777778, 6.]]) 55 | 56 | linkage_ytdist_weighted = array([[2., 5., 138., 2.], 57 | [3., 4., 219., 2.], 58 | [0., 7., 333.5, 3.], 59 | [1., 6., 347.5, 3.], 60 | [8., 9., 670.125, 6.]]) 61 | 62 | # the optimal leaf ordering of linkage_ytdist_single 63 | linkage_ytdist_single_olo = array([[5., 2., 138., 2.], 64 | [4., 3., 219., 2.], 65 | [7., 0., 255., 3.], 66 | [1., 8., 268., 4.], 67 | [6., 9., 295., 6.]]) 68 | 69 | X = array([[1.43054825, -7.5693489], 70 | [6.95887839, 6.82293382], 71 | [2.87137846, -9.68248579], 72 | [7.87974764, -6.05485803], 73 | [8.24018364, -6.09495602], 74 | [7.39020262, 8.54004355]]) 75 | 76 | linkage_X_centroid = array([[3., 4., 0.36265956, 2.], 77 | [1., 5., 1.77045373, 2.], 78 | [0., 2., 2.55760419, 2.], 79 | [6., 8., 6.43614494, 4.], 80 | [7., 9., 15.17363237, 6.]]) 81 | 82 | linkage_X_median = array([[3., 4., 0.36265956, 2.], 83 | [1., 5., 1.77045373, 2.], 84 | [0., 2., 2.55760419, 2.], 85 | [6., 8., 6.43614494, 4.], 86 | [7., 9., 15.17363237, 6.]]) 87 | 88 | linkage_X_ward = array([[3., 4., 0.36265956, 2.], 89 | [1., 5., 1.77045373, 2.], 90 | [0., 2., 2.55760419, 2.], 91 | [6., 8., 9.10208346, 4.], 92 | [7., 9., 24.7784379, 6.]]) 93 | 94 | # the optimal leaf ordering of linkage_X_ward 95 | linkage_X_ward_olo = array([[4., 3., 0.36265956, 2.], 96 | [5., 1., 1.77045373, 2.], 97 | [2., 0., 2.55760419, 2.], 98 | [6., 8., 9.10208346, 4.], 99 | [7., 9., 24.7784379, 6.]]) 100 | 101 | inconsistent_ytdist = { 102 | 1: array([[138., 0., 1., 0.], 103 | [219., 0., 1., 0.], 104 | [255., 0., 1., 0.], 105 | [268., 0., 1., 0.], 106 | [295., 0., 1., 0.]]), 107 | 2: array([[138., 0., 1., 0.], 108 | [219., 0., 1., 0.], 109 | [237., 25.45584412, 2., 0.70710678], 110 | [261.5, 9.19238816, 2., 0.70710678], 111 | [233.66666667, 83.9424406, 3., 0.7306594]]), 112 | 3: array([[138., 0., 1., 0.], 113 | [219., 0., 1., 0.], 114 | [237., 25.45584412, 2., 0.70710678], 115 | [247.33333333, 25.38372182, 3., 0.81417007], 116 | [239., 69.36377537, 4., 0.80733783]]), 117 | 4: array([[138., 0., 1., 0.], 118 | [219., 0., 1., 0.], 119 | [237., 25.45584412, 2., 0.70710678], 120 | [247.33333333, 25.38372182, 3., 0.81417007], 121 | [235., 60.73302232, 5., 0.98793042]])} 122 | 123 | fcluster_inconsistent = { 124 | 0.8: array([6, 2, 2, 4, 6, 2, 3, 7, 3, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1, 125 | 1, 1, 1, 1, 1, 1, 1, 1, 1]), 126 | 1.0: array([6, 2, 2, 4, 6, 2, 3, 7, 3, 5, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1, 127 | 1, 1, 1, 1, 1, 1, 1, 1, 1]), 128 | 2.0: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 129 | 1, 1, 1, 1, 1, 1, 1, 1, 1])} 130 | 131 | fcluster_distance = { 132 | 0.6: array([4, 4, 4, 4, 4, 4, 4, 5, 4, 4, 6, 6, 6, 6, 6, 7, 6, 6, 6, 6, 3, 133 | 1, 1, 1, 2, 1, 1, 1, 1, 1]), 134 | 1.0: array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 135 | 1, 1, 1, 1, 1, 1, 1, 1, 1]), 136 | 2.0: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 137 | 1, 1, 1, 1, 1, 1, 1, 1, 1])} 138 | 139 | fcluster_maxclust = { 140 | 8.0: array([5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 7, 7, 7, 7, 7, 8, 7, 7, 7, 7, 4, 141 | 1, 1, 1, 3, 1, 1, 1, 1, 2]), 142 | 4.0: array([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 143 | 1, 1, 1, 1, 1, 1, 1, 1, 1]), 144 | 1.0: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 145 | 1, 1, 1, 1, 1, 1, 1, 1, 1])} 146 | -------------------------------------------------------------------------------- /sstnet/lib/cluster/tests/test_vq.py: -------------------------------------------------------------------------------- 1 | 2 | import warnings 3 | import sys 4 | 5 | import numpy as np 6 | from numpy.testing import (assert_array_equal, assert_array_almost_equal, 7 | assert_allclose, assert_equal, assert_, 8 | suppress_warnings) 9 | import pytest 10 | from pytest import raises as assert_raises 11 | 12 | from scipy.cluster.vq import (kmeans, kmeans2, py_vq, vq, whiten, 13 | ClusterError, _krandinit) 14 | from scipy.cluster import _vq 15 | from scipy.sparse.sputils import matrix 16 | 17 | 18 | TESTDATA_2D = np.array([ 19 | -2.2, 1.17, -1.63, 1.69, -2.04, 4.38, -3.09, 0.95, -1.7, 4.79, -1.68, 0.68, 20 | -2.26, 3.34, -2.29, 2.55, -1.72, -0.72, -1.99, 2.34, -2.75, 3.43, -2.45, 21 | 2.41, -4.26, 3.65, -1.57, 1.87, -1.96, 4.03, -3.01, 3.86, -2.53, 1.28, 22 | -4.0, 3.95, -1.62, 1.25, -3.42, 3.17, -1.17, 0.12, -3.03, -0.27, -2.07, 23 | -0.55, -1.17, 1.34, -2.82, 3.08, -2.44, 0.24, -1.71, 2.48, -5.23, 4.29, 24 | -2.08, 3.69, -1.89, 3.62, -2.09, 0.26, -0.92, 1.07, -2.25, 0.88, -2.25, 25 | 2.02, -4.31, 3.86, -2.03, 3.42, -2.76, 0.3, -2.48, -0.29, -3.42, 3.21, 26 | -2.3, 1.73, -2.84, 0.69, -1.81, 2.48, -5.24, 4.52, -2.8, 1.31, -1.67, 27 | -2.34, -1.18, 2.17, -2.17, 2.82, -1.85, 2.25, -2.45, 1.86, -6.79, 3.94, 28 | -2.33, 1.89, -1.55, 2.08, -1.36, 0.93, -2.51, 2.74, -2.39, 3.92, -3.33, 29 | 2.99, -2.06, -0.9, -2.83, 3.35, -2.59, 3.05, -2.36, 1.85, -1.69, 1.8, 30 | -1.39, 0.66, -2.06, 0.38, -1.47, 0.44, -4.68, 3.77, -5.58, 3.44, -2.29, 31 | 2.24, -1.04, -0.38, -1.85, 4.23, -2.88, 0.73, -2.59, 1.39, -1.34, 1.75, 32 | -1.95, 1.3, -2.45, 3.09, -1.99, 3.41, -5.55, 5.21, -1.73, 2.52, -2.17, 33 | 0.85, -2.06, 0.49, -2.54, 2.07, -2.03, 1.3, -3.23, 3.09, -1.55, 1.44, 34 | -0.81, 1.1, -2.99, 2.92, -1.59, 2.18, -2.45, -0.73, -3.12, -1.3, -2.83, 35 | 0.2, -2.77, 3.24, -1.98, 1.6, -4.59, 3.39, -4.85, 3.75, -2.25, 1.71, -3.28, 36 | 3.38, -1.74, 0.88, -2.41, 1.92, -2.24, 1.19, -2.48, 1.06, -1.68, -0.62, 37 | -1.3, 0.39, -1.78, 2.35, -3.54, 2.44, -1.32, 0.66, -2.38, 2.76, -2.35, 38 | 3.95, -1.86, 4.32, -2.01, -1.23, -1.79, 2.76, -2.13, -0.13, -5.25, 3.84, 39 | -2.24, 1.59, -4.85, 2.96, -2.41, 0.01, -0.43, 0.13, -3.92, 2.91, -1.75, 40 | -0.53, -1.69, 1.69, -1.09, 0.15, -2.11, 2.17, -1.53, 1.22, -2.1, -0.86, 41 | -2.56, 2.28, -3.02, 3.33, -1.12, 3.86, -2.18, -1.19, -3.03, 0.79, -0.83, 42 | 0.97, -3.19, 1.45, -1.34, 1.28, -2.52, 4.22, -4.53, 3.22, -1.97, 1.75, 43 | -2.36, 3.19, -0.83, 1.53, -1.59, 1.86, -2.17, 2.3, -1.63, 2.71, -2.03, 44 | 3.75, -2.57, -0.6, -1.47, 1.33, -1.95, 0.7, -1.65, 1.27, -1.42, 1.09, -3.0, 45 | 3.87, -2.51, 3.06, -2.6, 0.74, -1.08, -0.03, -2.44, 1.31, -2.65, 2.99, 46 | -1.84, 1.65, -4.76, 3.75, -2.07, 3.98, -2.4, 2.67, -2.21, 1.49, -1.21, 47 | 1.22, -5.29, 2.38, -2.85, 2.28, -5.6, 3.78, -2.7, 0.8, -1.81, 3.5, -3.75, 48 | 4.17, -1.29, 2.99, -5.92, 3.43, -1.83, 1.23, -1.24, -1.04, -2.56, 2.37, 49 | -3.26, 0.39, -4.63, 2.51, -4.52, 3.04, -1.7, 0.36, -1.41, 0.04, -2.1, 1.0, 50 | -1.87, 3.78, -4.32, 3.59, -2.24, 1.38, -1.99, -0.22, -1.87, 1.95, -0.84, 51 | 2.17, -5.38, 3.56, -1.27, 2.9, -1.79, 3.31, -5.47, 3.85, -1.44, 3.69, 52 | -2.02, 0.37, -1.29, 0.33, -2.34, 2.56, -1.74, -1.27, -1.97, 1.22, -2.51, 53 | -0.16, -1.64, -0.96, -2.99, 1.4, -1.53, 3.31, -2.24, 0.45, -2.46, 1.71, 54 | -2.88, 1.56, -1.63, 1.46, -1.41, 0.68, -1.96, 2.76, -1.61, 55 | 2.11]).reshape((200, 2)) 56 | 57 | 58 | # Global data 59 | X = np.array([[3.0, 3], [4, 3], [4, 2], 60 | [9, 2], [5, 1], [6, 2], [9, 4], 61 | [5, 2], [5, 4], [7, 4], [6, 5]]) 62 | 63 | CODET1 = np.array([[3.0000, 3.0000], 64 | [6.2000, 4.0000], 65 | [5.8000, 1.8000]]) 66 | 67 | CODET2 = np.array([[11.0/3, 8.0/3], 68 | [6.7500, 4.2500], 69 | [6.2500, 1.7500]]) 70 | 71 | LABEL1 = np.array([0, 1, 2, 2, 2, 2, 1, 2, 1, 1, 1]) 72 | 73 | 74 | class TestWhiten(object): 75 | def test_whiten(self): 76 | desired = np.array([[5.08738849, 2.97091878], 77 | [3.19909255, 0.69660580], 78 | [4.51041982, 0.02640918], 79 | [4.38567074, 0.95120889], 80 | [2.32191480, 1.63195503]]) 81 | for tp in np.array, matrix: 82 | obs = tp([[0.98744510, 0.82766775], 83 | [0.62093317, 0.19406729], 84 | [0.87545741, 0.00735733], 85 | [0.85124403, 0.26499712], 86 | [0.45067590, 0.45464607]]) 87 | assert_allclose(whiten(obs), desired, rtol=1e-5) 88 | 89 | def test_whiten_zero_std(self): 90 | desired = np.array([[0., 1.0, 2.86666544], 91 | [0., 1.0, 1.32460034], 92 | [0., 1.0, 3.74382172]]) 93 | for tp in np.array, matrix: 94 | obs = tp([[0., 1., 0.74109533], 95 | [0., 1., 0.34243798], 96 | [0., 1., 0.96785929]]) 97 | with warnings.catch_warnings(record=True) as w: 98 | warnings.simplefilter('always') 99 | assert_allclose(whiten(obs), desired, rtol=1e-5) 100 | assert_equal(len(w), 1) 101 | assert_(issubclass(w[-1].category, RuntimeWarning)) 102 | 103 | def test_whiten_not_finite(self): 104 | for tp in np.array, matrix: 105 | for bad_value in np.nan, np.inf, -np.inf: 106 | obs = tp([[0.98744510, bad_value], 107 | [0.62093317, 0.19406729], 108 | [0.87545741, 0.00735733], 109 | [0.85124403, 0.26499712], 110 | [0.45067590, 0.45464607]]) 111 | assert_raises(ValueError, whiten, obs) 112 | 113 | 114 | class TestVq(object): 115 | def test_py_vq(self): 116 | initc = np.concatenate(([[X[0]], [X[1]], [X[2]]])) 117 | for tp in np.array, matrix: 118 | label1 = py_vq(tp(X), tp(initc))[0] 119 | assert_array_equal(label1, LABEL1) 120 | 121 | def test_vq(self): 122 | initc = np.concatenate(([[X[0]], [X[1]], [X[2]]])) 123 | for tp in np.array, matrix: 124 | label1, dist = _vq.vq(tp(X), tp(initc)) 125 | assert_array_equal(label1, LABEL1) 126 | tlabel1, tdist = vq(tp(X), tp(initc)) 127 | 128 | def test_vq_1d(self): 129 | # Test special rank 1 vq algo, python implementation. 130 | data = X[:, 0] 131 | initc = data[:3] 132 | a, b = _vq.vq(data, initc) 133 | ta, tb = py_vq(data[:, np.newaxis], initc[:, np.newaxis]) 134 | assert_array_equal(a, ta) 135 | assert_array_equal(b, tb) 136 | 137 | def test__vq_sametype(self): 138 | a = np.array([1.0, 2.0], dtype=np.float64) 139 | b = a.astype(np.float32) 140 | assert_raises(TypeError, _vq.vq, a, b) 141 | 142 | def test__vq_invalid_type(self): 143 | a = np.array([1, 2], dtype=int) 144 | assert_raises(TypeError, _vq.vq, a, a) 145 | 146 | def test_vq_large_nfeat(self): 147 | X = np.random.rand(20, 20) 148 | code_book = np.random.rand(3, 20) 149 | 150 | codes0, dis0 = _vq.vq(X, code_book) 151 | codes1, dis1 = py_vq(X, code_book) 152 | assert_allclose(dis0, dis1, 1e-5) 153 | assert_array_equal(codes0, codes1) 154 | 155 | X = X.astype(np.float32) 156 | code_book = code_book.astype(np.float32) 157 | 158 | codes0, dis0 = _vq.vq(X, code_book) 159 | codes1, dis1 = py_vq(X, code_book) 160 | assert_allclose(dis0, dis1, 1e-5) 161 | assert_array_equal(codes0, codes1) 162 | 163 | def test_vq_large_features(self): 164 | X = np.random.rand(10, 5) * 1000000 165 | code_book = np.random.rand(2, 5) * 1000000 166 | 167 | codes0, dis0 = _vq.vq(X, code_book) 168 | codes1, dis1 = py_vq(X, code_book) 169 | assert_allclose(dis0, dis1, 1e-5) 170 | assert_array_equal(codes0, codes1) 171 | 172 | 173 | class TestKMean(object): 174 | def test_large_features(self): 175 | # Generate a data set with large values, and run kmeans on it to 176 | # (regression for 1077). 177 | d = 300 178 | n = 100 179 | 180 | m1 = np.random.randn(d) 181 | m2 = np.random.randn(d) 182 | x = 10000 * np.random.randn(n, d) - 20000 * m1 183 | y = 10000 * np.random.randn(n, d) + 20000 * m2 184 | 185 | data = np.empty((x.shape[0] + y.shape[0], d), np.double) 186 | data[:x.shape[0]] = x 187 | data[x.shape[0]:] = y 188 | 189 | kmeans(data, 2) 190 | 191 | def test_kmeans_simple(self): 192 | np.random.seed(54321) 193 | initc = np.concatenate(([[X[0]], [X[1]], [X[2]]])) 194 | for tp in np.array, matrix: 195 | code1 = kmeans(tp(X), tp(initc), iter=1)[0] 196 | assert_array_almost_equal(code1, CODET2) 197 | 198 | def test_kmeans_lost_cluster(self): 199 | # This will cause kmeans to have a cluster with no points. 200 | data = TESTDATA_2D 201 | initk = np.array([[-1.8127404, -0.67128041], 202 | [2.04621601, 0.07401111], 203 | [-2.31149087, -0.05160469]]) 204 | 205 | kmeans(data, initk) 206 | with suppress_warnings() as sup: 207 | sup.filter(UserWarning, 208 | "One of the clusters is empty. Re-run kmeans with a " 209 | "different initialization") 210 | kmeans2(data, initk, missing='warn') 211 | 212 | assert_raises(ClusterError, kmeans2, data, initk, missing='raise') 213 | 214 | def test_kmeans2_simple(self): 215 | np.random.seed(12345678) 216 | initc = np.concatenate(([[X[0]], [X[1]], [X[2]]])) 217 | for tp in np.array, matrix: 218 | code1 = kmeans2(tp(X), tp(initc), iter=1)[0] 219 | code2 = kmeans2(tp(X), tp(initc), iter=2)[0] 220 | 221 | assert_array_almost_equal(code1, CODET1) 222 | assert_array_almost_equal(code2, CODET2) 223 | 224 | def test_kmeans2_rank1(self): 225 | data = TESTDATA_2D 226 | data1 = data[:, 0] 227 | 228 | initc = data1[:3] 229 | code = initc.copy() 230 | kmeans2(data1, code, iter=1)[0] 231 | kmeans2(data1, code, iter=2)[0] 232 | 233 | def test_kmeans2_rank1_2(self): 234 | data = TESTDATA_2D 235 | data1 = data[:, 0] 236 | kmeans2(data1, 2, iter=1) 237 | 238 | def test_kmeans2_high_dim(self): 239 | # test kmeans2 when the number of dimensions exceeds the number 240 | # of input points 241 | data = TESTDATA_2D 242 | data = data.reshape((20, 20))[:10] 243 | kmeans2(data, 2) 244 | 245 | def test_kmeans2_init(self): 246 | np.random.seed(12345) 247 | data = TESTDATA_2D 248 | 249 | kmeans2(data, 3, minit='points') 250 | kmeans2(data[:, :1], 3, minit='points') # special case (1-D) 251 | 252 | kmeans2(data, 3, minit='++') 253 | kmeans2(data[:, :1], 3, minit='++') # special case (1-D) 254 | 255 | # minit='random' can give warnings, filter those 256 | with suppress_warnings() as sup: 257 | sup.filter(message="One of the clusters is empty. Re-run.") 258 | kmeans2(data, 3, minit='random') 259 | kmeans2(data[:, :1], 3, minit='random') # special case (1-D) 260 | 261 | @pytest.mark.skipif(sys.platform == 'win32', 262 | reason='Fails with MemoryError in Wine.') 263 | def test_krandinit(self): 264 | data = TESTDATA_2D 265 | datas = [data.reshape((200, 2)), data.reshape((20, 20))[:10]] 266 | k = int(1e6) 267 | for data in datas: 268 | np.random.seed(1234) 269 | init = _krandinit(data, k) 270 | orig_cov = np.cov(data, rowvar=0) 271 | init_cov = np.cov(init, rowvar=0) 272 | assert_allclose(orig_cov, init_cov, atol=1e-2) 273 | 274 | def test_kmeans2_empty(self): 275 | # Regression test for gh-1032. 276 | assert_raises(ValueError, kmeans2, [], 2) 277 | 278 | def test_kmeans_0k(self): 279 | # Regression test for gh-1073: fail when k arg is 0. 280 | assert_raises(ValueError, kmeans, X, 0) 281 | assert_raises(ValueError, kmeans2, X, 0) 282 | assert_raises(ValueError, kmeans2, X, np.array([])) 283 | 284 | def test_kmeans_large_thres(self): 285 | # Regression test for gh-1774 286 | x = np.array([1, 2, 3, 4, 10], dtype=float) 287 | res = kmeans(x, 1, thresh=1e16) 288 | assert_allclose(res[0], np.array([4.])) 289 | assert_allclose(res[1], 2.3999999999999999) 290 | 291 | def test_kmeans2_kpp_low_dim(self): 292 | # Regression test for gh-11462 293 | prev_res = np.array([[-1.95266667, 0.898], 294 | [-3.153375, 3.3945]]) 295 | np.random.seed(42) 296 | res, _ = kmeans2(TESTDATA_2D, 2, minit='++') 297 | assert_allclose(res, prev_res) 298 | 299 | def test_kmeans2_kpp_high_dim(self): 300 | # Regression test for gh-11462 301 | n_dim = 100 302 | size = 10 303 | centers = np.vstack([5 * np.ones(n_dim), 304 | -5 * np.ones(n_dim)]) 305 | np.random.seed(42) 306 | data = np.vstack([ 307 | np.random.multivariate_normal(centers[0], np.eye(n_dim), size=size), 308 | np.random.multivariate_normal(centers[1], np.eye(n_dim), size=size) 309 | ]) 310 | res, _ = kmeans2(data, 2, minit='++') 311 | assert_array_almost_equal(res, centers, decimal=0) 312 | -------------------------------------------------------------------------------- /sstnet/lib/htree/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Gorilla-Lab. All rights reserved. 2 | from setuptools import setup 3 | from pybind11.setup_helpers import Pybind11Extension, build_ext 4 | # from torch.utils import cpp_extension 5 | 6 | setup( 7 | name="htree", 8 | ext_modules=[ 9 | Pybind11Extension("htree", [ 10 | "src/tree.cpp", 11 | "src/api.cpp", 12 | ]) 13 | ], 14 | cmdclass={"build_ext": build_ext} 15 | ) -------------------------------------------------------------------------------- /sstnet/lib/htree/src/api.cpp: -------------------------------------------------------------------------------- 1 | #include "tree.h" 2 | // #include 3 | 4 | namespace py = pybind11; 5 | 6 | #define TORCH_EXTENSION_NAME htree 7 | 8 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 9 | m.doc() = "hierarchy construction"; 10 | 11 | py::class_(m, "Tree") 12 | .def(py::init()) 13 | .def("num", &Tree::getNum, "get the number of nodes") 14 | .def("root", &Tree::getRoot, "get the root id") 15 | .def("is_leaf", &Tree::isLeaf, "judge whether is leaf or not") 16 | .def("get_leaves", &Tree::getLeaves, "get leaves according to given id") 17 | .def("fusion_record", &Tree::fusionRecord, "fusion and record the process"); 18 | } 19 | 20 | -------------------------------------------------------------------------------- /sstnet/lib/htree/src/tree.cpp: -------------------------------------------------------------------------------- 1 | #include "tree.h" 2 | 3 | 4 | // build hierarchical tree according to the connection 5 | Tree::Tree(DoubleList &c) 6 | { 7 | connection = c; 8 | numLeaves = connection.size() + 1; 9 | amount = numLeaves + connection.size(); 10 | } 11 | 12 | int Tree::getNum() 13 | { return amount; } 14 | 15 | int Tree::getRoot() 16 | { return amount - 1; } 17 | 18 | bool Tree::isLeaf(int id) 19 | { 20 | if (id >= amount) 21 | { throw "id is out of range!"; } 22 | return (id < numLeaves); 23 | } 24 | 25 | 26 | NumList Tree::getLeaves(int id) 27 | { 28 | NumList leaves = {}; 29 | if (isLeaf(id)) { 30 | leaves.push_back(id); 31 | } else { 32 | // traverse child] 33 | int idx = id - numLeaves; 34 | NumList childrenIds = connection[idx]; // children's ids 35 | for (Int i = 0; i < 2; ++i) { 36 | Int childId = childrenIds[i]; 37 | NumList childLeaves = getLeaves(childId); // get children's leaves 38 | leaves.insert(leaves.end(), childLeaves.begin(), childLeaves.end()); 39 | } 40 | } 41 | return leaves; 42 | } 43 | 44 | 45 | std::tuple Tree::fusionRecord() 46 | { 47 | NumList leftList; 48 | NumList rightList; 49 | NumList fusionList; 50 | NumList leftIds; 51 | NumList rightIds; 52 | NumList fusionIds; 53 | for (Int i = 0; i < connection.size(); ++i) { 54 | NumList connect = connection[i]; 55 | Int c0 = connect[0]; 56 | Int c1 = connect[1]; 57 | // get both leaves 58 | NumList leavesLeft = getLeaves(c0); 59 | NumList leavesRight = getLeaves(c1); 60 | // copy and concat 61 | NumList leavesFusion; 62 | leavesFusion = leavesLeft; 63 | // std::copy(leavesLeft.begin(), leavesLeft.end(), std::back_insert_iterator(leavesFusion)); 64 | leavesFusion.insert(leavesFusion.end(), leavesRight.begin(), leavesRight.end()); 65 | // concat to record (TODO wrap a function) 66 | leftList.insert(leftList.end(), leavesLeft.begin(), leavesLeft.end()); 67 | rightList.insert(rightList.end(), leavesRight.begin(), leavesRight.end()); 68 | fusionList.insert(fusionList.end(), leavesFusion.begin(), leavesFusion.end()); 69 | NumList tempIdsLeaf (leavesLeft.size(), i); 70 | NumList tempIdsRight (leavesRight.size(), i); 71 | NumList tempIdsFusion (leavesFusion.size(), i); 72 | leftIds.insert(leftIds.end(), tempIdsLeaf.begin(), tempIdsLeaf.end()); 73 | rightIds.insert(rightIds.end(), tempIdsRight.begin(), tempIdsRight.end()); 74 | fusionIds.insert(fusionIds.end(), tempIdsFusion.begin(), tempIdsFusion.end()); 75 | } 76 | return std::tuple(leftList, leftIds, rightList, rightIds, fusionList, fusionIds); 77 | } 78 | 79 | -------------------------------------------------------------------------------- /sstnet/lib/htree/src/tree.h: -------------------------------------------------------------------------------- 1 | #ifndef TREE_H 2 | #define TREE_H 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include // everything needed for embedding 8 | #include 9 | #include 10 | #include 11 | 12 | using Int = int32_t; 13 | using NumList = std::vector; 14 | using DoubleList = std::vector; 15 | 16 | class Tree 17 | { 18 | private: 19 | DoubleList connection; 20 | int numLeaves; 21 | int amount; 22 | public: 23 | Tree(DoubleList &c); 24 | int getNum(); 25 | int getRoot(); 26 | bool isLeaf(int id); 27 | NumList getLeaves(int id); 28 | std::tuple fusionRecord(); 29 | }; 30 | 31 | #endif 32 | 33 | -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .pointgroup_ops import (voxelization_idx, voxelization, point_recover, 2 | ballquery_batch_p, bfs_cluster, roipool, get_iou, 3 | sec_mean, sec_min, sec_max) 4 | 5 | -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/pointgroup_ops.py: -------------------------------------------------------------------------------- 1 | ''' 2 | PointGroup operations 3 | Written by Li Jiang 4 | ''' 5 | 6 | import torch 7 | from torch.autograd import Function 8 | 9 | # import pointgroup_ops_ext 10 | from . import pointgroup_ops_ext 11 | 12 | class Voxelization_Idx(Function): 13 | @staticmethod 14 | def forward(ctx, coords, batchsize, mode=4): 15 | ''' 16 | :param ctx: 17 | :param coords: long (N, dimension + 1) or (N, dimension) dimension = 3 18 | :param batchsize 19 | :param mode: int 4=mean 20 | :param dimension: int 21 | :return: output_coords: long (M, dimension + 1) (M <= N) 22 | :return: output_map: int M * (maxActive + 1) 23 | :return: input_map: int N 24 | ''' 25 | assert coords.is_contiguous() 26 | N = coords.size(0) 27 | output_coords = coords.new() 28 | 29 | input_map = torch.IntTensor(N).zero_() 30 | output_map = input_map.new() 31 | 32 | pointgroup_ops_ext.voxelize_idx(coords, output_coords, input_map, output_map, batchsize, mode) 33 | return output_coords, input_map, output_map 34 | 35 | 36 | @staticmethod 37 | def backward(ctx, a=None, b=None, c=None): 38 | return None 39 | 40 | voxelization_idx = Voxelization_Idx.apply 41 | 42 | 43 | class Voxelization(Function): 44 | @staticmethod 45 | def forward(ctx, feats, map_rule, mode=4): 46 | ''' 47 | :param ctx: 48 | :param map_rule: cuda int M * (maxActive + 1) 49 | :param feats: cuda float N * C 50 | :return: output_feats: cuda float M * C 51 | ''' 52 | assert map_rule.is_contiguous() 53 | assert feats.is_contiguous() 54 | N, C = feats.size() 55 | M = map_rule.size(0) 56 | maxActive = map_rule.size(1) - 1 57 | 58 | output_feats = torch.cuda.FloatTensor(M, C).zero_() 59 | 60 | ctx.for_backwards = (map_rule, mode, maxActive, N) 61 | 62 | pointgroup_ops_ext.voxelize_fp(feats, output_feats, map_rule, mode, M, maxActive, C) 63 | return output_feats 64 | 65 | 66 | @staticmethod 67 | def backward(ctx, d_output_feats): 68 | map_rule, mode, maxActive, N = ctx.for_backwards 69 | M, C = d_output_feats.size() 70 | 71 | d_feats = torch.cuda.FloatTensor(N, C).zero_() 72 | 73 | pointgroup_ops_ext.voxelize_bp(d_output_feats.contiguous(), d_feats, map_rule, mode, M, maxActive, C) 74 | return d_feats, None, None 75 | 76 | voxelization = Voxelization.apply 77 | 78 | 79 | class PointRecover(Function): 80 | @staticmethod 81 | def forward(ctx, feats, map_rule, nPoint): 82 | ''' 83 | :param ctx: 84 | :param feats: cuda float M * C 85 | :param map_rule: cuda int M * (maxActive + 1) 86 | :param nPoint: int 87 | :return: output_feats: cuda float N * C 88 | ''' 89 | assert map_rule.is_contiguous() 90 | assert feats.is_contiguous() 91 | M, C = feats.size() 92 | maxActive = map_rule.size(1) - 1 93 | 94 | output_feats = torch.cuda.FloatTensor(nPoint, C).zero_() 95 | 96 | ctx.for_backwards = (map_rule, maxActive, M) 97 | 98 | pointgroup_ops_ext.point_recover_fp(feats, output_feats, map_rule, M, maxActive, C) 99 | 100 | return output_feats 101 | 102 | @staticmethod 103 | def backward(ctx, d_output_feats): 104 | map_rule, maxActive, M = ctx.for_backwards 105 | N, C = d_output_feats.size() 106 | 107 | d_feats = torch.cuda.FloatTensor(M, C).zero_() 108 | 109 | pointgroup_ops_ext.point_recover_bp(d_output_feats.contiguous(), d_feats, map_rule, M, maxActive, C) 110 | 111 | return d_feats, None, None 112 | 113 | point_recover = PointRecover.apply 114 | 115 | 116 | class BallQueryBatchP(Function): 117 | @staticmethod 118 | def forward(ctx, coords, batch_idxs, batch_offsets, radius, meanActive): 119 | ''' 120 | :param ctx: 121 | :param coords: (n, 3) float 122 | :param batch_idxs: (n) int 123 | :param batch_offsets: (B+1) int 124 | :param radius: float 125 | :param meanActive: int 126 | :return: idx (nActive), int 127 | :return: start_len (n, 2), int 128 | ''' 129 | 130 | n = coords.size(0) 131 | 132 | assert coords.is_contiguous() and coords.is_cuda 133 | assert batch_idxs.is_contiguous() and batch_idxs.is_cuda 134 | assert batch_offsets.is_contiguous() and batch_offsets.is_cuda 135 | 136 | while True: 137 | idx = torch.cuda.IntTensor(n * meanActive).zero_() 138 | start_len = torch.cuda.IntTensor(n, 2).zero_() 139 | nActive = pointgroup_ops_ext.ballquery_batch_p(coords, batch_idxs, batch_offsets, idx, start_len, n, meanActive, radius) 140 | if nActive <= n * meanActive: 141 | break 142 | meanActive = int(nActive // n + 1) 143 | idx = idx[:nActive] 144 | 145 | return idx, start_len 146 | 147 | @staticmethod 148 | def backward(ctx, a=None, b=None): 149 | return None, None, None 150 | 151 | ballquery_batch_p = BallQueryBatchP.apply 152 | 153 | 154 | class BFSCluster(Function): 155 | @staticmethod 156 | def forward(ctx, semantic_label, ball_query_idxs, start_len, threshold): 157 | ''' 158 | :param ctx: 159 | :param semantic_label: (N), int 160 | :param ball_query_idxs: (nActive), int 161 | :param start_len: (N, 2), int 162 | :return: cluster_idxs: int (sumNPoint, 2), dim 0 for cluster_id, dim 1 for corresponding point idxs in N 163 | :return: cluster_offsets: int (nCluster + 1) 164 | ''' 165 | 166 | N = start_len.size(0) 167 | 168 | assert semantic_label.is_contiguous() 169 | assert ball_query_idxs.is_contiguous() 170 | assert start_len.is_contiguous() 171 | 172 | cluster_idxs = semantic_label.new() 173 | cluster_offsets = semantic_label.new() 174 | 175 | pointgroup_ops_ext.bfs_cluster(semantic_label, ball_query_idxs, start_len, cluster_idxs, cluster_offsets, N, threshold) 176 | 177 | return cluster_idxs, cluster_offsets 178 | 179 | @staticmethod 180 | def backward(ctx, a=None): 181 | return None 182 | 183 | bfs_cluster = BFSCluster.apply 184 | 185 | 186 | class RoiPool(Function): 187 | @staticmethod 188 | def forward(ctx, feats, proposals_offset): 189 | ''' 190 | :param ctx: 191 | :param feats: (sumNPoint, C) float 192 | :param proposals_offset: (nProposal + 1) int 193 | :return: output_feats (nProposal, C) float 194 | ''' 195 | nProposal = proposals_offset.size(0) - 1 196 | sumNPoint, C = feats.size() 197 | 198 | assert feats.is_contiguous() 199 | assert proposals_offset.is_contiguous() 200 | 201 | output_feats = torch.cuda.FloatTensor(nProposal, C).zero_() 202 | output_maxidx = torch.cuda.IntTensor(nProposal, C).zero_() 203 | 204 | pointgroup_ops_ext.roipool_fp(feats, proposals_offset, output_feats, output_maxidx, nProposal, C) 205 | 206 | ctx.for_backwards = (output_maxidx, proposals_offset, sumNPoint) 207 | 208 | return output_feats 209 | 210 | @staticmethod 211 | def backward(ctx, d_output_feats): 212 | nProposal, C = d_output_feats.size() 213 | 214 | output_maxidx, proposals_offset, sumNPoint = ctx.for_backwards 215 | 216 | d_feats = torch.cuda.FloatTensor(sumNPoint, C).zero_() 217 | 218 | pointgroup_ops_ext.roipool_bp(d_feats, proposals_offset, output_maxidx, d_output_feats.contiguous(), nProposal, C) 219 | 220 | return d_feats, None 221 | 222 | roipool = RoiPool.apply 223 | 224 | 225 | class GetIoU(Function): 226 | @staticmethod 227 | def forward(ctx, proposals_idx, proposals_offset, instance_labels, instance_pointnum): 228 | ''' 229 | :param ctx: 230 | :param proposals_idx: (sumNPoint), int 231 | :param proposals_offset: (nProposal + 1), int 232 | :param instance_labels: (N), long, 0~total_nInst-1, -100 233 | :param instance_pointnum: (total_nInst), int 234 | :return: proposals_iou: (nProposal, total_nInst), float 235 | ''' 236 | nInstance = instance_pointnum.size(0) 237 | nProposal = proposals_offset.size(0) - 1 238 | 239 | assert proposals_idx.is_contiguous() and proposals_idx.is_cuda 240 | assert proposals_offset.is_contiguous() and proposals_offset.is_cuda 241 | assert instance_labels.is_contiguous() and instance_labels.is_cuda 242 | assert instance_pointnum.is_contiguous() and instance_pointnum.is_cuda 243 | 244 | proposals_iou = torch.cuda.FloatTensor(nProposal, nInstance).zero_() 245 | 246 | pointgroup_ops_ext.get_iou(proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou, nInstance, nProposal) 247 | 248 | return proposals_iou 249 | 250 | @staticmethod 251 | def backward(ctx, a=None): 252 | return None, None, None, None 253 | 254 | get_iou = GetIoU.apply 255 | 256 | 257 | class SecMean(Function): 258 | @staticmethod 259 | def forward(ctx, inp, offsets): 260 | ''' 261 | :param ctx: 262 | :param inp: (N, C) float 263 | :param offsets: (nProposal + 1) int 264 | :return: out (nProposal, C) float 265 | ''' 266 | nProposal = offsets.size(0) - 1 267 | C = inp.size(1) 268 | 269 | assert inp.is_contiguous() 270 | assert offsets.is_contiguous() 271 | 272 | out = torch.cuda.FloatTensor(nProposal, C).zero_() 273 | 274 | pointgroup_ops_ext.sec_mean(inp, offsets, out, nProposal, C) 275 | 276 | return out 277 | 278 | @staticmethod 279 | def backward(ctx, a=None): 280 | return None, None 281 | 282 | sec_mean = SecMean.apply 283 | 284 | 285 | class SecMin(Function): 286 | @staticmethod 287 | def forward(ctx, inp, offsets): 288 | ''' 289 | :param ctx: 290 | :param inp: (N, C) float 291 | :param offsets: (nProposal + 1) int 292 | :return: out (nProposal, C) float 293 | ''' 294 | nProposal = offsets.size(0) - 1 295 | C = inp.size(1) 296 | 297 | assert inp.is_contiguous() 298 | assert offsets.is_contiguous() 299 | 300 | out = torch.cuda.FloatTensor(nProposal, C).zero_() 301 | 302 | pointgroup_ops_ext.sec_min(inp, offsets, out, nProposal, C) 303 | 304 | return out 305 | 306 | @staticmethod 307 | def backward(ctx, a=None): 308 | return None, None 309 | 310 | sec_min = SecMin.apply 311 | 312 | 313 | class SecMax(Function): 314 | @staticmethod 315 | def forward(ctx, inp, offsets): 316 | ''' 317 | :param ctx: 318 | :param inp: (N, C) float 319 | :param offsets: (nProposal + 1) int 320 | :return: out (nProposal, C) float 321 | ''' 322 | nProposal = offsets.size(0) - 1 323 | C = inp.size(1) 324 | 325 | assert inp.is_contiguous() 326 | assert offsets.is_contiguous() 327 | 328 | out = torch.cuda.FloatTensor(nProposal, C).zero_() 329 | 330 | pointgroup_ops_ext.sec_max(inp, offsets, out, nProposal, C) 331 | 332 | return out 333 | 334 | @staticmethod 335 | def backward(ctx, a=None): 336 | return None, None 337 | 338 | sec_max = SecMax.apply -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/bfs_cluster/bfs_cluster.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Ball Query with BatchIdx & Clustering Algorithm 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #include "bfs_cluster.h" 8 | 9 | /* ================================== ballquery_batch_p ================================== */ 10 | // input xyz: (n, 3) float 11 | // input batch_idxs: (n) int 12 | // input batch_offsets: (B+1) int, batch_offsets[-1] 13 | // output idx: (n * meanActive) dim 0 for number of points in the ball, idx in n 14 | // output start_len: (n, 2), int 15 | int ballquery_batch_p(at::Tensor xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor batch_offsets_tensor, at::Tensor idx_tensor, at::Tensor start_len_tensor, int n, int meanActive, float radius){ 16 | const float *xyz = xyz_tensor.data(); 17 | const int *batch_idxs = batch_idxs_tensor.data(); 18 | const int *batch_offsets = batch_offsets_tensor.data(); 19 | int *idx = idx_tensor.data(); 20 | int *start_len = start_len_tensor.data(); 21 | 22 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 23 | int cumsum = ballquery_batch_p_cuda(n, meanActive, radius, xyz, batch_idxs, batch_offsets, idx, start_len, stream); 24 | return cumsum; 25 | } 26 | 27 | /* ================================== bfs_cluster ================================== */ 28 | ConnectedComponent find_cc(Int idx, int *semantic_label, Int *ball_query_idxs, int *start_len, int *visited){ 29 | ConnectedComponent cc; 30 | cc.addPoint(idx); 31 | visited[idx] = 1; 32 | 33 | std::queue Q; 34 | assert(Q.empty()); 35 | Q.push(idx); 36 | 37 | while(!Q.empty()){ 38 | Int cur = Q.front(); Q.pop(); 39 | int start = start_len[cur * 2]; 40 | int len = start_len[cur * 2 + 1]; 41 | int label_cur = semantic_label[cur]; 42 | for(Int i = start; i < start + len; i++){ 43 | Int idx_i = ball_query_idxs[i]; 44 | if(semantic_label[idx_i] != label_cur) continue; 45 | if(visited[idx_i] == 1) continue; 46 | 47 | cc.addPoint(idx_i); 48 | visited[idx_i] = 1; 49 | 50 | Q.push(idx_i); 51 | } 52 | } 53 | return cc; 54 | } 55 | 56 | //input: semantic_label, int, N 57 | //input: ball_query_idxs, Int, (nActive) 58 | //input: start_len, int, (N, 2) 59 | //output: clusters, CCs 60 | int get_clusters(int *semantic_label, Int *ball_query_idxs, int *start_len, const Int nPoint, int threshold, ConnectedComponents &clusters){ 61 | int visited[nPoint] = {0}; 62 | 63 | int sumNPoint = 0; 64 | for(Int i = 0; i < nPoint; i++){ 65 | if(visited[i] == 0){ 66 | ConnectedComponent CC = find_cc(i, semantic_label, ball_query_idxs, start_len, visited); 67 | if((int)CC.pt_idxs.size() >= threshold){ 68 | clusters.push_back(CC); 69 | sumNPoint += (int)CC.pt_idxs.size(); 70 | } 71 | } 72 | } 73 | 74 | return sumNPoint; 75 | } 76 | 77 | void fill_cluster_idxs_(ConnectedComponents &CCs, int *cluster_idxs, int *cluster_offsets){ 78 | for(int i = 0; i < (int)CCs.size(); i++){ 79 | cluster_offsets[i + 1] = cluster_offsets[i] + (int)CCs[i].pt_idxs.size(); 80 | for(int j = 0; j < (int)CCs[i].pt_idxs.size(); j++){ 81 | int idx = CCs[i].pt_idxs[j]; 82 | cluster_idxs[(cluster_offsets[i] + j) * 2 + 0] = i; 83 | cluster_idxs[(cluster_offsets[i] + j) * 2 + 1] = idx; 84 | } 85 | } 86 | } 87 | 88 | //input: semantic_label, int, N 89 | //input: ball_query_idxs, int, (nActive) 90 | //input: start_len, int, (N, 2) 91 | //output: cluster_idxs, int (sumNPoint, 2), dim 0 for cluster_id, dim 1 for corresponding point idxs in N 92 | //output: cluster_offsets, int (nCluster + 1) 93 | void bfs_cluster(at::Tensor semantic_label_tensor, at::Tensor ball_query_idxs_tensor, at::Tensor start_len_tensor, 94 | at::Tensor cluster_idxs_tensor, at::Tensor cluster_offsets_tensor, const int N, int threshold){ 95 | int *semantic_label = semantic_label_tensor.data(); 96 | Int *ball_query_idxs = ball_query_idxs_tensor.data(); 97 | int *start_len = start_len_tensor.data(); 98 | 99 | ConnectedComponents CCs; 100 | int sumNPoint = get_clusters(semantic_label, ball_query_idxs, start_len, N, threshold, CCs); 101 | 102 | int nCluster = (int)CCs.size(); 103 | cluster_idxs_tensor.resize_({sumNPoint, 2}); 104 | cluster_offsets_tensor.resize_({nCluster + 1}); 105 | cluster_idxs_tensor.zero_(); 106 | cluster_offsets_tensor.zero_(); 107 | 108 | int *cluster_idxs = cluster_idxs_tensor.data(); 109 | int *cluster_offsets = cluster_offsets_tensor.data(); 110 | 111 | fill_cluster_idxs_(CCs, cluster_idxs, cluster_offsets); 112 | } -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/bfs_cluster/bfs_cluster.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Ball Query with BatchIdx 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | #include "bfs_cluster.h" 7 | #include "../cuda_utils.h" 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | 14 | /* ================================== ballquery_batch_p ================================== */ 15 | __global__ void ballquery_batch_p_cuda_(int n, int meanActive, float radius, const float *xyz, const int *batch_idxs, const int *batch_offsets, int *idx, int *start_len, int *cumsum) { 16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 17 | if (pt_idx >= n) return; 18 | 19 | start_len += (pt_idx * 2); 20 | int idx_temp[1000]; 21 | 22 | float radius2 = radius * radius; 23 | float o_x = xyz[pt_idx * 3 + 0]; 24 | float o_y = xyz[pt_idx * 3 + 1]; 25 | float o_z = xyz[pt_idx * 3 + 2]; 26 | 27 | int batch_idx = batch_idxs[pt_idx]; 28 | int start = batch_offsets[batch_idx]; 29 | int end = batch_offsets[batch_idx + 1]; 30 | 31 | int cnt = 0; 32 | for(int k = start; k < end; k++){ 33 | float x = xyz[k * 3 + 0]; 34 | float y = xyz[k * 3 + 1]; 35 | float z = xyz[k * 3 + 2]; 36 | float d2 = (o_x - x) * (o_x - x) + (o_y - y) * (o_y - y) + (o_z - z) * (o_z - z); 37 | if(d2 < radius2){ 38 | if(cnt < 1000){ 39 | idx_temp[cnt] = k; 40 | } 41 | else{ 42 | break; 43 | } 44 | ++cnt; 45 | } 46 | } 47 | 48 | start_len[0] = atomicAdd(cumsum, cnt); 49 | start_len[1] = cnt; 50 | 51 | int thre = n * meanActive; 52 | if(start_len[0] >= thre) return; 53 | 54 | idx += start_len[0]; 55 | if(start_len[0] + cnt >= thre) cnt = thre - start_len[0]; 56 | 57 | for(int k = 0; k < cnt; k++){ 58 | idx[k] = idx_temp[k]; 59 | } 60 | } 61 | 62 | 63 | int ballquery_batch_p_cuda(int n, int meanActive, float radius, const float *xyz, const int *batch_idxs, const int *batch_offsets, int *idx, int *start_len, cudaStream_t stream) { 64 | // param xyz: (n, 3) 65 | // param batch_idxs: (n) 66 | // param batch_offsets: (B + 1) 67 | // output idx: (n * meanActive) dim 0 for number of points in the ball, idx in n 68 | // output start_len: (n, 2), int 69 | 70 | cudaError_t err; 71 | 72 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK)); 73 | dim3 threads(THREADS_PER_BLOCK); 74 | 75 | int cumsum = 0; 76 | int* p_cumsum; 77 | cudaMalloc((void**)&p_cumsum, sizeof(int)); 78 | cudaMemcpy(p_cumsum, &cumsum, sizeof(int), cudaMemcpyHostToDevice); 79 | 80 | ballquery_batch_p_cuda_<<>>(n, meanActive, radius, xyz, batch_idxs, batch_offsets, idx, start_len, p_cumsum); 81 | 82 | err = cudaGetLastError(); 83 | if (cudaSuccess != err) { 84 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 85 | exit(-1); 86 | } 87 | 88 | cudaMemcpy(&cumsum, p_cumsum, sizeof(int), cudaMemcpyDeviceToHost); 89 | return cumsum; 90 | } -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/bfs_cluster/bfs_cluster.h: -------------------------------------------------------------------------------- 1 | /* 2 | Ball Query with BatchIdx & Clustering Algorithm 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #ifndef BFS_CLUSTER_H 8 | #define BFS_CLUSTER_H 9 | #include 10 | #include 11 | #include 12 | 13 | #include "../datatype/datatype.h" 14 | 15 | int ballquery_batch_p(at::Tensor xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor batch_offsets_tensor, at::Tensor idx_tensor, at::Tensor start_len_tensor, int n, int meanActive, float radius); 16 | int ballquery_batch_p_cuda(int n, int meanActive, float radius, const float *xyz, const int *batch_idxs, const int *batch_offsets, int *idx, int *start_len, cudaStream_t stream); 17 | 18 | void bfs_cluster(at::Tensor semantic_label_tensor, at::Tensor ball_query_idxs_tensor, at::Tensor start_len_tensor, at::Tensor cluster_idxs_tensor, at::Tensor cluster_offsets_tensor, const int N, int threshold); 19 | 20 | #endif //BFS_CLUSTER_H -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "datatype/datatype.h" 3 | 4 | #include "voxelize/voxelize.cu" 5 | #include "bfs_cluster/bfs_cluster.cu" 6 | #include "roipool/roipool.cu" 7 | #include "get_iou/get_iou.cu" 8 | #include "sec_mean/sec_mean.cu" 9 | 10 | template void voxelize_fp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, float *feats, float *output_feats, Int *rules, bool average); 11 | 12 | template void voxelize_bp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, float *d_output_feats, float *d_feats, Int *rules, bool average); -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define TOTAL_THREADS 1024 7 | 8 | #define THREADS_PER_BLOCK 512 9 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 10 | 11 | inline int opt_n_threads(int work_size) { 12 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 13 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | 16 | inline dim3 opt_block_config(int x, int y) { 17 | const int x_threads = opt_n_threads(x); 18 | const int y_threads = max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 19 | dim3 block_config(x_threads, y_threads, 1); 20 | return block_config; 21 | } 22 | 23 | #endif -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/datatype/datatype.cpp: -------------------------------------------------------------------------------- 1 | #include "datatype.h" 2 | 3 | template SparseGrid::SparseGrid() : ctr(0) { 4 | // Sparsehash needs a key to be set aside and never used 5 | Point empty_key; 6 | for(Int i = 0; i < dimension; i++){ 7 | empty_key[i] = std::numeric_limits::min(); 8 | } 9 | mp.set_empty_key(empty_key); 10 | } 11 | 12 | ConnectedComponent::ConnectedComponent(){} 13 | 14 | void ConnectedComponent::addPoint(Int pt_idx){ 15 | pt_idxs.push_back(pt_idx); 16 | } 17 | -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/datatype/datatype.h: -------------------------------------------------------------------------------- 1 | #ifndef DATATYPE_H 2 | #define DATATYPE_H 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using Int = int32_t; 10 | 11 | template using Point = std::array; 12 | 13 | template struct IntArrayHash{ 14 | std::size_t operator()(Point const &p) const{ 15 | Int hash = 16777619; 16 | for(auto x : p){ 17 | hash *= 2166136261; 18 | hash ^= x; 19 | } 20 | return hash; 21 | } 22 | }; 23 | 24 | template using SparseGridMap = google::dense_hash_map, Int, IntArrayHash, std::equal_to>>; // 25 | 26 | template class SparseGrid{ 27 | public: 28 | Int ctr; 29 | SparseGridMap mp; 30 | SparseGrid(); 31 | }; 32 | 33 | template using SparseGrids = std::vector>; 34 | 35 | using RuleBook = std::vector>; 36 | 37 | class ConnectedComponent{ 38 | public: 39 | std::vector pt_idxs; 40 | 41 | ConnectedComponent(); 42 | void addPoint(Int pt_idx); 43 | }; 44 | 45 | using ConnectedComponents = std::vector; 46 | 47 | #endif //DATATYPE_H -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/get_iou/get_iou.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Get the IoU between predictions and gt masks 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #include "get_iou.h" 8 | 9 | void get_iou(at::Tensor proposals_idx_tensor, at::Tensor proposals_offset_tensor, at::Tensor instance_labels_tensor, at::Tensor instance_pointnum_tensor, at::Tensor proposals_iou_tensor, int nInstance, int nProposal){ 10 | int *proposals_idx = proposals_idx_tensor.data(); 11 | int *proposals_offset = proposals_offset_tensor.data(); 12 | long *instance_labels = instance_labels_tensor.data(); 13 | int *instance_pointnum = instance_pointnum_tensor.data(); 14 | 15 | float *proposals_iou = proposals_iou_tensor.data(); 16 | 17 | get_iou_cuda(nInstance, nProposal, proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou); 18 | } -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/get_iou/get_iou.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Get the IoU between predictions and gt masks 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #include 8 | #include 9 | #include "get_iou.h" 10 | 11 | 12 | __global__ void get_iou_cuda_(int nInstance, int nProposal, int *proposals_idx, int *proposals_offset, long *instance_labels, int *instance_pointnum, float *proposals_iou){ 13 | for(int proposal_id = blockIdx.x; proposal_id < nProposal; proposal_id += gridDim.x){ 14 | int start = proposals_offset[proposal_id]; 15 | int end = proposals_offset[proposal_id + 1]; 16 | int proposal_total = end - start; 17 | for(int instance_id = threadIdx.x; instance_id < nInstance; instance_id += blockDim.x){ 18 | int instance_total = instance_pointnum[instance_id]; 19 | int intersection = 0; 20 | for(int i = start; i < end; i++){ 21 | int idx = proposals_idx[i]; 22 | if((int)instance_labels[idx] == instance_id){ 23 | intersection += 1; 24 | } 25 | } 26 | proposals_iou[proposal_id * nInstance + instance_id] = (float)intersection / ((float)(proposal_total + instance_total - intersection) + 1e-5); 27 | } 28 | } 29 | } 30 | 31 | //input: proposals_idx (sumNPoint), int 32 | //input: proposals_offset (nProposal + 1), int 33 | //input: instance_labels (N), long, 0~total_nInst-1, -100 34 | //input: instance_pointnum (total_nInst), int 35 | //output: proposals_iou (nProposal, total_nInst), float 36 | void get_iou_cuda(int nInstance, int nProposal, int *proposals_idx, int *proposals_offset, long *instance_labels, int *instance_pointnum, float *proposals_iou){ 37 | get_iou_cuda_<<>>(nInstance, nProposal, proposals_idx, proposals_offset, instance_labels, instance_pointnum, proposals_iou); 38 | } -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/get_iou/get_iou.h: -------------------------------------------------------------------------------- 1 | /* 2 | Get the IoU between predictions and gt masks 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #ifndef GET_IOU_H 8 | #define GET_IOU_H 9 | #include 10 | #include 11 | 12 | #include "../datatype/datatype.h" 13 | 14 | // 15 | void get_iou_cuda(int nInstance, int nProposal, int *proposals_idx, int *proposals_offset, long *instance_labels, int *instance_pointnum, float *proposals_iou); 16 | void get_iou(at::Tensor proposals_idx_tensor, at::Tensor proposals_offset_tensor, at::Tensor instance_labels_tensor, at::Tensor instance_pointnum_tensor, at::Tensor proposals_iou_tensor, int nInstance, int nProposal); 17 | 18 | #endif //GET_IOU_H -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/pointgroup_ops.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "datatype/datatype.cpp" 6 | 7 | #include "voxelize/voxelize.cpp" 8 | #include "bfs_cluster/bfs_cluster.cpp" 9 | #include "roipool/roipool.cpp" 10 | #include "get_iou/get_iou.cpp" 11 | #include "sec_mean/sec_mean.cpp" 12 | 13 | void voxelize_idx_3d(/* long N*4 */ at::Tensor coords, /* long M*4 */ at::Tensor output_coords, 14 | /* Int N */ at::Tensor input_map, /* Int M*(maxActive+1) */ at::Tensor output_map, Int batchSize, Int mode){ 15 | voxelize_idx<3>(coords, output_coords, input_map, output_map, batchSize, mode); 16 | } 17 | 18 | void voxelize_fp_feat(/* cuda float N*C */ at::Tensor feats, // N * 3 -> M * 3 (N >= M) 19 | /* cuda float M*C */ at::Tensor output_feats, 20 | /* cuda Int M*(maxActive+1) */ at::Tensor output_map, Int mode, Int nActive, Int maxActive, Int nPlane){ 21 | voxelize_fp(feats, output_feats, output_map, mode, nActive, maxActive, nPlane); 22 | } 23 | 24 | 25 | void voxelize_bp_feat(/* cuda float M*C */ at::Tensor d_output_feats, /* cuda float N*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor output_map, 26 | Int mode, Int nActive, Int maxActive, Int nPlane){ 27 | voxelize_bp(d_output_feats, d_feats, output_map, mode, nActive, maxActive, nPlane); 28 | } 29 | 30 | void point_recover_fp_feat(/* cuda float M*C */ at::Tensor feats, /* cuda float N*C */ at::Tensor output_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map, 31 | Int nActive, Int maxActive, Int nPlane){ 32 | point_recover_fp(feats, output_feats, idx_map, nActive, maxActive, nPlane); 33 | } 34 | 35 | void point_recover_bp_feat(/* cuda float N*C */ at::Tensor d_output_feats, /* cuda float M*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map, 36 | Int nActive, Int maxActive, Int nPlane){ 37 | point_recover_bp(d_output_feats, d_feats, idx_map, nActive, maxActive, nPlane); 38 | } 39 | -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/pointgroup_ops.h: -------------------------------------------------------------------------------- 1 | #ifndef POINTGROUP_H 2 | #define POINTGROUP_H 3 | #include "datatype/datatype.h" 4 | 5 | #include "bfs_cluster/bfs_cluster.h" 6 | #include "roipool/roipool.h" 7 | #include "get_iou/get_iou.h" 8 | #include "sec_mean/sec_mean.h" 9 | 10 | void voxelize_idx_3d(/* long N*4 */ at::Tensor coords, /* long M*4 */ at::Tensor output_coords, 11 | /* Int N */ at::Tensor input_map, /* Int M*(maxActive+1) */ at::Tensor output_map, Int batchSize, Int mode); 12 | 13 | void voxelize_fp_feat(/* cuda float N*C */ at::Tensor feats, // N * 3 -> M * 3 (N >= M) 14 | /* cuda float M*C */ at::Tensor output_feats, 15 | /* cuda Int M*(maxActive+1) */ at::Tensor output_map, Int mode, Int nActive, Int maxActive, Int nPlane); 16 | 17 | void voxelize_bp_feat(/* cuda float M*C */ at::Tensor d_output_feats, /* cuda float N*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor output_map, 18 | Int mode, Int nActive, Int maxActive, Int nPlane); 19 | 20 | void point_recover_fp_feat(/* cuda float M*C */ at::Tensor feats, /* cuda float N*C */ at::Tensor output_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map, 21 | Int nActive, Int maxActive, Int nPlane); 22 | 23 | void point_recover_bp_feat(/* cuda float N*C */ at::Tensor d_output_feats, /* cuda float M*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map, 24 | Int nActive, Int maxActive, Int nPlane); 25 | 26 | 27 | #endif // POINTGROUP_H -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/pointgroup_ops_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "pointgroup_ops.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 7 | m.def("voxelize_idx", &voxelize_idx_3d, "voxelize_idx"); 8 | m.def("voxelize_fp", &voxelize_fp_feat, "voxelize_fp"); 9 | m.def("voxelize_bp", &voxelize_bp_feat, "voxelize_bp"); 10 | m.def("point_recover_fp", &point_recover_fp_feat, "point_recover_fp"); 11 | m.def("point_recover_bp", &point_recover_bp_feat, "point_recover_bp"); 12 | 13 | m.def("ballquery_batch_p", &ballquery_batch_p, "ballquery_batch_p"); 14 | m.def("bfs_cluster", &bfs_cluster, "bfs_cluster"); 15 | 16 | m.def("roipool_fp", &roipool_fp, "roipool_fp"); 17 | m.def("roipool_bp", &roipool_bp, "roipool_bp"); 18 | 19 | m.def("get_iou", &get_iou, "get_iou"); 20 | 21 | m.def("sec_mean", &sec_mean, "sec_mean"); 22 | m.def("sec_min", &sec_min, "sec_min"); 23 | m.def("sec_max", &sec_max, "sec_max"); 24 | } -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/roipool/roipool.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | ROI Max Pool 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #include "roipool.h" 8 | 9 | void roipool_fp(at::Tensor feats_tensor, at::Tensor proposals_offset_tensor, at::Tensor output_feats_tensor, at::Tensor output_maxidx_tensor, int nProposal, int C){ 10 | float *feats = feats_tensor.data(); 11 | int *proposals_offset = proposals_offset_tensor.data(); 12 | float *output_feats = output_feats_tensor.data(); 13 | int *output_maxidx = output_maxidx_tensor.data(); 14 | 15 | roipool_fp_cuda(nProposal, C, feats, proposals_offset, output_feats, output_maxidx); 16 | } 17 | 18 | 19 | void roipool_bp(at::Tensor d_feats_tensor, at::Tensor proposals_offset_tensor, at::Tensor output_maxidx_tensor, at::Tensor d_output_feats_tensor, int nProposal, int C){ 20 | float *d_feats = d_feats_tensor.data(); 21 | int *proposals_offset = proposals_offset_tensor.data(); 22 | int *output_maxidx = output_maxidx_tensor.data(); 23 | float *d_output_feats = d_output_feats_tensor.data(); 24 | 25 | roipool_bp_cuda(nProposal, C, d_feats, proposals_offset, output_maxidx, d_output_feats); 26 | } -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/roipool/roipool.cu: -------------------------------------------------------------------------------- 1 | /* 2 | ROI Max Pool 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #include 8 | #include 9 | #include "roipool.h" 10 | 11 | // fp 12 | __global__ void roipool_fp_cuda_(int nProposal, int C, float *feats, int *proposals_offset, float *output_feats, int *output_maxidx){ 13 | for(int pp_id = blockIdx.x; pp_id < nProposal; pp_id += gridDim.x){ 14 | int start = proposals_offset[pp_id]; 15 | int end = proposals_offset[pp_id + 1]; 16 | 17 | for(int plane = threadIdx.x; plane < C; plane += blockDim.x){ 18 | int argmax_idx = -1; 19 | float max_val = -1e50; 20 | 21 | for(int i = start; i < end; i++){ 22 | if(feats[i * C + plane] > max_val){ 23 | argmax_idx = i; 24 | max_val = feats[i * C + plane]; 25 | } 26 | } 27 | output_maxidx[pp_id * C + plane] = argmax_idx; 28 | output_feats[pp_id * C + plane] = max_val; 29 | } 30 | } 31 | } 32 | 33 | //input: feats (sumNPoint, C) float 34 | //input: proposals_offset (nProposal + 1) int 35 | //output: output_feats (nProposal, C) float 36 | //output: output_maxidx (nProposal, C) int 37 | void roipool_fp_cuda(int nProposal, int C, float *feats, int *proposals_offset, float *output_feats, int *output_maxidx){ 38 | roipool_fp_cuda_<<>>(nProposal, C, feats, proposals_offset, output_feats, output_maxidx); 39 | } 40 | 41 | // bp 42 | __global__ void roipool_bp_cuda_(int nProposal, int C, float *d_feats, int *proposals_offset, int *output_maxidx, float *d_output_feats){ 43 | for(int pp_id = blockIdx.x; pp_id < nProposal; pp_id += gridDim.x){ 44 | for(int plane = threadIdx.x; plane < C; plane += blockDim.x){ 45 | int argmax_idx = output_maxidx[pp_id * C + plane]; 46 | atomicAdd(&d_feats[argmax_idx * C + plane], d_output_feats[pp_id * C + plane]); 47 | } 48 | } 49 | } 50 | 51 | //input: d_output_feats (nProposal, C) float 52 | //input: output_maxidx (nProposal, C) int 53 | //input: proposals_offset (nProposal + 1) int 54 | //output: d_feats (sumNPoint, C) float 55 | void roipool_bp_cuda(int nProposal, int C, float *d_feats, int *proposals_offset, int *output_maxidx, float *d_output_feats){ 56 | roipool_bp_cuda_<<>>(nProposal, C, d_feats, proposals_offset, output_maxidx, d_output_feats); 57 | } -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/roipool/roipool.h: -------------------------------------------------------------------------------- 1 | /* 2 | ROI Max Pool 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #ifndef ROIPOOL_H 8 | #define ROIPOOL_H 9 | #include 10 | #include 11 | 12 | #include "../datatype/datatype.h" 13 | 14 | // 15 | void roipool_fp(at::Tensor feats_tensor, at::Tensor proposals_offset_tensor, at::Tensor output_feats_tensor, at::Tensor output_maxidx_tensor, int nProposal, int C); 16 | 17 | void roipool_fp_cuda(int nProposal, int C, float *feats, int *proposals_offset, float *output_feats, int *output_maxidx); 18 | 19 | 20 | // 21 | void roipool_bp(at::Tensor d_feats_tensor, at::Tensor proposals_offset_tensor, at::Tensor output_maxidx_tensor, at::Tensor d_output_feats_tensor, int nProposal, int C); 22 | 23 | void roipool_bp_cuda(int nProposal, int C, float *d_feats, int *proposals_offset, int *output_maxidx, float *d_output_feats); 24 | 25 | #endif //ROIPOOL_H 26 | -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/sec_mean/sec_mean.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Segment Operations (mean, max, min) 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #include "sec_mean.h" 8 | 9 | void sec_mean(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C){ 10 | int *offsets = offsets_tensor.data(); 11 | float *inp = inp_tensor.data(); 12 | float *out = out_tensor.data(); 13 | 14 | sec_mean_cuda(nProposal, C, inp, offsets, out); 15 | } 16 | 17 | void sec_min(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C){ 18 | int *offsets = offsets_tensor.data(); 19 | float *inp = inp_tensor.data(); 20 | float *out = out_tensor.data(); 21 | 22 | sec_min_cuda(nProposal, C, inp, offsets, out); 23 | } 24 | 25 | void sec_max(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C){ 26 | int *offsets = offsets_tensor.data(); 27 | float *inp = inp_tensor.data(); 28 | float *out = out_tensor.data(); 29 | 30 | sec_max_cuda(nProposal, C, inp, offsets, out); 31 | } -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/sec_mean/sec_mean.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Segment Operations (mean, max, min) (no bp) 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #include 8 | #include 9 | #include "sec_mean.h" 10 | 11 | /* ================================== sec_mean ================================== */ 12 | __global__ void sec_mean_cuda_(int nProposal, int C, float *inp, int *offsets, float *out){ 13 | for(int p_id = blockIdx.x; p_id < nProposal; p_id += gridDim.x){ 14 | int start = offsets[p_id]; 15 | int end = offsets[p_id + 1]; 16 | 17 | float count = (float)(end - start); 18 | 19 | for(int plane = threadIdx.x; plane < C; plane += blockDim.x){ 20 | float mean = 0; 21 | for(int i = start; i < end; i++){ 22 | mean += (inp[i * C + plane] / count); 23 | } 24 | out[p_id * C + plane] = mean; 25 | } 26 | } 27 | } 28 | 29 | //input: inp (N, C) float 30 | //input: offsets (nProposal + 1) int 31 | //output: out (nProposal, C) float 32 | void sec_mean_cuda(int nProposal, int C, float *inp, int *offsets, float *out){ 33 | sec_mean_cuda_<<>>(nProposal, C, inp, offsets, out); 34 | } 35 | 36 | 37 | /* ================================== sec_min ================================== */ 38 | __global__ void sec_min_cuda_(int nProposal, int C, float *inp, int *offsets, float *out){ 39 | for(int p_id = blockIdx.x; p_id < nProposal; p_id += gridDim.x){ 40 | int start = offsets[p_id]; 41 | int end = offsets[p_id + 1]; 42 | 43 | for(int plane = threadIdx.x; plane < C; plane += blockDim.x){ 44 | float min_val = 1e50; 45 | for(int i = start; i < end; i++){ 46 | if(inp[i * C + plane] < min_val){ 47 | min_val = inp[i * C + plane]; 48 | } 49 | } 50 | out[p_id * C + plane] = min_val; 51 | } 52 | } 53 | } 54 | 55 | //input: inp (N, C) float 56 | //input: offsets (nProposal + 1) int 57 | //output: out (nProposal, C) float 58 | void sec_min_cuda(int nProposal, int C, float *inp, int *offsets, float *out){ 59 | sec_min_cuda_<<>>(nProposal, C, inp, offsets, out); 60 | } 61 | 62 | 63 | /* ================================== sec_max ================================== */ 64 | __global__ void sec_max_cuda_(int nProposal, int C, float *inp, int *offsets, float *out){ 65 | for(int p_id = blockIdx.x; p_id < nProposal; p_id += gridDim.x){ 66 | int start = offsets[p_id]; 67 | int end = offsets[p_id + 1]; 68 | 69 | for(int plane = threadIdx.x; plane < C; plane += blockDim.x){ 70 | float max_val = -1e50; 71 | for(int i = start; i < end; i++){ 72 | if(inp[i * C + plane] > max_val){ 73 | max_val = inp[i * C + plane]; 74 | } 75 | } 76 | out[p_id * C + plane] = max_val; 77 | } 78 | } 79 | } 80 | 81 | //input: inp (N, C) float 82 | //input: offsets (nProposal + 1) int 83 | //output: out (nProposal, C) float 84 | void sec_max_cuda(int nProposal, int C, float *inp, int *offsets, float *out){ 85 | sec_max_cuda_<<>>(nProposal, C, inp, offsets, out); 86 | } -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/sec_mean/sec_mean.h: -------------------------------------------------------------------------------- 1 | /* 2 | Segment Operations (mean, max, min) 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #ifndef SEC_MEAN_H 8 | #define SEC_MEAN_H 9 | #include 10 | #include 11 | 12 | #include "../datatype/datatype.h" 13 | 14 | void sec_mean(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C); 15 | void sec_mean_cuda(int nProposal, int C, float *inp, int *offsets, float *out); 16 | 17 | void sec_min(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C); 18 | void sec_min_cuda(int nProposal, int C, float *inp, int *offsets, float *out); 19 | 20 | void sec_max(at::Tensor inp_tensor, at::Tensor offsets_tensor, at::Tensor out_tensor, int nProposal, int C); 21 | void sec_max_cuda(int nProposal, int C, float *inp, int *offsets, float *out); 22 | 23 | 24 | #endif //SEC_MEAN_H 25 | -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/voxelize/voxelize.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Points to Voxels & Voxels to Points (Modified from SparseConv) 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #include "voxelize.h" 8 | 9 | /* ================================== voxelize_idx ================================== */ 10 | template 11 | void voxelize_idx(/* long N*4 */ at::Tensor coords, /* long M*4 */ at::Tensor output_coords, 12 | /* Int N */ at::Tensor input_map, /* Int M*(maxActive+1) */ at::Tensor output_map, Int batchSize, Int mode){ 13 | assert(coords.ndimension() == 2); 14 | assert(coords.size(1) >= dimension and coords.size(1) <= dimension + 1); 15 | 16 | RuleBook voxelizeRuleBook; // rule[1]: M voxels -> N points output_map 17 | SparseGrids inputSGs; // voxel_coords -> voxel_idx in M voxels input_map: N points -> M voxels 18 | Int nActive = 0; 19 | 20 | Int maxActive = voxelize_inputmap(inputSGs, input_map.data(), voxelizeRuleBook, nActive, coords.data(), coords.size(0), coords.size(1), batchSize, mode); 21 | 22 | output_map.resize_({nActive, maxActive + 1}); 23 | output_map.zero_(); 24 | 25 | output_coords.resize_({nActive, coords.size(1)}); 26 | output_coords.zero_(); 27 | 28 | Int *oM = output_map.data(); 29 | long *oC = output_coords.data(); 30 | voxelize_outputmap(coords.data(), oC, oM, &voxelizeRuleBook[1][0], nActive, maxActive); 31 | } 32 | 33 | 34 | template 35 | void voxelize_outputmap(long *coords, long *output_coords, Int *output_map, Int *rule, Int nOutputRows, Int maxActive){ 36 | for(Int i = 0; i < nOutputRows; i++){ 37 | for(Int j = 0; j <= maxActive; j++) 38 | output_map[j] = rule[j]; 39 | Int inputIdx = rule[1]; 40 | rule += (1 + maxActive); 41 | output_map += (1 + maxActive); 42 | 43 | long *coord = coords + inputIdx * (dimension + 1); 44 | long *output_coord = output_coords + i * (dimension + 1); 45 | for(Int j = 0; j <= dimension; j++){ 46 | output_coord[j] = coord[j]; 47 | } 48 | } 49 | } 50 | 51 | //mode 0=guaranteed unique 1=last item(overwrite) 2=first item(keep) 3=sum, 4=mean 52 | //input: coords 53 | //output: SGs: one map for each batch: map from voxel_coord to voxel_idx(in M voxels) 54 | //output: input_map: N, N points -> M voxels 55 | //output: rules 56 | //output: nActive 57 | //output: maxActive 58 | template 59 | Int voxelize_inputmap(SparseGrids &SGs, Int *input_map, RuleBook &rules, Int &nActive, long *coords, Int nInputRows, Int nInputColumns, Int batchSize, Int mode){ 60 | assert(nActive == 0); 61 | assert(rules.size() == 0); 62 | assert(SGs.size() == 0); 63 | 64 | SGs.resize(batchSize); 65 | Point p; 66 | 67 | std::vector> outputRows; 68 | if(nInputColumns == dimension){ 69 | SGs.resize(1); 70 | auto &sg = SGs[0]; 71 | for(Int i = 0; i < nInputRows; i++){ 72 | for(Int j = 0; j < dimension; j++) 73 | p[j] = coords[j]; 74 | coords += dimension; 75 | auto iter = sg.mp.find(p); 76 | if (iter == sg.mp.end()){ 77 | sg.mp[p] = nActive++; 78 | outputRows.resize(nActive); 79 | } 80 | outputRows[sg.mp[p]].push_back(i); 81 | 82 | input_map[i] = sg.mp[p]; 83 | } 84 | } 85 | else{ // nInputColumns == dimension + 1 (1 in index 0 for batchidx) 86 | Int batchIdx; 87 | for(Int i = 0; i < nInputRows; i++){ 88 | batchIdx = coords[0]; 89 | for(Int j = 0; j < dimension; j++) 90 | p[j] = coords[j + 1]; 91 | coords += (dimension + 1); 92 | if(batchIdx + 1 >= (Int)SGs.size()){ 93 | SGs.resize(batchIdx + 1); 94 | } 95 | auto &sg = SGs[batchIdx]; 96 | auto iter = sg.mp.find(p); 97 | if(iter == sg.mp.end()){ 98 | sg.mp[p] = nActive++; 99 | outputRows.resize(nActive); 100 | } 101 | outputRows[sg.mp[p]].push_back(i); 102 | 103 | input_map[i] = sg.mp[p]; 104 | } 105 | } 106 | 107 | // Rulebook Format 108 | // rules[0][0] == mode 109 | // rules[0][1] == maxActive per spatial location (==1 for modes 0,1,2) 110 | // rules[0][2] == nInputRows 111 | // rules[0][3] == nOutputRows 112 | // rules[1] nOutputRows x (1+maxActive) 113 | rules.resize(2); 114 | rules[0].push_back(mode); 115 | rules[0].push_back(1); 116 | rules[0].push_back(nInputRows); 117 | rules[0].push_back(outputRows.size()); 118 | auto &rule = rules[1]; 119 | if(mode == 0){ 120 | assert(nInputRows == (Int)outputRows.size()); 121 | for(Int i = 0; i < nActive; i++){ 122 | rule.push_back(1); 123 | assert((Int)outputRows[i].size() == 1); 124 | rule.push_back(outputRows[i][0]); 125 | } 126 | } 127 | if(mode == 1){ 128 | for(Int i = 0; i < nActive; i++){ 129 | rule.push_back(1); 130 | rule.push_back(outputRows[i].front()); 131 | } 132 | } 133 | if(mode == 2){ 134 | for(Int i = 0; i < nActive; i++){ 135 | rule.push_back(1); 136 | rule.push_back(outputRows[i].back()); 137 | } 138 | } 139 | Int maxActive = 1; 140 | if(mode == 3 or mode == 4){ 141 | for(auto &row: outputRows) 142 | maxActive = std::max(maxActive, (Int)row.size()); 143 | rules[0][1] = maxActive; 144 | for(auto &row: outputRows){ 145 | rule.push_back(row.size()); 146 | for(auto &r: row) 147 | rule.push_back(r); 148 | rule.resize((rule.size() + maxActive) / (maxActive + 1) * (maxActive + 1)); 149 | } 150 | } 151 | return maxActive; 152 | } 153 | 154 | 155 | /* ================================== voxelize ================================== */ 156 | template 157 | void voxelize_fp(/* cuda float N*C */ at::Tensor feats, // N * 3 -> M * 3 (N >= M) 158 | /* cuda float M*C */ at::Tensor output_feats, 159 | /* cuda Int M*(maxActive+1) */ at::Tensor output_map, Int mode, Int nActive, Int maxActive, Int nPlane){ 160 | 161 | auto iF = feats.data(); 162 | auto oF = output_feats.data(); 163 | 164 | Int *rules = output_map.data(); 165 | 166 | voxelize_fp_cuda(nActive, maxActive, nPlane, iF, oF, rules, mode==4); 167 | } 168 | 169 | template 170 | void voxelize_bp(/* cuda float M*C */ at::Tensor d_output_feats, /* cuda float N*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor output_map, 171 | Int mode, Int nActive, Int maxActive, Int nPlane){ 172 | auto d_oF = d_output_feats.data(); 173 | auto d_iF = d_feats.data(); 174 | 175 | Int *rules = output_map.data(); 176 | 177 | voxelize_bp_cuda(nActive, maxActive, nPlane, d_oF, d_iF, rules, mode==4); 178 | } 179 | 180 | /* ================================== point_recover ================================== */ 181 | template 182 | void point_recover_fp(/* cuda float M*C */ at::Tensor feats, /* cuda float N*C */ at::Tensor output_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map, 183 | Int nActive, Int maxActive, Int nPlane){ 184 | auto iF = feats.data(); 185 | auto oF = output_feats.data(); 186 | 187 | Int *rules = idx_map.data(); 188 | 189 | voxelize_bp_cuda(nActive, maxActive, nPlane, iF, oF, rules, false); 190 | } 191 | 192 | 193 | template 194 | void point_recover_bp(/* cuda float N*C */ at::Tensor d_output_feats, /* cuda float M*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map, 195 | Int nActive, Int maxActive, Int nPlane){ 196 | auto d_oF = d_output_feats.data(); 197 | auto d_iF = d_feats.data(); 198 | 199 | Int *rules = idx_map.data(); 200 | 201 | voxelize_fp_cuda(nActive, maxActive, nPlane, d_oF, d_iF, rules, false); 202 | } -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/voxelize/voxelize.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Points to Voxels & Voxels to Points (Modified from SparseConv) 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #include "voxelize.h" 8 | 9 | template 10 | __global__ void voxelize_fp_cuda_(Int nOutputRows, Int maxActive, Int nPlanes, T *feats, T *output_feats, Int *rules, bool average){ 11 | for(int row = blockIdx.x; row < nOutputRows; row += gridDim.x){ 12 | T *out = output_feats + row * nPlanes; 13 | Int *r = rules + row * (maxActive + 1); 14 | Int nActive = r[0]; 15 | T multiplier = (average and nActive > 0) ? (T) 1 / nActive : (T) 1; 16 | for(int i = 1; i <= nActive; i++){ 17 | T *inp = feats + r[i] * nPlanes; 18 | for(int plane = threadIdx.x; plane < nPlanes; plane += blockDim.x){ 19 | atomicAdd(&out[plane], multiplier * inp[plane]); 20 | } 21 | } 22 | } 23 | } 24 | 25 | // input: feats N * C 26 | // input: rules M * (1 + maxActive) 27 | // output: output_feats M * C 28 | template 29 | void voxelize_fp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, T *feats, T *output_feats, Int *rules, bool average){ 30 | voxelize_fp_cuda_<<>>(nOutputRows, maxActive, nPlanes, feats, output_feats, rules, average); 31 | } 32 | 33 | 34 | template 35 | __global__ void voxelize_bp_cuda_(Int nOutputRows, Int maxActive, Int nPlanes, T *d_output_feats, T *d_feats, Int *rules, bool average){ 36 | for(int row = blockIdx.x; row < nOutputRows; row += gridDim.x){ 37 | T *out = d_output_feats + row * nPlanes; 38 | Int *r = rules + row * (maxActive + 1); 39 | Int nActive = r[0]; 40 | T multiplier = (average and nActive > 0) ? (T) 1 / nActive : (T) 1; 41 | for(int i = 1; i <= nActive; i++){ 42 | T *inp = d_feats + r[i] * nPlanes; 43 | for(int plane = threadIdx.x; plane < nPlanes; plane += blockDim.x){ 44 | atomicAdd(&inp[plane], multiplier * out[plane]); 45 | } 46 | } 47 | } 48 | } 49 | 50 | template 51 | void voxelize_bp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, T *d_output_feats, T *d_feats, Int *rules, bool average){ 52 | voxelize_bp_cuda_<<>>(nOutputRows, maxActive, nPlanes, d_output_feats, d_feats, rules, average); 53 | } 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/pointgroup_ops/src/voxelize/voxelize.h: -------------------------------------------------------------------------------- 1 | /* 2 | Points to Voxels & Voxels to Points (Modified from SparseConv) 3 | Written by Li Jiang 4 | All Rights Reserved 2020. 5 | */ 6 | 7 | #ifndef VOXELIZE_H 8 | #define VOXELIZE_H 9 | #include 10 | #include 11 | 12 | #include "../datatype/datatype.h" 13 | 14 | /* ================================== voxelize_idx ================================== */ 15 | template 16 | void voxelize_idx(/* long N*4 */ at::Tensor coords, /* long M*4 */ at::Tensor output_coords, 17 | /* Int N */ at::Tensor input_map, /* Int M*(maxActive+1) */ at::Tensor output_map, Int batchSize, Int mode); 18 | 19 | template 20 | void voxelize_outputmap(long *coords, long *output_coords, Int *output_map, Int *rule, Int nOutputRows, Int maxActive); 21 | 22 | template 23 | Int voxelize_inputmap(SparseGrids &SGs, Int *input_map, RuleBook &rules, Int &nActive, long *coords, Int nInputRows, Int nInputColumns, Int batchSize, Int mode); 24 | 25 | /* ================================== voxelize ================================== */ 26 | template 27 | void voxelize_fp(/* cuda float N*C */ at::Tensor feats, // N * 3 -> M * 3 (N >= M) 28 | /* cuda float M*C */ at::Tensor output_feats, 29 | /* cuda Int M*(maxActive+1) */ at::Tensor output_map, Int mode, Int nActive, Int maxActive, Int nPlane); 30 | 31 | template 32 | void voxelize_fp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, T *feats, T *output_feats, Int *rules, bool average); 33 | 34 | 35 | // 36 | template 37 | void voxelize_bp(/* cuda float M*C */ at::Tensor d_output_feats, /* cuda float N*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor output_map, 38 | Int mode, Int nActive, Int maxActive, Int nPlane); 39 | 40 | template 41 | void voxelize_bp_cuda(Int nOutputRows, Int maxActive, Int nPlanes, T *d_output_feats, T *d_feats, Int *rules, bool average); 42 | 43 | 44 | /* ================================== point_recover ================================== */ 45 | template 46 | void point_recover_fp(/* cuda float M*C */ at::Tensor feats, /* cuda float N*C */ at::Tensor output_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map, 47 | Int nActive, Int maxActive, Int nPlane); 48 | 49 | // 50 | template 51 | void point_recover_bp(/* cuda float N*C */ at::Tensor d_output_feats, /* cuda float M*C */ at::Tensor d_feats, /* cuda Int M*(maxActive+1) */ at::Tensor idx_map, 52 | Int nActive, Int maxActive, Int nPlane); 53 | 54 | 55 | #endif //VOXELIZE_H 56 | -------------------------------------------------------------------------------- /sstnet/lib/pointgroup_ops/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Gorilla-Lab. All rights reserved. 2 | import os.path as osp 3 | from glob import glob 4 | 5 | import torch 6 | from setuptools import setup, find_packages 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | 9 | 10 | def get_sources(module, surfix="*.c*"): 11 | src_dir = osp.join(*module.split("."), "src") 12 | cuda_dir = osp.join(src_dir, "cuda") 13 | cpu_dir = osp.join(src_dir, "cpu") 14 | return glob(osp.join(src_dir, surfix)) + \ 15 | glob(osp.join(cuda_dir, surfix)) + \ 16 | glob(osp.join(cpu_dir, surfix)) 17 | 18 | 19 | def get_include_dir(module): 20 | include_dir = osp.join(*module.split("."), "include") 21 | if osp.exists(include_dir): 22 | return [osp.abspath(include_dir)] 23 | else: 24 | return [] 25 | 26 | def make_extension(name, module): 27 | if not torch.cuda.is_available(): return 28 | extersion = CUDAExtension 29 | return extersion(name=".".join([module, name]), 30 | sources=get_sources(module), 31 | include_dirs=get_include_dir(module), 32 | extra_compile_args={ 33 | "cxx": ["-g"], 34 | "nvcc": [ 35 | "-D__CUDA_NO_HALF_OPERATORS__", 36 | "-D__CUDA_NO_HALF_CONVERSIONS__", 37 | "-D__CUDA_NO_HALF2_OPERATORS__", 38 | ], 39 | }, 40 | define_macros=[("WITH_CUDA", None)]) 41 | 42 | setup( 43 | name="pointgroup_ops", 44 | ext_modules=[make_extension(name="pointgroup_ops_ext", 45 | module="pointgroup_ops")], 46 | packages=find_packages(), 47 | cmdclass={"build_ext": BuildExtension} 48 | ) -------------------------------------------------------------------------------- /sstnet/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Gorilla-Lab. All rights reserved. 2 | from .sstnet import SSTNet 3 | from .losses import SSTLoss 4 | from .func_helper import * 5 | 6 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 7 | -------------------------------------------------------------------------------- /sstnet/model/func_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Gorilla-Lab. All rights reserved. 2 | from typing import Optional, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from scipy.sparse import coo_matrix 8 | from torch_scatter import scatter_add 9 | from treelib import Tree 10 | 11 | import htree 12 | from cluster.hierarchy import linkage 13 | 14 | 15 | class Node: 16 | def __init__(self, 17 | feature: torch.Tensor, 18 | center: torch.Tensor, 19 | soft_label: Optional[torch.Tensor]=None, 20 | num: Optional[int]=1) -> None: 21 | super().__init__() 22 | self.feature = feature 23 | self.center = center 24 | self.soft_label = soft_label 25 | self.num = num 26 | 27 | 28 | def build_hierarchical_tree(affinity: torch.Tensor, 29 | features: torch.Tensor, 30 | centers: torch.Tensor, 31 | affinity_count: torch.Tensor, 32 | batch_idxs: torch.Tensor, 33 | soft_label: Optional[torch.Tensor]=None): 34 | r""" 35 | build the hierarchical tree 36 | 37 | Args: 38 | affinity (torch.Tensor, [num_leaves, C]): affinity of nodes 39 | features (torch.Tensor, [num_leaves, C']): features of nodes 40 | centers (torch.Tensor, [num_leaves, 3]): centers of nodes 41 | affinity_count (torch.Tensor, [num_leaves]): point count of nodes 42 | batch_idxs (torch.Tensor, [num_leaves]): batch idxs of nodes 43 | soft_label (Optional[torch.Tensor], [num_leaves, num_label + 1]): soft label of nodes. Default to None 44 | 45 | Returns: 46 | list of tree and tree_connection 47 | """ 48 | tree_list = [] 49 | hierarchical_tree_list = [] 50 | scores_features_list = [] 51 | labels_list = [] 52 | nodes_list = [] 53 | # build hierarchical tree for each batch 54 | for batch_idx in torch.unique(batch_idxs): 55 | ids = (batch_idxs == batch_idx) 56 | num_batch = ids.sum() 57 | batch_centers = centers[ids] # [num_batch, 3] 58 | batch_affinity = affinity[ids] # [num_batch, C] 59 | batch_features = features[ids] # [num_batch, C'] 60 | batch_soft_label = soft_label[ids] 61 | 62 | # build tree by affinity 63 | batch_affinity_count = affinity_count[ids] # [num_batch] 64 | affinity_np = batch_affinity.detach().cpu().numpy() 65 | affinity_np = np.concatenate([affinity_np, batch_affinity_count[:, None].cpu().numpy()], axis=1) # [num_leaves, C+1] 66 | tree_connection = linkage(affinity_np, method="average", with_observation=True) 67 | tree_connection = tree_connection[:, :2].astype(np.int) 68 | 69 | # add leaf nodes 70 | node_list = [Node(batch_features[i], 71 | batch_centers[i], 72 | batch_soft_label[i], 73 | batch_affinity_count[i]) for i in range(num_batch)] 74 | 75 | num_nodes = tree_connection.max() + 1 76 | 77 | connection = tree_connection.tolist() 78 | hierarchical_tree = htree.Tree(connection) 79 | hierarchical_tree_list.append(hierarchical_tree) 80 | 81 | # get the fusion process and cuda 82 | left_leaves, left_ids, right_leaves, right_ids, fusion_leaves, fusion_ids = hierarchical_tree.fusion_record() 83 | left_leaves = torch.Tensor(left_leaves).long().to(affinity.device) 84 | left_ids = torch.Tensor(left_ids).long().to(affinity.device) 85 | right_leaves = torch.Tensor(right_leaves).long().to(affinity.device) 86 | right_ids = torch.Tensor(right_ids).long().to(affinity.device) 87 | fusion_leaves = torch.Tensor(fusion_leaves).long().to(affinity.device) 88 | fusion_ids = torch.Tensor(fusion_ids).long().to(affinity.device) 89 | 90 | # record the fusion nodes 91 | fusion_affinity_count = scatter_add(batch_affinity_count[fusion_leaves], fusion_ids, dim=0) # [num_fusion] 92 | node_features = get_fusion_property(batch_features, batch_affinity_count, fusion_leaves, fusion_ids, fusion_affinity_count) 93 | node_centers = get_fusion_property(batch_centers, batch_affinity_count, fusion_leaves, fusion_ids, fusion_affinity_count) 94 | node_soft_label = get_fusion_property(batch_soft_label, batch_affinity_count, fusion_leaves, fusion_ids, fusion_affinity_count) 95 | 96 | # addd intermidiate tree node 97 | for i in range(len(connection)): 98 | node_list.append( 99 | Node( 100 | node_features[i], 101 | node_centers[i], 102 | node_soft_label[i], 103 | fusion_affinity_count[i])) 104 | 105 | # get the left and right children's property for each node 106 | left_affinity_count = scatter_add(batch_affinity_count[left_leaves], left_ids, dim=0) # [num_fusion] 107 | left_features = get_fusion_property(batch_features, batch_affinity_count, left_leaves, left_ids, left_affinity_count) 108 | left_soft_label = get_fusion_property(batch_soft_label, batch_affinity_count, left_leaves, left_ids, left_affinity_count) 109 | right_affinity_count = scatter_add(batch_affinity_count[right_leaves], right_ids, dim=0) # [num_fusion] 110 | right_features = get_fusion_property(batch_features, batch_affinity_count, right_leaves, right_ids, right_affinity_count) 111 | right_soft_label = get_fusion_property(batch_soft_label, batch_affinity_count, right_leaves, right_ids, right_affinity_count) 112 | 113 | features_list = torch.cat([torch.cat([left_features, right_features], dim=1)[:, None, :], 114 | torch.cat([right_features, left_features], dim=1)[:, None, :]], 115 | dim=1) # [num_nodes, 2, C * 2] 116 | scores_features = features_list.view(-1, features_list.shape[-1]) # [num_nodes * 2, C * 2] 117 | fusion_scores = (left_soft_label * right_soft_label).sum(dim=1) # [num_nodes] 118 | labels = fusion_scores[:, None].repeat(1, 2).view(-1) # [num_nodes * 2] 119 | node_id_list = list(range(num_batch, num_batch + len(tree_connection))) # [num_nodes] 120 | 121 | # inverse range to realize traverse top-down 122 | num_all_nodes = len(node_id_list) 123 | scores_features = scores_features[range(-1, -(num_all_nodes*2+1), -1)] # [num_nodes * 2, C] 124 | labels = labels[range(-1, -(num_all_nodes*2+1), -1)] # [num_nodes * 2] 125 | nodes = torch.Tensor(node_id_list).to(scores_features.device)[range(-1, -(num_all_nodes+1), -1)] # [num_nodes] 126 | scores_features_list.append(scores_features) 127 | labels_list.append(labels) 128 | nodes_list.append(nodes) 129 | 130 | tree = Tree() 131 | tree.create_node(num_nodes, num_nodes, data=node_list[num_nodes]) # root node 132 | for connection in tree_connection[::-1]: 133 | c0, c1 = connection 134 | tree.create_node(c0, c0, parent=num_nodes, data=node_list[c0]) 135 | tree.create_node(c1, c1, parent=num_nodes, data=node_list[c1]) 136 | num_nodes -= 1 137 | 138 | tree_list.append(tree) 139 | 140 | return hierarchical_tree_list, tree_list, scores_features_list, labels_list, nodes_list 141 | 142 | 143 | def get_fusion_property(properties: torch.Tensor, 144 | count: torch.Tensor, 145 | leaves: torch.Tensor, 146 | ids: torch.Tensor, 147 | nodes_count: torch.Tensor) -> torch.Tensor: 148 | r"""get the fused properties of fusion for nodes(HNIR) 149 | 150 | Args: 151 | properties (torch.Tensor, [N, C]): properties to be fused 152 | count (torch.Tensor, [num_leaves]): points number of each leaf 153 | leaves (torch.Tensor, [num_leaves]): leaf ids to label properties 154 | ids (torch.Tensor, [num_leaves]): node ids of each leaf 155 | nodes_count (torch.Tensor, [num_nodes]): points number of each node 156 | 157 | Returns: 158 | torch.Tensor: [description] 159 | """ 160 | num_leaves = leaves.shape[0] 161 | properties = properties[leaves].view(num_leaves, -1) # [num_leaves, C] 162 | property_gain = properties * count[leaves].view(num_leaves, 1) # [num_leaves, C] 163 | properties = scatter_add(property_gain, ids, dim=0) # [num_nodes, C] 164 | properties = properties / nodes_count.view(-1, 1) # [num_nodes, C] 165 | return properties 166 | 167 | 168 | def align_superpoint_label(labels: torch.Tensor, 169 | superpoint: torch.Tensor, 170 | num_label: int=20, 171 | ignore_label: int=-100): 172 | r"""refine semantic segmentation by superpoint 173 | 174 | Args: 175 | labels (torch.Tensor, [N]): semantic label of points 176 | superpoint (torch.Tensor, [N]): superpoint cluster id of points 177 | num_label (int): number of valid label categories 178 | ignore_label (int): the ignore label id 179 | 180 | Returns: 181 | label: (torch.Tensor, [num_superpoint]): superpoint's label 182 | label_scores: (torch.Tensor, [num_superpoint, num_label + 1]): superpoint's label scores 183 | """ 184 | row = superpoint.cpu().numpy() # superpoint has been compression 185 | col = labels.cpu().numpy() 186 | col[col < 0] = num_label 187 | data = np.ones(len(superpoint)) 188 | shape = (len(np.unique(row)), num_label + 1) 189 | label_map = coo_matrix((data, (row, col)), shape=shape).toarray() # [num_superpoint, num_label + 1] 190 | label = torch.Tensor(np.argmax(label_map, axis=1)).long().to(labels.device) # [num_superpoint] 191 | label[label == num_label] = ignore_label # ignore_label 192 | label_scores = torch.Tensor(label_map / label_map.sum(axis=1)[:, None]).to(labels.device) # [num_superpoint, num_label + 1] 193 | 194 | return label, label_scores 195 | 196 | 197 | def voting_semantic_segmentation(semantic_preds: torch.Tensor, 198 | superpoint: torch.Tensor, 199 | num_semantic: int=20): 200 | r"""get semantic segmentation by superpoint voting 201 | 202 | Args: 203 | semantic_preds (torch.Tensor, [N]): semantic label of points 204 | superpoint (torch.Tensor, [N]): superpoint cluster id of points 205 | num_semantic (int): the number of semantic labels 206 | 207 | Returns: 208 | replace_semantic: (torch.Tensor, [N]): refine semantic label of points 209 | """ 210 | _, row = np.unique(superpoint.cpu().numpy(), return_inverse=True) 211 | col = semantic_preds.cpu().numpy() 212 | data = np.ones(len(superpoint)) 213 | shape = (len(np.unique(row)), num_semantic) 214 | semantic_map = coo_matrix((data, (row, col)), shape=shape).toarray() # [num_superpoint, num_semantic] 215 | semantic_map = torch.Tensor(np.argmax(semantic_map, axis=1)).to(semantic_preds.device) # [num_superpoint] 216 | replace_semantic = semantic_map[torch.Tensor(row).to(semantic_preds.device).long()] 217 | 218 | return replace_semantic 219 | 220 | 221 | def traversal_cluster(tree: Tree, 222 | nodes: List[int], 223 | fusion_labels: List[bool]): 224 | r""" 225 | get the cluster result by top-down bfs traversing hierachical tree 226 | 227 | Args: 228 | tree (treelib.Tree): [description] 229 | nodes (torch.Tensor, [num_nodes]): [description] 230 | scores (torch.Tensor, [num_nodes * 2]): [description] 231 | 232 | Returns: 233 | List[List[List[int]], List[int]], list of cluster superpoint id and list of node id 234 | """ 235 | queue = [tree.root] 236 | 237 | cluster_list = [] 238 | node_id_list = [] 239 | # refine_labels = [] 240 | nodes_ids = [] 241 | leaves_ids = [] 242 | nodes_soft_label = [] 243 | leaves_soft_labels = [] 244 | while (len(queue) > 0): 245 | # get aim point id from queue 246 | node_id = queue.pop(0) 247 | idx = nodes.index(node_id) 248 | if fusion_labels[idx]: 249 | leaves = [l.tag for l in tree.leaves(node_id)] 250 | cluster_list.append(leaves) 251 | node_id_list.append(node_id) 252 | nodes_ids.extend([node_id] * len(leaves)) 253 | leaves_ids.extend(leaves) 254 | nodes_soft_label.extend([tree.get_node(node_id).data.soft_label] * len(leaves)) 255 | leaves_soft_labels.extend([tree.get_node(l).data.soft_label for l in leaves]) 256 | else: 257 | child = tree.children(node_id) 258 | for c in child: 259 | nid = c.tag 260 | # child 261 | if len(tree.children(nid)) > 0: 262 | queue.append(nid) 263 | 264 | try: 265 | nodes_soft_label = torch.stack(nodes_soft_label) 266 | leaves_soft_labels = torch.stack(leaves_soft_labels) 267 | refine_labels = (nodes_soft_label * leaves_soft_labels).sum(1) 268 | except: 269 | cluster_list = None 270 | node_id_list = None 271 | refine_labels = None 272 | 273 | return cluster_list, node_id_list, refine_labels 274 | 275 | 276 | def build_superpoint_clique(tree: Tree, 277 | node_id_list: List[List[int]]): 278 | r"""build the superpoint clique for refinement 279 | 280 | Args: 281 | tree (Tree): input sstnet 282 | node_id_list (List[List[int]]): node ids of each proposal 283 | """ 284 | num_leaves = len(tree.leaves(tree.root)) 285 | num_graph_nodes = num_leaves + len(node_id_list) 286 | # self connection 287 | dense_matrix = torch.eye(num_graph_nodes).float() 288 | # conver the sub_tree as graph 289 | for idx, node_id in enumerate(node_id_list): 290 | leaves = tree.leaves(node_id) 291 | root_id = num_leaves + idx 292 | for leaf in leaves: 293 | nid = leaf.tag 294 | dense_matrix[nid, root_id] = 1 295 | dense_matrix[root_id, nid] = 1 296 | dense_matrix = F.normalize(dense_matrix, p=1, dim=-2) 297 | 298 | # construct sparse matrix to represent graph connection 299 | indices = torch.where(dense_matrix > 0) 300 | i = torch.stack(indices) 301 | v = dense_matrix[indices] 302 | adjancy_matrix = torch.sparse.FloatTensor(i, v, torch.Size([num_graph_nodes, num_graph_nodes])) 303 | return adjancy_matrix 304 | 305 | 306 | def get_proposals_idx(superpoint: torch.Tensor, cluster_list: List[List[int]]): 307 | r""" 308 | get proposals idx(mask) from superpoint clusters 309 | 310 | Args: 311 | superpoint (torch.Tensor): superpoint ids 312 | cluster_list (List[List[int]]): List of cluster ids 313 | 314 | Returns: 315 | proposals_idx 316 | """ 317 | superpoint_np = superpoint.cpu().numpy() 318 | proposals_idx_list = [] 319 | cluster_id = 0 320 | for cluster in cluster_list: 321 | proposals_idx = np.where(np.isin(superpoint_np, cluster))[0] 322 | clusters_id = np.ones_like(proposals_idx) * cluster_id 323 | proposals_idx = np.stack([clusters_id, proposals_idx], axis=1) 324 | if len(proposals_idx) < 50: 325 | continue 326 | proposals_idx_list.append(proposals_idx) 327 | cluster_id += 1 328 | proposals_idx = np.concatenate(proposals_idx_list) 329 | proposals_idx = torch.from_numpy(proposals_idx) 330 | 331 | return proposals_idx 332 | 333 | -------------------------------------------------------------------------------- /sstnet/model/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Gorilla-Lab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_scatter import scatter_mean 6 | 7 | import pointgroup_ops 8 | import gorilla 9 | from gorilla.losses import dice_loss_multi_calsses, iou_guided_loss 10 | from .func_helper import * 11 | 12 | 13 | @gorilla.LOSSES.register_module() 14 | class SSTLoss(nn.Module): 15 | def __init__(self, 16 | ignore_label: int, 17 | fusion_epochs: int=128, 18 | score_epochs: int=160, 19 | bg_thresh: float=0.25, 20 | fg_thresh: float=0.75, 21 | semantic_dice: bool=True, 22 | loss_weight: List[float]=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]): 23 | super().__init__() 24 | self.ignore_label = ignore_label 25 | self.fusion_epochs = fusion_epochs 26 | self.score_epochs = score_epochs 27 | self.bg_thresh = bg_thresh 28 | self.fg_thresh = fg_thresh 29 | self.semantic_dice = semantic_dice 30 | self.loss_weight = loss_weight 31 | 32 | #### criterion 33 | self.semantic_criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_label) 34 | self.score_criterion = nn.BCELoss(reduction="none") 35 | 36 | def forward(self, loss_inp, epoch): 37 | loss_out = {} 38 | fusion_flag = (epoch > self.fusion_epochs) 39 | prepare_flag = (epoch > self.score_epochs) 40 | 41 | """semantic loss""" 42 | semantic_scores, semantic_labels = loss_inp["semantic_scores"] 43 | # semantic_scores: (N, nClass), float32, cuda 44 | # semantic_labels: (N), long, cuda 45 | 46 | semantic_loss = self.semantic_criterion(semantic_scores, semantic_labels) 47 | if self.semantic_dice: 48 | filter_ids = (semantic_labels != self.ignore_label) 49 | semantic_scores = semantic_scores[filter_ids] 50 | semantic_scores = F.softmax(semantic_scores, dim=-1) 51 | semantic_labels = semantic_labels[filter_ids] 52 | one_hot_labels = F.one_hot(semantic_labels, num_classes=20) 53 | semantic_loss += dice_loss_multi_calsses(semantic_scores, one_hot_labels).mean() 54 | loss_out["semantic_loss"] = (semantic_loss, semantic_scores.shape[0]) 55 | 56 | """offset loss""" 57 | pt_offsets, coords, instance_info, instance_labels, instance_pointnum = loss_inp["pt_offsets"] 58 | # pt_offsets: (N, 3), float, cuda 59 | # coords: (N, 3), float32 60 | # instance_info: (N, 9), float32 tensor (meanxyz, minxyz, maxxyz) 61 | # instance_labels: (N), long 62 | # instance_pointnum: (total_num_inst), int 63 | 64 | gt_offsets = instance_info[:, 0:3] - coords # [N, 3] 65 | pt_diff = pt_offsets - gt_offsets # [N, 3] 66 | pt_dist = torch.sum(torch.abs(pt_diff), dim=-1) # [N] 67 | valid = (instance_labels != self.ignore_label) 68 | 69 | offset_norm_loss = torch.sum(pt_dist * valid) / (valid.sum() + 1e-6) 70 | 71 | gt_offsets_norm = torch.norm(gt_offsets, p=2, dim=1) # [N], float 72 | gt_offsets_ = gt_offsets / (gt_offsets_norm.unsqueeze(-1) + 1e-8) 73 | pt_offsets_norm = torch.norm(pt_offsets, p=2, dim=1) 74 | pt_offsets_ = pt_offsets / (pt_offsets_norm.unsqueeze(-1) + 1e-8) 75 | direction_diff = - (gt_offsets_ * pt_offsets_).sum(-1) # [N] 76 | offset_dir_loss = torch.sum(direction_diff * valid) / (valid.sum() + 1e-6) 77 | 78 | loss_out["offset_norm_loss"] = (offset_norm_loss, valid.sum()) 79 | loss_out["offset_dir_loss"] = (offset_dir_loss, valid.sum()) 80 | 81 | empty_flag = loss_inp["empty_flag"] 82 | """superpoint clustering loss""" 83 | if fusion_flag: 84 | fusion_scores, fusion_labels = loss_inp["fusion"] 85 | # fusion_scores: [num_superpoint - 1], float 86 | # fusion_labels: [num_superpoint - 1], float 87 | fusion_loss = F.binary_cross_entropy_with_logits(fusion_scores, fusion_labels) 88 | fusion_count = fusion_labels.shape[0] 89 | loss_out["fusion_loss"] = (fusion_loss, fusion_count) 90 | 91 | if "refine" in loss_inp and not empty_flag: 92 | """refine loss""" 93 | (refine_scores, refine_labels) = loss_inp["refine"] 94 | refine_loss = F.binary_cross_entropy_with_logits(refine_scores, refine_labels) 95 | refine_count = refine_labels.shape[0] 96 | loss_out["refine_loss"] = (refine_loss, refine_count) 97 | if prepare_flag and not empty_flag: 98 | proposals_idx, proposals_offset = loss_inp["proposals"] 99 | # scores: (num_prop, 1), float32 100 | # proposals_idx: (sum_points, 2), int, cpu, dim 0 for cluster_id, dim 1 for corresponding point idxs in N 101 | # proposals_offset: (num_prop + 1), int, cpu 102 | ious = pointgroup_ops.get_iou(proposals_idx[:, 1].int().cuda(), 103 | proposals_offset.int().cuda(), 104 | instance_labels, 105 | instance_pointnum.int()) # [num_prop, num_inst], float 106 | gt_ious, gt_inst_idxs = ious.max(1) # [num_prop] float, long 107 | """score loss""" 108 | scores = loss_inp["proposal_scores"] 109 | gt_ious, gt_inst_idxs = ious.max(1) # [num_prop] float, long 110 | score_loss = iou_guided_loss(scores.view(-1), gt_ious, self.fg_thresh, self.bg_thresh, use_sigmoid=False) 111 | score_loss = score_loss.mean() 112 | 113 | loss_out["score_loss"] = (score_loss, gt_ious.shape[0]) 114 | 115 | """total loss""" 116 | # loss = fusion_loss 117 | loss = self.loss_weight[0] * semantic_loss + self.loss_weight[1] * offset_norm_loss + self.loss_weight[2] * offset_dir_loss 118 | 119 | if fusion_flag: 120 | loss += fusion_loss 121 | 122 | if prepare_flag and not empty_flag: 123 | loss += (self.loss_weight[3] * score_loss) 124 | 125 | if "refine" in loss_inp and not empty_flag: 126 | loss += refine_loss 127 | 128 | return loss, loss_out 129 | 130 | 131 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Gorilla-Lab. All rights reserved. 2 | import argparse 3 | import numpy as np 4 | import os 5 | 6 | import torch 7 | import spconv 8 | import scipy.stats as stats 9 | 10 | import pointgroup_ops 11 | import gorilla 12 | import gorilla3d 13 | import gorilla3d.datasets as datasets 14 | import sstnet 15 | 16 | def get_parser(): 17 | parser = argparse.ArgumentParser(description="SSTNet for Point Cloud Instance Segmentation") 18 | parser.add_argument("--config", 19 | type=str, 20 | default="config/default.yaml", 21 | help="path to config file") 22 | ### pretrain 23 | parser.add_argument("--pretrain", 24 | type=str, 25 | default="", 26 | help="path to pretrain model") 27 | ### split 28 | parser.add_argument("--split", 29 | type=str, 30 | default="val", 31 | help="dataset split to test") 32 | ### semantic only 33 | parser.add_argument("--semantic", 34 | action="store_true", 35 | help="only evaluate semantic segmentation") 36 | ### log file path 37 | parser.add_argument("--log-file", 38 | type=str, 39 | default=None, 40 | help="log_file path") 41 | ### test srcipt operation 42 | parser.add_argument("--eval", 43 | action="store_true", 44 | help="evaluate or not") 45 | parser.add_argument("--save", 46 | action="store_true", 47 | help="save results or not") 48 | parser.add_argument("--visual", 49 | type=str, 50 | default=None, 51 | help="visual path, give to save visualization results") 52 | 53 | args_cfg = parser.parse_args() 54 | 55 | return args_cfg 56 | 57 | 58 | def init(): 59 | args = get_parser() 60 | cfg = gorilla.Config.fromfile(args.config) 61 | cfg.pretrain = args.pretrain 62 | cfg.semantic = args.semantic 63 | cfg.dataset.task = args.split # change tasks 64 | cfg.data.visual = args.visual 65 | cfg.data.eval = args.eval 66 | cfg.data.save = args.save 67 | 68 | gorilla.set_random_seed(cfg.data.test_seed) 69 | 70 | #### get logger file 71 | params_dict = dict( 72 | epoch=cfg.data.test_epoch, 73 | optim=cfg.optimizer.type, 74 | lr=cfg.optimizer.lr, 75 | scheduler=cfg.lr_scheduler.type 76 | ) 77 | if "test" in args.split: 78 | params_dict["suffix"] = "test" 79 | 80 | log_dir, logger = gorilla.collect_logger( 81 | prefix=os.path.splitext(args.config.split("/")[-1])[0], 82 | log_name="test", 83 | log_file=args.log_file, 84 | # **params_dict 85 | ) 86 | 87 | logger.info( 88 | "************************ Start Logging ************************") 89 | 90 | # log the config 91 | logger.info(cfg) 92 | 93 | global result_dir 94 | result_dir = os.path.join( 95 | log_dir, "result", 96 | "epoch{}_nmst{}_scoret{}_npointt{}".format(cfg.data.test_epoch, 97 | cfg.data.TEST_NMS_THRESH, 98 | cfg.data.TEST_SCORE_THRESH, 99 | cfg.data.TEST_NPOINT_THRESH), 100 | args.split) 101 | os.makedirs(os.path.join(result_dir, "predicted_masks"), exist_ok=True) 102 | 103 | global semantic_label_idx 104 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 105 | semantic_label_idx = torch.tensor([ 106 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39 107 | ]).cuda() 108 | 109 | return logger, cfg 110 | 111 | 112 | def test(model, cfg, logger): 113 | logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>") 114 | 115 | epoch = cfg.data.test_epoch 116 | semantic = cfg.semantic 117 | 118 | cfg.dataset.test_mode = True 119 | cfg.dataloader.batch_size = 1 120 | cfg.dataloader.num_workers = 2 121 | test_dataset = gorilla.build_dataset(cfg.dataset) 122 | test_dataloader = gorilla.build_dataloader(test_dataset, cfg.dataloader) 123 | 124 | with torch.no_grad(): 125 | model = model.eval() 126 | 127 | # init timer to calculate time 128 | timer = gorilla.Timer() 129 | 130 | # define evaluator 131 | # get the real data root 132 | data_root = os.path.join(os.path.dirname(__file__), cfg.dataset.data_root) 133 | sub_dir = "scans_test" if "test" in cfg.dataset.task else "scans" 134 | 135 | label_root = os.path.join(data_root, cfg.dataset.task + "_gt") 136 | evaluator = gorilla3d.ScanNetSemanticEvaluator(label_root) 137 | inst_evaluator = gorilla3d.ScanNetInstanceEvaluator(label_root) 138 | 139 | for i, batch in enumerate(test_dataloader): 140 | torch.cuda.empty_cache() 141 | timer.reset() 142 | N = batch["feats"].shape[0] 143 | test_scene_name = batch["scene_list"][0] 144 | 145 | coords = batch["locs"].cuda() # [N, 1 + 3] dimension 0 for batch_idx 146 | locs_offset = batch["locs_offset"].cuda() # [B, 3] 147 | voxel_coords = batch["voxel_locs"].cuda() # [M, 1 + 3] 148 | p2v_map = batch["p2v_map"].cuda() # [N] 149 | v2p_map = batch["v2p_map"].cuda() # [M, 1 + maxActive] 150 | 151 | coords_float = batch["locs_float"].cuda() # [N, 3] 152 | feats = batch["feats"].cuda() # [N, C] 153 | 154 | batch_offsets = batch["offsets"].cuda() # [B + 1] 155 | scene_list = batch["scene_list"] 156 | superpoint = batch["superpoint"].cuda() # [N 157 | _, superpoint = torch.unique(superpoint, return_inverse=True) # [N] 158 | 159 | extra_data = {"batch_idxs": coords[:, 0].int(), 160 | "superpoint": superpoint, 161 | "locs_offset": locs_offset, 162 | "scene_list": scene_list} 163 | 164 | spatial_shape = batch["spatial_shape"] 165 | 166 | if cfg.model.use_coords: 167 | feats = torch.cat((feats, coords_float), 1) 168 | voxel_feats = pointgroup_ops.voxelization(feats, v2p_map, cfg.data.mode) # [M, C] 169 | 170 | input_ = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, cfg.dataloader.batch_size) 171 | 172 | data_time = timer.since_last() 173 | 174 | ret = model(input_, 175 | p2v_map, 176 | coords_float, 177 | epoch, 178 | extra_data, 179 | mode="test", 180 | semantic_only=semantic) 181 | 182 | semantic_scores = ret["semantic_scores"] # [N, nClass] 183 | pt_offsets = ret["pt_offsets"] # [N, 3] 184 | 185 | score_epochs = cfg.model.score_epochs 186 | prepare_flag = epoch > score_epochs 187 | if prepare_flag and not semantic: 188 | scores = ret["proposal_scores"] 189 | 190 | ##### preds 191 | with torch.no_grad(): 192 | preds = {} 193 | preds["semantic"] = semantic_scores 194 | preds["pt_offsets"] = pt_offsets 195 | if prepare_flag and not semantic: 196 | proposals_idx, proposals_offset = ret["proposals"] 197 | preds["score"] = scores 198 | preds["proposals"] = (proposals_idx, proposals_offset) 199 | 200 | 201 | ##### get predictions (#1 semantic_pred, pt_offsets; #2 scores, proposals_pred) 202 | semantic_scores = preds["semantic"] # [N, nClass=20] 203 | semantic_pred = semantic_scores.max(1)[1] # [N] 204 | 205 | pt_offsets = preds["pt_offsets"] # [N, 3] 206 | 207 | ##### semantic segmentation evaluation 208 | if cfg.data.eval: 209 | inputs = [{"scene_name": test_scene_name}] 210 | outputs = [{"semantic_pred": semantic_pred}] 211 | evaluator.process(inputs, outputs) 212 | 213 | if prepare_flag and not semantic: 214 | scores = preds["score"] # [num_prop, 1] 215 | scores_pred = torch.sigmoid(scores.view(-1)) 216 | 217 | proposals_idx, proposals_offset = preds["proposals"] 218 | # proposals_idx: (sumNPoint, 2) dim 0 for cluster_id, dim 1 for corresponding point idxs in N 219 | # proposals_offset: (num_prop + 1) 220 | proposals_pred = torch.zeros( 221 | (proposals_offset.shape[0] - 1, N), 222 | dtype=torch.int, 223 | device=scores_pred.device) # [num_prop, N] 224 | proposals_pred[proposals_idx[:, 0].long(), 225 | proposals_idx[:, 1].long()] = 1 226 | semantic_pred_list = [] 227 | for start, end in zip(proposals_offset[:-1], 228 | proposals_offset[1:]): 229 | semantic_label, _ = stats.mode( 230 | semantic_pred[proposals_idx[start:end, 231 | 1].long()].cpu().numpy()) 232 | semantic_label = semantic_label[0] 233 | semantic_pred_list.append(semantic_label) 234 | 235 | semantic_id = semantic_label_idx[semantic_pred_list] 236 | 237 | ##### score threshold 238 | score_mask = (scores_pred > cfg.data.TEST_SCORE_THRESH) 239 | scores_pred = scores_pred[score_mask] 240 | proposals_pred = proposals_pred[score_mask] 241 | semantic_id = semantic_id[score_mask] 242 | 243 | ##### npoint threshold 244 | proposals_pointnum = proposals_pred.sum(1) 245 | npoint_mask = (proposals_pointnum > cfg.data.TEST_NPOINT_THRESH) 246 | scores_pred = scores_pred[npoint_mask] 247 | proposals_pred = proposals_pred[npoint_mask] 248 | semantic_id = semantic_id[npoint_mask] 249 | 250 | ##### nms 251 | if semantic_id.shape[0] == 0: 252 | pick_idxs = np.empty(0) 253 | else: 254 | proposals_pred_f = proposals_pred.float( 255 | ) # [num_prop, N], float, cuda 256 | intersection = torch.mm( 257 | proposals_pred_f, proposals_pred_f.t( 258 | )) # [num_prop, num_prop], float, cuda 259 | proposals_pointnum = proposals_pred_f.sum( 260 | 1) # [num_prop], float, cuda 261 | proposals_pn_h = proposals_pointnum.unsqueeze(-1).repeat( 262 | 1, proposals_pointnum.shape[0]) 263 | proposals_pn_v = proposals_pointnum.unsqueeze(0).repeat( 264 | proposals_pointnum.shape[0], 1) 265 | cross_ious = intersection / (proposals_pn_h + 266 | proposals_pn_v - intersection) 267 | 268 | pick_idxs = gorilla3d.non_max_suppression( 269 | cross_ious.cpu().numpy(), 270 | scores_pred.cpu().numpy(), 271 | cfg.data.TEST_NMS_THRESH) # int, (nCluster, N) 272 | clusters = proposals_pred[pick_idxs] 273 | cluster_scores = scores_pred[pick_idxs] 274 | cluster_semantic_id = semantic_id[pick_idxs] 275 | 276 | nclusters = clusters.shape[0] 277 | 278 | ##### prepare for evaluation 279 | if cfg.data.eval: 280 | pred_info = {} 281 | pred_info["scene_name"] = test_scene_name 282 | pred_info["conf"] = cluster_scores.cpu().numpy() 283 | pred_info["label_id"] = cluster_semantic_id.cpu().numpy() 284 | pred_info["mask"] = clusters.cpu().numpy() 285 | inst_evaluator.process(inputs, [pred_info]) 286 | 287 | inference_time = timer.since_last() 288 | 289 | ##### visual 290 | if cfg.data.visual is not None: 291 | # visual semantic result 292 | gorilla.check_dir(cfg.data.visual) 293 | if cfg.semantic: 294 | pass 295 | # visual instance result 296 | else: 297 | datasets.visualize_instance_mask( 298 | clusters.cpu().numpy(), 299 | test_scene_name, 300 | cfg.data.visual, 301 | os.path.join(data_root, sub_dir), 302 | cluster_scores.cpu().numpy(), 303 | semantic_pred.cpu().numpy(),) 304 | 305 | ##### save files 306 | if (prepare_flag and cfg.data.save): 307 | f = open(os.path.join(result_dir, test_scene_name + ".txt"), "w") 308 | for proposal_id in range(nclusters): 309 | clusters_i = clusters[proposal_id].cpu().numpy() # [N] 310 | semantic_label = np.argmax( 311 | np.bincount( 312 | semantic_pred[np.where(clusters_i == 1)[0]].cpu())) 313 | score = cluster_scores[proposal_id] 314 | f.write(f"predicted_masks/{test_scene_name}_{proposal_id:03d}.txt " 315 | f"{semantic_label_idx[semantic_label]} {score:.4f}") 316 | if proposal_id < nclusters - 1: 317 | f.write("\n") 318 | content = list(map(lambda x: str(x), clusters_i.tolist())) 319 | content = "\n".join(content) 320 | with open( 321 | os.path.join( 322 | result_dir, "predicted_masks", 323 | test_scene_name + "_%03d.txt" % (proposal_id)), 324 | "w") as cf: 325 | cf.write(content) 326 | # np.savetxt(os.path.join(result_dir, "predicted_masks", test_scene_name + "_%03d.txt" % (proposal_id)), clusters_i, fmt="%d") 327 | f.close() 328 | 329 | save_time = timer.since_last() 330 | total_time = timer.since_start() 331 | 332 | ##### print 333 | if semantic: 334 | logger.info( 335 | f"instance iter: {i + 1}/{len(test_dataloader)} point_num: {N} " 336 | f"time: total {total_time:.2f}s data: {data_time:.2f}s " 337 | f"inference {inference_time:.2f}s save {save_time:.2f}s") 338 | else: 339 | logger.info( 340 | f"instance iter: {i + 1}/{len(test_dataloader)} point_num: {N} " 341 | f"ncluster: {nclusters} time: total {total_time:.2f}s data: {data_time:.2f}s " 342 | f"inference {inference_time:.2f}s save {save_time:.2f}s") 343 | 344 | ##### evaluation 345 | if cfg.data.eval: 346 | if not semantic: 347 | inst_evaluator.evaluate(prec_rec=False) 348 | evaluator.evaluate() 349 | 350 | 351 | if __name__ == "__main__": 352 | logger, cfg = init() 353 | 354 | ##### model 355 | logger.info("=> creating model ...") 356 | logger.info(f"Classes: {cfg.model.classes}") 357 | 358 | model = gorilla.build_model(cfg.model) 359 | 360 | use_cuda = torch.cuda.is_available() 361 | logger.info(f"cuda available: {use_cuda}") 362 | assert use_cuda 363 | model = model.cuda() 364 | 365 | # logger.info(model) 366 | logger.info(f"#classifier parameters (model): {sum([x.nelement() for x in model.parameters()])}") 367 | 368 | ##### load model 369 | gorilla.load_checkpoint( 370 | model, cfg.pretrain 371 | ) # resume from the latest epoch, or specify the epoch to restore 372 | 373 | ##### evaluate 374 | test(model, cfg, logger) 375 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Gorilla-Lab. All rights reserved. 2 | import glob 3 | import os.path as osp 4 | 5 | import torch 6 | import gorilla 7 | import gorilla3d 8 | import spconv 9 | 10 | import sstnet 11 | import pointgroup_ops 12 | 13 | def get_parser(): 14 | # the default argument parser contains some 15 | # essential parameters for distributed 16 | parser = gorilla.core.default_argument_parser() 17 | parser.add_argument("--config", 18 | type=str, 19 | default="config/default.yaml", 20 | help="path to config file") 21 | 22 | args_cfg = parser.parse_args() 23 | 24 | return args_cfg 25 | 26 | 27 | def do_train(model, cfg, logger): 28 | model.train() 29 | # initilize optimizer and scheduler (scheduler is optional-adjust learning rate manually) 30 | optimizer = gorilla.build_optimizer(model, cfg.optimizer) 31 | lr_scheduler = gorilla.build_lr_scheduler(optimizer, cfg.lr_scheduler) 32 | 33 | # initialize criterion (Optional, can calculate in model forward) 34 | criterion = gorilla.build_loss(cfg.loss) 35 | 36 | # resume model/optimizer/scheduler 37 | iter = 1 38 | checkpoint, epoch = get_checkpoint(cfg.log_dir) 39 | if gorilla.is_filepath(checkpoint): # read valid checkpoint file 40 | # meta is the dict save some necessary information (last epoch/iteration, acc, loss) 41 | meta = gorilla.resume(model=model, 42 | filename=checkpoint, 43 | optimizer=optimizer, # optimizer and scheduler is optional 44 | scheduler=lr_scheduler, # to resume (can not give these paramters) 45 | resume_optimizer=True, 46 | resume_scheduler=True, 47 | strict=False, 48 | ) 49 | # get epoch from meta (Optional) 50 | epoch = meta.get("epoch", epoch) + 1 51 | iter = meta.get("iter", iter) + 1 52 | 53 | # initialize train dataset 54 | train_dataset = gorilla.build_dataset(cfg.dataset) 55 | train_dataloader = gorilla.build_dataloader(train_dataset, 56 | cfg.dataloader, 57 | shuffle=True, 58 | pin_memory=True, 59 | drop_last=True) 60 | 61 | # initialize tensorboard (Optional) TODO: integrating the tensorborad manager 62 | writer = gorilla.TensorBoardWriter(cfg.log_dir) # tensorboard writer 63 | 64 | # initialize timers (Optional) 65 | iter_timer = gorilla.Timer() 66 | epoch_timer = gorilla.Timer() 67 | 68 | # loss/time buffer for epoch record (Optional) 69 | loss_buffer = gorilla.HistoryBuffer() 70 | iter_time = gorilla.HistoryBuffer() 71 | data_time = gorilla.HistoryBuffer() 72 | 73 | while epoch <= cfg.data.epochs: 74 | for i, batch in enumerate(train_dataloader): 75 | torch.cuda.empty_cache() # (empty cuda cache, Optional) 76 | # calculate data loading time 77 | data_time.update(iter_timer.since_last()) 78 | # cuda manually (TODO: integrating the data cuda operation) 79 | ##### prepare input and forward 80 | coords = batch["locs"].cuda() # [N, 1 + 3], long, cuda, dimension 0 for batch_idx 81 | locs_offset = batch["locs_offset"].cuda() # [B, 3], long, cuda 82 | voxel_coords = batch["voxel_locs"].cuda() # [M, 1 + 3], long, cuda 83 | p2v_map = batch["p2v_map"].cuda() # [N], int, cuda 84 | v2p_map = batch["v2p_map"].cuda() # [M, 1 + maxActive], int, cuda 85 | 86 | coords_float = batch["locs_float"].cuda() # [N, 3], float32, cuda 87 | feats = batch["feats"].cuda() # [N, C], float32, cuda 88 | semantic_labels = batch["semantic_labels"].cuda() # [N], long, cuda 89 | instance_labels = batch["instance_labels"].cuda( 90 | ) # [N], long, cuda, 0~total_num_inst, -100 91 | 92 | instance_info = batch["instance_info"].cuda( 93 | ) # [N, 9], float32, cuda, (meanxyz, minxyz, maxxyz) 94 | instance_pointnum = batch["instance_pointnum"].cuda( 95 | ) # [total_num_inst], int, cuda 96 | 97 | batch_offsets = batch["offsets"].cuda() # [B + 1], int, cuda 98 | superpoint = batch["superpoint"].cuda() # [N], long, cuda 99 | _, superpoint = torch.unique(superpoint, return_inverse=True) # [N], long, cuda 100 | 101 | fusion_epochs = cfg.model.fusion_epochs 102 | score_epochs = cfg.model.score_epochs 103 | prepare_flag = (epoch > score_epochs) 104 | fusion_flag = (epoch > fusion_epochs) 105 | with_refine = cfg.model.with_refine 106 | scene_list = batch["scene_list"] 107 | spatial_shape = batch["spatial_shape"] 108 | 109 | extra_data = { 110 | "batch_idxs": coords[:, 0].int(), 111 | "superpoint": superpoint, 112 | "locs_offset": locs_offset, 113 | "scene_list": scene_list, 114 | "instance_labels": instance_labels, 115 | "instance_pointnum": instance_pointnum 116 | } 117 | 118 | if cfg.model.use_coords: 119 | feats = torch.cat((feats, coords_float), 1) 120 | voxel_feats = pointgroup_ops.voxelization( 121 | feats, v2p_map, cfg.data.mode) # [M, C] 122 | 123 | input_ = spconv.SparseConvTensor(voxel_feats, 124 | voxel_coords.int(), 125 | spatial_shape, 126 | cfg.dataloader.batch_size) 127 | 128 | ret = model(input_, 129 | p2v_map, 130 | coords_float, 131 | epoch, 132 | extra_data) 133 | 134 | semantic_scores = ret["semantic_scores"] # [N, nClass] float32, cuda 135 | pt_offsets = ret["pt_offsets"] # [N, 3], float32, cuda 136 | 137 | loss_inp = {} 138 | loss_inp["batch_idxs"] = coords[:, 0].int() 139 | loss_inp["feats"] = feats 140 | loss_inp["scene_list"] = scene_list 141 | 142 | loss_inp["semantic_scores"] = (semantic_scores, semantic_labels) 143 | loss_inp["pt_offsets"] = (pt_offsets, 144 | coords_float, 145 | instance_info, 146 | instance_labels, 147 | instance_pointnum) 148 | 149 | loss_inp["superpoint"] = superpoint 150 | loss_inp["empty_flag"] = ret["empty_flag"] # avoid stack error 151 | 152 | if fusion_flag: 153 | loss_inp["fusion"] = ret["fusion"] 154 | 155 | if with_refine: 156 | loss_inp["refine"] = ret["refine"] 157 | 158 | if prepare_flag: 159 | loss_inp["proposals"] = ret["proposals"] 160 | scores = ret["proposal_scores"] 161 | # scores: (num_prop, 1) float, cuda 162 | # proposals_idx: (sum_points, 2), int, cpu, dim 0 for cluster_id, dim 1 for corresponding point idxs in N 163 | # proposals_offset: (num_prop + 1), int, cpu 164 | 165 | loss_inp["proposal_scores"] = scores 166 | 167 | loss, loss_out = criterion(loss_inp, epoch) 168 | loss_buffer.update(loss) 169 | 170 | # sample the learning rate(Optional) 171 | lr = optimizer.param_groups[0]["lr"] 172 | # write tensorboard 173 | loss_out.update({"loss": loss, "lr": lr}) 174 | writer.update(loss_out, iter) 175 | # # equivalent write operation 176 | # writer.add_scalar(f"train/loss", loss, iter) 177 | # writer.add_scalar(f"lr", lr, iter) 178 | # # (NOTE: the `loss_out` is work for multi losses, which saves each loss item) 179 | # for k, v in loss_out.items(): 180 | # writer.add_scalar(f"train/{k}", v[0], iter) 181 | 182 | # backward 183 | optimizer.zero_grad() 184 | loss.backward() 185 | optimizer.step() 186 | iter += 1 187 | 188 | # calculate time and reset timer(Optional) 189 | iter_time.update(iter_timer.since_start()) 190 | iter_timer.reset() # record the iteration time and reset timer 191 | 192 | # TODO: the time manager will be integrated into gorilla-core 193 | # calculate remain time(Optional) 194 | remain_iter = (cfg.data.epochs - epoch + 1) * len(train_dataloader) + i + 1 195 | remain_time = gorilla.convert_seconds(remain_iter * iter_time.avg) # convert seconds into "hours:minutes:sceonds" 196 | 197 | print(f"epoch: {epoch}/{cfg.data.epochs} iter: {i + 1}/{len(train_dataloader)} " 198 | f"lr: {lr:4f} loss: {loss_buffer.latest:.4f}({loss_buffer.avg:.4f}) " 199 | f"data_time: {data_time.latest:.2f}({data_time.avg:.2f}) " 200 | f"iter_time: {iter_time.latest:.2f}({iter_time.avg:.2f}) eta: {remain_time}") 201 | 202 | # updata learning rate scheduler and epoch 203 | lr_scheduler.step() 204 | 205 | # log the epoch information 206 | logger.info(f"epoch: {epoch}/{cfg.data.epochs}, train loss: {loss_buffer.avg}, time: {epoch_timer.since_start()}s") 207 | iter_time.clear() 208 | data_time.clear() 209 | loss_buffer.clear() 210 | 211 | # write the important information into meta 212 | meta = {"epoch": epoch, 213 | "iter": iter} 214 | 215 | # save checkpoint 216 | checkpoint = osp.join(cfg.log_dir, "epoch_{0:05d}.pth".format(epoch)) 217 | if (epoch == fusion_epochs) or (epoch == fusion_epochs): 218 | gorilla.save_checkpoint(model=model, 219 | filename=checkpoint, 220 | optimizer=optimizer, 221 | scheduler=lr_scheduler, 222 | meta=meta) 223 | else: 224 | gorilla.save_checkpoint(model=model, 225 | filename=checkpoint, 226 | meta=meta) 227 | logger.info("Saving " + checkpoint) 228 | # save as latest checkpoint 229 | latest_checkpoint = osp.join(cfg.log_dir, "epoch_latest.pth") 230 | gorilla.save_checkpoint(model=model, 231 | filename=latest_checkpoint, 232 | optimizer=optimizer, 233 | scheduler=lr_scheduler, 234 | meta=meta) 235 | 236 | epoch += 1 237 | 238 | 239 | def get_checkpoint(log_dir, epoch=0, checkpoint=""): 240 | if not checkpoint: 241 | if epoch > 0: 242 | checkpoint = osp.join(log_dir, "epoch_{0:05d}.pth".format(epoch)) 243 | assert osp.isfile(checkpoint) 244 | else: 245 | latest_checkpoint = glob.glob(osp.join(log_dir, "*latest*.pth")) 246 | if len(latest_checkpoint) > 0: 247 | checkpoint = latest_checkpoint[0] 248 | else: 249 | checkpoint = sorted(glob.glob(osp.join(log_dir, "*.pth"))) 250 | if len(checkpoint) > 0: 251 | checkpoint = checkpoint[-1] 252 | epoch = int(checkpoint.split("_")[-1].split(".")[0]) 253 | 254 | return checkpoint, epoch + 1 255 | 256 | def main(args): 257 | # read config file 258 | cfg = gorilla.Config.fromfile(args.config) 259 | 260 | # get logger file 261 | log_dir, logger = gorilla.collect_logger( 262 | prefix=osp.splitext(osp.basename(args.config))[0]) 263 | #### NOTE: can initlize the logger manually 264 | # logger = gorilla.get_logger(log_file) 265 | 266 | # backup the necessary file and directory(Optional, details for source code) 267 | backup_list = ["train.py", "test.py", "sstnet", args.config] 268 | backup_dir = osp.join(log_dir, "backup") 269 | gorilla.backup(backup_dir, backup_list) 270 | 271 | # merge the paramters in args into cfg 272 | cfg = gorilla.config.merge_cfg_and_args(cfg, args) 273 | 274 | cfg.log_dir = log_dir 275 | 276 | # set random seed 277 | seed = cfg.get("seed", 0) 278 | gorilla.set_random_seed(seed) 279 | 280 | # model 281 | logger.info("=> creating model ...") 282 | 283 | # create model 284 | model = gorilla.build_model(cfg.model) 285 | model = model.cuda() 286 | if args.num_gpus > 1: 287 | # convert the BatchNorm in model as SyncBatchNorm (NOTE: this will be error for low-version pytorch!!!) 288 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 289 | # DDP wrap model 290 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gorilla.get_local_rank()], find_unused_parameters=True) 291 | 292 | # logger.info("Model:\n{}".format(model)) (Optional print model) 293 | 294 | # count the paramters of model (Optional) 295 | count_parameters = sum(gorilla.parameter_count(model).values()) 296 | logger.info(f"#classifier parameters new: {count_parameters}") 297 | 298 | # start training 299 | do_train(model, cfg, logger) 300 | 301 | 302 | if __name__ == "__main__": 303 | # get the args 304 | args = get_parser() 305 | 306 | # # auto using the free gpus 307 | # gorilla.set_cuda_visible_devices(num_gpu=args.num_gpus) 308 | 309 | gorilla.launch( 310 | main, 311 | args.num_gpus, 312 | num_machines=args.num_machines, 313 | machine_rank=args.machine_rank, 314 | dist_url=args.dist_url, 315 | args=(args,) # use tuple to wrap 316 | ) 317 | --------------------------------------------------------------------------------