├── .gitignore ├── README.md ├── figs ├── framework.jpg └── scannet_visualization.jpg └── pointdc_mk ├── README.md ├── data_prepare ├── S3DIS_anno_paths.txt ├── S3DIS_class_names.txt ├── ScanNet_splits │ ├── scannetv2_test.txt │ ├── scannetv2_train.txt │ ├── scannetv2_trainval.txt │ └── scannetv2_val.txt ├── data_prepare_S3DIS.py ├── data_prepare_ScanNet.py ├── initialSP_prepare_S3DIS.py ├── initialSP_prepare_ScanNet.py └── semantic-kitti.yaml ├── datasets ├── S3DIS.py ├── ScanNet.py └── SemanticKITTI.py ├── env.yaml ├── eval_S3DIS.py ├── eval_ScanNet.py ├── figs ├── framework.jpg └── scannet_visualization.jpg ├── lib ├── aug_tools.py ├── helper_ply.py ├── utils.py └── utils_s3dis.py ├── models ├── __init__.py ├── api_modules.py ├── common.py ├── fpn.py ├── modules.py ├── networks.py ├── pretrain_models.py ├── res16unet.py └── resunet.py ├── train_S3DIS.py └── train_ScanNet.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | 3 | pointdc_mk/ckpt 4 | pointdc_mk/ckpt_old 5 | pointdc_mk/data 6 | 7 | */__pycache__ 8 | */*/__pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![arXiv](https://img.shields.io/badge/arXiv-2304.08965-b31b1b.svg)](https://arxiv.org/abs/2304.08965) 2 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 3 | 4 | ## PointDC:Unsupervised Semantic Segmentation of 3D Point Clouds via Cross-modal Distillation and Super-Voxel Clustering (ICCV 2023) 5 | 6 | ### Overview 7 | 8 | We propose an unsupervised point clouds semantic segmentation framework, called **PointDC**. 9 | 10 |

11 | drawing 12 |

13 | 14 | ## NOTE 15 | There are two projects deployed here. [pointdc_mk](https://github.com/SCUT-BIP-Lab/PointDC/tree/main/pointdc_mk) is based on [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine). 16 | 17 | ## TODO 18 | - [x] Release code based on Minkowski and model weight files 19 | - [ ] Release code based on SpConv and model weight files 20 | - [x] Release Spare Feature Volume files 21 | 22 | 23 | ### Citation 24 | If this paper is helpful to you, please cite: 25 | ``` 26 | @article{chen2023unsupervised, 27 | title={Unsupervised Semantic Segmentation of 3D Point Clouds via Cross-modal Distillation and Super-Voxel Clustering}, 28 | author={Chen, Zisheng and Xu, Hongbin}, 29 | journal={arXiv preprint arXiv:2304.08965}, 30 | year={2023} 31 | } 32 | ``` -------------------------------------------------------------------------------- /figs/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCUT-BIP-Lab/PointDC/99f4f30ee4e2fa2d173c3881196afec4b02bb606/figs/framework.jpg -------------------------------------------------------------------------------- /figs/scannet_visualization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCUT-BIP-Lab/PointDC/99f4f30ee4e2fa2d173c3881196afec4b02bb606/figs/scannet_visualization.jpg -------------------------------------------------------------------------------- /pointdc_mk/README.md: -------------------------------------------------------------------------------- 1 | [![arXiv](https://img.shields.io/badge/arXiv-2304.08965-b31b1b.svg)](https://arxiv.org/abs/2304.08965) 2 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 3 | 4 | ## PointDC:Unsupervised Semantic Segmentation of 3D Point Clouds via Cross-modal Distillation and Super-Voxel Clustering (ICCV 2023) 5 | 6 | ### Overview 7 | 8 | We propose an unsupervised point clouds semantic segmentation framework, called **PointDC**. 9 | 10 |

11 | drawing 12 |

13 | 14 | ## NOTE 15 | This project is based on Minkowski Engine and refers to the code from [growsp](https://github.com/vLAR-group/GrowSP), but the methods used are consistent with the original paper. 16 | 17 | ## TODO 18 | - [x] Release code deployed on the ScanNet dataset and model weight files 19 | - [x] Release code deployed on the S3DIS dataset and model weight files 20 | - [x] Release Spare Feature Volume files 21 | 22 | ## 1. Setup 23 | Setting up for this project involves installing dependencies. 24 | 25 | ### Installing dependencies 26 | To install all the dependencies, please run the following: 27 | ```shell script 28 | sudo apt install build-essential python3-dev libopenblas-dev 29 | conda env create -f env.yaml 30 | conda activate pointdc_mk 31 | pip install -U MinkowskiEngine --install-option="--blas=openblas" -v --no-deps 32 | ``` 33 | ## 2. Running codes 34 | ### 2.1 ScanNet 35 | Download the ScanNet dataset from [the official website](http://kaldir.vc.in.tum.de/scannet_benchmark/documentation). 36 | You need to sign the terms of use. Uncompress the folder and move it to 37 | `${your_ScanNet}`. 38 | - Download sp feats files from [here](https://pan.baidu.com/s/1ibxoq3HyxRJa3KrnPafCWw?pwd=6666), and put it in the right path. 39 | 40 | 41 | - Preparing the dataset: 42 | ```shell script 43 | python data_prepare/data_prepare_ScanNet.py --data_path ${your_ScanNet} 44 | ``` 45 | This code will preprcocess ScanNet and put it under `./data/ScanNet/processed` 46 | 47 | - Construct initial superpoints: 48 | ```shell script 49 | python data_prepare/initialSP_prepare_ScanNet.py 50 | ``` 51 | This code will construct superpoints on ScanNet and put it under `./data/ScanNet/initial_superpoints` 52 | 53 | - Training: 54 | ```shell script 55 | CUDA_VISIBLE_DEVICES=0, python train_ScanNet.py --expname ${your_experiment_name} 56 | ``` 57 | The output model and log file will be saved in `./ckpt/ScanNet` by default. 58 | 59 | - Evaling: 60 | Revise experiment name ```expnames=[eval_experiment_name]```in Lines 141. 61 | ```shell script 62 | CUDA_VISIBLE_DEVICES=0, python eval_ScanNet.py 63 | ``` 64 | 65 | ### 2.2 S3DIS 66 | Download the S3DIS dataset from [the official website](https://docs.google.com/forms/d/e/1FAIpQLScDimvNMCGhy_rmBA2gHfDu3naktRm6A8BPwAWWDv-Uhm6Shw/viewform?c=0&w=1&pli=1), download the files named **Stanford3dDataset_v1.2.zip**. 67 | Uncompress the folder and move it to '${your_S3DIS}'. And there is an error in line 180389 of file Area_5/hallway_6/Annotations/ceiling_1.txt. It need to be fixed manually. 68 | - Download sp feats and sp files from [here](https://pan.baidu.com/s/1ibxoq3HyxRJa3KrnPafCWw?pwd=6666), and put it in the right path. 69 | >Due to the randomness in the construction of super-voxels, different super-voxels will lead to different super-voxel features. Therefore, we provide both super-voxel features and corresponding supe-voxels. This only affects the distillation stage. 70 | 71 | - Preparing the dataset: 72 | ```shell script 73 | python data_prepare/data_prepare_S3DIS.py --data_path ${your_S3DIS} 74 | ``` 75 | This code will preprcocess S3DIS and put it under `./data/S3DIS/processed` 76 | 77 | - Construct initial superpoints: 78 | ```shell script 79 | python data_prepare/initialSP_prepare_S3DIS.py 80 | ``` 81 | This code will construct superpoints on S3DIS and put it under `./data/S3DIS/initial_superpoints` 82 | 83 | - Training: 84 | ```shell script 85 | CUDA_VISIBLE_DEVICES=0, python train_S3DIS.py --expname ${your_experiment_name} 86 | ``` 87 | The output model and log file will be saved in `./ckpt/S3DIS` by default. 88 | 89 | - Evaling: 90 | Revise experiment name `expnames=[eval_experiment_name]`. 91 | ```shell script 92 | CUDA_VISIBLE_DEVICES=0, python eval_S3DIS.py 93 | ``` 94 | 95 | ## 3. Model Weights Files 96 | The trained models and other processed files can be found at [here](https://pan.baidu.com/s/1ibxoq3HyxRJa3KrnPafCWw?pwd=6666). 97 | 98 | ## Acknowledgement 99 | [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) 100 | 101 | [growsp](https://github.com/vLAR-group/GrowSP) 102 | -------------------------------------------------------------------------------- /pointdc_mk/data_prepare/S3DIS_anno_paths.txt: -------------------------------------------------------------------------------- 1 | Area_1/conferenceRoom_1/Annotations 2 | Area_1/conferenceRoom_2/Annotations 3 | Area_1/copyRoom_1/Annotations 4 | Area_1/hallway_1/Annotations 5 | Area_1/hallway_2/Annotations 6 | Area_1/hallway_3/Annotations 7 | Area_1/hallway_4/Annotations 8 | Area_1/hallway_5/Annotations 9 | Area_1/hallway_6/Annotations 10 | Area_1/hallway_7/Annotations 11 | Area_1/hallway_8/Annotations 12 | Area_1/office_10/Annotations 13 | Area_1/office_11/Annotations 14 | Area_1/office_12/Annotations 15 | Area_1/office_13/Annotations 16 | Area_1/office_14/Annotations 17 | Area_1/office_15/Annotations 18 | Area_1/office_16/Annotations 19 | Area_1/office_17/Annotations 20 | Area_1/office_18/Annotations 21 | Area_1/office_19/Annotations 22 | Area_1/office_1/Annotations 23 | Area_1/office_20/Annotations 24 | Area_1/office_21/Annotations 25 | Area_1/office_22/Annotations 26 | Area_1/office_23/Annotations 27 | Area_1/office_24/Annotations 28 | Area_1/office_25/Annotations 29 | Area_1/office_26/Annotations 30 | Area_1/office_27/Annotations 31 | Area_1/office_28/Annotations 32 | Area_1/office_29/Annotations 33 | Area_1/office_2/Annotations 34 | Area_1/office_30/Annotations 35 | Area_1/office_31/Annotations 36 | Area_1/office_3/Annotations 37 | Area_1/office_4/Annotations 38 | Area_1/office_5/Annotations 39 | Area_1/office_6/Annotations 40 | Area_1/office_7/Annotations 41 | Area_1/office_8/Annotations 42 | Area_1/office_9/Annotations 43 | Area_1/pantry_1/Annotations 44 | Area_1/WC_1/Annotations 45 | Area_2/auditorium_1/Annotations 46 | Area_2/auditorium_2/Annotations 47 | Area_2/conferenceRoom_1/Annotations 48 | Area_2/hallway_10/Annotations 49 | Area_2/hallway_11/Annotations 50 | Area_2/hallway_12/Annotations 51 | Area_2/hallway_1/Annotations 52 | Area_2/hallway_2/Annotations 53 | Area_2/hallway_3/Annotations 54 | Area_2/hallway_4/Annotations 55 | Area_2/hallway_5/Annotations 56 | Area_2/hallway_6/Annotations 57 | Area_2/hallway_7/Annotations 58 | Area_2/hallway_8/Annotations 59 | Area_2/hallway_9/Annotations 60 | Area_2/office_10/Annotations 61 | Area_2/office_11/Annotations 62 | Area_2/office_12/Annotations 63 | Area_2/office_13/Annotations 64 | Area_2/office_14/Annotations 65 | Area_2/office_1/Annotations 66 | Area_2/office_2/Annotations 67 | Area_2/office_3/Annotations 68 | Area_2/office_4/Annotations 69 | Area_2/office_5/Annotations 70 | Area_2/office_6/Annotations 71 | Area_2/office_7/Annotations 72 | Area_2/office_8/Annotations 73 | Area_2/office_9/Annotations 74 | Area_2/storage_1/Annotations 75 | Area_2/storage_2/Annotations 76 | Area_2/storage_3/Annotations 77 | Area_2/storage_4/Annotations 78 | Area_2/storage_5/Annotations 79 | Area_2/storage_6/Annotations 80 | Area_2/storage_7/Annotations 81 | Area_2/storage_8/Annotations 82 | Area_2/storage_9/Annotations 83 | Area_2/WC_1/Annotations 84 | Area_2/WC_2/Annotations 85 | Area_3/conferenceRoom_1/Annotations 86 | Area_3/hallway_1/Annotations 87 | Area_3/hallway_2/Annotations 88 | Area_3/hallway_3/Annotations 89 | Area_3/hallway_4/Annotations 90 | Area_3/hallway_5/Annotations 91 | Area_3/hallway_6/Annotations 92 | Area_3/lounge_1/Annotations 93 | Area_3/lounge_2/Annotations 94 | Area_3/office_10/Annotations 95 | Area_3/office_1/Annotations 96 | Area_3/office_2/Annotations 97 | Area_3/office_3/Annotations 98 | Area_3/office_4/Annotations 99 | Area_3/office_5/Annotations 100 | Area_3/office_6/Annotations 101 | Area_3/office_7/Annotations 102 | Area_3/office_8/Annotations 103 | Area_3/office_9/Annotations 104 | Area_3/storage_1/Annotations 105 | Area_3/storage_2/Annotations 106 | Area_3/WC_1/Annotations 107 | Area_3/WC_2/Annotations 108 | Area_4/conferenceRoom_1/Annotations 109 | Area_4/conferenceRoom_2/Annotations 110 | Area_4/conferenceRoom_3/Annotations 111 | Area_4/hallway_10/Annotations 112 | Area_4/hallway_11/Annotations 113 | Area_4/hallway_12/Annotations 114 | Area_4/hallway_13/Annotations 115 | Area_4/hallway_14/Annotations 116 | Area_4/hallway_1/Annotations 117 | Area_4/hallway_2/Annotations 118 | Area_4/hallway_3/Annotations 119 | Area_4/hallway_4/Annotations 120 | Area_4/hallway_5/Annotations 121 | Area_4/hallway_6/Annotations 122 | Area_4/hallway_7/Annotations 123 | Area_4/hallway_8/Annotations 124 | Area_4/hallway_9/Annotations 125 | Area_4/lobby_1/Annotations 126 | Area_4/lobby_2/Annotations 127 | Area_4/office_10/Annotations 128 | Area_4/office_11/Annotations 129 | Area_4/office_12/Annotations 130 | Area_4/office_13/Annotations 131 | Area_4/office_14/Annotations 132 | Area_4/office_15/Annotations 133 | Area_4/office_16/Annotations 134 | Area_4/office_17/Annotations 135 | Area_4/office_18/Annotations 136 | Area_4/office_19/Annotations 137 | Area_4/office_1/Annotations 138 | Area_4/office_20/Annotations 139 | Area_4/office_21/Annotations 140 | Area_4/office_22/Annotations 141 | Area_4/office_2/Annotations 142 | Area_4/office_3/Annotations 143 | Area_4/office_4/Annotations 144 | Area_4/office_5/Annotations 145 | Area_4/office_6/Annotations 146 | Area_4/office_7/Annotations 147 | Area_4/office_8/Annotations 148 | Area_4/office_9/Annotations 149 | Area_4/storage_1/Annotations 150 | Area_4/storage_2/Annotations 151 | Area_4/storage_3/Annotations 152 | Area_4/storage_4/Annotations 153 | Area_4/WC_1/Annotations 154 | Area_4/WC_2/Annotations 155 | Area_4/WC_3/Annotations 156 | Area_4/WC_4/Annotations 157 | Area_5/conferenceRoom_1/Annotations 158 | Area_5/conferenceRoom_2/Annotations 159 | Area_5/conferenceRoom_3/Annotations 160 | Area_5/hallway_10/Annotations 161 | Area_5/hallway_11/Annotations 162 | Area_5/hallway_12/Annotations 163 | Area_5/hallway_13/Annotations 164 | Area_5/hallway_14/Annotations 165 | Area_5/hallway_15/Annotations 166 | Area_5/hallway_1/Annotations 167 | Area_5/hallway_2/Annotations 168 | Area_5/hallway_3/Annotations 169 | Area_5/hallway_4/Annotations 170 | Area_5/hallway_5/Annotations 171 | Area_5/hallway_6/Annotations 172 | Area_5/hallway_7/Annotations 173 | Area_5/hallway_8/Annotations 174 | Area_5/hallway_9/Annotations 175 | Area_5/lobby_1/Annotations 176 | Area_5/office_10/Annotations 177 | Area_5/office_11/Annotations 178 | Area_5/office_12/Annotations 179 | Area_5/office_13/Annotations 180 | Area_5/office_14/Annotations 181 | Area_5/office_15/Annotations 182 | Area_5/office_16/Annotations 183 | Area_5/office_17/Annotations 184 | Area_5/office_18/Annotations 185 | Area_5/office_19/Annotations 186 | Area_5/office_1/Annotations 187 | Area_5/office_20/Annotations 188 | Area_5/office_21/Annotations 189 | Area_5/office_22/Annotations 190 | Area_5/office_23/Annotations 191 | Area_5/office_24/Annotations 192 | Area_5/office_25/Annotations 193 | Area_5/office_26/Annotations 194 | Area_5/office_27/Annotations 195 | Area_5/office_28/Annotations 196 | Area_5/office_29/Annotations 197 | Area_5/office_2/Annotations 198 | Area_5/office_30/Annotations 199 | Area_5/office_31/Annotations 200 | Area_5/office_32/Annotations 201 | Area_5/office_33/Annotations 202 | Area_5/office_34/Annotations 203 | Area_5/office_35/Annotations 204 | Area_5/office_36/Annotations 205 | Area_5/office_37/Annotations 206 | Area_5/office_38/Annotations 207 | Area_5/office_39/Annotations 208 | Area_5/office_3/Annotations 209 | Area_5/office_40/Annotations 210 | Area_5/office_41/Annotations 211 | Area_5/office_42/Annotations 212 | Area_5/office_4/Annotations 213 | Area_5/office_5/Annotations 214 | Area_5/office_6/Annotations 215 | Area_5/office_7/Annotations 216 | Area_5/office_8/Annotations 217 | Area_5/office_9/Annotations 218 | Area_5/pantry_1/Annotations 219 | Area_5/storage_1/Annotations 220 | Area_5/storage_2/Annotations 221 | Area_5/storage_3/Annotations 222 | Area_5/storage_4/Annotations 223 | Area_5/WC_1/Annotations 224 | Area_5/WC_2/Annotations 225 | Area_6/conferenceRoom_1/Annotations 226 | Area_6/copyRoom_1/Annotations 227 | Area_6/hallway_1/Annotations 228 | Area_6/hallway_2/Annotations 229 | Area_6/hallway_3/Annotations 230 | Area_6/hallway_4/Annotations 231 | Area_6/hallway_5/Annotations 232 | Area_6/hallway_6/Annotations 233 | Area_6/lounge_1/Annotations 234 | Area_6/office_10/Annotations 235 | Area_6/office_11/Annotations 236 | Area_6/office_12/Annotations 237 | Area_6/office_13/Annotations 238 | Area_6/office_14/Annotations 239 | Area_6/office_15/Annotations 240 | Area_6/office_16/Annotations 241 | Area_6/office_17/Annotations 242 | Area_6/office_18/Annotations 243 | Area_6/office_19/Annotations 244 | Area_6/office_1/Annotations 245 | Area_6/office_20/Annotations 246 | Area_6/office_21/Annotations 247 | Area_6/office_22/Annotations 248 | Area_6/office_23/Annotations 249 | Area_6/office_24/Annotations 250 | Area_6/office_25/Annotations 251 | Area_6/office_26/Annotations 252 | Area_6/office_27/Annotations 253 | Area_6/office_28/Annotations 254 | Area_6/office_29/Annotations 255 | Area_6/office_2/Annotations 256 | Area_6/office_30/Annotations 257 | Area_6/office_31/Annotations 258 | Area_6/office_32/Annotations 259 | Area_6/office_33/Annotations 260 | Area_6/office_34/Annotations 261 | Area_6/office_35/Annotations 262 | Area_6/office_36/Annotations 263 | Area_6/office_37/Annotations 264 | Area_6/office_3/Annotations 265 | Area_6/office_4/Annotations 266 | Area_6/office_5/Annotations 267 | Area_6/office_6/Annotations 268 | Area_6/office_7/Annotations 269 | Area_6/office_8/Annotations 270 | Area_6/office_9/Annotations 271 | Area_6/openspace_1/Annotations 272 | Area_6/pantry_1/Annotations 273 | -------------------------------------------------------------------------------- /pointdc_mk/data_prepare/S3DIS_class_names.txt: -------------------------------------------------------------------------------- 1 | ceiling 2 | floor 3 | wall 4 | beam 5 | column 6 | window 7 | door 8 | table 9 | chair 10 | sofa 11 | bookcase 12 | board 13 | clutter 14 | 15 | -------------------------------------------------------------------------------- /pointdc_mk/data_prepare/ScanNet_splits/scannetv2_test.txt: -------------------------------------------------------------------------------- 1 | scene0707_00_vh_clean_2.ply 2 | scene0708_00_vh_clean_2.ply 3 | scene0709_00_vh_clean_2.ply 4 | scene0710_00_vh_clean_2.ply 5 | scene0711_00_vh_clean_2.ply 6 | scene0712_00_vh_clean_2.ply 7 | scene0713_00_vh_clean_2.ply 8 | scene0714_00_vh_clean_2.ply 9 | scene0715_00_vh_clean_2.ply 10 | scene0716_00_vh_clean_2.ply 11 | scene0717_00_vh_clean_2.ply 12 | scene0718_00_vh_clean_2.ply 13 | scene0719_00_vh_clean_2.ply 14 | scene0720_00_vh_clean_2.ply 15 | scene0721_00_vh_clean_2.ply 16 | scene0722_00_vh_clean_2.ply 17 | scene0723_00_vh_clean_2.ply 18 | scene0724_00_vh_clean_2.ply 19 | scene0725_00_vh_clean_2.ply 20 | scene0726_00_vh_clean_2.ply 21 | scene0727_00_vh_clean_2.ply 22 | scene0728_00_vh_clean_2.ply 23 | scene0729_00_vh_clean_2.ply 24 | scene0730_00_vh_clean_2.ply 25 | scene0731_00_vh_clean_2.ply 26 | scene0732_00_vh_clean_2.ply 27 | scene0733_00_vh_clean_2.ply 28 | scene0734_00_vh_clean_2.ply 29 | scene0735_00_vh_clean_2.ply 30 | scene0736_00_vh_clean_2.ply 31 | scene0737_00_vh_clean_2.ply 32 | scene0738_00_vh_clean_2.ply 33 | scene0739_00_vh_clean_2.ply 34 | scene0740_00_vh_clean_2.ply 35 | scene0741_00_vh_clean_2.ply 36 | scene0742_00_vh_clean_2.ply 37 | scene0743_00_vh_clean_2.ply 38 | scene0744_00_vh_clean_2.ply 39 | scene0745_00_vh_clean_2.ply 40 | scene0746_00_vh_clean_2.ply 41 | scene0747_00_vh_clean_2.ply 42 | scene0748_00_vh_clean_2.ply 43 | scene0749_00_vh_clean_2.ply 44 | scene0750_00_vh_clean_2.ply 45 | scene0751_00_vh_clean_2.ply 46 | scene0752_00_vh_clean_2.ply 47 | scene0753_00_vh_clean_2.ply 48 | scene0754_00_vh_clean_2.ply 49 | scene0755_00_vh_clean_2.ply 50 | scene0756_00_vh_clean_2.ply 51 | scene0757_00_vh_clean_2.ply 52 | scene0758_00_vh_clean_2.ply 53 | scene0759_00_vh_clean_2.ply 54 | scene0760_00_vh_clean_2.ply 55 | scene0761_00_vh_clean_2.ply 56 | scene0762_00_vh_clean_2.ply 57 | scene0763_00_vh_clean_2.ply 58 | scene0764_00_vh_clean_2.ply 59 | scene0765_00_vh_clean_2.ply 60 | scene0766_00_vh_clean_2.ply 61 | scene0767_00_vh_clean_2.ply 62 | scene0768_00_vh_clean_2.ply 63 | scene0769_00_vh_clean_2.ply 64 | scene0770_00_vh_clean_2.ply 65 | scene0771_00_vh_clean_2.ply 66 | scene0772_00_vh_clean_2.ply 67 | scene0773_00_vh_clean_2.ply 68 | scene0774_00_vh_clean_2.ply 69 | scene0775_00_vh_clean_2.ply 70 | scene0776_00_vh_clean_2.ply 71 | scene0777_00_vh_clean_2.ply 72 | scene0778_00_vh_clean_2.ply 73 | scene0779_00_vh_clean_2.ply 74 | scene0780_00_vh_clean_2.ply 75 | scene0781_00_vh_clean_2.ply 76 | scene0782_00_vh_clean_2.ply 77 | scene0783_00_vh_clean_2.ply 78 | scene0784_00_vh_clean_2.ply 79 | scene0785_00_vh_clean_2.ply 80 | scene0786_00_vh_clean_2.ply 81 | scene0787_00_vh_clean_2.ply 82 | scene0788_00_vh_clean_2.ply 83 | scene0789_00_vh_clean_2.ply 84 | scene0790_00_vh_clean_2.ply 85 | scene0791_00_vh_clean_2.ply 86 | scene0792_00_vh_clean_2.ply 87 | scene0793_00_vh_clean_2.ply 88 | scene0794_00_vh_clean_2.ply 89 | scene0795_00_vh_clean_2.ply 90 | scene0796_00_vh_clean_2.ply 91 | scene0797_00_vh_clean_2.ply 92 | scene0798_00_vh_clean_2.ply 93 | scene0799_00_vh_clean_2.ply 94 | scene0800_00_vh_clean_2.ply 95 | scene0801_00_vh_clean_2.ply 96 | scene0802_00_vh_clean_2.ply 97 | scene0803_00_vh_clean_2.ply 98 | scene0804_00_vh_clean_2.ply 99 | scene0805_00_vh_clean_2.ply 100 | scene0806_00_vh_clean_2.ply 101 | -------------------------------------------------------------------------------- /pointdc_mk/data_prepare/ScanNet_splits/scannetv2_val.txt: -------------------------------------------------------------------------------- 1 | scene0568_00.ply 2 | scene0568_01.ply 3 | scene0568_02.ply 4 | scene0304_00.ply 5 | scene0488_00.ply 6 | scene0488_01.ply 7 | scene0412_00.ply 8 | scene0412_01.ply 9 | scene0217_00.ply 10 | scene0019_00.ply 11 | scene0019_01.ply 12 | scene0414_00.ply 13 | scene0575_00.ply 14 | scene0575_01.ply 15 | scene0575_02.ply 16 | scene0426_00.ply 17 | scene0426_01.ply 18 | scene0426_02.ply 19 | scene0426_03.ply 20 | scene0549_00.ply 21 | scene0549_01.ply 22 | scene0578_00.ply 23 | scene0578_01.ply 24 | scene0578_02.ply 25 | scene0665_00.ply 26 | scene0665_01.ply 27 | scene0050_00.ply 28 | scene0050_01.ply 29 | scene0050_02.ply 30 | scene0257_00.ply 31 | scene0025_00.ply 32 | scene0025_01.ply 33 | scene0025_02.ply 34 | scene0583_00.ply 35 | scene0583_01.ply 36 | scene0583_02.ply 37 | scene0701_00.ply 38 | scene0701_01.ply 39 | scene0701_02.ply 40 | scene0580_00.ply 41 | scene0580_01.ply 42 | scene0565_00.ply 43 | scene0169_00.ply 44 | scene0169_01.ply 45 | scene0655_00.ply 46 | scene0655_01.ply 47 | scene0655_02.ply 48 | scene0063_00.ply 49 | scene0221_00.ply 50 | scene0221_01.ply 51 | scene0591_00.ply 52 | scene0591_01.ply 53 | scene0591_02.ply 54 | scene0678_00.ply 55 | scene0678_01.ply 56 | scene0678_02.ply 57 | scene0462_00.ply 58 | scene0427_00.ply 59 | scene0595_00.ply 60 | scene0193_00.ply 61 | scene0193_01.ply 62 | scene0164_00.ply 63 | scene0164_01.ply 64 | scene0164_02.ply 65 | scene0164_03.ply 66 | scene0598_00.ply 67 | scene0598_01.ply 68 | scene0598_02.ply 69 | scene0599_00.ply 70 | scene0599_01.ply 71 | scene0599_02.ply 72 | scene0328_00.ply 73 | scene0300_00.ply 74 | scene0300_01.ply 75 | scene0354_00.ply 76 | scene0458_00.ply 77 | scene0458_01.ply 78 | scene0423_00.ply 79 | scene0423_01.ply 80 | scene0423_02.ply 81 | scene0307_00.ply 82 | scene0307_01.ply 83 | scene0307_02.ply 84 | scene0606_00.ply 85 | scene0606_01.ply 86 | scene0606_02.ply 87 | scene0432_00.ply 88 | scene0432_01.ply 89 | scene0608_00.ply 90 | scene0608_01.ply 91 | scene0608_02.ply 92 | scene0651_00.ply 93 | scene0651_01.ply 94 | scene0651_02.ply 95 | scene0430_00.ply 96 | scene0430_01.ply 97 | scene0689_00.ply 98 | scene0357_00.ply 99 | scene0357_01.ply 100 | scene0574_00.ply 101 | scene0574_01.ply 102 | scene0574_02.ply 103 | scene0329_00.ply 104 | scene0329_01.ply 105 | scene0329_02.ply 106 | scene0153_00.ply 107 | scene0153_01.ply 108 | scene0616_00.ply 109 | scene0616_01.ply 110 | scene0671_00.ply 111 | scene0671_01.ply 112 | scene0618_00.ply 113 | scene0382_00.ply 114 | scene0382_01.ply 115 | scene0490_00.ply 116 | scene0621_00.ply 117 | scene0607_00.ply 118 | scene0607_01.ply 119 | scene0149_00.ply 120 | scene0695_00.ply 121 | scene0695_01.ply 122 | scene0695_02.ply 123 | scene0695_03.ply 124 | scene0389_00.ply 125 | scene0377_00.ply 126 | scene0377_01.ply 127 | scene0377_02.ply 128 | scene0342_00.ply 129 | scene0139_00.ply 130 | scene0629_00.ply 131 | scene0629_01.ply 132 | scene0629_02.ply 133 | scene0496_00.ply 134 | scene0633_00.ply 135 | scene0633_01.ply 136 | scene0518_00.ply 137 | scene0652_00.ply 138 | scene0406_00.ply 139 | scene0406_01.ply 140 | scene0406_02.ply 141 | scene0144_00.ply 142 | scene0144_01.ply 143 | scene0494_00.ply 144 | scene0278_00.ply 145 | scene0278_01.ply 146 | scene0316_00.ply 147 | scene0609_00.ply 148 | scene0609_01.ply 149 | scene0609_02.ply 150 | scene0609_03.ply 151 | scene0084_00.ply 152 | scene0084_01.ply 153 | scene0084_02.ply 154 | scene0696_00.ply 155 | scene0696_01.ply 156 | scene0696_02.ply 157 | scene0351_00.ply 158 | scene0351_01.ply 159 | scene0643_00.ply 160 | scene0644_00.ply 161 | scene0645_00.ply 162 | scene0645_01.ply 163 | scene0645_02.ply 164 | scene0081_00.ply 165 | scene0081_01.ply 166 | scene0081_02.ply 167 | scene0647_00.ply 168 | scene0647_01.ply 169 | scene0535_00.ply 170 | scene0353_00.ply 171 | scene0353_01.ply 172 | scene0353_02.ply 173 | scene0559_00.ply 174 | scene0559_01.ply 175 | scene0559_02.ply 176 | scene0593_00.ply 177 | scene0593_01.ply 178 | scene0246_00.ply 179 | scene0653_00.ply 180 | scene0653_01.ply 181 | scene0064_00.ply 182 | scene0064_01.ply 183 | scene0356_00.ply 184 | scene0356_01.ply 185 | scene0356_02.ply 186 | scene0030_00.ply 187 | scene0030_01.ply 188 | scene0030_02.ply 189 | scene0222_00.ply 190 | scene0222_01.ply 191 | scene0338_00.ply 192 | scene0338_01.ply 193 | scene0338_02.ply 194 | scene0378_00.ply 195 | scene0378_01.ply 196 | scene0378_02.ply 197 | scene0660_00.ply 198 | scene0553_00.ply 199 | scene0553_01.ply 200 | scene0553_02.ply 201 | scene0527_00.ply 202 | scene0663_00.ply 203 | scene0663_01.ply 204 | scene0663_02.ply 205 | scene0664_00.ply 206 | scene0664_01.ply 207 | scene0664_02.ply 208 | scene0334_00.ply 209 | scene0334_01.ply 210 | scene0334_02.ply 211 | scene0046_00.ply 212 | scene0046_01.ply 213 | scene0046_02.ply 214 | scene0203_00.ply 215 | scene0203_01.ply 216 | scene0203_02.ply 217 | scene0088_00.ply 218 | scene0088_01.ply 219 | scene0088_02.ply 220 | scene0088_03.ply 221 | scene0086_00.ply 222 | scene0086_01.ply 223 | scene0086_02.ply 224 | scene0670_00.ply 225 | scene0670_01.ply 226 | scene0256_00.ply 227 | scene0256_01.ply 228 | scene0256_02.ply 229 | scene0249_00.ply 230 | scene0441_00.ply 231 | scene0658_00.ply 232 | scene0704_00.ply 233 | scene0704_01.ply 234 | scene0187_00.ply 235 | scene0187_01.ply 236 | scene0131_00.ply 237 | scene0131_01.ply 238 | scene0131_02.ply 239 | scene0207_00.ply 240 | scene0207_01.ply 241 | scene0207_02.ply 242 | scene0461_00.ply 243 | scene0011_00.ply 244 | scene0011_01.ply 245 | scene0343_00.ply 246 | scene0251_00.ply 247 | scene0077_00.ply 248 | scene0077_01.ply 249 | scene0684_00.ply 250 | scene0684_01.ply 251 | scene0550_00.ply 252 | scene0686_00.ply 253 | scene0686_01.ply 254 | scene0686_02.ply 255 | scene0208_00.ply 256 | scene0500_00.ply 257 | scene0500_01.ply 258 | scene0552_00.ply 259 | scene0552_01.ply 260 | scene0648_00.ply 261 | scene0648_01.ply 262 | scene0435_00.ply 263 | scene0435_01.ply 264 | scene0435_02.ply 265 | scene0435_03.ply 266 | scene0690_00.ply 267 | scene0690_01.ply 268 | scene0693_00.ply 269 | scene0693_01.ply 270 | scene0693_02.ply 271 | scene0700_00.ply 272 | scene0700_01.ply 273 | scene0700_02.ply 274 | scene0699_00.ply 275 | scene0231_00.ply 276 | scene0231_01.ply 277 | scene0231_02.ply 278 | scene0697_00.ply 279 | scene0697_01.ply 280 | scene0697_02.ply 281 | scene0697_03.ply 282 | scene0474_00.ply 283 | scene0474_01.ply 284 | scene0474_02.ply 285 | scene0474_03.ply 286 | scene0474_04.ply 287 | scene0474_05.ply 288 | scene0355_00.ply 289 | scene0355_01.ply 290 | scene0146_00.ply 291 | scene0146_01.ply 292 | scene0146_02.ply 293 | scene0196_00.ply 294 | scene0702_00.ply 295 | scene0702_01.ply 296 | scene0702_02.ply 297 | scene0314_00.ply 298 | scene0277_00.ply 299 | scene0277_01.ply 300 | scene0277_02.ply 301 | scene0095_00.ply 302 | scene0095_01.ply 303 | scene0015_00.ply 304 | scene0100_00.ply 305 | scene0100_01.ply 306 | scene0100_02.ply 307 | scene0558_00.ply 308 | scene0558_01.ply 309 | scene0558_02.ply 310 | scene0685_00.ply 311 | scene0685_01.ply 312 | scene0685_02.ply 313 | -------------------------------------------------------------------------------- /pointdc_mk/data_prepare/data_prepare_S3DIS.py: -------------------------------------------------------------------------------- 1 | import MinkowskiEngine as ME 2 | from os.path import join, exists, dirname, abspath 3 | import numpy as np 4 | import pandas as pd 5 | import os, sys, glob 6 | import argparse 7 | from tqdm import tqdm 8 | 9 | BASE_DIR = dirname(abspath(__file__)) 10 | ROOT_DIR = dirname(BASE_DIR) 11 | sys.path.append(BASE_DIR) 12 | sys.path.append(ROOT_DIR) 13 | from lib.helper_ply import write_ply 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--data_path', type=str, default='data/Stanford3dDataset_v1.2', help='raw data path') 17 | parser.add_argument('--processed_data_path', type=str, default='data/S3DIS/processed') 18 | args = parser.parse_args() 19 | 20 | args.data_path = join(ROOT_DIR, args.data_path) 21 | args.processed_data_path = join(ROOT_DIR, args.processed_data_path) 22 | 23 | anno_paths = [line.rstrip() for line in open(join(BASE_DIR, 'S3DIS_anno_paths.txt'))] 24 | anno_paths = [join(args.data_path, p) for p in anno_paths] 25 | 26 | gt_class = [x.rstrip() for x in open(join(BASE_DIR, 'S3DIS_class_names.txt'))] 27 | gt_class2label = {cls: i for i, cls in enumerate(gt_class)} 28 | 29 | sub_grid_size = 0.010 30 | if not exists(args.processed_data_path): 31 | os.makedirs(args.processed_data_path) 32 | out_format = '.ply' 33 | 34 | def convert_pc2ply(anno_path, file_name): 35 | sub_ply_file = join(args.processed_data_path, file_name) 36 | if os.path.exists(sub_ply_file): return 37 | data_list = [] 38 | 39 | for f in glob.glob(join(anno_path, '*.txt')): 40 | class_name = os.path.basename(f).split('_')[0] 41 | if class_name not in gt_class: # note: in some room there is 'staris' class.. 42 | class_name = 'clutter' 43 | pc = pd.read_csv(f, header=None, delim_whitespace=True).values 44 | labels = np.ones((pc.shape[0], 1)) * gt_class2label[class_name] 45 | data_list.append(np.concatenate([pc, labels], 1)) 46 | 47 | pc_info = np.concatenate(data_list, 0) 48 | 49 | coords = pc_info[:, :3] 50 | colors = pc_info[:, 3:6].astype(np.uint8) 51 | labels = pc_info[:, 6] 52 | 53 | _, _, collabels, inds = ME.utils.sparse_quantize(np.ascontiguousarray(coords), colors, labels, return_index=True, ignore_label=-1, quantization_size=sub_grid_size) 54 | sub_coords, sub_colors, sub_labels = coords[inds], colors[inds], collabels 55 | 56 | write_ply(sub_ply_file, [sub_coords, sub_colors, sub_labels[:,None]], ['x', 'y', 'z', 'red', 'green', 'blue', 'class']) 57 | 58 | print('start preprocess') 59 | # Note: there is an extra character in the v1.2 data in Area_5/hallway_6. It's fixed manually. 60 | anno_paths_bar = tqdm(anno_paths) 61 | for annotation_path in anno_paths_bar: 62 | # print(annotation_path) 63 | anno_paths_bar.set_description(annotation_path.split('/')[-3] + '_' + annotation_path.split('/')[-2]) 64 | elements = str(annotation_path).split('/') 65 | out_file_name = elements[-3] + '_' + elements[-2] + out_format 66 | convert_pc2ply(annotation_path, out_file_name) 67 | -------------------------------------------------------------------------------- /pointdc_mk/data_prepare/data_prepare_ScanNet.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from os.path import join, exists, dirname, abspath 3 | import os, sys, glob 4 | 5 | import numpy as np 6 | BASE_DIR = dirname(abspath(__file__)) 7 | ROOT_DIR = dirname(BASE_DIR) 8 | sys.path.append(BASE_DIR) 9 | sys.path.append(ROOT_DIR) 10 | from lib.helper_ply import read_ply, write_ply 11 | from concurrent.futures import ProcessPoolExecutor 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--data_path', type=str, default='data/ScanNet/', help='raw data path') 16 | parser.add_argument('--processed_data_path', type=str, default='data/ScanNet/') 17 | args = parser.parse_args() 18 | 19 | SCANNET_RAW_PATH = Path(join(ROOT_DIR, args.data_path)) 20 | SCANNET_OUT_PATH = Path(join(ROOT_DIR, args.processed_data_path)) 21 | TRAIN_DEST = 'train' 22 | TEST_DEST = 'test' 23 | SUBSETS = {TRAIN_DEST: 'scans', TEST_DEST: 'scans_test'} 24 | POINTCLOUD_FILE = '_vh_clean_2.ply' 25 | BUGS = { 26 | 'scene0270_00': 50, 27 | 'scene0270_02': 50, 28 | 'scene0384_00': 149, 29 | } 30 | CLASS_LABELS = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 31 | 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator', 32 | 'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture') 33 | VALID_CLASS_IDS = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39) 34 | NUM_LABELS = 41 35 | IGNORE_LABELS = tuple(set(range(41)) - set(VALID_CLASS_IDS)) 36 | 37 | ''' Set Invalid Label to -1''' 38 | label_map = {} 39 | n_used = 0 40 | for l in range(NUM_LABELS): 41 | if l in IGNORE_LABELS: 42 | label_map[l] = -1#ignore label 43 | else: 44 | label_map[l] = n_used 45 | n_used += 1 46 | 47 | def handle_process(path): 48 | f = Path(path.split(',')[0]) 49 | phase_out_path = Path(path.split(',')[1]) 50 | # pointcloud = read_ply(f) 51 | data = read_ply(str(f), triangular_mesh=True)[0] 52 | coords = np.vstack((data['x'], data['y'], data['z'])).T 53 | colors = np.vstack((data['red'], data['green'], data['blue'])).T 54 | # Load label file. 55 | label_f = f.parent / (f.stem + '.labels' + f.suffix) 56 | if label_f.is_file(): 57 | label = read_ply(str(label_f), triangular_mesh=True)[0]['label'].T.squeeze() 58 | else: # Label may not exist in test case. 59 | label = -np.zeros(data.shape) 60 | out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix) 61 | 62 | '''Fix Data Bug''' 63 | for item, bug_index in BUGS.items(): 64 | if item in path: 65 | print('Fixing {} bugged label'.format(item)) 66 | bug_mask = label == bug_index 67 | label[bug_mask] = 0 68 | 69 | label = np.array([label_map[x] for x in label]) 70 | write_ply(str(out_f), [coords.astype(np.float64), colors, label[:, None].astype(np.float64)], ['x', 'y', 'z', 'red', 'green', 'blue', 'class']) 71 | 72 | 73 | print('start preprocess') 74 | 75 | path_list = [] 76 | for out_path, in_path in SUBSETS.items(): 77 | phase_out_path = SCANNET_OUT_PATH / out_path 78 | # phase_out_path = SCANNET_OUT_PATH 79 | phase_out_path.mkdir(parents=True, exist_ok=True) 80 | for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE): 81 | path_list.append(str(f) + ',' + str(phase_out_path)) 82 | 83 | pool = ProcessPoolExecutor(max_workers=30) 84 | result = list(pool.map(handle_process, path_list)) -------------------------------------------------------------------------------- /pointdc_mk/data_prepare/initialSP_prepare_S3DIS.py: -------------------------------------------------------------------------------- 1 | from pclpy import pcl 2 | import pclpy 3 | import numpy as np 4 | from scipy import stats 5 | import os 6 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 7 | from os.path import join, exists, dirname, abspath 8 | import sys, glob 9 | 10 | BASE_DIR = dirname(abspath(__file__)) 11 | ROOT_DIR = dirname(BASE_DIR) 12 | sys.path.append(BASE_DIR) 13 | sys.path.append(ROOT_DIR) 14 | from lib.helper_ply import read_ply, write_ply 15 | import time 16 | import MinkowskiEngine as ME 17 | import matplotlib.pyplot as plt 18 | from concurrent.futures import ProcessPoolExecutor 19 | from pathlib import Path 20 | from tqdm import tqdm 21 | 22 | colormap = [] 23 | for _ in range(1000): 24 | for k in range(12): 25 | colormap.append(plt.cm.Set3(k)) 26 | for k in range(9): 27 | colormap.append(plt.cm.Set1(k)) 28 | for k in range(8): 29 | colormap.append(plt.cm.Set2(k)) 30 | colormap.append((0, 0, 0, 0)) 31 | colormap = np.array(colormap) 32 | import argparse 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--input_path', type=str, default='data/S3DIS/processed', help='raw data path') 36 | parser.add_argument('--sp_path', type=str, default='data/S3DIS/initial_superpoints') 37 | args = parser.parse_args() 38 | 39 | args.input_path = join(ROOT_DIR, args.input_path) 40 | args.sp_path = join(ROOT_DIR, args.sp_path) 41 | 42 | # ignore_label = 12 43 | voxel_size = 0.05 44 | vis = True 45 | 46 | def supervoxel_clustering(coords, rgb=None): 47 | pc = pcl.PointCloud.PointXYZRGBA(coords, rgb) 48 | normals = pc.compute_normals(radius=3, num_threads=2) 49 | vox = pcl.segmentation.SupervoxelClustering.PointXYZRGBA(voxel_resolution=1, seed_resolution=10) 50 | vox.setInputCloud(pc) 51 | vox.setNormalCloud(normals) 52 | vox.setSpatialImportance(0.4) 53 | vox.setNormalImportance(1) 54 | vox.setColorImportance(0.2) 55 | output = pcl.vectors.map_uint32t_PointXYZRGBA() 56 | vox.extract(output) 57 | return list(output.items()) 58 | 59 | def region_growing_simple(coords): 60 | pc = pcl.PointCloud.PointXYZ(coords) 61 | normals = pc.compute_normals(radius=3, num_threads=2) 62 | clusters = pclpy.region_growing(pc, normals=normals, min_size=1, max_size=100000, n_neighbours=15, 63 | smooth_threshold=3, curvature_threshold=1, residual_threshold=1) 64 | return clusters, normals.normals 65 | 66 | 67 | def construct_superpoints(path): 68 | f = Path(path) 69 | data = read_ply(f) 70 | coords = np.vstack((data['x'], data['y'], data['z'])).T.copy() 71 | feats = np.vstack((data['red'], data['green'], data['blue'])).T.copy() 72 | labels = data['class'].copy() 73 | coords = coords.astype(np.float32) 74 | coords -= coords.mean(0) 75 | 76 | time_start = time.time() 77 | '''Voxelize''' 78 | scale = 1 / voxel_size 79 | coords = np.floor(coords * scale) 80 | coords, feats, labels, unique_map, inverse_map = ME.utils.sparse_quantize(np.ascontiguousarray(coords), 81 | feats, labels=labels, ignore_label=-1, return_index=True, return_inverse=True) 82 | coords = coords.numpy().astype(np.float32) 83 | 84 | '''VCCS''' 85 | out = supervoxel_clustering(coords, feats) 86 | voxel_idx = -np.ones_like(labels) 87 | voxel_num = 0 88 | for voxel in range(len(out)): 89 | if out[voxel][1].voxels_.xyz.shape[0] >= 0: 90 | for xyz_voxel in out[voxel][1].voxels_.xyz: 91 | index_colum = np.where((xyz_voxel == coords).all(1)) 92 | voxel_idx[index_colum] = voxel_num 93 | voxel_num += 1 94 | 95 | '''Region Growing''' 96 | clusters = region_growing_simple(coords)[0] 97 | region_idx = -1 * np.ones_like(labels) 98 | for region in range(len(clusters)): 99 | for point_idx in clusters[region].indices: 100 | region_idx[point_idx] = region 101 | 102 | '''Merging''' 103 | merged = -np.ones_like(labels) 104 | voxel_idx[voxel_idx != -1] += len(clusters) 105 | for v in np.unique(voxel_idx): 106 | if v != -1: 107 | voxel_mask = v == voxel_idx 108 | voxel2region = region_idx[voxel_mask] ### count which regions are appeared in current voxel 109 | dominant_region = stats.mode(voxel2region)[0][0] 110 | if (dominant_region == voxel2region).sum() > voxel2region.shape[0] * 0.5: 111 | merged[voxel_mask] = dominant_region 112 | else: 113 | merged[voxel_mask] = v 114 | 115 | '''Make Superpoint Labels Continuous''' 116 | sp_labels = -np.ones_like(merged) 117 | count_num = 0 118 | for m in np.unique(merged): 119 | if m != -1: 120 | sp_labels[merged == m] = count_num 121 | count_num += 1 122 | 123 | '''ReProject to Input Point Cloud''' 124 | out_sp_labels = sp_labels[inverse_map] 125 | out_coords = np.vstack((data['x'], data['y'], data['z'])).T 126 | out_labels = data['class'].squeeze() 127 | # 128 | if not exists(args.sp_path): 129 | os.makedirs(args.sp_path) 130 | np.save(args.sp_path + '/' + f.name[:-4] + '_superpoint.npy', out_sp_labels) 131 | 132 | if vis: 133 | vis_path = args.sp_path +'/vis/' 134 | if not os.path.exists(vis_path): 135 | os.makedirs(vis_path) 136 | colors = np.zeros_like(out_coords) 137 | for p in range(colors.shape[0]): 138 | colors[p] = 255 * (colormap[out_sp_labels[p].astype(np.int32)])[:3] 139 | colors = colors.astype(np.uint8) 140 | write_ply(vis_path + '/' + f.name, [out_coords, colors], ['x', 'y', 'z', 'red', 'green', 'blue']) 141 | 142 | sp2gt = -np.ones_like(out_labels) 143 | for sp in np.unique(out_sp_labels): 144 | if sp != -1: 145 | sp_mask = sp == out_sp_labels 146 | sp2gt[sp_mask] = stats.mode(out_labels[sp_mask])[0][0] 147 | 148 | print('completed scene: {}, used time: {:.2f}s'.format(f.name, time.time() - time_start)) 149 | return (out_labels, sp2gt) 150 | 151 | 152 | 153 | print('start constructing initial superpoints') 154 | path_list = [] 155 | folders = sorted(glob.glob(args.input_path + '/*.ply')) 156 | for _, file in enumerate(folders): 157 | path_list.append(file) 158 | 159 | construct_superpoints(path_list[0]) 160 | # pool = ProcessPoolExecutor(max_workers=40) 161 | # result = list(pool.map(construct_superpoints, path_list)) 162 | print('end constructing initial superpoints') 163 | -------------------------------------------------------------------------------- /pointdc_mk/data_prepare/initialSP_prepare_ScanNet.py: -------------------------------------------------------------------------------- 1 | from pclpy import pcl 2 | import pclpy 3 | import numpy as np 4 | from scipy import stats 5 | from os.path import join, exists, dirname, abspath 6 | import sys, glob 7 | import json 8 | 9 | BASE_DIR = dirname(abspath(__file__)) 10 | ROOT_DIR = dirname(BASE_DIR) 11 | sys.path.append(BASE_DIR) 12 | sys.path.append(ROOT_DIR) 13 | from lib.helper_ply import read_ply, write_ply 14 | import os 15 | import matplotlib.pyplot as plt 16 | from concurrent.futures import ProcessPoolExecutor 17 | from pathlib import Path 18 | 19 | trainval_file = [line.rstrip() for line in open(join(BASE_DIR, 'ScanNet_splits/scannetv2_trainval.txt'))] 20 | 21 | colormap = [] 22 | for _ in range(1000): 23 | for k in range(12): 24 | colormap.append(plt.cm.Set3(k)) 25 | for k in range(9): 26 | colormap.append(plt.cm.Set1(k)) 27 | for k in range(8): 28 | colormap.append(plt.cm.Set2(k)) 29 | colormap.append((0, 0, 0, 0)) 30 | colormap = np.array(colormap) 31 | 32 | import argparse 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--input_path', type=str, default='data/ScanNet/scans/', help='raw data path') # *.segs.json文件路径 36 | parser.add_argument('--sp_path', type=str, default='data/ScanNet/initial_superpoints/') # 保存超点云文件路径 37 | parser.add_argument('--pc_path', type=str, default='data/ScanNet/train/') # 处理后点云文件路径 38 | args = parser.parse_args() 39 | 40 | args.input_path = join(ROOT_DIR, args.input_path) 41 | args.sp_path = join(ROOT_DIR, args.sp_path) 42 | args.pc_path = join(ROOT_DIR, args.pc_path) 43 | 44 | vis = True 45 | 46 | def read_superpoints(path): 47 | # 读取json文件 48 | with open(path, 'r', encoding='utf-8') as f: 49 | json_data = json.load(f) 50 | # 读取SP,重新排序 51 | ori_sp = json_data['segIndices'] 52 | unique_vals = sorted(np.unique(ori_sp)) 53 | sp_labels = np.searchsorted(unique_vals, ori_sp) 54 | # 保存文件 55 | if not os.path.exists(args.sp_path): 56 | os.makedirs(args.sp_path) 57 | np.save(args.sp_path + path.split('/')[-2] + '_superpoint.npy', sp_labels) 58 | # 可视化文件 59 | if vis: 60 | vis_path = args.sp_path+'/vis/' 61 | if not os.path.exists(vis_path): 62 | os.makedirs(vis_path) 63 | pc = Path(join(args.pc_path, path.split('/')[-2]+'.ply')) # pc文件路径 64 | data = read_ply(pc) 65 | coords = np.vstack((data['x'], data['y'], data['z'])).T.copy().astype(np.float32) 66 | colors = np.zeros_like(coords) 67 | for p in range(colors.shape[0]): 68 | colors[p] = 255 * (colormap[sp_labels[p].astype(np.int32)])[:3] 69 | colors = colors.astype(np.uint8) 70 | write_ply(vis_path + '/' + path.split('/')[-2], [coords, colors], ['x', 'y', 'z', 'red', 'green', 'blue']) 71 | 72 | print('start constructing initial superpoints') 73 | folders = sorted(glob.glob(args.input_path + '*/*.segs.json')) 74 | # read_superpoints(folders[0]) 75 | pool = ProcessPoolExecutor(max_workers=40) 76 | result = list(pool.map(read_superpoints, folders)) -------------------------------------------------------------------------------- /pointdc_mk/data_prepare/semantic-kitti.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | color_map: # bgr 38 | 0 : [0, 0, 0] 39 | 1 : [0, 0, 255] 40 | 10: [245, 150, 100] 41 | 11: [245, 230, 100] 42 | 13: [250, 80, 100] 43 | 15: [150, 60, 30] 44 | 16: [255, 0, 0] 45 | 18: [180, 30, 80] 46 | 20: [255, 0, 0] 47 | 30: [30, 30, 255] 48 | 31: [200, 40, 255] 49 | 32: [90, 30, 150] 50 | 40: [255, 0, 255] 51 | 44: [255, 150, 255] 52 | 48: [75, 0, 75] 53 | 49: [75, 0, 175] 54 | 50: [0, 200, 255] 55 | 51: [50, 120, 255] 56 | 52: [0, 150, 255] 57 | 60: [170, 255, 150] 58 | 70: [0, 175, 0] 59 | 71: [0, 60, 135] 60 | 72: [80, 240, 150] 61 | 80: [150, 240, 255] 62 | 81: [0, 0, 255] 63 | 99: [255, 255, 50] 64 | 252: [245, 150, 100] 65 | 256: [255, 0, 0] 66 | 253: [200, 40, 255] 67 | 254: [30, 30, 255] 68 | 255: [90, 30, 150] 69 | 257: [250, 80, 100] 70 | 258: [180, 30, 80] 71 | 259: [255, 0, 0] 72 | content: # as a ratio with the total number of points 73 | 0: 0.018889854628292943 74 | 1: 0.0002937197336781505 75 | 10: 0.040818519255974316 76 | 11: 0.00016609538710764618 77 | 13: 2.7879693665067774e-05 78 | 15: 0.00039838616015114444 79 | 16: 0.0 80 | 18: 0.0020633612104619787 81 | 20: 0.0016218197275284021 82 | 30: 0.00017698551338515307 83 | 31: 1.1065903904919655e-08 84 | 32: 5.532951952459828e-09 85 | 40: 0.1987493871255525 86 | 44: 0.014717169549888214 87 | 48: 0.14392298360372 88 | 49: 0.0039048553037472045 89 | 50: 0.1326861944777486 90 | 51: 0.0723592229456223 91 | 52: 0.002395131480328884 92 | 60: 4.7084144280367186e-05 93 | 70: 0.26681502148037506 94 | 71: 0.006035012012626033 95 | 72: 0.07814222006271769 96 | 80: 0.002855498193863172 97 | 81: 0.0006155958086189918 98 | 99: 0.009923127583046915 99 | 252: 0.001789309418528068 100 | 253: 0.00012709999297008662 101 | 254: 0.00016059776092534436 102 | 255: 3.745553104802113e-05 103 | 256: 0.0 104 | 257: 0.00011351574470342043 105 | 258: 0.00010157861367183268 106 | 259: 4.3840131989471124e-05 107 | # classes that are indistinguishable from single scan or inconsistent in 108 | # ground truth are mapped to their closest equivalent 109 | learning_map: 110 | 0 : 0 # "unlabeled" 111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 112 | 10: 1 # "car" 113 | 11: 2 # "bicycle" 114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 115 | 15: 3 # "motorcycle" 116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 117 | 18: 4 # "truck" 118 | 20: 5 # "other-vehicle" 119 | 30: 6 # "person" 120 | 31: 7 # "bicyclist" 121 | 32: 8 # "motorcyclist" 122 | 40: 9 # "road" 123 | 44: 10 # "parking" 124 | 48: 11 # "sidewalk" 125 | 49: 12 # "other-ground" 126 | 50: 13 # "building" 127 | 51: 14 # "fence" 128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 130 | 70: 15 # "vegetation" 131 | 71: 16 # "trunk" 132 | 72: 17 # "terrain" 133 | 80: 18 # "pole" 134 | 81: 19 # "traffic-sign" 135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 136 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 137 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 138 | 254: 6 # "moving-person" to "person" ------------------------------mapped 139 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 140 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 141 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 142 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 143 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 144 | learning_map_inv: # inverse of previous map 145 | 0: 0 # "unlabeled", and others ignored 146 | 1: 10 # "car" 147 | 2: 11 # "bicycle" 148 | 3: 15 # "motorcycle" 149 | 4: 18 # "truck" 150 | 5: 20 # "other-vehicle" 151 | 6: 30 # "person" 152 | 7: 31 # "bicyclist" 153 | 8: 32 # "motorcyclist" 154 | 9: 40 # "road" 155 | 10: 44 # "parking" 156 | 11: 48 # "sidewalk" 157 | 12: 49 # "other-ground" 158 | 13: 50 # "building" 159 | 14: 51 # "fence" 160 | 15: 70 # "vegetation" 161 | 16: 71 # "trunk" 162 | 17: 72 # "terrain" 163 | 18: 80 # "pole" 164 | 19: 81 # "traffic-sign" 165 | learning_ignore: # Ignore classes 166 | 0: True # "unlabeled", and others ignored 167 | 1: False # "car" 168 | 2: False # "bicycle" 169 | 3: False # "motorcycle" 170 | 4: False # "truck" 171 | 5: False # "other-vehicle" 172 | 6: False # "person" 173 | 7: False # "bicyclist" 174 | 8: False # "motorcyclist" 175 | 9: False # "road" 176 | 10: False # "parking" 177 | 11: False # "sidewalk" 178 | 12: False # "other-ground" 179 | 13: False # "building" 180 | 14: False # "fence" 181 | 15: False # "vegetation" 182 | 16: False # "trunk" 183 | 17: False # "terrain" 184 | 18: False # "pole" 185 | 19: False # "traffic-sign" 186 | split: # sequence numbers 187 | train: 188 | - 0 189 | - 1 190 | - 2 191 | - 3 192 | - 4 193 | - 5 194 | - 6 195 | - 7 196 | - 9 197 | - 10 198 | valid: 199 | - 8 200 | test: 201 | - 11 202 | - 12 203 | - 13 204 | - 14 205 | - 15 206 | - 16 207 | - 17 208 | - 18 209 | - 19 210 | - 20 211 | - 21 212 | -------------------------------------------------------------------------------- /pointdc_mk/datasets/SemanticKITTI.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from lib.helper_ply import read_ply, write_ply 4 | from torch.utils.data import Dataset 5 | import MinkowskiEngine as ME 6 | import random 7 | import os 8 | import open3d as o3d 9 | from lib.aug_tools import rota_coords, scale_coords, trans_coords 10 | 11 | class cfl_collate_fn: 12 | 13 | def __call__(self, list_data): 14 | coords, feats, normals, labels, inverse_map, pseudo, inds, region, index = list(zip(*list_data)) 15 | coords_batch, feats_batch, normal_batch, labels_batch, inverse_batch, pseudo_batch, inds_batch = [], [], [], [], [], [], [] 16 | region_batch = [] 17 | accm_num = 0 18 | for batch_id, _ in enumerate(coords): 19 | num_points = coords[batch_id].shape[0] 20 | coords_batch.append(torch.cat((torch.ones(num_points, 1).int() * batch_id, torch.from_numpy(coords[batch_id]).int()), 1)) 21 | feats_batch.append(torch.from_numpy(feats[batch_id])) 22 | normal_batch.append(torch.from_numpy(normals[batch_id])) 23 | labels_batch.append(torch.from_numpy(labels[batch_id]).int()) 24 | inverse_batch.append(torch.from_numpy(inverse_map[batch_id])) 25 | pseudo_batch.append(torch.from_numpy(pseudo[batch_id])) 26 | inds_batch.append(torch.from_numpy(inds[batch_id] + accm_num).int()) 27 | region_batch.append(torch.from_numpy(region[batch_id])[:,None]) 28 | accm_num += coords[batch_id].shape[0] 29 | 30 | # Concatenate all lists 31 | coords_batch = torch.cat(coords_batch, 0).float()#.int() 32 | feats_batch = torch.cat(feats_batch, 0).float() 33 | normal_batch = torch.cat(normal_batch, 0).float() 34 | labels_batch = torch.cat(labels_batch, 0).float() 35 | inverse_batch = torch.cat(inverse_batch, 0).int() 36 | pseudo_batch = torch.cat(pseudo_batch, -1) 37 | inds_batch = torch.cat(inds_batch, 0) 38 | region_batch = torch.cat(region_batch, 0) 39 | 40 | return coords_batch, feats_batch, normal_batch, labels_batch, inverse_batch, pseudo_batch, inds_batch, region_batch, index 41 | 42 | 43 | 44 | 45 | class KITTItrain(Dataset): 46 | def __init__(self, args, scene_idx, split='train'): 47 | self.args = args 48 | self.label_to_names = {0: 'unlabeled', 49 | 1: 'car', 50 | 2: 'bicycle', 51 | 3: 'motorcycle', 52 | 4: 'truck', 53 | 5: 'other-vehicle', 54 | 6: 'person', 55 | 7: 'bicyclist', 56 | 8: 'motorcyclist', 57 | 9: 'road', 58 | 10: 'parking', 59 | 11: 'sidewalk', 60 | 12: 'other-ground', 61 | 13: 'building', 62 | 14: 'fence', 63 | 15: 'vegetation', 64 | 16: 'trunk', 65 | 17: 'terrain', 66 | 18: 'pole', 67 | 19: 'traffic-sign'} 68 | self.mode = 'train' 69 | self.split = split 70 | self.val_split = '08' 71 | self.file = [] 72 | 73 | seq_list = np.sort(os.listdir(self.args.data_path)) 74 | for seq_id in seq_list: 75 | seq_path = os.path.join(self.args.data_path, seq_id) 76 | if self.split == 'train': 77 | if seq_id in ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10']: 78 | for f in np.sort(os.listdir(seq_path)): 79 | self.file.append(os.path.join(seq_path, f)) 80 | 81 | elif self == 'val': 82 | if seq_id == '08': 83 | for f in np.sort(os.listdir(seq_path)): 84 | self.file.append(os.path.join(seq_path, f)) 85 | scene_idx = range(len(self.file)) 86 | 87 | '''Initial Augmentations''' 88 | self.trans_coords = trans_coords(shift_ratio=50) ### 50% 89 | self.rota_coords = rota_coords(rotation_bound = ((-np.pi/32, np.pi/32), (-np.pi/32, np.pi/32), (-np.pi, np.pi))) 90 | self.scale_coords = scale_coords(scale_bound=(0.9, 1.1)) 91 | 92 | self.random_select_sample(scene_idx) 93 | 94 | def random_select_sample(self, scene_idx): 95 | self.name = [] 96 | self.file_selected = [] 97 | for i in scene_idx: 98 | self.file_selected.append(self.file[i]) 99 | self.name.append(self.file[i][0:-4].replace(self.args.data_path, '')) 100 | 101 | 102 | def augs(self, coords): 103 | coords = self.rota_coords(coords) 104 | coords = self.trans_coords(coords) 105 | coords = self.scale_coords(coords) 106 | return coords 107 | 108 | 109 | def augment_coords_to_feats(self, coords, feats, labels=None): 110 | coords_center = coords.mean(0, keepdims=True) 111 | coords_center[0, 2] = 0 112 | norm_coords = (coords - coords_center) 113 | return norm_coords, feats, labels 114 | 115 | def voxelize(self, coords, feats, labels): 116 | scale = 1 / self.args.voxel_size 117 | coords = np.floor(coords * scale) 118 | coords, feats, labels, unique_map, inverse_map = ME.utils.sparse_quantize(np.ascontiguousarray(coords), feats, labels=labels, ignore_label=-1, return_index=True, return_inverse=True) 119 | return coords.numpy(), feats, labels, unique_map, inverse_map.numpy() 120 | 121 | 122 | def __len__(self): 123 | return len(self.file_selected) 124 | 125 | def __getitem__(self, index): 126 | file = self.file_selected[index] 127 | data = read_ply(file) 128 | coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T 129 | feats = np.array(data['remission'])[:, np.newaxis] 130 | labels = np.array(data['class']) 131 | coords = coords.astype(np.float32) 132 | coords -= coords.mean(0) 133 | 134 | coords, feats, labels, unique_map, inverse_map = self.voxelize(coords, feats, labels) 135 | coords = coords.astype(np.float32) 136 | 137 | mask = np.sqrt(((coords*self.args.voxel_size)**2).sum(-1))< self.args.r_crop 138 | coords, feats, labels = coords[mask], feats[mask], labels[mask] 139 | 140 | region_file = self.args.sp_path + '/' +self.name[index] + '_superpoint.npy' 141 | region = np.load(region_file) 142 | region = region[unique_map] 143 | region = region[mask] 144 | 145 | coords = self.augs(coords) 146 | 147 | ''' Take Mixup as an Augmentation''' 148 | inds = np.arange(coords.shape[0]) 149 | mix = random.randint(0, len(self.name)-1) 150 | 151 | data_mix = read_ply(self.file_selected[mix]) 152 | coords_mix = np.array([data_mix['x'], data_mix['y'], data_mix['z']], dtype=np.float32).T 153 | feats_mix = np.array(data_mix['remission'])[:, np.newaxis] 154 | labels_mix = np.array(data_mix['class']) 155 | feats_mix = feats_mix.astype(np.float32) 156 | coords_mix = coords_mix.astype(np.float32) 157 | coords_mix -= coords_mix.mean(0) 158 | 159 | coords_mix, feats_mix, _, unique_map_mix, _ = self.voxelize(coords_mix, feats_mix, labels_mix) 160 | coords_mix = coords_mix.astype(np.float32) 161 | 162 | mask_mix = np.sqrt(((coords_mix * self.args.voxel_size) ** 2).sum(-1)) < self.args.r_crop 163 | coords_mix, feats_mix = coords_mix[mask_mix], feats_mix[mask_mix] 164 | # 165 | coords_mix = self.augs(coords_mix) 166 | coords = np.concatenate((coords, coords_mix), axis=0) 167 | ''' End Mixup''' 168 | 169 | coords, feats, labels = self.augment_coords_to_feats(coords, feats, labels) 170 | labels -= 1 171 | 172 | '''mode must be cluster or train''' 173 | if self.mode == 'cluster': 174 | pcd = o3d.geometry.PointCloud() 175 | pcd.points = o3d.utility.Vector3dVector(coords[inds]) 176 | pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=10, max_nn=30)) 177 | normals = np.array(pcd.normals) 178 | 179 | region[labels==-1] = -1 180 | 181 | for q in np.unique(region): 182 | mask = q == region 183 | if mask.sum() < self.args.drop_threshold and q != -1: 184 | region[mask] = -1 185 | 186 | valid_region = region[region != -1] 187 | unique_vals = np.unique(valid_region) 188 | unique_vals.sort() 189 | valid_region = np.searchsorted(unique_vals, valid_region) 190 | 191 | region[region != -1] = valid_region 192 | 193 | pseudo = -np.ones_like(labels).astype(np.long) 194 | 195 | else: 196 | normals = np.zeros_like(coords) 197 | scene_name = self.name[index] 198 | file_path = self.args.pseudo_label_path + '/' + scene_name + '.npy' 199 | pseudo = np.array(np.load(file_path), dtype=np.long) 200 | 201 | 202 | return coords, feats, normals, labels, inverse_map, pseudo, inds, region, index 203 | 204 | 205 | 206 | class KITTIval(Dataset): 207 | def __init__(self, args, split='val'): 208 | self.args = args 209 | self.label_to_names = {0: 'unlabeled', 210 | 1: 'car', 211 | 2: 'bicycle', 212 | 3: 'motorcycle', 213 | 4: 'truck', 214 | 5: 'other-vehicle', 215 | 6: 'person', 216 | 7: 'bicyclist', 217 | 8: 'motorcyclist', 218 | 9: 'road', 219 | 10: 'parking', 220 | 11: 'sidewalk', 221 | 12: 'other-ground', 222 | 13: 'building', 223 | 14: 'fence', 224 | 15: 'vegetation', 225 | 16: 'trunk', 226 | 17: 'terrain', 227 | 18: 'pole', 228 | 19: 'traffic-sign'} 229 | self.name = [] 230 | self.mode = 'val' 231 | self.split = split 232 | self.val_split = '08' 233 | self.file = [] 234 | 235 | seq_list = np.sort(os.listdir(self.args.data_path)) 236 | for seq_id in seq_list: 237 | seq_path = os.path.join(self.args.data_path, seq_id) 238 | if self.split == 'val': 239 | if seq_id == '08': 240 | for f in np.sort(os.listdir(seq_path)): 241 | self.file.append(os.path.join(seq_path, f)) 242 | self.name.append(os.path.join(seq_path, f)[0:-4].replace(self.args.data_path, '')) 243 | 244 | 245 | def augment_coords_to_feats(self, coords, feats, labels=None): 246 | coords_center = coords.mean(0, keepdims=True) 247 | coords_center[0, 2] = 0 248 | norm_coords = (coords - coords_center) 249 | return norm_coords, feats, labels 250 | 251 | def voxelize(self, coords, feats, labels): 252 | scale = 1 / self.args.voxel_size 253 | coords = np.floor(coords * scale) 254 | coords, feats, labels, unique_map, inverse_map = ME.utils.sparse_quantize(np.ascontiguousarray(coords), feats, labels=labels, ignore_label=-1, return_index=True, return_inverse=True) 255 | return coords.numpy(), feats, labels, unique_map, inverse_map.numpy() 256 | 257 | 258 | def __len__(self): 259 | return len(self.file) 260 | 261 | def __getitem__(self, index): 262 | file = self.file[index] 263 | data = read_ply(file) 264 | coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T 265 | feats = np.array(data['remission'])[:, np.newaxis] 266 | labels = np.array(data['class']) 267 | coords = coords.astype(np.float32) 268 | coords -= coords.mean(0) 269 | 270 | coords, feats, _, unique_map, inverse_map = self.voxelize(coords, feats, labels) 271 | coords = coords.astype(np.float32) 272 | 273 | region_file = self.args.sp_path + '/' +self.name[index] + '_superpoint.npy' 274 | region = np.load(region_file) 275 | region = region[unique_map] 276 | 277 | coords, feats, labels = self.augment_coords_to_feats(coords, feats, labels) 278 | labels = labels -1 279 | 280 | return coords, feats, np.ascontiguousarray(labels), inverse_map, region, index 281 | 282 | 283 | class cfl_collate_fn_val: 284 | 285 | def __call__(self, list_data): 286 | coords, feats, labels, inverse_map, region, index = list(zip(*list_data)) 287 | coords_batch, feats_batch, inverse_batch, labels_batch = [], [], [], [] 288 | region_batch = [] 289 | for batch_id, _ in enumerate(coords): 290 | num_points = coords[batch_id].shape[0] 291 | coords_batch.append( 292 | torch.cat((torch.ones(num_points, 1).int() * batch_id, torch.from_numpy(coords[batch_id]).int()), 1)) 293 | feats_batch.append(torch.from_numpy(feats[batch_id])) 294 | inverse_batch.append(torch.from_numpy(inverse_map[batch_id])) 295 | labels_batch.append(torch.from_numpy(labels[batch_id]).int()) 296 | region_batch.append(torch.from_numpy(region[batch_id])[:, None]) 297 | # 298 | # Concatenate all lists 299 | coords_batch = torch.cat(coords_batch, 0).float() 300 | feats_batch = torch.cat(feats_batch, 0).float() 301 | inverse_batch = torch.cat(inverse_batch, 0).int() 302 | labels_batch = torch.cat(labels_batch, 0).int() 303 | region_batch = torch.cat(region_batch, 0) 304 | 305 | return coords_batch, feats_batch, inverse_batch, labels_batch, index, region_batch 306 | -------------------------------------------------------------------------------- /pointdc_mk/env.yaml: -------------------------------------------------------------------------------- 1 | name: pointdc_mk 2 | channels: 3 | - davidcaron 4 | - pytorch/label/nightly 5 | - pytorch 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=2_kmp_llvm 11 | - appdirs=1.4.4=pyhd3eb1b0_0 12 | - blas=1.0=mkl 13 | - boost-cpp=1.72.0=he72f1d9_7 14 | - brotlipy=0.7.0=py38h27cfd23_1003 15 | - bzip2=1.0.8=h7f98852_4 16 | - c-ares=1.19.1=hd590300_0 17 | - ca-certificates=2023.05.30=h06a4308_0 18 | - cffi=1.15.1=py38h5eee18b_3 19 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 20 | - cryptography=41.0.2=py38h22a60cf_0 21 | - cudatoolkit=11.3.1=hb98b00a_12 22 | - eigen=3.4.1=h00ab1b0_0 23 | - ffmpeg=4.3=hf484d3e_0 24 | - flann=1.9.1=h941a29b_1013 25 | - freetype=2.12.1=hca18f0e_1 26 | - gmp=6.2.1=h58526e2_0 27 | - gnutls=3.6.13=h85f3911_1 28 | - hdf5=1.14.1=nompi_h4f84152_100 29 | - icu=70.1=h27087fc_0 30 | - idna=3.4=py38h06a4308_0 31 | - intel-openmp=2021.4.0=h06a4308_3561 32 | - jpeg=9e=h0b41bf4_3 33 | - keyutils=1.6.1=h166bdaf_0 34 | - krb5=1.21.1=h659d440_0 35 | - lame=3.100=h166bdaf_1003 36 | - laspy=2.3.0=pyha21a80b_0 37 | - lcms2=2.15=hfd0df8a_0 38 | - ld_impl_linux-64=2.40=h41732ed_0 39 | - lerc=4.0.0=h27087fc_0 40 | - libaec=1.0.6=hcb278e6_1 41 | - libcurl=8.2.0=hca28451_0 42 | - libdeflate=1.17=h0b41bf4_0 43 | - libedit=3.1.20191231=he28a2e2_2 44 | - libev=4.33=h516909a_1 45 | - libfaiss=1.7.3=hfc2d529_97_cuda11.3_nightly 46 | - libffi=3.4.2=h7f98852_5 47 | - libgcc-ng=13.1.0=he5830b7_0 48 | - libgfortran-ng=13.1.0=h69a702a_0 49 | - libgfortran5=13.1.0=h15d22d2_0 50 | - libhwloc=2.9.1=hd6dc26d_0 51 | - libiconv=1.17=h166bdaf_0 52 | - libnghttp2=1.52.0=h61bc06f_0 53 | - libnsl=2.0.0=h7f98852_0 54 | - libpng=1.6.39=h753d276_0 55 | - libsqlite=3.42.0=h2797004_0 56 | - libssh2=1.11.0=h0841786_0 57 | - libstdcxx-ng=13.1.0=hfd8a6a1_0 58 | - libtiff=4.5.0=h6adf6a1_2 59 | - libuv=1.44.2=h166bdaf_0 60 | - libwebp-base=1.3.1=hd590300_0 61 | - libxcb=1.13=h7f98852_1004 62 | - libxml2=2.10.3=hca2bb57_4 63 | - libzlib=1.2.13=hd590300_5 64 | - llvm-openmp=16.0.6=h4dfa4b3_0 65 | - mkl=2021.4.0=h06a4308_640 66 | - mkl-include=2022.1.0=h84fe81f_915 67 | - mkl-service=2.4.0=py38h7f8727e_0 68 | - mkl_fft=1.3.1=py38hd3c417c_0 69 | - mkl_random=1.2.2=py38h51133e4_0 70 | - ncurses=6.4=hcb278e6_0 71 | - nettle=3.6=he412f7d_0 72 | - numpy=1.21.2=py38h20f2e39_0 73 | - numpy-base=1.21.2=py38h79a1101_0 74 | - openh264=2.1.1=h780b84a_0 75 | - openjpeg=2.5.0=hfec8fc6_2 76 | - openssl=3.1.1=hd590300_1 77 | - pcl=1.9.1=h2dfa329_1005 78 | - pclpy=0.12.0=py38_1 79 | - pillow=9.4.0=py38hde6dc18_1 80 | - pooch=1.4.0=pyhd3eb1b0_0 81 | - pthread-stubs=0.4=h36c2ea0_1001 82 | - pycparser=2.21=pyhd3eb1b0_0 83 | - pyopenssl=23.2.0=py38h06a4308_0 84 | - pysocks=1.7.1=py38h06a4308_0 85 | - python=3.8.12=h0744224_3_cpython 86 | - python_abi=3.8=3_cp38 87 | - pytorch=1.10.2=py3.8_cuda11.3_cudnn8.2.0_0 88 | - pytorch-mutex=1.0=cuda 89 | - qhull=2015.2=hc9558a2_1001 90 | - readline=8.2=h8228510_1 91 | - requests=2.31.0=py38h06a4308_0 92 | - setuptools=68.0.0=pyhd8ed1ab_0 93 | - six=1.16.0=pyhd3eb1b0_1 94 | - sqlite=3.42.0=h2c6b66d_0 95 | - tbb=2021.9.0=hf52228f_0 96 | - threadpoolctl=2.2.0=pyh0d69192_0 97 | - tk=8.6.12=h27826a3_0 98 | - torchvision=0.11.3=py38_cu113 99 | - typing_extensions=4.7.1=pyha770c72_0 100 | - wheel=0.40.0=pyhd8ed1ab_1 101 | - xorg-libxau=1.0.11=hd590300_0 102 | - xorg-libxdmcp=1.1.3=h7f98852_0 103 | - xz=5.2.6=h166bdaf_0 104 | - zlib=1.2.13=hd590300_5 105 | - zstd=1.5.2=hfc55251_7 106 | - pip: 107 | - anyio==3.7.1 108 | - argon2-cffi==21.3.0 109 | - argon2-cffi-bindings==21.2.0 110 | - arrow==1.2.3 111 | - asttokens==2.2.1 112 | - async-lru==2.0.3 113 | - attrs==23.1.0 114 | - babel==2.12.1 115 | - backcall==0.2.0 116 | - beautifulsoup4==4.12.2 117 | - bleach==6.0.0 118 | - blessed==1.20.0 119 | - cachetools==5.3.1 120 | - certifi==2023.5.7 121 | - clip==1.0 122 | - comm==0.1.3 123 | - cycler==0.11.0 124 | - debugpy==1.6.7 125 | - decorator==5.1.1 126 | - defusedxml==0.7.1 127 | - exceptiongroup==1.1.2 128 | - executing==1.2.0 129 | - faiss-gpu==1.7.2 130 | - fastjsonschema==2.17.1 131 | - fonttools==4.41.0 132 | - fqdn==1.5.1 133 | - ftfy==6.1.1 134 | - gpustat==1.1 135 | - importlib-metadata==6.8.0 136 | - importlib-resources==6.0.0 137 | - ipykernel==6.24.0 138 | - ipython==8.12.2 139 | - ipywidgets==8.0.7 140 | - isoduration==20.11.0 141 | - jedi==0.18.2 142 | - jinja2==3.1.2 143 | - joblib==1.3.1 144 | - json5==0.9.14 145 | - jsonpointer==2.4 146 | - jsonschema==4.18.4 147 | - jsonschema-specifications==2023.7.1 148 | - jupyter-client==8.3.0 149 | - jupyter-core==5.3.1 150 | - jupyter-events==0.6.3 151 | - jupyter-lsp==2.2.0 152 | - jupyter-server==2.7.0 153 | - jupyter-server-terminals==0.4.4 154 | - jupyterlab==4.0.3 155 | - jupyterlab-pygments==0.2.2 156 | - jupyterlab-server==2.23.0 157 | - jupyterlab-widgets==3.0.8 158 | - kiwisolver==1.4.4 159 | - markupsafe==2.1.3 160 | - matplotlib==3.5.1 161 | - matplotlib-inline==0.1.6 162 | - minkowskiengine==0.5.4 163 | - mistune==3.0.1 164 | - nbclient==0.8.0 165 | - nbconvert==7.7.2 166 | - nbformat==5.9.1 167 | - nest-asyncio==1.5.6 168 | - ninja==1.10.2.3 169 | - notebook==7.0.0 170 | - notebook-shim==0.2.3 171 | - nvidia-ml-py==12.535.77 172 | - nvitop==1.2.0 173 | - open3d==0.10.0.0 174 | - opencv-python==4.8.0.74 175 | - overrides==7.3.1 176 | - packaging==23.1 177 | - pandas==2.0.3 178 | - pandocfilters==1.5.0 179 | - parso==0.8.3 180 | - pexpect==4.8.0 181 | - pickleshare==0.7.5 182 | - pip==23.2.1 183 | - pkgutil-resolve-name==1.3.10 184 | - platformdirs==3.9.1 185 | - prometheus-client==0.17.1 186 | - prompt-toolkit==3.0.39 187 | - protobuf==4.24.4 188 | - psutil==5.9.5 189 | - ptyprocess==0.7.0 190 | - pure-eval==0.2.2 191 | - pygments==2.15.1 192 | - pyparsing==3.1.0 193 | - python-dateutil==2.8.2 194 | - python-json-logger==2.0.7 195 | - pytz==2023.3 196 | - pyyaml==6.0.1 197 | - pyzmq==25.1.0 198 | - rama==0.0.7 199 | - referencing==0.30.0 200 | - regex==2023.10.3 201 | - rfc3339-validator==0.1.4 202 | - rfc3986-validator==0.1.1 203 | - rpds-py==0.9.2 204 | - scikit-learn==0.22.2 205 | - scipy==1.8.0 206 | - seaborn==0.11.2 207 | - send2trash==1.8.2 208 | - sniffio==1.3.0 209 | - soupsieve==2.4.1 210 | - stack-data==0.6.2 211 | - tensorboardx==2.6.2.2 212 | - termcolor==2.3.0 213 | - terminado==0.17.1 214 | - tinycss2==1.2.1 215 | - tomli==2.0.1 216 | - torch-scatter==2.0.9 217 | - tornado==6.3.2 218 | - tqdm==4.65.0 219 | - traitlets==5.9.0 220 | - tzdata==2023.3 221 | - uri-template==1.3.0 222 | - urllib3==2.0.4 223 | - wcwidth==0.2.6 224 | - webcolors==1.13 225 | - webencodings==0.5.1 226 | - websocket-client==1.6.1 227 | - widgetsnbextension==4.0.8 228 | - zipp==3.16.2 -------------------------------------------------------------------------------- /pointdc_mk/eval_S3DIS.py: -------------------------------------------------------------------------------- 1 | import torch, os, argparse, faiss 2 | import torch.nn.functional as F 3 | from datasets.S3DIS import S3DIStest, S3DIStrain, cfl_collate_fn_test, cfl_collate_fn 4 | import numpy as np 5 | import MinkowskiEngine as ME 6 | from torch.utils.data import DataLoader 7 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2 8 | from sklearn.cluster import KMeans 9 | from models.fpn import Res16FPN18 10 | from lib.utils_s3dis import * 11 | from tqdm import tqdm 12 | from os.path import join 13 | from datetime import datetime 14 | from sklearn.cluster._kmeans import k_means 15 | from models.pretrain_models import SubModel 16 | ### 17 | def parse_args(): 18 | '''PARAMETERS''' 19 | parser = argparse.ArgumentParser(description='PyTorch Unsuper_3D_Seg') 20 | parser.add_argument('--data_path', type=str, default='data/S3DIS/', help='pont cloud data path') 21 | parser.add_argument('--sp_path', type=str, default= 'data/S3DIS/', help='initial sp path') 22 | parser.add_argument('--expname', type=str, default= 'zdefalut', help='expname for logger') 23 | ### 24 | parser.add_argument('--save_path', type=str, default='ckpt/S3DIS/', help='model savepath') 25 | ### 26 | parser.add_argument('--bn_momentum', type=float, default=0.02, help='batchnorm parameters') 27 | parser.add_argument('--conv1_kernel_size', type=int, default=5, help='kernel size of 1st conv layers') 28 | ### 29 | parser.add_argument('--workers', type=int, default=8, help='how many workers for loading data') 30 | parser.add_argument('--seed', type=int, default=2023, help='random seed') 31 | parser.add_argument('--log-interval', type=int, default=150, help='log interval') 32 | parser.add_argument('--batch_size', type=int, default=8, help='batchsize in training') 33 | parser.add_argument('--voxel_size', type=float, default=0.05, help='voxel size in SparseConv') 34 | parser.add_argument('--input_dim', type=int, default=6, help='network input dimension')### 6 for XYZGB 35 | parser.add_argument('--primitive_num', type=int, default=13, help='how many primitives used in training') 36 | parser.add_argument('--semantic_class', type=int, default=13, help='ground truth semantic class') 37 | parser.add_argument('--feats_dim', type=int, default=128, help='output feature dimension') 38 | parser.add_argument('--ignore_label', type=int, default=-1, help='invalid label') 39 | parser.add_argument('--drop_threshold', type=int, default=50, help='mask counts') 40 | 41 | return parser.parse_args() 42 | 43 | 44 | 45 | def eval_once(args, model, test_loader, classifier, use_sp=False): 46 | model.mode = 'train' 47 | all_preds, all_label = [], [] 48 | test_loader_bar = tqdm(test_loader) 49 | for data in test_loader_bar: 50 | test_loader_bar.set_description('Start eval...') 51 | with torch.no_grad(): 52 | coords, features, inverse_map, labels, index, region = data 53 | 54 | in_field = ME.TensorField(features, coords, device=0) 55 | feats_nonorm = model(in_field) 56 | feats_norm = F.normalize(feats_nonorm) 57 | 58 | region = region.squeeze() 59 | if use_sp: 60 | region_inds = torch.unique(region).long() 61 | region_feats = [] 62 | for id in region_inds: 63 | if id != -1: 64 | valid_mask = id == region 65 | region_feats.append(feats_norm[valid_mask].mean(0, keepdim=True)) 66 | region_feats = torch.cat(region_feats, dim=0) 67 | 68 | scores = F.linear(F.normalize(feats_norm), F.normalize(classifier.weight)) 69 | preds = torch.argmax(scores, dim=1).cpu() 70 | 71 | region_scores = F.linear(F.normalize(region_feats), F.normalize(classifier.weight)) 72 | for id in region_inds: # 遇到0跳过 73 | if id != 0: 74 | valid_mask = id == region 75 | preds[valid_mask] = torch.argmax(region_scores, dim=1).cpu()[id] 76 | else: 77 | scores = F.linear(F.normalize(feats_nonorm), F.normalize(classifier.weight)) 78 | preds = torch.argmax(scores, dim=1).cpu() 79 | 80 | preds = preds[inverse_map.long()] 81 | all_preds.append(preds[labels!=args.ignore_label]), all_label.append(labels[[labels!=args.ignore_label]]) 82 | 83 | return all_preds, all_label 84 | 85 | def eval(epoch, args, mode='svc'): 86 | ## Model 87 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, conv1_kernel_size=args.conv1_kernel_size, config=args, mode='train').cuda() 88 | model.load_state_dict(torch.load(os.path.join(args.save_path, mode, 'model_' + str(epoch) + '_checkpoint.pth'))) 89 | model.eval() 90 | ## Merge Cluster Centers 91 | primitive_centers = torch.load(os.path.join(args.save_path, mode, 'cls_' + str(epoch) + '_checkpoint.pth')) 92 | centroids, _, _ = k_means(primitive_centers.cpu(), n_clusters=args.semantic_class, random_state=None, n_init=20) 93 | centroids = F.normalize(torch.FloatTensor(centroids), dim=1).cuda() 94 | cls = get_fixclassifier(in_channel=args.feats_dim, centroids_num=args.semantic_class, centroids=centroids).cuda() 95 | cls.eval() 96 | 97 | val_dataset = S3DIStest(args) 98 | val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=cfl_collate_fn_test(), num_workers=args.cluster_workers, pin_memory=True) 99 | 100 | preds, labels = eval_once(args, model, val_loader, cls, use_sp=True) 101 | all_preds = torch.cat(preds).numpy() 102 | all_labels = torch.cat(labels).numpy() 103 | 104 | o_Acc, m_Acc, s = compute_seg_results(args, all_labels, all_preds) 105 | 106 | return o_Acc, m_Acc, s 107 | 108 | def eval_by_cluster(args, epoch, mode='svc'): 109 | ## Prepare Data 110 | trainset = S3DIStrain(args, areas=['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']) 111 | cluster_loader = DataLoader(trainset, batch_size=1, shuffle=True, collate_fn=cfl_collate_fn(), \ 112 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(args.seed)) 113 | val_dataset = S3DIStest(args) 114 | val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=cfl_collate_fn_test(), num_workers=args.workers, pin_memory=True) 115 | ## Define model 116 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, conv1_kernel_size=args.conv1_kernel_size, config=args, mode='train').cuda() 117 | model.load_state_dict(torch.load(os.path.join(args.save_path, mode, 'model_' + str(epoch) + '_checkpoint.pth'))) 118 | model.eval() 119 | ## Get sp features 120 | print('Start to get sp features...') 121 | sp_feats_list = init_get_sp_feature(args, cluster_loader, model) 122 | sp_feats = torch.cat(sp_feats_list, dim=0) 123 | print('Start to train faiss...') 124 | ## Train faiss module 125 | _, primitive_centers = faiss_cluster(args, sp_feats.cpu().numpy()) 126 | ## Merge Primitive 127 | centroids, _, _ = k_means(primitive_centers.cpu(), n_clusters=args.semantic_class, random_state=None, n_init=20, n_jobs=20) 128 | centroids = F.normalize(torch.FloatTensor(centroids), dim=1).cuda() 129 | ## Get cls 130 | cls = get_fixclassifier(in_channel=args.feats_dim, centroids_num=args.semantic_class, centroids=centroids).cuda() 131 | cls.eval() 132 | ## eval 133 | preds, labels = eval_once(args, model, val_loader, cls, use_sp=True) 134 | all_preds, all_labels = torch.cat(preds).numpy(), torch.cat(labels).numpy() 135 | o_Acc, m_Acc, s = compute_seg_results(args, all_labels, all_preds) 136 | return o_Acc, m_Acc, s 137 | 138 | if __name__ == '__main__': 139 | args = parse_args() 140 | expnames = ['240524_trainall'] 141 | epoches = [70, 80, 90] 142 | seeds = [12, 43, 56, 78, 90] 143 | for expname in expnames: 144 | args.save_path = 'ckpt/S3DIS' 145 | args.save_path = join(args.save_path, expname) 146 | assert os.path.exists(args.save_path), 'There is no {} !!!'.format(expname) 147 | if not os.path.exists(join('results', expname)): 148 | os.makedirs(join('results', expname)) 149 | for epoch in epoches: 150 | results = [] 151 | results.append('Eval exp {}\n'.format(expname)) 152 | results.append("Eval time: {}\n".format(datetime.now().strftime("%Y-%m-%d %H:%M"))) 153 | results_file = join('results', expname, 'eval_{}.txt'.format(str(epoch))) 154 | print('Eval {}, save file to {}'.format(expname, results_file)) 155 | for seed in seeds: 156 | args.seed = seed 157 | set_seed(args.seed) 158 | o_Acc, m_Acc, s = eval_by_cluster(args, epoch, mode='svc') 159 | print('Epoch {:02d} Seed {}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, seed, o_Acc, m_Acc) + s) 160 | results.append('Epoch {:02d} Seed {}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, seed, o_Acc, m_Acc) + s +'\n') 161 | write_list(results_file, results) 162 | print('\n') -------------------------------------------------------------------------------- /pointdc_mk/eval_ScanNet.py: -------------------------------------------------------------------------------- 1 | import torch, os, argparse, faiss 2 | import torch.nn.functional as F 3 | from datasets.ScanNet import Scannetval, cfl_collate_fn_val 4 | import numpy as np 5 | import MinkowskiEngine as ME 6 | from torch.utils.data import DataLoader 7 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2 8 | from sklearn.cluster import KMeans 9 | from models.fpn import Res16FPN18 10 | from lib.utils import get_fixclassifier, init_get_sp_feature, faiss_cluster, worker_init_fn, set_seed, compute_seg_results, write_list 11 | from tqdm import tqdm 12 | from os.path import join 13 | from datasets.ScanNet import Scannettrain, Scannetdistill, Scannetval, cfl_collate_fn, cfl_collate_fn_distill, cfl_collate_fn_val 14 | from datetime import datetime 15 | from sklearn.cluster._kmeans import k_means 16 | from models.pretrain_models import SubModel 17 | ### 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='PyTorch Unsuper_3D_Seg') 20 | parser.add_argument('--data_path', type=str, default='data/ScanNet4growsp/train', 21 | help='pont cloud data path') 22 | parser.add_argument('--feats_path', type=str, default='data/ScanNet4growsp/traindatas', 23 | help='pont cloud data path') 24 | parser.add_argument('--sp_path', type=str, default= 'data/scans_growed_sp', 25 | help='initial sp path') 26 | parser.add_argument('--save_path', type=str, default='ckpt/ScanNet/', 27 | help='model savepath') 28 | ### 29 | parser.add_argument('--bn_momentum', type=float, default=0.02, help='batchnorm parameters') 30 | parser.add_argument('--conv1_kernel_size', type=int, default=5, help='kernel size of 1st conv layers') 31 | #### 32 | parser.add_argument('--workers', type=int, default=10, help='how many workers for loading data') 33 | parser.add_argument('--seed', type=int, default=2023, help='random seed') 34 | parser.add_argument('--voxel_size', type=float, default=0.02, help='voxel size in SparseConv') 35 | parser.add_argument('--input_dim', type=int, default=6, help='network input dimension')### 6 for XYZGB 36 | parser.add_argument('--primitive_num', type=int, default=30, help='how many primitives used in training') 37 | parser.add_argument('--semantic_class', type=int, default=20, help='ground truth semantic class') 38 | parser.add_argument('--feats_dim', type=int, default=128, help='output feature dimension') 39 | parser.add_argument('--ignore_label', type=int, default=-1, help='invalid label') 40 | return parser.parse_args() 41 | 42 | 43 | def eval_once(args, model, test_loader, classifier, use_sp=False): 44 | model.mode = 'train' 45 | all_preds, all_label = [], [] 46 | test_loader_bar = tqdm(test_loader) 47 | for data in test_loader_bar: 48 | test_loader_bar.set_description('Start eval...') 49 | with torch.no_grad(): 50 | coords, features, inverse_map, labels, index, region = data 51 | 52 | in_field = ME.TensorField(features, coords, device=0) 53 | feats_nonorm = model(in_field) 54 | feats_norm = F.normalize(feats_nonorm) 55 | 56 | region = region.squeeze() 57 | if use_sp: 58 | region_inds = torch.unique(region) 59 | region_feats = [] 60 | for id in region_inds: 61 | if id != -1: 62 | valid_mask = id == region 63 | region_feats.append(feats_norm[valid_mask].mean(0, keepdim=True)) 64 | region_feats = torch.cat(region_feats, dim=0) 65 | 66 | scores = F.linear(F.normalize(feats_norm), F.normalize(classifier.weight)) 67 | preds = torch.argmax(scores, dim=1).cpu() 68 | 69 | region_scores = F.linear(F.normalize(region_feats), F.normalize(classifier.weight)) 70 | region_no = 0 71 | for id in region_inds: 72 | if id != -1: 73 | valid_mask = id == region 74 | preds[valid_mask] = torch.argmax(region_scores, dim=1).cpu()[region_no] 75 | region_no +=1 76 | else: 77 | scores = F.linear(F.normalize(feats_nonorm), F.normalize(classifier.weight)) 78 | preds = torch.argmax(scores, dim=1).cpu() 79 | 80 | preds = preds[inverse_map.long()] 81 | all_preds.append(preds[labels!=args.ignore_label]), all_label.append(labels[[labels!=args.ignore_label]]) 82 | 83 | return all_preds, all_label 84 | 85 | def eval(epoch, args, mode='svc'): 86 | ## Model 87 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, conv1_kernel_size=args.conv1_kernel_size, config=args, mode='train').cuda() 88 | model.load_state_dict(torch.load(os.path.join(args.save_path, mode, 'model_' + str(epoch) + '_checkpoint.pth'))) 89 | model.eval() 90 | ## Merge Cluster Centers 91 | primitive_centers = torch.load(os.path.join(args.save_path, mode, 'cls_' + str(epoch) + '_checkpoint.pth')) 92 | centroids, _, _ = k_means(primitive_centers.cpu(), n_clusters=args.semantic_class, random_state=None, n_init=20) 93 | centroids = F.normalize(torch.FloatTensor(centroids), dim=1).cuda() 94 | cls = get_fixclassifier(in_channel=args.feats_dim, centroids_num=args.semantic_class, centroids=centroids).cuda() 95 | cls.eval() 96 | 97 | val_dataset = Scannetval(args) 98 | val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=cfl_collate_fn_val(), num_workers=args.workers, pin_memory=True) 99 | 100 | preds, labels = eval_once(args, model, val_loader, cls, use_sp=True) 101 | all_preds = torch.cat(preds).numpy() 102 | all_labels = torch.cat(labels).numpy() 103 | 104 | o_Acc, m_Acc, s = compute_seg_results(args, all_labels, all_preds) 105 | 106 | return o_Acc, m_Acc, s 107 | 108 | def eval_by_cluster(args, epoch, mode='svc'): 109 | ## Prepare Data 110 | trainset = Scannettrain(args) 111 | cluster_loader = DataLoader(trainset, batch_size=1, shuffle=True, collate_fn=cfl_collate_fn(), \ 112 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(args.seed)) 113 | val_dataset = Scannetval(args) 114 | val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=cfl_collate_fn_val(), num_workers=args.workers, pin_memory=True) 115 | ## Define model 116 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, conv1_kernel_size=args.conv1_kernel_size, config=args, mode='train').cuda() 117 | model.load_state_dict(torch.load(os.path.join(args.save_path, mode, 'model_' + str(epoch) + '_checkpoint.pth'))) 118 | model.eval() 119 | ## Get sp features 120 | print('Start to get sp features...') 121 | sp_feats_list = init_get_sp_feature(args, cluster_loader, model) 122 | sp_feats = torch.cat(sp_feats_list, dim=0) 123 | print('Start to train faiss...') 124 | ## Train faiss module 125 | _, primitive_centers = faiss_cluster(args, sp_feats.cpu().numpy()) 126 | ## Merge Primitive 127 | centroids, _, _ = k_means(primitive_centers.cpu(), n_clusters=args.semantic_class, random_state=None, n_init=20, n_jobs=20) 128 | centroids = F.normalize(torch.FloatTensor(centroids), dim=1).cuda() 129 | ## Get cls 130 | cls = get_fixclassifier(in_channel=args.feats_dim, centroids_num=args.semantic_class, centroids=centroids).cuda() 131 | cls.eval() 132 | ## eval 133 | preds, labels = eval_once(args, model, val_loader, cls, use_sp=True) 134 | all_preds, all_labels = torch.cat(preds).numpy(), torch.cat(labels).numpy() 135 | o_Acc, m_Acc, s = compute_seg_results(args, all_labels, all_preds) 136 | return o_Acc, m_Acc, s 137 | 138 | if __name__ == '__main__': 139 | args = parse_args() 140 | expnames = ['~'] # your exp name 141 | epoches = [] 142 | seeds = [12, 43, 56, 78, 90] 143 | for expname in expnames: 144 | args.save_path = 'ckpt/ScanNet' 145 | args.save_path = join(args.save_path, expname) 146 | assert os.path.exists(args.save_path), 'There is no {} !!!'.format(expname) 147 | if not os.path.exists(join('results', expname)): 148 | os.makedirs(join('results', expname)) 149 | for epoch in epoches: 150 | results = [] 151 | results.append('Eval exp {}\n'.format(expname)) 152 | results.append("Eval time: {}\n".format(datetime.now().strftime("%Y-%m-%d %H:%M"))) 153 | results_file = join('results', expname, 'eval_{}.txt'.format(str(epoch))) 154 | print('Eval {}, save file to {}'.format(expname, results_file)) 155 | for seed in seeds: 156 | args.seed = seed 157 | set_seed(args.seed) 158 | o_Acc, m_Acc, s = eval_by_cluster(args, epoch, mode='svc') 159 | print('Epoch {:02d} Seed {}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, seed, o_Acc, m_Acc) + s) 160 | results.append('Epoch {:02d} Seed {}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, seed, o_Acc, m_Acc) + s +'\n') 161 | write_list(results_file, results) 162 | print('\n') -------------------------------------------------------------------------------- /pointdc_mk/figs/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCUT-BIP-Lab/PointDC/99f4f30ee4e2fa2d173c3881196afec4b02bb606/pointdc_mk/figs/framework.jpg -------------------------------------------------------------------------------- /pointdc_mk/figs/scannet_visualization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCUT-BIP-Lab/PointDC/99f4f30ee4e2fa2d173c3881196afec4b02bb606/pointdc_mk/figs/scannet_visualization.jpg -------------------------------------------------------------------------------- /pointdc_mk/lib/aug_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.linalg import expm, norm 3 | import scipy.ndimage 4 | 5 | def M(axis, theta): 6 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta)) 7 | 8 | 9 | class trans_coords: 10 | def __init__(self, shift_ratio): 11 | self.ratio = shift_ratio 12 | 13 | def __call__(self, coords): 14 | shift = (np.random.uniform(0, 1, 3) * self.ratio) 15 | return coords + shift 16 | 17 | 18 | class rota_coords: 19 | def __init__(self, rotation_bound = ((-np.pi/32, np.pi/32), (-np.pi/32, np.pi/32), (-np.pi, np.pi))): 20 | self.rotation_bound = rotation_bound 21 | 22 | def __call__(self, coords): 23 | rot_mats = [] 24 | for axis_ind, rot_bound in enumerate(self.rotation_bound): 25 | theta = 0 26 | axis = np.zeros(3) 27 | axis[axis_ind] = 1 28 | if rot_bound is not None: 29 | theta = np.random.uniform(*rot_bound) 30 | rot_mats.append(M(axis, theta)) 31 | # Use random order 32 | np.random.shuffle(rot_mats) 33 | rot_mat = rot_mats[0] @ rot_mats[1] @ rot_mats[2] 34 | return coords.dot(rot_mat) 35 | 36 | 37 | class scale_coords: 38 | def __init__(self, scale_bound=(0.8, 1.25)): 39 | self.scale_bound = scale_bound 40 | 41 | def __call__(self, coords): 42 | scale = np.random.uniform(*self.scale_bound) 43 | return coords*scale 44 | 45 | class elastic_coords: 46 | def __init__(self, voxel_size): 47 | self.voxel_size = voxel_size 48 | 49 | def __call__(self, coords, gran, mag): 50 | blur0 = np.ones((3, 1, 1)).astype('float32') / 3 51 | blur1 = np.ones((1, 3, 1)).astype('float32') / 3 52 | blur2 = np.ones((1, 1, 3)).astype('float32') / 3 53 | 54 | bb = (np.abs(coords).max(0).astype(np.int32) // gran + 3).astype(np.int32) 55 | noise = [np.random.randn(bb[0], bb[1], bb[2]).astype('float32') for _ in range(3)] 56 | noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise] 57 | noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise] 58 | noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise] 59 | noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise] 60 | noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise] 61 | noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise] 62 | ax = [np.linspace(-(b - 1) * gran, (b - 1) * gran, b) for b in bb] 63 | interp = [scipy.interpolate.RegularGridInterpolator(ax, n, bounds_error=0, fill_value=0) for n in noise] 64 | 65 | def g(x_): 66 | return np.hstack([i(x_)[:, None] for i in interp]) 67 | 68 | return coords + g(coords) * mag -------------------------------------------------------------------------------- /pointdc_mk/lib/helper_ply.py: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # 0===============================0 4 | # | PLY files reader/writer | 5 | # 0===============================0 6 | # 7 | # 8 | # ---------------------------------------------------------------------------------------------------------------------- 9 | # 10 | # function to read/write .ply files 11 | # 12 | # ---------------------------------------------------------------------------------------------------------------------- 13 | # 14 | # Hugues THOMAS - 10/02/2017 15 | # 16 | 17 | 18 | # ---------------------------------------------------------------------------------------------------------------------- 19 | # 20 | # Imports and global variables 21 | # \**********************************/ 22 | # 23 | 24 | 25 | # Basic libs 26 | import numpy as np 27 | import sys 28 | 29 | 30 | # Define PLY types 31 | ply_dtypes = dict([ 32 | (b'int8', 'i1'), 33 | (b'char', 'i1'), 34 | (b'uint8', 'u1'), 35 | (b'uchar', 'u1'), 36 | (b'int16', 'i2'), 37 | (b'short', 'i2'), 38 | (b'uint16', 'u2'), 39 | (b'ushort', 'u2'), 40 | (b'int32', 'i4'), 41 | (b'int', 'i4'), 42 | (b'uint32', 'u4'), 43 | (b'uint', 'u4'), 44 | (b'float32', 'f4'), 45 | (b'float', 'f4'), 46 | (b'float64', 'f8'), 47 | (b'double', 'f8') 48 | ]) 49 | 50 | # Numpy reader format 51 | valid_formats = {'ascii': '', 'binary_big_endian': '>', 52 | 'binary_little_endian': '<'} 53 | 54 | 55 | # ---------------------------------------------------------------------------------------------------------------------- 56 | # 57 | # Functions 58 | # \***************/ 59 | # 60 | 61 | 62 | def parse_header(plyfile, ext): 63 | # Variables 64 | line = [] 65 | properties = [] 66 | num_points = None 67 | 68 | while b'end_header' not in line and line != b'': 69 | line = plyfile.readline() 70 | 71 | if b'element' in line: 72 | line = line.split() 73 | num_points = int(line[2]) 74 | 75 | elif b'property' in line: 76 | line = line.split() 77 | properties.append((line[2].decode(), ext + ply_dtypes[line[1]])) 78 | 79 | return num_points, properties 80 | 81 | 82 | def parse_mesh_header(plyfile, ext): 83 | # Variables 84 | line = [] 85 | vertex_properties = [] 86 | num_points = None 87 | num_faces = None 88 | current_element = None 89 | 90 | 91 | while b'end_header' not in line and line != b'': 92 | line = plyfile.readline() 93 | 94 | # Find point element 95 | if b'element vertex' in line: 96 | current_element = 'vertex' 97 | line = line.split() 98 | num_points = int(line[2]) 99 | 100 | elif b'element face' in line: 101 | current_element = 'face' 102 | line = line.split() 103 | num_faces = int(line[2]) 104 | 105 | elif b'property' in line: 106 | if current_element == 'vertex': 107 | line = line.split() 108 | vertex_properties.append((line[2].decode(), ext + ply_dtypes[line[1]])) 109 | elif current_element == 'vertex': 110 | if not line.startswith('property list uchar int'): 111 | raise ValueError('Unsupported faces property : ' + line) 112 | 113 | return num_points, num_faces, vertex_properties 114 | 115 | 116 | def read_ply(filename, triangular_mesh=False): 117 | """ 118 | Read ".ply" files 119 | 120 | Parameters 121 | ---------- 122 | filename : string 123 | the name of the file to read. 124 | 125 | Returns 126 | ------- 127 | result : array 128 | data stored in the file 129 | 130 | Examples 131 | -------- 132 | Store data in file 133 | 134 | >>> points = np.random.rand(5, 3) 135 | >>> values = np.random.randint(2, size=10) 136 | >>> write_ply('example.ply', [points, values], ['x', 'y', 'z', 'values']) 137 | 138 | Read the file 139 | 140 | >>> data = read_ply('example.ply') 141 | >>> values = data['values'] 142 | array([0, 0, 1, 1, 0]) 143 | 144 | >>> points = np.vstack((data['x'], data['y'], data['z'])).T 145 | array([[ 0.466 0.595 0.324] 146 | [ 0.538 0.407 0.654] 147 | [ 0.850 0.018 0.988] 148 | [ 0.395 0.394 0.363] 149 | [ 0.873 0.996 0.092]]) 150 | 151 | """ 152 | 153 | with open(filename, 'rb') as plyfile: 154 | 155 | 156 | # Check if the file start with ply 157 | if b'ply' not in plyfile.readline(): 158 | raise ValueError('The file does not start whith the word ply') 159 | 160 | # get binary_little/big or ascii 161 | fmt = plyfile.readline().split()[1].decode() 162 | if fmt == "ascii": 163 | raise ValueError('The file is not binary') 164 | 165 | # get extension for building the numpy dtypes 166 | ext = valid_formats[fmt] 167 | 168 | # PointCloud reader vs mesh reader 169 | if triangular_mesh: 170 | 171 | # Parse header 172 | num_points, num_faces, properties = parse_mesh_header(plyfile, ext) 173 | 174 | # Get point data 175 | vertex_data = np.fromfile(plyfile, dtype=properties, count=num_points) 176 | 177 | # Get face data 178 | face_properties = [('k', ext + 'u1'), 179 | ('v1', ext + 'i4'), 180 | ('v2', ext + 'i4'), 181 | ('v3', ext + 'i4')] 182 | faces_data = np.fromfile(plyfile, dtype=face_properties, count=num_faces) 183 | 184 | # Return vertex data and concatenated faces 185 | faces = np.vstack((faces_data['v1'], faces_data['v2'], faces_data['v3'])).T 186 | data = [vertex_data, faces] 187 | 188 | else: 189 | 190 | # Parse header 191 | num_points, properties = parse_header(plyfile, ext) 192 | 193 | # Get data 194 | data = np.fromfile(plyfile, dtype=properties, count=num_points) 195 | 196 | return data 197 | 198 | 199 | def header_properties(field_list, field_names): 200 | 201 | # List of lines to write 202 | lines = [] 203 | 204 | # First line describing element vertex 205 | lines.append('element vertex %d' % field_list[0].shape[0]) 206 | 207 | # Properties lines 208 | i = 0 209 | for fields in field_list: 210 | for field in fields.T: 211 | lines.append('property %s %s' % (field.dtype.name, field_names[i])) 212 | i += 1 213 | 214 | return lines 215 | 216 | 217 | def write_ply(filename, field_list, field_names, triangular_faces=None): 218 | """ 219 | Write ".ply" files 220 | 221 | Parameters 222 | ---------- 223 | filename : string 224 | the name of the file to which the data is saved. A '.ply' extension will be appended to the 225 | file name if it does no already have one. 226 | 227 | field_list : list, tuple, numpy array 228 | the fields to be saved in the ply file. Either a numpy array, a list of numpy arrays or a 229 | tuple of numpy arrays. Each 1D numpy array and each column of 2D numpy arrays are considered 230 | as one field. 231 | 232 | field_names : list 233 | the name of each fields as a list of strings. Has to be the same length as the number of 234 | fields. 235 | 236 | Examples 237 | -------- 238 | >>> points = np.random.rand(10, 3) 239 | >>> write_ply('example1.ply', points, ['x', 'y', 'z']) 240 | 241 | >>> values = np.random.randint(2, size=10) 242 | >>> write_ply('example2.ply', [points, values], ['x', 'y', 'z', 'values']) 243 | 244 | >>> colors = np.random.randint(255, size=(10,3), dtype=np.uint8) 245 | >>> field_names = ['x', 'y', 'z', 'red', 'green', 'blue', 'values'] 246 | >>> write_ply('example3.ply', [points, colors, values], field_names) 247 | 248 | """ 249 | 250 | # Format list input to the right form 251 | field_list = list(field_list) if (type(field_list) == list or type(field_list) == tuple) else list((field_list,)) 252 | for i, field in enumerate(field_list): 253 | if field.ndim < 2: 254 | field_list[i] = field.reshape(-1, 1) 255 | if field.ndim > 2: 256 | print('fields have more than 2 dimensions') 257 | return False 258 | 259 | # check all fields have the same number of data 260 | n_points = [field.shape[0] for field in field_list] 261 | if not np.all(np.equal(n_points, n_points[0])): 262 | print('wrong field dimensions') 263 | return False 264 | 265 | # Check if field_names and field_list have same nb of column 266 | n_fields = np.sum([field.shape[1] for field in field_list]) 267 | if (n_fields != len(field_names)): 268 | print('wrong number of field names') 269 | return False 270 | 271 | # Add extension if not there 272 | if not filename.endswith('.ply'): 273 | filename += '.ply' 274 | 275 | # open in text mode to write the header 276 | with open(filename, 'w') as plyfile: 277 | 278 | # First magical word 279 | header = ['ply'] 280 | 281 | # Encoding format 282 | header.append('format binary_' + sys.byteorder + '_endian 1.0') 283 | 284 | # Points properties description 285 | header.extend(header_properties(field_list, field_names)) 286 | 287 | # Add faces if needded 288 | if triangular_faces is not None: 289 | header.append('element face {:d}'.format(triangular_faces.shape[0])) 290 | header.append('property list uchar int vertex_indices') 291 | 292 | # End of header 293 | header.append('end_header') 294 | 295 | # Write all lines 296 | for line in header: 297 | plyfile.write("%s\n" % line) 298 | 299 | # open in binary/append to use tofile 300 | with open(filename, 'ab') as plyfile: 301 | 302 | # Create a structured array 303 | i = 0 304 | type_list = [] 305 | for fields in field_list: 306 | for field in fields.T: 307 | type_list += [(field_names[i], field.dtype.str)] 308 | i += 1 309 | data = np.empty(field_list[0].shape[0], dtype=type_list) 310 | i = 0 311 | for fields in field_list: 312 | for field in fields.T: 313 | data[field_names[i]] = field 314 | i += 1 315 | 316 | data.tofile(plyfile) 317 | 318 | if triangular_faces is not None: 319 | triangular_faces = triangular_faces.astype(np.int32) 320 | type_list = [('k', 'uint8')] + [(str(ind), 'int32') for ind in range(3)] 321 | data = np.empty(triangular_faces.shape[0], dtype=type_list) 322 | data['k'] = np.full((triangular_faces.shape[0],), 3, dtype=np.uint8) 323 | data['0'] = triangular_faces[:, 0] 324 | data['1'] = triangular_faces[:, 1] 325 | data['2'] = triangular_faces[:, 2] 326 | data.tofile(plyfile) 327 | 328 | return True 329 | 330 | 331 | def describe_element(name, df): 332 | """ Takes the columns of the dataframe and builds a ply-like description 333 | 334 | Parameters 335 | ---------- 336 | name: str 337 | df: pandas DataFrame 338 | 339 | Returns 340 | ------- 341 | element: list[str] 342 | """ 343 | property_formats = {'f': 'float', 'u': 'uchar', 'i': 'int'} 344 | element = ['element ' + name + ' ' + str(len(df))] 345 | 346 | if name == 'face': 347 | element.append("property list uchar int points_indices") 348 | 349 | else: 350 | for i in range(len(df.columns)): 351 | # get first letter of dtype to infer format 352 | f = property_formats[str(df.dtypes[i])[0]] 353 | element.append('property ' + f + ' ' + df.columns.values[i]) 354 | 355 | return element 356 | 357 | -------------------------------------------------------------------------------- /pointdc_mk/lib/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os, faiss, shutil, random 6 | import faiss 7 | import shutil 8 | from sklearn.cluster import KMeans 9 | import MinkowskiEngine as ME 10 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2 11 | 12 | from tqdm import tqdm 13 | from torch_scatter import scatter 14 | 15 | class AverageMeter(object): 16 | """Computes and stores the average and current value""" 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | class FocalLoss(nn.Module): 33 | def __init__(self, weight=None, ignore_index=None, gamma=2): 34 | super(FocalLoss,self).__init__() 35 | self.gamma = gamma 36 | self.weight = weight # 是tensor数据格式的列表 37 | self.ignore_index = ignore_index 38 | 39 | def forward(self, preds, labels): 40 | """ 41 | preds:logist输出值 42 | labels:标签 43 | """ 44 | if self.ignore_index is not None: 45 | mask = torch.nonzero(labels!=self.ignore_index) 46 | preds = preds[mask].squeeze(1) 47 | labels = labels[mask].squeeze(1) 48 | 49 | preds = F.softmax(preds,dim=1) 50 | eps = 1e-7 51 | 52 | target = self.one_hot(preds.size(1), labels) 53 | 54 | ce = -torch.log(preds+eps) * target 55 | floss = torch.pow((1-preds), self.gamma) * ce 56 | if self.weight is not None: 57 | floss = torch.mul(floss, self.weight) 58 | floss = torch.sum(floss, dim=1) 59 | return torch.mean(floss) 60 | 61 | def one_hot(self, num, labels): 62 | one = torch.zeros((labels.size(0),num)).cuda() 63 | one[range(labels.size(0)),labels] = 1 64 | return one 65 | 66 | class MseMaskLoss(nn.Module): 67 | def __init__(self): 68 | super(MseMaskLoss, self).__init__() 69 | 70 | def forward(seelf, sourcefeats, targetfeats): 71 | targetfeats = F.normalize(targetfeats, dim=1, p=2) 72 | sourcefeats = F.normalize(sourcefeats, dim=1, p=2) 73 | compute_index = torch.where(torch.abs(targetfeats).sum(dim=1)>0) 74 | mseloss = (sourcefeats[compute_index] - targetfeats[compute_index])**2 75 | 76 | return mseloss.mean() 77 | 78 | def worker_init_fn(seed): 79 | return lambda x: np.random.seed(seed + x) 80 | 81 | def set_seed(seed): 82 | """ 83 | Unfortunately, backward() of [interpolate] functional seems to be never deterministic. 84 | 85 | Below are related threads: 86 | https://github.com/pytorch/pytorch/issues/7068 87 | https://discuss.pytorch.org/t/non-deterministic-behavior-of-pytorch-upsample-interpolate/42842?u=sbelharbi 88 | """ 89 | # Use random seed. 90 | random.seed(seed) 91 | os.environ['PYTHONHASHSEED'] = str(seed) 92 | np.random.seed(seed) 93 | torch.manual_seed(seed) 94 | torch.cuda.manual_seed(seed) 95 | torch.cuda.manual_seed_all(seed) 96 | torch.backends.cudnn.deterministic = True 97 | torch.backends.cudnn.benchmark = False 98 | torch.backends.cudnn.enabled = False 99 | 100 | def init_get_sp_feature(args, loader, model, submodel=None): 101 | loader.dataset.mode = 'cluster' 102 | 103 | region_feats_list = [] 104 | model.eval() 105 | with torch.no_grad(): 106 | for batch_idx, data in enumerate(loader): 107 | coords, features, _, labels, inverse_map, pseudo_labels, inds, region, index, scenenames = data 108 | region = region.squeeze() 109 | raw_region = region.clone() 110 | 111 | in_field = ME.TensorField(features, coords, device=0) 112 | 113 | if submodel is not None: 114 | feats_TensorField = model(in_field) 115 | feats_nonorm = submodel(feats_TensorField) 116 | else: 117 | feats_nonorm = model(in_field) 118 | feats_nonorm = feats_nonorm[inverse_map.long()] # 获取points feats 119 | feats_norm = F.normalize(feats_nonorm, dim=1) ## NOTE 可能需要normalize?没啥区别 120 | 121 | region_feats = scatter(feats_norm, raw_region.cuda(), dim=0, reduce='mean') 122 | 123 | valid_mask = labels!=-1 # 获取带训练的mask区域 124 | # labels = labels[valid_mask] 125 | region_masked = region[valid_mask].long() 126 | region_masked_num = torch.unique(region_masked) 127 | region_masked_feats = region_feats[region_masked_num] 128 | region_masked_feats_norm = F.normalize(region_masked_feats, dim=1).cpu() 129 | 130 | region_feats_list.append(region_masked_feats_norm) 131 | 132 | torch.cuda.empty_cache() 133 | torch.cuda.synchronize(torch.device("cuda")) 134 | 135 | return region_feats_list 136 | 137 | def init_get_pseudo(args, loader, model, centroids_norm, submodel=None): 138 | 139 | pseudo_label_folder = args.pseudo_path + '/' 140 | if not os.path.exists(pseudo_label_folder): os.makedirs(pseudo_label_folder) 141 | 142 | all_pseudo = [] 143 | all_label = [] 144 | model.eval() 145 | with torch.no_grad(): 146 | for batch_idx, data in enumerate(loader): 147 | coords, features, _, labels, inverse_map, pseudo_labels, \ 148 | inds, region, index, scenenames = data 149 | region = region.squeeze() 150 | raw_region = region.clone() 151 | 152 | in_field = ME.TensorField(features, coords, device=0) 153 | 154 | if submodel is not None: 155 | feats_TensorField = model(in_field) 156 | feats_nonorm = submodel(feats_TensorField) 157 | else: 158 | feats_nonorm = model(in_field) 159 | feats_nonorm = feats_nonorm[inverse_map.long()] # 获取points feats 160 | feats_norm = F.normalize(feats_nonorm, dim=1) ## NOTE 可能需要normalize? 161 | 162 | region_feats = scatter(feats_norm, raw_region.cuda(), dim=0, reduce='mean') 163 | 164 | ## 预测 165 | region_scores= F.linear(F.normalize(region_feats), centroids_norm) 166 | region_scores_maxs, region_preds = torch.max(region_scores, dim=1) 167 | region_scores_maxs, region_preds = region_scores_maxs.cpu(), region_preds.cpu() 168 | preds = region_preds[region] ## all point preds 169 | region_scores_maxs = region_scores_maxs[region] ## all point preds 170 | 171 | pseudo_label_file = pseudo_label_folder + '/' + scenenames[0] + '.npy' 172 | np.save(pseudo_label_file, preds) 173 | 174 | all_label.append(labels) 175 | all_pseudo.append(preds) 176 | 177 | torch.cuda.empty_cache() 178 | torch.cuda.synchronize(torch.device("cuda")) 179 | 180 | all_pseudo = np.concatenate(all_pseudo) 181 | all_label = np.concatenate(all_label) 182 | 183 | return all_pseudo, all_label 184 | 185 | def get_fixclassifier(in_channel, centroids_num, centroids): 186 | classifier = nn.Linear(in_features=in_channel, out_features=centroids_num, bias=False) 187 | centroids = F.normalize(centroids, dim=1) 188 | classifier.weight.data = centroids 189 | for para in classifier.parameters(): 190 | para.requires_grad = False 191 | return classifier 192 | 193 | def compute_hist(normal, bins=10, min=-1, max=1): 194 | ## normal : [N, 3] 195 | normal = F.normalize(normal) 196 | relation = torch.mm(normal, normal.t()) 197 | relation = torch.triu(relation, diagonal=0) # top-half matrix 198 | hist = torch.histc(relation, bins, min, max) 199 | # hist = torch.histogram(relation, bins, range=(-1, 1)) 200 | hist /= hist.sum() 201 | 202 | return hist 203 | 204 | def faiss_cluster(args, sp_feats, metric='cosin'): 205 | dim = sp_feats.shape[-1] 206 | 207 | # define faiss module 208 | res = faiss.StandardGpuResources() 209 | fcfg = faiss.GpuIndexFlatConfig() 210 | fcfg.useFloat16 = False 211 | fcfg.device = 0 #NOTE: Single GPU only. 212 | if metric == 'l2': 213 | faiss_module = faiss.GpuIndexFlatL2(res, dim, fcfg) # 欧式距离 214 | elif metric == 'cosin': 215 | faiss_module = faiss.GpuIndexFlatIP(res, dim, fcfg) # 余弦距离 216 | clus = faiss.Clustering(dim, args.primitive_num) 217 | clus.seed = np.random.randint(args.seed) 218 | clus.niter = 80 219 | 220 | # train 221 | clus.train(sp_feats, faiss_module) 222 | centroids = faiss.vector_float_to_array(clus.centroids).reshape(args.primitive_num, dim).astype('float32') 223 | centroids_norm = F.normalize(torch.tensor(centroids), dim=1) 224 | # D, I = faiss_module.search(sp_feats, 1) 225 | 226 | return None, centroids_norm 227 | 228 | def cache_codes(args): 229 | tardir = os.path.join(args.save_path, 'cache_code') 230 | if not os.path.exists(tardir): 231 | os.makedirs(tardir) 232 | try: 233 | all_files = os.listdir('./') 234 | pyfile_list = [file for file in all_files if file.endswith(".py")] 235 | shutil.copytree(r'./lib', os.path.join(tardir, r'lib')) 236 | shutil.copytree(r'./data_prepare', os.path.join(tardir, r'data_prepare')) 237 | shutil.copytree(r'./datasets', os.path.join(tardir, r'datasets')) 238 | for pyfile in pyfile_list: 239 | shutil.copy(pyfile, os.path.join(tardir, pyfile)) 240 | except: 241 | pass 242 | 243 | def compute_seg_results(args, all_labels, all_preds): 244 | '''Unsupervised, Match pred to gt''' 245 | sem_num = args.semantic_class 246 | mask = (all_labels >= 0) & (all_labels < sem_num) 247 | histogram = np.bincount(sem_num * all_labels[mask] + all_preds[mask], minlength=sem_num ** 2).reshape(sem_num, sem_num) 248 | '''Hungarian Matching''' 249 | m = linear_assignment(histogram.max() - histogram) 250 | o_Acc = histogram[m[:, 0], m[:, 1]].sum() / histogram.sum()*100. 251 | m_Acc = np.mean(histogram[m[:, 0], m[:, 1]] / histogram.sum(1))*100 252 | hist_new = np.zeros((sem_num, sem_num)) 253 | for idx in range(sem_num): 254 | hist_new[:, idx] = histogram[:, m[idx, 1]] 255 | '''Final Metrics''' 256 | tp = np.diag(hist_new) 257 | fp = np.sum(hist_new, 0) - tp 258 | fn = np.sum(hist_new, 1) - tp 259 | IoUs = tp / (tp + fp + fn + 1e-8) 260 | m_IoU = np.nanmean(IoUs) 261 | s = '| mIoU {:5.2f} | '.format(100 * m_IoU) 262 | for IoU in IoUs: 263 | s += '{:5.2f} '.format(100 * IoU) 264 | 265 | return o_Acc, m_Acc, s 266 | 267 | def write_list(file_path, contents): 268 | with open(file_path, 'w') as file: 269 | for content in contents: 270 | file.write(content) 271 | -------------------------------------------------------------------------------- /pointdc_mk/lib/utils_s3dis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os, faiss, shutil, random 6 | import faiss 7 | import shutil 8 | from sklearn.cluster import KMeans 9 | import MinkowskiEngine as ME 10 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2 11 | 12 | from tqdm import tqdm 13 | from torch_scatter import scatter 14 | 15 | class AverageMeter(object): 16 | """Computes and stores the average and current value""" 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | class FocalLoss(nn.Module): 33 | def __init__(self, weight=None, ignore_index=None, gamma=2): 34 | super(FocalLoss,self).__init__() 35 | self.gamma = gamma 36 | self.weight = weight # 是tensor数据格式的列表 37 | self.ignore_index = ignore_index 38 | 39 | def forward(self, preds, labels): 40 | """ 41 | preds:logist输出值 42 | labels:标签 43 | """ 44 | if self.ignore_index is not None: 45 | mask = torch.nonzero(labels!=self.ignore_index) 46 | preds = preds[mask].squeeze(1) 47 | labels = labels[mask].squeeze(1) 48 | 49 | preds = F.softmax(preds,dim=1) 50 | eps = 1e-7 51 | 52 | target = self.one_hot(preds.size(1), labels) 53 | 54 | ce = -torch.log(preds+eps) * target 55 | floss = torch.pow((1-preds), self.gamma) * ce 56 | if self.weight is not None: 57 | floss = torch.mul(floss, self.weight) 58 | floss = torch.sum(floss, dim=1) 59 | return torch.mean(floss) 60 | 61 | def one_hot(self, num, labels): 62 | one = torch.zeros((labels.size(0),num)).cuda() 63 | one[range(labels.size(0)),labels] = 1 64 | return one 65 | 66 | class MseMaskLoss(nn.Module): 67 | def __init__(self): 68 | super(MseMaskLoss, self).__init__() 69 | 70 | def forward(seelf, sourcefeats, targetfeats): 71 | targetfeats = F.normalize(targetfeats, dim=1, p=2) 72 | sourcefeats = F.normalize(sourcefeats, dim=1, p=2) 73 | mseloss = (sourcefeats - targetfeats)**2 74 | 75 | return mseloss.mean() 76 | 77 | def worker_init_fn(seed): 78 | return lambda x: np.random.seed(seed + x) 79 | 80 | def set_seed(seed): 81 | """ 82 | Unfortunately, backward() of [interpolate] functional seems to be never deterministic. 83 | 84 | Below are related threads: 85 | https://github.com/pytorch/pytorch/issues/7068 86 | https://discuss.pytorch.org/t/non-deterministic-behavior-of-pytorch-upsample-interpolate/42842?u=sbelharbi 87 | """ 88 | # Use random seed. 89 | random.seed(seed) 90 | os.environ['PYTHONHASHSEED'] = str(seed) 91 | np.random.seed(seed) 92 | torch.manual_seed(seed) 93 | torch.cuda.manual_seed(seed) 94 | torch.cuda.manual_seed_all(seed) 95 | torch.backends.cudnn.deterministic = True 96 | torch.backends.cudnn.benchmark = False 97 | torch.backends.cudnn.enabled = False 98 | 99 | def init_get_sp_feature(args, loader, model, submodel=None): 100 | loader.dataset.mode = 'cluster' 101 | 102 | region_feats_list = [] 103 | model.eval() 104 | with torch.no_grad(): 105 | for batch_idx, data in enumerate(loader): 106 | coords, features, _, labels, inverse_map, pseudo_labels, inds, region, index, scenenames = data 107 | region = region.squeeze() 108 | 109 | in_field = ME.TensorField(features, coords, device=0) 110 | 111 | if submodel is not None: 112 | feats_TensorField = model(in_field) 113 | feats_nonorm = submodel(feats_TensorField) 114 | else: 115 | feats_nonorm = model(in_field) 116 | feats_nonorm = feats_nonorm[inverse_map.long()] # 获取points feats 117 | feats_norm = F.normalize(feats_nonorm, dim=1) ## NOTE 可能需要normalize?没啥区别 118 | 119 | valid_mask = region!=-1 # 获取带训练的mask区域 120 | region_feats = scatter(feats_norm[valid_mask], region[valid_mask].cuda(), dim=0, reduce='mean') 121 | 122 | region_feats_norm = F.normalize(region_feats, dim=1).cpu() 123 | region_feats_list.append(region_feats_norm) 124 | 125 | torch.cuda.empty_cache() 126 | torch.cuda.synchronize(torch.device("cuda")) 127 | 128 | return region_feats_list 129 | 130 | def init_get_pseudo(args, loader, model, centroids_norm, submodel=None): 131 | 132 | pseudo_label_folder = args.pseudo_path + '/' 133 | if not os.path.exists(pseudo_label_folder): os.makedirs(pseudo_label_folder) 134 | 135 | all_pseudo = [] 136 | all_label = [] 137 | model.eval() 138 | with torch.no_grad(): 139 | for batch_idx, data in enumerate(loader): 140 | coords, features, _, labels, inverse_map, pseudo_labels, inds, region, index, scenenames = data 141 | region = region.squeeze()+1 142 | raw_region = region.clone() 143 | 144 | in_field = ME.TensorField(features, coords, device=0) 145 | 146 | if submodel is not None: 147 | feats_TensorField = model(in_field) 148 | feats_nonorm = submodel(feats_TensorField) 149 | else: 150 | feats_nonorm = model(in_field) 151 | feats_nonorm = feats_nonorm[inverse_map.long()] # 获取points feats 152 | feats_norm = F.normalize(feats_nonorm, dim=1) ## NOTE 可能需要normalize? 153 | 154 | scores = F.linear(feats_norm, F.normalize(centroids_norm)) 155 | preds = torch.argmax(scores, dim=1).cpu() # 基于点的预测结果 156 | 157 | region_feats = scatter(feats_norm, raw_region.cuda(), dim=0, reduce='mean') 158 | 159 | region_inds = torch.unique(region, sorted=True) 160 | region_scores = F.linear(F.normalize(region_feats), F.normalize(centroids_norm)) # 超体素的预测结果 161 | for id in region_inds: # 遇到0跳过 162 | if id != 0: 163 | valid_mask = id == region 164 | preds[valid_mask] = torch.argmax(region_scores, dim=1).cpu()[id] 165 | 166 | preds[labels==-1] = -1 167 | pseudo_label_file = pseudo_label_folder + '/' + scenenames[0] + '.npy' 168 | np.save(pseudo_label_file, preds) 169 | 170 | all_label.append(labels) 171 | all_pseudo.append(preds) 172 | 173 | torch.cuda.empty_cache() 174 | torch.cuda.synchronize(torch.device("cuda")) 175 | 176 | all_pseudo = np.concatenate(all_pseudo) 177 | all_label = np.concatenate(all_label) 178 | 179 | return all_pseudo.astype('int64'), all_label.astype('int64') 180 | 181 | def get_fixclassifier(in_channel, centroids_num, centroids): 182 | classifier = nn.Linear(in_features=in_channel, out_features=centroids_num, bias=False) 183 | centroids = F.normalize(centroids, dim=1) 184 | classifier.weight.data = centroids 185 | for para in classifier.parameters(): 186 | para.requires_grad = False 187 | return classifier 188 | 189 | def compute_hist(normal, bins=10, min=-1, max=1): 190 | ## normal : [N, 3] 191 | normal = F.normalize(normal) 192 | relation = torch.mm(normal, normal.t()) 193 | relation = torch.triu(relation, diagonal=0) # top-half matrix 194 | hist = torch.histc(relation, bins, min, max) 195 | # hist = torch.histogram(relation, bins, range=(-1, 1)) 196 | hist /= hist.sum() 197 | 198 | return hist 199 | 200 | def faiss_cluster(args, sp_feats, metric='cosin'): 201 | dim = sp_feats.shape[-1] 202 | 203 | # define faiss module 204 | res = faiss.StandardGpuResources() 205 | fcfg = faiss.GpuIndexFlatConfig() 206 | fcfg.useFloat16 = False 207 | fcfg.device = 0 #NOTE: Single GPU only. 208 | if metric == 'l2': 209 | faiss_module = faiss.GpuIndexFlatL2(res, dim, fcfg) # 欧式距离 210 | elif metric == 'cosin': 211 | faiss_module = faiss.GpuIndexFlatIP(res, dim, fcfg) # 余弦距离 212 | clus = faiss.Clustering(dim, args.primitive_num) 213 | clus.seed = np.random.randint(args.seed) 214 | clus.niter = 80 215 | 216 | # train 217 | clus.train(sp_feats, faiss_module) 218 | centroids = faiss.vector_float_to_array(clus.centroids).reshape(args.primitive_num, dim).astype('float32') 219 | centroids_norm = F.normalize(torch.tensor(centroids), dim=1) 220 | # D, I = faiss_module.search(sp_feats, 1) 221 | 222 | return None, centroids_norm 223 | 224 | def cache_codes(args): 225 | tardir = os.path.join(args.save_path, 'cache_code') 226 | if not os.path.exists(tardir): 227 | os.makedirs(tardir) 228 | try: 229 | all_files = os.listdir('./') 230 | pyfile_list = [file for file in all_files if file.endswith(".py")] 231 | shutil.copytree(r'./lib', os.path.join(tardir, r'lib')) 232 | shutil.copytree(r'./data_prepare', os.path.join(tardir, r'data_prepare')) 233 | shutil.copytree(r'./datasets', os.path.join(tardir, r'datasets')) 234 | for pyfile in pyfile_list: 235 | shutil.copy(pyfile, os.path.join(tardir, pyfile)) 236 | except: 237 | pass 238 | 239 | def compute_seg_results(args, all_labels, all_preds): 240 | '''Unsupervised, Match pred to gt''' 241 | sem_num = args.semantic_class 242 | mask = (all_labels >= 0) & (all_labels < sem_num) 243 | histogram = np.bincount(sem_num * all_labels[mask] + all_preds[mask], minlength=sem_num ** 2).reshape(sem_num, sem_num) 244 | '''Hungarian Matching''' 245 | m = linear_assignment(histogram.max() - histogram) 246 | o_Acc = histogram[m[:, 0], m[:, 1]].sum() / histogram.sum()*100. 247 | m_Acc = np.mean(histogram[m[:, 0], m[:, 1]] / histogram.sum(1))*100 248 | hist_new = np.zeros((sem_num, sem_num)) 249 | for idx in range(sem_num): 250 | hist_new[:, idx] = histogram[:, m[idx, 1]] 251 | '''Final Metrics''' 252 | tp = np.diag(hist_new) 253 | fp = np.sum(hist_new, 0) - tp 254 | fn = np.sum(hist_new, 1) - tp 255 | IoUs = tp / (tp + fp + fn + 1e-8) 256 | m_IoU = np.nanmean(IoUs) 257 | s = '| mIoU {:5.2f} | '.format(100 * m_IoU) 258 | for IoU in IoUs: 259 | s += '{:5.2f} '.format(100 * IoU) 260 | 261 | return o_Acc, m_Acc, s 262 | 263 | def write_list(file_path, contents): 264 | with open(file_path, 'w') as file: 265 | for content in contents: 266 | file.write(content) 267 | -------------------------------------------------------------------------------- /pointdc_mk/models/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | try: 5 | from .networks import * 6 | from .res16unet import * 7 | from .resunet import * 8 | 9 | _custom_models = sys.modules[__name__] 10 | 11 | def initialize_minkowski_unet( 12 | model_name, in_channels, out_channels, D=3, conv1_kernel_size=3, dilations=[1, 1, 1, 1], **kwargs 13 | ): 14 | net_cls = getattr(_custom_models, model_name) 15 | return net_cls( 16 | in_channels=in_channels, out_channels=out_channels, D=D, conv1_kernel_size=conv1_kernel_size, **kwargs 17 | ) 18 | 19 | 20 | except: 21 | import logging 22 | 23 | log = logging.getLogger(__name__) 24 | log.warning("Could not load Minkowski Engine, please check that it is installed correctly") 25 | -------------------------------------------------------------------------------- /pointdc_mk/models/api_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import MinkowskiEngine as ME 3 | import sys 4 | 5 | from .common import NormType, get_norm 6 | from torch_points3d.core.common_modules import Seq, Identity 7 | 8 | 9 | class ResBlock(ME.MinkowskiNetwork): 10 | """ 11 | Basic ResNet type block 12 | 13 | Parameters 14 | ---------- 15 | input_nc: 16 | Number of input channels 17 | output_nc: 18 | number of output channels 19 | convolution 20 | Either MinkowskConvolution or MinkowskiConvolutionTranspose 21 | dimension: 22 | Dimension of the spatial grid 23 | """ 24 | 25 | def __init__(self, input_nc, output_nc, convolution, dimension=3): 26 | ME.MinkowskiNetwork.__init__(self, dimension) 27 | self.block = ( 28 | Seq() 29 | .append( 30 | convolution( 31 | in_channels=input_nc, 32 | out_channels=output_nc, 33 | kernel_size=3, 34 | stride=1, 35 | dilation=1, 36 | bias=False, 37 | dimension=dimension, 38 | ) 39 | ) 40 | .append(ME.MinkowskiBatchNorm(output_nc)) 41 | .append(ME.MinkowskiReLU()) 42 | .append( 43 | convolution( 44 | in_channels=output_nc, 45 | out_channels=output_nc, 46 | kernel_size=3, 47 | stride=1, 48 | dilation=1, 49 | bias=False, 50 | dimension=dimension, 51 | ) 52 | ) 53 | .append(ME.MinkowskiBatchNorm(output_nc)) 54 | .append(ME.MinkowskiReLU()) 55 | ) 56 | 57 | if input_nc != output_nc: 58 | self.downsample = ( 59 | Seq() 60 | .append( 61 | convolution( 62 | in_channels=input_nc, 63 | out_channels=output_nc, 64 | kernel_size=1, 65 | stride=1, 66 | dilation=1, 67 | bias=False, 68 | dimension=dimension, 69 | ) 70 | ) 71 | .append(ME.MinkowskiBatchNorm(output_nc)) 72 | ) 73 | else: 74 | self.downsample = None 75 | 76 | def forward(self, x): 77 | out = self.block(x) 78 | if self.downsample: 79 | out += self.downsample(x) 80 | else: 81 | out += x 82 | return out 83 | 84 | 85 | class BottleneckBlock(ME.MinkowskiNetwork): 86 | """ 87 | Bottleneck block with residual 88 | """ 89 | 90 | def __init__(self, input_nc, output_nc, convolution, dimension=3, reduction=4): 91 | self.block = ( 92 | Seq() 93 | .append( 94 | convolution( 95 | in_channels=input_nc, 96 | out_channels=output_nc // reduction, 97 | kernel_size=1, 98 | stride=1, 99 | dilation=1, 100 | bias=False, 101 | dimension=dimension, 102 | ) 103 | ) 104 | .append(ME.MinkowskiBatchNorm(output_nc // reduction)) 105 | .append(ME.MinkowskiReLU()) 106 | .append( 107 | convolution( 108 | output_nc // reduction, 109 | output_nc // reduction, 110 | kernel_size=3, 111 | stride=1, 112 | dilation=1, 113 | bias=False, 114 | dimension=dimension, 115 | ) 116 | ) 117 | .append(ME.MinkowskiBatchNorm(output_nc // reduction)) 118 | .append(ME.MinkowskiReLU()) 119 | .append( 120 | convolution( 121 | output_nc // reduction, 122 | output_nc, 123 | kernel_size=1, 124 | stride=1, 125 | dilation=1, 126 | bias=False, 127 | dimension=dimension, 128 | ) 129 | ) 130 | .append(ME.MinkowskiBatchNorm(output_nc)) 131 | .append(ME.MinkowskiReLU()) 132 | ) 133 | 134 | if input_nc != output_nc: 135 | self.downsample = ( 136 | Seq() 137 | .append( 138 | convolution( 139 | in_channels=input_nc, 140 | out_channels=output_nc, 141 | kernel_size=1, 142 | stride=1, 143 | dilation=1, 144 | bias=False, 145 | dimension=dimension, 146 | ) 147 | ) 148 | .append(ME.MinkowskiBatchNorm(output_nc)) 149 | ) 150 | else: 151 | self.downsample = None 152 | 153 | def forward(self, x): 154 | out = self.block(x) 155 | if self.downsample: 156 | out += self.downsample(x) 157 | else: 158 | out += x 159 | return out 160 | 161 | 162 | class SELayer(torch.nn.Module): 163 | """ 164 | Squeeze and excite layer 165 | 166 | Parameters 167 | ---------- 168 | channel: 169 | size of the input and output 170 | reduction: 171 | magnitude of the compression 172 | D: 173 | dimension of the kernels 174 | """ 175 | 176 | def __init__(self, channel, reduction=16, dimension=3): 177 | # Global coords does not require coords_key 178 | super(SELayer, self).__init__() 179 | self.fc = torch.nn.Sequential( 180 | ME.MinkowskiLinear(channel, channel // reduction), 181 | ME.MinkowskiReLU(), 182 | ME.MinkowskiLinear(channel // reduction, channel), 183 | ME.MinkowskiSigmoid(), 184 | ) 185 | self.pooling = ME.MinkowskiGlobalPooling() 186 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication() 187 | 188 | def forward(self, x): 189 | y = self.pooling(x) 190 | y = self.fc(y) 191 | return self.broadcast_mul(x, y) 192 | 193 | 194 | class SEBlock(ResBlock): 195 | """ 196 | ResBlock with SE layer 197 | """ 198 | 199 | def __init__(self, input_nc, output_nc, convolution, dimension=3, reduction=16): 200 | super().__init__(input_nc, output_nc, convolution, dimension=3) 201 | self.SE = SELayer(output_nc, reduction=reduction, dimension=dimension) 202 | 203 | def forward(self, x): 204 | out = self.block(x) 205 | out = self.SE(out) 206 | if self.downsample: 207 | out += self.downsample(x) 208 | else: 209 | out += x 210 | return out 211 | 212 | 213 | class SEBottleneckBlock(BottleneckBlock): 214 | """ 215 | BottleneckBlock with SE layer 216 | """ 217 | 218 | def __init__(self, input_nc, output_nc, convolution, dimension=3, reduction=16): 219 | super().__init__(input_nc, output_nc, convolution, dimension=3, reduction=4) 220 | self.SE = SELayer(output_nc, reduction=reduction, dimension=dimension) 221 | 222 | def forward(self, x): 223 | out = self.block(x) 224 | out = self.SE(out) 225 | if self.downsample: 226 | out += self.downsample(x) 227 | else: 228 | out += x 229 | return out 230 | 231 | 232 | _res_blocks = sys.modules[__name__] 233 | 234 | 235 | class ResNetDown(ME.MinkowskiNetwork): 236 | """ 237 | Resnet block that looks like 238 | 239 | in --- strided conv ---- Block ---- sum --[... N times] 240 | | | 241 | |-- 1x1 - BN --| 242 | """ 243 | 244 | CONVOLUTION = ME.MinkowskiConvolution 245 | 246 | def __init__( 247 | self, down_conv_nn=[], kernel_size=2, dilation=1, dimension=3, stride=2, N=1, block="ResBlock", **kwargs 248 | ): 249 | block = getattr(_res_blocks, block) 250 | ME.MinkowskiNetwork.__init__(self, dimension) 251 | if stride > 1: 252 | conv1_output = down_conv_nn[0] 253 | else: 254 | conv1_output = down_conv_nn[1] 255 | 256 | self.conv_in = ( 257 | Seq() 258 | .append( 259 | self.CONVOLUTION( 260 | in_channels=down_conv_nn[0], 261 | out_channels=conv1_output, 262 | kernel_size=kernel_size, 263 | stride=stride, 264 | dilation=dilation, 265 | bias=False, 266 | dimension=dimension, 267 | ) 268 | ) 269 | .append(ME.MinkowskiBatchNorm(conv1_output)) 270 | .append(ME.MinkowskiReLU()) 271 | ) 272 | 273 | if N > 0: 274 | self.blocks = Seq() 275 | for _ in range(N): 276 | self.blocks.append(block(conv1_output, down_conv_nn[1], self.CONVOLUTION, dimension=dimension)) 277 | conv1_output = down_conv_nn[1] 278 | else: 279 | self.blocks = None 280 | 281 | def forward(self, x): 282 | out = self.conv_in(x) 283 | if self.blocks: 284 | out = self.blocks(out) 285 | return out 286 | 287 | 288 | class ResNetUp(ResNetDown): 289 | """ 290 | Same as Down conv but for the Decoder 291 | """ 292 | 293 | CONVOLUTION = ME.MinkowskiConvolutionTranspose 294 | 295 | def __init__(self, up_conv_nn=[], kernel_size=2, dilation=1, dimension=3, stride=2, N=1, **kwargs): 296 | super().__init__( 297 | down_conv_nn=up_conv_nn, 298 | kernel_size=kernel_size, 299 | dilation=dilation, 300 | dimension=dimension, 301 | stride=stride, 302 | N=N, 303 | **kwargs 304 | ) 305 | 306 | def forward(self, x, skip): 307 | if skip is not None: 308 | inp = ME.cat(x, skip) 309 | else: 310 | inp = x 311 | return super().forward(inp) 312 | -------------------------------------------------------------------------------- /pointdc_mk/models/common.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from enum import Enum 3 | import torch.nn as nn 4 | 5 | import MinkowskiEngine as ME 6 | 7 | 8 | class NormType(Enum): 9 | BATCH_NORM = 0 10 | INSTANCE_NORM = 1 11 | INSTANCE_BATCH_NORM = 2 12 | 13 | 14 | def get_norm(norm_type, n_channels, D, bn_momentum=0.1): 15 | if norm_type == NormType.BATCH_NORM: 16 | return ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum) 17 | elif norm_type == NormType.INSTANCE_NORM: 18 | return ME.MinkowskiInstanceNorm(n_channels) 19 | elif norm_type == NormType.INSTANCE_BATCH_NORM: 20 | return nn.Sequential( 21 | ME.MinkowskiInstanceNorm(n_channels), ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum) 22 | ) 23 | else: 24 | raise ValueError(f"Norm type: {norm_type} not supported") 25 | 26 | 27 | def get_nonlinearity(non_type): 28 | if non_type == 'ReLU': 29 | return ME.MinkowskiReLU() 30 | elif non_type == 'ELU': 31 | # return ME.MinkowskiInstanceNorm(num_feats, dimension=dimension) 32 | return ME.MinkowskiELU() 33 | else: 34 | raise ValueError(f'Type {non_type}, not defined') 35 | 36 | 37 | class ConvType(Enum): 38 | """ 39 | Define the kernel region type 40 | """ 41 | 42 | HYPERCUBE = 0, "HYPERCUBE" 43 | SPATIAL_HYPERCUBE = 1, "SPATIAL_HYPERCUBE" 44 | SPATIO_TEMPORAL_HYPERCUBE = 2, "SPATIO_TEMPORAL_HYPERCUBE" 45 | HYPERCROSS = 3, "HYPERCROSS" 46 | SPATIAL_HYPERCROSS = 4, "SPATIAL_HYPERCROSS" 47 | SPATIO_TEMPORAL_HYPERCROSS = 5, "SPATIO_TEMPORAL_HYPERCROSS" 48 | SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS = 6, "SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS " 49 | 50 | def __new__(cls, value, name): 51 | member = object.__new__(cls) 52 | member._value_ = value 53 | member.fullname = name 54 | return member 55 | 56 | def __int__(self): 57 | return self.value 58 | 59 | 60 | # Covert the ConvType var to a RegionType var 61 | conv_to_region_type = { 62 | # kernel_size = [k, k, k, 1] 63 | ConvType.HYPERCUBE: ME.RegionType.HYPER_CUBE, 64 | ConvType.SPATIAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, 65 | ConvType.SPATIO_TEMPORAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, 66 | ConvType.HYPERCROSS: ME.RegionType.HYPER_CROSS, 67 | ConvType.SPATIAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, 68 | ConvType.SPATIO_TEMPORAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, 69 | ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: ME.RegionType.CUSTOM, 70 | } 71 | 72 | int_to_region_type = {0: ME.RegionType.HYPER_CUBE, 1: ME.RegionType.HYPER_CROSS, 2: ME.RegionType.CUSTOM} 73 | 74 | 75 | def convert_region_type(region_type): 76 | """ 77 | Convert the integer region_type to the corresponding RegionType enum object. 78 | """ 79 | return int_to_region_type[region_type] 80 | 81 | 82 | def convert_conv_type(conv_type, kernel_size, D): 83 | assert isinstance(conv_type, ConvType), "conv_type must be of ConvType" 84 | region_type = conv_to_region_type[conv_type] 85 | axis_types = None 86 | if conv_type == ConvType.SPATIAL_HYPERCUBE: 87 | # No temporal convolution 88 | if isinstance(kernel_size, collections.Sequence): 89 | kernel_size = kernel_size[:3] 90 | else: 91 | kernel_size = [kernel_size,] * 3 92 | if D == 4: 93 | kernel_size.append(1) 94 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCUBE: 95 | # conv_type conversion already handled 96 | assert D == 4 97 | elif conv_type == ConvType.HYPERCUBE: 98 | # conv_type conversion already handled 99 | pass 100 | elif conv_type == ConvType.SPATIAL_HYPERCROSS: 101 | if isinstance(kernel_size, collections.Sequence): 102 | kernel_size = kernel_size[:3] 103 | else: 104 | kernel_size = [kernel_size,] * 3 105 | if D == 4: 106 | kernel_size.append(1) 107 | elif conv_type == ConvType.HYPERCROSS: 108 | # conv_type conversion already handled 109 | pass 110 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCROSS: 111 | # conv_type conversion already handled 112 | assert D == 4 113 | elif conv_type == ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: 114 | # Define the CUBIC conv kernel for spatial dims and CROSS conv for temp dim 115 | if D < 4: 116 | region_type = ME.RegionType.HYPER_CUBE 117 | else: 118 | axis_types = [ME.RegionType.HYPER_CUBE,] * 3 119 | if D == 4: 120 | axis_types.append(ME.RegionType.HYPER_CROSS) 121 | return region_type, axis_types, kernel_size 122 | 123 | 124 | def conv(in_planes, out_planes, kernel_size, stride=1, dilation=1, bias=False, conv_type=ConvType.HYPERCUBE, D=-1): 125 | assert D > 0, "Dimension must be a positive integer" 126 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 127 | kernel_generator = ME.KernelGenerator( 128 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D 129 | ) 130 | 131 | return ME.MinkowskiConvolution( 132 | in_channels=in_planes, 133 | out_channels=out_planes, 134 | kernel_size=kernel_size, 135 | stride=stride, 136 | dilation=dilation, 137 | bias=bias, 138 | kernel_generator=kernel_generator, 139 | dimension=D, 140 | ) 141 | 142 | 143 | def conv_tr( 144 | in_planes, out_planes, kernel_size, upsample_stride=1, dilation=1, bias=False, conv_type=ConvType.HYPERCUBE, D=-1 145 | ): 146 | assert D > 0, "Dimension must be a positive integer" 147 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 148 | kernel_generator = ME.KernelGenerator( 149 | kernel_size, upsample_stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D 150 | ) 151 | 152 | return ME.MinkowskiConvolutionTranspose( 153 | in_channels=in_planes, 154 | out_channels=out_planes, 155 | kernel_size=kernel_size, 156 | stride=upsample_stride, 157 | dilation=dilation, 158 | bias=bias, 159 | kernel_generator=kernel_generator, 160 | dimension=D, 161 | ) 162 | 163 | 164 | def avg_pool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, in_coords_key=None, D=-1): 165 | assert D > 0, "Dimension must be a positive integer" 166 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 167 | kernel_generator = ME.KernelGenerator( 168 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D 169 | ) 170 | 171 | return ME.MinkowskiAvgPooling( 172 | kernel_size=kernel_size, stride=stride, dilation=dilation, kernel_generator=kernel_generator, dimension=D 173 | ) 174 | 175 | 176 | def avg_unpool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): 177 | assert D > 0, "Dimension must be a positive integer" 178 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 179 | kernel_generator = ME.KernelGenerator( 180 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D 181 | ) 182 | 183 | return ME.MinkowskiAvgUnpooling( 184 | kernel_size=kernel_size, stride=stride, dilation=dilation, kernel_generator=kernel_generator, dimension=D 185 | ) 186 | 187 | 188 | def sum_pool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): 189 | assert D > 0, "Dimension must be a positive integer" 190 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 191 | kernel_generator = ME.KernelGenerator( 192 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D 193 | ) 194 | 195 | return ME.MinkowskiSumPooling( 196 | kernel_size=kernel_size, stride=stride, dilation=dilation, kernel_generator=kernel_generator, dimension=D 197 | ) 198 | -------------------------------------------------------------------------------- /pointdc_mk/models/fpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import MinkowskiEngine as ME 3 | from MinkowskiEngine import MinkowskiNetwork 4 | from MinkowskiEngine import MinkowskiReLU, MinkowskiUnion 5 | import torch.nn.functional as F 6 | 7 | from .common import ConvType, NormType, conv, conv_tr, get_norm, sum_pool 8 | 9 | 10 | class BasicBlockBase(nn.Module): 11 | expansion = 1 12 | NORM_TYPE = NormType.BATCH_NORM 13 | 14 | def __init__( 15 | self, 16 | inplanes, 17 | planes, 18 | stride=1, 19 | dilation=1, 20 | downsample=None, 21 | conv_type=ConvType.HYPERCUBE, 22 | bn_momentum=0.1, 23 | D=3, 24 | ): 25 | super(BasicBlockBase, self).__init__() 26 | 27 | self.conv1 = conv(inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D) 28 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 29 | self.conv2 = conv( 30 | planes, planes, kernel_size=3, stride=1, dilation=dilation, bias=False, conv_type=conv_type, D=D 31 | ) 32 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 33 | self.relu = MinkowskiReLU(inplace=True) 34 | self.downsample = downsample 35 | 36 | def forward(self, x): 37 | residual = x 38 | out = self.conv1(x) 39 | out = self.norm1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.norm2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class BasicBlock(BasicBlockBase): 55 | NORM_TYPE = NormType.BATCH_NORM 56 | 57 | 58 | class BasicBlockIN(BasicBlockBase): 59 | NORM_TYPE = NormType.INSTANCE_NORM 60 | 61 | 62 | class BasicBlockINBN(BasicBlockBase): 63 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM 64 | 65 | 66 | class BottleneckBase(nn.Module): 67 | expansion = 4 68 | NORM_TYPE = NormType.BATCH_NORM 69 | 70 | def __init__( 71 | self, 72 | inplanes, 73 | planes, 74 | stride=1, 75 | dilation=1, 76 | downsample=None, 77 | conv_type=ConvType.HYPERCUBE, 78 | bn_momentum=0.1, 79 | D=3, 80 | ): 81 | super(BottleneckBase, self).__init__() 82 | self.conv1 = conv(inplanes, planes, kernel_size=1, D=D) 83 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 84 | 85 | self.conv2 = conv(planes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D) 86 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum) 87 | 88 | self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, D=D) 89 | self.norm3 = get_norm(self.NORM_TYPE, planes * self.expansion, D, bn_momentum=bn_momentum) 90 | 91 | self.relu = MinkowskiReLU(inplace=True) 92 | self.downsample = downsample 93 | 94 | def forward(self, x): 95 | residual = x 96 | 97 | out = self.conv1(x) 98 | out = self.norm1(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv2(out) 102 | out = self.norm2(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv3(out) 106 | out = self.norm3(out) 107 | 108 | if self.downsample is not None: 109 | residual = self.downsample(x) 110 | 111 | out += residual 112 | out = self.relu(out) 113 | 114 | return out 115 | 116 | 117 | class Bottleneck(BottleneckBase): 118 | NORM_TYPE = NormType.BATCH_NORM 119 | 120 | 121 | class BottleneckIN(BottleneckBase): 122 | NORM_TYPE = NormType.INSTANCE_NORM 123 | 124 | 125 | class BottleneckINBN(BottleneckBase): 126 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM 127 | 128 | 129 | class ResNetBase(MinkowskiNetwork): 130 | BLOCK = None 131 | LAYERS = () 132 | INIT_DIM = 64 133 | PLANES = (64, 128, 256, 512) 134 | OUT_PIXEL_DIST = 32 135 | HAS_LAST_BLOCK = False 136 | CONV_TYPE = ConvType.HYPERCUBE 137 | 138 | def __init__(self, in_channels, out_channels, D, conv1_kernel_size=3, dilations=[1, 1, 1, 1], **kwargs): 139 | super(ResNetBase, self).__init__(D) 140 | self.in_channels = in_channels 141 | self.out_channels = out_channels 142 | self.conv1_kernel_size = conv1_kernel_size 143 | self.dilations = dilations 144 | assert self.BLOCK is not None 145 | assert self.OUT_PIXEL_DIST > 0 146 | 147 | self.network_initialization(in_channels, out_channels, D) 148 | self.weight_initialization() 149 | 150 | def network_initialization(self, in_channels, out_channels, D): 151 | def space_n_time_m(n, m): 152 | return n if D == 3 else [n, n, n, m] 153 | 154 | if D == 4: 155 | self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) 156 | 157 | dilations = self.dilations 158 | bn_momentum = 1 159 | self.inplanes = self.INIT_DIM 160 | self.conv1 = conv( 161 | in_channels, self.inplanes, kernel_size=space_n_time_m(self.conv1_kernel_size, 1), stride=1, D=D 162 | ) 163 | 164 | self.bn1 = get_norm(NormType.BATCH_NORM, self.inplanes, D=self.D, bn_momentum=bn_momentum) 165 | self.relu = ME.MinkowskiReLU(inplace=True) 166 | self.pool = sum_pool(kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), D=D) 167 | 168 | self.layer1 = self._make_layer( 169 | self.BLOCK, 170 | self.PLANES[0], 171 | self.LAYERS[0], 172 | stride=space_n_time_m(2, 1), 173 | dilation=space_n_time_m(dilations[0], 1), 174 | ) 175 | self.layer2 = self._make_layer( 176 | self.BLOCK, 177 | self.PLANES[1], 178 | self.LAYERS[1], 179 | stride=space_n_time_m(2, 1), 180 | dilation=space_n_time_m(dilations[1], 1), 181 | ) 182 | self.layer3 = self._make_layer( 183 | self.BLOCK, 184 | self.PLANES[2], 185 | self.LAYERS[2], 186 | stride=space_n_time_m(2, 1), 187 | dilation=space_n_time_m(dilations[2], 1), 188 | ) 189 | self.layer4 = self._make_layer( 190 | self.BLOCK, 191 | self.PLANES[3], 192 | self.LAYERS[3], 193 | stride=space_n_time_m(2, 1), 194 | dilation=space_n_time_m(dilations[3], 1), 195 | ) 196 | 197 | self.final = conv(self.PLANES[3] * self.BLOCK.expansion, out_channels, kernel_size=1, bias=True, D=D) 198 | 199 | def weight_initialization(self): 200 | for m in self.modules(): 201 | if isinstance(m, ME.MinkowskiBatchNorm): 202 | nn.init.constant_(m.bn.weight, 1) 203 | nn.init.constant_(m.bn.bias, 0) 204 | 205 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_type=NormType.BATCH_NORM, bn_momentum=0.1): 206 | downsample = None 207 | if stride != 1 or self.inplanes != planes * block.expansion: 208 | downsample = nn.Sequential( 209 | conv(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, D=self.D), 210 | get_norm(norm_type, planes * block.expansion, D=self.D, bn_momentum=bn_momentum), 211 | ) 212 | layers = [] 213 | layers.append( 214 | block( 215 | self.inplanes, 216 | planes, 217 | stride=stride, 218 | dilation=dilation, 219 | downsample=downsample, 220 | conv_type=self.CONV_TYPE, 221 | D=self.D, 222 | ) 223 | ) 224 | self.inplanes = planes * block.expansion 225 | for i in range(1, blocks): 226 | layers.append(block(self.inplanes, planes, stride=1, dilation=dilation, conv_type=self.CONV_TYPE, D=self.D)) 227 | 228 | return nn.Sequential(*layers) 229 | 230 | def forward(self, x): 231 | x = self.conv1(x) 232 | x = self.bn1(x) 233 | x = self.relu(x) 234 | x = self.pool(x) 235 | 236 | x = self.layer1(x) 237 | x = self.layer2(x) 238 | x = self.layer3(x) 239 | x = self.layer4(x) 240 | 241 | x = self.final(x) 242 | return x 243 | 244 | 245 | class Res16FPNBase(ResNetBase): 246 | BLOCK = None 247 | PLANES = (32, 64, 128, 256, 256, 256, 256, 256) 248 | DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) 249 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 250 | INIT_DIM = 32 251 | OUT_PIXEL_DIST = 1 252 | NORM_TYPE = NormType.BATCH_NORM 253 | NON_BLOCK_CONV_TYPE = ConvType.SPATIAL_HYPERCUBE 254 | CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS 255 | 256 | # To use the model, must call initialize_coords before forward pass. 257 | # Once data is processed, call clear to reset the model before calling initialize_coords 258 | def __init__(self, in_channels, out_channels, D=3, conv1_kernel_size=5, **kwargs): 259 | super(Res16FPNBase, self).__init__(in_channels, out_channels, D, conv1_kernel_size) 260 | 261 | self.mode = kwargs['mode'] 262 | 263 | def network_initialization(self, in_channels, out_channels, D): 264 | # Setup net_metadata 265 | dilations = self.DILATIONS 266 | bn_momentum = 0.02 267 | 268 | def space_n_time_m(n, m): 269 | return n if D == 3 else [n, n, n, m] 270 | 271 | if D == 4: 272 | self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) 273 | 274 | self.inplanes = self.INIT_DIM 275 | self.conv0p1s1 = conv( 276 | in_channels, 277 | self.inplanes, 278 | kernel_size=space_n_time_m(self.conv1_kernel_size, 1), 279 | stride=1, 280 | dilation=1, 281 | conv_type=self.NON_BLOCK_CONV_TYPE, 282 | D=D, 283 | ) 284 | 285 | self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) 286 | 287 | self.conv1p1s2 = conv( 288 | self.inplanes, 289 | self.inplanes, 290 | kernel_size=space_n_time_m(2, 1), 291 | stride=space_n_time_m(2, 1), 292 | dilation=1, 293 | conv_type=self.NON_BLOCK_CONV_TYPE, 294 | D=D, 295 | ) 296 | self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) 297 | self.block1 = self._make_layer( 298 | self.BLOCK, 299 | self.PLANES[0], 300 | self.LAYERS[0], 301 | dilation=dilations[0], 302 | norm_type=self.NORM_TYPE, 303 | bn_momentum=bn_momentum, 304 | ) 305 | 306 | self.conv2p2s2 = conv( 307 | self.inplanes, 308 | self.inplanes, 309 | kernel_size=space_n_time_m(2, 1), 310 | stride=space_n_time_m(2, 1), 311 | dilation=1, 312 | conv_type=self.NON_BLOCK_CONV_TYPE, 313 | D=D, 314 | ) 315 | self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) 316 | self.block2 = self._make_layer( 317 | self.BLOCK, 318 | self.PLANES[1], 319 | self.LAYERS[1], 320 | dilation=dilations[1], 321 | norm_type=self.NORM_TYPE, 322 | bn_momentum=bn_momentum, 323 | ) 324 | 325 | self.conv3p4s2 = conv( 326 | self.inplanes, 327 | self.inplanes, 328 | kernel_size=space_n_time_m(2, 1), 329 | stride=space_n_time_m(2, 1), 330 | dilation=1, 331 | conv_type=self.NON_BLOCK_CONV_TYPE, 332 | D=D, 333 | ) 334 | self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) 335 | self.block3 = self._make_layer( 336 | self.BLOCK, 337 | self.PLANES[2], 338 | self.LAYERS[2], 339 | dilation=dilations[2], 340 | norm_type=self.NORM_TYPE, 341 | bn_momentum=bn_momentum, 342 | ) 343 | 344 | self.conv4p8s2 = conv( 345 | self.inplanes, 346 | self.inplanes, 347 | kernel_size=space_n_time_m(2, 1), 348 | stride=space_n_time_m(2, 1), 349 | dilation=1, 350 | conv_type=self.NON_BLOCK_CONV_TYPE, 351 | D=D, 352 | ) 353 | self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum) 354 | self.block4 = self._make_layer( 355 | self.BLOCK, 356 | self.PLANES[3], 357 | self.LAYERS[3], 358 | dilation=dilations[3], 359 | norm_type=self.NORM_TYPE, 360 | bn_momentum=bn_momentum, 361 | ) 362 | 363 | self.delayer1 = ME.MinkowskiLinear(256, 128, bias=False) 364 | self.delayer2 = ME.MinkowskiLinear(128, 128, bias=False) 365 | self.delayer3 = ME.MinkowskiLinear(64, 128, bias=False) 366 | self.delayer4 = ME.MinkowskiLinear(32, 128, bias=False) 367 | 368 | self.relu = MinkowskiReLU(inplace=True) 369 | 370 | 371 | def forward(self, x): # 372 | y = x.sparse() 373 | out = self.conv0p1s1(y) 374 | out = self.bn0(out) 375 | out_p1 = self.relu(out)###32 376 | 377 | out = self.conv1p1s2(out_p1) 378 | out = self.bn1(out) 379 | out = self.relu(out) 380 | out_b1p2 = self.block1(out)###32 381 | 382 | out = self.conv2p2s2(out_b1p2) 383 | out = self.bn2(out) 384 | out = self.relu(out) 385 | out_b2p4 = self.block2(out)###64 386 | 387 | out = self.conv3p4s2(out_b2p4) 388 | out = self.bn3(out) 389 | out = self.relu(out) 390 | out_b3p8 = self.block3(out)###128 391 | 392 | out = self.conv4p8s2(out_b3p8) 393 | out = self.bn4(out) 394 | out = self.relu(out) 395 | out = self.block4(out)###256 396 | 397 | 398 | out = self.delayer1(out) 399 | out = out.interpolate(x) 400 | 401 | dout_b3p8 = self.delayer2(out_b3p8) 402 | dout_b3p8 = dout_b3p8.interpolate(x) 403 | 404 | dout_b2p4 = self.delayer3(out_b2p4) 405 | dout_b2p4 = dout_b2p4.interpolate(x) 406 | 407 | dout_b1p2 = self.delayer4(out_b1p2) 408 | dout_b1p2 = dout_b1p2.interpolate(x) 409 | 410 | if self.mode == 'distill': 411 | output = MinkowskiUnion()(out.sparse(), dout_b3p8.sparse(), dout_b2p4.sparse(), dout_b1p2.sparse()).slice(x) 412 | else: 413 | output = out.F + dout_b3p8.F + dout_b2p4.F + dout_b1p2.F # dim=128 type: TensorField 414 | return output 415 | 416 | 417 | 418 | class Res16FPN14(Res16FPNBase): 419 | BLOCK = BasicBlock 420 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) 421 | 422 | 423 | class Res16FPN18(Res16FPNBase): 424 | BLOCK = BasicBlock 425 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 426 | 427 | 428 | class Res16FPN34(Res16FPNBase): 429 | BLOCK = BasicBlock 430 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 431 | 432 | 433 | class Res16FPN50(Res16FPNBase): 434 | BLOCK = Bottleneck 435 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 436 | 437 | 438 | class Res16UFPN101(Res16FPNBase): 439 | BLOCK = Bottleneck 440 | LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) 441 | 442 | 443 | class Res16FPN14A(Res16FPN14): 444 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 445 | 446 | 447 | class Res16FPN14A2(Res16FPN14A): 448 | LAYERS = (1, 1, 1, 1, 2, 2, 2, 2) 449 | 450 | 451 | class Res16FPN14B(Res16FPN14): 452 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 453 | 454 | 455 | class Res16FPN14B2(Res16FPN14B): 456 | LAYERS = (1, 1, 1, 1, 2, 2, 2, 2) 457 | 458 | 459 | class Res16FPN14B3(Res16FPN14B): 460 | LAYERS = (2, 2, 2, 2, 1, 1, 1, 1) 461 | 462 | 463 | class Res16FPN14C(Res16FPN14): 464 | PLANES = (32, 64, 128, 256, 192, 192, 128, 128) 465 | 466 | 467 | class Res16FPN14D(Res16FPN14): 468 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 469 | 470 | 471 | class Res16FPN18A(Res16FPN18): 472 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 473 | 474 | 475 | class Res16FPN18B(Res16FPN18): 476 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 477 | 478 | 479 | class Res16FPN18D(Res16FPN18): 480 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 481 | 482 | 483 | class Res16FPN32B(Res16FPN34): 484 | PLANES = (32, 64, 128, 256, 256, 64, 64, 64) 485 | 486 | 487 | class Res16FPN34A(Res16FPN34): 488 | PLANES = (32, 64, 128, 256, 256, 128, 64, 64) 489 | 490 | 491 | class Res16FPN34B(Res16FPN34): 492 | PLANES = (32, 64, 128, 256, 256, 128, 64, 32) 493 | 494 | 495 | class Res16FPN34C(Res16FPN34): 496 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) 497 | 498 | 499 | 500 | def get_block(norm_type, inplanes, planes, stride=1, dilation=1, downsample=None, bn_momentum=0.1, D=3): 501 | if norm_type == NormType.BATCH_NORM: 502 | return BasicBlock( 503 | inplanes=inplanes, 504 | planes=planes, 505 | stride=stride, 506 | dilation=dilation, 507 | downsample=downsample, 508 | bn_momentum=bn_momentum, 509 | D=D, 510 | ) 511 | elif norm_type == NormType.INSTANCE_NORM: 512 | return BasicBlockIN(inplanes, planes, stride, dilation, downsample, bn_momentum, D) 513 | else: 514 | raise ValueError(f"Type {norm_type}, not defined") 515 | -------------------------------------------------------------------------------- /pointdc_mk/models/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import MinkowskiEngine as ME 3 | from .common import ConvType, NormType 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | """This module implements a basic residual convolution block using MinkowskiEngine 8 | 9 | Parameters 10 | ---------- 11 | inplanes: int 12 | Input dimension 13 | planes: int 14 | Output dimension 15 | dilation: int 16 | Dilation value 17 | downsample: nn.Module 18 | If provided, downsample will be applied on input before doing residual addition 19 | bn_momentum: float 20 | Input dimension 21 | """ 22 | 23 | EXPANSION = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, bn_momentum=0.1, dimension=-1): 26 | super(BasicBlock, self).__init__() 27 | assert dimension > 0 28 | 29 | self.conv1 = ME.MinkowskiConvolution( 30 | inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, dimension=dimension 31 | ) 32 | self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) 33 | self.conv2 = ME.MinkowskiConvolution( 34 | planes, planes, kernel_size=3, stride=1, dilation=dilation, dimension=dimension 35 | ) 36 | self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) 37 | self.relu = ME.MinkowskiReLU(inplace=True) 38 | self.downsample = downsample 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.norm1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.norm2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | EXPANSION = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, bn_momentum=0.1, dimension=-1): 63 | super(Bottleneck, self).__init__() 64 | assert dimension > 0 65 | 66 | self.conv1 = ME.MinkowskiConvolution(inplanes, planes, kernel_size=1, dimension=dimension) 67 | self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) 68 | 69 | self.conv2 = ME.MinkowskiConvolution( 70 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, dimension=dimension 71 | ) 72 | self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) 73 | 74 | self.conv3 = ME.MinkowskiConvolution(planes, planes * self.EXPANSION, kernel_size=1, dimension=dimension) 75 | self.norm3 = ME.MinkowskiBatchNorm(planes * self.EXPANSION, momentum=bn_momentum) 76 | 77 | self.relu = ME.MinkowskiReLU(inplace=True) 78 | self.downsample = downsample 79 | 80 | def forward(self, x): 81 | residual = x 82 | 83 | out = self.conv1(x) 84 | out = self.norm1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.norm2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.norm3(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | class BaseResBlock(nn.Module): 104 | def __init__( 105 | self, 106 | feat_in, 107 | feat_mid, 108 | feat_out, 109 | kernel_sizes=[], 110 | strides=[], 111 | dilations=[], 112 | has_biases=[], 113 | kernel_generators=[], 114 | kernel_size=3, 115 | stride=1, 116 | dilation=1, 117 | bias=False, 118 | kernel_generator=None, 119 | norm_layer=ME.MinkowskiBatchNorm, 120 | activation=ME.MinkowskiReLU, 121 | bn_momentum=0.1, 122 | dimension=-1, 123 | **kwargs 124 | ): 125 | 126 | super(BaseResBlock, self).__init__() 127 | assert dimension > 0 128 | 129 | modules = [] 130 | 131 | convolutions_dim = [[feat_in, feat_mid], [feat_mid, feat_mid], [feat_mid, feat_out]] 132 | 133 | kernel_sizes = self.create_arguments_list(kernel_sizes, kernel_size) 134 | strides = self.create_arguments_list(strides, stride) 135 | dilations = self.create_arguments_list(dilations, dilation) 136 | has_biases = self.create_arguments_list(has_biases, bias) 137 | kernel_generators = self.create_arguments_list(kernel_generators, kernel_generator) 138 | 139 | for conv_dim, kernel_size, stride, dilation, has_bias, kernel_generator in zip( 140 | convolutions_dim, kernel_sizes, strides, dilations, has_biases, kernel_generators 141 | ): 142 | 143 | modules.append( 144 | ME.MinkowskiConvolution( 145 | conv_dim[0], 146 | conv_dim[1], 147 | kernel_size=kernel_size, 148 | stride=stride, 149 | dilation=dilation, 150 | bias=has_bias, 151 | kernel_generator=kernel_generator, 152 | dimension=dimension, 153 | ) 154 | ) 155 | 156 | if norm_layer: 157 | modules.append(norm_layer(conv_dim[1], momentum=bn_momentum)) 158 | 159 | if activation: 160 | modules.append(activation(inplace=True)) 161 | 162 | self.conv = nn.Sequential(*modules) 163 | 164 | @staticmethod 165 | def create_arguments_list(arg_list, arg): 166 | if len(arg_list) == 3: 167 | return arg_list 168 | return [arg for _ in range(3)] 169 | 170 | def forward(self, x): 171 | return x, self.conv(x) 172 | 173 | 174 | class ResnetBlockDown(BaseResBlock): 175 | def __init__( 176 | self, 177 | down_conv_nn=[], 178 | kernel_sizes=[], 179 | strides=[], 180 | dilations=[], 181 | kernel_size=3, 182 | stride=1, 183 | dilation=1, 184 | norm_layer=ME.MinkowskiBatchNorm, 185 | activation=ME.MinkowskiReLU, 186 | bn_momentum=0.1, 187 | dimension=-1, 188 | down_stride=2, 189 | **kwargs 190 | ): 191 | 192 | super(ResnetBlockDown, self).__init__( 193 | down_conv_nn[0], 194 | down_conv_nn[1], 195 | down_conv_nn[2], 196 | kernel_sizes=kernel_sizes, 197 | strides=strides, 198 | dilations=dilations, 199 | kernel_size=kernel_size, 200 | stride=stride, 201 | dilation=dilation, 202 | norm_layer=norm_layer, 203 | activation=activation, 204 | bn_momentum=bn_momentum, 205 | dimension=dimension, 206 | ) 207 | 208 | self.downsample = nn.Sequential( 209 | ME.MinkowskiConvolution( 210 | down_conv_nn[0], down_conv_nn[2], kernel_size=2, stride=down_stride, dimension=dimension 211 | ), 212 | ME.MinkowskiBatchNorm(down_conv_nn[2]), 213 | ) 214 | 215 | def forward(self, x): 216 | 217 | residual, x = super().forward(x) 218 | 219 | return self.downsample(residual) + x 220 | 221 | 222 | class ResnetBlockUp(BaseResBlock): 223 | def __init__( 224 | self, 225 | up_conv_nn=[], 226 | kernel_sizes=[], 227 | strides=[], 228 | dilations=[], 229 | kernel_size=3, 230 | stride=1, 231 | dilation=1, 232 | norm_layer=ME.MinkowskiBatchNorm, 233 | activation=ME.MinkowskiReLU, 234 | bn_momentum=0.1, 235 | dimension=-1, 236 | up_stride=2, 237 | skip=True, 238 | **kwargs 239 | ): 240 | 241 | self.skip = skip 242 | 243 | super(ResnetBlockUp, self).__init__( 244 | up_conv_nn[0], 245 | up_conv_nn[1], 246 | up_conv_nn[2], 247 | kernel_sizes=kernel_sizes, 248 | strides=strides, 249 | dilations=dilations, 250 | kernel_size=kernel_size, 251 | stride=stride, 252 | dilation=dilation, 253 | norm_layer=norm_layer, 254 | activation=activation, 255 | bn_momentum=bn_momentum, 256 | dimension=dimension, 257 | ) 258 | 259 | self.upsample = ME.MinkowskiConvolutionTranspose( 260 | up_conv_nn[0], up_conv_nn[2], kernel_size=2, stride=up_stride, dimension=dimension 261 | ) 262 | 263 | def forward(self, x, x_skip): 264 | residual, x = super().forward(x) 265 | 266 | x = self.upsample(residual) + x 267 | 268 | if self.skip: 269 | return ME.cat(x, x_skip) 270 | else: 271 | return x 272 | 273 | 274 | class SELayer(nn.Module): 275 | def __init__(self, channel, reduction=16, D=-1): 276 | # Global coords does not require coords_key 277 | super(SELayer, self).__init__() 278 | self.fc = nn.Sequential( 279 | ME.MinkowskiLinear(channel, channel // reduction), 280 | ME.MinkowskiReLU(inplace=True), 281 | ME.MinkowskiLinear(channel // reduction, channel), 282 | ME.MinkowskiSigmoid(), 283 | ) 284 | self.pooling = ME.MinkowskiGlobalPooling(dimension=D) 285 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication(dimension=D) 286 | 287 | def forward(self, x): 288 | y = self.pooling(x) 289 | y = self.fc(y) 290 | return self.broadcast_mul(x, y) 291 | 292 | 293 | class SEBasicBlock(BasicBlock): 294 | def __init__( 295 | self, inplanes, planes, stride=1, dilation=1, downsample=None, conv_type=ConvType.HYPERCUBE, reduction=16, D=-1 296 | ): 297 | super(SEBasicBlock, self).__init__( 298 | inplanes, planes, stride=stride, dilation=dilation, downsample=downsample, conv_type=conv_type, D=D 299 | ) 300 | self.se = SELayer(planes, reduction=reduction, D=D) 301 | 302 | def forward(self, x): 303 | residual = x 304 | 305 | out = self.conv1(x) 306 | out = self.norm1(out) 307 | out = self.relu(out) 308 | 309 | out = self.conv2(out) 310 | out = self.norm2(out) 311 | out = self.se(out) 312 | 313 | if self.downsample is not None: 314 | residual = self.downsample(x) 315 | 316 | out += residual 317 | out = self.relu(out) 318 | 319 | return out 320 | 321 | 322 | class SEBasicBlockBN(SEBasicBlock): 323 | NORM_TYPE = NormType.BATCH_NORM 324 | 325 | 326 | class SEBasicBlockIN(SEBasicBlock): 327 | NORM_TYPE = NormType.INSTANCE_NORM 328 | 329 | 330 | class SEBasicBlockIBN(SEBasicBlock): 331 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM 332 | 333 | 334 | class SEBottleneck(Bottleneck): 335 | def __init__( 336 | self, inplanes, planes, stride=1, dilation=1, downsample=None, conv_type=ConvType.HYPERCUBE, D=3, reduction=16 337 | ): 338 | super(SEBottleneck, self).__init__( 339 | inplanes, planes, stride=stride, dilation=dilation, downsample=downsample, conv_type=conv_type, D=D 340 | ) 341 | self.se = SELayer(planes * self.expansion, reduction=reduction, D=D) 342 | 343 | def forward(self, x): 344 | residual = x 345 | 346 | out = self.conv1(x) 347 | out = self.norm1(out) 348 | out = self.relu(out) 349 | 350 | out = self.conv2(out) 351 | out = self.norm2(out) 352 | out = self.relu(out) 353 | 354 | out = self.conv3(out) 355 | out = self.norm3(out) 356 | out = self.se(out) 357 | 358 | if self.downsample is not None: 359 | residual = self.downsample(x) 360 | 361 | out += residual 362 | out = self.relu(out) 363 | 364 | return out 365 | 366 | 367 | class SEBottleneckBN(SEBottleneck): 368 | NORM_TYPE = NormType.BATCH_NORM 369 | 370 | 371 | class SEBottleneckIN(SEBottleneck): 372 | NORM_TYPE = NormType.INSTANCE_NORM 373 | 374 | 375 | class SEBottleneckIBN(SEBottleneck): 376 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM 377 | -------------------------------------------------------------------------------- /pointdc_mk/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import MinkowskiEngine as ME 4 | from .modules import BasicBlock, Bottleneck 5 | 6 | 7 | class ResNetBase(nn.Module): 8 | BLOCK = None 9 | LAYERS = () 10 | INIT_DIM = 64 11 | PLANES = (64, 128, 256, 512) 12 | 13 | def __init__(self, in_channels, out_channels, D=3, **kwargs): 14 | nn.Module.__init__(self) 15 | self.D = D 16 | assert self.BLOCK is not None, "BLOCK is not defined" 17 | assert self.PLANES is not None, "PLANES is not defined" 18 | self.network_initialization(in_channels, out_channels, D) 19 | self.weight_initialization() 20 | 21 | def network_initialization(self, in_channels, out_channels, D): 22 | 23 | self.inplanes = self.INIT_DIM 24 | self.conv1 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=5, stride=2, dimension=D) 25 | 26 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 27 | self.relu = ME.MinkowskiReLU(inplace=True) 28 | 29 | self.pool = ME.MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=D) 30 | 31 | self.layer1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2) 32 | self.layer2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2) 33 | self.layer3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2) 34 | self.layer4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2) 35 | 36 | self.conv5 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D) 37 | self.bn5 = ME.MinkowskiBatchNorm(self.inplanes) 38 | 39 | self.glob_avg = ME.MinkowskiGlobalMaxPooling(dimension=D) 40 | 41 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) 42 | 43 | def weight_initialization(self): 44 | for m in self.modules(): 45 | if isinstance(m, ME.MinkowskiConvolution): 46 | ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") 47 | 48 | if isinstance(m, ME.MinkowskiBatchNorm): 49 | nn.init.constant_(m.bn.weight, 1) 50 | nn.init.constant_(m.bn.bias, 0) 51 | 52 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1): 53 | downsample = None 54 | if stride != 1 or self.inplanes != planes * block.EXPANSION: 55 | downsample = nn.Sequential( 56 | ME.MinkowskiConvolution( 57 | self.inplanes, planes * block.EXPANSION, kernel_size=1, stride=stride, dimension=self.D 58 | ), 59 | ME.MinkowskiBatchNorm(planes * block.EXPANSION), 60 | ) 61 | layers = [] 62 | layers.append( 63 | block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample, dimension=self.D) 64 | ) 65 | self.inplanes = planes * block.EXPANSION 66 | for i in range(1, blocks): 67 | layers.append(block(self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D)) 68 | 69 | return nn.Sequential(*layers) 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = self.bn1(x) 74 | x = self.relu(x) 75 | x = self.pool(x) 76 | 77 | x = self.layer1(x) 78 | x = self.layer2(x) 79 | x = self.layer3(x) 80 | x = self.layer4(x) 81 | 82 | x = self.conv5(x) 83 | x = self.bn5(x) 84 | x = self.relu(x) 85 | 86 | x = self.glob_avg(x) 87 | return self.final(x) 88 | 89 | 90 | class ResNet14(ResNetBase): 91 | BLOCK = BasicBlock 92 | LAYERS = (1, 1, 1, 1) 93 | 94 | 95 | class ResNet18(ResNetBase): 96 | BLOCK = BasicBlock 97 | LAYERS = (2, 2, 2, 2) 98 | 99 | 100 | class ResNet34(ResNetBase): 101 | BLOCK = BasicBlock 102 | LAYERS = (3, 4, 6, 3) 103 | 104 | 105 | class ResNet50(ResNetBase): 106 | BLOCK = Bottleneck 107 | LAYERS = (3, 4, 6, 3) 108 | 109 | 110 | class ResNet101(ResNetBase): 111 | BLOCK = Bottleneck 112 | LAYERS = (3, 4, 23, 3) 113 | 114 | 115 | class MinkUNetBase(ResNetBase): 116 | BLOCK = None 117 | PLANES = None 118 | DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) 119 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 120 | INIT_DIM = 32 121 | OUT_TENSOR_STRIDE = 1 122 | 123 | # To use the model, must call initialize_coords before forward pass. 124 | # Once data is processed, call clear to reset the model before calling 125 | # initialize_coords 126 | def __init__(self, in_channels, out_channels, D=3, **kwargs): 127 | ResNetBase.__init__(self, in_channels, out_channels, D) 128 | 129 | def network_initialization(self, in_channels, out_channels, D): 130 | # Output of the first conv concated to conv6 131 | self.inplanes = self.INIT_DIM 132 | self.conv0p1s1 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=5, dimension=D) 133 | 134 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) 135 | 136 | self.conv1p1s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 137 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 138 | 139 | self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0]) 140 | 141 | self.conv2p2s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 142 | self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) 143 | 144 | self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1]) 145 | 146 | self.conv3p4s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 147 | 148 | self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) 149 | self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2]) 150 | 151 | self.conv4p8s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 152 | self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) 153 | self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3]) 154 | 155 | self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose( 156 | self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D 157 | ) 158 | self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4]) 159 | 160 | self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.EXPANSION 161 | self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4]) 162 | self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose( 163 | self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D 164 | ) 165 | self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5]) 166 | 167 | self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.EXPANSION 168 | self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5]) 169 | self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose( 170 | self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D 171 | ) 172 | self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6]) 173 | 174 | self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.EXPANSION 175 | self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], self.LAYERS[6]) 176 | self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose( 177 | self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D 178 | ) 179 | self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7]) 180 | 181 | self.inplanes = self.PLANES[7] + self.INIT_DIM 182 | self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], self.LAYERS[7]) 183 | 184 | self.final = ME.MinkowskiConvolution(self.PLANES[7], out_channels, kernel_size=1, bias=True, dimension=D) 185 | self.relu = ME.MinkowskiReLU(inplace=True) 186 | 187 | def forward(self, x): 188 | out = self.conv0p1s1(x) 189 | out = self.bn0(out) 190 | out_p1 = self.relu(out) 191 | 192 | out = self.conv1p1s2(out_p1) 193 | out = self.bn1(out) 194 | out = self.relu(out) 195 | out_b1p2 = self.block1(out) 196 | 197 | out = self.conv2p2s2(out_b1p2) 198 | out = self.bn2(out) 199 | out = self.relu(out) 200 | out_b2p4 = self.block2(out) 201 | 202 | out = self.conv3p4s2(out_b2p4) 203 | out = self.bn3(out) 204 | out = self.relu(out) 205 | out_b3p8 = self.block3(out) 206 | 207 | # tensor_stride=16 208 | out = self.conv4p8s2(out_b3p8) 209 | out = self.bn4(out) 210 | out = self.relu(out) 211 | out = self.block4(out) 212 | 213 | # tensor_stride=8 214 | out = self.convtr4p16s2(out) 215 | out = self.bntr4(out) 216 | out = self.relu(out) 217 | 218 | out = ME.cat(out, out_b3p8) 219 | out = self.block5(out) 220 | 221 | # tensor_stride=4 222 | out = self.convtr5p8s2(out) 223 | out = self.bntr5(out) 224 | out = self.relu(out) 225 | 226 | out = ME.cat(out, out_b2p4) 227 | out = self.block6(out) 228 | 229 | # tensor_stride=2 230 | out = self.convtr6p4s2(out) 231 | out = self.bntr6(out) 232 | out = self.relu(out) 233 | 234 | out = ME.cat(out, out_b1p2) 235 | out = self.block7(out) 236 | 237 | # tensor_stride=1 238 | out = self.convtr7p2s2(out) 239 | out = self.bntr7(out) 240 | out = self.relu(out) 241 | 242 | out = ME.cat(out, out_p1) 243 | out = self.block8(out) 244 | 245 | return self.final(out) 246 | 247 | 248 | class MinkUNet14(MinkUNetBase): 249 | BLOCK = BasicBlock 250 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) 251 | 252 | 253 | class MinkUNet18(MinkUNetBase): 254 | BLOCK = BasicBlock 255 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 256 | 257 | 258 | class MinkUNet34(MinkUNetBase): 259 | BLOCK = BasicBlock 260 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 261 | 262 | 263 | class MinkUNet50(MinkUNetBase): 264 | BLOCK = Bottleneck 265 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 266 | 267 | 268 | class MinkUNet101(MinkUNetBase): 269 | BLOCK = Bottleneck 270 | LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) 271 | 272 | 273 | class MinkUNet14A(MinkUNet14): 274 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 275 | 276 | 277 | class MinkUNet14B(MinkUNet14): 278 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 279 | 280 | 281 | class MinkUNet14C(MinkUNet14): 282 | PLANES = (32, 64, 128, 256, 192, 192, 128, 128) 283 | 284 | 285 | class MinkUNet14D(MinkUNet14): 286 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 287 | 288 | 289 | class MinkUNet18A(MinkUNet18): 290 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 291 | 292 | 293 | class MinkUNet18B(MinkUNet18): 294 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 295 | 296 | 297 | class MinkUNet18D(MinkUNet18): 298 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 299 | 300 | 301 | class MinkUNet34A(MinkUNet34): 302 | PLANES = (32, 64, 128, 256, 256, 128, 64, 64) 303 | 304 | 305 | class MinkUNet34B(MinkUNet34): 306 | PLANES = (32, 64, 128, 256, 256, 128, 64, 32) 307 | 308 | 309 | class MinkUNet34C(MinkUNet34): 310 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) 311 | -------------------------------------------------------------------------------- /pointdc_mk/models/pretrain_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import MinkowskiEngine as ME 4 | 5 | from MinkowskiEngine import MinkowskiNetwork 6 | from MinkowskiEngine import MinkowskiReLU 7 | import torch.nn.functional as F 8 | 9 | from os.path import dirname, abspath 10 | import sys 11 | sys.path.append(dirname(abspath(__file__))) 12 | from common import ConvType, NormType, conv, conv_tr, get_norm, sum_pool 13 | 14 | class SegHead(nn.Module): 15 | def __init__(self, in_channels=128, out_channels=20): 16 | super(SegHead, self).__init__() 17 | self.cluster = torch.nn.parameter.Parameter(data=torch.randn(out_channels, in_channels), requires_grad=True) 18 | 19 | def forward(self, feats): 20 | normed_clusters = F.normalize(self.cluster, dim=1) 21 | normed_features = F.normalize(feats, dim=1) 22 | logits = F.linear(normed_features, normed_clusters) 23 | 24 | return logits 25 | 26 | class alignlayer(MinkowskiNetwork): 27 | def __init__(self, in_channels=128, out_channels=70, bn_momentum=0.02, norm_layer=True, D=3): 28 | super(alignlayer, self).__init__(D) 29 | def space_n_time_m(n, m): 30 | return n if D == 3 else [n, n, n, m] 31 | 32 | self.conv_align = conv( 33 | in_channels, out_channels, kernel_size=space_n_time_m(1, 1), stride=1, D=D 34 | ) 35 | 36 | def forward(self, feats_tensorfield): 37 | x = feats_tensorfield.sparse() 38 | aligned_out = self.conv_align(x).slice(feats_tensorfield).F # 39 | 40 | return aligned_out 41 | 42 | class SubModel(MinkowskiNetwork): 43 | def __init__(self, args, D=3): 44 | super(SubModel, self).__init__(D) 45 | self.args = args 46 | self.distill_layer = alignlayer(in_channels=args.feats_dim) 47 | 48 | def forward(self, feats_tensorfield): 49 | feats_aligned_nonorm = self.distill_layer(feats_tensorfield) 50 | 51 | return feats_aligned_nonorm -------------------------------------------------------------------------------- /pointdc_mk/models/resunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import MinkowskiEngine as ME 3 | import MinkowskiEngine.MinkowskiFunctional as MEF 4 | from .common import get_norm 5 | 6 | from .res16unet import get_block 7 | from .common import NormType 8 | 9 | 10 | class ResUNet2(ME.MinkowskiNetwork): 11 | NORM_TYPE = None 12 | BLOCK_NORM_TYPE = NormType.BATCH_NORM 13 | CHANNELS = [None, 32, 64, 128, 256] 14 | TR_CHANNELS = [None, 32, 64, 64, 128] 15 | 16 | # To use the model, must call initialize_coords before forward pass. 17 | # Once data is processed, call clear to reset the model before calling initialize_coords 18 | def __init__( 19 | self, in_channels=3, out_channels=32, bn_momentum=0.01, normalize_feature=True, conv1_kernel_size=5, D=3 20 | ): 21 | ME.MinkowskiNetwork.__init__(self, D) 22 | NORM_TYPE = self.NORM_TYPE 23 | BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE 24 | CHANNELS = self.CHANNELS 25 | TR_CHANNELS = self.TR_CHANNELS 26 | # print(D, in_channels, out_channels, conv1_kernel_size) 27 | self.normalize_feature = normalize_feature 28 | self.conv1 = ME.MinkowskiConvolution( 29 | in_channels=in_channels, 30 | out_channels=CHANNELS[1], 31 | kernel_size=conv1_kernel_size, 32 | stride=1, 33 | dilation=1, 34 | bias=False, 35 | dimension=D, 36 | ) 37 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D) 38 | 39 | self.block1 = get_block(BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum, D=D) 40 | 41 | self.conv2 = ME.MinkowskiConvolution( 42 | in_channels=CHANNELS[1], 43 | out_channels=CHANNELS[2], 44 | kernel_size=3, 45 | stride=2, 46 | dilation=1, 47 | bias=False, 48 | dimension=D, 49 | ) 50 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D) 51 | 52 | self.block2 = get_block(BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum, D=D) 53 | 54 | self.conv3 = ME.MinkowskiConvolution( 55 | in_channels=CHANNELS[2], 56 | out_channels=CHANNELS[3], 57 | kernel_size=3, 58 | stride=2, 59 | dilation=1, 60 | bias=False, 61 | dimension=D, 62 | ) 63 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D) 64 | 65 | self.block3 = get_block(BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum, D=D) 66 | 67 | self.conv4 = ME.MinkowskiConvolution( 68 | in_channels=CHANNELS[3], 69 | out_channels=CHANNELS[4], 70 | kernel_size=3, 71 | stride=2, 72 | dilation=1, 73 | bias=False, 74 | dimension=D, 75 | ) 76 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D) 77 | 78 | self.block4 = get_block(BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum, D=D) 79 | 80 | self.conv4_tr = ME.MinkowskiConvolutionTranspose( 81 | in_channels=CHANNELS[4], 82 | out_channels=TR_CHANNELS[4], 83 | kernel_size=3, 84 | stride=2, 85 | dilation=1, 86 | bias=False, 87 | dimension=D, 88 | ) 89 | self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) 90 | 91 | self.block4_tr = get_block(BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) 92 | 93 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 94 | in_channels=CHANNELS[3] + TR_CHANNELS[4], 95 | out_channels=TR_CHANNELS[3], 96 | kernel_size=3, 97 | stride=2, 98 | dilation=1, 99 | bias=False, 100 | dimension=D, 101 | ) 102 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 103 | 104 | self.block3_tr = get_block(BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 105 | 106 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 107 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 108 | out_channels=TR_CHANNELS[2], 109 | kernel_size=3, 110 | stride=2, 111 | dilation=1, 112 | bias=False, 113 | dimension=D, 114 | ) 115 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 116 | 117 | self.block2_tr = get_block(BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 118 | 119 | self.conv1_tr = ME.MinkowskiConvolution( 120 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 121 | out_channels=TR_CHANNELS[1], 122 | kernel_size=1, 123 | stride=1, 124 | dilation=1, 125 | bias=False, 126 | dimension=D, 127 | ) 128 | 129 | # self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D) 130 | 131 | self.final = ME.MinkowskiConvolution( 132 | in_channels=TR_CHANNELS[1], 133 | out_channels=out_channels, 134 | kernel_size=1, 135 | stride=1, 136 | dilation=1, 137 | bias=True, 138 | dimension=D, 139 | ) 140 | 141 | def forward(self, x): 142 | out_s1 = self.conv1(x) 143 | out_s1 = self.norm1(out_s1) 144 | out_s1 = self.block1(out_s1) 145 | out = MEF.relu(out_s1) 146 | 147 | out_s2 = self.conv2(out) 148 | out_s2 = self.norm2(out_s2) 149 | out_s2 = self.block2(out_s2) 150 | out = MEF.relu(out_s2) 151 | 152 | out_s4 = self.conv3(out) 153 | out_s4 = self.norm3(out_s4) 154 | out_s4 = self.block3(out_s4) 155 | out = MEF.relu(out_s4) 156 | 157 | out_s8 = self.conv4(out) 158 | out_s8 = self.norm4(out_s8) 159 | out_s8 = self.block4(out_s8) 160 | out = MEF.relu(out_s8) 161 | 162 | out = self.conv4_tr(out) 163 | out = self.norm4_tr(out) 164 | out = self.block4_tr(out) 165 | out_s4_tr = MEF.relu(out) 166 | 167 | out = ME.cat(out_s4_tr, out_s4) 168 | 169 | out = self.conv3_tr(out) 170 | out = self.norm3_tr(out) 171 | out = self.block3_tr(out) 172 | out_s2_tr = MEF.relu(out) 173 | 174 | out = ME.cat(out_s2_tr, out_s2) 175 | 176 | out = self.conv2_tr(out) 177 | out = self.norm2_tr(out) 178 | out = self.block2_tr(out) 179 | out_s1_tr = MEF.relu(out) 180 | 181 | out = ME.cat(out_s1_tr, out_s1) 182 | out = self.conv1_tr(out) 183 | out = MEF.relu(out) 184 | out = self.final(out) 185 | 186 | if self.normalize_feature: 187 | return ME.SparseTensor( 188 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 189 | coordinate_map_key=out.coordinate_map_key, 190 | coordinate_manager=out.coordinate_manager, 191 | ) 192 | else: 193 | return out 194 | 195 | 196 | class ResUNetBN2(ResUNet2): 197 | NORM_TYPE = NormType.BATCH_NORM 198 | 199 | 200 | class ResUNetBN2B(ResUNet2): 201 | NORM_TYPE = NormType.BATCH_NORM 202 | CHANNELS = [None, 32, 64, 128, 256] 203 | TR_CHANNELS = [None, 64, 64, 64, 64] 204 | 205 | 206 | class ResUNetBN2C(ResUNet2): 207 | NORM_TYPE = NormType.BATCH_NORM 208 | CHANNELS = [None, 32, 64, 128, 256] 209 | TR_CHANNELS = [None, 64, 64, 64, 128] 210 | 211 | 212 | class ResUNetBN2D(ResUNet2): 213 | NORM_TYPE = NormType.BATCH_NORM 214 | CHANNELS = [None, 32, 64, 128, 256] 215 | TR_CHANNELS = [None, 64, 64, 128, 128] 216 | 217 | 218 | class ResUNetBN2E(ResUNet2): 219 | NORM_TYPE = NormType.BATCH_NORM 220 | CHANNELS = [None, 128, 128, 128, 256] 221 | TR_CHANNELS = [None, 64, 128, 128, 128] 222 | 223 | 224 | class Res2BlockDown(ME.MinkowskiNetwork): 225 | 226 | """ 227 | block for unwrapped Resnet 228 | """ 229 | 230 | def __init__( 231 | self, 232 | down_conv_nn, 233 | kernel_size, 234 | stride, 235 | dilation, 236 | dimension=3, 237 | bn_momentum=0.01, 238 | norm_type=NormType.BATCH_NORM, 239 | block_norm_type=NormType.BATCH_NORM, 240 | **kwargs 241 | ): 242 | ME.MinkowskiNetwork.__init__(self, dimension) 243 | self.conv = ME.MinkowskiConvolution( 244 | in_channels=down_conv_nn[0], 245 | out_channels=down_conv_nn[1], 246 | kernel_size=kernel_size, 247 | stride=stride, 248 | dilation=dilation, 249 | bias=False, 250 | dimension=dimension, 251 | ) 252 | self.norm = get_norm(norm_type, down_conv_nn[1], bn_momentum=bn_momentum, D=dimension) 253 | self.block = get_block(block_norm_type, down_conv_nn[1], down_conv_nn[1], bn_momentum=bn_momentum, D=dimension) 254 | 255 | def forward(self, x): 256 | 257 | out_s = self.conv(x) 258 | out_s = self.norm(out_s) 259 | out = self.block(out_s) 260 | return out 261 | 262 | 263 | class Res2BlockUp(ME.MinkowskiNetwork): 264 | 265 | """ 266 | block for unwrapped Resnet 267 | """ 268 | 269 | def __init__( 270 | self, 271 | up_conv_nn, 272 | kernel_size, 273 | stride, 274 | dilation, 275 | dimension=3, 276 | bn_momentum=0.01, 277 | norm_type=NormType.BATCH_NORM, 278 | block_norm_type=NormType.BATCH_NORM, 279 | **kwargs 280 | ): 281 | ME.MinkowskiNetwork.__init__(self, dimension) 282 | self.conv = ME.MinkowskiConvolutionTranspose( 283 | in_channels=up_conv_nn[0], 284 | out_channels=up_conv_nn[1], 285 | kernel_size=kernel_size, 286 | stride=stride, 287 | dilation=dilation, 288 | bias=False, 289 | dimension=dimension, 290 | ) 291 | if len(up_conv_nn) == 3: 292 | self.final = ME.MinkowskiConvolution( 293 | in_channels=up_conv_nn[1], 294 | out_channels=up_conv_nn[2], 295 | kernel_size=kernel_size, 296 | stride=stride, 297 | dilation=dilation, 298 | bias=True, 299 | dimension=dimension, 300 | ) 301 | else: 302 | self.norm = get_norm(norm_type, up_conv_nn[1], bn_momentum=bn_momentum, D=dimension) 303 | self.block = get_block(block_norm_type, up_conv_nn[1], up_conv_nn[1], bn_momentum=bn_momentum, D=dimension) 304 | self.final = None 305 | 306 | def forward(self, x, x_skip): 307 | if x_skip is not None: 308 | x = ME.cat(x, x_skip) 309 | out_s = self.conv(x) 310 | if self.final is None: 311 | out_s = self.norm(out_s) 312 | out = self.block(out_s) 313 | return out 314 | else: 315 | out_s = MEF.relu(out_s) 316 | out = self.final(out_s) 317 | return out 318 | -------------------------------------------------------------------------------- /pointdc_mk/train_S3DIS.py: -------------------------------------------------------------------------------- 1 | import os, random, time, argparse, logging, warnings, torch 2 | import numpy as np 3 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2 4 | from datasets.S3DIS import S3DISdistill, S3DIStrain, S3DIScluster, cfl_collate_fn_distill, cfl_collate_fn 5 | import MinkowskiEngine as ME 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | from models.fpn import Res16FPN18 9 | from models.pretrain_models import SubModel, SegHead 10 | from eval_S3DIS import eval, eval_once, eval_by_cluster 11 | from lib.utils_s3dis import * 12 | from sklearn.cluster import KMeans 13 | from os.path import join 14 | from tqdm import tqdm 15 | from torch.optim import lr_scheduler 16 | warnings.filterwarnings('ignore') 17 | 18 | def parse_args(): 19 | '''PARAMETERS''' 20 | parser = argparse.ArgumentParser(description='PyTorch Unsuper_3D_Seg') 21 | parser.add_argument('--data_path', type=str, default='data/S3DIS/', help='pont cloud data path') 22 | parser.add_argument('--sp_path', type=str, default= 'data/S3DIS/', help='initial sp path') 23 | parser.add_argument('--expname', type=str, default= 'zdefalut', help='expname for logger') 24 | ### 25 | parser.add_argument('--save_path', type=str, default='ckpt/S3DIS/', help='model savepath') 26 | parser.add_argument('--max_epoch', type=list, default=[200, 30, 60], help='max epoch') 27 | ### 28 | parser.add_argument('--bn_momentum', type=float, default=0.02, help='batchnorm parameters') 29 | parser.add_argument('--conv1_kernel_size', type=int, default=5, help='kernel size of 1st conv layers') 30 | #### 31 | parser.add_argument('--lrs', type=list, default=[1e-3, 3e-2, 3e-2], help='learning rate') 32 | parser.add_argument('--momentum', type=float, default=0.9, help='SGD parameters') 33 | parser.add_argument('--dampening', type=float, default=0.1, help='SGD parameters') 34 | parser.add_argument('--weight-decay', type=float, default=1e-4, help='SGD parameters') 35 | parser.add_argument('--workers', type=int, default=8, help='how many workers for loading data') 36 | parser.add_argument('--cluster_workers', type=int, default=4, help='how many workers for loading data in clustering') 37 | parser.add_argument('--seed', type=int, default=2023, help='random seed') 38 | parser.add_argument('--log-interval', type=int, default=150, help='log interval') 39 | parser.add_argument('--batch_size', type=int, default=8, help='batchsize in training') 40 | parser.add_argument('--voxel_size', type=float, default=0.05, help='voxel size in SparseConv') 41 | parser.add_argument('--input_dim', type=int, default=6, help='network input dimension')### 6 for XYZGB 42 | parser.add_argument('--primitive_num', type=int, default=13, help='how many primitives used in training') 43 | parser.add_argument('--semantic_class', type=int, default=13, help='ground truth semantic class') 44 | parser.add_argument('--feats_dim', type=int, default=128, help='output feature dimension') 45 | parser.add_argument('--ignore_label', type=int, default=-1, help='invalid label') 46 | parser.add_argument('--drop_threshold', type=int, default=50, help='mask counts') 47 | 48 | return parser.parse_args() 49 | 50 | def main(args, logger): 51 | # Cross model distillation 52 | logger.info('**************Start Cross Model Distillation**************') 53 | ## Prepare Model/Optimizer 54 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, \ 55 | conv1_kernel_size=args.conv1_kernel_size, args=args, mode='distill') 56 | submodel = SubModel(args) 57 | adam = torch.optim.Adam([{'params':model.parameters()}, {'params': submodel.parameters()}], \ 58 | lr=args.lrs[0]) 59 | model, submodel = model.cuda(), submodel.cuda() 60 | distill_loss = MseMaskLoss().cuda() 61 | 62 | # Prepare Data 63 | distillset = S3DISdistill(args) 64 | distill_loader = DataLoader(distillset, batch_size=args.batch_size, shuffle=True, collate_fn=cfl_collate_fn_distill(), \ 65 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed)) 66 | clusterset = S3DIScluster(args, areas=['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']) 67 | cluster_loader = DataLoader(clusterset, batch_size=1, shuffle=True, collate_fn=cfl_collate_fn(), \ 68 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed)) 69 | # Distill 70 | for epoch in range(1, args.max_epoch[0]+1): 71 | distill(distill_loader, logger, model, submodel, adam, distill_loss, epoch, args.max_epoch[0]) 72 | if epoch % 10 == 0: 73 | torch.save(model.state_dict(), join(args.save_path, 'cmd', 'model_' + str(epoch) + '_checkpoint.pth')) 74 | torch.save(submodel.state_dict(), join(args.save_path, 'cmd', 'submodule_' + str(epoch) + '_checkpoint.pth')) 75 | 76 | # Compute pseudo label 77 | model, submodel = model.cuda(), submodel.cuda() 78 | 79 | centroids_norm = init_cluster(args, logger, cluster_loader, model, submodel=submodel) 80 | 81 | del adam, distill_loader, distillset, submodel 82 | logger.info('====>End Cross Model Distill !!!\n') 83 | 84 | # Super Voxel Clustering 85 | logger.info('**************Start Super Voxel Clustering**************') 86 | ## Prepare Data 87 | trainset = S3DIStrain(args, areas=['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']) 88 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, collate_fn=cfl_collate_fn(), \ 89 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed)) 90 | ## Warm Up 91 | model.mode = 'train' 92 | ## Prepare Model/Loss/Optimizer 93 | seghead = SegHead(args.feats_dim, args.primitive_num) 94 | seghead = seghead.cuda() 95 | loss = torch.nn.CrossEntropyLoss(ignore_index=-1).cuda() 96 | warmup_optimizer = torch.optim.SGD([{"params": seghead.parameters()}, {"params": model.parameters()}], \ 97 | lr=args.lrs[1], momentum=args.momentum, \ 98 | dampening=args.dampening, weight_decay=args.weight_decay) 99 | 100 | logger.info('====>Start Warm Up.') 101 | for epoch in range(1, args.max_epoch[1]+1): 102 | train(train_loader, logger, model, warmup_optimizer, loss, epoch, seghead, args.max_epoch[1]) 103 | ### Evalutaion and Save checkpoint 104 | if epoch % 5 == 0: 105 | torch.save(model.state_dict(), join(args.save_path, 'svc', 'model_' + str(epoch) + '_checkpoint.pth')) 106 | torch.save(seghead.state_dict()['cluster'], join(args.save_path, 'svc', 'cls_' + str(epoch) + '_checkpoint.pth')) 107 | with torch.no_grad(): 108 | o_Acc, m_Acc, s = eval(epoch, args) 109 | logger.info('WarmUp--Eval Epoch: {:02d}, oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, o_Acc, m_Acc) + s+'\n') 110 | logger.info('====>End Warm Up !!!\n') 111 | 112 | # Iterative Training 113 | del seghead, warmup_optimizer # NOTE 114 | logger.info('====>Start Iterative Training.') 115 | iter_optimizer = torch.optim.SGD(model.parameters(), \ 116 | lr=args.lrs[2], momentum=args.momentum, \ 117 | dampening=args.dampening, weight_decay=args.weight_decay) 118 | 119 | scheduler = lr_scheduler.StepLR(iter_optimizer, step_size=5, gamma=0.8) # step lr 120 | logger.info('====>Update pseudo labels.') 121 | centroids_norm = init_cluster(args, logger, cluster_loader, model) 122 | seghead = get_fixclassifier(args.feats_dim, args.primitive_num, centroids_norm).cuda() 123 | for epoch in range(args.max_epoch[1]+1, args.max_epoch[1]+args.max_epoch[2]+1): 124 | logger.info('Update Optimizer lr:{:.2e}'.format(scheduler.get_last_lr()[0])) 125 | train(train_loader, logger, model, iter_optimizer, loss, epoch, seghead, args.max_epoch[1]+args.max_epoch[2]) ### train 126 | scheduler.step() 127 | if epoch % 5 == 0: 128 | torch.save(model.state_dict(), join(args.save_path, 'svc', 'model_' + str(epoch) + '_checkpoint.pth')) 129 | torch.save(seghead.state_dict()['weight'], join(args.save_path, 'svc', 'cls_' + str(epoch) + '_checkpoint.pth')) 130 | with torch.no_grad(): 131 | o_Acc, m_Acc, s = eval(epoch, args, mode='svc') 132 | logger.info('Iter--Eval Epoch{:02d}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, o_Acc, m_Acc) + s+'\n') 133 | ### Update pseudo labels 134 | if epoch != args.max_epoch[1]+args.max_epoch[2]+1: 135 | logger.info('Update pseudo labels') 136 | centroids_norm = init_cluster(args, logger, cluster_loader, model) 137 | seghead.weight.data = centroids_norm.requires_grad_(False) 138 | 139 | logger.info('====>End Super Voxel Clustering !!!\n') 140 | 141 | 142 | def init_cluster(args, logger, cluster_loader, model, submodel=None): 143 | time_start = time.time() 144 | 145 | ## Extract Superpoints Feature 146 | sp_feats_list = init_get_sp_feature(args, cluster_loader, model, submodel) 147 | sp_feats = torch.cat(sp_feats_list, dim=0) ### will do Kmeans with l2 distance 148 | _, centroids_norm = faiss_cluster(args, sp_feats.cpu().numpy()) 149 | centroids_norm = centroids_norm.cuda() 150 | 151 | ## Compute and Save Pseudo Labels 152 | all_pseudo, all_labels = init_get_pseudo(args, cluster_loader, model, centroids_norm, submodel) 153 | o_Acc, m_Acc, s = compute_seg_results(args, all_labels, all_pseudo) 154 | logger.info('clustering time: %.2fs', (time.time() - time_start)) 155 | logger.info('Trainset: oAcc {:.2f} mAcc {:.2f} IoUs'.format(o_Acc, m_Acc) + s+'\n') 156 | 157 | return centroids_norm 158 | 159 | def distill(distill_loader, logger, model, submodel, optimizer, loss, epoch, maxepochs): 160 | distill_loader.dataset.mode = 'distill' 161 | model.train() 162 | submodel.train() 163 | loss_display = AverageMeter() 164 | 165 | trainloader_bar = tqdm(distill_loader) 166 | for batch_idx, data in enumerate(trainloader_bar): 167 | ## Prepare data 168 | trainloader_bar.set_description('Epoch {}'.format(epoch)) 169 | coords, features, dinofeats, normals, labels, inverse_map, region, index, scenenames = data 170 | ## Forward 171 | in_field = ME.TensorField(features, coords, device=0) 172 | feats = model(in_field) 173 | feats_aligned = submodel(feats) 174 | ## Loss 175 | mask = region.squeeze() >= 0 176 | loss_distill = loss(F.normalize(feats_aligned[mask]), F.normalize(dinofeats[mask].cuda().detach())) 177 | loss_display.update(loss_distill.item()) 178 | optimizer.zero_grad() 179 | loss_distill.backward() 180 | optimizer.step() 181 | 182 | torch.cuda.empty_cache() 183 | torch.cuda.synchronize(torch.device("cuda")) 184 | if batch_idx %5 == 0: 185 | trainloader_bar.set_postfix(trainloss='{:.3e}'.format(loss_display.avg)) 186 | if epoch % 10 == 0: 187 | logger.info('Epoch: {}/{} Train loss: {:.3e}'.format(epoch, maxepochs, loss_display.avg)) 188 | 189 | def train(train_loader, logger, model, optimizer, loss, epoch, classifier, maxepochs): 190 | train_loader.dataset.mode = 'train' 191 | model.train() 192 | classifier.train() 193 | loss_display = AverageMeter() 194 | 195 | trainloader_bar = tqdm(train_loader) 196 | for batch_idx, data in enumerate(trainloader_bar): 197 | 198 | trainloader_bar.set_description('Epoch {}/{}'.format(epoch, maxepochs)) 199 | coords, features, normals, labels, inverse_map, pseudo_labels, inds, region, index, scenenames = data 200 | 201 | in_field = ME.TensorField(features, coords, device=0) 202 | feats_nonorm = model(in_field) 203 | logits = classifier(feats_nonorm) 204 | 205 | ## loss 206 | pseudo_labels_comp = pseudo_labels.long().cuda() 207 | loss_sem = loss(logits, pseudo_labels_comp).mean() 208 | loss_display.update(loss_sem.item()) 209 | optimizer.zero_grad() 210 | loss_sem.backward() 211 | optimizer.step() 212 | 213 | torch.cuda.empty_cache() 214 | torch.cuda.synchronize(torch.device("cuda")) 215 | 216 | if batch_idx % 20 == 0: 217 | trainloader_bar.set_postfix(trainloss='{:.3e}'.format(loss_display.avg)) 218 | 219 | logger.info('Epoch {}/{}: Train loss: {:.3e}'.format(epoch, maxepochs, loss_display.avg)) 220 | 221 | def set_logger(log_path): 222 | logger = logging.getLogger() 223 | logger.setLevel(logging.INFO) 224 | 225 | # Logging to a file 226 | file_handler = logging.FileHandler(log_path) 227 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 228 | logger.addHandler(file_handler) 229 | 230 | # Logging to console 231 | stream_handler = logging.StreamHandler() 232 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 233 | logger.addHandler(stream_handler) 234 | 235 | return logger 236 | 237 | if __name__ == '__main__': 238 | args = parse_args() 239 | 240 | args.save_path = os.path.join(args.save_path, args.expname) 241 | args.pseudo_path = os.path.join(args.save_path, 'pseudo_labels') 242 | 243 | '''Setup logger''' 244 | if not os.path.exists(args.save_path): 245 | os.makedirs(args.save_path) 246 | os.makedirs(args.pseudo_path) 247 | os.makedirs(join(args.save_path, 'cmd')) 248 | os.makedirs(join(args.save_path, 'svc')) 249 | logger = set_logger(os.path.join(args.save_path, 'train.log')) 250 | logger.info(args) 251 | 252 | '''Cache code''' 253 | cache_codes(args) 254 | 255 | '''Random Seed''' 256 | seed = args.seed 257 | set_seed(seed) 258 | 259 | main(args, logger) 260 | -------------------------------------------------------------------------------- /pointdc_mk/train_ScanNet.py: -------------------------------------------------------------------------------- 1 | import os, random, time, argparse, logging, warnings, torch 2 | import numpy as np 3 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2 4 | from datasets.ScanNet import Scannettrain, Scannetdistill, Scannetval, cfl_collate_fn, cfl_collate_fn_distill, cfl_collate_fn_val 5 | import MinkowskiEngine as ME 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | from models.fpn import Res16FPN18 9 | from models.pretrain_models import SubModel, SegHead 10 | from eval_ScanNet import eval, eval_once, eval_by_cluster 11 | from lib.utils import * 12 | from sklearn.cluster import KMeans 13 | from os.path import join 14 | from tqdm import tqdm 15 | from torch.optim import lr_scheduler 16 | warnings.filterwarnings('ignore') 17 | 18 | def parse_args(): 19 | '''PARAMETERS''' 20 | parser = argparse.ArgumentParser(description='PointDC') 21 | parser.add_argument('--data_path', type=str, default='data/ScanNet/train', help='pont cloud data path') # 点云文件路径 22 | parser.add_argument('--feats_path', type=str, default='data/ScanNet/train_feats', help='pont cloud data path') # 特征体文件路径 23 | parser.add_argument('--sp_path', type=str, default= 'data/ScanNet/initial_superpoints', help='initial sp path') # 超体素文件路径 24 | parser.add_argument('--expname', type=str, default= 'default', help='expname for logger') 25 | ### 26 | parser.add_argument('--save_path', type=str, default='ckpt/ScanNet/', help='model savepath') 27 | parser.add_argument('--max_epoch', type=list, default=[200, 30, 60], help='max epoch') 28 | ### 29 | parser.add_argument('--bn_momentum', type=float, default=0.02, help='batchnorm parameters') 30 | parser.add_argument('--conv1_kernel_size', type=int, default=5, help='kernel size of 1st conv layers') 31 | #### 32 | parser.add_argument('--lrs', type=list, default=[1e-3, 3e-2, 3e-2], help='learning rate') 33 | parser.add_argument('--momentum', type=float, default=0.9, help='SGD parameters') 34 | parser.add_argument('--dampening', type=float, default=0.1, help='SGD parameters') 35 | parser.add_argument('--weight-decay', type=float, default=1e-4, help='SGD parameters') 36 | parser.add_argument('--workers', type=int, default=8, help='how many workers for loading data') 37 | parser.add_argument('--cluster_workers', type=int, default=4, help='how many workers for loading data in clustering') 38 | parser.add_argument('--seed', type=int, default=2023, help='random seed') 39 | parser.add_argument('--log-interval', type=int, default=150, help='log interval') 40 | parser.add_argument('--batch_size', type=int, default=8, help='batchsize in training') 41 | parser.add_argument('--voxel_size', type=float, default=0.02, help='voxel size in SparseConv') 42 | parser.add_argument('--input_dim', type=int, default=6, help='network input dimension')### 6 for XYZGB 43 | parser.add_argument('--primitive_num', type=int, default=20, help='how many primitives used in training') 44 | parser.add_argument('--semantic_class', type=int, default=20, help='ground truth semantic class') 45 | parser.add_argument('--feats_dim', type=int, default=128, help='output feature dimension') 46 | parser.add_argument('--ignore_label', type=int, default=-1, help='invalid label') 47 | 48 | return parser.parse_args() 49 | 50 | 51 | def main(args, logger): 52 | # Cross model distillation 53 | logger.info('**************Start Cross Model Distillation**************') 54 | ## Prepare Model/Optimizer 55 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, \ 56 | conv1_kernel_size=args.conv1_kernel_size, args=args, mode='distill') 57 | submodel = SubModel(args) 58 | model, submodel= model.cuda(), submodel.cuda() 59 | distill_loss = MseMaskLoss().cuda() 60 | adam = torch.optim.Adam([{'params':model.parameters()}, {'params': submodel.parameters()}], \ 61 | lr=args.lrs[0]) 62 | ## Prepare Data 63 | distillset = Scannetdistill(args) 64 | distill_loader = DataLoader(distillset, batch_size=args.batch_size, shuffle=True, collate_fn=cfl_collate_fn_distill(), \ 65 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed)) 66 | ## Distill 67 | for epoch in range(1, args.max_epoch[0]+1): 68 | distill(distill_loader, logger, model, submodel, adam, distill_loss, epoch, args.max_epoch[0]) 69 | if epoch % 10 == 0: 70 | torch.save(model.state_dict(), join(args.save_path, 'cmd', 'model_' + str(epoch) + '_checkpoint.pth')) 71 | torch.save(submodel.state_dict(), join(args.save_path, 'cmd', 'submodule_' + str(epoch) + '_checkpoint.pth')) 72 | 73 | ## Cluster & Compute pseudo labels 74 | trainset = Scannettrain(args) 75 | cluster_loader = DataLoader(trainset, batch_size=1, shuffle=True, collate_fn=cfl_collate_fn(), \ 76 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed)) 77 | model, submodel = model.cuda(), submodel.cuda() 78 | centroids_norm = init_cluster(args, logger, cluster_loader, model, submodel=submodel) 79 | 80 | del adam, distill_loader, distillset, submodel 81 | logger.info('====>End Cross Model Distill !!!\n') 82 | 83 | # Super Voxel Clustering 84 | logger.info('**************Start Super Voxel Clustering**************') 85 | ## Prepare Data 86 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, collate_fn=cfl_collate_fn(), \ 87 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed)) 88 | ## Warm Up 89 | model.mode = 'train' 90 | ### Prepare Model/Loss/Optimizer 91 | seghead = SegHead(args.feats_dim, args.primitive_num) 92 | seghead = seghead.cuda() 93 | loss = torch.nn.CrossEntropyLoss(ignore_index=-1).cuda() 94 | warmup_optimizer = torch.optim.SGD([{"params": seghead.parameters()}, {"params": model.parameters()}], \ 95 | lr=args.lrs[1], momentum=args.momentum, dampening=args.dampening, weight_decay=args.weight_decay) 96 | scheduler = lr_scheduler.StepLR(warmup_optimizer, step_size=5, gamma=0.8) # step lr 97 | logger.info('====>Start Warm Up.') 98 | for epoch in range(1, args.max_epoch[1]+1): 99 | logger.info('Update Optimizer lr:{:.2e}'.format(scheduler.get_last_lr()[0])) 100 | train(train_loader, logger, model, warmup_optimizer, loss, epoch, seghead, args.max_epoch[1]) 101 | ### Evalutaion and Save checkpoint 102 | if epoch % 5 == 0: 103 | torch.save(model.state_dict(), join(args.save_path, 'svc', 'model_' + str(epoch) + '_checkpoint.pth')) 104 | torch.save(seghead.state_dict()['cluster'], join(args.save_path, 'svc', 'cls_' + str(epoch) + '_checkpoint.pth')) 105 | with torch.no_grad(): 106 | o_Acc, m_Acc, s = eval(epoch, args) 107 | logger.info('WarmUp--Eval Epoch: {:02d}, oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, o_Acc, m_Acc) + s+'\n') 108 | logger.info('====>End Warm Up !!!\n') 109 | ## Iterative Training 110 | del seghead, warmup_optimizer # NOTE 111 | logger.info('====>Start Iterative Training.') 112 | iter_optimizer = torch.optim.SGD(model.parameters(), \ 113 | lr=args.lrs[2], momentum=args.momentum, dampening=args.dampening, weight_decay=args.weight_decay) 114 | scheduler = lr_scheduler.StepLR(iter_optimizer, step_size=5, gamma=0.8) # step lr 115 | logger.info('====>Update pseudo labels.') 116 | centroids_norm = init_cluster(args, logger, cluster_loader, model) 117 | seghead = get_fixclassifier(args.feats_dim, args.primitive_num, centroids_norm).cuda() 118 | for epoch in range(args.max_epoch[1]+1, args.max_epoch[1]+args.max_epoch[2]+1): 119 | logger.info('Update Optimizer lr:{:.2e}'.format(scheduler.get_last_lr()[0])) 120 | train(train_loader, logger, model, iter_optimizer, loss, epoch, seghead, args.max_epoch[1]+args.max_epoch[2]) ### train 121 | scheduler.step() 122 | if epoch % 5 == 0: 123 | torch.save(model.state_dict(), join(args.save_path, 'svc', 'model_' + str(epoch) + '_checkpoint.pth')) 124 | torch.save(seghead.state_dict()['weight'], join(args.save_path, 'svc', 'cls_' + str(epoch) + '_checkpoint.pth')) 125 | with torch.no_grad(): 126 | o_Acc, m_Acc, s = eval(epoch, args, mode='svc') 127 | logger.info('Iter--Eval Epoch{:02d}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, o_Acc, m_Acc) + s+'\n') 128 | ### Update pseudo labels 129 | if epoch != args.max_epoch[1]+args.max_epoch[2]+1: 130 | logger.info('Update pseudo labels') 131 | centroids_norm = init_cluster(args, logger, cluster_loader, model) 132 | seghead.weight.data = centroids_norm.requires_grad_(False) 133 | 134 | logger.info('====>End Super Voxel Clustering !!!\n') 135 | 136 | def init_cluster(args, logger, cluster_loader, model, submodel=None): 137 | time_start = time.time() 138 | cluster_loader.dataset.mode = 'cluster' 139 | 140 | ## Extract Superpoints Feature 141 | sp_feats_list = init_get_sp_feature(args, cluster_loader, model, submodel) 142 | sp_feats = torch.cat(sp_feats_list, dim=0) ### will do Kmeans with geometric distance 143 | _, centroids_norm = faiss_cluster(args, sp_feats.cpu().numpy()) 144 | centroids_norm = centroids_norm.cuda() 145 | 146 | ## Compute and Save Pseudo Labels 147 | all_pseudo, all_labels = init_get_pseudo(args, cluster_loader, model, centroids_norm, submodel) 148 | logger.info('clustering time: %.2fs', (time.time() - time_start)) 149 | 150 | return centroids_norm 151 | 152 | def distill(distill_loader, logger, model, submodel, optimizer, loss, epoch, maxepochs): 153 | distill_loader.dataset.mode = 'distill' 154 | model.train() 155 | submodel.train() 156 | loss_display = AverageMeter() 157 | 158 | trainloader_bar = tqdm(distill_loader) 159 | for batch_idx, data in enumerate(trainloader_bar): 160 | ## Prepare data 161 | trainloader_bar.set_description('Epoch {}'.format(epoch)) 162 | coords, features, normals, labels, inverse_map, pseudo_labels, inds, region, index, scenenames, spfeats = data 163 | ## Forward 164 | in_field = ME.TensorField(features, coords, device=0) 165 | feats = model(in_field) 166 | feats_aligned = submodel(feats) 167 | ## Loss 168 | loss_distill = loss(feats_aligned, spfeats.cuda().detach()) 169 | loss_display.update(loss_distill.item()) 170 | optimizer.zero_grad() 171 | loss_distill.backward() 172 | optimizer.step() 173 | 174 | torch.cuda.empty_cache() 175 | # torch.cuda.synchronize(torch.device("cuda")) 176 | if batch_idx % 10 == 0: 177 | trainloader_bar.set_postfix(trainloss='{:.3e}'.format(loss_display.avg)) 178 | if epoch % 10 == 0: 179 | logger.info('Epoch: {}/{} Train loss: {:.3e}'.format(epoch, maxepochs, loss_display.avg)) 180 | 181 | def train(train_loader, logger, model, optimizer, loss, epoch, classifier, maxepochs): 182 | train_loader.dataset.mode = 'train' 183 | model.train() 184 | classifier.train() 185 | loss_display = AverageMeter() 186 | 187 | trainloader_bar = tqdm(train_loader) 188 | for batch_idx, data in enumerate(trainloader_bar): 189 | 190 | trainloader_bar.set_description('Epoch {}/{}'.format(epoch, maxepochs)) 191 | coords, features, normals, labels, inverse_map, pseudo_labels, inds, region, index, scenenames = data 192 | 193 | in_field = ME.TensorField(features, coords, device=0) 194 | feats_nonorm = model(in_field) 195 | logits = classifier(feats_nonorm) 196 | 197 | ## loss 198 | pseudo_labels_comp = pseudo_labels.long().cuda() 199 | loss_sem = loss(logits, pseudo_labels_comp).mean() 200 | loss_display.update(loss_sem.item()) 201 | optimizer.zero_grad() 202 | loss_sem.backward() 203 | optimizer.step() 204 | 205 | torch.cuda.empty_cache() 206 | torch.cuda.synchronize(torch.device("cuda")) 207 | 208 | if batch_idx % 20 == 0: 209 | trainloader_bar.set_postfix(trainloss='{:.3e}'.format(loss_display.avg)) 210 | 211 | logger.info('Epoch {}/{}: Train loss: {:.3e}'.format(epoch, maxepochs, loss_display.avg)) 212 | 213 | def set_logger(log_path): 214 | logger = logging.getLogger() 215 | logger.setLevel(logging.INFO) 216 | 217 | # Logging to a file 218 | file_handler = logging.FileHandler(log_path) 219 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 220 | logger.addHandler(file_handler) 221 | 222 | # Logging to console 223 | stream_handler = logging.StreamHandler() 224 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 225 | logger.addHandler(stream_handler) 226 | 227 | return logger 228 | 229 | if __name__ == '__main__': 230 | args = parse_args() 231 | 232 | args.save_path = os.path.join(args.save_path, args.expname) 233 | args.pseudo_path = os.path.join(args.save_path, 'pseudo_labels') 234 | 235 | '''Setup logger''' 236 | if not os.path.exists(args.save_path): 237 | os.makedirs(args.save_path) 238 | os.makedirs(args.pseudo_path) 239 | os.makedirs(join(args.save_path, 'cmd')) 240 | os.makedirs(join(args.save_path, 'svc')) 241 | logger = set_logger(os.path.join(args.save_path, 'train.log')) 242 | logger.info(args) 243 | 244 | '''Cache code''' 245 | cache_codes(args) 246 | 247 | '''Random Seed''' 248 | seed = args.seed 249 | set_seed(seed) 250 | 251 | main(args, logger) 252 | --------------------------------------------------------------------------------