├── .vscode
└── settings.json
├── LICENSE
├── README.md
├── assets
├── Teaser.pdf
└── Teaser.png
├── checkpoints
├── config
├── config.yaml
├── data
│ ├── base.yaml
│ ├── carla_novel.yaml
│ ├── carla_original.yaml
│ ├── carla_patch.yaml
│ ├── carla_seg_patch.yaml
│ ├── mixture.yaml
│ ├── scannet.yaml
│ ├── scenenn.yaml
│ └── synthetic.yaml
└── model
│ ├── base.yaml
│ ├── carla_model.yaml
│ ├── scannet_model.yaml
│ ├── scenenn_model.yaml
│ └── synthetic_model.yaml
├── data
├── eval.log
├── eval.py
├── noksr
├── callback
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ └── gpu_cache_clean_callback.cpython-310.pyc
│ └── gpu_cache_clean_callback.py
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ └── data_module.cpython-310.pyc
│ ├── data_module.py
│ └── dataset
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── augmentation.cpython-310.pyc
│ │ ├── augmentation.cpython-38.pyc
│ │ ├── carla.cpython-310.pyc
│ │ ├── carla_gt_geometry.cpython-310.pyc
│ │ ├── general_dataset.cpython-310.pyc
│ │ ├── general_dataset.cpython-38.pyc
│ │ ├── mixture.cpython-310.pyc
│ │ ├── multiscan.cpython-38.pyc
│ │ ├── scannet.cpython-310.pyc
│ │ ├── scannet.cpython-38.pyc
│ │ ├── scannet_rangeudf.cpython-310.pyc
│ │ ├── scannet_rangeudf.cpython-38.pyc
│ │ ├── scenenet.cpython-310.pyc
│ │ ├── scenenn.cpython-310.pyc
│ │ ├── shapenet.cpython-310.pyc
│ │ ├── synthetic.cpython-310.pyc
│ │ ├── voxelization_utils.cpython-310.pyc
│ │ ├── voxelization_utils.cpython-38.pyc
│ │ ├── voxelizer.cpython-310.pyc
│ │ └── voxelizer.cpython-38.pyc
│ │ ├── augmentation.py
│ │ ├── carla.py
│ │ ├── carla_gt_geometry.py
│ │ ├── general_dataset.py
│ │ ├── mixture.py
│ │ ├── scannet.py
│ │ ├── scenenn.py
│ │ └── synthetic.py
├── model
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── general_model.cpython-310.pyc
│ │ ├── noksr_net.cpython-310.pyc
│ │ └── pcs4esr_net.cpython-310.pyc
│ ├── general_model.py
│ ├── module
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── backbone.cpython-310.pyc
│ │ │ ├── backbone.cpython-38.pyc
│ │ │ ├── backbone_nocs.cpython-38.pyc
│ │ │ ├── common.cpython-310.pyc
│ │ │ ├── common.cpython-38.pyc
│ │ │ ├── decoder.cpython-310.pyc
│ │ │ ├── decoder.cpython-38.pyc
│ │ │ ├── decoder2.cpython-38.pyc
│ │ │ ├── encoder.cpython-310.pyc
│ │ │ ├── encoder.cpython-38.pyc
│ │ │ ├── generation.cpython-310.pyc
│ │ │ ├── generation.cpython-38.pyc
│ │ │ ├── kp_decoder.cpython-310.pyc
│ │ │ ├── larger_decoder.cpython-310.pyc
│ │ │ ├── point_transformer.cpython-310.pyc
│ │ │ ├── tiny_unet.cpython-38.pyc
│ │ │ ├── visualization.cpython-310.pyc
│ │ │ └── visualization.cpython-38.pyc
│ │ ├── decoder.py
│ │ ├── generation.py
│ │ └── point_transformer.py
│ └── noksr_net.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── evaluation.cpython-310.pyc
│ ├── optimizer.cpython-310.pyc
│ ├── samples.cpython-310.pyc
│ ├── segmentation.cpython-310.pyc
│ └── transform.cpython-310.pyc
│ ├── evaluation.py
│ ├── optimizer.py
│ ├── samples.py
│ ├── segmentation.py
│ ├── serialization
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── default.cpython-310.pyc
│ │ ├── hilbert.cpython-310.pyc
│ │ └── z_order.cpython-310.pyc
│ ├── default.py
│ ├── hilbert.py
│ └── z_order.py
│ └── transform.py
├── requirements.txt
├── scripts
└── segment_carla.py
├── train.log
├── train.py
└── wandb
├── debug-cli.zla247.log
└── settings
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "workbench.colorCustomizations": {
3 | "activityBar.background": "#0B2379",
4 | "titleBar.activeBackground": "#1031AA",
5 | "titleBar.activeForeground": "#FBFCFF"
6 | }
7 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NoKSR: Kernel-Free Neural Surface Reconstruction via Point Cloud Serialization
2 |
3 | 
4 |
5 | **NoKSR: Kernel-Free Neural Surface Reconstruction via Point Cloud Serialization**
6 | [Zhen Li *](https://colinzhenli.github.io/), [Weiwei Sun *†](https://wsunid.github.io/), [Shrisudhan Govindarajan](https://shrisudhan.github.io/), [Shaobo Xia](https://scholar.google.com/citations?user=eOPO9E0AAAAJ&hl=en), [Daniel Rebain](http://drebain.com/), [Kwang Moo Yi](https://www.cs.ubc.ca/~kmyi/), [Andrea Tagliasacchi](https://theialab.ca/)
7 | **[Paper](https://arxiv.org/abs/2502.12534), [Project Page](https://theialab.github.io/noksr/)**
8 |
9 | Abstract: We present a novel approach to large-scale point cloud
10 | surface reconstruction by developing an efficient framework
11 | that converts an irregular point cloud into a signed distance
12 | field (SDF). Our backbone builds upon recent transformer-
13 | based architectures (i.e. PointTransformerV3), that serial-
14 | izes the point cloud into a locality-preserving sequence of
15 | tokens. We efficiently predict the SDF value at a point by ag-
16 | gregating nearby tokens, where fast approximate neighbors
17 | can be retrieved thanks to the serialization. We serialize
18 | the point cloud at different levels/scales, and non-linearly
19 | aggregate a feature to predict the SDF value. We show
20 | that aggregating across multiple scales is critical to over-
21 | come the approximations introduced by the serialization
22 | (i.e. false negatives in the neighborhood). Our frameworks
23 | sets the new state-of-the-art in terms of accuracy and effi-
24 | ciency (better or similar performance with half the latency
25 | of the best prior method, coupled with a simpler implemen-
26 | tation), particularly on outdoor datasets where sparse-grid
27 | methods have shown limited performance.
28 |
29 | Contact [Zhen Li @ SFU](zla247@sfu.ca) for questions, comments and reporting bugs.
30 |
31 | ## Related Package
32 |
33 | We implemented the fast approximate neighbor search algorithm in the package [`serial-neighbor`](https://pypi.org/project/serial-neighbor/) — a standalone pip package that provides fast and flexible point cloud neighbor search using serialization encoding by space-filling curves (Z-order, Hilbert, etc.).
34 |
35 | - **📦 PyPI**: [`serial-neighbor`](https://pypi.org/project/serial-neighbor/)
36 | - **🔗 GitHub**: [https://github.com/colinzhenli/serial-neighbor](https://github.com/colinzhenli/serial-neighbor)
37 |
38 | You can install it via:
39 |
40 | ```bash
41 | pip install serial-neighbor
42 | ```
43 |
44 | ## News
45 |
46 | - [2025/03/22] The package [`serial-neighbor`](https://pypi.org/project/serial-neighbor/) is released.
47 | - [2025/02/21] The code is released.
48 | - [2025/02/19] The arXiv version is released.
49 |
50 | ## Environment setup
51 |
52 | The code is tested on Ubuntu 20.04 LTS with PyTorch 2.0.0 CUDA 11.8 installed. Please follow the following steps to install PyTorch first.
53 |
54 | ```
55 | # Clone the repository
56 | git clone https://github.com/theialab/noksr.git
57 | cd noksr
58 |
59 | # create and activate the conda environment
60 | conda create -n noksr python=3.10
61 | conda activate noksr
62 |
63 | # install PyTorch 2.x.x
64 | conda install pytorch==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia
65 |
66 | ```
67 | Then, install PyTorch3D
68 | ```
69 | # install runtime dependencies for PyTorch3D
70 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath
71 | conda install -c bottler nvidiacub
72 |
73 | # install PyTorch3D
74 | conda install pytorch3d -c pytorch3d
75 | ```
76 |
77 | Install the necessary packages listed out in requirements.txt:
78 | ```
79 | pip install -r requirements.txt
80 | ```
81 |
82 | Install torch-scatter and nksr
83 | ```
84 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
85 | pip install nksr -f https://nksr.huangjh.tech/whl/torch-2.0.0+cu118.html
86 | ```
87 |
88 | The detailed installation of nksr is described in the [NKSR](https://github.com/nv-tlabs/nksr).
89 |
90 | ## Reproducing results from the paper
91 |
92 | ### Data Preparation
93 |
94 | You can download the data from the following links and put it under `NoKSR/data/`.
95 | - ScanNet:
96 | Data is available [here](https://drive.google.com/drive/folders/1JK_6T61eQ07_y1bi1DD9Xj-XRU0EDKGS?usp=sharing).
97 | We converted original meshes to `.pth` data, and the normals are generated using the [open3d.geometry.TriangleMesh](https://www.open3d.org/html/python_api/open3d.geometry.TriangleMesh.html). The processing detailed from raw scannetv2 data is from [minsu3d](https://github.com/3dlg-hcvc/minsu3d).
98 |
99 | - SceneNN
100 | Data is available [here](https://drive.google.com/file/d/1d_ILfaxpJBpiiwCZtvC4jEKnixEr9N2l/view?usp=sharing).
101 |
102 | - SyntheticRoom
103 | Data is available [here](https://drive.google.com/drive/folders/1PosV8qyXCkjIHzVjPeOIdhCLigpXXDku?usp=sharing), it is from [ConvONet](https://github.com/autonomousvision/convolutional_occupancy_networks), which contains the processing details.
104 |
105 | - CARLA
106 | Data is available [here](https://drive.google.com/file/d/1BFwExw7SRJaqHJ98pqqnR-k6g8XYMAqq/view?usp=sharing), it is from [NKSR](https://github.com/nv-tlabs/nksr).
107 |
108 |
109 | ### Training
110 | Note: Configuration files are managed by [Hydra](https://hydra.cc/), you can easily add or override any configuration attributes by passing them as arguments.
111 | ```shell
112 | # log in to WandB
113 | wandb login
114 |
115 | # train a model from scratch
116 | # ScanNet dataset
117 | python train.py model=scannet_model data=scannet
118 | # SyntheticRoom dataset
119 | python train.py model=synthetic_model data=synthetic
120 | # CARLA dataset
121 | python train.py model=carla_model data=carla_patch
122 | ```
123 |
124 | In addition, you can manually specify different training settings. Common flags include:
125 | - `experiment_name`: Additional experiment name to specify.
126 | - `data.dataset_root_path`: Root path of the dataset.
127 | - `output_folder`: Output folder to save the results, the checkpoints will be saved in `output/{dataset_name}/{experiment_name}/training`.
128 | - `model.network.default_decoder.neighboring`: Neighboring type, default is `Serial`. Options: `Serial`, `KNN`, `Mixture`
129 |
130 | ### Inference
131 |
132 | You can either infer using your own trained models or our pre-trained checkpoints.
133 |
134 | The pre-trained checkpoints on different datasets with different neighboring types are available [here](https://drive.google.com/file/d/1hMm5cnCOfNmr_PgkpOmwRnzCCG4wPqnu/view?usp=drive_link), you can download and put them under `noksr/checkpoints/`.
135 |
136 | ```bash
137 | # For example, Carla original dataset with Serialization neighboring, you need more than 24GB GPU memory to inferece the CARLA dataset, we recommend using a server.
138 | python eval.py model=carla_model data=carla_original model.ckpt_path=checkpoints/Carla_Serial_best.ckpt
139 | # For example, Carla model with Laplacian loss
140 | python eval.py model=carla_model data=carla_original model.ckpt_path=checkpoints/Carla_Laplacian_best.ckpt
141 | # For example, ScanNet dataset with Serialization neighboring
142 | python eval.py model=scannet_model data=scannet model.ckpt_path=checkpoints/ScanNet_Serial_best.ckpt model.inference.split=val
143 | # For example, Test on SceneNN dataset with model trained on ScanNet.
144 | python eval.py model=scenenn_model data=scenenn model.ckpt_path=checkpoints/ScanNet_KNN_best.ckpt
145 | ```
146 | In addition, in the Carla dataset, you can enable reconstruction from segments. This option will be slower but will save a lot of memory. Flags include:
147 | - `data.reconstruction.by_segment=True`: Enable reconstruction from segments.
148 | - `data.reconstruction.segment_num=10`: Number of segments.
149 |
150 | ### Reconstruction
151 | You can reconstruct a specific scene from the datasets above by specifying the scene index.
152 | ```bash
153 | # For example, Carla dataset, 0 can be replaced by any other scene index of validation set
154 | python eval.py model=carla_model data=carla_original model.ckpt_path={path_to_checkpoint} data.over_fitting=True data.take=1 data.intake_start=0
155 | # For example, ScanNet dataset, 308 can be replaced by any other scene index of validation set
156 | python eval.py model=scannet_model data=scannet model.ckpt_path={path_to_checkpoint} data.over_fitting=True data.take=1 data.intake_start=308 model.inference.split=val
157 |
158 | ```
159 | In addition, you can manually specify visualization settings. Flags include:
160 | - `data.visualization.save=True`: When to save the results.
161 | - `data.visualization.Mesh=True`: When to save the reconstructed mesh.
162 | - `data.visualization.Input_points=True`: When to save the input points.
163 |
164 | The results will be saved in `output/{dataset_name}/{experiment_name}/reconstruction/visualization`.
165 |
166 | ## Acknowledgement
167 | This work was supported in part by the Natural Sciences and Engineering Research Council of Canada (NSERC) Discovery Grant, NSERC Collaborative Research and Development Grant, Google DeepMind, Digital Research Alliance of Canada, the Advanced Research Computing at the University of British Columbia, and the SFU Visual Computing Research Chair program. Shaobo Xia was supported
168 | by National Natural Science Foundation of China under Grant 42201481. We would also like to thank Jiahui Huang for the valuable discussion and feedback.
169 |
170 | ## BibTex
171 | If you find our work useful in your research, please consider citing:
172 | ```bibtex
173 | @article{li2025noksrkernelfreeneuralsurface,
174 | author = {Zhen Li and Weiwei Sun and Shrisudhan Govindarajan and Shaobo Xia and Daniel Rebain and Kwang Moo Yi and Andrea Tagliasacchi},
175 | title = {NoKSR: Kernel-Free Neural Surface Reconstruction via Point Cloud Serialization},
176 | year = {2025},
177 | booktitle = {International Conference on 3D Vision (3DV)},
178 | }
179 | ```
180 |
--------------------------------------------------------------------------------
/assets/Teaser.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/assets/Teaser.pdf
--------------------------------------------------------------------------------
/assets/Teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/assets/Teaser.png
--------------------------------------------------------------------------------
/checkpoints:
--------------------------------------------------------------------------------
1 | /localhome/zla247/theia2_data/output/checkpoints
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | # Managed by Hydra
2 |
3 | hydra:
4 | output_subdir: null
5 | run:
6 | dir: .
7 |
8 | defaults:
9 | - _self_
10 | - data: base
11 | - model: base
12 |
13 | experiment_name: run_1
14 | output_folder: output
15 | exp_output_root_path: ${output_folder}/${data.dataset}/${experiment_name}
16 |
17 | global_train_seed: 12345
18 | global_test_seed: 32102
--------------------------------------------------------------------------------
/config/data/base.yaml:
--------------------------------------------------------------------------------
1 | # Managed by Hydra
2 |
3 | dataset_root_path: data
4 |
5 | batch_size: 1
6 | num_workers: 4
7 | voxel_size: 0.02 # only used to compute neighboring radius
8 |
9 | supervision:
10 | only_l1_sdf_loss: false
11 | structure_weight: 20.0
12 |
13 | gt_type: "PointTSDFVolume"
14 |
15 | on_surface:
16 | normal_loss: False # supervised normal by nksr style loss
17 | weight: 200.0
18 | normal_weight: 100.0
19 | subsample: 10000
20 | svh_tree_depth: 4
21 |
22 | sdf:
23 | max_dist: 0.2
24 | weight: 300.0
25 | reg_sdf_weight: 0.0
26 | svh_tree_depth: 4
27 | samplers:
28 | - type: "uniform"
29 | n_samples: 10000
30 | expand: 1
31 | expand_top: 3
32 | - type: "band"
33 | n_samples: 10000
34 | eps: 0.5 # Times voxel size.
35 |
36 | truncate: False # whether to truncate the SDF values
37 | gt_type: "l1" # or 'l1'
38 | gt_soft: true
39 | gt_band: 1.0 # times voxel size.
40 | pd_transform: true
41 | # (For AV Supervision)
42 | vol_sup: true
43 |
44 | udf:
45 | max_dist: 0.2
46 | abs_sdf: True
47 | weight: 150.0
48 | svh_tree_depth: 4
49 | samplers:
50 | - type: "uniform"
51 | n_samples: 10000
52 | expand: 1
53 | expand_top: 5
54 | - type: "band"
55 | n_samples: 10000
56 | eps: 0.5 # Times voxel size.
57 |
58 | eikonal:
59 | loss: False
60 | flip: True # flip the gradient sign
61 | weight: 10.0
62 |
63 | laplacian:
64 | loss: False
65 | weight: 0.0
66 |
67 | num_input_points: 10000
68 | uniform_sampling: True
69 | input_splats: False
70 | num_query_points: 50000
71 | std_dev: 0.00 # noise to add to the input points
72 | in_memory: True # whether precompute voxel indices and load voxel into memory
73 |
74 | take: -1 # how many data to take for training and validation, -1 means all
75 | intake_start: 0
76 | over_fitting: False # whether to use only one voxel for training and validation
77 |
78 | reconstruction:
79 | by_segment: False
80 | segment_num: 10
81 | trim: True
82 | gt_mask: False
83 | gt_sdf: False
84 |
85 | visualization:
86 | save: False
87 | Mesh: True
88 | Input_points: True
89 | Dense_points: False
90 |
91 |
92 |
--------------------------------------------------------------------------------
/config/data/carla_novel.yaml:
--------------------------------------------------------------------------------
1 | # Managed by Hydra
2 |
3 | defaults:
4 | - base
5 |
6 | dataset: Carla
7 |
8 | base_path: ${data.dataset_root_path}/carla-lidar/dataset-no-patch
9 | input_path: ${data.dataset_root_path}/carla-lidar/dataset-p1n2-no-patch
10 |
11 | drives: ['Town03-0', 'Town03-1', 'Town03-2']
12 |
13 | voxel_size: 0.1
14 |
15 | transforms: []
16 |
17 | supervision:
18 | sdf:
19 | max_dist: 0.6
20 |
21 | udf:
22 | max_dist: 0.6
23 |
24 | reconstruction:
25 | mask_threshold: 0.1
26 |
27 | evaluation:
28 | evaluator: "MeshEvaluator" # align evaluation with NKSR
--------------------------------------------------------------------------------
/config/data/carla_original.yaml:
--------------------------------------------------------------------------------
1 | # Managed by Hydra
2 |
3 | defaults:
4 | - base
5 |
6 | dataset: Carla
7 |
8 | base_path: ${data.dataset_root_path}/carla-lidar/dataset-no-patch
9 | input_path: ${data.dataset_root_path}/carla-lidar/dataset-p1n2-no-patch
10 |
11 | drives: ['Town01-0', 'Town01-1', 'Town01-2',
12 | 'Town02-0', 'Town02-1', 'Town02-2',
13 | 'Town10-0', 'Town10-1', 'Town10-2', 'Town10-3', 'Town10-4']
14 | voxel_size: 0.1
15 |
16 | transforms: []
17 |
18 | supervision:
19 | sdf:
20 | max_dist: 0.6
21 |
22 | udf:
23 | max_dist: 0.6
24 |
25 | reconstruction:
26 | mask_threshold: 0.1
27 |
28 | evaluation:
29 | evaluator: "MeshEvaluator" # align evaluation with NKSR
--------------------------------------------------------------------------------
/config/data/carla_patch.yaml:
--------------------------------------------------------------------------------
1 | # Managed by Hydra
2 |
3 | defaults:
4 | - base
5 |
6 | dataset: Carla
7 |
8 | base_path: ${data.dataset_root_path}/carla-lidar/dataset
9 | input_path: ${data.dataset_root_path}/carla-lidar/dataset-p1n2
10 |
11 | drives: ['Town01-0', 'Town01-1', 'Town01-2',
12 | 'Town02-0', 'Town02-1', 'Town02-2',
13 | 'Town10-0', 'Town10-1', 'Town10-2', 'Town10-3', 'Town10-4']
14 |
15 | voxel_size: 0.1
16 |
17 | transforms: []
18 |
19 | supervision:
20 | sdf:
21 | max_dist: 0.6
22 |
23 | udf:
24 | max_dist: 0.6
25 |
26 | reconstruction:
27 | mask_threshold: 0.1
28 |
29 | evaluation:
30 | evaluator: "MeshEvaluator" # align evaluation with NKSR
--------------------------------------------------------------------------------
/config/data/carla_seg_patch.yaml:
--------------------------------------------------------------------------------
1 | # Managed by Hydra
2 |
3 | defaults:
4 | - base
5 |
6 | dataset: Carla
7 |
8 | base_path: ${data.dataset_root_path}/carla-lidar/dataset-seg-patch
9 | input_path:
10 |
11 | drives: ['Town01-0', 'Town01-1', 'Town01-2',
12 | 'Town02-0', 'Town02-1', 'Town02-2',
13 | 'Town10-0', 'Town10-1', 'Town10-2', 'Town10-3', 'Town10-4']
14 |
15 | voxel_size: 0.1
16 |
17 | transforms: []
18 |
19 | supervision:
20 | sdf:
21 | max_dist: 0.6
22 |
23 | udf:
24 | max_dist: 0.6
25 |
26 | reconstruction:
27 | mask_threshold: 0.1
28 |
29 | evaluation:
30 | evaluator: "MeshEvaluator" # align evaluation with NKSR
--------------------------------------------------------------------------------
/config/data/mixture.yaml:
--------------------------------------------------------------------------------
1 | # Managed by Hydra
2 |
3 | defaults:
4 | - base
5 |
6 | dataset: Mixture
7 |
8 | validation_set: Synthetic # ScanNet, Synthetic, or both
9 | metadata: # only for SscanNet
10 | metadata_path: ${data.dataset_root_path}/scannetv2/metadata
11 | train_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_train.txt
12 | val_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_val.txt
13 | test_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_test.txt
14 | combine_file: ${data.dataset_root_path}/scannetv2/metadata/scannetv2-labels.combined.tsv
15 |
16 | ScanNet:
17 | dataset_path: ${data.dataset_root_path}/scannetv2
18 | classes: 2
19 | # ignore_classes: [ 1, 2 ]
20 | class_names: [ 'floor', 'wall', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture',
21 | 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink',
22 | 'bathtub', 'otherfurniture' ]
23 |
24 | mapping_classes_ids: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39 ]
25 |
26 | Synthetic:
27 | path: ${data.dataset_root_path}/synthetic_data/synthetic_room_dataset
28 |
29 | input_type: pointcloud_crop
30 | classes: ['rooms_04', 'rooms_05', 'rooms_06', 'rooms_07', 'rooms_08']
31 | pointcloud_n: 10000
32 | std_dev: 0.00 # 0.005
33 | # points_subsample: 1024
34 | points_file: points_iou
35 | points_iou_file: points_iou
36 | pointcloud_file: pointcloud
37 | pointcloud_chamfer_file: pointcloud
38 | voxels_file: null
39 | multi_files: 10
40 | unit_size: 0.005 # size of a voxel (not used)
41 | query_vol_size: 25
42 |
--------------------------------------------------------------------------------
/config/data/scannet.yaml:
--------------------------------------------------------------------------------
1 | # Managed by Hydra
2 |
3 | defaults:
4 | - base
5 |
6 | dataset: Scannet
7 | dataset_path: ${data.dataset_root_path}/scannetv2
8 |
9 | metadata:
10 | metadata_path: ${data.dataset_root_path}/scannetv2/metadata
11 | train_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_train.txt
12 | val_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_val.txt
13 | test_list: ${data.dataset_root_path}/scannetv2/metadata/scannetv2_test.txt
14 | combine_file: ${data.dataset_root_path}/scannetv2/metadata/scannetv2-labels.combined.tsv
15 |
16 | supervision:
17 | sdf:
18 | max_dist: 0.2
19 | udf:
20 | max_dist: 0.2
21 |
22 | reconstruction:
23 | mask_threshold: 0.015
24 |
25 | evaluation:
26 | evaluator: "UnitMeshEvaluator"
27 |
--------------------------------------------------------------------------------
/config/data/scenenn.yaml:
--------------------------------------------------------------------------------
1 | # Managed by Hydra
2 |
3 | defaults:
4 | - base
5 |
6 | dataset: SceneNN
7 | dataset_path: ${data.dataset_root_path}/scenenn_seg_76_raw/scenenn_sub_data
8 | train_files: ['005', '014', '015', '016', '025', '036', '038', '041', '045',
9 | '047', '052', '054', '057', '061', '062', '066', '071', '073', '078', '080',
10 | '084', '087', '089', '096', '098', '109', '201', '202', '209', '217', '223',
11 | '225', '227', '231', '234', '237', '240', '243', '249', '251', '255', '260',
12 | '263', '265', '270', '276', '279', '286', '294', '308', '522', '609', '613',
13 | '614', '623', '700']
14 | test_files: ['011', '021', '065', '032', '093', '246', '086', '069', '206',
15 | '252', '273', '527', '621', '076', '082', '049', '207', '213', '272', '074']
16 |
17 | supervision:
18 | sdf:
19 | max_dist: 0.2
20 | udf:
21 | max_dist: 0.2
22 |
23 | reconstruction:
24 | mask_threshold: 0.015
25 |
26 | evaluation:
27 | evaluator: "UnitMeshEvaluator"
--------------------------------------------------------------------------------
/config/data/synthetic.yaml:
--------------------------------------------------------------------------------
1 | # Managed by Hydra
2 |
3 | defaults:
4 | - base
5 |
6 | dataset: Synthetic
7 | path: ${data.dataset_root_path}/synthetic_data/synthetic_room_dataset
8 |
9 | input_type: pointcloud_crop
10 | classes: ['rooms_04', 'rooms_05', 'rooms_06', 'rooms_07', 'rooms_08']
11 | pointcloud_n: 10000
12 | # std_dev: 0.00 # 0.005
13 | # points_subsample: 1024
14 | points_file: points_iou
15 | points_iou_file: points_iou
16 | pointcloud_file: pointcloud
17 | pointcloud_chamfer_file: pointcloud
18 | voxels_file: null
19 | multi_files: 10
20 | unit_size: 0.005 # size of a voxel (not used)
21 | query_vol_size: 25
22 |
23 | supervision:
24 | sdf:
25 | max_dist: 0.2
26 |
27 | udf:
28 | max_dist: 0.2
29 |
30 | reconstruction:
31 | mask_threshold: 0.01
32 |
33 | evaluation:
34 | evaluator: "UnitMeshEvaluator"
35 |
--------------------------------------------------------------------------------
/config/model/base.yaml:
--------------------------------------------------------------------------------
1 | # Managed by Hydra
2 |
3 | ckpt_path:
4 |
5 | logger:
6 | # https://pytorch-lightning.readthedocs.io/en/stable/extensions/generated/pytorch_lightning.loggers.WandbLogger.html
7 | _target_: pytorch_lightning.loggers.WandbLogger
8 | project: noksr
9 | name: ${experiment_name}
10 |
11 | # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
12 | trainer:
13 | accelerator: gpu #cpu or gpu
14 | devices: auto
15 | num_nodes: 1
16 | max_epochs: 800
17 | max_steps: 200000
18 | num_sanity_val_steps: 8
19 | check_val_every_n_epoch: 4
20 | profiler: simple
21 |
22 | # https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html
23 | checkpoint_monitor:
24 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
25 | save_top_k: -1
26 | every_n_epochs: 4
27 | filename: "{epoch}"
28 | dirpath: ${exp_output_root_path}/training
29 |
30 |
31 | optimizer:
32 | name: Adam # SGD or Adam
33 | lr: 0.001
34 | warmup_steps_ratio: 0.1
35 |
36 | lr_decay: # for Adam
37 | decay_start_epoch: 250
38 |
39 |
40 | inference:
41 | split: test
--------------------------------------------------------------------------------
/config/model/carla_model.yaml:
--------------------------------------------------------------------------------
1 |
2 | # Managed by Hydra
3 |
4 | defaults:
5 | - base
6 |
7 | trainer:
8 | max_epochs: 100
9 | lr_decay: # for Adam
10 | decay_start_epoch: 60
11 |
12 | network:
13 | module: noksr
14 | use_color: False
15 | use_normal: True # True for SDF supervision
16 | use_xyz: True # True only for PointTransformerV3 debugging
17 |
18 | latent_dim: 32 #
19 | prepare_epochs: 350
20 | eval_algorithm: DMC # DMC or DensePC or MC
21 | grad_type: Analytical # Analytical or Numerical
22 | backbone: PointTransformerV3 # PointTransformerV3 or MinkUNet
23 |
24 | point_transformerv3: # default
25 | in_channels: 6
26 | order: ["z", "z-trans", "hilbert", "hilbert-trans"]
27 | stride: [8, 8, 4, 2]
28 | enc_depths: [2, 2, 2, 6, 2]
29 | enc_channels: [32, 64, 128, 256, 512]
30 | enc_num_head: [2, 4, 8, 16, 32]
31 | enc_patch_size: [64, 64, 64, 64, 64]
32 |
33 |
34 | dec_depths: [2, 2, 2, 2]
35 | dec_channels: [64, 64, 128, 256]
36 | dec_num_head: [4, 4, 8, 16]
37 | dec_patch_size: [64, 64, 64, 64]
38 |
39 | mlp_ratio: 4
40 | qkv_bias: true
41 | qk_scale: null
42 | attn_drop: 0.0
43 | proj_drop: 0.0
44 | drop_path: 0.3
45 | pre_norm: true
46 | shuffle_orders: true
47 | enable_rpe: false
48 | enable_flash: true
49 | upcast_attention: false
50 | upcast_softmax: false
51 | cls_mode: false
52 | pdnorm_bn: false
53 | pdnorm_ln: false
54 | pdnorm_decouple: true
55 | pdnorm_adaptive: false
56 | pdnorm_affine: true
57 | pdnorm_conditions: ["ScanNet", "S3DIS", "Structured3D"]
58 |
59 | # Define common settings as an anchor
60 | default_decoder:
61 | decoder_type: Decoder # or Decoder or SimpleDecoder or SimpleInterpolatedDecoder or InterpolatedDecoder or MultiScaleInterpolatedDecoder
62 | backbone: ${model.network.backbone}
63 | decoder_channels: ${model.network.point_transformerv3.dec_channels}
64 | stride: ${model.network.point_transformerv3.stride}
65 | coords_enc: Fourier # MLP or Fourier (16 dim MLP or 63 dim Fourier)
66 | architecture: point_nerf # attentive_pooling or transformer_encoder or point_nerf
67 | activation: LeakyReLU # LeakyReLU or ReLU or Softplus or ShiftedSoftplus
68 | negative_slope: 0.01 # for LeakyReLU, default: 0.01
69 |
70 | neighboring: Serial # KNN or Mink or Serial or Mixture
71 | serial_neighbor_layers: 3 # number of serial neighbor layers
72 | k_neighbors: 8 # 1 for no interpolation
73 | dist_factor: [1, 1, 1, 1] # N times voxel size
74 | serial_orders: ['hilbert'] # 'z', 'hilbert', 'z-trans', 'hilbert-trans'
75 | last_n_layers: 4 # decodes features from last n layers of the Unet backbone
76 | feature_dim: [8, 8, 8, 8]
77 | point_nerf_hidden_dim: 32
78 | point_nerf_before_skip: 1
79 | point_nerf_after_skip: 1
80 |
81 | num_hidden_layers_before: 2
82 | num_hidden_layers_after: 2
83 | hidden_dim: 32
84 |
85 |
86 | # Specific decoder configurations
87 | sdf_decoder:
88 | decoder_type: ${model.network.default_decoder.decoder_type}
89 | backbone: ${model.network.default_decoder.backbone}
90 | decoder_channels: ${model.network.default_decoder.decoder_channels}
91 | stride: ${model.network.default_decoder.stride}
92 | coords_enc: ${model.network.default_decoder.coords_enc}
93 | activation: ${model.network.default_decoder.activation}
94 | negative_slope: ${model.network.default_decoder.negative_slope}
95 |
96 | neighboring: ${model.network.default_decoder.neighboring}
97 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers}
98 | k_neighbors: ${model.network.default_decoder.k_neighbors}
99 | dist_factor: ${model.network.default_decoder.dist_factor}
100 | serial_orders: ${model.network.default_decoder.serial_orders}
101 | last_n_layers: ${model.network.default_decoder.last_n_layers}
102 | feature_dim: ${model.network.default_decoder.feature_dim}
103 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim}
104 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip}
105 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip}
106 |
107 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before}
108 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after}
109 | hidden_dim: ${model.network.default_decoder.hidden_dim}
110 |
111 |
112 | mask_decoder:
113 | decoder_type: ${model.network.default_decoder.decoder_type}
114 | backbone: ${model.network.default_decoder.backbone}
115 | decoder_channels: ${model.network.default_decoder.decoder_channels}
116 | stride: ${model.network.default_decoder.stride}
117 | coords_enc: ${model.network.default_decoder.coords_enc}
118 | activation: ${model.network.default_decoder.activation}
119 | negative_slope: ${model.network.default_decoder.negative_slope}
120 |
121 | neighboring: ${model.network.default_decoder.neighboring}
122 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers}
123 | k_neighbors: ${model.network.default_decoder.k_neighbors}
124 | dist_factor: ${model.network.default_decoder.dist_factor}
125 | serial_orders: ${model.network.default_decoder.serial_orders}
126 | last_n_layers: ${model.network.default_decoder.last_n_layers}
127 | feature_dim: ${model.network.default_decoder.feature_dim}
128 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim}
129 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip}
130 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip}
131 |
132 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before}
133 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after}
134 | hidden_dim: ${model.network.default_decoder.hidden_dim}
135 |
136 |
137 |
--------------------------------------------------------------------------------
/config/model/scannet_model.yaml:
--------------------------------------------------------------------------------
1 |
2 | # Managed by Hydra
3 |
4 | defaults:
5 | - base
6 |
7 | trainer:
8 | max_epochs: 160
9 | lr_decay: # for Adam
10 | decay_start_epoch: 100
11 |
12 | network:
13 | module: noksr
14 | use_color: False
15 | use_normal: True
16 | use_xyz: True
17 |
18 | latent_dim: 32 #
19 | prepare_epochs: 350
20 | eval_algorithm: DMC # DMC or DensePC or MC
21 | grad_type: Analytical # Analytical or Numerical
22 | backbone: PointTransformerV3 # PointTransformerV3 or MinkUNet
23 |
24 | point_transformerv3: # default
25 | in_channels: 6
26 | order: ["z", "z-trans", "hilbert", "hilbert-trans"]
27 | stride: [2, 2, 4, 4]
28 | enc_depths: [2, 2, 2, 6, 2]
29 | enc_channels: [32, 64, 128, 256, 512]
30 | enc_num_head: [2, 4, 8, 16, 32]
31 | enc_patch_size: [64, 64, 64, 64, 64]
32 |
33 |
34 | dec_depths: [2, 2, 2, 2]
35 | dec_channels: [64, 64, 128, 256]
36 | dec_num_head: [4, 4, 8, 16]
37 | dec_patch_size: [64, 64, 64, 64]
38 |
39 | mlp_ratio: 4
40 | qkv_bias: true
41 | qk_scale: null
42 | attn_drop: 0.0
43 | proj_drop: 0.0
44 | drop_path: 0.3
45 | pre_norm: true
46 | shuffle_orders: true
47 | enable_rpe: false
48 | enable_flash: true
49 | upcast_attention: false
50 | upcast_softmax: false
51 | cls_mode: false
52 | pdnorm_bn: false
53 | pdnorm_ln: false
54 | pdnorm_decouple: true
55 | pdnorm_adaptive: false
56 | pdnorm_affine: true
57 | pdnorm_conditions: ["ScanNet", "S3DIS", "Structured3D"]
58 |
59 | # Define common settings as an anchor
60 | default_decoder:
61 | decoder_type: Decoder
62 | backbone: ${model.network.backbone}
63 | decoder_channels: ${model.network.point_transformerv3.dec_channels}
64 | stride: ${model.network.point_transformerv3.stride}
65 | coords_enc: Fourier # MLP or Fourier (16 dim MLP or 63 dim Fourier)
66 | architecture: point_nerf # attentive_pooling or transformer_encoder or point_nerf
67 | negative_slope: 0.01 # for LeakyReLU, default: 0.01
68 |
69 | neighboring: Serial # KNN or Mink or Serial or Mixture
70 | serial_neighbor_layers: 3 # number of serial neighbor layers
71 | k_neighbors: 8 # 1 for no interpolation
72 | dist_factor: [4, 3, 2, 2] # N times voxel size
73 | serial_orders: ['hilbert'] # 'z', 'hilbert', 'z-trans', 'hilbert-trans'
74 | last_n_layers: 4 # decodes features from last n layers of the Unet backbone
75 | feature_dim: [8, 8, 8, 8]
76 | point_nerf_hidden_dim: 32
77 | point_nerf_before_skip: 1
78 | point_nerf_after_skip: 1
79 |
80 | num_hidden_layers_before: 2
81 | num_hidden_layers_after: 2
82 | hidden_dim: 32
83 |
84 | activation: LeakyReLU # LeakyReLU or ReLU or Softplus or ShiftedSoftplus
85 |
86 | # Specific decoder configurations
87 | sdf_decoder:
88 | decoder_type: ${model.network.default_decoder.decoder_type}
89 | backbone: ${model.network.default_decoder.backbone}
90 | decoder_channels: ${model.network.default_decoder.decoder_channels}
91 | stride: ${model.network.default_decoder.stride}
92 | coords_enc: ${model.network.default_decoder.coords_enc}
93 | negative_slope: ${model.network.default_decoder.negative_slope}
94 |
95 | neighboring: ${model.network.default_decoder.neighboring}
96 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers}
97 | k_neighbors: ${model.network.default_decoder.k_neighbors}
98 | dist_factor: ${model.network.default_decoder.dist_factor}
99 | serial_orders: ${model.network.default_decoder.serial_orders}
100 | last_n_layers: ${model.network.default_decoder.last_n_layers}
101 | feature_dim: ${model.network.default_decoder.feature_dim}
102 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim}
103 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip}
104 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip}
105 |
106 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before}
107 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after}
108 | hidden_dim: ${model.network.default_decoder.hidden_dim}
109 |
110 | activation: ${model.network.default_decoder.activation}
111 |
112 | mask_decoder:
113 | decoder_type: ${model.network.default_decoder.decoder_type}
114 | backbone: ${model.network.default_decoder.backbone}
115 | decoder_channels: ${model.network.default_decoder.decoder_channels}
116 | stride: ${model.network.default_decoder.stride}
117 | coords_enc: ${model.network.default_decoder.coords_enc}
118 | negative_slope: ${model.network.default_decoder.negative_slope}
119 |
120 | neighboring: ${model.network.default_decoder.neighboring}
121 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers}
122 | k_neighbors: ${model.network.default_decoder.k_neighbors}
123 | dist_factor: ${model.network.default_decoder.dist_factor}
124 | serial_orders: ${model.network.default_decoder.serial_orders}
125 | last_n_layers: ${model.network.default_decoder.last_n_layers}
126 | feature_dim: ${model.network.default_decoder.feature_dim}
127 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim}
128 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip}
129 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip}
130 |
131 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before}
132 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after}
133 | hidden_dim: ${model.network.default_decoder.hidden_dim}
134 |
135 | activation: ${model.network.default_decoder.activation}
136 |
137 |
138 |
--------------------------------------------------------------------------------
/config/model/scenenn_model.yaml:
--------------------------------------------------------------------------------
1 |
2 | # Managed by Hydra
3 |
4 | defaults:
5 | - base
6 |
7 | trainer:
8 | max_epochs: 160
9 | lr_decay: # for Adam
10 | decay_start_epoch: 100
11 |
12 | network:
13 | module: noksr
14 | use_color: False
15 | use_normal: True
16 | use_xyz: True
17 |
18 | latent_dim: 32 #
19 | prepare_epochs: 350
20 | eval_algorithm: DMC # DMC or DensePC or MC
21 | grad_type: Analytical # Analytical or Numerical
22 | backbone: PointTransformerV3 # PointTransformerV3 or MinkUNet
23 |
24 | point_transformerv3: # default
25 | in_channels: 6
26 | order: ["z", "z-trans", "hilbert", "hilbert-trans"]
27 | stride: [2, 2, 4, 4]
28 | enc_depths: [2, 2, 2, 6, 2]
29 | enc_channels: [32, 64, 128, 256, 512]
30 | enc_num_head: [2, 4, 8, 16, 32]
31 | enc_patch_size: [64, 64, 64, 64, 64]
32 |
33 |
34 | dec_depths: [2, 2, 2, 2]
35 | dec_channels: [64, 64, 128, 256]
36 | dec_num_head: [4, 4, 8, 16]
37 | dec_patch_size: [64, 64, 64, 64]
38 |
39 | mlp_ratio: 4
40 | qkv_bias: true
41 | qk_scale: null
42 | attn_drop: 0.0
43 | proj_drop: 0.0
44 | drop_path: 0.3
45 | pre_norm: true
46 | shuffle_orders: true
47 | enable_rpe: false
48 | enable_flash: true
49 | upcast_attention: false
50 | upcast_softmax: false
51 | cls_mode: false
52 | pdnorm_bn: false
53 | pdnorm_ln: false
54 | pdnorm_decouple: true
55 | pdnorm_adaptive: false
56 | pdnorm_affine: true
57 | pdnorm_conditions: ["ScanNet", "S3DIS", "Structured3D"]
58 |
59 | # Define common settings as an anchor
60 | default_decoder:
61 | decoder_type: Decoder
62 | backbone: ${model.network.backbone}
63 | decoder_channels: ${model.network.point_transformerv3.dec_channels}
64 | stride: ${model.network.point_transformerv3.stride}
65 | coords_enc: Fourier # MLP or Fourier (16 dim MLP or 63 dim Fourier)
66 | architecture: point_nerf # attentive_pooling or transformer_encoder or point_nerf
67 | negative_slope: 0.01 # for LeakyReLU, default: 0.01
68 |
69 | neighboring: KNN # KNN or Mink or Serial or Mixture
70 | serial_neighbor_layers: 3 # number of serial neighbor layers
71 | k_neighbors: 4 # 1 for no interpolation
72 | dist_factor: [8, 4, 4, 4] # N times voxel size
73 | serial_orders: ['hilbert'] # 'z', 'hilbert', 'z-trans', 'hilbert-trans'
74 | last_n_layers: 4 # decodes features from last n layers of the Unet backbone
75 | feature_dim: [8, 8, 8, 8]
76 | point_nerf_hidden_dim: 32
77 | point_nerf_before_skip: 1
78 | point_nerf_after_skip: 1
79 |
80 | num_hidden_layers_before: 2
81 | num_hidden_layers_after: 2
82 | hidden_dim: 32
83 |
84 | activation: LeakyReLU # LeakyReLU or ReLU or Softplus or ShiftedSoftplus
85 |
86 | # Specific decoder configurations
87 | sdf_decoder:
88 | decoder_type: ${model.network.default_decoder.decoder_type}
89 | backbone: ${model.network.default_decoder.backbone}
90 | decoder_channels: ${model.network.default_decoder.decoder_channels}
91 | stride: ${model.network.default_decoder.stride}
92 | coords_enc: ${model.network.default_decoder.coords_enc}
93 | negative_slope: ${model.network.default_decoder.negative_slope}
94 |
95 | neighboring: ${model.network.default_decoder.neighboring}
96 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers}
97 | k_neighbors: ${model.network.default_decoder.k_neighbors}
98 | dist_factor: ${model.network.default_decoder.dist_factor}
99 | serial_orders: ${model.network.default_decoder.serial_orders}
100 | last_n_layers: ${model.network.default_decoder.last_n_layers}
101 | feature_dim: ${model.network.default_decoder.feature_dim}
102 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim}
103 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip}
104 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip}
105 |
106 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before}
107 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after}
108 | hidden_dim: ${model.network.default_decoder.hidden_dim}
109 |
110 | activation: ${model.network.default_decoder.activation}
111 |
112 | mask_decoder:
113 | decoder_type: ${model.network.default_decoder.decoder_type}
114 | backbone: ${model.network.default_decoder.backbone}
115 | decoder_channels: ${model.network.default_decoder.decoder_channels}
116 | stride: ${model.network.default_decoder.stride}
117 | coords_enc: ${model.network.default_decoder.coords_enc}
118 | negative_slope: ${model.network.default_decoder.negative_slope}
119 |
120 | neighboring: ${model.network.default_decoder.neighboring}
121 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers}
122 | k_neighbors: ${model.network.default_decoder.k_neighbors}
123 | dist_factor: ${model.network.default_decoder.dist_factor}
124 | serial_orders: ${model.network.default_decoder.serial_orders}
125 | last_n_layers: ${model.network.default_decoder.last_n_layers}
126 | feature_dim: ${model.network.default_decoder.feature_dim}
127 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim}
128 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip}
129 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip}
130 |
131 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before}
132 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after}
133 | hidden_dim: ${model.network.default_decoder.hidden_dim}
134 |
135 | activation: ${model.network.default_decoder.activation}
136 |
137 |
138 |
--------------------------------------------------------------------------------
/config/model/synthetic_model.yaml:
--------------------------------------------------------------------------------
1 |
2 | # Managed by Hydra
3 |
4 | defaults:
5 | - base
6 |
7 | trainer:
8 | max_epochs: 160
9 | lr_decay: # for Adam
10 | decay_start_epoch: 100
11 |
12 | network:
13 | module: noksr
14 | use_color: False
15 | use_normal: True # True for SDF supervision
16 | use_xyz: True # True only for PointTransformerV3 debugging
17 |
18 | latent_dim: 32 #
19 | prepare_epochs: 350
20 | eval_algorithm: DMC # DMC or DensePC or MC
21 | grad_type: Analytical # Analytical or Numerical
22 | backbone: PointTransformerV3 # PointTransformerV3 or MinkUNet
23 |
24 | point_transformerv3: # default
25 | in_channels: 6
26 | order: ["z", "z-trans", "hilbert", "hilbert-trans"]
27 | stride: [2, 2, 4, 4]
28 | enc_depths: [2, 2, 2, 6, 2]
29 | enc_channels: [32, 64, 128, 256, 512]
30 | enc_num_head: [2, 4, 8, 16, 32]
31 | enc_patch_size: [64, 64, 64, 64, 64]
32 |
33 |
34 | dec_depths: [2, 2, 2, 2]
35 | dec_channels: [64, 64, 128, 256]
36 | dec_num_head: [4, 4, 8, 16]
37 | dec_patch_size: [64, 64, 64, 64]
38 |
39 | mlp_ratio: 4
40 | qkv_bias: true
41 | qk_scale: null
42 | attn_drop: 0.0
43 | proj_drop: 0.0
44 | drop_path: 0.3
45 | pre_norm: true
46 | shuffle_orders: true
47 | enable_rpe: false
48 | enable_flash: true
49 | upcast_attention: false
50 | upcast_softmax: false
51 | cls_mode: false
52 | pdnorm_bn: false
53 | pdnorm_ln: false
54 | pdnorm_decouple: true
55 | pdnorm_adaptive: false
56 | pdnorm_affine: true
57 | pdnorm_conditions: ["ScanNet", "S3DIS", "Structured3D"]
58 |
59 | # Define common settings as an anchor
60 | default_decoder:
61 | decoder_type: Decoder # or Decoder or SimpleDecoder or SimpleInterpolatedDecoder or InterpolatedDecoder or MultiScaleInterpolatedDecoder
62 | backbone: ${model.network.backbone}
63 | decoder_channels: ${model.network.point_transformerv3.dec_channels}
64 | stride: ${model.network.point_transformerv3.stride}
65 | coords_enc: Fourier # MLP or Fourier (16 dim MLP or 63 dim Fourier)
66 | architecture: point_nerf # attentive_pooling or transformer_encoder or point_nerf
67 | activation: LeakyReLU # LeakyReLU or ReLU or Softplus or ShiftedSoftplus
68 | negative_slope: 0.01 # for LeakyReLU, default: 0.01
69 |
70 | neighboring: Serial # KNN or Mink or Serial or Mixture
71 | serial_neighbor_layers: 3 # number of serial neighbor layers
72 | k_neighbors: 4 # 1 for no interpolation
73 | dist_factor: [4, 3, 2, 2] # N times voxel size
74 | serial_orders: ['hilbert'] # 'z', 'hilbert', 'z-trans', 'hilbert-trans'
75 | last_n_layers: 4 # decodes features from last n layers of the Unet backbone
76 | feature_dim: [8, 8, 8, 8]
77 | point_nerf_hidden_dim: 32
78 | point_nerf_before_skip: 1
79 | point_nerf_after_skip: 1
80 |
81 | num_hidden_layers_before: 2
82 | num_hidden_layers_after: 2
83 | hidden_dim: 32
84 |
85 |
86 | # Specific decoder configurations
87 | sdf_decoder:
88 | decoder_type: ${model.network.default_decoder.decoder_type}
89 | backbone: ${model.network.default_decoder.backbone}
90 | decoder_channels: ${model.network.default_decoder.decoder_channels}
91 | stride: ${model.network.default_decoder.stride}
92 | coords_enc: ${model.network.default_decoder.coords_enc}
93 | activation: ${model.network.default_decoder.activation}
94 | negative_slope: ${model.network.default_decoder.negative_slope}
95 |
96 | neighboring: ${model.network.default_decoder.neighboring}
97 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers}
98 | k_neighbors: ${model.network.default_decoder.k_neighbors}
99 | dist_factor: ${model.network.default_decoder.dist_factor}
100 | serial_orders: ${model.network.default_decoder.serial_orders}
101 | last_n_layers: ${model.network.default_decoder.last_n_layers}
102 | feature_dim: ${model.network.default_decoder.feature_dim}
103 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim}
104 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip}
105 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip}
106 |
107 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before}
108 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after}
109 | hidden_dim: ${model.network.default_decoder.hidden_dim}
110 |
111 |
112 | mask_decoder:
113 | decoder_type: ${model.network.default_decoder.decoder_type}
114 | backbone: ${model.network.default_decoder.backbone}
115 | decoder_channels: ${model.network.default_decoder.decoder_channels}
116 | stride: ${model.network.default_decoder.stride}
117 | coords_enc: ${model.network.default_decoder.coords_enc}
118 | activation: ${model.network.default_decoder.activation}
119 | negative_slope: ${model.network.default_decoder.negative_slope}
120 |
121 | neighboring: ${model.network.default_decoder.neighboring}
122 | serial_neighbor_layers: ${model.network.default_decoder.serial_neighbor_layers}
123 | k_neighbors: ${model.network.default_decoder.k_neighbors}
124 | dist_factor: ${model.network.default_decoder.dist_factor}
125 | serial_orders: ${model.network.default_decoder.serial_orders}
126 | last_n_layers: ${model.network.default_decoder.last_n_layers}
127 | feature_dim: ${model.network.default_decoder.feature_dim}
128 | point_nerf_hidden_dim: ${model.network.default_decoder.point_nerf_hidden_dim}
129 | point_nerf_before_skip: ${model.network.default_decoder.point_nerf_before_skip}
130 | point_nerf_after_skip: ${model.network.default_decoder.point_nerf_after_skip}
131 |
132 | num_hidden_layers_before: ${model.network.default_decoder.num_hidden_layers_before}
133 | num_hidden_layers_after: ${model.network.default_decoder.num_hidden_layers_after}
134 | hidden_dim: ${model.network.default_decoder.hidden_dim}
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/data:
--------------------------------------------------------------------------------
1 | /localhome/zla247/theia2_data
--------------------------------------------------------------------------------
/eval.log:
--------------------------------------------------------------------------------
1 | [2024-11-25 21:56:55,155][pycg.exp][WARNING] - Empty pointcloud / mesh detected! Return NaN metric!
2 | [2024-11-25 23:24:31,010][pycg.exp][WARNING] - Empty pointcloud / mesh detected! Return NaN metric!
3 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import hydra
3 | import torch
4 | from tqdm import tqdm
5 | import numpy as np
6 | import open3d as o3d
7 | from collections import defaultdict
8 | from torch.utils.data import Dataset
9 | from pathlib import Path
10 | from importlib import import_module
11 | import pytorch_lightning as pl
12 | from noksr.data.data_module import DataModule
13 | from noksr.utils.evaluation import UnitMeshEvaluator, MeshEvaluator
14 | from noksr.model.module import Generator
15 | from nksr.svh import SparseFeatureHierarchy
16 | from noksr.utils.segmentation import segment_and_generate_encoder_outputs
17 |
18 | @hydra.main(version_base=None, config_path="config", config_name="config")
19 | def main(cfg):
20 | assert torch.cuda.is_available()
21 | device = torch.device("cuda")
22 | # fix the seed
23 | pl.seed_everything(cfg.global_test_seed, workers=True)
24 |
25 | output_path = os.path.join(cfg.exp_output_root_path, "reconstruction", "visualizations")
26 | os.makedirs(output_path, exist_ok=True)
27 | # output_path = cfg.exp_output_root_path
28 |
29 | print("==> initializing data ...")
30 | data_module = DataModule(cfg)
31 | data_module.setup("test")
32 | val_loader = data_module.val_dataloader()
33 |
34 | print("=> initializing model...")
35 | model = getattr(import_module("noksr.model"), cfg.model.network.module)(cfg)
36 | # Load checkpoint
37 | if os.path.isfile(cfg.model.ckpt_path):
38 | print(f"=> loading model checkpoint '{cfg.model.ckpt_path}'")
39 | checkpoint = torch.load(cfg.model.ckpt_path, map_location=device)
40 | model.load_state_dict(checkpoint['state_dict'])
41 | print("=> loaded checkpoint successfully.")
42 | else:
43 | raise FileNotFoundError(f"No checkpoint found at '{cfg.model.ckpt_path}'. Please ensure the path is correct.")
44 |
45 | model.to(device)
46 | model.eval()
47 |
48 | if cfg.data.reconstruction.trim:
49 | dense_generator = Generator(
50 | model.sdf_decoder,
51 | model.mask_decoder,
52 | cfg.data.voxel_size,
53 | cfg.model.network.sdf_decoder.k_neighbors,
54 | cfg.model.network.sdf_decoder.last_n_layers,
55 | cfg.data.reconstruction
56 | )
57 | else:
58 | dense_generator = Generator(
59 | model.sdf_decoder,
60 | None,
61 | cfg.data.voxel_size,
62 | cfg.model.network.sdf_decoder.k_neighbors,
63 | cfg.model.network.sdf_decoder.last_n_layers,
64 | cfg.data.reconstruction
65 | )
66 |
67 | # Initialize a dictionary to keep track of sums and count
68 | eval_sums = defaultdict(float)
69 | batch_count = 0
70 | import time
71 | print("=> start inference...")
72 | start_time = time.time()
73 | total_reconstruction_duration = 0.0
74 | total_neighboring_time = 0.0
75 | total_dmc_time = 0.0
76 | total_aggregation_time = 0.0
77 | total_forward_duration = 0.0
78 | total_decoder_time = 0.0
79 | results_dict = []
80 | for batch in tqdm(val_loader, desc="Inference", unit="batch"):
81 | batch = {k: v.to(device) if hasattr(v, 'to') else v for k, v in batch.items()}
82 | if 'gt_geometry' in batch:
83 | gt_geometry = batch['gt_geometry']
84 | batch['all_xyz'], batch['all_normal'], _ = gt_geometry[0].torch_attr()
85 | process_start = time.time()
86 | forward_start = time.time()
87 | if cfg.data.reconstruction.by_segment:
88 | encoder_outputs, encoding_codes, depth = segment_and_generate_encoder_outputs(
89 | batch=batch,
90 | model=model,
91 | device=device,
92 | segment_num=cfg.data.reconstruction.segment_num, # Example: 10 segments
93 | grid_size=0.01,
94 | serial_order='z'
95 | )
96 |
97 | forward_end = time.time()
98 | torch.set_grad_enabled(False)
99 | dmc_mesh, time_dict = dense_generator.generate_dual_mc_mesh_by_segment(
100 | data_dict=batch,
101 | encoder_outputs=encoder_outputs,
102 | encoding_codes=encoding_codes,
103 | depth=depth,
104 | device=device
105 | )
106 | else:
107 | if cfg.model.network.backbone == 'PointTransformerV3':
108 | pt_data = {}
109 | pt_data['feat'] = batch['point_features']
110 | pt_data['offset'] = batch['xyz_splits']
111 | pt_data['grid_size'] = 0.01
112 | pt_data['coord'] = batch['xyz']
113 | encoder_outputs = model.point_transformer(pt_data)
114 |
115 | forward_end = time.time()
116 | torch.set_grad_enabled(False)
117 |
118 | dmc_mesh, time_dict = dense_generator.generate_dual_mc_mesh(batch, encoder_outputs, device)
119 | total_neighboring_time += time_dict['neighboring_time']
120 | total_dmc_time += time_dict['dmc_time']
121 | total_aggregation_time += time_dict['aggregation_time']
122 | total_decoder_time += time_dict['decoder_time']
123 | # Calculate time taken for these three steps
124 | process_end = time.time()
125 | total_forward_duration += forward_end - forward_start
126 | process_duration = process_end - process_start
127 | total_reconstruction_duration += process_duration # Accumulate the duration
128 | print("\nTotal Reconstruction Time: {:.2f} seconds".format(total_reconstruction_duration))
129 | print("├── Total PointTransformerV3 Time: {:.2f} seconds".format(total_forward_duration))
130 | print("├── Total Decoder Time: {:.2f} seconds".format(total_decoder_time))
131 | print("│ ├── Total Neighboring Time: {:.2f} seconds".format(total_neighboring_time))
132 | print("│ └── Total Aggregation Time: {:.2f} seconds"
133 | .format(total_aggregation_time))
134 | print("├── Total Dual Marching Cube Time: {:.2f} seconds".format(total_dmc_time))
135 |
136 | # Evaluate the reconstructed mesh
137 | if cfg.data.evaluation.evaluator == "UnitMeshEvaluator":
138 | evaluator = UnitMeshEvaluator(n_points=100000, metric_names=UnitMeshEvaluator.ESSENTIAL_METRICS)
139 | elif cfg.data.evaluation.evaluator == "MeshEvaluator":
140 | evaluator = MeshEvaluator(n_points=int(5e6), metric_names=MeshEvaluator.ESSENTIAL_METRICS)
141 | if "gt_onet_sample" in batch:
142 | eval_dict = evaluator.eval_mesh(dmc_mesh, batch['all_xyz'], None, onet_samples=batch['gt_onet_sample'][0])
143 | else:
144 | eval_dict = evaluator.eval_mesh(dmc_mesh, batch['all_xyz'], None)
145 |
146 | # eval_dict['voxel_num'] = batch['voxel_nums'][0]
147 | for k, v in eval_dict.items():
148 | eval_sums[k] += v
149 | scene_name = batch['scene_names']
150 | eval_dict["data_id"] = batch_count
151 | eval_dict["scene_name"] = scene_name
152 | results_dict.append(eval_dict)
153 | print(f"Scene Name: {scene_name}")
154 | print(f"completeness: {eval_dict['completeness']:.4f}")
155 | print(f"accuracy: {eval_dict['accuracy']:.4f}")
156 | print(f"Chamfer-L2: {eval_dict['chamfer-L2']:.4f}")
157 | print(f"Chamfer-L1: {eval_dict['chamfer-L1']:.4f}")
158 | print(f"F-Score (1.0%): {eval_dict['f-score-10']:.4f}")
159 | print(f"F-Score (1.5%): {eval_dict['f-score-15']:.4f}")
160 | print(f"F-Score (2.0%): {eval_dict['f-score-20']:.4f}")
161 | if "o3d-iou" in eval_dict:
162 | print(f"O3D-IoU: {eval_dict['o3d-iou']:.4f}")
163 | print() # For better readability
164 |
165 | # Save the mesh
166 | if cfg.data.visualization.save:
167 | scene_name = batch['scene_names'][0].replace('/', '-')
168 | if cfg.data.visualization.Mesh:
169 | mesh_file = f"{output_path}/noksr-{cfg.data.dataset}-data_id_{cfg.data.intake_start}-{scene_name}_CD_{eval_dict['chamfer-L1']:.4f}_mesh.obj" # or "output_mesh.obj" for OBJ format
170 | o3d.io.write_triangle_mesh(mesh_file, dmc_mesh)
171 | if cfg.data.visualization.Input_points:
172 | pcd = o3d.geometry.PointCloud()
173 | pcd.points = o3d.utility.Vector3dVector(batch['xyz'].cpu().numpy())
174 | point_cloud_file = f"{output_path}/noksr-{cfg.data.dataset}-data_id_{cfg.data.intake_start}-{scene_name}_input_pcd.ply"
175 | o3d.io.write_point_cloud(point_cloud_file, pcd)
176 | if cfg.data.visualization.Dense_points:
177 | pcd = o3d.geometry.PointCloud()
178 | pcd.points = o3d.utility.Vector3dVector(batch['all_xyz'].cpu().numpy())
179 | point_cloud_file = f"{output_path}/noksr-{cfg.data.dataset}-data_id_{cfg.data.intake_start}-{scene_name}_dense_pcd.ply"
180 | o3d.io.write_point_cloud(point_cloud_file, pcd)
181 |
182 | batch_count += 1
183 | torch.cuda.empty_cache()
184 |
185 | end_time = time.time()
186 | total_time = end_time - start_time
187 | print(f"\n---Total reconstruction time without data loading: {total_reconstruction_duration:.2f} seconds")
188 | print(f"---Total reconstruction time including data loading for all scenes: {total_time:.2f} seconds")
189 | if batch_count > 0:
190 | print("\n--- Evaluation Metrics Averages ---")
191 | for k in eval_sums:
192 | average = eval_sums[k] / batch_count
193 | print(f"{k}: {average:.5f}")
194 | else:
195 | print("No batches were processed.")
196 |
197 |
198 | if __name__ == "__main__":
199 | main()
200 |
--------------------------------------------------------------------------------
/noksr/callback/__init__.py:
--------------------------------------------------------------------------------
1 | from .gpu_cache_clean_callback import GPUCacheCleanCallback
2 |
--------------------------------------------------------------------------------
/noksr/callback/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/callback/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/callback/__pycache__/gpu_cache_clean_callback.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/callback/__pycache__/gpu_cache_clean_callback.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/callback/gpu_cache_clean_callback.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning.callbacks import Callback
2 | import torch
3 |
4 |
5 | class GPUCacheCleanCallback(Callback):
6 |
7 | def on_train_batch_start(self, *args, **kwargs):
8 | torch.cuda.empty_cache()
9 |
10 | def on_validation_batch_start(self, *args, **kwargs):
11 | torch.cuda.empty_cache()
12 |
13 | def on_test_batch_start(self, *args, **kwargs):
14 | torch.cuda.empty_cache()
15 |
16 | def on_predict_batch_start(self, *args, **kwargs):
17 | torch.cuda.empty_cache()
18 |
--------------------------------------------------------------------------------
/noksr/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_module import DataModule
--------------------------------------------------------------------------------
/noksr/data/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/__pycache__/data_module.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/__pycache__/data_module.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/data_module.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import DataLoader
5 | import torch
6 | from torch.utils.data import Sampler, DistributedSampler, Dataset
7 | import pytorch_lightning as pl
8 | from arrgh import arrgh
9 |
10 |
11 | class DataModule(pl.LightningDataModule):
12 | def __init__(self, data_cfg):
13 | super().__init__()
14 | self.data_cfg = data_cfg
15 | self.dataset = getattr(import_module('noksr.data.dataset'), data_cfg.data.dataset)
16 |
17 | def setup(self, stage=None):
18 | if stage == "fit" or stage is None:
19 | self.train_set = self.dataset(self.data_cfg, "train")
20 | self.val_set = self.dataset(self.data_cfg, "val")
21 | if stage == "test" or stage is None:
22 | self.val_set = self.dataset(self.data_cfg, self.data_cfg.model.inference.split)
23 | if stage == "predict" or stage is None:
24 | self.test_set = self.dataset(self.data_cfg, "test")
25 |
26 | def train_dataloader(self):
27 | return DataLoader(self.train_set, batch_size=self.data_cfg.data.batch_size, shuffle=True, pin_memory=True,
28 | collate_fn=_sparse_collate_fn, num_workers=self.data_cfg.data.num_workers, drop_last=True)
29 |
30 | def val_dataloader(self):
31 | return DataLoader(self.val_set, batch_size=1, pin_memory=True, collate_fn=_sparse_collate_fn,
32 | num_workers=self.data_cfg.data.num_workers)
33 |
34 | def test_dataloader(self):
35 | return DataLoader(self.val_set, batch_size=1, pin_memory=True, collate_fn=_sparse_collate_fn,
36 | num_workers=self.data_cfg.data.num_workers)
37 |
38 | def predict_dataloader(self):
39 | return DataLoader(self.test_set, batch_size=1, pin_memory=True, collate_fn=_sparse_collate_fn,
40 | num_workers=self.data_cfg.data.num_workers)
41 |
42 |
43 | def _sparse_collate_fn(batch):
44 | if "gt_geometry" in batch[0]:
45 | """ for dataset with ground truth geometry """
46 | data = {}
47 | xyz = []
48 | point_features = []
49 | gt_geometry_list = []
50 | scene_names_list = []
51 |
52 | for _, b in enumerate(batch):
53 | scene_names_list.append(b["scene_name"])
54 | xyz.append(torch.from_numpy(b["xyz"]))
55 | point_features.append(torch.from_numpy(b["point_features"]))
56 | gt_geometry_list.append(b["gt_geometry"])
57 |
58 | data['xyz'] = torch.cat(xyz, dim=0)
59 | data['point_features'] = torch.cat(point_features, dim=0)
60 | data['xyz_splits'] = torch.tensor([c.shape[0] for c in xyz])
61 | data['gt_geometry'] = gt_geometry_list
62 |
63 | data['scene_names'] = scene_names_list
64 |
65 | return data
66 |
67 | else:
68 | data = {}
69 | xyz = []
70 | all_xyz = []
71 | all_normals = []
72 | point_features = []
73 | scene_names_list = []
74 | for _, b in enumerate(batch):
75 | scene_names_list.append(b["scene_name"])
76 | xyz.append(torch.from_numpy(b["xyz"]))
77 | all_xyz.append(torch.from_numpy(b["all_xyz"]))
78 | all_normals.append(torch.from_numpy(b["all_normals"]))
79 | point_features.append(torch.from_numpy(b["point_features"]))
80 |
81 | data['all_xyz'] = torch.cat(all_xyz, dim=0)
82 | data['all_normals'] = torch.cat(all_normals, dim=0)
83 | data['xyz'] = torch.cat(xyz, dim=0)
84 | data['point_features'] = torch.cat(point_features, dim=0)
85 |
86 | data['scene_names'] = scene_names_list
87 | data['row_splits'] = [c.shape[0] for c in all_xyz]
88 | data['xyz_splits'] = torch.tensor([c.shape[0] for c in xyz])
89 | if "gt_onet_sample" in batch[0]:
90 | data['gt_onet_sample'] = [b["gt_onet_sample"] for b in batch]
91 | return data
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------
/noksr/data/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .general_dataset import GeneralDataset
2 | # from .scannet_rangeudf import ScannetRangeUDF
3 | from .scannet import Scannet
4 | from .carla import Carla
5 | from .synthetic import Synthetic
6 | from .mixture import Mixture
7 | from .general_dataset import DatasetSpec
8 | from .general_dataset import RandomSafeDataset
9 | from .carla_gt_geometry import get_class
10 | from .scenenn import SceneNN
11 |
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/augmentation.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/augmentation.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/augmentation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/augmentation.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/carla.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/carla.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/carla_gt_geometry.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/carla_gt_geometry.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/general_dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/general_dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/general_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/general_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/mixture.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/mixture.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/multiscan.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/multiscan.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/scannet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scannet.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/scannet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scannet.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/scannet_rangeudf.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scannet_rangeudf.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/scannet_rangeudf.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scannet_rangeudf.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/scenenet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scenenet.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/scenenn.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/scenenn.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/shapenet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/shapenet.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/synthetic.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/synthetic.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/voxelization_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/voxelization_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/voxelization_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/voxelization_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/voxelizer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/voxelizer.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/__pycache__/voxelizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/data/dataset/__pycache__/voxelizer.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/data/dataset/augmentation.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import logging
4 | import numpy as np
5 | import scipy
6 | import scipy.ndimage
7 | import scipy.interpolate
8 | import torch
9 |
10 |
11 | # A sparse tensor consists of coordinates and associated features.
12 | # You must apply augmentation to both.
13 | # In 2D, flip, shear, scale, and rotation of images are coordinate transformation
14 | # color jitter, hue, etc., are feature transformations
15 | ##############################
16 | # Feature transformations
17 | ##############################
18 | class ChromaticTranslation(object):
19 | '''Add random color to the image, input must be an array in [0,255] or a PIL image'''
20 |
21 | def __init__(self, trans_range_ratio=1e-1):
22 | '''
23 | trans_range_ratio: ratio of translation i.e. 255 * 2 * ratio * rand(-0.5, 0.5)
24 | '''
25 | self.trans_range_ratio = trans_range_ratio
26 |
27 | def __call__(self, coords, feats, labels):
28 | if random.random() < 0.95:
29 | tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.trans_range_ratio
30 | feats[:, :3] = np.clip(tr + feats[:, :3], 0, 255)
31 | return coords, feats, labels
32 |
33 |
34 | class ChromaticAutoContrast(object):
35 |
36 | def __init__(self, randomize_blend_factor=True, blend_factor=0.5):
37 | self.randomize_blend_factor = randomize_blend_factor
38 | self.blend_factor = blend_factor
39 |
40 | def __call__(self, coords, feats, labels):
41 | if random.random() < 0.2:
42 | # mean = np.mean(feats, 0, keepdims=True)
43 | # std = np.std(feats, 0, keepdims=True)
44 | # lo = mean - std
45 | # hi = mean + std
46 | lo = np.min(feats, 0, keepdims=True)
47 | hi = np.max(feats, 0, keepdims=True)
48 |
49 | scale = 255 / (hi - lo)
50 |
51 | contrast_feats = (feats - lo) * scale
52 |
53 | blend_factor = random.random() if self.randomize_blend_factor else self.blend_factor
54 | feats = (1 - blend_factor) * feats + blend_factor * contrast_feats
55 | return coords, feats, labels
56 |
57 |
58 | class ChromaticJitter(object):
59 |
60 | def __init__(self, std=0.01):
61 | self.std = std
62 |
63 | def __call__(self, coords, feats, labels):
64 | if random.random() < 0.95:
65 | noise = np.random.randn(feats.shape[0], 3)
66 | noise *= self.std * 255
67 | feats[:, :3] = np.clip(noise + feats[:, :3], 0, 255)
68 | return coords, feats, labels
69 |
70 |
71 | class HueSaturationTranslation(object):
72 |
73 | @staticmethod
74 | def rgb_to_hsv(rgb):
75 | # Translated from source of colorsys.rgb_to_hsv
76 | # r,g,b should be a numpy arrays with values between 0 and 255
77 | # rgb_to_hsv returns an array of floats between 0.0 and 1.0.
78 | rgb = rgb.astype('float')
79 | hsv = np.zeros_like(rgb)
80 | # in case an RGBA array was passed, just copy the A channel
81 | hsv[..., 3:] = rgb[..., 3:]
82 | r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
83 | maxc = np.max(rgb[..., :3], axis=-1)
84 | minc = np.min(rgb[..., :3], axis=-1)
85 | hsv[..., 2] = maxc
86 | mask = maxc != minc
87 | hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask]
88 | rc = np.zeros_like(r)
89 | gc = np.zeros_like(g)
90 | bc = np.zeros_like(b)
91 | rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask]
92 | gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask]
93 | bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask]
94 | hsv[..., 0] = np.select([r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc)
95 | hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0
96 | return hsv
97 |
98 | @staticmethod
99 | def hsv_to_rgb(hsv):
100 | # Translated from source of colorsys.hsv_to_rgb
101 | # h,s should be a numpy arrays with values between 0.0 and 1.0
102 | # v should be a numpy array with values between 0.0 and 255.0
103 | # hsv_to_rgb returns an array of uints between 0 and 255.
104 | rgb = np.empty_like(hsv)
105 | rgb[..., 3:] = hsv[..., 3:]
106 | h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
107 | i = (h * 6.0).astype('uint8')
108 | f = (h * 6.0) - i
109 | p = v * (1.0 - s)
110 | q = v * (1.0 - s * f)
111 | t = v * (1.0 - s * (1.0 - f))
112 | i = i % 6
113 | conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5]
114 | rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v)
115 | rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t)
116 | rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p)
117 | return rgb.astype('uint8')
118 |
119 | def __init__(self, hue_max, saturation_max):
120 | self.hue_max = hue_max
121 | self.saturation_max = saturation_max
122 |
123 | def __call__(self, coords, feats, labels):
124 | # Assume feat[:, :3] is rgb
125 | hsv = HueSaturationTranslation.rgb_to_hsv(feats[:, :3])
126 | hue_val = (random.random() - 0.5) * 2 * self.hue_max
127 | sat_ratio = 1 + (random.random() - 0.5) * 2 * self.saturation_max
128 | hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1)
129 | hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1)
130 | feats[:, :3] = np.clip(HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255)
131 |
132 | return coords, feats, labels
133 |
134 |
135 | ##############################
136 | # Coordinate transformations
137 | ##############################
138 | class RandomHorizontalFlip(object):
139 |
140 | def __init__(self, upright_axis, is_temporal):
141 | '''
142 | upright_axis: axis index among x,y,z, i.e. 2 for z
143 | '''
144 | self.is_temporal = is_temporal
145 | self.D = 4 if is_temporal else 3
146 | self.upright_axis = {'x': 0, 'y': 1, 'z': 2}[upright_axis.lower()]
147 | # Use the rest of axes for flipping.
148 | self.horz_axes = set(range(self.D)) - set([self.upright_axis])
149 |
150 | def __call__(self, coords, feats, labels):
151 | if random.random() < 0.95:
152 | for curr_ax in self.horz_axes:
153 | if random.random() < 0.5:
154 | coord_max = np.max(coords[:, curr_ax])
155 | coords[:, curr_ax] = coord_max - coords[:, curr_ax]
156 | return coords, feats, labels
157 |
158 |
159 | class ElasticDistortion:
160 |
161 | def __init__(self, distortion_params):
162 | self.distortion_params = distortion_params
163 |
164 | def elastic_distortion(self, coords, granularity, magnitude):
165 | '''Apply elastic distortion on sparse coordinate space.
166 |
167 | pointcloud: numpy array of (number of points, at least 3 spatial dims)
168 | granularity: size of the noise grid (in same scale[m/cm] as the voxel grid)
169 | magnitude: noise multiplier
170 | '''
171 | blurx = np.ones((3, 1, 1, 1)).astype('float32') / 3
172 | blury = np.ones((1, 3, 1, 1)).astype('float32') / 3
173 | blurz = np.ones((1, 1, 3, 1)).astype('float32') / 3
174 | coords_min = coords.min(0)
175 |
176 | # Create Gaussian noise tensor of the size given by granularity.
177 | noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3
178 | noise = np.random.randn(*noise_dim, 3).astype(np.float32)
179 |
180 | # Smoothing.
181 | for _ in range(2):
182 | noise = scipy.ndimage.filters.convolve(noise, blurx, mode='constant', cval=0)
183 | noise = scipy.ndimage.filters.convolve(noise, blury, mode='constant', cval=0)
184 | noise = scipy.ndimage.filters.convolve(noise, blurz, mode='constant', cval=0)
185 |
186 | # Trilinear interpolate noise filters for each spatial dimensions.
187 | ax = [
188 | np.linspace(d_min, d_max, d)
189 | for d_min, d_max, d in zip(coords_min - granularity, coords_min +
190 | granularity * (noise_dim - 2), noise_dim)
191 | ]
192 | interp = scipy.interpolate.RegularGridInterpolator(ax, noise, bounds_error=0, fill_value=0)
193 | coords = coords + interp(coords) * magnitude
194 | return coords
195 |
196 | def __call__(self, pointcloud):
197 | if self.distortion_params is not None:
198 | if random.random() < 0.95:
199 | for granularity, magnitude in self.distortion_params:
200 | pointcloud = self.elastic_distortion(pointcloud, granularity, magnitude)
201 | return pointcloud
202 |
203 |
204 | class Compose(object):
205 | '''Composes several transforms together.'''
206 |
207 | def __init__(self, transforms):
208 | self.transforms = transforms
209 |
210 | def __call__(self, *args):
211 | for t in self.transforms:
212 | args = t(*args)
213 | return args
214 |
215 |
216 | class cfl_collate_fn_factory:
217 | '''Generates collate function for coords, feats, labels.
218 |
219 | Args:
220 | limit_numpoints: If 0 or False, does not alter batch size. If positive integer, limits batch
221 | size so that the number of input coordinates is below limit_numpoints.
222 | '''
223 |
224 | def __init__(self, limit_numpoints):
225 | self.limit_numpoints = limit_numpoints
226 |
227 | def __call__(self, list_data):
228 | coords, feats, labels = list(zip(*list_data))
229 | coords_batch, feats_batch, labels_batch = [], [], []
230 |
231 | batch_id = 0
232 | batch_num_points = 0
233 | for batch_id, _ in enumerate(coords):
234 | num_points = coords[batch_id].shape[0]
235 | batch_num_points += num_points
236 | if self.limit_numpoints and batch_num_points > self.limit_numpoints:
237 | num_full_points = sum(len(c) for c in coords)
238 | num_full_batch_size = len(coords)
239 | logging.warning(
240 | f'\t\tCannot fit {num_full_points} points into {self.limit_numpoints} points '
241 | f'limit. Truncating batch size at {batch_id} out of {num_full_batch_size} with {batch_num_points - num_points}.'
242 | )
243 | break
244 | coords_batch.append(
245 | torch.cat((torch.from_numpy(coords[batch_id]).int(),
246 | torch.ones(num_points, 1).int() * batch_id), 1))
247 | feats_batch.append(torch.from_numpy(feats[batch_id]))
248 | labels_batch.append(torch.from_numpy(labels[batch_id]).int())
249 |
250 | batch_id += 1
251 |
252 | # Concatenate all lists
253 | coords_batch = torch.cat(coords_batch, 0).int()
254 | feats_batch = torch.cat(feats_batch, 0).float()
255 | labels_batch = torch.cat(labels_batch, 0).int()
256 | return coords_batch, feats_batch, labels_batch
257 |
258 |
259 | class cflt_collate_fn_factory:
260 | '''Generates collate function for coords, feats, labels, point_clouds, transformations.
261 |
262 | Args:
263 | limit_numpoints: If 0 or False, does not alter batch size. If positive integer, limits batch
264 | size so that the number of input coordinates is below limit_numpoints.
265 | '''
266 |
267 | def __init__(self, limit_numpoints):
268 | self.limit_numpoints = limit_numpoints
269 |
270 | def __call__(self, list_data):
271 | coords, feats, labels, pointclouds, transformations = list(zip(*list_data))
272 | cfl_collate_fn = cfl_collate_fn_factory(limit_numpoints=self.limit_numpoints)
273 | coords_batch, feats_batch, labels_batch = cfl_collate_fn(list(zip(coords, feats, labels)))
274 | num_truncated_batch = coords_batch[:, -1].max().item() + 1
275 |
276 | batch_id = 0
277 | pointclouds_batch, transformations_batch = [], []
278 | for pointcloud, transformation in zip(pointclouds, transformations):
279 | if batch_id >= num_truncated_batch:
280 | break
281 | pointclouds_batch.append(
282 | torch.cat((torch.from_numpy(pointcloud), torch.ones(pointcloud.shape[0], 1) * batch_id),
283 | 1))
284 | transformations_batch.append(
285 | torch.cat(
286 | (torch.from_numpy(transformation), torch.ones(transformation.shape[0], 1) * batch_id),
287 | 1))
288 | batch_id += 1
289 |
290 | pointclouds_batch = torch.cat(pointclouds_batch, 0).float()
291 | transformations_batch = torch.cat(transformations_batch, 0).float()
292 | return coords_batch, feats_batch, labels_batch, pointclouds_batch, transformations_batch
293 |
294 | class RandomDropout(object):
295 |
296 | def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5):
297 | """
298 | upright_axis: axis index among x,y,z, i.e. 2 for z
299 | """
300 | self.dropout_ratio = dropout_ratio
301 | self.dropout_application_ratio = dropout_application_ratio
302 |
303 | def __call__(self, coords, feats, labels):
304 | if random.random() < self.dropout_ratio:
305 | N = len(coords)
306 | inds = np.random.choice(N, int(N * (1 - self.dropout_ratio)), replace=False)
307 | return coords[inds], feats[inds], labels[inds]
308 | return coords, feats, labels
--------------------------------------------------------------------------------
/noksr/data/dataset/carla.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from torch.utils.data import Dataset
4 |
5 | import numpy as np
6 |
7 | from noksr.utils.transform import ComposedTransforms
8 | from noksr.data.dataset.carla_gt_geometry import get_class
9 | from noksr.data.dataset.general_dataset import DatasetSpec as DS
10 | from noksr.data.dataset.general_dataset import RandomSafeDataset
11 |
12 | from pycg import exp
13 |
14 |
15 | class Carla(RandomSafeDataset):
16 | def __init__(self, cfg, split):
17 | super().__init__(cfg, split)
18 |
19 | self.skip_on_error = False
20 | self.custom_name = "carla"
21 | self.cfg = cfg
22 |
23 | self.split = split # use only train set for overfitting
24 | split = self.split
25 | self.intake_start = cfg.data.intake_start
26 | self.take = cfg.data.take
27 | self.input_splats = cfg.data.input_splats
28 |
29 | self.gt_type = cfg.data.supervision.gt_type
30 |
31 | self.transforms = ComposedTransforms(cfg.data.transforms)
32 | self.use_dummy_gt = False
33 |
34 | # If drives not specified, use all sub-folders
35 | base_path = Path(cfg.data.base_path)
36 | drives = cfg.data.drives
37 | if drives is None:
38 | drives = os.listdir(base_path)
39 | drives = [c for c in drives if (base_path / c).is_dir()]
40 | self.drives = drives
41 | self.input_path = cfg.data.input_path
42 |
43 | # Get all items
44 | self.all_items = []
45 | self.drive_base_paths = {}
46 | for c in drives:
47 | self.drive_base_paths[c] = base_path / c
48 | split_file = self.drive_base_paths[c] / (split + '.lst')
49 | with split_file.open('r') as f:
50 | models_c = f.read().split('\n')
51 | if '' in models_c:
52 | models_c.remove('')
53 | self.all_items += [{'drive': c, 'item': m} for m in models_c]
54 |
55 | if self.cfg.data.over_fitting:
56 | self.all_items = self.all_items[self.intake_start:self.take+self.intake_start]
57 |
58 |
59 |
60 | def __len__(self):
61 | return len(self.all_items)
62 |
63 | def get_name(self):
64 | return f"{self.custom_name}-cat{len(self.drives)}-{self.split}"
65 |
66 | def get_short_name(self):
67 | return self.custom_name
68 |
69 | def _get_item(self, data_id, rng):
70 | # self.num_input_points = 50000
71 | drive_name = self.all_items[data_id]['drive']
72 | item_name = self.all_items[data_id]['item']
73 |
74 | named_data = {}
75 |
76 | try:
77 | if self.input_path is None:
78 | input_data = np.load(self.drive_base_paths[drive_name] / item_name / 'pointcloud.npz')
79 | else:
80 | input_data = np.load(Path(self.input_path) / drive_name / item_name / 'pointcloud.npz')
81 | except FileNotFoundError:
82 | exp.logger.warning(f"File not found for AV dataset for {item_name}")
83 | raise ConnectionAbortedError
84 |
85 | named_data[DS.SHAPE_NAME] = "/".join([drive_name, item_name])
86 | named_data[DS.INPUT_PC]= input_data['points'].astype(np.float32)
87 | named_data[DS.TARGET_NORMAL] = input_data['normals'].astype(np.float32)
88 |
89 | if self.transforms is not None:
90 | named_data = self.transforms(named_data, rng)
91 |
92 | point_features = np.zeros(shape=(len(named_data[DS.INPUT_PC]), 0), dtype=np.float32)
93 | if self.cfg.model.network.use_normal:
94 | point_features = np.concatenate((point_features, named_data[DS.TARGET_NORMAL]), axis=1)
95 | if self.cfg.model.network.use_xyz:
96 | point_features = np.concatenate((point_features, named_data[DS.INPUT_PC]), axis=1) # add xyz to point features
97 |
98 | xyz = named_data[DS.INPUT_PC]
99 | normals = named_data[DS.TARGET_NORMAL]
100 |
101 |
102 | geom_cls = get_class(self.gt_type)
103 |
104 | if (self.drive_base_paths[drive_name] / item_name / "groundtruth.bin").exists():
105 | named_data[DS.GT_GEOMETRY] = geom_cls.load(self.drive_base_paths[drive_name] / item_name / "groundtruth.bin")
106 | data = {
107 | "gt_geometry": named_data[DS.GT_GEOMETRY],
108 | "xyz": xyz, # N, 3
109 | "normals": normals, # N, 3
110 | "scene_name": named_data[DS.SHAPE_NAME],
111 | "point_features": point_features, # N, K
112 | }
113 | else:
114 | data = {
115 | "all_xyz": input_data['ref_xyz'].astype(np.float32),
116 | "all_normals": input_data['ref_normals'].astype(np.float32),
117 | "xyz": xyz, # N, 3
118 | "normals": normals, # N, 3
119 | "scene_name": named_data[DS.SHAPE_NAME],
120 | "point_features": point_features, # N, K
121 | }
122 |
123 | return data
124 |
--------------------------------------------------------------------------------
/noksr/data/dataset/carla_gt_geometry.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from pathlib import Path
4 | from pycg.exp import lru_cache_class, logger
5 |
6 | from pycg.isometry import Isometry
7 | import torch.nn.functional as F
8 |
9 |
10 | class AVGroundTruthGeometry:
11 | def __init__(self):
12 | pass
13 |
14 | @classmethod
15 | def load(cls, path: Path):
16 | raise NotImplementedError
17 |
18 | def save(self, path: Path):
19 | raise NotImplementedError
20 |
21 | def crop(self, bounds: np.ndarray):
22 | # bounds: (C, 2, 3) min_coords and max_coords
23 | raise NotImplementedError
24 |
25 | def transform(self, iso: Isometry = Isometry(), scale: float = 1.0):
26 | # p <- s(Rp+t)
27 | raise NotImplementedError
28 |
29 |
30 | class DensePointsGroundTruthGeometry(AVGroundTruthGeometry):
31 | def __init__(self, xyz: np.ndarray, normal: np.ndarray):
32 | super().__init__()
33 | self.xyz = xyz
34 | self.normal = normal
35 | assert self.xyz.shape[0] == self.normal.shape[0]
36 | assert self.xyz.shape[1] == self.normal.shape[1] == 3
37 |
38 | def save(self, path: Path):
39 | with path.open("wb") as f:
40 | np.savez_compressed(f, xyz=self.xyz, normal=self.normal)
41 |
42 | def transform(self, iso: Isometry = Isometry(), scale: float = 1.0):
43 | self.xyz = scale * (iso @ self.xyz)
44 | self.normal = iso.rotation @ self.normal
45 |
46 | def is_empty(self):
47 | return self.xyz.shape[0] < 64
48 |
49 | @classmethod
50 | def load(cls, path: Path):
51 | res = np.load(path, allow_pickle=True)
52 | inst = cls(res['xyz'], res['normal'])
53 | return inst
54 |
55 | @classmethod
56 | def empty(cls):
57 | return cls(np.zeros((0, 3)), np.zeros((0, 3)))
58 |
59 | @lru_cache_class(maxsize=None)
60 | def torch_attr(self):
61 | return torch.from_numpy(self.xyz).float().cuda(), torch.from_numpy(self.normal).float().cuda()
62 |
63 | def query_sdf(self, queries: torch.Tensor):
64 | import ext
65 | all_points_torch, all_normals_torch = self.torch_attr()
66 |
67 | sdf_kwargs = {
68 | 'queries': queries, 'ref_xyz': all_points_torch, 'ref_normal': all_normals_torch,
69 | 'nb_points': 8, 'stdv': 3.0, 'adaptive_knn': 8
70 | }
71 | try:
72 | query_sdf = -ext.sdfgen.sdf_from_points(**sdf_kwargs)[0]
73 | except MemoryError:
74 | logger.warning("Query SDF OOM. Try empty pytorch cache.")
75 | torch.cuda.empty_cache()
76 | query_sdf = -ext.sdfgen.sdf_from_points(**sdf_kwargs)[0]
77 |
78 | return query_sdf
79 |
80 | def crop(self, bounds: np.ndarray):
81 | crops = []
82 | for cur_bound in bounds:
83 | min_bound, max_bound = cur_bound[0], cur_bound[1]
84 | crop_mask = np.logical_and.reduce([
85 | self.xyz[:, 0] > min_bound[0], self.xyz[:, 0] < max_bound[0],
86 | self.xyz[:, 1] > min_bound[1], self.xyz[:, 1] < max_bound[1],
87 | self.xyz[:, 2] > min_bound[2], self.xyz[:, 2] < max_bound[2]
88 | ])
89 | crop_inst = self.__class__(self.xyz[crop_mask], self.normal[crop_mask])
90 | crops.append(crop_inst)
91 | return crops
92 |
93 |
94 | class PointTSDFVolumeGroundTruthGeometry(AVGroundTruthGeometry):
95 | def __init__(self, dense_points: DensePointsGroundTruthGeometry,
96 | volume: np.ndarray, volume_min: np.ndarray, volume_max: np.ndarray):
97 | super().__init__()
98 | self.dense_points = dense_points
99 | self.volume = volume
100 | self.volume_min = volume_min
101 | self.volume_max = volume_max
102 | assert np.all(self.volume_min < self.volume_max)
103 |
104 | @property
105 | def xyz(self):
106 | return self.dense_points.xyz
107 |
108 | @property
109 | def normal(self):
110 | return self.dense_points.normal
111 |
112 | @classmethod
113 | def empty(cls):
114 | return cls(DensePointsGroundTruthGeometry.empty(), np.ones((1, 1, 1)), np.zeros(3,), np.ones(3,))
115 |
116 | def is_empty(self):
117 | return self.dense_points.is_empty()
118 |
119 | def save(self, path: Path):
120 | with path.open("wb") as f:
121 | np.savez_compressed(f, xyz=self.dense_points.xyz, normal=self.dense_points.normal,
122 | volume=self.volume,
123 | volume_min=self.volume_min, volume_max=self.volume_max)
124 |
125 | def transform(self, iso: Isometry = Isometry(), scale: float = 1.0):
126 | assert iso.q.is_unit(), "Volume transform does not support rotation yet"
127 | self.dense_points.transform(iso, scale)
128 | self.volume_min = scale * (self.volume_min + iso.t)
129 | self.volume_max = scale * (self.volume_max + iso.t)
130 |
131 | @classmethod
132 | def load(cls, path: Path):
133 | dense_points = DensePointsGroundTruthGeometry.load(path)
134 | res = np.load(path)
135 | return cls(dense_points, res['volume'], res['volume_min'], res['volume_max'])
136 |
137 | @lru_cache_class(maxsize=None)
138 | def torch_attr(self):
139 | return *self.dense_points.torch_attr(), torch.from_numpy(self.volume).float().cuda()
140 |
141 | def query_classification(self, queries: torch.Tensor, band: float = 1.0):
142 | """
143 | Return integer classifications of the query points:
144 | 0 - near surface
145 | 1 - far surface empty
146 | 2 - unknown (also for query points outside volume)
147 | :param queries: torch.Tensor (N, 3)
148 | :param band: 0-1 band size to be classified as 'near-surface'
149 | :return: (N, ) ids
150 | """
151 | _, _, volume_input = self.torch_attr()
152 |
153 | in_volume_mask = (queries[:, 0] >= self.volume_min[0]) & (queries[:, 0] <= self.volume_max[0]) & \
154 | (queries[:, 1] >= self.volume_min[1]) & (queries[:, 1] <= self.volume_max[1]) & \
155 | (queries[:, 2] >= self.volume_min[2]) & (queries[:, 2] <= self.volume_max[2])
156 |
157 | queries_norm = queries[in_volume_mask].clone()
158 | for i in range(3):
159 | queries_norm[:, i] = (queries_norm[:, i] - self.volume_min[i]) / \
160 | (self.volume_max[i] - self.volume_min[i]) * 2. - 1.
161 | sample_grid = torch.fliplr(queries_norm)[None, None, None, ...]
162 | # B=1,C=1,Di=1,Hi=1,Wi x B=1,Do=1,Ho=1,Wo,3 --> B=1,C=1,Do=1,Ho=1,Wo
163 | sample_res = F.grid_sample(volume_input[None, None, ...], sample_grid,
164 | mode='nearest', padding_mode='border', align_corners=True)[0, 0, 0, 0]
165 |
166 | cls_in_volume = torch.ones_like(sample_res, dtype=torch.long)
167 | cls_in_volume[~torch.isfinite(sample_res)] = 2
168 | cls_in_volume[torch.abs(sample_res) < band] = 0
169 |
170 | cls = torch.ones(queries.size(0), dtype=torch.long, device=cls_in_volume.device) * 2
171 | cls[in_volume_mask] = cls_in_volume
172 |
173 | return cls
174 |
175 | def query_sdf(self, queries: torch.Tensor):
176 | return self.dense_points.query_sdf(queries)
177 |
178 | def crop(self, bounds: np.ndarray):
179 | point_crops = self.dense_points.crop(bounds)
180 |
181 | volume_x_ticks = np.linspace(self.volume_min[0], self.volume_max[0], self.volume.shape[0])
182 | volume_y_ticks = np.linspace(self.volume_min[1], self.volume_max[1], self.volume.shape[1])
183 | volume_z_ticks = np.linspace(self.volume_min[2], self.volume_max[2], self.volume.shape[2])
184 |
185 | crops = []
186 | for cur_point_crop, cur_bound in zip(point_crops, bounds):
187 | min_bound, max_bound = cur_bound[0], cur_bound[1]
188 | # volume_ticks[id_min] <= min_bound < max_bound <= volume_ticks[id_max]
189 | x_id_min = np.maximum(np.searchsorted(volume_x_ticks, min_bound[0], side='right') - 1, 0)
190 | x_id_max = np.minimum(np.searchsorted(volume_x_ticks, max_bound[0], side='left'),
191 | volume_x_ticks.shape[0] - 1)
192 | y_id_min = np.maximum(np.searchsorted(volume_y_ticks, min_bound[1], side='right') - 1, 0)
193 | y_id_max = np.minimum(np.searchsorted(volume_y_ticks, max_bound[1], side='left'),
194 | volume_y_ticks.shape[0] - 1)
195 | z_id_min = np.maximum(np.searchsorted(volume_z_ticks, min_bound[2], side='right') - 1, 0)
196 | z_id_max = np.minimum(np.searchsorted(volume_z_ticks, max_bound[2], side='left'),
197 | volume_z_ticks.shape[0] - 1)
198 | crops.append(self.__class__(
199 | cur_point_crop,
200 | self.volume[x_id_min:x_id_max+1, y_id_min:y_id_max+1, z_id_min:z_id_max+1],
201 | np.array([volume_x_ticks[x_id_min], volume_y_ticks[y_id_min], volume_z_ticks[z_id_min]]),
202 | np.array([volume_x_ticks[x_id_max], volume_y_ticks[y_id_max], volume_z_ticks[z_id_max]])
203 | ))
204 | return crops
205 |
206 |
207 | def get_class(class_name):
208 | if class_name == "DensePoints":
209 | return DensePointsGroundTruthGeometry
210 | elif class_name == "PointTSDFVolume":
211 | return PointTSDFVolumeGroundTruthGeometry
212 | else:
213 | raise NotImplementedError
214 |
--------------------------------------------------------------------------------
/noksr/data/dataset/general_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import Dataset
3 | import open3d as o3d
4 | from enum import Enum
5 | from numpy.random import RandomState
6 | import multiprocessing
7 | from omegaconf import DictConfig, ListConfig
8 | from pycg import exp
9 |
10 |
11 | class GeneralDataset(Dataset):
12 | """ Only used for Carla dataset """
13 | # Augmentation arguments
14 | SCALE_AUGMENTATION_BOUND = (0.9, 1.1)
15 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 64, np.pi / 64), (-np.pi / 64, np.pi / 64), (-np.pi,
16 | np.pi))
17 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.2, 0.2), (0, 0))
18 | ELASTIC_DISTORT_PARAMS = ((0.2, 0.4), (0.8, 1.6))
19 |
20 | ROTATION_AXIS = 'z'
21 | LOCFEAT_IDX = 2
22 |
23 | def __init__(self, cfg, split):
24 | pass
25 |
26 | class DatasetSpec(Enum):
27 | SCENE_NAME = 100
28 | SHAPE_NAME = 0
29 | INPUT_PC = 200
30 | TARGET_NORMAL = 300
31 | INPUT_COLOR = 400
32 | INPUT_SENSOR_POS = 500
33 | GT_DENSE_PC = 600
34 | GT_DENSE_NORMAL = 700
35 | GT_DENSE_COLOR = 800
36 | GT_MESH = 900
37 | GT_MESH_SOUP = 1000
38 | GT_ONET_SAMPLE = 1100
39 | GT_GEOMETRY = 1200
40 | DATASET_CFG = 1300
41 |
42 | class RandomSafeDataset(Dataset):
43 | """
44 | A dataset class that provides a deterministic random seed.
45 | However, in order to have consistent validation set, we need to set is_val=True for validation/test sets.
46 | Usage: First, inherent this class.
47 | Then, at the beginning of your get_item call, get an rng;
48 | Last, use this rng as the random state for your program.
49 | """
50 |
51 | def __init__(self, cfg, split):
52 | self._seed = cfg.global_train_seed
53 | self._is_val = split in ['val', 'test']
54 | self.skip_on_error = False
55 | if not self._is_val:
56 | self._manager = multiprocessing.Manager()
57 | self._read_count = self._manager.dict()
58 | self._rc_lock = multiprocessing.Lock()
59 |
60 | def get_rng(self, idx):
61 | if self._is_val:
62 | return RandomState(self._seed)
63 | with self._rc_lock:
64 | if idx not in self._read_count:
65 | self._read_count[idx] = 0
66 | rng = RandomState(exp.deterministic_hash((idx, self._read_count[idx], self._seed)))
67 | self._read_count[idx] += 1
68 | return rng
69 |
70 | def sanitize_specs(self, old_spec, available_spec):
71 | old_spec = set(old_spec)
72 | available_spec = set(available_spec)
73 | for os in old_spec:
74 | assert isinstance(os, DatasetSpec)
75 | new_spec = old_spec.intersection(available_spec)
76 | # lack_spec = old_spec.difference(new_spec)
77 | # if len(lack_spec) > 0:
78 | # exp.logger.warning(f"Lack spec {lack_spec}.")
79 | return new_spec
80 |
81 | def _get_item(self, data_id, rng):
82 | raise NotImplementedError
83 |
84 | def __getitem__(self, data_id):
85 | rng = self.get_rng(data_id)
86 | if self.skip_on_error:
87 | try:
88 | return self._get_item(data_id, rng)
89 | except ConnectionAbortedError:
90 | return self.__getitem__(rng.randint(0, len(self) - 1))
91 | except Exception:
92 | # Just return a random other item.
93 | exp.logger.warning(f"Get item {data_id} error, but handled.")
94 | return self.__getitem__(rng.randint(0, len(self) - 1))
95 | else:
96 | try:
97 | return self._get_item(data_id, rng)
98 | except ConnectionAbortedError:
99 | return self.__getitem__(rng.randint(0, len(self) - 1))
100 |
101 |
--------------------------------------------------------------------------------
/noksr/data/dataset/mixture.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset
5 | from noksr.data.dataset.scannet import Scannet
6 | from noksr.data.dataset.synthetic import Synthetic
7 |
8 | class Mixture(Dataset):
9 | def __init__(self, cfg, split):
10 |
11 | self.scannet_dataset = Scannet(cfg, split)
12 | self.synthetic_dataset = Synthetic(cfg, split)
13 | self.cfg = cfg
14 | self.split = split
15 | self.over_fitting = cfg.data.over_fitting
16 |
17 | # Combine the lengths of both datasets
18 | self.length = len(self.scannet_dataset) + len(self.synthetic_dataset)
19 |
20 | def __len__(self):
21 | if self.split == 'val':
22 | if self.cfg.data.validation_set == "ScanNet":
23 | return len(self.scannet_dataset)
24 | elif self.cfg.data.validation_set == "Synthetic":
25 | return len(self.synthetic_dataset)
26 |
27 | return self.length
28 |
29 | def __getitem__(self, idx):
30 | # Determine which dataset to load from based on idx
31 | if self.split == 'val':
32 | if self.cfg.data.validation_set == "ScanNet":
33 | return self.scannet_dataset[idx]
34 | elif self.cfg.data.validation_set == "Synthetic":
35 | return self.synthetic_dataset[idx]
36 |
37 | if idx < len(self.scannet_dataset):
38 | return self.scannet_dataset[idx]
39 | else:
40 | return self.synthetic_dataset[idx - len(self.scannet_dataset)]
41 |
--------------------------------------------------------------------------------
/noksr/data/dataset/scannet.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import Dataset
6 | from noksr.utils.serialization import encode
7 | import open3d as o3d
8 |
9 | class Scannet(Dataset):
10 | def __init__(self, cfg, split):
11 | self.cfg = cfg
12 | self.split = 'val' if self.cfg.data.over_fitting else split # use only val set for overfitting
13 | if 'ScanNet' in cfg.data: # subset of mixture data
14 | self.dataset_root_path = cfg.data.ScanNet.dataset_path
15 | self.dataset_path = cfg.data.ScanNet.dataset_path
16 | else:
17 | self.dataset_root_path = cfg.data.dataset_path
18 | self.dataset_path = cfg.data.dataset_path
19 |
20 | self.metadata = cfg.data.metadata
21 | self.num_input_points = cfg.data.num_input_points
22 | self.take = cfg.data.take
23 | self.intake_start = cfg.data.intake_start
24 | self.uniform_sampling = cfg.data.uniform_sampling
25 | self.input_splats = cfg.data.input_splats
26 | self.std_dev = cfg.data.std_dev
27 |
28 | self.in_memory = cfg.data.in_memory
29 | self.dataset_split = "test" if split == "test" else "train" # train and val scenes and all under train set
30 | self.data_map = {
31 | "train": self.metadata.train_list,
32 | "val": self.metadata.val_list,
33 | "test": self.metadata.test_list
34 | }
35 | self._load_from_disk()
36 |
37 | def _load_from_disk(self):
38 | with open(getattr(self.metadata, f"{self.split}_list")) as f:
39 | self.scene_names = [line.strip() for line in f]
40 |
41 | self.scenes = []
42 | if self.cfg.data.over_fitting:
43 | self.scene_names = self.scene_names[self.intake_start:self.take+self.intake_start]
44 | if len(self.scene_names) == 1: # if only one scene is taken, overfit on scene 0221_00
45 | self.scene_names = ['scene0221_00']
46 | for scene_name in tqdm(self.scene_names, desc=f"Loading {self.split} data from disk"):
47 | scene_path = os.path.join(self.dataset_path, self.split, f"{scene_name}.pth")
48 | scene = torch.load(scene_path)
49 | scene["xyz"] = scene["xyz"].astype(np.float32)
50 | scene["rgb"] = scene["rgb"].astype(np.float32)
51 | scene['scene_name'] = scene_name
52 | point_features = np.zeros(shape=(len(scene['xyz']), 0), dtype=np.float32)
53 | if self.cfg.model.network.use_color:
54 | point_features = np.concatenate((point_features, scene['rgb']), axis=1)
55 | if self.cfg.model.network.use_normal:
56 | point_features = np.concatenate((point_features, scene['normal']), axis=1)
57 | if self.cfg.model.network.use_xyz:
58 | point_features = np.concatenate((point_features, scene['xyz']), axis=1) # add xyz to point features
59 | scene['point_features'] = point_features
60 | self.scenes.append(scene)
61 |
62 | def __len__(self):
63 | return len(self.scenes)
64 |
65 | def __getitem__(self, idx):
66 | scene = self.scenes[idx]
67 | all_xyz = scene['xyz']
68 | all_normals = scene['normal']
69 | scene_name = scene['scene_name']
70 |
71 | # sample input points
72 | num_points = scene["xyz"].shape[0]
73 | num_input_points = self.num_input_points
74 | if num_input_points == -1:
75 | xyz = scene["xyz"]
76 | point_features = scene['point_features']
77 | else:
78 | if not self.uniform_sampling:
79 | # Number of blocks along each axis
80 | num_blocks = 2
81 | total_blocks = num_blocks ** 3
82 | self.common_difference = 200
83 | # Calculate block sizes
84 | block_sizes = (all_xyz.max(axis=0) - all_xyz.min(axis=0)) / num_blocks
85 |
86 | # Create the number_per_block array with an arithmetic sequence
87 | average_points_per_block = num_input_points // total_blocks
88 | number_per_block = np.array([
89 | average_points_per_block + (i - total_blocks // 2) * self.common_difference
90 | for i in range(total_blocks)
91 | ])
92 |
93 | # Adjust number_per_block to ensure the sum is num_input_points
94 | total_points = np.sum(number_per_block)
95 | difference = num_input_points - total_points
96 | number_per_block[-1] += difference
97 |
98 | # Sample points from each block
99 | sample_indices = []
100 | block_index = 0
101 | total_chosen_indices = 0
102 | remaining_points = 0 # Points to be added to the next block
103 | for i in range(num_blocks):
104 | for j in range(num_blocks):
105 | for k in range(num_blocks):
106 | block_min = all_xyz.min(axis=0) + block_sizes * np.array([i, j, k])
107 | block_max = block_min + block_sizes
108 | block_mask = np.all((all_xyz >= block_min) & (all_xyz < block_max), axis=1)
109 | block_indices = np.where(block_mask)[0]
110 | num_samples = number_per_block[block_index] + remaining_points
111 | remaining_points = 0 # Reset remaining points
112 | block_index += 1
113 | if len(block_indices) > 0:
114 | chosen_indices = np.random.choice(block_indices, num_samples, replace=True)
115 | sample_indices.extend(chosen_indices)
116 | total_chosen_indices += len(chosen_indices)
117 | # print(f"Block {block_index} - Desired: {num_samples}, Actual: {len(chosen_indices)}")
118 | if len(chosen_indices) < num_samples:
119 | remaining_points += (num_samples - len(chosen_indices))
120 | else:
121 | # print(f"Block {block_index} - No points available. Adding {num_samples} points to the next block.")
122 | remaining_points += num_samples
123 |
124 | else:
125 | if num_points < num_input_points:
126 | print(f"Scene {scene_name} has less than {num_input_points} points. Sampling with replacement.")
127 | sample_indices = np.random.choice(num_points, num_input_points, replace=True)
128 | else:
129 | sample_indices = np.random.choice(num_points, num_input_points, replace=True)
130 |
131 | xyz = scene["xyz"][sample_indices]
132 | point_features = scene['point_features'][sample_indices]
133 | noise = np.random.normal(0, self.std_dev, xyz.shape)
134 | xyz += noise
135 |
136 | data = {
137 | "all_xyz": all_xyz,
138 | "all_normals": all_normals,
139 | "xyz": xyz, # N, 3
140 | "point_features": point_features, # N, 3
141 | "scene_name": scene['scene_name']
142 | }
143 |
144 | return data
145 |
146 |
--------------------------------------------------------------------------------
/noksr/data/dataset/scenenn.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 | from plyfile import PlyData
6 |
7 | class SceneNN(Dataset):
8 | def __init__(self, cfg, split):
9 | self.cfg = cfg
10 | self.split = split
11 | self.dataset_root_path = cfg.data.dataset_path
12 | self.voxel_size = cfg.data.voxel_size
13 | self.num_input_points = cfg.data.num_input_points
14 | self.std_dev = cfg.data.std_dev
15 | self.intake_start = cfg.data.intake_start
16 | self.take = cfg.data.take
17 | self.input_splats = cfg.data.input_splats
18 |
19 | self.in_memory = cfg.data.in_memory
20 | self.train_files = cfg.data.train_files
21 | self.test_files = cfg.data.test_files
22 | if self.split == 'test':
23 | scene_ids = self.test_files
24 | else:
25 | scene_ids = self.train_files + self.test_files
26 | self.filenames = sorted([os.path.join(self.dataset_root_path, scene_id, f)
27 | for scene_id in scene_ids
28 | for f in os.listdir(os.path.join(self.dataset_root_path, scene_id))
29 | if f.endswith('.ply')])
30 | if self.cfg.data.over_fitting:
31 | self.filenames = self.filenames[self.intake_start:self.take+self.intake_start]
32 |
33 | def __len__(self):
34 | return len(self.filenames)
35 |
36 | def __getitem__(self, idx):
37 | """Get item."""
38 | # load the mesh
39 | scene_filename = self.filenames[idx]
40 |
41 | ply_data = PlyData.read(scene_filename)
42 | vertex = ply_data['vertex']
43 | pos = np.stack([vertex[t] for t in ('x', 'y', 'z')], axis=1)
44 | nls = np.stack([vertex[t] for t in ('nx', 'ny', 'nz')], axis=1) if 'nx' in vertex and 'ny' in vertex and 'nz' in vertex else np.zeros_like(pos)
45 |
46 | all_xyz = pos
47 | all_normals = nls
48 | scene_name = os.path.basename(scene_filename).replace('.ply', '')
49 |
50 | all_point_features = np.zeros(shape=(len(all_xyz), 0), dtype=np.float32)
51 | if self.cfg.model.network.use_normal:
52 | all_point_features = np.concatenate((all_point_features, all_normals), axis=1)
53 | if self.cfg.model.network.use_xyz:
54 | all_point_features = np.concatenate((all_point_features, all_xyz), axis=1) # add xyz to point features
55 |
56 | # sample input points
57 | num_points = all_xyz.shape[0]
58 | if self.num_input_points == -1:
59 | xyz = all_xyz
60 | point_features = all_point_features
61 | normals = all_normals
62 | else:
63 | sample_indices = np.random.choice(num_points, self.num_input_points, replace=True)
64 | xyz = all_xyz[sample_indices]
65 | point_features = all_point_features[sample_indices]
66 | normals = all_normals[sample_indices]
67 |
68 | noise = np.random.normal(0, self.std_dev, xyz.shape)
69 | xyz += noise
70 |
71 | data = {
72 | "all_xyz": all_xyz,
73 | "all_normals": all_normals,
74 | "xyz": xyz, # N, 3
75 | "normals": normals, # N, 3
76 | "point_features": point_features, # N, 3
77 | "scene_name": scene_name
78 | }
79 |
80 | return data
81 |
--------------------------------------------------------------------------------
/noksr/data/dataset/synthetic.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | from statistics import mode
4 | import numpy as np
5 |
6 | from torch.utils.data import Dataset
7 | from plyfile import PlyData
8 |
9 |
10 | class Synthetic(Dataset):
11 | def __init__(self, cfg, split):
12 | # Attributes
13 | if 'Synthetic' in cfg.data: # subset of mixture data
14 | categories = cfg.data.Synthetic.classes
15 | self.dataset_folder = cfg.data.Synthetic.path
16 | self.multi_files = cfg.data.Synthetic.multi_files
17 | self.file_name = cfg.data.Synthetic.pointcloud_file
18 | else:
19 | categories = cfg.data.classes
20 | self.dataset_folder = cfg.data.path
21 | self.multi_files = cfg.data.multi_files
22 | self.file_name = cfg.data.pointcloud_file
23 | self.scale = 2.2 # Emperical scale to transfer back to physical scale
24 | self.cfg = cfg
25 | self.split = split
26 | self.std_dev = cfg.data.std_dev * self.scale
27 | self.num_input_points = cfg.data.num_input_points
28 |
29 | self.no_except = True
30 |
31 | self.intake_start = cfg.data.intake_start
32 | self.take = cfg.data.take
33 | # If categories is None, use all subfolders
34 | if categories is None:
35 | categories = os.listdir(self.dataset_folder)
36 | categories = [c for c in categories
37 | if os.path.isdir(os.path.join(self.dataset_folder, c))]
38 |
39 | self.metadata = {
40 | c: {'id': c, 'name': 'n/a'} for c in categories
41 | }
42 |
43 | # Set index
44 | for c_idx, c in enumerate(categories):
45 | self.metadata[c]['idx'] = c_idx
46 |
47 | # Get all models
48 | self.models = []
49 | for c_idx, c in enumerate(categories):
50 | subpath = os.path.join(self.dataset_folder, c)
51 | if not os.path.isdir(subpath):
52 | print('Category %s does not exist in dataset.' % c)
53 |
54 | if self.split is None:
55 | self.models += [
56 | {'category': c, 'model': m} for m in [d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ]
57 | ]
58 |
59 | else:
60 | split_file = os.path.join(subpath, self.split + '.lst')
61 | with open(split_file, 'r') as f:
62 | models_c = f.read().split('\n')
63 |
64 | if '' in models_c:
65 | models_c.remove('')
66 |
67 | self.models += [
68 | {'category': c, 'model': m}
69 | for m in models_c
70 | ]
71 |
72 | # overfit in one data
73 | if self.cfg.data.over_fitting:
74 | self.models = self.models[self.intake_start:self.take+self.intake_start]
75 |
76 |
77 | def __len__(self):
78 | ''' Returns the length of the dataset.
79 | '''
80 | return len(self.models)
81 |
82 | def load(self, model_path, idx, vol):
83 | ''' Loads the data point.
84 |
85 | Args:
86 | model_path (str): path to model
87 | idx (int): ID of data point
88 | vol (dict): precomputed volume info
89 | '''
90 | if self.multi_files is None:
91 | file_path = os.path.join(model_path, self.file_name)
92 | else:
93 | num = np.random.randint(self.multi_files)
94 | file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num))
95 |
96 | item_path = os.path.join(model_path, 'item_dict.npz')
97 | item_dict = np.load(item_path, allow_pickle=True)
98 | points_dict = np.load(file_path, allow_pickle=True)
99 | points = points_dict['points'] * self.scale # roughly transfer back to physical scale
100 | normals = points_dict['normals']
101 | semantics = points_dict['semantics']
102 | # Break symmetry if given in float16:
103 | if points.dtype == np.float16:
104 | points = points.astype(np.float32)
105 | normals = normals.astype(np.float32)
106 | points += 1e-4 * np.random.randn(*points.shape)
107 | normals += 1e-4 * np.random.randn(*normals.shape)
108 |
109 | min_values = np.min(points, axis=0)
110 | points -= min_values
111 |
112 | return points, normals, semantics
113 |
114 | def __getitem__(self, idx):
115 | ''' Returns an item of the dataset.
116 |
117 | Args:
118 | idx (int): ID of data point
119 | '''
120 | category = self.models[idx]['category']
121 | model = self.models[idx]['model']
122 | c_idx = self.metadata[category]['idx']
123 |
124 | model_path = os.path.join(self.dataset_folder, category, model)
125 |
126 | all_xyz, all_normals, all_semantics = self.load(model_path, idx, c_idx)
127 | scene_name = f"{category}/{model}/{idx}"
128 |
129 | all_point_features = np.zeros(shape=(len(all_xyz), 0), dtype=np.float32)
130 | if self.cfg.model.network.use_normal:
131 | all_point_features = np.concatenate((all_point_features, all_normals), axis=1)
132 | if self.cfg.model.network.use_xyz:
133 | all_point_features = np.concatenate((all_point_features, all_xyz), axis=1) # add xyz to point features
134 | # sample input points
135 | num_points = all_xyz.shape[0]
136 | if self.num_input_points == -1:
137 | xyz = all_xyz
138 | point_features = all_point_features
139 | normals = all_normals
140 | else:
141 | sample_indices = np.random.choice(num_points, self.num_input_points, replace=True)
142 | xyz = all_xyz[sample_indices]
143 | point_features = all_point_features[sample_indices]
144 | normals = all_normals[sample_indices]
145 |
146 | noise = np.random.normal(0, self.std_dev, xyz.shape)
147 | xyz += noise
148 |
149 | data = {
150 | "all_xyz": all_xyz,
151 | "all_normals": all_normals,
152 | "xyz": xyz, # N, 3
153 | "normals": normals, # N, 3
154 | "point_features": point_features, # N, 3
155 | "scene_name": scene_name
156 | }
157 | return data
158 |
--------------------------------------------------------------------------------
/noksr/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .noksr_net import noksr
2 |
3 |
4 |
--------------------------------------------------------------------------------
/noksr/model/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/__pycache__/general_model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/__pycache__/general_model.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/__pycache__/noksr_net.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/__pycache__/noksr_net.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/__pycache__/pcs4esr_net.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/__pycache__/pcs4esr_net.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/general_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import open3d as o3d
4 | import numpy as np
5 | import pytorch_lightning as pl
6 | import pl_bolts
7 | from collections import OrderedDict
8 | from typing import Mapping, Any, Optional
9 | from noksr.utils.optimizer import cosine_lr_decay
10 | # from noksr.model.module import Generator
11 | from torch.nn import functional as F
12 | from pycg import exp, image
13 |
14 |
15 | class GeneralModel(pl.LightningModule):
16 | def __init__(self, cfg):
17 | super().__init__()
18 | self.save_hyperparameters()
19 | self.val_test_step_outputs = []
20 | # For recording test information
21 | # step -> log_name -> log_value (list of ordered-dict)
22 | self.test_logged_values = []
23 | self.record_folder = None
24 | self.record_headers = []
25 | self.record_data_cache = {}
26 | self.last_test_valid = False
27 |
28 | def configure_optimizers(self):
29 | params_to_optimize = self.parameters()
30 |
31 | if self.hparams.model.optimizer.name == "SGD":
32 | optimizer = torch.optim.SGD(
33 | params_to_optimize,
34 | lr=self.hparams.model.optimizer.lr,
35 | momentum=0.9,
36 | weight_decay=1e-4,
37 | )
38 | scheduler = pl_bolts.optimizers.LinearWarmupCosineAnnealingLR(
39 | optimizer,
40 | warmup_epochs=int(self.hparams.model.optimizer.warmup_steps_ratio * self.hparams.model.trainer.max_steps),
41 | max_epochs=self.hparams.model.trainer.max_steps,
42 | eta_min=0,
43 | )
44 | return {
45 | "optimizer": optimizer,
46 | "lr_scheduler": {
47 | "scheduler": scheduler,
48 | "interval": "step"
49 | }
50 | }
51 |
52 | elif self.hparams.model.optimizer.name == 'Adam':
53 | optimizer = torch.optim.Adam(
54 | params_to_optimize,
55 | lr=self.hparams.model.optimizer.lr,
56 | betas=(0.9, 0.999),
57 | weight_decay=1e-4,
58 | )
59 | return optimizer
60 |
61 | else:
62 | logging.error('Optimizer type not supported')
63 |
64 | def training_step(self, data_dict):
65 | pass
66 |
67 | def on_train_epoch_end(self):
68 | if self.hparams.model.optimizer.name == 'Adam':
69 | # Update the learning rates for Adam optimizers
70 | cosine_lr_decay(
71 | self.trainer.optimizers[0], self.hparams.model.optimizer.lr, self.current_epoch,
72 | self.hparams.model.lr_decay.decay_start_epoch, self.hparams.model.trainer.max_epochs, 1e-6
73 | )
74 |
75 | def validation_step(self, data_dict, idx):
76 | pass
77 |
78 | def validation_epoch_end(self, outputs):
79 | metrics_to_log = ['chamfer-L1', 'f-score', 'f-score-20']
80 | if outputs:
81 | avg_metrics = {metric: np.mean([x[metric] for x in outputs]) for metric in metrics_to_log if metric in outputs[0]}
82 | for key, value in avg_metrics.items():
83 | self.log(f"val_reconstruction/{key}", value, logger=True)
84 |
85 |
86 |
87 | def test_step(self, data_dict, idx):
88 | pass
89 |
90 |
91 | def log_dict_prefix(
92 | self,
93 | prefix: str,
94 | dictionary: Mapping[str, Any],
95 | prog_bar: bool = False,
96 | logger: bool = True,
97 | on_step: Optional[bool] = None,
98 | on_epoch: Optional[bool] = None
99 | ):
100 | """
101 | This overrides fixes if dict key is not a string...
102 | """
103 | dictionary = {
104 | prefix + "/" + str(k): v for k, v in dictionary.items()
105 | }
106 | self.log_dict(dictionary=dictionary,
107 | prog_bar=prog_bar,
108 | logger=logger, on_step=on_step, on_epoch=on_epoch)
109 |
110 | def log_image(self, name: str, img: np.ndarray):
111 | if self.trainer.logger is not None:
112 | self.trainer.logger.log_image(key=name, images=[img])
113 |
114 |
115 | def log_geometry(self, name: str, geom, draw_color: bool = False):
116 | if self.trainer.logger is None:
117 | return
118 | if isinstance(geom, o3d.geometry.TriangleMesh):
119 | try:
120 | from pycg import render
121 | mv_img = render.multiview_image(
122 | [geom], viewport_shading='LIT' if draw_color else 'NORMAL', backend='filament')
123 | # mv_img = render.multiview_image(
124 | # [geom], viewport_shading='LIT' if draw_color else 'NORMAL', backend='opengl')
125 | self.log_image("mesh" + name, mv_img)
126 | except Exception:
127 | exp.logger.warning("Not able to render mesh during training.")
128 | else:
129 | raise NotImplementedError
130 |
131 | def test_log_data(self, data_dict: dict):
132 | self.record_data_cache.update(data_dict)
--------------------------------------------------------------------------------
/noksr/model/module/__init__.py:
--------------------------------------------------------------------------------
1 | from .decoder import Decoder
2 | from .generation import Generator
3 | from .point_transformer import PointTransformerV3
4 |
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/backbone.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/backbone.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/backbone.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/backbone.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/backbone_nocs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/backbone_nocs.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/common.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/common.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/common.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/common.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/decoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/decoder.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/decoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/decoder.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/decoder2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/decoder2.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/encoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/encoder.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/generation.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/generation.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/generation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/generation.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/kp_decoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/kp_decoder.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/larger_decoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/larger_decoder.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/point_transformer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/point_transformer.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/tiny_unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/tiny_unet.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/visualization.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/visualization.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/model/module/__pycache__/visualization.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/model/module/__pycache__/visualization.cpython-38.pyc
--------------------------------------------------------------------------------
/noksr/model/module/generation.py:
--------------------------------------------------------------------------------
1 | from pyexpat import features
2 | import time
3 |
4 | import torch
5 | from pycg import vis
6 | from torch import Tensor, nn
7 | from torch.nn import functional as F
8 | from tqdm import tqdm
9 | from typing import Callable, Tuple
10 | from sklearn.neighbors import NearestNeighbors
11 | from pytorch3d.ops import knn_points
12 | import open3d as o3d
13 | import pytorch_lightning as pl
14 | from noksr.utils.samples import BatchedSampler
15 | from noksr.utils.serialization import encode
16 |
17 |
18 | class MeshingResult:
19 | def __init__(self, v: torch.Tensor = None, f: torch.Tensor = None, c: torch.Tensor = None):
20 | self.v = v
21 | self.f = f
22 | self.c = c
23 |
24 | class Generator(pl.LightningModule):
25 | def __init__(self, model, mask_model, voxel_size, k_neighbors, last_n_layers, reconstruction_cfg):
26 | super().__init__()
27 | self.model = model # the model should be the UDF Decoder
28 | self.mask_model = mask_model # the distance mask decoder
29 | self.rec_cfg = reconstruction_cfg
30 | self.voxel_size = voxel_size
31 | self.threshold = 0.4
32 | self.k_neighbors = k_neighbors
33 | self.last_n_layers = last_n_layers
34 |
35 |
36 | def compute_gt_sdf_from_pts(self, gt_xyz, gt_normals, query_pos: torch.Tensor):
37 | k = 8
38 | stdv = 0.02
39 | knn_output = knn_points(query_pos.unsqueeze(0).to(torch.device("cuda")), gt_xyz.unsqueeze(0).to(torch.device("cuda")), K=k)
40 | indices = knn_output.idx.squeeze(0)
41 | indices = torch.tensor(indices, device=query_pos.device)
42 | closest_points = gt_xyz[indices]
43 | surface_to_queries_vec = query_pos.unsqueeze(1) - closest_points #N, K, 3
44 |
45 | dot_products = torch.einsum("ijk,ijk->ij", surface_to_queries_vec, gt_normals[indices]) #N, K
46 | vec_lengths = torch.norm(surface_to_queries_vec[:, 0, :], dim=-1)
47 | use_dot_product = vec_lengths < stdv
48 | sdf = torch.where(use_dot_product, torch.abs(dot_products[:, 0]), vec_lengths)
49 |
50 | # Adjust the sign of the sdf values based on the majority of dot products
51 | num_pos = torch.sum(dot_products > 0, dim=1)
52 | inside = num_pos <= (k / 2)
53 | sdf[inside] *= -1
54 |
55 | return -sdf
56 |
57 | def generate_dual_mc_mesh(self, data_dict, encoder_outputs, device):
58 | from nksr.svh import SparseFeatureHierarchy, SparseIndexGrid
59 | from nksr.ext import meshing
60 | from nksr.meshing import MarchingCubes
61 | from nksr import utils
62 |
63 | max_depth = 100
64 | grid_upsample = 1
65 | mise_iter = 0
66 | knn_time = 0
67 | dmc_time = 0
68 | aggregation_time = 0
69 | decoder_time = 0
70 | mask_threshold = self.rec_cfg.mask_threshold
71 |
72 | pts = data_dict['xyz'].detach()
73 | self.last_n_layers = 4
74 | self.trim = self.rec_cfg.trim
75 | self.gt_mask = self.rec_cfg.gt_mask
76 | self.gt_sdf = self.rec_cfg.gt_sdf
77 | # Generate DMC grid structure
78 | nksr_svh = SparseFeatureHierarchy(
79 | voxel_size=self.voxel_size,
80 | depth=self.last_n_layers,
81 | device= pts.device
82 | )
83 | nksr_svh.build_point_splatting(pts)
84 |
85 | flattened_grids = []
86 | for d in range(min(nksr_svh.depth, max_depth + 1)):
87 | f_grid = meshing.build_flattened_grid(
88 | nksr_svh.grids[d]._grid,
89 | nksr_svh.grids[d - 1]._grid if d > 0 else None,
90 | d != nksr_svh.depth - 1
91 | )
92 | if grid_upsample > 1:
93 | f_grid = f_grid.subdivided_grid(grid_upsample)
94 | flattened_grids.append(f_grid)
95 |
96 | dual_grid = meshing.build_joint_dual_grid(flattened_grids)
97 | dmc_graph = meshing.dual_cube_graph(flattened_grids, dual_grid)
98 | dmc_vertices = torch.cat([
99 | f_grid.grid_to_world(f_grid.active_grid_coords().float())
100 | for f_grid in flattened_grids if f_grid.num_voxels() > 0
101 | ], dim=0)
102 | del flattened_grids, dual_grid
103 | """ create a mask to trim spurious geometry """
104 |
105 | decoder_time -= time.time()
106 | dmc_value, sdf_knn_time, sdf_aggregation_time = self.model(encoder_outputs, dmc_vertices)
107 | decoder_time += time.time()
108 | knn_time += sdf_knn_time
109 | aggregation_time += sdf_aggregation_time
110 | if self.gt_sdf:
111 | if 'gt_geometry' in data_dict:
112 | ref_xyz, ref_normal, _ = data_dict['gt_geometry'][0].torch_attr()
113 | else:
114 | ref_xyz, ref_normal = data_dict['all_xyz'], data_dict['all_normals']
115 | dmc_value = self.compute_gt_sdf_from_pts(ref_xyz, ref_normal, dmc_vertices)
116 |
117 | for _ in range(mise_iter):
118 | cube_sign = dmc_value[dmc_graph] > 0
119 | cube_mask = ~torch.logical_or(torch.all(cube_sign, dim=1), torch.all(~cube_sign, dim=1))
120 | dmc_graph = dmc_graph[cube_mask]
121 | unq, dmc_graph = torch.unique(dmc_graph.view(-1), return_inverse=True)
122 | dmc_graph = dmc_graph.view(-1, 8)
123 | dmc_vertices = dmc_vertices[unq]
124 | dmc_graph, dmc_vertices = utils.subdivide_cube_indices(dmc_graph, dmc_vertices)
125 | dmc_value = torch.clamp(self.model(encoder_outputs, dmc_vertices.to(device)), max=self.threshold)
126 |
127 | dmc_time -= time.time()
128 | dual_v, dual_f = MarchingCubes().apply(dmc_graph, dmc_vertices, dmc_value)
129 | dmc_time += time.time()
130 |
131 | vert_mask = None
132 | if self.trim:
133 | if self.gt_mask:
134 | nn = NearestNeighbors(n_neighbors=1)
135 | nn.fit(data_dict['all_xyz'].cpu().numpy()) # coords is an (N, 3) array
136 | dist, indx = nn.kneighbors(dual_v.detach().cpu().numpy()) # xyz is an (M, 3) array
137 | dist = torch.from_numpy(dist).to(dual_v.device).squeeze(-1)
138 | vert_mask = dist < mask_threshold
139 | else:
140 | decoder_time -= time.time()
141 | dist, mask_knn_time, mask_aggregation_time = self.mask_model(encoder_outputs, dual_v.to(device))
142 | decoder_time += time.time()
143 | vert_mask = dist < mask_threshold
144 | knn_time += mask_knn_time
145 | aggregation_time += mask_aggregation_time
146 | dmc_time -= time.time()
147 | dual_v, dual_f = utils.apply_vertex_mask(dual_v, dual_f, vert_mask)
148 | dmc_time += time.time()
149 |
150 | dmc_time -= time.time()
151 | mesh_res = MeshingResult(dual_v, dual_f, None)
152 | # del dual_v, dual_f
153 | mesh = vis.mesh(mesh_res.v, mesh_res.f)
154 | dmc_time += time.time()
155 |
156 | time_dict = {
157 | 'neighboring_time': knn_time,
158 | 'dmc_time': dmc_time,
159 | 'aggregation_time': aggregation_time,
160 | 'decoder_time': decoder_time,
161 | }
162 | return mesh, time_dict
163 |
164 | def generate_dual_mc_mesh_by_segment(self, data_dict, encoder_outputs, encoding_codes, depth, device):
165 | """
166 | This function generates a dual marching cube mesh by computing the sdf values for each segment individually.
167 | """
168 | from nksr.svh import SparseFeatureHierarchy, SparseIndexGrid
169 | from nksr.ext import meshing
170 | from nksr.meshing import MarchingCubes
171 | from nksr import utils
172 |
173 | max_depth = 100
174 | grid_upsample = 1
175 | mise_iter = 0
176 | knn_time = 0
177 | dmc_time = 0
178 | aggregation_time = 0
179 | decoder_time = 0
180 | mask_threshold = self.rec_cfg.mask_threshold
181 |
182 | pts = data_dict['xyz'].detach()
183 | self.last_n_layers = 4
184 | self.trim = self.rec_cfg.trim
185 | self.gt_mask = self.rec_cfg.gt_mask
186 | self.gt_sdf = self.rec_cfg.gt_sdf
187 |
188 | # Generate DMC grid structure
189 | nksr_svh = SparseFeatureHierarchy(
190 | voxel_size=self.voxel_size,
191 | depth=self.last_n_layers,
192 | device=pts.device
193 | )
194 | nksr_svh.build_point_splatting(pts)
195 |
196 | flattened_grids = []
197 | for d in range(min(nksr_svh.depth, max_depth + 1)):
198 | f_grid = meshing.build_flattened_grid(
199 | nksr_svh.grids[d]._grid,
200 | nksr_svh.grids[d - 1]._grid if d > 0 else None,
201 | d != nksr_svh.depth - 1
202 | )
203 | if grid_upsample > 1:
204 | f_grid = f_grid.subdivided_grid(grid_upsample)
205 | flattened_grids.append(f_grid)
206 |
207 | dual_grid = meshing.build_joint_dual_grid(flattened_grids)
208 | dmc_graph = meshing.dual_cube_graph(flattened_grids, dual_grid)
209 | dmc_vertices = torch.cat([
210 | f_grid.grid_to_world(f_grid.active_grid_coords().float())
211 | for f_grid in flattened_grids if f_grid.num_voxels() > 0
212 | ], dim=0)
213 | del flattened_grids, dual_grid
214 |
215 | # Encode and segment `dmc_vertices`
216 | in_quant_coords = torch.floor(dmc_vertices / 0.01).to(torch.int)
217 | dmc_quant_codes = encode(
218 | in_quant_coords,
219 | torch.zeros(in_quant_coords.shape[0], dtype=torch.int64, device=in_quant_coords.device),
220 | depth,
221 | order='z'
222 | )
223 | sorted_codes, sorted_indices = torch.sort(dmc_quant_codes)
224 |
225 | dmc_value_list = []
226 | for idx in range(len(encoding_codes)):
227 | if idx == 0:
228 | segment_mask = (sorted_codes < encoding_codes[idx+1])
229 | elif idx == len(encoding_codes) - 1:
230 | segment_mask = (sorted_codes >= encoding_codes[idx])
231 | else:
232 | segment_mask = (sorted_codes >= encoding_codes[idx]) & (sorted_codes < encoding_codes[idx+1])
233 | segment_indices = sorted_indices[segment_mask]
234 | segment_vertices = dmc_vertices[segment_indices]
235 | segment_encoder_output = encoder_outputs[idx]
236 | segment_dmc_value, sdf_knn_time, sdf_aggregation_time = self.model(segment_encoder_output, segment_vertices)
237 | dmc_value_list.append(segment_dmc_value)
238 |
239 | knn_time += sdf_knn_time
240 | aggregation_time += sdf_aggregation_time
241 |
242 | dmc_values = torch.zeros_like(sorted_codes, dtype=torch.float32, device=device)
243 | dmc_values[sorted_indices] = torch.cat(dmc_value_list)
244 |
245 | dmc_time -= time.time()
246 | dual_v, dual_f = MarchingCubes().apply(dmc_graph, dmc_vertices, dmc_values)
247 | dmc_time += time.time()
248 |
249 | """ create a mask to trim spurious geometry """
250 | in_quant_coords = torch.floor(dual_v / 0.01).to(torch.int)
251 | dual_quant_codes = encode(
252 | in_quant_coords,
253 | torch.zeros(in_quant_coords.shape[0], dtype=torch.int64, device=in_quant_coords.device),
254 | depth,
255 | order='z'
256 | )
257 | sorted_dual_codes, sorted_dual_indices = torch.sort(dual_quant_codes)
258 |
259 | dist_list = []
260 | for idx in range(len(encoding_codes)):
261 | if idx == 0:
262 | segment_mask = (sorted_dual_codes < encoding_codes[idx+1])
263 | elif idx == len(encoding_codes) - 1:
264 | segment_mask = (sorted_dual_codes >= encoding_codes[idx])
265 | else:
266 | segment_mask = (sorted_dual_codes >= encoding_codes[idx]) & (sorted_dual_codes < encoding_codes[idx+1])
267 | segment_indices = sorted_dual_indices[segment_mask]
268 | segment_dual_v = dual_v[segment_indices]
269 |
270 | if self.gt_mask:
271 | nn = NearestNeighbors(n_neighbors=1)
272 | nn.fit(data_dict['all_xyz'].cpu().numpy()) # Reference points (N, 3)
273 | segment_dist, _ = nn.kneighbors(segment_dual_v.detach().cpu().numpy()) # Query points (M, 3)
274 | segment_dist = torch.from_numpy(segment_dist).to(dual_v.device).squeeze(-1)
275 | else:
276 | decoder_time -= time.time()
277 | segment_dist, mask_knn_time, mask_aggregation_time = self.mask_model(encoder_outputs[idx], segment_dual_v.to(device))
278 | decoder_time += time.time()
279 |
280 | knn_time += mask_knn_time
281 | aggregation_time += mask_aggregation_time
282 |
283 | dist_list.append(segment_dist)
284 |
285 | dist = torch.zeros_like(sorted_dual_codes, dtype=torch.float32, device=device)
286 | dist[sorted_dual_indices] = torch.cat(dist_list)
287 |
288 | vert_mask = dist < mask_threshold
289 | dual_v, dual_f = utils.apply_vertex_mask(dual_v, dual_f, vert_mask)
290 |
291 | dmc_time -= time.time()
292 | mesh_res = MeshingResult(dual_v, dual_f, None)
293 | mesh = vis.mesh(mesh_res.v, mesh_res.f)
294 | dmc_time += time.time()
295 |
296 | time_dict = {
297 | 'neighboring_time': knn_time,
298 | 'dmc_time': dmc_time,
299 | 'aggregation_time': aggregation_time,
300 | 'decoder_time': decoder_time,
301 | }
302 |
303 | return mesh, time_dict
304 |
--------------------------------------------------------------------------------
/noksr/model/noksr_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import pytorch_lightning as pl
4 | import importlib
5 | from noksr.model.module import PointTransformerV3
6 | from noksr.utils.samples import BatchedSampler
7 | from noksr.model.general_model import GeneralModel
8 | from torch.nn import functional as F
9 |
10 | class noksr(GeneralModel):
11 | def __init__(self, cfg):
12 | super().__init__(cfg)
13 | self.save_hyperparameters(cfg)
14 |
15 | self.latent_dim = cfg.model.network.latent_dim
16 | self.decoder_type = cfg.model.network.sdf_decoder.decoder_type
17 | self.eikonal = cfg.data.supervision.eikonal.loss
18 | self.laplacian = cfg.data.supervision.laplacian.loss
19 | self.surface_normal_supervision = cfg.data.supervision.on_surface.normal_loss
20 | self.flip_eikonal = cfg.data.supervision.eikonal.flip
21 | self.backbone = cfg.model.network.backbone
22 | if self.backbone == "PointTransformerV3":
23 | self.point_transformer = PointTransformerV3(
24 | backbone_cfg=cfg.model.network.point_transformerv3
25 | )
26 |
27 | module = importlib.import_module('noksr.model.module')
28 | decoder_class = getattr(module, self.decoder_type)
29 | self.sdf_decoder = decoder_class(
30 | decoder_cfg=cfg.model.network.sdf_decoder,
31 | supervision = 'SDF',
32 | latent_dim=cfg.model.network.latent_dim,
33 | feature_dim=cfg.model.network.sdf_decoder.feature_dim,
34 | hidden_dim=cfg.model.network.sdf_decoder.hidden_dim,
35 | out_dim=1,
36 | voxel_size=cfg.data.voxel_size,
37 | activation=cfg.model.network.sdf_decoder.activation
38 | )
39 |
40 | # if self.hparams.model.network.mask_decoder.distance_mask:
41 | if cfg.data.supervision.udf.weight > 0:
42 | self.mask_decoder = decoder_class(
43 | decoder_cfg=cfg.model.network.mask_decoder,
44 | supervision = 'Distance',
45 | latent_dim=cfg.model.network.latent_dim,
46 | feature_dim=cfg.model.network.mask_decoder.feature_dim,
47 | hidden_dim=cfg.model.network.mask_decoder.hidden_dim,
48 | out_dim=1,
49 | voxel_size=cfg.data.voxel_size,
50 | activation=cfg.model.network.mask_decoder.activation
51 | )
52 |
53 | self.batched_sampler = BatchedSampler(cfg) # Instantiate the batched sampler with the configurations
54 |
55 | def forward(self, data_dict):
56 | outputs = {}
57 | """ Get query samples and ground truth values """
58 | query_xyz, query_gt_sdf = self.batched_sampler.batch_sdf_sample(data_dict)
59 | on_surface_xyz, gt_on_surface_normal = self.batched_sampler.batch_on_surface_sample(data_dict)
60 | if self.hparams.data.supervision.udf.weight > 0:
61 | mask_query_xyz, mask_query_gt_udf = self.batched_sampler.batch_udf_sample(data_dict)
62 | outputs['gt_distances'] = mask_query_gt_udf
63 |
64 | outputs['gt_values'] = query_gt_sdf
65 | outputs['gt_on_surface_normal'] = gt_on_surface_normal
66 |
67 | if self.backbone == "PointTransformerV3":
68 | pt_data = {}
69 | pt_data['feat'] = data_dict['point_features']
70 | pt_data['offset'] = torch.cumsum(data_dict['xyz_splits'], dim=0)
71 | pt_data['grid_size'] = 0.01
72 | pt_data['coord'] = data_dict['xyz']
73 | encoder_output = self.point_transformer(pt_data)
74 |
75 | outputs['values'], *_ = self.sdf_decoder(encoder_output, query_xyz)
76 | outputs['surface_values'], *_ = self.sdf_decoder(encoder_output, on_surface_xyz)
77 |
78 | if self.eikonal:
79 | if self.hparams.model.network.grad_type == "Numerical":
80 | interval = 0.01 * self.hparams.data.voxel_size
81 | grad_value = []
82 | for offset in [(interval, 0, 0), (0, interval, 0), (0, 0, interval)]:
83 | offset_tensor = torch.tensor(offset, device=self.device)[None, :]
84 | res_p, *_ = self.sdf_decoder(encoder_output, query_xyz + offset_tensor)
85 | res_n, *_ = self.sdf_decoder(encoder_output, query_xyz - offset_tensor)
86 | grad_value.append((res_p - res_n) / (2 * interval))
87 | outputs['pd_grad'] = torch.stack(grad_value, dim=1)
88 | else:
89 | xyz = torch.clone(query_xyz)
90 | xyz.requires_grad = True
91 | with torch.enable_grad():
92 | res, *_ = self.sdf_decoder(encoder_output, xyz)
93 | outputs['pd_grad'] = torch.autograd.grad(res, [xyz],
94 | grad_outputs=torch.ones_like(res),
95 | create_graph=self.sdf_decoder.training, allow_unused=True)[0]
96 |
97 | if self.laplacian:
98 | interval = 1.0 * self.hparams.data.voxel_size
99 | laplacian_value = 0
100 |
101 | for offset in [(interval, 0, 0), (0, interval, 0), (0, 0, interval)]:
102 | offset_tensor = torch.tensor(offset, device=self.device)[None, :]
103 |
104 | # Calculate numerical gradient
105 | res, *_ = self.sdf_decoder(encoder_output, query_xyz)
106 | res_p, *_ = self.sdf_decoder(encoder_output, query_xyz + offset_tensor)
107 | res_pp, *_ = self.sdf_decoder(encoder_output, query_xyz + 2 * offset_tensor)
108 | laplacian_value += (res_pp - 2 * res_p + res) / (interval ** 2)
109 | outputs['pd_laplacian'] = laplacian_value
110 |
111 | if self.surface_normal_supervision:
112 | if self.hparams.model.network.grad_type == "Numerical":
113 | interval = 0.01 * self.hparams.data.voxel_size
114 | grad_value = []
115 | for offset in [(interval, 0, 0), (0, interval, 0), (0, 0, interval)]:
116 | offset_tensor = torch.tensor(offset, device=self.device)[None, :]
117 | res_p, *_ = self.sdf_decoder(encoder_output, on_surface_xyz + offset_tensor)
118 | res_n, *_ = self.sdf_decoder(encoder_output, on_surface_xyz - offset_tensor)
119 | grad_value.append((res_p - res_n) / (2 * interval))
120 | outputs['pd_surface_grad'] = torch.stack(grad_value, dim=1)
121 | else:
122 | xyz = torch.clone(on_surface_xyz)
123 | xyz.requires_grad = True
124 | with torch.enable_grad():
125 | res, *_ = self.sdf_decoder(encoder_output, xyz)
126 | outputs['pd_surface_grad'] = torch.autograd.grad(res, [xyz],
127 | grad_outputs=torch.ones_like(res),
128 | create_graph=self.sdf_decoder.training, allow_unused=True)[0]
129 |
130 | if self.hparams.data.supervision.udf.weight > 0:
131 | outputs['distances'], *_ = self.mask_decoder(encoder_output, mask_query_xyz)
132 |
133 | return outputs, encoder_output
134 |
135 | def loss(self, data_dict, outputs, encoder_output):
136 | l1_loss = torch.nn.L1Loss(reduction='mean')(torch.clamp(outputs['values'], min = -self.hparams.data.supervision.sdf.max_dist, max=self.hparams.data.supervision.sdf.max_dist), torch.clamp(outputs['gt_values'], min = -self.hparams.data.supervision.sdf.max_dist, max=self.hparams.data.supervision.sdf.max_dist))
137 | on_surface_loss = torch.abs(outputs['surface_values']).mean()
138 |
139 | mask_loss = torch.tensor(0.0, device=self.device)
140 | eikonal_loss = torch.tensor(0.0, device=self.device)
141 | normal_loss = torch.tensor(0.0, device=self.device)
142 | laplacian_loss = torch.tensor(0.0, device=self.device)
143 |
144 | # Create mask for points within max_dist
145 | valid_mask = (outputs['gt_values'] >= -self.hparams.data.supervision.sdf.max_dist/2) & (outputs['gt_values'] <= self.hparams.data.supervision.sdf.max_dist/2)
146 |
147 | # Eikonal Loss computation
148 | if self.eikonal:
149 | norms = torch.norm(outputs['pd_grad'], dim=1) # Compute the norm over the gradient vectors
150 | eikonal_loss = ((norms - 1) ** 2)[valid_mask].mean() # Masked eikonal loss
151 |
152 | if self.laplacian:
153 | laplacian_loss = torch.abs(outputs['pd_laplacian'])[valid_mask].mean() # Masked laplacian loss
154 |
155 | if self.surface_normal_supervision:
156 | normalized_pd_surface_grad = -outputs['pd_surface_grad'] / (torch.linalg.norm(outputs['pd_surface_grad'], dim=-1, keepdim=True) + 1.0e-6)
157 | normal_loss = 1.0 - torch.sum(normalized_pd_surface_grad * outputs['gt_on_surface_normal'], dim=-1).mean()
158 |
159 | # if self.mask:
160 | if self.hparams.data.supervision.udf.weight > 0:
161 | mask_loss = torch.nn.L1Loss(reduction='mean')(torch.clamp(outputs['distances'], max=self.hparams.data.supervision.udf.max_dist), torch.clamp(outputs['gt_distances'], max=self.hparams.data.supervision.udf.max_dist))
162 |
163 | return l1_loss, on_surface_loss, mask_loss, eikonal_loss, normal_loss, laplacian_loss
164 |
165 | def training_step(self, data_dict):
166 | """ UDF auto-encoder training stage """
167 |
168 | batch_size = self.hparams.data.batch_size
169 | outputs, encoder_output = self.forward(data_dict)
170 |
171 | l1_loss, on_surface_loss, mask_loss, eikonal_loss, normal_loss, laplacian_loss = self.loss(data_dict, outputs, encoder_output)
172 | self.log("train/l1_loss", l1_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size)
173 | self.log("train/on_surface_loss", on_surface_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size)
174 | self.log("train/mask_loss", mask_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size)
175 | self.log("train/eikonal_loss", eikonal_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size)
176 | self.log("train/laplacian_loss", laplacian_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size)
177 |
178 | total_loss = l1_loss*self.hparams.data.supervision.sdf.weight + on_surface_loss*self.hparams.data.supervision.on_surface.weight + eikonal_loss*self.hparams.data.supervision.eikonal.weight + mask_loss*self.hparams.data.supervision.udf.weight + normal_loss*self.hparams.data.supervision.on_surface.normal_weight + laplacian_loss*self.hparams.data.supervision.laplacian.weight
179 |
180 | return total_loss
181 |
182 | def validation_step(self, data_dict, idx):
183 | batch_size = 1
184 | outputs, encoder_output = self.forward(data_dict)
185 | l1_loss, on_surface_loss, mask_loss, eikonal_loss, normal_loss, laplacian_loss = self.loss(data_dict, outputs, encoder_output)
186 |
187 | self.log("val/l1_loss", l1_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size, logger=True)
188 | self.log("val/on_surface_loss", on_surface_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size, logger=True)
189 | self.log("val/mask_loss", mask_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size, logger=True)
190 | self.log("val/eikonal_loss", eikonal_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size, logger=True)
191 | self.log("val/laplacian_loss", laplacian_loss.float(), on_step=True, on_epoch=True, sync_dist=True, batch_size=batch_size, logger=True)
--------------------------------------------------------------------------------
/noksr/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/noksr/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/utils/__pycache__/evaluation.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/evaluation.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/utils/__pycache__/optimizer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/optimizer.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/utils/__pycache__/samples.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/samples.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/utils/__pycache__/segmentation.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/segmentation.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/utils/__pycache__/transform.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/__pycache__/transform.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from math import cos, pi
3 |
4 |
5 | """
6 | modified from
7 | https://github.com/thangvubk/SoftGroup/blob/a87940664d0af38a3f55cdac46f86fc83b029ecc/softgroup/util/utils.py#L55
8 | """
9 | def cosine_lr_decay(optimizer, base_lr, current_epoch, start_epoch, total_epochs, clip):
10 | if current_epoch < start_epoch:
11 | return
12 | for param_group in optimizer.param_groups:
13 | param_group['lr'] = clip + 0.5 * (base_lr - clip) * \
14 | (1 + cos(pi * ((current_epoch - start_epoch) / (total_epochs - start_epoch))))
15 |
16 | def adjust_learning_rate(
17 | initial_lr, optimizer, num_iterations, total_iterations, decreased_by
18 | ):
19 | adjust_lr_every = int(total_iterations / 2)
20 | lr = initial_lr * ((1 / decreased_by) ** (num_iterations // adjust_lr_every))
21 | for param_group in optimizer.param_groups:
22 | param_group["lr"] = lr
--------------------------------------------------------------------------------
/noksr/utils/samples.py:
--------------------------------------------------------------------------------
1 | # import ext
2 |
3 | import torch
4 | from nksr.svh import SparseFeatureHierarchy
5 | from pytorch3d.ops import knn_points
6 | from scipy.spatial import KDTree
7 |
8 |
9 | import torch
10 |
11 | class Sampler:
12 | def __init__(self, **kwargs):
13 | # Default values can be set with kwargs.get('key', default_value)
14 | self.voxel_size = kwargs.get('voxel_size')
15 | self.adaptive_policy = {
16 | 'method': 'normal',
17 | 'tau': 0.1,
18 | 'depth': 2
19 | }
20 | self.cfg = kwargs.get('cfg')
21 | self.ref_xyz = kwargs.get('ref_xyz')
22 | self.ref_normal = kwargs.get('ref_normal')
23 | self.svh = self._build_gt_svh()
24 | self.kdtree = KDTree(self.ref_xyz.detach().cpu().numpy())
25 |
26 | def _build_gt_svh(self):
27 | gt_svh = SparseFeatureHierarchy(
28 | voxel_size=self.voxel_size,
29 | depth=self.cfg.svh_tree_depth,
30 | device=self.ref_xyz.device
31 | )
32 | if self.adaptive_policy['method'] == "normal":
33 | gt_svh.build_adaptive_normal_variation(
34 | self.ref_xyz, self.ref_normal,
35 | tau=self.adaptive_policy['tau'],
36 | adaptive_depth=self.adaptive_policy['depth']
37 | )
38 | return gt_svh
39 |
40 | def _get_svh_samples(self, svh, n_samples, expand=0, expand_top=0):
41 | """
42 | Get random samples, across all layers of the decoder hierarchy
43 | :param svh: SparseFeatureHierarchy, hierarchy of spatial features
44 | :param n_samples: int, number of total samples
45 | :param expand: int, size of expansion
46 | :param expand_top: int, size of expansion of the coarsest level.
47 | :return: (n_samples, 3) tensor of positions
48 | """
49 | base_coords, base_scales = [], []
50 | for d in range(svh.depth):
51 | if svh.grids[d] is None:
52 | continue
53 | ijk_coords = svh.grids[d].active_grid_coords()
54 | d_expand = expand if d != svh.depth - 1 else expand_top
55 | if d_expand >= 3:
56 | mc_offsets = torch.arange(-d_expand // 2 + 1, d_expand // 2 + 1, device=svh.device)
57 | mc_offsets = torch.stack(torch.meshgrid(mc_offsets, mc_offsets, mc_offsets, indexing='ij'), dim=3)
58 | mc_offsets = mc_offsets.view(-1, 3)
59 | ijk_coords = (ijk_coords.unsqueeze(dim=1).repeat(1, mc_offsets.size(0), 1) +
60 | mc_offsets.unsqueeze(0)).view(-1, 3)
61 | ijk_coords = torch.unique(ijk_coords, dim=0)
62 | base_coords.append(svh.grids[d].grid_to_world(ijk_coords.float()))
63 | base_scales.append(torch.full((ijk_coords.size(0),), svh.grids[d].voxel_size, device=svh.device))
64 | base_coords, base_scales = torch.cat(base_coords), torch.cat(base_scales)
65 | local_ids = (torch.rand((n_samples,), device=svh.device) * base_coords.size(0)).long()
66 | local_coords = (torch.rand((n_samples, 3), device=svh.device) - 0.5) * base_scales[local_ids, None]
67 | query_pos = base_coords[local_ids] + local_coords
68 | return query_pos
69 |
70 | def _get_samples(self):
71 | all_samples = []
72 | for config in self.cfg.samplers:
73 | if config.type == "uniform":
74 | all_samples.append(
75 | self._get_svh_samples(self.svh, config.n_samples, config.expand, config.expand_top)
76 | )
77 | elif config.type == "band":
78 | band_inds = (torch.rand((config.n_samples, ), device=self.ref_xyz.device) * self.ref_xyz.size(0)).long()
79 | eps = config.eps * self.voxel_size
80 | band_pos = self.ref_xyz[band_inds] + \
81 | self.ref_normal[band_inds] * torch.randn((config.n_samples, 1), device=self.ref_xyz.device) * eps
82 | all_samples.append(band_pos)
83 | elif config.type == 'on_surface':
84 | n_subsample = config.subsample
85 | if 0 < n_subsample < self.ref_xyz.size(0):
86 | ref_xyz_inds = (torch.rand((n_subsample,), device=self.ref_xyz.device) *
87 | self.ref_xyz.size(0)).long()
88 | else:
89 | ref_xyz_inds = (torch.rand((n_subsample,), device=self.ref_xyz.device) *
90 | self.ref_xyz.size(0)).long()
91 | all_samples.append(self.ref_xyz[ref_xyz_inds])
92 |
93 | return torch.cat(all_samples, 0)
94 |
95 | def transform_field(self, field: torch.Tensor):
96 | sdf_config = self.cfg
97 | assert sdf_config.gt_type != "binary"
98 | truncation_size = sdf_config.gt_band * self.voxel_size
99 | if sdf_config.gt_soft:
100 | field = torch.tanh(field / truncation_size) * truncation_size
101 | else:
102 | field = torch.clone(field)
103 | field[field > truncation_size] = truncation_size
104 | field[field < -truncation_size] = -truncation_size
105 | return field
106 |
107 | def compute_gt_sdf_from_pts(self, query_pos: torch.Tensor):
108 | k = 8
109 | stdv = 0.02
110 | normals = self.ref_normal
111 | device = query_pos.device
112 | knn_output = knn_points(query_pos.unsqueeze(0).to(device), self.ref_xyz.unsqueeze(0).to(device), K=k)
113 | indices = knn_output.idx.squeeze(0)
114 | closest_points = self.ref_xyz[indices]
115 | surface_to_queries_vec = query_pos.unsqueeze(1) - closest_points #N, K, 3
116 |
117 | dot_products = torch.einsum("ijk,ijk->ij", surface_to_queries_vec, normals[indices]) #N, K
118 | vec_lengths = torch.norm(surface_to_queries_vec[:, 0, :], dim=-1)
119 | use_dot_product = vec_lengths < stdv
120 | sdf = torch.where(use_dot_product, torch.abs(dot_products[:, 0]), vec_lengths)
121 |
122 | # Adjust the sign of the sdf values based on the majority of dot products
123 | num_pos = torch.sum(dot_products > 0, dim=1)
124 | inside = num_pos <= (k / 2)
125 | sdf[inside] *= -1
126 |
127 | return -sdf
128 |
129 |
130 | class BatchedSampler:
131 | def __init__(self, hparams):
132 | self.hparams = hparams
133 |
134 | def batch_sdf_sample(self, data_dict):
135 | if 'gt_geometry' in data_dict:
136 | gt_geometry = data_dict['gt_geometry']
137 | else:
138 | xyz, normal = data_dict['all_xyz'], data_dict['all_normals']
139 | row_splits = data_dict['row_splits']
140 | batch_size = len(data_dict['scene_names'])
141 |
142 | start = 0
143 | batch_samples_pos = []
144 | batch_gt_sdf = []
145 |
146 | for i in range(batch_size):
147 | if 'gt_geometry' in data_dict:
148 | ref_xyz, ref_normal, _ = gt_geometry[i].torch_attr()
149 | else:
150 | end = start + row_splits[i]
151 | ref_xyz = xyz[start:end]
152 | ref_normal = normal[start:end]
153 | start = end
154 |
155 | # Instantiate Sampler using kwargs
156 | sampler = Sampler(
157 | voxel_size=self.hparams.data.voxel_size,
158 | cfg=self.hparams.data.supervision.sdf,
159 | ref_xyz=ref_xyz,
160 | ref_normal=ref_normal
161 | )
162 | samples_pos = sampler._get_samples()
163 | # nksr_sdf = sampler.compute_gt_chi_from_pts(samples_pos)
164 | gt_sdf = sampler.compute_gt_sdf_from_pts(samples_pos)
165 | if self.hparams.data.supervision.sdf.truncate:
166 | gt_sdf = sampler.transform_field(gt_sdf)
167 |
168 | batch_samples_pos.append(samples_pos)
169 | batch_gt_sdf.append(gt_sdf)
170 |
171 | batch_samples_pos = torch.cat(batch_samples_pos, dim=0)
172 | batch_gt_sdf = torch.cat(batch_gt_sdf, dim=0)
173 |
174 | return batch_samples_pos, batch_gt_sdf
175 |
176 | def batch_udf_sample(self, data_dict):
177 | if 'gt_geometry' in data_dict:
178 | gt_geometry = data_dict['gt_geometry']
179 | else:
180 | xyz, normal = data_dict['all_xyz'], data_dict['all_normals']
181 | row_splits = data_dict['row_splits']
182 | batch_size = len(data_dict['scene_names'])
183 |
184 | start = 0
185 | batch_samples_pos = []
186 | batch_gt_udf = []
187 |
188 | for i in range(batch_size):
189 | if 'gt_geometry' in data_dict:
190 | ref_xyz, ref_normal, _ = gt_geometry[i].torch_attr()
191 | else:
192 | end = start + row_splits[i]
193 | ref_xyz = xyz[start:end]
194 | ref_normal = normal[start:end]
195 | start = end
196 |
197 | # Instantiate Sampler for UDF using kwargs
198 | sampler = Sampler(
199 | voxel_size=self.hparams.data.voxel_size,
200 | cfg=self.hparams.data.supervision.udf,
201 | ref_xyz=ref_xyz,
202 | ref_normal=ref_normal
203 | )
204 | samples_pos = sampler._get_samples()
205 | if self.hparams.data.supervision.udf.abs_sdf:
206 | gt_udf = torch.abs(sampler.compute_gt_sdf_from_pts(samples_pos))
207 | else:
208 | knn_output = knn_points(samples_pos.unsqueeze(0).to(torch.device("cuda")),
209 | ref_xyz.unsqueeze(0).to(torch.device("cuda")),
210 | K=1)
211 | gt_udf = knn_output.dists.squeeze(0).squeeze(-1)
212 |
213 | batch_samples_pos.append(samples_pos)
214 | batch_gt_udf.append(gt_udf)
215 |
216 |
217 | batch_samples_pos = torch.cat(batch_samples_pos, dim=0)
218 | batch_gt_udf = torch.cat(batch_gt_udf, dim=0)
219 |
220 | return batch_samples_pos, batch_gt_udf
221 |
222 | def batch_on_surface_sample(self, data_dict):
223 | if 'gt_geometry' in data_dict:
224 | gt_geometry = data_dict['gt_geometry']
225 | else:
226 | xyz, normal = data_dict['all_xyz'], data_dict['all_normals']
227 | row_splits = data_dict['row_splits']
228 | batch_size = len(data_dict['scene_names'])
229 |
230 | start = 0
231 | batch_samples_pos = []
232 | batch_samples_normal = []
233 | batch_gt_udf = []
234 |
235 | for i in range(batch_size):
236 | if 'gt_geometry' in data_dict:
237 | ref_xyz, ref_normal, _ = gt_geometry[i].torch_attr()
238 | else:
239 | end = start + row_splits[i]
240 | ref_xyz = xyz[start:end]
241 | ref_normal = normal[start:end]
242 | start = end
243 |
244 | n_subsample = self.hparams.data.supervision.on_surface.subsample
245 | if 0 < n_subsample < ref_xyz.size(0):
246 | ref_xyz_inds = (torch.rand((n_subsample,), device=ref_xyz.device) *
247 | ref_xyz.size(0)).long()
248 | else:
249 | ref_xyz_inds = (torch.rand((n_subsample,), device=ref_xyz.device) *
250 | ref_xyz.size(0)).long()
251 | batch_samples_pos.append(ref_xyz[ref_xyz_inds])
252 | batch_samples_normal.append(ref_normal[ref_xyz_inds])
253 |
254 | batch_samples_pos = torch.cat(batch_samples_pos, dim=0)
255 | batch_samples_normal = torch.cat(batch_samples_normal, dim=0)
256 | return batch_samples_pos, batch_samples_normal
257 |
258 |
259 |
--------------------------------------------------------------------------------
/noksr/utils/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from noksr.utils.serialization import encode
3 |
4 | def segment_and_generate_encoder_outputs(batch, model, device, segment_num, grid_size=0.01, serial_order='z'):
5 | """
6 | Segment input `xyz` of the batch, generate encoder outputs for each segment, and return necessary metadata.
7 |
8 | Args:
9 | batch (dict): Batch data containing `xyz` and other input features.
10 | model (torch.nn.Module): The neural network model for processing data.
11 | device (torch.device): The device to run inference on.
12 | segment_num (int): Fixed number of segments.
13 | grid_size (float): Grid size for serialization.
14 | serial_order (str): Order for serialization ('z' by default).
15 |
16 | Returns:
17 | encoder_outputs (list): List of encoder outputs for each segment.
18 | encoding_codes (list): List of (start_code, end_code) for each segment.
19 | depth (int): Depth used for serialization.
20 | """
21 | # Extract necessary data from the batch
22 | xyz = batch['xyz'].to(torch.float32) # Convert to float32 for processing
23 | total_points = xyz.shape[0]
24 |
25 | # Compute segment length based on fixed number of segments
26 | segment_length = max(1, total_points // segment_num)
27 |
28 | in_quant_coords = torch.floor(xyz / grid_size).to(torch.int)
29 | depth = int(torch.abs(in_quant_coords).max()).bit_length() # Calculate serialization depth
30 | in_quant_codes = encode(
31 | in_quant_coords,
32 | torch.zeros(in_quant_coords.shape[0], dtype=torch.int64, device=in_quant_coords.device),
33 | depth,
34 | order=serial_order
35 | )
36 | in_sorted_quant_codes, in_sorted_indices = torch.sort(in_quant_codes)
37 |
38 | segments = []
39 | encoding_codes = []
40 | for i in range(segment_num):
41 | start_idx = i * segment_length
42 | end_idx = min(start_idx + segment_length, total_points)
43 | segment_indices = in_sorted_indices[start_idx:end_idx]
44 | segments.append(segment_indices)
45 |
46 | # Store the start and end encoding codes for the segment
47 | start_code = in_sorted_quant_codes[start_idx].item()
48 | end_code = in_sorted_quant_codes[end_idx - 1].item()
49 | encoding_codes.append(start_code)
50 |
51 | # Generate encoder outputs for each segment
52 | encoder_outputs = []
53 | for segment_indices in segments:
54 | # Create a new batch for the current segment
55 | segment_batch = {
56 | "xyz": batch['xyz'][segment_indices],
57 | "point_features": batch['point_features'][segment_indices],
58 | "scene_names": batch['scene_names'],
59 | "xyz_splits": torch.tensor([len(segment_indices)], device=batch['xyz'].device)
60 | }
61 |
62 | pt_data = {
63 | 'feat': segment_batch['point_features'],
64 | 'offset': segment_batch['xyz_splits'], # Offset for the segment
65 | 'grid_size': grid_size,
66 | 'coord': segment_batch['xyz']
67 | }
68 | segment_encoder_output = model.point_transformer(pt_data)
69 | encoder_outputs.append(segment_encoder_output)
70 |
71 | return encoder_outputs, encoding_codes, depth
--------------------------------------------------------------------------------
/noksr/utils/serialization/__init__.py:
--------------------------------------------------------------------------------
1 | from .default import (
2 | encode,
3 | decode,
4 | z_order_encode,
5 | z_order_decode,
6 | hilbert_encode,
7 | hilbert_decode,
8 | )
9 |
--------------------------------------------------------------------------------
/noksr/utils/serialization/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/serialization/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/utils/serialization/__pycache__/default.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/serialization/__pycache__/default.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/utils/serialization/__pycache__/hilbert.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/serialization/__pycache__/hilbert.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/utils/serialization/__pycache__/z_order.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/noksr/utils/serialization/__pycache__/z_order.cpython-310.pyc
--------------------------------------------------------------------------------
/noksr/utils/serialization/default.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .z_order import xyz2key as z_order_encode_
3 | from .z_order import key2xyz as z_order_decode_
4 | from .hilbert import encode as hilbert_encode_
5 | from .hilbert import decode as hilbert_decode_
6 |
7 |
8 | @torch.inference_mode()
9 | def encode(grid_coord, batch=None, depth=16, order="z"):
10 | assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
11 | if order == "z":
12 | code = z_order_encode(grid_coord, depth=depth)
13 | elif order == "z-trans":
14 | code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
15 | elif order == "hilbert":
16 | code = hilbert_encode(grid_coord, depth=depth)
17 | elif order == "hilbert-trans":
18 | code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
19 | else:
20 | raise NotImplementedError
21 | if batch is not None:
22 | batch = batch.long()
23 | code = batch << depth * 3 | code
24 | return code
25 |
26 |
27 | @torch.inference_mode()
28 | def decode(code, depth=16, order="z"):
29 | assert order in {"z", "hilbert"}
30 | batch = code >> depth * 3
31 | code = code & ((1 << depth * 3) - 1)
32 | if order == "z":
33 | grid_coord = z_order_decode(code, depth=depth)
34 | elif order == "hilbert":
35 | grid_coord = hilbert_decode(code, depth=depth)
36 | else:
37 | raise NotImplementedError
38 | return grid_coord, batch
39 |
40 |
41 | def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
42 | x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
43 | # we block the support to batch, maintain batched code in Point class
44 | code = z_order_encode_(x, y, z, b=None, depth=depth)
45 | return code
46 |
47 |
48 | def z_order_decode(code: torch.Tensor, depth):
49 | x, y, z = z_order_decode_(code, depth=depth)
50 | grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3)
51 | return grid_coord
52 |
53 |
54 | def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
55 | return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
56 |
57 |
58 | def hilbert_decode(code: torch.Tensor, depth: int = 16):
59 | return hilbert_decode_(code, num_dims=3, num_bits=depth)
60 |
--------------------------------------------------------------------------------
/noksr/utils/serialization/hilbert.py:
--------------------------------------------------------------------------------
1 | """
2 | Hilbert Order
3 | Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve
4 |
5 | Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu
6 | Please cite our work if the code is helpful to you.
7 | """
8 |
9 | import torch
10 |
11 |
12 | def right_shift(binary, k=1, axis=-1):
13 | """Right shift an array of binary values.
14 |
15 | Parameters:
16 | -----------
17 | binary: An ndarray of binary values.
18 |
19 | k: The number of bits to shift. Default 1.
20 |
21 | axis: The axis along which to shift. Default -1.
22 |
23 | Returns:
24 | --------
25 | Returns an ndarray with zero prepended and the ends truncated, along
26 | whatever axis was specified."""
27 |
28 | # If we're shifting the whole thing, just return zeros.
29 | if binary.shape[axis] <= k:
30 | return torch.zeros_like(binary)
31 |
32 | # Determine the padding pattern.
33 | # padding = [(0,0)] * len(binary.shape)
34 | # padding[axis] = (k,0)
35 |
36 | # Determine the slicing pattern to eliminate just the last one.
37 | slicing = [slice(None)] * len(binary.shape)
38 | slicing[axis] = slice(None, -k)
39 | shifted = torch.nn.functional.pad(
40 | binary[tuple(slicing)], (k, 0), mode="constant", value=0
41 | )
42 |
43 | return shifted
44 |
45 |
46 | def binary2gray(binary, axis=-1):
47 | """Convert an array of binary values into Gray codes.
48 |
49 | This uses the classic X ^ (X >> 1) trick to compute the Gray code.
50 |
51 | Parameters:
52 | -----------
53 | binary: An ndarray of binary values.
54 |
55 | axis: The axis along which to compute the gray code. Default=-1.
56 |
57 | Returns:
58 | --------
59 | Returns an ndarray of Gray codes.
60 | """
61 | shifted = right_shift(binary, axis=axis)
62 |
63 | # Do the X ^ (X >> 1) trick.
64 | gray = torch.logical_xor(binary, shifted)
65 |
66 | return gray
67 |
68 |
69 | def gray2binary(gray, axis=-1):
70 | """Convert an array of Gray codes back into binary values.
71 |
72 | Parameters:
73 | -----------
74 | gray: An ndarray of gray codes.
75 |
76 | axis: The axis along which to perform Gray decoding. Default=-1.
77 |
78 | Returns:
79 | --------
80 | Returns an ndarray of binary values.
81 | """
82 |
83 | # Loop the log2(bits) number of times necessary, with shift and xor.
84 | shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
85 | while shift > 0:
86 | gray = torch.logical_xor(gray, right_shift(gray, shift))
87 | shift = torch.div(shift, 2, rounding_mode="floor")
88 | return gray
89 |
90 |
91 | def encode(locs, num_dims, num_bits):
92 | """Decode an array of locations in a hypercube into a Hilbert integer.
93 |
94 | This is a vectorized-ish version of the Hilbert curve implementation by John
95 | Skilling as described in:
96 |
97 | Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
98 | Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
99 |
100 | Params:
101 | -------
102 | locs - An ndarray of locations in a hypercube of num_dims dimensions, in
103 | which each dimension runs from 0 to 2**num_bits-1. The shape can
104 | be arbitrary, as long as the last dimension of the same has size
105 | num_dims.
106 |
107 | num_dims - The dimensionality of the hypercube. Integer.
108 |
109 | num_bits - The number of bits for each dimension. Integer.
110 |
111 | Returns:
112 | --------
113 | The output is an ndarray of uint64 integers with the same shape as the
114 | input, excluding the last dimension, which needs to be num_dims.
115 | """
116 |
117 | # Keep around the original shape for later.
118 | orig_shape = locs.shape
119 | bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
120 | bitpack_mask_rev = bitpack_mask.flip(-1)
121 |
122 | if orig_shape[-1] != num_dims:
123 | raise ValueError(
124 | """
125 | The shape of locs was surprising in that the last dimension was of size
126 | %d, but num_dims=%d. These need to be equal.
127 | """
128 | % (orig_shape[-1], num_dims)
129 | )
130 |
131 | if num_dims * num_bits > 63:
132 | raise ValueError(
133 | """
134 | num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
135 | into a int64. Are you sure you need that many points on your Hilbert
136 | curve?
137 | """
138 | % (num_dims, num_bits, num_dims * num_bits)
139 | )
140 |
141 | # Treat the location integers as 64-bit unsigned and then split them up into
142 | # a sequence of uint8s. Preserve the association by dimension.
143 | locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
144 |
145 | # Now turn these into bits and truncate to num_bits.
146 | gray = (
147 | locs_uint8.unsqueeze(-1)
148 | .bitwise_and(bitpack_mask_rev)
149 | .ne(0)
150 | .byte()
151 | .flatten(-2, -1)[..., -num_bits:]
152 | )
153 |
154 | # Run the decoding process the other way.
155 | # Iterate forwards through the bits.
156 | for bit in range(0, num_bits):
157 | # Iterate forwards through the dimensions.
158 | for dim in range(0, num_dims):
159 | # Identify which ones have this bit active.
160 | mask = gray[:, dim, bit]
161 |
162 | # Where this bit is on, invert the 0 dimension for lower bits.
163 | gray[:, 0, bit + 1 :] = torch.logical_xor(
164 | gray[:, 0, bit + 1 :], mask[:, None]
165 | )
166 |
167 | # Where the bit is off, exchange the lower bits with the 0 dimension.
168 | to_flip = torch.logical_and(
169 | torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
170 | torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
171 | )
172 | gray[:, dim, bit + 1 :] = torch.logical_xor(
173 | gray[:, dim, bit + 1 :], to_flip
174 | )
175 | gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
176 |
177 | # Now flatten out.
178 | gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
179 |
180 | # Convert Gray back to binary.
181 | hh_bin = gray2binary(gray)
182 |
183 | # Pad back out to 64 bits.
184 | extra_dims = 64 - num_bits * num_dims
185 | padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
186 |
187 | # Convert binary values into uint8s.
188 | hh_uint8 = (
189 | (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
190 | .sum(2)
191 | .squeeze()
192 | .type(torch.uint8)
193 | )
194 |
195 | # Convert uint8s into uint64s.
196 | hh_uint64 = hh_uint8.view(torch.int64).squeeze()
197 |
198 | return hh_uint64
199 |
200 |
201 | def decode(hilberts, num_dims, num_bits):
202 | """Decode an array of Hilbert integers into locations in a hypercube.
203 |
204 | This is a vectorized-ish version of the Hilbert curve implementation by John
205 | Skilling as described in:
206 |
207 | Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
208 | Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
209 |
210 | Params:
211 | -------
212 | hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
213 | cannot have fewer bits than num_dims * num_bits.
214 |
215 | num_dims - The dimensionality of the hypercube. Integer.
216 |
217 | num_bits - The number of bits for each dimension. Integer.
218 |
219 | Returns:
220 | --------
221 | The output is an ndarray of unsigned integers with the same shape as hilberts
222 | but with an additional dimension of size num_dims.
223 | """
224 |
225 | if num_dims * num_bits > 64:
226 | raise ValueError(
227 | """
228 | num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
229 | into a uint64. Are you sure you need that many points on your Hilbert
230 | curve?
231 | """
232 | % (num_dims, num_bits)
233 | )
234 |
235 | # Handle the case where we got handed a naked integer.
236 | hilberts = torch.atleast_1d(hilberts)
237 |
238 | # Keep around the shape for later.
239 | orig_shape = hilberts.shape
240 | bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
241 | bitpack_mask_rev = bitpack_mask.flip(-1)
242 |
243 | # Treat each of the hilberts as a s equence of eight uint8.
244 | # This treats all of the inputs as uint64 and makes things uniform.
245 | hh_uint8 = (
246 | hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
247 | )
248 |
249 | # Turn these lists of uints into lists of bits and then truncate to the size
250 | # we actually need for using Skilling's procedure.
251 | hh_bits = (
252 | hh_uint8.unsqueeze(-1)
253 | .bitwise_and(bitpack_mask_rev)
254 | .ne(0)
255 | .byte()
256 | .flatten(-2, -1)[:, -num_dims * num_bits :]
257 | )
258 |
259 | # Take the sequence of bits and Gray-code it.
260 | gray = binary2gray(hh_bits)
261 |
262 | # There has got to be a better way to do this.
263 | # I could index them differently, but the eventual packbits likes it this way.
264 | gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
265 |
266 | # Iterate backwards through the bits.
267 | for bit in range(num_bits - 1, -1, -1):
268 | # Iterate backwards through the dimensions.
269 | for dim in range(num_dims - 1, -1, -1):
270 | # Identify which ones have this bit active.
271 | mask = gray[:, dim, bit]
272 |
273 | # Where this bit is on, invert the 0 dimension for lower bits.
274 | gray[:, 0, bit + 1 :] = torch.logical_xor(
275 | gray[:, 0, bit + 1 :], mask[:, None]
276 | )
277 |
278 | # Where the bit is off, exchange the lower bits with the 0 dimension.
279 | to_flip = torch.logical_and(
280 | torch.logical_not(mask[:, None]),
281 | torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
282 | )
283 | gray[:, dim, bit + 1 :] = torch.logical_xor(
284 | gray[:, dim, bit + 1 :], to_flip
285 | )
286 | gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
287 |
288 | # Pad back out to 64 bits.
289 | extra_dims = 64 - num_bits
290 | padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
291 |
292 | # Now chop these up into blocks of 8.
293 | locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
294 |
295 | # Take those blocks and turn them unto uint8s.
296 | # from IPython import embed; embed()
297 | locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
298 |
299 | # Finally, treat these as uint64s.
300 | flat_locs = locs_uint8.view(torch.int64)
301 |
302 | # Return them in the expected shape.
303 | return flat_locs.reshape((*orig_shape, num_dims))
304 |
--------------------------------------------------------------------------------
/noksr/utils/serialization/z_order.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Octree-based Sparse Convolutional Neural Networks
3 | # Copyright (c) 2022 Peng-Shuai Wang
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from typing import Optional, Union
10 |
11 |
12 | class KeyLUT:
13 | def __init__(self):
14 | r256 = torch.arange(256, dtype=torch.int64)
15 | r512 = torch.arange(512, dtype=torch.int64)
16 | zero = torch.zeros(256, dtype=torch.int64)
17 | device = torch.device("cpu")
18 |
19 | self._encode = {
20 | device: (
21 | self.xyz2key(r256, zero, zero, 8),
22 | self.xyz2key(zero, r256, zero, 8),
23 | self.xyz2key(zero, zero, r256, 8),
24 | )
25 | }
26 | self._decode = {device: self.key2xyz(r512, 9)}
27 |
28 | def encode_lut(self, device=torch.device("cpu")):
29 | if device not in self._encode:
30 | cpu = torch.device("cpu")
31 | self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
32 | return self._encode[device]
33 |
34 | def decode_lut(self, device=torch.device("cpu")):
35 | if device not in self._decode:
36 | cpu = torch.device("cpu")
37 | self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
38 | return self._decode[device]
39 |
40 | def xyz2key(self, x, y, z, depth):
41 | key = torch.zeros_like(x)
42 | for i in range(depth):
43 | mask = 1 << i
44 | key = (
45 | key
46 | | ((x & mask) << (2 * i + 2))
47 | | ((y & mask) << (2 * i + 1))
48 | | ((z & mask) << (2 * i + 0))
49 | )
50 | return key
51 |
52 | def key2xyz(self, key, depth):
53 | x = torch.zeros_like(key)
54 | y = torch.zeros_like(key)
55 | z = torch.zeros_like(key)
56 | for i in range(depth):
57 | x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
58 | y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
59 | z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
60 | return x, y, z
61 |
62 |
63 | _key_lut = KeyLUT()
64 |
65 |
66 | def xyz2key(
67 | x: torch.Tensor,
68 | y: torch.Tensor,
69 | z: torch.Tensor,
70 | b: Optional[Union[torch.Tensor, int]] = None,
71 | depth: int = 16,
72 | ):
73 | r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
74 | based on pre-computed look up tables. The speed of this function is much
75 | faster than the method based on for-loop.
76 |
77 | Args:
78 | x (torch.Tensor): The x coordinate.
79 | y (torch.Tensor): The y coordinate.
80 | z (torch.Tensor): The z coordinate.
81 | b (torch.Tensor or int): The batch index of the coordinates, and should be
82 | smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
83 | :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
84 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
85 | """
86 |
87 | EX, EY, EZ = _key_lut.encode_lut(x.device)
88 | x, y, z = x.long(), y.long(), z.long()
89 |
90 | mask = 255 if depth > 8 else (1 << depth) - 1
91 | key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
92 | if depth > 8:
93 | mask = (1 << (depth - 8)) - 1
94 | key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
95 | key = key16 << 24 | key
96 |
97 | if b is not None:
98 | b = b.long()
99 | key = b << 48 | key
100 |
101 | return key
102 |
103 |
104 | def key2xyz(key: torch.Tensor, depth: int = 16):
105 | r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
106 | and the batch index based on pre-computed look up tables.
107 |
108 | Args:
109 | key (torch.Tensor): The shuffled key.
110 | depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
111 | """
112 |
113 | DX, DY, DZ = _key_lut.decode_lut(key.device)
114 | x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
115 |
116 | b = key >> 48
117 | key = key & ((1 << 48) - 1)
118 |
119 | n = (depth + 2) // 3
120 | for i in range(n):
121 | k = key >> (i * 9) & 511
122 | x = x | (DX[k] << (i * 3))
123 | y = y | (DY[k] << (i * 3))
124 | z = z | (DZ[k] << (i * 3))
125 |
126 | return x, y, z, b
127 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==1.9.4
2 | lightning-bolts
3 | scipy
4 | open3d
5 | h5py
6 | ninja
7 | wandb
8 | hydra-core
9 | cython
10 | scikit-image
11 | trimesh
12 | arrgh
13 | plyfile
14 | imageio-ffmpeg
15 | gin-config
16 | torchviz
17 | thop
18 | imageio
19 | einops
20 | spconv-cu118
21 | timm
22 | flash-attn==2.6.3
23 |
--------------------------------------------------------------------------------
/scripts/segment_carla.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from pathlib import Path
5 | import shutil
6 | from noksr.utils.serialization import encode
7 | from tqdm import tqdm
8 |
9 | def _serial_scene(scene, segment_length, segment_num, grid_size=0.01, serial_order='z'):
10 | # Process ground truth points
11 | in_xyz = torch.from_numpy(scene['in_xyz']).to(torch.float32)
12 | in_quant_coords = torch.floor(in_xyz / grid_size).to(torch.int)
13 | gt_xyz = torch.from_numpy(scene['gt_xyz']).to(torch.float32)
14 | gt_quant_coords = torch.floor(gt_xyz / grid_size).to(torch.int)
15 |
16 | depth = int(max(gt_quant_coords.max(), in_quant_coords.max())).bit_length()
17 | gt_quant_codes = encode(gt_quant_coords, torch.zeros(gt_quant_coords.shape[0], dtype=torch.int64, device=gt_quant_coords.device), depth, order=serial_order)
18 | gt_sorted_quant_codes, gt_sorted_indices = torch.sort(gt_quant_codes)
19 |
20 | total_gt_points = len(gt_sorted_quant_codes)
21 |
22 | in_quant_codes = encode(in_quant_coords, torch.zeros(in_quant_coords.shape[0], dtype=torch.int64, device=in_quant_coords.device), depth, order=serial_order)
23 | in_sorted_quant_codes, in_sorted_indices = torch.sort(in_quant_codes)
24 |
25 | # Segment ground truth points and find corresponding input points
26 | segments = []
27 | for i in range(segment_num):
28 | # Ground truth segmentation
29 | gt_start_idx = i * segment_length
30 | gt_end_idx = gt_start_idx + segment_length
31 | gt_end_idx = min(gt_end_idx, total_gt_points)
32 | gt_segment_indices = gt_sorted_indices[gt_start_idx:gt_end_idx]
33 |
34 | # Input segmentation within the ground truth range
35 | in_segment_indices = in_sorted_indices[(in_sorted_quant_codes >= gt_sorted_quant_codes[gt_start_idx].item()) &
36 | (in_sorted_quant_codes < gt_sorted_quant_codes[min(gt_end_idx, total_gt_points - 1)].item())]
37 |
38 | segments.append({
39 | "gt_indices": gt_segment_indices.numpy(),
40 | "in_indices": in_segment_indices.numpy()
41 | })
42 |
43 | return segments
44 |
45 | def update_lst_files(input_drive_path, output_drive_path, segmented_scenes):
46 | """
47 | Update .lst files in the output directory by expanding entries for segmented scenes.
48 |
49 | Args:
50 | input_drive_path (Path): Path to the input drive directory.
51 | output_drive_path (Path): Path to the output drive directory.
52 | segmented_scenes (dict): Dictionary of original scenes to their segment counts (e.g., {"rgn0000": 12}).
53 | """
54 | lst_files = ['test.lst', 'testall.lst', 'train.lst', 'val.lst']
55 |
56 | for lst_file in lst_files:
57 | input_lst_file = input_drive_path / lst_file
58 | output_lst_file = output_drive_path / lst_file
59 |
60 | if not input_lst_file.exists():
61 | continue # Skip if the .lst file doesn't exist in the input
62 |
63 | with input_lst_file.open('r') as infile, output_lst_file.open('w') as outfile:
64 | for line in infile:
65 | scene_name = line.strip() # Original scene name without extension
66 | if scene_name in segmented_scenes:
67 | # Expand scene into its segments
68 | num_segments = segmented_scenes[scene_name]
69 | for i in range(num_segments):
70 | outfile.write(f"{scene_name.split('-')[0]}-crop{(i):04d}\n")
71 | else:
72 | # Write the scene as-is if it's not segmented
73 | outfile.write(line)
74 |
75 | def process_dataset(input_path, output_path, fixed_segment_length):
76 | input_base = Path(input_path)
77 | output_base = Path(output_path)
78 | output_base.mkdir(parents=True, exist_ok=True)
79 | target_drives = ['Town01-0', 'Town01-1', 'Town01-2',
80 | 'Town02-0', 'Town02-1', 'Town02-2',
81 | 'Town10-0', 'Town10-1', 'Town10-2', 'Town10-3', 'Town10-4']
82 |
83 | # Copy list files (e.g., test.lst, val.lst)
84 | for file in input_base.glob('*.lst'):
85 | shutil.copy(file, output_base / file.name)
86 |
87 | # Count total number of scenes for progress bar initialization
88 | total_scenes = sum(
89 | 1
90 | for drive in input_base.iterdir()
91 | if drive.is_dir() and drive.name in target_drives
92 | for item in drive.iterdir()
93 | if item.is_dir()
94 | )
95 |
96 | # Process each drive folder with progress bar
97 | with tqdm(total=total_scenes, desc="Processing Scenes", unit="scene") as pbar:
98 | for drive in input_base.iterdir():
99 | if drive.is_dir() and drive.name in target_drives:
100 | drive_output = output_base / drive.name
101 | drive_output.mkdir(parents=True, exist_ok=True)
102 | segmented_scenes = {}
103 |
104 | for item in drive.iterdir():
105 | if item.is_dir():
106 | scene_name = item.name
107 | output_scene_base = drive_output
108 |
109 | # Load data
110 | data_file = item / 'pointcloud.npz'
111 | gt_file = item / 'groundtruth.bin'
112 | gt_data = np.load(gt_file, allow_pickle=True)
113 | data = np.load(data_file)
114 | scene = {
115 | 'in_xyz': data['points'],
116 | 'in_normal': data['normals'],
117 | 'gt_xyz': gt_data['xyz'],
118 | 'gt_normal': gt_data['normal']
119 | }
120 |
121 | total_gt_points = len(scene['gt_xyz'])
122 |
123 | # Step 1: Compute segment number
124 | segment_num = max(1, total_gt_points // fixed_segment_length)
125 |
126 | # Step 2: Recompute segment length for even distribution
127 | segment_length = total_gt_points // segment_num
128 |
129 | # Serialize and segment scene
130 | segments = _serial_scene(scene, segment_length, segment_num)
131 | segmented_scenes[scene_name] = len(segments) # Record segment count for this scene
132 | for i, segment in enumerate(segments):
133 | segment_scene_name = f"{scene_name.split('-')[0]}-crop{(i):04d}"
134 | segment_output = output_scene_base / segment_scene_name
135 | segment_output.mkdir(parents=True, exist_ok=True)
136 |
137 | # Save segmented data
138 | segment_data = {
139 | 'points': scene['in_xyz'][segment["in_indices"]],
140 | 'normals': scene['in_normal'][segment["in_indices"]],
141 | 'ref_xyz': scene['gt_xyz'][segment["gt_indices"]],
142 | 'ref_normals': scene['gt_normal'][segment["gt_indices"]]
143 | }
144 | np.savez(segment_output / 'pointcloud.npz', **segment_data)
145 |
146 | # Update progress bar
147 | pbar.update(1)
148 |
149 | update_lst_files(drive, drive_output, segmented_scenes)
150 |
151 | if __name__ == "__main__":
152 | """
153 | This script is used to regenerate the training segments by 1-d serialization, original Carla patches are uniformly sampled.
154 | """
155 | input_path = "./data/carla-lidar/dataset-no-patch" # Replace with your input dataset path
156 | output_path = "./data/carla-lidar/dataset-seg-patch" # Replace with your desired output path
157 | fixed_segment_length = 300000 # Fixed segment length to start computation
158 |
159 | process_dataset(input_path, output_path, fixed_segment_length)
--------------------------------------------------------------------------------
/train.log:
--------------------------------------------------------------------------------
1 | [2025-01-02 15:44:07,239][pycg.exp][WARNING] - Customized build of Open3D is not detected, to resolve this you can do:
2 | (recommended, using customized Open3D that enables view sync, animation, ...)
3 |
4 | pip install python-pycg[full] -f https://pycg.s3.ap-northeast-1.amazonaws.com/packages/index.html
5 |
6 |
7 | [2025-01-02 15:44:10,471][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0
8 | [2025-01-02 15:44:10,471][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
9 | [2025-01-02 15:46:55,468][torch.nn.parallel.distributed][INFO] - Reducer buckets have been rebuilt in this iteration.
10 | [2025-01-02 15:53:14,873][pycg.exp][WARNING] - Customized build of Open3D is not detected, to resolve this you can do:
11 | (recommended, using customized Open3D that enables view sync, animation, ...)
12 |
13 | pip install python-pycg[full] -f https://pycg.s3.ap-northeast-1.amazonaws.com/packages/index.html
14 |
15 |
16 | [2025-01-02 15:53:18,417][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0
17 | [2025-01-02 15:53:18,417][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
18 | [2025-01-02 15:53:22,497][torch.nn.parallel.distributed][INFO] - Reducer buckets have been rebuilt in this iteration.
19 | [2025-01-02 15:53:32,911][pycg.exp][WARNING] - Customized build of Open3D is not detected, to resolve this you can do:
20 | (recommended, using customized Open3D that enables view sync, animation, ...)
21 |
22 | pip install python-pycg[full] -f https://pycg.s3.ap-northeast-1.amazonaws.com/packages/index.html
23 |
24 |
25 | [2025-01-02 15:53:36,194][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0
26 | [2025-01-02 15:53:36,194][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
27 | [2025-01-02 15:53:41,534][torch.nn.parallel.distributed][INFO] - Reducer buckets have been rebuilt in this iteration.
28 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import hydra
3 | import wandb
4 | import pytorch_lightning as pl
5 | from noksr.callback import *
6 | from importlib import import_module
7 | from noksr.data.data_module import DataModule
8 | from pytorch_lightning.callbacks import LearningRateMonitor
9 | from pytorch_lightning.strategies.ddp import DDPStrategy
10 |
11 |
12 |
13 | def init_callbacks(cfg):
14 | checkpoint_monitor = hydra.utils.instantiate(cfg.model.checkpoint_monitor)
15 | gpu_cache_clean_monitor = GPUCacheCleanCallback()
16 | lr_monitor = LearningRateMonitor(logging_interval="epoch")
17 | return [checkpoint_monitor, gpu_cache_clean_monitor, lr_monitor]
18 |
19 |
20 | @hydra.main(version_base=None, config_path="config", config_name="config")
21 | def main(cfg):
22 | # fix the seed
23 | pl.seed_everything(cfg.global_train_seed, workers=True)
24 |
25 | output_path = os.path.join(cfg.exp_output_root_path, "training")
26 | os.makedirs(output_path, exist_ok=True)
27 |
28 | print("==> initializing data ...")
29 | data_module = DataModule(cfg)
30 |
31 | print("==> initializing logger ...")
32 | logger = hydra.utils.instantiate(cfg.model.logger, save_dir=output_path)
33 |
34 | print("==> initializing monitor ...")
35 | callbacks = init_callbacks(cfg)
36 |
37 |
38 | print("==> initializing trainer ...")
39 | trainer = pl.Trainer(callbacks=callbacks, logger=logger, **cfg.model.trainer, strategy=DDPStrategy(find_unused_parameters=False))
40 |
41 | print("==> initializing model ...")
42 | model = getattr(import_module("noksr.model"), cfg.model.network.module)(cfg)
43 |
44 |
45 | print("==> start training ...")
46 | trainer.fit(model=model, datamodule=data_module, ckpt_path=cfg.model.ckpt_path)
47 |
48 | if __name__ == '__main__':
49 | main()
--------------------------------------------------------------------------------
/wandb/debug-cli.zla247.log:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theialab/noksr/899f827e7fbe64f2f084fbab1e57a354ed507133/wandb/debug-cli.zla247.log
--------------------------------------------------------------------------------
/wandb/settings:
--------------------------------------------------------------------------------
1 | [default]
2 |
3 |
--------------------------------------------------------------------------------