├── .gitignore
├── .gitmodules
├── LICENSE
├── NYUv2.gif
├── README.md
├── checkpoints
└── .placeholder
├── iso
├── config
│ ├── iso.yaml
│ ├── iso_occscannet.yaml
│ └── iso_occscannet_mini.yaml
├── data
│ ├── NYU
│ │ ├── collate.py
│ │ ├── nyu_dataset.py
│ │ ├── nyu_dm.py
│ │ ├── params.py
│ │ └── preprocess.py
│ ├── OccScanNet
│ │ ├── collate.py
│ │ ├── occscannet_dataset.py
│ │ ├── occscannet_dm.py
│ │ └── params.py
│ └── utils
│ │ ├── fusion.py
│ │ ├── helpers.py
│ │ └── torch_util.py
├── loss
│ ├── CRP_loss.py
│ ├── depth_loss.py
│ ├── sscMetrics.py
│ └── ssc_loss.py
├── models
│ ├── CRP3D.py
│ ├── DDR.py
│ ├── depthnet.py
│ ├── flosp.py
│ ├── iso.py
│ ├── modules.py
│ ├── unet2d.py
│ └── unet3d_nyu.py
└── scripts
│ ├── eval.sh
│ ├── eval_iso.py
│ ├── generate_output.py
│ ├── train.sh
│ ├── train_iso.py
│ └── visualization
│ ├── NYU_vis_pred.py
│ ├── NYU_vis_pred_2.py
│ ├── OccScanNet_vis_pred.py
│ └── kitti_vis_pred.py
├── requirements.txt
├── setup.py
└── trained_models
└── .placeholder
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__/
3 | *.egg-info/
4 | outputs/
5 | trained_models/*.ckpt
6 | checkpoints/*.pt
7 | iso_output/*
8 | core*
9 | images
10 | .vscode
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "depth_anything"]
2 | path = depth_anything
3 | url = https://github.com/LiheYoung/Depth-Anything.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/NYUv2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hongxiaoy/ISO/12d30d244479d52fe64bdc6402bf3dfb4d43503b/NYUv2.gif
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Monocular Occupancy Prediction for Scalable Indoor Scenes
3 |
4 | [**Hongxiao Yu**](https://orcid.org/0009-0003-9249-2726)
1,2 · [**Yuqi Wang**](https://orcid.org/0000-0002-6360-1431)
1,2 · [**Yuntao Chen**](https://orcid.org/0000-0002-9555-1897)
3 · [**Zhaoxiang Zhang**](https://orcid.org/0000-0003-2648-3875)
1,2,3
5 |
6 |
1School of Artificial Intelligence, University of Chinese Academy of Sciences (UCAS)
7 |
8 |
2NLPR, MAIS, Institute of Automation, Chinese Academy of Sciences (CASIA)
9 |
10 |
3Centre for Artificial Intelligence and Robotics (HKISI_CAS)
11 |
12 | **ECCV 2024**
13 |
14 | [](https://arxiv.org/abs/2407.11730) [](https://hongxiaoy.github.io/ISO)
15 | [](https://huggingface.co/spaces/hongxiaoy/ISO)
16 |
17 | [](https://paperswithcode.com/sota/3d-semantic-scene-completion-from-a-single?p=monocular-occupancy-prediction-for-scalable)
18 |
19 |
20 |

21 |
22 |
23 |
24 | # Performance
25 |
26 | Here we compare our ISO with the previously best NDC-Scene and MonoScene model.
27 |
28 | | Method | IoU | ceiling | floor | wall | window | chair | bed | sofa | table | tvs | furniture | object | mIoU |
29 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
30 | | MonoScene | 42.51 | 8.89 | 93.50 | 12.06 | 12.57 | 13.72 | 48.19 | 36.11 | 15.13 | 15.22 | 27.96 | 12.94 | 26.94 |
31 | | NDC-Scene| 44.17 | 12.02 | **93.51** | 13.11 | 13.77 | 15.83 | 49.57 | 39.87 | 17.17 | 24.57 | 31.00 | 14.96 | 29.03 |
32 | | Ours | **47.11** | **14.21** | 93.47 | **15.89** | **15.14** | **18.35** | **50.01** | **40.82** | **18.25** | **25.90** | **34.08** | **17.67** | **31.25** |
33 |
34 | We highlight the **best** results in **bold**.
35 |
36 | Pretrained models on NYUv2 can be downloaded [here](https://huggingface.co/hongxiaoy/ISO/tree/main).
37 |
38 | # Preparing ISO
39 |
40 | ## Installation
41 |
42 | 1. Create conda environment:
43 |
44 | ```
45 | $ conda create -n iso python=3.9 -y
46 | $ conda activate iso
47 | ```
48 | 2. This code was implemented with python 3.9, pytorch 2.0.0 and CUDA 11.7. Please install [PyTorch](https://pytorch.org/):
49 |
50 | ```
51 | $ conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia
52 | ```
53 |
54 | 3. Install the additional dependencies:
55 |
56 | ```
57 | $ git clone --recursive https://github.com/hongxiaoy/ISO.git
58 | $ cd ISO/
59 | $ pip install -r requirements.txt
60 | ```
61 |
62 | > :bulb:Note
63 | >
64 | > Change L140 in ```depth_anything/metric_depth/zoedepth/models/base_models/dpt_dinov2/dpt.py``` to
65 | >
66 | > ```self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder), pretrained=False)```
67 | >
68 | > Then, download Depth-Anything pre-trained [model](https://github.com/LiheYoung/Depth-Anything/tree/main#no-network-connection-cannot-load-these-models) and metric depth [model](https://github.com/LiheYoung/Depth-Anything/tree/main/metric_depth#evaluation) checkpoints file to ```checkpoints/```.
69 |
70 | 4. Install tbb:
71 |
72 | ```
73 | $ conda install -c bioconda tbb=2020.2
74 | ```
75 |
76 | 5. Finally, install ISO:
77 |
78 | ```
79 | $ pip install -e ./
80 | ```
81 |
82 | > :bulb:Note
83 | >
84 | > If you move the ISO dir to another place, you should run
85 | >
86 | > ```pip cache purge```
87 | >
88 | > then run ```pip install -e ./``` again.
89 |
90 | ## Datasets
91 |
92 | ### NYUv2
93 |
94 | 1. Download the [NYUv2 dataset](https://www.rocq.inria.fr/rits_files/computer-vision/monoscene/nyu.zip).
95 |
96 | 2. Create a folder to store NYUv2 preprocess data at `/path/to/NYU/preprocess/folder`.
97 |
98 | 3. Store paths in environment variables for faster access:
99 |
100 | ```
101 | $ export NYU_PREPROCESS=/path/to/NYU/preprocess/folder
102 | $ export NYU_ROOT=/path/to/NYU/depthbin
103 | ```
104 |
105 | > :bulb:Note
106 | >
107 | > Recommend using
108 | >
109 | > ```echo "export NYU_PREPROCESS=/path/to/NYU/preprocess/folder" >> ~/.bashrc```
110 | >
111 | > format command for future convenience.
112 |
113 | 4. Preprocess the data to generate labels at a lower scale, which are used to compute the ground truth relation matrices:
114 |
115 | ```
116 | $ cd ISO/
117 | $ python iso/data/NYU/preprocess.py NYU_root=$NYU_ROOT NYU_preprocess_root=$NYU_PREPROCESS
118 | ```
119 |
120 | ### Occ-ScanNet
121 |
122 | 1. Download the [Occ-ScanNet dataset](https://huggingface.co/datasets/hongxiaoy/OccScanNet), this include:
123 | - `posed_images`
124 | - `gathered_data`
125 | - `train_subscenes.txt`
126 | - `val_subscenes.txt`
127 |
128 | 2. Create a root folder to store Occ-ScanNet dataset `/path/to/Occ/ScanNet/folder`, and move the all dataset files to this folder, zip files need extraction.
129 |
130 | 3. Store paths in environment variables for faster access:
131 |
132 | ```
133 | $ export OCC_SCANNET_ROOT=/path/to/Occ/ScanNet/folder
134 | ```
135 |
136 | > :bulb:Note
137 | >
138 | > Recommend using
139 | >
140 | > ```echo "export OCC_SCANNET_ROOT=/path/to/Occ/ScanNet/folder" >> ~/.bashrc```
141 | >
142 | > format command for future convenience.
143 |
144 | ## Pretrained Models
145 |
146 | Download ISO pretrained models [on NYUv2](https://huggingface.co/hongxiaoy/ISO/tree/main), then put them in the folder `/path/to/ISO/trained_models`.
147 |
148 | ```bash
149 | huggingface-cli download --repo-type model hongxiaoy/ISO
150 | ```
151 |
152 | If you didn't install `huggingface-cli` before, please following [official instructions](https://huggingface.co/docs/hub/en/models-adding-libraries#installation).
153 |
154 | # Running ISO
155 |
156 | ## Training
157 |
158 | ### NYUv2
159 |
160 | 1. Create folders to store training logs at **/path/to/NYU/logdir**.
161 |
162 | 2. Store in an environment variable:
163 |
164 | ```
165 | $ export NYU_LOG=/path/to/NYU/logdir
166 | ```
167 |
168 | 3. Train ISO using 2 GPUs with batch_size of 4 (2 item per GPU) on NYUv2:
169 | ```
170 | $ cd ISO/
171 | $ python iso/scripts/train_iso.py \
172 | dataset=NYU \
173 | NYU_root=$NYU_ROOT \
174 | NYU_preprocess_root=$NYU_PREPROCESS \
175 | logdir=$NYU_LOG \
176 | n_gpus=2 batch_size=4
177 | ```
178 |
179 | ### Occ-ScanNet
180 |
181 | 1. Create folders to store training logs at **/path/to/OccScanNet/logdir**.
182 |
183 | 2. Store in an environment variable:
184 |
185 | ```
186 | $ export OCC_SCANNET_LOG=/path/to/OccScanNet/logdir
187 | ```
188 |
189 | 3. Train ISO using 2 GPUs with batch_size of 4 (2 item per GPU) on Occ-ScanNet (should match config file name in train_iso.py):
190 | ```
191 | $ cd ISO/
192 | $ python iso/scripts/train_iso.py \
193 | dataset=OccScanNet \
194 | OccScanNet_root=$OCC_SCANNET_ROOT \
195 | logdir=$OCC_SCANNET_LOG \
196 | n_gpus=2 batch_size=4
197 | ```
198 |
199 | ## Evaluating
200 |
201 | ### NYUv2
202 |
203 | To evaluate ISO on NYUv2 test set, type:
204 |
205 | ```
206 | $ cd ISO/
207 | $ python iso/scripts/eval_iso.py \
208 | dataset=NYU \
209 | NYU_root=$NYU_ROOT\
210 | NYU_preprocess_root=$NYU_PREPROCESS \
211 | n_gpus=1 batch_size=1
212 | ```
213 |
214 | ## Inference
215 |
216 | Please create folder **/path/to/iso/output** to store the ISO outputs and store in environment variable:
217 |
218 | ```
219 | export ISO_OUTPUT=/path/to/iso/output
220 | ```
221 |
222 | ### NYUv2
223 |
224 | To generate the predictions on the NYUv2 test set, type:
225 |
226 | ```
227 | $ cd ISO/
228 | $ python iso/scripts/generate_output.py \
229 | +output_path=$ISO_OUTPUT \
230 | dataset=NYU \
231 | NYU_root=$NYU_ROOT \
232 | NYU_preprocess_root=$NYU_PREPROCESS \
233 | n_gpus=1 batch_size=1
234 | ```
235 |
236 | ## Visualization
237 |
238 | You need to create a new Anaconda environment for visualization.
239 |
240 | ```bash
241 | conda create -n mayavi_vis python=3.7 -y
242 | conda activate mayavi_vis
243 | pip install omegaconf hydra-core PyQt5 mayavi
244 | ```
245 |
246 | If you meet some problem when installing `mayavi`, please refer to the following instructions:
247 |
248 | - [Official mayavi installation instruction](https://docs.enthought.com/mayavi/installation.html)
249 |
250 |
251 | ### NYUv2
252 | ```
253 | $ cd ISO/
254 | $ python iso/scripts/visualization/NYU_vis_pred.py +file=/path/to/output/file.pkl
255 | ```
256 |
257 |
258 | # Aknowledgement
259 |
260 | This project is built based on MonoScene. Please refer to (https://github.com/astra-vision/MonoScene) for more documentations and details.
261 |
262 | We would like to thank the creators, maintainers, and contributors of the [MonoScene](https://github.com/astra-vision/MonoScene), [NDC-Scene](https://github.com/Jiawei-Yao0812/NDCScene), [ZoeDepth](https://github.com/isl-org/ZoeDepth), [Depth Anything](https://github.com/LiheYoung/Depth-Anything) for their invaluable work. Their dedication and open-source spirit have been instrumental in our development.
263 |
264 | # Citation
265 |
266 | ```
267 | @article{yu2024monocular,
268 | title={Monocular Occupancy Prediction for Scalable Indoor Scenes},
269 | author={Yu, Hongxiao and Wang, Yuqi and Chen, Yuntao and Zhang, Zhaoxiang},
270 | journal={arXiv preprint arXiv:2407.11730},
271 | year={2024}
272 | }
273 | ```
274 |
--------------------------------------------------------------------------------
/checkpoints/.placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hongxiaoy/ISO/12d30d244479d52fe64bdc6402bf3dfb4d43503b/checkpoints/.placeholder
--------------------------------------------------------------------------------
/iso/config/iso.yaml:
--------------------------------------------------------------------------------
1 | dataset: "NYU" # "kitti", "kitti_360"
2 | #dataset: "kitti_360"
3 |
4 | n_relations: 4
5 |
6 | enable_log: true
7 | #kitti_root: '/path/to/semantic_kitti'
8 | #kitti_preprocess_root: '/path/to/kitti/preprocess/folder'
9 | #kitti_logdir: '/path/to/semantic_kitti/logdir'
10 |
11 | NYU_root: "/mnt/vdb1/hongxiao.yu/data/NYU_dataset/depthbin/"
12 | NYU_preprocess_root: "/mnt/vdb1/hongxiao.yu/data/NYU_dataset/preprocess/"
13 | logdir: "/mnt/vdb1/hongxiao.yu/logs/ISO_logs/"
14 |
15 |
16 | fp_loss: true
17 | frustum_size: 8
18 | batch_size: 1
19 | n_gpus: 1
20 | num_workers_per_gpu: 0
21 | exp_prefix: "iso"
22 | run: 1
23 | lr: 1e-4
24 | weight_decay: 1e-4
25 |
26 | context_prior: true
27 |
28 | relation_loss: true
29 | CE_ssc_loss: true
30 | sem_scal_loss: true
31 | geo_scal_loss: true
32 |
33 | project_1_2: true
34 | project_1_4: true
35 | project_1_8: true
36 |
37 | voxeldepth: true
38 | voxeldepthcfg:
39 | depth_scale_1: true
40 | depth_scale_2: false
41 | depth_scale_4: false
42 | depth_scale_8: false
43 |
44 | use_gt_depth: false
45 | use_zoedepth: false
46 | use_depthanything: true
47 | zoedepth_as_gt: false
48 | depthanything_as_gt: false
49 |
50 | add_fusion: true
51 |
52 | frozen_encoder: true
53 |
--------------------------------------------------------------------------------
/iso/config/iso_occscannet.yaml:
--------------------------------------------------------------------------------
1 | dataset: "OccScanNet" # "NYU", "OccScanNet"
2 |
3 | n_relations: 4
4 |
5 | enable_log: true
6 |
7 | logdir: "/mnt/vdb1/hongxiao.yu/logs/ISO_occscannet_logs/"
8 | OccScanNet_root: "/mnt/vdb1/hongxiao.yu/data/NYU_dataset/depthbin/"
9 |
10 |
11 | fp_loss: true
12 | frustum_size: 8
13 | batch_size: 1
14 | n_gpus: 1
15 | num_workers_per_gpu: 0
16 | exp_prefix: "iso"
17 | run: 1
18 | lr: 1e-4
19 | weight_decay: 1e-4
20 |
21 | context_prior: true
22 |
23 | relation_loss: true
24 | CE_ssc_loss: true
25 | sem_scal_loss: true
26 | geo_scal_loss: true
27 |
28 | project_1_2: true
29 | project_1_4: true
30 | project_1_8: true
31 |
32 | voxeldepth: true
33 | voxeldepthcfg:
34 | depth_scale_1: true
35 | depth_scale_2: false
36 | depth_scale_4: false
37 | depth_scale_8: false
38 |
39 | use_gt_depth: false
40 | use_zoedepth: false
41 | use_depthanything: true
42 | zoedepth_as_gt: false
43 | depthanything_as_gt: false
44 |
45 | add_fusion: true
46 |
47 | frozen_encoder: true
48 |
--------------------------------------------------------------------------------
/iso/config/iso_occscannet_mini.yaml:
--------------------------------------------------------------------------------
1 | dataset: "OccScanNet_mini" # "NYU", "OccScanNet", "OccScanNet_mini"
2 |
3 | n_relations: 4
4 |
5 | enable_log: true
6 |
7 | logdir: "/mnt/vdb1/hongxiao.yu/logs/ISO_occscannet_logs/"
8 | OccScanNet_root: "/home/hongxiao.yu/projects/ISO_occscannet_2/occ_data_root"
9 |
10 |
11 | fp_loss: true
12 | frustum_size: 8
13 | batch_size: 1
14 | n_gpus: 1
15 | num_workers_per_gpu: 0
16 | exp_prefix: "iso"
17 | run: 1
18 | lr: 1e-4
19 | weight_decay: 1e-4
20 |
21 | context_prior: true
22 |
23 | relation_loss: true
24 | CE_ssc_loss: true
25 | sem_scal_loss: true
26 | geo_scal_loss: true
27 |
28 | project_1_2: true
29 | project_1_4: true
30 | project_1_8: true
31 |
32 | voxeldepth: true
33 | voxeldepthcfg:
34 | depth_scale_1: true
35 | depth_scale_2: false
36 | depth_scale_4: false
37 | depth_scale_8: false
38 |
39 | use_gt_depth: false
40 | use_zoedepth: false
41 | use_depthanything: true
42 | zoedepth_as_gt: false
43 | depthanything_as_gt: false
44 |
45 | add_fusion: true
46 |
47 | frozen_encoder: true
48 |
--------------------------------------------------------------------------------
/iso/data/NYU/collate.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def collate_fn(batch):
5 | data = {}
6 | imgs = []
7 | raw_imgs = []
8 | depth_gts = []
9 | targets = []
10 | names = []
11 | cam_poses = []
12 |
13 | pix_z = []
14 |
15 | vox_origins = []
16 | cam_ks = []
17 |
18 | CP_mega_matrices = []
19 |
20 | data["projected_pix_1"] = []
21 | data["fov_mask_1"] = []
22 | data["frustums_masks"] = []
23 | data["frustums_class_dists"] = []
24 |
25 | for idx, input_dict in enumerate(batch):
26 | CP_mega_matrices.append(torch.from_numpy(input_dict["CP_mega_matrix"]))
27 | for key in data:
28 | if key in input_dict:
29 | data[key].append(torch.from_numpy(input_dict[key]))
30 |
31 | cam_ks.append(torch.from_numpy(input_dict["cam_k"]).double())
32 | cam_poses.append(torch.from_numpy(input_dict["cam_pose"]).float())
33 | vox_origins.append(torch.from_numpy(input_dict["voxel_origin"]).double())
34 |
35 | pix_z.append(torch.from_numpy(input_dict['pix_z']).float())
36 |
37 | names.append(input_dict["name"])
38 |
39 | img = input_dict["img"]
40 | imgs.append(img)
41 |
42 | raw_img = input_dict['raw_img']
43 | raw_imgs.append(raw_img)
44 |
45 | depth_gt = torch.from_numpy(input_dict['depth_gt'])
46 | depth_gts.append(depth_gt)
47 |
48 | target = torch.from_numpy(input_dict["target"])
49 | targets.append(target)
50 |
51 | ret_data = {
52 | "CP_mega_matrices": CP_mega_matrices,
53 | "cam_pose": torch.stack(cam_poses),
54 | "cam_k": torch.stack(cam_ks),
55 | "vox_origin": torch.stack(vox_origins),
56 | "name": names,
57 | "img": torch.stack(imgs),
58 | # "raw_img": torch.stack(raw_imgs),
59 | "raw_img": raw_imgs,
60 | 'depth_gt': torch.stack(depth_gts),
61 | "target": torch.stack(targets),
62 | 'pix_z': torch.stack(pix_z),
63 | }
64 | for key in data:
65 | ret_data[key] = data[key]
66 | return ret_data
67 |
--------------------------------------------------------------------------------
/iso/data/NYU/nyu_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import glob
4 | from torch.utils.data import Dataset
5 | import numpy as np
6 | from PIL import Image, ImageTransform
7 | from torchvision import transforms
8 | from iso.data.utils.helpers import (
9 | vox2pix,
10 | compute_local_frustums,
11 | compute_CP_mega_matrix,
12 | )
13 | import pickle
14 | import torch.nn.functional as F
15 | import copy
16 |
17 |
18 | def read_depth(depth_path):
19 | depth_vis = Image.open(depth_path).convert('I;16')
20 | depth_vis_array = np.array(depth_vis)
21 |
22 | arr1 = np.right_shift(depth_vis_array, 3)
23 | arr2 = np.left_shift(depth_vis_array, 13)
24 | depth_vis_array = np.bitwise_or(arr1, arr2)
25 |
26 | depth_inpaint = depth_vis_array.astype(np.float32) / 1000.0
27 |
28 | return depth_inpaint
29 |
30 |
31 | class NYUDataset(Dataset):
32 | def __init__(
33 | self,
34 | split,
35 | root,
36 | preprocess_root,
37 | n_relations=4,
38 | color_jitter=None,
39 | frustum_size=4,
40 | fliplr=0.0,
41 | ):
42 | self.n_relations = n_relations
43 | self.frustum_size = frustum_size
44 | self.n_classes = 12
45 | self.root = os.path.join(root, "NYU" + split)
46 | self.preprocess_root = preprocess_root
47 | self.base_dir = os.path.join(preprocess_root, "base", "NYU" + split)
48 | self.fliplr = fliplr
49 |
50 | self.voxel_size = 0.08 # 0.08m
51 | self.scene_size = (4.8, 4.8, 2.88) # (4.8m, 4.8m, 2.88m)
52 | self.img_W = 640
53 | self.img_H = 480
54 | self.cam_k = np.array([[518.8579, 0, 320], [0, 518.8579, 240], [0, 0, 1]])
55 |
56 | self.color_jitter = (
57 | transforms.ColorJitter(*color_jitter) if color_jitter else None
58 | )
59 |
60 | self.scan_names = glob.glob(os.path.join(self.root, "*.bin"))
61 |
62 | self.normalize_rgb = transforms.Compose(
63 | [
64 | transforms.ToTensor(),
65 | transforms.Normalize(
66 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
67 | ),
68 | ]
69 | )
70 |
71 | def __getitem__(self, index):
72 | file_path = self.scan_names[index]
73 | filename = os.path.basename(file_path)
74 | name = filename[:-4]
75 |
76 | os.makedirs(self.base_dir, exist_ok=True)
77 | filepath = os.path.join(self.base_dir, name + ".pkl")
78 |
79 | with open(filepath, "rb") as handle:
80 | data = pickle.load(handle)
81 |
82 | cam_pose = data["cam_pose"]
83 | T_world_2_cam = np.linalg.inv(cam_pose)
84 | vox_origin = data["voxel_origin"]
85 | data["cam_k"] = self.cam_k
86 | target = data[
87 | "target_1_4"
88 | ] # Following SSC literature, the output resolution on NYUv2 is set to 1:4
89 | data["target"] = target
90 | target_1_4 = data["target_1_16"]
91 |
92 | CP_mega_matrix = compute_CP_mega_matrix(
93 | target_1_4, is_binary=self.n_relations == 2
94 | )
95 | data["CP_mega_matrix"] = CP_mega_matrix
96 |
97 | # compute the 3D-2D mapping
98 | projected_pix, fov_mask, pix_z = vox2pix(
99 | T_world_2_cam,
100 | self.cam_k,
101 | vox_origin,
102 | self.voxel_size,
103 | self.img_W,
104 | self.img_H,
105 | self.scene_size,
106 | )
107 |
108 | data["projected_pix_1"] = projected_pix
109 | data["fov_mask_1"] = fov_mask
110 |
111 | # compute the masks, each indicates voxels inside a frustum
112 | frustums_masks, frustums_class_dists = compute_local_frustums(
113 | projected_pix,
114 | pix_z,
115 | target,
116 | self.img_W,
117 | self.img_H,
118 | dataset="NYU",
119 | n_classes=12,
120 | size=self.frustum_size,
121 | )
122 | data["frustums_masks"] = frustums_masks
123 | data["frustums_class_dists"] = frustums_class_dists
124 |
125 | rgb_path = os.path.join(self.root, name + "_color.jpg")
126 | depth_path = os.path.join(self.root, name + ".png")
127 | img = Image.open(rgb_path).convert("RGB")
128 | raw_img_pil = copy.deepcopy(img)
129 | raw_img = np.array(raw_img_pil)
130 | depth_gt = read_depth(depth_path)
131 | depth_gt = np.array(depth_gt, dtype=np.float32, copy=False)
132 |
133 | # Image augmentation
134 | if self.color_jitter is not None:
135 | img = self.color_jitter(img)
136 | # PIL to numpy
137 | img = np.array(img, dtype=np.float32, copy=False) / 255.0
138 |
139 | # randomly fliplr the image
140 | if np.random.rand() < self.fliplr:
141 | img = np.ascontiguousarray(np.fliplr(img))
142 | # raw_img = np.ascontiguousarray(np.fliplr(raw_img))
143 | raw_img_pil = raw_img_pil.transpose(Image.FLIP_LEFT_RIGHT)
144 | data["projected_pix_1"][:, 0] = (
145 | img.shape[1] - 1 - data["projected_pix_1"][:, 0]
146 | )
147 |
148 | depth_gt = np.ascontiguousarray(np.fliplr(depth_gt))
149 |
150 | data["img"] = self.normalize_rgb(img) # (3, img_H, img_W)
151 | data['depth_gt'] = depth_gt[None] # (1, 480, 640)
152 | # data['raw_img'] = transforms.PILToTensor()(raw_img).float()
153 | # data['raw_img'] = transforms.ToTensor()(raw_img)
154 | data['raw_img'] = raw_img_pil
155 | # print(data['raw_img'].dtype)
156 | # print((data['raw_img'].numpy() - raw_img).max())
157 | data['pix_z'] = pix_z
158 |
159 | return data
160 |
161 | def __len__(self):
162 | return len(self.scan_names)
163 |
--------------------------------------------------------------------------------
/iso/data/NYU/nyu_dm.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataloader import DataLoader
2 | from iso.data.NYU.nyu_dataset import NYUDataset
3 | from iso.data.NYU.collate import collate_fn
4 | import pytorch_lightning as pl
5 | from iso.data.utils.torch_util import worker_init_fn
6 |
7 |
8 | class NYUDataModule(pl.LightningDataModule):
9 | def __init__(
10 | self,
11 | root,
12 | preprocess_root,
13 | n_relations=4,
14 | batch_size=4,
15 | frustum_size=4,
16 | num_workers=6,
17 | ):
18 | super().__init__()
19 | self.n_relations = n_relations
20 | self.preprocess_root = preprocess_root
21 | self.root = root
22 | self.batch_size = batch_size
23 | self.num_workers = num_workers
24 | self.frustum_size = frustum_size
25 |
26 | def setup(self, stage=None):
27 | self.train_ds = NYUDataset(
28 | split="train",
29 | preprocess_root=self.preprocess_root,
30 | n_relations=self.n_relations,
31 | root=self.root,
32 | fliplr=0.5,
33 | frustum_size=self.frustum_size,
34 | color_jitter=(0.4, 0.4, 0.4),
35 | )
36 | self.test_ds = NYUDataset(
37 | split="test",
38 | preprocess_root=self.preprocess_root,
39 | n_relations=self.n_relations,
40 | root=self.root,
41 | frustum_size=self.frustum_size,
42 | fliplr=0.0,
43 | color_jitter=None,
44 | )
45 |
46 | def train_dataloader(self):
47 | return DataLoader(
48 | self.train_ds,
49 | batch_size=self.batch_size,
50 | drop_last=True,
51 | num_workers=self.num_workers,
52 | shuffle=True,
53 | pin_memory=True,
54 | worker_init_fn=worker_init_fn,
55 | collate_fn=collate_fn,
56 | )
57 |
58 | def val_dataloader(self):
59 | return DataLoader(
60 | self.test_ds,
61 | batch_size=self.batch_size,
62 | num_workers=self.num_workers,
63 | drop_last=False,
64 | shuffle=False,
65 | pin_memory=True,
66 | collate_fn=collate_fn,
67 | )
68 |
69 | def test_dataloader(self):
70 | return DataLoader(
71 | self.test_ds,
72 | batch_size=self.batch_size,
73 | num_workers=self.num_workers,
74 | drop_last=False,
75 | shuffle=False,
76 | pin_memory=True,
77 | collate_fn=collate_fn,
78 | )
79 |
--------------------------------------------------------------------------------
/iso/data/NYU/params.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | NYU_class_names = [
5 | "empty",
6 | "ceiling",
7 | "floor",
8 | "wall",
9 | "window",
10 | "chair",
11 | "bed",
12 | "sofa",
13 | "table",
14 | "tvs",
15 | "furn",
16 | "objs",
17 | ]
18 | class_weights = torch.FloatTensor([0.05, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
19 |
20 | class_freq_1_4 = np.array(
21 | [
22 | 43744234,
23 | 80205,
24 | 1070052,
25 | 905632,
26 | 116952,
27 | 180994,
28 | 436852,
29 | 279714,
30 | 254611,
31 | 28247,
32 | 1805949,
33 | 850724,
34 | ]
35 | )
36 | class_freq_1_8 = np.array(
37 | [
38 | 5176253,
39 | 17277,
40 | 220105,
41 | 183849,
42 | 21827,
43 | 33520,
44 | 67022,
45 | 44248,
46 | 46615,
47 | 4419,
48 | 290218,
49 | 142573,
50 | ]
51 | )
52 | class_freq_1_16 = np.array(
53 | [587620, 3820, 46836, 36256, 4241, 5978, 10939, 8000, 8224, 781, 49778, 25864]
54 | )
55 |
--------------------------------------------------------------------------------
/iso/data/NYU/preprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import numpy.matlib
4 | import os
5 | import glob
6 | import pickle
7 | import hydra
8 | from omegaconf import DictConfig
9 |
10 |
11 | seg_class_map = [
12 | 0,
13 | 1,
14 | 2,
15 | 3,
16 | 4,
17 | 11,
18 | 5,
19 | 6,
20 | 7,
21 | 8,
22 | 8,
23 | 10,
24 | 10,
25 | 10,
26 | 11,
27 | 11,
28 | 9,
29 | 8,
30 | 11,
31 | 11,
32 | 11,
33 | 11,
34 | 11,
35 | 11,
36 | 11,
37 | 11,
38 | 11,
39 | 10,
40 | 10,
41 | 11,
42 | 8,
43 | 10,
44 | 11,
45 | 9,
46 | 11,
47 | 11,
48 | 11,
49 | ]
50 |
51 |
52 | def _rle2voxel(rle, voxel_size=(240, 144, 240), rle_filename=""):
53 | r"""Read voxel label data from file (RLE compression), and convert it to fully occupancy labeled voxels.
54 | code taken from https://github.com/waterljwant/SSC/blob/master/dataloaders/dataloader.py#L172
55 | In the data loader of pytorch, only single thread is allowed.
56 | For multi-threads version and more details, see 'readRLE.py'.
57 | output: seg_label: 3D numpy array, size 240 x 144 x 240
58 | """
59 | seg_label = np.zeros(
60 | int(voxel_size[0] * voxel_size[1] * voxel_size[2]), dtype=np.uint8
61 | ) # segmentation label
62 | vox_idx = 0
63 | for idx in range(int(rle.shape[0] / 2)):
64 | check_val = rle[idx * 2]
65 | check_iter = rle[idx * 2 + 1]
66 | if check_val >= 37 and check_val != 255: # 37 classes to 12 classes
67 | print("RLE {} check_val: {}".format(rle_filename, check_val))
68 | seg_label_val = (
69 | seg_class_map[check_val] if check_val != 255 else 255
70 | ) # 37 classes to 12 classes
71 | seg_label[vox_idx : vox_idx + check_iter] = np.matlib.repmat(
72 | seg_label_val, 1, check_iter
73 | )
74 | vox_idx = vox_idx + check_iter
75 | seg_label = seg_label.reshape(voxel_size) # 3D array, size 240 x 144 x 240
76 | return seg_label
77 |
78 |
79 | def _read_rle(rle_filename): # 0.0005s
80 | """Read RLE compression data
81 | code taken from https://github.com/waterljwant/SSC/blob/master/dataloaders/dataloader.py#L153
82 | Return:
83 | vox_origin,
84 | cam_pose,
85 | vox_rle, voxel label data from file
86 | Shape:
87 | vox_rle, (240, 144, 240)
88 | """
89 | fid = open(rle_filename, "rb")
90 | vox_origin = np.fromfile(
91 | fid, np.float32, 3
92 | ).T # Read voxel origin in world coordinates
93 | cam_pose = np.fromfile(fid, np.float32, 16).reshape((4, 4)) # Read camera pose
94 | vox_rle = (
95 | np.fromfile(fid, np.uint32).reshape((-1, 1)).T
96 | ) # Read voxel label data from file
97 | vox_rle = np.squeeze(vox_rle) # 2d array: (1 x N), to 1d array: (N , )
98 | fid.close()
99 | return vox_origin, cam_pose, vox_rle
100 |
101 |
102 | def _downsample_label(label, voxel_size=(240, 144, 240), downscale=4):
103 | r"""downsample the labeled data,
104 | code taken from https://github.com/waterljwant/SSC/blob/master/dataloaders/dataloader.py#L262
105 | Shape:
106 | label, (240, 144, 240)
107 | label_downscale, if downsample==4, then (60, 36, 60)
108 | """
109 | if downscale == 1:
110 | return label
111 | ds = downscale
112 | small_size = (
113 | voxel_size[0] // ds,
114 | voxel_size[1] // ds,
115 | voxel_size[2] // ds,
116 | ) # small size
117 | label_downscale = np.zeros(small_size, dtype=np.uint8)
118 | empty_t = 0.95 * ds * ds * ds # threshold
119 | s01 = small_size[0] * small_size[1]
120 | label_i = np.zeros((ds, ds, ds), dtype=np.int32)
121 |
122 | for i in range(small_size[0] * small_size[1] * small_size[2]):
123 | z = int(i / s01)
124 | y = int((i - z * s01) / small_size[0])
125 | x = int(i - z * s01 - y * small_size[0])
126 |
127 | label_i[:, :, :] = label[
128 | x * ds : (x + 1) * ds, y * ds : (y + 1) * ds, z * ds : (z + 1) * ds
129 | ]
130 | label_bin = label_i.flatten()
131 |
132 | zero_count_0 = np.array(np.where(label_bin == 0)).size
133 | zero_count_255 = np.array(np.where(label_bin == 255)).size
134 |
135 | zero_count = zero_count_0 + zero_count_255
136 | if zero_count > empty_t:
137 | label_downscale[x, y, z] = 0 if zero_count_0 > zero_count_255 else 255
138 | else:
139 | label_i_s = label_bin[
140 | np.where(np.logical_and(label_bin > 0, label_bin < 255))
141 | ]
142 | label_downscale[x, y, z] = np.argmax(np.bincount(label_i_s))
143 | return label_downscale
144 |
145 |
146 | @hydra.main(config_name="../../config/iso.yaml")
147 | def main(config: DictConfig):
148 | scene_size = (240, 144, 240)
149 | for split in ["train", "test"]:
150 | root = os.path.join(config.NYU_root, "NYU" + split)
151 | base_dir = os.path.join(config.NYU_preprocess_root, "base", "NYU" + split)
152 | os.makedirs(base_dir, exist_ok=True)
153 |
154 | scans = glob.glob(os.path.join(root, "*.bin"))
155 | for scan in tqdm(scans):
156 | filename = os.path.basename(scan)
157 | name = filename[:-4]
158 | filepath = os.path.join(base_dir, name + ".pkl")
159 | if os.path.exists(filepath):
160 | continue
161 |
162 | vox_origin, cam_pose, rle = _read_rle(scan)
163 |
164 | target_1_1 = _rle2voxel(rle, scene_size, scan)
165 | target_1_4 = _downsample_label(target_1_1, scene_size, 4)
166 | target_1_16 = _downsample_label(target_1_1, scene_size, 16)
167 |
168 | data = {
169 | "cam_pose": cam_pose,
170 | "voxel_origin": vox_origin,
171 | "name": name,
172 | "target_1_4": target_1_4,
173 | "target_1_16": target_1_16,
174 | }
175 |
176 | with open(filepath, "wb") as handle:
177 | pickle.dump(data, handle)
178 | print("wrote to", filepath)
179 |
180 |
181 | if __name__ == "__main__":
182 | main()
183 |
--------------------------------------------------------------------------------
/iso/data/OccScanNet/collate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def collate_fn(batch):
6 | data = {}
7 | imgs = []
8 | raw_imgs = []
9 | depth_gts = []
10 | targets = []
11 | names = []
12 | cam_poses = []
13 |
14 | pix_z = []
15 |
16 | vox_origins = []
17 | cam_ks = []
18 |
19 | CP_mega_matrices = []
20 |
21 | data["projected_pix_1"] = []
22 | data["fov_mask_1"] = []
23 | data["frustums_masks"] = []
24 | data["frustums_class_dists"] = []
25 |
26 | for idx, input_dict in enumerate(batch):
27 | try:
28 | CP_mega_matrices.append(torch.from_numpy(input_dict["CP_mega_matrix"]))
29 | except:
30 | CP_mega_matrices.append(input_dict["CP_mega_matrix"])
31 | for key in data:
32 | if key in input_dict:
33 | data[key].append(torch.from_numpy(input_dict[key]))
34 |
35 | cam_ks.append(torch.from_numpy(input_dict["cam_k"]).double())
36 | cam_poses.append(torch.from_numpy(input_dict["cam_pose"]).float())
37 | vox_origins.append(torch.from_numpy(np.array(input_dict["voxel_origin"])).double())
38 |
39 | pix_z.append(torch.from_numpy(input_dict['pix_z']).float())
40 |
41 | names.append(input_dict["name"])
42 |
43 | img = input_dict["img"]
44 | imgs.append(img)
45 | raw_img = input_dict['raw_img']
46 | raw_imgs.append(raw_img)
47 |
48 | depth_gt = torch.from_numpy(input_dict['depth_gt'])
49 | depth_gts.append(depth_gt)
50 |
51 | target = torch.from_numpy(input_dict["target"])
52 | targets.append(target)
53 |
54 | ret_data = {
55 | "CP_mega_matrices": CP_mega_matrices,
56 | "cam_pose": torch.stack(cam_poses),
57 | "cam_k": torch.stack(cam_ks),
58 | "vox_origin": torch.stack(vox_origins),
59 | "name": names,
60 | "img": torch.stack(imgs),
61 | "raw_img": raw_imgs,
62 | "target": torch.stack(targets),
63 | 'depth_gt': torch.stack(depth_gts),
64 | 'pix_z': torch.stack(pix_z),
65 | }
66 | for key in data:
67 | ret_data[key] = data[key]
68 |
69 | return ret_data
--------------------------------------------------------------------------------
/iso/data/OccScanNet/occscannet_dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from pprint import pprint
3 | import torch
4 | import os
5 | import glob
6 | from torch.utils.data import Dataset
7 | import numpy as np
8 | from PIL import Image
9 | from torchvision import transforms
10 | from iso.data.utils.helpers import (
11 | vox2pix,
12 | compute_local_frustums,
13 | compute_CP_mega_matrix,
14 | )
15 | import pickle
16 | import torch.nn.functional as F
17 | import matplotlib.pyplot as plt
18 |
19 | import cv2
20 |
21 |
22 | class OccScanNetDataset(Dataset):
23 | def __init__(
24 | self,
25 | split,
26 | root,
27 | interval=-1,
28 | train_scenes_sample=-1,
29 | val_scenes_sample=-1,
30 | n_relations=4,
31 | color_jitter=None,
32 | frustum_size=4,
33 | fliplr=0.0,
34 | v2=False,
35 | ):
36 | # cur_dir = os.path.abspath(os.path.curdir)
37 | self.occscannet_root = root
38 |
39 | self.n_relations = n_relations
40 | self.frustum_size = frustum_size
41 | self.split = split
42 | self.fliplr = fliplr
43 |
44 | self.voxel_size = 0.08 # 0.08m
45 | self.scene_size = (4.8, 4.8, 2.88) # (4.8m, 4.8m, 2.88m)
46 |
47 | self.color_jitter = (
48 | transforms.ColorJitter(*color_jitter) if color_jitter else None
49 | )
50 |
51 | # print(os.getcwd())
52 | if v2:
53 | subscenes_list = f'{self.occscannet_root}/{self.split}_subscenes_v2.txt'
54 | else: # data/occscannet/train_subscenes.txt
55 | subscenes_list = f'{self.occscannet_root}/{self.split}_subscenes.txt'
56 | with open(subscenes_list, 'r') as f:
57 | self.used_subscenes = f.readlines()
58 | for i in range(len(self.used_subscenes)):
59 | self.used_subscenes[i] = f'{self.occscannet_root}/' + self.used_subscenes[i].strip()
60 |
61 | if "train" in self.split:
62 | # breakpoint()
63 | if train_scenes_sample != -1:
64 | self.used_subscenes = self.used_subscenes[:train_scenes_sample]
65 | # if interval != -1:
66 | # self.used_subscenes = self.used_subscenes[::interval]
67 | # print(f"Total train scenes number: {len(self.used_scan_names)}")
68 | # pprint(self.used_subscenes)
69 | print(f"Total train scenes number: {len(self.used_subscenes)}")
70 | elif "val" in self.split:
71 | # print(f"Total validation scenes number: {len(self.used_scan_names)}")
72 | # pprint(self.used_subscenes)
73 | if val_scenes_sample != -1:
74 | self.used_subscenes = self.used_subscenes[:val_scenes_sample]
75 | print(f"Total validation scenes number: {len(self.used_subscenes)}")
76 |
77 | self.normalize_rgb = transforms.Compose(
78 | [
79 | transforms.ToTensor(),
80 | transforms.Normalize(
81 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
82 | ),
83 | ]
84 | )
85 | # os.chdir(cur_dir)
86 |
87 | def __getitem__(self, index):
88 | name = self.used_subscenes[index]
89 | with open(name, 'rb') as f:
90 | data = pickle.load(f)
91 |
92 | cam_pose = data["cam_pose"]
93 | cam_intrin = data['intrinsic']
94 |
95 | img = cv2.imread(data['img'])
96 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
97 | data['img'] = img
98 | data['raw_img'] = copy.deepcopy(cv2.resize(img, (640, 480)))
99 | depth_img = Image.open(data['depth_gt']).convert('I;16')
100 | depth_img = np.array(depth_img) / 1000.0
101 | data['depth_gt'] = depth_img
102 | depth_gt = data['depth_gt']
103 | # print(depth_gt.shape)
104 | img_H, img_W = img.shape[0], img.shape[1]
105 | img = cv2.resize(img, (640, 480))
106 | # plt.imshow(img)
107 | # plt.savefig('img.jpg')
108 | W_factor = 640 / img_W
109 | H_factor = 480 / img_H
110 | img_H, img_W = img.shape[0], img.shape[1]
111 |
112 | cam_intrin[0, 0] *= W_factor
113 | cam_intrin[1, 1] *= H_factor
114 | cam_intrin[0, 2] *= W_factor
115 | cam_intrin[1, 2] *= H_factor
116 |
117 | data["cam_pose"] = cam_pose
118 | T_world_2_cam = np.linalg.inv(cam_pose)
119 | vox_origin = list(data["voxel_origin"])
120 | vox_origin = np.array(vox_origin)
121 | data["vox_origin"] = vox_origin
122 | data["cam_k"] = cam_intrin[:3, :3][None]
123 |
124 |
125 | target = data[
126 | "target_1_4"
127 | ] # Following SSC literature, the output resolution on NYUv2 is set to 1:4
128 | target = np.where(target == 255, 0, target)
129 | data["target"] = target
130 | data["target_1_4"] = target
131 | target_1_4 = data["target_1_16"]
132 | target_1_4 = np.where(target_1_4 == 255, 0, target_1_4)
133 | data["target_1_16"] = target_1_4
134 |
135 | CP_mega_matrix = compute_CP_mega_matrix(
136 | target_1_4, is_binary=self.n_relations == 2
137 | )
138 |
139 | data["CP_mega_matrix"] = CP_mega_matrix
140 |
141 | # compute the 3D-2D mapping
142 | projected_pix, fov_mask, pix_z = vox2pix(
143 | T_world_2_cam,
144 | cam_intrin,
145 | vox_origin,
146 | self.voxel_size,
147 | img_W,
148 | img_H,
149 | self.scene_size,
150 | )
151 | # print(projected_pix)
152 | # print(fov_mask.shape)
153 |
154 | data["projected_pix_1"] = projected_pix
155 | data["fov_mask_1"] = fov_mask
156 | data['pix_z'] = pix_z
157 |
158 | # compute the masks, each indicates voxels inside a frustum
159 |
160 | frustums_masks, frustums_class_dists = compute_local_frustums(
161 | projected_pix,
162 | pix_z,
163 | target,
164 | img_W,
165 | img_H,
166 | dataset="OccScanNet",
167 | n_classes=12,
168 | size=self.frustum_size,
169 | )
170 | data["frustums_masks"] = frustums_masks
171 | data["frustums_class_dists"] = frustums_class_dists
172 |
173 | img = Image.fromarray(img).convert('RGB')
174 |
175 | # Image augmentation
176 | if self.color_jitter is not None:
177 | img = self.color_jitter(img)
178 |
179 | # PIL to numpy
180 | img = np.array(img, dtype=np.float32, copy=False) / 255.0
181 |
182 | if np.random.rand() < self.fliplr:
183 | img = np.ascontiguousarray(np.fliplr(img))
184 | # raw_img = np.ascontiguousarray(np.fliplr(raw_img))
185 | data['raw_img'] = np.ascontiguousarray(np.fliplr(data['raw_img']))
186 | data["projected_pix_1"][:, 0] = (
187 | img.shape[1] - 1 - data["projected_pix_1"][:, 0]
188 | )
189 |
190 | depth_gt = np.ascontiguousarray(np.fliplr(depth_gt))
191 | data['depth_gt'] = depth_gt
192 | # print(depth_gt.shape)
193 |
194 | data["img"] = self.normalize_rgb(img) # (3, img_H, img_W)
195 | data["name"] = name
196 |
197 | return data
198 |
199 | def __len__(self):
200 | if 'train' in self.split:
201 | return len(self.used_subscenes)
202 | elif 'val' in self.split:
203 | return len(self.used_subscenes)
204 |
205 | def test():
206 | datset = OccScanNetDataset(
207 | split='train',
208 | fliplr=0.5,
209 | frustum_size=8,
210 | color_jitter=(0.4, 0.4, 0.4),
211 | )
212 | datset = OccScanNetDataset(
213 | split='val',
214 | frustum_size=8,
215 | color_jitter=(0.4, 0.4, 0.4),
216 | )
217 |
218 | if __name__ == "__main__":
219 | test()
--------------------------------------------------------------------------------
/iso/data/OccScanNet/occscannet_dm.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataloader import DataLoader
2 | from iso.data.OccScanNet.occscannet_dataset import OccScanNetDataset
3 | from iso.data.OccScanNet.collate import collate_fn
4 | import pytorch_lightning as pl
5 | from iso.data.utils.torch_util import worker_init_fn
6 |
7 |
8 | class OccScanNetDataModule(pl.LightningDataModule):
9 | def __init__(
10 | self,
11 | root,
12 | n_relations=4,
13 | batch_size=4,
14 | frustum_size=4,
15 | num_workers=6,
16 | interval=-1,
17 | train_scenes_sample=-1,
18 | val_scenes_sample=-1,
19 | v2=False,
20 | ):
21 | super().__init__()
22 |
23 | self.root = root
24 | self.n_relations = n_relations
25 |
26 | self.batch_size = batch_size
27 | self.num_workers = num_workers
28 | self.frustum_size = frustum_size
29 |
30 | self.train_scenes_sample = train_scenes_sample
31 | self.val_scenes_sample = val_scenes_sample
32 |
33 | def setup(self, stage=None):
34 | self.train_ds = OccScanNetDataset(
35 | split="train",
36 | root=self.root,
37 | n_relations=self.n_relations,
38 | fliplr=0.5,
39 | train_scenes_sample=self.train_scenes_sample,
40 | frustum_size=self.frustum_size,
41 | color_jitter=(0.4, 0.4, 0.4),
42 | )
43 | self.test_ds = OccScanNetDataset(
44 | split="val",
45 | root=self.root,
46 | n_relations=self.n_relations,
47 | val_scenes_sample = self.val_scenes_sample,
48 | frustum_size=self.frustum_size,
49 | fliplr=0.0,
50 | color_jitter=None,
51 | )
52 |
53 | def train_dataloader(self):
54 | return DataLoader(
55 | self.train_ds,
56 | batch_size=self.batch_size,
57 | drop_last=True,
58 | num_workers=self.num_workers,
59 | shuffle=True,
60 | pin_memory=True,
61 | worker_init_fn=worker_init_fn,
62 | collate_fn=collate_fn,
63 | )
64 |
65 | def val_dataloader(self):
66 | return DataLoader(
67 | self.test_ds,
68 | batch_size=self.batch_size,
69 | num_workers=self.num_workers,
70 | drop_last=False,
71 | shuffle=False,
72 | pin_memory=True,
73 | collate_fn=collate_fn,
74 | )
75 |
76 | def test_dataloader(self):
77 | return DataLoader(
78 | self.test_ds,
79 | batch_size=self.batch_size,
80 | num_workers=self.num_workers,
81 | drop_last=False,
82 | shuffle=False,
83 | pin_memory=True,
84 | collate_fn=collate_fn,
85 | )
86 |
87 |
88 | def test():
89 | datamodule = OccScanNetDataModule(
90 | n_relations=4,
91 | batch_size=64,
92 | frustum_size=8,
93 | )
94 | datamodule.setup()
95 | train_data = datamodule.train_dataloader()
96 | print(next(iter(train_data)).shape)
97 |
98 |
99 | if __name__ == "__main__":
100 | test()
--------------------------------------------------------------------------------
/iso/data/OccScanNet/params.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | OccScanNet_class_names = [
5 | "empty",
6 | "ceiling",
7 | "floor",
8 | "wall",
9 | "window",
10 | "chair",
11 | "bed",
12 | "sofa",
13 | "table",
14 | "tvs",
15 | "furn",
16 | "objs",
17 | ]
18 | class_weights = torch.FloatTensor([0.05, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
--------------------------------------------------------------------------------
/iso/data/utils/fusion.py:
--------------------------------------------------------------------------------
1 | """
2 | Most of the code is taken from https://github.com/andyzeng/tsdf-fusion-python/blob/master/fusion.py
3 |
4 | @inproceedings{zeng20163dmatch,
5 | title={3DMatch: Learning Local Geometric Descriptors from RGB-D Reconstructions},
6 | author={Zeng, Andy and Song, Shuran and Nie{\ss}ner, Matthias and Fisher, Matthew and Xiao, Jianxiong and Funkhouser, Thomas},
7 | booktitle={CVPR},
8 | year={2017}
9 | }
10 | """
11 |
12 | import numpy as np
13 |
14 | from numba import njit, prange
15 | from skimage import measure
16 |
17 | FUSION_GPU_MODE = 0
18 |
19 |
20 | class TSDFVolume:
21 | """Volumetric TSDF Fusion of RGB-D Images."""
22 |
23 | def __init__(self, vol_bnds, voxel_size, use_gpu=True):
24 | """Constructor.
25 |
26 | Args:
27 | vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the
28 | xyz bounds (min/max) in meters.
29 | voxel_size (float): The volume discretization in meters.
30 | """
31 | vol_bnds = np.asarray(vol_bnds)
32 | assert vol_bnds.shape == (3, 2), "[!] `vol_bnds` should be of shape (3, 2)."
33 |
34 | # Define voxel volume parameters
35 | self._vol_bnds = vol_bnds
36 | self._voxel_size = float(voxel_size)
37 | self._trunc_margin = 5 * self._voxel_size # truncation on SDF
38 | # self._trunc_margin = 10 # truncation on SDF
39 | self._color_const = 256 * 256
40 |
41 | # Adjust volume bounds and ensure C-order contiguous
42 | self._vol_dim = (
43 | np.ceil((self._vol_bnds[:, 1] - self._vol_bnds[:, 0]) / self._voxel_size)
44 | .copy(order="C")
45 | .astype(int)
46 | )
47 | self._vol_bnds[:, 1] = self._vol_bnds[:, 0] + self._vol_dim * self._voxel_size
48 | self._vol_origin = self._vol_bnds[:, 0].copy(order="C").astype(np.float32)
49 |
50 | print(
51 | "Voxel volume size: {} x {} x {} - # points: {:,}".format(
52 | self._vol_dim[0],
53 | self._vol_dim[1],
54 | self._vol_dim[2],
55 | self._vol_dim[0] * self._vol_dim[1] * self._vol_dim[2],
56 | )
57 | )
58 |
59 | # Initialize pointers to voxel volume in CPU memory
60 | self._tsdf_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
61 | # for computing the cumulative moving average of observations per voxel
62 | self._weight_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
63 | self._color_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
64 |
65 | self.gpu_mode = use_gpu and FUSION_GPU_MODE
66 |
67 | # Copy voxel volumes to GPU
68 | if self.gpu_mode:
69 | self._tsdf_vol_gpu = cuda.mem_alloc(self._tsdf_vol_cpu.nbytes)
70 | cuda.memcpy_htod(self._tsdf_vol_gpu, self._tsdf_vol_cpu)
71 | self._weight_vol_gpu = cuda.mem_alloc(self._weight_vol_cpu.nbytes)
72 | cuda.memcpy_htod(self._weight_vol_gpu, self._weight_vol_cpu)
73 | self._color_vol_gpu = cuda.mem_alloc(self._color_vol_cpu.nbytes)
74 | cuda.memcpy_htod(self._color_vol_gpu, self._color_vol_cpu)
75 |
76 | # Cuda kernel function (C++)
77 | self._cuda_src_mod = SourceModule(
78 | """
79 | __global__ void integrate(float * tsdf_vol,
80 | float * weight_vol,
81 | float * color_vol,
82 | float * vol_dim,
83 | float * vol_origin,
84 | float * cam_intr,
85 | float * cam_pose,
86 | float * other_params,
87 | float * color_im,
88 | float * depth_im) {
89 | // Get voxel index
90 | int gpu_loop_idx = (int) other_params[0];
91 | int max_threads_per_block = blockDim.x;
92 | int block_idx = blockIdx.z*gridDim.y*gridDim.x+blockIdx.y*gridDim.x+blockIdx.x;
93 | int voxel_idx = gpu_loop_idx*gridDim.x*gridDim.y*gridDim.z*max_threads_per_block+block_idx*max_threads_per_block+threadIdx.x;
94 | int vol_dim_x = (int) vol_dim[0];
95 | int vol_dim_y = (int) vol_dim[1];
96 | int vol_dim_z = (int) vol_dim[2];
97 | if (voxel_idx > vol_dim_x*vol_dim_y*vol_dim_z)
98 | return;
99 | // Get voxel grid coordinates (note: be careful when casting)
100 | float voxel_x = floorf(((float)voxel_idx)/((float)(vol_dim_y*vol_dim_z)));
101 | float voxel_y = floorf(((float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z))/((float)vol_dim_z));
102 | float voxel_z = (float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z-((int)voxel_y)*vol_dim_z);
103 | // Voxel grid coordinates to world coordinates
104 | float voxel_size = other_params[1];
105 | float pt_x = vol_origin[0]+voxel_x*voxel_size;
106 | float pt_y = vol_origin[1]+voxel_y*voxel_size;
107 | float pt_z = vol_origin[2]+voxel_z*voxel_size;
108 | // World coordinates to camera coordinates
109 | float tmp_pt_x = pt_x-cam_pose[0*4+3];
110 | float tmp_pt_y = pt_y-cam_pose[1*4+3];
111 | float tmp_pt_z = pt_z-cam_pose[2*4+3];
112 | float cam_pt_x = cam_pose[0*4+0]*tmp_pt_x+cam_pose[1*4+0]*tmp_pt_y+cam_pose[2*4+0]*tmp_pt_z;
113 | float cam_pt_y = cam_pose[0*4+1]*tmp_pt_x+cam_pose[1*4+1]*tmp_pt_y+cam_pose[2*4+1]*tmp_pt_z;
114 | float cam_pt_z = cam_pose[0*4+2]*tmp_pt_x+cam_pose[1*4+2]*tmp_pt_y+cam_pose[2*4+2]*tmp_pt_z;
115 | // Camera coordinates to image pixels
116 | int pixel_x = (int) roundf(cam_intr[0*3+0]*(cam_pt_x/cam_pt_z)+cam_intr[0*3+2]);
117 | int pixel_y = (int) roundf(cam_intr[1*3+1]*(cam_pt_y/cam_pt_z)+cam_intr[1*3+2]);
118 | // Skip if outside view frustum
119 | int im_h = (int) other_params[2];
120 | int im_w = (int) other_params[3];
121 | if (pixel_x < 0 || pixel_x >= im_w || pixel_y < 0 || pixel_y >= im_h || cam_pt_z<0)
122 | return;
123 | // Skip invalid depth
124 | float depth_value = depth_im[pixel_y*im_w+pixel_x];
125 | if (depth_value == 0)
126 | return;
127 | // Integrate TSDF
128 | float trunc_margin = other_params[4];
129 | float depth_diff = depth_value-cam_pt_z;
130 | if (depth_diff < -trunc_margin)
131 | return;
132 | float dist = fmin(1.0f,depth_diff/trunc_margin);
133 | float w_old = weight_vol[voxel_idx];
134 | float obs_weight = other_params[5];
135 | float w_new = w_old + obs_weight;
136 | weight_vol[voxel_idx] = w_new;
137 | tsdf_vol[voxel_idx] = (tsdf_vol[voxel_idx]*w_old+obs_weight*dist)/w_new;
138 | // Integrate color
139 | float old_color = color_vol[voxel_idx];
140 | float old_b = floorf(old_color/(256*256));
141 | float old_g = floorf((old_color-old_b*256*256)/256);
142 | float old_r = old_color-old_b*256*256-old_g*256;
143 | float new_color = color_im[pixel_y*im_w+pixel_x];
144 | float new_b = floorf(new_color/(256*256));
145 | float new_g = floorf((new_color-new_b*256*256)/256);
146 | float new_r = new_color-new_b*256*256-new_g*256;
147 | new_b = fmin(roundf((old_b*w_old+obs_weight*new_b)/w_new),255.0f);
148 | new_g = fmin(roundf((old_g*w_old+obs_weight*new_g)/w_new),255.0f);
149 | new_r = fmin(roundf((old_r*w_old+obs_weight*new_r)/w_new),255.0f);
150 | color_vol[voxel_idx] = new_b*256*256+new_g*256+new_r;
151 | }"""
152 | )
153 |
154 | self._cuda_integrate = self._cuda_src_mod.get_function("integrate")
155 |
156 | # Determine block/grid size on GPU
157 | gpu_dev = cuda.Device(0)
158 | self._max_gpu_threads_per_block = gpu_dev.MAX_THREADS_PER_BLOCK
159 | n_blocks = int(
160 | np.ceil(
161 | float(np.prod(self._vol_dim))
162 | / float(self._max_gpu_threads_per_block)
163 | )
164 | )
165 | grid_dim_x = min(gpu_dev.MAX_GRID_DIM_X, int(np.floor(np.cbrt(n_blocks))))
166 | grid_dim_y = min(
167 | gpu_dev.MAX_GRID_DIM_Y, int(np.floor(np.sqrt(n_blocks / grid_dim_x)))
168 | )
169 | grid_dim_z = min(
170 | gpu_dev.MAX_GRID_DIM_Z,
171 | int(np.ceil(float(n_blocks) / float(grid_dim_x * grid_dim_y))),
172 | )
173 | self._max_gpu_grid_dim = np.array(
174 | [grid_dim_x, grid_dim_y, grid_dim_z]
175 | ).astype(int)
176 | self._n_gpu_loops = int(
177 | np.ceil(
178 | float(np.prod(self._vol_dim))
179 | / float(
180 | np.prod(self._max_gpu_grid_dim)
181 | * self._max_gpu_threads_per_block
182 | )
183 | )
184 | )
185 |
186 | else:
187 | # Get voxel grid coordinates
188 | xv, yv, zv = np.meshgrid(
189 | range(self._vol_dim[0]),
190 | range(self._vol_dim[1]),
191 | range(self._vol_dim[2]),
192 | indexing="ij",
193 | )
194 | self.vox_coords = (
195 | np.concatenate(
196 | [xv.reshape(1, -1), yv.reshape(1, -1), zv.reshape(1, -1)], axis=0
197 | )
198 | .astype(int)
199 | .T
200 | )
201 |
202 | @staticmethod
203 | @njit(parallel=True)
204 | def vox2world(vol_origin, vox_coords, vox_size, offsets=(0.5, 0.5, 0.5)):
205 | """Convert voxel grid coordinates to world coordinates."""
206 | vol_origin = vol_origin.astype(np.float32)
207 | vox_coords = vox_coords.astype(np.float32)
208 | # print(np.min(vox_coords))
209 | cam_pts = np.empty_like(vox_coords, dtype=np.float32)
210 |
211 | for i in prange(vox_coords.shape[0]):
212 | for j in range(3):
213 | cam_pts[i, j] = (
214 | vol_origin[j]
215 | + (vox_size * vox_coords[i, j])
216 | + vox_size * offsets[j]
217 | )
218 | return cam_pts
219 |
220 | @staticmethod
221 | @njit(parallel=True)
222 | def cam2pix(cam_pts, intr):
223 | """Convert camera coordinates to pixel coordinates."""
224 | intr = intr.astype(np.float32)
225 | fx, fy = intr[0, 0], intr[1, 1]
226 | cx, cy = intr[0, 2], intr[1, 2]
227 | pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64)
228 | for i in prange(cam_pts.shape[0]):
229 | pix[i, 0] = int(np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx))
230 | pix[i, 1] = int(np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy))
231 | return pix
232 |
233 | @staticmethod
234 | @njit(parallel=True)
235 | def integrate_tsdf(tsdf_vol, dist, w_old, obs_weight):
236 | """Integrate the TSDF volume."""
237 | tsdf_vol_int = np.empty_like(tsdf_vol, dtype=np.float32)
238 | # print(tsdf_vol.shape)
239 | w_new = np.empty_like(w_old, dtype=np.float32)
240 | for i in prange(len(tsdf_vol)):
241 | w_new[i] = w_old[i] + obs_weight
242 | tsdf_vol_int[i] = (w_old[i] * tsdf_vol[i] + obs_weight * dist[i]) / w_new[i]
243 | return tsdf_vol_int, w_new
244 |
245 | def integrate(self, color_im, depth_im, cam_intr, cam_pose, obs_weight=1.0):
246 | """Integrate an RGB-D frame into the TSDF volume.
247 |
248 | Args:
249 | color_im (ndarray): An RGB image of shape (H, W, 3).
250 | depth_im (ndarray): A depth image of shape (H, W).
251 | cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3).
252 | cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4).
253 | obs_weight (float): The weight to assign for the current observation. A higher
254 | value
255 | """
256 | im_h, im_w = depth_im.shape
257 |
258 | # Fold RGB color image into a single channel image
259 | color_im = color_im.astype(np.float32)
260 | color_im = np.floor(
261 | color_im[..., 2] * self._color_const
262 | + color_im[..., 1] * 256
263 | + color_im[..., 0]
264 | )
265 |
266 | if self.gpu_mode: # GPU mode: integrate voxel volume (calls CUDA kernel)
267 | for gpu_loop_idx in range(self._n_gpu_loops):
268 | self._cuda_integrate(
269 | self._tsdf_vol_gpu,
270 | self._weight_vol_gpu,
271 | self._color_vol_gpu,
272 | cuda.InOut(self._vol_dim.astype(np.float32)),
273 | cuda.InOut(self._vol_origin.astype(np.float32)),
274 | cuda.InOut(cam_intr.reshape(-1).astype(np.float32)),
275 | cuda.InOut(cam_pose.reshape(-1).astype(np.float32)),
276 | cuda.InOut(
277 | np.asarray(
278 | [
279 | gpu_loop_idx,
280 | self._voxel_size,
281 | im_h,
282 | im_w,
283 | self._trunc_margin,
284 | obs_weight,
285 | ],
286 | np.float32,
287 | )
288 | ),
289 | cuda.InOut(color_im.reshape(-1).astype(np.float32)),
290 | cuda.InOut(depth_im.reshape(-1).astype(np.float32)),
291 | block=(self._max_gpu_threads_per_block, 1, 1),
292 | grid=(
293 | int(self._max_gpu_grid_dim[0]),
294 | int(self._max_gpu_grid_dim[1]),
295 | int(self._max_gpu_grid_dim[2]),
296 | ),
297 | )
298 | else: # CPU mode: integrate voxel volume (vectorized implementation)
299 | # Convert voxel grid coordinates to pixel coordinates
300 | cam_pts = self.vox2world(
301 | self._vol_origin, self.vox_coords, self._voxel_size
302 | )
303 | cam_pts = rigid_transform(cam_pts, np.linalg.inv(cam_pose))
304 | pix_z = cam_pts[:, 2]
305 | pix = self.cam2pix(cam_pts, cam_intr)
306 | pix_x, pix_y = pix[:, 0], pix[:, 1]
307 |
308 | # Eliminate pixels outside view frustum
309 | valid_pix = np.logical_and(
310 | pix_x >= 0,
311 | np.logical_and(
312 | pix_x < im_w,
313 | np.logical_and(pix_y >= 0, np.logical_and(pix_y < im_h, pix_z > 0)),
314 | ),
315 | )
316 | depth_val = np.zeros(pix_x.shape)
317 | depth_val[valid_pix] = depth_im[pix_y[valid_pix], pix_x[valid_pix]]
318 |
319 | # Integrate TSDF
320 | depth_diff = depth_val - pix_z
321 |
322 | valid_pts = np.logical_and(depth_val > 0, depth_diff >= -10)
323 | dist = depth_diff
324 |
325 | valid_vox_x = self.vox_coords[valid_pts, 0]
326 | valid_vox_y = self.vox_coords[valid_pts, 1]
327 | valid_vox_z = self.vox_coords[valid_pts, 2]
328 | w_old = self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
329 | tsdf_vals = self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
330 | valid_dist = dist[valid_pts]
331 | tsdf_vol_new, w_new = self.integrate_tsdf(
332 | tsdf_vals, valid_dist, w_old, obs_weight
333 | )
334 | self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = w_new
335 | self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = tsdf_vol_new
336 |
337 | # Integrate color
338 | old_color = self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
339 | old_b = np.floor(old_color / self._color_const)
340 | old_g = np.floor((old_color - old_b * self._color_const) / 256)
341 | old_r = old_color - old_b * self._color_const - old_g * 256
342 | new_color = color_im[pix_y[valid_pts], pix_x[valid_pts]]
343 | new_b = np.floor(new_color / self._color_const)
344 | new_g = np.floor((new_color - new_b * self._color_const) / 256)
345 | new_r = new_color - new_b * self._color_const - new_g * 256
346 | new_b = np.minimum(
347 | 255.0, np.round((w_old * old_b + obs_weight * new_b) / w_new)
348 | )
349 | new_g = np.minimum(
350 | 255.0, np.round((w_old * old_g + obs_weight * new_g) / w_new)
351 | )
352 | new_r = np.minimum(
353 | 255.0, np.round((w_old * old_r + obs_weight * new_r) / w_new)
354 | )
355 | self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = (
356 | new_b * self._color_const + new_g * 256 + new_r
357 | )
358 |
359 | def get_volume(self):
360 | if self.gpu_mode:
361 | cuda.memcpy_dtoh(self._tsdf_vol_cpu, self._tsdf_vol_gpu)
362 | cuda.memcpy_dtoh(self._color_vol_cpu, self._color_vol_gpu)
363 | return self._tsdf_vol_cpu, self._color_vol_cpu
364 |
365 | def get_point_cloud(self):
366 | """Extract a point cloud from the voxel volume."""
367 | tsdf_vol, color_vol = self.get_volume()
368 |
369 | # Marching cubes
370 | verts = measure.marching_cubes_lewiner(tsdf_vol, level=0)[0]
371 | verts_ind = np.round(verts).astype(int)
372 | verts = verts * self._voxel_size + self._vol_origin
373 |
374 | # Get vertex colors
375 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]]
376 | colors_b = np.floor(rgb_vals / self._color_const)
377 | colors_g = np.floor((rgb_vals - colors_b * self._color_const) / 256)
378 | colors_r = rgb_vals - colors_b * self._color_const - colors_g * 256
379 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T
380 | colors = colors.astype(np.uint8)
381 |
382 | pc = np.hstack([verts, colors])
383 | return pc
384 |
385 | def get_mesh(self):
386 | """Compute a mesh from the voxel volume using marching cubes."""
387 | tsdf_vol, color_vol = self.get_volume()
388 |
389 | # Marching cubes
390 | verts, faces, norms, vals = measure.marching_cubes_lewiner(tsdf_vol, level=0)
391 | verts_ind = np.round(verts).astype(int)
392 | verts = (
393 | verts * self._voxel_size + self._vol_origin
394 | ) # voxel grid coordinates to world coordinates
395 |
396 | # Get vertex colors
397 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]]
398 | colors_b = np.floor(rgb_vals / self._color_const)
399 | colors_g = np.floor((rgb_vals - colors_b * self._color_const) / 256)
400 | colors_r = rgb_vals - colors_b * self._color_const - colors_g * 256
401 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T
402 | colors = colors.astype(np.uint8)
403 | return verts, faces, norms, colors
404 |
405 |
406 | def rigid_transform(xyz, transform):
407 | """Applies a rigid transform to an (N, 3) pointcloud."""
408 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)])
409 | xyz_t_h = np.dot(transform, xyz_h.T).T
410 | return xyz_t_h[:, :3]
411 |
412 |
413 | def get_view_frustum(depth_im, cam_intr, cam_pose):
414 | """Get corners of 3D camera view frustum of depth image"""
415 | im_h = depth_im.shape[0]
416 | im_w = depth_im.shape[1]
417 | max_depth = np.max(depth_im)
418 | view_frust_pts = np.array(
419 | [
420 | (np.array([0, 0, 0, im_w, im_w]) - cam_intr[0, 2])
421 | * np.array([0, max_depth, max_depth, max_depth, max_depth])
422 | / cam_intr[0, 0],
423 | (np.array([0, 0, im_h, 0, im_h]) - cam_intr[1, 2])
424 | * np.array([0, max_depth, max_depth, max_depth, max_depth])
425 | / cam_intr[1, 1],
426 | np.array([0, max_depth, max_depth, max_depth, max_depth]),
427 | ]
428 | )
429 | view_frust_pts = rigid_transform(view_frust_pts.T, cam_pose).T
430 | return view_frust_pts
431 |
432 |
433 | def meshwrite(filename, verts, faces, norms, colors):
434 | """Save a 3D mesh to a polygon .ply file."""
435 | # Write header
436 | ply_file = open(filename, "w")
437 | ply_file.write("ply\n")
438 | ply_file.write("format ascii 1.0\n")
439 | ply_file.write("element vertex %d\n" % (verts.shape[0]))
440 | ply_file.write("property float x\n")
441 | ply_file.write("property float y\n")
442 | ply_file.write("property float z\n")
443 | ply_file.write("property float nx\n")
444 | ply_file.write("property float ny\n")
445 | ply_file.write("property float nz\n")
446 | ply_file.write("property uchar red\n")
447 | ply_file.write("property uchar green\n")
448 | ply_file.write("property uchar blue\n")
449 | ply_file.write("element face %d\n" % (faces.shape[0]))
450 | ply_file.write("property list uchar int vertex_index\n")
451 | ply_file.write("end_header\n")
452 |
453 | # Write vertex list
454 | for i in range(verts.shape[0]):
455 | ply_file.write(
456 | "%f %f %f %f %f %f %d %d %d\n"
457 | % (
458 | verts[i, 0],
459 | verts[i, 1],
460 | verts[i, 2],
461 | norms[i, 0],
462 | norms[i, 1],
463 | norms[i, 2],
464 | colors[i, 0],
465 | colors[i, 1],
466 | colors[i, 2],
467 | )
468 | )
469 |
470 | # Write face list
471 | for i in range(faces.shape[0]):
472 | ply_file.write("3 %d %d %d\n" % (faces[i, 0], faces[i, 1], faces[i, 2]))
473 |
474 | ply_file.close()
475 |
476 |
477 | def pcwrite(filename, xyzrgb):
478 | """Save a point cloud to a polygon .ply file."""
479 | xyz = xyzrgb[:, :3]
480 | rgb = xyzrgb[:, 3:].astype(np.uint8)
481 |
482 | # Write header
483 | ply_file = open(filename, "w")
484 | ply_file.write("ply\n")
485 | ply_file.write("format ascii 1.0\n")
486 | ply_file.write("element vertex %d\n" % (xyz.shape[0]))
487 | ply_file.write("property float x\n")
488 | ply_file.write("property float y\n")
489 | ply_file.write("property float z\n")
490 | ply_file.write("property uchar red\n")
491 | ply_file.write("property uchar green\n")
492 | ply_file.write("property uchar blue\n")
493 | ply_file.write("end_header\n")
494 |
495 | # Write vertex list
496 | for i in range(xyz.shape[0]):
497 | ply_file.write(
498 | "%f %f %f %d %d %d\n"
499 | % (
500 | xyz[i, 0],
501 | xyz[i, 1],
502 | xyz[i, 2],
503 | rgb[i, 0],
504 | rgb[i, 1],
505 | rgb[i, 2],
506 | )
507 | )
508 |
--------------------------------------------------------------------------------
/iso/data/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import iso.data.utils.fusion as fusion
3 | import torch
4 |
5 |
6 | def compute_CP_mega_matrix(target, is_binary=False):
7 | """
8 | Parameters
9 | ---------
10 | target: (H, W, D)
11 | contains voxels semantic labels
12 |
13 | is_binary: bool
14 | if True, return binary voxels relations else return 4-way relations
15 | """
16 | label = target.reshape(-1)
17 | label_row = label
18 | N = label.shape[0]
19 | super_voxel_size = [i//2 for i in target.shape]
20 | if is_binary:
21 | matrix = np.zeros((2, N, super_voxel_size[0] * super_voxel_size[1] * super_voxel_size[2]), dtype=np.uint8)
22 | else:
23 | matrix = np.zeros((4, N, super_voxel_size[0] * super_voxel_size[1] * super_voxel_size[2]), dtype=np.uint8)
24 |
25 | for xx in range(super_voxel_size[0]):
26 | for yy in range(super_voxel_size[1]):
27 | for zz in range(super_voxel_size[2]):
28 | col_idx = xx * (super_voxel_size[1] * super_voxel_size[2]) + yy * super_voxel_size[2] + zz
29 | label_col_megas = np.array([
30 | target[xx * 2, yy * 2, zz * 2],
31 | target[xx * 2 + 1, yy * 2, zz * 2],
32 | target[xx * 2, yy * 2 + 1, zz * 2],
33 | target[xx * 2, yy * 2, zz * 2 + 1],
34 | target[xx * 2 + 1, yy * 2 + 1, zz * 2],
35 | target[xx * 2 + 1, yy * 2, zz * 2 + 1],
36 | target[xx * 2, yy * 2 + 1, zz * 2 + 1],
37 | target[xx * 2 + 1, yy * 2 + 1, zz * 2 + 1],
38 | ])
39 | label_col_megas = label_col_megas[label_col_megas != 255]
40 | for label_col_mega in label_col_megas:
41 | label_col = np.ones(N) * label_col_mega
42 | if not is_binary:
43 | matrix[0, (label_row != 255) & (label_col == label_row) & (label_col != 0), col_idx] = 1.0 # non non same
44 | matrix[1, (label_row != 255) & (label_col != label_row) & (label_col != 0) & (label_row != 0), col_idx] = 1.0 # non non diff
45 | matrix[2, (label_row != 255) & (label_row == label_col) & (label_col == 0), col_idx] = 1.0 # empty empty
46 | matrix[3, (label_row != 255) & (label_row != label_col) & ((label_row == 0) | (label_col == 0)), col_idx] = 1.0 # nonempty empty
47 | else:
48 | matrix[0, (label_row != 255) & (label_col != label_row), col_idx] = 1.0 # diff
49 | matrix[1, (label_row != 255) & (label_col == label_row), col_idx] = 1.0 # same
50 | return matrix
51 |
52 |
53 | def vox2pix(cam_E, cam_k,
54 | vox_origin, voxel_size,
55 | img_W, img_H,
56 | scene_size):
57 | """
58 | compute the 2D projection of voxels centroids
59 |
60 | Parameters:
61 | ----------
62 | cam_E: 4x4
63 | =camera pose in case of NYUv2 dataset
64 | =Transformation from camera to lidar coordinate in case of SemKITTI
65 | cam_k: 3x3
66 | camera intrinsics
67 | vox_origin: (3,)
68 | world(NYU)/lidar(SemKITTI) cooridnates of the voxel at index (0, 0, 0)
69 | img_W: int
70 | image width
71 | img_H: int
72 | image height
73 | scene_size: (3,)
74 | scene size in meter: (51.2, 51.2, 6.4) for SemKITTI and (4.8, 4.8, 2.88) for NYUv2
75 |
76 | Returns
77 | -------
78 | projected_pix: (N, 2)
79 | Projected 2D positions of voxels
80 | fov_mask: (N,)
81 | Voxels mask indice voxels inside image's FOV
82 | pix_z: (N,)
83 | Voxels'distance to the sensor in meter
84 | """
85 | # Compute the x, y, z bounding of the scene in meter
86 | vol_bnds = np.zeros((3,2))
87 | vol_bnds[:,0] = vox_origin
88 | vol_bnds[:,1] = vox_origin + np.array(scene_size)
89 |
90 |
91 | # Compute the voxels centroids in lidar cooridnates
92 | # TODO: Make sure the around process has no influence on the NYUv2 result.
93 | vol_dim = np.ceil(np.around((vol_bnds[:,1]- vol_bnds[:,0])/ voxel_size, 3)).copy(order='C').astype(int)
94 | if vol_dim[0] != 60 or vol_dim[1] != 60 or vol_dim[2] != 36:
95 | print("Find it:", vol_dim, '\n', vol_bnds, vol_bnds.dtype)
96 | exit(-1)
97 | xv, yv, zv = np.meshgrid(
98 | range(vol_dim[0]),
99 | range(vol_dim[1]),
100 | range(vol_dim[2]),
101 | indexing='ij'
102 | )
103 | vox_coords = np.concatenate([
104 | xv.reshape(1,-1),
105 | yv.reshape(1,-1),
106 | zv.reshape(1,-1)
107 | ], axis=0).astype(int).T
108 |
109 | # Project voxels'centroid from lidar coordinates to camera coordinates
110 | cam_pts = fusion.TSDFVolume.vox2world(vox_origin, vox_coords, voxel_size)
111 | cam_pts = fusion.rigid_transform(cam_pts, cam_E)
112 |
113 | # Project camera coordinates to pixel positions
114 | projected_pix = fusion.TSDFVolume.cam2pix(cam_pts, cam_k)
115 | pix_x, pix_y = projected_pix[:, 0], projected_pix[:, 1]
116 |
117 | # Eliminate pixels outside view frustum
118 | pix_z = cam_pts[:, 2]
119 | fov_mask = np.logical_and(pix_x >= 0,
120 | np.logical_and(pix_x < img_W,
121 | np.logical_and(pix_y >= 0,
122 | np.logical_and(pix_y < img_H,
123 | pix_z > 0))))
124 |
125 |
126 | return projected_pix, fov_mask, pix_z
127 |
128 |
129 | def compute_local_frustum(pix_x, pix_y, min_x, max_x, min_y, max_y, pix_z):
130 | valid_pix = np.logical_and(pix_x >= min_x,
131 | np.logical_and(pix_x < max_x,
132 | np.logical_and(pix_y >= min_y,
133 | np.logical_and(pix_y < max_y,
134 | pix_z > 0))))
135 | return valid_pix
136 |
137 | def compute_local_frustums(projected_pix, pix_z, target, img_W, img_H, dataset, n_classes, size=4):
138 | """
139 | Compute the local frustums mask and their class frequencies
140 |
141 | Parameters:
142 | ----------
143 | projected_pix: (N, 2)
144 | 2D projected pix of all voxels
145 | pix_z: (N,)
146 | Distance of the camera sensor to voxels
147 | target: (H, W, D)
148 | Voxelized sematic labels
149 | img_W: int
150 | Image width
151 | img_H: int
152 | Image height
153 | dataset: str
154 | ="NYU" or "kitti" (for both SemKITTI and KITTI-360)
155 | n_classes: int
156 | Number of classes (12 for NYU and 20 for SemKITTI)
157 | size: int
158 | determine the number of local frustums i.e. size * size
159 |
160 | Returns
161 | -------
162 | frustums_masks: (n_frustums, N)
163 | List of frustums_masks, each indicates the belonging voxels
164 | frustums_class_dists: (n_frustums, n_classes)
165 | Contains the class frequencies in each frustum
166 | """
167 | H, W, D = target.shape
168 | ranges = [(i * 1.0/size, (i * 1.0 + 1)/size) for i in range(size)]
169 | local_frustum_masks = []
170 | local_frustum_class_dists = []
171 | pix_x, pix_y = projected_pix[:, 0], projected_pix[:, 1]
172 | for y in ranges:
173 | for x in ranges:
174 | start_x = x[0] * img_W
175 | end_x = x[1] * img_W
176 | start_y = y[0] * img_H
177 | end_y = y[1] * img_H
178 | local_frustum = compute_local_frustum(pix_x, pix_y, start_x, end_x, start_y, end_y, pix_z)
179 | if dataset == "NYU":
180 | mask = (target != 255) & np.moveaxis(local_frustum.reshape(60, 60, 36), [0, 1, 2], [0, 2, 1])
181 | elif dataset == "OccScanNet" or dataset == "OccScanNet_mini":
182 | mask = (target != 255) & local_frustum.reshape(H, W, D)
183 |
184 | local_frustum_masks.append(mask)
185 | classes, cnts = np.unique(target[mask], return_counts=True)
186 | class_counts = np.zeros(n_classes)
187 | class_counts[classes.astype(int)] = cnts
188 | local_frustum_class_dists.append(class_counts)
189 | frustums_masks, frustums_class_dists = np.array(local_frustum_masks), np.array(local_frustum_class_dists)
190 | return frustums_masks, frustums_class_dists
191 |
--------------------------------------------------------------------------------
/iso/data/utils/torch_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def worker_init_fn(worker_id):
6 | """The function is designed for pytorch multi-process dataloader.
7 | Note that we use the pytorch random generator to generate a base_seed.
8 | Please try to be consistent.
9 |
10 | References:
11 | https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed
12 |
13 | """
14 | base_seed = torch.IntTensor(1).random_().item()
15 | np.random.seed(base_seed + worker_id)
16 |
--------------------------------------------------------------------------------
/iso/loss/CRP_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def compute_super_CP_multilabel_loss(pred_logits, CP_mega_matrices):
5 | logits = []
6 | labels = []
7 | bs, n_relations, _, _ = pred_logits.shape
8 | for i in range(bs):
9 | pred_logit = pred_logits[i, :, :, :].permute(
10 | 0, 2, 1
11 | ) # n_relations, N, n_mega_voxels
12 | CP_mega_matrix = CP_mega_matrices[i] # n_relations, N, n_mega_voxels
13 | logits.append(pred_logit.reshape(n_relations, -1))
14 | labels.append(CP_mega_matrix.reshape(n_relations, -1))
15 |
16 | logits = torch.cat(logits, dim=1).T # M, 4
17 | labels = torch.cat(labels, dim=1).T # M, 4
18 |
19 | cnt_neg = (labels == 0).sum(0)
20 | cnt_pos = labels.sum(0)
21 | pos_weight = cnt_neg / (cnt_pos + 1e-5)
22 | criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
23 | loss_bce = criterion(logits, labels.float())
24 | return loss_bce
25 |
--------------------------------------------------------------------------------
/iso/loss/depth_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.cuda.amp import autocast
3 | import torch.nn.functional as F
4 |
5 | from iso.models.modules import bin_depths
6 |
7 |
8 | class DepthClsLoss:
9 | def __init__(self, downsample_factor, d_bound, depth_channels=64):
10 | self.downsample_factor = downsample_factor
11 | self.depth_channels = depth_channels
12 | self.d_bound = d_bound
13 | # self.depth_channels = int((self.d_bound[1] - self.d_bound[0]) / self.d_bound[2])
14 |
15 | def _get_downsampled_gt_depth(self, gt_depths):
16 | """
17 | Input:
18 | gt_depths: [B, N, H, W]
19 | Output:
20 | gt_depths: [B*N*h*w, d]
21 | """
22 | B, N, H, W = gt_depths.shape
23 | gt_depths = gt_depths.view(
24 | B * N,
25 | H // self.downsample_factor,
26 | self.downsample_factor,
27 | W // self.downsample_factor,
28 | self.downsample_factor,
29 | 1,
30 | )
31 | gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous()
32 | gt_depths = gt_depths.view(-1, self.downsample_factor * self.downsample_factor)
33 |
34 | gt_depths_tmp = torch.where(
35 | gt_depths == 0.0, 1e5 * torch.ones_like(gt_depths), gt_depths
36 | )
37 | gt_depths = torch.min(gt_depths_tmp, dim=-1).values
38 | gt_depths = gt_depths.view(
39 | B * N, H // self.downsample_factor, W // self.downsample_factor
40 | )
41 |
42 | # ds = torch.arange(64) # (64,)
43 | # depth_bin_pos = (10.0 / 64 / 65 * ds * (ds + 1)).reshape(1, 64, 1, 1)
44 | # # print(gt_depths.unsqueeze(1).shape)
45 | # # print(depth_bin_pos.shape)
46 | # delta_z = torch.abs(gt_depths.unsqueeze(1) - depth_bin_pos.to(gt_depths.device))
47 | # gt_depths = torch.argmin(delta_z, dim=1)
48 | gt_depths = bin_depths(gt_depths, mode='LID', depth_min=0, depth_max=10, num_bins=self.depth_channels, target=True)
49 |
50 | gt_depths = torch.where(
51 | (gt_depths < self.depth_channels + 1) & (gt_depths >= 0.0),
52 | gt_depths,
53 | torch.zeros_like(gt_depths),
54 | )
55 | gt_depths = F.one_hot(
56 | gt_depths.long(), num_classes=self.depth_channels + 1
57 | ).view(-1, self.depth_channels + 1)[:, :-1]
58 |
59 | return gt_depths.float()
60 |
61 | def get_depth_loss(self, depth_labels, depth_preds):
62 | if len(depth_labels.shape) != 4:
63 | depth_labels = depth_labels.unsqueeze(1)
64 | # print(depth_labels.shape)
65 | N_pred, n_cam_pred, D, H, W = depth_preds.shape
66 | N_gt, n_cam_label, oriH, oriW = depth_labels.shape
67 | assert (
68 | N_pred * n_cam_pred == N_gt * n_cam_label
69 | ), f"N_pred: {N_pred}, n_cam_pred: {n_cam_pred}, N_gt: {N_gt}, n_cam_label: {n_cam_label}"
70 | depth_labels = depth_labels.reshape(N_gt * n_cam_label, oriH, oriW)
71 | depth_preds = depth_preds.reshape(N_pred * n_cam_pred, D, H, W)
72 |
73 | # depth_labels = depth_labels.reshape(
74 | # N
75 | # )
76 |
77 | # depth_labels = depth_labels.unsqueeze(1)
78 | # depth_labels = depth_labels
79 | depth_labels = F.interpolate(
80 | depth_labels.unsqueeze(1),
81 | (H * self.downsample_factor, W * self.downsample_factor),
82 | mode="nearest",
83 | )
84 | depth_labels = self._get_downsampled_gt_depth(depth_labels)
85 | depth_preds = (
86 | depth_preds.permute(0, 2, 3, 1).contiguous().view(-1, self.depth_channels)
87 | )
88 | fg_mask = torch.max(depth_labels, dim=1).values > 0.0
89 |
90 | with autocast(enabled=False):
91 | depth_loss = F.binary_cross_entropy(
92 | depth_preds[fg_mask],
93 | depth_labels[fg_mask],
94 | reduction="none",
95 | ).sum() / max(1.0, fg_mask.sum())
96 | # depth_loss = torch.nan_to_num(depth_loss, nan=0.)
97 |
98 | return depth_loss
99 |
--------------------------------------------------------------------------------
/iso/loss/sscMetrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Part of the code is taken from https://github.com/waterljwant/SSC/blob/master/sscMetrics.py
3 | """
4 | import numpy as np
5 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support
6 |
7 |
8 | def get_iou(iou_sum, cnt_class):
9 | _C = iou_sum.shape[0] # 12
10 | iou = np.zeros(_C, dtype=np.float32) # iou for each class
11 | for idx in range(_C):
12 | iou[idx] = iou_sum[idx] / cnt_class[idx] if cnt_class[idx] else 0
13 |
14 | mean_iou = np.sum(iou[1:]) / np.count_nonzero(cnt_class[1:])
15 | return iou, mean_iou
16 |
17 |
18 | def get_accuracy(predict, target, weight=None): # 0.05s
19 | _bs = predict.shape[0] # batch size
20 | _C = predict.shape[1] # _C = 12
21 | target = np.int32(target)
22 | target = target.reshape(_bs, -1) # (_bs, 60*36*60) 129600
23 | predict = predict.reshape(_bs, _C, -1) # (_bs, _C, 60*36*60)
24 | predict = np.argmax(
25 | predict, axis=1
26 | ) # one-hot: _bs x _C x 60*36*60 --> label: _bs x 60*36*60.
27 |
28 | correct = predict == target # (_bs, 129600)
29 | if weight: # 0.04s, add class weights
30 | weight_k = np.ones(target.shape)
31 | for i in range(_bs):
32 | for n in range(target.shape[1]):
33 | idx = 0 if target[i, n] == 255 else target[i, n]
34 | weight_k[i, n] = weight[idx]
35 | correct = correct * weight_k
36 | acc = correct.sum() / correct.size
37 | return acc
38 |
39 |
40 | class SSCMetrics:
41 | def __init__(self, n_classes):
42 | self.n_classes = n_classes
43 | self.reset()
44 |
45 | def hist_info(self, n_cl, pred, gt):
46 | assert pred.shape == gt.shape
47 | k = (gt >= 0) & (gt < n_cl) # exclude 255
48 | labeled = np.sum(k)
49 | correct = np.sum((pred[k] == gt[k]))
50 |
51 | return (
52 | np.bincount(
53 | n_cl * gt[k].astype(int) + pred[k].astype(int), minlength=n_cl ** 2
54 | ).reshape(n_cl, n_cl),
55 | correct,
56 | labeled,
57 | )
58 |
59 | @staticmethod
60 | def compute_score(hist, correct, labeled):
61 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
62 | mean_IU = np.nanmean(iu)
63 | mean_IU_no_back = np.nanmean(iu[1:])
64 | freq = hist.sum(1) / hist.sum()
65 | freq_IU = (iu[freq > 0] * freq[freq > 0]).sum()
66 | mean_pixel_acc = correct / labeled if labeled != 0 else 0
67 |
68 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc
69 |
70 | def add_batch(self, y_pred, y_true, nonempty=None, nonsurface=None):
71 | self.count += 1
72 | mask = y_true != 255
73 | if nonempty is not None:
74 | mask = mask & nonempty
75 | if nonsurface is not None:
76 | mask = mask & nonsurface
77 | tp, fp, fn = self.get_score_completion(y_pred, y_true, mask)
78 |
79 | self.completion_tp += tp
80 | self.completion_fp += fp
81 | self.completion_fn += fn
82 |
83 | mask = y_true != 255
84 | if nonempty is not None:
85 | mask = mask & nonempty
86 | tp_sum, fp_sum, fn_sum = self.get_score_semantic_and_completion(
87 | y_pred, y_true, mask
88 | )
89 | self.tps += tp_sum
90 | self.fps += fp_sum
91 | self.fns += fn_sum
92 |
93 | def get_stats(self):
94 | if self.completion_tp != 0:
95 | precision = self.completion_tp / (self.completion_tp + self.completion_fp)
96 | recall = self.completion_tp / (self.completion_tp + self.completion_fn)
97 | iou = self.completion_tp / (
98 | self.completion_tp + self.completion_fp + self.completion_fn
99 | )
100 | else:
101 | precision, recall, iou = 0, 0, 0
102 | iou_ssc = self.tps / (self.tps + self.fps + self.fns + 1e-5)
103 | return {
104 | "precision": precision,
105 | "recall": recall,
106 | "iou": iou,
107 | "iou_ssc": iou_ssc,
108 | "iou_ssc_mean": np.mean(iou_ssc[1:]),
109 | }
110 |
111 | def reset(self):
112 |
113 | self.completion_tp = 0
114 | self.completion_fp = 0
115 | self.completion_fn = 0
116 | self.tps = np.zeros(self.n_classes)
117 | self.fps = np.zeros(self.n_classes)
118 | self.fns = np.zeros(self.n_classes)
119 |
120 | self.hist_ssc = np.zeros((self.n_classes, self.n_classes))
121 | self.labeled_ssc = 0
122 | self.correct_ssc = 0
123 |
124 | self.precision = 0
125 | self.recall = 0
126 | self.iou = 0
127 | self.count = 1e-8
128 | self.iou_ssc = np.zeros(self.n_classes, dtype=np.float32)
129 | self.cnt_class = np.zeros(self.n_classes, dtype=np.float32)
130 |
131 | def get_score_completion(self, predict, target, nonempty=None):
132 | predict = np.copy(predict)
133 | target = np.copy(target)
134 |
135 | """for scene completion, treat the task as two-classes problem, just empty or occupancy"""
136 | _bs = predict.shape[0] # batch size
137 | # ---- ignore
138 | predict[target == 255] = 0
139 | target[target == 255] = 0
140 | # ---- flatten
141 | target = target.reshape(_bs, -1) # (_bs, 129600)
142 | predict = predict.reshape(_bs, -1) # (_bs, _C, 129600), 60*36*60=129600
143 | # ---- treat all non-empty object class as one category, set them to label 1
144 | b_pred = np.zeros(predict.shape)
145 | b_true = np.zeros(target.shape)
146 | b_pred[predict > 0] = 1
147 | b_true[target > 0] = 1
148 | p, r, iou = 0.0, 0.0, 0.0
149 | tp_sum, fp_sum, fn_sum = 0, 0, 0
150 | for idx in range(_bs):
151 | y_true = b_true[idx, :] # GT
152 | y_pred = b_pred[idx, :]
153 | if nonempty is not None:
154 | nonempty_idx = nonempty[idx, :].reshape(-1)
155 | y_true = y_true[nonempty_idx == 1]
156 | y_pred = y_pred[nonempty_idx == 1]
157 |
158 | tp = np.array(np.where(np.logical_and(y_true == 1, y_pred == 1))).size
159 | fp = np.array(np.where(np.logical_and(y_true != 1, y_pred == 1))).size
160 | fn = np.array(np.where(np.logical_and(y_true == 1, y_pred != 1))).size
161 | tp_sum += tp
162 | fp_sum += fp
163 | fn_sum += fn
164 | return tp_sum, fp_sum, fn_sum
165 |
166 | def get_score_semantic_and_completion(self, predict, target, nonempty=None):
167 | target = np.copy(target)
168 | predict = np.copy(predict)
169 | _bs = predict.shape[0] # batch size
170 | _C = self.n_classes # _C = 12
171 | # ---- ignore
172 | predict[target == 255] = 0
173 | target[target == 255] = 0
174 | # ---- flatten
175 | target = target.reshape(_bs, -1) # (_bs, 129600)
176 | predict = predict.reshape(_bs, -1) # (_bs, 129600), 60*36*60=129600
177 |
178 | cnt_class = np.zeros(_C, dtype=np.int32) # count for each class
179 | iou_sum = np.zeros(_C, dtype=np.float32) # sum of iou for each class
180 | tp_sum = np.zeros(_C, dtype=np.int32) # tp
181 | fp_sum = np.zeros(_C, dtype=np.int32) # fp
182 | fn_sum = np.zeros(_C, dtype=np.int32) # fn
183 |
184 | for idx in range(_bs):
185 | y_true = target[idx, :] # GT
186 | y_pred = predict[idx, :]
187 | if nonempty is not None:
188 | nonempty_idx = nonempty[idx, :].reshape(-1)
189 | y_pred = y_pred[
190 | np.where(np.logical_and(nonempty_idx == 1, y_true != 255))
191 | ]
192 | y_true = y_true[
193 | np.where(np.logical_and(nonempty_idx == 1, y_true != 255))
194 | ]
195 | for j in range(_C): # for each class
196 | tp = np.array(np.where(np.logical_and(y_true == j, y_pred == j))).size
197 | fp = np.array(np.where(np.logical_and(y_true != j, y_pred == j))).size
198 | fn = np.array(np.where(np.logical_and(y_true == j, y_pred != j))).size
199 |
200 | tp_sum[j] += tp
201 | fp_sum[j] += fp
202 | fn_sum[j] += fn
203 |
204 | return tp_sum, fp_sum, fn_sum
205 |
--------------------------------------------------------------------------------
/iso/loss/ssc_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def KL_sep(p, target):
7 | """
8 | KL divergence on nonzeros classes
9 | """
10 | nonzeros = target != 0
11 | nonzero_p = p[nonzeros]
12 | kl_term = F.kl_div(torch.log(nonzero_p), target[nonzeros], reduction="sum")
13 | return kl_term
14 |
15 |
16 | def geo_scal_loss(pred, ssc_target):
17 |
18 | # Get softmax probabilities
19 | pred = F.softmax(pred, dim=1)
20 |
21 | # Compute empty and nonempty probabilities
22 | empty_probs = pred[:, 0, :, :, :]
23 | nonempty_probs = 1 - empty_probs
24 |
25 | # Remove unknown voxels
26 | mask = ssc_target != 255
27 | nonempty_target = ssc_target != 0
28 | nonempty_target = nonempty_target[mask].float()
29 | nonempty_probs = nonempty_probs[mask]
30 | empty_probs = empty_probs[mask]
31 |
32 | intersection = (nonempty_target * nonempty_probs).sum()
33 | precision = intersection / nonempty_probs.sum()
34 | recall = intersection / nonempty_target.sum()
35 | spec = ((1 - nonempty_target) * (empty_probs)).sum() / (1 - nonempty_target).sum()
36 | return (
37 | F.binary_cross_entropy(precision, torch.ones_like(precision))
38 | + F.binary_cross_entropy(recall, torch.ones_like(recall))
39 | + F.binary_cross_entropy(spec, torch.ones_like(spec))
40 | )
41 |
42 |
43 | def sem_scal_loss(pred, ssc_target):
44 | # Get softmax probabilities
45 | pred = F.softmax(pred, dim=1)
46 | loss = 0
47 | count = 0
48 | mask = ssc_target != 255
49 | n_classes = pred.shape[1]
50 | for i in range(0, n_classes):
51 |
52 | # Get probability of class i
53 | p = pred[:, i, :, :, :]
54 |
55 | # Remove unknown voxels
56 | target_ori = ssc_target
57 | p = p[mask]
58 | target = ssc_target[mask]
59 |
60 | completion_target = torch.ones_like(target)
61 | completion_target[target != i] = 0
62 | completion_target_ori = torch.ones_like(target_ori).float()
63 | completion_target_ori[target_ori != i] = 0
64 | if torch.sum(completion_target) > 0:
65 | count += 1.0
66 | nominator = torch.sum(p * completion_target)
67 | loss_class = 0
68 | if torch.sum(p) > 0:
69 | precision = nominator / (torch.sum(p))
70 | loss_precision = F.binary_cross_entropy(
71 | precision, torch.ones_like(precision)
72 | )
73 | loss_class += loss_precision
74 | if torch.sum(completion_target) > 0:
75 | recall = nominator / (torch.sum(completion_target))
76 | loss_recall = F.binary_cross_entropy(recall, torch.ones_like(recall))
77 | loss_class += loss_recall
78 | if torch.sum(1 - completion_target) > 0:
79 | specificity = torch.sum((1 - p) * (1 - completion_target)) / (
80 | torch.sum(1 - completion_target)
81 | )
82 | loss_specificity = F.binary_cross_entropy(
83 | specificity, torch.ones_like(specificity)
84 | )
85 | loss_class += loss_specificity
86 | loss += loss_class
87 | return loss / count
88 |
89 |
90 | def CE_ssc_loss(pred, target, class_weights):
91 | """
92 | :param: prediction: the predicted tensor, must be [BS, C, H, W, D]
93 | """
94 | criterion = nn.CrossEntropyLoss(
95 | weight=class_weights, ignore_index=255, reduction="mean"
96 | )
97 | loss = criterion(pred, target.long())
98 |
99 | return loss
100 |
--------------------------------------------------------------------------------
/iso/models/CRP3D.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from iso.models.modules import (
4 | Process,
5 | ASPP,
6 | )
7 |
8 |
9 | class CPMegaVoxels(nn.Module):
10 | def __init__(self, feature, size, n_relations=4, bn_momentum=0.0003):
11 | super().__init__()
12 | self.size = size
13 | self.n_relations = n_relations
14 | print("n_relations", self.n_relations)
15 | self.flatten_size = size[0] * size[1] * size[2]
16 | self.feature = feature
17 | self.context_feature = feature * 2
18 | self.flatten_context_size = (size[0] // 2) * (size[1] // 2) * (size[2] // 2)
19 | padding = ((size[0] + 1) % 2, (size[1] + 1) % 2, (size[2] + 1) % 2)
20 |
21 | self.mega_context = nn.Sequential(
22 | nn.Conv3d(
23 | feature, self.context_feature, stride=2, padding=padding, kernel_size=3
24 | ),
25 | )
26 | self.flatten_context_size = (size[0] // 2) * (size[1] // 2) * (size[2] // 2)
27 |
28 | self.context_prior_logits = nn.ModuleList(
29 | [
30 | nn.Sequential(
31 | nn.Conv3d(
32 | self.feature,
33 | self.flatten_context_size,
34 | padding=0,
35 | kernel_size=1,
36 | ),
37 | )
38 | for i in range(n_relations)
39 | ]
40 | )
41 | self.aspp = ASPP(feature, [1, 2, 3])
42 |
43 | self.resize = nn.Sequential(
44 | nn.Conv3d(
45 | self.context_feature * self.n_relations + feature,
46 | feature,
47 | kernel_size=1,
48 | padding=0,
49 | bias=False,
50 | ),
51 | Process(feature, nn.BatchNorm3d, bn_momentum, dilations=[1]),
52 | )
53 |
54 | def forward(self, input):
55 | ret = {}
56 | bs = input.shape[0]
57 |
58 | x_agg = self.aspp(input)
59 |
60 | # get the mega context
61 | x_mega_context_raw = self.mega_context(x_agg)
62 | x_mega_context = x_mega_context_raw.reshape(bs, self.context_feature, -1)
63 | x_mega_context = x_mega_context.permute(0, 2, 1)
64 |
65 | # get context prior map
66 | x_context_prior_logits = []
67 | x_context_rels = []
68 | for rel in range(self.n_relations):
69 |
70 | # Compute the relation matrices
71 | x_context_prior_logit = self.context_prior_logits[rel](x_agg)
72 | x_context_prior_logit = x_context_prior_logit.reshape(
73 | bs, self.flatten_context_size, self.flatten_size
74 | )
75 | x_context_prior_logits.append(x_context_prior_logit.unsqueeze(1))
76 |
77 | x_context_prior_logit = x_context_prior_logit.permute(0, 2, 1)
78 | x_context_prior = torch.sigmoid(x_context_prior_logit)
79 |
80 | # Multiply the relation matrices with the mega context to gather context features
81 | x_context_rel = torch.bmm(x_context_prior, x_mega_context) # bs, N, f
82 | x_context_rels.append(x_context_rel)
83 |
84 | x_context = torch.cat(x_context_rels, dim=2)
85 | x_context = x_context.permute(0, 2, 1)
86 | x_context = x_context.reshape(
87 | bs, x_context.shape[1], self.size[0], self.size[1], self.size[2]
88 | )
89 |
90 | x = torch.cat([input, x_context], dim=1)
91 | x = self.resize(x)
92 |
93 | x_context_prior_logits = torch.cat(x_context_prior_logits, dim=1)
94 | ret["P_logits"] = x_context_prior_logits
95 | ret["x"] = x
96 |
97 | return ret
98 |
--------------------------------------------------------------------------------
/iso/models/DDR.py:
--------------------------------------------------------------------------------
1 | """
2 | Most of the code in this file is taken from https://github.com/waterljwant/SSC/blob/master/models/DDR.py
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class SimpleRB(nn.Module):
11 | def __init__(self, in_channel, norm_layer, bn_momentum):
12 | super(SimpleRB, self).__init__()
13 | self.path = nn.Sequential(
14 | nn.Conv3d(in_channel, in_channel, kernel_size=3, padding=1, bias=False),
15 | norm_layer(in_channel, momentum=bn_momentum),
16 | nn.ReLU(),
17 | nn.Conv3d(in_channel, in_channel, kernel_size=3, padding=1, bias=False),
18 | norm_layer(in_channel, momentum=bn_momentum),
19 | )
20 | self.relu = nn.ReLU()
21 |
22 | def forward(self, x):
23 | residual = x
24 | conv_path = self.path(x)
25 | out = residual + conv_path
26 | out = self.relu(out)
27 | return out
28 |
29 |
30 | """
31 | 3D Residual Block,3x3x3 conv ==> 3 smaller 3D conv, refered from DDRNet
32 | """
33 |
34 |
35 | class Bottleneck3D(nn.Module):
36 | def __init__(
37 | self,
38 | inplanes,
39 | planes,
40 | norm_layer,
41 | stride=1,
42 | dilation=[1, 1, 1],
43 | expansion=4,
44 | downsample=None,
45 | fist_dilation=1,
46 | multi_grid=1,
47 | bn_momentum=0.0003,
48 | ):
49 | super(Bottleneck3D, self).__init__()
50 | # often,planes = inplanes // 4
51 | self.expansion = expansion
52 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
53 | self.bn1 = norm_layer(planes, momentum=bn_momentum)
54 | self.conv2 = nn.Conv3d(
55 | planes,
56 | planes,
57 | kernel_size=(1, 1, 3),
58 | stride=(1, 1, stride),
59 | dilation=(1, 1, dilation[0]),
60 | padding=(0, 0, dilation[0]),
61 | bias=False,
62 | )
63 | self.bn2 = norm_layer(planes, momentum=bn_momentum)
64 | self.conv3 = nn.Conv3d(
65 | planes,
66 | planes,
67 | kernel_size=(1, 3, 1),
68 | stride=(1, stride, 1),
69 | dilation=(1, dilation[1], 1),
70 | padding=(0, dilation[1], 0),
71 | bias=False,
72 | )
73 | self.bn3 = norm_layer(planes, momentum=bn_momentum)
74 | self.conv4 = nn.Conv3d(
75 | planes,
76 | planes,
77 | kernel_size=(3, 1, 1),
78 | stride=(stride, 1, 1),
79 | dilation=(dilation[2], 1, 1),
80 | padding=(dilation[2], 0, 0),
81 | bias=False,
82 | )
83 | self.bn4 = norm_layer(planes, momentum=bn_momentum)
84 | self.conv5 = nn.Conv3d(
85 | planes, planes * self.expansion, kernel_size=(1, 1, 1), bias=False
86 | )
87 | self.bn5 = norm_layer(planes * self.expansion, momentum=bn_momentum)
88 |
89 | self.relu = nn.ReLU(inplace=False)
90 | self.relu_inplace = nn.ReLU(inplace=True)
91 | self.downsample = downsample
92 | self.dilation = dilation
93 | self.stride = stride
94 |
95 | self.downsample2 = nn.Sequential(
96 | nn.AvgPool3d(kernel_size=(1, stride, 1), stride=(1, stride, 1)),
97 | nn.Conv3d(planes, planes, kernel_size=1, stride=1, bias=False),
98 | norm_layer(planes, momentum=bn_momentum),
99 | )
100 | self.downsample3 = nn.Sequential(
101 | nn.AvgPool3d(kernel_size=(stride, 1, 1), stride=(stride, 1, 1)),
102 | nn.Conv3d(planes, planes, kernel_size=1, stride=1, bias=False),
103 | norm_layer(planes, momentum=bn_momentum),
104 | )
105 | self.downsample4 = nn.Sequential(
106 | nn.AvgPool3d(kernel_size=(stride, 1, 1), stride=(stride, 1, 1)),
107 | nn.Conv3d(planes, planes, kernel_size=1, stride=1, bias=False),
108 | norm_layer(planes, momentum=bn_momentum),
109 | )
110 |
111 | def forward(self, x):
112 | residual = x
113 |
114 | out1 = self.relu(self.bn1(self.conv1(x)))
115 | out2 = self.bn2(self.conv2(out1))
116 | out2_relu = self.relu(out2)
117 |
118 | out3 = self.bn3(self.conv3(out2_relu))
119 | if self.stride != 1:
120 | out2 = self.downsample2(out2)
121 | out3 = out3 + out2
122 | out3_relu = self.relu(out3)
123 |
124 | out4 = self.bn4(self.conv4(out3_relu))
125 | if self.stride != 1:
126 | out2 = self.downsample3(out2)
127 | out3 = self.downsample4(out3)
128 | out4 = out4 + out2 + out3
129 |
130 | out4_relu = self.relu(out4)
131 | out5 = self.bn5(self.conv5(out4_relu))
132 |
133 | if self.downsample is not None:
134 | residual = self.downsample(x)
135 |
136 | out = out5 + residual
137 | out_relu = self.relu(out)
138 |
139 | return out_relu
140 |
--------------------------------------------------------------------------------
/iso/models/depthnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | # from mmdet.models.backbones.resnet import BasicBlock
5 |
6 | import math
7 | # from occdepth.models.f2v.frustum_grid_generator import FrustumGridGenerator
8 | # from occdepth.models.f2v.frustum_to_voxel import FrustumToVoxel
9 | # from occdepth.models.f2v.sampler import Sampler
10 |
11 | class BasicBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self,
15 | inplanes,
16 | planes,
17 | stride=1,
18 | dilation=1,
19 | downsample=None,
20 | # style='pytorch',
21 | with_cp=False,
22 | # conv_cfg=None,
23 | # norm_cfg=dict(type='BN'),
24 | dcn=None,
25 | plugins=None,
26 | init_cfg=None):
27 | super(BasicBlock, self).__init__()
28 | assert dcn is None, 'Not implemented yet.'
29 | assert plugins is None, 'Not implemented yet.'
30 |
31 | self.norm1 = nn.BatchNorm2d(planes)
32 | self.norm2 = nn.BatchNorm2d(planes)
33 |
34 | self.conv1 = nn.Conv2d(
35 | inplanes,
36 | planes,
37 | 3,
38 | stride=stride,
39 | padding=dilation,
40 | dilation=dilation,
41 | bias=False)
42 |
43 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
44 |
45 | self.relu = nn.ReLU(inplace=True)
46 | self.downsample = downsample
47 | self.stride = stride
48 | self.dilation = dilation
49 | self.with_cp = with_cp
50 |
51 | @property
52 | def norm1(self):
53 | """nn.Module: normalization layer after the first convolution layer"""
54 | return getattr(self, self.norm1_name)
55 |
56 | @property
57 | def norm2(self):
58 | """nn.Module: normalization layer after the second convolution layer"""
59 | return getattr(self, self.norm2_name)
60 |
61 | def forward(self, x):
62 | """Forward function."""
63 | # print(x.shape)
64 |
65 | def _inner_forward(x):
66 | identity = x
67 |
68 | out = self.conv1(x)
69 | out = self.norm1(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv2(out)
73 | out = self.norm2(out)
74 |
75 | if self.downsample is not None:
76 | identity = self.downsample(x)
77 |
78 | out += identity
79 |
80 | return out
81 |
82 | out = _inner_forward(x)
83 |
84 | out = self.relu(out)
85 |
86 | return out
87 |
88 |
89 | class Mlp(nn.Module):
90 | def __init__(
91 | self,
92 | in_features,
93 | hidden_features=None,
94 | out_features=None,
95 | act_layer=nn.ReLU,
96 | drop=0.0,
97 | ):
98 | super().__init__()
99 | out_features = out_features or in_features
100 | hidden_features = hidden_features or in_features
101 | self.fc1 = nn.Linear(in_features, hidden_features)
102 | self.act = act_layer()
103 | self.drop1 = nn.Dropout(drop)
104 | self.fc2 = nn.Linear(hidden_features, out_features)
105 | self.drop2 = nn.Dropout(drop)
106 |
107 | def forward(self, x):
108 | x = self.fc1(x)
109 | x = self.act(x)
110 | x = self.drop1(x)
111 | x = self.fc2(x)
112 | x = self.drop2(x)
113 | return x
114 |
115 |
116 | class SELayer(nn.Module):
117 | def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid):
118 | super().__init__()
119 | self.conv_reduce = nn.Conv2d(channels, channels, 1, bias=True)
120 | self.act1 = act_layer()
121 | self.conv_expand = nn.Conv2d(channels, channels, 1, bias=True)
122 | self.gate = gate_layer()
123 |
124 | def forward(self, x, x_se):
125 | x_se = self.conv_reduce(x_se)
126 | x_se = self.act1(x_se)
127 | x_se = self.conv_expand(x_se)
128 | return x * self.gate(x_se)
129 |
130 |
131 | class DepthNet(nn.Module):
132 | def __init__(
133 | self,
134 | in_channels,
135 | mid_channels,
136 | context_channels,
137 | depth_channels,
138 | infer_mode=False,
139 | ):
140 | super(DepthNet, self).__init__()
141 | self.reduce_conv = nn.Sequential(
142 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1),
143 | nn.BatchNorm2d(mid_channels),
144 | nn.ReLU(inplace=True),
145 | )
146 | self.mlp = Mlp(1, mid_channels, mid_channels)
147 | self.se = SELayer(mid_channels) # NOTE: add camera-aware
148 | self.depth_conv = nn.Sequential(
149 | BasicBlock(mid_channels, mid_channels),
150 | BasicBlock(mid_channels, mid_channels),
151 | BasicBlock(mid_channels, mid_channels),
152 | )
153 | # self.aspp = ASPP(mid_channels, mid_channels, BatchNorm=nn.InstanceNorm2d)
154 |
155 | self.depth_pred = nn.Conv2d(
156 | mid_channels, depth_channels, kernel_size=1, stride=1, padding=0
157 | )
158 | self.infer_mode = infer_mode
159 |
160 | def forward(
161 | self,
162 | x=None,
163 | sweep_intrins=None,
164 | scaled_pixel_size=None,
165 | scale_depth_factor=1000.0,
166 | ):
167 | # self.eval()
168 | inv_intrinsics = torch.inverse(sweep_intrins)
169 | pixel_size = torch.norm(
170 | torch.stack(
171 | [inv_intrinsics[..., 0, 0], inv_intrinsics[..., 1, 1]], dim=-1
172 | ),
173 | dim=-1,
174 | ).reshape(-1, 1).to(x.device)
175 | scaled_pixel_size = pixel_size * scale_depth_factor
176 |
177 | x = self.reduce_conv(x)
178 | # aug_scale = torch.sqrt(sweep_post_rots_ida[..., 0, 0] ** 2 + sweep_post_rots_ida[..., 0, 1] ** 2).reshape(-1, 1)
179 | x_se = self.mlp(scaled_pixel_size)[..., None, None]
180 |
181 | x = self.se(x, x_se)
182 | x = self.depth_conv(x)
183 | # x = self.aspp(x)
184 | depth = self.depth_pred(x)
185 | return depth
186 |
--------------------------------------------------------------------------------
/iso/models/flosp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class FLoSP(nn.Module):
6 | def __init__(self, scene_size, dataset, project_scale):
7 | super().__init__()
8 | self.scene_size = scene_size
9 | self.dataset = dataset
10 | self.project_scale = project_scale
11 |
12 | def forward(self, x2d, projected_pix, fov_mask):
13 | c, h, w = x2d.shape
14 |
15 | src = x2d.view(c, -1)
16 | zeros_vec = torch.zeros(c, 1).type_as(src)
17 | src = torch.cat([src, zeros_vec], 1)
18 |
19 | pix_x, pix_y = projected_pix[:, 0], projected_pix[:, 1]
20 | img_indices = pix_y * w + pix_x
21 | img_indices[~fov_mask] = h * w
22 | img_indices = img_indices.expand(c, -1).long() # c, HWD
23 | src_feature = torch.gather(src, 1, img_indices)
24 |
25 | if self.dataset == "NYU":
26 | x3d = src_feature.reshape(
27 | c,
28 | self.scene_size[0] // self.project_scale,
29 | self.scene_size[2] // self.project_scale,
30 | self.scene_size[1] // self.project_scale,
31 | )
32 | x3d = x3d.permute(0, 1, 3, 2)
33 | elif self.dataset == "OccScanNet" or self.dataset == "OccScanNet_mini":
34 | x3d = src_feature.reshape(
35 | c,
36 | self.scene_size[0] // self.project_scale,
37 | self.scene_size[1] // self.project_scale,
38 | self.scene_size[2] // self.project_scale,
39 | )
40 |
41 | return x3d
42 |
--------------------------------------------------------------------------------
/iso/models/iso.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | import torch.nn as nn
4 | from iso.models.unet3d_nyu import UNet3D as UNet3DNYU
5 | from iso.loss.sscMetrics import SSCMetrics
6 | from iso.loss.ssc_loss import sem_scal_loss, CE_ssc_loss, KL_sep, geo_scal_loss
7 | from iso.models.flosp import FLoSP
8 | from iso.models.depthnet import DepthNet
9 | from iso.loss.CRP_loss import compute_super_CP_multilabel_loss
10 | import numpy as np
11 | import torch.nn.functional as F
12 | from iso.models.unet2d import UNet2D
13 | from torch.optim.lr_scheduler import MultiStepLR
14 | import sys
15 | sys.path.append('./iso')
16 | sys.path.append('./depth_anything/metric_depth')
17 | from depth_anything.metric_depth.zoedepth.models.builder import build_model as build_depthany_model
18 | from depth_anything.metric_depth.zoedepth.utils.config import get_config as get_depthany_config
19 |
20 | from iso.models.modules import sample_grid_feature, get_depth_index, sample_3d_feature, bin_depths
21 | # from iso.models.depth_utils import down_sample_depth_dist
22 | from iso.loss.depth_loss import DepthClsLoss
23 |
24 | import torch
25 |
26 | from torch.cuda.amp import autocast
27 | import torch.nn.functional as F
28 |
29 | # from transformers import pipeline
30 | from PIL import Image
31 |
32 |
33 | class ISO(pl.LightningModule):
34 | def __init__(
35 | self,
36 | n_classes,
37 | class_names,
38 | feature,
39 | class_weights,
40 | project_scale,
41 | full_scene_size,
42 | dataset,
43 | n_relations=4,
44 | context_prior=True,
45 | fp_loss=True,
46 | project_res=[],
47 | bevdepth=False,
48 | voxeldepth=False,
49 | voxeldepth_res=[],
50 | frustum_size=4,
51 | relation_loss=False,
52 | CE_ssc_loss=True,
53 | geo_scal_loss=True,
54 | sem_scal_loss=True,
55 | lr=1e-4,
56 | weight_decay=1e-4,
57 | use_gt_depth=False,
58 | add_fusion=False,
59 | use_zoedepth=True,
60 | use_depthanything=False,
61 | zoedepth_as_gt=False,
62 | depthanything_as_gt=False,
63 | frozen_encoder=False,
64 | ):
65 | super().__init__()
66 |
67 | self.project_res = project_res
68 | self.bevdepth = bevdepth
69 | self.voxeldepth = voxeldepth
70 | self.voxeldepth_res = voxeldepth_res
71 | self.fp_loss = fp_loss
72 | self.dataset = dataset
73 | self.context_prior = context_prior
74 | self.frustum_size = frustum_size
75 | self.class_names = class_names
76 | self.relation_loss = relation_loss
77 | self.CE_ssc_loss = CE_ssc_loss
78 | self.sem_scal_loss = sem_scal_loss
79 | self.geo_scal_loss = geo_scal_loss
80 | self.project_scale = project_scale
81 | self.class_weights = class_weights
82 | self.lr = lr
83 | self.weight_decay = weight_decay
84 | self.use_gt_depth = use_gt_depth
85 | self.add_fusion = add_fusion
86 | self.use_zoedepth = use_zoedepth
87 | self.use_depthanything = use_depthanything
88 | self.zoedepth_as_gt = zoedepth_as_gt
89 | self.depthanything_as_gt = depthanything_as_gt
90 | self.frozen_encoder = frozen_encoder
91 |
92 | self.projects = {}
93 | self.scale_2ds = [1, 2, 4, 8] # 2D scales
94 | for scale_2d in self.scale_2ds:
95 | self.projects[str(scale_2d)] = FLoSP(
96 | full_scene_size, project_scale=self.project_scale, dataset=self.dataset
97 | )
98 | self.projects = nn.ModuleDict(self.projects)
99 |
100 | self.n_classes = n_classes
101 | if self.dataset == "NYU" or self.dataset == "OccScanNet" or self.dataset == "OccScanNet_mini":
102 | self.net_3d_decoder = UNet3DNYU(
103 | self.n_classes,
104 | nn.BatchNorm3d,
105 | n_relations=n_relations,
106 | feature=feature,
107 | full_scene_size=full_scene_size,
108 | context_prior=context_prior,
109 | )
110 | # if self.voxeldepth:
111 | # self.depth_net_3d_decoder = UNet3DNYU(
112 | # self.n_classes,
113 | # nn.BatchNorm3d,
114 | # n_relations=n_relations,
115 | # feature=feature,
116 | # full_scene_size=full_scene_size,
117 | # context_prior=context_prior,
118 | # beforehead=True,
119 | # )
120 | elif self.dataset == "kitti":
121 | self.net_3d_decoder = UNet3DKitti(
122 | self.n_classes,
123 | nn.BatchNorm3d,
124 | project_scale=project_scale,
125 | feature=feature,
126 | full_scene_size=full_scene_size,
127 | context_prior=context_prior,
128 | )
129 | self.net_rgb = UNet2D.build(out_feature=feature, use_decoder=True, frozen_encoder=self.frozen_encoder)
130 |
131 | if self.voxeldepth and not self.use_gt_depth:
132 | self.depthnet_1_1 = DepthNet(200, 256, 200, 64)
133 | if self.use_zoedepth: # use zoedepth pretrained
134 | self.net_depth = self._init_zoedepth()
135 | self.depthnet_1_1 = DepthNet(201, 256, 200, 64)
136 | elif self.use_depthanything:
137 | self.net_depth = self._init_depthanything()
138 | self.depthnet_1_1 = DepthNet(201, 256, 200, 64)
139 | elif self.zoedepth_as_gt: # use gt and use zoedepth as gt
140 | self.net_depth = self._init_zoedepth()
141 | elif self.depthanything_as_gt:
142 | self.net_depth = self._init_depthanything()
143 | elif self.bevdepth:
144 | self.net_depth = self._init_zoedepth()
145 | else:
146 | pass # use gt and use dataset gt
147 |
148 | # log hyperparameters
149 | self.save_hyperparameters()
150 |
151 | self.train_metrics = SSCMetrics(self.n_classes)
152 | self.val_metrics = SSCMetrics(self.n_classes)
153 | self.test_metrics = SSCMetrics(self.n_classes)
154 |
155 | def _init_zoedepth(self):
156 | conf = get_config("zoedepth", "infer")
157 | conf['img_size'] = [480, 640]
158 | model_zoe_n = build_model(conf)
159 | return model_zoe_n.cuda()
160 |
161 | def _init_depthanything(self):
162 | import sys
163 | sys.path.append('/home/hongxiao.yu/projects/ISO/depth_anything/metric_depth')
164 | overrite = {"pretrained_resource": "local::/home/hongxiao.yu/projects/ISO/checkpoints/depth_anything_metric_depth_indoor.pt"}
165 | conf = get_depthany_config("zoedepth", "infer", "nyu", **overrite)
166 | # conf['img_size'] = [480, 640]
167 | from pprint import pprint
168 | # pprint(conf)
169 | model = build_depthany_model(conf)
170 | return model.cuda()
171 | # return pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")
172 |
173 | def _set_train_params(self, curr_epoch):
174 | if curr_epoch // 2 == 1:
175 | for k, p in self.named_parameters():
176 | if 'depthnet_1_1' in k:
177 | p.requires_grad = True
178 | else:
179 | p.requires_grad = False
180 | else:
181 | for k, p in self.named_parameters():
182 | if 'depthnet_1_1' in k:
183 | p.requires_grad = False
184 | else:
185 | p.requires_grad = True
186 |
187 | def forward(self, batch):
188 |
189 | # for k, v in self.state_dict().items():
190 | # if v.dtype != torch.float32:
191 | # print(k, v.dtype)
192 | # breakpoint()
193 |
194 | # self._set_train_params(self.current_epoch)
195 |
196 | img = batch["img"]
197 | raw_img = batch['raw_img']
198 | pix_z = batch['pix_z'] # (B, 129600)
199 | bs = len(img)
200 | # print(img)
201 |
202 | out = {}
203 |
204 | # for k, v in self.named_parameters():
205 | # print(k, ':', v)
206 | x_rgb = self.net_rgb(img)
207 |
208 | x3ds = []
209 | x3d_bevs = []
210 | if self.add_fusion:
211 | x3ds_res = []
212 | if self.voxeldepth:
213 | depth_preds = {
214 | '1_1': [],
215 | '1_2': [],
216 | '1_4': [],
217 | '1_8': [],
218 | }
219 | for i in range(bs):
220 | if self.voxeldepth:
221 | x3d_depths = {
222 | '1_1': None,
223 | '1_2': None,
224 | '1_4': None,
225 | '1_8': None,
226 | }
227 | depths = {
228 | '1_1': None,
229 | '1_2': None,
230 | '1_4': None,
231 | '1_8': None,
232 | }
233 |
234 | if not self.use_gt_depth:
235 | if self.use_zoedepth:
236 | # self.net_depth.eval() # zoe_depth
237 | # for param in self.net_depth.parameters():
238 | # param.requires_grad = False
239 | # rslts = self.net_depth(raw_img[i:i+1], return_probs=False, return_final_centers=False)
240 | # feature = rslts['metric_depth'] # (1, 1, 480, 640)
241 | self.net_depth.device = 'cuda'
242 | feature = self.net_depth.infer_pil(raw_img[i], output_type="tensor", with_flip_aug=False).cuda().unsqueeze(0).unsqueeze(0)
243 |
244 | # print(feature.shape)
245 |
246 | # import matplotlib.pyplot as plt
247 | # plt.subplot(1, 2, 1)
248 | # plt.imshow(feature[0].permute(1, 2, 0).cpu().numpy())
249 | # plt.subplot(1, 2, 2)
250 | # plt.imshow(batch['depth_gt'][i].permute(1, 2, 0).cpu().numpy())
251 | # plt.savefig('/home/hongxiao.yu/ISO/depth_compare.png')
252 |
253 | input_kwargs = {
254 | "img_feat_1_1": torch.cat([x_rgb['1_1'][i:i+1], feature], dim=1),
255 | "cam_k": batch["cam_k"][i:i+1],
256 | "T_velo_2_cam": batch["cam_pose"][i:i+1],
257 | "vox_origin": batch['vox_origin'][i:i+1],
258 | }
259 | elif self.use_depthanything:
260 | self.net_depth.device = 'cuda'
261 | feature = self.net_depth.infer_pil(raw_img[i], output_type="tensor", with_flip_aug=False).cuda().unsqueeze(0).unsqueeze(0)
262 |
263 | # print(feature.shape)
264 | # print(feature.shape)
265 |
266 | # import matplotlib.pyplot as plt
267 | # plt.subplot(1, 2, 1)
268 | # plt.imshow(feature[0].permute(1, 2, 0).cpu().numpy())
269 | # plt.subplot(1, 2, 2)
270 | # plt.imshow(batch['depth_gt'][i].permute(1, 2, 0).cpu().numpy())
271 | # plt.savefig('/home/hongxiao.yu/ISO/depth_compare.png')
272 |
273 | input_kwargs = {
274 | "img_feat_1_1": torch.cat([x_rgb['1_1'][i:i+1], feature], dim=1),
275 | "cam_k": batch["cam_k"][i:i+1],
276 | "T_velo_2_cam": batch["cam_pose"][i:i+1],
277 | "vox_origin": batch['vox_origin'][i:i+1],
278 | }
279 | else:
280 | input_kwargs = {
281 | "img_feat_1_1": x_rgb['1_1'][i:i+1],
282 | "cam_k": batch["cam_k"][i:i+1],
283 | "T_velo_2_cam": batch["cam_pose"][i:i+1],
284 | "vox_origin": batch['vox_origin'][i:i+1],
285 | }
286 | intrins_mat = input_kwargs['cam_k'].new_zeros(1, 4, 4).to(torch.float)
287 | intrins_mat[:, :3, :3] = input_kwargs['cam_k']
288 | intrins_mat[:, 3, 3] = 1 # (1, 4, 4)
289 |
290 | depth_feature_1_1 = self.depthnet_1_1(
291 | x=input_kwargs['img_feat_1_1'],
292 | sweep_intrins=intrins_mat,
293 | scaled_pixel_size=None,
294 | )
295 | depths['1_1'] = depth_feature_1_1.softmax(1) # 得到depth的分布
296 | for res in self.voxeldepth_res[1:]:
297 | depths['1_'+str(res)] = down_sample_depth_dist(depths['1_1'], int(res))
298 | else:
299 | disc_cfg = {
300 | "mode": "LID",
301 | "num_bins": 64,
302 | "depth_min": 0,
303 | "depth_max": 10,
304 | }
305 | depth_1_1 = batch['depth_gt'][i:i+1]
306 | # print(depth_1_1.shape)
307 | if self.zoedepth_as_gt:
308 | self.net_depth.eval() # zoe_depth
309 | for param in self.net_depth.parameters():
310 | param.requires_grad = False
311 | self.net_depth.device = 'cuda'
312 | rslts = self.net_depth.infer_pil(raw_img[i], output_type="tensor", with_flip_aug=False).cuda() # TODO: Need Check
313 | depth_1_1 = rslts['metric_depth'] # (1, 1, 480, 640)
314 | elif self.depthanything_as_gt:
315 | self.net_depth.eval() # zoe_depth
316 | for param in self.net_depth.parameters():
317 | param.requires_grad = False
318 | self.net_depth.device = 'cuda'
319 | depth_1_1 = self.net_depth.infer_pil(raw_img[i], output_type="tensor", with_flip_aug=False).cuda().unsqueeze(0).unsqueeze(0) # TODO: Need Check
320 | # depth_1_1 = rslts['metric_depth'] # (1, 1, 480, 640)
321 | # print(depth_1_1.shape)
322 | depth_1_1 = bin_depths(depth_map=depth_1_1, target=True, **disc_cfg).cuda() # (1, 1, 480, 640)
323 | print(depth_1_1.shape)
324 | # depth_1_1 = ((depth_1_1 - (0.1 - 0.13)) / 0.13).cuda().long()
325 | # print(depth_1_1.shape)
326 | depth_1_1 = F.one_hot(depth_1_1[:, 0, :, :], 81).permute(0, 3, 1, 2)[:, :-1, :, :] # (1, 81, 480, 640)
327 | depths['1_1'] = depth_1_1.float() # 得到depth的分布
328 | for res in self.voxeldepth_res[1:]:
329 | depths['1_'+str(res)] = down_sample_depth_dist(depths['1_1'], int(res)).float()
330 |
331 | projected_pix = batch["projected_pix_{}".format(self.project_scale)][i].cuda()
332 | fov_mask = batch["fov_mask_{}".format(self.project_scale)][i].cuda()
333 |
334 | # pix_z_index = get_depth_index(pix_z[i])
335 | disc_cfg = {
336 | "mode": "LID",
337 | "num_bins": 64,
338 | "depth_min": 0,
339 | "depth_max": 10,
340 | }
341 | pix_z_index = bin_depths(depth_map=pix_z[i], target=True, **disc_cfg).to(fov_mask.device)
342 | # pix_z_index = ((pix_z[i] - (0.1 - 0.13)) / 0.13).to(fov_mask.device)
343 | # pix_z_index = torch.where(
344 | # (pix_z_index < 80 + 1) & (pix_z_index >= 0.0),
345 | # pix_z_index,
346 | # torch.zeros_like(pix_z_index),
347 | # )
348 | dist_mask = torch.logical_and(pix_z_index >= 0, pix_z_index < 64)
349 | dist_mask = torch.logical_and(dist_mask, fov_mask)
350 |
351 | for res in self.voxeldepth_res:
352 | probs = torch.zeros((129600, 1), dtype=torch.float32).to(self.device)
353 | # print(depths['1_'+str(res)].dtype, (projected_pix//int(res)).dtype, pix_z_index.dtype)
354 | probs[dist_mask] = sample_3d_feature(depths['1_'+str(res)], projected_pix//int(res), pix_z_index, dist_mask)
355 | if self.dataset == 'NYU':
356 | x3d_depths['1_'+str(res)] = probs.reshape(60, 60, 36).permute(0, 2, 1).unsqueeze(0)
357 | elif self.dataset == 'OccScanNet' or self.dataset == 'OccScanNet_mini':
358 | x3d_depths['1_'+str(res)] = probs.reshape(60, 60, 36).unsqueeze(0)
359 | depth_preds['1_'+str(res)].append(depths['1_'+str(res)]) # (1, 64, 60, 80)
360 |
361 | x3d = None
362 | if self.add_fusion:
363 | x3d_res = None
364 | for scale_2d in self.project_res:
365 |
366 | # project features at each 2D scale to target 3D scale
367 | scale_2d = int(scale_2d)
368 | projected_pix = batch["projected_pix_{}".format(self.project_scale)][i].cuda()
369 | fov_mask = batch["fov_mask_{}".format(self.project_scale)][i].cuda()
370 | if self.bevdepth and scale_2d == 4:
371 | xys = projected_pix // scale_2d
372 |
373 | D, fH, fW = 64, 480 // scale_2d, 640 // scale_2d
374 | xs = torch.linspace(0, 640 - 1, fW, dtype=torch.float).view(1, 1, fW).expand(D, fH, fW).cuda() # (64, 120, 160)
375 | ys = torch.linspace(0, 480 - 1, fH, dtype=torch.float).view(1, fH, 1).expand(D, fH, fW).cuda() # (64, 120, 160)
376 | d_xs = torch.floor(xs[0].reshape(-1)).to(torch.long) # (fH*fW,)
377 | d_ys = torch.floor(ys[0].reshape(-1)).to(torch.long) # (fH*fW,)
378 |
379 | self.net_depth.device = 'cuda'
380 | # feature = self.net_depth.infer_pil(raw_img[i], output_type="tensor", with_flip_aug=False).cuda().unsqueeze(0).unsqueeze(0)
381 | rslts = self.net_depth(img[i:i+1], return_final_centers=True, return_probs=True)
382 | probs = rslts['probs'].cuda() # 得到depth distribution
383 | bin_center = rslts['bin_centers'].cuda() # (1, 64, 384, 512)
384 | # print(probs.shape)
385 | probs = probs[0, :, d_ys, d_xs].reshape(D, fH, fW).to(projected_pix.device) # (D, fH, fW) depth distribution
386 | ds = bin_center[0, :, d_ys, d_xs].reshape(D, fH, fW).to(xs.device) # (D, fH, fW) depth bins distance
387 |
388 | # x = F.interpolate(x_rgb["1_" + str(scale_2d)][i][None, ...], (384, 512), mode='bilinear', align_corners=True)[0] # (200, 384, 512)
389 | x = x_rgb["1_" + str(scale_2d)][i]
390 | x = probs.unsqueeze(0) * x.unsqueeze(1) # (1, 64, 384, 512), (200, 1, 384, 512) --> (200, 64, 384, 512)
391 | x = x.reshape(-1, fH, fW)
392 | x = self.projects[str(scale_2d)](
393 | x,
394 | xys,
395 | fov_mask,
396 | )
397 | # print(x.shape) # (200*64, 60, 60, 36)
398 | x = x.reshape(200*64, 60*36*60).T # (60*60*36, 200*64)
399 | pix_z_depth = ds[:, xys[:, 1][fov_mask], xys[:, 0][fov_mask]] # get the depth bins distance (64, K)
400 | pix_zz = pix_z[i][fov_mask] # (K,)
401 | # print(pix_z_depth.shape, pix_z.shape)
402 | pix_z_delta = torch.abs(pix_z_depth.to(pix_zz.device) - pix_zz[None, ...]) # (64, K)
403 | # print(pix_z_delta.shape)
404 | min_z = torch.argmin(pix_z_delta, dim=0) # (K,)
405 | # print(min_z.shape)
406 | temp = x[fov_mask].reshape(min_z.shape[0], 200, 64) # (K, 200, 64)
407 | temp = temp[torch.arange(min_z.shape[0]).to(torch.long), :, min_z] # (K, 200)
408 | # temp = temp.reshape(temp.shape[0], 200, 1).repeat(1, 1, 64).reshape(temp.shape[0], 200*64) # (K, 200*64)
409 | # x[fov_mask] = x[fov_mask][torch.arange(min_z.shape[0]), min_z*200:(min_z+1)*200] = temp # (K, 200*64)
410 | # x = x.T.reshape(200*64, 60, 60, 36)[:200, ...]
411 | x3d_bev = torch.zeros(60*36*60, 200).to(x.device)
412 | x3d_bev[fov_mask] = temp
413 | x3d_bev = x3d_bev.T.reshape(200, 60, 36, 60)
414 | # torch.cuda.empty_cache()
415 |
416 | if self.add_fusion:
417 | if x3d_res is None:
418 | x3d_res = self.projects[str(scale_2d)](
419 | x_rgb["1_" + str(scale_2d)][i],
420 | projected_pix // scale_2d,
421 | fov_mask,
422 | )
423 | else:
424 | x3d_res += self.projects[str(scale_2d)](
425 | x_rgb["1_" + str(scale_2d)][i],
426 | projected_pix // scale_2d,
427 | fov_mask,
428 | )
429 |
430 | # Sum all the 3D features
431 | if x3d is None:
432 | if self.voxeldepth:
433 | if len(self.voxeldepth_res) == 1:
434 | res = self.voxeldepth_res[0]
435 | x3d = self.projects[str(scale_2d)](
436 | x_rgb["1_" + str(scale_2d)][i],
437 | projected_pix // scale_2d,
438 | fov_mask,
439 | ) * x3d_depths['1_'+str(res)] * 100
440 | else:
441 | x3d = self.projects[str(scale_2d)](
442 | x_rgb["1_" + str(scale_2d)][i],
443 | projected_pix // scale_2d,
444 | fov_mask,
445 | ) * x3d_depths['1_1'] * 100
446 | else:
447 | x3d = self.projects[str(scale_2d)](
448 | x_rgb["1_" + str(scale_2d)][i],
449 | projected_pix // scale_2d,
450 | fov_mask,
451 | )
452 | else:
453 | if self.voxeldepth:
454 | if len(self.voxeldepth_res) == 1:
455 | res = self.voxeldepth_res[0]
456 | x3d = self.projects[str(scale_2d)](
457 | x_rgb["1_" + str(scale_2d)][i],
458 | projected_pix // scale_2d,
459 | fov_mask,
460 | ) * x3d_depths['1_'+str(res)] * 100
461 | else:
462 | x3d += self.projects[str(scale_2d)](
463 | x_rgb["1_" + str(scale_2d)][i],
464 | projected_pix // scale_2d,
465 | fov_mask,
466 | ) * x3d_depths['1_'+str(scale_2d)] * 100
467 | else:
468 | x3d += self.projects[str(scale_2d)](
469 | x_rgb["1_" + str(scale_2d)][i],
470 | projected_pix // scale_2d,
471 | fov_mask,
472 | )
473 | x3ds.append(x3d)
474 | if self.add_fusion:
475 | x3ds_res.append(x3d_res)
476 | if self.bevdepth:
477 | x3d_bevs.append(x3d_bev)
478 |
479 | input_dict = {
480 | "x3d": torch.stack(x3ds),
481 | }
482 | if self.add_fusion:
483 | input_dict['x3d'] += torch.stack(x3ds_res)
484 | if self.bevdepth:
485 | # print(input_dict['x3d'].shape, torch.stack(x3d_bevs).shape)
486 | input_dict['x3d'] += torch.stack(x3d_bevs)
487 |
488 | # print(input_dict["x3d"])
489 | # from pytorch_lightning import seed_everything
490 | # seed_everything(84, True)
491 | # input_dict['x3d'] = torch.randn_like(input_dict['x3d']).to(torch.float64)
492 | # print(input_dict['x3d'], 'x3d')
493 |
494 | out = self.net_3d_decoder(input_dict)
495 | # print(torch.randn((5, 5)), 'x3d')
496 | # print(torch.randn((5, 5)), 'x3d')
497 |
498 | if self.voxeldepth and not self.use_gt_depth:
499 | for res in self.voxeldepth_res:
500 | out['depth_1_'+str(res)] = torch.vstack(depth_preds['1_'+str(res)]) # (B, 64, 60, 80)
501 | out['depth_gt'] = batch['depth_gt']
502 | # print(batch['name'], out["ssc_logit"])
503 | return out
504 |
505 | def step(self, batch, step_type, metric):
506 | bs = len(batch["img"])
507 | loss = 0
508 | out_dict = self(batch)
509 | ssc_pred = out_dict["ssc_logit"]
510 | ssc_pred = torch.where(torch.isinf(ssc_pred), torch.zeros_like(ssc_pred), ssc_pred)
511 | ssc_pred = torch.where(torch.isnan(ssc_pred), torch.zeros_like(ssc_pred), ssc_pred)
512 | # print(ssc_pred.requires_grad, ssc_pred.grad_fn)
513 | # torch.cuda.manual_seed_all(42)
514 | # ssc_pred = torch.randn_like(ssc_pred)
515 | target = batch["target"]
516 |
517 | if self.voxeldepth and not self.use_gt_depth:
518 | loss_depth_dict = {}
519 | loss_depth = 0.0
520 | out_dict['depth_1_1'] = torch.where(torch.isnan(out_dict['depth_1_1']), torch.zeros_like(out_dict['depth_1_1']), out_dict['depth_1_1'], )
521 | # print(521, torch.any(torch.isnan(out_dict['depth_1_1'])))
522 | if torch.any(torch.isnan(out_dict['depth_1_1'])):
523 | print("===== Got Inf !!! =====")
524 | exit(-1)
525 | for res in self.voxeldepth_res:
526 | loss_depth_dict['1_'+str(res)] = DepthClsLoss(int(res), [0.0, 10, 0.16], 64).get_depth_loss(out_dict['depth_gt'],
527 | out_dict['depth_1_'+str(res)].unsqueeze(1))
528 | loss_depth += loss_depth_dict['1_'+str(res)]
529 | loss_depth = loss_depth / len(loss_depth_dict)
530 | loss += loss_depth
531 | # print(521, torch.any(torch.isnan(loss_depth)))
532 | self.log(
533 | step_type + "/loss_depth",
534 | loss_depth.detach(),
535 | on_epoch=True,
536 | sync_dist=True,
537 | )
538 |
539 | if self.context_prior:
540 | P_logits = out_dict["P_logits"]
541 | P_logits = torch.where(torch.isnan(P_logits), torch.zeros_like(P_logits), P_logits)
542 |
543 | CP_mega_matrices = batch["CP_mega_matrices"]
544 |
545 | if self.relation_loss:
546 | # print(543, torch.any(torch.isnan(P_logits)))
547 | loss_rel_ce = compute_super_CP_multilabel_loss(
548 | P_logits, CP_mega_matrices
549 | )
550 | # loss_rel_ce = torch.where(torch.isnan(loss_rel_ce), torch.zeros_like(loss_rel_ce), loss_rel_ce)
551 | loss += loss_rel_ce
552 | # print(543, torch.any(torch.isnan(loss_rel_ce)))
553 | self.log(
554 | step_type + "/loss_relation_ce_super",
555 | loss_rel_ce.detach(),
556 | on_epoch=True,
557 | sync_dist=True,
558 | )
559 |
560 | class_weight = self.class_weights.type_as(batch["img"])
561 | if self.CE_ssc_loss:
562 | loss_ssc = CE_ssc_loss(ssc_pred, target, class_weight)
563 | # print(558, torch.any(torch.isnan(loss_ssc)))
564 | loss += loss_ssc
565 | self.log(
566 | step_type + "/loss_ssc",
567 | loss_ssc.detach(),
568 | on_epoch=True,
569 | sync_dist=True,
570 | )
571 |
572 | if self.sem_scal_loss:
573 | loss_sem_scal = sem_scal_loss(ssc_pred, target)
574 | # print(569, torch.any(torch.isnan(loss_sem_scal)))
575 | loss += loss_sem_scal
576 | self.log(
577 | step_type + "/loss_sem_scal",
578 | loss_sem_scal.detach(),
579 | on_epoch=True,
580 | sync_dist=True,
581 | )
582 |
583 | if self.geo_scal_loss:
584 | loss_geo_scal = geo_scal_loss(ssc_pred, target)
585 | # print(580, torch.any(torch.isnan(loss_geo_scal)))
586 | loss += loss_geo_scal
587 | self.log(
588 | step_type + "/loss_geo_scal",
589 | loss_geo_scal.detach(),
590 | on_epoch=True,
591 | sync_dist=True,
592 | )
593 |
594 | if self.fp_loss and step_type != "test":
595 | frustums_masks = torch.stack(batch["frustums_masks"])
596 | frustums_class_dists = torch.stack(
597 | batch["frustums_class_dists"]
598 | ).float() # (bs, n_frustums, n_classes)
599 | n_frustums = frustums_class_dists.shape[1]
600 |
601 | pred_prob = F.softmax(ssc_pred, dim=1)
602 | batch_cnt = frustums_class_dists.sum(0) # (n_frustums, n_classes)
603 |
604 | frustum_loss = 0
605 | frustum_nonempty = 0
606 | for frus in range(n_frustums):
607 | frustum_mask = frustums_masks[:, frus, :, :, :].unsqueeze(1).float()
608 | prob = frustum_mask * pred_prob # bs, n_classes, H, W, D
609 | prob = prob.reshape(bs, self.n_classes, -1).permute(1, 0, 2)
610 | prob = prob.reshape(self.n_classes, -1)
611 | cum_prob = prob.sum(dim=1) # n_classes
612 |
613 | total_cnt = torch.sum(batch_cnt[frus])
614 | total_prob = prob.sum()
615 | if total_prob > 0 and total_cnt > 0:
616 | frustum_target_proportion = batch_cnt[frus] / total_cnt
617 | cum_prob = cum_prob / total_prob # n_classes
618 | frustum_loss_i = KL_sep(cum_prob, frustum_target_proportion)
619 | frustum_loss += frustum_loss_i
620 | frustum_nonempty += 1
621 | frustum_loss = frustum_loss / frustum_nonempty
622 | loss += frustum_loss
623 | # print(618, torch.any(torch.isnan(frustum_loss)))
624 | self.log(
625 | step_type + "/loss_frustums",
626 | frustum_loss.detach(),
627 | on_epoch=True,
628 | sync_dist=True,
629 | )
630 |
631 | y_true = target.cpu().numpy()
632 | y_pred = ssc_pred.detach().cpu().numpy()
633 | y_pred = np.argmax(y_pred, axis=1)
634 | metric.add_batch(y_pred, y_true)
635 |
636 | self.log(step_type + "/loss", loss.detach(), on_epoch=True, sync_dist=True, prog_bar=True)
637 |
638 | # loss_dict = {'loss': loss}
639 | # print(636, torch.any(torch.isnan(loss)))
640 |
641 | return loss
642 |
643 | def training_step(self, batch, batch_idx):
644 | return self.step(batch, "train", self.train_metrics)
645 |
646 | def validation_step(self, batch, batch_idx):
647 | self.step(batch, "val", self.val_metrics)
648 |
649 | def on_validation_epoch_end(self): #, outputs):
650 | metric_list = [("train", self.train_metrics), ("val", self.val_metrics)]
651 |
652 | for prefix, metric in metric_list:
653 | stats = metric.get_stats()
654 | for i, class_name in enumerate(self.class_names):
655 | self.log(
656 | "{}_SemIoU/{}".format(prefix, class_name),
657 | stats["iou_ssc"][i],
658 | sync_dist=True,
659 | )
660 | self.log("{}/mIoU".format(prefix), torch.tensor(stats["iou_ssc_mean"]).to(torch.float32), sync_dist=True)
661 | self.log("{}/IoU".format(prefix), torch.tensor(stats["iou"]).to(torch.float32), sync_dist=True)
662 | self.log("{}/Precision".format(prefix), torch.tensor(stats["precision"]).to(torch.float32), sync_dist=True)
663 | self.log("{}/Recall".format(prefix), torch.tensor(stats["recall"]).to(torch.float32), sync_dist=True)
664 | metric.reset()
665 |
666 | def test_step(self, batch, batch_idx):
667 | self.step(batch, "test", self.test_metrics)
668 |
669 | def on_test_epoch_end(self,):# outputs):
670 | classes = self.class_names
671 | metric_list = [("test", self.test_metrics)]
672 | for prefix, metric in metric_list:
673 | print("{}======".format(prefix))
674 | stats = metric.get_stats()
675 | print(
676 | "Precision={:.4f}, Recall={:.4f}, IoU={:.4f}".format(
677 | stats["precision"] * 100, stats["recall"] * 100, stats["iou"] * 100
678 | )
679 | )
680 | print("class IoU: {}, ".format(classes))
681 | print(
682 | " ".join(["{:.4f}, "] * len(classes)).format(
683 | *(stats["iou_ssc"] * 100).tolist()
684 | )
685 | )
686 | print("mIoU={:.4f}".format(stats["iou_ssc_mean"] * 100))
687 | metric.reset()
688 |
689 | def configure_optimizers(self):
690 | if self.dataset == "NYU" and not self.voxeldepth:
691 | optimizer = torch.optim.AdamW(
692 | self.parameters(), lr=self.lr, weight_decay=self.weight_decay
693 | )
694 | scheduler = MultiStepLR(optimizer, milestones=[20], gamma=0.1)
695 | return [optimizer], [scheduler]
696 | elif self.dataset == "NYU" and self.voxeldepth:
697 | depth_params = []
698 | other_params = []
699 | for k, p in self.named_parameters():
700 | if 'depthnet_1_1' in k:
701 | depth_params.append(p)
702 | else:
703 | other_params.append(p)
704 | params_list = [{'params': depth_params, 'lr': self.lr * 0.05}, # 0.4 high, 0.1 unstable, 0.05 ok
705 | {'params': other_params}]
706 | optimizer = torch.optim.AdamW(
707 | params_list, lr=self.lr, weight_decay=self.weight_decay
708 | )
709 | scheduler = MultiStepLR(optimizer, milestones=[20], gamma=0.1)
710 | return [optimizer], [scheduler]
711 | elif self.dataset == "OccScanNet" or self.dataset == "OccScanNet_mini" and self.voxeldepth:
712 | depth_params = []
713 | other_params = []
714 | for k, p in self.named_parameters():
715 | if 'depthnet_1_1' in k:
716 | depth_params.append(p)
717 | else:
718 | other_params.append(p)
719 | params_list = [{'params': depth_params, 'lr': self.lr * 0.05}, # 0.4 high, 0.1 unstable, 0.05 ok
720 | {'params': other_params}]
721 | optimizer = torch.optim.AdamW(
722 | params_list, lr=self.lr, weight_decay=self.weight_decay
723 | )
724 | scheduler = MultiStepLR(optimizer, milestones=[40], gamma=0.1)
725 | return [optimizer], [scheduler]
726 |
--------------------------------------------------------------------------------
/iso/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from iso.models.DDR import Bottleneck3D
5 |
6 |
7 | class ASPP(nn.Module):
8 | """
9 | ASPP 3D
10 | Adapt from https://github.com/cv-rits/LMSCNet/blob/main/LMSCNet/models/LMSCNet.py#L7
11 | """
12 |
13 | def __init__(self, planes, dilations_conv_list):
14 | super().__init__()
15 |
16 | # ASPP Block
17 | self.conv_list = dilations_conv_list
18 | self.conv1 = nn.ModuleList(
19 | [
20 | nn.Conv3d(
21 | planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False
22 | )
23 | for dil in dilations_conv_list
24 | ]
25 | )
26 | self.bn1 = nn.ModuleList(
27 | [nn.BatchNorm3d(planes) for dil in dilations_conv_list]
28 | )
29 | self.conv2 = nn.ModuleList(
30 | [
31 | nn.Conv3d(
32 | planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False
33 | )
34 | for dil in dilations_conv_list
35 | ]
36 | )
37 | self.bn2 = nn.ModuleList(
38 | [nn.BatchNorm3d(planes) for dil in dilations_conv_list]
39 | )
40 | self.relu = nn.ReLU()
41 |
42 | def forward(self, x_in):
43 |
44 | y = self.bn2[0](self.conv2[0](self.relu(self.bn1[0](self.conv1[0](x_in)))))
45 | for i in range(1, len(self.conv_list)):
46 | y += self.bn2[i](self.conv2[i](self.relu(self.bn1[i](self.conv1[i](x_in)))))
47 | x_in = self.relu(y + x_in) # modified
48 |
49 | return x_in
50 |
51 |
52 | class SegmentationHead(nn.Module):
53 | """
54 | 3D Segmentation heads to retrieve semantic segmentation at each scale.
55 | Formed by Dim expansion, Conv3D, ASPP block, Conv3D.
56 | Taken from https://github.com/cv-rits/LMSCNet/blob/main/LMSCNet/models/LMSCNet.py#L7
57 | """
58 |
59 | def __init__(self, inplanes, planes, nbr_classes, dilations_conv_list):
60 | super().__init__()
61 |
62 | # First convolution
63 | self.conv0 = nn.Conv3d(inplanes, planes, kernel_size=3, padding=1, stride=1)
64 |
65 | # ASPP Block
66 | self.conv_list = dilations_conv_list
67 | self.conv1 = nn.ModuleList(
68 | [
69 | nn.Conv3d(
70 | planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False
71 | )
72 | for dil in dilations_conv_list
73 | ]
74 | )
75 | self.bn1 = nn.ModuleList(
76 | [nn.BatchNorm3d(planes) for dil in dilations_conv_list]
77 | )
78 | self.conv2 = nn.ModuleList(
79 | [
80 | nn.Conv3d(
81 | planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False
82 | )
83 | for dil in dilations_conv_list
84 | ]
85 | )
86 | self.bn2 = nn.ModuleList(
87 | [nn.BatchNorm3d(planes) for dil in dilations_conv_list]
88 | )
89 | self.relu = nn.ReLU()
90 |
91 | self.conv_classes = nn.Conv3d(
92 | planes, nbr_classes, kernel_size=3, padding=1, stride=1
93 | )
94 |
95 | def forward(self, x_in):
96 |
97 | # Convolution to go from inplanes to planes features...
98 | x_in = self.relu(self.conv0(x_in))
99 |
100 | y = self.bn2[0](self.conv2[0](self.relu(self.bn1[0](self.conv1[0](x_in)))))
101 | for i in range(1, len(self.conv_list)):
102 | y += self.bn2[i](self.conv2[i](self.relu(self.bn1[i](self.conv1[i](x_in)))))
103 | x_in = self.relu(y + x_in) # modified
104 |
105 | x_in = self.conv_classes(x_in)
106 |
107 | return x_in
108 |
109 |
110 | class Process(nn.Module):
111 | def __init__(self, feature, norm_layer, bn_momentum, dilations=[1, 2, 3]):
112 | super(Process, self).__init__()
113 | self.main = nn.Sequential(
114 | *[
115 | Bottleneck3D(
116 | feature,
117 | feature // 4,
118 | bn_momentum=bn_momentum,
119 | norm_layer=norm_layer,
120 | dilation=[i, i, i],
121 | )
122 | for i in dilations
123 | ]
124 | )
125 |
126 | def forward(self, x):
127 | return self.main(x)
128 |
129 |
130 | class Upsample(nn.Module):
131 | def __init__(self, in_channels, out_channels, norm_layer, bn_momentum):
132 | super(Upsample, self).__init__()
133 | self.main = nn.Sequential(
134 | nn.ConvTranspose3d(
135 | in_channels,
136 | out_channels,
137 | kernel_size=3,
138 | stride=2,
139 | padding=1,
140 | dilation=1,
141 | output_padding=1,
142 | ),
143 | norm_layer(out_channels, momentum=bn_momentum),
144 | nn.ReLU(),
145 | )
146 |
147 | def forward(self, x):
148 | return self.main(x)
149 |
150 |
151 | class Downsample(nn.Module):
152 | def __init__(self, feature, norm_layer, bn_momentum, expansion=8):
153 | super(Downsample, self).__init__()
154 | self.main = Bottleneck3D(
155 | feature,
156 | feature // 4,
157 | bn_momentum=bn_momentum,
158 | expansion=expansion,
159 | stride=2,
160 | downsample=nn.Sequential(
161 | nn.AvgPool3d(kernel_size=2, stride=2),
162 | nn.Conv3d(
163 | feature,
164 | int(feature * expansion / 4),
165 | kernel_size=1,
166 | stride=1,
167 | bias=False,
168 | ),
169 | norm_layer(int(feature * expansion / 4), momentum=bn_momentum),
170 | ),
171 | norm_layer=norm_layer,
172 | )
173 |
174 | def forward(self, x):
175 | return self.main(x)
176 |
177 |
178 | def sample_grid_feature(feature, fho=480, fwo=640, scale=4):
179 | """
180 | Args:
181 | feature (torch.tensor): 2D feature to be sampled, shape (B, C, H, W)
182 | """
183 | if len(feature.shape) == 4:
184 | B, D, H, W = feature.shape
185 | else:
186 | D, H, W = feature.shape
187 | fH, fW = fho // scale, fwo // scale
188 | xs = torch.linspace(0, fwo - 1, fW, dtype=torch.float).view(1, 1, fW).expand(D, fH, fW) # (64, 120, 160)
189 | ys = torch.linspace(0, fho - 1, fH, dtype=torch.float).view(1, fH, 1).expand(D, fH, fW) # (64, 120, 160)
190 | d_xs = torch.floor(xs[0].reshape(-1)).to(torch.long) # (fH*fW,)
191 | d_ys = torch.floor(ys[0].reshape(-1)).to(torch.long) # (fH*fW,)
192 | # grid_pts = torch.stack([d_xs, d_ys], dim=1) # (fH*fW, 2)
193 |
194 | if len(feature.shape) == 4:
195 | sample_feature = feature[:, :, d_ys, d_xs].reshape(B, D, fH, fW) # (D, fH, fW)
196 | else:
197 | sample_feature = feature[:, d_ys, d_xs].reshape(D, fH, fW)
198 | return sample_feature
199 |
200 |
201 | def get_depth_index(pix_z):
202 | """
203 | Args:
204 | pix_z (torch.tensor): The depth in camera frame after voxel projected to pixel, shape (N,), N is
205 | total voxel number.
206 | """
207 | ds = torch.arange(64).to(pix_z.device) # (64,)
208 | ds = 10 / 64 / 65 * ds * (ds + 1) # (64,)
209 | delta_z = torch.abs(pix_z[None, ...] - ds[..., None]) # (64, N)
210 | pix_z_index = torch.argmin(delta_z, dim=0) # (N,)
211 | return pix_z_index
212 |
213 |
214 | def bin_depths(depth_map, mode, depth_min, depth_max, num_bins, target=False):
215 | """
216 | Converts depth map into bin indices
217 | Args:
218 | depth_map [torch.Tensor(H, W)]: Depth Map
219 | mode [string]: Discretiziation mode (See https://arxiv.org/pdf/2005.13423.pdf for more details)
220 | UD: Uniform discretiziation
221 | LID: Linear increasing discretiziation
222 | SID: Spacing increasing discretiziation
223 | depth_min [float]: Minimum depth value
224 | depth_max [float]: Maximum depth value
225 | num_bins [int]: Number of depth bins
226 | target [bool]: Whether the depth bins indices will be used for a target tensor in loss comparison
227 | Returns:
228 | indices [torch.Tensor(H, W)]: Depth bin indices
229 | """
230 | if mode == "UD":
231 | bin_size = (depth_max - depth_min) / num_bins
232 | indices = (depth_map - depth_min) / bin_size
233 | elif mode == "LID":
234 | bin_size = 2 * (depth_max - depth_min) / (num_bins * (1 + num_bins))
235 | indices = -0.5 + 0.5 * torch.sqrt(1 + 8 * (depth_map - depth_min) / bin_size)
236 | elif mode == "SID":
237 | indices = (
238 | num_bins
239 | * (torch.log(1 + depth_map) - math.log(1 + depth_min))
240 | / (math.log(1 + depth_max) - math.log(1 + depth_min))
241 | )
242 | else:
243 | raise NotImplementedError
244 |
245 | if target:
246 | # Remove indicies outside of bounds (-2, -1, 0, 1, ..., num_bins, num_bins +1) --> (num_bins, num_bins, 0, 1, ..., num_bins, num_bins)
247 | mask = (indices < 0) | (indices > num_bins) | (~torch.isfinite(indices))
248 | indices[mask] = num_bins
249 |
250 | # Convert to integer
251 | indices = indices.type(torch.int64)
252 | return indices.long()
253 |
254 |
255 | def sample_3d_feature(feature_3d, pix_xy, pix_z, fov_mask):
256 | """
257 | Args:
258 | feature_3d (torch.tensor): 3D feature, shape (C, D, H, W).
259 | pix_xy (torch.tensor): Projected pix coordinate, shape (N, 2).
260 | pix_z (torch.tensor): Projected pix depth coordinate, shape (N,).
261 |
262 | Returns:
263 | torch.tensor: Sampled feature, shape (N, C)
264 | """
265 | pix_x, pix_y = pix_xy[:, 0][fov_mask], pix_xy[:, 1][fov_mask]
266 | pix_z = pix_z[fov_mask].to(pix_y.dtype)
267 | ret = feature_3d[:, pix_z, pix_y, pix_x].T
268 | return ret
269 |
270 |
--------------------------------------------------------------------------------
/iso/models/unet2d.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from https://github.com/astra-vision/MonoScene/blob/master/monoscene/models/unet2d.py
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import os
8 |
9 |
10 | class UpSampleBN(nn.Module):
11 | def __init__(self, skip_input, output_features):
12 | super(UpSampleBN, self).__init__()
13 | self._net = nn.Sequential(
14 | nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
15 | nn.BatchNorm2d(output_features),
16 | nn.LeakyReLU(),
17 | nn.Conv2d(
18 | output_features, output_features, kernel_size=3, stride=1, padding=1
19 | ),
20 | nn.BatchNorm2d(output_features),
21 | nn.LeakyReLU(),
22 | )
23 |
24 | def forward(self, x, concat_with):
25 | up_x = F.interpolate(
26 | x,
27 | size=(concat_with.shape[2], concat_with.shape[3]),
28 | mode="bilinear",
29 | align_corners=True,
30 | )
31 | f = torch.cat([up_x, concat_with], dim=1)
32 | return self._net(f)
33 |
34 |
35 | class DecoderBN(nn.Module):
36 | def __init__(
37 | self, num_features, bottleneck_features, out_feature, use_decoder=True
38 | ):
39 | super(DecoderBN, self).__init__()
40 | features = int(num_features)
41 | self.use_decoder = use_decoder
42 |
43 | self.conv2 = nn.Conv2d(
44 | bottleneck_features, features, kernel_size=1, stride=1, padding=0
45 | )
46 |
47 | self.out_feature_1_1 = out_feature # 200
48 | self.out_feature_1_2 = out_feature # 200
49 | self.out_feature_1_4 = out_feature # 200
50 | self.out_feature_1_8 = out_feature # 200
51 | self.out_feature_1_16 = out_feature # 200
52 | self.feature_1_16 = features // 2
53 | self.feature_1_8 = features // 4
54 | self.feature_1_4 = features // 8
55 | self.feature_1_2 = features // 16
56 | self.feature_1_1 = features // 32
57 |
58 | if self.use_decoder:
59 | self.resize_output_1_1 = nn.Conv2d(
60 | self.feature_1_1, self.out_feature_1_1, kernel_size=1
61 | )
62 | self.resize_output_1_2 = nn.Conv2d(
63 | self.feature_1_2, self.out_feature_1_2, kernel_size=1
64 | )
65 | self.resize_output_1_4 = nn.Conv2d(
66 | self.feature_1_4, self.out_feature_1_4, kernel_size=1
67 | )
68 | self.resize_output_1_8 = nn.Conv2d(
69 | self.feature_1_8, self.out_feature_1_8, kernel_size=1
70 | )
71 | self.resize_output_1_16 = nn.Conv2d(
72 | self.feature_1_16, self.out_feature_1_16, kernel_size=1
73 | )
74 |
75 | self.up16 = UpSampleBN(
76 | skip_input=features + 224, output_features=self.feature_1_16
77 | )
78 | self.up8 = UpSampleBN(
79 | skip_input=self.feature_1_16 + 80, output_features=self.feature_1_8
80 | )
81 | self.up4 = UpSampleBN(
82 | skip_input=self.feature_1_8 + 48, output_features=self.feature_1_4
83 | )
84 | self.up2 = UpSampleBN(
85 | skip_input=self.feature_1_4 + 32, output_features=self.feature_1_2
86 | )
87 | self.up1 = UpSampleBN(
88 | skip_input=self.feature_1_2 + 3, output_features=self.feature_1_1
89 | )
90 | else:
91 | self.resize_output_1_1 = nn.Conv2d(3, out_feature, kernel_size=1)
92 | self.resize_output_1_2 = nn.Conv2d(32, out_feature * 2, kernel_size=1)
93 | self.resize_output_1_4 = nn.Conv2d(48, out_feature * 4, kernel_size=1)
94 |
95 | def forward(self, features):
96 | x_block0, x_block1, x_block2, x_block3, x_block4 = (
97 | features[4],
98 | features[5],
99 | features[6],
100 | features[8],
101 | features[11], # (B, 2560, 15, 20)
102 | )
103 | bs = x_block0.shape[0]
104 | x_d0 = self.conv2(x_block4) # (B, 2560, 15, 20)
105 |
106 | if self.use_decoder:
107 | x_1_16 = self.up16(x_d0, x_block3) # (B, 1280, 30, 40)
108 |
109 | x_1_8 = self.up8(x_1_16, x_block2) # (B, 640, 60, 80)
110 | x_1_4 = self.up4(x_1_8, x_block1)
111 | x_1_2 = self.up2(x_1_4, x_block0)
112 | x_1_1 = self.up1(x_1_2, features[0])
113 | return {
114 | "1_1": self.resize_output_1_1(x_1_1),
115 | "1_2": self.resize_output_1_2(x_1_2),
116 | "1_4": self.resize_output_1_4(x_1_4),
117 | "1_8": self.resize_output_1_8(x_1_8),
118 | "1_16": self.resize_output_1_16(x_1_16),
119 | }
120 | else:
121 | x_1_1 = features[0]
122 | x_1_2, x_1_4, x_1_8, x_1_16 = (
123 | features[4],
124 | features[5],
125 | features[6],
126 | features[8],
127 | )
128 | x_global = features[-1].reshape(bs, 2560, -1).mean(2)
129 | return {
130 | "1_1": self.resize_output_1_1(x_1_1),
131 | "1_2": self.resize_output_1_2(x_1_2),
132 | "1_4": self.resize_output_1_4(x_1_4),
133 | "global": x_global,
134 | }
135 |
136 |
137 | class Encoder(nn.Module):
138 | def __init__(self, backend, frozen_encoder=False):
139 | super(Encoder, self).__init__()
140 | self.original_model = backend
141 | self.frozen_encoder = frozen_encoder
142 |
143 | def forward(self, x):
144 | if self.frozen_encoder:
145 | self.eval()
146 | features = [x]
147 | for k, v in self.original_model._modules.items():
148 | if k == "blocks":
149 | for ki, vi in v._modules.items():
150 | features.append(vi(features[-1]))
151 | else:
152 | features.append(v(features[-1]))
153 | return features
154 |
155 |
156 | class UNet2D(nn.Module):
157 | def __init__(self, backend, num_features, out_feature, use_decoder=True, frozen_encoder=False):
158 | super(UNet2D, self).__init__()
159 | self.use_decoder = use_decoder
160 |
161 | self.encoder = Encoder(backend, frozen_encoder)
162 | self.decoder = DecoderBN(
163 | out_feature=out_feature,
164 | use_decoder=use_decoder,
165 | bottleneck_features=num_features,
166 | num_features=num_features,
167 | )
168 |
169 | def forward(self, x, **kwargs):
170 | encoded_feats = self.encoder(x)
171 | unet_out = self.decoder(encoded_feats, **kwargs)
172 | return unet_out
173 |
174 | def get_encoder_params(self): # lr/10 learning rate
175 | return self.encoder.parameters()
176 |
177 | def get_decoder_params(self): # lr learning rate
178 | return self.decoder.parameters()
179 |
180 | @classmethod
181 | def build(cls, **kwargs):
182 | basemodel_name = "tf_efficientnet_b7_ns"
183 | num_features = 2560
184 |
185 | print("Loading base model ()...".format(basemodel_name), end="")
186 | basemodel = torch.hub.load(
187 | "rwightman/gen-efficientnet-pytorch", basemodel_name, pretrained=True
188 | )
189 | print("Done.")
190 |
191 | # Remove last layer
192 | print("Removing last two layers (global_pool & classifier).")
193 | basemodel.global_pool = nn.Identity()
194 | basemodel.classifier = nn.Identity()
195 |
196 | # Building Encoder-Decoder model
197 | print("Building Encoder-Decoder model..", end="")
198 | m = cls(basemodel, num_features=num_features, **kwargs)
199 | print("Done.")
200 | return m
201 |
202 | if __name__ == '__main__':
203 | model = UNet2D.build(out_feature=256, use_decoder=True)
204 |
--------------------------------------------------------------------------------
/iso/models/unet3d_nyu.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | Code adapted from https://github.com/astra-vision/MonoScene/blob/master/monoscene/models/unet3d_nyu.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import numpy as np
9 | from iso.models.CRP3D import CPMegaVoxels
10 | from iso.models.modules import (
11 | Process,
12 | Upsample,
13 | Downsample,
14 | SegmentationHead,
15 | ASPP,
16 | # Downsample_SE,
17 | # Process_SE,
18 | )
19 |
20 |
21 | class UNet3D(nn.Module):
22 | def __init__(
23 | self,
24 | class_num,
25 | norm_layer,
26 | feature,
27 | full_scene_size,
28 | n_relations=4,
29 | project_res=[],
30 | context_prior=True,
31 | bn_momentum=0.1,
32 | ):
33 | super(UNet3D, self).__init__()
34 | self.business_layer = []
35 | self.project_res = project_res
36 |
37 | self.feature_1_4 = feature
38 | self.feature_1_8 = feature * 2
39 | self.feature_1_16 = feature * 4
40 |
41 | self.feature_1_16_dec = self.feature_1_16
42 | self.feature_1_8_dec = self.feature_1_8
43 | self.feature_1_4_dec = self.feature_1_4
44 |
45 | self.process_1_4 = nn.Sequential(
46 | Process(self.feature_1_4, norm_layer, bn_momentum, dilations=[1, 2, 3]),
47 | Downsample(self.feature_1_4, norm_layer, bn_momentum),
48 | )
49 | self.process_1_8 = nn.Sequential(
50 | Process(self.feature_1_8, norm_layer, bn_momentum, dilations=[1, 2, 3]),
51 | Downsample(self.feature_1_8, norm_layer, bn_momentum),
52 | )
53 | self.up_1_16_1_8 = Upsample(
54 | self.feature_1_16_dec, self.feature_1_8_dec, norm_layer, bn_momentum
55 | )
56 | self.up_1_8_1_4 = Upsample(
57 | self.feature_1_8_dec, self.feature_1_4_dec, norm_layer, bn_momentum
58 | )
59 | self.ssc_head_1_4 = SegmentationHead(
60 | self.feature_1_4_dec, self.feature_1_4_dec, class_num, [1, 2, 3]
61 | )
62 |
63 | self.context_prior = context_prior
64 | size_1_16 = tuple(np.ceil(i / 4).astype(int) for i in full_scene_size)
65 |
66 | if context_prior:
67 | self.CP_mega_voxels = CPMegaVoxels(
68 | self.feature_1_16,
69 | size_1_16,
70 | n_relations=n_relations,
71 | bn_momentum=bn_momentum,
72 | )
73 |
74 | #
75 | def forward(self, input_dict):
76 | res = {}
77 |
78 | x3d_1_4 = input_dict["x3d"]
79 | x3d_1_8 = self.process_1_4(x3d_1_4)
80 | x3d_1_16 = self.process_1_8(x3d_1_8)
81 |
82 | if self.context_prior:
83 | ret = self.CP_mega_voxels(x3d_1_16)
84 | x3d_1_16 = ret["x"]
85 | for k in ret.keys():
86 | res[k] = ret[k]
87 |
88 | x3d_up_1_8 = self.up_1_16_1_8(x3d_1_16) + x3d_1_8
89 | x3d_up_1_4 = self.up_1_8_1_4(x3d_up_1_8) + x3d_1_4
90 |
91 | ssc_logit_1_4 = self.ssc_head_1_4(x3d_up_1_4)
92 |
93 | res["ssc_logit"] = ssc_logit_1_4
94 |
95 | return res
96 |
--------------------------------------------------------------------------------
/iso/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | python iso/scripts/eval_iso.py n_gpus=1 batch_size=1
--------------------------------------------------------------------------------
/iso/scripts/eval_iso.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning import Trainer
2 | from iso.models.iso import ISO
3 | from iso.data.NYU.nyu_dm import NYUDataModule
4 | import hydra
5 | from omegaconf import DictConfig
6 | import torch
7 | import os
8 | from hydra.utils import get_original_cwd
9 | from pytorch_lightning import seed_everything
10 |
11 | torch.set_float32_matmul_precision('high')
12 |
13 |
14 | @hydra.main(config_name="../config/iso.yaml")
15 | def main(config: DictConfig):
16 | torch.set_grad_enabled(False)
17 | if config.dataset == "kitti":
18 | config.batch_size = 1
19 | n_classes = 20
20 | feature = 64
21 | project_scale = 2
22 | full_scene_size = (256, 256, 32)
23 | data_module = KittiDataModule(
24 | root=config.kitti_root,
25 | preprocess_root=config.kitti_preprocess_root,
26 | frustum_size=config.frustum_size,
27 | batch_size=int(config.batch_size / config.n_gpus),
28 | num_workers=int(config.num_workers_per_gpu * config.n_gpus),
29 | )
30 |
31 | elif config.dataset == "NYU":
32 | config.batch_size = 2
33 | project_scale = 1
34 | n_classes = 12
35 | feature = 200
36 | full_scene_size = (60, 36, 60)
37 | data_module = NYUDataModule(
38 | root=config.NYU_root,
39 | preprocess_root=config.NYU_preprocess_root,
40 | n_relations=config.n_relations,
41 | frustum_size=config.frustum_size,
42 | batch_size=int(config.batch_size / config.n_gpus),
43 | num_workers=int(config.num_workers_per_gpu * config.n_gpus),
44 | )
45 |
46 | trainer = Trainer(
47 | sync_batchnorm=True, deterministic=False, devices=config.n_gpus, accelerator="gpu",
48 | )
49 |
50 | if config.dataset == "NYU":
51 | model_path = os.path.join(
52 | get_original_cwd(), "trained_models", "iso_nyu.ckpt"
53 | )
54 | else:
55 | model_path = os.path.join(
56 | get_original_cwd(), "trained_models", "iso_kitti.ckpt"
57 | )
58 |
59 | voxeldepth_res = []
60 | if config.voxeldepth:
61 | if config.voxeldepthcfg.depth_scale_1:
62 | voxeldepth_res.append('1')
63 | if config.voxeldepthcfg.depth_scale_2:
64 | voxeldepth_res.append('2')
65 | if config.voxeldepthcfg.depth_scale_4:
66 | voxeldepth_res.append('4')
67 | if config.voxeldepthcfg.depth_scale_8:
68 | voxeldepth_res.append('8')
69 |
70 | os.chdir(hydra.utils.get_original_cwd())
71 |
72 | model = ISO.load_from_checkpoint(
73 | model_path,
74 | feature=feature,
75 | project_scale=project_scale,
76 | fp_loss=config.fp_loss,
77 | full_scene_size=full_scene_size,
78 | voxeldepth=config.voxeldepth,
79 | voxeldepth_res=voxeldepth_res,
80 | #
81 | use_gt_depth=config.use_gt_depth,
82 | add_fusion=config.add_fusion,
83 | use_zoedepth=config.use_zoedepth,
84 | use_depthanything=config.use_depthanything,
85 | zoedepth_as_gt=config.zoedepth_as_gt,
86 | depthanything_as_gt=config.depthanything_as_gt,
87 | frozen_encoder=config.frozen_encoder,
88 | )
89 | model.eval()
90 | data_module.setup()
91 | val_dataloader = data_module.val_dataloader()
92 | trainer.test(model, dataloaders=val_dataloader)
93 |
94 |
95 | if __name__ == "__main__":
96 | main()
97 |
--------------------------------------------------------------------------------
/iso/scripts/generate_output.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning import Trainer
2 | from iso.models.iso import ISO
3 | from iso.data.NYU.nyu_dm import NYUDataModule
4 | # from iso.data.semantic_kitti.kitti_dm import KittiDataModule
5 | # from iso.data.kitti_360.kitti_360_dm import Kitti360DataModule
6 | import hydra
7 | from omegaconf import DictConfig
8 | import torch
9 | import numpy as np
10 | import os
11 | from hydra.utils import get_original_cwd
12 | from tqdm import tqdm
13 | import pickle
14 |
15 |
16 | @hydra.main(config_name="../config/iso.yaml")
17 | def main(config: DictConfig):
18 | torch.set_grad_enabled(False)
19 |
20 | # Setup dataloader
21 | if config.dataset == "kitti" or config.dataset == "kitti_360":
22 | feature = 64
23 | project_scale = 2
24 | full_scene_size = (256, 256, 32)
25 |
26 | if config.dataset == "kitti":
27 | data_module = KittiDataModule(
28 | root=config.kitti_root,
29 | preprocess_root=config.kitti_preprocess_root,
30 | frustum_size=config.frustum_size,
31 | batch_size=int(config.batch_size / config.n_gpus),
32 | num_workers=int(config.num_workers_per_gpu * config.n_gpus),
33 | )
34 | data_module.setup()
35 | data_loader = data_module.val_dataloader()
36 | # data_loader = data_module.test_dataloader() # use this if you want to infer on test set
37 | else:
38 | data_module = Kitti360DataModule(
39 | root=config.kitti_360_root,
40 | sequences=[config.kitti_360_sequence],
41 | n_scans=2000,
42 | batch_size=1,
43 | num_workers=3,
44 | )
45 | data_module.setup()
46 | data_loader = data_module.dataloader()
47 |
48 | elif config.dataset == "NYU":
49 | project_scale = 1
50 | feature = 200
51 | full_scene_size = (60, 36, 60)
52 | data_module = NYUDataModule(
53 | root=config.NYU_root,
54 | preprocess_root=config.NYU_preprocess_root,
55 | n_relations=config.n_relations,
56 | frustum_size=config.frustum_size,
57 | batch_size=int(config.batch_size / config.n_gpus),
58 | num_workers=int(config.num_workers_per_gpu * config.n_gpus),
59 | )
60 | data_module.setup()
61 | data_loader = data_module.val_dataloader()
62 | # data_loader = data_module.test_dataloader() # use this if you want to infer on test set
63 | else:
64 | print("dataset not support")
65 |
66 | # Load pretrained models
67 | if config.dataset == "NYU":
68 | model_path = os.path.join(
69 | get_original_cwd(), "trained_models", "iso_nyu.ckpt"
70 | )
71 | else:
72 | model_path = os.path.join(
73 | get_original_cwd(), "trained_models", "iso_kitti.ckpt"
74 | )
75 |
76 | voxeldepth_res = []
77 | if config.voxeldepth:
78 | if config.voxeldepthcfg.depth_scale_1:
79 | voxeldepth_res.append('1')
80 | if config.voxeldepthcfg.depth_scale_2:
81 | voxeldepth_res.append('2')
82 | if config.voxeldepthcfg.depth_scale_4:
83 | voxeldepth_res.append('4')
84 | if config.voxeldepthcfg.depth_scale_8:
85 | voxeldepth_res.append('8')
86 |
87 | os.chdir(hydra.utils.get_original_cwd())
88 |
89 | model = ISO.load_from_checkpoint(
90 | model_path,
91 | feature=feature,
92 | project_scale=project_scale,
93 | fp_loss=config.fp_loss,
94 | full_scene_size=full_scene_size,
95 | voxeldepth=config.voxeldepth,
96 | voxeldepth_res=voxeldepth_res,
97 | #
98 | use_gt_depth=config.use_gt_depth,
99 | add_fusion=config.add_fusion,
100 | use_zoedepth=config.use_zoedepth,
101 | use_depthanything=config.use_depthanything,
102 | zoedepth_as_gt=config.zoedepth_as_gt,
103 | depthanything_as_gt=config.depthanything_as_gt,
104 | frozen_encoder=config.frozen_encoder,
105 | )
106 | model.cuda()
107 | model.eval()
108 |
109 | # Save prediction and additional data
110 | # to draw the viewing frustum and remove scene outside the room for NYUv2
111 | output_path = os.path.join(config.output_path, config.dataset)
112 | with torch.no_grad():
113 | for batch in tqdm(data_loader):
114 | batch["img"] = batch["img"].cuda()
115 | pred = model(batch)
116 | y_pred = torch.softmax(pred["ssc_logit"], dim=1).detach().cpu().numpy()
117 | y_pred = np.argmax(y_pred, axis=1)
118 | for i in range(config.batch_size):
119 | out_dict = {"y_pred": y_pred[i].astype(np.uint16)}
120 | if "target" in batch:
121 | out_dict["target"] = (
122 | batch["target"][i].detach().cpu().numpy().astype(np.uint16)
123 | )
124 |
125 | if config.dataset == "NYU":
126 | write_path = output_path
127 | filepath = os.path.join(write_path, batch["name"][i] + ".pkl")
128 | out_dict["cam_pose"] = batch["cam_pose"][i].detach().cpu().numpy()
129 | out_dict["vox_origin"] = (
130 | batch["vox_origin"][i].detach().cpu().numpy()
131 | )
132 | else:
133 | write_path = os.path.join(output_path, batch["sequence"][i])
134 | filepath = os.path.join(write_path, batch["frame_id"][i] + ".pkl")
135 | out_dict["fov_mask_1"] = (
136 | batch["fov_mask_1"][i].detach().cpu().numpy()
137 | )
138 | out_dict["cam_k"] = batch["cam_k"][i].detach().cpu().numpy()
139 | out_dict["T_velo_2_cam"] = (
140 | batch["T_velo_2_cam"][i].detach().cpu().numpy()
141 | )
142 |
143 | os.makedirs(write_path, exist_ok=True)
144 | with open(filepath, "wb") as handle:
145 | pickle.dump(out_dict, handle)
146 | print("wrote to", filepath)
147 |
148 |
149 | if __name__ == "__main__":
150 | main()
151 |
--------------------------------------------------------------------------------
/iso/scripts/train.sh:
--------------------------------------------------------------------------------
1 | python iso/scripts/train_iso.py n_gpus=2 batch_size=4
--------------------------------------------------------------------------------
/iso/scripts/train_iso.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('/home/hongxiao.yu/projects/ISO_occscannet/depth_anything')
3 | from iso.data.NYU.params import (
4 | class_weights as NYU_class_weights,
5 | NYU_class_names,
6 | )
7 | from iso.data.OccScanNet.params import (
8 | class_weights as OccScanNet_class_weights,
9 | OccScanNet_class_names
10 | )
11 | from iso.data.NYU.nyu_dm import NYUDataModule
12 | from iso.data.OccScanNet.occscannet_dm import OccScanNetDataModule
13 | from torch.utils.data.dataloader import DataLoader
14 | from iso.models.iso import ISO
15 | from pytorch_lightning import Trainer
16 | import pytorch_lightning as pl
17 | from pytorch_lightning.loggers import TensorBoardLogger
18 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
19 | import os
20 | import hydra
21 | from omegaconf import DictConfig
22 | import numpy as np
23 | import torch
24 |
25 | hydra.output_subdir = None
26 |
27 | pl.seed_everything(658018589) #, workers=True)
28 |
29 | @hydra.main(config_name="../config/iso_occscannet_mini.yaml", config_path='.')
30 | def main(config: DictConfig):
31 | exp_name = config.exp_prefix
32 | exp_name += "_{}_{}".format(config.dataset, config.run)
33 | exp_name += "_FrusSize_{}".format(config.frustum_size)
34 | exp_name += "_nRelations{}".format(config.n_relations)
35 | exp_name += "_WD{}_lr{}".format(config.weight_decay, config.lr)
36 |
37 | if config.use_gt_depth:
38 | exp_name += '_gtdepth'
39 | if config.add_fusion:
40 | exp_name += '_addfusion'
41 | if not config.use_zoedepth:
42 | exp_name += '_nozoedepth'
43 | if config.zoedepth_as_gt:
44 | exp_name += '_zoedepthgt'
45 | if config.frozen_encoder:
46 | exp_name += '_frozen_encoder'
47 | if config.use_depthanything:
48 | exp_name += '_depthanything'
49 |
50 | voxeldepth_res = []
51 | if config.voxeldepth:
52 | exp_name += '_VoxelDepth'
53 | if config.voxeldepthcfg.depth_scale_1:
54 | exp_name += '_1'
55 | voxeldepth_res.append('1')
56 | if config.voxeldepthcfg.depth_scale_2:
57 | exp_name += '_2'
58 | voxeldepth_res.append('2')
59 | if config.voxeldepthcfg.depth_scale_4:
60 | exp_name += '_4'
61 | voxeldepth_res.append('4')
62 | if config.voxeldepthcfg.depth_scale_8:
63 | exp_name += '_8'
64 | voxeldepth_res.append('8')
65 |
66 | if config.CE_ssc_loss:
67 | exp_name += "_CEssc"
68 | if config.geo_scal_loss:
69 | exp_name += "_geoScalLoss"
70 | if config.sem_scal_loss:
71 | exp_name += "_semScalLoss"
72 | if config.fp_loss:
73 | exp_name += "_fpLoss"
74 |
75 | if config.relation_loss:
76 | exp_name += "_CERel"
77 | if config.context_prior:
78 | exp_name += "_3DCRP"
79 |
80 | # Setup dataloaders
81 | if config.dataset == "OccScanNet":
82 | class_names = OccScanNet_class_names
83 | max_epochs = 60
84 | logdir = config.logdir
85 | full_scene_size = (60, 60, 36)
86 | project_scale = 1
87 | feature = 200
88 | n_classes = 12
89 | class_weights = OccScanNet_class_weights
90 | data_module = OccScanNetDataModule(
91 | root=config.OccScanNet_root,
92 | n_relations=config.n_relations,
93 | frustum_size=config.frustum_size,
94 | batch_size=int(config.batch_size / config.n_gpus),
95 | num_workers=int(config.num_workers_per_gpu * config.n_gpus),
96 | )
97 | elif config.dataset == "OccScanNet_mini":
98 | class_names = OccScanNet_class_names
99 | max_epochs = 60
100 | logdir = config.logdir
101 | full_scene_size = (60, 60, 36)
102 | project_scale = 1
103 | feature = 200
104 | n_classes = 12
105 | class_weights = OccScanNet_class_weights
106 | data_module = OccScanNetDataModule(
107 | root=config.OccScanNet_root,
108 | n_relations=config.n_relations,
109 | frustum_size=config.frustum_size,
110 | batch_size=int(config.batch_size / config.n_gpus),
111 | num_workers=int(config.num_workers_per_gpu * config.n_gpus),
112 | train_scenes_sample=4639,
113 | val_scenes_sample=2007,
114 | )
115 |
116 | elif config.dataset == "NYU":
117 | class_names = NYU_class_names
118 | max_epochs = 30
119 | logdir = config.logdir
120 | full_scene_size = (60, 36, 60)
121 | project_scale = 1
122 | feature = 200
123 | n_classes = 12
124 | class_weights = NYU_class_weights
125 | data_module = NYUDataModule(
126 | root=config.NYU_root,
127 | preprocess_root=config.NYU_preprocess_root,
128 | n_relations=config.n_relations,
129 | frustum_size=config.frustum_size,
130 | batch_size=int(config.batch_size / config.n_gpus),
131 | num_workers=int(config.num_workers_per_gpu * config.n_gpus),
132 | )
133 |
134 | project_res = ["1"]
135 | if config.project_1_2:
136 | exp_name += "_Proj_2"
137 | project_res.append("2")
138 | if config.project_1_4:
139 | exp_name += "_4"
140 | project_res.append("4")
141 | if config.project_1_8:
142 | exp_name += "_8"
143 | project_res.append("8")
144 |
145 | print(exp_name)
146 |
147 | os.chdir(hydra.utils.get_original_cwd())
148 |
149 | # Initialize ISO model
150 | model = ISO(
151 | dataset=config.dataset,
152 | frustum_size=config.frustum_size,
153 | project_scale=project_scale,
154 | n_relations=config.n_relations,
155 | fp_loss=config.fp_loss,
156 | feature=feature,
157 | full_scene_size=full_scene_size,
158 | project_res=project_res,
159 | voxeldepth=config.voxeldepth,
160 | voxeldepth_res=voxeldepth_res,
161 | n_classes=n_classes,
162 | class_names=class_names,
163 | context_prior=config.context_prior,
164 | relation_loss=config.relation_loss,
165 | CE_ssc_loss=config.CE_ssc_loss,
166 | sem_scal_loss=config.sem_scal_loss,
167 | geo_scal_loss=config.geo_scal_loss,
168 | lr=config.lr,
169 | weight_decay=config.weight_decay,
170 | class_weights=class_weights,
171 | use_gt_depth=config.use_gt_depth,
172 | add_fusion=config.add_fusion,
173 | use_zoedepth=config.use_zoedepth,
174 | use_depthanything=config.use_depthanything,
175 | zoedepth_as_gt=config.zoedepth_as_gt,
176 | depthanything_as_gt=config.depthanything_as_gt,
177 | frozen_encoder=config.frozen_encoder,
178 | )
179 |
180 | if config.enable_log:
181 | logger = TensorBoardLogger(save_dir=logdir, name=exp_name, version="")
182 | lr_monitor = LearningRateMonitor(logging_interval="step")
183 | checkpoint_callbacks = [
184 | ModelCheckpoint(
185 | save_last=True,
186 | monitor="val/mIoU",
187 | save_top_k=1,
188 | mode="max",
189 | filename="{epoch:03d}-{val/mIoU:.5f}",
190 | ),
191 | lr_monitor,
192 | ]
193 | else:
194 | logger = False
195 | checkpoint_callbacks = False
196 |
197 | model_path = os.path.join(logdir, exp_name, "checkpoints/last.ckpt")
198 | if os.path.isfile(model_path):
199 | # Continue training from last.ckpt
200 | trainer = Trainer(
201 | callbacks=checkpoint_callbacks,
202 | resume_from_checkpoint=model_path,
203 | sync_batchnorm=True,
204 | deterministic=False,
205 | max_epochs=max_epochs,
206 | devices=config.n_gpus,
207 | accelerator="gpu",
208 | logger=logger,
209 | check_val_every_n_epoch=1,
210 | log_every_n_steps=10,
211 | # flush_logs_every_n_steps=100,
212 | # strategy="ddp_find_unused_parameters_true",
213 | )
214 | else:
215 | # Train from scratch
216 | trainer = Trainer(
217 | callbacks=checkpoint_callbacks,
218 | sync_batchnorm=True,
219 | deterministic=False,
220 | max_epochs=max_epochs,
221 | devices=config.n_gpus,
222 | accelerator='gpu',
223 | logger=logger,
224 | check_val_every_n_epoch=1,
225 | log_every_n_steps=10,
226 | # flush_logs_every_n_steps=100,
227 | strategy="ddp_find_unused_parameters_true",
228 | )
229 | torch.set_float32_matmul_precision('high')
230 | # os.chdir("/home/hongxiao.yu/projects/ISO_occscannet")
231 | trainer.fit(model, data_module)
232 |
233 |
234 | if __name__ == "__main__":
235 | main()
236 |
--------------------------------------------------------------------------------
/iso/scripts/visualization/NYU_vis_pred.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | from omegaconf import DictConfig
4 | import numpy as np
5 | import hydra
6 | from mayavi import mlab
7 |
8 |
9 | def get_grid_coords(dims, resolution):
10 | """
11 | :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32])
12 | :return coords_grid: is the center coords of voxels in the grid
13 | """
14 |
15 | g_xx = np.arange(0, dims[0] + 1)
16 | g_yy = np.arange(0, dims[1] + 1)
17 |
18 | g_zz = np.arange(0, dims[2] + 1)
19 |
20 | # Obtaining the grid with coords...
21 | xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1])
22 | coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T
23 | coords_grid = coords_grid.astype(np.float)
24 |
25 | coords_grid = (coords_grid * resolution) + resolution / 2
26 |
27 | temp = np.copy(coords_grid)
28 | temp[:, 0] = coords_grid[:, 1]
29 | temp[:, 1] = coords_grid[:, 0]
30 | coords_grid = np.copy(temp)
31 |
32 | return coords_grid
33 |
34 |
35 | def draw(
36 | voxels,
37 | cam_pose,
38 | vox_origin,
39 | voxel_size=0.08,
40 | d=0.75, # 0.75m - determine the size of the mesh representing the camera
41 | ):
42 | # Compute the coordinates of the mesh representing camera
43 | y = d * 480 / (2 * 518.8579)
44 | x = d * 640 / (2 * 518.8579)
45 | tri_points = np.array(
46 | [
47 | [0, 0, 0],
48 | [x, y, d],
49 | [-x, y, d],
50 | [-x, -y, d],
51 | [x, -y, d],
52 | ]
53 | )
54 | tri_points = np.hstack([tri_points, np.ones((5, 1))])
55 |
56 | tri_points = (cam_pose @ tri_points.T).T
57 | x = tri_points[:, 0] - vox_origin[0]
58 | y = tri_points[:, 1] - vox_origin[1]
59 | z = tri_points[:, 2] - vox_origin[2]
60 | triangles = [
61 | (0, 1, 2),
62 | (0, 1, 4),
63 | (0, 3, 4),
64 | (0, 2, 3),
65 | ]
66 |
67 | # Compute the voxels coordinates
68 | grid_coords = get_grid_coords(
69 | [voxels.shape[0], voxels.shape[2], voxels.shape[1]], voxel_size
70 | )
71 |
72 | # Attach the predicted class to every voxel
73 | grid_coords = np.vstack(
74 | (grid_coords.T, np.moveaxis(voxels, [0, 1, 2], [0, 2, 1]).reshape(-1))
75 | ).T
76 |
77 | # Remove empty and unknown voxels
78 | occupied_voxels = grid_coords[(grid_coords[:, 3] > 0) & (grid_coords[:, 3] < 255)]
79 | figure = mlab.figure(size=(1600, 900), bgcolor=(1, 1, 1))
80 |
81 | # Draw the camera
82 | mlab.triangular_mesh(
83 | x,
84 | y,
85 | z,
86 | triangles,
87 | representation="wireframe",
88 | color=(0, 0, 0),
89 | line_width=5,
90 | )
91 |
92 | # Draw occupied voxels
93 | plt_plot = mlab.points3d(
94 | occupied_voxels[:, 0],
95 | occupied_voxels[:, 1],
96 | occupied_voxels[:, 2],
97 | occupied_voxels[:, 3],
98 | colormap="viridis",
99 | scale_factor=voxel_size - 0.1 * voxel_size,
100 | mode="cube",
101 | opacity=1.0,
102 | vmin=0,
103 | vmax=12,
104 | )
105 |
106 | colors = np.array(
107 | [
108 | [22, 191, 206, 255],
109 | [214, 38, 40, 255],
110 | [43, 160, 43, 255],
111 | [158, 216, 229, 255],
112 | [114, 158, 206, 255],
113 | [204, 204, 91, 255],
114 | [255, 186, 119, 255],
115 | [147, 102, 188, 255],
116 | [30, 119, 181, 255],
117 | [188, 188, 33, 255],
118 | [255, 127, 12, 255],
119 | [196, 175, 214, 255],
120 | [153, 153, 153, 255],
121 | ]
122 | )
123 |
124 | plt_plot.glyph.scale_mode = "scale_by_vector"
125 |
126 | plt_plot.module_manager.scalar_lut_manager.lut.table = colors
127 |
128 | mlab.show()
129 |
130 |
131 | @hydra.main(config_path=None)
132 | def main(config: DictConfig):
133 | scan = config.file
134 |
135 | with open(scan, "rb") as handle:
136 | b = pickle.load(handle)
137 |
138 | cam_pose = b["cam_pose"]
139 | vox_origin = b["vox_origin"]
140 | gt_scene = b["target"]
141 | pred_scene = b["y_pred"]
142 | scan = os.path.basename(scan)[:12]
143 |
144 | pred_scene[(gt_scene == 255)] = 255 # only draw scene inside the room
145 |
146 | draw(
147 | pred_scene,
148 | cam_pose,
149 | vox_origin,
150 | voxel_size=0.08,
151 | d=0.75,
152 | )
153 |
154 |
155 | if __name__ == "__main__":
156 | main()
157 |
--------------------------------------------------------------------------------
/iso/scripts/visualization/NYU_vis_pred_2.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | from omegaconf import DictConfig
4 | import numpy as np
5 | import hydra
6 | # from mayavi import mlab
7 | import open3d as o3d
8 |
9 |
10 | def get_grid_coords(dims, resolution):
11 | """
12 | :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32])
13 | :return coords_grid: is the center coords of voxels in the grid
14 | """
15 |
16 | # The sensor in centered in X (we go to dims/2 + 1 for the histogramdd)
17 | g_xx = np.arange(0, dims[0] + 1)
18 | # The sensor is in Y=0 (we go to dims + 1 for the histogramdd)
19 | g_yy = np.arange(0, dims[1] + 1)
20 | # The sensor is in Z=1.73. I observed that the ground was to voxel levels above the grid bottom, so Z pose is at 10
21 | # if bottom voxel is 0. If we want the sensor to be at (0, 0, 0), then the bottom in z is -10, top is 22
22 | # (we go to 22 + 1 for the histogramdd)
23 | # ATTENTION.. Is 11 for old grids.. 10 for new grids (v1.1) (https://github.com/PRBonn/semantic-kitti-api/issues/49)
24 | g_zz = np.arange(0, dims[2] + 1)
25 |
26 | # Obtaining the grid with coords...
27 | xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1])
28 | coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T
29 | coords_grid = coords_grid.astype(np.float)
30 |
31 | coords_grid = (coords_grid * resolution) + resolution / 2
32 |
33 | temp = np.copy(coords_grid)
34 | temp[:, 0] = coords_grid[:, 1]
35 | temp[:, 1] = coords_grid[:, 0]
36 | coords_grid = np.copy(temp)
37 |
38 | return coords_grid
39 |
40 |
41 |
42 | # def draw_semantic_open3d(
43 | # voxels,
44 | # cam_param_path="",
45 | # voxel_size=0.2):
46 |
47 |
48 |
49 | # grid_coords, _, _, _ = get_grid_coords([voxels.shape[0], voxels.shape[1], voxels.shape[2]], voxel_size)
50 |
51 | # points = np.vstack([grid_coords.T, voxels.reshape(-1)]).T
52 |
53 | # # Obtaining voxels with semantic class
54 | # points = points[(points[:, 3] != 0) & (points[:, 3] != 255)] # remove empty voxel and unknown class
55 |
56 | # colors = np.take_along_axis(colors, points[:, 3].astype(np.int32).reshape(-1, 1), axis=0)
57 |
58 | # vis = o3d.visualization.Visualizer()
59 | # vis.create_window(width=1200, height=600)
60 | # ctr = vis.get_view_control()
61 | # param = o3d.io.read_pinhole_camera_parameters(cam_param_path)
62 |
63 | # pcd = o3d.geometry.PointCloud()
64 | # pcd.points = o3d.utility.Vector3dVector(points[:, :3])
65 | # pcd.colors = o3d.utility.Vector3dVector(colors[:, :3])
66 | # pcd.estimate_normals()
67 | # vis.add_geometry(pcd)
68 |
69 | # ctr.convert_from_pinhole_camera_parameters(param)
70 |
71 | # vis.run() # user changes the view and press "q" to terminate
72 | # param = vis.get_view_control().convert_to_pinhole_camera_parameters()
73 | # o3d.io.write_pinhole_camera_parameters(cam_param_path, param)
74 |
75 |
76 |
77 |
78 |
79 | def draw(
80 | voxels,
81 | cam_pose,
82 | vox_origin,
83 | voxel_size=0.08,
84 | d=0.75, # 0.75m - determine the size of the mesh representing the camera
85 | ):
86 | # Compute the coordinates of the mesh representing camera
87 | y = d * 480 / (2 * 518.8579)
88 | x = d * 640 / (2 * 518.8579)
89 | tri_points = np.array(
90 | [
91 | [0, 0, 0],
92 | [x, y, d],
93 | [-x, y, d],
94 | [-x, -y, d],
95 | [x, -y, d],
96 | ]
97 | )
98 | tri_points = np.hstack([tri_points, np.ones((5, 1))])
99 |
100 | tri_points = (cam_pose @ tri_points.T).T
101 | x = tri_points[:, 0] - vox_origin[0]
102 | y = tri_points[:, 1] - vox_origin[1]
103 | z = tri_points[:, 2] - vox_origin[2]
104 | triangles = [
105 | (0, 1, 2),
106 | (0, 1, 4),
107 | (0, 3, 4),
108 | (0, 2, 3),
109 | ]
110 |
111 |
112 | # Compute the voxels coordinates
113 | grid_coords = get_grid_coords(
114 | [voxels.shape[0], voxels.shape[2], voxels.shape[1]], voxel_size
115 | )
116 |
117 | # Attach the predicted class to every voxel
118 | grid_coords = np.vstack(
119 | (grid_coords.T, np.moveaxis(voxels, [0, 1, 2], [0, 2, 1]).reshape(-1))
120 | ).T
121 |
122 | # Remove empty and unknown voxels
123 | occupied_voxels = grid_coords[(grid_coords[:, 3] > 0) & (grid_coords[:, 3] < 255)]
124 |
125 | # # Draw the camera
126 | # mlab.triangular_mesh(
127 | # x,
128 | # y,
129 | # z,
130 | # triangles,
131 | # representation="wireframe",
132 | # color=(0, 0, 0),
133 | # line_width=5,
134 | # )
135 |
136 | # # Draw occupied voxels
137 | # plt_plot = mlab.points3d(
138 | # occupied_voxels[:, 0],
139 | # occupied_voxels[:, 1],
140 | # occupied_voxels[:, 2],
141 | # occupied_voxels[:, 3],
142 | # colormap="viridis",
143 | # scale_factor=voxel_size - 0.1 * voxel_size,
144 | # mode="cube",
145 | # opacity=1.0,
146 | # vmin=0,
147 | # vmax=12,
148 | # )
149 |
150 | colors = np.array(
151 | [
152 | [22, 191, 206, 255],
153 | [214, 38, 40, 255],
154 | [43, 160, 43, 255],
155 | [158, 216, 229, 255],
156 | [114, 158, 206, 255],
157 | [204, 204, 91, 255],
158 | [255, 186, 119, 255],
159 | [147, 102, 188, 255],
160 | [30, 119, 181, 255],
161 | [188, 188, 33, 255],
162 | [255, 127, 12, 255],
163 | [196, 175, 214, 255],
164 | [153, 153, 153, 255],
165 | ]
166 | ) / 255.0
167 |
168 | colors = np.take_along_axis(colors, occupied_voxels[:, 3].astype(np.int32).reshape(-1, 1), axis=0)
169 |
170 | vis = o3d.visualization.Visualizer()
171 | vis.create_window(width=1200, height=600)
172 | ctr = vis.get_view_control()
173 | param = o3d.io.read_pinhole_camera_parameters(cam_pose)
174 |
175 | pcd = o3d.geometry.PointCloud()
176 | pcd.points = o3d.utility.Vector3dVector(occupied_voxels[:, :3])
177 | pcd.colors = o3d.utility.Vector3dVector(colors[:, :3])
178 | pcd.estimate_normals()
179 | vis.add_geometry(pcd)
180 |
181 | ctr.convert_from_pinhole_camera_parameters(cam_pose)
182 |
183 | vis.run() # user changes the view and press "q" to terminate
184 | param = vis.get_view_control().convert_to_pinhole_camera_parameters()
185 | # o3d.io.write_pinhole_camera_parameters(cam_param_path, param)
186 |
187 |
188 | @hydra.main(config_path=None)
189 | def main(config: DictConfig):
190 | scan = config.file
191 |
192 | with open(scan, "rb") as handle:
193 | b = pickle.load(handle)
194 |
195 | cam_pose = b["cam_pose"]
196 | vox_origin = b["vox_origin"]
197 | gt_scene = b["target"]
198 | pred_scene = b["y_pred"]
199 | scan = os.path.basename(scan)[:12]
200 |
201 | pred_scene[(gt_scene == 255)] = 255 # only draw scene inside the room
202 |
203 | draw(
204 | pred_scene,
205 | cam_pose,
206 | vox_origin,
207 | voxel_size=0.08,
208 | d=0.75,
209 | )
210 |
211 |
212 | if __name__ == "__main__":
213 | main()
214 |
--------------------------------------------------------------------------------
/iso/scripts/visualization/OccScanNet_vis_pred.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mayavi import mlab
3 | import argparse
4 |
5 |
6 | def load_voxels(path):
7 | """Load voxel labels from file.
8 |
9 | Args:
10 | path (str): The path of the voxel labels file.
11 |
12 | Returns:
13 | ndarray: The voxel labels with shape (N, 4), 4 is for [x, y, z, label].
14 | """
15 | labels = np.load(path)
16 | if labels.shape[1] == 7:
17 | labels = labels[:, [0, 1, 2, 6]]
18 |
19 | return labels
20 |
21 |
22 | def draw(voxel_label, voxel_size=0.05, intrinsic=None, cam_pose=None, d=0.5):
23 | """Visualize the gt or predicted voxel labels.
24 |
25 | Args:
26 | voxel_label (ndarray): The gt or predicted voxel label, with shape (N, 4), N is for number
27 | of voxels, 7 is for [x, y, z, label].
28 | voxel_size (double): The size of each voxel.
29 | intrinsic (ndarray): The camera intrinsics.
30 | cam_pose (ndarray): The camera pose.
31 | d (double): The depth of camera model visualization.
32 | """
33 | figure = mlab.figure(size=(1600*0.8, 900*0.8), bgcolor=(1, 1, 1))
34 |
35 | # voxel_origin = np.array([-0.6619388,
36 | # -2.3863946,
37 | # -0.05 ])
38 | # voxel_1 = voxel_origin + np.array([4.8, 0, 0])
39 | # voxel_2 = voxel_origin + np.array([0, 4.8, 0])
40 | # voxel_3 = voxel_origin + np.array([0, 0, 4.8])
41 | # voxel_4 = voxel_origin + np.array([4.8, 4.8, 0])
42 | # voxel_5 = voxel_origin + np.array([4.8, 0, 4.8])
43 | # voxel_6 = voxel_origin + np.array([0, 4.8, 4.8])
44 | # voxel_7 = voxel_origin + np.array([4.8, 4.8, 4.8])
45 | # voxels = np.vstack([voxel_origin, voxel_1, voxel_2, voxel_3, voxel_4, voxel_5, voxel_6, voxel_7])
46 | # print(voxels.shape)
47 | # x = voxels[:, 0]
48 | # y = voxels[:, 1]
49 | # z = voxels[:, 2]
50 | # sqs = [
51 | # (0, 1, 2),
52 | # (0, 1, 3),
53 | # (0, 2, 3),
54 | # (1, 2, 4),
55 | # (1, 3, 5),
56 | # (2, 3, 6),
57 | # (3, 5, 6),
58 | # (5, 6, 7),
59 | # ]
60 |
61 | # # draw cam model
62 | # mlab.triangular_mesh(
63 | # x,
64 | # y,
65 | # z,
66 | # sqs,
67 | # representation="wireframe",
68 | # color=(1, 0, 0),
69 | # line_width=7.5,
70 | # )
71 |
72 |
73 | if intrinsic is not None and cam_pose is not None:
74 | assert d > 0, 'camera model d should > 0'
75 | fx = intrinsic[0, 0]
76 | fy = intrinsic[1, 1]
77 | cx = intrinsic[0, 2]
78 | cy = intrinsic[1, 2]
79 |
80 | # half of the image plane size
81 | y = d * 2 * cy / (2 * fy)
82 | x = d * 2 * cx / (2 * fx)
83 |
84 | # camera points (cam frame)
85 | tri_points = np.array(
86 | [
87 | [0, 0, 0],
88 | [x, y, d],
89 | [-x, y, d],
90 | [-x, -y, d],
91 | [x, -y, d],
92 | ]
93 | )
94 | tri_points = np.hstack([tri_points, np.ones((5, 1))])
95 |
96 | # camera points (world frame)
97 | tri_points = (cam_pose @ tri_points.T).T
98 | x = tri_points[:, 0]
99 | y = tri_points[:, 1]
100 | z = tri_points[:, 2]
101 | triangles = [
102 | (0, 1, 2),
103 | (0, 1, 4),
104 | (0, 3, 4),
105 | (0, 2, 3),
106 | ]
107 |
108 | # draw cam model
109 | mlab.triangular_mesh(
110 | x,
111 | y,
112 | z,
113 | triangles,
114 | representation="wireframe",
115 | color=(0, 0, 0),
116 | line_width=7.5,
117 | )
118 |
119 | # draw occupied voxels
120 | plt_plot = mlab.points3d(
121 | voxel_label[:, 0],
122 | voxel_label[:, 1],
123 | voxel_label[:, 2],
124 | voxel_label[:, 3],
125 | colormap="viridis",
126 | scale_factor=voxel_size - 0.1 * voxel_size,
127 | mode="cube",
128 | opacity=1.0,
129 | vmin=0,
130 | vmax=12,
131 | )
132 |
133 | # label colors
134 | colors = np.array(
135 | [
136 | [0, 0, 0, 255], # 0 empty
137 | [255, 202, 251, 255], # 1 ceiling
138 | [208, 192, 122, 255], # 2 floor
139 | [199, 210, 255, 255], # 3 wall
140 | [82, 42, 127, 255], # 4 window
141 | [224, 250, 30, 255], # 5 chair
142 | [255, 0, 65, 255], # 6 bed
143 | [144, 177, 144, 255], # 7 sofa
144 | [246, 110, 31, 255], # 8 table
145 | [0, 216, 0, 255], # 9 tv
146 | [135, 177, 214, 255], # 10 furniture
147 | [1, 92, 121, 255], # 11 objects
148 | [128, 128, 128, 255], # 12 occupied with semantic
149 | ]
150 | )
151 |
152 | plt_plot.glyph.scale_mode = "scale_by_vector"
153 |
154 | plt_plot.module_manager.scalar_lut_manager.lut.table = colors
155 |
156 | mlab.show()
157 |
158 |
159 | def parse_args():
160 | parser = argparse.ArgumentParser(description="CompleteScanNet dataset visualization.")
161 | parser.add_argument("--file", type=str, help="Voxel label file path.", required=True)
162 | args = parser.parse_args()
163 | return args
164 |
165 |
166 | if __name__ == "__main__":
167 | args = parse_args()
168 | voxels = load_voxels(args.file)
169 | draw(voxels, voxel_size=0.05, d=0.5)
--------------------------------------------------------------------------------
/iso/scripts/visualization/kitti_vis_pred.py:
--------------------------------------------------------------------------------
1 | # from operator import gt
2 | import pickle
3 | import numpy as np
4 | from omegaconf import DictConfig
5 | import hydra
6 | from mayavi import mlab
7 |
8 |
9 | def get_grid_coords(dims, resolution):
10 | """
11 | :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32])
12 | :return coords_grid: is the center coords of voxels in the grid
13 | """
14 |
15 | g_xx = np.arange(0, dims[0] + 1)
16 | g_yy = np.arange(0, dims[1] + 1)
17 | sensor_pose = 10
18 | g_zz = np.arange(0, dims[2] + 1)
19 |
20 | # Obtaining the grid with coords...
21 | xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1])
22 | coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T
23 | coords_grid = coords_grid.astype(np.float)
24 |
25 | coords_grid = (coords_grid * resolution) + resolution / 2
26 |
27 | temp = np.copy(coords_grid)
28 | temp[:, 0] = coords_grid[:, 1]
29 | temp[:, 1] = coords_grid[:, 0]
30 | coords_grid = np.copy(temp)
31 |
32 | return coords_grid
33 |
34 |
35 | def draw(
36 | voxels,
37 | T_velo_2_cam,
38 | vox_origin,
39 | fov_mask,
40 | img_size,
41 | f,
42 | voxel_size=0.2,
43 | d=7, # 7m - determine the size of the mesh representing the camera
44 | ):
45 | # Compute the coordinates of the mesh representing camera
46 | x = d * img_size[0] / (2 * f)
47 | y = d * img_size[1] / (2 * f)
48 | tri_points = np.array(
49 | [
50 | [0, 0, 0],
51 | [x, y, d],
52 | [-x, y, d],
53 | [-x, -y, d],
54 | [x, -y, d],
55 | ]
56 | )
57 | tri_points = np.hstack([tri_points, np.ones((5, 1))])
58 | tri_points = (np.linalg.inv(T_velo_2_cam) @ tri_points.T).T
59 | x = tri_points[:, 0] - vox_origin[0]
60 | y = tri_points[:, 1] - vox_origin[1]
61 | z = tri_points[:, 2] - vox_origin[2]
62 | triangles = [
63 | (0, 1, 2),
64 | (0, 1, 4),
65 | (0, 3, 4),
66 | (0, 2, 3),
67 | ]
68 |
69 | # Compute the voxels coordinates
70 | grid_coords = get_grid_coords(
71 | [voxels.shape[0], voxels.shape[1], voxels.shape[2]], voxel_size
72 | )
73 |
74 | # Attach the predicted class to every voxel
75 | grid_coords = np.vstack([grid_coords.T, voxels.reshape(-1)]).T
76 |
77 | # Get the voxels inside FOV
78 | fov_grid_coords = grid_coords[fov_mask, :]
79 |
80 | # Get the voxels outside FOV
81 | outfov_grid_coords = grid_coords[~fov_mask, :]
82 |
83 | # Remove empty and unknown voxels
84 | fov_voxels = fov_grid_coords[
85 | (fov_grid_coords[:, 3] > 0) & (fov_grid_coords[:, 3] < 255)
86 | ]
87 | outfov_voxels = outfov_grid_coords[
88 | (outfov_grid_coords[:, 3] > 0) & (outfov_grid_coords[:, 3] < 255)
89 | ]
90 |
91 | figure = mlab.figure(size=(1400, 1400), bgcolor=(1, 1, 1))
92 |
93 | # Draw the camera
94 | mlab.triangular_mesh(
95 | x, y, z, triangles, representation="wireframe", color=(0, 0, 0), line_width=5
96 | )
97 |
98 | # Draw occupied inside FOV voxels
99 | plt_plot_fov = mlab.points3d(
100 | fov_voxels[:, 0],
101 | fov_voxels[:, 1],
102 | fov_voxels[:, 2],
103 | fov_voxels[:, 3],
104 | colormap="viridis",
105 | scale_factor=voxel_size - 0.05 * voxel_size,
106 | mode="cube",
107 | opacity=1.0,
108 | vmin=1,
109 | vmax=19,
110 | )
111 |
112 | # Draw occupied outside FOV voxels
113 | plt_plot_outfov = mlab.points3d(
114 | outfov_voxels[:, 0],
115 | outfov_voxels[:, 1],
116 | outfov_voxels[:, 2],
117 | outfov_voxels[:, 3],
118 | colormap="viridis",
119 | scale_factor=voxel_size - 0.05 * voxel_size,
120 | mode="cube",
121 | opacity=1.0,
122 | vmin=1,
123 | vmax=19,
124 | )
125 |
126 | colors = np.array(
127 | [
128 | [100, 150, 245, 255],
129 | [100, 230, 245, 255],
130 | [30, 60, 150, 255],
131 | [80, 30, 180, 255],
132 | [100, 80, 250, 255],
133 | [255, 30, 30, 255],
134 | [255, 40, 200, 255],
135 | [150, 30, 90, 255],
136 | [255, 0, 255, 255],
137 | [255, 150, 255, 255],
138 | [75, 0, 75, 255],
139 | [175, 0, 75, 255],
140 | [255, 200, 0, 255],
141 | [255, 120, 50, 255],
142 | [0, 175, 0, 255],
143 | [135, 60, 0, 255],
144 | [150, 240, 80, 255],
145 | [255, 240, 150, 255],
146 | [255, 0, 0, 255],
147 | ]
148 | ).astype(np.uint8)
149 |
150 | plt_plot_fov.glyph.scale_mode = "scale_by_vector"
151 | plt_plot_outfov.glyph.scale_mode = "scale_by_vector"
152 |
153 | plt_plot_fov.module_manager.scalar_lut_manager.lut.table = colors
154 |
155 | outfov_colors = colors
156 | outfov_colors[:, :3] = outfov_colors[:, :3] // 3 * 2
157 | plt_plot_outfov.module_manager.scalar_lut_manager.lut.table = outfov_colors
158 |
159 | mlab.show()
160 |
161 |
162 | @hydra.main(config_path=None)
163 | def main(config: DictConfig):
164 | scan = config.file
165 | with open(scan, "rb") as handle:
166 | b = pickle.load(handle)
167 |
168 | fov_mask_1 = b["fov_mask_1"]
169 | T_velo_2_cam = b["T_velo_2_cam"]
170 | vox_origin = np.array([0, -25.6, -2])
171 |
172 | y_pred = b["y_pred"]
173 |
174 | if config.dataset == "kitti_360":
175 | # Visualize KITTI-360
176 | draw(
177 | y_pred,
178 | T_velo_2_cam,
179 | vox_origin,
180 | fov_mask_1,
181 | voxel_size=0.2,
182 | f=552.55426,
183 | img_size=(1408, 376),
184 | d=7,
185 | )
186 | else:
187 | # Visualize Semantic KITTI
188 | draw(
189 | y_pred,
190 | T_velo_2_cam,
191 | vox_origin,
192 | fov_mask_1,
193 | img_size=(1220, 370),
194 | f=707.0912,
195 | voxel_size=0.2,
196 | d=7,
197 | )
198 |
199 |
200 | if __name__ == "__main__":
201 | main()
202 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit-image==0.18.1
2 | PyYAML==5.4
3 | tqdm==4.57.0
4 | scikit-learn==0.24.0
5 | pytorch-lightning==2.0.0
6 | opencv-python==4.5.1.48
7 | hydra-core==1.0.5
8 | numpy==1.20.3
9 | numba==0.53.0
10 | imageio==2.34.1
11 | protobuf==3.19.6
12 | tensorboard
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from setuptools import find_packages
3 |
4 | # for install, do: pip install -ve .
5 |
6 | setup(name='iso', packages=find_packages())
7 |
--------------------------------------------------------------------------------
/trained_models/.placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hongxiaoy/ISO/12d30d244479d52fe64bdc6402bf3dfb4d43503b/trained_models/.placeholder
--------------------------------------------------------------------------------