├── .gitignore ├── LICENSE ├── Mask4Former3D ├── README.md ├── conf │ ├── augmentation │ │ └── volumentations_aug.yaml │ ├── callbacks │ │ ├── callbacks_panoptic.yaml │ │ └── callbacks_panoptic3d.yaml │ ├── config_panoptic_3d.yaml │ ├── data │ │ ├── collation_functions │ │ │ ├── voxelize_collate.yaml │ │ │ └── voxelize_collate_merge.yaml │ │ ├── data_loaders │ │ │ └── simple_loader.yaml │ │ ├── datasets │ │ │ ├── semantic_kitti.yaml │ │ │ └── semantic_kitti_206.yaml │ │ ├── kitti.yaml │ │ └── kitti3d.yaml │ ├── logging │ │ └── full.yaml │ ├── loss │ │ └── set_criterion.yaml │ ├── matcher │ │ └── hungarian_matcher.yaml │ ├── metric │ │ ├── lstq.yaml │ │ └── pq.yaml │ ├── model │ │ └── mask4former3d.yaml │ ├── optimizer │ │ └── adamw.yaml │ ├── scheduler │ │ └── onecyclelr.yaml │ ├── semantic-kitti.yaml │ └── trainer │ │ └── trainer30.yaml ├── data │ └── .gitkeep ├── datasets │ ├── lidar.py │ ├── lidar_no_preprocessing.py │ ├── preprocessing │ │ └── semantic_kitti_preprocessing.py │ └── utils.py ├── main_panoptic.py ├── models │ ├── __init__.py │ ├── criterion.py │ ├── mask4former.py │ ├── matcher.py │ ├── metrics │ │ ├── __init__.py │ │ ├── panoptic_eval.py │ │ └── panoptic_quality.py │ ├── model.py │ ├── modules │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── common.py │ │ ├── helpers_3detr.py │ │ └── resnet_block.py │ ├── position_embedding.py │ ├── res16unet.py │ ├── resnet.py │ └── resunet.py ├── scripts │ ├── test.sh │ ├── train.sh │ └── val.sh ├── trainer │ ├── pq_trainer.py │ └── trainer.py └── utils │ ├── __init__.py │ └── utils.py ├── README.md ├── compute_object_level_ood.py ├── compute_point_level_ood.py ├── dataset ├── __init__.py └── stu_image_dataset.py ├── docs └── gthub_teaser.jpg ├── generate_dbscan_instances.py └── utils ├── __init__.py └── common.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Alexey Nekrasov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Mask4Former3D/README.md: -------------------------------------------------------------------------------- 1 | # Mask4Former3D 2 | 3 | This code adapts the Mask4Former codebase for 3D panoptic segmentation. 4 | 5 | For more details, please check the original Mask4Former codebase here: [Mask4Former](https://github.com/YilmazKadir/Mask4Former). 6 | This README is intended as an extension and completion of the original README. 7 | 8 | --- 9 | 10 | ## Code Dependencies and Installation 11 | The main project dependencies are: 12 | ```yaml 13 | cuda: 12.1 14 | cudnn: 9.0.0+ 15 | python: 3.10 16 | ``` 17 | In general, you should be able to install the project with any CUDA 12 and Python 3.8+ versions. 18 | I recommend using Python 3.10 and CUDA 12.1, or a slightly newer version, mainly because I don't want to deal with Python 3.11 and NumPy version 2. 19 | I have tested the code with CUDA 12.1 and 12.5 for Python 3.10 and 3.11. In principle, it should work with CUDA 11.6+ and Python 3.7, as in the original Mask4Former codebase. 20 | 21 | ```bash 22 | uv venv -p 3.10 23 | source .venv/bin/activate 24 | ``` 25 | 26 | ```bash 27 | uv pip install torch==2.3.0+cu121 torchvision==0.18.0+cu121 --extra-index-url https://download.pytorch.org/whl/cu121 28 | uv pip install volumentations PyYAML==6.0.2 numpy==1.24.4 setuptools==59.8.0 opencv-python natsort tensorboard wheel GitPython fire ninja "hydra-core<=1.0.5" python-dotenv pandas fire joblib GitPython flake8 "pytorch-lightning==1.9.5" loguru 29 | 30 | # might be necessery to build minkowskiengine 31 | uv pip install git+https://github.com/kumuji/FilePacker.git 32 | uv pip install git+https://github.com/facebookresearch/pytorch3d.git@v0.7.6 --no-deps --no-build-isolation 33 | ``` 34 | 35 | To install MinkowskiEngine with CUDA 12+, MinkowskiEngine needs a few modifications. 36 | For more details, check this comment: https://github.com/NVIDIA/MinkowskiEngine/issues/543#issuecomment-1773458776 37 | I tried building MinkowskiEngine with CUDA 12.8, but it didn't work. 38 | Typically, CUDA backward compatibility within a major version is not an issue, but it seems to be the case here. 39 | The same goes for GCC/G++. I could make it build with GCC 11. 40 | ```bash 41 | # You might want to set up these arguments to make it work 42 | # export CC=/usr/bin/gcc-11 43 | # export CXX=/usr/bin/g++-11 44 | # export TORCH_CUDA_ARCH_LIST="6.0 6.1 6.2 7.0 7.2 7.5 8.0 8.6 8.9" 45 | git clone https://github.com/NVIDIA/MinkowskiEngine.git 46 | cd MinkowskiEngine 47 | python setup.py install 48 | ``` 49 | Since many of these projects are quite old, it may be difficult to adapt them to newer versions of other dependencies. 50 | 51 | ## Data Preparation 52 | The code assumes that a data folder exists with three subfolders for the training, testing, and validation data. 53 | Code assumes that there is a data folder with three folders for train, test and validation data. 54 | Scenes to train, validate and test on are defined in the config file in `conf/semantic-kitti.yaml`. 55 | ```tree 56 | data 57 | ├── train (semkitti, panoptic-cudal, stu-train) 58 | │ ├── 00 59 | │ ... 60 | │ ├── 40 61 | │ ... 62 | │ └── 206 63 | ├── validation (semkitti) 64 | │ └── 08 65 | ├── test (stu-validation) 66 | │ ├── 125 67 | │ ... 68 | │ ... 69 | │ └── 169 70 | ``` 71 | 72 | Currently, the data needs preprocessing before training to populate instance databases. 73 | You don't need preprocessing to run the inference. 74 | ```bash 75 | python -m datasets.preprocessing.semantic_kitti_preprocessing preprocess --data_dir "data" --save_dir "./data" 76 | python -m datasets.preprocessing.semantic_kitti_preprocessing make_instance_database --data_dir "data" --save_dir "./data" 77 | ``` 78 | 79 | However, the repository requires initialization of train, validation and test datasets. 80 | Just touch these files if you don't want to train the model. 81 | ```bash 82 | touch data/train_instances_database.yaml 83 | touch data/train_database.yaml 84 | ``` 85 | 86 | ## Inference and Training 87 | To run inference: 88 | ```bash 89 | python main_panoptic.py model=mask4former3d data/datasets=semantic_kitti_206 general.ckpt_path=checkpoint.ckpt general.mode=test 90 | ``` 91 | Inference will take quite a while, but not all of the scans contain anomaly points, and they can be ignored. 92 | 93 | To train the model, you need to run: 94 | ```bash 95 | python main_panoptic.py model=mask4former3d data/datasets=semantic_kitti_206 general.mode=train 96 | ``` 97 | Unfortunately, the code doesn't support multi-GPU training, so we train all of the models using a single H100 GPU. 98 | If you want to extend it for multi-GPU / multi-node training, check [Interactive4D codebase.](https://github.com/Ilya-Fradlin/Interactive4D) 99 | You can train models using smaller GPUs, but you will have to adjust the batch size and the learning rate. 100 | 101 | ## Model Checkpoints :floppy_disk: 102 | 103 | | Checkpoint | PQ on SemanticKITTI | 104 | | :-: | :-: | 105 | | [Ensemble Model 0](https://omnomnom.vision.rwth-aachen.de/data/stu_checkpoints/59p1pq_ens0.ckpt) | `59.1` | 106 | | [Ensemble Model 1](https://omnomnom.vision.rwth-aachen.de/data/stu_checkpoints/59p6pq_ens1.ckpt) | `59.6` | 107 | | [Ensemble Model 2](https://omnomnom.vision.rwth-aachen.de/data/stu_checkpoints/60p7pq_ens2.ckpt) | `60.7` | 108 | | :-: | :-: | 109 | | [Dropout Model](https://omnomnom.vision.rwth-aachen.de/data/stu_checkpoints/dropout_model.ckpt) | -- | 110 | | :-: | :-: | 111 | | [Void Model](https://omnomnom.vision.rwth-aachen.de/data/stu_checkpoints/void_model.ckpt) | `47.9` | 112 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/augmentation/volumentations_aug.yaml: -------------------------------------------------------------------------------- 1 | # pi = 3.14159265358979 2 | # pi/2 = 1.57079632679489 3 | # pi/3 = 1.04719755119659 4 | # pi/6 = 0.52359877559829 5 | # pi/12 = 0.26179938779914 6 | # pi/24 = 0.13089969389957 7 | 8 | __version__: 0.1.6 9 | transform: 10 | __class_fullname__: volumentations.core.composition.Compose 11 | additional_targets: {} 12 | p: 1.0 13 | transforms: 14 | - __class_fullname__: volumentations.augmentations.transforms.Scale3d 15 | always_apply: true 16 | p: 0.5 17 | scale_limit: 18 | - - -0.1 19 | - 0.1 20 | - - -0.1 21 | - 0.1 22 | - - -0.1 23 | - 0.1 24 | - __class_fullname__: volumentations.augmentations.transforms.RotateAroundAxis3d 25 | always_apply: true 26 | axis: 27 | - 0 28 | - 0 29 | - 1 30 | p: 0.5 31 | rotation_limit: 32 | - -3.141592653589793 33 | - 3.141592653589793 34 | - __class_fullname__: volumentations.augmentations.transforms.RotateAroundAxis3d 35 | always_apply: true 36 | axis: 37 | - 0 38 | - 1 39 | - 0 40 | p: 0.5 41 | rotation_limit: 42 | - -0.13089969389957 43 | - 0.13089969389957 44 | - __class_fullname__: volumentations.augmentations.transforms.RotateAroundAxis3d 45 | always_apply: true 46 | axis: 47 | - 1 48 | - 0 49 | - 0 50 | p: 0.5 51 | rotation_limit: 52 | - -0.13089969389957 53 | - 0.13089969389957 54 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/callbacks/callbacks_panoptic.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: val_mean_lstq 4 | save_last: true 5 | save_top_k: 1 6 | mode: max 7 | dirpath: ${general.save_dir} 8 | filename: "{epoch}-{val_mean_lstq:.3f}" 9 | every_n_epochs: 1 10 | 11 | - _target_: pytorch_lightning.callbacks.LearningRateMonitor 12 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/callbacks/callbacks_panoptic3d.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: val_mean_pq 4 | save_last: true 5 | save_top_k: 1 6 | mode: max 7 | dirpath: ${general.save_dir} 8 | filename: "{epoch}-{val_mean_pq:.3f}" 9 | every_n_epochs: 1 10 | 11 | - _target_: pytorch_lightning.callbacks.LearningRateMonitor 12 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/config_panoptic_3d.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | mode: "train" 3 | seed: null 4 | ckpt_path: null 5 | project_name: mask4former3d 6 | workspace: null 7 | instance_population: 20 8 | dbscan_eps: null 9 | experiment_name: ${now:%Y-%m-%d_%H%M%S} 10 | save_dir: saved/${general.experiment_name} 11 | gpus: 1 12 | 13 | defaults: 14 | - data: kitti3d 15 | - data/data_loaders: simple_loader 16 | - data/datasets: semantic_kitti_206 17 | - data/collation_functions: voxelize_collate 18 | - logging: full 19 | - model: mask4former3d 20 | - optimizer: adamw 21 | - scheduler: onecyclelr 22 | - trainer: trainer30 23 | - callbacks: callbacks_panoptic3d 24 | - matcher: hungarian_matcher 25 | - loss: set_criterion 26 | - metric: pq 27 | 28 | hydra: 29 | run: 30 | dir: saved/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} 31 | sweep: 32 | dir: saved/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} 33 | # dir: ${general.save_dir} 34 | subdir: ${hydra.job.num}_${hydra.job.id} 35 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/data/collation_functions/voxelize_collate.yaml: -------------------------------------------------------------------------------- 1 | # @package data 2 | 3 | train_collation: 4 | _target_: datasets.utils.VoxelizeCollate 5 | ignore_label: ${data.ignore_label} 6 | voxel_size: ${data.voxel_size} 7 | 8 | validation_collation: 9 | _target_: datasets.utils.VoxelizeCollate 10 | ignore_label: ${data.ignore_label} 11 | voxel_size: ${data.voxel_size} 12 | 13 | test_collation: 14 | _target_: datasets.utils.VoxelizeCollate 15 | ignore_label: ${data.ignore_label} 16 | voxel_size: ${data.voxel_size} 17 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/data/collation_functions/voxelize_collate_merge.yaml: -------------------------------------------------------------------------------- 1 | train_collation: 2 | _target_: datasets.utils.VoxelizeCollateMerge 3 | ignore_label: ${data.ignore_label} 4 | voxel_size: ${data.voxel_size} 5 | mode: ${data.train_mode} 6 | small_crops: false 7 | very_small_crops: false 8 | scenes: 2 9 | batch_instance: false 10 | make_one_pc_noise: false 11 | place_nearby: false 12 | place_far: false 13 | proba: 1 14 | 15 | validation_collation: 16 | _target_: datasets.utils.VoxelizeCollate 17 | ignore_label: ${data.ignore_label} 18 | voxel_size: ${data.voxel_size} 19 | mode: ${data.validation_mode} 20 | 21 | test_collation: 22 | _target_: datasets.utils.VoxelizeCollate 23 | ignore_label: ${data.ignore_label} 24 | voxel_size: ${data.voxel_size} 25 | mode: ${data.test_mode} 26 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/data/data_loaders/simple_loader.yaml: -------------------------------------------------------------------------------- 1 | # @package data 2 | 3 | train_dataloader: 4 | _target_: torch.utils.data.DataLoader 5 | shuffle: true 6 | pin_memory: ${data.pin_memory} 7 | num_workers: ${data.num_workers} 8 | batch_size: ${data.batch_size} 9 | 10 | validation_dataloader: 11 | _target_: torch.utils.data.DataLoader 12 | shuffle: false 13 | pin_memory: ${data.pin_memory} 14 | num_workers: ${data.num_workers} 15 | batch_size: ${data.test_batch_size} 16 | 17 | test_dataloader: 18 | _target_: torch.utils.data.DataLoader 19 | shuffle: false 20 | pin_memory: ${data.pin_memory} 21 | num_workers: ${data.num_workers} 22 | batch_size: ${data.test_batch_size} 23 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/data/datasets/semantic_kitti.yaml: -------------------------------------------------------------------------------- 1 | # @package data 2 | train_dataset: 3 | _target_: datasets.lidar.LidarDataset 4 | data_dir: data/semantic_kitti 5 | mode: ${data.train_mode} 6 | add_distance: ${data.add_distance} 7 | sweep: ${data.sweep} 8 | instance_population: ${data.instance_population} 9 | ignore_label: ${data.ignore_label} 10 | volume_augmentations_path: conf/augmentation/volumentations_aug.yaml 11 | 12 | validation_dataset: 13 | _target_: datasets.lidar.LidarDataset 14 | data_dir: data/semantic_kitti 15 | mode: ${data.validation_mode} 16 | add_distance: ${data.add_distance} 17 | sweep: ${data.sweep} 18 | instance_population: 0 19 | ignore_label: ${data.ignore_label} 20 | volume_augmentations_path: null 21 | 22 | test_dataset: 23 | _target_: datasets.lidar.LidarDataset 24 | data_dir: data/semantic_kitti 25 | mode: ${data.test_mode} 26 | add_distance: ${data.add_distance} 27 | sweep: ${data.sweep} 28 | instance_population: 0 29 | ignore_label: ${data.ignore_label} 30 | volume_augmentations_path: null 31 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/data/datasets/semantic_kitti_206.yaml: -------------------------------------------------------------------------------- 1 | # @package data 2 | train_dataset: 3 | _target_: datasets.lidar.LidarDataset 4 | data_dir: ${data.base_path} 5 | config_path: conf/mask4former3d.yaml 6 | mode: ${data.train_mode} 7 | add_distance: ${data.add_distance} 8 | sweep: ${data.sweep} 9 | instance_population: ${data.instance_population} 10 | ignore_label: ${data.ignore_label} 11 | volume_augmentations_path: conf/augmentation/volumentations_aug.yaml 12 | 13 | validation_dataset: 14 | _target_: datasets.lidar_no_preprocessing.LidarDataset 15 | data_dir: ${data.base_path}/validation 16 | config_path: conf/mask4former3d.yaml 17 | mode: ${data.validation_mode} 18 | add_distance: ${data.add_distance} 19 | sweep: ${data.sweep} 20 | instance_population: 0 21 | ignore_label: ${data.ignore_label} 22 | volume_augmentations_path: null 23 | 24 | test_dataset: 25 | _target_: datasets.lidar_no_preprocessing.LidarDataset 26 | data_dir: ${data.base_path}/test 27 | config_path: conf/mask4former3d.yaml 28 | mode: ${data.test_mode} 29 | add_distance: ${data.add_distance} 30 | sweep: ${data.sweep} 31 | instance_population: 0 32 | ignore_label: ${data.ignore_label} 33 | volume_augmentations_path: null 34 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/data/kitti.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | # these parameters are inherited by datasets, data_loaders and collators 4 | # but they might be overwritten 5 | 6 | # splits 7 | train_mode: train 8 | validation_mode: validation 9 | test_mode: test 10 | 11 | # dataset 12 | ignore_label: 255 13 | add_distance: true 14 | in_channels: 2 15 | num_labels: 19 16 | instance_population: ${general.instance_population} 17 | sweep: 2 18 | min_stuff_cls_id: 9 19 | min_points: 50 20 | class_names: ['car', 'bicycle', 'motorcycle', 'truck', 'other-vehicle', 'person', 'bicyclist', 'motorcyclist', 'road', 'parking', 'sidewalk', 'other-ground', 'building', 'fence', 'vegetation', 'trunk', 'terrain', 'pole', 'traffic-sign'] 21 | 22 | # data loader 23 | pin_memory: true 24 | num_workers: 4 25 | batch_size: 4 26 | test_batch_size: 2 27 | 28 | # collation 29 | voxel_size: 0.05 30 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/data/kitti3d.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | # these parameters are inherited by datasets, data_loaders and collators 4 | # but they might be overwritten 5 | 6 | # splits 7 | train_mode: train 8 | validation_mode: valid 9 | test_mode: test 10 | 11 | # dataset 12 | ignore_label: 255 13 | add_distance: true 14 | in_channels: 2 15 | num_labels: 19 16 | instance_population: ${general.instance_population} 17 | sweep: 1 18 | min_stuff_cls_id: 9 19 | min_points: 50 20 | class_names: ['car', 'bicycle', 'motorcycle', 'truck', 'other-vehicle', 'person', 'bicyclist', 'motorcyclist', 'road', 'parking', 'sidewalk', 'other-ground', 'building', 'fence', 'vegetation', 'trunk', 'terrain', 'pole', 'traffic-sign'] 21 | 22 | # data loader 23 | pin_memory: true 24 | num_workers: 16 25 | batch_size: 16 26 | test_batch_size: 16 27 | 28 | # collation 29 | voxel_size: 0.05 30 | 31 | base_path: ./data 32 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/logging/full.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 3 | name: ${general.experiment_name} 4 | save_dir: ${general.save_dir} 5 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/loss/set_criterion.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: models.criterion.SetCriterion 3 | num_classes: ${data.num_labels} 4 | eos_coef: 0.1 5 | losses: 6 | - "labels" 7 | - "masks" 8 | - "bboxs" 9 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/matcher/hungarian_matcher.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: models.matcher.HungarianMatcher 3 | cost_class: 2. 4 | cost_mask: 5. 5 | cost_dice: 2. 6 | cost_box: 5. 7 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/metric/lstq.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: models.metrics.Panoptic4DEval 3 | n_classes: ${data.num_labels} 4 | min_stuff_cls_id: ${data.min_stuff_cls_id} 5 | min_points: ${data.min_points} -------------------------------------------------------------------------------- /Mask4Former3D/conf/metric/pq.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: models.metrics.PanopticEval 3 | n_classes: ${data.num_labels} 4 | ignore_label: ${data.ignore_label} 5 | min_points: ${data.min_points} 6 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/model/mask4former3d.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: models.Mask4Former3D 3 | 4 | # backbone 5 | backbone: 6 | _target_: models.Res16UNet34C 7 | config: 8 | dialations: [ 1, 1, 1, 1 ] 9 | conv1_kernel_size: 5 10 | bn_momentum: 0.02 11 | in_channels: ${data.in_channels} 12 | out_channels: ${data.num_labels} 13 | 14 | # transformer parameters 15 | num_queries: 100 16 | num_heads: 8 17 | num_decoders: 3 18 | num_levels: 4 19 | sample_sizes: [4000, 8000, 16000, 32000] 20 | mask_dim: 128 21 | dim_feedforward: 1024 22 | num_labels: ${data.num_labels} 23 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: torch.optim.AdamW 3 | lr: 0.0002 -------------------------------------------------------------------------------- /Mask4Former3D/conf/scheduler/onecyclelr.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | scheduler: 3 | _target_: torch.optim.lr_scheduler.OneCycleLR 4 | max_lr: ${optimizer.lr} 5 | epochs: ${trainer.max_epochs} 6 | # need to set to number because of tensorboard logger 7 | steps_per_epoch: -1 8 | 9 | pytorch_lightning_params: 10 | interval: step 11 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/semantic-kitti.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 2 : "anomaly" 6 | 10: "car" 7 | 11: "bicycle" 8 | 13: "bus" 9 | 15: "motorcycle" 10 | 16: "on-rails" 11 | 18: "truck" 12 | 20: "other-vehicle" 13 | 30: "person" 14 | 31: "bicyclist" 15 | 32: "motorcyclist" 16 | 40: "road" 17 | 44: "parking" 18 | 48: "sidewalk" 19 | 49: "other-ground" 20 | 50: "building" 21 | 51: "fence" 22 | 52: "other-structure" 23 | 60: "lane-marking" 24 | 70: "vegetation" 25 | 71: "trunk" 26 | 72: "terrain" 27 | 80: "pole" 28 | 81: "traffic-sign" 29 | 99: "other-object" 30 | 252: "moving-car" 31 | 253: "moving-bicyclist" 32 | 254: "moving-person" 33 | 255: "moving-motorcyclist" 34 | 256: "moving-on-rails" 35 | 257: "moving-bus" 36 | 258: "moving-truck" 37 | 259: "moving-other-vehicle" 38 | 39 | color_map: # bgr 40 | 0 : [0, 0, 0] 41 | 1 : [0, 0, 255] 42 | 2 : [255, 0, 0] 43 | 10: [245, 150, 100] 44 | 11: [245, 230, 100] 45 | 13: [250, 80, 100] 46 | 15: [150, 60, 30] 47 | 16: [255, 0, 0] 48 | 18: [180, 30, 80] 49 | 20: [255, 0, 0] 50 | 30: [30, 30, 255] 51 | 31: [200, 40, 255] 52 | 32: [90, 30, 150] 53 | 40: [255, 0, 255] 54 | 44: [255, 150, 255] 55 | 48: [75, 0, 75] 56 | 49: [75, 0, 175] 57 | 50: [0, 200, 255] 58 | 51: [50, 120, 255] 59 | 52: [0, 150, 255] 60 | 60: [170, 255, 150] 61 | 70: [0, 175, 0] 62 | 71: [0, 60, 135] 63 | 72: [80, 240, 150] 64 | 80: [150, 240, 255] 65 | 81: [0, 0, 255] 66 | 99: [255, 255, 50] 67 | 252: [245, 150, 100] 68 | 256: [255, 0, 0] 69 | 253: [200, 40, 255] 70 | 254: [30, 30, 255] 71 | 255: [90, 30, 150] 72 | 257: [250, 80, 100] 73 | 258: [180, 30, 80] 74 | 259: [255, 0, 0] 75 | 76 | content: # as a ratio with the total number of points 77 | 0: 0.018889854628292943 78 | 1: 0.0002937197336781505 79 | 10: 0.040818519255974316 80 | 11: 0.00016609538710764618 81 | 13: 2.7879693665067774e-05 82 | 15: 0.00039838616015114444 83 | 16: 0.0 84 | 18: 0.0020633612104619787 85 | 20: 0.0016218197275284021 86 | 30: 0.00017698551338515307 87 | 31: 1.1065903904919655e-08 88 | 32: 5.532951952459828e-09 89 | 40: 0.1987493871255525 90 | 44: 0.014717169549888214 91 | 48: 0.14392298360372 92 | 49: 0.0039048553037472045 93 | 50: 0.1326861944777486 94 | 51: 0.0723592229456223 95 | 52: 0.002395131480328884 96 | 60: 4.7084144280367186e-05 97 | 70: 0.26681502148037506 98 | 71: 0.006035012012626033 99 | 72: 0.07814222006271769 100 | 80: 0.002855498193863172 101 | 81: 0.0006155958086189918 102 | 99: 0.009923127583046915 103 | 252: 0.001789309418528068 104 | 253: 0.00012709999297008662 105 | 254: 0.00016059776092534436 106 | 255: 3.745553104802113e-05 107 | 256: 0.0 108 | 257: 0.00011351574470342043 109 | 258: 0.00010157861367183268 110 | 259: 4.3840131989471124e-05 111 | 112 | # classes that are indistinguishable from single scan or inconsistent in 113 | # ground truth are mapped to their closest equivalent 114 | learning_map: 115 | 0 : 0 # "unlabeled" 116 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 117 | 2 : 0 # "anomaly" mapped to "unlabeled" --------------------------mapped 118 | 10: 1 # "car" 119 | 11: 2 # "bicycle" 120 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 121 | 15: 3 # "motorcycle" 122 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 123 | 18: 4 # "truck" 124 | 20: 5 # "other-vehicle" 125 | 30: 6 # "person" 126 | 31: 7 # "bicyclist" 127 | 32: 8 # "motorcyclist" 128 | 40: 9 # "road" 129 | 44: 10 # "parking" 130 | 48: 11 # "sidewalk" 131 | 49: 12 # "other-ground" 132 | 50: 13 # "building" 133 | 51: 14 # "fence" 134 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 135 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 136 | 70: 15 # "vegetation" 137 | 71: 16 # "trunk" 138 | 72: 17 # "terrain" 139 | 80: 18 # "pole" 140 | 81: 19 # "traffic-sign" 141 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 142 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 143 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 144 | 254: 6 # "moving-person" to "person" ------------------------------mapped 145 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 146 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 147 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 148 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 149 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 150 | 151 | learning_map_inv: # inverse of previous map 152 | 0: 0 # "unlabeled", and others ignored 153 | 1: 10 # "car" 154 | 2: 11 # "bicycle" 155 | 3: 15 # "motorcycle" 156 | 4: 18 # "truck" 157 | 5: 20 # "other-vehicle" 158 | 6: 30 # "person" 159 | 7: 31 # "bicyclist" 160 | 8: 32 # "motorcyclist" 161 | 9: 40 # "road" 162 | 10: 44 # "parking" 163 | 11: 48 # "sidewalk" 164 | 12: 49 # "other-ground" 165 | 13: 50 # "building" 166 | 14: 51 # "fence" 167 | 15: 70 # "vegetation" 168 | 16: 71 # "trunk" 169 | 17: 72 # "terrain" 170 | 18: 80 # "pole" 171 | 19: 81 # "traffic-sign" 172 | 173 | learning_ignore: # Ignore classes 174 | 0: True # "unlabeled", and others ignored 175 | 1: False # "car" 176 | 2: False # "bicycle" 177 | 3: False # "motorcycle" 178 | 4: False # "truck" 179 | 5: False # "other-vehicle" 180 | 6: False # "person" 181 | 7: False # "bicyclist" 182 | 8: False # "motorcyclist" 183 | 9: False # "road" 184 | 10: False # "parking" 185 | 11: False # "sidewalk" 186 | 12: False # "other-ground" 187 | 13: False # "building" 188 | 14: False # "fence" 189 | 15: False # "vegetation" 190 | 16: False # "trunk" 191 | 17: False # "terrain" 192 | 18: False # "pole" 193 | 19: False # "traffic-sign" 194 | 195 | split: # sequence numbers 196 | train: 197 | # semkitti 198 | - 0 199 | - 1 200 | - 2 201 | - 3 202 | - 4 203 | - 5 204 | - 6 205 | - 7 206 | - 9 207 | - 10 208 | # panoptic-cudal 209 | - 30 210 | - 31 211 | - 36 212 | - 40 213 | - 41 214 | # stu train 215 | - 206 216 | valid: 217 | # semkitti 218 | - 8 219 | # panoptic-cudal 220 | # - 32 221 | # stu val 222 | # - 201 223 | test: 224 | # australia 225 | # - 100 226 | # - 101 227 | # - 102 228 | # - 103 229 | # - 104 230 | # - 105 231 | # - 106 232 | # - 107 233 | # - 108 234 | # - 109 235 | # - 110 236 | # - 111 237 | # - 112 238 | # - 113 239 | # - 114 240 | # - 115 241 | # - 116 242 | # - 117 243 | # - 118 244 | # - 119 245 | # - 120 246 | # - 121 247 | # - 122 248 | # - 123 249 | # - 124 250 | # - 125 251 | # - 126 252 | # - 127 253 | # - 128 254 | # - 129 255 | # - 130 256 | # - 131 257 | # - 132 258 | # - 133 259 | # - 134 260 | # - 135 261 | # - 136 262 | # - 137 263 | # - 138 264 | # - 139 265 | # - 140 266 | # - 141 267 | # - 142 268 | # - 143 269 | # - 144 270 | # - 145 271 | # - 146 272 | # - 147 273 | # - 148 274 | # - 149 275 | # - 150 276 | # - 151 277 | # - 152 278 | # - 153 279 | # - 154 280 | # - 155 281 | # - 156 282 | # - 157 283 | # - 158 284 | # - 159 285 | # - 160 286 | # - 161 287 | # - 162 288 | # - 163 289 | # - 164 290 | # - 165 291 | # - 166 292 | # - 167 293 | # - 168 294 | # - 169 295 | - 125 296 | - 137 297 | - 138 298 | - 139 299 | - 140 300 | - 141 301 | - 142 302 | - 143 303 | - 144 304 | - 145 305 | - 146 306 | - 147 307 | - 148 308 | - 149 309 | - 150 310 | - 151 311 | - 152 312 | - 153 313 | - 169 314 | -------------------------------------------------------------------------------- /Mask4Former3D/conf/trainer/trainer30.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | max_epochs: 30 3 | check_val_every_n_epoch: 1 4 | num_sanity_val_steps: 2 5 | -------------------------------------------------------------------------------- /Mask4Former3D/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumuji/stu_dataset/3812d768f3634fbb6faeb8b0bfbd5246a9798e93/Mask4Former3D/data/.gitkeep -------------------------------------------------------------------------------- /Mask4Former3D/datasets/lidar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import volumentations as V 3 | import yaml 4 | from torch.utils.data import Dataset 5 | from pathlib import Path 6 | from typing import List, Optional, Union 7 | from random import random, choice, uniform 8 | 9 | 10 | class LidarDataset(Dataset): 11 | def __init__( 12 | self, 13 | data_dir: Optional[str] = "data/processed/semantic_kitti", 14 | config_path: Optional[str] = "conf/semantic-kitti.yaml", 15 | mode: Optional[str] = "train", 16 | add_distance: Optional[bool] = False, 17 | ignore_label: Optional[Union[int, List[int]]] = 255, 18 | volume_augmentations_path: Optional[str] = None, 19 | instance_population: Optional[int] = 0, 20 | sweep: Optional[int] = 1, 21 | ): 22 | self.mode = mode 23 | self.data_dir = data_dir 24 | self.ignore_label = ignore_label 25 | self.add_distance = add_distance 26 | self.instance_population = instance_population 27 | self.sweep = sweep 28 | self.config = self._load_yaml(config_path) 29 | 30 | # loading database file 31 | database_path = Path(self.data_dir) 32 | if not (database_path / f"{mode}_database.yaml").exists(): 33 | print(f"generate {database_path}/{mode}_database.yaml first") 34 | exit() 35 | self.data = self._load_yaml(database_path / f"{mode}_database.yaml") 36 | 37 | self.label_info = self._select_correct_labels(self.config["learning_ignore"]) 38 | # augmentations 39 | self.volume_augmentations = V.NoOp() 40 | if volume_augmentations_path is not None: 41 | self.volume_augmentations = V.load(volume_augmentations_path, data_format="yaml") 42 | # reformulating in sweeps 43 | data = [[]] 44 | last_scene = self.data[0]["scene"] 45 | for x in self.data: 46 | if x["scene"] == last_scene: 47 | data[-1].append(x) 48 | else: 49 | last_scene = x["scene"] 50 | data.append([x]) 51 | for i in range(len(data)): 52 | data[i] = list(self.chunks(data[i], sweep)) 53 | self.data = [val for sublist in data for val in sublist] 54 | 55 | if instance_population > 0: 56 | self.instance_data = self._load_yaml(database_path / f"{mode}_instances_database.yaml") 57 | 58 | def chunks(self, lst, n): 59 | if "train" in self.mode or n == 1: 60 | for i in range(len(lst) - n + 1): 61 | yield lst[i : i + n] 62 | else: 63 | for i in range(0, len(lst) - n + 1, n - 1): 64 | yield lst[i : i + n] 65 | if i != len(lst) - n: 66 | yield lst[i + n - 1 :] 67 | 68 | def __len__(self): 69 | return len(self.data) 70 | 71 | def __getitem__(self, idx: int): 72 | coordinates_list = [] 73 | features_list = [] 74 | labels_list = [] 75 | acc_num_points = [0] 76 | for time, scan in enumerate(self.data[idx]): 77 | points = np.fromfile(scan["filepath"], dtype=np.float32).reshape(-1, 4) 78 | coordinates = points[:, :3] 79 | # rotate and translate 80 | pose = np.array(scan["pose"]).T 81 | coordinates = coordinates @ pose[:3, :3] + pose[3, :3] 82 | coordinates_list.append(coordinates) 83 | acc_num_points.append(acc_num_points[-1] + len(coordinates)) 84 | features = points[:, 3:4] 85 | time_array = np.ones((features.shape[0], 1)) * time 86 | features = np.hstack((time_array, features)) 87 | features_list.append(features) 88 | if "test" in self.mode: 89 | labels = np.zeros_like(features).astype(np.int64) 90 | labels_list.append(labels) 91 | else: 92 | panoptic_label = np.fromfile(scan["label_filepath"], dtype=np.uint32) 93 | semantic_label = panoptic_label & 0xFFFF 94 | semantic_label = np.vectorize(self.config["learning_map"].__getitem__)(semantic_label) 95 | labels = np.hstack((semantic_label[:, None], panoptic_label[:, None])) 96 | labels_list.append(labels) 97 | 98 | coordinates = np.vstack(coordinates_list) 99 | features = np.vstack(features_list) 100 | labels = np.vstack(labels_list) 101 | 102 | if "train" in self.mode and self.instance_population > 0: 103 | max_instance_id = np.amax(labels[:, 1]) 104 | pc_center = coordinates.mean(axis=0) 105 | instance_c, instance_f, instance_l = self.populate_instances( 106 | max_instance_id, pc_center, self.instance_population 107 | ) 108 | coordinates = np.vstack((coordinates, instance_c)) 109 | features = np.vstack((features, instance_f)) 110 | labels = np.vstack((labels, instance_l)) 111 | 112 | if self.add_distance: 113 | center_coordinate = coordinates.mean(0) 114 | features = np.hstack( 115 | ( 116 | features, 117 | np.linalg.norm(coordinates - center_coordinate, axis=1)[:, np.newaxis], 118 | ) 119 | ) 120 | 121 | # volume and image augmentations for train 122 | if "train" in self.mode: 123 | coordinates -= coordinates.mean(0) 124 | if 0.5 > random(): 125 | coordinates += np.random.uniform(coordinates.min(0), coordinates.max(0)) / 2 126 | aug = self.volume_augmentations(points=coordinates) 127 | coordinates = aug["points"] 128 | 129 | features = np.hstack((coordinates, features)) 130 | 131 | labels[:, 0] = np.vectorize(self.label_info.__getitem__)(labels[:, 0]) 132 | 133 | return { 134 | "num_points": acc_num_points, 135 | "coordinates": coordinates, 136 | "features": features, 137 | "labels": labels, 138 | "sequence": scan["scene"], 139 | } 140 | 141 | @staticmethod 142 | def _load_yaml(filepath): 143 | with open(filepath) as f: 144 | file = yaml.safe_load(f) 145 | return file 146 | 147 | def _select_correct_labels(self, learning_ignore): 148 | count = 0 149 | label_info = dict() 150 | for k, v in learning_ignore.items(): 151 | if v: 152 | label_info[k] = self.ignore_label 153 | else: 154 | label_info[k] = count 155 | count += 1 156 | return label_info 157 | 158 | def _remap_model_output(self, output): 159 | inv_map = {v: k for k, v in self.label_info.items()} 160 | output = np.vectorize(inv_map.__getitem__)(output) 161 | return output 162 | 163 | def populate_instances(self, max_instance_id, pc_center, instance_population): 164 | coordinates_list = [] 165 | features_list = [] 166 | labels_list = [] 167 | for _ in range(instance_population): 168 | instance_dict = choice(self.instance_data) 169 | idx = np.random.randint(len(instance_dict["filepaths"])) 170 | instance_list = [] 171 | for time in range(self.sweep): 172 | if idx < len(instance_dict["filepaths"]): 173 | filepath = instance_dict["filepaths"][idx] 174 | instance = np.load(filepath) 175 | time_array = np.ones((instance.shape[0], 1)) * time 176 | instance = np.hstack((instance[:, :3], time_array, instance[:, 3:4])) 177 | instance_list.append(instance) 178 | idx = idx + 1 179 | instances = np.vstack(instance_list) 180 | coordinates = instances[:, :3] - instances[:, :3].mean(0) 181 | coordinates += pc_center + np.array([uniform(-10, 10), uniform(-10, 10), uniform(-1, 1)]) 182 | features = instances[:, 3:] 183 | semantic_label = instance_dict["semantic_label"] 184 | labels = np.zeros_like(features, dtype=np.int64) 185 | labels[:, 0] = semantic_label 186 | max_instance_id = max_instance_id + 1 187 | labels[:, 1] = max_instance_id 188 | aug = self.volume_augmentations(points=coordinates) 189 | coordinates = aug["points"] 190 | coordinates_list.append(coordinates) 191 | features_list.append(features) 192 | labels_list.append(labels) 193 | return np.vstack(coordinates_list), np.vstack(features_list), np.vstack(labels_list) 194 | -------------------------------------------------------------------------------- /Mask4Former3D/datasets/lidar_no_preprocessing.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from random import choice, random, uniform 3 | from typing import List, Optional, Union 4 | 5 | import numpy as np 6 | import volumentations as V 7 | import yaml 8 | from file_packer import FilePackReader 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class LidarDataset(Dataset): 13 | def __init__( 14 | self, 15 | data_dir: str, 16 | config_path: str, 17 | instance_data_fpack: Optional[str] = None, 18 | mode: Optional[str] = "train", 19 | add_distance: Optional[bool] = False, 20 | ignore_label: Optional[Union[int, List[int]]] = 255, 21 | volume_augmentations_path: Optional[str] = None, 22 | instance_population: Optional[int] = 0, 23 | sweep: Optional[int] = 1, 24 | ): 25 | self.mode = mode 26 | self.data_dir = Path(data_dir) 27 | self.ignore_label = ignore_label 28 | self.add_distance = add_distance 29 | self.instance_population = instance_population 30 | self.sweep = sweep 31 | self.config = self._load_yaml(config_path) 32 | self.label_info = self._select_correct_labels(self.config["learning_ignore"]) 33 | 34 | # Load scenes directly from directory 35 | self.scenes = self._load_scenes() 36 | # And for instance population, I used Ali's filepacker, but it is not necessery 37 | if self.instance_population > 0: 38 | self.instance_data = FilePackReader(instance_data_fpack) 39 | 40 | # Preload all poses for scenes 41 | self.pose_cache = self._preload_poses() 42 | 43 | # Augmentations 44 | self.volume_augmentations = V.NoOp() 45 | if volume_augmentations_path is not None: 46 | self.volume_augmentations = V.load( 47 | volume_augmentations_path, data_format="yaml" 48 | ) 49 | 50 | def _preload_poses(self): 51 | """Load and cache all pose transformations for each frame in each scene.""" 52 | poses = list() 53 | for scene in self.scenes: 54 | # print(scene[0]) 55 | scene_path = Path(scene[0]).parent.parent 56 | calib_file = scene_path / "calib.txt" 57 | pose_file = scene_path / "poses.txt" 58 | calibration = self.parse_calibration(calib_file) 59 | # Load poses for each frame in the scene 60 | poses.append(self.parse_poses(pose_file, calibration)) 61 | return poses 62 | 63 | def _load_scenes(self): 64 | """Load all scenes and their respective frames based on mode.""" 65 | scene_data = [] 66 | scene_ids = self.config["split"].get(self.mode, []) 67 | if self.mode != "train": 68 | print(f"{self.mode}: {scene_ids}, {len(scene_ids)}") 69 | for scene_id in scene_ids: 70 | scene_path = self.data_dir / f"{int(scene_id):02}" 71 | print(Path(scene_path / "velodyne")) 72 | frames = list(Path(scene_path / "velodyne").glob("*.bin")) 73 | frames = sorted(frames) 74 | scene_data.append([str(frame) for frame in frames]) 75 | return scene_data 76 | 77 | def __len__(self): 78 | return sum(len(scene) for scene in self.scenes) 79 | 80 | def __getitem__(self, idx: int): 81 | # Locate scene and frame 82 | scene_idx, frame_idx = self._find_scene_frame(idx) 83 | frame_path = Path(self.scenes[scene_idx][frame_idx]) 84 | 85 | # Load point cloud and apply transformation 86 | points = np.fromfile(frame_path, dtype=np.float32).reshape(-1, 4) 87 | coordinates = points[:, :3] 88 | features = points[:, 3:4] 89 | time_array = np.zeros((features.shape[0], 1)) 90 | features = np.hstack((time_array, features)) 91 | 92 | # Apply pose transformation 93 | pose = self.pose_cache[scene_idx][frame_idx] 94 | coordinates = coordinates @ pose[:3, :3] + pose[3, :3] 95 | acc_num_points = [0, len(coordinates)] 96 | 97 | # Get labels 98 | labels = self._load_labels(frame_path) 99 | 100 | # Add instance population if required 101 | if "train" in self.mode and self.instance_population > 0: 102 | max_instance_id = np.amax(labels[:, 1]) 103 | pc_center = coordinates.mean(axis=0) 104 | instance_coords, instance_feats, instance_labels = self.populate_instances( 105 | max_instance_id, pc_center, num_instances=self.instance_population 106 | ) 107 | coordinates = np.vstack((coordinates, instance_coords)) 108 | features = np.vstack((features, instance_feats)) 109 | labels = np.vstack((labels, instance_labels)) 110 | 111 | # Add distance if required 112 | if self.add_distance: 113 | center_coordinate = coordinates.mean(0) 114 | features = np.hstack( 115 | ( 116 | features, 117 | np.linalg.norm(coordinates - center_coordinate, axis=1)[ 118 | :, np.newaxis 119 | ], 120 | ) 121 | ) 122 | 123 | # Apply augmentations 124 | if "train" in self.mode: 125 | coordinates -= coordinates.mean(0) 126 | if 0.5 > random(): 127 | coordinates += ( 128 | np.random.uniform(coordinates.min(0), coordinates.max(0)) / 2 129 | ) 130 | aug = self.volume_augmentations(points=coordinates) 131 | coordinates = aug["points"] 132 | 133 | labels[:, 0] = np.vectorize(self.label_info.__getitem__)(labels[:, 0]) 134 | 135 | return { 136 | "num_points": acc_num_points, 137 | "coordinates": coordinates, 138 | "features": np.hstack((coordinates, features)), 139 | "labels": labels, 140 | "sequence": (str(frame_path.parent.parent.name), str(frame_path.stem)), 141 | } 142 | 143 | def _find_scene_frame(self, idx): 144 | """Determine which scene and frame corresponds to a dataset index.""" 145 | cumulative = 0 146 | for scene_idx, frames in enumerate(self.scenes): 147 | if idx < cumulative + len(frames): 148 | return scene_idx, idx - cumulative 149 | cumulative += len(frames) 150 | raise IndexError("Index out of range") 151 | 152 | def _load_labels(self, frame_path): 153 | """Load and process the labels for a given frame.""" 154 | label_path = ( 155 | str(frame_path).replace("velodyne", "labels").replace(".bin", ".label") 156 | ) 157 | panoptic_label = np.fromfile(label_path, dtype=np.uint32) 158 | semantic_label = panoptic_label & 0xFFFF 159 | semantic_label = np.vectorize(self.config["learning_map"].__getitem__)( 160 | semantic_label 161 | ) 162 | labels = np.hstack((semantic_label[:, None], panoptic_label[:, None])) 163 | return labels 164 | 165 | def populate_instances(self, max_instance_id, pc_center, num_instances): 166 | coordinates_list = [] 167 | features_list = [] 168 | labels_list = [] 169 | 170 | # Get all instance directories (assuming the root directory in the pack is named "instances") 171 | instance_dirs = self.instance_data.listdir(self.instance_data.base_path) 172 | for _ in range(num_instances): 173 | # Randomly select an instance directory 174 | instance_dir = choice(instance_dirs) 175 | semantic_label = int(instance_dir.split("_")[1]) # Extract semantic label 176 | # List all bin files within the chosen instance directory 177 | bin_files = sorted( 178 | self.instance_data.listdir( 179 | f"{self.instance_data.base_path}/{instance_dir}" 180 | ) 181 | ) 182 | # Choose a random starting index for the sequence of sweeps 183 | idx = np.random.randint(len(bin_files)) 184 | instance_list = [] 185 | for time in range(self.sweep): 186 | if idx < len(bin_files): 187 | instance_filepath = f"{self.instance_data.base_path}/{instance_dir}/{bin_files[idx]}" 188 | # Read the binary file directly from the file pack 189 | with self.instance_data.open(instance_filepath, mode="rb") as file: 190 | instance = np.frombuffer(file.read(), dtype=np.float32).reshape( 191 | -1, 4 192 | ) 193 | # Add a time dimension to the instance points 194 | time_array = np.ones((instance.shape[0], 1)) * time 195 | instance = np.hstack( 196 | (instance[:, :3], time_array, instance[:, 3:4]) 197 | ) 198 | instance_list.append(instance) 199 | # Increment index for the next timestep in the sweep 200 | idx = idx + 1 201 | 202 | # Aggregate instances from the list into a single array 203 | instances = np.vstack(instance_list) 204 | # Center the coordinates and apply translation 205 | coordinates = instances[:, :3] - instances[:, :3].mean(0) 206 | coordinates += pc_center + np.array( 207 | [uniform(-10, 10), uniform(-10, 10), uniform(-1, 1)] 208 | ) 209 | # Extract features 210 | features = instances[:, 3:] 211 | # Create labels with semantic and instance IDs 212 | labels = np.zeros_like(features, dtype=np.int64) 213 | labels[:, 0] = semantic_label 214 | max_instance_id += 1 215 | labels[:, 1] = max_instance_id 216 | # Apply augmentations if defined 217 | aug = self.volume_augmentations(points=coordinates) 218 | coordinates = aug["points"] 219 | 220 | # Append to output lists 221 | coordinates_list.append(coordinates) 222 | features_list.append(features) 223 | labels_list.append(labels) 224 | return ( 225 | np.vstack(coordinates_list), 226 | np.vstack(features_list), 227 | np.vstack(labels_list), 228 | ) 229 | 230 | @staticmethod 231 | def parse_calibration(filename): 232 | calib = {} 233 | with open(filename) as calib_file: 234 | for line in calib_file: 235 | key, content = line.strip().split(":") 236 | values = [float(v) for v in content.strip().split()] 237 | pose = np.eye(4) 238 | pose[:3, :4] = np.array(values).reshape(3, 4) 239 | calib[key] = pose 240 | return calib 241 | 242 | @staticmethod 243 | def parse_poses(filename, calibration): 244 | Tr = calibration["Tr"] 245 | Tr_inv = np.linalg.inv(Tr) 246 | poses = list() 247 | with open(filename) as file: 248 | for line in file: 249 | values = [float(v) for v in line.strip().split()] 250 | pose = np.eye(4) 251 | pose[:3, :4] = np.array(values).reshape(3, 4) 252 | pose = Tr_inv @ pose @ Tr 253 | poses.append(pose) 254 | return poses 255 | 256 | def _select_correct_labels(self, learning_ignore): 257 | count = 0 258 | label_info = dict() 259 | for k, v in learning_ignore.items(): 260 | if v: 261 | label_info[k] = self.ignore_label 262 | else: 263 | label_info[k] = count 264 | count += 1 265 | return label_info 266 | 267 | @staticmethod 268 | def _load_yaml(filepath): 269 | with open(filepath) as f: 270 | return yaml.safe_load(f) 271 | 272 | def _remap_model_output(self, output): 273 | inv_map = {v: k for k, v in self.label_info.items()} 274 | output = np.vectorize(inv_map.__getitem__)(output) 275 | return output 276 | -------------------------------------------------------------------------------- /Mask4Former3D/datasets/preprocessing/semantic_kitti_preprocessing.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import yaml 4 | from pathlib import Path 5 | from natsort import natsorted 6 | from loguru import logger 7 | from tqdm import tqdm 8 | from fire import Fire 9 | 10 | 11 | class SemanticKittiPreprocessing: 12 | def __init__( 13 | self, 14 | data_dir: str = "/globalwork/data/SemanticKITTI/dataset", 15 | save_dir: str = "/globalwork/yilmaz/data/processed/semantic_kitti", 16 | modes: tuple = ("train", "validation", "test"), 17 | ): 18 | self.data_dir = Path(data_dir) 19 | self.save_dir = Path(save_dir) 20 | self.modes = modes 21 | 22 | if not self.data_dir.exists(): 23 | logger.error("Data folder doesn't exist") 24 | raise FileNotFoundError 25 | if self.save_dir.exists() is False: 26 | self.save_dir.mkdir(parents=True, exist_ok=True) 27 | 28 | self.files = {} 29 | for data_type in self.modes: 30 | self.files.update({data_type: []}) 31 | 32 | self.config = self._load_yaml("conf/semantic-kitti.yaml") 33 | self.create_label_database("conf/semantic-kitti.yaml") 34 | self.pose = dict() 35 | 36 | for mode in self.modes: 37 | scene_mode = "valid" if mode == "validation" else mode 38 | self.pose[mode] = dict() 39 | for scene in sorted(self.config["split"][scene_mode]): 40 | filepaths = list(self.data_dir.glob(f"*/{scene:02}/velodyne/*bin")) 41 | filepaths = [str(file) for file in filepaths] 42 | self.files[mode].extend(natsorted(filepaths)) 43 | calibration = parse_calibration(Path(filepaths[0]).parent.parent / "calib.txt") 44 | self.pose[mode].update( 45 | { 46 | scene: parse_poses( 47 | Path(filepaths[0]).parent.parent / "poses.txt", 48 | calibration, 49 | ), 50 | } 51 | ) 52 | 53 | def preprocess(self): 54 | for mode in self.modes: 55 | database = [] 56 | for filepath in tqdm(self.files[mode], unit="file"): 57 | filebase = self.process_file(filepath, mode) 58 | database.append(filebase) 59 | self.save_database(database, mode) 60 | self.joint_database() 61 | 62 | def make_instance_database(self): 63 | train_database = self._load_yaml(self.save_dir / "train_database.yaml") 64 | instance_database = {} 65 | for sample in tqdm(train_database): 66 | instances = self.extract_instance_from_file(sample) 67 | for instance in instances: 68 | scene = instance["scene"] 69 | panoptic_label = instance["panoptic_label"] 70 | unique_identifier = f"{scene}_{panoptic_label}" 71 | if unique_identifier in instance_database: 72 | instance_database[unique_identifier]["filepaths"].append(instance["instance_filepath"]) 73 | else: 74 | instance_database[unique_identifier] = { 75 | "semantic_label": instance["semantic_label"], 76 | "filepaths": [instance["instance_filepath"]], 77 | } 78 | self.save_database(list(instance_database.values()), "train_instances") 79 | 80 | validation_database = self._load_yaml(self.save_dir / "validation_database.yaml") 81 | for sample in tqdm(validation_database): 82 | instances = self.extract_instance_from_file(sample) 83 | for instance in instances: 84 | scene = instance["scene"] 85 | panoptic_label = instance["panoptic_label"] 86 | unique_identifier = f"{scene}_{panoptic_label}" 87 | if unique_identifier in instance_database: 88 | instance_database[unique_identifier]["filepaths"].append(instance["instance_filepath"]) 89 | else: 90 | instance_database[unique_identifier] = { 91 | "semantic_label": instance["semantic_label"], 92 | "filepaths": [instance["instance_filepath"]], 93 | } 94 | self.save_database(list(instance_database.values()), "trainval_instances") 95 | 96 | def extract_instance_from_file(self, sample): 97 | points = np.fromfile(sample["filepath"], dtype=np.float32).reshape(-1, 4) 98 | pose = np.array(sample["pose"]).T 99 | points[:, :3] = points[:, :3] @ pose[:3, :3] + pose[3, :3] 100 | label = np.fromfile(sample["label_filepath"], dtype=np.uint32) 101 | scene, sub_scene = re.search(r"(\d{2,3}).*(\d{6})", sample["filepath"]).group(1, 2) 102 | file_instances = [] 103 | for panoptic_label in np.unique(label): 104 | semantic_label = panoptic_label & 0xFFFF 105 | semantic_label = np.vectorize(self.config["learning_map"].__getitem__)(semantic_label) 106 | if np.isin(semantic_label, range(1, 9)): 107 | instance_mask = label == panoptic_label 108 | instance_points = points[instance_mask, :] 109 | filename = f"{scene}_{panoptic_label:010d}_{sub_scene}.npy" 110 | instance_filepath = self.save_dir / "instances" / filename 111 | instance = { 112 | "scene": scene, 113 | "sub_scene": sub_scene, 114 | "panoptic_label": f"{panoptic_label:010d}", 115 | "instance_filepath": str(instance_filepath), 116 | "semantic_label": semantic_label.item(), 117 | } 118 | if not instance_filepath.parent.exists(): 119 | instance_filepath.parent.mkdir(parents=True, exist_ok=True) 120 | np.save(instance_filepath, instance_points.astype(np.float32)) 121 | file_instances.append(instance) 122 | return file_instances 123 | 124 | def save_database(self, database, mode): 125 | for element in database: 126 | self._dict_to_yaml(element) 127 | self._save_yaml(self.save_dir / (mode + "_database.yaml"), database) 128 | 129 | def joint_database(self, train_modes=["train", "validation"]): 130 | joint_db = [] 131 | for mode in train_modes: 132 | joint_db.extend(self._load_yaml(self.save_dir / (mode + "_database.yaml"))) 133 | self._save_yaml(self.save_dir / "trainval_database.yaml", joint_db) 134 | 135 | @classmethod 136 | def _save_yaml(cls, path, file): 137 | with open(path, "w") as f: 138 | yaml.safe_dump(file, f, default_style=None, default_flow_style=False) 139 | 140 | @classmethod 141 | def _dict_to_yaml(cls, dictionary): 142 | if not isinstance(dictionary, dict): 143 | return 144 | for k, v in dictionary.items(): 145 | if isinstance(v, dict): 146 | cls._dict_to_yaml(v) 147 | if isinstance(v, np.ndarray): 148 | dictionary[k] = v.tolist() 149 | if isinstance(v, Path): 150 | dictionary[k] = str(v) 151 | 152 | @classmethod 153 | def _load_yaml(cls, filepath): 154 | with open(filepath) as f: 155 | file = yaml.safe_load(f) 156 | return file 157 | 158 | def create_label_database(self, config_file): 159 | if (self.save_dir / "label_database.yaml").exists(): 160 | return self._load_yaml(self.save_dir / "label_database.yaml") 161 | config = self._load_yaml(config_file) 162 | label_database = {} 163 | for key, old_key in config["learning_map_inv"].items(): 164 | label_database.update( 165 | { 166 | key: { 167 | "name": config["labels"][old_key], 168 | "color": config["color_map"][old_key][::-1], 169 | "validation": not config["learning_ignore"][key], 170 | } 171 | } 172 | ) 173 | 174 | self._save_yaml(self.save_dir / "label_database.yaml", label_database) 175 | return label_database 176 | 177 | def process_file(self, filepath, mode): 178 | scene, sub_scene = re.search(r"(\d{2,3}).*(\d{6})", filepath).group(1, 2) 179 | sample = { 180 | "filepath": filepath, 181 | "scene": int(scene), 182 | "pose": self.pose[mode][int(scene)][int(sub_scene)].tolist(), 183 | } 184 | 185 | if mode in ["train", "validation"]: 186 | # getting label info 187 | label_filepath = filepath.replace("velodyne", "labels").replace("bin", "label") 188 | sample["label_filepath"] = label_filepath 189 | return sample 190 | 191 | 192 | def parse_calibration(filename): 193 | calib = {} 194 | 195 | with open(filename) as calib_file: 196 | for line in calib_file: 197 | key, content = line.strip().split(":") 198 | values = [float(v) for v in content.strip().split()] 199 | 200 | pose = np.zeros((4, 4)) 201 | pose[0, 0:4] = values[0:4] 202 | pose[1, 0:4] = values[4:8] 203 | pose[2, 0:4] = values[8:12] 204 | pose[3, 3] = 1.0 205 | 206 | calib[key] = pose 207 | return calib 208 | 209 | 210 | def parse_poses(filename, calibration): 211 | poses = [] 212 | 213 | Tr = calibration["Tr"] 214 | Tr_inv = np.linalg.inv(Tr) 215 | 216 | with open(filename) as file: 217 | for line in file: 218 | values = [float(v) for v in line.strip().split()] 219 | 220 | pose = np.zeros((4, 4)) 221 | pose[0, 0:4] = values[0:4] 222 | pose[1, 0:4] = values[4:8] 223 | pose[2, 0:4] = values[8:12] 224 | pose[3, 3] = 1.0 225 | 226 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr))) 227 | 228 | return poses 229 | 230 | 231 | if __name__ == "__main__": 232 | Fire(SemanticKittiPreprocessing) 233 | -------------------------------------------------------------------------------- /Mask4Former3D/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import MinkowskiEngine as ME 2 | import numpy as np 3 | import torch 4 | 5 | 6 | class VoxelizeCollate: 7 | def __init__( 8 | self, 9 | ignore_label=255, 10 | voxel_size=1, 11 | ): 12 | self.voxel_size = voxel_size 13 | self.ignore_label = ignore_label 14 | 15 | def __call__(self, batch): 16 | (coordinates, features, labels, original_labels, inverse_maps, num_points, sequences) = ( 17 | [], 18 | [], 19 | [], 20 | [], 21 | [], 22 | [], 23 | [], 24 | ) 25 | 26 | for sample in batch: 27 | original_labels.append(sample["labels"]) 28 | num_points.append(sample["num_points"]) 29 | sequences.append(sample["sequence"]) 30 | sample_c, sample_f, sample_l, inverse_map = voxelize( 31 | sample["coordinates"], sample["features"], sample["labels"], self.voxel_size 32 | ) 33 | inverse_maps.append(inverse_map) 34 | coordinates.append(sample_c) 35 | features.append(sample_f) 36 | labels.append(sample_l) 37 | 38 | # Concatenate all lists 39 | target = generate_target(features, labels, self.ignore_label) 40 | coordinates, features = ME.utils.sparse_collate(coordinates, features) 41 | raw_coordinates = features[:, :4] 42 | features = features[:, 4:] 43 | 44 | return ( 45 | NoGpu( 46 | coordinates, features, raw_coordinates, original_labels, inverse_maps, num_points, sequences 47 | ), 48 | target, 49 | ) 50 | 51 | 52 | def voxelize(coordinates, features, labels, voxel_size): 53 | if coordinates.shape[1] == 4: 54 | voxel_size = np.array([voxel_size, voxel_size, voxel_size, 1]) 55 | sample_c, sample_f, unique_map, inverse_map = ME.utils.sparse_quantize( 56 | coordinates=coordinates, 57 | features=features, 58 | return_index=True, 59 | return_inverse=True, 60 | quantization_size=voxel_size, 61 | ) 62 | sample_c = sample_c 63 | sample_f = torch.from_numpy(sample_f).float() 64 | sample_l = torch.from_numpy(labels[unique_map]) 65 | return sample_c, sample_f, sample_l, inverse_map 66 | 67 | 68 | def generate_target(features, labels, ignore_label): 69 | target = [] 70 | 71 | for feat, lb in zip(features, labels): 72 | raw_coords = feat[:, :3] 73 | raw_coords = (raw_coords - raw_coords.min(0)[0]) / (raw_coords.max(0)[0] - raw_coords.min(0)[0]) 74 | mask_labels = [] 75 | binary_masks = [] 76 | bboxs = [] 77 | 78 | panoptic_labels = lb[:, 1].unique() 79 | for panoptic_label in panoptic_labels: 80 | mask = lb[:, 1] == panoptic_label 81 | 82 | if panoptic_label == 0: 83 | continue 84 | 85 | sem_labels = lb[mask, 0] 86 | if sem_labels[0] != ignore_label: 87 | mask_labels.append(sem_labels[0]) 88 | binary_masks.append(mask) 89 | mask_coords = raw_coords[mask, :] 90 | bboxs.append( 91 | torch.hstack( 92 | ( 93 | mask_coords.mean(0), 94 | mask_coords.max(0)[0] - mask_coords.min(0)[0], 95 | ) 96 | ) 97 | ) 98 | 99 | if len(mask_labels) != 0: 100 | mask_labels = torch.stack(mask_labels) 101 | binary_masks = torch.stack(binary_masks) 102 | bboxs = torch.stack(bboxs) 103 | target.append({"labels": mask_labels, "masks": binary_masks, "bboxs": bboxs}) 104 | 105 | return target 106 | 107 | 108 | class NoGpu: 109 | def __init__( 110 | self, 111 | coordinates, 112 | features, 113 | raw_coordinates, 114 | original_labels=None, 115 | inverse_maps=None, 116 | num_points=None, 117 | sequences=None, 118 | ): 119 | """helper class to prevent gpu loading on lightning""" 120 | self.coordinates = coordinates 121 | self.features = features 122 | self.raw_coordinates = raw_coordinates 123 | self.original_labels = original_labels 124 | self.inverse_maps = inverse_maps 125 | self.num_points = num_points 126 | self.sequences = sequences 127 | -------------------------------------------------------------------------------- /Mask4Former3D/main_panoptic.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import hydra 4 | import torch 5 | from dotenv import load_dotenv 6 | from omegaconf import DictConfig, OmegaConf 7 | from trainer.pq_trainer import PanopticSegmentation 8 | from utils.utils import flatten_dict, RegularCheckpointing 9 | from pytorch_lightning import Trainer, seed_everything 10 | 11 | 12 | def get_parameters(cfg: DictConfig): 13 | logger = logging.getLogger(__name__) 14 | load_dotenv(".env") 15 | 16 | # parsing input parameters 17 | seed_everything(cfg.general.seed) 18 | 19 | # getting basic configuration 20 | if cfg.general.get("gpus", None) is None: 21 | cfg.general.gpus = os.environ.get("CUDA_VISIBLE_DEVICES", None) 22 | loggers = [] 23 | 24 | if not os.path.exists(cfg.general.save_dir): 25 | os.makedirs(cfg.general.save_dir) 26 | else: 27 | print("EXPERIMENT ALREADY EXIST") 28 | cfg.general.ckpt_path = f"{cfg.general.save_dir}/last-epoch.ckpt" 29 | 30 | for log in cfg.logging: 31 | print(log) 32 | loggers.append(hydra.utils.instantiate(log)) 33 | 34 | model = PanopticSegmentation(cfg) 35 | 36 | logger.info(flatten_dict(OmegaConf.to_container(cfg, resolve=True))) 37 | return cfg, model, loggers 38 | 39 | 40 | @hydra.main(config_path="conf", config_name="config_panoptic_3d.yaml") 41 | def train(cfg: DictConfig): 42 | os.chdir(hydra.utils.get_original_cwd()) 43 | cfg, model, loggers = get_parameters(cfg) 44 | callbacks = [] 45 | for cb in cfg.callbacks: 46 | callbacks.append(hydra.utils.instantiate(cb)) 47 | 48 | callbacks.append(RegularCheckpointing()) 49 | # torch.use_deterministic_algorithms(True) 50 | runner = Trainer( 51 | logger=loggers, 52 | accelerator="gpu", 53 | devices=1, 54 | callbacks=callbacks, 55 | default_root_dir=str(cfg.general.save_dir), 56 | **cfg.trainer, 57 | ) 58 | runner.fit(model, ckpt_path=cfg.general.ckpt_path) 59 | 60 | 61 | @hydra.main(config_path="conf", config_name="config_panoptic_3d.yaml") 62 | def validate(cfg: DictConfig): 63 | # because hydra wants to change dir for some reason 64 | os.chdir(hydra.utils.get_original_cwd()) 65 | cfg, model, loggers = get_parameters(cfg) 66 | runner = Trainer( 67 | logger=loggers, 68 | accelerator="gpu", 69 | devices=1, 70 | default_root_dir=str(cfg.general.save_dir), 71 | ) 72 | runner.validate(model=model, ckpt_path=cfg.general.ckpt_path) 73 | 74 | 75 | @hydra.main(config_path="conf", config_name="config_panoptic_3d.yaml") 76 | def test(cfg: DictConfig): 77 | # because hydra wants to change dir for some reason 78 | os.chdir(hydra.utils.get_original_cwd()) 79 | cfg, model, loggers = get_parameters(cfg) 80 | runner = Trainer( 81 | logger=loggers, 82 | accelerator="gpu", 83 | devices=1, 84 | default_root_dir=str(cfg.general.save_dir), 85 | ) 86 | runner.test(model=model, ckpt_path=cfg.general.ckpt_path) 87 | 88 | 89 | @hydra.main(config_path="conf", config_name="config_panoptic_3d.yaml") 90 | def main(cfg: DictConfig): 91 | if cfg["general"]["mode"] == "train": 92 | train(cfg) 93 | elif cfg["general"]["mode"] == "validate": 94 | validate(cfg) 95 | else: 96 | test(cfg) 97 | 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /Mask4Former3D/models/__init__.py: -------------------------------------------------------------------------------- 1 | import models.resunet as resunet 2 | import models.res16unet as res16unet 3 | from models.res16unet import Res16UNet34C, STRes16UNet34C 4 | from models.mask4former import Mask4Former, Mask4Former3D 5 | 6 | MODELS = [] 7 | 8 | 9 | def add_models(module): 10 | MODELS.extend([getattr(module, a) for a in dir(module) if "Net" in a]) 11 | 12 | 13 | add_models(resunet) 14 | add_models(res16unet) 15 | 16 | 17 | def get_models(): 18 | """Returns a tuple of sample models.""" 19 | return MODELS 20 | 21 | 22 | def load_model(name): 23 | """Creates and returns an instance of the model given its class name.""" 24 | # Find the model class from its name 25 | all_models = get_models() 26 | mdict = {model.__name__: model for model in all_models} 27 | if name not in mdict: 28 | print("Invalid model index. Options are:") 29 | # Display a list of valid model names 30 | for model in all_models: 31 | print(f"\t* {model.__name__}") 32 | return None 33 | NetClass = mdict[name] 34 | 35 | return NetClass 36 | -------------------------------------------------------------------------------- /Mask4Former3D/models/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | def dice_loss(inputs: torch.Tensor, targets: torch.Tensor, num_masks: float): 7 | inputs = inputs.sigmoid() 8 | inputs = inputs.flatten(1) 9 | numerator = 2 * (inputs * targets).sum(-1) 10 | denominator = inputs.sum(-1) + targets.sum(-1) 11 | loss = 1 - (numerator + 1) / (denominator + 1) 12 | return loss.sum() / num_masks 13 | 14 | 15 | dice_loss_jit = torch.jit.script(dice_loss) # type: torch.jit.ScriptModule 16 | 17 | 18 | def sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor, num_masks: float): 19 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 20 | return loss.mean(1).sum() / num_masks 21 | 22 | 23 | sigmoid_ce_loss_jit = torch.jit.script(sigmoid_ce_loss) # type: torch.jit.ScriptModule 24 | 25 | 26 | def box_loss(inputs: torch.Tensor, targets: torch.Tensor, num_bboxs: float): 27 | loss = F.l1_loss(inputs, targets, reduction="none") 28 | return loss.mean(1).sum() / num_bboxs 29 | 30 | 31 | box_loss_jit = torch.jit.script(box_loss) # type: torch.jit.ScriptModule 32 | 33 | 34 | class SetCriterion(nn.Module): 35 | """This class computes the loss for DETR. 36 | The process happens in two steps: 37 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 38 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 39 | """ 40 | 41 | def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): 42 | """Create the criterion. 43 | Parameters: 44 | num_classes: number of object categories, omitting the special no-object category 45 | matcher: module able to compute a matching between targets and proposals 46 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 47 | eos_coef: relative classification weight applied to the no-object category 48 | losses: list of all the losses to be applied. See get_loss for list of available losses. 49 | """ 50 | super().__init__() 51 | self.num_classes = num_classes 52 | self.matcher = matcher 53 | self.weight_dict = weight_dict 54 | self.eos_coef = eos_coef 55 | self.losses = losses 56 | empty_weight = torch.ones(num_classes + 1) 57 | empty_weight[-1] = self.eos_coef 58 | 59 | # self.register_buffer("empty_weight", empty_weight) 60 | self.register_buffer("ce_class_weights", empty_weight) 61 | 62 | def loss_labels(self, outputs, targets, indices): 63 | src_logits = outputs["pred_logits"].float() 64 | 65 | idx = self._get_src_permutation_idx(indices) 66 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) 67 | target_classes = torch.full( 68 | src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device 69 | ) 70 | target_classes[idx] = target_classes_o 71 | 72 | loss_ce = F.cross_entropy( 73 | src_logits.transpose(1, 2), target_classes, self.ce_class_weights, ignore_index=255 74 | # src_logits.transpose(1, 2), target_classes, self.empty_weight, ignore_index=255 75 | ) 76 | losses = {"loss_ce": loss_ce} 77 | return losses 78 | 79 | def loss_masks(self, outputs, targets, indices): 80 | loss_masks = [] 81 | loss_dices = [] 82 | 83 | for batch_id, (map_id, target_id) in enumerate(indices): 84 | map = outputs["pred_masks"][batch_id][:, map_id].T 85 | target_mask = targets[batch_id]["masks"][target_id].float() 86 | num_masks = target_mask.shape[0] 87 | 88 | loss_masks.append(sigmoid_ce_loss_jit(map, target_mask, num_masks)) 89 | loss_dices.append(dice_loss_jit(map, target_mask, num_masks)) 90 | return { 91 | "loss_mask": torch.sum(torch.stack(loss_masks)), 92 | "loss_dice": torch.sum(torch.stack(loss_dices)), 93 | } 94 | 95 | def loss_bboxs(self, outputs, targets, indices): 96 | loss_box = torch.zeros(1, device=outputs["pred_bboxs"].device) 97 | for batch_id, (map_id, target_id) in enumerate(indices): 98 | pred_bboxs = outputs["pred_bboxs"][batch_id, map_id, :] 99 | target_bboxs = targets[batch_id]["bboxs"][target_id] 100 | target_classes = targets[batch_id]["labels"][target_id] 101 | keep_things = target_classes < 8 102 | if torch.any(keep_things): 103 | target_bboxs = target_bboxs[keep_things] 104 | pred_bboxs = pred_bboxs[keep_things] 105 | num_bboxs = target_bboxs.shape[0] 106 | loss_box += box_loss_jit(pred_bboxs, target_bboxs, num_bboxs) 107 | return { 108 | "loss_box": loss_box, 109 | } 110 | 111 | def _get_src_permutation_idx(self, indices): 112 | # permute predictions following indices 113 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 114 | src_idx = torch.cat([src for (src, _) in indices]) 115 | return batch_idx, src_idx 116 | 117 | def _get_tgt_permutation_idx(self, indices): 118 | # permute targets following indices 119 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 120 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 121 | return batch_idx, tgt_idx 122 | 123 | def get_loss(self, loss, outputs, targets, indices): 124 | loss_map = {"labels": self.loss_labels, "masks": self.loss_masks, "bboxs": self.loss_bboxs} 125 | return loss_map[loss](outputs, targets, indices) 126 | 127 | def forward(self, outputs, targets): 128 | """This performs the loss computation. 129 | Parameters: 130 | outputs: dict of tensors, see the output specification of the model for the format 131 | targets: list of dicts, such that len(targets) == batch_size. 132 | The expected keys in each dict depends on the losses applied, see each loss' doc 133 | """ 134 | outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} 135 | 136 | # Retrieve the matching between the outputs of the last layer and the targets 137 | indices = self.matcher(outputs_without_aux, targets) 138 | 139 | # Compute all the requested losses 140 | losses = {} 141 | for loss in self.losses: 142 | losses.update(self.get_loss(loss, outputs, targets, indices)) 143 | 144 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 145 | if "aux_outputs" in outputs: 146 | for i, aux_outputs in enumerate(outputs["aux_outputs"]): 147 | indices = self.matcher(aux_outputs, targets) 148 | for loss in self.losses: 149 | l_dict = self.get_loss(loss, aux_outputs, targets, indices) 150 | l_dict = {k + f"_{i}": v for k, v in l_dict.items()} 151 | losses.update(l_dict) 152 | 153 | return losses 154 | -------------------------------------------------------------------------------- /Mask4Former3D/models/mask4former.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import MinkowskiEngine.MinkowskiOps as me 3 | import torch 4 | import torch.nn as nn 5 | from MinkowskiEngine.MinkowskiPooling import MinkowskiAvgPooling 6 | from models.modules.attention import CrossAttentionLayer, FFNLayer, SelfAttentionLayer 7 | from models.modules.common import conv 8 | from models.modules.helpers_3detr import GenericMLP 9 | from models.position_embedding import PositionEmbeddingCoordsSine 10 | from pytorch3d.ops import sample_farthest_points 11 | from torch.cuda.amp import autocast 12 | 13 | 14 | class Mask4Former(nn.Module): 15 | def __init__( 16 | self, 17 | backbone, 18 | num_queries, 19 | num_heads, 20 | num_decoders, 21 | num_levels, 22 | sample_sizes, 23 | mask_dim, 24 | dim_feedforward, 25 | num_labels, 26 | ): 27 | super().__init__() 28 | self.backbone = hydra.utils.instantiate(backbone) 29 | self.num_queries = num_queries 30 | self.num_heads = num_heads 31 | self.num_decoders = num_decoders 32 | self.num_levels = num_levels 33 | self.sample_sizes = sample_sizes 34 | sizes = self.backbone.PLANES[-5:] 35 | 36 | self.point_features_head = conv( 37 | self.backbone.PLANES[7], mask_dim, kernel_size=1, stride=1, bias=True, D=3 38 | ) 39 | 40 | self.query_projection = GenericMLP( 41 | input_dim=mask_dim, 42 | hidden_dims=[mask_dim], 43 | output_dim=mask_dim, 44 | use_conv=True, 45 | output_use_activation=True, 46 | hidden_use_bias=True, 47 | ) 48 | 49 | self.mask_embed_head = nn.Sequential( 50 | nn.Linear(mask_dim, mask_dim), nn.ReLU(), nn.Linear(mask_dim, mask_dim) 51 | ) 52 | 53 | self.bbox_embed_head = nn.Sequential( 54 | nn.Linear(mask_dim, mask_dim), 55 | nn.ReLU(), 56 | nn.Linear(mask_dim, mask_dim), 57 | nn.ReLU(), 58 | nn.Linear(mask_dim, 6), 59 | nn.Sigmoid(), 60 | ) 61 | 62 | self.class_embed_head = nn.Linear(mask_dim, num_labels + 1) 63 | self.pos_enc = PositionEmbeddingCoordsSine(d_pos=mask_dim) 64 | self.temporal_pos_enc = PositionEmbeddingCoordsSine(d_in=1, d_pos=mask_dim) 65 | self.pooling = MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=3) 66 | 67 | self.cross_attention = nn.ModuleList() 68 | self.self_attention = nn.ModuleList() 69 | self.ffn_attention = nn.ModuleList() 70 | self.lin_squeeze = nn.ModuleList() 71 | 72 | for hlevel in range(self.num_levels): 73 | self.cross_attention.append( 74 | CrossAttentionLayer( 75 | d_model=mask_dim, 76 | nhead=self.num_heads, 77 | ) 78 | ) 79 | self.lin_squeeze.append(nn.Linear(sizes[hlevel], mask_dim)) 80 | self.self_attention.append( 81 | SelfAttentionLayer( 82 | d_model=mask_dim, 83 | nhead=self.num_heads, 84 | ) 85 | ) 86 | self.ffn_attention.append( 87 | FFNLayer( 88 | d_model=mask_dim, 89 | dim_feedforward=dim_feedforward, 90 | ) 91 | ) 92 | 93 | self.decoder_norm = nn.LayerNorm(mask_dim) 94 | 95 | def forward(self, x, raw_coordinates=None, is_eval=False): 96 | device = x.device 97 | all_features = self.backbone(x) 98 | point_features = self.point_features_head(all_features[-1]) 99 | 100 | with torch.no_grad(): 101 | coordinates = me.SparseTensor( 102 | features=raw_coordinates, coordinates=x.C, device=device 103 | ) 104 | pos_encodings_pcd = self.get_pos_encs(coordinates) 105 | 106 | sampled_coords = [] 107 | mins = [] 108 | maxs = [] 109 | for coords, feats in zip( 110 | x.decomposed_coordinates, coordinates.decomposed_features 111 | ): 112 | _, fps_idx = sample_farthest_points( 113 | coords[None, ...].float(), K=self.num_queries 114 | ) 115 | sampled_coords.append(feats[fps_idx.squeeze(0).long(), :3]) 116 | mins.append(feats[:, :3].min(dim=0)[0]) 117 | maxs.append(feats[:, :3].max(dim=0)[0]) 118 | 119 | sampled_coords = torch.stack(sampled_coords) 120 | mins = torch.stack(mins) 121 | maxs = torch.stack(maxs) 122 | 123 | query_pos = self.pos_enc(sampled_coords.float(), input_range=[mins, maxs]) 124 | query_pos = self.query_projection(query_pos) 125 | 126 | queries = torch.zeros_like(query_pos).permute((0, 2, 1)) 127 | query_pos = query_pos.permute((2, 0, 1)) 128 | 129 | predictions_class = [] 130 | predictions_bbox = [] 131 | predictions_mask = [] 132 | 133 | for _ in range(self.num_decoders): 134 | for hlevel in range(self.num_levels): 135 | output_class, outputs_bbox, outputs_mask, attn_mask = self.mask_module( 136 | queries, point_features, self.num_levels - hlevel 137 | ) 138 | 139 | decomposed_feat = all_features[hlevel].decomposed_features 140 | decomposed_attn = attn_mask.decomposed_features 141 | 142 | pcd_sizes = [pcd.shape[0] for pcd in decomposed_feat] 143 | curr_sample_size = max(pcd_sizes) 144 | 145 | if not is_eval: 146 | curr_sample_size = min(curr_sample_size, self.sample_sizes[hlevel]) 147 | 148 | rand_idx, mask_idx = self.get_random_samples( 149 | pcd_sizes, curr_sample_size, device 150 | ) 151 | 152 | batched_feat = torch.stack( 153 | [feat[idx, :] for feat, idx in zip(decomposed_feat, rand_idx)] 154 | ) 155 | 156 | batched_attn = torch.stack( 157 | [attn[idx, :] for attn, idx in zip(decomposed_attn, rand_idx)] 158 | ) 159 | 160 | batched_pos_enc = torch.stack( 161 | [ 162 | pos_enc[idx, :] 163 | for pos_enc, idx in zip(pos_encodings_pcd[hlevel], rand_idx) 164 | ] 165 | ) 166 | 167 | batched_attn.permute((0, 2, 1))[ 168 | batched_attn.sum(1) == curr_sample_size 169 | ] = False 170 | 171 | m = torch.stack(mask_idx) 172 | batched_attn = torch.logical_or(batched_attn, m[..., None]) 173 | 174 | src_pcd = self.lin_squeeze[hlevel](batched_feat.permute((1, 0, 2))) 175 | 176 | output = self.cross_attention[ 177 | hlevel 178 | ]( 179 | queries.permute((1, 0, 2)), 180 | src_pcd, 181 | memory_mask=batched_attn.repeat_interleave( 182 | self.num_heads, dim=0 183 | ).permute((0, 2, 1)), 184 | memory_key_padding_mask=None, # here we do not apply masking on padded region 185 | pos=batched_pos_enc.permute((1, 0, 2)), 186 | query_pos=query_pos, 187 | ) 188 | 189 | output = self.self_attention[hlevel]( 190 | output, 191 | tgt_mask=None, 192 | tgt_key_padding_mask=None, 193 | query_pos=query_pos, 194 | ) 195 | 196 | # FFN 197 | queries = self.ffn_attention[hlevel](output).permute((1, 0, 2)) 198 | 199 | predictions_class.append(output_class) 200 | predictions_bbox.append(outputs_bbox) 201 | predictions_mask.append(outputs_mask) 202 | 203 | output_class, outputs_bbox, outputs_mask = self.mask_module( 204 | queries, point_features 205 | ) 206 | predictions_class.append(output_class) 207 | predictions_bbox.append(outputs_bbox) 208 | predictions_mask.append(outputs_mask) 209 | 210 | return { 211 | "pred_logits": predictions_class[-1], 212 | "pred_bboxs": predictions_bbox[-1], 213 | "pred_masks": predictions_mask[-1], 214 | "aux_outputs": self._set_aux_loss( 215 | predictions_class, predictions_bbox, predictions_mask 216 | ), 217 | } 218 | 219 | def mask_module(self, query_feat, point_features, num_pooling_steps=0): 220 | query_feat = self.decoder_norm(query_feat) 221 | mask_embed = self.mask_embed_head(query_feat) 222 | outputs_class = self.class_embed_head(query_feat) 223 | outputs_bbox = self.bbox_embed_head(query_feat) 224 | 225 | output_masks = [] 226 | 227 | for feat, embed in zip(point_features.decomposed_features, mask_embed): 228 | output_masks.append(feat @ embed.T) 229 | 230 | output_masks = torch.cat(output_masks) 231 | outputs_mask = me.SparseTensor( 232 | features=output_masks, 233 | coordinate_manager=point_features.coordinate_manager, 234 | coordinate_map_key=point_features.coordinate_map_key, 235 | ) 236 | 237 | if num_pooling_steps != 0: 238 | attn_mask = outputs_mask 239 | for _ in range(num_pooling_steps): 240 | attn_mask = self.pooling(attn_mask.float()) 241 | 242 | attn_mask = me.SparseTensor( 243 | features=(attn_mask.F.detach().sigmoid() < 0.5), 244 | coordinate_manager=attn_mask.coordinate_manager, 245 | coordinate_map_key=attn_mask.coordinate_map_key, 246 | ) 247 | 248 | return ( 249 | outputs_class, 250 | outputs_bbox, 251 | outputs_mask.decomposed_features, 252 | attn_mask, 253 | ) 254 | 255 | return outputs_class, outputs_bbox, outputs_mask.decomposed_features 256 | 257 | def get_pos_encs(self, coordinates): 258 | pos_encodings_pcd = [] 259 | 260 | for _ in range(self.num_levels + 1): 261 | pos_encodings_pcd.append([]) 262 | 263 | for coords_batch in coordinates.decomposed_features: 264 | scene_min = coords_batch.min(dim=0)[0][None, ...] 265 | scene_max = coords_batch.max(dim=0)[0][None, ...] 266 | 267 | with autocast(enabled=False): 268 | tmp = self.pos_enc( 269 | coords_batch[None, :, :3].float(), 270 | input_range=[scene_min[:, :3], scene_max[:, :3]], 271 | ) 272 | tmp += self.temporal_pos_enc( 273 | coords_batch[None, :, 3].float(), 274 | input_range=[scene_min[:, 3:4], scene_max[:, 3:4]], 275 | ) 276 | 277 | pos_encodings_pcd[-1].append(tmp.squeeze(0).permute((1, 0))) 278 | 279 | coordinates = self.pooling(coordinates) 280 | 281 | pos_encodings_pcd.reverse() 282 | 283 | return pos_encodings_pcd 284 | 285 | def get_random_samples(self, pcd_sizes, curr_sample_size, device): 286 | rand_idx = [] 287 | mask_idx = [] 288 | for pcd_size in pcd_sizes: 289 | if pcd_size <= curr_sample_size: 290 | # we do not need to sample 291 | # take all points and pad the rest with zeroes and mask it 292 | idx = torch.zeros(curr_sample_size, dtype=torch.long, device=device) 293 | midx = torch.ones(curr_sample_size, dtype=torch.bool, device=device) 294 | idx[:pcd_size] = torch.arange(pcd_size, device=device) 295 | midx[:pcd_size] = False # attend to first points 296 | else: 297 | # we have more points in pcd as we like to sample 298 | # take a subset (no padding or masking needed) 299 | idx = torch.randperm(pcd_size, device=device)[:curr_sample_size] 300 | midx = torch.zeros(curr_sample_size, dtype=torch.bool, device=device) 301 | 302 | rand_idx.append(idx) 303 | mask_idx.append(midx) 304 | return rand_idx, mask_idx 305 | 306 | @torch.jit.unused 307 | def _set_aux_loss(self, outputs_class, outputs_bbox, outputs_seg_masks): 308 | # this is a workaround to make torchscript happy, as torchscript 309 | # doesn't support dictionary with non-homogeneous values, such 310 | # as a dict having both a Tensor and a list. 311 | return [ 312 | {"pred_logits": a, "pred_bboxs": b, "pred_masks": c} 313 | for a, b, c in zip( 314 | outputs_class[:-1], outputs_bbox[:-1], outputs_seg_masks[:-1] 315 | ) 316 | ] 317 | 318 | 319 | class Mask4Former3D(Mask4Former): 320 | def __init__(self, *args, **kwargs): 321 | super().__init__(*args, **kwargs) 322 | 323 | mask_dim = kwargs["mask_dim"] 324 | 325 | self.query_projection = nn.Sequential( 326 | nn.Conv1d(mask_dim, mask_dim, 1), 327 | nn.ReLU(), 328 | nn.Conv1d(mask_dim, mask_dim, 1), 329 | nn.ReLU(), 330 | ) 331 | 332 | def get_pos_encs(self, coordinates): 333 | pos_encodings_pcd = [] 334 | 335 | for _ in range(self.num_levels + 1): 336 | pos_encodings_pcd.append([]) 337 | 338 | for coords_batch in coordinates.decomposed_features: 339 | scene_min = coords_batch.min(dim=0)[0][None, ...] 340 | scene_max = coords_batch.max(dim=0)[0][None, ...] 341 | 342 | with autocast(enabled=False): 343 | tmp = self.pos_enc( 344 | coords_batch[None, :, :3].float(), 345 | input_range=[scene_min[:, :3], scene_max[:, :3]], 346 | ) 347 | # tmp += self.temporal_pos_enc( 348 | # coords_batch[None, :, 3].float(), input_range=[scene_min[:, 3:4], scene_max[:, 3:4]] 349 | # ) 350 | pos_encodings_pcd[-1].append(tmp.squeeze(0).permute((1, 0))) 351 | 352 | coordinates = self.pooling(coordinates) 353 | 354 | pos_encodings_pcd.reverse() 355 | 356 | return pos_encodings_pcd 357 | -------------------------------------------------------------------------------- /Mask4Former3D/models/matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py 3 | """ 4 | Modules to compute the matching cost and solve the corresponding LSAP. 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | from scipy.optimize import linear_sum_assignment 9 | from torch import nn 10 | from torch.cuda.amp import autocast 11 | 12 | 13 | def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): 14 | inputs = inputs.sigmoid() 15 | inputs = inputs.flatten(1) 16 | numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) 17 | denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] 18 | loss = 1 - (numerator + 1) / (denominator + 1) 19 | return loss 20 | 21 | 22 | batch_dice_loss_jit = torch.jit.script(batch_dice_loss) # type: torch.jit.ScriptModule 23 | 24 | 25 | def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): 26 | """ 27 | Args: 28 | inputs: A float tensor of arbitrary shape. 29 | The predictions for each example. 30 | targets: A float tensor with the same shape as inputs. Stores the binary 31 | classification label for each element in inputs 32 | (0 for the negative class and 1 for the positive class). 33 | Returns: 34 | Loss tensor 35 | """ 36 | hw = inputs.shape[1] 37 | 38 | pos = F.binary_cross_entropy_with_logits(inputs, torch.ones_like(inputs), reduction="none") 39 | neg = F.binary_cross_entropy_with_logits(inputs, torch.zeros_like(inputs), reduction="none") 40 | 41 | loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum("nc,mc->nm", neg, (1 - targets)) 42 | 43 | return loss / hw 44 | 45 | 46 | batch_sigmoid_ce_loss_jit = torch.jit.script(batch_sigmoid_ce_loss) # type: torch.jit.ScriptModule 47 | 48 | 49 | class HungarianMatcher(nn.Module): 50 | """This class computes an assignment between the targets and the predictions of the network 51 | 52 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 53 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 54 | while the others are un-matched (and thus treated as non-objects). 55 | """ 56 | 57 | def __init__( 58 | self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, cost_box: float = 1 59 | ): 60 | """Creates the matcher 61 | 62 | Params: 63 | cost_class: This is the relative weight of the classification error in the matching cost 64 | cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost 65 | cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost 66 | """ 67 | super().__init__() 68 | self.cost_class = cost_class 69 | self.cost_mask = cost_mask 70 | self.cost_dice = cost_dice 71 | self.cost_box = cost_box 72 | 73 | @torch.no_grad() 74 | def forward(self, outputs, targets): 75 | """Performs the matching 76 | 77 | Params: 78 | outputs: This is a dict that contains at least these entries: 79 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 80 | "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks 81 | 82 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 83 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 84 | objects in the target) containing the class labels 85 | "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks 86 | 87 | Returns: 88 | A list of size batch_size, containing tuples of (index_i, index_j) where: 89 | - index_i is the indices of the selected predictions (in order) 90 | - index_j is the indices of the corresponding selected targets (in order) 91 | For each batch element, it holds: 92 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 93 | """ 94 | bs, num_queries = outputs["pred_logits"].shape[:2] 95 | 96 | indices = [] 97 | 98 | # Iterate through batch size 99 | for b in range(bs): 100 | out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes] 101 | tgt_ids = targets[b]["labels"] 102 | 103 | cost_class = -out_prob[:, tgt_ids] 104 | 105 | out_mask = outputs["pred_masks"][b].T # [num_queries, H_pred, W_pred] 106 | # gt masks are already padded when preparing target 107 | tgt_mask = targets[b]["masks"].to(out_mask) 108 | 109 | with autocast(enabled=False): 110 | out_mask = out_mask.float() 111 | tgt_mask = tgt_mask.float() 112 | # Compute the focal loss between masks 113 | cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) 114 | 115 | # Compute the dice loss betwen masks 116 | cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) 117 | 118 | # Final cost matrix 119 | C = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice 120 | C = C.reshape(num_queries, -1).cpu() 121 | 122 | indices.append(linear_sum_assignment(C)) 123 | 124 | return [ 125 | (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices 126 | ] 127 | -------------------------------------------------------------------------------- /Mask4Former3D/models/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .panoptic_quality import Panoptic4DEval 2 | from .panoptic_eval import PanopticEval 3 | -------------------------------------------------------------------------------- /Mask4Former3D/models/metrics/panoptic_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # This file is covered by the LICENSE file in the root of this project. 4 | # https://github.com/PRBonn/semantic-kitti-api 5 | 6 | import math 7 | import time 8 | 9 | import numpy as np 10 | 11 | 12 | class PanopticEval: 13 | """Panoptic evaluation using numpy 14 | 15 | authors: Andres Milioto and Jens Behley 16 | 17 | """ 18 | 19 | def __init__( 20 | self, n_classes, device=None, ignore_label=0, offset=2**32, min_points=50 21 | ): 22 | self.n_classes = n_classes + 1 23 | assert device is None 24 | ignore_label = 0 25 | self.ignore = np.array(ignore_label, dtype=np.int64) 26 | self.include = np.array( 27 | [n for n in range(self.n_classes) if n not in self.ignore], dtype=np.int64 28 | ) 29 | 30 | print("[PANOPTIC EVAL] IGNORE: ", self.ignore) 31 | print("[PANOPTIC EVAL] INCLUDE: ", self.include) 32 | 33 | self.reset() 34 | self.offset = offset # largest number of instances in a given scan 35 | self.min_points = ( 36 | min_points # smallest number of points to consider instances in gt 37 | ) 38 | self.eps = 1e-15 39 | 40 | def num_classes(self): 41 | return self.n_classes 42 | 43 | def reset(self): 44 | # general things 45 | # iou stuff 46 | self.px_iou_conf_matrix = np.zeros( 47 | (self.n_classes, self.n_classes), dtype=np.int64 48 | ) 49 | # panoptic stuff 50 | self.pan_tp = np.zeros(self.n_classes, dtype=np.int64) 51 | self.pan_iou = np.zeros(self.n_classes, dtype=np.double) 52 | self.pan_fp = np.zeros(self.n_classes, dtype=np.int64) 53 | self.pan_fn = np.zeros(self.n_classes, dtype=np.int64) 54 | 55 | ################################# IoU STUFF ################################## 56 | def addBatchSemIoU(self, x_sem, y_sem): 57 | # idxs are labels and predictions 58 | idxs = np.stack([x_sem, y_sem], axis=0) 59 | 60 | # set_trace() 61 | # print(idxs) 62 | # print(idxs.shape) 63 | # print(self.px_iou_conf_matrix) 64 | # print(self.px_iou_conf_matrix.shape) 65 | 66 | # make confusion matrix (cols = gt, rows = pred) 67 | np.add.at(self.px_iou_conf_matrix, tuple(idxs), 1) 68 | 69 | def getSemIoUStats(self): 70 | # clone to avoid modifying the real deal 71 | conf = self.px_iou_conf_matrix.copy().astype(np.double) 72 | # remove fp from confusion on the ignore classes predictions 73 | # points that were predicted of another class, but were ignore 74 | # (corresponds to zeroing the cols of those classes, since the predictions 75 | # go on the rows) 76 | conf[:, self.ignore] = 0 77 | 78 | # get the clean stats 79 | tp = conf.diagonal() 80 | fp = conf.sum(axis=1) - tp 81 | fn = conf.sum(axis=0) - tp 82 | return tp, fp, fn 83 | 84 | def getSemIoU(self): 85 | tp, fp, fn = self.getSemIoUStats() 86 | # print(f"tp={tp}") 87 | # print(f"fp={fp}") 88 | # print(f"fn={fn}") 89 | intersection = tp 90 | union = tp + fp + fn 91 | union = np.maximum(union, self.eps) 92 | iou = intersection.astype(np.float64) / union.astype(np.float64) 93 | iou_mean = ( 94 | intersection[self.include].astype(np.double) 95 | / union[self.include].astype(np.double) 96 | ).mean() 97 | # prec = tp / (tp+fp) 98 | # recall = tp / (tp+fn) 99 | return iou_mean, iou # returns "iou mean", "iou per class" ALL CLASSES 100 | 101 | def getSemAcc(self): 102 | tp, fp, fn = self.getSemIoUStats() 103 | total_tp = tp.sum() 104 | total = tp[self.include].sum() + fp[self.include].sum() 105 | total = np.maximum(total, self.eps) 106 | acc_mean = total_tp.astype(np.double) / total.astype(np.double) 107 | 108 | return acc_mean # returns "acc mean" 109 | 110 | ################################# IoU STUFF ################################## 111 | ############################################################################## 112 | 113 | ############################# Panoptic STUFF ################################ 114 | def addBatchPanoptic(self, x_sem_row, x_inst_row, y_sem_row, y_inst_row): 115 | # make sure instances are not zeros (it messes with my approach) 116 | x_inst_row = x_inst_row + 1 117 | y_inst_row = y_inst_row + 1 118 | 119 | # only interested in points that are outside the void area (not in excluded classes) 120 | for cl in [0]: 121 | # make a mask for this class 122 | gt_not_in_excl_mask = y_sem_row != cl 123 | # remove all other points 124 | x_sem_row = x_sem_row[gt_not_in_excl_mask] 125 | y_sem_row = y_sem_row[gt_not_in_excl_mask] 126 | x_inst_row = x_inst_row[gt_not_in_excl_mask] 127 | y_inst_row = y_inst_row[gt_not_in_excl_mask] 128 | 129 | # first step is to count intersections > 0.5 IoU for each class (except the ignored ones) 130 | for cl in self.include: 131 | # print("*"*80) 132 | # print("CLASS", cl.item()) 133 | # get a class mask 134 | x_inst_in_cl_mask = x_sem_row == cl 135 | y_inst_in_cl_mask = y_sem_row == cl 136 | 137 | # get instance points in class (makes outside stuff 0) 138 | x_inst_in_cl = x_inst_row * x_inst_in_cl_mask.astype(np.int64) 139 | y_inst_in_cl = y_inst_row * y_inst_in_cl_mask.astype(np.int64) 140 | 141 | # generate the areas for each unique instance prediction 142 | unique_pred, counts_pred = np.unique( 143 | x_inst_in_cl[x_inst_in_cl > 0], return_counts=True 144 | ) 145 | id2idx_pred = {id: idx for idx, id in enumerate(unique_pred)} 146 | matched_pred = np.array([False] * unique_pred.shape[0]) 147 | # print("Unique predictions:", unique_pred) 148 | 149 | # generate the areas for each unique instance gt_np 150 | unique_gt, counts_gt = np.unique( 151 | y_inst_in_cl[y_inst_in_cl > 0], return_counts=True 152 | ) 153 | id2idx_gt = {id: idx for idx, id in enumerate(unique_gt)} 154 | matched_gt = np.array([False] * unique_gt.shape[0]) 155 | # print("Unique ground truth:", unique_gt) 156 | 157 | # generate intersection using offset 158 | valid_combos = np.logical_and(x_inst_in_cl > 0, y_inst_in_cl > 0) 159 | offset_combo = ( 160 | x_inst_in_cl[valid_combos] + self.offset * y_inst_in_cl[valid_combos] 161 | ) 162 | unique_combo, counts_combo = np.unique(offset_combo, return_counts=True) 163 | 164 | # generate an intersection map 165 | # count the intersections with over 0.5 IoU as TP 166 | gt_labels = unique_combo // self.offset 167 | pred_labels = unique_combo % self.offset 168 | gt_areas = np.array([counts_gt[id2idx_gt[id]] for id in gt_labels]) 169 | pred_areas = np.array([counts_pred[id2idx_pred[id]] for id in pred_labels]) 170 | intersections = counts_combo 171 | unions = gt_areas + pred_areas - intersections 172 | ious = intersections.astype(np.float64) / unions.astype(np.float64) 173 | 174 | tp_indexes = ious > 0.5 175 | self.pan_tp[cl] += np.sum(tp_indexes) 176 | self.pan_iou[cl] += np.sum(ious[tp_indexes]) 177 | 178 | matched_gt[[id2idx_gt[id] for id in gt_labels[tp_indexes]]] = True 179 | matched_pred[[id2idx_pred[id] for id in pred_labels[tp_indexes]]] = True 180 | 181 | # count the FN 182 | self.pan_fn[cl] += np.sum( 183 | np.logical_and(counts_gt >= self.min_points, matched_gt == False) 184 | ) 185 | 186 | # count the FP 187 | self.pan_fp[cl] += np.sum( 188 | np.logical_and(counts_pred >= self.min_points, matched_pred == False) 189 | ) 190 | 191 | def getPQ(self): 192 | # first calculate for all classes 193 | sq_all = self.pan_iou.astype(np.float64) / np.maximum( 194 | self.pan_tp.astype(np.float64), self.eps 195 | ) 196 | rq_all = self.pan_tp.astype(np.float64) / np.maximum( 197 | self.pan_tp.astype(np.float64) 198 | + 0.5 * self.pan_fp.astype(np.float64) 199 | + 0.5 * self.pan_fn.astype(np.float64), 200 | self.eps, 201 | ) 202 | pq_all = sq_all * rq_all 203 | 204 | # then do the REAL mean (no ignored classes) 205 | SQ = sq_all[self.include].mean() 206 | RQ = rq_all[self.include].mean() 207 | PQ = pq_all[self.include].mean() 208 | 209 | return PQ, SQ, RQ, pq_all, sq_all, rq_all 210 | 211 | ############################# Panoptic STUFF ################################ 212 | ############################################################################## 213 | 214 | def addBatch( 215 | self, x_sem, x_inst, y_sem, y_inst, indices, seq 216 | ): # x=preds, y=targets 217 | """IMPORTANT: Inputs must be batched. Either [N,H,W], or [N, P]""" 218 | # add to IoU calculation (for checking purposes) 219 | self.addBatchSemIoU(x_sem, y_sem) 220 | 221 | # now do the panoptic stuff 222 | self.addBatchPanoptic(x_sem, x_inst, y_sem, y_inst) 223 | -------------------------------------------------------------------------------- /Mask4Former3D/models/metrics/panoptic_quality.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | 5 | class Panoptic4DEval: 6 | def __init__(self, n_classes, min_stuff_cls_id, ignore=0, offset=2**32, min_points=50): 7 | self.n_classes = n_classes + 1 8 | self.ignore = ignore 9 | self.include = np.array([n for n in range(self.n_classes) if n != self.ignore], dtype=np.int64) 10 | self.min_stuff_cls_id = min_stuff_cls_id 11 | self.reset() 12 | self.offset = offset # largest number of instances in a given scan 13 | self.min_points = min_points # smallest number of points to consider instances in gt 14 | self.eps = 1e-15 15 | 16 | def reset(self): 17 | # iou stuff 18 | self.px_iou_conf_matrix = np.zeros((self.n_classes, self.n_classes), dtype=np.int64) 19 | self.sequences = [] 20 | self.preds = {} 21 | self.gts = {} 22 | self.intersects = {} 23 | 24 | def addBatchSemIoU(self, x_sem, y_sem): 25 | # idxs are labels and predictions 26 | idxs = np.stack([x_sem, y_sem], axis=0) 27 | 28 | # make confusion matrix (cols = gt, rows = pred) 29 | np.add.at(self.px_iou_conf_matrix, tuple(idxs), 1) 30 | 31 | def getSemIoUStats(self): 32 | conf = self.px_iou_conf_matrix.copy().astype(np.double) 33 | conf[:, self.ignore] = 0 34 | 35 | # get the clean stats 36 | tp = conf.diagonal() 37 | fp = conf.sum(axis=1) - tp 38 | fn = conf.sum(axis=0) - tp 39 | return tp, fp, fn 40 | 41 | def getSemIoU(self): 42 | tp, fp, fn = self.getSemIoUStats() 43 | intersection = tp 44 | union = tp + fp + fn 45 | union = np.maximum(union, self.eps) 46 | iou = intersection[self.include].astype(np.double) / union[self.include].astype(np.double) 47 | iou_mean = iou.mean() 48 | 49 | return iou_mean, iou 50 | 51 | def update_dict_stat(self, stat_dict, unique_ids, unique_cnts): 52 | for uniqueid, counts in zip(unique_ids, unique_cnts): 53 | if uniqueid == 1: 54 | continue # 1 -- no instance 55 | if uniqueid in stat_dict: 56 | stat_dict[uniqueid] += counts 57 | else: 58 | stat_dict[uniqueid] = counts 59 | 60 | def addBatchPanoptic4D(self, seq, x_sem_row, x_inst_row, y_sem_row, y_inst_row): 61 | if seq not in self.sequences: 62 | self.sequences.append(seq) 63 | self.preds[seq] = {} 64 | self.gts[seq] = [{} for i in range(self.n_classes)] 65 | self.intersects[seq] = [{} for i in range(self.n_classes)] 66 | 67 | # make sure instances are not zeros (it messes with my approach) 68 | x_inst_row = x_inst_row + 1 69 | y_inst_row = y_inst_row + 1 70 | 71 | preds = self.preds[seq] 72 | # generate the areas for each unique instance prediction (i.e., set1) 73 | unique_pred, counts_pred = np.unique(x_inst_row, return_counts=True) 74 | self.update_dict_stat(preds, unique_pred, counts_pred) 75 | 76 | for cl in self.include: 77 | # Per-class accumulated stats 78 | cl_gts = self.gts[seq][cl] 79 | cl_intersects = self.intersects[seq][cl] 80 | 81 | # get a binary class mask (filter acc. to semantic class!) 82 | y_inst_in_cl_mask = y_sem_row == cl 83 | 84 | # get instance points in class (mask-out everything but _this_ class) 85 | y_inst_in_cl = y_inst_row * y_inst_in_cl_mask.astype(np.int64) 86 | 87 | # generate the areas for each unique instance gt_np (i.e., set2) 88 | unique_gt, counts_gt = np.unique(y_inst_in_cl[y_inst_in_cl > 0], return_counts=True) 89 | self.update_dict_stat( 90 | cl_gts, unique_gt[counts_gt > self.min_points], counts_gt[counts_gt > self.min_points] 91 | ) 92 | y_inst_in_cl[np.isin(y_inst_in_cl, unique_gt[counts_gt <= self.min_points])] = 0 93 | 94 | # generate intersection using offset 95 | offset_combo = x_inst_row[y_inst_in_cl > 0] + self.offset * y_inst_in_cl[y_inst_in_cl > 0] 96 | unique_combo, counts_combo = np.unique(offset_combo, return_counts=True) 97 | 98 | self.update_dict_stat(cl_intersects, unique_combo, counts_combo) 99 | 100 | def getPQ4D(self): 101 | pan_aq = np.zeros(self.n_classes, dtype=np.double) 102 | pan_aq_ovr = 0.0 103 | num_tubes = [0] * self.n_classes 104 | 105 | for seq in self.sequences: 106 | preds = self.preds[seq] 107 | for cl in range(self.n_classes): 108 | cl_gts = self.gts[seq][cl] 109 | cl_intersects = self.intersects[seq][cl] 110 | outer_sum_iou = 0.0 111 | for gt_id, gt_size in cl_gts.items(): 112 | num_tubes[cl] += 1 113 | inner_sum_iou = 0.0 114 | for pr_id, pr_size in preds.items(): 115 | TPA_key = pr_id + self.offset * gt_id 116 | if TPA_key in cl_intersects: 117 | TPA_ovr = cl_intersects[TPA_key] 118 | inner_sum_iou += TPA_ovr * (TPA_ovr / (gt_size + pr_size - TPA_ovr)) 119 | outer_sum_iou += float(inner_sum_iou) / float(gt_size) 120 | pan_aq[cl] += outer_sum_iou 121 | pan_aq_ovr += outer_sum_iou 122 | 123 | AQ_overall = np.sum(pan_aq_ovr) / np.sum(num_tubes[1 : self.min_stuff_cls_id]) 124 | AQ = pan_aq / np.maximum(num_tubes, self.eps) 125 | 126 | iou_mean, iou = self.getSemIoU() 127 | 128 | PQ4D = math.sqrt(AQ_overall * iou_mean) 129 | return PQ4D, AQ_overall, AQ[self.include], iou_mean, iou 130 | 131 | def addBatch(self, x_sem, x_inst, y_sem, y_inst, indices, seq): # x=preds, y=targets 132 | x_sem = x_sem[indices] 133 | x_inst = x_inst[indices] 134 | y_sem = y_sem[indices] 135 | y_inst = y_inst[indices] 136 | 137 | # only interested in points that are outside the void area (not in excluded classes) 138 | gt_not_in_excl_mask = y_sem != self.ignore 139 | # remove all other points 140 | x_sem = x_sem[gt_not_in_excl_mask] 141 | y_sem = y_sem[gt_not_in_excl_mask] 142 | x_inst = x_inst[gt_not_in_excl_mask] 143 | y_inst = y_inst[gt_not_in_excl_mask] 144 | 145 | # add to IoU calculation (for checking purposes) 146 | self.addBatchSemIoU(x_sem, y_sem) 147 | 148 | # now do the panoptic stuff 149 | self.addBatchPanoptic4D(seq, x_sem, x_inst, y_sem, y_inst) 150 | -------------------------------------------------------------------------------- /Mask4Former3D/models/model.py: -------------------------------------------------------------------------------- 1 | from MinkowskiEngine import MinkowskiNetwork 2 | 3 | 4 | class Model(MinkowskiNetwork): 5 | """ 6 | Base network for all sparse convnet 7 | 8 | By default, all networks are segmentation networks. 9 | """ 10 | 11 | OUT_PIXEL_DIST = -1 12 | 13 | def __init__(self, in_channels, out_channels, config, D, **kwargs): 14 | super().__init__(D) 15 | self.in_channels = in_channels 16 | self.out_channels = out_channels 17 | self.config = config 18 | 19 | 20 | class HighDimensionalModel(Model): 21 | """ 22 | Base network for all spatio (temporal) chromatic sparse convnet 23 | """ 24 | 25 | def __init__(self, in_channels, out_channels, config, D, **kwargs): 26 | assert D > 4, "Num dimension smaller than 5" 27 | super().__init__(in_channels, out_channels, config, D, **kwargs) 28 | -------------------------------------------------------------------------------- /Mask4Former3D/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumuji/stu_dataset/3812d768f3634fbb6faeb8b0bfbd5246a9798e93/Mask4Former3D/models/modules/__init__.py -------------------------------------------------------------------------------- /Mask4Former3D/models/modules/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | 4 | 5 | class SelfAttentionLayer(nn.Module): 6 | def __init__(self, d_model, nhead, dropout=0.0, activation="relu"): 7 | super().__init__() 8 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 9 | 10 | self.norm = nn.LayerNorm(d_model) 11 | self.dropout = nn.Dropout(dropout) 12 | 13 | self.activation = _get_activation_fn(activation) 14 | 15 | self._reset_parameters() 16 | 17 | def _reset_parameters(self): 18 | for p in self.parameters(): 19 | if p.dim() > 1: 20 | nn.init.xavier_uniform_(p) 21 | 22 | def with_pos_embed(self, tensor, pos): 23 | return tensor if pos is None else tensor + pos 24 | 25 | def forward(self, tgt, tgt_mask=None, tgt_key_padding_mask=None, query_pos=None): 26 | q = k = self.with_pos_embed(tgt, query_pos) 27 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] 28 | tgt = tgt + self.dropout(tgt2) 29 | tgt = self.norm(tgt) 30 | 31 | return tgt 32 | 33 | 34 | class CrossAttentionLayer(nn.Module): 35 | def __init__(self, d_model, nhead, dropout=0.0, activation="relu"): 36 | super().__init__() 37 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 38 | 39 | self.norm = nn.LayerNorm(d_model) 40 | self.dropout = nn.Dropout(dropout) 41 | 42 | self.activation = _get_activation_fn(activation) 43 | 44 | self._reset_parameters() 45 | 46 | def _reset_parameters(self): 47 | for p in self.parameters(): 48 | if p.dim() > 1: 49 | nn.init.xavier_uniform_(p) 50 | 51 | def with_pos_embed(self, tensor, pos): 52 | return tensor if pos is None else tensor + pos 53 | 54 | def forward(self, tgt, memory, memory_mask=None, memory_key_padding_mask=None, pos=None, query_pos=None): 55 | tgt2 = self.multihead_attn( 56 | query=self.with_pos_embed(tgt, query_pos), 57 | key=self.with_pos_embed(memory, pos), 58 | value=memory, 59 | attn_mask=memory_mask, 60 | key_padding_mask=memory_key_padding_mask, 61 | )[0] 62 | tgt = tgt + self.dropout(tgt2) 63 | tgt = self.norm(tgt) 64 | 65 | return tgt 66 | 67 | 68 | class FFNLayer(nn.Module): 69 | def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu"): 70 | super().__init__() 71 | # Implementation of Feedforward model 72 | self.linear1 = nn.Linear(d_model, dim_feedforward) 73 | self.dropout = nn.Dropout(dropout) 74 | self.linear2 = nn.Linear(dim_feedforward, d_model) 75 | 76 | self.norm = nn.LayerNorm(d_model) 77 | 78 | self.activation = _get_activation_fn(activation) 79 | 80 | self._reset_parameters() 81 | 82 | def _reset_parameters(self): 83 | for p in self.parameters(): 84 | if p.dim() > 1: 85 | nn.init.xavier_uniform_(p) 86 | 87 | def with_pos_embed(self, tensor, pos): 88 | return tensor if pos is None else tensor + pos 89 | 90 | def forward(self, tgt): 91 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 92 | tgt = tgt + self.dropout(tgt2) 93 | tgt = self.norm(tgt) 94 | return tgt 95 | 96 | 97 | def _get_activation_fn(activation): 98 | """Return an activation function given a string""" 99 | if activation == "relu": 100 | return F.relu 101 | if activation == "gelu": 102 | return F.gelu 103 | if activation == "glu": 104 | return F.glu 105 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 106 | -------------------------------------------------------------------------------- /Mask4Former3D/models/modules/common.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info[:2] >= (3, 8): 4 | from collections.abc import Sequence 5 | else: 6 | from collections import Sequence 7 | 8 | from enum import Enum 9 | 10 | import torch.nn as nn 11 | import MinkowskiEngine as ME 12 | 13 | 14 | class NormType(Enum): 15 | BATCH_NORM = 0 16 | INSTANCE_NORM = 1 17 | INSTANCE_BATCH_NORM = 2 18 | 19 | 20 | def get_norm(norm_type, n_channels, D, bn_momentum=0.1): 21 | if norm_type == NormType.BATCH_NORM: 22 | return ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum) 23 | elif norm_type == NormType.INSTANCE_NORM: 24 | return ME.MinkowskiInstanceNorm(n_channels) 25 | elif norm_type == NormType.INSTANCE_BATCH_NORM: 26 | return nn.Sequential( 27 | ME.MinkowskiInstanceNorm(n_channels), 28 | ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum), 29 | ) 30 | else: 31 | raise ValueError(f"Norm type: {norm_type} not supported") 32 | 33 | 34 | class ConvType(Enum): 35 | """ 36 | Define the kernel region type 37 | """ 38 | 39 | HYPERCUBE = 0, "HYPERCUBE" 40 | SPATIAL_HYPERCUBE = 1, "SPATIAL_HYPERCUBE" 41 | SPATIO_TEMPORAL_HYPERCUBE = 2, "SPATIO_TEMPORAL_HYPERCUBE" 42 | HYPERCROSS = 3, "HYPERCROSS" 43 | SPATIAL_HYPERCROSS = 4, "SPATIAL_HYPERCROSS" 44 | SPATIO_TEMPORAL_HYPERCROSS = 5, "SPATIO_TEMPORAL_HYPERCROSS" 45 | SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS = 6, "SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS " 46 | 47 | def __new__(cls, value, name): 48 | member = object.__new__(cls) 49 | member._value_ = value 50 | member.fullname = name 51 | return member 52 | 53 | def __int__(self): 54 | return self.value 55 | 56 | 57 | # Covert the ConvType var to a RegionType var 58 | conv_to_region_type = { 59 | # kernel_size = [k, k, k, 1] 60 | ConvType.HYPERCUBE: ME.RegionType.HYPER_CUBE, 61 | ConvType.SPATIAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, 62 | ConvType.SPATIO_TEMPORAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, 63 | ConvType.HYPERCROSS: ME.RegionType.HYPER_CROSS, 64 | ConvType.SPATIAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, 65 | ConvType.SPATIO_TEMPORAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, 66 | ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: ME.RegionType.HYPER_CUBE, # JONAS CHANGE from HYBRID 67 | } 68 | 69 | # int_to_region_type = {m.value: m for m in ME.RegionType} 70 | int_to_region_type = {m: ME.RegionType(m) for m in range(3)} 71 | 72 | 73 | def convert_region_type(region_type): 74 | """ 75 | Convert the integer region_type to the corresponding RegionType enum object. 76 | """ 77 | return int_to_region_type[region_type] 78 | 79 | 80 | def convert_conv_type(conv_type, kernel_size, D): 81 | assert isinstance(conv_type, ConvType), "conv_type must be of ConvType" 82 | region_type = conv_to_region_type[conv_type] 83 | axis_types = None 84 | if conv_type == ConvType.SPATIAL_HYPERCUBE: 85 | # No temporal convolution 86 | if isinstance(kernel_size, Sequence): 87 | kernel_size = kernel_size[:3] 88 | else: 89 | kernel_size = [ 90 | kernel_size, 91 | ] * 3 92 | if D == 4: 93 | kernel_size.append(1) 94 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCUBE: 95 | # conv_type conversion already handled 96 | assert D == 4 97 | elif conv_type == ConvType.HYPERCUBE: 98 | # conv_type conversion already handled 99 | pass 100 | elif conv_type == ConvType.SPATIAL_HYPERCROSS: 101 | if isinstance(kernel_size, Sequence): 102 | kernel_size = kernel_size[:3] 103 | else: 104 | kernel_size = [ 105 | kernel_size, 106 | ] * 3 107 | if D == 4: 108 | kernel_size.append(1) 109 | elif conv_type == ConvType.HYPERCROSS: 110 | # conv_type conversion already handled 111 | pass 112 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCROSS: 113 | # conv_type conversion already handled 114 | assert D == 4 115 | elif conv_type == ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: 116 | # Define the CUBIC conv kernel for spatial dims and CROSS conv for temp dim 117 | axis_types = [ 118 | ME.RegionType.HYPER_CUBE, 119 | ] * 3 120 | if D == 4: 121 | axis_types.append(ME.RegionType.HYPER_CROSS) 122 | return region_type, axis_types, kernel_size 123 | 124 | 125 | def conv( 126 | in_planes, 127 | out_planes, 128 | kernel_size, 129 | stride=1, 130 | dilation=1, 131 | bias=False, 132 | conv_type=ConvType.HYPERCUBE, 133 | D=-1, 134 | ): 135 | assert D > 0, "Dimension must be a positive integer" 136 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 137 | kernel_generator = ME.KernelGenerator( 138 | kernel_size, 139 | stride, 140 | dilation, 141 | region_type=region_type, 142 | axis_types=None, # axis_types JONAS 143 | dimension=D, 144 | ) 145 | 146 | return ME.MinkowskiConvolution( 147 | in_channels=in_planes, 148 | out_channels=out_planes, 149 | kernel_size=kernel_size, 150 | stride=stride, 151 | dilation=dilation, 152 | bias=bias, 153 | kernel_generator=kernel_generator, 154 | dimension=D, 155 | ) 156 | 157 | 158 | def conv_tr( 159 | in_planes, 160 | out_planes, 161 | kernel_size, 162 | upsample_stride=1, 163 | dilation=1, 164 | bias=False, 165 | conv_type=ConvType.HYPERCUBE, 166 | D=-1, 167 | ): 168 | assert D > 0, "Dimension must be a positive integer" 169 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 170 | kernel_generator = ME.KernelGenerator( 171 | kernel_size, 172 | upsample_stride, 173 | dilation, 174 | region_type=region_type, 175 | axis_types=axis_types, 176 | dimension=D, 177 | ) 178 | 179 | return ME.MinkowskiConvolutionTranspose( 180 | in_channels=in_planes, 181 | out_channels=out_planes, 182 | kernel_size=kernel_size, 183 | stride=upsample_stride, 184 | dilation=dilation, 185 | bias=bias, 186 | kernel_generator=kernel_generator, 187 | dimension=D, 188 | ) 189 | 190 | 191 | def avg_pool( 192 | kernel_size, 193 | stride=1, 194 | dilation=1, 195 | conv_type=ConvType.HYPERCUBE, 196 | in_coords_key=None, 197 | D=-1, 198 | ): 199 | assert D > 0, "Dimension must be a positive integer" 200 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 201 | kernel_generator = ME.KernelGenerator( 202 | kernel_size, 203 | stride, 204 | dilation, 205 | region_type=region_type, 206 | axis_types=axis_types, 207 | dimension=D, 208 | ) 209 | 210 | return ME.MinkowskiAvgPooling( 211 | kernel_size=kernel_size, 212 | stride=stride, 213 | dilation=dilation, 214 | kernel_generator=kernel_generator, 215 | dimension=D, 216 | ) 217 | 218 | 219 | def avg_unpool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): 220 | assert D > 0, "Dimension must be a positive integer" 221 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 222 | kernel_generator = ME.KernelGenerator( 223 | kernel_size, 224 | stride, 225 | dilation, 226 | region_type=region_type, 227 | axis_types=axis_types, 228 | dimension=D, 229 | ) 230 | 231 | return ME.MinkowskiAvgUnpooling( 232 | kernel_size=kernel_size, 233 | stride=stride, 234 | dilation=dilation, 235 | kernel_generator=kernel_generator, 236 | dimension=D, 237 | ) 238 | 239 | 240 | def sum_pool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): 241 | assert D > 0, "Dimension must be a positive integer" 242 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 243 | kernel_generator = ME.KernelGenerator( 244 | kernel_size, 245 | stride, 246 | dilation, 247 | region_type=region_type, 248 | axis_types=axis_types, 249 | dimension=D, 250 | ) 251 | 252 | return ME.MinkowskiSumPooling( 253 | kernel_size=kernel_size, 254 | stride=stride, 255 | dilation=dilation, 256 | kernel_generator=kernel_generator, 257 | dimension=D, 258 | ) 259 | -------------------------------------------------------------------------------- /Mask4Former3D/models/modules/helpers_3detr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch.nn as nn 3 | from functools import partial 4 | import copy 5 | 6 | 7 | class BatchNormDim1Swap(nn.BatchNorm1d): 8 | """ 9 | Used for nn.Transformer that uses a HW x N x C rep 10 | """ 11 | 12 | def forward(self, x): 13 | """ 14 | x: HW x N x C 15 | permute to N x C x HW 16 | Apply BN on C 17 | permute back 18 | """ 19 | hw, n, c = x.shape 20 | x = x.permute(1, 2, 0) 21 | x = super(BatchNormDim1Swap, self).forward(x) 22 | # x: n x c x hw -> hw x n x c 23 | x = x.permute(2, 0, 1) 24 | return x 25 | 26 | 27 | NORM_DICT = { 28 | "bn": BatchNormDim1Swap, 29 | "bn1d": nn.BatchNorm1d, 30 | "id": nn.Identity, 31 | "ln": nn.LayerNorm, 32 | } 33 | 34 | ACTIVATION_DICT = { 35 | "relu": nn.ReLU, 36 | "gelu": nn.GELU, 37 | "leakyrelu": partial(nn.LeakyReLU, negative_slope=0.1), 38 | } 39 | 40 | WEIGHT_INIT_DICT = { 41 | "xavier_uniform": nn.init.xavier_uniform_, 42 | } 43 | 44 | 45 | class GenericMLP(nn.Module): 46 | def __init__( 47 | self, 48 | input_dim, 49 | hidden_dims, 50 | output_dim, 51 | norm_fn_name=None, 52 | activation="relu", 53 | use_conv=False, 54 | dropout=None, 55 | hidden_use_bias=False, 56 | output_use_bias=True, 57 | output_use_activation=False, 58 | output_use_norm=False, 59 | weight_init_name=None, 60 | ): 61 | super().__init__() 62 | activation = ACTIVATION_DICT[activation] 63 | norm = None 64 | if norm_fn_name is not None: 65 | norm = NORM_DICT[norm_fn_name] 66 | if norm_fn_name == "ln" and use_conv: 67 | norm = lambda x: nn.GroupNorm(1, x) # easier way to use LayerNorm 68 | 69 | if dropout is not None: 70 | if not isinstance(dropout, list): 71 | dropout = [dropout for _ in range(len(hidden_dims))] 72 | 73 | layers = [] 74 | prev_dim = input_dim 75 | for idx, x in enumerate(hidden_dims): 76 | if use_conv: 77 | layer = nn.Conv1d(prev_dim, x, 1, bias=hidden_use_bias) 78 | else: 79 | layer = nn.Linear(prev_dim, x, bias=hidden_use_bias) 80 | layers.append(layer) 81 | if norm: 82 | layers.append(norm(x)) 83 | layers.append(activation()) 84 | if dropout is not None: 85 | layers.append(nn.Dropout(p=dropout[idx])) 86 | prev_dim = x 87 | if use_conv: 88 | layer = nn.Conv1d(prev_dim, output_dim, 1, bias=output_use_bias) 89 | else: 90 | layer = nn.Linear(prev_dim, output_dim, bias=output_use_bias) 91 | layers.append(layer) 92 | 93 | if output_use_norm: 94 | layers.append(norm(output_dim)) 95 | 96 | if output_use_activation: 97 | layers.append(activation()) 98 | 99 | self.layers = nn.Sequential(*layers) 100 | 101 | if weight_init_name is not None: 102 | self.do_weight_init(weight_init_name) 103 | 104 | def do_weight_init(self, weight_init_name): 105 | func = WEIGHT_INIT_DICT[weight_init_name] 106 | for _, param in self.named_parameters(): 107 | if param.dim() > 1: # skips batchnorm/layernorm 108 | func(param) 109 | 110 | def forward(self, x): 111 | output = self.layers(x) 112 | return output 113 | 114 | 115 | def get_clones(module, N): 116 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 117 | -------------------------------------------------------------------------------- /Mask4Former3D/models/modules/resnet_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from MinkowskiEngine import MinkowskiReLU 3 | 4 | from models.modules.common import ConvType, NormType, conv, get_norm 5 | 6 | 7 | class BasicBlockBase(nn.Module): 8 | expansion = 1 9 | NORM_TYPE = NormType.BATCH_NORM 10 | 11 | def __init__( 12 | self, 13 | inplanes, 14 | planes, 15 | stride=1, 16 | dilation=1, 17 | downsample=None, 18 | conv_type=ConvType.HYPERCUBE, 19 | bn_momentum=0.1, 20 | D=3, 21 | ): 22 | super().__init__() 23 | 24 | self.conv1 = conv( 25 | inplanes, 26 | planes, 27 | kernel_size=3, 28 | stride=stride, 29 | dilation=dilation, 30 | conv_type=conv_type, 31 | D=D, 32 | ) 33 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 34 | self.conv2 = conv( 35 | planes, 36 | planes, 37 | kernel_size=3, 38 | stride=1, 39 | dilation=dilation, 40 | bias=False, 41 | conv_type=conv_type, 42 | D=D, 43 | ) 44 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 45 | self.relu = MinkowskiReLU(inplace=True) 46 | self.downsample = downsample 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.norm1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.norm2(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class BasicBlock(BasicBlockBase): 68 | NORM_TYPE = NormType.BATCH_NORM 69 | 70 | 71 | class BasicBlockIN(BasicBlockBase): 72 | NORM_TYPE = NormType.INSTANCE_NORM 73 | 74 | 75 | class BasicBlockINBN(BasicBlockBase): 76 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM 77 | 78 | 79 | class BottleneckBase(nn.Module): 80 | expansion = 4 81 | NORM_TYPE = NormType.BATCH_NORM 82 | 83 | def __init__( 84 | self, 85 | inplanes, 86 | planes, 87 | stride=1, 88 | dilation=1, 89 | downsample=None, 90 | conv_type=ConvType.HYPERCUBE, 91 | bn_momentum=0.1, 92 | D=3, 93 | ): 94 | super().__init__() 95 | self.conv1 = conv(inplanes, planes, kernel_size=1, D=D) 96 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 97 | 98 | self.conv2 = conv( 99 | planes, 100 | planes, 101 | kernel_size=3, 102 | stride=stride, 103 | dilation=dilation, 104 | conv_type=conv_type, 105 | D=D, 106 | ) 107 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 108 | 109 | self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, D=D) 110 | self.norm3 = get_norm(self.NORM_TYPE, planes * self.expansion, D, bn_momentum=bn_momentum) 111 | 112 | self.relu = MinkowskiReLU(inplace=True) 113 | self.downsample = downsample 114 | 115 | def forward(self, x): 116 | residual = x 117 | 118 | out = self.conv1(x) 119 | out = self.norm1(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv2(out) 123 | out = self.norm2(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv3(out) 127 | out = self.norm3(out) 128 | 129 | if self.downsample is not None: 130 | residual = self.downsample(x) 131 | 132 | out += residual 133 | out = self.relu(out) 134 | 135 | return out 136 | 137 | 138 | class Bottleneck(BottleneckBase): 139 | NORM_TYPE = NormType.BATCH_NORM 140 | 141 | 142 | class BottleneckIN(BottleneckBase): 143 | NORM_TYPE = NormType.INSTANCE_NORM 144 | 145 | 146 | class BottleneckINBN(BottleneckBase): 147 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM 148 | -------------------------------------------------------------------------------- /Mask4Former3D/models/position_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | 6 | def shift_scale_points(pred_xyz, input_range): 7 | """ 8 | pred_xyz: B x N x 3 9 | input_range: [[B x 3], [B x 3]] - min and max XYZ coords 10 | dst_range: [[B x 3], [B x 3]] - min and max XYZ coords 11 | """ 12 | dst_range = [ 13 | torch.zeros_like(input_range[0], device=input_range[0].device), 14 | torch.ones_like(input_range[0], device=input_range[0].device), 15 | ] 16 | 17 | src_diff = input_range[1][:, None, :] - input_range[0][:, None, :] 18 | dst_diff = dst_range[1][:, None, :] - dst_range[0][:, None, :] 19 | prop_xyz = (((pred_xyz - input_range[0][:, None, :]) * dst_diff) / src_diff) + dst_range[0][:, None, :] 20 | return prop_xyz 21 | 22 | 23 | class PositionEmbeddingCoordsSine(nn.Module): 24 | def __init__( 25 | self, 26 | d_in=3, 27 | d_pos=None, 28 | normalize=True, 29 | ): 30 | super().__init__() 31 | self.d_in = d_in 32 | self.d_pos = d_pos 33 | self.normalize = normalize 34 | 35 | # define a gaussian matrix input_ch -> output_ch 36 | B = torch.empty((d_in, d_pos // 2)).normal_() 37 | self.register_buffer("gauss_B", B) 38 | 39 | @torch.no_grad() 40 | def forward(self, xyz, num_channels=None, input_range=None): 41 | # xyz is batch x npoints x 3 42 | if num_channels is None: 43 | num_channels = self.gauss_B.shape[1] * 2 44 | 45 | bsize, npoints = xyz.shape[0], xyz.shape[1] 46 | d_out = num_channels // 2 47 | 48 | # clone coords so that shift/scale operations do not affect original tensor 49 | orig_xyz = xyz 50 | xyz = orig_xyz.clone() 51 | 52 | if self.normalize: 53 | xyz = shift_scale_points(xyz, input_range=input_range) 54 | 55 | xyz *= 2 * np.pi 56 | xyz_proj = torch.mm(xyz.view(-1, self.d_in), self.gauss_B[:, :d_out]).view(bsize, npoints, d_out) 57 | final_embeds = [xyz_proj.sin(), xyz_proj.cos()] 58 | 59 | # return batch x d_pos x npoints embedding 60 | final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1) 61 | return final_embeds 62 | -------------------------------------------------------------------------------- /Mask4Former3D/models/res16unet.py: -------------------------------------------------------------------------------- 1 | import MinkowskiEngine.MinkowskiOps as me 2 | from MinkowskiEngine import MinkowskiReLU 3 | 4 | from models.resnet import ResNetBase, get_norm 5 | from models.modules.common import ConvType, NormType, conv, conv_tr 6 | from models.modules.resnet_block import BasicBlock, Bottleneck 7 | 8 | 9 | class Res16UNetBase(ResNetBase): 10 | BLOCK = None 11 | PLANES = (32, 64, 128, 256, 256, 256, 256, 256) 12 | DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) 13 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 14 | INIT_DIM = 32 15 | OUT_PIXEL_DIST = 1 16 | NORM_TYPE = NormType.BATCH_NORM 17 | NON_BLOCK_CONV_TYPE = ConvType.SPATIAL_HYPERCUBE 18 | CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS 19 | 20 | # To use the model, must call initialize_coords before forward pass. 21 | # Once data is processed, call clear to reset the model before calling initialize_coords 22 | def __init__(self, in_channels, out_channels, config, D=3, **kwargs): 23 | super().__init__(in_channels, out_channels, config, D) 24 | 25 | def network_initialization(self, in_channels, out_channels, config, D): 26 | # Setup net_metadata 27 | dilations = self.DILATIONS 28 | bn_momentum = config.bn_momentum 29 | 30 | def space_n_time_m(n, m): 31 | return n if D == 3 else [n, n, n, m] 32 | 33 | if D == 4: 34 | self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) 35 | 36 | # Output of the first conv concated to conv6 37 | self.inplanes = self.INIT_DIM 38 | self.conv0p1s1 = conv( 39 | in_channels, 40 | self.inplanes, 41 | kernel_size=space_n_time_m(config.conv1_kernel_size, 1), 42 | stride=1, 43 | dilation=1, 44 | conv_type=self.NON_BLOCK_CONV_TYPE, 45 | D=D, 46 | ) 47 | 48 | self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) 49 | 50 | self.conv1p1s2 = conv( 51 | self.inplanes, 52 | self.inplanes, 53 | kernel_size=space_n_time_m(2, 1), 54 | stride=space_n_time_m(2, 1), 55 | dilation=1, 56 | conv_type=self.NON_BLOCK_CONV_TYPE, 57 | D=D, 58 | ) 59 | self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) 60 | self.block1 = self._make_layer( 61 | self.BLOCK, 62 | self.PLANES[0], 63 | self.LAYERS[0], 64 | dilation=dilations[0], 65 | norm_type=self.NORM_TYPE, 66 | bn_momentum=bn_momentum, 67 | ) 68 | 69 | self.conv2p2s2 = conv( 70 | self.inplanes, 71 | self.inplanes, 72 | kernel_size=space_n_time_m(2, 1), 73 | stride=space_n_time_m(2, 1), 74 | dilation=1, 75 | conv_type=self.NON_BLOCK_CONV_TYPE, 76 | D=D, 77 | ) 78 | self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) 79 | self.block2 = self._make_layer( 80 | self.BLOCK, 81 | self.PLANES[1], 82 | self.LAYERS[1], 83 | dilation=dilations[1], 84 | norm_type=self.NORM_TYPE, 85 | bn_momentum=bn_momentum, 86 | ) 87 | 88 | self.conv3p4s2 = conv( 89 | self.inplanes, 90 | self.inplanes, 91 | kernel_size=space_n_time_m(2, 1), 92 | stride=space_n_time_m(2, 1), 93 | dilation=1, 94 | conv_type=self.NON_BLOCK_CONV_TYPE, 95 | D=D, 96 | ) 97 | self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) 98 | self.block3 = self._make_layer( 99 | self.BLOCK, 100 | self.PLANES[2], 101 | self.LAYERS[2], 102 | dilation=dilations[2], 103 | norm_type=self.NORM_TYPE, 104 | bn_momentum=bn_momentum, 105 | ) 106 | 107 | self.conv4p8s2 = conv( 108 | self.inplanes, 109 | self.inplanes, 110 | kernel_size=space_n_time_m(2, 1), 111 | stride=space_n_time_m(2, 1), 112 | dilation=1, 113 | conv_type=self.NON_BLOCK_CONV_TYPE, 114 | D=D, 115 | ) 116 | self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) 117 | self.block4 = self._make_layer( 118 | self.BLOCK, 119 | self.PLANES[3], 120 | self.LAYERS[3], 121 | dilation=dilations[3], 122 | norm_type=self.NORM_TYPE, 123 | bn_momentum=bn_momentum, 124 | ) 125 | self.convtr4p16s2 = conv_tr( 126 | self.inplanes, 127 | self.PLANES[4], 128 | kernel_size=space_n_time_m(2, 1), 129 | upsample_stride=space_n_time_m(2, 1), 130 | dilation=1, 131 | bias=False, 132 | conv_type=self.NON_BLOCK_CONV_TYPE, 133 | D=D, 134 | ) 135 | self.bntr4 = get_norm(self.NORM_TYPE, self.PLANES[4], D, bn_momentum=bn_momentum) 136 | 137 | self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion 138 | self.block5 = self._make_layer( 139 | self.BLOCK, 140 | self.PLANES[4], 141 | self.LAYERS[4], 142 | dilation=dilations[4], 143 | norm_type=self.NORM_TYPE, 144 | bn_momentum=bn_momentum, 145 | ) 146 | self.convtr5p8s2 = conv_tr( 147 | self.inplanes, 148 | self.PLANES[5], 149 | kernel_size=space_n_time_m(2, 1), 150 | upsample_stride=space_n_time_m(2, 1), 151 | dilation=1, 152 | bias=False, 153 | conv_type=self.NON_BLOCK_CONV_TYPE, 154 | D=D, 155 | ) 156 | self.bntr5 = get_norm(self.NORM_TYPE, self.PLANES[5], D, bn_momentum=bn_momentum) 157 | 158 | self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion 159 | self.block6 = self._make_layer( 160 | self.BLOCK, 161 | self.PLANES[5], 162 | self.LAYERS[5], 163 | dilation=dilations[5], 164 | norm_type=self.NORM_TYPE, 165 | bn_momentum=bn_momentum, 166 | ) 167 | self.convtr6p4s2 = conv_tr( 168 | self.inplanes, 169 | self.PLANES[6], 170 | kernel_size=space_n_time_m(2, 1), 171 | upsample_stride=space_n_time_m(2, 1), 172 | dilation=1, 173 | bias=False, 174 | conv_type=self.NON_BLOCK_CONV_TYPE, 175 | D=D, 176 | ) 177 | self.bntr6 = get_norm(self.NORM_TYPE, self.PLANES[6], D, bn_momentum=bn_momentum) 178 | 179 | self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion 180 | self.block7 = self._make_layer( 181 | self.BLOCK, 182 | self.PLANES[6], 183 | self.LAYERS[6], 184 | dilation=dilations[6], 185 | norm_type=self.NORM_TYPE, 186 | bn_momentum=bn_momentum, 187 | ) 188 | self.convtr7p2s2 = conv_tr( 189 | self.inplanes, 190 | self.PLANES[7], 191 | kernel_size=space_n_time_m(2, 1), 192 | upsample_stride=space_n_time_m(2, 1), 193 | dilation=1, 194 | bias=False, 195 | conv_type=self.NON_BLOCK_CONV_TYPE, 196 | D=D, 197 | ) 198 | self.bntr7 = get_norm(self.NORM_TYPE, self.PLANES[7], D, bn_momentum=bn_momentum) 199 | 200 | self.inplanes = self.PLANES[7] + self.INIT_DIM 201 | self.block8 = self._make_layer( 202 | self.BLOCK, 203 | self.PLANES[7], 204 | self.LAYERS[7], 205 | dilation=dilations[7], 206 | norm_type=self.NORM_TYPE, 207 | bn_momentum=bn_momentum, 208 | ) 209 | 210 | # self.final = conv( 211 | # self.PLANES[7], out_channels, kernel_size=1, stride=1, bias=True, D=D 212 | # ) 213 | self.relu = MinkowskiReLU(inplace=True) 214 | 215 | def forward(self, x): 216 | feature_maps = [] 217 | 218 | out = self.conv0p1s1(x) 219 | out = self.bn0(out) 220 | out_p1 = self.relu(out) 221 | 222 | out = self.conv1p1s2(out_p1) 223 | out = self.bn1(out) 224 | out = self.relu(out) 225 | out_b1p2 = self.block1(out) 226 | 227 | out = self.conv2p2s2(out_b1p2) 228 | out = self.bn2(out) 229 | out = self.relu(out) 230 | out_b2p4 = self.block2(out) 231 | 232 | out = self.conv3p4s2(out_b2p4) 233 | out = self.bn3(out) 234 | out = self.relu(out) 235 | out_b3p8 = self.block3(out) 236 | 237 | # pixel_dist=16 238 | out = self.conv4p8s2(out_b3p8) 239 | out = self.bn4(out) 240 | out = self.relu(out) 241 | out = self.block4(out) 242 | 243 | feature_maps.append(out) 244 | 245 | # pixel_dist=8 246 | out = self.convtr4p16s2(out) 247 | out = self.bntr4(out) 248 | out = self.relu(out) 249 | 250 | out = me.cat(out, out_b3p8) 251 | out = self.block5(out) 252 | 253 | feature_maps.append(out) 254 | 255 | # pixel_dist=4 256 | out = self.convtr5p8s2(out) 257 | out = self.bntr5(out) 258 | out = self.relu(out) 259 | 260 | out = me.cat(out, out_b2p4) 261 | out = self.block6(out) 262 | 263 | feature_maps.append(out) 264 | 265 | # pixel_dist=2 266 | out = self.convtr6p4s2(out) 267 | out = self.bntr6(out) 268 | out = self.relu(out) 269 | 270 | out = me.cat(out, out_b1p2) 271 | out = self.block7(out) 272 | 273 | feature_maps.append(out) 274 | 275 | # pixel_dist=1 276 | out = self.convtr7p2s2(out) 277 | out = self.bntr7(out) 278 | out = self.relu(out) 279 | 280 | out = me.cat(out, out_p1) 281 | out = self.block8(out) 282 | 283 | feature_maps.append(out) 284 | 285 | return feature_maps 286 | 287 | 288 | class Res16UNet14(Res16UNetBase): 289 | BLOCK = BasicBlock 290 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) 291 | 292 | 293 | class Res16UNet18(Res16UNetBase): 294 | BLOCK = BasicBlock 295 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 296 | 297 | 298 | class Res16UNet34(Res16UNetBase): 299 | BLOCK = BasicBlock 300 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 301 | 302 | 303 | class Res16UNet50(Res16UNetBase): 304 | BLOCK = Bottleneck 305 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 306 | 307 | 308 | class Res16UNet101(Res16UNetBase): 309 | BLOCK = Bottleneck 310 | LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) 311 | 312 | 313 | class Res16UNet14A(Res16UNet14): 314 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 315 | 316 | 317 | class Res16UNet14A2(Res16UNet14A): 318 | LAYERS = (1, 1, 1, 1, 2, 2, 2, 2) 319 | 320 | 321 | class Res16UNet14B(Res16UNet14): 322 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 323 | 324 | 325 | class Res16UNet14B2(Res16UNet14B): 326 | LAYERS = (1, 1, 1, 1, 2, 2, 2, 2) 327 | 328 | 329 | class Res16UNet14B3(Res16UNet14B): 330 | LAYERS = (2, 2, 2, 2, 1, 1, 1, 1) 331 | 332 | 333 | class Res16UNet14C(Res16UNet14): 334 | PLANES = (32, 64, 128, 256, 192, 192, 128, 128) 335 | 336 | 337 | class Res16UNet14D(Res16UNet14): 338 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 339 | 340 | 341 | class Res16UNet18A(Res16UNet18): 342 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 343 | 344 | 345 | class Res16UNet18B(Res16UNet18): 346 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 347 | 348 | 349 | class Res16UNet18D(Res16UNet18): 350 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 351 | 352 | 353 | class Res16UNet34A(Res16UNet34): 354 | PLANES = (32, 64, 128, 256, 256, 128, 64, 64) 355 | 356 | 357 | class Res16UNet34B(Res16UNet34): 358 | PLANES = (32, 64, 128, 256, 256, 128, 64, 32) 359 | 360 | 361 | class Res16UNet34C(Res16UNet34): 362 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) 363 | 364 | 365 | class Custom30M(Res16UNet34): 366 | PLANES = (32, 64, 128, 256, 128, 64, 64, 32) 367 | 368 | 369 | class Res16UNet34D(Res16UNet34): 370 | PLANES = (32, 64, 128, 256, 256, 128, 96, 128) 371 | 372 | 373 | class STRes16UNetBase(Res16UNetBase): 374 | CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS 375 | 376 | def __init__(self, in_channels, out_channels, config, D=4, **kwargs): 377 | super().__init__(in_channels, out_channels, config, D, **kwargs) 378 | 379 | 380 | class STRes16UNet14(STRes16UNetBase, Res16UNet14): 381 | pass 382 | 383 | 384 | class STRes16UNet14A(STRes16UNetBase, Res16UNet14A): 385 | pass 386 | 387 | 388 | class STRes16UNet18(STRes16UNetBase, Res16UNet18): 389 | pass 390 | 391 | 392 | class STRes16UNet34(STRes16UNetBase, Res16UNet34): 393 | pass 394 | 395 | 396 | class STRes16UNet34C(STRes16UNetBase, Res16UNet34C): 397 | pass 398 | 399 | 400 | class STRes16UNet50(STRes16UNetBase, Res16UNet50): 401 | pass 402 | 403 | 404 | class STRes16UNet101(STRes16UNetBase, Res16UNet101): 405 | pass 406 | 407 | 408 | class STRes16UNet18A(STRes16UNet18): 409 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 410 | 411 | 412 | class STResTesseract16UNetBase(STRes16UNetBase): 413 | pass 414 | # CONV_TYPE = ConvType.HYPERCUBE 415 | 416 | 417 | class STResTesseract16UNet18A(STRes16UNet18A, STResTesseract16UNetBase): 418 | pass 419 | -------------------------------------------------------------------------------- /Mask4Former3D/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import MinkowskiEngine as ME 3 | 4 | from models.model import Model 5 | from models.modules.common import ConvType, NormType, conv, get_norm, sum_pool 6 | from models.modules.resnet_block import BasicBlock, Bottleneck 7 | 8 | 9 | class ResNetBase(Model): 10 | BLOCK = None 11 | LAYERS = () 12 | INIT_DIM = 64 13 | PLANES = (64, 128, 256, 512) 14 | OUT_PIXEL_DIST = 32 15 | HAS_LAST_BLOCK = False 16 | CONV_TYPE = ConvType.HYPERCUBE 17 | 18 | def __init__(self, in_channels, out_channels, config, D=3, **kwargs): 19 | assert self.BLOCK is not None 20 | assert self.OUT_PIXEL_DIST > 0 21 | 22 | super().__init__(in_channels, out_channels, config, D, **kwargs) 23 | 24 | self.network_initialization(in_channels, out_channels, config, D) 25 | self.weight_initialization() 26 | 27 | def network_initialization(self, in_channels, out_channels, config, D): 28 | def space_n_time_m(n, m): 29 | return n if D == 3 else [n, n, n, m] 30 | 31 | if D == 4: 32 | self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) 33 | 34 | dilations = config.dilations 35 | bn_momentum = config.bn_momentum 36 | self.inplanes = self.INIT_DIM 37 | self.conv1 = conv( 38 | in_channels, 39 | self.inplanes, 40 | kernel_size=space_n_time_m(config.conv1_kernel_size, 1), 41 | stride=1, 42 | D=D, 43 | ) 44 | 45 | self.bn1 = get_norm(NormType.BATCH_NORM, self.inplanes, D=self.D, bn_momentum=bn_momentum) 46 | self.relu = ME.MinkowskiReLU(inplace=True) 47 | self.pool = sum_pool(kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), D=D) 48 | 49 | self.layer1 = self._make_layer( 50 | self.BLOCK, 51 | self.PLANES[0], 52 | self.LAYERS[0], 53 | stride=space_n_time_m(2, 1), 54 | dilation=space_n_time_m(dilations[0], 1), 55 | ) 56 | self.layer2 = self._make_layer( 57 | self.BLOCK, 58 | self.PLANES[1], 59 | self.LAYERS[1], 60 | stride=space_n_time_m(2, 1), 61 | dilation=space_n_time_m(dilations[1], 1), 62 | ) 63 | self.layer3 = self._make_layer( 64 | self.BLOCK, 65 | self.PLANES[2], 66 | self.LAYERS[2], 67 | stride=space_n_time_m(2, 1), 68 | dilation=space_n_time_m(dilations[2], 1), 69 | ) 70 | self.layer4 = self._make_layer( 71 | self.BLOCK, 72 | self.PLANES[3], 73 | self.LAYERS[3], 74 | stride=space_n_time_m(2, 1), 75 | dilation=space_n_time_m(dilations[3], 1), 76 | ) 77 | 78 | self.final = conv( 79 | self.PLANES[3] * self.BLOCK.expansion, 80 | out_channels, 81 | kernel_size=1, 82 | bias=True, 83 | D=D, 84 | ) 85 | 86 | def weight_initialization(self): 87 | for m in self.modules(): 88 | if isinstance(m, ME.MinkowskiBatchNorm): 89 | nn.init.constant_(m.bn.weight, 1) 90 | nn.init.constant_(m.bn.bias, 0) 91 | 92 | def _make_layer( 93 | self, 94 | block, 95 | planes, 96 | blocks, 97 | stride=1, 98 | dilation=1, 99 | norm_type=NormType.BATCH_NORM, 100 | bn_momentum=0.1, 101 | ): 102 | downsample = None 103 | if stride != 1 or self.inplanes != planes * block.expansion: 104 | downsample = nn.Sequential( 105 | conv( 106 | self.inplanes, 107 | planes * block.expansion, 108 | kernel_size=1, 109 | stride=stride, 110 | bias=False, 111 | D=self.D, 112 | ), 113 | get_norm( 114 | norm_type, 115 | planes * block.expansion, 116 | D=self.D, 117 | bn_momentum=bn_momentum, 118 | ), 119 | ) 120 | layers = [] 121 | layers.append( 122 | block( 123 | self.inplanes, 124 | planes, 125 | stride=stride, 126 | dilation=dilation, 127 | downsample=downsample, 128 | conv_type=self.CONV_TYPE, 129 | D=self.D, 130 | ) 131 | ) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append( 135 | block( 136 | self.inplanes, 137 | planes, 138 | stride=1, 139 | dilation=dilation, 140 | conv_type=self.CONV_TYPE, 141 | D=self.D, 142 | ) 143 | ) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def forward(self, x): 148 | x = self.conv1(x) 149 | x = self.bn1(x) 150 | x = self.relu(x) 151 | x = self.pool(x) 152 | 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x = self.layer3(x) 156 | x = self.layer4(x) 157 | 158 | x = self.final(x) 159 | return x 160 | 161 | 162 | class ResNet14(ResNetBase): 163 | BLOCK = BasicBlock 164 | LAYERS = (1, 1, 1, 1) 165 | 166 | 167 | class ResNet18(ResNetBase): 168 | BLOCK = BasicBlock 169 | LAYERS = (2, 2, 2, 2) 170 | 171 | 172 | class ResNet34(ResNetBase): 173 | BLOCK = BasicBlock 174 | LAYERS = (3, 4, 6, 3) 175 | 176 | 177 | class ResNet50(ResNetBase): 178 | BLOCK = Bottleneck 179 | LAYERS = (3, 4, 6, 3) 180 | 181 | 182 | class ResNet101(ResNetBase): 183 | BLOCK = Bottleneck 184 | LAYERS = (3, 4, 23, 3) 185 | 186 | 187 | class STResNetBase(ResNetBase): 188 | CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS 189 | 190 | def __init__(self, in_channels, out_channels, config, D=4, **kwargs): 191 | super().__init__(in_channels, out_channels, config, D, **kwargs) 192 | 193 | 194 | class STResNet14(STResNetBase, ResNet14): 195 | pass 196 | 197 | 198 | class STResNet18(STResNetBase, ResNet18): 199 | pass 200 | 201 | 202 | class STResNet34(STResNetBase, ResNet34): 203 | pass 204 | 205 | 206 | class STResNet50(STResNetBase, ResNet50): 207 | pass 208 | 209 | 210 | class STResNet101(STResNetBase, ResNet101): 211 | pass 212 | 213 | 214 | class STResTesseractNetBase(STResNetBase): 215 | CONV_TYPE = ConvType.HYPERCUBE 216 | 217 | 218 | class STResTesseractNet14(STResTesseractNetBase, STResNet14): 219 | pass 220 | 221 | 222 | class STResTesseractNet18(STResTesseractNetBase, STResNet18): 223 | pass 224 | 225 | 226 | class STResTesseractNet34(STResTesseractNetBase, STResNet34): 227 | pass 228 | 229 | 230 | class STResTesseractNet50(STResTesseractNetBase, STResNet50): 231 | pass 232 | 233 | 234 | class STResTesseractNet101(STResTesseractNetBase, STResNet101): 235 | pass 236 | -------------------------------------------------------------------------------- /Mask4Former3D/scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export OMP_NUM_THREADS=12 # speeds up MinkowskiEngine 3 | export CUDA_LAUNCH_BLOCKING=1 4 | export HYDRA_FULL_ERROR=1 5 | 6 | EXPERIMENT_NAME="2024-01-01_000000" 7 | 8 | python main_panoptic.py \ 9 | general.mode="test" \ 10 | general.ckpt_path="saved/$EXPERIMENT_NAME/last-epoch.ckpt" \ 11 | general.dbscan_eps=1.0 -------------------------------------------------------------------------------- /Mask4Former3D/scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export OMP_NUM_THREADS=12 # speeds up MinkowskiEngine 3 | export CUDA_LAUNCH_BLOCKING=1 4 | export HYDRA_FULL_ERROR=1 5 | 6 | # TRAIN 7 | python main_panoptic.py -------------------------------------------------------------------------------- /Mask4Former3D/scripts/val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export OMP_NUM_THREADS=12 # speeds up MinkowskiEngine 3 | export CUDA_LAUNCH_BLOCKING=1 4 | export HYDRA_FULL_ERROR=1 5 | 6 | EXPERIMENT_NAME="2024-01-01_000000" 7 | 8 | python main_panoptic.py \ 9 | general.mode="validate" \ 10 | general.ckpt_path="saved/$EXPERIMENT_NAME/last-epoch.ckpt" \ 11 | general.dbscan_eps=1.0 -------------------------------------------------------------------------------- /Mask4Former3D/trainer/pq_trainer.py: -------------------------------------------------------------------------------- 1 | import statistics 2 | from collections import defaultdict 3 | from contextlib import nullcontext 4 | from pathlib import Path 5 | 6 | import hydra 7 | import MinkowskiEngine as ME 8 | import numpy as np 9 | import pytorch_lightning as pl 10 | import torch 11 | from sklearn.cluster import DBSCAN 12 | 13 | from utils.utils import associate_instances 14 | 15 | 16 | class PanopticSegmentation(pl.LightningModule): 17 | def __init__(self, config): 18 | super().__init__() 19 | 20 | self.config = config 21 | self.save_hyperparameters() 22 | # model 23 | self.model = hydra.utils.instantiate(config.model) 24 | self.optional_freeze = nullcontext 25 | 26 | matcher = hydra.utils.instantiate(config.matcher) 27 | weight_dict = { 28 | "loss_ce": matcher.cost_class, 29 | "loss_mask": matcher.cost_mask, 30 | "loss_dice": matcher.cost_dice, 31 | "loss_box": matcher.cost_box, 32 | } 33 | 34 | aux_weight_dict = {} 35 | for i in range(self.model.num_levels * self.model.num_decoders): 36 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 37 | weight_dict.update(aux_weight_dict) 38 | 39 | self.criterion = hydra.utils.instantiate( 40 | config.loss, matcher=matcher, weight_dict=weight_dict 41 | ) 42 | # metrics 43 | self.class_evaluator = hydra.utils.instantiate(config.metric) 44 | self.last_seq = None 45 | 46 | def forward(self, x, raw_coordinates=None, is_eval=False): 47 | with self.optional_freeze(): 48 | x = self.model(x, raw_coordinates=raw_coordinates, is_eval=is_eval) 49 | return x 50 | 51 | def training_step(self, batch, batch_idx): 52 | data, target = batch 53 | 54 | raw_coordinates = data.raw_coordinates 55 | data = ME.SparseTensor( 56 | coordinates=data.coordinates, features=data.features, device=self.device 57 | ) 58 | 59 | output = self.forward(data, raw_coordinates=raw_coordinates) 60 | losses = self.criterion(output, target) 61 | 62 | for k in list(losses.keys()): 63 | if k in self.criterion.weight_dict: 64 | losses[k] *= self.criterion.weight_dict[k] 65 | else: 66 | # remove this loss if not specified in `weight_dict` 67 | losses.pop(k) 68 | 69 | logs = {f"train_{k}": v.detach().cpu().item() for k, v in losses.items()} 70 | 71 | logs["train_mean_loss_ce"] = statistics.mean( 72 | [item for item in [v for k, v in logs.items() if "loss_ce" in k]] 73 | ) 74 | 75 | logs["train_mean_loss_mask"] = statistics.mean( 76 | [item for item in [v for k, v in logs.items() if "loss_mask" in k]] 77 | ) 78 | 79 | logs["train_mean_loss_dice"] = statistics.mean( 80 | [item for item in [v for k, v in logs.items() if "loss_dice" in k]] 81 | ) 82 | 83 | logs["train_mean_loss_box"] = statistics.mean( 84 | [item for item in [v for k, v in logs.items() if "loss_box" in k]] 85 | ) 86 | 87 | self.log_dict(logs) 88 | return sum(losses.values()) 89 | 90 | def test_step(self, batch, batch_idx): 91 | data, target = batch 92 | inverse_maps = data.inverse_maps 93 | # original_labels = data.original_labels 94 | raw_coordinates = data.raw_coordinates 95 | # num_points = data.num_points 96 | sequences = data.sequences 97 | 98 | data = ME.SparseTensor( 99 | coordinates=data.coordinates, features=data.features, device=self.device 100 | ) 101 | output = self.forward(data, raw_coordinates=raw_coordinates, is_eval=True) 102 | 103 | def get_maxlogit(logit, mask, inv_map): 104 | confid = mask.float().sigmoid().matmul(logit) 105 | max_logit = torch.max(confid, dim=1).values[inv_map] 106 | max_logit = (max_logit * -1) + 1 107 | return max_logit 108 | 109 | def get_rba(logit, mask, inv_map): 110 | confid = mask.float().sigmoid().matmul(logit) 111 | rba = -confid.tanh().sum(dim=1)[inv_map] 112 | rba[rba < -1] = -1 113 | rba = rba + 1 114 | return rba 115 | 116 | pred_logits = output["pred_logits"] 117 | pred_logits = torch.functional.F.softmax(pred_logits, dim=-1)[..., :-1] 118 | pred_masks = output["pred_masks"] 119 | 120 | save_path = Path(self.config.general.save_dir) / "prediction" 121 | for b_idx in range(len(pred_logits)): 122 | rba_score = get_rba( 123 | pred_logits[b_idx], pred_masks[b_idx], inverse_maps[b_idx] 124 | ) 125 | # here are some of the commented-out lines that were used to predict rba and max_logit 126 | # at the same time 127 | # max_logit = get_maxlogit( 128 | # pred_logits[b_idx], pred_masks[b_idx], inverse_maps[b_idx] 129 | # ) 130 | base_path = save_path / f"{sequences[b_idx][0]}" 131 | 132 | rba_save_path = base_path / f"{sequences[b_idx][1]}.txt" 133 | # rba_save_path = base_path / f"rba_{sequences[b_idx][1]}.txt" 134 | # max_logit_save_path = base_path / f"max_logit_{sequences[b_idx][1]}.txt" 135 | if not rba_save_path.parent.exists(): 136 | rba_save_path.parent.mkdir(exist_ok=True, parents=True) 137 | np.savetxt(rba_save_path, rba_score.detach().cpu().numpy()) 138 | # np.savetxt(max_logit_save_path, max_logit.detach().cpu().numpy()) 139 | 140 | # Save inlier predictions 141 | # confid = pred_masks[b_idx].float().sigmoid().matmul(pred_logits[b_idx]) 142 | # sem_preds = torch.argmax(confid, dim=1)[inverse_maps[b_idx]].cpu().numpy() 143 | # sem_preds_path = base_path / f"sem_preds_{sequences[b_idx][1]}.txt" 144 | # np.savetxt(sem_preds_path, sem_preds) 145 | 146 | return {} 147 | 148 | def validation_step(self, batch, batch_idx): 149 | data, target = batch 150 | inverse_maps = data.inverse_maps 151 | original_labels = data.original_labels 152 | raw_coordinates = data.raw_coordinates 153 | num_points = data.num_points 154 | sequences = data.sequences 155 | 156 | data = ME.SparseTensor( 157 | coordinates=data.coordinates, features=data.features, device=self.device 158 | ) 159 | output = self.forward(data, raw_coordinates=raw_coordinates, is_eval=True) 160 | losses = self.criterion(output, target) 161 | 162 | for k in list(losses.keys()): 163 | if k in self.criterion.weight_dict: 164 | losses[k] *= self.criterion.weight_dict[k] 165 | else: 166 | # remove this loss if not specified in `weight_dict` 167 | losses.pop(k) 168 | 169 | pred_logits = output["pred_logits"] 170 | pred_logits = torch.functional.F.softmax(pred_logits, dim=-1)[..., :-1] 171 | pred_masks = output["pred_masks"] 172 | offset_coords_idx = 0 173 | 174 | for logit, mask, map, label, n_point, seq in zip( 175 | pred_logits, 176 | pred_masks, 177 | inverse_maps, 178 | original_labels, 179 | num_points, 180 | sequences, 181 | ): 182 | seq = seq[0] 183 | if seq != self.last_seq: 184 | self.last_seq = seq 185 | self.previous_instances = None 186 | self.max_instance_id = self.config.model.num_queries 187 | self.scene = 0 188 | 189 | class_confidence, classes = torch.max(logit.detach().cpu(), dim=1) 190 | foreground_confidence = mask.detach().cpu().float().sigmoid() 191 | confidence = class_confidence[None, ...] * foreground_confidence 192 | confidence = confidence[map].numpy() 193 | 194 | ins_preds = np.argmax(confidence, axis=1) 195 | sem_preds = classes[ins_preds].numpy() + 1 196 | ins_preds += 1 197 | ins_preds[ 198 | np.isin( 199 | sem_preds, range(1, self.config.data.min_stuff_cls_id), invert=True 200 | ) 201 | ] = 0 202 | sem_labels = self.validation_dataset._remap_model_output(label[:, 0]) 203 | ins_labels = label[:, 1] >> 16 204 | 205 | db_max_instance_id = self.config.model.num_queries 206 | if self.config.general.dbscan_eps is not None: 207 | curr_coords_idx = mask.shape[0] 208 | curr_coords = raw_coordinates[ 209 | offset_coords_idx : curr_coords_idx + offset_coords_idx, :3 210 | ] 211 | curr_coords = curr_coords[map].detach().cpu().numpy() 212 | offset_coords_idx += curr_coords_idx 213 | 214 | ins_ids = np.unique(ins_preds) 215 | for ins_id in ins_ids: 216 | if ins_id != 0: 217 | instance_mask = ins_preds == ins_id 218 | clusters = ( 219 | DBSCAN( 220 | eps=self.config.general.dbscan_eps, 221 | min_samples=1, 222 | n_jobs=-1, 223 | ) 224 | .fit(curr_coords[instance_mask]) 225 | .labels_ 226 | ) 227 | new_mask = np.zeros(ins_preds.shape, dtype=np.int64) 228 | new_mask[instance_mask] = clusters + 1 229 | for cluster_id in np.unique(new_mask): 230 | if cluster_id != 0: 231 | db_max_instance_id += 1 232 | ins_preds[new_mask == cluster_id] = db_max_instance_id 233 | 234 | self.max_instance_id = max(db_max_instance_id, self.max_instance_id) 235 | for i in range(len(n_point) - 1): 236 | indices = range(n_point[i], n_point[i + 1]) 237 | if i == 0 and self.previous_instances is not None: 238 | current_instances = ins_preds[indices] 239 | associations = associate_instances( 240 | self.previous_instances, current_instances 241 | ) 242 | for id in np.unique(ins_preds): 243 | if associations.get(id) is None: 244 | self.max_instance_id += 1 245 | associations[id] = self.max_instance_id 246 | ins_preds = np.vectorize(associations.__getitem__)(ins_preds) 247 | else: 248 | self.class_evaluator.addBatch( 249 | sem_preds, ins_preds, sem_labels, ins_labels, indices, seq 250 | ) 251 | if i > 0: 252 | self.previous_instances = ins_preds[indices] 253 | 254 | return {f"val_{k}": v.detach().cpu().item() for k, v in losses.items()} 255 | 256 | def training_epoch_end(self, outputs): 257 | train_loss = sum([out["loss"].cpu().item() for out in outputs]) / len(outputs) 258 | results = {"train_loss_mean": train_loss} 259 | self.log_dict(results) 260 | 261 | def validation_epoch_end(self, outputs): 262 | self.last_seq = None 263 | class_names = self.config.data.class_names 264 | pq, sq, rq, all_pq, all_sq, all_rq = self.class_evaluator.getPQ() 265 | self.class_evaluator.reset() 266 | results = {} 267 | results["val_mean_pq"] = pq 268 | results["val_mean_sq"] = sq 269 | results["val_mean_rq"] = rq 270 | # print(class_names) 271 | # print(all_pq) 272 | # print(all_sq) 273 | # print(all_rq) 274 | for i, (pq, sq, rq) in enumerate(zip(all_pq, all_sq, all_rq)): 275 | results[f"val_{class_names[i-1]}_pq"] = pq.item() 276 | results[f"val_{class_names[i-1]}_sq"] = sq.item() 277 | results[f"val_{class_names[i-1]}_rq"] = rq.item() 278 | self.log_dict(results) 279 | print(results) 280 | 281 | dd = defaultdict(list) 282 | for output in outputs: 283 | for key, val in output.items(): 284 | dd[key].append(val) 285 | 286 | dd = {k: statistics.mean(v) for k, v in dd.items()} 287 | 288 | dd["val_mean_loss_ce"] = statistics.mean( 289 | [item for item in [v for k, v in dd.items() if "loss_ce" in k]] 290 | ) 291 | dd["val_mean_loss_mask"] = statistics.mean( 292 | [item for item in [v for k, v in dd.items() if "loss_mask" in k]] 293 | ) 294 | dd["val_mean_loss_dice"] = statistics.mean( 295 | [item for item in [v for k, v in dd.items() if "loss_dice" in k]] 296 | ) 297 | dd["val_mean_loss_box"] = statistics.mean( 298 | [item for item in [v for k, v in dd.items() if "loss_box" in k]] 299 | ) 300 | self.log_dict(dd) 301 | 302 | def test_epoch_end(self, outputs): 303 | return {} 304 | 305 | def configure_optimizers(self): 306 | optimizer = hydra.utils.instantiate( 307 | self.config.optimizer, params=self.parameters() 308 | ) 309 | if "steps_per_epoch" in self.config.scheduler.scheduler.keys(): 310 | self.config.scheduler.scheduler.steps_per_epoch = len( 311 | self.train_dataloader() 312 | ) 313 | lr_scheduler = hydra.utils.instantiate( 314 | self.config.scheduler.scheduler, optimizer=optimizer 315 | ) 316 | scheduler_config = {"scheduler": lr_scheduler} 317 | scheduler_config.update(self.config.scheduler.pytorch_lightning_params) 318 | return [optimizer], [scheduler_config] 319 | 320 | def prepare_data(self): 321 | self.train_dataset = hydra.utils.instantiate(self.config.data.train_dataset) 322 | self.validation_dataset = hydra.utils.instantiate( 323 | self.config.data.validation_dataset 324 | ) 325 | self.test_dataset = hydra.utils.instantiate(self.config.data.test_dataset) 326 | 327 | def train_dataloader(self): 328 | c_fn = hydra.utils.instantiate(self.config.data.train_collation) 329 | return hydra.utils.instantiate( 330 | self.config.data.train_dataloader, 331 | self.train_dataset, 332 | collate_fn=c_fn, 333 | ) 334 | 335 | def val_dataloader(self): 336 | c_fn = hydra.utils.instantiate(self.config.data.validation_collation) 337 | return hydra.utils.instantiate( 338 | self.config.data.validation_dataloader, 339 | self.validation_dataset, 340 | collate_fn=c_fn, 341 | ) 342 | 343 | def test_dataloader(self): 344 | c_fn = hydra.utils.instantiate(self.config.data.test_collation) 345 | return hydra.utils.instantiate( 346 | self.config.data.test_dataloader, 347 | self.test_dataset, 348 | collate_fn=c_fn, 349 | ) 350 | -------------------------------------------------------------------------------- /Mask4Former3D/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import statistics 2 | import hydra 3 | import MinkowskiEngine as ME 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | from sklearn.cluster import DBSCAN 8 | from contextlib import nullcontext 9 | from collections import defaultdict 10 | from utils.utils import associate_instances, save_predictions 11 | 12 | 13 | class PanopticSegmentation(pl.LightningModule): 14 | def __init__(self, config): 15 | super().__init__() 16 | 17 | self.config = config 18 | self.save_hyperparameters() 19 | # model 20 | self.model = hydra.utils.instantiate(config.model) 21 | self.optional_freeze = nullcontext 22 | 23 | matcher = hydra.utils.instantiate(config.matcher) 24 | weight_dict = { 25 | "loss_ce": matcher.cost_class, 26 | "loss_mask": matcher.cost_mask, 27 | "loss_dice": matcher.cost_dice, 28 | "loss_box": matcher.cost_box, 29 | } 30 | 31 | aux_weight_dict = {} 32 | for i in range(self.model.num_levels * self.model.num_decoders): 33 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 34 | weight_dict.update(aux_weight_dict) 35 | 36 | self.criterion = hydra.utils.instantiate(config.loss, matcher=matcher, weight_dict=weight_dict) 37 | # metrics 38 | self.class_evaluator = hydra.utils.instantiate(config.metric) 39 | self.last_seq = None 40 | 41 | def forward(self, x, raw_coordinates=None, is_eval=False): 42 | with self.optional_freeze(): 43 | x = self.model(x, raw_coordinates=raw_coordinates, is_eval=is_eval) 44 | return x 45 | 46 | def training_step(self, batch, batch_idx): 47 | data, target = batch 48 | raw_coordinates = data.raw_coordinates 49 | data = ME.SparseTensor(coordinates=data.coordinates, features=data.features, device=self.device) 50 | 51 | output = self.forward(data, raw_coordinates=raw_coordinates) 52 | losses = self.criterion(output, target) 53 | 54 | for k in list(losses.keys()): 55 | if k in self.criterion.weight_dict: 56 | losses[k] *= self.criterion.weight_dict[k] 57 | else: 58 | # remove this loss if not specified in `weight_dict` 59 | losses.pop(k) 60 | 61 | logs = {f"train_{k}": v.detach().cpu().item() for k, v in losses.items()} 62 | 63 | logs["train_mean_loss_ce"] = statistics.mean( 64 | [item for item in [v for k, v in logs.items() if "loss_ce" in k]] 65 | ) 66 | 67 | logs["train_mean_loss_mask"] = statistics.mean( 68 | [item for item in [v for k, v in logs.items() if "loss_mask" in k]] 69 | ) 70 | 71 | logs["train_mean_loss_dice"] = statistics.mean( 72 | [item for item in [v for k, v in logs.items() if "loss_dice" in k]] 73 | ) 74 | 75 | logs["train_mean_loss_box"] = statistics.mean( 76 | [item for item in [v for k, v in logs.items() if "loss_box" in k]] 77 | ) 78 | 79 | self.log_dict(logs) 80 | return sum(losses.values()) 81 | 82 | def validation_step(self, batch, batch_idx): 83 | data, target = batch 84 | inverse_maps = data.inverse_maps 85 | original_labels = data.original_labels 86 | raw_coordinates = data.raw_coordinates 87 | num_points = data.num_points 88 | sequences = data.sequences 89 | 90 | data = ME.SparseTensor(coordinates=data.coordinates, features=data.features, device=self.device) 91 | output = self.forward(data, raw_coordinates=raw_coordinates, is_eval=True) 92 | losses = self.criterion(output, target) 93 | 94 | for k in list(losses.keys()): 95 | if k in self.criterion.weight_dict: 96 | losses[k] *= self.criterion.weight_dict[k] 97 | else: 98 | # remove this loss if not specified in `weight_dict` 99 | losses.pop(k) 100 | 101 | pred_logits = output["pred_logits"] 102 | pred_logits = torch.functional.F.softmax(pred_logits, dim=-1)[..., :-1] 103 | pred_masks = output["pred_masks"] 104 | offset_coords_idx = 0 105 | 106 | for logit, mask, map, label, n_point, seq in zip( 107 | pred_logits, pred_masks, inverse_maps, original_labels, num_points, sequences 108 | ): 109 | if seq != self.last_seq: 110 | self.last_seq = seq 111 | self.previous_instances = None 112 | self.max_instance_id = self.config.model.num_queries 113 | self.scene = 0 114 | 115 | class_confidence, classes = torch.max(logit.detach().cpu(), dim=1) 116 | foreground_confidence = mask.detach().cpu().float().sigmoid() 117 | confidence = class_confidence[None, ...] * foreground_confidence 118 | confidence = confidence[map].numpy() 119 | 120 | ins_preds = np.argmax(confidence, axis=1) 121 | sem_preds = classes[ins_preds].numpy() + 1 122 | ins_preds += 1 123 | ins_preds[np.isin(sem_preds, range(1, self.config.data.min_stuff_cls_id), invert=True)] = 0 124 | sem_labels = self.validation_dataset._remap_model_output(label[:, 0]) 125 | ins_labels = label[:, 1] >> 16 126 | 127 | db_max_instance_id = self.config.model.num_queries 128 | if self.config.general.dbscan_eps is not None: 129 | curr_coords_idx = mask.shape[0] 130 | curr_coords = raw_coordinates[offset_coords_idx : curr_coords_idx + offset_coords_idx, :3] 131 | curr_coords = curr_coords[map].detach().cpu().numpy() 132 | offset_coords_idx += curr_coords_idx 133 | 134 | ins_ids = np.unique(ins_preds) 135 | for ins_id in ins_ids: 136 | if ins_id != 0: 137 | instance_mask = ins_preds == ins_id 138 | clusters = ( 139 | DBSCAN(eps=self.config.general.dbscan_eps, min_samples=1, n_jobs=-1) 140 | .fit(curr_coords[instance_mask]) 141 | .labels_ 142 | ) 143 | new_mask = np.zeros(ins_preds.shape, dtype=np.int64) 144 | new_mask[instance_mask] = clusters + 1 145 | for cluster_id in np.unique(new_mask): 146 | if cluster_id != 0: 147 | db_max_instance_id += 1 148 | ins_preds[new_mask == cluster_id] = db_max_instance_id 149 | 150 | self.max_instance_id = max(db_max_instance_id, self.max_instance_id) 151 | for i in range(len(n_point) - 1): 152 | indices = range(n_point[i], n_point[i + 1]) 153 | if i == 0 and self.previous_instances is not None: 154 | current_instances = ins_preds[indices] 155 | associations = associate_instances(self.previous_instances, current_instances) 156 | for id in np.unique(ins_preds): 157 | if associations.get(id) is None: 158 | self.max_instance_id += 1 159 | associations[id] = self.max_instance_id 160 | ins_preds = np.vectorize(associations.__getitem__)(ins_preds) 161 | else: 162 | self.class_evaluator.addBatch(sem_preds, ins_preds, sem_labels, ins_labels, indices, seq) 163 | if i > 0: 164 | self.previous_instances = ins_preds[indices] 165 | 166 | return {f"val_{k}": v.detach().cpu().item() for k, v in losses.items()} 167 | 168 | def test_step(self, batch, batch_idx): 169 | data, _ = batch 170 | inverse_maps = data.inverse_maps 171 | raw_coordinates = data.raw_coordinates 172 | num_points = data.num_points 173 | sequences = data.sequences 174 | 175 | data = ME.SparseTensor(coordinates=data.coordinates, features=data.features, device=self.device) 176 | output = self.forward(data, raw_coordinates=raw_coordinates, is_eval=True) 177 | 178 | pred_logits = output["pred_logits"] 179 | pred_logits = torch.functional.F.softmax(pred_logits, dim=-1)[..., :-1] 180 | pred_masks = output["pred_masks"] 181 | 182 | offset_coords_idx = 0 183 | 184 | for logit, mask, map, n_point, seq in zip( 185 | pred_logits, pred_masks, inverse_maps, num_points, sequences 186 | ): 187 | if seq != self.last_seq: 188 | self.last_seq = seq 189 | self.previous_instances = None 190 | self.max_instance_id = self.config.model.num_queries 191 | self.scene = 0 192 | class_confidence, classes = torch.max(logit.detach().cpu(), dim=1) 193 | foreground_confidence = mask.detach().cpu().float().sigmoid() 194 | confidence = class_confidence[None, ...] * foreground_confidence 195 | confidence = confidence[map].numpy() 196 | 197 | ins_preds = np.argmax(confidence, axis=1) 198 | sem_preds = classes[ins_preds].numpy() + 1 199 | ins_preds += 1 200 | ins_preds[np.isin(sem_preds, range(1, self.config.data.min_stuff_cls_id), invert=True)] = 0 201 | 202 | db_max_instance_id = self.config.model.num_queries 203 | if self.config.general.dbscan_eps is not None: 204 | curr_coords_idx = mask.shape[0] 205 | curr_coords = raw_coordinates[offset_coords_idx : curr_coords_idx + offset_coords_idx, :3] 206 | curr_coords = curr_coords[map].detach().cpu().numpy() 207 | offset_coords_idx += curr_coords_idx 208 | 209 | ins_ids = np.unique(ins_preds) 210 | for ins_id in ins_ids: 211 | if ins_id != 0: 212 | instance_mask = ins_preds == ins_id 213 | clusters = ( 214 | DBSCAN(eps=self.config.general.dbscan_eps, min_samples=1, n_jobs=-1) 215 | .fit(curr_coords[instance_mask]) 216 | .labels_ 217 | ) 218 | new_mask = np.zeros(ins_preds.shape, dtype=np.int64) 219 | new_mask[instance_mask] = clusters + 1 220 | for cluster_id in np.unique(new_mask): 221 | if cluster_id != 0: 222 | db_max_instance_id += 1 223 | ins_preds[new_mask == cluster_id] = db_max_instance_id 224 | 225 | self.max_instance_id = max(db_max_instance_id, self.max_instance_id) 226 | for i in range(len(n_point) - 1): 227 | indices = range(n_point[i], n_point[i + 1]) 228 | if i == 0 and self.previous_instances is not None: 229 | current_instances = ins_preds[indices] 230 | associations = associate_instances(self.previous_instances, current_instances) 231 | for id in np.unique(ins_preds): 232 | if associations.get(id) is None: 233 | self.max_instance_id += 1 234 | associations[id] = self.max_instance_id 235 | ins_preds = np.vectorize(associations.__getitem__)(ins_preds) 236 | else: 237 | save_predictions(sem_preds[indices], ins_preds[indices], f"{seq:02}", f"{self.scene:06}") 238 | self.scene += 1 239 | if i > 0: 240 | self.previous_instances = ins_preds[indices] 241 | 242 | return {} 243 | 244 | def training_epoch_end(self, outputs): 245 | train_loss = sum([out["loss"].cpu().item() for out in outputs]) / len(outputs) 246 | results = {"train_loss_mean": train_loss} 247 | self.log_dict(results) 248 | 249 | def validation_epoch_end(self, outputs): 250 | self.last_seq = None 251 | class_names = self.config.data.class_names 252 | lstq, aq, all_aq, iou, all_iou = self.class_evaluator.getPQ4D() 253 | self.class_evaluator.reset() 254 | results = {} 255 | results["val_mean_aq"] = aq 256 | results["val_mean_iou"] = iou 257 | results["val_mean_lstq"] = lstq 258 | for i, (aq, iou) in enumerate(zip(all_aq, all_iou)): 259 | results[f"val_{class_names[i]}_aq"] = aq.item() 260 | results[f"val_{class_names[i]}_iou"] = iou.item() 261 | self.log_dict(results) 262 | 263 | dd = defaultdict(list) 264 | for output in outputs: 265 | for key, val in output.items(): 266 | dd[key].append(val) 267 | 268 | dd = {k: statistics.mean(v) for k, v in dd.items()} 269 | 270 | dd["val_mean_loss_ce"] = statistics.mean( 271 | [item for item in [v for k, v in dd.items() if "loss_ce" in k]] 272 | ) 273 | dd["val_mean_loss_mask"] = statistics.mean( 274 | [item for item in [v for k, v in dd.items() if "loss_mask" in k]] 275 | ) 276 | dd["val_mean_loss_dice"] = statistics.mean( 277 | [item for item in [v for k, v in dd.items() if "loss_dice" in k]] 278 | ) 279 | dd["val_mean_loss_box"] = statistics.mean( 280 | [item for item in [v for k, v in dd.items() if "loss_box" in k]] 281 | ) 282 | 283 | self.log_dict(dd) 284 | 285 | def test_epoch_end(self, outputs): 286 | return {} 287 | 288 | def configure_optimizers(self): 289 | optimizer = hydra.utils.instantiate(self.config.optimizer, params=self.parameters()) 290 | if "steps_per_epoch" in self.config.scheduler.scheduler.keys(): 291 | self.config.scheduler.scheduler.steps_per_epoch = len(self.train_dataloader()) 292 | lr_scheduler = hydra.utils.instantiate(self.config.scheduler.scheduler, optimizer=optimizer) 293 | scheduler_config = {"scheduler": lr_scheduler} 294 | scheduler_config.update(self.config.scheduler.pytorch_lightning_params) 295 | return [optimizer], [scheduler_config] 296 | 297 | def prepare_data(self): 298 | self.train_dataset = hydra.utils.instantiate(self.config.data.train_dataset) 299 | self.validation_dataset = hydra.utils.instantiate(self.config.data.validation_dataset) 300 | self.test_dataset = hydra.utils.instantiate(self.config.data.test_dataset) 301 | 302 | def train_dataloader(self): 303 | c_fn = hydra.utils.instantiate(self.config.data.train_collation) 304 | return hydra.utils.instantiate( 305 | self.config.data.train_dataloader, 306 | self.train_dataset, 307 | collate_fn=c_fn, 308 | ) 309 | 310 | def val_dataloader(self): 311 | c_fn = hydra.utils.instantiate(self.config.data.validation_collation) 312 | return hydra.utils.instantiate( 313 | self.config.data.validation_dataloader, 314 | self.validation_dataset, 315 | collate_fn=c_fn, 316 | ) 317 | 318 | def test_dataloader(self): 319 | c_fn = hydra.utils.instantiate(self.config.data.test_collation) 320 | return hydra.utils.instantiate( 321 | self.config.data.test_dataloader, 322 | self.test_dataset, 323 | collate_fn=c_fn, 324 | ) 325 | -------------------------------------------------------------------------------- /Mask4Former3D/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kumuji/stu_dataset/3812d768f3634fbb6faeb8b0bfbd5246a9798e93/Mask4Former3D/utils/__init__.py -------------------------------------------------------------------------------- /Mask4Former3D/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.optimize import linear_sum_assignment 4 | import sys 5 | import pytorch_lightning as pl 6 | from pathlib import Path 7 | import os 8 | 9 | if sys.version_info[:2] >= (3, 8): 10 | from collections.abc import MutableMapping 11 | else: 12 | from collections import MutableMapping 13 | 14 | 15 | def flatten_dict(d, parent_key="", sep="_"): 16 | """ 17 | https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys 18 | """ 19 | items = [] 20 | for k, v in d.items(): 21 | new_key = parent_key + sep + k if parent_key else k 22 | if isinstance(v, MutableMapping): 23 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 24 | else: 25 | items.append((new_key, v)) 26 | return dict(items) 27 | 28 | 29 | class RegularCheckpointing(pl.Callback): 30 | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): 31 | general = pl_module.config.general 32 | trainer.save_checkpoint(f"{general.save_dir}/last-epoch.ckpt") 33 | print("Checkpoint created") 34 | 35 | 36 | def associate_instances(previous_ins_label, current_ins_label): 37 | previous_instance_ids, c_p = np.unique(previous_ins_label[previous_ins_label != 0], return_counts=True) 38 | current_instance_ids, c_c = np.unique(current_ins_label[current_ins_label != 0], return_counts=True) 39 | 40 | associations = {0: 0} 41 | 42 | large_previous_instance_ids = [] 43 | large_current_instance_ids = [] 44 | for id, count in zip(previous_instance_ids, c_p): 45 | if count > 25: 46 | large_previous_instance_ids.append(id) 47 | for id, count in zip(current_instance_ids, c_c): 48 | if count > 50: 49 | large_current_instance_ids.append(id) 50 | 51 | p_n = len(large_previous_instance_ids) 52 | c_n = len(large_current_instance_ids) 53 | 54 | association_costs = torch.zeros(p_n, c_n) 55 | for i, p_id in enumerate(large_previous_instance_ids): 56 | for j, c_id in enumerate(large_current_instance_ids): 57 | intersection = np.sum((previous_ins_label == p_id) & (current_ins_label == c_id)) 58 | union = np.sum(previous_ins_label == p_id) + np.sum(current_ins_label == c_id) - intersection 59 | iou = intersection / union 60 | cost = 1 - iou 61 | association_costs[i, j] = cost 62 | 63 | idxes_1, idxes_2 = linear_sum_assignment(association_costs) 64 | 65 | for i1, i2 in zip(idxes_1, idxes_2): 66 | if association_costs[i1][i2] < 1.0: 67 | associations[large_current_instance_ids[i2]] = large_previous_instance_ids[i1] 68 | return associations 69 | 70 | 71 | def save_predictions(sem_preds, ins_preds, seq_name, sweep_name): 72 | filename = Path("/globalwork/yilmaz/submission/sequences") / seq_name / "predictions" 73 | # assert not filename.exists(), "Path exists" 74 | filename.mkdir(parents=True, exist_ok=True) 75 | learning_map_inv = { 76 | 1: 10, # "car" 77 | 2: 11, # "bicycle" 78 | 3: 15, # "motorcycle" 79 | 4: 18, # "truck" 80 | 5: 20, # "other-vehicle" 81 | 6: 30, # "person" 82 | 7: 31, # "bicyclist" 83 | 8: 32, # "motorcyclist" 84 | 9: 40, # "road" 85 | 10: 44, # "parking" 86 | 11: 48, # "sidewalk" 87 | 12: 49, # "other-ground" 88 | 13: 50, # "building" 89 | 14: 51, # "fence" 90 | 15: 70, # "vegetation" 91 | 16: 71, # "trunk" 92 | 17: 72, # "terrain" 93 | 18: 80, # "pole" 94 | 19: 81, # "traffic-sign" 95 | } 96 | sem_preds = np.vectorize(learning_map_inv.__getitem__)(sem_preds) 97 | panoptic_preds = (ins_preds << 16) + sem_preds 98 | file_path = str(filename / sweep_name) + ".label" 99 | if not os.path.exists(file_path): 100 | with open(file_path, "wb") as f: 101 | f.write(panoptic_preds.astype(np.uint32).tobytes()) 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spotting the Unexpected (STU): A 3D LiDAR Dataset for Anomaly Segmentation in Autonomous Driving 2 |