├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── demos
├── lipschitz_mlp
│ ├── README.md
│ ├── lipschitz_mlp_interpolation.mp4
│ ├── lipschitz_mlp_loss_history.jpg
│ ├── lipschitz_mlp_params.pkl
│ ├── main_lipmlp.py
│ └── model.py
├── nerf
│ ├── nerf_interpolation.mp4
│ ├── nerf_loss_history.jpg
│ ├── nerf_params.pkl
│ ├── test_nerf.py
│ ├── tiny_nerf_data.npz
│ ├── train_nerf.py
│ └── utils.py
├── neural_sdf
│ ├── ground truth (t=0).png
│ ├── ground truth (t=1).png
│ ├── loss_history.jpg
│ ├── main.py
│ ├── mlp_params.pkl
│ ├── model.py
│ ├── network output (t=0).png
│ └── network output (t=1).png
└── normal_driven_stylization
│ ├── .polyscope.ini
│ ├── imgui.ini
│ ├── loss_history.jpg
│ ├── main.py
│ ├── opt.obj
│ └── spot.obj
├── differentiable
├── angle_defect.py
├── cotangent_weights.py
├── dihedral_angles.py
├── dotrow.py
├── face_areas.py
├── face_normals.py
├── fit_rotations_cayley.py
├── halfedge_lengths.py
├── normalize_unit_box.py
├── normalizerow.py
├── normrow.py
├── ramp_smooth.py
├── tip_angles.py
├── vertex_areas.py
└── vertex_normals.py
├── external
├── read_mesh.py
├── signed_distance.py
└── write_obj.py
├── general
├── adjacency_edge_face.py
├── adjacency_list_edge_face.py
├── adjacency_list_face_face.py
├── adjacency_list_vertex_face.py
├── adjacency_list_vertex_vertex.py
├── adjacency_vertex_vertex.py
├── boundary_vertices.py
├── cotmatrix.py
├── edge_flaps.py
├── edges.py
├── edges_with_mapping.py
├── find_index.py
├── he_initialization.py
├── knn_search.py
├── list_remove_indices.py
├── massmatrix.py
├── mid_point_curve_simplification.py
├── ordered_outline.py
├── outline.py
├── remove_unreferenced.py
├── sample_2D_grid.py
├── sdf_circle.py
├── sdf_cross.py
├── sdf_star.py
└── sdf_triangle.py
└── logo.png
/.gitignore:
--------------------------------------------------------------------------------
1 | unit_test
2 | not_ready
3 | old
4 | demos/shapeup_with_GD
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103 | __pypackages__/
104 |
105 | # Celery stuff
106 | celerybeat-schedule
107 | celerybeat.pid
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 |
136 | # Pyre type checker
137 | .pyre/
138 |
139 | # pytype static type analyzer
140 | .pytype/
141 |
142 | # Cython debug symbols
143 | cython_debug/
144 |
145 | # Specific apple ignores
146 | # General
147 | .DS_Store
148 | .AppleDouble
149 | .LSOverride
150 |
151 | # Icon must end with two \r
152 | Icon
153 |
154 | # Thumbnails
155 | ._*
156 |
157 | # Files that might appear in the root of a volume
158 | .DocumentRevisions-V100
159 | .fseventsd
160 | .Spotlight-V100
161 | .TemporaryItems
162 | .Trashes
163 | .VolumeIcon.icns
164 | .com.apple.timemachine.donotpresent
165 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The code in this repository is licensed under the Apache license.
2 | By contributing to this project, you agree to license your code under the
3 | same license.
4 |
5 |
6 | ===========================================================================
7 |
8 |
9 | Apache License
10 | Version 2.0, January 2004
11 | http://www.apache.org/licenses/
12 |
13 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
14 |
15 | 1. Definitions.
16 |
17 | "License" shall mean the terms and conditions for use, reproduction,
18 | and distribution as defined by Sections 1 through 9 of this document.
19 |
20 | "Licensor" shall mean the copyright owner or entity authorized by
21 | the copyright owner that is granting the License.
22 |
23 | "Legal Entity" shall mean the union of the acting entity and all
24 | other entities that control, are controlled by, or are under common
25 | control with that entity. For the purposes of this definition,
26 | "control" means (i) the power, direct or indirect, to cause the
27 | direction or management of such entity, whether by contract or
28 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
29 | outstanding shares, or (iii) beneficial ownership of such entity.
30 |
31 | "You" (or "Your") shall mean an individual or Legal Entity
32 | exercising permissions granted by this License.
33 |
34 | "Source" form shall mean the preferred form for making modifications,
35 | including but not limited to software source code, documentation
36 | source, and configuration files.
37 |
38 | "Object" form shall mean any form resulting from mechanical
39 | transformation or translation of a Source form, including but
40 | not limited to compiled object code, generated documentation,
41 | and conversions to other media types.
42 |
43 | "Work" shall mean the work of authorship, whether in Source or
44 | Object form, made available under the License, as indicated by a
45 | copyright notice that is included in or attached to the work
46 | (an example is provided in the Appendix below).
47 |
48 | "Derivative Works" shall mean any work, whether in Source or Object
49 | form, that is based on (or derived from) the Work and for which the
50 | editorial revisions, annotations, elaborations, or other modifications
51 | represent, as a whole, an original work of authorship. For the purposes
52 | of this License, Derivative Works shall not include works that remain
53 | separable from, or merely link (or bind by name) to the interfaces of,
54 | the Work and Derivative Works thereof.
55 |
56 | "Contribution" shall mean any work of authorship, including
57 | the original version of the Work and any modifications or additions
58 | to that Work or Derivative Works thereof, that is intentionally
59 | submitted to Licensor for inclusion in the Work by the copyright owner
60 | or by an individual or Legal Entity authorized to submit on behalf of
61 | the copyright owner. For the purposes of this definition, "submitted"
62 | means any form of electronic, verbal, or written communication sent
63 | to the Licensor or its representatives, including but not limited to
64 | communication on electronic mailing lists, source code control systems,
65 | and issue tracking systems that are managed by, or on behalf of, the
66 | Licensor for the purpose of discussing and improving the Work, but
67 | excluding communication that is conspicuously marked or otherwise
68 | designated in writing by the copyright owner as "Not a Contribution."
69 |
70 | "Contributor" shall mean Licensor and any individual or Legal Entity
71 | on behalf of whom a Contribution has been received by Licensor and
72 | subsequently incorporated within the Work.
73 |
74 | 2. Grant of Copyright License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | copyright license to reproduce, prepare Derivative Works of,
78 | publicly display, publicly perform, sublicense, and distribute the
79 | Work and such Derivative Works in Source or Object form.
80 |
81 | 3. Grant of Patent License. Subject to the terms and conditions of
82 | this License, each Contributor hereby grants to You a perpetual,
83 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
84 | (except as stated in this section) patent license to make, have made,
85 | use, offer to sell, sell, import, and otherwise transfer the Work,
86 | where such license applies only to those patent claims licensable
87 | by such Contributor that are necessarily infringed by their
88 | Contribution(s) alone or by combination of their Contribution(s)
89 | with the Work to which such Contribution(s) was submitted. If You
90 | institute patent litigation against any entity (including a
91 | cross-claim or counterclaim in a lawsuit) alleging that the Work
92 | or a Contribution incorporated within the Work constitutes direct
93 | or contributory patent infringement, then any patent licenses
94 | granted to You under this License for that Work shall terminate
95 | as of the date such litigation is filed.
96 |
97 | 4. Redistribution. You may reproduce and distribute copies of the
98 | Work or Derivative Works thereof in any medium, with or without
99 | modifications, and in Source or Object form, provided that You
100 | meet the following conditions:
101 |
102 | (a) You must give any other recipients of the Work or
103 | Derivative Works a copy of this License; and
104 |
105 | (b) You must cause any modified files to carry prominent notices
106 | stating that You changed the files; and
107 |
108 | (c) You must retain, in the Source form of any Derivative Works
109 | that You distribute, all copyright, patent, trademark, and
110 | attribution notices from the Source form of the Work,
111 | excluding those notices that do not pertain to any part of
112 | the Derivative Works; and
113 |
114 | (d) If the Work includes a "NOTICE" text file as part of its
115 | distribution, then any Derivative Works that You distribute must
116 | include a readable copy of the attribution notices contained
117 | within such NOTICE file, excluding those notices that do not
118 | pertain to any part of the Derivative Works, in at least one
119 | of the following places: within a NOTICE text file distributed
120 | as part of the Derivative Works; within the Source form or
121 | documentation, if provided along with the Derivative Works; or,
122 | within a display generated by the Derivative Works, if and
123 | wherever such third-party notices normally appear. The contents
124 | of the NOTICE file are for informational purposes only and
125 | do not modify the License. You may add Your own attribution
126 | notices within Derivative Works that You distribute, alongside
127 | or as an addendum to the NOTICE text from the Work, provided
128 | that such additional attribution notices cannot be construed
129 | as modifying the License.
130 |
131 | You may add Your own copyright statement to Your modifications and
132 | may provide additional or different license terms and conditions
133 | for use, reproduction, or distribution of Your modifications, or
134 | for any such Derivative Works as a whole, provided Your use,
135 | reproduction, and distribution of the Work otherwise complies with
136 | the conditions stated in this License.
137 |
138 | 5. Submission of Contributions. Unless You explicitly state otherwise,
139 | any Contribution intentionally submitted for inclusion in the Work
140 | by You to the Licensor shall be under the terms and conditions of
141 | this License, without any additional terms or conditions.
142 | Notwithstanding the above, nothing herein shall supersede or modify
143 | the terms of any separate license agreement you may have executed
144 | with Licensor regarding such Contributions.
145 |
146 | 6. Trademarks. This License does not grant permission to use the trade
147 | names, trademarks, service marks, or product names of the Licensor,
148 | except as required for reasonable and customary use in describing the
149 | origin of the Work and reproducing the content of the NOTICE file.
150 |
151 | 7. Disclaimer of Warranty. Unless required by applicable law or
152 | agreed to in writing, Licensor provides the Work (and each
153 | Contributor provides its Contributions) on an "AS IS" BASIS,
154 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
155 | implied, including, without limitation, any warranties or conditions
156 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
157 | PARTICULAR PURPOSE. You are solely responsible for determining the
158 | appropriateness of using or redistributing the Work and assume any
159 | risks associated with Your exercise of permissions under this License.
160 |
161 | 8. Limitation of Liability. In no event and under no legal theory,
162 | whether in tort (including negligence), contract, or otherwise,
163 | unless required by applicable law (such as deliberate and grossly
164 | negligent acts) or agreed to in writing, shall any Contributor be
165 | liable to You for damages, including any direct, indirect, special,
166 | incidental, or consequential damages of any character arising as a
167 | result of this License or out of the use or inability to use the
168 | Work (including but not limited to damages for loss of goodwill,
169 | work stoppage, computer failure or malfunction, or any and all
170 | other commercial damages or losses), even if such Contributor
171 | has been advised of the possibility of such damages.
172 |
173 | 9. Accepting Warranty or Additional Liability. While redistributing
174 | the Work or Derivative Works thereof, You may choose to offer,
175 | and charge a fee for, acceptance of support, warranty, indemnity,
176 | or other liability obligations and/or rights consistent with this
177 | License. However, in accepting such obligations, You may act only
178 | on Your own behalf and on Your sole responsibility, not on behalf
179 | of any other Contributor, and only if You agree to indemnify,
180 | defend, and hold each Contributor harmless for any liability
181 | incurred by, or claims asserted against, such Contributor by reason
182 | of your accepting any such warranty or additional liability.
183 |
184 | END OF TERMS AND CONDITIONS
185 |
186 | APPENDIX: How to apply the Apache License to your work.
187 |
188 | To apply the Apache License to your work, attach the following
189 | boilerplate notice, with the fields enclosed by brackets "[]"
190 | replaced with your own identifying information. (Don't include
191 | the brackets!) The text should be enclosed in the appropriate
192 | comment syntax for the file format. We also recommend that a
193 | file or class name and description of purpose be included on the
194 | same "printed page" as the copyright notice for easier
195 | identification within third-party archives.
196 |
197 | Copyright [yyyy] [name of copyright owner]
198 |
199 | Licensed under the Apache License, Version 2.0 (the "License");
200 | you may not use this file except in compliance with the License.
201 | You may obtain a copy of the License at
202 |
203 | http://www.apache.org/licenses/LICENSE-2.0
204 |
205 | Unless required by applicable law or agreed to in writing, software
206 | distributed under the License is distributed on an "AS IS" BASIS,
207 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
208 | See the License for the specific language governing permissions and
209 | limitations under the License.
210 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # jaxgptoolbox
6 |
7 | This is a collection of basic geometry processing functions, constructed to work with [jax](https://github.com/google/jax)'s autodifferentiation feature for applications in machine learning. We split these functions into _not differentiable_ ones in the `general` folder, and differentiable ones in the `differentiable` folder. We also include some wrappers of third party functions in the `external` folder for convenience. To use these utility functions, one can simply import this package and use it as
8 | ```
9 | import jaxgptoolbox as jgp
10 | import polyscope as ps
11 | V,F = jgp.read_mesh('path_to_OBJ')
12 | ps.init()
13 | ps.register_surface_mesh('my_mesh',V,F)
14 | ps.show()
15 | ```
16 |
17 | ### Dependencies
18 |
19 | This library depends on [jax](https://github.com/google/jax) and some common python libraries [numpy](https://github.com/numpy/numpy) [scipy](https://github.com/scipy/scipy). Our `demos` rely on [matplotlib](https://github.com/matplotlib/matplotlib) and [polyscope](https://polyscope.run/py/) for visualization. Some functions in the `external` folder depend on [libigl](https://libigl.github.io/libigl-python-bindings/). Please make sure to install all dependencies (for example, with [conda](https://docs.conda.io/projects/conda/en/latest/index.html)) before using the library.
20 |
21 | ### Contacts & Warnings
22 |
23 | The toolbox grew out of [Oded Stein](https://odedstein.com)'s and [Hsueh-Ti Derek Liu](https://www.dgp.toronto.edu/~hsuehtil/)'s private research codebase during their PhD studies. Some of these functions are not fully tested nor optimized, please use them with caution. If you're interested in contributing or noticing any issues, please contact us (ostein@mit.edu, hsuehtil@cs.toronto.edu) or submit a pull request.
24 |
25 | ### License
26 |
27 | Consult the [LICENSE](LICENSE) file for details about the license of this project.
28 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .general.adjacency_edge_face import adjacency_edge_face
2 | from .general.adjacency_list_edge_face import adjacency_list_edge_face
3 | from .general.adjacency_list_face_face import adjacency_list_face_face
4 | from .general.adjacency_list_vertex_face import adjacency_list_vertex_face
5 | from .general.adjacency_list_vertex_vertex import adjacency_list_vertex_vertex
6 | from .general.adjacency_vertex_vertex import adjacency_vertex_vertex
7 | from .general.boundary_vertices import boundary_vertices
8 | from .general.cotmatrix import cotmatrix
9 | from .general.edges import edges
10 | from .general.edge_flaps import edge_flaps
11 | from .general.edges_with_mapping import edges_with_mapping
12 | from .general.find_index import find_index
13 | from .general.he_initialization import he_initialization
14 | from .general.knn_search import knn_search
15 | from .general.massmatrix import massmatrix
16 | from .general.mid_point_curve_simplification import mid_point_curve_simplification
17 | from .general.ordered_outline import ordered_outline
18 | from .general.outline import outline
19 | from .general.remove_unreferenced import remove_unreferenced
20 | from .general.sample_2D_grid import sample_2D_grid
21 | from .general.sdf_circle import sdf_circle
22 | from .general.sdf_cross import sdf_cross
23 | from .general.sdf_star import sdf_star
24 | from .general.sdf_triangle import sdf_triangle
25 |
26 | from .differentiable.angle_defect import angle_defect
27 | from .differentiable.angle_defect import angle_defect_intrinsic
28 | from .differentiable.cotangent_weights import cotangent_weights
29 | from .differentiable.dihedral_angles import dihedral_angles
30 | from .differentiable.dihedral_angles import dihedral_angles_from_normals
31 | from .differentiable.dotrow import dotrow
32 | from .differentiable.face_areas import face_areas
33 | from .differentiable.face_normals import face_normals
34 | from .differentiable.fit_rotations_cayley import fit_rotations_cayley
35 | from .differentiable.halfedge_lengths import halfedge_lengths
36 | from .differentiable.halfedge_lengths import halfedge_lengths_squared
37 | from .differentiable.normalize_unit_box import normalize_unit_box
38 | from .differentiable.normalizerow import normalizerow
39 | from .differentiable.normrow import normrow
40 | from .differentiable.ramp_smooth import ramp_smooth
41 | from .differentiable.tip_angles import tip_angles
42 | from .differentiable.tip_angles import tip_angles_intrinsic
43 | from .differentiable.vertex_areas import vertex_areas
44 | from .differentiable.vertex_normals import vertex_normals
45 |
46 | from .external.signed_distance import signed_distance
47 | from .external.read_mesh import read_mesh
48 | from .external.write_obj import write_obj
--------------------------------------------------------------------------------
/demos/lipschitz_mlp/README.md:
--------------------------------------------------------------------------------
1 | # Learning Smooth Neural Functions via Lipschitz Regularization
2 |
3 | This is the demo code for the Lipschitz MLP:
4 |
5 | **Learning Smooth Neural Functions via Lipschitz Regularization**
6 | _Hsueh-Ti Derek Liu, Francis Williams, Alec Jacobson, Sanja Fidler, Or Litany_
7 | SIGGRAPH (North America), 2022
8 | [[Project Page](https://nv-tlabs.github.io/lip-mlp/)] [[Preprint](https://www.dgp.toronto.edu/~hsuehtil/pdf/lipmlp.pdf)]
9 |
10 | ### Dependencies
11 | Our method depends on [JAX](https://github.com/google/jax) and some common python dependencies (e.g., numpy, tqdm, matplotlib, etc.). Some functions in the script, such as generating analytical signed distance functions, depend on other parts in the repository -- [jaxgptoolbox](https://github.com/ml-for-gp/jaxgptoolbox).
12 |
13 | ### Repository Structure
14 | - `main_lipmlp.py` is the main training script. This is a self-contained script to train a Lipschitz MLP to interpolate 2D signed distance functions of a star and a circle. To train the model from scratch, one can simply run
15 | ```python
16 | python main_lipmlp.py
17 | ```
18 | After training (~15 min on a CPU), you should see the interpolation results in `lipschitz_mlp_interpolation.mp4` and the model parameters in `lipschitz_mlp_params.pkl`.
19 | - `model.py` contains the Lipschitz MLP model. One can simply use it as
20 | ```python
21 | model = lipmlp(hyper_params) # build the model
22 | params = model.initialize_weights() # initialize weights
23 | y = model.forward(params, latent_code, x) # forward pass
24 | ```
25 |
26 |
--------------------------------------------------------------------------------
/demos/lipschitz_mlp/lipschitz_mlp_interpolation.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/lipschitz_mlp/lipschitz_mlp_interpolation.mp4
--------------------------------------------------------------------------------
/demos/lipschitz_mlp/lipschitz_mlp_loss_history.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/lipschitz_mlp/lipschitz_mlp_loss_history.jpg
--------------------------------------------------------------------------------
/demos/lipschitz_mlp/lipschitz_mlp_params.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/lipschitz_mlp/lipschitz_mlp_params.pkl
--------------------------------------------------------------------------------
/demos/lipschitz_mlp/main_lipmlp.py:
--------------------------------------------------------------------------------
1 | from model import *
2 |
3 | # implementation of "Learning Smooth Neural Functions via Lipschitz Regularization" by Liu et al. 2022
4 | if __name__ == '__main__':
5 | random.seed(1)
6 |
7 | # hyper parameters
8 | hyper_params = {
9 | "dim_in": 2,
10 | "dim_t": 1,
11 | "dim_out": 1,
12 | "h_mlp": [64,64,64,64,64],
13 | "step_size": 1e-4,
14 | "grid_size": 32,
15 | "num_epochs": 200000,
16 | "samples_per_epoch": 512
17 | }
18 | alpha = 1e-6
19 |
20 | # initialize a mlp
21 | model = lipmlp(hyper_params)
22 | params = model.initialize_weights()
23 |
24 | # optimizer
25 | opt_init, opt_update, get_params = optimizers.adam(step_size=hyper_params["step_size"])
26 | opt_state = opt_init(params)
27 |
28 | # define loss function and update function
29 | def loss(params_, alpha, x_, y0_, y1_):
30 | out0 = model.forward(params_, np.array([0.0]), x_) # star when t = 0.0
31 | out1 = model.forward(params_, np.array([1.0]), x_) # circle when t = 1.0
32 | loss_sdf = np.mean((out0 - y0_)**2) + np.mean((out1 - y1_)**2)
33 | loss_lipschitz = model.get_lipschitz_loss(params_)
34 | return loss_sdf + alpha * loss_lipschitz
35 |
36 | @jit
37 | def update(epoch, opt_state, alpha, x_, y0_, y1_):
38 | params_ = get_params(opt_state)
39 | value, grads = value_and_grad(loss, argnums = 0)(params_, alpha, x_, y0_, y1_)
40 | opt_state = opt_update(epoch, grads, opt_state)
41 | return value, opt_state
42 |
43 | # training
44 | loss_history = onp.zeros(hyper_params["num_epochs"])
45 | pbar = tqdm.tqdm(range(hyper_params["num_epochs"]))
46 | for epoch in pbar:
47 | # sample a bunch of random points
48 | x = np.array(random.rand(hyper_params["samples_per_epoch"], hyper_params["dim_in"]))
49 | y0 = jgp.sdf_star(x)
50 | y1 = jgp.sdf_circle(x)
51 |
52 | # update
53 | loss_value, opt_state = update(epoch, opt_state, alpha, x, y0, y1)
54 | loss_history[epoch] = loss_value
55 | pbar.set_postfix({"loss": loss_value})
56 |
57 | if epoch % 1000 == 0: # plot loss history every 1000 iter
58 | plt.close(1)
59 | plt.figure(1)
60 | plt.semilogy(loss_history[:epoch])
61 | plt.title('Reconstruction loss + Lipschitz loss')
62 | plt.grid()
63 | plt.savefig("lipschitz_mlp_loss_history.jpg")
64 |
65 | # save final parameters
66 | params = get_params(opt_state)
67 | with open("lipschitz_mlp_params.pkl", 'wb') as handle:
68 | pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL)
69 |
70 | # normalize weights during test time
71 | params_final = model.normalize_params(params)
72 |
73 | # save result as a video
74 | sdf_cm = mpl.colors.LinearSegmentedColormap.from_list('SDF', [(0,'#eff3ff'),(0.5,'#3182bd'),(0.5,'#31a354'),(1,'#e5f5e0')], N=256)
75 |
76 | # create video
77 | fig = plt.figure()
78 | x = jgp.sample_2D_grid(hyper_params["grid_size"]) # sample on unit grid for visualization
79 | def animate(t):
80 | plt.cla()
81 | out = model.forward_eval(params_final, np.array([t]), x)
82 | levels = onp.linspace(-0.5, 0.5, 21)
83 | im = plt.contourf(out.reshape(hyper_params['grid_size'],hyper_params['grid_size']), levels = levels, cmap=sdf_cm)
84 | plt.axis('equal')
85 | plt.axis("off")
86 | return im
87 | anim = animation.FuncAnimation(fig, animate, frames=np.linspace(0, 1, 50), interval=50)
88 | anim.save("lipschitz_mlp_interpolation.mp4")
--------------------------------------------------------------------------------
/demos/lipschitz_mlp/model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../../')
3 | import jaxgptoolbox as jgp
4 |
5 | import jax
6 | import jax.numpy as np
7 | from jax import jit, value_and_grad
8 | from jax.experimental import optimizers
9 |
10 | import numpy as onp
11 | import numpy.random as random
12 | import matplotlib as mpl
13 | import matplotlib.pyplot as plt
14 | import matplotlib.animation as animation
15 | import tqdm
16 | import pickle
17 |
18 | class lipmlp:
19 | def __init__(self, hyperParams):
20 | self.hyperParams = hyperParams
21 |
22 | def initialize_weights(self):
23 | """
24 | Initialize the parameters of the Lipschitz mlp
25 |
26 | Inputs
27 | hyperParams: hyper parameter dictionary
28 |
29 | Outputs
30 | params_net: parameters of the network (weight, bias, initial lipschitz bound)
31 | """
32 | def init_W(size_out, size_in):
33 | W = onp.random.randn(size_out, size_in) * onp.sqrt(2 / size_in)
34 | return np.array(W)
35 | sizes = self.hyperParams["h_mlp"]
36 | sizes.insert(0, self.hyperParams["dim_in"] + self.hyperParams["dim_t"])
37 | sizes.append(self.hyperParams["dim_out"])
38 | params_net = []
39 | for ii in range(len(sizes) - 1):
40 | W = init_W(sizes[ii+1], sizes[ii])
41 | b = np.zeros(sizes[ii+1])
42 | c = np.max(np.sum(np.abs(W), axis=1))
43 | params_net.append([W, b, c])
44 | return params_net
45 |
46 | def weight_normalization(self, W, softplus_c):
47 | """
48 | Lipschitz weight normalization based on the L-infinity norm
49 | """
50 | absrowsum = np.sum(np.abs(W), axis=1)
51 | scale = np.minimum(1.0, softplus_c/absrowsum)
52 | return W * scale[:,None]
53 |
54 | def forward_single(self, params_net, t, x):
55 | """
56 | Forward pass of a lipschitz MLP
57 |
58 | Inputs
59 | params_net: parameters of the network
60 | t: the input feature of the shape
61 | x: a query location in the space
62 |
63 | Outputs
64 | out: implicit function value at x
65 | """
66 | # concatenate coordinate and latent code
67 | x = np.append(x, t)
68 |
69 | # forward pass
70 | for ii in range(len(params_net) - 1):
71 | W, b, c = params_net[ii]
72 | W = self.weight_normalization(W, jax.nn.softplus(c))
73 | x = jax.nn.relu(np.dot(W, x) + b)
74 |
75 | # final layer
76 | W, b, c = params_net[-1]
77 | W = self.weight_normalization(W, jax.nn.softplus(c))
78 | out = np.dot(W, x) + b
79 | return out[0]
80 | forward = jax.vmap(forward_single, in_axes=(None, None, None, 0), out_axes=0)
81 |
82 | def get_lipschitz_loss(self, params_net):
83 | """
84 | This function computes the Lipschitz regularization
85 | """
86 | loss_lip = 1.0
87 | for ii in range(len(params_net)):
88 | W, b, c = params_net[ii]
89 | loss_lip = loss_lip * jax.nn.softplus(c)
90 | return loss_lip
91 |
92 | def normalize_params(self, params_net):
93 | """
94 | (Optional) After training, this function will clip network [W, b] based on learned lipschitz constants. Thus, one can use normal MLP forward pass during test time, which is a little bit faster.
95 | """
96 | params_final = []
97 | for ii in range(len(params_net)):
98 | W, b, c = params_net[ii]
99 | W = self.weight_normalization(W, jax.nn.softplus(c))
100 | params_final.append([W, b])
101 | return params_final
102 |
103 | def forward_eval_single(self, params_final, t, x):
104 | """
105 | (Optional) this is a standard forward pass of a mlp. This is useful to speed up the performance during test time
106 | """
107 | # concatenate coordinate and latent code
108 | x = np.append(x, t)
109 |
110 | # forward pass
111 | for ii in range(len(params_final) - 1):
112 | W, b = params_final[ii]
113 | x = jax.nn.relu(np.dot(W, x) + b)
114 | W, b = params_final[-1] # final layer
115 | out = np.dot(W, x) + b
116 | return out[0]
117 | forward_eval = jax.vmap(forward_eval_single, in_axes=(None, None, None, 0), out_axes=0)
--------------------------------------------------------------------------------
/demos/nerf/nerf_interpolation.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/nerf/nerf_interpolation.mp4
--------------------------------------------------------------------------------
/demos/nerf/nerf_loss_history.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/nerf/nerf_loss_history.jpg
--------------------------------------------------------------------------------
/demos/nerf/nerf_params.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/nerf/nerf_params.pkl
--------------------------------------------------------------------------------
/demos/nerf/test_nerf.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 |
3 | if __name__ == "__main__":
4 | poses = generate_test_poses()
5 | imgs, _, focal = load_tiny_nerf('./tiny_nerf_data.npz')
6 |
7 | # hyper parameters
8 | hyper_params = {
9 | "n_in": 3,
10 | "n_out": 4,
11 | "n_pos_encode": 6,
12 | "n_samples_per_path": 128,
13 | "near_plane": 2.0,
14 | "far_plane": 6.0,
15 | "h_mlp": [128,128,128],
16 | "step_size": 1e-4,
17 | "num_epochs": 1000,
18 | "batch_size": 200
19 | }
20 |
21 | # initialize a nerf network
22 | model = NeRF(hyper_params)
23 | with open('nerf_params.pkl', 'rb') as handle:
24 | params = pickle.load(handle)
25 |
26 | # create video
27 | n_imgs = poses.shape[0]
28 | img_H = imgs.shape[1]
29 | img_W = imgs.shape[2]
30 | fig = plt.figure()
31 | def animate(ii):
32 | plt.cla()
33 | orig, dirs = generate_rays_from_camera(img_H, img_W, focal, poses[ii])
34 | dirs = np.reshape(dirs, (-1, 3))
35 |
36 | out = model.path_integral(params, orig, dirs)
37 | out = onp.reshape(onp.array(out), (img_H, img_W, 3))
38 | im = plt.imshow(out)
39 | return im
40 | anim = animation.FuncAnimation(fig, animate, frames=np.arange(n_imgs), interval=50)
41 | anim.save("nerf_interpolation.mp4")
42 |
43 |
--------------------------------------------------------------------------------
/demos/nerf/tiny_nerf_data.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/nerf/tiny_nerf_data.npz
--------------------------------------------------------------------------------
/demos/nerf/train_nerf.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 |
3 | # Re-implementation of "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis" by Mildenhall et al 2020
4 | if __name__ == "__main__":
5 | imgs, poses, focal = load_tiny_nerf('./tiny_nerf_data.npz')
6 | # jax_imshow(imgs[0])
7 |
8 | # hyper parameters
9 | hyper_params = {
10 | "n_in": 3,
11 | "n_out": 4,
12 | "n_pos_encode": 6,
13 | "n_samples_per_path": 256,
14 | "near_plane": 2.0,
15 | "far_plane": 6.0,
16 | "h_mlp": [128,128,128,128,128,128],
17 | "step_size": 1e-4,
18 | "num_epochs": 300,
19 | "batch_size": 200
20 | }
21 |
22 | # initialize a nerf network
23 | model = NeRF(hyper_params)
24 | params = model.initialize_weights()
25 |
26 | # optimizer
27 | opt_init, opt_update, get_params = optimizers.adam(step_size=hyper_params["step_size"])
28 | opt_state = opt_init(params)
29 |
30 | def loss(params_, origin_, directions_, img_):
31 | out = model.path_integral(params_, origin_, directions_)
32 | out = np.clip(out, 0.0, 1.0)
33 | loss_val = np.mean((out - img_)**2)
34 | return loss_val
35 |
36 | @jit
37 | def update(epoch, opt_state, origin_, directions_, img_):
38 | params_ = get_params(opt_state)
39 | value, grads = value_and_grad(loss, argnums = 0)(params_, origin_, directions_, img_)
40 | opt_state = opt_update(epoch, grads, opt_state)
41 | return value, opt_state
42 |
43 | # training
44 | loss_history = onp.zeros(hyper_params["num_epochs"])
45 | pbar = tqdm.tqdm(range(hyper_params["num_epochs"]))
46 | # data for training
47 | n_imgs = imgs.shape[0]
48 | img_H = imgs.shape[1]
49 | img_W = imgs.shape[2]
50 | batch_size = hyper_params["batch_size"]
51 | for epoch in pbar:
52 | for ii in range(n_imgs):
53 | # gradient step
54 | orig, dirs = generate_rays_from_camera(img_H, img_W, focal, poses[ii])
55 | dirs = np.reshape(dirs, (-1, 3))
56 | img = np.reshape(imgs[ii], (-1, 3))
57 |
58 | # split rays into batches
59 | img_batches = np.array_split(img, batch_size)
60 | dirs_batches = np.array_split(dirs, batch_size)
61 | for b in range(batch_size):
62 | img_b = img_batches[b]
63 | dirs_b = dirs_batches[b]
64 | loss_value, opt_state = update(epoch, opt_state, orig, dirs_b, img_b)
65 |
66 | # save loss
67 | loss_history[epoch] += loss_value / n_imgs / batch_size
68 | pbar.set_postfix({"loss": loss_history[epoch]})
69 |
70 | if epoch % 1 == 0: # plot loss history every 1000 iter
71 | plt.close(1)
72 | plt.figure(1)
73 | plt.semilogy(loss_history[:epoch])
74 | plt.title('reconstruction loss')
75 | plt.grid()
76 | plt.savefig("nerf_loss_history.jpg")
77 |
78 | params = get_params(opt_state)
79 | with open("nerf_params.pkl", 'wb') as handle:
80 | pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL)
81 |
82 | # save final parameters
83 | params = get_params(opt_state)
84 | with open("nerf_params.pkl", 'wb') as handle:
85 | pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL)
--------------------------------------------------------------------------------
/demos/nerf/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../../')
3 | import jaxgptoolbox as jgp
4 | import numpy as onp
5 | import jax
6 | import jax.numpy as np
7 | from jax import jit, value_and_grad
8 | from jax.experimental import optimizers
9 | import pickle
10 |
11 | from matplotlib import pyplot as plt
12 | import matplotlib.animation as animation
13 | import cv2
14 | import tqdm
15 |
16 | def load_tiny_nerf(path):
17 | data = onp.load(path)
18 | images = data["images"] # n_img x H x W x 3
19 | poses = data["poses"] # n_img x 4 x 4
20 | focal = float(data["focal"])
21 | return np.array(images), np.array(poses), focal
22 |
23 | def jax_imshow(jax_img):
24 | img = onp.array(jax_img)
25 | img = img[:,:,[2,1,0]]
26 | scale = int(600. / img.shape[0])
27 | dim = (img.shape[0]*scale, img.shape[1]*scale)
28 | img_resize = cv2.resize(img, dim)
29 | cv2.imshow('image', img_resize)
30 | cv2.waitKey(0)
31 |
32 | class NeRF:
33 | def __init__(self, hyper_params):
34 | self.hyper_params = hyper_params
35 |
36 | def initialize_weights(self):
37 | """
38 | initialize network weights
39 | """
40 | sizes = self.hyper_params["h_mlp"]
41 |
42 | # add input dimension
43 | n_raw_in = self.hyper_params["n_in"] # point dim
44 | n_encode = self.hyper_params["n_pos_encode"] # positional encoding with sin/cos
45 | n_in = n_raw_in + n_raw_in * 2 * n_encode # "2" is due to sin/cos
46 | sizes.insert(0, n_in)
47 |
48 | sizes.append( self.hyper_params["n_out"]) # add output dimension
49 |
50 | # initialization network parameters
51 | params = []
52 | for ii in range(len(sizes) - 1):
53 | if ii == (len(sizes) - 2):
54 | # last layer only outputs RGB, n_out - 1
55 | # last layer inputs additional view dir, n_in + 3
56 | W, b = jgp.he_initialization(sizes[ii+1]-1, sizes[ii]+3)
57 | elif ii == (len(sizes) - 3):
58 | # second last layer outputs additional volume density
59 | W, b = jgp.he_initialization(sizes[ii+1]+1, sizes[ii])
60 | else:
61 | W, b = jgp.he_initialization(sizes[ii+1], sizes[ii])
62 | print(W.shape)
63 | params.append([W, b])
64 | # last layer add view direction
65 | return params
66 |
67 | def positional_encoding(self, x_in):
68 | """
69 | positional encoding: x -> [x, sin(w*x), cos(w*x)]
70 | where w = 2^[0, 1, ..., n_encode-1]
71 | """
72 | n_encode = self.hyper_params["n_pos_encode"]
73 | w = np.power(2., np.arange(0., n_encode))
74 | x = w[:,None] * x_in[None,:]
75 | x = x.flatten()
76 | return np.concatenate((x_in, np.sin(2*np.pi*x), np.cos(2*np.pi*x)))
77 |
78 | def activation(self, x):
79 | return jax.nn.leaky_relu(x)
80 |
81 | def forward_single(self, params, dir, x):
82 | x = self.positional_encoding(x)
83 | for ii in range(len(params) - 1):
84 | W, b = params[ii]
85 | x = self.activation(np.dot(W, x) + b)
86 | # second last layer outputs volume density (sigma)
87 | sigma = jax.nn.relu(x[0])
88 | x = x[1:]
89 | # last layer append view direction
90 | x = np.concatenate((x, dir))
91 | W, b = params[-1]
92 | rgb = jax.nn.sigmoid(np.dot(W, x) + b)
93 | return np.append(rgb, sigma)
94 | forward = jax.vmap(forward_single, in_axes=(None, None, None, 0), out_axes=0)
95 |
96 | def path_integral_single(self, params, orig, dir):
97 | # generate query locations
98 | n_samples = self.hyper_params["n_samples_per_path"]
99 | near = self.hyper_params["near_plane"]
100 | far = self.hyper_params["far_plane"]
101 | dists = np.linspace(near, far, n_samples)
102 | x = orig[None,:] + dir[None,:] * dists[:,None]
103 |
104 | # forward pass to get (r,g,b,density)
105 | rgbd = self.forward(params, dir, x)
106 |
107 | # volume integral to get colors (see NeRF paper Eq.3)
108 | d = dists[1] - dists[0] # distance between samples
109 | cumsum_density = np.cumsum(rgbd[:,3])
110 | T = np.exp(-cumsum_density*d)
111 | w = T * (1.-np.exp(-rgbd[:,3]*d)) # weights of each evaluation
112 | rgb_out = w.dot(rgbd[:,:3]) # integrate along the path
113 | return rgb_out
114 | path_integral = jax.vmap(path_integral_single, in_axes=(None, None, None, 0), out_axes=0)
115 |
116 | def generate_rays_from_camera(image_height, image_width, focal, pose):
117 | """
118 | generate camera rays in the world space for each pixel
119 |
120 | Inputs
121 | image_height: number of pixels in H direction
122 | image_width: number of pixels in W direction
123 | focal: focal length of a camera
124 | pose: 4x4 array of camera pose matrix
125 |
126 | Outputs
127 | ray_origin (3,) array of ray origin
128 | ray_directions HxWx3 array of ray directions
129 | """
130 | i, j = np.meshgrid(np.arange(image_width), np.arange(image_height), indexing="xy")
131 | k = -np.ones_like(i)
132 | i = (i - image_width * 0.5) / focal
133 | j = -(j - image_height * 0.5) / focal
134 | directions = np.stack([i, j, k], axis=-1)
135 | camera_matrix = pose[:3, :3]
136 | ray_directions = np.einsum("ijl,kl", directions, camera_matrix)
137 | # ray_origins = np.broadcast_to(pose[:3, -1], ray_directions.shape)
138 | ray_origins = pose[:3, -1]
139 | return ray_origins, ray_directions
140 |
141 | def angles_to_pose(theta, phi, radius):
142 | translation = lambda t: np.asarray(
143 | [
144 | [1, 0, 0, 0],
145 | [0, 1, 0, 0],
146 | [0, 0, 1, t],
147 | [0, 0, 0, 1],
148 | ]
149 | )
150 | rotation_phi = lambda phi: np.asarray(
151 | [
152 | [1, 0, 0, 0],
153 | [0, np.cos(phi), -np.sin(phi), 0],
154 | [0, np.sin(phi), np.cos(phi), 0],
155 | [0, 0, 0, 1],
156 | ]
157 | )
158 | rotation_theta = lambda th: np.asarray(
159 | [
160 | [np.cos(th), 0, -np.sin(th), 0],
161 | [0, 1, 0, 0],
162 | [np.sin(th), 0, np.cos(th), 0],
163 | [0, 0, 0, 1],
164 | ]
165 | )
166 |
167 | pose = translation(radius)
168 | pose = rotation_phi(phi / 180.0 * np.pi) @ pose
169 | pose = rotation_theta(theta / 180.0 * np.pi) @ pose
170 | return (
171 | np.array([
172 | [-1, 0, 0, 0],
173 | [0, 0, 1, 0],
174 | [0, 1, 0, 0],
175 | [0, 0, 0, 1]
176 | ]) @ pose
177 | )
178 |
179 | def generate_test_poses():
180 | video_angle = onp.linspace(0.0, 360.0, 120, endpoint=False)
181 | poses = onp.zeros((len(video_angle), 4, 4))
182 | for ii in range(len(video_angle)):
183 | poses[ii] = angles_to_pose(video_angle[ii], -30, 4.0)
184 | return poses
--------------------------------------------------------------------------------
/demos/neural_sdf/ground truth (t=0).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/ground truth (t=0).png
--------------------------------------------------------------------------------
/demos/neural_sdf/ground truth (t=1).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/ground truth (t=1).png
--------------------------------------------------------------------------------
/demos/neural_sdf/loss_history.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/loss_history.jpg
--------------------------------------------------------------------------------
/demos/neural_sdf/main.py:
--------------------------------------------------------------------------------
1 | from model import *
2 |
3 | if __name__ == '__main__':
4 | random.seed(1)
5 |
6 | # hyper parameters
7 | hyper_params = {
8 | "dim_in": 2,
9 | "dim_t": 1,
10 | "dim_out": 1,
11 | "h_mlp": [64,64,64],
12 | "step_size": 1e-4,
13 | "grid_size": 32,
14 | "num_epochs": 50000,
15 | "samples_per_epoch": 512
16 | }
17 |
18 | # initialize a mlp
19 | model = mlp(hyper_params)
20 | params = model.initialize_weights()
21 |
22 | # optimizer
23 | opt_init, opt_update, get_params = optimizers.adam(step_size=hyper_params["step_size"])
24 | opt_state = opt_init(params)
25 |
26 | # define loss function and update function
27 | def loss(params_, x_, y0_, y1_):
28 | out0 = model.forward(params_, np.array([0.0]), x_) # star when t = 0.0
29 | out1 = model.forward(params_, np.array([1.0]), x_) # circle when t = 1.0
30 | loss_sdf = np.mean((out0 - y0_)**2) + np.mean((out1 - y1_)**2)
31 | return loss_sdf
32 |
33 | @jit
34 | def update(epoch, opt_state, x_, y0_, y1_):
35 | params_ = get_params(opt_state)
36 | value, grads = value_and_grad(loss, argnums = 0)(params_, x_, y0_, y1_)
37 | opt_state = opt_update(epoch, grads, opt_state)
38 | return value, opt_state
39 |
40 | # training
41 | loss_history = onp.zeros(hyper_params["num_epochs"])
42 | pbar = tqdm.tqdm(range(hyper_params["num_epochs"])) # progress bar
43 | for epoch in pbar:
44 | # sample a bunch of random points
45 | x = np.array(random.rand(hyper_params["samples_per_epoch"], hyper_params["dim_in"]))
46 | y0 = jgp.sdf_star(x) # target SDF values at x
47 | y1 = jgp.sdf_circle(x) # target SDF values at x
48 |
49 | # update network parameters
50 | loss_value, opt_state = update(epoch, opt_state, x, y0, y1)
51 | loss_history[epoch] = loss_value
52 | pbar.set_postfix({"loss": loss_value})
53 |
54 | if epoch % 1000 == 0: # plot loss history every 1000 iter
55 | plt.close(1)
56 | plt.figure(1)
57 | plt.semilogy(loss_history[:epoch])
58 | plt.title('Reconstruction loss')
59 | plt.grid()
60 | plt.savefig("loss_history.jpg")
61 |
62 | # save final parameters
63 | params = get_params(opt_state)
64 | with open("mlp_params.pkl", 'wb') as handle:
65 | pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL)
66 |
67 | # save results
68 | sdf_cm = mpl.colors.LinearSegmentedColormap.from_list('SDF', [(0,'#eff3ff'),(0.5,'#3182bd'),(0.5,'#31a354'),(1,'#e5f5e0')], N=256) # color map
69 | levels = onp.linspace(-0.5, 0.5, 21) # isoline
70 | x = jgp.sample_2D_grid(hyper_params["grid_size"]) # sample on unit grid for visualization
71 |
72 | fig = plt.figure()
73 | y0 = jgp.sdf_star(x)
74 | im = plt.contourf(y0.reshape(hyper_params['grid_size'],hyper_params['grid_size']), levels = levels, cmap=sdf_cm)
75 | plt.axis('equal')
76 | plt.axis("off")
77 | plt.savefig('ground truth (t=0)')
78 |
79 | plt.clf()
80 | y0_pred = model.forward(params, np.array([0.0]), x)
81 | im = plt.contourf(y0_pred.reshape(hyper_params['grid_size'],hyper_params['grid_size']), levels = levels, cmap=sdf_cm)
82 | plt.axis('equal')
83 | plt.axis("off")
84 | plt.savefig('network output (t=0)')
85 |
86 | plt.clf()
87 | y1 = jgp.sdf_circle(x)
88 | im = plt.contourf(y1.reshape(hyper_params['grid_size'],hyper_params['grid_size']), levels = levels, cmap=sdf_cm)
89 | plt.axis('equal')
90 | plt.axis("off")
91 | plt.savefig('ground truth (t=1)')
92 |
93 | plt.clf()
94 | y1_pred = model.forward(params, np.array([1.0]), x)
95 | im = plt.contourf(y1_pred.reshape(hyper_params['grid_size'],hyper_params['grid_size']), levels = levels, cmap=sdf_cm)
96 | plt.axis('equal')
97 | plt.axis("off")
98 | plt.savefig('network output (t=1)')
99 |
--------------------------------------------------------------------------------
/demos/neural_sdf/mlp_params.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/mlp_params.pkl
--------------------------------------------------------------------------------
/demos/neural_sdf/model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../../')
3 | import jaxgptoolbox as jgp
4 |
5 | import jax
6 | import jax.numpy as np
7 | from jax import jit, value_and_grad
8 | from jax.experimental import optimizers
9 |
10 | import numpy as onp
11 | import numpy.random as random
12 | import matplotlib as mpl
13 | import matplotlib.pyplot as plt
14 | import matplotlib.animation as animation
15 | import tqdm
16 | import pickle
17 |
18 | class mlp:
19 | def __init__(self, hyperParams):
20 | self.hyperParams = hyperParams
21 |
22 | def initialize_weights(self):
23 | """
24 | Initialize the parameters of the mlp
25 |
26 | Inputs
27 | hyperParams: hyper parameter dictionary
28 |
29 | Outputs
30 | params_net: parameters of the network (weight, bias) as a list of jax.numpy arrays
31 | """
32 | def init_W(size_out, size_in):
33 | W = onp.random.randn(size_out, size_in) * onp.sqrt(2 / size_in)
34 | return np.array(W)
35 | sizes = self.hyperParams["h_mlp"]
36 | sizes.insert(0, self.hyperParams["dim_in"] + self.hyperParams["dim_t"])
37 | sizes.append(self.hyperParams["dim_out"])
38 | params_net = []
39 | for ii in range(len(sizes) - 1):
40 | W = init_W(sizes[ii+1], sizes[ii])
41 | b = np.zeros(sizes[ii+1])
42 | params_net.append([W, b])
43 | return params_net
44 |
45 | def forward_single(self, params_net, t, x):
46 | """
47 | Forward pass of a MLP
48 |
49 | Inputs
50 | params_net: parameters of the network
51 | t: the latent code of the shape
52 | x: a query location in the space
53 |
54 | Outputs
55 | out: implicit function value at x (signed distance in this case)
56 | """
57 | # concatenate coordinate and latent code
58 | x = np.append(x, t)
59 |
60 | # forward pass
61 | for ii in range(len(params_net) - 1):
62 | W, b = params_net[ii]
63 | x = jax.nn.relu(np.dot(W, x) + b)
64 |
65 | # final layer
66 | W, b = params_net[-1]
67 | out = np.dot(W, x) + b
68 | return out[0]
69 |
70 | # vectorize the "forward_single" function
71 | forward = jax.vmap(forward_single, in_axes=(None, None, None, 0), out_axes=0)
--------------------------------------------------------------------------------
/demos/neural_sdf/network output (t=0).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/network output (t=0).png
--------------------------------------------------------------------------------
/demos/neural_sdf/network output (t=1).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/neural_sdf/network output (t=1).png
--------------------------------------------------------------------------------
/demos/normal_driven_stylization/.polyscope.ini:
--------------------------------------------------------------------------------
1 | {
2 | "windowHeight": 720,
3 | "windowPosX": 48,
4 | "windowPosY": 53,
5 | "windowWidth": 1280
6 | }
7 |
--------------------------------------------------------------------------------
/demos/normal_driven_stylization/imgui.ini:
--------------------------------------------------------------------------------
1 | [Window][Debug##Default]
2 | Pos=60,60
3 | Size=400,400
4 | Collapsed=0
5 |
6 | [Window][Polyscope]
7 | Pos=10,10
8 | Size=305,156
9 | Collapsed=0
10 |
11 | [Window][Structures]
12 | Pos=10,186
13 | Size=305,524
14 | Collapsed=0
15 |
16 | [Window][Selection]
17 | Pos=770,30
18 | Size=500,98
19 | Collapsed=0
20 |
21 |
--------------------------------------------------------------------------------
/demos/normal_driven_stylization/loss_history.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/demos/normal_driven_stylization/loss_history.jpg
--------------------------------------------------------------------------------
/demos/normal_driven_stylization/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../../')
3 | import jaxgptoolbox as jgp
4 |
5 | import jax
6 | import jax.numpy as np
7 | from jax import jit, value_and_grad
8 | from jax.example_libraries import optimizers
9 |
10 | import numpy as onp
11 | import tqdm
12 | import matplotlib.pyplot as plt
13 | import polyscope as ps
14 |
15 | def spokes_rims(V,F):
16 | """
17 | build spokes and rims edge indices and edge weights
18 |
19 | Inputs:
20 | V: |V|x3 array of vertex list
21 | F: |F|x3 array of face list
22 |
23 | Outputs:
24 | Ek_all: |V| list of arrays, where Ek_all[v] = |Ek|x2 array of edge indices
25 | Wk_all: |V| list of arrays, where Wk_all[v] = |Ek| array of cotan weights
26 | """
27 | V2F = jgp.vertex_face_adjacency_list(F)
28 | C = jgp.cotangent_weights(V,F)
29 |
30 | # construct spokes and rims for each vertex
31 | F = onp.array(F)
32 | nV = V.shape[0]
33 |
34 | # find the max number of neighbors
35 | max_Nk = 0
36 | for kk in range(nV):
37 | len_Nk = len(V2F[kk])
38 | if len_Nk > max_Nk:
39 | max_Nk = len_Nk
40 | print("max #neighbors is: %d" % max_Nk)
41 |
42 | Ek_np = onp.zeros((nV,max_Nk*3,2), dtype=onp.int32)
43 | Wk_np = onp.zeros((nV,max_Nk*3), dtype=onp.float32)
44 | for kk in range(nV):
45 | # get neighbors
46 | Nk = V2F[kk]
47 |
48 | # construct edge list
49 | Ek0 = onp.concatenate((F[Nk,0], F[Nk,1], F[Nk,2]))
50 | Ek1 = onp.concatenate((F[Nk,1], F[Nk,2], F[Nk,0]))
51 | Ek = onp.concatenate((Ek0[:,None], Ek1[:,None]), axis = 1)
52 | # pad with the ghost vertex index so that Ek becomes a numpy array with size (max_Nk, 2)
53 | Ek_pad = onp.pad(Ek, pad_width = ((0, max_Nk*3 - Ek.shape[0]), (0, 0)), mode = 'constant', constant_values = nV)
54 | Ek_np[kk,:,:] = Ek_pad
55 |
56 | # get all the cotan weights
57 | Wk = onp.concatenate((C[Nk,2],C[Nk,0],C[Nk,1]))
58 | # pad with 0 so that Wk becomes a numpy array with size (max_Nk,)
59 | Wk_pad = onp.pad(Wk, pad_width = ((0, max_Nk*3 - Wk.shape[0])), mode = 'constant', constant_values = 0)
60 | Wk_np[kk,:] = Wk_pad
61 |
62 | return np.array(Ek_np), np.array(Wk_np)
63 |
64 | def compute_target_normal_single(n):
65 | cIdx = np.argmax(np.abs(n))
66 | tar_n = np.zeros((3,), dtype=np.float32)
67 | tar_n = tar_n.at[cIdx].set(np.sign(n[cIdx]))
68 | return tar_n
69 | compute_target_normals = jax.vmap(compute_target_normal_single, in_axes=(0), out_axes=0)
70 |
71 | def normal_driven_energy_single(U,V,lam,n,tar_n,a,E,W):
72 | v_ghost = np.array([[0.,0.,0.]])
73 | Ug = np.concatenate((U, v_ghost), axis = 0)
74 | Vg = np.concatenate((V, v_ghost), axis = 0)
75 |
76 | dV = (Vg[E[:,1],:] - Vg[E[:,0],:]).T
77 | dU = (Ug[E[:,1],:] - Ug[E[:,0],:]).T
78 |
79 | # orthogonal procrustes
80 | S = (dV * W).dot(dU.T) + lam*a*n[:,None].dot(tar_n[None,:])
81 |
82 | # fit rotation
83 | R = jgp.fit_rotations_cayley(S)
84 |
85 | # compute loss
86 | RdV_dU = R.dot(dV) - dU
87 | Rn_tar_n = R.dot(n) - tar_n
88 | return np.trace((RdV_dU * W).dot(RdV_dU.T)) + lam*a*Rn_tar_n.dot(Rn_tar_n)
89 | normal_driven_energy = jax.vmap(normal_driven_energy_single, in_axes=(None,None,None,0,0,0,0,0), out_axes=0)
90 |
91 | # define loss function and update function
92 | def loss(U,V,lam,N,tar_N,VA,Ek_all,Wk_all):
93 | loss = normal_driven_energy(U,V,lam,N,tar_N,VA,Ek_all,Wk_all)
94 | return loss.mean()
95 |
96 | @jit
97 | def update(epoch, opt_state, V,lam,N,tar_N,VA,Ek_all,Wk_all):
98 | U = get_params(opt_state)
99 | value, grads = value_and_grad(loss, argnums = 0)(U,V,lam,N,tar_N,VA,Ek_all,Wk_all)
100 | opt_state = opt_update(epoch, grads, opt_state)
101 | return value, opt_state
102 |
103 | if __name__ == "__main__":
104 | # hyper parameters
105 | hyper_params = {
106 | "step_size": 1e-4,
107 | "num_epochs": 1000,
108 | }
109 |
110 | V,F = jgp.read_mesh("./spot.obj")
111 | N = jgp.vertex_normals(V,F)
112 | VA = jgp.vertex_areas(V,F)
113 | Ek_all, Wk_all = spokes_rims(V,F)
114 | tar_N = compute_target_normals(N)
115 |
116 | U = V.copy()
117 | lam = 1.0
118 |
119 | # optimizer
120 | opt_init, opt_update, get_params = optimizers.adam(step_size=hyper_params["step_size"])
121 | opt_state = opt_init(U)
122 |
123 | # training
124 | loss_history = onp.zeros(hyper_params["num_epochs"])
125 | pbar = tqdm.tqdm(range(hyper_params["num_epochs"]))
126 | for epoch in pbar:
127 | loss_value, opt_state = update(epoch, opt_state, V,lam,N,tar_N,VA,Ek_all,Wk_all)
128 | loss_history[epoch] = loss_value
129 | pbar.set_postfix({"loss": loss_value})
130 |
131 | U = get_params(opt_state)
132 | jgp.write_obj("opt.obj", U,F)
133 |
134 | plt.semilogy(loss_history)
135 | plt.title('normal driven energy')
136 | plt.grid()
137 | plt.savefig("loss_history.jpg")
138 |
139 | ps.init()
140 | ps.register_surface_mesh('input mesh',V,F)
141 | ps.register_surface_mesh('optimized mesh',U,F)
142 | ps.show()
143 |
144 |
145 |
146 |
147 |
--------------------------------------------------------------------------------
/differentiable/angle_defect.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from .tip_angles import tip_angles
3 | from ..general.boundary_vertices import boundary_vertices
4 | from jax import jit
5 |
6 | # TODO: these are not abel to @jit yet
7 | def angle_defect(V,F,b=None,n=None):
8 | '''
9 | ANGLE_DEFECT computes the angle defect (integrated Gauss curvature) at each
10 | vertex
11 |
12 | Inputs:
13 | V: (|V|,3) numpy ndarray of vertex positions.
14 | F: (|F|,3) numpy ndarray of face indices, must be static.
15 | b: (|b|,) numpy ndarray of boundary vertex indices (they have 0 defect).
16 | Will be computed if not provided, at the price of JITability.
17 | n: number of vertices in the mesh.
18 | Will be computed if not provided, at the price of JITability.
19 |
20 | Outputs:
21 | k: (|N|,) numpy array of angle defect at each vertex
22 | '''
23 | #k = tip_angles(V,F)
24 |
25 | if b is None:
26 | bv = boundary_vertices(F)
27 | else:
28 | bv = b
29 |
30 | if n is None:
31 | nv,_ = V.shape
32 | else:
33 | nv = n
34 |
35 | k = angle_defect_intrinsic(F,tip_angles(V,F),bv,nv)
36 |
37 | #Pad return value if there are vertices in V not occuring in F
38 | # nv = V.shape[0]
39 | # nk = k.size
40 | # if nv>nk:
41 | # k = np.concatenate((k,np.zeros(nv-nk)))
42 |
43 | return k
44 |
45 | def angle_defect_intrinsic(F,A,b=np.empty(0,dtype=int),n=None):
46 | '''
47 | ANGLE_DEFECT_INTRINSIC computes the angle defect (integrated Gauss
48 | curvature) at each vertex given intrinsic tip angles
49 |
50 | Inputs:
51 | F: (|F|,3) numpy ndarray of face indices
52 | A: (|F|,3) numpy ndarray of tip angles at each corner of each face (in [0,pi))
53 | b: (|b|,) numpy ndarray of boundary vertex indices (they have 0 defect)
54 | n: number of vertices in the mesh.
55 | Will be computed if not provided, at the price of JITability.
56 |
57 | Outputs:
58 | k: (|N|,) numpy array of angle defect at each vertex
59 | '''
60 |
61 | if n is None:
62 | nv = F.max() + 1
63 | else:
64 | nv = n
65 |
66 | k = 2.*np.pi - np.bincount(np.ravel(F), weights=np.ravel(A), length=nv)
67 | # k = k.at[b].set(0)
68 |
69 | return k
70 |
--------------------------------------------------------------------------------
/differentiable/cotangent_weights.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from jax import jit
3 |
4 | @jit
5 | def cotangent_weights(V,F):
6 | """
7 | computes cotangent weight for each half edge
8 |
9 | Input:
10 | V (|V|,3) array of vertex positions
11 | F (|F|,3) array of face indices
12 | Output:
13 | C (|F|,3) array of cotengent weights
14 | """
15 | i0 = F[:,0]
16 | i1 = F[:,1]
17 | i2 = F[:,2]
18 | l0 = np.sqrt(np.sum(np.power(V[i1,:] - V[i2,:],2),1))
19 | l1 = np.sqrt(np.sum(np.power(V[i2,:] - V[i0,:],2),1))
20 | l2 = np.sqrt(np.sum(np.power(V[i0,:] - V[i1,:],2),1))
21 |
22 | # Heron's formula for area
23 | s = (l0 + l1 + l2) / 2.
24 | area = np.sqrt(s * (s-l0) * (s-l1) * (s-l2))
25 |
26 | # cotangent weights
27 | c0 = (l1*l1 + l2*l2 - l0*l0) / area / 8
28 | c1 = (l0*l0 + l2*l2 - l1*l1) / area / 8
29 | c2 = (l0*l0 + l1*l1 - l2*l2) / area / 8
30 | C = np.concatenate((c0[:,None], c1[:,None], c2[:,None]), axis = 1)
31 | return C
--------------------------------------------------------------------------------
/differentiable/dihedral_angles.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from .face_normals import face_normals
3 | from .dotrow import dotrow
4 | from ..general.adjacency_list_edge_face import adjacency_list_edge_face
5 | from jax import jit
6 |
7 | # TODO: these are not abel to @jit yet
8 | def dihedral_angles(V, F):
9 | '''
10 | DIHEDRAL_ANGLES computes the dihedral angles of a mesh
11 |
12 | Inputs:
13 | V: (|V|,3) numpy ndarray of vertex positions
14 | F: (|F|,3) numpy ndarray of face indices
15 |
16 | Outputs:
17 | dihedral_angles: (|E|,) numpy array of dihedral angles (0 ~ pi)
18 | E: (|E|,2) numpy array of edge indices
19 | '''
20 | # TODO: double check whether this is differentiable
21 |
22 | FN = face_normals(V,F)
23 | E2F, E = adjacency_list_edge_face(F)
24 | return dihedral_angles_from_normals(FN,E2F), E
25 |
26 | def dihedral_angles_from_normals(N, E2F):
27 | '''
28 | DIHEDRAL_ANGLES_FROM_NORMALS computes the dihedral angles of a mesh, given
29 | precomputed normals and edge-to-face map
30 |
31 | Inputs:
32 | N: (|F|,3) numpy ndarray of unit face normals
33 | E2F: (|E|,|F|) scipy sparse matrix scipy sparse matrix of adjacency
34 | information between edges and faces, for example as produced by
35 | adjacency_list_edge_face
36 |
37 | Outputs:
38 | dihedral_angles: (|E|,) numpy array of dihedral angles at each edge
39 | '''
40 |
41 | dotN = dotrow(N[E2F[:,0],:], N[E2F[:,1],:]).clip(-1,1)
42 | return np.pi - np.arccos(dotN)
43 |
--------------------------------------------------------------------------------
/differentiable/dotrow.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from jax import jit
3 |
4 | @jit
5 | def dotrow(X,Y):
6 | '''
7 | DOTROW computes the row-wise dot product of the rows of two matrices
8 |
9 | Inputs:
10 | X: (n,m) numpy ndarray
11 | Y: (n,m) numpy ndarray
12 |
13 | Outputs:
14 | d: (n,) numpy array of rowwise dot product of X and Y
15 | '''
16 |
17 | return np.sum(X * Y, axis = 1)
18 |
19 |
--------------------------------------------------------------------------------
/differentiable/face_areas.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from jax import jit
3 |
4 | @jit
5 | def face_areas(V, F):
6 | """
7 | FACEAREAS computes area per face
8 |
9 | Input:
10 | V (|V|,3) numpy array of vertex positions
11 | F (|F|,3) numpy array of face indices
12 | Output:
13 | FA (|F|,) numpy array of face areas
14 | """
15 | vec1 = V[F[:,1],:] - V[F[:,0],:]
16 | vec2 = V[F[:,2],:] - V[F[:,0],:]
17 | FN = np.cross(vec1, vec2) / 2
18 | FA = np.sqrt(np.sum(FN**2,1))
19 | return FA
--------------------------------------------------------------------------------
/differentiable/face_normals.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from .normalizerow import normalizerow
3 | from jax import jit
4 |
5 | @jit
6 | def face_normals(V, F):
7 | '''
8 | FACENORMALS computes unit face normal of a triangle mesh
9 |
10 | Inputs:
11 | V: (|V|,3) numpy ndarray of vertex positions
12 | F: (|F|,3) numpy ndarray of face indices
13 |
14 | Outputs:
15 | FN_normalized: (|F|,3) numpy ndarray of unit face normal
16 | '''
17 | vec1 = V[F[:,1],:] - V[F[:,0],:]
18 | vec2 = V[F[:,2],:] - V[F[:,0],:]
19 | FN = np.cross(vec1, vec2) / 2
20 | # l2Norm = np.sqrt((FN * FN).sum(axis=1))
21 | # FN_normalized = FN / (l2Norm.reshape(FN.shape[0],1))
22 | FN_normalized = normalizerow(FN)
23 | return FN_normalized
--------------------------------------------------------------------------------
/differentiable/fit_rotations_cayley.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 |
3 | def fit_rotations_cayley(S, R = np.eye(3)):
4 | """
5 | given a cross-covariance matrix S, this function outputs the closest rotation R. This method is based on "Fast Updates for Least-Squares Rotational Alignment" by Zhang et al 2021
6 |
7 | Input:
8 | S: 3x3 array of the corss covariance matrix
9 | R: (optional) 3x3 initial rotation matrix (default identity)
10 |
11 | Output
12 | R: rotation matrix which maximize np.trace(S@R)
13 |
14 | Warning:
15 | - I haven't add a input check to check whether S is a valid crosss covariance matrix, if not a valid crosss covariance matrix, this method will not output rotation matrix
16 | - I haven't implemented a stopping criteria, I have to figure out how to imeplement that in JAX
17 | """
18 | for iter in range(3): # empirically, 3 iterations seem enough
19 | MM = S @ R
20 | z = cayley_step(MM)
21 | s = z.dot(z)
22 | z_col = np.expand_dims(z, 1)
23 | zzT = z_col @ z_col.T
24 | Z = np.array([[0., -z[2], z[1]], [z[2], 0., -z[0]], [-z[1], z[0], 0.]])
25 | R = 1./(1.+s) * R @ ((1.-s)*np.eye(3) + 2*zzT + 2*Z)
26 | # if s < 1e-8:
27 | # return R
28 | return R
29 |
30 | def cayley_step(M):
31 | m = np.array([M[1,2]-M[2,1], M[2,0]-M[0,2], M[0,1]-M[1,0]])
32 | i,j,k = 0,1,2
33 | two_lam0 = 2*M[i,i] + np.abs(M[i,j]+M[j,i]) + np.abs(M[i,k]+M[k,i])
34 | i,j,k = 1,2,0
35 | two_lam1 = 2*M[i,i] + np.abs(M[i,j]+M[j,i]) + np.abs(M[i,k]+M[k,i])
36 | i,j,k = 2,0,1
37 | two_lam2 = 2*M[i,i] + np.abs(M[i,j]+M[j,i]) + np.abs(M[i,k]+M[k,i])
38 | two_lam = np.maximum(np.maximum(two_lam0, two_lam1), two_lam2)
39 | t = np.trace(M)
40 | gs = np.maximum(t, two_lam - t)
41 | H = M + M.T - (t + np.sqrt(gs*gs + m.dot(m)))*np.eye(3)
42 |
43 | h_inv = 1. / np.linalg.det(H)
44 | d0 = np.linalg.det( np.array([-m, H[:,1], H[:,2]]) )
45 | d1 = np.linalg.det( np.array([H[:,0], -m, H[:,2]]) )
46 | d2 = np.linalg.det( np.array([H[:,0], H[:,1], -m]) )
47 | return h_inv * np.array([d0, d1, d2])
--------------------------------------------------------------------------------
/differentiable/halfedge_lengths.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from .dotrow import dotrow
3 | from jax import jit
4 |
5 | @jit
6 | def halfedge_lengths(V, F):
7 | '''
8 | HALFEDGE_LENGTHS computes the lengths of all halfedges in the mesh
9 |
10 | Inputs:
11 | V: (|V|,3) numpy ndarray of vertex positions
12 | F: (|F|,3) numpy ndarray of face indices
13 |
14 | Outputs:
15 | l: (|F|,3) numpy ndarray of halfedge lenghts.
16 | Our halfedge convention identifies each halfedge by the index
17 | of the face and the opposite vertex within the face:
18 | (face, opposite vertex)
19 | '''
20 |
21 | return np.sqrt(halfedge_lengths_squared(V,F))
22 |
23 | @jit
24 | def halfedge_lengths_squared(V, F):
25 | '''
26 | HALFEDGE_LENGTHS_SQUARED computes the lengths of all halfedges in the mesh,
27 | squared (this is often preferable to just the
28 | lengths, since it's easier to differentiate
29 | through)
30 |
31 | Inputs:
32 | V: (|V|,3) numpy ndarray of vertex positions
33 | F: (|F|,3) numpy ndarray of face indices
34 |
35 | Outputs:
36 | l_sq: (|F|,3) numpy ndarray of squared halfedge lenghts.
37 | Our halfedge convention identifies each halfedge by the index
38 | of the face and the opposite vertex within the face:
39 | (face, opposite vertex)
40 | '''
41 |
42 | he0 = V[F[:,2],:] - V[F[:,1],:]
43 | he1 = V[F[:,0],:] - V[F[:,2],:]
44 | he2 = V[F[:,1],:] - V[F[:,0],:]
45 |
46 | lhe0 = np.expand_dims(dotrow(he0,he0), axis=1)
47 | lhe1 = np.expand_dims(dotrow(he1,he1), axis=1)
48 | lhe2 = np.expand_dims(dotrow(he2,he2), axis=1)
49 |
50 | l_sq = np.concatenate((lhe0, lhe1, lhe2), axis=1)
51 |
52 | return l_sq
53 |
--------------------------------------------------------------------------------
/differentiable/normalize_unit_box.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from jax import jit
3 |
4 | @jit
5 | def normalize_unit_box(V, margin = 0.0):
6 | """
7 | NORMALIZE_UNIT_BOX normalize a set of points to a unit bounding box with a user-specified margin
8 |
9 | Input:
10 | V: (n,3) numpy array of point locations
11 | margin: a constant of user specified margin
12 | Output:
13 | V: (n,3) numpy array of point locations bounded by margin ~ 1-margin
14 | """
15 |
16 | V = V - V.min(0)
17 | V = V / V.max()
18 | V = V - 0.5
19 | V = V * (1.0 - margin*2)
20 | V = V + 0.5
21 | return V
--------------------------------------------------------------------------------
/differentiable/normalizerow.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from .normrow import normrow
3 | from jax import jit
4 |
5 | @jit
6 | def normalizerow(X):
7 | """
8 | NORMALIZEROW normalizes the l2-norm of each row in a np array
9 |
10 | Input:
11 | X: (n,m) numpy array
12 | Output:
13 | X_normalized: (n,m) row normalized numpy array
14 | """
15 | l2Norm = normrow(X)
16 | X_normalized = X / (l2Norm.reshape(X.shape[0],1))
17 | return X_normalized
18 |
--------------------------------------------------------------------------------
/differentiable/normrow.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from .dotrow import dotrow
3 | from jax import jit
4 |
5 | @jit
6 | def normrow(X):
7 | """
8 | NORMROW computes the l2-norm of each row in a np array
9 |
10 | Input:
11 | X: (n,m) numpy array
12 | Output:
13 | nX: (n,) numpy array of l2 norm of each row in X
14 | """
15 |
16 | return np.sqrt(dotrow(X,X))
--------------------------------------------------------------------------------
/differentiable/ramp_smooth.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from jax import jit
3 |
4 | @jit
5 | def ramp_smooth(d, d0 = 1.0):
6 | """
7 | RAMP function so that:
8 | - 1, when d/d0 >= 1
9 | - -1 when d/d0 <= 1
10 | - smooth decay when -1 < d/d0 < 1
11 | (following the notation of "Curl-Noise for Procedural Fluid Flow" Bridson et al 2007)
12 |
13 | Input:
14 | d distance of a query point to its closest point on boundary
15 | d0 scale of the decay
16 | Output:
17 | scale ramp scaling factor
18 | """
19 | r = d / d0
20 | val = 15./8.*r - 10./8.*(r**3) + 3./8.*(r**5)
21 | return np.minimum(np.abs(val), 1.0) * np.sign(val)
--------------------------------------------------------------------------------
/differentiable/tip_angles.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from jax import jit
3 | from .halfedge_lengths import halfedge_lengths_squared
4 |
5 | @jit
6 | def tip_angles(V, F):
7 | '''
8 | TIP_ANGLES computes the tip angles of a mesh (the angle at every corner)
9 |
10 | Inputs:
11 | V: (|V|,3) numpy ndarray of vertex positions
12 | F: (|F|,3) numpy ndarray of face indices
13 |
14 | Outputs:
15 | A: (|F|,3) numpy ndarray of tip angles at each corner of each face (in [0,pi))
16 | '''
17 |
18 | return tip_angles_intrinsic(halfedge_lengths_squared(V,F))
19 |
20 |
21 | def tip_angles_intrinsic(l_sq):
22 | '''
23 | TIP_ANGLES_INTRINSIC computes the tip angles of a mesh (the angle at every
24 | corner) given squared halfedge lengths
25 |
26 | Inputs:
27 | l_sq: (|F|,3) numpy ndarray of squared halfedge lenghts.
28 | Our halfedge convention identifies each halfedge by the index
29 | of the face and the opposite vertex within the face:
30 | (face, opposite vertex)
31 |
32 | Outputs:
33 | A: (|F|,3) numpy ndarray of tip angles at each corner of each face (in [0,pi))
34 | '''
35 |
36 | #Use the cosine rule
37 | a_sq = np.expand_dims(l_sq[:,0], axis=1)
38 | b_sq = np.expand_dims(l_sq[:,1], axis=1)
39 | c_sq = np.expand_dims(l_sq[:,2], axis=1)
40 | a = np.sqrt(a_sq)
41 | b = np.sqrt(b_sq)
42 | c = np.sqrt(c_sq)
43 |
44 | cos_angles = np.concatenate((
45 | (b_sq + c_sq - a_sq) / (2.*b*c),
46 | (a_sq + c_sq - b_sq) / (2.*a*c),
47 | (a_sq + b_sq - c_sq) / (2.*a*b)
48 | ), axis=1).clip(-1,1)
49 |
50 | return np.arccos(cos_angles)
51 |
--------------------------------------------------------------------------------
/differentiable/vertex_areas.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from jax import jit
3 |
4 | @jit
5 | def vertex_areas(V,F):
6 | """
7 | computes area per vertex
8 |
9 | Input:
10 | V (|V|,3) numpy array of vertex positions
11 | F (|F|,3) numpy array of face indices
12 | Output:
13 | VA (|V|,) numpy array of vertex areas
14 | """
15 | l1 = np.sqrt(np.sum((V[F[:,1],:]-V[F[:,2],:])**2,1))
16 | l2 = np.sqrt(np.sum((V[F[:,2],:]-V[F[:,0],:])**2,1))
17 | l3 = np.sqrt(np.sum((V[F[:,0],:]-V[F[:,1],:])**2,1))
18 |
19 | cos1 = (l3**2+l2**2-l1**2) / (2*l2*l3)
20 | cos2 = (l1**2+l3**2-l2**2) / (2*l1*l3)
21 | cos3 = (l1**2+l2**2-l3**2) / (2*l1*l2)
22 |
23 | cosMat = np.concatenate( (cos1[:,None], cos2[:,None], cos3[:,None]), axis =1)
24 | lMat = np.concatenate( (l1[:,None], l2[:,None], l3[:,None]), axis =1)
25 | barycentric = cosMat * lMat
26 | normalized_barycentric = barycentric / np.sum(barycentric,1)[:,None]
27 | areas = 0.25 * np.sqrt( (l1+l2-l3)*(l1-l2+l3)*(-l1+l2+l3)*(l1+l2+l3) )
28 | partArea = normalized_barycentric * areas[:,None]
29 |
30 | quad0 = (partArea[:,1]+partArea[:,2]) * 0.5
31 | quad1 = (partArea[:,0]+partArea[:,2]) * 0.5
32 | quad2 = (partArea[:,0]+partArea[:,1]) * 0.5
33 |
34 | idx = np.where(cos1<0, 0, 1)
35 | quad0 = quad0.at[idx].set(areas[idx] * 0.5)
36 | quad1 = quad1.at[idx].set(areas[idx] * 0.25)
37 | quad2 = quad2.at[idx].set(areas[idx] * 0.25)
38 |
39 | idx = np.where(cos2<0, 0, 1)
40 | quad0 = quad0.at[idx].set(areas[idx] * 0.25)
41 | quad1 = quad1.at[idx].set(areas[idx] * 0.5)
42 | quad2 = quad2.at[idx].set(areas[idx] * 0.25)
43 |
44 | idx = np.where(cos3<0, 0, 1)
45 | quad0 = quad0.at[idx].set(areas[idx] * 0.25)
46 | quad1 = quad1.at[idx].set(areas[idx] * 0.25)
47 | quad2 = quad2.at[idx].set(areas[idx] * 0.5)
48 | quads = np.concatenate( (quad0[:,None], quad1[:,None], quad2[:,None]), axis =1).flatten()
49 |
50 | VA = np.zeros((V.shape[0],))
51 | VA = VA.at[F.flatten()].add(quads)
52 | return VA
--------------------------------------------------------------------------------
/differentiable/vertex_normals.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | from jax import jit
3 | from .normalizerow import normalizerow
4 |
5 | @jit
6 | def vertex_normals(V,F):
7 | """
8 | Computes face area weighted vertex normals
9 |
10 | Input:
11 | V |V|x3 numpy array of vertex positions
12 | F |F|x3 numpy array of face indices
13 |
14 | Output:
15 | N |V|x3 array of normalized vertex normal (weighted by face areas)
16 | """
17 | vec1 = V[F[:,1],:] - V[F[:,0],:]
18 | vec2 = V[F[:,2],:] - V[F[:,0],:]
19 | FN_unnormalized = np.cross(vec1, vec2) / 2
20 |
21 | VN = np.zeros((V.shape[0],3), dtype=np.float32)
22 | VN = VN.at[F[:,0]].add(FN_unnormalized)
23 | VN = VN.at[F[:,1]].add(FN_unnormalized)
24 | VN = VN.at[F[:,2]].add(FN_unnormalized)
25 | return normalizerow(VN)
26 |
--------------------------------------------------------------------------------
/external/read_mesh.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | import numpy as onp
3 | import igl
4 |
5 | def read_mesh(path):
6 | '''
7 | just a jax wrapper for igl.read_triangle_mesh
8 | '''
9 | V, F = igl.read_triangle_mesh(path)
10 | return np.array(V), np.array(F)
--------------------------------------------------------------------------------
/external/signed_distance.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | import numpy as onp
3 | import igl
4 |
5 | def signed_distance(P, V, F):
6 | '''
7 | SIGNED_DISTANCE computes signed distance from given points to a mesh
8 |
9 | Inputs:
10 | P: (|P|,3) numpy ndarray of point positions
11 | V: (|V|,3) numpy ndarray of vertex positions
12 | F: (|F|,3) numpy ndarray of face indices
13 |
14 | Outputs:
15 | S: (|P|,) numpy array of signed distance
16 | pF:(|P|,) numpy array of projected face indices
17 | pV:(|P|,3) numpy array of projected point locations
18 |
19 | Notes:
20 | It can be differentiable, but this faster version based on C++ is not
21 | '''
22 | P_np = onp.array(P)
23 | V_np = onp.array(V)
24 | F_np = onp.array(F)
25 | S,pF,pV = igl.signed_distance(P_np, V_np, F_np)
26 |
27 | return np.array(S), np.array(pF), np.array(pV)
--------------------------------------------------------------------------------
/external/write_obj.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as np
2 | import numpy as onp
3 | import igl
4 |
5 | def write_obj(path, V, F):
6 | '''
7 | just a jax wrapper for igl.write_obj
8 | '''
9 | igl.write_obj(path,onp.array(V),onp.array(F))
--------------------------------------------------------------------------------
/general/adjacency_edge_face.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | import jax
4 | from . edges_with_mapping import edges_with_mapping
5 |
6 | def adjacency_edge_face(F):
7 | '''
8 | ADJACENCY_EDGE_FACE computes edge-face adjacency matrix
9 |
10 | Input:
11 | F (|F|,3) numpy array of face indices
12 | Output:
13 | E2F (|E|, |F|) scipy sparse matrix of adjacency information between edges and faces
14 | E (|E|,2) numpy array of edge indices
15 | '''
16 | E, F2E = edges_with_mapping(F)
17 | IC = F2E.T.reshape(F.shape[1]*F.shape[0])
18 |
19 | row = IC
20 | col = np.tile(np.arange(F.shape[0]), 3)
21 | val = np.ones(len(IC), dtype=np.int)
22 | E2F = scipy.sparse.coo_matrix((val,(row, col)), shape=(E.shape[0], F.shape[0])).tocsr()
23 |
24 | return E2F, jax.numpy.asarray(E)
25 |
--------------------------------------------------------------------------------
/general/adjacency_list_edge_face.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jax
3 | from . adjacency_edge_face import adjacency_edge_face
4 |
5 | def adjacency_list_edge_face(F):
6 | '''
7 | ADJACENCY_LIST_EDGE_FACE computes edge-face adjacency list
8 |
9 | Input:
10 | F (|F|,3) numpy array of face indices
11 | Output:
12 | adjList (|E|,2) numpy array of adjacency list between edges and faces
13 | E (|E|,2) numpy array of edge indices
14 | '''
15 |
16 | E2F, E = adjacency_edge_face(F)
17 | adjList = np.zeros((E.shape[0], 2), dtype=np.int) # assume manifold
18 | E2F_bool = E2F.astype('bool')
19 |
20 | idx = np.arange(F.shape[0])
21 |
22 | for e in range(E.shape[0]):
23 | E2F_bool_row = np.squeeze(np.asarray(E2F_bool[e,:].todense()))
24 | adjList[e,:] = idx[E2F_bool_row]
25 |
26 | return jax.numpy.asarray(adjList), E
--------------------------------------------------------------------------------
/general/adjacency_list_face_face.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 | from scipy.sparse import coo_matrix
3 |
4 | def adjacency_list_face_face(F):
5 | """
6 | build a face-face adjacency list such that F2F[face_index] = [adjacent_face_indices]. Note that neighboring faces are determined by whether two faces share an edge.
7 |
8 | Inputs
9 | F: |F|x3 array of face indices
10 |
11 | Outputs
12 | F2F: list of lists with so that F2F[f] = [fi, fj, ...]
13 | """
14 |
15 | F = onp.array(F)
16 |
17 | # build adjacency matrix
18 | E_all = onp.vstack((F[:,[1,2]], F[:,[2,0]], F[:,[0,1]]))
19 | sort_E_all = onp.sort(E_all,-1)
20 |
21 | # extract unique edges with inverse indices s.t. E[inv_idx,:] = sort_E_all
22 | E, inv_idx = onp.unique(sort_E_all, axis=0, return_inverse=True)
23 |
24 | nF = F.shape[0]
25 | nE = E.shape[0]
26 |
27 | row = inv_idx
28 | col = onp.repeat(onp.arange(nF),F.shape[1])
29 | data = onp.ones_like(inv_idx)
30 | uE2F = coo_matrix((data, (row, col)), shape=(nE, nF))
31 | F2F_mat = uE2F.T @ uE2F
32 | F2F_mat.setdiag(0)
33 |
34 | F2F = [[] for _ in range(nF)]
35 | for f in range(nF):
36 | _, col = F2F_mat[f,:].nonzero()
37 | F2F[f] = col.tolist()
38 |
39 | return F2F
--------------------------------------------------------------------------------
/general/adjacency_list_vertex_face.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 |
3 | def adjacency_list_vertex_face(F):
4 | """
5 | build a vertex-face adjacency list such that V2F[vertex_index] = [adjacent_face_indices]
6 |
7 | Inputs
8 | F: |F|x3 array of face indices
9 |
10 | Outputs
11 | V2F: list of lists with so that V2F[v] = [fi, fj, ...]
12 | """
13 | F = onp.array(F)
14 | nV = F.max() + 1
15 | V2F = [[] for _ in range(nV)]
16 | for f in range(F.shape[0]):
17 | for s in range(F.shape[1]):
18 | V2F[F[f,s]].append(f)
19 | return V2F
--------------------------------------------------------------------------------
/general/adjacency_list_vertex_vertex.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .adjacency_vertex_vertex import adjacency_vertex_vertex
3 |
4 | def adjacency_list_vertex_vertex(F):
5 | """
6 | This function computes vertex to vertex adjacency matrix
7 |
8 | inputs:
9 | F: |F|x3 list of face indices
10 |
11 | outputs:
12 | AList: |V| list of adjacent vertex indices
13 | """
14 | A = adjacency_vertex_vertex(F)
15 | AList = []
16 |
17 | for ii in range(A.shape[0]):
18 | Aii_nonzero = A[ii,:].nonzero()[1]
19 | AList.append(Aii_nonzero)
20 |
21 | return AList
--------------------------------------------------------------------------------
/general/adjacency_vertex_vertex.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | from .edges import edges
4 |
5 | def adjacency_vertex_vertex(F):
6 | """
7 | This function computes vertex to vertex adjacency matrix
8 |
9 | inputs:
10 | F: |F|x3 list of face indices
11 |
12 | outputs:
13 | A: |V|x|V| scipy sparse matrix of vertex adjacency matrix
14 | """
15 | E = edges(F)
16 | nV = F.max() + 1
17 |
18 | row = np.concatenate((E[:,0],E[:,1]))
19 | col = np.concatenate((E[:,1],E[:,0]))
20 | val = np.ones(len(row), dtype=np.int)
21 | A = scipy.sparse.coo_matrix((val,(row, col)), shape=(nV,nV)).tocsr()
22 | return A
--------------------------------------------------------------------------------
/general/boundary_vertices.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jax
3 | from .outline import outline
4 |
5 | def boundary_vertices(F):
6 | '''
7 | OUTLINE compute the unordered outline edges of a triangle mesh
8 |
9 | Input:
10 | F (|F|,3) numpy array of face indices
11 | Output:
12 | b (|bE|,) numpy array of boundary vertex indices
13 | '''
14 |
15 | return np.unique(outline(F))
16 |
--------------------------------------------------------------------------------
/general/cotmatrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | import scipy.sparse as sparse
4 |
5 | def cotmatrix(V,F):
6 | '''
7 | COTMATRIX computes the cotangent laplace matrix of a triangle mesh
8 |
9 | Inputs:
10 | V: |V|-by-3 numpy ndarray of vertex positions
11 | F: |F|-by-3 numpy ndarray of face indices
12 |
13 | Output:
14 | L: |V|-by-|V| matrix
15 | '''
16 | V = np.array(V)
17 | F = np.array(F)
18 | numVert = V.shape[0]
19 | numFace = F.shape[0]
20 |
21 | temp1 = np.zeros((numFace, 3))
22 | temp2 = np.zeros((numFace, 3))
23 | angles = np.zeros((numFace, 3))
24 |
25 | # compute angle
26 | for i in range(3):
27 | i1 = (i ) % 3
28 | i2 = (i+1) % 3
29 | i3 = (i+2) % 3
30 | temp1 = V[F[:,i2],:] - V[F[:,i1],:]
31 | temp2 = V[F[:,i3],:] - V[F[:,i1],:]
32 |
33 | # normalize the vectors
34 | norm_temp1 = np.sqrt(np.power(temp1,2).sum(axis = 1))
35 | norm_temp2 = np.sqrt(np.power(temp2,2).sum(axis = 1))
36 | temp1 = np.divide(temp1, np.repeat([norm_temp1], 3, axis=0).transpose())
37 | temp2 = np.divide(temp2, np.repeat([norm_temp2], 3, axis=0).transpose())
38 |
39 | # compute angles
40 | dotProd = np.multiply(temp1, temp2).sum(axis = 1)
41 | angles[:,i1] = np.arccos(dotProd)
42 |
43 | # compute cotan laplace
44 | L = sparse.lil_matrix((numVert,numVert), dtype=np.float64)
45 | for i in range(3):
46 | i1 = (i ) % 3
47 | i2 = (i+1) % 3
48 | i3 = (i+2) % 3
49 | L[F[:,i1],F[:,i2]] += -1 / np.tan(angles[:,i3])
50 | L = (L + L.transpose()) / 2
51 | temp = np.array(L.sum(axis=1)).reshape((numVert))
52 | L -= sparse.diags(temp)
53 | return -L.tocsr()
--------------------------------------------------------------------------------
/general/edge_flaps.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from . edges_with_mapping import edges_with_mapping
3 |
4 | def edge_flaps(F):
5 | '''
6 | EDGEFLAPS compute flap edge indices for each edge
7 |
8 | Input:
9 | F (|F|,3) numpy array of face indices
10 | Output:
11 | E (|E|,2) numpy array of edge indices
12 | flapEdges (|E|, 4 or 2) numpy array of edge indices
13 | '''
14 | # Notes:
15 | # Each flapEdges[e,:] = [a,b,c,d] edges indices
16 | # / \
17 | # b a
18 | # / \
19 | # - e - -
20 | # \ /
21 | # c d
22 | # \ /
23 | E, F2E = edges_with_mapping(F)
24 | flapEdges = [[] for i in range(E.shape[0])]
25 | for f in range(F.shape[0]):
26 | e0 = F2E[f,0]
27 | e1 = F2E[f,1]
28 | e2 = F2E[f,2]
29 | flapEdges[e0].extend([e1,e2])
30 | flapEdges[e1].extend([e2,e0])
31 | flapEdges[e2].extend([e0,e1])
32 | return E, np.array(flapEdges)
--------------------------------------------------------------------------------
/general/edges.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import jax
3 | from . edges_with_mapping import edges_with_mapping
4 |
5 | def edges(F):
6 | '''
7 | EDGES compute edges from face information
8 |
9 | Input:
10 | F (|F|,3) numpy array of face indices
11 | Output:
12 | E (|E|,2) numpy array of edge indices
13 | '''
14 | edge, _ = edges_with_mapping(F)
15 | return edge
--------------------------------------------------------------------------------
/general/edges_with_mapping.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def edges_with_mapping(F):
4 | '''
5 | EDGES compute edges from face information
6 |
7 | Input:
8 | F (|F|,3) numpy array of face indices
9 | Output:
10 | uE (|E|,2) numpy array of edge indices
11 | F2E (|F|,3) numpy array mapping from halfedges to unique edges.
12 | Our halfedge convention identifies each halfedge by the index
13 | of the face and the opposite vertex within the face:
14 | (face, opposite vertex)
15 | '''
16 | F12 = F[:, np.array([1,2])]
17 | F20 = F[:, np.array([2,0])]
18 | F01 = F[:, np.array([0,1])]
19 | EAll = np.concatenate( (F12, F20, F01), axis = 0)
20 |
21 | EAll_sortrow = np.sort(EAll, axis = 1)
22 | uE, F2E_vec = np.unique(EAll_sortrow, return_inverse=True, axis=0)
23 | F2E = F2E_vec.reshape(F.shape[1], F.shape[0]).T
24 | return uE, F2E
--------------------------------------------------------------------------------
/general/find_index.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def find_index(F, VIdx):
4 | '''
5 | FINDINDEX finds desired indices in the ndarray
6 |
7 | Inputs:
8 | F: |F|-by-dim numpy ndarray
9 | VIdx: a list of indices
10 |
11 | Output:
12 | r, c: row/colummn indices in the ndarray
13 | '''
14 | mask = np.in1d(F.flatten(),VIdx)
15 | try:
16 | nDim = F.shape[1]
17 | except:
18 | nDim = 1
19 | r = np.floor(np.where(mask)[0] / (nDim*1.0) ).astype(int)
20 | c = np.where(mask)[0] % nDim
21 | return r,c
--------------------------------------------------------------------------------
/general/he_initialization.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 | import jax.numpy as np
3 | def he_initialization(n_out, n_in):
4 | """
5 | initialization based on He et al 2015. This is often recommended for ReLU like networks
6 |
7 | Inputs
8 | n_out: dimension of the output of a linear layer
9 | n_in: dimension of the input of a linear layer
10 |
11 | Outputs
12 | W: jax array of weight matrix
13 | b: jax array of bias matrix
14 | """
15 | W = np.array(onp.random.randn(n_out, n_in) * onp.sqrt(2 / n_in))
16 | b = np.zeros(n_out)
17 | return W, b
--------------------------------------------------------------------------------
/general/knn_search.py:
--------------------------------------------------------------------------------
1 | import scipy
2 | import scipy.spatial
3 |
4 | def knn_search(query_points, source_points, k):
5 | """
6 | KNNSEARCH finds the k nearnest neighbors of query_points in source_points
7 |
8 | Inputs:
9 | query_points: N-by-D numpy array of query points
10 | source_points: M-by-D numpy array existing points
11 | k: number of neighbors to return
12 |
13 | Output:
14 | dist: distance between the point in array1 with kNN
15 | NNIdx: nearest neighbor indices of array1
16 | """
17 | kdtree = scipy.spatial.cKDTree(source_points)
18 | dist, NNIdx = kdtree.query(query_points, k)
19 | return dist, NNIdx
--------------------------------------------------------------------------------
/general/list_remove_indices.py:
--------------------------------------------------------------------------------
1 | def list_remove_indices(l, indices):
2 | """
3 | this method remove multiple elements from a list by indices
4 |
5 | inputs
6 | l: list
7 | indices: indices to be removed
8 |
9 | outouts
10 | l: change in place
11 | """
12 | indices = sorted(indices, reverse=True)
13 | for idx in indices:
14 | if idx < len(l):
15 | l.pop(idx)
--------------------------------------------------------------------------------
/general/massmatrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | import scipy.sparse
4 |
5 | def massmatrix(V, F):
6 | """
7 | construct voronoi area mass matrix of vertex areas
8 |
9 | Input:
10 | V (|V|,3) array of vertex positions
11 | F (|F|,3) array of face indices
12 |
13 | Output:
14 | M (|V|,|V|) scipy sparse diagonal matrix of vertex area
15 | """
16 | # in case input is not numpy array
17 | V = np.array(V)
18 | F = np.array(F)
19 |
20 | l1 = np.sqrt(np.sum((V[F[:,1],:]-V[F[:,2],:])**2,1))
21 | l2 = np.sqrt(np.sum((V[F[:,2],:]-V[F[:,0],:])**2,1))
22 | l3 = np.sqrt(np.sum((V[F[:,0],:]-V[F[:,1],:])**2,1))
23 | lMat = np.concatenate( (l1[:,None], l2[:,None], l3[:,None]), axis =1)
24 |
25 | cos1 = (l3**2+l2**2-l1**2) / (2*l2*l3)
26 | cos2 = (l1**2+l3**2-l2**2) / (2*l1*l3)
27 | cos3 = (l1**2+l2**2-l3**2) / (2*l1*l2)
28 | cosMat = np.concatenate( (cos1[:,None], cos2[:,None], cos3[:,None]), axis =1)
29 |
30 | barycentric = cosMat * lMat
31 | normalized_barycentric = barycentric / np.sum(barycentric,1)[:,None]
32 | areas = 0.25 * np.sqrt( (l1+l2-l3)*(l1-l2+l3)*(-l1+l2+l3)*(l1+l2+l3) )
33 | partArea = normalized_barycentric * areas[:,None]
34 | quad1 = (partArea[:,1]+partArea[:,2]) * 0.5
35 | quad2 = (partArea[:,0]+partArea[:,2]) * 0.5
36 | quad3 = (partArea[:,0]+partArea[:,1]) * 0.5
37 | quads = np.concatenate( (quad1[:,None], quad2[:,None], quad3[:,None]), axis =1)
38 |
39 | boolM = cosMat[:,0]<0
40 | quads[boolM,0] = areas[boolM]*0.5
41 | quads[boolM,1] = areas[boolM]*0.25
42 | quads[boolM,2] = areas[boolM]*0.25
43 |
44 | boolM = cosMat[:,1]<0
45 | quads[boolM,0] = areas[boolM]*0.25
46 | quads[boolM,1] = areas[boolM]*0.5
47 | quads[boolM,2] = areas[boolM]*0.25
48 |
49 | boolM = cosMat[:,2]<0
50 | quads[boolM,0] = areas[boolM]*0.25
51 | quads[boolM,1] = areas[boolM]*0.25
52 | quads[boolM,2] = areas[boolM]*0.5
53 |
54 | rIdx = F.flatten()
55 | cIdx = F.flatten()
56 | val = quads.flatten()
57 | nV = V.shape[0]
58 | M = scipy.sparse.csr_matrix( (val,(rIdx,cIdx)), shape=(nV,nV), dtype=np.float64 )
59 | return M
--------------------------------------------------------------------------------
/general/mid_point_curve_simplification.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .ordered_outline import ordered_outline
3 | from .remove_unreferenced import remove_unreferenced
4 |
5 | def mid_point_curve_simplification(V,O,tarE):
6 | """
7 | this function simplify a single closed curve via collapsing the shortest edge
8 |
9 | Inputs
10 | V: |V|x3 vertex list
11 | O: |O|x2 (unordered) boundary edges
12 | tarE: target number of edges in the simplified curve
13 |
14 | Outputs
15 | V: |Vc|x3 simplified vertex list
16 | O: tarEx3 simplified boundary curve
17 |
18 | Warning:
19 | - This only support single closed curve
20 | - This is a simple collapst which does not preserve geometry nor avoid collision
21 | """
22 | V = np.array(V)
23 | L = ordered_outline(O)
24 | E_list = np.array(L[0])
25 | E = np.array([E_list, np.roll(E_list,-1)]).T
26 |
27 | # mid point collapse ordered outline
28 | ECost = np.sqrt(np.sum((V[E[:,0],:] - V[E[:,1],:])**2,1)) # cost is outline edge lengths
29 |
30 | total_collapses = E.shape[0] - tarE
31 | num_collapses = 0
32 |
33 | while True:
34 | if num_collapses % 100 == 0:
35 | print("collapse progress %d / %d\n" % (num_collapses, total_collapses))
36 |
37 | # get the minimum cost (slow)
38 | # note: a faster version should use a priority queue
39 | e = np.argmin(ECost)
40 |
41 | # check if the edge is degenerated
42 | if E[e,0] == E[e,1]:
43 | E = np.delete(E,e,0)
44 | ECost = np.delete(ECost,e)
45 | continue
46 |
47 | # move vertex vi
48 | vi, vj = E[e,:]
49 | V[vi,:] = (V[vi,:] + V[vj,:]) / 2.
50 |
51 | # reconnect edges
52 | prev_e = (e-1) % E.shape[0]
53 | next_e = (e+1) % E.shape[0]
54 | E[next_e,0] = vi # keep E[e,0] and unreference vj
55 |
56 | # update edge costs
57 | ECost[prev_e] = np.sqrt(np.sum((V[E[prev_e,0],:] - V[E[prev_e,1],:])**2))
58 | ECost[next_e] = np.sqrt(np.sum((V[E[next_e,0],:] - V[E[next_e,1],:])**2))
59 |
60 | # post collapse update
61 | E = np.delete(E,e,0)
62 | ECost = np.delete(ECost,e)
63 | num_collapses += 1
64 |
65 | # stopping
66 | if num_collapses == total_collapses:
67 | break
68 | V,E,_ = remove_unreferenced(V,E)
69 | return V, E
--------------------------------------------------------------------------------
/general/ordered_outline.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .outline import outline
3 |
4 | def ordered_outline(ForO):
5 | """
6 | this function computes an ordered boundary curves of a mesh
7 |
8 | Inputs:
9 | F: |F|x3 face index list
10 | or
11 | O: |E|x2 unordered outline
12 |
13 | Outputs
14 | L: list of list of order boundary indices such that L[0] is the list of vertices of the 0th boundary curve
15 | """
16 | if ForO.shape[1] == 3: # input is a face list
17 | F = np.array(ForO)
18 | O = outline(F)
19 | O = np.array(O)
20 | elif ForO.shape[1] == 2: # input is a (unordered) boundary curve list
21 | O = np.array(ForO)
22 |
23 | # index map for vertices such that IMV[old_vIdx] = new_vIdx
24 | uV = np.unique(O)
25 | nV = O.max() + 1
26 | IMV = np.zeros(nV, dtype = np.int64)
27 | IMV[uV] = np.arange(len(uV))
28 |
29 | # inverse index map such that invIMV[new_vIdx] = old_vIdx
30 | invIMV = uV
31 |
32 | # index map to O[:,0] such that old_vIdx = O[IMO[old_vIdx],0]
33 | IMO = np.zeros(nV, dtype = np.int64)
34 | IMO[O[:,0]] = np.arange(len(uV))
35 |
36 | L = [] # loop for multiple boundary loops
37 | visited = np.full((len(uV),),False) # whether visited (stored in new vIdx)
38 | while not np.all(visited):
39 | # get a vertex for a Loop
40 | vnew = np.where(visited == False)[0][0]
41 | v = invIMV[vnew]
42 | next_v = get_next_v(v,O,IMO)
43 |
44 | start_v = v # track starting vertex
45 | B = [] # to store each boundary loop
46 | B.append(start_v) # add the starting vertex
47 |
48 | # update visited list
49 | visited[IMV[start_v]] = True
50 |
51 |
52 | # find all the vertices of this loop
53 | while next_v != start_v:
54 | B.append(next_v) # add the next vertex
55 | visited[IMV[next_v]] = True
56 | v = next_v
57 | next_v = get_next_v(v,O,IMO)
58 | L.append(B)
59 | return L
60 |
61 | def get_next_v(v,O,IMO):
62 | return O[IMO[v],1]
--------------------------------------------------------------------------------
/general/outline.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 | import scipy
3 | import jax.numpy as np
4 | def outline(F_jnp):
5 | """
6 | this function extract (unordered) boundary edges of a mesh
7 |
8 | Inputs
9 | F_jnp: |F|x3 jax numpy array of the face indices
10 |
11 | Outputs
12 | O: |E|x2 jax numpy array of unordered boundary edges
13 |
14 | Reference:
15 | this code is adapted from https://github.com/alecjacobson/gptoolbox/blob/master/mesh/outline.m
16 | """
17 | F = onp.array(F_jnp) # convert the numpy for efficiency
18 | nV = F.max()+1
19 | row = F.flatten()
20 | col = F[:,[1,2,0]].flatten()
21 | data = onp.ones(len(row), dtype=np.int32)
22 | A = scipy.sparse.csr_matrix((data, (row, col)), shape=(nV,nV)) # build directed adj matrix
23 | AA = A - A.transpose() # figure out edges that only have one half edge
24 | AA.eliminate_zeros()
25 | I,J,V = scipy.sparse.find(AA) # get the non-zeros
26 | O = np.array([I[V>0], J[V>0]]).T # construct the boundary edge list
27 | return O
28 |
29 |
30 | # =====
31 | # I switch to the gptoolbox version because a short prifiling show that the above version is 5x faster than the bottom version on a mesh with 1200 faces
32 | # =====
33 | # import numpy as np
34 | # import jax
35 |
36 | # def outline(F):
37 | # '''
38 | # OUTLINE compute the unordered outline edges of a triangle mesh
39 |
40 | # Input:
41 | # F (|F|,3) numpy array of face indices
42 | # Output:
43 | # O (|bE|,2) numpy array of boundary vertex indices, one edge per row
44 | # '''
45 |
46 | # # All halfedges
47 | # he = np.stack((np.ravel(F[:,[1,2,0]]),np.ravel(F[:,[2,0,1]])), axis=1)
48 |
49 | # # Sort hes to be able to find duplicates later
50 | # # inds = np.argsort(he, axis=1)
51 | # # he_sorted = np.sort(he, inds, axis=1)
52 | # he_sorted = np.sort(he, axis=1)
53 |
54 | # # Extract unique rows
55 | # _,unique_indices,unique_counts = np.unique(he_sorted, axis=0, return_index=True, return_counts=True)
56 |
57 | # # All the indices with only one unique count are the original boundary edges
58 | # # in he
59 | # O = he[unique_indices[unique_counts==1],:]
60 |
61 | # return O
62 |
--------------------------------------------------------------------------------
/general/remove_unreferenced.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def remove_unreferenced(V,F):
4 | """
5 | remove invalid faces, [-1,-1,-1], and unreferenced vertices from the intrinsic mesh
6 |
7 | Inputs
8 | V: |V|x3 array of vertex locations
9 | F: |F|x3 array of face list
10 |
11 | Outputs
12 | V,F: new mesh
13 | IMV: index map for vertices such that IMV[old_vIdx] = new_vIdx
14 | """
15 | V = np.array(V)
16 | F = np.array(F)
17 |
18 | # removed unreferenced vertices/faces
19 | nV = V.shape[0]
20 |
21 | # get a list of unique vertex indices from face list
22 | uV = np.unique(F)
23 |
24 | # index map for vertices such that IMV[old_vIdx] = new_vIdx
25 | IMV = np.zeros(nV, dtype = np.int64)
26 | IMV[uV] = np.arange(len(uV))
27 |
28 | # return the new mesh
29 | V = V[uV]
30 | F = IMV[F]
31 | return V,F,IMV
--------------------------------------------------------------------------------
/general/sample_2D_grid.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 | import jax.numpy as np
3 |
4 | def sample_2D_grid(resolution, low = 0, high = 1):
5 | idx = onp.linspace(low,high,num=resolution)
6 | x, y = onp.meshgrid(idx, idx)
7 | V = onp.concatenate((x.reshape((-1,1)), y.reshape((-1,1))), 1)
8 | return np.array(V)
--------------------------------------------------------------------------------
/general/sdf_circle.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 | import jax.numpy as np
3 |
4 | def sdf_circle(x, r = 0.282, center = np.array([0.5,0.5])):
5 | """
6 | output the SDF value of a circle in 2D
7 |
8 | Inputs
9 | x: nx2 array of locations
10 | r: radius of the circle
11 | center: center point of the circle
12 |
13 | Outputs
14 | array of signed distance values at x
15 | """
16 | dx = x - center
17 | return np.sqrt(np.sum((dx)**2, axis = 1)) - r
--------------------------------------------------------------------------------
/general/sdf_cross.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 | import jax.numpy as np
3 |
4 | def sdf_cross(p, bx=0.35, by=0.12, r=0.):
5 | """
6 | output the signed distance value of a cross in 2D
7 |
8 | Inputs
9 | p: nx2 array of locations
10 | bx, by, r: parameters of the cross (please see the reference for more details)
11 |
12 | Outputs
13 | array of signed distance values at x
14 |
15 | Reference
16 | https://iquilezles.org/www/articles/distfunctions2d/distfunctions2d.htm
17 | """
18 | p = onp.array(p - 0.5)
19 | p = onp.abs(p)
20 | p = onp.sort(p,1)[:,[1,0]]
21 | b = onp.array([bx, by])
22 | q = p - b
23 | k = onp.max(q, 1)
24 | w = q
25 | w[k<=0,0] = b[1] - p[k<=0,0]
26 | w[k<=0,1] = -k[k<=0]
27 | w = onp.maximum(w, 0.0)
28 | length_w = onp.sqrt(onp.sum(w*w, 1))
29 | out = onp.sign(k) * length_w + r
30 | return np.array(out)
31 |
--------------------------------------------------------------------------------
/general/sdf_star.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 | import jax.numpy as np
3 |
4 | def sdf_star(x, r = 0.22):
5 | """
6 | output the signed distance value of a star in 2D
7 |
8 | Inputs
9 | x: nx2 array of locations
10 | r: size of the star
11 |
12 | Outputs
13 | array of signed distance values at x
14 |
15 | Reference:
16 | https://iquilezles.org/www/articles/distfunctions2d/distfunctions2d.htm
17 | """
18 | x = onp.array(x)
19 | kxy = onp.array([-0.5,0.86602540378])
20 | kyx = onp.array([0.86602540378,-0.5])
21 | kz = 0.57735026919
22 | kw = 1.73205080757
23 |
24 | x = onp.abs(x - 0.5)
25 | x -= 2.0 * onp.minimum(x.dot(kxy), 0.0)[:,None] * kxy[None,:]
26 | x -= 2.0 * onp.minimum(x.dot(kyx), 0.0)[:,None] * kyx[None,:]
27 | x[:,0] -= onp.clip(x[:,0],r*kz,r*kw)
28 | x[:,1] -= r
29 | length_x = onp.sqrt(onp.sum(x*x, 1))
30 | return np.array(length_x*onp.sign(x[:,1]))
--------------------------------------------------------------------------------
/general/sdf_triangle.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 | import jax.numpy as np
3 |
4 | def sdf_triangle(p, p0 = onp.array([.2,.2]), p1 = onp.array([.8,.2]), p2 = onp.array([.5,.8])):
5 | """
6 | output the signed distance value of a triangle in 2D
7 |
8 | Inputs
9 | p: nx2 array of locations
10 | p0,p1,p2: locations of the triangle corners
11 |
12 | Outputs
13 | array of signed distance values at x
14 | """
15 | p = onp.array(p)
16 | e0 = p1 - p0
17 | e1 = p2 - p1
18 | e2 = p0 - p2
19 | v0 = p - p0
20 | v1 = p - p1
21 | v2 = p - p2
22 | pq0 = v0 - e0[None,:] * onp.clip( v0.dot(e0) / e0.dot(e0), 0.0, 1.0 )[:,None]
23 | pq1 = v1 - e1[None,:] * onp.clip( v1.dot(e1) / e1.dot(e1), 0.0, 1.0 )[:,None]
24 | pq2 = v2 - e2[None,:] * onp.clip( v2.dot(e2) / e2.dot(e2), 0.0, 1.0 )[:,None]
25 | s = onp.sign( e0[0]*e2[1] - e0[1]*e2[0] )
26 |
27 | pq0pq0 = onp.sum(pq0*pq0, 1)
28 | d0 = onp.array([pq0pq0, s*(v0[:,0]*e0[1]-v0[:,1]*e0[0])]).T
29 | pq1pq1 = onp.sum(pq1*pq1, 1)
30 | d1 = onp.array([pq1pq1, s*(v1[:,0]*e1[1]-v1[:,1]*e1[0])]).T
31 | pq2pq2 = onp.sum(pq2*pq2, 1)
32 | d2 = onp.array([pq2pq2, s*(v2[:,0]*e2[1]-v2[:,1]*e2[0])]).T
33 | d = onp.minimum(onp.minimum(d0, d1),d2)
34 | out = -onp.sqrt(d[:,0]) * onp.sign(d[:,1])
35 | return np.array(out)
36 |
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ml-for-gp/jaxgptoolbox/7048aada5db1e6603a3d13fb1bc1ee2c61762985/logo.png
--------------------------------------------------------------------------------