├── .gitignore
├── README.md
├── figs
├── framework.jpg
└── scannet_visualization.jpg
└── pointdc_mk
├── README.md
├── data_prepare
├── S3DIS_anno_paths.txt
├── S3DIS_class_names.txt
├── ScanNet_splits
│ ├── scannetv2_test.txt
│ ├── scannetv2_train.txt
│ ├── scannetv2_trainval.txt
│ └── scannetv2_val.txt
├── data_prepare_S3DIS.py
├── data_prepare_ScanNet.py
├── initialSP_prepare_S3DIS.py
├── initialSP_prepare_ScanNet.py
└── semantic-kitti.yaml
├── datasets
├── S3DIS.py
├── ScanNet.py
└── SemanticKITTI.py
├── env.yaml
├── eval_S3DIS.py
├── eval_ScanNet.py
├── figs
├── framework.jpg
└── scannet_visualization.jpg
├── lib
├── aug_tools.py
├── helper_ply.py
├── utils.py
└── utils_s3dis.py
├── models
├── __init__.py
├── api_modules.py
├── common.py
├── fpn.py
├── modules.py
├── networks.py
├── pretrain_models.py
├── res16unet.py
└── resunet.py
├── train_S3DIS.py
└── train_ScanNet.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 |
3 | pointdc_mk/ckpt
4 | pointdc_mk/ckpt_old
5 | pointdc_mk/data
6 |
7 | */__pycache__
8 | */*/__pycache__
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://arxiv.org/abs/2304.08965)
2 | [](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode)
3 |
4 | ## PointDC:Unsupervised Semantic Segmentation of 3D Point Clouds via Cross-modal Distillation and Super-Voxel Clustering (ICCV 2023)
5 |
6 | ### Overview
7 |
8 | We propose an unsupervised point clouds semantic segmentation framework, called **PointDC**.
9 |
10 |
11 |
12 |
13 |
14 | ## NOTE
15 | There are two projects deployed here. [pointdc_mk](https://github.com/SCUT-BIP-Lab/PointDC/tree/main/pointdc_mk) is based on [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine).
16 |
17 | ## TODO
18 | - [x] Release code based on Minkowski and model weight files
19 | - [ ] Release code based on SpConv and model weight files
20 | - [x] Release Spare Feature Volume files
21 |
22 |
23 | ### Citation
24 | If this paper is helpful to you, please cite:
25 | ```
26 | @article{chen2023unsupervised,
27 | title={Unsupervised Semantic Segmentation of 3D Point Clouds via Cross-modal Distillation and Super-Voxel Clustering},
28 | author={Chen, Zisheng and Xu, Hongbin},
29 | journal={arXiv preprint arXiv:2304.08965},
30 | year={2023}
31 | }
32 | ```
--------------------------------------------------------------------------------
/figs/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCUT-BIP-Lab/PointDC/99f4f30ee4e2fa2d173c3881196afec4b02bb606/figs/framework.jpg
--------------------------------------------------------------------------------
/figs/scannet_visualization.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCUT-BIP-Lab/PointDC/99f4f30ee4e2fa2d173c3881196afec4b02bb606/figs/scannet_visualization.jpg
--------------------------------------------------------------------------------
/pointdc_mk/README.md:
--------------------------------------------------------------------------------
1 | [](https://arxiv.org/abs/2304.08965)
2 | [](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode)
3 |
4 | ## PointDC:Unsupervised Semantic Segmentation of 3D Point Clouds via Cross-modal Distillation and Super-Voxel Clustering (ICCV 2023)
5 |
6 | ### Overview
7 |
8 | We propose an unsupervised point clouds semantic segmentation framework, called **PointDC**.
9 |
10 |
11 |
12 |
13 |
14 | ## NOTE
15 | This project is based on Minkowski Engine and refers to the code from [growsp](https://github.com/vLAR-group/GrowSP), but the methods used are consistent with the original paper.
16 |
17 | ## TODO
18 | - [x] Release code deployed on the ScanNet dataset and model weight files
19 | - [x] Release code deployed on the S3DIS dataset and model weight files
20 | - [x] Release Spare Feature Volume files
21 |
22 | ## 1. Setup
23 | Setting up for this project involves installing dependencies.
24 |
25 | ### Installing dependencies
26 | To install all the dependencies, please run the following:
27 | ```shell script
28 | sudo apt install build-essential python3-dev libopenblas-dev
29 | conda env create -f env.yaml
30 | conda activate pointdc_mk
31 | pip install -U MinkowskiEngine --install-option="--blas=openblas" -v --no-deps
32 | ```
33 | ## 2. Running codes
34 | ### 2.1 ScanNet
35 | Download the ScanNet dataset from [the official website](http://kaldir.vc.in.tum.de/scannet_benchmark/documentation).
36 | You need to sign the terms of use. Uncompress the folder and move it to
37 | `${your_ScanNet}`.
38 | - Download sp feats files from [here](https://pan.baidu.com/s/1ibxoq3HyxRJa3KrnPafCWw?pwd=6666), and put it in the right path.
39 |
40 |
41 | - Preparing the dataset:
42 | ```shell script
43 | python data_prepare/data_prepare_ScanNet.py --data_path ${your_ScanNet}
44 | ```
45 | This code will preprcocess ScanNet and put it under `./data/ScanNet/processed`
46 |
47 | - Construct initial superpoints:
48 | ```shell script
49 | python data_prepare/initialSP_prepare_ScanNet.py
50 | ```
51 | This code will construct superpoints on ScanNet and put it under `./data/ScanNet/initial_superpoints`
52 |
53 | - Training:
54 | ```shell script
55 | CUDA_VISIBLE_DEVICES=0, python train_ScanNet.py --expname ${your_experiment_name}
56 | ```
57 | The output model and log file will be saved in `./ckpt/ScanNet` by default.
58 |
59 | - Evaling:
60 | Revise experiment name ```expnames=[eval_experiment_name]```in Lines 141.
61 | ```shell script
62 | CUDA_VISIBLE_DEVICES=0, python eval_ScanNet.py
63 | ```
64 |
65 | ### 2.2 S3DIS
66 | Download the S3DIS dataset from [the official website](https://docs.google.com/forms/d/e/1FAIpQLScDimvNMCGhy_rmBA2gHfDu3naktRm6A8BPwAWWDv-Uhm6Shw/viewform?c=0&w=1&pli=1), download the files named **Stanford3dDataset_v1.2.zip**.
67 | Uncompress the folder and move it to '${your_S3DIS}'. And there is an error in line 180389 of file Area_5/hallway_6/Annotations/ceiling_1.txt. It need to be fixed manually.
68 | - Download sp feats and sp files from [here](https://pan.baidu.com/s/1ibxoq3HyxRJa3KrnPafCWw?pwd=6666), and put it in the right path.
69 | >Due to the randomness in the construction of super-voxels, different super-voxels will lead to different super-voxel features. Therefore, we provide both super-voxel features and corresponding supe-voxels. This only affects the distillation stage.
70 |
71 | - Preparing the dataset:
72 | ```shell script
73 | python data_prepare/data_prepare_S3DIS.py --data_path ${your_S3DIS}
74 | ```
75 | This code will preprcocess S3DIS and put it under `./data/S3DIS/processed`
76 |
77 | - Construct initial superpoints:
78 | ```shell script
79 | python data_prepare/initialSP_prepare_S3DIS.py
80 | ```
81 | This code will construct superpoints on S3DIS and put it under `./data/S3DIS/initial_superpoints`
82 |
83 | - Training:
84 | ```shell script
85 | CUDA_VISIBLE_DEVICES=0, python train_S3DIS.py --expname ${your_experiment_name}
86 | ```
87 | The output model and log file will be saved in `./ckpt/S3DIS` by default.
88 |
89 | - Evaling:
90 | Revise experiment name `expnames=[eval_experiment_name]`.
91 | ```shell script
92 | CUDA_VISIBLE_DEVICES=0, python eval_S3DIS.py
93 | ```
94 |
95 | ## 3. Model Weights Files
96 | The trained models and other processed files can be found at [here](https://pan.baidu.com/s/1ibxoq3HyxRJa3KrnPafCWw?pwd=6666).
97 |
98 | ## Acknowledgement
99 | [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine)
100 |
101 | [growsp](https://github.com/vLAR-group/GrowSP)
102 |
--------------------------------------------------------------------------------
/pointdc_mk/data_prepare/S3DIS_anno_paths.txt:
--------------------------------------------------------------------------------
1 | Area_1/conferenceRoom_1/Annotations
2 | Area_1/conferenceRoom_2/Annotations
3 | Area_1/copyRoom_1/Annotations
4 | Area_1/hallway_1/Annotations
5 | Area_1/hallway_2/Annotations
6 | Area_1/hallway_3/Annotations
7 | Area_1/hallway_4/Annotations
8 | Area_1/hallway_5/Annotations
9 | Area_1/hallway_6/Annotations
10 | Area_1/hallway_7/Annotations
11 | Area_1/hallway_8/Annotations
12 | Area_1/office_10/Annotations
13 | Area_1/office_11/Annotations
14 | Area_1/office_12/Annotations
15 | Area_1/office_13/Annotations
16 | Area_1/office_14/Annotations
17 | Area_1/office_15/Annotations
18 | Area_1/office_16/Annotations
19 | Area_1/office_17/Annotations
20 | Area_1/office_18/Annotations
21 | Area_1/office_19/Annotations
22 | Area_1/office_1/Annotations
23 | Area_1/office_20/Annotations
24 | Area_1/office_21/Annotations
25 | Area_1/office_22/Annotations
26 | Area_1/office_23/Annotations
27 | Area_1/office_24/Annotations
28 | Area_1/office_25/Annotations
29 | Area_1/office_26/Annotations
30 | Area_1/office_27/Annotations
31 | Area_1/office_28/Annotations
32 | Area_1/office_29/Annotations
33 | Area_1/office_2/Annotations
34 | Area_1/office_30/Annotations
35 | Area_1/office_31/Annotations
36 | Area_1/office_3/Annotations
37 | Area_1/office_4/Annotations
38 | Area_1/office_5/Annotations
39 | Area_1/office_6/Annotations
40 | Area_1/office_7/Annotations
41 | Area_1/office_8/Annotations
42 | Area_1/office_9/Annotations
43 | Area_1/pantry_1/Annotations
44 | Area_1/WC_1/Annotations
45 | Area_2/auditorium_1/Annotations
46 | Area_2/auditorium_2/Annotations
47 | Area_2/conferenceRoom_1/Annotations
48 | Area_2/hallway_10/Annotations
49 | Area_2/hallway_11/Annotations
50 | Area_2/hallway_12/Annotations
51 | Area_2/hallway_1/Annotations
52 | Area_2/hallway_2/Annotations
53 | Area_2/hallway_3/Annotations
54 | Area_2/hallway_4/Annotations
55 | Area_2/hallway_5/Annotations
56 | Area_2/hallway_6/Annotations
57 | Area_2/hallway_7/Annotations
58 | Area_2/hallway_8/Annotations
59 | Area_2/hallway_9/Annotations
60 | Area_2/office_10/Annotations
61 | Area_2/office_11/Annotations
62 | Area_2/office_12/Annotations
63 | Area_2/office_13/Annotations
64 | Area_2/office_14/Annotations
65 | Area_2/office_1/Annotations
66 | Area_2/office_2/Annotations
67 | Area_2/office_3/Annotations
68 | Area_2/office_4/Annotations
69 | Area_2/office_5/Annotations
70 | Area_2/office_6/Annotations
71 | Area_2/office_7/Annotations
72 | Area_2/office_8/Annotations
73 | Area_2/office_9/Annotations
74 | Area_2/storage_1/Annotations
75 | Area_2/storage_2/Annotations
76 | Area_2/storage_3/Annotations
77 | Area_2/storage_4/Annotations
78 | Area_2/storage_5/Annotations
79 | Area_2/storage_6/Annotations
80 | Area_2/storage_7/Annotations
81 | Area_2/storage_8/Annotations
82 | Area_2/storage_9/Annotations
83 | Area_2/WC_1/Annotations
84 | Area_2/WC_2/Annotations
85 | Area_3/conferenceRoom_1/Annotations
86 | Area_3/hallway_1/Annotations
87 | Area_3/hallway_2/Annotations
88 | Area_3/hallway_3/Annotations
89 | Area_3/hallway_4/Annotations
90 | Area_3/hallway_5/Annotations
91 | Area_3/hallway_6/Annotations
92 | Area_3/lounge_1/Annotations
93 | Area_3/lounge_2/Annotations
94 | Area_3/office_10/Annotations
95 | Area_3/office_1/Annotations
96 | Area_3/office_2/Annotations
97 | Area_3/office_3/Annotations
98 | Area_3/office_4/Annotations
99 | Area_3/office_5/Annotations
100 | Area_3/office_6/Annotations
101 | Area_3/office_7/Annotations
102 | Area_3/office_8/Annotations
103 | Area_3/office_9/Annotations
104 | Area_3/storage_1/Annotations
105 | Area_3/storage_2/Annotations
106 | Area_3/WC_1/Annotations
107 | Area_3/WC_2/Annotations
108 | Area_4/conferenceRoom_1/Annotations
109 | Area_4/conferenceRoom_2/Annotations
110 | Area_4/conferenceRoom_3/Annotations
111 | Area_4/hallway_10/Annotations
112 | Area_4/hallway_11/Annotations
113 | Area_4/hallway_12/Annotations
114 | Area_4/hallway_13/Annotations
115 | Area_4/hallway_14/Annotations
116 | Area_4/hallway_1/Annotations
117 | Area_4/hallway_2/Annotations
118 | Area_4/hallway_3/Annotations
119 | Area_4/hallway_4/Annotations
120 | Area_4/hallway_5/Annotations
121 | Area_4/hallway_6/Annotations
122 | Area_4/hallway_7/Annotations
123 | Area_4/hallway_8/Annotations
124 | Area_4/hallway_9/Annotations
125 | Area_4/lobby_1/Annotations
126 | Area_4/lobby_2/Annotations
127 | Area_4/office_10/Annotations
128 | Area_4/office_11/Annotations
129 | Area_4/office_12/Annotations
130 | Area_4/office_13/Annotations
131 | Area_4/office_14/Annotations
132 | Area_4/office_15/Annotations
133 | Area_4/office_16/Annotations
134 | Area_4/office_17/Annotations
135 | Area_4/office_18/Annotations
136 | Area_4/office_19/Annotations
137 | Area_4/office_1/Annotations
138 | Area_4/office_20/Annotations
139 | Area_4/office_21/Annotations
140 | Area_4/office_22/Annotations
141 | Area_4/office_2/Annotations
142 | Area_4/office_3/Annotations
143 | Area_4/office_4/Annotations
144 | Area_4/office_5/Annotations
145 | Area_4/office_6/Annotations
146 | Area_4/office_7/Annotations
147 | Area_4/office_8/Annotations
148 | Area_4/office_9/Annotations
149 | Area_4/storage_1/Annotations
150 | Area_4/storage_2/Annotations
151 | Area_4/storage_3/Annotations
152 | Area_4/storage_4/Annotations
153 | Area_4/WC_1/Annotations
154 | Area_4/WC_2/Annotations
155 | Area_4/WC_3/Annotations
156 | Area_4/WC_4/Annotations
157 | Area_5/conferenceRoom_1/Annotations
158 | Area_5/conferenceRoom_2/Annotations
159 | Area_5/conferenceRoom_3/Annotations
160 | Area_5/hallway_10/Annotations
161 | Area_5/hallway_11/Annotations
162 | Area_5/hallway_12/Annotations
163 | Area_5/hallway_13/Annotations
164 | Area_5/hallway_14/Annotations
165 | Area_5/hallway_15/Annotations
166 | Area_5/hallway_1/Annotations
167 | Area_5/hallway_2/Annotations
168 | Area_5/hallway_3/Annotations
169 | Area_5/hallway_4/Annotations
170 | Area_5/hallway_5/Annotations
171 | Area_5/hallway_6/Annotations
172 | Area_5/hallway_7/Annotations
173 | Area_5/hallway_8/Annotations
174 | Area_5/hallway_9/Annotations
175 | Area_5/lobby_1/Annotations
176 | Area_5/office_10/Annotations
177 | Area_5/office_11/Annotations
178 | Area_5/office_12/Annotations
179 | Area_5/office_13/Annotations
180 | Area_5/office_14/Annotations
181 | Area_5/office_15/Annotations
182 | Area_5/office_16/Annotations
183 | Area_5/office_17/Annotations
184 | Area_5/office_18/Annotations
185 | Area_5/office_19/Annotations
186 | Area_5/office_1/Annotations
187 | Area_5/office_20/Annotations
188 | Area_5/office_21/Annotations
189 | Area_5/office_22/Annotations
190 | Area_5/office_23/Annotations
191 | Area_5/office_24/Annotations
192 | Area_5/office_25/Annotations
193 | Area_5/office_26/Annotations
194 | Area_5/office_27/Annotations
195 | Area_5/office_28/Annotations
196 | Area_5/office_29/Annotations
197 | Area_5/office_2/Annotations
198 | Area_5/office_30/Annotations
199 | Area_5/office_31/Annotations
200 | Area_5/office_32/Annotations
201 | Area_5/office_33/Annotations
202 | Area_5/office_34/Annotations
203 | Area_5/office_35/Annotations
204 | Area_5/office_36/Annotations
205 | Area_5/office_37/Annotations
206 | Area_5/office_38/Annotations
207 | Area_5/office_39/Annotations
208 | Area_5/office_3/Annotations
209 | Area_5/office_40/Annotations
210 | Area_5/office_41/Annotations
211 | Area_5/office_42/Annotations
212 | Area_5/office_4/Annotations
213 | Area_5/office_5/Annotations
214 | Area_5/office_6/Annotations
215 | Area_5/office_7/Annotations
216 | Area_5/office_8/Annotations
217 | Area_5/office_9/Annotations
218 | Area_5/pantry_1/Annotations
219 | Area_5/storage_1/Annotations
220 | Area_5/storage_2/Annotations
221 | Area_5/storage_3/Annotations
222 | Area_5/storage_4/Annotations
223 | Area_5/WC_1/Annotations
224 | Area_5/WC_2/Annotations
225 | Area_6/conferenceRoom_1/Annotations
226 | Area_6/copyRoom_1/Annotations
227 | Area_6/hallway_1/Annotations
228 | Area_6/hallway_2/Annotations
229 | Area_6/hallway_3/Annotations
230 | Area_6/hallway_4/Annotations
231 | Area_6/hallway_5/Annotations
232 | Area_6/hallway_6/Annotations
233 | Area_6/lounge_1/Annotations
234 | Area_6/office_10/Annotations
235 | Area_6/office_11/Annotations
236 | Area_6/office_12/Annotations
237 | Area_6/office_13/Annotations
238 | Area_6/office_14/Annotations
239 | Area_6/office_15/Annotations
240 | Area_6/office_16/Annotations
241 | Area_6/office_17/Annotations
242 | Area_6/office_18/Annotations
243 | Area_6/office_19/Annotations
244 | Area_6/office_1/Annotations
245 | Area_6/office_20/Annotations
246 | Area_6/office_21/Annotations
247 | Area_6/office_22/Annotations
248 | Area_6/office_23/Annotations
249 | Area_6/office_24/Annotations
250 | Area_6/office_25/Annotations
251 | Area_6/office_26/Annotations
252 | Area_6/office_27/Annotations
253 | Area_6/office_28/Annotations
254 | Area_6/office_29/Annotations
255 | Area_6/office_2/Annotations
256 | Area_6/office_30/Annotations
257 | Area_6/office_31/Annotations
258 | Area_6/office_32/Annotations
259 | Area_6/office_33/Annotations
260 | Area_6/office_34/Annotations
261 | Area_6/office_35/Annotations
262 | Area_6/office_36/Annotations
263 | Area_6/office_37/Annotations
264 | Area_6/office_3/Annotations
265 | Area_6/office_4/Annotations
266 | Area_6/office_5/Annotations
267 | Area_6/office_6/Annotations
268 | Area_6/office_7/Annotations
269 | Area_6/office_8/Annotations
270 | Area_6/office_9/Annotations
271 | Area_6/openspace_1/Annotations
272 | Area_6/pantry_1/Annotations
273 |
--------------------------------------------------------------------------------
/pointdc_mk/data_prepare/S3DIS_class_names.txt:
--------------------------------------------------------------------------------
1 | ceiling
2 | floor
3 | wall
4 | beam
5 | column
6 | window
7 | door
8 | table
9 | chair
10 | sofa
11 | bookcase
12 | board
13 | clutter
14 |
15 |
--------------------------------------------------------------------------------
/pointdc_mk/data_prepare/ScanNet_splits/scannetv2_test.txt:
--------------------------------------------------------------------------------
1 | scene0707_00_vh_clean_2.ply
2 | scene0708_00_vh_clean_2.ply
3 | scene0709_00_vh_clean_2.ply
4 | scene0710_00_vh_clean_2.ply
5 | scene0711_00_vh_clean_2.ply
6 | scene0712_00_vh_clean_2.ply
7 | scene0713_00_vh_clean_2.ply
8 | scene0714_00_vh_clean_2.ply
9 | scene0715_00_vh_clean_2.ply
10 | scene0716_00_vh_clean_2.ply
11 | scene0717_00_vh_clean_2.ply
12 | scene0718_00_vh_clean_2.ply
13 | scene0719_00_vh_clean_2.ply
14 | scene0720_00_vh_clean_2.ply
15 | scene0721_00_vh_clean_2.ply
16 | scene0722_00_vh_clean_2.ply
17 | scene0723_00_vh_clean_2.ply
18 | scene0724_00_vh_clean_2.ply
19 | scene0725_00_vh_clean_2.ply
20 | scene0726_00_vh_clean_2.ply
21 | scene0727_00_vh_clean_2.ply
22 | scene0728_00_vh_clean_2.ply
23 | scene0729_00_vh_clean_2.ply
24 | scene0730_00_vh_clean_2.ply
25 | scene0731_00_vh_clean_2.ply
26 | scene0732_00_vh_clean_2.ply
27 | scene0733_00_vh_clean_2.ply
28 | scene0734_00_vh_clean_2.ply
29 | scene0735_00_vh_clean_2.ply
30 | scene0736_00_vh_clean_2.ply
31 | scene0737_00_vh_clean_2.ply
32 | scene0738_00_vh_clean_2.ply
33 | scene0739_00_vh_clean_2.ply
34 | scene0740_00_vh_clean_2.ply
35 | scene0741_00_vh_clean_2.ply
36 | scene0742_00_vh_clean_2.ply
37 | scene0743_00_vh_clean_2.ply
38 | scene0744_00_vh_clean_2.ply
39 | scene0745_00_vh_clean_2.ply
40 | scene0746_00_vh_clean_2.ply
41 | scene0747_00_vh_clean_2.ply
42 | scene0748_00_vh_clean_2.ply
43 | scene0749_00_vh_clean_2.ply
44 | scene0750_00_vh_clean_2.ply
45 | scene0751_00_vh_clean_2.ply
46 | scene0752_00_vh_clean_2.ply
47 | scene0753_00_vh_clean_2.ply
48 | scene0754_00_vh_clean_2.ply
49 | scene0755_00_vh_clean_2.ply
50 | scene0756_00_vh_clean_2.ply
51 | scene0757_00_vh_clean_2.ply
52 | scene0758_00_vh_clean_2.ply
53 | scene0759_00_vh_clean_2.ply
54 | scene0760_00_vh_clean_2.ply
55 | scene0761_00_vh_clean_2.ply
56 | scene0762_00_vh_clean_2.ply
57 | scene0763_00_vh_clean_2.ply
58 | scene0764_00_vh_clean_2.ply
59 | scene0765_00_vh_clean_2.ply
60 | scene0766_00_vh_clean_2.ply
61 | scene0767_00_vh_clean_2.ply
62 | scene0768_00_vh_clean_2.ply
63 | scene0769_00_vh_clean_2.ply
64 | scene0770_00_vh_clean_2.ply
65 | scene0771_00_vh_clean_2.ply
66 | scene0772_00_vh_clean_2.ply
67 | scene0773_00_vh_clean_2.ply
68 | scene0774_00_vh_clean_2.ply
69 | scene0775_00_vh_clean_2.ply
70 | scene0776_00_vh_clean_2.ply
71 | scene0777_00_vh_clean_2.ply
72 | scene0778_00_vh_clean_2.ply
73 | scene0779_00_vh_clean_2.ply
74 | scene0780_00_vh_clean_2.ply
75 | scene0781_00_vh_clean_2.ply
76 | scene0782_00_vh_clean_2.ply
77 | scene0783_00_vh_clean_2.ply
78 | scene0784_00_vh_clean_2.ply
79 | scene0785_00_vh_clean_2.ply
80 | scene0786_00_vh_clean_2.ply
81 | scene0787_00_vh_clean_2.ply
82 | scene0788_00_vh_clean_2.ply
83 | scene0789_00_vh_clean_2.ply
84 | scene0790_00_vh_clean_2.ply
85 | scene0791_00_vh_clean_2.ply
86 | scene0792_00_vh_clean_2.ply
87 | scene0793_00_vh_clean_2.ply
88 | scene0794_00_vh_clean_2.ply
89 | scene0795_00_vh_clean_2.ply
90 | scene0796_00_vh_clean_2.ply
91 | scene0797_00_vh_clean_2.ply
92 | scene0798_00_vh_clean_2.ply
93 | scene0799_00_vh_clean_2.ply
94 | scene0800_00_vh_clean_2.ply
95 | scene0801_00_vh_clean_2.ply
96 | scene0802_00_vh_clean_2.ply
97 | scene0803_00_vh_clean_2.ply
98 | scene0804_00_vh_clean_2.ply
99 | scene0805_00_vh_clean_2.ply
100 | scene0806_00_vh_clean_2.ply
101 |
--------------------------------------------------------------------------------
/pointdc_mk/data_prepare/ScanNet_splits/scannetv2_val.txt:
--------------------------------------------------------------------------------
1 | scene0568_00.ply
2 | scene0568_01.ply
3 | scene0568_02.ply
4 | scene0304_00.ply
5 | scene0488_00.ply
6 | scene0488_01.ply
7 | scene0412_00.ply
8 | scene0412_01.ply
9 | scene0217_00.ply
10 | scene0019_00.ply
11 | scene0019_01.ply
12 | scene0414_00.ply
13 | scene0575_00.ply
14 | scene0575_01.ply
15 | scene0575_02.ply
16 | scene0426_00.ply
17 | scene0426_01.ply
18 | scene0426_02.ply
19 | scene0426_03.ply
20 | scene0549_00.ply
21 | scene0549_01.ply
22 | scene0578_00.ply
23 | scene0578_01.ply
24 | scene0578_02.ply
25 | scene0665_00.ply
26 | scene0665_01.ply
27 | scene0050_00.ply
28 | scene0050_01.ply
29 | scene0050_02.ply
30 | scene0257_00.ply
31 | scene0025_00.ply
32 | scene0025_01.ply
33 | scene0025_02.ply
34 | scene0583_00.ply
35 | scene0583_01.ply
36 | scene0583_02.ply
37 | scene0701_00.ply
38 | scene0701_01.ply
39 | scene0701_02.ply
40 | scene0580_00.ply
41 | scene0580_01.ply
42 | scene0565_00.ply
43 | scene0169_00.ply
44 | scene0169_01.ply
45 | scene0655_00.ply
46 | scene0655_01.ply
47 | scene0655_02.ply
48 | scene0063_00.ply
49 | scene0221_00.ply
50 | scene0221_01.ply
51 | scene0591_00.ply
52 | scene0591_01.ply
53 | scene0591_02.ply
54 | scene0678_00.ply
55 | scene0678_01.ply
56 | scene0678_02.ply
57 | scene0462_00.ply
58 | scene0427_00.ply
59 | scene0595_00.ply
60 | scene0193_00.ply
61 | scene0193_01.ply
62 | scene0164_00.ply
63 | scene0164_01.ply
64 | scene0164_02.ply
65 | scene0164_03.ply
66 | scene0598_00.ply
67 | scene0598_01.ply
68 | scene0598_02.ply
69 | scene0599_00.ply
70 | scene0599_01.ply
71 | scene0599_02.ply
72 | scene0328_00.ply
73 | scene0300_00.ply
74 | scene0300_01.ply
75 | scene0354_00.ply
76 | scene0458_00.ply
77 | scene0458_01.ply
78 | scene0423_00.ply
79 | scene0423_01.ply
80 | scene0423_02.ply
81 | scene0307_00.ply
82 | scene0307_01.ply
83 | scene0307_02.ply
84 | scene0606_00.ply
85 | scene0606_01.ply
86 | scene0606_02.ply
87 | scene0432_00.ply
88 | scene0432_01.ply
89 | scene0608_00.ply
90 | scene0608_01.ply
91 | scene0608_02.ply
92 | scene0651_00.ply
93 | scene0651_01.ply
94 | scene0651_02.ply
95 | scene0430_00.ply
96 | scene0430_01.ply
97 | scene0689_00.ply
98 | scene0357_00.ply
99 | scene0357_01.ply
100 | scene0574_00.ply
101 | scene0574_01.ply
102 | scene0574_02.ply
103 | scene0329_00.ply
104 | scene0329_01.ply
105 | scene0329_02.ply
106 | scene0153_00.ply
107 | scene0153_01.ply
108 | scene0616_00.ply
109 | scene0616_01.ply
110 | scene0671_00.ply
111 | scene0671_01.ply
112 | scene0618_00.ply
113 | scene0382_00.ply
114 | scene0382_01.ply
115 | scene0490_00.ply
116 | scene0621_00.ply
117 | scene0607_00.ply
118 | scene0607_01.ply
119 | scene0149_00.ply
120 | scene0695_00.ply
121 | scene0695_01.ply
122 | scene0695_02.ply
123 | scene0695_03.ply
124 | scene0389_00.ply
125 | scene0377_00.ply
126 | scene0377_01.ply
127 | scene0377_02.ply
128 | scene0342_00.ply
129 | scene0139_00.ply
130 | scene0629_00.ply
131 | scene0629_01.ply
132 | scene0629_02.ply
133 | scene0496_00.ply
134 | scene0633_00.ply
135 | scene0633_01.ply
136 | scene0518_00.ply
137 | scene0652_00.ply
138 | scene0406_00.ply
139 | scene0406_01.ply
140 | scene0406_02.ply
141 | scene0144_00.ply
142 | scene0144_01.ply
143 | scene0494_00.ply
144 | scene0278_00.ply
145 | scene0278_01.ply
146 | scene0316_00.ply
147 | scene0609_00.ply
148 | scene0609_01.ply
149 | scene0609_02.ply
150 | scene0609_03.ply
151 | scene0084_00.ply
152 | scene0084_01.ply
153 | scene0084_02.ply
154 | scene0696_00.ply
155 | scene0696_01.ply
156 | scene0696_02.ply
157 | scene0351_00.ply
158 | scene0351_01.ply
159 | scene0643_00.ply
160 | scene0644_00.ply
161 | scene0645_00.ply
162 | scene0645_01.ply
163 | scene0645_02.ply
164 | scene0081_00.ply
165 | scene0081_01.ply
166 | scene0081_02.ply
167 | scene0647_00.ply
168 | scene0647_01.ply
169 | scene0535_00.ply
170 | scene0353_00.ply
171 | scene0353_01.ply
172 | scene0353_02.ply
173 | scene0559_00.ply
174 | scene0559_01.ply
175 | scene0559_02.ply
176 | scene0593_00.ply
177 | scene0593_01.ply
178 | scene0246_00.ply
179 | scene0653_00.ply
180 | scene0653_01.ply
181 | scene0064_00.ply
182 | scene0064_01.ply
183 | scene0356_00.ply
184 | scene0356_01.ply
185 | scene0356_02.ply
186 | scene0030_00.ply
187 | scene0030_01.ply
188 | scene0030_02.ply
189 | scene0222_00.ply
190 | scene0222_01.ply
191 | scene0338_00.ply
192 | scene0338_01.ply
193 | scene0338_02.ply
194 | scene0378_00.ply
195 | scene0378_01.ply
196 | scene0378_02.ply
197 | scene0660_00.ply
198 | scene0553_00.ply
199 | scene0553_01.ply
200 | scene0553_02.ply
201 | scene0527_00.ply
202 | scene0663_00.ply
203 | scene0663_01.ply
204 | scene0663_02.ply
205 | scene0664_00.ply
206 | scene0664_01.ply
207 | scene0664_02.ply
208 | scene0334_00.ply
209 | scene0334_01.ply
210 | scene0334_02.ply
211 | scene0046_00.ply
212 | scene0046_01.ply
213 | scene0046_02.ply
214 | scene0203_00.ply
215 | scene0203_01.ply
216 | scene0203_02.ply
217 | scene0088_00.ply
218 | scene0088_01.ply
219 | scene0088_02.ply
220 | scene0088_03.ply
221 | scene0086_00.ply
222 | scene0086_01.ply
223 | scene0086_02.ply
224 | scene0670_00.ply
225 | scene0670_01.ply
226 | scene0256_00.ply
227 | scene0256_01.ply
228 | scene0256_02.ply
229 | scene0249_00.ply
230 | scene0441_00.ply
231 | scene0658_00.ply
232 | scene0704_00.ply
233 | scene0704_01.ply
234 | scene0187_00.ply
235 | scene0187_01.ply
236 | scene0131_00.ply
237 | scene0131_01.ply
238 | scene0131_02.ply
239 | scene0207_00.ply
240 | scene0207_01.ply
241 | scene0207_02.ply
242 | scene0461_00.ply
243 | scene0011_00.ply
244 | scene0011_01.ply
245 | scene0343_00.ply
246 | scene0251_00.ply
247 | scene0077_00.ply
248 | scene0077_01.ply
249 | scene0684_00.ply
250 | scene0684_01.ply
251 | scene0550_00.ply
252 | scene0686_00.ply
253 | scene0686_01.ply
254 | scene0686_02.ply
255 | scene0208_00.ply
256 | scene0500_00.ply
257 | scene0500_01.ply
258 | scene0552_00.ply
259 | scene0552_01.ply
260 | scene0648_00.ply
261 | scene0648_01.ply
262 | scene0435_00.ply
263 | scene0435_01.ply
264 | scene0435_02.ply
265 | scene0435_03.ply
266 | scene0690_00.ply
267 | scene0690_01.ply
268 | scene0693_00.ply
269 | scene0693_01.ply
270 | scene0693_02.ply
271 | scene0700_00.ply
272 | scene0700_01.ply
273 | scene0700_02.ply
274 | scene0699_00.ply
275 | scene0231_00.ply
276 | scene0231_01.ply
277 | scene0231_02.ply
278 | scene0697_00.ply
279 | scene0697_01.ply
280 | scene0697_02.ply
281 | scene0697_03.ply
282 | scene0474_00.ply
283 | scene0474_01.ply
284 | scene0474_02.ply
285 | scene0474_03.ply
286 | scene0474_04.ply
287 | scene0474_05.ply
288 | scene0355_00.ply
289 | scene0355_01.ply
290 | scene0146_00.ply
291 | scene0146_01.ply
292 | scene0146_02.ply
293 | scene0196_00.ply
294 | scene0702_00.ply
295 | scene0702_01.ply
296 | scene0702_02.ply
297 | scene0314_00.ply
298 | scene0277_00.ply
299 | scene0277_01.ply
300 | scene0277_02.ply
301 | scene0095_00.ply
302 | scene0095_01.ply
303 | scene0015_00.ply
304 | scene0100_00.ply
305 | scene0100_01.ply
306 | scene0100_02.ply
307 | scene0558_00.ply
308 | scene0558_01.ply
309 | scene0558_02.ply
310 | scene0685_00.ply
311 | scene0685_01.ply
312 | scene0685_02.ply
313 |
--------------------------------------------------------------------------------
/pointdc_mk/data_prepare/data_prepare_S3DIS.py:
--------------------------------------------------------------------------------
1 | import MinkowskiEngine as ME
2 | from os.path import join, exists, dirname, abspath
3 | import numpy as np
4 | import pandas as pd
5 | import os, sys, glob
6 | import argparse
7 | from tqdm import tqdm
8 |
9 | BASE_DIR = dirname(abspath(__file__))
10 | ROOT_DIR = dirname(BASE_DIR)
11 | sys.path.append(BASE_DIR)
12 | sys.path.append(ROOT_DIR)
13 | from lib.helper_ply import write_ply
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--data_path', type=str, default='data/Stanford3dDataset_v1.2', help='raw data path')
17 | parser.add_argument('--processed_data_path', type=str, default='data/S3DIS/processed')
18 | args = parser.parse_args()
19 |
20 | args.data_path = join(ROOT_DIR, args.data_path)
21 | args.processed_data_path = join(ROOT_DIR, args.processed_data_path)
22 |
23 | anno_paths = [line.rstrip() for line in open(join(BASE_DIR, 'S3DIS_anno_paths.txt'))]
24 | anno_paths = [join(args.data_path, p) for p in anno_paths]
25 |
26 | gt_class = [x.rstrip() for x in open(join(BASE_DIR, 'S3DIS_class_names.txt'))]
27 | gt_class2label = {cls: i for i, cls in enumerate(gt_class)}
28 |
29 | sub_grid_size = 0.010
30 | if not exists(args.processed_data_path):
31 | os.makedirs(args.processed_data_path)
32 | out_format = '.ply'
33 |
34 | def convert_pc2ply(anno_path, file_name):
35 | sub_ply_file = join(args.processed_data_path, file_name)
36 | if os.path.exists(sub_ply_file): return
37 | data_list = []
38 |
39 | for f in glob.glob(join(anno_path, '*.txt')):
40 | class_name = os.path.basename(f).split('_')[0]
41 | if class_name not in gt_class: # note: in some room there is 'staris' class..
42 | class_name = 'clutter'
43 | pc = pd.read_csv(f, header=None, delim_whitespace=True).values
44 | labels = np.ones((pc.shape[0], 1)) * gt_class2label[class_name]
45 | data_list.append(np.concatenate([pc, labels], 1))
46 |
47 | pc_info = np.concatenate(data_list, 0)
48 |
49 | coords = pc_info[:, :3]
50 | colors = pc_info[:, 3:6].astype(np.uint8)
51 | labels = pc_info[:, 6]
52 |
53 | _, _, collabels, inds = ME.utils.sparse_quantize(np.ascontiguousarray(coords), colors, labels, return_index=True, ignore_label=-1, quantization_size=sub_grid_size)
54 | sub_coords, sub_colors, sub_labels = coords[inds], colors[inds], collabels
55 |
56 | write_ply(sub_ply_file, [sub_coords, sub_colors, sub_labels[:,None]], ['x', 'y', 'z', 'red', 'green', 'blue', 'class'])
57 |
58 | print('start preprocess')
59 | # Note: there is an extra character in the v1.2 data in Area_5/hallway_6. It's fixed manually.
60 | anno_paths_bar = tqdm(anno_paths)
61 | for annotation_path in anno_paths_bar:
62 | # print(annotation_path)
63 | anno_paths_bar.set_description(annotation_path.split('/')[-3] + '_' + annotation_path.split('/')[-2])
64 | elements = str(annotation_path).split('/')
65 | out_file_name = elements[-3] + '_' + elements[-2] + out_format
66 | convert_pc2ply(annotation_path, out_file_name)
67 |
--------------------------------------------------------------------------------
/pointdc_mk/data_prepare/data_prepare_ScanNet.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from os.path import join, exists, dirname, abspath
3 | import os, sys, glob
4 |
5 | import numpy as np
6 | BASE_DIR = dirname(abspath(__file__))
7 | ROOT_DIR = dirname(BASE_DIR)
8 | sys.path.append(BASE_DIR)
9 | sys.path.append(ROOT_DIR)
10 | from lib.helper_ply import read_ply, write_ply
11 | from concurrent.futures import ProcessPoolExecutor
12 | import argparse
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--data_path', type=str, default='data/ScanNet/', help='raw data path')
16 | parser.add_argument('--processed_data_path', type=str, default='data/ScanNet/')
17 | args = parser.parse_args()
18 |
19 | SCANNET_RAW_PATH = Path(join(ROOT_DIR, args.data_path))
20 | SCANNET_OUT_PATH = Path(join(ROOT_DIR, args.processed_data_path))
21 | TRAIN_DEST = 'train'
22 | TEST_DEST = 'test'
23 | SUBSETS = {TRAIN_DEST: 'scans', TEST_DEST: 'scans_test'}
24 | POINTCLOUD_FILE = '_vh_clean_2.ply'
25 | BUGS = {
26 | 'scene0270_00': 50,
27 | 'scene0270_02': 50,
28 | 'scene0384_00': 149,
29 | }
30 | CLASS_LABELS = ('wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
31 | 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refrigerator',
32 | 'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture')
33 | VALID_CLASS_IDS = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39)
34 | NUM_LABELS = 41
35 | IGNORE_LABELS = tuple(set(range(41)) - set(VALID_CLASS_IDS))
36 |
37 | ''' Set Invalid Label to -1'''
38 | label_map = {}
39 | n_used = 0
40 | for l in range(NUM_LABELS):
41 | if l in IGNORE_LABELS:
42 | label_map[l] = -1#ignore label
43 | else:
44 | label_map[l] = n_used
45 | n_used += 1
46 |
47 | def handle_process(path):
48 | f = Path(path.split(',')[0])
49 | phase_out_path = Path(path.split(',')[1])
50 | # pointcloud = read_ply(f)
51 | data = read_ply(str(f), triangular_mesh=True)[0]
52 | coords = np.vstack((data['x'], data['y'], data['z'])).T
53 | colors = np.vstack((data['red'], data['green'], data['blue'])).T
54 | # Load label file.
55 | label_f = f.parent / (f.stem + '.labels' + f.suffix)
56 | if label_f.is_file():
57 | label = read_ply(str(label_f), triangular_mesh=True)[0]['label'].T.squeeze()
58 | else: # Label may not exist in test case.
59 | label = -np.zeros(data.shape)
60 | out_f = phase_out_path / (f.name[:-len(POINTCLOUD_FILE)] + f.suffix)
61 |
62 | '''Fix Data Bug'''
63 | for item, bug_index in BUGS.items():
64 | if item in path:
65 | print('Fixing {} bugged label'.format(item))
66 | bug_mask = label == bug_index
67 | label[bug_mask] = 0
68 |
69 | label = np.array([label_map[x] for x in label])
70 | write_ply(str(out_f), [coords.astype(np.float64), colors, label[:, None].astype(np.float64)], ['x', 'y', 'z', 'red', 'green', 'blue', 'class'])
71 |
72 |
73 | print('start preprocess')
74 |
75 | path_list = []
76 | for out_path, in_path in SUBSETS.items():
77 | phase_out_path = SCANNET_OUT_PATH / out_path
78 | # phase_out_path = SCANNET_OUT_PATH
79 | phase_out_path.mkdir(parents=True, exist_ok=True)
80 | for f in (SCANNET_RAW_PATH / in_path).glob('*/*' + POINTCLOUD_FILE):
81 | path_list.append(str(f) + ',' + str(phase_out_path))
82 |
83 | pool = ProcessPoolExecutor(max_workers=30)
84 | result = list(pool.map(handle_process, path_list))
--------------------------------------------------------------------------------
/pointdc_mk/data_prepare/initialSP_prepare_S3DIS.py:
--------------------------------------------------------------------------------
1 | from pclpy import pcl
2 | import pclpy
3 | import numpy as np
4 | from scipy import stats
5 | import os
6 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
7 | from os.path import join, exists, dirname, abspath
8 | import sys, glob
9 |
10 | BASE_DIR = dirname(abspath(__file__))
11 | ROOT_DIR = dirname(BASE_DIR)
12 | sys.path.append(BASE_DIR)
13 | sys.path.append(ROOT_DIR)
14 | from lib.helper_ply import read_ply, write_ply
15 | import time
16 | import MinkowskiEngine as ME
17 | import matplotlib.pyplot as plt
18 | from concurrent.futures import ProcessPoolExecutor
19 | from pathlib import Path
20 | from tqdm import tqdm
21 |
22 | colormap = []
23 | for _ in range(1000):
24 | for k in range(12):
25 | colormap.append(plt.cm.Set3(k))
26 | for k in range(9):
27 | colormap.append(plt.cm.Set1(k))
28 | for k in range(8):
29 | colormap.append(plt.cm.Set2(k))
30 | colormap.append((0, 0, 0, 0))
31 | colormap = np.array(colormap)
32 | import argparse
33 |
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument('--input_path', type=str, default='data/S3DIS/processed', help='raw data path')
36 | parser.add_argument('--sp_path', type=str, default='data/S3DIS/initial_superpoints')
37 | args = parser.parse_args()
38 |
39 | args.input_path = join(ROOT_DIR, args.input_path)
40 | args.sp_path = join(ROOT_DIR, args.sp_path)
41 |
42 | # ignore_label = 12
43 | voxel_size = 0.05
44 | vis = True
45 |
46 | def supervoxel_clustering(coords, rgb=None):
47 | pc = pcl.PointCloud.PointXYZRGBA(coords, rgb)
48 | normals = pc.compute_normals(radius=3, num_threads=2)
49 | vox = pcl.segmentation.SupervoxelClustering.PointXYZRGBA(voxel_resolution=1, seed_resolution=10)
50 | vox.setInputCloud(pc)
51 | vox.setNormalCloud(normals)
52 | vox.setSpatialImportance(0.4)
53 | vox.setNormalImportance(1)
54 | vox.setColorImportance(0.2)
55 | output = pcl.vectors.map_uint32t_PointXYZRGBA()
56 | vox.extract(output)
57 | return list(output.items())
58 |
59 | def region_growing_simple(coords):
60 | pc = pcl.PointCloud.PointXYZ(coords)
61 | normals = pc.compute_normals(radius=3, num_threads=2)
62 | clusters = pclpy.region_growing(pc, normals=normals, min_size=1, max_size=100000, n_neighbours=15,
63 | smooth_threshold=3, curvature_threshold=1, residual_threshold=1)
64 | return clusters, normals.normals
65 |
66 |
67 | def construct_superpoints(path):
68 | f = Path(path)
69 | data = read_ply(f)
70 | coords = np.vstack((data['x'], data['y'], data['z'])).T.copy()
71 | feats = np.vstack((data['red'], data['green'], data['blue'])).T.copy()
72 | labels = data['class'].copy()
73 | coords = coords.astype(np.float32)
74 | coords -= coords.mean(0)
75 |
76 | time_start = time.time()
77 | '''Voxelize'''
78 | scale = 1 / voxel_size
79 | coords = np.floor(coords * scale)
80 | coords, feats, labels, unique_map, inverse_map = ME.utils.sparse_quantize(np.ascontiguousarray(coords),
81 | feats, labels=labels, ignore_label=-1, return_index=True, return_inverse=True)
82 | coords = coords.numpy().astype(np.float32)
83 |
84 | '''VCCS'''
85 | out = supervoxel_clustering(coords, feats)
86 | voxel_idx = -np.ones_like(labels)
87 | voxel_num = 0
88 | for voxel in range(len(out)):
89 | if out[voxel][1].voxels_.xyz.shape[0] >= 0:
90 | for xyz_voxel in out[voxel][1].voxels_.xyz:
91 | index_colum = np.where((xyz_voxel == coords).all(1))
92 | voxel_idx[index_colum] = voxel_num
93 | voxel_num += 1
94 |
95 | '''Region Growing'''
96 | clusters = region_growing_simple(coords)[0]
97 | region_idx = -1 * np.ones_like(labels)
98 | for region in range(len(clusters)):
99 | for point_idx in clusters[region].indices:
100 | region_idx[point_idx] = region
101 |
102 | '''Merging'''
103 | merged = -np.ones_like(labels)
104 | voxel_idx[voxel_idx != -1] += len(clusters)
105 | for v in np.unique(voxel_idx):
106 | if v != -1:
107 | voxel_mask = v == voxel_idx
108 | voxel2region = region_idx[voxel_mask] ### count which regions are appeared in current voxel
109 | dominant_region = stats.mode(voxel2region)[0][0]
110 | if (dominant_region == voxel2region).sum() > voxel2region.shape[0] * 0.5:
111 | merged[voxel_mask] = dominant_region
112 | else:
113 | merged[voxel_mask] = v
114 |
115 | '''Make Superpoint Labels Continuous'''
116 | sp_labels = -np.ones_like(merged)
117 | count_num = 0
118 | for m in np.unique(merged):
119 | if m != -1:
120 | sp_labels[merged == m] = count_num
121 | count_num += 1
122 |
123 | '''ReProject to Input Point Cloud'''
124 | out_sp_labels = sp_labels[inverse_map]
125 | out_coords = np.vstack((data['x'], data['y'], data['z'])).T
126 | out_labels = data['class'].squeeze()
127 | #
128 | if not exists(args.sp_path):
129 | os.makedirs(args.sp_path)
130 | np.save(args.sp_path + '/' + f.name[:-4] + '_superpoint.npy', out_sp_labels)
131 |
132 | if vis:
133 | vis_path = args.sp_path +'/vis/'
134 | if not os.path.exists(vis_path):
135 | os.makedirs(vis_path)
136 | colors = np.zeros_like(out_coords)
137 | for p in range(colors.shape[0]):
138 | colors[p] = 255 * (colormap[out_sp_labels[p].astype(np.int32)])[:3]
139 | colors = colors.astype(np.uint8)
140 | write_ply(vis_path + '/' + f.name, [out_coords, colors], ['x', 'y', 'z', 'red', 'green', 'blue'])
141 |
142 | sp2gt = -np.ones_like(out_labels)
143 | for sp in np.unique(out_sp_labels):
144 | if sp != -1:
145 | sp_mask = sp == out_sp_labels
146 | sp2gt[sp_mask] = stats.mode(out_labels[sp_mask])[0][0]
147 |
148 | print('completed scene: {}, used time: {:.2f}s'.format(f.name, time.time() - time_start))
149 | return (out_labels, sp2gt)
150 |
151 |
152 |
153 | print('start constructing initial superpoints')
154 | path_list = []
155 | folders = sorted(glob.glob(args.input_path + '/*.ply'))
156 | for _, file in enumerate(folders):
157 | path_list.append(file)
158 |
159 | construct_superpoints(path_list[0])
160 | # pool = ProcessPoolExecutor(max_workers=40)
161 | # result = list(pool.map(construct_superpoints, path_list))
162 | print('end constructing initial superpoints')
163 |
--------------------------------------------------------------------------------
/pointdc_mk/data_prepare/initialSP_prepare_ScanNet.py:
--------------------------------------------------------------------------------
1 | from pclpy import pcl
2 | import pclpy
3 | import numpy as np
4 | from scipy import stats
5 | from os.path import join, exists, dirname, abspath
6 | import sys, glob
7 | import json
8 |
9 | BASE_DIR = dirname(abspath(__file__))
10 | ROOT_DIR = dirname(BASE_DIR)
11 | sys.path.append(BASE_DIR)
12 | sys.path.append(ROOT_DIR)
13 | from lib.helper_ply import read_ply, write_ply
14 | import os
15 | import matplotlib.pyplot as plt
16 | from concurrent.futures import ProcessPoolExecutor
17 | from pathlib import Path
18 |
19 | trainval_file = [line.rstrip() for line in open(join(BASE_DIR, 'ScanNet_splits/scannetv2_trainval.txt'))]
20 |
21 | colormap = []
22 | for _ in range(1000):
23 | for k in range(12):
24 | colormap.append(plt.cm.Set3(k))
25 | for k in range(9):
26 | colormap.append(plt.cm.Set1(k))
27 | for k in range(8):
28 | colormap.append(plt.cm.Set2(k))
29 | colormap.append((0, 0, 0, 0))
30 | colormap = np.array(colormap)
31 |
32 | import argparse
33 |
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument('--input_path', type=str, default='data/ScanNet/scans/', help='raw data path') # *.segs.json文件路径
36 | parser.add_argument('--sp_path', type=str, default='data/ScanNet/initial_superpoints/') # 保存超点云文件路径
37 | parser.add_argument('--pc_path', type=str, default='data/ScanNet/train/') # 处理后点云文件路径
38 | args = parser.parse_args()
39 |
40 | args.input_path = join(ROOT_DIR, args.input_path)
41 | args.sp_path = join(ROOT_DIR, args.sp_path)
42 | args.pc_path = join(ROOT_DIR, args.pc_path)
43 |
44 | vis = True
45 |
46 | def read_superpoints(path):
47 | # 读取json文件
48 | with open(path, 'r', encoding='utf-8') as f:
49 | json_data = json.load(f)
50 | # 读取SP,重新排序
51 | ori_sp = json_data['segIndices']
52 | unique_vals = sorted(np.unique(ori_sp))
53 | sp_labels = np.searchsorted(unique_vals, ori_sp)
54 | # 保存文件
55 | if not os.path.exists(args.sp_path):
56 | os.makedirs(args.sp_path)
57 | np.save(args.sp_path + path.split('/')[-2] + '_superpoint.npy', sp_labels)
58 | # 可视化文件
59 | if vis:
60 | vis_path = args.sp_path+'/vis/'
61 | if not os.path.exists(vis_path):
62 | os.makedirs(vis_path)
63 | pc = Path(join(args.pc_path, path.split('/')[-2]+'.ply')) # pc文件路径
64 | data = read_ply(pc)
65 | coords = np.vstack((data['x'], data['y'], data['z'])).T.copy().astype(np.float32)
66 | colors = np.zeros_like(coords)
67 | for p in range(colors.shape[0]):
68 | colors[p] = 255 * (colormap[sp_labels[p].astype(np.int32)])[:3]
69 | colors = colors.astype(np.uint8)
70 | write_ply(vis_path + '/' + path.split('/')[-2], [coords, colors], ['x', 'y', 'z', 'red', 'green', 'blue'])
71 |
72 | print('start constructing initial superpoints')
73 | folders = sorted(glob.glob(args.input_path + '*/*.segs.json'))
74 | # read_superpoints(folders[0])
75 | pool = ProcessPoolExecutor(max_workers=40)
76 | result = list(pool.map(read_superpoints, folders))
--------------------------------------------------------------------------------
/pointdc_mk/data_prepare/semantic-kitti.yaml:
--------------------------------------------------------------------------------
1 | # This file is covered by the LICENSE file in the root of this project.
2 | labels:
3 | 0 : "unlabeled"
4 | 1 : "outlier"
5 | 10: "car"
6 | 11: "bicycle"
7 | 13: "bus"
8 | 15: "motorcycle"
9 | 16: "on-rails"
10 | 18: "truck"
11 | 20: "other-vehicle"
12 | 30: "person"
13 | 31: "bicyclist"
14 | 32: "motorcyclist"
15 | 40: "road"
16 | 44: "parking"
17 | 48: "sidewalk"
18 | 49: "other-ground"
19 | 50: "building"
20 | 51: "fence"
21 | 52: "other-structure"
22 | 60: "lane-marking"
23 | 70: "vegetation"
24 | 71: "trunk"
25 | 72: "terrain"
26 | 80: "pole"
27 | 81: "traffic-sign"
28 | 99: "other-object"
29 | 252: "moving-car"
30 | 253: "moving-bicyclist"
31 | 254: "moving-person"
32 | 255: "moving-motorcyclist"
33 | 256: "moving-on-rails"
34 | 257: "moving-bus"
35 | 258: "moving-truck"
36 | 259: "moving-other-vehicle"
37 | color_map: # bgr
38 | 0 : [0, 0, 0]
39 | 1 : [0, 0, 255]
40 | 10: [245, 150, 100]
41 | 11: [245, 230, 100]
42 | 13: [250, 80, 100]
43 | 15: [150, 60, 30]
44 | 16: [255, 0, 0]
45 | 18: [180, 30, 80]
46 | 20: [255, 0, 0]
47 | 30: [30, 30, 255]
48 | 31: [200, 40, 255]
49 | 32: [90, 30, 150]
50 | 40: [255, 0, 255]
51 | 44: [255, 150, 255]
52 | 48: [75, 0, 75]
53 | 49: [75, 0, 175]
54 | 50: [0, 200, 255]
55 | 51: [50, 120, 255]
56 | 52: [0, 150, 255]
57 | 60: [170, 255, 150]
58 | 70: [0, 175, 0]
59 | 71: [0, 60, 135]
60 | 72: [80, 240, 150]
61 | 80: [150, 240, 255]
62 | 81: [0, 0, 255]
63 | 99: [255, 255, 50]
64 | 252: [245, 150, 100]
65 | 256: [255, 0, 0]
66 | 253: [200, 40, 255]
67 | 254: [30, 30, 255]
68 | 255: [90, 30, 150]
69 | 257: [250, 80, 100]
70 | 258: [180, 30, 80]
71 | 259: [255, 0, 0]
72 | content: # as a ratio with the total number of points
73 | 0: 0.018889854628292943
74 | 1: 0.0002937197336781505
75 | 10: 0.040818519255974316
76 | 11: 0.00016609538710764618
77 | 13: 2.7879693665067774e-05
78 | 15: 0.00039838616015114444
79 | 16: 0.0
80 | 18: 0.0020633612104619787
81 | 20: 0.0016218197275284021
82 | 30: 0.00017698551338515307
83 | 31: 1.1065903904919655e-08
84 | 32: 5.532951952459828e-09
85 | 40: 0.1987493871255525
86 | 44: 0.014717169549888214
87 | 48: 0.14392298360372
88 | 49: 0.0039048553037472045
89 | 50: 0.1326861944777486
90 | 51: 0.0723592229456223
91 | 52: 0.002395131480328884
92 | 60: 4.7084144280367186e-05
93 | 70: 0.26681502148037506
94 | 71: 0.006035012012626033
95 | 72: 0.07814222006271769
96 | 80: 0.002855498193863172
97 | 81: 0.0006155958086189918
98 | 99: 0.009923127583046915
99 | 252: 0.001789309418528068
100 | 253: 0.00012709999297008662
101 | 254: 0.00016059776092534436
102 | 255: 3.745553104802113e-05
103 | 256: 0.0
104 | 257: 0.00011351574470342043
105 | 258: 0.00010157861367183268
106 | 259: 4.3840131989471124e-05
107 | # classes that are indistinguishable from single scan or inconsistent in
108 | # ground truth are mapped to their closest equivalent
109 | learning_map:
110 | 0 : 0 # "unlabeled"
111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped
112 | 10: 1 # "car"
113 | 11: 2 # "bicycle"
114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped
115 | 15: 3 # "motorcycle"
116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped
117 | 18: 4 # "truck"
118 | 20: 5 # "other-vehicle"
119 | 30: 6 # "person"
120 | 31: 7 # "bicyclist"
121 | 32: 8 # "motorcyclist"
122 | 40: 9 # "road"
123 | 44: 10 # "parking"
124 | 48: 11 # "sidewalk"
125 | 49: 12 # "other-ground"
126 | 50: 13 # "building"
127 | 51: 14 # "fence"
128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped
129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped
130 | 70: 15 # "vegetation"
131 | 71: 16 # "trunk"
132 | 72: 17 # "terrain"
133 | 80: 18 # "pole"
134 | 81: 19 # "traffic-sign"
135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped
136 | 252: 1 # "moving-car" to "car" ------------------------------------mapped
137 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped
138 | 254: 6 # "moving-person" to "person" ------------------------------mapped
139 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped
140 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped
141 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped
142 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped
143 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped
144 | learning_map_inv: # inverse of previous map
145 | 0: 0 # "unlabeled", and others ignored
146 | 1: 10 # "car"
147 | 2: 11 # "bicycle"
148 | 3: 15 # "motorcycle"
149 | 4: 18 # "truck"
150 | 5: 20 # "other-vehicle"
151 | 6: 30 # "person"
152 | 7: 31 # "bicyclist"
153 | 8: 32 # "motorcyclist"
154 | 9: 40 # "road"
155 | 10: 44 # "parking"
156 | 11: 48 # "sidewalk"
157 | 12: 49 # "other-ground"
158 | 13: 50 # "building"
159 | 14: 51 # "fence"
160 | 15: 70 # "vegetation"
161 | 16: 71 # "trunk"
162 | 17: 72 # "terrain"
163 | 18: 80 # "pole"
164 | 19: 81 # "traffic-sign"
165 | learning_ignore: # Ignore classes
166 | 0: True # "unlabeled", and others ignored
167 | 1: False # "car"
168 | 2: False # "bicycle"
169 | 3: False # "motorcycle"
170 | 4: False # "truck"
171 | 5: False # "other-vehicle"
172 | 6: False # "person"
173 | 7: False # "bicyclist"
174 | 8: False # "motorcyclist"
175 | 9: False # "road"
176 | 10: False # "parking"
177 | 11: False # "sidewalk"
178 | 12: False # "other-ground"
179 | 13: False # "building"
180 | 14: False # "fence"
181 | 15: False # "vegetation"
182 | 16: False # "trunk"
183 | 17: False # "terrain"
184 | 18: False # "pole"
185 | 19: False # "traffic-sign"
186 | split: # sequence numbers
187 | train:
188 | - 0
189 | - 1
190 | - 2
191 | - 3
192 | - 4
193 | - 5
194 | - 6
195 | - 7
196 | - 9
197 | - 10
198 | valid:
199 | - 8
200 | test:
201 | - 11
202 | - 12
203 | - 13
204 | - 14
205 | - 15
206 | - 16
207 | - 17
208 | - 18
209 | - 19
210 | - 20
211 | - 21
212 |
--------------------------------------------------------------------------------
/pointdc_mk/datasets/SemanticKITTI.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from lib.helper_ply import read_ply, write_ply
4 | from torch.utils.data import Dataset
5 | import MinkowskiEngine as ME
6 | import random
7 | import os
8 | import open3d as o3d
9 | from lib.aug_tools import rota_coords, scale_coords, trans_coords
10 |
11 | class cfl_collate_fn:
12 |
13 | def __call__(self, list_data):
14 | coords, feats, normals, labels, inverse_map, pseudo, inds, region, index = list(zip(*list_data))
15 | coords_batch, feats_batch, normal_batch, labels_batch, inverse_batch, pseudo_batch, inds_batch = [], [], [], [], [], [], []
16 | region_batch = []
17 | accm_num = 0
18 | for batch_id, _ in enumerate(coords):
19 | num_points = coords[batch_id].shape[0]
20 | coords_batch.append(torch.cat((torch.ones(num_points, 1).int() * batch_id, torch.from_numpy(coords[batch_id]).int()), 1))
21 | feats_batch.append(torch.from_numpy(feats[batch_id]))
22 | normal_batch.append(torch.from_numpy(normals[batch_id]))
23 | labels_batch.append(torch.from_numpy(labels[batch_id]).int())
24 | inverse_batch.append(torch.from_numpy(inverse_map[batch_id]))
25 | pseudo_batch.append(torch.from_numpy(pseudo[batch_id]))
26 | inds_batch.append(torch.from_numpy(inds[batch_id] + accm_num).int())
27 | region_batch.append(torch.from_numpy(region[batch_id])[:,None])
28 | accm_num += coords[batch_id].shape[0]
29 |
30 | # Concatenate all lists
31 | coords_batch = torch.cat(coords_batch, 0).float()#.int()
32 | feats_batch = torch.cat(feats_batch, 0).float()
33 | normal_batch = torch.cat(normal_batch, 0).float()
34 | labels_batch = torch.cat(labels_batch, 0).float()
35 | inverse_batch = torch.cat(inverse_batch, 0).int()
36 | pseudo_batch = torch.cat(pseudo_batch, -1)
37 | inds_batch = torch.cat(inds_batch, 0)
38 | region_batch = torch.cat(region_batch, 0)
39 |
40 | return coords_batch, feats_batch, normal_batch, labels_batch, inverse_batch, pseudo_batch, inds_batch, region_batch, index
41 |
42 |
43 |
44 |
45 | class KITTItrain(Dataset):
46 | def __init__(self, args, scene_idx, split='train'):
47 | self.args = args
48 | self.label_to_names = {0: 'unlabeled',
49 | 1: 'car',
50 | 2: 'bicycle',
51 | 3: 'motorcycle',
52 | 4: 'truck',
53 | 5: 'other-vehicle',
54 | 6: 'person',
55 | 7: 'bicyclist',
56 | 8: 'motorcyclist',
57 | 9: 'road',
58 | 10: 'parking',
59 | 11: 'sidewalk',
60 | 12: 'other-ground',
61 | 13: 'building',
62 | 14: 'fence',
63 | 15: 'vegetation',
64 | 16: 'trunk',
65 | 17: 'terrain',
66 | 18: 'pole',
67 | 19: 'traffic-sign'}
68 | self.mode = 'train'
69 | self.split = split
70 | self.val_split = '08'
71 | self.file = []
72 |
73 | seq_list = np.sort(os.listdir(self.args.data_path))
74 | for seq_id in seq_list:
75 | seq_path = os.path.join(self.args.data_path, seq_id)
76 | if self.split == 'train':
77 | if seq_id in ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10']:
78 | for f in np.sort(os.listdir(seq_path)):
79 | self.file.append(os.path.join(seq_path, f))
80 |
81 | elif self == 'val':
82 | if seq_id == '08':
83 | for f in np.sort(os.listdir(seq_path)):
84 | self.file.append(os.path.join(seq_path, f))
85 | scene_idx = range(len(self.file))
86 |
87 | '''Initial Augmentations'''
88 | self.trans_coords = trans_coords(shift_ratio=50) ### 50%
89 | self.rota_coords = rota_coords(rotation_bound = ((-np.pi/32, np.pi/32), (-np.pi/32, np.pi/32), (-np.pi, np.pi)))
90 | self.scale_coords = scale_coords(scale_bound=(0.9, 1.1))
91 |
92 | self.random_select_sample(scene_idx)
93 |
94 | def random_select_sample(self, scene_idx):
95 | self.name = []
96 | self.file_selected = []
97 | for i in scene_idx:
98 | self.file_selected.append(self.file[i])
99 | self.name.append(self.file[i][0:-4].replace(self.args.data_path, ''))
100 |
101 |
102 | def augs(self, coords):
103 | coords = self.rota_coords(coords)
104 | coords = self.trans_coords(coords)
105 | coords = self.scale_coords(coords)
106 | return coords
107 |
108 |
109 | def augment_coords_to_feats(self, coords, feats, labels=None):
110 | coords_center = coords.mean(0, keepdims=True)
111 | coords_center[0, 2] = 0
112 | norm_coords = (coords - coords_center)
113 | return norm_coords, feats, labels
114 |
115 | def voxelize(self, coords, feats, labels):
116 | scale = 1 / self.args.voxel_size
117 | coords = np.floor(coords * scale)
118 | coords, feats, labels, unique_map, inverse_map = ME.utils.sparse_quantize(np.ascontiguousarray(coords), feats, labels=labels, ignore_label=-1, return_index=True, return_inverse=True)
119 | return coords.numpy(), feats, labels, unique_map, inverse_map.numpy()
120 |
121 |
122 | def __len__(self):
123 | return len(self.file_selected)
124 |
125 | def __getitem__(self, index):
126 | file = self.file_selected[index]
127 | data = read_ply(file)
128 | coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T
129 | feats = np.array(data['remission'])[:, np.newaxis]
130 | labels = np.array(data['class'])
131 | coords = coords.astype(np.float32)
132 | coords -= coords.mean(0)
133 |
134 | coords, feats, labels, unique_map, inverse_map = self.voxelize(coords, feats, labels)
135 | coords = coords.astype(np.float32)
136 |
137 | mask = np.sqrt(((coords*self.args.voxel_size)**2).sum(-1))< self.args.r_crop
138 | coords, feats, labels = coords[mask], feats[mask], labels[mask]
139 |
140 | region_file = self.args.sp_path + '/' +self.name[index] + '_superpoint.npy'
141 | region = np.load(region_file)
142 | region = region[unique_map]
143 | region = region[mask]
144 |
145 | coords = self.augs(coords)
146 |
147 | ''' Take Mixup as an Augmentation'''
148 | inds = np.arange(coords.shape[0])
149 | mix = random.randint(0, len(self.name)-1)
150 |
151 | data_mix = read_ply(self.file_selected[mix])
152 | coords_mix = np.array([data_mix['x'], data_mix['y'], data_mix['z']], dtype=np.float32).T
153 | feats_mix = np.array(data_mix['remission'])[:, np.newaxis]
154 | labels_mix = np.array(data_mix['class'])
155 | feats_mix = feats_mix.astype(np.float32)
156 | coords_mix = coords_mix.astype(np.float32)
157 | coords_mix -= coords_mix.mean(0)
158 |
159 | coords_mix, feats_mix, _, unique_map_mix, _ = self.voxelize(coords_mix, feats_mix, labels_mix)
160 | coords_mix = coords_mix.astype(np.float32)
161 |
162 | mask_mix = np.sqrt(((coords_mix * self.args.voxel_size) ** 2).sum(-1)) < self.args.r_crop
163 | coords_mix, feats_mix = coords_mix[mask_mix], feats_mix[mask_mix]
164 | #
165 | coords_mix = self.augs(coords_mix)
166 | coords = np.concatenate((coords, coords_mix), axis=0)
167 | ''' End Mixup'''
168 |
169 | coords, feats, labels = self.augment_coords_to_feats(coords, feats, labels)
170 | labels -= 1
171 |
172 | '''mode must be cluster or train'''
173 | if self.mode == 'cluster':
174 | pcd = o3d.geometry.PointCloud()
175 | pcd.points = o3d.utility.Vector3dVector(coords[inds])
176 | pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=10, max_nn=30))
177 | normals = np.array(pcd.normals)
178 |
179 | region[labels==-1] = -1
180 |
181 | for q in np.unique(region):
182 | mask = q == region
183 | if mask.sum() < self.args.drop_threshold and q != -1:
184 | region[mask] = -1
185 |
186 | valid_region = region[region != -1]
187 | unique_vals = np.unique(valid_region)
188 | unique_vals.sort()
189 | valid_region = np.searchsorted(unique_vals, valid_region)
190 |
191 | region[region != -1] = valid_region
192 |
193 | pseudo = -np.ones_like(labels).astype(np.long)
194 |
195 | else:
196 | normals = np.zeros_like(coords)
197 | scene_name = self.name[index]
198 | file_path = self.args.pseudo_label_path + '/' + scene_name + '.npy'
199 | pseudo = np.array(np.load(file_path), dtype=np.long)
200 |
201 |
202 | return coords, feats, normals, labels, inverse_map, pseudo, inds, region, index
203 |
204 |
205 |
206 | class KITTIval(Dataset):
207 | def __init__(self, args, split='val'):
208 | self.args = args
209 | self.label_to_names = {0: 'unlabeled',
210 | 1: 'car',
211 | 2: 'bicycle',
212 | 3: 'motorcycle',
213 | 4: 'truck',
214 | 5: 'other-vehicle',
215 | 6: 'person',
216 | 7: 'bicyclist',
217 | 8: 'motorcyclist',
218 | 9: 'road',
219 | 10: 'parking',
220 | 11: 'sidewalk',
221 | 12: 'other-ground',
222 | 13: 'building',
223 | 14: 'fence',
224 | 15: 'vegetation',
225 | 16: 'trunk',
226 | 17: 'terrain',
227 | 18: 'pole',
228 | 19: 'traffic-sign'}
229 | self.name = []
230 | self.mode = 'val'
231 | self.split = split
232 | self.val_split = '08'
233 | self.file = []
234 |
235 | seq_list = np.sort(os.listdir(self.args.data_path))
236 | for seq_id in seq_list:
237 | seq_path = os.path.join(self.args.data_path, seq_id)
238 | if self.split == 'val':
239 | if seq_id == '08':
240 | for f in np.sort(os.listdir(seq_path)):
241 | self.file.append(os.path.join(seq_path, f))
242 | self.name.append(os.path.join(seq_path, f)[0:-4].replace(self.args.data_path, ''))
243 |
244 |
245 | def augment_coords_to_feats(self, coords, feats, labels=None):
246 | coords_center = coords.mean(0, keepdims=True)
247 | coords_center[0, 2] = 0
248 | norm_coords = (coords - coords_center)
249 | return norm_coords, feats, labels
250 |
251 | def voxelize(self, coords, feats, labels):
252 | scale = 1 / self.args.voxel_size
253 | coords = np.floor(coords * scale)
254 | coords, feats, labels, unique_map, inverse_map = ME.utils.sparse_quantize(np.ascontiguousarray(coords), feats, labels=labels, ignore_label=-1, return_index=True, return_inverse=True)
255 | return coords.numpy(), feats, labels, unique_map, inverse_map.numpy()
256 |
257 |
258 | def __len__(self):
259 | return len(self.file)
260 |
261 | def __getitem__(self, index):
262 | file = self.file[index]
263 | data = read_ply(file)
264 | coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T
265 | feats = np.array(data['remission'])[:, np.newaxis]
266 | labels = np.array(data['class'])
267 | coords = coords.astype(np.float32)
268 | coords -= coords.mean(0)
269 |
270 | coords, feats, _, unique_map, inverse_map = self.voxelize(coords, feats, labels)
271 | coords = coords.astype(np.float32)
272 |
273 | region_file = self.args.sp_path + '/' +self.name[index] + '_superpoint.npy'
274 | region = np.load(region_file)
275 | region = region[unique_map]
276 |
277 | coords, feats, labels = self.augment_coords_to_feats(coords, feats, labels)
278 | labels = labels -1
279 |
280 | return coords, feats, np.ascontiguousarray(labels), inverse_map, region, index
281 |
282 |
283 | class cfl_collate_fn_val:
284 |
285 | def __call__(self, list_data):
286 | coords, feats, labels, inverse_map, region, index = list(zip(*list_data))
287 | coords_batch, feats_batch, inverse_batch, labels_batch = [], [], [], []
288 | region_batch = []
289 | for batch_id, _ in enumerate(coords):
290 | num_points = coords[batch_id].shape[0]
291 | coords_batch.append(
292 | torch.cat((torch.ones(num_points, 1).int() * batch_id, torch.from_numpy(coords[batch_id]).int()), 1))
293 | feats_batch.append(torch.from_numpy(feats[batch_id]))
294 | inverse_batch.append(torch.from_numpy(inverse_map[batch_id]))
295 | labels_batch.append(torch.from_numpy(labels[batch_id]).int())
296 | region_batch.append(torch.from_numpy(region[batch_id])[:, None])
297 | #
298 | # Concatenate all lists
299 | coords_batch = torch.cat(coords_batch, 0).float()
300 | feats_batch = torch.cat(feats_batch, 0).float()
301 | inverse_batch = torch.cat(inverse_batch, 0).int()
302 | labels_batch = torch.cat(labels_batch, 0).int()
303 | region_batch = torch.cat(region_batch, 0)
304 |
305 | return coords_batch, feats_batch, inverse_batch, labels_batch, index, region_batch
306 |
--------------------------------------------------------------------------------
/pointdc_mk/env.yaml:
--------------------------------------------------------------------------------
1 | name: pointdc_mk
2 | channels:
3 | - davidcaron
4 | - pytorch/label/nightly
5 | - pytorch
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=conda_forge
10 | - _openmp_mutex=4.5=2_kmp_llvm
11 | - appdirs=1.4.4=pyhd3eb1b0_0
12 | - blas=1.0=mkl
13 | - boost-cpp=1.72.0=he72f1d9_7
14 | - brotlipy=0.7.0=py38h27cfd23_1003
15 | - bzip2=1.0.8=h7f98852_4
16 | - c-ares=1.19.1=hd590300_0
17 | - ca-certificates=2023.05.30=h06a4308_0
18 | - cffi=1.15.1=py38h5eee18b_3
19 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
20 | - cryptography=41.0.2=py38h22a60cf_0
21 | - cudatoolkit=11.3.1=hb98b00a_12
22 | - eigen=3.4.1=h00ab1b0_0
23 | - ffmpeg=4.3=hf484d3e_0
24 | - flann=1.9.1=h941a29b_1013
25 | - freetype=2.12.1=hca18f0e_1
26 | - gmp=6.2.1=h58526e2_0
27 | - gnutls=3.6.13=h85f3911_1
28 | - hdf5=1.14.1=nompi_h4f84152_100
29 | - icu=70.1=h27087fc_0
30 | - idna=3.4=py38h06a4308_0
31 | - intel-openmp=2021.4.0=h06a4308_3561
32 | - jpeg=9e=h0b41bf4_3
33 | - keyutils=1.6.1=h166bdaf_0
34 | - krb5=1.21.1=h659d440_0
35 | - lame=3.100=h166bdaf_1003
36 | - laspy=2.3.0=pyha21a80b_0
37 | - lcms2=2.15=hfd0df8a_0
38 | - ld_impl_linux-64=2.40=h41732ed_0
39 | - lerc=4.0.0=h27087fc_0
40 | - libaec=1.0.6=hcb278e6_1
41 | - libcurl=8.2.0=hca28451_0
42 | - libdeflate=1.17=h0b41bf4_0
43 | - libedit=3.1.20191231=he28a2e2_2
44 | - libev=4.33=h516909a_1
45 | - libfaiss=1.7.3=hfc2d529_97_cuda11.3_nightly
46 | - libffi=3.4.2=h7f98852_5
47 | - libgcc-ng=13.1.0=he5830b7_0
48 | - libgfortran-ng=13.1.0=h69a702a_0
49 | - libgfortran5=13.1.0=h15d22d2_0
50 | - libhwloc=2.9.1=hd6dc26d_0
51 | - libiconv=1.17=h166bdaf_0
52 | - libnghttp2=1.52.0=h61bc06f_0
53 | - libnsl=2.0.0=h7f98852_0
54 | - libpng=1.6.39=h753d276_0
55 | - libsqlite=3.42.0=h2797004_0
56 | - libssh2=1.11.0=h0841786_0
57 | - libstdcxx-ng=13.1.0=hfd8a6a1_0
58 | - libtiff=4.5.0=h6adf6a1_2
59 | - libuv=1.44.2=h166bdaf_0
60 | - libwebp-base=1.3.1=hd590300_0
61 | - libxcb=1.13=h7f98852_1004
62 | - libxml2=2.10.3=hca2bb57_4
63 | - libzlib=1.2.13=hd590300_5
64 | - llvm-openmp=16.0.6=h4dfa4b3_0
65 | - mkl=2021.4.0=h06a4308_640
66 | - mkl-include=2022.1.0=h84fe81f_915
67 | - mkl-service=2.4.0=py38h7f8727e_0
68 | - mkl_fft=1.3.1=py38hd3c417c_0
69 | - mkl_random=1.2.2=py38h51133e4_0
70 | - ncurses=6.4=hcb278e6_0
71 | - nettle=3.6=he412f7d_0
72 | - numpy=1.21.2=py38h20f2e39_0
73 | - numpy-base=1.21.2=py38h79a1101_0
74 | - openh264=2.1.1=h780b84a_0
75 | - openjpeg=2.5.0=hfec8fc6_2
76 | - openssl=3.1.1=hd590300_1
77 | - pcl=1.9.1=h2dfa329_1005
78 | - pclpy=0.12.0=py38_1
79 | - pillow=9.4.0=py38hde6dc18_1
80 | - pooch=1.4.0=pyhd3eb1b0_0
81 | - pthread-stubs=0.4=h36c2ea0_1001
82 | - pycparser=2.21=pyhd3eb1b0_0
83 | - pyopenssl=23.2.0=py38h06a4308_0
84 | - pysocks=1.7.1=py38h06a4308_0
85 | - python=3.8.12=h0744224_3_cpython
86 | - python_abi=3.8=3_cp38
87 | - pytorch=1.10.2=py3.8_cuda11.3_cudnn8.2.0_0
88 | - pytorch-mutex=1.0=cuda
89 | - qhull=2015.2=hc9558a2_1001
90 | - readline=8.2=h8228510_1
91 | - requests=2.31.0=py38h06a4308_0
92 | - setuptools=68.0.0=pyhd8ed1ab_0
93 | - six=1.16.0=pyhd3eb1b0_1
94 | - sqlite=3.42.0=h2c6b66d_0
95 | - tbb=2021.9.0=hf52228f_0
96 | - threadpoolctl=2.2.0=pyh0d69192_0
97 | - tk=8.6.12=h27826a3_0
98 | - torchvision=0.11.3=py38_cu113
99 | - typing_extensions=4.7.1=pyha770c72_0
100 | - wheel=0.40.0=pyhd8ed1ab_1
101 | - xorg-libxau=1.0.11=hd590300_0
102 | - xorg-libxdmcp=1.1.3=h7f98852_0
103 | - xz=5.2.6=h166bdaf_0
104 | - zlib=1.2.13=hd590300_5
105 | - zstd=1.5.2=hfc55251_7
106 | - pip:
107 | - anyio==3.7.1
108 | - argon2-cffi==21.3.0
109 | - argon2-cffi-bindings==21.2.0
110 | - arrow==1.2.3
111 | - asttokens==2.2.1
112 | - async-lru==2.0.3
113 | - attrs==23.1.0
114 | - babel==2.12.1
115 | - backcall==0.2.0
116 | - beautifulsoup4==4.12.2
117 | - bleach==6.0.0
118 | - blessed==1.20.0
119 | - cachetools==5.3.1
120 | - certifi==2023.5.7
121 | - clip==1.0
122 | - comm==0.1.3
123 | - cycler==0.11.0
124 | - debugpy==1.6.7
125 | - decorator==5.1.1
126 | - defusedxml==0.7.1
127 | - exceptiongroup==1.1.2
128 | - executing==1.2.0
129 | - faiss-gpu==1.7.2
130 | - fastjsonschema==2.17.1
131 | - fonttools==4.41.0
132 | - fqdn==1.5.1
133 | - ftfy==6.1.1
134 | - gpustat==1.1
135 | - importlib-metadata==6.8.0
136 | - importlib-resources==6.0.0
137 | - ipykernel==6.24.0
138 | - ipython==8.12.2
139 | - ipywidgets==8.0.7
140 | - isoduration==20.11.0
141 | - jedi==0.18.2
142 | - jinja2==3.1.2
143 | - joblib==1.3.1
144 | - json5==0.9.14
145 | - jsonpointer==2.4
146 | - jsonschema==4.18.4
147 | - jsonschema-specifications==2023.7.1
148 | - jupyter-client==8.3.0
149 | - jupyter-core==5.3.1
150 | - jupyter-events==0.6.3
151 | - jupyter-lsp==2.2.0
152 | - jupyter-server==2.7.0
153 | - jupyter-server-terminals==0.4.4
154 | - jupyterlab==4.0.3
155 | - jupyterlab-pygments==0.2.2
156 | - jupyterlab-server==2.23.0
157 | - jupyterlab-widgets==3.0.8
158 | - kiwisolver==1.4.4
159 | - markupsafe==2.1.3
160 | - matplotlib==3.5.1
161 | - matplotlib-inline==0.1.6
162 | - minkowskiengine==0.5.4
163 | - mistune==3.0.1
164 | - nbclient==0.8.0
165 | - nbconvert==7.7.2
166 | - nbformat==5.9.1
167 | - nest-asyncio==1.5.6
168 | - ninja==1.10.2.3
169 | - notebook==7.0.0
170 | - notebook-shim==0.2.3
171 | - nvidia-ml-py==12.535.77
172 | - nvitop==1.2.0
173 | - open3d==0.10.0.0
174 | - opencv-python==4.8.0.74
175 | - overrides==7.3.1
176 | - packaging==23.1
177 | - pandas==2.0.3
178 | - pandocfilters==1.5.0
179 | - parso==0.8.3
180 | - pexpect==4.8.0
181 | - pickleshare==0.7.5
182 | - pip==23.2.1
183 | - pkgutil-resolve-name==1.3.10
184 | - platformdirs==3.9.1
185 | - prometheus-client==0.17.1
186 | - prompt-toolkit==3.0.39
187 | - protobuf==4.24.4
188 | - psutil==5.9.5
189 | - ptyprocess==0.7.0
190 | - pure-eval==0.2.2
191 | - pygments==2.15.1
192 | - pyparsing==3.1.0
193 | - python-dateutil==2.8.2
194 | - python-json-logger==2.0.7
195 | - pytz==2023.3
196 | - pyyaml==6.0.1
197 | - pyzmq==25.1.0
198 | - rama==0.0.7
199 | - referencing==0.30.0
200 | - regex==2023.10.3
201 | - rfc3339-validator==0.1.4
202 | - rfc3986-validator==0.1.1
203 | - rpds-py==0.9.2
204 | - scikit-learn==0.22.2
205 | - scipy==1.8.0
206 | - seaborn==0.11.2
207 | - send2trash==1.8.2
208 | - sniffio==1.3.0
209 | - soupsieve==2.4.1
210 | - stack-data==0.6.2
211 | - tensorboardx==2.6.2.2
212 | - termcolor==2.3.0
213 | - terminado==0.17.1
214 | - tinycss2==1.2.1
215 | - tomli==2.0.1
216 | - torch-scatter==2.0.9
217 | - tornado==6.3.2
218 | - tqdm==4.65.0
219 | - traitlets==5.9.0
220 | - tzdata==2023.3
221 | - uri-template==1.3.0
222 | - urllib3==2.0.4
223 | - wcwidth==0.2.6
224 | - webcolors==1.13
225 | - webencodings==0.5.1
226 | - websocket-client==1.6.1
227 | - widgetsnbextension==4.0.8
228 | - zipp==3.16.2
--------------------------------------------------------------------------------
/pointdc_mk/eval_S3DIS.py:
--------------------------------------------------------------------------------
1 | import torch, os, argparse, faiss
2 | import torch.nn.functional as F
3 | from datasets.S3DIS import S3DIStest, S3DIStrain, cfl_collate_fn_test, cfl_collate_fn
4 | import numpy as np
5 | import MinkowskiEngine as ME
6 | from torch.utils.data import DataLoader
7 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2
8 | from sklearn.cluster import KMeans
9 | from models.fpn import Res16FPN18
10 | from lib.utils_s3dis import *
11 | from tqdm import tqdm
12 | from os.path import join
13 | from datetime import datetime
14 | from sklearn.cluster._kmeans import k_means
15 | from models.pretrain_models import SubModel
16 | ###
17 | def parse_args():
18 | '''PARAMETERS'''
19 | parser = argparse.ArgumentParser(description='PyTorch Unsuper_3D_Seg')
20 | parser.add_argument('--data_path', type=str, default='data/S3DIS/', help='pont cloud data path')
21 | parser.add_argument('--sp_path', type=str, default= 'data/S3DIS/', help='initial sp path')
22 | parser.add_argument('--expname', type=str, default= 'zdefalut', help='expname for logger')
23 | ###
24 | parser.add_argument('--save_path', type=str, default='ckpt/S3DIS/', help='model savepath')
25 | ###
26 | parser.add_argument('--bn_momentum', type=float, default=0.02, help='batchnorm parameters')
27 | parser.add_argument('--conv1_kernel_size', type=int, default=5, help='kernel size of 1st conv layers')
28 | ###
29 | parser.add_argument('--workers', type=int, default=8, help='how many workers for loading data')
30 | parser.add_argument('--seed', type=int, default=2023, help='random seed')
31 | parser.add_argument('--log-interval', type=int, default=150, help='log interval')
32 | parser.add_argument('--batch_size', type=int, default=8, help='batchsize in training')
33 | parser.add_argument('--voxel_size', type=float, default=0.05, help='voxel size in SparseConv')
34 | parser.add_argument('--input_dim', type=int, default=6, help='network input dimension')### 6 for XYZGB
35 | parser.add_argument('--primitive_num', type=int, default=13, help='how many primitives used in training')
36 | parser.add_argument('--semantic_class', type=int, default=13, help='ground truth semantic class')
37 | parser.add_argument('--feats_dim', type=int, default=128, help='output feature dimension')
38 | parser.add_argument('--ignore_label', type=int, default=-1, help='invalid label')
39 | parser.add_argument('--drop_threshold', type=int, default=50, help='mask counts')
40 |
41 | return parser.parse_args()
42 |
43 |
44 |
45 | def eval_once(args, model, test_loader, classifier, use_sp=False):
46 | model.mode = 'train'
47 | all_preds, all_label = [], []
48 | test_loader_bar = tqdm(test_loader)
49 | for data in test_loader_bar:
50 | test_loader_bar.set_description('Start eval...')
51 | with torch.no_grad():
52 | coords, features, inverse_map, labels, index, region = data
53 |
54 | in_field = ME.TensorField(features, coords, device=0)
55 | feats_nonorm = model(in_field)
56 | feats_norm = F.normalize(feats_nonorm)
57 |
58 | region = region.squeeze()
59 | if use_sp:
60 | region_inds = torch.unique(region).long()
61 | region_feats = []
62 | for id in region_inds:
63 | if id != -1:
64 | valid_mask = id == region
65 | region_feats.append(feats_norm[valid_mask].mean(0, keepdim=True))
66 | region_feats = torch.cat(region_feats, dim=0)
67 |
68 | scores = F.linear(F.normalize(feats_norm), F.normalize(classifier.weight))
69 | preds = torch.argmax(scores, dim=1).cpu()
70 |
71 | region_scores = F.linear(F.normalize(region_feats), F.normalize(classifier.weight))
72 | for id in region_inds: # 遇到0跳过
73 | if id != 0:
74 | valid_mask = id == region
75 | preds[valid_mask] = torch.argmax(region_scores, dim=1).cpu()[id]
76 | else:
77 | scores = F.linear(F.normalize(feats_nonorm), F.normalize(classifier.weight))
78 | preds = torch.argmax(scores, dim=1).cpu()
79 |
80 | preds = preds[inverse_map.long()]
81 | all_preds.append(preds[labels!=args.ignore_label]), all_label.append(labels[[labels!=args.ignore_label]])
82 |
83 | return all_preds, all_label
84 |
85 | def eval(epoch, args, mode='svc'):
86 | ## Model
87 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, conv1_kernel_size=args.conv1_kernel_size, config=args, mode='train').cuda()
88 | model.load_state_dict(torch.load(os.path.join(args.save_path, mode, 'model_' + str(epoch) + '_checkpoint.pth')))
89 | model.eval()
90 | ## Merge Cluster Centers
91 | primitive_centers = torch.load(os.path.join(args.save_path, mode, 'cls_' + str(epoch) + '_checkpoint.pth'))
92 | centroids, _, _ = k_means(primitive_centers.cpu(), n_clusters=args.semantic_class, random_state=None, n_init=20)
93 | centroids = F.normalize(torch.FloatTensor(centroids), dim=1).cuda()
94 | cls = get_fixclassifier(in_channel=args.feats_dim, centroids_num=args.semantic_class, centroids=centroids).cuda()
95 | cls.eval()
96 |
97 | val_dataset = S3DIStest(args)
98 | val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=cfl_collate_fn_test(), num_workers=args.cluster_workers, pin_memory=True)
99 |
100 | preds, labels = eval_once(args, model, val_loader, cls, use_sp=True)
101 | all_preds = torch.cat(preds).numpy()
102 | all_labels = torch.cat(labels).numpy()
103 |
104 | o_Acc, m_Acc, s = compute_seg_results(args, all_labels, all_preds)
105 |
106 | return o_Acc, m_Acc, s
107 |
108 | def eval_by_cluster(args, epoch, mode='svc'):
109 | ## Prepare Data
110 | trainset = S3DIStrain(args, areas=['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6'])
111 | cluster_loader = DataLoader(trainset, batch_size=1, shuffle=True, collate_fn=cfl_collate_fn(), \
112 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(args.seed))
113 | val_dataset = S3DIStest(args)
114 | val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=cfl_collate_fn_test(), num_workers=args.workers, pin_memory=True)
115 | ## Define model
116 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, conv1_kernel_size=args.conv1_kernel_size, config=args, mode='train').cuda()
117 | model.load_state_dict(torch.load(os.path.join(args.save_path, mode, 'model_' + str(epoch) + '_checkpoint.pth')))
118 | model.eval()
119 | ## Get sp features
120 | print('Start to get sp features...')
121 | sp_feats_list = init_get_sp_feature(args, cluster_loader, model)
122 | sp_feats = torch.cat(sp_feats_list, dim=0)
123 | print('Start to train faiss...')
124 | ## Train faiss module
125 | _, primitive_centers = faiss_cluster(args, sp_feats.cpu().numpy())
126 | ## Merge Primitive
127 | centroids, _, _ = k_means(primitive_centers.cpu(), n_clusters=args.semantic_class, random_state=None, n_init=20, n_jobs=20)
128 | centroids = F.normalize(torch.FloatTensor(centroids), dim=1).cuda()
129 | ## Get cls
130 | cls = get_fixclassifier(in_channel=args.feats_dim, centroids_num=args.semantic_class, centroids=centroids).cuda()
131 | cls.eval()
132 | ## eval
133 | preds, labels = eval_once(args, model, val_loader, cls, use_sp=True)
134 | all_preds, all_labels = torch.cat(preds).numpy(), torch.cat(labels).numpy()
135 | o_Acc, m_Acc, s = compute_seg_results(args, all_labels, all_preds)
136 | return o_Acc, m_Acc, s
137 |
138 | if __name__ == '__main__':
139 | args = parse_args()
140 | expnames = ['240524_trainall']
141 | epoches = [70, 80, 90]
142 | seeds = [12, 43, 56, 78, 90]
143 | for expname in expnames:
144 | args.save_path = 'ckpt/S3DIS'
145 | args.save_path = join(args.save_path, expname)
146 | assert os.path.exists(args.save_path), 'There is no {} !!!'.format(expname)
147 | if not os.path.exists(join('results', expname)):
148 | os.makedirs(join('results', expname))
149 | for epoch in epoches:
150 | results = []
151 | results.append('Eval exp {}\n'.format(expname))
152 | results.append("Eval time: {}\n".format(datetime.now().strftime("%Y-%m-%d %H:%M")))
153 | results_file = join('results', expname, 'eval_{}.txt'.format(str(epoch)))
154 | print('Eval {}, save file to {}'.format(expname, results_file))
155 | for seed in seeds:
156 | args.seed = seed
157 | set_seed(args.seed)
158 | o_Acc, m_Acc, s = eval_by_cluster(args, epoch, mode='svc')
159 | print('Epoch {:02d} Seed {}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, seed, o_Acc, m_Acc) + s)
160 | results.append('Epoch {:02d} Seed {}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, seed, o_Acc, m_Acc) + s +'\n')
161 | write_list(results_file, results)
162 | print('\n')
--------------------------------------------------------------------------------
/pointdc_mk/eval_ScanNet.py:
--------------------------------------------------------------------------------
1 | import torch, os, argparse, faiss
2 | import torch.nn.functional as F
3 | from datasets.ScanNet import Scannetval, cfl_collate_fn_val
4 | import numpy as np
5 | import MinkowskiEngine as ME
6 | from torch.utils.data import DataLoader
7 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2
8 | from sklearn.cluster import KMeans
9 | from models.fpn import Res16FPN18
10 | from lib.utils import get_fixclassifier, init_get_sp_feature, faiss_cluster, worker_init_fn, set_seed, compute_seg_results, write_list
11 | from tqdm import tqdm
12 | from os.path import join
13 | from datasets.ScanNet import Scannettrain, Scannetdistill, Scannetval, cfl_collate_fn, cfl_collate_fn_distill, cfl_collate_fn_val
14 | from datetime import datetime
15 | from sklearn.cluster._kmeans import k_means
16 | from models.pretrain_models import SubModel
17 | ###
18 | def parse_args():
19 | parser = argparse.ArgumentParser(description='PyTorch Unsuper_3D_Seg')
20 | parser.add_argument('--data_path', type=str, default='data/ScanNet4growsp/train',
21 | help='pont cloud data path')
22 | parser.add_argument('--feats_path', type=str, default='data/ScanNet4growsp/traindatas',
23 | help='pont cloud data path')
24 | parser.add_argument('--sp_path', type=str, default= 'data/scans_growed_sp',
25 | help='initial sp path')
26 | parser.add_argument('--save_path', type=str, default='ckpt/ScanNet/',
27 | help='model savepath')
28 | ###
29 | parser.add_argument('--bn_momentum', type=float, default=0.02, help='batchnorm parameters')
30 | parser.add_argument('--conv1_kernel_size', type=int, default=5, help='kernel size of 1st conv layers')
31 | ####
32 | parser.add_argument('--workers', type=int, default=10, help='how many workers for loading data')
33 | parser.add_argument('--seed', type=int, default=2023, help='random seed')
34 | parser.add_argument('--voxel_size', type=float, default=0.02, help='voxel size in SparseConv')
35 | parser.add_argument('--input_dim', type=int, default=6, help='network input dimension')### 6 for XYZGB
36 | parser.add_argument('--primitive_num', type=int, default=30, help='how many primitives used in training')
37 | parser.add_argument('--semantic_class', type=int, default=20, help='ground truth semantic class')
38 | parser.add_argument('--feats_dim', type=int, default=128, help='output feature dimension')
39 | parser.add_argument('--ignore_label', type=int, default=-1, help='invalid label')
40 | return parser.parse_args()
41 |
42 |
43 | def eval_once(args, model, test_loader, classifier, use_sp=False):
44 | model.mode = 'train'
45 | all_preds, all_label = [], []
46 | test_loader_bar = tqdm(test_loader)
47 | for data in test_loader_bar:
48 | test_loader_bar.set_description('Start eval...')
49 | with torch.no_grad():
50 | coords, features, inverse_map, labels, index, region = data
51 |
52 | in_field = ME.TensorField(features, coords, device=0)
53 | feats_nonorm = model(in_field)
54 | feats_norm = F.normalize(feats_nonorm)
55 |
56 | region = region.squeeze()
57 | if use_sp:
58 | region_inds = torch.unique(region)
59 | region_feats = []
60 | for id in region_inds:
61 | if id != -1:
62 | valid_mask = id == region
63 | region_feats.append(feats_norm[valid_mask].mean(0, keepdim=True))
64 | region_feats = torch.cat(region_feats, dim=0)
65 |
66 | scores = F.linear(F.normalize(feats_norm), F.normalize(classifier.weight))
67 | preds = torch.argmax(scores, dim=1).cpu()
68 |
69 | region_scores = F.linear(F.normalize(region_feats), F.normalize(classifier.weight))
70 | region_no = 0
71 | for id in region_inds:
72 | if id != -1:
73 | valid_mask = id == region
74 | preds[valid_mask] = torch.argmax(region_scores, dim=1).cpu()[region_no]
75 | region_no +=1
76 | else:
77 | scores = F.linear(F.normalize(feats_nonorm), F.normalize(classifier.weight))
78 | preds = torch.argmax(scores, dim=1).cpu()
79 |
80 | preds = preds[inverse_map.long()]
81 | all_preds.append(preds[labels!=args.ignore_label]), all_label.append(labels[[labels!=args.ignore_label]])
82 |
83 | return all_preds, all_label
84 |
85 | def eval(epoch, args, mode='svc'):
86 | ## Model
87 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, conv1_kernel_size=args.conv1_kernel_size, config=args, mode='train').cuda()
88 | model.load_state_dict(torch.load(os.path.join(args.save_path, mode, 'model_' + str(epoch) + '_checkpoint.pth')))
89 | model.eval()
90 | ## Merge Cluster Centers
91 | primitive_centers = torch.load(os.path.join(args.save_path, mode, 'cls_' + str(epoch) + '_checkpoint.pth'))
92 | centroids, _, _ = k_means(primitive_centers.cpu(), n_clusters=args.semantic_class, random_state=None, n_init=20)
93 | centroids = F.normalize(torch.FloatTensor(centroids), dim=1).cuda()
94 | cls = get_fixclassifier(in_channel=args.feats_dim, centroids_num=args.semantic_class, centroids=centroids).cuda()
95 | cls.eval()
96 |
97 | val_dataset = Scannetval(args)
98 | val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=cfl_collate_fn_val(), num_workers=args.workers, pin_memory=True)
99 |
100 | preds, labels = eval_once(args, model, val_loader, cls, use_sp=True)
101 | all_preds = torch.cat(preds).numpy()
102 | all_labels = torch.cat(labels).numpy()
103 |
104 | o_Acc, m_Acc, s = compute_seg_results(args, all_labels, all_preds)
105 |
106 | return o_Acc, m_Acc, s
107 |
108 | def eval_by_cluster(args, epoch, mode='svc'):
109 | ## Prepare Data
110 | trainset = Scannettrain(args)
111 | cluster_loader = DataLoader(trainset, batch_size=1, shuffle=True, collate_fn=cfl_collate_fn(), \
112 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(args.seed))
113 | val_dataset = Scannetval(args)
114 | val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=cfl_collate_fn_val(), num_workers=args.workers, pin_memory=True)
115 | ## Define model
116 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, conv1_kernel_size=args.conv1_kernel_size, config=args, mode='train').cuda()
117 | model.load_state_dict(torch.load(os.path.join(args.save_path, mode, 'model_' + str(epoch) + '_checkpoint.pth')))
118 | model.eval()
119 | ## Get sp features
120 | print('Start to get sp features...')
121 | sp_feats_list = init_get_sp_feature(args, cluster_loader, model)
122 | sp_feats = torch.cat(sp_feats_list, dim=0)
123 | print('Start to train faiss...')
124 | ## Train faiss module
125 | _, primitive_centers = faiss_cluster(args, sp_feats.cpu().numpy())
126 | ## Merge Primitive
127 | centroids, _, _ = k_means(primitive_centers.cpu(), n_clusters=args.semantic_class, random_state=None, n_init=20, n_jobs=20)
128 | centroids = F.normalize(torch.FloatTensor(centroids), dim=1).cuda()
129 | ## Get cls
130 | cls = get_fixclassifier(in_channel=args.feats_dim, centroids_num=args.semantic_class, centroids=centroids).cuda()
131 | cls.eval()
132 | ## eval
133 | preds, labels = eval_once(args, model, val_loader, cls, use_sp=True)
134 | all_preds, all_labels = torch.cat(preds).numpy(), torch.cat(labels).numpy()
135 | o_Acc, m_Acc, s = compute_seg_results(args, all_labels, all_preds)
136 | return o_Acc, m_Acc, s
137 |
138 | if __name__ == '__main__':
139 | args = parse_args()
140 | expnames = ['~'] # your exp name
141 | epoches = []
142 | seeds = [12, 43, 56, 78, 90]
143 | for expname in expnames:
144 | args.save_path = 'ckpt/ScanNet'
145 | args.save_path = join(args.save_path, expname)
146 | assert os.path.exists(args.save_path), 'There is no {} !!!'.format(expname)
147 | if not os.path.exists(join('results', expname)):
148 | os.makedirs(join('results', expname))
149 | for epoch in epoches:
150 | results = []
151 | results.append('Eval exp {}\n'.format(expname))
152 | results.append("Eval time: {}\n".format(datetime.now().strftime("%Y-%m-%d %H:%M")))
153 | results_file = join('results', expname, 'eval_{}.txt'.format(str(epoch)))
154 | print('Eval {}, save file to {}'.format(expname, results_file))
155 | for seed in seeds:
156 | args.seed = seed
157 | set_seed(args.seed)
158 | o_Acc, m_Acc, s = eval_by_cluster(args, epoch, mode='svc')
159 | print('Epoch {:02d} Seed {}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, seed, o_Acc, m_Acc) + s)
160 | results.append('Epoch {:02d} Seed {}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, seed, o_Acc, m_Acc) + s +'\n')
161 | write_list(results_file, results)
162 | print('\n')
--------------------------------------------------------------------------------
/pointdc_mk/figs/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCUT-BIP-Lab/PointDC/99f4f30ee4e2fa2d173c3881196afec4b02bb606/pointdc_mk/figs/framework.jpg
--------------------------------------------------------------------------------
/pointdc_mk/figs/scannet_visualization.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCUT-BIP-Lab/PointDC/99f4f30ee4e2fa2d173c3881196afec4b02bb606/pointdc_mk/figs/scannet_visualization.jpg
--------------------------------------------------------------------------------
/pointdc_mk/lib/aug_tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.linalg import expm, norm
3 | import scipy.ndimage
4 |
5 | def M(axis, theta):
6 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta))
7 |
8 |
9 | class trans_coords:
10 | def __init__(self, shift_ratio):
11 | self.ratio = shift_ratio
12 |
13 | def __call__(self, coords):
14 | shift = (np.random.uniform(0, 1, 3) * self.ratio)
15 | return coords + shift
16 |
17 |
18 | class rota_coords:
19 | def __init__(self, rotation_bound = ((-np.pi/32, np.pi/32), (-np.pi/32, np.pi/32), (-np.pi, np.pi))):
20 | self.rotation_bound = rotation_bound
21 |
22 | def __call__(self, coords):
23 | rot_mats = []
24 | for axis_ind, rot_bound in enumerate(self.rotation_bound):
25 | theta = 0
26 | axis = np.zeros(3)
27 | axis[axis_ind] = 1
28 | if rot_bound is not None:
29 | theta = np.random.uniform(*rot_bound)
30 | rot_mats.append(M(axis, theta))
31 | # Use random order
32 | np.random.shuffle(rot_mats)
33 | rot_mat = rot_mats[0] @ rot_mats[1] @ rot_mats[2]
34 | return coords.dot(rot_mat)
35 |
36 |
37 | class scale_coords:
38 | def __init__(self, scale_bound=(0.8, 1.25)):
39 | self.scale_bound = scale_bound
40 |
41 | def __call__(self, coords):
42 | scale = np.random.uniform(*self.scale_bound)
43 | return coords*scale
44 |
45 | class elastic_coords:
46 | def __init__(self, voxel_size):
47 | self.voxel_size = voxel_size
48 |
49 | def __call__(self, coords, gran, mag):
50 | blur0 = np.ones((3, 1, 1)).astype('float32') / 3
51 | blur1 = np.ones((1, 3, 1)).astype('float32') / 3
52 | blur2 = np.ones((1, 1, 3)).astype('float32') / 3
53 |
54 | bb = (np.abs(coords).max(0).astype(np.int32) // gran + 3).astype(np.int32)
55 | noise = [np.random.randn(bb[0], bb[1], bb[2]).astype('float32') for _ in range(3)]
56 | noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
57 | noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
58 | noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
59 | noise = [scipy.ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
60 | noise = [scipy.ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
61 | noise = [scipy.ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
62 | ax = [np.linspace(-(b - 1) * gran, (b - 1) * gran, b) for b in bb]
63 | interp = [scipy.interpolate.RegularGridInterpolator(ax, n, bounds_error=0, fill_value=0) for n in noise]
64 |
65 | def g(x_):
66 | return np.hstack([i(x_)[:, None] for i in interp])
67 |
68 | return coords + g(coords) * mag
--------------------------------------------------------------------------------
/pointdc_mk/lib/helper_ply.py:
--------------------------------------------------------------------------------
1 | #
2 | #
3 | # 0===============================0
4 | # | PLY files reader/writer |
5 | # 0===============================0
6 | #
7 | #
8 | # ----------------------------------------------------------------------------------------------------------------------
9 | #
10 | # function to read/write .ply files
11 | #
12 | # ----------------------------------------------------------------------------------------------------------------------
13 | #
14 | # Hugues THOMAS - 10/02/2017
15 | #
16 |
17 |
18 | # ----------------------------------------------------------------------------------------------------------------------
19 | #
20 | # Imports and global variables
21 | # \**********************************/
22 | #
23 |
24 |
25 | # Basic libs
26 | import numpy as np
27 | import sys
28 |
29 |
30 | # Define PLY types
31 | ply_dtypes = dict([
32 | (b'int8', 'i1'),
33 | (b'char', 'i1'),
34 | (b'uint8', 'u1'),
35 | (b'uchar', 'u1'),
36 | (b'int16', 'i2'),
37 | (b'short', 'i2'),
38 | (b'uint16', 'u2'),
39 | (b'ushort', 'u2'),
40 | (b'int32', 'i4'),
41 | (b'int', 'i4'),
42 | (b'uint32', 'u4'),
43 | (b'uint', 'u4'),
44 | (b'float32', 'f4'),
45 | (b'float', 'f4'),
46 | (b'float64', 'f8'),
47 | (b'double', 'f8')
48 | ])
49 |
50 | # Numpy reader format
51 | valid_formats = {'ascii': '', 'binary_big_endian': '>',
52 | 'binary_little_endian': '<'}
53 |
54 |
55 | # ----------------------------------------------------------------------------------------------------------------------
56 | #
57 | # Functions
58 | # \***************/
59 | #
60 |
61 |
62 | def parse_header(plyfile, ext):
63 | # Variables
64 | line = []
65 | properties = []
66 | num_points = None
67 |
68 | while b'end_header' not in line and line != b'':
69 | line = plyfile.readline()
70 |
71 | if b'element' in line:
72 | line = line.split()
73 | num_points = int(line[2])
74 |
75 | elif b'property' in line:
76 | line = line.split()
77 | properties.append((line[2].decode(), ext + ply_dtypes[line[1]]))
78 |
79 | return num_points, properties
80 |
81 |
82 | def parse_mesh_header(plyfile, ext):
83 | # Variables
84 | line = []
85 | vertex_properties = []
86 | num_points = None
87 | num_faces = None
88 | current_element = None
89 |
90 |
91 | while b'end_header' not in line and line != b'':
92 | line = plyfile.readline()
93 |
94 | # Find point element
95 | if b'element vertex' in line:
96 | current_element = 'vertex'
97 | line = line.split()
98 | num_points = int(line[2])
99 |
100 | elif b'element face' in line:
101 | current_element = 'face'
102 | line = line.split()
103 | num_faces = int(line[2])
104 |
105 | elif b'property' in line:
106 | if current_element == 'vertex':
107 | line = line.split()
108 | vertex_properties.append((line[2].decode(), ext + ply_dtypes[line[1]]))
109 | elif current_element == 'vertex':
110 | if not line.startswith('property list uchar int'):
111 | raise ValueError('Unsupported faces property : ' + line)
112 |
113 | return num_points, num_faces, vertex_properties
114 |
115 |
116 | def read_ply(filename, triangular_mesh=False):
117 | """
118 | Read ".ply" files
119 |
120 | Parameters
121 | ----------
122 | filename : string
123 | the name of the file to read.
124 |
125 | Returns
126 | -------
127 | result : array
128 | data stored in the file
129 |
130 | Examples
131 | --------
132 | Store data in file
133 |
134 | >>> points = np.random.rand(5, 3)
135 | >>> values = np.random.randint(2, size=10)
136 | >>> write_ply('example.ply', [points, values], ['x', 'y', 'z', 'values'])
137 |
138 | Read the file
139 |
140 | >>> data = read_ply('example.ply')
141 | >>> values = data['values']
142 | array([0, 0, 1, 1, 0])
143 |
144 | >>> points = np.vstack((data['x'], data['y'], data['z'])).T
145 | array([[ 0.466 0.595 0.324]
146 | [ 0.538 0.407 0.654]
147 | [ 0.850 0.018 0.988]
148 | [ 0.395 0.394 0.363]
149 | [ 0.873 0.996 0.092]])
150 |
151 | """
152 |
153 | with open(filename, 'rb') as plyfile:
154 |
155 |
156 | # Check if the file start with ply
157 | if b'ply' not in plyfile.readline():
158 | raise ValueError('The file does not start whith the word ply')
159 |
160 | # get binary_little/big or ascii
161 | fmt = plyfile.readline().split()[1].decode()
162 | if fmt == "ascii":
163 | raise ValueError('The file is not binary')
164 |
165 | # get extension for building the numpy dtypes
166 | ext = valid_formats[fmt]
167 |
168 | # PointCloud reader vs mesh reader
169 | if triangular_mesh:
170 |
171 | # Parse header
172 | num_points, num_faces, properties = parse_mesh_header(plyfile, ext)
173 |
174 | # Get point data
175 | vertex_data = np.fromfile(plyfile, dtype=properties, count=num_points)
176 |
177 | # Get face data
178 | face_properties = [('k', ext + 'u1'),
179 | ('v1', ext + 'i4'),
180 | ('v2', ext + 'i4'),
181 | ('v3', ext + 'i4')]
182 | faces_data = np.fromfile(plyfile, dtype=face_properties, count=num_faces)
183 |
184 | # Return vertex data and concatenated faces
185 | faces = np.vstack((faces_data['v1'], faces_data['v2'], faces_data['v3'])).T
186 | data = [vertex_data, faces]
187 |
188 | else:
189 |
190 | # Parse header
191 | num_points, properties = parse_header(plyfile, ext)
192 |
193 | # Get data
194 | data = np.fromfile(plyfile, dtype=properties, count=num_points)
195 |
196 | return data
197 |
198 |
199 | def header_properties(field_list, field_names):
200 |
201 | # List of lines to write
202 | lines = []
203 |
204 | # First line describing element vertex
205 | lines.append('element vertex %d' % field_list[0].shape[0])
206 |
207 | # Properties lines
208 | i = 0
209 | for fields in field_list:
210 | for field in fields.T:
211 | lines.append('property %s %s' % (field.dtype.name, field_names[i]))
212 | i += 1
213 |
214 | return lines
215 |
216 |
217 | def write_ply(filename, field_list, field_names, triangular_faces=None):
218 | """
219 | Write ".ply" files
220 |
221 | Parameters
222 | ----------
223 | filename : string
224 | the name of the file to which the data is saved. A '.ply' extension will be appended to the
225 | file name if it does no already have one.
226 |
227 | field_list : list, tuple, numpy array
228 | the fields to be saved in the ply file. Either a numpy array, a list of numpy arrays or a
229 | tuple of numpy arrays. Each 1D numpy array and each column of 2D numpy arrays are considered
230 | as one field.
231 |
232 | field_names : list
233 | the name of each fields as a list of strings. Has to be the same length as the number of
234 | fields.
235 |
236 | Examples
237 | --------
238 | >>> points = np.random.rand(10, 3)
239 | >>> write_ply('example1.ply', points, ['x', 'y', 'z'])
240 |
241 | >>> values = np.random.randint(2, size=10)
242 | >>> write_ply('example2.ply', [points, values], ['x', 'y', 'z', 'values'])
243 |
244 | >>> colors = np.random.randint(255, size=(10,3), dtype=np.uint8)
245 | >>> field_names = ['x', 'y', 'z', 'red', 'green', 'blue', 'values']
246 | >>> write_ply('example3.ply', [points, colors, values], field_names)
247 |
248 | """
249 |
250 | # Format list input to the right form
251 | field_list = list(field_list) if (type(field_list) == list or type(field_list) == tuple) else list((field_list,))
252 | for i, field in enumerate(field_list):
253 | if field.ndim < 2:
254 | field_list[i] = field.reshape(-1, 1)
255 | if field.ndim > 2:
256 | print('fields have more than 2 dimensions')
257 | return False
258 |
259 | # check all fields have the same number of data
260 | n_points = [field.shape[0] for field in field_list]
261 | if not np.all(np.equal(n_points, n_points[0])):
262 | print('wrong field dimensions')
263 | return False
264 |
265 | # Check if field_names and field_list have same nb of column
266 | n_fields = np.sum([field.shape[1] for field in field_list])
267 | if (n_fields != len(field_names)):
268 | print('wrong number of field names')
269 | return False
270 |
271 | # Add extension if not there
272 | if not filename.endswith('.ply'):
273 | filename += '.ply'
274 |
275 | # open in text mode to write the header
276 | with open(filename, 'w') as plyfile:
277 |
278 | # First magical word
279 | header = ['ply']
280 |
281 | # Encoding format
282 | header.append('format binary_' + sys.byteorder + '_endian 1.0')
283 |
284 | # Points properties description
285 | header.extend(header_properties(field_list, field_names))
286 |
287 | # Add faces if needded
288 | if triangular_faces is not None:
289 | header.append('element face {:d}'.format(triangular_faces.shape[0]))
290 | header.append('property list uchar int vertex_indices')
291 |
292 | # End of header
293 | header.append('end_header')
294 |
295 | # Write all lines
296 | for line in header:
297 | plyfile.write("%s\n" % line)
298 |
299 | # open in binary/append to use tofile
300 | with open(filename, 'ab') as plyfile:
301 |
302 | # Create a structured array
303 | i = 0
304 | type_list = []
305 | for fields in field_list:
306 | for field in fields.T:
307 | type_list += [(field_names[i], field.dtype.str)]
308 | i += 1
309 | data = np.empty(field_list[0].shape[0], dtype=type_list)
310 | i = 0
311 | for fields in field_list:
312 | for field in fields.T:
313 | data[field_names[i]] = field
314 | i += 1
315 |
316 | data.tofile(plyfile)
317 |
318 | if triangular_faces is not None:
319 | triangular_faces = triangular_faces.astype(np.int32)
320 | type_list = [('k', 'uint8')] + [(str(ind), 'int32') for ind in range(3)]
321 | data = np.empty(triangular_faces.shape[0], dtype=type_list)
322 | data['k'] = np.full((triangular_faces.shape[0],), 3, dtype=np.uint8)
323 | data['0'] = triangular_faces[:, 0]
324 | data['1'] = triangular_faces[:, 1]
325 | data['2'] = triangular_faces[:, 2]
326 | data.tofile(plyfile)
327 |
328 | return True
329 |
330 |
331 | def describe_element(name, df):
332 | """ Takes the columns of the dataframe and builds a ply-like description
333 |
334 | Parameters
335 | ----------
336 | name: str
337 | df: pandas DataFrame
338 |
339 | Returns
340 | -------
341 | element: list[str]
342 | """
343 | property_formats = {'f': 'float', 'u': 'uchar', 'i': 'int'}
344 | element = ['element ' + name + ' ' + str(len(df))]
345 |
346 | if name == 'face':
347 | element.append("property list uchar int points_indices")
348 |
349 | else:
350 | for i in range(len(df.columns)):
351 | # get first letter of dtype to infer format
352 | f = property_formats[str(df.dtypes[i])[0]]
353 | element.append('property ' + f + ' ' + df.columns.values[i])
354 |
355 | return element
356 |
357 |
--------------------------------------------------------------------------------
/pointdc_mk/lib/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import os, faiss, shutil, random
6 | import faiss
7 | import shutil
8 | from sklearn.cluster import KMeans
9 | import MinkowskiEngine as ME
10 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2
11 |
12 | from tqdm import tqdm
13 | from torch_scatter import scatter
14 |
15 | class AverageMeter(object):
16 | """Computes and stores the average and current value"""
17 | def __init__(self):
18 | self.reset()
19 |
20 | def reset(self):
21 | self.val = 0
22 | self.avg = 0
23 | self.sum = 0
24 | self.count = 0
25 |
26 | def update(self, val, n=1):
27 | self.val = val
28 | self.sum += val * n
29 | self.count += n
30 | self.avg = self.sum / self.count
31 |
32 | class FocalLoss(nn.Module):
33 | def __init__(self, weight=None, ignore_index=None, gamma=2):
34 | super(FocalLoss,self).__init__()
35 | self.gamma = gamma
36 | self.weight = weight # 是tensor数据格式的列表
37 | self.ignore_index = ignore_index
38 |
39 | def forward(self, preds, labels):
40 | """
41 | preds:logist输出值
42 | labels:标签
43 | """
44 | if self.ignore_index is not None:
45 | mask = torch.nonzero(labels!=self.ignore_index)
46 | preds = preds[mask].squeeze(1)
47 | labels = labels[mask].squeeze(1)
48 |
49 | preds = F.softmax(preds,dim=1)
50 | eps = 1e-7
51 |
52 | target = self.one_hot(preds.size(1), labels)
53 |
54 | ce = -torch.log(preds+eps) * target
55 | floss = torch.pow((1-preds), self.gamma) * ce
56 | if self.weight is not None:
57 | floss = torch.mul(floss, self.weight)
58 | floss = torch.sum(floss, dim=1)
59 | return torch.mean(floss)
60 |
61 | def one_hot(self, num, labels):
62 | one = torch.zeros((labels.size(0),num)).cuda()
63 | one[range(labels.size(0)),labels] = 1
64 | return one
65 |
66 | class MseMaskLoss(nn.Module):
67 | def __init__(self):
68 | super(MseMaskLoss, self).__init__()
69 |
70 | def forward(seelf, sourcefeats, targetfeats):
71 | targetfeats = F.normalize(targetfeats, dim=1, p=2)
72 | sourcefeats = F.normalize(sourcefeats, dim=1, p=2)
73 | compute_index = torch.where(torch.abs(targetfeats).sum(dim=1)>0)
74 | mseloss = (sourcefeats[compute_index] - targetfeats[compute_index])**2
75 |
76 | return mseloss.mean()
77 |
78 | def worker_init_fn(seed):
79 | return lambda x: np.random.seed(seed + x)
80 |
81 | def set_seed(seed):
82 | """
83 | Unfortunately, backward() of [interpolate] functional seems to be never deterministic.
84 |
85 | Below are related threads:
86 | https://github.com/pytorch/pytorch/issues/7068
87 | https://discuss.pytorch.org/t/non-deterministic-behavior-of-pytorch-upsample-interpolate/42842?u=sbelharbi
88 | """
89 | # Use random seed.
90 | random.seed(seed)
91 | os.environ['PYTHONHASHSEED'] = str(seed)
92 | np.random.seed(seed)
93 | torch.manual_seed(seed)
94 | torch.cuda.manual_seed(seed)
95 | torch.cuda.manual_seed_all(seed)
96 | torch.backends.cudnn.deterministic = True
97 | torch.backends.cudnn.benchmark = False
98 | torch.backends.cudnn.enabled = False
99 |
100 | def init_get_sp_feature(args, loader, model, submodel=None):
101 | loader.dataset.mode = 'cluster'
102 |
103 | region_feats_list = []
104 | model.eval()
105 | with torch.no_grad():
106 | for batch_idx, data in enumerate(loader):
107 | coords, features, _, labels, inverse_map, pseudo_labels, inds, region, index, scenenames = data
108 | region = region.squeeze()
109 | raw_region = region.clone()
110 |
111 | in_field = ME.TensorField(features, coords, device=0)
112 |
113 | if submodel is not None:
114 | feats_TensorField = model(in_field)
115 | feats_nonorm = submodel(feats_TensorField)
116 | else:
117 | feats_nonorm = model(in_field)
118 | feats_nonorm = feats_nonorm[inverse_map.long()] # 获取points feats
119 | feats_norm = F.normalize(feats_nonorm, dim=1) ## NOTE 可能需要normalize?没啥区别
120 |
121 | region_feats = scatter(feats_norm, raw_region.cuda(), dim=0, reduce='mean')
122 |
123 | valid_mask = labels!=-1 # 获取带训练的mask区域
124 | # labels = labels[valid_mask]
125 | region_masked = region[valid_mask].long()
126 | region_masked_num = torch.unique(region_masked)
127 | region_masked_feats = region_feats[region_masked_num]
128 | region_masked_feats_norm = F.normalize(region_masked_feats, dim=1).cpu()
129 |
130 | region_feats_list.append(region_masked_feats_norm)
131 |
132 | torch.cuda.empty_cache()
133 | torch.cuda.synchronize(torch.device("cuda"))
134 |
135 | return region_feats_list
136 |
137 | def init_get_pseudo(args, loader, model, centroids_norm, submodel=None):
138 |
139 | pseudo_label_folder = args.pseudo_path + '/'
140 | if not os.path.exists(pseudo_label_folder): os.makedirs(pseudo_label_folder)
141 |
142 | all_pseudo = []
143 | all_label = []
144 | model.eval()
145 | with torch.no_grad():
146 | for batch_idx, data in enumerate(loader):
147 | coords, features, _, labels, inverse_map, pseudo_labels, \
148 | inds, region, index, scenenames = data
149 | region = region.squeeze()
150 | raw_region = region.clone()
151 |
152 | in_field = ME.TensorField(features, coords, device=0)
153 |
154 | if submodel is not None:
155 | feats_TensorField = model(in_field)
156 | feats_nonorm = submodel(feats_TensorField)
157 | else:
158 | feats_nonorm = model(in_field)
159 | feats_nonorm = feats_nonorm[inverse_map.long()] # 获取points feats
160 | feats_norm = F.normalize(feats_nonorm, dim=1) ## NOTE 可能需要normalize?
161 |
162 | region_feats = scatter(feats_norm, raw_region.cuda(), dim=0, reduce='mean')
163 |
164 | ## 预测
165 | region_scores= F.linear(F.normalize(region_feats), centroids_norm)
166 | region_scores_maxs, region_preds = torch.max(region_scores, dim=1)
167 | region_scores_maxs, region_preds = region_scores_maxs.cpu(), region_preds.cpu()
168 | preds = region_preds[region] ## all point preds
169 | region_scores_maxs = region_scores_maxs[region] ## all point preds
170 |
171 | pseudo_label_file = pseudo_label_folder + '/' + scenenames[0] + '.npy'
172 | np.save(pseudo_label_file, preds)
173 |
174 | all_label.append(labels)
175 | all_pseudo.append(preds)
176 |
177 | torch.cuda.empty_cache()
178 | torch.cuda.synchronize(torch.device("cuda"))
179 |
180 | all_pseudo = np.concatenate(all_pseudo)
181 | all_label = np.concatenate(all_label)
182 |
183 | return all_pseudo, all_label
184 |
185 | def get_fixclassifier(in_channel, centroids_num, centroids):
186 | classifier = nn.Linear(in_features=in_channel, out_features=centroids_num, bias=False)
187 | centroids = F.normalize(centroids, dim=1)
188 | classifier.weight.data = centroids
189 | for para in classifier.parameters():
190 | para.requires_grad = False
191 | return classifier
192 |
193 | def compute_hist(normal, bins=10, min=-1, max=1):
194 | ## normal : [N, 3]
195 | normal = F.normalize(normal)
196 | relation = torch.mm(normal, normal.t())
197 | relation = torch.triu(relation, diagonal=0) # top-half matrix
198 | hist = torch.histc(relation, bins, min, max)
199 | # hist = torch.histogram(relation, bins, range=(-1, 1))
200 | hist /= hist.sum()
201 |
202 | return hist
203 |
204 | def faiss_cluster(args, sp_feats, metric='cosin'):
205 | dim = sp_feats.shape[-1]
206 |
207 | # define faiss module
208 | res = faiss.StandardGpuResources()
209 | fcfg = faiss.GpuIndexFlatConfig()
210 | fcfg.useFloat16 = False
211 | fcfg.device = 0 #NOTE: Single GPU only.
212 | if metric == 'l2':
213 | faiss_module = faiss.GpuIndexFlatL2(res, dim, fcfg) # 欧式距离
214 | elif metric == 'cosin':
215 | faiss_module = faiss.GpuIndexFlatIP(res, dim, fcfg) # 余弦距离
216 | clus = faiss.Clustering(dim, args.primitive_num)
217 | clus.seed = np.random.randint(args.seed)
218 | clus.niter = 80
219 |
220 | # train
221 | clus.train(sp_feats, faiss_module)
222 | centroids = faiss.vector_float_to_array(clus.centroids).reshape(args.primitive_num, dim).astype('float32')
223 | centroids_norm = F.normalize(torch.tensor(centroids), dim=1)
224 | # D, I = faiss_module.search(sp_feats, 1)
225 |
226 | return None, centroids_norm
227 |
228 | def cache_codes(args):
229 | tardir = os.path.join(args.save_path, 'cache_code')
230 | if not os.path.exists(tardir):
231 | os.makedirs(tardir)
232 | try:
233 | all_files = os.listdir('./')
234 | pyfile_list = [file for file in all_files if file.endswith(".py")]
235 | shutil.copytree(r'./lib', os.path.join(tardir, r'lib'))
236 | shutil.copytree(r'./data_prepare', os.path.join(tardir, r'data_prepare'))
237 | shutil.copytree(r'./datasets', os.path.join(tardir, r'datasets'))
238 | for pyfile in pyfile_list:
239 | shutil.copy(pyfile, os.path.join(tardir, pyfile))
240 | except:
241 | pass
242 |
243 | def compute_seg_results(args, all_labels, all_preds):
244 | '''Unsupervised, Match pred to gt'''
245 | sem_num = args.semantic_class
246 | mask = (all_labels >= 0) & (all_labels < sem_num)
247 | histogram = np.bincount(sem_num * all_labels[mask] + all_preds[mask], minlength=sem_num ** 2).reshape(sem_num, sem_num)
248 | '''Hungarian Matching'''
249 | m = linear_assignment(histogram.max() - histogram)
250 | o_Acc = histogram[m[:, 0], m[:, 1]].sum() / histogram.sum()*100.
251 | m_Acc = np.mean(histogram[m[:, 0], m[:, 1]] / histogram.sum(1))*100
252 | hist_new = np.zeros((sem_num, sem_num))
253 | for idx in range(sem_num):
254 | hist_new[:, idx] = histogram[:, m[idx, 1]]
255 | '''Final Metrics'''
256 | tp = np.diag(hist_new)
257 | fp = np.sum(hist_new, 0) - tp
258 | fn = np.sum(hist_new, 1) - tp
259 | IoUs = tp / (tp + fp + fn + 1e-8)
260 | m_IoU = np.nanmean(IoUs)
261 | s = '| mIoU {:5.2f} | '.format(100 * m_IoU)
262 | for IoU in IoUs:
263 | s += '{:5.2f} '.format(100 * IoU)
264 |
265 | return o_Acc, m_Acc, s
266 |
267 | def write_list(file_path, contents):
268 | with open(file_path, 'w') as file:
269 | for content in contents:
270 | file.write(content)
271 |
--------------------------------------------------------------------------------
/pointdc_mk/lib/utils_s3dis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import os, faiss, shutil, random
6 | import faiss
7 | import shutil
8 | from sklearn.cluster import KMeans
9 | import MinkowskiEngine as ME
10 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2
11 |
12 | from tqdm import tqdm
13 | from torch_scatter import scatter
14 |
15 | class AverageMeter(object):
16 | """Computes and stores the average and current value"""
17 | def __init__(self):
18 | self.reset()
19 |
20 | def reset(self):
21 | self.val = 0
22 | self.avg = 0
23 | self.sum = 0
24 | self.count = 0
25 |
26 | def update(self, val, n=1):
27 | self.val = val
28 | self.sum += val * n
29 | self.count += n
30 | self.avg = self.sum / self.count
31 |
32 | class FocalLoss(nn.Module):
33 | def __init__(self, weight=None, ignore_index=None, gamma=2):
34 | super(FocalLoss,self).__init__()
35 | self.gamma = gamma
36 | self.weight = weight # 是tensor数据格式的列表
37 | self.ignore_index = ignore_index
38 |
39 | def forward(self, preds, labels):
40 | """
41 | preds:logist输出值
42 | labels:标签
43 | """
44 | if self.ignore_index is not None:
45 | mask = torch.nonzero(labels!=self.ignore_index)
46 | preds = preds[mask].squeeze(1)
47 | labels = labels[mask].squeeze(1)
48 |
49 | preds = F.softmax(preds,dim=1)
50 | eps = 1e-7
51 |
52 | target = self.one_hot(preds.size(1), labels)
53 |
54 | ce = -torch.log(preds+eps) * target
55 | floss = torch.pow((1-preds), self.gamma) * ce
56 | if self.weight is not None:
57 | floss = torch.mul(floss, self.weight)
58 | floss = torch.sum(floss, dim=1)
59 | return torch.mean(floss)
60 |
61 | def one_hot(self, num, labels):
62 | one = torch.zeros((labels.size(0),num)).cuda()
63 | one[range(labels.size(0)),labels] = 1
64 | return one
65 |
66 | class MseMaskLoss(nn.Module):
67 | def __init__(self):
68 | super(MseMaskLoss, self).__init__()
69 |
70 | def forward(seelf, sourcefeats, targetfeats):
71 | targetfeats = F.normalize(targetfeats, dim=1, p=2)
72 | sourcefeats = F.normalize(sourcefeats, dim=1, p=2)
73 | mseloss = (sourcefeats - targetfeats)**2
74 |
75 | return mseloss.mean()
76 |
77 | def worker_init_fn(seed):
78 | return lambda x: np.random.seed(seed + x)
79 |
80 | def set_seed(seed):
81 | """
82 | Unfortunately, backward() of [interpolate] functional seems to be never deterministic.
83 |
84 | Below are related threads:
85 | https://github.com/pytorch/pytorch/issues/7068
86 | https://discuss.pytorch.org/t/non-deterministic-behavior-of-pytorch-upsample-interpolate/42842?u=sbelharbi
87 | """
88 | # Use random seed.
89 | random.seed(seed)
90 | os.environ['PYTHONHASHSEED'] = str(seed)
91 | np.random.seed(seed)
92 | torch.manual_seed(seed)
93 | torch.cuda.manual_seed(seed)
94 | torch.cuda.manual_seed_all(seed)
95 | torch.backends.cudnn.deterministic = True
96 | torch.backends.cudnn.benchmark = False
97 | torch.backends.cudnn.enabled = False
98 |
99 | def init_get_sp_feature(args, loader, model, submodel=None):
100 | loader.dataset.mode = 'cluster'
101 |
102 | region_feats_list = []
103 | model.eval()
104 | with torch.no_grad():
105 | for batch_idx, data in enumerate(loader):
106 | coords, features, _, labels, inverse_map, pseudo_labels, inds, region, index, scenenames = data
107 | region = region.squeeze()
108 |
109 | in_field = ME.TensorField(features, coords, device=0)
110 |
111 | if submodel is not None:
112 | feats_TensorField = model(in_field)
113 | feats_nonorm = submodel(feats_TensorField)
114 | else:
115 | feats_nonorm = model(in_field)
116 | feats_nonorm = feats_nonorm[inverse_map.long()] # 获取points feats
117 | feats_norm = F.normalize(feats_nonorm, dim=1) ## NOTE 可能需要normalize?没啥区别
118 |
119 | valid_mask = region!=-1 # 获取带训练的mask区域
120 | region_feats = scatter(feats_norm[valid_mask], region[valid_mask].cuda(), dim=0, reduce='mean')
121 |
122 | region_feats_norm = F.normalize(region_feats, dim=1).cpu()
123 | region_feats_list.append(region_feats_norm)
124 |
125 | torch.cuda.empty_cache()
126 | torch.cuda.synchronize(torch.device("cuda"))
127 |
128 | return region_feats_list
129 |
130 | def init_get_pseudo(args, loader, model, centroids_norm, submodel=None):
131 |
132 | pseudo_label_folder = args.pseudo_path + '/'
133 | if not os.path.exists(pseudo_label_folder): os.makedirs(pseudo_label_folder)
134 |
135 | all_pseudo = []
136 | all_label = []
137 | model.eval()
138 | with torch.no_grad():
139 | for batch_idx, data in enumerate(loader):
140 | coords, features, _, labels, inverse_map, pseudo_labels, inds, region, index, scenenames = data
141 | region = region.squeeze()+1
142 | raw_region = region.clone()
143 |
144 | in_field = ME.TensorField(features, coords, device=0)
145 |
146 | if submodel is not None:
147 | feats_TensorField = model(in_field)
148 | feats_nonorm = submodel(feats_TensorField)
149 | else:
150 | feats_nonorm = model(in_field)
151 | feats_nonorm = feats_nonorm[inverse_map.long()] # 获取points feats
152 | feats_norm = F.normalize(feats_nonorm, dim=1) ## NOTE 可能需要normalize?
153 |
154 | scores = F.linear(feats_norm, F.normalize(centroids_norm))
155 | preds = torch.argmax(scores, dim=1).cpu() # 基于点的预测结果
156 |
157 | region_feats = scatter(feats_norm, raw_region.cuda(), dim=0, reduce='mean')
158 |
159 | region_inds = torch.unique(region, sorted=True)
160 | region_scores = F.linear(F.normalize(region_feats), F.normalize(centroids_norm)) # 超体素的预测结果
161 | for id in region_inds: # 遇到0跳过
162 | if id != 0:
163 | valid_mask = id == region
164 | preds[valid_mask] = torch.argmax(region_scores, dim=1).cpu()[id]
165 |
166 | preds[labels==-1] = -1
167 | pseudo_label_file = pseudo_label_folder + '/' + scenenames[0] + '.npy'
168 | np.save(pseudo_label_file, preds)
169 |
170 | all_label.append(labels)
171 | all_pseudo.append(preds)
172 |
173 | torch.cuda.empty_cache()
174 | torch.cuda.synchronize(torch.device("cuda"))
175 |
176 | all_pseudo = np.concatenate(all_pseudo)
177 | all_label = np.concatenate(all_label)
178 |
179 | return all_pseudo.astype('int64'), all_label.astype('int64')
180 |
181 | def get_fixclassifier(in_channel, centroids_num, centroids):
182 | classifier = nn.Linear(in_features=in_channel, out_features=centroids_num, bias=False)
183 | centroids = F.normalize(centroids, dim=1)
184 | classifier.weight.data = centroids
185 | for para in classifier.parameters():
186 | para.requires_grad = False
187 | return classifier
188 |
189 | def compute_hist(normal, bins=10, min=-1, max=1):
190 | ## normal : [N, 3]
191 | normal = F.normalize(normal)
192 | relation = torch.mm(normal, normal.t())
193 | relation = torch.triu(relation, diagonal=0) # top-half matrix
194 | hist = torch.histc(relation, bins, min, max)
195 | # hist = torch.histogram(relation, bins, range=(-1, 1))
196 | hist /= hist.sum()
197 |
198 | return hist
199 |
200 | def faiss_cluster(args, sp_feats, metric='cosin'):
201 | dim = sp_feats.shape[-1]
202 |
203 | # define faiss module
204 | res = faiss.StandardGpuResources()
205 | fcfg = faiss.GpuIndexFlatConfig()
206 | fcfg.useFloat16 = False
207 | fcfg.device = 0 #NOTE: Single GPU only.
208 | if metric == 'l2':
209 | faiss_module = faiss.GpuIndexFlatL2(res, dim, fcfg) # 欧式距离
210 | elif metric == 'cosin':
211 | faiss_module = faiss.GpuIndexFlatIP(res, dim, fcfg) # 余弦距离
212 | clus = faiss.Clustering(dim, args.primitive_num)
213 | clus.seed = np.random.randint(args.seed)
214 | clus.niter = 80
215 |
216 | # train
217 | clus.train(sp_feats, faiss_module)
218 | centroids = faiss.vector_float_to_array(clus.centroids).reshape(args.primitive_num, dim).astype('float32')
219 | centroids_norm = F.normalize(torch.tensor(centroids), dim=1)
220 | # D, I = faiss_module.search(sp_feats, 1)
221 |
222 | return None, centroids_norm
223 |
224 | def cache_codes(args):
225 | tardir = os.path.join(args.save_path, 'cache_code')
226 | if not os.path.exists(tardir):
227 | os.makedirs(tardir)
228 | try:
229 | all_files = os.listdir('./')
230 | pyfile_list = [file for file in all_files if file.endswith(".py")]
231 | shutil.copytree(r'./lib', os.path.join(tardir, r'lib'))
232 | shutil.copytree(r'./data_prepare', os.path.join(tardir, r'data_prepare'))
233 | shutil.copytree(r'./datasets', os.path.join(tardir, r'datasets'))
234 | for pyfile in pyfile_list:
235 | shutil.copy(pyfile, os.path.join(tardir, pyfile))
236 | except:
237 | pass
238 |
239 | def compute_seg_results(args, all_labels, all_preds):
240 | '''Unsupervised, Match pred to gt'''
241 | sem_num = args.semantic_class
242 | mask = (all_labels >= 0) & (all_labels < sem_num)
243 | histogram = np.bincount(sem_num * all_labels[mask] + all_preds[mask], minlength=sem_num ** 2).reshape(sem_num, sem_num)
244 | '''Hungarian Matching'''
245 | m = linear_assignment(histogram.max() - histogram)
246 | o_Acc = histogram[m[:, 0], m[:, 1]].sum() / histogram.sum()*100.
247 | m_Acc = np.mean(histogram[m[:, 0], m[:, 1]] / histogram.sum(1))*100
248 | hist_new = np.zeros((sem_num, sem_num))
249 | for idx in range(sem_num):
250 | hist_new[:, idx] = histogram[:, m[idx, 1]]
251 | '''Final Metrics'''
252 | tp = np.diag(hist_new)
253 | fp = np.sum(hist_new, 0) - tp
254 | fn = np.sum(hist_new, 1) - tp
255 | IoUs = tp / (tp + fp + fn + 1e-8)
256 | m_IoU = np.nanmean(IoUs)
257 | s = '| mIoU {:5.2f} | '.format(100 * m_IoU)
258 | for IoU in IoUs:
259 | s += '{:5.2f} '.format(100 * IoU)
260 |
261 | return o_Acc, m_Acc, s
262 |
263 | def write_list(file_path, contents):
264 | with open(file_path, 'w') as file:
265 | for content in contents:
266 | file.write(content)
267 |
--------------------------------------------------------------------------------
/pointdc_mk/models/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 |
4 | try:
5 | from .networks import *
6 | from .res16unet import *
7 | from .resunet import *
8 |
9 | _custom_models = sys.modules[__name__]
10 |
11 | def initialize_minkowski_unet(
12 | model_name, in_channels, out_channels, D=3, conv1_kernel_size=3, dilations=[1, 1, 1, 1], **kwargs
13 | ):
14 | net_cls = getattr(_custom_models, model_name)
15 | return net_cls(
16 | in_channels=in_channels, out_channels=out_channels, D=D, conv1_kernel_size=conv1_kernel_size, **kwargs
17 | )
18 |
19 |
20 | except:
21 | import logging
22 |
23 | log = logging.getLogger(__name__)
24 | log.warning("Could not load Minkowski Engine, please check that it is installed correctly")
25 |
--------------------------------------------------------------------------------
/pointdc_mk/models/api_modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import MinkowskiEngine as ME
3 | import sys
4 |
5 | from .common import NormType, get_norm
6 | from torch_points3d.core.common_modules import Seq, Identity
7 |
8 |
9 | class ResBlock(ME.MinkowskiNetwork):
10 | """
11 | Basic ResNet type block
12 |
13 | Parameters
14 | ----------
15 | input_nc:
16 | Number of input channels
17 | output_nc:
18 | number of output channels
19 | convolution
20 | Either MinkowskConvolution or MinkowskiConvolutionTranspose
21 | dimension:
22 | Dimension of the spatial grid
23 | """
24 |
25 | def __init__(self, input_nc, output_nc, convolution, dimension=3):
26 | ME.MinkowskiNetwork.__init__(self, dimension)
27 | self.block = (
28 | Seq()
29 | .append(
30 | convolution(
31 | in_channels=input_nc,
32 | out_channels=output_nc,
33 | kernel_size=3,
34 | stride=1,
35 | dilation=1,
36 | bias=False,
37 | dimension=dimension,
38 | )
39 | )
40 | .append(ME.MinkowskiBatchNorm(output_nc))
41 | .append(ME.MinkowskiReLU())
42 | .append(
43 | convolution(
44 | in_channels=output_nc,
45 | out_channels=output_nc,
46 | kernel_size=3,
47 | stride=1,
48 | dilation=1,
49 | bias=False,
50 | dimension=dimension,
51 | )
52 | )
53 | .append(ME.MinkowskiBatchNorm(output_nc))
54 | .append(ME.MinkowskiReLU())
55 | )
56 |
57 | if input_nc != output_nc:
58 | self.downsample = (
59 | Seq()
60 | .append(
61 | convolution(
62 | in_channels=input_nc,
63 | out_channels=output_nc,
64 | kernel_size=1,
65 | stride=1,
66 | dilation=1,
67 | bias=False,
68 | dimension=dimension,
69 | )
70 | )
71 | .append(ME.MinkowskiBatchNorm(output_nc))
72 | )
73 | else:
74 | self.downsample = None
75 |
76 | def forward(self, x):
77 | out = self.block(x)
78 | if self.downsample:
79 | out += self.downsample(x)
80 | else:
81 | out += x
82 | return out
83 |
84 |
85 | class BottleneckBlock(ME.MinkowskiNetwork):
86 | """
87 | Bottleneck block with residual
88 | """
89 |
90 | def __init__(self, input_nc, output_nc, convolution, dimension=3, reduction=4):
91 | self.block = (
92 | Seq()
93 | .append(
94 | convolution(
95 | in_channels=input_nc,
96 | out_channels=output_nc // reduction,
97 | kernel_size=1,
98 | stride=1,
99 | dilation=1,
100 | bias=False,
101 | dimension=dimension,
102 | )
103 | )
104 | .append(ME.MinkowskiBatchNorm(output_nc // reduction))
105 | .append(ME.MinkowskiReLU())
106 | .append(
107 | convolution(
108 | output_nc // reduction,
109 | output_nc // reduction,
110 | kernel_size=3,
111 | stride=1,
112 | dilation=1,
113 | bias=False,
114 | dimension=dimension,
115 | )
116 | )
117 | .append(ME.MinkowskiBatchNorm(output_nc // reduction))
118 | .append(ME.MinkowskiReLU())
119 | .append(
120 | convolution(
121 | output_nc // reduction,
122 | output_nc,
123 | kernel_size=1,
124 | stride=1,
125 | dilation=1,
126 | bias=False,
127 | dimension=dimension,
128 | )
129 | )
130 | .append(ME.MinkowskiBatchNorm(output_nc))
131 | .append(ME.MinkowskiReLU())
132 | )
133 |
134 | if input_nc != output_nc:
135 | self.downsample = (
136 | Seq()
137 | .append(
138 | convolution(
139 | in_channels=input_nc,
140 | out_channels=output_nc,
141 | kernel_size=1,
142 | stride=1,
143 | dilation=1,
144 | bias=False,
145 | dimension=dimension,
146 | )
147 | )
148 | .append(ME.MinkowskiBatchNorm(output_nc))
149 | )
150 | else:
151 | self.downsample = None
152 |
153 | def forward(self, x):
154 | out = self.block(x)
155 | if self.downsample:
156 | out += self.downsample(x)
157 | else:
158 | out += x
159 | return out
160 |
161 |
162 | class SELayer(torch.nn.Module):
163 | """
164 | Squeeze and excite layer
165 |
166 | Parameters
167 | ----------
168 | channel:
169 | size of the input and output
170 | reduction:
171 | magnitude of the compression
172 | D:
173 | dimension of the kernels
174 | """
175 |
176 | def __init__(self, channel, reduction=16, dimension=3):
177 | # Global coords does not require coords_key
178 | super(SELayer, self).__init__()
179 | self.fc = torch.nn.Sequential(
180 | ME.MinkowskiLinear(channel, channel // reduction),
181 | ME.MinkowskiReLU(),
182 | ME.MinkowskiLinear(channel // reduction, channel),
183 | ME.MinkowskiSigmoid(),
184 | )
185 | self.pooling = ME.MinkowskiGlobalPooling()
186 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication()
187 |
188 | def forward(self, x):
189 | y = self.pooling(x)
190 | y = self.fc(y)
191 | return self.broadcast_mul(x, y)
192 |
193 |
194 | class SEBlock(ResBlock):
195 | """
196 | ResBlock with SE layer
197 | """
198 |
199 | def __init__(self, input_nc, output_nc, convolution, dimension=3, reduction=16):
200 | super().__init__(input_nc, output_nc, convolution, dimension=3)
201 | self.SE = SELayer(output_nc, reduction=reduction, dimension=dimension)
202 |
203 | def forward(self, x):
204 | out = self.block(x)
205 | out = self.SE(out)
206 | if self.downsample:
207 | out += self.downsample(x)
208 | else:
209 | out += x
210 | return out
211 |
212 |
213 | class SEBottleneckBlock(BottleneckBlock):
214 | """
215 | BottleneckBlock with SE layer
216 | """
217 |
218 | def __init__(self, input_nc, output_nc, convolution, dimension=3, reduction=16):
219 | super().__init__(input_nc, output_nc, convolution, dimension=3, reduction=4)
220 | self.SE = SELayer(output_nc, reduction=reduction, dimension=dimension)
221 |
222 | def forward(self, x):
223 | out = self.block(x)
224 | out = self.SE(out)
225 | if self.downsample:
226 | out += self.downsample(x)
227 | else:
228 | out += x
229 | return out
230 |
231 |
232 | _res_blocks = sys.modules[__name__]
233 |
234 |
235 | class ResNetDown(ME.MinkowskiNetwork):
236 | """
237 | Resnet block that looks like
238 |
239 | in --- strided conv ---- Block ---- sum --[... N times]
240 | | |
241 | |-- 1x1 - BN --|
242 | """
243 |
244 | CONVOLUTION = ME.MinkowskiConvolution
245 |
246 | def __init__(
247 | self, down_conv_nn=[], kernel_size=2, dilation=1, dimension=3, stride=2, N=1, block="ResBlock", **kwargs
248 | ):
249 | block = getattr(_res_blocks, block)
250 | ME.MinkowskiNetwork.__init__(self, dimension)
251 | if stride > 1:
252 | conv1_output = down_conv_nn[0]
253 | else:
254 | conv1_output = down_conv_nn[1]
255 |
256 | self.conv_in = (
257 | Seq()
258 | .append(
259 | self.CONVOLUTION(
260 | in_channels=down_conv_nn[0],
261 | out_channels=conv1_output,
262 | kernel_size=kernel_size,
263 | stride=stride,
264 | dilation=dilation,
265 | bias=False,
266 | dimension=dimension,
267 | )
268 | )
269 | .append(ME.MinkowskiBatchNorm(conv1_output))
270 | .append(ME.MinkowskiReLU())
271 | )
272 |
273 | if N > 0:
274 | self.blocks = Seq()
275 | for _ in range(N):
276 | self.blocks.append(block(conv1_output, down_conv_nn[1], self.CONVOLUTION, dimension=dimension))
277 | conv1_output = down_conv_nn[1]
278 | else:
279 | self.blocks = None
280 |
281 | def forward(self, x):
282 | out = self.conv_in(x)
283 | if self.blocks:
284 | out = self.blocks(out)
285 | return out
286 |
287 |
288 | class ResNetUp(ResNetDown):
289 | """
290 | Same as Down conv but for the Decoder
291 | """
292 |
293 | CONVOLUTION = ME.MinkowskiConvolutionTranspose
294 |
295 | def __init__(self, up_conv_nn=[], kernel_size=2, dilation=1, dimension=3, stride=2, N=1, **kwargs):
296 | super().__init__(
297 | down_conv_nn=up_conv_nn,
298 | kernel_size=kernel_size,
299 | dilation=dilation,
300 | dimension=dimension,
301 | stride=stride,
302 | N=N,
303 | **kwargs
304 | )
305 |
306 | def forward(self, x, skip):
307 | if skip is not None:
308 | inp = ME.cat(x, skip)
309 | else:
310 | inp = x
311 | return super().forward(inp)
312 |
--------------------------------------------------------------------------------
/pointdc_mk/models/common.py:
--------------------------------------------------------------------------------
1 | import collections
2 | from enum import Enum
3 | import torch.nn as nn
4 |
5 | import MinkowskiEngine as ME
6 |
7 |
8 | class NormType(Enum):
9 | BATCH_NORM = 0
10 | INSTANCE_NORM = 1
11 | INSTANCE_BATCH_NORM = 2
12 |
13 |
14 | def get_norm(norm_type, n_channels, D, bn_momentum=0.1):
15 | if norm_type == NormType.BATCH_NORM:
16 | return ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum)
17 | elif norm_type == NormType.INSTANCE_NORM:
18 | return ME.MinkowskiInstanceNorm(n_channels)
19 | elif norm_type == NormType.INSTANCE_BATCH_NORM:
20 | return nn.Sequential(
21 | ME.MinkowskiInstanceNorm(n_channels), ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum)
22 | )
23 | else:
24 | raise ValueError(f"Norm type: {norm_type} not supported")
25 |
26 |
27 | def get_nonlinearity(non_type):
28 | if non_type == 'ReLU':
29 | return ME.MinkowskiReLU()
30 | elif non_type == 'ELU':
31 | # return ME.MinkowskiInstanceNorm(num_feats, dimension=dimension)
32 | return ME.MinkowskiELU()
33 | else:
34 | raise ValueError(f'Type {non_type}, not defined')
35 |
36 |
37 | class ConvType(Enum):
38 | """
39 | Define the kernel region type
40 | """
41 |
42 | HYPERCUBE = 0, "HYPERCUBE"
43 | SPATIAL_HYPERCUBE = 1, "SPATIAL_HYPERCUBE"
44 | SPATIO_TEMPORAL_HYPERCUBE = 2, "SPATIO_TEMPORAL_HYPERCUBE"
45 | HYPERCROSS = 3, "HYPERCROSS"
46 | SPATIAL_HYPERCROSS = 4, "SPATIAL_HYPERCROSS"
47 | SPATIO_TEMPORAL_HYPERCROSS = 5, "SPATIO_TEMPORAL_HYPERCROSS"
48 | SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS = 6, "SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS "
49 |
50 | def __new__(cls, value, name):
51 | member = object.__new__(cls)
52 | member._value_ = value
53 | member.fullname = name
54 | return member
55 |
56 | def __int__(self):
57 | return self.value
58 |
59 |
60 | # Covert the ConvType var to a RegionType var
61 | conv_to_region_type = {
62 | # kernel_size = [k, k, k, 1]
63 | ConvType.HYPERCUBE: ME.RegionType.HYPER_CUBE,
64 | ConvType.SPATIAL_HYPERCUBE: ME.RegionType.HYPER_CUBE,
65 | ConvType.SPATIO_TEMPORAL_HYPERCUBE: ME.RegionType.HYPER_CUBE,
66 | ConvType.HYPERCROSS: ME.RegionType.HYPER_CROSS,
67 | ConvType.SPATIAL_HYPERCROSS: ME.RegionType.HYPER_CROSS,
68 | ConvType.SPATIO_TEMPORAL_HYPERCROSS: ME.RegionType.HYPER_CROSS,
69 | ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: ME.RegionType.CUSTOM,
70 | }
71 |
72 | int_to_region_type = {0: ME.RegionType.HYPER_CUBE, 1: ME.RegionType.HYPER_CROSS, 2: ME.RegionType.CUSTOM}
73 |
74 |
75 | def convert_region_type(region_type):
76 | """
77 | Convert the integer region_type to the corresponding RegionType enum object.
78 | """
79 | return int_to_region_type[region_type]
80 |
81 |
82 | def convert_conv_type(conv_type, kernel_size, D):
83 | assert isinstance(conv_type, ConvType), "conv_type must be of ConvType"
84 | region_type = conv_to_region_type[conv_type]
85 | axis_types = None
86 | if conv_type == ConvType.SPATIAL_HYPERCUBE:
87 | # No temporal convolution
88 | if isinstance(kernel_size, collections.Sequence):
89 | kernel_size = kernel_size[:3]
90 | else:
91 | kernel_size = [kernel_size,] * 3
92 | if D == 4:
93 | kernel_size.append(1)
94 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCUBE:
95 | # conv_type conversion already handled
96 | assert D == 4
97 | elif conv_type == ConvType.HYPERCUBE:
98 | # conv_type conversion already handled
99 | pass
100 | elif conv_type == ConvType.SPATIAL_HYPERCROSS:
101 | if isinstance(kernel_size, collections.Sequence):
102 | kernel_size = kernel_size[:3]
103 | else:
104 | kernel_size = [kernel_size,] * 3
105 | if D == 4:
106 | kernel_size.append(1)
107 | elif conv_type == ConvType.HYPERCROSS:
108 | # conv_type conversion already handled
109 | pass
110 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCROSS:
111 | # conv_type conversion already handled
112 | assert D == 4
113 | elif conv_type == ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS:
114 | # Define the CUBIC conv kernel for spatial dims and CROSS conv for temp dim
115 | if D < 4:
116 | region_type = ME.RegionType.HYPER_CUBE
117 | else:
118 | axis_types = [ME.RegionType.HYPER_CUBE,] * 3
119 | if D == 4:
120 | axis_types.append(ME.RegionType.HYPER_CROSS)
121 | return region_type, axis_types, kernel_size
122 |
123 |
124 | def conv(in_planes, out_planes, kernel_size, stride=1, dilation=1, bias=False, conv_type=ConvType.HYPERCUBE, D=-1):
125 | assert D > 0, "Dimension must be a positive integer"
126 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D)
127 | kernel_generator = ME.KernelGenerator(
128 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D
129 | )
130 |
131 | return ME.MinkowskiConvolution(
132 | in_channels=in_planes,
133 | out_channels=out_planes,
134 | kernel_size=kernel_size,
135 | stride=stride,
136 | dilation=dilation,
137 | bias=bias,
138 | kernel_generator=kernel_generator,
139 | dimension=D,
140 | )
141 |
142 |
143 | def conv_tr(
144 | in_planes, out_planes, kernel_size, upsample_stride=1, dilation=1, bias=False, conv_type=ConvType.HYPERCUBE, D=-1
145 | ):
146 | assert D > 0, "Dimension must be a positive integer"
147 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D)
148 | kernel_generator = ME.KernelGenerator(
149 | kernel_size, upsample_stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D
150 | )
151 |
152 | return ME.MinkowskiConvolutionTranspose(
153 | in_channels=in_planes,
154 | out_channels=out_planes,
155 | kernel_size=kernel_size,
156 | stride=upsample_stride,
157 | dilation=dilation,
158 | bias=bias,
159 | kernel_generator=kernel_generator,
160 | dimension=D,
161 | )
162 |
163 |
164 | def avg_pool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, in_coords_key=None, D=-1):
165 | assert D > 0, "Dimension must be a positive integer"
166 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D)
167 | kernel_generator = ME.KernelGenerator(
168 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D
169 | )
170 |
171 | return ME.MinkowskiAvgPooling(
172 | kernel_size=kernel_size, stride=stride, dilation=dilation, kernel_generator=kernel_generator, dimension=D
173 | )
174 |
175 |
176 | def avg_unpool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1):
177 | assert D > 0, "Dimension must be a positive integer"
178 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D)
179 | kernel_generator = ME.KernelGenerator(
180 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D
181 | )
182 |
183 | return ME.MinkowskiAvgUnpooling(
184 | kernel_size=kernel_size, stride=stride, dilation=dilation, kernel_generator=kernel_generator, dimension=D
185 | )
186 |
187 |
188 | def sum_pool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1):
189 | assert D > 0, "Dimension must be a positive integer"
190 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D)
191 | kernel_generator = ME.KernelGenerator(
192 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D
193 | )
194 |
195 | return ME.MinkowskiSumPooling(
196 | kernel_size=kernel_size, stride=stride, dilation=dilation, kernel_generator=kernel_generator, dimension=D
197 | )
198 |
--------------------------------------------------------------------------------
/pointdc_mk/models/fpn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import MinkowskiEngine as ME
3 | from MinkowskiEngine import MinkowskiNetwork
4 | from MinkowskiEngine import MinkowskiReLU, MinkowskiUnion
5 | import torch.nn.functional as F
6 |
7 | from .common import ConvType, NormType, conv, conv_tr, get_norm, sum_pool
8 |
9 |
10 | class BasicBlockBase(nn.Module):
11 | expansion = 1
12 | NORM_TYPE = NormType.BATCH_NORM
13 |
14 | def __init__(
15 | self,
16 | inplanes,
17 | planes,
18 | stride=1,
19 | dilation=1,
20 | downsample=None,
21 | conv_type=ConvType.HYPERCUBE,
22 | bn_momentum=0.1,
23 | D=3,
24 | ):
25 | super(BasicBlockBase, self).__init__()
26 |
27 | self.conv1 = conv(inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D)
28 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
29 | self.conv2 = conv(
30 | planes, planes, kernel_size=3, stride=1, dilation=dilation, bias=False, conv_type=conv_type, D=D
31 | )
32 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
33 | self.relu = MinkowskiReLU(inplace=True)
34 | self.downsample = downsample
35 |
36 | def forward(self, x):
37 | residual = x
38 | out = self.conv1(x)
39 | out = self.norm1(out)
40 | out = self.relu(out)
41 |
42 | out = self.conv2(out)
43 | out = self.norm2(out)
44 |
45 | if self.downsample is not None:
46 | residual = self.downsample(x)
47 |
48 | out += residual
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 |
54 | class BasicBlock(BasicBlockBase):
55 | NORM_TYPE = NormType.BATCH_NORM
56 |
57 |
58 | class BasicBlockIN(BasicBlockBase):
59 | NORM_TYPE = NormType.INSTANCE_NORM
60 |
61 |
62 | class BasicBlockINBN(BasicBlockBase):
63 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM
64 |
65 |
66 | class BottleneckBase(nn.Module):
67 | expansion = 4
68 | NORM_TYPE = NormType.BATCH_NORM
69 |
70 | def __init__(
71 | self,
72 | inplanes,
73 | planes,
74 | stride=1,
75 | dilation=1,
76 | downsample=None,
77 | conv_type=ConvType.HYPERCUBE,
78 | bn_momentum=0.1,
79 | D=3,
80 | ):
81 | super(BottleneckBase, self).__init__()
82 | self.conv1 = conv(inplanes, planes, kernel_size=1, D=D)
83 | self.norm1 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
84 |
85 | self.conv2 = conv(planes, planes, kernel_size=3, stride=stride, dilation=dilation, conv_type=conv_type, D=D)
86 | self.norm2 = get_norm(self.NORM_TYPE, planes, D, bn_momentum=bn_momentum)
87 |
88 | self.conv3 = conv(planes, planes * self.expansion, kernel_size=1, D=D)
89 | self.norm3 = get_norm(self.NORM_TYPE, planes * self.expansion, D, bn_momentum=bn_momentum)
90 |
91 | self.relu = MinkowskiReLU(inplace=True)
92 | self.downsample = downsample
93 |
94 | def forward(self, x):
95 | residual = x
96 |
97 | out = self.conv1(x)
98 | out = self.norm1(out)
99 | out = self.relu(out)
100 |
101 | out = self.conv2(out)
102 | out = self.norm2(out)
103 | out = self.relu(out)
104 |
105 | out = self.conv3(out)
106 | out = self.norm3(out)
107 |
108 | if self.downsample is not None:
109 | residual = self.downsample(x)
110 |
111 | out += residual
112 | out = self.relu(out)
113 |
114 | return out
115 |
116 |
117 | class Bottleneck(BottleneckBase):
118 | NORM_TYPE = NormType.BATCH_NORM
119 |
120 |
121 | class BottleneckIN(BottleneckBase):
122 | NORM_TYPE = NormType.INSTANCE_NORM
123 |
124 |
125 | class BottleneckINBN(BottleneckBase):
126 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM
127 |
128 |
129 | class ResNetBase(MinkowskiNetwork):
130 | BLOCK = None
131 | LAYERS = ()
132 | INIT_DIM = 64
133 | PLANES = (64, 128, 256, 512)
134 | OUT_PIXEL_DIST = 32
135 | HAS_LAST_BLOCK = False
136 | CONV_TYPE = ConvType.HYPERCUBE
137 |
138 | def __init__(self, in_channels, out_channels, D, conv1_kernel_size=3, dilations=[1, 1, 1, 1], **kwargs):
139 | super(ResNetBase, self).__init__(D)
140 | self.in_channels = in_channels
141 | self.out_channels = out_channels
142 | self.conv1_kernel_size = conv1_kernel_size
143 | self.dilations = dilations
144 | assert self.BLOCK is not None
145 | assert self.OUT_PIXEL_DIST > 0
146 |
147 | self.network_initialization(in_channels, out_channels, D)
148 | self.weight_initialization()
149 |
150 | def network_initialization(self, in_channels, out_channels, D):
151 | def space_n_time_m(n, m):
152 | return n if D == 3 else [n, n, n, m]
153 |
154 | if D == 4:
155 | self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)
156 |
157 | dilations = self.dilations
158 | bn_momentum = 1
159 | self.inplanes = self.INIT_DIM
160 | self.conv1 = conv(
161 | in_channels, self.inplanes, kernel_size=space_n_time_m(self.conv1_kernel_size, 1), stride=1, D=D
162 | )
163 |
164 | self.bn1 = get_norm(NormType.BATCH_NORM, self.inplanes, D=self.D, bn_momentum=bn_momentum)
165 | self.relu = ME.MinkowskiReLU(inplace=True)
166 | self.pool = sum_pool(kernel_size=space_n_time_m(2, 1), stride=space_n_time_m(2, 1), D=D)
167 |
168 | self.layer1 = self._make_layer(
169 | self.BLOCK,
170 | self.PLANES[0],
171 | self.LAYERS[0],
172 | stride=space_n_time_m(2, 1),
173 | dilation=space_n_time_m(dilations[0], 1),
174 | )
175 | self.layer2 = self._make_layer(
176 | self.BLOCK,
177 | self.PLANES[1],
178 | self.LAYERS[1],
179 | stride=space_n_time_m(2, 1),
180 | dilation=space_n_time_m(dilations[1], 1),
181 | )
182 | self.layer3 = self._make_layer(
183 | self.BLOCK,
184 | self.PLANES[2],
185 | self.LAYERS[2],
186 | stride=space_n_time_m(2, 1),
187 | dilation=space_n_time_m(dilations[2], 1),
188 | )
189 | self.layer4 = self._make_layer(
190 | self.BLOCK,
191 | self.PLANES[3],
192 | self.LAYERS[3],
193 | stride=space_n_time_m(2, 1),
194 | dilation=space_n_time_m(dilations[3], 1),
195 | )
196 |
197 | self.final = conv(self.PLANES[3] * self.BLOCK.expansion, out_channels, kernel_size=1, bias=True, D=D)
198 |
199 | def weight_initialization(self):
200 | for m in self.modules():
201 | if isinstance(m, ME.MinkowskiBatchNorm):
202 | nn.init.constant_(m.bn.weight, 1)
203 | nn.init.constant_(m.bn.bias, 0)
204 |
205 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_type=NormType.BATCH_NORM, bn_momentum=0.1):
206 | downsample = None
207 | if stride != 1 or self.inplanes != planes * block.expansion:
208 | downsample = nn.Sequential(
209 | conv(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, D=self.D),
210 | get_norm(norm_type, planes * block.expansion, D=self.D, bn_momentum=bn_momentum),
211 | )
212 | layers = []
213 | layers.append(
214 | block(
215 | self.inplanes,
216 | planes,
217 | stride=stride,
218 | dilation=dilation,
219 | downsample=downsample,
220 | conv_type=self.CONV_TYPE,
221 | D=self.D,
222 | )
223 | )
224 | self.inplanes = planes * block.expansion
225 | for i in range(1, blocks):
226 | layers.append(block(self.inplanes, planes, stride=1, dilation=dilation, conv_type=self.CONV_TYPE, D=self.D))
227 |
228 | return nn.Sequential(*layers)
229 |
230 | def forward(self, x):
231 | x = self.conv1(x)
232 | x = self.bn1(x)
233 | x = self.relu(x)
234 | x = self.pool(x)
235 |
236 | x = self.layer1(x)
237 | x = self.layer2(x)
238 | x = self.layer3(x)
239 | x = self.layer4(x)
240 |
241 | x = self.final(x)
242 | return x
243 |
244 |
245 | class Res16FPNBase(ResNetBase):
246 | BLOCK = None
247 | PLANES = (32, 64, 128, 256, 256, 256, 256, 256)
248 | DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1)
249 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
250 | INIT_DIM = 32
251 | OUT_PIXEL_DIST = 1
252 | NORM_TYPE = NormType.BATCH_NORM
253 | NON_BLOCK_CONV_TYPE = ConvType.SPATIAL_HYPERCUBE
254 | CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS
255 |
256 | # To use the model, must call initialize_coords before forward pass.
257 | # Once data is processed, call clear to reset the model before calling initialize_coords
258 | def __init__(self, in_channels, out_channels, D=3, conv1_kernel_size=5, **kwargs):
259 | super(Res16FPNBase, self).__init__(in_channels, out_channels, D, conv1_kernel_size)
260 |
261 | self.mode = kwargs['mode']
262 |
263 | def network_initialization(self, in_channels, out_channels, D):
264 | # Setup net_metadata
265 | dilations = self.DILATIONS
266 | bn_momentum = 0.02
267 |
268 | def space_n_time_m(n, m):
269 | return n if D == 3 else [n, n, n, m]
270 |
271 | if D == 4:
272 | self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1)
273 |
274 | self.inplanes = self.INIT_DIM
275 | self.conv0p1s1 = conv(
276 | in_channels,
277 | self.inplanes,
278 | kernel_size=space_n_time_m(self.conv1_kernel_size, 1),
279 | stride=1,
280 | dilation=1,
281 | conv_type=self.NON_BLOCK_CONV_TYPE,
282 | D=D,
283 | )
284 |
285 | self.bn0 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
286 |
287 | self.conv1p1s2 = conv(
288 | self.inplanes,
289 | self.inplanes,
290 | kernel_size=space_n_time_m(2, 1),
291 | stride=space_n_time_m(2, 1),
292 | dilation=1,
293 | conv_type=self.NON_BLOCK_CONV_TYPE,
294 | D=D,
295 | )
296 | self.bn1 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
297 | self.block1 = self._make_layer(
298 | self.BLOCK,
299 | self.PLANES[0],
300 | self.LAYERS[0],
301 | dilation=dilations[0],
302 | norm_type=self.NORM_TYPE,
303 | bn_momentum=bn_momentum,
304 | )
305 |
306 | self.conv2p2s2 = conv(
307 | self.inplanes,
308 | self.inplanes,
309 | kernel_size=space_n_time_m(2, 1),
310 | stride=space_n_time_m(2, 1),
311 | dilation=1,
312 | conv_type=self.NON_BLOCK_CONV_TYPE,
313 | D=D,
314 | )
315 | self.bn2 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
316 | self.block2 = self._make_layer(
317 | self.BLOCK,
318 | self.PLANES[1],
319 | self.LAYERS[1],
320 | dilation=dilations[1],
321 | norm_type=self.NORM_TYPE,
322 | bn_momentum=bn_momentum,
323 | )
324 |
325 | self.conv3p4s2 = conv(
326 | self.inplanes,
327 | self.inplanes,
328 | kernel_size=space_n_time_m(2, 1),
329 | stride=space_n_time_m(2, 1),
330 | dilation=1,
331 | conv_type=self.NON_BLOCK_CONV_TYPE,
332 | D=D,
333 | )
334 | self.bn3 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
335 | self.block3 = self._make_layer(
336 | self.BLOCK,
337 | self.PLANES[2],
338 | self.LAYERS[2],
339 | dilation=dilations[2],
340 | norm_type=self.NORM_TYPE,
341 | bn_momentum=bn_momentum,
342 | )
343 |
344 | self.conv4p8s2 = conv(
345 | self.inplanes,
346 | self.inplanes,
347 | kernel_size=space_n_time_m(2, 1),
348 | stride=space_n_time_m(2, 1),
349 | dilation=1,
350 | conv_type=self.NON_BLOCK_CONV_TYPE,
351 | D=D,
352 | )
353 | self.bn4 = get_norm(self.NORM_TYPE, self.inplanes, D, bn_momentum=bn_momentum)
354 | self.block4 = self._make_layer(
355 | self.BLOCK,
356 | self.PLANES[3],
357 | self.LAYERS[3],
358 | dilation=dilations[3],
359 | norm_type=self.NORM_TYPE,
360 | bn_momentum=bn_momentum,
361 | )
362 |
363 | self.delayer1 = ME.MinkowskiLinear(256, 128, bias=False)
364 | self.delayer2 = ME.MinkowskiLinear(128, 128, bias=False)
365 | self.delayer3 = ME.MinkowskiLinear(64, 128, bias=False)
366 | self.delayer4 = ME.MinkowskiLinear(32, 128, bias=False)
367 |
368 | self.relu = MinkowskiReLU(inplace=True)
369 |
370 |
371 | def forward(self, x): #
372 | y = x.sparse()
373 | out = self.conv0p1s1(y)
374 | out = self.bn0(out)
375 | out_p1 = self.relu(out)###32
376 |
377 | out = self.conv1p1s2(out_p1)
378 | out = self.bn1(out)
379 | out = self.relu(out)
380 | out_b1p2 = self.block1(out)###32
381 |
382 | out = self.conv2p2s2(out_b1p2)
383 | out = self.bn2(out)
384 | out = self.relu(out)
385 | out_b2p4 = self.block2(out)###64
386 |
387 | out = self.conv3p4s2(out_b2p4)
388 | out = self.bn3(out)
389 | out = self.relu(out)
390 | out_b3p8 = self.block3(out)###128
391 |
392 | out = self.conv4p8s2(out_b3p8)
393 | out = self.bn4(out)
394 | out = self.relu(out)
395 | out = self.block4(out)###256
396 |
397 |
398 | out = self.delayer1(out)
399 | out = out.interpolate(x)
400 |
401 | dout_b3p8 = self.delayer2(out_b3p8)
402 | dout_b3p8 = dout_b3p8.interpolate(x)
403 |
404 | dout_b2p4 = self.delayer3(out_b2p4)
405 | dout_b2p4 = dout_b2p4.interpolate(x)
406 |
407 | dout_b1p2 = self.delayer4(out_b1p2)
408 | dout_b1p2 = dout_b1p2.interpolate(x)
409 |
410 | if self.mode == 'distill':
411 | output = MinkowskiUnion()(out.sparse(), dout_b3p8.sparse(), dout_b2p4.sparse(), dout_b1p2.sparse()).slice(x)
412 | else:
413 | output = out.F + dout_b3p8.F + dout_b2p4.F + dout_b1p2.F # dim=128 type: TensorField
414 | return output
415 |
416 |
417 |
418 | class Res16FPN14(Res16FPNBase):
419 | BLOCK = BasicBlock
420 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1)
421 |
422 |
423 | class Res16FPN18(Res16FPNBase):
424 | BLOCK = BasicBlock
425 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
426 |
427 |
428 | class Res16FPN34(Res16FPNBase):
429 | BLOCK = BasicBlock
430 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
431 |
432 |
433 | class Res16FPN50(Res16FPNBase):
434 | BLOCK = Bottleneck
435 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
436 |
437 |
438 | class Res16UFPN101(Res16FPNBase):
439 | BLOCK = Bottleneck
440 | LAYERS = (2, 3, 4, 23, 2, 2, 2, 2)
441 |
442 |
443 | class Res16FPN14A(Res16FPN14):
444 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
445 |
446 |
447 | class Res16FPN14A2(Res16FPN14A):
448 | LAYERS = (1, 1, 1, 1, 2, 2, 2, 2)
449 |
450 |
451 | class Res16FPN14B(Res16FPN14):
452 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
453 |
454 |
455 | class Res16FPN14B2(Res16FPN14B):
456 | LAYERS = (1, 1, 1, 1, 2, 2, 2, 2)
457 |
458 |
459 | class Res16FPN14B3(Res16FPN14B):
460 | LAYERS = (2, 2, 2, 2, 1, 1, 1, 1)
461 |
462 |
463 | class Res16FPN14C(Res16FPN14):
464 | PLANES = (32, 64, 128, 256, 192, 192, 128, 128)
465 |
466 |
467 | class Res16FPN14D(Res16FPN14):
468 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
469 |
470 |
471 | class Res16FPN18A(Res16FPN18):
472 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
473 |
474 |
475 | class Res16FPN18B(Res16FPN18):
476 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
477 |
478 |
479 | class Res16FPN18D(Res16FPN18):
480 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
481 |
482 |
483 | class Res16FPN32B(Res16FPN34):
484 | PLANES = (32, 64, 128, 256, 256, 64, 64, 64)
485 |
486 |
487 | class Res16FPN34A(Res16FPN34):
488 | PLANES = (32, 64, 128, 256, 256, 128, 64, 64)
489 |
490 |
491 | class Res16FPN34B(Res16FPN34):
492 | PLANES = (32, 64, 128, 256, 256, 128, 64, 32)
493 |
494 |
495 | class Res16FPN34C(Res16FPN34):
496 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
497 |
498 |
499 |
500 | def get_block(norm_type, inplanes, planes, stride=1, dilation=1, downsample=None, bn_momentum=0.1, D=3):
501 | if norm_type == NormType.BATCH_NORM:
502 | return BasicBlock(
503 | inplanes=inplanes,
504 | planes=planes,
505 | stride=stride,
506 | dilation=dilation,
507 | downsample=downsample,
508 | bn_momentum=bn_momentum,
509 | D=D,
510 | )
511 | elif norm_type == NormType.INSTANCE_NORM:
512 | return BasicBlockIN(inplanes, planes, stride, dilation, downsample, bn_momentum, D)
513 | else:
514 | raise ValueError(f"Type {norm_type}, not defined")
515 |
--------------------------------------------------------------------------------
/pointdc_mk/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import MinkowskiEngine as ME
3 | from .common import ConvType, NormType
4 |
5 |
6 | class BasicBlock(nn.Module):
7 | """This module implements a basic residual convolution block using MinkowskiEngine
8 |
9 | Parameters
10 | ----------
11 | inplanes: int
12 | Input dimension
13 | planes: int
14 | Output dimension
15 | dilation: int
16 | Dilation value
17 | downsample: nn.Module
18 | If provided, downsample will be applied on input before doing residual addition
19 | bn_momentum: float
20 | Input dimension
21 | """
22 |
23 | EXPANSION = 1
24 |
25 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, bn_momentum=0.1, dimension=-1):
26 | super(BasicBlock, self).__init__()
27 | assert dimension > 0
28 |
29 | self.conv1 = ME.MinkowskiConvolution(
30 | inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, dimension=dimension
31 | )
32 | self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum)
33 | self.conv2 = ME.MinkowskiConvolution(
34 | planes, planes, kernel_size=3, stride=1, dilation=dilation, dimension=dimension
35 | )
36 | self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum)
37 | self.relu = ME.MinkowskiReLU(inplace=True)
38 | self.downsample = downsample
39 |
40 | def forward(self, x):
41 | residual = x
42 |
43 | out = self.conv1(x)
44 | out = self.norm1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.norm2(out)
49 |
50 | if self.downsample is not None:
51 | residual = self.downsample(x)
52 |
53 | out += residual
54 | out = self.relu(out)
55 |
56 | return out
57 |
58 |
59 | class Bottleneck(nn.Module):
60 | EXPANSION = 4
61 |
62 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, bn_momentum=0.1, dimension=-1):
63 | super(Bottleneck, self).__init__()
64 | assert dimension > 0
65 |
66 | self.conv1 = ME.MinkowskiConvolution(inplanes, planes, kernel_size=1, dimension=dimension)
67 | self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum)
68 |
69 | self.conv2 = ME.MinkowskiConvolution(
70 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, dimension=dimension
71 | )
72 | self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum)
73 |
74 | self.conv3 = ME.MinkowskiConvolution(planes, planes * self.EXPANSION, kernel_size=1, dimension=dimension)
75 | self.norm3 = ME.MinkowskiBatchNorm(planes * self.EXPANSION, momentum=bn_momentum)
76 |
77 | self.relu = ME.MinkowskiReLU(inplace=True)
78 | self.downsample = downsample
79 |
80 | def forward(self, x):
81 | residual = x
82 |
83 | out = self.conv1(x)
84 | out = self.norm1(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv2(out)
88 | out = self.norm2(out)
89 | out = self.relu(out)
90 |
91 | out = self.conv3(out)
92 | out = self.norm3(out)
93 |
94 | if self.downsample is not None:
95 | residual = self.downsample(x)
96 |
97 | out += residual
98 | out = self.relu(out)
99 |
100 | return out
101 |
102 |
103 | class BaseResBlock(nn.Module):
104 | def __init__(
105 | self,
106 | feat_in,
107 | feat_mid,
108 | feat_out,
109 | kernel_sizes=[],
110 | strides=[],
111 | dilations=[],
112 | has_biases=[],
113 | kernel_generators=[],
114 | kernel_size=3,
115 | stride=1,
116 | dilation=1,
117 | bias=False,
118 | kernel_generator=None,
119 | norm_layer=ME.MinkowskiBatchNorm,
120 | activation=ME.MinkowskiReLU,
121 | bn_momentum=0.1,
122 | dimension=-1,
123 | **kwargs
124 | ):
125 |
126 | super(BaseResBlock, self).__init__()
127 | assert dimension > 0
128 |
129 | modules = []
130 |
131 | convolutions_dim = [[feat_in, feat_mid], [feat_mid, feat_mid], [feat_mid, feat_out]]
132 |
133 | kernel_sizes = self.create_arguments_list(kernel_sizes, kernel_size)
134 | strides = self.create_arguments_list(strides, stride)
135 | dilations = self.create_arguments_list(dilations, dilation)
136 | has_biases = self.create_arguments_list(has_biases, bias)
137 | kernel_generators = self.create_arguments_list(kernel_generators, kernel_generator)
138 |
139 | for conv_dim, kernel_size, stride, dilation, has_bias, kernel_generator in zip(
140 | convolutions_dim, kernel_sizes, strides, dilations, has_biases, kernel_generators
141 | ):
142 |
143 | modules.append(
144 | ME.MinkowskiConvolution(
145 | conv_dim[0],
146 | conv_dim[1],
147 | kernel_size=kernel_size,
148 | stride=stride,
149 | dilation=dilation,
150 | bias=has_bias,
151 | kernel_generator=kernel_generator,
152 | dimension=dimension,
153 | )
154 | )
155 |
156 | if norm_layer:
157 | modules.append(norm_layer(conv_dim[1], momentum=bn_momentum))
158 |
159 | if activation:
160 | modules.append(activation(inplace=True))
161 |
162 | self.conv = nn.Sequential(*modules)
163 |
164 | @staticmethod
165 | def create_arguments_list(arg_list, arg):
166 | if len(arg_list) == 3:
167 | return arg_list
168 | return [arg for _ in range(3)]
169 |
170 | def forward(self, x):
171 | return x, self.conv(x)
172 |
173 |
174 | class ResnetBlockDown(BaseResBlock):
175 | def __init__(
176 | self,
177 | down_conv_nn=[],
178 | kernel_sizes=[],
179 | strides=[],
180 | dilations=[],
181 | kernel_size=3,
182 | stride=1,
183 | dilation=1,
184 | norm_layer=ME.MinkowskiBatchNorm,
185 | activation=ME.MinkowskiReLU,
186 | bn_momentum=0.1,
187 | dimension=-1,
188 | down_stride=2,
189 | **kwargs
190 | ):
191 |
192 | super(ResnetBlockDown, self).__init__(
193 | down_conv_nn[0],
194 | down_conv_nn[1],
195 | down_conv_nn[2],
196 | kernel_sizes=kernel_sizes,
197 | strides=strides,
198 | dilations=dilations,
199 | kernel_size=kernel_size,
200 | stride=stride,
201 | dilation=dilation,
202 | norm_layer=norm_layer,
203 | activation=activation,
204 | bn_momentum=bn_momentum,
205 | dimension=dimension,
206 | )
207 |
208 | self.downsample = nn.Sequential(
209 | ME.MinkowskiConvolution(
210 | down_conv_nn[0], down_conv_nn[2], kernel_size=2, stride=down_stride, dimension=dimension
211 | ),
212 | ME.MinkowskiBatchNorm(down_conv_nn[2]),
213 | )
214 |
215 | def forward(self, x):
216 |
217 | residual, x = super().forward(x)
218 |
219 | return self.downsample(residual) + x
220 |
221 |
222 | class ResnetBlockUp(BaseResBlock):
223 | def __init__(
224 | self,
225 | up_conv_nn=[],
226 | kernel_sizes=[],
227 | strides=[],
228 | dilations=[],
229 | kernel_size=3,
230 | stride=1,
231 | dilation=1,
232 | norm_layer=ME.MinkowskiBatchNorm,
233 | activation=ME.MinkowskiReLU,
234 | bn_momentum=0.1,
235 | dimension=-1,
236 | up_stride=2,
237 | skip=True,
238 | **kwargs
239 | ):
240 |
241 | self.skip = skip
242 |
243 | super(ResnetBlockUp, self).__init__(
244 | up_conv_nn[0],
245 | up_conv_nn[1],
246 | up_conv_nn[2],
247 | kernel_sizes=kernel_sizes,
248 | strides=strides,
249 | dilations=dilations,
250 | kernel_size=kernel_size,
251 | stride=stride,
252 | dilation=dilation,
253 | norm_layer=norm_layer,
254 | activation=activation,
255 | bn_momentum=bn_momentum,
256 | dimension=dimension,
257 | )
258 |
259 | self.upsample = ME.MinkowskiConvolutionTranspose(
260 | up_conv_nn[0], up_conv_nn[2], kernel_size=2, stride=up_stride, dimension=dimension
261 | )
262 |
263 | def forward(self, x, x_skip):
264 | residual, x = super().forward(x)
265 |
266 | x = self.upsample(residual) + x
267 |
268 | if self.skip:
269 | return ME.cat(x, x_skip)
270 | else:
271 | return x
272 |
273 |
274 | class SELayer(nn.Module):
275 | def __init__(self, channel, reduction=16, D=-1):
276 | # Global coords does not require coords_key
277 | super(SELayer, self).__init__()
278 | self.fc = nn.Sequential(
279 | ME.MinkowskiLinear(channel, channel // reduction),
280 | ME.MinkowskiReLU(inplace=True),
281 | ME.MinkowskiLinear(channel // reduction, channel),
282 | ME.MinkowskiSigmoid(),
283 | )
284 | self.pooling = ME.MinkowskiGlobalPooling(dimension=D)
285 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication(dimension=D)
286 |
287 | def forward(self, x):
288 | y = self.pooling(x)
289 | y = self.fc(y)
290 | return self.broadcast_mul(x, y)
291 |
292 |
293 | class SEBasicBlock(BasicBlock):
294 | def __init__(
295 | self, inplanes, planes, stride=1, dilation=1, downsample=None, conv_type=ConvType.HYPERCUBE, reduction=16, D=-1
296 | ):
297 | super(SEBasicBlock, self).__init__(
298 | inplanes, planes, stride=stride, dilation=dilation, downsample=downsample, conv_type=conv_type, D=D
299 | )
300 | self.se = SELayer(planes, reduction=reduction, D=D)
301 |
302 | def forward(self, x):
303 | residual = x
304 |
305 | out = self.conv1(x)
306 | out = self.norm1(out)
307 | out = self.relu(out)
308 |
309 | out = self.conv2(out)
310 | out = self.norm2(out)
311 | out = self.se(out)
312 |
313 | if self.downsample is not None:
314 | residual = self.downsample(x)
315 |
316 | out += residual
317 | out = self.relu(out)
318 |
319 | return out
320 |
321 |
322 | class SEBasicBlockBN(SEBasicBlock):
323 | NORM_TYPE = NormType.BATCH_NORM
324 |
325 |
326 | class SEBasicBlockIN(SEBasicBlock):
327 | NORM_TYPE = NormType.INSTANCE_NORM
328 |
329 |
330 | class SEBasicBlockIBN(SEBasicBlock):
331 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM
332 |
333 |
334 | class SEBottleneck(Bottleneck):
335 | def __init__(
336 | self, inplanes, planes, stride=1, dilation=1, downsample=None, conv_type=ConvType.HYPERCUBE, D=3, reduction=16
337 | ):
338 | super(SEBottleneck, self).__init__(
339 | inplanes, planes, stride=stride, dilation=dilation, downsample=downsample, conv_type=conv_type, D=D
340 | )
341 | self.se = SELayer(planes * self.expansion, reduction=reduction, D=D)
342 |
343 | def forward(self, x):
344 | residual = x
345 |
346 | out = self.conv1(x)
347 | out = self.norm1(out)
348 | out = self.relu(out)
349 |
350 | out = self.conv2(out)
351 | out = self.norm2(out)
352 | out = self.relu(out)
353 |
354 | out = self.conv3(out)
355 | out = self.norm3(out)
356 | out = self.se(out)
357 |
358 | if self.downsample is not None:
359 | residual = self.downsample(x)
360 |
361 | out += residual
362 | out = self.relu(out)
363 |
364 | return out
365 |
366 |
367 | class SEBottleneckBN(SEBottleneck):
368 | NORM_TYPE = NormType.BATCH_NORM
369 |
370 |
371 | class SEBottleneckIN(SEBottleneck):
372 | NORM_TYPE = NormType.INSTANCE_NORM
373 |
374 |
375 | class SEBottleneckIBN(SEBottleneck):
376 | NORM_TYPE = NormType.INSTANCE_BATCH_NORM
377 |
--------------------------------------------------------------------------------
/pointdc_mk/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | import MinkowskiEngine as ME
4 | from .modules import BasicBlock, Bottleneck
5 |
6 |
7 | class ResNetBase(nn.Module):
8 | BLOCK = None
9 | LAYERS = ()
10 | INIT_DIM = 64
11 | PLANES = (64, 128, 256, 512)
12 |
13 | def __init__(self, in_channels, out_channels, D=3, **kwargs):
14 | nn.Module.__init__(self)
15 | self.D = D
16 | assert self.BLOCK is not None, "BLOCK is not defined"
17 | assert self.PLANES is not None, "PLANES is not defined"
18 | self.network_initialization(in_channels, out_channels, D)
19 | self.weight_initialization()
20 |
21 | def network_initialization(self, in_channels, out_channels, D):
22 |
23 | self.inplanes = self.INIT_DIM
24 | self.conv1 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=5, stride=2, dimension=D)
25 |
26 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes)
27 | self.relu = ME.MinkowskiReLU(inplace=True)
28 |
29 | self.pool = ME.MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=D)
30 |
31 | self.layer1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2)
32 | self.layer2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2)
33 | self.layer3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2)
34 | self.layer4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2)
35 |
36 | self.conv5 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D)
37 | self.bn5 = ME.MinkowskiBatchNorm(self.inplanes)
38 |
39 | self.glob_avg = ME.MinkowskiGlobalMaxPooling(dimension=D)
40 |
41 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True)
42 |
43 | def weight_initialization(self):
44 | for m in self.modules():
45 | if isinstance(m, ME.MinkowskiConvolution):
46 | ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu")
47 |
48 | if isinstance(m, ME.MinkowskiBatchNorm):
49 | nn.init.constant_(m.bn.weight, 1)
50 | nn.init.constant_(m.bn.bias, 0)
51 |
52 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1):
53 | downsample = None
54 | if stride != 1 or self.inplanes != planes * block.EXPANSION:
55 | downsample = nn.Sequential(
56 | ME.MinkowskiConvolution(
57 | self.inplanes, planes * block.EXPANSION, kernel_size=1, stride=stride, dimension=self.D
58 | ),
59 | ME.MinkowskiBatchNorm(planes * block.EXPANSION),
60 | )
61 | layers = []
62 | layers.append(
63 | block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample, dimension=self.D)
64 | )
65 | self.inplanes = planes * block.EXPANSION
66 | for i in range(1, blocks):
67 | layers.append(block(self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D))
68 |
69 | return nn.Sequential(*layers)
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = self.bn1(x)
74 | x = self.relu(x)
75 | x = self.pool(x)
76 |
77 | x = self.layer1(x)
78 | x = self.layer2(x)
79 | x = self.layer3(x)
80 | x = self.layer4(x)
81 |
82 | x = self.conv5(x)
83 | x = self.bn5(x)
84 | x = self.relu(x)
85 |
86 | x = self.glob_avg(x)
87 | return self.final(x)
88 |
89 |
90 | class ResNet14(ResNetBase):
91 | BLOCK = BasicBlock
92 | LAYERS = (1, 1, 1, 1)
93 |
94 |
95 | class ResNet18(ResNetBase):
96 | BLOCK = BasicBlock
97 | LAYERS = (2, 2, 2, 2)
98 |
99 |
100 | class ResNet34(ResNetBase):
101 | BLOCK = BasicBlock
102 | LAYERS = (3, 4, 6, 3)
103 |
104 |
105 | class ResNet50(ResNetBase):
106 | BLOCK = Bottleneck
107 | LAYERS = (3, 4, 6, 3)
108 |
109 |
110 | class ResNet101(ResNetBase):
111 | BLOCK = Bottleneck
112 | LAYERS = (3, 4, 23, 3)
113 |
114 |
115 | class MinkUNetBase(ResNetBase):
116 | BLOCK = None
117 | PLANES = None
118 | DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1)
119 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
120 | INIT_DIM = 32
121 | OUT_TENSOR_STRIDE = 1
122 |
123 | # To use the model, must call initialize_coords before forward pass.
124 | # Once data is processed, call clear to reset the model before calling
125 | # initialize_coords
126 | def __init__(self, in_channels, out_channels, D=3, **kwargs):
127 | ResNetBase.__init__(self, in_channels, out_channels, D)
128 |
129 | def network_initialization(self, in_channels, out_channels, D):
130 | # Output of the first conv concated to conv6
131 | self.inplanes = self.INIT_DIM
132 | self.conv0p1s1 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=5, dimension=D)
133 |
134 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes)
135 |
136 | self.conv1p1s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
137 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes)
138 |
139 | self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0])
140 |
141 | self.conv2p2s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
142 | self.bn2 = ME.MinkowskiBatchNorm(self.inplanes)
143 |
144 | self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1])
145 |
146 | self.conv3p4s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
147 |
148 | self.bn3 = ME.MinkowskiBatchNorm(self.inplanes)
149 | self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2])
150 |
151 | self.conv4p8s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
152 | self.bn4 = ME.MinkowskiBatchNorm(self.inplanes)
153 | self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3])
154 |
155 | self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose(
156 | self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D
157 | )
158 | self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4])
159 |
160 | self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.EXPANSION
161 | self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4])
162 | self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose(
163 | self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D
164 | )
165 | self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5])
166 |
167 | self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.EXPANSION
168 | self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5])
169 | self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose(
170 | self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D
171 | )
172 | self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6])
173 |
174 | self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.EXPANSION
175 | self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], self.LAYERS[6])
176 | self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose(
177 | self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D
178 | )
179 | self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7])
180 |
181 | self.inplanes = self.PLANES[7] + self.INIT_DIM
182 | self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], self.LAYERS[7])
183 |
184 | self.final = ME.MinkowskiConvolution(self.PLANES[7], out_channels, kernel_size=1, bias=True, dimension=D)
185 | self.relu = ME.MinkowskiReLU(inplace=True)
186 |
187 | def forward(self, x):
188 | out = self.conv0p1s1(x)
189 | out = self.bn0(out)
190 | out_p1 = self.relu(out)
191 |
192 | out = self.conv1p1s2(out_p1)
193 | out = self.bn1(out)
194 | out = self.relu(out)
195 | out_b1p2 = self.block1(out)
196 |
197 | out = self.conv2p2s2(out_b1p2)
198 | out = self.bn2(out)
199 | out = self.relu(out)
200 | out_b2p4 = self.block2(out)
201 |
202 | out = self.conv3p4s2(out_b2p4)
203 | out = self.bn3(out)
204 | out = self.relu(out)
205 | out_b3p8 = self.block3(out)
206 |
207 | # tensor_stride=16
208 | out = self.conv4p8s2(out_b3p8)
209 | out = self.bn4(out)
210 | out = self.relu(out)
211 | out = self.block4(out)
212 |
213 | # tensor_stride=8
214 | out = self.convtr4p16s2(out)
215 | out = self.bntr4(out)
216 | out = self.relu(out)
217 |
218 | out = ME.cat(out, out_b3p8)
219 | out = self.block5(out)
220 |
221 | # tensor_stride=4
222 | out = self.convtr5p8s2(out)
223 | out = self.bntr5(out)
224 | out = self.relu(out)
225 |
226 | out = ME.cat(out, out_b2p4)
227 | out = self.block6(out)
228 |
229 | # tensor_stride=2
230 | out = self.convtr6p4s2(out)
231 | out = self.bntr6(out)
232 | out = self.relu(out)
233 |
234 | out = ME.cat(out, out_b1p2)
235 | out = self.block7(out)
236 |
237 | # tensor_stride=1
238 | out = self.convtr7p2s2(out)
239 | out = self.bntr7(out)
240 | out = self.relu(out)
241 |
242 | out = ME.cat(out, out_p1)
243 | out = self.block8(out)
244 |
245 | return self.final(out)
246 |
247 |
248 | class MinkUNet14(MinkUNetBase):
249 | BLOCK = BasicBlock
250 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1)
251 |
252 |
253 | class MinkUNet18(MinkUNetBase):
254 | BLOCK = BasicBlock
255 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
256 |
257 |
258 | class MinkUNet34(MinkUNetBase):
259 | BLOCK = BasicBlock
260 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
261 |
262 |
263 | class MinkUNet50(MinkUNetBase):
264 | BLOCK = Bottleneck
265 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
266 |
267 |
268 | class MinkUNet101(MinkUNetBase):
269 | BLOCK = Bottleneck
270 | LAYERS = (2, 3, 4, 23, 2, 2, 2, 2)
271 |
272 |
273 | class MinkUNet14A(MinkUNet14):
274 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
275 |
276 |
277 | class MinkUNet14B(MinkUNet14):
278 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
279 |
280 |
281 | class MinkUNet14C(MinkUNet14):
282 | PLANES = (32, 64, 128, 256, 192, 192, 128, 128)
283 |
284 |
285 | class MinkUNet14D(MinkUNet14):
286 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
287 |
288 |
289 | class MinkUNet18A(MinkUNet18):
290 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
291 |
292 |
293 | class MinkUNet18B(MinkUNet18):
294 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
295 |
296 |
297 | class MinkUNet18D(MinkUNet18):
298 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
299 |
300 |
301 | class MinkUNet34A(MinkUNet34):
302 | PLANES = (32, 64, 128, 256, 256, 128, 64, 64)
303 |
304 |
305 | class MinkUNet34B(MinkUNet34):
306 | PLANES = (32, 64, 128, 256, 256, 128, 64, 32)
307 |
308 |
309 | class MinkUNet34C(MinkUNet34):
310 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
311 |
--------------------------------------------------------------------------------
/pointdc_mk/models/pretrain_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import MinkowskiEngine as ME
4 |
5 | from MinkowskiEngine import MinkowskiNetwork
6 | from MinkowskiEngine import MinkowskiReLU
7 | import torch.nn.functional as F
8 |
9 | from os.path import dirname, abspath
10 | import sys
11 | sys.path.append(dirname(abspath(__file__)))
12 | from common import ConvType, NormType, conv, conv_tr, get_norm, sum_pool
13 |
14 | class SegHead(nn.Module):
15 | def __init__(self, in_channels=128, out_channels=20):
16 | super(SegHead, self).__init__()
17 | self.cluster = torch.nn.parameter.Parameter(data=torch.randn(out_channels, in_channels), requires_grad=True)
18 |
19 | def forward(self, feats):
20 | normed_clusters = F.normalize(self.cluster, dim=1)
21 | normed_features = F.normalize(feats, dim=1)
22 | logits = F.linear(normed_features, normed_clusters)
23 |
24 | return logits
25 |
26 | class alignlayer(MinkowskiNetwork):
27 | def __init__(self, in_channels=128, out_channels=70, bn_momentum=0.02, norm_layer=True, D=3):
28 | super(alignlayer, self).__init__(D)
29 | def space_n_time_m(n, m):
30 | return n if D == 3 else [n, n, n, m]
31 |
32 | self.conv_align = conv(
33 | in_channels, out_channels, kernel_size=space_n_time_m(1, 1), stride=1, D=D
34 | )
35 |
36 | def forward(self, feats_tensorfield):
37 | x = feats_tensorfield.sparse()
38 | aligned_out = self.conv_align(x).slice(feats_tensorfield).F #
39 |
40 | return aligned_out
41 |
42 | class SubModel(MinkowskiNetwork):
43 | def __init__(self, args, D=3):
44 | super(SubModel, self).__init__(D)
45 | self.args = args
46 | self.distill_layer = alignlayer(in_channels=args.feats_dim)
47 |
48 | def forward(self, feats_tensorfield):
49 | feats_aligned_nonorm = self.distill_layer(feats_tensorfield)
50 |
51 | return feats_aligned_nonorm
--------------------------------------------------------------------------------
/pointdc_mk/models/resunet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import MinkowskiEngine as ME
3 | import MinkowskiEngine.MinkowskiFunctional as MEF
4 | from .common import get_norm
5 |
6 | from .res16unet import get_block
7 | from .common import NormType
8 |
9 |
10 | class ResUNet2(ME.MinkowskiNetwork):
11 | NORM_TYPE = None
12 | BLOCK_NORM_TYPE = NormType.BATCH_NORM
13 | CHANNELS = [None, 32, 64, 128, 256]
14 | TR_CHANNELS = [None, 32, 64, 64, 128]
15 |
16 | # To use the model, must call initialize_coords before forward pass.
17 | # Once data is processed, call clear to reset the model before calling initialize_coords
18 | def __init__(
19 | self, in_channels=3, out_channels=32, bn_momentum=0.01, normalize_feature=True, conv1_kernel_size=5, D=3
20 | ):
21 | ME.MinkowskiNetwork.__init__(self, D)
22 | NORM_TYPE = self.NORM_TYPE
23 | BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE
24 | CHANNELS = self.CHANNELS
25 | TR_CHANNELS = self.TR_CHANNELS
26 | # print(D, in_channels, out_channels, conv1_kernel_size)
27 | self.normalize_feature = normalize_feature
28 | self.conv1 = ME.MinkowskiConvolution(
29 | in_channels=in_channels,
30 | out_channels=CHANNELS[1],
31 | kernel_size=conv1_kernel_size,
32 | stride=1,
33 | dilation=1,
34 | bias=False,
35 | dimension=D,
36 | )
37 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D)
38 |
39 | self.block1 = get_block(BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum, D=D)
40 |
41 | self.conv2 = ME.MinkowskiConvolution(
42 | in_channels=CHANNELS[1],
43 | out_channels=CHANNELS[2],
44 | kernel_size=3,
45 | stride=2,
46 | dilation=1,
47 | bias=False,
48 | dimension=D,
49 | )
50 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D)
51 |
52 | self.block2 = get_block(BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum, D=D)
53 |
54 | self.conv3 = ME.MinkowskiConvolution(
55 | in_channels=CHANNELS[2],
56 | out_channels=CHANNELS[3],
57 | kernel_size=3,
58 | stride=2,
59 | dilation=1,
60 | bias=False,
61 | dimension=D,
62 | )
63 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D)
64 |
65 | self.block3 = get_block(BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum, D=D)
66 |
67 | self.conv4 = ME.MinkowskiConvolution(
68 | in_channels=CHANNELS[3],
69 | out_channels=CHANNELS[4],
70 | kernel_size=3,
71 | stride=2,
72 | dilation=1,
73 | bias=False,
74 | dimension=D,
75 | )
76 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D)
77 |
78 | self.block4 = get_block(BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum, D=D)
79 |
80 | self.conv4_tr = ME.MinkowskiConvolutionTranspose(
81 | in_channels=CHANNELS[4],
82 | out_channels=TR_CHANNELS[4],
83 | kernel_size=3,
84 | stride=2,
85 | dilation=1,
86 | bias=False,
87 | dimension=D,
88 | )
89 | self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)
90 |
91 | self.block4_tr = get_block(BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)
92 |
93 | self.conv3_tr = ME.MinkowskiConvolutionTranspose(
94 | in_channels=CHANNELS[3] + TR_CHANNELS[4],
95 | out_channels=TR_CHANNELS[3],
96 | kernel_size=3,
97 | stride=2,
98 | dilation=1,
99 | bias=False,
100 | dimension=D,
101 | )
102 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)
103 |
104 | self.block3_tr = get_block(BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)
105 |
106 | self.conv2_tr = ME.MinkowskiConvolutionTranspose(
107 | in_channels=CHANNELS[2] + TR_CHANNELS[3],
108 | out_channels=TR_CHANNELS[2],
109 | kernel_size=3,
110 | stride=2,
111 | dilation=1,
112 | bias=False,
113 | dimension=D,
114 | )
115 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)
116 |
117 | self.block2_tr = get_block(BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)
118 |
119 | self.conv1_tr = ME.MinkowskiConvolution(
120 | in_channels=CHANNELS[1] + TR_CHANNELS[2],
121 | out_channels=TR_CHANNELS[1],
122 | kernel_size=1,
123 | stride=1,
124 | dilation=1,
125 | bias=False,
126 | dimension=D,
127 | )
128 |
129 | # self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D)
130 |
131 | self.final = ME.MinkowskiConvolution(
132 | in_channels=TR_CHANNELS[1],
133 | out_channels=out_channels,
134 | kernel_size=1,
135 | stride=1,
136 | dilation=1,
137 | bias=True,
138 | dimension=D,
139 | )
140 |
141 | def forward(self, x):
142 | out_s1 = self.conv1(x)
143 | out_s1 = self.norm1(out_s1)
144 | out_s1 = self.block1(out_s1)
145 | out = MEF.relu(out_s1)
146 |
147 | out_s2 = self.conv2(out)
148 | out_s2 = self.norm2(out_s2)
149 | out_s2 = self.block2(out_s2)
150 | out = MEF.relu(out_s2)
151 |
152 | out_s4 = self.conv3(out)
153 | out_s4 = self.norm3(out_s4)
154 | out_s4 = self.block3(out_s4)
155 | out = MEF.relu(out_s4)
156 |
157 | out_s8 = self.conv4(out)
158 | out_s8 = self.norm4(out_s8)
159 | out_s8 = self.block4(out_s8)
160 | out = MEF.relu(out_s8)
161 |
162 | out = self.conv4_tr(out)
163 | out = self.norm4_tr(out)
164 | out = self.block4_tr(out)
165 | out_s4_tr = MEF.relu(out)
166 |
167 | out = ME.cat(out_s4_tr, out_s4)
168 |
169 | out = self.conv3_tr(out)
170 | out = self.norm3_tr(out)
171 | out = self.block3_tr(out)
172 | out_s2_tr = MEF.relu(out)
173 |
174 | out = ME.cat(out_s2_tr, out_s2)
175 |
176 | out = self.conv2_tr(out)
177 | out = self.norm2_tr(out)
178 | out = self.block2_tr(out)
179 | out_s1_tr = MEF.relu(out)
180 |
181 | out = ME.cat(out_s1_tr, out_s1)
182 | out = self.conv1_tr(out)
183 | out = MEF.relu(out)
184 | out = self.final(out)
185 |
186 | if self.normalize_feature:
187 | return ME.SparseTensor(
188 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True),
189 | coordinate_map_key=out.coordinate_map_key,
190 | coordinate_manager=out.coordinate_manager,
191 | )
192 | else:
193 | return out
194 |
195 |
196 | class ResUNetBN2(ResUNet2):
197 | NORM_TYPE = NormType.BATCH_NORM
198 |
199 |
200 | class ResUNetBN2B(ResUNet2):
201 | NORM_TYPE = NormType.BATCH_NORM
202 | CHANNELS = [None, 32, 64, 128, 256]
203 | TR_CHANNELS = [None, 64, 64, 64, 64]
204 |
205 |
206 | class ResUNetBN2C(ResUNet2):
207 | NORM_TYPE = NormType.BATCH_NORM
208 | CHANNELS = [None, 32, 64, 128, 256]
209 | TR_CHANNELS = [None, 64, 64, 64, 128]
210 |
211 |
212 | class ResUNetBN2D(ResUNet2):
213 | NORM_TYPE = NormType.BATCH_NORM
214 | CHANNELS = [None, 32, 64, 128, 256]
215 | TR_CHANNELS = [None, 64, 64, 128, 128]
216 |
217 |
218 | class ResUNetBN2E(ResUNet2):
219 | NORM_TYPE = NormType.BATCH_NORM
220 | CHANNELS = [None, 128, 128, 128, 256]
221 | TR_CHANNELS = [None, 64, 128, 128, 128]
222 |
223 |
224 | class Res2BlockDown(ME.MinkowskiNetwork):
225 |
226 | """
227 | block for unwrapped Resnet
228 | """
229 |
230 | def __init__(
231 | self,
232 | down_conv_nn,
233 | kernel_size,
234 | stride,
235 | dilation,
236 | dimension=3,
237 | bn_momentum=0.01,
238 | norm_type=NormType.BATCH_NORM,
239 | block_norm_type=NormType.BATCH_NORM,
240 | **kwargs
241 | ):
242 | ME.MinkowskiNetwork.__init__(self, dimension)
243 | self.conv = ME.MinkowskiConvolution(
244 | in_channels=down_conv_nn[0],
245 | out_channels=down_conv_nn[1],
246 | kernel_size=kernel_size,
247 | stride=stride,
248 | dilation=dilation,
249 | bias=False,
250 | dimension=dimension,
251 | )
252 | self.norm = get_norm(norm_type, down_conv_nn[1], bn_momentum=bn_momentum, D=dimension)
253 | self.block = get_block(block_norm_type, down_conv_nn[1], down_conv_nn[1], bn_momentum=bn_momentum, D=dimension)
254 |
255 | def forward(self, x):
256 |
257 | out_s = self.conv(x)
258 | out_s = self.norm(out_s)
259 | out = self.block(out_s)
260 | return out
261 |
262 |
263 | class Res2BlockUp(ME.MinkowskiNetwork):
264 |
265 | """
266 | block for unwrapped Resnet
267 | """
268 |
269 | def __init__(
270 | self,
271 | up_conv_nn,
272 | kernel_size,
273 | stride,
274 | dilation,
275 | dimension=3,
276 | bn_momentum=0.01,
277 | norm_type=NormType.BATCH_NORM,
278 | block_norm_type=NormType.BATCH_NORM,
279 | **kwargs
280 | ):
281 | ME.MinkowskiNetwork.__init__(self, dimension)
282 | self.conv = ME.MinkowskiConvolutionTranspose(
283 | in_channels=up_conv_nn[0],
284 | out_channels=up_conv_nn[1],
285 | kernel_size=kernel_size,
286 | stride=stride,
287 | dilation=dilation,
288 | bias=False,
289 | dimension=dimension,
290 | )
291 | if len(up_conv_nn) == 3:
292 | self.final = ME.MinkowskiConvolution(
293 | in_channels=up_conv_nn[1],
294 | out_channels=up_conv_nn[2],
295 | kernel_size=kernel_size,
296 | stride=stride,
297 | dilation=dilation,
298 | bias=True,
299 | dimension=dimension,
300 | )
301 | else:
302 | self.norm = get_norm(norm_type, up_conv_nn[1], bn_momentum=bn_momentum, D=dimension)
303 | self.block = get_block(block_norm_type, up_conv_nn[1], up_conv_nn[1], bn_momentum=bn_momentum, D=dimension)
304 | self.final = None
305 |
306 | def forward(self, x, x_skip):
307 | if x_skip is not None:
308 | x = ME.cat(x, x_skip)
309 | out_s = self.conv(x)
310 | if self.final is None:
311 | out_s = self.norm(out_s)
312 | out = self.block(out_s)
313 | return out
314 | else:
315 | out_s = MEF.relu(out_s)
316 | out = self.final(out_s)
317 | return out
318 |
--------------------------------------------------------------------------------
/pointdc_mk/train_S3DIS.py:
--------------------------------------------------------------------------------
1 | import os, random, time, argparse, logging, warnings, torch
2 | import numpy as np
3 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2
4 | from datasets.S3DIS import S3DISdistill, S3DIStrain, S3DIScluster, cfl_collate_fn_distill, cfl_collate_fn
5 | import MinkowskiEngine as ME
6 | import torch.nn.functional as F
7 | from torch.utils.data import DataLoader
8 | from models.fpn import Res16FPN18
9 | from models.pretrain_models import SubModel, SegHead
10 | from eval_S3DIS import eval, eval_once, eval_by_cluster
11 | from lib.utils_s3dis import *
12 | from sklearn.cluster import KMeans
13 | from os.path import join
14 | from tqdm import tqdm
15 | from torch.optim import lr_scheduler
16 | warnings.filterwarnings('ignore')
17 |
18 | def parse_args():
19 | '''PARAMETERS'''
20 | parser = argparse.ArgumentParser(description='PyTorch Unsuper_3D_Seg')
21 | parser.add_argument('--data_path', type=str, default='data/S3DIS/', help='pont cloud data path')
22 | parser.add_argument('--sp_path', type=str, default= 'data/S3DIS/', help='initial sp path')
23 | parser.add_argument('--expname', type=str, default= 'zdefalut', help='expname for logger')
24 | ###
25 | parser.add_argument('--save_path', type=str, default='ckpt/S3DIS/', help='model savepath')
26 | parser.add_argument('--max_epoch', type=list, default=[200, 30, 60], help='max epoch')
27 | ###
28 | parser.add_argument('--bn_momentum', type=float, default=0.02, help='batchnorm parameters')
29 | parser.add_argument('--conv1_kernel_size', type=int, default=5, help='kernel size of 1st conv layers')
30 | ####
31 | parser.add_argument('--lrs', type=list, default=[1e-3, 3e-2, 3e-2], help='learning rate')
32 | parser.add_argument('--momentum', type=float, default=0.9, help='SGD parameters')
33 | parser.add_argument('--dampening', type=float, default=0.1, help='SGD parameters')
34 | parser.add_argument('--weight-decay', type=float, default=1e-4, help='SGD parameters')
35 | parser.add_argument('--workers', type=int, default=8, help='how many workers for loading data')
36 | parser.add_argument('--cluster_workers', type=int, default=4, help='how many workers for loading data in clustering')
37 | parser.add_argument('--seed', type=int, default=2023, help='random seed')
38 | parser.add_argument('--log-interval', type=int, default=150, help='log interval')
39 | parser.add_argument('--batch_size', type=int, default=8, help='batchsize in training')
40 | parser.add_argument('--voxel_size', type=float, default=0.05, help='voxel size in SparseConv')
41 | parser.add_argument('--input_dim', type=int, default=6, help='network input dimension')### 6 for XYZGB
42 | parser.add_argument('--primitive_num', type=int, default=13, help='how many primitives used in training')
43 | parser.add_argument('--semantic_class', type=int, default=13, help='ground truth semantic class')
44 | parser.add_argument('--feats_dim', type=int, default=128, help='output feature dimension')
45 | parser.add_argument('--ignore_label', type=int, default=-1, help='invalid label')
46 | parser.add_argument('--drop_threshold', type=int, default=50, help='mask counts')
47 |
48 | return parser.parse_args()
49 |
50 | def main(args, logger):
51 | # Cross model distillation
52 | logger.info('**************Start Cross Model Distillation**************')
53 | ## Prepare Model/Optimizer
54 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, \
55 | conv1_kernel_size=args.conv1_kernel_size, args=args, mode='distill')
56 | submodel = SubModel(args)
57 | adam = torch.optim.Adam([{'params':model.parameters()}, {'params': submodel.parameters()}], \
58 | lr=args.lrs[0])
59 | model, submodel = model.cuda(), submodel.cuda()
60 | distill_loss = MseMaskLoss().cuda()
61 |
62 | # Prepare Data
63 | distillset = S3DISdistill(args)
64 | distill_loader = DataLoader(distillset, batch_size=args.batch_size, shuffle=True, collate_fn=cfl_collate_fn_distill(), \
65 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed))
66 | clusterset = S3DIScluster(args, areas=['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6'])
67 | cluster_loader = DataLoader(clusterset, batch_size=1, shuffle=True, collate_fn=cfl_collate_fn(), \
68 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed))
69 | # Distill
70 | for epoch in range(1, args.max_epoch[0]+1):
71 | distill(distill_loader, logger, model, submodel, adam, distill_loss, epoch, args.max_epoch[0])
72 | if epoch % 10 == 0:
73 | torch.save(model.state_dict(), join(args.save_path, 'cmd', 'model_' + str(epoch) + '_checkpoint.pth'))
74 | torch.save(submodel.state_dict(), join(args.save_path, 'cmd', 'submodule_' + str(epoch) + '_checkpoint.pth'))
75 |
76 | # Compute pseudo label
77 | model, submodel = model.cuda(), submodel.cuda()
78 |
79 | centroids_norm = init_cluster(args, logger, cluster_loader, model, submodel=submodel)
80 |
81 | del adam, distill_loader, distillset, submodel
82 | logger.info('====>End Cross Model Distill !!!\n')
83 |
84 | # Super Voxel Clustering
85 | logger.info('**************Start Super Voxel Clustering**************')
86 | ## Prepare Data
87 | trainset = S3DIStrain(args, areas=['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6'])
88 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, collate_fn=cfl_collate_fn(), \
89 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed))
90 | ## Warm Up
91 | model.mode = 'train'
92 | ## Prepare Model/Loss/Optimizer
93 | seghead = SegHead(args.feats_dim, args.primitive_num)
94 | seghead = seghead.cuda()
95 | loss = torch.nn.CrossEntropyLoss(ignore_index=-1).cuda()
96 | warmup_optimizer = torch.optim.SGD([{"params": seghead.parameters()}, {"params": model.parameters()}], \
97 | lr=args.lrs[1], momentum=args.momentum, \
98 | dampening=args.dampening, weight_decay=args.weight_decay)
99 |
100 | logger.info('====>Start Warm Up.')
101 | for epoch in range(1, args.max_epoch[1]+1):
102 | train(train_loader, logger, model, warmup_optimizer, loss, epoch, seghead, args.max_epoch[1])
103 | ### Evalutaion and Save checkpoint
104 | if epoch % 5 == 0:
105 | torch.save(model.state_dict(), join(args.save_path, 'svc', 'model_' + str(epoch) + '_checkpoint.pth'))
106 | torch.save(seghead.state_dict()['cluster'], join(args.save_path, 'svc', 'cls_' + str(epoch) + '_checkpoint.pth'))
107 | with torch.no_grad():
108 | o_Acc, m_Acc, s = eval(epoch, args)
109 | logger.info('WarmUp--Eval Epoch: {:02d}, oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, o_Acc, m_Acc) + s+'\n')
110 | logger.info('====>End Warm Up !!!\n')
111 |
112 | # Iterative Training
113 | del seghead, warmup_optimizer # NOTE
114 | logger.info('====>Start Iterative Training.')
115 | iter_optimizer = torch.optim.SGD(model.parameters(), \
116 | lr=args.lrs[2], momentum=args.momentum, \
117 | dampening=args.dampening, weight_decay=args.weight_decay)
118 |
119 | scheduler = lr_scheduler.StepLR(iter_optimizer, step_size=5, gamma=0.8) # step lr
120 | logger.info('====>Update pseudo labels.')
121 | centroids_norm = init_cluster(args, logger, cluster_loader, model)
122 | seghead = get_fixclassifier(args.feats_dim, args.primitive_num, centroids_norm).cuda()
123 | for epoch in range(args.max_epoch[1]+1, args.max_epoch[1]+args.max_epoch[2]+1):
124 | logger.info('Update Optimizer lr:{:.2e}'.format(scheduler.get_last_lr()[0]))
125 | train(train_loader, logger, model, iter_optimizer, loss, epoch, seghead, args.max_epoch[1]+args.max_epoch[2]) ### train
126 | scheduler.step()
127 | if epoch % 5 == 0:
128 | torch.save(model.state_dict(), join(args.save_path, 'svc', 'model_' + str(epoch) + '_checkpoint.pth'))
129 | torch.save(seghead.state_dict()['weight'], join(args.save_path, 'svc', 'cls_' + str(epoch) + '_checkpoint.pth'))
130 | with torch.no_grad():
131 | o_Acc, m_Acc, s = eval(epoch, args, mode='svc')
132 | logger.info('Iter--Eval Epoch{:02d}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, o_Acc, m_Acc) + s+'\n')
133 | ### Update pseudo labels
134 | if epoch != args.max_epoch[1]+args.max_epoch[2]+1:
135 | logger.info('Update pseudo labels')
136 | centroids_norm = init_cluster(args, logger, cluster_loader, model)
137 | seghead.weight.data = centroids_norm.requires_grad_(False)
138 |
139 | logger.info('====>End Super Voxel Clustering !!!\n')
140 |
141 |
142 | def init_cluster(args, logger, cluster_loader, model, submodel=None):
143 | time_start = time.time()
144 |
145 | ## Extract Superpoints Feature
146 | sp_feats_list = init_get_sp_feature(args, cluster_loader, model, submodel)
147 | sp_feats = torch.cat(sp_feats_list, dim=0) ### will do Kmeans with l2 distance
148 | _, centroids_norm = faiss_cluster(args, sp_feats.cpu().numpy())
149 | centroids_norm = centroids_norm.cuda()
150 |
151 | ## Compute and Save Pseudo Labels
152 | all_pseudo, all_labels = init_get_pseudo(args, cluster_loader, model, centroids_norm, submodel)
153 | o_Acc, m_Acc, s = compute_seg_results(args, all_labels, all_pseudo)
154 | logger.info('clustering time: %.2fs', (time.time() - time_start))
155 | logger.info('Trainset: oAcc {:.2f} mAcc {:.2f} IoUs'.format(o_Acc, m_Acc) + s+'\n')
156 |
157 | return centroids_norm
158 |
159 | def distill(distill_loader, logger, model, submodel, optimizer, loss, epoch, maxepochs):
160 | distill_loader.dataset.mode = 'distill'
161 | model.train()
162 | submodel.train()
163 | loss_display = AverageMeter()
164 |
165 | trainloader_bar = tqdm(distill_loader)
166 | for batch_idx, data in enumerate(trainloader_bar):
167 | ## Prepare data
168 | trainloader_bar.set_description('Epoch {}'.format(epoch))
169 | coords, features, dinofeats, normals, labels, inverse_map, region, index, scenenames = data
170 | ## Forward
171 | in_field = ME.TensorField(features, coords, device=0)
172 | feats = model(in_field)
173 | feats_aligned = submodel(feats)
174 | ## Loss
175 | mask = region.squeeze() >= 0
176 | loss_distill = loss(F.normalize(feats_aligned[mask]), F.normalize(dinofeats[mask].cuda().detach()))
177 | loss_display.update(loss_distill.item())
178 | optimizer.zero_grad()
179 | loss_distill.backward()
180 | optimizer.step()
181 |
182 | torch.cuda.empty_cache()
183 | torch.cuda.synchronize(torch.device("cuda"))
184 | if batch_idx %5 == 0:
185 | trainloader_bar.set_postfix(trainloss='{:.3e}'.format(loss_display.avg))
186 | if epoch % 10 == 0:
187 | logger.info('Epoch: {}/{} Train loss: {:.3e}'.format(epoch, maxepochs, loss_display.avg))
188 |
189 | def train(train_loader, logger, model, optimizer, loss, epoch, classifier, maxepochs):
190 | train_loader.dataset.mode = 'train'
191 | model.train()
192 | classifier.train()
193 | loss_display = AverageMeter()
194 |
195 | trainloader_bar = tqdm(train_loader)
196 | for batch_idx, data in enumerate(trainloader_bar):
197 |
198 | trainloader_bar.set_description('Epoch {}/{}'.format(epoch, maxepochs))
199 | coords, features, normals, labels, inverse_map, pseudo_labels, inds, region, index, scenenames = data
200 |
201 | in_field = ME.TensorField(features, coords, device=0)
202 | feats_nonorm = model(in_field)
203 | logits = classifier(feats_nonorm)
204 |
205 | ## loss
206 | pseudo_labels_comp = pseudo_labels.long().cuda()
207 | loss_sem = loss(logits, pseudo_labels_comp).mean()
208 | loss_display.update(loss_sem.item())
209 | optimizer.zero_grad()
210 | loss_sem.backward()
211 | optimizer.step()
212 |
213 | torch.cuda.empty_cache()
214 | torch.cuda.synchronize(torch.device("cuda"))
215 |
216 | if batch_idx % 20 == 0:
217 | trainloader_bar.set_postfix(trainloss='{:.3e}'.format(loss_display.avg))
218 |
219 | logger.info('Epoch {}/{}: Train loss: {:.3e}'.format(epoch, maxepochs, loss_display.avg))
220 |
221 | def set_logger(log_path):
222 | logger = logging.getLogger()
223 | logger.setLevel(logging.INFO)
224 |
225 | # Logging to a file
226 | file_handler = logging.FileHandler(log_path)
227 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
228 | logger.addHandler(file_handler)
229 |
230 | # Logging to console
231 | stream_handler = logging.StreamHandler()
232 | stream_handler.setFormatter(logging.Formatter('%(message)s'))
233 | logger.addHandler(stream_handler)
234 |
235 | return logger
236 |
237 | if __name__ == '__main__':
238 | args = parse_args()
239 |
240 | args.save_path = os.path.join(args.save_path, args.expname)
241 | args.pseudo_path = os.path.join(args.save_path, 'pseudo_labels')
242 |
243 | '''Setup logger'''
244 | if not os.path.exists(args.save_path):
245 | os.makedirs(args.save_path)
246 | os.makedirs(args.pseudo_path)
247 | os.makedirs(join(args.save_path, 'cmd'))
248 | os.makedirs(join(args.save_path, 'svc'))
249 | logger = set_logger(os.path.join(args.save_path, 'train.log'))
250 | logger.info(args)
251 |
252 | '''Cache code'''
253 | cache_codes(args)
254 |
255 | '''Random Seed'''
256 | seed = args.seed
257 | set_seed(seed)
258 |
259 | main(args, logger)
260 |
--------------------------------------------------------------------------------
/pointdc_mk/train_ScanNet.py:
--------------------------------------------------------------------------------
1 | import os, random, time, argparse, logging, warnings, torch
2 | import numpy as np
3 | from sklearn.utils.linear_assignment_ import linear_assignment # pip install scikit-learn==0.22.2
4 | from datasets.ScanNet import Scannettrain, Scannetdistill, Scannetval, cfl_collate_fn, cfl_collate_fn_distill, cfl_collate_fn_val
5 | import MinkowskiEngine as ME
6 | import torch.nn.functional as F
7 | from torch.utils.data import DataLoader
8 | from models.fpn import Res16FPN18
9 | from models.pretrain_models import SubModel, SegHead
10 | from eval_ScanNet import eval, eval_once, eval_by_cluster
11 | from lib.utils import *
12 | from sklearn.cluster import KMeans
13 | from os.path import join
14 | from tqdm import tqdm
15 | from torch.optim import lr_scheduler
16 | warnings.filterwarnings('ignore')
17 |
18 | def parse_args():
19 | '''PARAMETERS'''
20 | parser = argparse.ArgumentParser(description='PointDC')
21 | parser.add_argument('--data_path', type=str, default='data/ScanNet/train', help='pont cloud data path') # 点云文件路径
22 | parser.add_argument('--feats_path', type=str, default='data/ScanNet/train_feats', help='pont cloud data path') # 特征体文件路径
23 | parser.add_argument('--sp_path', type=str, default= 'data/ScanNet/initial_superpoints', help='initial sp path') # 超体素文件路径
24 | parser.add_argument('--expname', type=str, default= 'default', help='expname for logger')
25 | ###
26 | parser.add_argument('--save_path', type=str, default='ckpt/ScanNet/', help='model savepath')
27 | parser.add_argument('--max_epoch', type=list, default=[200, 30, 60], help='max epoch')
28 | ###
29 | parser.add_argument('--bn_momentum', type=float, default=0.02, help='batchnorm parameters')
30 | parser.add_argument('--conv1_kernel_size', type=int, default=5, help='kernel size of 1st conv layers')
31 | ####
32 | parser.add_argument('--lrs', type=list, default=[1e-3, 3e-2, 3e-2], help='learning rate')
33 | parser.add_argument('--momentum', type=float, default=0.9, help='SGD parameters')
34 | parser.add_argument('--dampening', type=float, default=0.1, help='SGD parameters')
35 | parser.add_argument('--weight-decay', type=float, default=1e-4, help='SGD parameters')
36 | parser.add_argument('--workers', type=int, default=8, help='how many workers for loading data')
37 | parser.add_argument('--cluster_workers', type=int, default=4, help='how many workers for loading data in clustering')
38 | parser.add_argument('--seed', type=int, default=2023, help='random seed')
39 | parser.add_argument('--log-interval', type=int, default=150, help='log interval')
40 | parser.add_argument('--batch_size', type=int, default=8, help='batchsize in training')
41 | parser.add_argument('--voxel_size', type=float, default=0.02, help='voxel size in SparseConv')
42 | parser.add_argument('--input_dim', type=int, default=6, help='network input dimension')### 6 for XYZGB
43 | parser.add_argument('--primitive_num', type=int, default=20, help='how many primitives used in training')
44 | parser.add_argument('--semantic_class', type=int, default=20, help='ground truth semantic class')
45 | parser.add_argument('--feats_dim', type=int, default=128, help='output feature dimension')
46 | parser.add_argument('--ignore_label', type=int, default=-1, help='invalid label')
47 |
48 | return parser.parse_args()
49 |
50 |
51 | def main(args, logger):
52 | # Cross model distillation
53 | logger.info('**************Start Cross Model Distillation**************')
54 | ## Prepare Model/Optimizer
55 | model = Res16FPN18(in_channels=args.input_dim, out_channels=args.primitive_num, \
56 | conv1_kernel_size=args.conv1_kernel_size, args=args, mode='distill')
57 | submodel = SubModel(args)
58 | model, submodel= model.cuda(), submodel.cuda()
59 | distill_loss = MseMaskLoss().cuda()
60 | adam = torch.optim.Adam([{'params':model.parameters()}, {'params': submodel.parameters()}], \
61 | lr=args.lrs[0])
62 | ## Prepare Data
63 | distillset = Scannetdistill(args)
64 | distill_loader = DataLoader(distillset, batch_size=args.batch_size, shuffle=True, collate_fn=cfl_collate_fn_distill(), \
65 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed))
66 | ## Distill
67 | for epoch in range(1, args.max_epoch[0]+1):
68 | distill(distill_loader, logger, model, submodel, adam, distill_loss, epoch, args.max_epoch[0])
69 | if epoch % 10 == 0:
70 | torch.save(model.state_dict(), join(args.save_path, 'cmd', 'model_' + str(epoch) + '_checkpoint.pth'))
71 | torch.save(submodel.state_dict(), join(args.save_path, 'cmd', 'submodule_' + str(epoch) + '_checkpoint.pth'))
72 |
73 | ## Cluster & Compute pseudo labels
74 | trainset = Scannettrain(args)
75 | cluster_loader = DataLoader(trainset, batch_size=1, shuffle=True, collate_fn=cfl_collate_fn(), \
76 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed))
77 | model, submodel = model.cuda(), submodel.cuda()
78 | centroids_norm = init_cluster(args, logger, cluster_loader, model, submodel=submodel)
79 |
80 | del adam, distill_loader, distillset, submodel
81 | logger.info('====>End Cross Model Distill !!!\n')
82 |
83 | # Super Voxel Clustering
84 | logger.info('**************Start Super Voxel Clustering**************')
85 | ## Prepare Data
86 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, collate_fn=cfl_collate_fn(), \
87 | num_workers=args.workers, pin_memory=True, worker_init_fn=worker_init_fn(seed))
88 | ## Warm Up
89 | model.mode = 'train'
90 | ### Prepare Model/Loss/Optimizer
91 | seghead = SegHead(args.feats_dim, args.primitive_num)
92 | seghead = seghead.cuda()
93 | loss = torch.nn.CrossEntropyLoss(ignore_index=-1).cuda()
94 | warmup_optimizer = torch.optim.SGD([{"params": seghead.parameters()}, {"params": model.parameters()}], \
95 | lr=args.lrs[1], momentum=args.momentum, dampening=args.dampening, weight_decay=args.weight_decay)
96 | scheduler = lr_scheduler.StepLR(warmup_optimizer, step_size=5, gamma=0.8) # step lr
97 | logger.info('====>Start Warm Up.')
98 | for epoch in range(1, args.max_epoch[1]+1):
99 | logger.info('Update Optimizer lr:{:.2e}'.format(scheduler.get_last_lr()[0]))
100 | train(train_loader, logger, model, warmup_optimizer, loss, epoch, seghead, args.max_epoch[1])
101 | ### Evalutaion and Save checkpoint
102 | if epoch % 5 == 0:
103 | torch.save(model.state_dict(), join(args.save_path, 'svc', 'model_' + str(epoch) + '_checkpoint.pth'))
104 | torch.save(seghead.state_dict()['cluster'], join(args.save_path, 'svc', 'cls_' + str(epoch) + '_checkpoint.pth'))
105 | with torch.no_grad():
106 | o_Acc, m_Acc, s = eval(epoch, args)
107 | logger.info('WarmUp--Eval Epoch: {:02d}, oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, o_Acc, m_Acc) + s+'\n')
108 | logger.info('====>End Warm Up !!!\n')
109 | ## Iterative Training
110 | del seghead, warmup_optimizer # NOTE
111 | logger.info('====>Start Iterative Training.')
112 | iter_optimizer = torch.optim.SGD(model.parameters(), \
113 | lr=args.lrs[2], momentum=args.momentum, dampening=args.dampening, weight_decay=args.weight_decay)
114 | scheduler = lr_scheduler.StepLR(iter_optimizer, step_size=5, gamma=0.8) # step lr
115 | logger.info('====>Update pseudo labels.')
116 | centroids_norm = init_cluster(args, logger, cluster_loader, model)
117 | seghead = get_fixclassifier(args.feats_dim, args.primitive_num, centroids_norm).cuda()
118 | for epoch in range(args.max_epoch[1]+1, args.max_epoch[1]+args.max_epoch[2]+1):
119 | logger.info('Update Optimizer lr:{:.2e}'.format(scheduler.get_last_lr()[0]))
120 | train(train_loader, logger, model, iter_optimizer, loss, epoch, seghead, args.max_epoch[1]+args.max_epoch[2]) ### train
121 | scheduler.step()
122 | if epoch % 5 == 0:
123 | torch.save(model.state_dict(), join(args.save_path, 'svc', 'model_' + str(epoch) + '_checkpoint.pth'))
124 | torch.save(seghead.state_dict()['weight'], join(args.save_path, 'svc', 'cls_' + str(epoch) + '_checkpoint.pth'))
125 | with torch.no_grad():
126 | o_Acc, m_Acc, s = eval(epoch, args, mode='svc')
127 | logger.info('Iter--Eval Epoch{:02d}: oAcc {:.2f} mAcc {:.2f} IoUs'.format(epoch, o_Acc, m_Acc) + s+'\n')
128 | ### Update pseudo labels
129 | if epoch != args.max_epoch[1]+args.max_epoch[2]+1:
130 | logger.info('Update pseudo labels')
131 | centroids_norm = init_cluster(args, logger, cluster_loader, model)
132 | seghead.weight.data = centroids_norm.requires_grad_(False)
133 |
134 | logger.info('====>End Super Voxel Clustering !!!\n')
135 |
136 | def init_cluster(args, logger, cluster_loader, model, submodel=None):
137 | time_start = time.time()
138 | cluster_loader.dataset.mode = 'cluster'
139 |
140 | ## Extract Superpoints Feature
141 | sp_feats_list = init_get_sp_feature(args, cluster_loader, model, submodel)
142 | sp_feats = torch.cat(sp_feats_list, dim=0) ### will do Kmeans with geometric distance
143 | _, centroids_norm = faiss_cluster(args, sp_feats.cpu().numpy())
144 | centroids_norm = centroids_norm.cuda()
145 |
146 | ## Compute and Save Pseudo Labels
147 | all_pseudo, all_labels = init_get_pseudo(args, cluster_loader, model, centroids_norm, submodel)
148 | logger.info('clustering time: %.2fs', (time.time() - time_start))
149 |
150 | return centroids_norm
151 |
152 | def distill(distill_loader, logger, model, submodel, optimizer, loss, epoch, maxepochs):
153 | distill_loader.dataset.mode = 'distill'
154 | model.train()
155 | submodel.train()
156 | loss_display = AverageMeter()
157 |
158 | trainloader_bar = tqdm(distill_loader)
159 | for batch_idx, data in enumerate(trainloader_bar):
160 | ## Prepare data
161 | trainloader_bar.set_description('Epoch {}'.format(epoch))
162 | coords, features, normals, labels, inverse_map, pseudo_labels, inds, region, index, scenenames, spfeats = data
163 | ## Forward
164 | in_field = ME.TensorField(features, coords, device=0)
165 | feats = model(in_field)
166 | feats_aligned = submodel(feats)
167 | ## Loss
168 | loss_distill = loss(feats_aligned, spfeats.cuda().detach())
169 | loss_display.update(loss_distill.item())
170 | optimizer.zero_grad()
171 | loss_distill.backward()
172 | optimizer.step()
173 |
174 | torch.cuda.empty_cache()
175 | # torch.cuda.synchronize(torch.device("cuda"))
176 | if batch_idx % 10 == 0:
177 | trainloader_bar.set_postfix(trainloss='{:.3e}'.format(loss_display.avg))
178 | if epoch % 10 == 0:
179 | logger.info('Epoch: {}/{} Train loss: {:.3e}'.format(epoch, maxepochs, loss_display.avg))
180 |
181 | def train(train_loader, logger, model, optimizer, loss, epoch, classifier, maxepochs):
182 | train_loader.dataset.mode = 'train'
183 | model.train()
184 | classifier.train()
185 | loss_display = AverageMeter()
186 |
187 | trainloader_bar = tqdm(train_loader)
188 | for batch_idx, data in enumerate(trainloader_bar):
189 |
190 | trainloader_bar.set_description('Epoch {}/{}'.format(epoch, maxepochs))
191 | coords, features, normals, labels, inverse_map, pseudo_labels, inds, region, index, scenenames = data
192 |
193 | in_field = ME.TensorField(features, coords, device=0)
194 | feats_nonorm = model(in_field)
195 | logits = classifier(feats_nonorm)
196 |
197 | ## loss
198 | pseudo_labels_comp = pseudo_labels.long().cuda()
199 | loss_sem = loss(logits, pseudo_labels_comp).mean()
200 | loss_display.update(loss_sem.item())
201 | optimizer.zero_grad()
202 | loss_sem.backward()
203 | optimizer.step()
204 |
205 | torch.cuda.empty_cache()
206 | torch.cuda.synchronize(torch.device("cuda"))
207 |
208 | if batch_idx % 20 == 0:
209 | trainloader_bar.set_postfix(trainloss='{:.3e}'.format(loss_display.avg))
210 |
211 | logger.info('Epoch {}/{}: Train loss: {:.3e}'.format(epoch, maxepochs, loss_display.avg))
212 |
213 | def set_logger(log_path):
214 | logger = logging.getLogger()
215 | logger.setLevel(logging.INFO)
216 |
217 | # Logging to a file
218 | file_handler = logging.FileHandler(log_path)
219 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
220 | logger.addHandler(file_handler)
221 |
222 | # Logging to console
223 | stream_handler = logging.StreamHandler()
224 | stream_handler.setFormatter(logging.Formatter('%(message)s'))
225 | logger.addHandler(stream_handler)
226 |
227 | return logger
228 |
229 | if __name__ == '__main__':
230 | args = parse_args()
231 |
232 | args.save_path = os.path.join(args.save_path, args.expname)
233 | args.pseudo_path = os.path.join(args.save_path, 'pseudo_labels')
234 |
235 | '''Setup logger'''
236 | if not os.path.exists(args.save_path):
237 | os.makedirs(args.save_path)
238 | os.makedirs(args.pseudo_path)
239 | os.makedirs(join(args.save_path, 'cmd'))
240 | os.makedirs(join(args.save_path, 'svc'))
241 | logger = set_logger(os.path.join(args.save_path, 'train.log'))
242 | logger.info(args)
243 |
244 | '''Cache code'''
245 | cache_codes(args)
246 |
247 | '''Random Seed'''
248 | seed = args.seed
249 | set_seed(seed)
250 |
251 | main(args, logger)
252 |
--------------------------------------------------------------------------------