├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── assets
└── overview.png
├── config
├── default.gin
├── s3dis
│ ├── eval_default.gin
│ ├── eval_fpt.gin
│ ├── eval_res16unet34c.gin
│ ├── train_default.gin
│ ├── train_fpt.gin
│ └── train_res16unet34c.gin
└── scannet
│ ├── eval_default.gin
│ ├── eval_fpt.gin
│ ├── eval_res16unet34c.gin
│ ├── train_default.gin
│ ├── train_fpt.gin
│ └── train_res16unet34c.gin
├── environment.yaml
├── eval.py
├── setup.sh
├── src
├── cscore
│ ├── calculate.py
│ └── prepare.py
├── cuda_ops
│ ├── functions
│ │ └── sparse_ops.py
│ ├── setup.py
│ └── src
│ │ ├── cuda_ops_api.cpp
│ │ ├── cuda_utils.h
│ │ ├── dot_product
│ │ ├── dot_product.cpp
│ │ ├── dot_product_kernel.cu
│ │ └── dot_product_kernel.h
│ │ └── scalar_attention
│ │ ├── scalar_attention.cpp
│ │ ├── scalar_attention_kernel.cu
│ │ └── scalar_attention_kernel.h
├── data
│ ├── __init__.py
│ ├── collate.py
│ ├── meta_data
│ │ ├── s3dis
│ │ │ ├── area1.txt
│ │ │ ├── area2.txt
│ │ │ ├── area3.txt
│ │ │ ├── area4.txt
│ │ │ ├── area5.txt
│ │ │ └── area6.txt
│ │ └── scannet
│ │ │ ├── scannetv2_test.txt
│ │ │ ├── scannetv2_train.txt
│ │ │ ├── scannetv2_trainval.txt
│ │ │ └── scannetv2_val.txt
│ ├── preprocess_s3dis.py
│ ├── preprocess_scannet.py
│ ├── s3dis_loader.py
│ ├── sampler.py
│ ├── scannet_loader.py
│ └── transforms.py
├── models
│ ├── __init__.py
│ ├── common.py
│ ├── fast_point_transformer.py
│ ├── resnet.py
│ ├── resunet.py
│ ├── spvcnn.py
│ └── transformer_base.py
├── modules
│ ├── __init__.py
│ └── segmentation.py
└── utils
│ ├── file.py
│ ├── logger.py
│ ├── metric.py
│ ├── misc.py
│ └── visualization.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | .empty/
3 | __pycache__/
4 | wandb/
5 | tensorboard/
6 |
7 | src/cuda_ops/build/
8 | src/cuda_ops/cuda_ops.egg-info/
9 | src/cuda_ops/cuda_sparse_ops.egg-info/
10 | src/cuda_ops/dist/
11 |
12 | Open3D/
13 | *ipynb*
14 |
15 | *.ply
16 | *.pyc
17 | *.png
18 | *.ckpt
19 |
20 | thirdparty/MinkowskiEngine/MinkowskiEngine.egg-info/
21 | thirdparty/MinkowskiEngine/build/
22 | thirdparty/MinkowskiEngine/dist/
23 | thirdparty-temp/
24 |
25 | experiments/
26 | checkpoints/
27 | consistency_outputs/
28 | votenet-logs-temp/
29 | sbatch/*.out
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "thirdparty/MinkowskiEngine"]
2 | path = thirdparty/MinkowskiEngine
3 | url = git@github.com:chrockey/MinkowskiEngine.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Chunghyun Park
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fast Point Transformer
2 | ### [Project Page](http://cvlab.postech.ac.kr/research/FPT/) | [Paper](https://arxiv.org/abs/2112.04702)
3 | This repository contains the official source code and data for our paper:
4 |
5 | >[Fast Point Transformer](https://arxiv.org/abs/2112.04702)
6 | > [Chunghyun Park](https://chrockey.github.io/),
7 | > [Yoonwoo Jeong](https://yoonwoojeong.medium.com/about),
8 | > [Minsu Cho](http://cvlab.postech.ac.kr/~mcho/), and
9 | > [Jaesik Park](http://jaesik.info/)
10 | > POSTECH GSAI & CSE
11 | > CVPR, New Orleans, 2022.
12 |
13 |
14 |

15 |
16 |
17 | ## Overview
18 | This work introduces *Fast Point Transformer* that consists of a new lightweight self-attention layer. Our approach encodes continuous 3D coordinates, and the voxel hashing-based architecture boosts computational efficiency. The proposed method is demonstrated with 3D semantic segmentation and 3D detection. The accuracy of our approach is competitive to the best voxel based method, and our network achieves 129 times faster inference time than the state-of-the-art, Point Transformer, with a reasonable accuracy trade-off in 3D semantic segmentation on S3DIS dataset.
19 |
20 | ## Citation
21 | If you find our code or paper useful, please consider citing our paper:
22 |
23 | ```BibTeX
24 | @inproceedings{park2022fast,
25 | title={Fast Point Transformer},
26 | author={Park, Chunghyun and Jeong, Yoonwoo and Cho, Minsu and Park, Jaesik},
27 | booktitle={Proceedings of the {IEEE/CVF} Conference on Computer Vision and Pattern Recognition (CVPR)},
28 | month={June},
29 | year={2022},
30 | pages={16949-16958}
31 | }
32 | ```
33 |
34 | ## Experiments
35 | ### 1. S3DIS Area 5 test
36 | We denote MinkowskiNet42 trained with this repository as MinkowskiNet42†.
37 | We use voxel size 4cm for both MinkowskiNet42† and our Fast Point Transformer.
38 |
39 | | Model | Latency (sec) | mAcc (%) | mIoU (%) | Reference |
40 | |:----------------------------------|--------------------:|:--------:|:--------:|:---------:|
41 | | PointTransformer | 18.07 | 76.5 | 70.4 | [Codes from the authors](https://github.com/POSTECH-CVLab/point-transformer) |
42 | | MinkowskiNet42† | 0.08 | 74.1 | 67.2 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EZcO0DH6QeNGgIwGFZsmL-4BAlikmHAHlBs4JBcS5XfpVQ?download=1) |
43 | | + rotation average | 0.66 | 75.1 | 69.0 | - |
44 | | FastPointTransformer | 0.14 | 76.6 | 69.2 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/ER8KwMTzqAxAvK9KeOZ9U_IBuCAuv4hP6zOWD-3HNO6Xeg?download=1) |
45 | | + rotation average | 1.13 | 77.6 | 71.0 | - |
46 |
47 | ### 2. ScanNetV2 validation
48 | | Model | Voxel Size | mAcc (%) | mIoU (%) | Reference |
49 | |:----------------------------------|:-----------:|:--------:|:--------:|:---------:|
50 | | MinkowskiNet42 | 2cm | 80.4 | 72.2 | [Official GitHub](https://github.com/chrischoy/SpatioTemporalSegmentation) |
51 | | MinkowskiNet42† | 2cm | 81.4 | 72.1 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EXmE1pWDZ8lEtJU7SQMjkXcBnhSMXFTdHWXkMAAF7KeiuA?download=1) |
52 | | FastPointTransformer | 2cm | 81.2 | 72.5 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EX_xAyhoNXdJg4eSg2vS_bYB8eFAP7A8FPCYfKOS2T13LQ?download=1) |
53 | | MinkowskiNet42† | 5cm | 76.3 | 67.0 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EZLG00u5JXJDvOi3sYziOIMB1l6HNN5OW9gTQRFWc6EwzA?download=1) |
54 | | FastPointTransformer | 5cm | 78.9 | 70.0 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EXbXclfXZGtMpBZY93zi7M8B_tl8rwM65NK1cumN7QM_8g?download=1) |
55 | | MinkowskiNet42† | 10cm | 70.8 | 60.7 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/EVLn0f5noY1Al6Kos9l-0yABM0qZLFt6d4a3yFgTcQ2Vmw?download=1) |
56 | | FastPointTransformer | 10cm | 76.1 | 66.5 | [Checkpoint](https://postechackr-my.sharepoint.com/:u:/g/personal/p0125ch_postech_ac_kr/ESO1jLNHO89ApdjguUauqsMBCx_TijA26UOeGbF4XxQwoA?download=1) |
57 |
58 | ## Installation
59 | This repository is developed and tested on
60 |
61 | - Ubuntu 18.04 and 20.04
62 | - Conda 4.11.0
63 | - CUDA 11.1 and 11.3
64 | - Python 3.8.13
65 | - PyTorch 1.7.1, 1.10.0, and 1.12.1
66 | - MinkowskiEngine 0.5.4
67 |
68 | ### Environment Setup
69 | You can install the environment by using the provided shell script:
70 | ```bash
71 | ~$ git clone --recursive git@github.com:POSTECH-CVLab/FastPointTransformer.git
72 | ~$ cd FastPointTransformer
73 | ~/FastPointTransformer$ bash setup.sh fpt
74 | ~/FastPointTransformer$ conda activate fpt
75 | ```
76 |
77 | ## Training & Evaluation
78 | First of all, you need to download the datasets (ScanNetV2 and S3DIS), and preprocess them as:
79 | ```bash
80 | (fpt) ~/FastPointTransformer$ python src/data/preprocess_scannet.py # you need to modify the data path
81 | (fpt) ~/FastPointTransformer$ python src/data/preprocess_s3dis.py # you need to modify the data path
82 | ```
83 | And then, locate the provided meta data of each dataset (`src/data/meta_data`) with the preprocessed dataset following the structure below:
84 |
85 | ```
86 | ${data_dir}
87 | ├── scannetv2
88 | │ ├── meta_data
89 | │ │ ├── scannetv2_train.txt
90 | │ │ ├── scannetv2_val.txt
91 | │ │ └── ...
92 | │ └── scannet_processed
93 | │ ├── train
94 | │ │ ├── scene0000_00.ply
95 | │ │ ├── scene0000_01.ply
96 | │ │ └── ...
97 | │ └── test
98 | └── s3dis
99 | ├── meta_data
100 | │ ├── area1.txt
101 | │ ├── area2.txt
102 | │ └── ...
103 | └── s3dis_processed
104 | ├── Area_1
105 | │ ├── conferenceRoom_1.ply
106 | │ ├── conferenceRoom_2.ply
107 | │ └── ...
108 | ├── Area_2
109 | └── ...
110 | ```
111 |
112 | After then, you can train and evalaute a model by using the provided python scripts (`train.py` and `eval.py`) with configuration files in the `config` directory.
113 | For example, you can train and evaluate Fast Point Transformer with voxel size 4cm on S3DIS dataset via the following commands:
114 | ```bash
115 | (fpt) ~/FastPointTransformer$ python train.py config/s3dis/train_fpt.gin
116 | (fpt) ~/FastPointTransformer$ python eval.py config/s3dis/eval_fpt.gin {checkpoint_file} # use -r option for rotation averaging.
117 | ```
118 |
119 | ### Consistency Score
120 | You need to generate predictions via the following command:
121 | ```bash
122 | (fpt) ~/FastPointTransformer$ python -m src.cscore.prepare {checkpoint_file} -m {model_name} -v {voxel_size} # This takes hours.
123 | ```
124 | Then, you can calculate the consistency score (CScore) with:
125 | ```bash
126 | (fpt) ~/FastPointTransformer$ python -m src.cscore.calculate {prediction_dir} # This takes seconds.
127 | ```
128 |
129 | ### 3D Object Detection using VoteNet
130 | Please refer [this repository](https://github.com/chrockey/FastPointTransformer-VoteNet).
131 |
132 | ## Acknowledgement
133 |
134 | Our code is based on the [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine).
135 | We also thank [Hengshuang Zhao](https://hszhao.github.io/) for providing [the code](https://github.com/POSTECH-CVLab/point-transformer) of [Point Transformer](https://arxiv.org/abs/2012.09164).
136 | If you use our model, please consider citing them as well.
137 |
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/POSTECH-CVLab/FastPointTransformer/9b2cec39bc05f4921e4856d32fc563a6b43087cf/assets/overview.png
--------------------------------------------------------------------------------
/config/default.gin:
--------------------------------------------------------------------------------
1 | # Training
2 | train.project_name = "FastPointTransformer-release"
--------------------------------------------------------------------------------
/config/s3dis/eval_default.gin:
--------------------------------------------------------------------------------
1 | # Constants
2 | in_channels = 3
3 | out_channels = 13
4 |
5 | # Data module
6 | S3DISArea5RGBDataModule.data_root = "/root/data/s3dis" # you need to modify this according to your data.
7 | S3DISArea5RGBDataModule.train_batch_size = None
8 | S3DISArea5RGBDataModule.val_batch_size = 1
9 | S3DISArea5RGBDataModule.train_num_workers = None
10 | S3DISArea5RGBDataModule.val_num_workers = 4
11 | S3DISArea5RGBDataModule.collation_type = "collate_minkowski"
12 | S3DISArea5RGBDataModule.train_transforms = None
13 | S3DISArea5RGBDataModule.eval_transforms = [
14 | "DimensionlessCoordinates",
15 | "NormalizeColor",
16 | ]
17 |
18 | # Augmentation
19 | DimensionlessCoordinates.voxel_size = 0.04
20 |
21 | # Evaluation
22 | eval.data_module_name = "S3DISArea5RGBDataModule"
--------------------------------------------------------------------------------
/config/s3dis/eval_fpt.gin:
--------------------------------------------------------------------------------
1 | include "./config/s3dis/eval_res16unet34c.gin"
2 |
3 | # Model
4 | eval.model_name = "FastPointTransformer"
5 | FastPointTransformer.in_channels = %in_channels
6 | FastPointTransformer.out_channels = %out_channels
--------------------------------------------------------------------------------
/config/s3dis/eval_res16unet34c.gin:
--------------------------------------------------------------------------------
1 | include "./config/s3dis/eval_default.gin"
2 |
3 | # Model
4 | eval.model_name = "Res16UNet34C"
5 | Res16UNet34C.in_channels = %in_channels
6 | Res16UNet34C.out_channels = %out_channels
--------------------------------------------------------------------------------
/config/s3dis/train_default.gin:
--------------------------------------------------------------------------------
1 | include "./config/default.gin"
2 |
3 | # Constants
4 | in_channels = 3
5 | out_channels = 13
6 |
7 | # Data module
8 | S3DISArea5RGBDataModule.data_root = "/root/data/s3dis" # you need to modify this according to your data.
9 | S3DISArea5RGBDataModule.train_batch_size = 8
10 | S3DISArea5RGBDataModule.val_batch_size = 1
11 | S3DISArea5RGBDataModule.train_num_workers = 8
12 | S3DISArea5RGBDataModule.val_num_workers = 4
13 | S3DISArea5RGBDataModule.collation_type = "collate_minkowski"
14 | S3DISArea5RGBDataModule.train_transforms = [
15 | "DimensionlessCoordinates",
16 | "RandomRotation",
17 | "RandomCrop",
18 | "RandomAffine", # affine to rotate the rectangular crop
19 | "CoordinateDropout",
20 | "ChromaticTranslation",
21 | "ChromaticJitter",
22 | "RandomHorizontalFlip",
23 | "RandomTranslation",
24 | "ElasticDistortion",
25 | "NormalizeColor",
26 | ]
27 | S3DISArea5RGBDataModule.eval_transforms = [
28 | "DimensionlessCoordinates",
29 | "NormalizeColor",
30 | ]
31 |
32 | # Augmentation
33 | DimensionlessCoordinates.voxel_size = 0.04
34 | RandomCrop.x = 100
35 | RandomCrop.y = 100
36 | RandomCrop.z = 100
37 | RandomCrop.min_cardinality = 100
38 | RandomCrop.max_retries = 40
39 | RandomHorizontalFlip.upright_axis = "z"
40 | RandomAffine.upright_axis = "z"
41 | RandomAffine.application_ratio = 0.7
42 | ChromaticJitter.std = 0.01
43 | ChromaticJitter.application_ratio = 0.7
44 | ElasticDistortion.distortion_params = [(4, 16)]
45 | ElasticDistortion.application_ratio = 0.7
46 |
47 | # Pytorch lightning module
48 | LitSegmentationModuleBase.num_classes = %out_channels
49 | LitSegmentationModuleBase.lr = 0.1
50 | LitSegmentationModuleBase.momentum = 0.9
51 | LitSegmentationModuleBase.weight_decay = 1e-4
52 | LitSegmentationModuleBase.warmup_steps_ratio = 0.01
53 | LitSegmentationModuleBase.best_metric_type = "maximize"
54 |
55 | # Training
56 | train.data_module_name = "S3DISArea5RGBDataModule"
57 | train.gpus = 1
58 | train.log_every_n_steps = 10
59 | train.check_val_every_n_epoch = 1
60 | train.refresh_rate_per_second = 1
61 | train.best_metric = "val_mIoU"
62 | train.max_epoch = None
63 | train.max_step = 40000
64 |
65 | # Logging
66 | logged_hparams.keys = [
67 | "train.model_name",
68 | "train.data_module_name",
69 | "DimensionlessCoordinates.voxel_size",
70 | "S3DISArea5RGBDataModule.train_transforms",
71 | "S3DISArea5RGBDataModule.eval_transforms",
72 | "S3DISArea5RGBDataModule.train_batch_size",
73 | "S3DISArea5RGBDataModule.val_batch_size",
74 | "S3DISArea5RGBDataModule.train_num_workers",
75 | "S3DISArea5RGBDataModule.val_num_workers",
76 | "RandomCrop.x",
77 | "RandomHorizontalFlip.upright_axis",
78 | "RandomAffine.upright_axis",
79 | "RandomAffine.application_ratio",
80 | "ChromaticJitter.std",
81 | "ChromaticJitter.application_ratio",
82 | "ElasticDistortion.distortion_params",
83 | "ElasticDistortion.application_ratio",
84 | "LitSegmentationModuleBase.lr",
85 | "LitSegmentationModuleBase.momentum",
86 | "LitSegmentationModuleBase.weight_decay",
87 | "LitSegmentationModuleBase.warmup_steps_ratio",
88 | "train.max_step",
89 | ]
--------------------------------------------------------------------------------
/config/s3dis/train_fpt.gin:
--------------------------------------------------------------------------------
1 | # The code should be run on a GPU with at least 80GB memory (e.g., A100-80GB).
2 | include "./config/s3dis/train_res16unet34c.gin"
3 |
4 | # Model
5 | train.model_name = "FastPointTransformer"
6 | FastPointTransformer.in_channels = %in_channels
7 | FastPointTransformer.out_channels = %out_channels
--------------------------------------------------------------------------------
/config/s3dis/train_res16unet34c.gin:
--------------------------------------------------------------------------------
1 | include "./config/s3dis/train_default.gin"
2 |
3 | # Model
4 | train.lightning_module_name = "LitSegMinkowskiModule"
5 | train.model_name = "Res16UNet34C"
6 | Res16UNetBase.in_channels = %in_channels
7 | Res16UNetBase.out_channels = %out_channels
--------------------------------------------------------------------------------
/config/scannet/eval_default.gin:
--------------------------------------------------------------------------------
1 | # Constants
2 | in_channels = 3
3 | out_channels = 20
4 |
5 | # Data module
6 | ScanNetRGBDataModule.data_root = "/root/data/scannetv2" # you need to modify this according to your data.
7 | ScanNetRGBDataModule.train_batch_size = None
8 | ScanNetRGBDataModule.val_batch_size = 1
9 | ScanNetRGBDataModule.train_num_workers = None
10 | ScanNetRGBDataModule.val_num_workers = 4
11 | ScanNetRGBDataModule.collation_type = "collate_minkowski"
12 | ScanNetRGBDataModule.train_transforms = None
13 | ScanNetRGBDataModule.eval_transforms = [
14 | "DimensionlessCoordinates",
15 | "NormalizeColor",
16 | ]
17 |
18 | # Augmentation
19 | DimensionlessCoordinates.voxel_size = 0.02
20 |
21 | # Evaluation
22 | eval.data_module_name = "ScanNetRGBDataModule"
--------------------------------------------------------------------------------
/config/scannet/eval_fpt.gin:
--------------------------------------------------------------------------------
1 | include "./config/scannet/eval_res16unet34c.gin"
2 |
3 | # Model
4 | eval.model_name = "FastPointTransformer"
5 | FastPointTransformer.in_channels = %in_channels
6 | FastPointTransformer.out_channels = %out_channels
--------------------------------------------------------------------------------
/config/scannet/eval_res16unet34c.gin:
--------------------------------------------------------------------------------
1 | include "./config/scannet/eval_default.gin"
2 |
3 | # Model
4 | eval.model_name = "Res16UNet34C"
5 | Res16UNet34C.in_channels = %in_channels
6 | Res16UNet34C.out_channels = %out_channels
--------------------------------------------------------------------------------
/config/scannet/train_default.gin:
--------------------------------------------------------------------------------
1 | include "./config/default.gin"
2 |
3 | # Constants
4 | in_channels = 3
5 | out_channels = 20
6 |
7 | # Data module
8 | ScanNetRGBDataModule.data_root = "/root/data/scannetv2" # you need to modify this according to your data.
9 | ScanNetRGBDataModule.train_batch_size = 8
10 | ScanNetRGBDataModule.val_batch_size = 2
11 | ScanNetRGBDataModule.train_num_workers = 8
12 | ScanNetRGBDataModule.val_num_workers = 4
13 | ScanNetRGBDataModule.collation_type = "collate_minkowski"
14 | ScanNetRGBDataModule.train_transforms = [
15 | "DimensionlessCoordinates",
16 | "RandomRotation",
17 | "RandomCrop",
18 | "RandomAffine", # affine to rotate the rectangular crop
19 | "CoordinateDropout",
20 | "ChromaticTranslation",
21 | "ChromaticJitter",
22 | "RandomHorizontalFlip",
23 | "RandomTranslation",
24 | "ElasticDistortion",
25 | "NormalizeColor",
26 | ]
27 | ScanNetRGBDataModule.eval_transforms = [
28 | "DimensionlessCoordinates",
29 | "NormalizeColor",
30 | ]
31 |
32 | # Augmentation
33 | DimensionlessCoordinates.voxel_size = 0.02
34 | RandomCrop.x = 225
35 | RandomCrop.y = 225
36 | RandomCrop.z = 225
37 | RandomHorizontalFlip.upright_axis = "z"
38 | RandomAffine.upright_axis = "z"
39 | RandomAffine.application_ratio = 0.7
40 | ChromaticJitter.std = 0.01
41 | ChromaticJitter.application_ratio = 0.7
42 | ElasticDistortion.distortion_params = [(4, 16)]
43 | ElasticDistortion.application_ratio = 0.7
44 |
45 | # Pytorch lightning module
46 | LitSegmentationModuleBase.num_classes = %out_channels
47 | LitSegmentationModuleBase.lr = 0.1
48 | LitSegmentationModuleBase.momentum = 0.9
49 | LitSegmentationModuleBase.weight_decay = 1e-4
50 | LitSegmentationModuleBase.warmup_steps_ratio = 0.1
51 | LitSegmentationModuleBase.best_metric_type = "maximize"
52 |
53 | # Training
54 | train.data_module_name = "ScanNetRGBDataModule"
55 | train.gpus = 1
56 | train.log_every_n_steps = 10
57 | train.check_val_every_n_epoch = 1
58 | train.refresh_rate_per_second = 1
59 | train.best_metric = "val_mIoU"
60 | train.max_epoch = None
61 | train.max_step = 100000
62 |
63 | # Logging
64 | logged_hparams.keys = [
65 | "train.model_name",
66 | "train.data_module_name",
67 | "DimensionlessCoordinates.voxel_size",
68 | "ScanNetRGBDataModule.train_transforms",
69 | "ScanNetRGBDataModule.eval_transforms",
70 | "ScanNetRGBDataModule.train_batch_size",
71 | "ScanNetRGBDataModule.val_batch_size",
72 | "ScanNetRGBDataModule.train_num_workers",
73 | "ScanNetRGBDataModule.val_num_workers",
74 | "RandomCrop.x",
75 | "RandomHorizontalFlip.upright_axis",
76 | "RandomAffine.upright_axis",
77 | "RandomAffine.application_ratio",
78 | "ChromaticJitter.std",
79 | "ChromaticJitter.application_ratio",
80 | "ElasticDistortion.distortion_params",
81 | "ElasticDistortion.application_ratio",
82 | "LitSegmentationModuleBase.lr",
83 | "LitSegmentationModuleBase.momentum",
84 | "LitSegmentationModuleBase.weight_decay",
85 | "LitSegmentationModuleBase.warmup_steps_ratio",
86 | "train.max_step",
87 | ]
--------------------------------------------------------------------------------
/config/scannet/train_fpt.gin:
--------------------------------------------------------------------------------
1 | # The code should be run on a GPU with at least 80GB memory (e.g., A100-80GB).
2 | include "./config/scannet/train_res16unet34c.gin"
3 |
4 | # Model
5 | train.model_name = "FastPointTransformer"
6 | FastPointTransformer.in_channels = %in_channels
7 | FastPointTransformer.out_channels = %out_channels
--------------------------------------------------------------------------------
/config/scannet/train_res16unet34c.gin:
--------------------------------------------------------------------------------
1 | # The code should be run on a GPU with at least 24GB memory (e.g., A5000).
2 | include "./config/scannet/train_default.gin"
3 |
4 | # Model
5 | train.lightning_module_name = "LitSegMinkowskiModule"
6 | train.model_name = "Res16UNet34C"
7 | Res16UNet34C.in_channels = %in_channels
8 | Res16UNet34C.out_channels = %out_channels
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: fpt
2 | channels:
3 | - anaconda
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=5.1=1_gnu
8 | - blas=1.0=openblas
9 | - ca-certificates=2022.07.19=h06a4308_0
10 | - certifi=2022.6.15=py38h06a4308_0
11 | - ld_impl_linux-64=2.38=h1181459_1
12 | - libffi=3.4.2=h6a678d5_6
13 | - libgcc-ng=11.2.0=h1234567_1
14 | - libgfortran-ng=11.2.0=h00389a5_1
15 | - libgfortran5=11.2.0=h1234567_1
16 | - libgomp=11.2.0=h1234567_1
17 | - libopenblas=0.3.20=h043d6bf_1
18 | - libstdcxx-ng=11.2.0=h1234567_1
19 | - ncurses=6.4=h6a678d5_0
20 | - nomkl=3.0=0
21 | - openblas-devel=0.3.20=h06a4308_1
22 | - openssl=1.1.1s=h7f8727e_0
23 | - pip=22.3.1=py38h06a4308_0
24 | - python=3.8.16=h7a1cb2a_2
25 | - readline=8.2=h5eee18b_0
26 | - setuptools=65.6.3=py38h06a4308_0
27 | - sqlite=3.40.1=h5082296_0
28 | - tk=8.6.12=h1ccaba5_0
29 | - wheel=0.37.1=pyhd3eb1b0_0
30 | - xz=5.2.10=h5eee18b_1
31 | - zlib=1.2.13=h5eee18b_0
32 | - pip:
33 | - absl-py==1.4.0
34 | - addict==2.4.0
35 | - aiohttp==3.8.3
36 | - aiosignal==1.3.1
37 | - appdirs==1.4.4
38 | - asttokens==2.2.1
39 | - async-timeout==4.0.2
40 | - attrs==22.2.0
41 | - backcall==0.2.0
42 | - cachetools==5.3.0
43 | - charset-normalizer==2.1.1
44 | - click==8.1.3
45 | - comm==0.1.2
46 | - configargparse==1.5.3
47 | - contourpy==1.0.7
48 | - cuda-sparse-ops==0.1.0
49 | - cycler==0.11.0
50 | - dash==2.8.1
51 | - dash-core-components==2.0.0
52 | - dash-html-components==2.0.0
53 | - dash-table==5.0.0
54 | - debugpy==1.6.6
55 | - decorator==5.1.1
56 | - docker-pycreds==0.4.0
57 | - einops==0.6.0
58 | - executing==1.2.0
59 | - fastjsonschema==2.16.2
60 | - fire==0.5.0
61 | - flask==2.2.2
62 | - fonttools==4.38.0
63 | - frozenlist==1.3.3
64 | - fsspec==2023.1.0
65 | - gin-config==0.5.0
66 | - gitdb==4.0.10
67 | - gitpython==3.1.30
68 | - google-auth==2.16.0
69 | - google-auth-oauthlib==0.4.6
70 | - grpcio==1.51.1
71 | - h5py==3.8.0
72 | - idna==3.4
73 | - importlib-metadata==6.0.0
74 | - importlib-resources==5.10.2
75 | - ipykernel==6.21.0
76 | - ipython==8.9.0
77 | - ipywidgets==8.0.4
78 | - itsdangerous==2.1.2
79 | - jedi==0.18.2
80 | - jinja2==3.1.2
81 | - joblib==1.2.0
82 | - jsonschema==4.17.3
83 | - jupyter-client==8.0.2
84 | - jupyter-core==5.2.0
85 | - jupyterlab-widgets==3.0.5
86 | - kiwisolver==1.4.4
87 | - lightning-bolts==0.6.0.post1
88 | - lightning-utilities==0.3.0
89 | - markdown==3.4.1
90 | - markdown-it-py==2.1.0
91 | - markupsafe==2.1.2
92 | - matplotlib==3.6.3
93 | - matplotlib-inline==0.1.6
94 | - mdurl==0.1.2
95 | - minkowskiengine==0.5.4
96 | - multidict==6.0.4
97 | - nbformat==5.5.0
98 | - numpy==1.24.1
99 | - oauthlib==3.2.2
100 | - open3d==0.16.0
101 | - packaging==23.0
102 | - pandas==1.5.3
103 | - parso==0.8.3
104 | - pathtools==0.1.2
105 | - pexpect==4.8.0
106 | - pickleshare==0.7.5
107 | - pillow==9.4.0
108 | - pkgutil-resolve-name==1.3.10
109 | - platformdirs==2.6.2
110 | - plotly==5.13.0
111 | - plyfile==0.7.4
112 | - prompt-toolkit==3.0.36
113 | - protobuf==3.20.3
114 | - psutil==5.9.4
115 | - ptyprocess==0.7.0
116 | - pure-eval==0.2.2
117 | - pyasn1==0.4.8
118 | - pyasn1-modules==0.2.8
119 | - pygments==2.14.0
120 | - pyparsing==3.0.9
121 | - pyquaternion==0.9.9
122 | - pyrsistent==0.19.3
123 | - python-dateutil==2.8.2
124 | - pytorch-lightning==1.8.2
125 | - pytz==2022.7.1
126 | - pyyaml==6.0
127 | - pyzmq==25.0.0
128 | - requests==2.28.2
129 | - requests-oauthlib==1.3.1
130 | - rich==13.3.1
131 | - rsa==4.9
132 | - scikit-learn==1.2.1
133 | - scipy==1.10.0
134 | - sentry-sdk==1.14.0
135 | - setproctitle==1.3.2
136 | - six==1.16.0
137 | - smmap==5.0.0
138 | - stack-data==0.6.2
139 | - tenacity==8.1.0
140 | - tensorboard==2.11.2
141 | - tensorboard-data-server==0.6.1
142 | - tensorboard-plugin-wit==1.8.1
143 | - termcolor==2.2.0
144 | - threadpoolctl==3.1.0
145 | - torch==1.12.1+cu113
146 | - torch-scatter==2.1.0+pt112cu113
147 | - torchmetrics==0.11.0
148 | - torchvision==0.13.1+cu113
149 | - tornado==6.2
150 | - tqdm==4.64.1
151 | - traitlets==5.9.0
152 | - typing-extensions==4.4.0
153 | - urllib3==1.26.14
154 | - wandb==0.13.9
155 | - wcwidth==0.2.6
156 | - werkzeug==2.2.2
157 | - widgetsnbextension==4.0.5
158 | - yarl==1.8.2
159 | - zipp==3.12.0
160 | prefix: /opt/anaconda3/envs/fpt
161 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import argparse
3 |
4 | import gin
5 | import torch
6 | import torchmetrics
7 | import MinkowskiEngine as ME
8 | import numpy as np
9 | from rich.console import Console
10 | from rich.progress import track
11 | from rich.table import Table
12 |
13 | from src.models import get_model
14 | from src.data import get_data_module
15 | from src.utils.metric import per_class_iou
16 | import src.data.transforms as T
17 |
18 |
19 | def print_results(classnames, confusion_matrix):
20 | # results
21 | ious = per_class_iou(confusion_matrix) * 100
22 | accs = confusion_matrix.diagonal() / confusion_matrix.sum(1) * 100
23 | miou = np.nanmean(ious)
24 | macc = np.nanmean(accs)
25 |
26 | # print results
27 | console = Console()
28 | table = Table(show_header=True, header_style="bold")
29 |
30 | columns = ["mAcc", "mIoU"]
31 | num_classes = len(classnames)
32 | for i in range(num_classes):
33 | columns.append(classnames[i])
34 | for col in columns:
35 | table.add_column(col)
36 | ious = ious.tolist()
37 | row = [macc, miou, *ious]
38 | table.add_row(*[f"{x:.2f}" for x in row])
39 | console.print(table)
40 |
41 |
42 | def get_rotation_matrices(num_rotations=8):
43 | angles = [2 * np.pi / num_rotations * i for i in range(num_rotations)]
44 | rot_matrices = []
45 | for angle in angles:
46 | rot_matrices.append(
47 | torch.Tensor([
48 | [np.cos(angle), -np.sin(angle), 0, 0],
49 | [np.sin(angle), np.cos(angle), 0, 0],
50 | [0, 0, 1, 0],
51 | [0, 0, 0, 1]
52 | ])
53 | )
54 | return rot_matrices
55 |
56 |
57 | @torch.no_grad()
58 | def infer(model, batch, device):
59 | in_data = ME.TensorField(
60 | features=batch["features"],
61 | coordinates=batch["coordinates"],
62 | quantization_mode=model.QMODE,
63 | device=device
64 | )
65 | pred = model(in_data).argmax(dim=1).cpu()
66 | return pred
67 |
68 |
69 | @torch.no_grad()
70 | def infer_with_rotation_average(model, batch, device):
71 | rotation_matrices = get_rotation_matrices()
72 | pred = torch.zeros((len(batch["labels"]), model.out_channels), dtype=torch.float32)
73 | for M in rotation_matrices:
74 | batch_, coords_ = torch.split(batch["coordinates"], [1, 3], dim=1)
75 | coords = T.homogeneous_coords(coords_) @ M
76 | coords = torch.cat([batch_, coords[:, :3].float()], dim=1)
77 |
78 | in_data = ME.TensorField(
79 | features=batch["features"],
80 | coordinates=coords,
81 | quantization_mode=model.QMODE,
82 | device=device
83 | )
84 | pred += model(in_data).cpu()
85 |
86 | gc.collect()
87 | torch.cuda.empty_cache()
88 |
89 | pred = pred.argmax(dim=1)
90 | return pred
91 |
92 |
93 | @gin.configurable
94 | def eval(
95 | checkpoint_path,
96 | model_name,
97 | data_module_name,
98 | use_rotation_average,
99 | ):
100 | assert torch.cuda.is_available()
101 | device = torch.device("cuda")
102 |
103 | ckpt = torch.load(checkpoint_path)
104 |
105 | def remove_prefix(k, prefix):
106 | return k[len(prefix) :] if k.startswith(prefix) else k
107 |
108 | state_dict = {remove_prefix(k, "model."): v for k, v in ckpt["state_dict"].items()}
109 | model = get_model(model_name)()
110 | model.load_state_dict(state_dict)
111 | model = model.to(device)
112 | model.eval()
113 |
114 | data_module = get_data_module(data_module_name)()
115 | data_module.setup("test")
116 | val_loader = data_module.val_dataloader()
117 |
118 | confmat = torchmetrics.ConfusionMatrix(
119 | num_classes=data_module.dset_val.NUM_CLASSES, compute_on_step=False
120 | )
121 | infer_fn = infer_with_rotation_average if use_rotation_average else infer
122 | with torch.inference_mode(mode=True):
123 | for batch in track(val_loader):
124 | pred = infer_fn(model, batch, device)
125 | mask = batch["labels"] != data_module.dset_val.ignore_label
126 | confmat(pred[mask], batch["labels"][mask])
127 | torch.cuda.empty_cache()
128 | confmat = confmat.compute().numpy()
129 |
130 | cnames = data_module.dset_val.get_classnames()
131 | print_results(cnames, confmat)
132 |
133 |
134 | if __name__ == "__main__":
135 | parser = argparse.ArgumentParser()
136 | parser.add_argument("config", type=str)
137 | parser.add_argument("ckpt_path", type=str)
138 | parser.add_argument("-r", "--use_rotation_average", action="store_true")
139 | parser.add_argument("-v", "--voxel_size", type=float, default=None) # overwrite voxel_size
140 | args = parser.parse_args()
141 |
142 | gin.parse_config_file(args.config)
143 | if args.voxel_size is not None:
144 | gin.bind_parameter("DimensionlessCoordinates.voxel_size", args.voxel_size)
145 |
146 | eval(args.ckpt_path, use_rotation_average=args.use_rotation_average)
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | SERVER=${2:-local}
4 |
5 | if [[ $SERVER = *local* ]]; then
6 | echo "[FPT INFO] Running on Local: You should manually load modules..."
7 | conda init zsh
8 | source /opt/anaconda3/etc/profile.d/conda.sh # you may need to modify the conda path.
9 | export CUDA_HOME=/usr/local/cuda-11.1
10 | else
11 | echo "[FPT INFO] Running on Server..."
12 | conda init bash
13 | source ~/anaconda3/etc/profile.d/conda.sh
14 |
15 | module purge
16 | module load autotools
17 | module load prun/1.3
18 | module load gnu8/8.3.0
19 | module load singularity
20 |
21 | module load cuDNN/cuda/11.1/8.0.4.30
22 | module load cuda/11.1
23 | module load nccl/cuda/11.1/2.8.3
24 |
25 | echo "[FPT INFO] Loaded all modules."
26 | fi;
27 |
28 | ENVS=$(conda env list | awk '{print $1}' )
29 |
30 | if [[ $ENVS = *"$1"* ]]; then
31 | echo "[FPT INFO] \"$1\" already exists. Pass the installation."
32 | else
33 | echo "[FPT INFO] Creating $1..."
34 | conda create -n $1 python=3.8 -y
35 | conda activate "$1"
36 | echo "[FPT INFO] Done."
37 |
38 | echo "[FPT INFO] Installing OpenBLAS and PyTorch..."
39 | conda install pytorch=1.10.0 torchvision cudatoolkit=11.1 -c pytorch -c nvidia -y
40 | conda install numpy -y
41 | conda install openblas-devel -c anaconda -y
42 | echo "[FPT INFO] Done."
43 |
44 | echo "[FPT INFO] Installing other dependencies..."
45 | conda install -c anaconda pandas scipy h5py scikit-learn -y
46 | conda install -c conda-forge plyfile pytorch-lightning torchmetrics wandb wrapt gin-config rich einops -y
47 | conda install -c open3d-admin -c conda-forge open3d -y
48 | pip install lightning-bolts
49 | pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
50 | echo "[FPT INFO] Done."
51 |
52 | echo "[FPT INFO] Installing MinkowskiEngine..."
53 | cd thirdparty/MinkowskiEngine
54 | python setup.py install --blas_include_dirs=${CONDA_PREFIX}/include --blas=openblas --force_cuda
55 | cd ../..
56 | echo "[FPT INFO] Done."
57 |
58 | echo "[FPT INFO] Installing cuda_ops..."
59 | cd src/cuda_ops
60 | pip3 install .
61 | cd ../..
62 | echo "[FPT INFO] Done."
63 |
64 | TORCH="$(python -c "import torch; print(torch.__version__)")"
65 | ME="$(python -c "import MinkowskiEngine as ME; print(ME.__version__)")"
66 |
67 | echo "[FPT INFO] Finished the installation!"
68 | echo "[FPT INFO] ========== Configurations =========="
69 | echo "[FPT INFO] PyTorch version: $TORCH"
70 | echo "[FPT INFO] MinkowskiEngine version: $ME"
71 | echo "[FPT INFO] ===================================="
72 | fi;
--------------------------------------------------------------------------------
/src/cscore/calculate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import argparse
4 |
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | N = 41 # the number of rigid transforms
9 | N_t = 26 # translation
10 | N_r = 15 # rotation
11 | M = 312 # the number of validation scenes
12 |
13 |
14 | def main(args):
15 | ref_dir = osp.join(args.dir, "reference")
16 | ref_fnames = sorted(os.listdir(ref_dir))
17 | assert len(ref_fnames) == M
18 | pred_dirs = [osp.join(args.dir, f"transform{i}") for i in range(N)]
19 |
20 | print(">>> Loading references...")
21 | # load references
22 | refs = [np.load(osp.join(ref_dir, fname)) for fname in ref_fnames]
23 | print(" done!")
24 |
25 | # translation
26 | print(">>> Calculating point-wise CScores for translation...")
27 | score_trans = [] # pointwise scores
28 | for i, (fname, ref) in enumerate(tqdm(zip(ref_fnames, refs))):
29 | score_tran = np.zeros_like(ref)
30 | for j in range(N_t):
31 | pred = np.load(osp.join(pred_dirs[j], fname))
32 | score_tran[np.where(ref == pred)] += 1
33 | score_tran = score_tran / N_t
34 | score_tran[np.where(ref == 255)] = -1
35 | score_trans.append(score_tran)
36 | print(" done!")
37 |
38 | # rotation
39 | print(">>> Calculating point-wise CScores for rotation...")
40 | score_rots = [] # pointwise scores
41 | for i, (fname, ref) in enumerate(tqdm(zip(ref_fnames, refs))):
42 | score_rot = np.zeros_like(ref)
43 | for j in range(N_r):
44 | pred = np.load(osp.join(pred_dirs[N_t + j], fname))
45 | score_rot[np.where(ref == pred)] += 1
46 | score_rot = score_rot / N_r
47 | score_rot[np.where(ref == 255)] = -1
48 | score_rots.append(score_rot)
49 | print(" done!")
50 |
51 | # full
52 | print(">>> Calculating point-wise CScores for full rigid transformations...")
53 | score_fulls = []
54 | for fname, score_tran, score_rot in tqdm(zip(ref_fnames, score_trans, score_rots)):
55 | score_full = (N_t*score_tran + N_r*score_rot) / N
56 | score_fulls.append(score_full)
57 | print(" done!")
58 |
59 | # final calculation
60 | cloudwise_score_trans = [np.mean(score_tran[np.where(score_tran > -0.5)]) for score_tran in score_trans]
61 | cloudwise_score_rots = [np.mean(score_rot[np.where(score_rot > -0.5)]) for score_rot in score_rots]
62 | cloudwise_score_fulls = [np.mean(score_full[np.where(score_full > -0.5)]) for score_full in score_fulls]
63 |
64 | if args.save:
65 | # save pointwise scores for visualization
66 | output_dir = osp.join(args.dir, "pointwise_scores")
67 | os.makedirs(output_dir, exist_ok=True)
68 | # save for visualization
69 | print(">>> Saving point-wise CScores for full rigid transformations...")
70 | for fname, cloudwise_score_full, score_full in tqdm(zip(ref_fnames, cloudwise_score_fulls, score_fulls)):
71 | scene_id = fname.split('.')[0]
72 | np.save(osp.join(output_dir, f'{scene_id}-score={round(1000 * cloudwise_score_full):d}.npy'), score_full)
73 | print(" done!")
74 |
75 | # final logging
76 | total_score_tran = np.mean(cloudwise_score_trans)
77 | total_score_rot = np.mean(cloudwise_score_rots)
78 | total_score_full = np.mean(cloudwise_score_fulls)
79 | print(">>> Results:")
80 | print(f" Rotation: {100 * total_score_rot:.1f}")
81 | print(f" Translation:: {100 * total_score_tran:.1f}")
82 | print(f" Full: {100 * total_score_full:.1f}")
83 |
84 |
85 | if __name__ == "__main__":
86 | parser = argparse.ArgumentParser()
87 | parser.add_argument("--dir", type=str)
88 | parser.add_argument("--save", action="store_true", default=False) # for point-wise cscore visualization.
89 | args = parser.parse_args()
90 |
91 | main(args)
--------------------------------------------------------------------------------
/src/cscore/prepare.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import argparse
4 | import os.path as osp
5 |
6 | import torch
7 | import numpy as np
8 | import pytorch_lightning as pl
9 | import MinkowskiEngine as ME
10 | from tqdm import tqdm
11 |
12 | from src.models import get_model
13 | from src.data.scannet_loader import ScanNetRGBDataset_
14 | from src.data.transforms import NormalizeColor, homogeneous_coords
15 | from src.utils.misc import load_from_pl_state_dict
16 |
17 |
18 | class SimpleConfig:
19 | def __init__(
20 | self,
21 | scannet_path="/root/data/scannetv2",
22 | ignore_label=255,
23 | voxel_size=0.1,
24 | cache_data=False
25 | ):
26 | self.scannet_path = scannet_path
27 | self.voxel_size= voxel_size
28 | self.ignore_label = ignore_label
29 | self.cache_data = cache_data
30 | self.limit_numpoints = -1
31 |
32 |
33 | class TransformGenerator:
34 | def __init__(self, voxel_size=0.1, trans_granularity=3, rot_granularity=8):
35 | self.voxel_size = voxel_size
36 | trans_step = 1. / trans_granularity
37 | self.trans_list = [trans_step * i * voxel_size for i in range(trans_granularity)] # 0, 1, 2, 3, ..., granularity - 1
38 | rot_step = 1. / rot_granularity
39 | self.rot_list = [rot_step * i * np.pi for i in range(1, 2*rot_granularity)]
40 | self.transform_list = []
41 | self.num_trans = 0
42 | self.num_rot = 0
43 |
44 | def generate_transforms(self):
45 | self._cleanup()
46 | self.transform_list.extend(self._get_trans_mtx_list())
47 | self.transform_list.extend(self._get_rot_mtx_list())
48 |
49 | def _get_trans_mtx_list(self):
50 | mtx_list = []
51 | for delta_x in self.trans_list:
52 | for delta_y in self.trans_list:
53 | for delta_z in self.trans_list:
54 | if delta_x == 0 and delta_y == 0 and delta_z == 0:
55 | continue
56 | T = np.eye(4)
57 | T[0, 3] = delta_x
58 | T[1, 3] = delta_y
59 | T[2, 3] = delta_z
60 | mtx_list.append(T)
61 | self.num_trans += 1
62 | return mtx_list
63 |
64 | def _get_rot_mtx_list(self):
65 | mtx_list = []
66 | for theta in self.rot_list:
67 | T = np.array(
68 | [
69 | [np.cos(theta), -np.sin(theta), 0, 0],
70 | [np.sin(theta), np.cos(theta), 0, 0],
71 | [0, 0, 1, 0],
72 | [0, 0, 0, 1]
73 | ]
74 | )
75 | mtx_list.append(T)
76 | self.num_rot += 1
77 | return mtx_list
78 |
79 | def _cleanup(self):
80 | self.num_trans = 0
81 | self.num_rot = 0
82 | self.transform_list = []
83 |
84 |
85 | def save_outputs(args, model, name, tmatrix):
86 | output_dir = osp.join(args.out_dir, args.model_name + f"-voxel={args.voxel_size}")
87 | if args.postfix is not None:
88 | output_dir = output_dir + args.postfix
89 | transform = NormalizeColor()
90 | dset = ScanNetRGBDataset_("val", args.scannet_path, transform)
91 | num_samples = len(dset)
92 | output_name = osp.join(output_dir, name)
93 | assert not osp.isdir(output_name)
94 | os.makedirs(output_name, exist_ok=True)
95 |
96 | print(f'>>> Saving predictions for {num_samples} val samples...')
97 | with torch.inference_mode(mode=True):
98 | for coords_, feats, labels, fname in dset:
99 | if tmatrix is not None:
100 | np.save(osp.join(output_name, "tmatrix.npy"), tmatrix)
101 | coords_ = homogeneous_coords(coords_) @ tmatrix.T
102 | coords = coords_[:, :3]
103 | else:
104 | coords = coords_
105 | coords, feats = ME.utils.sparse_collate(
106 | [coords / args.voxel_size],
107 | [feats],
108 | dtype=torch.float32
109 | )
110 | in_field = ME.TensorField(
111 | features=feats,
112 | coordinates=coords,
113 | quantization_mode=model.QMODE,
114 | device=device
115 | )
116 | pred = model(in_field).argmax(dim=1, keepdim=False).cpu().numpy()
117 | assert len(pred) == len(labels)
118 | pred[np.where(labels.numpy() == 255)] = 255 # ignore labels
119 | scene_id = fname.split('/')[-1].split('.')[0]
120 | np.save(osp.join(output_name, f'{scene_id}.npy'), pred)
121 | gc.collect()
122 | torch.cuda.empty_cache()
123 | print(' done!')
124 |
125 |
126 | if __name__ == "__main__":
127 | parser = argparse.ArgumentParser()
128 | parser.add_argument("-c", "--ckpt", type=str)
129 | parser.add_argument("-m", "--model_name", type=str, choices=["mink", "fpt"])
130 | parser.add_argument("-v", "--voxel_size", type=float, choices=[0.1, 0.05, 0.02])
131 | parser.add_argument("--scannet_path", type=str, default="/root/data/scannetv2")
132 | parser.add_argument("--out_dir", type=str, default="consistency_outputs")
133 | parser.add_argument("-p", "--postfix", type=str, default=None)
134 | args = parser.parse_args()
135 |
136 | assert torch.cuda.is_available()
137 | device = torch.device("cuda")
138 | print(f">>> Loading the checkpoint from {args.ckpt}...")
139 | ckpt = torch.load(args.ckpt)
140 | pl.seed_everything(7777)
141 | print(" done!")
142 |
143 | print(">>> Loading the model...")
144 | if args.model_name == "mink":
145 | model = get_model("Res16UNet34C")(ScanNetRGBDataset_.IN_CHANNELS, ScanNetRGBDataset_.NUM_CLASSES)
146 | else:
147 | model = get_model("FastPointTransformer")(ScanNetRGBDataset_.IN_CHANNELS, ScanNetRGBDataset_.NUM_CLASSES)
148 | model = load_from_pl_state_dict(model, ckpt["state_dict"])
149 | model = model.to(device)
150 | model.eval()
151 | print(" done!")
152 |
153 | print(">>> Generating rigid transformations...")
154 | tgenerator = TransformGenerator(voxel_size=args.voxel_size)
155 | tgenerator.generate_transforms()
156 | print(f' {len(tgenerator.transform_list)} rigid transformations generated!')
157 |
158 | print(f">>> Evaluating the model...")
159 | save_outputs(args, model, "reference", None)
160 | for t_idx, tmatrix in enumerate(tqdm(tgenerator.transform_list)):
161 | print(tmatrix)
162 | save_outputs(args, model, f"transform{t_idx}", tmatrix)
163 | print(" done!")
--------------------------------------------------------------------------------
/src/cuda_ops/functions/sparse_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 |
4 | import cuda_sparse_ops
5 |
6 |
7 | class DotProduct(Function):
8 | @staticmethod
9 | def forward(ctx, query, pos_enc, out_F, kq_map):
10 | assert (query.is_contiguous() and pos_enc.is_contiguous() and out_F.is_contiguous())
11 | ctx.m = kq_map.shape[1]
12 | _, ctx.h, ctx.c = query.shape
13 | ctx.kkk = pos_enc.shape[0]
14 | ctx.save_for_backward(query, pos_enc, kq_map)
15 | cuda_sparse_ops.dot_product_forward(ctx.m, ctx.h, ctx.kkk, ctx.c, query, pos_enc,
16 | out_F, kq_map)
17 | return out_F
18 |
19 | @staticmethod
20 | def backward(ctx, grad_out_F):
21 | query, pos_enc, kq_map = ctx.saved_tensors
22 | grad_query = torch.zeros_like(query)
23 | grad_pos = torch.zeros_like(pos_enc)
24 | cuda_sparse_ops.dot_product_backward(ctx.m, ctx.h, ctx.kkk, ctx.c, query, pos_enc,
25 | kq_map, grad_query, grad_pos, grad_out_F)
26 | return grad_query, grad_pos, None, None
27 |
28 | dot_product_cuda = DotProduct.apply
29 |
30 |
31 | class ScalarAttention(Function):
32 | @staticmethod
33 | def forward(ctx, weight, value, out_F, kq_indices):
34 | assert (weight.is_contiguous() and value.is_contiguous() and out_F.is_contiguous())
35 | ctx.m = kq_indices.shape[1]
36 | _, ctx.h, ctx.c = value.shape
37 | ctx.save_for_backward(weight, value, kq_indices)
38 | cuda_sparse_ops.scalar_attention_forward(ctx.m, ctx.h, ctx.c, weight, value, out_F,
39 | kq_indices)
40 | return out_F
41 |
42 | @staticmethod
43 | def backward(ctx, grad_out_F):
44 | weight, value, kq_indices = ctx.saved_tensors
45 | grad_weight = torch.zeros_like(weight)
46 | grad_value = torch.zeros_like(value)
47 | cuda_sparse_ops.scalar_attention_backward(ctx.m, ctx.h, ctx.c, weight, value,
48 | kq_indices, grad_weight, grad_value,
49 | grad_out_F)
50 | return grad_weight, grad_value, None, None
51 |
52 | scalar_attention_cuda = ScalarAttention.apply
53 |
--------------------------------------------------------------------------------
/src/cuda_ops/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name='cuda_sparse_ops',
6 | author='Chunghyun Park and Yoonwoo Jeong',
7 | version="0.1.0",
8 | ext_modules=[
9 | CUDAExtension('cuda_sparse_ops', [
10 | 'src/cuda_ops_api.cpp',
11 | 'src/dot_product/dot_product.cpp',
12 | 'src/dot_product/dot_product_kernel.cu',
13 | 'src/scalar_attention/scalar_attention.cpp',
14 | 'src/scalar_attention/scalar_attention_kernel.cu',
15 | ],
16 | extra_compile_args={
17 | 'cxx': ['-g'],
18 | 'nvcc': ['-O2']
19 | })
20 | ],
21 | cmdclass={'build_ext': BuildExtension})
22 |
--------------------------------------------------------------------------------
/src/cuda_ops/src/cuda_ops_api.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "dot_product/dot_product_kernel.h"
5 | #include "scalar_attention/scalar_attention_kernel.h"
6 |
7 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
8 | m.def("dot_product_forward", &dot_product_forward, "dot_product_forward");
9 | m.def("dot_product_backward", &dot_product_backward, "dot_product_backward");
10 | m.def("scalar_attention_forward", &scalar_attention_forward, "scalar_attention_forward");
11 | m.def("scalar_attention_backward", &scalar_attention_backward, "scalar_attention_backward");
12 | }
--------------------------------------------------------------------------------
/src/cuda_ops/src/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 | #include
6 |
7 | #define TOTAL_THREADS 1024
8 | #define THREADS_PER_BLOCK 256
9 | #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
10 |
11 | inline int opt_n_threads(int work_size) {
12 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
13 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1);
14 | }
15 |
16 | inline dim3 opt_block_config(int x, int y) {
17 | const int x_threads = opt_n_threads(x);
18 | const int y_threads = std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
19 | dim3 block_config(x_threads, y_threads, 1);
20 | return block_config;
21 | }
22 |
23 | #endif
24 |
--------------------------------------------------------------------------------
/src/cuda_ops/src/dot_product/dot_product.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "dot_product_kernel.h"
6 |
7 | void dot_product_forward(
8 | int m, int h, int kkk, int c, AT query_tensor, AT pos_tensor, AT out_F_tensor, AT kq_map_tensor
9 | )
10 | {
11 | const float* query = query_tensor.data_ptr();
12 | const float* pos = pos_tensor.data_ptr();
13 | float* out_F = out_F_tensor.data_ptr();
14 | const int* kq_map = kq_map_tensor.data_ptr();
15 |
16 | dot_product_forward_launcher(
17 | m, h, kkk, c, query, pos, out_F, kq_map
18 | );
19 | }
20 |
21 | void dot_product_backward(
22 | int m, int h, int kkk, int c, AT query_tensor, AT pos_tensor, AT kq_map_tensor,
23 | AT grad_query_tensor, AT grad_pos_tensor, AT grad_out_F_tensor
24 | )
25 | {
26 | const float* query = query_tensor.data_ptr();
27 | const float* pos = pos_tensor.data_ptr();
28 | const int* kq_map = kq_map_tensor.data_ptr();
29 |
30 | float* grad_query = grad_query_tensor.data_ptr();
31 | float* grad_pos = grad_pos_tensor.data_ptr();
32 | const float* grad_out_F = grad_out_F_tensor.data_ptr();
33 |
34 | dot_product_backward_launcher(
35 | m, h, kkk, c, query, pos, kq_map,
36 | grad_query, grad_pos, grad_out_F
37 | );
38 | }
--------------------------------------------------------------------------------
/src/cuda_ops/src/dot_product/dot_product_kernel.cu:
--------------------------------------------------------------------------------
1 | #include "../cuda_utils.h"
2 | #include "dot_product_kernel.h"
3 |
4 |
5 | __global__ void dot_product_forward_kernel(
6 | int m, int h, int kkk, int c, const float* query, const float* pos, float* out_F, const int* kq_map
7 | )
8 | {
9 | // m: # of total mappings
10 | // h: # of attention heads
11 | // kkk: # of keys (kernel volume)
12 | // c: # of attention channels
13 |
14 | int index = blockIdx.x * blockDim.x + threadIdx.x;
15 | if (index >= m * h) return;
16 |
17 | int map_idx = index / h;
18 | int head_idx = index % h;
19 |
20 | int query_idx_ = kq_map[m + map_idx]; // kq_map[1][map_idx]
21 | int kernel_idx = kq_map[map_idx] % kkk;
22 |
23 | for(int i = 0; i < c; i++){
24 |
25 | int query_idx = query_idx_ * h * c + head_idx * c + i;
26 | int pos_idx = kernel_idx * h * c + head_idx * c + i;
27 |
28 | out_F[index] += query[query_idx] * pos[pos_idx];
29 | }
30 | }
31 |
32 | __global__ void dot_product_backward_kernel(
33 | int m, int h, int kkk, int c, const float* query, const float* pos, const int* kq_map,
34 | float* grad_query, float* grad_pos, const float* grad_out_F
35 | )
36 | {
37 | // m: # of total mappings
38 | // h: # of attention heads
39 | // kkk: # of keys (kernel volume)
40 | // c: # of attention channels
41 |
42 | int index = blockIdx.x * blockDim.x + threadIdx.x;
43 | if (index >= m * c) return;
44 |
45 | int map_idx = index / c;
46 | int i = index % c;
47 |
48 | int query_idx_ = kq_map[m + map_idx]; // kq_map[1][map_idx]
49 | int kernel_idx = kq_map[map_idx] % kkk;
50 |
51 | for(int head_idx = 0; head_idx < h; head_idx++){
52 |
53 | int out_F_idx = map_idx * h + head_idx;
54 | int query_idx = query_idx_ * h * c + head_idx * c + i;
55 | int pos_idx = kernel_idx * h * c + head_idx * c + i;
56 |
57 | atomicAdd(grad_query + query_idx, grad_out_F[out_F_idx] * pos[pos_idx]);
58 | atomicAdd(grad_pos + pos_idx, grad_out_F[out_F_idx] * query[query_idx]);
59 | }
60 | }
61 |
62 | void dot_product_forward_launcher(
63 | int m, int h, int kkk, int c, const float* query, const float* pos, float* out_F, const int* kq_map
64 | ) {
65 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
66 | dim3 blocks(DIVUP(m * h, THREADS_PER_BLOCK));
67 | dim3 threads(THREADS_PER_BLOCK);
68 | dot_product_forward_kernel<<>>(
69 | m, h, kkk, c, query, pos, out_F, kq_map
70 | );
71 | }
72 |
73 | void dot_product_backward_launcher(
74 | int m, int h, int kkk, int c, const float* query, const float* pos, const int* kq_map,
75 | float* grad_query, float* grad_pos, const float* grad_out_F
76 | ) {
77 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
78 | dim3 blocks(DIVUP(m * c, THREADS_PER_BLOCK));
79 | dim3 threads(THREADS_PER_BLOCK);
80 | dot_product_backward_kernel<<>>(
81 | m, h, kkk, c, query, pos, kq_map,
82 | grad_query, grad_pos, grad_out_F
83 | );
84 | }
85 |
--------------------------------------------------------------------------------
/src/cuda_ops/src/dot_product/dot_product_kernel.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #ifndef _dot_product_KERNEL
3 | #define _dot_product_KERNEL
4 | #include
5 | #include
6 | #include
7 |
8 | #define AT at::Tensor
9 |
10 | void dot_product_forward(
11 | int m, int h, int kkk, int c, AT query_tensor, AT pos_tensor, AT out_F_tensor, AT kq_map_tensor
12 | );
13 | void dot_product_backward(
14 | int m, int h, int kkk, int c, AT query_tensor, AT pos_tensor, AT kq_map_tensor,
15 | AT grad_query_tensor, AT grad_pos_tensor, AT grad_out_F_tensor
16 | );
17 |
18 | #ifdef __cplusplus
19 | extern "C" {
20 | #endif
21 |
22 | void dot_product_forward_launcher(
23 | int m, int h, int kkk, int c, const float* query, const float* pos, float* out_F, const int* kq_map
24 | );
25 | void dot_product_backward_launcher(
26 | int m, int h, int kkk, int c, const float* query, const float* pos, const int* kq_map,
27 | float* grad_query, float* grad_pos, const float* grad_out_F
28 | );
29 |
30 | #ifdef __cplusplus
31 | }
32 | #endif
33 | #endif
34 |
--------------------------------------------------------------------------------
/src/cuda_ops/src/scalar_attention/scalar_attention.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "scalar_attention_kernel.h"
6 |
7 | void scalar_attention_forward(
8 | int m, int h, int c, AT weight_tensor, AT value_tensor, AT out_F_tensor, AT kq_indices_tensor
9 | )
10 | {
11 | const float* weight = weight_tensor.data_ptr();
12 | const float* value = value_tensor.data_ptr();
13 | float* out_F = out_F_tensor.data_ptr();
14 | const int* kq_indices = kq_indices_tensor.data_ptr();
15 |
16 | scalar_attention_forward_launcher(
17 | m, h, c, weight, value, out_F, kq_indices
18 | );
19 | }
20 |
21 | void scalar_attention_backward(
22 | int m, int h, int c, AT weight_tensor, AT value_tensor, AT kq_indices_tensor,
23 | AT grad_weight_tensor, AT grad_value_tensor, AT grad_out_F_tensor
24 | )
25 | {
26 | const float* weight = weight_tensor.data_ptr();
27 | const float* value = value_tensor.data_ptr();
28 | const int* kq_indices = kq_indices_tensor.data_ptr();
29 |
30 | float* grad_weight = grad_weight_tensor.data_ptr();
31 | float* grad_value = grad_value_tensor.data_ptr();
32 | const float* grad_out_F = grad_out_F_tensor.data_ptr();
33 |
34 | scalar_attention_backward_launcher(
35 | m, h, c, weight, value, kq_indices,
36 | grad_weight, grad_value, grad_out_F
37 | );
38 | }
--------------------------------------------------------------------------------
/src/cuda_ops/src/scalar_attention/scalar_attention_kernel.cu:
--------------------------------------------------------------------------------
1 | #include "../cuda_utils.h"
2 | #include "scalar_attention_kernel.h"
3 |
4 |
5 | __global__ void scalar_attention_forward_kernel(
6 | int m, int h, int c, const float* weight, const float* value, float* out_F, const int* kq_indices
7 | )
8 | {
9 | // m: # of total mappings
10 | // h: # of attention heads
11 | // c: # of attention channels
12 |
13 | int index = blockIdx.x * blockDim.x + threadIdx.x;
14 | if (index >= m * c) return;
15 |
16 | int map_idx = index / c;
17 | int i = index % c;
18 |
19 | int out_F_idx_ = kq_indices[m + map_idx]; // kq_indices[1][map_idx]
20 | int value_idx_ = kq_indices[map_idx]; // kq_indices[0][map_idx]
21 |
22 | for(int head_idx = 0; head_idx < h; head_idx++){
23 |
24 | int weight_idx = map_idx * h + head_idx;
25 | int out_F_idx = out_F_idx_ * h * c + head_idx * c + i;
26 | int value_idx = value_idx_ * h * c + head_idx * c + i;
27 |
28 | atomicAdd(out_F + out_F_idx, weight[weight_idx] * value[value_idx]);
29 | }
30 | }
31 |
32 | __global__ void scalar_attention_backward_kernel(
33 | int m, int h, int c, const float* weight, const float* value, const int* kq_indices,
34 | float* grad_weight, float* grad_value, const float* grad_out_F
35 | )
36 | {
37 | // m: # of total mappings
38 | // h: # of attention heads
39 | // c: # of attention channels
40 |
41 | int index = blockIdx.x * blockDim.x + threadIdx.x;
42 | if (index >= m * c) return;
43 |
44 | int map_idx = index / c;
45 | int i = index % c;
46 |
47 | int out_F_idx_ = kq_indices[m + map_idx]; // kq_indices[1][map_idx]
48 | int value_idx_ = kq_indices[map_idx]; // kq_indices[0][map_idx]
49 |
50 | for(int head_idx = 0; head_idx < h; head_idx++){
51 |
52 | int weight_idx = map_idx * h + head_idx;
53 | int out_F_idx = out_F_idx_ * h * c + head_idx * c + i;
54 | int value_idx = value_idx_ * h * c + head_idx * c + i;
55 |
56 | atomicAdd(grad_weight + weight_idx, grad_out_F[out_F_idx] * value[value_idx]);
57 | atomicAdd(grad_value + value_idx, grad_out_F[out_F_idx] * weight[weight_idx]);
58 | }
59 | }
60 |
61 | void scalar_attention_forward_launcher(
62 | int m, int h, int c, const float* weight, const float* value, float* out_F, const int* kq_indices
63 | ) {
64 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
65 | dim3 blocks(DIVUP(m * c, THREADS_PER_BLOCK));
66 | dim3 threads(THREADS_PER_BLOCK);
67 | scalar_attention_forward_kernel<<>>(
68 | m, h, c, weight, value, out_F, kq_indices
69 | );
70 | }
71 |
72 | void scalar_attention_backward_launcher(
73 | int m, int h, int c, const float* weight, const float* value, const int* kq_indices,
74 | float* grad_weight, float* grad_value, const float* grad_out_F
75 | ) {
76 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
77 | dim3 blocks(DIVUP(m * c, THREADS_PER_BLOCK));
78 | dim3 threads(THREADS_PER_BLOCK);
79 | scalar_attention_backward_kernel<<>>(
80 | m, h, c, weight, value, kq_indices,
81 | grad_weight, grad_value, grad_out_F
82 | );
83 | }
84 |
--------------------------------------------------------------------------------
/src/cuda_ops/src/scalar_attention/scalar_attention_kernel.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #ifndef _scalar_attention_KERNEL
3 | #define _scalar_attention_KERNEL
4 | #include
5 | #include
6 | #include
7 |
8 | #define AT at::Tensor
9 |
10 | void scalar_attention_forward(
11 | int m, int h, int c, AT weight_tensor, AT value_tensor, AT out_F_tensor, AT kq_indices_tensor
12 | );
13 | void scalar_attention_backward(
14 | int m, int h, int c, AT weight_tensor, AT value_tensor, AT kq_indices_tensor,
15 | AT grad_weight_tensor, AT grad_value_tensor, AT grad_out_F_tensor
16 | );
17 |
18 | #ifdef __cplusplus
19 | extern "C" {
20 | #endif
21 |
22 | void scalar_attention_forward_launcher(
23 | int m, int h, int c, const float* weight, const float* value, float* out_F, const int* kq_indices
24 | );
25 | void scalar_attention_backward_launcher(
26 | int m, int h, int c, const float* weight, const float* value, const int* kq_indices,
27 | float* grad_weight, float* grad_value, const float* grad_out_F
28 | );
29 |
30 | #ifdef __cplusplus
31 | }
32 | #endif
33 | #endif
34 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from src.data.scannet_loader import *
4 | from src.data.s3dis_loader import *
5 |
6 | ALL_DATA_MODULES = [
7 | ScanNetRGBDataModule,
8 | S3DISArea5RGBDataModule,
9 | ]
10 | ALL_DATASETS = [
11 | ScanNetRGBDataset,
12 | S3DISArea5RGBDataset,
13 | ScanNetRGBDataset_,
14 | ]
15 | data_module_str_mapping = {d.__name__: d for d in ALL_DATA_MODULES}
16 | dataset_str_mapping = {d.__name__: d for d in ALL_DATASETS}
17 |
18 |
19 | def get_data_module(name: str):
20 | if name not in data_module_str_mapping.keys():
21 | logging.error(
22 | f"data_module {name}, does not exists in ".join(
23 | data_module_str_mapping.keys()
24 | )
25 | )
26 | return data_module_str_mapping[name]
27 |
28 |
29 | def get_dataset(name: str):
30 | if name not in dataset_str_mapping.keys():
31 | logging.error(
32 | f"dataset {name}, does not exists in ".join(
33 | dataset_str_mapping.keys()
34 | )
35 | )
36 | return dataset_str_mapping[name]
37 |
--------------------------------------------------------------------------------
/src/data/collate.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import gin
4 | import MinkowskiEngine as ME
5 |
6 |
7 | @gin.configurable
8 | class CollationFunctionFactory:
9 | def __init__(self, collation_type="collate_default"):
10 | if collation_type == "collate_default":
11 | self.collation_fn = self.collate_default
12 | elif collation_type == "collate_minkowski":
13 | self.collation_fn = self.collate_minkowski
14 | else:
15 | raise ValueError(f"collation_type {collation_type} not found")
16 |
17 | def __call__(self, list_data):
18 | return self.collation_fn(list_data)
19 |
20 | def collate_default(self, list_data):
21 | return list_data
22 |
23 | def collate_minkowski(self, list_data):
24 | B = len(list_data)
25 | list_data = [data for data in list_data if data is not None]
26 | if B != len(list_data):
27 | logging.info(f"Retain {len(list_data)} from {B} data.")
28 | if len(list_data) == 0:
29 | raise ValueError("No data in the batch")
30 |
31 | coords, feats, labels, extra_packages = list(zip(*list_data))
32 | row_splits = [c.shape[0] for c in coords]
33 | coords_batch, feats_batch, labels_batch = ME.utils.sparse_collate(
34 | coords, feats, labels, dtype=coords[0].dtype
35 | )
36 | return {
37 | "coordinates": coords_batch,
38 | "features": feats_batch,
39 | "labels": labels_batch,
40 | "row_splits": row_splits,
41 | "batch_size": B,
42 | "extra_packages": extra_packages,
43 | }
--------------------------------------------------------------------------------
/src/data/meta_data/s3dis/area1.txt:
--------------------------------------------------------------------------------
1 | Area_1/office_26.ply
2 | Area_1/conferenceRoom_2.ply
3 | Area_1/hallway_6.ply
4 | Area_1/office_21.ply
5 | Area_1/hallway_3.ply
6 | Area_1/office_24.ply
7 | Area_1/hallway_8.ply
8 | Area_1/office_20.ply
9 | Area_1/office_22.ply
10 | Area_1/office_13.ply
11 | Area_1/office_6.ply
12 | Area_1/office_23.ply
13 | Area_1/office_7.ply
14 | Area_1/hallway_5.ply
15 | Area_1/office_11.ply
16 | Area_1/copyRoom_1.ply
17 | Area_1/office_30.ply
18 | Area_1/office_28.ply
19 | Area_1/pantry_1.ply
20 | Area_1/office_9.ply
21 | Area_1/office_29.ply
22 | Area_1/office_14.ply
23 | Area_1/office_18.ply
24 | Area_1/office_16.ply
25 | Area_1/office_17.ply
26 | Area_1/office_1.ply
27 | Area_1/office_3.ply
28 | Area_1/office_31.ply
29 | Area_1/office_25.ply
30 | Area_1/office_15.ply
31 | Area_1/hallway_1.ply
32 | Area_1/office_10.ply
33 | Area_1/office_5.ply
34 | Area_1/conferenceRoom_1.ply
35 | Area_1/office_4.ply
36 | Area_1/hallway_4.ply
37 | Area_1/office_2.ply
38 | Area_1/WC_1.ply
39 | Area_1/office_27.ply
40 | Area_1/office_8.ply
41 | Area_1/hallway_7.ply
42 | Area_1/office_12.ply
43 | Area_1/office_19.ply
44 | Area_1/hallway_2.ply
--------------------------------------------------------------------------------
/src/data/meta_data/s3dis/area2.txt:
--------------------------------------------------------------------------------
1 | Area_2/storage_9.ply
2 | Area_2/storage_5.ply
3 | Area_2/hallway_6.ply
4 | Area_2/hallway_9.ply
5 | Area_2/storage_8.ply
6 | Area_2/hallway_3.ply
7 | Area_2/hallway_8.ply
8 | Area_2/storage_2.ply
9 | Area_2/storage_7.ply
10 | Area_2/office_13.ply
11 | Area_2/office_6.ply
12 | Area_2/storage_3.ply
13 | Area_2/office_7.ply
14 | Area_2/storage_6.ply
15 | Area_2/hallway_5.ply
16 | Area_2/hallway_11.ply
17 | Area_2/office_11.ply
18 | Area_2/hallway_12.ply
19 | Area_2/WC_2.ply
20 | Area_2/office_9.ply
21 | Area_2/storage_1.ply
22 | Area_2/office_14.ply
23 | Area_2/auditorium_2.ply
24 | Area_2/office_1.ply
25 | Area_2/office_3.ply
26 | Area_2/hallway_1.ply
27 | Area_2/office_10.ply
28 | Area_2/office_5.ply
29 | Area_2/conferenceRoom_1.ply
30 | Area_2/office_4.ply
31 | Area_2/hallway_4.ply
32 | Area_2/office_2.ply
33 | Area_2/WC_1.ply
34 | Area_2/storage_4.ply
35 | Area_2/auditorium_1.ply
36 | Area_2/office_8.ply
37 | Area_2/hallway_7.ply
38 | Area_2/office_12.ply
39 | Area_2/hallway_2.ply
40 | Area_2/hallway_10.ply
--------------------------------------------------------------------------------
/src/data/meta_data/s3dis/area3.txt:
--------------------------------------------------------------------------------
1 | Area_3/hallway_6.ply
2 | Area_3/hallway_3.ply
3 | Area_3/storage_2.ply
4 | Area_3/office_6.ply
5 | Area_3/office_7.ply
6 | Area_3/hallway_5.ply
7 | Area_3/WC_2.ply
8 | Area_3/office_9.ply
9 | Area_3/storage_1.ply
10 | Area_3/office_1.ply
11 | Area_3/lounge_1.ply
12 | Area_3/office_3.ply
13 | Area_3/hallway_1.ply
14 | Area_3/office_10.ply
15 | Area_3/office_5.ply
16 | Area_3/conferenceRoom_1.ply
17 | Area_3/office_4.ply
18 | Area_3/hallway_4.ply
19 | Area_3/office_2.ply
20 | Area_3/WC_1.ply
21 | Area_3/office_8.ply
22 | Area_3/hallway_2.ply
23 | Area_3/lounge_2.ply
--------------------------------------------------------------------------------
/src/data/meta_data/s3dis/area4.txt:
--------------------------------------------------------------------------------
1 | Area_4/conferenceRoom_2.ply
2 | Area_4/hallway_6.ply
3 | Area_4/office_21.ply
4 | Area_4/hallway_9.ply
5 | Area_4/hallway_3.ply
6 | Area_4/hallway_8.ply
7 | Area_4/office_20.ply
8 | Area_4/office_22.ply
9 | Area_4/hallway_13.ply
10 | Area_4/storage_2.ply
11 | Area_4/office_13.ply
12 | Area_4/office_6.ply
13 | Area_4/storage_3.ply
14 | Area_4/lobby_1.ply
15 | Area_4/office_7.ply
16 | Area_4/hallway_5.ply
17 | Area_4/hallway_11.ply
18 | Area_4/office_11.ply
19 | Area_4/hallway_12.ply
20 | Area_4/WC_4.ply
21 | Area_4/WC_2.ply
22 | Area_4/lobby_2.ply
23 | Area_4/office_9.ply
24 | Area_4/hallway_14.ply
25 | Area_4/WC_3.ply
26 | Area_4/conferenceRoom_3.ply
27 | Area_4/storage_1.ply
28 | Area_4/office_14.ply
29 | Area_4/office_18.ply
30 | Area_4/office_16.ply
31 | Area_4/office_17.ply
32 | Area_4/office_1.ply
33 | Area_4/office_3.ply
34 | Area_4/office_15.ply
35 | Area_4/hallway_1.ply
36 | Area_4/office_10.ply
37 | Area_4/office_5.ply
38 | Area_4/conferenceRoom_1.ply
39 | Area_4/office_4.ply
40 | Area_4/hallway_4.ply
41 | Area_4/office_2.ply
42 | Area_4/WC_1.ply
43 | Area_4/storage_4.ply
44 | Area_4/office_8.ply
45 | Area_4/hallway_7.ply
46 | Area_4/office_12.ply
47 | Area_4/office_19.ply
48 | Area_4/hallway_2.ply
49 | Area_4/hallway_10.ply
--------------------------------------------------------------------------------
/src/data/meta_data/s3dis/area5.txt:
--------------------------------------------------------------------------------
1 | Area_5/office_26.ply
2 | Area_5/conferenceRoom_2.ply
3 | Area_5/hallway_6.ply
4 | Area_5/office_21.ply
5 | Area_5/hallway_9.ply
6 | Area_5/hallway_3.ply
7 | Area_5/office_39.ply
8 | Area_5/office_24.ply
9 | Area_5/hallway_8.ply
10 | Area_5/office_20.ply
11 | Area_5/office_22.ply
12 | Area_5/hallway_13.ply
13 | Area_5/storage_2.ply
14 | Area_5/office_13.ply
15 | Area_5/office_6.ply
16 | Area_5/office_41.ply
17 | Area_5/office_23.ply
18 | Area_5/office_34.ply
19 | Area_5/office_35.ply
20 | Area_5/office_36.ply
21 | Area_5/storage_3.ply
22 | Area_5/lobby_1.ply
23 | Area_5/office_7.ply
24 | Area_5/hallway_5.ply
25 | Area_5/hallway_11.ply
26 | Area_5/office_11.ply
27 | Area_5/office_37.ply
28 | Area_5/hallway_12.ply
29 | Area_5/office_30.ply
30 | Area_5/office_28.ply
31 | Area_5/pantry_1.ply
32 | Area_5/WC_2.ply
33 | Area_5/office_9.ply
34 | Area_5/office_38.ply
35 | Area_5/office_42.ply
36 | Area_5/hallway_14.ply
37 | Area_5/conferenceRoom_3.ply
38 | Area_5/storage_1.ply
39 | Area_5/office_29.ply
40 | Area_5/office_14.ply
41 | Area_5/office_18.ply
42 | Area_5/office_16.ply
43 | Area_5/office_17.ply
44 | Area_5/office_1.ply
45 | Area_5/office_3.ply
46 | Area_5/office_31.ply
47 | Area_5/office_25.ply
48 | Area_5/hallway_15.ply
49 | Area_5/office_15.ply
50 | Area_5/hallway_1.ply
51 | Area_5/office_10.ply
52 | Area_5/office_5.ply
53 | Area_5/conferenceRoom_1.ply
54 | Area_5/office_4.ply
55 | Area_5/hallway_4.ply
56 | Area_5/office_2.ply
57 | Area_5/office_33.ply
58 | Area_5/WC_1.ply
59 | Area_5/office_27.ply
60 | Area_5/storage_4.ply
61 | Area_5/office_40.ply
62 | Area_5/office_8.ply
63 | Area_5/hallway_7.ply
64 | Area_5/office_12.ply
65 | Area_5/office_32.ply
66 | Area_5/office_19.ply
67 | Area_5/hallway_2.ply
68 | Area_5/hallway_10.ply
--------------------------------------------------------------------------------
/src/data/meta_data/s3dis/area6.txt:
--------------------------------------------------------------------------------
1 | Area_6/office_26.ply
2 | Area_6/hallway_6.ply
3 | Area_6/office_21.ply
4 | Area_6/hallway_3.ply
5 | Area_6/office_24.ply
6 | Area_6/office_20.ply
7 | Area_6/office_22.ply
8 | Area_6/office_13.ply
9 | Area_6/office_6.ply
10 | Area_6/office_23.ply
11 | Area_6/office_34.ply
12 | Area_6/office_35.ply
13 | Area_6/office_36.ply
14 | Area_6/office_7.ply
15 | Area_6/hallway_5.ply
16 | Area_6/office_11.ply
17 | Area_6/copyRoom_1.ply
18 | Area_6/office_37.ply
19 | Area_6/office_30.ply
20 | Area_6/office_28.ply
21 | Area_6/pantry_1.ply
22 | Area_6/openspace_1.ply
23 | Area_6/office_9.ply
24 | Area_6/office_29.ply
25 | Area_6/office_14.ply
26 | Area_6/office_18.ply
27 | Area_6/office_16.ply
28 | Area_6/office_17.ply
29 | Area_6/office_1.ply
30 | Area_6/lounge_1.ply
31 | Area_6/office_3.ply
32 | Area_6/office_31.ply
33 | Area_6/office_25.ply
34 | Area_6/office_15.ply
35 | Area_6/hallway_1.ply
36 | Area_6/office_10.ply
37 | Area_6/office_5.ply
38 | Area_6/conferenceRoom_1.ply
39 | Area_6/office_4.ply
40 | Area_6/hallway_4.ply
41 | Area_6/office_2.ply
42 | Area_6/office_33.ply
43 | Area_6/office_27.ply
44 | Area_6/office_8.ply
45 | Area_6/office_12.ply
46 | Area_6/office_32.ply
47 | Area_6/office_19.ply
48 | Area_6/hallway_2.ply
--------------------------------------------------------------------------------
/src/data/meta_data/scannet/scannetv2_test.txt:
--------------------------------------------------------------------------------
1 | scene0707_00
2 | scene0708_00
3 | scene0709_00
4 | scene0710_00
5 | scene0711_00
6 | scene0712_00
7 | scene0713_00
8 | scene0714_00
9 | scene0715_00
10 | scene0716_00
11 | scene0717_00
12 | scene0718_00
13 | scene0719_00
14 | scene0720_00
15 | scene0721_00
16 | scene0722_00
17 | scene0723_00
18 | scene0724_00
19 | scene0725_00
20 | scene0726_00
21 | scene0727_00
22 | scene0728_00
23 | scene0729_00
24 | scene0730_00
25 | scene0731_00
26 | scene0732_00
27 | scene0733_00
28 | scene0734_00
29 | scene0735_00
30 | scene0736_00
31 | scene0737_00
32 | scene0738_00
33 | scene0739_00
34 | scene0740_00
35 | scene0741_00
36 | scene0742_00
37 | scene0743_00
38 | scene0744_00
39 | scene0745_00
40 | scene0746_00
41 | scene0747_00
42 | scene0748_00
43 | scene0749_00
44 | scene0750_00
45 | scene0751_00
46 | scene0752_00
47 | scene0753_00
48 | scene0754_00
49 | scene0755_00
50 | scene0756_00
51 | scene0757_00
52 | scene0758_00
53 | scene0759_00
54 | scene0760_00
55 | scene0761_00
56 | scene0762_00
57 | scene0763_00
58 | scene0764_00
59 | scene0765_00
60 | scene0766_00
61 | scene0767_00
62 | scene0768_00
63 | scene0769_00
64 | scene0770_00
65 | scene0771_00
66 | scene0772_00
67 | scene0773_00
68 | scene0774_00
69 | scene0775_00
70 | scene0776_00
71 | scene0777_00
72 | scene0778_00
73 | scene0779_00
74 | scene0780_00
75 | scene0781_00
76 | scene0782_00
77 | scene0783_00
78 | scene0784_00
79 | scene0785_00
80 | scene0786_00
81 | scene0787_00
82 | scene0788_00
83 | scene0789_00
84 | scene0790_00
85 | scene0791_00
86 | scene0792_00
87 | scene0793_00
88 | scene0794_00
89 | scene0795_00
90 | scene0796_00
91 | scene0797_00
92 | scene0798_00
93 | scene0799_00
94 | scene0800_00
95 | scene0801_00
96 | scene0802_00
97 | scene0803_00
98 | scene0804_00
99 | scene0805_00
100 | scene0806_00
101 |
--------------------------------------------------------------------------------
/src/data/meta_data/scannet/scannetv2_train.txt:
--------------------------------------------------------------------------------
1 | scene0191_00
2 | scene0191_01
3 | scene0191_02
4 | scene0119_00
5 | scene0230_00
6 | scene0528_00
7 | scene0528_01
8 | scene0705_00
9 | scene0705_01
10 | scene0705_02
11 | scene0415_00
12 | scene0415_01
13 | scene0415_02
14 | scene0007_00
15 | scene0141_00
16 | scene0141_01
17 | scene0141_02
18 | scene0515_00
19 | scene0515_01
20 | scene0515_02
21 | scene0447_00
22 | scene0447_01
23 | scene0447_02
24 | scene0531_00
25 | scene0503_00
26 | scene0285_00
27 | scene0069_00
28 | scene0584_00
29 | scene0584_01
30 | scene0584_02
31 | scene0581_00
32 | scene0581_01
33 | scene0581_02
34 | scene0620_00
35 | scene0620_01
36 | scene0263_00
37 | scene0263_01
38 | scene0481_00
39 | scene0481_01
40 | scene0020_00
41 | scene0020_01
42 | scene0291_00
43 | scene0291_01
44 | scene0291_02
45 | scene0469_00
46 | scene0469_01
47 | scene0469_02
48 | scene0659_00
49 | scene0659_01
50 | scene0024_00
51 | scene0024_01
52 | scene0024_02
53 | scene0564_00
54 | scene0117_00
55 | scene0027_00
56 | scene0027_01
57 | scene0027_02
58 | scene0028_00
59 | scene0330_00
60 | scene0418_00
61 | scene0418_01
62 | scene0418_02
63 | scene0233_00
64 | scene0233_01
65 | scene0673_00
66 | scene0673_01
67 | scene0673_02
68 | scene0673_03
69 | scene0673_04
70 | scene0673_05
71 | scene0585_00
72 | scene0585_01
73 | scene0362_00
74 | scene0362_01
75 | scene0362_02
76 | scene0362_03
77 | scene0035_00
78 | scene0035_01
79 | scene0358_00
80 | scene0358_01
81 | scene0358_02
82 | scene0037_00
83 | scene0194_00
84 | scene0321_00
85 | scene0293_00
86 | scene0293_01
87 | scene0623_00
88 | scene0623_01
89 | scene0592_00
90 | scene0592_01
91 | scene0569_00
92 | scene0569_01
93 | scene0413_00
94 | scene0313_00
95 | scene0313_01
96 | scene0313_02
97 | scene0480_00
98 | scene0480_01
99 | scene0401_00
100 | scene0517_00
101 | scene0517_01
102 | scene0517_02
103 | scene0032_00
104 | scene0032_01
105 | scene0613_00
106 | scene0613_01
107 | scene0613_02
108 | scene0306_00
109 | scene0306_01
110 | scene0052_00
111 | scene0052_01
112 | scene0052_02
113 | scene0053_00
114 | scene0444_00
115 | scene0444_01
116 | scene0055_00
117 | scene0055_01
118 | scene0055_02
119 | scene0560_00
120 | scene0589_00
121 | scene0589_01
122 | scene0589_02
123 | scene0610_00
124 | scene0610_01
125 | scene0610_02
126 | scene0364_00
127 | scene0364_01
128 | scene0383_00
129 | scene0383_01
130 | scene0383_02
131 | scene0006_00
132 | scene0006_01
133 | scene0006_02
134 | scene0275_00
135 | scene0451_00
136 | scene0451_01
137 | scene0451_02
138 | scene0451_03
139 | scene0451_04
140 | scene0451_05
141 | scene0135_00
142 | scene0065_00
143 | scene0065_01
144 | scene0065_02
145 | scene0104_00
146 | scene0674_00
147 | scene0674_01
148 | scene0448_00
149 | scene0448_01
150 | scene0448_02
151 | scene0502_00
152 | scene0502_01
153 | scene0502_02
154 | scene0440_00
155 | scene0440_01
156 | scene0440_02
157 | scene0071_00
158 | scene0072_00
159 | scene0072_01
160 | scene0072_02
161 | scene0509_00
162 | scene0509_01
163 | scene0509_02
164 | scene0649_00
165 | scene0649_01
166 | scene0602_00
167 | scene0694_00
168 | scene0694_01
169 | scene0101_00
170 | scene0101_01
171 | scene0101_02
172 | scene0101_03
173 | scene0101_04
174 | scene0101_05
175 | scene0218_00
176 | scene0218_01
177 | scene0579_00
178 | scene0579_01
179 | scene0579_02
180 | scene0039_00
181 | scene0039_01
182 | scene0493_00
183 | scene0493_01
184 | scene0242_00
185 | scene0242_01
186 | scene0242_02
187 | scene0083_00
188 | scene0083_01
189 | scene0127_00
190 | scene0127_01
191 | scene0662_00
192 | scene0662_01
193 | scene0662_02
194 | scene0018_00
195 | scene0087_00
196 | scene0087_01
197 | scene0087_02
198 | scene0332_00
199 | scene0332_01
200 | scene0332_02
201 | scene0628_00
202 | scene0628_01
203 | scene0628_02
204 | scene0134_00
205 | scene0134_01
206 | scene0134_02
207 | scene0238_00
208 | scene0238_01
209 | scene0092_00
210 | scene0092_01
211 | scene0092_02
212 | scene0092_03
213 | scene0092_04
214 | scene0022_00
215 | scene0022_01
216 | scene0467_00
217 | scene0392_00
218 | scene0392_01
219 | scene0392_02
220 | scene0424_00
221 | scene0424_01
222 | scene0424_02
223 | scene0646_00
224 | scene0646_01
225 | scene0646_02
226 | scene0098_00
227 | scene0098_01
228 | scene0044_00
229 | scene0044_01
230 | scene0044_02
231 | scene0510_00
232 | scene0510_01
233 | scene0510_02
234 | scene0571_00
235 | scene0571_01
236 | scene0166_00
237 | scene0166_01
238 | scene0166_02
239 | scene0563_00
240 | scene0172_00
241 | scene0172_01
242 | scene0388_00
243 | scene0388_01
244 | scene0215_00
245 | scene0215_01
246 | scene0252_00
247 | scene0287_00
248 | scene0668_00
249 | scene0572_00
250 | scene0572_01
251 | scene0572_02
252 | scene0026_00
253 | scene0224_00
254 | scene0113_00
255 | scene0113_01
256 | scene0551_00
257 | scene0381_00
258 | scene0381_01
259 | scene0381_02
260 | scene0371_00
261 | scene0371_01
262 | scene0460_00
263 | scene0118_00
264 | scene0118_01
265 | scene0118_02
266 | scene0417_00
267 | scene0008_00
268 | scene0634_00
269 | scene0521_00
270 | scene0123_00
271 | scene0123_01
272 | scene0123_02
273 | scene0045_00
274 | scene0045_01
275 | scene0511_00
276 | scene0511_01
277 | scene0114_00
278 | scene0114_01
279 | scene0114_02
280 | scene0070_00
281 | scene0029_00
282 | scene0029_01
283 | scene0029_02
284 | scene0129_00
285 | scene0103_00
286 | scene0103_01
287 | scene0002_00
288 | scene0002_01
289 | scene0132_00
290 | scene0132_01
291 | scene0132_02
292 | scene0124_00
293 | scene0124_01
294 | scene0143_00
295 | scene0143_01
296 | scene0143_02
297 | scene0604_00
298 | scene0604_01
299 | scene0604_02
300 | scene0507_00
301 | scene0105_00
302 | scene0105_01
303 | scene0105_02
304 | scene0428_00
305 | scene0428_01
306 | scene0311_00
307 | scene0140_00
308 | scene0140_01
309 | scene0182_00
310 | scene0182_01
311 | scene0182_02
312 | scene0142_00
313 | scene0142_01
314 | scene0399_00
315 | scene0399_01
316 | scene0012_00
317 | scene0012_01
318 | scene0012_02
319 | scene0060_00
320 | scene0060_01
321 | scene0370_00
322 | scene0370_01
323 | scene0370_02
324 | scene0310_00
325 | scene0310_01
326 | scene0310_02
327 | scene0661_00
328 | scene0650_00
329 | scene0152_00
330 | scene0152_01
331 | scene0152_02
332 | scene0158_00
333 | scene0158_01
334 | scene0158_02
335 | scene0482_00
336 | scene0482_01
337 | scene0600_00
338 | scene0600_01
339 | scene0600_02
340 | scene0393_00
341 | scene0393_01
342 | scene0393_02
343 | scene0562_00
344 | scene0174_00
345 | scene0174_01
346 | scene0157_00
347 | scene0157_01
348 | scene0161_00
349 | scene0161_01
350 | scene0161_02
351 | scene0159_00
352 | scene0254_00
353 | scene0254_01
354 | scene0115_00
355 | scene0115_01
356 | scene0115_02
357 | scene0162_00
358 | scene0163_00
359 | scene0163_01
360 | scene0523_00
361 | scene0523_01
362 | scene0523_02
363 | scene0459_00
364 | scene0459_01
365 | scene0175_00
366 | scene0085_00
367 | scene0085_01
368 | scene0279_00
369 | scene0279_01
370 | scene0279_02
371 | scene0201_00
372 | scene0201_01
373 | scene0201_02
374 | scene0283_00
375 | scene0456_00
376 | scene0456_01
377 | scene0429_00
378 | scene0043_00
379 | scene0043_01
380 | scene0419_00
381 | scene0419_01
382 | scene0419_02
383 | scene0368_00
384 | scene0368_01
385 | scene0348_00
386 | scene0348_01
387 | scene0348_02
388 | scene0442_00
389 | scene0178_00
390 | scene0380_00
391 | scene0380_01
392 | scene0380_02
393 | scene0165_00
394 | scene0165_01
395 | scene0165_02
396 | scene0181_00
397 | scene0181_01
398 | scene0181_02
399 | scene0181_03
400 | scene0333_00
401 | scene0614_00
402 | scene0614_01
403 | scene0614_02
404 | scene0404_00
405 | scene0404_01
406 | scene0404_02
407 | scene0185_00
408 | scene0126_00
409 | scene0126_01
410 | scene0126_02
411 | scene0519_00
412 | scene0236_00
413 | scene0236_01
414 | scene0189_00
415 | scene0075_00
416 | scene0267_00
417 | scene0192_00
418 | scene0192_01
419 | scene0192_02
420 | scene0281_00
421 | scene0420_00
422 | scene0420_01
423 | scene0420_02
424 | scene0195_00
425 | scene0195_01
426 | scene0195_02
427 | scene0597_00
428 | scene0597_01
429 | scene0597_02
430 | scene0041_00
431 | scene0041_01
432 | scene0111_00
433 | scene0111_01
434 | scene0111_02
435 | scene0666_00
436 | scene0666_01
437 | scene0666_02
438 | scene0200_00
439 | scene0200_01
440 | scene0200_02
441 | scene0536_00
442 | scene0536_01
443 | scene0536_02
444 | scene0390_00
445 | scene0280_00
446 | scene0280_01
447 | scene0280_02
448 | scene0344_00
449 | scene0344_01
450 | scene0205_00
451 | scene0205_01
452 | scene0205_02
453 | scene0484_00
454 | scene0484_01
455 | scene0009_00
456 | scene0009_01
457 | scene0009_02
458 | scene0302_00
459 | scene0302_01
460 | scene0209_00
461 | scene0209_01
462 | scene0209_02
463 | scene0210_00
464 | scene0210_01
465 | scene0395_00
466 | scene0395_01
467 | scene0395_02
468 | scene0683_00
469 | scene0601_00
470 | scene0601_01
471 | scene0214_00
472 | scene0214_01
473 | scene0214_02
474 | scene0477_00
475 | scene0477_01
476 | scene0439_00
477 | scene0439_01
478 | scene0468_00
479 | scene0468_01
480 | scene0468_02
481 | scene0546_00
482 | scene0466_00
483 | scene0466_01
484 | scene0220_00
485 | scene0220_01
486 | scene0220_02
487 | scene0122_00
488 | scene0122_01
489 | scene0130_00
490 | scene0110_00
491 | scene0110_01
492 | scene0110_02
493 | scene0327_00
494 | scene0156_00
495 | scene0266_00
496 | scene0266_01
497 | scene0001_00
498 | scene0001_01
499 | scene0228_00
500 | scene0199_00
501 | scene0219_00
502 | scene0464_00
503 | scene0232_00
504 | scene0232_01
505 | scene0232_02
506 | scene0299_00
507 | scene0299_01
508 | scene0530_00
509 | scene0363_00
510 | scene0453_00
511 | scene0453_01
512 | scene0570_00
513 | scene0570_01
514 | scene0570_02
515 | scene0183_00
516 | scene0239_00
517 | scene0239_01
518 | scene0239_02
519 | scene0373_00
520 | scene0373_01
521 | scene0241_00
522 | scene0241_01
523 | scene0241_02
524 | scene0188_00
525 | scene0622_00
526 | scene0622_01
527 | scene0244_00
528 | scene0244_01
529 | scene0691_00
530 | scene0691_01
531 | scene0206_00
532 | scene0206_01
533 | scene0206_02
534 | scene0247_00
535 | scene0247_01
536 | scene0061_00
537 | scene0061_01
538 | scene0082_00
539 | scene0250_00
540 | scene0250_01
541 | scene0250_02
542 | scene0501_00
543 | scene0501_01
544 | scene0501_02
545 | scene0320_00
546 | scene0320_01
547 | scene0320_02
548 | scene0320_03
549 | scene0631_00
550 | scene0631_01
551 | scene0631_02
552 | scene0255_00
553 | scene0255_01
554 | scene0255_02
555 | scene0047_00
556 | scene0265_00
557 | scene0265_01
558 | scene0265_02
559 | scene0004_00
560 | scene0336_00
561 | scene0336_01
562 | scene0058_00
563 | scene0058_01
564 | scene0260_00
565 | scene0260_01
566 | scene0260_02
567 | scene0243_00
568 | scene0603_00
569 | scene0603_01
570 | scene0093_00
571 | scene0093_01
572 | scene0093_02
573 | scene0109_00
574 | scene0109_01
575 | scene0434_00
576 | scene0434_01
577 | scene0434_02
578 | scene0290_00
579 | scene0627_00
580 | scene0627_01
581 | scene0470_00
582 | scene0470_01
583 | scene0137_00
584 | scene0137_01
585 | scene0137_02
586 | scene0270_00
587 | scene0270_01
588 | scene0270_02
589 | scene0271_00
590 | scene0271_01
591 | scene0504_00
592 | scene0274_00
593 | scene0274_01
594 | scene0274_02
595 | scene0036_00
596 | scene0036_01
597 | scene0276_00
598 | scene0276_01
599 | scene0272_00
600 | scene0272_01
601 | scene0499_00
602 | scene0698_00
603 | scene0698_01
604 | scene0051_00
605 | scene0051_01
606 | scene0051_02
607 | scene0051_03
608 | scene0108_00
609 | scene0245_00
610 | scene0369_00
611 | scene0369_01
612 | scene0369_02
613 | scene0284_00
614 | scene0289_00
615 | scene0289_01
616 | scene0286_00
617 | scene0286_01
618 | scene0286_02
619 | scene0286_03
620 | scene0031_00
621 | scene0031_01
622 | scene0031_02
623 | scene0545_00
624 | scene0545_01
625 | scene0545_02
626 | scene0557_00
627 | scene0557_01
628 | scene0557_02
629 | scene0533_00
630 | scene0533_01
631 | scene0116_00
632 | scene0116_01
633 | scene0116_02
634 | scene0611_00
635 | scene0611_01
636 | scene0688_00
637 | scene0294_00
638 | scene0294_01
639 | scene0294_02
640 | scene0295_00
641 | scene0295_01
642 | scene0296_00
643 | scene0296_01
644 | scene0596_00
645 | scene0596_01
646 | scene0596_02
647 | scene0532_00
648 | scene0532_01
649 | scene0637_00
650 | scene0638_00
651 | scene0121_00
652 | scene0121_01
653 | scene0121_02
654 | scene0040_00
655 | scene0040_01
656 | scene0197_00
657 | scene0197_01
658 | scene0197_02
659 | scene0410_00
660 | scene0410_01
661 | scene0305_00
662 | scene0305_01
663 | scene0615_00
664 | scene0615_01
665 | scene0703_00
666 | scene0703_01
667 | scene0555_00
668 | scene0297_00
669 | scene0297_01
670 | scene0297_02
671 | scene0582_00
672 | scene0582_01
673 | scene0582_02
674 | scene0023_00
675 | scene0094_00
676 | scene0013_00
677 | scene0013_01
678 | scene0013_02
679 | scene0136_00
680 | scene0136_01
681 | scene0136_02
682 | scene0407_00
683 | scene0407_01
684 | scene0062_00
685 | scene0062_01
686 | scene0062_02
687 | scene0386_00
688 | scene0318_00
689 | scene0554_00
690 | scene0554_01
691 | scene0497_00
692 | scene0213_00
693 | scene0258_00
694 | scene0323_00
695 | scene0323_01
696 | scene0324_00
697 | scene0324_01
698 | scene0016_00
699 | scene0016_01
700 | scene0016_02
701 | scene0681_00
702 | scene0398_00
703 | scene0398_01
704 | scene0227_00
705 | scene0090_00
706 | scene0066_00
707 | scene0262_00
708 | scene0262_01
709 | scene0155_00
710 | scene0155_01
711 | scene0155_02
712 | scene0352_00
713 | scene0352_01
714 | scene0352_02
715 | scene0038_00
716 | scene0038_01
717 | scene0038_02
718 | scene0335_00
719 | scene0335_01
720 | scene0335_02
721 | scene0261_00
722 | scene0261_01
723 | scene0261_02
724 | scene0261_03
725 | scene0640_00
726 | scene0640_01
727 | scene0640_02
728 | scene0080_00
729 | scene0080_01
730 | scene0080_02
731 | scene0403_00
732 | scene0403_01
733 | scene0282_00
734 | scene0282_01
735 | scene0282_02
736 | scene0682_00
737 | scene0173_00
738 | scene0173_01
739 | scene0173_02
740 | scene0522_00
741 | scene0687_00
742 | scene0345_00
743 | scene0345_01
744 | scene0612_00
745 | scene0612_01
746 | scene0411_00
747 | scene0411_01
748 | scene0411_02
749 | scene0625_00
750 | scene0625_01
751 | scene0211_00
752 | scene0211_01
753 | scene0211_02
754 | scene0211_03
755 | scene0676_00
756 | scene0676_01
757 | scene0179_00
758 | scene0498_00
759 | scene0498_01
760 | scene0498_02
761 | scene0547_00
762 | scene0547_01
763 | scene0547_02
764 | scene0269_00
765 | scene0269_01
766 | scene0269_02
767 | scene0366_00
768 | scene0680_00
769 | scene0680_01
770 | scene0588_00
771 | scene0588_01
772 | scene0588_02
773 | scene0588_03
774 | scene0346_00
775 | scene0346_01
776 | scene0359_00
777 | scene0359_01
778 | scene0014_00
779 | scene0120_00
780 | scene0120_01
781 | scene0212_00
782 | scene0212_01
783 | scene0212_02
784 | scene0176_00
785 | scene0049_00
786 | scene0259_00
787 | scene0259_01
788 | scene0586_00
789 | scene0586_01
790 | scene0586_02
791 | scene0309_00
792 | scene0309_01
793 | scene0125_00
794 | scene0455_00
795 | scene0177_00
796 | scene0177_01
797 | scene0177_02
798 | scene0326_00
799 | scene0372_00
800 | scene0171_00
801 | scene0171_01
802 | scene0374_00
803 | scene0654_00
804 | scene0654_01
805 | scene0445_00
806 | scene0445_01
807 | scene0475_00
808 | scene0475_01
809 | scene0475_02
810 | scene0349_00
811 | scene0349_01
812 | scene0234_00
813 | scene0669_00
814 | scene0669_01
815 | scene0375_00
816 | scene0375_01
817 | scene0375_02
818 | scene0387_00
819 | scene0387_01
820 | scene0387_02
821 | scene0312_00
822 | scene0312_01
823 | scene0312_02
824 | scene0384_00
825 | scene0385_00
826 | scene0385_01
827 | scene0385_02
828 | scene0000_00
829 | scene0000_01
830 | scene0000_02
831 | scene0376_00
832 | scene0376_01
833 | scene0376_02
834 | scene0301_00
835 | scene0301_01
836 | scene0301_02
837 | scene0322_00
838 | scene0542_00
839 | scene0079_00
840 | scene0079_01
841 | scene0099_00
842 | scene0099_01
843 | scene0476_00
844 | scene0476_01
845 | scene0476_02
846 | scene0394_00
847 | scene0394_01
848 | scene0147_00
849 | scene0147_01
850 | scene0067_00
851 | scene0067_01
852 | scene0067_02
853 | scene0397_00
854 | scene0397_01
855 | scene0337_00
856 | scene0337_01
857 | scene0337_02
858 | scene0431_00
859 | scene0223_00
860 | scene0223_01
861 | scene0223_02
862 | scene0010_00
863 | scene0010_01
864 | scene0402_00
865 | scene0268_00
866 | scene0268_01
867 | scene0268_02
868 | scene0679_00
869 | scene0679_01
870 | scene0405_00
871 | scene0128_00
872 | scene0408_00
873 | scene0408_01
874 | scene0190_00
875 | scene0107_00
876 | scene0076_00
877 | scene0167_00
878 | scene0361_00
879 | scene0361_01
880 | scene0361_02
881 | scene0216_00
882 | scene0202_00
883 | scene0303_00
884 | scene0303_01
885 | scene0303_02
886 | scene0446_00
887 | scene0446_01
888 | scene0089_00
889 | scene0089_01
890 | scene0089_02
891 | scene0360_00
892 | scene0150_00
893 | scene0150_01
894 | scene0150_02
895 | scene0421_00
896 | scene0421_01
897 | scene0421_02
898 | scene0454_00
899 | scene0626_00
900 | scene0626_01
901 | scene0626_02
902 | scene0186_00
903 | scene0186_01
904 | scene0538_00
905 | scene0479_00
906 | scene0479_01
907 | scene0479_02
908 | scene0656_00
909 | scene0656_01
910 | scene0656_02
911 | scene0656_03
912 | scene0525_00
913 | scene0525_01
914 | scene0525_02
915 | scene0308_00
916 | scene0396_00
917 | scene0396_01
918 | scene0396_02
919 | scene0624_00
920 | scene0292_00
921 | scene0292_01
922 | scene0632_00
923 | scene0253_00
924 | scene0021_00
925 | scene0325_00
926 | scene0325_01
927 | scene0437_00
928 | scene0437_01
929 | scene0438_00
930 | scene0590_00
931 | scene0590_01
932 | scene0400_00
933 | scene0400_01
934 | scene0541_00
935 | scene0541_01
936 | scene0541_02
937 | scene0677_00
938 | scene0677_01
939 | scene0677_02
940 | scene0443_00
941 | scene0315_00
942 | scene0288_00
943 | scene0288_01
944 | scene0288_02
945 | scene0422_00
946 | scene0672_00
947 | scene0672_01
948 | scene0184_00
949 | scene0449_00
950 | scene0449_01
951 | scene0449_02
952 | scene0048_00
953 | scene0048_01
954 | scene0138_00
955 | scene0452_00
956 | scene0452_01
957 | scene0452_02
958 | scene0667_00
959 | scene0667_01
960 | scene0667_02
961 | scene0463_00
962 | scene0463_01
963 | scene0078_00
964 | scene0078_01
965 | scene0078_02
966 | scene0636_00
967 | scene0457_00
968 | scene0457_01
969 | scene0457_02
970 | scene0465_00
971 | scene0465_01
972 | scene0577_00
973 | scene0151_00
974 | scene0151_01
975 | scene0339_00
976 | scene0573_00
977 | scene0573_01
978 | scene0154_00
979 | scene0096_00
980 | scene0096_01
981 | scene0096_02
982 | scene0235_00
983 | scene0168_00
984 | scene0168_01
985 | scene0168_02
986 | scene0594_00
987 | scene0587_00
988 | scene0587_01
989 | scene0587_02
990 | scene0587_03
991 | scene0229_00
992 | scene0229_01
993 | scene0229_02
994 | scene0512_00
995 | scene0106_00
996 | scene0106_01
997 | scene0106_02
998 | scene0472_00
999 | scene0472_01
1000 | scene0472_02
1001 | scene0489_00
1002 | scene0489_01
1003 | scene0489_02
1004 | scene0425_00
1005 | scene0425_01
1006 | scene0641_00
1007 | scene0526_00
1008 | scene0526_01
1009 | scene0317_00
1010 | scene0317_01
1011 | scene0544_00
1012 | scene0017_00
1013 | scene0017_01
1014 | scene0017_02
1015 | scene0042_00
1016 | scene0042_01
1017 | scene0042_02
1018 | scene0576_00
1019 | scene0576_01
1020 | scene0576_02
1021 | scene0347_00
1022 | scene0347_01
1023 | scene0347_02
1024 | scene0436_00
1025 | scene0226_00
1026 | scene0226_01
1027 | scene0485_00
1028 | scene0486_00
1029 | scene0487_00
1030 | scene0487_01
1031 | scene0619_00
1032 | scene0097_00
1033 | scene0367_00
1034 | scene0367_01
1035 | scene0491_00
1036 | scene0492_00
1037 | scene0492_01
1038 | scene0005_00
1039 | scene0005_01
1040 | scene0543_00
1041 | scene0543_01
1042 | scene0543_02
1043 | scene0657_00
1044 | scene0341_00
1045 | scene0341_01
1046 | scene0534_00
1047 | scene0534_01
1048 | scene0319_00
1049 | scene0273_00
1050 | scene0273_01
1051 | scene0225_00
1052 | scene0198_00
1053 | scene0003_00
1054 | scene0003_01
1055 | scene0003_02
1056 | scene0409_00
1057 | scene0409_01
1058 | scene0331_00
1059 | scene0331_01
1060 | scene0505_00
1061 | scene0505_01
1062 | scene0505_02
1063 | scene0505_03
1064 | scene0505_04
1065 | scene0506_00
1066 | scene0057_00
1067 | scene0057_01
1068 | scene0074_00
1069 | scene0074_01
1070 | scene0074_02
1071 | scene0091_00
1072 | scene0112_00
1073 | scene0112_01
1074 | scene0112_02
1075 | scene0240_00
1076 | scene0102_00
1077 | scene0102_01
1078 | scene0513_00
1079 | scene0514_00
1080 | scene0514_01
1081 | scene0537_00
1082 | scene0516_00
1083 | scene0516_01
1084 | scene0495_00
1085 | scene0617_00
1086 | scene0133_00
1087 | scene0520_00
1088 | scene0520_01
1089 | scene0635_00
1090 | scene0635_01
1091 | scene0054_00
1092 | scene0473_00
1093 | scene0473_01
1094 | scene0524_00
1095 | scene0524_01
1096 | scene0379_00
1097 | scene0471_00
1098 | scene0471_01
1099 | scene0471_02
1100 | scene0566_00
1101 | scene0248_00
1102 | scene0248_01
1103 | scene0248_02
1104 | scene0529_00
1105 | scene0529_01
1106 | scene0529_02
1107 | scene0391_00
1108 | scene0264_00
1109 | scene0264_01
1110 | scene0264_02
1111 | scene0675_00
1112 | scene0675_01
1113 | scene0350_00
1114 | scene0350_01
1115 | scene0350_02
1116 | scene0450_00
1117 | scene0068_00
1118 | scene0068_01
1119 | scene0237_00
1120 | scene0237_01
1121 | scene0365_00
1122 | scene0365_01
1123 | scene0365_02
1124 | scene0605_00
1125 | scene0605_01
1126 | scene0539_00
1127 | scene0539_01
1128 | scene0539_02
1129 | scene0540_00
1130 | scene0540_01
1131 | scene0540_02
1132 | scene0170_00
1133 | scene0170_01
1134 | scene0170_02
1135 | scene0433_00
1136 | scene0340_00
1137 | scene0340_01
1138 | scene0340_02
1139 | scene0160_00
1140 | scene0160_01
1141 | scene0160_02
1142 | scene0160_03
1143 | scene0160_04
1144 | scene0059_00
1145 | scene0059_01
1146 | scene0059_02
1147 | scene0056_00
1148 | scene0056_01
1149 | scene0478_00
1150 | scene0478_01
1151 | scene0548_00
1152 | scene0548_01
1153 | scene0548_02
1154 | scene0204_00
1155 | scene0204_01
1156 | scene0204_02
1157 | scene0033_00
1158 | scene0145_00
1159 | scene0483_00
1160 | scene0508_00
1161 | scene0508_01
1162 | scene0508_02
1163 | scene0180_00
1164 | scene0148_00
1165 | scene0556_00
1166 | scene0556_01
1167 | scene0416_00
1168 | scene0416_01
1169 | scene0416_02
1170 | scene0416_03
1171 | scene0416_04
1172 | scene0073_00
1173 | scene0073_01
1174 | scene0073_02
1175 | scene0073_03
1176 | scene0034_00
1177 | scene0034_01
1178 | scene0034_02
1179 | scene0639_00
1180 | scene0561_00
1181 | scene0561_01
1182 | scene0298_00
1183 | scene0692_00
1184 | scene0692_01
1185 | scene0692_02
1186 | scene0692_03
1187 | scene0692_04
1188 | scene0642_00
1189 | scene0642_01
1190 | scene0642_02
1191 | scene0642_03
1192 | scene0630_00
1193 | scene0630_01
1194 | scene0630_02
1195 | scene0630_03
1196 | scene0630_04
1197 | scene0630_05
1198 | scene0630_06
1199 | scene0706_00
1200 | scene0567_00
1201 | scene0567_01
1202 |
--------------------------------------------------------------------------------
/src/data/meta_data/scannet/scannetv2_trainval.txt:
--------------------------------------------------------------------------------
1 | scene0000_00
2 | scene0000_01
3 | scene0000_02
4 | scene0001_00
5 | scene0001_01
6 | scene0002_00
7 | scene0002_01
8 | scene0003_00
9 | scene0003_01
10 | scene0003_02
11 | scene0004_00
12 | scene0005_00
13 | scene0005_01
14 | scene0006_00
15 | scene0006_01
16 | scene0006_02
17 | scene0007_00
18 | scene0008_00
19 | scene0009_00
20 | scene0009_01
21 | scene0009_02
22 | scene0010_00
23 | scene0010_01
24 | scene0011_00
25 | scene0011_01
26 | scene0012_00
27 | scene0012_01
28 | scene0012_02
29 | scene0013_00
30 | scene0013_01
31 | scene0013_02
32 | scene0014_00
33 | scene0015_00
34 | scene0016_00
35 | scene0016_01
36 | scene0016_02
37 | scene0017_00
38 | scene0017_01
39 | scene0017_02
40 | scene0018_00
41 | scene0019_00
42 | scene0019_01
43 | scene0020_00
44 | scene0020_01
45 | scene0021_00
46 | scene0022_00
47 | scene0022_01
48 | scene0023_00
49 | scene0024_00
50 | scene0024_01
51 | scene0024_02
52 | scene0025_00
53 | scene0025_01
54 | scene0025_02
55 | scene0026_00
56 | scene0027_00
57 | scene0027_01
58 | scene0027_02
59 | scene0028_00
60 | scene0029_00
61 | scene0029_01
62 | scene0029_02
63 | scene0030_00
64 | scene0030_01
65 | scene0030_02
66 | scene0031_00
67 | scene0031_01
68 | scene0031_02
69 | scene0032_00
70 | scene0032_01
71 | scene0033_00
72 | scene0034_00
73 | scene0034_01
74 | scene0034_02
75 | scene0035_00
76 | scene0035_01
77 | scene0036_00
78 | scene0036_01
79 | scene0037_00
80 | scene0038_00
81 | scene0038_01
82 | scene0038_02
83 | scene0039_00
84 | scene0039_01
85 | scene0040_00
86 | scene0040_01
87 | scene0041_00
88 | scene0041_01
89 | scene0042_00
90 | scene0042_01
91 | scene0042_02
92 | scene0043_00
93 | scene0043_01
94 | scene0044_00
95 | scene0044_01
96 | scene0044_02
97 | scene0045_00
98 | scene0045_01
99 | scene0046_00
100 | scene0046_01
101 | scene0046_02
102 | scene0047_00
103 | scene0048_00
104 | scene0048_01
105 | scene0049_00
106 | scene0050_00
107 | scene0050_01
108 | scene0050_02
109 | scene0051_00
110 | scene0051_01
111 | scene0051_02
112 | scene0051_03
113 | scene0052_00
114 | scene0052_01
115 | scene0052_02
116 | scene0053_00
117 | scene0054_00
118 | scene0055_00
119 | scene0055_01
120 | scene0055_02
121 | scene0056_00
122 | scene0056_01
123 | scene0057_00
124 | scene0057_01
125 | scene0058_00
126 | scene0058_01
127 | scene0059_00
128 | scene0059_01
129 | scene0059_02
130 | scene0060_00
131 | scene0060_01
132 | scene0061_00
133 | scene0061_01
134 | scene0062_00
135 | scene0062_01
136 | scene0062_02
137 | scene0063_00
138 | scene0064_00
139 | scene0064_01
140 | scene0065_00
141 | scene0065_01
142 | scene0065_02
143 | scene0066_00
144 | scene0067_00
145 | scene0067_01
146 | scene0067_02
147 | scene0068_00
148 | scene0068_01
149 | scene0069_00
150 | scene0070_00
151 | scene0071_00
152 | scene0072_00
153 | scene0072_01
154 | scene0072_02
155 | scene0073_00
156 | scene0073_01
157 | scene0073_02
158 | scene0073_03
159 | scene0074_00
160 | scene0074_01
161 | scene0074_02
162 | scene0075_00
163 | scene0076_00
164 | scene0077_00
165 | scene0077_01
166 | scene0078_00
167 | scene0078_01
168 | scene0078_02
169 | scene0079_00
170 | scene0079_01
171 | scene0080_00
172 | scene0080_01
173 | scene0080_02
174 | scene0081_00
175 | scene0081_01
176 | scene0081_02
177 | scene0082_00
178 | scene0083_00
179 | scene0083_01
180 | scene0084_00
181 | scene0084_01
182 | scene0084_02
183 | scene0085_00
184 | scene0085_01
185 | scene0086_00
186 | scene0086_01
187 | scene0086_02
188 | scene0087_00
189 | scene0087_01
190 | scene0087_02
191 | scene0088_00
192 | scene0088_01
193 | scene0088_02
194 | scene0088_03
195 | scene0089_00
196 | scene0089_01
197 | scene0089_02
198 | scene0090_00
199 | scene0091_00
200 | scene0092_00
201 | scene0092_01
202 | scene0092_02
203 | scene0092_03
204 | scene0092_04
205 | scene0093_00
206 | scene0093_01
207 | scene0093_02
208 | scene0094_00
209 | scene0095_00
210 | scene0095_01
211 | scene0096_00
212 | scene0096_01
213 | scene0096_02
214 | scene0097_00
215 | scene0098_00
216 | scene0098_01
217 | scene0099_00
218 | scene0099_01
219 | scene0100_00
220 | scene0100_01
221 | scene0100_02
222 | scene0101_00
223 | scene0101_01
224 | scene0101_02
225 | scene0101_03
226 | scene0101_04
227 | scene0101_05
228 | scene0102_00
229 | scene0102_01
230 | scene0103_00
231 | scene0103_01
232 | scene0104_00
233 | scene0105_00
234 | scene0105_01
235 | scene0105_02
236 | scene0106_00
237 | scene0106_01
238 | scene0106_02
239 | scene0107_00
240 | scene0108_00
241 | scene0109_00
242 | scene0109_01
243 | scene0110_00
244 | scene0110_01
245 | scene0110_02
246 | scene0111_00
247 | scene0111_01
248 | scene0111_02
249 | scene0112_00
250 | scene0112_01
251 | scene0112_02
252 | scene0113_00
253 | scene0113_01
254 | scene0114_00
255 | scene0114_01
256 | scene0114_02
257 | scene0115_00
258 | scene0115_01
259 | scene0115_02
260 | scene0116_00
261 | scene0116_01
262 | scene0116_02
263 | scene0117_00
264 | scene0118_00
265 | scene0118_01
266 | scene0118_02
267 | scene0119_00
268 | scene0120_00
269 | scene0120_01
270 | scene0121_00
271 | scene0121_01
272 | scene0121_02
273 | scene0122_00
274 | scene0122_01
275 | scene0123_00
276 | scene0123_01
277 | scene0123_02
278 | scene0124_00
279 | scene0124_01
280 | scene0125_00
281 | scene0126_00
282 | scene0126_01
283 | scene0126_02
284 | scene0127_00
285 | scene0127_01
286 | scene0128_00
287 | scene0129_00
288 | scene0130_00
289 | scene0131_00
290 | scene0131_01
291 | scene0131_02
292 | scene0132_00
293 | scene0132_01
294 | scene0132_02
295 | scene0133_00
296 | scene0134_00
297 | scene0134_01
298 | scene0134_02
299 | scene0135_00
300 | scene0136_00
301 | scene0136_01
302 | scene0136_02
303 | scene0137_00
304 | scene0137_01
305 | scene0137_02
306 | scene0138_00
307 | scene0139_00
308 | scene0140_00
309 | scene0140_01
310 | scene0141_00
311 | scene0141_01
312 | scene0141_02
313 | scene0142_00
314 | scene0142_01
315 | scene0143_00
316 | scene0143_01
317 | scene0143_02
318 | scene0144_00
319 | scene0144_01
320 | scene0145_00
321 | scene0146_00
322 | scene0146_01
323 | scene0146_02
324 | scene0147_00
325 | scene0147_01
326 | scene0148_00
327 | scene0149_00
328 | scene0150_00
329 | scene0150_01
330 | scene0150_02
331 | scene0151_00
332 | scene0151_01
333 | scene0152_00
334 | scene0152_01
335 | scene0152_02
336 | scene0153_00
337 | scene0153_01
338 | scene0154_00
339 | scene0155_00
340 | scene0155_01
341 | scene0155_02
342 | scene0156_00
343 | scene0157_00
344 | scene0157_01
345 | scene0158_00
346 | scene0158_01
347 | scene0158_02
348 | scene0159_00
349 | scene0160_00
350 | scene0160_01
351 | scene0160_02
352 | scene0160_03
353 | scene0160_04
354 | scene0161_00
355 | scene0161_01
356 | scene0161_02
357 | scene0162_00
358 | scene0163_00
359 | scene0163_01
360 | scene0164_00
361 | scene0164_01
362 | scene0164_02
363 | scene0164_03
364 | scene0165_00
365 | scene0165_01
366 | scene0165_02
367 | scene0166_00
368 | scene0166_01
369 | scene0166_02
370 | scene0167_00
371 | scene0168_00
372 | scene0168_01
373 | scene0168_02
374 | scene0169_00
375 | scene0169_01
376 | scene0170_00
377 | scene0170_01
378 | scene0170_02
379 | scene0171_00
380 | scene0171_01
381 | scene0172_00
382 | scene0172_01
383 | scene0173_00
384 | scene0173_01
385 | scene0173_02
386 | scene0174_00
387 | scene0174_01
388 | scene0175_00
389 | scene0176_00
390 | scene0177_00
391 | scene0177_01
392 | scene0177_02
393 | scene0178_00
394 | scene0179_00
395 | scene0180_00
396 | scene0181_00
397 | scene0181_01
398 | scene0181_02
399 | scene0181_03
400 | scene0182_00
401 | scene0182_01
402 | scene0182_02
403 | scene0183_00
404 | scene0184_00
405 | scene0185_00
406 | scene0186_00
407 | scene0186_01
408 | scene0187_00
409 | scene0187_01
410 | scene0188_00
411 | scene0189_00
412 | scene0190_00
413 | scene0191_00
414 | scene0191_01
415 | scene0191_02
416 | scene0192_00
417 | scene0192_01
418 | scene0192_02
419 | scene0193_00
420 | scene0193_01
421 | scene0194_00
422 | scene0195_00
423 | scene0195_01
424 | scene0195_02
425 | scene0196_00
426 | scene0197_00
427 | scene0197_01
428 | scene0197_02
429 | scene0198_00
430 | scene0199_00
431 | scene0200_00
432 | scene0200_01
433 | scene0200_02
434 | scene0201_00
435 | scene0201_01
436 | scene0201_02
437 | scene0202_00
438 | scene0203_00
439 | scene0203_01
440 | scene0203_02
441 | scene0204_00
442 | scene0204_01
443 | scene0204_02
444 | scene0205_00
445 | scene0205_01
446 | scene0205_02
447 | scene0206_00
448 | scene0206_01
449 | scene0206_02
450 | scene0207_00
451 | scene0207_01
452 | scene0207_02
453 | scene0208_00
454 | scene0209_00
455 | scene0209_01
456 | scene0209_02
457 | scene0210_00
458 | scene0210_01
459 | scene0211_00
460 | scene0211_01
461 | scene0211_02
462 | scene0211_03
463 | scene0212_00
464 | scene0212_01
465 | scene0212_02
466 | scene0213_00
467 | scene0214_00
468 | scene0214_01
469 | scene0214_02
470 | scene0215_00
471 | scene0215_01
472 | scene0216_00
473 | scene0217_00
474 | scene0218_00
475 | scene0218_01
476 | scene0219_00
477 | scene0220_00
478 | scene0220_01
479 | scene0220_02
480 | scene0221_00
481 | scene0221_01
482 | scene0222_00
483 | scene0222_01
484 | scene0223_00
485 | scene0223_01
486 | scene0223_02
487 | scene0224_00
488 | scene0225_00
489 | scene0226_00
490 | scene0226_01
491 | scene0227_00
492 | scene0228_00
493 | scene0229_00
494 | scene0229_01
495 | scene0229_02
496 | scene0230_00
497 | scene0231_00
498 | scene0231_01
499 | scene0231_02
500 | scene0232_00
501 | scene0232_01
502 | scene0232_02
503 | scene0233_00
504 | scene0233_01
505 | scene0234_00
506 | scene0235_00
507 | scene0236_00
508 | scene0236_01
509 | scene0237_00
510 | scene0237_01
511 | scene0238_00
512 | scene0238_01
513 | scene0239_00
514 | scene0239_01
515 | scene0239_02
516 | scene0240_00
517 | scene0241_00
518 | scene0241_01
519 | scene0241_02
520 | scene0242_00
521 | scene0242_01
522 | scene0242_02
523 | scene0243_00
524 | scene0244_00
525 | scene0244_01
526 | scene0245_00
527 | scene0246_00
528 | scene0247_00
529 | scene0247_01
530 | scene0248_00
531 | scene0248_01
532 | scene0248_02
533 | scene0249_00
534 | scene0250_00
535 | scene0250_01
536 | scene0250_02
537 | scene0251_00
538 | scene0252_00
539 | scene0253_00
540 | scene0254_00
541 | scene0254_01
542 | scene0255_00
543 | scene0255_01
544 | scene0255_02
545 | scene0256_00
546 | scene0256_01
547 | scene0256_02
548 | scene0257_00
549 | scene0258_00
550 | scene0259_00
551 | scene0259_01
552 | scene0260_00
553 | scene0260_01
554 | scene0260_02
555 | scene0261_00
556 | scene0261_01
557 | scene0261_02
558 | scene0261_03
559 | scene0262_00
560 | scene0262_01
561 | scene0263_00
562 | scene0263_01
563 | scene0264_00
564 | scene0264_01
565 | scene0264_02
566 | scene0265_00
567 | scene0265_01
568 | scene0265_02
569 | scene0266_00
570 | scene0266_01
571 | scene0267_00
572 | scene0268_00
573 | scene0268_01
574 | scene0268_02
575 | scene0269_00
576 | scene0269_01
577 | scene0269_02
578 | scene0270_00
579 | scene0270_01
580 | scene0270_02
581 | scene0271_00
582 | scene0271_01
583 | scene0272_00
584 | scene0272_01
585 | scene0273_00
586 | scene0273_01
587 | scene0274_00
588 | scene0274_01
589 | scene0274_02
590 | scene0275_00
591 | scene0276_00
592 | scene0276_01
593 | scene0277_00
594 | scene0277_01
595 | scene0277_02
596 | scene0278_00
597 | scene0278_01
598 | scene0279_00
599 | scene0279_01
600 | scene0279_02
601 | scene0280_00
602 | scene0280_01
603 | scene0280_02
604 | scene0281_00
605 | scene0282_00
606 | scene0282_01
607 | scene0282_02
608 | scene0283_00
609 | scene0284_00
610 | scene0285_00
611 | scene0286_00
612 | scene0286_01
613 | scene0286_02
614 | scene0286_03
615 | scene0287_00
616 | scene0288_00
617 | scene0288_01
618 | scene0288_02
619 | scene0289_00
620 | scene0289_01
621 | scene0290_00
622 | scene0291_00
623 | scene0291_01
624 | scene0291_02
625 | scene0292_00
626 | scene0292_01
627 | scene0293_00
628 | scene0293_01
629 | scene0294_00
630 | scene0294_01
631 | scene0294_02
632 | scene0295_00
633 | scene0295_01
634 | scene0296_00
635 | scene0296_01
636 | scene0297_00
637 | scene0297_01
638 | scene0297_02
639 | scene0298_00
640 | scene0299_00
641 | scene0299_01
642 | scene0300_00
643 | scene0300_01
644 | scene0301_00
645 | scene0301_01
646 | scene0301_02
647 | scene0302_00
648 | scene0302_01
649 | scene0303_00
650 | scene0303_01
651 | scene0303_02
652 | scene0304_00
653 | scene0305_00
654 | scene0305_01
655 | scene0306_00
656 | scene0306_01
657 | scene0307_00
658 | scene0307_01
659 | scene0307_02
660 | scene0308_00
661 | scene0309_00
662 | scene0309_01
663 | scene0310_00
664 | scene0310_01
665 | scene0310_02
666 | scene0311_00
667 | scene0312_00
668 | scene0312_01
669 | scene0312_02
670 | scene0313_00
671 | scene0313_01
672 | scene0313_02
673 | scene0314_00
674 | scene0315_00
675 | scene0316_00
676 | scene0317_00
677 | scene0317_01
678 | scene0318_00
679 | scene0319_00
680 | scene0320_00
681 | scene0320_01
682 | scene0320_02
683 | scene0320_03
684 | scene0321_00
685 | scene0322_00
686 | scene0323_00
687 | scene0323_01
688 | scene0324_00
689 | scene0324_01
690 | scene0325_00
691 | scene0325_01
692 | scene0326_00
693 | scene0327_00
694 | scene0328_00
695 | scene0329_00
696 | scene0329_01
697 | scene0329_02
698 | scene0330_00
699 | scene0331_00
700 | scene0331_01
701 | scene0332_00
702 | scene0332_01
703 | scene0332_02
704 | scene0333_00
705 | scene0334_00
706 | scene0334_01
707 | scene0334_02
708 | scene0335_00
709 | scene0335_01
710 | scene0335_02
711 | scene0336_00
712 | scene0336_01
713 | scene0337_00
714 | scene0337_01
715 | scene0337_02
716 | scene0338_00
717 | scene0338_01
718 | scene0338_02
719 | scene0339_00
720 | scene0340_00
721 | scene0340_01
722 | scene0340_02
723 | scene0341_00
724 | scene0341_01
725 | scene0342_00
726 | scene0343_00
727 | scene0344_00
728 | scene0344_01
729 | scene0345_00
730 | scene0345_01
731 | scene0346_00
732 | scene0346_01
733 | scene0347_00
734 | scene0347_01
735 | scene0347_02
736 | scene0348_00
737 | scene0348_01
738 | scene0348_02
739 | scene0349_00
740 | scene0349_01
741 | scene0350_00
742 | scene0350_01
743 | scene0350_02
744 | scene0351_00
745 | scene0351_01
746 | scene0352_00
747 | scene0352_01
748 | scene0352_02
749 | scene0353_00
750 | scene0353_01
751 | scene0353_02
752 | scene0354_00
753 | scene0355_00
754 | scene0355_01
755 | scene0356_00
756 | scene0356_01
757 | scene0356_02
758 | scene0357_00
759 | scene0357_01
760 | scene0358_00
761 | scene0358_01
762 | scene0358_02
763 | scene0359_00
764 | scene0359_01
765 | scene0360_00
766 | scene0361_00
767 | scene0361_01
768 | scene0361_02
769 | scene0362_00
770 | scene0362_01
771 | scene0362_02
772 | scene0362_03
773 | scene0363_00
774 | scene0364_00
775 | scene0364_01
776 | scene0365_00
777 | scene0365_01
778 | scene0365_02
779 | scene0366_00
780 | scene0367_00
781 | scene0367_01
782 | scene0368_00
783 | scene0368_01
784 | scene0369_00
785 | scene0369_01
786 | scene0369_02
787 | scene0370_00
788 | scene0370_01
789 | scene0370_02
790 | scene0371_00
791 | scene0371_01
792 | scene0372_00
793 | scene0373_00
794 | scene0373_01
795 | scene0374_00
796 | scene0375_00
797 | scene0375_01
798 | scene0375_02
799 | scene0376_00
800 | scene0376_01
801 | scene0376_02
802 | scene0377_00
803 | scene0377_01
804 | scene0377_02
805 | scene0378_00
806 | scene0378_01
807 | scene0378_02
808 | scene0379_00
809 | scene0380_00
810 | scene0380_01
811 | scene0380_02
812 | scene0381_00
813 | scene0381_01
814 | scene0381_02
815 | scene0382_00
816 | scene0382_01
817 | scene0383_00
818 | scene0383_01
819 | scene0383_02
820 | scene0384_00
821 | scene0385_00
822 | scene0385_01
823 | scene0385_02
824 | scene0386_00
825 | scene0387_00
826 | scene0387_01
827 | scene0387_02
828 | scene0388_00
829 | scene0388_01
830 | scene0389_00
831 | scene0390_00
832 | scene0391_00
833 | scene0392_00
834 | scene0392_01
835 | scene0392_02
836 | scene0393_00
837 | scene0393_01
838 | scene0393_02
839 | scene0394_00
840 | scene0394_01
841 | scene0395_00
842 | scene0395_01
843 | scene0395_02
844 | scene0396_00
845 | scene0396_01
846 | scene0396_02
847 | scene0397_00
848 | scene0397_01
849 | scene0398_00
850 | scene0398_01
851 | scene0399_00
852 | scene0399_01
853 | scene0400_00
854 | scene0400_01
855 | scene0401_00
856 | scene0402_00
857 | scene0403_00
858 | scene0403_01
859 | scene0404_00
860 | scene0404_01
861 | scene0404_02
862 | scene0405_00
863 | scene0406_00
864 | scene0406_01
865 | scene0406_02
866 | scene0407_00
867 | scene0407_01
868 | scene0408_00
869 | scene0408_01
870 | scene0409_00
871 | scene0409_01
872 | scene0410_00
873 | scene0410_01
874 | scene0411_00
875 | scene0411_01
876 | scene0411_02
877 | scene0412_00
878 | scene0412_01
879 | scene0413_00
880 | scene0414_00
881 | scene0415_00
882 | scene0415_01
883 | scene0415_02
884 | scene0416_00
885 | scene0416_01
886 | scene0416_02
887 | scene0416_03
888 | scene0416_04
889 | scene0417_00
890 | scene0418_00
891 | scene0418_01
892 | scene0418_02
893 | scene0419_00
894 | scene0419_01
895 | scene0419_02
896 | scene0420_00
897 | scene0420_01
898 | scene0420_02
899 | scene0421_00
900 | scene0421_01
901 | scene0421_02
902 | scene0422_00
903 | scene0423_00
904 | scene0423_01
905 | scene0423_02
906 | scene0424_00
907 | scene0424_01
908 | scene0424_02
909 | scene0425_00
910 | scene0425_01
911 | scene0426_00
912 | scene0426_01
913 | scene0426_02
914 | scene0426_03
915 | scene0427_00
916 | scene0428_00
917 | scene0428_01
918 | scene0429_00
919 | scene0430_00
920 | scene0430_01
921 | scene0431_00
922 | scene0432_00
923 | scene0432_01
924 | scene0433_00
925 | scene0434_00
926 | scene0434_01
927 | scene0434_02
928 | scene0435_00
929 | scene0435_01
930 | scene0435_02
931 | scene0435_03
932 | scene0436_00
933 | scene0437_00
934 | scene0437_01
935 | scene0438_00
936 | scene0439_00
937 | scene0439_01
938 | scene0440_00
939 | scene0440_01
940 | scene0440_02
941 | scene0441_00
942 | scene0442_00
943 | scene0443_00
944 | scene0444_00
945 | scene0444_01
946 | scene0445_00
947 | scene0445_01
948 | scene0446_00
949 | scene0446_01
950 | scene0447_00
951 | scene0447_01
952 | scene0447_02
953 | scene0448_00
954 | scene0448_01
955 | scene0448_02
956 | scene0449_00
957 | scene0449_01
958 | scene0449_02
959 | scene0450_00
960 | scene0451_00
961 | scene0451_01
962 | scene0451_02
963 | scene0451_03
964 | scene0451_04
965 | scene0451_05
966 | scene0452_00
967 | scene0452_01
968 | scene0452_02
969 | scene0453_00
970 | scene0453_01
971 | scene0454_00
972 | scene0455_00
973 | scene0456_00
974 | scene0456_01
975 | scene0457_00
976 | scene0457_01
977 | scene0457_02
978 | scene0458_00
979 | scene0458_01
980 | scene0459_00
981 | scene0459_01
982 | scene0460_00
983 | scene0461_00
984 | scene0462_00
985 | scene0463_00
986 | scene0463_01
987 | scene0464_00
988 | scene0465_00
989 | scene0465_01
990 | scene0466_00
991 | scene0466_01
992 | scene0467_00
993 | scene0468_00
994 | scene0468_01
995 | scene0468_02
996 | scene0469_00
997 | scene0469_01
998 | scene0469_02
999 | scene0470_00
1000 | scene0470_01
1001 | scene0471_00
1002 | scene0471_01
1003 | scene0471_02
1004 | scene0472_00
1005 | scene0472_01
1006 | scene0472_02
1007 | scene0473_00
1008 | scene0473_01
1009 | scene0474_00
1010 | scene0474_01
1011 | scene0474_02
1012 | scene0474_03
1013 | scene0474_04
1014 | scene0474_05
1015 | scene0475_00
1016 | scene0475_01
1017 | scene0475_02
1018 | scene0476_00
1019 | scene0476_01
1020 | scene0476_02
1021 | scene0477_00
1022 | scene0477_01
1023 | scene0478_00
1024 | scene0478_01
1025 | scene0479_00
1026 | scene0479_01
1027 | scene0479_02
1028 | scene0480_00
1029 | scene0480_01
1030 | scene0481_00
1031 | scene0481_01
1032 | scene0482_00
1033 | scene0482_01
1034 | scene0483_00
1035 | scene0484_00
1036 | scene0484_01
1037 | scene0485_00
1038 | scene0486_00
1039 | scene0487_00
1040 | scene0487_01
1041 | scene0488_00
1042 | scene0488_01
1043 | scene0489_00
1044 | scene0489_01
1045 | scene0489_02
1046 | scene0490_00
1047 | scene0491_00
1048 | scene0492_00
1049 | scene0492_01
1050 | scene0493_00
1051 | scene0493_01
1052 | scene0494_00
1053 | scene0495_00
1054 | scene0496_00
1055 | scene0497_00
1056 | scene0498_00
1057 | scene0498_01
1058 | scene0498_02
1059 | scene0499_00
1060 | scene0500_00
1061 | scene0500_01
1062 | scene0501_00
1063 | scene0501_01
1064 | scene0501_02
1065 | scene0502_00
1066 | scene0502_01
1067 | scene0502_02
1068 | scene0503_00
1069 | scene0504_00
1070 | scene0505_00
1071 | scene0505_01
1072 | scene0505_02
1073 | scene0505_03
1074 | scene0505_04
1075 | scene0506_00
1076 | scene0507_00
1077 | scene0508_00
1078 | scene0508_01
1079 | scene0508_02
1080 | scene0509_00
1081 | scene0509_01
1082 | scene0509_02
1083 | scene0510_00
1084 | scene0510_01
1085 | scene0510_02
1086 | scene0511_00
1087 | scene0511_01
1088 | scene0512_00
1089 | scene0513_00
1090 | scene0514_00
1091 | scene0514_01
1092 | scene0515_00
1093 | scene0515_01
1094 | scene0515_02
1095 | scene0516_00
1096 | scene0516_01
1097 | scene0517_00
1098 | scene0517_01
1099 | scene0517_02
1100 | scene0518_00
1101 | scene0519_00
1102 | scene0520_00
1103 | scene0520_01
1104 | scene0521_00
1105 | scene0522_00
1106 | scene0523_00
1107 | scene0523_01
1108 | scene0523_02
1109 | scene0524_00
1110 | scene0524_01
1111 | scene0525_00
1112 | scene0525_01
1113 | scene0525_02
1114 | scene0526_00
1115 | scene0526_01
1116 | scene0527_00
1117 | scene0528_00
1118 | scene0528_01
1119 | scene0529_00
1120 | scene0529_01
1121 | scene0529_02
1122 | scene0530_00
1123 | scene0531_00
1124 | scene0532_00
1125 | scene0532_01
1126 | scene0533_00
1127 | scene0533_01
1128 | scene0534_00
1129 | scene0534_01
1130 | scene0535_00
1131 | scene0536_00
1132 | scene0536_01
1133 | scene0536_02
1134 | scene0537_00
1135 | scene0538_00
1136 | scene0539_00
1137 | scene0539_01
1138 | scene0539_02
1139 | scene0540_00
1140 | scene0540_01
1141 | scene0540_02
1142 | scene0541_00
1143 | scene0541_01
1144 | scene0541_02
1145 | scene0542_00
1146 | scene0543_00
1147 | scene0543_01
1148 | scene0543_02
1149 | scene0544_00
1150 | scene0545_00
1151 | scene0545_01
1152 | scene0545_02
1153 | scene0546_00
1154 | scene0547_00
1155 | scene0547_01
1156 | scene0547_02
1157 | scene0548_00
1158 | scene0548_01
1159 | scene0548_02
1160 | scene0549_00
1161 | scene0549_01
1162 | scene0550_00
1163 | scene0551_00
1164 | scene0552_00
1165 | scene0552_01
1166 | scene0553_00
1167 | scene0553_01
1168 | scene0553_02
1169 | scene0554_00
1170 | scene0554_01
1171 | scene0555_00
1172 | scene0556_00
1173 | scene0556_01
1174 | scene0557_00
1175 | scene0557_01
1176 | scene0557_02
1177 | scene0558_00
1178 | scene0558_01
1179 | scene0558_02
1180 | scene0559_00
1181 | scene0559_01
1182 | scene0559_02
1183 | scene0560_00
1184 | scene0561_00
1185 | scene0561_01
1186 | scene0562_00
1187 | scene0563_00
1188 | scene0564_00
1189 | scene0565_00
1190 | scene0566_00
1191 | scene0567_00
1192 | scene0567_01
1193 | scene0568_00
1194 | scene0568_01
1195 | scene0568_02
1196 | scene0569_00
1197 | scene0569_01
1198 | scene0570_00
1199 | scene0570_01
1200 | scene0570_02
1201 | scene0571_00
1202 | scene0571_01
1203 | scene0572_00
1204 | scene0572_01
1205 | scene0572_02
1206 | scene0573_00
1207 | scene0573_01
1208 | scene0574_00
1209 | scene0574_01
1210 | scene0574_02
1211 | scene0575_00
1212 | scene0575_01
1213 | scene0575_02
1214 | scene0576_00
1215 | scene0576_01
1216 | scene0576_02
1217 | scene0577_00
1218 | scene0578_00
1219 | scene0578_01
1220 | scene0578_02
1221 | scene0579_00
1222 | scene0579_01
1223 | scene0579_02
1224 | scene0580_00
1225 | scene0580_01
1226 | scene0581_00
1227 | scene0581_01
1228 | scene0581_02
1229 | scene0582_00
1230 | scene0582_01
1231 | scene0582_02
1232 | scene0583_00
1233 | scene0583_01
1234 | scene0583_02
1235 | scene0584_00
1236 | scene0584_01
1237 | scene0584_02
1238 | scene0585_00
1239 | scene0585_01
1240 | scene0586_00
1241 | scene0586_01
1242 | scene0586_02
1243 | scene0587_00
1244 | scene0587_01
1245 | scene0587_02
1246 | scene0587_03
1247 | scene0588_00
1248 | scene0588_01
1249 | scene0588_02
1250 | scene0588_03
1251 | scene0589_00
1252 | scene0589_01
1253 | scene0589_02
1254 | scene0590_00
1255 | scene0590_01
1256 | scene0591_00
1257 | scene0591_01
1258 | scene0591_02
1259 | scene0592_00
1260 | scene0592_01
1261 | scene0593_00
1262 | scene0593_01
1263 | scene0594_00
1264 | scene0595_00
1265 | scene0596_00
1266 | scene0596_01
1267 | scene0596_02
1268 | scene0597_00
1269 | scene0597_01
1270 | scene0597_02
1271 | scene0598_00
1272 | scene0598_01
1273 | scene0598_02
1274 | scene0599_00
1275 | scene0599_01
1276 | scene0599_02
1277 | scene0600_00
1278 | scene0600_01
1279 | scene0600_02
1280 | scene0601_00
1281 | scene0601_01
1282 | scene0602_00
1283 | scene0603_00
1284 | scene0603_01
1285 | scene0604_00
1286 | scene0604_01
1287 | scene0604_02
1288 | scene0605_00
1289 | scene0605_01
1290 | scene0606_00
1291 | scene0606_01
1292 | scene0606_02
1293 | scene0607_00
1294 | scene0607_01
1295 | scene0608_00
1296 | scene0608_01
1297 | scene0608_02
1298 | scene0609_00
1299 | scene0609_01
1300 | scene0609_02
1301 | scene0609_03
1302 | scene0610_00
1303 | scene0610_01
1304 | scene0610_02
1305 | scene0611_00
1306 | scene0611_01
1307 | scene0612_00
1308 | scene0612_01
1309 | scene0613_00
1310 | scene0613_01
1311 | scene0613_02
1312 | scene0614_00
1313 | scene0614_01
1314 | scene0614_02
1315 | scene0615_00
1316 | scene0615_01
1317 | scene0616_00
1318 | scene0616_01
1319 | scene0617_00
1320 | scene0618_00
1321 | scene0619_00
1322 | scene0620_00
1323 | scene0620_01
1324 | scene0621_00
1325 | scene0622_00
1326 | scene0622_01
1327 | scene0623_00
1328 | scene0623_01
1329 | scene0624_00
1330 | scene0625_00
1331 | scene0625_01
1332 | scene0626_00
1333 | scene0626_01
1334 | scene0626_02
1335 | scene0627_00
1336 | scene0627_01
1337 | scene0628_00
1338 | scene0628_01
1339 | scene0628_02
1340 | scene0629_00
1341 | scene0629_01
1342 | scene0629_02
1343 | scene0630_00
1344 | scene0630_01
1345 | scene0630_02
1346 | scene0630_03
1347 | scene0630_04
1348 | scene0630_05
1349 | scene0630_06
1350 | scene0631_00
1351 | scene0631_01
1352 | scene0631_02
1353 | scene0632_00
1354 | scene0633_00
1355 | scene0633_01
1356 | scene0634_00
1357 | scene0635_00
1358 | scene0635_01
1359 | scene0636_00
1360 | scene0637_00
1361 | scene0638_00
1362 | scene0639_00
1363 | scene0640_00
1364 | scene0640_01
1365 | scene0640_02
1366 | scene0641_00
1367 | scene0642_00
1368 | scene0642_01
1369 | scene0642_02
1370 | scene0642_03
1371 | scene0643_00
1372 | scene0644_00
1373 | scene0645_00
1374 | scene0645_01
1375 | scene0645_02
1376 | scene0646_00
1377 | scene0646_01
1378 | scene0646_02
1379 | scene0647_00
1380 | scene0647_01
1381 | scene0648_00
1382 | scene0648_01
1383 | scene0649_00
1384 | scene0649_01
1385 | scene0650_00
1386 | scene0651_00
1387 | scene0651_01
1388 | scene0651_02
1389 | scene0652_00
1390 | scene0653_00
1391 | scene0653_01
1392 | scene0654_00
1393 | scene0654_01
1394 | scene0655_00
1395 | scene0655_01
1396 | scene0655_02
1397 | scene0656_00
1398 | scene0656_01
1399 | scene0656_02
1400 | scene0656_03
1401 | scene0657_00
1402 | scene0658_00
1403 | scene0659_00
1404 | scene0659_01
1405 | scene0660_00
1406 | scene0661_00
1407 | scene0662_00
1408 | scene0662_01
1409 | scene0662_02
1410 | scene0663_00
1411 | scene0663_01
1412 | scene0663_02
1413 | scene0664_00
1414 | scene0664_01
1415 | scene0664_02
1416 | scene0665_00
1417 | scene0665_01
1418 | scene0666_00
1419 | scene0666_01
1420 | scene0666_02
1421 | scene0667_00
1422 | scene0667_01
1423 | scene0667_02
1424 | scene0668_00
1425 | scene0669_00
1426 | scene0669_01
1427 | scene0670_00
1428 | scene0670_01
1429 | scene0671_00
1430 | scene0671_01
1431 | scene0672_00
1432 | scene0672_01
1433 | scene0673_00
1434 | scene0673_01
1435 | scene0673_02
1436 | scene0673_03
1437 | scene0673_04
1438 | scene0673_05
1439 | scene0674_00
1440 | scene0674_01
1441 | scene0675_00
1442 | scene0675_01
1443 | scene0676_00
1444 | scene0676_01
1445 | scene0677_00
1446 | scene0677_01
1447 | scene0677_02
1448 | scene0678_00
1449 | scene0678_01
1450 | scene0678_02
1451 | scene0679_00
1452 | scene0679_01
1453 | scene0680_00
1454 | scene0680_01
1455 | scene0681_00
1456 | scene0682_00
1457 | scene0683_00
1458 | scene0684_00
1459 | scene0684_01
1460 | scene0685_00
1461 | scene0685_01
1462 | scene0685_02
1463 | scene0686_00
1464 | scene0686_01
1465 | scene0686_02
1466 | scene0687_00
1467 | scene0688_00
1468 | scene0689_00
1469 | scene0690_00
1470 | scene0690_01
1471 | scene0691_00
1472 | scene0691_01
1473 | scene0692_00
1474 | scene0692_01
1475 | scene0692_02
1476 | scene0692_03
1477 | scene0692_04
1478 | scene0693_00
1479 | scene0693_01
1480 | scene0693_02
1481 | scene0694_00
1482 | scene0694_01
1483 | scene0695_00
1484 | scene0695_01
1485 | scene0695_02
1486 | scene0695_03
1487 | scene0696_00
1488 | scene0696_01
1489 | scene0696_02
1490 | scene0697_00
1491 | scene0697_01
1492 | scene0697_02
1493 | scene0697_03
1494 | scene0698_00
1495 | scene0698_01
1496 | scene0699_00
1497 | scene0700_00
1498 | scene0700_01
1499 | scene0700_02
1500 | scene0701_00
1501 | scene0701_01
1502 | scene0701_02
1503 | scene0702_00
1504 | scene0702_01
1505 | scene0702_02
1506 | scene0703_00
1507 | scene0703_01
1508 | scene0704_00
1509 | scene0704_01
1510 | scene0705_00
1511 | scene0705_01
1512 | scene0705_02
1513 | scene0706_00
1514 |
--------------------------------------------------------------------------------
/src/data/meta_data/scannet/scannetv2_val.txt:
--------------------------------------------------------------------------------
1 | scene0568_00
2 | scene0568_01
3 | scene0568_02
4 | scene0304_00
5 | scene0488_00
6 | scene0488_01
7 | scene0412_00
8 | scene0412_01
9 | scene0217_00
10 | scene0019_00
11 | scene0019_01
12 | scene0414_00
13 | scene0575_00
14 | scene0575_01
15 | scene0575_02
16 | scene0426_00
17 | scene0426_01
18 | scene0426_02
19 | scene0426_03
20 | scene0549_00
21 | scene0549_01
22 | scene0578_00
23 | scene0578_01
24 | scene0578_02
25 | scene0665_00
26 | scene0665_01
27 | scene0050_00
28 | scene0050_01
29 | scene0050_02
30 | scene0257_00
31 | scene0025_00
32 | scene0025_01
33 | scene0025_02
34 | scene0583_00
35 | scene0583_01
36 | scene0583_02
37 | scene0701_00
38 | scene0701_01
39 | scene0701_02
40 | scene0580_00
41 | scene0580_01
42 | scene0565_00
43 | scene0169_00
44 | scene0169_01
45 | scene0655_00
46 | scene0655_01
47 | scene0655_02
48 | scene0063_00
49 | scene0221_00
50 | scene0221_01
51 | scene0591_00
52 | scene0591_01
53 | scene0591_02
54 | scene0678_00
55 | scene0678_01
56 | scene0678_02
57 | scene0462_00
58 | scene0427_00
59 | scene0595_00
60 | scene0193_00
61 | scene0193_01
62 | scene0164_00
63 | scene0164_01
64 | scene0164_02
65 | scene0164_03
66 | scene0598_00
67 | scene0598_01
68 | scene0598_02
69 | scene0599_00
70 | scene0599_01
71 | scene0599_02
72 | scene0328_00
73 | scene0300_00
74 | scene0300_01
75 | scene0354_00
76 | scene0458_00
77 | scene0458_01
78 | scene0423_00
79 | scene0423_01
80 | scene0423_02
81 | scene0307_00
82 | scene0307_01
83 | scene0307_02
84 | scene0606_00
85 | scene0606_01
86 | scene0606_02
87 | scene0432_00
88 | scene0432_01
89 | scene0608_00
90 | scene0608_01
91 | scene0608_02
92 | scene0651_00
93 | scene0651_01
94 | scene0651_02
95 | scene0430_00
96 | scene0430_01
97 | scene0689_00
98 | scene0357_00
99 | scene0357_01
100 | scene0574_00
101 | scene0574_01
102 | scene0574_02
103 | scene0329_00
104 | scene0329_01
105 | scene0329_02
106 | scene0153_00
107 | scene0153_01
108 | scene0616_00
109 | scene0616_01
110 | scene0671_00
111 | scene0671_01
112 | scene0618_00
113 | scene0382_00
114 | scene0382_01
115 | scene0490_00
116 | scene0621_00
117 | scene0607_00
118 | scene0607_01
119 | scene0149_00
120 | scene0695_00
121 | scene0695_01
122 | scene0695_02
123 | scene0695_03
124 | scene0389_00
125 | scene0377_00
126 | scene0377_01
127 | scene0377_02
128 | scene0342_00
129 | scene0139_00
130 | scene0629_00
131 | scene0629_01
132 | scene0629_02
133 | scene0496_00
134 | scene0633_00
135 | scene0633_01
136 | scene0518_00
137 | scene0652_00
138 | scene0406_00
139 | scene0406_01
140 | scene0406_02
141 | scene0144_00
142 | scene0144_01
143 | scene0494_00
144 | scene0278_00
145 | scene0278_01
146 | scene0316_00
147 | scene0609_00
148 | scene0609_01
149 | scene0609_02
150 | scene0609_03
151 | scene0084_00
152 | scene0084_01
153 | scene0084_02
154 | scene0696_00
155 | scene0696_01
156 | scene0696_02
157 | scene0351_00
158 | scene0351_01
159 | scene0643_00
160 | scene0644_00
161 | scene0645_00
162 | scene0645_01
163 | scene0645_02
164 | scene0081_00
165 | scene0081_01
166 | scene0081_02
167 | scene0647_00
168 | scene0647_01
169 | scene0535_00
170 | scene0353_00
171 | scene0353_01
172 | scene0353_02
173 | scene0559_00
174 | scene0559_01
175 | scene0559_02
176 | scene0593_00
177 | scene0593_01
178 | scene0246_00
179 | scene0653_00
180 | scene0653_01
181 | scene0064_00
182 | scene0064_01
183 | scene0356_00
184 | scene0356_01
185 | scene0356_02
186 | scene0030_00
187 | scene0030_01
188 | scene0030_02
189 | scene0222_00
190 | scene0222_01
191 | scene0338_00
192 | scene0338_01
193 | scene0338_02
194 | scene0378_00
195 | scene0378_01
196 | scene0378_02
197 | scene0660_00
198 | scene0553_00
199 | scene0553_01
200 | scene0553_02
201 | scene0527_00
202 | scene0663_00
203 | scene0663_01
204 | scene0663_02
205 | scene0664_00
206 | scene0664_01
207 | scene0664_02
208 | scene0334_00
209 | scene0334_01
210 | scene0334_02
211 | scene0046_00
212 | scene0046_01
213 | scene0046_02
214 | scene0203_00
215 | scene0203_01
216 | scene0203_02
217 | scene0088_00
218 | scene0088_01
219 | scene0088_02
220 | scene0088_03
221 | scene0086_00
222 | scene0086_01
223 | scene0086_02
224 | scene0670_00
225 | scene0670_01
226 | scene0256_00
227 | scene0256_01
228 | scene0256_02
229 | scene0249_00
230 | scene0441_00
231 | scene0658_00
232 | scene0704_00
233 | scene0704_01
234 | scene0187_00
235 | scene0187_01
236 | scene0131_00
237 | scene0131_01
238 | scene0131_02
239 | scene0207_00
240 | scene0207_01
241 | scene0207_02
242 | scene0461_00
243 | scene0011_00
244 | scene0011_01
245 | scene0343_00
246 | scene0251_00
247 | scene0077_00
248 | scene0077_01
249 | scene0684_00
250 | scene0684_01
251 | scene0550_00
252 | scene0686_00
253 | scene0686_01
254 | scene0686_02
255 | scene0208_00
256 | scene0500_00
257 | scene0500_01
258 | scene0552_00
259 | scene0552_01
260 | scene0648_00
261 | scene0648_01
262 | scene0435_00
263 | scene0435_01
264 | scene0435_02
265 | scene0435_03
266 | scene0690_00
267 | scene0690_01
268 | scene0693_00
269 | scene0693_01
270 | scene0693_02
271 | scene0700_00
272 | scene0700_01
273 | scene0700_02
274 | scene0699_00
275 | scene0231_00
276 | scene0231_01
277 | scene0231_02
278 | scene0697_00
279 | scene0697_01
280 | scene0697_02
281 | scene0697_03
282 | scene0474_00
283 | scene0474_01
284 | scene0474_02
285 | scene0474_03
286 | scene0474_04
287 | scene0474_05
288 | scene0355_00
289 | scene0355_01
290 | scene0146_00
291 | scene0146_01
292 | scene0146_02
293 | scene0196_00
294 | scene0702_00
295 | scene0702_01
296 | scene0702_02
297 | scene0314_00
298 | scene0277_00
299 | scene0277_01
300 | scene0277_02
301 | scene0095_00
302 | scene0095_01
303 | scene0015_00
304 | scene0100_00
305 | scene0100_01
306 | scene0100_02
307 | scene0558_00
308 | scene0558_01
309 | scene0558_02
310 | scene0685_00
311 | scene0685_01
312 | scene0685_02
313 |
--------------------------------------------------------------------------------
/src/data/preprocess_s3dis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import errno
3 | import glob
4 |
5 | import numpy as np
6 | from tqdm import tqdm
7 | from plyfile import PlyData, PlyElement
8 |
9 |
10 | STANFORD_3D_IN_PATH = '/root/data/s3dis/Stanford3dDataset_v1.2/' # you may need to modify this path.
11 | STANFORD_3D_OUT_PATH = '/root/data/s3dis/s3dis_processed' # you may need to modify this path.
12 | IGNORE_LABEL = 255
13 |
14 |
15 | def mkdir_p(path):
16 | try:
17 | os.makedirs(path)
18 | except OSError as exc:
19 | if exc.errno == errno.EEXIST and os.path.isdir(path):
20 | pass
21 | else:
22 | raise
23 |
24 |
25 | def save_point_cloud(points_3d, filename, binary=True, with_label=False, verbose=True):
26 | """Save an RGB point cloud as a PLY file.
27 | Args:
28 | points_3d: Nx6 matrix where points_3d[:, :3] are the XYZ coordinates and points_3d[:, 4:] are
29 | the RGB values. If Nx3 matrix, save all points with [128, 128, 128] (gray) color.
30 | """
31 | assert points_3d.ndim == 2
32 | if with_label:
33 | assert points_3d.shape[1] == 7
34 | python_types = (float, float, float, int, int, int, int)
35 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'),
36 | ('blue', 'u1'), ('label', 'u1')]
37 | else:
38 | if points_3d.shape[1] == 3:
39 | gray_concat = np.tile(np.array([128], dtype=np.uint8), (points_3d.shape[0], 3))
40 | points_3d = np.hstack((points_3d, gray_concat))
41 | assert points_3d.shape[1] == 6
42 | python_types = (float, float, float, int, int, int)
43 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'),
44 | ('blue', 'u1')]
45 | if binary:
46 | # Format into NumPy structured array
47 | vertices = []
48 | for row_idx in range(points_3d.shape[0]):
49 | cur_point = points_3d[row_idx]
50 | vertices.append(tuple(dtype(point) for dtype, point in zip(python_types, cur_point)))
51 | vertices_array = np.array(vertices, dtype=npy_types)
52 | el = PlyElement.describe(vertices_array, 'vertex')
53 |
54 | # Write
55 | PlyData([el]).write(filename)
56 | else:
57 | raise NotImplementedError
58 | if verbose is True:
59 | print('Saved point cloud to: %s' % filename)
60 |
61 |
62 | class Stanford3DDatasetConverter:
63 |
64 | CLASSES = [
65 | 'ceiling', 'floor', 'wall', 'beam', 'column',
66 | 'window', 'door', 'chair', 'table', 'bookcase',
67 | 'sofa', 'board', 'clutter'
68 | ]
69 | TRAIN_TEXT = 'train'
70 | VAL_TEXT = 'val'
71 | TEST_TEXT = 'test'
72 |
73 | @classmethod
74 | def read_txt(cls, txtfile):
75 | # Read txt file and parse its content.
76 | obj_name = txtfile.split('/')[-1]
77 | if obj_name == 'ceiling_1.txt':
78 | with open(txtfile, 'r') as f:
79 | lines = f.readlines()
80 | for l_i, line in enumerate(lines):
81 | # https://github.com/zeliu98/CloserLook3D/issues/15
82 | if '103.0\x100000' in line:
83 | print(f'Fix bug in {txtfile}')
84 | print(f'Bug line: {line}')
85 | lines[l_i] = line.replace('103.0\x100000', '103.000000')
86 | with open(txtfile, 'w') as f:
87 | f.writelines(lines)
88 | try:
89 | pointcloud = np.loadtxt(txtfile, dtype=np.float32)
90 | except:
91 | print('Bug!!!!!!!!!!!!!!!!!!!!!')
92 | print(obj_name)
93 | print(txtfile)
94 |
95 | # Load point cloud to named numpy array.
96 | pointcloud = np.array(pointcloud).astype(np.float32)
97 | assert pointcloud.shape[1] == 6
98 | xyz = pointcloud[:, :3].astype(np.float32)
99 | rgb = pointcloud[:, 3:].astype(np.uint8)
100 | return xyz, rgb
101 |
102 | @classmethod
103 | def convert_to_ply(cls, root_path, out_path):
104 | """Convert Stanford3DDataset to PLY format that is compatible with
105 | Synthia dataset. Assumes file structure as given by the dataset.
106 | Outputs the processed PLY files to `STANFORD_3D_OUT_PATH`.
107 | """
108 |
109 | txtfiles = glob.glob(os.path.join(root_path, '*/*/*.txt'))
110 | for txtfile in tqdm(txtfiles):
111 | area_name = txtfile.split('/')[-3]
112 | file_sp = os.path.normpath(txtfile).split(os.path.sep)
113 | target_path = os.path.join(out_path, file_sp[-3])
114 | out_file = os.path.join(target_path, file_sp[-2] + '.ply')
115 |
116 | if os.path.exists(out_file):
117 | print(out_file, ' exists')
118 | continue
119 |
120 | annotation, _ = os.path.split(txtfile)
121 | subclouds = glob.glob(os.path.join(annotation, 'Annotations/*.txt'))
122 | coords, feats, labels = [], [], []
123 | for inst, subcloud in enumerate(subclouds):
124 | # Read ply file and parse its rgb values.
125 | xyz, rgb = cls.read_txt(subcloud)
126 | _, annotation_subfile = os.path.split(subcloud)
127 | clsname = annotation_subfile.split('_')[0]
128 | # https://github.com/chrischoy/SpatioTemporalSegmentation/blob/4afee296ebe387d9a06fc1b168c4af212a2b4804/lib/datasets/stanford.py#L20
129 | if clsname == 'stairs':
130 | print('Ignore stairs')
131 | clsidx = IGNORE_LABEL
132 | else:
133 | clsidx = cls.CLASSES.index(clsname)
134 |
135 | coords.append(xyz)
136 | feats.append(rgb)
137 | labels.append(np.full((len(xyz), 1), clsidx, dtype=np.int32))
138 |
139 | if len(coords) == 0:
140 | print(txtfile, ' has 0 files.')
141 | else:
142 | # Concat
143 | coords = np.concatenate(coords, 0)
144 | feats = np.concatenate(feats, 0)
145 | labels = np.concatenate(labels, 0)
146 |
147 | pointcloud = np.concatenate((coords, feats, labels), axis=1)
148 |
149 | # Write ply file.
150 | mkdir_p(target_path)
151 | save_point_cloud(pointcloud, out_file, with_label=True, verbose=False)
152 |
153 |
154 | if __name__ == '__main__':
155 | Stanford3DDatasetConverter.convert_to_ply(STANFORD_3D_IN_PATH, STANFORD_3D_OUT_PATH)
156 |
--------------------------------------------------------------------------------
/src/data/preprocess_scannet.py:
--------------------------------------------------------------------------------
1 | import json
2 | from concurrent.futures import ProcessPoolExecutor
3 | from pathlib import Path
4 |
5 | import pandas as pd
6 | import numpy as np
7 | from plyfile import PlyData, PlyElement
8 |
9 |
10 | SCANNET_RAW_PATH = Path('/root/data/scannetv2_raw') # you may need to modify this path.
11 | SCANNET_OUT_PATH = Path('/root/data/scannet_processed') # you may need to modify this path.
12 | TRAIN_DEST = 'train'
13 | TEST_DEST = 'test'
14 | SUBSETS = {TRAIN_DEST: 'scans', TEST_DEST: 'scans_test'}
15 | POINTCLOUD_FILE = '_vh_clean_2.ply'
16 | BUGS = {
17 | 'train/scene0270_00.ply': 50,
18 | 'train/scene0270_02.ply': 50,
19 | 'train/scene0384_00.ply': 149,
20 | } # https://github.com/ScanNet/ScanNet/issues/20
21 |
22 |
23 | def read_plyfile(path):
24 | with open(path, 'rb') as f:
25 | data = PlyData.read(f)
26 | if data.elements:
27 | return pd.DataFrame(data.elements[0].data).values
28 |
29 |
30 | def save_point_cloud(points_3d, filename, verbose=True):
31 | assert points_3d.ndim == 2
32 | assert points_3d.shape[1] == 8 # x, y, z, r, g, b, semantic_label, instance_label
33 |
34 | python_types = (float, float, float, int, int, int, int, int)
35 | npy_types = [
36 | ('x', 'f4'),
37 | ('y', 'f4'),
38 | ('z', 'f4'),
39 | ('red', 'u1'),
40 | ('green', 'u1'),
41 | ('blue', 'u1'),
42 | ('semantic_label', 'u1'),
43 | ('instance_label', 'u1')
44 | ]
45 |
46 | vertices = []
47 | for row_idx in range(points_3d.shape[0]):
48 | cur_point = points_3d[row_idx]
49 | vertices.append(tuple(dtype(point) for dtype, point in zip(python_types, cur_point)))
50 | vertices_array = np.array(vertices, dtype=npy_types)
51 | el = PlyElement.describe(vertices_array, 'vertex')
52 | # Write
53 | PlyData([el]).write(filename)
54 |
55 | if verbose:
56 | print(f'Saved point cloud to: {filename}')
57 |
58 |
59 | def handle_process(paths):
60 | f = paths[0]
61 | phase_out_path = paths[1]
62 | out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix)
63 | pointcloud = read_plyfile(f)
64 | num_points = pointcloud.shape[0]
65 | # Make sure alpha value is meaningless.
66 | assert np.unique(pointcloud[:, -1]).size == 1
67 | # Load label file.
68 | segment_f = f.with_suffix('.0.010000.segs.json')
69 | segment_group_f = (f.parent / f.name[:-len(POINTCLOUD_FILE)]).with_suffix('.aggregation.json')
70 | semantic_f = f.parent / (f.stem + '.labels' + f.suffix)
71 |
72 | if semantic_f.is_file():
73 | semantic_label = read_plyfile(semantic_f)
74 | # Sanity check that the pointcloud and its label has same vertices.
75 | assert pointcloud.shape[0] == semantic_label.shape[0]
76 | assert np.allclose(pointcloud[:, :3], semantic_label[:, :3])
77 | semantic_label = semantic_label[:, -1]
78 | # Load instance label
79 | with open(segment_f) as f:
80 | segment = np.array(json.load(f)['segIndices'])
81 | with open(segment_group_f) as f:
82 | segment_groups = json.load(f)['segGroups']
83 | assert segment.size == num_points
84 | instance_label = np.zeros(num_points)
85 | for group_idx, segment_group in enumerate(segment_groups):
86 | for segment_idx in segment_group['segments']:
87 | instance_label[segment == segment_idx] = group_idx + 1
88 | else: # Label may not exist in test case.
89 | semantic_label = np.zeros(num_points)
90 | instance_label = np.zeros(num_points)
91 |
92 | processed = np.hstack((pointcloud[:, :6], semantic_label[:, None], instance_label[:, None]))
93 | save_point_cloud(processed, out_f, verbose=False)
94 |
95 |
96 | def main():
97 | path_list = []
98 | for out_path, in_path in SUBSETS.items():
99 | phase_out_path = SCANNET_OUT_PATH / out_path
100 | phase_out_path.mkdir(parents=True, exist_ok=True)
101 | for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE):
102 | path_list.append([f, phase_out_path])
103 |
104 | pool = ProcessPoolExecutor(max_workers=20)
105 | result = list(pool.map(handle_process, path_list))
106 |
107 | # Fix bug in the data.
108 | for files, bug_index in BUGS.items():
109 | print(files)
110 |
111 | for f in SCANNET_OUT_PATH.glob(files):
112 | pointcloud = read_plyfile(f)
113 | bug_mask = pointcloud[:, -2] == bug_index
114 | print(f'Fixing {f} bugged label {bug_index} x {bug_mask.sum()}...')
115 | pointcloud[bug_mask, -2] = 0
116 | save_point_cloud(pointcloud, f, verbose=False)
117 |
118 |
119 | if __name__ == '__main__':
120 | print('Preprocessing ScanNetV2 dataset...')
121 | main()
122 |
--------------------------------------------------------------------------------
/src/data/s3dis_loader.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from typing import Optional
3 |
4 | import gin
5 | import numpy as np
6 | import torch
7 | import pytorch_lightning as pl
8 |
9 | from src.data.scannet_loader import read_ply
10 | from src.data.collate import CollationFunctionFactory
11 | from src.data.sampler import InfSampler
12 | import src.data.transforms as T
13 |
14 | CLASSES = [
15 | 'ceiling',
16 | 'floor',
17 | 'wall',
18 | 'beam',
19 | 'column',
20 | 'window',
21 | 'door',
22 | 'chair',
23 | 'table',
24 | 'bookcase',
25 | 'sofa',
26 | 'board',
27 | 'clutter',
28 | ]
29 |
30 |
31 | def read_txt(path):
32 | """Read txt file into lines.
33 | """
34 | with open(path) as f:
35 | lines = f.readlines()
36 | lines = [x.strip() for x in lines]
37 | return lines
38 |
39 |
40 | @gin.configurable
41 | class S3DISArea5DatasetBase(torch.utils.data.Dataset):
42 | IN_CHANNELS = None
43 | NUM_CLASSES = 13
44 | SPLIT_FILES = {
45 | 'train': ['area1.txt', 'area2.txt', 'area3.txt', 'area4.txt', 'area6.txt'],
46 | 'val': ['area5.txt'],
47 | 'test': ['area5.txt']
48 | }
49 |
50 | def __init__(self, phase, data_root, transform=None, ignore_label=255):
51 | assert self.IN_CHANNELS is not None
52 | assert phase in ['train', 'val', 'test']
53 | super(S3DISArea5DatasetBase, self).__init__()
54 |
55 | self.phase = phase
56 | self.data_root = data_root
57 | self.transform = transform
58 | self.ignore_label = ignore_label
59 | self.split_files = self.SPLIT_FILES[phase]
60 |
61 | filenames = []
62 | for split_file in self.split_files:
63 | filenames += read_txt(osp.join(self.data_root, 'meta_data', split_file))
64 | self.filenames = [
65 | osp.join(self.data_root, 's3dis_processed', fname) for fname in filenames
66 | ]
67 |
68 | def __len__(self):
69 | return len(self.filenames)
70 |
71 | def get_classnames(self):
72 | return CLASSES
73 |
74 | def __getitem__(self, idx):
75 | data = self._load_data(idx)
76 | coords, feats, labels = self.get_cfl_from_data(data)
77 | if self.transform is not None:
78 | coords, feats, labels = self.transform(coords, feats, labels)
79 | coords = torch.from_numpy(coords)
80 | feats = torch.from_numpy(feats)
81 | labels = torch.from_numpy(labels)
82 | return coords.float(), feats.float(), labels.long(), None
83 |
84 | def get_cfl_from_data(self, data):
85 | raise NotImplementedError
86 |
87 | def _load_data(self, idx):
88 | filename = self.filenames[idx]
89 | data = read_ply(filename)
90 | return data
91 |
92 |
93 | @gin.configurable
94 | class S3DISArea5RGBDataset(S3DISArea5DatasetBase):
95 | IN_CHANNELS = 3
96 |
97 | def __init__(self, phase, data_root, transform=None, ignore_label=255):
98 | super(S3DISArea5RGBDataset, self).__init__(phase, data_root, transform, ignore_label)
99 |
100 | def get_cfl_from_data(self, data):
101 | xyz, rgb, label = data[:, :3], data[:, 3:6], data[:, 6]
102 | return (
103 | xyz.astype(np.float32), rgb.astype(np.float32), label.astype(np.int64)
104 | )
105 |
106 |
107 | @gin.configurable
108 | class S3DISArea5RGBDataModule(pl.LightningDataModule):
109 | def __init__(
110 | self,
111 | data_root,
112 | train_batch_size,
113 | val_batch_size,
114 | train_num_workers,
115 | val_num_workers,
116 | collation_type,
117 | train_transforms,
118 | eval_transforms,
119 | ):
120 | super(S3DISArea5RGBDataModule, self).__init__()
121 | self.data_root = data_root
122 | self.train_batch_size = train_batch_size
123 | self.val_batch_size = val_batch_size
124 | self.train_num_workers = train_num_workers
125 | self.val_num_workers = val_num_workers
126 | self.collate_fn = CollationFunctionFactory(collation_type)
127 | self.train_transforms_ = train_transforms
128 | self.eval_transforms_ = eval_transforms
129 |
130 | def setup(self, stage: Optional[str] = None):
131 | if stage == "fit" or stage is None:
132 | train_transforms = []
133 | if self.train_transforms_ is not None:
134 | for name in self.train_transforms_:
135 | train_transforms.append(getattr(T, name)())
136 | train_transforms = T.Compose(train_transforms)
137 | self.dset_train = S3DISArea5RGBDataset("train", self.data_root, train_transforms)
138 | eval_transforms = []
139 | if self.eval_transforms_ is not None:
140 | for name in self.eval_transforms_:
141 | eval_transforms.append(getattr(T, name)())
142 | eval_transforms = T.Compose(eval_transforms)
143 | self.dset_val = S3DISArea5RGBDataset("val", self.data_root, eval_transforms)
144 |
145 | def train_dataloader(self):
146 | return torch.utils.data.DataLoader(
147 | self.dset_train, batch_size=self.train_batch_size, sampler=InfSampler(self.dset_train, True),
148 | num_workers=self.train_num_workers, collate_fn=self.collate_fn
149 | )
150 |
151 | def val_dataloader(self):
152 | return torch.utils.data.DataLoader(
153 | self.dset_val, batch_size=self.val_batch_size, shuffle=False, num_workers=self.val_num_workers,
154 | drop_last=False, collate_fn=self.collate_fn
155 | )
--------------------------------------------------------------------------------
/src/data/sampler.py:
--------------------------------------------------------------------------------
1 | import gin
2 | import torch
3 | from torch.utils.data.sampler import Sampler
4 |
5 |
6 | @gin.configurable
7 | class InfSampler(Sampler):
8 | """Samples elements randomly, without replacement.
9 | Arguments:
10 | data_source (Dataset): dataset to sample from
11 | """
12 |
13 | def __init__(self, data_source, shuffle=False):
14 | self.data_source = data_source
15 | self.shuffle = shuffle
16 | self.reset_permutation()
17 |
18 | def reset_permutation(self):
19 | perm = len(self.data_source)
20 | if self.shuffle:
21 | perm = torch.randperm(perm)
22 | else:
23 | perm = torch.arange(perm)
24 | self._perm = perm.tolist()
25 |
26 | def __iter__(self):
27 | return self
28 |
29 | def __next__(self):
30 | if len(self._perm) == 0:
31 | self.reset_permutation()
32 | return self._perm.pop()
33 |
34 | def __len__(self):
35 | return len(self.data_source)
--------------------------------------------------------------------------------
/src/data/scannet_loader.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from typing import Optional
3 |
4 | import gin
5 | import numpy as np
6 | from plyfile import PlyData
7 | from pandas import DataFrame
8 | import torch
9 | import pytorch_lightning as pl
10 |
11 | from src.data.collate import CollationFunctionFactory
12 | import src.data.transforms as T
13 |
14 | SCANNET_COLOR_MAP = {
15 | 0: (0., 0., 0.),
16 | 1: (174., 199., 232.),
17 | 2: (152., 223., 138.),
18 | 3: (31., 119., 180.),
19 | 4: (255., 187., 120.),
20 | 5: (188., 189., 34.),
21 | 6: (140., 86., 75.),
22 | 7: (255., 152., 150.),
23 | 8: (214., 39., 40.),
24 | 9: (197., 176., 213.),
25 | 10: (148., 103., 189.),
26 | 11: (196., 156., 148.),
27 | 12: (23., 190., 207.), # No 13
28 | 14: (247., 182., 210.),
29 | 15: (66., 188., 102.),
30 | 16: (219., 219., 141.),
31 | 17: (140., 57., 197.),
32 | 18: (202., 185., 52.),
33 | 19: (51., 176., 203.),
34 | 20: (200., 54., 131.),
35 | 21: (92., 193., 61.),
36 | 22: (78., 71., 183.),
37 | 23: (172., 114., 82.),
38 | 24: (255., 127., 14.),
39 | 25: (91., 163., 138.),
40 | 26: (153., 98., 156.),
41 | 27: (140., 153., 101.),
42 | 28: (158., 218., 229.),
43 | 29: (100., 125., 154.),
44 | 30: (178., 127., 135.), # No 31
45 | 32: (146., 111., 194.),
46 | 33: (44., 160., 44.),
47 | 34: (112., 128., 144.),
48 | 35: (96., 207., 209.),
49 | 36: (227., 119., 194.),
50 | 37: (213., 92., 176.),
51 | 38: (94., 106., 211.),
52 | 39: (82., 84., 163.),
53 | 40: (100., 85., 144.),
54 | }
55 | VALID_CLASS_LABELS = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39)
56 | VALID_CLASS_NAMES = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table',
57 | 'door', 'window', 'bookshelf', 'picture', 'counter', 'desk',
58 | 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink',
59 | 'bathtub', 'otherfurniture')
60 |
61 |
62 | def read_ply(filename):
63 | with open(osp.join(filename), 'rb') as f:
64 | plydata = PlyData.read(f)
65 | assert plydata.elements
66 | data = DataFrame(plydata.elements[0].data).values
67 | return data
68 |
69 |
70 | @gin.configurable
71 | class ScanNetDatasetBase(torch.utils.data.Dataset):
72 | IN_CHANNELS = None
73 | CLASS_LABELS = None
74 | SPLIT_FILES = {
75 | 'train': 'scannetv2_train.txt',
76 | 'val': 'scannetv2_val.txt',
77 | 'trainval': 'scannetv2_trainval.txt',
78 | 'test': 'scannetv2_test.txt',
79 | 'overfit': 'scannetv2_overfit.txt'
80 | }
81 |
82 | def __init__(self, phase, data_root, transform=None, ignore_label=255):
83 | assert self.IN_CHANNELS is not None
84 | assert self.CLASS_LABELS is not None
85 | assert phase in self.SPLIT_FILES.keys()
86 | super(ScanNetDatasetBase, self).__init__()
87 |
88 | self.phase = phase
89 | self.data_root = data_root
90 | self.transform = transform
91 | self.ignore_label = ignore_label
92 | self.split_file = self.SPLIT_FILES[phase]
93 | self.ignore_class_labels = tuple(set(range(41)) - set(self.CLASS_LABELS))
94 | self.labelmap = self.get_labelmap()
95 | self.labelmap_inverse = self.get_labelmap_inverse()
96 |
97 | with open(osp.join(self.data_root, 'meta_data', self.split_file), 'r') as f:
98 | filenames = f.read().splitlines()
99 |
100 | sub_dir = 'test' if phase == 'test' else 'train'
101 | self.filenames = [
102 | osp.join(self.data_root, 'scannet_processed', sub_dir, f'{filename}.ply')
103 | for filename in filenames
104 | ]
105 |
106 | def __len__(self):
107 | return len(self.filenames)
108 |
109 | def get_classnames(self):
110 | classnames = {}
111 | for class_id in self.CLASS_LABELS:
112 | classnames[self.labelmap[class_id]] = VALID_CLASS_NAMES[VALID_CLASS_LABELS.index(class_id)]
113 | return classnames
114 |
115 | def get_colormaps(self):
116 | colormaps = {}
117 | for class_id in self.CLASS_LABELS:
118 | colormaps[self.labelmap[class_id]] = SCANNET_COLOR_MAP[class_id]
119 | return colormaps
120 |
121 | def get_labelmap(self):
122 | labelmap = {}
123 | for k in range(41):
124 | if k in self.ignore_class_labels:
125 | labelmap[k] = self.ignore_label
126 | else:
127 | labelmap[k] = self.CLASS_LABELS.index(k)
128 | return labelmap
129 |
130 | def get_labelmap_inverse(self):
131 | labelmap_inverse = {}
132 | for k, v in self.labelmap.items():
133 | labelmap_inverse[v] = self.ignore_label if v == self.ignore_label else k
134 | return labelmap_inverse
135 |
136 |
137 | @gin.configurable
138 | class ScanNetRGBDataset(ScanNetDatasetBase):
139 | IN_CHANNELS = 3
140 | CLASS_LABELS = VALID_CLASS_LABELS
141 | NUM_CLASSES = len(VALID_CLASS_LABELS) # 20
142 |
143 | def __getitem__(self, idx):
144 | data = self._load_data(idx)
145 | coords, feats, labels = self.get_cfl_from_data(data)
146 | if self.transform is not None:
147 | coords, feats, labels = self.transform(coords, feats, labels)
148 | coords = torch.from_numpy(coords)
149 | feats = torch.from_numpy(feats)
150 | labels = torch.from_numpy(labels)
151 | return coords.float(), feats.float(), labels.long(), None
152 |
153 | def _load_data(self, idx):
154 | filename = self.filenames[idx]
155 | data = read_ply(filename)
156 | return data
157 |
158 | def get_cfl_from_data(self, data):
159 | xyz, rgb, labels = data[:, :3], data[:, 3:6], data[:, -2]
160 | labels = np.array([self.labelmap[x] for x in labels])
161 | return (
162 | xyz.astype(np.float32),
163 | rgb.astype(np.float32),
164 | labels.astype(np.int64)
165 | )
166 |
167 |
168 | @gin.configurable
169 | class ScanNetRGBDataModule(pl.LightningDataModule):
170 | def __init__(
171 | self,
172 | data_root,
173 | train_batch_size,
174 | val_batch_size,
175 | train_num_workers,
176 | val_num_workers,
177 | collation_type,
178 | train_transforms,
179 | eval_transforms,
180 | ):
181 | super(ScanNetRGBDataModule, self).__init__()
182 | self.data_root = data_root
183 | self.train_batch_size = train_batch_size
184 | self.val_batch_size = val_batch_size
185 | self.train_num_workers = train_num_workers
186 | self.val_num_workers = val_num_workers
187 | self.collate_fn = CollationFunctionFactory(collation_type)
188 | self.train_transforms_ = train_transforms
189 | self.eval_transforms_ = eval_transforms
190 |
191 | def setup(self, stage: Optional[str] = None):
192 | if stage == "fit" or stage is None:
193 | train_transforms = []
194 | if self.train_transforms_ is not None:
195 | for name in self.train_transforms_:
196 | train_transforms.append(getattr(T, name)())
197 | train_transforms = T.Compose(train_transforms)
198 | self.dset_train = ScanNetRGBDataset("train", self.data_root, train_transforms)
199 | eval_transforms = []
200 | if self.eval_transforms_ is not None:
201 | for name in self.eval_transforms_:
202 | eval_transforms.append(getattr(T, name)())
203 | eval_transforms = T.Compose(eval_transforms)
204 | self.dset_val = ScanNetRGBDataset("val", self.data_root, eval_transforms)
205 |
206 | def train_dataloader(self):
207 | return torch.utils.data.DataLoader(
208 | self.dset_train, batch_size=self.train_batch_size, shuffle=True, drop_last=False,
209 | num_workers=self.train_num_workers, collate_fn=self.collate_fn
210 | )
211 |
212 | def val_dataloader(self):
213 | return torch.utils.data.DataLoader(
214 | self.dset_val, batch_size=self.val_batch_size, shuffle=False, num_workers=self.val_num_workers,
215 | drop_last=False, collate_fn=self.collate_fn
216 | )
217 |
218 |
219 | @gin.configurable
220 | class ScanNetRGBDataset_(ScanNetRGBDataset):
221 | def __getitem__(self, idx):
222 | data, filename = self._load_data(idx)
223 | coords, feats, labels = self.get_cfl_from_data(data)
224 | if self.transform is not None:
225 | coords, feats, labels = self.transform(coords, feats, labels)
226 | coords = torch.from_numpy(coords)
227 | feats = torch.from_numpy(feats)
228 | labels = torch.from_numpy(labels)
229 | return coords.float(), feats.float(), labels.long(), filename
230 |
231 | def _load_data(self, idx):
232 | filename = self.filenames[idx]
233 | data = read_ply(filename)
234 | return data, filename
--------------------------------------------------------------------------------
/src/data/transforms.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 |
4 | import gin
5 | import numpy as np
6 | import scipy
7 | import scipy.ndimage
8 | import scipy.interpolate
9 | from scipy.linalg import expm, norm
10 | import torch
11 | import MinkowskiEngine as ME
12 |
13 |
14 | def homogeneous_coords(coords):
15 | assert isinstance(coords, torch.Tensor) and coords.shape[1] == 3
16 | return torch.cat([coords, torch.ones((len(coords), 1))], dim=1)
17 |
18 |
19 | class Compose(object):
20 | """Composes several transforms together."""
21 |
22 | def __init__(self, transforms):
23 | self.transforms = transforms
24 |
25 | def __call__(self, *args):
26 | for t in self.transforms:
27 | args = t(*args)
28 | return args
29 |
30 |
31 | # A sparse tensor consists of coordinates and associated features.
32 | # You must apply augmentation to both.
33 | # In 2D, flip, shear, scale, and rotation of images are coordinate transformation
34 | # color jitter, hue, etc., are feature transformations
35 | ##############################
36 | # Coordinate transformations
37 | ##############################
38 | @gin.configurable
39 | class ElasticDistortion:
40 | def __init__(self, distortion_params=[(4, 16), (8, 24)], application_ratio=0.9):
41 | self.application_ratio = application_ratio
42 | self.distortion_params = distortion_params
43 | logging.info(
44 | f"{self.__class__.__name__} distortion_params:{distortion_params} with application_ratio:{application_ratio}"
45 | )
46 |
47 | def elastic_distortion(self, coords, feats, labels, granularity, magnitude):
48 | """Apply elastic distortion on sparse coordinate space.
49 | pointcloud: numpy array of (number of points, at least 3 spatial dims)
50 | granularity: size of the noise grid (in same scale[m/cm] as the voxel grid)
51 | magnitude: noise multiplier
52 | """
53 | blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3
54 | blury = np.ones((1, 3, 1, 1)).astype("float32") / 3
55 | blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3
56 | coords_min = coords.min(0)
57 |
58 | # Create Gaussian noise tensor of the size given by granularity.
59 | noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3
60 | noise = np.random.randn(*noise_dim, 3).astype(np.float32)
61 |
62 | # Smoothing.
63 | for _ in range(2):
64 | noise = scipy.ndimage.filters.convolve(
65 | noise, blurx, mode="constant", cval=0
66 | )
67 | noise = scipy.ndimage.filters.convolve(
68 | noise, blury, mode="constant", cval=0
69 | )
70 | noise = scipy.ndimage.filters.convolve(
71 | noise, blurz, mode="constant", cval=0
72 | )
73 |
74 | # Trilinear interpolate noise filters for each spatial dimensions.
75 | ax = [
76 | np.linspace(d_min, d_max, d)
77 | for d_min, d_max, d in zip(
78 | coords_min - granularity,
79 | coords_min + granularity * (noise_dim - 2),
80 | noise_dim,
81 | )
82 | ]
83 | interp = scipy.interpolate.RegularGridInterpolator(
84 | ax, noise, bounds_error=0, fill_value=0
85 | )
86 | coords += interp(coords) * magnitude
87 | return coords, feats, labels
88 |
89 | def __call__(self, coords, feats, labels):
90 | if self.distortion_params is not None:
91 | if random.random() < self.application_ratio:
92 | for granularity, magnitude in self.distortion_params:
93 | coords, feats, labels = self.elastic_distortion(
94 | coords, feats, labels, granularity, magnitude
95 | )
96 | return coords, feats, labels
97 |
98 |
99 | # Rotation matrix along axis with angle theta
100 | def M(axis, theta):
101 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta))
102 |
103 |
104 | @gin.configurable
105 | class RandomRotation(object):
106 | def __init__(self, upright_axis="z", axis_std=0.01, application_ratio=0.9):
107 | self.upright_axis = {"x": 0, "y": 1, "z": 2}[upright_axis.lower()]
108 | self.D = 3
109 | # Use the rest of axes for flipping.
110 | self.horz_axes = set(range(self.D)) - set([self.upright_axis])
111 | self.application_ratio = application_ratio
112 | self.axis_std = axis_std
113 | logging.info(
114 | f"{self.__class__.__name__} upright_axis:{upright_axis}, axis_std:{axis_std} with application_ratio:{application_ratio}"
115 | )
116 |
117 | def __call__(self, coords, feats, labels):
118 | if random.random() < self.application_ratio:
119 | axis = self.axis_std * np.random.randn(3)
120 | axis[self.upright_axis] += 1
121 | angle = random.random() * 2 * np.pi
122 | coords = coords @ M(axis, angle)
123 | return coords, feats, labels
124 |
125 |
126 | @gin.configurable
127 | class RandomTranslation(object):
128 | def __init__(
129 | self,
130 | max_translation=3,
131 | application_ratio=0.9,
132 | ):
133 | self.max_translation = max_translation
134 | self.application_ratio = application_ratio
135 | logging.info(
136 | f"{self.__class__.__name__} max_translation:{max_translation} with application_ratio:{application_ratio}"
137 | )
138 |
139 | def __call__(self, coords, feats, labels):
140 | if random.random() < self.application_ratio:
141 | coords += 2 * (np.random.rand(1, 3) - 0.5) * self.max_translation
142 | return coords, feats, labels
143 |
144 |
145 | @gin.configurable
146 | class RandomScale(object):
147 | def __init__(self, scale_ratio=0.1, application_ratio=0.9):
148 | self.scale_ratio = scale_ratio
149 | self.application_ratio = application_ratio
150 | logging.info(f"{self.__class__.__name__}(scale_ratio={scale_ratio})")
151 |
152 | def __call__(self, coords, feats, labels):
153 | if random.random() < self.application_ratio:
154 | coords = coords * np.random.uniform(
155 | low=1 - self.scale_ratio, high=1 + self.scale_ratio
156 | )
157 | return coords, feats, labels
158 |
159 |
160 | @gin.configurable
161 | class RandomCrop(object):
162 | def __init__(self, x, y, z, application_ratio=1, min_cardinality=100, max_retries=10):
163 | assert x > 0
164 | assert y > 0
165 | assert z > 0
166 | self.application_ratio = application_ratio
167 | self.max_size = np.array([[x, y, z]])
168 | self.min_cardinality = min_cardinality
169 | self.max_retries = max_retries
170 | logging.info(f"{self.__class__.__name__} with size {self.max_size}")
171 |
172 | def __call__(self, coords: np.array, feats, labels):
173 | if random.random() > self.application_ratio:
174 | return coords, feats, labels
175 |
176 | norm_coords = coords - coords.min(0, keepdims=True)
177 | max_coords = norm_coords.max(0, keepdims=True)
178 | # start range
179 | coord_range = max_coords - self.max_size
180 | coord_range = np.clip(coord_range, a_min=0, a_max=float("inf"))
181 | # If crop size is larger than the coordinates, return orig
182 | if np.prod(coord_range == 0):
183 | return coords, feats, labels
184 |
185 | # sample crop start point
186 | valid = False
187 | retries = 0
188 | while not valid:
189 | min_box = np.random.rand(1, 3) * coord_range
190 | max_box = min_box + self.max_size
191 | sel = np.logical_and(
192 | np.prod(norm_coords > min_box, 1), np.prod(norm_coords < max_box, 1)
193 | )
194 | if np.sum(sel) > self.min_cardinality:
195 | valid = True
196 | retries += 1
197 | if retries % 2 == 0:
198 | logging.warn(f"RandomCrop retries: {retries}. crop_range={coord_range}")
199 | if retries >= self.max_retries:
200 | break
201 |
202 | if valid:
203 | return (
204 | coords[sel],
205 | feats if feats is None else feats[sel],
206 | labels if labels is None else labels[sel],
207 | )
208 | return coords, feats, labels
209 |
210 |
211 | @gin.configurable
212 | class RandomAffine(object):
213 | def __init__(
214 | self,
215 | upright_axis="z",
216 | axis_std=0.1,
217 | scale_range=0.2,
218 | affine_range=0.1,
219 | application_ratio=0.9,
220 | ):
221 | self.upright_axis = {"x": 0, "y": 1, "z": 2}[upright_axis.lower()]
222 | self.D = 3
223 | self.scale_range = scale_range
224 | self.affine_range = affine_range
225 | # Use the rest of axes for flipping.
226 | self.horz_axes = set(range(self.D)) - set([self.upright_axis])
227 | self.application_ratio = application_ratio
228 | self.axis_std = axis_std
229 | logging.info(
230 | f"{self.__class__.__name__} upright_axis:{upright_axis}, axis_std:{axis_std}, scale_range:{scale_range}, affine_range:{affine_range} with application_ratio:{application_ratio}"
231 | )
232 |
233 | def __call__(self, coords, feats, labels):
234 | if random.random() < self.application_ratio:
235 | axis = self.axis_std * np.random.randn(3)
236 | axis[self.upright_axis] += 1
237 | angle = 2 * (random.random() - 0.5) * np.pi
238 | T = M(axis, angle) @ (
239 | np.diag(2 * (np.random.rand(3) - 0.5) * self.scale_range + 1)
240 | + 2 * (np.random.rand(3, 3) - 0.5) * self.affine_range
241 | )
242 | coords = coords @ T
243 | return coords, feats, labels
244 |
245 |
246 | @gin.configurable
247 | class RandomHorizontalFlip(object):
248 | def __init__(self, upright_axis="z", application_ratio=0.9):
249 | """
250 | upright_axis: axis index among x,y,z, i.e. 2 for z
251 | """
252 | self.D = 3
253 | self.upright_axis = {"x": 0, "y": 1, "z": 2}[upright_axis.lower()]
254 | # Use the rest of axes for flipping.
255 | self.horz_axes = set(range(self.D)) - set([self.upright_axis])
256 | self.application_ratio = application_ratio
257 | logging.info(
258 | f"{self.__class__.__name__} upright_axis:{upright_axis} with application_ratio:{application_ratio}"
259 | )
260 |
261 | def __call__(self, coords, feats, labels):
262 | if random.random() < self.application_ratio:
263 | for curr_ax in self.horz_axes:
264 | if random.random() < 0.5:
265 | coord_max = np.max(coords[:, curr_ax])
266 | coords[:, curr_ax] = coord_max - coords[:, curr_ax]
267 | return coords, feats, labels
268 |
269 |
270 | @gin.configurable
271 | class CoordinateDropout(object):
272 | def __init__(self, dropout_ratio=0.2, application_ratio=0.2):
273 | self.dropout_ratio = dropout_ratio
274 | self.application_ratio = application_ratio
275 | logging.info(
276 | f"{self.__class__.__name__} dropout:{dropout_ratio} with application_ratio:{application_ratio}"
277 | )
278 |
279 | def __call__(self, coords, feats, labels):
280 | if random.random() < self.application_ratio:
281 | N = len(coords)
282 | inds = np.random.choice(N, int(N * (1 - self.dropout_ratio)), replace=False)
283 | return (
284 | coords[inds],
285 | feats if feats is None else feats[inds],
286 | labels if labels is None else labels[inds],
287 | )
288 | return coords, feats, labels
289 |
290 |
291 | @gin.configurable
292 | class CoordinateJitter(object):
293 | def __init__(self, jitter_std=0.5, application_ratio=0.7):
294 | self.jitter_std = jitter_std
295 | self.application_ratio = application_ratio
296 | logging.info(
297 | f"{self.__class__.__name__} jitter_std:{jitter_std} with application_ratio:{application_ratio}"
298 | )
299 |
300 | def __call__(self, coords, feats, labels):
301 | if random.random() < self.application_ratio:
302 | N = len(coords)
303 | coords += (2 * self.jitter_std) * (np.random.rand(N, 3) - 0.5)
304 | return coords, feats, labels
305 |
306 |
307 | @gin.configurable
308 | class CoordinateUniformTranslation:
309 | def __init__(self, max_translation=0.2):
310 | self.max_translation = max_translation
311 |
312 | def __call__(self, coords, feats, labels):
313 | if self.max_translation > 0:
314 | coords += np.random.uniform(
315 | low=-self.max_translation, high=self.max_translation, size=[1, 3]
316 | )
317 | return coords, feats, labels
318 |
319 |
320 | @gin.configurable
321 | class RegionDropout(object):
322 | def __init__(
323 | self,
324 | box_center_range=[100, 100, 10],
325 | max_region_size=[300, 300, 300],
326 | min_region_size=[100, 100, 100],
327 | application_ratio=0.3,
328 | ):
329 | self.max_region_size = np.array(max_region_size)
330 | self.min_region_size = np.array(min_region_size)
331 | self.box_range = self.max_region_size - self.min_region_size
332 | self.box_center_range = np.array([box_center_range])
333 | self.application_ratio = application_ratio
334 | logging.info(
335 | f"{self.__class__.__name__} max_region_size:{max_region_size} min_region_size:{min_region_size} box_center_range:{box_center_range} with application_ratio:{application_ratio}"
336 | )
337 |
338 | def __call__(self, coords, feats, labels):
339 | if random.random() < self.application_ratio:
340 | while True:
341 | box_center = self.box_center_range * (
342 | np.random.rand(1, 3) - 0.5
343 | ) * 2 + np.mean(coords, axis=0, keepdims=True)
344 | box_size = self.box_range * np.random.rand(1, 3)
345 | min_xyz = box_center - box_size / 2
346 | max_xyz = box_center + box_size / 2
347 | sel = np.logical_not(
348 | np.prod(coords < max_xyz, axis=1)
349 | * np.prod(coords > min_xyz, axis=1)
350 | )
351 | if sel.sum() > len(coords) * 0.5:
352 | break
353 | return coords[sel], feats[sel], labels[sel]
354 | return coords, feats, labels
355 |
356 |
357 | @gin.configurable
358 | class DimensionlessCoordinates(object):
359 | def __init__(self, voxel_size=0.02):
360 | self.voxel_size = voxel_size
361 | logging.info(f"{self.__class__.__name__} with voxel_size:{voxel_size}")
362 |
363 | def __call__(self, coords, feats, labels):
364 | return coords / self.voxel_size, feats, labels
365 |
366 |
367 | @gin.configurable
368 | class PerlinNoise:
369 | def __init__(
370 | self, noise_params=[(4, 4), (16, 16)], application_ratio=0.9, device="cpu"
371 | ):
372 | self.application_ratio = application_ratio
373 | self.noise_params = noise_params
374 | logging.info(
375 | f"{self.__class__.__name__} noise_params:{noise_params} with application_ratio:{application_ratio}"
376 | )
377 | self.interp = ME.MinkowskiInterpolation()
378 | self.corners = torch.Tensor(
379 | [
380 | [0, 0, 0],
381 | [0, 0, 1],
382 | [0, 1, 0],
383 | [1, 0, 0],
384 | [0, 1, 1],
385 | [1, 1, 0],
386 | [1, 0, 1],
387 | [1, 1, 1],
388 | ]
389 | )
390 | self.smooth = ME.MinkowskiConvolution(
391 | in_channels=3,
392 | out_channels=3,
393 | kernel_size=3,
394 | bias=False,
395 | dimension=3,
396 | )
397 | self.smooth.kernel[:] = 1 / 27
398 | if device is None and torch.cuda.is_available():
399 | self.device = "cuda"
400 | else:
401 | self.device = "cpu"
402 | self.smooth = self.smooth.to(self.device)
403 | self.corners = self.corners.to(self.device)
404 |
405 | def perlin_noise(self, coordinates, noise_quantization_size, noise_std):
406 | aug_coordinates = coordinates.reshape(-1, 1, 3) + (
407 | self.corners * noise_quantization_size
408 | ).reshape(1, 8, 3)
409 | bcoords = ME.utils.batched_coordinates(
410 | [aug_coordinates.reshape(-1, 3) / noise_quantization_size],
411 | dtype=torch.float32,
412 | )
413 | noise_tensor = ME.SparseTensor(
414 | features=torch.randn((len(bcoords), 3), device=self.device),
415 | coordinates=bcoords,
416 | device=self.device,
417 | )
418 | noise_tensor = self.smooth(noise_tensor)
419 | interp_noise = self.interp(
420 | noise_tensor,
421 | ME.utils.batched_coordinates(
422 | [coordinates / noise_quantization_size],
423 | dtype=torch.float32,
424 | device=self.device,
425 | ),
426 | )
427 | return coordinates + noise_std * interp_noise
428 |
429 | def __call__(self, coords, feats, labels):
430 | if self.noise_params is not None:
431 | if random.random() < self.application_ratio:
432 | coords = torch.from_numpy(coords).to(self.device)
433 | with torch.no_grad():
434 | for quantization_size, noise_std in self.noise_params:
435 | coords = self.perlin_noise(coords, quantization_size, noise_std)
436 | coords = coords.cpu().numpy()
437 | return coords, feats, labels
438 |
439 |
440 | ##############################
441 | # Feature transformations
442 | ##############################
443 | @gin.configurable
444 | class ChromaticTranslation(object):
445 | """Add random color to the image, input must be an array in [0,255] or a PIL image"""
446 |
447 | def __init__(self, translation_range_ratio=1e-1, application_ratio=0.9):
448 | """
449 | trans_range_ratio: ratio of translation i.e. 255 * 2 * ratio * rand(-0.5, 0.5)
450 | """
451 | self.trans_range_ratio = translation_range_ratio
452 | self.application_ratio = application_ratio
453 | logging.info(
454 | f"{self.__class__.__name__} with translation_range_ratio:{translation_range_ratio} with application_ratio:{application_ratio}"
455 | )
456 |
457 | def __call__(self, coords, feats, labels):
458 | if random.random() < self.application_ratio:
459 | tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.trans_range_ratio
460 | feats[:, :3] = np.clip(tr + feats[:, :3], 0, 255)
461 | return coords, feats, labels
462 |
463 |
464 | @gin.configurable
465 | class ChromaticJitter(object):
466 | def __init__(self, std=0.01, application_ratio=0.9):
467 | self.std = std
468 | self.application_ratio = application_ratio
469 | logging.info(
470 | f"{self.__class__.__name__} with std:{std} with application_ratio:{application_ratio}"
471 | )
472 |
473 | def __call__(self, coords, feats, labels):
474 | if random.random() < self.application_ratio:
475 | noise = np.random.randn(feats.shape[0], 3)
476 | noise *= self.std * 255
477 | feats[:, :3] = np.clip(noise + feats[:, :3], 0, 255)
478 | return coords, feats, labels
479 |
480 |
481 | @gin.configurable
482 | class ChromaticAutoContrast(object):
483 | def __init__(
484 | self, randomize_blend_factor=True, blend_factor=0.5, application_ratio=0.2
485 | ):
486 | self.randomize_blend_factor = randomize_blend_factor
487 | self.blend_factor = blend_factor
488 | self.application_ratio = application_ratio
489 | logging.info(
490 | f"{self.__class__.__name__} with randomize_blend_factor:{randomize_blend_factor}, blend_factor:{blend_factor} with application_ratio:{application_ratio}"
491 | )
492 |
493 | def __call__(self, coords, feats, labels):
494 | if random.random() < self.application_ratio:
495 | # mean = np.mean(feats, 0, keepdims=True)
496 | # std = np.std(feats, 0, keepdims=True)
497 | # lo = mean - std
498 | # hi = mean + std
499 | lo = feats[:, :3].min(0, keepdims=True)
500 | hi = feats[:, :3].max(0, keepdims=True)
501 | assert hi.max() > 1, f"invalid color value. Color is supposed to be [0-255]"
502 |
503 | if np.prod(hi - lo):
504 | scale = 255 / (hi - lo)
505 | contrast_feats = (feats[:, :3] - lo) * scale
506 | blend_factor = (
507 | random.random() if self.randomize_blend_factor else self.blend_factor
508 | )
509 | feats[:, :3] = (1 - blend_factor) * feats + blend_factor * contrast_feats
510 | return coords, feats, labels
511 |
512 |
513 | @gin.configurable
514 | class NormalizeColor(object):
515 | def __init__(self, mean=[128, 128, 128], std=[256, 256, 256], pre_norm=False):
516 | self.mean = np.array([mean], dtype=np.float32)
517 | self.std = np.array([std], dtype=np.float32)
518 | self.pre_norm = pre_norm
519 | logging.info(f"{self.__class__.__name__} mean:{mean} std:{std}")
520 |
521 | def __call__(self, coords, feats, labels):
522 | if self.pre_norm:
523 | feats = feats / 255.
524 | return coords, (feats - self.mean) / self.std, labels
525 |
526 |
527 | @gin.configurable
528 | class HueSaturationTranslation(object):
529 | @staticmethod
530 | def rgb_to_hsv(rgb):
531 | # Translated from source of colorsys.rgb_to_hsv
532 | # r,g,b should be a numpy arrays with values between 0 and 255
533 | # rgb_to_hsv returns an array of floats between 0.0 and 1.0.
534 | rgb = rgb.astype("float")
535 | hsv = np.zeros_like(rgb)
536 | # in case an RGBA array was passed, just copy the A channel
537 | hsv[..., 3:] = rgb[..., 3:]
538 | r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
539 | maxc = np.max(rgb[..., :3], axis=-1)
540 | minc = np.min(rgb[..., :3], axis=-1)
541 | hsv[..., 2] = maxc
542 | mask = maxc != minc
543 | hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask]
544 | rc = np.zeros_like(r)
545 | gc = np.zeros_like(g)
546 | bc = np.zeros_like(b)
547 | rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask]
548 | gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask]
549 | bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask]
550 | hsv[..., 0] = np.select(
551 | [r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc
552 | )
553 | hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0
554 | return hsv
555 |
556 | @staticmethod
557 | def hsv_to_rgb(hsv):
558 | # Translated from source of colorsys.hsv_to_rgb
559 | # h,s should be a numpy arrays with values between 0.0 and 1.0
560 | # v should be a numpy array with values between 0.0 and 255.0
561 | # hsv_to_rgb returns an array of uints between 0 and 255.
562 | rgb = np.empty_like(hsv)
563 | rgb[..., 3:] = hsv[..., 3:]
564 | h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
565 | i = (h * 6.0).astype("uint8")
566 | f = (h * 6.0) - i
567 | p = v * (1.0 - s)
568 | q = v * (1.0 - s * f)
569 | t = v * (1.0 - s * (1.0 - f))
570 | i = i % 6
571 | conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5]
572 | rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v)
573 | rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t)
574 | rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p)
575 | return rgb.astype("uint8")
576 |
577 | def __init__(self, hue_max, saturation_max):
578 | self.hue_max = hue_max
579 | self.saturation_max = saturation_max
580 |
581 | def __call__(self, coords, feats, labels):
582 | # Assume feat[:, :3] is rgb
583 | hsv = HueSaturationTranslation.rgb_to_hsv(feats[:, :3])
584 | hue_val = (random.random() - 0.5) * 2 * self.hue_max
585 | sat_ratio = 1 + (random.random() - 0.5) * 2 * self.saturation_max
586 | hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1)
587 | hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1)
588 | feats[:, :3] = np.clip(HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255)
589 | return coords, feats, labels
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from src.models.spvcnn import SPVCNN
4 | import src.models.resnet as resnets
5 | import src.models.resunet as resunets
6 | import src.models.fast_point_transformer as transformers
7 |
8 | MODELS = [SPVCNN]
9 |
10 |
11 | def add_models(module):
12 | MODELS.extend([getattr(module, a) for a in dir(module) if "Net" in a or "Transformer" in a])
13 |
14 |
15 | add_models(resnets)
16 | add_models(resunets)
17 | add_models(transformers)
18 |
19 |
20 | def get_model(name):
21 | """Creates and returns an instance of the model given its class name."""
22 | # Find the model class from its name
23 | all_models = MODELS
24 | mdict = {model.__name__: model for model in all_models}
25 | if name not in mdict:
26 | logging.info(f"Invalid model index. You put {name}. Options are:")
27 | # Display a list of valid model names
28 | for model in all_models:
29 | logging.info("\t* {}".format(model.__name__))
30 | return None
31 | model_class = mdict[name]
32 | return model_class
--------------------------------------------------------------------------------
/src/models/common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import MinkowskiEngine as ME
3 |
4 |
5 | @torch.no_grad()
6 | def downsample_points(points, tensor_map, field_map, size):
7 | down_points = ME.MinkowskiSPMMAverageFunction().apply(
8 | tensor_map, field_map, size, points
9 | )
10 | _, counts = torch.unique(tensor_map, return_counts=True)
11 | return down_points, counts.unsqueeze_(1).type_as(down_points)
12 |
13 |
14 | @torch.no_grad()
15 | def stride_centroids(points, counts, rows, cols, size):
16 | stride_centroids = ME.MinkowskiSPMMFunction().apply(rows, cols, counts, size, points)
17 | ones = torch.ones(size[1], dtype=points.dtype, device=points.device)
18 | stride_counts = ME.MinkowskiSPMMFunction().apply(rows, cols, ones, size, counts)
19 | stride_counts.clamp_(min=1)
20 | return torch.true_divide(stride_centroids, stride_counts), stride_counts
21 |
22 |
23 | def downsample_embeddings(embeddings, inverse_map, size, mode="avg"):
24 | assert len(embeddings) == size[1]
25 | assert mode in ["avg", "max"]
26 | if mode == "max":
27 | in_map = torch.arange(size[1], dtype=inverse_map.dtype, device=inverse_map.device)
28 | down_embeddings = ME.MinkowskiDirectMaxPoolingFunction().apply(
29 | in_map, inverse_map, embeddings, size[0]
30 | )
31 | else:
32 | cols = torch.arange(size[1], dtype=inverse_map.dtype, device=inverse_map.device)
33 | down_embeddings = ME.MinkowskiSPMMAverageFunction().apply(
34 | inverse_map, cols, size, embeddings
35 | )
36 | return down_embeddings
--------------------------------------------------------------------------------
/src/models/fast_point_transformer.py:
--------------------------------------------------------------------------------
1 | import gin
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import MinkowskiEngine as ME
6 |
7 | from src.models.transformer_base import LocalSelfAttentionBase, ResidualBlockWithPointsBase
8 | from src.models.common import stride_centroids, downsample_points, downsample_embeddings
9 | import src.cuda_ops.functions.sparse_ops as ops
10 |
11 |
12 | class MaxPoolWithPoints(nn.Module):
13 | def __init__(self, kernel_size=2, stride=2):
14 | assert kernel_size == 2 and stride == 2
15 | super(MaxPoolWithPoints, self).__init__()
16 | self.pool = ME.MinkowskiMaxPooling(kernel_size=kernel_size, stride=stride, dimension=3)
17 |
18 | def forward(self, stensor, points, counts):
19 | assert isinstance(stensor, ME.SparseTensor)
20 | assert len(stensor) == len(points)
21 | cm = stensor.coordinate_manager
22 | down_stensor = self.pool(stensor)
23 | cols, rows = cm.stride_map(stensor.coordinate_map_key, down_stensor.coordinate_map_key)
24 | size = torch.Size([len(down_stensor), len(stensor)])
25 | down_points, down_counts = stride_centroids(points, counts, rows, cols, size)
26 | return down_stensor, down_points, down_counts
27 |
28 |
29 | ####################################
30 | # Layers
31 | ####################################
32 | @gin.configurable
33 | class LightweightSelfAttentionLayer(LocalSelfAttentionBase):
34 | def __init__(
35 | self,
36 | in_channels,
37 | out_channels=None,
38 | kernel_size=3,
39 | stride=1,
40 | dilation=1,
41 | num_heads=8,
42 | ):
43 | out_channels = in_channels if out_channels is None else out_channels
44 | assert out_channels % num_heads == 0
45 | assert kernel_size % 2 == 1
46 | assert stride == 1, "Currently, this layer only supports stride == 1"
47 | assert dilation == 1,"Currently, this layer only supports dilation == 1"
48 | super(LightweightSelfAttentionLayer, self).__init__(kernel_size, stride, dilation, dimension=3)
49 |
50 | self.in_channels = in_channels
51 | self.out_channels = out_channels
52 | self.kernel_size = kernel_size
53 | self.stride = stride
54 | self.dilation = dilation
55 | self.num_heads = num_heads
56 | self.attn_channels = out_channels // num_heads
57 |
58 | self.to_query = nn.Sequential(
59 | ME.MinkowskiLinear(in_channels, out_channels),
60 | ME.MinkowskiToFeature()
61 | )
62 | self.to_value = nn.Sequential(
63 | ME.MinkowskiLinear(in_channels, out_channels),
64 | ME.MinkowskiToFeature()
65 | )
66 | self.to_out = nn.Linear(out_channels, out_channels)
67 |
68 | self.inter_pos_enc = nn.Parameter(torch.FloatTensor(self.kernel_volume, self.num_heads, self.attn_channels))
69 | self.intra_pos_mlp = nn.Sequential(
70 | nn.Linear(3, 3, bias=False),
71 | nn.BatchNorm1d(3),
72 | nn.ReLU(inplace=True),
73 | nn.Linear(3, in_channels, bias=False),
74 | nn.BatchNorm1d(in_channels),
75 | nn.ReLU(inplace=True),
76 | nn.Linear(in_channels, in_channels)
77 | )
78 | nn.init.normal_(self.inter_pos_enc, 0, 1)
79 |
80 | def forward(self, stensor, norm_points):
81 | dtype = stensor._F.dtype
82 | device = stensor._F.device
83 |
84 | # query, key, value, and relative positional encoding
85 | intra_pos_enc = self.intra_pos_mlp(norm_points)
86 | stensor = stensor + intra_pos_enc
87 | q = self.to_query(stensor).view(-1, self.num_heads, self.attn_channels).contiguous()
88 | v = self.to_value(stensor).view(-1, self.num_heads, self.attn_channels).contiguous()
89 |
90 | # key-query map
91 | kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
92 | kq_map = self.key_query_map_from_kernel_map(kernel_map)
93 |
94 | # attention weights with cosine similarity
95 | attn = torch.zeros((kq_map.shape[1], self.num_heads), dtype=dtype, device=device)
96 | norm_q = F.normalize(q, p=2, dim=-1)
97 | norm_pos_enc = F.normalize(self.inter_pos_enc, p=2, dim=-1)
98 | attn = ops.dot_product_cuda(norm_q, norm_pos_enc, attn, kq_map)
99 |
100 | # aggregation & the output
101 | out_F = torch.zeros((len(q), self.num_heads, self.attn_channels),
102 | dtype=dtype,
103 | device=device)
104 | kq_indices = self.key_query_indices_from_key_query_map(kq_map)
105 | out_F = ops.scalar_attention_cuda(attn, v, out_F, kq_indices)
106 | out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
107 | return ME.SparseTensor(out_F,
108 | coordinate_map_key=out_key,
109 | coordinate_manager=stensor.coordinate_manager)
110 |
111 |
112 | ####################################
113 | # Blocks
114 | ####################################
115 | @gin.configurable
116 | class LightweightSelfAttentionBlock(ResidualBlockWithPointsBase):
117 | LAYER = LightweightSelfAttentionLayer
118 |
119 |
120 | ####################################
121 | # Models
122 | ####################################
123 | @gin.configurable
124 | class FastPointTransformer(nn.Module):
125 | INIT_DIM = 32
126 | ENC_DIM = 32
127 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
128 | PLANES = (64, 128, 384, 640, 384, 384, 256, 128)
129 | QMODE = ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE
130 | LAYER = LightweightSelfAttentionLayer
131 | BLOCK = LightweightSelfAttentionBlock
132 |
133 | def __init__(self, in_channels, out_channels):
134 | super(FastPointTransformer, self).__init__()
135 | self.in_channels = in_channels
136 | self.out_channels = out_channels
137 |
138 | self.enc_mlp = nn.Sequential(
139 | nn.Linear(3, self.ENC_DIM, bias=False),
140 | nn.BatchNorm1d(self.ENC_DIM),
141 | nn.Tanh(),
142 | nn.Linear(self.ENC_DIM, self.ENC_DIM, bias=False),
143 | nn.BatchNorm1d(self.ENC_DIM),
144 | nn.Tanh()
145 | )
146 | self.attn0p1 = self.LAYER(in_channels + self.ENC_DIM, self.INIT_DIM, kernel_size=5)
147 | self.bn0 = ME.MinkowskiBatchNorm(self.INIT_DIM)
148 |
149 | self.attn1p1 = self.LAYER(self.INIT_DIM, self.PLANES[0])
150 | self.bn1 = ME.MinkowskiBatchNorm(self.PLANES[0])
151 | self.block1 = nn.ModuleList([self.BLOCK(self.PLANES[0]) for _ in range(self.LAYERS[0])])
152 |
153 | self.attn2p2 = self.LAYER(self.PLANES[0], self.PLANES[1])
154 | self.bn2 = ME.MinkowskiBatchNorm(self.PLANES[1])
155 | self.block2 = nn.ModuleList([self.BLOCK(self.PLANES[1]) for _ in range(self.LAYERS[1])])
156 |
157 | self.attn3p4 = self.LAYER(self.PLANES[1], self.PLANES[2])
158 | self.bn3 = ME.MinkowskiBatchNorm(self.PLANES[2])
159 | self.block3 = nn.ModuleList([self.BLOCK(self.PLANES[2]) for _ in range(self.LAYERS[2])])
160 |
161 | self.attn4p8 = self.LAYER(self.PLANES[2], self.PLANES[3])
162 | self.bn4 = ME.MinkowskiBatchNorm(self.PLANES[3])
163 | self.block4 = nn.ModuleList([self.BLOCK(self.PLANES[3]) for _ in range(self.LAYERS[3])])
164 |
165 | self.attn5p8 = self.LAYER(self.PLANES[3] + self.PLANES[3], self.PLANES[4])
166 | self.bn5 = ME.MinkowskiBatchNorm(self.PLANES[4])
167 | self.block5 = nn.ModuleList([self.BLOCK(self.PLANES[4]) for _ in range(self.LAYERS[4])])
168 |
169 | self.attn6p4 = self.LAYER(self.PLANES[4] + self.PLANES[2], self.PLANES[5])
170 | self.bn6 = ME.MinkowskiBatchNorm(self.PLANES[5])
171 | self.block6 = nn.ModuleList([self.BLOCK(self.PLANES[5]) for _ in range(self.LAYERS[5])])
172 |
173 | self.attn7p2 = self.LAYER(self.PLANES[5] + self.PLANES[1], self.PLANES[6])
174 | self.bn7 = ME.MinkowskiBatchNorm(self.PLANES[6])
175 | self.block7 = nn.ModuleList([self.BLOCK(self.PLANES[6]) for _ in range(self.LAYERS[6])])
176 |
177 | self.attn8p1 = self.LAYER(self.PLANES[6] + self.PLANES[0], self.PLANES[7])
178 | self.bn8 = ME.MinkowskiBatchNorm(self.PLANES[7])
179 | self.block8 = nn.ModuleList([self.BLOCK(self.PLANES[7]) for _ in range(self.LAYERS[7])])
180 |
181 | self.final = nn.Sequential(
182 | nn.Linear(self.PLANES[7] + self.ENC_DIM, self.PLANES[7], bias=False),
183 | nn.BatchNorm1d(self.PLANES[7]),
184 | nn.ReLU(inplace=True),
185 | nn.Linear(self.PLANES[7], out_channels)
186 | )
187 | self.relu = ME.MinkowskiReLU(inplace=True)
188 | self.pool = MaxPoolWithPoints()
189 | self.pooltr = ME.MinkowskiPoolingTranspose(kernel_size=2, stride=2, dimension=3)
190 |
191 | @torch.no_grad()
192 | def normalize_points(self, points, centroids, tensor_map):
193 | tensor_map = tensor_map if tensor_map.dtype == torch.int64 else tensor_map.long()
194 | norm_points = points - centroids[tensor_map]
195 | return norm_points
196 |
197 | @torch.no_grad()
198 | def normalize_centroids(self, down_points, coordinates, tensor_stride):
199 | norm_points = (down_points - coordinates[:, 1:]) / tensor_stride - 0.5
200 | return norm_points
201 |
202 | def voxelize_with_centroids(self, x: ME.TensorField):
203 | cm = x.coordinate_manager
204 | points = x.C[:, 1:]
205 |
206 | out = x.sparse()
207 | size = torch.Size([len(out), len(x)])
208 | tensor_map, field_map = cm.field_to_sparse_map(x.coordinate_key, out.coordinate_key)
209 | points_p1, count_p1 = downsample_points(points, tensor_map, field_map, size)
210 | norm_points = self.normalize_points(points, points_p1, tensor_map)
211 |
212 | pos_embs = self.enc_mlp(norm_points)
213 | down_pos_embs = downsample_embeddings(pos_embs, tensor_map, size, mode="avg")
214 | out = ME.SparseTensor(torch.cat([out.F, down_pos_embs], dim=1),
215 | coordinate_map_key=out.coordinate_key,
216 | coordinate_manager=cm)
217 |
218 | norm_points_p1 = self.normalize_centroids(points_p1, out.C, out.tensor_stride[0])
219 | return out, norm_points_p1, points_p1, count_p1, pos_embs
220 |
221 | def devoxelize_with_centroids(self, out: ME.SparseTensor, x: ME.TensorField, h_embs):
222 | out = self.final(torch.cat([out.slice(x).F, h_embs], dim=1))
223 | return out
224 |
225 | def forward(self, x: ME.TensorField):
226 | out, norm_points_p1, points_p1, count_p1, pos_embs = self.voxelize_with_centroids(x)
227 | out = self.relu(self.bn0(self.attn0p1(out, norm_points_p1)))
228 | out_p1 = self.relu(self.bn1(self.attn1p1(out, norm_points_p1)))
229 |
230 | out, points_p2, count_p2 = self.pool(out_p1, points_p1, count_p1)
231 | norm_points_p2 = self.normalize_centroids(points_p2, out.C, out.tensor_stride[0])
232 | for module in self.block1:
233 | out = module(out, norm_points_p2)
234 | out_p2 = self.relu(self.bn2(self.attn2p2(out, norm_points_p2)))
235 |
236 | out, points_p4, count_p4 = self.pool(out_p2, points_p2, count_p2)
237 | norm_points_p4 = self.normalize_centroids(points_p4, out.C, out.tensor_stride[0])
238 | for module in self.block2:
239 | out = module(out, norm_points_p4)
240 | out_p4 = self.relu(self.bn3(self.attn3p4(out, norm_points_p4)))
241 |
242 | out, points_p8, count_p8 = self.pool(out_p4, points_p4, count_p4)
243 | norm_points_p8 = self.normalize_centroids(points_p8, out.C, out.tensor_stride[0])
244 | for module in self.block3:
245 | out = module(out, norm_points_p8)
246 | out_p8 = self.relu(self.bn4(self.attn4p8(out, norm_points_p8)))
247 |
248 | out, points_p16 = self.pool(out_p8, points_p8, count_p8)[:2]
249 | norm_points_p16 = self.normalize_centroids(points_p16, out.C, out.tensor_stride[0])
250 | for module in self.block4:
251 | out = module(out, norm_points_p16)
252 |
253 | out = self.pooltr(out)
254 | out = ME.cat(out, out_p8)
255 | out = self.relu(self.bn5(self.attn5p8(out, norm_points_p8)))
256 | for module in self.block5:
257 | out = module(out, norm_points_p8)
258 |
259 | out = self.pooltr(out)
260 | out = ME.cat(out, out_p4)
261 | out = self.relu(self.bn6(self.attn6p4(out, norm_points_p4)))
262 | for module in self.block6:
263 | out = module(out, norm_points_p4)
264 |
265 | out = self.pooltr(out)
266 | out = ME.cat(out, out_p2)
267 | out = self.relu(self.bn7(self.attn7p2(out, norm_points_p2)))
268 | for module in self.block7:
269 | out = module(out, norm_points_p2)
270 |
271 | out = self.pooltr(out)
272 | out = ME.cat(out, out_p1)
273 | out = self.relu(self.bn8(self.attn8p1(out, norm_points_p1)))
274 | for module in self.block8:
275 | out = module(out, norm_points_p1)
276 |
277 | out = self.devoxelize_with_centroids(out, x, pos_embs)
278 | return out
279 |
280 |
281 | @gin.configurable
282 | class FastPointTransformerSmall(FastPointTransformer):
283 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
284 |
285 |
286 | @gin.configurable
287 | class FastPointTransformerSmaller(FastPointTransformer):
288 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1)
289 |
--------------------------------------------------------------------------------
/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | import gin
2 | import torch.nn as nn
3 | import MinkowskiEngine as ME
4 | from MinkowskiEngine.modules.resnet_block import BasicBlock
5 |
6 |
7 | @gin.configurable
8 | class ResNetBase(nn.Module):
9 | QMODE = ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE
10 | LAYER = ME.MinkowskiConvolution
11 | BLOCK = None
12 | LAYERS = ()
13 | INIT_DIM = 64
14 | PLANES = (64, 128, 256, 512)
15 |
16 | def __init__(self, in_channels, out_channels=None, D=3):
17 | nn.Module.__init__(self)
18 | assert self.BLOCK is not None
19 | self.in_channels = in_channels
20 | self.out_channels = out_channels
21 | self.D = D
22 |
23 | self.network_initialization(in_channels, out_channels, D)
24 | self.weight_initialization()
25 |
26 | def network_initialization(self, in_channels, out_channels, D):
27 | self.inplanes = self.INIT_DIM
28 | self.conv1 = nn.Sequential(
29 | self.LAYER(in_channels, self.inplanes, kernel_size=3, stride=2, dimension=D),
30 | ME.MinkowskiBatchNorm(self.inplanes),
31 | ME.MinkowskiReLU(inplace=True),
32 | ME.MinkowskiSumPooling(kernel_size=2, stride=2, dimension=D),
33 | )
34 |
35 | self.layer1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2)
36 | self.layer2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2)
37 | self.layer3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2)
38 | self.layer4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2)
39 | self.conv5 = nn.Sequential(
40 | ME.MinkowskiDropout(),
41 | ME.MinkowskiConvolution(self.inplanes,
42 | self.inplanes,
43 | kernel_size=3,
44 | stride=3,
45 | dimension=D),
46 | ME.MinkowskiInstanceNorm(self.inplanes),
47 | ME.MinkowskiGELU(),
48 | )
49 | self.glob_pool = ME.MinkowskiGlobalMaxPooling()
50 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True)
51 |
52 | def weight_initialization(self):
53 | for m in self.modules():
54 | if isinstance(m, nn.BatchNorm1d):
55 | nn.init.constant_(m.weight, 1)
56 | nn.init.constant_(m.bias, 0)
57 | elif isinstance(m, ME.MinkowskiBatchNorm):
58 | nn.init.constant_(m.bn.weight, 1)
59 | nn.init.constant_(m.bn.bias, 0)
60 | elif isinstance(m, nn.Linear):
61 | nn.init.xavier_normal_(m.weight)
62 | if m.bias is not None:
63 | nn.init.constant_(m.bias, 0)
64 | elif isinstance(m, ME.MinkowskiLinear):
65 | nn.init.xavier_normal_(m.linear.weight)
66 | if m.linear.bias is not None:
67 | nn.init.constant_(m.linear.bias, 0)
68 | elif isinstance(m, ME.MinkowskiConvolution):
69 | ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu")
70 |
71 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1):
72 | downsample = None
73 | if stride != 1 or self.inplanes != planes * block.expansion:
74 | downsample = nn.Sequential(
75 | self.LAYER(
76 | self.inplanes,
77 | planes * block.expansion,
78 | kernel_size=1,
79 | stride=stride,
80 | dimension=self.D,
81 | ),
82 | ME.MinkowskiBatchNorm(planes * block.expansion),
83 | )
84 | layers = []
85 | layers.append(
86 | block(
87 | self.inplanes,
88 | planes,
89 | stride=stride,
90 | dilation=dilation,
91 | downsample=downsample,
92 | dimension=self.D,
93 | ))
94 | self.inplanes = planes * block.expansion
95 | for _ in range(1, blocks):
96 | layers.append(
97 | block(self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D))
98 | return nn.Sequential(*layers)
99 |
100 | def voxelize(self, x: ME.TensorField):
101 | return x.sparse()
102 |
103 | def forward(self, x: ME.TensorField):
104 | x = self.voxelize(x)
105 | x = self.conv1(x)
106 | x = self.layer1(x)
107 | x = self.layer2(x)
108 | x = self.layer3(x)
109 | x = self.layer4(x)
110 | x = self.conv5(x)
111 | x = self.glob_pool(x)
112 | return self.final(x).F
113 |
114 |
115 | @gin.configurable
116 | class ResNet34(ResNetBase):
117 | BLOCK = BasicBlock
118 | LAYERS = (3, 4, 6, 3)
--------------------------------------------------------------------------------
/src/models/resunet.py:
--------------------------------------------------------------------------------
1 | import gin
2 | import torch
3 | import MinkowskiEngine as ME
4 | from MinkowskiEngine.modules.resnet_block import BasicBlock
5 |
6 | from src.models.resnet import ResNetBase
7 |
8 |
9 | @gin.configurable
10 | class Res16UNetBase(ResNetBase):
11 | INIT_DIM = 32
12 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
13 | PLANES = (32, 64, 128, 256, 256, 256, 256, 256)
14 |
15 | def __init__(self, in_channels, out_channels, D=3):
16 | super(Res16UNetBase, self).__init__(in_channels, out_channels, D)
17 |
18 | def network_initialization(self, in_channels, out_channels, D):
19 | self.inplanes = self.INIT_DIM
20 | self.conv0p1s1 = self.LAYER(in_channels, self.inplanes, kernel_size=5, dimension=D)
21 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes)
22 |
23 | self.conv1p1s2 = self.LAYER(self.inplanes,
24 | self.inplanes,
25 | kernel_size=2,
26 | stride=2,
27 | dimension=D) # pooling
28 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes)
29 | self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0])
30 |
31 | self.conv2p2s2 = self.LAYER(self.inplanes,
32 | self.inplanes,
33 | kernel_size=2,
34 | stride=2,
35 | dimension=D) # pooling
36 | self.bn2 = ME.MinkowskiBatchNorm(self.inplanes)
37 | self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1])
38 |
39 | self.conv3p4s2 = self.LAYER(self.inplanes,
40 | self.inplanes,
41 | kernel_size=2,
42 | stride=2,
43 | dimension=D) # pooling
44 | self.bn3 = ME.MinkowskiBatchNorm(self.inplanes)
45 | self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2])
46 |
47 | self.conv4p8s2 = self.LAYER(self.inplanes,
48 | self.inplanes,
49 | kernel_size=2,
50 | stride=2,
51 | dimension=D) # pooling
52 | self.bn4 = ME.MinkowskiBatchNorm(self.inplanes)
53 | self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3])
54 |
55 | self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose(self.inplanes,
56 | self.PLANES[4],
57 | kernel_size=2,
58 | stride=2,
59 | dimension=D) # unpooling
60 | self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4])
61 | self.inplanes = self.PLANES[
62 | 4] + self.PLANES[2] * self.BLOCK.expansion # concatenated dimension
63 | self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4])
64 |
65 | self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose(self.inplanes,
66 | self.PLANES[5],
67 | kernel_size=2,
68 | stride=2,
69 | dimension=D) # unpooling
70 | self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5])
71 | self.inplanes = self.PLANES[
72 | 5] + self.PLANES[1] * self.BLOCK.expansion # concatenated dimension
73 | self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5])
74 |
75 | self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose(self.inplanes,
76 | self.PLANES[6],
77 | kernel_size=2,
78 | stride=2,
79 | dimension=D) # unpooling
80 | self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6])
81 | self.inplanes = self.PLANES[
82 | 6] + self.PLANES[0] * self.BLOCK.expansion # concatenated dimension
83 | self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], self.LAYERS[6])
84 |
85 | self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose(self.inplanes,
86 | self.PLANES[7],
87 | kernel_size=2,
88 | stride=2,
89 | dimension=D) # unpooling
90 | self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7])
91 | self.inplanes = self.PLANES[7] + self.INIT_DIM # concatenated dimension
92 | self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], self.LAYERS[7])
93 |
94 | self.final = ME.MinkowskiConvolution(self.PLANES[7] * self.BLOCK.expansion,
95 | out_channels,
96 | kernel_size=1,
97 | stride=1,
98 | bias=True,
99 | dimension=D)
100 | self.relu = ME.MinkowskiReLU(inplace=True)
101 |
102 | def voxelize(self, x: ME.TensorField):
103 | raise NotImplementedError()
104 |
105 | def devoxelize(self, out: ME.SparseTensor, x: ME.TensorField, emb: torch.Tensor):
106 | raise NotImplementedError()
107 |
108 | def forward(self, x: ME.TensorField):
109 | out, emb = self.voxelize(x)
110 | out_p1 = self.relu(self.bn0(self.conv0p1s1(out)))
111 |
112 | out = self.relu(self.bn1(self.conv1p1s2(out_p1)))
113 | out_p2 = self.block1(out)
114 |
115 | out = self.relu(self.bn2(self.conv2p2s2(out_p2)))
116 | out_p4 = self.block2(out)
117 |
118 | out = self.relu(self.bn3(self.conv3p4s2(out_p4)))
119 | out_p8 = self.block3(out)
120 |
121 | out = self.relu(self.bn4(self.conv4p8s2(out_p8)))
122 | out = self.block4(out)
123 |
124 | out = self.relu(self.bntr4(self.convtr4p16s2(out)))
125 | out = ME.cat(out, out_p8)
126 | out = self.block5(out)
127 |
128 | out = self.relu(self.bntr5(self.convtr5p8s2(out)))
129 | out = ME.cat(out, out_p4)
130 | out = self.block6(out)
131 |
132 | out = self.relu(self.bntr6(self.convtr6p4s2(out)))
133 | out = ME.cat(out, out_p2)
134 | out = self.block7(out)
135 |
136 | out = self.relu(self.bntr7(self.convtr7p2s2(out)))
137 | out = ME.cat(out, out_p1)
138 | out = self.block8(out)
139 | return self.devoxelize(out, x, emb)
140 |
141 |
142 | @gin.configurable
143 | class Res16UNet34C(Res16UNetBase): # MinkowskiNet42
144 | BLOCK = BasicBlock
145 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
146 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
147 |
148 | def voxelize(self, x: ME.TensorField):
149 | return x.sparse(), None
150 |
151 | def devoxelize(self, out: ME.SparseTensor, x: ME.TensorField, emb: torch.Tensor):
152 | return self.final(out).slice(x).F
153 |
154 |
155 | @gin.configurable
156 | class Res16UNet34CSmall(Res16UNet34C):
157 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
158 |
159 |
160 | @gin.configurable
161 | class Res16UNet34CSmaller(Res16UNet34C):
162 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1)
163 |
--------------------------------------------------------------------------------
/src/models/spvcnn.py:
--------------------------------------------------------------------------------
1 | import gin
2 | import torch
3 | import torch.nn as nn
4 | import MinkowskiEngine as ME
5 |
6 | from src.models.resunet import Res16UNet34C
7 |
8 |
9 | @gin.configurable
10 | class SPVCNN(Res16UNet34C):
11 | def network_initialization(self, in_channels, out_channels, D):
12 | super(SPVCNN, self).network_initialization(in_channels, out_channels, D)
13 | self.final = ME.MinkowskiLinear(self.PLANES[7] * self.BLOCK.expansion, out_channels)
14 |
15 | self.point_transforms = nn.ModuleList([
16 | nn.Sequential(
17 | ME.MinkowskiLinear(self.INIT_DIM, self.PLANES[3]),
18 | ME.MinkowskiBatchNorm(self.PLANES[3]),
19 | ME.MinkowskiReLU(True)
20 | ),
21 | nn.Sequential(
22 | ME.MinkowskiLinear(self.PLANES[4], self.PLANES[5]),
23 | ME.MinkowskiBatchNorm(self.PLANES[5]),
24 | ME.MinkowskiReLU(True)
25 | ),
26 | nn.Sequential(
27 | ME.MinkowskiLinear(self.PLANES[5], self.PLANES[7]),
28 | ME.MinkowskiBatchNorm(self.PLANES[7]),
29 | ME.MinkowskiReLU(True)
30 | ),
31 | ])
32 | self.dropout = ME.MinkowskiDropout(0.3, True)
33 |
34 | def voxel_to_point(self, s: ME.SparseTensor, f: ME.TensorField):
35 | feats, _, out_map, weights = ME.MinkowskiInterpolationFunction().apply(
36 | s.F, f.C, s.coordinate_key, s.coordinate_manager
37 | )
38 | denom = torch.zeros((len(f),), dtype=feats.dtype, device=feats.device)
39 | denom.index_add_(0, out_map.long(), weights)
40 | denom.unsqueeze_(1)
41 | norm_feats = torch.true_divide(feats, denom + 1e-8)
42 | return ME.TensorField(
43 | features=norm_feats,
44 | coordinate_field_map_key=f.coordinate_field_map_key,
45 | quantization_mode=f.quantization_mode,
46 | coordinate_manager=f.coordinate_manager
47 | )
48 |
49 | def forward(self, x: ME.TensorField):
50 | x0 = x.sparse()
51 | x0 = self.relu(self.bn0(self.conv0p1s1(x0)))
52 | z0 = self.voxel_to_point(x0, x)
53 |
54 | x1 = z0.sparse(coordinate_map_key=x0.coordinate_map_key)
55 | x1 = self.relu(self.bn1(self.conv1p1s2(x1)))
56 | x1 = self.block1(x1)
57 |
58 | x2 = self.relu(self.bn2(self.conv2p2s2(x1)))
59 | x2 = self.block2(x2)
60 |
61 | x3 = self.relu(self.bn3(self.conv3p4s2(x2)))
62 | x3 = self.block3(x3)
63 |
64 | x4 = self.relu(self.bn4(self.conv4p8s2(x3)))
65 | x4 = self.block4(x4)
66 |
67 | z1 = self.voxel_to_point(x4, x)
68 | z1 = z1 + self.point_transforms[0](z0).F
69 |
70 | y1 = z1.sparse(coordinate_map_key=x4.coordinate_map_key)
71 | y1 = self.dropout(y1)
72 | y1 = self.relu(self.bntr4(self.convtr4p16s2(y1)))
73 | y1 = ME.cat(y1, x3)
74 | y1 = self.block5(y1)
75 |
76 | y2 = self.relu(self.bntr5(self.convtr5p8s2(y1)))
77 | y2 = ME.cat(y2, x2)
78 | y2 = self.block6(y2)
79 | z2 = self.voxel_to_point(y2, x)
80 | z2 = z2 + self.point_transforms[1](z1).F
81 |
82 | y3 = z2.sparse(coordinate_map_key=x2.coordinate_map_key)
83 | y3 = self.dropout(y3)
84 | y3 = self.relu(self.bntr6(self.convtr6p4s2(y3)))
85 | y3 = ME.cat(y3, x1)
86 | y3 = self.block7(y3)
87 |
88 | y4 = self.relu(self.bntr7(self.convtr7p2s2(y3)))
89 | y4 = ME.cat(y4, x0)
90 | y4 = self.block8(y4)
91 | z3 = self.voxel_to_point(y4, x)
92 | z3 = z3 + self.point_transforms[2](z2).F
93 | return self.final(z3).F
94 |
--------------------------------------------------------------------------------
/src/models/transformer_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import MinkowskiEngine as ME
4 | from MinkowskiEngine.MinkowskiKernelGenerator import KernelGenerator
5 |
6 |
7 | class LocalSelfAttentionBase(nn.Module):
8 | def __init__(self, kernel_size, stride, dilation, dimension):
9 | super(LocalSelfAttentionBase, self).__init__()
10 | self.kernel_size = kernel_size
11 | self.stride = stride
12 | self.dilation = dilation
13 | self.dimension = dimension
14 |
15 | self.kernel_generator = KernelGenerator(kernel_size=kernel_size,
16 | stride=stride,
17 | dilation=dilation,
18 | dimension=dimension)
19 | self.kernel_volume = self.kernel_generator.kernel_volume
20 |
21 | def get_kernel_map_and_out_key(self, stensor):
22 | cm = stensor.coordinate_manager
23 | in_key = stensor.coordinate_key
24 | out_key = cm.stride(in_key, self.kernel_generator.kernel_stride)
25 | region_type, region_offset, _ = self.kernel_generator.get_kernel(
26 | stensor.tensor_stride, False)
27 | kernel_map = cm.kernel_map(in_key,
28 | out_key,
29 | self.kernel_generator.kernel_stride,
30 | self.kernel_generator.kernel_size,
31 | self.kernel_generator.kernel_dilation,
32 | region_type=region_type,
33 | region_offset=region_offset)
34 | return kernel_map, out_key
35 |
36 | def key_query_map_from_kernel_map(self, kernel_map):
37 | kq_map = []
38 | for kernel_idx, in_out in kernel_map.items():
39 | in_out[0] = in_out[0] * self.kernel_volume + kernel_idx
40 | kq_map.append(in_out)
41 | kq_map = torch.cat(kq_map, -1)
42 | return kq_map
43 |
44 | def key_query_indices_from_kernel_map(self, kernel_map):
45 | kq_indices = []
46 | for _, in_out in kernel_map.items():
47 | kq_indices.append(in_out)
48 | kq_indices = torch.cat(kq_indices, -1)
49 | return kq_indices
50 |
51 | def key_query_indices_from_key_query_map(self, kq_map):
52 | kq_indices = kq_map.clone()
53 | kq_indices[0] = kq_indices[0] // self.kernel_volume
54 | return kq_indices
55 |
56 |
57 | class ResidualBlockWithPointsBase(nn.Module):
58 | LAYER = None
59 |
60 | def __init__(self, in_channels, out_channels=None, kernel_size=3):
61 | out_channels = in_channels if out_channels is None else out_channels
62 | assert self.LAYER is not None
63 | super(ResidualBlockWithPointsBase, self).__init__()
64 |
65 | self.layer1 = self.LAYER(in_channels, out_channels, kernel_size)
66 | self.norm1 = ME.MinkowskiBatchNorm(out_channels)
67 | self.layer2 = self.LAYER(out_channels, out_channels, kernel_size)
68 | self.norm2 = ME.MinkowskiBatchNorm(out_channels)
69 | self.relu = ME.MinkowskiReLU(inplace=True)
70 |
71 | def forward(self, stensor, norm_points):
72 | residual = stensor
73 | out = self.layer1(stensor, norm_points)
74 | out = self.norm1(out)
75 | out = self.relu(out)
76 | out = self.layer2(out, norm_points)
77 | out = self.norm2(out)
78 | out += residual
79 | out = self.relu(out)
80 | return out
81 |
--------------------------------------------------------------------------------
/src/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .segmentation import LitSegMinkowskiModule
2 |
3 | modules = [LitSegMinkowskiModule]
4 | modules_dict = {m.__name__: m for m in modules}
5 |
6 |
7 | def get_lightning_module(name):
8 | assert (
9 | name in modules_dict.keys()
10 | ), f"{name} not in {modules_dict.keys()}"
11 | return modules_dict[name]
--------------------------------------------------------------------------------
/src/modules/segmentation.py:
--------------------------------------------------------------------------------
1 | import gin
2 | import numpy as np
3 | import torch
4 | import pytorch_lightning as pl
5 | import pl_bolts
6 | import torchmetrics
7 | import MinkowskiEngine as ME
8 |
9 | from src.utils.metric import per_class_iou
10 |
11 |
12 | @gin.configurable
13 | class LitSegmentationModuleBase(pl.LightningModule):
14 | def __init__(
15 | self,
16 | model,
17 | num_classes,
18 | lr,
19 | momentum,
20 | weight_decay,
21 | warmup_steps_ratio,
22 | max_steps,
23 | best_metric_type,
24 | ignore_label=255,
25 | dist_sync_metric=False,
26 | lr_eta_min=0.,
27 | ):
28 | super(LitSegmentationModuleBase, self).__init__()
29 | for name, value in vars().items():
30 | if name not in ["self", "__class__"]:
31 | setattr(self, name, value)
32 |
33 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)
34 | self.best_metric_value = -np.inf if best_metric_type == "maximize" else np.inf
35 | self.metric = torchmetrics.ConfusionMatrix(
36 | num_classes=num_classes,
37 | compute_on_step=False,
38 | dist_sync_on_step=dist_sync_metric
39 | )
40 |
41 | def configure_optimizers(self):
42 | optimizer = torch.optim.SGD(
43 | self.model.parameters(),
44 | lr=self.lr,
45 | momentum=self.momentum,
46 | weight_decay=self.weight_decay,
47 | )
48 | scheduler = pl_bolts.optimizers.LinearWarmupCosineAnnealingLR(
49 | optimizer,
50 | warmup_epochs=int(self.warmup_steps_ratio * self.max_steps),
51 | max_epochs=self.max_steps,
52 | eta_min=self.lr_eta_min,
53 | )
54 | return {
55 | "optimizer": optimizer,
56 | "lr_scheduler": {
57 | "scheduler": scheduler,
58 | "interval": "step"
59 | }
60 | }
61 |
62 | def training_step(self, batch, batch_idx):
63 | in_data = self.prepare_input_data(batch)
64 | logits = self.model(in_data)
65 | loss = self.criterion(logits, batch["labels"])
66 | self.log("train_loss", loss.item(), batch_size=batch["batch_size"], logger=True, prog_bar=True)
67 | torch.cuda.empty_cache()
68 | return loss
69 |
70 | def validation_step(self, batch, batch_idx):
71 | in_data = self.prepare_input_data(batch)
72 | logits = self.model(in_data)
73 | loss = self.criterion(logits, batch["labels"])
74 | self.log("val_loss", loss.item(), batch_size=batch["batch_size"], logger=True, prog_bar=True)
75 | pred = logits.argmax(dim=1, keepdim=False)
76 | mask = batch["labels"] != self.ignore_label
77 | self.metric(pred[mask], batch["labels"][mask])
78 | torch.cuda.empty_cache()
79 | return loss
80 |
81 | def validation_epoch_end(self, outputs):
82 | confusion_matrix = self.metric.compute().cpu().numpy()
83 | self.metric.reset()
84 | ious = per_class_iou(confusion_matrix) * 100
85 | accs = confusion_matrix.diagonal() / confusion_matrix.sum(1) * 100
86 | miou = np.nanmean(ious)
87 | macc = np.nanmean(accs)
88 |
89 | def compare(prev, cur):
90 | return prev < cur if self.best_metric_type == "maximize" else prev > cur
91 |
92 | if compare(self.best_metric_value, miou):
93 | self.best_metric_value = miou
94 | self.log("val_best_mIoU", self.best_metric_value, logger=True)
95 | self.log("val_mIoU", miou, logger=True)
96 | self.log("val_mAcc", macc, logger=True)
97 |
98 | def prepare_input_data(self, batch):
99 | raise NotImplementedError
100 |
101 |
102 | @gin.configurable
103 | class LitSegMinkowskiModule(LitSegmentationModuleBase):
104 | def prepare_input_data(self, batch):
105 | in_data = ME.TensorField(
106 | features=batch["features"],
107 | coordinates=batch["coordinates"],
108 | quantization_mode=self.model.QMODE
109 | )
110 | return in_data
--------------------------------------------------------------------------------
/src/utils/file.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def ensure_dir(path):
5 | if not os.path.exists(path):
6 | os.makedirs(path, mode=0o755)
--------------------------------------------------------------------------------
/src/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | from rich.logging import RichHandler
5 |
6 |
7 | def setup_logger(exp_name, debug):
8 | from imp import reload
9 |
10 | reload(logging)
11 |
12 | CUDA_TAG = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
13 | EXP_TAG = exp_name
14 |
15 | logger_config = dict(
16 | level=logging.DEBUG if debug else logging.INFO,
17 | format=f"{CUDA_TAG}:[{EXP_TAG}] %(message)s",
18 | handlers=[RichHandler()],
19 | datefmt="[%X]",
20 | )
21 | logging.basicConfig(**logger_config)
--------------------------------------------------------------------------------
/src/utils/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def fast_hist(pred, target, num_classes, ignore_label=255):
6 | mask = (target != ignore_label) & (target < num_classes)
7 | return np.bincount(num_classes * target[mask].astype(int) + pred[mask],
8 | minlength=num_classes**2).reshape(num_classes, num_classes)
9 |
10 |
11 | def per_class_iou(hist):
12 | with np.errstate(divide='ignore', invalid='ignore'):
13 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
14 |
--------------------------------------------------------------------------------
/src/utils/misc.py:
--------------------------------------------------------------------------------
1 | import gin
2 |
3 |
4 | @gin.configurable
5 | def logged_hparams(keys):
6 | C = dict()
7 | for k in keys:
8 | C[k] = gin.query_parameter(f"{k}")
9 | return C
10 |
11 |
12 | def load_from_pl_state_dict(model, pl_state_dict):
13 | state_dict = {}
14 | for k, v in pl_state_dict.items():
15 | state_dict[k[6:]] = v
16 | model.load_state_dict(state_dict)
17 | return model
--------------------------------------------------------------------------------
/src/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open3d as o3d
3 |
4 |
5 | def make_open3d_point_cloud(xyz, color=None):
6 | pcd = o3d.geometry.PointCloud()
7 | pcd.points = o3d.utility.Vector3dVector(xyz)
8 | if color is not None:
9 | if len(color) != len(xyz):
10 | color = np.tile(color, (len(xyz), 1))
11 | pcd.colors = o3d.utility.Vector3dVector(color)
12 | return pcd
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from datetime import datetime
4 |
5 | import gin
6 | import pytorch_lightning as pl
7 |
8 | from src.models import get_model
9 | from src.data import get_data_module
10 | from src.modules import get_lightning_module
11 | from src.utils.file import ensure_dir
12 | from src.utils.logger import setup_logger
13 | from src.utils.misc import logged_hparams
14 |
15 |
16 | @gin.configurable
17 | def train(
18 | save_path,
19 | project_name,
20 | run_name,
21 | lightning_module_name,
22 | data_module_name,
23 | model_name,
24 | gpus,
25 | log_every_n_steps,
26 | check_val_every_n_epoch,
27 | refresh_rate_per_second,
28 | best_metric,
29 | max_epoch,
30 | max_step,
31 | ):
32 | now = datetime.now().strftime('%m-%d-%H-%M-%S')
33 | run_name = run_name + "_" + now
34 | save_path = os.path.join(save_path, run_name)
35 | ensure_dir(save_path)
36 |
37 | data_module = get_data_module(data_module_name)()
38 | model = get_model(model_name)()
39 | pl_module = get_lightning_module(lightning_module_name)(model=model, max_steps=max_step)
40 | gin.finalize()
41 |
42 | hparams = logged_hparams()
43 | callbacks = [
44 | pl.callbacks.TQDMProgressBar(refresh_rate=refresh_rate_per_second),
45 | pl.callbacks.ModelCheckpoint(
46 | dirpath=save_path, monitor=best_metric, save_last=True, save_top_k=1, mode="max"
47 | ),
48 | pl.callbacks.LearningRateMonitor(),
49 | ]
50 | loggers = [
51 | pl.loggers.WandbLogger(
52 | name=run_name,
53 | save_dir=save_path,
54 | project=project_name,
55 | log_model=True,
56 | entity="chrockey",
57 | config=hparams,
58 | )
59 | ]
60 | additional_kwargs = dict()
61 | if gpus > 1:
62 | raise NotImplementedError("Currently, multi-gpu training is not supported.")
63 |
64 | trainer = pl.Trainer(
65 | default_root_dir=save_path,
66 | max_epochs=max_epoch,
67 | max_steps=max_step,
68 | gpus=gpus,
69 | callbacks=callbacks,
70 | logger=loggers,
71 | log_every_n_steps=log_every_n_steps,
72 | check_val_every_n_epoch=check_val_every_n_epoch,
73 | **additional_kwargs
74 | )
75 |
76 | # write config file
77 | with open(os.path.join(save_path, "config.gin"), "w") as f:
78 | f.write(gin.operative_config_str())
79 |
80 | trainer.fit(pl_module, data_module)
81 |
82 |
83 | if __name__ == "__main__":
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument("config", type=str)
86 | parser.add_argument("--save_path", type=str, default="experiments")
87 | parser.add_argument("--run_name", type=str, default="default")
88 | parser.add_argument("--seed", type=int, default=1235)
89 | parser.add_argument("-v", "--voxel_size", type=float, default=None)
90 | parser.add_argument("--debug", action="store_true")
91 | args = parser.parse_args()
92 |
93 | pl.seed_everything(args.seed)
94 | gin.parse_config_file(args.config)
95 | if args.voxel_size is not None:
96 | gin.bind_parameter("DimensionlessCoordinates.voxel_size", args.voxel_size)
97 | setup_logger(args.run_name, args.debug)
98 |
99 | train(save_path=args.save_path, run_name=args.run_name)
--------------------------------------------------------------------------------