├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── assets └── method.png ├── bevlab ├── __init__.py ├── backbones │ ├── __init__.py │ ├── minkunet.py │ ├── minkunet_me.py │ ├── minkunet_segcontrast.py │ ├── spconv_backbone.py │ └── spvcnn.py ├── dataloader.py ├── datasets.py ├── imagenet.py ├── models.py ├── pnp.py ├── resnet_encoder.py ├── transforms.py └── vision_transformer.py ├── cfgs ├── pretrain_ns_minkunet.yaml ├── pretrain_ns_spconv.yaml ├── pretrain_ns_spvcnn.yaml ├── pretrain_sk_minkunet.yaml └── pretrain_sk_spvcnn.yaml ├── downstream ├── configs │ ├── cfg │ │ ├── nuscenes_minkowski.yaml │ │ ├── nuscenes_torchsparse.yaml │ │ ├── semantickitti_minkowski.yaml │ │ ├── semantickitti_torchsparse.yaml │ │ └── semanticposs_minkowski.yaml │ └── config.yaml ├── convert_models.py ├── datasets │ ├── __init__.py │ ├── kitti360.py │ ├── kitti3d.py │ ├── kitti3d_train.txt │ ├── kitti3d_val.txt │ ├── kitti_360_train_velodynes.txt │ ├── kitti_360_val_velodynes.txt │ ├── nuscenes_category.json │ ├── nuscenes_dataset.py │ ├── once.py │ ├── percentiles_split.json │ ├── semantic-kitti.yaml │ ├── semantic_poss_segcontrast.py │ └── semantickitti_dataset.py ├── eval.py ├── eval_offset.py ├── networks │ ├── __init__.py │ └── backbone │ │ ├── __init__.py │ │ ├── minkowski_engine │ │ ├── __init__.py │ │ ├── minkunet.py │ │ ├── minkunet_segcontrast.py │ │ └── utils.py │ │ ├── spconv │ │ ├── __init__.py │ │ ├── pcdet_models.py │ │ └── utils.py │ │ └── torchsparse │ │ ├── __init__.py │ │ ├── minkunet.py │ │ ├── spvcnn.py │ │ └── utils.py ├── train_downstream_semseg.py ├── transforms │ ├── __init__.py │ ├── create_inputs.py │ ├── create_points.py │ ├── duplicate.py │ ├── get_transforms.py │ ├── random_flip.py │ ├── random_rotate.py │ ├── scaling.py │ └── voxel_decimation.py ├── utils │ ├── __init__.py │ ├── callbacks.py │ ├── confusion_matrix.py │ ├── metrics.py │ └── utils.py └── visu_downstream.py ├── requirements.txt ├── train.py └── utils ├── __init__.py ├── config.py ├── convert_spconv_model.py ├── logger.py ├── optimizer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | datasets 3 | output 4 | outputs 5 | **/__pycache__ 6 | *.pyc 7 | **/.ipynb_checkpoints 8 | downstream/results 9 | checkpoints/ 10 | *.pt 11 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.12.0-cuda11.3-cudnn8-devel 2 | 3 | # -------------------------------------------------------------------- 4 | 5 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 6 | 7 | RUN DEBIAN_FRONTEND=noninteractive apt-get update 8 | RUN DEBIAN_FRONTEND=noninteractive apt-get install -y openssh-server sudo 9 | 10 | # ------------------------------------------------------------------- 11 | 12 | ############################################## 13 | ENV TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6+PTX" 14 | ############################################## 15 | 16 | ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" 17 | 18 | # Install dependencies 19 | RUN apt-get update 20 | RUN apt-get install -y git ninja-build cmake build-essential libopenblas-dev \ 21 | xterm xauth openssh-server tmux wget mate-desktop-environment-core 22 | 23 | RUN apt-get clean 24 | RUN rm -rf /var/lib/apt/lists/* 25 | 26 | # For faster build, use more jobs. 27 | ENV MAX_JOBS=4 28 | RUN pip install -U git+https://github.com/NVIDIA/MinkowskiEngine --install-option="--blas=openblas" --install-option="--force_cuda" -v --no-deps 29 | 30 | RUN apt-get install libgl1-mesa-glx -y 31 | RUN pip install spconv-cu113 32 | 33 | # Torchsparse 34 | RUN apt-get update 35 | RUN apt-get install libsparsehash-dev 36 | RUN FORCE_CUDA=1 pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0 37 | 38 | # Torch Geometric 39 | RUN pip install torch_scatter -f https://data.pyg.org/whl/torch-1.12.0+cu113.html 40 | RUN pip install torch-geometric hydra-core 41 | RUN pip install tqdm easydict tensorboardX 42 | 43 | RUN pip install pytorch-lightning==1.6.5 44 | 45 | RUN pip install nuscenes-devkit==1.1.9 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | NCLR 2 | 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | https://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | 17 | 18 | Apache License 19 | Version 2.0, January 2004 20 | https://www.apache.org/licenses/ 21 | 22 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 23 | 24 | 1. Definitions. 25 | 26 | "License" shall mean the terms and conditions for use, reproduction, 27 | and distribution as defined by Sections 1 through 9 of this document. 28 | 29 | "Licensor" shall mean the copyright owner or entity authorized by 30 | the copyright owner that is granting the License. 31 | 32 | "Legal Entity" shall mean the union of the acting entity and all 33 | other entities that control, are controlled by, or are under common 34 | control with that entity. For the purposes of this definition, 35 | "control" means (i) the power, direct or indirect, to cause the 36 | direction or management of such entity, whether by contract or 37 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 38 | outstanding shares, or (iii) beneficial ownership of such entity. 39 | 40 | "You" (or "Your") shall mean an individual or Legal Entity 41 | exercising permissions granted by this License. 42 | 43 | "Source" form shall mean the preferred form for making modifications, 44 | including but not limited to software source code, documentation 45 | source, and configuration files. 46 | 47 | "Object" form shall mean any form resulting from mechanical 48 | transformation or translation of a Source form, including but 49 | not limited to compiled object code, generated documentation, 50 | and conversions to other media types. 51 | 52 | "Work" shall mean the work of authorship, whether in Source or 53 | Object form, made available under the License, as indicated by a 54 | copyright notice that is included in or attached to the work 55 | (an example is provided in the Appendix below). 56 | 57 | "Derivative Works" shall mean any work, whether in Source or Object 58 | form, that is based on (or derived from) the Work and for which the 59 | editorial revisions, annotations, elaborations, or other modifications 60 | represent, as a whole, an original work of authorship. For the purposes 61 | of this License, Derivative Works shall not include works that remain 62 | separable from, or merely link (or bind by name) to the interfaces of, 63 | the Work and Derivative Works thereof. 64 | 65 | "Contribution" shall mean any work of authorship, including 66 | the original version of the Work and any modifications or additions 67 | to that Work or Derivative Works thereof, that is intentionally 68 | submitted to Licensor for inclusion in the Work by the copyright owner 69 | or by an individual or Legal Entity authorized to submit on behalf of 70 | the copyright owner. For the purposes of this definition, "submitted" 71 | means any form of electronic, verbal, or written communication sent 72 | to the Licensor or its representatives, including but not limited to 73 | communication on electronic mailing lists, source code control systems, 74 | and issue tracking systems that are managed by, or on behalf of, the 75 | Licensor for the purpose of discussing and improving the Work, but 76 | excluding communication that is conspicuously marked or otherwise 77 | designated in writing by the copyright owner as "Not a Contribution." 78 | 79 | "Contributor" shall mean Licensor and any individual or Legal Entity 80 | on behalf of whom a Contribution has been received by Licensor and 81 | subsequently incorporated within the Work. 82 | 83 | 2. Grant of Copyright License. Subject to the terms and conditions of 84 | this License, each Contributor hereby grants to You a perpetual, 85 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 86 | copyright license to reproduce, prepare Derivative Works of, 87 | publicly display, publicly perform, sublicense, and distribute the 88 | Work and such Derivative Works in Source or Object form. 89 | 90 | 3. Grant of Patent License. Subject to the terms and conditions of 91 | this License, each Contributor hereby grants to You a perpetual, 92 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 93 | (except as stated in this section) patent license to make, have made, 94 | use, offer to sell, sell, import, and otherwise transfer the Work, 95 | where such license applies only to those patent claims licensable 96 | by such Contributor that are necessarily infringed by their 97 | Contribution(s) alone or by combination of their Contribution(s) 98 | with the Work to which such Contribution(s) was submitted. If You 99 | institute patent litigation against any entity (including a 100 | cross-claim or counterclaim in a lawsuit) alleging that the Work 101 | or a Contribution incorporated within the Work constitutes direct 102 | or contributory patent infringement, then any patent licenses 103 | granted to You under this License for that Work shall terminate 104 | as of the date such litigation is filed. 105 | 106 | 4. Redistribution. You may reproduce and distribute copies of the 107 | Work or Derivative Works thereof in any medium, with or without 108 | modifications, and in Source or Object form, provided that You 109 | meet the following conditions: 110 | 111 | (a) You must give any other recipients of the Work or 112 | Derivative Works a copy of this License; and 113 | 114 | (b) You must cause any modified files to carry prominent notices 115 | stating that You changed the files; and 116 | 117 | (c) You must retain, in the Source form of any Derivative Works 118 | that You distribute, all copyright, patent, trademark, and 119 | attribution notices from the Source form of the Work, 120 | excluding those notices that do not pertain to any part of 121 | the Derivative Works; and 122 | 123 | (d) If the Work includes a "NOTICE" text file as part of its 124 | distribution, then any Derivative Works that You distribute must 125 | include a readable copy of the attribution notices contained 126 | within such NOTICE file, excluding those notices that do not 127 | pertain to any part of the Derivative Works, in at least one 128 | of the following places: within a NOTICE text file distributed 129 | as part of the Derivative Works; within the Source form or 130 | documentation, if provided along with the Derivative Works; or, 131 | within a display generated by the Derivative Works, if and 132 | wherever such third-party notices normally appear. The contents 133 | of the NOTICE file are for informational purposes only and 134 | do not modify the License. You may add Your own attribution 135 | notices within Derivative Works that You distribute, alongside 136 | or as an addendum to the NOTICE text from the Work, provided 137 | that such additional attribution notices cannot be construed 138 | as modifying the License. 139 | 140 | You may add Your own copyright statement to Your modifications and 141 | may provide additional or different license terms and conditions 142 | for use, reproduction, or distribution of Your modifications, or 143 | for any such Derivative Works as a whole, provided Your use, 144 | reproduction, and distribution of the Work otherwise complies with 145 | the conditions stated in this License. 146 | 147 | 5. Submission of Contributions. Unless You explicitly state otherwise, 148 | any Contribution intentionally submitted for inclusion in the Work 149 | by You to the Licensor shall be under the terms and conditions of 150 | this License, without any additional terms or conditions. 151 | Notwithstanding the above, nothing herein shall supersede or modify 152 | the terms of any separate license agreement you may have executed 153 | with Licensor regarding such Contributions. 154 | 155 | 6. Trademarks. This License does not grant permission to use the trade 156 | names, trademarks, service marks, or product names of the Licensor, 157 | except as required for reasonable and customary use in describing the 158 | origin of the Work and reproducing the content of the NOTICE file. 159 | 160 | 7. Disclaimer of Warranty. Unless required by applicable law or 161 | agreed to in writing, Licensor provides the Work (and each 162 | Contributor provides its Contributions) on an "AS IS" BASIS, 163 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 164 | implied, including, without limitation, any warranties or conditions 165 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 166 | PARTICULAR PURPOSE. You are solely responsible for determining the 167 | appropriateness of using or redistributing the Work and assume any 168 | risks associated with Your exercise of permissions under this License. 169 | 170 | 8. Limitation of Liability. In no event and under no legal theory, 171 | whether in tort (including negligence), contract, or otherwise, 172 | unless required by applicable law (such as deliberate and grossly 173 | negligent acts) or agreed to in writing, shall any Contributor be 174 | liable to You for damages, including any direct, indirect, special, 175 | incidental, or consequential damages of any character arising as a 176 | result of this License or out of the use or inability to use the 177 | Work (including but not limited to damages for loss of goodwill, 178 | work stoppage, computer failure or malfunction, or any and all 179 | other commercial damages or losses), even if such Contributor 180 | has been advised of the possibility of such damages. 181 | 182 | 9. Accepting Warranty or Additional Liability. While redistributing 183 | the Work or Derivative Works thereof, You may choose to offer, 184 | and charge a fee for, acceptance of support, warranty, indemnity, 185 | or other liability obligations and/or rights consistent with this 186 | License. However, in accepting such obligations, You may act only 187 | on Your own behalf and on Your sole responsibility, not on behalf 188 | of any other Contributor, and only if You agree to indemnify, 189 | defend, and hold each Contributor harmless for any liability 190 | incurred by, or claims asserted against, such Contributor by reason 191 | of your accepting any such warranty or additional liability. 192 | 193 | END OF TERMS AND CONDITIONS 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-supervised Learning of LiDAR 3D PointClouds via 2D-3D Neural Calibration 2 | 3 | Official Pytorch implementation of the method **NCLR**. More details can be found in the paper: 4 | 5 | **Self-supervised Learning of LiDAR 3D PointClouds via 2D-3D Neural Calibration**, Arxiv 2024 [[arXiv](https://arxiv.org/abs/2401.12452)] 6 | by *Yifan Zhang, Siyu Ren, Junhui Hou, Jinjian Wu, Yixuan Yuan, Guangming Shi* 7 | 8 | ![Overview of the method](./assets/method.png) 9 | 10 | If you use NCLR in your research, please cite: 11 | ``` 12 | @article{zhang2024nclr, 13 | title={Self-supervised Learning of LiDAR 3D Point Clouds via 2D-3D Neural Calibration}, 14 | author={Zhang, Yifan and Ren, Siyu and Hou, Junhui and Wu, Jinjian and Yuan, Yixuan and Shi, Guangming}, 15 | journal={arXiv preprint arXiv:2401.12452}, 16 | year={2024} 17 | } 18 | ``` 19 | 20 | ## Dependencies 21 | 22 | To install the various dependencies, you can run ```pip install -r requirements.txt```. 23 | 24 | 25 | ## Datasets 26 | 27 | The code provided can be used with [nuScenes](https://www.nuscenes.org/lidar-segmentation), [SemanticKITTI](http://www.semantic-kitti.org/tasks.html#semseg), and [SemanticPOSS](http://www.poss.pku.edu.cn/semanticposs.html). Put the datasets you intend to use in the datasets folder (a symbolic link is accepted). 28 | 29 | ## Pre-trained models 30 | 31 | ### Minkowski SR-UNet 32 | [SR-UNet pre-trained on nuScenes](todo) 33 | 34 | [SR-UNet pre-trained on SemanticKITTI](todo) 35 | 36 | ### SPconv VoxelNet 37 | [VoxelNet pre-trained on nuScenes](todo) 38 | 39 | ## Reproducing the results 40 | 41 | *When using MinkowskiEngine (on SemanticKITTI), please set the OMP_NUM_THREADS environment variable to your number of CPU cores* 42 | 43 | ### Semantic segmentation's pre-training 44 | 45 | Config file for SemanticKITTI is included for [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) by default to keep retro-compatibility with previous work, while for nuScenes it uses [Torchsparse](https://github.com/mit-han-lab/torchsparse) which is generally faster. Switching between libraries in the config files is easy. While architectures are similar, weights from one library cannot easily be transferred to the other. 46 | 47 | - On nuScenes: 48 | 49 | ```python train.py --config_file cfgs/pretrain_ns_minkunet.yaml --name minkunet_nclr_ns``` 50 | 51 | - On SemanticKITTI: 52 | 53 | ```python train.py --config_file cfgs/pretrain_sk_minkunet.yaml --name minkunet_nclr_sk``` 54 | 55 | ### Semantic segmentation's downstream 56 | 57 | The specific code for downstream semantic segmentation has been adapted from [ALSO](https://github.com/valeoai/ALSO). 58 | 59 | #### Results on nuScenes' validation set using a Minkowski SR-Unet 34: 60 | Method | 0.1% | 1% | 10% | 50% | 100% 61 | --- |:-: |:-: |:-: |:-: |:-: 62 | Random init. | 21.6 | 35.0 | 57.3 | 69.0 | 71.2 63 | [PointContrast](https://arxiv.org/abs/2007.10985) | 27.1 | 37.0 | 58.9 | 69.4 | 71.1 64 | [DepthContrast](https://arxiv.org/abs/2101.02691) | 21.7 | 34.6 | 57.4 | 69.2 | 71.2 65 | [ALSO](https://arxiv.org/abs/2104.04687) | 26.2 | 37.4 | 59.0 | 69.8 | 71.8 66 | NCLR | **26.6** |**37.8**|**59.5**|**71.2**|**72.7** 67 | 68 | To launch a downstream experiment, with a Torchsparse SR-Unet, you can use these commands in addition with `cfg.downstream.checkpoint_dir=[checkpoint directory] cfg.downstream.checkpoint_name=[checkpoint name]` 69 | 70 | ```bash 71 | cd downstream 72 | 73 | # 100% 74 | python train_downstream_semseg.py cfg=nuscenes_torchsparse cfg.downstream.max_epochs=30 cfg.downstream.val_interval=5 cfg.downstream.skip_ratio=1 75 | 76 | # 50% 77 | python train_downstream_semseg.py cfg=nuscenes_torchsparse cfg.downstream.max_epochs=50 cfg.downstream.val_interval=5 cfg.downstream.skip_ratio=2 78 | 79 | # 10% 80 | python train_downstream_semseg.py cfg=nuscenes_torchsparse cfg.downstream.max_epochs=100 cfg.downstream.val_interval=10 cfg.downstream.skip_ratio=10 81 | 82 | # 1% 83 | python train_downstream_semseg.py cfg=nuscenes_torchsparse cfg.downstream.max_epochs=500 cfg.downstream.val_interval=50 cfg.downstream.skip_ratio=100 84 | 85 | # 0.1% 86 | python train_downstream_semseg.py cfg=nuscenes_torchsparse cfg.downstream.max_epochs=1000 cfg.downstream.val_interval=100 cfg.downstream.skip_ratio=1000 87 | ``` 88 | 89 | 90 | 101 | 102 | To launch a downstream experiment, with a Minkowski SR-Unet, you can use these commands in addition with `cfg.downstream.checkpoint_dir=[checkpoint directory] cfg.downstream.checkpoint_name=[checkpoint name]` 103 | 104 | ```bash 105 | cd downstream 106 | 107 | # 100% 108 | python train_downstream_semseg.py cfg=nuscenes_minkowski cfg.downstream.max_epochs=30 cfg.downstream.val_interval=5 cfg.downstream.skip_ratio=1 109 | 110 | # 50% 111 | python train_downstream_semseg.py cfg=nuscenes_minkowski cfg.downstream.max_epochs=50 cfg.downstream.val_interval=5 cfg.downstream.skip_ratio=2 112 | 113 | # 10% 114 | python train_downstream_semseg.py cfg=nuscenes_minkowski cfg.downstream.max_epochs=100 cfg.downstream.val_interval=10 cfg.downstream.skip_ratio=10 115 | 116 | # 1% 117 | python train_downstream_semseg.py cfg=nuscenes_minkowski cfg.downstream.max_epochs=500 cfg.downstream.val_interval=50 cfg.downstream.skip_ratio=100 118 | 119 | # 0.1% 120 | python train_downstream_semseg.py cfg=nuscenes_minkowski cfg.downstream.max_epochs=1000 cfg.downstream.val_interval=100 cfg.downstream.skip_ratio=1000 121 | ``` 122 | 123 | ### Object detection's pre-training 124 | 125 | ```python train.py --config_file cfgs/pretrain_ns_spconv.yaml --name voxelnet_nclr_ns``` 126 | 127 | ### Object detection's downstream 128 | 129 | Please use the code of [OpenPCDet](https://github.com/open-mmlab/OpenPCDet) with default parameters for SECOND or PVRCNN and with no multiprocessing to retain compatibility with previous work and this one. 130 | 131 | ### Panoptic segmentation baseline 132 | Panoptic segmentation baseline [MinkowskiPanoptic](https://github.com/PRBonn/MinkowskiPanoptic) implemented based on the MinkowskiEngine library 133 | 134 | ## Acknowledgment 135 | 136 | Part of the codebase has been adapted from [OpenPCDet](https://github.com/open-mmlab/OpenPCDet), [ALSO](https://github.com/valeoai/ALSO), [valeoai](https://github.com/valeoai/BEVContrast), and [SLidR](https://github.com/valeoai/SLidR). 137 | 138 | ## TODO List 139 | 140 | - [x] Initial release. 141 | - [x] Add license. See [here](#license) for more details. 142 | - [x] Add installation details. 143 | - [x] Add data preparation details. 144 | - [x] Add evaluation details. 145 | - [x] Add training details. 146 | - [ ] LTA 147 | - [ ] Add pre-trained weights. 148 | 149 | ## Recommended Works 150 | Here are some of the methods I recommend for 3D representation learning: 151 | - SLidR: [Paper](https://arxiv.org/abs/2203.16258), [Code](https://github.com/valeoai/SLidR) 152 | - OLIVINE: [Paper](arxiv.org/abs/2405.14271), [Code](https://github.com/Eaphan/OLIVINE) 153 | - BEVContrast: [Paper](https://arxiv.org/abs/2310.17281), [Code](https://github.com/valeoai/BEVContrast) 154 | - MinkowskiPanoptic: [Code](https://github.com/PRBonn/MinkowskiPanoptic) 155 | 156 | ## License 157 | NCLR is released under the [Apache 2.0 license](./LICENSE). 158 | -------------------------------------------------------------------------------- /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eaphan/NCLR/b3a944af649b64f0aed82aae0211ebc5f2fe2d13/assets/method.png -------------------------------------------------------------------------------- /bevlab/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eaphan/NCLR/b3a944af649b64f0aed82aae0211ebc5f2fe2d13/bevlab/__init__.py -------------------------------------------------------------------------------- /bevlab/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .spconv_backbone import BEVNet 2 | from .minkunet import MinkUNet34 3 | from .minkunet_segcontrast import SegContrastMinkUNet18 4 | from .spvcnn import SPVCNN 5 | 6 | __all__ = ["BEVNet", "MinkUNet34", "SegContrastMinkUNet18", "SPVCNN"] 7 | -------------------------------------------------------------------------------- /bevlab/backbones/minkunet_segcontrast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import MinkowskiEngine as ME 4 | import logging 5 | 6 | 7 | class BasicConvolutionBlock(nn.Module): 8 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1, D=3): 9 | super().__init__() 10 | self.net = nn.Sequential( 11 | ME.MinkowskiConvolution(inc, 12 | outc, 13 | kernel_size=ks, 14 | dilation=dilation, 15 | stride=stride, 16 | dimension=D), 17 | ME.MinkowskiBatchNorm(outc), 18 | ME.MinkowskiReLU(inplace=True) 19 | ) 20 | 21 | def forward(self, x): 22 | out = self.net(x) 23 | return out 24 | 25 | 26 | class BasicDeconvolutionBlock(nn.Module): 27 | def __init__(self, inc, outc, ks=3, stride=1, D=3): 28 | super().__init__() 29 | self.net = nn.Sequential( 30 | ME.MinkowskiConvolutionTranspose(inc, 31 | outc, 32 | kernel_size=ks, 33 | stride=stride, 34 | dimension=D), 35 | ME.MinkowskiBatchNorm(outc), 36 | ME.MinkowskiReLU(inplace=True) 37 | ) 38 | 39 | def forward(self, x): 40 | return self.net(x) 41 | 42 | 43 | class ResidualBlock(nn.Module): 44 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1, D=3): 45 | super().__init__() 46 | self.net = nn.Sequential( 47 | ME.MinkowskiConvolution(inc, 48 | outc, 49 | kernel_size=ks, 50 | dilation=dilation, 51 | stride=stride, 52 | dimension=D), 53 | ME.MinkowskiBatchNorm(outc), 54 | ME.MinkowskiReLU(inplace=True), 55 | ME.MinkowskiConvolution(outc, 56 | outc, 57 | kernel_size=ks, 58 | dilation=dilation, 59 | stride=1, 60 | dimension=D), 61 | ME.MinkowskiBatchNorm(outc) 62 | ) 63 | 64 | self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ 65 | nn.Sequential( 66 | ME.MinkowskiConvolution(inc, outc, kernel_size=1, dilation=1, stride=stride, dimension=D), 67 | ME.MinkowskiBatchNorm(outc) 68 | ) 69 | 70 | self.relu = ME.MinkowskiReLU(inplace=True) 71 | 72 | def forward(self, x): 73 | out = self.relu(self.net(x) + self.downsample(x)) 74 | return out 75 | 76 | 77 | class MinkUNet(nn.Module): 78 | def __init__(self, **kwargs): 79 | super().__init__() 80 | 81 | cr = kwargs.get('cr', 1.0) 82 | in_channels = kwargs.get('in_channels', 3) 83 | out_channels = kwargs.get('out_channels', 0) 84 | cs = [32, 32, 64, 128, 256, 256, 128, 96, 96] 85 | cs = [int(cr * x) for x in cs] 86 | self.run_up = kwargs.get('run_up', True) 87 | self.D = kwargs.get('D', 3) 88 | self.stem = nn.Sequential( 89 | ME.MinkowskiConvolution(in_channels, cs[0], kernel_size=3, stride=1, dimension=self.D), 90 | ME.MinkowskiBatchNorm(cs[0]), 91 | ME.MinkowskiReLU(True), 92 | ME.MinkowskiConvolution(cs[0], cs[0], kernel_size=3, stride=1, dimension=self.D), 93 | ME.MinkowskiBatchNorm(cs[0]), 94 | ME.MinkowskiReLU(inplace=True) 95 | ) 96 | 97 | self.stage1 = nn.Sequential( 98 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1, D=self.D), 99 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1, D=self.D), 100 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1, D=self.D), 101 | ) 102 | 103 | self.stage2 = nn.Sequential( 104 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1, D=self.D), 105 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1, D=self.D), 106 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1, D=self.D) 107 | ) 108 | 109 | self.stage3 = nn.Sequential( 110 | BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1, D=self.D), 111 | ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1, D=self.D), 112 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1, D=self.D), 113 | ) 114 | 115 | self.stage4 = nn.Sequential( 116 | BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1, D=self.D), 117 | ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1, D=self.D), 118 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1, D=self.D), 119 | ) 120 | 121 | self.up1 = nn.ModuleList([ 122 | BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2, D=self.D), 123 | nn.Sequential( 124 | ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, 125 | dilation=1, D=self.D), 126 | ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1, D=self.D), 127 | ) 128 | ]) 129 | 130 | self.up2 = nn.ModuleList([ 131 | BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2, D=self.D), 132 | nn.Sequential( 133 | ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, 134 | dilation=1, D=self.D), 135 | ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1, D=self.D), 136 | ) 137 | ]) 138 | 139 | self.up3 = nn.ModuleList([ 140 | BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2, D=self.D), 141 | nn.Sequential( 142 | ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, 143 | dilation=1, D=self.D), 144 | ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1, D=self.D), 145 | ) 146 | ]) 147 | 148 | self.up4 = nn.ModuleList([ 149 | BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2, D=self.D), 150 | nn.Sequential( 151 | ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, 152 | dilation=1, D=self.D), 153 | ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1, D=self.D), 154 | ) 155 | ]) 156 | 157 | if out_channels > 0: 158 | 159 | if 'head' in kwargs and kwargs['head'] == "bn_linear": 160 | logging.info("network - bn linear head") 161 | self.final = nn.Sequential(nn.BatchNorm1d(cs[8], affine=False), nn.Linear(cs[8], out_channels)) 162 | else: 163 | logging.info("network - linear head") 164 | self.final = nn.Sequential(nn.Linear(cs[8], out_channels)) 165 | else: 166 | self.final = None 167 | 168 | self.weight_initialization() 169 | self.dropout = nn.Dropout(0.3, True) 170 | 171 | def weight_initialization(self): 172 | for m in self.modules(): 173 | if isinstance(m, nn.BatchNorm1d): 174 | if m.weight is not None: 175 | nn.init.constant_(m.weight, 1) 176 | if m.bias is not None: 177 | nn.init.constant_(m.bias, 0) 178 | 179 | def forward(self, x): 180 | x0 = self.stem(x) 181 | x1 = self.stage1(x0) 182 | x2 = self.stage2(x1) 183 | x3 = self.stage3(x2) 184 | x4 = self.stage4(x3) 185 | 186 | y1 = self.up1[0](x4) 187 | y1 = ME.cat(y1, x3) 188 | y1 = self.up1[1](y1) 189 | 190 | y2 = self.up2[0](y1) 191 | y2 = ME.cat(y2, x2) 192 | y2 = self.up2[1](y2) 193 | 194 | y3 = self.up3[0](y2) 195 | y3 = ME.cat(y3, x1) 196 | y3 = self.up3[1](y3) 197 | 198 | y4 = self.up4[0](y3) 199 | y4 = ME.cat(y4, x0) 200 | y4 = self.up4[1](y4) 201 | 202 | yout = self.final(y4.F) 203 | 204 | return yout 205 | 206 | 207 | class SegContrastMinkUNet18(MinkUNet): 208 | 209 | def __init__(self, in_channels, out_channels, **kwargs): 210 | super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs) 211 | 212 | def forward(self, feats, coords): 213 | 214 | # coords = torch.cat([data["voxel_coords_batch"].unsqueeze(1), data["voxel_coords"]], dim=1).int() 215 | # feats = data["voxel_x"] 216 | input = ME.SparseTensor(feats, coords) 217 | 218 | outputs = super().forward(input) 219 | 220 | # vox_num = data["voxel_number"] 221 | # increment = torch.cat([vox_num.new_zeros((1,)), vox_num[:-1]], dim=0) 222 | # increment = increment.cumsum(0) 223 | # increment = increment[data["batch"]] 224 | # inv_map = data["voxel_to_pc_id"] + increment 225 | 226 | # # interpolate the outputs 227 | # outputs = outputs[inv_map] 228 | 229 | return outputs 230 | 231 | def get_last_layer_channels(self): 232 | return self.PLANES[-1] 233 | -------------------------------------------------------------------------------- /bevlab/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cumm.tensorview as tv 4 | from functools import partial 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from bevlab import datasets as datasets # TODO 8 | from torchsparse.utils.quantize import sparse_quantize as sparse_quantize_torchsparse 9 | import MinkowskiEngine as ME 10 | from torch.utils.data.distributed import DistributedSampler 11 | from spconv.utils import Point2VoxelCPU3d as VoxelGenerator 12 | 13 | 14 | def collate_torchsparse(voxel_size, use_coords, list_data): 15 | batch = {} 16 | for key in list_data[0]: 17 | batch[key] = [l[key] for l in list_data] 18 | 19 | batch["voxels"] = [] 20 | batch['indexes'] = [] 21 | batch['inv_indexes'] = [] 22 | batch['batch_n_points'] = [] 23 | coords = [] 24 | batch_id = 0 25 | offset = 0 26 | for group_pc in batch['points']: 27 | pairing_points = batch["pairing_points"][batch_id] 28 | pc_ = group_pc[:, :3] / voxel_size 29 | pc_min = pc_.min(0, keepdims=1) 30 | pc_ -= pc_min 31 | 32 | coordinates, indexes, inv_indexes = sparse_quantize_torchsparse(pc_, return_index=True, return_inverse=True) 33 | coords.append(F.pad(torch.from_numpy(coordinates), (0, 1, 0, 0), value=batch_id)) 34 | 35 | pairing_points = inv_indexes[pairing_points] 36 | pairing_points += offset 37 | batch['pairing_points'][batch_id] = pairing_points 38 | batch['pairing_images'][batch_id][:, 0] += batch_id * batch['images'][0].shape[0] 39 | 40 | batch['batch_n_points'].append(coordinates.shape[0]) 41 | batch['indexes'].append(torch.from_numpy(indexes)) 42 | batch['inv_indexes'].append(torch.from_numpy(inv_indexes) + offset) 43 | if use_coords: 44 | batch['voxels'].append(torch.from_numpy(group_pc[indexes])) 45 | else: 46 | batch['voxels'].append(torch.from_numpy(group_pc[indexes, 3:])) 47 | batch_id += 1 48 | offset += coordinates.shape[0] 49 | 50 | batch['coordinates'] = torch.cat(coords, 0).int() 51 | batch['voxels'] = torch.cat(batch['voxels']) 52 | batch['images'] = torch.cat([torch.from_numpy(img) for img in batch['images']], 0).float() 53 | batch['img_overlap_masks'] = torch.cat([torch.from_numpy(img) for img in batch['img_overlap_masks']], 0).float() 54 | 55 | batch['pairing_points'] = torch.tensor(np.concatenate(batch['pairing_points'])) 56 | batch['pairing_images'] = torch.tensor(np.concatenate(batch['pairing_images'])) 57 | # batch['pc_overlap_mask'] = [torch.from_numpy(mask) for mask in batch['pc_overlap_mask']] 58 | # batch['pc_overlap_mask'] = torch.cat(batch['pc_overlap_mask']) 59 | 60 | batch['R_data'] = torch.stack([torch.from_numpy(R) for R in batch['R_data']], axis=0).float() 61 | batch['T_data'] = torch.stack([torch.from_numpy(T) for T in batch['T_data']], axis=0).float() 62 | if 'K_data' in batch: 63 | batch['K_data'] = torch.stack([torch.from_numpy(K) for K in batch['K_data']], axis=0).float() 64 | 65 | return batch 66 | 67 | 68 | def collate_minkowski(voxel_size, use_coords, list_data): 69 | batch = {} 70 | for key in list_data[0]: 71 | batch[key] = [l[key] for l in list_data] 72 | 73 | batch["voxels"] = [] 74 | batch['indexes'] = [] 75 | batch['inv_indexes'] = [] 76 | batch['batch_n_points'] = [] 77 | batch['pc_min'] = [] 78 | coords = [] 79 | batch_id = 0 80 | offset = 0 81 | for group_pc in batch["points"]: 82 | pairing_points = batch["pairing_points"][batch_id] 83 | coords_aug = group_pc[:, :3] / voxel_size 84 | pc_ = np.round(coords_aug).astype(np.int32) 85 | pc_min = pc_.min(0, keepdims=1) 86 | pc_ -= pc_min 87 | 88 | coordinates, indexes, inv_indexes = ME.utils.quantization.sparse_quantize(pc_, return_index=True, return_inverse=True) 89 | coords.append(F.pad(coordinates, (1, 0, 0, 0), value=batch_id)) 90 | 91 | pairing_points = inv_indexes[pairing_points] 92 | pairing_points += offset 93 | batch['pairing_points'][batch_id] = pairing_points 94 | batch['pairing_images'][batch_id][:, 0] += batch_id * batch['images'][0].shape[0] 95 | 96 | batch['batch_n_points'].append(coordinates.shape[0]) 97 | pc_min = pc_min.repeat(indexes.shape[0], 0) 98 | batch['pc_min'].append(torch.from_numpy(pc_min)) 99 | batch["indexes"].append(indexes) 100 | batch['inv_indexes'].append(inv_indexes) 101 | if use_coords: 102 | batch["voxels"].append(torch.from_numpy(group_pc[indexes])) 103 | else: 104 | batch["voxels"].append(torch.from_numpy(group_pc[indexes, 3:])) 105 | batch_id += 1 106 | offset += coordinates.shape[0] 107 | batch['coordinates'] = torch.cat(coords, 0).int() 108 | # batch['pc_min'] = torch.cat(batch['pc_min']) 109 | batch['voxels'] = torch.cat(batch['voxels']) 110 | 111 | # batch['pc_overlap_mask'] = [torch.from_numpy(mask) for mask in batch['pc_overlap_mask']] 112 | # batch['pc_overlap_mask'] = torch.cat(batch['pc_overlap_mask']) 113 | 114 | batch['images'] = torch.cat([torch.from_numpy(img) for img in batch['images']], 0).float() 115 | batch['img_overlap_masks'] = torch.cat([torch.from_numpy(img) for img in batch['img_overlap_masks']], 0).float() 116 | 117 | batch['pairing_points'] = torch.tensor(np.concatenate(batch['pairing_points'])) 118 | batch['pairing_images'] = torch.tensor(np.concatenate(batch['pairing_images'])) 119 | batch['R_data'] = torch.stack([torch.from_numpy(R) for R in batch['R_data']], axis=0).float() 120 | batch['T_data'] = torch.stack([torch.from_numpy(T) for T in batch['T_data']], axis=0).float() 121 | if 'K_data' in batch: 122 | batch['K_data'] = torch.stack([torch.from_numpy(K) for K in batch['K_data']], axis=0).float() 123 | 124 | return batch 125 | 126 | 127 | class CollateSpconv: 128 | def __init__(self, config) -> None: 129 | 130 | self._voxel_generator = VoxelGenerator( 131 | vsize_xyz=config.DATASET.VOXEL_SIZE, 132 | coors_range_xyz=config.DATASET.POINT_CLOUD_RANGE, 133 | num_point_features=5, # ad hoc, one dim for indices of group_pc 134 | max_num_points_per_voxel=10, 135 | max_num_voxels=60000 136 | ) 137 | self.coors_range = config.DATASET.POINT_CLOUD_RANGE 138 | 139 | def generate(self, points): 140 | voxel_output = self._voxel_generator.point_to_voxel(tv.from_numpy(points)) 141 | tv_voxels, tv_coordinates, tv_num_points = voxel_output 142 | # make copy with numpy(), since numpy_view() will disappear as soon as the generator is deleted 143 | voxels = tv_voxels.numpy() 144 | coordinates = tv_coordinates.numpy() 145 | num_points = tv_num_points.numpy() 146 | return voxels, coordinates, num_points 147 | 148 | @staticmethod 149 | def mask_points_by_range(points, limit_range): 150 | mask = (points[:, 0] >= limit_range[0]) & (points[:, 0] <= limit_range[3]) \ 151 | & (points[:, 1] >= limit_range[1]) & (points[:, 1] <= limit_range[4]) 152 | return points[mask], mask 153 | 154 | def collate_spconv(self, voxel_size, use_coords, list_data): 155 | batch = {} 156 | for key in list_data[0]: 157 | batch[key] = [l[key] for l in list_data] 158 | 159 | batch["voxels"] = [] 160 | batch['indexes'] = [] 161 | batch['inv_indexes'] = [] 162 | batch['batch_n_points'] = [] 163 | batch['pc_min'] = [] 164 | coords = [] 165 | batch_id = 0 166 | offset = 0 167 | for group_pc in batch["points"]: 168 | group_pc_ori_len = len(group_pc) 169 | pairing_points = batch["pairing_points"][batch_id] 170 | pairing_images = batch['pairing_images'][batch_id] 171 | # raise ImportError("Please checkout the branch for spconv.") 172 | group_pc, mask = self.mask_points_by_range(group_pc, self.coors_range) 173 | # original_to_new_index = np.cumsum(mask) - 1 174 | # new_indexes = original_to_new_index[indexes] 175 | original_to_new_index = -np.ones(group_pc_ori_len, dtype=int) 176 | original_to_new_index[mask] = np.arange(len(group_pc)) 177 | 178 | pairing_points = original_to_new_index[pairing_points] 179 | 180 | pairing_images = pairing_images[pairing_points >= 0] 181 | pairing_points = pairing_points[pairing_points >= 0] 182 | 183 | # group_pc_masked_len = len(group_pc) 184 | group_pc=np.concatenate([group_pc, np.arange(len(group_pc)).reshape(-1, 1)], 1) 185 | voxels, coordinates, num_points = self.generate(group_pc) 186 | 187 | selection_indices = voxels[:, 0, 4].astype(np.int) 188 | voxels = voxels[:, :, :4] 189 | original_to_new_index = -np.ones(group_pc.shape[0], dtype=int) 190 | original_to_new_index[selection_indices] = np.arange(len(voxels)) 191 | mapped_indexes = original_to_new_index[pairing_points] 192 | valid_mask = mapped_indexes >= 0 193 | pairing_points = mapped_indexes[valid_mask] 194 | pairing_images = pairing_images[valid_mask] 195 | 196 | coordinates = torch.from_numpy(coordinates) 197 | points_mean = torch.from_numpy(voxels).sum(dim=1, keepdim=False) 198 | normalizer = torch.clamp_min(torch.from_numpy(num_points).view(-1, 1), min=1.0).type_as(points_mean) 199 | points_mean = points_mean / normalizer 200 | voxels = points_mean.contiguous() 201 | 202 | coords.append(F.pad(coordinates, (1, 0, 0, 0), value=batch_id)) 203 | 204 | # todo 205 | # pairing_points = inv_indexes[pairing_points] 206 | pairing_points += offset 207 | batch['pairing_points'][batch_id] = pairing_points 208 | batch['pairing_images'][batch_id] = pairing_images 209 | batch['pairing_images'][batch_id][:, 0] += batch_id * batch['images'][0].shape[0] 210 | 211 | pc_min = np.array([self.coors_range[0:3]]).repeat(coordinates.shape[0], 0) 212 | batch['pc_min'].append(torch.from_numpy(pc_min)) 213 | # batch["indexes"].append(indexes) 214 | # batch['inv_indexes'].append(inv_indexes) 215 | if use_coords: 216 | batch["voxels"].append(voxels) 217 | else: 218 | batch["voxels"].append(voxels[:, 3:]) 219 | batch_id += 1 220 | offset += coordinates.shape[0] 221 | batch['coordinates'] = torch.cat(coords, 0).int() 222 | # batch['pc_min'] = torch.cat(batch['pc_min']) 223 | batch['voxels'] = torch.cat(batch['voxels']) 224 | 225 | batch['images'] = torch.cat([torch.from_numpy(img) for img in batch['images']], 0).float() 226 | batch['img_overlap_masks'] = torch.cat([torch.from_numpy(img) for img in batch['img_overlap_masks']], 0).float() 227 | batch['pairing_points'] = torch.tensor(np.concatenate(batch['pairing_points'])) 228 | batch['pairing_images'] = torch.tensor(np.concatenate(batch['pairing_images'])) 229 | 230 | # batch['R_data'] = torch.stack([torch.from_numpy(np.stack(R)) for R in batch['R_data']], axis=0) 231 | # batch['T_data'] = torch.stack([torch.from_numpy(np.stack(T)) for T in batch['T_data']], axis=0) 232 | batch['R_data'] = torch.stack([torch.from_numpy(R) for R in batch['R_data']], axis=0).float() 233 | batch['T_data'] = torch.stack([torch.from_numpy(T) for T in batch['T_data']], axis=0).float() 234 | if 'K_data' in batch: 235 | batch['K_data'] = torch.stack([torch.from_numpy(K) for K in batch['K_data']], axis=0).float() 236 | 237 | if 'flow' in batch: 238 | batch['flow'] = torch.from_numpy(np.stack(batch['flow'])) 239 | 240 | return batch 241 | 242 | 243 | collate_fns = {"collate_torchsparse": collate_torchsparse, 244 | "collate_minkowski": collate_minkowski} 245 | 246 | 247 | def make_dataloader(config, phase, world_size=1, rank=0): 248 | dataset_class = getattr(datasets, config.DATASET.TRAIN) 249 | dataset = dataset_class(phase, config) 250 | try: 251 | collate_fn = collate_fns[config.ENCODER.COLLATE] 252 | except KeyError: 253 | collate_fn = CollateSpconv(config).collate_spconv 254 | 255 | use_coords = config.ENCODER.IN_CHANNELS == 4 256 | collate_fn = partial(collate_fn, config.DATASET.VOXEL_SIZE, use_coords) 257 | if world_size > 1: 258 | sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True) 259 | return DataLoader( 260 | dataset, 261 | batch_size=config.OPTIMIZATION.BATCH_SIZE_PER_GPU, 262 | shuffle=False, 263 | num_workers=config.OPTIMIZATION.NUM_WORKERS_PER_GPU, 264 | collate_fn=collate_fn, 265 | pin_memory=True, 266 | sampler=sampler, 267 | persistent_workers=True 268 | ) 269 | return DataLoader( 270 | dataset, 271 | batch_size=config.OPTIMIZATION.BATCH_SIZE_PER_GPU, 272 | shuffle=True, 273 | num_workers=config.OPTIMIZATION.NUM_WORKERS_PER_GPU, 274 | collate_fn=collate_fn, 275 | pin_memory=True, 276 | drop_last=True, 277 | worker_init_fn=lambda id: np.random.seed( 278 | torch.initial_seed() // 2 ** 32 + id 279 | ), 280 | persistent_workers=True 281 | ) 282 | -------------------------------------------------------------------------------- /bevlab/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import requests 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.transforms as T 7 | import torch.utils.model_zoo as model_zoo 8 | from torchvision.models.resnet import model_urls 9 | from bevlab.resnet_encoder import resnet_encoders 10 | import bevlab.vision_transformer as dino_vit 11 | # import vision_transformer as dino_vit 12 | 13 | _MEAN_PIXEL_IMAGENET = [0.485, 0.456, 0.406] 14 | _STD_PIXEL_IMAGENET = [0.229, 0.224, 0.225] 15 | 16 | 17 | def adapt_weights(architecture): 18 | if architecture == "imagenet" or architecture is None: 19 | return 20 | 21 | weights_url = { 22 | "moco_v3": "https://dl.fbaipublicfiles.com/moco-v3/r-50-1000ep/r-50-1000ep.pth.tar", 23 | "moco_v2": "https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", 24 | "moco_v1": "https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v1_200ep/moco_v1_200ep_pretrain.pth.tar", 25 | "swav": "https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar", 26 | "deepcluster_v2": "https://dl.fbaipublicfiles.com/deepcluster/deepclusterv2_800ep_pretrain.pth.tar", 27 | "dino": "https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth" 28 | } 29 | 30 | if not os.path.exists(f"weights/{architecture}.pt"): 31 | r = requests.get(weights_url[architecture], allow_redirects=True) 32 | os.makedirs("weights", exist_ok=True) 33 | with open(f"weights/{architecture}.pt", 'wb') as f: 34 | f.write(r.content) 35 | 36 | weights = torch.load(f"weights/{architecture}.pt") 37 | 38 | if architecture == "obow": 39 | return weights["network"] 40 | 41 | if architecture == "pixpro": 42 | weights = { 43 | k.replace("module.encoder.", ""): v 44 | for k, v in weights["model"].items() 45 | if k.startswith("module.encoder.") 46 | } 47 | return weights 48 | 49 | if architecture in ("moco_v1", "moco_v2", "moco_coco"): 50 | weights = { 51 | k.replace("module.encoder_q.", ""): v 52 | for k, v in weights["state_dict"].items() 53 | if k.startswith("module.encoder_q.") and not k.startswith("module.encoder_q.fc") 54 | } 55 | return weights 56 | 57 | if architecture == "moco_v3": 58 | weights = { 59 | k.replace("module.base_encoder.", ""): v 60 | for k, v in weights["state_dict"].items() 61 | if k.startswith("module.base_encoder.") and not k.startswith("module.base_encoder.fc") 62 | } 63 | return weights 64 | 65 | 66 | if architecture in ("swav", "deepcluster_v2"): 67 | weights = { 68 | k.replace("module.", ""): v 69 | for k, v in weights.items() 70 | if k.startswith("module.") and not k.startswith("module.pro") 71 | } 72 | return weights 73 | 74 | if architecture == "dino": 75 | return weights 76 | 77 | 78 | class Preprocessing: 79 | """ 80 | Use the ImageNet preprocessing. 81 | """ 82 | 83 | def __init__(self): 84 | normalize = T.Normalize(mean=_MEAN_PIXEL_IMAGENET, std=_STD_PIXEL_IMAGENET) 85 | self.preprocessing_img = normalize 86 | 87 | def __call__(self, image): 88 | return self.preprocessing_img(image) 89 | 90 | 91 | class ImageEncoder(nn.Module): 92 | def __init__(self, image_weights="dino", preprocessing='default'): 93 | super(ImageEncoder, self).__init__() 94 | Encoder = resnet_encoders["resnet50"]["encoder"] 95 | params = resnet_encoders["resnet50"]["params"] 96 | # params.update(replace_stride_with_dilation=[True, True, True]) 97 | self.encoder = Encoder(**params) 98 | 99 | if image_weights == "imagenet": 100 | self.encoder.load_state_dict(model_zoo.load_url(model_urls["resnet50"])) 101 | 102 | weights = adapt_weights(architecture=image_weights) 103 | if weights is not None: 104 | self.encoder.load_state_dict(weights) 105 | 106 | for param in self.encoder.parameters(): 107 | param.requires_grad = False 108 | 109 | in1 = 2048 110 | 111 | self.decoder = nn.Sequential( 112 | nn.Conv2d(in1, 64, 1), # ad hoc 113 | nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True), 114 | ) 115 | if preprocessing == 'default': 116 | self.preprocessing = Preprocessing() 117 | else: 118 | raise ValueError 119 | # self.normalize_feature = config["normalize_features"] 120 | 121 | def forward(self, x): 122 | if self.preprocessing: 123 | x = self.preprocessing(x) 124 | x = self.decoder(self.encoder(x)) 125 | # if self.normalize_feature: 126 | # x = F.normalize(x, p=2, dim=1) 127 | return x 128 | 129 | 130 | # class ImageEncoder(nn.Module): 131 | # """ 132 | # DINO Vision Transformer Feature Extractor. 133 | # """ 134 | # def __init__(self, model_type='vit_small_p8', preprocessing='default'): 135 | # super(ImageEncoder, self).__init__() 136 | # dino_models = { 137 | # "vit_small_p16": ("vit_small", 16, 384), 138 | # "vit_small_p8": ("vit_small", 8, 384), 139 | # "vit_base_p16": ("vit_base", 16, 768), 140 | # "vit_base_p8": ("vit_base", 8, 768), 141 | # } 142 | 143 | # model_name, patch_size, embed_dim = dino_models[model_type] 144 | 145 | # print("Use Vision Transformer pretrained with DINO as the image encoder") 146 | # print(f"==> model_name: {model_name}") 147 | # print(f"==> patch_size: {patch_size}") 148 | # print(f"==> embed_dim: {embed_dim}") 149 | 150 | # self.patch_size = patch_size 151 | # self.embed_dim = embed_dim 152 | 153 | # self.encoder = dino_vit.__dict__[model_name](patch_size=patch_size, num_classes=0) 154 | # dino_vit.load_pretrained_weights(self.encoder, "", None, model_name, patch_size) 155 | 156 | # for param in self.encoder.parameters(): 157 | # param.requires_grad = False 158 | 159 | # self.decoder = nn.Sequential( 160 | # nn.Conv2d(embed_dim, 64, 1), # adhoc 161 | # nn.Upsample(scale_factor=patch_size, mode="bilinear", align_corners=True), 162 | # ) 163 | # if preprocessing == 'default': 164 | # self.preprocessing = Preprocessing() 165 | # else: 166 | # raise ValueError 167 | # self.normalize_feature = False 168 | 169 | # def forward(self, x): 170 | # if self.preprocessing: 171 | # x = self.preprocessing(x) 172 | # batch_size, _, height, width = x.size() 173 | # assert (height % self.patch_size) == 0 174 | # assert (width % self.patch_size) == 0 175 | # f_height = height // self.patch_size 176 | # f_width = width // self.patch_size 177 | 178 | # x = self.encoder(x, all=True) 179 | # # the output of x should be [batch_size x (1 + f_height * f_width) x self.embed_dim] 180 | # assert x.size(1) == (1 + f_height * f_width) 181 | # # Remove the CLS token and reshape the the patch token features. 182 | # x = x[:, 1:, :].contiguous().transpose(1, 2).contiguous().view(batch_size, self.embed_dim, f_height, f_width) 183 | 184 | # x = self.decoder(x) 185 | # if self.normalize_feature: 186 | # x = F.normalize(x, p=2, dim=1) 187 | # return x 188 | 189 | 190 | if __name__=='__main__': 191 | a=torch.rand(10,3,160,512).cuda() 192 | model=ImageEncoder() 193 | model=model.cuda() 194 | b=model(a) 195 | print(b.size()) 196 | # for i in b: 197 | # print(i.size()) 198 | '''print(b[0].size()) 199 | print(b[1].size()) 200 | print(b[2].size())''' 201 | -------------------------------------------------------------------------------- /bevlab/resnet_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torchvision.models.resnet import ResNet 4 | from torchvision.models.resnet import BasicBlock 5 | from torchvision.models.resnet import Bottleneck 6 | 7 | 8 | class ResNetEncoder(ResNet): 9 | def __init__(self, **kwargs): 10 | super().__init__(**kwargs) 11 | 12 | del self.fc 13 | del self.avgpool 14 | 15 | def get_stages(self): 16 | return [ 17 | nn.Identity(), 18 | nn.Sequential(self.conv1, self.bn1, self.relu), 19 | nn.Sequential(self.maxpool, self.layer1), 20 | self.layer2, 21 | self.layer3, 22 | self.layer4, 23 | ] 24 | 25 | def forward(self, x): 26 | stages = self.get_stages() 27 | 28 | features = [] 29 | for i in range(6): 30 | x = stages[i](x) 31 | features.append(x) 32 | 33 | return features[5] 34 | 35 | def load_state_dict(self, state_dict, **kwargs): 36 | state_dict.pop("fc.bias", None) 37 | state_dict.pop("fc.weight", None) 38 | super().load_state_dict(state_dict, **kwargs) 39 | 40 | 41 | resnet_encoders = { 42 | "resnet18": { 43 | "encoder": ResNetEncoder, 44 | "params": { 45 | "block": BasicBlock, 46 | "layers": [2, 2, 2, 2], 47 | }, 48 | }, 49 | "resnet50": { 50 | "encoder": ResNetEncoder, 51 | "params": { 52 | "block": Bottleneck, 53 | "layers": [3, 4, 6, 3], 54 | }, 55 | }, 56 | } 57 | -------------------------------------------------------------------------------- /bevlab/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def revtrans_rotation(pc, trans_dict): 5 | angle = np.random.random() * 2 * np.pi 6 | c = np.cos(angle) 7 | s = np.sin(angle) 8 | rotation = np.array( 9 | [[c, -s], [s, c]], dtype=np.float32 10 | ) 11 | pc[:, :2] = pc[:, :2] @ rotation 12 | rotation = np.pad(rotation, (0, 1)) 13 | rotation[2, 2] = 1. 14 | trans_dict['R'] = rotation.T 15 | return pc, trans_dict 16 | 17 | 18 | def revtrans_translation(pc, trans_dict): 19 | translation = np.clip(np.random.normal(size=2, scale=4.).astype(np.float32), -15, 15) # no trans along z 20 | pc[:, :2] += translation 21 | trans_dict['T'] = np.pad(translation, (0, 1)) 22 | return pc, trans_dict 23 | 24 | 25 | def revtrans_jittering(pc, trans_dict): 26 | pc[:, 3] = np.random.normal(pc[:, 3], 0.01) 27 | return pc, trans_dict 28 | 29 | 30 | def revtrans_scaling(pc, trans_dict): 31 | scale = np.random.uniform(0.95, 1.05) 32 | pc[:, :3] = pc[:, :3] * scale 33 | trans_dict['S'] = scale 34 | return pc, trans_dict 35 | -------------------------------------------------------------------------------- /cfgs/pretrain_ns_minkunet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN: 'NuscenesDataset' 3 | VAL: null 4 | DATASET_ROOT: 'data/nuscenes' 5 | 6 | IMG_H: 224 7 | IMG_W: 416 8 | 9 | POINT_CLOUD_RANGE: [-51.2, -51.2, -3.0, 51.2, 51.2, 1.0] 10 | POINT_CLOUD_RANGE_VAL: [0, -40, -3, 70.4, 40, 1] 11 | DATA_SPLIT: { 12 | 'train': parametrizing, 13 | 'val': train, 14 | 'test': val 15 | } 16 | VOXEL_SIZE: 0.1 17 | APPLY_SCALING: False 18 | # INPUT_FRAMES: 1 19 | # OUTPUT_FRAMES: 1 20 | # SKIP_FRAMES: 1 21 | 22 | ENCODER: 23 | NAME: MinkUNet34 24 | COLLATE: collate_torchsparse # collate_torchsparse collate_minkowski 25 | IN_CHANNELS: 1 26 | OUT_CHANNELS: 64 # DO NOT CHANGE 27 | FEATURE_DIMENSION: 128 28 | 29 | OPTIMIZATION: 30 | BATCH_SIZE_PER_GPU: 10 31 | NUM_EPOCHS: 50 32 | NUM_WORKERS_PER_GPU: 1 33 | 34 | OPTIMIZER: AdamW 35 | LR: 0.001 36 | WEIGHT_DECAY: 0.001 37 | 38 | # LOSS: "contrast" 39 | # BEV_STRIDE: 6 40 | 41 | DEBUG: False 42 | -------------------------------------------------------------------------------- /cfgs/pretrain_ns_spconv.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN: 'NuscenesDataset' 3 | VAL: null 4 | DATASET_ROOT: 'data/nuscenes' 5 | IMG_H: 224 6 | IMG_W: 416 7 | POINT_CLOUD_RANGE: [-51.2, -51.2, -3.0, 51.2, 51.2, 1.0] 8 | POINT_CLOUD_RANGE_VAL: null 9 | DATA_SPLIT: { 10 | 'train': parametrizing, 11 | 'val': train, 12 | 'test': val 13 | } 14 | VOXEL_SIZE: [0.05, 0.05, 0.1] 15 | APPLY_SCALING: False 16 | INPUT_FRAMES: 1 17 | OUTPUT_FRAMES: 1 18 | SKIP_FRAMES: 1 19 | 20 | ENCODER: 21 | NAME: BEVNet 22 | COLLATE: collate_spconv 23 | IN_CHANNELS: 4 24 | OUT_CHANNELS: 256 # DO NOT CHANGE 25 | FEATURE_DIMENSION: 64 26 | 27 | OPTIMIZATION: 28 | BATCH_SIZE_PER_GPU: 10 29 | NUM_EPOCHS: 50 30 | NUM_WORKERS_PER_GPU: 2 31 | 32 | OPTIMIZER: AdamW 33 | LR: 0.001 34 | WEIGHT_DECAY: 0.001 35 | 36 | LOSS: "contrast" 37 | BEV_STRIDE: null 38 | -------------------------------------------------------------------------------- /cfgs/pretrain_ns_spvcnn.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN: 'NuscenesDataset' 3 | VAL: null 4 | DATASET_ROOT: 'data/nuscenes' 5 | 6 | IMG_H: 224 7 | IMG_W: 416 8 | 9 | POINT_CLOUD_RANGE: [-51.2, -51.2, -3.0, 51.2, 51.2, 1.0] 10 | POINT_CLOUD_RANGE_VAL: [0, -40, -3, 70.4, 40, 1] 11 | DATA_SPLIT: { 12 | 'train': parametrizing, 13 | 'val': train, 14 | 'test': val 15 | } 16 | VOXEL_SIZE: 0.1 17 | APPLY_SCALING: False 18 | # INPUT_FRAMES: 1 19 | # OUTPUT_FRAMES: 1 20 | # SKIP_FRAMES: 1 21 | 22 | ENCODER: 23 | NAME: SPVCNN 24 | COLLATE: collate_torchsparse 25 | IN_CHANNELS: 1 26 | OUT_CHANNELS: 64 # DO NOT CHANGE 27 | FEATURE_DIMENSION: 128 28 | 29 | OPTIMIZATION: 30 | BATCH_SIZE_PER_GPU: 10 31 | NUM_EPOCHS: 50 32 | NUM_WORKERS_PER_GPU: 1 33 | 34 | OPTIMIZER: AdamW 35 | LR: 0.001 36 | WEIGHT_DECAY: 0.001 37 | 38 | # LOSS: "contrast" 39 | # BEV_STRIDE: 6 40 | 41 | DEBUG: False 42 | -------------------------------------------------------------------------------- /cfgs/pretrain_sk_minkunet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN: 'SemanticKITTIDataset' 3 | VAL: null 4 | DATASET_ROOT: 'data/semantic_kitti' 5 | 6 | IMG_H: 160 7 | IMG_W: 512 8 | 9 | # POINT_CLOUD_RANGE: [-70, -70, -3.0, 70, 70, 1.0] 10 | POINT_CLOUD_RANGE: [-51.2, -51.2, -3.0, 51.2, 51.2, 1.0] 11 | DATA_SPLIT: { 12 | 'train': train, 13 | 'val': val, 14 | 'test': val 15 | } 16 | VOXEL_SIZE: 0.05 17 | APPLY_SCALING: False 18 | # INPUT_FRAMES: 1 19 | # OUTPUT_FRAMES: 1 20 | # SKIP_FRAMES: 6 21 | 22 | ENCODER: 23 | NAME: SegContrastMinkUNet18 24 | COLLATE: collate_minkowski 25 | IN_CHANNELS: 1 26 | OUT_CHANNELS: 64 # DO NOT CHANGE 27 | FEATURE_DIMENSION: 128 28 | 29 | OPTIMIZATION: 30 | BATCH_SIZE_PER_GPU: 4 31 | NUM_EPOCHS: 50 32 | NUM_WORKERS_PER_GPU: 1 33 | 34 | OPTIMIZER: AdamW 35 | LR: 0.001 36 | WEIGHT_DECAY: 0.001 37 | 38 | # LOSS: "contrast" 39 | # BEV_STRIDE: 4 40 | -------------------------------------------------------------------------------- /cfgs/pretrain_sk_spvcnn.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN: 'SemanticKITTIDataset' 3 | VAL: null 4 | DATASET_ROOT: 'data/semantic_kitti' 5 | 6 | IMG_H: 160 7 | IMG_W: 512 8 | 9 | # POINT_CLOUD_RANGE: [-70, -70, -3.0, 70, 70, 1.0] 10 | POINT_CLOUD_RANGE: [-51.2, -51.2, -3.0, 51.2, 51.2, 1.0] 11 | DATA_SPLIT: { 12 | 'train': train, 13 | 'val': val, 14 | 'test': val 15 | } 16 | VOXEL_SIZE: 0.05 17 | APPLY_SCALING: False 18 | # INPUT_FRAMES: 1 19 | # OUTPUT_FRAMES: 1 20 | # SKIP_FRAMES: 6 21 | 22 | ENCODER: 23 | NAME: SPVCNN 24 | COLLATE: collate_torchsparse 25 | IN_CHANNELS: 1 26 | OUT_CHANNELS: 64 # DO NOT CHANGE 27 | FEATURE_DIMENSION: 128 28 | 29 | OPTIMIZATION: 30 | BATCH_SIZE_PER_GPU: 4 31 | NUM_EPOCHS: 50 32 | NUM_WORKERS_PER_GPU: 1 33 | 34 | OPTIMIZER: AdamW 35 | LR: 0.001 36 | WEIGHT_DECAY: 0.001 37 | 38 | # LOSS: "contrast" 39 | # BEV_STRIDE: 4 -------------------------------------------------------------------------------- /downstream/configs/cfg/nuscenes_minkowski.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: NuScenes 3 | dataset_name: NuScenes 4 | dataset_root: ../data/nuscenes 5 | desc: inI_predI 6 | save_dir: results/nuScenes/ 7 | 8 | # splits 9 | train_split: parametrizing 10 | val_split: verifying 11 | test_split: val 12 | 13 | # inputs 14 | inputs: ["intensities"] # ["x"] would be only ones 15 | 16 | # optimization 17 | training: 18 | max_epochs: 200 19 | batch_size: 16 20 | val_interval: 5 21 | 22 | optimizer: torch.optim.AdamW 23 | optimizer_params: 24 | lr: 0.001 25 | weight_decay: 0.0001 26 | 27 | # network 28 | network: 29 | framework: torchsparse 30 | backbone: MinkUNet34 31 | backbone_params: 32 | quantization_params: 33 | voxel_size: 0.1 34 | decoder: InterpNet 35 | decoder_params: 36 | radius: 1.0 37 | out_channels: 2 # 1 for reconstruction, 1 for intensity 38 | intensity_loss: true 39 | radius_search: true 40 | latent_size: 128 41 | 42 | # losses 43 | loss: 44 | recons_loss_lambda: 1 45 | intensity_loss_lambda: 1 46 | 47 | # misc 48 | device: cuda 49 | num_device: 1 50 | threads: 1 51 | interactive_log: false 52 | logging: INFO 53 | resume: null 54 | 55 | # sampling 56 | manifold_points: 16384 57 | non_manifold_points: 2048 58 | 59 | # data augmentation 60 | transforms: 61 | voxel_decimation: 0.1 62 | scaling_intensities: false 63 | random_rotation_z: true 64 | random_flip: true 65 | 66 | downstream: 67 | checkpoint_dir: null 68 | checkpoint_name: null 69 | batch_size: 8 70 | num_classes: 17 71 | max_epochs: 30 72 | val_interval: 5 73 | skip_ratio: 1 74 | seed_offset: 0 75 | ignore_index: 0 76 | -------------------------------------------------------------------------------- /downstream/configs/cfg/nuscenes_torchsparse.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: NuScenes 3 | dataset_name: NuScenes 4 | dataset_root: ../data/nuscenes 5 | desc: inI_predI 6 | save_dir: results/nuScenes/ 7 | 8 | # splits 9 | train_split: parametrizing 10 | val_split: verifying 11 | test_split: val 12 | 13 | # inputs 14 | inputs: ["intensities"] # ["x"] would be only ones 15 | 16 | # optimization 17 | training: 18 | max_epochs: 200 19 | batch_size: 16 20 | val_interval: 5 21 | 22 | optimizer: torch.optim.AdamW 23 | optimizer_params: 24 | lr: 0.001 25 | weight_decay: 0.0001 26 | 27 | # network 28 | network: 29 | framework: torchsparse 30 | backbone: SPVCNN 31 | backbone_params: 32 | quantization_params: 33 | voxel_size: 0.1 34 | decoder: InterpNet 35 | decoder_params: 36 | radius: 1.0 37 | out_channels: 2 # 1 for reconstruction, 1 for intensity 38 | intensity_loss: true 39 | radius_search: true 40 | latent_size: 128 41 | 42 | # losses 43 | loss: 44 | recons_loss_lambda: 1 45 | intensity_loss_lambda: 1 46 | 47 | # misc 48 | device: cuda 49 | num_device: 1 50 | threads: 1 51 | interactive_log: false 52 | logging: INFO 53 | resume: null 54 | 55 | # sampling 56 | manifold_points: 16384 57 | non_manifold_points: 2048 58 | 59 | # data augmentation 60 | transforms: 61 | voxel_decimation: 0.1 62 | scaling_intensities: false 63 | random_rotation_z: true 64 | random_flip: true 65 | 66 | downstream: 67 | checkpoint_dir: null 68 | checkpoint_name: null 69 | batch_size: 8 70 | num_classes: 17 71 | max_epochs: 30 72 | val_interval: 5 73 | skip_ratio: 1 74 | seed_offset: 0 75 | ignore_index: 0 76 | -------------------------------------------------------------------------------- /downstream/configs/cfg/semantickitti_minkowski.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: SemanticKITTI 3 | dataset_name: SemanticKITTI 4 | dataset_root: ../data/semantic_kitti/ 5 | desc: inI_predI 6 | save_dir: results/SemanticKITTI/ 7 | 8 | # splits 9 | train_split: train 10 | val_split: val 11 | test_split: val 12 | 13 | # inputs 14 | inputs: ["intensities"] #["x"] # x is ones, intensities, dirs, normals, pos, rgb in the desired order 15 | 16 | # optimization 17 | training: 18 | max_epochs: 50 19 | batch_size: 4 20 | val_interval: 5 21 | 22 | optimizer: torch.optim.AdamW 23 | optimizer_params: 24 | lr: 0.001 25 | 26 | # network 27 | network: 28 | framework: minkowski_engine 29 | backbone: SegContrastMinkUNet18 30 | backbone_params: 31 | quantization_params: 32 | voxel_size: 0.05 33 | latent_size: 128 34 | 35 | # misc 36 | device: cuda 37 | num_device: 1 38 | threads: 1 39 | interactive_log: false 40 | logging: INFO 41 | resume: null 42 | 43 | 44 | # sampling 45 | manifold_points: 80000 46 | non_manifold_points: 4096 47 | 48 | # data augmentation 49 | transforms: 50 | voxel_decimation: 0.05 51 | scaling_intensities: false 52 | random_rotation_z: true 53 | random_flip: true 54 | 55 | downstream: 56 | checkpoint_dir: null 57 | checkpoint_name: null 58 | batch_size: 8 59 | num_classes: 20 60 | max_epochs: 30 61 | val_interval: 5 62 | skip_ratio: 1 63 | seed_offset: 0 64 | ignore_index: 0 65 | -------------------------------------------------------------------------------- /downstream/configs/cfg/semantickitti_torchsparse.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: SemanticKITTI 3 | dataset_name: SemanticKITTI 4 | dataset_root: ../data/semantic_kitti/ 5 | desc: inI_predI 6 | save_dir: results/SemanticKITTI/ 7 | 8 | # splits 9 | train_split: train 10 | val_split: val 11 | test_split: val 12 | 13 | # inputs 14 | inputs: ["intensities"] # ["x"] would be only ones 15 | 16 | # optimization 17 | training: 18 | max_epochs: 50 19 | batch_size: 4 20 | val_interval: 5 21 | 22 | optimizer: torch.optim.AdamW 23 | optimizer_params: 24 | lr: 0.001 25 | 26 | # network 27 | network: 28 | framework: torchsparse 29 | backbone: MinkUNet18SC 30 | backbone_params: 31 | quantization_params: 32 | voxel_size: 0.05 33 | latent_size: 128 34 | 35 | # misc 36 | device: cuda 37 | num_device: 1 38 | threads: 4 39 | interactive_log: false 40 | logging: INFO 41 | resume: null 42 | 43 | 44 | # sampling 45 | manifold_points: 80000 46 | non_manifold_points: 4096 47 | 48 | # data augmentation 49 | transforms: 50 | voxel_decimation: 0.05 51 | scaling_intensities: false 52 | random_rotation_z: true 53 | random_flip: true 54 | 55 | downstream: 56 | checkpoint_dir: null 57 | checkpoint_name: null 58 | batch_size: 8 59 | num_classes: 20 60 | max_epochs: 30 61 | val_interval: 5 62 | skip_ratio: 1 63 | seed_offset: 0 64 | ignore_index: 0 65 | -------------------------------------------------------------------------------- /downstream/configs/cfg/semanticposs_minkowski.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: SemanticPOSS 3 | dataset_name: SemanticPOSS 4 | dataset_root: ../data/semantic_poss 5 | desc: inI_predI 6 | save_dir: results/SemanticPOSS/ 7 | 8 | # splits 9 | train_split: train 10 | val_split: val 11 | test_split: val 12 | 13 | # inputs 14 | inputs: ["intensities"] #["x"] # x is ones, intensities, dirs, normals, pos, rgb in the desired order 15 | 16 | # optimization 17 | training: 18 | max_epochs: 50 19 | batch_size: 2 20 | val_interval: 5 21 | 22 | optimizer: torch.optim.AdamW 23 | optimizer_params: 24 | lr: 0.001 25 | 26 | # network 27 | network: 28 | framework: minkowski_engine 29 | backbone: SegContrastMinkUNet18 30 | backbone_params: 31 | quantization_params: 32 | voxel_size: 0.05 33 | latent_size: 128 34 | 35 | # misc 36 | device: cuda 37 | num_device: 1 38 | threads: 4 39 | interactive_log: false 40 | logging: INFO 41 | resume: null 42 | 43 | 44 | # sampling 45 | manifold_points: 80000 46 | non_manifold_points: 2048 47 | 48 | # data augmentation 49 | transforms: 50 | voxel_decimation: 0.05 51 | scaling_intensities: false 52 | random_rotation_z: true 53 | random_flip: true 54 | 55 | downstream: 56 | checkpoint_dir: null 57 | checkpoint_name: null 58 | batch_size: 2 59 | num_classes: 20 60 | max_epochs: 30 61 | val_interval: 5 62 | skip_ratio: 1 63 | seed_offset: 0 64 | ignore_index: 0 65 | -------------------------------------------------------------------------------- /downstream/configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - cfg: nuscenes -------------------------------------------------------------------------------- /downstream/convert_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import argparse 4 | import warnings 5 | import logging 6 | 7 | 8 | if __name__ == "__main__": 9 | warnings.filterwarnings("ignore", category=UserWarning) 10 | 11 | parser = argparse.ArgumentParser(description='Self supervised.') 12 | parser.add_argument('--downstream', action="store_true") 13 | parser.add_argument('--ckpt', '-c', type=str, required=True) 14 | opts = parser.parse_args() 15 | 16 | logging.getLogger().setLevel("INFO") 17 | 18 | logging.info("Loading the checkpoint") 19 | state_dict = torch.load(opts.ckpt, map_location="cpu")["state_dict"] 20 | 21 | ckpt_dir = os.path.dirname(opts.ckpt) 22 | ckpt_name = os.path.basename(opts.ckpt) 23 | 24 | if opts.downstream: 25 | logging.info("Filtering the state dict") 26 | # filter the state dict 27 | trained_dict = {} 28 | for k, v in state_dict.items(): 29 | if k[:4] != "net.": # keep only the weights of the backbone 30 | continue 31 | trained_dict[k[4:]] = v 32 | 33 | logging.info("Saving the weights") 34 | torch.save(trained_dict, os.path.join(ckpt_dir, "trained_model_" + ckpt_name)) 35 | else: 36 | 37 | logging.info("Filtering the state dict") 38 | # filter the state dict 39 | pretrained_dict = {} 40 | classifier_dict = {} 41 | for k, v in state_dict.items(): 42 | if "backbone." not in k: # keep only the weights of the backbone 43 | continue 44 | if "classifier." in k: # do not keep the weights of the classifier 45 | classifier_dict[k.replace("backbone.", "")] = v 46 | continue 47 | # print("backbone", k) 48 | pretrained_dict[k.replace("backbone.", "")] = v 49 | 50 | logging.info("Saving the weights") 51 | torch.save(pretrained_dict, os.path.join(ckpt_dir, "pretrained_backbone_" + ckpt_name)) 52 | torch.save(classifier_dict, os.path.join(ckpt_dir, "pretrained_classifier_" + ckpt_name)) 53 | 54 | logging.info("Done") 55 | -------------------------------------------------------------------------------- /downstream/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .nuscenes_dataset import NuScenes 2 | from .semantickitti_dataset import SemanticKITTI 3 | from .kitti3d import KITTI3D 4 | from .kitti360 import KITTI360 5 | from .once import ONCE 6 | 7 | from .semantic_poss_segcontrast import SemanticPOSS_SegContrast as SemanticPOSS 8 | # from .livox import LivoxSimu 9 | -------------------------------------------------------------------------------- /downstream/datasets/kitti360.py: -------------------------------------------------------------------------------- 1 | from fileinput import filename 2 | import os 3 | import sys 4 | import numpy as np 5 | import torch 6 | from torch_geometric.data import Dataset 7 | from torch_geometric.data import Data 8 | import logging 9 | from tqdm import tqdm 10 | 11 | class KITTI360(Dataset): 12 | 13 | def __init__(self, 14 | root, 15 | split="training", 16 | transform=None, 17 | dataset_size=None, 18 | multiframe_range=None, 19 | skip_ratio=1, 20 | **kwargs): 21 | 22 | 23 | super().__init__(root, transform, None) 24 | 25 | self.split = split 26 | self.n_frames = 1 27 | self.multiframe_range = multiframe_range 28 | 29 | logging.info(f"KITTI360 - {split}") 30 | 31 | if split=="train": 32 | filenames_list = "kitti_360_train_velodynes.txt" 33 | elif split=="val": 34 | filenames_list = "kitti_360_val_velodynes.txt" 35 | 36 | with open(os.path.join("datasets", filenames_list), "r") as f: 37 | files = f.readlines() 38 | 39 | if split=="val": # for fast validation 40 | files = files[::20] 41 | 42 | self.files = [os.path.join(self.root, f.split("\n")[0]) for f in files] 43 | 44 | 45 | logging.info(f"KITTI360 - {len(self.files)} files") 46 | 47 | def _download(self): # override _download to remove makedirs 48 | pass 49 | 50 | def download(self): 51 | pass 52 | 53 | def process(self): 54 | pass 55 | 56 | def _process(self): 57 | pass 58 | 59 | def len(self): 60 | return len(self.files) 61 | 62 | def get(self, idx): 63 | 64 | fname_points = self.files[idx] 65 | frame_points = np.fromfile(fname_points, dtype=np.float32) 66 | 67 | pos = frame_points.reshape((-1, 4)) 68 | intensities = pos[:,3:] 69 | pos = pos[:,:3] 70 | 71 | pos = torch.tensor(pos, dtype=torch.float) 72 | intensities = torch.tensor(intensities, dtype=torch.float) 73 | x = torch.ones((pos.shape[0],1), dtype=torch.float) 74 | 75 | return Data(x=x, intensities=intensities, pos=pos, shape_id=idx, ) 76 | 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | print("Creating the list of training frames") 82 | 83 | with open("2013_05_28_drive_train.txt", "r") as f: 84 | lines = f.readlines() 85 | 86 | print("writing train files...") 87 | with open("kitti_360_train_velodynes.txt", "w") as f: 88 | 89 | for line in tqdm(lines, ncols=100): 90 | 91 | line = line.split("\n")[0] 92 | 93 | directory = line.split("/")[2] 94 | 95 | filename = os.path.basename(line) 96 | 97 | first_file = int(os.path.splitext(filename)[0].split("_")[0]) 98 | second_file = int(os.path.splitext(filename)[0].split("_")[1]) 99 | 100 | fnames = [] 101 | for i in range(first_file, second_file+1): 102 | fname = f"{i:010d}.bin" 103 | fname = os.path.join("data_3d_raw", directory, "velodyne_points/data", fname) 104 | fname = str(fname) 105 | fnames.append(fname) 106 | 107 | for item in fnames: 108 | # write each item on a new line 109 | f.write("%s\n" % item) 110 | print('Done') 111 | 112 | print("Creating the list of val frames") 113 | 114 | with open("2013_05_28_drive_val.txt", "r") as f: 115 | lines = f.readlines() 116 | 117 | print("writing val files...") 118 | with open("kitti_360_val_velodynes.txt", "w") as f: 119 | 120 | for line in tqdm(lines, ncols=100): 121 | 122 | line = line.split("\n")[0] 123 | 124 | directory = line.split("/")[2] 125 | 126 | filename = os.path.basename(line) 127 | 128 | first_file = int(os.path.splitext(filename)[0].split("_")[0]) 129 | second_file = int(os.path.splitext(filename)[0].split("_")[1]) 130 | 131 | fnames = [] 132 | for i in range(first_file, second_file+1): 133 | fname = f"{i:010d}.bin" 134 | fname = os.path.join("data_3d_raw", directory, "velodyne_points/data", fname) 135 | fname = str(fname) 136 | fnames.append(fname) 137 | 138 | for item in fnames: 139 | # write each item on a new line 140 | f.write("%s\n" % item) 141 | print('Done') -------------------------------------------------------------------------------- /downstream/datasets/kitti3d.py: -------------------------------------------------------------------------------- 1 | from fileinput import filename 2 | import os 3 | import sys 4 | import numpy as np 5 | import torch 6 | from torch_geometric.data import Dataset 7 | from torch_geometric.data import Data 8 | import logging 9 | from tqdm import tqdm 10 | 11 | class KITTI3D(Dataset): 12 | 13 | def __init__(self, 14 | root, 15 | split="training", 16 | transform=None, 17 | dataset_size=None, 18 | multiframe_range=None, 19 | skip_ratio=1, 20 | **kwargs): 21 | 22 | 23 | super().__init__(root, transform, None) 24 | 25 | self.split = split 26 | self.n_frames = 1 27 | self.multiframe_range = multiframe_range 28 | 29 | logging.info(f"KITTI3D - {split}") 30 | 31 | if split=="train": 32 | filenames_list = "kitti3d_train.txt" 33 | elif split=="val": 34 | filenames_list = "kitti3d_val.txt" 35 | 36 | with open(os.path.join("datasets", filenames_list), "r") as f: 37 | files = f.readlines() 38 | 39 | if split=="val": # for fast validation 40 | files = files[::20] 41 | 42 | 43 | self.files = [os.path.join(self.root, "training/velodyne", f.split("\n")[0]+".bin") for f in files] 44 | 45 | 46 | logging.info(f"KITTI3D - {len(self.files)} files") 47 | 48 | def _download(self): # override _download to remove makedirs 49 | pass 50 | 51 | def download(self): 52 | pass 53 | 54 | def process(self): 55 | pass 56 | 57 | def _process(self): 58 | pass 59 | 60 | def len(self): 61 | return len(self.files) 62 | 63 | def get(self, idx): 64 | 65 | fname_points = self.files[idx] 66 | frame_points = np.fromfile(fname_points, dtype=np.float32) 67 | 68 | pos = frame_points.reshape((-1, 4)) 69 | intensities = pos[:,3:] 70 | pos = pos[:,:3] 71 | 72 | pos = torch.tensor(pos, dtype=torch.float) 73 | intensities = torch.tensor(intensities, dtype=torch.float) 74 | x = torch.ones((pos.shape[0],1), dtype=torch.float) 75 | 76 | return Data(x=x, intensities=intensities, pos=pos, shape_id=idx, ) -------------------------------------------------------------------------------- /downstream/datasets/nuscenes_category.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "token": "aaddc3454ccbefbb2d8d8461f8f7f481", 4 | "name": "noise", 5 | "description": "Any lidar return that does not correspond to a physical object, such as dust, vapor, noise, fog, raindrops, smoke and reflections.", 6 | "index": 0 7 | }, 8 | { 9 | "token": "63a94dfa99bb47529567cd90d3b58384", 10 | "name": "animal", 11 | "description": "All animals, e.g. cats, rats, dogs, deer, birds.", 12 | "index": 1 13 | }, 14 | { 15 | "token": "1fa93b757fc74fb197cdd60001ad8abf", 16 | "name": "human.pedestrian.adult", 17 | "description": "Adult subcategory.", 18 | "index": 2 19 | }, 20 | { 21 | "token": "b1c6de4c57f14a5383d9f963fbdcb5cb", 22 | "name": "human.pedestrian.child", 23 | "description": "Child subcategory.", 24 | "index": 3 25 | }, 26 | { 27 | "token": "909f1237d34a49d6bdd27c2fe4581d79", 28 | "name": "human.pedestrian.construction_worker", 29 | "description": "Construction worker", 30 | "index": 4 31 | }, 32 | { 33 | "token": "403fede16c88426885dd73366f16c34a", 34 | "name": "human.pedestrian.personal_mobility", 35 | "description": "A small electric or self-propelled vehicle, e.g. skateboard, segway, or scooters, on which the person typically travels in a upright position. Driver and (if applicable) rider should be included in the bounding box along with the vehicle.", 36 | "index": 5 37 | }, 38 | { 39 | "token": "bb867e2064014279863c71a29b1eb381", 40 | "name": "human.pedestrian.police_officer", 41 | "description": "Police officer.", 42 | "index": 6 43 | }, 44 | { 45 | "token": "6a5888777ca14867a8aee3fe539b56c4", 46 | "name": "human.pedestrian.stroller", 47 | "description": "Strollers. If a person is in the stroller, include in the annotation.", 48 | "index": 7 49 | }, 50 | { 51 | "token": "b2d7c6c701254928a9e4d6aac9446d79", 52 | "name": "human.pedestrian.wheelchair", 53 | "description": "Wheelchairs. If a person is in the wheelchair, include in the annotation.", 54 | "index": 8 55 | }, 56 | { 57 | "token": "653f7efbb9514ce7b81d44070d6208c1", 58 | "name": "movable_object.barrier", 59 | "description": "Temporary road barrier placed in the scene in order to redirect traffic. Commonly used at construction sites. This includes concrete barrier, metal barrier and water barrier. No fences.", 60 | "index": 9 61 | }, 62 | { 63 | "token": "063c5e7f638343d3a7230bc3641caf97", 64 | "name": "movable_object.debris", 65 | "description": "Movable object that is left on the driveable surface that is too large to be driven over safely, e.g tree branch, full trash bag etc.", 66 | "index": 10 67 | }, 68 | { 69 | "token": "d772e4bae20f493f98e15a76518b31d7", 70 | "name": "movable_object.pushable_pullable", 71 | "description": "Objects that a pedestrian may push or pull. For example dolleys, wheel barrows, garbage-bins, or shopping carts.", 72 | "index": 11 73 | }, 74 | { 75 | "token": "85abebdccd4d46c7be428af5a6173947", 76 | "name": "movable_object.trafficcone", 77 | "description": "All types of traffic cone.", 78 | "index": 12 79 | }, 80 | { 81 | "token": "0a30519ee16a4619b4f4acfe2d78fb55", 82 | "name": "static_object.bicycle_rack", 83 | "description": "Area or device intended to park or secure the bicycles in a row. It includes all the bikes parked in it and any empty slots that are intended for parking bikes.", 84 | "index": 13 85 | }, 86 | { 87 | "token": "fc95c87b806f48f8a1faea2dcc2222a4", 88 | "name": "vehicle.bicycle", 89 | "description": "Human or electric powered 2-wheeled vehicle designed to travel at lower speeds either on road surface, sidewalks or bike paths.", 90 | "index": 14 91 | }, 92 | { 93 | "token": "003edbfb9ca849ee8a7496e9af3025d4", 94 | "name": "vehicle.bus.bendy", 95 | "description": "Bendy bus subcategory. Annotate each section of the bendy bus individually.", 96 | "index": 15 97 | }, 98 | { 99 | "token": "fedb11688db84088883945752e480c2c", 100 | "name": "vehicle.bus.rigid", 101 | "description": "Rigid bus subcategory.", 102 | "index": 16 103 | }, 104 | { 105 | "token": "fd69059b62a3469fbaef25340c0eab7f", 106 | "name": "vehicle.car", 107 | "description": "Vehicle designed primarily for personal use, e.g. sedans, hatch-backs, wagons, vans, mini-vans, SUVs and jeeps. If the vehicle is designed to carry more than 10 people use vehicle.bus. If it is primarily designed to haul cargo use vehicle.truck. ", 108 | "index": 17 109 | }, 110 | { 111 | "token": "5b3cd6f2bca64b83aa3d0008df87d0e4", 112 | "name": "vehicle.construction", 113 | "description": "Vehicles primarily designed for construction. Typically very slow moving or stationary. Cranes and extremities of construction vehicles are only included in annotations if they interfere with traffic. Trucks used to haul rocks or building materials are considered vehicle.truck rather than construction vehicles.", 114 | "index": 18 115 | }, 116 | { 117 | "token": "732cce86872640628788ff1bb81006d4", 118 | "name": "vehicle.emergency.ambulance", 119 | "description": "All types of ambulances.", 120 | "index": 19 121 | }, 122 | { 123 | "token": "7b2ff083a64e4d53809ae5d9be563504", 124 | "name": "vehicle.emergency.police", 125 | "description": "All types of police vehicles including police bicycles and motorcycles.", 126 | "index": 20 127 | }, 128 | { 129 | "token": "dfd26f200ade4d24b540184e16050022", 130 | "name": "vehicle.motorcycle", 131 | "description": "Gasoline or electric powered 2-wheeled vehicle designed to move rapidly (at the speed of standard cars) on the road surface. This category includes all motorcycles, vespas and scooters.", 132 | "index": 21 133 | }, 134 | { 135 | "token": "90d0f6f8e7c749149b1b6c3a029841a8", 136 | "name": "vehicle.trailer", 137 | "description": "Any vehicle trailer, both for trucks, cars and bikes.", 138 | "index": 22 139 | }, 140 | { 141 | "token": "6021b5187b924d64be64a702e5570edf", 142 | "name": "vehicle.truck", 143 | "description": "Vehicles primarily designed to haul cargo including pick-ups, lorrys, trucks and semi-tractors. Trailers hauled after a semi-tractor should be labeled as vehicle.trailer", 144 | "index": 23 145 | }, 146 | { 147 | "token": "89d20ff31e1fbdc844a74ff50f90c65c", 148 | "name": "flat.driveable_surface", 149 | "description": "All paved or unpaved surfaces that a car can drive on with no concern of traffic rules.", 150 | "index": 24 151 | }, 152 | { 153 | "token": "65deb30a3b9481422af8ad8adc983d63", 154 | "name": "flat.other", 155 | "description": "All other forms of horizontal ground-level structures that do not belong to any of driveable_surface, curb, sidewalk and terrain. Includes elevated parts of traffic islands, delimiters, rail tracks, stairs with at most 3 steps and larger bodies of water (lakes, rivers).", 156 | "index": 25 157 | }, 158 | { 159 | "token": "bf7b16f053ff2ea504a3d083fed223dd", 160 | "name": "flat.sidewalk", 161 | "description": "Sidewalk, pedestrian walkways, bike paths, etc. Part of the ground designated for pedestrians or cyclists. Sidewalks do **not** have to be next to a road.", 162 | "index": 26 163 | }, 164 | { 165 | "token": "2cfb3bdc510a4d28a4a1a78e611e4dfc", 166 | "name": "flat.terrain", 167 | "description": "Natural horizontal surfaces such as ground level horizontal vegetation (< 20 cm tall), grass, rolling hills, soil, sand and gravel.", 168 | "index": 27 169 | }, 170 | { 171 | "token": "a6773ab08859eb7037c36acc6d302a57", 172 | "name": "static.manmade", 173 | "description": "Includes man-made structures but not limited to: buildings, walls, guard rails, fences, poles, drainages, hydrants, flags, banners, street signs, electric circuit boxes, traffic lights, parking meters and stairs with more than 3 steps.", 174 | "index": 28 175 | }, 176 | { 177 | "token": "0d35abf67670c9b13a4fe6550c698e73", 178 | "name": "static.other", 179 | "description": "Points in the background that are not distinguishable, or objects that do not match any of the above labels.", 180 | "index": 29 181 | }, 182 | { 183 | "token": "e8fc03c4a3ce3cd25c9bc1c808197861", 184 | "name": "static.vegetation", 185 | "description": "Any vegetation in the frame that is higher than the ground, including bushes, plants, potted plants, trees, etc. Only tall grass (> 20cm) is part of this, ground level grass is part of `terrain`.", 186 | "index": 30 187 | }, 188 | { 189 | "token": "3847caf8adb16ed747535b76fdb9fd05", 190 | "name": "vehicle.ego", 191 | "description": "The vehicle on which the cameras, radar and lidar are mounted, that is sometimes visible at the bottom of the image.", 192 | "index": 31 193 | } 194 | ] -------------------------------------------------------------------------------- /downstream/datasets/nuscenes_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import logging 5 | 6 | import torch 7 | 8 | from torch_geometric.data import Dataset 9 | from torch_geometric.data import Data 10 | 11 | from nuscenes import NuScenes as NuScenes_ 12 | from nuscenes.utils.splits import create_splits_scenes 13 | from nuscenes.utils.data_io import load_bin_file 14 | from nuscenes.utils.data_classes import LidarPointCloud 15 | 16 | CUSTOM_SPLIT = [ 17 | "scene-0008", "scene-0009", "scene-0019", "scene-0029", "scene-0032", "scene-0042", 18 | "scene-0045", "scene-0049", "scene-0052", "scene-0054", "scene-0056", "scene-0066", 19 | "scene-0067", "scene-0073", "scene-0131", "scene-0152", "scene-0166", "scene-0168", 20 | "scene-0183", "scene-0190", "scene-0194", "scene-0208", "scene-0210", "scene-0211", 21 | "scene-0241", "scene-0243", "scene-0248", "scene-0259", "scene-0260", "scene-0261", 22 | "scene-0287", "scene-0292", "scene-0297", "scene-0305", "scene-0306", "scene-0350", 23 | "scene-0352", "scene-0358", "scene-0361", "scene-0365", "scene-0368", "scene-0377", 24 | "scene-0388", "scene-0391", "scene-0395", "scene-0413", "scene-0427", "scene-0428", 25 | "scene-0438", "scene-0444", "scene-0452", "scene-0453", "scene-0459", "scene-0463", 26 | "scene-0464", "scene-0475", "scene-0513", "scene-0533", "scene-0544", "scene-0575", 27 | "scene-0587", "scene-0589", "scene-0642", "scene-0652", "scene-0658", "scene-0669", 28 | "scene-0678", "scene-0687", "scene-0701", "scene-0703", "scene-0706", "scene-0710", 29 | "scene-0715", "scene-0726", "scene-0735", "scene-0740", "scene-0758", "scene-0786", 30 | "scene-0790", "scene-0804", "scene-0806", "scene-0847", "scene-0856", "scene-0868", 31 | "scene-0882", "scene-0897", "scene-0899", "scene-0976", "scene-0996", "scene-1012", 32 | "scene-1015", "scene-1016", "scene-1018", "scene-1020", "scene-1024", "scene-1044", 33 | "scene-1058", "scene-1094", "scene-1098", "scene-1107", 34 | ] 35 | 36 | 37 | class NuScenes(Dataset): 38 | 39 | N_LABELS=17 40 | 41 | def __init__(self, 42 | root, 43 | split="training", 44 | transform=None, 45 | skip_ratio=1, 46 | skip_for_visu=1, 47 | **kwargs): 48 | 49 | super().__init__(root, transform, None) 50 | 51 | self.nusc = NuScenes_(version='v1.0-trainval', dataroot=self.root, verbose=True) 52 | self.split = split 53 | 54 | self.label_to_name = {0: 'noise', 55 | 1: 'animal', 56 | 2: 'human.pedestrian.adult', 57 | 3: 'human.pedestrian.child', 58 | 4: 'human.pedestrian.construction_worker', 59 | 5: 'human.pedestrian.personal_mobility', 60 | 6: 'human.pedestrian.police_officer', 61 | 7: 'human.pedestrian.stroller', 62 | 8: 'human.pedestrian.wheelchair', 63 | 9: 'movable_object.barrier', 64 | 10: 'movable_object.debris', 65 | 11: 'movable_object.pushable_pullable', 66 | 12: 'movable_object.trafficcone', 67 | 13: 'static_object.bicycle_rack', 68 | 14: 'vehicle.bicycle', 69 | 15: 'vehicle.bus.bendy', 70 | 16: 'vehicle.bus.rigid', 71 | 17: 'vehicle.car', 72 | 18: 'vehicle.construction', 73 | 19: 'vehicle.emergency.ambulance', 74 | 20: 'vehicle.emergency.police', 75 | 21: 'vehicle.motorcycle', 76 | 22: 'vehicle.trailer', 77 | 23: 'vehicle.truck', 78 | 24: 'flat.driveable_surface', 79 | 25: 'flat.other', 80 | 26: 'flat.sidewalk', 81 | 27: 'flat.terrain', 82 | 28: 'static.manmade', 83 | 29: 'static.other', 84 | 30: 'static.vegetation', 85 | 31: 'vehicle.ego' 86 | } 87 | 88 | self.label_to_name_reduced = { 89 | 0: 'noise', 90 | 1: 'barrier', 91 | 2: 'bicycle', 92 | 3: 'bus', 93 | 4: 'car', 94 | 5: 'construction_vehicle', 95 | 6: 'motorcycle', 96 | 7: 'pedestrian', 97 | 8: 'traffic_cone', 98 | 9: 'trailer', 99 | 10: 'truck', 100 | 11: 'driveable_surface', 101 | 12: 'other_flat', 102 | 13: 'sidewalk', 103 | 14: 'terrain', 104 | 15: 'manmade', 105 | 16: 'vegetation', 106 | } 107 | 108 | self.label_to_reduced = { 109 | 1: 0, 110 | 5: 0, 111 | 7: 0, 112 | 8: 0, 113 | 10: 0, 114 | 11: 0, 115 | 13: 0, 116 | 19: 0, 117 | 20: 0, 118 | 0: 0, 119 | 29: 0, 120 | 31: 0, 121 | 9: 1, 122 | 14: 2, 123 | 15: 3, 124 | 16: 3, 125 | 17: 4, 126 | 18: 5, 127 | 21: 6, 128 | 2: 7, 129 | 3: 7, 130 | 4: 7, 131 | 6: 7, 132 | 12: 8, 133 | 22: 9, 134 | 23: 10, 135 | 24: 11, 136 | 25: 12, 137 | 26: 13, 138 | 27: 14, 139 | 28: 15, 140 | 30: 16 141 | } 142 | 143 | self.label_to_reduced_np = np.zeros(32, dtype=np.int) 144 | for i in range(32): 145 | self.label_to_reduced_np[i] = self.label_to_reduced[i] 146 | 147 | self.reduced_colors = np.array([ 148 | [0, 0, 0], 149 | [112, 128, 144], 150 | [220, 20, 60], # Crimson 151 | [255, 127, 80], # Coral 152 | [255, 158, 0], # Orange 153 | [233, 150, 70], # Darksalmon 154 | [255, 61, 99], # Red 155 | [0, 0, 230], # Blue 156 | [47, 79, 79], # Darkslategrey 157 | [255, 140, 0], # Darkorange 158 | [255, 99, 71], # Tomato 159 | [0, 207, 191], # nuTonomy green 160 | [175, 0, 75], 161 | [75, 0, 75], 162 | [112, 180, 60], 163 | [222, 184, 135], # Burlywood 164 | [0, 175, 0], # Green 165 | ], dtype=np.uint8) 166 | 167 | 168 | ############## 169 | logging.info(f"Nuscenes dataset - creating splits - split {split}") 170 | # from nuscenes.utils import splits 171 | 172 | # get the scenes 173 | assert(split in ["train", "val", "test", "verifying", "parametrizing"]) 174 | if split == "verifying": 175 | phase_scenes = CUSTOM_SPLIT 176 | elif split == "parametrizing": 177 | phase_scenes = list( set(create_splits_scenes()["train"]) - set(CUSTOM_SPLIT) ) 178 | else: 179 | phase_scenes = create_splits_scenes()[split] 180 | 181 | 182 | # create a list of camera & lidar scans 183 | skip_counter = 0 184 | self.list_keyframes = [] 185 | for scene_idx in range(len(self.nusc.scene)): 186 | scene = self.nusc.scene[scene_idx] 187 | if scene["name"] in phase_scenes: 188 | 189 | skip_counter += 1 190 | if skip_counter % skip_ratio == 0: 191 | current_sample_token = scene["first_sample_token"] 192 | 193 | # Loop to get all successive keyframes 194 | list_data = [] 195 | while current_sample_token != "": 196 | current_sample = self.nusc.get("sample", current_sample_token) 197 | list_data.append(current_sample) 198 | current_sample_token = current_sample["next"] 199 | 200 | if skip_for_visu > 1: 201 | break 202 | 203 | # Add new scans in the list 204 | self.list_keyframes.extend(list_data) 205 | 206 | self.list_keyframes = self.list_keyframes[::skip_for_visu] 207 | 208 | if len(self.list_keyframes)==0: 209 | # add only one scene 210 | # scenes with all labels (parametrizing split) "scene-0392", "scene-0517", "scene-0656", "scene-0730", "scene-0738" 211 | for scene_idx in range(len(self.nusc.scene)): 212 | scene = self.nusc.scene[scene_idx] 213 | if scene["name"] in phase_scenes and scene["name"] in ["scene-0392"]: 214 | 215 | current_sample_token = scene["first_sample_token"] 216 | 217 | # Loop to get all successive keyframes 218 | list_data = [] 219 | while current_sample_token != "": 220 | current_sample = self.nusc.get("sample", current_sample_token) 221 | list_data.append(current_sample) 222 | current_sample_token = current_sample["next"] 223 | 224 | # Add new scans in the list 225 | self.list_keyframes.extend(list_data) 226 | 227 | # if split == 'verifying': 228 | # self.list_keyframes = self.list_keyframes[::10] 229 | 230 | logging.info(f"Nuscenes dataset split {split} - {len(self.list_keyframes)} frames") 231 | 232 | 233 | def get_weights(self): 234 | weights = torch.ones(self.N_LABELS) 235 | weights[0] = 0 236 | return weights 237 | 238 | @staticmethod 239 | def get_mask_filter_valid_labels(y): 240 | return (y>0) 241 | 242 | @staticmethod 243 | def get_ignore_index(): 244 | return 0 245 | 246 | def get_colors(self, labels): 247 | return self.reduced_colors[labels] 248 | 249 | 250 | def get_filename(self, index): 251 | 252 | # get sample 253 | sample = self.list_keyframes[index] 254 | 255 | # get the lidar token 256 | lidar_token = sample["data"]["LIDAR_TOP"] 257 | 258 | return str(lidar_token) 259 | 260 | @property 261 | def raw_file_names(self): 262 | return [] 263 | 264 | def _download(self): # override _download to remove makedirs 265 | pass 266 | 267 | def download(self): 268 | pass 269 | 270 | def process(self): 271 | pass 272 | 273 | def _process(self): 274 | pass 275 | 276 | def len(self): 277 | return len(self.list_keyframes) 278 | 279 | def get(self, idx): 280 | """Get item.""" 281 | 282 | # get sample 283 | sample = self.list_keyframes[idx] 284 | 285 | # get the lidar token 286 | lidar_token = sample["data"]["LIDAR_TOP"] 287 | 288 | # the lidar record 289 | lidar_rec = self.nusc.get('sample_data', sample['data']["LIDAR_TOP"]) 290 | 291 | # get intensities 292 | pc = LidarPointCloud.from_file(os.path.join(self.nusc.dataroot, lidar_rec['filename'])) 293 | pos = pc.points.T[:,:3] 294 | intensities = pc.points.T[:,3:] / 255 # intensities 295 | 296 | # get the labels 297 | lidarseg_label_filename = os.path.join(self.nusc.dataroot, self.nusc.get('lidarseg', lidar_token)['filename']) 298 | y_complete_labels = load_bin_file(lidarseg_label_filename) 299 | y = self.label_to_reduced_np[y_complete_labels] 300 | 301 | # convert to torch 302 | pos = torch.tensor(pos, dtype=torch.float) 303 | y = torch.tensor(y, dtype=torch.long) 304 | intensities = torch.tensor(intensities, dtype=torch.float) 305 | x = torch.ones((pos.shape[0],1), dtype=torch.float) 306 | 307 | return Data(x=x, intensities=intensities, pos=pos, y=y, shape_id=idx, ) 308 | -------------------------------------------------------------------------------- /downstream/datasets/once.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from torch_geometric.data import Dataset 3 | from torch_geometric.data import Data 4 | import numpy as np 5 | import os 6 | import pickle 7 | import copy 8 | import torch 9 | 10 | class ONCE(Dataset): 11 | 12 | INFO_PATH= { 13 | 'train': "once_infos_train.pkl", 14 | 'val': "once_infos_val.pkl", 15 | 'test': "once_infos_test.pkl", 16 | 'raw_small': "once_infos_raw_small.pkl", 17 | 'raw_medium': "once_infos_raw_medium.pkl", 18 | 'raw_large': "once_infos_raw_large.pkl", 19 | } 20 | 21 | def __init__(self, 22 | root, 23 | split="training", 24 | transform=None, 25 | skip_ratio=1, 26 | **kwargs): 27 | 28 | super().__init__(root, transform, None) 29 | 30 | logging.info(f"ONCE - split {split}") 31 | 32 | info_path = os.path.join(self.root, self.INFO_PATH[split]) 33 | with open(info_path, 'rb') as f: 34 | self.once_infos = pickle.load(f) 35 | 36 | 37 | if split in ["val", "valiation"]: 38 | self.once_infos = self.once_infos[:100] 39 | 40 | logging.info(f"ONCE dataset {len(self.once_infos)}") 41 | 42 | 43 | def _download(self): # override _download to remove makedirs 44 | pass 45 | 46 | def download(self): 47 | pass 48 | 49 | def process(self): 50 | pass 51 | 52 | def _process(self): 53 | pass 54 | 55 | def len(self): 56 | return len(self.once_infos) 57 | 58 | def get(self, idx): 59 | """Get item.""" 60 | 61 | # get ids 62 | frame_id = self.once_infos[idx]['frame_id'] 63 | seq_id = self.once_infos[idx]['sequence_id'] 64 | 65 | # load lidar data 66 | bin_path = os.path.join(self.root, 'data', seq_id, 'lidar_roof', '{}.bin'.format(frame_id)) 67 | points = np.fromfile(bin_path, dtype=np.float32).reshape(-1, 4) 68 | 69 | # get intensities and points 70 | intensities = points[:,3:4] / 255. 71 | pos = points[:,:3] 72 | 73 | # transform to tensor 74 | pos = torch.tensor(pos, dtype=torch.float) 75 | intensities = torch.tensor(intensities, dtype=torch.float) 76 | x = torch.ones((pos.shape[0],1), dtype=torch.float) 77 | 78 | return Data(x=x, intensities=intensities, pos=pos, shape_id=idx) 79 | -------------------------------------------------------------------------------- /downstream/datasets/semantic-kitti.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | color_map: # bgr 38 | 0 : [0, 0, 0] 39 | 1 : [0, 0, 255] 40 | 10: [245, 150, 100] 41 | 11: [245, 230, 100] 42 | 13: [250, 80, 100] 43 | 15: [150, 60, 30] 44 | 16: [255, 0, 0] 45 | 18: [180, 30, 80] 46 | 20: [255, 0, 0] 47 | 30: [30, 30, 255] 48 | 31: [200, 40, 255] 49 | 32: [90, 30, 150] 50 | 40: [255, 0, 255] 51 | 44: [255, 150, 255] 52 | 48: [75, 0, 75] 53 | 49: [75, 0, 175] 54 | 50: [0, 200, 255] 55 | 51: [50, 120, 255] 56 | 52: [0, 150, 255] 57 | 60: [170, 255, 150] 58 | 70: [0, 175, 0] 59 | 71: [0, 60, 135] 60 | 72: [80, 240, 150] 61 | 80: [150, 240, 255] 62 | 81: [0, 0, 255] 63 | 99: [255, 255, 50] 64 | 252: [245, 150, 100] 65 | 256: [255, 0, 0] 66 | 253: [200, 40, 255] 67 | 254: [30, 30, 255] 68 | 255: [90, 30, 150] 69 | 257: [250, 80, 100] 70 | 258: [180, 30, 80] 71 | 259: [255, 0, 0] 72 | content: # as a ratio with the total number of points 73 | 0: 0.018889854628292943 74 | 1: 0.0002937197336781505 75 | 10: 0.040818519255974316 76 | 11: 0.00016609538710764618 77 | 13: 2.7879693665067774e-05 78 | 15: 0.00039838616015114444 79 | 16: 0.0 80 | 18: 0.0020633612104619787 81 | 20: 0.0016218197275284021 82 | 30: 0.00017698551338515307 83 | 31: 1.1065903904919655e-08 84 | 32: 5.532951952459828e-09 85 | 40: 0.1987493871255525 86 | 44: 0.014717169549888214 87 | 48: 0.14392298360372 88 | 49: 0.0039048553037472045 89 | 50: 0.1326861944777486 90 | 51: 0.0723592229456223 91 | 52: 0.002395131480328884 92 | 60: 4.7084144280367186e-05 93 | 70: 0.26681502148037506 94 | 71: 0.006035012012626033 95 | 72: 0.07814222006271769 96 | 80: 0.002855498193863172 97 | 81: 0.0006155958086189918 98 | 99: 0.009923127583046915 99 | 252: 0.001789309418528068 100 | 253: 0.00012709999297008662 101 | 254: 0.00016059776092534436 102 | 255: 3.745553104802113e-05 103 | 256: 0.0 104 | 257: 0.00011351574470342043 105 | 258: 0.00010157861367183268 106 | 259: 4.3840131989471124e-05 107 | # classes that are indistinguishable from single scan or inconsistent in 108 | # ground truth are mapped to their closest equivalent 109 | learning_map: 110 | 0 : 0 # "unlabeled" 111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 112 | 10: 1 # "car" 113 | 11: 2 # "bicycle" 114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 115 | 15: 3 # "motorcycle" 116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 117 | 18: 4 # "truck" 118 | 20: 5 # "other-vehicle" 119 | 30: 6 # "person" 120 | 31: 7 # "bicyclist" 121 | 32: 8 # "motorcyclist" 122 | 40: 9 # "road" 123 | 44: 10 # "parking" 124 | 48: 11 # "sidewalk" 125 | 49: 12 # "other-ground" 126 | 50: 13 # "building" 127 | 51: 14 # "fence" 128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 130 | 70: 15 # "vegetation" 131 | 71: 16 # "trunk" 132 | 72: 17 # "terrain" 133 | 80: 18 # "pole" 134 | 81: 19 # "traffic-sign" 135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 136 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 137 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 138 | 254: 6 # "moving-person" to "person" ------------------------------mapped 139 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 140 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 141 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 142 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 143 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 144 | learning_map_inv: # inverse of previous map 145 | 0: 0 # "unlabeled", and others ignored 146 | 1: 10 # "car" 147 | 2: 11 # "bicycle" 148 | 3: 15 # "motorcycle" 149 | 4: 18 # "truck" 150 | 5: 20 # "other-vehicle" 151 | 6: 30 # "person" 152 | 7: 31 # "bicyclist" 153 | 8: 32 # "motorcyclist" 154 | 9: 40 # "road" 155 | 10: 44 # "parking" 156 | 11: 48 # "sidewalk" 157 | 12: 49 # "other-ground" 158 | 13: 50 # "building" 159 | 14: 51 # "fence" 160 | 15: 70 # "vegetation" 161 | 16: 71 # "trunk" 162 | 17: 72 # "terrain" 163 | 18: 80 # "pole" 164 | 19: 81 # "traffic-sign" 165 | learning_ignore: # Ignore classes 166 | 0: True # "unlabeled", and others ignored 167 | 1: False # "car" 168 | 2: False # "bicycle" 169 | 3: False # "motorcycle" 170 | 4: False # "truck" 171 | 5: False # "other-vehicle" 172 | 6: False # "person" 173 | 7: False # "bicyclist" 174 | 8: False # "motorcyclist" 175 | 9: False # "road" 176 | 10: False # "parking" 177 | 11: False # "sidewalk" 178 | 12: False # "other-ground" 179 | 13: False # "building" 180 | 14: False # "fence" 181 | 15: False # "vegetation" 182 | 16: False # "trunk" 183 | 17: False # "terrain" 184 | 18: False # "pole" 185 | 19: False # "traffic-sign" 186 | split: # sequence numbers 187 | train: 188 | - 0 189 | - 1 190 | - 2 191 | - 3 192 | - 4 193 | - 5 194 | - 6 195 | - 7 196 | - 9 197 | - 10 198 | valid: 199 | - 8 200 | test: 201 | - 11 202 | - 12 203 | - 13 204 | - 14 205 | - 15 206 | - 16 207 | - 17 208 | - 18 209 | - 19 210 | - 20 211 | - 21 212 | -------------------------------------------------------------------------------- /downstream/datasets/semantic_poss_segcontrast.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from pathlib import Path 5 | import json 6 | import numpy as np 7 | import torch 8 | from torch_geometric.data import Dataset 9 | from torch_geometric.data import Data 10 | 11 | 12 | class SemanticPOSS_SegContrast(Dataset): 13 | 14 | N_LABELS = 14 15 | 16 | def __init__(self, 17 | root, 18 | split="training", 19 | transform=None, 20 | skip_ratio=1, 21 | **kwargs): 22 | 23 | super().__init__(root, transform, None) 24 | 25 | self.split = split 26 | self.n_frames = 1 27 | 28 | logging.info(f"SemanticPOSS - split {split}") 29 | 30 | # get the scenes 31 | assert(split in ["train", "val", "test"]) 32 | if split == "train": 33 | self.sequences = ['00', '01', '02', '04', '05'] 34 | elif split == "val": 35 | self.sequences = ['03'] 36 | elif split == "test": 37 | raise NotImplementedError 38 | else: 39 | raise ValueError('Unknown set for SemanticPOSS data: ', split) 40 | 41 | self.points_datapath = [] 42 | self.labels_datapath = [] 43 | # if self.split == "train" and skip_ratio > 1: 44 | # with open("datasets/percentiles_split.json", 'r') as p: 45 | # splits = json.load(p) 46 | # skip_to_percent = {2:'0.5', 4:'0.25', 10:'0.1', 100:'0.01', 1000:'0.001', 10000:'0.0001'} 47 | # if skip_ratio not in skip_to_percent: 48 | # raise ValueError 49 | # percentage = skip_to_percent[skip_ratio] 50 | 51 | # for seq in splits[percentage]: 52 | # self.points_datapath += splits[percentage][seq]['points'] 53 | # self.labels_datapath += splits[percentage][seq]['labels'] 54 | 55 | # for i in range(len(self.points_datapath)): 56 | # self.points_datapath[i] = self.points_datapath[i].replace("Datasets/SemanticKITTI/", "") 57 | # self.points_datapath[i] = os.path.join(self.root,self.points_datapath[i]) 58 | # self.labels_datapath[i] = self.labels_datapath[i].replace("Datasets/SemanticKITTI/", "") 59 | # self.labels_datapath[i] = os.path.join(self.root,self.labels_datapath[i]) 60 | # else: 61 | 62 | 63 | for sequence in self.sequences: 64 | self.points_datapath += [path for path in Path(os.path.join(self.root, "dataset", "sequences", sequence, "velodyne")).rglob('*.bin')] 65 | 66 | for fname in self.points_datapath: 67 | fname = str(fname).replace("/velodyne/", "/labels/") 68 | fname = str(fname).replace(".bin", ".label") 69 | self.labels_datapath.append(fname) 70 | 71 | 72 | if skip_ratio > 1: 73 | self.points_datapath = self.points_datapath[::skip_ratio] 74 | self.labels_datapath = self.labels_datapath[::skip_ratio] 75 | 76 | 77 | # Read labels 78 | # config_file = 'datasets/semantic-kitti.yaml' 79 | 80 | # with open(config_file, 'r') as stream: 81 | # doc = yaml.safe_load(stream) 82 | # all_labels = doc['labels'] 83 | # learning_map_inv = doc['learning_map_inv'] 84 | # learning_map = doc['learning_map'] 85 | # self.learning_map = np.zeros((np.max([k for k in learning_map.keys()]) + 1), dtype=np.int32) 86 | # for k, v in learning_map.items(): 87 | # self.learning_map[k] = v 88 | 89 | # self.learning_map_inv = np.zeros((np.max([k for k in learning_map_inv.keys()]) + 1), dtype=np.int32) 90 | # for k, v in learning_map_inv.items(): 91 | # self.learning_map_inv[k] = v 92 | 93 | learning_map = { 94 | 0: 0, 95 | 4: 1, 96 | 5: 1, 97 | 6: 2, 98 | 7: 3, 99 | 8: 4, 100 | 9: 5, 101 | 10: 6, 102 | 11: 6, 103 | 12: 6, 104 | 13: 7, 105 | 14: 8, 106 | 15: 9, 107 | 16: 10, 108 | 17: 11, 109 | 21: 12, 110 | 22: 13, 111 | } 112 | 113 | learning_map_inv = { 114 | 0: 0, 115 | 1: 4, 116 | 1: 5, 117 | 2: 6, 118 | 3: 7, 119 | 4: 8, 120 | 5: 9, 121 | 6:10, 122 | 6:11, 123 | 6:12, 124 | 7:13, 125 | 8:14, 126 | 9:15, 127 | 10:16, 128 | 11:17, 129 | 12:21, 130 | 13:22 131 | } 132 | 133 | self.learning_map = np.zeros((np.max([k for k in learning_map.keys()]) + 1), dtype=np.int32) 134 | for k, v in learning_map.items(): 135 | self.learning_map[k] = v 136 | 137 | self.learning_map_inv = np.zeros((np.max([k for k in learning_map_inv.keys()]) + 1), dtype=np.int32) 138 | for k, v in learning_map_inv.items(): 139 | self.learning_map_inv[k] = v 140 | 141 | self.class_colors = np.array([ 142 | [0, 0, 0], 143 | [245, 150, 100], 144 | [245, 230, 100], 145 | [150, 60, 30], 146 | [180, 30, 80], 147 | [255, 0, 0], 148 | [30, 30, 255], 149 | [200, 40, 255], 150 | [90, 30, 150], 151 | [255, 0, 255], 152 | [255, 150, 255], 153 | [75, 0, 75], 154 | [75, 0, 175], 155 | [0, 200, 255], 156 | # [50, 120, 255], 157 | # [0, 175, 0], 158 | # [0, 60, 135], 159 | # [80, 240, 150], 160 | # [150, 240, 255], 161 | # [0, 0, 255], 162 | ], dtype=np.uint8) 163 | 164 | 165 | logging.info(f"SemanticPOSS dataset {len(self.points_datapath)}") 166 | 167 | def get_weights(self): 168 | weights = torch.ones(self.N_LABELS) 169 | weights[0] = 0 170 | return weights 171 | 172 | @staticmethod 173 | def get_mask_filter_valid_labels(y): 174 | return (y>0) 175 | 176 | def get_colors(self, labels): 177 | return self.class_colors[labels] 178 | 179 | def get_filename(self, index): 180 | fname = self.points_datapath[index] 181 | fname = str(fname).split("/") 182 | fname = ("_").join([fname[-3], fname[-1]]) 183 | fname = fname[:-4] 184 | return fname 185 | 186 | @staticmethod 187 | def get_ignore_index(): 188 | return 0 189 | 190 | def _download(self): # override _download to remove makedirs 191 | pass 192 | 193 | def download(self): 194 | pass 195 | 196 | def process(self): 197 | pass 198 | 199 | def _process(self): 200 | pass 201 | 202 | def len(self): 203 | return len(self.points_datapath) 204 | 205 | def get(self, idx): 206 | """Get item.""" 207 | 208 | fname_points = self.points_datapath[idx] 209 | frame_points = np.fromfile(fname_points, dtype=np.float32) 210 | 211 | pos = frame_points.reshape((-1, 4)) 212 | intensities = pos[:,3:] 213 | pos = pos[:,:3] 214 | 215 | # Read labels 216 | label_file = self.labels_datapath[idx] 217 | frame_labels = np.fromfile(label_file, dtype=np.int32) 218 | frame_labels = frame_labels.reshape((-1)) 219 | y = frame_labels & 0xFFFF # semantic label in lower half 220 | 221 | # get unlabeled data 222 | y = self.learning_map[y] 223 | unlabeled = (y == 0) 224 | 225 | # remove unlabeled points 226 | y = np.delete(y, unlabeled, axis=0) 227 | pos = np.delete(pos, unlabeled, axis=0) 228 | intensities = np.delete(intensities, unlabeled, axis=0) 229 | 230 | pos = torch.tensor(pos, dtype=torch.float) 231 | y = torch.tensor(y, dtype=torch.long) 232 | intensities = torch.tensor(intensities, dtype=torch.float) 233 | x = torch.ones((pos.shape[0],1), dtype=torch.float) 234 | 235 | return Data(x=x, intensities=intensities, pos=pos, y=y, 236 | shape_id=idx, ) 237 | -------------------------------------------------------------------------------- /downstream/datasets/semantickitti_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from pathlib import Path 5 | import json 6 | import numpy as np 7 | import torch 8 | from torch_geometric.data import Dataset 9 | from torch_geometric.data import Data 10 | 11 | 12 | class SemanticKITTI(Dataset): 13 | 14 | N_LABELS = 20 15 | 16 | def __init__(self, 17 | root, 18 | split="training", 19 | transform=None, 20 | skip_ratio=1, 21 | **kwargs): 22 | 23 | super().__init__(root, transform, None) 24 | 25 | self.split = split 26 | self.n_frames = 1 27 | 28 | logging.info(f"SemanticKITTI - split {split}") 29 | 30 | # get the scenes 31 | assert(split in ["train", "val", "test"]) 32 | if split == "train": 33 | self.sequences = ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10'] 34 | elif split == "val": 35 | self.sequences = ['08'] 36 | elif split == "test": 37 | raise NotImplementedError 38 | else: 39 | raise ValueError('Unknown set for SemanticKitti data: ', split) 40 | 41 | self.points_datapath = [] 42 | self.labels_datapath = [] 43 | if self.split == "train" and skip_ratio > 1: 44 | with open("datasets/percentiles_split.json", 'r') as p: 45 | splits = json.load(p) 46 | skip_to_percent = {2:'0.5', 4:'0.25', 10:'0.1', 100:'0.01', 1000:'0.001', 10000:'0.0001'} 47 | if skip_ratio not in skip_to_percent: 48 | raise ValueError 49 | percentage = skip_to_percent[skip_ratio] 50 | 51 | for seq in splits[percentage]: 52 | self.points_datapath += splits[percentage][seq]['points'] 53 | self.labels_datapath += splits[percentage][seq]['labels'] 54 | 55 | for i in range(len(self.points_datapath)): 56 | self.points_datapath[i] = self.points_datapath[i].replace("Datasets/SemanticKITTI/", "") 57 | self.points_datapath[i] = os.path.join(self.root,self.points_datapath[i]) 58 | self.labels_datapath[i] = self.labels_datapath[i].replace("Datasets/SemanticKITTI/", "") 59 | self.labels_datapath[i] = os.path.join(self.root,self.labels_datapath[i]) 60 | else: 61 | 62 | 63 | for sequence in self.sequences: 64 | self.points_datapath += [path for path in Path(os.path.join(self.root, "dataset", "sequences", sequence, "velodyne")).rglob('*.bin')] 65 | 66 | for fname in self.points_datapath: 67 | fname = str(fname).replace("/velodyne/", "/labels/") 68 | fname = str(fname).replace(".bin", ".label") 69 | self.labels_datapath.append(fname) 70 | 71 | 72 | # Read labels 73 | config_file = 'datasets/semantic-kitti.yaml' 74 | 75 | with open(config_file, 'r') as stream: 76 | doc = yaml.safe_load(stream) 77 | all_labels = doc['labels'] 78 | learning_map_inv = doc['learning_map_inv'] 79 | learning_map = doc['learning_map'] 80 | self.learning_map = np.zeros((np.max([k for k in learning_map.keys()]) + 1), dtype=np.int32) 81 | for k, v in learning_map.items(): 82 | self.learning_map[k] = v 83 | 84 | self.learning_map_inv = np.zeros((np.max([k for k in learning_map_inv.keys()]) + 1), dtype=np.int32) 85 | for k, v in learning_map_inv.items(): 86 | self.learning_map_inv[k] = v 87 | 88 | self.class_colors = np.array([ 89 | [0, 0, 0], 90 | [245, 150, 100], 91 | [245, 230, 100], 92 | [150, 60, 30], 93 | [180, 30, 80], 94 | [255, 0, 0], 95 | [30, 30, 255], 96 | [200, 40, 255], 97 | [90, 30, 150], 98 | [255, 0, 255], 99 | [255, 150, 255], 100 | [75, 0, 75], 101 | [75, 0, 175], 102 | [0, 200, 255], 103 | [50, 120, 255], 104 | [0, 175, 0], 105 | [0, 60, 135], 106 | [80, 240, 150], 107 | [150, 240, 255], 108 | [0, 0, 255], 109 | ], dtype=np.uint8) 110 | 111 | # if split == 'val': 112 | # self.points_datapath = self.points_datapath[::10] 113 | # self.labels_datapath = self.labels_datapath[::10] 114 | 115 | logging.info(f"SemanticKITTI dataset {len(self.points_datapath)}") 116 | 117 | def get_weights(self): 118 | weights = torch.ones(self.N_LABELS) 119 | weights[0] = 0 120 | return weights 121 | 122 | @staticmethod 123 | def get_mask_filter_valid_labels(y): 124 | return (y>0) 125 | 126 | def get_colors(self, labels): 127 | return self.class_colors[labels] 128 | 129 | def get_filename(self, index): 130 | fname = self.points_datapath[index] 131 | fname = str(fname).split("/") 132 | fname = ("_").join([fname[-3], fname[-1]]) 133 | fname = fname[:-4] 134 | return fname 135 | 136 | @staticmethod 137 | def get_ignore_index(): 138 | return 0 139 | 140 | def _download(self): # override _download to remove makedirs 141 | pass 142 | 143 | def download(self): 144 | pass 145 | 146 | def process(self): 147 | pass 148 | 149 | def _process(self): 150 | pass 151 | 152 | def len(self): 153 | return len(self.points_datapath) 154 | 155 | def get(self, idx): 156 | """Get item.""" 157 | 158 | fname_points = self.points_datapath[idx] 159 | frame_points = np.fromfile(fname_points, dtype=np.float32) 160 | 161 | pos = frame_points.reshape((-1, 4)) 162 | intensities = pos[:,3:] 163 | pos = pos[:,:3] 164 | 165 | # Read labels 166 | label_file = self.labels_datapath[idx] 167 | frame_labels = np.fromfile(label_file, dtype=np.int32) 168 | frame_labels = frame_labels.reshape((-1)) 169 | y = frame_labels & 0xFFFF # semantic label in lower half 170 | 171 | # get unlabeled data 172 | y = self.learning_map[y] 173 | unlabeled = (y == 0) 174 | 175 | # remove unlabeled points 176 | y = np.delete(y, unlabeled, axis=0) 177 | pos = np.delete(pos, unlabeled, axis=0) 178 | intensities = np.delete(intensities, unlabeled, axis=0) 179 | 180 | pos = torch.tensor(pos, dtype=torch.float) 181 | y = torch.tensor(y, dtype=torch.long) 182 | intensities = torch.tensor(intensities, dtype=torch.float) 183 | x = torch.ones((pos.shape[0],1), dtype=torch.float) 184 | 185 | return Data(x=x, intensities=intensities, pos=pos, y=y, 186 | shape_id=idx, ) -------------------------------------------------------------------------------- /downstream/eval.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import logging 3 | import argparse 4 | import warnings 5 | import importlib 6 | 7 | from tqdm import tqdm 8 | 9 | from scipy.spatial import KDTree 10 | 11 | import torch 12 | 13 | from torch_geometric.data import DataLoader 14 | 15 | from utils.utils import wgreen 16 | from utils.confusion_matrix import ConfusionMatrix 17 | from transforms import get_transforms, get_input_channels 18 | 19 | import datasets 20 | import networks 21 | from networks.backbone import * 22 | 23 | 24 | if __name__ == "__main__": 25 | warnings.filterwarnings("ignore", category=UserWarning) 26 | 27 | logging.getLogger().setLevel("INFO") 28 | 29 | parser = argparse.ArgumentParser(description='Process some integers.') 30 | parser.add_argument('--config', type=str, required=True) 31 | parser.add_argument('--ckpt', type=str, required=True) 32 | parser.add_argument('--split', type=str, required=True) 33 | opts = parser.parse_args() 34 | 35 | config = yaml.load(open(opts.config, "r"), yaml.FullLoader) 36 | 37 | logging.info("Dataset") 38 | DatasetClass = eval("datasets." + config["dataset_name"]) 39 | test_transforms = get_transforms(config, train=False, downstream=True, keep_orignal_data=True) 40 | test_dataset = DatasetClass(config["dataset_root"], 41 | split=opts.split, 42 | transform=test_transforms, 43 | ) 44 | 45 | logging.info("Dataloader") 46 | test_loader = DataLoader( 47 | test_dataset, 48 | batch_size=1, 49 | shuffle=False, 50 | num_workers=config["threads"], 51 | follow_batch=["voxel_coords"] 52 | ) 53 | 54 | num_classes = config["downstream"]["num_classes"] 55 | device = torch.device("cuda") 56 | 57 | logging.info("Network") 58 | if config["network"]["backbone_params"] is None: 59 | config["network"]["backbone_params"] = {} 60 | config["network"]["backbone_params"]["in_channels"] = get_input_channels(config["inputs"]) 61 | config["network"]["backbone_params"]["out_channels"] = config["downstream"]["num_classes"] 62 | 63 | backbone_name = "networks.backbone." 64 | if config["network"]["framework"] is not None: 65 | backbone_name += config["network"]["framework"] 66 | importlib.import_module(backbone_name) 67 | backbone_name += "." + config["network"]["backbone"] 68 | net = eval(backbone_name)(**config["network"]["backbone_params"]) 69 | net.to(device) 70 | net.eval() 71 | 72 | logging.info("Loading the weights from pretrained network") 73 | try: 74 | net.load_state_dict(torch.load(opts.ckpt), strict=True) 75 | except RuntimeError: 76 | ckpt = torch.load(opts.ckpt) 77 | ckpt = {k[4:]: v for k, v in ckpt['state_dict'].items()} 78 | net.load_state_dict(ckpt, strict=True) 79 | 80 | cm = ConfusionMatrix(num_classes, 0) 81 | with torch.no_grad(): 82 | t = tqdm(test_loader, ncols=100) 83 | for data in t: 84 | 85 | data = data.to(device) 86 | 87 | # predictions 88 | predictions = net(data) 89 | predictions = torch.nn.functional.softmax(predictions[:, 1:], dim=1).max(dim=1)[1] 90 | predictions = predictions.cpu().numpy() + 1 91 | 92 | # interpolate to original point cloud 93 | original_pos_np = data["original_pos"].cpu().numpy() 94 | pos_np = data["pos"].cpu().numpy() 95 | tree = KDTree(pos_np) 96 | _, indices = tree.query(original_pos_np, k=1) 97 | predictions = predictions[indices] 98 | 99 | # update the confusion matric 100 | targets_np = data["original_y"].cpu().numpy() 101 | cm.update(targets_np, predictions) 102 | 103 | # compute metrics 104 | iou_per_class = cm.get_per_class_iou() 105 | miou = cm.get_mean_iou() 106 | freqweighted_iou = cm.get_freqweighted_iou() 107 | description = f"Val. | mIoU {miou*100:.2f} - fIoU {freqweighted_iou*100:.2f}" 108 | t.set_description_str(wgreen(description)) 109 | 110 | torch.cuda.empty_cache() 111 | 112 | logging.info(f"MIoU: {miou}") 113 | logging.info(f"FIoU: {freqweighted_iou}") 114 | logging.info(f"IoU per class: {iou_per_class}") 115 | -------------------------------------------------------------------------------- /downstream/eval_offset.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import logging 3 | import argparse 4 | import warnings 5 | import importlib 6 | 7 | from tqdm import tqdm 8 | 9 | from scipy.spatial import KDTree 10 | 11 | import torch 12 | 13 | from torch_geometric.data import DataLoader 14 | 15 | from utils.utils import wgreen 16 | from utils.confusion_matrix import ConfusionMatrix 17 | from transforms import get_transforms, get_input_channels 18 | 19 | import datasets 20 | import networks 21 | from networks.backbone import * 22 | 23 | 24 | if __name__ == "__main__": 25 | warnings.filterwarnings("ignore", category=UserWarning) 26 | 27 | logging.getLogger().setLevel("INFO") 28 | 29 | parser = argparse.ArgumentParser(description='Process some integers.') 30 | parser.add_argument('--config', type=str, required=True) 31 | parser.add_argument('--ckpt', type=str, required=True) 32 | parser.add_argument('--split', type=str, required=True) 33 | parser.add_argument('--skip_ratio', type=int, required=True) 34 | 35 | opts = parser.parse_args() 36 | 37 | config = yaml.load(open(opts.config, "r"), yaml.FullLoader) 38 | 39 | logging.info("Dataset") 40 | DatasetClass = eval("datasets." + config["dataset_name"]) 41 | test_transforms = get_transforms(config, train=False, downstream=True, keep_orignal_data=True) 42 | test_dataset = DatasetClass(config["dataset_root"], 43 | split=opts.split, 44 | transform=test_transforms, 45 | skip_ratio=opts.skip_ratio, 46 | complementary=True, 47 | ) 48 | 49 | logging.info("Dataloader") 50 | test_loader = DataLoader( 51 | test_dataset, 52 | batch_size=1, 53 | shuffle=False, 54 | num_workers=config["threads"], 55 | follow_batch=["voxel_coords"] 56 | ) 57 | 58 | num_classes = config["downstream"]["num_classes"] 59 | device = torch.device("cuda") 60 | 61 | logging.info("Network") 62 | if config["network"]["backbone_params"] is None: 63 | config["network"]["backbone_params"] = {} 64 | config["network"]["backbone_params"]["in_channels"] = get_input_channels(config["inputs"]) 65 | config["network"]["backbone_params"]["out_channels"] = config["downstream"]["num_classes"] 66 | 67 | backbone_name = "networks.backbone." 68 | if config["network"]["framework"] is not None: 69 | backbone_name += config["network"]["framework"] 70 | importlib.import_module(backbone_name) 71 | backbone_name += "." + config["network"]["backbone"] 72 | net = eval(backbone_name)(**config["network"]["backbone_params"]) 73 | net.to(device) 74 | net.eval() 75 | 76 | logging.info("Loading the weights from pretrained network") 77 | try: 78 | net.load_state_dict(torch.load(opts.ckpt), strict=True) 79 | except RuntimeError: 80 | ckpt = torch.load(opts.ckpt) 81 | ckpt = {k[4:]: v for k, v in ckpt['state_dict'].items()} 82 | net.load_state_dict(ckpt, strict=True) 83 | 84 | cm = ConfusionMatrix(num_classes, 0) 85 | with torch.no_grad(): 86 | t = tqdm(test_loader, ncols=100) 87 | for data in t: 88 | 89 | data = data.to(device) 90 | 91 | # predictions 92 | predictions = net(data) 93 | predictions = torch.nn.functional.softmax(predictions[:, 1:], dim=1).max(dim=1)[1] 94 | predictions = predictions.cpu().numpy() + 1 95 | 96 | # interpolate to original point cloud 97 | original_pos_np = data["original_pos"].cpu().numpy() 98 | pos_np = data["pos"].cpu().numpy() 99 | tree = KDTree(pos_np) 100 | _, indices = tree.query(original_pos_np, k=1) 101 | predictions = predictions[indices] 102 | 103 | # update the confusion matric 104 | targets_np = data["original_y"].cpu().numpy() 105 | cm.update(targets_np, predictions) 106 | 107 | # compute metrics 108 | iou_per_class = cm.get_per_class_iou() 109 | miou = cm.get_mean_iou() 110 | freqweighted_iou = cm.get_freqweighted_iou() 111 | description = f"Val. | mIoU {miou*100:.2f} - fIoU {freqweighted_iou*100:.2f}" 112 | t.set_description_str(wgreen(description)) 113 | 114 | torch.cuda.empty_cache() 115 | 116 | logging.info(f"MIoU: {miou}") 117 | logging.info(f"FIoU: {freqweighted_iou}") 118 | logging.info(f"IoU per class: {iou_per_class}") 119 | -------------------------------------------------------------------------------- /downstream/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eaphan/NCLR/b3a944af649b64f0aed82aae0211ebc5f2fe2d13/downstream/networks/__init__.py -------------------------------------------------------------------------------- /downstream/networks/backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eaphan/NCLR/b3a944af649b64f0aed82aae0211ebc5f2fe2d13/downstream/networks/backbone/__init__.py -------------------------------------------------------------------------------- /downstream/networks/backbone/minkowski_engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .minkunet import MinkUNet34, MinkUNet18 2 | from .utils import Quantize 3 | from .minkunet_segcontrast import SegContrastMinkUNet18 -------------------------------------------------------------------------------- /downstream/networks/backbone/minkowski_engine/minkunet_segcontrast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import MinkowskiEngine as ME 4 | import logging 5 | from contextlib import nullcontext 6 | 7 | 8 | class BasicConvolutionBlock(nn.Module): 9 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1, D=3): 10 | super().__init__() 11 | self.net = nn.Sequential( 12 | ME.MinkowskiConvolution(inc, 13 | outc, 14 | kernel_size=ks, 15 | dilation=dilation, 16 | stride=stride, 17 | dimension=D), 18 | ME.MinkowskiBatchNorm(outc), 19 | ME.MinkowskiReLU(inplace=True) 20 | ) 21 | 22 | def forward(self, x): 23 | out = self.net(x) 24 | return out 25 | 26 | 27 | class BasicDeconvolutionBlock(nn.Module): 28 | def __init__(self, inc, outc, ks=3, stride=1, D=3): 29 | super().__init__() 30 | self.net = nn.Sequential( 31 | ME.MinkowskiConvolutionTranspose(inc, 32 | outc, 33 | kernel_size=ks, 34 | stride=stride, 35 | dimension=D), 36 | ME.MinkowskiBatchNorm(outc), 37 | ME.MinkowskiReLU(inplace=True) 38 | ) 39 | 40 | def forward(self, x): 41 | return self.net(x) 42 | 43 | 44 | class ResidualBlock(nn.Module): 45 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1, D=3): 46 | super().__init__() 47 | self.net = nn.Sequential( 48 | ME.MinkowskiConvolution(inc, 49 | outc, 50 | kernel_size=ks, 51 | dilation=dilation, 52 | stride=stride, 53 | dimension=D), 54 | ME.MinkowskiBatchNorm(outc), 55 | ME.MinkowskiReLU(inplace=True), 56 | ME.MinkowskiConvolution(outc, 57 | outc, 58 | kernel_size=ks, 59 | dilation=dilation, 60 | stride=1, 61 | dimension=D), 62 | ME.MinkowskiBatchNorm(outc) 63 | ) 64 | 65 | self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ 66 | nn.Sequential( 67 | ME.MinkowskiConvolution(inc, outc, kernel_size=1, dilation=1, stride=stride, dimension=D), 68 | ME.MinkowskiBatchNorm(outc) 69 | ) 70 | 71 | self.relu = ME.MinkowskiReLU(inplace=True) 72 | 73 | def forward(self, x): 74 | out = self.relu(self.net(x) + self.downsample(x)) 75 | return out 76 | 77 | 78 | class MinkUNet(nn.Module): 79 | def __init__(self, **kwargs): 80 | super().__init__() 81 | 82 | cr = kwargs.get('cr', 1.0) 83 | in_channels = kwargs.get('in_channels', 3) 84 | out_channels = kwargs.get('out_channels', 0) 85 | cs = [32, 32, 64, 128, 256, 256, 128, 96, 96] 86 | cs = [int(cr * x) for x in cs] 87 | self.run_up = kwargs.get('run_up', True) 88 | self.D = kwargs.get('D', 3) 89 | self.stem = nn.Sequential( 90 | ME.MinkowskiConvolution(in_channels, cs[0], kernel_size=3, stride=1, dimension=self.D), 91 | ME.MinkowskiBatchNorm(cs[0]), 92 | ME.MinkowskiReLU(True), 93 | ME.MinkowskiConvolution(cs[0], cs[0], kernel_size=3, stride=1, dimension=self.D), 94 | ME.MinkowskiBatchNorm(cs[0]), 95 | ME.MinkowskiReLU(inplace=True) 96 | ) 97 | 98 | self.stage1 = nn.Sequential( 99 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1, D=self.D), 100 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1, D=self.D), 101 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1, D=self.D), 102 | ) 103 | 104 | self.stage2 = nn.Sequential( 105 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1, D=self.D), 106 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1, D=self.D), 107 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1, D=self.D) 108 | ) 109 | 110 | self.stage3 = nn.Sequential( 111 | BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1, D=self.D), 112 | ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1, D=self.D), 113 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1, D=self.D), 114 | ) 115 | 116 | self.stage4 = nn.Sequential( 117 | BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1, D=self.D), 118 | ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1, D=self.D), 119 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1, D=self.D), 120 | ) 121 | 122 | self.up1 = nn.ModuleList([ 123 | BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2, D=self.D), 124 | nn.Sequential( 125 | ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, 126 | dilation=1, D=self.D), 127 | ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1, D=self.D), 128 | ) 129 | ]) 130 | 131 | self.up2 = nn.ModuleList([ 132 | BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2, D=self.D), 133 | nn.Sequential( 134 | ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, 135 | dilation=1, D=self.D), 136 | ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1, D=self.D), 137 | ) 138 | ]) 139 | 140 | self.up3 = nn.ModuleList([ 141 | BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2, D=self.D), 142 | nn.Sequential( 143 | ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, 144 | dilation=1, D=self.D), 145 | ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1, D=self.D), 146 | ) 147 | ]) 148 | 149 | self.up4 = nn.ModuleList([ 150 | BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2, D=self.D), 151 | nn.Sequential( 152 | ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, 153 | dilation=1, D=self.D), 154 | ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1, D=self.D), 155 | ) 156 | ]) 157 | 158 | if out_channels > 0: 159 | 160 | if 'head' in kwargs and kwargs['head'] == "bn_linear": 161 | logging.info("network - bn linear head") 162 | self.final = nn.Sequential(nn.BatchNorm1d(cs[8], affine=False), nn.Linear(cs[8], out_channels)) 163 | else: 164 | logging.info("network - linear head") 165 | self.final = nn.Sequential(nn.Linear(cs[8], out_channels)) 166 | else: 167 | self.final = None 168 | 169 | self.weight_initialization() 170 | self.dropout = nn.Dropout(0.3, True) 171 | self.linear_probing = False 172 | self.context_manager = nullcontext() 173 | 174 | def set_linear_probing(self): 175 | self.linear_probing = True 176 | self.context_manager = torch.no_grad() 177 | 178 | def weight_initialization(self): 179 | for m in self.modules(): 180 | if isinstance(m, nn.BatchNorm1d): 181 | if m.weight is not None: 182 | nn.init.constant_(m.weight, 1) 183 | if m.bias is not None: 184 | nn.init.constant_(m.bias, 0) 185 | 186 | # modified train function to take into account the linear probing 187 | def train(self, mode: bool = True): 188 | if not isinstance(mode, bool): 189 | raise ValueError("training mode is expected to be boolean") 190 | self.training = mode 191 | if self.linear_probing: 192 | for module in self.children(): 193 | module.train(False) 194 | self.final.train(mode) 195 | else: 196 | for module in self.children(): 197 | module.train(mode) 198 | return self 199 | 200 | # modified parameters function to take into account the linear probing 201 | def parameters(self, recurse: bool = True): 202 | if self.linear_probing: 203 | for name, param in self.named_parameters(recurse=recurse): 204 | if "classifier" in name: 205 | yield param 206 | else: 207 | for name, param in self.named_parameters(recurse=recurse): 208 | yield param 209 | 210 | def forward(self, x): 211 | with self.context_manager: 212 | x0 = self.stem(x) 213 | x1 = self.stage1(x0) 214 | x2 = self.stage2(x1) 215 | x3 = self.stage3(x2) 216 | x4 = self.stage4(x3) 217 | 218 | y1 = self.up1[0](x4) 219 | y1 = ME.cat(y1, x3) 220 | y1 = self.up1[1](y1) 221 | 222 | y2 = self.up2[0](y1) 223 | y2 = ME.cat(y2, x2) 224 | y2 = self.up2[1](y2) 225 | 226 | y3 = self.up3[0](y2) 227 | y3 = ME.cat(y3, x1) 228 | y3 = self.up3[1](y3) 229 | 230 | y4 = self.up4[0](y3) 231 | y4 = ME.cat(y4, x0) 232 | y4 = self.up4[1](y4) 233 | 234 | yout = self.final(y4.F) 235 | 236 | return yout 237 | 238 | 239 | class SegContrastMinkUNet18(MinkUNet): 240 | 241 | def __init__(self, in_channels, out_channels, **kwargs): 242 | super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs) 243 | 244 | def forward(self, data, downstream=False): 245 | 246 | coords = torch.cat([data["voxel_coords_batch"].unsqueeze(1), data["voxel_coords"]], dim=1).int() 247 | feats = data["voxel_x"] 248 | input = ME.SparseTensor(feats, coords) 249 | 250 | outputs = super().forward(input) 251 | 252 | vox_num = data["voxel_number"] 253 | increment = torch.cat([vox_num.new_zeros((1,)), vox_num[:-1]], dim=0) 254 | increment = increment.cumsum(0) 255 | increment = increment[data["batch"]] 256 | inv_map = data["voxel_to_pc_id"] + increment 257 | 258 | # interpolate the outputs 259 | outputs = outputs[inv_map] 260 | 261 | return outputs 262 | 263 | def get_last_layer_channels(self): 264 | return self.PLANES[-1] 265 | -------------------------------------------------------------------------------- /downstream/networks/backbone/minkowski_engine/utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # me_found = importlib.util.find_spec("MinkowskiEngine") is not None 4 | # logging.info(f"ME found - {torchsparse_found}") 5 | # if me_found: 6 | # from MinkowskiEngine.utils import sparse_quantize as me_sparse_quantize 7 | # from MinkowskiEngine import SparseTensor as MESparseTensor 8 | 9 | 10 | from MinkowskiEngine.utils import sparse_quantize as me_sparse_quantize 11 | from MinkowskiEngine import SparseTensor as MESparseTensor 12 | import torch 13 | import math 14 | 15 | def cart2polar(input_xyz): 16 | rho = torch.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2) 17 | phi = torch.atan2(input_xyz[:, 1], input_xyz[:, 0]) 18 | return torch.stack((rho, phi, input_xyz[:, 2]), dim=1) 19 | 20 | class Quantize(object): 21 | 22 | def __init__(self, voxel_size, **kwargs): 23 | self.voxel_size = voxel_size 24 | self.cylinder_coords = kwargs["cylinder_coords"] if "cylinder_coords" in kwargs else False 25 | 26 | def __call__(self, data): 27 | 28 | if self.cylinder_coords: 29 | 30 | pc_ = cart2polar(data["pos"].clone()) 31 | pc_[:,0] = pc_[:,0]/self.voxel_size # radius 32 | pc_[:,1] = pc_[:,1]/math.pi * 180 # angle (-180, 180) 33 | pc_[:,2] = pc_[:,2]/self.voxel_size # height 34 | pc_ = torch.round(pc_.clone()) 35 | 36 | else: 37 | 38 | pc_ = torch.round(data["pos"].clone() / self.voxel_size) 39 | 40 | pc_ -= pc_.min(0, keepdim=True)[0] 41 | 42 | coords, indices, inverse_map = me_sparse_quantize(pc_, 43 | return_index=True, 44 | return_inverse=True) 45 | 46 | feats = data["x"][indices] 47 | 48 | data["voxel_coords"] = coords 49 | data["voxel_x"] = feats 50 | data["voxel_to_pc_id"] = inverse_map 51 | data["voxel_number"] = int(coords.shape[0]) 52 | 53 | return data 54 | 55 | 56 | # class MEQuantizeCylindrical(object): 57 | 58 | # def __init__(self, voxel_size) -> None: 59 | 60 | 61 | # self.voxel_size = voxel_size 62 | 63 | # def __call__(self, data): 64 | 65 | # pc_ = data["pos"].clone() 66 | # x, y, z = pc_[:,0], pc_[:,1], pc_[:,2] 67 | # rho = torch.sqrt(x ** 2 + y ** 2) / self.voxel_size 68 | # # corresponds to a split each 1° 69 | # phi = torch.atan2(y, x) * 180 / np.pi 70 | # z = z / self.voxel_size 71 | # pc_[:,0] = rho 72 | # pc_[:,1] = phi 73 | # pc_[:,2] = z 74 | 75 | # data["vox_pos"] = pc_ 76 | 77 | # return data -------------------------------------------------------------------------------- /downstream/networks/backbone/spconv/__init__.py: -------------------------------------------------------------------------------- 1 | from .pcdet_models import SECOND 2 | 3 | from .utils import Quantize 4 | -------------------------------------------------------------------------------- /downstream/networks/backbone/spconv/pcdet_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pcdet.models import build_network 3 | from pcdet.config import cfg, cfg_from_yaml_file 4 | import numpy as np 5 | 6 | class custom_point_feature_encoder: 7 | def __init__(self, num): 8 | self.num_point_features = num 9 | 10 | class custom_dataset: 11 | 12 | def __init__(self, cfg, class_names) -> None: 13 | 14 | self.dataset_cfg = cfg 15 | 16 | self.point_cloud_range = np.array(self.dataset_cfg.POINT_CLOUD_RANGE, dtype=np.float32) 17 | self.point_feature_encoder = custom_point_feature_encoder(len(self.dataset_cfg.POINT_FEATURE_ENCODING["used_feature_list"])) 18 | 19 | self.class_names = class_names 20 | 21 | for cur in self.dataset_cfg.DATA_PROCESSOR: 22 | print(cur) 23 | if "VOXEL_SIZE" in cur: 24 | self.voxel_size = cur["VOXEL_SIZE"] 25 | 26 | grid_size = (self.point_cloud_range[3:6] - self.point_cloud_range[0:3]) / np.array(self.voxel_size) 27 | self.grid_size = np.round(grid_size).astype(np.int64) 28 | 29 | 30 | self.depth_downsample_factor = None 31 | 32 | class SECOND(torch.nn.Module): 33 | 34 | def __init__(self, in_channels, out_channels, 35 | **kwargs): 36 | 37 | super().__init__() 38 | 39 | config_path= kwargs["config"] 40 | 41 | cfg_from_yaml_file(config_path, cfg) 42 | 43 | self.train_set = custom_dataset(cfg.DATA_CONFIG, cfg.CLASS_NAMES) 44 | self.model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=self.train_set) 45 | 46 | self.classifier = torch.nn.Conv2d(512, out_channels, kernel_size=1) 47 | 48 | 49 | def forward(self, data): 50 | 51 | coords = torch.cat([data["voxel_coords_batch"].unsqueeze(1), data["voxel_coords"]], dim=1).int() 52 | features = data["voxel_x"] 53 | batch_size = data["voxel_coords_batch"][-1] + 1 54 | 55 | outputs = self.model.backbone_3d.forward( 56 | {"batch_size":batch_size, "voxel_features":features, "voxel_coords":coords}) 57 | 58 | outputs = self.model.map_to_bev_module(outputs) 59 | 60 | outputs = self.model.backbone_2d(outputs) 61 | 62 | outputs = outputs["spatial_features_2d"] 63 | outputs = self.classifier(outputs) 64 | 65 | y = torch.arange(outputs.shape[2]) 66 | x = torch.arange(outputs.shape[3]) 67 | 68 | grid_y, grid_x = torch.meshgrid(y, x, indexing='ij') 69 | grid_x = grid_x.float().to(outputs.device)/grid_x.max()*(self.train_set.point_cloud_range[3] - self.train_set.point_cloud_range[0]) + self.train_set.point_cloud_range[0] 70 | grid_y = grid_y.float().to(outputs.device)/grid_y.max()*(self.train_set.point_cloud_range[4] - self.train_set.point_cloud_range[1]) + self.train_set.point_cloud_range[1] 71 | grid_z = torch.zeros_like(grid_x, dtype=torch.float, device=outputs.device) 72 | 73 | points = torch.stack([grid_x, grid_y, grid_z], dim=-1).reshape(-1, 3) 74 | points = points.repeat(outputs.shape[0], 1) 75 | 76 | points_batch = torch.arange(batch_size, device=outputs.device, dtype=torch.long).reshape(-1,1,1).expand((outputs.shape[0], outputs.shape[2], outputs.shape[3])) 77 | points_batch = points_batch.reshape(-1) 78 | 79 | points_outputs = outputs.permute(0,2,3,1).reshape(-1, outputs.shape[1]) 80 | 81 | return { 82 | 'latents': points_outputs, 83 | 'latents_pos': points, 84 | 'latents_batch': points_batch 85 | } -------------------------------------------------------------------------------- /downstream/networks/backbone/spconv/utils.py: -------------------------------------------------------------------------------- 1 | 2 | from spconv.pytorch.utils import PointToVoxel 3 | # from spconv.utils import Point2VoxelCPU3d as PointToVoxel 4 | from torch_geometric.nn import global_mean_pool 5 | import torch 6 | import re 7 | import numpy as np 8 | 9 | ##################################################################### 10 | # transformation between Cartesian coordinates and polar coordinates 11 | # code from https://github.com/xinge008/Cylinder3D/ 12 | # please refer to the repo 13 | 14 | def cart2polar(input_xyz): 15 | rho = torch.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2) 16 | phi = torch.atan2(input_xyz[:, 1], input_xyz[:, 0]) 17 | return torch.stack((rho, phi, input_xyz[:, 2]), dim=1) 18 | 19 | 20 | def polar2cart(input_xyz_polar): 21 | # print(input_xyz_polar.shape) 22 | x = input_xyz_polar[0] * torch.cos(input_xyz_polar[1]) 23 | y = input_xyz_polar[0] * torch.sin(input_xyz_polar[1]) 24 | return torch.stack((x, y, input_xyz_polar[2]), dim=0) 25 | ########################################################## 26 | 27 | 28 | class SpatialExtentCrop(object): 29 | 30 | 31 | def __init__(self, spatial_extent, item_list=None, croping_ref_field="pos"): 32 | self.spatial_extent = spatial_extent 33 | self.item_list = item_list 34 | self.croping_ref_field = croping_ref_field 35 | 36 | def __call__(self, data): 37 | 38 | if self.item_list is None: 39 | num_nodes = data.num_nodes 40 | else: 41 | num_nodes = data[self.item_list[0]].shape[0] 42 | 43 | ref_field = data[self.croping_ref_field] 44 | 45 | mask_x = torch.logical_and(ref_field[:,0]>self.spatial_extent[0], ref_field[:,0]self.spatial_extent[1], ref_field[:,1]self.spatial_extent[2], ref_field[:,2] None: 75 | 76 | print("Quantize - init") 77 | print(voxel_size) 78 | print(spatial_extent) 79 | 80 | self.voxel_size = voxel_size 81 | self.spatial_extent = spatial_extent 82 | self.cylinder_coords = kwargs["cylinder_coords"] if "cylinder_coords" in kwargs else False 83 | self.voxel_num = kwargs["voxel_num"] if "voxel_num" in kwargs else None 84 | 85 | self.num_features = kwargs["num_features"] if "num_features" in kwargs else 4 86 | 87 | if self.voxel_num is not None: 88 | self.voxel_sizes = [ 89 | (spatial_extent[3]-spatial_extent[0])/self.voxel_num[0], 90 | (spatial_extent[4]-spatial_extent[1])/self.voxel_num[1], 91 | (spatial_extent[5]-spatial_extent[2])/self.voxel_num[2], 92 | ] 93 | self.gen = PointToVoxel( 94 | vsize_xyz=self.voxel_sizes, 95 | coors_range_xyz=[ 96 | self.spatial_extent[0], 97 | self.spatial_extent[1], 98 | self.spatial_extent[2], 99 | self.spatial_extent[3]+1, 100 | self.spatial_extent[4]+1, 101 | self.spatial_extent[5]+1, 102 | ], 103 | num_point_features=self.num_features, 104 | max_num_voxels=100000, 105 | max_num_points_per_voxel=5) 106 | 107 | else: 108 | if isinstance(self.voxel_size, list): 109 | self.gen = PointToVoxel( 110 | vsize_xyz=self.voxel_size, 111 | coors_range_xyz=self.spatial_extent, 112 | num_point_features=4, 113 | max_num_voxels=100000, 114 | max_num_points_per_voxel=5) 115 | else: 116 | self.gen = PointToVoxel( 117 | vsize_xyz=[self.voxel_size, self.voxel_size, self.voxel_size], 118 | coors_range_xyz=self.spatial_extent, 119 | num_point_features=4, 120 | max_num_voxels=100000, 121 | max_num_points_per_voxel=5) 122 | 123 | if self.cylinder_coords: 124 | self.prior_crop = SpatialExtentCrop(spatial_extent, croping_ref_field="pos_pol") 125 | else: 126 | self.prior_crop = SpatialExtentCrop(spatial_extent, croping_ref_field="pos") 127 | 128 | def __call__(self, data): 129 | 130 | data = self.prior_crop(data) 131 | 132 | voxels, coords, num_points_per_voxel, pc_voxel_id = self.gen.generate_voxel_with_id(data["x"], empty_mean=True) 133 | 134 | x_pool = voxels[:, :, :].sum(dim=1, keepdim=False) 135 | normalizer = torch.clamp_min(num_points_per_voxel.view(-1, 1), min=1.0).type_as(x_pool) 136 | x_pool = x_pool / normalizer 137 | x_pool = x_pool.contiguous() 138 | 139 | data["voxel_coords"] = coords 140 | data["voxel_x"] = x_pool # features 141 | data["voxel_to_pc_id"] = pc_voxel_id # if index in the key, will be incremented automatically by pytorch geometric 142 | data["voxel_number"] = coords.shape[0] # number of voxels 143 | 144 | return data -------------------------------------------------------------------------------- /downstream/networks/backbone/torchsparse/__init__.py: -------------------------------------------------------------------------------- 1 | from .minkunet import MinkUNet34, MinkUNet18, MinkUNet18SC 2 | from .utils import Quantize 3 | from .spvcnn import SPVCNN, SPVCNN0p5 -------------------------------------------------------------------------------- /downstream/networks/backbone/torchsparse/spvcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchsparse 3 | import torchsparse.nn as spnn 4 | import torchsparse.nn.functional as F 5 | from torch import nn 6 | from torchsparse import PointTensor, SparseTensor 7 | from torchsparse.nn.utils import get_kernel_offsets 8 | 9 | __all__ = ['SPVCNN', 'SPVCNN0p5'] 10 | 11 | # z: PointTensor 12 | # return: SparseTensor 13 | def initial_voxelize(z, init_res, after_res): 14 | new_float_coord = torch.cat( 15 | [(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1) 16 | 17 | pc_hash = F.sphash(torch.floor(new_float_coord).int()) 18 | sparse_hash = torch.unique(pc_hash) 19 | idx_query = F.sphashquery(pc_hash, sparse_hash) 20 | counts = F.spcount(idx_query.int(), len(sparse_hash)) 21 | 22 | inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query, 23 | counts) 24 | inserted_coords = torch.round(inserted_coords).int() 25 | inserted_feat = F.spvoxelize(z.F, idx_query, counts) 26 | 27 | new_tensor = SparseTensor(inserted_feat, inserted_coords, 1) 28 | new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords) 29 | z.additional_features['idx_query'][1] = idx_query 30 | z.additional_features['counts'][1] = counts 31 | z.C = new_float_coord 32 | 33 | return new_tensor 34 | 35 | 36 | # x: SparseTensor, z: PointTensor 37 | # return: SparseTensor 38 | def point_to_voxel(x, z): 39 | if z.additional_features is None or z.additional_features.get( 40 | 'idx_query') is None or z.additional_features['idx_query'].get( 41 | x.s) is None: 42 | pc_hash = F.sphash( 43 | torch.cat([ 44 | torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], 45 | z.C[:, -1].int().view(-1, 1) 46 | ], 1)) 47 | sparse_hash = F.sphash(x.C) 48 | idx_query = F.sphashquery(pc_hash, sparse_hash) 49 | counts = F.spcount(idx_query.int(), x.C.shape[0]) 50 | z.additional_features['idx_query'][x.s] = idx_query 51 | z.additional_features['counts'][x.s] = counts 52 | else: 53 | idx_query = z.additional_features['idx_query'][x.s] 54 | counts = z.additional_features['counts'][x.s] 55 | 56 | inserted_feat = F.spvoxelize(z.F, idx_query, counts) 57 | new_tensor = SparseTensor(inserted_feat, x.C, x.s) 58 | new_tensor.cmaps = x.cmaps 59 | new_tensor.kmaps = x.kmaps 60 | 61 | return new_tensor 62 | 63 | 64 | # x: SparseTensor, z: PointTensor 65 | # return: PointTensor 66 | def voxel_to_point(x, z, nearest=False): 67 | if z.idx_query is None or z.weights is None or z.idx_query.get( 68 | x.s) is None or z.weights.get(x.s) is None: 69 | off = get_kernel_offsets(2, x.s, 1, device=z.F.device) 70 | old_hash = F.sphash( 71 | torch.cat([ 72 | torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], 73 | z.C[:, -1].int().view(-1, 1) 74 | ], 1), off) 75 | pc_hash = F.sphash(x.C.to(z.F.device)) 76 | idx_query = F.sphashquery(old_hash, pc_hash) 77 | weights = F.calc_ti_weights(z.C, idx_query, 78 | scale=x.s[0]).transpose(0, 1).contiguous() 79 | idx_query = idx_query.transpose(0, 1).contiguous() 80 | if nearest: 81 | weights[:, 1:] = 0. 82 | idx_query[:, 1:] = -1 83 | new_feat = F.spdevoxelize(x.F, idx_query, weights) 84 | new_tensor = PointTensor(new_feat, 85 | z.C, 86 | idx_query=z.idx_query, 87 | weights=z.weights) 88 | new_tensor.additional_features = z.additional_features 89 | new_tensor.idx_query[x.s] = idx_query 90 | new_tensor.weights[x.s] = weights 91 | z.idx_query[x.s] = idx_query 92 | z.weights[x.s] = weights 93 | 94 | else: 95 | new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s), z.weights.get(x.s)) 96 | new_tensor = PointTensor(new_feat, 97 | z.C, 98 | idx_query=z.idx_query, 99 | weights=z.weights) 100 | new_tensor.additional_features = z.additional_features 101 | 102 | return new_tensor 103 | 104 | 105 | 106 | 107 | class BasicConvolutionBlock(nn.Module): 108 | 109 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1): 110 | super().__init__() 111 | self.net = nn.Sequential( 112 | spnn.Conv3d(inc, 113 | outc, 114 | kernel_size=ks, 115 | dilation=dilation, 116 | stride=stride), 117 | spnn.BatchNorm(outc), 118 | spnn.ReLU(True), 119 | ) 120 | 121 | def forward(self, x): 122 | out = self.net(x) 123 | return out 124 | 125 | 126 | class BasicDeconvolutionBlock(nn.Module): 127 | 128 | def __init__(self, inc, outc, ks=3, stride=1): 129 | super().__init__() 130 | self.net = nn.Sequential( 131 | spnn.Conv3d(inc, 132 | outc, 133 | kernel_size=ks, 134 | stride=stride, 135 | transposed=True), 136 | spnn.BatchNorm(outc), 137 | spnn.ReLU(True), 138 | ) 139 | 140 | def forward(self, x): 141 | return self.net(x) 142 | 143 | 144 | class ResidualBlock(nn.Module): 145 | 146 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1): 147 | super().__init__() 148 | self.net = nn.Sequential( 149 | spnn.Conv3d(inc, 150 | outc, 151 | kernel_size=ks, 152 | dilation=dilation, 153 | stride=stride), 154 | spnn.BatchNorm(outc), 155 | spnn.ReLU(True), 156 | spnn.Conv3d(outc, outc, kernel_size=ks, dilation=dilation, 157 | stride=1), 158 | spnn.BatchNorm(outc), 159 | ) 160 | 161 | if inc == outc and stride == 1: 162 | self.downsample = nn.Identity() 163 | else: 164 | self.downsample = nn.Sequential( 165 | spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, 166 | stride=stride), 167 | spnn.BatchNorm(outc), 168 | ) 169 | 170 | self.relu = spnn.ReLU(True) 171 | 172 | def forward(self, x): 173 | out = self.relu(self.net(x) + self.downsample(x)) 174 | return out 175 | 176 | 177 | class SPVCNN(nn.Module): 178 | 179 | def __init__(self, 180 | in_channels, out_channels, 181 | **kwargs): 182 | super().__init__() 183 | 184 | cr = kwargs.get('cr', 1.0) 185 | cs = [32, 32, 64, 128, 256, 256, 128, 96, 96] 186 | cs = [int(cr * x) for x in cs] 187 | 188 | # if 'pres' in kwargs and 'vres' in kwargs: 189 | # self.pres = kwargs['pres'] 190 | # self.vres = kwargs['vres'] 191 | 192 | self.pres = kwargs["quantization_params"]["voxel_size"] 193 | self.vres = kwargs["quantization_params"]["voxel_size"] 194 | 195 | 196 | self.stem = nn.Sequential( 197 | spnn.Conv3d(in_channels, cs[0], kernel_size=3, stride=1), 198 | spnn.BatchNorm(cs[0]), spnn.ReLU(True), 199 | spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1), 200 | spnn.BatchNorm(cs[0]), spnn.ReLU(True)) 201 | 202 | self.stage1 = nn.Sequential( 203 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), 204 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), 205 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), 206 | ) 207 | 208 | self.stage2 = nn.Sequential( 209 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), 210 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), 211 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), 212 | ) 213 | 214 | self.stage3 = nn.Sequential( 215 | BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1), 216 | ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1), 217 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), 218 | ) 219 | 220 | self.stage4 = nn.Sequential( 221 | BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1), 222 | ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1), 223 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), 224 | ) 225 | 226 | self.up1 = nn.ModuleList([ 227 | BasicDeconvolutionBlock(cs[4], cs[5], ks=2, stride=2), 228 | nn.Sequential( 229 | ResidualBlock(cs[5] + cs[3], cs[5], ks=3, stride=1, dilation=1), 230 | ResidualBlock(cs[5], cs[5], ks=3, stride=1, dilation=1), 231 | ) 232 | ]) 233 | 234 | self.up2 = nn.ModuleList([ 235 | BasicDeconvolutionBlock(cs[5], cs[6], ks=2, stride=2), 236 | nn.Sequential( 237 | ResidualBlock(cs[6] + cs[2], cs[6], ks=3, stride=1, dilation=1), 238 | ResidualBlock(cs[6], cs[6], ks=3, stride=1, dilation=1), 239 | ) 240 | ]) 241 | 242 | self.up3 = nn.ModuleList([ 243 | BasicDeconvolutionBlock(cs[6], cs[7], ks=2, stride=2), 244 | nn.Sequential( 245 | ResidualBlock(cs[7] + cs[1], cs[7], ks=3, stride=1, dilation=1), 246 | ResidualBlock(cs[7], cs[7], ks=3, stride=1, dilation=1), 247 | ) 248 | ]) 249 | 250 | self.up4 = nn.ModuleList([ 251 | BasicDeconvolutionBlock(cs[7], cs[8], ks=2, stride=2), 252 | nn.Sequential( 253 | ResidualBlock(cs[8] + cs[0], cs[8], ks=3, stride=1, dilation=1), 254 | ResidualBlock(cs[8], cs[8], ks=3, stride=1, dilation=1), 255 | ) 256 | ]) 257 | 258 | if out_channels > 0: 259 | self.classifier = nn.Sequential(nn.Linear(cs[8], out_channels)) 260 | else: 261 | self.classifier = nn.Identity() 262 | 263 | self.point_transforms = nn.ModuleList([ 264 | nn.Sequential( 265 | nn.Linear(cs[0], cs[4]), 266 | nn.BatchNorm1d(cs[4]), 267 | nn.ReLU(True), 268 | ), 269 | nn.Sequential( 270 | nn.Linear(cs[4], cs[6]), 271 | nn.BatchNorm1d(cs[6]), 272 | nn.ReLU(True), 273 | ), 274 | nn.Sequential( 275 | nn.Linear(cs[6], cs[8]), 276 | nn.BatchNorm1d(cs[8]), 277 | nn.ReLU(True), 278 | ) 279 | ]) 280 | 281 | self.weight_initialization() 282 | self.dropout = nn.Dropout(0.3, True) 283 | 284 | def weight_initialization(self): 285 | for m in self.modules(): 286 | if isinstance(m, nn.BatchNorm1d): 287 | nn.init.constant_(m.weight, 1) 288 | nn.init.constant_(m.bias, 0) 289 | 290 | # def forward(self, x): 291 | 292 | def forward(self, data): 293 | 294 | coords = torch.cat([data["voxel_coords"], data["voxel_coords_batch"].unsqueeze(1)], dim=1).int() 295 | feats = data["voxel_x"] 296 | x = torchsparse.SparseTensor(coords=coords, feats=feats) 297 | 298 | ######## SPVCNN forward 299 | 300 | # x: SparseTensor z: PointTensor 301 | z = PointTensor(x.F, x.C.float()) 302 | 303 | x0 = initial_voxelize(z, self.pres, self.vres) 304 | 305 | x0 = self.stem(x0) 306 | z0 = voxel_to_point(x0, z, nearest=False) 307 | z0.F = z0.F 308 | 309 | x1 = point_to_voxel(x0, z0) 310 | x1 = self.stage1(x1) 311 | x2 = self.stage2(x1) 312 | x3 = self.stage3(x2) 313 | x4 = self.stage4(x3) 314 | z1 = voxel_to_point(x4, z0) 315 | z1.F = z1.F + self.point_transforms[0](z0.F) 316 | 317 | y1 = point_to_voxel(x4, z1) 318 | y1.F = self.dropout(y1.F) 319 | y1 = self.up1[0](y1) 320 | y1 = torchsparse.cat([y1, x3]) 321 | y1 = self.up1[1](y1) 322 | 323 | y2 = self.up2[0](y1) 324 | y2 = torchsparse.cat([y2, x2]) 325 | y2 = self.up2[1](y2) 326 | z2 = voxel_to_point(y2, z1) 327 | z2.F = z2.F + self.point_transforms[1](z1.F) 328 | 329 | y3 = point_to_voxel(y2, z2) 330 | y3.F = self.dropout(y3.F) 331 | y3 = self.up3[0](y3) 332 | y3 = torchsparse.cat([y3, x1]) 333 | y3 = self.up3[1](y3) 334 | 335 | y4 = self.up4[0](y3) 336 | y4 = torchsparse.cat([y4, x0]) 337 | y4 = self.up4[1](y4) 338 | z3 = voxel_to_point(y4, z2) 339 | z3.F = z3.F + self.point_transforms[2](z2.F) 340 | 341 | outputs = self.classifier(z3.F) 342 | 343 | ###### 344 | 345 | vox_num = data["voxel_number"] 346 | increment = torch.cat([vox_num.new_zeros((1,)), vox_num[:-1]], dim=0) 347 | increment = increment.cumsum(0) 348 | increment = increment[data["batch"]] 349 | inv_map = data["voxel_to_pc_id"] + increment 350 | 351 | # interpolate the outputs 352 | outputs = outputs[inv_map] 353 | 354 | return outputs 355 | 356 | 357 | class SPVCNN0p5(SPVCNN): 358 | def __init__(self, 359 | in_channels, out_channels, 360 | **kwargs): 361 | super().__init__(in_channels=in_channels, out_channels=out_channels, cr=0.5, **kwargs) 362 | -------------------------------------------------------------------------------- /downstream/networks/backbone/torchsparse/utils.py: -------------------------------------------------------------------------------- 1 | from torchsparse.utils.quantize import sparse_quantize 2 | import torch 3 | import numpy as np 4 | 5 | class Quantize(object): 6 | 7 | def __init__(self, voxel_size, **kwargs) -> None: 8 | self.voxel_size = voxel_size 9 | 10 | def __call__(self, data): 11 | 12 | pc_ = np.round(data["pos"].numpy() / self.voxel_size).astype(np.int32) 13 | 14 | pc_ -= pc_.min(0, keepdims=1) 15 | 16 | coords, indices, inverse_map = sparse_quantize(pc_, 17 | return_index=True, 18 | return_inverse=True) 19 | 20 | coords = torch.tensor(coords, dtype=torch.int) 21 | 22 | indices = torch.tensor(indices) 23 | feats = data["x"][indices] 24 | 25 | inverse_map = torch.tensor(inverse_map, dtype=torch.long) 26 | 27 | data["voxel_coords"] = coords 28 | data["voxel_x"] = feats 29 | data["voxel_to_pc_id"] = inverse_map 30 | data["voxel_number"] = coords.shape[0] 31 | 32 | return data -------------------------------------------------------------------------------- /downstream/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_transforms import get_transforms, get_input_channels -------------------------------------------------------------------------------- /downstream/transforms/create_inputs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | 5 | class CreateInputs(object): 6 | 7 | def __init__(self, item_list): 8 | 9 | if isinstance(item_list, list): 10 | self.item_list = item_list 11 | elif isinstance(item_list, str): 12 | if item_list[0] == "[": 13 | item_list = item_list[1:] 14 | if item_list[-1] == "]": 15 | item_list = item_list[:-1] 16 | item_list = item_list.split(",") 17 | self.item_list = item_list 18 | logging.info(f"CreateInputs -- {item_list}") 19 | 20 | def __call__(self, data): 21 | 22 | features = [] 23 | for key in self.item_list: 24 | features.append(data[key]) 25 | 26 | data["x"] = torch.cat(features, dim=1) 27 | return data 28 | -------------------------------------------------------------------------------- /downstream/transforms/create_points.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import logging 5 | import re 6 | 7 | 8 | class CreatePoints(object): 9 | 10 | def __init__(self, npts=None, exact_number_of_points=None, pts_item_list=None, n_non_manifold_pts=None, non_manifold_dist=0.1): 11 | 12 | logging.info(f"Transforms - CreatePoints - npts {npts} - exact_number_of_points {exact_number_of_points} - non_manifold {n_non_manifold_pts} - non_manifold_dist {non_manifold_dist}") 13 | self.npts = npts 14 | self.exact_number_of_points = exact_number_of_points 15 | self.n_non_manifold_pts = n_non_manifold_pts 16 | self.pts_item_list = pts_item_list 17 | self.non_manifold_dist = non_manifold_dist 18 | 19 | def __call__(self, data): 20 | 21 | num_nodes = data.num_nodes if (self.pts_item_list is None) else data[self.pts_item_list[0]].shape[0] 22 | 23 | # select the points 24 | choice = None 25 | if self.npts < num_nodes: 26 | choice = torch.randperm(num_nodes)[:self.npts] 27 | elif self.npts > num_nodes and self.exact_number_of_points: 28 | choice = np.random.choice(num_nodes, self.npts, replace=True) 29 | choice = torch.from_numpy(choice).to(torch.long) 30 | 31 | # non manifold points 32 | if self.n_non_manifold_pts is not None: 33 | 34 | # nmp -> non_manifold points 35 | if "pos2" in data.keys: 36 | 37 | n_nmp2 = self.n_non_manifold_pts // 2 38 | n_nmp = self.n_non_manifold_pts - n_nmp2 39 | 40 | 41 | n_nmp2_out = n_nmp2 // 3 42 | n_nmp2_out_far = n_nmp2 // 3 43 | n_nmp2_in = n_nmp2 - 2 * (n_nmp2//3) 44 | nmp2_choice_in = torch.randperm(data["pos2"].shape[0])[:n_nmp2_in] 45 | nmp2_choice_out = torch.randperm(data["pos2"].shape[0])[:n_nmp2_out] 46 | nmp2_choice_out_far = torch.randperm(data["pos2"].shape[0])[:n_nmp2_out_far] 47 | 48 | else: 49 | n_nmp2 = 0 50 | n_nmp = self.n_non_manifold_pts 51 | 52 | # select the points for the current frame 53 | n_nmp_out = n_nmp // 3 54 | n_nmp_out_far = n_nmp // 3 55 | n_nmp_in = n_nmp - 2 * (n_nmp//3) 56 | nmp_choice_in = torch.randperm(data["pos"].shape[0])[:n_nmp_in] 57 | nmp_choice_out = torch.randperm(data["pos"].shape[0])[:n_nmp_out] 58 | nmp_choice_out_far = torch.randperm(data["pos"].shape[0])[:n_nmp_out_far] 59 | 60 | # center 61 | center = torch.zeros((1,3), dtype=torch.float) 62 | 63 | # in points 64 | pos = data["pos"][nmp_choice_in] 65 | dirs = F.normalize(pos, dim=1) 66 | pos_in = pos + self.non_manifold_dist * dirs * torch.rand((pos.shape[0],1)) 67 | occ_in = torch.ones(pos_in.shape[0], dtype=torch.long) 68 | 69 | # out points 70 | pos = data["pos"][nmp_choice_out] 71 | dirs = F.normalize(pos, dim=1) 72 | pos_out = pos - self.non_manifold_dist * dirs * torch.rand((pos.shape[0],1)) 73 | occ_out = torch.zeros(pos_out.shape[0], dtype=torch.long) 74 | 75 | # out far points 76 | pos = data["pos"][nmp_choice_out_far] 77 | dirs = F.normalize(pos, dim=1) 78 | pos_out_far = (pos - center) * torch.rand((pos.shape[0],1)) + center 79 | occ_out_far = torch.zeros(pos_out_far.shape[0], dtype=torch.long) 80 | 81 | 82 | pos_non_manifold = torch.cat([pos_in, pos_out, pos_out_far], dim=0) 83 | occupancies = torch.cat([occ_in, occ_out, occ_out_far], dim=0) 84 | intensities = None 85 | rgb = None 86 | 87 | if "intensities" in data: 88 | intensities_in = data["intensities"][nmp_choice_in] 89 | intensities_out = data["intensities"][nmp_choice_out] 90 | intensities_out_far = torch.full((pos_out_far.shape[0],1), fill_value=-1) 91 | intensities = torch.cat([intensities_in, intensities_out, intensities_out_far], dim=0) 92 | 93 | if "rgb" in data: 94 | rgb_in = data["rgb"][nmp_choice_in] 95 | rgb_out = data["rgb"][nmp_choice_out] 96 | rgb_out_far = torch.full((pos_out_far.shape[0],3), fill_value=-1) 97 | rgb = torch.cat([rgb_in, rgb_out, rgb_out_far], dim=0) 98 | 99 | 100 | if n_nmp2 > 0: 101 | # multiframe setting 102 | 103 | # in points 104 | pos = data["pos2"][nmp2_choice_in] 105 | dirs = F.normalize(pos - data["sensors2"][nmp2_choice_in], dim=1) 106 | pos_in = pos + self.non_manifold_dist * dirs * torch.rand((pos.shape[0],1)) 107 | occ_in = torch.ones(pos_in.shape[0], dtype=torch.long) 108 | 109 | # out points 110 | pos = data["pos2"][nmp2_choice_out] 111 | dirs = F.normalize(pos - data["sensors2"][nmp2_choice_out], dim=1) 112 | pos_out = pos - self.non_manifold_dist * dirs * torch.rand((pos.shape[0],1)) 113 | occ_out = torch.zeros(pos_out.shape[0], dtype=torch.long) 114 | 115 | # out far points 116 | pos = data["pos2"][nmp2_choice_out_far] 117 | dirs = F.normalize(pos - data["sensors2"][nmp2_choice_out_far], dim=1) 118 | pos_out_far = (pos - center) * torch.rand((pos.shape[0],1)) + center 119 | occ_out_far = torch.zeros(pos_out_far.shape[0], dtype=torch.long) 120 | 121 | 122 | pos_non_manifold2 = torch.cat([pos_in, pos_out, pos_out_far], dim=0) 123 | occupancies2 = torch.cat([occ_in, occ_out, occ_out_far], dim=0) 124 | intensities2 = None 125 | 126 | pos_non_manifold = torch.cat([pos_non_manifold, pos_non_manifold2], dim=0) 127 | occupancies = torch.cat([occupancies, occupancies2], dim=0) 128 | 129 | if "intensities2" in data: 130 | intensities_in = data["intensities2"][nmp2_choice_in] 131 | intensities_out = data["intensities2"][nmp2_choice_out] 132 | intensities_out_far = torch.full((pos_out_far.shape[0],1), fill_value=-1) 133 | intensities2 = torch.cat([intensities_in, intensities_out, intensities_out_far], dim=0) 134 | intensities = torch.cat([intensities, intensities2], dim=0) 135 | 136 | data["pos_non_manifold"] = pos_non_manifold 137 | data["occupancies"] = occupancies 138 | if intensities is not None: 139 | data["intensities_non_manifold"] = intensities 140 | 141 | if rgb is not None: 142 | data["rgb_non_manifold"] = rgb 143 | 144 | 145 | 146 | 147 | # replace in data 148 | if choice is not None: 149 | 150 | # selecting elements 151 | if self.pts_item_list is None: 152 | for key, item in data: 153 | if bool(re.search('edge', key)): 154 | continue 155 | if (torch.is_tensor(item) and item.size(0) == num_nodes 156 | and item.size(0) != 1): 157 | data[key] = item[choice] 158 | else: 159 | for key, item in data: 160 | if key in self.pts_item_list: 161 | if bool(re.search('edge', key)): 162 | continue 163 | if (torch.is_tensor(item) and item.size(0) != 1): 164 | data[key] = item[choice] 165 | 166 | return data -------------------------------------------------------------------------------- /downstream/transforms/duplicate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class Duplicate(object): 5 | 6 | def __init__(self, item_list, prefix) -> None: 7 | 8 | logging.info(f"Transforms - Duplicate {item_list} - {prefix}") 9 | self.item_list = item_list 10 | self.prefix = prefix 11 | 12 | def __call__(self, data): 13 | 14 | for item in self.item_list: 15 | data[self.prefix + item] = data[item].clone() 16 | 17 | return data 18 | -------------------------------------------------------------------------------- /downstream/transforms/get_transforms.py: -------------------------------------------------------------------------------- 1 | import torch_geometric.transforms as T 2 | import logging 3 | 4 | import importlib 5 | 6 | from transforms.create_inputs import CreateInputs 7 | from transforms.create_points import CreatePoints 8 | from transforms.duplicate import Duplicate 9 | from transforms.random_rotate import RandomRotate 10 | from transforms.random_flip import RandomFlip 11 | from transforms.scaling import Scaling 12 | from transforms.voxel_decimation import VoxelDecimation 13 | 14 | 15 | class CleanData(object): 16 | 17 | def __init__(self, prefixes=[], item_list=[]): 18 | self.prefixes = prefixes 19 | self.item_list = item_list 20 | 21 | def __call__(self, data): 22 | 23 | for prefix in self.prefixes: 24 | for key in data.keys: 25 | if key.startswith(prefix): 26 | data[key] = None 27 | 28 | for key in self.item_list: 29 | if key in data.keys: 30 | data[key] = None 31 | 32 | return data 33 | 34 | 35 | class ToDict(object): 36 | 37 | def __call__(self, data): 38 | 39 | d = {} 40 | for key in data.keys: 41 | d[key] = data[key] 42 | return d 43 | 44 | def __repr__(self): 45 | return '{}'.format(self.__class__.__name__) 46 | 47 | 48 | def get_input_channels(input_config): 49 | 50 | # compute the input size: 51 | in_channels = 0 52 | for key in input_config: 53 | if key in ["intensities", "x"]: 54 | in_channels += 1 55 | elif key in ["pos", "rgb", "normals", "dirs"]: 56 | in_channels += 3 57 | 58 | return in_channels 59 | 60 | 61 | def get_transforms(config, network_function=None, train=True, downstream=False, keep_orignal_data=False): 62 | 63 | logging.info(f"Transforms - Train {train} - Downstream {downstream}") 64 | 65 | augmentations = config["transforms"] 66 | print(augmentations) 67 | transforms = [] 68 | 69 | if keep_orignal_data: 70 | transforms.append(Duplicate(["pos", "y"], "original_")) 71 | 72 | # if augmentations['voxel_decimation'] is not None: 73 | # transforms.append(VoxelDecimation(augmentations["voxel_decimation"])) 74 | 75 | exact_number_of_points = (config["network"]["backbone"] in ["FKAConv", "DGCNN"]) 76 | n_non_manifold_pts = config["non_manifold_points"] if (not downstream) else None 77 | non_manifold_dist = config["non_manifold_dist"] if "non_manifold_dist" in config else 0.1 78 | 79 | # transforms.append(CreatePoints(npts=config["manifold_points"], 80 | # exact_number_of_points=exact_number_of_points, 81 | # pts_item_list=["x", "pos", "y", "intensities"], 82 | # n_non_manifold_pts=n_non_manifold_pts, non_manifold_dist=non_manifold_dist)) 83 | 84 | if augmentations["scaling_intensities"]: 85 | logging.info("Transforms - Scale intensities") 86 | transforms.append(Scaling(255., item_list=["intensities", "intensities_non_manifold"])) 87 | 88 | transforms.append(CleanData(prefixes=[], item_list=["pos2", "intensities2", "sensors2"])) 89 | 90 | if train: 91 | if augmentations["random_rotation_z"]: 92 | transforms.append(RandomRotate(180, axis=2, item_list=["pos"])) 93 | 94 | if augmentations["random_flip"]: 95 | logging.info("Transforms - Flip") 96 | transforms.append(RandomFlip(["pos"])) 97 | 98 | transforms.append(CreateInputs(config["inputs"])) 99 | 100 | if config["network"]["framework"] is not None: 101 | logging.info(f"Transforms - Quantize - {config['network']['framework']}") 102 | model_module = importlib.import_module("networks.backbone." + config["network"]["framework"]) 103 | transforms.append(model_module.Quantize(**config["network"]["backbone_params"]["quantization_params"])) 104 | 105 | transforms = T.Compose(transforms) 106 | 107 | return transforms 108 | -------------------------------------------------------------------------------- /downstream/transforms/random_flip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | class RandomFlip(object): 5 | 6 | def __init__(self, item_list) -> None: 7 | self.item_list = item_list 8 | 9 | def __call__(self, data): 10 | 11 | # if torch.randint(0, 2, size=(1,)).item(): 12 | # for item in self.item_list: 13 | # if item not in data: 14 | # continue 15 | # if len(data[item].shape) == 2: 16 | # data[item][:, 0] = -data[item][:, 0] 17 | # elif len(data[item].shape) == 1: 18 | # data[item][0] = -data[item][0] 19 | # else: 20 | # raise NotImplementedError 21 | 22 | for item in self.item_list: 23 | for curr_ax in range(2): 24 | if random.random() < 0.5: 25 | data[item][:, curr_ax] = -data[item][:, curr_ax] 26 | 27 | return data 28 | -------------------------------------------------------------------------------- /downstream/transforms/random_rotate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import numbers 4 | import math 5 | import random 6 | 7 | 8 | 9 | class RandomRotate(object): 10 | def __init__(self, degrees, axis=0, item_list=["pos"]): 11 | if isinstance(degrees, numbers.Number): 12 | degrees = (-abs(degrees), abs(degrees)) 13 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 14 | 15 | logging.info(f"Transforms - Axis {axis} - {item_list}") 16 | self.degrees = degrees 17 | self.axis = axis 18 | self.item_list = item_list 19 | 20 | def __call__(self, data): 21 | degree = math.pi * random.uniform(*self.degrees) / 180.0 22 | sin, cos = math.sin(degree), math.cos(degree) 23 | 24 | if data.pos.size(-1) == 2: 25 | matrix = [[cos, sin], [-sin, cos]] 26 | else: 27 | if self.axis == 0: 28 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]] 29 | elif self.axis == 1: 30 | matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]] 31 | else: 32 | matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]] 33 | 34 | matrix = torch.tensor(matrix) 35 | 36 | for key, item in data: 37 | if key in self.item_list: 38 | if torch.is_tensor(item): 39 | data[key] = torch.matmul(item, matrix.to(item.dtype).to(item.device)) 40 | if ("second_" + key) in data.keys: 41 | data["second_" + key] = torch.matmul(data["second_" + key], matrix.to(item.dtype).to(item.device)) 42 | 43 | return data 44 | -------------------------------------------------------------------------------- /downstream/transforms/scaling.py: -------------------------------------------------------------------------------- 1 | class Scaling(object): 2 | 3 | def __init__(self, scale, item_list=["pos"]): 4 | self.scale = scale 5 | self.item_list = item_list 6 | 7 | def __call__(self, data): 8 | 9 | for key in self.item_list: 10 | if key in data.keys: 11 | data[key] = data[key] * self.scale 12 | 13 | return data 14 | -------------------------------------------------------------------------------- /downstream/transforms/voxel_decimation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import logging 4 | 5 | 6 | class VoxelDecimation(object): 7 | 8 | def __init__(self, voxel_size) -> None: 9 | logging.info(f"Transforms - VoxelDecimation - {voxel_size}") 10 | self.v_size = voxel_size 11 | 12 | def __call__(self, data): 13 | 14 | pos = data["pos"] 15 | pos = (pos / self.v_size).long() 16 | num_pts = pos.shape[0] 17 | 18 | # Numpy version 19 | pos, indices = np.unique(pos.cpu().numpy(), return_index=True, axis=0) 20 | 21 | for key in data.keys: 22 | if isinstance(data[key], torch.Tensor) and ("second" not in key) and data[key].shape[0] == num_pts: 23 | data[key] = data[key][indices] 24 | 25 | # if second frame --> decimation of the second frame 26 | if "second_pos" in data.keys: 27 | pos = data["second_pos"] 28 | pos = (pos / self.v_size).long() 29 | num_pts = pos.shape[0] 30 | pos, indices = np.unique(pos.cpu().numpy(), return_index=True, axis=0) 31 | 32 | for key in data.keys: 33 | if isinstance(data[key], torch.Tensor) and ("second" in key) and data[key].shape[0] == num_pts: 34 | data[key] = data[key][indices] 35 | 36 | return data 37 | -------------------------------------------------------------------------------- /downstream/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eaphan/NCLR/b3a944af649b64f0aed82aae0211ebc5f2fe2d13/downstream/utils/__init__.py -------------------------------------------------------------------------------- /downstream/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import TQDMProgressBar 2 | 3 | 4 | class CustomProgressBar(TQDMProgressBar): 5 | def on_train_epoch_start(self, trainer, *args): 6 | super().on_train_epoch_start(trainer, *args) 7 | self.main_progress_bar.reset() 8 | 9 | def get_metrics(self, trainer, model): 10 | # don't show the version number 11 | items = super().get_metrics(trainer, model) 12 | items.pop("v_num", None) 13 | return items -------------------------------------------------------------------------------- /downstream/utils/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict, List, Tuple 3 | 4 | class ConfusionMatrix: 5 | """ 6 | Class for confusion matrix with various convenient methods. 7 | """ 8 | def __init__(self, num_classes: int, ignore_idx: int = None): 9 | """ 10 | Initialize a ConfusionMatrix object. 11 | :param num_classes: Number of classes in the confusion matrix. 12 | :param ignore_idx: Index of the class to be ignored in the confusion matrix. 13 | """ 14 | self.num_classes = num_classes 15 | self.ignore_idx = ignore_idx 16 | 17 | self.global_cm = None 18 | 19 | def update(self, gt_array: np.ndarray, pred_array: np.ndarray) -> None: 20 | """ 21 | Updates the global confusion matrix. 22 | :param gt_array: An array containing the ground truth labels. 23 | :param pred_array: An array containing the predicted labels. 24 | """ 25 | cm = self._get_confusion_matrix(gt_array, pred_array) 26 | 27 | if self.global_cm is None: 28 | self.global_cm = cm 29 | else: 30 | self.global_cm += cm 31 | 32 | def _get_confusion_matrix(self, gt_array: np.ndarray, pred_array: np.ndarray) -> np.ndarray: 33 | """ 34 | Obtains the confusion matrix for the segmentation of a single point cloud. 35 | :param gt_array: An array containing the ground truth labels. 36 | :param pred_array: An array containing the predicted labels. 37 | :return: N x N array where N is the number of classes. 38 | """ 39 | assert all((gt_array >= 0) & (gt_array < self.num_classes)), \ 40 | "Error: Array for ground truth must be between 0 and {} (inclusive).".format(self.num_classes - 1) 41 | assert all((pred_array > 0) & (pred_array < self.num_classes)), \ 42 | "Error: Array for predictions must be between 1 and {} (inclusive).".format(self.num_classes - 1) 43 | 44 | label = self.num_classes * gt_array.astype('int') + pred_array 45 | count = np.bincount(label, minlength=self.num_classes ** 2) 46 | 47 | # Make confusion matrix (rows = gt, cols = preds). 48 | confusion_matrix = count.reshape(self.num_classes, self.num_classes) 49 | 50 | # For the class to be ignored, set both the row and column to 0 (adapted from 51 | # https://github.com/davidtvs/PyTorch-ENet/blob/master/metric/iou.py). 52 | if self.ignore_idx is not None: 53 | confusion_matrix[self.ignore_idx, :] = 0 54 | confusion_matrix[:, self.ignore_idx] = 0 55 | 56 | return confusion_matrix 57 | 58 | def get_per_class_iou(self) -> List[float]: 59 | """ 60 | Gets the IOU of each class in a confusion matrix. 61 | :return: An array in which the IOU of a particular class sits at the array index corresponding to the 62 | class index. 63 | """ 64 | conf = self.global_cm.copy() 65 | 66 | # Get the intersection for each class. 67 | intersection = np.diagonal(conf) 68 | 69 | # Get the union for each class. 70 | ground_truth_set = conf.sum(axis=1) 71 | predicted_set = conf.sum(axis=0) 72 | union = ground_truth_set + predicted_set - intersection 73 | 74 | # Get the IOU for each class. 75 | # In case we get a division by 0, ignore / hide the error(adapted from 76 | # https://github.com/davidtvs/PyTorch-ENet/blob/master/metric/iou.py). 77 | with np.errstate(divide='ignore', invalid='ignore'): 78 | iou_per_class = intersection / (union.astype(np.float32)) 79 | 80 | return iou_per_class 81 | 82 | def get_mean_iou(self) -> float: 83 | """ 84 | Gets the mean IOU (mIOU) over the classes. 85 | :return: mIOU over the classes. 86 | """ 87 | iou_per_class = self.get_per_class_iou() 88 | miou = float(np.nanmean(iou_per_class)) 89 | return miou 90 | 91 | def get_freqweighted_iou(self) -> float: 92 | """ 93 | Gets the frequency-weighted IOU over the classes. 94 | :return: Frequency-weighted IOU over the classes. 95 | """ 96 | conf = self.global_cm.copy() 97 | 98 | # Get the number of points per class (based on ground truth). 99 | num_points_per_class = conf.sum(axis=1) 100 | 101 | # Get the total number of points in the eval set. 102 | num_points_total = conf.sum() 103 | 104 | # Get the IOU per class. 105 | iou_per_class = self.get_per_class_iou() 106 | 107 | # Weight the IOU by frequency and sum across the classes. 108 | freqweighted_iou = float(np.nansum(num_points_per_class * iou_per_class) / num_points_total) 109 | 110 | return freqweighted_iou -------------------------------------------------------------------------------- /downstream/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def stats_overall_accuracy(cm): 5 | """Computes the overall accuracy. 6 | 7 | # Arguments: 8 | cm: 2-D numpy array. 9 | Confusion matrix. 10 | """ 11 | return np.trace(cm) / cm.sum() 12 | 13 | 14 | def stats_pfa_per_class(cm): 15 | """Computes the probability of false alarms. 16 | 17 | # Arguments: 18 | cm: 2-D numpy array. 19 | Confusion matrix. 20 | """ 21 | sums = np.sum(cm, axis=0) 22 | mask = sums > 0 23 | sums[sums == 0] = 1 24 | pfa_per_class = (cm.sum(axis=0) - np.diag(cm)) / sums 25 | pfa_per_class[np.logical_not(mask)] = -1 26 | average_pfa = pfa_per_class[mask].mean() 27 | return average_pfa, pfa_per_class 28 | 29 | 30 | def stats_accuracy_per_class(cm): 31 | """Computes the accuracy per class and average accuracy. 32 | 33 | # Arguments: 34 | cm: 2-D numpy array. 35 | Confusion matrix. 36 | 37 | # Returns 38 | average_accuracy: float. 39 | The average accuracy. 40 | accuracy_per_class: 1-D numpy array. 41 | The accuracy per class. 42 | """ 43 | sums = np.sum(cm, axis=1) 44 | mask = sums > 0 45 | sums[sums == 0] = 1 46 | accuracy_per_class = np.diag(cm) / sums # sum over lines 47 | accuracy_per_class[np.logical_not(mask)] = -1 48 | average_accuracy = accuracy_per_class[mask].mean() 49 | return average_accuracy, accuracy_per_class 50 | 51 | 52 | def stats_iou_per_class(cm): 53 | """Computes the IoU per class and average IoU. 54 | 55 | # Arguments: 56 | cm: 2-D numpy array. 57 | Confusion matrix. 58 | 59 | # Returns 60 | average_accuracy: float. 61 | The average IoU. 62 | accuracy_per_class: 1-D numpy array. 63 | The IoU per class. 64 | """ 65 | 66 | # compute TP, FN et FP 67 | TP = np.diagonal(cm, axis1=-2, axis2=-1) 68 | TP_plus_FN = np.sum(cm, axis=-1) 69 | TP_plus_FP = np.sum(cm, axis=-2) 70 | 71 | # compute IoU 72 | mask = TP_plus_FN == 0 73 | IoU = TP / (TP_plus_FN + TP_plus_FP - TP + mask) 74 | 75 | # replace IoU with 0 by the average IoU 76 | aIoU = IoU[np.logical_not(mask)].mean(axis=-1, keepdims=True) 77 | IoU += mask * aIoU 78 | 79 | return IoU.mean(axis=-1), IoU 80 | 81 | 82 | def stats_f1score_per_class(cm): 83 | """Computes the F1 per class and average F1. 84 | 85 | # Arguments: 86 | cm: 2-D numpy array. 87 | Confusion matrix. 88 | 89 | # Returns 90 | average_accuracy: float. 91 | The average F1. 92 | accuracy_per_class: 1-D numpy array. 93 | The F1 per class. 94 | """ 95 | # defined as 2 * recall * prec / recall + prec 96 | sums = np.sum(cm, axis=1) + np.sum(cm, axis=0) 97 | mask = sums > 0 98 | sums[sums == 0] = 1 99 | f1score_per_class = 2 * np.diag(cm) / sums 100 | f1score_per_class[np.logical_not(mask)] = -1 101 | average_f1_score = f1score_per_class[mask].mean() 102 | return average_f1_score, f1score_per_class 103 | -------------------------------------------------------------------------------- /downstream/utils/utils.py: -------------------------------------------------------------------------------- 1 | class bcolors: 2 | HEADER = '\033[95m' 3 | OKBLUE = '\033[94m' 4 | OKGREEN = '\033[92m' 5 | WARNING = '\033[93m' 6 | FAIL = '\033[91m' 7 | ENDC = '\033[0m' 8 | BOLD = '\033[1m' 9 | UNDERLINE = '\033[4m' 10 | 11 | # wrap blue / green 12 | def wblue(str): 13 | return bcolors.OKBLUE+str+bcolors.ENDC 14 | def wgreen(str): 15 | return bcolors.OKGREEN+str+bcolors.ENDC 16 | def wred(str): 17 | return bcolors.FAIL+str+bcolors.ENDC 18 | -------------------------------------------------------------------------------- /downstream/visu_downstream.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import yaml 4 | import logging 5 | import argparse 6 | import importlib 7 | 8 | from tqdm import tqdm 9 | 10 | from scipy.spatial import KDTree 11 | 12 | # torch imports 13 | import torch 14 | 15 | 16 | from torch_geometric.data import DataLoader 17 | 18 | from transforms import get_transforms, get_input_channels 19 | 20 | import datasets 21 | import networks 22 | from networks.backbone import * 23 | 24 | 25 | if __name__ == "__main__": 26 | logging.getLogger().setLevel("INFO") 27 | np.random.seed(0) 28 | torch.manual_seed(0) 29 | torch.cuda.manual_seed_all(0) 30 | 31 | parser = argparse.ArgumentParser(description='Self supervised.') 32 | parser.add_argument('--ckpt', type=str, required=True) 33 | parser.add_argument('--resultsDir', type=str, default="visus") 34 | parser.add_argument('--config', type=str, default="config.yaml") 35 | parser.add_argument('--split', type=str, default="val") 36 | opts = parser.parse_args() 37 | 38 | logging.info("loading the config file") 39 | config = yaml.load(open(opts.config, "r"), yaml.FullLoader) 40 | 41 | logging.info("Dataset") 42 | DatasetClass = eval("datasets." + config["dataset_name"]) 43 | test_transforms = get_transforms(config, train=False, downstream=True, keep_orignal_data=True) 44 | test_dataset = DatasetClass(config["dataset_root"], 45 | split=opts.split, 46 | transform=test_transforms, 47 | ) 48 | 49 | logging.info("Dataloader") 50 | test_loader = DataLoader( 51 | test_dataset, 52 | batch_size=1, 53 | shuffle=False, 54 | num_workers=config["threads"], 55 | follow_batch=["voxel_coords"] 56 | ) 57 | 58 | num_classes = config["downstream"]["num_classes"] 59 | device = torch.device("cuda") 60 | 61 | logging.info("Network") 62 | if config["network"]["backbone_params"] is None: 63 | config["network"]["backbone_params"] = {} 64 | config["network"]["backbone_params"]["in_channels"] = get_input_channels(config["inputs"]) 65 | config["network"]["backbone_params"]["out_channels"] = config["downstream"]["num_classes"] 66 | 67 | backbone_name = "networks.backbone." 68 | if config["network"]["framework"] is not None: 69 | backbone_name += config["network"]["framework"] 70 | importlib.import_module(backbone_name) 71 | backbone_name += "." + config["network"]["backbone"] 72 | net = eval(backbone_name)(**config["network"]["backbone_params"]) 73 | net.to(device) 74 | net.eval() 75 | 76 | logging.info("Loading the weights from pretrained network") 77 | try: 78 | net.load_state_dict(torch.load(opts.ckpt), strict=True) 79 | except RuntimeError: 80 | ckpt = torch.load(opts.ckpt) 81 | ckpt = {k[4:]: v for k, v in ckpt['state_dict'].items()} 82 | net.load_state_dict(ckpt, strict=True) 83 | 84 | with torch.no_grad(): 85 | t = tqdm(test_loader, ncols=100, disable=True) 86 | for data in t: 87 | 88 | data = data.to(device) 89 | 90 | # predictions 91 | predictions = net(data) 92 | predictions = torch.nn.functional.softmax(predictions[:, 1:], dim=1).max(dim=1)[1] 93 | predictions = predictions.cpu().numpy() + 1 94 | 95 | # interpolate to original point cloud 96 | original_pos_np = data["original_pos"].cpu().numpy() 97 | pos_np = data["pos"].cpu().numpy() 98 | tree = KDTree(pos_np) 99 | _, indices = tree.query(original_pos_np, k=1) 100 | predictions = predictions[indices] 101 | 102 | # update the confusion matric 103 | targets_np = data["original_y"].cpu().numpy() 104 | 105 | # create the colors 106 | prediction_colors = test_dataset.get_colors(predictions) 107 | target_colors = test_dataset.get_colors(targets_np) 108 | 109 | # good / bad predictions 110 | good_bad_pred = (predictions == targets_np).astype(np.uint8) 111 | 112 | fname = test_dataset.get_filename(data["shape_id"].item()) + ".xyz" 113 | 114 | # save everything 115 | predictions_dir = os.path.join(opts.resultsDir, "predictions") 116 | targets_dir = os.path.join(opts.resultsDir, "ground_truth") 117 | good_bad_pred_dir = os.path.join(opts.resultsDir, "good_bad") 118 | 119 | 120 | os.makedirs(predictions_dir, exist_ok=True) 121 | os.makedirs(targets_dir, exist_ok=True) 122 | os.makedirs(good_bad_pred_dir, exist_ok=True) 123 | 124 | 125 | np.savetxt(os.path.join(predictions_dir, fname), 126 | np.concatenate([original_pos_np, prediction_colors], axis=1), 127 | fmt=["%.3f", "%.3f", "%.3f", "%u", "%u", "%u"] 128 | ) 129 | np.savetxt(os.path.join(targets_dir, fname), 130 | np.concatenate([original_pos_np, target_colors], axis=1), 131 | fmt=["%.3f", "%.3f", "%.3f", "%u", "%u", "%u"] 132 | ) 133 | np.savetxt(os.path.join(good_bad_pred_dir, fname), 134 | np.concatenate([original_pos_np, good_bad_pred[:,np.newaxis]], axis=1), 135 | fmt=["%.3f", "%.3f", "%.3f", "%u"] 136 | ) 137 | 138 | torch.cuda.empty_cache() 139 | 140 | print(fname, good_bad_pred.sum()/good_bad_pred.shape[0]) 141 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | spconv==2.3.6 3 | torchsparse @ git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0 4 | MinkowskiEngine==0.5.4 5 | torch-scatter==2.1.1 6 | torch-geometric==2.3.1 7 | nuscenes-devkit 8 | easydict 9 | hydra-core==1.3.2 10 | pytorch-lightning==1.6.5 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 - Valeo Comfort and Driving Assistance - Corentin Sautier @ valeo.ai 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Code inspired by OpenPCDet. 16 | # Credit goes to OpenMMLab: https://github.com/open-mmlab/OpenPCDet 17 | 18 | import os 19 | import tqdm 20 | import torch 21 | import argparse 22 | import numpy as np 23 | from pathlib import Path 24 | import torch.distributed as dist 25 | import torch.multiprocessing as mp 26 | from datetime import datetime as dt 27 | from tensorboardX import SummaryWriter 28 | 29 | from utils.logger import make_logger 30 | from bevlab.models import make_models 31 | from bevlab.dataloader import make_dataloader 32 | from utils.config import generate_config, log_config 33 | from utils.optimizer import make_optimizer, make_scheduler 34 | from torch.nn.parallel import DistributedDataParallel as DDP 35 | 36 | import warnings 37 | warnings.filterwarnings("ignore", category=UserWarning) 38 | 39 | 40 | def ddp_setup(rank: int, world_size: int): 41 | """ 42 | Args: 43 | rank: Unique identifier of each process 44 | world_size: Total number of processes 45 | """ 46 | os.environ["MASTER_ADDR"] = "localhost" 47 | os.environ["MASTER_PORT"] = "12355" 48 | torch.cuda.set_device(rank % world_size) 49 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) 50 | 51 | def load_pretrained_weights(model, pretrained_path): 52 | # 加载预训练权重 53 | load_dict = torch.load(pretrained_path, map_location='cpu') 54 | pretrained_state_dict = load_dict['state_dict'] 55 | model_state_dict = model.state_dict() 56 | 57 | # 创建一个新的字典来保存调整后的权重 58 | new_state_dict = {} 59 | 60 | for key in model_state_dict.keys(): 61 | if key in pretrained_state_dict: 62 | # 获取模型和预训练权重的形状 63 | model_shape = model_state_dict[key].shape 64 | pretrained_shape = pretrained_state_dict[key].shape 65 | 66 | if model_shape == pretrained_shape: 67 | # 如果形状匹配,直接使用预训练权重 68 | new_state_dict[key] = pretrained_state_dict[key] 69 | else: 70 | # 如果形状不匹配,进行调整 71 | print(f"Shape mismatch for {key}, model shape: {model_shape}, pretrained shape: {pretrained_shape}") 72 | # 这里你可以根据需求进行调整,例如裁剪、填充、或者跳过该权重 73 | # 示例:如果只考虑形状不匹配时的简单跳过 74 | new_state_dict[key] = model_state_dict[key] 75 | else: 76 | # 如果预训练权重中没有该key,使用模型的默认权重 77 | print(f"Key {key} not found in pretrained weights, using model default weights") 78 | new_state_dict[key] = model_state_dict[key] 79 | 80 | # 加载调整后的权重 81 | model.load_state_dict(new_state_dict, strict=False) 82 | 83 | def parse_config(): 84 | parser = argparse.ArgumentParser(description='arg parser') 85 | parser.add_argument('--config_file', type=str, default=None, help='specify the config for training') 86 | 87 | parser.add_argument('--batch_size_per_gpu', type=int, default=None, required=False, help='batch size for training') 88 | parser.add_argument('--lr', type=int, default=None, required=False, help='batch size for training') 89 | parser.add_argument('--epochs', type=int, default=None, required=False, help='number of epochs to train for') 90 | parser.add_argument('--num_workers_per_gpu', type=int, default=None, help='number of workers for dataloader') 91 | parser.add_argument('--name', type=str, default='default', help='name of the experiment') 92 | parser.add_argument('--debug', action='store_true', default=False, help='') 93 | parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training') 94 | parser.add_argument('--resume_path', type=str, default=None, help='checkpoint to resume training from') 95 | parser.add_argument('--pretrain_path', type=str, default=None, help='checkpoint to load weights from') 96 | # parser.add_argument('--fix_random_seed', action='store_true', default=False, help='') 97 | 98 | args = parser.parse_args() 99 | 100 | config = generate_config(args.config_file) 101 | config.SAVE_FOLDER = Path('output', args.name, dt.today().strftime("%d%m%y-%H%M")) 102 | return args, config 103 | 104 | 105 | def main(rank, world_size): 106 | multigpu = world_size > 1 107 | if multigpu: 108 | ddp_setup(rank, world_size) 109 | args, config = parse_config() 110 | 111 | if args.batch_size_per_gpu is not None: 112 | config.OPTIMIZATION.BATCH_SIZE_PER_GPU = args.batch_size_per_gpu 113 | if args.epochs is not None: 114 | config.OPTIMIZATION.NUM_EPOCHS = args.epochs 115 | if args.num_workers_per_gpu is not None: 116 | config.OPTIMIZATION.NUM_WORKERS_PER_GPU = args.num_workers_per_gpu 117 | if args.lr is not None: 118 | config.OPTIMIZATION.LR = args.lr 119 | config.DEBUG = args.debug 120 | 121 | config.LOCAL_RANK = rank 122 | 123 | # if args.fix_random_seed: 124 | # # unfortunately as grid_sampler_2d_backward_cuda is non-deterministic, reproductibility isn't possible 125 | # torch.use_deterministic_algorithms(True) 126 | # torch.backends.cudnn.benchmark = False 127 | # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 128 | # random.seed(0) 129 | # np.random.seed(0) 130 | # torch.manual_seed(0) 131 | 132 | ckpt_dir = config.SAVE_FOLDER / 'ckpt' 133 | if rank == 0: 134 | ckpt_dir.mkdir(parents=True, exist_ok=True) 135 | log_file = config.SAVE_FOLDER / 'log_train.txt' 136 | logger = make_logger(log_file, rank=rank) 137 | 138 | logger.info("==============Logging config==============") 139 | log_config(config, logger) 140 | 141 | logger.info('World size : %s' % world_size) 142 | 143 | if rank == 0: 144 | os.system('cp %s %s' % (args.config_file, config.SAVE_FOLDER)) 145 | 146 | train_dataloader = make_dataloader( 147 | config=config, 148 | phase=config.DATASET.DATA_SPLIT['train'], 149 | world_size=world_size, 150 | rank=rank 151 | ) 152 | 153 | model = make_models(config=config) 154 | if multigpu and not config.ENCODER.COLLATE == "collate_torchsparse": 155 | # sync batchnorm doesn't work with torchsparse 156 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 157 | 158 | optimizer = make_optimizer(model, config) 159 | scheduler = make_scheduler( 160 | config, total_iters=len(train_dataloader) * config.OPTIMIZATION.NUM_EPOCHS 161 | ) 162 | 163 | model.train() 164 | # model.img_encoder.eval() 165 | # model.img_encoder.decoder.train() 166 | 167 | model = model.to(rank) 168 | 169 | # load checkpoint if it is possible 170 | if args.resume_path is not None: 171 | logger.warning(f"Continuing previous training: {args.resume_path}") 172 | load_dict = torch.load(args.resume_path, 'cpu') 173 | model.load_state_dict(load_dict['state_dict'], strict=False) 174 | optimizer.load_state_dict(load_dict['optimizer']) 175 | start_epoch = load_dict['epoch'] 176 | train_iter = load_dict['iters'] 177 | elif args.pretrain_path is not None: 178 | # load_dict = torch.load(args.pretrain_path, 'cpu') 179 | # model.load_state_dict(load_dict['state_dict'], strict=False) 180 | load_pretrained_weights(model, args.pretrain_path) 181 | start_epoch = 0 182 | train_iter = 0 183 | else: 184 | start_epoch = 0 185 | train_iter = 0 186 | 187 | if multigpu: 188 | model = DDP(model, device_ids=[rank], find_unused_parameters=True) 189 | 190 | train( 191 | model, 192 | start_epoch, 193 | train_iter, 194 | train_dataloader, 195 | optimizer, 196 | scheduler=scheduler, 197 | config=config, 198 | rank=rank, 199 | multigpu=multigpu 200 | ) 201 | 202 | if multigpu: 203 | dist.destroy_process_group() 204 | 205 | 206 | def train(model, start_epoch, train_iter, train_dataloader, optimizer, scheduler, config, rank, multigpu): 207 | debug = config.DEBUG 208 | if not debug: 209 | tb_log = SummaryWriter(log_dir=str(config.SAVE_FOLDER / 'tensorboard')) if rank == 0 else None 210 | 211 | total_epochs = config.OPTIMIZATION.NUM_EPOCHS 212 | disp_dict = {} 213 | with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True, leave=(rank == 0)) as tbar: 214 | total_it_each_epoch = len(train_dataloader) 215 | 216 | for cur_epoch in tbar: 217 | if multigpu: 218 | train_dataloader.sampler.set_epoch(cur_epoch) 219 | train_dataloader_iter = iter(train_dataloader) 220 | statistics = {"losses": []} 221 | 222 | if rank == 0: 223 | pbar = tqdm.tqdm(total=total_it_each_epoch, leave=False, desc='train', dynamic_ncols=True) 224 | 225 | for cur_it in range(len(train_dataloader)): 226 | cur_lr = scheduler[train_iter] 227 | batch = next(train_dataloader_iter) 228 | batch['cur_epoch'] = cur_epoch 229 | batch['voxels'] = batch['voxels'].to(rank, non_blocking=True) 230 | batch['pairing_points'] = batch['pairing_points'].to(rank, non_blocking=True) 231 | batch['pairing_images'] = batch['pairing_images'].to(rank, non_blocking=True) 232 | batch['coordinates'] = batch['coordinates'].to(rank, non_blocking=True) 233 | # batch['cam_coords'] = batch['cam_coords'].to(rank, non_blocking=True) 234 | batch['images'] = batch['images'].to(rank, non_blocking=True) 235 | batch['R_data'] = batch['R_data'].to(rank, non_blocking=True) 236 | batch['T_data'] = batch['T_data'].to(rank, non_blocking=True) 237 | if 'K_data' in batch: 238 | batch['K_data'] = batch['K_data'].to(rank, non_blocking=True) 239 | batch['img_overlap_masks'] = batch['img_overlap_masks'].to(rank, non_blocking=True) 240 | # batch['pc_overlap_masks'] = batch['pc_overlap_masks'].to(rank, non_blocking=True) 241 | for param_group in optimizer.param_groups: 242 | param_group["lr"] = cur_lr 243 | 244 | optimizer.zero_grad(set_to_none=True) 245 | 246 | loss, metrics = model(batch) 247 | statistics["losses"].append(loss.item()) 248 | 249 | loss.backward() 250 | optimizer.step() 251 | 252 | # log to console and tensorboard 253 | if rank == 0: 254 | 255 | pbar.update() 256 | pbar.set_postfix(dict(total_it=train_iter, loss=loss.item(), **metrics)) 257 | 258 | if not debug: 259 | tb_log.add_scalar('train/loss', loss.item(), train_iter) 260 | for key, value in metrics.items(): 261 | tb_log.add_scalar(f'train/{key}', value, train_iter) 262 | tb_log.add_scalar('meta_data/learning_rate', cur_lr, train_iter) 263 | 264 | del loss, metrics 265 | train_iter += 1 266 | 267 | if rank == 0: 268 | loss = np.mean(statistics['losses']) 269 | disp_dict.update({'loss': np.mean(statistics['losses'])}) 270 | tbar.set_postfix(disp_dict) 271 | if not debug: 272 | tb_log.add_scalar('epoch/loss', loss, cur_epoch) 273 | pbar.close() 274 | 275 | # save trained model 276 | if isinstance(model, DDP): 277 | torch.save({ 278 | "state_dict": model.module.state_dict(), 279 | "optimizer": optimizer.state_dict(), 280 | "epoch": cur_epoch+1, 281 | "iters": train_iter, 282 | "config": config}, 283 | config.SAVE_FOLDER / 'ckpt' / f'model_{cur_epoch+1}.pt') 284 | else: 285 | torch.save({ 286 | "state_dict": model.state_dict(), 287 | "optimizer": optimizer.state_dict(), 288 | "epoch": cur_epoch+1, 289 | "iters": train_iter, 290 | "config": config}, 291 | config.SAVE_FOLDER / 'ckpt' / f'model_{cur_epoch+1}.pt') 292 | 293 | 294 | if __name__ == '__main__': 295 | multigpu = torch.cuda.device_count() > 1 296 | if multigpu: 297 | world_size = torch.cuda.device_count() 298 | mp.spawn(main, args=(world_size,), nprocs=world_size) 299 | else: 300 | main(0, 1) 301 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Eaphan/NCLR/b3a944af649b64f0aed82aae0211ebc5f2fe2d13/utils/__init__.py -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from easydict import EasyDict 3 | 4 | 5 | def log_config(config, logger): 6 | for key, val in config.items(): 7 | if isinstance(val, EasyDict): 8 | logger.info("===== %s =====:" % key) 9 | log_config(val, logger) 10 | else: 11 | logger.info('%s: %s' % (key, val)) 12 | 13 | 14 | def generate_config(config): 15 | with open(config, 'r') as f: 16 | config = yaml.safe_load(f) 17 | 18 | return EasyDict(config) 19 | -------------------------------------------------------------------------------- /utils/convert_spconv_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | 4 | 5 | if __name__ == "__main__": 6 | path = sys.argv[1] 7 | ckpt = torch.load(path) 8 | ckpt = ckpt['state_dict'] 9 | ckpt = {k.replace('encoder.', ''): v for k, v in ckpt.items()} 10 | del ckpt['final.weight'] 11 | torch.save({'model_state': ckpt}, path.replace('.pt', '_converted.pt')) 12 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def make_logger(logfile, rank): 5 | return CustomLogger(logfile, rank) 6 | 7 | 8 | class CustomLogger: 9 | def __init__(self, logfile, rank) -> None: 10 | self.rank = rank 11 | if rank == 0: 12 | self.logger = logging.getLogger() 13 | self.logger.setLevel(logging.DEBUG) 14 | ch = logging.StreamHandler() 15 | ch.setLevel(logging.DEBUG) 16 | fh = logging.FileHandler(logfile) 17 | fh.setLevel(logging.DEBUG) 18 | formatter = logging.Formatter('%(levelname)s - %(message)s', datefmt='%d%m%y-%H%M') 19 | ch.setFormatter(formatter) 20 | fh.setFormatter(formatter) 21 | self.logger.addHandler(ch) 22 | self.logger.addHandler(fh) 23 | else: 24 | self.logger = None 25 | 26 | def debug(self, message): 27 | if self.rank == 0: 28 | self.logger.debug(message) 29 | 30 | def info(self, message): 31 | if self.rank == 0: 32 | self.logger.info(message) 33 | 34 | def warning(self, message): 35 | if self.rank == 0: 36 | self.logger.warning(message) 37 | 38 | def error(self, message): 39 | if self.rank == 0: 40 | self.logger.error(message) 41 | 42 | def critical(self, message): 43 | if self.rank == 0: 44 | self.logger.critical(message) 45 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from utils.utils import cosine_scheduler 3 | 4 | 5 | def make_optimizer(model, config): 6 | optimizer_class = getattr(optim, config.OPTIMIZATION.OPTIMIZER) 7 | optimizer_params = { 8 | 'lr': config.OPTIMIZATION.LR, 9 | 'weight_decay': config.OPTIMIZATION.WEIGHT_DECAY 10 | } 11 | if config.OPTIMIZATION.OPTIMIZER == 'SGD': 12 | optimizer_params['momentum'] = config.OPTIMIZATION.SGD_MOMENTUM 13 | optimizer_params['dampening'] = config.OPTIMIZATION.SGD_DAMPENING 14 | optimizer = optimizer_class( 15 | model.parameters(), 16 | **optimizer_params 17 | ) 18 | return optimizer 19 | 20 | 21 | def make_scheduler(config, total_iters): 22 | return cosine_scheduler(config.OPTIMIZATION.LR, config.OPTIMIZATION.LR / 100, total_iters, total_iters // 20) 23 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def confusion_matrix(preds, labels, num_classes): 6 | hist = ( 7 | torch.bincount( 8 | num_classes * labels + preds, 9 | minlength=num_classes ** 2, 10 | ) 11 | .reshape(num_classes, num_classes) 12 | .float() 13 | ) 14 | return hist 15 | 16 | 17 | def compute_IoU_from_cmatrix(hist, ignore_index=None): 18 | """Computes the Intersection over Union (IoU). 19 | Args: 20 | hist: confusion matrix. 21 | Returns: 22 | m_IoU, fw_IoU, and matrix IoU 23 | """ 24 | if ignore_index is not None: 25 | hist[ignore_index] = 0.0 26 | intersection = torch.diag(hist) 27 | union = hist.sum(dim=1) + hist.sum(dim=0) - intersection 28 | IoU = intersection.float() / union.float() 29 | IoU[union == 0] = 1.0 30 | if ignore_index is not None: 31 | IoU = torch.cat((IoU[:ignore_index], IoU[ignore_index + 1:])) 32 | m_IoU = torch.mean(IoU).item() 33 | fw_IoU = ( 34 | torch.sum(intersection) / (2 * torch.sum(hist) - torch.sum(intersection)) 35 | ).item() 36 | return m_IoU, fw_IoU, IoU 37 | 38 | 39 | def knn_classifier(train_features, train_labels, test_features, k, T, num_classes=1000, num_chunks=5000): 40 | preds = [] 41 | train_features = train_features.t() 42 | num_test_images = test_features.shape[0] 43 | imgs_per_chunk = num_test_images // num_chunks 44 | retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device, non_blocking=True) 45 | for idx in range(0, num_test_images, imgs_per_chunk): 46 | # get the features for test images 47 | features = test_features[ 48 | idx: min((idx + imgs_per_chunk), num_test_images), : 49 | ] 50 | batch_size = features.shape[0] 51 | 52 | # calculate the dot product and compute top-k neighbors 53 | similarity = torch.mm(features, train_features) 54 | distances, indices = similarity.topk(k, largest=True, sorted=True) 55 | del similarity 56 | candidates = train_labels.view(1, -1).expand(batch_size, -1) 57 | retrieved_neighbors = torch.gather(candidates, 1, indices) 58 | 59 | # preds.append(torch.mode(retrieved_neighbors, dim=1)[0]) 60 | 61 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() 62 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 63 | # distances_transform = distances.clone().div_(T).exp_() 64 | distances = distances.div_(T).exp_() 65 | probs = torch.sum( 66 | torch.mul( 67 | retrieval_one_hot.view(batch_size, -1, num_classes), 68 | distances.view(batch_size, -1, 1), 69 | ), 70 | 1, 71 | ) 72 | _, predictions = probs.sort(1, True) 73 | preds.append(predictions[:, 0]) 74 | 75 | return torch.cat(preds, 0) 76 | 77 | 78 | def cosine_scheduler(base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0): 79 | # Code taken from https://github.com/facebookresearch/dino 80 | # Copyright (c) Facebook, Inc. and its affiliates. 81 | warmup_schedule = np.array([]) 82 | if warmup_iters > 0: 83 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 84 | 85 | iters = np.arange(total_iters - warmup_iters) 86 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 87 | 88 | schedule = np.concatenate((warmup_schedule, schedule)) 89 | assert len(schedule) == total_iters 90 | return schedule 91 | 92 | 93 | def transform_rotation(pc, seed): 94 | angle = seed * 2 * np.pi 95 | c = np.cos(angle) 96 | s = np.sin(angle) 97 | rotation = np.array( 98 | [[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32 99 | ) 100 | pc[:, :3] = pc[:, :3] @ rotation 101 | return pc 102 | 103 | 104 | def transform_dilation(pc, seed): 105 | dilation = seed * 0.1 + 0.95 106 | pc[:, :3] = pc[:, :3] * dilation 107 | return pc 108 | 109 | 110 | def transform_jittering(pc, seed): 111 | pc[:, 3] = np.random.normal(pc[:, 3], 0.01) 112 | return pc 113 | 114 | 115 | def mask_points_outside_range(points, range): 116 | mask = (points[:, 0] >= range[0]) & (points[:, 0] <= range[3]) \ 117 | & (points[:, 1] >= range[1]) & (points[:, 1] <= range[4]) 118 | return points[mask] 119 | 120 | 121 | def det_3x3(mat): 122 | 123 | a, b, c = mat[:, 0, 0], mat[:, 0, 1], mat[:, 0, 2] 124 | d, e, f = mat[:, 1, 0], mat[:, 1, 1], mat[:, 1, 2] 125 | g, h, i = mat[:, 2, 0], mat[:, 2, 1], mat[:, 2, 2] 126 | 127 | det = a * e * i + b * f * g + c * d * h 128 | det = det - c * e * g - b * d * i - a * f * h 129 | 130 | return det 131 | 132 | 133 | def det_2x2(mat): 134 | 135 | a, b = mat[:, 0, 0], mat[:, 0, 1] 136 | c, d = mat[:, 1, 0], mat[:, 1, 1] 137 | 138 | det = a * d - c * b 139 | 140 | return det 141 | 142 | 143 | # def estimate_rot_trans(x, y, w): 144 | # # if threshold is not None: 145 | # # w = w * (w > self.threshold).float() 146 | # w = torch.nn.functional.normalize(w, dim=-1, p=1) 147 | 148 | # # Center point clouds 149 | # mean_x = (w * x).sum(dim=-1, keepdim=True) 150 | # mean_y = (w * y).sum(dim=-1, keepdim=True) 151 | # x_centered = x - mean_x 152 | # y_centered = y - mean_y 153 | 154 | # # Covariance 155 | # cov = torch.bmm(y_centered, (w * x_centered).transpose(1, 2)) 156 | 157 | # # Rotation 158 | # U, _, V = torch.svd(cov) 159 | # det = det_3x3(U) * det_3x3(V) 160 | # S = torch.eye(3, device=U.device).unsqueeze(0).repeat(x.shape[0], 1, 1) 161 | # S[:, -1, -1] = det 162 | # R = torch.bmm(U, torch.bmm(S, V.transpose(1, 2))) 163 | 164 | # # Translation 165 | # T = mean_y - torch.bmm(R, mean_x) 166 | 167 | # return R, T, w 168 | 169 | 170 | def estimate_rot_trans(x, y, w=None): 171 | if w is None: 172 | w = torch.ones(size=(x.shape[0], 1, x.shape[2]), device=x.device) 173 | # if threshold is not None: 174 | # w = w * (w > self.threshold).float() 175 | w = torch.nn.functional.normalize(w, dim=-1, p=1) 176 | 177 | # Center point clouds 178 | mean_x = (w * x).sum(dim=-1, keepdim=True) 179 | mean_y = (w * y).sum(dim=-1, keepdim=True) 180 | x_centered = x - mean_x 181 | y_centered = y - mean_y 182 | 183 | # Covariance 184 | cov = torch.bmm(y_centered, (w * x_centered).transpose(1, 2)) 185 | 186 | # Rotation 187 | U, _, V = torch.svd(cov) 188 | det = det_2x2(U) * det_2x2(V) 189 | S = torch.eye(2, device=U.device).unsqueeze(0).repeat(x.shape[0], 1, 1) 190 | S[:, -1, -1] = det 191 | R = torch.bmm(U, torch.bmm(S, V.transpose(1, 2))) 192 | 193 | # Translation 194 | T = mean_y - torch.bmm(R, mean_x) 195 | 196 | return R, T 197 | 198 | 199 | def compute_rte(t, t_est): 200 | 201 | t = t.squeeze().detach().cpu().numpy() 202 | t_est = t_est.squeeze().detach().cpu().numpy() 203 | 204 | return np.linalg.norm(t - t_est) 205 | 206 | 207 | def compute_rre(R_est, R): 208 | 209 | eps = 1e-16 210 | 211 | R = R.squeeze().detach().cpu().numpy() 212 | R_est = R_est.squeeze().detach().cpu().numpy() 213 | 214 | return np.arccos( 215 | np.clip( 216 | np.trace(R_est.T @ R) / 2, 217 | # (np.trace(R_est.T @ R) - 1) / 2, 218 | -1 + eps, 219 | 1 - eps 220 | ) 221 | ) * 180. / np.pi 222 | --------------------------------------------------------------------------------