├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── demos ├── lipschitz_mlp │ ├── README.md │ ├── lipschitz_mlp_interpolation.mp4 │ ├── lipschitz_mlp_loss_history.jpg │ ├── lipschitz_mlp_params.pkl │ ├── main_lipmlp.py │ └── model.py ├── nerf │ ├── nerf_interpolation.mp4 │ ├── nerf_loss_history.jpg │ ├── nerf_params.pkl │ ├── test_nerf.py │ ├── tiny_nerf_data.npz │ ├── train_nerf.py │ └── utils.py ├── neural_sdf │ ├── ground truth (t=0).png │ ├── ground truth (t=1).png │ ├── loss_history.jpg │ ├── main.py │ ├── mlp_params.pkl │ ├── model.py │ ├── network output (t=0).png │ └── network output (t=1).png └── normal_driven_stylization │ ├── .polyscope.ini │ ├── imgui.ini │ ├── loss_history.jpg │ ├── main.py │ ├── opt.obj │ └── spot.obj ├── differentiable ├── angle_defect.py ├── cotangent_weights.py ├── dihedral_angles.py ├── dotrow.py ├── face_areas.py ├── face_normals.py ├── fit_rotations_cayley.py ├── halfedge_lengths.py ├── normalize_unit_box.py ├── normalizerow.py ├── normrow.py ├── ramp_smooth.py ├── tip_angles.py ├── vertex_areas.py └── vertex_normals.py ├── external ├── read_mesh.py ├── signed_distance.py └── write_obj.py ├── general ├── adjacency_edge_face.py ├── adjacency_list_edge_face.py ├── adjacency_list_face_face.py ├── adjacency_list_vertex_face.py ├── adjacency_list_vertex_vertex.py ├── adjacency_vertex_vertex.py ├── boundary_vertices.py ├── cotmatrix.py ├── edge_flaps.py ├── edges.py ├── edges_with_mapping.py ├── find_index.py ├── he_initialization.py ├── knn_search.py ├── list_remove_indices.py ├── massmatrix.py ├── mid_point_curve_simplification.py ├── ordered_outline.py ├── outline.py ├── remove_unreferenced.py ├── sample_2D_grid.py ├── sdf_circle.py ├── sdf_cross.py ├── sdf_star.py └── sdf_triangle.py └── logo.png /.gitignore: -------------------------------------------------------------------------------- 1 | unit_test 2 | not_ready 3 | old 4 | demos/shapeup_with_GD 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # Cython debug symbols 143 | cython_debug/ 144 | 145 | # Specific apple ignores 146 | # General 147 | .DS_Store 148 | .AppleDouble 149 | .LSOverride 150 | 151 | # Icon must end with two \r 152 | Icon 153 | 154 | # Thumbnails 155 | ._* 156 | 157 | # Files that might appear in the root of a volume 158 | .DocumentRevisions-V100 159 | .fseventsd 160 | .Spotlight-V100 161 | .TemporaryItems 162 | .Trashes 163 | .VolumeIcon.icns 164 | .com.apple.timemachine.donotpresent 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The code in this repository is licensed under the Apache license. 2 | By contributing to this project, you agree to license your code under the 3 | same license. 4 | 5 | 6 | =========================================================================== 7 | 8 | 9 | Apache License 10 | Version 2.0, January 2004 11 | http://www.apache.org/licenses/ 12 | 13 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 14 | 15 | 1. Definitions. 16 | 17 | "License" shall mean the terms and conditions for use, reproduction, 18 | and distribution as defined by Sections 1 through 9 of this document. 19 | 20 | "Licensor" shall mean the copyright owner or entity authorized by 21 | the copyright owner that is granting the License. 22 | 23 | "Legal Entity" shall mean the union of the acting entity and all 24 | other entities that control, are controlled by, or are under common 25 | control with that entity. For the purposes of this definition, 26 | "control" means (i) the power, direct or indirect, to cause the 27 | direction or management of such entity, whether by contract or 28 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 29 | outstanding shares, or (iii) beneficial ownership of such entity. 30 | 31 | "You" (or "Your") shall mean an individual or Legal Entity 32 | exercising permissions granted by this License. 33 | 34 | "Source" form shall mean the preferred form for making modifications, 35 | including but not limited to software source code, documentation 36 | source, and configuration files. 37 | 38 | "Object" form shall mean any form resulting from mechanical 39 | transformation or translation of a Source form, including but 40 | not limited to compiled object code, generated documentation, 41 | and conversions to other media types. 42 | 43 | "Work" shall mean the work of authorship, whether in Source or 44 | Object form, made available under the License, as indicated by a 45 | copyright notice that is included in or attached to the work 46 | (an example is provided in the Appendix below). 47 | 48 | "Derivative Works" shall mean any work, whether in Source or Object 49 | form, that is based on (or derived from) the Work and for which the 50 | editorial revisions, annotations, elaborations, or other modifications 51 | represent, as a whole, an original work of authorship. For the purposes 52 | of this License, Derivative Works shall not include works that remain 53 | separable from, or merely link (or bind by name) to the interfaces of, 54 | the Work and Derivative Works thereof. 55 | 56 | "Contribution" shall mean any work of authorship, including 57 | the original version of the Work and any modifications or additions 58 | to that Work or Derivative Works thereof, that is intentionally 59 | submitted to Licensor for inclusion in the Work by the copyright owner 60 | or by an individual or Legal Entity authorized to submit on behalf of 61 | the copyright owner. For the purposes of this definition, "submitted" 62 | means any form of electronic, verbal, or written communication sent 63 | to the Licensor or its representatives, including but not limited to 64 | communication on electronic mailing lists, source code control systems, 65 | and issue tracking systems that are managed by, or on behalf of, the 66 | Licensor for the purpose of discussing and improving the Work, but 67 | excluding communication that is conspicuously marked or otherwise 68 | designated in writing by the copyright owner as "Not a Contribution." 69 | 70 | "Contributor" shall mean Licensor and any individual or Legal Entity 71 | on behalf of whom a Contribution has been received by Licensor and 72 | subsequently incorporated within the Work. 73 | 74 | 2. Grant of Copyright License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | copyright license to reproduce, prepare Derivative Works of, 78 | publicly display, publicly perform, sublicense, and distribute the 79 | Work and such Derivative Works in Source or Object form. 80 | 81 | 3. Grant of Patent License. Subject to the terms and conditions of 82 | this License, each Contributor hereby grants to You a perpetual, 83 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 84 | (except as stated in this section) patent license to make, have made, 85 | use, offer to sell, sell, import, and otherwise transfer the Work, 86 | where such license applies only to those patent claims licensable 87 | by such Contributor that are necessarily infringed by their 88 | Contribution(s) alone or by combination of their Contribution(s) 89 | with the Work to which such Contribution(s) was submitted. If You 90 | institute patent litigation against any entity (including a 91 | cross-claim or counterclaim in a lawsuit) alleging that the Work 92 | or a Contribution incorporated within the Work constitutes direct 93 | or contributory patent infringement, then any patent licenses 94 | granted to You under this License for that Work shall terminate 95 | as of the date such litigation is filed. 96 | 97 | 4. Redistribution. You may reproduce and distribute copies of the 98 | Work or Derivative Works thereof in any medium, with or without 99 | modifications, and in Source or Object form, provided that You 100 | meet the following conditions: 101 | 102 | (a) You must give any other recipients of the Work or 103 | Derivative Works a copy of this License; and 104 | 105 | (b) You must cause any modified files to carry prominent notices 106 | stating that You changed the files; and 107 | 108 | (c) You must retain, in the Source form of any Derivative Works 109 | that You distribute, all copyright, patent, trademark, and 110 | attribution notices from the Source form of the Work, 111 | excluding those notices that do not pertain to any part of 112 | the Derivative Works; and 113 | 114 | (d) If the Work includes a "NOTICE" text file as part of its 115 | distribution, then any Derivative Works that You distribute must 116 | include a readable copy of the attribution notices contained 117 | within such NOTICE file, excluding those notices that do not 118 | pertain to any part of the Derivative Works, in at least one 119 | of the following places: within a NOTICE text file distributed 120 | as part of the Derivative Works; within the Source form or 121 | documentation, if provided along with the Derivative Works; or, 122 | within a display generated by the Derivative Works, if and 123 | wherever such third-party notices normally appear. The contents 124 | of the NOTICE file are for informational purposes only and 125 | do not modify the License. You may add Your own attribution 126 | notices within Derivative Works that You distribute, alongside 127 | or as an addendum to the NOTICE text from the Work, provided 128 | that such additional attribution notices cannot be construed 129 | as modifying the License. 130 | 131 | You may add Your own copyright statement to Your modifications and 132 | may provide additional or different license terms and conditions 133 | for use, reproduction, or distribution of Your modifications, or 134 | for any such Derivative Works as a whole, provided Your use, 135 | reproduction, and distribution of the Work otherwise complies with 136 | the conditions stated in this License. 137 | 138 | 5. Submission of Contributions. Unless You explicitly state otherwise, 139 | any Contribution intentionally submitted for inclusion in the Work 140 | by You to the Licensor shall be under the terms and conditions of 141 | this License, without any additional terms or conditions. 142 | Notwithstanding the above, nothing herein shall supersede or modify 143 | the terms of any separate license agreement you may have executed 144 | with Licensor regarding such Contributions. 145 | 146 | 6. Trademarks. This License does not grant permission to use the trade 147 | names, trademarks, service marks, or product names of the Licensor, 148 | except as required for reasonable and customary use in describing the 149 | origin of the Work and reproducing the content of the NOTICE file. 150 | 151 | 7. Disclaimer of Warranty. Unless required by applicable law or 152 | agreed to in writing, Licensor provides the Work (and each 153 | Contributor provides its Contributions) on an "AS IS" BASIS, 154 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 155 | implied, including, without limitation, any warranties or conditions 156 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 157 | PARTICULAR PURPOSE. You are solely responsible for determining the 158 | appropriateness of using or redistributing the Work and assume any 159 | risks associated with Your exercise of permissions under this License. 160 | 161 | 8. Limitation of Liability. In no event and under no legal theory, 162 | whether in tort (including negligence), contract, or otherwise, 163 | unless required by applicable law (such as deliberate and grossly 164 | negligent acts) or agreed to in writing, shall any Contributor be 165 | liable to You for damages, including any direct, indirect, special, 166 | incidental, or consequential damages of any character arising as a 167 | result of this License or out of the use or inability to use the 168 | Work (including but not limited to damages for loss of goodwill, 169 | work stoppage, computer failure or malfunction, or any and all 170 | other commercial damages or losses), even if such Contributor 171 | has been advised of the possibility of such damages. 172 | 173 | 9. Accepting Warranty or Additional Liability. While redistributing 174 | the Work or Derivative Works thereof, You may choose to offer, 175 | and charge a fee for, acceptance of support, warranty, indemnity, 176 | or other liability obligations and/or rights consistent with this 177 | License. However, in accepting such obligations, You may act only 178 | on Your own behalf and on Your sole responsibility, not on behalf 179 | of any other Contributor, and only if You agree to indemnify, 180 | defend, and hold each Contributor harmless for any liability 181 | incurred by, or claims asserted against, such Contributor by reason 182 | of your accepting any such warranty or additional liability. 183 | 184 | END OF TERMS AND CONDITIONS 185 | 186 | APPENDIX: How to apply the Apache License to your work. 187 | 188 | To apply the Apache License to your work, attach the following 189 | boilerplate notice, with the fields enclosed by brackets "[]" 190 | replaced with your own identifying information. (Don't include 191 | the brackets!) The text should be enclosed in the appropriate 192 | comment syntax for the file format. We also recommend that a 193 | file or class name and description of purpose be included on the 194 | same "printed page" as the copyright notice for easier 195 | identification within third-party archives. 196 | 197 | Copyright [yyyy] [name of copyright owner] 198 | 199 | Licensed under the Apache License, Version 2.0 (the "License"); 200 | you may not use this file except in compliance with the License. 201 | You may obtain a copy of the License at 202 | 203 | http://www.apache.org/licenses/LICENSE-2.0 204 | 205 | Unless required by applicable law or agreed to in writing, software 206 | distributed under the License is distributed on an "AS IS" BASIS, 207 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 208 | See the License for the specific language governing permissions and 209 | limitations under the License. 210 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | # jaxgptoolbox 6 | 7 | This is a collection of basic geometry processing functions, constructed to work with [jax](https://github.com/google/jax)'s autodifferentiation feature for applications in machine learning. We split these functions into _not differentiable_ ones in the `general` folder, and differentiable ones in the `differentiable` folder. We also include some wrappers of third party functions in the `external` folder for convenience. To use these utility functions, one can simply import this package and use it as 8 | ``` 9 | import jaxgptoolbox as jgp 10 | import polyscope as ps 11 | V,F = jgp.read_mesh('path_to_OBJ') 12 | ps.init() 13 | ps.register_surface_mesh('my_mesh',V,F) 14 | ps.show() 15 | ``` 16 | 17 | ### Dependencies 18 | 19 | This library depends on [jax](https://github.com/google/jax) and some common python libraries [numpy](https://github.com/numpy/numpy) [scipy](https://github.com/scipy/scipy). Our `demos` rely on [matplotlib](https://github.com/matplotlib/matplotlib) and [polyscope](https://polyscope.run/py/) for visualization. Some functions in the `external` folder depend on [libigl](https://libigl.github.io/libigl-python-bindings/). Please make sure to install all dependencies (for example, with [conda](https://docs.conda.io/projects/conda/en/latest/index.html)) before using the library. 20 | 21 | ### Contacts & Warnings 22 | 23 | The toolbox grew out of [Oded Stein](https://odedstein.com)'s and [Hsueh-Ti Derek Liu](https://www.dgp.toronto.edu/~hsuehtil/)'s private research codebase during their PhD studies. Some of these functions are not fully tested nor optimized, please use them with caution. If you're interested in contributing or noticing any issues, please contact us (ostein@mit.edu, hsuehtil@cs.toronto.edu) or submit a pull request. 24 | 25 | ### License 26 | 27 | Consult the [LICENSE](LICENSE) file for details about the license of this project. 28 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .general.adjacency_edge_face import adjacency_edge_face 2 | from .general.adjacency_list_edge_face import adjacency_list_edge_face 3 | from .general.adjacency_list_face_face import adjacency_list_face_face 4 | from .general.adjacency_list_vertex_face import adjacency_list_vertex_face 5 | from .general.adjacency_list_vertex_vertex import adjacency_list_vertex_vertex 6 | from .general.adjacency_vertex_vertex import adjacency_vertex_vertex 7 | from .general.boundary_vertices import boundary_vertices 8 | from .general.cotmatrix import cotmatrix 9 | from .general.edges import edges 10 | from .general.edge_flaps import edge_flaps 11 | from .general.edges_with_mapping import edges_with_mapping 12 | from .general.find_index import find_index 13 | from .general.he_initialization import he_initialization 14 | from .general.knn_search import knn_search 15 | from .general.massmatrix import massmatrix 16 | from .general.mid_point_curve_simplification import mid_point_curve_simplification 17 | from .general.ordered_outline import ordered_outline 18 | from .general.outline import outline 19 | from .general.remove_unreferenced import remove_unreferenced 20 | from .general.sample_2D_grid import sample_2D_grid 21 | from .general.sdf_circle import sdf_circle 22 | from .general.sdf_cross import sdf_cross 23 | from .general.sdf_star import sdf_star 24 | from .general.sdf_triangle import sdf_triangle 25 | 26 | from .differentiable.angle_defect import angle_defect 27 | from .differentiable.angle_defect import angle_defect_intrinsic 28 | from .differentiable.cotangent_weights import cotangent_weights 29 | from .differentiable.dihedral_angles import dihedral_angles 30 | from .differentiable.dihedral_angles import dihedral_angles_from_normals 31 | from .differentiable.dotrow import dotrow 32 | from .differentiable.face_areas import face_areas 33 | from .differentiable.face_normals import face_normals 34 | from .differentiable.fit_rotations_cayley import fit_rotations_cayley 35 | from .differentiable.halfedge_lengths import halfedge_lengths 36 | from .differentiable.halfedge_lengths import halfedge_lengths_squared 37 | from .differentiable.normalize_unit_box import normalize_unit_box 38 | from .differentiable.normalizerow import normalizerow 39 | from .differentiable.normrow import normrow 40 | from .differentiable.ramp_smooth import ramp_smooth 41 | from .differentiable.tip_angles import tip_angles 42 | from .differentiable.tip_angles import tip_angles_intrinsic 43 | from .differentiable.vertex_areas import vertex_areas 44 | from .differentiable.vertex_normals import vertex_normals 45 | 46 | from .external.signed_distance import signed_distance 47 | from .external.read_mesh import read_mesh 48 | from .external.write_obj import write_obj -------------------------------------------------------------------------------- /demos/lipschitz_mlp/README.md: -------------------------------------------------------------------------------- 1 | # Learning Smooth Neural Functions via Lipschitz Regularization 2 | 3 | This is the demo code for the Lipschitz MLP: 4 | 5 | **Learning Smooth Neural Functions via Lipschitz Regularization** 6 | _Hsueh-Ti Derek Liu, Francis Williams, Alec Jacobson, Sanja Fidler, Or Litany_ 7 | SIGGRAPH (North America), 2022 8 | [[Project Page](https://nv-tlabs.github.io/lip-mlp/)] [[Preprint](https://www.dgp.toronto.edu/~hsuehtil/pdf/lipmlp.pdf)] 9 | 10 | ### Dependencies 11 | Our method depends on [JAX](https://github.com/google/jax) and some common python dependencies (e.g., numpy, tqdm, matplotlib, etc.). Some functions in the script, such as generating analytical signed distance functions, depend on other parts in the repository -- [jaxgptoolbox](https://github.com/ml-for-gp/jaxgptoolbox). 12 | 13 | ### Repository Structure 14 | - `main_lipmlp.py` is the main training script. This is a self-contained script to train a Lipschitz MLP to interpolate 2D signed distance functions of a star and a circle. To train the model from scratch, one can simply run 15 | ```python 16 | python main_lipmlp.py 17 | ``` 18 | After training (~15 min on a CPU), you should see the interpolation results in `lipschitz_mlp_interpolation.mp4` and the model parameters in `lipschitz_mlp_params.pkl`. 19 | - `model.py` contains the Lipschitz MLP model. One can simply use it as 20 | ```python 21 | model = lipmlp(hyper_params) # build the model 22 | params = model.initialize_weights() # initialize weights 23 | y = model.forward(params, latent_code, x) # forward pass 24 | ``` 25 | 26 | -------------------------------------------------------------------------------- /demos/lipschitz_mlp/lipschitz_mlp_interpolation.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/lipschitz_mlp/lipschitz_mlp_interpolation.mp4 -------------------------------------------------------------------------------- /demos/lipschitz_mlp/lipschitz_mlp_loss_history.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/lipschitz_mlp/lipschitz_mlp_loss_history.jpg -------------------------------------------------------------------------------- /demos/lipschitz_mlp/lipschitz_mlp_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/lipschitz_mlp/lipschitz_mlp_params.pkl -------------------------------------------------------------------------------- /demos/lipschitz_mlp/main_lipmlp.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | 3 | # implementation of "Learning Smooth Neural Functions via Lipschitz Regularization" by Liu et al. 2022 4 | if __name__ == '__main__': 5 | random.seed(1) 6 | 7 | # hyper parameters 8 | hyper_params = { 9 | "dim_in": 2, 10 | "dim_t": 1, 11 | "dim_out": 1, 12 | "h_mlp": [64,64,64,64,64], 13 | "step_size": 1e-4, 14 | "grid_size": 32, 15 | "num_epochs": 200000, 16 | "samples_per_epoch": 512 17 | } 18 | alpha = 1e-6 19 | 20 | # initialize a mlp 21 | model = lipmlp(hyper_params) 22 | params = model.initialize_weights() 23 | 24 | # optimizer 25 | opt_init, opt_update, get_params = optimizers.adam(step_size=hyper_params["step_size"]) 26 | opt_state = opt_init(params) 27 | 28 | # define loss function and update function 29 | def loss(params_, alpha, x_, y0_, y1_): 30 | out0 = model.forward(params_, np.array([0.0]), x_) # star when t = 0.0 31 | out1 = model.forward(params_, np.array([1.0]), x_) # circle when t = 1.0 32 | loss_sdf = np.mean((out0 - y0_)**2) + np.mean((out1 - y1_)**2) 33 | loss_lipschitz = model.get_lipschitz_loss(params_) 34 | return loss_sdf + alpha * loss_lipschitz 35 | 36 | @jit 37 | def update(epoch, opt_state, alpha, x_, y0_, y1_): 38 | params_ = get_params(opt_state) 39 | value, grads = value_and_grad(loss, argnums = 0)(params_, alpha, x_, y0_, y1_) 40 | opt_state = opt_update(epoch, grads, opt_state) 41 | return value, opt_state 42 | 43 | # training 44 | loss_history = onp.zeros(hyper_params["num_epochs"]) 45 | pbar = tqdm.tqdm(range(hyper_params["num_epochs"])) 46 | for epoch in pbar: 47 | # sample a bunch of random points 48 | x = np.array(random.rand(hyper_params["samples_per_epoch"], hyper_params["dim_in"])) 49 | y0 = jgp.sdf_star(x) 50 | y1 = jgp.sdf_circle(x) 51 | 52 | # update 53 | loss_value, opt_state = update(epoch, opt_state, alpha, x, y0, y1) 54 | loss_history[epoch] = loss_value 55 | pbar.set_postfix({"loss": loss_value}) 56 | 57 | if epoch % 1000 == 0: # plot loss history every 1000 iter 58 | plt.close(1) 59 | plt.figure(1) 60 | plt.semilogy(loss_history[:epoch]) 61 | plt.title('Reconstruction loss + Lipschitz loss') 62 | plt.grid() 63 | plt.savefig("lipschitz_mlp_loss_history.jpg") 64 | 65 | # save final parameters 66 | params = get_params(opt_state) 67 | with open("lipschitz_mlp_params.pkl", 'wb') as handle: 68 | pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL) 69 | 70 | # normalize weights during test time 71 | params_final = model.normalize_params(params) 72 | 73 | # save result as a video 74 | sdf_cm = mpl.colors.LinearSegmentedColormap.from_list('SDF', [(0,'#eff3ff'),(0.5,'#3182bd'),(0.5,'#31a354'),(1,'#e5f5e0')], N=256) 75 | 76 | # create video 77 | fig = plt.figure() 78 | x = jgp.sample_2D_grid(hyper_params["grid_size"]) # sample on unit grid for visualization 79 | def animate(t): 80 | plt.cla() 81 | out = model.forward_eval(params_final, np.array([t]), x) 82 | levels = onp.linspace(-0.5, 0.5, 21) 83 | im = plt.contourf(out.reshape(hyper_params['grid_size'],hyper_params['grid_size']), levels = levels, cmap=sdf_cm) 84 | plt.axis('equal') 85 | plt.axis("off") 86 | return im 87 | anim = animation.FuncAnimation(fig, animate, frames=np.linspace(0, 1, 50), interval=50) 88 | anim.save("lipschitz_mlp_interpolation.mp4") -------------------------------------------------------------------------------- /demos/lipschitz_mlp/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../../') 3 | import jaxgptoolbox as jgp 4 | 5 | import jax 6 | import jax.numpy as np 7 | from jax import jit, value_and_grad 8 | from jax.experimental import optimizers 9 | 10 | import numpy as onp 11 | import numpy.random as random 12 | import matplotlib as mpl 13 | import matplotlib.pyplot as plt 14 | import matplotlib.animation as animation 15 | import tqdm 16 | import pickle 17 | 18 | class lipmlp: 19 | def __init__(self, hyperParams): 20 | self.hyperParams = hyperParams 21 | 22 | def initialize_weights(self): 23 | """ 24 | Initialize the parameters of the Lipschitz mlp 25 | 26 | Inputs 27 | hyperParams: hyper parameter dictionary 28 | 29 | Outputs 30 | params_net: parameters of the network (weight, bias, initial lipschitz bound) 31 | """ 32 | def init_W(size_out, size_in): 33 | W = onp.random.randn(size_out, size_in) * onp.sqrt(2 / size_in) 34 | return np.array(W) 35 | sizes = self.hyperParams["h_mlp"] 36 | sizes.insert(0, self.hyperParams["dim_in"] + self.hyperParams["dim_t"]) 37 | sizes.append(self.hyperParams["dim_out"]) 38 | params_net = [] 39 | for ii in range(len(sizes) - 1): 40 | W = init_W(sizes[ii+1], sizes[ii]) 41 | b = np.zeros(sizes[ii+1]) 42 | c = np.max(np.sum(np.abs(W), axis=1)) 43 | params_net.append([W, b, c]) 44 | return params_net 45 | 46 | def weight_normalization(self, W, softplus_c): 47 | """ 48 | Lipschitz weight normalization based on the L-infinity norm 49 | """ 50 | absrowsum = np.sum(np.abs(W), axis=1) 51 | scale = np.minimum(1.0, softplus_c/absrowsum) 52 | return W * scale[:,None] 53 | 54 | def forward_single(self, params_net, t, x): 55 | """ 56 | Forward pass of a lipschitz MLP 57 | 58 | Inputs 59 | params_net: parameters of the network 60 | t: the input feature of the shape 61 | x: a query location in the space 62 | 63 | Outputs 64 | out: implicit function value at x 65 | """ 66 | # concatenate coordinate and latent code 67 | x = np.append(x, t) 68 | 69 | # forward pass 70 | for ii in range(len(params_net) - 1): 71 | W, b, c = params_net[ii] 72 | W = self.weight_normalization(W, jax.nn.softplus(c)) 73 | x = jax.nn.relu(np.dot(W, x) + b) 74 | 75 | # final layer 76 | W, b, c = params_net[-1] 77 | W = self.weight_normalization(W, jax.nn.softplus(c)) 78 | out = np.dot(W, x) + b 79 | return out[0] 80 | forward = jax.vmap(forward_single, in_axes=(None, None, None, 0), out_axes=0) 81 | 82 | def get_lipschitz_loss(self, params_net): 83 | """ 84 | This function computes the Lipschitz regularization 85 | """ 86 | loss_lip = 1.0 87 | for ii in range(len(params_net)): 88 | W, b, c = params_net[ii] 89 | loss_lip = loss_lip * jax.nn.softplus(c) 90 | return loss_lip 91 | 92 | def normalize_params(self, params_net): 93 | """ 94 | (Optional) After training, this function will clip network [W, b] based on learned lipschitz constants. Thus, one can use normal MLP forward pass during test time, which is a little bit faster. 95 | """ 96 | params_final = [] 97 | for ii in range(len(params_net)): 98 | W, b, c = params_net[ii] 99 | W = self.weight_normalization(W, jax.nn.softplus(c)) 100 | params_final.append([W, b]) 101 | return params_final 102 | 103 | def forward_eval_single(self, params_final, t, x): 104 | """ 105 | (Optional) this is a standard forward pass of a mlp. This is useful to speed up the performance during test time 106 | """ 107 | # concatenate coordinate and latent code 108 | x = np.append(x, t) 109 | 110 | # forward pass 111 | for ii in range(len(params_final) - 1): 112 | W, b = params_final[ii] 113 | x = jax.nn.relu(np.dot(W, x) + b) 114 | W, b = params_final[-1] # final layer 115 | out = np.dot(W, x) + b 116 | return out[0] 117 | forward_eval = jax.vmap(forward_eval_single, in_axes=(None, None, None, 0), out_axes=0) -------------------------------------------------------------------------------- /demos/nerf/nerf_interpolation.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/nerf/nerf_interpolation.mp4 -------------------------------------------------------------------------------- /demos/nerf/nerf_loss_history.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/nerf/nerf_loss_history.jpg -------------------------------------------------------------------------------- /demos/nerf/nerf_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/nerf/nerf_params.pkl -------------------------------------------------------------------------------- /demos/nerf/test_nerf.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | if __name__ == "__main__": 4 | poses = generate_test_poses() 5 | imgs, _, focal = load_tiny_nerf('./tiny_nerf_data.npz') 6 | 7 | # hyper parameters 8 | hyper_params = { 9 | "n_in": 3, 10 | "n_out": 4, 11 | "n_pos_encode": 6, 12 | "n_samples_per_path": 128, 13 | "near_plane": 2.0, 14 | "far_plane": 6.0, 15 | "h_mlp": [128,128,128], 16 | "step_size": 1e-4, 17 | "num_epochs": 1000, 18 | "batch_size": 200 19 | } 20 | 21 | # initialize a nerf network 22 | model = NeRF(hyper_params) 23 | with open('nerf_params.pkl', 'rb') as handle: 24 | params = pickle.load(handle) 25 | 26 | # create video 27 | n_imgs = poses.shape[0] 28 | img_H = imgs.shape[1] 29 | img_W = imgs.shape[2] 30 | fig = plt.figure() 31 | def animate(ii): 32 | plt.cla() 33 | orig, dirs = generate_rays_from_camera(img_H, img_W, focal, poses[ii]) 34 | dirs = np.reshape(dirs, (-1, 3)) 35 | 36 | out = model.path_integral(params, orig, dirs) 37 | out = onp.reshape(onp.array(out), (img_H, img_W, 3)) 38 | im = plt.imshow(out) 39 | return im 40 | anim = animation.FuncAnimation(fig, animate, frames=np.arange(n_imgs), interval=50) 41 | anim.save("nerf_interpolation.mp4") 42 | 43 | -------------------------------------------------------------------------------- /demos/nerf/tiny_nerf_data.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/nerf/tiny_nerf_data.npz -------------------------------------------------------------------------------- /demos/nerf/train_nerf.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | # Re-implementation of "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis" by Mildenhall et al 2020 4 | if __name__ == "__main__": 5 | imgs, poses, focal = load_tiny_nerf('./tiny_nerf_data.npz') 6 | # jax_imshow(imgs[0]) 7 | 8 | # hyper parameters 9 | hyper_params = { 10 | "n_in": 3, 11 | "n_out": 4, 12 | "n_pos_encode": 6, 13 | "n_samples_per_path": 256, 14 | "near_plane": 2.0, 15 | "far_plane": 6.0, 16 | "h_mlp": [128,128,128,128,128,128], 17 | "step_size": 1e-4, 18 | "num_epochs": 300, 19 | "batch_size": 200 20 | } 21 | 22 | # initialize a nerf network 23 | model = NeRF(hyper_params) 24 | params = model.initialize_weights() 25 | 26 | # optimizer 27 | opt_init, opt_update, get_params = optimizers.adam(step_size=hyper_params["step_size"]) 28 | opt_state = opt_init(params) 29 | 30 | def loss(params_, origin_, directions_, img_): 31 | out = model.path_integral(params_, origin_, directions_) 32 | out = np.clip(out, 0.0, 1.0) 33 | loss_val = np.mean((out - img_)**2) 34 | return loss_val 35 | 36 | @jit 37 | def update(epoch, opt_state, origin_, directions_, img_): 38 | params_ = get_params(opt_state) 39 | value, grads = value_and_grad(loss, argnums = 0)(params_, origin_, directions_, img_) 40 | opt_state = opt_update(epoch, grads, opt_state) 41 | return value, opt_state 42 | 43 | # training 44 | loss_history = onp.zeros(hyper_params["num_epochs"]) 45 | pbar = tqdm.tqdm(range(hyper_params["num_epochs"])) 46 | # data for training 47 | n_imgs = imgs.shape[0] 48 | img_H = imgs.shape[1] 49 | img_W = imgs.shape[2] 50 | batch_size = hyper_params["batch_size"] 51 | for epoch in pbar: 52 | for ii in range(n_imgs): 53 | # gradient step 54 | orig, dirs = generate_rays_from_camera(img_H, img_W, focal, poses[ii]) 55 | dirs = np.reshape(dirs, (-1, 3)) 56 | img = np.reshape(imgs[ii], (-1, 3)) 57 | 58 | # split rays into batches 59 | img_batches = np.array_split(img, batch_size) 60 | dirs_batches = np.array_split(dirs, batch_size) 61 | for b in range(batch_size): 62 | img_b = img_batches[b] 63 | dirs_b = dirs_batches[b] 64 | loss_value, opt_state = update(epoch, opt_state, orig, dirs_b, img_b) 65 | 66 | # save loss 67 | loss_history[epoch] += loss_value / n_imgs / batch_size 68 | pbar.set_postfix({"loss": loss_history[epoch]}) 69 | 70 | if epoch % 1 == 0: # plot loss history every 1000 iter 71 | plt.close(1) 72 | plt.figure(1) 73 | plt.semilogy(loss_history[:epoch]) 74 | plt.title('reconstruction loss') 75 | plt.grid() 76 | plt.savefig("nerf_loss_history.jpg") 77 | 78 | params = get_params(opt_state) 79 | with open("nerf_params.pkl", 'wb') as handle: 80 | pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL) 81 | 82 | # save final parameters 83 | params = get_params(opt_state) 84 | with open("nerf_params.pkl", 'wb') as handle: 85 | pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /demos/nerf/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../../') 3 | import jaxgptoolbox as jgp 4 | import numpy as onp 5 | import jax 6 | import jax.numpy as np 7 | from jax import jit, value_and_grad 8 | from jax.experimental import optimizers 9 | import pickle 10 | 11 | from matplotlib import pyplot as plt 12 | import matplotlib.animation as animation 13 | import cv2 14 | import tqdm 15 | 16 | def load_tiny_nerf(path): 17 | data = onp.load(path) 18 | images = data["images"] # n_img x H x W x 3 19 | poses = data["poses"] # n_img x 4 x 4 20 | focal = float(data["focal"]) 21 | return np.array(images), np.array(poses), focal 22 | 23 | def jax_imshow(jax_img): 24 | img = onp.array(jax_img) 25 | img = img[:,:,[2,1,0]] 26 | scale = int(600. / img.shape[0]) 27 | dim = (img.shape[0]*scale, img.shape[1]*scale) 28 | img_resize = cv2.resize(img, dim) 29 | cv2.imshow('image', img_resize) 30 | cv2.waitKey(0) 31 | 32 | class NeRF: 33 | def __init__(self, hyper_params): 34 | self.hyper_params = hyper_params 35 | 36 | def initialize_weights(self): 37 | """ 38 | initialize network weights 39 | """ 40 | sizes = self.hyper_params["h_mlp"] 41 | 42 | # add input dimension 43 | n_raw_in = self.hyper_params["n_in"] # point dim 44 | n_encode = self.hyper_params["n_pos_encode"] # positional encoding with sin/cos 45 | n_in = n_raw_in + n_raw_in * 2 * n_encode # "2" is due to sin/cos 46 | sizes.insert(0, n_in) 47 | 48 | sizes.append( self.hyper_params["n_out"]) # add output dimension 49 | 50 | # initialization network parameters 51 | params = [] 52 | for ii in range(len(sizes) - 1): 53 | if ii == (len(sizes) - 2): 54 | # last layer only outputs RGB, n_out - 1 55 | # last layer inputs additional view dir, n_in + 3 56 | W, b = jgp.he_initialization(sizes[ii+1]-1, sizes[ii]+3) 57 | elif ii == (len(sizes) - 3): 58 | # second last layer outputs additional volume density 59 | W, b = jgp.he_initialization(sizes[ii+1]+1, sizes[ii]) 60 | else: 61 | W, b = jgp.he_initialization(sizes[ii+1], sizes[ii]) 62 | print(W.shape) 63 | params.append([W, b]) 64 | # last layer add view direction 65 | return params 66 | 67 | def positional_encoding(self, x_in): 68 | """ 69 | positional encoding: x -> [x, sin(w*x), cos(w*x)] 70 | where w = 2^[0, 1, ..., n_encode-1] 71 | """ 72 | n_encode = self.hyper_params["n_pos_encode"] 73 | w = np.power(2., np.arange(0., n_encode)) 74 | x = w[:,None] * x_in[None,:] 75 | x = x.flatten() 76 | return np.concatenate((x_in, np.sin(2*np.pi*x), np.cos(2*np.pi*x))) 77 | 78 | def activation(self, x): 79 | return jax.nn.leaky_relu(x) 80 | 81 | def forward_single(self, params, dir, x): 82 | x = self.positional_encoding(x) 83 | for ii in range(len(params) - 1): 84 | W, b = params[ii] 85 | x = self.activation(np.dot(W, x) + b) 86 | # second last layer outputs volume density (sigma) 87 | sigma = jax.nn.relu(x[0]) 88 | x = x[1:] 89 | # last layer append view direction 90 | x = np.concatenate((x, dir)) 91 | W, b = params[-1] 92 | rgb = jax.nn.sigmoid(np.dot(W, x) + b) 93 | return np.append(rgb, sigma) 94 | forward = jax.vmap(forward_single, in_axes=(None, None, None, 0), out_axes=0) 95 | 96 | def path_integral_single(self, params, orig, dir): 97 | # generate query locations 98 | n_samples = self.hyper_params["n_samples_per_path"] 99 | near = self.hyper_params["near_plane"] 100 | far = self.hyper_params["far_plane"] 101 | dists = np.linspace(near, far, n_samples) 102 | x = orig[None,:] + dir[None,:] * dists[:,None] 103 | 104 | # forward pass to get (r,g,b,density) 105 | rgbd = self.forward(params, dir, x) 106 | 107 | # volume integral to get colors (see NeRF paper Eq.3) 108 | d = dists[1] - dists[0] # distance between samples 109 | cumsum_density = np.cumsum(rgbd[:,3]) 110 | T = np.exp(-cumsum_density*d) 111 | w = T * (1.-np.exp(-rgbd[:,3]*d)) # weights of each evaluation 112 | rgb_out = w.dot(rgbd[:,:3]) # integrate along the path 113 | return rgb_out 114 | path_integral = jax.vmap(path_integral_single, in_axes=(None, None, None, 0), out_axes=0) 115 | 116 | def generate_rays_from_camera(image_height, image_width, focal, pose): 117 | """ 118 | generate camera rays in the world space for each pixel 119 | 120 | Inputs 121 | image_height: number of pixels in H direction 122 | image_width: number of pixels in W direction 123 | focal: focal length of a camera 124 | pose: 4x4 array of camera pose matrix 125 | 126 | Outputs 127 | ray_origin (3,) array of ray origin 128 | ray_directions HxWx3 array of ray directions 129 | """ 130 | i, j = np.meshgrid(np.arange(image_width), np.arange(image_height), indexing="xy") 131 | k = -np.ones_like(i) 132 | i = (i - image_width * 0.5) / focal 133 | j = -(j - image_height * 0.5) / focal 134 | directions = np.stack([i, j, k], axis=-1) 135 | camera_matrix = pose[:3, :3] 136 | ray_directions = np.einsum("ijl,kl", directions, camera_matrix) 137 | # ray_origins = np.broadcast_to(pose[:3, -1], ray_directions.shape) 138 | ray_origins = pose[:3, -1] 139 | return ray_origins, ray_directions 140 | 141 | def angles_to_pose(theta, phi, radius): 142 | translation = lambda t: np.asarray( 143 | [ 144 | [1, 0, 0, 0], 145 | [0, 1, 0, 0], 146 | [0, 0, 1, t], 147 | [0, 0, 0, 1], 148 | ] 149 | ) 150 | rotation_phi = lambda phi: np.asarray( 151 | [ 152 | [1, 0, 0, 0], 153 | [0, np.cos(phi), -np.sin(phi), 0], 154 | [0, np.sin(phi), np.cos(phi), 0], 155 | [0, 0, 0, 1], 156 | ] 157 | ) 158 | rotation_theta = lambda th: np.asarray( 159 | [ 160 | [np.cos(th), 0, -np.sin(th), 0], 161 | [0, 1, 0, 0], 162 | [np.sin(th), 0, np.cos(th), 0], 163 | [0, 0, 0, 1], 164 | ] 165 | ) 166 | 167 | pose = translation(radius) 168 | pose = rotation_phi(phi / 180.0 * np.pi) @ pose 169 | pose = rotation_theta(theta / 180.0 * np.pi) @ pose 170 | return ( 171 | np.array([ 172 | [-1, 0, 0, 0], 173 | [0, 0, 1, 0], 174 | [0, 1, 0, 0], 175 | [0, 0, 0, 1] 176 | ]) @ pose 177 | ) 178 | 179 | def generate_test_poses(): 180 | video_angle = onp.linspace(0.0, 360.0, 120, endpoint=False) 181 | poses = onp.zeros((len(video_angle), 4, 4)) 182 | for ii in range(len(video_angle)): 183 | poses[ii] = angles_to_pose(video_angle[ii], -30, 4.0) 184 | return poses -------------------------------------------------------------------------------- /demos/neural_sdf/ground truth (t=0).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/ground truth (t=0).png -------------------------------------------------------------------------------- /demos/neural_sdf/ground truth (t=1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/ground truth (t=1).png -------------------------------------------------------------------------------- /demos/neural_sdf/loss_history.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/loss_history.jpg -------------------------------------------------------------------------------- /demos/neural_sdf/main.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | 3 | if __name__ == '__main__': 4 | random.seed(1) 5 | 6 | # hyper parameters 7 | hyper_params = { 8 | "dim_in": 2, 9 | "dim_t": 1, 10 | "dim_out": 1, 11 | "h_mlp": [64,64,64], 12 | "step_size": 1e-4, 13 | "grid_size": 32, 14 | "num_epochs": 50000, 15 | "samples_per_epoch": 512 16 | } 17 | 18 | # initialize a mlp 19 | model = mlp(hyper_params) 20 | params = model.initialize_weights() 21 | 22 | # optimizer 23 | opt_init, opt_update, get_params = optimizers.adam(step_size=hyper_params["step_size"]) 24 | opt_state = opt_init(params) 25 | 26 | # define loss function and update function 27 | def loss(params_, x_, y0_, y1_): 28 | out0 = model.forward(params_, np.array([0.0]), x_) # star when t = 0.0 29 | out1 = model.forward(params_, np.array([1.0]), x_) # circle when t = 1.0 30 | loss_sdf = np.mean((out0 - y0_)**2) + np.mean((out1 - y1_)**2) 31 | return loss_sdf 32 | 33 | @jit 34 | def update(epoch, opt_state, x_, y0_, y1_): 35 | params_ = get_params(opt_state) 36 | value, grads = value_and_grad(loss, argnums = 0)(params_, x_, y0_, y1_) 37 | opt_state = opt_update(epoch, grads, opt_state) 38 | return value, opt_state 39 | 40 | # training 41 | loss_history = onp.zeros(hyper_params["num_epochs"]) 42 | pbar = tqdm.tqdm(range(hyper_params["num_epochs"])) # progress bar 43 | for epoch in pbar: 44 | # sample a bunch of random points 45 | x = np.array(random.rand(hyper_params["samples_per_epoch"], hyper_params["dim_in"])) 46 | y0 = jgp.sdf_star(x) # target SDF values at x 47 | y1 = jgp.sdf_circle(x) # target SDF values at x 48 | 49 | # update network parameters 50 | loss_value, opt_state = update(epoch, opt_state, x, y0, y1) 51 | loss_history[epoch] = loss_value 52 | pbar.set_postfix({"loss": loss_value}) 53 | 54 | if epoch % 1000 == 0: # plot loss history every 1000 iter 55 | plt.close(1) 56 | plt.figure(1) 57 | plt.semilogy(loss_history[:epoch]) 58 | plt.title('Reconstruction loss') 59 | plt.grid() 60 | plt.savefig("loss_history.jpg") 61 | 62 | # save final parameters 63 | params = get_params(opt_state) 64 | with open("mlp_params.pkl", 'wb') as handle: 65 | pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL) 66 | 67 | # save results 68 | sdf_cm = mpl.colors.LinearSegmentedColormap.from_list('SDF', [(0,'#eff3ff'),(0.5,'#3182bd'),(0.5,'#31a354'),(1,'#e5f5e0')], N=256) # color map 69 | levels = onp.linspace(-0.5, 0.5, 21) # isoline 70 | x = jgp.sample_2D_grid(hyper_params["grid_size"]) # sample on unit grid for visualization 71 | 72 | fig = plt.figure() 73 | y0 = jgp.sdf_star(x) 74 | im = plt.contourf(y0.reshape(hyper_params['grid_size'],hyper_params['grid_size']), levels = levels, cmap=sdf_cm) 75 | plt.axis('equal') 76 | plt.axis("off") 77 | plt.savefig('ground truth (t=0)') 78 | 79 | plt.clf() 80 | y0_pred = model.forward(params, np.array([0.0]), x) 81 | im = plt.contourf(y0_pred.reshape(hyper_params['grid_size'],hyper_params['grid_size']), levels = levels, cmap=sdf_cm) 82 | plt.axis('equal') 83 | plt.axis("off") 84 | plt.savefig('network output (t=0)') 85 | 86 | plt.clf() 87 | y1 = jgp.sdf_circle(x) 88 | im = plt.contourf(y1.reshape(hyper_params['grid_size'],hyper_params['grid_size']), levels = levels, cmap=sdf_cm) 89 | plt.axis('equal') 90 | plt.axis("off") 91 | plt.savefig('ground truth (t=1)') 92 | 93 | plt.clf() 94 | y1_pred = model.forward(params, np.array([1.0]), x) 95 | im = plt.contourf(y1_pred.reshape(hyper_params['grid_size'],hyper_params['grid_size']), levels = levels, cmap=sdf_cm) 96 | plt.axis('equal') 97 | plt.axis("off") 98 | plt.savefig('network output (t=1)') 99 | -------------------------------------------------------------------------------- /demos/neural_sdf/mlp_params.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/mlp_params.pkl -------------------------------------------------------------------------------- /demos/neural_sdf/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../../') 3 | import jaxgptoolbox as jgp 4 | 5 | import jax 6 | import jax.numpy as np 7 | from jax import jit, value_and_grad 8 | from jax.experimental import optimizers 9 | 10 | import numpy as onp 11 | import numpy.random as random 12 | import matplotlib as mpl 13 | import matplotlib.pyplot as plt 14 | import matplotlib.animation as animation 15 | import tqdm 16 | import pickle 17 | 18 | class mlp: 19 | def __init__(self, hyperParams): 20 | self.hyperParams = hyperParams 21 | 22 | def initialize_weights(self): 23 | """ 24 | Initialize the parameters of the mlp 25 | 26 | Inputs 27 | hyperParams: hyper parameter dictionary 28 | 29 | Outputs 30 | params_net: parameters of the network (weight, bias) as a list of jax.numpy arrays 31 | """ 32 | def init_W(size_out, size_in): 33 | W = onp.random.randn(size_out, size_in) * onp.sqrt(2 / size_in) 34 | return np.array(W) 35 | sizes = self.hyperParams["h_mlp"] 36 | sizes.insert(0, self.hyperParams["dim_in"] + self.hyperParams["dim_t"]) 37 | sizes.append(self.hyperParams["dim_out"]) 38 | params_net = [] 39 | for ii in range(len(sizes) - 1): 40 | W = init_W(sizes[ii+1], sizes[ii]) 41 | b = np.zeros(sizes[ii+1]) 42 | params_net.append([W, b]) 43 | return params_net 44 | 45 | def forward_single(self, params_net, t, x): 46 | """ 47 | Forward pass of a MLP 48 | 49 | Inputs 50 | params_net: parameters of the network 51 | t: the latent code of the shape 52 | x: a query location in the space 53 | 54 | Outputs 55 | out: implicit function value at x (signed distance in this case) 56 | """ 57 | # concatenate coordinate and latent code 58 | x = np.append(x, t) 59 | 60 | # forward pass 61 | for ii in range(len(params_net) - 1): 62 | W, b = params_net[ii] 63 | x = jax.nn.relu(np.dot(W, x) + b) 64 | 65 | # final layer 66 | W, b = params_net[-1] 67 | out = np.dot(W, x) + b 68 | return out[0] 69 | 70 | # vectorize the "forward_single" function 71 | forward = jax.vmap(forward_single, in_axes=(None, None, None, 0), out_axes=0) -------------------------------------------------------------------------------- /demos/neural_sdf/network output (t=0).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/network output (t=0).png -------------------------------------------------------------------------------- /demos/neural_sdf/network output (t=1).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/network output (t=1).png -------------------------------------------------------------------------------- /demos/normal_driven_stylization/.polyscope.ini: -------------------------------------------------------------------------------- 1 | { 2 | "windowHeight": 720, 3 | "windowPosX": 48, 4 | "windowPosY": 53, 5 | "windowWidth": 1280 6 | } 7 | -------------------------------------------------------------------------------- /demos/normal_driven_stylization/imgui.ini: -------------------------------------------------------------------------------- 1 | [Window][Debug##Default] 2 | Pos=60,60 3 | Size=400,400 4 | Collapsed=0 5 | 6 | [Window][Polyscope] 7 | Pos=10,10 8 | Size=305,156 9 | Collapsed=0 10 | 11 | [Window][Structures] 12 | Pos=10,186 13 | Size=305,524 14 | Collapsed=0 15 | 16 | [Window][Selection] 17 | Pos=770,30 18 | Size=500,98 19 | Collapsed=0 20 | 21 | -------------------------------------------------------------------------------- /demos/normal_driven_stylization/loss_history.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/normal_driven_stylization/loss_history.jpg -------------------------------------------------------------------------------- /demos/normal_driven_stylization/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../../') 3 | import jaxgptoolbox as jgp 4 | 5 | import jax 6 | import jax.numpy as np 7 | from jax import jit, value_and_grad 8 | from jax.example_libraries import optimizers 9 | 10 | import numpy as onp 11 | import tqdm 12 | import matplotlib.pyplot as plt 13 | import polyscope as ps 14 | 15 | def spokes_rims(V,F): 16 | """ 17 | build spokes and rims edge indices and edge weights 18 | 19 | Inputs: 20 | V: |V|x3 array of vertex list 21 | F: |F|x3 array of face list 22 | 23 | Outputs: 24 | Ek_all: |V| list of arrays, where Ek_all[v] = |Ek|x2 array of edge indices 25 | Wk_all: |V| list of arrays, where Wk_all[v] = |Ek| array of cotan weights 26 | """ 27 | V2F = jgp.vertex_face_adjacency_list(F) 28 | C = jgp.cotangent_weights(V,F) 29 | 30 | # construct spokes and rims for each vertex 31 | F = onp.array(F) 32 | nV = V.shape[0] 33 | 34 | # find the max number of neighbors 35 | max_Nk = 0 36 | for kk in range(nV): 37 | len_Nk = len(V2F[kk]) 38 | if len_Nk > max_Nk: 39 | max_Nk = len_Nk 40 | print("max #neighbors is: %d" % max_Nk) 41 | 42 | Ek_np = onp.zeros((nV,max_Nk*3,2), dtype=onp.int32) 43 | Wk_np = onp.zeros((nV,max_Nk*3), dtype=onp.float32) 44 | for kk in range(nV): 45 | # get neighbors 46 | Nk = V2F[kk] 47 | 48 | # construct edge list 49 | Ek0 = onp.concatenate((F[Nk,0], F[Nk,1], F[Nk,2])) 50 | Ek1 = onp.concatenate((F[Nk,1], F[Nk,2], F[Nk,0])) 51 | Ek = onp.concatenate((Ek0[:,None], Ek1[:,None]), axis = 1) 52 | # pad with the ghost vertex index so that Ek becomes a numpy array with size (max_Nk, 2) 53 | Ek_pad = onp.pad(Ek, pad_width = ((0, max_Nk*3 - Ek.shape[0]), (0, 0)), mode = 'constant', constant_values = nV) 54 | Ek_np[kk,:,:] = Ek_pad 55 | 56 | # get all the cotan weights 57 | Wk = onp.concatenate((C[Nk,2],C[Nk,0],C[Nk,1])) 58 | # pad with 0 so that Wk becomes a numpy array with size (max_Nk,) 59 | Wk_pad = onp.pad(Wk, pad_width = ((0, max_Nk*3 - Wk.shape[0])), mode = 'constant', constant_values = 0) 60 | Wk_np[kk,:] = Wk_pad 61 | 62 | return np.array(Ek_np), np.array(Wk_np) 63 | 64 | def compute_target_normal_single(n): 65 | cIdx = np.argmax(np.abs(n)) 66 | tar_n = np.zeros((3,), dtype=np.float32) 67 | tar_n = tar_n.at[cIdx].set(np.sign(n[cIdx])) 68 | return tar_n 69 | compute_target_normals = jax.vmap(compute_target_normal_single, in_axes=(0), out_axes=0) 70 | 71 | def normal_driven_energy_single(U,V,lam,n,tar_n,a,E,W): 72 | v_ghost = np.array([[0.,0.,0.]]) 73 | Ug = np.concatenate((U, v_ghost), axis = 0) 74 | Vg = np.concatenate((V, v_ghost), axis = 0) 75 | 76 | dV = (Vg[E[:,1],:] - Vg[E[:,0],:]).T 77 | dU = (Ug[E[:,1],:] - Ug[E[:,0],:]).T 78 | 79 | # orthogonal procrustes 80 | S = (dV * W).dot(dU.T) + lam*a*n[:,None].dot(tar_n[None,:]) 81 | 82 | # fit rotation 83 | R = jgp.fit_rotations_cayley(S) 84 | 85 | # compute loss 86 | RdV_dU = R.dot(dV) - dU 87 | Rn_tar_n = R.dot(n) - tar_n 88 | return np.trace((RdV_dU * W).dot(RdV_dU.T)) + lam*a*Rn_tar_n.dot(Rn_tar_n) 89 | normal_driven_energy = jax.vmap(normal_driven_energy_single, in_axes=(None,None,None,0,0,0,0,0), out_axes=0) 90 | 91 | # define loss function and update function 92 | def loss(U,V,lam,N,tar_N,VA,Ek_all,Wk_all): 93 | loss = normal_driven_energy(U,V,lam,N,tar_N,VA,Ek_all,Wk_all) 94 | return loss.mean() 95 | 96 | @jit 97 | def update(epoch, opt_state, V,lam,N,tar_N,VA,Ek_all,Wk_all): 98 | U = get_params(opt_state) 99 | value, grads = value_and_grad(loss, argnums = 0)(U,V,lam,N,tar_N,VA,Ek_all,Wk_all) 100 | opt_state = opt_update(epoch, grads, opt_state) 101 | return value, opt_state 102 | 103 | if __name__ == "__main__": 104 | # hyper parameters 105 | hyper_params = { 106 | "step_size": 1e-4, 107 | "num_epochs": 1000, 108 | } 109 | 110 | V,F = jgp.read_mesh("./spot.obj") 111 | N = jgp.vertex_normals(V,F) 112 | VA = jgp.vertex_areas(V,F) 113 | Ek_all, Wk_all = spokes_rims(V,F) 114 | tar_N = compute_target_normals(N) 115 | 116 | U = V.copy() 117 | lam = 1.0 118 | 119 | # optimizer 120 | opt_init, opt_update, get_params = optimizers.adam(step_size=hyper_params["step_size"]) 121 | opt_state = opt_init(U) 122 | 123 | # training 124 | loss_history = onp.zeros(hyper_params["num_epochs"]) 125 | pbar = tqdm.tqdm(range(hyper_params["num_epochs"])) 126 | for epoch in pbar: 127 | loss_value, opt_state = update(epoch, opt_state, V,lam,N,tar_N,VA,Ek_all,Wk_all) 128 | loss_history[epoch] = loss_value 129 | pbar.set_postfix({"loss": loss_value}) 130 | 131 | U = get_params(opt_state) 132 | jgp.write_obj("opt.obj", U,F) 133 | 134 | plt.semilogy(loss_history) 135 | plt.title('normal driven energy') 136 | plt.grid() 137 | plt.savefig("loss_history.jpg") 138 | 139 | ps.init() 140 | ps.register_surface_mesh('input mesh',V,F) 141 | ps.register_surface_mesh('optimized mesh',U,F) 142 | ps.show() 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /differentiable/angle_defect.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from .tip_angles import tip_angles 3 | from ..general.boundary_vertices import boundary_vertices 4 | from jax import jit 5 | 6 | # TODO: these are not abel to @jit yet 7 | def angle_defect(V,F,b=None,n=None): 8 | ''' 9 | ANGLE_DEFECT computes the angle defect (integrated Gauss curvature) at each 10 | vertex 11 | 12 | Inputs: 13 | V: (|V|,3) numpy ndarray of vertex positions. 14 | F: (|F|,3) numpy ndarray of face indices, must be static. 15 | b: (|b|,) numpy ndarray of boundary vertex indices (they have 0 defect). 16 | Will be computed if not provided, at the price of JITability. 17 | n: number of vertices in the mesh. 18 | Will be computed if not provided, at the price of JITability. 19 | 20 | Outputs: 21 | k: (|N|,) numpy array of angle defect at each vertex 22 | ''' 23 | #k = tip_angles(V,F) 24 | 25 | if b is None: 26 | bv = boundary_vertices(F) 27 | else: 28 | bv = b 29 | 30 | if n is None: 31 | nv,_ = V.shape 32 | else: 33 | nv = n 34 | 35 | k = angle_defect_intrinsic(F,tip_angles(V,F),bv,nv) 36 | 37 | #Pad return value if there are vertices in V not occuring in F 38 | # nv = V.shape[0] 39 | # nk = k.size 40 | # if nv>nk: 41 | # k = np.concatenate((k,np.zeros(nv-nk))) 42 | 43 | return k 44 | 45 | def angle_defect_intrinsic(F,A,b=np.empty(0,dtype=int),n=None): 46 | ''' 47 | ANGLE_DEFECT_INTRINSIC computes the angle defect (integrated Gauss 48 | curvature) at each vertex given intrinsic tip angles 49 | 50 | Inputs: 51 | F: (|F|,3) numpy ndarray of face indices 52 | A: (|F|,3) numpy ndarray of tip angles at each corner of each face (in [0,pi)) 53 | b: (|b|,) numpy ndarray of boundary vertex indices (they have 0 defect) 54 | n: number of vertices in the mesh. 55 | Will be computed if not provided, at the price of JITability. 56 | 57 | Outputs: 58 | k: (|N|,) numpy array of angle defect at each vertex 59 | ''' 60 | 61 | if n is None: 62 | nv = F.max() + 1 63 | else: 64 | nv = n 65 | 66 | k = 2.*np.pi - np.bincount(np.ravel(F), weights=np.ravel(A), length=nv) 67 | # k = k.at[b].set(0) 68 | 69 | return k 70 | -------------------------------------------------------------------------------- /differentiable/cotangent_weights.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from jax import jit 3 | 4 | @jit 5 | def cotangent_weights(V,F): 6 | """ 7 | computes cotangent weight for each half edge 8 | 9 | Input: 10 | V (|V|,3) array of vertex positions 11 | F (|F|,3) array of face indices 12 | Output: 13 | C (|F|,3) array of cotengent weights 14 | """ 15 | i0 = F[:,0] 16 | i1 = F[:,1] 17 | i2 = F[:,2] 18 | l0 = np.sqrt(np.sum(np.power(V[i1,:] - V[i2,:],2),1)) 19 | l1 = np.sqrt(np.sum(np.power(V[i2,:] - V[i0,:],2),1)) 20 | l2 = np.sqrt(np.sum(np.power(V[i0,:] - V[i1,:],2),1)) 21 | 22 | # Heron's formula for area 23 | s = (l0 + l1 + l2) / 2. 24 | area = np.sqrt(s * (s-l0) * (s-l1) * (s-l2)) 25 | 26 | # cotangent weights 27 | c0 = (l1*l1 + l2*l2 - l0*l0) / area / 8 28 | c1 = (l0*l0 + l2*l2 - l1*l1) / area / 8 29 | c2 = (l0*l0 + l1*l1 - l2*l2) / area / 8 30 | C = np.concatenate((c0[:,None], c1[:,None], c2[:,None]), axis = 1) 31 | return C -------------------------------------------------------------------------------- /differentiable/dihedral_angles.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from .face_normals import face_normals 3 | from .dotrow import dotrow 4 | from ..general.adjacency_list_edge_face import adjacency_list_edge_face 5 | from jax import jit 6 | 7 | # TODO: these are not abel to @jit yet 8 | def dihedral_angles(V, F): 9 | ''' 10 | DIHEDRAL_ANGLES computes the dihedral angles of a mesh 11 | 12 | Inputs: 13 | V: (|V|,3) numpy ndarray of vertex positions 14 | F: (|F|,3) numpy ndarray of face indices 15 | 16 | Outputs: 17 | dihedral_angles: (|E|,) numpy array of dihedral angles (0 ~ pi) 18 | E: (|E|,2) numpy array of edge indices 19 | ''' 20 | # TODO: double check whether this is differentiable 21 | 22 | FN = face_normals(V,F) 23 | E2F, E = adjacency_list_edge_face(F) 24 | return dihedral_angles_from_normals(FN,E2F), E 25 | 26 | def dihedral_angles_from_normals(N, E2F): 27 | ''' 28 | DIHEDRAL_ANGLES_FROM_NORMALS computes the dihedral angles of a mesh, given 29 | precomputed normals and edge-to-face map 30 | 31 | Inputs: 32 | N: (|F|,3) numpy ndarray of unit face normals 33 | E2F: (|E|,|F|) scipy sparse matrix scipy sparse matrix of adjacency 34 | information between edges and faces, for example as produced by 35 | adjacency_list_edge_face 36 | 37 | Outputs: 38 | dihedral_angles: (|E|,) numpy array of dihedral angles at each edge 39 | ''' 40 | 41 | dotN = dotrow(N[E2F[:,0],:], N[E2F[:,1],:]).clip(-1,1) 42 | return np.pi - np.arccos(dotN) 43 | -------------------------------------------------------------------------------- /differentiable/dotrow.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from jax import jit 3 | 4 | @jit 5 | def dotrow(X,Y): 6 | ''' 7 | DOTROW computes the row-wise dot product of the rows of two matrices 8 | 9 | Inputs: 10 | X: (n,m) numpy ndarray 11 | Y: (n,m) numpy ndarray 12 | 13 | Outputs: 14 | d: (n,) numpy array of rowwise dot product of X and Y 15 | ''' 16 | 17 | return np.sum(X * Y, axis = 1) 18 | 19 | -------------------------------------------------------------------------------- /differentiable/face_areas.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from jax import jit 3 | 4 | @jit 5 | def face_areas(V, F): 6 | """ 7 | FACEAREAS computes area per face 8 | 9 | Input: 10 | V (|V|,3) numpy array of vertex positions 11 | F (|F|,3) numpy array of face indices 12 | Output: 13 | FA (|F|,) numpy array of face areas 14 | """ 15 | vec1 = V[F[:,1],:] - V[F[:,0],:] 16 | vec2 = V[F[:,2],:] - V[F[:,0],:] 17 | FN = np.cross(vec1, vec2) / 2 18 | FA = np.sqrt(np.sum(FN**2,1)) 19 | return FA -------------------------------------------------------------------------------- /differentiable/face_normals.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from .normalizerow import normalizerow 3 | from jax import jit 4 | 5 | @jit 6 | def face_normals(V, F): 7 | ''' 8 | FACENORMALS computes unit face normal of a triangle mesh 9 | 10 | Inputs: 11 | V: (|V|,3) numpy ndarray of vertex positions 12 | F: (|F|,3) numpy ndarray of face indices 13 | 14 | Outputs: 15 | FN_normalized: (|F|,3) numpy ndarray of unit face normal 16 | ''' 17 | vec1 = V[F[:,1],:] - V[F[:,0],:] 18 | vec2 = V[F[:,2],:] - V[F[:,0],:] 19 | FN = np.cross(vec1, vec2) / 2 20 | # l2Norm = np.sqrt((FN * FN).sum(axis=1)) 21 | # FN_normalized = FN / (l2Norm.reshape(FN.shape[0],1)) 22 | FN_normalized = normalizerow(FN) 23 | return FN_normalized -------------------------------------------------------------------------------- /differentiable/fit_rotations_cayley.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | 3 | def fit_rotations_cayley(S, R = np.eye(3)): 4 | """ 5 | given a cross-covariance matrix S, this function outputs the closest rotation R. This method is based on "Fast Updates for Least-Squares Rotational Alignment" by Zhang et al 2021 6 | 7 | Input: 8 | S: 3x3 array of the corss covariance matrix 9 | R: (optional) 3x3 initial rotation matrix (default identity) 10 | 11 | Output 12 | R: rotation matrix which maximize np.trace(S@R) 13 | 14 | Warning: 15 | - I haven't add a input check to check whether S is a valid crosss covariance matrix, if not a valid crosss covariance matrix, this method will not output rotation matrix 16 | - I haven't implemented a stopping criteria, I have to figure out how to imeplement that in JAX 17 | """ 18 | for iter in range(3): # empirically, 3 iterations seem enough 19 | MM = S @ R 20 | z = cayley_step(MM) 21 | s = z.dot(z) 22 | z_col = np.expand_dims(z, 1) 23 | zzT = z_col @ z_col.T 24 | Z = np.array([[0., -z[2], z[1]], [z[2], 0., -z[0]], [-z[1], z[0], 0.]]) 25 | R = 1./(1.+s) * R @ ((1.-s)*np.eye(3) + 2*zzT + 2*Z) 26 | # if s < 1e-8: 27 | # return R 28 | return R 29 | 30 | def cayley_step(M): 31 | m = np.array([M[1,2]-M[2,1], M[2,0]-M[0,2], M[0,1]-M[1,0]]) 32 | i,j,k = 0,1,2 33 | two_lam0 = 2*M[i,i] + np.abs(M[i,j]+M[j,i]) + np.abs(M[i,k]+M[k,i]) 34 | i,j,k = 1,2,0 35 | two_lam1 = 2*M[i,i] + np.abs(M[i,j]+M[j,i]) + np.abs(M[i,k]+M[k,i]) 36 | i,j,k = 2,0,1 37 | two_lam2 = 2*M[i,i] + np.abs(M[i,j]+M[j,i]) + np.abs(M[i,k]+M[k,i]) 38 | two_lam = np.maximum(np.maximum(two_lam0, two_lam1), two_lam2) 39 | t = np.trace(M) 40 | gs = np.maximum(t, two_lam - t) 41 | H = M + M.T - (t + np.sqrt(gs*gs + m.dot(m)))*np.eye(3) 42 | 43 | h_inv = 1. / np.linalg.det(H) 44 | d0 = np.linalg.det( np.array([-m, H[:,1], H[:,2]]) ) 45 | d1 = np.linalg.det( np.array([H[:,0], -m, H[:,2]]) ) 46 | d2 = np.linalg.det( np.array([H[:,0], H[:,1], -m]) ) 47 | return h_inv * np.array([d0, d1, d2]) -------------------------------------------------------------------------------- /differentiable/halfedge_lengths.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from .dotrow import dotrow 3 | from jax import jit 4 | 5 | @jit 6 | def halfedge_lengths(V, F): 7 | ''' 8 | HALFEDGE_LENGTHS computes the lengths of all halfedges in the mesh 9 | 10 | Inputs: 11 | V: (|V|,3) numpy ndarray of vertex positions 12 | F: (|F|,3) numpy ndarray of face indices 13 | 14 | Outputs: 15 | l: (|F|,3) numpy ndarray of halfedge lenghts. 16 | Our halfedge convention identifies each halfedge by the index 17 | of the face and the opposite vertex within the face: 18 | (face, opposite vertex) 19 | ''' 20 | 21 | return np.sqrt(halfedge_lengths_squared(V,F)) 22 | 23 | @jit 24 | def halfedge_lengths_squared(V, F): 25 | ''' 26 | HALFEDGE_LENGTHS_SQUARED computes the lengths of all halfedges in the mesh, 27 | squared (this is often preferable to just the 28 | lengths, since it's easier to differentiate 29 | through) 30 | 31 | Inputs: 32 | V: (|V|,3) numpy ndarray of vertex positions 33 | F: (|F|,3) numpy ndarray of face indices 34 | 35 | Outputs: 36 | l_sq: (|F|,3) numpy ndarray of squared halfedge lenghts. 37 | Our halfedge convention identifies each halfedge by the index 38 | of the face and the opposite vertex within the face: 39 | (face, opposite vertex) 40 | ''' 41 | 42 | he0 = V[F[:,2],:] - V[F[:,1],:] 43 | he1 = V[F[:,0],:] - V[F[:,2],:] 44 | he2 = V[F[:,1],:] - V[F[:,0],:] 45 | 46 | lhe0 = np.expand_dims(dotrow(he0,he0), axis=1) 47 | lhe1 = np.expand_dims(dotrow(he1,he1), axis=1) 48 | lhe2 = np.expand_dims(dotrow(he2,he2), axis=1) 49 | 50 | l_sq = np.concatenate((lhe0, lhe1, lhe2), axis=1) 51 | 52 | return l_sq 53 | -------------------------------------------------------------------------------- /differentiable/normalize_unit_box.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from jax import jit 3 | 4 | @jit 5 | def normalize_unit_box(V, margin = 0.0): 6 | """ 7 | NORMALIZE_UNIT_BOX normalize a set of points to a unit bounding box with a user-specified margin 8 | 9 | Input: 10 | V: (n,3) numpy array of point locations 11 | margin: a constant of user specified margin 12 | Output: 13 | V: (n,3) numpy array of point locations bounded by margin ~ 1-margin 14 | """ 15 | 16 | V = V - V.min(0) 17 | V = V / V.max() 18 | V = V - 0.5 19 | V = V * (1.0 - margin*2) 20 | V = V + 0.5 21 | return V -------------------------------------------------------------------------------- /differentiable/normalizerow.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from .normrow import normrow 3 | from jax import jit 4 | 5 | @jit 6 | def normalizerow(X): 7 | """ 8 | NORMALIZEROW normalizes the l2-norm of each row in a np array 9 | 10 | Input: 11 | X: (n,m) numpy array 12 | Output: 13 | X_normalized: (n,m) row normalized numpy array 14 | """ 15 | l2Norm = normrow(X) 16 | X_normalized = X / (l2Norm.reshape(X.shape[0],1)) 17 | return X_normalized 18 | -------------------------------------------------------------------------------- /differentiable/normrow.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from .dotrow import dotrow 3 | from jax import jit 4 | 5 | @jit 6 | def normrow(X): 7 | """ 8 | NORMROW computes the l2-norm of each row in a np array 9 | 10 | Input: 11 | X: (n,m) numpy array 12 | Output: 13 | nX: (n,) numpy array of l2 norm of each row in X 14 | """ 15 | 16 | return np.sqrt(dotrow(X,X)) -------------------------------------------------------------------------------- /differentiable/ramp_smooth.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from jax import jit 3 | 4 | @jit 5 | def ramp_smooth(d, d0 = 1.0): 6 | """ 7 | RAMP function so that: 8 | - 1, when d/d0 >= 1 9 | - -1 when d/d0 <= 1 10 | - smooth decay when -1 < d/d0 < 1 11 | (following the notation of "Curl-Noise for Procedural Fluid Flow" Bridson et al 2007) 12 | 13 | Input: 14 | d distance of a query point to its closest point on boundary 15 | d0 scale of the decay 16 | Output: 17 | scale ramp scaling factor 18 | """ 19 | r = d / d0 20 | val = 15./8.*r - 10./8.*(r**3) + 3./8.*(r**5) 21 | return np.minimum(np.abs(val), 1.0) * np.sign(val) -------------------------------------------------------------------------------- /differentiable/tip_angles.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from jax import jit 3 | from .halfedge_lengths import halfedge_lengths_squared 4 | 5 | @jit 6 | def tip_angles(V, F): 7 | ''' 8 | TIP_ANGLES computes the tip angles of a mesh (the angle at every corner) 9 | 10 | Inputs: 11 | V: (|V|,3) numpy ndarray of vertex positions 12 | F: (|F|,3) numpy ndarray of face indices 13 | 14 | Outputs: 15 | A: (|F|,3) numpy ndarray of tip angles at each corner of each face (in [0,pi)) 16 | ''' 17 | 18 | return tip_angles_intrinsic(halfedge_lengths_squared(V,F)) 19 | 20 | 21 | def tip_angles_intrinsic(l_sq): 22 | ''' 23 | TIP_ANGLES_INTRINSIC computes the tip angles of a mesh (the angle at every 24 | corner) given squared halfedge lengths 25 | 26 | Inputs: 27 | l_sq: (|F|,3) numpy ndarray of squared halfedge lenghts. 28 | Our halfedge convention identifies each halfedge by the index 29 | of the face and the opposite vertex within the face: 30 | (face, opposite vertex) 31 | 32 | Outputs: 33 | A: (|F|,3) numpy ndarray of tip angles at each corner of each face (in [0,pi)) 34 | ''' 35 | 36 | #Use the cosine rule 37 | a_sq = np.expand_dims(l_sq[:,0], axis=1) 38 | b_sq = np.expand_dims(l_sq[:,1], axis=1) 39 | c_sq = np.expand_dims(l_sq[:,2], axis=1) 40 | a = np.sqrt(a_sq) 41 | b = np.sqrt(b_sq) 42 | c = np.sqrt(c_sq) 43 | 44 | cos_angles = np.concatenate(( 45 | (b_sq + c_sq - a_sq) / (2.*b*c), 46 | (a_sq + c_sq - b_sq) / (2.*a*c), 47 | (a_sq + b_sq - c_sq) / (2.*a*b) 48 | ), axis=1).clip(-1,1) 49 | 50 | return np.arccos(cos_angles) 51 | -------------------------------------------------------------------------------- /differentiable/vertex_areas.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from jax import jit 3 | 4 | @jit 5 | def vertex_areas(V,F): 6 | """ 7 | computes area per vertex 8 | 9 | Input: 10 | V (|V|,3) numpy array of vertex positions 11 | F (|F|,3) numpy array of face indices 12 | Output: 13 | VA (|V|,) numpy array of vertex areas 14 | """ 15 | l1 = np.sqrt(np.sum((V[F[:,1],:]-V[F[:,2],:])**2,1)) 16 | l2 = np.sqrt(np.sum((V[F[:,2],:]-V[F[:,0],:])**2,1)) 17 | l3 = np.sqrt(np.sum((V[F[:,0],:]-V[F[:,1],:])**2,1)) 18 | 19 | cos1 = (l3**2+l2**2-l1**2) / (2*l2*l3) 20 | cos2 = (l1**2+l3**2-l2**2) / (2*l1*l3) 21 | cos3 = (l1**2+l2**2-l3**2) / (2*l1*l2) 22 | 23 | cosMat = np.concatenate( (cos1[:,None], cos2[:,None], cos3[:,None]), axis =1) 24 | lMat = np.concatenate( (l1[:,None], l2[:,None], l3[:,None]), axis =1) 25 | barycentric = cosMat * lMat 26 | normalized_barycentric = barycentric / np.sum(barycentric,1)[:,None] 27 | areas = 0.25 * np.sqrt( (l1+l2-l3)*(l1-l2+l3)*(-l1+l2+l3)*(l1+l2+l3) ) 28 | partArea = normalized_barycentric * areas[:,None] 29 | 30 | quad0 = (partArea[:,1]+partArea[:,2]) * 0.5 31 | quad1 = (partArea[:,0]+partArea[:,2]) * 0.5 32 | quad2 = (partArea[:,0]+partArea[:,1]) * 0.5 33 | 34 | idx = np.where(cos1<0, 0, 1) 35 | quad0 = quad0.at[idx].set(areas[idx] * 0.5) 36 | quad1 = quad1.at[idx].set(areas[idx] * 0.25) 37 | quad2 = quad2.at[idx].set(areas[idx] * 0.25) 38 | 39 | idx = np.where(cos2<0, 0, 1) 40 | quad0 = quad0.at[idx].set(areas[idx] * 0.25) 41 | quad1 = quad1.at[idx].set(areas[idx] * 0.5) 42 | quad2 = quad2.at[idx].set(areas[idx] * 0.25) 43 | 44 | idx = np.where(cos3<0, 0, 1) 45 | quad0 = quad0.at[idx].set(areas[idx] * 0.25) 46 | quad1 = quad1.at[idx].set(areas[idx] * 0.25) 47 | quad2 = quad2.at[idx].set(areas[idx] * 0.5) 48 | quads = np.concatenate( (quad0[:,None], quad1[:,None], quad2[:,None]), axis =1).flatten() 49 | 50 | VA = np.zeros((V.shape[0],)) 51 | VA = VA.at[F.flatten()].add(quads) 52 | return VA -------------------------------------------------------------------------------- /differentiable/vertex_normals.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from jax import jit 3 | from .normalizerow import normalizerow 4 | 5 | @jit 6 | def vertex_normals(V,F): 7 | """ 8 | Computes face area weighted vertex normals 9 | 10 | Input: 11 | V |V|x3 numpy array of vertex positions 12 | F |F|x3 numpy array of face indices 13 | 14 | Output: 15 | N |V|x3 array of normalized vertex normal (weighted by face areas) 16 | """ 17 | vec1 = V[F[:,1],:] - V[F[:,0],:] 18 | vec2 = V[F[:,2],:] - V[F[:,0],:] 19 | FN_unnormalized = np.cross(vec1, vec2) / 2 20 | 21 | VN = np.zeros((V.shape[0],3), dtype=np.float32) 22 | VN = VN.at[F[:,0]].add(FN_unnormalized) 23 | VN = VN.at[F[:,1]].add(FN_unnormalized) 24 | VN = VN.at[F[:,2]].add(FN_unnormalized) 25 | return normalizerow(VN) 26 | -------------------------------------------------------------------------------- /external/read_mesh.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | import numpy as onp 3 | import igl 4 | 5 | def read_mesh(path): 6 | ''' 7 | just a jax wrapper for igl.read_triangle_mesh 8 | ''' 9 | V, F = igl.read_triangle_mesh(path) 10 | return np.array(V), np.array(F) -------------------------------------------------------------------------------- /external/signed_distance.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | import numpy as onp 3 | import igl 4 | 5 | def signed_distance(P, V, F): 6 | ''' 7 | SIGNED_DISTANCE computes signed distance from given points to a mesh 8 | 9 | Inputs: 10 | P: (|P|,3) numpy ndarray of point positions 11 | V: (|V|,3) numpy ndarray of vertex positions 12 | F: (|F|,3) numpy ndarray of face indices 13 | 14 | Outputs: 15 | S: (|P|,) numpy array of signed distance 16 | pF:(|P|,) numpy array of projected face indices 17 | pV:(|P|,3) numpy array of projected point locations 18 | 19 | Notes: 20 | It can be differentiable, but this faster version based on C++ is not 21 | ''' 22 | P_np = onp.array(P) 23 | V_np = onp.array(V) 24 | F_np = onp.array(F) 25 | S,pF,pV = igl.signed_distance(P_np, V_np, F_np) 26 | 27 | return np.array(S), np.array(pF), np.array(pV) -------------------------------------------------------------------------------- /external/write_obj.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | import numpy as onp 3 | import igl 4 | 5 | def write_obj(path, V, F): 6 | ''' 7 | just a jax wrapper for igl.write_obj 8 | ''' 9 | igl.write_obj(path,onp.array(V),onp.array(F)) -------------------------------------------------------------------------------- /general/adjacency_edge_face.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import jax 4 | from . edges_with_mapping import edges_with_mapping 5 | 6 | def adjacency_edge_face(F): 7 | ''' 8 | ADJACENCY_EDGE_FACE computes edge-face adjacency matrix 9 | 10 | Input: 11 | F (|F|,3) numpy array of face indices 12 | Output: 13 | E2F (|E|, |F|) scipy sparse matrix of adjacency information between edges and faces 14 | E (|E|,2) numpy array of edge indices 15 | ''' 16 | E, F2E = edges_with_mapping(F) 17 | IC = F2E.T.reshape(F.shape[1]*F.shape[0]) 18 | 19 | row = IC 20 | col = np.tile(np.arange(F.shape[0]), 3) 21 | val = np.ones(len(IC), dtype=np.int) 22 | E2F = scipy.sparse.coo_matrix((val,(row, col)), shape=(E.shape[0], F.shape[0])).tocsr() 23 | 24 | return E2F, jax.numpy.asarray(E) 25 | -------------------------------------------------------------------------------- /general/adjacency_list_edge_face.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | from . adjacency_edge_face import adjacency_edge_face 4 | 5 | def adjacency_list_edge_face(F): 6 | ''' 7 | ADJACENCY_LIST_EDGE_FACE computes edge-face adjacency list 8 | 9 | Input: 10 | F (|F|,3) numpy array of face indices 11 | Output: 12 | adjList (|E|,2) numpy array of adjacency list between edges and faces 13 | E (|E|,2) numpy array of edge indices 14 | ''' 15 | 16 | E2F, E = adjacency_edge_face(F) 17 | adjList = np.zeros((E.shape[0], 2), dtype=np.int) # assume manifold 18 | E2F_bool = E2F.astype('bool') 19 | 20 | idx = np.arange(F.shape[0]) 21 | 22 | for e in range(E.shape[0]): 23 | E2F_bool_row = np.squeeze(np.asarray(E2F_bool[e,:].todense())) 24 | adjList[e,:] = idx[E2F_bool_row] 25 | 26 | return jax.numpy.asarray(adjList), E -------------------------------------------------------------------------------- /general/adjacency_list_face_face.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | from scipy.sparse import coo_matrix 3 | 4 | def adjacency_list_face_face(F): 5 | """ 6 | build a face-face adjacency list such that F2F[face_index] = [adjacent_face_indices]. Note that neighboring faces are determined by whether two faces share an edge. 7 | 8 | Inputs 9 | F: |F|x3 array of face indices 10 | 11 | Outputs 12 | F2F: list of lists with so that F2F[f] = [fi, fj, ...] 13 | """ 14 | 15 | F = onp.array(F) 16 | 17 | # build adjacency matrix 18 | E_all = onp.vstack((F[:,[1,2]], F[:,[2,0]], F[:,[0,1]])) 19 | sort_E_all = onp.sort(E_all,-1) 20 | 21 | # extract unique edges with inverse indices s.t. E[inv_idx,:] = sort_E_all 22 | E, inv_idx = onp.unique(sort_E_all, axis=0, return_inverse=True) 23 | 24 | nF = F.shape[0] 25 | nE = E.shape[0] 26 | 27 | row = inv_idx 28 | col = onp.repeat(onp.arange(nF),F.shape[1]) 29 | data = onp.ones_like(inv_idx) 30 | uE2F = coo_matrix((data, (row, col)), shape=(nE, nF)) 31 | F2F_mat = uE2F.T @ uE2F 32 | F2F_mat.setdiag(0) 33 | 34 | F2F = [[] for _ in range(nF)] 35 | for f in range(nF): 36 | _, col = F2F_mat[f,:].nonzero() 37 | F2F[f] = col.tolist() 38 | 39 | return F2F -------------------------------------------------------------------------------- /general/adjacency_list_vertex_face.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | 3 | def adjacency_list_vertex_face(F): 4 | """ 5 | build a vertex-face adjacency list such that V2F[vertex_index] = [adjacent_face_indices] 6 | 7 | Inputs 8 | F: |F|x3 array of face indices 9 | 10 | Outputs 11 | V2F: list of lists with so that V2F[v] = [fi, fj, ...] 12 | """ 13 | F = onp.array(F) 14 | nV = F.max() + 1 15 | V2F = [[] for _ in range(nV)] 16 | for f in range(F.shape[0]): 17 | for s in range(F.shape[1]): 18 | V2F[F[f,s]].append(f) 19 | return V2F -------------------------------------------------------------------------------- /general/adjacency_list_vertex_vertex.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .adjacency_vertex_vertex import adjacency_vertex_vertex 3 | 4 | def adjacency_list_vertex_vertex(F): 5 | """ 6 | This function computes vertex to vertex adjacency matrix 7 | 8 | inputs: 9 | F: |F|x3 list of face indices 10 | 11 | outputs: 12 | AList: |V| list of adjacent vertex indices 13 | """ 14 | A = adjacency_vertex_vertex(F) 15 | AList = [] 16 | 17 | for ii in range(A.shape[0]): 18 | Aii_nonzero = A[ii,:].nonzero()[1] 19 | AList.append(Aii_nonzero) 20 | 21 | return AList -------------------------------------------------------------------------------- /general/adjacency_vertex_vertex.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | from .edges import edges 4 | 5 | def adjacency_vertex_vertex(F): 6 | """ 7 | This function computes vertex to vertex adjacency matrix 8 | 9 | inputs: 10 | F: |F|x3 list of face indices 11 | 12 | outputs: 13 | A: |V|x|V| scipy sparse matrix of vertex adjacency matrix 14 | """ 15 | E = edges(F) 16 | nV = F.max() + 1 17 | 18 | row = np.concatenate((E[:,0],E[:,1])) 19 | col = np.concatenate((E[:,1],E[:,0])) 20 | val = np.ones(len(row), dtype=np.int) 21 | A = scipy.sparse.coo_matrix((val,(row, col)), shape=(nV,nV)).tocsr() 22 | return A -------------------------------------------------------------------------------- /general/boundary_vertices.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | from .outline import outline 4 | 5 | def boundary_vertices(F): 6 | ''' 7 | OUTLINE compute the unordered outline edges of a triangle mesh 8 | 9 | Input: 10 | F (|F|,3) numpy array of face indices 11 | Output: 12 | b (|bE|,) numpy array of boundary vertex indices 13 | ''' 14 | 15 | return np.unique(outline(F)) 16 | -------------------------------------------------------------------------------- /general/cotmatrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.sparse as sparse 4 | 5 | def cotmatrix(V,F): 6 | ''' 7 | COTMATRIX computes the cotangent laplace matrix of a triangle mesh 8 | 9 | Inputs: 10 | V: |V|-by-3 numpy ndarray of vertex positions 11 | F: |F|-by-3 numpy ndarray of face indices 12 | 13 | Output: 14 | L: |V|-by-|V| matrix 15 | ''' 16 | V = np.array(V) 17 | F = np.array(F) 18 | numVert = V.shape[0] 19 | numFace = F.shape[0] 20 | 21 | temp1 = np.zeros((numFace, 3)) 22 | temp2 = np.zeros((numFace, 3)) 23 | angles = np.zeros((numFace, 3)) 24 | 25 | # compute angle 26 | for i in range(3): 27 | i1 = (i ) % 3 28 | i2 = (i+1) % 3 29 | i3 = (i+2) % 3 30 | temp1 = V[F[:,i2],:] - V[F[:,i1],:] 31 | temp2 = V[F[:,i3],:] - V[F[:,i1],:] 32 | 33 | # normalize the vectors 34 | norm_temp1 = np.sqrt(np.power(temp1,2).sum(axis = 1)) 35 | norm_temp2 = np.sqrt(np.power(temp2,2).sum(axis = 1)) 36 | temp1 = np.divide(temp1, np.repeat([norm_temp1], 3, axis=0).transpose()) 37 | temp2 = np.divide(temp2, np.repeat([norm_temp2], 3, axis=0).transpose()) 38 | 39 | # compute angles 40 | dotProd = np.multiply(temp1, temp2).sum(axis = 1) 41 | angles[:,i1] = np.arccos(dotProd) 42 | 43 | # compute cotan laplace 44 | L = sparse.lil_matrix((numVert,numVert), dtype=np.float64) 45 | for i in range(3): 46 | i1 = (i ) % 3 47 | i2 = (i+1) % 3 48 | i3 = (i+2) % 3 49 | L[F[:,i1],F[:,i2]] += -1 / np.tan(angles[:,i3]) 50 | L = (L + L.transpose()) / 2 51 | temp = np.array(L.sum(axis=1)).reshape((numVert)) 52 | L -= sparse.diags(temp) 53 | return -L.tocsr() -------------------------------------------------------------------------------- /general/edge_flaps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from . edges_with_mapping import edges_with_mapping 3 | 4 | def edge_flaps(F): 5 | ''' 6 | EDGEFLAPS compute flap edge indices for each edge 7 | 8 | Input: 9 | F (|F|,3) numpy array of face indices 10 | Output: 11 | E (|E|,2) numpy array of edge indices 12 | flapEdges (|E|, 4 or 2) numpy array of edge indices 13 | ''' 14 | # Notes: 15 | # Each flapEdges[e,:] = [a,b,c,d] edges indices 16 | # / \ 17 | # b a 18 | # / \ 19 | # - e - - 20 | # \ / 21 | # c d 22 | # \ / 23 | E, F2E = edges_with_mapping(F) 24 | flapEdges = [[] for i in range(E.shape[0])] 25 | for f in range(F.shape[0]): 26 | e0 = F2E[f,0] 27 | e1 = F2E[f,1] 28 | e2 = F2E[f,2] 29 | flapEdges[e0].extend([e1,e2]) 30 | flapEdges[e1].extend([e2,e0]) 31 | flapEdges[e2].extend([e0,e1]) 32 | return E, np.array(flapEdges) -------------------------------------------------------------------------------- /general/edges.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax 3 | from . edges_with_mapping import edges_with_mapping 4 | 5 | def edges(F): 6 | ''' 7 | EDGES compute edges from face information 8 | 9 | Input: 10 | F (|F|,3) numpy array of face indices 11 | Output: 12 | E (|E|,2) numpy array of edge indices 13 | ''' 14 | edge, _ = edges_with_mapping(F) 15 | return edge -------------------------------------------------------------------------------- /general/edges_with_mapping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def edges_with_mapping(F): 4 | ''' 5 | EDGES compute edges from face information 6 | 7 | Input: 8 | F (|F|,3) numpy array of face indices 9 | Output: 10 | uE (|E|,2) numpy array of edge indices 11 | F2E (|F|,3) numpy array mapping from halfedges to unique edges. 12 | Our halfedge convention identifies each halfedge by the index 13 | of the face and the opposite vertex within the face: 14 | (face, opposite vertex) 15 | ''' 16 | F12 = F[:, np.array([1,2])] 17 | F20 = F[:, np.array([2,0])] 18 | F01 = F[:, np.array([0,1])] 19 | EAll = np.concatenate( (F12, F20, F01), axis = 0) 20 | 21 | EAll_sortrow = np.sort(EAll, axis = 1) 22 | uE, F2E_vec = np.unique(EAll_sortrow, return_inverse=True, axis=0) 23 | F2E = F2E_vec.reshape(F.shape[1], F.shape[0]).T 24 | return uE, F2E -------------------------------------------------------------------------------- /general/find_index.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def find_index(F, VIdx): 4 | ''' 5 | FINDINDEX finds desired indices in the ndarray 6 | 7 | Inputs: 8 | F: |F|-by-dim numpy ndarray 9 | VIdx: a list of indices 10 | 11 | Output: 12 | r, c: row/colummn indices in the ndarray 13 | ''' 14 | mask = np.in1d(F.flatten(),VIdx) 15 | try: 16 | nDim = F.shape[1] 17 | except: 18 | nDim = 1 19 | r = np.floor(np.where(mask)[0] / (nDim*1.0) ).astype(int) 20 | c = np.where(mask)[0] % nDim 21 | return r,c -------------------------------------------------------------------------------- /general/he_initialization.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax.numpy as np 3 | def he_initialization(n_out, n_in): 4 | """ 5 | initialization based on He et al 2015. This is often recommended for ReLU like networks 6 | 7 | Inputs 8 | n_out: dimension of the output of a linear layer 9 | n_in: dimension of the input of a linear layer 10 | 11 | Outputs 12 | W: jax array of weight matrix 13 | b: jax array of bias matrix 14 | """ 15 | W = np.array(onp.random.randn(n_out, n_in) * onp.sqrt(2 / n_in)) 16 | b = np.zeros(n_out) 17 | return W, b -------------------------------------------------------------------------------- /general/knn_search.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import scipy.spatial 3 | 4 | def knn_search(query_points, source_points, k): 5 | """ 6 | KNNSEARCH finds the k nearnest neighbors of query_points in source_points 7 | 8 | Inputs: 9 | query_points: N-by-D numpy array of query points 10 | source_points: M-by-D numpy array existing points 11 | k: number of neighbors to return 12 | 13 | Output: 14 | dist: distance between the point in array1 with kNN 15 | NNIdx: nearest neighbor indices of array1 16 | """ 17 | kdtree = scipy.spatial.cKDTree(source_points) 18 | dist, NNIdx = kdtree.query(query_points, k) 19 | return dist, NNIdx -------------------------------------------------------------------------------- /general/list_remove_indices.py: -------------------------------------------------------------------------------- 1 | def list_remove_indices(l, indices): 2 | """ 3 | this method remove multiple elements from a list by indices 4 | 5 | inputs 6 | l: list 7 | indices: indices to be removed 8 | 9 | outouts 10 | l: change in place 11 | """ 12 | indices = sorted(indices, reverse=True) 13 | for idx in indices: 14 | if idx < len(l): 15 | l.pop(idx) -------------------------------------------------------------------------------- /general/massmatrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.sparse 4 | 5 | def massmatrix(V, F): 6 | """ 7 | construct voronoi area mass matrix of vertex areas 8 | 9 | Input: 10 | V (|V|,3) array of vertex positions 11 | F (|F|,3) array of face indices 12 | 13 | Output: 14 | M (|V|,|V|) scipy sparse diagonal matrix of vertex area 15 | """ 16 | # in case input is not numpy array 17 | V = np.array(V) 18 | F = np.array(F) 19 | 20 | l1 = np.sqrt(np.sum((V[F[:,1],:]-V[F[:,2],:])**2,1)) 21 | l2 = np.sqrt(np.sum((V[F[:,2],:]-V[F[:,0],:])**2,1)) 22 | l3 = np.sqrt(np.sum((V[F[:,0],:]-V[F[:,1],:])**2,1)) 23 | lMat = np.concatenate( (l1[:,None], l2[:,None], l3[:,None]), axis =1) 24 | 25 | cos1 = (l3**2+l2**2-l1**2) / (2*l2*l3) 26 | cos2 = (l1**2+l3**2-l2**2) / (2*l1*l3) 27 | cos3 = (l1**2+l2**2-l3**2) / (2*l1*l2) 28 | cosMat = np.concatenate( (cos1[:,None], cos2[:,None], cos3[:,None]), axis =1) 29 | 30 | barycentric = cosMat * lMat 31 | normalized_barycentric = barycentric / np.sum(barycentric,1)[:,None] 32 | areas = 0.25 * np.sqrt( (l1+l2-l3)*(l1-l2+l3)*(-l1+l2+l3)*(l1+l2+l3) ) 33 | partArea = normalized_barycentric * areas[:,None] 34 | quad1 = (partArea[:,1]+partArea[:,2]) * 0.5 35 | quad2 = (partArea[:,0]+partArea[:,2]) * 0.5 36 | quad3 = (partArea[:,0]+partArea[:,1]) * 0.5 37 | quads = np.concatenate( (quad1[:,None], quad2[:,None], quad3[:,None]), axis =1) 38 | 39 | boolM = cosMat[:,0]<0 40 | quads[boolM,0] = areas[boolM]*0.5 41 | quads[boolM,1] = areas[boolM]*0.25 42 | quads[boolM,2] = areas[boolM]*0.25 43 | 44 | boolM = cosMat[:,1]<0 45 | quads[boolM,0] = areas[boolM]*0.25 46 | quads[boolM,1] = areas[boolM]*0.5 47 | quads[boolM,2] = areas[boolM]*0.25 48 | 49 | boolM = cosMat[:,2]<0 50 | quads[boolM,0] = areas[boolM]*0.25 51 | quads[boolM,1] = areas[boolM]*0.25 52 | quads[boolM,2] = areas[boolM]*0.5 53 | 54 | rIdx = F.flatten() 55 | cIdx = F.flatten() 56 | val = quads.flatten() 57 | nV = V.shape[0] 58 | M = scipy.sparse.csr_matrix( (val,(rIdx,cIdx)), shape=(nV,nV), dtype=np.float64 ) 59 | return M -------------------------------------------------------------------------------- /general/mid_point_curve_simplification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .ordered_outline import ordered_outline 3 | from .remove_unreferenced import remove_unreferenced 4 | 5 | def mid_point_curve_simplification(V,O,tarE): 6 | """ 7 | this function simplify a single closed curve via collapsing the shortest edge 8 | 9 | Inputs 10 | V: |V|x3 vertex list 11 | O: |O|x2 (unordered) boundary edges 12 | tarE: target number of edges in the simplified curve 13 | 14 | Outputs 15 | V: |Vc|x3 simplified vertex list 16 | O: tarEx3 simplified boundary curve 17 | 18 | Warning: 19 | - This only support single closed curve 20 | - This is a simple collapst which does not preserve geometry nor avoid collision 21 | """ 22 | V = np.array(V) 23 | L = ordered_outline(O) 24 | E_list = np.array(L[0]) 25 | E = np.array([E_list, np.roll(E_list,-1)]).T 26 | 27 | # mid point collapse ordered outline 28 | ECost = np.sqrt(np.sum((V[E[:,0],:] - V[E[:,1],:])**2,1)) # cost is outline edge lengths 29 | 30 | total_collapses = E.shape[0] - tarE 31 | num_collapses = 0 32 | 33 | while True: 34 | if num_collapses % 100 == 0: 35 | print("collapse progress %d / %d\n" % (num_collapses, total_collapses)) 36 | 37 | # get the minimum cost (slow) 38 | # note: a faster version should use a priority queue 39 | e = np.argmin(ECost) 40 | 41 | # check if the edge is degenerated 42 | if E[e,0] == E[e,1]: 43 | E = np.delete(E,e,0) 44 | ECost = np.delete(ECost,e) 45 | continue 46 | 47 | # move vertex vi 48 | vi, vj = E[e,:] 49 | V[vi,:] = (V[vi,:] + V[vj,:]) / 2. 50 | 51 | # reconnect edges 52 | prev_e = (e-1) % E.shape[0] 53 | next_e = (e+1) % E.shape[0] 54 | E[next_e,0] = vi # keep E[e,0] and unreference vj 55 | 56 | # update edge costs 57 | ECost[prev_e] = np.sqrt(np.sum((V[E[prev_e,0],:] - V[E[prev_e,1],:])**2)) 58 | ECost[next_e] = np.sqrt(np.sum((V[E[next_e,0],:] - V[E[next_e,1],:])**2)) 59 | 60 | # post collapse update 61 | E = np.delete(E,e,0) 62 | ECost = np.delete(ECost,e) 63 | num_collapses += 1 64 | 65 | # stopping 66 | if num_collapses == total_collapses: 67 | break 68 | V,E,_ = remove_unreferenced(V,E) 69 | return V, E -------------------------------------------------------------------------------- /general/ordered_outline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .outline import outline 3 | 4 | def ordered_outline(ForO): 5 | """ 6 | this function computes an ordered boundary curves of a mesh 7 | 8 | Inputs: 9 | F: |F|x3 face index list 10 | or 11 | O: |E|x2 unordered outline 12 | 13 | Outputs 14 | L: list of list of order boundary indices such that L[0] is the list of vertices of the 0th boundary curve 15 | """ 16 | if ForO.shape[1] == 3: # input is a face list 17 | F = np.array(ForO) 18 | O = outline(F) 19 | O = np.array(O) 20 | elif ForO.shape[1] == 2: # input is a (unordered) boundary curve list 21 | O = np.array(ForO) 22 | 23 | # index map for vertices such that IMV[old_vIdx] = new_vIdx 24 | uV = np.unique(O) 25 | nV = O.max() + 1 26 | IMV = np.zeros(nV, dtype = np.int64) 27 | IMV[uV] = np.arange(len(uV)) 28 | 29 | # inverse index map such that invIMV[new_vIdx] = old_vIdx 30 | invIMV = uV 31 | 32 | # index map to O[:,0] such that old_vIdx = O[IMO[old_vIdx],0] 33 | IMO = np.zeros(nV, dtype = np.int64) 34 | IMO[O[:,0]] = np.arange(len(uV)) 35 | 36 | L = [] # loop for multiple boundary loops 37 | visited = np.full((len(uV),),False) # whether visited (stored in new vIdx) 38 | while not np.all(visited): 39 | # get a vertex for a Loop 40 | vnew = np.where(visited == False)[0][0] 41 | v = invIMV[vnew] 42 | next_v = get_next_v(v,O,IMO) 43 | 44 | start_v = v # track starting vertex 45 | B = [] # to store each boundary loop 46 | B.append(start_v) # add the starting vertex 47 | 48 | # update visited list 49 | visited[IMV[start_v]] = True 50 | 51 | 52 | # find all the vertices of this loop 53 | while next_v != start_v: 54 | B.append(next_v) # add the next vertex 55 | visited[IMV[next_v]] = True 56 | v = next_v 57 | next_v = get_next_v(v,O,IMO) 58 | L.append(B) 59 | return L 60 | 61 | def get_next_v(v,O,IMO): 62 | return O[IMO[v],1] -------------------------------------------------------------------------------- /general/outline.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import scipy 3 | import jax.numpy as np 4 | def outline(F_jnp): 5 | """ 6 | this function extract (unordered) boundary edges of a mesh 7 | 8 | Inputs 9 | F_jnp: |F|x3 jax numpy array of the face indices 10 | 11 | Outputs 12 | O: |E|x2 jax numpy array of unordered boundary edges 13 | 14 | Reference: 15 | this code is adapted from https://github.com/alecjacobson/gptoolbox/blob/master/mesh/outline.m 16 | """ 17 | F = onp.array(F_jnp) # convert the numpy for efficiency 18 | nV = F.max()+1 19 | row = F.flatten() 20 | col = F[:,[1,2,0]].flatten() 21 | data = onp.ones(len(row), dtype=np.int32) 22 | A = scipy.sparse.csr_matrix((data, (row, col)), shape=(nV,nV)) # build directed adj matrix 23 | AA = A - A.transpose() # figure out edges that only have one half edge 24 | AA.eliminate_zeros() 25 | I,J,V = scipy.sparse.find(AA) # get the non-zeros 26 | O = np.array([I[V>0], J[V>0]]).T # construct the boundary edge list 27 | return O 28 | 29 | 30 | # ===== 31 | # I switch to the gptoolbox version because a short prifiling show that the above version is 5x faster than the bottom version on a mesh with 1200 faces 32 | # ===== 33 | # import numpy as np 34 | # import jax 35 | 36 | # def outline(F): 37 | # ''' 38 | # OUTLINE compute the unordered outline edges of a triangle mesh 39 | 40 | # Input: 41 | # F (|F|,3) numpy array of face indices 42 | # Output: 43 | # O (|bE|,2) numpy array of boundary vertex indices, one edge per row 44 | # ''' 45 | 46 | # # All halfedges 47 | # he = np.stack((np.ravel(F[:,[1,2,0]]),np.ravel(F[:,[2,0,1]])), axis=1) 48 | 49 | # # Sort hes to be able to find duplicates later 50 | # # inds = np.argsort(he, axis=1) 51 | # # he_sorted = np.sort(he, inds, axis=1) 52 | # he_sorted = np.sort(he, axis=1) 53 | 54 | # # Extract unique rows 55 | # _,unique_indices,unique_counts = np.unique(he_sorted, axis=0, return_index=True, return_counts=True) 56 | 57 | # # All the indices with only one unique count are the original boundary edges 58 | # # in he 59 | # O = he[unique_indices[unique_counts==1],:] 60 | 61 | # return O 62 | -------------------------------------------------------------------------------- /general/remove_unreferenced.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def remove_unreferenced(V,F): 4 | """ 5 | remove invalid faces, [-1,-1,-1], and unreferenced vertices from the intrinsic mesh 6 | 7 | Inputs 8 | V: |V|x3 array of vertex locations 9 | F: |F|x3 array of face list 10 | 11 | Outputs 12 | V,F: new mesh 13 | IMV: index map for vertices such that IMV[old_vIdx] = new_vIdx 14 | """ 15 | V = np.array(V) 16 | F = np.array(F) 17 | 18 | # removed unreferenced vertices/faces 19 | nV = V.shape[0] 20 | 21 | # get a list of unique vertex indices from face list 22 | uV = np.unique(F) 23 | 24 | # index map for vertices such that IMV[old_vIdx] = new_vIdx 25 | IMV = np.zeros(nV, dtype = np.int64) 26 | IMV[uV] = np.arange(len(uV)) 27 | 28 | # return the new mesh 29 | V = V[uV] 30 | F = IMV[F] 31 | return V,F,IMV -------------------------------------------------------------------------------- /general/sample_2D_grid.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax.numpy as np 3 | 4 | def sample_2D_grid(resolution, low = 0, high = 1): 5 | idx = onp.linspace(low,high,num=resolution) 6 | x, y = onp.meshgrid(idx, idx) 7 | V = onp.concatenate((x.reshape((-1,1)), y.reshape((-1,1))), 1) 8 | return np.array(V) -------------------------------------------------------------------------------- /general/sdf_circle.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax.numpy as np 3 | 4 | def sdf_circle(x, r = 0.282, center = np.array([0.5,0.5])): 5 | """ 6 | output the SDF value of a circle in 2D 7 | 8 | Inputs 9 | x: nx2 array of locations 10 | r: radius of the circle 11 | center: center point of the circle 12 | 13 | Outputs 14 | array of signed distance values at x 15 | """ 16 | dx = x - center 17 | return np.sqrt(np.sum((dx)**2, axis = 1)) - r -------------------------------------------------------------------------------- /general/sdf_cross.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax.numpy as np 3 | 4 | def sdf_cross(p, bx=0.35, by=0.12, r=0.): 5 | """ 6 | output the signed distance value of a cross in 2D 7 | 8 | Inputs 9 | p: nx2 array of locations 10 | bx, by, r: parameters of the cross (please see the reference for more details) 11 | 12 | Outputs 13 | array of signed distance values at x 14 | 15 | Reference 16 | https://iquilezles.org/www/articles/distfunctions2d/distfunctions2d.htm 17 | """ 18 | p = onp.array(p - 0.5) 19 | p = onp.abs(p) 20 | p = onp.sort(p,1)[:,[1,0]] 21 | b = onp.array([bx, by]) 22 | q = p - b 23 | k = onp.max(q, 1) 24 | w = q 25 | w[k<=0,0] = b[1] - p[k<=0,0] 26 | w[k<=0,1] = -k[k<=0] 27 | w = onp.maximum(w, 0.0) 28 | length_w = onp.sqrt(onp.sum(w*w, 1)) 29 | out = onp.sign(k) * length_w + r 30 | return np.array(out) 31 | -------------------------------------------------------------------------------- /general/sdf_star.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax.numpy as np 3 | 4 | def sdf_star(x, r = 0.22): 5 | """ 6 | output the signed distance value of a star in 2D 7 | 8 | Inputs 9 | x: nx2 array of locations 10 | r: size of the star 11 | 12 | Outputs 13 | array of signed distance values at x 14 | 15 | Reference: 16 | https://iquilezles.org/www/articles/distfunctions2d/distfunctions2d.htm 17 | """ 18 | x = onp.array(x) 19 | kxy = onp.array([-0.5,0.86602540378]) 20 | kyx = onp.array([0.86602540378,-0.5]) 21 | kz = 0.57735026919 22 | kw = 1.73205080757 23 | 24 | x = onp.abs(x - 0.5) 25 | x -= 2.0 * onp.minimum(x.dot(kxy), 0.0)[:,None] * kxy[None,:] 26 | x -= 2.0 * onp.minimum(x.dot(kyx), 0.0)[:,None] * kyx[None,:] 27 | x[:,0] -= onp.clip(x[:,0],r*kz,r*kw) 28 | x[:,1] -= r 29 | length_x = onp.sqrt(onp.sum(x*x, 1)) 30 | return np.array(length_x*onp.sign(x[:,1])) -------------------------------------------------------------------------------- /general/sdf_triangle.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax.numpy as np 3 | 4 | def sdf_triangle(p, p0 = onp.array([.2,.2]), p1 = onp.array([.8,.2]), p2 = onp.array([.5,.8])): 5 | """ 6 | output the signed distance value of a triangle in 2D 7 | 8 | Inputs 9 | p: nx2 array of locations 10 | p0,p1,p2: locations of the triangle corners 11 | 12 | Outputs 13 | array of signed distance values at x 14 | """ 15 | p = onp.array(p) 16 | e0 = p1 - p0 17 | e1 = p2 - p1 18 | e2 = p0 - p2 19 | v0 = p - p0 20 | v1 = p - p1 21 | v2 = p - p2 22 | pq0 = v0 - e0[None,:] * onp.clip( v0.dot(e0) / e0.dot(e0), 0.0, 1.0 )[:,None] 23 | pq1 = v1 - e1[None,:] * onp.clip( v1.dot(e1) / e1.dot(e1), 0.0, 1.0 )[:,None] 24 | pq2 = v2 - e2[None,:] * onp.clip( v2.dot(e2) / e2.dot(e2), 0.0, 1.0 )[:,None] 25 | s = onp.sign( e0[0]*e2[1] - e0[1]*e2[0] ) 26 | 27 | pq0pq0 = onp.sum(pq0*pq0, 1) 28 | d0 = onp.array([pq0pq0, s*(v0[:,0]*e0[1]-v0[:,1]*e0[0])]).T 29 | pq1pq1 = onp.sum(pq1*pq1, 1) 30 | d1 = onp.array([pq1pq1, s*(v1[:,0]*e1[1]-v1[:,1]*e1[0])]).T 31 | pq2pq2 = onp.sum(pq2*pq2, 1) 32 | d2 = onp.array([pq2pq2, s*(v2[:,0]*e2[1]-v2[:,1]*e2[0])]).T 33 | d = onp.minimum(onp.minimum(d0, d1),d2) 34 | out = -onp.sqrt(d[:,0]) * onp.sign(d[:,1]) 35 | return np.array(out) 36 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/logo.png --------------------------------------------------------------------------------