├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets └── overview.png ├── config ├── default.gin ├── s3dis │ ├── eval_default.gin │ ├── eval_fpt.gin │ ├── eval_res16unet34c.gin │ ├── train_default.gin │ ├── train_fpt.gin │ └── train_res16unet34c.gin └── scannet │ ├── eval_default.gin │ ├── eval_fpt.gin │ ├── eval_res16unet34c.gin │ ├── train_default.gin │ ├── train_fpt.gin │ └── train_res16unet34c.gin ├── environment.yaml ├── eval.py ├── setup.sh ├── src ├── cscore │ ├── calculate.py │ └── prepare.py ├── cuda_ops │ ├── functions │ │ └── sparse_ops.py │ ├── setup.py │ └── src │ │ ├── cuda_ops_api.cpp │ │ ├── cuda_utils.h │ │ ├── dot_product │ │ ├── dot_product.cpp │ │ ├── dot_product_kernel.cu │ │ └── dot_product_kernel.h │ │ └── scalar_attention │ │ ├── scalar_attention.cpp │ │ ├── scalar_attention_kernel.cu │ │ └── scalar_attention_kernel.h ├── data │ ├── __init__.py │ ├── collate.py │ ├── meta_data │ │ ├── s3dis │ │ │ ├── area1.txt │ │ │ ├── area2.txt │ │ │ ├── area3.txt │ │ │ ├── area4.txt │ │ │ ├── area5.txt │ │ │ └── area6.txt │ │ └── scannet │ │ │ ├── scannetv2_test.txt │ │ │ ├── scannetv2_train.txt │ │ │ ├── scannetv2_trainval.txt │ │ │ └── scannetv2_val.txt │ ├── preprocess_s3dis.py │ ├── preprocess_scannet.py │ ├── s3dis_loader.py │ ├── sampler.py │ ├── scannet_loader.py │ └── transforms.py ├── models │ ├── __init__.py │ ├── common.py │ ├── fast_point_transformer.py │ ├── resnet.py │ ├── resunet.py │ ├── spvcnn.py │ └── transformer_base.py ├── modules │ ├── __init__.py │ └── segmentation.py └── utils │ ├── file.py │ ├── logger.py │ ├── metric.py │ ├── misc.py │ └── visualization.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .empty/ 3 | __pycache__/ 4 | wandb/ 5 | tensorboard/ 6 | 7 | src/cuda_ops/build/ 8 | src/cuda_ops/cuda_ops.egg-info/ 9 | src/cuda_ops/cuda_sparse_ops.egg-info/ 10 | src/cuda_ops/dist/ 11 | 12 | Open3D/ 13 | *ipynb* 14 | 15 | *.ply 16 | *.pyc 17 | *.png 18 | *.ckpt 19 | 20 | thirdparty/MinkowskiEngine/MinkowskiEngine.egg-info/ 21 | thirdparty/MinkowskiEngine/build/ 22 | thirdparty/MinkowskiEngine/dist/ 23 | thirdparty-temp/ 24 | 25 | experiments/ 26 | checkpoints/ 27 | consistency_outputs/ 28 | votenet-logs-temp/ 29 | sbatch/*.out -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "thirdparty/MinkowskiEngine"] 2 | path = thirdparty/MinkowskiEngine 3 | url = git@github.com:chrockey/MinkowskiEngine.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Chunghyun Park 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast Point Transformer 2 | ### [Project Page](http://cvlab.postech.ac.kr/research/FPT/) | [Paper](https://arxiv.org/abs/2112.04702) 3 | This repository contains the official source code and data for our paper: 4 | 5 | >[Fast Point Transformer](https://arxiv.org/abs/2112.04702) 6 | > [Chunghyun Park](https://chrockey.github.io/), 7 | > [Yoonwoo Jeong](https://yoonwoojeong.medium.com/about), 8 | > [Minsu Cho](http://cvlab.postech.ac.kr/~mcho/), and 9 | > [Jaesik Park](http://jaesik.info/)
10 | > POSTECH GSAI & CSE
11 | > CVPR, New Orleans, 2022. 12 | 13 |
14 | An Overview of the proposed pipeline 15 |
16 | 17 | ## Overview 18 | This work introduces *Fast Point Transformer* that consists of a new lightweight self-attention layer. Our approach encodes continuous 3D coordinates, and the voxel hashing-based architecture boosts computational efficiency. The proposed method is demonstrated with 3D semantic segmentation and 3D detection. The accuracy of our approach is competitive to the best voxel based method, and our network achieves 129 times faster inference time than the state-of-the-art, Point Transformer, with a reasonable accuracy trade-off in 3D semantic segmentation on S3DIS dataset. 19 | 20 | ## Citation 21 | If you find our code or paper useful, please consider citing our paper: 22 | 23 | ```BibTeX 24 | @inproceedings{park2022fast, 25 | title={Fast Point Transformer}, 26 | author={Park, Chunghyun and Jeong, Yoonwoo and Cho, Minsu and Park, Jaesik}, 27 | booktitle={Proceedings of the {IEEE/CVF} Conference on Computer Vision and Pattern Recognition (CVPR)}, 28 | month={June}, 29 | year={2022}, 30 | pages={16949-16958} 31 | } 32 | ``` 33 | 34 | ## Experiments 35 | ### 1. S3DIS Area 5 test 36 | We denote MinkowskiNet42 trained with this repository as MinkowskiNet42. 37 | We use voxel size 4cm for both MinkowskiNet42 and our Fast Point Transformer. 38 | 39 | | Model | Latency (sec) | mAcc (%) | mIoU (%) | Reference | 40 | |:----------------------------------|--------------------:|:--------:|:--------:|:---------:| 41 | | PointTransformer | 18.07 | 76.5 | 70.4 | [Codes from the authors](https://github.com/POSTECH-CVLab/point-transformer) | 42 | | MinkowskiNet42 | 0.08 | 74.1 | 67.2 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EZcO0DH6QeNGgIwGFZsmL-4BAlikmHAHlBs4JBcS5XfpVQ?download=1) | 43 | |   + rotation average | 0.66 | 75.1 | 69.0 | - | 44 | | FastPointTransformer | 0.14 | 76.6 | 69.2 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/ER8KwMTzqAxAvK9KeOZ9U_IBuCAuv4hP6zOWD-3HNO6Xeg?download=1) | 45 | |   + rotation average | 1.13 | 77.6 | 71.0 | - | 46 | 47 | ### 2. ScanNetV2 validation 48 | | Model | Voxel Size | mAcc (%) | mIoU (%) | Reference | 49 | |:----------------------------------|:-----------:|:--------:|:--------:|:---------:| 50 | | MinkowskiNet42 | 2cm | 80.4 | 72.2 | [Official GitHub](https://github.com/chrischoy/SpatioTemporalSegmentation) | 51 | | MinkowskiNet42 | 2cm | 81.4 | 72.1 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EXmE1pWDZ8lEtJU7SQMjkXcBnhSMXFTdHWXkMAAF7KeiuA?download=1) | 52 | | FastPointTransformer | 2cm | 81.2 | 72.5 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EX_xAyhoNXdJg4eSg2vS_bYB8eFAP7A8FPCYfKOS2T13LQ?download=1) | 53 | | MinkowskiNet42 | 5cm | 76.3 | 67.0 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EZLG00u5JXJDvOi3sYziOIMB1l6HNN5OW9gTQRFWc6EwzA?download=1) | 54 | | FastPointTransformer | 5cm | 78.9 | 70.0 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EXbXclfXZGtMpBZY93zi7M8B_tl8rwM65NK1cumN7QM_8g?download=1) | 55 | | MinkowskiNet42 | 10cm | 70.8 | 60.7 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EVLn0f5noY1Al6Kos9l-0yABM0qZLFt6d4a3yFgTcQ2Vmw?download=1) | 56 | | FastPointTransformer | 10cm | 76.1 | 66.5 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/ESO1jLNHO89ApdjguUauqsMBCx_TijA26UOeGbF4XxQwoA?download=1) | 57 | 58 | ## Installation 59 | This repository is developed and tested on 60 | 61 | - Ubuntu 18.04 and 20.04 62 | - Conda 4.11.0 63 | - CUDA 11.1 and 11.3 64 | - Python 3.8.13 65 | - PyTorch 1.7.1, 1.10.0, and 1.12.1 66 | - MinkowskiEngine 0.5.4 67 | 68 | ### Environment Setup 69 | You can install the environment by using the provided shell script: 70 | ```bash 71 | ~$ git clone --recursive git@github.com:POSTECH-CVLab/FastPointTransformer.git 72 | ~$ cd FastPointTransformer 73 | ~/FastPointTransformer$ bash setup.sh fpt 74 | ~/FastPointTransformer$ conda activate fpt 75 | ``` 76 | 77 | ## Training & Evaluation 78 | First of all, you need to download the datasets (ScanNetV2 and S3DIS), and preprocess them as: 79 | ```bash 80 | (fpt) ~/FastPointTransformer$ python src/data/preprocess_scannet.py # you need to modify the data path 81 | (fpt) ~/FastPointTransformer$ python src/data/preprocess_s3dis.py # you need to modify the data path 82 | ``` 83 | And then, locate the provided meta data of each dataset (`src/data/meta_data`) with the preprocessed dataset following the structure below: 84 | 85 | ``` 86 | ${data_dir} 87 | ├── scannetv2 88 | │ ├── meta_data 89 | │ │ ├── scannetv2_train.txt 90 | │ │ ├── scannetv2_val.txt 91 | │ │ └── ... 92 | │ └── scannet_processed 93 | │ ├── train 94 | │ │ ├── scene0000_00.ply 95 | │ │ ├── scene0000_01.ply 96 | │ │ └── ... 97 | │ └── test 98 | └── s3dis 99 | ├── meta_data 100 | │ ├── area1.txt 101 | │ ├── area2.txt 102 | │ └── ... 103 | └── s3dis_processed 104 | ├── Area_1 105 | │ ├── conferenceRoom_1.ply 106 | │ ├── conferenceRoom_2.ply 107 | │ └── ... 108 | ├── Area_2 109 | └── ... 110 | ``` 111 | 112 | After then, you can train and evalaute a model by using the provided python scripts (`train.py` and `eval.py`) with configuration files in the `config` directory. 113 | For example, you can train and evaluate Fast Point Transformer with voxel size 4cm on S3DIS dataset via the following commands: 114 | ```bash 115 | (fpt) ~/FastPointTransformer$ python train.py config/s3dis/train_fpt.gin 116 | (fpt) ~/FastPointTransformer$ python eval.py config/s3dis/eval_fpt.gin {checkpoint_file} # use -r option for rotation averaging. 117 | ``` 118 | 119 | ### Consistency Score 120 | You need to generate predictions via the following command: 121 | ```bash 122 | (fpt) ~/FastPointTransformer$ python -m src.cscore.prepare {checkpoint_file} -m {model_name} -v {voxel_size} # This takes hours. 123 | ``` 124 | Then, you can calculate the consistency score (CScore) with: 125 | ```bash 126 | (fpt) ~/FastPointTransformer$ python -m src.cscore.calculate {prediction_dir} # This takes seconds. 127 | ``` 128 | 129 | ### 3D Object Detection using VoteNet 130 | Please refer [this repository](https://github.com/chrockey/FastPointTransformer-VoteNet). 131 | 132 | ## Acknowledgement 133 | 134 | Our code is based on the [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine). 135 | We also thank [Hengshuang Zhao](https://hszhao.github.io/) for providing [the code](https://github.com/POSTECH-CVLab/point-transformer) of [Point Transformer](https://arxiv.org/abs/2012.09164). 136 | If you use our model, please consider citing them as well. 137 | -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/POSTECH-CVLab/FastPointTransformer/9b2cec39bc05f4921e4856d32fc563a6b43087cf/assets/overview.png -------------------------------------------------------------------------------- /config/default.gin: -------------------------------------------------------------------------------- 1 | # Training 2 | train.project_name = "FastPointTransformer-release" -------------------------------------------------------------------------------- /config/s3dis/eval_default.gin: -------------------------------------------------------------------------------- 1 | # Constants 2 | in_channels = 3 3 | out_channels = 13 4 | 5 | # Data module 6 | S3DISArea5RGBDataModule.data_root = "/root/data/s3dis" # you need to modify this according to your data. 7 | S3DISArea5RGBDataModule.train_batch_size = None 8 | S3DISArea5RGBDataModule.val_batch_size = 1 9 | S3DISArea5RGBDataModule.train_num_workers = None 10 | S3DISArea5RGBDataModule.val_num_workers = 4 11 | S3DISArea5RGBDataModule.collation_type = "collate_minkowski" 12 | S3DISArea5RGBDataModule.train_transforms = None 13 | S3DISArea5RGBDataModule.eval_transforms = [ 14 | "DimensionlessCoordinates", 15 | "NormalizeColor", 16 | ] 17 | 18 | # Augmentation 19 | DimensionlessCoordinates.voxel_size = 0.04 20 | 21 | # Evaluation 22 | eval.data_module_name = "S3DISArea5RGBDataModule" -------------------------------------------------------------------------------- /config/s3dis/eval_fpt.gin: -------------------------------------------------------------------------------- 1 | include "./config/s3dis/eval_res16unet34c.gin" 2 | 3 | # Model 4 | eval.model_name = "FastPointTransformer" 5 | FastPointTransformer.in_channels = %in_channels 6 | FastPointTransformer.out_channels = %out_channels -------------------------------------------------------------------------------- /config/s3dis/eval_res16unet34c.gin: -------------------------------------------------------------------------------- 1 | include "./config/s3dis/eval_default.gin" 2 | 3 | # Model 4 | eval.model_name = "Res16UNet34C" 5 | Res16UNet34C.in_channels = %in_channels 6 | Res16UNet34C.out_channels = %out_channels -------------------------------------------------------------------------------- /config/s3dis/train_default.gin: -------------------------------------------------------------------------------- 1 | include "./config/default.gin" 2 | 3 | # Constants 4 | in_channels = 3 5 | out_channels = 13 6 | 7 | # Data module 8 | S3DISArea5RGBDataModule.data_root = "/root/data/s3dis" # you need to modify this according to your data. 9 | S3DISArea5RGBDataModule.train_batch_size = 8 10 | S3DISArea5RGBDataModule.val_batch_size = 1 11 | S3DISArea5RGBDataModule.train_num_workers = 8 12 | S3DISArea5RGBDataModule.val_num_workers = 4 13 | S3DISArea5RGBDataModule.collation_type = "collate_minkowski" 14 | S3DISArea5RGBDataModule.train_transforms = [ 15 | "DimensionlessCoordinates", 16 | "RandomRotation", 17 | "RandomCrop", 18 | "RandomAffine", # affine to rotate the rectangular crop 19 | "CoordinateDropout", 20 | "ChromaticTranslation", 21 | "ChromaticJitter", 22 | "RandomHorizontalFlip", 23 | "RandomTranslation", 24 | "ElasticDistortion", 25 | "NormalizeColor", 26 | ] 27 | S3DISArea5RGBDataModule.eval_transforms = [ 28 | "DimensionlessCoordinates", 29 | "NormalizeColor", 30 | ] 31 | 32 | # Augmentation 33 | DimensionlessCoordinates.voxel_size = 0.04 34 | RandomCrop.x = 100 35 | RandomCrop.y = 100 36 | RandomCrop.z = 100 37 | RandomCrop.min_cardinality = 100 38 | RandomCrop.max_retries = 40 39 | RandomHorizontalFlip.upright_axis = "z" 40 | RandomAffine.upright_axis = "z" 41 | RandomAffine.application_ratio = 0.7 42 | ChromaticJitter.std = 0.01 43 | ChromaticJitter.application_ratio = 0.7 44 | ElasticDistortion.distortion_params = [(4, 16)] 45 | ElasticDistortion.application_ratio = 0.7 46 | 47 | # Pytorch lightning module 48 | LitSegmentationModuleBase.num_classes = %out_channels 49 | LitSegmentationModuleBase.lr = 0.1 50 | LitSegmentationModuleBase.momentum = 0.9 51 | LitSegmentationModuleBase.weight_decay = 1e-4 52 | LitSegmentationModuleBase.warmup_steps_ratio = 0.01 53 | LitSegmentationModuleBase.best_metric_type = "maximize" 54 | 55 | # Training 56 | train.data_module_name = "S3DISArea5RGBDataModule" 57 | train.gpus = 1 58 | train.log_every_n_steps = 10 59 | train.check_val_every_n_epoch = 1 60 | train.refresh_rate_per_second = 1 61 | train.best_metric = "val_mIoU" 62 | train.max_epoch = None 63 | train.max_step = 40000 64 | 65 | # Logging 66 | logged_hparams.keys = [ 67 | "train.model_name", 68 | "train.data_module_name", 69 | "DimensionlessCoordinates.voxel_size", 70 | "S3DISArea5RGBDataModule.train_transforms", 71 | "S3DISArea5RGBDataModule.eval_transforms", 72 | "S3DISArea5RGBDataModule.train_batch_size", 73 | "S3DISArea5RGBDataModule.val_batch_size", 74 | "S3DISArea5RGBDataModule.train_num_workers", 75 | "S3DISArea5RGBDataModule.val_num_workers", 76 | "RandomCrop.x", 77 | "RandomHorizontalFlip.upright_axis", 78 | "RandomAffine.upright_axis", 79 | "RandomAffine.application_ratio", 80 | "ChromaticJitter.std", 81 | "ChromaticJitter.application_ratio", 82 | "ElasticDistortion.distortion_params", 83 | "ElasticDistortion.application_ratio", 84 | "LitSegmentationModuleBase.lr", 85 | "LitSegmentationModuleBase.momentum", 86 | "LitSegmentationModuleBase.weight_decay", 87 | "LitSegmentationModuleBase.warmup_steps_ratio", 88 | "train.max_step", 89 | ] -------------------------------------------------------------------------------- /config/s3dis/train_fpt.gin: -------------------------------------------------------------------------------- 1 | # The code should be run on a GPU with at least 80GB memory (e.g., A100-80GB). 2 | include "./config/s3dis/train_res16unet34c.gin" 3 | 4 | # Model 5 | train.model_name = "FastPointTransformer" 6 | FastPointTransformer.in_channels = %in_channels 7 | FastPointTransformer.out_channels = %out_channels -------------------------------------------------------------------------------- /config/s3dis/train_res16unet34c.gin: -------------------------------------------------------------------------------- 1 | include "./config/s3dis/train_default.gin" 2 | 3 | # Model 4 | train.lightning_module_name = "LitSegMinkowskiModule" 5 | train.model_name = "Res16UNet34C" 6 | Res16UNetBase.in_channels = %in_channels 7 | Res16UNetBase.out_channels = %out_channels -------------------------------------------------------------------------------- /config/scannet/eval_default.gin: -------------------------------------------------------------------------------- 1 | # Constants 2 | in_channels = 3 3 | out_channels = 20 4 | 5 | # Data module 6 | ScanNetRGBDataModule.data_root = "/root/data/scannetv2" # you need to modify this according to your data. 7 | ScanNetRGBDataModule.train_batch_size = None 8 | ScanNetRGBDataModule.val_batch_size = 1 9 | ScanNetRGBDataModule.train_num_workers = None 10 | ScanNetRGBDataModule.val_num_workers = 4 11 | ScanNetRGBDataModule.collation_type = "collate_minkowski" 12 | ScanNetRGBDataModule.train_transforms = None 13 | ScanNetRGBDataModule.eval_transforms = [ 14 | "DimensionlessCoordinates", 15 | "NormalizeColor", 16 | ] 17 | 18 | # Augmentation 19 | DimensionlessCoordinates.voxel_size = 0.02 20 | 21 | # Evaluation 22 | eval.data_module_name = "ScanNetRGBDataModule" -------------------------------------------------------------------------------- /config/scannet/eval_fpt.gin: -------------------------------------------------------------------------------- 1 | include "./config/scannet/eval_res16unet34c.gin" 2 | 3 | # Model 4 | eval.model_name = "FastPointTransformer" 5 | FastPointTransformer.in_channels = %in_channels 6 | FastPointTransformer.out_channels = %out_channels -------------------------------------------------------------------------------- /config/scannet/eval_res16unet34c.gin: -------------------------------------------------------------------------------- 1 | include "./config/scannet/eval_default.gin" 2 | 3 | # Model 4 | eval.model_name = "Res16UNet34C" 5 | Res16UNet34C.in_channels = %in_channels 6 | Res16UNet34C.out_channels = %out_channels -------------------------------------------------------------------------------- /config/scannet/train_default.gin: -------------------------------------------------------------------------------- 1 | include "./config/default.gin" 2 | 3 | # Constants 4 | in_channels = 3 5 | out_channels = 20 6 | 7 | # Data module 8 | ScanNetRGBDataModule.data_root = "/root/data/scannetv2" # you need to modify this according to your data. 9 | ScanNetRGBDataModule.train_batch_size = 8 10 | ScanNetRGBDataModule.val_batch_size = 2 11 | ScanNetRGBDataModule.train_num_workers = 8 12 | ScanNetRGBDataModule.val_num_workers = 4 13 | ScanNetRGBDataModule.collation_type = "collate_minkowski" 14 | ScanNetRGBDataModule.train_transforms = [ 15 | "DimensionlessCoordinates", 16 | "RandomRotation", 17 | "RandomCrop", 18 | "RandomAffine", # affine to rotate the rectangular crop 19 | "CoordinateDropout", 20 | "ChromaticTranslation", 21 | "ChromaticJitter", 22 | "RandomHorizontalFlip", 23 | "RandomTranslation", 24 | "ElasticDistortion", 25 | "NormalizeColor", 26 | ] 27 | ScanNetRGBDataModule.eval_transforms = [ 28 | "DimensionlessCoordinates", 29 | "NormalizeColor", 30 | ] 31 | 32 | # Augmentation 33 | DimensionlessCoordinates.voxel_size = 0.02 34 | RandomCrop.x = 225 35 | RandomCrop.y = 225 36 | RandomCrop.z = 225 37 | RandomHorizontalFlip.upright_axis = "z" 38 | RandomAffine.upright_axis = "z" 39 | RandomAffine.application_ratio = 0.7 40 | ChromaticJitter.std = 0.01 41 | ChromaticJitter.application_ratio = 0.7 42 | ElasticDistortion.distortion_params = [(4, 16)] 43 | ElasticDistortion.application_ratio = 0.7 44 | 45 | # Pytorch lightning module 46 | LitSegmentationModuleBase.num_classes = %out_channels 47 | LitSegmentationModuleBase.lr = 0.1 48 | LitSegmentationModuleBase.momentum = 0.9 49 | LitSegmentationModuleBase.weight_decay = 1e-4 50 | LitSegmentationModuleBase.warmup_steps_ratio = 0.1 51 | LitSegmentationModuleBase.best_metric_type = "maximize" 52 | 53 | # Training 54 | train.data_module_name = "ScanNetRGBDataModule" 55 | train.gpus = 1 56 | train.log_every_n_steps = 10 57 | train.check_val_every_n_epoch = 1 58 | train.refresh_rate_per_second = 1 59 | train.best_metric = "val_mIoU" 60 | train.max_epoch = None 61 | train.max_step = 100000 62 | 63 | # Logging 64 | logged_hparams.keys = [ 65 | "train.model_name", 66 | "train.data_module_name", 67 | "DimensionlessCoordinates.voxel_size", 68 | "ScanNetRGBDataModule.train_transforms", 69 | "ScanNetRGBDataModule.eval_transforms", 70 | "ScanNetRGBDataModule.train_batch_size", 71 | "ScanNetRGBDataModule.val_batch_size", 72 | "ScanNetRGBDataModule.train_num_workers", 73 | "ScanNetRGBDataModule.val_num_workers", 74 | "RandomCrop.x", 75 | "RandomHorizontalFlip.upright_axis", 76 | "RandomAffine.upright_axis", 77 | "RandomAffine.application_ratio", 78 | "ChromaticJitter.std", 79 | "ChromaticJitter.application_ratio", 80 | "ElasticDistortion.distortion_params", 81 | "ElasticDistortion.application_ratio", 82 | "LitSegmentationModuleBase.lr", 83 | "LitSegmentationModuleBase.momentum", 84 | "LitSegmentationModuleBase.weight_decay", 85 | "LitSegmentationModuleBase.warmup_steps_ratio", 86 | "train.max_step", 87 | ] -------------------------------------------------------------------------------- /config/scannet/train_fpt.gin: -------------------------------------------------------------------------------- 1 | # The code should be run on a GPU with at least 80GB memory (e.g., A100-80GB). 2 | include "./config/scannet/train_res16unet34c.gin" 3 | 4 | # Model 5 | train.model_name = "FastPointTransformer" 6 | FastPointTransformer.in_channels = %in_channels 7 | FastPointTransformer.out_channels = %out_channels -------------------------------------------------------------------------------- /config/scannet/train_res16unet34c.gin: -------------------------------------------------------------------------------- 1 | # The code should be run on a GPU with at least 24GB memory (e.g., A5000). 2 | include "./config/scannet/train_default.gin" 3 | 4 | # Model 5 | train.lightning_module_name = "LitSegMinkowskiModule" 6 | train.model_name = "Res16UNet34C" 7 | Res16UNet34C.in_channels = %in_channels 8 | Res16UNet34C.out_channels = %out_channels -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: fpt 2 | channels: 3 | - anaconda 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=openblas 9 | - ca-certificates=2022.07.19=h06a4308_0 10 | - certifi=2022.6.15=py38h06a4308_0 11 | - ld_impl_linux-64=2.38=h1181459_1 12 | - libffi=3.4.2=h6a678d5_6 13 | - libgcc-ng=11.2.0=h1234567_1 14 | - libgfortran-ng=11.2.0=h00389a5_1 15 | - libgfortran5=11.2.0=h1234567_1 16 | - libgomp=11.2.0=h1234567_1 17 | - libopenblas=0.3.20=h043d6bf_1 18 | - libstdcxx-ng=11.2.0=h1234567_1 19 | - ncurses=6.4=h6a678d5_0 20 | - nomkl=3.0=0 21 | - openblas-devel=0.3.20=h06a4308_1 22 | - openssl=1.1.1s=h7f8727e_0 23 | - pip=22.3.1=py38h06a4308_0 24 | - python=3.8.16=h7a1cb2a_2 25 | - readline=8.2=h5eee18b_0 26 | - setuptools=65.6.3=py38h06a4308_0 27 | - sqlite=3.40.1=h5082296_0 28 | - tk=8.6.12=h1ccaba5_0 29 | - wheel=0.37.1=pyhd3eb1b0_0 30 | - xz=5.2.10=h5eee18b_1 31 | - zlib=1.2.13=h5eee18b_0 32 | - pip: 33 | - absl-py==1.4.0 34 | - addict==2.4.0 35 | - aiohttp==3.8.3 36 | - aiosignal==1.3.1 37 | - appdirs==1.4.4 38 | - asttokens==2.2.1 39 | - async-timeout==4.0.2 40 | - attrs==22.2.0 41 | - backcall==0.2.0 42 | - cachetools==5.3.0 43 | - charset-normalizer==2.1.1 44 | - click==8.1.3 45 | - comm==0.1.2 46 | - configargparse==1.5.3 47 | - contourpy==1.0.7 48 | - cuda-sparse-ops==0.1.0 49 | - cycler==0.11.0 50 | - dash==2.8.1 51 | - dash-core-components==2.0.0 52 | - dash-html-components==2.0.0 53 | - dash-table==5.0.0 54 | - debugpy==1.6.6 55 | - decorator==5.1.1 56 | - docker-pycreds==0.4.0 57 | - einops==0.6.0 58 | - executing==1.2.0 59 | - fastjsonschema==2.16.2 60 | - fire==0.5.0 61 | - flask==2.2.2 62 | - fonttools==4.38.0 63 | - frozenlist==1.3.3 64 | - fsspec==2023.1.0 65 | - gin-config==0.5.0 66 | - gitdb==4.0.10 67 | - gitpython==3.1.30 68 | - google-auth==2.16.0 69 | - google-auth-oauthlib==0.4.6 70 | - grpcio==1.51.1 71 | - h5py==3.8.0 72 | - idna==3.4 73 | - importlib-metadata==6.0.0 74 | - importlib-resources==5.10.2 75 | - ipykernel==6.21.0 76 | - ipython==8.9.0 77 | - ipywidgets==8.0.4 78 | - itsdangerous==2.1.2 79 | - jedi==0.18.2 80 | - jinja2==3.1.2 81 | - joblib==1.2.0 82 | - jsonschema==4.17.3 83 | - jupyter-client==8.0.2 84 | - jupyter-core==5.2.0 85 | - jupyterlab-widgets==3.0.5 86 | - kiwisolver==1.4.4 87 | - lightning-bolts==0.6.0.post1 88 | - lightning-utilities==0.3.0 89 | - markdown==3.4.1 90 | - markdown-it-py==2.1.0 91 | - markupsafe==2.1.2 92 | - matplotlib==3.6.3 93 | - matplotlib-inline==0.1.6 94 | - mdurl==0.1.2 95 | - minkowskiengine==0.5.4 96 | - multidict==6.0.4 97 | - nbformat==5.5.0 98 | - numpy==1.24.1 99 | - oauthlib==3.2.2 100 | - open3d==0.16.0 101 | - packaging==23.0 102 | - pandas==1.5.3 103 | - parso==0.8.3 104 | - pathtools==0.1.2 105 | - pexpect==4.8.0 106 | - pickleshare==0.7.5 107 | - pillow==9.4.0 108 | - pkgutil-resolve-name==1.3.10 109 | - platformdirs==2.6.2 110 | - plotly==5.13.0 111 | - plyfile==0.7.4 112 | - prompt-toolkit==3.0.36 113 | - protobuf==3.20.3 114 | - psutil==5.9.4 115 | - ptyprocess==0.7.0 116 | - pure-eval==0.2.2 117 | - pyasn1==0.4.8 118 | - pyasn1-modules==0.2.8 119 | - pygments==2.14.0 120 | - pyparsing==3.0.9 121 | - pyquaternion==0.9.9 122 | - pyrsistent==0.19.3 123 | - python-dateutil==2.8.2 124 | - pytorch-lightning==1.8.2 125 | - pytz==2022.7.1 126 | - pyyaml==6.0 127 | - pyzmq==25.0.0 128 | - requests==2.28.2 129 | - requests-oauthlib==1.3.1 130 | - rich==13.3.1 131 | - rsa==4.9 132 | - scikit-learn==1.2.1 133 | - scipy==1.10.0 134 | - sentry-sdk==1.14.0 135 | - setproctitle==1.3.2 136 | - six==1.16.0 137 | - smmap==5.0.0 138 | - stack-data==0.6.2 139 | - tenacity==8.1.0 140 | - tensorboard==2.11.2 141 | - tensorboard-data-server==0.6.1 142 | - tensorboard-plugin-wit==1.8.1 143 | - termcolor==2.2.0 144 | - threadpoolctl==3.1.0 145 | - torch==1.12.1+cu113 146 | - torch-scatter==2.1.0+pt112cu113 147 | - torchmetrics==0.11.0 148 | - torchvision==0.13.1+cu113 149 | - tornado==6.2 150 | - tqdm==4.64.1 151 | - traitlets==5.9.0 152 | - typing-extensions==4.4.0 153 | - urllib3==1.26.14 154 | - wandb==0.13.9 155 | - wcwidth==0.2.6 156 | - werkzeug==2.2.2 157 | - widgetsnbextension==4.0.5 158 | - yarl==1.8.2 159 | - zipp==3.12.0 160 | prefix: /opt/anaconda3/envs/fpt 161 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import argparse 3 | 4 | import gin 5 | import torch 6 | import torchmetrics 7 | import MinkowskiEngine as ME 8 | import numpy as np 9 | from rich.console import Console 10 | from rich.progress import track 11 | from rich.table import Table 12 | 13 | from src.models import get_model 14 | from src.data import get_data_module 15 | from src.utils.metric import per_class_iou 16 | import src.data.transforms as T 17 | 18 | 19 | def print_results(classnames, confusion_matrix): 20 | # results 21 | ious = per_class_iou(confusion_matrix) * 100 22 | accs = confusion_matrix.diagonal() / confusion_matrix.sum(1) * 100 23 | miou = np.nanmean(ious) 24 | macc = np.nanmean(accs) 25 | 26 | # print results 27 | console = Console() 28 | table = Table(show_header=True, header_style="bold") 29 | 30 | columns = ["mAcc", "mIoU"] 31 | num_classes = len(classnames) 32 | for i in range(num_classes): 33 | columns.append(classnames[i]) 34 | for col in columns: 35 | table.add_column(col) 36 | ious = ious.tolist() 37 | row = [macc, miou, *ious] 38 | table.add_row(*[f"{x:.2f}" for x in row]) 39 | console.print(table) 40 | 41 | 42 | def get_rotation_matrices(num_rotations=8): 43 | angles = [2 * np.pi / num_rotations * i for i in range(num_rotations)] 44 | rot_matrices = [] 45 | for angle in angles: 46 | rot_matrices.append( 47 | torch.Tensor([ 48 | [np.cos(angle), -np.sin(angle), 0, 0], 49 | [np.sin(angle), np.cos(angle), 0, 0], 50 | [0, 0, 1, 0], 51 | [0, 0, 0, 1] 52 | ]) 53 | ) 54 | return rot_matrices 55 | 56 | 57 | @torch.no_grad() 58 | def infer(model, batch, device): 59 | in_data = ME.TensorField( 60 | features=batch["features"], 61 | coordinates=batch["coordinates"], 62 | quantization_mode=model.QMODE, 63 | device=device 64 | ) 65 | pred = model(in_data).argmax(dim=1).cpu() 66 | return pred 67 | 68 | 69 | @torch.no_grad() 70 | def infer_with_rotation_average(model, batch, device): 71 | rotation_matrices = get_rotation_matrices() 72 | pred = torch.zeros((len(batch["labels"]), model.out_channels), dtype=torch.float32) 73 | for M in rotation_matrices: 74 | batch_, coords_ = torch.split(batch["coordinates"], [1, 3], dim=1) 75 | coords = T.homogeneous_coords(coords_) @ M 76 | coords = torch.cat([batch_, coords[:, :3].float()], dim=1) 77 | 78 | in_data = ME.TensorField( 79 | features=batch["features"], 80 | coordinates=coords, 81 | quantization_mode=model.QMODE, 82 | device=device 83 | ) 84 | pred += model(in_data).cpu() 85 | 86 | gc.collect() 87 | torch.cuda.empty_cache() 88 | 89 | pred = pred.argmax(dim=1) 90 | return pred 91 | 92 | 93 | @gin.configurable 94 | def eval( 95 | checkpoint_path, 96 | model_name, 97 | data_module_name, 98 | use_rotation_average, 99 | ): 100 | assert torch.cuda.is_available() 101 | device = torch.device("cuda") 102 | 103 | ckpt = torch.load(checkpoint_path) 104 | 105 | def remove_prefix(k, prefix): 106 | return k[len(prefix) :] if k.startswith(prefix) else k 107 | 108 | state_dict = {remove_prefix(k, "model."): v for k, v in ckpt["state_dict"].items()} 109 | model = get_model(model_name)() 110 | model.load_state_dict(state_dict) 111 | model = model.to(device) 112 | model.eval() 113 | 114 | data_module = get_data_module(data_module_name)() 115 | data_module.setup("test") 116 | val_loader = data_module.val_dataloader() 117 | 118 | confmat = torchmetrics.ConfusionMatrix( 119 | num_classes=data_module.dset_val.NUM_CLASSES, compute_on_step=False 120 | ) 121 | infer_fn = infer_with_rotation_average if use_rotation_average else infer 122 | with torch.inference_mode(mode=True): 123 | for batch in track(val_loader): 124 | pred = infer_fn(model, batch, device) 125 | mask = batch["labels"] != data_module.dset_val.ignore_label 126 | confmat(pred[mask], batch["labels"][mask]) 127 | torch.cuda.empty_cache() 128 | confmat = confmat.compute().numpy() 129 | 130 | cnames = data_module.dset_val.get_classnames() 131 | print_results(cnames, confmat) 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument("config", type=str) 137 | parser.add_argument("ckpt_path", type=str) 138 | parser.add_argument("-r", "--use_rotation_average", action="store_true") 139 | parser.add_argument("-v", "--voxel_size", type=float, default=None) # overwrite voxel_size 140 | args = parser.parse_args() 141 | 142 | gin.parse_config_file(args.config) 143 | if args.voxel_size is not None: 144 | gin.bind_parameter("DimensionlessCoordinates.voxel_size", args.voxel_size) 145 | 146 | eval(args.ckpt_path, use_rotation_average=args.use_rotation_average) -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | SERVER=${2:-local} 4 | 5 | if [[ $SERVER = *local* ]]; then 6 | echo "[FPT INFO] Running on Local: You should manually load modules..." 7 | conda init zsh 8 | source /opt/anaconda3/etc/profile.d/conda.sh # you may need to modify the conda path. 9 | export CUDA_HOME=/usr/local/cuda-11.1 10 | else 11 | echo "[FPT INFO] Running on Server..." 12 | conda init bash 13 | source ~/anaconda3/etc/profile.d/conda.sh 14 | 15 | module purge 16 | module load autotools 17 | module load prun/1.3 18 | module load gnu8/8.3.0 19 | module load singularity 20 | 21 | module load cuDNN/cuda/11.1/8.0.4.30 22 | module load cuda/11.1 23 | module load nccl/cuda/11.1/2.8.3 24 | 25 | echo "[FPT INFO] Loaded all modules." 26 | fi; 27 | 28 | ENVS=$(conda env list | awk '{print $1}' ) 29 | 30 | if [[ $ENVS = *"$1"* ]]; then 31 | echo "[FPT INFO] \"$1\" already exists. Pass the installation." 32 | else 33 | echo "[FPT INFO] Creating $1..." 34 | conda create -n $1 python=3.8 -y 35 | conda activate "$1" 36 | echo "[FPT INFO] Done." 37 | 38 | echo "[FPT INFO] Installing OpenBLAS and PyTorch..." 39 | conda install pytorch=1.10.0 torchvision cudatoolkit=11.1 -c pytorch -c nvidia -y 40 | conda install numpy -y 41 | conda install openblas-devel -c anaconda -y 42 | echo "[FPT INFO] Done." 43 | 44 | echo "[FPT INFO] Installing other dependencies..." 45 | conda install -c anaconda pandas scipy h5py scikit-learn -y 46 | conda install -c conda-forge plyfile pytorch-lightning torchmetrics wandb wrapt gin-config rich einops -y 47 | conda install -c open3d-admin -c conda-forge open3d -y 48 | pip install lightning-bolts 49 | pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html 50 | echo "[FPT INFO] Done." 51 | 52 | echo "[FPT INFO] Installing MinkowskiEngine..." 53 | cd thirdparty/MinkowskiEngine 54 | python setup.py install --blas_include_dirs=${CONDA_PREFIX}/include --blas=openblas --force_cuda 55 | cd ../.. 56 | echo "[FPT INFO] Done." 57 | 58 | echo "[FPT INFO] Installing cuda_ops..." 59 | cd src/cuda_ops 60 | pip3 install . 61 | cd ../.. 62 | echo "[FPT INFO] Done." 63 | 64 | TORCH="$(python -c "import torch; print(torch.__version__)")" 65 | ME="$(python -c "import MinkowskiEngine as ME; print(ME.__version__)")" 66 | 67 | echo "[FPT INFO] Finished the installation!" 68 | echo "[FPT INFO] ========== Configurations ==========" 69 | echo "[FPT INFO] PyTorch version: $TORCH" 70 | echo "[FPT INFO] MinkowskiEngine version: $ME" 71 | echo "[FPT INFO] ====================================" 72 | fi; -------------------------------------------------------------------------------- /src/cscore/calculate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | N = 41 # the number of rigid transforms 9 | N_t = 26 # translation 10 | N_r = 15 # rotation 11 | M = 312 # the number of validation scenes 12 | 13 | 14 | def main(args): 15 | ref_dir = osp.join(args.dir, "reference") 16 | ref_fnames = sorted(os.listdir(ref_dir)) 17 | assert len(ref_fnames) == M 18 | pred_dirs = [osp.join(args.dir, f"transform{i}") for i in range(N)] 19 | 20 | print(">>> Loading references...") 21 | # load references 22 | refs = [np.load(osp.join(ref_dir, fname)) for fname in ref_fnames] 23 | print(" done!") 24 | 25 | # translation 26 | print(">>> Calculating point-wise CScores for translation...") 27 | score_trans = [] # pointwise scores 28 | for i, (fname, ref) in enumerate(tqdm(zip(ref_fnames, refs))): 29 | score_tran = np.zeros_like(ref) 30 | for j in range(N_t): 31 | pred = np.load(osp.join(pred_dirs[j], fname)) 32 | score_tran[np.where(ref == pred)] += 1 33 | score_tran = score_tran / N_t 34 | score_tran[np.where(ref == 255)] = -1 35 | score_trans.append(score_tran) 36 | print(" done!") 37 | 38 | # rotation 39 | print(">>> Calculating point-wise CScores for rotation...") 40 | score_rots = [] # pointwise scores 41 | for i, (fname, ref) in enumerate(tqdm(zip(ref_fnames, refs))): 42 | score_rot = np.zeros_like(ref) 43 | for j in range(N_r): 44 | pred = np.load(osp.join(pred_dirs[N_t + j], fname)) 45 | score_rot[np.where(ref == pred)] += 1 46 | score_rot = score_rot / N_r 47 | score_rot[np.where(ref == 255)] = -1 48 | score_rots.append(score_rot) 49 | print(" done!") 50 | 51 | # full 52 | print(">>> Calculating point-wise CScores for full rigid transformations...") 53 | score_fulls = [] 54 | for fname, score_tran, score_rot in tqdm(zip(ref_fnames, score_trans, score_rots)): 55 | score_full = (N_t*score_tran + N_r*score_rot) / N 56 | score_fulls.append(score_full) 57 | print(" done!") 58 | 59 | # final calculation 60 | cloudwise_score_trans = [np.mean(score_tran[np.where(score_tran > -0.5)]) for score_tran in score_trans] 61 | cloudwise_score_rots = [np.mean(score_rot[np.where(score_rot > -0.5)]) for score_rot in score_rots] 62 | cloudwise_score_fulls = [np.mean(score_full[np.where(score_full > -0.5)]) for score_full in score_fulls] 63 | 64 | if args.save: 65 | # save pointwise scores for visualization 66 | output_dir = osp.join(args.dir, "pointwise_scores") 67 | os.makedirs(output_dir, exist_ok=True) 68 | # save for visualization 69 | print(">>> Saving point-wise CScores for full rigid transformations...") 70 | for fname, cloudwise_score_full, score_full in tqdm(zip(ref_fnames, cloudwise_score_fulls, score_fulls)): 71 | scene_id = fname.split('.')[0] 72 | np.save(osp.join(output_dir, f'{scene_id}-score={round(1000 * cloudwise_score_full):d}.npy'), score_full) 73 | print(" done!") 74 | 75 | # final logging 76 | total_score_tran = np.mean(cloudwise_score_trans) 77 | total_score_rot = np.mean(cloudwise_score_rots) 78 | total_score_full = np.mean(cloudwise_score_fulls) 79 | print(">>> Results:") 80 | print(f" Rotation: {100 * total_score_rot:.1f}") 81 | print(f" Translation:: {100 * total_score_tran:.1f}") 82 | print(f" Full: {100 * total_score_full:.1f}") 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument("--dir", type=str) 88 | parser.add_argument("--save", action="store_true", default=False) # for point-wise cscore visualization. 89 | args = parser.parse_args() 90 | 91 | main(args) -------------------------------------------------------------------------------- /src/cscore/prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import argparse 4 | import os.path as osp 5 | 6 | import torch 7 | import numpy as np 8 | import pytorch_lightning as pl 9 | import MinkowskiEngine as ME 10 | from tqdm import tqdm 11 | 12 | from src.models import get_model 13 | from src.data.scannet_loader import ScanNetRGBDataset_ 14 | from src.data.transforms import NormalizeColor, homogeneous_coords 15 | from src.utils.misc import load_from_pl_state_dict 16 | 17 | 18 | class SimpleConfig: 19 | def __init__( 20 | self, 21 | scannet_path="/root/data/scannetv2", 22 | ignore_label=255, 23 | voxel_size=0.1, 24 | cache_data=False 25 | ): 26 | self.scannet_path = scannet_path 27 | self.voxel_size= voxel_size 28 | self.ignore_label = ignore_label 29 | self.cache_data = cache_data 30 | self.limit_numpoints = -1 31 | 32 | 33 | class TransformGenerator: 34 | def __init__(self, voxel_size=0.1, trans_granularity=3, rot_granularity=8): 35 | self.voxel_size = voxel_size 36 | trans_step = 1. / trans_granularity 37 | self.trans_list = [trans_step * i * voxel_size for i in range(trans_granularity)] # 0, 1, 2, 3, ..., granularity - 1 38 | rot_step = 1. / rot_granularity 39 | self.rot_list = [rot_step * i * np.pi for i in range(1, 2*rot_granularity)] 40 | self.transform_list = [] 41 | self.num_trans = 0 42 | self.num_rot = 0 43 | 44 | def generate_transforms(self): 45 | self._cleanup() 46 | self.transform_list.extend(self._get_trans_mtx_list()) 47 | self.transform_list.extend(self._get_rot_mtx_list()) 48 | 49 | def _get_trans_mtx_list(self): 50 | mtx_list = [] 51 | for delta_x in self.trans_list: 52 | for delta_y in self.trans_list: 53 | for delta_z in self.trans_list: 54 | if delta_x == 0 and delta_y == 0 and delta_z == 0: 55 | continue 56 | T = np.eye(4) 57 | T[0, 3] = delta_x 58 | T[1, 3] = delta_y 59 | T[2, 3] = delta_z 60 | mtx_list.append(T) 61 | self.num_trans += 1 62 | return mtx_list 63 | 64 | def _get_rot_mtx_list(self): 65 | mtx_list = [] 66 | for theta in self.rot_list: 67 | T = np.array( 68 | [ 69 | [np.cos(theta), -np.sin(theta), 0, 0], 70 | [np.sin(theta), np.cos(theta), 0, 0], 71 | [0, 0, 1, 0], 72 | [0, 0, 0, 1] 73 | ] 74 | ) 75 | mtx_list.append(T) 76 | self.num_rot += 1 77 | return mtx_list 78 | 79 | def _cleanup(self): 80 | self.num_trans = 0 81 | self.num_rot = 0 82 | self.transform_list = [] 83 | 84 | 85 | def save_outputs(args, model, name, tmatrix): 86 | output_dir = osp.join(args.out_dir, args.model_name + f"-voxel={args.voxel_size}") 87 | if args.postfix is not None: 88 | output_dir = output_dir + args.postfix 89 | transform = NormalizeColor() 90 | dset = ScanNetRGBDataset_("val", args.scannet_path, transform) 91 | num_samples = len(dset) 92 | output_name = osp.join(output_dir, name) 93 | assert not osp.isdir(output_name) 94 | os.makedirs(output_name, exist_ok=True) 95 | 96 | print(f'>>> Saving predictions for {num_samples} val samples...') 97 | with torch.inference_mode(mode=True): 98 | for coords_, feats, labels, fname in dset: 99 | if tmatrix is not None: 100 | np.save(osp.join(output_name, "tmatrix.npy"), tmatrix) 101 | coords_ = homogeneous_coords(coords_) @ tmatrix.T 102 | coords = coords_[:, :3] 103 | else: 104 | coords = coords_ 105 | coords, feats = ME.utils.sparse_collate( 106 | [coords / args.voxel_size], 107 | [feats], 108 | dtype=torch.float32 109 | ) 110 | in_field = ME.TensorField( 111 | features=feats, 112 | coordinates=coords, 113 | quantization_mode=model.QMODE, 114 | device=device 115 | ) 116 | pred = model(in_field).argmax(dim=1, keepdim=False).cpu().numpy() 117 | assert len(pred) == len(labels) 118 | pred[np.where(labels.numpy() == 255)] = 255 # ignore labels 119 | scene_id = fname.split('/')[-1].split('.')[0] 120 | np.save(osp.join(output_name, f'{scene_id}.npy'), pred) 121 | gc.collect() 122 | torch.cuda.empty_cache() 123 | print(' done!') 124 | 125 | 126 | if __name__ == "__main__": 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument("-c", "--ckpt", type=str) 129 | parser.add_argument("-m", "--model_name", type=str, choices=["mink", "fpt"]) 130 | parser.add_argument("-v", "--voxel_size", type=float, choices=[0.1, 0.05, 0.02]) 131 | parser.add_argument("--scannet_path", type=str, default="/root/data/scannetv2") 132 | parser.add_argument("--out_dir", type=str, default="consistency_outputs") 133 | parser.add_argument("-p", "--postfix", type=str, default=None) 134 | args = parser.parse_args() 135 | 136 | assert torch.cuda.is_available() 137 | device = torch.device("cuda") 138 | print(f">>> Loading the checkpoint from {args.ckpt}...") 139 | ckpt = torch.load(args.ckpt) 140 | pl.seed_everything(7777) 141 | print(" done!") 142 | 143 | print(">>> Loading the model...") 144 | if args.model_name == "mink": 145 | model = get_model("Res16UNet34C")(ScanNetRGBDataset_.IN_CHANNELS, ScanNetRGBDataset_.NUM_CLASSES) 146 | else: 147 | model = get_model("FastPointTransformer")(ScanNetRGBDataset_.IN_CHANNELS, ScanNetRGBDataset_.NUM_CLASSES) 148 | model = load_from_pl_state_dict(model, ckpt["state_dict"]) 149 | model = model.to(device) 150 | model.eval() 151 | print(" done!") 152 | 153 | print(">>> Generating rigid transformations...") 154 | tgenerator = TransformGenerator(voxel_size=args.voxel_size) 155 | tgenerator.generate_transforms() 156 | print(f' {len(tgenerator.transform_list)} rigid transformations generated!') 157 | 158 | print(f">>> Evaluating the model...") 159 | save_outputs(args, model, "reference", None) 160 | for t_idx, tmatrix in enumerate(tqdm(tgenerator.transform_list)): 161 | print(tmatrix) 162 | save_outputs(args, model, f"transform{t_idx}", tmatrix) 163 | print(" done!") -------------------------------------------------------------------------------- /src/cuda_ops/functions/sparse_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | import cuda_sparse_ops 5 | 6 | 7 | class DotProduct(Function): 8 | @staticmethod 9 | def forward(ctx, query, pos_enc, out_F, kq_map): 10 | assert (query.is_contiguous() and pos_enc.is_contiguous() and out_F.is_contiguous()) 11 | ctx.m = kq_map.shape[1] 12 | _, ctx.h, ctx.c = query.shape 13 | ctx.kkk = pos_enc.shape[0] 14 | ctx.save_for_backward(query, pos_enc, kq_map) 15 | cuda_sparse_ops.dot_product_forward(ctx.m, ctx.h, ctx.kkk, ctx.c, query, pos_enc, 16 | out_F, kq_map) 17 | return out_F 18 | 19 | @staticmethod 20 | def backward(ctx, grad_out_F): 21 | query, pos_enc, kq_map = ctx.saved_tensors 22 | grad_query = torch.zeros_like(query) 23 | grad_pos = torch.zeros_like(pos_enc) 24 | cuda_sparse_ops.dot_product_backward(ctx.m, ctx.h, ctx.kkk, ctx.c, query, pos_enc, 25 | kq_map, grad_query, grad_pos, grad_out_F) 26 | return grad_query, grad_pos, None, None 27 | 28 | dot_product_cuda = DotProduct.apply 29 | 30 | 31 | class ScalarAttention(Function): 32 | @staticmethod 33 | def forward(ctx, weight, value, out_F, kq_indices): 34 | assert (weight.is_contiguous() and value.is_contiguous() and out_F.is_contiguous()) 35 | ctx.m = kq_indices.shape[1] 36 | _, ctx.h, ctx.c = value.shape 37 | ctx.save_for_backward(weight, value, kq_indices) 38 | cuda_sparse_ops.scalar_attention_forward(ctx.m, ctx.h, ctx.c, weight, value, out_F, 39 | kq_indices) 40 | return out_F 41 | 42 | @staticmethod 43 | def backward(ctx, grad_out_F): 44 | weight, value, kq_indices = ctx.saved_tensors 45 | grad_weight = torch.zeros_like(weight) 46 | grad_value = torch.zeros_like(value) 47 | cuda_sparse_ops.scalar_attention_backward(ctx.m, ctx.h, ctx.c, weight, value, 48 | kq_indices, grad_weight, grad_value, 49 | grad_out_F) 50 | return grad_weight, grad_value, None, None 51 | 52 | scalar_attention_cuda = ScalarAttention.apply 53 | -------------------------------------------------------------------------------- /src/cuda_ops/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='cuda_sparse_ops', 6 | author='Chunghyun Park and Yoonwoo Jeong', 7 | version="0.1.0", 8 | ext_modules=[ 9 | CUDAExtension('cuda_sparse_ops', [ 10 | 'src/cuda_ops_api.cpp', 11 | 'src/dot_product/dot_product.cpp', 12 | 'src/dot_product/dot_product_kernel.cu', 13 | 'src/scalar_attention/scalar_attention.cpp', 14 | 'src/scalar_attention/scalar_attention_kernel.cu', 15 | ], 16 | extra_compile_args={ 17 | 'cxx': ['-g'], 18 | 'nvcc': ['-O2'] 19 | }) 20 | ], 21 | cmdclass={'build_ext': BuildExtension}) 22 | -------------------------------------------------------------------------------- /src/cuda_ops/src/cuda_ops_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "dot_product/dot_product_kernel.h" 5 | #include "scalar_attention/scalar_attention_kernel.h" 6 | 7 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 8 | m.def("dot_product_forward", &dot_product_forward, "dot_product_forward"); 9 | m.def("dot_product_backward", &dot_product_backward, "dot_product_backward"); 10 | m.def("scalar_attention_forward", &scalar_attention_forward, "scalar_attention_forward"); 11 | m.def("scalar_attention_backward", &scalar_attention_backward, "scalar_attention_backward"); 12 | } -------------------------------------------------------------------------------- /src/cuda_ops/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | 7 | #define TOTAL_THREADS 1024 8 | #define THREADS_PER_BLOCK 256 9 | #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) 10 | 11 | inline int opt_n_threads(int work_size) { 12 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 13 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | 16 | inline dim3 opt_block_config(int x, int y) { 17 | const int x_threads = opt_n_threads(x); 18 | const int y_threads = std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 19 | dim3 block_config(x_threads, y_threads, 1); 20 | return block_config; 21 | } 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /src/cuda_ops/src/dot_product/dot_product.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "dot_product_kernel.h" 6 | 7 | void dot_product_forward( 8 | int m, int h, int kkk, int c, AT query_tensor, AT pos_tensor, AT out_F_tensor, AT kq_map_tensor 9 | ) 10 | { 11 | const float* query = query_tensor.data_ptr(); 12 | const float* pos = pos_tensor.data_ptr(); 13 | float* out_F = out_F_tensor.data_ptr(); 14 | const int* kq_map = kq_map_tensor.data_ptr(); 15 | 16 | dot_product_forward_launcher( 17 | m, h, kkk, c, query, pos, out_F, kq_map 18 | ); 19 | } 20 | 21 | void dot_product_backward( 22 | int m, int h, int kkk, int c, AT query_tensor, AT pos_tensor, AT kq_map_tensor, 23 | AT grad_query_tensor, AT grad_pos_tensor, AT grad_out_F_tensor 24 | ) 25 | { 26 | const float* query = query_tensor.data_ptr(); 27 | const float* pos = pos_tensor.data_ptr(); 28 | const int* kq_map = kq_map_tensor.data_ptr(); 29 | 30 | float* grad_query = grad_query_tensor.data_ptr(); 31 | float* grad_pos = grad_pos_tensor.data_ptr(); 32 | const float* grad_out_F = grad_out_F_tensor.data_ptr(); 33 | 34 | dot_product_backward_launcher( 35 | m, h, kkk, c, query, pos, kq_map, 36 | grad_query, grad_pos, grad_out_F 37 | ); 38 | } -------------------------------------------------------------------------------- /src/cuda_ops/src/dot_product/dot_product_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "dot_product_kernel.h" 3 | 4 | 5 | __global__ void dot_product_forward_kernel( 6 | int m, int h, int kkk, int c, const float* query, const float* pos, float* out_F, const int* kq_map 7 | ) 8 | { 9 | // m: # of total mappings 10 | // h: # of attention heads 11 | // kkk: # of keys (kernel volume) 12 | // c: # of attention channels 13 | 14 | int index = blockIdx.x * blockDim.x + threadIdx.x; 15 | if (index >= m * h) return; 16 | 17 | int map_idx = index / h; 18 | int head_idx = index % h; 19 | 20 | int query_idx_ = kq_map[m + map_idx]; // kq_map[1][map_idx] 21 | int kernel_idx = kq_map[map_idx] % kkk; 22 | 23 | for(int i = 0; i < c; i++){ 24 | 25 | int query_idx = query_idx_ * h * c + head_idx * c + i; 26 | int pos_idx = kernel_idx * h * c + head_idx * c + i; 27 | 28 | out_F[index] += query[query_idx] * pos[pos_idx]; 29 | } 30 | } 31 | 32 | __global__ void dot_product_backward_kernel( 33 | int m, int h, int kkk, int c, const float* query, const float* pos, const int* kq_map, 34 | float* grad_query, float* grad_pos, const float* grad_out_F 35 | ) 36 | { 37 | // m: # of total mappings 38 | // h: # of attention heads 39 | // kkk: # of keys (kernel volume) 40 | // c: # of attention channels 41 | 42 | int index = blockIdx.x * blockDim.x + threadIdx.x; 43 | if (index >= m * c) return; 44 | 45 | int map_idx = index / c; 46 | int i = index % c; 47 | 48 | int query_idx_ = kq_map[m + map_idx]; // kq_map[1][map_idx] 49 | int kernel_idx = kq_map[map_idx] % kkk; 50 | 51 | for(int head_idx = 0; head_idx < h; head_idx++){ 52 | 53 | int out_F_idx = map_idx * h + head_idx; 54 | int query_idx = query_idx_ * h * c + head_idx * c + i; 55 | int pos_idx = kernel_idx * h * c + head_idx * c + i; 56 | 57 | atomicAdd(grad_query + query_idx, grad_out_F[out_F_idx] * pos[pos_idx]); 58 | atomicAdd(grad_pos + pos_idx, grad_out_F[out_F_idx] * query[query_idx]); 59 | } 60 | } 61 | 62 | void dot_product_forward_launcher( 63 | int m, int h, int kkk, int c, const float* query, const float* pos, float* out_F, const int* kq_map 64 | ) { 65 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 66 | dim3 blocks(DIVUP(m * h, THREADS_PER_BLOCK)); 67 | dim3 threads(THREADS_PER_BLOCK); 68 | dot_product_forward_kernel<<>>( 69 | m, h, kkk, c, query, pos, out_F, kq_map 70 | ); 71 | } 72 | 73 | void dot_product_backward_launcher( 74 | int m, int h, int kkk, int c, const float* query, const float* pos, const int* kq_map, 75 | float* grad_query, float* grad_pos, const float* grad_out_F 76 | ) { 77 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 78 | dim3 blocks(DIVUP(m * c, THREADS_PER_BLOCK)); 79 | dim3 threads(THREADS_PER_BLOCK); 80 | dot_product_backward_kernel<<>>( 81 | m, h, kkk, c, query, pos, kq_map, 82 | grad_query, grad_pos, grad_out_F 83 | ); 84 | } 85 | -------------------------------------------------------------------------------- /src/cuda_ops/src/dot_product/dot_product_kernel.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef _dot_product_KERNEL 3 | #define _dot_product_KERNEL 4 | #include 5 | #include 6 | #include 7 | 8 | #define AT at::Tensor 9 | 10 | void dot_product_forward( 11 | int m, int h, int kkk, int c, AT query_tensor, AT pos_tensor, AT out_F_tensor, AT kq_map_tensor 12 | ); 13 | void dot_product_backward( 14 | int m, int h, int kkk, int c, AT query_tensor, AT pos_tensor, AT kq_map_tensor, 15 | AT grad_query_tensor, AT grad_pos_tensor, AT grad_out_F_tensor 16 | ); 17 | 18 | #ifdef __cplusplus 19 | extern "C" { 20 | #endif 21 | 22 | void dot_product_forward_launcher( 23 | int m, int h, int kkk, int c, const float* query, const float* pos, float* out_F, const int* kq_map 24 | ); 25 | void dot_product_backward_launcher( 26 | int m, int h, int kkk, int c, const float* query, const float* pos, const int* kq_map, 27 | float* grad_query, float* grad_pos, const float* grad_out_F 28 | ); 29 | 30 | #ifdef __cplusplus 31 | } 32 | #endif 33 | #endif 34 | -------------------------------------------------------------------------------- /src/cuda_ops/src/scalar_attention/scalar_attention.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "scalar_attention_kernel.h" 6 | 7 | void scalar_attention_forward( 8 | int m, int h, int c, AT weight_tensor, AT value_tensor, AT out_F_tensor, AT kq_indices_tensor 9 | ) 10 | { 11 | const float* weight = weight_tensor.data_ptr(); 12 | const float* value = value_tensor.data_ptr(); 13 | float* out_F = out_F_tensor.data_ptr(); 14 | const int* kq_indices = kq_indices_tensor.data_ptr(); 15 | 16 | scalar_attention_forward_launcher( 17 | m, h, c, weight, value, out_F, kq_indices 18 | ); 19 | } 20 | 21 | void scalar_attention_backward( 22 | int m, int h, int c, AT weight_tensor, AT value_tensor, AT kq_indices_tensor, 23 | AT grad_weight_tensor, AT grad_value_tensor, AT grad_out_F_tensor 24 | ) 25 | { 26 | const float* weight = weight_tensor.data_ptr(); 27 | const float* value = value_tensor.data_ptr(); 28 | const int* kq_indices = kq_indices_tensor.data_ptr(); 29 | 30 | float* grad_weight = grad_weight_tensor.data_ptr(); 31 | float* grad_value = grad_value_tensor.data_ptr(); 32 | const float* grad_out_F = grad_out_F_tensor.data_ptr(); 33 | 34 | scalar_attention_backward_launcher( 35 | m, h, c, weight, value, kq_indices, 36 | grad_weight, grad_value, grad_out_F 37 | ); 38 | } -------------------------------------------------------------------------------- /src/cuda_ops/src/scalar_attention/scalar_attention_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "scalar_attention_kernel.h" 3 | 4 | 5 | __global__ void scalar_attention_forward_kernel( 6 | int m, int h, int c, const float* weight, const float* value, float* out_F, const int* kq_indices 7 | ) 8 | { 9 | // m: # of total mappings 10 | // h: # of attention heads 11 | // c: # of attention channels 12 | 13 | int index = blockIdx.x * blockDim.x + threadIdx.x; 14 | if (index >= m * c) return; 15 | 16 | int map_idx = index / c; 17 | int i = index % c; 18 | 19 | int out_F_idx_ = kq_indices[m + map_idx]; // kq_indices[1][map_idx] 20 | int value_idx_ = kq_indices[map_idx]; // kq_indices[0][map_idx] 21 | 22 | for(int head_idx = 0; head_idx < h; head_idx++){ 23 | 24 | int weight_idx = map_idx * h + head_idx; 25 | int out_F_idx = out_F_idx_ * h * c + head_idx * c + i; 26 | int value_idx = value_idx_ * h * c + head_idx * c + i; 27 | 28 | atomicAdd(out_F + out_F_idx, weight[weight_idx] * value[value_idx]); 29 | } 30 | } 31 | 32 | __global__ void scalar_attention_backward_kernel( 33 | int m, int h, int c, const float* weight, const float* value, const int* kq_indices, 34 | float* grad_weight, float* grad_value, const float* grad_out_F 35 | ) 36 | { 37 | // m: # of total mappings 38 | // h: # of attention heads 39 | // c: # of attention channels 40 | 41 | int index = blockIdx.x * blockDim.x + threadIdx.x; 42 | if (index >= m * c) return; 43 | 44 | int map_idx = index / c; 45 | int i = index % c; 46 | 47 | int out_F_idx_ = kq_indices[m + map_idx]; // kq_indices[1][map_idx] 48 | int value_idx_ = kq_indices[map_idx]; // kq_indices[0][map_idx] 49 | 50 | for(int head_idx = 0; head_idx < h; head_idx++){ 51 | 52 | int weight_idx = map_idx * h + head_idx; 53 | int out_F_idx = out_F_idx_ * h * c + head_idx * c + i; 54 | int value_idx = value_idx_ * h * c + head_idx * c + i; 55 | 56 | atomicAdd(grad_weight + weight_idx, grad_out_F[out_F_idx] * value[value_idx]); 57 | atomicAdd(grad_value + value_idx, grad_out_F[out_F_idx] * weight[weight_idx]); 58 | } 59 | } 60 | 61 | void scalar_attention_forward_launcher( 62 | int m, int h, int c, const float* weight, const float* value, float* out_F, const int* kq_indices 63 | ) { 64 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 65 | dim3 blocks(DIVUP(m * c, THREADS_PER_BLOCK)); 66 | dim3 threads(THREADS_PER_BLOCK); 67 | scalar_attention_forward_kernel<<>>( 68 | m, h, c, weight, value, out_F, kq_indices 69 | ); 70 | } 71 | 72 | void scalar_attention_backward_launcher( 73 | int m, int h, int c, const float* weight, const float* value, const int* kq_indices, 74 | float* grad_weight, float* grad_value, const float* grad_out_F 75 | ) { 76 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 77 | dim3 blocks(DIVUP(m * c, THREADS_PER_BLOCK)); 78 | dim3 threads(THREADS_PER_BLOCK); 79 | scalar_attention_backward_kernel<<>>( 80 | m, h, c, weight, value, kq_indices, 81 | grad_weight, grad_value, grad_out_F 82 | ); 83 | } 84 | -------------------------------------------------------------------------------- /src/cuda_ops/src/scalar_attention/scalar_attention_kernel.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef _scalar_attention_KERNEL 3 | #define _scalar_attention_KERNEL 4 | #include 5 | #include 6 | #include 7 | 8 | #define AT at::Tensor 9 | 10 | void scalar_attention_forward( 11 | int m, int h, int c, AT weight_tensor, AT value_tensor, AT out_F_tensor, AT kq_indices_tensor 12 | ); 13 | void scalar_attention_backward( 14 | int m, int h, int c, AT weight_tensor, AT value_tensor, AT kq_indices_tensor, 15 | AT grad_weight_tensor, AT grad_value_tensor, AT grad_out_F_tensor 16 | ); 17 | 18 | #ifdef __cplusplus 19 | extern "C" { 20 | #endif 21 | 22 | void scalar_attention_forward_launcher( 23 | int m, int h, int c, const float* weight, const float* value, float* out_F, const int* kq_indices 24 | ); 25 | void scalar_attention_backward_launcher( 26 | int m, int h, int c, const float* weight, const float* value, const int* kq_indices, 27 | float* grad_weight, float* grad_value, const float* grad_out_F 28 | ); 29 | 30 | #ifdef __cplusplus 31 | } 32 | #endif 33 | #endif 34 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from src.data.scannet_loader import * 4 | from src.data.s3dis_loader import * 5 | 6 | ALL_DATA_MODULES = [ 7 | ScanNetRGBDataModule, 8 | S3DISArea5RGBDataModule, 9 | ] 10 | ALL_DATASETS = [ 11 | ScanNetRGBDataset, 12 | S3DISArea5RGBDataset, 13 | ScanNetRGBDataset_, 14 | ] 15 | data_module_str_mapping = {d.__name__: d for d in ALL_DATA_MODULES} 16 | dataset_str_mapping = {d.__name__: d for d in ALL_DATASETS} 17 | 18 | 19 | def get_data_module(name: str): 20 | if name not in data_module_str_mapping.keys(): 21 | logging.error( 22 | f"data_module {name}, does not exists in ".join( 23 | data_module_str_mapping.keys() 24 | ) 25 | ) 26 | return data_module_str_mapping[name] 27 | 28 | 29 | def get_dataset(name: str): 30 | if name not in dataset_str_mapping.keys(): 31 | logging.error( 32 | f"dataset {name}, does not exists in ".join( 33 | dataset_str_mapping.keys() 34 | ) 35 | ) 36 | return dataset_str_mapping[name] 37 | -------------------------------------------------------------------------------- /src/data/collate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import gin 4 | import MinkowskiEngine as ME 5 | 6 | 7 | @gin.configurable 8 | class CollationFunctionFactory: 9 | def __init__(self, collation_type="collate_default"): 10 | if collation_type == "collate_default": 11 | self.collation_fn = self.collate_default 12 | elif collation_type == "collate_minkowski": 13 | self.collation_fn = self.collate_minkowski 14 | else: 15 | raise ValueError(f"collation_type {collation_type} not found") 16 | 17 | def __call__(self, list_data): 18 | return self.collation_fn(list_data) 19 | 20 | def collate_default(self, list_data): 21 | return list_data 22 | 23 | def collate_minkowski(self, list_data): 24 | B = len(list_data) 25 | list_data = [data for data in list_data if data is not None] 26 | if B != len(list_data): 27 | logging.info(f"Retain {len(list_data)} from {B} data.") 28 | if len(list_data) == 0: 29 | raise ValueError("No data in the batch") 30 | 31 | coords, feats, labels, extra_packages = list(zip(*list_data)) 32 | row_splits = [c.shape[0] for c in coords] 33 | coords_batch, feats_batch, labels_batch = ME.utils.sparse_collate( 34 | coords, feats, labels, dtype=coords[0].dtype 35 | ) 36 | return { 37 | "coordinates": coords_batch, 38 | "features": feats_batch, 39 | "labels": labels_batch, 40 | "row_splits": row_splits, 41 | "batch_size": B, 42 | "extra_packages": extra_packages, 43 | } -------------------------------------------------------------------------------- /src/data/meta_data/s3dis/area1.txt: -------------------------------------------------------------------------------- 1 | Area_1/office_26.ply 2 | Area_1/conferenceRoom_2.ply 3 | Area_1/hallway_6.ply 4 | Area_1/office_21.ply 5 | Area_1/hallway_3.ply 6 | Area_1/office_24.ply 7 | Area_1/hallway_8.ply 8 | Area_1/office_20.ply 9 | Area_1/office_22.ply 10 | Area_1/office_13.ply 11 | Area_1/office_6.ply 12 | Area_1/office_23.ply 13 | Area_1/office_7.ply 14 | Area_1/hallway_5.ply 15 | Area_1/office_11.ply 16 | Area_1/copyRoom_1.ply 17 | Area_1/office_30.ply 18 | Area_1/office_28.ply 19 | Area_1/pantry_1.ply 20 | Area_1/office_9.ply 21 | Area_1/office_29.ply 22 | Area_1/office_14.ply 23 | Area_1/office_18.ply 24 | Area_1/office_16.ply 25 | Area_1/office_17.ply 26 | Area_1/office_1.ply 27 | Area_1/office_3.ply 28 | Area_1/office_31.ply 29 | Area_1/office_25.ply 30 | Area_1/office_15.ply 31 | Area_1/hallway_1.ply 32 | Area_1/office_10.ply 33 | Area_1/office_5.ply 34 | Area_1/conferenceRoom_1.ply 35 | Area_1/office_4.ply 36 | Area_1/hallway_4.ply 37 | Area_1/office_2.ply 38 | Area_1/WC_1.ply 39 | Area_1/office_27.ply 40 | Area_1/office_8.ply 41 | Area_1/hallway_7.ply 42 | Area_1/office_12.ply 43 | Area_1/office_19.ply 44 | Area_1/hallway_2.ply -------------------------------------------------------------------------------- /src/data/meta_data/s3dis/area2.txt: -------------------------------------------------------------------------------- 1 | Area_2/storage_9.ply 2 | Area_2/storage_5.ply 3 | Area_2/hallway_6.ply 4 | Area_2/hallway_9.ply 5 | Area_2/storage_8.ply 6 | Area_2/hallway_3.ply 7 | Area_2/hallway_8.ply 8 | Area_2/storage_2.ply 9 | Area_2/storage_7.ply 10 | Area_2/office_13.ply 11 | Area_2/office_6.ply 12 | Area_2/storage_3.ply 13 | Area_2/office_7.ply 14 | Area_2/storage_6.ply 15 | Area_2/hallway_5.ply 16 | Area_2/hallway_11.ply 17 | Area_2/office_11.ply 18 | Area_2/hallway_12.ply 19 | Area_2/WC_2.ply 20 | Area_2/office_9.ply 21 | Area_2/storage_1.ply 22 | Area_2/office_14.ply 23 | Area_2/auditorium_2.ply 24 | Area_2/office_1.ply 25 | Area_2/office_3.ply 26 | Area_2/hallway_1.ply 27 | Area_2/office_10.ply 28 | Area_2/office_5.ply 29 | Area_2/conferenceRoom_1.ply 30 | Area_2/office_4.ply 31 | Area_2/hallway_4.ply 32 | Area_2/office_2.ply 33 | Area_2/WC_1.ply 34 | Area_2/storage_4.ply 35 | Area_2/auditorium_1.ply 36 | Area_2/office_8.ply 37 | Area_2/hallway_7.ply 38 | Area_2/office_12.ply 39 | Area_2/hallway_2.ply 40 | Area_2/hallway_10.ply -------------------------------------------------------------------------------- /src/data/meta_data/s3dis/area3.txt: -------------------------------------------------------------------------------- 1 | Area_3/hallway_6.ply 2 | Area_3/hallway_3.ply 3 | Area_3/storage_2.ply 4 | Area_3/office_6.ply 5 | Area_3/office_7.ply 6 | Area_3/hallway_5.ply 7 | Area_3/WC_2.ply 8 | Area_3/office_9.ply 9 | Area_3/storage_1.ply 10 | Area_3/office_1.ply 11 | Area_3/lounge_1.ply 12 | Area_3/office_3.ply 13 | Area_3/hallway_1.ply 14 | Area_3/office_10.ply 15 | Area_3/office_5.ply 16 | Area_3/conferenceRoom_1.ply 17 | Area_3/office_4.ply 18 | Area_3/hallway_4.ply 19 | Area_3/office_2.ply 20 | Area_3/WC_1.ply 21 | Area_3/office_8.ply 22 | Area_3/hallway_2.ply 23 | Area_3/lounge_2.ply -------------------------------------------------------------------------------- /src/data/meta_data/s3dis/area4.txt: -------------------------------------------------------------------------------- 1 | Area_4/conferenceRoom_2.ply 2 | Area_4/hallway_6.ply 3 | Area_4/office_21.ply 4 | Area_4/hallway_9.ply 5 | Area_4/hallway_3.ply 6 | Area_4/hallway_8.ply 7 | Area_4/office_20.ply 8 | Area_4/office_22.ply 9 | Area_4/hallway_13.ply 10 | Area_4/storage_2.ply 11 | Area_4/office_13.ply 12 | Area_4/office_6.ply 13 | Area_4/storage_3.ply 14 | Area_4/lobby_1.ply 15 | Area_4/office_7.ply 16 | Area_4/hallway_5.ply 17 | Area_4/hallway_11.ply 18 | Area_4/office_11.ply 19 | Area_4/hallway_12.ply 20 | Area_4/WC_4.ply 21 | Area_4/WC_2.ply 22 | Area_4/lobby_2.ply 23 | Area_4/office_9.ply 24 | Area_4/hallway_14.ply 25 | Area_4/WC_3.ply 26 | Area_4/conferenceRoom_3.ply 27 | Area_4/storage_1.ply 28 | Area_4/office_14.ply 29 | Area_4/office_18.ply 30 | Area_4/office_16.ply 31 | Area_4/office_17.ply 32 | Area_4/office_1.ply 33 | Area_4/office_3.ply 34 | Area_4/office_15.ply 35 | Area_4/hallway_1.ply 36 | Area_4/office_10.ply 37 | Area_4/office_5.ply 38 | Area_4/conferenceRoom_1.ply 39 | Area_4/office_4.ply 40 | Area_4/hallway_4.ply 41 | Area_4/office_2.ply 42 | Area_4/WC_1.ply 43 | Area_4/storage_4.ply 44 | Area_4/office_8.ply 45 | Area_4/hallway_7.ply 46 | Area_4/office_12.ply 47 | Area_4/office_19.ply 48 | Area_4/hallway_2.ply 49 | Area_4/hallway_10.ply -------------------------------------------------------------------------------- /src/data/meta_data/s3dis/area5.txt: -------------------------------------------------------------------------------- 1 | Area_5/office_26.ply 2 | Area_5/conferenceRoom_2.ply 3 | Area_5/hallway_6.ply 4 | Area_5/office_21.ply 5 | Area_5/hallway_9.ply 6 | Area_5/hallway_3.ply 7 | Area_5/office_39.ply 8 | Area_5/office_24.ply 9 | Area_5/hallway_8.ply 10 | Area_5/office_20.ply 11 | Area_5/office_22.ply 12 | Area_5/hallway_13.ply 13 | Area_5/storage_2.ply 14 | Area_5/office_13.ply 15 | Area_5/office_6.ply 16 | Area_5/office_41.ply 17 | Area_5/office_23.ply 18 | Area_5/office_34.ply 19 | Area_5/office_35.ply 20 | Area_5/office_36.ply 21 | Area_5/storage_3.ply 22 | Area_5/lobby_1.ply 23 | Area_5/office_7.ply 24 | Area_5/hallway_5.ply 25 | Area_5/hallway_11.ply 26 | Area_5/office_11.ply 27 | Area_5/office_37.ply 28 | Area_5/hallway_12.ply 29 | Area_5/office_30.ply 30 | Area_5/office_28.ply 31 | Area_5/pantry_1.ply 32 | Area_5/WC_2.ply 33 | Area_5/office_9.ply 34 | Area_5/office_38.ply 35 | Area_5/office_42.ply 36 | Area_5/hallway_14.ply 37 | Area_5/conferenceRoom_3.ply 38 | Area_5/storage_1.ply 39 | Area_5/office_29.ply 40 | Area_5/office_14.ply 41 | Area_5/office_18.ply 42 | Area_5/office_16.ply 43 | Area_5/office_17.ply 44 | Area_5/office_1.ply 45 | Area_5/office_3.ply 46 | Area_5/office_31.ply 47 | Area_5/office_25.ply 48 | Area_5/hallway_15.ply 49 | Area_5/office_15.ply 50 | Area_5/hallway_1.ply 51 | Area_5/office_10.ply 52 | Area_5/office_5.ply 53 | Area_5/conferenceRoom_1.ply 54 | Area_5/office_4.ply 55 | Area_5/hallway_4.ply 56 | Area_5/office_2.ply 57 | Area_5/office_33.ply 58 | Area_5/WC_1.ply 59 | Area_5/office_27.ply 60 | Area_5/storage_4.ply 61 | Area_5/office_40.ply 62 | Area_5/office_8.ply 63 | Area_5/hallway_7.ply 64 | Area_5/office_12.ply 65 | Area_5/office_32.ply 66 | Area_5/office_19.ply 67 | Area_5/hallway_2.ply 68 | Area_5/hallway_10.ply -------------------------------------------------------------------------------- /src/data/meta_data/s3dis/area6.txt: -------------------------------------------------------------------------------- 1 | Area_6/office_26.ply 2 | Area_6/hallway_6.ply 3 | Area_6/office_21.ply 4 | Area_6/hallway_3.ply 5 | Area_6/office_24.ply 6 | Area_6/office_20.ply 7 | Area_6/office_22.ply 8 | Area_6/office_13.ply 9 | Area_6/office_6.ply 10 | Area_6/office_23.ply 11 | Area_6/office_34.ply 12 | Area_6/office_35.ply 13 | Area_6/office_36.ply 14 | Area_6/office_7.ply 15 | Area_6/hallway_5.ply 16 | Area_6/office_11.ply 17 | Area_6/copyRoom_1.ply 18 | Area_6/office_37.ply 19 | Area_6/office_30.ply 20 | Area_6/office_28.ply 21 | Area_6/pantry_1.ply 22 | Area_6/openspace_1.ply 23 | Area_6/office_9.ply 24 | Area_6/office_29.ply 25 | Area_6/office_14.ply 26 | Area_6/office_18.ply 27 | Area_6/office_16.ply 28 | Area_6/office_17.ply 29 | Area_6/office_1.ply 30 | Area_6/lounge_1.ply 31 | Area_6/office_3.ply 32 | Area_6/office_31.ply 33 | Area_6/office_25.ply 34 | Area_6/office_15.ply 35 | Area_6/hallway_1.ply 36 | Area_6/office_10.ply 37 | Area_6/office_5.ply 38 | Area_6/conferenceRoom_1.ply 39 | Area_6/office_4.ply 40 | Area_6/hallway_4.ply 41 | Area_6/office_2.ply 42 | Area_6/office_33.ply 43 | Area_6/office_27.ply 44 | Area_6/office_8.ply 45 | Area_6/office_12.ply 46 | Area_6/office_32.ply 47 | Area_6/office_19.ply 48 | Area_6/hallway_2.ply -------------------------------------------------------------------------------- /src/data/meta_data/scannet/scannetv2_test.txt: -------------------------------------------------------------------------------- 1 | scene0707_00 2 | scene0708_00 3 | scene0709_00 4 | scene0710_00 5 | scene0711_00 6 | scene0712_00 7 | scene0713_00 8 | scene0714_00 9 | scene0715_00 10 | scene0716_00 11 | scene0717_00 12 | scene0718_00 13 | scene0719_00 14 | scene0720_00 15 | scene0721_00 16 | scene0722_00 17 | scene0723_00 18 | scene0724_00 19 | scene0725_00 20 | scene0726_00 21 | scene0727_00 22 | scene0728_00 23 | scene0729_00 24 | scene0730_00 25 | scene0731_00 26 | scene0732_00 27 | scene0733_00 28 | scene0734_00 29 | scene0735_00 30 | scene0736_00 31 | scene0737_00 32 | scene0738_00 33 | scene0739_00 34 | scene0740_00 35 | scene0741_00 36 | scene0742_00 37 | scene0743_00 38 | scene0744_00 39 | scene0745_00 40 | scene0746_00 41 | scene0747_00 42 | scene0748_00 43 | scene0749_00 44 | scene0750_00 45 | scene0751_00 46 | scene0752_00 47 | scene0753_00 48 | scene0754_00 49 | scene0755_00 50 | scene0756_00 51 | scene0757_00 52 | scene0758_00 53 | scene0759_00 54 | scene0760_00 55 | scene0761_00 56 | scene0762_00 57 | scene0763_00 58 | scene0764_00 59 | scene0765_00 60 | scene0766_00 61 | scene0767_00 62 | scene0768_00 63 | scene0769_00 64 | scene0770_00 65 | scene0771_00 66 | scene0772_00 67 | scene0773_00 68 | scene0774_00 69 | scene0775_00 70 | scene0776_00 71 | scene0777_00 72 | scene0778_00 73 | scene0779_00 74 | scene0780_00 75 | scene0781_00 76 | scene0782_00 77 | scene0783_00 78 | scene0784_00 79 | scene0785_00 80 | scene0786_00 81 | scene0787_00 82 | scene0788_00 83 | scene0789_00 84 | scene0790_00 85 | scene0791_00 86 | scene0792_00 87 | scene0793_00 88 | scene0794_00 89 | scene0795_00 90 | scene0796_00 91 | scene0797_00 92 | scene0798_00 93 | scene0799_00 94 | scene0800_00 95 | scene0801_00 96 | scene0802_00 97 | scene0803_00 98 | scene0804_00 99 | scene0805_00 100 | scene0806_00 101 | -------------------------------------------------------------------------------- /src/data/meta_data/scannet/scannetv2_train.txt: -------------------------------------------------------------------------------- 1 | scene0191_00 2 | scene0191_01 3 | scene0191_02 4 | scene0119_00 5 | scene0230_00 6 | scene0528_00 7 | scene0528_01 8 | scene0705_00 9 | scene0705_01 10 | scene0705_02 11 | scene0415_00 12 | scene0415_01 13 | scene0415_02 14 | scene0007_00 15 | scene0141_00 16 | scene0141_01 17 | scene0141_02 18 | scene0515_00 19 | scene0515_01 20 | scene0515_02 21 | scene0447_00 22 | scene0447_01 23 | scene0447_02 24 | scene0531_00 25 | scene0503_00 26 | scene0285_00 27 | scene0069_00 28 | scene0584_00 29 | scene0584_01 30 | scene0584_02 31 | scene0581_00 32 | scene0581_01 33 | scene0581_02 34 | scene0620_00 35 | scene0620_01 36 | scene0263_00 37 | scene0263_01 38 | scene0481_00 39 | scene0481_01 40 | scene0020_00 41 | scene0020_01 42 | scene0291_00 43 | scene0291_01 44 | scene0291_02 45 | scene0469_00 46 | scene0469_01 47 | scene0469_02 48 | scene0659_00 49 | scene0659_01 50 | scene0024_00 51 | scene0024_01 52 | scene0024_02 53 | scene0564_00 54 | scene0117_00 55 | scene0027_00 56 | scene0027_01 57 | scene0027_02 58 | scene0028_00 59 | scene0330_00 60 | scene0418_00 61 | scene0418_01 62 | scene0418_02 63 | scene0233_00 64 | scene0233_01 65 | scene0673_00 66 | scene0673_01 67 | scene0673_02 68 | scene0673_03 69 | scene0673_04 70 | scene0673_05 71 | scene0585_00 72 | scene0585_01 73 | scene0362_00 74 | scene0362_01 75 | scene0362_02 76 | scene0362_03 77 | scene0035_00 78 | scene0035_01 79 | scene0358_00 80 | scene0358_01 81 | scene0358_02 82 | scene0037_00 83 | scene0194_00 84 | scene0321_00 85 | scene0293_00 86 | scene0293_01 87 | scene0623_00 88 | scene0623_01 89 | scene0592_00 90 | scene0592_01 91 | scene0569_00 92 | scene0569_01 93 | scene0413_00 94 | scene0313_00 95 | scene0313_01 96 | scene0313_02 97 | scene0480_00 98 | scene0480_01 99 | scene0401_00 100 | scene0517_00 101 | scene0517_01 102 | scene0517_02 103 | scene0032_00 104 | scene0032_01 105 | scene0613_00 106 | scene0613_01 107 | scene0613_02 108 | scene0306_00 109 | scene0306_01 110 | scene0052_00 111 | scene0052_01 112 | scene0052_02 113 | scene0053_00 114 | scene0444_00 115 | scene0444_01 116 | scene0055_00 117 | scene0055_01 118 | scene0055_02 119 | scene0560_00 120 | scene0589_00 121 | scene0589_01 122 | scene0589_02 123 | scene0610_00 124 | scene0610_01 125 | scene0610_02 126 | scene0364_00 127 | scene0364_01 128 | scene0383_00 129 | scene0383_01 130 | scene0383_02 131 | scene0006_00 132 | scene0006_01 133 | scene0006_02 134 | scene0275_00 135 | scene0451_00 136 | scene0451_01 137 | scene0451_02 138 | scene0451_03 139 | scene0451_04 140 | scene0451_05 141 | scene0135_00 142 | scene0065_00 143 | scene0065_01 144 | scene0065_02 145 | scene0104_00 146 | scene0674_00 147 | scene0674_01 148 | scene0448_00 149 | scene0448_01 150 | scene0448_02 151 | scene0502_00 152 | scene0502_01 153 | scene0502_02 154 | scene0440_00 155 | scene0440_01 156 | scene0440_02 157 | scene0071_00 158 | scene0072_00 159 | scene0072_01 160 | scene0072_02 161 | scene0509_00 162 | scene0509_01 163 | scene0509_02 164 | scene0649_00 165 | scene0649_01 166 | scene0602_00 167 | scene0694_00 168 | scene0694_01 169 | scene0101_00 170 | scene0101_01 171 | scene0101_02 172 | scene0101_03 173 | scene0101_04 174 | scene0101_05 175 | scene0218_00 176 | scene0218_01 177 | scene0579_00 178 | scene0579_01 179 | scene0579_02 180 | scene0039_00 181 | scene0039_01 182 | scene0493_00 183 | scene0493_01 184 | scene0242_00 185 | scene0242_01 186 | scene0242_02 187 | scene0083_00 188 | scene0083_01 189 | scene0127_00 190 | scene0127_01 191 | scene0662_00 192 | scene0662_01 193 | scene0662_02 194 | scene0018_00 195 | scene0087_00 196 | scene0087_01 197 | scene0087_02 198 | scene0332_00 199 | scene0332_01 200 | scene0332_02 201 | scene0628_00 202 | scene0628_01 203 | scene0628_02 204 | scene0134_00 205 | scene0134_01 206 | scene0134_02 207 | scene0238_00 208 | scene0238_01 209 | scene0092_00 210 | scene0092_01 211 | scene0092_02 212 | scene0092_03 213 | scene0092_04 214 | scene0022_00 215 | scene0022_01 216 | scene0467_00 217 | scene0392_00 218 | scene0392_01 219 | scene0392_02 220 | scene0424_00 221 | scene0424_01 222 | scene0424_02 223 | scene0646_00 224 | scene0646_01 225 | scene0646_02 226 | scene0098_00 227 | scene0098_01 228 | scene0044_00 229 | scene0044_01 230 | scene0044_02 231 | scene0510_00 232 | scene0510_01 233 | scene0510_02 234 | scene0571_00 235 | scene0571_01 236 | scene0166_00 237 | scene0166_01 238 | scene0166_02 239 | scene0563_00 240 | scene0172_00 241 | scene0172_01 242 | scene0388_00 243 | scene0388_01 244 | scene0215_00 245 | scene0215_01 246 | scene0252_00 247 | scene0287_00 248 | scene0668_00 249 | scene0572_00 250 | scene0572_01 251 | scene0572_02 252 | scene0026_00 253 | scene0224_00 254 | scene0113_00 255 | scene0113_01 256 | scene0551_00 257 | scene0381_00 258 | scene0381_01 259 | scene0381_02 260 | scene0371_00 261 | scene0371_01 262 | scene0460_00 263 | scene0118_00 264 | scene0118_01 265 | scene0118_02 266 | scene0417_00 267 | scene0008_00 268 | scene0634_00 269 | scene0521_00 270 | scene0123_00 271 | scene0123_01 272 | scene0123_02 273 | scene0045_00 274 | scene0045_01 275 | scene0511_00 276 | scene0511_01 277 | scene0114_00 278 | scene0114_01 279 | scene0114_02 280 | scene0070_00 281 | scene0029_00 282 | scene0029_01 283 | scene0029_02 284 | scene0129_00 285 | scene0103_00 286 | scene0103_01 287 | scene0002_00 288 | scene0002_01 289 | scene0132_00 290 | scene0132_01 291 | scene0132_02 292 | scene0124_00 293 | scene0124_01 294 | scene0143_00 295 | scene0143_01 296 | scene0143_02 297 | scene0604_00 298 | scene0604_01 299 | scene0604_02 300 | scene0507_00 301 | scene0105_00 302 | scene0105_01 303 | scene0105_02 304 | scene0428_00 305 | scene0428_01 306 | scene0311_00 307 | scene0140_00 308 | scene0140_01 309 | scene0182_00 310 | scene0182_01 311 | scene0182_02 312 | scene0142_00 313 | scene0142_01 314 | scene0399_00 315 | scene0399_01 316 | scene0012_00 317 | scene0012_01 318 | scene0012_02 319 | scene0060_00 320 | scene0060_01 321 | scene0370_00 322 | scene0370_01 323 | scene0370_02 324 | scene0310_00 325 | scene0310_01 326 | scene0310_02 327 | scene0661_00 328 | scene0650_00 329 | scene0152_00 330 | scene0152_01 331 | scene0152_02 332 | scene0158_00 333 | scene0158_01 334 | scene0158_02 335 | scene0482_00 336 | scene0482_01 337 | scene0600_00 338 | scene0600_01 339 | scene0600_02 340 | scene0393_00 341 | scene0393_01 342 | scene0393_02 343 | scene0562_00 344 | scene0174_00 345 | scene0174_01 346 | scene0157_00 347 | scene0157_01 348 | scene0161_00 349 | scene0161_01 350 | scene0161_02 351 | scene0159_00 352 | scene0254_00 353 | scene0254_01 354 | scene0115_00 355 | scene0115_01 356 | scene0115_02 357 | scene0162_00 358 | scene0163_00 359 | scene0163_01 360 | scene0523_00 361 | scene0523_01 362 | scene0523_02 363 | scene0459_00 364 | scene0459_01 365 | scene0175_00 366 | scene0085_00 367 | scene0085_01 368 | scene0279_00 369 | scene0279_01 370 | scene0279_02 371 | scene0201_00 372 | scene0201_01 373 | scene0201_02 374 | scene0283_00 375 | scene0456_00 376 | scene0456_01 377 | scene0429_00 378 | scene0043_00 379 | scene0043_01 380 | scene0419_00 381 | scene0419_01 382 | scene0419_02 383 | scene0368_00 384 | scene0368_01 385 | scene0348_00 386 | scene0348_01 387 | scene0348_02 388 | scene0442_00 389 | scene0178_00 390 | scene0380_00 391 | scene0380_01 392 | scene0380_02 393 | scene0165_00 394 | scene0165_01 395 | scene0165_02 396 | scene0181_00 397 | scene0181_01 398 | scene0181_02 399 | scene0181_03 400 | scene0333_00 401 | scene0614_00 402 | scene0614_01 403 | scene0614_02 404 | scene0404_00 405 | scene0404_01 406 | scene0404_02 407 | scene0185_00 408 | scene0126_00 409 | scene0126_01 410 | scene0126_02 411 | scene0519_00 412 | scene0236_00 413 | scene0236_01 414 | scene0189_00 415 | scene0075_00 416 | scene0267_00 417 | scene0192_00 418 | scene0192_01 419 | scene0192_02 420 | scene0281_00 421 | scene0420_00 422 | scene0420_01 423 | scene0420_02 424 | scene0195_00 425 | scene0195_01 426 | scene0195_02 427 | scene0597_00 428 | scene0597_01 429 | scene0597_02 430 | scene0041_00 431 | scene0041_01 432 | scene0111_00 433 | scene0111_01 434 | scene0111_02 435 | scene0666_00 436 | scene0666_01 437 | scene0666_02 438 | scene0200_00 439 | scene0200_01 440 | scene0200_02 441 | scene0536_00 442 | scene0536_01 443 | scene0536_02 444 | scene0390_00 445 | scene0280_00 446 | scene0280_01 447 | scene0280_02 448 | scene0344_00 449 | scene0344_01 450 | scene0205_00 451 | scene0205_01 452 | scene0205_02 453 | scene0484_00 454 | scene0484_01 455 | scene0009_00 456 | scene0009_01 457 | scene0009_02 458 | scene0302_00 459 | scene0302_01 460 | scene0209_00 461 | scene0209_01 462 | scene0209_02 463 | scene0210_00 464 | scene0210_01 465 | scene0395_00 466 | scene0395_01 467 | scene0395_02 468 | scene0683_00 469 | scene0601_00 470 | scene0601_01 471 | scene0214_00 472 | scene0214_01 473 | scene0214_02 474 | scene0477_00 475 | scene0477_01 476 | scene0439_00 477 | scene0439_01 478 | scene0468_00 479 | scene0468_01 480 | scene0468_02 481 | scene0546_00 482 | scene0466_00 483 | scene0466_01 484 | scene0220_00 485 | scene0220_01 486 | scene0220_02 487 | scene0122_00 488 | scene0122_01 489 | scene0130_00 490 | scene0110_00 491 | scene0110_01 492 | scene0110_02 493 | scene0327_00 494 | scene0156_00 495 | scene0266_00 496 | scene0266_01 497 | scene0001_00 498 | scene0001_01 499 | scene0228_00 500 | scene0199_00 501 | scene0219_00 502 | scene0464_00 503 | scene0232_00 504 | scene0232_01 505 | scene0232_02 506 | scene0299_00 507 | scene0299_01 508 | scene0530_00 509 | scene0363_00 510 | scene0453_00 511 | scene0453_01 512 | scene0570_00 513 | scene0570_01 514 | scene0570_02 515 | scene0183_00 516 | scene0239_00 517 | scene0239_01 518 | scene0239_02 519 | scene0373_00 520 | scene0373_01 521 | scene0241_00 522 | scene0241_01 523 | scene0241_02 524 | scene0188_00 525 | scene0622_00 526 | scene0622_01 527 | scene0244_00 528 | scene0244_01 529 | scene0691_00 530 | scene0691_01 531 | scene0206_00 532 | scene0206_01 533 | scene0206_02 534 | scene0247_00 535 | scene0247_01 536 | scene0061_00 537 | scene0061_01 538 | scene0082_00 539 | scene0250_00 540 | scene0250_01 541 | scene0250_02 542 | scene0501_00 543 | scene0501_01 544 | scene0501_02 545 | scene0320_00 546 | scene0320_01 547 | scene0320_02 548 | scene0320_03 549 | scene0631_00 550 | scene0631_01 551 | scene0631_02 552 | scene0255_00 553 | scene0255_01 554 | scene0255_02 555 | scene0047_00 556 | scene0265_00 557 | scene0265_01 558 | scene0265_02 559 | scene0004_00 560 | scene0336_00 561 | scene0336_01 562 | scene0058_00 563 | scene0058_01 564 | scene0260_00 565 | scene0260_01 566 | scene0260_02 567 | scene0243_00 568 | scene0603_00 569 | scene0603_01 570 | scene0093_00 571 | scene0093_01 572 | scene0093_02 573 | scene0109_00 574 | scene0109_01 575 | scene0434_00 576 | scene0434_01 577 | scene0434_02 578 | scene0290_00 579 | scene0627_00 580 | scene0627_01 581 | scene0470_00 582 | scene0470_01 583 | scene0137_00 584 | scene0137_01 585 | scene0137_02 586 | scene0270_00 587 | scene0270_01 588 | scene0270_02 589 | scene0271_00 590 | scene0271_01 591 | scene0504_00 592 | scene0274_00 593 | scene0274_01 594 | scene0274_02 595 | scene0036_00 596 | scene0036_01 597 | scene0276_00 598 | scene0276_01 599 | scene0272_00 600 | scene0272_01 601 | scene0499_00 602 | scene0698_00 603 | scene0698_01 604 | scene0051_00 605 | scene0051_01 606 | scene0051_02 607 | scene0051_03 608 | scene0108_00 609 | scene0245_00 610 | scene0369_00 611 | scene0369_01 612 | scene0369_02 613 | scene0284_00 614 | scene0289_00 615 | scene0289_01 616 | scene0286_00 617 | scene0286_01 618 | scene0286_02 619 | scene0286_03 620 | scene0031_00 621 | scene0031_01 622 | scene0031_02 623 | scene0545_00 624 | scene0545_01 625 | scene0545_02 626 | scene0557_00 627 | scene0557_01 628 | scene0557_02 629 | scene0533_00 630 | scene0533_01 631 | scene0116_00 632 | scene0116_01 633 | scene0116_02 634 | scene0611_00 635 | scene0611_01 636 | scene0688_00 637 | scene0294_00 638 | scene0294_01 639 | scene0294_02 640 | scene0295_00 641 | scene0295_01 642 | scene0296_00 643 | scene0296_01 644 | scene0596_00 645 | scene0596_01 646 | scene0596_02 647 | scene0532_00 648 | scene0532_01 649 | scene0637_00 650 | scene0638_00 651 | scene0121_00 652 | scene0121_01 653 | scene0121_02 654 | scene0040_00 655 | scene0040_01 656 | scene0197_00 657 | scene0197_01 658 | scene0197_02 659 | scene0410_00 660 | scene0410_01 661 | scene0305_00 662 | scene0305_01 663 | scene0615_00 664 | scene0615_01 665 | scene0703_00 666 | scene0703_01 667 | scene0555_00 668 | scene0297_00 669 | scene0297_01 670 | scene0297_02 671 | scene0582_00 672 | scene0582_01 673 | scene0582_02 674 | scene0023_00 675 | scene0094_00 676 | scene0013_00 677 | scene0013_01 678 | scene0013_02 679 | scene0136_00 680 | scene0136_01 681 | scene0136_02 682 | scene0407_00 683 | scene0407_01 684 | scene0062_00 685 | scene0062_01 686 | scene0062_02 687 | scene0386_00 688 | scene0318_00 689 | scene0554_00 690 | scene0554_01 691 | scene0497_00 692 | scene0213_00 693 | scene0258_00 694 | scene0323_00 695 | scene0323_01 696 | scene0324_00 697 | scene0324_01 698 | scene0016_00 699 | scene0016_01 700 | scene0016_02 701 | scene0681_00 702 | scene0398_00 703 | scene0398_01 704 | scene0227_00 705 | scene0090_00 706 | scene0066_00 707 | scene0262_00 708 | scene0262_01 709 | scene0155_00 710 | scene0155_01 711 | scene0155_02 712 | scene0352_00 713 | scene0352_01 714 | scene0352_02 715 | scene0038_00 716 | scene0038_01 717 | scene0038_02 718 | scene0335_00 719 | scene0335_01 720 | scene0335_02 721 | scene0261_00 722 | scene0261_01 723 | scene0261_02 724 | scene0261_03 725 | scene0640_00 726 | scene0640_01 727 | scene0640_02 728 | scene0080_00 729 | scene0080_01 730 | scene0080_02 731 | scene0403_00 732 | scene0403_01 733 | scene0282_00 734 | scene0282_01 735 | scene0282_02 736 | scene0682_00 737 | scene0173_00 738 | scene0173_01 739 | scene0173_02 740 | scene0522_00 741 | scene0687_00 742 | scene0345_00 743 | scene0345_01 744 | scene0612_00 745 | scene0612_01 746 | scene0411_00 747 | scene0411_01 748 | scene0411_02 749 | scene0625_00 750 | scene0625_01 751 | scene0211_00 752 | scene0211_01 753 | scene0211_02 754 | scene0211_03 755 | scene0676_00 756 | scene0676_01 757 | scene0179_00 758 | scene0498_00 759 | scene0498_01 760 | scene0498_02 761 | scene0547_00 762 | scene0547_01 763 | scene0547_02 764 | scene0269_00 765 | scene0269_01 766 | scene0269_02 767 | scene0366_00 768 | scene0680_00 769 | scene0680_01 770 | scene0588_00 771 | scene0588_01 772 | scene0588_02 773 | scene0588_03 774 | scene0346_00 775 | scene0346_01 776 | scene0359_00 777 | scene0359_01 778 | scene0014_00 779 | scene0120_00 780 | scene0120_01 781 | scene0212_00 782 | scene0212_01 783 | scene0212_02 784 | scene0176_00 785 | scene0049_00 786 | scene0259_00 787 | scene0259_01 788 | scene0586_00 789 | scene0586_01 790 | scene0586_02 791 | scene0309_00 792 | scene0309_01 793 | scene0125_00 794 | scene0455_00 795 | scene0177_00 796 | scene0177_01 797 | scene0177_02 798 | scene0326_00 799 | scene0372_00 800 | scene0171_00 801 | scene0171_01 802 | scene0374_00 803 | scene0654_00 804 | scene0654_01 805 | scene0445_00 806 | scene0445_01 807 | scene0475_00 808 | scene0475_01 809 | scene0475_02 810 | scene0349_00 811 | scene0349_01 812 | scene0234_00 813 | scene0669_00 814 | scene0669_01 815 | scene0375_00 816 | scene0375_01 817 | scene0375_02 818 | scene0387_00 819 | scene0387_01 820 | scene0387_02 821 | scene0312_00 822 | scene0312_01 823 | scene0312_02 824 | scene0384_00 825 | scene0385_00 826 | scene0385_01 827 | scene0385_02 828 | scene0000_00 829 | scene0000_01 830 | scene0000_02 831 | scene0376_00 832 | scene0376_01 833 | scene0376_02 834 | scene0301_00 835 | scene0301_01 836 | scene0301_02 837 | scene0322_00 838 | scene0542_00 839 | scene0079_00 840 | scene0079_01 841 | scene0099_00 842 | scene0099_01 843 | scene0476_00 844 | scene0476_01 845 | scene0476_02 846 | scene0394_00 847 | scene0394_01 848 | scene0147_00 849 | scene0147_01 850 | scene0067_00 851 | scene0067_01 852 | scene0067_02 853 | scene0397_00 854 | scene0397_01 855 | scene0337_00 856 | scene0337_01 857 | scene0337_02 858 | scene0431_00 859 | scene0223_00 860 | scene0223_01 861 | scene0223_02 862 | scene0010_00 863 | scene0010_01 864 | scene0402_00 865 | scene0268_00 866 | scene0268_01 867 | scene0268_02 868 | scene0679_00 869 | scene0679_01 870 | scene0405_00 871 | scene0128_00 872 | scene0408_00 873 | scene0408_01 874 | scene0190_00 875 | scene0107_00 876 | scene0076_00 877 | scene0167_00 878 | scene0361_00 879 | scene0361_01 880 | scene0361_02 881 | scene0216_00 882 | scene0202_00 883 | scene0303_00 884 | scene0303_01 885 | scene0303_02 886 | scene0446_00 887 | scene0446_01 888 | scene0089_00 889 | scene0089_01 890 | scene0089_02 891 | scene0360_00 892 | scene0150_00 893 | scene0150_01 894 | scene0150_02 895 | scene0421_00 896 | scene0421_01 897 | scene0421_02 898 | scene0454_00 899 | scene0626_00 900 | scene0626_01 901 | scene0626_02 902 | scene0186_00 903 | scene0186_01 904 | scene0538_00 905 | scene0479_00 906 | scene0479_01 907 | scene0479_02 908 | scene0656_00 909 | scene0656_01 910 | scene0656_02 911 | scene0656_03 912 | scene0525_00 913 | scene0525_01 914 | scene0525_02 915 | scene0308_00 916 | scene0396_00 917 | scene0396_01 918 | scene0396_02 919 | scene0624_00 920 | scene0292_00 921 | scene0292_01 922 | scene0632_00 923 | scene0253_00 924 | scene0021_00 925 | scene0325_00 926 | scene0325_01 927 | scene0437_00 928 | scene0437_01 929 | scene0438_00 930 | scene0590_00 931 | scene0590_01 932 | scene0400_00 933 | scene0400_01 934 | scene0541_00 935 | scene0541_01 936 | scene0541_02 937 | scene0677_00 938 | scene0677_01 939 | scene0677_02 940 | scene0443_00 941 | scene0315_00 942 | scene0288_00 943 | scene0288_01 944 | scene0288_02 945 | scene0422_00 946 | scene0672_00 947 | scene0672_01 948 | scene0184_00 949 | scene0449_00 950 | scene0449_01 951 | scene0449_02 952 | scene0048_00 953 | scene0048_01 954 | scene0138_00 955 | scene0452_00 956 | scene0452_01 957 | scene0452_02 958 | scene0667_00 959 | scene0667_01 960 | scene0667_02 961 | scene0463_00 962 | scene0463_01 963 | scene0078_00 964 | scene0078_01 965 | scene0078_02 966 | scene0636_00 967 | scene0457_00 968 | scene0457_01 969 | scene0457_02 970 | scene0465_00 971 | scene0465_01 972 | scene0577_00 973 | scene0151_00 974 | scene0151_01 975 | scene0339_00 976 | scene0573_00 977 | scene0573_01 978 | scene0154_00 979 | scene0096_00 980 | scene0096_01 981 | scene0096_02 982 | scene0235_00 983 | scene0168_00 984 | scene0168_01 985 | scene0168_02 986 | scene0594_00 987 | scene0587_00 988 | scene0587_01 989 | scene0587_02 990 | scene0587_03 991 | scene0229_00 992 | scene0229_01 993 | scene0229_02 994 | scene0512_00 995 | scene0106_00 996 | scene0106_01 997 | scene0106_02 998 | scene0472_00 999 | scene0472_01 1000 | scene0472_02 1001 | scene0489_00 1002 | scene0489_01 1003 | scene0489_02 1004 | scene0425_00 1005 | scene0425_01 1006 | scene0641_00 1007 | scene0526_00 1008 | scene0526_01 1009 | scene0317_00 1010 | scene0317_01 1011 | scene0544_00 1012 | scene0017_00 1013 | scene0017_01 1014 | scene0017_02 1015 | scene0042_00 1016 | scene0042_01 1017 | scene0042_02 1018 | scene0576_00 1019 | scene0576_01 1020 | scene0576_02 1021 | scene0347_00 1022 | scene0347_01 1023 | scene0347_02 1024 | scene0436_00 1025 | scene0226_00 1026 | scene0226_01 1027 | scene0485_00 1028 | scene0486_00 1029 | scene0487_00 1030 | scene0487_01 1031 | scene0619_00 1032 | scene0097_00 1033 | scene0367_00 1034 | scene0367_01 1035 | scene0491_00 1036 | scene0492_00 1037 | scene0492_01 1038 | scene0005_00 1039 | scene0005_01 1040 | scene0543_00 1041 | scene0543_01 1042 | scene0543_02 1043 | scene0657_00 1044 | scene0341_00 1045 | scene0341_01 1046 | scene0534_00 1047 | scene0534_01 1048 | scene0319_00 1049 | scene0273_00 1050 | scene0273_01 1051 | scene0225_00 1052 | scene0198_00 1053 | scene0003_00 1054 | scene0003_01 1055 | scene0003_02 1056 | scene0409_00 1057 | scene0409_01 1058 | scene0331_00 1059 | scene0331_01 1060 | scene0505_00 1061 | scene0505_01 1062 | scene0505_02 1063 | scene0505_03 1064 | scene0505_04 1065 | scene0506_00 1066 | scene0057_00 1067 | scene0057_01 1068 | scene0074_00 1069 | scene0074_01 1070 | scene0074_02 1071 | scene0091_00 1072 | scene0112_00 1073 | scene0112_01 1074 | scene0112_02 1075 | scene0240_00 1076 | scene0102_00 1077 | scene0102_01 1078 | scene0513_00 1079 | scene0514_00 1080 | scene0514_01 1081 | scene0537_00 1082 | scene0516_00 1083 | scene0516_01 1084 | scene0495_00 1085 | scene0617_00 1086 | scene0133_00 1087 | scene0520_00 1088 | scene0520_01 1089 | scene0635_00 1090 | scene0635_01 1091 | scene0054_00 1092 | scene0473_00 1093 | scene0473_01 1094 | scene0524_00 1095 | scene0524_01 1096 | scene0379_00 1097 | scene0471_00 1098 | scene0471_01 1099 | scene0471_02 1100 | scene0566_00 1101 | scene0248_00 1102 | scene0248_01 1103 | scene0248_02 1104 | scene0529_00 1105 | scene0529_01 1106 | scene0529_02 1107 | scene0391_00 1108 | scene0264_00 1109 | scene0264_01 1110 | scene0264_02 1111 | scene0675_00 1112 | scene0675_01 1113 | scene0350_00 1114 | scene0350_01 1115 | scene0350_02 1116 | scene0450_00 1117 | scene0068_00 1118 | scene0068_01 1119 | scene0237_00 1120 | scene0237_01 1121 | scene0365_00 1122 | scene0365_01 1123 | scene0365_02 1124 | scene0605_00 1125 | scene0605_01 1126 | scene0539_00 1127 | scene0539_01 1128 | scene0539_02 1129 | scene0540_00 1130 | scene0540_01 1131 | scene0540_02 1132 | scene0170_00 1133 | scene0170_01 1134 | scene0170_02 1135 | scene0433_00 1136 | scene0340_00 1137 | scene0340_01 1138 | scene0340_02 1139 | scene0160_00 1140 | scene0160_01 1141 | scene0160_02 1142 | scene0160_03 1143 | scene0160_04 1144 | scene0059_00 1145 | scene0059_01 1146 | scene0059_02 1147 | scene0056_00 1148 | scene0056_01 1149 | scene0478_00 1150 | scene0478_01 1151 | scene0548_00 1152 | scene0548_01 1153 | scene0548_02 1154 | scene0204_00 1155 | scene0204_01 1156 | scene0204_02 1157 | scene0033_00 1158 | scene0145_00 1159 | scene0483_00 1160 | scene0508_00 1161 | scene0508_01 1162 | scene0508_02 1163 | scene0180_00 1164 | scene0148_00 1165 | scene0556_00 1166 | scene0556_01 1167 | scene0416_00 1168 | scene0416_01 1169 | scene0416_02 1170 | scene0416_03 1171 | scene0416_04 1172 | scene0073_00 1173 | scene0073_01 1174 | scene0073_02 1175 | scene0073_03 1176 | scene0034_00 1177 | scene0034_01 1178 | scene0034_02 1179 | scene0639_00 1180 | scene0561_00 1181 | scene0561_01 1182 | scene0298_00 1183 | scene0692_00 1184 | scene0692_01 1185 | scene0692_02 1186 | scene0692_03 1187 | scene0692_04 1188 | scene0642_00 1189 | scene0642_01 1190 | scene0642_02 1191 | scene0642_03 1192 | scene0630_00 1193 | scene0630_01 1194 | scene0630_02 1195 | scene0630_03 1196 | scene0630_04 1197 | scene0630_05 1198 | scene0630_06 1199 | scene0706_00 1200 | scene0567_00 1201 | scene0567_01 1202 | -------------------------------------------------------------------------------- /src/data/meta_data/scannet/scannetv2_trainval.txt: -------------------------------------------------------------------------------- 1 | scene0000_00 2 | scene0000_01 3 | scene0000_02 4 | scene0001_00 5 | scene0001_01 6 | scene0002_00 7 | scene0002_01 8 | scene0003_00 9 | scene0003_01 10 | scene0003_02 11 | scene0004_00 12 | scene0005_00 13 | scene0005_01 14 | scene0006_00 15 | scene0006_01 16 | scene0006_02 17 | scene0007_00 18 | scene0008_00 19 | scene0009_00 20 | scene0009_01 21 | scene0009_02 22 | scene0010_00 23 | scene0010_01 24 | scene0011_00 25 | scene0011_01 26 | scene0012_00 27 | scene0012_01 28 | scene0012_02 29 | scene0013_00 30 | scene0013_01 31 | scene0013_02 32 | scene0014_00 33 | scene0015_00 34 | scene0016_00 35 | scene0016_01 36 | scene0016_02 37 | scene0017_00 38 | scene0017_01 39 | scene0017_02 40 | scene0018_00 41 | scene0019_00 42 | scene0019_01 43 | scene0020_00 44 | scene0020_01 45 | scene0021_00 46 | scene0022_00 47 | scene0022_01 48 | scene0023_00 49 | scene0024_00 50 | scene0024_01 51 | scene0024_02 52 | scene0025_00 53 | scene0025_01 54 | scene0025_02 55 | scene0026_00 56 | scene0027_00 57 | scene0027_01 58 | scene0027_02 59 | scene0028_00 60 | scene0029_00 61 | scene0029_01 62 | scene0029_02 63 | scene0030_00 64 | scene0030_01 65 | scene0030_02 66 | scene0031_00 67 | scene0031_01 68 | scene0031_02 69 | scene0032_00 70 | scene0032_01 71 | scene0033_00 72 | scene0034_00 73 | scene0034_01 74 | scene0034_02 75 | scene0035_00 76 | scene0035_01 77 | scene0036_00 78 | scene0036_01 79 | scene0037_00 80 | scene0038_00 81 | scene0038_01 82 | scene0038_02 83 | scene0039_00 84 | scene0039_01 85 | scene0040_00 86 | scene0040_01 87 | scene0041_00 88 | scene0041_01 89 | scene0042_00 90 | scene0042_01 91 | scene0042_02 92 | scene0043_00 93 | scene0043_01 94 | scene0044_00 95 | scene0044_01 96 | scene0044_02 97 | scene0045_00 98 | scene0045_01 99 | scene0046_00 100 | scene0046_01 101 | scene0046_02 102 | scene0047_00 103 | scene0048_00 104 | scene0048_01 105 | scene0049_00 106 | scene0050_00 107 | scene0050_01 108 | scene0050_02 109 | scene0051_00 110 | scene0051_01 111 | scene0051_02 112 | scene0051_03 113 | scene0052_00 114 | scene0052_01 115 | scene0052_02 116 | scene0053_00 117 | scene0054_00 118 | scene0055_00 119 | scene0055_01 120 | scene0055_02 121 | scene0056_00 122 | scene0056_01 123 | scene0057_00 124 | scene0057_01 125 | scene0058_00 126 | scene0058_01 127 | scene0059_00 128 | scene0059_01 129 | scene0059_02 130 | scene0060_00 131 | scene0060_01 132 | scene0061_00 133 | scene0061_01 134 | scene0062_00 135 | scene0062_01 136 | scene0062_02 137 | scene0063_00 138 | scene0064_00 139 | scene0064_01 140 | scene0065_00 141 | scene0065_01 142 | scene0065_02 143 | scene0066_00 144 | scene0067_00 145 | scene0067_01 146 | scene0067_02 147 | scene0068_00 148 | scene0068_01 149 | scene0069_00 150 | scene0070_00 151 | scene0071_00 152 | scene0072_00 153 | scene0072_01 154 | scene0072_02 155 | scene0073_00 156 | scene0073_01 157 | scene0073_02 158 | scene0073_03 159 | scene0074_00 160 | scene0074_01 161 | scene0074_02 162 | scene0075_00 163 | scene0076_00 164 | scene0077_00 165 | scene0077_01 166 | scene0078_00 167 | scene0078_01 168 | scene0078_02 169 | scene0079_00 170 | scene0079_01 171 | scene0080_00 172 | scene0080_01 173 | scene0080_02 174 | scene0081_00 175 | scene0081_01 176 | scene0081_02 177 | scene0082_00 178 | scene0083_00 179 | scene0083_01 180 | scene0084_00 181 | scene0084_01 182 | scene0084_02 183 | scene0085_00 184 | scene0085_01 185 | scene0086_00 186 | scene0086_01 187 | scene0086_02 188 | scene0087_00 189 | scene0087_01 190 | scene0087_02 191 | scene0088_00 192 | scene0088_01 193 | scene0088_02 194 | scene0088_03 195 | scene0089_00 196 | scene0089_01 197 | scene0089_02 198 | scene0090_00 199 | scene0091_00 200 | scene0092_00 201 | scene0092_01 202 | scene0092_02 203 | scene0092_03 204 | scene0092_04 205 | scene0093_00 206 | scene0093_01 207 | scene0093_02 208 | scene0094_00 209 | scene0095_00 210 | scene0095_01 211 | scene0096_00 212 | scene0096_01 213 | scene0096_02 214 | scene0097_00 215 | scene0098_00 216 | scene0098_01 217 | scene0099_00 218 | scene0099_01 219 | scene0100_00 220 | scene0100_01 221 | scene0100_02 222 | scene0101_00 223 | scene0101_01 224 | scene0101_02 225 | scene0101_03 226 | scene0101_04 227 | scene0101_05 228 | scene0102_00 229 | scene0102_01 230 | scene0103_00 231 | scene0103_01 232 | scene0104_00 233 | scene0105_00 234 | scene0105_01 235 | scene0105_02 236 | scene0106_00 237 | scene0106_01 238 | scene0106_02 239 | scene0107_00 240 | scene0108_00 241 | scene0109_00 242 | scene0109_01 243 | scene0110_00 244 | scene0110_01 245 | scene0110_02 246 | scene0111_00 247 | scene0111_01 248 | scene0111_02 249 | scene0112_00 250 | scene0112_01 251 | scene0112_02 252 | scene0113_00 253 | scene0113_01 254 | scene0114_00 255 | scene0114_01 256 | scene0114_02 257 | scene0115_00 258 | scene0115_01 259 | scene0115_02 260 | scene0116_00 261 | scene0116_01 262 | scene0116_02 263 | scene0117_00 264 | scene0118_00 265 | scene0118_01 266 | scene0118_02 267 | scene0119_00 268 | scene0120_00 269 | scene0120_01 270 | scene0121_00 271 | scene0121_01 272 | scene0121_02 273 | scene0122_00 274 | scene0122_01 275 | scene0123_00 276 | scene0123_01 277 | scene0123_02 278 | scene0124_00 279 | scene0124_01 280 | scene0125_00 281 | scene0126_00 282 | scene0126_01 283 | scene0126_02 284 | scene0127_00 285 | scene0127_01 286 | scene0128_00 287 | scene0129_00 288 | scene0130_00 289 | scene0131_00 290 | scene0131_01 291 | scene0131_02 292 | scene0132_00 293 | scene0132_01 294 | scene0132_02 295 | scene0133_00 296 | scene0134_00 297 | scene0134_01 298 | scene0134_02 299 | scene0135_00 300 | scene0136_00 301 | scene0136_01 302 | scene0136_02 303 | scene0137_00 304 | scene0137_01 305 | scene0137_02 306 | scene0138_00 307 | scene0139_00 308 | scene0140_00 309 | scene0140_01 310 | scene0141_00 311 | scene0141_01 312 | scene0141_02 313 | scene0142_00 314 | scene0142_01 315 | scene0143_00 316 | scene0143_01 317 | scene0143_02 318 | scene0144_00 319 | scene0144_01 320 | scene0145_00 321 | scene0146_00 322 | scene0146_01 323 | scene0146_02 324 | scene0147_00 325 | scene0147_01 326 | scene0148_00 327 | scene0149_00 328 | scene0150_00 329 | scene0150_01 330 | scene0150_02 331 | scene0151_00 332 | scene0151_01 333 | scene0152_00 334 | scene0152_01 335 | scene0152_02 336 | scene0153_00 337 | scene0153_01 338 | scene0154_00 339 | scene0155_00 340 | scene0155_01 341 | scene0155_02 342 | scene0156_00 343 | scene0157_00 344 | scene0157_01 345 | scene0158_00 346 | scene0158_01 347 | scene0158_02 348 | scene0159_00 349 | scene0160_00 350 | scene0160_01 351 | scene0160_02 352 | scene0160_03 353 | scene0160_04 354 | scene0161_00 355 | scene0161_01 356 | scene0161_02 357 | scene0162_00 358 | scene0163_00 359 | scene0163_01 360 | scene0164_00 361 | scene0164_01 362 | scene0164_02 363 | scene0164_03 364 | scene0165_00 365 | scene0165_01 366 | scene0165_02 367 | scene0166_00 368 | scene0166_01 369 | scene0166_02 370 | scene0167_00 371 | scene0168_00 372 | scene0168_01 373 | scene0168_02 374 | scene0169_00 375 | scene0169_01 376 | scene0170_00 377 | scene0170_01 378 | scene0170_02 379 | scene0171_00 380 | scene0171_01 381 | scene0172_00 382 | scene0172_01 383 | scene0173_00 384 | scene0173_01 385 | scene0173_02 386 | scene0174_00 387 | scene0174_01 388 | scene0175_00 389 | scene0176_00 390 | scene0177_00 391 | scene0177_01 392 | scene0177_02 393 | scene0178_00 394 | scene0179_00 395 | scene0180_00 396 | scene0181_00 397 | scene0181_01 398 | scene0181_02 399 | scene0181_03 400 | scene0182_00 401 | scene0182_01 402 | scene0182_02 403 | scene0183_00 404 | scene0184_00 405 | scene0185_00 406 | scene0186_00 407 | scene0186_01 408 | scene0187_00 409 | scene0187_01 410 | scene0188_00 411 | scene0189_00 412 | scene0190_00 413 | scene0191_00 414 | scene0191_01 415 | scene0191_02 416 | scene0192_00 417 | scene0192_01 418 | scene0192_02 419 | scene0193_00 420 | scene0193_01 421 | scene0194_00 422 | scene0195_00 423 | scene0195_01 424 | scene0195_02 425 | scene0196_00 426 | scene0197_00 427 | scene0197_01 428 | scene0197_02 429 | scene0198_00 430 | scene0199_00 431 | scene0200_00 432 | scene0200_01 433 | scene0200_02 434 | scene0201_00 435 | scene0201_01 436 | scene0201_02 437 | scene0202_00 438 | scene0203_00 439 | scene0203_01 440 | scene0203_02 441 | scene0204_00 442 | scene0204_01 443 | scene0204_02 444 | scene0205_00 445 | scene0205_01 446 | scene0205_02 447 | scene0206_00 448 | scene0206_01 449 | scene0206_02 450 | scene0207_00 451 | scene0207_01 452 | scene0207_02 453 | scene0208_00 454 | scene0209_00 455 | scene0209_01 456 | scene0209_02 457 | scene0210_00 458 | scene0210_01 459 | scene0211_00 460 | scene0211_01 461 | scene0211_02 462 | scene0211_03 463 | scene0212_00 464 | scene0212_01 465 | scene0212_02 466 | scene0213_00 467 | scene0214_00 468 | scene0214_01 469 | scene0214_02 470 | scene0215_00 471 | scene0215_01 472 | scene0216_00 473 | scene0217_00 474 | scene0218_00 475 | scene0218_01 476 | scene0219_00 477 | scene0220_00 478 | scene0220_01 479 | scene0220_02 480 | scene0221_00 481 | scene0221_01 482 | scene0222_00 483 | scene0222_01 484 | scene0223_00 485 | scene0223_01 486 | scene0223_02 487 | scene0224_00 488 | scene0225_00 489 | scene0226_00 490 | scene0226_01 491 | scene0227_00 492 | scene0228_00 493 | scene0229_00 494 | scene0229_01 495 | scene0229_02 496 | scene0230_00 497 | scene0231_00 498 | scene0231_01 499 | scene0231_02 500 | scene0232_00 501 | scene0232_01 502 | scene0232_02 503 | scene0233_00 504 | scene0233_01 505 | scene0234_00 506 | scene0235_00 507 | scene0236_00 508 | scene0236_01 509 | scene0237_00 510 | scene0237_01 511 | scene0238_00 512 | scene0238_01 513 | scene0239_00 514 | scene0239_01 515 | scene0239_02 516 | scene0240_00 517 | scene0241_00 518 | scene0241_01 519 | scene0241_02 520 | scene0242_00 521 | scene0242_01 522 | scene0242_02 523 | scene0243_00 524 | scene0244_00 525 | scene0244_01 526 | scene0245_00 527 | scene0246_00 528 | scene0247_00 529 | scene0247_01 530 | scene0248_00 531 | scene0248_01 532 | scene0248_02 533 | scene0249_00 534 | scene0250_00 535 | scene0250_01 536 | scene0250_02 537 | scene0251_00 538 | scene0252_00 539 | scene0253_00 540 | scene0254_00 541 | scene0254_01 542 | scene0255_00 543 | scene0255_01 544 | scene0255_02 545 | scene0256_00 546 | scene0256_01 547 | scene0256_02 548 | scene0257_00 549 | scene0258_00 550 | scene0259_00 551 | scene0259_01 552 | scene0260_00 553 | scene0260_01 554 | scene0260_02 555 | scene0261_00 556 | scene0261_01 557 | scene0261_02 558 | scene0261_03 559 | scene0262_00 560 | scene0262_01 561 | scene0263_00 562 | scene0263_01 563 | scene0264_00 564 | scene0264_01 565 | scene0264_02 566 | scene0265_00 567 | scene0265_01 568 | scene0265_02 569 | scene0266_00 570 | scene0266_01 571 | scene0267_00 572 | scene0268_00 573 | scene0268_01 574 | scene0268_02 575 | scene0269_00 576 | scene0269_01 577 | scene0269_02 578 | scene0270_00 579 | scene0270_01 580 | scene0270_02 581 | scene0271_00 582 | scene0271_01 583 | scene0272_00 584 | scene0272_01 585 | scene0273_00 586 | scene0273_01 587 | scene0274_00 588 | scene0274_01 589 | scene0274_02 590 | scene0275_00 591 | scene0276_00 592 | scene0276_01 593 | scene0277_00 594 | scene0277_01 595 | scene0277_02 596 | scene0278_00 597 | scene0278_01 598 | scene0279_00 599 | scene0279_01 600 | scene0279_02 601 | scene0280_00 602 | scene0280_01 603 | scene0280_02 604 | scene0281_00 605 | scene0282_00 606 | scene0282_01 607 | scene0282_02 608 | scene0283_00 609 | scene0284_00 610 | scene0285_00 611 | scene0286_00 612 | scene0286_01 613 | scene0286_02 614 | scene0286_03 615 | scene0287_00 616 | scene0288_00 617 | scene0288_01 618 | scene0288_02 619 | scene0289_00 620 | scene0289_01 621 | scene0290_00 622 | scene0291_00 623 | scene0291_01 624 | scene0291_02 625 | scene0292_00 626 | scene0292_01 627 | scene0293_00 628 | scene0293_01 629 | scene0294_00 630 | scene0294_01 631 | scene0294_02 632 | scene0295_00 633 | scene0295_01 634 | scene0296_00 635 | scene0296_01 636 | scene0297_00 637 | scene0297_01 638 | scene0297_02 639 | scene0298_00 640 | scene0299_00 641 | scene0299_01 642 | scene0300_00 643 | scene0300_01 644 | scene0301_00 645 | scene0301_01 646 | scene0301_02 647 | scene0302_00 648 | scene0302_01 649 | scene0303_00 650 | scene0303_01 651 | scene0303_02 652 | scene0304_00 653 | scene0305_00 654 | scene0305_01 655 | scene0306_00 656 | scene0306_01 657 | scene0307_00 658 | scene0307_01 659 | scene0307_02 660 | scene0308_00 661 | scene0309_00 662 | scene0309_01 663 | scene0310_00 664 | scene0310_01 665 | scene0310_02 666 | scene0311_00 667 | scene0312_00 668 | scene0312_01 669 | scene0312_02 670 | scene0313_00 671 | scene0313_01 672 | scene0313_02 673 | scene0314_00 674 | scene0315_00 675 | scene0316_00 676 | scene0317_00 677 | scene0317_01 678 | scene0318_00 679 | scene0319_00 680 | scene0320_00 681 | scene0320_01 682 | scene0320_02 683 | scene0320_03 684 | scene0321_00 685 | scene0322_00 686 | scene0323_00 687 | scene0323_01 688 | scene0324_00 689 | scene0324_01 690 | scene0325_00 691 | scene0325_01 692 | scene0326_00 693 | scene0327_00 694 | scene0328_00 695 | scene0329_00 696 | scene0329_01 697 | scene0329_02 698 | scene0330_00 699 | scene0331_00 700 | scene0331_01 701 | scene0332_00 702 | scene0332_01 703 | scene0332_02 704 | scene0333_00 705 | scene0334_00 706 | scene0334_01 707 | scene0334_02 708 | scene0335_00 709 | scene0335_01 710 | scene0335_02 711 | scene0336_00 712 | scene0336_01 713 | scene0337_00 714 | scene0337_01 715 | scene0337_02 716 | scene0338_00 717 | scene0338_01 718 | scene0338_02 719 | scene0339_00 720 | scene0340_00 721 | scene0340_01 722 | scene0340_02 723 | scene0341_00 724 | scene0341_01 725 | scene0342_00 726 | scene0343_00 727 | scene0344_00 728 | scene0344_01 729 | scene0345_00 730 | scene0345_01 731 | scene0346_00 732 | scene0346_01 733 | scene0347_00 734 | scene0347_01 735 | scene0347_02 736 | scene0348_00 737 | scene0348_01 738 | scene0348_02 739 | scene0349_00 740 | scene0349_01 741 | scene0350_00 742 | scene0350_01 743 | scene0350_02 744 | scene0351_00 745 | scene0351_01 746 | scene0352_00 747 | scene0352_01 748 | scene0352_02 749 | scene0353_00 750 | scene0353_01 751 | scene0353_02 752 | scene0354_00 753 | scene0355_00 754 | scene0355_01 755 | scene0356_00 756 | scene0356_01 757 | scene0356_02 758 | scene0357_00 759 | scene0357_01 760 | scene0358_00 761 | scene0358_01 762 | scene0358_02 763 | scene0359_00 764 | scene0359_01 765 | scene0360_00 766 | scene0361_00 767 | scene0361_01 768 | scene0361_02 769 | scene0362_00 770 | scene0362_01 771 | scene0362_02 772 | scene0362_03 773 | scene0363_00 774 | scene0364_00 775 | scene0364_01 776 | scene0365_00 777 | scene0365_01 778 | scene0365_02 779 | scene0366_00 780 | scene0367_00 781 | scene0367_01 782 | scene0368_00 783 | scene0368_01 784 | scene0369_00 785 | scene0369_01 786 | scene0369_02 787 | scene0370_00 788 | scene0370_01 789 | scene0370_02 790 | scene0371_00 791 | scene0371_01 792 | scene0372_00 793 | scene0373_00 794 | scene0373_01 795 | scene0374_00 796 | scene0375_00 797 | scene0375_01 798 | scene0375_02 799 | scene0376_00 800 | scene0376_01 801 | scene0376_02 802 | scene0377_00 803 | scene0377_01 804 | scene0377_02 805 | scene0378_00 806 | scene0378_01 807 | scene0378_02 808 | scene0379_00 809 | scene0380_00 810 | scene0380_01 811 | scene0380_02 812 | scene0381_00 813 | scene0381_01 814 | scene0381_02 815 | scene0382_00 816 | scene0382_01 817 | scene0383_00 818 | scene0383_01 819 | scene0383_02 820 | scene0384_00 821 | scene0385_00 822 | scene0385_01 823 | scene0385_02 824 | scene0386_00 825 | scene0387_00 826 | scene0387_01 827 | scene0387_02 828 | scene0388_00 829 | scene0388_01 830 | scene0389_00 831 | scene0390_00 832 | scene0391_00 833 | scene0392_00 834 | scene0392_01 835 | scene0392_02 836 | scene0393_00 837 | scene0393_01 838 | scene0393_02 839 | scene0394_00 840 | scene0394_01 841 | scene0395_00 842 | scene0395_01 843 | scene0395_02 844 | scene0396_00 845 | scene0396_01 846 | scene0396_02 847 | scene0397_00 848 | scene0397_01 849 | scene0398_00 850 | scene0398_01 851 | scene0399_00 852 | scene0399_01 853 | scene0400_00 854 | scene0400_01 855 | scene0401_00 856 | scene0402_00 857 | scene0403_00 858 | scene0403_01 859 | scene0404_00 860 | scene0404_01 861 | scene0404_02 862 | scene0405_00 863 | scene0406_00 864 | scene0406_01 865 | scene0406_02 866 | scene0407_00 867 | scene0407_01 868 | scene0408_00 869 | scene0408_01 870 | scene0409_00 871 | scene0409_01 872 | scene0410_00 873 | scene0410_01 874 | scene0411_00 875 | scene0411_01 876 | scene0411_02 877 | scene0412_00 878 | scene0412_01 879 | scene0413_00 880 | scene0414_00 881 | scene0415_00 882 | scene0415_01 883 | scene0415_02 884 | scene0416_00 885 | scene0416_01 886 | scene0416_02 887 | scene0416_03 888 | scene0416_04 889 | scene0417_00 890 | scene0418_00 891 | scene0418_01 892 | scene0418_02 893 | scene0419_00 894 | scene0419_01 895 | scene0419_02 896 | scene0420_00 897 | scene0420_01 898 | scene0420_02 899 | scene0421_00 900 | scene0421_01 901 | scene0421_02 902 | scene0422_00 903 | scene0423_00 904 | scene0423_01 905 | scene0423_02 906 | scene0424_00 907 | scene0424_01 908 | scene0424_02 909 | scene0425_00 910 | scene0425_01 911 | scene0426_00 912 | scene0426_01 913 | scene0426_02 914 | scene0426_03 915 | scene0427_00 916 | scene0428_00 917 | scene0428_01 918 | scene0429_00 919 | scene0430_00 920 | scene0430_01 921 | scene0431_00 922 | scene0432_00 923 | scene0432_01 924 | scene0433_00 925 | scene0434_00 926 | scene0434_01 927 | scene0434_02 928 | scene0435_00 929 | scene0435_01 930 | scene0435_02 931 | scene0435_03 932 | scene0436_00 933 | scene0437_00 934 | scene0437_01 935 | scene0438_00 936 | scene0439_00 937 | scene0439_01 938 | scene0440_00 939 | scene0440_01 940 | scene0440_02 941 | scene0441_00 942 | scene0442_00 943 | scene0443_00 944 | scene0444_00 945 | scene0444_01 946 | scene0445_00 947 | scene0445_01 948 | scene0446_00 949 | scene0446_01 950 | scene0447_00 951 | scene0447_01 952 | scene0447_02 953 | scene0448_00 954 | scene0448_01 955 | scene0448_02 956 | scene0449_00 957 | scene0449_01 958 | scene0449_02 959 | scene0450_00 960 | scene0451_00 961 | scene0451_01 962 | scene0451_02 963 | scene0451_03 964 | scene0451_04 965 | scene0451_05 966 | scene0452_00 967 | scene0452_01 968 | scene0452_02 969 | scene0453_00 970 | scene0453_01 971 | scene0454_00 972 | scene0455_00 973 | scene0456_00 974 | scene0456_01 975 | scene0457_00 976 | scene0457_01 977 | scene0457_02 978 | scene0458_00 979 | scene0458_01 980 | scene0459_00 981 | scene0459_01 982 | scene0460_00 983 | scene0461_00 984 | scene0462_00 985 | scene0463_00 986 | scene0463_01 987 | scene0464_00 988 | scene0465_00 989 | scene0465_01 990 | scene0466_00 991 | scene0466_01 992 | scene0467_00 993 | scene0468_00 994 | scene0468_01 995 | scene0468_02 996 | scene0469_00 997 | scene0469_01 998 | scene0469_02 999 | scene0470_00 1000 | scene0470_01 1001 | scene0471_00 1002 | scene0471_01 1003 | scene0471_02 1004 | scene0472_00 1005 | scene0472_01 1006 | scene0472_02 1007 | scene0473_00 1008 | scene0473_01 1009 | scene0474_00 1010 | scene0474_01 1011 | scene0474_02 1012 | scene0474_03 1013 | scene0474_04 1014 | scene0474_05 1015 | scene0475_00 1016 | scene0475_01 1017 | scene0475_02 1018 | scene0476_00 1019 | scene0476_01 1020 | scene0476_02 1021 | scene0477_00 1022 | scene0477_01 1023 | scene0478_00 1024 | scene0478_01 1025 | scene0479_00 1026 | scene0479_01 1027 | scene0479_02 1028 | scene0480_00 1029 | scene0480_01 1030 | scene0481_00 1031 | scene0481_01 1032 | scene0482_00 1033 | scene0482_01 1034 | scene0483_00 1035 | scene0484_00 1036 | scene0484_01 1037 | scene0485_00 1038 | scene0486_00 1039 | scene0487_00 1040 | scene0487_01 1041 | scene0488_00 1042 | scene0488_01 1043 | scene0489_00 1044 | scene0489_01 1045 | scene0489_02 1046 | scene0490_00 1047 | scene0491_00 1048 | scene0492_00 1049 | scene0492_01 1050 | scene0493_00 1051 | scene0493_01 1052 | scene0494_00 1053 | scene0495_00 1054 | scene0496_00 1055 | scene0497_00 1056 | scene0498_00 1057 | scene0498_01 1058 | scene0498_02 1059 | scene0499_00 1060 | scene0500_00 1061 | scene0500_01 1062 | scene0501_00 1063 | scene0501_01 1064 | scene0501_02 1065 | scene0502_00 1066 | scene0502_01 1067 | scene0502_02 1068 | scene0503_00 1069 | scene0504_00 1070 | scene0505_00 1071 | scene0505_01 1072 | scene0505_02 1073 | scene0505_03 1074 | scene0505_04 1075 | scene0506_00 1076 | scene0507_00 1077 | scene0508_00 1078 | scene0508_01 1079 | scene0508_02 1080 | scene0509_00 1081 | scene0509_01 1082 | scene0509_02 1083 | scene0510_00 1084 | scene0510_01 1085 | scene0510_02 1086 | scene0511_00 1087 | scene0511_01 1088 | scene0512_00 1089 | scene0513_00 1090 | scene0514_00 1091 | scene0514_01 1092 | scene0515_00 1093 | scene0515_01 1094 | scene0515_02 1095 | scene0516_00 1096 | scene0516_01 1097 | scene0517_00 1098 | scene0517_01 1099 | scene0517_02 1100 | scene0518_00 1101 | scene0519_00 1102 | scene0520_00 1103 | scene0520_01 1104 | scene0521_00 1105 | scene0522_00 1106 | scene0523_00 1107 | scene0523_01 1108 | scene0523_02 1109 | scene0524_00 1110 | scene0524_01 1111 | scene0525_00 1112 | scene0525_01 1113 | scene0525_02 1114 | scene0526_00 1115 | scene0526_01 1116 | scene0527_00 1117 | scene0528_00 1118 | scene0528_01 1119 | scene0529_00 1120 | scene0529_01 1121 | scene0529_02 1122 | scene0530_00 1123 | scene0531_00 1124 | scene0532_00 1125 | scene0532_01 1126 | scene0533_00 1127 | scene0533_01 1128 | scene0534_00 1129 | scene0534_01 1130 | scene0535_00 1131 | scene0536_00 1132 | scene0536_01 1133 | scene0536_02 1134 | scene0537_00 1135 | scene0538_00 1136 | scene0539_00 1137 | scene0539_01 1138 | scene0539_02 1139 | scene0540_00 1140 | scene0540_01 1141 | scene0540_02 1142 | scene0541_00 1143 | scene0541_01 1144 | scene0541_02 1145 | scene0542_00 1146 | scene0543_00 1147 | scene0543_01 1148 | scene0543_02 1149 | scene0544_00 1150 | scene0545_00 1151 | scene0545_01 1152 | scene0545_02 1153 | scene0546_00 1154 | scene0547_00 1155 | scene0547_01 1156 | scene0547_02 1157 | scene0548_00 1158 | scene0548_01 1159 | scene0548_02 1160 | scene0549_00 1161 | scene0549_01 1162 | scene0550_00 1163 | scene0551_00 1164 | scene0552_00 1165 | scene0552_01 1166 | scene0553_00 1167 | scene0553_01 1168 | scene0553_02 1169 | scene0554_00 1170 | scene0554_01 1171 | scene0555_00 1172 | scene0556_00 1173 | scene0556_01 1174 | scene0557_00 1175 | scene0557_01 1176 | scene0557_02 1177 | scene0558_00 1178 | scene0558_01 1179 | scene0558_02 1180 | scene0559_00 1181 | scene0559_01 1182 | scene0559_02 1183 | scene0560_00 1184 | scene0561_00 1185 | scene0561_01 1186 | scene0562_00 1187 | scene0563_00 1188 | scene0564_00 1189 | scene0565_00 1190 | scene0566_00 1191 | scene0567_00 1192 | scene0567_01 1193 | scene0568_00 1194 | scene0568_01 1195 | scene0568_02 1196 | scene0569_00 1197 | scene0569_01 1198 | scene0570_00 1199 | scene0570_01 1200 | scene0570_02 1201 | scene0571_00 1202 | scene0571_01 1203 | scene0572_00 1204 | scene0572_01 1205 | scene0572_02 1206 | scene0573_00 1207 | scene0573_01 1208 | scene0574_00 1209 | scene0574_01 1210 | scene0574_02 1211 | scene0575_00 1212 | scene0575_01 1213 | scene0575_02 1214 | scene0576_00 1215 | scene0576_01 1216 | scene0576_02 1217 | scene0577_00 1218 | scene0578_00 1219 | scene0578_01 1220 | scene0578_02 1221 | scene0579_00 1222 | scene0579_01 1223 | scene0579_02 1224 | scene0580_00 1225 | scene0580_01 1226 | scene0581_00 1227 | scene0581_01 1228 | scene0581_02 1229 | scene0582_00 1230 | scene0582_01 1231 | scene0582_02 1232 | scene0583_00 1233 | scene0583_01 1234 | scene0583_02 1235 | scene0584_00 1236 | scene0584_01 1237 | scene0584_02 1238 | scene0585_00 1239 | scene0585_01 1240 | scene0586_00 1241 | scene0586_01 1242 | scene0586_02 1243 | scene0587_00 1244 | scene0587_01 1245 | scene0587_02 1246 | scene0587_03 1247 | scene0588_00 1248 | scene0588_01 1249 | scene0588_02 1250 | scene0588_03 1251 | scene0589_00 1252 | scene0589_01 1253 | scene0589_02 1254 | scene0590_00 1255 | scene0590_01 1256 | scene0591_00 1257 | scene0591_01 1258 | scene0591_02 1259 | scene0592_00 1260 | scene0592_01 1261 | scene0593_00 1262 | scene0593_01 1263 | scene0594_00 1264 | scene0595_00 1265 | scene0596_00 1266 | scene0596_01 1267 | scene0596_02 1268 | scene0597_00 1269 | scene0597_01 1270 | scene0597_02 1271 | scene0598_00 1272 | scene0598_01 1273 | scene0598_02 1274 | scene0599_00 1275 | scene0599_01 1276 | scene0599_02 1277 | scene0600_00 1278 | scene0600_01 1279 | scene0600_02 1280 | scene0601_00 1281 | scene0601_01 1282 | scene0602_00 1283 | scene0603_00 1284 | scene0603_01 1285 | scene0604_00 1286 | scene0604_01 1287 | scene0604_02 1288 | scene0605_00 1289 | scene0605_01 1290 | scene0606_00 1291 | scene0606_01 1292 | scene0606_02 1293 | scene0607_00 1294 | scene0607_01 1295 | scene0608_00 1296 | scene0608_01 1297 | scene0608_02 1298 | scene0609_00 1299 | scene0609_01 1300 | scene0609_02 1301 | scene0609_03 1302 | scene0610_00 1303 | scene0610_01 1304 | scene0610_02 1305 | scene0611_00 1306 | scene0611_01 1307 | scene0612_00 1308 | scene0612_01 1309 | scene0613_00 1310 | scene0613_01 1311 | scene0613_02 1312 | scene0614_00 1313 | scene0614_01 1314 | scene0614_02 1315 | scene0615_00 1316 | scene0615_01 1317 | scene0616_00 1318 | scene0616_01 1319 | scene0617_00 1320 | scene0618_00 1321 | scene0619_00 1322 | scene0620_00 1323 | scene0620_01 1324 | scene0621_00 1325 | scene0622_00 1326 | scene0622_01 1327 | scene0623_00 1328 | scene0623_01 1329 | scene0624_00 1330 | scene0625_00 1331 | scene0625_01 1332 | scene0626_00 1333 | scene0626_01 1334 | scene0626_02 1335 | scene0627_00 1336 | scene0627_01 1337 | scene0628_00 1338 | scene0628_01 1339 | scene0628_02 1340 | scene0629_00 1341 | scene0629_01 1342 | scene0629_02 1343 | scene0630_00 1344 | scene0630_01 1345 | scene0630_02 1346 | scene0630_03 1347 | scene0630_04 1348 | scene0630_05 1349 | scene0630_06 1350 | scene0631_00 1351 | scene0631_01 1352 | scene0631_02 1353 | scene0632_00 1354 | scene0633_00 1355 | scene0633_01 1356 | scene0634_00 1357 | scene0635_00 1358 | scene0635_01 1359 | scene0636_00 1360 | scene0637_00 1361 | scene0638_00 1362 | scene0639_00 1363 | scene0640_00 1364 | scene0640_01 1365 | scene0640_02 1366 | scene0641_00 1367 | scene0642_00 1368 | scene0642_01 1369 | scene0642_02 1370 | scene0642_03 1371 | scene0643_00 1372 | scene0644_00 1373 | scene0645_00 1374 | scene0645_01 1375 | scene0645_02 1376 | scene0646_00 1377 | scene0646_01 1378 | scene0646_02 1379 | scene0647_00 1380 | scene0647_01 1381 | scene0648_00 1382 | scene0648_01 1383 | scene0649_00 1384 | scene0649_01 1385 | scene0650_00 1386 | scene0651_00 1387 | scene0651_01 1388 | scene0651_02 1389 | scene0652_00 1390 | scene0653_00 1391 | scene0653_01 1392 | scene0654_00 1393 | scene0654_01 1394 | scene0655_00 1395 | scene0655_01 1396 | scene0655_02 1397 | scene0656_00 1398 | scene0656_01 1399 | scene0656_02 1400 | scene0656_03 1401 | scene0657_00 1402 | scene0658_00 1403 | scene0659_00 1404 | scene0659_01 1405 | scene0660_00 1406 | scene0661_00 1407 | scene0662_00 1408 | scene0662_01 1409 | scene0662_02 1410 | scene0663_00 1411 | scene0663_01 1412 | scene0663_02 1413 | scene0664_00 1414 | scene0664_01 1415 | scene0664_02 1416 | scene0665_00 1417 | scene0665_01 1418 | scene0666_00 1419 | scene0666_01 1420 | scene0666_02 1421 | scene0667_00 1422 | scene0667_01 1423 | scene0667_02 1424 | scene0668_00 1425 | scene0669_00 1426 | scene0669_01 1427 | scene0670_00 1428 | scene0670_01 1429 | scene0671_00 1430 | scene0671_01 1431 | scene0672_00 1432 | scene0672_01 1433 | scene0673_00 1434 | scene0673_01 1435 | scene0673_02 1436 | scene0673_03 1437 | scene0673_04 1438 | scene0673_05 1439 | scene0674_00 1440 | scene0674_01 1441 | scene0675_00 1442 | scene0675_01 1443 | scene0676_00 1444 | scene0676_01 1445 | scene0677_00 1446 | scene0677_01 1447 | scene0677_02 1448 | scene0678_00 1449 | scene0678_01 1450 | scene0678_02 1451 | scene0679_00 1452 | scene0679_01 1453 | scene0680_00 1454 | scene0680_01 1455 | scene0681_00 1456 | scene0682_00 1457 | scene0683_00 1458 | scene0684_00 1459 | scene0684_01 1460 | scene0685_00 1461 | scene0685_01 1462 | scene0685_02 1463 | scene0686_00 1464 | scene0686_01 1465 | scene0686_02 1466 | scene0687_00 1467 | scene0688_00 1468 | scene0689_00 1469 | scene0690_00 1470 | scene0690_01 1471 | scene0691_00 1472 | scene0691_01 1473 | scene0692_00 1474 | scene0692_01 1475 | scene0692_02 1476 | scene0692_03 1477 | scene0692_04 1478 | scene0693_00 1479 | scene0693_01 1480 | scene0693_02 1481 | scene0694_00 1482 | scene0694_01 1483 | scene0695_00 1484 | scene0695_01 1485 | scene0695_02 1486 | scene0695_03 1487 | scene0696_00 1488 | scene0696_01 1489 | scene0696_02 1490 | scene0697_00 1491 | scene0697_01 1492 | scene0697_02 1493 | scene0697_03 1494 | scene0698_00 1495 | scene0698_01 1496 | scene0699_00 1497 | scene0700_00 1498 | scene0700_01 1499 | scene0700_02 1500 | scene0701_00 1501 | scene0701_01 1502 | scene0701_02 1503 | scene0702_00 1504 | scene0702_01 1505 | scene0702_02 1506 | scene0703_00 1507 | scene0703_01 1508 | scene0704_00 1509 | scene0704_01 1510 | scene0705_00 1511 | scene0705_01 1512 | scene0705_02 1513 | scene0706_00 1514 | -------------------------------------------------------------------------------- /src/data/meta_data/scannet/scannetv2_val.txt: -------------------------------------------------------------------------------- 1 | scene0568_00 2 | scene0568_01 3 | scene0568_02 4 | scene0304_00 5 | scene0488_00 6 | scene0488_01 7 | scene0412_00 8 | scene0412_01 9 | scene0217_00 10 | scene0019_00 11 | scene0019_01 12 | scene0414_00 13 | scene0575_00 14 | scene0575_01 15 | scene0575_02 16 | scene0426_00 17 | scene0426_01 18 | scene0426_02 19 | scene0426_03 20 | scene0549_00 21 | scene0549_01 22 | scene0578_00 23 | scene0578_01 24 | scene0578_02 25 | scene0665_00 26 | scene0665_01 27 | scene0050_00 28 | scene0050_01 29 | scene0050_02 30 | scene0257_00 31 | scene0025_00 32 | scene0025_01 33 | scene0025_02 34 | scene0583_00 35 | scene0583_01 36 | scene0583_02 37 | scene0701_00 38 | scene0701_01 39 | scene0701_02 40 | scene0580_00 41 | scene0580_01 42 | scene0565_00 43 | scene0169_00 44 | scene0169_01 45 | scene0655_00 46 | scene0655_01 47 | scene0655_02 48 | scene0063_00 49 | scene0221_00 50 | scene0221_01 51 | scene0591_00 52 | scene0591_01 53 | scene0591_02 54 | scene0678_00 55 | scene0678_01 56 | scene0678_02 57 | scene0462_00 58 | scene0427_00 59 | scene0595_00 60 | scene0193_00 61 | scene0193_01 62 | scene0164_00 63 | scene0164_01 64 | scene0164_02 65 | scene0164_03 66 | scene0598_00 67 | scene0598_01 68 | scene0598_02 69 | scene0599_00 70 | scene0599_01 71 | scene0599_02 72 | scene0328_00 73 | scene0300_00 74 | scene0300_01 75 | scene0354_00 76 | scene0458_00 77 | scene0458_01 78 | scene0423_00 79 | scene0423_01 80 | scene0423_02 81 | scene0307_00 82 | scene0307_01 83 | scene0307_02 84 | scene0606_00 85 | scene0606_01 86 | scene0606_02 87 | scene0432_00 88 | scene0432_01 89 | scene0608_00 90 | scene0608_01 91 | scene0608_02 92 | scene0651_00 93 | scene0651_01 94 | scene0651_02 95 | scene0430_00 96 | scene0430_01 97 | scene0689_00 98 | scene0357_00 99 | scene0357_01 100 | scene0574_00 101 | scene0574_01 102 | scene0574_02 103 | scene0329_00 104 | scene0329_01 105 | scene0329_02 106 | scene0153_00 107 | scene0153_01 108 | scene0616_00 109 | scene0616_01 110 | scene0671_00 111 | scene0671_01 112 | scene0618_00 113 | scene0382_00 114 | scene0382_01 115 | scene0490_00 116 | scene0621_00 117 | scene0607_00 118 | scene0607_01 119 | scene0149_00 120 | scene0695_00 121 | scene0695_01 122 | scene0695_02 123 | scene0695_03 124 | scene0389_00 125 | scene0377_00 126 | scene0377_01 127 | scene0377_02 128 | scene0342_00 129 | scene0139_00 130 | scene0629_00 131 | scene0629_01 132 | scene0629_02 133 | scene0496_00 134 | scene0633_00 135 | scene0633_01 136 | scene0518_00 137 | scene0652_00 138 | scene0406_00 139 | scene0406_01 140 | scene0406_02 141 | scene0144_00 142 | scene0144_01 143 | scene0494_00 144 | scene0278_00 145 | scene0278_01 146 | scene0316_00 147 | scene0609_00 148 | scene0609_01 149 | scene0609_02 150 | scene0609_03 151 | scene0084_00 152 | scene0084_01 153 | scene0084_02 154 | scene0696_00 155 | scene0696_01 156 | scene0696_02 157 | scene0351_00 158 | scene0351_01 159 | scene0643_00 160 | scene0644_00 161 | scene0645_00 162 | scene0645_01 163 | scene0645_02 164 | scene0081_00 165 | scene0081_01 166 | scene0081_02 167 | scene0647_00 168 | scene0647_01 169 | scene0535_00 170 | scene0353_00 171 | scene0353_01 172 | scene0353_02 173 | scene0559_00 174 | scene0559_01 175 | scene0559_02 176 | scene0593_00 177 | scene0593_01 178 | scene0246_00 179 | scene0653_00 180 | scene0653_01 181 | scene0064_00 182 | scene0064_01 183 | scene0356_00 184 | scene0356_01 185 | scene0356_02 186 | scene0030_00 187 | scene0030_01 188 | scene0030_02 189 | scene0222_00 190 | scene0222_01 191 | scene0338_00 192 | scene0338_01 193 | scene0338_02 194 | scene0378_00 195 | scene0378_01 196 | scene0378_02 197 | scene0660_00 198 | scene0553_00 199 | scene0553_01 200 | scene0553_02 201 | scene0527_00 202 | scene0663_00 203 | scene0663_01 204 | scene0663_02 205 | scene0664_00 206 | scene0664_01 207 | scene0664_02 208 | scene0334_00 209 | scene0334_01 210 | scene0334_02 211 | scene0046_00 212 | scene0046_01 213 | scene0046_02 214 | scene0203_00 215 | scene0203_01 216 | scene0203_02 217 | scene0088_00 218 | scene0088_01 219 | scene0088_02 220 | scene0088_03 221 | scene0086_00 222 | scene0086_01 223 | scene0086_02 224 | scene0670_00 225 | scene0670_01 226 | scene0256_00 227 | scene0256_01 228 | scene0256_02 229 | scene0249_00 230 | scene0441_00 231 | scene0658_00 232 | scene0704_00 233 | scene0704_01 234 | scene0187_00 235 | scene0187_01 236 | scene0131_00 237 | scene0131_01 238 | scene0131_02 239 | scene0207_00 240 | scene0207_01 241 | scene0207_02 242 | scene0461_00 243 | scene0011_00 244 | scene0011_01 245 | scene0343_00 246 | scene0251_00 247 | scene0077_00 248 | scene0077_01 249 | scene0684_00 250 | scene0684_01 251 | scene0550_00 252 | scene0686_00 253 | scene0686_01 254 | scene0686_02 255 | scene0208_00 256 | scene0500_00 257 | scene0500_01 258 | scene0552_00 259 | scene0552_01 260 | scene0648_00 261 | scene0648_01 262 | scene0435_00 263 | scene0435_01 264 | scene0435_02 265 | scene0435_03 266 | scene0690_00 267 | scene0690_01 268 | scene0693_00 269 | scene0693_01 270 | scene0693_02 271 | scene0700_00 272 | scene0700_01 273 | scene0700_02 274 | scene0699_00 275 | scene0231_00 276 | scene0231_01 277 | scene0231_02 278 | scene0697_00 279 | scene0697_01 280 | scene0697_02 281 | scene0697_03 282 | scene0474_00 283 | scene0474_01 284 | scene0474_02 285 | scene0474_03 286 | scene0474_04 287 | scene0474_05 288 | scene0355_00 289 | scene0355_01 290 | scene0146_00 291 | scene0146_01 292 | scene0146_02 293 | scene0196_00 294 | scene0702_00 295 | scene0702_01 296 | scene0702_02 297 | scene0314_00 298 | scene0277_00 299 | scene0277_01 300 | scene0277_02 301 | scene0095_00 302 | scene0095_01 303 | scene0015_00 304 | scene0100_00 305 | scene0100_01 306 | scene0100_02 307 | scene0558_00 308 | scene0558_01 309 | scene0558_02 310 | scene0685_00 311 | scene0685_01 312 | scene0685_02 313 | -------------------------------------------------------------------------------- /src/data/preprocess_s3dis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import glob 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | from plyfile import PlyData, PlyElement 8 | 9 | 10 | STANFORD_3D_IN_PATH = '/root/data/s3dis/Stanford3dDataset_v1.2/' # you may need to modify this path. 11 | STANFORD_3D_OUT_PATH = '/root/data/s3dis/s3dis_processed' # you may need to modify this path. 12 | IGNORE_LABEL = 255 13 | 14 | 15 | def mkdir_p(path): 16 | try: 17 | os.makedirs(path) 18 | except OSError as exc: 19 | if exc.errno == errno.EEXIST and os.path.isdir(path): 20 | pass 21 | else: 22 | raise 23 | 24 | 25 | def save_point_cloud(points_3d, filename, binary=True, with_label=False, verbose=True): 26 | """Save an RGB point cloud as a PLY file. 27 | Args: 28 | points_3d: Nx6 matrix where points_3d[:, :3] are the XYZ coordinates and points_3d[:, 4:] are 29 | the RGB values. If Nx3 matrix, save all points with [128, 128, 128] (gray) color. 30 | """ 31 | assert points_3d.ndim == 2 32 | if with_label: 33 | assert points_3d.shape[1] == 7 34 | python_types = (float, float, float, int, int, int, int) 35 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), 36 | ('blue', 'u1'), ('label', 'u1')] 37 | else: 38 | if points_3d.shape[1] == 3: 39 | gray_concat = np.tile(np.array([128], dtype=np.uint8), (points_3d.shape[0], 3)) 40 | points_3d = np.hstack((points_3d, gray_concat)) 41 | assert points_3d.shape[1] == 6 42 | python_types = (float, float, float, int, int, int) 43 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), 44 | ('blue', 'u1')] 45 | if binary: 46 | # Format into NumPy structured array 47 | vertices = [] 48 | for row_idx in range(points_3d.shape[0]): 49 | cur_point = points_3d[row_idx] 50 | vertices.append(tuple(dtype(point) for dtype, point in zip(python_types, cur_point))) 51 | vertices_array = np.array(vertices, dtype=npy_types) 52 | el = PlyElement.describe(vertices_array, 'vertex') 53 | 54 | # Write 55 | PlyData([el]).write(filename) 56 | else: 57 | raise NotImplementedError 58 | if verbose is True: 59 | print('Saved point cloud to: %s' % filename) 60 | 61 | 62 | class Stanford3DDatasetConverter: 63 | 64 | CLASSES = [ 65 | 'ceiling', 'floor', 'wall', 'beam', 'column', 66 | 'window', 'door', 'chair', 'table', 'bookcase', 67 | 'sofa', 'board', 'clutter' 68 | ] 69 | TRAIN_TEXT = 'train' 70 | VAL_TEXT = 'val' 71 | TEST_TEXT = 'test' 72 | 73 | @classmethod 74 | def read_txt(cls, txtfile): 75 | # Read txt file and parse its content. 76 | obj_name = txtfile.split('/')[-1] 77 | if obj_name == 'ceiling_1.txt': 78 | with open(txtfile, 'r') as f: 79 | lines = f.readlines() 80 | for l_i, line in enumerate(lines): 81 | # https://github.com/zeliu98/CloserLook3D/issues/15 82 | if '103.0\x100000' in line: 83 | print(f'Fix bug in {txtfile}') 84 | print(f'Bug line: {line}') 85 | lines[l_i] = line.replace('103.0\x100000', '103.000000') 86 | with open(txtfile, 'w') as f: 87 | f.writelines(lines) 88 | try: 89 | pointcloud = np.loadtxt(txtfile, dtype=np.float32) 90 | except: 91 | print('Bug!!!!!!!!!!!!!!!!!!!!!') 92 | print(obj_name) 93 | print(txtfile) 94 | 95 | # Load point cloud to named numpy array. 96 | pointcloud = np.array(pointcloud).astype(np.float32) 97 | assert pointcloud.shape[1] == 6 98 | xyz = pointcloud[:, :3].astype(np.float32) 99 | rgb = pointcloud[:, 3:].astype(np.uint8) 100 | return xyz, rgb 101 | 102 | @classmethod 103 | def convert_to_ply(cls, root_path, out_path): 104 | """Convert Stanford3DDataset to PLY format that is compatible with 105 | Synthia dataset. Assumes file structure as given by the dataset. 106 | Outputs the processed PLY files to `STANFORD_3D_OUT_PATH`. 107 | """ 108 | 109 | txtfiles = glob.glob(os.path.join(root_path, '*/*/*.txt')) 110 | for txtfile in tqdm(txtfiles): 111 | area_name = txtfile.split('/')[-3] 112 | file_sp = os.path.normpath(txtfile).split(os.path.sep) 113 | target_path = os.path.join(out_path, file_sp[-3]) 114 | out_file = os.path.join(target_path, file_sp[-2] + '.ply') 115 | 116 | if os.path.exists(out_file): 117 | print(out_file, ' exists') 118 | continue 119 | 120 | annotation, _ = os.path.split(txtfile) 121 | subclouds = glob.glob(os.path.join(annotation, 'Annotations/*.txt')) 122 | coords, feats, labels = [], [], [] 123 | for inst, subcloud in enumerate(subclouds): 124 | # Read ply file and parse its rgb values. 125 | xyz, rgb = cls.read_txt(subcloud) 126 | _, annotation_subfile = os.path.split(subcloud) 127 | clsname = annotation_subfile.split('_')[0] 128 | # https://github.com/chrischoy/SpatioTemporalSegmentation/blob/4afee296ebe387d9a06fc1b168c4af212a2b4804/lib/datasets/stanford.py#L20 129 | if clsname == 'stairs': 130 | print('Ignore stairs') 131 | clsidx = IGNORE_LABEL 132 | else: 133 | clsidx = cls.CLASSES.index(clsname) 134 | 135 | coords.append(xyz) 136 | feats.append(rgb) 137 | labels.append(np.full((len(xyz), 1), clsidx, dtype=np.int32)) 138 | 139 | if len(coords) == 0: 140 | print(txtfile, ' has 0 files.') 141 | else: 142 | # Concat 143 | coords = np.concatenate(coords, 0) 144 | feats = np.concatenate(feats, 0) 145 | labels = np.concatenate(labels, 0) 146 | 147 | pointcloud = np.concatenate((coords, feats, labels), axis=1) 148 | 149 | # Write ply file. 150 | mkdir_p(target_path) 151 | save_point_cloud(pointcloud, out_file, with_label=True, verbose=False) 152 | 153 | 154 | if __name__ == '__main__': 155 | Stanford3DDatasetConverter.convert_to_ply(STANFORD_3D_IN_PATH, STANFORD_3D_OUT_PATH) 156 | -------------------------------------------------------------------------------- /src/data/preprocess_scannet.py: -------------------------------------------------------------------------------- 1 | import json 2 | from concurrent.futures import ProcessPoolExecutor 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | import numpy as np 7 | from plyfile import PlyData, PlyElement 8 | 9 | 10 | SCANNET_RAW_PATH = Path('/root/data/scannetv2_raw') # you may need to modify this path. 11 | SCANNET_OUT_PATH = Path('/root/data/scannet_processed') # you may need to modify this path. 12 | TRAIN_DEST = 'train' 13 | TEST_DEST = 'test' 14 | SUBSETS = {TRAIN_DEST: 'scans', TEST_DEST: 'scans_test'} 15 | POINTCLOUD_FILE = '_vh_clean_2.ply' 16 | BUGS = { 17 | 'train/scene0270_00.ply': 50, 18 | 'train/scene0270_02.ply': 50, 19 | 'train/scene0384_00.ply': 149, 20 | } # https://github.com/ScanNet/ScanNet/issues/20 21 | 22 | 23 | def read_plyfile(path): 24 | with open(path, 'rb') as f: 25 | data = PlyData.read(f) 26 | if data.elements: 27 | return pd.DataFrame(data.elements[0].data).values 28 | 29 | 30 | def save_point_cloud(points_3d, filename, verbose=True): 31 | assert points_3d.ndim == 2 32 | assert points_3d.shape[1] == 8 # x, y, z, r, g, b, semantic_label, instance_label 33 | 34 | python_types = (float, float, float, int, int, int, int, int) 35 | npy_types = [ 36 | ('x', 'f4'), 37 | ('y', 'f4'), 38 | ('z', 'f4'), 39 | ('red', 'u1'), 40 | ('green', 'u1'), 41 | ('blue', 'u1'), 42 | ('semantic_label', 'u1'), 43 | ('instance_label', 'u1') 44 | ] 45 | 46 | vertices = [] 47 | for row_idx in range(points_3d.shape[0]): 48 | cur_point = points_3d[row_idx] 49 | vertices.append(tuple(dtype(point) for dtype, point in zip(python_types, cur_point))) 50 | vertices_array = np.array(vertices, dtype=npy_types) 51 | el = PlyElement.describe(vertices_array, 'vertex') 52 | # Write 53 | PlyData([el]).write(filename) 54 | 55 | if verbose: 56 | print(f'Saved point cloud to: {filename}') 57 | 58 | 59 | def handle_process(paths): 60 | f = paths[0] 61 | phase_out_path = paths[1] 62 | out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix) 63 | pointcloud = read_plyfile(f) 64 | num_points = pointcloud.shape[0] 65 | # Make sure alpha value is meaningless. 66 | assert np.unique(pointcloud[:, -1]).size == 1 67 | # Load label file. 68 | segment_f = f.with_suffix('.0.010000.segs.json') 69 | segment_group_f = (f.parent / f.name[:-len(POINTCLOUD_FILE)]).with_suffix('.aggregation.json') 70 | semantic_f = f.parent / (f.stem + '.labels' + f.suffix) 71 | 72 | if semantic_f.is_file(): 73 | semantic_label = read_plyfile(semantic_f) 74 | # Sanity check that the pointcloud and its label has same vertices. 75 | assert pointcloud.shape[0] == semantic_label.shape[0] 76 | assert np.allclose(pointcloud[:, :3], semantic_label[:, :3]) 77 | semantic_label = semantic_label[:, -1] 78 | # Load instance label 79 | with open(segment_f) as f: 80 | segment = np.array(json.load(f)['segIndices']) 81 | with open(segment_group_f) as f: 82 | segment_groups = json.load(f)['segGroups'] 83 | assert segment.size == num_points 84 | instance_label = np.zeros(num_points) 85 | for group_idx, segment_group in enumerate(segment_groups): 86 | for segment_idx in segment_group['segments']: 87 | instance_label[segment == segment_idx] = group_idx + 1 88 | else: # Label may not exist in test case. 89 | semantic_label = np.zeros(num_points) 90 | instance_label = np.zeros(num_points) 91 | 92 | processed = np.hstack((pointcloud[:, :6], semantic_label[:, None], instance_label[:, None])) 93 | save_point_cloud(processed, out_f, verbose=False) 94 | 95 | 96 | def main(): 97 | path_list = [] 98 | for out_path, in_path in SUBSETS.items(): 99 | phase_out_path = SCANNET_OUT_PATH / out_path 100 | phase_out_path.mkdir(parents=True, exist_ok=True) 101 | for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE): 102 | path_list.append([f, phase_out_path]) 103 | 104 | pool = ProcessPoolExecutor(max_workers=20) 105 | result = list(pool.map(handle_process, path_list)) 106 | 107 | # Fix bug in the data. 108 | for files, bug_index in BUGS.items(): 109 | print(files) 110 | 111 | for f in SCANNET_OUT_PATH.glob(files): 112 | pointcloud = read_plyfile(f) 113 | bug_mask = pointcloud[:, -2] == bug_index 114 | print(f'Fixing {f} bugged label {bug_index} x {bug_mask.sum()}...') 115 | pointcloud[bug_mask, -2] = 0 116 | save_point_cloud(pointcloud, f, verbose=False) 117 | 118 | 119 | if __name__ == '__main__': 120 | print('Preprocessing ScanNetV2 dataset...') 121 | main() 122 | -------------------------------------------------------------------------------- /src/data/s3dis_loader.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from typing import Optional 3 | 4 | import gin 5 | import numpy as np 6 | import torch 7 | import pytorch_lightning as pl 8 | 9 | from src.data.scannet_loader import read_ply 10 | from src.data.collate import CollationFunctionFactory 11 | from src.data.sampler import InfSampler 12 | import src.data.transforms as T 13 | 14 | CLASSES = [ 15 | 'ceiling', 16 | 'floor', 17 | 'wall', 18 | 'beam', 19 | 'column', 20 | 'window', 21 | 'door', 22 | 'chair', 23 | 'table', 24 | 'bookcase', 25 | 'sofa', 26 | 'board', 27 | 'clutter', 28 | ] 29 | 30 | 31 | def read_txt(path): 32 | """Read txt file into lines. 33 | """ 34 | with open(path) as f: 35 | lines = f.readlines() 36 | lines = [x.strip() for x in lines] 37 | return lines 38 | 39 | 40 | @gin.configurable 41 | class S3DISArea5DatasetBase(torch.utils.data.Dataset): 42 | IN_CHANNELS = None 43 | NUM_CLASSES = 13 44 | SPLIT_FILES = { 45 | 'train': ['area1.txt', 'area2.txt', 'area3.txt', 'area4.txt', 'area6.txt'], 46 | 'val': ['area5.txt'], 47 | 'test': ['area5.txt'] 48 | } 49 | 50 | def __init__(self, phase, data_root, transform=None, ignore_label=255): 51 | assert self.IN_CHANNELS is not None 52 | assert phase in ['train', 'val', 'test'] 53 | super(S3DISArea5DatasetBase, self).__init__() 54 | 55 | self.phase = phase 56 | self.data_root = data_root 57 | self.transform = transform 58 | self.ignore_label = ignore_label 59 | self.split_files = self.SPLIT_FILES[phase] 60 | 61 | filenames = [] 62 | for split_file in self.split_files: 63 | filenames += read_txt(osp.join(self.data_root, 'meta_data', split_file)) 64 | self.filenames = [ 65 | osp.join(self.data_root, 's3dis_processed', fname) for fname in filenames 66 | ] 67 | 68 | def __len__(self): 69 | return len(self.filenames) 70 | 71 | def get_classnames(self): 72 | return CLASSES 73 | 74 | def __getitem__(self, idx): 75 | data = self._load_data(idx) 76 | coords, feats, labels = self.get_cfl_from_data(data) 77 | if self.transform is not None: 78 | coords, feats, labels = self.transform(coords, feats, labels) 79 | coords = torch.from_numpy(coords) 80 | feats = torch.from_numpy(feats) 81 | labels = torch.from_numpy(labels) 82 | return coords.float(), feats.float(), labels.long(), None 83 | 84 | def get_cfl_from_data(self, data): 85 | raise NotImplementedError 86 | 87 | def _load_data(self, idx): 88 | filename = self.filenames[idx] 89 | data = read_ply(filename) 90 | return data 91 | 92 | 93 | @gin.configurable 94 | class S3DISArea5RGBDataset(S3DISArea5DatasetBase): 95 | IN_CHANNELS = 3 96 | 97 | def __init__(self, phase, data_root, transform=None, ignore_label=255): 98 | super(S3DISArea5RGBDataset, self).__init__(phase, data_root, transform, ignore_label) 99 | 100 | def get_cfl_from_data(self, data): 101 | xyz, rgb, label = data[:, :3], data[:, 3:6], data[:, 6] 102 | return ( 103 | xyz.astype(np.float32), rgb.astype(np.float32), label.astype(np.int64) 104 | ) 105 | 106 | 107 | @gin.configurable 108 | class S3DISArea5RGBDataModule(pl.LightningDataModule): 109 | def __init__( 110 | self, 111 | data_root, 112 | train_batch_size, 113 | val_batch_size, 114 | train_num_workers, 115 | val_num_workers, 116 | collation_type, 117 | train_transforms, 118 | eval_transforms, 119 | ): 120 | super(S3DISArea5RGBDataModule, self).__init__() 121 | self.data_root = data_root 122 | self.train_batch_size = train_batch_size 123 | self.val_batch_size = val_batch_size 124 | self.train_num_workers = train_num_workers 125 | self.val_num_workers = val_num_workers 126 | self.collate_fn = CollationFunctionFactory(collation_type) 127 | self.train_transforms_ = train_transforms 128 | self.eval_transforms_ = eval_transforms 129 | 130 | def setup(self, stage: Optional[str] = None): 131 | if stage == "fit" or stage is None: 132 | train_transforms = [] 133 | if self.train_transforms_ is not None: 134 | for name in self.train_transforms_: 135 | train_transforms.append(getattr(T, name)()) 136 | train_transforms = T.Compose(train_transforms) 137 | self.dset_train = S3DISArea5RGBDataset("train", self.data_root, train_transforms) 138 | eval_transforms = [] 139 | if self.eval_transforms_ is not None: 140 | for name in self.eval_transforms_: 141 | eval_transforms.append(getattr(T, name)()) 142 | eval_transforms = T.Compose(eval_transforms) 143 | self.dset_val = S3DISArea5RGBDataset("val", self.data_root, eval_transforms) 144 | 145 | def train_dataloader(self): 146 | return torch.utils.data.DataLoader( 147 | self.dset_train, batch_size=self.train_batch_size, sampler=InfSampler(self.dset_train, True), 148 | num_workers=self.train_num_workers, collate_fn=self.collate_fn 149 | ) 150 | 151 | def val_dataloader(self): 152 | return torch.utils.data.DataLoader( 153 | self.dset_val, batch_size=self.val_batch_size, shuffle=False, num_workers=self.val_num_workers, 154 | drop_last=False, collate_fn=self.collate_fn 155 | ) -------------------------------------------------------------------------------- /src/data/sampler.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | @gin.configurable 7 | class InfSampler(Sampler): 8 | """Samples elements randomly, without replacement. 9 | Arguments: 10 | data_source (Dataset): dataset to sample from 11 | """ 12 | 13 | def __init__(self, data_source, shuffle=False): 14 | self.data_source = data_source 15 | self.shuffle = shuffle 16 | self.reset_permutation() 17 | 18 | def reset_permutation(self): 19 | perm = len(self.data_source) 20 | if self.shuffle: 21 | perm = torch.randperm(perm) 22 | else: 23 | perm = torch.arange(perm) 24 | self._perm = perm.tolist() 25 | 26 | def __iter__(self): 27 | return self 28 | 29 | def __next__(self): 30 | if len(self._perm) == 0: 31 | self.reset_permutation() 32 | return self._perm.pop() 33 | 34 | def __len__(self): 35 | return len(self.data_source) -------------------------------------------------------------------------------- /src/data/scannet_loader.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from typing import Optional 3 | 4 | import gin 5 | import numpy as np 6 | from plyfile import PlyData 7 | from pandas import DataFrame 8 | import torch 9 | import pytorch_lightning as pl 10 | 11 | from src.data.collate import CollationFunctionFactory 12 | import src.data.transforms as T 13 | 14 | SCANNET_COLOR_MAP = { 15 | 0: (0., 0., 0.), 16 | 1: (174., 199., 232.), 17 | 2: (152., 223., 138.), 18 | 3: (31., 119., 180.), 19 | 4: (255., 187., 120.), 20 | 5: (188., 189., 34.), 21 | 6: (140., 86., 75.), 22 | 7: (255., 152., 150.), 23 | 8: (214., 39., 40.), 24 | 9: (197., 176., 213.), 25 | 10: (148., 103., 189.), 26 | 11: (196., 156., 148.), 27 | 12: (23., 190., 207.), # No 13 28 | 14: (247., 182., 210.), 29 | 15: (66., 188., 102.), 30 | 16: (219., 219., 141.), 31 | 17: (140., 57., 197.), 32 | 18: (202., 185., 52.), 33 | 19: (51., 176., 203.), 34 | 20: (200., 54., 131.), 35 | 21: (92., 193., 61.), 36 | 22: (78., 71., 183.), 37 | 23: (172., 114., 82.), 38 | 24: (255., 127., 14.), 39 | 25: (91., 163., 138.), 40 | 26: (153., 98., 156.), 41 | 27: (140., 153., 101.), 42 | 28: (158., 218., 229.), 43 | 29: (100., 125., 154.), 44 | 30: (178., 127., 135.), # No 31 45 | 32: (146., 111., 194.), 46 | 33: (44., 160., 44.), 47 | 34: (112., 128., 144.), 48 | 35: (96., 207., 209.), 49 | 36: (227., 119., 194.), 50 | 37: (213., 92., 176.), 51 | 38: (94., 106., 211.), 52 | 39: (82., 84., 163.), 53 | 40: (100., 85., 144.), 54 | } 55 | VALID_CLASS_LABELS = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39) 56 | VALID_CLASS_NAMES = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 57 | 'door', 'window', 'bookshelf', 'picture', 'counter', 'desk', 58 | 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink', 59 | 'bathtub', 'otherfurniture') 60 | 61 | 62 | def read_ply(filename): 63 | with open(osp.join(filename), 'rb') as f: 64 | plydata = PlyData.read(f) 65 | assert plydata.elements 66 | data = DataFrame(plydata.elements[0].data).values 67 | return data 68 | 69 | 70 | @gin.configurable 71 | class ScanNetDatasetBase(torch.utils.data.Dataset): 72 | IN_CHANNELS = None 73 | CLASS_LABELS = None 74 | SPLIT_FILES = { 75 | 'train': 'scannetv2_train.txt', 76 | 'val': 'scannetv2_val.txt', 77 | 'trainval': 'scannetv2_trainval.txt', 78 | 'test': 'scannetv2_test.txt', 79 | 'overfit': 'scannetv2_overfit.txt' 80 | } 81 | 82 | def __init__(self, phase, data_root, transform=None, ignore_label=255): 83 | assert self.IN_CHANNELS is not None 84 | assert self.CLASS_LABELS is not None 85 | assert phase in self.SPLIT_FILES.keys() 86 | super(ScanNetDatasetBase, self).__init__() 87 | 88 | self.phase = phase 89 | self.data_root = data_root 90 | self.transform = transform 91 | self.ignore_label = ignore_label 92 | self.split_file = self.SPLIT_FILES[phase] 93 | self.ignore_class_labels = tuple(set(range(41)) - set(self.CLASS_LABELS)) 94 | self.labelmap = self.get_labelmap() 95 | self.labelmap_inverse = self.get_labelmap_inverse() 96 | 97 | with open(osp.join(self.data_root, 'meta_data', self.split_file), 'r') as f: 98 | filenames = f.read().splitlines() 99 | 100 | sub_dir = 'test' if phase == 'test' else 'train' 101 | self.filenames = [ 102 | osp.join(self.data_root, 'scannet_processed', sub_dir, f'{filename}.ply') 103 | for filename in filenames 104 | ] 105 | 106 | def __len__(self): 107 | return len(self.filenames) 108 | 109 | def get_classnames(self): 110 | classnames = {} 111 | for class_id in self.CLASS_LABELS: 112 | classnames[self.labelmap[class_id]] = VALID_CLASS_NAMES[VALID_CLASS_LABELS.index(class_id)] 113 | return classnames 114 | 115 | def get_colormaps(self): 116 | colormaps = {} 117 | for class_id in self.CLASS_LABELS: 118 | colormaps[self.labelmap[class_id]] = SCANNET_COLOR_MAP[class_id] 119 | return colormaps 120 | 121 | def get_labelmap(self): 122 | labelmap = {} 123 | for k in range(41): 124 | if k in self.ignore_class_labels: 125 | labelmap[k] = self.ignore_label 126 | else: 127 | labelmap[k] = self.CLASS_LABELS.index(k) 128 | return labelmap 129 | 130 | def get_labelmap_inverse(self): 131 | labelmap_inverse = {} 132 | for k, v in self.labelmap.items(): 133 | labelmap_inverse[v] = self.ignore_label if v == self.ignore_label else k 134 | return labelmap_inverse 135 | 136 | 137 | @gin.configurable 138 | class ScanNetRGBDataset(ScanNetDatasetBase): 139 | IN_CHANNELS = 3 140 | CLASS_LABELS = VALID_CLASS_LABELS 141 | NUM_CLASSES = len(VALID_CLASS_LABELS) # 20 142 | 143 | def __getitem__(self, idx): 144 | data = self._load_data(idx) 145 | coords, feats, labels = self.get_cfl_from_data(data) 146 | if self.transform is not None: 147 | coords, feats, labels = self.transform(coords, feats, labels) 148 | coords = torch.from_numpy(coords) 149 | feats = torch.from_numpy(feats) 150 | labels = torch.from_numpy(labels) 151 | return coords.float(), feats.float(), labels.long(), None 152 | 153 | def _load_data(self, idx): 154 | filename = self.filenames[idx] 155 | data = read_ply(filename) 156 | return data 157 | 158 | def get_cfl_from_data(self, data): 159 | xyz, rgb, labels = data[:, :3], data[:, 3:6], data[:, -2] 160 | labels = np.array([self.labelmap[x] for x in labels]) 161 | return ( 162 | xyz.astype(np.float32), 163 | rgb.astype(np.float32), 164 | labels.astype(np.int64) 165 | ) 166 | 167 | 168 | @gin.configurable 169 | class ScanNetRGBDataModule(pl.LightningDataModule): 170 | def __init__( 171 | self, 172 | data_root, 173 | train_batch_size, 174 | val_batch_size, 175 | train_num_workers, 176 | val_num_workers, 177 | collation_type, 178 | train_transforms, 179 | eval_transforms, 180 | ): 181 | super(ScanNetRGBDataModule, self).__init__() 182 | self.data_root = data_root 183 | self.train_batch_size = train_batch_size 184 | self.val_batch_size = val_batch_size 185 | self.train_num_workers = train_num_workers 186 | self.val_num_workers = val_num_workers 187 | self.collate_fn = CollationFunctionFactory(collation_type) 188 | self.train_transforms_ = train_transforms 189 | self.eval_transforms_ = eval_transforms 190 | 191 | def setup(self, stage: Optional[str] = None): 192 | if stage == "fit" or stage is None: 193 | train_transforms = [] 194 | if self.train_transforms_ is not None: 195 | for name in self.train_transforms_: 196 | train_transforms.append(getattr(T, name)()) 197 | train_transforms = T.Compose(train_transforms) 198 | self.dset_train = ScanNetRGBDataset("train", self.data_root, train_transforms) 199 | eval_transforms = [] 200 | if self.eval_transforms_ is not None: 201 | for name in self.eval_transforms_: 202 | eval_transforms.append(getattr(T, name)()) 203 | eval_transforms = T.Compose(eval_transforms) 204 | self.dset_val = ScanNetRGBDataset("val", self.data_root, eval_transforms) 205 | 206 | def train_dataloader(self): 207 | return torch.utils.data.DataLoader( 208 | self.dset_train, batch_size=self.train_batch_size, shuffle=True, drop_last=False, 209 | num_workers=self.train_num_workers, collate_fn=self.collate_fn 210 | ) 211 | 212 | def val_dataloader(self): 213 | return torch.utils.data.DataLoader( 214 | self.dset_val, batch_size=self.val_batch_size, shuffle=False, num_workers=self.val_num_workers, 215 | drop_last=False, collate_fn=self.collate_fn 216 | ) 217 | 218 | 219 | @gin.configurable 220 | class ScanNetRGBDataset_(ScanNetRGBDataset): 221 | def __getitem__(self, idx): 222 | data, filename = self._load_data(idx) 223 | coords, feats, labels = self.get_cfl_from_data(data) 224 | if self.transform is not None: 225 | coords, feats, labels = self.transform(coords, feats, labels) 226 | coords = torch.from_numpy(coords) 227 | feats = torch.from_numpy(feats) 228 | labels = torch.from_numpy(labels) 229 | return coords.float(), feats.float(), labels.long(), filename 230 | 231 | def _load_data(self, idx): 232 | filename = self.filenames[idx] 233 | data = read_ply(filename) 234 | return data, filename -------------------------------------------------------------------------------- /src/data/transforms.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | 4 | import gin 5 | import numpy as np 6 | import scipy 7 | import scipy.ndimage 8 | import scipy.interpolate 9 | from scipy.linalg import expm, norm 10 | import torch 11 | import MinkowskiEngine as ME 12 | 13 | 14 | def homogeneous_coords(coords): 15 | assert isinstance(coords, torch.Tensor) and coords.shape[1] == 3 16 | return torch.cat([coords, torch.ones((len(coords), 1))], dim=1) 17 | 18 | 19 | class Compose(object): 20 | """Composes several transforms together.""" 21 | 22 | def __init__(self, transforms): 23 | self.transforms = transforms 24 | 25 | def __call__(self, *args): 26 | for t in self.transforms: 27 | args = t(*args) 28 | return args 29 | 30 | 31 | # A sparse tensor consists of coordinates and associated features. 32 | # You must apply augmentation to both. 33 | # In 2D, flip, shear, scale, and rotation of images are coordinate transformation 34 | # color jitter, hue, etc., are feature transformations 35 | ############################## 36 | # Coordinate transformations 37 | ############################## 38 | @gin.configurable 39 | class ElasticDistortion: 40 | def __init__(self, distortion_params=[(4, 16), (8, 24)], application_ratio=0.9): 41 | self.application_ratio = application_ratio 42 | self.distortion_params = distortion_params 43 | logging.info( 44 | f"{self.__class__.__name__} distortion_params:{distortion_params} with application_ratio:{application_ratio}" 45 | ) 46 | 47 | def elastic_distortion(self, coords, feats, labels, granularity, magnitude): 48 | """Apply elastic distortion on sparse coordinate space. 49 | pointcloud: numpy array of (number of points, at least 3 spatial dims) 50 | granularity: size of the noise grid (in same scale[m/cm] as the voxel grid) 51 | magnitude: noise multiplier 52 | """ 53 | blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3 54 | blury = np.ones((1, 3, 1, 1)).astype("float32") / 3 55 | blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3 56 | coords_min = coords.min(0) 57 | 58 | # Create Gaussian noise tensor of the size given by granularity. 59 | noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3 60 | noise = np.random.randn(*noise_dim, 3).astype(np.float32) 61 | 62 | # Smoothing. 63 | for _ in range(2): 64 | noise = scipy.ndimage.filters.convolve( 65 | noise, blurx, mode="constant", cval=0 66 | ) 67 | noise = scipy.ndimage.filters.convolve( 68 | noise, blury, mode="constant", cval=0 69 | ) 70 | noise = scipy.ndimage.filters.convolve( 71 | noise, blurz, mode="constant", cval=0 72 | ) 73 | 74 | # Trilinear interpolate noise filters for each spatial dimensions. 75 | ax = [ 76 | np.linspace(d_min, d_max, d) 77 | for d_min, d_max, d in zip( 78 | coords_min - granularity, 79 | coords_min + granularity * (noise_dim - 2), 80 | noise_dim, 81 | ) 82 | ] 83 | interp = scipy.interpolate.RegularGridInterpolator( 84 | ax, noise, bounds_error=0, fill_value=0 85 | ) 86 | coords += interp(coords) * magnitude 87 | return coords, feats, labels 88 | 89 | def __call__(self, coords, feats, labels): 90 | if self.distortion_params is not None: 91 | if random.random() < self.application_ratio: 92 | for granularity, magnitude in self.distortion_params: 93 | coords, feats, labels = self.elastic_distortion( 94 | coords, feats, labels, granularity, magnitude 95 | ) 96 | return coords, feats, labels 97 | 98 | 99 | # Rotation matrix along axis with angle theta 100 | def M(axis, theta): 101 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta)) 102 | 103 | 104 | @gin.configurable 105 | class RandomRotation(object): 106 | def __init__(self, upright_axis="z", axis_std=0.01, application_ratio=0.9): 107 | self.upright_axis = {"x": 0, "y": 1, "z": 2}[upright_axis.lower()] 108 | self.D = 3 109 | # Use the rest of axes for flipping. 110 | self.horz_axes = set(range(self.D)) - set([self.upright_axis]) 111 | self.application_ratio = application_ratio 112 | self.axis_std = axis_std 113 | logging.info( 114 | f"{self.__class__.__name__} upright_axis:{upright_axis}, axis_std:{axis_std} with application_ratio:{application_ratio}" 115 | ) 116 | 117 | def __call__(self, coords, feats, labels): 118 | if random.random() < self.application_ratio: 119 | axis = self.axis_std * np.random.randn(3) 120 | axis[self.upright_axis] += 1 121 | angle = random.random() * 2 * np.pi 122 | coords = coords @ M(axis, angle) 123 | return coords, feats, labels 124 | 125 | 126 | @gin.configurable 127 | class RandomTranslation(object): 128 | def __init__( 129 | self, 130 | max_translation=3, 131 | application_ratio=0.9, 132 | ): 133 | self.max_translation = max_translation 134 | self.application_ratio = application_ratio 135 | logging.info( 136 | f"{self.__class__.__name__} max_translation:{max_translation} with application_ratio:{application_ratio}" 137 | ) 138 | 139 | def __call__(self, coords, feats, labels): 140 | if random.random() < self.application_ratio: 141 | coords += 2 * (np.random.rand(1, 3) - 0.5) * self.max_translation 142 | return coords, feats, labels 143 | 144 | 145 | @gin.configurable 146 | class RandomScale(object): 147 | def __init__(self, scale_ratio=0.1, application_ratio=0.9): 148 | self.scale_ratio = scale_ratio 149 | self.application_ratio = application_ratio 150 | logging.info(f"{self.__class__.__name__}(scale_ratio={scale_ratio})") 151 | 152 | def __call__(self, coords, feats, labels): 153 | if random.random() < self.application_ratio: 154 | coords = coords * np.random.uniform( 155 | low=1 - self.scale_ratio, high=1 + self.scale_ratio 156 | ) 157 | return coords, feats, labels 158 | 159 | 160 | @gin.configurable 161 | class RandomCrop(object): 162 | def __init__(self, x, y, z, application_ratio=1, min_cardinality=100, max_retries=10): 163 | assert x > 0 164 | assert y > 0 165 | assert z > 0 166 | self.application_ratio = application_ratio 167 | self.max_size = np.array([[x, y, z]]) 168 | self.min_cardinality = min_cardinality 169 | self.max_retries = max_retries 170 | logging.info(f"{self.__class__.__name__} with size {self.max_size}") 171 | 172 | def __call__(self, coords: np.array, feats, labels): 173 | if random.random() > self.application_ratio: 174 | return coords, feats, labels 175 | 176 | norm_coords = coords - coords.min(0, keepdims=True) 177 | max_coords = norm_coords.max(0, keepdims=True) 178 | # start range 179 | coord_range = max_coords - self.max_size 180 | coord_range = np.clip(coord_range, a_min=0, a_max=float("inf")) 181 | # If crop size is larger than the coordinates, return orig 182 | if np.prod(coord_range == 0): 183 | return coords, feats, labels 184 | 185 | # sample crop start point 186 | valid = False 187 | retries = 0 188 | while not valid: 189 | min_box = np.random.rand(1, 3) * coord_range 190 | max_box = min_box + self.max_size 191 | sel = np.logical_and( 192 | np.prod(norm_coords > min_box, 1), np.prod(norm_coords < max_box, 1) 193 | ) 194 | if np.sum(sel) > self.min_cardinality: 195 | valid = True 196 | retries += 1 197 | if retries % 2 == 0: 198 | logging.warn(f"RandomCrop retries: {retries}. crop_range={coord_range}") 199 | if retries >= self.max_retries: 200 | break 201 | 202 | if valid: 203 | return ( 204 | coords[sel], 205 | feats if feats is None else feats[sel], 206 | labels if labels is None else labels[sel], 207 | ) 208 | return coords, feats, labels 209 | 210 | 211 | @gin.configurable 212 | class RandomAffine(object): 213 | def __init__( 214 | self, 215 | upright_axis="z", 216 | axis_std=0.1, 217 | scale_range=0.2, 218 | affine_range=0.1, 219 | application_ratio=0.9, 220 | ): 221 | self.upright_axis = {"x": 0, "y": 1, "z": 2}[upright_axis.lower()] 222 | self.D = 3 223 | self.scale_range = scale_range 224 | self.affine_range = affine_range 225 | # Use the rest of axes for flipping. 226 | self.horz_axes = set(range(self.D)) - set([self.upright_axis]) 227 | self.application_ratio = application_ratio 228 | self.axis_std = axis_std 229 | logging.info( 230 | f"{self.__class__.__name__} upright_axis:{upright_axis}, axis_std:{axis_std}, scale_range:{scale_range}, affine_range:{affine_range} with application_ratio:{application_ratio}" 231 | ) 232 | 233 | def __call__(self, coords, feats, labels): 234 | if random.random() < self.application_ratio: 235 | axis = self.axis_std * np.random.randn(3) 236 | axis[self.upright_axis] += 1 237 | angle = 2 * (random.random() - 0.5) * np.pi 238 | T = M(axis, angle) @ ( 239 | np.diag(2 * (np.random.rand(3) - 0.5) * self.scale_range + 1) 240 | + 2 * (np.random.rand(3, 3) - 0.5) * self.affine_range 241 | ) 242 | coords = coords @ T 243 | return coords, feats, labels 244 | 245 | 246 | @gin.configurable 247 | class RandomHorizontalFlip(object): 248 | def __init__(self, upright_axis="z", application_ratio=0.9): 249 | """ 250 | upright_axis: axis index among x,y,z, i.e. 2 for z 251 | """ 252 | self.D = 3 253 | self.upright_axis = {"x": 0, "y": 1, "z": 2}[upright_axis.lower()] 254 | # Use the rest of axes for flipping. 255 | self.horz_axes = set(range(self.D)) - set([self.upright_axis]) 256 | self.application_ratio = application_ratio 257 | logging.info( 258 | f"{self.__class__.__name__} upright_axis:{upright_axis} with application_ratio:{application_ratio}" 259 | ) 260 | 261 | def __call__(self, coords, feats, labels): 262 | if random.random() < self.application_ratio: 263 | for curr_ax in self.horz_axes: 264 | if random.random() < 0.5: 265 | coord_max = np.max(coords[:, curr_ax]) 266 | coords[:, curr_ax] = coord_max - coords[:, curr_ax] 267 | return coords, feats, labels 268 | 269 | 270 | @gin.configurable 271 | class CoordinateDropout(object): 272 | def __init__(self, dropout_ratio=0.2, application_ratio=0.2): 273 | self.dropout_ratio = dropout_ratio 274 | self.application_ratio = application_ratio 275 | logging.info( 276 | f"{self.__class__.__name__} dropout:{dropout_ratio} with application_ratio:{application_ratio}" 277 | ) 278 | 279 | def __call__(self, coords, feats, labels): 280 | if random.random() < self.application_ratio: 281 | N = len(coords) 282 | inds = np.random.choice(N, int(N * (1 - self.dropout_ratio)), replace=False) 283 | return ( 284 | coords[inds], 285 | feats if feats is None else feats[inds], 286 | labels if labels is None else labels[inds], 287 | ) 288 | return coords, feats, labels 289 | 290 | 291 | @gin.configurable 292 | class CoordinateJitter(object): 293 | def __init__(self, jitter_std=0.5, application_ratio=0.7): 294 | self.jitter_std = jitter_std 295 | self.application_ratio = application_ratio 296 | logging.info( 297 | f"{self.__class__.__name__} jitter_std:{jitter_std} with application_ratio:{application_ratio}" 298 | ) 299 | 300 | def __call__(self, coords, feats, labels): 301 | if random.random() < self.application_ratio: 302 | N = len(coords) 303 | coords += (2 * self.jitter_std) * (np.random.rand(N, 3) - 0.5) 304 | return coords, feats, labels 305 | 306 | 307 | @gin.configurable 308 | class CoordinateUniformTranslation: 309 | def __init__(self, max_translation=0.2): 310 | self.max_translation = max_translation 311 | 312 | def __call__(self, coords, feats, labels): 313 | if self.max_translation > 0: 314 | coords += np.random.uniform( 315 | low=-self.max_translation, high=self.max_translation, size=[1, 3] 316 | ) 317 | return coords, feats, labels 318 | 319 | 320 | @gin.configurable 321 | class RegionDropout(object): 322 | def __init__( 323 | self, 324 | box_center_range=[100, 100, 10], 325 | max_region_size=[300, 300, 300], 326 | min_region_size=[100, 100, 100], 327 | application_ratio=0.3, 328 | ): 329 | self.max_region_size = np.array(max_region_size) 330 | self.min_region_size = np.array(min_region_size) 331 | self.box_range = self.max_region_size - self.min_region_size 332 | self.box_center_range = np.array([box_center_range]) 333 | self.application_ratio = application_ratio 334 | logging.info( 335 | f"{self.__class__.__name__} max_region_size:{max_region_size} min_region_size:{min_region_size} box_center_range:{box_center_range} with application_ratio:{application_ratio}" 336 | ) 337 | 338 | def __call__(self, coords, feats, labels): 339 | if random.random() < self.application_ratio: 340 | while True: 341 | box_center = self.box_center_range * ( 342 | np.random.rand(1, 3) - 0.5 343 | ) * 2 + np.mean(coords, axis=0, keepdims=True) 344 | box_size = self.box_range * np.random.rand(1, 3) 345 | min_xyz = box_center - box_size / 2 346 | max_xyz = box_center + box_size / 2 347 | sel = np.logical_not( 348 | np.prod(coords < max_xyz, axis=1) 349 | * np.prod(coords > min_xyz, axis=1) 350 | ) 351 | if sel.sum() > len(coords) * 0.5: 352 | break 353 | return coords[sel], feats[sel], labels[sel] 354 | return coords, feats, labels 355 | 356 | 357 | @gin.configurable 358 | class DimensionlessCoordinates(object): 359 | def __init__(self, voxel_size=0.02): 360 | self.voxel_size = voxel_size 361 | logging.info(f"{self.__class__.__name__} with voxel_size:{voxel_size}") 362 | 363 | def __call__(self, coords, feats, labels): 364 | return coords / self.voxel_size, feats, labels 365 | 366 | 367 | @gin.configurable 368 | class PerlinNoise: 369 | def __init__( 370 | self, noise_params=[(4, 4), (16, 16)], application_ratio=0.9, device="cpu" 371 | ): 372 | self.application_ratio = application_ratio 373 | self.noise_params = noise_params 374 | logging.info( 375 | f"{self.__class__.__name__} noise_params:{noise_params} with application_ratio:{application_ratio}" 376 | ) 377 | self.interp = ME.MinkowskiInterpolation() 378 | self.corners = torch.Tensor( 379 | [ 380 | [0, 0, 0], 381 | [0, 0, 1], 382 | [0, 1, 0], 383 | [1, 0, 0], 384 | [0, 1, 1], 385 | [1, 1, 0], 386 | [1, 0, 1], 387 | [1, 1, 1], 388 | ] 389 | ) 390 | self.smooth = ME.MinkowskiConvolution( 391 | in_channels=3, 392 | out_channels=3, 393 | kernel_size=3, 394 | bias=False, 395 | dimension=3, 396 | ) 397 | self.smooth.kernel[:] = 1 / 27 398 | if device is None and torch.cuda.is_available(): 399 | self.device = "cuda" 400 | else: 401 | self.device = "cpu" 402 | self.smooth = self.smooth.to(self.device) 403 | self.corners = self.corners.to(self.device) 404 | 405 | def perlin_noise(self, coordinates, noise_quantization_size, noise_std): 406 | aug_coordinates = coordinates.reshape(-1, 1, 3) + ( 407 | self.corners * noise_quantization_size 408 | ).reshape(1, 8, 3) 409 | bcoords = ME.utils.batched_coordinates( 410 | [aug_coordinates.reshape(-1, 3) / noise_quantization_size], 411 | dtype=torch.float32, 412 | ) 413 | noise_tensor = ME.SparseTensor( 414 | features=torch.randn((len(bcoords), 3), device=self.device), 415 | coordinates=bcoords, 416 | device=self.device, 417 | ) 418 | noise_tensor = self.smooth(noise_tensor) 419 | interp_noise = self.interp( 420 | noise_tensor, 421 | ME.utils.batched_coordinates( 422 | [coordinates / noise_quantization_size], 423 | dtype=torch.float32, 424 | device=self.device, 425 | ), 426 | ) 427 | return coordinates + noise_std * interp_noise 428 | 429 | def __call__(self, coords, feats, labels): 430 | if self.noise_params is not None: 431 | if random.random() < self.application_ratio: 432 | coords = torch.from_numpy(coords).to(self.device) 433 | with torch.no_grad(): 434 | for quantization_size, noise_std in self.noise_params: 435 | coords = self.perlin_noise(coords, quantization_size, noise_std) 436 | coords = coords.cpu().numpy() 437 | return coords, feats, labels 438 | 439 | 440 | ############################## 441 | # Feature transformations 442 | ############################## 443 | @gin.configurable 444 | class ChromaticTranslation(object): 445 | """Add random color to the image, input must be an array in [0,255] or a PIL image""" 446 | 447 | def __init__(self, translation_range_ratio=1e-1, application_ratio=0.9): 448 | """ 449 | trans_range_ratio: ratio of translation i.e. 255 * 2 * ratio * rand(-0.5, 0.5) 450 | """ 451 | self.trans_range_ratio = translation_range_ratio 452 | self.application_ratio = application_ratio 453 | logging.info( 454 | f"{self.__class__.__name__} with translation_range_ratio:{translation_range_ratio} with application_ratio:{application_ratio}" 455 | ) 456 | 457 | def __call__(self, coords, feats, labels): 458 | if random.random() < self.application_ratio: 459 | tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.trans_range_ratio 460 | feats[:, :3] = np.clip(tr + feats[:, :3], 0, 255) 461 | return coords, feats, labels 462 | 463 | 464 | @gin.configurable 465 | class ChromaticJitter(object): 466 | def __init__(self, std=0.01, application_ratio=0.9): 467 | self.std = std 468 | self.application_ratio = application_ratio 469 | logging.info( 470 | f"{self.__class__.__name__} with std:{std} with application_ratio:{application_ratio}" 471 | ) 472 | 473 | def __call__(self, coords, feats, labels): 474 | if random.random() < self.application_ratio: 475 | noise = np.random.randn(feats.shape[0], 3) 476 | noise *= self.std * 255 477 | feats[:, :3] = np.clip(noise + feats[:, :3], 0, 255) 478 | return coords, feats, labels 479 | 480 | 481 | @gin.configurable 482 | class ChromaticAutoContrast(object): 483 | def __init__( 484 | self, randomize_blend_factor=True, blend_factor=0.5, application_ratio=0.2 485 | ): 486 | self.randomize_blend_factor = randomize_blend_factor 487 | self.blend_factor = blend_factor 488 | self.application_ratio = application_ratio 489 | logging.info( 490 | f"{self.__class__.__name__} with randomize_blend_factor:{randomize_blend_factor}, blend_factor:{blend_factor} with application_ratio:{application_ratio}" 491 | ) 492 | 493 | def __call__(self, coords, feats, labels): 494 | if random.random() < self.application_ratio: 495 | # mean = np.mean(feats, 0, keepdims=True) 496 | # std = np.std(feats, 0, keepdims=True) 497 | # lo = mean - std 498 | # hi = mean + std 499 | lo = feats[:, :3].min(0, keepdims=True) 500 | hi = feats[:, :3].max(0, keepdims=True) 501 | assert hi.max() > 1, f"invalid color value. Color is supposed to be [0-255]" 502 | 503 | if np.prod(hi - lo): 504 | scale = 255 / (hi - lo) 505 | contrast_feats = (feats[:, :3] - lo) * scale 506 | blend_factor = ( 507 | random.random() if self.randomize_blend_factor else self.blend_factor 508 | ) 509 | feats[:, :3] = (1 - blend_factor) * feats + blend_factor * contrast_feats 510 | return coords, feats, labels 511 | 512 | 513 | @gin.configurable 514 | class NormalizeColor(object): 515 | def __init__(self, mean=[128, 128, 128], std=[256, 256, 256], pre_norm=False): 516 | self.mean = np.array([mean], dtype=np.float32) 517 | self.std = np.array([std], dtype=np.float32) 518 | self.pre_norm = pre_norm 519 | logging.info(f"{self.__class__.__name__} mean:{mean} std:{std}") 520 | 521 | def __call__(self, coords, feats, labels): 522 | if self.pre_norm: 523 | feats = feats / 255. 524 | return coords, (feats - self.mean) / self.std, labels 525 | 526 | 527 | @gin.configurable 528 | class HueSaturationTranslation(object): 529 | @staticmethod 530 | def rgb_to_hsv(rgb): 531 | # Translated from source of colorsys.rgb_to_hsv 532 | # r,g,b should be a numpy arrays with values between 0 and 255 533 | # rgb_to_hsv returns an array of floats between 0.0 and 1.0. 534 | rgb = rgb.astype("float") 535 | hsv = np.zeros_like(rgb) 536 | # in case an RGBA array was passed, just copy the A channel 537 | hsv[..., 3:] = rgb[..., 3:] 538 | r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] 539 | maxc = np.max(rgb[..., :3], axis=-1) 540 | minc = np.min(rgb[..., :3], axis=-1) 541 | hsv[..., 2] = maxc 542 | mask = maxc != minc 543 | hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask] 544 | rc = np.zeros_like(r) 545 | gc = np.zeros_like(g) 546 | bc = np.zeros_like(b) 547 | rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask] 548 | gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask] 549 | bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask] 550 | hsv[..., 0] = np.select( 551 | [r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc 552 | ) 553 | hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0 554 | return hsv 555 | 556 | @staticmethod 557 | def hsv_to_rgb(hsv): 558 | # Translated from source of colorsys.hsv_to_rgb 559 | # h,s should be a numpy arrays with values between 0.0 and 1.0 560 | # v should be a numpy array with values between 0.0 and 255.0 561 | # hsv_to_rgb returns an array of uints between 0 and 255. 562 | rgb = np.empty_like(hsv) 563 | rgb[..., 3:] = hsv[..., 3:] 564 | h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] 565 | i = (h * 6.0).astype("uint8") 566 | f = (h * 6.0) - i 567 | p = v * (1.0 - s) 568 | q = v * (1.0 - s * f) 569 | t = v * (1.0 - s * (1.0 - f)) 570 | i = i % 6 571 | conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5] 572 | rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v) 573 | rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t) 574 | rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p) 575 | return rgb.astype("uint8") 576 | 577 | def __init__(self, hue_max, saturation_max): 578 | self.hue_max = hue_max 579 | self.saturation_max = saturation_max 580 | 581 | def __call__(self, coords, feats, labels): 582 | # Assume feat[:, :3] is rgb 583 | hsv = HueSaturationTranslation.rgb_to_hsv(feats[:, :3]) 584 | hue_val = (random.random() - 0.5) * 2 * self.hue_max 585 | sat_ratio = 1 + (random.random() - 0.5) * 2 * self.saturation_max 586 | hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1) 587 | hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1) 588 | feats[:, :3] = np.clip(HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255) 589 | return coords, feats, labels -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from src.models.spvcnn import SPVCNN 4 | import src.models.resnet as resnets 5 | import src.models.resunet as resunets 6 | import src.models.fast_point_transformer as transformers 7 | 8 | MODELS = [SPVCNN] 9 | 10 | 11 | def add_models(module): 12 | MODELS.extend([getattr(module, a) for a in dir(module) if "Net" in a or "Transformer" in a]) 13 | 14 | 15 | add_models(resnets) 16 | add_models(resunets) 17 | add_models(transformers) 18 | 19 | 20 | def get_model(name): 21 | """Creates and returns an instance of the model given its class name.""" 22 | # Find the model class from its name 23 | all_models = MODELS 24 | mdict = {model.__name__: model for model in all_models} 25 | if name not in mdict: 26 | logging.info(f"Invalid model index. You put {name}. Options are:") 27 | # Display a list of valid model names 28 | for model in all_models: 29 | logging.info("\t* {}".format(model.__name__)) 30 | return None 31 | model_class = mdict[name] 32 | return model_class -------------------------------------------------------------------------------- /src/models/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import MinkowskiEngine as ME 3 | 4 | 5 | @torch.no_grad() 6 | def downsample_points(points, tensor_map, field_map, size): 7 | down_points = ME.MinkowskiSPMMAverageFunction().apply( 8 | tensor_map, field_map, size, points 9 | ) 10 | _, counts = torch.unique(tensor_map, return_counts=True) 11 | return down_points, counts.unsqueeze_(1).type_as(down_points) 12 | 13 | 14 | @torch.no_grad() 15 | def stride_centroids(points, counts, rows, cols, size): 16 | stride_centroids = ME.MinkowskiSPMMFunction().apply(rows, cols, counts, size, points) 17 | ones = torch.ones(size[1], dtype=points.dtype, device=points.device) 18 | stride_counts = ME.MinkowskiSPMMFunction().apply(rows, cols, ones, size, counts) 19 | stride_counts.clamp_(min=1) 20 | return torch.true_divide(stride_centroids, stride_counts), stride_counts 21 | 22 | 23 | def downsample_embeddings(embeddings, inverse_map, size, mode="avg"): 24 | assert len(embeddings) == size[1] 25 | assert mode in ["avg", "max"] 26 | if mode == "max": 27 | in_map = torch.arange(size[1], dtype=inverse_map.dtype, device=inverse_map.device) 28 | down_embeddings = ME.MinkowskiDirectMaxPoolingFunction().apply( 29 | in_map, inverse_map, embeddings, size[0] 30 | ) 31 | else: 32 | cols = torch.arange(size[1], dtype=inverse_map.dtype, device=inverse_map.device) 33 | down_embeddings = ME.MinkowskiSPMMAverageFunction().apply( 34 | inverse_map, cols, size, embeddings 35 | ) 36 | return down_embeddings -------------------------------------------------------------------------------- /src/models/fast_point_transformer.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import MinkowskiEngine as ME 6 | 7 | from src.models.transformer_base import LocalSelfAttentionBase, ResidualBlockWithPointsBase 8 | from src.models.common import stride_centroids, downsample_points, downsample_embeddings 9 | import src.cuda_ops.functions.sparse_ops as ops 10 | 11 | 12 | class MaxPoolWithPoints(nn.Module): 13 | def __init__(self, kernel_size=2, stride=2): 14 | assert kernel_size == 2 and stride == 2 15 | super(MaxPoolWithPoints, self).__init__() 16 | self.pool = ME.MinkowskiMaxPooling(kernel_size=kernel_size, stride=stride, dimension=3) 17 | 18 | def forward(self, stensor, points, counts): 19 | assert isinstance(stensor, ME.SparseTensor) 20 | assert len(stensor) == len(points) 21 | cm = stensor.coordinate_manager 22 | down_stensor = self.pool(stensor) 23 | cols, rows = cm.stride_map(stensor.coordinate_map_key, down_stensor.coordinate_map_key) 24 | size = torch.Size([len(down_stensor), len(stensor)]) 25 | down_points, down_counts = stride_centroids(points, counts, rows, cols, size) 26 | return down_stensor, down_points, down_counts 27 | 28 | 29 | #################################### 30 | # Layers 31 | #################################### 32 | @gin.configurable 33 | class LightweightSelfAttentionLayer(LocalSelfAttentionBase): 34 | def __init__( 35 | self, 36 | in_channels, 37 | out_channels=None, 38 | kernel_size=3, 39 | stride=1, 40 | dilation=1, 41 | num_heads=8, 42 | ): 43 | out_channels = in_channels if out_channels is None else out_channels 44 | assert out_channels % num_heads == 0 45 | assert kernel_size % 2 == 1 46 | assert stride == 1, "Currently, this layer only supports stride == 1" 47 | assert dilation == 1,"Currently, this layer only supports dilation == 1" 48 | super(LightweightSelfAttentionLayer, self).__init__(kernel_size, stride, dilation, dimension=3) 49 | 50 | self.in_channels = in_channels 51 | self.out_channels = out_channels 52 | self.kernel_size = kernel_size 53 | self.stride = stride 54 | self.dilation = dilation 55 | self.num_heads = num_heads 56 | self.attn_channels = out_channels // num_heads 57 | 58 | self.to_query = nn.Sequential( 59 | ME.MinkowskiLinear(in_channels, out_channels), 60 | ME.MinkowskiToFeature() 61 | ) 62 | self.to_value = nn.Sequential( 63 | ME.MinkowskiLinear(in_channels, out_channels), 64 | ME.MinkowskiToFeature() 65 | ) 66 | self.to_out = nn.Linear(out_channels, out_channels) 67 | 68 | self.inter_pos_enc = nn.Parameter(torch.FloatTensor(self.kernel_volume, self.num_heads, self.attn_channels)) 69 | self.intra_pos_mlp = nn.Sequential( 70 | nn.Linear(3, 3, bias=False), 71 | nn.BatchNorm1d(3), 72 | nn.ReLU(inplace=True), 73 | nn.Linear(3, in_channels, bias=False), 74 | nn.BatchNorm1d(in_channels), 75 | nn.ReLU(inplace=True), 76 | nn.Linear(in_channels, in_channels) 77 | ) 78 | nn.init.normal_(self.inter_pos_enc, 0, 1) 79 | 80 | def forward(self, stensor, norm_points): 81 | dtype = stensor._F.dtype 82 | device = stensor._F.device 83 | 84 | # query, key, value, and relative positional encoding 85 | intra_pos_enc = self.intra_pos_mlp(norm_points) 86 | stensor = stensor + intra_pos_enc 87 | q = self.to_query(stensor).view(-1, self.num_heads, self.attn_channels).contiguous() 88 | v = self.to_value(stensor).view(-1, self.num_heads, self.attn_channels).contiguous() 89 | 90 | # key-query map 91 | kernel_map, out_key = self.get_kernel_map_and_out_key(stensor) 92 | kq_map = self.key_query_map_from_kernel_map(kernel_map) 93 | 94 | # attention weights with cosine similarity 95 | attn = torch.zeros((kq_map.shape[1], self.num_heads), dtype=dtype, device=device) 96 | norm_q = F.normalize(q, p=2, dim=-1) 97 | norm_pos_enc = F.normalize(self.inter_pos_enc, p=2, dim=-1) 98 | attn = ops.dot_product_cuda(norm_q, norm_pos_enc, attn, kq_map) 99 | 100 | # aggregation & the output 101 | out_F = torch.zeros((len(q), self.num_heads, self.attn_channels), 102 | dtype=dtype, 103 | device=device) 104 | kq_indices = self.key_query_indices_from_key_query_map(kq_map) 105 | out_F = ops.scalar_attention_cuda(attn, v, out_F, kq_indices) 106 | out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous()) 107 | return ME.SparseTensor(out_F, 108 | coordinate_map_key=out_key, 109 | coordinate_manager=stensor.coordinate_manager) 110 | 111 | 112 | #################################### 113 | # Blocks 114 | #################################### 115 | @gin.configurable 116 | class LightweightSelfAttentionBlock(ResidualBlockWithPointsBase): 117 | LAYER = LightweightSelfAttentionLayer 118 | 119 | 120 | #################################### 121 | # Models 122 | #################################### 123 | @gin.configurable 124 | class FastPointTransformer(nn.Module): 125 | INIT_DIM = 32 126 | ENC_DIM = 32 127 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 128 | PLANES = (64, 128, 384, 640, 384, 384, 256, 128) 129 | QMODE = ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE 130 | LAYER = LightweightSelfAttentionLayer 131 | BLOCK = LightweightSelfAttentionBlock 132 | 133 | def __init__(self, in_channels, out_channels): 134 | super(FastPointTransformer, self).__init__() 135 | self.in_channels = in_channels 136 | self.out_channels = out_channels 137 | 138 | self.enc_mlp = nn.Sequential( 139 | nn.Linear(3, self.ENC_DIM, bias=False), 140 | nn.BatchNorm1d(self.ENC_DIM), 141 | nn.Tanh(), 142 | nn.Linear(self.ENC_DIM, self.ENC_DIM, bias=False), 143 | nn.BatchNorm1d(self.ENC_DIM), 144 | nn.Tanh() 145 | ) 146 | self.attn0p1 = self.LAYER(in_channels + self.ENC_DIM, self.INIT_DIM, kernel_size=5) 147 | self.bn0 = ME.MinkowskiBatchNorm(self.INIT_DIM) 148 | 149 | self.attn1p1 = self.LAYER(self.INIT_DIM, self.PLANES[0]) 150 | self.bn1 = ME.MinkowskiBatchNorm(self.PLANES[0]) 151 | self.block1 = nn.ModuleList([self.BLOCK(self.PLANES[0]) for _ in range(self.LAYERS[0])]) 152 | 153 | self.attn2p2 = self.LAYER(self.PLANES[0], self.PLANES[1]) 154 | self.bn2 = ME.MinkowskiBatchNorm(self.PLANES[1]) 155 | self.block2 = nn.ModuleList([self.BLOCK(self.PLANES[1]) for _ in range(self.LAYERS[1])]) 156 | 157 | self.attn3p4 = self.LAYER(self.PLANES[1], self.PLANES[2]) 158 | self.bn3 = ME.MinkowskiBatchNorm(self.PLANES[2]) 159 | self.block3 = nn.ModuleList([self.BLOCK(self.PLANES[2]) for _ in range(self.LAYERS[2])]) 160 | 161 | self.attn4p8 = self.LAYER(self.PLANES[2], self.PLANES[3]) 162 | self.bn4 = ME.MinkowskiBatchNorm(self.PLANES[3]) 163 | self.block4 = nn.ModuleList([self.BLOCK(self.PLANES[3]) for _ in range(self.LAYERS[3])]) 164 | 165 | self.attn5p8 = self.LAYER(self.PLANES[3] + self.PLANES[3], self.PLANES[4]) 166 | self.bn5 = ME.MinkowskiBatchNorm(self.PLANES[4]) 167 | self.block5 = nn.ModuleList([self.BLOCK(self.PLANES[4]) for _ in range(self.LAYERS[4])]) 168 | 169 | self.attn6p4 = self.LAYER(self.PLANES[4] + self.PLANES[2], self.PLANES[5]) 170 | self.bn6 = ME.MinkowskiBatchNorm(self.PLANES[5]) 171 | self.block6 = nn.ModuleList([self.BLOCK(self.PLANES[5]) for _ in range(self.LAYERS[5])]) 172 | 173 | self.attn7p2 = self.LAYER(self.PLANES[5] + self.PLANES[1], self.PLANES[6]) 174 | self.bn7 = ME.MinkowskiBatchNorm(self.PLANES[6]) 175 | self.block7 = nn.ModuleList([self.BLOCK(self.PLANES[6]) for _ in range(self.LAYERS[6])]) 176 | 177 | self.attn8p1 = self.LAYER(self.PLANES[6] + self.PLANES[0], self.PLANES[7]) 178 | self.bn8 = ME.MinkowskiBatchNorm(self.PLANES[7]) 179 | self.block8 = nn.ModuleList([self.BLOCK(self.PLANES[7]) for _ in range(self.LAYERS[7])]) 180 | 181 | self.final = nn.Sequential( 182 | nn.Linear(self.PLANES[7] + self.ENC_DIM, self.PLANES[7], bias=False), 183 | nn.BatchNorm1d(self.PLANES[7]), 184 | nn.ReLU(inplace=True), 185 | nn.Linear(self.PLANES[7], out_channels) 186 | ) 187 | self.relu = ME.MinkowskiReLU(inplace=True) 188 | self.pool = MaxPoolWithPoints() 189 | self.pooltr = ME.MinkowskiPoolingTranspose(kernel_size=2, stride=2, dimension=3) 190 | 191 | @torch.no_grad() 192 | def normalize_points(self, points, centroids, tensor_map): 193 | tensor_map = tensor_map if tensor_map.dtype == torch.int64 else tensor_map.long() 194 | norm_points = points - centroids[tensor_map] 195 | return norm_points 196 | 197 | @torch.no_grad() 198 | def normalize_centroids(self, down_points, coordinates, tensor_stride): 199 | norm_points = (down_points - coordinates[:, 1:]) / tensor_stride - 0.5 200 | return norm_points 201 | 202 | def voxelize_with_centroids(self, x: ME.TensorField): 203 | cm = x.coordinate_manager 204 | points = x.C[:, 1:] 205 | 206 | out = x.sparse() 207 | size = torch.Size([len(out), len(x)]) 208 | tensor_map, field_map = cm.field_to_sparse_map(x.coordinate_key, out.coordinate_key) 209 | points_p1, count_p1 = downsample_points(points, tensor_map, field_map, size) 210 | norm_points = self.normalize_points(points, points_p1, tensor_map) 211 | 212 | pos_embs = self.enc_mlp(norm_points) 213 | down_pos_embs = downsample_embeddings(pos_embs, tensor_map, size, mode="avg") 214 | out = ME.SparseTensor(torch.cat([out.F, down_pos_embs], dim=1), 215 | coordinate_map_key=out.coordinate_key, 216 | coordinate_manager=cm) 217 | 218 | norm_points_p1 = self.normalize_centroids(points_p1, out.C, out.tensor_stride[0]) 219 | return out, norm_points_p1, points_p1, count_p1, pos_embs 220 | 221 | def devoxelize_with_centroids(self, out: ME.SparseTensor, x: ME.TensorField, h_embs): 222 | out = self.final(torch.cat([out.slice(x).F, h_embs], dim=1)) 223 | return out 224 | 225 | def forward(self, x: ME.TensorField): 226 | out, norm_points_p1, points_p1, count_p1, pos_embs = self.voxelize_with_centroids(x) 227 | out = self.relu(self.bn0(self.attn0p1(out, norm_points_p1))) 228 | out_p1 = self.relu(self.bn1(self.attn1p1(out, norm_points_p1))) 229 | 230 | out, points_p2, count_p2 = self.pool(out_p1, points_p1, count_p1) 231 | norm_points_p2 = self.normalize_centroids(points_p2, out.C, out.tensor_stride[0]) 232 | for module in self.block1: 233 | out = module(out, norm_points_p2) 234 | out_p2 = self.relu(self.bn2(self.attn2p2(out, norm_points_p2))) 235 | 236 | out, points_p4, count_p4 = self.pool(out_p2, points_p2, count_p2) 237 | norm_points_p4 = self.normalize_centroids(points_p4, out.C, out.tensor_stride[0]) 238 | for module in self.block2: 239 | out = module(out, norm_points_p4) 240 | out_p4 = self.relu(self.bn3(self.attn3p4(out, norm_points_p4))) 241 | 242 | out, points_p8, count_p8 = self.pool(out_p4, points_p4, count_p4) 243 | norm_points_p8 = self.normalize_centroids(points_p8, out.C, out.tensor_stride[0]) 244 | for module in self.block3: 245 | out = module(out, norm_points_p8) 246 | out_p8 = self.relu(self.bn4(self.attn4p8(out, norm_points_p8))) 247 | 248 | out, points_p16 = self.pool(out_p8, points_p8, count_p8)[:2] 249 | norm_points_p16 = self.normalize_centroids(points_p16, out.C, out.tensor_stride[0]) 250 | for module in self.block4: 251 | out = module(out, norm_points_p16) 252 | 253 | out = self.pooltr(out) 254 | out = ME.cat(out, out_p8) 255 | out = self.relu(self.bn5(self.attn5p8(out, norm_points_p8))) 256 | for module in self.block5: 257 | out = module(out, norm_points_p8) 258 | 259 | out = self.pooltr(out) 260 | out = ME.cat(out, out_p4) 261 | out = self.relu(self.bn6(self.attn6p4(out, norm_points_p4))) 262 | for module in self.block6: 263 | out = module(out, norm_points_p4) 264 | 265 | out = self.pooltr(out) 266 | out = ME.cat(out, out_p2) 267 | out = self.relu(self.bn7(self.attn7p2(out, norm_points_p2))) 268 | for module in self.block7: 269 | out = module(out, norm_points_p2) 270 | 271 | out = self.pooltr(out) 272 | out = ME.cat(out, out_p1) 273 | out = self.relu(self.bn8(self.attn8p1(out, norm_points_p1))) 274 | for module in self.block8: 275 | out = module(out, norm_points_p1) 276 | 277 | out = self.devoxelize_with_centroids(out, x, pos_embs) 278 | return out 279 | 280 | 281 | @gin.configurable 282 | class FastPointTransformerSmall(FastPointTransformer): 283 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 284 | 285 | 286 | @gin.configurable 287 | class FastPointTransformerSmaller(FastPointTransformer): 288 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) 289 | -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import torch.nn as nn 3 | import MinkowskiEngine as ME 4 | from MinkowskiEngine.modules.resnet_block import BasicBlock 5 | 6 | 7 | @gin.configurable 8 | class ResNetBase(nn.Module): 9 | QMODE = ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE 10 | LAYER = ME.MinkowskiConvolution 11 | BLOCK = None 12 | LAYERS = () 13 | INIT_DIM = 64 14 | PLANES = (64, 128, 256, 512) 15 | 16 | def __init__(self, in_channels, out_channels=None, D=3): 17 | nn.Module.__init__(self) 18 | assert self.BLOCK is not None 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | self.D = D 22 | 23 | self.network_initialization(in_channels, out_channels, D) 24 | self.weight_initialization() 25 | 26 | def network_initialization(self, in_channels, out_channels, D): 27 | self.inplanes = self.INIT_DIM 28 | self.conv1 = nn.Sequential( 29 | self.LAYER(in_channels, self.inplanes, kernel_size=3, stride=2, dimension=D), 30 | ME.MinkowskiBatchNorm(self.inplanes), 31 | ME.MinkowskiReLU(inplace=True), 32 | ME.MinkowskiSumPooling(kernel_size=2, stride=2, dimension=D), 33 | ) 34 | 35 | self.layer1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2) 36 | self.layer2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2) 37 | self.layer3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2) 38 | self.layer4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2) 39 | self.conv5 = nn.Sequential( 40 | ME.MinkowskiDropout(), 41 | ME.MinkowskiConvolution(self.inplanes, 42 | self.inplanes, 43 | kernel_size=3, 44 | stride=3, 45 | dimension=D), 46 | ME.MinkowskiInstanceNorm(self.inplanes), 47 | ME.MinkowskiGELU(), 48 | ) 49 | self.glob_pool = ME.MinkowskiGlobalMaxPooling() 50 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) 51 | 52 | def weight_initialization(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.BatchNorm1d): 55 | nn.init.constant_(m.weight, 1) 56 | nn.init.constant_(m.bias, 0) 57 | elif isinstance(m, ME.MinkowskiBatchNorm): 58 | nn.init.constant_(m.bn.weight, 1) 59 | nn.init.constant_(m.bn.bias, 0) 60 | elif isinstance(m, nn.Linear): 61 | nn.init.xavier_normal_(m.weight) 62 | if m.bias is not None: 63 | nn.init.constant_(m.bias, 0) 64 | elif isinstance(m, ME.MinkowskiLinear): 65 | nn.init.xavier_normal_(m.linear.weight) 66 | if m.linear.bias is not None: 67 | nn.init.constant_(m.linear.bias, 0) 68 | elif isinstance(m, ME.MinkowskiConvolution): 69 | ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") 70 | 71 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1): 72 | downsample = None 73 | if stride != 1 or self.inplanes != planes * block.expansion: 74 | downsample = nn.Sequential( 75 | self.LAYER( 76 | self.inplanes, 77 | planes * block.expansion, 78 | kernel_size=1, 79 | stride=stride, 80 | dimension=self.D, 81 | ), 82 | ME.MinkowskiBatchNorm(planes * block.expansion), 83 | ) 84 | layers = [] 85 | layers.append( 86 | block( 87 | self.inplanes, 88 | planes, 89 | stride=stride, 90 | dilation=dilation, 91 | downsample=downsample, 92 | dimension=self.D, 93 | )) 94 | self.inplanes = planes * block.expansion 95 | for _ in range(1, blocks): 96 | layers.append( 97 | block(self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D)) 98 | return nn.Sequential(*layers) 99 | 100 | def voxelize(self, x: ME.TensorField): 101 | return x.sparse() 102 | 103 | def forward(self, x: ME.TensorField): 104 | x = self.voxelize(x) 105 | x = self.conv1(x) 106 | x = self.layer1(x) 107 | x = self.layer2(x) 108 | x = self.layer3(x) 109 | x = self.layer4(x) 110 | x = self.conv5(x) 111 | x = self.glob_pool(x) 112 | return self.final(x).F 113 | 114 | 115 | @gin.configurable 116 | class ResNet34(ResNetBase): 117 | BLOCK = BasicBlock 118 | LAYERS = (3, 4, 6, 3) -------------------------------------------------------------------------------- /src/models/resunet.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import torch 3 | import MinkowskiEngine as ME 4 | from MinkowskiEngine.modules.resnet_block import BasicBlock 5 | 6 | from src.models.resnet import ResNetBase 7 | 8 | 9 | @gin.configurable 10 | class Res16UNetBase(ResNetBase): 11 | INIT_DIM = 32 12 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 13 | PLANES = (32, 64, 128, 256, 256, 256, 256, 256) 14 | 15 | def __init__(self, in_channels, out_channels, D=3): 16 | super(Res16UNetBase, self).__init__(in_channels, out_channels, D) 17 | 18 | def network_initialization(self, in_channels, out_channels, D): 19 | self.inplanes = self.INIT_DIM 20 | self.conv0p1s1 = self.LAYER(in_channels, self.inplanes, kernel_size=5, dimension=D) 21 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) 22 | 23 | self.conv1p1s2 = self.LAYER(self.inplanes, 24 | self.inplanes, 25 | kernel_size=2, 26 | stride=2, 27 | dimension=D) # pooling 28 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 29 | self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0]) 30 | 31 | self.conv2p2s2 = self.LAYER(self.inplanes, 32 | self.inplanes, 33 | kernel_size=2, 34 | stride=2, 35 | dimension=D) # pooling 36 | self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) 37 | self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1]) 38 | 39 | self.conv3p4s2 = self.LAYER(self.inplanes, 40 | self.inplanes, 41 | kernel_size=2, 42 | stride=2, 43 | dimension=D) # pooling 44 | self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) 45 | self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2]) 46 | 47 | self.conv4p8s2 = self.LAYER(self.inplanes, 48 | self.inplanes, 49 | kernel_size=2, 50 | stride=2, 51 | dimension=D) # pooling 52 | self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) 53 | self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3]) 54 | 55 | self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose(self.inplanes, 56 | self.PLANES[4], 57 | kernel_size=2, 58 | stride=2, 59 | dimension=D) # unpooling 60 | self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4]) 61 | self.inplanes = self.PLANES[ 62 | 4] + self.PLANES[2] * self.BLOCK.expansion # concatenated dimension 63 | self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4]) 64 | 65 | self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose(self.inplanes, 66 | self.PLANES[5], 67 | kernel_size=2, 68 | stride=2, 69 | dimension=D) # unpooling 70 | self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5]) 71 | self.inplanes = self.PLANES[ 72 | 5] + self.PLANES[1] * self.BLOCK.expansion # concatenated dimension 73 | self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5]) 74 | 75 | self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose(self.inplanes, 76 | self.PLANES[6], 77 | kernel_size=2, 78 | stride=2, 79 | dimension=D) # unpooling 80 | self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6]) 81 | self.inplanes = self.PLANES[ 82 | 6] + self.PLANES[0] * self.BLOCK.expansion # concatenated dimension 83 | self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], self.LAYERS[6]) 84 | 85 | self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose(self.inplanes, 86 | self.PLANES[7], 87 | kernel_size=2, 88 | stride=2, 89 | dimension=D) # unpooling 90 | self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7]) 91 | self.inplanes = self.PLANES[7] + self.INIT_DIM # concatenated dimension 92 | self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], self.LAYERS[7]) 93 | 94 | self.final = ME.MinkowskiConvolution(self.PLANES[7] * self.BLOCK.expansion, 95 | out_channels, 96 | kernel_size=1, 97 | stride=1, 98 | bias=True, 99 | dimension=D) 100 | self.relu = ME.MinkowskiReLU(inplace=True) 101 | 102 | def voxelize(self, x: ME.TensorField): 103 | raise NotImplementedError() 104 | 105 | def devoxelize(self, out: ME.SparseTensor, x: ME.TensorField, emb: torch.Tensor): 106 | raise NotImplementedError() 107 | 108 | def forward(self, x: ME.TensorField): 109 | out, emb = self.voxelize(x) 110 | out_p1 = self.relu(self.bn0(self.conv0p1s1(out))) 111 | 112 | out = self.relu(self.bn1(self.conv1p1s2(out_p1))) 113 | out_p2 = self.block1(out) 114 | 115 | out = self.relu(self.bn2(self.conv2p2s2(out_p2))) 116 | out_p4 = self.block2(out) 117 | 118 | out = self.relu(self.bn3(self.conv3p4s2(out_p4))) 119 | out_p8 = self.block3(out) 120 | 121 | out = self.relu(self.bn4(self.conv4p8s2(out_p8))) 122 | out = self.block4(out) 123 | 124 | out = self.relu(self.bntr4(self.convtr4p16s2(out))) 125 | out = ME.cat(out, out_p8) 126 | out = self.block5(out) 127 | 128 | out = self.relu(self.bntr5(self.convtr5p8s2(out))) 129 | out = ME.cat(out, out_p4) 130 | out = self.block6(out) 131 | 132 | out = self.relu(self.bntr6(self.convtr6p4s2(out))) 133 | out = ME.cat(out, out_p2) 134 | out = self.block7(out) 135 | 136 | out = self.relu(self.bntr7(self.convtr7p2s2(out))) 137 | out = ME.cat(out, out_p1) 138 | out = self.block8(out) 139 | return self.devoxelize(out, x, emb) 140 | 141 | 142 | @gin.configurable 143 | class Res16UNet34C(Res16UNetBase): # MinkowskiNet42 144 | BLOCK = BasicBlock 145 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 146 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) 147 | 148 | def voxelize(self, x: ME.TensorField): 149 | return x.sparse(), None 150 | 151 | def devoxelize(self, out: ME.SparseTensor, x: ME.TensorField, emb: torch.Tensor): 152 | return self.final(out).slice(x).F 153 | 154 | 155 | @gin.configurable 156 | class Res16UNet34CSmall(Res16UNet34C): 157 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 158 | 159 | 160 | @gin.configurable 161 | class Res16UNet34CSmaller(Res16UNet34C): 162 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) 163 | -------------------------------------------------------------------------------- /src/models/spvcnn.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import torch 3 | import torch.nn as nn 4 | import MinkowskiEngine as ME 5 | 6 | from src.models.resunet import Res16UNet34C 7 | 8 | 9 | @gin.configurable 10 | class SPVCNN(Res16UNet34C): 11 | def network_initialization(self, in_channels, out_channels, D): 12 | super(SPVCNN, self).network_initialization(in_channels, out_channels, D) 13 | self.final = ME.MinkowskiLinear(self.PLANES[7] * self.BLOCK.expansion, out_channels) 14 | 15 | self.point_transforms = nn.ModuleList([ 16 | nn.Sequential( 17 | ME.MinkowskiLinear(self.INIT_DIM, self.PLANES[3]), 18 | ME.MinkowskiBatchNorm(self.PLANES[3]), 19 | ME.MinkowskiReLU(True) 20 | ), 21 | nn.Sequential( 22 | ME.MinkowskiLinear(self.PLANES[4], self.PLANES[5]), 23 | ME.MinkowskiBatchNorm(self.PLANES[5]), 24 | ME.MinkowskiReLU(True) 25 | ), 26 | nn.Sequential( 27 | ME.MinkowskiLinear(self.PLANES[5], self.PLANES[7]), 28 | ME.MinkowskiBatchNorm(self.PLANES[7]), 29 | ME.MinkowskiReLU(True) 30 | ), 31 | ]) 32 | self.dropout = ME.MinkowskiDropout(0.3, True) 33 | 34 | def voxel_to_point(self, s: ME.SparseTensor, f: ME.TensorField): 35 | feats, _, out_map, weights = ME.MinkowskiInterpolationFunction().apply( 36 | s.F, f.C, s.coordinate_key, s.coordinate_manager 37 | ) 38 | denom = torch.zeros((len(f),), dtype=feats.dtype, device=feats.device) 39 | denom.index_add_(0, out_map.long(), weights) 40 | denom.unsqueeze_(1) 41 | norm_feats = torch.true_divide(feats, denom + 1e-8) 42 | return ME.TensorField( 43 | features=norm_feats, 44 | coordinate_field_map_key=f.coordinate_field_map_key, 45 | quantization_mode=f.quantization_mode, 46 | coordinate_manager=f.coordinate_manager 47 | ) 48 | 49 | def forward(self, x: ME.TensorField): 50 | x0 = x.sparse() 51 | x0 = self.relu(self.bn0(self.conv0p1s1(x0))) 52 | z0 = self.voxel_to_point(x0, x) 53 | 54 | x1 = z0.sparse(coordinate_map_key=x0.coordinate_map_key) 55 | x1 = self.relu(self.bn1(self.conv1p1s2(x1))) 56 | x1 = self.block1(x1) 57 | 58 | x2 = self.relu(self.bn2(self.conv2p2s2(x1))) 59 | x2 = self.block2(x2) 60 | 61 | x3 = self.relu(self.bn3(self.conv3p4s2(x2))) 62 | x3 = self.block3(x3) 63 | 64 | x4 = self.relu(self.bn4(self.conv4p8s2(x3))) 65 | x4 = self.block4(x4) 66 | 67 | z1 = self.voxel_to_point(x4, x) 68 | z1 = z1 + self.point_transforms[0](z0).F 69 | 70 | y1 = z1.sparse(coordinate_map_key=x4.coordinate_map_key) 71 | y1 = self.dropout(y1) 72 | y1 = self.relu(self.bntr4(self.convtr4p16s2(y1))) 73 | y1 = ME.cat(y1, x3) 74 | y1 = self.block5(y1) 75 | 76 | y2 = self.relu(self.bntr5(self.convtr5p8s2(y1))) 77 | y2 = ME.cat(y2, x2) 78 | y2 = self.block6(y2) 79 | z2 = self.voxel_to_point(y2, x) 80 | z2 = z2 + self.point_transforms[1](z1).F 81 | 82 | y3 = z2.sparse(coordinate_map_key=x2.coordinate_map_key) 83 | y3 = self.dropout(y3) 84 | y3 = self.relu(self.bntr6(self.convtr6p4s2(y3))) 85 | y3 = ME.cat(y3, x1) 86 | y3 = self.block7(y3) 87 | 88 | y4 = self.relu(self.bntr7(self.convtr7p2s2(y3))) 89 | y4 = ME.cat(y4, x0) 90 | y4 = self.block8(y4) 91 | z3 = self.voxel_to_point(y4, x) 92 | z3 = z3 + self.point_transforms[2](z2).F 93 | return self.final(z3).F 94 | -------------------------------------------------------------------------------- /src/models/transformer_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import MinkowskiEngine as ME 4 | from MinkowskiEngine.MinkowskiKernelGenerator import KernelGenerator 5 | 6 | 7 | class LocalSelfAttentionBase(nn.Module): 8 | def __init__(self, kernel_size, stride, dilation, dimension): 9 | super(LocalSelfAttentionBase, self).__init__() 10 | self.kernel_size = kernel_size 11 | self.stride = stride 12 | self.dilation = dilation 13 | self.dimension = dimension 14 | 15 | self.kernel_generator = KernelGenerator(kernel_size=kernel_size, 16 | stride=stride, 17 | dilation=dilation, 18 | dimension=dimension) 19 | self.kernel_volume = self.kernel_generator.kernel_volume 20 | 21 | def get_kernel_map_and_out_key(self, stensor): 22 | cm = stensor.coordinate_manager 23 | in_key = stensor.coordinate_key 24 | out_key = cm.stride(in_key, self.kernel_generator.kernel_stride) 25 | region_type, region_offset, _ = self.kernel_generator.get_kernel( 26 | stensor.tensor_stride, False) 27 | kernel_map = cm.kernel_map(in_key, 28 | out_key, 29 | self.kernel_generator.kernel_stride, 30 | self.kernel_generator.kernel_size, 31 | self.kernel_generator.kernel_dilation, 32 | region_type=region_type, 33 | region_offset=region_offset) 34 | return kernel_map, out_key 35 | 36 | def key_query_map_from_kernel_map(self, kernel_map): 37 | kq_map = [] 38 | for kernel_idx, in_out in kernel_map.items(): 39 | in_out[0] = in_out[0] * self.kernel_volume + kernel_idx 40 | kq_map.append(in_out) 41 | kq_map = torch.cat(kq_map, -1) 42 | return kq_map 43 | 44 | def key_query_indices_from_kernel_map(self, kernel_map): 45 | kq_indices = [] 46 | for _, in_out in kernel_map.items(): 47 | kq_indices.append(in_out) 48 | kq_indices = torch.cat(kq_indices, -1) 49 | return kq_indices 50 | 51 | def key_query_indices_from_key_query_map(self, kq_map): 52 | kq_indices = kq_map.clone() 53 | kq_indices[0] = kq_indices[0] // self.kernel_volume 54 | return kq_indices 55 | 56 | 57 | class ResidualBlockWithPointsBase(nn.Module): 58 | LAYER = None 59 | 60 | def __init__(self, in_channels, out_channels=None, kernel_size=3): 61 | out_channels = in_channels if out_channels is None else out_channels 62 | assert self.LAYER is not None 63 | super(ResidualBlockWithPointsBase, self).__init__() 64 | 65 | self.layer1 = self.LAYER(in_channels, out_channels, kernel_size) 66 | self.norm1 = ME.MinkowskiBatchNorm(out_channels) 67 | self.layer2 = self.LAYER(out_channels, out_channels, kernel_size) 68 | self.norm2 = ME.MinkowskiBatchNorm(out_channels) 69 | self.relu = ME.MinkowskiReLU(inplace=True) 70 | 71 | def forward(self, stensor, norm_points): 72 | residual = stensor 73 | out = self.layer1(stensor, norm_points) 74 | out = self.norm1(out) 75 | out = self.relu(out) 76 | out = self.layer2(out, norm_points) 77 | out = self.norm2(out) 78 | out += residual 79 | out = self.relu(out) 80 | return out 81 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .segmentation import LitSegMinkowskiModule 2 | 3 | modules = [LitSegMinkowskiModule] 4 | modules_dict = {m.__name__: m for m in modules} 5 | 6 | 7 | def get_lightning_module(name): 8 | assert ( 9 | name in modules_dict.keys() 10 | ), f"{name} not in {modules_dict.keys()}" 11 | return modules_dict[name] -------------------------------------------------------------------------------- /src/modules/segmentation.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import numpy as np 3 | import torch 4 | import pytorch_lightning as pl 5 | import pl_bolts 6 | import torchmetrics 7 | import MinkowskiEngine as ME 8 | 9 | from src.utils.metric import per_class_iou 10 | 11 | 12 | @gin.configurable 13 | class LitSegmentationModuleBase(pl.LightningModule): 14 | def __init__( 15 | self, 16 | model, 17 | num_classes, 18 | lr, 19 | momentum, 20 | weight_decay, 21 | warmup_steps_ratio, 22 | max_steps, 23 | best_metric_type, 24 | ignore_label=255, 25 | dist_sync_metric=False, 26 | lr_eta_min=0., 27 | ): 28 | super(LitSegmentationModuleBase, self).__init__() 29 | for name, value in vars().items(): 30 | if name not in ["self", "__class__"]: 31 | setattr(self, name, value) 32 | 33 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label) 34 | self.best_metric_value = -np.inf if best_metric_type == "maximize" else np.inf 35 | self.metric = torchmetrics.ConfusionMatrix( 36 | num_classes=num_classes, 37 | compute_on_step=False, 38 | dist_sync_on_step=dist_sync_metric 39 | ) 40 | 41 | def configure_optimizers(self): 42 | optimizer = torch.optim.SGD( 43 | self.model.parameters(), 44 | lr=self.lr, 45 | momentum=self.momentum, 46 | weight_decay=self.weight_decay, 47 | ) 48 | scheduler = pl_bolts.optimizers.LinearWarmupCosineAnnealingLR( 49 | optimizer, 50 | warmup_epochs=int(self.warmup_steps_ratio * self.max_steps), 51 | max_epochs=self.max_steps, 52 | eta_min=self.lr_eta_min, 53 | ) 54 | return { 55 | "optimizer": optimizer, 56 | "lr_scheduler": { 57 | "scheduler": scheduler, 58 | "interval": "step" 59 | } 60 | } 61 | 62 | def training_step(self, batch, batch_idx): 63 | in_data = self.prepare_input_data(batch) 64 | logits = self.model(in_data) 65 | loss = self.criterion(logits, batch["labels"]) 66 | self.log("train_loss", loss.item(), batch_size=batch["batch_size"], logger=True, prog_bar=True) 67 | torch.cuda.empty_cache() 68 | return loss 69 | 70 | def validation_step(self, batch, batch_idx): 71 | in_data = self.prepare_input_data(batch) 72 | logits = self.model(in_data) 73 | loss = self.criterion(logits, batch["labels"]) 74 | self.log("val_loss", loss.item(), batch_size=batch["batch_size"], logger=True, prog_bar=True) 75 | pred = logits.argmax(dim=1, keepdim=False) 76 | mask = batch["labels"] != self.ignore_label 77 | self.metric(pred[mask], batch["labels"][mask]) 78 | torch.cuda.empty_cache() 79 | return loss 80 | 81 | def validation_epoch_end(self, outputs): 82 | confusion_matrix = self.metric.compute().cpu().numpy() 83 | self.metric.reset() 84 | ious = per_class_iou(confusion_matrix) * 100 85 | accs = confusion_matrix.diagonal() / confusion_matrix.sum(1) * 100 86 | miou = np.nanmean(ious) 87 | macc = np.nanmean(accs) 88 | 89 | def compare(prev, cur): 90 | return prev < cur if self.best_metric_type == "maximize" else prev > cur 91 | 92 | if compare(self.best_metric_value, miou): 93 | self.best_metric_value = miou 94 | self.log("val_best_mIoU", self.best_metric_value, logger=True) 95 | self.log("val_mIoU", miou, logger=True) 96 | self.log("val_mAcc", macc, logger=True) 97 | 98 | def prepare_input_data(self, batch): 99 | raise NotImplementedError 100 | 101 | 102 | @gin.configurable 103 | class LitSegMinkowskiModule(LitSegmentationModuleBase): 104 | def prepare_input_data(self, batch): 105 | in_data = ME.TensorField( 106 | features=batch["features"], 107 | coordinates=batch["coordinates"], 108 | quantization_mode=self.model.QMODE 109 | ) 110 | return in_data -------------------------------------------------------------------------------- /src/utils/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def ensure_dir(path): 5 | if not os.path.exists(path): 6 | os.makedirs(path, mode=0o755) -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from rich.logging import RichHandler 5 | 6 | 7 | def setup_logger(exp_name, debug): 8 | from imp import reload 9 | 10 | reload(logging) 11 | 12 | CUDA_TAG = os.environ.get("CUDA_VISIBLE_DEVICES", "0") 13 | EXP_TAG = exp_name 14 | 15 | logger_config = dict( 16 | level=logging.DEBUG if debug else logging.INFO, 17 | format=f"{CUDA_TAG}:[{EXP_TAG}] %(message)s", 18 | handlers=[RichHandler()], 19 | datefmt="[%X]", 20 | ) 21 | logging.basicConfig(**logger_config) -------------------------------------------------------------------------------- /src/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def fast_hist(pred, target, num_classes, ignore_label=255): 6 | mask = (target != ignore_label) & (target < num_classes) 7 | return np.bincount(num_classes * target[mask].astype(int) + pred[mask], 8 | minlength=num_classes**2).reshape(num_classes, num_classes) 9 | 10 | 11 | def per_class_iou(hist): 12 | with np.errstate(divide='ignore', invalid='ignore'): 13 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 14 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | import gin 2 | 3 | 4 | @gin.configurable 5 | def logged_hparams(keys): 6 | C = dict() 7 | for k in keys: 8 | C[k] = gin.query_parameter(f"{k}") 9 | return C 10 | 11 | 12 | def load_from_pl_state_dict(model, pl_state_dict): 13 | state_dict = {} 14 | for k, v in pl_state_dict.items(): 15 | state_dict[k[6:]] = v 16 | model.load_state_dict(state_dict) 17 | return model -------------------------------------------------------------------------------- /src/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | 4 | 5 | def make_open3d_point_cloud(xyz, color=None): 6 | pcd = o3d.geometry.PointCloud() 7 | pcd.points = o3d.utility.Vector3dVector(xyz) 8 | if color is not None: 9 | if len(color) != len(xyz): 10 | color = np.tile(color, (len(xyz), 1)) 11 | pcd.colors = o3d.utility.Vector3dVector(color) 12 | return pcd -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from datetime import datetime 4 | 5 | import gin 6 | import pytorch_lightning as pl 7 | 8 | from src.models import get_model 9 | from src.data import get_data_module 10 | from src.modules import get_lightning_module 11 | from src.utils.file import ensure_dir 12 | from src.utils.logger import setup_logger 13 | from src.utils.misc import logged_hparams 14 | 15 | 16 | @gin.configurable 17 | def train( 18 | save_path, 19 | project_name, 20 | run_name, 21 | lightning_module_name, 22 | data_module_name, 23 | model_name, 24 | gpus, 25 | log_every_n_steps, 26 | check_val_every_n_epoch, 27 | refresh_rate_per_second, 28 | best_metric, 29 | max_epoch, 30 | max_step, 31 | ): 32 | now = datetime.now().strftime('%m-%d-%H-%M-%S') 33 | run_name = run_name + "_" + now 34 | save_path = os.path.join(save_path, run_name) 35 | ensure_dir(save_path) 36 | 37 | data_module = get_data_module(data_module_name)() 38 | model = get_model(model_name)() 39 | pl_module = get_lightning_module(lightning_module_name)(model=model, max_steps=max_step) 40 | gin.finalize() 41 | 42 | hparams = logged_hparams() 43 | callbacks = [ 44 | pl.callbacks.TQDMProgressBar(refresh_rate=refresh_rate_per_second), 45 | pl.callbacks.ModelCheckpoint( 46 | dirpath=save_path, monitor=best_metric, save_last=True, save_top_k=1, mode="max" 47 | ), 48 | pl.callbacks.LearningRateMonitor(), 49 | ] 50 | loggers = [ 51 | pl.loggers.WandbLogger( 52 | name=run_name, 53 | save_dir=save_path, 54 | project=project_name, 55 | log_model=True, 56 | entity="chrockey", 57 | config=hparams, 58 | ) 59 | ] 60 | additional_kwargs = dict() 61 | if gpus > 1: 62 | raise NotImplementedError("Currently, multi-gpu training is not supported.") 63 | 64 | trainer = pl.Trainer( 65 | default_root_dir=save_path, 66 | max_epochs=max_epoch, 67 | max_steps=max_step, 68 | gpus=gpus, 69 | callbacks=callbacks, 70 | logger=loggers, 71 | log_every_n_steps=log_every_n_steps, 72 | check_val_every_n_epoch=check_val_every_n_epoch, 73 | **additional_kwargs 74 | ) 75 | 76 | # write config file 77 | with open(os.path.join(save_path, "config.gin"), "w") as f: 78 | f.write(gin.operative_config_str()) 79 | 80 | trainer.fit(pl_module, data_module) 81 | 82 | 83 | if __name__ == "__main__": 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument("config", type=str) 86 | parser.add_argument("--save_path", type=str, default="experiments") 87 | parser.add_argument("--run_name", type=str, default="default") 88 | parser.add_argument("--seed", type=int, default=1235) 89 | parser.add_argument("-v", "--voxel_size", type=float, default=None) 90 | parser.add_argument("--debug", action="store_true") 91 | args = parser.parse_args() 92 | 93 | pl.seed_everything(args.seed) 94 | gin.parse_config_file(args.config) 95 | if args.voxel_size is not None: 96 | gin.bind_parameter("DimensionlessCoordinates.voxel_size", args.voxel_size) 97 | setup_logger(args.run_name, args.debug) 98 | 99 | train(save_path=args.save_path, run_name=args.run_name) --------------------------------------------------------------------------------