├── .gitignore ├── LICENSE ├── README.md ├── configs ├── config_igr.yaml ├── config_neural_pull.yaml ├── config_siren0.033 copy.yaml └── config_siren0.33.yaml ├── diffcd ├── __init__.py ├── closest_point.py ├── datasets.py ├── evaluation │ ├── __init__.py │ ├── chamfer.py │ ├── contours.py │ └── meshing.py ├── methods.py ├── networks.py ├── newton.py ├── samplers.py ├── training.py └── utils.py ├── evaluation.ipynb ├── fit_implicit.py ├── images ├── results │ ├── max_noise │ │ ├── grid_metrics_diffcd.png │ │ ├── grid_metrics_igr.png │ │ ├── grid_metrics_neural-pull.png │ │ ├── grid_metrics_nksr.png │ │ ├── grid_metrics_siren0.033.png │ │ └── grid_metrics_siren0.33.png │ ├── medium_noise │ │ ├── grid_metrics_diffcd.png │ │ ├── grid_metrics_igr.png │ │ ├── grid_metrics_neural-pull.png │ │ ├── grid_metrics_nksr.png │ │ ├── grid_metrics_siren0.033.png │ │ └── grid_metrics_siren0.33.png │ └── no_noise │ │ ├── grid_metrics-diffcd.png │ │ ├── grid_metrics-nksr.png │ │ ├── grid_metrics_igr.png │ │ ├── grid_metrics_neural-pull.png │ │ ├── grid_metrics_siren0.033.png │ │ └── grid_metrics_siren0.33.png └── teaser.png ├── levelset_figures.ipynb ├── notebook_utils.py ├── notebook_utils_blender.py ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── test_methods.py └── test_training.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffCD: A Symmetric Differentiable Chamfer Distance for Neural Implicit Surface Fitting 2 | 3 | > **ECCV 2024** 4 | > 5 | > [Linus Härenstam-Nielsen](https://cvg.cit.tum.de/members/hael), [Lu Sang](https://sangluisme.github.io/), [Abhishek Saroha](https://cvg.cit.tum.de/members/saroha), [Nikita Araslanov](https://arnike.github.io/) and [Daniel Cremers](https://vision.in.tum.de/members/cremers) 6 | > 7 | > Technical University of Munich, Munich Center for Machine Learning 8 | > 9 | > [📄 Paper](https://arxiv.org/abs/2407.17058) 10 | 11 | ![teaser](images/teaser.png) 12 | 13 | This repository contains the official implementation of the methods and experiments from the paper [DiffCD: A Symmetric Differentiable Chamfer Distance for Neural Implicit Surface Fitting](https://arxiv.org/abs/2407.17058). 14 | 15 | **Abstract:** 16 | Neural implicit surfaces can be used to recover accurate 3D geometry from imperfect point clouds. In this work, we show that state-of-the-art techniques work by minimizing an approximation of a one-sided Chamfer distance. This shape metric is not symmetric, as it only ensures that the point cloud is near the surface but not vice versa. As a consequence, existing methods can produce inaccurate reconstructions with spurious surfaces. Although one approach against spurious surfaces has been widely used in the literature, we theoretically and experimentally show that it is equivalent to regularizing the surface area, resulting in over-smoothing. As a more appealing alternative, we propose DiffCD, a novel loss function corresponding to the symmetric Chamfer distance. In contrast to previous work, DiffCD also assures that the surface is near the point cloud, which eliminates spurious surfaces without the need for additional regularization. We experimentally show that DiffCD reliably recovers a high degree of shape detail, substantially outperforming existing work across varying surface complexity and noise levels. 17 | 18 | ```text 19 | @inproceedings{haerenstam2024diffcd, 20 | title = {DiffCD: A Symmetric Differentiable Chamfer Distance for Neural Implicit Surface Fitting}, 21 | author = {L Härenstam-Nielsen and L Sang and A Saroha and N Araslanov and D Cremers}, 22 | booktitle = {European Conference on Computer Vision (ECCV)}, 23 | year = {2024}, 24 | eprinttype = {arXiv}, 25 | eprint = {2407.17058}, 26 | } 27 | ``` 28 | 29 | ## 🛠️ Setup 30 | 31 | Install requirements 32 | 33 | ```bash 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | Install jax version 0.4.14 matching your CUDA version as described [here](https://github.com/google/jax#pip-installation-gpu-cuda-installed-via-pip-easier). For example for CUDA 12: 38 | 39 | ```bash 40 | pip install --upgrade "jax[cuda12_pip]==0.4.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 41 | ``` 42 | 43 | Other jax versions may also work, but have not been tested. 44 | 45 | ## 📏 Fit implicit function 46 | 47 | The main interface for surface fitting is the training script `fit_implicit.py`. It takes a `.npy` point cloud as input and generates an implicit surface stored as an orbax checkpoint. You can also optionally provide a ground truth `.ply` mesh as input which is then used to compute evaluation metrics. 48 | The methods are implemented as to be easy to use in a standalone fashion as well. 49 | 50 | The repository also includes re-implementations of the loss functions from [IGR](https://arxiv.org/abs/2002.10099), [Neural-Pull](https://arxiv.org/abs/2011.13495) and [SIREN](https://arxiv.org/abs/2006.09661). If you use those implementations for a scientific publication please make sure to cite the corresponding papers. 51 | 52 | Command line arguments are generated using [tyro](https://brentyi.github.io/tyro/). Some example uses: 53 | 54 | ```bash 55 | # display command line options 56 | python fit_implicit.py -h 57 | 58 | # fit implicit function with diffcd 59 | python fit_implicit.py --output-dir outputs --dataset.path DATASET_PATH/bunny.xyz.npy 60 | 61 | # fit implicit function with diffcd and eikonal weight 0.5 62 | python fit_implicit.py --output-dir outputs --dataset.path DATASET_PATH/bunny.xyz.npy method:diff-cd --method.eikonal-weight 0.5 63 | 64 | # fit implicit function with IGR 65 | python fit_implicit.py --yaml-config configs/config_igr.yaml --output-dir outputs --dataset.path DATASET_PATH/bunny.xyz.npy 66 | ``` 67 | 68 | To log metrics to wandb, add `--wandb-project [project name] --wandb-entity [username]` to the command line argument. 69 | 70 | ## 🔬 Analyzing the results 71 | 72 | The outputs of `fit_implicit.py` should look like the following: 73 | 74 | ```yaml 75 | checkpoints/ # orbax checkpoints per save step 76 | 0 77 | 1000 78 | ... 79 | 40000 80 | meshes/ # extracted mesh per save step 81 | mesh_0.ply 82 | mesh_1000.ply 83 | ... 84 | mesh_40000.ply 85 | config.yaml # config used for reproducing results 86 | config.pickle # pickle file with the same config - yaml file loading can some times break when changing config structure 87 | 88 | eval_metrics.csv # evaluation metrics per save step 89 | eval_metrics_final_40000.csv # evaluation metrics for final shape 90 | train_metrics.csv # training metrics 91 | 92 | train_points.npy # training points in normalized coordinates 93 | local_sigma.npy # local sigma used for generating sample points - for debugging purposes 94 | ``` 95 | 96 | We provide notebooks [evaluation.ipynb](evaluation.ipynb) and [levelset_figures.ipynb](levelset_figures.ipynb) for analyzing the results. They can be used to reproduce all tables and figures in the main paper (let us know if anything is missing!). 97 | 98 | The notebook evaluation.ipynb can also be used to generate a per-shape breakdown of the metrics: 99 | 100 | ![DiffCD shape grid](images/results/no_noise/grid_metrics-diffcd.png) 101 | 102 | The same grid for other methods can be found under [images/results](images/results). 103 | 104 | ### Model checkpoints 105 | 106 | The trained models and evaluation results from the paper can be downloaded via this [google drive link](https://drive.google.com/drive/folders/1JHbsQ2eicajG7VMi8YHmnyT4An0yAuvr?usp=sharing). 107 | 108 | The folder contains the output from `fit_implicit.py` for each input shape and parameter setting. 109 | Make sure to update the `base_dir` variable in `evaluation.ipynb` to match the download location. 110 | 111 | ### Set up 3D visualizations 112 | 113 | Generating 3D figures requires setting up a [blender-notebook kernel](https://github.com/cheng-chi/blender_notebook). 114 | First, download and install [Blender](https://www.blender.org/). Then run: 115 | 116 | ```bash 117 | pip install blender-notebook blender-plots ipykernel ipywidgets ipympl 118 | blender_notebook install --blender-exec=BLENDER_INSTALL_PATH/blender-4.0.1-linux-x64/blender --kernel-name blender-diffcd 119 | ``` 120 | 121 | Modify `--blender-exec` to match your Blender install location. 122 | 123 | Finally, open [evaluation.ipynb](evaluation.ipynb) or [levelset_figures.ipynb](levelset_figures.ipynb) in a notebook editor (e.g. VS Code) and select the kernel `blender-diffcd`. 124 | This should launch the blender UI as a separate window, which uses the same python runtime as the notebook. 125 | -------------------------------------------------------------------------------- /configs/config_igr.yaml: -------------------------------------------------------------------------------- 1 | !dataclass:TrainingConfig 2 | checkpoint_options: !dataclass:CustomCheckpointManagerOptions 3 | save_interval_steps: 1000 4 | method: !dataclass:IGR 5 | eikonal_weight: 0.1 6 | dataset: !dataclass:PointCloud 7 | path: '' 8 | output_dir: outputs 9 | model: !dataclass:MLP 10 | activation_function: !enum:ActivationFunction SOFTPLUS 11 | -------------------------------------------------------------------------------- /configs/config_neural_pull.yaml: -------------------------------------------------------------------------------- 1 | !dataclass:TrainingConfig 2 | checkpoint_options: !dataclass:CustomCheckpointManagerOptions 3 | save_interval_steps: 1000 4 | method: !dataclass:NeuralPull 5 | sampling: !dataclass:SamplingConfig 6 | samples_per_point: 50 7 | dataset: !dataclass:PointCloud 8 | path: '' 9 | output_dir: outputs 10 | model: !dataclass:MLP 11 | activation_function: !enum:ActivationFunction SOFTPLUS 12 | -------------------------------------------------------------------------------- /configs/config_siren0.033 copy.yaml: -------------------------------------------------------------------------------- 1 | !dataclass:TrainingConfig 2 | checkpoint_options: !dataclass:CustomCheckpointManagerOptions 3 | save_interval_steps: 1000 4 | method: !dataclass:IGR 5 | eikonal_weight: 0.1 6 | alpha: 100 7 | surface_area_samples: 5000 8 | surface_area_weight: 0.033 9 | dataset: !dataclass:PointCloud 10 | path: '' 11 | output_dir: outputs 12 | model: !dataclass:MLP 13 | activation_function: !enum:ActivationFunction SOFTPLUS 14 | -------------------------------------------------------------------------------- /configs/config_siren0.33.yaml: -------------------------------------------------------------------------------- 1 | !dataclass:TrainingConfig 2 | checkpoint_options: !dataclass:CustomCheckpointManagerOptions 3 | save_interval_steps: 1000 4 | method: !dataclass:IGR 5 | eikonal_weight: 0.1 6 | alpha: 100 7 | surface_area_samples: 5000 8 | surface_area_weight: 0.33 9 | dataset: !dataclass:PointCloud 10 | path: '' 11 | output_dir: outputs 12 | model: !dataclass:MLP 13 | activation_function: !enum:ActivationFunction SOFTPLUS 14 | -------------------------------------------------------------------------------- /diffcd/__init__.py: -------------------------------------------------------------------------------- 1 | from diffcd import closest_point, datasets, methods, networks, training, utils, evaluation 2 | -------------------------------------------------------------------------------- /diffcd/closest_point.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from diffcd.newton import newton_kkt, NewtonState, NewtonConfig 3 | import flax 4 | import jax.numpy as jnp 5 | import jax 6 | 7 | def laplacian(f, params, query_point, z): 8 | '''L(x, mu) = .5 * ||query_point - x||^2 + mu * f(x), z = [x; mu]''' 9 | x, mu = z[:-1], z[-1] 10 | return .5 * sq_norm(x - query_point) + mu * f(params, x) 11 | 12 | @partial(jax.jit, static_argnames=("f", "newton_config")) 13 | def closest_point_newton(f, params, query_point, init_point, newton_config: NewtonConfig): 14 | """Compute the point closest to query_point on the surface defined by f(params, x) = 0 using Newton's method.""" 15 | f_val, g = jax.value_and_grad(f, argnums=1)(params, query_point) 16 | 17 | if init_point is None: 18 | # take initial step which will reach the surface if f is a perfect SDF 19 | # this corresponds to a Newton step with x0=query_point and mu0=0, but avoids the hessian calculation. 20 | x0 = query_point - g * f_val / sq_norm(g) 21 | z0 = jnp.array([*x0, f_val / sq_norm(g)]) 22 | else: 23 | mu0 = jnp.linalg.norm(query_point - init_point) / jnp.linalg.norm(g) * jnp.sign((query_point - init_point)[0] * g[0]) 24 | z0 = jnp.array([*init_point, mu0]) 25 | 26 | z_kkt, newton_state = newton_kkt( 27 | partial(laplacian, f), 28 | newton_config, 29 | jax.lax.stop_gradient(z0), 30 | params, 31 | query_point, 32 | ) 33 | valid_kkt = check_kkt( 34 | newton_state, 35 | f(params, query_point) 36 | ) 37 | valid = jnp.logical_and(valid_kkt, newton_state.converged) 38 | return z_kkt[:-1], newton_state, valid 39 | 40 | @partial(jax.jit, static_argnames=("f", "newton_config")) 41 | def closest_point_newton_batch(f, params, query_points, newton_config: NewtonConfig): 42 | """Compute closest point for a batch of query points, using the same config for each point.""" 43 | return jax.vmap( 44 | closest_point_newton, 45 | in_axes=(None, None, 0, None, None), 46 | )(f, params, query_points, None, newton_config) 47 | 48 | def get_curve_eigvals(g, H): 49 | """Compute the eigenvalues of the n-by-n matrix H restricted to the subspace orthogonal to n-dim vector g.""" 50 | P = jnp.linalg.svd(jax.lax.stop_gradient(jnp.outer(g, g)))[0][:, 1:] 51 | return jnp.linalg.eigvalsh(P.T @ H @ P) 52 | 53 | def check_kkt(state: NewtonState, f_query: float): 54 | curve_eigvals = get_curve_eigvals(state.H[-1, :-1], state.H[:-1, :-1]) 55 | is_local_min = curve_eigvals[0] > 0 56 | valid_sign = jnp.sign(state.z[-1]) == jnp.sign(f_query) 57 | return jnp.logical_and(is_local_min, valid_sign) 58 | 59 | def get_distance_derivative(f, theta, query_point, closest_point): 60 | """Compute the gradient of x*(theta) = min_x ||x - xq||^2 s.t. f(theta, x) = 0 wrt. theta.""" 61 | H = jax.hessian(f, argnums=1)(theta, closest_point) 62 | g_x = jax.grad(f, argnums=1)(theta, closest_point) 63 | mu = (query_point - closest_point)[0] / g_x[0] 64 | 65 | A = jnp.block([ 66 | [jnp.eye(len(query_point)) + mu * H, g_x[:, None]], 67 | [g_x, 0] 68 | ]) 69 | A_inv = jnp.linalg.inv(A)[:-1] 70 | w = - A_inv.T @ (closest_point - query_point) 71 | 72 | return jax.grad( 73 | lambda theta: 2 * w[-1] * f(theta, closest_point) + 74 | jax.jvp(partial(f, theta), (closest_point,), (w[:-1],))[1] 75 | )(theta) 76 | 77 | def sq_norm(a, *args, **kwargs): 78 | return (a ** 2).sum(*args, **kwargs) 79 | -------------------------------------------------------------------------------- /diffcd/datasets.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | from pathlib import Path 3 | from dataclasses import dataclass 4 | from typing import Literal, Union, Tuple, Optional 5 | import jax.numpy as jnp 6 | import jax.random as jrnd 7 | 8 | @dataclass 9 | class Circle: 10 | n_points: int = 50 11 | n_points_eval: int = 250 12 | radius: float = 2. 13 | sigma: float = 0.5 14 | lower_bound:Tuple[float, float] = (-3., -3.) 15 | upper_bound:Tuple[float, float] = (3., 3.) 16 | n_dims: Literal[2] = 2 17 | 18 | def get_data(self, key): 19 | alphas = jnp.linspace(0, 2 * jnp.pi, self.n_points + 1)[:-1] 20 | points = self.radius * jnp.array([jnp.cos(alphas), jnp.sin(alphas)]).T 21 | points = points + jrnd.normal(key, points.shape) * self.sigma 22 | return points 23 | 24 | def get_data_eval(self, key): 25 | alphas = jnp.linspace(0, 2 * jnp.pi, self.n_points_eval + 1)[:-1] 26 | points = self.radius * jnp.array([jnp.cos(alphas), jnp.sin(alphas)]).T 27 | return points 28 | 29 | def get_train_eval_points(self, key): 30 | key_train, key_eval = jrnd.split(key, 2) 31 | train_points = self.get_data(key_train) 32 | eval_points = self.get_data_eval(key_eval) 33 | return train_points, eval_points 34 | 35 | 36 | @dataclass 37 | class PointCloud: 38 | # Path to file containing point cloud. Either .npy with xyz coordinates, or .ply file with mesh to sample points from 39 | path: Path 40 | 41 | # Number of training points, defaults to all points for .npy files or n_vertices for .ply files 42 | n_points: Optional[int] = None 43 | 44 | # Standard deviation of gaussian noise to add to each point 45 | sigma: float = 0. 46 | 47 | # If true, subtract center of bounding box from point cloud and then divide my maximum side length 48 | auto_scale: bool = True 49 | 50 | n_dims: Literal[3] = 3 51 | _scale_factor = None 52 | _center_point = None 53 | 54 | def apply_normalization(self, points): 55 | if len(points) > 0: 56 | return (points - self._center_point) / self._scale_factor 57 | else: 58 | return points 59 | 60 | def undo_normalization(self, points): 61 | if len(points) > 0: 62 | return points * self._scale_factor + self._center_point 63 | else: 64 | return points 65 | 66 | def get_normalized_points(self, key): 67 | extensions = ['.npy', '.xyz'] 68 | if self.path.suffix in extensions: 69 | point_cloud = jnp.load(self.path) 70 | 71 | n_points = self.n_points if self.n_points is not None else len(point_cloud) 72 | point_cloud = jrnd.choice(key, point_cloud, (n_points,), replace=False) 73 | 74 | if self.auto_scale: 75 | lower, upper = point_cloud.min(axis=0), point_cloud.max(axis=0) 76 | self._center_point = (lower + upper) / 2 77 | self._scale_factor = (upper - lower).max() 78 | else: 79 | self._center_point = jnp.zeros(3, dtype=jnp.float32) 80 | self._scale_factor = 1. 81 | 82 | point_cloud = self.apply_normalization(point_cloud) 83 | else: 84 | raise ValueError(f'File extension {self.path.suffix} not recognized for file {self.path}. Expected {extensions}') 85 | 86 | point_cloud += jrnd.normal(key, point_cloud.shape) * self.sigma 87 | return point_cloud 88 | 89 | 90 | @dataclass 91 | class EvaluationMesh: 92 | path: Optional[Path] = None 93 | n_samples: int = 30000 94 | 95 | _mesh = None 96 | 97 | @property 98 | def mesh(self): 99 | if (self._mesh is None) and (self.path is not None): 100 | self._mesh = trimesh.load_mesh(self.path) 101 | return self._mesh 102 | 103 | 104 | Datasets = Union[ 105 | Circle, 106 | PointCloud, 107 | ] 108 | -------------------------------------------------------------------------------- /diffcd/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from diffcd.evaluation import chamfer, contours, meshing -------------------------------------------------------------------------------- /diffcd/evaluation/chamfer.py: -------------------------------------------------------------------------------- 1 | 2 | """from https://github.com/otaheri/chamfer_distance """ 3 | 4 | from functools import partial 5 | import numpy as np 6 | from scipy.spatial import cKDTree as KDTree 7 | import trimesh 8 | 9 | 10 | __all__ = ['compute_chamfer', 'compute_hausdorff'] 11 | 12 | def one_sided_chamfer(points1, points2): 13 | """Calculate (arg)min_j \|points1[i] - points2[j]\| for each i.""" 14 | points2_kd_tree = KDTree(points2) 15 | distances12, closest_indices2 = points2_kd_tree.query(points1) 16 | return distances12, closest_indices2 17 | 18 | def normal_angle_deg(normals1, normals2): 19 | dot_products = np.einsum('...i, ...i -> ...', normalize(normals1), normalize(normals2)) 20 | return np.arccos(np.clip(dot_products, -1, 1)) * 180 / np.pi 21 | 22 | def normalize(a): 23 | return a / np.linalg.norm(a, axis=-1, keepdims=True) 24 | 25 | def compute_shape_metrics(dist_fn, metric_name, shape1: trimesh.Trimesh, shape2: trimesh.Trimesh, num_mesh_samples): 26 | """ 27 | Compute nearest-neighbor based metrics for two meshes or point clouds. 28 | 29 | distance metric: distance from each point in shape i to the closest point in shape j 30 | normal metric: angle between normal of each point in shape i and the normal of the closest point in shape j 31 | 32 | Each metric is aggregated over points using dist_fn which maps an array to a single value. 33 | """ 34 | if (shape1 is None) or (shape2 is None) or (len(shape1.vertices) == 0) or (len(shape2.vertices) == 0): 35 | distance_metric, distance_square_metric, normal_angle_metric = np.nan, np.nan, np.nan 36 | else: 37 | points1, face_indices1 = trimesh.sample.sample_surface(shape1, num_mesh_samples) 38 | points2, face_indices2 = trimesh.sample.sample_surface(shape2, num_mesh_samples) 39 | 40 | distances12, closest_indices2 = one_sided_chamfer(points1, points2) 41 | distances21, closest_indices1 = one_sided_chamfer(points2, points1) 42 | distance_metric = dist_fn([dist_fn(distances12), dist_fn(distances21)]) 43 | distance_square_metric = dist_fn([dist_fn(distances12 ** 2), dist_fn(distances21 ** 2)]) 44 | 45 | if (face_indices1 is not None) and (face_indices2 is not None): 46 | normals1 = shape1.face_normals[face_indices1] 47 | normals2 = shape2.face_normals[face_indices2] 48 | normal_angle12 = normal_angle_deg(normals1, normals2[closest_indices2]) 49 | normal_angle21 = normal_angle_deg(normals2, normals1[closest_indices1]) 50 | normal_angle_metric = dist_fn([dist_fn(normal_angle12), dist_fn(normal_angle21)]) 51 | 52 | # compute normal metric again with the normals of one of the meshes flipped, and select whichever is smallest 53 | normal_angle12 = normal_angle_deg(normals1, -normals2[closest_indices2]) 54 | normal_angle21 = normal_angle_deg(-normals2, normals1[closest_indices1]) 55 | normal_angle_metric_flipped = dist_fn([dist_fn(normal_angle12), dist_fn(normal_angle21)]) 56 | 57 | normal_angle_metric = min(normal_angle_metric, normal_angle_metric_flipped) 58 | else: 59 | normal_angle_metric = np.nan 60 | return { 61 | f'{metric_name}_distance': distance_metric, 62 | f'{metric_name}_square_distance': distance_square_metric, 63 | f'{metric_name}_normal_angle': normal_angle_metric, 64 | } 65 | 66 | compute_chamfer = partial(compute_shape_metrics, np.mean, 'chamfer') 67 | compute_hausdorff = partial(compute_shape_metrics, np.max, 'hausdorff') 68 | -------------------------------------------------------------------------------- /diffcd/evaluation/contours.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dataclasses import dataclass 3 | import contourpy 4 | 5 | @dataclass 6 | class Contour: 7 | vertices: np.array 8 | edges: np.array 9 | 10 | def save(self, output_dir): 11 | np.savez(output_dir, vertices=self.vertices, segments=self.edges) 12 | 13 | @classmethod 14 | def load(self, contour_dir): 15 | return Contour(**np.load(contour_dir)) 16 | 17 | def get_contour(inputs, sdf_values): 18 | lines, codes = contourpy.contour_generator( 19 | inputs[..., 0], inputs[..., 1], sdf_values, 20 | name='mpl2014', corner_mask=True, 21 | line_type=contourpy.LineType.SeparateCode, 22 | fill_type=contourpy.FillType.OuterCode, 23 | chunk_size=0 24 | ).lines(0.) 25 | 26 | edges, points = [], [] 27 | start_index = 0 28 | for segment_points, segment_codes in zip(lines, codes): 29 | segment_edges = np.vstack([ # TODO: handle case when line is closed 30 | np.arange(0, len(segment_points)-1), 31 | np.arange(1, len(segment_points)), 32 | ]).T + start_index 33 | points.append(segment_points) 34 | edges.append(segment_edges) 35 | if segment_codes[-1] == 79: 36 | edges.append(np.array([start_index, start_index + len(segment_points) - 1])) 37 | start_index += len(segment_points) 38 | return Contour(np.vstack(points), np.vstack(edges)) 39 | 40 | 41 | def get_sample_points(contour: Contour, n_samples, seed=0): 42 | distances = np.linalg.norm(contour.vertices[contour.edges[..., 1]] - contour.vertices[contour.edges[..., 0]], axis=-1) 43 | cumulative_distances = np.hstack([0., np.cumsum(distances)]) 44 | 45 | sample_distances = np.random.default_rng(seed).random(n_samples) * cumulative_distances[-1] 46 | edge_indices = np.searchsorted(cumulative_distances, sample_distances) - 1 47 | 48 | alphas = ((sample_distances - cumulative_distances[edge_indices]) / distances[edge_indices])[..., None] 49 | starts = contour.vertices[contour.edges[edge_indices, 0]] 50 | ends = contour.vertices[contour.edges[edge_indices, 1]] 51 | sample_points = starts * (1 - alphas) + ends * alphas 52 | 53 | diffs = ends - starts 54 | normals = np.array([-diffs[:, 1], diffs[:, 0]]).T 55 | normals = normals / np.linalg.norm(normals, axis=1, keepdims=True) 56 | 57 | return sample_points, normals -------------------------------------------------------------------------------- /diffcd/evaluation/meshing.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import flax 3 | import numpy as np 4 | from skimage.measure import marching_cubes 5 | import trimesh 6 | from dataclasses import dataclass 7 | from typing import Tuple 8 | from pathlib import Path 9 | 10 | @flax.struct.dataclass 11 | class Meshing: 12 | points_per_axis: int = 256 13 | lower_bound: Tuple[float, float, float] = (-.8, -.8, -.8) 14 | upper_bound: Tuple[float, float, float] = (.8, .8, .8) 15 | f_batch_size: int = 64 ** 3 16 | levelset: float = 0. 17 | 18 | def get_grid(lower, upper, n): 19 | axis_values = (jnp.linspace(lower[i], upper[i], n) for i in range(len(lower))) 20 | grid_values = jnp.meshgrid(*axis_values, indexing='ij') 21 | return jnp.stack(grid_values, axis=-1) 22 | 23 | def iter_grid(lower, upper, n, batch_size): 24 | n_dims = len(lower) 25 | axis_batch_size = max(1, int(batch_size / n ** (n_dims - 1))) 26 | 27 | axis_values = [jnp.linspace(lower[i], upper[i], n) for i in range(len(lower))] 28 | for i in range(0, n, axis_batch_size): 29 | grid_values = jnp.meshgrid(axis_values[0][i:i+axis_batch_size], *axis_values[1:], indexing='ij') 30 | yield jnp.stack(grid_values, axis=-1) 31 | 32 | def extract_mesh(config: Meshing, f) -> trimesh.Trimesh: 33 | lower, upper = config.lower_bound, config.upper_bound 34 | n = config.points_per_axis 35 | 36 | outputs_numpy = [] 37 | for inputs in iter_grid(lower, upper, n, config.f_batch_size): 38 | outputs_numpy.append(np.array(f(inputs))) 39 | outputs_numpy = np.concatenate(outputs_numpy).reshape(n, n, n) 40 | 41 | try: 42 | vertices, faces, normals, _ = marching_cubes( 43 | volume=outputs_numpy, 44 | level=config.levelset, 45 | spacing=( 46 | (upper[0] - lower[0]) / config.points_per_axis, 47 | (upper[1] - lower[1]) / config.points_per_axis, 48 | (upper[2] - lower[2]) / config.points_per_axis, 49 | ) 50 | ) 51 | vertices += np.array(lower)[None] 52 | except ValueError: 53 | print('marching cubes: no 0-level set found') 54 | vertices, faces, normals = np.array([]).reshape((0, 3)), [], None 55 | return trimesh.Trimesh(vertices, faces, vertex_normals=normals) 56 | 57 | def save_ply(mesh: trimesh.Trimesh, output_file): 58 | Path(output_file).parent.mkdir(exist_ok=True) 59 | with open(output_file, 'wb') as ply_file: 60 | ply_file.write(trimesh.exchange.ply.export_ply(mesh)) 61 | -------------------------------------------------------------------------------- /diffcd/methods.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jax.random as jrnd 4 | from jax import jit, vmap 5 | from dataclasses import field 6 | from functools import partial 7 | from typing import Union, Literal, Optional 8 | import flax 9 | from scipy.spatial import cKDTree 10 | import numpy as np 11 | import trimesh 12 | 13 | from diffcd.evaluation import meshing 14 | from diffcd.closest_point import sq_norm, closest_point_newton, NewtonConfig 15 | from diffcd import samplers 16 | from diffcd.samplers import SamplingConfig, SurfaceSamplingConfig 17 | from diffcd.training import ShapeTrainState 18 | 19 | 20 | def any_nans(pytree): 21 | """Returns True if any leaf of the pytree contains a nan value.""" 22 | return jnp.array(jax.tree_util.tree_flatten(jax.tree_map(lambda a: jnp.isnan(a).any(), pytree))[0]).any() 23 | 24 | def valid_mean(array, valid_mask): 25 | return (array * valid_mask).sum() / valid_mask.sum() 26 | 27 | def safe_apply_grads(state, grads): 28 | nan_grads = any_nans(grads) 29 | state = jax.lax.cond(nan_grads, lambda: state, lambda: state.apply_gradients(grads=grads)) 30 | return state, nan_grads 31 | 32 | def soft_norm(x, eps=1e-12): 33 | """Use l2 for large values and squared l2 for small values to avoid grad=nan at x=0.""" 34 | return eps * (jnp.sqrt(sq_norm(x) / eps ** 2 + 1) - 1) 35 | 36 | def get_eikonal_loss(f, params, point): 37 | point_grad = jax.grad(f, argnums=1)(params, point) 38 | return (1 - soft_norm(point_grad)) ** 2 39 | 40 | def safe_normalize(x, eps=1e-12): 41 | x_norm = soft_norm(x) 42 | return x / jnp.array([x_norm, eps]).max() 43 | 44 | def _surface_point(f, params, point): 45 | """Makes point differentiable as a function of params, assuming f(params, x) = 0 and gradient x wrt any parameter is restricted to the normal of the surface.""" 46 | return point 47 | 48 | def grad_norm(f, inputs): 49 | return jnp.linalg.norm(jax.grad(f)(inputs)) 50 | 51 | def surface_point_fwd(f, params, point): 52 | return point, (point, params) 53 | 54 | def surface_point_bwd(f, info, tangent): 55 | point, params = info 56 | g_x = jax.grad(f, argnums=1)(params, point) 57 | tg = tangent @ g_x / sq_norm(g_x) 58 | g_params = jax.grad(f)(params, point) 59 | return jax.tree_map(lambda g_param: -g_param * tg, g_params), tangent 60 | 61 | surface_point = jax.custom_vjp(_surface_point, nondiff_argnums=(0,)) 62 | surface_point.defvjp(surface_point_fwd, surface_point_bwd) 63 | 64 | @flax.struct.dataclass 65 | class SurfaceSamples: 66 | points: jnp.array 67 | closest_train_points: jnp.array 68 | valid: jnp.array 69 | 70 | def implicit_distance(f, points, distance_metric): 71 | f_values = f(points) 72 | if distance_metric == 'squared_l2': 73 | implicit_distance = (f_values ** 2).mean() 74 | elif distance_metric == 'l2': 75 | implicit_distance = jnp.abs(f_values).mean() 76 | else: 77 | raise ValueError(f'Unrecognized distance metric {distance_metric=}.') 78 | return implicit_distance 79 | 80 | def distance_loss(point1, point2, distance_metric): 81 | match distance_metric: 82 | case 'l2': 83 | loss = soft_norm(point1 - point2, eps=1e-12) 84 | case 'squared_l2': 85 | loss = sq_norm(point1 - point2) 86 | case _: 87 | raise ValueError(f'Unrecognized distance metric {distance_metric=}') 88 | return loss 89 | 90 | @flax.struct.dataclass 91 | class DiffCDState: 92 | point_cloud: jnp.array 93 | local_sigma: jnp.array 94 | point_cloud_kd_tree: cKDTree = flax.struct.field(pytree_node=False) 95 | mesh_samples: jnp.array 96 | 97 | @flax.struct.dataclass 98 | class DiffCD: 99 | newton_config: NewtonConfig = field(default_factory=lambda: NewtonConfig()) 100 | p2s_loss: Literal['closest-point', 'implicit'] = 'implicit' 101 | 102 | eikonal_weight: float = 0.1 103 | s2p_weight: float = 1. 104 | 105 | surface_area_weight: float = 0. 106 | alpha: float = 100 107 | surface_area_samples: int = 5000 108 | 109 | distance_metric: Literal['l2', 'squared_l2'] = 'l2' 110 | sampling: SamplingConfig = field(default_factory=lambda: SamplingConfig()) 111 | surface_sampling: SurfaceSamplingConfig = field(default_factory=lambda: SurfaceSamplingConfig()) 112 | 113 | def init_state(self, key, point_cloud, save_dir=None): 114 | local_sigma = samplers.compute_local_sigma(point_cloud, self.sampling.k) * self.sampling.local_sigma_scale 115 | if save_dir is not None: 116 | np.save(save_dir / 'local_sigma.npy', local_sigma) 117 | return DiffCDState(point_cloud, local_sigma, cKDTree(point_cloud), None) 118 | 119 | @partial(jit, static_argnames=("self", "apply_fn", "batch_size")) 120 | def get_surface_samples(self, key, apply_fn, params, mesh_samples, point_cloud, batch_size): 121 | # compute batch of surface samples and their closest point cloud points 122 | batch_mesh_samples, _ = samplers.sample_array(key, mesh_samples, batch_size) 123 | surface_samples, valid = samplers.generate_surface_samples( 124 | apply_fn, params, 125 | batch_mesh_samples, 126 | self.surface_sampling.newton, 127 | ) 128 | closest_train_points = vmap(point_cloud_closest_point, in_axes=(0, None))(surface_samples, point_cloud) 129 | return SurfaceSamples(surface_samples, closest_train_points, valid) 130 | 131 | def get_batch(self, train_state: ShapeTrainState, state: DiffCDState, key, batch_size): 132 | _, *batch = IGR.get_batch(self, train_state, state, key, batch_size) 133 | if self.s2p_weight != 0: 134 | if (train_state.step % self.surface_sampling.mesh_interval == 0): 135 | # update surface sampling mesh 136 | sampling_mesh = meshing.extract_mesh( 137 | self.surface_sampling.surface_meshing, 138 | f=partial(train_state.apply_fn, train_state.params), 139 | ) 140 | mesh_samples = trimesh.sample.sample_surface(sampling_mesh, self.surface_sampling.num_samples)[0] 141 | state = DiffCDState(state.point_cloud, state.local_sigma, state.point_cloud_kd_tree, mesh_samples) 142 | 143 | surface_samples = self.get_surface_samples( 144 | key, train_state.apply_fn, train_state.params, state.mesh_samples, state.point_cloud, batch_size 145 | ) 146 | else: 147 | surface_samples = SurfaceSamples(jnp.array([]), jnp.array([]), jnp.array([])) 148 | return state, *batch, surface_samples 149 | 150 | def closest_point_loss(self, f, params, query_point): 151 | closest_point, newton_state, valid = closest_point_newton(f, params, query_point, None, self.newton_config) 152 | loss = distance_loss(closest_point, query_point, self.distance_metric) 153 | 154 | return loss, newton_state, valid 155 | 156 | def surface_sample_loss(self, f, params, surface_sample_point, point_cloud_point): 157 | surface_sample_point = surface_point(f, params, surface_sample_point) 158 | loss = distance_loss(surface_sample_point, point_cloud_point, self.distance_metric) 159 | return loss 160 | 161 | def batch_loss(self, apply_fn, params, query_points, sample_points, uniform_samples, surface_samples: SurfaceSamples): 162 | metrics = {} 163 | if self.p2s_loss == 'closest-point': 164 | p2s_losses, newton_state, valid = jax.vmap(self.closest_point_loss, in_axes=(None, None, 0))(apply_fn, params, query_points) 165 | p2s_loss = valid_mean(p2s_losses, valid) 166 | metrics = { 167 | 'mean_n_valid': valid.mean(), 168 | 'mean_n_converged': newton_state.converged.mean(), 169 | 'mean_n_newton_steps': newton_state.step.mean(), 170 | } 171 | elif self.p2s_loss == 'implicit': 172 | p2s_loss = implicit_distance(partial(apply_fn, params), query_points, self.distance_metric) 173 | else: 174 | raise ValueError(f'Unrecognized p2s metric {self.p2s_metric=}.') 175 | 176 | if self.s2p_weight != 0.: 177 | s2p_losses = vmap( 178 | self.surface_sample_loss, in_axes=(None, None, 0, 0) 179 | )(apply_fn, params, surface_samples.points, surface_samples.closest_train_points) 180 | s2p_loss = valid_mean(s2p_losses, surface_samples.valid) 181 | else: 182 | s2p_loss = 0. 183 | 184 | if self.eikonal_weight != 0.: 185 | eikonal_loss = vmap(get_eikonal_loss, in_axes=(None, None, 0))( 186 | apply_fn, params, jnp.vstack([query_points, sample_points]) 187 | ).mean() 188 | else: 189 | eikonal_loss = 0. 190 | 191 | if self.surface_area_weight != 0.: 192 | surface_area_loss = get_implicit_surface_area_loss(partial(apply_fn, params), uniform_samples, self.alpha).mean() 193 | else: 194 | surface_area_loss = 0. 195 | 196 | loss = p2s_loss 197 | loss += self.s2p_weight * s2p_loss 198 | loss /= 1. + self.s2p_weight 199 | loss += self.eikonal_weight * eikonal_loss 200 | loss += self.surface_area_weight * surface_area_loss 201 | 202 | metrics = { 203 | 'loss': loss, 204 | 'points_to_surface_loss': p2s_loss, 205 | 'surface_to_points_loss': s2p_loss, 206 | 'eikonal_loss': eikonal_loss, 207 | 'implicit_surface_area_loss': surface_area_loss, 208 | 'mean_n_valid_surface_sample': surface_samples.valid.mean(), 209 | **metrics, 210 | } 211 | return loss, metrics 212 | 213 | @partial(jit, static_argnames="self") 214 | def step(self, state: ShapeTrainState, query_points, sample_points, uniform_samples, surface_points): 215 | grads, metrics = jax.grad( 216 | self.batch_loss, argnums=1, has_aux=True 217 | )(state.apply_fn, state.params, query_points, sample_points, uniform_samples, surface_points) 218 | state, nan_grads = safe_apply_grads(state, grads) 219 | return metrics, state, nan_grads 220 | 221 | def get_implicit_surface_area_loss(f, points, alpha): 222 | return jnp.exp(-alpha * jnp.abs(f(points))) 223 | 224 | @flax.struct.dataclass 225 | class IGRState: 226 | point_cloud: jnp.array 227 | local_sigma: jnp.array 228 | 229 | @flax.struct.dataclass 230 | class IGR: 231 | eikonal_weight: float = 0.1 232 | distance_metric: Literal['l2', 'squared_l2'] = 'l2' 233 | sampling: SamplingConfig = field(default_factory=lambda: SamplingConfig()) 234 | surface_area_weight: float = 0. 235 | alpha: float = 100 236 | surface_area_samples: int = 5000 237 | 238 | def init_state(self, key, point_cloud, save_dir=None): 239 | local_sigma = samplers.compute_local_sigma(point_cloud, self.sampling.k) * self.sampling.local_sigma_scale 240 | if save_dir is not None: 241 | np.save(save_dir / 'local_sigma.npy', local_sigma) 242 | return IGRState(point_cloud, local_sigma) 243 | 244 | def get_loss(self, apply_fn, params, query_points, sample_points, uniform_sample_points): 245 | implicit_distance_loss = implicit_distance(partial(apply_fn, params), query_points, self.distance_metric) 246 | eikonal_loss = vmap(get_eikonal_loss, in_axes=(None, None, 0))(apply_fn, params, sample_points).mean() 247 | surface_area_loss = get_implicit_surface_area_loss(partial(apply_fn, params), uniform_sample_points, self.alpha).mean() 248 | loss = implicit_distance_loss + self.eikonal_weight * eikonal_loss + self.surface_area_weight * surface_area_loss 249 | return loss, (implicit_distance_loss, eikonal_loss, surface_area_loss) 250 | 251 | @partial(jit, static_argnames=("self", "batch_size")) 252 | def get_batch(self, train_state: ShapeTrainState, state: IGRState, key, batch_size): 253 | point_batch, batch_indices = samplers.sample_array(key, state.point_cloud, batch_size) 254 | 255 | key_local, key_global, key_area = jrnd.split(key, 3) 256 | local_samples = samplers.generate_local_samples( 257 | key_local, 258 | point_batch, 259 | self.sampling.samples_per_point, 260 | state.local_sigma[batch_indices], 261 | ) 262 | global_samples = samplers.generate_global_samples( 263 | key_global, 264 | lower=train_state.lower_bound, 265 | upper=train_state.upper_bound, 266 | n_points=self.sampling.global_samples if self.sampling.global_samples is not None else len(local_samples) // 8, 267 | n_dims=state.point_cloud.shape[-1], 268 | ) 269 | uniform_samples = samplers.generate_global_samples( 270 | key_area, 271 | lower=train_state.lower_bound, 272 | upper=train_state.upper_bound, 273 | n_points=self.surface_area_samples if self.surface_area_weight != 0 else 0, 274 | n_dims=state.point_cloud.shape[-1], 275 | ) 276 | return state, point_batch, jnp.concatenate([local_samples, global_samples]), uniform_samples 277 | 278 | @partial(jit, static_argnames="self") 279 | def step(self, state: ShapeTrainState, query_points, sample_points, uniform_sample_points): 280 | (loss, (distance_loss, eikonal_loss, surface_area_loss)), grads = jax.value_and_grad( 281 | self.get_loss, argnums=1, has_aux=True 282 | )(state.apply_fn, state.params, query_points, sample_points, uniform_sample_points) 283 | state, nan_grads = safe_apply_grads(state, grads) 284 | metrics = { 285 | 'loss': loss, 286 | f'implicit_{self.distance_metric}_loss': distance_loss, 287 | 'eikonal_loss': eikonal_loss, 288 | 'implicit_surface_area_loss': surface_area_loss, 289 | } 290 | return metrics, state, nan_grads 291 | 292 | def pull_point(f, point): 293 | f_value, f_grad = jax.value_and_grad(f)(point) 294 | return point - f_value * safe_normalize(f_grad) 295 | 296 | def point_cloud_closest_point(query_point, point_cloud): 297 | distances = sq_norm(point_cloud - query_point, axis=-1) 298 | return point_cloud[jnp.argmin(distances)] 299 | 300 | def get_neural_pull_loss(apply_fn, params, target_point, sample_point, distance_metric): 301 | pulled_point = pull_point(partial(apply_fn, params), sample_point) 302 | return distance_loss(target_point, pulled_point, distance_metric) 303 | 304 | @flax.struct.dataclass 305 | class NeuralPullState: 306 | sample_points: jnp.array 307 | target_points: jnp.array 308 | 309 | @flax.struct.dataclass 310 | class NeuralPull: 311 | eikonal_weight: float = 0. 312 | distance_metric: Literal['l1', 'l2', 'squared_l2'] = 'l2' 313 | sampling: SamplingConfig = field(default_factory=lambda: SamplingConfig(samples_per_point=10)) 314 | 315 | def init_state(self, key, point_cloud, save_dir=None): 316 | local_sigma = samplers.compute_local_sigma(point_cloud, self.sampling.k) * self.sampling.local_sigma_scale 317 | sample_points = samplers.generate_local_samples( 318 | key, 319 | point_cloud, 320 | self.sampling.samples_per_point, # specify total samples instead? 321 | local_sigma, 322 | ) 323 | target_indices = cKDTree(point_cloud).query(sample_points)[1] 324 | target_points = np.array(point_cloud[target_indices]) 325 | if save_dir is not None: 326 | np.save(save_dir / 'local_sigma.npy', local_sigma) 327 | np.save(save_dir / 'sample_points.npy', sample_points) 328 | np.save(save_dir / 'target_points.npy', target_points) 329 | 330 | return NeuralPullState(sample_points, target_points) 331 | 332 | @partial(jit, static_argnames=("self", "batch_size")) 333 | def get_batch(self, train_state: ShapeTrainState, state: NeuralPullState, key, batch_size, *args): 334 | batch_indices = jrnd.choice( 335 | key, len(state.sample_points), 336 | (min(batch_size, len(state.sample_points)),), 337 | replace=False 338 | ) 339 | return state, state.sample_points[batch_indices], state.target_points[batch_indices] 340 | 341 | def batch_loss(self, apply_fn, params, sample_points, target_points): 342 | neural_pull_loss = vmap(get_neural_pull_loss, in_axes=(None, None, 0, 0, None))( 343 | apply_fn, params, target_points, sample_points, self.distance_metric 344 | ).mean() 345 | 346 | if self.eikonal_weight != 0: 347 | eikonal_loss = vmap(get_eikonal_loss, in_axes=(None, None, 0))(apply_fn, params, sample_points).mean() 348 | else: 349 | eikonal_loss = 0. 350 | 351 | return neural_pull_loss + self.eikonal_weight * eikonal_loss, (neural_pull_loss, eikonal_loss) 352 | 353 | @partial(jit, static_argnames="self") 354 | def step(self, train_state: ShapeTrainState, sample_points, target_points): 355 | (loss, (distance_loss, eikonal_loss)), grads = jax.value_and_grad(self.batch_loss, argnums=1, has_aux=True)( 356 | train_state.apply_fn, train_state.params, sample_points, target_points 357 | ) 358 | train_state, nan_grads = safe_apply_grads(train_state, grads) 359 | metrics = { 360 | 'loss': loss, 361 | f'distance_loss': distance_loss, 362 | 'eikonal_loss': eikonal_loss, 363 | } 364 | return metrics, train_state, nan_grads 365 | 366 | Methods = Union[ 367 | DiffCD, 368 | IGR, 369 | NeuralPull, 370 | ] 371 | -------------------------------------------------------------------------------- /diffcd/networks.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import jax.numpy as jnp 3 | from flax import linen as nn 4 | from flax.linen import relu, elu, linear 5 | from typing import Tuple 6 | import jax.random as jrnd 7 | import jax 8 | 9 | 10 | class ActivationFunction(enum.Enum): 11 | RELU = enum.auto() 12 | ELU = enum.auto() 13 | SIN = enum.auto() 14 | SOFTPLUS = enum.auto() 15 | 16 | 17 | def get_activation_function(activation_function: ActivationFunction): 18 | return { 19 | ActivationFunction.RELU: relu, 20 | ActivationFunction.ELU: elu, 21 | ActivationFunction.SIN: jnp.sin, 22 | ActivationFunction.SOFTPLUS: safe_softplus, 23 | }[activation_function] 24 | 25 | 26 | class MLP(nn.Module): 27 | layer_size: int = 256 28 | n_layers: int = 8 29 | skip_layers: Tuple[int, ...] = (4,) 30 | activation_function: ActivationFunction = ActivationFunction.SOFTPLUS 31 | geometry_init: bool = True 32 | init_radius: float = 0.5 33 | 34 | @nn.compact 35 | def __call__(self, x): 36 | input_x = x 37 | dim_x = x.shape[-1] 38 | 39 | actication_function = get_activation_function(self.activation_function) 40 | kernel_init = zero_mean if self.geometry_init else linear.default_kernel_init 41 | for i in range(self.n_layers): 42 | if i in self.skip_layers: 43 | x = jnp.concatenate([x, input_x], axis=-1) / jnp.sqrt(2) 44 | layer_size = self.layer_size if i + 1 not in self.skip_layers else self.layer_size - dim_x 45 | x = nn.Dense(features=layer_size, name=f'dense_{i}', kernel_init=kernel_init)(x) 46 | x = actication_function(x) 47 | kernel_init_final = non_zero_mean if self.geometry_init else linear.default_kernel_init 48 | bias_init_final = jax.nn.initializers.constant(-self.init_radius if self.geometry_init else 0.) 49 | x = nn.Dense(features=1, name=f'dense_{self.n_layers}', kernel_init=kernel_init_final, bias_init=bias_init_final)(x) 50 | return x.squeeze() 51 | 52 | def non_zero_mean(key, shape, dtype=jnp.float32): 53 | normal_random_values = jrnd.normal(key, shape, dtype=dtype) 54 | mu = jnp.sqrt(jnp.pi) / jnp.sqrt(shape[0]) 55 | return mu + 0.00001 * normal_random_values 56 | 57 | def zero_mean(key, shape, dtype=jnp.float32): 58 | normal_random_values = jrnd.normal(key, shape, dtype=dtype) 59 | sigma = jnp.sqrt(2) / jnp.sqrt(shape[1]) 60 | return sigma * normal_random_values 61 | 62 | def softplus(x, beta=100): 63 | return jnp.logaddexp(0, beta * x) / beta 64 | 65 | def safe_softplus(x, beta=100): 66 | # revert to linear function for large inputs, same as pytorch 67 | return jnp.where(x * beta > 20, x, softplus(x)) 68 | -------------------------------------------------------------------------------- /diffcd/newton.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import jax 3 | import jax.numpy as jnp 4 | import flax 5 | from typing import Optional 6 | 7 | @flax.struct.dataclass 8 | class NewtonState: 9 | # variables 10 | z: jnp.array 11 | 12 | # variables from previous steps (including current step) 13 | z_steps: jnp.array 14 | 15 | # current step inex 16 | step: int 17 | 18 | # gradient of laplacian 19 | g: jnp.array 20 | 21 | # hessian of laplacian 22 | H: jnp.array 23 | 24 | # True if algorithm has converged 25 | converged: bool 26 | 27 | @flax.struct.dataclass 28 | class NewtonConfig: 29 | # maximum number of iterations 30 | max_iters: int = 4 31 | 32 | # converged when norm(grad(laplacian)(z)) <= eps 33 | grad_norm_eps: float = 1e-3 34 | 35 | # weather to stop when convergence critera is reached 36 | stop_when_converged: bool = False 37 | 38 | def newton_step(laplacian, config: NewtonConfig, args, state: NewtonState): 39 | z_next = state.z - jnp.linalg.lstsq(state.H, state.g)[0] 40 | g_next=jax.grad(laplacian, argnums=-1)(*args, z_next) 41 | H_next=jax.hessian(laplacian, argnums=-1)(*args, z_next) 42 | return NewtonState( 43 | z=z_next, 44 | z_steps=state.z_steps.at[state.step + 1].set(z_next), 45 | step=state.step + 1, 46 | g=g_next, 47 | H=H_next, 48 | converged=jnp.linalg.norm(g_next) < config.grad_norm_eps, 49 | ) 50 | 51 | def should_continue(config: NewtonConfig, state: NewtonState): 52 | return jnp.logical_and( 53 | state.step < config.max_iters, 54 | jnp.logical_not(jnp.logical_and(state.converged, config.stop_when_converged)), 55 | ) 56 | 57 | def _newton_kkt(laplacian, config: NewtonConfig, z0, *args): 58 | """Find kkt point z* where dL(params, z*)/dz = 0 using Newton's method.""" 59 | g0 = jax.grad(laplacian, argnums=-1)(*args, z0) 60 | H0 = jax.hessian(laplacian, argnums=-1)(*args, z0) 61 | init_state = NewtonState( 62 | z=z0, 63 | z_steps=jnp.repeat(jnp.zeros_like(z0)[None], config.max_iters+1, axis=0).at[0].set(z0), 64 | step=0, 65 | g=g0, 66 | H=H0, 67 | converged=jnp.linalg.norm(g0) < config.grad_norm_eps, 68 | ) 69 | final_state = jax.lax.while_loop( 70 | partial(should_continue, config), 71 | partial(newton_step, laplacian, config, args), 72 | init_state, 73 | ) 74 | return final_state.z, final_state 75 | 76 | def newton_kkt_fwd(laplacian, config: NewtonConfig, z0, *args): 77 | z, final_state = _newton_kkt(laplacian, config, z0, *args) 78 | return (z, final_state), (final_state, args) 79 | 80 | def newton_kkt_bwd(laplacian, config: NewtonConfig, info, tangent): 81 | final_state, args = info 82 | Hinvt = jnp.linalg.lstsq(final_state.H, tangent[0])[0] 83 | 84 | # -tangent @ Hinv @ d^2L/d(theta)dz 85 | jvp = jax.grad( 86 | lambda args: jax.jvp(partial(laplacian, *args), (final_state.z,), (-Hinvt,))[1] 87 | )(args) 88 | 89 | # gradient wrt z0 is 0 90 | return (jnp.zeros_like(final_state.z), *jvp) 91 | 92 | newton_kkt = jax.custom_vjp(_newton_kkt, nondiff_argnums=(0, 1)) 93 | newton_kkt.defvjp(newton_kkt_fwd, newton_kkt_bwd) 94 | -------------------------------------------------------------------------------- /diffcd/samplers.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import cKDTree 2 | from functools import partial 3 | import numpy as np 4 | import jax 5 | from typing import Optional 6 | import jax.random as jrnd 7 | import jax.numpy as jnp 8 | import flax 9 | import trimesh 10 | from dataclasses import field 11 | 12 | from diffcd.evaluation import meshing 13 | from diffcd.newton import NewtonConfig 14 | from diffcd.closest_point import closest_point_newton_batch 15 | 16 | 17 | @flax.struct.dataclass 18 | class SamplingConfig: 19 | # Number of points to sample around each query point, standard deviation determined by local sigma 20 | samples_per_point: int = 1 21 | 22 | # Number of neighbours to use for local sigma calculation 23 | k: int = 50 24 | 25 | # Number of global samples 26 | global_samples: Optional[int] = None 27 | 28 | # scale to apply to local sigma 29 | local_sigma_scale: float = .2 30 | 31 | @flax.struct.dataclass 32 | class SurfaceSamplingConfig: 33 | # number of mesh samples to generate each mesh interval 34 | num_samples: int = 30000 35 | 36 | # number of iterations between each recomputation of the surface mesh, which is used for initializing samples 37 | mesh_interval: int = 1000 38 | 39 | # config for closest point calculation, only max_iters is used 40 | newton: NewtonConfig = field(default_factory=lambda: NewtonConfig()) 41 | 42 | # config for mesh computation 43 | surface_meshing: meshing.Meshing = field(default_factory=lambda: meshing.Meshing()) 44 | 45 | @partial(jax.jit, static_argnames=("n_points", "n_dims")) 46 | def generate_global_samples(key, lower, upper, n_points, n_dims): 47 | return jrnd.uniform( 48 | key, 49 | shape=(n_points, n_dims), 50 | minval=jnp.array(lower), 51 | maxval=jnp.array(upper), 52 | ) 53 | 54 | @partial(jax.jit, static_argnames="samples_per_point") 55 | def generate_local_samples(key, query_points, samples_per_point, local_sigma): 56 | num_points, dims = query_points.shape 57 | noise = jrnd.normal(key, (num_points, samples_per_point, dims)) 58 | query_samples = query_points[:, None, :] + noise * local_sigma[:, None, None] 59 | return jnp.reshape(query_samples, (-1, dims)) 60 | 61 | @partial(jax.jit, static_argnames='n_samples') 62 | def sample_array(key, array, n_samples): 63 | sample_indices = jrnd.choice( 64 | key, len(array), 65 | (min(n_samples, len(array)),), 66 | replace=False, 67 | ) 68 | return array[sample_indices], sample_indices 69 | 70 | def compute_local_sigma(points, k): 71 | if k >= len(points): 72 | raise ValueError(f"Cannot find {k=} neighbours with {points.shape=}") 73 | 74 | if k == 0: 75 | return np.zeros(len(points)) 76 | sigmas = [] 77 | ptree = cKDTree(points) 78 | 79 | for points_batch in np.array_split(points, 100, axis=0): 80 | distances = ptree.query(points_batch, k + 1) 81 | sigmas.append(distances[0][:, -1]) 82 | return np.concatenate(sigmas) 83 | 84 | @flax.struct.dataclass 85 | class DescentState: 86 | i: int 87 | x: jnp.array 88 | 89 | def step(apply_fn, params, state): 90 | f, g = jax.value_and_grad(apply_fn, argnums=1)(params, state.x) 91 | return DescentState(state.i+1, state.x - f * g / jnp.linalg.norm(g)) 92 | 93 | @partial(jax.jit, static_argnames=["apply_fn", "n_steps"]) 94 | def sdf_descent(apply_fn, params, query_point, n_steps): 95 | """Compute surface point by iterating x_{k+1} = x_k - f(params, x_k)g(params, x_k)/||g(params, x_k)||""" 96 | state = jax.lax.while_loop( 97 | lambda state: state.i < n_steps, 98 | partial(step, apply_fn, params), 99 | DescentState(0, query_point), 100 | ) 101 | f = apply_fn(params, state.x) 102 | return state.x, jnp.abs(f) < 1e-3 103 | 104 | @partial(jax.jit, static_argnames=["f", "newton_config"]) 105 | def generate_surface_samples(f, params, mesh_samples, newton_config: NewtonConfig): 106 | """Generate samples from implicit surface defined by f(params, x) = 0 by sampling from approxmiating mesh and computing a neary surface point.""" 107 | surface_points, valid = jax.vmap(sdf_descent, in_axes=(None, None, 0, None))( 108 | f, params, mesh_samples, newton_config.max_iters, 109 | ) 110 | return surface_points, valid 111 | -------------------------------------------------------------------------------- /diffcd/training.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from flax.training.train_state import TrainState 3 | import flax 4 | import jax.numpy as jnp 5 | 6 | class ShapeTrainState(TrainState): 7 | # train sate that also includes the upper/lower bound for function inputs 8 | lower_bound: jnp.array = flax.struct.field(pytree_node=False) 9 | upper_bound: jnp.array = flax.struct.field(pytree_node=False) 10 | -------------------------------------------------------------------------------- /diffcd/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | import json 4 | import dataclasses 5 | import enum 6 | from pathlib import PosixPath, Path 7 | import pandas as pd 8 | 9 | class EnhancedJSONEncoder(json.JSONEncoder): 10 | def default(self, o): 11 | if dataclasses.is_dataclass(o): 12 | return dataclasses.asdict(o) 13 | if isinstance(o, PosixPath): 14 | return str(o) 15 | if isinstance(o, enum.Enum): 16 | return o.name 17 | if isinstance(o, frozenset): 18 | return list(o) 19 | return super().default(o) 20 | 21 | def print_config(config): 22 | for key, value in config.__dict__.items(): 23 | print(f'\033[96m{key}\033[0m: {value}') 24 | 25 | def save_metrics(metrics_dict, output_dir: Path, filename: str): 26 | pd.DataFrame(metrics_dict).to_csv(output_dir / (filename + '.csv')) 27 | 28 | def config_to_json(config): 29 | return json.loads(json.dumps(config, cls=EnhancedJSONEncoder)) 30 | 31 | def load_mesh(file_name: Path, normalize: bool=True): 32 | extension = file_name.suffix 33 | if extension == ".npz" or extension == ".npy": 34 | point_set = np.load(file_name).float() 35 | mesh = trimesh.points.PointCloud(point_set) 36 | elif extension == ".ply": 37 | mesh = trimesh.load(file_name, extension) 38 | else: 39 | raise NotImplementedError(f"File extension {extension} not supported") 40 | 41 | center = 3 # np.zeros(point_set.shape[1]) 42 | if normalize: 43 | center = np.mean(mesh.vertices, axis=0) 44 | mesh.vertices = mesh.vertices - np.expand_dims(center, axis=0) 45 | return mesh, center -------------------------------------------------------------------------------- /fit_implicit.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 3 | 4 | from pathlib import Path 5 | from dataclasses import dataclass, field 6 | from typing import Optional, Callable 7 | import tyro 8 | from tqdm.auto import tqdm 9 | import jax 10 | from jax import numpy as jnp 11 | from jax import random as jrnd 12 | from datetime import datetime 13 | import pickle 14 | import time 15 | from orbax.checkpoint import CheckpointManagerOptions, CheckpointManager, PyTreeCheckpointer 16 | import optax 17 | from flax.training import orbax_utils 18 | import numpy as np 19 | import wandb 20 | from functools import partial 21 | import pandas as pd 22 | import subprocess 23 | from trimesh import Trimesh 24 | 25 | from diffcd import training, utils, methods, datasets, networks 26 | from diffcd.evaluation import chamfer, meshing 27 | 28 | # override default checkpoint manager options 29 | @dataclass 30 | class CustomCheckpointManagerOptions(CheckpointManagerOptions): 31 | save_interval_steps: int = 1000 32 | 33 | # name of metric to use for checkpointing 34 | best_metric: Optional[str] = None 35 | 36 | def replace_none(value, replace_value): 37 | """replaces `value` with `replace_value` if `value` is None`""" 38 | return replace_value if value is None else value 39 | 40 | @dataclass 41 | class TrainingConfig: 42 | # dataset config 43 | dataset: datasets.PointCloud 44 | 45 | # Path to directory where experiment results will be saved 46 | output_dir: Path 47 | 48 | # Config for implicit function f(theta, x) 49 | model: networks.MLP 50 | 51 | # Config for model checkpointing 52 | checkpoint_options: CustomCheckpointManagerOptions 53 | 54 | # Config for ground truth mesh for computing shape metrics 55 | gt_mesh: datasets.EvaluationMesh = field(default_factory=lambda: datasets.EvaluationMesh()) 56 | 57 | # Config for converting the estimated SDF to a mesh 58 | eval_meshing: meshing.Meshing = field(default_factory=lambda: meshing.Meshing()) 59 | 60 | # Number of points per axis to use for meshing final shape 61 | final_mesh_points_per_axis: int = 512 62 | 63 | # Name of current experiment (a folder with this name will be created in output_dir) 64 | experiment_name: str = "experiment" 65 | 66 | # Whether to append a timestamp to the experiment name 67 | with_timestamp: bool = True 68 | 69 | # Path to yaml config file with default settings 70 | yaml_config: Path = None 71 | 72 | # Whether to copy the datasets to the output directory. 73 | copy_datasets: bool = True 74 | 75 | # Whether to save a .ply file at each evaluation step (only for 3D datasets) 76 | save_ply: bool = True 77 | 78 | learning_rate: float = 1e-3 79 | learning_rate_warmup: int = 1000 80 | 81 | batch_size: int = 5000 82 | n_batches: int = 40000 83 | rng_seed: int = 0 84 | 85 | # wandb logging settings, override wandb_project to enable logging 86 | wandb_project: Optional[str] = None 87 | wandb_entity: str = 'dcp-sdf' 88 | wandb_name: Optional[str] = None 89 | 90 | method: methods.Methods = field(default_factory=lambda: methods.DiffCD()) 91 | 92 | @property 93 | def experiment_dir(self): 94 | return self.output_dir / self.experiment_name 95 | 96 | def save_config(config: TrainingConfig, output_dir: Path, name: str='config'): 97 | with open(output_dir / f'{name}.yaml', 'w') as yaml_file: 98 | yaml_file.write(tyro.extras.to_yaml(config)) 99 | 100 | 101 | # save as pickle as well since yaml loading can break between versions 102 | with open(output_dir / f'{name}.pickle', 'wb') as pickle_file: 103 | pickle.dump(config, pickle_file) 104 | 105 | def load_config(experiment_dir: Path, name: str='config'): 106 | try: 107 | with open(experiment_dir / f'{name}.yaml', 'r') as yaml_file: 108 | return tyro.extras.from_yaml(TrainingConfig, yaml_file) 109 | except Exception as e: 110 | print(f'WARNING: failed to load config from yaml config from {experiment_dir} due to "{e}", probably a result of version mismatch. Loading pickle file instead.') 111 | with open(experiment_dir / f'{name}.pickle', 'rb') as pickle_file: 112 | return pickle.load(pickle_file) 113 | 114 | def check_best(metrics: list[dict], latest_metrics: dict, metric_name: str): 115 | if (metric_name is None) or (len(metrics) == 0): 116 | return True 117 | else: 118 | return latest_metrics[metric_name] < min([m[metric_name] for m in metrics]) 119 | 120 | def get_gpu_memory(): 121 | command = "nvidia-smi --query-gpu=memory.used --format=csv" 122 | memory_free_info = subprocess.check_output(command.split()).decode('ascii').split('\n')[:-1][1:] 123 | memory_free_values = {f'gpu{i}': int(x.split()[0]) for i, x in enumerate(memory_free_info)} 124 | return memory_free_values 125 | 126 | def eval(gt_mesh: Trimesh, n_samples: int, meshing_config: meshing.Meshing, f: Callable, transform: Callable, batch_index: int): 127 | """ 128 | Evaluate implicit surface 129 | 130 | Args: 131 | gt_mesh: ground truth mesh 132 | n_samples: number of surface samples to use for metric calculations 133 | meshing_config: config for converting implicit surface to a mesh 134 | f: function defining implicit surface via f(x) = 0 135 | transform: transform to apply to vertices of extracted mesh 136 | batch_index: index of current batch 137 | """ 138 | estimated_mesh = meshing.extract_mesh(meshing_config, f) 139 | estimated_mesh.vertices = transform(estimated_mesh.vertices) 140 | 141 | chamfer_metrics = chamfer.compute_chamfer(gt_mesh, estimated_mesh, n_samples) 142 | hausdorff_metrics = chamfer.compute_hausdorff(gt_mesh, estimated_mesh, n_samples) 143 | metrics = { 144 | 'step': batch_index, 145 | **chamfer_metrics, 146 | **hausdorff_metrics, 147 | **get_gpu_memory(), 148 | } 149 | return metrics, estimated_mesh 150 | 151 | 152 | def save_checkpoint(checkpoint_manager, train_state, checkpoint_info, save_args, batch_index): 153 | checkpoint_manager.save( 154 | step=batch_index, 155 | items={'model': train_state, **checkpoint_info}, 156 | save_kwargs={'save_args': save_args}, 157 | force=True, 158 | ) 159 | 160 | def cos_with_warmup(init_lr, warm_up, max_iters, step): 161 | lr = jnp.where(step < warm_up, step / warm_up, 0.5 * (jnp.cos((step - warm_up)/(max_iters - warm_up) * jnp.pi) + 1)) 162 | return lr * init_lr 163 | 164 | def run(config: TrainingConfig): 165 | '''Run training''' 166 | if config.yaml_config is not None: 167 | print(f"\033[92mYAML config {config.yaml_config} provided.") 168 | with open(config.yaml_config, 'r') as yaml_file: 169 | defaults = tyro.extras.from_yaml(TrainingConfig, yaml_file) 170 | config = tyro.cli(TrainingConfig, default=defaults) 171 | 172 | timestamp_str = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") if config.with_timestamp else '' 173 | config.experiment_name += timestamp_str 174 | experiment_dir = config.output_dir / config.experiment_name 175 | 176 | print(f'---\033[93m{config.experiment_name}\033[0m---') 177 | print(f'\033[96mexperiment dir\033[0m: {config.experiment_dir}') 178 | utils.print_config(config) 179 | 180 | config.experiment_dir.mkdir(parents=True, exist_ok=False) 181 | save_config(config, config.experiment_dir) 182 | 183 | np.random.seed(config.rng_seed) 184 | key = jrnd.PRNGKey(config.rng_seed) 185 | train_points = config.dataset.get_normalized_points(key) 186 | 187 | init_input = jnp.ones(config.dataset.n_dims) 188 | 189 | key, = jrnd.split(key, 1) 190 | params = config.model.init(key, init_input) 191 | 192 | train_metrics, eval_metrics = [], [] 193 | checkpoint_manager = CheckpointManager( 194 | directory=config.experiment_dir.resolve() / 'checkpoints', 195 | checkpointers=PyTreeCheckpointer(), 196 | options=config.checkpoint_options, 197 | ) 198 | 199 | lr_function = partial(cos_with_warmup, config.learning_rate, config.learning_rate_warmup, config.n_batches) 200 | 201 | train_state = training.ShapeTrainState.create( 202 | apply_fn=config.model.apply, 203 | tx=optax.adam(learning_rate=lr_function), 204 | params=params, 205 | lower_bound=config.eval_meshing.lower_bound, 206 | upper_bound=config.eval_meshing.upper_bound, 207 | ) 208 | checkpoint_info = { 209 | 'data': [init_input], 210 | 'scale_factor': config.dataset._scale_factor, 211 | 'center_point': config.dataset._center_point, 212 | } 213 | save_args = orbax_utils.save_args_from_target({'model': train_state, **checkpoint_info}) 214 | 215 | if config.wandb_project is not None: 216 | wandb.init( 217 | project=config.wandb_project, 218 | config=utils.config_to_json(config), 219 | name=replace_none(config.wandb_name, config.experiment_name), 220 | entity=config.wandb_entity, 221 | ) 222 | 223 | key, = jrnd.split(key, 1) 224 | method_state = config.method.init_state(key, train_points, config.experiment_dir if config.copy_datasets else None) 225 | if config.copy_datasets: 226 | np.save(config.experiment_dir / 'train_points.npy', train_points) 227 | 228 | apply_fn = jax.jit(train_state.apply_fn) 229 | 230 | start_time = time.time() 231 | for batch_index in tqdm(range(config.n_batches)): 232 | if checkpoint_manager.should_save(batch_index): 233 | save_checkpoint(checkpoint_manager, train_state, checkpoint_info, save_args, batch_index) 234 | utils.save_metrics(train_metrics, experiment_dir, 'train_metrics') 235 | 236 | metrics, estimated_mesh = eval(config.gt_mesh.mesh, config.gt_mesh.n_samples, config.eval_meshing, partial(apply_fn, train_state.params), config.dataset.undo_normalization, batch_index) 237 | if config.save_ply: 238 | meshing.save_ply(estimated_mesh, config.experiment_dir / f'meshes/mesh_{batch_index}.ply') 239 | if config.wandb_project is not None: 240 | wandb.log({'eval': metrics}) 241 | 242 | eval_metrics.append({**metrics, 'time': time.time() - start_time}) 243 | utils.save_metrics(eval_metrics, experiment_dir, 'eval_metrics') 244 | 245 | 246 | # training step 247 | train_step_time = time.time() 248 | key, = jrnd.split(key, 1) 249 | method_state, *batch = config.method.get_batch(train_state, method_state, key, config.batch_size) 250 | batch_metrics, train_state, nan_grads = config.method.step(train_state, *batch) 251 | train_step_time = time.time() - train_step_time 252 | 253 | # stop if there were nans in gradients and save state for debugging 254 | if nan_grads: 255 | save_checkpoint(checkpoint_manager, train_state, checkpoint_info, save_args, batch_index) 256 | np.save(config.experiment_dir / 'debug_key.npy', key) 257 | raise ValueError("nan encountered in gradients. Checkpoint saved for debugging.") 258 | batch_metrics = { 259 | 'step': batch_index, 260 | **batch_metrics, 261 | 'train_step_time': train_step_time, 262 | 'learning_rate': lr_function(train_state.step), 263 | 'time': time.time() - start_time, 264 | } 265 | train_metrics.append(batch_metrics) 266 | if config.wandb_project is not None: 267 | wandb.log({'train': batch_metrics}) 268 | 269 | print('\033[92mdone!\033[0m saving metrics...') 270 | batch_index += 1 271 | save_checkpoint(checkpoint_manager, train_state, checkpoint_info, save_args, batch_index) 272 | utils.save_metrics(train_metrics, experiment_dir, 'train_metrics') 273 | 274 | metrics, estimated_mesh = eval(config.gt_mesh.mesh, config.gt_mesh.n_samples, config.eval_meshing, partial(apply_fn, train_state.params), config.dataset.undo_normalization, batch_index) 275 | eval_metrics.append(metrics) 276 | if config.save_ply: 277 | meshing.save_ply(estimated_mesh, config.experiment_dir / f'meshes/mesh_{batch_index}.ply') 278 | if config.wandb_project is not None: 279 | wandb.log({'eval': metrics}) 280 | 281 | eval_metrics.append({**metrics, 'time': time.time() - start_time}) 282 | utils.save_metrics(eval_metrics, experiment_dir, 'eval_metrics') 283 | 284 | 285 | # do final eval with higher resolution marching cubes 286 | final_mesh_config = meshing.Meshing( 287 | config.final_mesh_points_per_axis, 288 | config.eval_meshing.lower_bound, 289 | config.eval_meshing.upper_bound, 290 | ) 291 | final_metrics, final_mesh = eval(config.gt_mesh.mesh, config.gt_mesh.n_samples, final_mesh_config, partial(apply_fn, train_state.params), config.dataset.undo_normalization, batch_index) 292 | meshing.save_ply(final_mesh, config.experiment_dir / f'mesh_final_{batch_index}.ply') 293 | pd.DataFrame([final_metrics]).to_csv(config.experiment_dir / f'eval_metrics_final_{batch_index}.csv') 294 | 295 | if __name__ == '__main__': 296 | run(tyro.cli(TrainingConfig)) -------------------------------------------------------------------------------- /images/results/max_noise/grid_metrics_diffcd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/max_noise/grid_metrics_diffcd.png -------------------------------------------------------------------------------- /images/results/max_noise/grid_metrics_igr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/max_noise/grid_metrics_igr.png -------------------------------------------------------------------------------- /images/results/max_noise/grid_metrics_neural-pull.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/max_noise/grid_metrics_neural-pull.png -------------------------------------------------------------------------------- /images/results/max_noise/grid_metrics_nksr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/max_noise/grid_metrics_nksr.png -------------------------------------------------------------------------------- /images/results/max_noise/grid_metrics_siren0.033.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/max_noise/grid_metrics_siren0.033.png -------------------------------------------------------------------------------- /images/results/max_noise/grid_metrics_siren0.33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/max_noise/grid_metrics_siren0.33.png -------------------------------------------------------------------------------- /images/results/medium_noise/grid_metrics_diffcd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/medium_noise/grid_metrics_diffcd.png -------------------------------------------------------------------------------- /images/results/medium_noise/grid_metrics_igr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/medium_noise/grid_metrics_igr.png -------------------------------------------------------------------------------- /images/results/medium_noise/grid_metrics_neural-pull.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/medium_noise/grid_metrics_neural-pull.png -------------------------------------------------------------------------------- /images/results/medium_noise/grid_metrics_nksr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/medium_noise/grid_metrics_nksr.png -------------------------------------------------------------------------------- /images/results/medium_noise/grid_metrics_siren0.033.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/medium_noise/grid_metrics_siren0.033.png -------------------------------------------------------------------------------- /images/results/medium_noise/grid_metrics_siren0.33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/medium_noise/grid_metrics_siren0.33.png -------------------------------------------------------------------------------- /images/results/no_noise/grid_metrics-diffcd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/no_noise/grid_metrics-diffcd.png -------------------------------------------------------------------------------- /images/results/no_noise/grid_metrics-nksr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/no_noise/grid_metrics-nksr.png -------------------------------------------------------------------------------- /images/results/no_noise/grid_metrics_igr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/no_noise/grid_metrics_igr.png -------------------------------------------------------------------------------- /images/results/no_noise/grid_metrics_neural-pull.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/no_noise/grid_metrics_neural-pull.png -------------------------------------------------------------------------------- /images/results/no_noise/grid_metrics_siren0.033.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/no_noise/grid_metrics_siren0.033.png -------------------------------------------------------------------------------- /images/results/no_noise/grid_metrics_siren0.33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/results/no_noise/grid_metrics_siren0.33.png -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/images/teaser.png -------------------------------------------------------------------------------- /notebook_utils.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | 4 | def srgb_to_linearrgb(c): 5 | if c < 0: return 0 6 | elif c < 0.04045: return c/12.92 7 | else: return ((c+0.055)/1.055)**2.4 8 | 9 | 10 | def hex_to_rgb(h,alpha=1): 11 | # source: https://blender.stackexchange.com/questions/153094/blender-2-8-python-how-to-set-material-color-using-hex-value-instead-of-rgb 12 | r = (h & 0xff0000) >> 16 13 | g = (h & 0x00ff00) >> 8 14 | b = (h & 0x0000ff) 15 | return tuple([srgb_to_linearrgb(c/0xff) for c in (r,g,b)] + [alpha]) 16 | 17 | def pprint(array, precision=4): 18 | with np.printoptions( 19 | precision=precision, 20 | suppress=True, 21 | ): 22 | print(array) 23 | 24 | def get_range(eps=0.1, n=100, symmetric=False, eps_negative=None): 25 | if eps_negative is None: 26 | eps_negative = eps 27 | taus = np.linspace((-1 if symmetric else 0)-eps_negative, 1+eps, n) 28 | taus = np.insert(taus, np.searchsorted(taus, [0, 1]), [0, 1]) 29 | return taus 30 | 31 | def get_grid(min, max, n=100): 32 | import jax.numpy as jnp 33 | if isinstance(n, int): 34 | n = [n for _ in range(len(min))] 35 | axis_values = (jnp.linspace(min[i], max[i], n[i]) for i in range(len(min))) 36 | grid_values = jnp.meshgrid(*axis_values, indexing='ij') 37 | return jnp.squeeze(jnp.stack(grid_values, axis=-1)) 38 | 39 | def get_bins(n, *values): 40 | all_values = np.hstack(values) 41 | return np.linspace(all_values.min(), all_values.max(), n) 42 | 43 | def lighten_color(color, amount=0.5): 44 | """ 45 | https://stackoverflow.com/questions/37765197/darken-or-lighten-a-color-in-matplotlib 46 | Lightens the given color by multiplying (1-luminosity) by the given amount. 47 | Input can be matplotlib color string, hex string, or RGB tuple. 48 | 49 | Examples: 50 | >> lighten_color('g', 0.3) 51 | >> lighten_color('#F034A3', 0.6) 52 | >> lighten_color((.3,.55,.1), 0.5) 53 | """ 54 | import matplotlib.colors as mc 55 | import colorsys 56 | try: 57 | c = mc.cnames[color] 58 | except: 59 | c = color 60 | c = colorsys.rgb_to_hls(*mc.to_rgb(c)) 61 | return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2]) 62 | 63 | def get_df(df, **kwargs): 64 | conds = [df[col] == val for col, val in kwargs.items()] 65 | cond = conds[0] 66 | for i in range(1, len(conds)): 67 | cond = cond & conds[i] 68 | return df[cond] 69 | 70 | def get_df_single(df, **kwargs): 71 | df_subset = get_df(df, **kwargs) 72 | if len(df_subset) != 1: 73 | raise ValueError(f"Expected unique result for {kwargs=}. Got {df_subset}") 74 | return df_subset.iloc[0] 75 | 76 | def bold(text): 77 | text = text.replace('_', ' ') 78 | text = text.replace(' ', '}$ $\\bf{') 79 | return r"$\bf{" + text + r"}$" 80 | 81 | from matplotlib import colors 82 | class NonSymmetricNormalize(colors.Normalize): 83 | def __call__(self, value, clip=None): 84 | if clip is None: 85 | clip = self.clip 86 | 87 | result, is_scalar = self.process_value(value) 88 | 89 | if self.vmin is None or self.vmax is None: 90 | self.autoscale_None(result) 91 | # Convert at least to float, without losing precision. 92 | (vmin,), _ = self.process_value(self.vmin) 93 | (vmax,), _ = self.process_value(self.vmax) 94 | if vmin == vmax: 95 | result.fill(0) # Or should it be all masked? Or 0.5? 96 | elif vmin > vmax: 97 | raise ValueError("minvalue must be less than or equal to maxvalue") 98 | else: 99 | if clip: 100 | mask = np.ma.getmask(result) 101 | result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax), 102 | mask=mask) 103 | # ma division is very slow; we can take a shortcut 104 | resdat = result.data 105 | resdat[resdat < 0] /= np.abs(vmin) * 2 106 | resdat[resdat > 0] /= np.abs(vmax) * 2 107 | resdat += .5 108 | # resdat /= (vmax - vmin) 109 | result = np.ma.array(resdat, mask=result.mask, copy=False) 110 | if is_scalar: 111 | result = result[0] 112 | return result 113 | cmap = 'RdBu_r' 114 | 115 | def plot_sdf_surface(ax, inputs, sdf_values, n_positive_levels=10, n_negative_levels=10, surface_width=2, with_0_level=True, hide_ticks=True, **kwargs): 116 | contour = ax.contourf( 117 | inputs[..., 0], inputs[..., 1], NonSymmetricNormalize()(sdf_values), cmap='RdBu_r', 118 | levels=np.hstack([np.linspace(0, .5, n_negative_levels), np.linspace(0.5, 1., n_positive_levels+1)[1:]]) 119 | ) 120 | 121 | # make sure there are no white gaps between level sets 122 | for c in contour.collections: 123 | c.set_edgecolor("face") 124 | 125 | if with_0_level: 126 | ax.contour(inputs[..., 0], inputs[..., 1], sdf_values, colors='k', levels=[0], linewidths=surface_width, **kwargs) 127 | if hide_ticks: 128 | ax.set_xticks([]) 129 | ax.set_yticks([]) 130 | ax.spines['top'].set_visible(False) 131 | ax.spines['right'].set_visible(False) 132 | ax.spines['bottom'].set_visible(False) 133 | ax.spines['left'].set_visible(False) 134 | 135 | def get_unique(df, key): 136 | vals = df[key].unique() 137 | if len(vals) != 1: 138 | raise ValueError(f"Expected one {key}, got {vals}") 139 | return vals[0] 140 | 141 | 142 | def get_basis(mu0, mu1, mu2, preserve_angle=True, preserve_norm=True): 143 | u1 = mu1 - mu0 144 | u2 = mu2 - mu0 145 | if preserve_angle: 146 | u2 -= u1 * (u1 @ u2) / (u1 @ u1) 147 | 148 | if preserve_norm: 149 | u2 *= np.linalg.norm(u1) / np.linalg.norm(u2) 150 | 151 | x = np.linalg.lstsq(np.hstack([u1[:, None], u2[:, None]]), np.hstack([(mu1 - mu0)[:, None], (mu2 - mu0)[:, None]]))[0] 152 | x1, x2 = x[:, 0], x[:, 1] 153 | 154 | return u1, u2, x1, x2 155 | 156 | def tabilize(results, precisions, rank_order, suffixes=None, hlines = []): 157 | 158 | def rankify(x, order): 159 | # Turn a vector of values into a list of ranks, while handling ties. 160 | assert len(x.shape) == 1 161 | if order == 0: 162 | return np.full_like(x, 1e5, dtype=np.int32) 163 | u = np.sort(np.unique(x)) 164 | if order == 1: 165 | u = u[::-1] 166 | r = np.zeros_like(x, dtype=np.int32) 167 | for ui, uu in enumerate(u): 168 | mask = x == uu 169 | r[mask] = ui 170 | return np.int32(r) 171 | 172 | names = results.keys() 173 | data = np.array(list(results.values())) 174 | assert len(names) == len(data) 175 | data = np.array(data) 176 | 177 | tags = [' \cellcolor{tabfirst}', 178 | '\cellcolor{tabsecond}', 179 | ' \cellcolor{tabthird}', 180 | ' '] 181 | 182 | max_len = max([len(v) for v in list(names)]) 183 | names_padded = [v + ' '*(max_len-len(v)) for v in names] 184 | 185 | data_quant = np.round((data * 10.**(np.array(precisions)[None, :]))) / 10.**(np.array(precisions)[None, :]) 186 | if suffixes is None: 187 | suffixes = [''] * len(precisions) 188 | 189 | tagranks = [] 190 | for d in range(data_quant.shape[1]): 191 | tagranks.append(np.clip(rankify(data_quant[:,d], rank_order[d]), 0, len(tags)-1)) 192 | tagranks = np.stack(tagranks, -1) 193 | 194 | for i_row in range(len(names)): 195 | line = '\t' 196 | if i_row in hlines: 197 | line += '\\hline\n' 198 | line += names_padded[i_row] 199 | for d in range(data_quant.shape[1]): 200 | line += ' & ' 201 | if rank_order[d] != 0 and not np.isnan(data[i_row,d]): 202 | line += tags[tagranks[i_row, d]] 203 | if np.isnan(data[i_row,d]): 204 | line += ' - ' 205 | else: 206 | assert precisions[d] >= 0 207 | line += ('{:' + f'0.{precisions[d]}f' + '}').format(data_quant[i_row,d]) + suffixes[d] 208 | line += ' \\\\' 209 | print(line) -------------------------------------------------------------------------------- /notebook_utils_blender.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import blender_plots as bplt 3 | from blender_plots import blender_utils as bu 4 | import trimesh 5 | import mathutils 6 | from pathlib import Path 7 | import numpy as np 8 | import math 9 | 10 | def add_text(text, location, color=None, name=''): 11 | font_curve = bpy.data.curves.new(type="FONT", name=text) 12 | font_curve.body = text 13 | font_obj = bu.new_empty(f'text_{name}', font_curve) 14 | color_mat = bpy.data.materials.new("Text color") 15 | color_mat.diffuse_color = (0., 0, 0, 0) if color is None else color 16 | font_obj.data.materials.append(color_mat) 17 | 18 | # bpy.context.scene.collection.objects.link(font_obj) 19 | font_obj.location = location 20 | font_obj.scale = [0.2, 0.2, 0.2] 21 | 22 | modifier = font_obj.modifiers.new(type="SOLIDIFY", name="solidify") 23 | modifier.thickness = 0.05 24 | return font_obj 25 | 26 | def add_mesh(mesh: trimesh.Trimesh, color, name="mesh", center=False, offset=None, scale=1.): 27 | if offset is None: 28 | offset = np.array([0, 0, 0]) 29 | 30 | mesh = mesh.copy() # make sure not to modify original mesh 31 | mesh.vertices *= scale 32 | if center: 33 | mesh.vertices -= mesh.vertices.mean(axis=0) 34 | # if offset is not None: 35 | # mesh.vertices += offset 36 | bmesh = bpy.data.meshes.new(name=name) 37 | bmesh.from_pydata(mesh.vertices, [], mesh.faces) 38 | mesh_object = bu.new_empty(name, bmesh) 39 | mesh_object.location = offset 40 | 41 | material = bpy.data.materials.new(name) 42 | material.use_nodes = True 43 | bsdf = material.node_tree.nodes['Principled BSDF'] 44 | bsdf.inputs['Base Color'].default_value = color 45 | mesh_object.data.materials.append(material) 46 | return mesh_object 47 | 48 | def plot_mesh(mesh_path, color, offset=None, scale=1., name="mesh", center=False, train_points=False, text='', rotation=None, with_text=True, text_color=None): 49 | if offset is None: 50 | offset = np.zeros(3) 51 | if rotation is None: 52 | rotation = (0, 0, 0) 53 | if isinstance(mesh_path, trimesh.Trimesh): 54 | gt_mesh = mesh_path 55 | else: 56 | gt_mesh = trimesh.load(mesh_path) 57 | if hasattr(gt_mesh, 'vertices'): 58 | mesh_object = add_mesh(gt_mesh, color, name, center, offset, scale) 59 | mesh_object.rotation_euler = rotation 60 | else: 61 | mesh_object = None 62 | # gt_mesh_object.color = (1, 0, 0, 1) 63 | 64 | if with_text: 65 | font_obj = add_text(text, offset + np.array([-0., -0.8, -0.62]), name=name, color=text_color) 66 | if train_points: 67 | train_points = np.load(Path(mesh_path).parent.parent / 'train_points.npy') 68 | scatter = bplt.Scatter(train_points + offset, marker_type='ico_spheres', radius=0.01, subdivisions=2, color=[0, 0, 0, 0.1], name=f"train points {name}") 69 | else: 70 | scatter = None 71 | return mesh_object, scatter 72 | 73 | def bounding_box(lower, upper, name, rotation=None, offset=None): 74 | x1, y1, z1 = lower 75 | x2, y2, z2 = upper 76 | 77 | vertices = [ 78 | (x1, y1, z1), # Vertex 0 79 | (x2, y1, z1), # Vertex 1 80 | (x2, y2, z1), # Vertex 2 81 | (x1, y2, z1), # Vertex 3 82 | (x1, y1, z2), # Vertex 4 83 | (x2, y1, z2), # Vertex 5 84 | (x2, y2, z2), # Vertex 6 85 | (x1, y2, z2) # Vertex 7 86 | ] 87 | 88 | edges = [ 89 | (0, 1), (1, 2), (2, 3), (3, 0), # Bottom face 90 | (4, 5), (5, 6), (6, 7), (7, 4), # Top face 91 | (0, 4), (1, 5), (2, 6), (3, 7) # Connecting faces 92 | ] 93 | 94 | mesh = bpy.data.meshes.new(name) 95 | obj = bu.new_empty(name, mesh) 96 | mesh.from_pydata(vertices, edges, []) 97 | if rotation is not None: 98 | obj.rotation_euler = rotation 99 | if offset is not None: 100 | obj.location = offset 101 | return obj 102 | 103 | def create_camera(location, rotation): 104 | if "Camera" in bpy.data.objects: 105 | bpy.data.objects.remove(bpy.data.objects["Camera"]) 106 | 107 | bpy.ops.object.camera_add(enter_editmode=False, align='VIEW', location=location, rotation=rotation) 108 | bpy.context.scene.camera = bpy.data.objects['Camera'] 109 | 110 | def render_image(output_path, resolution, samples=100): 111 | bpy.context.scene.render.resolution_x = resolution[0] 112 | bpy.context.scene.render.resolution_y = resolution[1] 113 | bpy.context.scene.cycles.samples = samples 114 | bpy.context.scene.render.filepath = output_path 115 | bpy.ops.render.render(write_still=True) 116 | 117 | def setup_scene(clear=False, camera_location=None, camera_rotation=None, resolution=None): 118 | if "Cube" in bpy.data.objects: 119 | bpy.data.objects.remove(bpy.data.objects["Cube"]) 120 | 121 | if camera_location is None: 122 | camera_location = np.array([0, -5.321560, 2.042498]) * 0.6 123 | if camera_rotation is None: 124 | camera_rotation = [math.radians(68.4), 0., 0.] 125 | 126 | if clear: 127 | bpy.ops.wm.read_homefile() 128 | bpy.data.worlds["World"].node_tree.nodes["Background"].inputs[0].default_value = (1, 1, 1, 1) 129 | bpy.context.scene.render.engine = 'CYCLES' 130 | bpy.data.scenes["Scene"].cycles.samples = 256 131 | 132 | if "Sun" in bpy.data.objects: 133 | bpy.data.objects.remove(bpy.data.objects["Sun"]) 134 | bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(0, 0, 0), scale=(1, 1, 1)) 135 | bpy.data.objects["Sun"].data.energy = 30. 136 | bpy.data.objects["Sun"].data.angle = np.pi / 2 137 | bpy.data.worlds["World"].node_tree.nodes["Background"].inputs["Strength"].default_value = 0.5 138 | 139 | bpy.context.scene.render.film_transparent = True 140 | create_camera(camera_location, camera_rotation) 141 | if resolution is not None: 142 | bpy.context.scene.render.resolution_x = resolution[0] 143 | bpy.context.scene.render.resolution_y = resolution[1] 144 | 145 | def euler_to_R(euler): 146 | return np.array(mathutils.Euler(euler).to_matrix()) 147 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.25.2 2 | jax==0.4.14 3 | jaxlib==0.4.14 4 | matplotlib 5 | tqdm 6 | jaxopt==0.7 7 | flax==0.7.4 8 | typing-extensions 9 | tyro 10 | pandas 11 | orbax==0.1.9 12 | orbax-checkpoint==0.4.8 13 | plyfile 14 | trimesh==3.23.1 15 | scipy==1.11.1 16 | scikit-image==0.21.0 17 | wandb -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='diffcd', 5 | version='0.1.0', 6 | description='Official implementation of "DiffCD: A Symmetric Differentiable Chamfer Distance for Neural Implicit Surface Fitting"', 7 | author='Linus Härenstam-Nielsen', 8 | author_email='linus.nielsen@tum.de', 9 | url='https://github.com/linusnie/diffcd', 10 | packages=find_packages(), 11 | install_requires=[ 12 | "numpy", 13 | "jax[cuda12]", 14 | "jaxlib", 15 | "matplotlib", 16 | "tqdm", 17 | "jaxopt", 18 | "flax", 19 | "typing-extensions", 20 | "tyro", 21 | "pandas", 22 | "orbax", 23 | "orbax-checkpoint", 24 | "plyfile", 25 | "trimesh", 26 | "scipy", 27 | "scikit-image", 28 | ], 29 | classifiers=[ 30 | 'Programming Language :: Python :: 3', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Operating System :: OS Independent', 33 | ], 34 | python_requires='>=3.10', 35 | ) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linusnie/diffcd/4fa94830cd528d87abd2d59fe92f9d303b1e5788/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_methods.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | import jax.numpy as jnp 4 | import unittest 5 | 6 | import diffcd 7 | 8 | def first(f): 9 | return lambda *args, **kwargs: f(*args, **kwargs)[0] 10 | 11 | class TestMethods(unittest.TestCase): 12 | 13 | def test_point_cloud_closest_point(self): 14 | point_cloud = jnp.array([ 15 | [0., 0., 0.], 16 | [1., 1., 1.], 17 | [5., -1., 3.], 18 | ]) 19 | for point in point_cloud: 20 | closest_point = diffcd.methods.point_cloud_closest_point( 21 | point + jnp.array([.1, .1, .1]), point_cloud 22 | ) 23 | np.testing.assert_array_equal(point, closest_point) 24 | 25 | def test_eikonal_loss_exact_sdf(self): 26 | f = lambda r, x: r - jnp.linalg.norm(x) 27 | eikonal_loss = diffcd.methods.get_eikonal_loss(f, 2., jnp.array([1., 2., 3.])) 28 | self.assertAlmostEqual(eikonal_loss, 0.) 29 | 30 | def test_eikonal_loss_approx_sdf(self): 31 | f = lambda r, x: r ** 2 - diffcd.closest_point.sq_norm(x) 32 | x = jnp.array([1., 2., 3.]) 33 | eikonal_loss = diffcd.methods.get_eikonal_loss(f, 2., x) 34 | self.assertAlmostEqual(eikonal_loss, (1 - jnp.linalg.norm(2 * x)) ** 2) 35 | 36 | x = jnp.array([0., 0., 0.]) 37 | eikonal_loss = diffcd.methods.get_eikonal_loss(f, 2., x) 38 | self.assertAlmostEqual(eikonal_loss, (1 - jnp.linalg.norm(2 * x)) ** 2) 39 | 40 | def test_igr_loss(self): 41 | f = lambda r, x: r - jnp.linalg.norm(x) 42 | x, r = jnp.array([1., 2., 3.]), 2. 43 | 44 | # l2 45 | igr_l2 = diffcd.methods.IGR(eikonal_weight=0., distance_metric='l2') 46 | igr_loss, _ = igr_l2.get_loss(f, r, x, x, x) 47 | self.assertAlmostEqual(igr_loss, jnp.linalg.norm(x - r * x / jnp.linalg.norm(x)), places=6) 48 | grad, _ = jax.grad(igr_l2.get_loss, argnums=1, has_aux=True)(f, r, x, x, x) 49 | self.assertAlmostEqual(grad, -1.) 50 | 51 | # squared l2 52 | igr_sql2 = diffcd.methods.IGR(eikonal_weight=0., distance_metric='squared_l2') 53 | igr_loss, _ = igr_sql2.get_loss(f, r, x, x, x) 54 | self.assertAlmostEqual(igr_loss, jnp.linalg.norm(x - r * x / jnp.linalg.norm(x)) ** 2, places=6) 55 | grad, _ = jax.grad(igr_sql2.get_loss, argnums=1, has_aux=True)(f, r, x, x, x) 56 | self.assertAlmostEqual(grad, 2 * (r - jnp.linalg.norm(x))) 57 | 58 | # l2 with f(x) = 0 59 | f = lambda r, x: r ** 2 - diffcd.closest_point.sq_norm(x) 60 | x = jnp.array([0., 0., 0.]) 61 | (igr_loss, _), grad = jax.value_and_grad(igr_l2.get_loss, argnums=1, has_aux=True)(f, r, x, x, x) 62 | self.assertAlmostEqual(igr_loss, r ** 2) 63 | self.assertAlmostEqual(grad, 2 * r) 64 | 65 | def test_pull_point(self): 66 | """Pulled point should equal closest point for SDF.""" 67 | radius = 2 68 | f = lambda x: radius - jnp.linalg.norm(x) 69 | x = jnp.array([1., 2., 3.]) 70 | 71 | pulled_point = diffcd.methods.pull_point(f, x) 72 | np.testing.assert_array_almost_equal( 73 | pulled_point, x / jnp.linalg.norm(x) * radius 74 | ) 75 | 76 | def test_pull_point_zero_grad(self): 77 | """Check that pull_points handles points with gradient=0.""" 78 | radius = 2 79 | f = lambda x: radius ** 2 - diffcd.closest_point.sq_norm(x) 80 | x = jnp.array([0., 0., 0.]) 81 | 82 | pulled_point = diffcd.methods.pull_point(f, x) 83 | np.testing.assert_array_almost_equal(pulled_point, x) 84 | 85 | loss, grad = jax.value_and_grad(lambda x: ((diffcd.methods.pull_point(f, x) - jnp.ones(3)) ** 2).mean())(x) 86 | np.testing.assert_almost_equal(loss, 1.) 87 | self.assertFalse(jnp.isnan(grad).any()) 88 | 89 | def test_closest_point(self): 90 | f = lambda radius, x: radius ** 2 - diffcd.closest_point.sq_norm(x) 91 | x = jnp.array([1., 2., 3.]) 92 | 93 | radius, eps, max_iters = 2, 1e-6, 10 94 | 95 | # without stop when converged 96 | closest_point, newton_state, valid = diffcd.closest_point.closest_point_newton( 97 | f, radius, x, x, diffcd.closest_point.NewtonConfig(grad_norm_eps=eps, max_iters=max_iters, stop_when_converged=False) 98 | ) 99 | self.assertTrue(valid) 100 | np.testing.assert_array_almost_equal(closest_point, x / jnp.linalg.norm(x) * radius) 101 | self.assertTrue(newton_state.converged) 102 | self.assertEqual(newton_state.step, max_iters) 103 | laplacian_grad = jax.grad(diffcd.closest_point.laplacian, argnums=-1)(f, radius, x, newton_state.z_steps[-1]) 104 | self.assertLess(jnp.linalg.norm(laplacian_grad), eps) 105 | 106 | # with stop when converged 107 | closest_point, newton_state, valid = diffcd.closest_point.closest_point_newton( 108 | f, radius, x, x, diffcd.closest_point.NewtonConfig(grad_norm_eps=eps, max_iters=max_iters, stop_when_converged=True) 109 | ) 110 | self.assertTrue(valid) 111 | np.testing.assert_array_almost_equal(closest_point, x / jnp.linalg.norm(x) * radius) 112 | self.assertTrue(newton_state.converged) 113 | self.assertLess(newton_state.step, max_iters) 114 | laplacian_grad = jax.grad(diffcd.closest_point.laplacian, argnums=-1)(f, radius, x, newton_state.z_steps[newton_state.step]) 115 | self.assertLess(jnp.linalg.norm(laplacian_grad), eps) 116 | 117 | 118 | def test_closest_point_grad(self): 119 | f = lambda radius, x: radius ** 2 - diffcd.closest_point.sq_norm(x) 120 | x = jnp.array([1., 2., 3.]) 121 | 122 | radius = 2. 123 | closest_point_grad = jax.jacrev(first(diffcd.closest_point.closest_point_newton), argnums=1)( 124 | f, radius, x, x, diffcd.closest_point.NewtonConfig() 125 | ) 126 | np.testing.assert_array_almost_equal( 127 | closest_point_grad, x / jnp.linalg.norm(x) 128 | ) 129 | 130 | if __name__ == '__main__': 131 | unittest.main() 132 | -------------------------------------------------------------------------------- /tests/test_training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | import unittest 4 | import jax 5 | import jax.numpy as jnp 6 | import jax.random as jrnd 7 | import optax 8 | import tempfile 9 | import trimesh 10 | import os 11 | 12 | import diffcd 13 | import fit_implicit 14 | from diffcd.evaluation import meshing 15 | 16 | def make_test_ply(output_file: Path): 17 | vertices = np.array([ 18 | [-0.5, -0.5, -0.5], 19 | [ 0.5, -0.5, -0.5], 20 | [ 0.5, 0.5, -0.5], 21 | [-0.5, 0.5, -0.5], 22 | [-0.5, -0.5, 0.5], 23 | [ 0.5, -0.5, 0.5], 24 | [ 0.5, 0.5, 0.5], 25 | [-0.5, 0.5, 0.5] 26 | ]) 27 | faces = np.array([ 28 | [0, 1, 2], 29 | [0, 2, 3], 30 | [4, 5, 6], 31 | [4, 6, 7], 32 | [0, 4, 7], 33 | [0, 7, 3], 34 | [1, 5, 6], 35 | [1, 6, 2], 36 | [0, 1, 5], 37 | [0, 5, 4], 38 | [2, 3, 7], 39 | [2, 7, 6] 40 | ]) 41 | mesh = trimesh.Trimesh(vertices=vertices, faces=faces) 42 | meshing.save_ply(mesh, output_file) 43 | 44 | def make_test_npy(output_file: Path): 45 | points = np.array([ 46 | [-0.5, -0.5, -0.5], 47 | [ 0.5, -0.5, -0.5], 48 | [ 0.5, 0.5, -0.5], 49 | [-0.5, 0.5, -0.5], 50 | [-0.5, -0.5, 0.5], 51 | [ 0.5, -0.5, 0.5], 52 | [ 0.5, 0.5, 0.5], 53 | [-0.5, 0.5, 0.5] 54 | ]) 55 | np.save(output_file, points) 56 | 57 | class TestTraining(unittest.TestCase): 58 | 59 | def assertFileExists(self, file_path: Path): 60 | self.assertTrue(file_path.exists(), f'File {file_path} does not exist.') 61 | 62 | def test_fit_implicit_point_cloud(self): 63 | """Run fit_implicit and check that all outputs are created correctly.""" 64 | with tempfile.TemporaryDirectory() as tmp_dir: 65 | tmp_dir = Path(tmp_dir) 66 | make_test_npy(tmp_dir / 'cube.npy') 67 | make_test_ply(tmp_dir / 'cube.ply') 68 | 69 | config = fit_implicit.TrainingConfig( 70 | output_dir=tmp_dir / 'outputs', 71 | model=diffcd.networks.MLP(10, 3), 72 | checkpoint_options=fit_implicit.CustomCheckpointManagerOptions(save_interval_steps=2), 73 | method=diffcd.methods.NeuralPull( 74 | sampling=diffcd.samplers.SamplingConfig(k=3), 75 | ), 76 | final_mesh_points_per_axis=32, 77 | batch_size=12, 78 | n_batches=4, 79 | with_timestamp=False, 80 | experiment_name='test_experiment', 81 | dataset=diffcd.datasets.PointCloud(path=tmp_dir / 'cube.npy'), 82 | gt_mesh=diffcd.datasets.EvaluationMesh(path=tmp_dir / 'cube.ply') 83 | ) 84 | fit_implicit.run(config) 85 | experiment_dir = tmp_dir / 'outputs' / 'test_experiment' 86 | self.assertFileExists(experiment_dir) 87 | 88 | file_names = [ 89 | 'config.pickle', 90 | 'config.yaml', 91 | 'eval_metrics.csv', 92 | 'eval_metrics_final_4.csv', 93 | 'local_sigma.npy', 94 | 'sample_points.npy', 95 | 'target_points.npy', 96 | 'mesh_final_4.ply', 97 | 'train_metrics.csv', 98 | 'train_points.npy', 99 | ] 100 | for file_name in file_names: 101 | self.assertFileExists(experiment_dir / file_name) 102 | 103 | # check that there are no extra files 104 | self.assertEqual(len(next(os.walk(experiment_dir))[2]), len(file_names)) 105 | 106 | for checkpoint_index in [0, 2, 4]: 107 | self.assertFileExists(experiment_dir / f'checkpoints/{checkpoint_index}') 108 | self.assertFileExists(experiment_dir / f'meshes/mesh_{checkpoint_index}.ply') 109 | 110 | def test_step(self): 111 | model = diffcd.networks.MLP( 112 | layer_size=10, 113 | n_layers=4, 114 | skip_layers=(2,) 115 | ) 116 | key = jrnd.PRNGKey(0) 117 | params = model.init(key, jnp.zeros(3) * 1.) 118 | 119 | self.assertEqual(len(params['params']), model.n_layers + 1) 120 | for i in range(model.n_layers): 121 | in_dim = 3 if i == 0 else model.layer_size 122 | out_dim = model.layer_size if i + 1 not in model.skip_layers else model.layer_size - 3 123 | self.assertEqual(params['params'][f'dense_{i}']['kernel'].shape, (in_dim, out_dim)) 124 | self.assertEqual(params['params'][f'dense_{i}']['bias'].shape, (out_dim,)) 125 | self.assertEqual(params['params'][f'dense_{model.n_layers}']['kernel'].shape, (model.layer_size, 1)) 126 | self.assertEqual(params['params'][f'dense_{model.n_layers}']['bias'].shape, (1,)) 127 | 128 | for method_class in [ 129 | diffcd.methods.IGR, 130 | diffcd.methods.NeuralPull, 131 | lambda **kwargs: diffcd.methods.DiffCD(**kwargs, p2s_loss='closest-point'), 132 | lambda **kwargs: diffcd.methods.DiffCD(**kwargs, p2s_loss='implicit'), 133 | ]: 134 | with self.subTest(method_class.__name__): 135 | method = method_class(sampling=diffcd.samplers.SamplingConfig(k=3)) 136 | train_state = diffcd.training.ShapeTrainState.create( 137 | apply_fn=jax.jit(model.apply), 138 | tx=optax.adam(learning_rate=1e-3,), 139 | params=params, 140 | lower_bound=(-1.8, -1.8, -1.8), 141 | upper_bound=(1.8, 1.8, 1.8), 142 | ) 143 | train_points = jnp.ones((10, 3)) 144 | method_state = method.init_state(key, train_points, None) 145 | _, *batch = method.get_batch(train_state, method_state, key, 5) 146 | metrics, train_state, nan_grads = method.step( 147 | train_state, *batch 148 | ) 149 | self.assertFalse(jnp.any(nan_grads)) 150 | 151 | if __name__ == '__main__': 152 | unittest.main() 153 | --------------------------------------------------------------------------------