├── .vscode └── settings.json ├── LICENSE ├── README.md ├── assets ├── Teaser.pdf └── Teaser.png ├── checkpoints ├── config ├── config.yaml ├── data │ ├── base.yaml │ ├── carla_novel.yaml │ ├── carla_original.yaml │ ├── carla_patch.yaml │ ├── carla_seg_patch.yaml │ ├── mixture.yaml │ ├── scannet.yaml │ ├── scenenn.yaml │ └── synthetic.yaml └── model │ ├── base.yaml │ ├── carla_model.yaml │ ├── scannet_model.yaml │ ├── scenenn_model.yaml │ └── synthetic_model.yaml ├── data ├── eval.log ├── eval.py ├── noksr ├── callback │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── gpu_cache_clean_callback.cpython-310.pyc │ └── gpu_cache_clean_callback.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── data_module.cpython-310.pyc │ ├── data_module.py │ └── dataset │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── augmentation.cpython-310.pyc │ │ ├── augmentation.cpython-38.pyc │ │ ├── carla.cpython-310.pyc │ │ ├── carla_gt_geometry.cpython-310.pyc │ │ ├── general_dataset.cpython-310.pyc │ │ ├── general_dataset.cpython-38.pyc │ │ ├── mixture.cpython-310.pyc │ │ ├── multiscan.cpython-38.pyc │ │ ├── scannet.cpython-310.pyc │ │ ├── scannet.cpython-38.pyc │ │ ├── scannet_rangeudf.cpython-310.pyc │ │ ├── scannet_rangeudf.cpython-38.pyc │ │ ├── scenenet.cpython-310.pyc │ │ ├── scenenn.cpython-310.pyc │ │ ├── shapenet.cpython-310.pyc │ │ ├── synthetic.cpython-310.pyc │ │ ├── voxelization_utils.cpython-310.pyc │ │ ├── voxelization_utils.cpython-38.pyc │ │ ├── voxelizer.cpython-310.pyc │ │ └── voxelizer.cpython-38.pyc │ │ ├── augmentation.py │ │ ├── carla.py │ │ ├── carla_gt_geometry.py │ │ ├── general_dataset.py │ │ ├── mixture.py │ │ ├── scannet.py │ │ ├── scenenn.py │ │ └── synthetic.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── general_model.cpython-310.pyc │ │ ├── noksr_net.cpython-310.pyc │ │ └── pcs4esr_net.cpython-310.pyc │ ├── general_model.py │ ├── module │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── backbone.cpython-310.pyc │ │ │ ├── backbone.cpython-38.pyc │ │ │ ├── backbone_nocs.cpython-38.pyc │ │ │ ├── common.cpython-310.pyc │ │ │ ├── common.cpython-38.pyc │ │ │ ├── decoder.cpython-310.pyc │ │ │ ├── decoder.cpython-38.pyc │ │ │ ├── decoder2.cpython-38.pyc │ │ │ ├── encoder.cpython-310.pyc │ │ │ ├── encoder.cpython-38.pyc │ │ │ ├── generation.cpython-310.pyc │ │ │ ├── generation.cpython-38.pyc │ │ │ ├── kp_decoder.cpython-310.pyc │ │ │ ├── larger_decoder.cpython-310.pyc │ │ │ ├── point_transformer.cpython-310.pyc │ │ │ ├── tiny_unet.cpython-38.pyc │ │ │ ├── visualization.cpython-310.pyc │ │ │ └── visualization.cpython-38.pyc │ │ ├── decoder.py │ │ ├── generation.py │ │ └── point_transformer.py │ └── noksr_net.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── evaluation.cpython-310.pyc │ ├── optimizer.cpython-310.pyc │ ├── samples.cpython-310.pyc │ ├── segmentation.cpython-310.pyc │ └── transform.cpython-310.pyc │ ├── evaluation.py │ ├── optimizer.py │ ├── samples.py │ ├── segmentation.py │ ├── serialization │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── default.cpython-310.pyc │ │ ├── hilbert.cpython-310.pyc │ │ └── z_order.cpython-310.pyc │ ├── default.py │ ├── hilbert.py │ └── z_order.py │ └── transform.py ├── requirements.txt ├── scripts └── segment_carla.py ├── train.log ├── train.py └── wandb ├── debug-cli.zla247.log └── settings /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "workbench.colorCustomizations": { 3 | "activityBar.background": "#0B2379", 4 | "titleBar.activeBackground": "#1031AA", 5 | "titleBar.activeForeground": "#FBFCFF" 6 | } 7 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NoKSR: Kernel-Free Neural Surface Reconstruction via Point Cloud Serialization 2 | 3 | ![noksr](assets/Teaser.png) 4 | 5 | **NoKSR: Kernel-Free Neural Surface Reconstruction via Point Cloud Serialization**
6 | [Zhen Li *](https://colinzhenli.github.io/), [Weiwei Sun *†](https://wsunid.github.io/), [Shrisudhan Govindarajan](https://shrisudhan.github.io/), [Shaobo Xia](https://scholar.google.com/citations?user=eOPO9E0AAAAJ&hl=en), [Daniel Rebain](http://drebain.com/), [Kwang Moo Yi](https://www.cs.ubc.ca/~kmyi/), [Andrea Tagliasacchi](https://theialab.ca/) 7 | **[Paper](https://arxiv.org/abs/2502.12534), [Project Page](https://theialab.github.io/noksr/)** 8 | 9 | Abstract: We present a novel approach to large-scale point cloud 10 | surface reconstruction by developing an efficient framework 11 | that converts an irregular point cloud into a signed distance 12 | field (SDF). Our backbone builds upon recent transformer- 13 | based architectures (i.e. PointTransformerV3), that serial- 14 | izes the point cloud into a locality-preserving sequence of 15 | tokens. We efficiently predict the SDF value at a point by ag- 16 | gregating nearby tokens, where fast approximate neighbors 17 | can be retrieved thanks to the serialization. We serialize 18 | the point cloud at different levels/scales, and non-linearly 19 | aggregate a feature to predict the SDF value. We show 20 | that aggregating across multiple scales is critical to over- 21 | come the approximations introduced by the serialization 22 | (i.e. false negatives in the neighborhood). Our frameworks 23 | sets the new state-of-the-art in terms of accuracy and effi- 24 | ciency (better or similar performance with half the latency 25 | of the best prior method, coupled with a simpler implemen- 26 | tation), particularly on outdoor datasets where sparse-grid 27 | methods have shown limited performance. 28 | 29 | Contact [Zhen Li @ SFU](zla247@sfu.ca) for questions, comments and reporting bugs. 30 | 31 | ## Related Package 32 | 33 | We implemented the fast approximate neighbor search algorithm in the package [`serial-neighbor`](https://pypi.org/project/serial-neighbor/) — a standalone pip package that provides fast and flexible point cloud neighbor search using serialization encoding by space-filling curves (Z-order, Hilbert, etc.). 34 | 35 | - **📦 PyPI**: [`serial-neighbor`](https://pypi.org/project/serial-neighbor/) 36 | - **🔗 GitHub**: [https://github.com/colinzhenli/serial-neighbor](https://github.com/colinzhenli/serial-neighbor) 37 | 38 | You can install it via: 39 | 40 | ```bash 41 | pip install serial-neighbor 42 | ``` 43 | 44 | ## News 45 | 46 | - [2025/03/22] The package [`serial-neighbor`](https://pypi.org/project/serial-neighbor/) is released. 47 | - [2025/02/21] The code is released. 48 | - [2025/02/19] The arXiv version is released. 49 | 50 | ## Environment setup 51 | 52 | The code is tested on Ubuntu 20.04 LTS with PyTorch 2.0.0 CUDA 11.8 installed. Please follow the following steps to install PyTorch first. 53 | 54 | ``` 55 | # Clone the repository 56 | git clone https://github.com/theialab/noksr.git 57 | cd noksr 58 | 59 | # create and activate the conda environment 60 | conda create -n noksr python=3.10 61 | conda activate noksr 62 | 63 | # install PyTorch 2.x.x 64 | conda install pytorch==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia 65 | 66 | ``` 67 | Then, install PyTorch3D 68 | ``` 69 | # install runtime dependencies for PyTorch3D 70 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 71 | conda install -c bottler nvidiacub 72 | 73 | # install PyTorch3D 74 | conda install pytorch3d -c pytorch3d 75 | ``` 76 | 77 | Install the necessary packages listed out in requirements.txt: 78 | ``` 79 | pip install -r requirements.txt 80 | ``` 81 | 82 | Install torch-scatter and nksr 83 | ``` 84 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.0+cu118.html 85 | pip install nksr -f https://nksr.huangjh.tech/whl/torch-2.0.0+cu118.html 86 | ``` 87 | 88 | The detailed installation of nksr is described in the [NKSR](https://github.com/nv-tlabs/nksr). 89 | 90 | ## Reproducing results from the paper 91 | 92 | ### Data Preparation 93 | 94 | You can download the data from the following links and put it under `NoKSR/data/`. 95 | - ScanNet: 96 | Data is available [here](https://drive.google.com/drive/folders/1JK_6T61eQ07_y1bi1DD9Xj-XRU0EDKGS?usp=sharing). 97 | We converted original meshes to `.pth` data, and the normals are generated using the [open3d.geometry.TriangleMesh](https://www.open3d.org/html/python_api/open3d.geometry.TriangleMesh.html). The processing detailed from raw scannetv2 data is from [minsu3d](https://github.com/3dlg-hcvc/minsu3d). 98 | 99 | - SceneNN 100 | Data is available [here](https://drive.google.com/file/d/1d_ILfaxpJBpiiwCZtvC4jEKnixEr9N2l/view?usp=sharing). 101 | 102 | - SyntheticRoom 103 | Data is available [here](https://drive.google.com/drive/folders/1PosV8qyXCkjIHzVjPeOIdhCLigpXXDku?usp=sharing), it is from [ConvONet](https://github.com/autonomousvision/convolutional_occupancy_networks), which contains the processing details. 104 | 105 | - CARLA 106 | Data is available [here](https://drive.google.com/file/d/1BFwExw7SRJaqHJ98pqqnR-k6g8XYMAqq/view?usp=sharing), it is from [NKSR](https://github.com/nv-tlabs/nksr). 107 | 108 | 109 | ### Training 110 | Note: Configuration files are managed by [Hydra](https://hydra.cc/), you can easily add or override any configuration attributes by passing them as arguments. 111 | ```shell 112 | # log in to WandB 113 | wandb login 114 | 115 | # train a model from scratch 116 | # ScanNet dataset 117 | python train.py model=scannet_model data=scannet 118 | # SyntheticRoom dataset 119 | python train.py model=synthetic_model data=synthetic 120 | # CARLA dataset 121 | python train.py model=carla_model data=carla_patch 122 | ``` 123 | 124 | In addition, you can manually specify different training settings. Common flags include: 125 | - `experiment_name`: Additional experiment name to specify. 126 | - `data.dataset_root_path`: Root path of the dataset. 127 | - `output_folder`: Output folder to save the results, the checkpoints will be saved in `output/{dataset_name}/{experiment_name}/training`. 128 | - `model.network.default_decoder.neighboring`: Neighboring type, default is `Serial`. Options: `Serial`, `KNN`, `Mixture` 129 | 130 | ### Inference 131 | 132 | You can either infer using your own trained models or our pre-trained checkpoints. 133 | 134 | The pre-trained checkpoints on different datasets with different neighboring types are available [here](https://drive.google.com/file/d/1hMm5cnCOfNmr_PgkpOmwRnzCCG4wPqnu/view?usp=drive_link), you can download and put them under `noksr/checkpoints/`. 135 | 136 | ```bash 137 | # For example, Carla original dataset with Serialization neighboring, you need more than 24GB GPU memory to inferece the CARLA dataset, we recommend using a server. 138 | python eval.py model=carla_model data=carla_original model.ckpt_path=checkpoints/Carla_Serial_best.ckpt 139 | # For example, Carla model with Laplacian loss 140 | python eval.py model=carla_model data=carla_original model.ckpt_path=checkpoints/Carla_Laplacian_best.ckpt 141 | # For example, ScanNet dataset with Serialization neighboring 142 | python eval.py model=scannet_model data=scannet model.ckpt_path=checkpoints/ScanNet_Serial_best.ckpt model.inference.split=val 143 | # For example, Test on SceneNN dataset with model trained on ScanNet. 144 | python eval.py model=scenenn_model data=scenenn model.ckpt_path=checkpoints/ScanNet_KNN_best.ckpt 145 | ``` 146 | In addition, in the Carla dataset, you can enable reconstruction from segments. This option will be slower but will save a lot of memory. Flags include: 147 | - `data.reconstruction.by_segment=True`: Enable reconstruction from segments. 148 | - `data.reconstruction.segment_num=10`: Number of segments. 149 | 150 | ### Reconstruction 151 | You can reconstruct a specific scene from the datasets above by specifying the scene index. 152 | ```bash 153 | # For example, Carla dataset, 0 can be replaced by any other scene index of validation set 154 | python eval.py model=carla_model data=carla_original model.ckpt_path={path_to_checkpoint} data.over_fitting=True data.take=1 data.intake_start=0 155 | # For example, ScanNet dataset, 308 can be replaced by any other scene index of validation set 156 | python eval.py model=scannet_model data=scannet model.ckpt_path={path_to_checkpoint} data.over_fitting=True data.take=1 data.intake_start=308 model.inference.split=val 157 | 158 | ``` 159 | In addition, you can manually specify visualization settings. Flags include: 160 | - `data.visualization.save=True`: When to save the results. 161 | - `data.visualization.Mesh=True`: When to save the reconstructed mesh. 162 | - `data.visualization.Input_points=True`: When to save the input points. 163 | 164 | The results will be saved in `output/{dataset_name}/{experiment_name}/reconstruction/visualization`. 165 | 166 | ## Acknowledgement 167 | This work was supported in part by the Natural Sciences and Engineering Research Council of Canada (NSERC) Discovery Grant, NSERC Collaborative Research and Development Grant, Google DeepMind, Digital Research Alliance of Canada, the Advanced Research Computing at the University of British Columbia, and the SFU Visual Computing Research Chair program. Shaobo Xia was supported 168 | by National Natural Science Foundation of China under Grant 42201481. We would also like to thank Jiahui Huang for the valuable discussion and feedback. 169 | 170 | ## BibTex 171 | If you find our work useful in your research, please consider citing: 172 | ```bibtex 173 | @article{li2025noksrkernelfreeneuralsurface, 174 | author = {Zhen Li and Weiwei Sun and Shrisudhan Govindarajan and Shaobo Xia and Daniel Rebain and Kwang Moo Yi and Andrea Tagliasacchi}, 175 | title = {NoKSR: Kernel-Free Neural Surface Reconstruction via Point Cloud Serialization}, 176 | year = {2025}, 177 | booktitle = {International Conference on 3D Vision (3DV)}, 178 | } 179 | ``` 180 | -------------------------------------------------------------------------------- /assets/Teaser.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/assets/Teaser.pdf -------------------------------------------------------------------------------- /assets/Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/assets/Teaser.png -------------------------------------------------------------------------------- /checkpoints: -------------------------------------------------------------------------------- 1 | /localhome/zla247/theia2_data/output/checkpoints -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # Managed by Hydra 2 | 3 | hydra: 4 | output_subdir: null 5 | run: 6 | dir: . 7 | 8 | defaults: 9 | - _self_ 10 | - data: base 11 | - model: base 12 | 13 | experiment_name: run_1 14 | output_folder: output 15 | exp_output_root_path: ${output_folder}/${data.dataset}/${experiment_name} 16 | 17 | global_train_seed: 12345 18 | global_test_seed: 32102 -------------------------------------------------------------------------------- /config/data/base.yaml: -------------------------------------------------------------------------------- 1 | # Managed by Hydra 2 | 3 | dataset_root_path: data 4 | 5 | batch_size: 1 6 | num_workers: 4 7 | voxel_size: 0.02 # only used to compute neighboring radius 8 | 9 | supervision: 10 | only_l1_sdf_loss: false 11 | structure_weight: 20.0 12 | 13 | gt_type: "PointTSDFVolume" 14 | 15 | on_surface: 16 | normal_loss: False # supervised normal by nksr style loss 17 | weight: 200.0 18 | normal_weight: 100.0 19 | subsample: 10000 20 | svh_tree_depth: 4 21 | 22 | sdf: 23 | max_dist: 0.2 24 | weight: 300.0 25 | reg_sdf_weight: 0.0 26 | svh_tree_depth: 4 27 | samplers: 28 | - type: "uniform" 29 | n_samples: 10000 30 | expand: 1 31 | expand_top: 3 32 | - type: "band" 33 | n_samples: 10000 34 | eps: 0.5 # Times voxel size. 35 | 36 | truncate: False # whether to truncate the SDF values 37 | gt_type: "l1" # or 'l1' 38 | gt_soft: true 39 | gt_band: 1.0 # times voxel size. 40 | pd_transform: true 41 | # (For AV Supervision) 42 | vol_sup: true 43 | 44 | udf: 45 | max_dist: 0.2 46 | abs_sdf: True 47 | weight: 150.0 48 | svh_tree_depth: 4 49 | samplers: 50 | - type: "uniform" 51 | n_samples: 10000 52 | expand: 1 53 | expand_top: 5 54 | - type: "band" 55 | n_samples: 10000 56 | eps: 0.5 # Times voxel size. 57 | 58 | eikonal: 59 | loss: False 60 | flip: True # flip the gradient sign 61 | weight: 10.0 62 | 63 | laplacian: 64 | loss: False 65 | weight: 0.0 66 | 67 | num_input_points: 10000 68 | uniform_sampling: True 69 | input_splats: False 70 | num_query_points: 50000 71 | std_dev: 0.00 # noise to add to the input points 72 | in_memory: True # whether precompute voxel indices and load voxel into memory 73 | 74 | take: -1 # how many data to take for training and validation, -1 means all 75 | intake_start: 0 76 | over_fitting: False # whether to use only one voxel for training and validation 77 | 78 | reconstruction: 79 | by_segment: False 80 | segment_num: 10 81 | trim: True 82 | gt_mask: False 83 | gt_sdf: False 84 | 85 | visualization: 86 | save: False 87 | Mesh: True 88 | Input_points: True 89 | Dense_points: False 90 | 91 | 92 | -------------------------------------------------------------------------------- /config/data/carla_novel.yaml: -------------------------------------------------------------------------------- 1 | # Managed by Hydra 2 | 3 | defaults: 4 | - base 5 | 6 | dataset: Carla 7 | 8 | base_path: ${data.dataset_root_path}/carla-lidar/dataset-no-patch 9 | input_path: ${data.dataset_root_path}/carla-lidar/dataset-p1n2-no-patch 10 | 11 | drives: ['Town03-0', 'Town03-1', 'Town03-2'] 12 | 13 | voxel_size: 0.1 14 | 15 | transforms: [] 16 | 17 | supervision: 18 | sdf: 19 | max_dist: 0.6 20 | 21 | udf: 22 | max_dist: 0.6 23 | 24 | reconstruction: 25 | mask_threshold: 0.1 26 | 27 | evaluation: 28 | evaluator: "MeshEvaluator" # align evaluation with NKSR -------------------------------------------------------------------------------- /config/data/carla_original.yaml: -------------------------------------------------------------------------------- 1 | # Managed by Hydra 2 | 3 | defaults: 4 | - base 5 | 6 | dataset: Carla 7 | 8 | base_path: ${data.dataset_root_path}/carla-lidar/dataset-no-patch 9 | input_path: ${data.dataset_root_path}/carla-lidar/dataset-p1n2-no-patch 10 | 11 | drives: ['Town01-0', 'Town01-1', 'Town01-2', 12 | 'Town02-0', 'Town02-1', 'Town02-2', 13 | 'Town10-0', 'Town10-1', 'Town10-2', 'Town10-3', 'Town10-4'] 14 | voxel_size: 0.1 15 | 16 | transforms: [] 17 | 18 | supervision: 19 | sdf: 20 | max_dist: 0.6 21 | 22 | udf: 23 | max_dist: 0.6 24 | 25 | reconstruction: 26 | mask_threshold: 0.1 27 | 28 | evaluation: 29 | evaluator: "MeshEvaluator" # align evaluation with NKSR -------------------------------------------------------------------------------- /config/data/carla_patch.yaml: -------------------------------------------------------------------------------- 1 | # Managed by Hydra 2 | 3 | defaults: 4 | - base 5 | 6 | dataset: Carla 7 | 8 | base_path: ${data.dataset_root_path}/carla-lidar/dataset 9 | input_path: ${data.dataset_root_path}/carla-lidar/dataset-p1n2 10 | 11 | drives: ['Town01-0', 'Town01-1', 'Town01-2', 12 | 'Town02-0', 'Town02-1', 'Town02-2', 13 | 'Town10-0', 'Town10-1', 'Town10-2', 'Town10-3', 'Town10-4'] 14 | 15 | voxel_size: 0.1 16 | 17 | transforms: [] 18 | 19 | supervision: 20 | sdf: 21 | max_dist: 0.6 22 | 23 | udf: 24 | max_dist: 0.6 25 | 26 | reconstruction: 27 | mask_threshold: 0.1 28 | 29 | evaluation: 30 | evaluator: "MeshEvaluator" # align evaluation with NKSR -------------------------------------------------------------------------------- /config/data/carla_seg_patch.yaml: -------------------------------------------------------------------------------- 1 | # Managed by Hydra 2 | 3 | defaults: 4 | - base 5 | 6 | dataset: Carla 7 | 8 | base_path: ${data.dataset_root_path}/carla-lidar/dataset-seg-patch 9 | input_path: 10 | 11 | drives: ['Town01-0', 'Town01-1', 'Town01-2', 12 | 'Town02-0', 'Town02-1', 'Town02-2', 13 | 'Town10-0', 'Town10-1', 'Town10-2', 'Town10-3', 'Town10-4'] 14 | 15 | voxel_size: 0.1 16 | 17 | transforms: [] 18 | 19 | supervision: 20 | sdf: 21 | max_dist: 0.6 22 | 23 | udf: 24 | max_dist: 0.6 25 | 26 | reconstruction: 27 | mask_threshold: 0.1 28 | 29 | evaluation: 30 | evaluator: "MeshEvaluator" # align evaluation with NKSR -------------------------------------------------------------------------------- /config/data/mixture.yaml: -------------------------------------------------------------------------------- 1 | # Managed by Hydra 2 | 3 | defaults: 4 | - base 5 | 6 | dataset: Mixture 7 | 8 | validation_set: Synthetic # ScanNet, Synthetic, or both 9 | metadata: # only for SscanNet 10 | metadata_path: ${data.dataset_root_path}/scannetv2/metadata 11 | train_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_train.txt 12 | val_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_val.txt 13 | test_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_test.txt 14 | combine_file: ${data.dataset_root_path}/scannetv2/metadata/scannetv2-labels.combined.tsv 15 | 16 | ScanNet: 17 | dataset_path: ${data.dataset_root_path}/scannetv2 18 | classes: 2 19 | # ignore_classes: [ 1, 2 ] 20 | class_names: [ 'floor', 'wall', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 21 | 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink', 22 | 'bathtub', 'otherfurniture' ] 23 | 24 | mapping_classes_ids: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39 ] 25 | 26 | Synthetic: 27 | path: ${data.dataset_root_path}/synthetic_data/synthetic_room_dataset 28 | 29 | input_type: pointcloud_crop 30 | classes: ['rooms_04', 'rooms_05', 'rooms_06', 'rooms_07', 'rooms_08'] 31 | pointcloud_n: 10000 32 | std_dev: 0.00 # 0.005 33 | # points_subsample: 1024 34 | points_file: points_iou 35 | points_iou_file: points_iou 36 | pointcloud_file: pointcloud 37 | pointcloud_chamfer_file: pointcloud 38 | voxels_file: null 39 | multi_files: 10 40 | unit_size: 0.005 # size of a voxel (not used) 41 | query_vol_size: 25 42 | -------------------------------------------------------------------------------- /config/data/scannet.yaml: -------------------------------------------------------------------------------- 1 | # Managed by Hydra 2 | 3 | defaults: 4 | - base 5 | 6 | dataset: Scannet 7 | dataset_path: ${data.dataset_root_path}/scannetv2 8 | 9 | metadata: 10 | metadata_path: ${data.dataset_root_path}/scannetv2/metadata 11 | train_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_train.txt 12 | val_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_val.txt 13 | test_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_test.txt 14 | combine_file: ${data.dataset_root_path}/scannetv2/metadata/scannetv2-labels.combined.tsv 15 | 16 | supervision: 17 | sdf: 18 | max_dist: 0.2 19 | udf: 20 | max_dist: 0.2 21 | 22 | reconstruction: 23 | mask_threshold: 0.015 24 | 25 | evaluation: 26 | evaluator: "UnitMeshEvaluator" 27 | -------------------------------------------------------------------------------- /config/data/scenenn.yaml: -------------------------------------------------------------------------------- 1 | # Managed by Hydra 2 | 3 | defaults: 4 | - base 5 | 6 | dataset: SceneNN 7 | dataset_path: ${data.dataset_root_path}/scenenn_seg_76_raw/scenenn_sub_data 8 | train_files: ['005', '014', '015', '016', '025', '036', '038', '041', '045', 9 | '047', '052', '054', '057', '061', '062', '066', '071', '073', '078', '080', 10 | '084', '087', '089', '096', '098', '109', '201', '202', '209', '217', '223', 11 | '225', '227', '231', '234', '237', '240', '243', '249', '251', '255', '260', 12 | '263', '265', '270', '276', '279', '286', '294', '308', '522', '609', '613', 13 | '614', '623', '700'] 14 | test_files: ['011', '021', '065', '032', '093', '246', '086', '069', '206', 15 | '252', '273', '527', '621', '076', '082', '049', '207', '213', '272', '074'] 16 | 17 | supervision: 18 | sdf: 19 | max_dist: 0.2 20 | udf: 21 | max_dist: 0.2 22 | 23 | reconstruction: 24 | mask_threshold: 0.015 25 | 26 | evaluation: 27 | evaluator: "UnitMeshEvaluator" -------------------------------------------------------------------------------- /config/data/synthetic.yaml: -------------------------------------------------------------------------------- 1 | # Managed by Hydra 2 | 3 | defaults: 4 | - base 5 | 6 | dataset: Synthetic 7 | path: ${data.dataset_root_path}/synthetic_data/synthetic_room_dataset 8 | 9 | input_type: pointcloud_crop 10 | classes: ['rooms_04', 'rooms_05', 'rooms_06', 'rooms_07', 'rooms_08'] 11 | pointcloud_n: 10000 12 | # std_dev: 0.00 # 0.005 13 | # points_subsample: 1024 14 | points_file: points_iou 15 | points_iou_file: points_iou 16 | pointcloud_file: pointcloud 17 | pointcloud_chamfer_file: pointcloud 18 | voxels_file: null 19 | multi_files: 10 20 | unit_size: 0.005 # size of a voxel (not used) 21 | query_vol_size: 25 22 | 23 | supervision: 24 | sdf: 25 | max_dist: 0.2 26 | 27 | udf: 28 | max_dist: 0.2 29 | 30 | reconstruction: 31 | mask_threshold: 0.01 32 | 33 | evaluation: 34 | evaluator: "UnitMeshEvaluator" 35 | -------------------------------------------------------------------------------- /config/model/base.yaml: -------------------------------------------------------------------------------- 1 | # Managed by Hydra 2 | 3 | ckpt_path: 4 | 5 | logger: 6 | # https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html 7 | _target_: pytorch_lightning.loggers.WandbLogger 8 | project: noksr 9 | name: ${experiment_name} 10 | 11 | # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html 12 | trainer: 13 | accelerator: gpu #cpu or gpu 14 | devices: auto 15 | num_nodes: 1 16 | max_epochs: 800 17 | max_steps: 200000 18 | num_sanity_val_steps: 8 19 | check_val_every_n_epoch: 4 20 | profiler: simple 21 | 22 | # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html 23 | checkpoint_monitor: 24 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 25 | save_top_k: -1 26 | every_n_epochs: 4 27 | filename: "{epoch}" 28 | dirpath: ${exp_output_root_path}/training 29 | 30 | 31 | optimizer: 32 | name: Adam # SGD or Adam 33 | lr: 0.001 34 | warmup_steps_ratio: 0.1 35 | 36 | lr_decay: # for Adam 37 | decay_start_epoch: 250 38 | 39 | 40 | inference: 41 | split: test -------------------------------------------------------------------------------- /config/model/carla_model.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Managed by Hydra 3 | 4 | defaults: 5 | - base 6 | 7 | trainer: 8 | max_epochs: 100 9 | lr_decay: # for Adam 10 | decay_start_epoch: 60 11 | 12 | network: 13 | module: noksr 14 | use_color: False 15 | use_normal: True # True for SDF supervision 16 | use_xyz: True # True only for PointTransformerV3 debugging 17 | 18 | latent_dim: 32 # 19 | prepare_epochs: 350 20 | eval_algorithm: DMC # DMC or DensePC or MC 21 | grad_type: Analytical # Analytical or Numerical 22 | backbone: PointTransformerV3 # PointTransformerV3 or MinkUNet 23 | 24 | point_transformerv3: # default 25 | in_channels: 6 26 | order: ["z", "z-trans", "hilbert", "hilbert-trans"] 27 | stride: [8, 8, 4, 2] 28 | enc_depths: [2, 2, 2, 6, 2] 29 | enc_channels: [32, 64, 128, 256, 512] 30 | enc_num_head: [2, 4, 8, 16, 32] 31 | enc_patch_size: [64, 64, 64, 64, 64] 32 | 33 | 34 | dec_depths: [2, 2, 2, 2] 35 | dec_channels: [64, 64, 128, 256] 36 | dec_num_head: [4, 4, 8, 16] 37 | dec_patch_size: [64, 64, 64, 64] 38 | 39 | mlp_ratio: 4 40 | qkv_bias: true 41 | qk_scale: null 42 | attn_drop: 0.0 43 | proj_drop: 0.0 44 | drop_path: 0.3 45 | pre_norm: true 46 | shuffle_orders: true 47 | enable_rpe: false 48 | enable_flash: true 49 | upcast_attention: false 50 | upcast_softmax: false 51 | cls_mode: false 52 | pdnorm_bn: false 53 | pdnorm_ln: false 54 | pdnorm_decouple: true 55 | pdnorm_adaptive: false 56 | pdnorm_affine: true 57 | pdnorm_conditions: ["ScanNet", "S3DIS", "Structured3D"] 58 | 59 | # Define common settings as an anchor 60 | default_decoder: 61 | decoder_type: Decoder # or Decoder or SimpleDecoder or SimpleInterpolatedDecoder or InterpolatedDecoder or MultiScaleInterpolatedDecoder 62 | backbone: ${model.network.backbone} 63 | decoder_channels: ${model.network.point_transformerv3.dec_channels} 64 | stride: ${model.network.point_transformerv3.stride} 65 | coords_enc: Fourier # MLP or Fourier (16 dim MLP or 63 dim Fourier) 66 | architecture: point_nerf # attentive_pooling or transformer_encoder or point_nerf 67 | activation: LeakyReLU # LeakyReLU or ReLU or Softplus or ShiftedSoftplus 68 | negative_slope: 0.01 # for LeakyReLU, default: 0.01 69 | 70 | neighboring: Serial # KNN or Mink or Serial or Mixture 71 | serial_neighbor_layers: 3 # number of serial neighbor layers 72 | k_neighbors: 8 # 1 for no interpolation 73 | dist_factor: [1, 1, 1, 1] # N times voxel size 74 | serial_orders: ['hilbert'] # 'z', 'hilbert', 'z-trans', 'hilbert-trans' 75 | last_n_layers: 4 # decodes features from last n layers of the Unet backbone 76 | feature_dim: [8, 8, 8, 8] 77 | point_nerf_hidden_dim: 32 78 | point_nerf_before_skip: 1 79 | point_nerf_after_skip: 1 80 | 81 | num_hidden_layers_before: 2 82 | num_hidden_layers_after: 2 83 | hidden_dim: 32 84 | 85 | 86 | # Specific decoder configurations 87 | sdf_decoder: 88 | decoder_type: ${model.network.default_decoder.decoder_type} 89 | backbone: ${model.network.default_decoder.backbone} 90 | decoder_channels: ${model.network.default_decoder.decoder_channels} 91 | stride: ${model.network.default_decoder.stride} 92 | coords_enc: ${model.network.default_decoder.coords_enc} 93 | activation: ${model.network.default_decoder.activation} 94 | negative_slope: ${model.network.default_decoder.negative_slope} 95 | 96 | neighboring: ${model.network.default_decoder.neighboring} 97 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers} 98 | k_neighbors: ${model.network.default_decoder.k_neighbors} 99 | dist_factor: ${model.network.default_decoder.dist_factor} 100 | serial_orders: ${model.network.default_decoder.serial_orders} 101 | last_n_layers: ${model.network.default_decoder.last_n_layers} 102 | feature_dim: ${model.network.default_decoder.feature_dim} 103 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim} 104 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip} 105 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip} 106 | 107 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before} 108 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after} 109 | hidden_dim: ${model.network.default_decoder.hidden_dim} 110 | 111 | 112 | mask_decoder: 113 | decoder_type: ${model.network.default_decoder.decoder_type} 114 | backbone: ${model.network.default_decoder.backbone} 115 | decoder_channels: ${model.network.default_decoder.decoder_channels} 116 | stride: ${model.network.default_decoder.stride} 117 | coords_enc: ${model.network.default_decoder.coords_enc} 118 | activation: ${model.network.default_decoder.activation} 119 | negative_slope: ${model.network.default_decoder.negative_slope} 120 | 121 | neighboring: ${model.network.default_decoder.neighboring} 122 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers} 123 | k_neighbors: ${model.network.default_decoder.k_neighbors} 124 | dist_factor: ${model.network.default_decoder.dist_factor} 125 | serial_orders: ${model.network.default_decoder.serial_orders} 126 | last_n_layers: ${model.network.default_decoder.last_n_layers} 127 | feature_dim: ${model.network.default_decoder.feature_dim} 128 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim} 129 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip} 130 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip} 131 | 132 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before} 133 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after} 134 | hidden_dim: ${model.network.default_decoder.hidden_dim} 135 | 136 | 137 | -------------------------------------------------------------------------------- /config/model/scannet_model.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Managed by Hydra 3 | 4 | defaults: 5 | - base 6 | 7 | trainer: 8 | max_epochs: 160 9 | lr_decay: # for Adam 10 | decay_start_epoch: 100 11 | 12 | network: 13 | module: noksr 14 | use_color: False 15 | use_normal: True 16 | use_xyz: True 17 | 18 | latent_dim: 32 # 19 | prepare_epochs: 350 20 | eval_algorithm: DMC # DMC or DensePC or MC 21 | grad_type: Analytical # Analytical or Numerical 22 | backbone: PointTransformerV3 # PointTransformerV3 or MinkUNet 23 | 24 | point_transformerv3: # default 25 | in_channels: 6 26 | order: ["z", "z-trans", "hilbert", "hilbert-trans"] 27 | stride: [2, 2, 4, 4] 28 | enc_depths: [2, 2, 2, 6, 2] 29 | enc_channels: [32, 64, 128, 256, 512] 30 | enc_num_head: [2, 4, 8, 16, 32] 31 | enc_patch_size: [64, 64, 64, 64, 64] 32 | 33 | 34 | dec_depths: [2, 2, 2, 2] 35 | dec_channels: [64, 64, 128, 256] 36 | dec_num_head: [4, 4, 8, 16] 37 | dec_patch_size: [64, 64, 64, 64] 38 | 39 | mlp_ratio: 4 40 | qkv_bias: true 41 | qk_scale: null 42 | attn_drop: 0.0 43 | proj_drop: 0.0 44 | drop_path: 0.3 45 | pre_norm: true 46 | shuffle_orders: true 47 | enable_rpe: false 48 | enable_flash: true 49 | upcast_attention: false 50 | upcast_softmax: false 51 | cls_mode: false 52 | pdnorm_bn: false 53 | pdnorm_ln: false 54 | pdnorm_decouple: true 55 | pdnorm_adaptive: false 56 | pdnorm_affine: true 57 | pdnorm_conditions: ["ScanNet", "S3DIS", "Structured3D"] 58 | 59 | # Define common settings as an anchor 60 | default_decoder: 61 | decoder_type: Decoder 62 | backbone: ${model.network.backbone} 63 | decoder_channels: ${model.network.point_transformerv3.dec_channels} 64 | stride: ${model.network.point_transformerv3.stride} 65 | coords_enc: Fourier # MLP or Fourier (16 dim MLP or 63 dim Fourier) 66 | architecture: point_nerf # attentive_pooling or transformer_encoder or point_nerf 67 | negative_slope: 0.01 # for LeakyReLU, default: 0.01 68 | 69 | neighboring: Serial # KNN or Mink or Serial or Mixture 70 | serial_neighbor_layers: 3 # number of serial neighbor layers 71 | k_neighbors: 8 # 1 for no interpolation 72 | dist_factor: [4, 3, 2, 2] # N times voxel size 73 | serial_orders: ['hilbert'] # 'z', 'hilbert', 'z-trans', 'hilbert-trans' 74 | last_n_layers: 4 # decodes features from last n layers of the Unet backbone 75 | feature_dim: [8, 8, 8, 8] 76 | point_nerf_hidden_dim: 32 77 | point_nerf_before_skip: 1 78 | point_nerf_after_skip: 1 79 | 80 | num_hidden_layers_before: 2 81 | num_hidden_layers_after: 2 82 | hidden_dim: 32 83 | 84 | activation: LeakyReLU # LeakyReLU or ReLU or Softplus or ShiftedSoftplus 85 | 86 | # Specific decoder configurations 87 | sdf_decoder: 88 | decoder_type: ${model.network.default_decoder.decoder_type} 89 | backbone: ${model.network.default_decoder.backbone} 90 | decoder_channels: ${model.network.default_decoder.decoder_channels} 91 | stride: ${model.network.default_decoder.stride} 92 | coords_enc: ${model.network.default_decoder.coords_enc} 93 | negative_slope: ${model.network.default_decoder.negative_slope} 94 | 95 | neighboring: ${model.network.default_decoder.neighboring} 96 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers} 97 | k_neighbors: ${model.network.default_decoder.k_neighbors} 98 | dist_factor: ${model.network.default_decoder.dist_factor} 99 | serial_orders: ${model.network.default_decoder.serial_orders} 100 | last_n_layers: ${model.network.default_decoder.last_n_layers} 101 | feature_dim: ${model.network.default_decoder.feature_dim} 102 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim} 103 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip} 104 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip} 105 | 106 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before} 107 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after} 108 | hidden_dim: ${model.network.default_decoder.hidden_dim} 109 | 110 | activation: ${model.network.default_decoder.activation} 111 | 112 | mask_decoder: 113 | decoder_type: ${model.network.default_decoder.decoder_type} 114 | backbone: ${model.network.default_decoder.backbone} 115 | decoder_channels: ${model.network.default_decoder.decoder_channels} 116 | stride: ${model.network.default_decoder.stride} 117 | coords_enc: ${model.network.default_decoder.coords_enc} 118 | negative_slope: ${model.network.default_decoder.negative_slope} 119 | 120 | neighboring: ${model.network.default_decoder.neighboring} 121 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers} 122 | k_neighbors: ${model.network.default_decoder.k_neighbors} 123 | dist_factor: ${model.network.default_decoder.dist_factor} 124 | serial_orders: ${model.network.default_decoder.serial_orders} 125 | last_n_layers: ${model.network.default_decoder.last_n_layers} 126 | feature_dim: ${model.network.default_decoder.feature_dim} 127 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim} 128 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip} 129 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip} 130 | 131 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before} 132 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after} 133 | hidden_dim: ${model.network.default_decoder.hidden_dim} 134 | 135 | activation: ${model.network.default_decoder.activation} 136 | 137 | 138 | -------------------------------------------------------------------------------- /config/model/scenenn_model.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Managed by Hydra 3 | 4 | defaults: 5 | - base 6 | 7 | trainer: 8 | max_epochs: 160 9 | lr_decay: # for Adam 10 | decay_start_epoch: 100 11 | 12 | network: 13 | module: noksr 14 | use_color: False 15 | use_normal: True 16 | use_xyz: True 17 | 18 | latent_dim: 32 # 19 | prepare_epochs: 350 20 | eval_algorithm: DMC # DMC or DensePC or MC 21 | grad_type: Analytical # Analytical or Numerical 22 | backbone: PointTransformerV3 # PointTransformerV3 or MinkUNet 23 | 24 | point_transformerv3: # default 25 | in_channels: 6 26 | order: ["z", "z-trans", "hilbert", "hilbert-trans"] 27 | stride: [2, 2, 4, 4] 28 | enc_depths: [2, 2, 2, 6, 2] 29 | enc_channels: [32, 64, 128, 256, 512] 30 | enc_num_head: [2, 4, 8, 16, 32] 31 | enc_patch_size: [64, 64, 64, 64, 64] 32 | 33 | 34 | dec_depths: [2, 2, 2, 2] 35 | dec_channels: [64, 64, 128, 256] 36 | dec_num_head: [4, 4, 8, 16] 37 | dec_patch_size: [64, 64, 64, 64] 38 | 39 | mlp_ratio: 4 40 | qkv_bias: true 41 | qk_scale: null 42 | attn_drop: 0.0 43 | proj_drop: 0.0 44 | drop_path: 0.3 45 | pre_norm: true 46 | shuffle_orders: true 47 | enable_rpe: false 48 | enable_flash: true 49 | upcast_attention: false 50 | upcast_softmax: false 51 | cls_mode: false 52 | pdnorm_bn: false 53 | pdnorm_ln: false 54 | pdnorm_decouple: true 55 | pdnorm_adaptive: false 56 | pdnorm_affine: true 57 | pdnorm_conditions: ["ScanNet", "S3DIS", "Structured3D"] 58 | 59 | # Define common settings as an anchor 60 | default_decoder: 61 | decoder_type: Decoder 62 | backbone: ${model.network.backbone} 63 | decoder_channels: ${model.network.point_transformerv3.dec_channels} 64 | stride: ${model.network.point_transformerv3.stride} 65 | coords_enc: Fourier # MLP or Fourier (16 dim MLP or 63 dim Fourier) 66 | architecture: point_nerf # attentive_pooling or transformer_encoder or point_nerf 67 | negative_slope: 0.01 # for LeakyReLU, default: 0.01 68 | 69 | neighboring: KNN # KNN or Mink or Serial or Mixture 70 | serial_neighbor_layers: 3 # number of serial neighbor layers 71 | k_neighbors: 4 # 1 for no interpolation 72 | dist_factor: [8, 4, 4, 4] # N times voxel size 73 | serial_orders: ['hilbert'] # 'z', 'hilbert', 'z-trans', 'hilbert-trans' 74 | last_n_layers: 4 # decodes features from last n layers of the Unet backbone 75 | feature_dim: [8, 8, 8, 8] 76 | point_nerf_hidden_dim: 32 77 | point_nerf_before_skip: 1 78 | point_nerf_after_skip: 1 79 | 80 | num_hidden_layers_before: 2 81 | num_hidden_layers_after: 2 82 | hidden_dim: 32 83 | 84 | activation: LeakyReLU # LeakyReLU or ReLU or Softplus or ShiftedSoftplus 85 | 86 | # Specific decoder configurations 87 | sdf_decoder: 88 | decoder_type: ${model.network.default_decoder.decoder_type} 89 | backbone: ${model.network.default_decoder.backbone} 90 | decoder_channels: ${model.network.default_decoder.decoder_channels} 91 | stride: ${model.network.default_decoder.stride} 92 | coords_enc: ${model.network.default_decoder.coords_enc} 93 | negative_slope: ${model.network.default_decoder.negative_slope} 94 | 95 | neighboring: ${model.network.default_decoder.neighboring} 96 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers} 97 | k_neighbors: ${model.network.default_decoder.k_neighbors} 98 | dist_factor: ${model.network.default_decoder.dist_factor} 99 | serial_orders: ${model.network.default_decoder.serial_orders} 100 | last_n_layers: ${model.network.default_decoder.last_n_layers} 101 | feature_dim: ${model.network.default_decoder.feature_dim} 102 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim} 103 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip} 104 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip} 105 | 106 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before} 107 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after} 108 | hidden_dim: ${model.network.default_decoder.hidden_dim} 109 | 110 | activation: ${model.network.default_decoder.activation} 111 | 112 | mask_decoder: 113 | decoder_type: ${model.network.default_decoder.decoder_type} 114 | backbone: ${model.network.default_decoder.backbone} 115 | decoder_channels: ${model.network.default_decoder.decoder_channels} 116 | stride: ${model.network.default_decoder.stride} 117 | coords_enc: ${model.network.default_decoder.coords_enc} 118 | negative_slope: ${model.network.default_decoder.negative_slope} 119 | 120 | neighboring: ${model.network.default_decoder.neighboring} 121 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers} 122 | k_neighbors: ${model.network.default_decoder.k_neighbors} 123 | dist_factor: ${model.network.default_decoder.dist_factor} 124 | serial_orders: ${model.network.default_decoder.serial_orders} 125 | last_n_layers: ${model.network.default_decoder.last_n_layers} 126 | feature_dim: ${model.network.default_decoder.feature_dim} 127 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim} 128 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip} 129 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip} 130 | 131 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before} 132 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after} 133 | hidden_dim: ${model.network.default_decoder.hidden_dim} 134 | 135 | activation: ${model.network.default_decoder.activation} 136 | 137 | 138 | -------------------------------------------------------------------------------- /config/model/synthetic_model.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Managed by Hydra 3 | 4 | defaults: 5 | - base 6 | 7 | trainer: 8 | max_epochs: 160 9 | lr_decay: # for Adam 10 | decay_start_epoch: 100 11 | 12 | network: 13 | module: noksr 14 | use_color: False 15 | use_normal: True # True for SDF supervision 16 | use_xyz: True # True only for PointTransformerV3 debugging 17 | 18 | latent_dim: 32 # 19 | prepare_epochs: 350 20 | eval_algorithm: DMC # DMC or DensePC or MC 21 | grad_type: Analytical # Analytical or Numerical 22 | backbone: PointTransformerV3 # PointTransformerV3 or MinkUNet 23 | 24 | point_transformerv3: # default 25 | in_channels: 6 26 | order: ["z", "z-trans", "hilbert", "hilbert-trans"] 27 | stride: [2, 2, 4, 4] 28 | enc_depths: [2, 2, 2, 6, 2] 29 | enc_channels: [32, 64, 128, 256, 512] 30 | enc_num_head: [2, 4, 8, 16, 32] 31 | enc_patch_size: [64, 64, 64, 64, 64] 32 | 33 | 34 | dec_depths: [2, 2, 2, 2] 35 | dec_channels: [64, 64, 128, 256] 36 | dec_num_head: [4, 4, 8, 16] 37 | dec_patch_size: [64, 64, 64, 64] 38 | 39 | mlp_ratio: 4 40 | qkv_bias: true 41 | qk_scale: null 42 | attn_drop: 0.0 43 | proj_drop: 0.0 44 | drop_path: 0.3 45 | pre_norm: true 46 | shuffle_orders: true 47 | enable_rpe: false 48 | enable_flash: true 49 | upcast_attention: false 50 | upcast_softmax: false 51 | cls_mode: false 52 | pdnorm_bn: false 53 | pdnorm_ln: false 54 | pdnorm_decouple: true 55 | pdnorm_adaptive: false 56 | pdnorm_affine: true 57 | pdnorm_conditions: ["ScanNet", "S3DIS", "Structured3D"] 58 | 59 | # Define common settings as an anchor 60 | default_decoder: 61 | decoder_type: Decoder # or Decoder or SimpleDecoder or SimpleInterpolatedDecoder or InterpolatedDecoder or MultiScaleInterpolatedDecoder 62 | backbone: ${model.network.backbone} 63 | decoder_channels: ${model.network.point_transformerv3.dec_channels} 64 | stride: ${model.network.point_transformerv3.stride} 65 | coords_enc: Fourier # MLP or Fourier (16 dim MLP or 63 dim Fourier) 66 | architecture: point_nerf # attentive_pooling or transformer_encoder or point_nerf 67 | activation: LeakyReLU # LeakyReLU or ReLU or Softplus or ShiftedSoftplus 68 | negative_slope: 0.01 # for LeakyReLU, default: 0.01 69 | 70 | neighboring: Serial # KNN or Mink or Serial or Mixture 71 | serial_neighbor_layers: 3 # number of serial neighbor layers 72 | k_neighbors: 4 # 1 for no interpolation 73 | dist_factor: [4, 3, 2, 2] # N times voxel size 74 | serial_orders: ['hilbert'] # 'z', 'hilbert', 'z-trans', 'hilbert-trans' 75 | last_n_layers: 4 # decodes features from last n layers of the Unet backbone 76 | feature_dim: [8, 8, 8, 8] 77 | point_nerf_hidden_dim: 32 78 | point_nerf_before_skip: 1 79 | point_nerf_after_skip: 1 80 | 81 | num_hidden_layers_before: 2 82 | num_hidden_layers_after: 2 83 | hidden_dim: 32 84 | 85 | 86 | # Specific decoder configurations 87 | sdf_decoder: 88 | decoder_type: ${model.network.default_decoder.decoder_type} 89 | backbone: ${model.network.default_decoder.backbone} 90 | decoder_channels: ${model.network.default_decoder.decoder_channels} 91 | stride: ${model.network.default_decoder.stride} 92 | coords_enc: ${model.network.default_decoder.coords_enc} 93 | activation: ${model.network.default_decoder.activation} 94 | negative_slope: ${model.network.default_decoder.negative_slope} 95 | 96 | neighboring: ${model.network.default_decoder.neighboring} 97 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers} 98 | k_neighbors: ${model.network.default_decoder.k_neighbors} 99 | dist_factor: ${model.network.default_decoder.dist_factor} 100 | serial_orders: ${model.network.default_decoder.serial_orders} 101 | last_n_layers: ${model.network.default_decoder.last_n_layers} 102 | feature_dim: ${model.network.default_decoder.feature_dim} 103 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim} 104 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip} 105 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip} 106 | 107 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before} 108 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after} 109 | hidden_dim: ${model.network.default_decoder.hidden_dim} 110 | 111 | 112 | mask_decoder: 113 | decoder_type: ${model.network.default_decoder.decoder_type} 114 | backbone: ${model.network.default_decoder.backbone} 115 | decoder_channels: ${model.network.default_decoder.decoder_channels} 116 | stride: ${model.network.default_decoder.stride} 117 | coords_enc: ${model.network.default_decoder.coords_enc} 118 | activation: ${model.network.default_decoder.activation} 119 | negative_slope: ${model.network.default_decoder.negative_slope} 120 | 121 | neighboring: ${model.network.default_decoder.neighboring} 122 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers} 123 | k_neighbors: ${model.network.default_decoder.k_neighbors} 124 | dist_factor: ${model.network.default_decoder.dist_factor} 125 | serial_orders: ${model.network.default_decoder.serial_orders} 126 | last_n_layers: ${model.network.default_decoder.last_n_layers} 127 | feature_dim: ${model.network.default_decoder.feature_dim} 128 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim} 129 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip} 130 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip} 131 | 132 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before} 133 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after} 134 | hidden_dim: ${model.network.default_decoder.hidden_dim} 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /data: -------------------------------------------------------------------------------- 1 | /localhome/zla247/theia2_data -------------------------------------------------------------------------------- /eval.log: -------------------------------------------------------------------------------- 1 | [2024-11-25 21:56:55,155][pycg.exp][WARNING] - Empty pointcloud / mesh detected! Return NaN metric! 2 | [2024-11-25 23:24:31,010][pycg.exp][WARNING] - Empty pointcloud / mesh detected! Return NaN metric! 3 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | import torch 4 | from tqdm import tqdm 5 | import numpy as np 6 | import open3d as o3d 7 | from collections import defaultdict 8 | from torch.utils.data import Dataset 9 | from pathlib import Path 10 | from importlib import import_module 11 | import pytorch_lightning as pl 12 | from noksr.data.data_module import DataModule 13 | from noksr.utils.evaluation import UnitMeshEvaluator, MeshEvaluator 14 | from noksr.model.module import Generator 15 | from nksr.svh import SparseFeatureHierarchy 16 | from noksr.utils.segmentation import segment_and_generate_encoder_outputs 17 | 18 | @hydra.main(version_base=None, config_path="config", config_name="config") 19 | def main(cfg): 20 | assert torch.cuda.is_available() 21 | device = torch.device("cuda") 22 | # fix the seed 23 | pl.seed_everything(cfg.global_test_seed, workers=True) 24 | 25 | output_path = os.path.join(cfg.exp_output_root_path, "reconstruction", "visualizations") 26 | os.makedirs(output_path, exist_ok=True) 27 | # output_path = cfg.exp_output_root_path 28 | 29 | print("==> initializing data ...") 30 | data_module = DataModule(cfg) 31 | data_module.setup("test") 32 | val_loader = data_module.val_dataloader() 33 | 34 | print("=> initializing model...") 35 | model = getattr(import_module("noksr.model"), cfg.model.network.module)(cfg) 36 | # Load checkpoint 37 | if os.path.isfile(cfg.model.ckpt_path): 38 | print(f"=> loading model checkpoint '{cfg.model.ckpt_path}'") 39 | checkpoint = torch.load(cfg.model.ckpt_path, map_location=device) 40 | model.load_state_dict(checkpoint['state_dict']) 41 | print("=> loaded checkpoint successfully.") 42 | else: 43 | raise FileNotFoundError(f"No checkpoint found at '{cfg.model.ckpt_path}'. Please ensure the path is correct.") 44 | 45 | model.to(device) 46 | model.eval() 47 | 48 | if cfg.data.reconstruction.trim: 49 | dense_generator = Generator( 50 | model.sdf_decoder, 51 | model.mask_decoder, 52 | cfg.data.voxel_size, 53 | cfg.model.network.sdf_decoder.k_neighbors, 54 | cfg.model.network.sdf_decoder.last_n_layers, 55 | cfg.data.reconstruction 56 | ) 57 | else: 58 | dense_generator = Generator( 59 | model.sdf_decoder, 60 | None, 61 | cfg.data.voxel_size, 62 | cfg.model.network.sdf_decoder.k_neighbors, 63 | cfg.model.network.sdf_decoder.last_n_layers, 64 | cfg.data.reconstruction 65 | ) 66 | 67 | # Initialize a dictionary to keep track of sums and count 68 | eval_sums = defaultdict(float) 69 | batch_count = 0 70 | import time 71 | print("=> start inference...") 72 | start_time = time.time() 73 | total_reconstruction_duration = 0.0 74 | total_neighboring_time = 0.0 75 | total_dmc_time = 0.0 76 | total_aggregation_time = 0.0 77 | total_forward_duration = 0.0 78 | total_decoder_time = 0.0 79 | results_dict = [] 80 | for batch in tqdm(val_loader, desc="Inference", unit="batch"): 81 | batch = {k: v.to(device) if hasattr(v, 'to') else v for k, v in batch.items()} 82 | if 'gt_geometry' in batch: 83 | gt_geometry = batch['gt_geometry'] 84 | batch['all_xyz'], batch['all_normal'], _ = gt_geometry[0].torch_attr() 85 | process_start = time.time() 86 | forward_start = time.time() 87 | if cfg.data.reconstruction.by_segment: 88 | encoder_outputs, encoding_codes, depth = segment_and_generate_encoder_outputs( 89 | batch=batch, 90 | model=model, 91 | device=device, 92 | segment_num=cfg.data.reconstruction.segment_num, # Example: 10 segments 93 | grid_size=0.01, 94 | serial_order='z' 95 | ) 96 | 97 | forward_end = time.time() 98 | torch.set_grad_enabled(False) 99 | dmc_mesh, time_dict = dense_generator.generate_dual_mc_mesh_by_segment( 100 | data_dict=batch, 101 | encoder_outputs=encoder_outputs, 102 | encoding_codes=encoding_codes, 103 | depth=depth, 104 | device=device 105 | ) 106 | else: 107 | if cfg.model.network.backbone == 'PointTransformerV3': 108 | pt_data = {} 109 | pt_data['feat'] = batch['point_features'] 110 | pt_data['offset'] = batch['xyz_splits'] 111 | pt_data['grid_size'] = 0.01 112 | pt_data['coord'] = batch['xyz'] 113 | encoder_outputs = model.point_transformer(pt_data) 114 | 115 | forward_end = time.time() 116 | torch.set_grad_enabled(False) 117 | 118 | dmc_mesh, time_dict = dense_generator.generate_dual_mc_mesh(batch, encoder_outputs, device) 119 | total_neighboring_time += time_dict['neighboring_time'] 120 | total_dmc_time += time_dict['dmc_time'] 121 | total_aggregation_time += time_dict['aggregation_time'] 122 | total_decoder_time += time_dict['decoder_time'] 123 | # Calculate time taken for these three steps 124 | process_end = time.time() 125 | total_forward_duration += forward_end - forward_start 126 | process_duration = process_end - process_start 127 | total_reconstruction_duration += process_duration # Accumulate the duration 128 | print("\nTotal Reconstruction Time: {:.2f} seconds".format(total_reconstruction_duration)) 129 | print("├── Total PointTransformerV3 Time: {:.2f} seconds".format(total_forward_duration)) 130 | print("├── Total Decoder Time: {:.2f} seconds".format(total_decoder_time)) 131 | print("│ ├── Total Neighboring Time: {:.2f} seconds".format(total_neighboring_time)) 132 | print("│ └── Total Aggregation Time: {:.2f} seconds" 133 | .format(total_aggregation_time)) 134 | print("├── Total Dual Marching Cube Time: {:.2f} seconds".format(total_dmc_time)) 135 | 136 | # Evaluate the reconstructed mesh 137 | if cfg.data.evaluation.evaluator == "UnitMeshEvaluator": 138 | evaluator = UnitMeshEvaluator(n_points=100000, metric_names=UnitMeshEvaluator.ESSENTIAL_METRICS) 139 | elif cfg.data.evaluation.evaluator == "MeshEvaluator": 140 | evaluator = MeshEvaluator(n_points=int(5e6), metric_names=MeshEvaluator.ESSENTIAL_METRICS) 141 | if "gt_onet_sample" in batch: 142 | eval_dict = evaluator.eval_mesh(dmc_mesh, batch['all_xyz'], None, onet_samples=batch['gt_onet_sample'][0]) 143 | else: 144 | eval_dict = evaluator.eval_mesh(dmc_mesh, batch['all_xyz'], None) 145 | 146 | # eval_dict['voxel_num'] = batch['voxel_nums'][0] 147 | for k, v in eval_dict.items(): 148 | eval_sums[k] += v 149 | scene_name = batch['scene_names'] 150 | eval_dict["data_id"] = batch_count 151 | eval_dict["scene_name"] = scene_name 152 | results_dict.append(eval_dict) 153 | print(f"Scene Name: {scene_name}") 154 | print(f"completeness: {eval_dict['completeness']:.4f}") 155 | print(f"accuracy: {eval_dict['accuracy']:.4f}") 156 | print(f"Chamfer-L2: {eval_dict['chamfer-L2']:.4f}") 157 | print(f"Chamfer-L1: {eval_dict['chamfer-L1']:.4f}") 158 | print(f"F-Score (1.0%): {eval_dict['f-score-10']:.4f}") 159 | print(f"F-Score (1.5%): {eval_dict['f-score-15']:.4f}") 160 | print(f"F-Score (2.0%): {eval_dict['f-score-20']:.4f}") 161 | if "o3d-iou" in eval_dict: 162 | print(f"O3D-IoU: {eval_dict['o3d-iou']:.4f}") 163 | print() # For better readability 164 | 165 | # Save the mesh 166 | if cfg.data.visualization.save: 167 | scene_name = batch['scene_names'][0].replace('/', '-') 168 | if cfg.data.visualization.Mesh: 169 | mesh_file = f"{output_path}/noksr-{cfg.data.dataset}-data_id_{cfg.data.intake_start}-{scene_name}_CD_{eval_dict['chamfer-L1']:.4f}_mesh.obj" # or "output_mesh.obj" for OBJ format 170 | o3d.io.write_triangle_mesh(mesh_file, dmc_mesh) 171 | if cfg.data.visualization.Input_points: 172 | pcd = o3d.geometry.PointCloud() 173 | pcd.points = o3d.utility.Vector3dVector(batch['xyz'].cpu().numpy()) 174 | point_cloud_file = f"{output_path}/noksr-{cfg.data.dataset}-data_id_{cfg.data.intake_start}-{scene_name}_input_pcd.ply" 175 | o3d.io.write_point_cloud(point_cloud_file, pcd) 176 | if cfg.data.visualization.Dense_points: 177 | pcd = o3d.geometry.PointCloud() 178 | pcd.points = o3d.utility.Vector3dVector(batch['all_xyz'].cpu().numpy()) 179 | point_cloud_file = f"{output_path}/noksr-{cfg.data.dataset}-data_id_{cfg.data.intake_start}-{scene_name}_dense_pcd.ply" 180 | o3d.io.write_point_cloud(point_cloud_file, pcd) 181 | 182 | batch_count += 1 183 | torch.cuda.empty_cache() 184 | 185 | end_time = time.time() 186 | total_time = end_time - start_time 187 | print(f"\n---Total reconstruction time without data loading: {total_reconstruction_duration:.2f} seconds") 188 | print(f"---Total reconstruction time including data loading for all scenes: {total_time:.2f} seconds") 189 | if batch_count > 0: 190 | print("\n--- Evaluation Metrics Averages ---") 191 | for k in eval_sums: 192 | average = eval_sums[k] / batch_count 193 | print(f"{k}: {average:.5f}") 194 | else: 195 | print("No batches were processed.") 196 | 197 | 198 | if __name__ == "__main__": 199 | main() 200 | -------------------------------------------------------------------------------- /noksr/callback/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpu_cache_clean_callback import GPUCacheCleanCallback 2 | -------------------------------------------------------------------------------- /noksr/callback/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/callback/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/callback/__pycache__/gpu_cache_clean_callback.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/callback/__pycache__/gpu_cache_clean_callback.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/callback/gpu_cache_clean_callback.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import Callback 2 | import torch 3 | 4 | 5 | class GPUCacheCleanCallback(Callback): 6 | 7 | def on_train_batch_start(self, *args, **kwargs): 8 | torch.cuda.empty_cache() 9 | 10 | def on_validation_batch_start(self, *args, **kwargs): 11 | torch.cuda.empty_cache() 12 | 13 | def on_test_batch_start(self, *args, **kwargs): 14 | torch.cuda.empty_cache() 15 | 16 | def on_predict_batch_start(self, *args, **kwargs): 17 | torch.cuda.empty_cache() 18 | -------------------------------------------------------------------------------- /noksr/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_module import DataModule -------------------------------------------------------------------------------- /noksr/data/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/__pycache__/data_module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/__pycache__/data_module.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/data_module.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import torch 6 | from torch.utils.data import Sampler, DistributedSampler, Dataset 7 | import pytorch_lightning as pl 8 | from arrgh import arrgh 9 | 10 | 11 | class DataModule(pl.LightningDataModule): 12 | def __init__(self, data_cfg): 13 | super().__init__() 14 | self.data_cfg = data_cfg 15 | self.dataset = getattr(import_module('noksr.data.dataset'), data_cfg.data.dataset) 16 | 17 | def setup(self, stage=None): 18 | if stage == "fit" or stage is None: 19 | self.train_set = self.dataset(self.data_cfg, "train") 20 | self.val_set = self.dataset(self.data_cfg, "val") 21 | if stage == "test" or stage is None: 22 | self.val_set = self.dataset(self.data_cfg, self.data_cfg.model.inference.split) 23 | if stage == "predict" or stage is None: 24 | self.test_set = self.dataset(self.data_cfg, "test") 25 | 26 | def train_dataloader(self): 27 | return DataLoader(self.train_set, batch_size=self.data_cfg.data.batch_size, shuffle=True, pin_memory=True, 28 | collate_fn=_sparse_collate_fn, num_workers=self.data_cfg.data.num_workers, drop_last=True) 29 | 30 | def val_dataloader(self): 31 | return DataLoader(self.val_set, batch_size=1, pin_memory=True, collate_fn=_sparse_collate_fn, 32 | num_workers=self.data_cfg.data.num_workers) 33 | 34 | def test_dataloader(self): 35 | return DataLoader(self.val_set, batch_size=1, pin_memory=True, collate_fn=_sparse_collate_fn, 36 | num_workers=self.data_cfg.data.num_workers) 37 | 38 | def predict_dataloader(self): 39 | return DataLoader(self.test_set, batch_size=1, pin_memory=True, collate_fn=_sparse_collate_fn, 40 | num_workers=self.data_cfg.data.num_workers) 41 | 42 | 43 | def _sparse_collate_fn(batch): 44 | if "gt_geometry" in batch[0]: 45 | """ for dataset with ground truth geometry """ 46 | data = {} 47 | xyz = [] 48 | point_features = [] 49 | gt_geometry_list = [] 50 | scene_names_list = [] 51 | 52 | for _, b in enumerate(batch): 53 | scene_names_list.append(b["scene_name"]) 54 | xyz.append(torch.from_numpy(b["xyz"])) 55 | point_features.append(torch.from_numpy(b["point_features"])) 56 | gt_geometry_list.append(b["gt_geometry"]) 57 | 58 | data['xyz'] = torch.cat(xyz, dim=0) 59 | data['point_features'] = torch.cat(point_features, dim=0) 60 | data['xyz_splits'] = torch.tensor([c.shape[0] for c in xyz]) 61 | data['gt_geometry'] = gt_geometry_list 62 | 63 | data['scene_names'] = scene_names_list 64 | 65 | return data 66 | 67 | else: 68 | data = {} 69 | xyz = [] 70 | all_xyz = [] 71 | all_normals = [] 72 | point_features = [] 73 | scene_names_list = [] 74 | for _, b in enumerate(batch): 75 | scene_names_list.append(b["scene_name"]) 76 | xyz.append(torch.from_numpy(b["xyz"])) 77 | all_xyz.append(torch.from_numpy(b["all_xyz"])) 78 | all_normals.append(torch.from_numpy(b["all_normals"])) 79 | point_features.append(torch.from_numpy(b["point_features"])) 80 | 81 | data['all_xyz'] = torch.cat(all_xyz, dim=0) 82 | data['all_normals'] = torch.cat(all_normals, dim=0) 83 | data['xyz'] = torch.cat(xyz, dim=0) 84 | data['point_features'] = torch.cat(point_features, dim=0) 85 | 86 | data['scene_names'] = scene_names_list 87 | data['row_splits'] = [c.shape[0] for c in all_xyz] 88 | data['xyz_splits'] = torch.tensor([c.shape[0] for c in xyz]) 89 | if "gt_onet_sample" in batch[0]: 90 | data['gt_onet_sample'] = [b["gt_onet_sample"] for b in batch] 91 | return data 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /noksr/data/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .general_dataset import GeneralDataset 2 | # from .scannet_rangeudf import ScannetRangeUDF 3 | from .scannet import Scannet 4 | from .carla import Carla 5 | from .synthetic import Synthetic 6 | from .mixture import Mixture 7 | from .general_dataset import DatasetSpec 8 | from .general_dataset import RandomSafeDataset 9 | from .carla_gt_geometry import get_class 10 | from .scenenn import SceneNN 11 | -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/augmentation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/augmentation.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/augmentation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/augmentation.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/carla.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/carla.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/carla_gt_geometry.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/carla_gt_geometry.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/general_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/general_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/general_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/general_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/mixture.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/mixture.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/multiscan.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/multiscan.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/scannet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scannet.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/scannet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scannet.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/scannet_rangeudf.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scannet_rangeudf.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/scannet_rangeudf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scannet_rangeudf.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/scenenet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scenenet.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/scenenn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scenenn.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/shapenet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/shapenet.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/synthetic.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/synthetic.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/voxelization_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/voxelization_utils.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/voxelization_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/voxelization_utils.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/voxelizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/voxelizer.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/__pycache__/voxelizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/voxelizer.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/data/dataset/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import logging 4 | import numpy as np 5 | import scipy 6 | import scipy.ndimage 7 | import scipy.interpolate 8 | import torch 9 | 10 | 11 | # A sparse tensor consists of coordinates and associated features. 12 | # You must apply augmentation to both. 13 | # In 2D, flip, shear, scale, and rotation of images are coordinate transformation 14 | # color jitter, hue, etc., are feature transformations 15 | ############################## 16 | # Feature transformations 17 | ############################## 18 | class ChromaticTranslation(object): 19 | '''Add random color to the image, input must be an array in [0,255] or a PIL image''' 20 | 21 | def __init__(self, trans_range_ratio=1e-1): 22 | ''' 23 | trans_range_ratio: ratio of translation i.e. 255 * 2 * ratio * rand(-0.5, 0.5) 24 | ''' 25 | self.trans_range_ratio = trans_range_ratio 26 | 27 | def __call__(self, coords, feats, labels): 28 | if random.random() < 0.95: 29 | tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.trans_range_ratio 30 | feats[:, :3] = np.clip(tr + feats[:, :3], 0, 255) 31 | return coords, feats, labels 32 | 33 | 34 | class ChromaticAutoContrast(object): 35 | 36 | def __init__(self, randomize_blend_factor=True, blend_factor=0.5): 37 | self.randomize_blend_factor = randomize_blend_factor 38 | self.blend_factor = blend_factor 39 | 40 | def __call__(self, coords, feats, labels): 41 | if random.random() < 0.2: 42 | # mean = np.mean(feats, 0, keepdims=True) 43 | # std = np.std(feats, 0, keepdims=True) 44 | # lo = mean - std 45 | # hi = mean + std 46 | lo = np.min(feats, 0, keepdims=True) 47 | hi = np.max(feats, 0, keepdims=True) 48 | 49 | scale = 255 / (hi - lo) 50 | 51 | contrast_feats = (feats - lo) * scale 52 | 53 | blend_factor = random.random() if self.randomize_blend_factor else self.blend_factor 54 | feats = (1 - blend_factor) * feats + blend_factor * contrast_feats 55 | return coords, feats, labels 56 | 57 | 58 | class ChromaticJitter(object): 59 | 60 | def __init__(self, std=0.01): 61 | self.std = std 62 | 63 | def __call__(self, coords, feats, labels): 64 | if random.random() < 0.95: 65 | noise = np.random.randn(feats.shape[0], 3) 66 | noise *= self.std * 255 67 | feats[:, :3] = np.clip(noise + feats[:, :3], 0, 255) 68 | return coords, feats, labels 69 | 70 | 71 | class HueSaturationTranslation(object): 72 | 73 | @staticmethod 74 | def rgb_to_hsv(rgb): 75 | # Translated from source of colorsys.rgb_to_hsv 76 | # r,g,b should be a numpy arrays with values between 0 and 255 77 | # rgb_to_hsv returns an array of floats between 0.0 and 1.0. 78 | rgb = rgb.astype('float') 79 | hsv = np.zeros_like(rgb) 80 | # in case an RGBA array was passed, just copy the A channel 81 | hsv[..., 3:] = rgb[..., 3:] 82 | r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] 83 | maxc = np.max(rgb[..., :3], axis=-1) 84 | minc = np.min(rgb[..., :3], axis=-1) 85 | hsv[..., 2] = maxc 86 | mask = maxc != minc 87 | hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask] 88 | rc = np.zeros_like(r) 89 | gc = np.zeros_like(g) 90 | bc = np.zeros_like(b) 91 | rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask] 92 | gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask] 93 | bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask] 94 | hsv[..., 0] = np.select([r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc) 95 | hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0 96 | return hsv 97 | 98 | @staticmethod 99 | def hsv_to_rgb(hsv): 100 | # Translated from source of colorsys.hsv_to_rgb 101 | # h,s should be a numpy arrays with values between 0.0 and 1.0 102 | # v should be a numpy array with values between 0.0 and 255.0 103 | # hsv_to_rgb returns an array of uints between 0 and 255. 104 | rgb = np.empty_like(hsv) 105 | rgb[..., 3:] = hsv[..., 3:] 106 | h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] 107 | i = (h * 6.0).astype('uint8') 108 | f = (h * 6.0) - i 109 | p = v * (1.0 - s) 110 | q = v * (1.0 - s * f) 111 | t = v * (1.0 - s * (1.0 - f)) 112 | i = i % 6 113 | conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5] 114 | rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v) 115 | rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t) 116 | rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p) 117 | return rgb.astype('uint8') 118 | 119 | def __init__(self, hue_max, saturation_max): 120 | self.hue_max = hue_max 121 | self.saturation_max = saturation_max 122 | 123 | def __call__(self, coords, feats, labels): 124 | # Assume feat[:, :3] is rgb 125 | hsv = HueSaturationTranslation.rgb_to_hsv(feats[:, :3]) 126 | hue_val = (random.random() - 0.5) * 2 * self.hue_max 127 | sat_ratio = 1 + (random.random() - 0.5) * 2 * self.saturation_max 128 | hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1) 129 | hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1) 130 | feats[:, :3] = np.clip(HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255) 131 | 132 | return coords, feats, labels 133 | 134 | 135 | ############################## 136 | # Coordinate transformations 137 | ############################## 138 | class RandomHorizontalFlip(object): 139 | 140 | def __init__(self, upright_axis, is_temporal): 141 | ''' 142 | upright_axis: axis index among x,y,z, i.e. 2 for z 143 | ''' 144 | self.is_temporal = is_temporal 145 | self.D = 4 if is_temporal else 3 146 | self.upright_axis = {'x': 0, 'y': 1, 'z': 2}[upright_axis.lower()] 147 | # Use the rest of axes for flipping. 148 | self.horz_axes = set(range(self.D)) - set([self.upright_axis]) 149 | 150 | def __call__(self, coords, feats, labels): 151 | if random.random() < 0.95: 152 | for curr_ax in self.horz_axes: 153 | if random.random() < 0.5: 154 | coord_max = np.max(coords[:, curr_ax]) 155 | coords[:, curr_ax] = coord_max - coords[:, curr_ax] 156 | return coords, feats, labels 157 | 158 | 159 | class ElasticDistortion: 160 | 161 | def __init__(self, distortion_params): 162 | self.distortion_params = distortion_params 163 | 164 | def elastic_distortion(self, coords, granularity, magnitude): 165 | '''Apply elastic distortion on sparse coordinate space. 166 | 167 | pointcloud: numpy array of (number of points, at least 3 spatial dims) 168 | granularity: size of the noise grid (in same scale[m/cm] as the voxel grid) 169 | magnitude: noise multiplier 170 | ''' 171 | blurx = np.ones((3, 1, 1, 1)).astype('float32') / 3 172 | blury = np.ones((1, 3, 1, 1)).astype('float32') / 3 173 | blurz = np.ones((1, 1, 3, 1)).astype('float32') / 3 174 | coords_min = coords.min(0) 175 | 176 | # Create Gaussian noise tensor of the size given by granularity. 177 | noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3 178 | noise = np.random.randn(*noise_dim, 3).astype(np.float32) 179 | 180 | # Smoothing. 181 | for _ in range(2): 182 | noise = scipy.ndimage.filters.convolve(noise, blurx, mode='constant', cval=0) 183 | noise = scipy.ndimage.filters.convolve(noise, blury, mode='constant', cval=0) 184 | noise = scipy.ndimage.filters.convolve(noise, blurz, mode='constant', cval=0) 185 | 186 | # Trilinear interpolate noise filters for each spatial dimensions. 187 | ax = [ 188 | np.linspace(d_min, d_max, d) 189 | for d_min, d_max, d in zip(coords_min - granularity, coords_min + 190 | granularity * (noise_dim - 2), noise_dim) 191 | ] 192 | interp = scipy.interpolate.RegularGridInterpolator(ax, noise, bounds_error=0, fill_value=0) 193 | coords = coords + interp(coords) * magnitude 194 | return coords 195 | 196 | def __call__(self, pointcloud): 197 | if self.distortion_params is not None: 198 | if random.random() < 0.95: 199 | for granularity, magnitude in self.distortion_params: 200 | pointcloud = self.elastic_distortion(pointcloud, granularity, magnitude) 201 | return pointcloud 202 | 203 | 204 | class Compose(object): 205 | '''Composes several transforms together.''' 206 | 207 | def __init__(self, transforms): 208 | self.transforms = transforms 209 | 210 | def __call__(self, *args): 211 | for t in self.transforms: 212 | args = t(*args) 213 | return args 214 | 215 | 216 | class cfl_collate_fn_factory: 217 | '''Generates collate function for coords, feats, labels. 218 | 219 | Args: 220 | limit_numpoints: If 0 or False, does not alter batch size. If positive integer, limits batch 221 | size so that the number of input coordinates is below limit_numpoints. 222 | ''' 223 | 224 | def __init__(self, limit_numpoints): 225 | self.limit_numpoints = limit_numpoints 226 | 227 | def __call__(self, list_data): 228 | coords, feats, labels = list(zip(*list_data)) 229 | coords_batch, feats_batch, labels_batch = [], [], [] 230 | 231 | batch_id = 0 232 | batch_num_points = 0 233 | for batch_id, _ in enumerate(coords): 234 | num_points = coords[batch_id].shape[0] 235 | batch_num_points += num_points 236 | if self.limit_numpoints and batch_num_points > self.limit_numpoints: 237 | num_full_points = sum(len(c) for c in coords) 238 | num_full_batch_size = len(coords) 239 | logging.warning( 240 | f'\t\tCannot fit {num_full_points} points into {self.limit_numpoints} points ' 241 | f'limit. Truncating batch size at {batch_id} out of {num_full_batch_size} with {batch_num_points - num_points}.' 242 | ) 243 | break 244 | coords_batch.append( 245 | torch.cat((torch.from_numpy(coords[batch_id]).int(), 246 | torch.ones(num_points, 1).int() * batch_id), 1)) 247 | feats_batch.append(torch.from_numpy(feats[batch_id])) 248 | labels_batch.append(torch.from_numpy(labels[batch_id]).int()) 249 | 250 | batch_id += 1 251 | 252 | # Concatenate all lists 253 | coords_batch = torch.cat(coords_batch, 0).int() 254 | feats_batch = torch.cat(feats_batch, 0).float() 255 | labels_batch = torch.cat(labels_batch, 0).int() 256 | return coords_batch, feats_batch, labels_batch 257 | 258 | 259 | class cflt_collate_fn_factory: 260 | '''Generates collate function for coords, feats, labels, point_clouds, transformations. 261 | 262 | Args: 263 | limit_numpoints: If 0 or False, does not alter batch size. If positive integer, limits batch 264 | size so that the number of input coordinates is below limit_numpoints. 265 | ''' 266 | 267 | def __init__(self, limit_numpoints): 268 | self.limit_numpoints = limit_numpoints 269 | 270 | def __call__(self, list_data): 271 | coords, feats, labels, pointclouds, transformations = list(zip(*list_data)) 272 | cfl_collate_fn = cfl_collate_fn_factory(limit_numpoints=self.limit_numpoints) 273 | coords_batch, feats_batch, labels_batch = cfl_collate_fn(list(zip(coords, feats, labels))) 274 | num_truncated_batch = coords_batch[:, -1].max().item() + 1 275 | 276 | batch_id = 0 277 | pointclouds_batch, transformations_batch = [], [] 278 | for pointcloud, transformation in zip(pointclouds, transformations): 279 | if batch_id >= num_truncated_batch: 280 | break 281 | pointclouds_batch.append( 282 | torch.cat((torch.from_numpy(pointcloud), torch.ones(pointcloud.shape[0], 1) * batch_id), 283 | 1)) 284 | transformations_batch.append( 285 | torch.cat( 286 | (torch.from_numpy(transformation), torch.ones(transformation.shape[0], 1) * batch_id), 287 | 1)) 288 | batch_id += 1 289 | 290 | pointclouds_batch = torch.cat(pointclouds_batch, 0).float() 291 | transformations_batch = torch.cat(transformations_batch, 0).float() 292 | return coords_batch, feats_batch, labels_batch, pointclouds_batch, transformations_batch 293 | 294 | class RandomDropout(object): 295 | 296 | def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5): 297 | """ 298 | upright_axis: axis index among x,y,z, i.e. 2 for z 299 | """ 300 | self.dropout_ratio = dropout_ratio 301 | self.dropout_application_ratio = dropout_application_ratio 302 | 303 | def __call__(self, coords, feats, labels): 304 | if random.random() < self.dropout_ratio: 305 | N = len(coords) 306 | inds = np.random.choice(N, int(N * (1 - self.dropout_ratio)), replace=False) 307 | return coords[inds], feats[inds], labels[inds] 308 | return coords, feats, labels -------------------------------------------------------------------------------- /noksr/data/dataset/carla.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from torch.utils.data import Dataset 4 | 5 | import numpy as np 6 | 7 | from noksr.utils.transform import ComposedTransforms 8 | from noksr.data.dataset.carla_gt_geometry import get_class 9 | from noksr.data.dataset.general_dataset import DatasetSpec as DS 10 | from noksr.data.dataset.general_dataset import RandomSafeDataset 11 | 12 | from pycg import exp 13 | 14 | 15 | class Carla(RandomSafeDataset): 16 | def __init__(self, cfg, split): 17 | super().__init__(cfg, split) 18 | 19 | self.skip_on_error = False 20 | self.custom_name = "carla" 21 | self.cfg = cfg 22 | 23 | self.split = split # use only train set for overfitting 24 | split = self.split 25 | self.intake_start = cfg.data.intake_start 26 | self.take = cfg.data.take 27 | self.input_splats = cfg.data.input_splats 28 | 29 | self.gt_type = cfg.data.supervision.gt_type 30 | 31 | self.transforms = ComposedTransforms(cfg.data.transforms) 32 | self.use_dummy_gt = False 33 | 34 | # If drives not specified, use all sub-folders 35 | base_path = Path(cfg.data.base_path) 36 | drives = cfg.data.drives 37 | if drives is None: 38 | drives = os.listdir(base_path) 39 | drives = [c for c in drives if (base_path / c).is_dir()] 40 | self.drives = drives 41 | self.input_path = cfg.data.input_path 42 | 43 | # Get all items 44 | self.all_items = [] 45 | self.drive_base_paths = {} 46 | for c in drives: 47 | self.drive_base_paths[c] = base_path / c 48 | split_file = self.drive_base_paths[c] / (split + '.lst') 49 | with split_file.open('r') as f: 50 | models_c = f.read().split('\n') 51 | if '' in models_c: 52 | models_c.remove('') 53 | self.all_items += [{'drive': c, 'item': m} for m in models_c] 54 | 55 | if self.cfg.data.over_fitting: 56 | self.all_items = self.all_items[self.intake_start:self.take+self.intake_start] 57 | 58 | 59 | 60 | def __len__(self): 61 | return len(self.all_items) 62 | 63 | def get_name(self): 64 | return f"{self.custom_name}-cat{len(self.drives)}-{self.split}" 65 | 66 | def get_short_name(self): 67 | return self.custom_name 68 | 69 | def _get_item(self, data_id, rng): 70 | # self.num_input_points = 50000 71 | drive_name = self.all_items[data_id]['drive'] 72 | item_name = self.all_items[data_id]['item'] 73 | 74 | named_data = {} 75 | 76 | try: 77 | if self.input_path is None: 78 | input_data = np.load(self.drive_base_paths[drive_name] / item_name / 'pointcloud.npz') 79 | else: 80 | input_data = np.load(Path(self.input_path) / drive_name / item_name / 'pointcloud.npz') 81 | except FileNotFoundError: 82 | exp.logger.warning(f"File not found for AV dataset for {item_name}") 83 | raise ConnectionAbortedError 84 | 85 | named_data[DS.SHAPE_NAME] = "/".join([drive_name, item_name]) 86 | named_data[DS.INPUT_PC]= input_data['points'].astype(np.float32) 87 | named_data[DS.TARGET_NORMAL] = input_data['normals'].astype(np.float32) 88 | 89 | if self.transforms is not None: 90 | named_data = self.transforms(named_data, rng) 91 | 92 | point_features = np.zeros(shape=(len(named_data[DS.INPUT_PC]), 0), dtype=np.float32) 93 | if self.cfg.model.network.use_normal: 94 | point_features = np.concatenate((point_features, named_data[DS.TARGET_NORMAL]), axis=1) 95 | if self.cfg.model.network.use_xyz: 96 | point_features = np.concatenate((point_features, named_data[DS.INPUT_PC]), axis=1) # add xyz to point features 97 | 98 | xyz = named_data[DS.INPUT_PC] 99 | normals = named_data[DS.TARGET_NORMAL] 100 | 101 | 102 | geom_cls = get_class(self.gt_type) 103 | 104 | if (self.drive_base_paths[drive_name] / item_name / "groundtruth.bin").exists(): 105 | named_data[DS.GT_GEOMETRY] = geom_cls.load(self.drive_base_paths[drive_name] / item_name / "groundtruth.bin") 106 | data = { 107 | "gt_geometry": named_data[DS.GT_GEOMETRY], 108 | "xyz": xyz, # N, 3 109 | "normals": normals, # N, 3 110 | "scene_name": named_data[DS.SHAPE_NAME], 111 | "point_features": point_features, # N, K 112 | } 113 | else: 114 | data = { 115 | "all_xyz": input_data['ref_xyz'].astype(np.float32), 116 | "all_normals": input_data['ref_normals'].astype(np.float32), 117 | "xyz": xyz, # N, 3 118 | "normals": normals, # N, 3 119 | "scene_name": named_data[DS.SHAPE_NAME], 120 | "point_features": point_features, # N, K 121 | } 122 | 123 | return data 124 | -------------------------------------------------------------------------------- /noksr/data/dataset/carla_gt_geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pathlib import Path 4 | from pycg.exp import lru_cache_class, logger 5 | 6 | from pycg.isometry import Isometry 7 | import torch.nn.functional as F 8 | 9 | 10 | class AVGroundTruthGeometry: 11 | def __init__(self): 12 | pass 13 | 14 | @classmethod 15 | def load(cls, path: Path): 16 | raise NotImplementedError 17 | 18 | def save(self, path: Path): 19 | raise NotImplementedError 20 | 21 | def crop(self, bounds: np.ndarray): 22 | # bounds: (C, 2, 3) min_coords and max_coords 23 | raise NotImplementedError 24 | 25 | def transform(self, iso: Isometry = Isometry(), scale: float = 1.0): 26 | # p <- s(Rp+t) 27 | raise NotImplementedError 28 | 29 | 30 | class DensePointsGroundTruthGeometry(AVGroundTruthGeometry): 31 | def __init__(self, xyz: np.ndarray, normal: np.ndarray): 32 | super().__init__() 33 | self.xyz = xyz 34 | self.normal = normal 35 | assert self.xyz.shape[0] == self.normal.shape[0] 36 | assert self.xyz.shape[1] == self.normal.shape[1] == 3 37 | 38 | def save(self, path: Path): 39 | with path.open("wb") as f: 40 | np.savez_compressed(f, xyz=self.xyz, normal=self.normal) 41 | 42 | def transform(self, iso: Isometry = Isometry(), scale: float = 1.0): 43 | self.xyz = scale * (iso @ self.xyz) 44 | self.normal = iso.rotation @ self.normal 45 | 46 | def is_empty(self): 47 | return self.xyz.shape[0] < 64 48 | 49 | @classmethod 50 | def load(cls, path: Path): 51 | res = np.load(path, allow_pickle=True) 52 | inst = cls(res['xyz'], res['normal']) 53 | return inst 54 | 55 | @classmethod 56 | def empty(cls): 57 | return cls(np.zeros((0, 3)), np.zeros((0, 3))) 58 | 59 | @lru_cache_class(maxsize=None) 60 | def torch_attr(self): 61 | return torch.from_numpy(self.xyz).float().cuda(), torch.from_numpy(self.normal).float().cuda() 62 | 63 | def query_sdf(self, queries: torch.Tensor): 64 | import ext 65 | all_points_torch, all_normals_torch = self.torch_attr() 66 | 67 | sdf_kwargs = { 68 | 'queries': queries, 'ref_xyz': all_points_torch, 'ref_normal': all_normals_torch, 69 | 'nb_points': 8, 'stdv': 3.0, 'adaptive_knn': 8 70 | } 71 | try: 72 | query_sdf = -ext.sdfgen.sdf_from_points(**sdf_kwargs)[0] 73 | except MemoryError: 74 | logger.warning("Query SDF OOM. Try empty pytorch cache.") 75 | torch.cuda.empty_cache() 76 | query_sdf = -ext.sdfgen.sdf_from_points(**sdf_kwargs)[0] 77 | 78 | return query_sdf 79 | 80 | def crop(self, bounds: np.ndarray): 81 | crops = [] 82 | for cur_bound in bounds: 83 | min_bound, max_bound = cur_bound[0], cur_bound[1] 84 | crop_mask = np.logical_and.reduce([ 85 | self.xyz[:, 0] > min_bound[0], self.xyz[:, 0] < max_bound[0], 86 | self.xyz[:, 1] > min_bound[1], self.xyz[:, 1] < max_bound[1], 87 | self.xyz[:, 2] > min_bound[2], self.xyz[:, 2] < max_bound[2] 88 | ]) 89 | crop_inst = self.__class__(self.xyz[crop_mask], self.normal[crop_mask]) 90 | crops.append(crop_inst) 91 | return crops 92 | 93 | 94 | class PointTSDFVolumeGroundTruthGeometry(AVGroundTruthGeometry): 95 | def __init__(self, dense_points: DensePointsGroundTruthGeometry, 96 | volume: np.ndarray, volume_min: np.ndarray, volume_max: np.ndarray): 97 | super().__init__() 98 | self.dense_points = dense_points 99 | self.volume = volume 100 | self.volume_min = volume_min 101 | self.volume_max = volume_max 102 | assert np.all(self.volume_min < self.volume_max) 103 | 104 | @property 105 | def xyz(self): 106 | return self.dense_points.xyz 107 | 108 | @property 109 | def normal(self): 110 | return self.dense_points.normal 111 | 112 | @classmethod 113 | def empty(cls): 114 | return cls(DensePointsGroundTruthGeometry.empty(), np.ones((1, 1, 1)), np.zeros(3,), np.ones(3,)) 115 | 116 | def is_empty(self): 117 | return self.dense_points.is_empty() 118 | 119 | def save(self, path: Path): 120 | with path.open("wb") as f: 121 | np.savez_compressed(f, xyz=self.dense_points.xyz, normal=self.dense_points.normal, 122 | volume=self.volume, 123 | volume_min=self.volume_min, volume_max=self.volume_max) 124 | 125 | def transform(self, iso: Isometry = Isometry(), scale: float = 1.0): 126 | assert iso.q.is_unit(), "Volume transform does not support rotation yet" 127 | self.dense_points.transform(iso, scale) 128 | self.volume_min = scale * (self.volume_min + iso.t) 129 | self.volume_max = scale * (self.volume_max + iso.t) 130 | 131 | @classmethod 132 | def load(cls, path: Path): 133 | dense_points = DensePointsGroundTruthGeometry.load(path) 134 | res = np.load(path) 135 | return cls(dense_points, res['volume'], res['volume_min'], res['volume_max']) 136 | 137 | @lru_cache_class(maxsize=None) 138 | def torch_attr(self): 139 | return *self.dense_points.torch_attr(), torch.from_numpy(self.volume).float().cuda() 140 | 141 | def query_classification(self, queries: torch.Tensor, band: float = 1.0): 142 | """ 143 | Return integer classifications of the query points: 144 | 0 - near surface 145 | 1 - far surface empty 146 | 2 - unknown (also for query points outside volume) 147 | :param queries: torch.Tensor (N, 3) 148 | :param band: 0-1 band size to be classified as 'near-surface' 149 | :return: (N, ) ids 150 | """ 151 | _, _, volume_input = self.torch_attr() 152 | 153 | in_volume_mask = (queries[:, 0] >= self.volume_min[0]) & (queries[:, 0] <= self.volume_max[0]) & \ 154 | (queries[:, 1] >= self.volume_min[1]) & (queries[:, 1] <= self.volume_max[1]) & \ 155 | (queries[:, 2] >= self.volume_min[2]) & (queries[:, 2] <= self.volume_max[2]) 156 | 157 | queries_norm = queries[in_volume_mask].clone() 158 | for i in range(3): 159 | queries_norm[:, i] = (queries_norm[:, i] - self.volume_min[i]) / \ 160 | (self.volume_max[i] - self.volume_min[i]) * 2. - 1. 161 | sample_grid = torch.fliplr(queries_norm)[None, None, None, ...] 162 | # B=1,C=1,Di=1,Hi=1,Wi x B=1,Do=1,Ho=1,Wo,3 --> B=1,C=1,Do=1,Ho=1,Wo 163 | sample_res = F.grid_sample(volume_input[None, None, ...], sample_grid, 164 | mode='nearest', padding_mode='border', align_corners=True)[0, 0, 0, 0] 165 | 166 | cls_in_volume = torch.ones_like(sample_res, dtype=torch.long) 167 | cls_in_volume[~torch.isfinite(sample_res)] = 2 168 | cls_in_volume[torch.abs(sample_res) < band] = 0 169 | 170 | cls = torch.ones(queries.size(0), dtype=torch.long, device=cls_in_volume.device) * 2 171 | cls[in_volume_mask] = cls_in_volume 172 | 173 | return cls 174 | 175 | def query_sdf(self, queries: torch.Tensor): 176 | return self.dense_points.query_sdf(queries) 177 | 178 | def crop(self, bounds: np.ndarray): 179 | point_crops = self.dense_points.crop(bounds) 180 | 181 | volume_x_ticks = np.linspace(self.volume_min[0], self.volume_max[0], self.volume.shape[0]) 182 | volume_y_ticks = np.linspace(self.volume_min[1], self.volume_max[1], self.volume.shape[1]) 183 | volume_z_ticks = np.linspace(self.volume_min[2], self.volume_max[2], self.volume.shape[2]) 184 | 185 | crops = [] 186 | for cur_point_crop, cur_bound in zip(point_crops, bounds): 187 | min_bound, max_bound = cur_bound[0], cur_bound[1] 188 | # volume_ticks[id_min] <= min_bound < max_bound <= volume_ticks[id_max] 189 | x_id_min = np.maximum(np.searchsorted(volume_x_ticks, min_bound[0], side='right') - 1, 0) 190 | x_id_max = np.minimum(np.searchsorted(volume_x_ticks, max_bound[0], side='left'), 191 | volume_x_ticks.shape[0] - 1) 192 | y_id_min = np.maximum(np.searchsorted(volume_y_ticks, min_bound[1], side='right') - 1, 0) 193 | y_id_max = np.minimum(np.searchsorted(volume_y_ticks, max_bound[1], side='left'), 194 | volume_y_ticks.shape[0] - 1) 195 | z_id_min = np.maximum(np.searchsorted(volume_z_ticks, min_bound[2], side='right') - 1, 0) 196 | z_id_max = np.minimum(np.searchsorted(volume_z_ticks, max_bound[2], side='left'), 197 | volume_z_ticks.shape[0] - 1) 198 | crops.append(self.__class__( 199 | cur_point_crop, 200 | self.volume[x_id_min:x_id_max+1, y_id_min:y_id_max+1, z_id_min:z_id_max+1], 201 | np.array([volume_x_ticks[x_id_min], volume_y_ticks[y_id_min], volume_z_ticks[z_id_min]]), 202 | np.array([volume_x_ticks[x_id_max], volume_y_ticks[y_id_max], volume_z_ticks[z_id_max]]) 203 | )) 204 | return crops 205 | 206 | 207 | def get_class(class_name): 208 | if class_name == "DensePoints": 209 | return DensePointsGroundTruthGeometry 210 | elif class_name == "PointTSDFVolume": 211 | return PointTSDFVolumeGroundTruthGeometry 212 | else: 213 | raise NotImplementedError 214 | -------------------------------------------------------------------------------- /noksr/data/dataset/general_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | import open3d as o3d 4 | from enum import Enum 5 | from numpy.random import RandomState 6 | import multiprocessing 7 | from omegaconf import DictConfig, ListConfig 8 | from pycg import exp 9 | 10 | 11 | class GeneralDataset(Dataset): 12 | """ Only used for Carla dataset """ 13 | # Augmentation arguments 14 | SCALE_AUGMENTATION_BOUND = (0.9, 1.1) 15 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 64, np.pi / 64), (-np.pi / 64, np.pi / 64), (-np.pi, 16 | np.pi)) 17 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.2, 0.2), (0, 0)) 18 | ELASTIC_DISTORT_PARAMS = ((0.2, 0.4), (0.8, 1.6)) 19 | 20 | ROTATION_AXIS = 'z' 21 | LOCFEAT_IDX = 2 22 | 23 | def __init__(self, cfg, split): 24 | pass 25 | 26 | class DatasetSpec(Enum): 27 | SCENE_NAME = 100 28 | SHAPE_NAME = 0 29 | INPUT_PC = 200 30 | TARGET_NORMAL = 300 31 | INPUT_COLOR = 400 32 | INPUT_SENSOR_POS = 500 33 | GT_DENSE_PC = 600 34 | GT_DENSE_NORMAL = 700 35 | GT_DENSE_COLOR = 800 36 | GT_MESH = 900 37 | GT_MESH_SOUP = 1000 38 | GT_ONET_SAMPLE = 1100 39 | GT_GEOMETRY = 1200 40 | DATASET_CFG = 1300 41 | 42 | class RandomSafeDataset(Dataset): 43 | """ 44 | A dataset class that provides a deterministic random seed. 45 | However, in order to have consistent validation set, we need to set is_val=True for validation/test sets. 46 | Usage: First, inherent this class. 47 | Then, at the beginning of your get_item call, get an rng; 48 | Last, use this rng as the random state for your program. 49 | """ 50 | 51 | def __init__(self, cfg, split): 52 | self._seed = cfg.global_train_seed 53 | self._is_val = split in ['val', 'test'] 54 | self.skip_on_error = False 55 | if not self._is_val: 56 | self._manager = multiprocessing.Manager() 57 | self._read_count = self._manager.dict() 58 | self._rc_lock = multiprocessing.Lock() 59 | 60 | def get_rng(self, idx): 61 | if self._is_val: 62 | return RandomState(self._seed) 63 | with self._rc_lock: 64 | if idx not in self._read_count: 65 | self._read_count[idx] = 0 66 | rng = RandomState(exp.deterministic_hash((idx, self._read_count[idx], self._seed))) 67 | self._read_count[idx] += 1 68 | return rng 69 | 70 | def sanitize_specs(self, old_spec, available_spec): 71 | old_spec = set(old_spec) 72 | available_spec = set(available_spec) 73 | for os in old_spec: 74 | assert isinstance(os, DatasetSpec) 75 | new_spec = old_spec.intersection(available_spec) 76 | # lack_spec = old_spec.difference(new_spec) 77 | # if len(lack_spec) > 0: 78 | # exp.logger.warning(f"Lack spec {lack_spec}.") 79 | return new_spec 80 | 81 | def _get_item(self, data_id, rng): 82 | raise NotImplementedError 83 | 84 | def __getitem__(self, data_id): 85 | rng = self.get_rng(data_id) 86 | if self.skip_on_error: 87 | try: 88 | return self._get_item(data_id, rng) 89 | except ConnectionAbortedError: 90 | return self.__getitem__(rng.randint(0, len(self) - 1)) 91 | except Exception: 92 | # Just return a random other item. 93 | exp.logger.warning(f"Get item {data_id} error, but handled.") 94 | return self.__getitem__(rng.randint(0, len(self) - 1)) 95 | else: 96 | try: 97 | return self._get_item(data_id, rng) 98 | except ConnectionAbortedError: 99 | return self.__getitem__(rng.randint(0, len(self) - 1)) 100 | 101 | -------------------------------------------------------------------------------- /noksr/data/dataset/mixture.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | from noksr.data.dataset.scannet import Scannet 6 | from noksr.data.dataset.synthetic import Synthetic 7 | 8 | class Mixture(Dataset): 9 | def __init__(self, cfg, split): 10 | 11 | self.scannet_dataset = Scannet(cfg, split) 12 | self.synthetic_dataset = Synthetic(cfg, split) 13 | self.cfg = cfg 14 | self.split = split 15 | self.over_fitting = cfg.data.over_fitting 16 | 17 | # Combine the lengths of both datasets 18 | self.length = len(self.scannet_dataset) + len(self.synthetic_dataset) 19 | 20 | def __len__(self): 21 | if self.split == 'val': 22 | if self.cfg.data.validation_set == "ScanNet": 23 | return len(self.scannet_dataset) 24 | elif self.cfg.data.validation_set == "Synthetic": 25 | return len(self.synthetic_dataset) 26 | 27 | return self.length 28 | 29 | def __getitem__(self, idx): 30 | # Determine which dataset to load from based on idx 31 | if self.split == 'val': 32 | if self.cfg.data.validation_set == "ScanNet": 33 | return self.scannet_dataset[idx] 34 | elif self.cfg.data.validation_set == "Synthetic": 35 | return self.synthetic_dataset[idx] 36 | 37 | if idx < len(self.scannet_dataset): 38 | return self.scannet_dataset[idx] 39 | else: 40 | return self.synthetic_dataset[idx - len(self.scannet_dataset)] 41 | -------------------------------------------------------------------------------- /noksr/data/dataset/scannet.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from noksr.utils.serialization import encode 7 | import open3d as o3d 8 | 9 | class Scannet(Dataset): 10 | def __init__(self, cfg, split): 11 | self.cfg = cfg 12 | self.split = 'val' if self.cfg.data.over_fitting else split # use only val set for overfitting 13 | if 'ScanNet' in cfg.data: # subset of mixture data 14 | self.dataset_root_path = cfg.data.ScanNet.dataset_path 15 | self.dataset_path = cfg.data.ScanNet.dataset_path 16 | else: 17 | self.dataset_root_path = cfg.data.dataset_path 18 | self.dataset_path = cfg.data.dataset_path 19 | 20 | self.metadata = cfg.data.metadata 21 | self.num_input_points = cfg.data.num_input_points 22 | self.take = cfg.data.take 23 | self.intake_start = cfg.data.intake_start 24 | self.uniform_sampling = cfg.data.uniform_sampling 25 | self.input_splats = cfg.data.input_splats 26 | self.std_dev = cfg.data.std_dev 27 | 28 | self.in_memory = cfg.data.in_memory 29 | self.dataset_split = "test" if split == "test" else "train" # train and val scenes and all under train set 30 | self.data_map = { 31 | "train": self.metadata.train_list, 32 | "val": self.metadata.val_list, 33 | "test": self.metadata.test_list 34 | } 35 | self._load_from_disk() 36 | 37 | def _load_from_disk(self): 38 | with open(getattr(self.metadata, f"{self.split}_list")) as f: 39 | self.scene_names = [line.strip() for line in f] 40 | 41 | self.scenes = [] 42 | if self.cfg.data.over_fitting: 43 | self.scene_names = self.scene_names[self.intake_start:self.take+self.intake_start] 44 | if len(self.scene_names) == 1: # if only one scene is taken, overfit on scene 0221_00 45 | self.scene_names = ['scene0221_00'] 46 | for scene_name in tqdm(self.scene_names, desc=f"Loading {self.split} data from disk"): 47 | scene_path = os.path.join(self.dataset_path, self.split, f"{scene_name}.pth") 48 | scene = torch.load(scene_path) 49 | scene["xyz"] = scene["xyz"].astype(np.float32) 50 | scene["rgb"] = scene["rgb"].astype(np.float32) 51 | scene['scene_name'] = scene_name 52 | point_features = np.zeros(shape=(len(scene['xyz']), 0), dtype=np.float32) 53 | if self.cfg.model.network.use_color: 54 | point_features = np.concatenate((point_features, scene['rgb']), axis=1) 55 | if self.cfg.model.network.use_normal: 56 | point_features = np.concatenate((point_features, scene['normal']), axis=1) 57 | if self.cfg.model.network.use_xyz: 58 | point_features = np.concatenate((point_features, scene['xyz']), axis=1) # add xyz to point features 59 | scene['point_features'] = point_features 60 | self.scenes.append(scene) 61 | 62 | def __len__(self): 63 | return len(self.scenes) 64 | 65 | def __getitem__(self, idx): 66 | scene = self.scenes[idx] 67 | all_xyz = scene['xyz'] 68 | all_normals = scene['normal'] 69 | scene_name = scene['scene_name'] 70 | 71 | # sample input points 72 | num_points = scene["xyz"].shape[0] 73 | num_input_points = self.num_input_points 74 | if num_input_points == -1: 75 | xyz = scene["xyz"] 76 | point_features = scene['point_features'] 77 | else: 78 | if not self.uniform_sampling: 79 | # Number of blocks along each axis 80 | num_blocks = 2 81 | total_blocks = num_blocks ** 3 82 | self.common_difference = 200 83 | # Calculate block sizes 84 | block_sizes = (all_xyz.max(axis=0) - all_xyz.min(axis=0)) / num_blocks 85 | 86 | # Create the number_per_block array with an arithmetic sequence 87 | average_points_per_block = num_input_points // total_blocks 88 | number_per_block = np.array([ 89 | average_points_per_block + (i - total_blocks // 2) * self.common_difference 90 | for i in range(total_blocks) 91 | ]) 92 | 93 | # Adjust number_per_block to ensure the sum is num_input_points 94 | total_points = np.sum(number_per_block) 95 | difference = num_input_points - total_points 96 | number_per_block[-1] += difference 97 | 98 | # Sample points from each block 99 | sample_indices = [] 100 | block_index = 0 101 | total_chosen_indices = 0 102 | remaining_points = 0 # Points to be added to the next block 103 | for i in range(num_blocks): 104 | for j in range(num_blocks): 105 | for k in range(num_blocks): 106 | block_min = all_xyz.min(axis=0) + block_sizes * np.array([i, j, k]) 107 | block_max = block_min + block_sizes 108 | block_mask = np.all((all_xyz >= block_min) & (all_xyz < block_max), axis=1) 109 | block_indices = np.where(block_mask)[0] 110 | num_samples = number_per_block[block_index] + remaining_points 111 | remaining_points = 0 # Reset remaining points 112 | block_index += 1 113 | if len(block_indices) > 0: 114 | chosen_indices = np.random.choice(block_indices, num_samples, replace=True) 115 | sample_indices.extend(chosen_indices) 116 | total_chosen_indices += len(chosen_indices) 117 | # print(f"Block {block_index} - Desired: {num_samples}, Actual: {len(chosen_indices)}") 118 | if len(chosen_indices) < num_samples: 119 | remaining_points += (num_samples - len(chosen_indices)) 120 | else: 121 | # print(f"Block {block_index} - No points available. Adding {num_samples} points to the next block.") 122 | remaining_points += num_samples 123 | 124 | else: 125 | if num_points < num_input_points: 126 | print(f"Scene {scene_name} has less than {num_input_points} points. Sampling with replacement.") 127 | sample_indices = np.random.choice(num_points, num_input_points, replace=True) 128 | else: 129 | sample_indices = np.random.choice(num_points, num_input_points, replace=True) 130 | 131 | xyz = scene["xyz"][sample_indices] 132 | point_features = scene['point_features'][sample_indices] 133 | noise = np.random.normal(0, self.std_dev, xyz.shape) 134 | xyz += noise 135 | 136 | data = { 137 | "all_xyz": all_xyz, 138 | "all_normals": all_normals, 139 | "xyz": xyz, # N, 3 140 | "point_features": point_features, # N, 3 141 | "scene_name": scene['scene_name'] 142 | } 143 | 144 | return data 145 | 146 | -------------------------------------------------------------------------------- /noksr/data/dataset/scenenn.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | from plyfile import PlyData 6 | 7 | class SceneNN(Dataset): 8 | def __init__(self, cfg, split): 9 | self.cfg = cfg 10 | self.split = split 11 | self.dataset_root_path = cfg.data.dataset_path 12 | self.voxel_size = cfg.data.voxel_size 13 | self.num_input_points = cfg.data.num_input_points 14 | self.std_dev = cfg.data.std_dev 15 | self.intake_start = cfg.data.intake_start 16 | self.take = cfg.data.take 17 | self.input_splats = cfg.data.input_splats 18 | 19 | self.in_memory = cfg.data.in_memory 20 | self.train_files = cfg.data.train_files 21 | self.test_files = cfg.data.test_files 22 | if self.split == 'test': 23 | scene_ids = self.test_files 24 | else: 25 | scene_ids = self.train_files + self.test_files 26 | self.filenames = sorted([os.path.join(self.dataset_root_path, scene_id, f) 27 | for scene_id in scene_ids 28 | for f in os.listdir(os.path.join(self.dataset_root_path, scene_id)) 29 | if f.endswith('.ply')]) 30 | if self.cfg.data.over_fitting: 31 | self.filenames = self.filenames[self.intake_start:self.take+self.intake_start] 32 | 33 | def __len__(self): 34 | return len(self.filenames) 35 | 36 | def __getitem__(self, idx): 37 | """Get item.""" 38 | # load the mesh 39 | scene_filename = self.filenames[idx] 40 | 41 | ply_data = PlyData.read(scene_filename) 42 | vertex = ply_data['vertex'] 43 | pos = np.stack([vertex[t] for t in ('x', 'y', 'z')], axis=1) 44 | nls = np.stack([vertex[t] for t in ('nx', 'ny', 'nz')], axis=1) if 'nx' in vertex and 'ny' in vertex and 'nz' in vertex else np.zeros_like(pos) 45 | 46 | all_xyz = pos 47 | all_normals = nls 48 | scene_name = os.path.basename(scene_filename).replace('.ply', '') 49 | 50 | all_point_features = np.zeros(shape=(len(all_xyz), 0), dtype=np.float32) 51 | if self.cfg.model.network.use_normal: 52 | all_point_features = np.concatenate((all_point_features, all_normals), axis=1) 53 | if self.cfg.model.network.use_xyz: 54 | all_point_features = np.concatenate((all_point_features, all_xyz), axis=1) # add xyz to point features 55 | 56 | # sample input points 57 | num_points = all_xyz.shape[0] 58 | if self.num_input_points == -1: 59 | xyz = all_xyz 60 | point_features = all_point_features 61 | normals = all_normals 62 | else: 63 | sample_indices = np.random.choice(num_points, self.num_input_points, replace=True) 64 | xyz = all_xyz[sample_indices] 65 | point_features = all_point_features[sample_indices] 66 | normals = all_normals[sample_indices] 67 | 68 | noise = np.random.normal(0, self.std_dev, xyz.shape) 69 | xyz += noise 70 | 71 | data = { 72 | "all_xyz": all_xyz, 73 | "all_normals": all_normals, 74 | "xyz": xyz, # N, 3 75 | "normals": normals, # N, 3 76 | "point_features": point_features, # N, 3 77 | "scene_name": scene_name 78 | } 79 | 80 | return data 81 | -------------------------------------------------------------------------------- /noksr/data/dataset/synthetic.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | from statistics import mode 4 | import numpy as np 5 | 6 | from torch.utils.data import Dataset 7 | from plyfile import PlyData 8 | 9 | 10 | class Synthetic(Dataset): 11 | def __init__(self, cfg, split): 12 | # Attributes 13 | if 'Synthetic' in cfg.data: # subset of mixture data 14 | categories = cfg.data.Synthetic.classes 15 | self.dataset_folder = cfg.data.Synthetic.path 16 | self.multi_files = cfg.data.Synthetic.multi_files 17 | self.file_name = cfg.data.Synthetic.pointcloud_file 18 | else: 19 | categories = cfg.data.classes 20 | self.dataset_folder = cfg.data.path 21 | self.multi_files = cfg.data.multi_files 22 | self.file_name = cfg.data.pointcloud_file 23 | self.scale = 2.2 # Emperical scale to transfer back to physical scale 24 | self.cfg = cfg 25 | self.split = split 26 | self.std_dev = cfg.data.std_dev * self.scale 27 | self.num_input_points = cfg.data.num_input_points 28 | 29 | self.no_except = True 30 | 31 | self.intake_start = cfg.data.intake_start 32 | self.take = cfg.data.take 33 | # If categories is None, use all subfolders 34 | if categories is None: 35 | categories = os.listdir(self.dataset_folder) 36 | categories = [c for c in categories 37 | if os.path.isdir(os.path.join(self.dataset_folder, c))] 38 | 39 | self.metadata = { 40 | c: {'id': c, 'name': 'n/a'} for c in categories 41 | } 42 | 43 | # Set index 44 | for c_idx, c in enumerate(categories): 45 | self.metadata[c]['idx'] = c_idx 46 | 47 | # Get all models 48 | self.models = [] 49 | for c_idx, c in enumerate(categories): 50 | subpath = os.path.join(self.dataset_folder, c) 51 | if not os.path.isdir(subpath): 52 | print('Category %s does not exist in dataset.' % c) 53 | 54 | if self.split is None: 55 | self.models += [ 56 | {'category': c, 'model': m} for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ] 57 | ] 58 | 59 | else: 60 | split_file = os.path.join(subpath, self.split + '.lst') 61 | with open(split_file, 'r') as f: 62 | models_c = f.read().split('\n') 63 | 64 | if '' in models_c: 65 | models_c.remove('') 66 | 67 | self.models += [ 68 | {'category': c, 'model': m} 69 | for m in models_c 70 | ] 71 | 72 | # overfit in one data 73 | if self.cfg.data.over_fitting: 74 | self.models = self.models[self.intake_start:self.take+self.intake_start] 75 | 76 | 77 | def __len__(self): 78 | ''' Returns the length of the dataset. 79 | ''' 80 | return len(self.models) 81 | 82 | def load(self, model_path, idx, vol): 83 | ''' Loads the data point. 84 | 85 | Args: 86 | model_path (str): path to model 87 | idx (int): ID of data point 88 | vol (dict): precomputed volume info 89 | ''' 90 | if self.multi_files is None: 91 | file_path = os.path.join(model_path, self.file_name) 92 | else: 93 | num = np.random.randint(self.multi_files) 94 | file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num)) 95 | 96 | item_path = os.path.join(model_path, 'item_dict.npz') 97 | item_dict = np.load(item_path, allow_pickle=True) 98 | points_dict = np.load(file_path, allow_pickle=True) 99 | points = points_dict['points'] * self.scale # roughly transfer back to physical scale 100 | normals = points_dict['normals'] 101 | semantics = points_dict['semantics'] 102 | # Break symmetry if given in float16: 103 | if points.dtype == np.float16: 104 | points = points.astype(np.float32) 105 | normals = normals.astype(np.float32) 106 | points += 1e-4 * np.random.randn(*points.shape) 107 | normals += 1e-4 * np.random.randn(*normals.shape) 108 | 109 | min_values = np.min(points, axis=0) 110 | points -= min_values 111 | 112 | return points, normals, semantics 113 | 114 | def __getitem__(self, idx): 115 | ''' Returns an item of the dataset. 116 | 117 | Args: 118 | idx (int): ID of data point 119 | ''' 120 | category = self.models[idx]['category'] 121 | model = self.models[idx]['model'] 122 | c_idx = self.metadata[category]['idx'] 123 | 124 | model_path = os.path.join(self.dataset_folder, category, model) 125 | 126 | all_xyz, all_normals, all_semantics = self.load(model_path, idx, c_idx) 127 | scene_name = f"{category}/{model}/{idx}" 128 | 129 | all_point_features = np.zeros(shape=(len(all_xyz), 0), dtype=np.float32) 130 | if self.cfg.model.network.use_normal: 131 | all_point_features = np.concatenate((all_point_features, all_normals), axis=1) 132 | if self.cfg.model.network.use_xyz: 133 | all_point_features = np.concatenate((all_point_features, all_xyz), axis=1) # add xyz to point features 134 | # sample input points 135 | num_points = all_xyz.shape[0] 136 | if self.num_input_points == -1: 137 | xyz = all_xyz 138 | point_features = all_point_features 139 | normals = all_normals 140 | else: 141 | sample_indices = np.random.choice(num_points, self.num_input_points, replace=True) 142 | xyz = all_xyz[sample_indices] 143 | point_features = all_point_features[sample_indices] 144 | normals = all_normals[sample_indices] 145 | 146 | noise = np.random.normal(0, self.std_dev, xyz.shape) 147 | xyz += noise 148 | 149 | data = { 150 | "all_xyz": all_xyz, 151 | "all_normals": all_normals, 152 | "xyz": xyz, # N, 3 153 | "normals": normals, # N, 3 154 | "point_features": point_features, # N, 3 155 | "scene_name": scene_name 156 | } 157 | return data 158 | -------------------------------------------------------------------------------- /noksr/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .noksr_net import noksr 2 | 3 | 4 | -------------------------------------------------------------------------------- /noksr/model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/__pycache__/general_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/__pycache__/general_model.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/__pycache__/noksr_net.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/__pycache__/noksr_net.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/__pycache__/pcs4esr_net.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/__pycache__/pcs4esr_net.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/general_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import open3d as o3d 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import pl_bolts 7 | from collections import OrderedDict 8 | from typing import Mapping, Any, Optional 9 | from noksr.utils.optimizer import cosine_lr_decay 10 | # from noksr.model.module import Generator 11 | from torch.nn import functional as F 12 | from pycg import exp, image 13 | 14 | 15 | class GeneralModel(pl.LightningModule): 16 | def __init__(self, cfg): 17 | super().__init__() 18 | self.save_hyperparameters() 19 | self.val_test_step_outputs = [] 20 | # For recording test information 21 | # step -> log_name -> log_value (list of ordered-dict) 22 | self.test_logged_values = [] 23 | self.record_folder = None 24 | self.record_headers = [] 25 | self.record_data_cache = {} 26 | self.last_test_valid = False 27 | 28 | def configure_optimizers(self): 29 | params_to_optimize = self.parameters() 30 | 31 | if self.hparams.model.optimizer.name == "SGD": 32 | optimizer = torch.optim.SGD( 33 | params_to_optimize, 34 | lr=self.hparams.model.optimizer.lr, 35 | momentum=0.9, 36 | weight_decay=1e-4, 37 | ) 38 | scheduler = pl_bolts.optimizers.LinearWarmupCosineAnnealingLR( 39 | optimizer, 40 | warmup_epochs=int(self.hparams.model.optimizer.warmup_steps_ratio * self.hparams.model.trainer.max_steps), 41 | max_epochs=self.hparams.model.trainer.max_steps, 42 | eta_min=0, 43 | ) 44 | return { 45 | "optimizer": optimizer, 46 | "lr_scheduler": { 47 | "scheduler": scheduler, 48 | "interval": "step" 49 | } 50 | } 51 | 52 | elif self.hparams.model.optimizer.name == 'Adam': 53 | optimizer = torch.optim.Adam( 54 | params_to_optimize, 55 | lr=self.hparams.model.optimizer.lr, 56 | betas=(0.9, 0.999), 57 | weight_decay=1e-4, 58 | ) 59 | return optimizer 60 | 61 | else: 62 | logging.error('Optimizer type not supported') 63 | 64 | def training_step(self, data_dict): 65 | pass 66 | 67 | def on_train_epoch_end(self): 68 | if self.hparams.model.optimizer.name == 'Adam': 69 | # Update the learning rates for Adam optimizers 70 | cosine_lr_decay( 71 | self.trainer.optimizers[0], self.hparams.model.optimizer.lr, self.current_epoch, 72 | self.hparams.model.lr_decay.decay_start_epoch, self.hparams.model.trainer.max_epochs, 1e-6 73 | ) 74 | 75 | def validation_step(self, data_dict, idx): 76 | pass 77 | 78 | def validation_epoch_end(self, outputs): 79 | metrics_to_log = ['chamfer-L1', 'f-score', 'f-score-20'] 80 | if outputs: 81 | avg_metrics = {metric: np.mean([x[metric] for x in outputs]) for metric in metrics_to_log if metric in outputs[0]} 82 | for key, value in avg_metrics.items(): 83 | self.log(f"val_reconstruction/{key}", value, logger=True) 84 | 85 | 86 | 87 | def test_step(self, data_dict, idx): 88 | pass 89 | 90 | 91 | def log_dict_prefix( 92 | self, 93 | prefix: str, 94 | dictionary: Mapping[str, Any], 95 | prog_bar: bool = False, 96 | logger: bool = True, 97 | on_step: Optional[bool] = None, 98 | on_epoch: Optional[bool] = None 99 | ): 100 | """ 101 | This overrides fixes if dict key is not a string... 102 | """ 103 | dictionary = { 104 | prefix + "/" + str(k): v for k, v in dictionary.items() 105 | } 106 | self.log_dict(dictionary=dictionary, 107 | prog_bar=prog_bar, 108 | logger=logger, on_step=on_step, on_epoch=on_epoch) 109 | 110 | def log_image(self, name: str, img: np.ndarray): 111 | if self.trainer.logger is not None: 112 | self.trainer.logger.log_image(key=name, images=[img]) 113 | 114 | 115 | def log_geometry(self, name: str, geom, draw_color: bool = False): 116 | if self.trainer.logger is None: 117 | return 118 | if isinstance(geom, o3d.geometry.TriangleMesh): 119 | try: 120 | from pycg import render 121 | mv_img = render.multiview_image( 122 | [geom], viewport_shading='LIT' if draw_color else 'NORMAL', backend='filament') 123 | # mv_img = render.multiview_image( 124 | # [geom], viewport_shading='LIT' if draw_color else 'NORMAL', backend='opengl') 125 | self.log_image("mesh" + name, mv_img) 126 | except Exception: 127 | exp.logger.warning("Not able to render mesh during training.") 128 | else: 129 | raise NotImplementedError 130 | 131 | def test_log_data(self, data_dict: dict): 132 | self.record_data_cache.update(data_dict) -------------------------------------------------------------------------------- /noksr/model/module/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import Decoder 2 | from .generation import Generator 3 | from .point_transformer import PointTransformerV3 4 | -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/backbone.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/backbone.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/backbone.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/backbone.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/backbone_nocs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/backbone_nocs.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/common.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/common.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/decoder.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/decoder.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/decoder2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/decoder2.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/encoder.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/encoder.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/generation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/generation.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/generation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/generation.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/kp_decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/kp_decoder.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/larger_decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/larger_decoder.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/point_transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/point_transformer.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/tiny_unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/tiny_unet.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/visualization.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/visualization.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/model/module/__pycache__/visualization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/visualization.cpython-38.pyc -------------------------------------------------------------------------------- /noksr/model/module/generation.py: -------------------------------------------------------------------------------- 1 | from pyexpat import features 2 | import time 3 | 4 | import torch 5 | from pycg import vis 6 | from torch import Tensor, nn 7 | from torch.nn import functional as F 8 | from tqdm import tqdm 9 | from typing import Callable, Tuple 10 | from sklearn.neighbors import NearestNeighbors 11 | from pytorch3d.ops import knn_points 12 | import open3d as o3d 13 | import pytorch_lightning as pl 14 | from noksr.utils.samples import BatchedSampler 15 | from noksr.utils.serialization import encode 16 | 17 | 18 | class MeshingResult: 19 | def __init__(self, v: torch.Tensor = None, f: torch.Tensor = None, c: torch.Tensor = None): 20 | self.v = v 21 | self.f = f 22 | self.c = c 23 | 24 | class Generator(pl.LightningModule): 25 | def __init__(self, model, mask_model, voxel_size, k_neighbors, last_n_layers, reconstruction_cfg): 26 | super().__init__() 27 | self.model = model # the model should be the UDF Decoder 28 | self.mask_model = mask_model # the distance mask decoder 29 | self.rec_cfg = reconstruction_cfg 30 | self.voxel_size = voxel_size 31 | self.threshold = 0.4 32 | self.k_neighbors = k_neighbors 33 | self.last_n_layers = last_n_layers 34 | 35 | 36 | def compute_gt_sdf_from_pts(self, gt_xyz, gt_normals, query_pos: torch.Tensor): 37 | k = 8 38 | stdv = 0.02 39 | knn_output = knn_points(query_pos.unsqueeze(0).to(torch.device("cuda")), gt_xyz.unsqueeze(0).to(torch.device("cuda")), K=k) 40 | indices = knn_output.idx.squeeze(0) 41 | indices = torch.tensor(indices, device=query_pos.device) 42 | closest_points = gt_xyz[indices] 43 | surface_to_queries_vec = query_pos.unsqueeze(1) - closest_points #N, K, 3 44 | 45 | dot_products = torch.einsum("ijk,ijk->ij", surface_to_queries_vec, gt_normals[indices]) #N, K 46 | vec_lengths = torch.norm(surface_to_queries_vec[:, 0, :], dim=-1) 47 | use_dot_product = vec_lengths < stdv 48 | sdf = torch.where(use_dot_product, torch.abs(dot_products[:, 0]), vec_lengths) 49 | 50 | # Adjust the sign of the sdf values based on the majority of dot products 51 | num_pos = torch.sum(dot_products > 0, dim=1) 52 | inside = num_pos <= (k / 2) 53 | sdf[inside] *= -1 54 | 55 | return -sdf 56 | 57 | def generate_dual_mc_mesh(self, data_dict, encoder_outputs, device): 58 | from nksr.svh import SparseFeatureHierarchy, SparseIndexGrid 59 | from nksr.ext import meshing 60 | from nksr.meshing import MarchingCubes 61 | from nksr import utils 62 | 63 | max_depth = 100 64 | grid_upsample = 1 65 | mise_iter = 0 66 | knn_time = 0 67 | dmc_time = 0 68 | aggregation_time = 0 69 | decoder_time = 0 70 | mask_threshold = self.rec_cfg.mask_threshold 71 | 72 | pts = data_dict['xyz'].detach() 73 | self.last_n_layers = 4 74 | self.trim = self.rec_cfg.trim 75 | self.gt_mask = self.rec_cfg.gt_mask 76 | self.gt_sdf = self.rec_cfg.gt_sdf 77 | # Generate DMC grid structure 78 | nksr_svh = SparseFeatureHierarchy( 79 | voxel_size=self.voxel_size, 80 | depth=self.last_n_layers, 81 | device= pts.device 82 | ) 83 | nksr_svh.build_point_splatting(pts) 84 | 85 | flattened_grids = [] 86 | for d in range(min(nksr_svh.depth, max_depth + 1)): 87 | f_grid = meshing.build_flattened_grid( 88 | nksr_svh.grids[d]._grid, 89 | nksr_svh.grids[d - 1]._grid if d > 0 else None, 90 | d != nksr_svh.depth - 1 91 | ) 92 | if grid_upsample > 1: 93 | f_grid = f_grid.subdivided_grid(grid_upsample) 94 | flattened_grids.append(f_grid) 95 | 96 | dual_grid = meshing.build_joint_dual_grid(flattened_grids) 97 | dmc_graph = meshing.dual_cube_graph(flattened_grids, dual_grid) 98 | dmc_vertices = torch.cat([ 99 | f_grid.grid_to_world(f_grid.active_grid_coords().float()) 100 | for f_grid in flattened_grids if f_grid.num_voxels() > 0 101 | ], dim=0) 102 | del flattened_grids, dual_grid 103 | """ create a mask to trim spurious geometry """ 104 | 105 | decoder_time -= time.time() 106 | dmc_value, sdf_knn_time, sdf_aggregation_time = self.model(encoder_outputs, dmc_vertices) 107 | decoder_time += time.time() 108 | knn_time += sdf_knn_time 109 | aggregation_time += sdf_aggregation_time 110 | if self.gt_sdf: 111 | if 'gt_geometry' in data_dict: 112 | ref_xyz, ref_normal, _ = data_dict['gt_geometry'][0].torch_attr() 113 | else: 114 | ref_xyz, ref_normal = data_dict['all_xyz'], data_dict['all_normals'] 115 | dmc_value = self.compute_gt_sdf_from_pts(ref_xyz, ref_normal, dmc_vertices) 116 | 117 | for _ in range(mise_iter): 118 | cube_sign = dmc_value[dmc_graph] > 0 119 | cube_mask = ~torch.logical_or(torch.all(cube_sign, dim=1), torch.all(~cube_sign, dim=1)) 120 | dmc_graph = dmc_graph[cube_mask] 121 | unq, dmc_graph = torch.unique(dmc_graph.view(-1), return_inverse=True) 122 | dmc_graph = dmc_graph.view(-1, 8) 123 | dmc_vertices = dmc_vertices[unq] 124 | dmc_graph, dmc_vertices = utils.subdivide_cube_indices(dmc_graph, dmc_vertices) 125 | dmc_value = torch.clamp(self.model(encoder_outputs, dmc_vertices.to(device)), max=self.threshold) 126 | 127 | dmc_time -= time.time() 128 | dual_v, dual_f = MarchingCubes().apply(dmc_graph, dmc_vertices, dmc_value) 129 | dmc_time += time.time() 130 | 131 | vert_mask = None 132 | if self.trim: 133 | if self.gt_mask: 134 | nn = NearestNeighbors(n_neighbors=1) 135 | nn.fit(data_dict['all_xyz'].cpu().numpy()) # coords is an (N, 3) array 136 | dist, indx = nn.kneighbors(dual_v.detach().cpu().numpy()) # xyz is an (M, 3) array 137 | dist = torch.from_numpy(dist).to(dual_v.device).squeeze(-1) 138 | vert_mask = dist < mask_threshold 139 | else: 140 | decoder_time -= time.time() 141 | dist, mask_knn_time, mask_aggregation_time = self.mask_model(encoder_outputs, dual_v.to(device)) 142 | decoder_time += time.time() 143 | vert_mask = dist < mask_threshold 144 | knn_time += mask_knn_time 145 | aggregation_time += mask_aggregation_time 146 | dmc_time -= time.time() 147 | dual_v, dual_f = utils.apply_vertex_mask(dual_v, dual_f, vert_mask) 148 | dmc_time += time.time() 149 | 150 | dmc_time -= time.time() 151 | mesh_res = MeshingResult(dual_v, dual_f, None) 152 | # del dual_v, dual_f 153 | mesh = vis.mesh(mesh_res.v, mesh_res.f) 154 | dmc_time += time.time() 155 | 156 | time_dict = { 157 | 'neighboring_time': knn_time, 158 | 'dmc_time': dmc_time, 159 | 'aggregation_time': aggregation_time, 160 | 'decoder_time': decoder_time, 161 | } 162 | return mesh, time_dict 163 | 164 | def generate_dual_mc_mesh_by_segment(self, data_dict, encoder_outputs, encoding_codes, depth, device): 165 | """ 166 | This function generates a dual marching cube mesh by computing the sdf values for each segment individually. 167 | """ 168 | from nksr.svh import SparseFeatureHierarchy, SparseIndexGrid 169 | from nksr.ext import meshing 170 | from nksr.meshing import MarchingCubes 171 | from nksr import utils 172 | 173 | max_depth = 100 174 | grid_upsample = 1 175 | mise_iter = 0 176 | knn_time = 0 177 | dmc_time = 0 178 | aggregation_time = 0 179 | decoder_time = 0 180 | mask_threshold = self.rec_cfg.mask_threshold 181 | 182 | pts = data_dict['xyz'].detach() 183 | self.last_n_layers = 4 184 | self.trim = self.rec_cfg.trim 185 | self.gt_mask = self.rec_cfg.gt_mask 186 | self.gt_sdf = self.rec_cfg.gt_sdf 187 | 188 | # Generate DMC grid structure 189 | nksr_svh = SparseFeatureHierarchy( 190 | voxel_size=self.voxel_size, 191 | depth=self.last_n_layers, 192 | device=pts.device 193 | ) 194 | nksr_svh.build_point_splatting(pts) 195 | 196 | flattened_grids = [] 197 | for d in range(min(nksr_svh.depth, max_depth + 1)): 198 | f_grid = meshing.build_flattened_grid( 199 | nksr_svh.grids[d]._grid, 200 | nksr_svh.grids[d - 1]._grid if d > 0 else None, 201 | d != nksr_svh.depth - 1 202 | ) 203 | if grid_upsample > 1: 204 | f_grid = f_grid.subdivided_grid(grid_upsample) 205 | flattened_grids.append(f_grid) 206 | 207 | dual_grid = meshing.build_joint_dual_grid(flattened_grids) 208 | dmc_graph = meshing.dual_cube_graph(flattened_grids, dual_grid) 209 | dmc_vertices = torch.cat([ 210 | f_grid.grid_to_world(f_grid.active_grid_coords().float()) 211 | for f_grid in flattened_grids if f_grid.num_voxels() > 0 212 | ], dim=0) 213 | del flattened_grids, dual_grid 214 | 215 | # Encode and segment `dmc_vertices` 216 | in_quant_coords = torch.floor(dmc_vertices / 0.01).to(torch.int) 217 | dmc_quant_codes = encode( 218 | in_quant_coords, 219 | torch.zeros(in_quant_coords.shape[0], dtype=torch.int64, device=in_quant_coords.device), 220 | depth, 221 | order='z' 222 | ) 223 | sorted_codes, sorted_indices = torch.sort(dmc_quant_codes) 224 | 225 | dmc_value_list = [] 226 | for idx in range(len(encoding_codes)): 227 | if idx == 0: 228 | segment_mask = (sorted_codes < encoding_codes[idx+1]) 229 | elif idx == len(encoding_codes) - 1: 230 | segment_mask = (sorted_codes >= encoding_codes[idx]) 231 | else: 232 | segment_mask = (sorted_codes >= encoding_codes[idx]) & (sorted_codes < encoding_codes[idx+1]) 233 | segment_indices = sorted_indices[segment_mask] 234 | segment_vertices = dmc_vertices[segment_indices] 235 | segment_encoder_output = encoder_outputs[idx] 236 | segment_dmc_value, sdf_knn_time, sdf_aggregation_time = self.model(segment_encoder_output, segment_vertices) 237 | dmc_value_list.append(segment_dmc_value) 238 | 239 | knn_time += sdf_knn_time 240 | aggregation_time += sdf_aggregation_time 241 | 242 | dmc_values = torch.zeros_like(sorted_codes, dtype=torch.float32, device=device) 243 | dmc_values[sorted_indices] = torch.cat(dmc_value_list) 244 | 245 | dmc_time -= time.time() 246 | dual_v, dual_f = MarchingCubes().apply(dmc_graph, dmc_vertices, dmc_values) 247 | dmc_time += time.time() 248 | 249 | """ create a mask to trim spurious geometry """ 250 | in_quant_coords = torch.floor(dual_v / 0.01).to(torch.int) 251 | dual_quant_codes = encode( 252 | in_quant_coords, 253 | torch.zeros(in_quant_coords.shape[0], dtype=torch.int64, device=in_quant_coords.device), 254 | depth, 255 | order='z' 256 | ) 257 | sorted_dual_codes, sorted_dual_indices = torch.sort(dual_quant_codes) 258 | 259 | dist_list = [] 260 | for idx in range(len(encoding_codes)): 261 | if idx == 0: 262 | segment_mask = (sorted_dual_codes < encoding_codes[idx+1]) 263 | elif idx == len(encoding_codes) - 1: 264 | segment_mask = (sorted_dual_codes >= encoding_codes[idx]) 265 | else: 266 | segment_mask = (sorted_dual_codes >= encoding_codes[idx]) & (sorted_dual_codes < encoding_codes[idx+1]) 267 | segment_indices = sorted_dual_indices[segment_mask] 268 | segment_dual_v = dual_v[segment_indices] 269 | 270 | if self.gt_mask: 271 | nn = NearestNeighbors(n_neighbors=1) 272 | nn.fit(data_dict['all_xyz'].cpu().numpy()) # Reference points (N, 3) 273 | segment_dist, _ = nn.kneighbors(segment_dual_v.detach().cpu().numpy()) # Query points (M, 3) 274 | segment_dist = torch.from_numpy(segment_dist).to(dual_v.device).squeeze(-1) 275 | else: 276 | decoder_time -= time.time() 277 | segment_dist, mask_knn_time, mask_aggregation_time = self.mask_model(encoder_outputs[idx], segment_dual_v.to(device)) 278 | decoder_time += time.time() 279 | 280 | knn_time += mask_knn_time 281 | aggregation_time += mask_aggregation_time 282 | 283 | dist_list.append(segment_dist) 284 | 285 | dist = torch.zeros_like(sorted_dual_codes, dtype=torch.float32, device=device) 286 | dist[sorted_dual_indices] = torch.cat(dist_list) 287 | 288 | vert_mask = dist < mask_threshold 289 | dual_v, dual_f = utils.apply_vertex_mask(dual_v, dual_f, vert_mask) 290 | 291 | dmc_time -= time.time() 292 | mesh_res = MeshingResult(dual_v, dual_f, None) 293 | mesh = vis.mesh(mesh_res.v, mesh_res.f) 294 | dmc_time += time.time() 295 | 296 | time_dict = { 297 | 'neighboring_time': knn_time, 298 | 'dmc_time': dmc_time, 299 | 'aggregation_time': aggregation_time, 300 | 'decoder_time': decoder_time, 301 | } 302 | 303 | return mesh, time_dict 304 | -------------------------------------------------------------------------------- /noksr/model/noksr_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pytorch_lightning as pl 4 | import importlib 5 | from noksr.model.module import PointTransformerV3 6 | from noksr.utils.samples import BatchedSampler 7 | from noksr.model.general_model import GeneralModel 8 | from torch.nn import functional as F 9 | 10 | class noksr(GeneralModel): 11 | def __init__(self, cfg): 12 | super().__init__(cfg) 13 | self.save_hyperparameters(cfg) 14 | 15 | self.latent_dim = cfg.model.network.latent_dim 16 | self.decoder_type = cfg.model.network.sdf_decoder.decoder_type 17 | self.eikonal = cfg.data.supervision.eikonal.loss 18 | self.laplacian = cfg.data.supervision.laplacian.loss 19 | self.surface_normal_supervision = cfg.data.supervision.on_surface.normal_loss 20 | self.flip_eikonal = cfg.data.supervision.eikonal.flip 21 | self.backbone = cfg.model.network.backbone 22 | if self.backbone == "PointTransformerV3": 23 | self.point_transformer = PointTransformerV3( 24 | backbone_cfg=cfg.model.network.point_transformerv3 25 | ) 26 | 27 | module = importlib.import_module('noksr.model.module') 28 | decoder_class = getattr(module, self.decoder_type) 29 | self.sdf_decoder = decoder_class( 30 | decoder_cfg=cfg.model.network.sdf_decoder, 31 | supervision = 'SDF', 32 | latent_dim=cfg.model.network.latent_dim, 33 | feature_dim=cfg.model.network.sdf_decoder.feature_dim, 34 | hidden_dim=cfg.model.network.sdf_decoder.hidden_dim, 35 | out_dim=1, 36 | voxel_size=cfg.data.voxel_size, 37 | activation=cfg.model.network.sdf_decoder.activation 38 | ) 39 | 40 | # if self.hparams.model.network.mask_decoder.distance_mask: 41 | if cfg.data.supervision.udf.weight > 0: 42 | self.mask_decoder = decoder_class( 43 | decoder_cfg=cfg.model.network.mask_decoder, 44 | supervision = 'Distance', 45 | latent_dim=cfg.model.network.latent_dim, 46 | feature_dim=cfg.model.network.mask_decoder.feature_dim, 47 | hidden_dim=cfg.model.network.mask_decoder.hidden_dim, 48 | out_dim=1, 49 | voxel_size=cfg.data.voxel_size, 50 | activation=cfg.model.network.mask_decoder.activation 51 | ) 52 | 53 | self.batched_sampler = BatchedSampler(cfg) # Instantiate the batched sampler with the configurations 54 | 55 | def forward(self, data_dict): 56 | outputs = {} 57 | """ Get query samples and ground truth values """ 58 | query_xyz, query_gt_sdf = self.batched_sampler.batch_sdf_sample(data_dict) 59 | on_surface_xyz, gt_on_surface_normal = self.batched_sampler.batch_on_surface_sample(data_dict) 60 | if self.hparams.data.supervision.udf.weight > 0: 61 | mask_query_xyz, mask_query_gt_udf = self.batched_sampler.batch_udf_sample(data_dict) 62 | outputs['gt_distances'] = mask_query_gt_udf 63 | 64 | outputs['gt_values'] = query_gt_sdf 65 | outputs['gt_on_surface_normal'] = gt_on_surface_normal 66 | 67 | if self.backbone == "PointTransformerV3": 68 | pt_data = {} 69 | pt_data['feat'] = data_dict['point_features'] 70 | pt_data['offset'] = torch.cumsum(data_dict['xyz_splits'], dim=0) 71 | pt_data['grid_size'] = 0.01 72 | pt_data['coord'] = data_dict['xyz'] 73 | encoder_output = self.point_transformer(pt_data) 74 | 75 | outputs['values'], *_ = self.sdf_decoder(encoder_output, query_xyz) 76 | outputs['surface_values'], *_ = self.sdf_decoder(encoder_output, on_surface_xyz) 77 | 78 | if self.eikonal: 79 | if self.hparams.model.network.grad_type == "Numerical": 80 | interval = 0.01 * self.hparams.data.voxel_size 81 | grad_value = [] 82 | for offset in [(interval, 0, 0), (0, interval, 0), (0, 0, interval)]: 83 | offset_tensor = torch.tensor(offset, device=self.device)[None, :] 84 | res_p, *_ = self.sdf_decoder(encoder_output, query_xyz + offset_tensor) 85 | res_n, *_ = self.sdf_decoder(encoder_output, query_xyz - offset_tensor) 86 | grad_value.append((res_p - res_n) / (2 * interval)) 87 | outputs['pd_grad'] = torch.stack(grad_value, dim=1) 88 | else: 89 | xyz = torch.clone(query_xyz) 90 | xyz.requires_grad = True 91 | with torch.enable_grad(): 92 | res, *_ = self.sdf_decoder(encoder_output, xyz) 93 | outputs['pd_grad'] = torch.autograd.grad(res, [xyz], 94 | grad_outputs=torch.ones_like(res), 95 | create_graph=self.sdf_decoder.training, allow_unused=True)[0] 96 | 97 | if self.laplacian: 98 | interval = 1.0 * self.hparams.data.voxel_size 99 | laplacian_value = 0 100 | 101 | for offset in [(interval, 0, 0), (0, interval, 0), (0, 0, interval)]: 102 | offset_tensor = torch.tensor(offset, device=self.device)[None, :] 103 | 104 | # Calculate numerical gradient 105 | res, *_ = self.sdf_decoder(encoder_output, query_xyz) 106 | res_p, *_ = self.sdf_decoder(encoder_output, query_xyz + offset_tensor) 107 | res_pp, *_ = self.sdf_decoder(encoder_output, query_xyz + 2 * offset_tensor) 108 | laplacian_value += (res_pp - 2 * res_p + res) / (interval ** 2) 109 | outputs['pd_laplacian'] = laplacian_value 110 | 111 | if self.surface_normal_supervision: 112 | if self.hparams.model.network.grad_type == "Numerical": 113 | interval = 0.01 * self.hparams.data.voxel_size 114 | grad_value = [] 115 | for offset in [(interval, 0, 0), (0, interval, 0), (0, 0, interval)]: 116 | offset_tensor = torch.tensor(offset, device=self.device)[None, :] 117 | res_p, *_ = self.sdf_decoder(encoder_output, on_surface_xyz + offset_tensor) 118 | res_n, *_ = self.sdf_decoder(encoder_output, on_surface_xyz - offset_tensor) 119 | grad_value.append((res_p - res_n) / (2 * interval)) 120 | outputs['pd_surface_grad'] = torch.stack(grad_value, dim=1) 121 | else: 122 | xyz = torch.clone(on_surface_xyz) 123 | xyz.requires_grad = True 124 | with torch.enable_grad(): 125 | res, *_ = self.sdf_decoder(encoder_output, xyz) 126 | outputs['pd_surface_grad'] = torch.autograd.grad(res, [xyz], 127 | grad_outputs=torch.ones_like(res), 128 | create_graph=self.sdf_decoder.training, allow_unused=True)[0] 129 | 130 | if self.hparams.data.supervision.udf.weight > 0: 131 | outputs['distances'], *_ = self.mask_decoder(encoder_output, mask_query_xyz) 132 | 133 | return outputs, encoder_output 134 | 135 | def loss(self, data_dict, outputs, encoder_output): 136 | l1_loss = torch.nn.L1Loss(reduction='mean')(torch.clamp(outputs['values'], min = -self.hparams.data.supervision.sdf.max_dist, max=self.hparams.data.supervision.sdf.max_dist), torch.clamp(outputs['gt_values'], min = -self.hparams.data.supervision.sdf.max_dist, max=self.hparams.data.supervision.sdf.max_dist)) 137 | on_surface_loss = torch.abs(outputs['surface_values']).mean() 138 | 139 | mask_loss = torch.tensor(0.0, device=self.device) 140 | eikonal_loss = torch.tensor(0.0, device=self.device) 141 | normal_loss = torch.tensor(0.0, device=self.device) 142 | laplacian_loss = torch.tensor(0.0, device=self.device) 143 | 144 | # Create mask for points within max_dist 145 | valid_mask = (outputs['gt_values'] >= -self.hparams.data.supervision.sdf.max_dist/2) & (outputs['gt_values'] <= self.hparams.data.supervision.sdf.max_dist/2) 146 | 147 | # Eikonal Loss computation 148 | if self.eikonal: 149 | norms = torch.norm(outputs['pd_grad'], dim=1) # Compute the norm over the gradient vectors 150 | eikonal_loss = ((norms - 1) ** 2)[valid_mask].mean() # Masked eikonal loss 151 | 152 | if self.laplacian: 153 | laplacian_loss = torch.abs(outputs['pd_laplacian'])[valid_mask].mean() # Masked laplacian loss 154 | 155 | if self.surface_normal_supervision: 156 | normalized_pd_surface_grad = -outputs['pd_surface_grad'] / (torch.linalg.norm(outputs['pd_surface_grad'], dim=-1, keepdim=True) + 1.0e-6) 157 | normal_loss = 1.0 - torch.sum(normalized_pd_surface_grad * outputs['gt_on_surface_normal'], dim=-1).mean() 158 | 159 | # if self.mask: 160 | if self.hparams.data.supervision.udf.weight > 0: 161 | mask_loss = torch.nn.L1Loss(reduction='mean')(torch.clamp(outputs['distances'], max=self.hparams.data.supervision.udf.max_dist), torch.clamp(outputs['gt_distances'], max=self.hparams.data.supervision.udf.max_dist)) 162 | 163 | return l1_loss, on_surface_loss, mask_loss, eikonal_loss, normal_loss, laplacian_loss 164 | 165 | def training_step(self, data_dict): 166 | """ UDF auto-encoder training stage """ 167 | 168 | batch_size = self.hparams.data.batch_size 169 | outputs, encoder_output = self.forward(data_dict) 170 | 171 | l1_loss, on_surface_loss, mask_loss, eikonal_loss, normal_loss, laplacian_loss = self.loss(data_dict, outputs, encoder_output) 172 | self.log("train/l1_loss", l1_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size) 173 | self.log("train/on_surface_loss", on_surface_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size) 174 | self.log("train/mask_loss", mask_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size) 175 | self.log("train/eikonal_loss", eikonal_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size) 176 | self.log("train/laplacian_loss", laplacian_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size) 177 | 178 | total_loss = l1_loss*self.hparams.data.supervision.sdf.weight + on_surface_loss*self.hparams.data.supervision.on_surface.weight + eikonal_loss*self.hparams.data.supervision.eikonal.weight + mask_loss*self.hparams.data.supervision.udf.weight + normal_loss*self.hparams.data.supervision.on_surface.normal_weight + laplacian_loss*self.hparams.data.supervision.laplacian.weight 179 | 180 | return total_loss 181 | 182 | def validation_step(self, data_dict, idx): 183 | batch_size = 1 184 | outputs, encoder_output = self.forward(data_dict) 185 | l1_loss, on_surface_loss, mask_loss, eikonal_loss, normal_loss, laplacian_loss = self.loss(data_dict, outputs, encoder_output) 186 | 187 | self.log("val/l1_loss", l1_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size, logger=True) 188 | self.log("val/on_surface_loss", on_surface_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size, logger=True) 189 | self.log("val/mask_loss", mask_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size, logger=True) 190 | self.log("val/eikonal_loss", eikonal_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size, logger=True) 191 | self.log("val/laplacian_loss", laplacian_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size, logger=True) -------------------------------------------------------------------------------- /noksr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /noksr/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/utils/__pycache__/evaluation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/evaluation.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/utils/__pycache__/optimizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/optimizer.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/utils/__pycache__/samples.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/samples.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/utils/__pycache__/segmentation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/segmentation.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/utils/__pycache__/transform.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/transform.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import cos, pi 3 | 4 | 5 | """ 6 | modified from 7 | https://github.com/thangvubk/SoftGroup/blob/a87940664d0af38a3f55cdac46f86fc83b029ecc/softgroup/util/utils.py#L55 8 | """ 9 | def cosine_lr_decay(optimizer, base_lr, current_epoch, start_epoch, total_epochs, clip): 10 | if current_epoch < start_epoch: 11 | return 12 | for param_group in optimizer.param_groups: 13 | param_group['lr'] = clip + 0.5 * (base_lr - clip) * \ 14 | (1 + cos(pi * ((current_epoch - start_epoch) / (total_epochs - start_epoch)))) 15 | 16 | def adjust_learning_rate( 17 | initial_lr, optimizer, num_iterations, total_iterations, decreased_by 18 | ): 19 | adjust_lr_every = int(total_iterations / 2) 20 | lr = initial_lr * ((1 / decreased_by) ** (num_iterations // adjust_lr_every)) 21 | for param_group in optimizer.param_groups: 22 | param_group["lr"] = lr -------------------------------------------------------------------------------- /noksr/utils/samples.py: -------------------------------------------------------------------------------- 1 | # import ext 2 | 3 | import torch 4 | from nksr.svh import SparseFeatureHierarchy 5 | from pytorch3d.ops import knn_points 6 | from scipy.spatial import KDTree 7 | 8 | 9 | import torch 10 | 11 | class Sampler: 12 | def __init__(self, **kwargs): 13 | # Default values can be set with kwargs.get('key', default_value) 14 | self.voxel_size = kwargs.get('voxel_size') 15 | self.adaptive_policy = { 16 | 'method': 'normal', 17 | 'tau': 0.1, 18 | 'depth': 2 19 | } 20 | self.cfg = kwargs.get('cfg') 21 | self.ref_xyz = kwargs.get('ref_xyz') 22 | self.ref_normal = kwargs.get('ref_normal') 23 | self.svh = self._build_gt_svh() 24 | self.kdtree = KDTree(self.ref_xyz.detach().cpu().numpy()) 25 | 26 | def _build_gt_svh(self): 27 | gt_svh = SparseFeatureHierarchy( 28 | voxel_size=self.voxel_size, 29 | depth=self.cfg.svh_tree_depth, 30 | device=self.ref_xyz.device 31 | ) 32 | if self.adaptive_policy['method'] == "normal": 33 | gt_svh.build_adaptive_normal_variation( 34 | self.ref_xyz, self.ref_normal, 35 | tau=self.adaptive_policy['tau'], 36 | adaptive_depth=self.adaptive_policy['depth'] 37 | ) 38 | return gt_svh 39 | 40 | def _get_svh_samples(self, svh, n_samples, expand=0, expand_top=0): 41 | """ 42 | Get random samples, across all layers of the decoder hierarchy 43 | :param svh: SparseFeatureHierarchy, hierarchy of spatial features 44 | :param n_samples: int, number of total samples 45 | :param expand: int, size of expansion 46 | :param expand_top: int, size of expansion of the coarsest level. 47 | :return: (n_samples, 3) tensor of positions 48 | """ 49 | base_coords, base_scales = [], [] 50 | for d in range(svh.depth): 51 | if svh.grids[d] is None: 52 | continue 53 | ijk_coords = svh.grids[d].active_grid_coords() 54 | d_expand = expand if d != svh.depth - 1 else expand_top 55 | if d_expand >= 3: 56 | mc_offsets = torch.arange(-d_expand // 2 + 1, d_expand // 2 + 1, device=svh.device) 57 | mc_offsets = torch.stack(torch.meshgrid(mc_offsets, mc_offsets, mc_offsets, indexing='ij'), dim=3) 58 | mc_offsets = mc_offsets.view(-1, 3) 59 | ijk_coords = (ijk_coords.unsqueeze(dim=1).repeat(1, mc_offsets.size(0), 1) + 60 | mc_offsets.unsqueeze(0)).view(-1, 3) 61 | ijk_coords = torch.unique(ijk_coords, dim=0) 62 | base_coords.append(svh.grids[d].grid_to_world(ijk_coords.float())) 63 | base_scales.append(torch.full((ijk_coords.size(0),), svh.grids[d].voxel_size, device=svh.device)) 64 | base_coords, base_scales = torch.cat(base_coords), torch.cat(base_scales) 65 | local_ids = (torch.rand((n_samples,), device=svh.device) * base_coords.size(0)).long() 66 | local_coords = (torch.rand((n_samples, 3), device=svh.device) - 0.5) * base_scales[local_ids, None] 67 | query_pos = base_coords[local_ids] + local_coords 68 | return query_pos 69 | 70 | def _get_samples(self): 71 | all_samples = [] 72 | for config in self.cfg.samplers: 73 | if config.type == "uniform": 74 | all_samples.append( 75 | self._get_svh_samples(self.svh, config.n_samples, config.expand, config.expand_top) 76 | ) 77 | elif config.type == "band": 78 | band_inds = (torch.rand((config.n_samples, ), device=self.ref_xyz.device) * self.ref_xyz.size(0)).long() 79 | eps = config.eps * self.voxel_size 80 | band_pos = self.ref_xyz[band_inds] + \ 81 | self.ref_normal[band_inds] * torch.randn((config.n_samples, 1), device=self.ref_xyz.device) * eps 82 | all_samples.append(band_pos) 83 | elif config.type == 'on_surface': 84 | n_subsample = config.subsample 85 | if 0 < n_subsample < self.ref_xyz.size(0): 86 | ref_xyz_inds = (torch.rand((n_subsample,), device=self.ref_xyz.device) * 87 | self.ref_xyz.size(0)).long() 88 | else: 89 | ref_xyz_inds = (torch.rand((n_subsample,), device=self.ref_xyz.device) * 90 | self.ref_xyz.size(0)).long() 91 | all_samples.append(self.ref_xyz[ref_xyz_inds]) 92 | 93 | return torch.cat(all_samples, 0) 94 | 95 | def transform_field(self, field: torch.Tensor): 96 | sdf_config = self.cfg 97 | assert sdf_config.gt_type != "binary" 98 | truncation_size = sdf_config.gt_band * self.voxel_size 99 | if sdf_config.gt_soft: 100 | field = torch.tanh(field / truncation_size) * truncation_size 101 | else: 102 | field = torch.clone(field) 103 | field[field > truncation_size] = truncation_size 104 | field[field < -truncation_size] = -truncation_size 105 | return field 106 | 107 | def compute_gt_sdf_from_pts(self, query_pos: torch.Tensor): 108 | k = 8 109 | stdv = 0.02 110 | normals = self.ref_normal 111 | device = query_pos.device 112 | knn_output = knn_points(query_pos.unsqueeze(0).to(device), self.ref_xyz.unsqueeze(0).to(device), K=k) 113 | indices = knn_output.idx.squeeze(0) 114 | closest_points = self.ref_xyz[indices] 115 | surface_to_queries_vec = query_pos.unsqueeze(1) - closest_points #N, K, 3 116 | 117 | dot_products = torch.einsum("ijk,ijk->ij", surface_to_queries_vec, normals[indices]) #N, K 118 | vec_lengths = torch.norm(surface_to_queries_vec[:, 0, :], dim=-1) 119 | use_dot_product = vec_lengths < stdv 120 | sdf = torch.where(use_dot_product, torch.abs(dot_products[:, 0]), vec_lengths) 121 | 122 | # Adjust the sign of the sdf values based on the majority of dot products 123 | num_pos = torch.sum(dot_products > 0, dim=1) 124 | inside = num_pos <= (k / 2) 125 | sdf[inside] *= -1 126 | 127 | return -sdf 128 | 129 | 130 | class BatchedSampler: 131 | def __init__(self, hparams): 132 | self.hparams = hparams 133 | 134 | def batch_sdf_sample(self, data_dict): 135 | if 'gt_geometry' in data_dict: 136 | gt_geometry = data_dict['gt_geometry'] 137 | else: 138 | xyz, normal = data_dict['all_xyz'], data_dict['all_normals'] 139 | row_splits = data_dict['row_splits'] 140 | batch_size = len(data_dict['scene_names']) 141 | 142 | start = 0 143 | batch_samples_pos = [] 144 | batch_gt_sdf = [] 145 | 146 | for i in range(batch_size): 147 | if 'gt_geometry' in data_dict: 148 | ref_xyz, ref_normal, _ = gt_geometry[i].torch_attr() 149 | else: 150 | end = start + row_splits[i] 151 | ref_xyz = xyz[start:end] 152 | ref_normal = normal[start:end] 153 | start = end 154 | 155 | # Instantiate Sampler using kwargs 156 | sampler = Sampler( 157 | voxel_size=self.hparams.data.voxel_size, 158 | cfg=self.hparams.data.supervision.sdf, 159 | ref_xyz=ref_xyz, 160 | ref_normal=ref_normal 161 | ) 162 | samples_pos = sampler._get_samples() 163 | # nksr_sdf = sampler.compute_gt_chi_from_pts(samples_pos) 164 | gt_sdf = sampler.compute_gt_sdf_from_pts(samples_pos) 165 | if self.hparams.data.supervision.sdf.truncate: 166 | gt_sdf = sampler.transform_field(gt_sdf) 167 | 168 | batch_samples_pos.append(samples_pos) 169 | batch_gt_sdf.append(gt_sdf) 170 | 171 | batch_samples_pos = torch.cat(batch_samples_pos, dim=0) 172 | batch_gt_sdf = torch.cat(batch_gt_sdf, dim=0) 173 | 174 | return batch_samples_pos, batch_gt_sdf 175 | 176 | def batch_udf_sample(self, data_dict): 177 | if 'gt_geometry' in data_dict: 178 | gt_geometry = data_dict['gt_geometry'] 179 | else: 180 | xyz, normal = data_dict['all_xyz'], data_dict['all_normals'] 181 | row_splits = data_dict['row_splits'] 182 | batch_size = len(data_dict['scene_names']) 183 | 184 | start = 0 185 | batch_samples_pos = [] 186 | batch_gt_udf = [] 187 | 188 | for i in range(batch_size): 189 | if 'gt_geometry' in data_dict: 190 | ref_xyz, ref_normal, _ = gt_geometry[i].torch_attr() 191 | else: 192 | end = start + row_splits[i] 193 | ref_xyz = xyz[start:end] 194 | ref_normal = normal[start:end] 195 | start = end 196 | 197 | # Instantiate Sampler for UDF using kwargs 198 | sampler = Sampler( 199 | voxel_size=self.hparams.data.voxel_size, 200 | cfg=self.hparams.data.supervision.udf, 201 | ref_xyz=ref_xyz, 202 | ref_normal=ref_normal 203 | ) 204 | samples_pos = sampler._get_samples() 205 | if self.hparams.data.supervision.udf.abs_sdf: 206 | gt_udf = torch.abs(sampler.compute_gt_sdf_from_pts(samples_pos)) 207 | else: 208 | knn_output = knn_points(samples_pos.unsqueeze(0).to(torch.device("cuda")), 209 | ref_xyz.unsqueeze(0).to(torch.device("cuda")), 210 | K=1) 211 | gt_udf = knn_output.dists.squeeze(0).squeeze(-1) 212 | 213 | batch_samples_pos.append(samples_pos) 214 | batch_gt_udf.append(gt_udf) 215 | 216 | 217 | batch_samples_pos = torch.cat(batch_samples_pos, dim=0) 218 | batch_gt_udf = torch.cat(batch_gt_udf, dim=0) 219 | 220 | return batch_samples_pos, batch_gt_udf 221 | 222 | def batch_on_surface_sample(self, data_dict): 223 | if 'gt_geometry' in data_dict: 224 | gt_geometry = data_dict['gt_geometry'] 225 | else: 226 | xyz, normal = data_dict['all_xyz'], data_dict['all_normals'] 227 | row_splits = data_dict['row_splits'] 228 | batch_size = len(data_dict['scene_names']) 229 | 230 | start = 0 231 | batch_samples_pos = [] 232 | batch_samples_normal = [] 233 | batch_gt_udf = [] 234 | 235 | for i in range(batch_size): 236 | if 'gt_geometry' in data_dict: 237 | ref_xyz, ref_normal, _ = gt_geometry[i].torch_attr() 238 | else: 239 | end = start + row_splits[i] 240 | ref_xyz = xyz[start:end] 241 | ref_normal = normal[start:end] 242 | start = end 243 | 244 | n_subsample = self.hparams.data.supervision.on_surface.subsample 245 | if 0 < n_subsample < ref_xyz.size(0): 246 | ref_xyz_inds = (torch.rand((n_subsample,), device=ref_xyz.device) * 247 | ref_xyz.size(0)).long() 248 | else: 249 | ref_xyz_inds = (torch.rand((n_subsample,), device=ref_xyz.device) * 250 | ref_xyz.size(0)).long() 251 | batch_samples_pos.append(ref_xyz[ref_xyz_inds]) 252 | batch_samples_normal.append(ref_normal[ref_xyz_inds]) 253 | 254 | batch_samples_pos = torch.cat(batch_samples_pos, dim=0) 255 | batch_samples_normal = torch.cat(batch_samples_normal, dim=0) 256 | return batch_samples_pos, batch_samples_normal 257 | 258 | 259 | -------------------------------------------------------------------------------- /noksr/utils/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from noksr.utils.serialization import encode 3 | 4 | def segment_and_generate_encoder_outputs(batch, model, device, segment_num, grid_size=0.01, serial_order='z'): 5 | """ 6 | Segment input `xyz` of the batch, generate encoder outputs for each segment, and return necessary metadata. 7 | 8 | Args: 9 | batch (dict): Batch data containing `xyz` and other input features. 10 | model (torch.nn.Module): The neural network model for processing data. 11 | device (torch.device): The device to run inference on. 12 | segment_num (int): Fixed number of segments. 13 | grid_size (float): Grid size for serialization. 14 | serial_order (str): Order for serialization ('z' by default). 15 | 16 | Returns: 17 | encoder_outputs (list): List of encoder outputs for each segment. 18 | encoding_codes (list): List of (start_code, end_code) for each segment. 19 | depth (int): Depth used for serialization. 20 | """ 21 | # Extract necessary data from the batch 22 | xyz = batch['xyz'].to(torch.float32) # Convert to float32 for processing 23 | total_points = xyz.shape[0] 24 | 25 | # Compute segment length based on fixed number of segments 26 | segment_length = max(1, total_points // segment_num) 27 | 28 | in_quant_coords = torch.floor(xyz / grid_size).to(torch.int) 29 | depth = int(torch.abs(in_quant_coords).max()).bit_length() # Calculate serialization depth 30 | in_quant_codes = encode( 31 | in_quant_coords, 32 | torch.zeros(in_quant_coords.shape[0], dtype=torch.int64, device=in_quant_coords.device), 33 | depth, 34 | order=serial_order 35 | ) 36 | in_sorted_quant_codes, in_sorted_indices = torch.sort(in_quant_codes) 37 | 38 | segments = [] 39 | encoding_codes = [] 40 | for i in range(segment_num): 41 | start_idx = i * segment_length 42 | end_idx = min(start_idx + segment_length, total_points) 43 | segment_indices = in_sorted_indices[start_idx:end_idx] 44 | segments.append(segment_indices) 45 | 46 | # Store the start and end encoding codes for the segment 47 | start_code = in_sorted_quant_codes[start_idx].item() 48 | end_code = in_sorted_quant_codes[end_idx - 1].item() 49 | encoding_codes.append(start_code) 50 | 51 | # Generate encoder outputs for each segment 52 | encoder_outputs = [] 53 | for segment_indices in segments: 54 | # Create a new batch for the current segment 55 | segment_batch = { 56 | "xyz": batch['xyz'][segment_indices], 57 | "point_features": batch['point_features'][segment_indices], 58 | "scene_names": batch['scene_names'], 59 | "xyz_splits": torch.tensor([len(segment_indices)], device=batch['xyz'].device) 60 | } 61 | 62 | pt_data = { 63 | 'feat': segment_batch['point_features'], 64 | 'offset': segment_batch['xyz_splits'], # Offset for the segment 65 | 'grid_size': grid_size, 66 | 'coord': segment_batch['xyz'] 67 | } 68 | segment_encoder_output = model.point_transformer(pt_data) 69 | encoder_outputs.append(segment_encoder_output) 70 | 71 | return encoder_outputs, encoding_codes, depth -------------------------------------------------------------------------------- /noksr/utils/serialization/__init__.py: -------------------------------------------------------------------------------- 1 | from .default import ( 2 | encode, 3 | decode, 4 | z_order_encode, 5 | z_order_decode, 6 | hilbert_encode, 7 | hilbert_decode, 8 | ) 9 | -------------------------------------------------------------------------------- /noksr/utils/serialization/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/serialization/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/utils/serialization/__pycache__/default.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/serialization/__pycache__/default.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/utils/serialization/__pycache__/hilbert.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/serialization/__pycache__/hilbert.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/utils/serialization/__pycache__/z_order.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/serialization/__pycache__/z_order.cpython-310.pyc -------------------------------------------------------------------------------- /noksr/utils/serialization/default.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .z_order import xyz2key as z_order_encode_ 3 | from .z_order import key2xyz as z_order_decode_ 4 | from .hilbert import encode as hilbert_encode_ 5 | from .hilbert import decode as hilbert_decode_ 6 | 7 | 8 | @torch.inference_mode() 9 | def encode(grid_coord, batch=None, depth=16, order="z"): 10 | assert order in {"z", "z-trans", "hilbert", "hilbert-trans"} 11 | if order == "z": 12 | code = z_order_encode(grid_coord, depth=depth) 13 | elif order == "z-trans": 14 | code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) 15 | elif order == "hilbert": 16 | code = hilbert_encode(grid_coord, depth=depth) 17 | elif order == "hilbert-trans": 18 | code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) 19 | else: 20 | raise NotImplementedError 21 | if batch is not None: 22 | batch = batch.long() 23 | code = batch << depth * 3 | code 24 | return code 25 | 26 | 27 | @torch.inference_mode() 28 | def decode(code, depth=16, order="z"): 29 | assert order in {"z", "hilbert"} 30 | batch = code >> depth * 3 31 | code = code & ((1 << depth * 3) - 1) 32 | if order == "z": 33 | grid_coord = z_order_decode(code, depth=depth) 34 | elif order == "hilbert": 35 | grid_coord = hilbert_decode(code, depth=depth) 36 | else: 37 | raise NotImplementedError 38 | return grid_coord, batch 39 | 40 | 41 | def z_order_encode(grid_coord: torch.Tensor, depth: int = 16): 42 | x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long() 43 | # we block the support to batch, maintain batched code in Point class 44 | code = z_order_encode_(x, y, z, b=None, depth=depth) 45 | return code 46 | 47 | 48 | def z_order_decode(code: torch.Tensor, depth): 49 | x, y, z = z_order_decode_(code, depth=depth) 50 | grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3) 51 | return grid_coord 52 | 53 | 54 | def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16): 55 | return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth) 56 | 57 | 58 | def hilbert_decode(code: torch.Tensor, depth: int = 16): 59 | return hilbert_decode_(code, num_dims=3, num_bits=depth) 60 | -------------------------------------------------------------------------------- /noksr/utils/serialization/hilbert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hilbert Order 3 | Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve 4 | 5 | Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu 6 | Please cite our work if the code is helpful to you. 7 | """ 8 | 9 | import torch 10 | 11 | 12 | def right_shift(binary, k=1, axis=-1): 13 | """Right shift an array of binary values. 14 | 15 | Parameters: 16 | ----------- 17 | binary: An ndarray of binary values. 18 | 19 | k: The number of bits to shift. Default 1. 20 | 21 | axis: The axis along which to shift. Default -1. 22 | 23 | Returns: 24 | -------- 25 | Returns an ndarray with zero prepended and the ends truncated, along 26 | whatever axis was specified.""" 27 | 28 | # If we're shifting the whole thing, just return zeros. 29 | if binary.shape[axis] <= k: 30 | return torch.zeros_like(binary) 31 | 32 | # Determine the padding pattern. 33 | # padding = [(0,0)] * len(binary.shape) 34 | # padding[axis] = (k,0) 35 | 36 | # Determine the slicing pattern to eliminate just the last one. 37 | slicing = [slice(None)] * len(binary.shape) 38 | slicing[axis] = slice(None, -k) 39 | shifted = torch.nn.functional.pad( 40 | binary[tuple(slicing)], (k, 0), mode="constant", value=0 41 | ) 42 | 43 | return shifted 44 | 45 | 46 | def binary2gray(binary, axis=-1): 47 | """Convert an array of binary values into Gray codes. 48 | 49 | This uses the classic X ^ (X >> 1) trick to compute the Gray code. 50 | 51 | Parameters: 52 | ----------- 53 | binary: An ndarray of binary values. 54 | 55 | axis: The axis along which to compute the gray code. Default=-1. 56 | 57 | Returns: 58 | -------- 59 | Returns an ndarray of Gray codes. 60 | """ 61 | shifted = right_shift(binary, axis=axis) 62 | 63 | # Do the X ^ (X >> 1) trick. 64 | gray = torch.logical_xor(binary, shifted) 65 | 66 | return gray 67 | 68 | 69 | def gray2binary(gray, axis=-1): 70 | """Convert an array of Gray codes back into binary values. 71 | 72 | Parameters: 73 | ----------- 74 | gray: An ndarray of gray codes. 75 | 76 | axis: The axis along which to perform Gray decoding. Default=-1. 77 | 78 | Returns: 79 | -------- 80 | Returns an ndarray of binary values. 81 | """ 82 | 83 | # Loop the log2(bits) number of times necessary, with shift and xor. 84 | shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1) 85 | while shift > 0: 86 | gray = torch.logical_xor(gray, right_shift(gray, shift)) 87 | shift = torch.div(shift, 2, rounding_mode="floor") 88 | return gray 89 | 90 | 91 | def encode(locs, num_dims, num_bits): 92 | """Decode an array of locations in a hypercube into a Hilbert integer. 93 | 94 | This is a vectorized-ish version of the Hilbert curve implementation by John 95 | Skilling as described in: 96 | 97 | Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference 98 | Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. 99 | 100 | Params: 101 | ------- 102 | locs - An ndarray of locations in a hypercube of num_dims dimensions, in 103 | which each dimension runs from 0 to 2**num_bits-1. The shape can 104 | be arbitrary, as long as the last dimension of the same has size 105 | num_dims. 106 | 107 | num_dims - The dimensionality of the hypercube. Integer. 108 | 109 | num_bits - The number of bits for each dimension. Integer. 110 | 111 | Returns: 112 | -------- 113 | The output is an ndarray of uint64 integers with the same shape as the 114 | input, excluding the last dimension, which needs to be num_dims. 115 | """ 116 | 117 | # Keep around the original shape for later. 118 | orig_shape = locs.shape 119 | bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) 120 | bitpack_mask_rev = bitpack_mask.flip(-1) 121 | 122 | if orig_shape[-1] != num_dims: 123 | raise ValueError( 124 | """ 125 | The shape of locs was surprising in that the last dimension was of size 126 | %d, but num_dims=%d. These need to be equal. 127 | """ 128 | % (orig_shape[-1], num_dims) 129 | ) 130 | 131 | if num_dims * num_bits > 63: 132 | raise ValueError( 133 | """ 134 | num_dims=%d and num_bits=%d for %d bits total, which can't be encoded 135 | into a int64. Are you sure you need that many points on your Hilbert 136 | curve? 137 | """ 138 | % (num_dims, num_bits, num_dims * num_bits) 139 | ) 140 | 141 | # Treat the location integers as 64-bit unsigned and then split them up into 142 | # a sequence of uint8s. Preserve the association by dimension. 143 | locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) 144 | 145 | # Now turn these into bits and truncate to num_bits. 146 | gray = ( 147 | locs_uint8.unsqueeze(-1) 148 | .bitwise_and(bitpack_mask_rev) 149 | .ne(0) 150 | .byte() 151 | .flatten(-2, -1)[..., -num_bits:] 152 | ) 153 | 154 | # Run the decoding process the other way. 155 | # Iterate forwards through the bits. 156 | for bit in range(0, num_bits): 157 | # Iterate forwards through the dimensions. 158 | for dim in range(0, num_dims): 159 | # Identify which ones have this bit active. 160 | mask = gray[:, dim, bit] 161 | 162 | # Where this bit is on, invert the 0 dimension for lower bits. 163 | gray[:, 0, bit + 1 :] = torch.logical_xor( 164 | gray[:, 0, bit + 1 :], mask[:, None] 165 | ) 166 | 167 | # Where the bit is off, exchange the lower bits with the 0 dimension. 168 | to_flip = torch.logical_and( 169 | torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1), 170 | torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), 171 | ) 172 | gray[:, dim, bit + 1 :] = torch.logical_xor( 173 | gray[:, dim, bit + 1 :], to_flip 174 | ) 175 | gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) 176 | 177 | # Now flatten out. 178 | gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims)) 179 | 180 | # Convert Gray back to binary. 181 | hh_bin = gray2binary(gray) 182 | 183 | # Pad back out to 64 bits. 184 | extra_dims = 64 - num_bits * num_dims 185 | padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0) 186 | 187 | # Convert binary values into uint8s. 188 | hh_uint8 = ( 189 | (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask) 190 | .sum(2) 191 | .squeeze() 192 | .type(torch.uint8) 193 | ) 194 | 195 | # Convert uint8s into uint64s. 196 | hh_uint64 = hh_uint8.view(torch.int64).squeeze() 197 | 198 | return hh_uint64 199 | 200 | 201 | def decode(hilberts, num_dims, num_bits): 202 | """Decode an array of Hilbert integers into locations in a hypercube. 203 | 204 | This is a vectorized-ish version of the Hilbert curve implementation by John 205 | Skilling as described in: 206 | 207 | Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference 208 | Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. 209 | 210 | Params: 211 | ------- 212 | hilberts - An ndarray of Hilbert integers. Must be an integer dtype and 213 | cannot have fewer bits than num_dims * num_bits. 214 | 215 | num_dims - The dimensionality of the hypercube. Integer. 216 | 217 | num_bits - The number of bits for each dimension. Integer. 218 | 219 | Returns: 220 | -------- 221 | The output is an ndarray of unsigned integers with the same shape as hilberts 222 | but with an additional dimension of size num_dims. 223 | """ 224 | 225 | if num_dims * num_bits > 64: 226 | raise ValueError( 227 | """ 228 | num_dims=%d and num_bits=%d for %d bits total, which can't be encoded 229 | into a uint64. Are you sure you need that many points on your Hilbert 230 | curve? 231 | """ 232 | % (num_dims, num_bits) 233 | ) 234 | 235 | # Handle the case where we got handed a naked integer. 236 | hilberts = torch.atleast_1d(hilberts) 237 | 238 | # Keep around the shape for later. 239 | orig_shape = hilberts.shape 240 | bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device) 241 | bitpack_mask_rev = bitpack_mask.flip(-1) 242 | 243 | # Treat each of the hilberts as a s equence of eight uint8. 244 | # This treats all of the inputs as uint64 and makes things uniform. 245 | hh_uint8 = ( 246 | hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1) 247 | ) 248 | 249 | # Turn these lists of uints into lists of bits and then truncate to the size 250 | # we actually need for using Skilling's procedure. 251 | hh_bits = ( 252 | hh_uint8.unsqueeze(-1) 253 | .bitwise_and(bitpack_mask_rev) 254 | .ne(0) 255 | .byte() 256 | .flatten(-2, -1)[:, -num_dims * num_bits :] 257 | ) 258 | 259 | # Take the sequence of bits and Gray-code it. 260 | gray = binary2gray(hh_bits) 261 | 262 | # There has got to be a better way to do this. 263 | # I could index them differently, but the eventual packbits likes it this way. 264 | gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2) 265 | 266 | # Iterate backwards through the bits. 267 | for bit in range(num_bits - 1, -1, -1): 268 | # Iterate backwards through the dimensions. 269 | for dim in range(num_dims - 1, -1, -1): 270 | # Identify which ones have this bit active. 271 | mask = gray[:, dim, bit] 272 | 273 | # Where this bit is on, invert the 0 dimension for lower bits. 274 | gray[:, 0, bit + 1 :] = torch.logical_xor( 275 | gray[:, 0, bit + 1 :], mask[:, None] 276 | ) 277 | 278 | # Where the bit is off, exchange the lower bits with the 0 dimension. 279 | to_flip = torch.logical_and( 280 | torch.logical_not(mask[:, None]), 281 | torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), 282 | ) 283 | gray[:, dim, bit + 1 :] = torch.logical_xor( 284 | gray[:, dim, bit + 1 :], to_flip 285 | ) 286 | gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip) 287 | 288 | # Pad back out to 64 bits. 289 | extra_dims = 64 - num_bits 290 | padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0) 291 | 292 | # Now chop these up into blocks of 8. 293 | locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8)) 294 | 295 | # Take those blocks and turn them unto uint8s. 296 | # from IPython import embed; embed() 297 | locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8) 298 | 299 | # Finally, treat these as uint64s. 300 | flat_locs = locs_uint8.view(torch.int64) 301 | 302 | # Return them in the expected shape. 303 | return flat_locs.reshape((*orig_shape, num_dims)) 304 | -------------------------------------------------------------------------------- /noksr/utils/serialization/z_order.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Octree-based Sparse Convolutional Neural Networks 3 | # Copyright (c) 2022 Peng-Shuai Wang 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from typing import Optional, Union 10 | 11 | 12 | class KeyLUT: 13 | def __init__(self): 14 | r256 = torch.arange(256, dtype=torch.int64) 15 | r512 = torch.arange(512, dtype=torch.int64) 16 | zero = torch.zeros(256, dtype=torch.int64) 17 | device = torch.device("cpu") 18 | 19 | self._encode = { 20 | device: ( 21 | self.xyz2key(r256, zero, zero, 8), 22 | self.xyz2key(zero, r256, zero, 8), 23 | self.xyz2key(zero, zero, r256, 8), 24 | ) 25 | } 26 | self._decode = {device: self.key2xyz(r512, 9)} 27 | 28 | def encode_lut(self, device=torch.device("cpu")): 29 | if device not in self._encode: 30 | cpu = torch.device("cpu") 31 | self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) 32 | return self._encode[device] 33 | 34 | def decode_lut(self, device=torch.device("cpu")): 35 | if device not in self._decode: 36 | cpu = torch.device("cpu") 37 | self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) 38 | return self._decode[device] 39 | 40 | def xyz2key(self, x, y, z, depth): 41 | key = torch.zeros_like(x) 42 | for i in range(depth): 43 | mask = 1 << i 44 | key = ( 45 | key 46 | | ((x & mask) << (2 * i + 2)) 47 | | ((y & mask) << (2 * i + 1)) 48 | | ((z & mask) << (2 * i + 0)) 49 | ) 50 | return key 51 | 52 | def key2xyz(self, key, depth): 53 | x = torch.zeros_like(key) 54 | y = torch.zeros_like(key) 55 | z = torch.zeros_like(key) 56 | for i in range(depth): 57 | x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) 58 | y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) 59 | z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) 60 | return x, y, z 61 | 62 | 63 | _key_lut = KeyLUT() 64 | 65 | 66 | def xyz2key( 67 | x: torch.Tensor, 68 | y: torch.Tensor, 69 | z: torch.Tensor, 70 | b: Optional[Union[torch.Tensor, int]] = None, 71 | depth: int = 16, 72 | ): 73 | r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys 74 | based on pre-computed look up tables. The speed of this function is much 75 | faster than the method based on for-loop. 76 | 77 | Args: 78 | x (torch.Tensor): The x coordinate. 79 | y (torch.Tensor): The y coordinate. 80 | z (torch.Tensor): The z coordinate. 81 | b (torch.Tensor or int): The batch index of the coordinates, and should be 82 | smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of 83 | :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. 84 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). 85 | """ 86 | 87 | EX, EY, EZ = _key_lut.encode_lut(x.device) 88 | x, y, z = x.long(), y.long(), z.long() 89 | 90 | mask = 255 if depth > 8 else (1 << depth) - 1 91 | key = EX[x & mask] | EY[y & mask] | EZ[z & mask] 92 | if depth > 8: 93 | mask = (1 << (depth - 8)) - 1 94 | key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] 95 | key = key16 << 24 | key 96 | 97 | if b is not None: 98 | b = b.long() 99 | key = b << 48 | key 100 | 101 | return key 102 | 103 | 104 | def key2xyz(key: torch.Tensor, depth: int = 16): 105 | r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates 106 | and the batch index based on pre-computed look up tables. 107 | 108 | Args: 109 | key (torch.Tensor): The shuffled key. 110 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). 111 | """ 112 | 113 | DX, DY, DZ = _key_lut.decode_lut(key.device) 114 | x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key) 115 | 116 | b = key >> 48 117 | key = key & ((1 << 48) - 1) 118 | 119 | n = (depth + 2) // 3 120 | for i in range(n): 121 | k = key >> (i * 9) & 511 122 | x = x | (DX[k] << (i * 3)) 123 | y = y | (DY[k] << (i * 3)) 124 | z = z | (DZ[k] << (i * 3)) 125 | 126 | return x, y, z, b 127 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.9.4 2 | lightning-bolts 3 | scipy 4 | open3d 5 | h5py 6 | ninja 7 | wandb 8 | hydra-core 9 | cython 10 | scikit-image 11 | trimesh 12 | arrgh 13 | plyfile 14 | imageio-ffmpeg 15 | gin-config 16 | torchviz 17 | thop 18 | imageio 19 | einops 20 | spconv-cu118 21 | timm 22 | flash-attn==2.6.3 23 | -------------------------------------------------------------------------------- /scripts/segment_carla.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from pathlib import Path 5 | import shutil 6 | from noksr.utils.serialization import encode 7 | from tqdm import tqdm 8 | 9 | def _serial_scene(scene, segment_length, segment_num, grid_size=0.01, serial_order='z'): 10 | # Process ground truth points 11 | in_xyz = torch.from_numpy(scene['in_xyz']).to(torch.float32) 12 | in_quant_coords = torch.floor(in_xyz / grid_size).to(torch.int) 13 | gt_xyz = torch.from_numpy(scene['gt_xyz']).to(torch.float32) 14 | gt_quant_coords = torch.floor(gt_xyz / grid_size).to(torch.int) 15 | 16 | depth = int(max(gt_quant_coords.max(), in_quant_coords.max())).bit_length() 17 | gt_quant_codes = encode(gt_quant_coords, torch.zeros(gt_quant_coords.shape[0], dtype=torch.int64, device=gt_quant_coords.device), depth, order=serial_order) 18 | gt_sorted_quant_codes, gt_sorted_indices = torch.sort(gt_quant_codes) 19 | 20 | total_gt_points = len(gt_sorted_quant_codes) 21 | 22 | in_quant_codes = encode(in_quant_coords, torch.zeros(in_quant_coords.shape[0], dtype=torch.int64, device=in_quant_coords.device), depth, order=serial_order) 23 | in_sorted_quant_codes, in_sorted_indices = torch.sort(in_quant_codes) 24 | 25 | # Segment ground truth points and find corresponding input points 26 | segments = [] 27 | for i in range(segment_num): 28 | # Ground truth segmentation 29 | gt_start_idx = i * segment_length 30 | gt_end_idx = gt_start_idx + segment_length 31 | gt_end_idx = min(gt_end_idx, total_gt_points) 32 | gt_segment_indices = gt_sorted_indices[gt_start_idx:gt_end_idx] 33 | 34 | # Input segmentation within the ground truth range 35 | in_segment_indices = in_sorted_indices[(in_sorted_quant_codes >= gt_sorted_quant_codes[gt_start_idx].item()) & 36 | (in_sorted_quant_codes < gt_sorted_quant_codes[min(gt_end_idx, total_gt_points - 1)].item())] 37 | 38 | segments.append({ 39 | "gt_indices": gt_segment_indices.numpy(), 40 | "in_indices": in_segment_indices.numpy() 41 | }) 42 | 43 | return segments 44 | 45 | def update_lst_files(input_drive_path, output_drive_path, segmented_scenes): 46 | """ 47 | Update .lst files in the output directory by expanding entries for segmented scenes. 48 | 49 | Args: 50 | input_drive_path (Path): Path to the input drive directory. 51 | output_drive_path (Path): Path to the output drive directory. 52 | segmented_scenes (dict): Dictionary of original scenes to their segment counts (e.g., {"rgn0000": 12}). 53 | """ 54 | lst_files = ['test.lst', 'testall.lst', 'train.lst', 'val.lst'] 55 | 56 | for lst_file in lst_files: 57 | input_lst_file = input_drive_path / lst_file 58 | output_lst_file = output_drive_path / lst_file 59 | 60 | if not input_lst_file.exists(): 61 | continue # Skip if the .lst file doesn't exist in the input 62 | 63 | with input_lst_file.open('r') as infile, output_lst_file.open('w') as outfile: 64 | for line in infile: 65 | scene_name = line.strip() # Original scene name without extension 66 | if scene_name in segmented_scenes: 67 | # Expand scene into its segments 68 | num_segments = segmented_scenes[scene_name] 69 | for i in range(num_segments): 70 | outfile.write(f"{scene_name.split('-')[0]}-crop{(i):04d}\n") 71 | else: 72 | # Write the scene as-is if it's not segmented 73 | outfile.write(line) 74 | 75 | def process_dataset(input_path, output_path, fixed_segment_length): 76 | input_base = Path(input_path) 77 | output_base = Path(output_path) 78 | output_base.mkdir(parents=True, exist_ok=True) 79 | target_drives = ['Town01-0', 'Town01-1', 'Town01-2', 80 | 'Town02-0', 'Town02-1', 'Town02-2', 81 | 'Town10-0', 'Town10-1', 'Town10-2', 'Town10-3', 'Town10-4'] 82 | 83 | # Copy list files (e.g., test.lst, val.lst) 84 | for file in input_base.glob('*.lst'): 85 | shutil.copy(file, output_base / file.name) 86 | 87 | # Count total number of scenes for progress bar initialization 88 | total_scenes = sum( 89 | 1 90 | for drive in input_base.iterdir() 91 | if drive.is_dir() and drive.name in target_drives 92 | for item in drive.iterdir() 93 | if item.is_dir() 94 | ) 95 | 96 | # Process each drive folder with progress bar 97 | with tqdm(total=total_scenes, desc="Processing Scenes", unit="scene") as pbar: 98 | for drive in input_base.iterdir(): 99 | if drive.is_dir() and drive.name in target_drives: 100 | drive_output = output_base / drive.name 101 | drive_output.mkdir(parents=True, exist_ok=True) 102 | segmented_scenes = {} 103 | 104 | for item in drive.iterdir(): 105 | if item.is_dir(): 106 | scene_name = item.name 107 | output_scene_base = drive_output 108 | 109 | # Load data 110 | data_file = item / 'pointcloud.npz' 111 | gt_file = item / 'groundtruth.bin' 112 | gt_data = np.load(gt_file, allow_pickle=True) 113 | data = np.load(data_file) 114 | scene = { 115 | 'in_xyz': data['points'], 116 | 'in_normal': data['normals'], 117 | 'gt_xyz': gt_data['xyz'], 118 | 'gt_normal': gt_data['normal'] 119 | } 120 | 121 | total_gt_points = len(scene['gt_xyz']) 122 | 123 | # Step 1: Compute segment number 124 | segment_num = max(1, total_gt_points // fixed_segment_length) 125 | 126 | # Step 2: Recompute segment length for even distribution 127 | segment_length = total_gt_points // segment_num 128 | 129 | # Serialize and segment scene 130 | segments = _serial_scene(scene, segment_length, segment_num) 131 | segmented_scenes[scene_name] = len(segments) # Record segment count for this scene 132 | for i, segment in enumerate(segments): 133 | segment_scene_name = f"{scene_name.split('-')[0]}-crop{(i):04d}" 134 | segment_output = output_scene_base / segment_scene_name 135 | segment_output.mkdir(parents=True, exist_ok=True) 136 | 137 | # Save segmented data 138 | segment_data = { 139 | 'points': scene['in_xyz'][segment["in_indices"]], 140 | 'normals': scene['in_normal'][segment["in_indices"]], 141 | 'ref_xyz': scene['gt_xyz'][segment["gt_indices"]], 142 | 'ref_normals': scene['gt_normal'][segment["gt_indices"]] 143 | } 144 | np.savez(segment_output / 'pointcloud.npz', **segment_data) 145 | 146 | # Update progress bar 147 | pbar.update(1) 148 | 149 | update_lst_files(drive, drive_output, segmented_scenes) 150 | 151 | if __name__ == "__main__": 152 | """ 153 | This script is used to regenerate the training segments by 1-d serialization, original Carla patches are uniformly sampled. 154 | """ 155 | input_path = "./data/carla-lidar/dataset-no-patch" # Replace with your input dataset path 156 | output_path = "./data/carla-lidar/dataset-seg-patch" # Replace with your desired output path 157 | fixed_segment_length = 300000 # Fixed segment length to start computation 158 | 159 | process_dataset(input_path, output_path, fixed_segment_length) -------------------------------------------------------------------------------- /train.log: -------------------------------------------------------------------------------- 1 | [2025-01-02 15:44:07,239][pycg.exp][WARNING] - Customized build of Open3D is not detected, to resolve this you can do: 2 | (recommended, using customized Open3D that enables view sync, animation, ...) 3 | 4 | pip install python-pycg[full] -f https://pycg.s3.ap-northeast-1.amazonaws.com/packages/index.html 5 | 6 | 7 | [2025-01-02 15:44:10,471][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0 8 | [2025-01-02 15:44:10,471][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes. 9 | [2025-01-02 15:46:55,468][torch.nn.parallel.distributed][INFO] - Reducer buckets have been rebuilt in this iteration. 10 | [2025-01-02 15:53:14,873][pycg.exp][WARNING] - Customized build of Open3D is not detected, to resolve this you can do: 11 | (recommended, using customized Open3D that enables view sync, animation, ...) 12 | 13 | pip install python-pycg[full] -f https://pycg.s3.ap-northeast-1.amazonaws.com/packages/index.html 14 | 15 | 16 | [2025-01-02 15:53:18,417][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0 17 | [2025-01-02 15:53:18,417][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes. 18 | [2025-01-02 15:53:22,497][torch.nn.parallel.distributed][INFO] - Reducer buckets have been rebuilt in this iteration. 19 | [2025-01-02 15:53:32,911][pycg.exp][WARNING] - Customized build of Open3D is not detected, to resolve this you can do: 20 | (recommended, using customized Open3D that enables view sync, animation, ...) 21 | 22 | pip install python-pycg[full] -f https://pycg.s3.ap-northeast-1.amazonaws.com/packages/index.html 23 | 24 | 25 | [2025-01-02 15:53:36,194][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0 26 | [2025-01-02 15:53:36,194][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes. 27 | [2025-01-02 15:53:41,534][torch.nn.parallel.distributed][INFO] - Reducer buckets have been rebuilt in this iteration. 28 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | import wandb 4 | import pytorch_lightning as pl 5 | from noksr.callback import * 6 | from importlib import import_module 7 | from noksr.data.data_module import DataModule 8 | from pytorch_lightning.callbacks import LearningRateMonitor 9 | from pytorch_lightning.strategies.ddp import DDPStrategy 10 | 11 | 12 | 13 | def init_callbacks(cfg): 14 | checkpoint_monitor = hydra.utils.instantiate(cfg.model.checkpoint_monitor) 15 | gpu_cache_clean_monitor = GPUCacheCleanCallback() 16 | lr_monitor = LearningRateMonitor(logging_interval="epoch") 17 | return [checkpoint_monitor, gpu_cache_clean_monitor, lr_monitor] 18 | 19 | 20 | @hydra.main(version_base=None, config_path="config", config_name="config") 21 | def main(cfg): 22 | # fix the seed 23 | pl.seed_everything(cfg.global_train_seed, workers=True) 24 | 25 | output_path = os.path.join(cfg.exp_output_root_path, "training") 26 | os.makedirs(output_path, exist_ok=True) 27 | 28 | print("==> initializing data ...") 29 | data_module = DataModule(cfg) 30 | 31 | print("==> initializing logger ...") 32 | logger = hydra.utils.instantiate(cfg.model.logger, save_dir=output_path) 33 | 34 | print("==> initializing monitor ...") 35 | callbacks = init_callbacks(cfg) 36 | 37 | 38 | print("==> initializing trainer ...") 39 | trainer = pl.Trainer(callbacks=callbacks, logger=logger, **cfg.model.trainer, strategy=DDPStrategy(find_unused_parameters=False)) 40 | 41 | print("==> initializing model ...") 42 | model = getattr(import_module("noksr.model"), cfg.model.network.module)(cfg) 43 | 44 | 45 | print("==> start training ...") 46 | trainer.fit(model=model, datamodule=data_module, ckpt_path=cfg.model.ckpt_path) 47 | 48 | if __name__ == '__main__': 49 | main() -------------------------------------------------------------------------------- /wandb/debug-cli.zla247.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/wandb/debug-cli.zla247.log -------------------------------------------------------------------------------- /wandb/settings: -------------------------------------------------------------------------------- 1 | [default] 2 | 3 | --------------------------------------------------------------------------------