├── .gitignore ├── .gitmodules ├── LICENSE ├── NYUv2.gif ├── README.md ├── checkpoints └── .placeholder ├── iso ├── config │ ├── iso.yaml │ ├── iso_occscannet.yaml │ └── iso_occscannet_mini.yaml ├── data │ ├── NYU │ │ ├── collate.py │ │ ├── nyu_dataset.py │ │ ├── nyu_dm.py │ │ ├── params.py │ │ └── preprocess.py │ ├── OccScanNet │ │ ├── collate.py │ │ ├── occscannet_dataset.py │ │ ├── occscannet_dm.py │ │ └── params.py │ └── utils │ │ ├── fusion.py │ │ ├── helpers.py │ │ └── torch_util.py ├── loss │ ├── CRP_loss.py │ ├── depth_loss.py │ ├── sscMetrics.py │ └── ssc_loss.py ├── models │ ├── CRP3D.py │ ├── DDR.py │ ├── depthnet.py │ ├── flosp.py │ ├── iso.py │ ├── modules.py │ ├── unet2d.py │ └── unet3d_nyu.py └── scripts │ ├── eval.sh │ ├── eval_iso.py │ ├── generate_output.py │ ├── train.sh │ ├── train_iso.py │ └── visualization │ ├── NYU_vis_pred.py │ ├── NYU_vis_pred_2.py │ ├── OccScanNet_vis_pred.py │ └── kitti_vis_pred.py ├── requirements.txt ├── setup.py └── trained_models └── .placeholder /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | *.egg-info/ 4 | outputs/ 5 | trained_models/*.ckpt 6 | checkpoints/*.pt 7 | iso_output/* 8 | core* 9 | images 10 | .vscode -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "depth_anything"] 2 | path = depth_anything 3 | url = https://github.com/LiheYoung/Depth-Anything.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NYUv2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongxiaoy/ISO/12d30d244479d52fe64bdc6402bf3dfb4d43503b/NYUv2.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Monocular Occupancy Prediction for Scalable Indoor Scenes

3 | 4 | [**Hongxiao Yu**](https://orcid.org/0009-0003-9249-2726)1,2 · [**Yuqi Wang**](https://orcid.org/0000-0002-6360-1431)1,2 · [**Yuntao Chen**](https://orcid.org/0000-0002-9555-1897)3 · [**Zhaoxiang Zhang**](https://orcid.org/0000-0003-2648-3875)1,2,3 5 | 6 | 1School of Artificial Intelligence, University of Chinese Academy of Sciences (UCAS) 7 | 8 | 2NLPR, MAIS, Institute of Automation, Chinese Academy of Sciences (CASIA) 9 | 10 | 3Centre for Artificial Intelligence and Robotics (HKISI_CAS) 11 | 12 | **ECCV 2024** 13 | 14 | [![Static Badge](https://img.shields.io/badge/arXiv-2407.11730-red)](https://arxiv.org/abs/2407.11730) [![Static Badge](https://img.shields.io/badge/Project%20Page-ISO-blue)](https://hongxiaoy.github.io/ISO) 15 | [![Static Badge](https://img.shields.io/badge/Demo-Hugging%20Face-yellow)](https://huggingface.co/spaces/hongxiaoy/ISO) 16 | 17 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/monocular-occupancy-prediction-for-scalable/3d-semantic-scene-completion-from-a-single)](https://paperswithcode.com/sota/3d-semantic-scene-completion-from-a-single?p=monocular-occupancy-prediction-for-scalable) 18 | 19 | 20 | 21 | 22 |
23 | 24 | # Performance 25 | 26 | Here we compare our ISO with the previously best NDC-Scene and MonoScene model. 27 | 28 | | Method | IoU | ceiling | floor | wall | window | chair | bed | sofa | table | tvs | furniture | object | mIoU | 29 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| 30 | | MonoScene | 42.51 | 8.89 | 93.50 | 12.06 | 12.57 | 13.72 | 48.19 | 36.11 | 15.13 | 15.22 | 27.96 | 12.94 | 26.94 | 31 | | NDC-Scene| 44.17 | 12.02 | **93.51** | 13.11 | 13.77 | 15.83 | 49.57 | 39.87 | 17.17 | 24.57 | 31.00 | 14.96 | 29.03 | 32 | | Ours | **47.11** | **14.21** | 93.47 | **15.89** | **15.14** | **18.35** | **50.01** | **40.82** | **18.25** | **25.90** | **34.08** | **17.67** | **31.25** | 33 | 34 | We highlight the **best** results in **bold**. 35 | 36 | Pretrained models on NYUv2 can be downloaded [here](https://huggingface.co/hongxiaoy/ISO/tree/main). 37 | 38 | # Preparing ISO 39 | 40 | ## Installation 41 | 42 | 1. Create conda environment: 43 | 44 | ``` 45 | $ conda create -n iso python=3.9 -y 46 | $ conda activate iso 47 | ``` 48 | 2. This code was implemented with python 3.9, pytorch 2.0.0 and CUDA 11.7. Please install [PyTorch](https://pytorch.org/): 49 | 50 | ``` 51 | $ conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia 52 | ``` 53 | 54 | 3. Install the additional dependencies: 55 | 56 | ``` 57 | $ git clone --recursive https://github.com/hongxiaoy/ISO.git 58 | $ cd ISO/ 59 | $ pip install -r requirements.txt 60 | ``` 61 | 62 | > :bulb:Note 63 | > 64 | > Change L140 in ```depth_anything/metric_depth/zoedepth/models/base_models/dpt_dinov2/dpt.py``` to 65 | > 66 | > ```self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder), pretrained=False)``` 67 | > 68 | > Then, download Depth-Anything pre-trained [model](https://github.com/LiheYoung/Depth-Anything/tree/main#no-network-connection-cannot-load-these-models) and metric depth [model](https://github.com/LiheYoung/Depth-Anything/tree/main/metric_depth#evaluation) checkpoints file to ```checkpoints/```. 69 | 70 | 4. Install tbb: 71 | 72 | ``` 73 | $ conda install -c bioconda tbb=2020.2 74 | ``` 75 | 76 | 5. Finally, install ISO: 77 | 78 | ``` 79 | $ pip install -e ./ 80 | ``` 81 | 82 | > :bulb:Note 83 | > 84 | > If you move the ISO dir to another place, you should run 85 | > 86 | > ```pip cache purge``` 87 | > 88 | > then run ```pip install -e ./``` again. 89 | 90 | ## Datasets 91 | 92 | ### NYUv2 93 | 94 | 1. Download the [NYUv2 dataset](https://www.rocq.inria.fr/rits_files/computer-vision/monoscene/nyu.zip). 95 | 96 | 2. Create a folder to store NYUv2 preprocess data at `/path/to/NYU/preprocess/folder`. 97 | 98 | 3. Store paths in environment variables for faster access: 99 | 100 | ``` 101 | $ export NYU_PREPROCESS=/path/to/NYU/preprocess/folder 102 | $ export NYU_ROOT=/path/to/NYU/depthbin 103 | ``` 104 | 105 | > :bulb:Note 106 | > 107 | > Recommend using 108 | > 109 | > ```echo "export NYU_PREPROCESS=/path/to/NYU/preprocess/folder" >> ~/.bashrc``` 110 | > 111 | > format command for future convenience. 112 | 113 | 4. Preprocess the data to generate labels at a lower scale, which are used to compute the ground truth relation matrices: 114 | 115 | ``` 116 | $ cd ISO/ 117 | $ python iso/data/NYU/preprocess.py NYU_root=$NYU_ROOT NYU_preprocess_root=$NYU_PREPROCESS 118 | ``` 119 | 120 | ### Occ-ScanNet 121 | 122 | 1. Download the [Occ-ScanNet dataset](https://huggingface.co/datasets/hongxiaoy/OccScanNet), this include: 123 | - `posed_images` 124 | - `gathered_data` 125 | - `train_subscenes.txt` 126 | - `val_subscenes.txt` 127 | 128 | 2. Create a root folder to store Occ-ScanNet dataset `/path/to/Occ/ScanNet/folder`, and move the all dataset files to this folder, zip files need extraction. 129 | 130 | 3. Store paths in environment variables for faster access: 131 | 132 | ``` 133 | $ export OCC_SCANNET_ROOT=/path/to/Occ/ScanNet/folder 134 | ``` 135 | 136 | > :bulb:Note 137 | > 138 | > Recommend using 139 | > 140 | > ```echo "export OCC_SCANNET_ROOT=/path/to/Occ/ScanNet/folder" >> ~/.bashrc``` 141 | > 142 | > format command for future convenience. 143 | 144 | ## Pretrained Models 145 | 146 | Download ISO pretrained models [on NYUv2](https://huggingface.co/hongxiaoy/ISO/tree/main), then put them in the folder `/path/to/ISO/trained_models`. 147 | 148 | ```bash 149 | huggingface-cli download --repo-type model hongxiaoy/ISO 150 | ``` 151 | 152 | If you didn't install `huggingface-cli` before, please following [official instructions](https://huggingface.co/docs/hub/en/models-adding-libraries#installation). 153 | 154 | # Running ISO 155 | 156 | ## Training 157 | 158 | ### NYUv2 159 | 160 | 1. Create folders to store training logs at **/path/to/NYU/logdir**. 161 | 162 | 2. Store in an environment variable: 163 | 164 | ``` 165 | $ export NYU_LOG=/path/to/NYU/logdir 166 | ``` 167 | 168 | 3. Train ISO using 2 GPUs with batch_size of 4 (2 item per GPU) on NYUv2: 169 | ``` 170 | $ cd ISO/ 171 | $ python iso/scripts/train_iso.py \ 172 | dataset=NYU \ 173 | NYU_root=$NYU_ROOT \ 174 | NYU_preprocess_root=$NYU_PREPROCESS \ 175 | logdir=$NYU_LOG \ 176 | n_gpus=2 batch_size=4 177 | ``` 178 | 179 | ### Occ-ScanNet 180 | 181 | 1. Create folders to store training logs at **/path/to/OccScanNet/logdir**. 182 | 183 | 2. Store in an environment variable: 184 | 185 | ``` 186 | $ export OCC_SCANNET_LOG=/path/to/OccScanNet/logdir 187 | ``` 188 | 189 | 3. Train ISO using 2 GPUs with batch_size of 4 (2 item per GPU) on Occ-ScanNet (should match config file name in train_iso.py): 190 | ``` 191 | $ cd ISO/ 192 | $ python iso/scripts/train_iso.py \ 193 | dataset=OccScanNet \ 194 | OccScanNet_root=$OCC_SCANNET_ROOT \ 195 | logdir=$OCC_SCANNET_LOG \ 196 | n_gpus=2 batch_size=4 197 | ``` 198 | 199 | ## Evaluating 200 | 201 | ### NYUv2 202 | 203 | To evaluate ISO on NYUv2 test set, type: 204 | 205 | ``` 206 | $ cd ISO/ 207 | $ python iso/scripts/eval_iso.py \ 208 | dataset=NYU \ 209 | NYU_root=$NYU_ROOT\ 210 | NYU_preprocess_root=$NYU_PREPROCESS \ 211 | n_gpus=1 batch_size=1 212 | ``` 213 | 214 | ## Inference 215 | 216 | Please create folder **/path/to/iso/output** to store the ISO outputs and store in environment variable: 217 | 218 | ``` 219 | export ISO_OUTPUT=/path/to/iso/output 220 | ``` 221 | 222 | ### NYUv2 223 | 224 | To generate the predictions on the NYUv2 test set, type: 225 | 226 | ``` 227 | $ cd ISO/ 228 | $ python iso/scripts/generate_output.py \ 229 | +output_path=$ISO_OUTPUT \ 230 | dataset=NYU \ 231 | NYU_root=$NYU_ROOT \ 232 | NYU_preprocess_root=$NYU_PREPROCESS \ 233 | n_gpus=1 batch_size=1 234 | ``` 235 | 236 | ## Visualization 237 | 238 | You need to create a new Anaconda environment for visualization. 239 | 240 | ```bash 241 | conda create -n mayavi_vis python=3.7 -y 242 | conda activate mayavi_vis 243 | pip install omegaconf hydra-core PyQt5 mayavi 244 | ``` 245 | 246 | If you meet some problem when installing `mayavi`, please refer to the following instructions: 247 | 248 | - [Official mayavi installation instruction](https://docs.enthought.com/mayavi/installation.html) 249 | 250 | 251 | ### NYUv2 252 | ``` 253 | $ cd ISO/ 254 | $ python iso/scripts/visualization/NYU_vis_pred.py +file=/path/to/output/file.pkl 255 | ``` 256 | 257 | 258 | # Aknowledgement 259 | 260 | This project is built based on MonoScene. Please refer to (https://github.com/astra-vision/MonoScene) for more documentations and details. 261 | 262 | We would like to thank the creators, maintainers, and contributors of the [MonoScene](https://github.com/astra-vision/MonoScene), [NDC-Scene](https://github.com/Jiawei-Yao0812/NDCScene), [ZoeDepth](https://github.com/isl-org/ZoeDepth), [Depth Anything](https://github.com/LiheYoung/Depth-Anything) for their invaluable work. Their dedication and open-source spirit have been instrumental in our development. 263 | 264 | # Citation 265 | 266 | ``` 267 | @article{yu2024monocular, 268 | title={Monocular Occupancy Prediction for Scalable Indoor Scenes}, 269 | author={Yu, Hongxiao and Wang, Yuqi and Chen, Yuntao and Zhang, Zhaoxiang}, 270 | journal={arXiv preprint arXiv:2407.11730}, 271 | year={2024} 272 | } 273 | ``` 274 | -------------------------------------------------------------------------------- /checkpoints/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongxiaoy/ISO/12d30d244479d52fe64bdc6402bf3dfb4d43503b/checkpoints/.placeholder -------------------------------------------------------------------------------- /iso/config/iso.yaml: -------------------------------------------------------------------------------- 1 | dataset: "NYU" # "kitti", "kitti_360" 2 | #dataset: "kitti_360" 3 | 4 | n_relations: 4 5 | 6 | enable_log: true 7 | #kitti_root: '/path/to/semantic_kitti' 8 | #kitti_preprocess_root: '/path/to/kitti/preprocess/folder' 9 | #kitti_logdir: '/path/to/semantic_kitti/logdir' 10 | 11 | NYU_root: "/mnt/vdb1/hongxiao.yu/data/NYU_dataset/depthbin/" 12 | NYU_preprocess_root: "/mnt/vdb1/hongxiao.yu/data/NYU_dataset/preprocess/" 13 | logdir: "/mnt/vdb1/hongxiao.yu/logs/ISO_logs/" 14 | 15 | 16 | fp_loss: true 17 | frustum_size: 8 18 | batch_size: 1 19 | n_gpus: 1 20 | num_workers_per_gpu: 0 21 | exp_prefix: "iso" 22 | run: 1 23 | lr: 1e-4 24 | weight_decay: 1e-4 25 | 26 | context_prior: true 27 | 28 | relation_loss: true 29 | CE_ssc_loss: true 30 | sem_scal_loss: true 31 | geo_scal_loss: true 32 | 33 | project_1_2: true 34 | project_1_4: true 35 | project_1_8: true 36 | 37 | voxeldepth: true 38 | voxeldepthcfg: 39 | depth_scale_1: true 40 | depth_scale_2: false 41 | depth_scale_4: false 42 | depth_scale_8: false 43 | 44 | use_gt_depth: false 45 | use_zoedepth: false 46 | use_depthanything: true 47 | zoedepth_as_gt: false 48 | depthanything_as_gt: false 49 | 50 | add_fusion: true 51 | 52 | frozen_encoder: true 53 | -------------------------------------------------------------------------------- /iso/config/iso_occscannet.yaml: -------------------------------------------------------------------------------- 1 | dataset: "OccScanNet" # "NYU", "OccScanNet" 2 | 3 | n_relations: 4 4 | 5 | enable_log: true 6 | 7 | logdir: "/mnt/vdb1/hongxiao.yu/logs/ISO_occscannet_logs/" 8 | OccScanNet_root: "/mnt/vdb1/hongxiao.yu/data/NYU_dataset/depthbin/" 9 | 10 | 11 | fp_loss: true 12 | frustum_size: 8 13 | batch_size: 1 14 | n_gpus: 1 15 | num_workers_per_gpu: 0 16 | exp_prefix: "iso" 17 | run: 1 18 | lr: 1e-4 19 | weight_decay: 1e-4 20 | 21 | context_prior: true 22 | 23 | relation_loss: true 24 | CE_ssc_loss: true 25 | sem_scal_loss: true 26 | geo_scal_loss: true 27 | 28 | project_1_2: true 29 | project_1_4: true 30 | project_1_8: true 31 | 32 | voxeldepth: true 33 | voxeldepthcfg: 34 | depth_scale_1: true 35 | depth_scale_2: false 36 | depth_scale_4: false 37 | depth_scale_8: false 38 | 39 | use_gt_depth: false 40 | use_zoedepth: false 41 | use_depthanything: true 42 | zoedepth_as_gt: false 43 | depthanything_as_gt: false 44 | 45 | add_fusion: true 46 | 47 | frozen_encoder: true 48 | -------------------------------------------------------------------------------- /iso/config/iso_occscannet_mini.yaml: -------------------------------------------------------------------------------- 1 | dataset: "OccScanNet_mini" # "NYU", "OccScanNet", "OccScanNet_mini" 2 | 3 | n_relations: 4 4 | 5 | enable_log: true 6 | 7 | logdir: "/mnt/vdb1/hongxiao.yu/logs/ISO_occscannet_logs/" 8 | OccScanNet_root: "/home/hongxiao.yu/projects/ISO_occscannet_2/occ_data_root" 9 | 10 | 11 | fp_loss: true 12 | frustum_size: 8 13 | batch_size: 1 14 | n_gpus: 1 15 | num_workers_per_gpu: 0 16 | exp_prefix: "iso" 17 | run: 1 18 | lr: 1e-4 19 | weight_decay: 1e-4 20 | 21 | context_prior: true 22 | 23 | relation_loss: true 24 | CE_ssc_loss: true 25 | sem_scal_loss: true 26 | geo_scal_loss: true 27 | 28 | project_1_2: true 29 | project_1_4: true 30 | project_1_8: true 31 | 32 | voxeldepth: true 33 | voxeldepthcfg: 34 | depth_scale_1: true 35 | depth_scale_2: false 36 | depth_scale_4: false 37 | depth_scale_8: false 38 | 39 | use_gt_depth: false 40 | use_zoedepth: false 41 | use_depthanything: true 42 | zoedepth_as_gt: false 43 | depthanything_as_gt: false 44 | 45 | add_fusion: true 46 | 47 | frozen_encoder: true 48 | -------------------------------------------------------------------------------- /iso/data/NYU/collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def collate_fn(batch): 5 | data = {} 6 | imgs = [] 7 | raw_imgs = [] 8 | depth_gts = [] 9 | targets = [] 10 | names = [] 11 | cam_poses = [] 12 | 13 | pix_z = [] 14 | 15 | vox_origins = [] 16 | cam_ks = [] 17 | 18 | CP_mega_matrices = [] 19 | 20 | data["projected_pix_1"] = [] 21 | data["fov_mask_1"] = [] 22 | data["frustums_masks"] = [] 23 | data["frustums_class_dists"] = [] 24 | 25 | for idx, input_dict in enumerate(batch): 26 | CP_mega_matrices.append(torch.from_numpy(input_dict["CP_mega_matrix"])) 27 | for key in data: 28 | if key in input_dict: 29 | data[key].append(torch.from_numpy(input_dict[key])) 30 | 31 | cam_ks.append(torch.from_numpy(input_dict["cam_k"]).double()) 32 | cam_poses.append(torch.from_numpy(input_dict["cam_pose"]).float()) 33 | vox_origins.append(torch.from_numpy(input_dict["voxel_origin"]).double()) 34 | 35 | pix_z.append(torch.from_numpy(input_dict['pix_z']).float()) 36 | 37 | names.append(input_dict["name"]) 38 | 39 | img = input_dict["img"] 40 | imgs.append(img) 41 | 42 | raw_img = input_dict['raw_img'] 43 | raw_imgs.append(raw_img) 44 | 45 | depth_gt = torch.from_numpy(input_dict['depth_gt']) 46 | depth_gts.append(depth_gt) 47 | 48 | target = torch.from_numpy(input_dict["target"]) 49 | targets.append(target) 50 | 51 | ret_data = { 52 | "CP_mega_matrices": CP_mega_matrices, 53 | "cam_pose": torch.stack(cam_poses), 54 | "cam_k": torch.stack(cam_ks), 55 | "vox_origin": torch.stack(vox_origins), 56 | "name": names, 57 | "img": torch.stack(imgs), 58 | # "raw_img": torch.stack(raw_imgs), 59 | "raw_img": raw_imgs, 60 | 'depth_gt': torch.stack(depth_gts), 61 | "target": torch.stack(targets), 62 | 'pix_z': torch.stack(pix_z), 63 | } 64 | for key in data: 65 | ret_data[key] = data[key] 66 | return ret_data 67 | -------------------------------------------------------------------------------- /iso/data/NYU/nyu_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import glob 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | from PIL import Image, ImageTransform 7 | from torchvision import transforms 8 | from iso.data.utils.helpers import ( 9 | vox2pix, 10 | compute_local_frustums, 11 | compute_CP_mega_matrix, 12 | ) 13 | import pickle 14 | import torch.nn.functional as F 15 | import copy 16 | 17 | 18 | def read_depth(depth_path): 19 | depth_vis = Image.open(depth_path).convert('I;16') 20 | depth_vis_array = np.array(depth_vis) 21 | 22 | arr1 = np.right_shift(depth_vis_array, 3) 23 | arr2 = np.left_shift(depth_vis_array, 13) 24 | depth_vis_array = np.bitwise_or(arr1, arr2) 25 | 26 | depth_inpaint = depth_vis_array.astype(np.float32) / 1000.0 27 | 28 | return depth_inpaint 29 | 30 | 31 | class NYUDataset(Dataset): 32 | def __init__( 33 | self, 34 | split, 35 | root, 36 | preprocess_root, 37 | n_relations=4, 38 | color_jitter=None, 39 | frustum_size=4, 40 | fliplr=0.0, 41 | ): 42 | self.n_relations = n_relations 43 | self.frustum_size = frustum_size 44 | self.n_classes = 12 45 | self.root = os.path.join(root, "NYU" + split) 46 | self.preprocess_root = preprocess_root 47 | self.base_dir = os.path.join(preprocess_root, "base", "NYU" + split) 48 | self.fliplr = fliplr 49 | 50 | self.voxel_size = 0.08 # 0.08m 51 | self.scene_size = (4.8, 4.8, 2.88) # (4.8m, 4.8m, 2.88m) 52 | self.img_W = 640 53 | self.img_H = 480 54 | self.cam_k = np.array([[518.8579, 0, 320], [0, 518.8579, 240], [0, 0, 1]]) 55 | 56 | self.color_jitter = ( 57 | transforms.ColorJitter(*color_jitter) if color_jitter else None 58 | ) 59 | 60 | self.scan_names = glob.glob(os.path.join(self.root, "*.bin")) 61 | 62 | self.normalize_rgb = transforms.Compose( 63 | [ 64 | transforms.ToTensor(), 65 | transforms.Normalize( 66 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 67 | ), 68 | ] 69 | ) 70 | 71 | def __getitem__(self, index): 72 | file_path = self.scan_names[index] 73 | filename = os.path.basename(file_path) 74 | name = filename[:-4] 75 | 76 | os.makedirs(self.base_dir, exist_ok=True) 77 | filepath = os.path.join(self.base_dir, name + ".pkl") 78 | 79 | with open(filepath, "rb") as handle: 80 | data = pickle.load(handle) 81 | 82 | cam_pose = data["cam_pose"] 83 | T_world_2_cam = np.linalg.inv(cam_pose) 84 | vox_origin = data["voxel_origin"] 85 | data["cam_k"] = self.cam_k 86 | target = data[ 87 | "target_1_4" 88 | ] # Following SSC literature, the output resolution on NYUv2 is set to 1:4 89 | data["target"] = target 90 | target_1_4 = data["target_1_16"] 91 | 92 | CP_mega_matrix = compute_CP_mega_matrix( 93 | target_1_4, is_binary=self.n_relations == 2 94 | ) 95 | data["CP_mega_matrix"] = CP_mega_matrix 96 | 97 | # compute the 3D-2D mapping 98 | projected_pix, fov_mask, pix_z = vox2pix( 99 | T_world_2_cam, 100 | self.cam_k, 101 | vox_origin, 102 | self.voxel_size, 103 | self.img_W, 104 | self.img_H, 105 | self.scene_size, 106 | ) 107 | 108 | data["projected_pix_1"] = projected_pix 109 | data["fov_mask_1"] = fov_mask 110 | 111 | # compute the masks, each indicates voxels inside a frustum 112 | frustums_masks, frustums_class_dists = compute_local_frustums( 113 | projected_pix, 114 | pix_z, 115 | target, 116 | self.img_W, 117 | self.img_H, 118 | dataset="NYU", 119 | n_classes=12, 120 | size=self.frustum_size, 121 | ) 122 | data["frustums_masks"] = frustums_masks 123 | data["frustums_class_dists"] = frustums_class_dists 124 | 125 | rgb_path = os.path.join(self.root, name + "_color.jpg") 126 | depth_path = os.path.join(self.root, name + ".png") 127 | img = Image.open(rgb_path).convert("RGB") 128 | raw_img_pil = copy.deepcopy(img) 129 | raw_img = np.array(raw_img_pil) 130 | depth_gt = read_depth(depth_path) 131 | depth_gt = np.array(depth_gt, dtype=np.float32, copy=False) 132 | 133 | # Image augmentation 134 | if self.color_jitter is not None: 135 | img = self.color_jitter(img) 136 | # PIL to numpy 137 | img = np.array(img, dtype=np.float32, copy=False) / 255.0 138 | 139 | # randomly fliplr the image 140 | if np.random.rand() < self.fliplr: 141 | img = np.ascontiguousarray(np.fliplr(img)) 142 | # raw_img = np.ascontiguousarray(np.fliplr(raw_img)) 143 | raw_img_pil = raw_img_pil.transpose(Image.FLIP_LEFT_RIGHT) 144 | data["projected_pix_1"][:, 0] = ( 145 | img.shape[1] - 1 - data["projected_pix_1"][:, 0] 146 | ) 147 | 148 | depth_gt = np.ascontiguousarray(np.fliplr(depth_gt)) 149 | 150 | data["img"] = self.normalize_rgb(img) # (3, img_H, img_W) 151 | data['depth_gt'] = depth_gt[None] # (1, 480, 640) 152 | # data['raw_img'] = transforms.PILToTensor()(raw_img).float() 153 | # data['raw_img'] = transforms.ToTensor()(raw_img) 154 | data['raw_img'] = raw_img_pil 155 | # print(data['raw_img'].dtype) 156 | # print((data['raw_img'].numpy() - raw_img).max()) 157 | data['pix_z'] = pix_z 158 | 159 | return data 160 | 161 | def __len__(self): 162 | return len(self.scan_names) 163 | -------------------------------------------------------------------------------- /iso/data/NYU/nyu_dm.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import DataLoader 2 | from iso.data.NYU.nyu_dataset import NYUDataset 3 | from iso.data.NYU.collate import collate_fn 4 | import pytorch_lightning as pl 5 | from iso.data.utils.torch_util import worker_init_fn 6 | 7 | 8 | class NYUDataModule(pl.LightningDataModule): 9 | def __init__( 10 | self, 11 | root, 12 | preprocess_root, 13 | n_relations=4, 14 | batch_size=4, 15 | frustum_size=4, 16 | num_workers=6, 17 | ): 18 | super().__init__() 19 | self.n_relations = n_relations 20 | self.preprocess_root = preprocess_root 21 | self.root = root 22 | self.batch_size = batch_size 23 | self.num_workers = num_workers 24 | self.frustum_size = frustum_size 25 | 26 | def setup(self, stage=None): 27 | self.train_ds = NYUDataset( 28 | split="train", 29 | preprocess_root=self.preprocess_root, 30 | n_relations=self.n_relations, 31 | root=self.root, 32 | fliplr=0.5, 33 | frustum_size=self.frustum_size, 34 | color_jitter=(0.4, 0.4, 0.4), 35 | ) 36 | self.test_ds = NYUDataset( 37 | split="test", 38 | preprocess_root=self.preprocess_root, 39 | n_relations=self.n_relations, 40 | root=self.root, 41 | frustum_size=self.frustum_size, 42 | fliplr=0.0, 43 | color_jitter=None, 44 | ) 45 | 46 | def train_dataloader(self): 47 | return DataLoader( 48 | self.train_ds, 49 | batch_size=self.batch_size, 50 | drop_last=True, 51 | num_workers=self.num_workers, 52 | shuffle=True, 53 | pin_memory=True, 54 | worker_init_fn=worker_init_fn, 55 | collate_fn=collate_fn, 56 | ) 57 | 58 | def val_dataloader(self): 59 | return DataLoader( 60 | self.test_ds, 61 | batch_size=self.batch_size, 62 | num_workers=self.num_workers, 63 | drop_last=False, 64 | shuffle=False, 65 | pin_memory=True, 66 | collate_fn=collate_fn, 67 | ) 68 | 69 | def test_dataloader(self): 70 | return DataLoader( 71 | self.test_ds, 72 | batch_size=self.batch_size, 73 | num_workers=self.num_workers, 74 | drop_last=False, 75 | shuffle=False, 76 | pin_memory=True, 77 | collate_fn=collate_fn, 78 | ) 79 | -------------------------------------------------------------------------------- /iso/data/NYU/params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | NYU_class_names = [ 5 | "empty", 6 | "ceiling", 7 | "floor", 8 | "wall", 9 | "window", 10 | "chair", 11 | "bed", 12 | "sofa", 13 | "table", 14 | "tvs", 15 | "furn", 16 | "objs", 17 | ] 18 | class_weights = torch.FloatTensor([0.05, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) 19 | 20 | class_freq_1_4 = np.array( 21 | [ 22 | 43744234, 23 | 80205, 24 | 1070052, 25 | 905632, 26 | 116952, 27 | 180994, 28 | 436852, 29 | 279714, 30 | 254611, 31 | 28247, 32 | 1805949, 33 | 850724, 34 | ] 35 | ) 36 | class_freq_1_8 = np.array( 37 | [ 38 | 5176253, 39 | 17277, 40 | 220105, 41 | 183849, 42 | 21827, 43 | 33520, 44 | 67022, 45 | 44248, 46 | 46615, 47 | 4419, 48 | 290218, 49 | 142573, 50 | ] 51 | ) 52 | class_freq_1_16 = np.array( 53 | [587620, 3820, 46836, 36256, 4241, 5978, 10939, 8000, 8224, 781, 49778, 25864] 54 | ) 55 | -------------------------------------------------------------------------------- /iso/data/NYU/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import numpy.matlib 4 | import os 5 | import glob 6 | import pickle 7 | import hydra 8 | from omegaconf import DictConfig 9 | 10 | 11 | seg_class_map = [ 12 | 0, 13 | 1, 14 | 2, 15 | 3, 16 | 4, 17 | 11, 18 | 5, 19 | 6, 20 | 7, 21 | 8, 22 | 8, 23 | 10, 24 | 10, 25 | 10, 26 | 11, 27 | 11, 28 | 9, 29 | 8, 30 | 11, 31 | 11, 32 | 11, 33 | 11, 34 | 11, 35 | 11, 36 | 11, 37 | 11, 38 | 11, 39 | 10, 40 | 10, 41 | 11, 42 | 8, 43 | 10, 44 | 11, 45 | 9, 46 | 11, 47 | 11, 48 | 11, 49 | ] 50 | 51 | 52 | def _rle2voxel(rle, voxel_size=(240, 144, 240), rle_filename=""): 53 | r"""Read voxel label data from file (RLE compression), and convert it to fully occupancy labeled voxels. 54 | code taken from https://github.com/waterljwant/SSC/blob/master/dataloaders/dataloader.py#L172 55 | In the data loader of pytorch, only single thread is allowed. 56 | For multi-threads version and more details, see 'readRLE.py'. 57 | output: seg_label: 3D numpy array, size 240 x 144 x 240 58 | """ 59 | seg_label = np.zeros( 60 | int(voxel_size[0] * voxel_size[1] * voxel_size[2]), dtype=np.uint8 61 | ) # segmentation label 62 | vox_idx = 0 63 | for idx in range(int(rle.shape[0] / 2)): 64 | check_val = rle[idx * 2] 65 | check_iter = rle[idx * 2 + 1] 66 | if check_val >= 37 and check_val != 255: # 37 classes to 12 classes 67 | print("RLE {} check_val: {}".format(rle_filename, check_val)) 68 | seg_label_val = ( 69 | seg_class_map[check_val] if check_val != 255 else 255 70 | ) # 37 classes to 12 classes 71 | seg_label[vox_idx : vox_idx + check_iter] = np.matlib.repmat( 72 | seg_label_val, 1, check_iter 73 | ) 74 | vox_idx = vox_idx + check_iter 75 | seg_label = seg_label.reshape(voxel_size) # 3D array, size 240 x 144 x 240 76 | return seg_label 77 | 78 | 79 | def _read_rle(rle_filename): # 0.0005s 80 | """Read RLE compression data 81 | code taken from https://github.com/waterljwant/SSC/blob/master/dataloaders/dataloader.py#L153 82 | Return: 83 | vox_origin, 84 | cam_pose, 85 | vox_rle, voxel label data from file 86 | Shape: 87 | vox_rle, (240, 144, 240) 88 | """ 89 | fid = open(rle_filename, "rb") 90 | vox_origin = np.fromfile( 91 | fid, np.float32, 3 92 | ).T # Read voxel origin in world coordinates 93 | cam_pose = np.fromfile(fid, np.float32, 16).reshape((4, 4)) # Read camera pose 94 | vox_rle = ( 95 | np.fromfile(fid, np.uint32).reshape((-1, 1)).T 96 | ) # Read voxel label data from file 97 | vox_rle = np.squeeze(vox_rle) # 2d array: (1 x N), to 1d array: (N , ) 98 | fid.close() 99 | return vox_origin, cam_pose, vox_rle 100 | 101 | 102 | def _downsample_label(label, voxel_size=(240, 144, 240), downscale=4): 103 | r"""downsample the labeled data, 104 | code taken from https://github.com/waterljwant/SSC/blob/master/dataloaders/dataloader.py#L262 105 | Shape: 106 | label, (240, 144, 240) 107 | label_downscale, if downsample==4, then (60, 36, 60) 108 | """ 109 | if downscale == 1: 110 | return label 111 | ds = downscale 112 | small_size = ( 113 | voxel_size[0] // ds, 114 | voxel_size[1] // ds, 115 | voxel_size[2] // ds, 116 | ) # small size 117 | label_downscale = np.zeros(small_size, dtype=np.uint8) 118 | empty_t = 0.95 * ds * ds * ds # threshold 119 | s01 = small_size[0] * small_size[1] 120 | label_i = np.zeros((ds, ds, ds), dtype=np.int32) 121 | 122 | for i in range(small_size[0] * small_size[1] * small_size[2]): 123 | z = int(i / s01) 124 | y = int((i - z * s01) / small_size[0]) 125 | x = int(i - z * s01 - y * small_size[0]) 126 | 127 | label_i[:, :, :] = label[ 128 | x * ds : (x + 1) * ds, y * ds : (y + 1) * ds, z * ds : (z + 1) * ds 129 | ] 130 | label_bin = label_i.flatten() 131 | 132 | zero_count_0 = np.array(np.where(label_bin == 0)).size 133 | zero_count_255 = np.array(np.where(label_bin == 255)).size 134 | 135 | zero_count = zero_count_0 + zero_count_255 136 | if zero_count > empty_t: 137 | label_downscale[x, y, z] = 0 if zero_count_0 > zero_count_255 else 255 138 | else: 139 | label_i_s = label_bin[ 140 | np.where(np.logical_and(label_bin > 0, label_bin < 255)) 141 | ] 142 | label_downscale[x, y, z] = np.argmax(np.bincount(label_i_s)) 143 | return label_downscale 144 | 145 | 146 | @hydra.main(config_name="../../config/iso.yaml") 147 | def main(config: DictConfig): 148 | scene_size = (240, 144, 240) 149 | for split in ["train", "test"]: 150 | root = os.path.join(config.NYU_root, "NYU" + split) 151 | base_dir = os.path.join(config.NYU_preprocess_root, "base", "NYU" + split) 152 | os.makedirs(base_dir, exist_ok=True) 153 | 154 | scans = glob.glob(os.path.join(root, "*.bin")) 155 | for scan in tqdm(scans): 156 | filename = os.path.basename(scan) 157 | name = filename[:-4] 158 | filepath = os.path.join(base_dir, name + ".pkl") 159 | if os.path.exists(filepath): 160 | continue 161 | 162 | vox_origin, cam_pose, rle = _read_rle(scan) 163 | 164 | target_1_1 = _rle2voxel(rle, scene_size, scan) 165 | target_1_4 = _downsample_label(target_1_1, scene_size, 4) 166 | target_1_16 = _downsample_label(target_1_1, scene_size, 16) 167 | 168 | data = { 169 | "cam_pose": cam_pose, 170 | "voxel_origin": vox_origin, 171 | "name": name, 172 | "target_1_4": target_1_4, 173 | "target_1_16": target_1_16, 174 | } 175 | 176 | with open(filepath, "wb") as handle: 177 | pickle.dump(data, handle) 178 | print("wrote to", filepath) 179 | 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /iso/data/OccScanNet/collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def collate_fn(batch): 6 | data = {} 7 | imgs = [] 8 | raw_imgs = [] 9 | depth_gts = [] 10 | targets = [] 11 | names = [] 12 | cam_poses = [] 13 | 14 | pix_z = [] 15 | 16 | vox_origins = [] 17 | cam_ks = [] 18 | 19 | CP_mega_matrices = [] 20 | 21 | data["projected_pix_1"] = [] 22 | data["fov_mask_1"] = [] 23 | data["frustums_masks"] = [] 24 | data["frustums_class_dists"] = [] 25 | 26 | for idx, input_dict in enumerate(batch): 27 | try: 28 | CP_mega_matrices.append(torch.from_numpy(input_dict["CP_mega_matrix"])) 29 | except: 30 | CP_mega_matrices.append(input_dict["CP_mega_matrix"]) 31 | for key in data: 32 | if key in input_dict: 33 | data[key].append(torch.from_numpy(input_dict[key])) 34 | 35 | cam_ks.append(torch.from_numpy(input_dict["cam_k"]).double()) 36 | cam_poses.append(torch.from_numpy(input_dict["cam_pose"]).float()) 37 | vox_origins.append(torch.from_numpy(np.array(input_dict["voxel_origin"])).double()) 38 | 39 | pix_z.append(torch.from_numpy(input_dict['pix_z']).float()) 40 | 41 | names.append(input_dict["name"]) 42 | 43 | img = input_dict["img"] 44 | imgs.append(img) 45 | raw_img = input_dict['raw_img'] 46 | raw_imgs.append(raw_img) 47 | 48 | depth_gt = torch.from_numpy(input_dict['depth_gt']) 49 | depth_gts.append(depth_gt) 50 | 51 | target = torch.from_numpy(input_dict["target"]) 52 | targets.append(target) 53 | 54 | ret_data = { 55 | "CP_mega_matrices": CP_mega_matrices, 56 | "cam_pose": torch.stack(cam_poses), 57 | "cam_k": torch.stack(cam_ks), 58 | "vox_origin": torch.stack(vox_origins), 59 | "name": names, 60 | "img": torch.stack(imgs), 61 | "raw_img": raw_imgs, 62 | "target": torch.stack(targets), 63 | 'depth_gt': torch.stack(depth_gts), 64 | 'pix_z': torch.stack(pix_z), 65 | } 66 | for key in data: 67 | ret_data[key] = data[key] 68 | 69 | return ret_data -------------------------------------------------------------------------------- /iso/data/OccScanNet/occscannet_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from pprint import pprint 3 | import torch 4 | import os 5 | import glob 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | from PIL import Image 9 | from torchvision import transforms 10 | from iso.data.utils.helpers import ( 11 | vox2pix, 12 | compute_local_frustums, 13 | compute_CP_mega_matrix, 14 | ) 15 | import pickle 16 | import torch.nn.functional as F 17 | import matplotlib.pyplot as plt 18 | 19 | import cv2 20 | 21 | 22 | class OccScanNetDataset(Dataset): 23 | def __init__( 24 | self, 25 | split, 26 | root, 27 | interval=-1, 28 | train_scenes_sample=-1, 29 | val_scenes_sample=-1, 30 | n_relations=4, 31 | color_jitter=None, 32 | frustum_size=4, 33 | fliplr=0.0, 34 | v2=False, 35 | ): 36 | # cur_dir = os.path.abspath(os.path.curdir) 37 | self.occscannet_root = root 38 | 39 | self.n_relations = n_relations 40 | self.frustum_size = frustum_size 41 | self.split = split 42 | self.fliplr = fliplr 43 | 44 | self.voxel_size = 0.08 # 0.08m 45 | self.scene_size = (4.8, 4.8, 2.88) # (4.8m, 4.8m, 2.88m) 46 | 47 | self.color_jitter = ( 48 | transforms.ColorJitter(*color_jitter) if color_jitter else None 49 | ) 50 | 51 | # print(os.getcwd()) 52 | if v2: 53 | subscenes_list = f'{self.occscannet_root}/{self.split}_subscenes_v2.txt' 54 | else: # data/occscannet/train_subscenes.txt 55 | subscenes_list = f'{self.occscannet_root}/{self.split}_subscenes.txt' 56 | with open(subscenes_list, 'r') as f: 57 | self.used_subscenes = f.readlines() 58 | for i in range(len(self.used_subscenes)): 59 | self.used_subscenes[i] = f'{self.occscannet_root}/' + self.used_subscenes[i].strip() 60 | 61 | if "train" in self.split: 62 | # breakpoint() 63 | if train_scenes_sample != -1: 64 | self.used_subscenes = self.used_subscenes[:train_scenes_sample] 65 | # if interval != -1: 66 | # self.used_subscenes = self.used_subscenes[::interval] 67 | # print(f"Total train scenes number: {len(self.used_scan_names)}") 68 | # pprint(self.used_subscenes) 69 | print(f"Total train scenes number: {len(self.used_subscenes)}") 70 | elif "val" in self.split: 71 | # print(f"Total validation scenes number: {len(self.used_scan_names)}") 72 | # pprint(self.used_subscenes) 73 | if val_scenes_sample != -1: 74 | self.used_subscenes = self.used_subscenes[:val_scenes_sample] 75 | print(f"Total validation scenes number: {len(self.used_subscenes)}") 76 | 77 | self.normalize_rgb = transforms.Compose( 78 | [ 79 | transforms.ToTensor(), 80 | transforms.Normalize( 81 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 82 | ), 83 | ] 84 | ) 85 | # os.chdir(cur_dir) 86 | 87 | def __getitem__(self, index): 88 | name = self.used_subscenes[index] 89 | with open(name, 'rb') as f: 90 | data = pickle.load(f) 91 | 92 | cam_pose = data["cam_pose"] 93 | cam_intrin = data['intrinsic'] 94 | 95 | img = cv2.imread(data['img']) 96 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 97 | data['img'] = img 98 | data['raw_img'] = copy.deepcopy(cv2.resize(img, (640, 480))) 99 | depth_img = Image.open(data['depth_gt']).convert('I;16') 100 | depth_img = np.array(depth_img) / 1000.0 101 | data['depth_gt'] = depth_img 102 | depth_gt = data['depth_gt'] 103 | # print(depth_gt.shape) 104 | img_H, img_W = img.shape[0], img.shape[1] 105 | img = cv2.resize(img, (640, 480)) 106 | # plt.imshow(img) 107 | # plt.savefig('img.jpg') 108 | W_factor = 640 / img_W 109 | H_factor = 480 / img_H 110 | img_H, img_W = img.shape[0], img.shape[1] 111 | 112 | cam_intrin[0, 0] *= W_factor 113 | cam_intrin[1, 1] *= H_factor 114 | cam_intrin[0, 2] *= W_factor 115 | cam_intrin[1, 2] *= H_factor 116 | 117 | data["cam_pose"] = cam_pose 118 | T_world_2_cam = np.linalg.inv(cam_pose) 119 | vox_origin = list(data["voxel_origin"]) 120 | vox_origin = np.array(vox_origin) 121 | data["vox_origin"] = vox_origin 122 | data["cam_k"] = cam_intrin[:3, :3][None] 123 | 124 | 125 | target = data[ 126 | "target_1_4" 127 | ] # Following SSC literature, the output resolution on NYUv2 is set to 1:4 128 | target = np.where(target == 255, 0, target) 129 | data["target"] = target 130 | data["target_1_4"] = target 131 | target_1_4 = data["target_1_16"] 132 | target_1_4 = np.where(target_1_4 == 255, 0, target_1_4) 133 | data["target_1_16"] = target_1_4 134 | 135 | CP_mega_matrix = compute_CP_mega_matrix( 136 | target_1_4, is_binary=self.n_relations == 2 137 | ) 138 | 139 | data["CP_mega_matrix"] = CP_mega_matrix 140 | 141 | # compute the 3D-2D mapping 142 | projected_pix, fov_mask, pix_z = vox2pix( 143 | T_world_2_cam, 144 | cam_intrin, 145 | vox_origin, 146 | self.voxel_size, 147 | img_W, 148 | img_H, 149 | self.scene_size, 150 | ) 151 | # print(projected_pix) 152 | # print(fov_mask.shape) 153 | 154 | data["projected_pix_1"] = projected_pix 155 | data["fov_mask_1"] = fov_mask 156 | data['pix_z'] = pix_z 157 | 158 | # compute the masks, each indicates voxels inside a frustum 159 | 160 | frustums_masks, frustums_class_dists = compute_local_frustums( 161 | projected_pix, 162 | pix_z, 163 | target, 164 | img_W, 165 | img_H, 166 | dataset="OccScanNet", 167 | n_classes=12, 168 | size=self.frustum_size, 169 | ) 170 | data["frustums_masks"] = frustums_masks 171 | data["frustums_class_dists"] = frustums_class_dists 172 | 173 | img = Image.fromarray(img).convert('RGB') 174 | 175 | # Image augmentation 176 | if self.color_jitter is not None: 177 | img = self.color_jitter(img) 178 | 179 | # PIL to numpy 180 | img = np.array(img, dtype=np.float32, copy=False) / 255.0 181 | 182 | if np.random.rand() < self.fliplr: 183 | img = np.ascontiguousarray(np.fliplr(img)) 184 | # raw_img = np.ascontiguousarray(np.fliplr(raw_img)) 185 | data['raw_img'] = np.ascontiguousarray(np.fliplr(data['raw_img'])) 186 | data["projected_pix_1"][:, 0] = ( 187 | img.shape[1] - 1 - data["projected_pix_1"][:, 0] 188 | ) 189 | 190 | depth_gt = np.ascontiguousarray(np.fliplr(depth_gt)) 191 | data['depth_gt'] = depth_gt 192 | # print(depth_gt.shape) 193 | 194 | data["img"] = self.normalize_rgb(img) # (3, img_H, img_W) 195 | data["name"] = name 196 | 197 | return data 198 | 199 | def __len__(self): 200 | if 'train' in self.split: 201 | return len(self.used_subscenes) 202 | elif 'val' in self.split: 203 | return len(self.used_subscenes) 204 | 205 | def test(): 206 | datset = OccScanNetDataset( 207 | split='train', 208 | fliplr=0.5, 209 | frustum_size=8, 210 | color_jitter=(0.4, 0.4, 0.4), 211 | ) 212 | datset = OccScanNetDataset( 213 | split='val', 214 | frustum_size=8, 215 | color_jitter=(0.4, 0.4, 0.4), 216 | ) 217 | 218 | if __name__ == "__main__": 219 | test() -------------------------------------------------------------------------------- /iso/data/OccScanNet/occscannet_dm.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import DataLoader 2 | from iso.data.OccScanNet.occscannet_dataset import OccScanNetDataset 3 | from iso.data.OccScanNet.collate import collate_fn 4 | import pytorch_lightning as pl 5 | from iso.data.utils.torch_util import worker_init_fn 6 | 7 | 8 | class OccScanNetDataModule(pl.LightningDataModule): 9 | def __init__( 10 | self, 11 | root, 12 | n_relations=4, 13 | batch_size=4, 14 | frustum_size=4, 15 | num_workers=6, 16 | interval=-1, 17 | train_scenes_sample=-1, 18 | val_scenes_sample=-1, 19 | v2=False, 20 | ): 21 | super().__init__() 22 | 23 | self.root = root 24 | self.n_relations = n_relations 25 | 26 | self.batch_size = batch_size 27 | self.num_workers = num_workers 28 | self.frustum_size = frustum_size 29 | 30 | self.train_scenes_sample = train_scenes_sample 31 | self.val_scenes_sample = val_scenes_sample 32 | 33 | def setup(self, stage=None): 34 | self.train_ds = OccScanNetDataset( 35 | split="train", 36 | root=self.root, 37 | n_relations=self.n_relations, 38 | fliplr=0.5, 39 | train_scenes_sample=self.train_scenes_sample, 40 | frustum_size=self.frustum_size, 41 | color_jitter=(0.4, 0.4, 0.4), 42 | ) 43 | self.test_ds = OccScanNetDataset( 44 | split="val", 45 | root=self.root, 46 | n_relations=self.n_relations, 47 | val_scenes_sample = self.val_scenes_sample, 48 | frustum_size=self.frustum_size, 49 | fliplr=0.0, 50 | color_jitter=None, 51 | ) 52 | 53 | def train_dataloader(self): 54 | return DataLoader( 55 | self.train_ds, 56 | batch_size=self.batch_size, 57 | drop_last=True, 58 | num_workers=self.num_workers, 59 | shuffle=True, 60 | pin_memory=True, 61 | worker_init_fn=worker_init_fn, 62 | collate_fn=collate_fn, 63 | ) 64 | 65 | def val_dataloader(self): 66 | return DataLoader( 67 | self.test_ds, 68 | batch_size=self.batch_size, 69 | num_workers=self.num_workers, 70 | drop_last=False, 71 | shuffle=False, 72 | pin_memory=True, 73 | collate_fn=collate_fn, 74 | ) 75 | 76 | def test_dataloader(self): 77 | return DataLoader( 78 | self.test_ds, 79 | batch_size=self.batch_size, 80 | num_workers=self.num_workers, 81 | drop_last=False, 82 | shuffle=False, 83 | pin_memory=True, 84 | collate_fn=collate_fn, 85 | ) 86 | 87 | 88 | def test(): 89 | datamodule = OccScanNetDataModule( 90 | n_relations=4, 91 | batch_size=64, 92 | frustum_size=8, 93 | ) 94 | datamodule.setup() 95 | train_data = datamodule.train_dataloader() 96 | print(next(iter(train_data)).shape) 97 | 98 | 99 | if __name__ == "__main__": 100 | test() -------------------------------------------------------------------------------- /iso/data/OccScanNet/params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | OccScanNet_class_names = [ 5 | "empty", 6 | "ceiling", 7 | "floor", 8 | "wall", 9 | "window", 10 | "chair", 11 | "bed", 12 | "sofa", 13 | "table", 14 | "tvs", 15 | "furn", 16 | "objs", 17 | ] 18 | class_weights = torch.FloatTensor([0.05, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) -------------------------------------------------------------------------------- /iso/data/utils/fusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most of the code is taken from https://github.com/andyzeng/tsdf-fusion-python/blob/master/fusion.py 3 | 4 | @inproceedings{zeng20163dmatch, 5 | title={3DMatch: Learning Local Geometric Descriptors from RGB-D Reconstructions}, 6 | author={Zeng, Andy and Song, Shuran and Nie{\ss}ner, Matthias and Fisher, Matthew and Xiao, Jianxiong and Funkhouser, Thomas}, 7 | booktitle={CVPR}, 8 | year={2017} 9 | } 10 | """ 11 | 12 | import numpy as np 13 | 14 | from numba import njit, prange 15 | from skimage import measure 16 | 17 | FUSION_GPU_MODE = 0 18 | 19 | 20 | class TSDFVolume: 21 | """Volumetric TSDF Fusion of RGB-D Images.""" 22 | 23 | def __init__(self, vol_bnds, voxel_size, use_gpu=True): 24 | """Constructor. 25 | 26 | Args: 27 | vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the 28 | xyz bounds (min/max) in meters. 29 | voxel_size (float): The volume discretization in meters. 30 | """ 31 | vol_bnds = np.asarray(vol_bnds) 32 | assert vol_bnds.shape == (3, 2), "[!] `vol_bnds` should be of shape (3, 2)." 33 | 34 | # Define voxel volume parameters 35 | self._vol_bnds = vol_bnds 36 | self._voxel_size = float(voxel_size) 37 | self._trunc_margin = 5 * self._voxel_size # truncation on SDF 38 | # self._trunc_margin = 10 # truncation on SDF 39 | self._color_const = 256 * 256 40 | 41 | # Adjust volume bounds and ensure C-order contiguous 42 | self._vol_dim = ( 43 | np.ceil((self._vol_bnds[:, 1] - self._vol_bnds[:, 0]) / self._voxel_size) 44 | .copy(order="C") 45 | .astype(int) 46 | ) 47 | self._vol_bnds[:, 1] = self._vol_bnds[:, 0] + self._vol_dim * self._voxel_size 48 | self._vol_origin = self._vol_bnds[:, 0].copy(order="C").astype(np.float32) 49 | 50 | print( 51 | "Voxel volume size: {} x {} x {} - # points: {:,}".format( 52 | self._vol_dim[0], 53 | self._vol_dim[1], 54 | self._vol_dim[2], 55 | self._vol_dim[0] * self._vol_dim[1] * self._vol_dim[2], 56 | ) 57 | ) 58 | 59 | # Initialize pointers to voxel volume in CPU memory 60 | self._tsdf_vol_cpu = np.zeros(self._vol_dim).astype(np.float32) 61 | # for computing the cumulative moving average of observations per voxel 62 | self._weight_vol_cpu = np.zeros(self._vol_dim).astype(np.float32) 63 | self._color_vol_cpu = np.zeros(self._vol_dim).astype(np.float32) 64 | 65 | self.gpu_mode = use_gpu and FUSION_GPU_MODE 66 | 67 | # Copy voxel volumes to GPU 68 | if self.gpu_mode: 69 | self._tsdf_vol_gpu = cuda.mem_alloc(self._tsdf_vol_cpu.nbytes) 70 | cuda.memcpy_htod(self._tsdf_vol_gpu, self._tsdf_vol_cpu) 71 | self._weight_vol_gpu = cuda.mem_alloc(self._weight_vol_cpu.nbytes) 72 | cuda.memcpy_htod(self._weight_vol_gpu, self._weight_vol_cpu) 73 | self._color_vol_gpu = cuda.mem_alloc(self._color_vol_cpu.nbytes) 74 | cuda.memcpy_htod(self._color_vol_gpu, self._color_vol_cpu) 75 | 76 | # Cuda kernel function (C++) 77 | self._cuda_src_mod = SourceModule( 78 | """ 79 | __global__ void integrate(float * tsdf_vol, 80 | float * weight_vol, 81 | float * color_vol, 82 | float * vol_dim, 83 | float * vol_origin, 84 | float * cam_intr, 85 | float * cam_pose, 86 | float * other_params, 87 | float * color_im, 88 | float * depth_im) { 89 | // Get voxel index 90 | int gpu_loop_idx = (int) other_params[0]; 91 | int max_threads_per_block = blockDim.x; 92 | int block_idx = blockIdx.z*gridDim.y*gridDim.x+blockIdx.y*gridDim.x+blockIdx.x; 93 | int voxel_idx = gpu_loop_idx*gridDim.x*gridDim.y*gridDim.z*max_threads_per_block+block_idx*max_threads_per_block+threadIdx.x; 94 | int vol_dim_x = (int) vol_dim[0]; 95 | int vol_dim_y = (int) vol_dim[1]; 96 | int vol_dim_z = (int) vol_dim[2]; 97 | if (voxel_idx > vol_dim_x*vol_dim_y*vol_dim_z) 98 | return; 99 | // Get voxel grid coordinates (note: be careful when casting) 100 | float voxel_x = floorf(((float)voxel_idx)/((float)(vol_dim_y*vol_dim_z))); 101 | float voxel_y = floorf(((float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z))/((float)vol_dim_z)); 102 | float voxel_z = (float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z-((int)voxel_y)*vol_dim_z); 103 | // Voxel grid coordinates to world coordinates 104 | float voxel_size = other_params[1]; 105 | float pt_x = vol_origin[0]+voxel_x*voxel_size; 106 | float pt_y = vol_origin[1]+voxel_y*voxel_size; 107 | float pt_z = vol_origin[2]+voxel_z*voxel_size; 108 | // World coordinates to camera coordinates 109 | float tmp_pt_x = pt_x-cam_pose[0*4+3]; 110 | float tmp_pt_y = pt_y-cam_pose[1*4+3]; 111 | float tmp_pt_z = pt_z-cam_pose[2*4+3]; 112 | float cam_pt_x = cam_pose[0*4+0]*tmp_pt_x+cam_pose[1*4+0]*tmp_pt_y+cam_pose[2*4+0]*tmp_pt_z; 113 | float cam_pt_y = cam_pose[0*4+1]*tmp_pt_x+cam_pose[1*4+1]*tmp_pt_y+cam_pose[2*4+1]*tmp_pt_z; 114 | float cam_pt_z = cam_pose[0*4+2]*tmp_pt_x+cam_pose[1*4+2]*tmp_pt_y+cam_pose[2*4+2]*tmp_pt_z; 115 | // Camera coordinates to image pixels 116 | int pixel_x = (int) roundf(cam_intr[0*3+0]*(cam_pt_x/cam_pt_z)+cam_intr[0*3+2]); 117 | int pixel_y = (int) roundf(cam_intr[1*3+1]*(cam_pt_y/cam_pt_z)+cam_intr[1*3+2]); 118 | // Skip if outside view frustum 119 | int im_h = (int) other_params[2]; 120 | int im_w = (int) other_params[3]; 121 | if (pixel_x < 0 || pixel_x >= im_w || pixel_y < 0 || pixel_y >= im_h || cam_pt_z<0) 122 | return; 123 | // Skip invalid depth 124 | float depth_value = depth_im[pixel_y*im_w+pixel_x]; 125 | if (depth_value == 0) 126 | return; 127 | // Integrate TSDF 128 | float trunc_margin = other_params[4]; 129 | float depth_diff = depth_value-cam_pt_z; 130 | if (depth_diff < -trunc_margin) 131 | return; 132 | float dist = fmin(1.0f,depth_diff/trunc_margin); 133 | float w_old = weight_vol[voxel_idx]; 134 | float obs_weight = other_params[5]; 135 | float w_new = w_old + obs_weight; 136 | weight_vol[voxel_idx] = w_new; 137 | tsdf_vol[voxel_idx] = (tsdf_vol[voxel_idx]*w_old+obs_weight*dist)/w_new; 138 | // Integrate color 139 | float old_color = color_vol[voxel_idx]; 140 | float old_b = floorf(old_color/(256*256)); 141 | float old_g = floorf((old_color-old_b*256*256)/256); 142 | float old_r = old_color-old_b*256*256-old_g*256; 143 | float new_color = color_im[pixel_y*im_w+pixel_x]; 144 | float new_b = floorf(new_color/(256*256)); 145 | float new_g = floorf((new_color-new_b*256*256)/256); 146 | float new_r = new_color-new_b*256*256-new_g*256; 147 | new_b = fmin(roundf((old_b*w_old+obs_weight*new_b)/w_new),255.0f); 148 | new_g = fmin(roundf((old_g*w_old+obs_weight*new_g)/w_new),255.0f); 149 | new_r = fmin(roundf((old_r*w_old+obs_weight*new_r)/w_new),255.0f); 150 | color_vol[voxel_idx] = new_b*256*256+new_g*256+new_r; 151 | }""" 152 | ) 153 | 154 | self._cuda_integrate = self._cuda_src_mod.get_function("integrate") 155 | 156 | # Determine block/grid size on GPU 157 | gpu_dev = cuda.Device(0) 158 | self._max_gpu_threads_per_block = gpu_dev.MAX_THREADS_PER_BLOCK 159 | n_blocks = int( 160 | np.ceil( 161 | float(np.prod(self._vol_dim)) 162 | / float(self._max_gpu_threads_per_block) 163 | ) 164 | ) 165 | grid_dim_x = min(gpu_dev.MAX_GRID_DIM_X, int(np.floor(np.cbrt(n_blocks)))) 166 | grid_dim_y = min( 167 | gpu_dev.MAX_GRID_DIM_Y, int(np.floor(np.sqrt(n_blocks / grid_dim_x))) 168 | ) 169 | grid_dim_z = min( 170 | gpu_dev.MAX_GRID_DIM_Z, 171 | int(np.ceil(float(n_blocks) / float(grid_dim_x * grid_dim_y))), 172 | ) 173 | self._max_gpu_grid_dim = np.array( 174 | [grid_dim_x, grid_dim_y, grid_dim_z] 175 | ).astype(int) 176 | self._n_gpu_loops = int( 177 | np.ceil( 178 | float(np.prod(self._vol_dim)) 179 | / float( 180 | np.prod(self._max_gpu_grid_dim) 181 | * self._max_gpu_threads_per_block 182 | ) 183 | ) 184 | ) 185 | 186 | else: 187 | # Get voxel grid coordinates 188 | xv, yv, zv = np.meshgrid( 189 | range(self._vol_dim[0]), 190 | range(self._vol_dim[1]), 191 | range(self._vol_dim[2]), 192 | indexing="ij", 193 | ) 194 | self.vox_coords = ( 195 | np.concatenate( 196 | [xv.reshape(1, -1), yv.reshape(1, -1), zv.reshape(1, -1)], axis=0 197 | ) 198 | .astype(int) 199 | .T 200 | ) 201 | 202 | @staticmethod 203 | @njit(parallel=True) 204 | def vox2world(vol_origin, vox_coords, vox_size, offsets=(0.5, 0.5, 0.5)): 205 | """Convert voxel grid coordinates to world coordinates.""" 206 | vol_origin = vol_origin.astype(np.float32) 207 | vox_coords = vox_coords.astype(np.float32) 208 | # print(np.min(vox_coords)) 209 | cam_pts = np.empty_like(vox_coords, dtype=np.float32) 210 | 211 | for i in prange(vox_coords.shape[0]): 212 | for j in range(3): 213 | cam_pts[i, j] = ( 214 | vol_origin[j] 215 | + (vox_size * vox_coords[i, j]) 216 | + vox_size * offsets[j] 217 | ) 218 | return cam_pts 219 | 220 | @staticmethod 221 | @njit(parallel=True) 222 | def cam2pix(cam_pts, intr): 223 | """Convert camera coordinates to pixel coordinates.""" 224 | intr = intr.astype(np.float32) 225 | fx, fy = intr[0, 0], intr[1, 1] 226 | cx, cy = intr[0, 2], intr[1, 2] 227 | pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64) 228 | for i in prange(cam_pts.shape[0]): 229 | pix[i, 0] = int(np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx)) 230 | pix[i, 1] = int(np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy)) 231 | return pix 232 | 233 | @staticmethod 234 | @njit(parallel=True) 235 | def integrate_tsdf(tsdf_vol, dist, w_old, obs_weight): 236 | """Integrate the TSDF volume.""" 237 | tsdf_vol_int = np.empty_like(tsdf_vol, dtype=np.float32) 238 | # print(tsdf_vol.shape) 239 | w_new = np.empty_like(w_old, dtype=np.float32) 240 | for i in prange(len(tsdf_vol)): 241 | w_new[i] = w_old[i] + obs_weight 242 | tsdf_vol_int[i] = (w_old[i] * tsdf_vol[i] + obs_weight * dist[i]) / w_new[i] 243 | return tsdf_vol_int, w_new 244 | 245 | def integrate(self, color_im, depth_im, cam_intr, cam_pose, obs_weight=1.0): 246 | """Integrate an RGB-D frame into the TSDF volume. 247 | 248 | Args: 249 | color_im (ndarray): An RGB image of shape (H, W, 3). 250 | depth_im (ndarray): A depth image of shape (H, W). 251 | cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3). 252 | cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4). 253 | obs_weight (float): The weight to assign for the current observation. A higher 254 | value 255 | """ 256 | im_h, im_w = depth_im.shape 257 | 258 | # Fold RGB color image into a single channel image 259 | color_im = color_im.astype(np.float32) 260 | color_im = np.floor( 261 | color_im[..., 2] * self._color_const 262 | + color_im[..., 1] * 256 263 | + color_im[..., 0] 264 | ) 265 | 266 | if self.gpu_mode: # GPU mode: integrate voxel volume (calls CUDA kernel) 267 | for gpu_loop_idx in range(self._n_gpu_loops): 268 | self._cuda_integrate( 269 | self._tsdf_vol_gpu, 270 | self._weight_vol_gpu, 271 | self._color_vol_gpu, 272 | cuda.InOut(self._vol_dim.astype(np.float32)), 273 | cuda.InOut(self._vol_origin.astype(np.float32)), 274 | cuda.InOut(cam_intr.reshape(-1).astype(np.float32)), 275 | cuda.InOut(cam_pose.reshape(-1).astype(np.float32)), 276 | cuda.InOut( 277 | np.asarray( 278 | [ 279 | gpu_loop_idx, 280 | self._voxel_size, 281 | im_h, 282 | im_w, 283 | self._trunc_margin, 284 | obs_weight, 285 | ], 286 | np.float32, 287 | ) 288 | ), 289 | cuda.InOut(color_im.reshape(-1).astype(np.float32)), 290 | cuda.InOut(depth_im.reshape(-1).astype(np.float32)), 291 | block=(self._max_gpu_threads_per_block, 1, 1), 292 | grid=( 293 | int(self._max_gpu_grid_dim[0]), 294 | int(self._max_gpu_grid_dim[1]), 295 | int(self._max_gpu_grid_dim[2]), 296 | ), 297 | ) 298 | else: # CPU mode: integrate voxel volume (vectorized implementation) 299 | # Convert voxel grid coordinates to pixel coordinates 300 | cam_pts = self.vox2world( 301 | self._vol_origin, self.vox_coords, self._voxel_size 302 | ) 303 | cam_pts = rigid_transform(cam_pts, np.linalg.inv(cam_pose)) 304 | pix_z = cam_pts[:, 2] 305 | pix = self.cam2pix(cam_pts, cam_intr) 306 | pix_x, pix_y = pix[:, 0], pix[:, 1] 307 | 308 | # Eliminate pixels outside view frustum 309 | valid_pix = np.logical_and( 310 | pix_x >= 0, 311 | np.logical_and( 312 | pix_x < im_w, 313 | np.logical_and(pix_y >= 0, np.logical_and(pix_y < im_h, pix_z > 0)), 314 | ), 315 | ) 316 | depth_val = np.zeros(pix_x.shape) 317 | depth_val[valid_pix] = depth_im[pix_y[valid_pix], pix_x[valid_pix]] 318 | 319 | # Integrate TSDF 320 | depth_diff = depth_val - pix_z 321 | 322 | valid_pts = np.logical_and(depth_val > 0, depth_diff >= -10) 323 | dist = depth_diff 324 | 325 | valid_vox_x = self.vox_coords[valid_pts, 0] 326 | valid_vox_y = self.vox_coords[valid_pts, 1] 327 | valid_vox_z = self.vox_coords[valid_pts, 2] 328 | w_old = self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 329 | tsdf_vals = self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 330 | valid_dist = dist[valid_pts] 331 | tsdf_vol_new, w_new = self.integrate_tsdf( 332 | tsdf_vals, valid_dist, w_old, obs_weight 333 | ) 334 | self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = w_new 335 | self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = tsdf_vol_new 336 | 337 | # Integrate color 338 | old_color = self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 339 | old_b = np.floor(old_color / self._color_const) 340 | old_g = np.floor((old_color - old_b * self._color_const) / 256) 341 | old_r = old_color - old_b * self._color_const - old_g * 256 342 | new_color = color_im[pix_y[valid_pts], pix_x[valid_pts]] 343 | new_b = np.floor(new_color / self._color_const) 344 | new_g = np.floor((new_color - new_b * self._color_const) / 256) 345 | new_r = new_color - new_b * self._color_const - new_g * 256 346 | new_b = np.minimum( 347 | 255.0, np.round((w_old * old_b + obs_weight * new_b) / w_new) 348 | ) 349 | new_g = np.minimum( 350 | 255.0, np.round((w_old * old_g + obs_weight * new_g) / w_new) 351 | ) 352 | new_r = np.minimum( 353 | 255.0, np.round((w_old * old_r + obs_weight * new_r) / w_new) 354 | ) 355 | self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = ( 356 | new_b * self._color_const + new_g * 256 + new_r 357 | ) 358 | 359 | def get_volume(self): 360 | if self.gpu_mode: 361 | cuda.memcpy_dtoh(self._tsdf_vol_cpu, self._tsdf_vol_gpu) 362 | cuda.memcpy_dtoh(self._color_vol_cpu, self._color_vol_gpu) 363 | return self._tsdf_vol_cpu, self._color_vol_cpu 364 | 365 | def get_point_cloud(self): 366 | """Extract a point cloud from the voxel volume.""" 367 | tsdf_vol, color_vol = self.get_volume() 368 | 369 | # Marching cubes 370 | verts = measure.marching_cubes_lewiner(tsdf_vol, level=0)[0] 371 | verts_ind = np.round(verts).astype(int) 372 | verts = verts * self._voxel_size + self._vol_origin 373 | 374 | # Get vertex colors 375 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]] 376 | colors_b = np.floor(rgb_vals / self._color_const) 377 | colors_g = np.floor((rgb_vals - colors_b * self._color_const) / 256) 378 | colors_r = rgb_vals - colors_b * self._color_const - colors_g * 256 379 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T 380 | colors = colors.astype(np.uint8) 381 | 382 | pc = np.hstack([verts, colors]) 383 | return pc 384 | 385 | def get_mesh(self): 386 | """Compute a mesh from the voxel volume using marching cubes.""" 387 | tsdf_vol, color_vol = self.get_volume() 388 | 389 | # Marching cubes 390 | verts, faces, norms, vals = measure.marching_cubes_lewiner(tsdf_vol, level=0) 391 | verts_ind = np.round(verts).astype(int) 392 | verts = ( 393 | verts * self._voxel_size + self._vol_origin 394 | ) # voxel grid coordinates to world coordinates 395 | 396 | # Get vertex colors 397 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]] 398 | colors_b = np.floor(rgb_vals / self._color_const) 399 | colors_g = np.floor((rgb_vals - colors_b * self._color_const) / 256) 400 | colors_r = rgb_vals - colors_b * self._color_const - colors_g * 256 401 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T 402 | colors = colors.astype(np.uint8) 403 | return verts, faces, norms, colors 404 | 405 | 406 | def rigid_transform(xyz, transform): 407 | """Applies a rigid transform to an (N, 3) pointcloud.""" 408 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)]) 409 | xyz_t_h = np.dot(transform, xyz_h.T).T 410 | return xyz_t_h[:, :3] 411 | 412 | 413 | def get_view_frustum(depth_im, cam_intr, cam_pose): 414 | """Get corners of 3D camera view frustum of depth image""" 415 | im_h = depth_im.shape[0] 416 | im_w = depth_im.shape[1] 417 | max_depth = np.max(depth_im) 418 | view_frust_pts = np.array( 419 | [ 420 | (np.array([0, 0, 0, im_w, im_w]) - cam_intr[0, 2]) 421 | * np.array([0, max_depth, max_depth, max_depth, max_depth]) 422 | / cam_intr[0, 0], 423 | (np.array([0, 0, im_h, 0, im_h]) - cam_intr[1, 2]) 424 | * np.array([0, max_depth, max_depth, max_depth, max_depth]) 425 | / cam_intr[1, 1], 426 | np.array([0, max_depth, max_depth, max_depth, max_depth]), 427 | ] 428 | ) 429 | view_frust_pts = rigid_transform(view_frust_pts.T, cam_pose).T 430 | return view_frust_pts 431 | 432 | 433 | def meshwrite(filename, verts, faces, norms, colors): 434 | """Save a 3D mesh to a polygon .ply file.""" 435 | # Write header 436 | ply_file = open(filename, "w") 437 | ply_file.write("ply\n") 438 | ply_file.write("format ascii 1.0\n") 439 | ply_file.write("element vertex %d\n" % (verts.shape[0])) 440 | ply_file.write("property float x\n") 441 | ply_file.write("property float y\n") 442 | ply_file.write("property float z\n") 443 | ply_file.write("property float nx\n") 444 | ply_file.write("property float ny\n") 445 | ply_file.write("property float nz\n") 446 | ply_file.write("property uchar red\n") 447 | ply_file.write("property uchar green\n") 448 | ply_file.write("property uchar blue\n") 449 | ply_file.write("element face %d\n" % (faces.shape[0])) 450 | ply_file.write("property list uchar int vertex_index\n") 451 | ply_file.write("end_header\n") 452 | 453 | # Write vertex list 454 | for i in range(verts.shape[0]): 455 | ply_file.write( 456 | "%f %f %f %f %f %f %d %d %d\n" 457 | % ( 458 | verts[i, 0], 459 | verts[i, 1], 460 | verts[i, 2], 461 | norms[i, 0], 462 | norms[i, 1], 463 | norms[i, 2], 464 | colors[i, 0], 465 | colors[i, 1], 466 | colors[i, 2], 467 | ) 468 | ) 469 | 470 | # Write face list 471 | for i in range(faces.shape[0]): 472 | ply_file.write("3 %d %d %d\n" % (faces[i, 0], faces[i, 1], faces[i, 2])) 473 | 474 | ply_file.close() 475 | 476 | 477 | def pcwrite(filename, xyzrgb): 478 | """Save a point cloud to a polygon .ply file.""" 479 | xyz = xyzrgb[:, :3] 480 | rgb = xyzrgb[:, 3:].astype(np.uint8) 481 | 482 | # Write header 483 | ply_file = open(filename, "w") 484 | ply_file.write("ply\n") 485 | ply_file.write("format ascii 1.0\n") 486 | ply_file.write("element vertex %d\n" % (xyz.shape[0])) 487 | ply_file.write("property float x\n") 488 | ply_file.write("property float y\n") 489 | ply_file.write("property float z\n") 490 | ply_file.write("property uchar red\n") 491 | ply_file.write("property uchar green\n") 492 | ply_file.write("property uchar blue\n") 493 | ply_file.write("end_header\n") 494 | 495 | # Write vertex list 496 | for i in range(xyz.shape[0]): 497 | ply_file.write( 498 | "%f %f %f %d %d %d\n" 499 | % ( 500 | xyz[i, 0], 501 | xyz[i, 1], 502 | xyz[i, 2], 503 | rgb[i, 0], 504 | rgb[i, 1], 505 | rgb[i, 2], 506 | ) 507 | ) 508 | -------------------------------------------------------------------------------- /iso/data/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import iso.data.utils.fusion as fusion 3 | import torch 4 | 5 | 6 | def compute_CP_mega_matrix(target, is_binary=False): 7 | """ 8 | Parameters 9 | --------- 10 | target: (H, W, D) 11 | contains voxels semantic labels 12 | 13 | is_binary: bool 14 | if True, return binary voxels relations else return 4-way relations 15 | """ 16 | label = target.reshape(-1) 17 | label_row = label 18 | N = label.shape[0] 19 | super_voxel_size = [i//2 for i in target.shape] 20 | if is_binary: 21 | matrix = np.zeros((2, N, super_voxel_size[0] * super_voxel_size[1] * super_voxel_size[2]), dtype=np.uint8) 22 | else: 23 | matrix = np.zeros((4, N, super_voxel_size[0] * super_voxel_size[1] * super_voxel_size[2]), dtype=np.uint8) 24 | 25 | for xx in range(super_voxel_size[0]): 26 | for yy in range(super_voxel_size[1]): 27 | for zz in range(super_voxel_size[2]): 28 | col_idx = xx * (super_voxel_size[1] * super_voxel_size[2]) + yy * super_voxel_size[2] + zz 29 | label_col_megas = np.array([ 30 | target[xx * 2, yy * 2, zz * 2], 31 | target[xx * 2 + 1, yy * 2, zz * 2], 32 | target[xx * 2, yy * 2 + 1, zz * 2], 33 | target[xx * 2, yy * 2, zz * 2 + 1], 34 | target[xx * 2 + 1, yy * 2 + 1, zz * 2], 35 | target[xx * 2 + 1, yy * 2, zz * 2 + 1], 36 | target[xx * 2, yy * 2 + 1, zz * 2 + 1], 37 | target[xx * 2 + 1, yy * 2 + 1, zz * 2 + 1], 38 | ]) 39 | label_col_megas = label_col_megas[label_col_megas != 255] 40 | for label_col_mega in label_col_megas: 41 | label_col = np.ones(N) * label_col_mega 42 | if not is_binary: 43 | matrix[0, (label_row != 255) & (label_col == label_row) & (label_col != 0), col_idx] = 1.0 # non non same 44 | matrix[1, (label_row != 255) & (label_col != label_row) & (label_col != 0) & (label_row != 0), col_idx] = 1.0 # non non diff 45 | matrix[2, (label_row != 255) & (label_row == label_col) & (label_col == 0), col_idx] = 1.0 # empty empty 46 | matrix[3, (label_row != 255) & (label_row != label_col) & ((label_row == 0) | (label_col == 0)), col_idx] = 1.0 # nonempty empty 47 | else: 48 | matrix[0, (label_row != 255) & (label_col != label_row), col_idx] = 1.0 # diff 49 | matrix[1, (label_row != 255) & (label_col == label_row), col_idx] = 1.0 # same 50 | return matrix 51 | 52 | 53 | def vox2pix(cam_E, cam_k, 54 | vox_origin, voxel_size, 55 | img_W, img_H, 56 | scene_size): 57 | """ 58 | compute the 2D projection of voxels centroids 59 | 60 | Parameters: 61 | ---------- 62 | cam_E: 4x4 63 | =camera pose in case of NYUv2 dataset 64 | =Transformation from camera to lidar coordinate in case of SemKITTI 65 | cam_k: 3x3 66 | camera intrinsics 67 | vox_origin: (3,) 68 | world(NYU)/lidar(SemKITTI) cooridnates of the voxel at index (0, 0, 0) 69 | img_W: int 70 | image width 71 | img_H: int 72 | image height 73 | scene_size: (3,) 74 | scene size in meter: (51.2, 51.2, 6.4) for SemKITTI and (4.8, 4.8, 2.88) for NYUv2 75 | 76 | Returns 77 | ------- 78 | projected_pix: (N, 2) 79 | Projected 2D positions of voxels 80 | fov_mask: (N,) 81 | Voxels mask indice voxels inside image's FOV 82 | pix_z: (N,) 83 | Voxels'distance to the sensor in meter 84 | """ 85 | # Compute the x, y, z bounding of the scene in meter 86 | vol_bnds = np.zeros((3,2)) 87 | vol_bnds[:,0] = vox_origin 88 | vol_bnds[:,1] = vox_origin + np.array(scene_size) 89 | 90 | 91 | # Compute the voxels centroids in lidar cooridnates 92 | # TODO: Make sure the around process has no influence on the NYUv2 result. 93 | vol_dim = np.ceil(np.around((vol_bnds[:,1]- vol_bnds[:,0])/ voxel_size, 3)).copy(order='C').astype(int) 94 | if vol_dim[0] != 60 or vol_dim[1] != 60 or vol_dim[2] != 36: 95 | print("Find it:", vol_dim, '\n', vol_bnds, vol_bnds.dtype) 96 | exit(-1) 97 | xv, yv, zv = np.meshgrid( 98 | range(vol_dim[0]), 99 | range(vol_dim[1]), 100 | range(vol_dim[2]), 101 | indexing='ij' 102 | ) 103 | vox_coords = np.concatenate([ 104 | xv.reshape(1,-1), 105 | yv.reshape(1,-1), 106 | zv.reshape(1,-1) 107 | ], axis=0).astype(int).T 108 | 109 | # Project voxels'centroid from lidar coordinates to camera coordinates 110 | cam_pts = fusion.TSDFVolume.vox2world(vox_origin, vox_coords, voxel_size) 111 | cam_pts = fusion.rigid_transform(cam_pts, cam_E) 112 | 113 | # Project camera coordinates to pixel positions 114 | projected_pix = fusion.TSDFVolume.cam2pix(cam_pts, cam_k) 115 | pix_x, pix_y = projected_pix[:, 0], projected_pix[:, 1] 116 | 117 | # Eliminate pixels outside view frustum 118 | pix_z = cam_pts[:, 2] 119 | fov_mask = np.logical_and(pix_x >= 0, 120 | np.logical_and(pix_x < img_W, 121 | np.logical_and(pix_y >= 0, 122 | np.logical_and(pix_y < img_H, 123 | pix_z > 0)))) 124 | 125 | 126 | return projected_pix, fov_mask, pix_z 127 | 128 | 129 | def compute_local_frustum(pix_x, pix_y, min_x, max_x, min_y, max_y, pix_z): 130 | valid_pix = np.logical_and(pix_x >= min_x, 131 | np.logical_and(pix_x < max_x, 132 | np.logical_and(pix_y >= min_y, 133 | np.logical_and(pix_y < max_y, 134 | pix_z > 0)))) 135 | return valid_pix 136 | 137 | def compute_local_frustums(projected_pix, pix_z, target, img_W, img_H, dataset, n_classes, size=4): 138 | """ 139 | Compute the local frustums mask and their class frequencies 140 | 141 | Parameters: 142 | ---------- 143 | projected_pix: (N, 2) 144 | 2D projected pix of all voxels 145 | pix_z: (N,) 146 | Distance of the camera sensor to voxels 147 | target: (H, W, D) 148 | Voxelized sematic labels 149 | img_W: int 150 | Image width 151 | img_H: int 152 | Image height 153 | dataset: str 154 | ="NYU" or "kitti" (for both SemKITTI and KITTI-360) 155 | n_classes: int 156 | Number of classes (12 for NYU and 20 for SemKITTI) 157 | size: int 158 | determine the number of local frustums i.e. size * size 159 | 160 | Returns 161 | ------- 162 | frustums_masks: (n_frustums, N) 163 | List of frustums_masks, each indicates the belonging voxels 164 | frustums_class_dists: (n_frustums, n_classes) 165 | Contains the class frequencies in each frustum 166 | """ 167 | H, W, D = target.shape 168 | ranges = [(i * 1.0/size, (i * 1.0 + 1)/size) for i in range(size)] 169 | local_frustum_masks = [] 170 | local_frustum_class_dists = [] 171 | pix_x, pix_y = projected_pix[:, 0], projected_pix[:, 1] 172 | for y in ranges: 173 | for x in ranges: 174 | start_x = x[0] * img_W 175 | end_x = x[1] * img_W 176 | start_y = y[0] * img_H 177 | end_y = y[1] * img_H 178 | local_frustum = compute_local_frustum(pix_x, pix_y, start_x, end_x, start_y, end_y, pix_z) 179 | if dataset == "NYU": 180 | mask = (target != 255) & np.moveaxis(local_frustum.reshape(60, 60, 36), [0, 1, 2], [0, 2, 1]) 181 | elif dataset == "OccScanNet" or dataset == "OccScanNet_mini": 182 | mask = (target != 255) & local_frustum.reshape(H, W, D) 183 | 184 | local_frustum_masks.append(mask) 185 | classes, cnts = np.unique(target[mask], return_counts=True) 186 | class_counts = np.zeros(n_classes) 187 | class_counts[classes.astype(int)] = cnts 188 | local_frustum_class_dists.append(class_counts) 189 | frustums_masks, frustums_class_dists = np.array(local_frustum_masks), np.array(local_frustum_class_dists) 190 | return frustums_masks, frustums_class_dists 191 | -------------------------------------------------------------------------------- /iso/data/utils/torch_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def worker_init_fn(worker_id): 6 | """The function is designed for pytorch multi-process dataloader. 7 | Note that we use the pytorch random generator to generate a base_seed. 8 | Please try to be consistent. 9 | 10 | References: 11 | https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed 12 | 13 | """ 14 | base_seed = torch.IntTensor(1).random_().item() 15 | np.random.seed(base_seed + worker_id) 16 | -------------------------------------------------------------------------------- /iso/loss/CRP_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_super_CP_multilabel_loss(pred_logits, CP_mega_matrices): 5 | logits = [] 6 | labels = [] 7 | bs, n_relations, _, _ = pred_logits.shape 8 | for i in range(bs): 9 | pred_logit = pred_logits[i, :, :, :].permute( 10 | 0, 2, 1 11 | ) # n_relations, N, n_mega_voxels 12 | CP_mega_matrix = CP_mega_matrices[i] # n_relations, N, n_mega_voxels 13 | logits.append(pred_logit.reshape(n_relations, -1)) 14 | labels.append(CP_mega_matrix.reshape(n_relations, -1)) 15 | 16 | logits = torch.cat(logits, dim=1).T # M, 4 17 | labels = torch.cat(labels, dim=1).T # M, 4 18 | 19 | cnt_neg = (labels == 0).sum(0) 20 | cnt_pos = labels.sum(0) 21 | pos_weight = cnt_neg / (cnt_pos + 1e-5) 22 | criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) 23 | loss_bce = criterion(logits, labels.float()) 24 | return loss_bce 25 | -------------------------------------------------------------------------------- /iso/loss/depth_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda.amp import autocast 3 | import torch.nn.functional as F 4 | 5 | from iso.models.modules import bin_depths 6 | 7 | 8 | class DepthClsLoss: 9 | def __init__(self, downsample_factor, d_bound, depth_channels=64): 10 | self.downsample_factor = downsample_factor 11 | self.depth_channels = depth_channels 12 | self.d_bound = d_bound 13 | # self.depth_channels = int((self.d_bound[1] - self.d_bound[0]) / self.d_bound[2]) 14 | 15 | def _get_downsampled_gt_depth(self, gt_depths): 16 | """ 17 | Input: 18 | gt_depths: [B, N, H, W] 19 | Output: 20 | gt_depths: [B*N*h*w, d] 21 | """ 22 | B, N, H, W = gt_depths.shape 23 | gt_depths = gt_depths.view( 24 | B * N, 25 | H // self.downsample_factor, 26 | self.downsample_factor, 27 | W // self.downsample_factor, 28 | self.downsample_factor, 29 | 1, 30 | ) 31 | gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous() 32 | gt_depths = gt_depths.view(-1, self.downsample_factor * self.downsample_factor) 33 | 34 | gt_depths_tmp = torch.where( 35 | gt_depths == 0.0, 1e5 * torch.ones_like(gt_depths), gt_depths 36 | ) 37 | gt_depths = torch.min(gt_depths_tmp, dim=-1).values 38 | gt_depths = gt_depths.view( 39 | B * N, H // self.downsample_factor, W // self.downsample_factor 40 | ) 41 | 42 | # ds = torch.arange(64) # (64,) 43 | # depth_bin_pos = (10.0 / 64 / 65 * ds * (ds + 1)).reshape(1, 64, 1, 1) 44 | # # print(gt_depths.unsqueeze(1).shape) 45 | # # print(depth_bin_pos.shape) 46 | # delta_z = torch.abs(gt_depths.unsqueeze(1) - depth_bin_pos.to(gt_depths.device)) 47 | # gt_depths = torch.argmin(delta_z, dim=1) 48 | gt_depths = bin_depths(gt_depths, mode='LID', depth_min=0, depth_max=10, num_bins=self.depth_channels, target=True) 49 | 50 | gt_depths = torch.where( 51 | (gt_depths < self.depth_channels + 1) & (gt_depths >= 0.0), 52 | gt_depths, 53 | torch.zeros_like(gt_depths), 54 | ) 55 | gt_depths = F.one_hot( 56 | gt_depths.long(), num_classes=self.depth_channels + 1 57 | ).view(-1, self.depth_channels + 1)[:, :-1] 58 | 59 | return gt_depths.float() 60 | 61 | def get_depth_loss(self, depth_labels, depth_preds): 62 | if len(depth_labels.shape) != 4: 63 | depth_labels = depth_labels.unsqueeze(1) 64 | # print(depth_labels.shape) 65 | N_pred, n_cam_pred, D, H, W = depth_preds.shape 66 | N_gt, n_cam_label, oriH, oriW = depth_labels.shape 67 | assert ( 68 | N_pred * n_cam_pred == N_gt * n_cam_label 69 | ), f"N_pred: {N_pred}, n_cam_pred: {n_cam_pred}, N_gt: {N_gt}, n_cam_label: {n_cam_label}" 70 | depth_labels = depth_labels.reshape(N_gt * n_cam_label, oriH, oriW) 71 | depth_preds = depth_preds.reshape(N_pred * n_cam_pred, D, H, W) 72 | 73 | # depth_labels = depth_labels.reshape( 74 | # N 75 | # ) 76 | 77 | # depth_labels = depth_labels.unsqueeze(1) 78 | # depth_labels = depth_labels 79 | depth_labels = F.interpolate( 80 | depth_labels.unsqueeze(1), 81 | (H * self.downsample_factor, W * self.downsample_factor), 82 | mode="nearest", 83 | ) 84 | depth_labels = self._get_downsampled_gt_depth(depth_labels) 85 | depth_preds = ( 86 | depth_preds.permute(0, 2, 3, 1).contiguous().view(-1, self.depth_channels) 87 | ) 88 | fg_mask = torch.max(depth_labels, dim=1).values > 0.0 89 | 90 | with autocast(enabled=False): 91 | depth_loss = F.binary_cross_entropy( 92 | depth_preds[fg_mask], 93 | depth_labels[fg_mask], 94 | reduction="none", 95 | ).sum() / max(1.0, fg_mask.sum()) 96 | # depth_loss = torch.nan_to_num(depth_loss, nan=0.) 97 | 98 | return depth_loss 99 | -------------------------------------------------------------------------------- /iso/loss/sscMetrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Part of the code is taken from https://github.com/waterljwant/SSC/blob/master/sscMetrics.py 3 | """ 4 | import numpy as np 5 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 6 | 7 | 8 | def get_iou(iou_sum, cnt_class): 9 | _C = iou_sum.shape[0] # 12 10 | iou = np.zeros(_C, dtype=np.float32) # iou for each class 11 | for idx in range(_C): 12 | iou[idx] = iou_sum[idx] / cnt_class[idx] if cnt_class[idx] else 0 13 | 14 | mean_iou = np.sum(iou[1:]) / np.count_nonzero(cnt_class[1:]) 15 | return iou, mean_iou 16 | 17 | 18 | def get_accuracy(predict, target, weight=None): # 0.05s 19 | _bs = predict.shape[0] # batch size 20 | _C = predict.shape[1] # _C = 12 21 | target = np.int32(target) 22 | target = target.reshape(_bs, -1) # (_bs, 60*36*60) 129600 23 | predict = predict.reshape(_bs, _C, -1) # (_bs, _C, 60*36*60) 24 | predict = np.argmax( 25 | predict, axis=1 26 | ) # one-hot: _bs x _C x 60*36*60 --> label: _bs x 60*36*60. 27 | 28 | correct = predict == target # (_bs, 129600) 29 | if weight: # 0.04s, add class weights 30 | weight_k = np.ones(target.shape) 31 | for i in range(_bs): 32 | for n in range(target.shape[1]): 33 | idx = 0 if target[i, n] == 255 else target[i, n] 34 | weight_k[i, n] = weight[idx] 35 | correct = correct * weight_k 36 | acc = correct.sum() / correct.size 37 | return acc 38 | 39 | 40 | class SSCMetrics: 41 | def __init__(self, n_classes): 42 | self.n_classes = n_classes 43 | self.reset() 44 | 45 | def hist_info(self, n_cl, pred, gt): 46 | assert pred.shape == gt.shape 47 | k = (gt >= 0) & (gt < n_cl) # exclude 255 48 | labeled = np.sum(k) 49 | correct = np.sum((pred[k] == gt[k])) 50 | 51 | return ( 52 | np.bincount( 53 | n_cl * gt[k].astype(int) + pred[k].astype(int), minlength=n_cl ** 2 54 | ).reshape(n_cl, n_cl), 55 | correct, 56 | labeled, 57 | ) 58 | 59 | @staticmethod 60 | def compute_score(hist, correct, labeled): 61 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 62 | mean_IU = np.nanmean(iu) 63 | mean_IU_no_back = np.nanmean(iu[1:]) 64 | freq = hist.sum(1) / hist.sum() 65 | freq_IU = (iu[freq > 0] * freq[freq > 0]).sum() 66 | mean_pixel_acc = correct / labeled if labeled != 0 else 0 67 | 68 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc 69 | 70 | def add_batch(self, y_pred, y_true, nonempty=None, nonsurface=None): 71 | self.count += 1 72 | mask = y_true != 255 73 | if nonempty is not None: 74 | mask = mask & nonempty 75 | if nonsurface is not None: 76 | mask = mask & nonsurface 77 | tp, fp, fn = self.get_score_completion(y_pred, y_true, mask) 78 | 79 | self.completion_tp += tp 80 | self.completion_fp += fp 81 | self.completion_fn += fn 82 | 83 | mask = y_true != 255 84 | if nonempty is not None: 85 | mask = mask & nonempty 86 | tp_sum, fp_sum, fn_sum = self.get_score_semantic_and_completion( 87 | y_pred, y_true, mask 88 | ) 89 | self.tps += tp_sum 90 | self.fps += fp_sum 91 | self.fns += fn_sum 92 | 93 | def get_stats(self): 94 | if self.completion_tp != 0: 95 | precision = self.completion_tp / (self.completion_tp + self.completion_fp) 96 | recall = self.completion_tp / (self.completion_tp + self.completion_fn) 97 | iou = self.completion_tp / ( 98 | self.completion_tp + self.completion_fp + self.completion_fn 99 | ) 100 | else: 101 | precision, recall, iou = 0, 0, 0 102 | iou_ssc = self.tps / (self.tps + self.fps + self.fns + 1e-5) 103 | return { 104 | "precision": precision, 105 | "recall": recall, 106 | "iou": iou, 107 | "iou_ssc": iou_ssc, 108 | "iou_ssc_mean": np.mean(iou_ssc[1:]), 109 | } 110 | 111 | def reset(self): 112 | 113 | self.completion_tp = 0 114 | self.completion_fp = 0 115 | self.completion_fn = 0 116 | self.tps = np.zeros(self.n_classes) 117 | self.fps = np.zeros(self.n_classes) 118 | self.fns = np.zeros(self.n_classes) 119 | 120 | self.hist_ssc = np.zeros((self.n_classes, self.n_classes)) 121 | self.labeled_ssc = 0 122 | self.correct_ssc = 0 123 | 124 | self.precision = 0 125 | self.recall = 0 126 | self.iou = 0 127 | self.count = 1e-8 128 | self.iou_ssc = np.zeros(self.n_classes, dtype=np.float32) 129 | self.cnt_class = np.zeros(self.n_classes, dtype=np.float32) 130 | 131 | def get_score_completion(self, predict, target, nonempty=None): 132 | predict = np.copy(predict) 133 | target = np.copy(target) 134 | 135 | """for scene completion, treat the task as two-classes problem, just empty or occupancy""" 136 | _bs = predict.shape[0] # batch size 137 | # ---- ignore 138 | predict[target == 255] = 0 139 | target[target == 255] = 0 140 | # ---- flatten 141 | target = target.reshape(_bs, -1) # (_bs, 129600) 142 | predict = predict.reshape(_bs, -1) # (_bs, _C, 129600), 60*36*60=129600 143 | # ---- treat all non-empty object class as one category, set them to label 1 144 | b_pred = np.zeros(predict.shape) 145 | b_true = np.zeros(target.shape) 146 | b_pred[predict > 0] = 1 147 | b_true[target > 0] = 1 148 | p, r, iou = 0.0, 0.0, 0.0 149 | tp_sum, fp_sum, fn_sum = 0, 0, 0 150 | for idx in range(_bs): 151 | y_true = b_true[idx, :] # GT 152 | y_pred = b_pred[idx, :] 153 | if nonempty is not None: 154 | nonempty_idx = nonempty[idx, :].reshape(-1) 155 | y_true = y_true[nonempty_idx == 1] 156 | y_pred = y_pred[nonempty_idx == 1] 157 | 158 | tp = np.array(np.where(np.logical_and(y_true == 1, y_pred == 1))).size 159 | fp = np.array(np.where(np.logical_and(y_true != 1, y_pred == 1))).size 160 | fn = np.array(np.where(np.logical_and(y_true == 1, y_pred != 1))).size 161 | tp_sum += tp 162 | fp_sum += fp 163 | fn_sum += fn 164 | return tp_sum, fp_sum, fn_sum 165 | 166 | def get_score_semantic_and_completion(self, predict, target, nonempty=None): 167 | target = np.copy(target) 168 | predict = np.copy(predict) 169 | _bs = predict.shape[0] # batch size 170 | _C = self.n_classes # _C = 12 171 | # ---- ignore 172 | predict[target == 255] = 0 173 | target[target == 255] = 0 174 | # ---- flatten 175 | target = target.reshape(_bs, -1) # (_bs, 129600) 176 | predict = predict.reshape(_bs, -1) # (_bs, 129600), 60*36*60=129600 177 | 178 | cnt_class = np.zeros(_C, dtype=np.int32) # count for each class 179 | iou_sum = np.zeros(_C, dtype=np.float32) # sum of iou for each class 180 | tp_sum = np.zeros(_C, dtype=np.int32) # tp 181 | fp_sum = np.zeros(_C, dtype=np.int32) # fp 182 | fn_sum = np.zeros(_C, dtype=np.int32) # fn 183 | 184 | for idx in range(_bs): 185 | y_true = target[idx, :] # GT 186 | y_pred = predict[idx, :] 187 | if nonempty is not None: 188 | nonempty_idx = nonempty[idx, :].reshape(-1) 189 | y_pred = y_pred[ 190 | np.where(np.logical_and(nonempty_idx == 1, y_true != 255)) 191 | ] 192 | y_true = y_true[ 193 | np.where(np.logical_and(nonempty_idx == 1, y_true != 255)) 194 | ] 195 | for j in range(_C): # for each class 196 | tp = np.array(np.where(np.logical_and(y_true == j, y_pred == j))).size 197 | fp = np.array(np.where(np.logical_and(y_true != j, y_pred == j))).size 198 | fn = np.array(np.where(np.logical_and(y_true == j, y_pred != j))).size 199 | 200 | tp_sum[j] += tp 201 | fp_sum[j] += fp 202 | fn_sum[j] += fn 203 | 204 | return tp_sum, fp_sum, fn_sum 205 | -------------------------------------------------------------------------------- /iso/loss/ssc_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def KL_sep(p, target): 7 | """ 8 | KL divergence on nonzeros classes 9 | """ 10 | nonzeros = target != 0 11 | nonzero_p = p[nonzeros] 12 | kl_term = F.kl_div(torch.log(nonzero_p), target[nonzeros], reduction="sum") 13 | return kl_term 14 | 15 | 16 | def geo_scal_loss(pred, ssc_target): 17 | 18 | # Get softmax probabilities 19 | pred = F.softmax(pred, dim=1) 20 | 21 | # Compute empty and nonempty probabilities 22 | empty_probs = pred[:, 0, :, :, :] 23 | nonempty_probs = 1 - empty_probs 24 | 25 | # Remove unknown voxels 26 | mask = ssc_target != 255 27 | nonempty_target = ssc_target != 0 28 | nonempty_target = nonempty_target[mask].float() 29 | nonempty_probs = nonempty_probs[mask] 30 | empty_probs = empty_probs[mask] 31 | 32 | intersection = (nonempty_target * nonempty_probs).sum() 33 | precision = intersection / nonempty_probs.sum() 34 | recall = intersection / nonempty_target.sum() 35 | spec = ((1 - nonempty_target) * (empty_probs)).sum() / (1 - nonempty_target).sum() 36 | return ( 37 | F.binary_cross_entropy(precision, torch.ones_like(precision)) 38 | + F.binary_cross_entropy(recall, torch.ones_like(recall)) 39 | + F.binary_cross_entropy(spec, torch.ones_like(spec)) 40 | ) 41 | 42 | 43 | def sem_scal_loss(pred, ssc_target): 44 | # Get softmax probabilities 45 | pred = F.softmax(pred, dim=1) 46 | loss = 0 47 | count = 0 48 | mask = ssc_target != 255 49 | n_classes = pred.shape[1] 50 | for i in range(0, n_classes): 51 | 52 | # Get probability of class i 53 | p = pred[:, i, :, :, :] 54 | 55 | # Remove unknown voxels 56 | target_ori = ssc_target 57 | p = p[mask] 58 | target = ssc_target[mask] 59 | 60 | completion_target = torch.ones_like(target) 61 | completion_target[target != i] = 0 62 | completion_target_ori = torch.ones_like(target_ori).float() 63 | completion_target_ori[target_ori != i] = 0 64 | if torch.sum(completion_target) > 0: 65 | count += 1.0 66 | nominator = torch.sum(p * completion_target) 67 | loss_class = 0 68 | if torch.sum(p) > 0: 69 | precision = nominator / (torch.sum(p)) 70 | loss_precision = F.binary_cross_entropy( 71 | precision, torch.ones_like(precision) 72 | ) 73 | loss_class += loss_precision 74 | if torch.sum(completion_target) > 0: 75 | recall = nominator / (torch.sum(completion_target)) 76 | loss_recall = F.binary_cross_entropy(recall, torch.ones_like(recall)) 77 | loss_class += loss_recall 78 | if torch.sum(1 - completion_target) > 0: 79 | specificity = torch.sum((1 - p) * (1 - completion_target)) / ( 80 | torch.sum(1 - completion_target) 81 | ) 82 | loss_specificity = F.binary_cross_entropy( 83 | specificity, torch.ones_like(specificity) 84 | ) 85 | loss_class += loss_specificity 86 | loss += loss_class 87 | return loss / count 88 | 89 | 90 | def CE_ssc_loss(pred, target, class_weights): 91 | """ 92 | :param: prediction: the predicted tensor, must be [BS, C, H, W, D] 93 | """ 94 | criterion = nn.CrossEntropyLoss( 95 | weight=class_weights, ignore_index=255, reduction="mean" 96 | ) 97 | loss = criterion(pred, target.long()) 98 | 99 | return loss 100 | -------------------------------------------------------------------------------- /iso/models/CRP3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from iso.models.modules import ( 4 | Process, 5 | ASPP, 6 | ) 7 | 8 | 9 | class CPMegaVoxels(nn.Module): 10 | def __init__(self, feature, size, n_relations=4, bn_momentum=0.0003): 11 | super().__init__() 12 | self.size = size 13 | self.n_relations = n_relations 14 | print("n_relations", self.n_relations) 15 | self.flatten_size = size[0] * size[1] * size[2] 16 | self.feature = feature 17 | self.context_feature = feature * 2 18 | self.flatten_context_size = (size[0] // 2) * (size[1] // 2) * (size[2] // 2) 19 | padding = ((size[0] + 1) % 2, (size[1] + 1) % 2, (size[2] + 1) % 2) 20 | 21 | self.mega_context = nn.Sequential( 22 | nn.Conv3d( 23 | feature, self.context_feature, stride=2, padding=padding, kernel_size=3 24 | ), 25 | ) 26 | self.flatten_context_size = (size[0] // 2) * (size[1] // 2) * (size[2] // 2) 27 | 28 | self.context_prior_logits = nn.ModuleList( 29 | [ 30 | nn.Sequential( 31 | nn.Conv3d( 32 | self.feature, 33 | self.flatten_context_size, 34 | padding=0, 35 | kernel_size=1, 36 | ), 37 | ) 38 | for i in range(n_relations) 39 | ] 40 | ) 41 | self.aspp = ASPP(feature, [1, 2, 3]) 42 | 43 | self.resize = nn.Sequential( 44 | nn.Conv3d( 45 | self.context_feature * self.n_relations + feature, 46 | feature, 47 | kernel_size=1, 48 | padding=0, 49 | bias=False, 50 | ), 51 | Process(feature, nn.BatchNorm3d, bn_momentum, dilations=[1]), 52 | ) 53 | 54 | def forward(self, input): 55 | ret = {} 56 | bs = input.shape[0] 57 | 58 | x_agg = self.aspp(input) 59 | 60 | # get the mega context 61 | x_mega_context_raw = self.mega_context(x_agg) 62 | x_mega_context = x_mega_context_raw.reshape(bs, self.context_feature, -1) 63 | x_mega_context = x_mega_context.permute(0, 2, 1) 64 | 65 | # get context prior map 66 | x_context_prior_logits = [] 67 | x_context_rels = [] 68 | for rel in range(self.n_relations): 69 | 70 | # Compute the relation matrices 71 | x_context_prior_logit = self.context_prior_logits[rel](x_agg) 72 | x_context_prior_logit = x_context_prior_logit.reshape( 73 | bs, self.flatten_context_size, self.flatten_size 74 | ) 75 | x_context_prior_logits.append(x_context_prior_logit.unsqueeze(1)) 76 | 77 | x_context_prior_logit = x_context_prior_logit.permute(0, 2, 1) 78 | x_context_prior = torch.sigmoid(x_context_prior_logit) 79 | 80 | # Multiply the relation matrices with the mega context to gather context features 81 | x_context_rel = torch.bmm(x_context_prior, x_mega_context) # bs, N, f 82 | x_context_rels.append(x_context_rel) 83 | 84 | x_context = torch.cat(x_context_rels, dim=2) 85 | x_context = x_context.permute(0, 2, 1) 86 | x_context = x_context.reshape( 87 | bs, x_context.shape[1], self.size[0], self.size[1], self.size[2] 88 | ) 89 | 90 | x = torch.cat([input, x_context], dim=1) 91 | x = self.resize(x) 92 | 93 | x_context_prior_logits = torch.cat(x_context_prior_logits, dim=1) 94 | ret["P_logits"] = x_context_prior_logits 95 | ret["x"] = x 96 | 97 | return ret 98 | -------------------------------------------------------------------------------- /iso/models/DDR.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most of the code in this file is taken from https://github.com/waterljwant/SSC/blob/master/models/DDR.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class SimpleRB(nn.Module): 11 | def __init__(self, in_channel, norm_layer, bn_momentum): 12 | super(SimpleRB, self).__init__() 13 | self.path = nn.Sequential( 14 | nn.Conv3d(in_channel, in_channel, kernel_size=3, padding=1, bias=False), 15 | norm_layer(in_channel, momentum=bn_momentum), 16 | nn.ReLU(), 17 | nn.Conv3d(in_channel, in_channel, kernel_size=3, padding=1, bias=False), 18 | norm_layer(in_channel, momentum=bn_momentum), 19 | ) 20 | self.relu = nn.ReLU() 21 | 22 | def forward(self, x): 23 | residual = x 24 | conv_path = self.path(x) 25 | out = residual + conv_path 26 | out = self.relu(out) 27 | return out 28 | 29 | 30 | """ 31 | 3D Residual Block,3x3x3 conv ==> 3 smaller 3D conv, refered from DDRNet 32 | """ 33 | 34 | 35 | class Bottleneck3D(nn.Module): 36 | def __init__( 37 | self, 38 | inplanes, 39 | planes, 40 | norm_layer, 41 | stride=1, 42 | dilation=[1, 1, 1], 43 | expansion=4, 44 | downsample=None, 45 | fist_dilation=1, 46 | multi_grid=1, 47 | bn_momentum=0.0003, 48 | ): 49 | super(Bottleneck3D, self).__init__() 50 | # often,planes = inplanes // 4 51 | self.expansion = expansion 52 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 53 | self.bn1 = norm_layer(planes, momentum=bn_momentum) 54 | self.conv2 = nn.Conv3d( 55 | planes, 56 | planes, 57 | kernel_size=(1, 1, 3), 58 | stride=(1, 1, stride), 59 | dilation=(1, 1, dilation[0]), 60 | padding=(0, 0, dilation[0]), 61 | bias=False, 62 | ) 63 | self.bn2 = norm_layer(planes, momentum=bn_momentum) 64 | self.conv3 = nn.Conv3d( 65 | planes, 66 | planes, 67 | kernel_size=(1, 3, 1), 68 | stride=(1, stride, 1), 69 | dilation=(1, dilation[1], 1), 70 | padding=(0, dilation[1], 0), 71 | bias=False, 72 | ) 73 | self.bn3 = norm_layer(planes, momentum=bn_momentum) 74 | self.conv4 = nn.Conv3d( 75 | planes, 76 | planes, 77 | kernel_size=(3, 1, 1), 78 | stride=(stride, 1, 1), 79 | dilation=(dilation[2], 1, 1), 80 | padding=(dilation[2], 0, 0), 81 | bias=False, 82 | ) 83 | self.bn4 = norm_layer(planes, momentum=bn_momentum) 84 | self.conv5 = nn.Conv3d( 85 | planes, planes * self.expansion, kernel_size=(1, 1, 1), bias=False 86 | ) 87 | self.bn5 = norm_layer(planes * self.expansion, momentum=bn_momentum) 88 | 89 | self.relu = nn.ReLU(inplace=False) 90 | self.relu_inplace = nn.ReLU(inplace=True) 91 | self.downsample = downsample 92 | self.dilation = dilation 93 | self.stride = stride 94 | 95 | self.downsample2 = nn.Sequential( 96 | nn.AvgPool3d(kernel_size=(1, stride, 1), stride=(1, stride, 1)), 97 | nn.Conv3d(planes, planes, kernel_size=1, stride=1, bias=False), 98 | norm_layer(planes, momentum=bn_momentum), 99 | ) 100 | self.downsample3 = nn.Sequential( 101 | nn.AvgPool3d(kernel_size=(stride, 1, 1), stride=(stride, 1, 1)), 102 | nn.Conv3d(planes, planes, kernel_size=1, stride=1, bias=False), 103 | norm_layer(planes, momentum=bn_momentum), 104 | ) 105 | self.downsample4 = nn.Sequential( 106 | nn.AvgPool3d(kernel_size=(stride, 1, 1), stride=(stride, 1, 1)), 107 | nn.Conv3d(planes, planes, kernel_size=1, stride=1, bias=False), 108 | norm_layer(planes, momentum=bn_momentum), 109 | ) 110 | 111 | def forward(self, x): 112 | residual = x 113 | 114 | out1 = self.relu(self.bn1(self.conv1(x))) 115 | out2 = self.bn2(self.conv2(out1)) 116 | out2_relu = self.relu(out2) 117 | 118 | out3 = self.bn3(self.conv3(out2_relu)) 119 | if self.stride != 1: 120 | out2 = self.downsample2(out2) 121 | out3 = out3 + out2 122 | out3_relu = self.relu(out3) 123 | 124 | out4 = self.bn4(self.conv4(out3_relu)) 125 | if self.stride != 1: 126 | out2 = self.downsample3(out2) 127 | out3 = self.downsample4(out3) 128 | out4 = out4 + out2 + out3 129 | 130 | out4_relu = self.relu(out4) 131 | out5 = self.bn5(self.conv5(out4_relu)) 132 | 133 | if self.downsample is not None: 134 | residual = self.downsample(x) 135 | 136 | out = out5 + residual 137 | out_relu = self.relu(out) 138 | 139 | return out_relu 140 | -------------------------------------------------------------------------------- /iso/models/depthnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from mmdet.models.backbones.resnet import BasicBlock 5 | 6 | import math 7 | # from occdepth.models.f2v.frustum_grid_generator import FrustumGridGenerator 8 | # from occdepth.models.f2v.frustum_to_voxel import FrustumToVoxel 9 | # from occdepth.models.f2v.sampler import Sampler 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, 15 | inplanes, 16 | planes, 17 | stride=1, 18 | dilation=1, 19 | downsample=None, 20 | # style='pytorch', 21 | with_cp=False, 22 | # conv_cfg=None, 23 | # norm_cfg=dict(type='BN'), 24 | dcn=None, 25 | plugins=None, 26 | init_cfg=None): 27 | super(BasicBlock, self).__init__() 28 | assert dcn is None, 'Not implemented yet.' 29 | assert plugins is None, 'Not implemented yet.' 30 | 31 | self.norm1 = nn.BatchNorm2d(planes) 32 | self.norm2 = nn.BatchNorm2d(planes) 33 | 34 | self.conv1 = nn.Conv2d( 35 | inplanes, 36 | planes, 37 | 3, 38 | stride=stride, 39 | padding=dilation, 40 | dilation=dilation, 41 | bias=False) 42 | 43 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 44 | 45 | self.relu = nn.ReLU(inplace=True) 46 | self.downsample = downsample 47 | self.stride = stride 48 | self.dilation = dilation 49 | self.with_cp = with_cp 50 | 51 | @property 52 | def norm1(self): 53 | """nn.Module: normalization layer after the first convolution layer""" 54 | return getattr(self, self.norm1_name) 55 | 56 | @property 57 | def norm2(self): 58 | """nn.Module: normalization layer after the second convolution layer""" 59 | return getattr(self, self.norm2_name) 60 | 61 | def forward(self, x): 62 | """Forward function.""" 63 | # print(x.shape) 64 | 65 | def _inner_forward(x): 66 | identity = x 67 | 68 | out = self.conv1(x) 69 | out = self.norm1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.norm2(out) 74 | 75 | if self.downsample is not None: 76 | identity = self.downsample(x) 77 | 78 | out += identity 79 | 80 | return out 81 | 82 | out = _inner_forward(x) 83 | 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class Mlp(nn.Module): 90 | def __init__( 91 | self, 92 | in_features, 93 | hidden_features=None, 94 | out_features=None, 95 | act_layer=nn.ReLU, 96 | drop=0.0, 97 | ): 98 | super().__init__() 99 | out_features = out_features or in_features 100 | hidden_features = hidden_features or in_features 101 | self.fc1 = nn.Linear(in_features, hidden_features) 102 | self.act = act_layer() 103 | self.drop1 = nn.Dropout(drop) 104 | self.fc2 = nn.Linear(hidden_features, out_features) 105 | self.drop2 = nn.Dropout(drop) 106 | 107 | def forward(self, x): 108 | x = self.fc1(x) 109 | x = self.act(x) 110 | x = self.drop1(x) 111 | x = self.fc2(x) 112 | x = self.drop2(x) 113 | return x 114 | 115 | 116 | class SELayer(nn.Module): 117 | def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid): 118 | super().__init__() 119 | self.conv_reduce = nn.Conv2d(channels, channels, 1, bias=True) 120 | self.act1 = act_layer() 121 | self.conv_expand = nn.Conv2d(channels, channels, 1, bias=True) 122 | self.gate = gate_layer() 123 | 124 | def forward(self, x, x_se): 125 | x_se = self.conv_reduce(x_se) 126 | x_se = self.act1(x_se) 127 | x_se = self.conv_expand(x_se) 128 | return x * self.gate(x_se) 129 | 130 | 131 | class DepthNet(nn.Module): 132 | def __init__( 133 | self, 134 | in_channels, 135 | mid_channels, 136 | context_channels, 137 | depth_channels, 138 | infer_mode=False, 139 | ): 140 | super(DepthNet, self).__init__() 141 | self.reduce_conv = nn.Sequential( 142 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1), 143 | nn.BatchNorm2d(mid_channels), 144 | nn.ReLU(inplace=True), 145 | ) 146 | self.mlp = Mlp(1, mid_channels, mid_channels) 147 | self.se = SELayer(mid_channels) # NOTE: add camera-aware 148 | self.depth_conv = nn.Sequential( 149 | BasicBlock(mid_channels, mid_channels), 150 | BasicBlock(mid_channels, mid_channels), 151 | BasicBlock(mid_channels, mid_channels), 152 | ) 153 | # self.aspp = ASPP(mid_channels, mid_channels, BatchNorm=nn.InstanceNorm2d) 154 | 155 | self.depth_pred = nn.Conv2d( 156 | mid_channels, depth_channels, kernel_size=1, stride=1, padding=0 157 | ) 158 | self.infer_mode = infer_mode 159 | 160 | def forward( 161 | self, 162 | x=None, 163 | sweep_intrins=None, 164 | scaled_pixel_size=None, 165 | scale_depth_factor=1000.0, 166 | ): 167 | # self.eval() 168 | inv_intrinsics = torch.inverse(sweep_intrins) 169 | pixel_size = torch.norm( 170 | torch.stack( 171 | [inv_intrinsics[..., 0, 0], inv_intrinsics[..., 1, 1]], dim=-1 172 | ), 173 | dim=-1, 174 | ).reshape(-1, 1).to(x.device) 175 | scaled_pixel_size = pixel_size * scale_depth_factor 176 | 177 | x = self.reduce_conv(x) 178 | # aug_scale = torch.sqrt(sweep_post_rots_ida[..., 0, 0] ** 2 + sweep_post_rots_ida[..., 0, 1] ** 2).reshape(-1, 1) 179 | x_se = self.mlp(scaled_pixel_size)[..., None, None] 180 | 181 | x = self.se(x, x_se) 182 | x = self.depth_conv(x) 183 | # x = self.aspp(x) 184 | depth = self.depth_pred(x) 185 | return depth 186 | -------------------------------------------------------------------------------- /iso/models/flosp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FLoSP(nn.Module): 6 | def __init__(self, scene_size, dataset, project_scale): 7 | super().__init__() 8 | self.scene_size = scene_size 9 | self.dataset = dataset 10 | self.project_scale = project_scale 11 | 12 | def forward(self, x2d, projected_pix, fov_mask): 13 | c, h, w = x2d.shape 14 | 15 | src = x2d.view(c, -1) 16 | zeros_vec = torch.zeros(c, 1).type_as(src) 17 | src = torch.cat([src, zeros_vec], 1) 18 | 19 | pix_x, pix_y = projected_pix[:, 0], projected_pix[:, 1] 20 | img_indices = pix_y * w + pix_x 21 | img_indices[~fov_mask] = h * w 22 | img_indices = img_indices.expand(c, -1).long() # c, HWD 23 | src_feature = torch.gather(src, 1, img_indices) 24 | 25 | if self.dataset == "NYU": 26 | x3d = src_feature.reshape( 27 | c, 28 | self.scene_size[0] // self.project_scale, 29 | self.scene_size[2] // self.project_scale, 30 | self.scene_size[1] // self.project_scale, 31 | ) 32 | x3d = x3d.permute(0, 1, 3, 2) 33 | elif self.dataset == "OccScanNet" or self.dataset == "OccScanNet_mini": 34 | x3d = src_feature.reshape( 35 | c, 36 | self.scene_size[0] // self.project_scale, 37 | self.scene_size[1] // self.project_scale, 38 | self.scene_size[2] // self.project_scale, 39 | ) 40 | 41 | return x3d 42 | -------------------------------------------------------------------------------- /iso/models/iso.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn as nn 4 | from iso.models.unet3d_nyu import UNet3D as UNet3DNYU 5 | from iso.loss.sscMetrics import SSCMetrics 6 | from iso.loss.ssc_loss import sem_scal_loss, CE_ssc_loss, KL_sep, geo_scal_loss 7 | from iso.models.flosp import FLoSP 8 | from iso.models.depthnet import DepthNet 9 | from iso.loss.CRP_loss import compute_super_CP_multilabel_loss 10 | import numpy as np 11 | import torch.nn.functional as F 12 | from iso.models.unet2d import UNet2D 13 | from torch.optim.lr_scheduler import MultiStepLR 14 | import sys 15 | sys.path.append('./iso') 16 | sys.path.append('./depth_anything/metric_depth') 17 | from depth_anything.metric_depth.zoedepth.models.builder import build_model as build_depthany_model 18 | from depth_anything.metric_depth.zoedepth.utils.config import get_config as get_depthany_config 19 | 20 | from iso.models.modules import sample_grid_feature, get_depth_index, sample_3d_feature, bin_depths 21 | # from iso.models.depth_utils import down_sample_depth_dist 22 | from iso.loss.depth_loss import DepthClsLoss 23 | 24 | import torch 25 | 26 | from torch.cuda.amp import autocast 27 | import torch.nn.functional as F 28 | 29 | # from transformers import pipeline 30 | from PIL import Image 31 | 32 | 33 | class ISO(pl.LightningModule): 34 | def __init__( 35 | self, 36 | n_classes, 37 | class_names, 38 | feature, 39 | class_weights, 40 | project_scale, 41 | full_scene_size, 42 | dataset, 43 | n_relations=4, 44 | context_prior=True, 45 | fp_loss=True, 46 | project_res=[], 47 | bevdepth=False, 48 | voxeldepth=False, 49 | voxeldepth_res=[], 50 | frustum_size=4, 51 | relation_loss=False, 52 | CE_ssc_loss=True, 53 | geo_scal_loss=True, 54 | sem_scal_loss=True, 55 | lr=1e-4, 56 | weight_decay=1e-4, 57 | use_gt_depth=False, 58 | add_fusion=False, 59 | use_zoedepth=True, 60 | use_depthanything=False, 61 | zoedepth_as_gt=False, 62 | depthanything_as_gt=False, 63 | frozen_encoder=False, 64 | ): 65 | super().__init__() 66 | 67 | self.project_res = project_res 68 | self.bevdepth = bevdepth 69 | self.voxeldepth = voxeldepth 70 | self.voxeldepth_res = voxeldepth_res 71 | self.fp_loss = fp_loss 72 | self.dataset = dataset 73 | self.context_prior = context_prior 74 | self.frustum_size = frustum_size 75 | self.class_names = class_names 76 | self.relation_loss = relation_loss 77 | self.CE_ssc_loss = CE_ssc_loss 78 | self.sem_scal_loss = sem_scal_loss 79 | self.geo_scal_loss = geo_scal_loss 80 | self.project_scale = project_scale 81 | self.class_weights = class_weights 82 | self.lr = lr 83 | self.weight_decay = weight_decay 84 | self.use_gt_depth = use_gt_depth 85 | self.add_fusion = add_fusion 86 | self.use_zoedepth = use_zoedepth 87 | self.use_depthanything = use_depthanything 88 | self.zoedepth_as_gt = zoedepth_as_gt 89 | self.depthanything_as_gt = depthanything_as_gt 90 | self.frozen_encoder = frozen_encoder 91 | 92 | self.projects = {} 93 | self.scale_2ds = [1, 2, 4, 8] # 2D scales 94 | for scale_2d in self.scale_2ds: 95 | self.projects[str(scale_2d)] = FLoSP( 96 | full_scene_size, project_scale=self.project_scale, dataset=self.dataset 97 | ) 98 | self.projects = nn.ModuleDict(self.projects) 99 | 100 | self.n_classes = n_classes 101 | if self.dataset == "NYU" or self.dataset == "OccScanNet" or self.dataset == "OccScanNet_mini": 102 | self.net_3d_decoder = UNet3DNYU( 103 | self.n_classes, 104 | nn.BatchNorm3d, 105 | n_relations=n_relations, 106 | feature=feature, 107 | full_scene_size=full_scene_size, 108 | context_prior=context_prior, 109 | ) 110 | # if self.voxeldepth: 111 | # self.depth_net_3d_decoder = UNet3DNYU( 112 | # self.n_classes, 113 | # nn.BatchNorm3d, 114 | # n_relations=n_relations, 115 | # feature=feature, 116 | # full_scene_size=full_scene_size, 117 | # context_prior=context_prior, 118 | # beforehead=True, 119 | # ) 120 | elif self.dataset == "kitti": 121 | self.net_3d_decoder = UNet3DKitti( 122 | self.n_classes, 123 | nn.BatchNorm3d, 124 | project_scale=project_scale, 125 | feature=feature, 126 | full_scene_size=full_scene_size, 127 | context_prior=context_prior, 128 | ) 129 | self.net_rgb = UNet2D.build(out_feature=feature, use_decoder=True, frozen_encoder=self.frozen_encoder) 130 | 131 | if self.voxeldepth and not self.use_gt_depth: 132 | self.depthnet_1_1 = DepthNet(200, 256, 200, 64) 133 | if self.use_zoedepth: # use zoedepth pretrained 134 | self.net_depth = self._init_zoedepth() 135 | self.depthnet_1_1 = DepthNet(201, 256, 200, 64) 136 | elif self.use_depthanything: 137 | self.net_depth = self._init_depthanything() 138 | self.depthnet_1_1 = DepthNet(201, 256, 200, 64) 139 | elif self.zoedepth_as_gt: # use gt and use zoedepth as gt 140 | self.net_depth = self._init_zoedepth() 141 | elif self.depthanything_as_gt: 142 | self.net_depth = self._init_depthanything() 143 | elif self.bevdepth: 144 | self.net_depth = self._init_zoedepth() 145 | else: 146 | pass # use gt and use dataset gt 147 | 148 | # log hyperparameters 149 | self.save_hyperparameters() 150 | 151 | self.train_metrics = SSCMetrics(self.n_classes) 152 | self.val_metrics = SSCMetrics(self.n_classes) 153 | self.test_metrics = SSCMetrics(self.n_classes) 154 | 155 | def _init_zoedepth(self): 156 | conf = get_config("zoedepth", "infer") 157 | conf['img_size'] = [480, 640] 158 | model_zoe_n = build_model(conf) 159 | return model_zoe_n.cuda() 160 | 161 | def _init_depthanything(self): 162 | import sys 163 | sys.path.append('/home/hongxiao.yu/projects/ISO/depth_anything/metric_depth') 164 | overrite = {"pretrained_resource": "local::/home/hongxiao.yu/projects/ISO/checkpoints/depth_anything_metric_depth_indoor.pt"} 165 | conf = get_depthany_config("zoedepth", "infer", "nyu", **overrite) 166 | # conf['img_size'] = [480, 640] 167 | from pprint import pprint 168 | # pprint(conf) 169 | model = build_depthany_model(conf) 170 | return model.cuda() 171 | # return pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf") 172 | 173 | def _set_train_params(self, curr_epoch): 174 | if curr_epoch // 2 == 1: 175 | for k, p in self.named_parameters(): 176 | if 'depthnet_1_1' in k: 177 | p.requires_grad = True 178 | else: 179 | p.requires_grad = False 180 | else: 181 | for k, p in self.named_parameters(): 182 | if 'depthnet_1_1' in k: 183 | p.requires_grad = False 184 | else: 185 | p.requires_grad = True 186 | 187 | def forward(self, batch): 188 | 189 | # for k, v in self.state_dict().items(): 190 | # if v.dtype != torch.float32: 191 | # print(k, v.dtype) 192 | # breakpoint() 193 | 194 | # self._set_train_params(self.current_epoch) 195 | 196 | img = batch["img"] 197 | raw_img = batch['raw_img'] 198 | pix_z = batch['pix_z'] # (B, 129600) 199 | bs = len(img) 200 | # print(img) 201 | 202 | out = {} 203 | 204 | # for k, v in self.named_parameters(): 205 | # print(k, ':', v) 206 | x_rgb = self.net_rgb(img) 207 | 208 | x3ds = [] 209 | x3d_bevs = [] 210 | if self.add_fusion: 211 | x3ds_res = [] 212 | if self.voxeldepth: 213 | depth_preds = { 214 | '1_1': [], 215 | '1_2': [], 216 | '1_4': [], 217 | '1_8': [], 218 | } 219 | for i in range(bs): 220 | if self.voxeldepth: 221 | x3d_depths = { 222 | '1_1': None, 223 | '1_2': None, 224 | '1_4': None, 225 | '1_8': None, 226 | } 227 | depths = { 228 | '1_1': None, 229 | '1_2': None, 230 | '1_4': None, 231 | '1_8': None, 232 | } 233 | 234 | if not self.use_gt_depth: 235 | if self.use_zoedepth: 236 | # self.net_depth.eval() # zoe_depth 237 | # for param in self.net_depth.parameters(): 238 | # param.requires_grad = False 239 | # rslts = self.net_depth(raw_img[i:i+1], return_probs=False, return_final_centers=False) 240 | # feature = rslts['metric_depth'] # (1, 1, 480, 640) 241 | self.net_depth.device = 'cuda' 242 | feature = self.net_depth.infer_pil(raw_img[i], output_type="tensor", with_flip_aug=False).cuda().unsqueeze(0).unsqueeze(0) 243 | 244 | # print(feature.shape) 245 | 246 | # import matplotlib.pyplot as plt 247 | # plt.subplot(1, 2, 1) 248 | # plt.imshow(feature[0].permute(1, 2, 0).cpu().numpy()) 249 | # plt.subplot(1, 2, 2) 250 | # plt.imshow(batch['depth_gt'][i].permute(1, 2, 0).cpu().numpy()) 251 | # plt.savefig('/home/hongxiao.yu/ISO/depth_compare.png') 252 | 253 | input_kwargs = { 254 | "img_feat_1_1": torch.cat([x_rgb['1_1'][i:i+1], feature], dim=1), 255 | "cam_k": batch["cam_k"][i:i+1], 256 | "T_velo_2_cam": batch["cam_pose"][i:i+1], 257 | "vox_origin": batch['vox_origin'][i:i+1], 258 | } 259 | elif self.use_depthanything: 260 | self.net_depth.device = 'cuda' 261 | feature = self.net_depth.infer_pil(raw_img[i], output_type="tensor", with_flip_aug=False).cuda().unsqueeze(0).unsqueeze(0) 262 | 263 | # print(feature.shape) 264 | # print(feature.shape) 265 | 266 | # import matplotlib.pyplot as plt 267 | # plt.subplot(1, 2, 1) 268 | # plt.imshow(feature[0].permute(1, 2, 0).cpu().numpy()) 269 | # plt.subplot(1, 2, 2) 270 | # plt.imshow(batch['depth_gt'][i].permute(1, 2, 0).cpu().numpy()) 271 | # plt.savefig('/home/hongxiao.yu/ISO/depth_compare.png') 272 | 273 | input_kwargs = { 274 | "img_feat_1_1": torch.cat([x_rgb['1_1'][i:i+1], feature], dim=1), 275 | "cam_k": batch["cam_k"][i:i+1], 276 | "T_velo_2_cam": batch["cam_pose"][i:i+1], 277 | "vox_origin": batch['vox_origin'][i:i+1], 278 | } 279 | else: 280 | input_kwargs = { 281 | "img_feat_1_1": x_rgb['1_1'][i:i+1], 282 | "cam_k": batch["cam_k"][i:i+1], 283 | "T_velo_2_cam": batch["cam_pose"][i:i+1], 284 | "vox_origin": batch['vox_origin'][i:i+1], 285 | } 286 | intrins_mat = input_kwargs['cam_k'].new_zeros(1, 4, 4).to(torch.float) 287 | intrins_mat[:, :3, :3] = input_kwargs['cam_k'] 288 | intrins_mat[:, 3, 3] = 1 # (1, 4, 4) 289 | 290 | depth_feature_1_1 = self.depthnet_1_1( 291 | x=input_kwargs['img_feat_1_1'], 292 | sweep_intrins=intrins_mat, 293 | scaled_pixel_size=None, 294 | ) 295 | depths['1_1'] = depth_feature_1_1.softmax(1) # 得到depth的分布 296 | for res in self.voxeldepth_res[1:]: 297 | depths['1_'+str(res)] = down_sample_depth_dist(depths['1_1'], int(res)) 298 | else: 299 | disc_cfg = { 300 | "mode": "LID", 301 | "num_bins": 64, 302 | "depth_min": 0, 303 | "depth_max": 10, 304 | } 305 | depth_1_1 = batch['depth_gt'][i:i+1] 306 | # print(depth_1_1.shape) 307 | if self.zoedepth_as_gt: 308 | self.net_depth.eval() # zoe_depth 309 | for param in self.net_depth.parameters(): 310 | param.requires_grad = False 311 | self.net_depth.device = 'cuda' 312 | rslts = self.net_depth.infer_pil(raw_img[i], output_type="tensor", with_flip_aug=False).cuda() # TODO: Need Check 313 | depth_1_1 = rslts['metric_depth'] # (1, 1, 480, 640) 314 | elif self.depthanything_as_gt: 315 | self.net_depth.eval() # zoe_depth 316 | for param in self.net_depth.parameters(): 317 | param.requires_grad = False 318 | self.net_depth.device = 'cuda' 319 | depth_1_1 = self.net_depth.infer_pil(raw_img[i], output_type="tensor", with_flip_aug=False).cuda().unsqueeze(0).unsqueeze(0) # TODO: Need Check 320 | # depth_1_1 = rslts['metric_depth'] # (1, 1, 480, 640) 321 | # print(depth_1_1.shape) 322 | depth_1_1 = bin_depths(depth_map=depth_1_1, target=True, **disc_cfg).cuda() # (1, 1, 480, 640) 323 | print(depth_1_1.shape) 324 | # depth_1_1 = ((depth_1_1 - (0.1 - 0.13)) / 0.13).cuda().long() 325 | # print(depth_1_1.shape) 326 | depth_1_1 = F.one_hot(depth_1_1[:, 0, :, :], 81).permute(0, 3, 1, 2)[:, :-1, :, :] # (1, 81, 480, 640) 327 | depths['1_1'] = depth_1_1.float() # 得到depth的分布 328 | for res in self.voxeldepth_res[1:]: 329 | depths['1_'+str(res)] = down_sample_depth_dist(depths['1_1'], int(res)).float() 330 | 331 | projected_pix = batch["projected_pix_{}".format(self.project_scale)][i].cuda() 332 | fov_mask = batch["fov_mask_{}".format(self.project_scale)][i].cuda() 333 | 334 | # pix_z_index = get_depth_index(pix_z[i]) 335 | disc_cfg = { 336 | "mode": "LID", 337 | "num_bins": 64, 338 | "depth_min": 0, 339 | "depth_max": 10, 340 | } 341 | pix_z_index = bin_depths(depth_map=pix_z[i], target=True, **disc_cfg).to(fov_mask.device) 342 | # pix_z_index = ((pix_z[i] - (0.1 - 0.13)) / 0.13).to(fov_mask.device) 343 | # pix_z_index = torch.where( 344 | # (pix_z_index < 80 + 1) & (pix_z_index >= 0.0), 345 | # pix_z_index, 346 | # torch.zeros_like(pix_z_index), 347 | # ) 348 | dist_mask = torch.logical_and(pix_z_index >= 0, pix_z_index < 64) 349 | dist_mask = torch.logical_and(dist_mask, fov_mask) 350 | 351 | for res in self.voxeldepth_res: 352 | probs = torch.zeros((129600, 1), dtype=torch.float32).to(self.device) 353 | # print(depths['1_'+str(res)].dtype, (projected_pix//int(res)).dtype, pix_z_index.dtype) 354 | probs[dist_mask] = sample_3d_feature(depths['1_'+str(res)], projected_pix//int(res), pix_z_index, dist_mask) 355 | if self.dataset == 'NYU': 356 | x3d_depths['1_'+str(res)] = probs.reshape(60, 60, 36).permute(0, 2, 1).unsqueeze(0) 357 | elif self.dataset == 'OccScanNet' or self.dataset == 'OccScanNet_mini': 358 | x3d_depths['1_'+str(res)] = probs.reshape(60, 60, 36).unsqueeze(0) 359 | depth_preds['1_'+str(res)].append(depths['1_'+str(res)]) # (1, 64, 60, 80) 360 | 361 | x3d = None 362 | if self.add_fusion: 363 | x3d_res = None 364 | for scale_2d in self.project_res: 365 | 366 | # project features at each 2D scale to target 3D scale 367 | scale_2d = int(scale_2d) 368 | projected_pix = batch["projected_pix_{}".format(self.project_scale)][i].cuda() 369 | fov_mask = batch["fov_mask_{}".format(self.project_scale)][i].cuda() 370 | if self.bevdepth and scale_2d == 4: 371 | xys = projected_pix // scale_2d 372 | 373 | D, fH, fW = 64, 480 // scale_2d, 640 // scale_2d 374 | xs = torch.linspace(0, 640 - 1, fW, dtype=torch.float).view(1, 1, fW).expand(D, fH, fW).cuda() # (64, 120, 160) 375 | ys = torch.linspace(0, 480 - 1, fH, dtype=torch.float).view(1, fH, 1).expand(D, fH, fW).cuda() # (64, 120, 160) 376 | d_xs = torch.floor(xs[0].reshape(-1)).to(torch.long) # (fH*fW,) 377 | d_ys = torch.floor(ys[0].reshape(-1)).to(torch.long) # (fH*fW,) 378 | 379 | self.net_depth.device = 'cuda' 380 | # feature = self.net_depth.infer_pil(raw_img[i], output_type="tensor", with_flip_aug=False).cuda().unsqueeze(0).unsqueeze(0) 381 | rslts = self.net_depth(img[i:i+1], return_final_centers=True, return_probs=True) 382 | probs = rslts['probs'].cuda() # 得到depth distribution 383 | bin_center = rslts['bin_centers'].cuda() # (1, 64, 384, 512) 384 | # print(probs.shape) 385 | probs = probs[0, :, d_ys, d_xs].reshape(D, fH, fW).to(projected_pix.device) # (D, fH, fW) depth distribution 386 | ds = bin_center[0, :, d_ys, d_xs].reshape(D, fH, fW).to(xs.device) # (D, fH, fW) depth bins distance 387 | 388 | # x = F.interpolate(x_rgb["1_" + str(scale_2d)][i][None, ...], (384, 512), mode='bilinear', align_corners=True)[0] # (200, 384, 512) 389 | x = x_rgb["1_" + str(scale_2d)][i] 390 | x = probs.unsqueeze(0) * x.unsqueeze(1) # (1, 64, 384, 512), (200, 1, 384, 512) --> (200, 64, 384, 512) 391 | x = x.reshape(-1, fH, fW) 392 | x = self.projects[str(scale_2d)]( 393 | x, 394 | xys, 395 | fov_mask, 396 | ) 397 | # print(x.shape) # (200*64, 60, 60, 36) 398 | x = x.reshape(200*64, 60*36*60).T # (60*60*36, 200*64) 399 | pix_z_depth = ds[:, xys[:, 1][fov_mask], xys[:, 0][fov_mask]] # get the depth bins distance (64, K) 400 | pix_zz = pix_z[i][fov_mask] # (K,) 401 | # print(pix_z_depth.shape, pix_z.shape) 402 | pix_z_delta = torch.abs(pix_z_depth.to(pix_zz.device) - pix_zz[None, ...]) # (64, K) 403 | # print(pix_z_delta.shape) 404 | min_z = torch.argmin(pix_z_delta, dim=0) # (K,) 405 | # print(min_z.shape) 406 | temp = x[fov_mask].reshape(min_z.shape[0], 200, 64) # (K, 200, 64) 407 | temp = temp[torch.arange(min_z.shape[0]).to(torch.long), :, min_z] # (K, 200) 408 | # temp = temp.reshape(temp.shape[0], 200, 1).repeat(1, 1, 64).reshape(temp.shape[0], 200*64) # (K, 200*64) 409 | # x[fov_mask] = x[fov_mask][torch.arange(min_z.shape[0]), min_z*200:(min_z+1)*200] = temp # (K, 200*64) 410 | # x = x.T.reshape(200*64, 60, 60, 36)[:200, ...] 411 | x3d_bev = torch.zeros(60*36*60, 200).to(x.device) 412 | x3d_bev[fov_mask] = temp 413 | x3d_bev = x3d_bev.T.reshape(200, 60, 36, 60) 414 | # torch.cuda.empty_cache() 415 | 416 | if self.add_fusion: 417 | if x3d_res is None: 418 | x3d_res = self.projects[str(scale_2d)]( 419 | x_rgb["1_" + str(scale_2d)][i], 420 | projected_pix // scale_2d, 421 | fov_mask, 422 | ) 423 | else: 424 | x3d_res += self.projects[str(scale_2d)]( 425 | x_rgb["1_" + str(scale_2d)][i], 426 | projected_pix // scale_2d, 427 | fov_mask, 428 | ) 429 | 430 | # Sum all the 3D features 431 | if x3d is None: 432 | if self.voxeldepth: 433 | if len(self.voxeldepth_res) == 1: 434 | res = self.voxeldepth_res[0] 435 | x3d = self.projects[str(scale_2d)]( 436 | x_rgb["1_" + str(scale_2d)][i], 437 | projected_pix // scale_2d, 438 | fov_mask, 439 | ) * x3d_depths['1_'+str(res)] * 100 440 | else: 441 | x3d = self.projects[str(scale_2d)]( 442 | x_rgb["1_" + str(scale_2d)][i], 443 | projected_pix // scale_2d, 444 | fov_mask, 445 | ) * x3d_depths['1_1'] * 100 446 | else: 447 | x3d = self.projects[str(scale_2d)]( 448 | x_rgb["1_" + str(scale_2d)][i], 449 | projected_pix // scale_2d, 450 | fov_mask, 451 | ) 452 | else: 453 | if self.voxeldepth: 454 | if len(self.voxeldepth_res) == 1: 455 | res = self.voxeldepth_res[0] 456 | x3d = self.projects[str(scale_2d)]( 457 | x_rgb["1_" + str(scale_2d)][i], 458 | projected_pix // scale_2d, 459 | fov_mask, 460 | ) * x3d_depths['1_'+str(res)] * 100 461 | else: 462 | x3d += self.projects[str(scale_2d)]( 463 | x_rgb["1_" + str(scale_2d)][i], 464 | projected_pix // scale_2d, 465 | fov_mask, 466 | ) * x3d_depths['1_'+str(scale_2d)] * 100 467 | else: 468 | x3d += self.projects[str(scale_2d)]( 469 | x_rgb["1_" + str(scale_2d)][i], 470 | projected_pix // scale_2d, 471 | fov_mask, 472 | ) 473 | x3ds.append(x3d) 474 | if self.add_fusion: 475 | x3ds_res.append(x3d_res) 476 | if self.bevdepth: 477 | x3d_bevs.append(x3d_bev) 478 | 479 | input_dict = { 480 | "x3d": torch.stack(x3ds), 481 | } 482 | if self.add_fusion: 483 | input_dict['x3d'] += torch.stack(x3ds_res) 484 | if self.bevdepth: 485 | # print(input_dict['x3d'].shape, torch.stack(x3d_bevs).shape) 486 | input_dict['x3d'] += torch.stack(x3d_bevs) 487 | 488 | # print(input_dict["x3d"]) 489 | # from pytorch_lightning import seed_everything 490 | # seed_everything(84, True) 491 | # input_dict['x3d'] = torch.randn_like(input_dict['x3d']).to(torch.float64) 492 | # print(input_dict['x3d'], 'x3d') 493 | 494 | out = self.net_3d_decoder(input_dict) 495 | # print(torch.randn((5, 5)), 'x3d') 496 | # print(torch.randn((5, 5)), 'x3d') 497 | 498 | if self.voxeldepth and not self.use_gt_depth: 499 | for res in self.voxeldepth_res: 500 | out['depth_1_'+str(res)] = torch.vstack(depth_preds['1_'+str(res)]) # (B, 64, 60, 80) 501 | out['depth_gt'] = batch['depth_gt'] 502 | # print(batch['name'], out["ssc_logit"]) 503 | return out 504 | 505 | def step(self, batch, step_type, metric): 506 | bs = len(batch["img"]) 507 | loss = 0 508 | out_dict = self(batch) 509 | ssc_pred = out_dict["ssc_logit"] 510 | ssc_pred = torch.where(torch.isinf(ssc_pred), torch.zeros_like(ssc_pred), ssc_pred) 511 | ssc_pred = torch.where(torch.isnan(ssc_pred), torch.zeros_like(ssc_pred), ssc_pred) 512 | # print(ssc_pred.requires_grad, ssc_pred.grad_fn) 513 | # torch.cuda.manual_seed_all(42) 514 | # ssc_pred = torch.randn_like(ssc_pred) 515 | target = batch["target"] 516 | 517 | if self.voxeldepth and not self.use_gt_depth: 518 | loss_depth_dict = {} 519 | loss_depth = 0.0 520 | out_dict['depth_1_1'] = torch.where(torch.isnan(out_dict['depth_1_1']), torch.zeros_like(out_dict['depth_1_1']), out_dict['depth_1_1'], ) 521 | # print(521, torch.any(torch.isnan(out_dict['depth_1_1']))) 522 | if torch.any(torch.isnan(out_dict['depth_1_1'])): 523 | print("===== Got Inf !!! =====") 524 | exit(-1) 525 | for res in self.voxeldepth_res: 526 | loss_depth_dict['1_'+str(res)] = DepthClsLoss(int(res), [0.0, 10, 0.16], 64).get_depth_loss(out_dict['depth_gt'], 527 | out_dict['depth_1_'+str(res)].unsqueeze(1)) 528 | loss_depth += loss_depth_dict['1_'+str(res)] 529 | loss_depth = loss_depth / len(loss_depth_dict) 530 | loss += loss_depth 531 | # print(521, torch.any(torch.isnan(loss_depth))) 532 | self.log( 533 | step_type + "/loss_depth", 534 | loss_depth.detach(), 535 | on_epoch=True, 536 | sync_dist=True, 537 | ) 538 | 539 | if self.context_prior: 540 | P_logits = out_dict["P_logits"] 541 | P_logits = torch.where(torch.isnan(P_logits), torch.zeros_like(P_logits), P_logits) 542 | 543 | CP_mega_matrices = batch["CP_mega_matrices"] 544 | 545 | if self.relation_loss: 546 | # print(543, torch.any(torch.isnan(P_logits))) 547 | loss_rel_ce = compute_super_CP_multilabel_loss( 548 | P_logits, CP_mega_matrices 549 | ) 550 | # loss_rel_ce = torch.where(torch.isnan(loss_rel_ce), torch.zeros_like(loss_rel_ce), loss_rel_ce) 551 | loss += loss_rel_ce 552 | # print(543, torch.any(torch.isnan(loss_rel_ce))) 553 | self.log( 554 | step_type + "/loss_relation_ce_super", 555 | loss_rel_ce.detach(), 556 | on_epoch=True, 557 | sync_dist=True, 558 | ) 559 | 560 | class_weight = self.class_weights.type_as(batch["img"]) 561 | if self.CE_ssc_loss: 562 | loss_ssc = CE_ssc_loss(ssc_pred, target, class_weight) 563 | # print(558, torch.any(torch.isnan(loss_ssc))) 564 | loss += loss_ssc 565 | self.log( 566 | step_type + "/loss_ssc", 567 | loss_ssc.detach(), 568 | on_epoch=True, 569 | sync_dist=True, 570 | ) 571 | 572 | if self.sem_scal_loss: 573 | loss_sem_scal = sem_scal_loss(ssc_pred, target) 574 | # print(569, torch.any(torch.isnan(loss_sem_scal))) 575 | loss += loss_sem_scal 576 | self.log( 577 | step_type + "/loss_sem_scal", 578 | loss_sem_scal.detach(), 579 | on_epoch=True, 580 | sync_dist=True, 581 | ) 582 | 583 | if self.geo_scal_loss: 584 | loss_geo_scal = geo_scal_loss(ssc_pred, target) 585 | # print(580, torch.any(torch.isnan(loss_geo_scal))) 586 | loss += loss_geo_scal 587 | self.log( 588 | step_type + "/loss_geo_scal", 589 | loss_geo_scal.detach(), 590 | on_epoch=True, 591 | sync_dist=True, 592 | ) 593 | 594 | if self.fp_loss and step_type != "test": 595 | frustums_masks = torch.stack(batch["frustums_masks"]) 596 | frustums_class_dists = torch.stack( 597 | batch["frustums_class_dists"] 598 | ).float() # (bs, n_frustums, n_classes) 599 | n_frustums = frustums_class_dists.shape[1] 600 | 601 | pred_prob = F.softmax(ssc_pred, dim=1) 602 | batch_cnt = frustums_class_dists.sum(0) # (n_frustums, n_classes) 603 | 604 | frustum_loss = 0 605 | frustum_nonempty = 0 606 | for frus in range(n_frustums): 607 | frustum_mask = frustums_masks[:, frus, :, :, :].unsqueeze(1).float() 608 | prob = frustum_mask * pred_prob # bs, n_classes, H, W, D 609 | prob = prob.reshape(bs, self.n_classes, -1).permute(1, 0, 2) 610 | prob = prob.reshape(self.n_classes, -1) 611 | cum_prob = prob.sum(dim=1) # n_classes 612 | 613 | total_cnt = torch.sum(batch_cnt[frus]) 614 | total_prob = prob.sum() 615 | if total_prob > 0 and total_cnt > 0: 616 | frustum_target_proportion = batch_cnt[frus] / total_cnt 617 | cum_prob = cum_prob / total_prob # n_classes 618 | frustum_loss_i = KL_sep(cum_prob, frustum_target_proportion) 619 | frustum_loss += frustum_loss_i 620 | frustum_nonempty += 1 621 | frustum_loss = frustum_loss / frustum_nonempty 622 | loss += frustum_loss 623 | # print(618, torch.any(torch.isnan(frustum_loss))) 624 | self.log( 625 | step_type + "/loss_frustums", 626 | frustum_loss.detach(), 627 | on_epoch=True, 628 | sync_dist=True, 629 | ) 630 | 631 | y_true = target.cpu().numpy() 632 | y_pred = ssc_pred.detach().cpu().numpy() 633 | y_pred = np.argmax(y_pred, axis=1) 634 | metric.add_batch(y_pred, y_true) 635 | 636 | self.log(step_type + "/loss", loss.detach(), on_epoch=True, sync_dist=True, prog_bar=True) 637 | 638 | # loss_dict = {'loss': loss} 639 | # print(636, torch.any(torch.isnan(loss))) 640 | 641 | return loss 642 | 643 | def training_step(self, batch, batch_idx): 644 | return self.step(batch, "train", self.train_metrics) 645 | 646 | def validation_step(self, batch, batch_idx): 647 | self.step(batch, "val", self.val_metrics) 648 | 649 | def on_validation_epoch_end(self): #, outputs): 650 | metric_list = [("train", self.train_metrics), ("val", self.val_metrics)] 651 | 652 | for prefix, metric in metric_list: 653 | stats = metric.get_stats() 654 | for i, class_name in enumerate(self.class_names): 655 | self.log( 656 | "{}_SemIoU/{}".format(prefix, class_name), 657 | stats["iou_ssc"][i], 658 | sync_dist=True, 659 | ) 660 | self.log("{}/mIoU".format(prefix), torch.tensor(stats["iou_ssc_mean"]).to(torch.float32), sync_dist=True) 661 | self.log("{}/IoU".format(prefix), torch.tensor(stats["iou"]).to(torch.float32), sync_dist=True) 662 | self.log("{}/Precision".format(prefix), torch.tensor(stats["precision"]).to(torch.float32), sync_dist=True) 663 | self.log("{}/Recall".format(prefix), torch.tensor(stats["recall"]).to(torch.float32), sync_dist=True) 664 | metric.reset() 665 | 666 | def test_step(self, batch, batch_idx): 667 | self.step(batch, "test", self.test_metrics) 668 | 669 | def on_test_epoch_end(self,):# outputs): 670 | classes = self.class_names 671 | metric_list = [("test", self.test_metrics)] 672 | for prefix, metric in metric_list: 673 | print("{}======".format(prefix)) 674 | stats = metric.get_stats() 675 | print( 676 | "Precision={:.4f}, Recall={:.4f}, IoU={:.4f}".format( 677 | stats["precision"] * 100, stats["recall"] * 100, stats["iou"] * 100 678 | ) 679 | ) 680 | print("class IoU: {}, ".format(classes)) 681 | print( 682 | " ".join(["{:.4f}, "] * len(classes)).format( 683 | *(stats["iou_ssc"] * 100).tolist() 684 | ) 685 | ) 686 | print("mIoU={:.4f}".format(stats["iou_ssc_mean"] * 100)) 687 | metric.reset() 688 | 689 | def configure_optimizers(self): 690 | if self.dataset == "NYU" and not self.voxeldepth: 691 | optimizer = torch.optim.AdamW( 692 | self.parameters(), lr=self.lr, weight_decay=self.weight_decay 693 | ) 694 | scheduler = MultiStepLR(optimizer, milestones=[20], gamma=0.1) 695 | return [optimizer], [scheduler] 696 | elif self.dataset == "NYU" and self.voxeldepth: 697 | depth_params = [] 698 | other_params = [] 699 | for k, p in self.named_parameters(): 700 | if 'depthnet_1_1' in k: 701 | depth_params.append(p) 702 | else: 703 | other_params.append(p) 704 | params_list = [{'params': depth_params, 'lr': self.lr * 0.05}, # 0.4 high, 0.1 unstable, 0.05 ok 705 | {'params': other_params}] 706 | optimizer = torch.optim.AdamW( 707 | params_list, lr=self.lr, weight_decay=self.weight_decay 708 | ) 709 | scheduler = MultiStepLR(optimizer, milestones=[20], gamma=0.1) 710 | return [optimizer], [scheduler] 711 | elif self.dataset == "OccScanNet" or self.dataset == "OccScanNet_mini" and self.voxeldepth: 712 | depth_params = [] 713 | other_params = [] 714 | for k, p in self.named_parameters(): 715 | if 'depthnet_1_1' in k: 716 | depth_params.append(p) 717 | else: 718 | other_params.append(p) 719 | params_list = [{'params': depth_params, 'lr': self.lr * 0.05}, # 0.4 high, 0.1 unstable, 0.05 ok 720 | {'params': other_params}] 721 | optimizer = torch.optim.AdamW( 722 | params_list, lr=self.lr, weight_decay=self.weight_decay 723 | ) 724 | scheduler = MultiStepLR(optimizer, milestones=[40], gamma=0.1) 725 | return [optimizer], [scheduler] 726 | -------------------------------------------------------------------------------- /iso/models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from iso.models.DDR import Bottleneck3D 5 | 6 | 7 | class ASPP(nn.Module): 8 | """ 9 | ASPP 3D 10 | Adapt from https://github.com/cv-rits/LMSCNet/blob/main/LMSCNet/models/LMSCNet.py#L7 11 | """ 12 | 13 | def __init__(self, planes, dilations_conv_list): 14 | super().__init__() 15 | 16 | # ASPP Block 17 | self.conv_list = dilations_conv_list 18 | self.conv1 = nn.ModuleList( 19 | [ 20 | nn.Conv3d( 21 | planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False 22 | ) 23 | for dil in dilations_conv_list 24 | ] 25 | ) 26 | self.bn1 = nn.ModuleList( 27 | [nn.BatchNorm3d(planes) for dil in dilations_conv_list] 28 | ) 29 | self.conv2 = nn.ModuleList( 30 | [ 31 | nn.Conv3d( 32 | planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False 33 | ) 34 | for dil in dilations_conv_list 35 | ] 36 | ) 37 | self.bn2 = nn.ModuleList( 38 | [nn.BatchNorm3d(planes) for dil in dilations_conv_list] 39 | ) 40 | self.relu = nn.ReLU() 41 | 42 | def forward(self, x_in): 43 | 44 | y = self.bn2[0](self.conv2[0](self.relu(self.bn1[0](self.conv1[0](x_in))))) 45 | for i in range(1, len(self.conv_list)): 46 | y += self.bn2[i](self.conv2[i](self.relu(self.bn1[i](self.conv1[i](x_in))))) 47 | x_in = self.relu(y + x_in) # modified 48 | 49 | return x_in 50 | 51 | 52 | class SegmentationHead(nn.Module): 53 | """ 54 | 3D Segmentation heads to retrieve semantic segmentation at each scale. 55 | Formed by Dim expansion, Conv3D, ASPP block, Conv3D. 56 | Taken from https://github.com/cv-rits/LMSCNet/blob/main/LMSCNet/models/LMSCNet.py#L7 57 | """ 58 | 59 | def __init__(self, inplanes, planes, nbr_classes, dilations_conv_list): 60 | super().__init__() 61 | 62 | # First convolution 63 | self.conv0 = nn.Conv3d(inplanes, planes, kernel_size=3, padding=1, stride=1) 64 | 65 | # ASPP Block 66 | self.conv_list = dilations_conv_list 67 | self.conv1 = nn.ModuleList( 68 | [ 69 | nn.Conv3d( 70 | planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False 71 | ) 72 | for dil in dilations_conv_list 73 | ] 74 | ) 75 | self.bn1 = nn.ModuleList( 76 | [nn.BatchNorm3d(planes) for dil in dilations_conv_list] 77 | ) 78 | self.conv2 = nn.ModuleList( 79 | [ 80 | nn.Conv3d( 81 | planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False 82 | ) 83 | for dil in dilations_conv_list 84 | ] 85 | ) 86 | self.bn2 = nn.ModuleList( 87 | [nn.BatchNorm3d(planes) for dil in dilations_conv_list] 88 | ) 89 | self.relu = nn.ReLU() 90 | 91 | self.conv_classes = nn.Conv3d( 92 | planes, nbr_classes, kernel_size=3, padding=1, stride=1 93 | ) 94 | 95 | def forward(self, x_in): 96 | 97 | # Convolution to go from inplanes to planes features... 98 | x_in = self.relu(self.conv0(x_in)) 99 | 100 | y = self.bn2[0](self.conv2[0](self.relu(self.bn1[0](self.conv1[0](x_in))))) 101 | for i in range(1, len(self.conv_list)): 102 | y += self.bn2[i](self.conv2[i](self.relu(self.bn1[i](self.conv1[i](x_in))))) 103 | x_in = self.relu(y + x_in) # modified 104 | 105 | x_in = self.conv_classes(x_in) 106 | 107 | return x_in 108 | 109 | 110 | class Process(nn.Module): 111 | def __init__(self, feature, norm_layer, bn_momentum, dilations=[1, 2, 3]): 112 | super(Process, self).__init__() 113 | self.main = nn.Sequential( 114 | *[ 115 | Bottleneck3D( 116 | feature, 117 | feature // 4, 118 | bn_momentum=bn_momentum, 119 | norm_layer=norm_layer, 120 | dilation=[i, i, i], 121 | ) 122 | for i in dilations 123 | ] 124 | ) 125 | 126 | def forward(self, x): 127 | return self.main(x) 128 | 129 | 130 | class Upsample(nn.Module): 131 | def __init__(self, in_channels, out_channels, norm_layer, bn_momentum): 132 | super(Upsample, self).__init__() 133 | self.main = nn.Sequential( 134 | nn.ConvTranspose3d( 135 | in_channels, 136 | out_channels, 137 | kernel_size=3, 138 | stride=2, 139 | padding=1, 140 | dilation=1, 141 | output_padding=1, 142 | ), 143 | norm_layer(out_channels, momentum=bn_momentum), 144 | nn.ReLU(), 145 | ) 146 | 147 | def forward(self, x): 148 | return self.main(x) 149 | 150 | 151 | class Downsample(nn.Module): 152 | def __init__(self, feature, norm_layer, bn_momentum, expansion=8): 153 | super(Downsample, self).__init__() 154 | self.main = Bottleneck3D( 155 | feature, 156 | feature // 4, 157 | bn_momentum=bn_momentum, 158 | expansion=expansion, 159 | stride=2, 160 | downsample=nn.Sequential( 161 | nn.AvgPool3d(kernel_size=2, stride=2), 162 | nn.Conv3d( 163 | feature, 164 | int(feature * expansion / 4), 165 | kernel_size=1, 166 | stride=1, 167 | bias=False, 168 | ), 169 | norm_layer(int(feature * expansion / 4), momentum=bn_momentum), 170 | ), 171 | norm_layer=norm_layer, 172 | ) 173 | 174 | def forward(self, x): 175 | return self.main(x) 176 | 177 | 178 | def sample_grid_feature(feature, fho=480, fwo=640, scale=4): 179 | """ 180 | Args: 181 | feature (torch.tensor): 2D feature to be sampled, shape (B, C, H, W) 182 | """ 183 | if len(feature.shape) == 4: 184 | B, D, H, W = feature.shape 185 | else: 186 | D, H, W = feature.shape 187 | fH, fW = fho // scale, fwo // scale 188 | xs = torch.linspace(0, fwo - 1, fW, dtype=torch.float).view(1, 1, fW).expand(D, fH, fW) # (64, 120, 160) 189 | ys = torch.linspace(0, fho - 1, fH, dtype=torch.float).view(1, fH, 1).expand(D, fH, fW) # (64, 120, 160) 190 | d_xs = torch.floor(xs[0].reshape(-1)).to(torch.long) # (fH*fW,) 191 | d_ys = torch.floor(ys[0].reshape(-1)).to(torch.long) # (fH*fW,) 192 | # grid_pts = torch.stack([d_xs, d_ys], dim=1) # (fH*fW, 2) 193 | 194 | if len(feature.shape) == 4: 195 | sample_feature = feature[:, :, d_ys, d_xs].reshape(B, D, fH, fW) # (D, fH, fW) 196 | else: 197 | sample_feature = feature[:, d_ys, d_xs].reshape(D, fH, fW) 198 | return sample_feature 199 | 200 | 201 | def get_depth_index(pix_z): 202 | """ 203 | Args: 204 | pix_z (torch.tensor): The depth in camera frame after voxel projected to pixel, shape (N,), N is 205 | total voxel number. 206 | """ 207 | ds = torch.arange(64).to(pix_z.device) # (64,) 208 | ds = 10 / 64 / 65 * ds * (ds + 1) # (64,) 209 | delta_z = torch.abs(pix_z[None, ...] - ds[..., None]) # (64, N) 210 | pix_z_index = torch.argmin(delta_z, dim=0) # (N,) 211 | return pix_z_index 212 | 213 | 214 | def bin_depths(depth_map, mode, depth_min, depth_max, num_bins, target=False): 215 | """ 216 | Converts depth map into bin indices 217 | Args: 218 | depth_map [torch.Tensor(H, W)]: Depth Map 219 | mode [string]: Discretiziation mode (See https://arxiv.org/pdf/2005.13423.pdf for more details) 220 | UD: Uniform discretiziation 221 | LID: Linear increasing discretiziation 222 | SID: Spacing increasing discretiziation 223 | depth_min [float]: Minimum depth value 224 | depth_max [float]: Maximum depth value 225 | num_bins [int]: Number of depth bins 226 | target [bool]: Whether the depth bins indices will be used for a target tensor in loss comparison 227 | Returns: 228 | indices [torch.Tensor(H, W)]: Depth bin indices 229 | """ 230 | if mode == "UD": 231 | bin_size = (depth_max - depth_min) / num_bins 232 | indices = (depth_map - depth_min) / bin_size 233 | elif mode == "LID": 234 | bin_size = 2 * (depth_max - depth_min) / (num_bins * (1 + num_bins)) 235 | indices = -0.5 + 0.5 * torch.sqrt(1 + 8 * (depth_map - depth_min) / bin_size) 236 | elif mode == "SID": 237 | indices = ( 238 | num_bins 239 | * (torch.log(1 + depth_map) - math.log(1 + depth_min)) 240 | / (math.log(1 + depth_max) - math.log(1 + depth_min)) 241 | ) 242 | else: 243 | raise NotImplementedError 244 | 245 | if target: 246 | # Remove indicies outside of bounds (-2, -1, 0, 1, ..., num_bins, num_bins +1) --> (num_bins, num_bins, 0, 1, ..., num_bins, num_bins) 247 | mask = (indices < 0) | (indices > num_bins) | (~torch.isfinite(indices)) 248 | indices[mask] = num_bins 249 | 250 | # Convert to integer 251 | indices = indices.type(torch.int64) 252 | return indices.long() 253 | 254 | 255 | def sample_3d_feature(feature_3d, pix_xy, pix_z, fov_mask): 256 | """ 257 | Args: 258 | feature_3d (torch.tensor): 3D feature, shape (C, D, H, W). 259 | pix_xy (torch.tensor): Projected pix coordinate, shape (N, 2). 260 | pix_z (torch.tensor): Projected pix depth coordinate, shape (N,). 261 | 262 | Returns: 263 | torch.tensor: Sampled feature, shape (N, C) 264 | """ 265 | pix_x, pix_y = pix_xy[:, 0][fov_mask], pix_xy[:, 1][fov_mask] 266 | pix_z = pix_z[fov_mask].to(pix_y.dtype) 267 | ret = feature_3d[:, pix_z, pix_y, pix_x].T 268 | return ret 269 | 270 | -------------------------------------------------------------------------------- /iso/models/unet2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/astra-vision/MonoScene/blob/master/monoscene/models/unet2d.py 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import os 8 | 9 | 10 | class UpSampleBN(nn.Module): 11 | def __init__(self, skip_input, output_features): 12 | super(UpSampleBN, self).__init__() 13 | self._net = nn.Sequential( 14 | nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), 15 | nn.BatchNorm2d(output_features), 16 | nn.LeakyReLU(), 17 | nn.Conv2d( 18 | output_features, output_features, kernel_size=3, stride=1, padding=1 19 | ), 20 | nn.BatchNorm2d(output_features), 21 | nn.LeakyReLU(), 22 | ) 23 | 24 | def forward(self, x, concat_with): 25 | up_x = F.interpolate( 26 | x, 27 | size=(concat_with.shape[2], concat_with.shape[3]), 28 | mode="bilinear", 29 | align_corners=True, 30 | ) 31 | f = torch.cat([up_x, concat_with], dim=1) 32 | return self._net(f) 33 | 34 | 35 | class DecoderBN(nn.Module): 36 | def __init__( 37 | self, num_features, bottleneck_features, out_feature, use_decoder=True 38 | ): 39 | super(DecoderBN, self).__init__() 40 | features = int(num_features) 41 | self.use_decoder = use_decoder 42 | 43 | self.conv2 = nn.Conv2d( 44 | bottleneck_features, features, kernel_size=1, stride=1, padding=0 45 | ) 46 | 47 | self.out_feature_1_1 = out_feature # 200 48 | self.out_feature_1_2 = out_feature # 200 49 | self.out_feature_1_4 = out_feature # 200 50 | self.out_feature_1_8 = out_feature # 200 51 | self.out_feature_1_16 = out_feature # 200 52 | self.feature_1_16 = features // 2 53 | self.feature_1_8 = features // 4 54 | self.feature_1_4 = features // 8 55 | self.feature_1_2 = features // 16 56 | self.feature_1_1 = features // 32 57 | 58 | if self.use_decoder: 59 | self.resize_output_1_1 = nn.Conv2d( 60 | self.feature_1_1, self.out_feature_1_1, kernel_size=1 61 | ) 62 | self.resize_output_1_2 = nn.Conv2d( 63 | self.feature_1_2, self.out_feature_1_2, kernel_size=1 64 | ) 65 | self.resize_output_1_4 = nn.Conv2d( 66 | self.feature_1_4, self.out_feature_1_4, kernel_size=1 67 | ) 68 | self.resize_output_1_8 = nn.Conv2d( 69 | self.feature_1_8, self.out_feature_1_8, kernel_size=1 70 | ) 71 | self.resize_output_1_16 = nn.Conv2d( 72 | self.feature_1_16, self.out_feature_1_16, kernel_size=1 73 | ) 74 | 75 | self.up16 = UpSampleBN( 76 | skip_input=features + 224, output_features=self.feature_1_16 77 | ) 78 | self.up8 = UpSampleBN( 79 | skip_input=self.feature_1_16 + 80, output_features=self.feature_1_8 80 | ) 81 | self.up4 = UpSampleBN( 82 | skip_input=self.feature_1_8 + 48, output_features=self.feature_1_4 83 | ) 84 | self.up2 = UpSampleBN( 85 | skip_input=self.feature_1_4 + 32, output_features=self.feature_1_2 86 | ) 87 | self.up1 = UpSampleBN( 88 | skip_input=self.feature_1_2 + 3, output_features=self.feature_1_1 89 | ) 90 | else: 91 | self.resize_output_1_1 = nn.Conv2d(3, out_feature, kernel_size=1) 92 | self.resize_output_1_2 = nn.Conv2d(32, out_feature * 2, kernel_size=1) 93 | self.resize_output_1_4 = nn.Conv2d(48, out_feature * 4, kernel_size=1) 94 | 95 | def forward(self, features): 96 | x_block0, x_block1, x_block2, x_block3, x_block4 = ( 97 | features[4], 98 | features[5], 99 | features[6], 100 | features[8], 101 | features[11], # (B, 2560, 15, 20) 102 | ) 103 | bs = x_block0.shape[0] 104 | x_d0 = self.conv2(x_block4) # (B, 2560, 15, 20) 105 | 106 | if self.use_decoder: 107 | x_1_16 = self.up16(x_d0, x_block3) # (B, 1280, 30, 40) 108 | 109 | x_1_8 = self.up8(x_1_16, x_block2) # (B, 640, 60, 80) 110 | x_1_4 = self.up4(x_1_8, x_block1) 111 | x_1_2 = self.up2(x_1_4, x_block0) 112 | x_1_1 = self.up1(x_1_2, features[0]) 113 | return { 114 | "1_1": self.resize_output_1_1(x_1_1), 115 | "1_2": self.resize_output_1_2(x_1_2), 116 | "1_4": self.resize_output_1_4(x_1_4), 117 | "1_8": self.resize_output_1_8(x_1_8), 118 | "1_16": self.resize_output_1_16(x_1_16), 119 | } 120 | else: 121 | x_1_1 = features[0] 122 | x_1_2, x_1_4, x_1_8, x_1_16 = ( 123 | features[4], 124 | features[5], 125 | features[6], 126 | features[8], 127 | ) 128 | x_global = features[-1].reshape(bs, 2560, -1).mean(2) 129 | return { 130 | "1_1": self.resize_output_1_1(x_1_1), 131 | "1_2": self.resize_output_1_2(x_1_2), 132 | "1_4": self.resize_output_1_4(x_1_4), 133 | "global": x_global, 134 | } 135 | 136 | 137 | class Encoder(nn.Module): 138 | def __init__(self, backend, frozen_encoder=False): 139 | super(Encoder, self).__init__() 140 | self.original_model = backend 141 | self.frozen_encoder = frozen_encoder 142 | 143 | def forward(self, x): 144 | if self.frozen_encoder: 145 | self.eval() 146 | features = [x] 147 | for k, v in self.original_model._modules.items(): 148 | if k == "blocks": 149 | for ki, vi in v._modules.items(): 150 | features.append(vi(features[-1])) 151 | else: 152 | features.append(v(features[-1])) 153 | return features 154 | 155 | 156 | class UNet2D(nn.Module): 157 | def __init__(self, backend, num_features, out_feature, use_decoder=True, frozen_encoder=False): 158 | super(UNet2D, self).__init__() 159 | self.use_decoder = use_decoder 160 | 161 | self.encoder = Encoder(backend, frozen_encoder) 162 | self.decoder = DecoderBN( 163 | out_feature=out_feature, 164 | use_decoder=use_decoder, 165 | bottleneck_features=num_features, 166 | num_features=num_features, 167 | ) 168 | 169 | def forward(self, x, **kwargs): 170 | encoded_feats = self.encoder(x) 171 | unet_out = self.decoder(encoded_feats, **kwargs) 172 | return unet_out 173 | 174 | def get_encoder_params(self): # lr/10 learning rate 175 | return self.encoder.parameters() 176 | 177 | def get_decoder_params(self): # lr learning rate 178 | return self.decoder.parameters() 179 | 180 | @classmethod 181 | def build(cls, **kwargs): 182 | basemodel_name = "tf_efficientnet_b7_ns" 183 | num_features = 2560 184 | 185 | print("Loading base model ()...".format(basemodel_name), end="") 186 | basemodel = torch.hub.load( 187 | "rwightman/gen-efficientnet-pytorch", basemodel_name, pretrained=True 188 | ) 189 | print("Done.") 190 | 191 | # Remove last layer 192 | print("Removing last two layers (global_pool & classifier).") 193 | basemodel.global_pool = nn.Identity() 194 | basemodel.classifier = nn.Identity() 195 | 196 | # Building Encoder-Decoder model 197 | print("Building Encoder-Decoder model..", end="") 198 | m = cls(basemodel, num_features=num_features, **kwargs) 199 | print("Done.") 200 | return m 201 | 202 | if __name__ == '__main__': 203 | model = UNet2D.build(out_feature=256, use_decoder=True) 204 | -------------------------------------------------------------------------------- /iso/models/unet3d_nyu.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | Code adapted from https://github.com/astra-vision/MonoScene/blob/master/monoscene/models/unet3d_nyu.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from iso.models.CRP3D import CPMegaVoxels 10 | from iso.models.modules import ( 11 | Process, 12 | Upsample, 13 | Downsample, 14 | SegmentationHead, 15 | ASPP, 16 | # Downsample_SE, 17 | # Process_SE, 18 | ) 19 | 20 | 21 | class UNet3D(nn.Module): 22 | def __init__( 23 | self, 24 | class_num, 25 | norm_layer, 26 | feature, 27 | full_scene_size, 28 | n_relations=4, 29 | project_res=[], 30 | context_prior=True, 31 | bn_momentum=0.1, 32 | ): 33 | super(UNet3D, self).__init__() 34 | self.business_layer = [] 35 | self.project_res = project_res 36 | 37 | self.feature_1_4 = feature 38 | self.feature_1_8 = feature * 2 39 | self.feature_1_16 = feature * 4 40 | 41 | self.feature_1_16_dec = self.feature_1_16 42 | self.feature_1_8_dec = self.feature_1_8 43 | self.feature_1_4_dec = self.feature_1_4 44 | 45 | self.process_1_4 = nn.Sequential( 46 | Process(self.feature_1_4, norm_layer, bn_momentum, dilations=[1, 2, 3]), 47 | Downsample(self.feature_1_4, norm_layer, bn_momentum), 48 | ) 49 | self.process_1_8 = nn.Sequential( 50 | Process(self.feature_1_8, norm_layer, bn_momentum, dilations=[1, 2, 3]), 51 | Downsample(self.feature_1_8, norm_layer, bn_momentum), 52 | ) 53 | self.up_1_16_1_8 = Upsample( 54 | self.feature_1_16_dec, self.feature_1_8_dec, norm_layer, bn_momentum 55 | ) 56 | self.up_1_8_1_4 = Upsample( 57 | self.feature_1_8_dec, self.feature_1_4_dec, norm_layer, bn_momentum 58 | ) 59 | self.ssc_head_1_4 = SegmentationHead( 60 | self.feature_1_4_dec, self.feature_1_4_dec, class_num, [1, 2, 3] 61 | ) 62 | 63 | self.context_prior = context_prior 64 | size_1_16 = tuple(np.ceil(i / 4).astype(int) for i in full_scene_size) 65 | 66 | if context_prior: 67 | self.CP_mega_voxels = CPMegaVoxels( 68 | self.feature_1_16, 69 | size_1_16, 70 | n_relations=n_relations, 71 | bn_momentum=bn_momentum, 72 | ) 73 | 74 | # 75 | def forward(self, input_dict): 76 | res = {} 77 | 78 | x3d_1_4 = input_dict["x3d"] 79 | x3d_1_8 = self.process_1_4(x3d_1_4) 80 | x3d_1_16 = self.process_1_8(x3d_1_8) 81 | 82 | if self.context_prior: 83 | ret = self.CP_mega_voxels(x3d_1_16) 84 | x3d_1_16 = ret["x"] 85 | for k in ret.keys(): 86 | res[k] = ret[k] 87 | 88 | x3d_up_1_8 = self.up_1_16_1_8(x3d_1_16) + x3d_1_8 89 | x3d_up_1_4 = self.up_1_8_1_4(x3d_up_1_8) + x3d_1_4 90 | 91 | ssc_logit_1_4 = self.ssc_head_1_4(x3d_up_1_4) 92 | 93 | res["ssc_logit"] = ssc_logit_1_4 94 | 95 | return res 96 | -------------------------------------------------------------------------------- /iso/scripts/eval.sh: -------------------------------------------------------------------------------- 1 | python iso/scripts/eval_iso.py n_gpus=1 batch_size=1 -------------------------------------------------------------------------------- /iso/scripts/eval_iso.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Trainer 2 | from iso.models.iso import ISO 3 | from iso.data.NYU.nyu_dm import NYUDataModule 4 | import hydra 5 | from omegaconf import DictConfig 6 | import torch 7 | import os 8 | from hydra.utils import get_original_cwd 9 | from pytorch_lightning import seed_everything 10 | 11 | torch.set_float32_matmul_precision('high') 12 | 13 | 14 | @hydra.main(config_name="../config/iso.yaml") 15 | def main(config: DictConfig): 16 | torch.set_grad_enabled(False) 17 | if config.dataset == "kitti": 18 | config.batch_size = 1 19 | n_classes = 20 20 | feature = 64 21 | project_scale = 2 22 | full_scene_size = (256, 256, 32) 23 | data_module = KittiDataModule( 24 | root=config.kitti_root, 25 | preprocess_root=config.kitti_preprocess_root, 26 | frustum_size=config.frustum_size, 27 | batch_size=int(config.batch_size / config.n_gpus), 28 | num_workers=int(config.num_workers_per_gpu * config.n_gpus), 29 | ) 30 | 31 | elif config.dataset == "NYU": 32 | config.batch_size = 2 33 | project_scale = 1 34 | n_classes = 12 35 | feature = 200 36 | full_scene_size = (60, 36, 60) 37 | data_module = NYUDataModule( 38 | root=config.NYU_root, 39 | preprocess_root=config.NYU_preprocess_root, 40 | n_relations=config.n_relations, 41 | frustum_size=config.frustum_size, 42 | batch_size=int(config.batch_size / config.n_gpus), 43 | num_workers=int(config.num_workers_per_gpu * config.n_gpus), 44 | ) 45 | 46 | trainer = Trainer( 47 | sync_batchnorm=True, deterministic=False, devices=config.n_gpus, accelerator="gpu", 48 | ) 49 | 50 | if config.dataset == "NYU": 51 | model_path = os.path.join( 52 | get_original_cwd(), "trained_models", "iso_nyu.ckpt" 53 | ) 54 | else: 55 | model_path = os.path.join( 56 | get_original_cwd(), "trained_models", "iso_kitti.ckpt" 57 | ) 58 | 59 | voxeldepth_res = [] 60 | if config.voxeldepth: 61 | if config.voxeldepthcfg.depth_scale_1: 62 | voxeldepth_res.append('1') 63 | if config.voxeldepthcfg.depth_scale_2: 64 | voxeldepth_res.append('2') 65 | if config.voxeldepthcfg.depth_scale_4: 66 | voxeldepth_res.append('4') 67 | if config.voxeldepthcfg.depth_scale_8: 68 | voxeldepth_res.append('8') 69 | 70 | os.chdir(hydra.utils.get_original_cwd()) 71 | 72 | model = ISO.load_from_checkpoint( 73 | model_path, 74 | feature=feature, 75 | project_scale=project_scale, 76 | fp_loss=config.fp_loss, 77 | full_scene_size=full_scene_size, 78 | voxeldepth=config.voxeldepth, 79 | voxeldepth_res=voxeldepth_res, 80 | # 81 | use_gt_depth=config.use_gt_depth, 82 | add_fusion=config.add_fusion, 83 | use_zoedepth=config.use_zoedepth, 84 | use_depthanything=config.use_depthanything, 85 | zoedepth_as_gt=config.zoedepth_as_gt, 86 | depthanything_as_gt=config.depthanything_as_gt, 87 | frozen_encoder=config.frozen_encoder, 88 | ) 89 | model.eval() 90 | data_module.setup() 91 | val_dataloader = data_module.val_dataloader() 92 | trainer.test(model, dataloaders=val_dataloader) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /iso/scripts/generate_output.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Trainer 2 | from iso.models.iso import ISO 3 | from iso.data.NYU.nyu_dm import NYUDataModule 4 | # from iso.data.semantic_kitti.kitti_dm import KittiDataModule 5 | # from iso.data.kitti_360.kitti_360_dm import Kitti360DataModule 6 | import hydra 7 | from omegaconf import DictConfig 8 | import torch 9 | import numpy as np 10 | import os 11 | from hydra.utils import get_original_cwd 12 | from tqdm import tqdm 13 | import pickle 14 | 15 | 16 | @hydra.main(config_name="../config/iso.yaml") 17 | def main(config: DictConfig): 18 | torch.set_grad_enabled(False) 19 | 20 | # Setup dataloader 21 | if config.dataset == "kitti" or config.dataset == "kitti_360": 22 | feature = 64 23 | project_scale = 2 24 | full_scene_size = (256, 256, 32) 25 | 26 | if config.dataset == "kitti": 27 | data_module = KittiDataModule( 28 | root=config.kitti_root, 29 | preprocess_root=config.kitti_preprocess_root, 30 | frustum_size=config.frustum_size, 31 | batch_size=int(config.batch_size / config.n_gpus), 32 | num_workers=int(config.num_workers_per_gpu * config.n_gpus), 33 | ) 34 | data_module.setup() 35 | data_loader = data_module.val_dataloader() 36 | # data_loader = data_module.test_dataloader() # use this if you want to infer on test set 37 | else: 38 | data_module = Kitti360DataModule( 39 | root=config.kitti_360_root, 40 | sequences=[config.kitti_360_sequence], 41 | n_scans=2000, 42 | batch_size=1, 43 | num_workers=3, 44 | ) 45 | data_module.setup() 46 | data_loader = data_module.dataloader() 47 | 48 | elif config.dataset == "NYU": 49 | project_scale = 1 50 | feature = 200 51 | full_scene_size = (60, 36, 60) 52 | data_module = NYUDataModule( 53 | root=config.NYU_root, 54 | preprocess_root=config.NYU_preprocess_root, 55 | n_relations=config.n_relations, 56 | frustum_size=config.frustum_size, 57 | batch_size=int(config.batch_size / config.n_gpus), 58 | num_workers=int(config.num_workers_per_gpu * config.n_gpus), 59 | ) 60 | data_module.setup() 61 | data_loader = data_module.val_dataloader() 62 | # data_loader = data_module.test_dataloader() # use this if you want to infer on test set 63 | else: 64 | print("dataset not support") 65 | 66 | # Load pretrained models 67 | if config.dataset == "NYU": 68 | model_path = os.path.join( 69 | get_original_cwd(), "trained_models", "iso_nyu.ckpt" 70 | ) 71 | else: 72 | model_path = os.path.join( 73 | get_original_cwd(), "trained_models", "iso_kitti.ckpt" 74 | ) 75 | 76 | voxeldepth_res = [] 77 | if config.voxeldepth: 78 | if config.voxeldepthcfg.depth_scale_1: 79 | voxeldepth_res.append('1') 80 | if config.voxeldepthcfg.depth_scale_2: 81 | voxeldepth_res.append('2') 82 | if config.voxeldepthcfg.depth_scale_4: 83 | voxeldepth_res.append('4') 84 | if config.voxeldepthcfg.depth_scale_8: 85 | voxeldepth_res.append('8') 86 | 87 | os.chdir(hydra.utils.get_original_cwd()) 88 | 89 | model = ISO.load_from_checkpoint( 90 | model_path, 91 | feature=feature, 92 | project_scale=project_scale, 93 | fp_loss=config.fp_loss, 94 | full_scene_size=full_scene_size, 95 | voxeldepth=config.voxeldepth, 96 | voxeldepth_res=voxeldepth_res, 97 | # 98 | use_gt_depth=config.use_gt_depth, 99 | add_fusion=config.add_fusion, 100 | use_zoedepth=config.use_zoedepth, 101 | use_depthanything=config.use_depthanything, 102 | zoedepth_as_gt=config.zoedepth_as_gt, 103 | depthanything_as_gt=config.depthanything_as_gt, 104 | frozen_encoder=config.frozen_encoder, 105 | ) 106 | model.cuda() 107 | model.eval() 108 | 109 | # Save prediction and additional data 110 | # to draw the viewing frustum and remove scene outside the room for NYUv2 111 | output_path = os.path.join(config.output_path, config.dataset) 112 | with torch.no_grad(): 113 | for batch in tqdm(data_loader): 114 | batch["img"] = batch["img"].cuda() 115 | pred = model(batch) 116 | y_pred = torch.softmax(pred["ssc_logit"], dim=1).detach().cpu().numpy() 117 | y_pred = np.argmax(y_pred, axis=1) 118 | for i in range(config.batch_size): 119 | out_dict = {"y_pred": y_pred[i].astype(np.uint16)} 120 | if "target" in batch: 121 | out_dict["target"] = ( 122 | batch["target"][i].detach().cpu().numpy().astype(np.uint16) 123 | ) 124 | 125 | if config.dataset == "NYU": 126 | write_path = output_path 127 | filepath = os.path.join(write_path, batch["name"][i] + ".pkl") 128 | out_dict["cam_pose"] = batch["cam_pose"][i].detach().cpu().numpy() 129 | out_dict["vox_origin"] = ( 130 | batch["vox_origin"][i].detach().cpu().numpy() 131 | ) 132 | else: 133 | write_path = os.path.join(output_path, batch["sequence"][i]) 134 | filepath = os.path.join(write_path, batch["frame_id"][i] + ".pkl") 135 | out_dict["fov_mask_1"] = ( 136 | batch["fov_mask_1"][i].detach().cpu().numpy() 137 | ) 138 | out_dict["cam_k"] = batch["cam_k"][i].detach().cpu().numpy() 139 | out_dict["T_velo_2_cam"] = ( 140 | batch["T_velo_2_cam"][i].detach().cpu().numpy() 141 | ) 142 | 143 | os.makedirs(write_path, exist_ok=True) 144 | with open(filepath, "wb") as handle: 145 | pickle.dump(out_dict, handle) 146 | print("wrote to", filepath) 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /iso/scripts/train.sh: -------------------------------------------------------------------------------- 1 | python iso/scripts/train_iso.py n_gpus=2 batch_size=4 -------------------------------------------------------------------------------- /iso/scripts/train_iso.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/hongxiao.yu/projects/ISO_occscannet/depth_anything') 3 | from iso.data.NYU.params import ( 4 | class_weights as NYU_class_weights, 5 | NYU_class_names, 6 | ) 7 | from iso.data.OccScanNet.params import ( 8 | class_weights as OccScanNet_class_weights, 9 | OccScanNet_class_names 10 | ) 11 | from iso.data.NYU.nyu_dm import NYUDataModule 12 | from iso.data.OccScanNet.occscannet_dm import OccScanNetDataModule 13 | from torch.utils.data.dataloader import DataLoader 14 | from iso.models.iso import ISO 15 | from pytorch_lightning import Trainer 16 | import pytorch_lightning as pl 17 | from pytorch_lightning.loggers import TensorBoardLogger 18 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 19 | import os 20 | import hydra 21 | from omegaconf import DictConfig 22 | import numpy as np 23 | import torch 24 | 25 | hydra.output_subdir = None 26 | 27 | pl.seed_everything(658018589) #, workers=True) 28 | 29 | @hydra.main(config_name="../config/iso_occscannet_mini.yaml", config_path='.') 30 | def main(config: DictConfig): 31 | exp_name = config.exp_prefix 32 | exp_name += "_{}_{}".format(config.dataset, config.run) 33 | exp_name += "_FrusSize_{}".format(config.frustum_size) 34 | exp_name += "_nRelations{}".format(config.n_relations) 35 | exp_name += "_WD{}_lr{}".format(config.weight_decay, config.lr) 36 | 37 | if config.use_gt_depth: 38 | exp_name += '_gtdepth' 39 | if config.add_fusion: 40 | exp_name += '_addfusion' 41 | if not config.use_zoedepth: 42 | exp_name += '_nozoedepth' 43 | if config.zoedepth_as_gt: 44 | exp_name += '_zoedepthgt' 45 | if config.frozen_encoder: 46 | exp_name += '_frozen_encoder' 47 | if config.use_depthanything: 48 | exp_name += '_depthanything' 49 | 50 | voxeldepth_res = [] 51 | if config.voxeldepth: 52 | exp_name += '_VoxelDepth' 53 | if config.voxeldepthcfg.depth_scale_1: 54 | exp_name += '_1' 55 | voxeldepth_res.append('1') 56 | if config.voxeldepthcfg.depth_scale_2: 57 | exp_name += '_2' 58 | voxeldepth_res.append('2') 59 | if config.voxeldepthcfg.depth_scale_4: 60 | exp_name += '_4' 61 | voxeldepth_res.append('4') 62 | if config.voxeldepthcfg.depth_scale_8: 63 | exp_name += '_8' 64 | voxeldepth_res.append('8') 65 | 66 | if config.CE_ssc_loss: 67 | exp_name += "_CEssc" 68 | if config.geo_scal_loss: 69 | exp_name += "_geoScalLoss" 70 | if config.sem_scal_loss: 71 | exp_name += "_semScalLoss" 72 | if config.fp_loss: 73 | exp_name += "_fpLoss" 74 | 75 | if config.relation_loss: 76 | exp_name += "_CERel" 77 | if config.context_prior: 78 | exp_name += "_3DCRP" 79 | 80 | # Setup dataloaders 81 | if config.dataset == "OccScanNet": 82 | class_names = OccScanNet_class_names 83 | max_epochs = 60 84 | logdir = config.logdir 85 | full_scene_size = (60, 60, 36) 86 | project_scale = 1 87 | feature = 200 88 | n_classes = 12 89 | class_weights = OccScanNet_class_weights 90 | data_module = OccScanNetDataModule( 91 | root=config.OccScanNet_root, 92 | n_relations=config.n_relations, 93 | frustum_size=config.frustum_size, 94 | batch_size=int(config.batch_size / config.n_gpus), 95 | num_workers=int(config.num_workers_per_gpu * config.n_gpus), 96 | ) 97 | elif config.dataset == "OccScanNet_mini": 98 | class_names = OccScanNet_class_names 99 | max_epochs = 60 100 | logdir = config.logdir 101 | full_scene_size = (60, 60, 36) 102 | project_scale = 1 103 | feature = 200 104 | n_classes = 12 105 | class_weights = OccScanNet_class_weights 106 | data_module = OccScanNetDataModule( 107 | root=config.OccScanNet_root, 108 | n_relations=config.n_relations, 109 | frustum_size=config.frustum_size, 110 | batch_size=int(config.batch_size / config.n_gpus), 111 | num_workers=int(config.num_workers_per_gpu * config.n_gpus), 112 | train_scenes_sample=4639, 113 | val_scenes_sample=2007, 114 | ) 115 | 116 | elif config.dataset == "NYU": 117 | class_names = NYU_class_names 118 | max_epochs = 30 119 | logdir = config.logdir 120 | full_scene_size = (60, 36, 60) 121 | project_scale = 1 122 | feature = 200 123 | n_classes = 12 124 | class_weights = NYU_class_weights 125 | data_module = NYUDataModule( 126 | root=config.NYU_root, 127 | preprocess_root=config.NYU_preprocess_root, 128 | n_relations=config.n_relations, 129 | frustum_size=config.frustum_size, 130 | batch_size=int(config.batch_size / config.n_gpus), 131 | num_workers=int(config.num_workers_per_gpu * config.n_gpus), 132 | ) 133 | 134 | project_res = ["1"] 135 | if config.project_1_2: 136 | exp_name += "_Proj_2" 137 | project_res.append("2") 138 | if config.project_1_4: 139 | exp_name += "_4" 140 | project_res.append("4") 141 | if config.project_1_8: 142 | exp_name += "_8" 143 | project_res.append("8") 144 | 145 | print(exp_name) 146 | 147 | os.chdir(hydra.utils.get_original_cwd()) 148 | 149 | # Initialize ISO model 150 | model = ISO( 151 | dataset=config.dataset, 152 | frustum_size=config.frustum_size, 153 | project_scale=project_scale, 154 | n_relations=config.n_relations, 155 | fp_loss=config.fp_loss, 156 | feature=feature, 157 | full_scene_size=full_scene_size, 158 | project_res=project_res, 159 | voxeldepth=config.voxeldepth, 160 | voxeldepth_res=voxeldepth_res, 161 | n_classes=n_classes, 162 | class_names=class_names, 163 | context_prior=config.context_prior, 164 | relation_loss=config.relation_loss, 165 | CE_ssc_loss=config.CE_ssc_loss, 166 | sem_scal_loss=config.sem_scal_loss, 167 | geo_scal_loss=config.geo_scal_loss, 168 | lr=config.lr, 169 | weight_decay=config.weight_decay, 170 | class_weights=class_weights, 171 | use_gt_depth=config.use_gt_depth, 172 | add_fusion=config.add_fusion, 173 | use_zoedepth=config.use_zoedepth, 174 | use_depthanything=config.use_depthanything, 175 | zoedepth_as_gt=config.zoedepth_as_gt, 176 | depthanything_as_gt=config.depthanything_as_gt, 177 | frozen_encoder=config.frozen_encoder, 178 | ) 179 | 180 | if config.enable_log: 181 | logger = TensorBoardLogger(save_dir=logdir, name=exp_name, version="") 182 | lr_monitor = LearningRateMonitor(logging_interval="step") 183 | checkpoint_callbacks = [ 184 | ModelCheckpoint( 185 | save_last=True, 186 | monitor="val/mIoU", 187 | save_top_k=1, 188 | mode="max", 189 | filename="{epoch:03d}-{val/mIoU:.5f}", 190 | ), 191 | lr_monitor, 192 | ] 193 | else: 194 | logger = False 195 | checkpoint_callbacks = False 196 | 197 | model_path = os.path.join(logdir, exp_name, "checkpoints/last.ckpt") 198 | if os.path.isfile(model_path): 199 | # Continue training from last.ckpt 200 | trainer = Trainer( 201 | callbacks=checkpoint_callbacks, 202 | resume_from_checkpoint=model_path, 203 | sync_batchnorm=True, 204 | deterministic=False, 205 | max_epochs=max_epochs, 206 | devices=config.n_gpus, 207 | accelerator="gpu", 208 | logger=logger, 209 | check_val_every_n_epoch=1, 210 | log_every_n_steps=10, 211 | # flush_logs_every_n_steps=100, 212 | # strategy="ddp_find_unused_parameters_true", 213 | ) 214 | else: 215 | # Train from scratch 216 | trainer = Trainer( 217 | callbacks=checkpoint_callbacks, 218 | sync_batchnorm=True, 219 | deterministic=False, 220 | max_epochs=max_epochs, 221 | devices=config.n_gpus, 222 | accelerator='gpu', 223 | logger=logger, 224 | check_val_every_n_epoch=1, 225 | log_every_n_steps=10, 226 | # flush_logs_every_n_steps=100, 227 | strategy="ddp_find_unused_parameters_true", 228 | ) 229 | torch.set_float32_matmul_precision('high') 230 | # os.chdir("/home/hongxiao.yu/projects/ISO_occscannet") 231 | trainer.fit(model, data_module) 232 | 233 | 234 | if __name__ == "__main__": 235 | main() 236 | -------------------------------------------------------------------------------- /iso/scripts/visualization/NYU_vis_pred.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | from omegaconf import DictConfig 4 | import numpy as np 5 | import hydra 6 | from mayavi import mlab 7 | 8 | 9 | def get_grid_coords(dims, resolution): 10 | """ 11 | :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32]) 12 | :return coords_grid: is the center coords of voxels in the grid 13 | """ 14 | 15 | g_xx = np.arange(0, dims[0] + 1) 16 | g_yy = np.arange(0, dims[1] + 1) 17 | 18 | g_zz = np.arange(0, dims[2] + 1) 19 | 20 | # Obtaining the grid with coords... 21 | xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1]) 22 | coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T 23 | coords_grid = coords_grid.astype(np.float) 24 | 25 | coords_grid = (coords_grid * resolution) + resolution / 2 26 | 27 | temp = np.copy(coords_grid) 28 | temp[:, 0] = coords_grid[:, 1] 29 | temp[:, 1] = coords_grid[:, 0] 30 | coords_grid = np.copy(temp) 31 | 32 | return coords_grid 33 | 34 | 35 | def draw( 36 | voxels, 37 | cam_pose, 38 | vox_origin, 39 | voxel_size=0.08, 40 | d=0.75, # 0.75m - determine the size of the mesh representing the camera 41 | ): 42 | # Compute the coordinates of the mesh representing camera 43 | y = d * 480 / (2 * 518.8579) 44 | x = d * 640 / (2 * 518.8579) 45 | tri_points = np.array( 46 | [ 47 | [0, 0, 0], 48 | [x, y, d], 49 | [-x, y, d], 50 | [-x, -y, d], 51 | [x, -y, d], 52 | ] 53 | ) 54 | tri_points = np.hstack([tri_points, np.ones((5, 1))]) 55 | 56 | tri_points = (cam_pose @ tri_points.T).T 57 | x = tri_points[:, 0] - vox_origin[0] 58 | y = tri_points[:, 1] - vox_origin[1] 59 | z = tri_points[:, 2] - vox_origin[2] 60 | triangles = [ 61 | (0, 1, 2), 62 | (0, 1, 4), 63 | (0, 3, 4), 64 | (0, 2, 3), 65 | ] 66 | 67 | # Compute the voxels coordinates 68 | grid_coords = get_grid_coords( 69 | [voxels.shape[0], voxels.shape[2], voxels.shape[1]], voxel_size 70 | ) 71 | 72 | # Attach the predicted class to every voxel 73 | grid_coords = np.vstack( 74 | (grid_coords.T, np.moveaxis(voxels, [0, 1, 2], [0, 2, 1]).reshape(-1)) 75 | ).T 76 | 77 | # Remove empty and unknown voxels 78 | occupied_voxels = grid_coords[(grid_coords[:, 3] > 0) & (grid_coords[:, 3] < 255)] 79 | figure = mlab.figure(size=(1600, 900), bgcolor=(1, 1, 1)) 80 | 81 | # Draw the camera 82 | mlab.triangular_mesh( 83 | x, 84 | y, 85 | z, 86 | triangles, 87 | representation="wireframe", 88 | color=(0, 0, 0), 89 | line_width=5, 90 | ) 91 | 92 | # Draw occupied voxels 93 | plt_plot = mlab.points3d( 94 | occupied_voxels[:, 0], 95 | occupied_voxels[:, 1], 96 | occupied_voxels[:, 2], 97 | occupied_voxels[:, 3], 98 | colormap="viridis", 99 | scale_factor=voxel_size - 0.1 * voxel_size, 100 | mode="cube", 101 | opacity=1.0, 102 | vmin=0, 103 | vmax=12, 104 | ) 105 | 106 | colors = np.array( 107 | [ 108 | [22, 191, 206, 255], 109 | [214, 38, 40, 255], 110 | [43, 160, 43, 255], 111 | [158, 216, 229, 255], 112 | [114, 158, 206, 255], 113 | [204, 204, 91, 255], 114 | [255, 186, 119, 255], 115 | [147, 102, 188, 255], 116 | [30, 119, 181, 255], 117 | [188, 188, 33, 255], 118 | [255, 127, 12, 255], 119 | [196, 175, 214, 255], 120 | [153, 153, 153, 255], 121 | ] 122 | ) 123 | 124 | plt_plot.glyph.scale_mode = "scale_by_vector" 125 | 126 | plt_plot.module_manager.scalar_lut_manager.lut.table = colors 127 | 128 | mlab.show() 129 | 130 | 131 | @hydra.main(config_path=None) 132 | def main(config: DictConfig): 133 | scan = config.file 134 | 135 | with open(scan, "rb") as handle: 136 | b = pickle.load(handle) 137 | 138 | cam_pose = b["cam_pose"] 139 | vox_origin = b["vox_origin"] 140 | gt_scene = b["target"] 141 | pred_scene = b["y_pred"] 142 | scan = os.path.basename(scan)[:12] 143 | 144 | pred_scene[(gt_scene == 255)] = 255 # only draw scene inside the room 145 | 146 | draw( 147 | pred_scene, 148 | cam_pose, 149 | vox_origin, 150 | voxel_size=0.08, 151 | d=0.75, 152 | ) 153 | 154 | 155 | if __name__ == "__main__": 156 | main() 157 | -------------------------------------------------------------------------------- /iso/scripts/visualization/NYU_vis_pred_2.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | from omegaconf import DictConfig 4 | import numpy as np 5 | import hydra 6 | # from mayavi import mlab 7 | import open3d as o3d 8 | 9 | 10 | def get_grid_coords(dims, resolution): 11 | """ 12 | :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32]) 13 | :return coords_grid: is the center coords of voxels in the grid 14 | """ 15 | 16 | # The sensor in centered in X (we go to dims/2 + 1 for the histogramdd) 17 | g_xx = np.arange(0, dims[0] + 1) 18 | # The sensor is in Y=0 (we go to dims + 1 for the histogramdd) 19 | g_yy = np.arange(0, dims[1] + 1) 20 | # The sensor is in Z=1.73. I observed that the ground was to voxel levels above the grid bottom, so Z pose is at 10 21 | # if bottom voxel is 0. If we want the sensor to be at (0, 0, 0), then the bottom in z is -10, top is 22 22 | # (we go to 22 + 1 for the histogramdd) 23 | # ATTENTION.. Is 11 for old grids.. 10 for new grids (v1.1) (https://github.com/PRBonn/semantic-kitti-api/issues/49) 24 | g_zz = np.arange(0, dims[2] + 1) 25 | 26 | # Obtaining the grid with coords... 27 | xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1]) 28 | coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T 29 | coords_grid = coords_grid.astype(np.float) 30 | 31 | coords_grid = (coords_grid * resolution) + resolution / 2 32 | 33 | temp = np.copy(coords_grid) 34 | temp[:, 0] = coords_grid[:, 1] 35 | temp[:, 1] = coords_grid[:, 0] 36 | coords_grid = np.copy(temp) 37 | 38 | return coords_grid 39 | 40 | 41 | 42 | # def draw_semantic_open3d( 43 | # voxels, 44 | # cam_param_path="", 45 | # voxel_size=0.2): 46 | 47 | 48 | 49 | # grid_coords, _, _, _ = get_grid_coords([voxels.shape[0], voxels.shape[1], voxels.shape[2]], voxel_size) 50 | 51 | # points = np.vstack([grid_coords.T, voxels.reshape(-1)]).T 52 | 53 | # # Obtaining voxels with semantic class 54 | # points = points[(points[:, 3] != 0) & (points[:, 3] != 255)] # remove empty voxel and unknown class 55 | 56 | # colors = np.take_along_axis(colors, points[:, 3].astype(np.int32).reshape(-1, 1), axis=0) 57 | 58 | # vis = o3d.visualization.Visualizer() 59 | # vis.create_window(width=1200, height=600) 60 | # ctr = vis.get_view_control() 61 | # param = o3d.io.read_pinhole_camera_parameters(cam_param_path) 62 | 63 | # pcd = o3d.geometry.PointCloud() 64 | # pcd.points = o3d.utility.Vector3dVector(points[:, :3]) 65 | # pcd.colors = o3d.utility.Vector3dVector(colors[:, :3]) 66 | # pcd.estimate_normals() 67 | # vis.add_geometry(pcd) 68 | 69 | # ctr.convert_from_pinhole_camera_parameters(param) 70 | 71 | # vis.run() # user changes the view and press "q" to terminate 72 | # param = vis.get_view_control().convert_to_pinhole_camera_parameters() 73 | # o3d.io.write_pinhole_camera_parameters(cam_param_path, param) 74 | 75 | 76 | 77 | 78 | 79 | def draw( 80 | voxels, 81 | cam_pose, 82 | vox_origin, 83 | voxel_size=0.08, 84 | d=0.75, # 0.75m - determine the size of the mesh representing the camera 85 | ): 86 | # Compute the coordinates of the mesh representing camera 87 | y = d * 480 / (2 * 518.8579) 88 | x = d * 640 / (2 * 518.8579) 89 | tri_points = np.array( 90 | [ 91 | [0, 0, 0], 92 | [x, y, d], 93 | [-x, y, d], 94 | [-x, -y, d], 95 | [x, -y, d], 96 | ] 97 | ) 98 | tri_points = np.hstack([tri_points, np.ones((5, 1))]) 99 | 100 | tri_points = (cam_pose @ tri_points.T).T 101 | x = tri_points[:, 0] - vox_origin[0] 102 | y = tri_points[:, 1] - vox_origin[1] 103 | z = tri_points[:, 2] - vox_origin[2] 104 | triangles = [ 105 | (0, 1, 2), 106 | (0, 1, 4), 107 | (0, 3, 4), 108 | (0, 2, 3), 109 | ] 110 | 111 | 112 | # Compute the voxels coordinates 113 | grid_coords = get_grid_coords( 114 | [voxels.shape[0], voxels.shape[2], voxels.shape[1]], voxel_size 115 | ) 116 | 117 | # Attach the predicted class to every voxel 118 | grid_coords = np.vstack( 119 | (grid_coords.T, np.moveaxis(voxels, [0, 1, 2], [0, 2, 1]).reshape(-1)) 120 | ).T 121 | 122 | # Remove empty and unknown voxels 123 | occupied_voxels = grid_coords[(grid_coords[:, 3] > 0) & (grid_coords[:, 3] < 255)] 124 | 125 | # # Draw the camera 126 | # mlab.triangular_mesh( 127 | # x, 128 | # y, 129 | # z, 130 | # triangles, 131 | # representation="wireframe", 132 | # color=(0, 0, 0), 133 | # line_width=5, 134 | # ) 135 | 136 | # # Draw occupied voxels 137 | # plt_plot = mlab.points3d( 138 | # occupied_voxels[:, 0], 139 | # occupied_voxels[:, 1], 140 | # occupied_voxels[:, 2], 141 | # occupied_voxels[:, 3], 142 | # colormap="viridis", 143 | # scale_factor=voxel_size - 0.1 * voxel_size, 144 | # mode="cube", 145 | # opacity=1.0, 146 | # vmin=0, 147 | # vmax=12, 148 | # ) 149 | 150 | colors = np.array( 151 | [ 152 | [22, 191, 206, 255], 153 | [214, 38, 40, 255], 154 | [43, 160, 43, 255], 155 | [158, 216, 229, 255], 156 | [114, 158, 206, 255], 157 | [204, 204, 91, 255], 158 | [255, 186, 119, 255], 159 | [147, 102, 188, 255], 160 | [30, 119, 181, 255], 161 | [188, 188, 33, 255], 162 | [255, 127, 12, 255], 163 | [196, 175, 214, 255], 164 | [153, 153, 153, 255], 165 | ] 166 | ) / 255.0 167 | 168 | colors = np.take_along_axis(colors, occupied_voxels[:, 3].astype(np.int32).reshape(-1, 1), axis=0) 169 | 170 | vis = o3d.visualization.Visualizer() 171 | vis.create_window(width=1200, height=600) 172 | ctr = vis.get_view_control() 173 | param = o3d.io.read_pinhole_camera_parameters(cam_pose) 174 | 175 | pcd = o3d.geometry.PointCloud() 176 | pcd.points = o3d.utility.Vector3dVector(occupied_voxels[:, :3]) 177 | pcd.colors = o3d.utility.Vector3dVector(colors[:, :3]) 178 | pcd.estimate_normals() 179 | vis.add_geometry(pcd) 180 | 181 | ctr.convert_from_pinhole_camera_parameters(cam_pose) 182 | 183 | vis.run() # user changes the view and press "q" to terminate 184 | param = vis.get_view_control().convert_to_pinhole_camera_parameters() 185 | # o3d.io.write_pinhole_camera_parameters(cam_param_path, param) 186 | 187 | 188 | @hydra.main(config_path=None) 189 | def main(config: DictConfig): 190 | scan = config.file 191 | 192 | with open(scan, "rb") as handle: 193 | b = pickle.load(handle) 194 | 195 | cam_pose = b["cam_pose"] 196 | vox_origin = b["vox_origin"] 197 | gt_scene = b["target"] 198 | pred_scene = b["y_pred"] 199 | scan = os.path.basename(scan)[:12] 200 | 201 | pred_scene[(gt_scene == 255)] = 255 # only draw scene inside the room 202 | 203 | draw( 204 | pred_scene, 205 | cam_pose, 206 | vox_origin, 207 | voxel_size=0.08, 208 | d=0.75, 209 | ) 210 | 211 | 212 | if __name__ == "__main__": 213 | main() 214 | -------------------------------------------------------------------------------- /iso/scripts/visualization/OccScanNet_vis_pred.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mayavi import mlab 3 | import argparse 4 | 5 | 6 | def load_voxels(path): 7 | """Load voxel labels from file. 8 | 9 | Args: 10 | path (str): The path of the voxel labels file. 11 | 12 | Returns: 13 | ndarray: The voxel labels with shape (N, 4), 4 is for [x, y, z, label]. 14 | """ 15 | labels = np.load(path) 16 | if labels.shape[1] == 7: 17 | labels = labels[:, [0, 1, 2, 6]] 18 | 19 | return labels 20 | 21 | 22 | def draw(voxel_label, voxel_size=0.05, intrinsic=None, cam_pose=None, d=0.5): 23 | """Visualize the gt or predicted voxel labels. 24 | 25 | Args: 26 | voxel_label (ndarray): The gt or predicted voxel label, with shape (N, 4), N is for number 27 | of voxels, 7 is for [x, y, z, label]. 28 | voxel_size (double): The size of each voxel. 29 | intrinsic (ndarray): The camera intrinsics. 30 | cam_pose (ndarray): The camera pose. 31 | d (double): The depth of camera model visualization. 32 | """ 33 | figure = mlab.figure(size=(1600*0.8, 900*0.8), bgcolor=(1, 1, 1)) 34 | 35 | # voxel_origin = np.array([-0.6619388, 36 | # -2.3863946, 37 | # -0.05 ]) 38 | # voxel_1 = voxel_origin + np.array([4.8, 0, 0]) 39 | # voxel_2 = voxel_origin + np.array([0, 4.8, 0]) 40 | # voxel_3 = voxel_origin + np.array([0, 0, 4.8]) 41 | # voxel_4 = voxel_origin + np.array([4.8, 4.8, 0]) 42 | # voxel_5 = voxel_origin + np.array([4.8, 0, 4.8]) 43 | # voxel_6 = voxel_origin + np.array([0, 4.8, 4.8]) 44 | # voxel_7 = voxel_origin + np.array([4.8, 4.8, 4.8]) 45 | # voxels = np.vstack([voxel_origin, voxel_1, voxel_2, voxel_3, voxel_4, voxel_5, voxel_6, voxel_7]) 46 | # print(voxels.shape) 47 | # x = voxels[:, 0] 48 | # y = voxels[:, 1] 49 | # z = voxels[:, 2] 50 | # sqs = [ 51 | # (0, 1, 2), 52 | # (0, 1, 3), 53 | # (0, 2, 3), 54 | # (1, 2, 4), 55 | # (1, 3, 5), 56 | # (2, 3, 6), 57 | # (3, 5, 6), 58 | # (5, 6, 7), 59 | # ] 60 | 61 | # # draw cam model 62 | # mlab.triangular_mesh( 63 | # x, 64 | # y, 65 | # z, 66 | # sqs, 67 | # representation="wireframe", 68 | # color=(1, 0, 0), 69 | # line_width=7.5, 70 | # ) 71 | 72 | 73 | if intrinsic is not None and cam_pose is not None: 74 | assert d > 0, 'camera model d should > 0' 75 | fx = intrinsic[0, 0] 76 | fy = intrinsic[1, 1] 77 | cx = intrinsic[0, 2] 78 | cy = intrinsic[1, 2] 79 | 80 | # half of the image plane size 81 | y = d * 2 * cy / (2 * fy) 82 | x = d * 2 * cx / (2 * fx) 83 | 84 | # camera points (cam frame) 85 | tri_points = np.array( 86 | [ 87 | [0, 0, 0], 88 | [x, y, d], 89 | [-x, y, d], 90 | [-x, -y, d], 91 | [x, -y, d], 92 | ] 93 | ) 94 | tri_points = np.hstack([tri_points, np.ones((5, 1))]) 95 | 96 | # camera points (world frame) 97 | tri_points = (cam_pose @ tri_points.T).T 98 | x = tri_points[:, 0] 99 | y = tri_points[:, 1] 100 | z = tri_points[:, 2] 101 | triangles = [ 102 | (0, 1, 2), 103 | (0, 1, 4), 104 | (0, 3, 4), 105 | (0, 2, 3), 106 | ] 107 | 108 | # draw cam model 109 | mlab.triangular_mesh( 110 | x, 111 | y, 112 | z, 113 | triangles, 114 | representation="wireframe", 115 | color=(0, 0, 0), 116 | line_width=7.5, 117 | ) 118 | 119 | # draw occupied voxels 120 | plt_plot = mlab.points3d( 121 | voxel_label[:, 0], 122 | voxel_label[:, 1], 123 | voxel_label[:, 2], 124 | voxel_label[:, 3], 125 | colormap="viridis", 126 | scale_factor=voxel_size - 0.1 * voxel_size, 127 | mode="cube", 128 | opacity=1.0, 129 | vmin=0, 130 | vmax=12, 131 | ) 132 | 133 | # label colors 134 | colors = np.array( 135 | [ 136 | [0, 0, 0, 255], # 0 empty 137 | [255, 202, 251, 255], # 1 ceiling 138 | [208, 192, 122, 255], # 2 floor 139 | [199, 210, 255, 255], # 3 wall 140 | [82, 42, 127, 255], # 4 window 141 | [224, 250, 30, 255], # 5 chair 142 | [255, 0, 65, 255], # 6 bed 143 | [144, 177, 144, 255], # 7 sofa 144 | [246, 110, 31, 255], # 8 table 145 | [0, 216, 0, 255], # 9 tv 146 | [135, 177, 214, 255], # 10 furniture 147 | [1, 92, 121, 255], # 11 objects 148 | [128, 128, 128, 255], # 12 occupied with semantic 149 | ] 150 | ) 151 | 152 | plt_plot.glyph.scale_mode = "scale_by_vector" 153 | 154 | plt_plot.module_manager.scalar_lut_manager.lut.table = colors 155 | 156 | mlab.show() 157 | 158 | 159 | def parse_args(): 160 | parser = argparse.ArgumentParser(description="CompleteScanNet dataset visualization.") 161 | parser.add_argument("--file", type=str, help="Voxel label file path.", required=True) 162 | args = parser.parse_args() 163 | return args 164 | 165 | 166 | if __name__ == "__main__": 167 | args = parse_args() 168 | voxels = load_voxels(args.file) 169 | draw(voxels, voxel_size=0.05, d=0.5) -------------------------------------------------------------------------------- /iso/scripts/visualization/kitti_vis_pred.py: -------------------------------------------------------------------------------- 1 | # from operator import gt 2 | import pickle 3 | import numpy as np 4 | from omegaconf import DictConfig 5 | import hydra 6 | from mayavi import mlab 7 | 8 | 9 | def get_grid_coords(dims, resolution): 10 | """ 11 | :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32]) 12 | :return coords_grid: is the center coords of voxels in the grid 13 | """ 14 | 15 | g_xx = np.arange(0, dims[0] + 1) 16 | g_yy = np.arange(0, dims[1] + 1) 17 | sensor_pose = 10 18 | g_zz = np.arange(0, dims[2] + 1) 19 | 20 | # Obtaining the grid with coords... 21 | xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1]) 22 | coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T 23 | coords_grid = coords_grid.astype(np.float) 24 | 25 | coords_grid = (coords_grid * resolution) + resolution / 2 26 | 27 | temp = np.copy(coords_grid) 28 | temp[:, 0] = coords_grid[:, 1] 29 | temp[:, 1] = coords_grid[:, 0] 30 | coords_grid = np.copy(temp) 31 | 32 | return coords_grid 33 | 34 | 35 | def draw( 36 | voxels, 37 | T_velo_2_cam, 38 | vox_origin, 39 | fov_mask, 40 | img_size, 41 | f, 42 | voxel_size=0.2, 43 | d=7, # 7m - determine the size of the mesh representing the camera 44 | ): 45 | # Compute the coordinates of the mesh representing camera 46 | x = d * img_size[0] / (2 * f) 47 | y = d * img_size[1] / (2 * f) 48 | tri_points = np.array( 49 | [ 50 | [0, 0, 0], 51 | [x, y, d], 52 | [-x, y, d], 53 | [-x, -y, d], 54 | [x, -y, d], 55 | ] 56 | ) 57 | tri_points = np.hstack([tri_points, np.ones((5, 1))]) 58 | tri_points = (np.linalg.inv(T_velo_2_cam) @ tri_points.T).T 59 | x = tri_points[:, 0] - vox_origin[0] 60 | y = tri_points[:, 1] - vox_origin[1] 61 | z = tri_points[:, 2] - vox_origin[2] 62 | triangles = [ 63 | (0, 1, 2), 64 | (0, 1, 4), 65 | (0, 3, 4), 66 | (0, 2, 3), 67 | ] 68 | 69 | # Compute the voxels coordinates 70 | grid_coords = get_grid_coords( 71 | [voxels.shape[0], voxels.shape[1], voxels.shape[2]], voxel_size 72 | ) 73 | 74 | # Attach the predicted class to every voxel 75 | grid_coords = np.vstack([grid_coords.T, voxels.reshape(-1)]).T 76 | 77 | # Get the voxels inside FOV 78 | fov_grid_coords = grid_coords[fov_mask, :] 79 | 80 | # Get the voxels outside FOV 81 | outfov_grid_coords = grid_coords[~fov_mask, :] 82 | 83 | # Remove empty and unknown voxels 84 | fov_voxels = fov_grid_coords[ 85 | (fov_grid_coords[:, 3] > 0) & (fov_grid_coords[:, 3] < 255) 86 | ] 87 | outfov_voxels = outfov_grid_coords[ 88 | (outfov_grid_coords[:, 3] > 0) & (outfov_grid_coords[:, 3] < 255) 89 | ] 90 | 91 | figure = mlab.figure(size=(1400, 1400), bgcolor=(1, 1, 1)) 92 | 93 | # Draw the camera 94 | mlab.triangular_mesh( 95 | x, y, z, triangles, representation="wireframe", color=(0, 0, 0), line_width=5 96 | ) 97 | 98 | # Draw occupied inside FOV voxels 99 | plt_plot_fov = mlab.points3d( 100 | fov_voxels[:, 0], 101 | fov_voxels[:, 1], 102 | fov_voxels[:, 2], 103 | fov_voxels[:, 3], 104 | colormap="viridis", 105 | scale_factor=voxel_size - 0.05 * voxel_size, 106 | mode="cube", 107 | opacity=1.0, 108 | vmin=1, 109 | vmax=19, 110 | ) 111 | 112 | # Draw occupied outside FOV voxels 113 | plt_plot_outfov = mlab.points3d( 114 | outfov_voxels[:, 0], 115 | outfov_voxels[:, 1], 116 | outfov_voxels[:, 2], 117 | outfov_voxels[:, 3], 118 | colormap="viridis", 119 | scale_factor=voxel_size - 0.05 * voxel_size, 120 | mode="cube", 121 | opacity=1.0, 122 | vmin=1, 123 | vmax=19, 124 | ) 125 | 126 | colors = np.array( 127 | [ 128 | [100, 150, 245, 255], 129 | [100, 230, 245, 255], 130 | [30, 60, 150, 255], 131 | [80, 30, 180, 255], 132 | [100, 80, 250, 255], 133 | [255, 30, 30, 255], 134 | [255, 40, 200, 255], 135 | [150, 30, 90, 255], 136 | [255, 0, 255, 255], 137 | [255, 150, 255, 255], 138 | [75, 0, 75, 255], 139 | [175, 0, 75, 255], 140 | [255, 200, 0, 255], 141 | [255, 120, 50, 255], 142 | [0, 175, 0, 255], 143 | [135, 60, 0, 255], 144 | [150, 240, 80, 255], 145 | [255, 240, 150, 255], 146 | [255, 0, 0, 255], 147 | ] 148 | ).astype(np.uint8) 149 | 150 | plt_plot_fov.glyph.scale_mode = "scale_by_vector" 151 | plt_plot_outfov.glyph.scale_mode = "scale_by_vector" 152 | 153 | plt_plot_fov.module_manager.scalar_lut_manager.lut.table = colors 154 | 155 | outfov_colors = colors 156 | outfov_colors[:, :3] = outfov_colors[:, :3] // 3 * 2 157 | plt_plot_outfov.module_manager.scalar_lut_manager.lut.table = outfov_colors 158 | 159 | mlab.show() 160 | 161 | 162 | @hydra.main(config_path=None) 163 | def main(config: DictConfig): 164 | scan = config.file 165 | with open(scan, "rb") as handle: 166 | b = pickle.load(handle) 167 | 168 | fov_mask_1 = b["fov_mask_1"] 169 | T_velo_2_cam = b["T_velo_2_cam"] 170 | vox_origin = np.array([0, -25.6, -2]) 171 | 172 | y_pred = b["y_pred"] 173 | 174 | if config.dataset == "kitti_360": 175 | # Visualize KITTI-360 176 | draw( 177 | y_pred, 178 | T_velo_2_cam, 179 | vox_origin, 180 | fov_mask_1, 181 | voxel_size=0.2, 182 | f=552.55426, 183 | img_size=(1408, 376), 184 | d=7, 185 | ) 186 | else: 187 | # Visualize Semantic KITTI 188 | draw( 189 | y_pred, 190 | T_velo_2_cam, 191 | vox_origin, 192 | fov_mask_1, 193 | img_size=(1220, 370), 194 | f=707.0912, 195 | voxel_size=0.2, 196 | d=7, 197 | ) 198 | 199 | 200 | if __name__ == "__main__": 201 | main() 202 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image==0.18.1 2 | PyYAML==5.4 3 | tqdm==4.57.0 4 | scikit-learn==0.24.0 5 | pytorch-lightning==2.0.0 6 | opencv-python==4.5.1.48 7 | hydra-core==1.0.5 8 | numpy==1.20.3 9 | numba==0.53.0 10 | imageio==2.34.1 11 | protobuf==3.19.6 12 | tensorboard -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools import find_packages 3 | 4 | # for install, do: pip install -ve . 5 | 6 | setup(name='iso', packages=find_packages()) 7 | -------------------------------------------------------------------------------- /trained_models/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hongxiaoy/ISO/12d30d244479d52fe64bdc6402bf3dfb4d43503b/trained_models/.placeholder --------------------------------------------------------------------------------