├── .gitignore ├── LICENSE ├── README.md ├── assets ├── images │ └── spelunking_teaser_big.jpg └── matcaps │ ├── wax_b.png │ ├── wax_g.png │ ├── wax_k.png │ └── wax_r.png ├── environment.yml ├── sample_inputs ├── birdcage_occ.npz ├── bunny.npz ├── fox.npz └── hammer.npz └── src ├── affine.py ├── affine_layers.py ├── bucketing.py ├── extract_cell.py ├── geometry.py ├── implicit_function.py ├── implicit_mlp_utils.py ├── kd_tree.py ├── main_fit_implicit.py ├── main_intersection.py ├── main_spelunking.py ├── mlp.py ├── queries.py ├── render.py ├── sdf.py ├── slope_interval.py ├── slope_interval_layers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Build directories 2 | build/ 3 | build_debug/ 4 | dist/ 5 | *.pyc 6 | __pycache__/ 7 | 8 | *.egg-info 9 | 10 | # Editor and OS things 11 | imgui.ini 12 | .polyscope.ini 13 | .DS_Store 14 | .vscode 15 | *.swp 16 | tags 17 | *.blend1 18 | pyrightconfig.json 19 | 20 | # Prerequisites 21 | *.d 22 | 23 | # Compiled Object files 24 | *.slo 25 | *.lo 26 | *.o 27 | *.obj 28 | 29 | # Precompiled Headers 30 | *.gch 31 | *.pch 32 | 33 | # Compiled Dynamic libraries 34 | *.so 35 | *.dylib 36 | *.dll 37 | 38 | # Fortran module files 39 | *.mod 40 | *.smod 41 | 42 | # Compiled Static libraries 43 | *.lai 44 | *.la 45 | *.a 46 | *.lib 47 | 48 | # Executables 49 | *.exe 50 | *.out 51 | *.app 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Nicholas Sharp and Alec Jacobson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Perform geometric queries on **neural implicit surfaces** like **ray casting**, **intersection testing**, **fast mesh extraction**, **closest points**, and **more**. Works on _general_ neural implicit surfaces (i.e. does not require a fitted signed distance function). Implemented in JAX. 2 | 3 |
4 |
5 |
7 | [project page] 8 | [PDF (4MB)] 9 | Authors: Nicholas Sharp & Alec Jacobson 10 |
11 | 12 | 13 | This code accompanies the paper **"Spelunking the Deep: Guaranteed Queries for General Neural Implicit Surfaces via Range Analysis"**, published at **SIGGRAPH 2022** (recognized with a Best Paper Award!). 14 | 15 | --- 16 | 17 | Neural implicit surface representations encode a 3D surface as a level set of a neural network applied to coordinates; this representation has many promising properties. But once you have one of these surfaces, how do you perform standard geometric queries like casting rays against the surface, or testing if two such surfaces intersect? This is especially tricky if the neural function is _not_ a signed distance function (SDF), such as occupancy functions as in popular radiance field formulations, or randomly initialized networks during training. 18 | 19 | This project introduces a technique for implementing these queries using __range analysis__, and automatic function transformation which we use to analyze a forward pass of the network and compute provably-valid bounds on the output range of the network over a spatial region. This basic operation is used as a building block for a variety of geometric queries. 20 | 21 | 22 | ## How does it work? How do I use it? 23 | 24 | This project **does not** propose any new network architectures, training procedures, etc. Instead, it takes an existing neural implicit MLP and analyzes it to perform geometric queries. Right now the code is only set up for simple feedforward MLP architectures with ReLU or TanH activations (please file an issue to tell us about other architectures you would like to see!). 25 | 26 | Exposing this machinery as a library is tricky, because the algorithm needs to analyze the internals of your neural network evaluation to work (somewhat similar to autodiff frameworks). For this reason, the library takes a simple specification of your MLP in a dictionary format; a convenience script is included to fit MLPs in this format, or see [below](TODO) for manually constructing dictionaries from your own data. 27 | 28 | Once your MLP is ready in the format expected by this library, the functions in `queries.py` (raycasting) and `kd_tree.py` (spatial queries) can be called to perform queries. 29 | 30 | Additionally, several [demo scripts](#demo-scripts) are included to interactively explore these queries. 31 | 32 | 33 | ## Quick guide: 34 | 35 | - Affine arithmetic rules appear in `affine.py` and `affine_layers.py` 36 | - Queries are implemented in `queries.py` (raycasting) and `kd_tree.py` (spatial queries) 37 | 38 | > **PERFORMANCE NOTE:** JAX uses JIT-compiled kernels for performance. All algorithms will be dramatically slower on the first pass due to JIT compilation (which can take up to a minute). We make use of bucketing to ensure there are only a small number of kernels that need to be JIT'd for a given algorithm, but it still takes time. all routines should be run twice to get an actual idea of performance. 39 | 40 | 41 | ## Installation 42 | 43 | This code has been tested on both Linux and OSX machines. Windows support is unknown. 44 | 45 | Some standard Python packages are needed, all available in package managers. A conda `environment.yml` file is included to help setting up the environment, but note that installing JAX may require nonstandard instructions---see the JAX documentation. Code has been tested with JAX 0.2.27 and 0.3.4. 46 | 47 | 48 | ## Demo scripts 49 | 50 | #### Spelunking 51 | 52 | This script provides an interactive GUI allowing the exploration of most of the queries described in this work. 53 | 54 | Run like: 55 | 56 | ``` 57 | python src/main_spelunking.py sample_inputs/fox.npz 58 | ``` 59 | 60 | This application can run most of the algorithms described in this work. Use the buttons on the right to explore them and visualize the output. 61 | 62 | Shapes are "previewed" via coarse meshes for the sake of the user interface. The coarse preview mesh is not used for any computation. 63 | 64 | To make it easier to see hierarchy trees, enable a slice plane in upper left menu panel under [View] --> [Slice Plane]. 65 | 66 | 67 | #### Intersection 68 | 69 | This script provides an interactive GUI allowing the exploration of intersection testing queries between two neural implicit shapes. 70 | 71 | Run like: 72 | 73 | ``` 74 | python src/main_intersection.py sample_inputs/hammer.npz sample_inputs/bunny.npz 75 | ``` 76 | 77 | To adjust the objects, click in the left menu bar [Options] --> [Transform] --> [Show Gizmo] and drag around. Don't use the scaling function of the gizmo, it is not supported. 78 | 79 | This query is configured to detect a single intersection, and exits as soon as any intersection is shown. The result will be printed to the terminal, and a point will be placed at the intersection location, though this location will be inside a shape, of course. 80 | 81 | To make it easier to see intersections, try [Options] --> [Transparency] to make meshes transparent. 82 | 83 | To make it easier to see hierarchy trees, enable a slice plane in upper left menu panel under [View] --> [Slice Plane]. 84 | 85 | 86 | #### Fit implicit 87 | 88 | This script is a helper to quickly fit suitable test MLPs to triangle meshes and save them in the format expected by this codebase. 89 | 90 | Run like: 91 | ``` 92 | python src/main_fit_implicit.py path/to/your/favorite/mesh.obj mlp_out.npz 93 | ``` 94 | 95 | If you would like to fit your own implicit functions in our format, this is a simple script to fit an implicit function to a given mesh. The flags have options for sdf vs. occupancy, layer sizes, etc. 96 | 97 | 98 | 99 | ## Misc notes 100 | 101 | Currently, JAX defaults to allocating nearly all GPU memory at startup. This may cause problems when subroutines external to JAX attempt to allocate additional memory. One workaround is to prepend the environment variable `XLA_PYTHON_CLIENT_MEM_FRACTION=.60` 102 | -------------------------------------------------------------------------------- /assets/images/spelunking_teaser_big.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmwsharp/neural-implicit-queries/c17e4b54f216cefb02d00ddba25c4f15b9873278/assets/images/spelunking_teaser_big.jpg -------------------------------------------------------------------------------- /assets/matcaps/wax_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmwsharp/neural-implicit-queries/c17e4b54f216cefb02d00ddba25c4f15b9873278/assets/matcaps/wax_b.png -------------------------------------------------------------------------------- /assets/matcaps/wax_g.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmwsharp/neural-implicit-queries/c17e4b54f216cefb02d00ddba25c4f15b9873278/assets/matcaps/wax_g.png -------------------------------------------------------------------------------- /assets/matcaps/wax_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmwsharp/neural-implicit-queries/c17e4b54f216cefb02d00ddba25c4f15b9873278/assets/matcaps/wax_k.png -------------------------------------------------------------------------------- /assets/matcaps/wax_r.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmwsharp/neural-implicit-queries/c17e4b54f216cefb02d00ddba25c4f15b9873278/assets/matcaps/wax_r.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: implicit-env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - _openmp_mutex=4.5 8 | - absl-py=1.0.0 9 | - blas=1.0 10 | - blosc=1.21.0 11 | - brotli=1.0.9 12 | - brunsli=0.1 13 | - bzip2=1.0.8 14 | - c-ares=1.18.1 15 | - ca-certificates=2021.10.8 16 | - certifi=2021.10.8 17 | - cfitsio=3.470 18 | - charls=2.2.0 19 | - cloudpickle=2.0.0 20 | - cycler=0.11.0 21 | - cytoolz=0.11.0 22 | - dask-core=2021.10.0 23 | - dbus=1.13.18 24 | - expat=2.4.1 25 | - fontconfig=2.13.1 26 | - fonttools=4.25.0 27 | - freetype=2.11.0 28 | - fsspec=2022.1.0 29 | - giflib=5.2.1 30 | - glib=2.69.1 31 | - gst-plugins-base=1.14.0 32 | - gstreamer=1.14.0 33 | - icu=58.2 34 | - igl=2.2.1 35 | - imagecodecs=2021.8.26 36 | - imageio=2.9.0 37 | - intel-openmp=2021.4.0 38 | - jax=0.2.27 39 | - jaxlib=0.1.75 40 | - jpeg=9d 41 | - jxrlib=1.1 42 | - kiwisolver=1.3.1 43 | - krb5=1.19.2 44 | - lcms2=2.12 45 | - ld_impl_linux-64=2.35.1 46 | - lerc=3.0 47 | - libaec=1.0.4 48 | - libblas=3.9.0 49 | - libcblas=3.9.0 50 | - libcurl=7.80.0 51 | - libdeflate=1.8 52 | - libedit=3.1.20210910 53 | - libev=4.33 54 | - libffi=3.3 55 | - libgcc-ng=11.2.0 56 | - libgfortran-ng=7.5.0 57 | - libgfortran4=7.5.0 58 | - libnghttp2=1.46.0 59 | - libpng=1.6.37 60 | - libssh2=1.9.0 61 | - libstdcxx-ng=11.2.0 62 | - libtiff=4.2.0 63 | - libuuid=1.0.3 64 | - libwebp=1.2.0 65 | - libwebp-base=1.2.0 66 | - libxcb=1.14 67 | - libxml2=2.9.12 68 | - libzopfli=1.0.3 69 | - llvm-openmp=12.0.1 70 | - locket=0.2.1 71 | - lz4-c=1.9.3 72 | - matplotlib=3.5.0 73 | - matplotlib-base=3.5.0 74 | - mkl=2021.4.0 75 | - mkl-service=2.4.0 76 | - mkl_fft=1.3.1 77 | - mkl_random=1.2.2 78 | - munkres=1.1.4 79 | - ncurses=6.3 80 | - networkx=2.6.3 81 | - numpy=1.21.2 82 | - numpy-base=1.21.2 83 | - olefile=0.46 84 | - openjpeg=2.4.0 85 | - openssl=1.1.1l 86 | - opt_einsum=3.3.0 87 | - packaging=21.3 88 | - partd=1.2.0 89 | - pcre=8.45 90 | - pillow=8.4.0 91 | - pip=21.2.4 92 | - polyscope=1.3.0 93 | - pyparsing=3.0.4 94 | - pyqt=5.9.2 95 | - python=3.9.7 96 | - python-dateutil=2.8.2 97 | - python-flatbuffers=2.0 98 | - python_abi=3.9 99 | - pywavelets=1.1.1 100 | - pyyaml=6.0 101 | - qt=5.9.7 102 | - readline=8.1.2 103 | - scikit-image=0.18.3 104 | - scipy=1.7.3 105 | - setuptools=58.0.4 106 | - sip=4.19.13 107 | - six=1.16.0 108 | - snappy=1.1.8 109 | - sqlite=3.37.0 110 | - tifffile=2021.7.2 111 | - tk=8.6.11 112 | - toolz=0.11.2 113 | - tornado=6.1 114 | - typing_extensions=4.0.1 115 | - tzdata=2021e 116 | - wheel=0.37.1 117 | - xorg-kbproto=1.0.7 118 | - xorg-libx11=1.7.2 119 | - xorg-libxext=1.3.4 120 | - xorg-libxinerama=1.1.4 121 | - xorg-xextproto=7.3.0 122 | - xorg-xproto=7.0.31 123 | - xz=5.2.5 124 | - yaml=0.2.5 125 | - zfp=0.5.5 126 | - zlib=1.2.11 127 | - zstd=1.4.9 128 | -------------------------------------------------------------------------------- /sample_inputs/birdcage_occ.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmwsharp/neural-implicit-queries/c17e4b54f216cefb02d00ddba25c4f15b9873278/sample_inputs/birdcage_occ.npz -------------------------------------------------------------------------------- /sample_inputs/bunny.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmwsharp/neural-implicit-queries/c17e4b54f216cefb02d00ddba25c4f15b9873278/sample_inputs/bunny.npz -------------------------------------------------------------------------------- /sample_inputs/fox.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmwsharp/neural-implicit-queries/c17e4b54f216cefb02d00ddba25c4f15b9873278/sample_inputs/fox.npz -------------------------------------------------------------------------------- /sample_inputs/hammer.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nmwsharp/neural-implicit-queries/c17e4b54f216cefb02d00ddba25c4f15b9873278/sample_inputs/hammer.npz -------------------------------------------------------------------------------- /src/affine.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import dataclasses 3 | from dataclasses import dataclass 4 | 5 | import numpy as np 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | 10 | import utils 11 | 12 | import implicit_function 13 | from implicit_function import SIGN_UNKNOWN, SIGN_POSITIVE, SIGN_NEGATIVE 14 | 15 | # === Function wrappers 16 | 17 | class AffineImplicitFunction(implicit_function.ImplicitFunction): 18 | 19 | def __init__(self, affine_func, ctx): 20 | super().__init__("classify-only") 21 | self.affine_func = affine_func 22 | self.ctx = ctx 23 | self.mode_dict = {'ctx' : self.ctx} 24 | 25 | 26 | def __call__(self, params, x): 27 | f = lambda x : self.affine_func(params, x, self.mode_dict) 28 | return wrap_scalar(f)(x) 29 | 30 | # the parent class automatically delegates to this 31 | # def classify_box(self, params, box_lower, box_upper): 32 | # pass 33 | 34 | def classify_general_box(self, params, box_center, box_vecs, offset=0.): 35 | 36 | d = box_center.shape[-1] 37 | v = box_vecs.shape[-2] 38 | assert box_center.shape == (d,), "bad box_vecs shape" 39 | assert box_vecs.shape == (v,d), "bad box_vecs shape" 40 | keep_ctx = dataclasses.replace(self.ctx, affine_domain_terms=v) 41 | 42 | # evaluate the function 43 | input = coordinates_in_general_box(keep_ctx, box_center, box_vecs) 44 | output = self.affine_func(params, input, {'ctx' : keep_ctx}) 45 | 46 | # compute relevant bounds 47 | may_lower, may_upper = may_contain_bounds(keep_ctx, output) 48 | # must_lower, must_upper = must_contain_bounds(keep_ctx, output) 49 | 50 | # determine the type of the region 51 | output_type = SIGN_UNKNOWN 52 | output_type = jnp.where(may_lower > offset, SIGN_POSITIVE, output_type) 53 | output_type = jnp.where(may_upper < -offset, SIGN_NEGATIVE, output_type) 54 | 55 | return output_type 56 | 57 | # === Affine utilities 58 | 59 | # We represent affine data as a tuple input=(base,aff,err). Base is a normal shape (d,) primal vector value, affine is a (v,d) array of affine coefficients (may be v=0), err is a centered interval error shape (d,), which must be nonnegative. 60 | # For constant values, aff == err == None. If is_const(input) == False, then it is guaranteed that aff and err are non-None. 61 | 62 | @dataclass(frozen=True) 63 | class AffineContext(): 64 | mode: str = 'affine_fixed' 65 | truncate_count: int = -777 66 | truncate_policy: str = 'absolute' 67 | affine_domain_terms: int = 0 68 | n_append: int = 0 69 | 70 | def __post_init__(self): 71 | if self.mode not in ['interval', 'affine_fixed', 'affine_truncate', 'affine_append', 'affine_all']: 72 | raise ValueError("invalid mode") 73 | 74 | if self.mode == 'affine_truncate': 75 | if self.truncate_count is None: 76 | raise ValueError("must specify truncate count") 77 | 78 | def is_const(input): 79 | base, aff, err = input 80 | if err is not None: return False 81 | return aff is None or aff.shape[0] == 0 82 | 83 | 84 | # Compute the 'radius' (width of the approximation) 85 | def radius(input): 86 | if is_const(input): return 0. 87 | base, aff, err = input 88 | rad = jnp.sum(jnp.abs(aff), axis=0) 89 | if err is not None: 90 | rad += err 91 | return rad 92 | 93 | # Constuct affine inputs for the coordinates in k-dimensional box 94 | # lower,upper should be vectors of length-k 95 | def coordinates_in_box(ctx, lower, upper): 96 | center = 0.5 * (lower+upper) 97 | vec = upper - center 98 | axis_vecs = jnp.diag(vec) 99 | return coordinates_in_general_box(ctx, center, axis_vecs) 100 | 101 | # Constuct affine inputs for the coordinates in k-dimensional box, 102 | # which is not necessarily axis-aligned 103 | # - center is the center of the box 104 | # - vecs is a (V,D) array of vectors which point from the center of the box to its 105 | # edges. These will correspond to each of the affine symbols, with the direction 106 | # of the vector becoming the positive orientaiton for the symbol. 107 | # (this function is nearly a no-op, but giving it this name makes it easier to 108 | # reason about) 109 | def coordinates_in_general_box(ctx, center, vecs): 110 | base = center 111 | if ctx.mode == 'interval': 112 | aff = jnp.zeros((0,center.shape[-1])) 113 | err = jnp.sum(jnp.abs(vecs), axis=0) 114 | else: 115 | aff = vecs 116 | err = jnp.zeros_like(center) 117 | return base, aff, err 118 | 119 | def may_contain_bounds(ctx, input,): 120 | ''' 121 | An interval range of values that `input` _may_ take along the domain 122 | ''' 123 | base, aff, err = input 124 | rad = radius(input) 125 | return base-rad, base+rad 126 | 127 | def truncate_affine(ctx, input): 128 | # do nothing if the input is a constant or we are not in truncate mode 129 | if is_const(input): return input 130 | if ctx.mode != 'affine_truncate': 131 | return input 132 | 133 | # gather values 134 | base, aff, err = input 135 | n_keep = ctx.truncate_count 136 | 137 | # if the affine list is shorter than the truncation length, nothing to do 138 | if aff.shape[0] <= n_keep: 139 | return input 140 | 141 | # compute the magnitudes of each affine value 142 | # TODO fanicier policies? 143 | if ctx.truncate_policy == 'absolute': 144 | affine_mags = jnp.sum(jnp.abs(aff), axis=-1) 145 | elif ctx.truncate_policy == 'relative': 146 | affine_mags = jnp.sum(jnp.abs(aff), axis=-1) / jnp.abs(base) 147 | else: 148 | raise RuntimeError("bad policy") 149 | 150 | 151 | # sort the affine terms by by magnitude 152 | sort_inds = jnp.argsort(-affine_mags, axis=-1) # sort to decreasing order 153 | aff = aff[sort_inds,:] 154 | 155 | # keep the n_keep highest-magnitude entries 156 | aff_keep = aff[:n_keep,:] 157 | aff_drop = aff[n_keep:,:] 158 | 159 | # for all the entries we aren't keeping, add their contribution to the interval error 160 | err = err + jnp.sum(jnp.abs(aff_drop), axis=0) 161 | 162 | return base, aff_keep, err 163 | 164 | def apply_linear_approx(ctx, input, alpha, beta, delta): 165 | base, aff, err = input 166 | base = alpha * base + beta 167 | if aff is not None: 168 | aff = alpha * aff 169 | 170 | # This _should_ always be positive by definition. Always be sure your 171 | # approximation routines are generating positive delta. 172 | # At most, we defending against floating point error here. 173 | delta = jnp.abs(delta) 174 | 175 | if ctx.mode in ['interval', 'affine_fixed']: 176 | err = alpha * err + delta 177 | elif ctx.mode in ['affine_truncate', 'affine_all']: 178 | err = alpha * err 179 | new_aff = jnp.diag(delta) 180 | aff = jnp.concatenate((aff, new_aff), axis=0) 181 | base, aff, err = truncate_affine(ctx, (base, aff, err)) 182 | 183 | elif ctx.mode in ['affine_append']: 184 | err = alpha * err 185 | 186 | keep_vals, keep_inds = jax.lax.top_k(delta, ctx.n_append) 187 | row_inds = jnp.arange(ctx.n_append) 188 | new_aff = jnp.zeros((ctx.n_append, aff.shape[-1])) 189 | new_aff = new_aff.at[row_inds, keep_inds].set(keep_vals) 190 | aff = jnp.concatenate((aff, new_aff), axis=0) 191 | err = err + (jnp.sum(delta) - jnp.sum(keep_vals)) # add in the error for the affs we didn't keep 192 | 193 | return base, aff, err 194 | 195 | # Convert to/from the affine representation from an ordinary value representing a scalar 196 | def from_scalar(x): 197 | return x, None, None 198 | def to_scalar(input): 199 | if not is_const(input): 200 | raise ValueError("non const input") 201 | return input[0] 202 | def wrap_scalar(func): 203 | return lambda x : to_scalar(func(from_scalar(x))) 204 | -------------------------------------------------------------------------------- /src/affine_layers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax import lax 6 | 7 | import affine 8 | import mlp 9 | import utils 10 | 11 | def dense(input, A, b, ctx): 12 | if(affine.is_const(input)): 13 | out = jnp.dot(input[0], A) 14 | if b is not None: 15 | out += b 16 | return out, None, None 17 | 18 | base, aff, err = input 19 | 20 | def dot(x, with_abs=False): 21 | myA = jnp.abs(A) if with_abs else A 22 | return jnp.dot(x, myA) 23 | 24 | base = dot(base) 25 | aff = jax.vmap(dot)(aff) 26 | err = dot(err, with_abs=True) 27 | 28 | if b is not None: 29 | base += b 30 | 31 | return base, aff, err 32 | mlp.apply_func['affine']['dense'] = dense 33 | 34 | def relu(input, ctx): 35 | # Chebyshev bound 36 | base, aff, err = input 37 | 38 | if affine.is_const(input): 39 | return jax.nn.relu(base), aff, err 40 | 41 | lower, upper = affine.may_contain_bounds(ctx, input) 42 | 43 | # Compute the linearized approximation 44 | alpha = (jax.nn.relu(upper) - jax.nn.relu(lower)) / (upper - lower) 45 | alpha = jnp.where(lower >= 0, 1., alpha) 46 | alpha = jnp.where(upper < 0, 0., alpha) 47 | # handle numerical badness in the denominator above 48 | alpha = jnp.nan_to_num(alpha, nan=0.0, copy=False) # necessary? 49 | alpha = jnp.clip(alpha, a_min=0., a_max=1.) 50 | 51 | # here, alpha/beta are necessarily positive, which makes this simpler 52 | beta = (jax.nn.relu(lower) - alpha * lower) / 2 53 | delta = beta 54 | 55 | output = affine.apply_linear_approx(ctx, input, alpha, beta, delta) 56 | return output 57 | mlp.apply_func['affine']['relu'] = relu 58 | 59 | def elu(input, ctx): 60 | # Chebyshev bound 61 | # Confusingly, elu has a parameter typically called 'alpha', and we also use 'alpha' for our linearizaiton notation. Here we simply ignore and do not support elu's alpha. 62 | base, aff, err = input 63 | 64 | if affine.is_const(input): 65 | return jax.nn.elu(base), aff, err 66 | 67 | lower, upper = affine.may_contain_bounds(ctx, input) 68 | 69 | # Compute the linearized approximation 70 | lowerF = jax.nn.elu(lower) 71 | upperF = jax.nn.elu(upper) 72 | # lowerS = jnp.where(lower < 0, lowerF + 1., 1.) 73 | # upperS = jnp.where(upper < 0, upperF + 1., 1.) 74 | lowerS = jnp.minimum(jnp.exp(lower), 1.) # more numerically stable than ^^^, but costs exp() 75 | upperS = jnp.minimum(jnp.exp(upper), 1.) 76 | 77 | alpha = (upperF - lowerF) / (upper - lower) 78 | alpha = jnp.where(lower >= 0, 1., alpha) 79 | # handle numerical badness in the denominator above 80 | alpha = jnp.nan_to_num(alpha, nan=0.0, copy=False) # necessary? 81 | alpha = jnp.clip(alpha, a_min=lowerS, a_max=upperS) 82 | 83 | # the alpha tangent point necessarily occurs in the <= 0. part of the function 84 | r_upper = (lowerF - alpha * lower) 85 | x_lower = jnp.clip(jnp.log(alpha), a_min=lower, a_max=upper) 86 | r_lower = (alpha-1.) - alpha * x_lower 87 | beta = 0.5 * (r_upper + r_lower) 88 | # delta = r_upper - beta 89 | delta = 0.5 * jnp.abs(r_upper - r_lower) # this is very defensive, to ensure delta>=0 90 | 91 | # in strictly > 0 case, just directly set the result 92 | alpha = jnp.where(lower >= 0, 1., alpha) 93 | beta = jnp.where(lower >= 0, 0., beta) 94 | delta = jnp.where(lower >= 0, 0., delta) 95 | 96 | output = affine.apply_linear_approx(ctx, input, alpha, beta, delta) 97 | return output 98 | mlp.apply_func['affine']['elu'] = elu 99 | 100 | def sin(input, ctx): 101 | # not-quite Chebyshev bound 102 | base, aff, err = input 103 | pi = jnp.pi 104 | 105 | if affine.is_const(input): 106 | return jnp.sin(base), aff, err 107 | 108 | lower, upper = affine.may_contain_bounds(ctx, input) 109 | 110 | slope_lower, slope_upper = utils.cos_bound(lower, upper) 111 | alpha = 0.5 * (slope_lower + slope_upper) # this is NOT the Chebyshev value, but seems reasonable 112 | alpha = jnp.clip(alpha, a_min=-1., a_max=1.) # (should already be there, this is for numerics only) 113 | 114 | # We want to find the minima/maxima of (sin(x) - alpha*x) on [lower, upper] to compute our 115 | # beta and delta. In addition to the endpoints, some calc show there can be interior 116 | # extrema at +-arccos(alpha) + 2kpi for some integer k. 117 | # The extrema will 118 | intA = jnp.arccos(alpha) 119 | intB = -intA 120 | 121 | # The the first and last occurence of a value which repeats mod 2pi on the domain [lower, upper] 122 | # (these give the only possible locations for our extrema) 123 | def first(x): return 2.*pi*jnp.ceil((lower + x) / (2.*pi)) - x 124 | def last(x): return 2.*pi*jnp.floor((upper - x) / (2.*pi)) + x 125 | 126 | extrema_locs = [lower, upper, first(intA), last(intA), first(intB), last(intB)] 127 | extrema_locs = [jnp.clip(x, a_min=lower, a_max=upper) for x in extrema_locs] 128 | extrema_vals = [jnp.sin(x) - alpha * x for x in extrema_locs] 129 | 130 | r_lower = utils.minimum_all(extrema_vals) 131 | r_upper = utils.maximum_all(extrema_vals) 132 | 133 | beta = 0.5 * (r_upper + r_lower) 134 | delta = r_upper - beta 135 | 136 | output = affine.apply_linear_approx(ctx, input, alpha, beta, delta) 137 | return output 138 | mlp.apply_func['affine']['sin'] = sin 139 | 140 | def pow2_frequency_encode(input, ctx, coefs, shift=None): 141 | base, aff, err = input 142 | 143 | # TODO debug 144 | if len(base.shape) > 1: 145 | raise ValueError("big base") 146 | 147 | # expand the length-d inputs to a lenght-d*c vector 148 | def s(with_shift, x): 149 | out = (x[:,None] * coefs[None,:]) 150 | if with_shift and shift is not None: 151 | out += shift 152 | return out.flatten() 153 | 154 | base = s(True, base) 155 | if affine.is_const(input): 156 | return base, None, None 157 | 158 | aff = jax.vmap(partial(s, False))(aff) 159 | err = s(False, err) 160 | 161 | return base, aff, err 162 | mlp.apply_func['affine']['pow2_frequency_encode'] = pow2_frequency_encode 163 | 164 | def squeeze_last(input, ctx): 165 | base, aff, err = input 166 | s = lambda x : jnp.squeeze(x, axis=0) 167 | base = s(base) 168 | if affine.is_const(input): 169 | return base, None, None 170 | aff = jax.vmap(s)(aff) 171 | err = s(err) 172 | return base, aff, err 173 | mlp.apply_func['affine']['squeeze_last'] = squeeze_last 174 | 175 | def spatial_transformation(input, R, t, ctx): 176 | # if the shape transforms by R,t, input points need the opposite transform 177 | R_inv = jnp.linalg.inv(R) 178 | t_inv = jnp.dot(R_inv, -t) 179 | return dense(input, A=R_inv, b=t_inv, ctx=ctx) 180 | mlp.apply_func['affine']['spatial_transformation'] = spatial_transformation 181 | -------------------------------------------------------------------------------- /src/bucketing.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from functools import partial 5 | 6 | # Populate with powers of 2 128 and up 7 | bucket_sizes = [] 8 | for s in range(7,31): 9 | bucket_sizes.append(2**s) 10 | def get_next_bucket_size(s): 11 | for b in bucket_sizes: 12 | if s <= b: 13 | return b 14 | raise ValueError("max bucket size exceeded") 15 | 16 | @partial(jax.jit, static_argnames=("bucket_size")) 17 | def compactify_and_rebucket_arrays(mask, bucket_size, *arrs): 18 | N_in = mask.sum() 19 | out_mask = jnp.arange(0, bucket_size) < N_in 20 | INVALID_IND = bucket_size + 1 21 | target_inds = jnp.nonzero(mask, size=bucket_size, fill_value=INVALID_IND) 22 | 23 | out_arrs = [] 24 | for a in arrs: 25 | if a is None: 26 | out_arrs.append(a) 27 | continue 28 | 29 | out = a.at[target_inds,...].get(mode='drop').squeeze(0) 30 | out_arrs.append(out) 31 | 32 | return out_mask, N_in, *out_arrs 33 | 34 | 35 | def fits_in_smaller_bucket(size, curr_bucket_size): 36 | return get_next_bucket_size(size) < curr_bucket_size 37 | -------------------------------------------------------------------------------- /src/extract_cell.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from functools import partial 5 | import math 6 | 7 | import numpy as np 8 | 9 | import utils 10 | 11 | # The raw table of all 256 cases (from http://www.paulbourke.net/geometry/polygonise/) 12 | MC_TRI_TABLE_RAW = [ 13 | [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 14 | [0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 15 | [0, 1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 16 | [1, 8, 3, 9, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 17 | [1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 18 | [0, 8, 3, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 19 | [9, 2, 10, 0, 2, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 20 | [2, 8, 3, 2, 10, 8, 10, 9, 8, -1, -1, -1, -1, -1, -1, -1], 21 | [3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 22 | [0, 11, 2, 8, 11, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 23 | [1, 9, 0, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 24 | [1, 11, 2, 1, 9, 11, 9, 8, 11, -1, -1, -1, -1, -1, -1, -1], 25 | [3, 10, 1, 11, 10, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 26 | [0, 10, 1, 0, 8, 10, 8, 11, 10, -1, -1, -1, -1, -1, -1, -1], 27 | [3, 9, 0, 3, 11, 9, 11, 10, 9, -1, -1, -1, -1, -1, -1, -1], 28 | [9, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 29 | [4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 30 | [4, 3, 0, 7, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 31 | [0, 1, 9, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 32 | [4, 1, 9, 4, 7, 1, 7, 3, 1, -1, -1, -1, -1, -1, -1, -1], 33 | [1, 2, 10, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 34 | [3, 4, 7, 3, 0, 4, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1], 35 | [9, 2, 10, 9, 0, 2, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1], 36 | [2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, -1, -1, -1, -1], 37 | [8, 4, 7, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 38 | [11, 4, 7, 11, 2, 4, 2, 0, 4, -1, -1, -1, -1, -1, -1, -1], 39 | [9, 0, 1, 8, 4, 7, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1], 40 | [4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, -1, -1, -1, -1], 41 | [3, 10, 1, 3, 11, 10, 7, 8, 4, -1, -1, -1, -1, -1, -1, -1], 42 | [1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, -1, -1, -1, -1], 43 | [4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, -1, -1, -1, -1], 44 | [4, 7, 11, 4, 11, 9, 9, 11, 10, -1, -1, -1, -1, -1, -1, -1], 45 | [9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 46 | [9, 5, 4, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 47 | [0, 5, 4, 1, 5, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 48 | [8, 5, 4, 8, 3, 5, 3, 1, 5, -1, -1, -1, -1, -1, -1, -1], 49 | [1, 2, 10, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 50 | [3, 0, 8, 1, 2, 10, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1], 51 | [5, 2, 10, 5, 4, 2, 4, 0, 2, -1, -1, -1, -1, -1, -1, -1], 52 | [2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, -1, -1, -1, -1], 53 | [9, 5, 4, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 54 | [0, 11, 2, 0, 8, 11, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1], 55 | [0, 5, 4, 0, 1, 5, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1], 56 | [2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, -1, -1, -1, -1], 57 | [10, 3, 11, 10, 1, 3, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1], 58 | [4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, -1, -1, -1, -1], 59 | [5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, -1, -1, -1, -1], 60 | [5, 4, 8, 5, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1], 61 | [9, 7, 8, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 62 | [9, 3, 0, 9, 5, 3, 5, 7, 3, -1, -1, -1, -1, -1, -1, -1], 63 | [0, 7, 8, 0, 1, 7, 1, 5, 7, -1, -1, -1, -1, -1, -1, -1], 64 | [1, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 65 | [9, 7, 8, 9, 5, 7, 10, 1, 2, -1, -1, -1, -1, -1, -1, -1], 66 | [10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, -1, -1, -1, -1], 67 | [8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, -1, -1, -1, -1], 68 | [2, 10, 5, 2, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1], 69 | [7, 9, 5, 7, 8, 9, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1], 70 | [9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, -1, -1, -1, -1], 71 | [2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, -1, -1, -1, -1], 72 | [11, 2, 1, 11, 1, 7, 7, 1, 5, -1, -1, -1, -1, -1, -1, -1], 73 | [9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, -1, -1, -1, -1], 74 | [5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, -1], 75 | [11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, -1], 76 | [11, 10, 5, 7, 11, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 77 | [10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 78 | [0, 8, 3, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 79 | [9, 0, 1, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 80 | [1, 8, 3, 1, 9, 8, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1], 81 | [1, 6, 5, 2, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 82 | [1, 6, 5, 1, 2, 6, 3, 0, 8, -1, -1, -1, -1, -1, -1, -1], 83 | [9, 6, 5, 9, 0, 6, 0, 2, 6, -1, -1, -1, -1, -1, -1, -1], 84 | [5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, -1, -1, -1, -1], 85 | [2, 3, 11, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 86 | [11, 0, 8, 11, 2, 0, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1], 87 | [0, 1, 9, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1], 88 | [5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, -1, -1, -1, -1], 89 | [6, 3, 11, 6, 5, 3, 5, 1, 3, -1, -1, -1, -1, -1, -1, -1], 90 | [0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, -1, -1, -1, -1], 91 | [3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, -1, -1, -1, -1], 92 | [6, 5, 9, 6, 9, 11, 11, 9, 8, -1, -1, -1, -1, -1, -1, -1], 93 | [5, 10, 6, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 94 | [4, 3, 0, 4, 7, 3, 6, 5, 10, -1, -1, -1, -1, -1, -1, -1], 95 | [1, 9, 0, 5, 10, 6, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1], 96 | [10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, -1, -1, -1, -1], 97 | [6, 1, 2, 6, 5, 1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1], 98 | [1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, -1, -1, -1, -1], 99 | [8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, -1, -1, -1, -1], 100 | [7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, -1], 101 | [3, 11, 2, 7, 8, 4, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1], 102 | [5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, -1, -1, -1, -1], 103 | [0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1], 104 | [9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, -1], 105 | [8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, -1, -1, -1, -1], 106 | [5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, -1], 107 | [0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, -1], 108 | [6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, -1, -1, -1, -1], 109 | [10, 4, 9, 6, 4, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 110 | [4, 10, 6, 4, 9, 10, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1], 111 | [10, 0, 1, 10, 6, 0, 6, 4, 0, -1, -1, -1, -1, -1, -1, -1], 112 | [8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, -1, -1, -1, -1], 113 | [1, 4, 9, 1, 2, 4, 2, 6, 4, -1, -1, -1, -1, -1, -1, -1], 114 | [3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, -1, -1, -1, -1], 115 | [0, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 116 | [8, 3, 2, 8, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1], 117 | [10, 4, 9, 10, 6, 4, 11, 2, 3, -1, -1, -1, -1, -1, -1, -1], 118 | [0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, -1, -1, -1, -1], 119 | [3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, -1, -1, -1, -1], 120 | [6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, -1], 121 | [9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, -1, -1, -1, -1], 122 | [8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, -1], 123 | [3, 11, 6, 3, 6, 0, 0, 6, 4, -1, -1, -1, -1, -1, -1, -1], 124 | [6, 4, 8, 11, 6, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 125 | [7, 10, 6, 7, 8, 10, 8, 9, 10, -1, -1, -1, -1, -1, -1, -1], 126 | [0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, -1, -1, -1, -1], 127 | [10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, -1, -1, -1, -1], 128 | [10, 6, 7, 10, 7, 1, 1, 7, 3, -1, -1, -1, -1, -1, -1, -1], 129 | [1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, -1, -1, -1, -1], 130 | [2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, -1], 131 | [7, 8, 0, 7, 0, 6, 6, 0, 2, -1, -1, -1, -1, -1, -1, -1], 132 | [7, 3, 2, 6, 7, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 133 | [2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, -1, -1, -1, -1], 134 | [2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, -1], 135 | [1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, -1], 136 | [11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, -1, -1, -1, -1], 137 | [8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, -1], 138 | [0, 9, 1, 11, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 139 | [7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, -1, -1, -1, -1], 140 | [7, 11, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 141 | [7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 142 | [3, 0, 8, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 143 | [0, 1, 9, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 144 | [8, 1, 9, 8, 3, 1, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1], 145 | [10, 1, 2, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 146 | [1, 2, 10, 3, 0, 8, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1], 147 | [2, 9, 0, 2, 10, 9, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1], 148 | [6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, -1, -1, -1, -1], 149 | [7, 2, 3, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 150 | [7, 0, 8, 7, 6, 0, 6, 2, 0, -1, -1, -1, -1, -1, -1, -1], 151 | [2, 7, 6, 2, 3, 7, 0, 1, 9, -1, -1, -1, -1, -1, -1, -1], 152 | [1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, -1, -1, -1, -1], 153 | [10, 7, 6, 10, 1, 7, 1, 3, 7, -1, -1, -1, -1, -1, -1, -1], 154 | [10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, -1, -1, -1, -1], 155 | [0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, -1, -1, -1, -1], 156 | [7, 6, 10, 7, 10, 8, 8, 10, 9, -1, -1, -1, -1, -1, -1, -1], 157 | [6, 8, 4, 11, 8, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 158 | [3, 6, 11, 3, 0, 6, 0, 4, 6, -1, -1, -1, -1, -1, -1, -1], 159 | [8, 6, 11, 8, 4, 6, 9, 0, 1, -1, -1, -1, -1, -1, -1, -1], 160 | [9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, -1, -1, -1, -1], 161 | [6, 8, 4, 6, 11, 8, 2, 10, 1, -1, -1, -1, -1, -1, -1, -1], 162 | [1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, -1, -1, -1, -1], 163 | [4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, -1, -1, -1, -1], 164 | [10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, -1], 165 | [8, 2, 3, 8, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1], 166 | [0, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 167 | [1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, -1, -1, -1, -1], 168 | [1, 9, 4, 1, 4, 2, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1], 169 | [8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, -1, -1, -1, -1], 170 | [10, 1, 0, 10, 0, 6, 6, 0, 4, -1, -1, -1, -1, -1, -1, -1], 171 | [4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, -1], 172 | [10, 9, 4, 6, 10, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 173 | [4, 9, 5, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 174 | [0, 8, 3, 4, 9, 5, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1], 175 | [5, 0, 1, 5, 4, 0, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1], 176 | [11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, -1, -1, -1, -1], 177 | [9, 5, 4, 10, 1, 2, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1], 178 | [6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, -1, -1, -1, -1], 179 | [7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, -1, -1, -1, -1], 180 | [3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, -1], 181 | [7, 2, 3, 7, 6, 2, 5, 4, 9, -1, -1, -1, -1, -1, -1, -1], 182 | [9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, -1, -1, -1, -1], 183 | [3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, -1, -1, -1, -1], 184 | [6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, -1], 185 | [9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, -1, -1, -1, -1], 186 | [1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, -1], 187 | [4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, -1], 188 | [7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, -1, -1, -1, -1], 189 | [6, 9, 5, 6, 11, 9, 11, 8, 9, -1, -1, -1, -1, -1, -1, -1], 190 | [3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, -1, -1, -1, -1], 191 | [0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, -1, -1, -1, -1], 192 | [6, 11, 3, 6, 3, 5, 5, 3, 1, -1, -1, -1, -1, -1, -1, -1], 193 | [1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, -1, -1, -1, -1], 194 | [0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, -1], 195 | [11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, -1], 196 | [6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, -1, -1, -1, -1], 197 | [5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, -1, -1, -1, -1], 198 | [9, 5, 6, 9, 6, 0, 0, 6, 2, -1, -1, -1, -1, -1, -1, -1], 199 | [1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, -1], 200 | [1, 5, 6, 2, 1, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 201 | [1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, -1], 202 | [10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, -1, -1, -1, -1], 203 | [0, 3, 8, 5, 6, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 204 | [10, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 205 | [11, 5, 10, 7, 5, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 206 | [11, 5, 10, 11, 7, 5, 8, 3, 0, -1, -1, -1, -1, -1, -1, -1], 207 | [5, 11, 7, 5, 10, 11, 1, 9, 0, -1, -1, -1, -1, -1, -1, -1], 208 | [10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, -1, -1, -1, -1], 209 | [11, 1, 2, 11, 7, 1, 7, 5, 1, -1, -1, -1, -1, -1, -1, -1], 210 | [0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, -1, -1, -1, -1], 211 | [9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, -1, -1, -1, -1], 212 | [7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, -1], 213 | [2, 5, 10, 2, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1], 214 | [8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, -1, -1, -1, -1], 215 | [9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, -1, -1, -1, -1], 216 | [9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, -1], 217 | [1, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 218 | [0, 8, 7, 0, 7, 1, 1, 7, 5, -1, -1, -1, -1, -1, -1, -1], 219 | [9, 0, 3, 9, 3, 5, 5, 3, 7, -1, -1, -1, -1, -1, -1, -1], 220 | [9, 8, 7, 5, 9, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 221 | [5, 8, 4, 5, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1], 222 | [5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, -1, -1, -1, -1], 223 | [0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, -1, -1, -1, -1], 224 | [10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, -1], 225 | [2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, -1, -1, -1, -1], 226 | [0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, -1], 227 | [0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, -1], 228 | [9, 4, 5, 2, 11, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 229 | [2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, -1, -1, -1, -1], 230 | [5, 10, 2, 5, 2, 4, 4, 2, 0, -1, -1, -1, -1, -1, -1, -1], 231 | [3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, -1], 232 | [5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, -1, -1, -1, -1], 233 | [8, 4, 5, 8, 5, 3, 3, 5, 1, -1, -1, -1, -1, -1, -1, -1], 234 | [0, 4, 5, 1, 0, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 235 | [8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, -1, -1, -1, -1], 236 | [9, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 237 | [4, 11, 7, 4, 9, 11, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1], 238 | [0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, -1, -1, -1, -1], 239 | [1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, -1, -1, -1, -1], 240 | [3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, -1], 241 | [4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, -1, -1, -1, -1], 242 | [9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, -1], 243 | [11, 7, 4, 11, 4, 2, 2, 4, 0, -1, -1, -1, -1, -1, -1, -1], 244 | [11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, -1, -1, -1, -1], 245 | [2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, -1, -1, -1, -1], 246 | [9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, -1], 247 | [3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, -1], 248 | [1, 10, 2, 8, 7, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 249 | [4, 9, 1, 4, 1, 7, 7, 1, 3, -1, -1, -1, -1, -1, -1, -1], 250 | [4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, -1, -1, -1, -1], 251 | [4, 0, 3, 7, 4, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 252 | [4, 8, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 253 | [9, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 254 | [3, 0, 9, 3, 9, 11, 11, 9, 10, -1, -1, -1, -1, -1, -1, -1], 255 | [0, 1, 10, 0, 10, 8, 8, 10, 11, -1, -1, -1, -1, -1, -1, -1], 256 | [3, 1, 10, 11, 3, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 257 | [1, 2, 11, 1, 11, 9, 9, 11, 8, -1, -1, -1, -1, -1, -1, -1], 258 | [3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, -1, -1, -1, -1], 259 | [0, 2, 11, 8, 0, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 260 | [3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 261 | [2, 3, 8, 2, 8, 10, 10, 8, 9, -1, -1, -1, -1, -1, -1, -1], 262 | [9, 10, 2, 0, 9, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 263 | [2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, -1, -1, -1, -1], 264 | [1, 10, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 265 | [1, 3, 8, 9, 1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 266 | [0, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 267 | [0, 3, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 268 | [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]] 269 | 270 | MC_TRI_TABLE_NP = np.array(MC_TRI_TABLE_RAW) 271 | 272 | 273 | # an enumeration of the edges corresponding to the table above 274 | EDGE_VERTS_RAW = [ 275 | [0 , 1], # 0 276 | [1 , 2], # 1 277 | [2 , 3], # 2 278 | [3 , 0], # 3 279 | [4 , 5], # 4 280 | [5 , 6], # 5 281 | [6 , 7], # 6 282 | [7 , 4], # 7 283 | [0 , 4], # 8 284 | [1 , 5], # 9 285 | [2 , 6], # 10 286 | [3 , 7], # 11 287 | ] 288 | 289 | EDGE_VERTS_NP = np.array(EDGE_VERTS_RAW) 290 | 291 | 292 | # which vert is which, matches the indexing above 293 | VERT_LOGICAL_COORDS_RAW = [ 294 | [0,1,0], #0 (2) 295 | [1,1,0], #1 (6) 296 | [1,0,0], #2 (4) 297 | [0,0,0], #3 (0) 298 | [0,1,1], #4 (3) 299 | [1,1,1], #5 (7) 300 | [1,0,1], #6 (5) 301 | [0,0,1], #7 (1) 302 | ] 303 | 304 | VERT_LOGICAL_COORDS_NP = np.array(VERT_LOGICAL_COORDS_RAW, dtype=bool) 305 | 306 | def get_mc_data(): 307 | tri_table = jnp.array(MC_TRI_TABLE_NP) 308 | edge_verts = jnp.array(EDGE_VERTS_NP) 309 | vert_logical_coords = jnp.array(VERT_LOGICAL_COORDS_NP) 310 | return tri_table, edge_verts, vert_logical_coords 311 | 312 | 313 | # Returns a list of triangles for this cells, as a [5,3,3] array of vertex positions, and a [5,] mask of which triangles are valid. 314 | @partial(jax.jit, static_argnames=("func")) 315 | def extract_triangles_from_cell(func, params, mc_data, cell_lower, cell_upper, vert_vals=None): 316 | 317 | tri_table, edge_verts, vert_logical_coords = mc_data 318 | 319 | # expand out positions for the 8 cube vertices 320 | vert_pos = jnp.where(vert_logical_coords, cell_upper[None,:], cell_lower[None,:]) 321 | 322 | # evaluate the function at each of the vertex locations 323 | # (or use precomputed values if we have them) 324 | if vert_vals is None: 325 | vert_vals = jax.vmap(partial(func,params))(vert_pos) 326 | 327 | # compute the crossing locations on each edge, and whether there is a crossing there 328 | def check_edge(inds): 329 | # gather values 330 | indA = inds[0] 331 | indB = inds[1] 332 | valA = vert_vals[indA] 333 | valB = vert_vals[indB] 334 | posA = vert_pos[indA,:] 335 | posB = vert_pos[indB,:] 336 | 337 | # compute the crossing 338 | t_cross = -valA / (valB - valA) 339 | t_cross = jnp.nan_to_num(t_cross) 340 | t_cross = jnp.clip(t_cross, a_min=0, a_max=1) 341 | cross_loc = (1. - t_cross) * posA + t_cross * posB 342 | has_cross = jnp.sign(valA) != jnp.sign(valB) 343 | return cross_loc, has_cross 344 | 345 | # compute all crossing values 346 | edge_cross_loc, edge_has_cross = jax.vmap(check_edge)(edge_verts) 347 | 348 | # enumerate which case we are in 349 | case_bits = jnp.power(2, jnp.arange(8)) 350 | case_id = jnp.sum((vert_vals < 0) * case_bits) 351 | 352 | # get the triangles for this case 353 | case_tris = jnp.reshape(tri_table[case_id,:15], (-1,3)) 354 | 355 | # for each triangle, gather the vertex positions 356 | def get_tri_pos(tri): 357 | is_valid = tri[0] != -1 358 | ind0 = jnp.clip(tri[0], a_min=0) 359 | ind1 = jnp.clip(tri[1], a_min=0) 360 | ind2 = jnp.clip(tri[2], a_min=0) 361 | return jnp.stack((edge_cross_loc[ind0,:], edge_cross_loc[ind1,:], edge_cross_loc[ind2,:]), axis=0), is_valid 362 | all_tri_pos, tri_is_valid = jax.vmap(get_tri_pos)(case_tris) 363 | 364 | return all_tri_pos, tri_is_valid 365 | 366 | @partial(jax.jit, static_argnames=("func", "n_sub_depth", "batch_eval_size")) 367 | def extract_triangles_from_subcells(func, params, mc_data, n_sub_depth, cell_lower, cell_upper, batch_eval_size=4096): 368 | 369 | tri_table, edge_verts, vert_logical_coords = mc_data 370 | 371 | # construct the grid of subcells 372 | side_n_sub_cells = 2**n_sub_depth 373 | side_n_pts = (1+side_n_sub_cells) 374 | side_coords0 = jnp.linspace(cell_lower[0], cell_upper[0], num=side_n_pts) 375 | side_coords1 = jnp.linspace(cell_lower[1], cell_upper[1], num=side_n_pts) 376 | side_coords2 = jnp.linspace(cell_lower[2], cell_upper[2], num=side_n_pts) 377 | grid_coords0, grid_coords1, grid_coords2 = jnp.meshgrid(side_coords0, side_coords1, side_coords2, indexing='ij') 378 | grid_coords = jnp.stack((grid_coords0, grid_coords1, grid_coords2), axis=-1) 379 | 380 | # evaluate the function 381 | flat_coords = jnp.reshape(grid_coords, (-1,3)) 382 | if flat_coords.shape[0] > batch_eval_size: 383 | # for very large sets, break in to batches 384 | nb = flat_coords.shape[0] // batch_eval_size 385 | stragglers = flat_coords[nb*batch_eval_size:,:] 386 | batched_flat_coords = jnp.reshape(flat_coords[:nb*batch_eval_size,:], (-1, batch_eval_size, 3)) 387 | vfunc = jax.vmap(partial(func, params)) 388 | batched_vals = jax.lax.map(vfunc, batched_flat_coords) 389 | batched_vals = jnp.reshape(batched_vals, (-1,)) 390 | 391 | # evaluate any stragglers in the very last batch 392 | straggler_vals = jax.vmap(partial(func,params))(stragglers) 393 | 394 | flat_vals = jnp.concatenate((batched_vals, straggler_vals)) 395 | else: 396 | flat_vals = jax.vmap(partial(func,params))(flat_coords) 397 | grid_vals = jnp.reshape(flat_vals, (side_n_pts, side_n_pts, side_n_pts)) 398 | 399 | # logical grid of subcell inds 400 | side_inds = jnp.arange(side_n_sub_cells) 401 | grid_inds0, grid_inds1, grid_inds2 = jnp.meshgrid(side_inds, side_inds, side_inds, indexing='ij') 402 | grid_inds = jnp.stack((grid_inds0, grid_inds1, grid_inds2), axis=-1) 403 | subcell_inds = jnp.reshape(grid_inds, (-1,3)) 404 | 405 | # compute the extents of each cell 406 | subcell_delta = (cell_upper - cell_lower) / side_n_sub_cells 407 | subcell_lower = cell_lower[None,:] + subcell_inds * subcell_delta[None,:] 408 | subcell_upper = subcell_lower + subcell_delta[None,:] 409 | 410 | # fetch the function values for each cell 411 | subcell_vert_inds = subcell_inds[:,None,:] + vert_logical_coords[None,:,:] 412 | subcell_vert_vals = grid_vals.at[subcell_vert_inds[:,:,0], subcell_vert_inds[:,:,1], subcell_vert_inds[:,:,2]].get() 413 | 414 | # apply the extraction routine to each subcell 415 | subcell_tri_pos, subcell_tri_is_valid = jax.vmap(partial(extract_triangles_from_cell, func, params, mc_data))(subcell_lower, subcell_upper, subcell_vert_vals) 416 | 417 | # combine all results 418 | subcell_tri_pos = jnp.reshape(subcell_tri_pos, (-1,3,3)) 419 | subcell_tri_is_valid = jnp.reshape(subcell_tri_is_valid, (-1,)) 420 | 421 | return subcell_tri_pos, subcell_tri_is_valid 422 | -------------------------------------------------------------------------------- /src/geometry.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | import numpy as np 5 | 6 | def norm(x): 7 | return jnp.linalg.norm(x, axis=-1) 8 | 9 | def norm2(x): 10 | return jnp.inner(x,x) 11 | 12 | def normalize(x): 13 | return x / norm(x) 14 | 15 | def orthogonal_dir(x, remove_dir): 16 | # take a vector x, remove any component in the direction of vector remove_dir, and return unit x 17 | remove_dir = normalize(remove_dir) 18 | x = x - jnp.dot(x, remove_dir) * remove_dir 19 | return normalize(x) 20 | 21 | def dot(x,y): 22 | return jnp.sum(x*y, axis=-1) 23 | 24 | def normalize_positions(pos, method='bbox'): 25 | # center and unit-scale positions in to the [-1,1] cube 26 | 27 | if method == 'mean': 28 | # center using the average point position 29 | pos = pos - jnp.mean(pos, axis=-2, keepdims=True) 30 | elif method == 'bbox': 31 | # center via the middle of the axis-aligned bounding box 32 | bbox_min = jnp.min(pos, axis=-2) 33 | bbox_max = jnp.max(pos, axis=-2) 34 | center = (bbox_max + bbox_min) / 2. 35 | pos -= center[None,:] 36 | else: 37 | raise ValueError("unrecognized method") 38 | 39 | scale = jnp.max(norm(pos), axis=-1, keepdims=True)[:,None] 40 | pos = pos / scale 41 | return pos 42 | 43 | def sample_mesh_sdf(V, F, n_sample, surface_frac=0.5, surface_perturb_sigma=0.01, ambient_expand=1.25): 44 | import igl 45 | ''' 46 | NOTE: Assumes input is scaled to lie in [-1,1] cube 47 | NOTE: RNG is handled internally, in part by an external library (libigl). Has none of the usual JAX RNG properties, may or may not yield same results, etc. 48 | ''' 49 | 50 | n_surface = int(n_sample * surface_frac) 51 | n_ambient = n_sample - n_surface 52 | 53 | # Compute a bounding box for the mesh 54 | bbox_min = np.array([-1,-1,-1]) 55 | bbox_max = np.array([1,1,1]) 56 | center = 0.5*(bbox_max + bbox_min) 57 | 58 | # Sample ambient points 59 | key = jax.random.PRNGKey(0) 60 | key, subkey = jax.random.split(key) 61 | 62 | # Q_ambient = jax.random.normal(subkey, (n_ambient, 3)) * ambient_sigma 63 | Q_ambient = jax.random.uniform(subkey, (n_ambient, 3), minval=-ambient_expand, maxval=ambient_expand) 64 | 65 | # Sample surface points 66 | sample_b, sample_f = igl.random_points_on_mesh(n_surface, np.array(V), np.array(F)) 67 | face_verts = V[F[sample_f], :] 68 | raw_samples = np.sum(sample_b[...,np.newaxis] * face_verts, axis=1) 69 | raw_samples = jnp.array(raw_samples) 70 | 71 | # add noise to surface points 72 | key, subkey = jax.random.split(key) 73 | offsets = jax.random.normal(subkey, (n_surface, 3)) * surface_perturb_sigma 74 | Q_surface = raw_samples + offsets 75 | 76 | # Combine and shuffle 77 | Q = np.vstack((Q_ambient, Q_surface)) 78 | key, subkey = jax.random.split(key) 79 | Q = jax.random.permutation(subkey, Q, axis=0) 80 | 81 | # Get SDF value via distance & winding number 82 | sdf_vals, _, closest = igl.signed_distance(np.array(Q), np.array(V), np.array(F)) 83 | sdf_vals = jnp.array(sdf_vals) 84 | 85 | return Q, sdf_vals 86 | 87 | 88 | def sample_mesh_importance(V, F, n_sample, n_sample_full_mult=10., beta=20., ambient_range=1.25): 89 | import igl 90 | 91 | V = np.array(V) 92 | F = np.array(F) 93 | n_sample_full = int(n_sample * n_sample_full_mult) 94 | 95 | # Sample ambient points 96 | Q_ambient = np.random.uniform(size=(n_sample_full, 3), low=-ambient_range, high=ambient_range) 97 | 98 | # Assign weights 99 | dist_sq, _, _ = igl.point_mesh_squared_distance(Q_ambient, np.array(V), np.array(F)) 100 | weight = np.exp(-beta * np.sqrt(dist_sq)) 101 | weight = weight / np.sum(weight) 102 | 103 | # Sample 104 | samp_inds = np.random.choice(n_sample_full, size=n_sample, p=weight) 105 | Q = Q_ambient[samp_inds,:] 106 | 107 | # Get SDF value via distance & winding number 108 | sdf_vals, _, closest = igl.signed_distance(Q, V, F) 109 | sdf_vals = jnp.array(sdf_vals) 110 | Q = jnp.array(Q) 111 | 112 | return Q, sdf_vals 113 | -------------------------------------------------------------------------------- /src/implicit_function.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | import utils 9 | 10 | # "enums" integer codes denoting the sign of the implicit function with a region 11 | SIGN_UNKNOWN = 0 # could be anything 12 | SIGN_POSITIVE = 1 # definitely positive throughout 13 | SIGN_NEGATIVE = 2 # definitely negative throughout 14 | 15 | class ImplicitFunction: 16 | 17 | # `eval` and `affine_eval` are functions that can be called 18 | def __init__(self, style): 19 | 20 | if style not in ['classify-only', 'classify-and-distance']: 21 | raise ValueError("unrecognized style") 22 | 23 | self.style = style 24 | 25 | def __call__(self, params, x): 26 | raise RuntimeError("ImplicitFunction does not implement a __call__() operator. Subclasses must provide an implementation if is to be used.") 27 | 28 | def classify_box(self, params, box_lower, box_upper, offset=0.): 29 | ''' 30 | Determine the sign of the function within a box (reports one of SIGN_UNKNOWN, etc) 31 | ''' 32 | 33 | # delegate to the more general version 34 | center = 0.5 * (box_lower + box_upper) 35 | pos_vec = box_upper - center 36 | vecs = jnp.diag(pos_vec) 37 | return self.classify_general_box(params, center, vecs, offset=offset) 38 | 39 | # General version for non-axis-aligned boxes 40 | def classify_general_box(self, params, box_center, box_vecs, offset=0.): 41 | ''' 42 | Determine the sign of the function within a general box (reports one of SIGN_UNKNOWN, etc) 43 | ''' 44 | 45 | raise RuntimeError("ImplicitFunction does not implement classify_general_box(). Subclasses must provide an implementation if is to be used.") 46 | 47 | 48 | def min_distance_to_zero(self, params, box_center, box_axis_vec): 49 | ''' 50 | Computes a lower bound on the distance to 0 from the center the box defined by `box_center` and `box_vecs`. The result is signed, a positive value means the function at the center point is positive, and likewise for negative. 51 | 52 | The query is evaluated on the axis-aligned range defined by the nonnegative values `box_vecs`. The min(box_vecs) is the largest-magnitude value which can ever be returned. 53 | 54 | If `box_vecs`, is `None`, it will be treated as the infinite domain. (Though some subclasses may not support this). 55 | ''' 56 | 57 | raise RuntimeError("ImplicitFunction does not implement min_distance_to_zero(). Subclasses must provide an implementation if is to be used.") 58 | 59 | 60 | def min_distance_to_zero_in_direction(self, params, source_point, bound_vec, source_range=None, return_source_value=False): 61 | ''' 62 | Computes a lower bound on the distance to 0 from `source_point` point in the direction `bound_vec`. The query is evaluated on the range `[source_point, source_point+bound_vec]`, and the magnitude of `bound_vec` is the largest-magnitude value which can be returned. 63 | 64 | Optionally, `source_range` is a `(v,d)` array of vectors defining a general box in space over which to evaluate the query. These vectors must be orthogonal to `bound_vec`. The result is then a minimum over all direction vectors in that prisim. 65 | 66 | Many methods incidentally compute the value of the function at the source as a part of evaluating this routine. If `return_source_value=True` the return will be a tuple `value, distance` giving the value as well. 67 | ''' 68 | 69 | 70 | raise RuntimeError("ImplicitFunction does not implement min_distance_to_zero_in_direction(). Subclasses must provide an implementation if is to be used.") 71 | 72 | 73 | -------------------------------------------------------------------------------- /src/implicit_mlp_utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | import utils 9 | import mlp, sdf, affine, slope_interval 10 | 11 | 12 | def generate_implicit_from_file(input_path, mode, **kwargs): 13 | 14 | ## Load the file 15 | if input_path.endswith(".npz"): 16 | params = mlp.load(input_path) 17 | else: 18 | raise ValueError("unrecognized filetype") 19 | 20 | 21 | # `params` is now populated 22 | 23 | # Construct an `ImplicitFunction` object ready to do the appropriate kind of evaluation 24 | if mode == 'sdf': 25 | implicit_func = mlp.func_from_spec(mode='default') 26 | if 'sdf_lipschitz' in kwargs: 27 | lipschitz_bound = kwargs['sdf_lipschitz'] 28 | else: 29 | lipschitz_bound = 1. 30 | return sdf.WeakSDFImplicitFunction(implicit_func, lipschitz_bound=lipschitz_bound), params 31 | 32 | elif mode == 'interval': 33 | implicit_func = mlp.func_from_spec(mode='affine') 34 | affine_ctx = affine.AffineContext('interval') 35 | return affine.AffineImplicitFunction(implicit_func, affine_ctx), params 36 | 37 | elif mode == 'affine_fixed': 38 | implicit_func = mlp.func_from_spec(mode='affine') 39 | affine_ctx = affine.AffineContext('affine_fixed') 40 | return affine.AffineImplicitFunction(implicit_func, affine_ctx), params 41 | 42 | elif mode == 'affine_truncate': 43 | implicit_func = mlp.func_from_spec(mode='affine') 44 | affine_ctx = affine.AffineContext('affine_truncate', 45 | truncate_count=kwargs['affine_n_truncate'], truncate_policy=kwargs['affine_truncate_policy']) 46 | return affine.AffineImplicitFunction(implicit_func, affine_ctx), params 47 | 48 | elif mode == 'affine_append': 49 | implicit_func = mlp.func_from_spec(mode='affine') 50 | affine_ctx = affine.AffineContext('affine_append', 51 | n_append=kwargs['affine_n_append']) 52 | return affine.AffineImplicitFunction(implicit_func, affine_ctx), params 53 | 54 | elif mode == 'affine_all': 55 | implicit_func = mlp.func_from_spec(mode='affine') 56 | affine_ctx = affine.AffineContext('affine_all') 57 | return affine.AffineImplicitFunction(implicit_func, affine_ctx), params 58 | 59 | elif mode == 'slope_interval': 60 | implicit_func = mlp.func_from_spec(mode='slope_interval') 61 | return slope_interval.SlopeIntervalImplicitFunction(implicit_func), params 62 | 63 | else: 64 | raise RuntimeError("unrecognized mode") 65 | -------------------------------------------------------------------------------- /src/kd_tree.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from functools import partial 5 | import math 6 | 7 | import numpy as np 8 | 9 | import utils 10 | from bucketing import * 11 | import implicit_function 12 | from implicit_function import SIGN_UNKNOWN, SIGN_POSITIVE, SIGN_NEGATIVE 13 | import extract_cell 14 | import geometry 15 | 16 | INVALID_IND = 2**30 17 | 18 | 19 | @partial(jax.jit, static_argnames=("func","continue_splitting"), donate_argnums=(7,8,9,10)) 20 | def construct_uniform_unknown_levelset_tree_iter( 21 | func, params, continue_splitting, 22 | node_valid, node_lower, node_upper, 23 | ib, out_valid, out_lower, out_upper, out_n_valid, 24 | finished_interior_lower, finished_interior_upper, N_finished_interior, 25 | finished_exterior_lower, finished_exterior_upper, N_finished_exterior, 26 | offset=0. 27 | ): 28 | 29 | N_in = node_lower.shape[0] 30 | d = node_lower.shape[-1] 31 | 32 | def eval_one_node(lower, upper): 33 | 34 | # perform an affine evaluation 35 | node_type = func.classify_box(params, lower, upper, offset=offset) 36 | 37 | # use the largest length along any dimension as the split policy 38 | worst_dim = jnp.argmax(upper-lower, axis=-1) 39 | 40 | return node_type, worst_dim 41 | 42 | # evaluate the function inside nodes 43 | node_types, node_split_dim = jax.vmap(eval_one_node)(node_lower, node_upper) 44 | 45 | # if requested, write out interior nodes 46 | if finished_interior_lower is not None: 47 | out_mask = jnp.logical_and(node_valid, node_types == SIGN_NEGATIVE) 48 | out_inds = utils.enumerate_mask(out_mask) + N_finished_interior 49 | finished_interior_lower = finished_interior_lower.at[out_inds,:].set(node_lower, mode='drop') 50 | finished_interior_upper = finished_interior_upper.at[out_inds,:].set(node_upper, mode='drop') 51 | N_finished_interior += jnp.sum(out_mask) 52 | 53 | # if requested, write out exterior nodes 54 | if finished_exterior_lower is not None: 55 | out_mask = jnp.logical_and(node_valid, node_types == SIGN_POSITIVE) 56 | out_inds = utils.enumerate_mask(out_mask) + N_finished_exterior 57 | finished_exterior_lower = finished_exterior_lower.at[out_inds,:].set(node_lower, mode='drop') 58 | finished_exterior_upper = finished_exterior_upper.at[out_inds,:].set(node_upper, mode='drop') 59 | N_finished_exterior += jnp.sum(out_mask) 60 | 61 | # split the unknown nodes to children 62 | # (if split_children is False this will just not create any children at all) 63 | split_mask = utils.logical_and_all([node_valid, node_types == SIGN_UNKNOWN]) 64 | N_new = jnp.sum(split_mask) # each split leads to two children (for a total of 2*N_new) 65 | 66 | ## now actually build the child nodes 67 | if continue_splitting: 68 | 69 | # extents of the new child nodes along each split dimension 70 | new_lower = node_lower 71 | new_upper = node_upper 72 | new_mid = 0.5 * (new_lower + new_upper) 73 | new_coord_mask = jnp.arange(3)[None,:] == node_split_dim[:,None] 74 | newA_lower = new_lower 75 | newA_upper = jnp.where(new_coord_mask, new_mid, new_upper) 76 | newB_lower = jnp.where(new_coord_mask, new_mid, new_lower) 77 | newB_upper = new_upper 78 | 79 | # concatenate the new children to form output arrays 80 | node_valid = jnp.concatenate((split_mask, split_mask)) 81 | node_lower = jnp.concatenate((newA_lower, newB_lower)) 82 | node_upper = jnp.concatenate((newA_upper, newB_upper)) 83 | new_N_valid = 2*N_new 84 | outL = out_valid.shape[1] 85 | 86 | else: 87 | node_valid = jnp.logical_and(node_valid, node_types == SIGN_UNKNOWN) 88 | new_N_valid = jnp.sum(node_valid) 89 | outL = node_valid.shape[0] 90 | 91 | # write the result in to arrays 92 | # utils.printarr(out_valid, node_valid, out_lower, node_lower, out_upper, node_upper) 93 | out_valid = out_valid.at[ib,:outL].set(node_valid) 94 | out_lower = out_lower.at[ib,:outL,:].set(node_lower) 95 | out_upper = out_upper.at[ib,:outL,:].set(node_upper) 96 | out_n_valid = out_n_valid + new_N_valid 97 | 98 | return out_valid, out_lower, out_upper, out_n_valid, \ 99 | finished_interior_lower, finished_interior_upper, N_finished_interior, \ 100 | finished_exterior_lower, finished_exterior_upper, N_finished_exterior 101 | 102 | 103 | def construct_uniform_unknown_levelset_tree(func, params, lower, upper, node_terminate_thresh=None, split_depth=None, compress_after=False, with_childern=False, with_interior_nodes=False, with_exterior_nodes=False, offset=0., batch_process_size=2048): 104 | 105 | # Validate input 106 | # ASSUMPTION: all of our bucket sizes larger than batch_process_size must be divisible by batch_process_size 107 | for b in bucket_sizes: 108 | if b > batch_process_size and (b//batch_process_size)*batch_process_size != b: 109 | raise ValueError(f"batch_process_size must be a factor of our bucket sizes, is not a factor of {b} (try a power of 2)") 110 | if node_terminate_thresh is None and split_depth is None: 111 | raise ValueError("must specify at least one of node_terminate_thresh or split_depth as a terminating condition") 112 | if node_terminate_thresh is None: 113 | node_terminate_thresh = 9999999999 114 | 115 | d = lower.shape[-1] 116 | B = batch_process_size 117 | 118 | print(f"\n == CONSTRUCTING LEVELSET TREE") 119 | # print(f" node thresh: {n_node_thresh}")n_node_thresh 120 | 121 | # Initialize data 122 | node_lower = lower[None,:] 123 | node_upper = upper[None,:] 124 | node_valid = jnp.ones((1,), dtype=bool) 125 | N_curr_nodes = 1 126 | finished_interior_lower = jnp.zeros((batch_process_size,3)) if with_interior_nodes else None 127 | finished_interior_upper = jnp.zeros((batch_process_size,3)) if with_interior_nodes else None 128 | N_finished_interior = 0 129 | finished_exterior_lower = jnp.zeros((batch_process_size,3)) if with_exterior_nodes else None 130 | finished_exterior_upper = jnp.zeros((batch_process_size,3)) if with_exterior_nodes else None 131 | N_finished_exterior = 0 132 | N_func_evals = 0 133 | 134 | ## Recursively build the tree 135 | i_split = 0 136 | n_splits = 99999999 if split_depth is None else split_depth+1 # 1 extra because last round doesn't split 137 | for i_split in range(n_splits): 138 | 139 | # Reshape in to batches of size <= B 140 | init_bucket_size = node_lower.shape[0] 141 | this_b = min(B, init_bucket_size) 142 | N_func_evals += node_lower.shape[0] 143 | # utils.printarr(node_valid, node_lower, node_upper) 144 | node_valid = jnp.reshape(node_valid, (-1, this_b)) 145 | node_lower = jnp.reshape(node_lower, (-1, this_b, d)) 146 | node_upper = jnp.reshape(node_upper, (-1, this_b, d)) 147 | nb = node_lower.shape[0] 148 | n_occ = int(math.ceil(N_curr_nodes / this_b)) # only the batches which are occupied (since valid nodes are densely packed at the start) 149 | 150 | # Detect when to quit. On the last iteration we need to not do any more splitting, but still process existing nodes one last time 151 | quit_next = (N_curr_nodes >= node_terminate_thresh) or i_split+1 == n_splits 152 | do_continue_splitting = not quit_next 153 | 154 | print(f"Uniform levelset tree. iter: {i_split} N_curr_nodes: {N_curr_nodes} bucket size: {init_bucket_size} batch size: {this_b} number of batches: {nb} quit next: {quit_next} do_continue_splitting: {do_continue_splitting}") 155 | 156 | # enlarge the finished nodes if needed 157 | if with_interior_nodes: 158 | while finished_interior_lower.shape[0] - N_finished_interior < N_curr_nodes: 159 | finished_interior_lower = utils.resize_array_axis(finished_interior_lower, 2*finished_interior_lower.shape[0]) 160 | finished_interior_upper = utils.resize_array_axis(finished_interior_upper, 2*finished_interior_upper.shape[0]) 161 | if with_exterior_nodes: 162 | while finished_exterior_lower.shape[0] - N_finished_exterior < N_curr_nodes: 163 | finished_exterior_lower = utils.resize_array_axis(finished_exterior_lower, 2*finished_exterior_lower.shape[0]) 164 | finished_exterior_upper = utils.resize_array_axis(finished_exterior_upper, 2*finished_exterior_upper.shape[0]) 165 | 166 | # map over the batches 167 | out_valid = jnp.zeros((nb, 2*this_b), dtype=bool) 168 | out_lower = jnp.zeros((nb, 2*this_b, 3)) 169 | out_upper = jnp.zeros((nb, 2*this_b, 3)) 170 | total_n_valid = 0 171 | for ib in range(n_occ): 172 | out_valid, out_lower, out_upper, total_n_valid, \ 173 | finished_interior_lower, finished_interior_upper, N_finished_interior, \ 174 | finished_exterior_lower, finished_exterior_upper, N_finished_exterior, \ 175 | = \ 176 | construct_uniform_unknown_levelset_tree_iter(func, params, do_continue_splitting, \ 177 | node_valid[ib,...], node_lower[ib,...], node_upper[ib,...], \ 178 | ib, out_valid, out_lower, out_upper, total_n_valid, \ 179 | finished_interior_lower, finished_interior_upper, N_finished_interior, \ 180 | finished_exterior_lower, finished_exterior_upper, N_finished_exterior, \ 181 | offset=offset) 182 | 183 | node_valid = out_valid 184 | node_lower = out_lower 185 | node_upper = out_upper 186 | N_curr_nodes = total_n_valid 187 | 188 | # flatten back out 189 | node_valid = jnp.reshape(node_valid, (-1,)) 190 | node_lower = jnp.reshape(node_lower, (-1, d)) 191 | node_upper = jnp.reshape(node_upper, (-1, d)) 192 | 193 | # Compactify and rebucket arrays 194 | target_bucket_size = get_next_bucket_size(total_n_valid) 195 | node_valid, N_curr_nodes, node_lower, node_upper = compactify_and_rebucket_arrays(node_valid, target_bucket_size, node_lower, node_upper) 196 | 197 | if quit_next: 198 | break 199 | 200 | 201 | # pack the output in to a dict to support optional outputs 202 | out_dict = { 203 | 'unknown_node_valid' : node_valid, 204 | 'unknown_node_lower' : node_lower, 205 | 'unknown_node_upper' : node_upper, 206 | } 207 | 208 | if with_interior_nodes: 209 | out_dict['interior_node_valid'] = jnp.arange(finished_interior_lower.shape[0]) < N_finished_interior 210 | out_dict['interior_node_lower'] = finished_interior_lower 211 | out_dict['interior_node_upper'] = finished_interior_upper 212 | 213 | if with_exterior_nodes: 214 | out_dict['exterior_node_valid'] = jnp.arange(finished_exterior_lower.shape[0]) < N_finished_exterior 215 | out_dict['exterior_node_lower'] = finished_exterior_lower 216 | out_dict['exterior_node_upper'] = finished_exterior_upper 217 | 218 | return out_dict 219 | 220 | @partial(jax.jit, static_argnames=("func", "n_samples_per_round"), donate_argnums=(8,9)) 221 | def sample_surface_iter(func, params, n_samples_per_round, width, rngkey, 222 | u_node_valid, u_node_lower, u_node_upper, 223 | found_sample_points, found_start_ind): 224 | 225 | ## Generate sample points in the nodes 226 | 227 | # pick which node to sample from 228 | # (uses the fact valid nodes will always be first N) 229 | n_node_valid = jnp.sum(u_node_valid) 230 | rngkey, subkey = jax.random.split(rngkey) 231 | node_ind = jax.random.randint(subkey, (n_samples_per_round,), minval=0, maxval=n_node_valid) 232 | 233 | # fetch node data 234 | samp_lower = u_node_lower[node_ind,:] 235 | samp_upper = u_node_upper[node_ind,:] 236 | 237 | # sample points uniformly within the nodes 238 | rngkey, subkey = jax.random.split(rngkey) 239 | samp_pos = jax.random.uniform(subkey, (n_samples_per_round,3), minval=samp_lower, maxval=samp_upper) 240 | 241 | # evaluate the function and reject samples outside of the specified width 242 | samp_val = jax.vmap(partial(func, params))(samp_pos) 243 | samp_valid = jnp.abs(samp_val) < width 244 | 245 | # append these new samples on to the output array 246 | n_samp_valid = jnp.sum(samp_valid) 247 | valid_inds = utils.enumerate_mask(samp_valid, fill_value=found_sample_points.shape[0]) 248 | valid_inds_loc = valid_inds + found_start_ind 249 | found_sample_points = found_sample_points.at[valid_inds_loc,:].set(samp_pos, mode='drop', indices_are_sorted=True) 250 | found_start_ind = jnp.minimum(found_start_ind + n_samp_valid, found_sample_points.shape[0]) 251 | 252 | return found_sample_points, found_start_ind 253 | 254 | 255 | def sample_surface(func, params, lower, upper, n_samples, width, rngkey, n_node_thresh=4096): 256 | 257 | ''' 258 | - Build tree over levelset (rather than a usual 0 bound, it needs to use += width, so we know for sure that the sample region is contained in unknown cells) 259 | - Uniformly rejection sample from the unknown cells 260 | ''' 261 | 262 | # print(f"Sample surface: building level set tree with at least {n_node_thresh} nodes") 263 | 264 | # Build a tree over the valid nodes 265 | # By definition returned nodes are all SIGN_UNKNOWN, and all the same size 266 | # print(f"sample_surface n_node_thresh {n_node_thresh}") 267 | out_dict = construct_uniform_unknown_levelset_tree(func, params, lower, upper, node_terminate_thresh=n_node_thresh, offset=width) 268 | node_valid = out_dict['unknown_node_valid'] 269 | node_lower = out_dict['unknown_node_lower'] 270 | node_upper = out_dict['unknown_node_upper'] 271 | 272 | # sample from the unknown nodes until we get enough samples 273 | n_samples_per_round = min(3*n_samples, 100000) # enough that we usually finish in one round 274 | found_sample_points = jnp.zeros((n_samples,3)) 275 | found_start_ind = 0 276 | round_count = 0 277 | while True: 278 | round_count += 1 279 | 280 | # print(f"Have {found_start_ind} / {n_samples} samples. Performing sample round") 281 | 282 | rngkey, subkey = jax.random.split(rngkey) 283 | found_sample_points, found_start_ind = sample_surface_iter(func, params, n_samples_per_round, width, subkey, 284 | node_valid, node_lower, node_upper, found_sample_points, found_start_ind) 285 | 286 | # NOTE: assumes all nodes are same size 287 | 288 | if found_start_ind == n_samples: 289 | break 290 | 291 | # print(f"Done! Sampling took {round_count} rounds") 292 | return found_sample_points 293 | 294 | 295 | # This is here for comparison to the tree-based one above 296 | @partial(jax.jit, static_argnames=("func", "n_samples_per_round"), donate_argnums=(8,9)) 297 | def sample_surface_uniform_iter(func, params, n_samples_per_round, width, rngkey, 298 | lower, upper, 299 | found_sample_points, found_start_ind): 300 | 301 | ## Generate sample points in the nodes 302 | 303 | # sample points uniformly within the nodes 304 | rngkey, subkey = jax.random.split(rngkey) 305 | samp_pos = jax.random.uniform(subkey, (n_samples_per_round,3), minval=lower, maxval=upper) 306 | 307 | # evaluate the function and reject samples outside of the specified width 308 | samp_val = jax.vmap(partial(func, params))(samp_pos) 309 | samp_valid = jnp.abs(samp_val) < width 310 | 311 | # append these new samples on to the output array 312 | n_samp_valid = jnp.sum(samp_valid) 313 | valid_inds = utils.enumerate_mask(samp_valid, fill_value=found_sample_points.shape[0]) 314 | valid_inds_loc = valid_inds + found_start_ind 315 | found_sample_points = found_sample_points.at[valid_inds_loc,:].set(samp_pos, mode='drop', indices_are_sorted=True) 316 | found_start_ind = jnp.minimum(found_start_ind + n_samp_valid, found_sample_points.shape[0]) 317 | 318 | return found_sample_points, found_start_ind 319 | 320 | def sample_surface_uniform(func, params, lower, upper, n_samples, width, rngkey): 321 | 322 | # sample from the unknown nodes until we get enough samples 323 | n_samples_per_round = min(10*n_samples, 100000) 324 | found_sample_points = jnp.zeros((n_samples,3)) 325 | found_start_ind = 0 326 | round_count = 0 327 | while True: 328 | round_count += 1 329 | 330 | rngkey, subkey = jax.random.split(rngkey) 331 | found_sample_points, found_start_ind = sample_surface_uniform_iter(func, params, n_samples_per_round, width, subkey, lower, upper, found_sample_points, found_start_ind) 332 | 333 | if found_start_ind == n_samples: 334 | break 335 | 336 | return found_sample_points 337 | 338 | @partial(jax.jit, static_argnames=("func","n_subcell_depth"), donate_argnums=(7,)) 339 | def hierarchical_marching_cubes_extract_iter(func, params, mc_data, n_subcell_depth, node_valid, node_lower, node_upper,tri_pos_out, n_out_written): 340 | 341 | # run the extraction routine 342 | tri_verts, tri_valid = jax.vmap(partial(extract_cell.extract_triangles_from_subcells, func, params, mc_data, n_subcell_depth))(node_lower, node_upper) 343 | tri_valid = jnp.logical_and(tri_valid, node_valid[:,None]) 344 | 345 | # flatten out the generated triangles 346 | tri_verts = jnp.reshape(tri_verts, (-1,3,3)) 347 | tri_valid = jnp.reshape(tri_valid, (-1,)) 348 | 349 | # write the result 350 | out_inds = utils.enumerate_mask(tri_valid, fill_value=tri_pos_out.shape[0]) 351 | out_inds += n_out_written 352 | tri_pos_out = tri_pos_out.at[out_inds,:,:].set(tri_verts, mode='drop') 353 | n_out_written += jnp.sum(tri_valid) 354 | 355 | return tri_pos_out, n_out_written 356 | 357 | def hierarchical_marching_cubes(func, params, lower, upper, depth, n_subcell_depth=2, extract_batch_max_tri_out=1000000): 358 | 359 | # Build a tree over the isosurface 360 | # By definition returned nodes are all SIGN_UNKNOWN, and all the same size 361 | out_dict = construct_uniform_unknown_levelset_tree(func, params, lower, upper, split_depth=3*(depth-n_subcell_depth)) 362 | node_valid = out_dict['unknown_node_valid'] 363 | node_lower = out_dict['unknown_node_lower'] 364 | node_upper = out_dict['unknown_node_upper'] 365 | 366 | # fetch the extraction data 367 | mc_data = extract_cell.get_mc_data() 368 | 369 | # Extract triangle from the valid nodes (do it in batches in case there are a lot) 370 | extract_batch_size = extract_batch_max_tri_out // (5 * (2**n_subcell_depth)**3) 371 | extract_batch_size = get_next_bucket_size(extract_batch_size) 372 | N_cell = node_valid.shape[0] 373 | N_valid = int(jnp.sum(node_valid)) 374 | n_out_written = 0 375 | tri_pos_out = jnp.zeros((1, 3, 3)) 376 | 377 | init_bucket_size = node_lower.shape[0] 378 | this_b = min(extract_batch_size, init_bucket_size) 379 | node_valid = jnp.reshape(node_valid, (-1, this_b)) 380 | node_lower = jnp.reshape(node_lower, (-1, this_b, 3)) 381 | node_upper = jnp.reshape(node_upper, (-1, this_b, 3)) 382 | nb = node_lower.shape[0] 383 | n_occ = int(math.ceil(N_valid/ this_b)) # only the batches which are occupied (since valid nodes are densely packed at the start) 384 | max_tri_round = this_b * 5 * (2**n_subcell_depth)**3 385 | for ib in range(n_occ): 386 | 387 | print(f"Extract iter {ib} / {n_occ}. max_tri_round: {max_tri_round} n_out_written: {n_out_written}") 388 | 389 | # expand the output array only lazily as needed 390 | while(tri_pos_out.shape[0] - n_out_written < max_tri_round): 391 | tri_pos_out = utils.resize_array_axis(tri_pos_out, 2*tri_pos_out.shape[0]) 392 | 393 | tri_pos_out, n_out_written = hierarchical_marching_cubes_extract_iter(func, params, mc_data, n_subcell_depth, node_valid[ib,...], node_lower[ib,...], node_upper[ib,...], tri_pos_out, n_out_written) 394 | 395 | # clip the result triangles 396 | # TODO bucket and mask here? need to if we want this in a JIT loop 397 | tri_pos_out = tri_pos_out[:n_out_written,:] 398 | 399 | return tri_pos_out 400 | 401 | 402 | @partial(jax.jit, static_argnames=("func_tuple","viz_nodes")) 403 | def find_any_intersection_iter( 404 | func_tuple, params_tuple, eps, 405 | node_lower, node_upper, N_curr_nodes, 406 | viz_nodes = False 407 | ): 408 | 409 | # N_curr_nodes --> the first N nodes are valid 410 | 411 | ''' 412 | Algorithm: 413 | process_node(): 414 | 415 | for each func: 416 | 417 | detect func in node as one of 4 categories ( 418 | positive: (strictly positive via interval bound) 419 | negative: (strictly negative via interval bound) 420 | unknown: (cannot bound via interval bound) 421 | near_surface: (there is a sign change in +- eps/2 of node center and node width < eps) 422 | (near surface has highest precedence if it applies) 423 | 424 | if >= 2 are (negative or near_surface): 425 | return found intersection! 426 | 427 | if >= 2 are (negative or unknown): 428 | recurse on subnodes 429 | 430 | else: 431 | return exit -- no intersection 432 | ''' 433 | 434 | N_in = node_lower.shape[0] 435 | d = node_lower.shape[-1] 436 | node_valid = jnp.arange(node_lower.shape[0]) < N_curr_nodes 437 | 438 | if len(func_tuple) != 2: 439 | raise ValueError("intersection supports pairwise only as written") 440 | funcA = func_tuple[0] 441 | funcB = func_tuple[1] 442 | paramsA = params_tuple[0] 443 | paramsB = params_tuple[1] 444 | 445 | # the side of a cube such that all points are within `eps` of each other 446 | eps_cube_width = eps / jnp.sqrt(3) 447 | 448 | 449 | def process_node(valid, lower, upper): 450 | 451 | intersection_count = 0 # nodes which definitely have an intersection in this cell 452 | possible_intersection_count = 0 # nodes which _might_ have an intersection in this cell 453 | 454 | # intersection details 455 | found_intersection = False 456 | intersection_loc = jnp.array((-777., -777., -777.)) 457 | 458 | # Node geometry 459 | node_width = jnp.max(upper-lower) 460 | node_split_dim = jnp.argmax(upper-lower, axis=-1) 461 | node_is_small = node_width < eps_cube_width 462 | node_center = 0.5 * (lower + upper) 463 | sample_offsets = jnp.concatenate((jnp.zeros((1,d)) ,jnp.eye(d), -jnp.eye(d)), axis=0) 464 | sample_pts = node_center[None,:] + eps_cube_width * sample_offsets 465 | 466 | # classify the box 467 | node_interval_typeA = funcA.classify_box(paramsA, lower, upper) 468 | node_interval_typeB = funcB.classify_box(paramsB, lower, upper) 469 | 470 | # test the sample points nearby for convergence checking 471 | sample_valsA = jax.vmap(partial(funcA, paramsA))(sample_pts) 472 | sample_valsB = jax.vmap(partial(funcB, paramsB))(sample_pts) 473 | 474 | all_same_signA = utils.all_same_sign(sample_valsA) 475 | all_same_signB = utils.all_same_sign(sample_valsB) 476 | is_near_surfaceA = jnp.logical_and(node_is_small, ~all_same_signA) 477 | is_near_surfaceB = jnp.logical_and(node_is_small, ~all_same_signB) 478 | 479 | ## test if we definitely found an intersection 480 | 481 | # if both functions are close to the surface 482 | any_neg_indA = jnp.nonzero(sample_valsA < 0, size=1, fill_value=0)[0][0] 483 | any_is_negA = sample_valsA[any_neg_indA] < 0 484 | any_neg_locA = sample_pts[any_neg_indA,:] 485 | any_neg_indB = jnp.nonzero(sample_valsB < 0, size=1, fill_value=0)[0][0] 486 | any_is_negB = sample_valsB[any_neg_indB] < 0 487 | any_neg_locB = sample_pts[any_neg_indB,:] 488 | have_near_neg_samples = utils.logical_and_all((node_is_small, any_is_negA, any_is_negB)) 489 | found_intersection = jnp.logical_or(found_intersection, have_near_neg_samples) 490 | intersection_loc = jnp.where(have_near_neg_samples, 0.5 * (any_neg_locA + any_neg_locB), intersection_loc) 491 | 492 | # if some sample point is inside of both funcs 493 | # (no need to do anything for both SIGN_NEGATIVE, it will be caught by this) 494 | # (this criterion is tested second because we prefer it, it gives a point stricly inside instead 495 | # of in the blurry eps converged region) 496 | sample_both_neg = jnp.logical_and(sample_valsA < 0 , sample_valsB < 0) 497 | both_neg_ind = jnp.nonzero(sample_both_neg, size=1, fill_value=0)[0][0] 498 | have_sample_both_neg = sample_both_neg[both_neg_ind] 499 | sample_both_neg_loc = sample_pts[both_neg_ind,:] 500 | found_intersection = jnp.logical_or(found_intersection, have_sample_both_neg) 501 | intersection_loc = jnp.where(have_sample_both_neg, sample_both_neg_loc, intersection_loc) 502 | 503 | 504 | ## test if we need to keep searching 505 | could_be_insideA = jnp.logical_or( 506 | node_interval_typeA == SIGN_NEGATIVE, 507 | jnp.logical_and(node_interval_typeA == SIGN_UNKNOWN, ~is_near_surfaceA) 508 | ) 509 | could_be_insideB = jnp.logical_or( 510 | node_interval_typeB == SIGN_NEGATIVE, 511 | jnp.logical_and(node_interval_typeB == SIGN_UNKNOWN, ~is_near_surfaceB) 512 | ) 513 | 514 | needs_subdiv = utils.logical_and_all((could_be_insideA, could_be_insideB, valid)) 515 | found_intersection = jnp.logical_and(found_intersection, valid) 516 | 517 | 518 | return found_intersection, intersection_loc, needs_subdiv, node_split_dim 519 | 520 | 521 | # evaluate the function inside nodes 522 | has_intersection, intersection_loc, needs_subdiv, node_split_dim = \ 523 | jax.vmap(process_node)(node_valid, node_lower, node_upper) 524 | 525 | # if there was any intersection, pull out its data right away 526 | int_ind = jnp.nonzero(has_intersection, size=1, fill_value=0)[0][0] # get any nonzero entry 527 | found_int = has_intersection[int_ind] 528 | found_int_loc = intersection_loc[int_ind, :] 529 | 530 | # no need to keep processing anything if we found an intersection 531 | needs_subdiv = jnp.logical_and(needs_subdiv, ~found_int) 532 | 533 | if viz_nodes: 534 | # if requested, dump out all of the nodes that were searched, for visualization 535 | viz_nodes_mask = jnp.logical_and(node_valid, ~needs_subdiv) 536 | else: 537 | viz_nodes_mask = None 538 | 539 | N_needs_sudiv = jnp.sum(needs_subdiv) 540 | 541 | # get rid of all nodes that are not getting subdivided and compactify the rest 542 | N_new = jnp.sum(needs_subdiv) # before split, after splitting there will be 2*N_new nodes 543 | compact_inds = jnp.nonzero(needs_subdiv, size=needs_subdiv.shape[0], fill_value=INVALID_IND)[0] 544 | node_lower = node_lower.at[compact_inds,:].get(mode='fill', fill_value=-777.) 545 | node_upper = node_upper.at[compact_inds,:].get(mode='fill', fill_value=-777.) 546 | node_split_dim = node_split_dim.at[compact_inds].get(mode='fill', fill_value=-777) 547 | 548 | ## now actually build the child nodes 549 | 550 | # extents of the new child nodes along each split dimension 551 | new_lower = node_lower 552 | new_upper = node_upper 553 | new_mid = 0.5 * (new_lower + new_upper) 554 | new_coord_mask = jnp.arange(3)[None,:] == node_split_dim[:,None] 555 | newA_lower = new_lower 556 | newA_upper = jnp.where(new_coord_mask, new_mid, new_upper) 557 | newB_lower = jnp.where(new_coord_mask, new_mid, new_lower) 558 | newB_upper = new_upper 559 | 560 | # write the new children in to the arrays (this will have twice the size of the input) 561 | node_lower = utils.interleave_arrays((newA_lower, newB_lower)) 562 | node_upper = utils.interleave_arrays((newA_upper, newB_upper)) 563 | 564 | return node_lower, node_upper, 2*N_new, found_int, 1, 2, found_int_loc, viz_nodes_mask 565 | 566 | 567 | def find_any_intersection(func_tuple, params_tuple, lower, upper, eps, viz_nodes=False): 568 | 569 | d = lower.shape[-1] 570 | 571 | print(f"\n == SEARCHING FOR INTERSECTION") 572 | # print(f" max depth: {max_depth}") 573 | 574 | # Initialize data 575 | node_lower = lower[None,:] 576 | node_upper = upper[None,:] 577 | N_curr_nodes = 1 578 | N_nodes_processed = 0 # only actually nodes, does not count fake ones due to bucketing 579 | N_bucket_nodes_processed = 0 # includes real and fake nodes due to bucketing 580 | 581 | if viz_nodes: 582 | viz_nodes_lower = jnp.zeros((0,3)) 583 | viz_nodes_upper = jnp.zeros((0,3)) 584 | viz_nodes_type = jnp.zeros((0,), dtype=int) 585 | else: 586 | viz_nodes_lower = None 587 | viz_nodes_upper = None 588 | viz_nodes_type = None 589 | 590 | ## Recursively search the space 591 | split_round = 0 592 | while(True): 593 | 594 | ## Call the function which does all the actual work 595 | # (the node_lower/node_upper arrays that come out are twice the size due to splits) 596 | 597 | N_nodes_processed += N_curr_nodes 598 | N_bucket_nodes_processed += node_lower.shape[0] 599 | 600 | print(f"Intersection search depth {split_round}. Searching {N_curr_nodes} nodes (bucket: {node_lower.shape[0]})") 601 | 602 | if(viz_nodes): 603 | # if requested, save visualization data 604 | # (back these up so we can use them below) 605 | node_lower_prev = node_lower 606 | node_upper_prev = node_upper 607 | 608 | node_lower, node_upper, N_curr_nodes, found_int, found_int_A, found_int_B, found_int_loc, viz_mask = \ 609 | find_any_intersection_iter(func_tuple, params_tuple, eps, node_lower, node_upper, N_curr_nodes, viz_nodes) 610 | 611 | if(viz_nodes): 612 | # if requested, save visualization data 613 | node_lower_save = node_lower_prev[viz_mask,:] 614 | node_upper_save = node_upper_prev[viz_mask,:] 615 | 616 | # classify the nodes 617 | def process_node(lower, upper): 618 | node_interval_typeA = func_tuple[0].classify_box(params_tuple[0], lower, upper) 619 | node_interval_typeB = func_tuple[1].classify_box(params_tuple[1], lower, upper) 620 | type_count = 0 621 | type_count = jnp.where(node_interval_typeA == SIGN_POSITIVE, 1, type_count) 622 | type_count = jnp.where(node_interval_typeB == SIGN_POSITIVE, 2, type_count) 623 | return type_count 624 | node_type_save = jax.vmap(process_node)(node_lower_save, node_upper_save) 625 | 626 | viz_nodes_lower = jnp.concatenate((viz_nodes_lower, node_lower_save)) 627 | viz_nodes_upper = jnp.concatenate((viz_nodes_upper, node_upper_save)) 628 | viz_nodes_type = jnp.concatenate((viz_nodes_type, node_type_save)) 629 | 630 | N_curr_nodes = int(N_curr_nodes) 631 | 632 | # quit because we found an intersection 633 | if found_int: 634 | print(f"Found intersection between funcs {found_int_A},{found_int_B} at {found_int_loc}. Processed {N_nodes_processed} nodes ({N_bucket_nodes_processed}).") 635 | if viz_nodes: 636 | return found_int, found_int_A, found_int_B, found_int_loc, viz_nodes_lower, viz_nodes_upper, viz_nodes_type 637 | else: 638 | return found_int, found_int_A, found_int_B, found_int_loc 639 | 640 | # quit because there can be no intersection 641 | if N_curr_nodes == 0: 642 | print(f"No intersection detected. Processed {N_nodes_processed} nodes ({N_bucket_nodes_processed}).") 643 | if viz_nodes: 644 | return False, 0, 0, jnp.array((-777., -777., -777.)), viz_nodes_lower, viz_nodes_upper, viz_nodes_type 645 | else: 646 | return False, 0, 0, jnp.array((-777., -777., -777.)) 647 | 648 | # if the current nodes would fit in a smaller array, put them there 649 | new_bucket_size = get_next_bucket_size(N_curr_nodes) 650 | curr_bucket_size = node_lower.shape[0] 651 | if new_bucket_size < curr_bucket_size: 652 | node_lower = node_lower[:new_bucket_size,:] 653 | node_upper = node_upper[:new_bucket_size,:] 654 | 655 | split_round += 1 656 | 657 | 658 | # @partial(jax.jit, static_argnames=("func","batch_process_size"), donate_argnums=(3,4,5,6,7)) 659 | @partial(jax.jit, static_argnames=("func","batch_process_size")) 660 | def closest_point_iter(func, params, 661 | query_points, query_min_dist, query_min_loc, 662 | work_query_id, work_node_lower, work_node_upper, work_stack_top, 663 | eps, batch_process_size): 664 | 665 | # basic strategy: 666 | # - pop work items off queue 667 | # - discard inside/outside nodes 668 | # - discard nodes further than min dist 669 | # - for any node which spans, compute minimum distance 670 | # - reduce over minimum 671 | # - if node dist == min, set min location 672 | # - recurse into big nodes, push back on stack 673 | 674 | 675 | ## pop off some work to do 676 | B = batch_process_size 677 | Q = query_points.shape[0] 678 | d = query_points.shape[-1] 679 | pop_ind = jnp.maximum(work_stack_top-B,0) 680 | batch_query_id = jax.lax.dynamic_slice_in_dim(work_query_id, pop_ind, B) 681 | batch_node_lower = jax.lax.dynamic_slice_in_dim(work_node_lower, pop_ind, B) 682 | batch_node_upper = jax.lax.dynamic_slice_in_dim(work_node_upper, pop_ind, B) 683 | batch_query_loc = query_points[batch_query_id,:] 684 | batch_query_min_dist = query_min_dist[batch_query_id] 685 | batch_valid = jnp.arange(B) < work_stack_top 686 | work_stack_top = pop_ind 687 | 688 | eps_cube_width = eps / jnp.sqrt(d) 689 | d = work_node_lower.shape[-1] 690 | 691 | # process each node, computing closest point data 692 | def process_one(valid, query_id, lower, upper, query_loc, query_min_dist): 693 | 694 | # compute an upper bound on the distance to any point in the node 695 | node_width = jnp.max(upper-lower) 696 | node_center = 0.5 * (lower + upper) 697 | node_center_dist_offset = jnp.sqrt(jnp.sum(jnp.square(upper-lower))) # maximum distance from the center to any point in the node 698 | max_dist_to_point_in_node = geometry.norm(query_loc - node_center) + node_center_dist_offset # could be tighter 699 | nearest_point_in_node = jnp.clip(query_loc, a_min=lower, a_max=upper) 700 | min_dist_to_point_in_node = geometry.norm(query_loc - node_center) 701 | node_split_dim = jnp.argmax(upper-lower, axis=-1) 702 | is_small = node_width < eps_cube_width 703 | sample_offsets = jnp.concatenate((jnp.zeros((1,d)) ,jnp.eye(d), -jnp.eye(d)), axis=0) # [7,3] 704 | sample_pts = node_center[None,:] + (upper-lower)[None,:] * sample_offsets 705 | 706 | # classify the box 707 | node_interval_type = func.classify_box(params, lower, upper) 708 | is_outside = jnp.logical_or(node_interval_type==SIGN_NEGATIVE, node_interval_type==SIGN_POSITIVE) 709 | 710 | # test the sample points nearby for convergence checking 711 | sample_vals = jax.vmap(partial(func, params))(sample_pts) 712 | spans_surface = jnp.logical_and(~utils.all_same_sign(sample_vals), valid) 713 | 714 | # compute outputs 715 | this_closest_point_dist = jnp.where(spans_surface, max_dist_to_point_in_node, float('inf')) 716 | needs_subdiv = utils.logical_and_all((valid, ~is_outside, ~is_small, min_dist_to_point_in_node < query_min_dist)) 717 | 718 | return needs_subdiv, this_closest_point_dist, node_center, node_split_dim 719 | 720 | batch_needs_subdiv, batch_this_closest_point_dist, batch_node_center, batch_node_split_dim = \ 721 | jax.vmap(process_one)(batch_valid, batch_query_id, batch_node_lower, batch_node_upper, batch_query_loc, batch_query_min_dist) 722 | 723 | 724 | # set any newly found closest values 725 | query_min_dist = query_min_dist.at[batch_query_id].min(batch_this_closest_point_dist) 726 | batch_query_new_min_dist = query_min_dist[batch_query_id] 727 | batch_has_new_min = (batch_this_closest_point_dist == batch_query_new_min_dist) 728 | batch_target_inds = jnp.where(batch_has_new_min, batch_query_id, Q) 729 | query_min_loc = query_min_loc.at[batch_target_inds,:].set(batch_node_center, mode='drop') 730 | 731 | # compactify the nodes which need to be subdivided 732 | N_new = jnp.sum(batch_needs_subdiv) # before split, after splitting there will be 2*N_new nodes 733 | compact_inds = jnp.nonzero(batch_needs_subdiv, size=batch_needs_subdiv.shape[0], fill_value=INVALID_IND)[0] 734 | batch_node_lower = batch_node_lower.at[compact_inds,:].get(mode='fill', fill_value=-777.) 735 | batch_node_upper = batch_node_upper.at[compact_inds,:].get(mode='fill', fill_value=-777.) 736 | batch_query_id = batch_query_id.at[compact_inds].get(mode='fill', fill_value=-777.) 737 | batch_node_split_dim = batch_node_split_dim.at[compact_inds].get(mode='fill', fill_value=-777) 738 | 739 | ## now actually build the child nodes 740 | 741 | # extents of the new child nodes along each split dimension 742 | new_batch_lower = batch_node_lower 743 | new_batch_upper = batch_node_upper 744 | new_batch_mid = 0.5 * (new_batch_lower + new_batch_upper) 745 | new_batch_coord_mask = jnp.arange(3)[None,:] == batch_node_split_dim[:,None] 746 | newA_lower = new_batch_lower 747 | newA_upper = jnp.where(new_batch_coord_mask, new_batch_mid, new_batch_upper) 748 | newB_lower = jnp.where(new_batch_coord_mask, new_batch_mid, new_batch_lower) 749 | newB_upper = new_batch_upper 750 | 751 | # write the new children in to the arrays (this will have twice the size of the input) 752 | new_node_lower = utils.interleave_arrays((newA_lower, newB_lower)) 753 | new_node_upper = utils.interleave_arrays((newA_upper, newB_upper)) 754 | new_node_query_id = utils.interleave_arrays((batch_query_id, batch_query_id)) 755 | 756 | # TODO is this guaranteed to update in place like at[] does? 757 | work_query_id = jax.lax.dynamic_update_slice_in_dim(work_query_id, new_node_query_id, pop_ind, axis=0) 758 | work_node_lower = jax.lax.dynamic_update_slice_in_dim(work_node_lower, new_node_lower, pop_ind, axis=0) 759 | work_node_upper = jax.lax.dynamic_update_slice_in_dim(work_node_upper, new_node_upper, pop_ind, axis=0) 760 | work_stack_top = work_stack_top + 2*N_new 761 | 762 | return query_min_dist, query_min_loc, \ 763 | work_query_id, work_node_lower, work_node_upper, work_stack_top, 764 | 765 | 766 | def closest_point(func, params, lower, upper, query_points, eps=0.001, batch_process_size=2048): 767 | 768 | 769 | # working data 770 | B = batch_process_size 771 | Q = query_points.shape[0] 772 | work_node_lower = jnp.repeat(lower[None,:], Q, axis=0) 773 | work_node_upper = jnp.repeat(upper[None,:], Q, axis=0) 774 | work_query_id = jnp.arange(Q) 775 | query_min_dist = jnp.full((Q,), float('inf')) 776 | query_min_loc = jnp.full((Q,3), -777.) 777 | work_stack_top = query_points.shape[0] 778 | 779 | i_round = 0 780 | while work_stack_top > 0: 781 | 782 | # Ensure there is enough room on the stack (at most we will add B new entries if every node is subdivided) 783 | while work_node_lower.shape[0] < (work_stack_top + B): 784 | N = work_node_lower.shape[0] 785 | N_new = max(2*N, 8*B) 786 | work_node_lower = utils.resize_array_axis(work_node_lower, N_new) 787 | work_node_upper = utils.resize_array_axis(work_node_upper, N_new) 788 | work_query_id = utils.resize_array_axis(work_query_id, N_new) 789 | 790 | 791 | query_min_dist, query_min_loc, \ 792 | work_query_id, work_node_lower, work_node_upper, work_stack_top = \ 793 | closest_point_iter(func, params, 794 | query_points, query_min_dist, query_min_loc, 795 | work_query_id, work_node_lower, work_node_upper, work_stack_top, 796 | eps=eps, batch_process_size=batch_process_size) 797 | 798 | work_stack_top = int(work_stack_top) 799 | 800 | i_round += 1 801 | 802 | return query_min_dist, query_min_loc 803 | 804 | @partial(jax.jit, static_argnames=("func", "n_samples",)) 805 | def bulk_properties_sample_mass(func, params, node_valid, node_lower, node_upper, n_samples, rngkey): 806 | 807 | # pick which node to sample from 808 | # (uses the fact valid nodes will always be first N) 809 | n_node_valid = jnp.sum(node_valid) 810 | rngkey, subkey = jax.random.split(rngkey) 811 | node_ind = jax.random.randint(subkey, (n_samples,), minval=0, maxval=n_node_valid) 812 | 813 | # fetch node data 814 | samp_lower = node_lower[node_ind,:] 815 | samp_upper = node_upper[node_ind,:] 816 | 817 | # sample points uniformly within the nodes 818 | rngkey, subkey = jax.random.split(rngkey) 819 | samp_pos = jax.random.uniform(subkey, (n_samples,3), minval=samp_lower, maxval=samp_upper) 820 | 821 | # evaluate the function and reject samples outside of the specified width 822 | samp_val = jax.vmap(partial(func, params))(samp_pos) 823 | samp_valid = samp_val < 0. 824 | 825 | # compute the contribution to mass and centroid 826 | areas = jnp.product(node_upper-node_lower, axis=-1) 827 | total_area = jnp.sum(jnp.where(node_valid, areas, 0.)) 828 | vol_per_sample = total_area / n_samples 829 | 830 | mass = vol_per_sample*jnp.sum(samp_valid) 831 | centroid = vol_per_sample*jnp.sum(jnp.where(samp_valid[:,None], samp_pos, 0.), axis=0) 832 | 833 | return mass, centroid 834 | 835 | def bulk_properties(func, params, lower, upper, rngkey, n_expand=int(1e4), n_sample=int(1e6)): 836 | 837 | out_dict = construct_uniform_unknown_levelset_tree(func, params, lower, upper, with_interior_nodes=True, node_terminate_thresh=n_expand) 838 | node_valid = out_dict['unknown_node_valid'] 839 | node_lower = out_dict['unknown_node_lower'] 840 | node_upper = out_dict['unknown_node_upper'] 841 | interior_node_valid = out_dict['interior_node_valid'] 842 | interior_node_lower = out_dict['interior_node_lower'] 843 | interior_node_upper = out_dict['interior_node_upper'] 844 | 845 | # Compute mass and centroid for this demo 846 | def compute_bulk_mass(lower, upper): 847 | mass = jnp.product(upper-lower) 848 | c = 0.5 * (lower + upper) 849 | return mass, mass * c 850 | 851 | mass_interior, centroid_interior = jax.jit(jax.vmap(compute_bulk_mass))(interior_node_lower, interior_node_upper) 852 | 853 | mass_interior = jnp.sum(jnp.where(interior_node_valid, mass_interior, 0.)) 854 | centroid_interior = jnp.sum(jnp.where(interior_node_valid[:,None], centroid_interior, 0.), axis=0) 855 | 856 | rngkey, subkey = jax.random.split(rngkey) 857 | mass_boundary, centroid_boundary = bulk_properties_sample_mass(func, params, node_valid, node_lower, node_upper, n_sample, subkey) 858 | 859 | mass = mass_interior + mass_boundary 860 | centroid = centroid_interior + centroid_boundary 861 | centroid = centroid / mass 862 | 863 | return mass, centroid 864 | 865 | 866 | def generate_tree_viz_nodes_simple(node_lower, node_upper, shrink_factor=0.05): 867 | 868 | print("Generating viz nodes") 869 | 870 | # (global shrink) 871 | min_width = jnp.min(node_upper - node_lower) 872 | shrink = shrink_factor * min_width 873 | node_lower += shrink 874 | node_upper -= shrink 875 | 876 | # Construct the 8 indices for each cell 877 | v0 = jnp.stack((node_lower[:,0], node_lower[:,1], node_lower[:,2]), axis=-1) 878 | v1 = jnp.stack((node_upper[:,0], node_lower[:,1], node_lower[:,2]), axis=-1) 879 | v2 = jnp.stack((node_upper[:,0], node_upper[:,1], node_lower[:,2]), axis=-1) 880 | v3 = jnp.stack((node_lower[:,0], node_upper[:,1], node_lower[:,2]), axis=-1) 881 | v4 = jnp.stack((node_lower[:,0], node_lower[:,1], node_upper[:,2]), axis=-1) 882 | v5 = jnp.stack((node_upper[:,0], node_lower[:,1], node_upper[:,2]), axis=-1) 883 | v6 = jnp.stack((node_upper[:,0], node_upper[:,1], node_upper[:,2]), axis=-1) 884 | v7 = jnp.stack((node_lower[:,0], node_upper[:,1], node_upper[:,2]), axis=-1) 885 | vs = [v0, v1, v2, v3, v4, v5, v6, v7] 886 | 887 | # (local shrink) 888 | centers = 0.5 * (node_lower + node_upper) 889 | for i in range(8): 890 | vs[i] = (1. - shrink_factor) * vs[i] + shrink_factor * centers 891 | 892 | verts = utils.interleave_arrays(vs) 893 | 894 | # Construct the index array 895 | inds = jnp.arange(8*v0.shape[0]).reshape((-1,8)) 896 | 897 | return verts, inds 898 | 899 | -------------------------------------------------------------------------------- /src/main_fit_implicit.py: -------------------------------------------------------------------------------- 1 | import igl 2 | 3 | import sys 4 | from functools import partial 5 | import argparse 6 | 7 | import numpy as np 8 | import jax 9 | import jax.numpy as jnp 10 | from jax.example_libraries import optimizers 11 | 12 | # Imports from this project 13 | from utils import * 14 | import mlp 15 | import geometry 16 | import render 17 | import queries 18 | import affine 19 | 20 | SRC_DIR = os.path.dirname(os.path.realpath(__file__)) 21 | ROOT_DIR = os.path.join(SRC_DIR, "..") 22 | 23 | def main(): 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | # Build arguments 28 | parser.add_argument("input_file", type=str) 29 | parser.add_argument("output_file", type=str) 30 | 31 | # network 32 | parser.add_argument("--activation", type=str, default='elu') 33 | parser.add_argument("--n_layers", type=int, default=8) 34 | parser.add_argument("--layer_width", type=int, default=32) 35 | parser.add_argument("--positional_encoding", action='store_true') 36 | parser.add_argument("--positional_count", type=int, default=10) 37 | parser.add_argument("--positional_pow_start", type=int, default=-3) 38 | 39 | # loss / data 40 | parser.add_argument("--fit_mode", type=str, default='sdf') 41 | parser.add_argument("--n_epochs", type=int, default=100) 42 | parser.add_argument("--n_samples", type=int, default=1000000) 43 | parser.add_argument("--sample_ambient_range", type=float, default=1.25) 44 | parser.add_argument("--sample_weight_beta", type=float, default=20.) 45 | 46 | # training 47 | parser.add_argument("--lr", type=float, default=1e-2) 48 | parser.add_argument("--batch_size", type=int, default=2048) 49 | parser.add_argument("--lr_decay_every", type=int, default=99999) 50 | parser.add_argument("--lr_decay_frac", type=float, default=.5) 51 | 52 | # jax options 53 | parser.add_argument("--log-compiles", action='store_true') 54 | parser.add_argument("--disable-jit", action='store_true') 55 | parser.add_argument("--debug-nans", action='store_true') 56 | parser.add_argument("--enable-double-precision", action='store_true') 57 | 58 | # Parse arguments 59 | args = parser.parse_args() 60 | 61 | # validate some inputs 62 | if args.activation not in ['relu', 'elu', 'cos']: 63 | raise ValueError("unrecognized activation") 64 | if args.fit_mode not in ['occupancy', 'sdf']: 65 | raise ValueError("unrecognized activation") 66 | if not args.output_file.endswith('.npz'): 67 | raise ValueError("output file should end with .npz") 68 | 69 | # Force jax to initialize itself so errors get thrown early 70 | _ = jnp.zeros(()) 71 | 72 | # Set jax things 73 | if args.log_compiles: 74 | jax.config.update("jax_log_compiles", 1) 75 | if args.disable_jit: 76 | jax.config.update('jax_disable_jit', True) 77 | if args.debug_nans: 78 | jax.config.update("jax_debug_nans", True) 79 | if args.enable_double_precision: 80 | jax.config.update("jax_enable_x64", True) 81 | 82 | # load the input 83 | print(f"Loading mesh {args.input_file}") 84 | V, F = igl.read_triangle_mesh(args.input_file) 85 | V = jnp.array(V) 86 | F = jnp.array(F) 87 | print(f" ...done") 88 | 89 | # preprocess (center and scale) 90 | V = geometry.normalize_positions(V, method='bbox') 91 | 92 | # sample training points 93 | print(f"Sampling {args.n_samples} training points...") 94 | # Uses a strategy which is basically the one Davies et al 95 | # samp, samp_SDF = geometry.sample_mesh_sdf(V, F, args.n_samples, surface_frac=args.surface_frac, surface_perturb_sigma=args.surface_perturb_sigma, ambient_range=args.surface_ambient_range) 96 | samp, samp_SDF = geometry.sample_mesh_importance(V, F, args.n_samples, beta=args.sample_weight_beta, ambient_range=args.sample_ambient_range) 97 | 98 | if args.fit_mode == 'occupancy': 99 | samp_target = (samp_SDF > 0) * 1.0 100 | n_pos = jnp.sum(samp_target > 0) 101 | n_neg = samp_target.shape[0] - n_pos 102 | w_pos = n_neg / (n_pos + n_neg) 103 | w_neg = n_pos / (n_pos + n_neg) 104 | samp_weight = jnp.where(samp_target > 0, w_pos, w_neg) 105 | elif args.fit_mode in ['sdf', 'tanh']: 106 | samp_target = samp_SDF 107 | samp_weight = jnp.ones_like(samp_target) 108 | else: raise ValueError("bad arg") 109 | print(f" ...done") 110 | 111 | # construct the network 112 | print(f"Constructing {args.n_layers}x{args.layer_width} {args.activation} network...") 113 | if args.positional_encoding: 114 | spec_list = [mlp.pow2_frequency_encode(args.positional_count, start_pow=args.positional_pow_start, with_shift=True), mlp.sin()] 115 | layers = [6*args.positional_count] + [args.layer_width]*args.n_layers + [1] 116 | spec_list += mlp.quick_mlp_spec(layers, args.activation) 117 | else: 118 | layers = [3] + [args.layer_width]*args.n_layers + [1] 119 | spec_list = mlp.quick_mlp_spec(layers, args.activation) 120 | orig_params = mlp.build_spec(spec_list) 121 | implicit_func = mlp.func_from_spec() 122 | 123 | 124 | # layer initialization 125 | key = jax.random.PRNGKey(0) 126 | key, subkey = jax.random.split(key) 127 | orig_params = mlp.initialize_params(orig_params, subkey) 128 | print(f" ...done") 129 | 130 | # test eval to ensure the function isn't broken 131 | print(f"Network test evaluation...") 132 | implicit_func(orig_params, jnp.array((0.1, 0.2, 0.3))) 133 | print(f"...done") 134 | 135 | # Create an optimizer 136 | print(f"Creating optimizer...") 137 | def step_func(i_epoch): 138 | out = args.lr * (args.lr_decay_frac ** (i_epoch // args.lr_decay_every)) 139 | return out 140 | opt = optimizers.adam(step_func) 141 | 142 | opt_param_keys = mlp.opt_param_keys(orig_params) 143 | 144 | # Union our optimizable parameters with the non-optimizable ones 145 | def add_full_params(opt_params): 146 | all_params = opt_params 147 | 148 | for k in orig_params: 149 | if k not in all_params: 150 | all_params[k] = orig_params[k] 151 | 152 | # Union our optimizable parameters with the non-optimizable ones 153 | def filter_to_opt_params_only(all_params): 154 | for k in all_params: 155 | if k not in opt_param_keys: 156 | del all_params[k] 157 | 158 | # Construct the optimizer over the optimizable params 159 | opt_params_only = {} 160 | for k in mlp.opt_param_keys(orig_params): 161 | opt_params_only[k] = orig_params[k] 162 | opt_state = opt.init_fn(opt_params_only) 163 | print(f"...done") 164 | 165 | best_loss = float('inf') 166 | best_params = None 167 | 168 | 169 | 170 | @jax.jit 171 | def generate_batch(rngkey, samples_in, samples_out, samples_weight): 172 | 173 | # concatenate to make processing easier 174 | samples = jnp.concatenate((samples_in, samples_out[:,None], samples_weight[:,None]), axis=-1) 175 | 176 | # shuffle 177 | samples = jax.random.permutation(rngkey, samples, axis=0) 178 | 179 | # split in to batches 180 | # (discard any extra samples) 181 | batch_count = samples.shape[0] // args.batch_size 182 | n_batch_total = args.batch_size * batch_count 183 | samples = samples[:n_batch_total, :] 184 | 185 | # split back up 186 | samples_in = samples[:,:3] 187 | samples_out = samples[:,3] 188 | samples_weight = samples[:,4] 189 | 190 | batch_in = jnp.reshape(samples_in, (batch_count, args.batch_size, 3)) 191 | batch_out = jnp.reshape(samples_out, (batch_count, args.batch_size)) 192 | batch_weight = jnp.reshape(samples_weight, (batch_count, args.batch_size)) 193 | 194 | return batch_in, batch_out, batch_weight, batch_count 195 | 196 | def batch_loss_fn(params, batch_coords, batch_target, batch_weight): 197 | 198 | add_full_params(params) 199 | 200 | def loss_one(params, coords, target, weight): 201 | pred = implicit_func(params, coords) 202 | 203 | if args.fit_mode == 'occupancy': 204 | return binary_cross_entropy_loss(pred, target) 205 | elif args.fit_mode == 'sdf': 206 | #L1 sdf loss 207 | return jnp.abs(pred - target) 208 | else: raise ValueError("bad arg") 209 | 210 | loss_terms = jax.vmap(partial(loss_one, params))(batch_coords, batch_target, batch_weight) 211 | loss_sum = jnp.mean(loss_terms) 212 | return loss_sum 213 | 214 | def batch_count_correct(params, batch_coords, batch_target): 215 | 216 | add_full_params(params) 217 | 218 | def loss_one(params, coords, target): 219 | pred = implicit_func(params, coords) 220 | 221 | if args.fit_mode == 'occupancy': 222 | is_correct_sign = jnp.sign(pred) == jnp.sign(target - .5) 223 | return is_correct_sign 224 | elif args.fit_mode in ['sdf']: 225 | is_correct_sign = jnp.sign(pred) == jnp.sign(target) 226 | return is_correct_sign 227 | else: raise ValueError("bad arg") 228 | 229 | correct_sign = jax.vmap(partial(loss_one, params))(batch_coords, batch_target) 230 | correct_count = jnp.sum(correct_sign) 231 | return correct_count 232 | 233 | @jax.jit 234 | def train_step(i_epoch, i_step, opt_state, batch_in, batch_out, batch_weight): 235 | 236 | opt_params = opt.params_fn(opt_state) 237 | value, grads = jax.value_and_grad(batch_loss_fn)(opt_params, batch_in, batch_out, batch_weight) 238 | correct_count = batch_count_correct(opt_params, batch_in, batch_out) 239 | opt_state = opt.update_fn(i_epoch, grads, opt_state) 240 | 241 | return value, opt_state, correct_count 242 | 243 | print(f"Training...") 244 | i_step = 0 245 | for i_epoch in range(args.n_epochs): 246 | 247 | key, subkey = jax.random.split(key) 248 | batches_in, batches_out, batches_weight, n_batches = generate_batch(subkey, samp, samp_target, samp_weight) 249 | losses = [] 250 | n_correct = 0 251 | n_total = 0 252 | 253 | for i_b in range(n_batches): 254 | 255 | loss, opt_state, correct_count = train_step(i_epoch, i_step, opt_state, batches_in[i_b,...], batches_out[i_b,...], batches_weight[i_b,...]) 256 | 257 | loss = float(loss) 258 | correct_count = int(correct_count) 259 | losses.append(loss) 260 | n_correct += correct_count 261 | n_total += args.batch_size 262 | i_step += 1 263 | 264 | mean_loss = np.mean(np.array(losses)) 265 | frac_correct= n_correct / n_total 266 | 267 | print(f"== Epoch {i_epoch} / {args.n_epochs} loss: {mean_loss:.6f} correct sign: {100*frac_correct:.2f}%") 268 | 269 | if mean_loss < best_loss: 270 | best_loss = mean_loss 271 | best_params = opt.params_fn(opt_state) 272 | add_full_params(best_params) 273 | print(" --> new best") 274 | 275 | print(f"Saving result to {args.output_file}") 276 | mlp.save(args.output_file, best_params) 277 | print(f" ...done") 278 | 279 | 280 | # save the result 281 | print(f"Saving result to {args.output_file}") 282 | mlp.save(args.output_file, best_params) 283 | print(f" ...done") 284 | 285 | 286 | if __name__ == '__main__': 287 | main() 288 | -------------------------------------------------------------------------------- /src/main_intersection.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | from functools import partial 3 | 4 | import numpy as np 5 | import jax 6 | import jax.numpy as jnp 7 | from skimage import measure 8 | 9 | import argparse 10 | 11 | import polyscope as ps 12 | import polyscope.imgui as psim 13 | 14 | 15 | # Imports from this project 16 | import render, geometry, queries 17 | from geometry import * 18 | from utils import * 19 | import affine 20 | import slope_interval 21 | import sdf 22 | import mlp 23 | import kd_tree 24 | from implicit_function import SIGN_UNKNOWN, SIGN_POSITIVE, SIGN_NEGATIVE 25 | 26 | import affine_layers 27 | import slope_interval_layers 28 | import implicit_mlp_utils 29 | 30 | # Config 31 | # from jax.config import config 32 | 33 | SRC_DIR = os.path.dirname(os.path.realpath(__file__)) 34 | ROOT_DIR = os.path.join(SRC_DIR, "..") 35 | 36 | 37 | def main(): 38 | 39 | parser = argparse.ArgumentParser() 40 | 41 | # Build arguments 42 | parser.add_argument("inputA", type=str) 43 | parser.add_argument("inputB", type=str) 44 | 45 | parser.add_argument("--res", type=int, default=1024) 46 | 47 | parser.add_argument("--mode", type=str, default='affine_all') 48 | 49 | parser.add_argument("--scaleA", type=float, default=1.) 50 | parser.add_argument("--scaleB", type=float, default=1.) 51 | 52 | parser.add_argument("--log-compiles", action='store_true') 53 | parser.add_argument("--disable-jit", action='store_true') 54 | parser.add_argument("--debug-nans", action='store_true') 55 | parser.add_argument("--enable-double-precision", action='store_true') 56 | 57 | # Parse arguments 58 | args = parser.parse_args() 59 | 60 | ## Small options 61 | debug_log_compiles = False 62 | debug_disable_jit = False 63 | debug_debug_nans = False 64 | if args.log_compiles: 65 | jax.config.update("jax_log_compiles", 1) 66 | debug_log_compiles = True 67 | if args.disable_jit: 68 | jax.config.update('jax_disable_jit', True) 69 | debug_disable_jit = True 70 | if args.debug_nans: 71 | jax.config.update("jax_debug_nans", True) 72 | debug_debug_nans = True 73 | if args.enable_double_precision: 74 | jax.config.update("jax_enable_x64", True) 75 | 76 | ps.set_use_prefs_file(False) 77 | ps.init() 78 | 79 | # GUI Parameters 80 | continuously_render = False 81 | fancy_render = False 82 | continuously_intersect = False 83 | opts = queries.get_default_cast_opts() 84 | opts['data_bound'] = 1 85 | opts['res_scale'] = 1 86 | opts['intersection_eps'] = 1e-3 87 | cast_frustum = False 88 | shade_style = 'matcap_color' 89 | surf_colorA = (0.157,0.613,1.000) 90 | surf_colorB = (0.215,0.865,0.046) 91 | 92 | 93 | # Load the shapes 94 | print("Loading shapes") 95 | implicit_funcA, paramsA = implicit_mlp_utils.generate_implicit_from_file(args.inputA, mode=args.mode, affine_n_truncate=64, affine_truncate_policy='absolute') 96 | paramsA = mlp.prepend_op(paramsA, mlp.spatial_transformation()) 97 | 98 | implicit_funcB, paramsB = implicit_mlp_utils.generate_implicit_from_file(args.inputB, mode=args.mode, affine_n_truncate=64, affine_truncate_policy='absolute') 99 | paramsB = mlp.prepend_op(paramsB, mlp.spatial_transformation()) 100 | 101 | 102 | # Register volume quantities in Polyscope for the shapes 103 | def register_volume(name, implicit_func, params, scale=1.): 104 | 105 | # Construct the regular grid 106 | grid_res = 128 107 | ax_coords = jnp.linspace(-1., 1., grid_res) 108 | grid_x, grid_y, grid_z = jnp.meshgrid(ax_coords, ax_coords, ax_coords, indexing='ij') 109 | grid = jnp.stack((grid_x.flatten(), grid_y.flatten(), grid_z.flatten()), axis=-1) 110 | delta = (grid[1,2] - grid[0,2]).item() 111 | sdf_vals = jax.vmap(partial(implicit_func, params))(grid) 112 | sdf_vals = sdf_vals.reshape(grid_res, grid_res, grid_res) 113 | bbox_min = grid[0,:] 114 | verts, faces, normals, values = measure.marching_cubes(np.array(sdf_vals), level=0., spacing=(delta, delta, delta)) 115 | verts = verts + bbox_min[None,:] 116 | ps_surf = ps.register_surface_mesh(name, verts, faces) 117 | return ps_surf 118 | 119 | print("Registering grids") 120 | ps_vol_A = register_volume("shape A coarse preview", implicit_funcA, paramsA, args.scaleA) 121 | ps_vol_B = register_volume("shape B coarse preview", implicit_funcB, paramsB, args.scaleB) 122 | 123 | print("Loading matcaps") 124 | matcaps = render.load_matcap(os.path.join(ROOT_DIR, "assets", "matcaps", "wax_{}.png")) 125 | 126 | print("Done") 127 | def find_intersection(): 128 | 129 | func_tuple = (implicit_funcA, implicit_funcB) 130 | params_tuple = (paramsA, paramsB) 131 | data_bound = opts['data_bound'] 132 | lower = jnp.array((-data_bound, -data_bound, -data_bound)) 133 | upper = jnp.array((data_bound, data_bound, data_bound)) 134 | eps = opts['intersection_eps'] 135 | 136 | with Timer("intersection"): 137 | found_int, found_int_A, found_int_B, found_int_loc = kd_tree.find_any_intersection(func_tuple, params_tuple, lower, upper, eps) 138 | 139 | if found_int: 140 | pos = np.array(found_int_loc)[None,:] 141 | ps_int_cloud = ps.register_point_cloud("intersection location", pos, enabled=True, radius=0.01, color=(1., 0., 0.)) 142 | else: 143 | ps.remove_point_cloud("intersection location", error_if_absent=False) 144 | 145 | 146 | def viz_intersection_tree(): 147 | 148 | func_tuple = (implicit_funcA, implicit_funcB) 149 | params_tuple = (paramsA, paramsB) 150 | data_bound = opts['data_bound'] 151 | lower = jnp.array((-data_bound, -data_bound, -data_bound)) 152 | upper = jnp.array((data_bound, data_bound, data_bound)) 153 | eps = opts['intersection_eps'] 154 | 155 | found_int, found_int_A, found_int_B, found_int_loc, nodes_lower, nodes_upper, nodes_type = kd_tree.find_any_intersection(func_tuple, params_tuple, lower, upper, eps, viz_nodes=True) 156 | 157 | 158 | verts, inds = kd_tree.generate_tree_viz_nodes_simple(nodes_lower, nodes_upper) 159 | 160 | ps_vol_nodes = ps.register_volume_mesh("search tree nodes", np.array(verts), hexes=np.array(inds)) 161 | ps_vol_nodes.add_scalar_quantity("type", np.array(nodes_type), defined_on='cells') 162 | ps_vol_nodes.set_enabled(True) 163 | 164 | 165 | def callback(): 166 | 167 | nonlocal implicit_funcA, paramsA, implicit_funcB, paramsB, continuously_render, fancy_render, continuously_intersect, cast_frustum, debug_log_compiles, debug_disable_jit, debug_debug_nans, shade_style, surf_colorA, surf_colorB 168 | 169 | 170 | # === Update transforms from Polyscope 171 | def update_transform(ps_vol, params, scale=1.): 172 | T = ps_vol.get_transform() 173 | R = T[:3,:3] 174 | 175 | # TODO this absurdity makes it the transform behave as expected. 176 | # I think there just miiiiight be a bug in the transforms Polyscope is returning 177 | R_inv = jnp.linalg.inv(R) 178 | t = R_inv @ R_inv @ T[:3,3] 179 | 180 | params["0000.spatial_transformation.R"] = R_inv * scale 181 | params["0000.spatial_transformation.t"] = t 182 | update_transform(ps_vol_A, paramsA, args.scaleA) 183 | update_transform(ps_vol_B, paramsB, args.scaleB) 184 | 185 | # === Build the UI 186 | 187 | ## Intersection & options 188 | psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver) 189 | if psim.TreeNode("Intersection"): 190 | psim.PushItemWidth(100) 191 | 192 | if psim.Button("Check for intersection"): 193 | find_intersection() 194 | psim.SameLine() 195 | 196 | _, continuously_intersect = psim.Checkbox("Continuously intersect", continuously_intersect) 197 | if continuously_intersect: 198 | find_intersection() 199 | 200 | 201 | _, opts['intersection_eps'] = psim.InputFloat("intersection_delta", opts['intersection_eps']) 202 | 203 | 204 | if psim.Button("Viz intersection tree"): 205 | viz_intersection_tree() 206 | 207 | psim.PopItemWidth() 208 | psim.TreePop() 209 | 210 | 211 | 212 | if psim.TreeNode("Debug"): 213 | psim.PushItemWidth(100) 214 | 215 | changed, debug_log_compiles = psim.Checkbox("debug_log_compiles", debug_log_compiles) 216 | if changed: 217 | jax.config.update("jax_log_compiles", 1 if debug_log_compiles else 0) 218 | 219 | changed, debug_disable_jit = psim.Checkbox("debug_disable_jit", debug_disable_jit) 220 | if changed: 221 | jax.config.update('jax_disable_jit', debug_disable_jit) 222 | 223 | changed, debug_debug_nans = psim.Checkbox("debug_debug_nans", debug_debug_nans) 224 | if changed: 225 | jax.config.update("jax_debug_nans", debug_debug_nans) 226 | 227 | 228 | psim.PopItemWidth() 229 | psim.TreePop() 230 | 231 | 232 | # Hand off control to the main callback 233 | ps.set_user_callback(callback) 234 | ps.show() 235 | 236 | 237 | 238 | if __name__ == '__main__': 239 | main() 240 | -------------------------------------------------------------------------------- /src/main_spelunking.py: -------------------------------------------------------------------------------- 1 | import igl # work around some env/packaging problems by loading this first 2 | 3 | import sys, os, time, math 4 | from functools import partial 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | import argparse 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | import imageio 13 | from skimage import measure 14 | 15 | 16 | import polyscope as ps 17 | import polyscope.imgui as psim 18 | 19 | # Imports from this project 20 | import render, geometry, queries 21 | from geometry import * 22 | from utils import * 23 | import affine 24 | import slope_interval 25 | import sdf 26 | import mlp 27 | from kd_tree import * 28 | from implicit_function import SIGN_UNKNOWN, SIGN_POSITIVE, SIGN_NEGATIVE 29 | import implicit_mlp_utils, extract_cell 30 | import affine_layers 31 | import slope_interval_layers 32 | 33 | # Config 34 | 35 | SRC_DIR = os.path.dirname(os.path.realpath(__file__)) 36 | ROOT_DIR = os.path.join(SRC_DIR, "..") 37 | 38 | 39 | def save_render_current_view(args, implicit_func, params, cast_frustum, opts, matcaps, surf_color): 40 | 41 | root = ps.get_camera_world_position() 42 | look, up, left = ps.get_camera_frame() 43 | fov_deg = ps.get_field_of_view() 44 | res = args.res // opts['res_scale'] 45 | 46 | surf_color = tuple(surf_color) 47 | 48 | img, depth, count, _, eval_sum, raycast_time = render.render_image(implicit_func, params, root, look, up, left, res, fov_deg, cast_frustum, opts, shading='matcap_color', matcaps=matcaps, shading_color_tuple=(surf_color,)) 49 | 50 | # flip Y 51 | img = img[::-1,:,:] 52 | 53 | # append an alpha channel 54 | alpha_channel = (jnp.min(img,axis=-1) < 1.) * 1. 55 | # alpha_channel = jnp.ones_like(img[:,:,0]) 56 | img_alpha = jnp.concatenate((img, alpha_channel[:,:,None]), axis=-1) 57 | img_alpha = jnp.clip(img_alpha, a_min=0., a_max=1.) 58 | img_alpha = np.array(img_alpha) 59 | print(f"Saving image to {args.image_write_path}") 60 | imageio.imwrite(args.image_write_path, img_alpha) 61 | 62 | 63 | def do_sample_surface(opts, implicit_func, params, n_samples, sample_width, n_node_thresh, do_viz_tree, do_uniform_sample): 64 | data_bound = opts['data_bound'] 65 | lower = jnp.array((-data_bound, -data_bound, -data_bound)) 66 | upper = jnp.array((data_bound, data_bound, data_bound)) 67 | 68 | rngkey = jax.random.PRNGKey(0) 69 | 70 | print(f"do_sample_surface n_node_thresh {n_node_thresh}") 71 | 72 | with Timer("sample points"): 73 | sample_points = sample_surface(implicit_func, params, lower, upper, n_samples, sample_width, rngkey, n_node_thresh=n_node_thresh) 74 | sample_points.block_until_ready() 75 | 76 | ps.register_point_cloud("sampled points", np.array(sample_points)) 77 | 78 | 79 | # Build the tree all over again so we can visualize it 80 | if do_viz_tree: 81 | out_dict = construct_uniform_unknown_levelset_tree(implicit_func, params, lower, upper, n_node_thresh, offset=sample_width) 82 | node_valid = out_dict['unknown_node_valid'] 83 | node_lower = out_dict['unknown_node_lower'] 84 | node_upper = out_dict['unknown_node_upper'] 85 | node_lower = node_lower[node_valid,:] 86 | node_upper = node_upper[node_valid,:] 87 | verts, inds = generate_tree_viz_nodes_simple(node_lower, node_upper, shrink_factor=0.05) 88 | ps_vol = ps.register_volume_mesh("tree nodes", np.array(verts), hexes=np.array(inds)) 89 | 90 | # If requested, also do uniform sampling 91 | if do_uniform_sample: 92 | 93 | with Timer("sample points uniform"): 94 | sample_points = sample_surface_uniform(implicit_func, params, lower, upper, n_samples, sample_width, rngkey) 95 | sample_points.block_until_ready() 96 | 97 | ps.register_point_cloud("uniform sampled points", np.array(sample_points)) 98 | 99 | 100 | 101 | def do_hierarchical_mc(opts, implicit_func, params, n_mc_depth, do_viz_tree, compute_dense_cost): 102 | 103 | 104 | data_bound = opts['data_bound'] 105 | lower = jnp.array((-data_bound, -data_bound, -data_bound)) 106 | upper = jnp.array((data_bound, data_bound, data_bound)) 107 | 108 | 109 | print(f"do_hierarchical_mc {n_mc_depth}") 110 | 111 | 112 | with Timer("extract mesh"): 113 | tri_pos = hierarchical_marching_cubes(implicit_func, params, lower, upper, n_mc_depth, n_subcell_depth=3) 114 | tri_pos.block_until_ready() 115 | 116 | tri_inds = jnp.reshape(jnp.arange(3*tri_pos.shape[0]), (-1,3)) 117 | tri_pos = jnp.reshape(tri_pos, (-1,3)) 118 | ps.register_surface_mesh("extracted mesh", np.array(tri_pos), np.array(tri_inds)) 119 | 120 | # Build the tree all over again so we can visualize it 121 | if do_viz_tree: 122 | n_mc_subcell=3 123 | out_dict = construct_uniform_unknown_levelset_tree(implicit_func, params, lower, upper, split_depth=3*(n_mc_depth-n_mc_subcell), with_interior_nodes=True, with_exterior_nodes=True) 124 | 125 | node_valid = out_dict['unknown_node_valid'] 126 | node_lower = out_dict['unknown_node_lower'] 127 | node_upper = out_dict['unknown_node_upper'] 128 | node_lower = node_lower[node_valid,:] 129 | node_upper = node_upper[node_valid,:] 130 | verts, inds = generate_tree_viz_nodes_simple(node_lower, node_upper, shrink_factor=0.05) 131 | ps_vol = ps.register_volume_mesh("unknown tree nodes", np.array(verts), hexes=np.array(inds)) 132 | 133 | node_valid = out_dict['interior_node_valid'] 134 | node_lower = out_dict['interior_node_lower'] 135 | node_upper = out_dict['interior_node_upper'] 136 | node_lower = node_lower[node_valid,:] 137 | node_upper = node_upper[node_valid,:] 138 | if node_lower.shape[0] > 0: 139 | verts, inds = generate_tree_viz_nodes_simple(node_lower, node_upper, shrink_factor=0.05) 140 | ps_vol = ps.register_volume_mesh("interior tree nodes", np.array(verts), hexes=np.array(inds)) 141 | 142 | node_valid = out_dict['exterior_node_valid'] 143 | node_lower = out_dict['exterior_node_lower'] 144 | node_upper = out_dict['exterior_node_upper'] 145 | node_lower = node_lower[node_valid,:] 146 | node_upper = node_upper[node_valid,:] 147 | if node_lower.shape[0] > 0: 148 | verts, inds = generate_tree_viz_nodes_simple(node_lower, node_upper, shrink_factor=0.05) 149 | ps_vol = ps.register_volume_mesh("exterior tree nodes", np.array(verts), hexes=np.array(inds)) 150 | 151 | def do_closest_point(opts, func, params, n_closest_point): 152 | 153 | data_bound = float(opts['data_bound']) 154 | eps = float(opts['hit_eps']) 155 | lower = jnp.array((-data_bound, -data_bound, -data_bound)) 156 | upper = jnp.array((data_bound, data_bound, data_bound)) 157 | 158 | print(f"do_closest_point {n_closest_point}") 159 | 160 | # generate some query points 161 | rngkey = jax.random.PRNGKey(n_closest_point) 162 | rngkey, subkey = jax.random.split(rngkey) 163 | query_points = jax.random.uniform(subkey, (n_closest_point,3), minval=lower, maxval=upper) 164 | 165 | with Timer("closest point"): 166 | query_dist, query_min_loc = closest_point(func, params, lower, upper, query_points, eps=eps) 167 | query_dist.block_until_ready() 168 | 169 | # visualize only the outside ones 170 | is_outside = jax.vmap(partial(func,params))(query_points) > 0 171 | query_points = query_points[is_outside,:] 172 | query_dist = query_dist[is_outside] 173 | query_min_loc = query_min_loc[is_outside,:] 174 | 175 | viz_line_nodes = jnp.reshape(jnp.stack((query_points, query_min_loc), axis=1), (-1,3)) 176 | viz_line_edges = jnp.reshape(jnp.arange(2*query_points.shape[0]), (-1,2)) 177 | ps.register_point_cloud("closest point query", np.array(query_points)) 178 | ps.register_point_cloud("closest point result", np.array(query_min_loc)) 179 | ps.register_curve_network("closest point line", np.array(viz_line_nodes), np.array(viz_line_edges)) 180 | 181 | 182 | def compute_bulk(args, implicit_func, params, opts): 183 | 184 | data_bound = float(opts['data_bound']) 185 | lower = jnp.array((-data_bound, -data_bound, -data_bound)) 186 | upper = jnp.array((data_bound, data_bound, data_bound)) 187 | 188 | rngkey = jax.random.PRNGKey(0) 189 | 190 | with Timer("bulk properties"): 191 | mass, centroid = bulk_properties(implicit_func, params, lower, upper, rngkey) 192 | mass.block_until_ready() 193 | 194 | print(f"Bulk properties:") 195 | print(f" Mass: {mass}") 196 | print(f" Centroid: {centroid}") 197 | 198 | ps.register_point_cloud("centroid", np.array([centroid])) 199 | 200 | def main(): 201 | 202 | parser = argparse.ArgumentParser() 203 | 204 | # Build arguments 205 | parser.add_argument("input", type=str) 206 | 207 | parser.add_argument("--res", type=int, default=1024) 208 | 209 | parser.add_argument("--image_write_path", type=str, default="render_out.png") 210 | 211 | parser.add_argument("--log-compiles", action='store_true') 212 | parser.add_argument("--disable-jit", action='store_true') 213 | parser.add_argument("--debug-nans", action='store_true') 214 | parser.add_argument("--enable-double-precision", action='store_true') 215 | 216 | # Parse arguments 217 | args = parser.parse_args() 218 | 219 | ## Small options 220 | debug_log_compiles = False 221 | debug_disable_jit = False 222 | debug_debug_nans = False 223 | if args.log_compiles: 224 | jax.config.update("jax_log_compiles", 1) 225 | debug_log_compiles = True 226 | if args.disable_jit: 227 | jax.config.update('jax_disable_jit', True) 228 | debug_disable_jit = True 229 | if args.debug_nans: 230 | jax.config.update("jax_debug_nans", True) 231 | debug_debug_nans = True 232 | if args.enable_double_precision: 233 | jax.config.update("jax_enable_x64", True) 234 | 235 | 236 | # GUI Parameters 237 | opts = queries.get_default_cast_opts() 238 | opts['data_bound'] = 1 239 | opts['res_scale'] = 1 240 | opts['tree_max_depth'] = 12 241 | opts['tree_split_aff'] = False 242 | cast_frustum = False 243 | mode = 'affine_fixed' 244 | modes = ['sdf', 'interval', 'affine_fixed', 'affine_truncate', 'affine_append', 'affine_all', 'slope_interval'] 245 | affine_opts = {} 246 | affine_opts['affine_n_truncate'] = 8 247 | affine_opts['affine_n_append'] = 4 248 | affine_opts['sdf_lipschitz'] = 1. 249 | truncate_policies = ['absolute', 'relative'] 250 | affine_opts['affine_truncate_policy'] = 'absolute' 251 | n_sample_pts = 100000 252 | sample_width = 0.01 253 | n_node_thresh = 4096 254 | do_uniform_sample = False 255 | do_viz_tree = False 256 | n_mc_depth = 8 257 | compute_dense_cost = False 258 | n_closest_point = 16 259 | shade_style = 'matcap_color' 260 | surf_color = (0.157,0.613,1.000) 261 | 262 | implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode=mode, **affine_opts) 263 | 264 | # load the matcaps 265 | matcaps = render.load_matcap(os.path.join(ROOT_DIR, "assets", "matcaps", "wax_{}.png")) 266 | 267 | def callback(): 268 | 269 | nonlocal implicit_func, params, mode, modes, cast_frustum, debug_log_compiles, debug_disable_jit, debug_debug_nans, shade_style, surf_color, n_sample_pts, sample_width, n_node_thresh, do_uniform_sample, do_viz_tree, n_mc_depth, compute_dense_cost, n_closest_point 270 | 271 | 272 | ## Options for general affine evaluation 273 | psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver) 274 | if psim.TreeNode("Eval options"): 275 | psim.PushItemWidth(100) 276 | 277 | old_mode = mode 278 | changed, mode = utils.combo_string_picker("Method", mode, modes) 279 | if mode != old_mode: 280 | implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode=mode, **affine_opts) 281 | 282 | if mode == 'affine_truncate': 283 | # truncate options 284 | 285 | changed, affine_opts['affine_n_truncate'] = psim.InputInt("affine_n_truncate", affine_opts['affine_n_truncate']) 286 | if changed: 287 | implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode=mode, **affine_opts) 288 | 289 | changed, affine_opts['affine_truncate_policy'] = utils.combo_string_picker("Method", affine_opts['affine_truncate_policy'], truncate_policies) 290 | if changed: 291 | implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode=mode, **affine_opts) 292 | 293 | if mode == 'affine_append': 294 | # truncate options 295 | 296 | changed, affine_opts['affine_n_append'] = psim.InputInt("affine_n_append", affine_opts['affine_n_append']) 297 | if changed: 298 | implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode=mode, **affine_opts) 299 | 300 | if mode == 'sdf': 301 | 302 | changed, affine_opts['sdf_lipschitz'] = psim.InputFloat("SDF Lipschitz", affine_opts['sdf_lipschitz']) 303 | if changed: 304 | implicit_func, params = implicit_mlp_utils.generate_implicit_from_file(args.input, mode=mode, **affine_opts) 305 | 306 | 307 | psim.PopItemWidth() 308 | psim.TreePop() 309 | 310 | 311 | psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver) 312 | if psim.TreeNode("Raycast"): 313 | psim.PushItemWidth(100) 314 | 315 | if psim.Button("Save Render"): 316 | save_render_current_view(args, implicit_func, params, cast_frustum, opts, matcaps, surf_color) 317 | 318 | 319 | _, cast_frustum = psim.Checkbox("cast frustum", cast_frustum) 320 | _, opts['hit_eps'] = psim.InputFloat("delta", opts['hit_eps']) 321 | _, opts['max_dist'] = psim.InputFloat("max_dist", opts['max_dist']) 322 | 323 | if cast_frustum: 324 | _, opts['n_side_init'] = psim.InputInt("n_side_init", opts['n_side_init']) 325 | 326 | psim.PopItemWidth() 327 | psim.TreePop() 328 | 329 | 330 | # psim.SetNextItemOpen(True, psim.ImGuiCond_FirstUseEver) 331 | if psim.TreeNode("Sample Surface "): 332 | psim.PushItemWidth(100) 333 | 334 | if psim.Button("Sample"): 335 | do_sample_surface(opts, implicit_func, params, n_sample_pts, sample_width, n_node_thresh, do_viz_tree, do_uniform_sample) 336 | 337 | _, n_sample_pts = psim.InputInt("n_sample_pts", n_sample_pts) 338 | 339 | psim.SameLine() 340 | _, sample_width = psim.InputFloat("sample_width", sample_width) 341 | _, n_node_thresh = psim.InputInt("n_node_thresh", n_node_thresh) 342 | _, do_viz_tree = psim.Checkbox("viz tree", do_viz_tree) 343 | psim.SameLine() 344 | _, do_uniform_sample = psim.Checkbox("also uniform sample", do_uniform_sample) 345 | 346 | 347 | psim.PopItemWidth() 348 | psim.TreePop() 349 | 350 | 351 | if psim.TreeNode("Extract mesh"): 352 | psim.PushItemWidth(100) 353 | 354 | if psim.Button("Extract"): 355 | do_hierarchical_mc(opts, implicit_func, params, n_mc_depth, do_viz_tree, compute_dense_cost) 356 | 357 | psim.SameLine() 358 | _, n_mc_depth = psim.InputInt("n_mc_depth", n_mc_depth) 359 | _, do_viz_tree = psim.Checkbox("viz tree", do_viz_tree) 360 | psim.SameLine() 361 | _, compute_dense_cost = psim.Checkbox("compute dense cost", compute_dense_cost) 362 | 363 | 364 | psim.PopItemWidth() 365 | psim.TreePop() 366 | 367 | 368 | if psim.TreeNode("Closest point"): 369 | psim.PushItemWidth(100) 370 | 371 | if psim.Button("Find closest pionts"): 372 | do_closest_point(opts, implicit_func, params, n_closest_point) 373 | 374 | _, n_closest_point= psim.InputInt("n_closest_point", n_closest_point) 375 | 376 | psim.PopItemWidth() 377 | psim.TreePop() 378 | 379 | ## Bulk 380 | if psim.TreeNode("Bulk Properties"): 381 | psim.PushItemWidth(100) 382 | 383 | if psim.Button("Compute bulk"): 384 | compute_bulk(args, implicit_func, params, opts) 385 | 386 | psim.PopItemWidth() 387 | psim.TreePop() 388 | 389 | 390 | if psim.TreeNode("Debug"): 391 | psim.PushItemWidth(100) 392 | 393 | changed, debug_log_compiles = psim.Checkbox("debug_log_compiles", debug_log_compiles) 394 | if changed: 395 | jax.config.update("jax_log_compiles", 1 if debug_log_compiles else 0) 396 | 397 | changed, debug_disable_jit = psim.Checkbox("debug_disable_jit", debug_disable_jit) 398 | if changed: 399 | jax.config.update('jax_disable_jit', debug_disable_jit) 400 | 401 | changed, debug_debug_nans = psim.Checkbox("debug_debug_nans", debug_debug_nans) 402 | if changed: 403 | jax.config.update("jax_debug_nans", debug_debug_nans) 404 | 405 | 406 | psim.PopItemWidth() 407 | psim.TreePop() 408 | 409 | ps.set_use_prefs_file(False) 410 | ps.init() 411 | 412 | 413 | 414 | # Visualize the data via quick coarse marching cubes, so we have something to look at 415 | 416 | # Construct the regular grid 417 | grid_res = 128 418 | ax_coords = jnp.linspace(-1., 1., grid_res) 419 | grid_x, grid_y, grid_z = jnp.meshgrid(ax_coords, ax_coords, ax_coords, indexing='ij') 420 | grid = jnp.stack((grid_x.flatten(), grid_y.flatten(), grid_z.flatten()), axis=-1) 421 | delta = (grid[1,2] - grid[0,2]).item() 422 | sdf_vals = jax.vmap(partial(implicit_func, params))(grid) 423 | sdf_vals = sdf_vals.reshape(grid_res, grid_res, grid_res) 424 | bbox_min = grid[0,:] 425 | verts, faces, normals, values = measure.marching_cubes(np.array(sdf_vals), level=0., spacing=(delta, delta, delta)) 426 | verts = verts + bbox_min[None,:] 427 | ps.register_surface_mesh("coarse shape preview", verts, faces) 428 | 429 | print("REMEMBER: All routines will be slow on the first invocation due to JAX kernel compilation. Subsequent calls will be fast.") 430 | 431 | # Hand off control to the main callback 432 | ps.show(1) 433 | ps.set_user_callback(callback) 434 | ps.show() 435 | 436 | 437 | if __name__ == '__main__': 438 | main() 439 | -------------------------------------------------------------------------------- /src/mlp.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | # Imports from this project 8 | from utils import * 9 | import affine 10 | import slope_interval 11 | 12 | # ===== High-level flow 13 | 14 | def build_spec(mlp_op_list): 15 | out_params = {} 16 | 17 | # prepend an opcount to the list of operations 18 | # (entries now look like "0007.dense.A", "0003.relu", etc) 19 | for i_op, op in enumerate(mlp_op_list): 20 | for key, val in op.items(): 21 | key = f"{i_op:04d}." + key 22 | out_params[key] = val 23 | 24 | return out_params 25 | 26 | def initialize_params(params, rngkey): 27 | 28 | N_op = n_ops(params) 29 | out_params = {} 30 | 31 | # perform initialization 32 | for i_op in range(N_op): 33 | name, orig_args = get_op_data(params, i_op) 34 | if name in initialize_func: 35 | 36 | # apply the init function 37 | subkey, rngkey = jax.random.split(rngkey) 38 | init_args = initialize_func[name](rngkey=subkey, **orig_args) 39 | 40 | # replace the updated data in the array 41 | for a in init_args: 42 | a_op = f"{i_op:04d}.{name}.{a}" 43 | out_params[a_op] = init_args[a] 44 | 45 | # functions which require no initialization 46 | # (just copy the params content) 47 | else: 48 | for a in orig_args: 49 | a_op = f"{i_op:04d}.{name}.{a}" 50 | out_params[a_op] = orig_args[a] 51 | 52 | 53 | return out_params 54 | 55 | def opt_param_keys(params): 56 | 57 | N_op = n_ops(params) 58 | keys = [] 59 | 60 | for i_op in range(N_op): 61 | name, orig_args = get_op_data(params, i_op) 62 | 63 | if name in opt_params: 64 | for a in orig_args: 65 | if a in opt_params[name]: 66 | fullname = f"{i_op:04d}.{name}.{a}" 67 | keys.append(fullname) 68 | 69 | return set(keys) 70 | 71 | 72 | # Easy helper for defining MLP from layers and activations 73 | def quick_mlp_spec(layer_sizes, activation): 74 | 75 | spec_list = [] 76 | 77 | for i in range(len(layer_sizes)-1): 78 | d_in = layer_sizes[i] 79 | d_out = layer_sizes[i+1] 80 | 81 | spec_list.append(dense(d_in, d_out)) 82 | 83 | # apply activation 84 | is_last = (i+2 == len(layer_sizes)) 85 | if not is_last: 86 | if activation == 'relu': 87 | spec_list.append(relu()) 88 | elif activation == 'elu': 89 | spec_list.append(elu()) 90 | else: raise ValueError("unrecognized activation") 91 | 92 | spec_list.append(squeeze_last()) 93 | 94 | return spec_list 95 | 96 | def func_from_spec(mode='default'): 97 | # be careful of mutable default arg ^^^ 98 | 99 | def eval_spec(params, x, mode_dict=None): 100 | N_op = n_ops(params) 101 | 102 | # walk the list of operations, evaluating each 103 | # TODO generalize w/ data tape to not assume linear dataflow 104 | for i_op in range(N_op): 105 | name, args = get_op_data(params, i_op) 106 | if mode_dict is not None: 107 | args.update(mode_dict) 108 | if "_" in args: 109 | del args["_"] 110 | x = apply_func[mode][name](x, **args) 111 | return x 112 | 113 | return eval_spec 114 | 115 | # ===== Utilities 116 | 117 | def get_op_data(params, i_op): 118 | i_op_str = f"{i_op:04d}" 119 | name = "" 120 | args = {} 121 | for key in params: 122 | if key.startswith(i_op_str): 123 | tokens = key.split(".") 124 | name = tokens[1] 125 | if len(tokens) > 2: 126 | argname = tokens[2] 127 | args[argname] = params[key] 128 | 129 | if name == "": 130 | print(params.keys()) 131 | raise ValueError(f"didn't find op {i_op}") 132 | 133 | return name, args 134 | 135 | def n_ops(params): 136 | n = 0 137 | for key in params: 138 | vals = key.split(".") 139 | try: 140 | i_op = int(vals[0]) 141 | except ValueError: 142 | raise ValueError(f"Could not parse out key {key}. Is this a valid mlp spec? Did you make a mistake passing params dictionaries around?") 143 | n = max(n, i_op+1) 144 | return n 145 | 146 | # helper to add an operation to an existing MLP 147 | # call like: 148 | # params = prepend_op(params, spatial_transformation()) 149 | def prepend_op(params, op): 150 | new_params = {} 151 | 152 | # increment the op ind in the key of all existing ops 153 | for key in params: 154 | vals = key.split(".") 155 | i_op = int(vals[0]) 156 | i_op += 1 157 | vals[0] = f"{i_op:04d}" 158 | new_key = ".".join(vals) 159 | new_params[new_key] = params[key] 160 | 161 | # add the new op 162 | N = n_ops(params) 163 | for key, val in op.items(): 164 | key = f"{0:04d}." + key 165 | new_params[key] = val 166 | 167 | return new_params 168 | 169 | def check_rng_key(key): 170 | if key is None: 171 | raise ValueError("to initialize model weights, must pass an RNG key") 172 | 173 | def load(filename): 174 | out_params = {} 175 | param_count = 0 176 | with np.load(filename) as data: 177 | for key,val in data.items(): 178 | # print(f"mlp layer key: {key}") 179 | # convert numpy to jax arrays 180 | if isinstance(val, np.ndarray): 181 | param_count += val.size 182 | val = jnp.array(val) 183 | out_params[key] = val 184 | print(f"Loaded MLP with {param_count} params") 185 | return out_params 186 | 187 | def save(filename, params): 188 | 189 | np_params = {} # copy to a new dict, we will modify 190 | for key, val in params.items(): 191 | # convert jax to numpy arrays 192 | if isinstance(val, jnp.ndarray): 193 | val = np.array(val) 194 | np_params[key] = val 195 | 196 | np.savez(filename, **np_params) 197 | 198 | 199 | # ===== Listing of layer types and associated functions 200 | # These are populated below, along with the creation functions themselves 201 | 202 | # Initializes array buffers for the functions 203 | initialize_func = {} 204 | 205 | # A list of the keys which need to be optimized during training 206 | opt_params = {} 207 | 208 | # These are populated in 'affine_layers' and 'slope_interval_layers', respectively. 209 | # TODO bad software design: need to import affine_layers, etc later for these to get populated 210 | apply_func = { 211 | 'default' : {}, 212 | 'affine' : {}, 213 | 'slope_interval' : {} 214 | } 215 | 216 | # == Dense linear layer 217 | 218 | def dense(in_dim, out_dim, with_bias=True, A=None, b=None): 219 | if(not with_bias and b is not None): 220 | raise ValueError("cannot specifify 'b' and 'with_bias=False'") 221 | 222 | # initialize A 223 | if A is None: 224 | # random initialize later 225 | A = (in_dim, out_dim) 226 | else: 227 | # use the input 228 | A = jnp.array(A) 229 | if A.shape != (in_dim,out_dim): 230 | raise ValueError(f"A should have shape ({in_dim},{out_dim}). Has shape {A.shape}.") 231 | 232 | # initialize b 233 | if b is None and with_bias: 234 | # random initialize later 235 | b = (out_dim,) 236 | else: 237 | # use the input 238 | b = jnp.array(b) 239 | if b.shape != (out_dim,): 240 | raise ValueError(f"b should have shape ({out_dim}). Has shape {b.shape}.") 241 | 242 | subdict = { 243 | "dense.A" : A, 244 | } 245 | 246 | if with_bias: 247 | subdict["dense.b"] = b 248 | 249 | return subdict 250 | 251 | opt_params['dense'] = ['A', 'b'] 252 | 253 | def default_dense(input, A, b): 254 | out = jnp.dot(input, A) 255 | if b is not None: 256 | out += b 257 | return out 258 | apply_func['default']['dense'] = default_dense 259 | 260 | def initialize_dense(rngkey=None, A=None, b=None): 261 | if isinstance(A, tuple): # if A needs initialization, it is a tuple giving the size 262 | check_rng_key(rngkey) 263 | subkey, rngkey = jax.random.split(rngkey) 264 | initF = jax.nn.initializers.glorot_normal() 265 | A = initF(subkey, A) 266 | if isinstance(b, tuple): # if b needs initialization, it is a tuple giving the size 267 | check_rng_key(rngkey) 268 | subkey, rngkey = jax.random.split(rngkey) 269 | initF = jax.nn.initializers.normal() 270 | b = initF(subkey, b) 271 | 272 | out_dict = { 'A' : A } 273 | if b is not None: 274 | out_dict['b'] = b 275 | 276 | return out_dict 277 | initialize_func['dense'] = initialize_dense 278 | 279 | 280 | 281 | # == Common activations 282 | 283 | def relu(): 284 | return {"relu._" : jnp.array([])} 285 | def default_relu(input): 286 | return jax.nn.relu(input) 287 | apply_func['default']['relu'] = default_relu 288 | 289 | def elu(): 290 | return {"elu._" : jnp.array([])} 291 | def default_elu(input): 292 | return jax.nn.elu(input) 293 | apply_func['default']['elu'] = default_elu 294 | 295 | 296 | def sin(): 297 | return {"sin._" : jnp.array([])} 298 | def default_sin(input): 299 | return jnp.sin(input) 300 | apply_func['default']['sin'] = default_sin 301 | 302 | # == Positional encoding 303 | 304 | def pow2_frequency_encode(count_pow2, start_pow=0, with_shift=True): 305 | pows = jax.lax.pow(2., jnp.arange(start=start_pow, stop=start_pow+count_pow2, dtype=float)) 306 | coefs = pows * jnp.pi 307 | 308 | if with_shift: 309 | coefs = jnp.repeat(coefs, 2) 310 | shift = jnp.zeros_like(coefs) 311 | shift = shift.at[1::2].set(jnp.pi) 312 | return {"pow2_frequency_encode.coefs" : coefs, "pow2_frequency_encode.shift" : shift} 313 | else: 314 | return {"pow2_frequency_encode.coefs" : coefs} 315 | 316 | def default_pow2_frequency_encode(input, coefs, shift=None): 317 | x = input[:,None] * coefs[None,:] 318 | if shift is not None: 319 | x += shift 320 | x = x.flatten() 321 | return x 322 | apply_func['default']['pow2_frequency_encode'] = default_pow2_frequency_encode 323 | 324 | 325 | # == Utility 326 | 327 | 328 | def squeeze_last(): 329 | return {"squeeze_last._" : jnp.array([])} 330 | def default_squeeze_last(input): 331 | return jnp.squeeze(input, axis=0) 332 | apply_func['default']['squeeze_last'] = default_squeeze_last 333 | 334 | # R,t are a transformation for the SHAPE, input points will get the opposite transform 335 | def spatial_transformation(): 336 | return { 337 | "spatial_transformation.R" : jnp.eye(3), 338 | "spatial_transformation.t" : jnp.zeros(3), 339 | } 340 | 341 | 342 | def default_spatial_transformation(input, R, t): 343 | # if the shape transforms by R,t, input points need the opposite transform 344 | R_inv = jnp.linalg.inv(R) 345 | t_inv = jnp.dot(R_inv, -t) 346 | return default_dense(input, A=R_inv, b=t_inv) 347 | apply_func['default']['spatial_transformation'] = default_spatial_transformation 348 | 349 | # TODO bad software design, see note above 350 | import affine_layers 351 | import slope_interval_layers 352 | -------------------------------------------------------------------------------- /src/queries.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from functools import partial 7 | 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | import utils 12 | import render 13 | import geometry 14 | from bucketing import * 15 | import implicit_function 16 | from implicit_function import SIGN_UNKNOWN, SIGN_POSITIVE, SIGN_NEGATIVE 17 | 18 | 19 | # ============================================================= 20 | # ==== Cast Rays 21 | # ============================================================= 22 | 23 | def get_default_cast_opts(): 24 | d = { 25 | 'hit_eps' : 0.001, 26 | 'max_dist' : 10., 27 | 'n_max_step' : 512, 28 | 'n_substeps' : 1, 29 | 'safety_factor' : 0.98, 30 | 'interval_grow_fac' : 1.5, 31 | 'interval_shrink_fac' : 0.5, 32 | 'interval_init_size' : 0.1, # relative, as afactor of max_dist 33 | 'refine_width_fac' : 2., 34 | 'n_side_init' : 16, 35 | } 36 | return d 37 | 38 | 39 | @partial(jax.jit, static_argnames=("funcs_tuple", "n_substeps"), donate_argnums=(5,6,8,9,10,11,12)) 40 | def cast_rays_iter(funcs_tuple, params_tuple, n_substeps, curr_roots, curr_dirs, curr_t, curr_int_size, curr_inds, curr_valid, curr_count, out_t, out_hit_id, out_count, opts): 41 | INVALID_IND = out_t.shape[0]+1 42 | 43 | # Evaluate the function 44 | def take_step(substep_ind, in_tup): 45 | root, dir, t, step_size, is_hit, hit_id, step_count = in_tup 46 | 47 | can_step = ~is_hit # this ensure that if we hit on a previous substep, we don't keep stepping 48 | 49 | step_count += ~is_hit # count another step for this ray (unless concvered) 50 | 51 | # loop over all function in the list 52 | func_id = 1 53 | for func, params in zip(funcs_tuple, params_tuple): 54 | 55 | pos_start = root + t * dir 56 | half_vec = 0.5 * step_size * dir 57 | pos_mid = pos_start + half_vec 58 | box_type = func.classify_general_box(params, pos_mid, half_vec[None,:]) 59 | 60 | # test if the step is safe 61 | can_step = jnp.logical_and( 62 | can_step, 63 | jnp.logical_or(box_type == SIGN_POSITIVE, box_type == SIGN_NEGATIVE) 64 | ) 65 | 66 | # For convergence testing, sample the function value at the start the interval and start + eps 67 | pos_start = root + t * dir 68 | pos_eps = root + (t + opts['hit_eps']) * dir 69 | val_start = func(params, pos_start) 70 | val_eps = func(params, pos_eps) 71 | 72 | # Check if we converged for this func 73 | this_is_hit = jnp.sign(val_start) != jnp.sign(val_eps) 74 | hit_id = jnp.where(this_is_hit, func_id, hit_id) 75 | is_hit = jnp.logical_or(is_hit, this_is_hit) 76 | 77 | func_id += 1 78 | 79 | # take a full step of step_size if it was safe, but even if not we still inch forward 80 | # (this matches our convergence/progress guarantee) 81 | this_step_size = jnp.where(can_step, step_size, opts['hit_eps']) 82 | 83 | # take the actual step (unless a previous substep hit, in which case we do nothing) 84 | t = jnp.where(is_hit, t, t + this_step_size * opts['safety_factor']) 85 | 86 | # update the step size 87 | step_size = jnp.where(can_step, 88 | step_size * opts['interval_grow_fac'], 89 | step_size * opts['interval_shrink_fac']) 90 | step_size = jnp.clip(step_size, a_min=opts['hit_eps']) 91 | 92 | return (root, dir, t, step_size, is_hit, hit_id, step_count) 93 | 94 | # substepping 95 | def take_several_steps(root, dir, t, step_size): 96 | 97 | # Perform some substeps 98 | is_hit = False 99 | hit_id = 0 100 | step_count = 0 101 | in_tup = (root, dir, t, step_size, is_hit, hit_id, step_count) 102 | 103 | out_tup = jax.lax.fori_loop(0, n_substeps, take_step, in_tup) 104 | 105 | _, _, t, step_size, is_hit, hit_id, step_count = out_tup 106 | 107 | return t, step_size, is_hit, hit_id, step_count 108 | 109 | # evaluate the substeps on a all rays 110 | curr_t, curr_int_size, is_hit, hit_id, num_inner_steps \ 111 | = jax.jit(jax.vmap(take_several_steps))(curr_roots, curr_dirs, curr_t, curr_int_size) 112 | 113 | curr_count += curr_valid * num_inner_steps 114 | 115 | # Test convergence 116 | is_miss = curr_t > opts['max_dist'] 117 | is_count_terminate = curr_count >= opts['n_max_step'] 118 | terminated = jnp.logical_and( 119 | jnp.logical_or(jnp.logical_or(is_hit, is_miss), is_count_terminate), 120 | curr_valid) 121 | 122 | # Write out finished rays 123 | write_inds = jnp.where(terminated, curr_inds, INVALID_IND) 124 | out_t = out_t.at[write_inds].set(curr_t, mode='drop') 125 | out_hit_id = out_hit_id.at[write_inds].set(hit_id, mode='drop') 126 | out_count = out_count.at[write_inds].set(curr_count, mode='drop') 127 | 128 | # Finished rays are no longer valid 129 | curr_valid = jnp.logical_and(curr_valid, ~terminated) 130 | N_valid = curr_valid.sum() 131 | 132 | return curr_t, curr_int_size, curr_valid, curr_count, out_t, out_hit_id, out_count, N_valid 133 | 134 | def cast_rays(funcs_tuple, params_tuple, roots, dirs, opts): 135 | 136 | N = roots.shape[0] 137 | N_evals = 0 # all of the evaluations, INCLUDING those performed on unused padded array elements 138 | n_substeps = opts['n_substeps'] 139 | 140 | # Outputs go here 141 | out_t = jnp.zeros(N) 142 | out_hit_id = jnp.zeros(N, dtype=int) 143 | 144 | # Working data (we will shrink this as the algorithm proceeds and rays start terminating) 145 | curr_roots = roots 146 | curr_dirs = dirs 147 | curr_t = jnp.zeros(N) 148 | curr_int_size = jnp.ones(N) * opts['interval_init_size'] * opts['max_dist'] 149 | # curr_int_size = None # TODO don't technically need this for SDFs 150 | curr_inds = jnp.arange(0, N, dtype=int) # which original ray this working ray corresponds to 151 | curr_valid = jnp.ones(N, dtype=bool) # a mask of rays which are actually valid, in-progress rays 152 | 153 | # Also track number of evaluations 154 | out_count = jnp.zeros(N, dtype=int) 155 | curr_count = jnp.zeros(N, dtype=int) 156 | 157 | iter = 0 158 | while(True): 159 | 160 | iter += 1 161 | curr_t, curr_int_size, curr_valid, curr_count, out_t, out_hit_id, out_count, N_valid \ 162 | = cast_rays_iter(funcs_tuple, params_tuple, n_substeps, curr_roots, curr_dirs, curr_t, curr_int_size, curr_inds, curr_valid, \ 163 | curr_count, out_t, out_hit_id, out_count, opts) 164 | N_evals += curr_t.shape[0] * n_substeps 165 | 166 | N_valid = int(N_valid) 167 | if N_valid == 0: 168 | break 169 | 170 | if fits_in_smaller_bucket(N_valid, curr_valid.shape[0]): 171 | new_bucket_size = get_next_bucket_size(N_valid) 172 | curr_valid, empty_start, curr_roots, curr_dirs, curr_t, curr_int_size, curr_inds, curr_count = \ 173 | compactify_and_rebucket_arrays(curr_valid, new_bucket_size, curr_roots, curr_dirs, curr_t, curr_int_size, curr_inds, curr_count) 174 | 175 | return out_t, out_hit_id, out_count, N_evals 176 | 177 | 178 | @partial(jax.jit, static_argnames=("funcs_tuple", "n_substeps"), donate_argnums=(5,7,8,9,10,11,12,13,14)) 179 | def cast_rays_frustum_iter( 180 | funcs_tuple, params_tuple, cam_params, iter, n_substeps, 181 | curr_valid, 182 | curr_frust_range, 183 | curr_frust_t, 184 | curr_frust_int_size, 185 | curr_frust_count, 186 | finished_frust_range, 187 | finished_frust_t, 188 | finished_frust_hit_id, 189 | finished_frust_count, 190 | finished_start_ind, 191 | opts): 192 | 193 | 194 | N = finished_frust_range.shape[0]+1 195 | INVALID_IND = N+1 196 | 197 | root_pos, look_dir, up_dir, left_dir, fov_x, fov_y, res_x, res_y = cam_params 198 | gen_cam_ray = partial(render.camera_ray, look_dir, up_dir, left_dir, fov_x, fov_y) 199 | 200 | 201 | # x/y should be integer coordinates on [0,res], they are converted to angles internally 202 | def take_step(ray_xu_yu, ray_xu_yl, ray_xl_yu, ray_xl_yl, mid_ray, expand_fac, is_single_pixel, substep_ind, in_tup): 203 | 204 | t, step_size, is_hit, hit_id, step_demands_subd, step_count = in_tup 205 | t_upper = t + step_size 206 | t_upper_adj = t_upper*expand_fac 207 | 208 | # Construct the rectangular (but not-axis-aligned) box enclosing the frustum 209 | right_front = (ray_xu_yu - ray_xl_yu) * t_upper_adj / 2 210 | up_front = (ray_xu_yu - ray_xu_yl) * t_upper_adj / 2 211 | source_range = jnp.stack((right_front, up_front), axis=0) 212 | 213 | can_step = ~is_hit # this ensure that if we hit on a previous substep, we don't keep stepping 214 | step_count += ~is_hit # count another step for this ray (unless concvered) 215 | 216 | center_mid = root_pos + 0.5 * (t + t_upper_adj) * mid_ray 217 | center_vec = 0.5 * (t_upper_adj - t) * mid_ray 218 | box_vecs = jnp.concatenate((center_vec[None,:], source_range), axis=0) 219 | 220 | # loop over all function in the list 221 | func_id = 1 222 | for func, params in zip(funcs_tuple, params_tuple): 223 | 224 | # Perform the actual interval test 225 | box_type = func.classify_general_box(params, center_mid, box_vecs) 226 | 227 | # test if the step is safe 228 | can_step = jnp.logical_and( 229 | can_step, 230 | jnp.logical_or(box_type == SIGN_POSITIVE, box_type == SIGN_NEGATIVE) 231 | ) 232 | 233 | 234 | # For convergence testing, sample the function value at the start the interval and start + eps 235 | # (this is only relevant/useful once the frustum is a single ray and we start testing hits) 236 | pos_start = root_pos + t * mid_ray 237 | pos_eps = root_pos + (t + opts['hit_eps']) * mid_ray 238 | val_start = func(params, pos_start) 239 | val_eps = func(params, pos_eps) 240 | 241 | # Check if we converged for this func 242 | # (this is only relevant/useful once the frustum is a single ray and we start testing hits) 243 | this_is_hit = jnp.sign(val_start) != jnp.sign(val_eps) 244 | hit_id = jnp.where(this_is_hit, func_id, hit_id) 245 | is_hit = jnp.logical_or(is_hit, this_is_hit) 246 | 247 | func_id += 1 248 | 249 | # take a full step of step_size if it was safe, but even if not we still inch forward 250 | # the is_single_pixel ensures that we only inch forward for single-pixel rays, we can't 251 | # be sure it's safe to do so for larger frusta. 252 | # (this matches our convergence/progress guarantee) 253 | this_step_size = jnp.where(can_step, step_size, opts['hit_eps'] * is_single_pixel) 254 | 255 | # take the actual step (unless a previous substep hit, in which case we do nothing) 256 | t = jnp.where(is_hit, t, t + this_step_size * opts['safety_factor']) 257 | 258 | # update the step size 259 | step_size = jnp.where(can_step, 260 | step_size * opts['interval_grow_fac'], 261 | step_size * opts['interval_shrink_fac']) 262 | 263 | step_demands_subd = utils.logical_or_all((step_demands_subd, step_size < opts['hit_eps'], is_hit)) 264 | 265 | 266 | step_size = jnp.clip(step_size, a_min=opts['hit_eps']) 267 | 268 | return t, step_size, is_hit, hit_id, step_demands_subd, step_count 269 | 270 | 271 | # substepping 272 | def take_several_steps(frust_range, t, step_size): 273 | 274 | # Do all of the frustum geometry calculation here. It doesn't change 275 | # per-substep, so might as well compute it before we start substepping. 276 | 277 | x_lower = frust_range[0] 278 | x_upper = frust_range[2] 279 | y_lower = frust_range[1] 280 | y_upper = frust_range[3] 281 | is_single_pixel = jnp.logical_and(x_lower+1==x_upper, y_lower+1==y_upper) 282 | 283 | # compute bounds as coords on [-1,1] 284 | # TODO it would be awesome to handle whole-pixel frustums and get a guarantee 285 | # about not leaking/aliasing geometry. However, in some cases the bounds cannot make 286 | # progress even a single-pixel sized frustum, and get stuck. We would need to handle 287 | # sub-pixel frustums to guarantee progress, which we do not currently support. For 288 | # this reason we treat each pixel as a point sample, and build frustums around those 289 | # instead. The difference is the -1 on the upper coords here. 290 | xc_lower = 2.* (x_lower ) / (res_x+1.) - 1. 291 | xc_upper = 2.* (x_upper-1) / (res_x+1.) - 1. 292 | yc_lower = 2.* (y_lower ) / (res_y+1.) - 1. 293 | yc_upper = 2.* (y_upper-1) / (res_y+1.) - 1. 294 | 295 | # generate rays corresponding to the four corners of the frustum 296 | ray_xu_yu = gen_cam_ray(xc_upper, yc_upper) 297 | ray_xl_yu = gen_cam_ray(xc_lower, yc_upper) 298 | ray_xu_yl = gen_cam_ray(xc_upper, yc_lower) 299 | ray_xl_yl = gen_cam_ray(xc_lower, yc_lower) 300 | 301 | # a ray down the center of the frustum 302 | mid_ray = 0.5 * (ray_xu_yu + ray_xl_yl) 303 | mid_ray_len = geometry.norm(mid_ray) 304 | mid_ray = mid_ray / mid_ray_len 305 | 306 | # Expand the box by a factor of 1/(cos(theta/2) to account for the fact that the spherical frustum extends a little beyond the naive linearly interpolated box. 307 | expand_fac = 1. / mid_ray_len 308 | 309 | # Perform some substeps 310 | is_hit = False 311 | hit_id = 0 312 | step_count = 0 313 | step_demands_subd = False 314 | in_tup = (t, step_size, is_hit, hit_id, step_demands_subd, step_count) 315 | 316 | take_step_this = partial(take_step, ray_xu_yu, ray_xu_yl, ray_xl_yu, ray_xl_yl, mid_ray, expand_fac, is_single_pixel) 317 | out_tup = jax.lax.fori_loop(0, n_substeps, take_step_this, in_tup) 318 | 319 | t, step_size, is_hit, hit_id, step_demands_subd, step_count = out_tup 320 | return t, step_size, is_hit, hit_id, step_demands_subd, step_count 321 | 322 | # evaluate the substeps on a all rays 323 | curr_frust_t, curr_frust_int_size, is_hit, hit_id, step_demands_subd, num_inner_steps \ 324 | = jax.jit(jax.vmap(take_several_steps))(curr_frust_range, curr_frust_t, curr_frust_int_size) 325 | 326 | # Measure frustum area in pixels, use it to track counts 327 | x_lower = curr_frust_range[:,0] 328 | x_upper = curr_frust_range[:,2] 329 | y_lower = curr_frust_range[:,1] 330 | y_upper = curr_frust_range[:,3] 331 | frust_area = (x_upper-x_lower)*(y_upper-y_lower) 332 | curr_frust_count += curr_valid * num_inner_steps * (1. / frust_area) 333 | 334 | # only size-1 frusta actually get to hit 335 | is_hit = jnp.logical_and(is_hit, frust_area == 1) 336 | 337 | is_miss = curr_frust_t > opts['max_dist'] 338 | is_count_terminate = iter >= opts['n_max_step'] 339 | terminated = jnp.logical_and(jnp.logical_or(jnp.logical_or(is_hit, is_miss), is_count_terminate), curr_valid) 340 | 341 | # Write out finished rays 342 | target_inds = utils.enumerate_mask(terminated, fill_value=INVALID_IND) + finished_start_ind 343 | finished_frust_range = finished_frust_range.at[target_inds,:].set(curr_frust_range, mode='drop') 344 | finished_frust_t = finished_frust_t.at[target_inds].set(curr_frust_t, mode='drop') 345 | finished_frust_hit_id = finished_frust_hit_id.at[target_inds].set(hit_id, mode='drop') 346 | finished_frust_count = finished_frust_count.at[target_inds].set(curr_frust_count, mode='drop') 347 | curr_valid = jnp.logical_and(curr_valid, ~terminated) 348 | finished_start_ind += jnp.sum(terminated) 349 | 350 | 351 | # Identify rays that need to be split 352 | # TODO some arithmetic repeated with the function raycast 353 | width_x = 2*jnp.sin(jnp.deg2rad(fov_x)/2 * (x_upper - x_lower) / res_x)*curr_frust_t 354 | width_y = 2*jnp.sin(jnp.deg2rad(fov_y)/2 * (y_upper - y_lower) / res_y)*curr_frust_t 355 | can_subd = jnp.logical_or(curr_frust_range[:,2] > (curr_frust_range[:,0]+1), curr_frust_range[:,3] > (curr_frust_range[:,1]+1)) 356 | needs_refine = utils.logical_or_all((width_x > opts['refine_width_fac']*curr_frust_int_size, 357 | width_y > opts['refine_width_fac']*curr_frust_int_size, 358 | step_demands_subd)) # last condition ensure rays which hit but still need subd always get it 359 | needs_refine = jnp.logical_and(needs_refine, can_subd) 360 | needs_refine = jnp.logical_and(needs_refine, curr_valid) 361 | 362 | N_needs_refine = jnp.sum(needs_refine) 363 | N_valid = jnp.sum(curr_valid) 364 | 365 | return curr_valid, curr_frust_t, curr_frust_int_size, curr_frust_count, needs_refine, \ 366 | finished_frust_range, finished_frust_t, finished_frust_hit_id, finished_frust_count, finished_start_ind, N_valid, N_needs_refine 367 | 368 | 369 | # For all frusta specified by sub_mask, split to be half the size along one axis (chosen automatically internally). 370 | # Creates sum(sub_mask) new frusta entries, in addition to updating the existing subd entries, all with half the size. 371 | # All entries specified by sub_mask MUST have index width >1 along one dimension. 372 | # Precondition: there must be space in the arrays to hold the new elements. This routine at most 373 | # doubles the size, therefore this requires frust.shape[0]-empty_start_ind > (2*sum(sub_mask)) 374 | @partial(jax.jit, donate_argnums=(2,3,4)) 375 | def subdivide_frusta(sub_mask, empty_start_ind, valid_mask, curr_frust_range, arrs): 376 | 377 | # curr_frust_t, curr_frust_int_size): 378 | N = sub_mask.shape[-1] 379 | INVALID_IND = N+1 380 | 381 | # TODO should probably just assume this 382 | sub_mask = jnp.logical_and(sub_mask, valid_mask) 383 | 384 | # Pick which direction to subdivide in 385 | x_gap = curr_frust_range[:,2] - curr_frust_range[:,0] 386 | y_gap = curr_frust_range[:,3] - curr_frust_range[:,1] 387 | # assumption: one of these gaps will always be nonempty 388 | subd_x = x_gap >= y_gap 389 | 390 | # Generate the new frustums (call the two of them 'A' and 'B') 391 | # (for the sake of vectorization, we generate these at all frustra, but will only use them # at the ones which are actually being split) 392 | x_mid = (curr_frust_range[:,0] + curr_frust_range[:,2]) / 2 393 | y_mid = (curr_frust_range[:,1] + curr_frust_range[:,3]) / 2 394 | split_x_hi_A = jnp.where( subd_x, x_mid, curr_frust_range[:,2]) 395 | split_x_lo_B = jnp.where( subd_x, x_mid, curr_frust_range[:,0]) 396 | split_y_hi_A = jnp.where(~subd_x, y_mid, curr_frust_range[:,3]) 397 | split_y_lo_B = jnp.where(~subd_x, y_mid, curr_frust_range[:,1]) 398 | frust_range_A = curr_frust_range 399 | frust_range_A = frust_range_A.at[:,2].set(split_x_hi_A) 400 | frust_range_A = frust_range_A.at[:,3].set(split_y_hi_A) 401 | frust_range_B = curr_frust_range 402 | frust_range_B = frust_range_B.at[:,0].set(split_x_lo_B) 403 | frust_range_B = frust_range_B.at[:,1].set(split_y_lo_B) 404 | 405 | arrs_out = arrs # initially this is just a copy (since B arrays inherit all the same data) 406 | 407 | # Overwrite the new A frustum on to the original entry 408 | overwrite_A = sub_mask 409 | curr_frust_range = jnp.where(overwrite_A[:,None], frust_range_A, curr_frust_range) 410 | # curr_frust_t = jnp.where(overwrite_A, frust_t_A, curr_frust_t) # optimization: this is a no-op 411 | # curr_frust_int_size = jnp.where(overwrite_A, frust_int_size_A, curr_frust_int_size) # optimization: this is a no-op 412 | 413 | # Compactify the new B entries, then roll them to their starting position in the new array 414 | compact_inds = jnp.nonzero(sub_mask, size=N, fill_value=INVALID_IND)[0] 415 | frust_range_B = frust_range_B.at[compact_inds,:].get(mode='drop') 416 | frust_range_B = jnp.roll(frust_range_B, empty_start_ind, axis=0) 417 | 418 | # Prep data arrays corresponding to all the B frusta 419 | arrs_B = [] 420 | for a in arrs: 421 | a = a.at[compact_inds,...].get(mode='drop') 422 | a = jnp.roll(a, empty_start_ind, axis=0) 423 | arrs_B.append(a) 424 | overwrite_B = jnp.roll(compact_inds < INVALID_IND, empty_start_ind) 425 | # print(f"overwrite_B:\n{overwrite_B}") 426 | 427 | # Overwrite 428 | curr_frust_range = jnp.where(overwrite_B[:,None], frust_range_B, curr_frust_range) 429 | for i in range(len(arrs_out)): 430 | arrs_out[i] = jnp.where(overwrite_B, arrs_B[i], arrs_out[i]) 431 | valid_mask = jnp.logical_or(valid_mask, overwrite_B) 432 | 433 | return valid_mask, curr_frust_range, arrs_out 434 | 435 | @jax.jit 436 | def frustum_needs_subdiv_to_pixel(frust_valid, frust_range): 437 | is_single_pixel = jnp.logical_and((frust_range[:,0]+1) == frust_range[:,2], (frust_range[:,1]+1) == frust_range[:,3]) 438 | needs_subd = jnp.logical_and(frust_valid, ~is_single_pixel) 439 | return needs_subd, jnp.any(needs_subd) 440 | 441 | @partial(jax.jit, static_argnames=("res_x", "res_y")) 442 | def write_frust_output(res_x, res_y, finished_frust_range, finished_frust_t, finished_frust_count, finished_frust_hit_id): 443 | 444 | ## (3) Write the result (one pixel per frustum) 445 | out_t = jnp.zeros((res_x, res_y)) 446 | out_hit_id = jnp.zeros((res_x, res_y), dtype=int) 447 | out_count = jnp.zeros((res_x, res_y), dtype=int) 448 | 449 | # not needed, all are valid 450 | # INVALID_IND = res_x+res_y+1 451 | x_coords = finished_frust_range[:,0] 452 | y_coords = finished_frust_range[:,1] 453 | 454 | at_args = {'mode' : 'promise_in_bounds', 'unique_indices' : True} 455 | out_t = out_t.at[x_coords, y_coords].set(finished_frust_t, **at_args) 456 | out_count = out_count.at[x_coords, y_coords].set(finished_frust_count, **at_args) 457 | out_hit_id = out_hit_id.at[x_coords, y_coords].set(finished_frust_hit_id, **at_args) 458 | 459 | return out_t, out_count, out_hit_id 460 | 461 | def cast_rays_frustum(funcs_tuple, params_tuple, cam_params, in_opts): 462 | 463 | root_pos, look_dir, up_dir, left_dir, fov_x, fov_y, res_x, res_y = cam_params 464 | 465 | # make sure everything is on the device 466 | cam_params = tuple([jnp.array(x) for x in cam_params]) 467 | opts = {} 468 | for k,v in in_opts.items(): 469 | if k == 'n_substeps': 470 | n_substeps = v 471 | else: 472 | opts[k] = jnp.array(v) 473 | 474 | N_out = res_x*res_y 475 | 476 | ## Steps: 477 | ## (1) March the frustra forward 478 | ## (1a) Take a step 479 | ## (1b) Split any frustra that need it 480 | ## (2) Once all frusta have terminated, subdivide any that need it until they are a single pixel 481 | ## (3) Write the result (one pixel per frustum) 482 | 483 | # TODO think about "subpixel" accuracy in this. Technically, we can guarantee that tiny points 484 | # never slip between rays. 485 | 486 | do_viz = False 487 | 488 | ## Construct a initial frustums 489 | N_side_init = opts['n_side_init'] 490 | N_init_frust = N_side_init**2 491 | N_evals = 0 492 | 493 | # This creates a grid of tiles N_side_init x N_side_init 494 | # (assumption: N_side_init <= res) 495 | x_ticks = jnp.linspace(start=0,stop=res_x,num=N_side_init+1,dtype=int) 496 | y_ticks = jnp.linspace(start=0,stop=res_y,num=N_side_init+1,dtype=int) 497 | x_start = jnp.tile(x_ticks[:-1], N_side_init) 498 | x_end = jnp.tile(x_ticks[1:], N_side_init) 499 | y_start = jnp.repeat(y_ticks[:-1], N_side_init) 500 | y_end = jnp.repeat(y_ticks[1:], N_side_init) 501 | curr_frust_range = jnp.stack((x_start, y_start, x_end, y_end), axis=-1) 502 | 503 | # All the other initial data 504 | curr_frust_t = jnp.zeros(N_init_frust) 505 | curr_frust_int_size = jnp.ones(N_init_frust) * opts['interval_init_size'] * opts['max_dist'] 506 | curr_frust_count = jnp.zeros(N_init_frust) 507 | curr_valid = jnp.ones(N_init_frust, dtype=bool) 508 | empty_start_ind = N_init_frust 509 | 510 | # As the frusta terminate, we will write them to the 'finished' catergory here. 511 | # Note: the upper bound of `N` is tight here, and we should never need to expand. 512 | finished_frust_range = jnp.zeros((N_out,4),dtype=int) 513 | finished_frust_t = jnp.zeros((N_out,)) 514 | finished_frust_hit_id = jnp.zeros((N_out,), dtype=int) 515 | finished_frust_count = jnp.zeros((N_out,)) 516 | finished_start_ind = 0 517 | 518 | if do_viz: 519 | prev_viz_val = {} 520 | 521 | ## (1) March the frustra forward 522 | iter = 0 523 | N_valid = N_init_frust 524 | while(True): 525 | 526 | # Take a step 527 | N_evals += curr_frust_t.shape[0] 528 | curr_valid, curr_frust_t, curr_frust_int_size, curr_frust_count, needs_refine, \ 529 | finished_frust_range, finished_frust_t, finished_frust_hit_id, finished_frust_count, finished_start_ind, N_valid, N_needs_refine = \ 530 | cast_rays_frustum_iter(funcs_tuple, params_tuple, cam_params, iter, n_substeps, \ 531 | curr_valid, curr_frust_range, curr_frust_t, curr_frust_int_size, curr_frust_count,\ 532 | finished_frust_range, finished_frust_t, finished_frust_hit_id, finished_frust_count, finished_start_ind, \ 533 | opts) 534 | 535 | iter += n_substeps 536 | 537 | N_valid = int(N_valid) 538 | N_needs_refine = int(N_needs_refine) 539 | 540 | if(N_valid == 0): break 541 | 542 | 543 | space_needed = N_valid + N_needs_refine 544 | new_bucket_size = get_next_bucket_size(space_needed) 545 | curr_bucket_size = curr_valid.shape[0] 546 | fits_in_smaller_bucket = new_bucket_size < curr_bucket_size 547 | needs_room_to_subdivide = empty_start_ind + N_needs_refine > curr_bucket_size 548 | if needs_room_to_subdivide or fits_in_smaller_bucket: 549 | 550 | # print(f"** COMPATCT AND REBUCKET {curr_bucket_size} --> {new_bucket_size}") 551 | 552 | curr_valid, empty_start_ind, curr_frust_range, curr_frust_t, curr_frust_int_size, curr_frust_count, needs_refine = compactify_and_rebucket_arrays(curr_valid, new_bucket_size, curr_frust_range, curr_frust_t, curr_frust_int_size, curr_frust_count, needs_refine) 553 | empty_start_ind = int(empty_start_ind) 554 | 555 | 556 | # Do the spliting for any rays that need it 557 | curr_valid, curr_frust_range, [curr_frust_t, curr_frust_int_size, curr_frust_count] = \ 558 | subdivide_frusta(needs_refine, empty_start_ind, curr_valid, curr_frust_range, [curr_frust_t, curr_frust_int_size, curr_frust_count]) 559 | 560 | empty_start_ind += N_needs_refine 561 | 562 | 563 | ## (2) Once all frusta have terminated, subdivide any that need it until they are a single pixel 564 | ## TODO: consider that we could write output using forI loops instead 565 | finished_valid = jnp.arange(finished_frust_t.shape[-1]) < finished_start_ind 566 | 567 | # NOTE: we can alternately compute the number needed manually. Each subdivision round splits an axis in half, so the number of rounds is the max of log_2(width_x) + log_2(width_y). (A quick test showed this didn't help performance) 568 | # i_sub = 0 569 | while(True): 570 | 571 | # Any frustum whose area is greater than 1 572 | needs_subd, any_needs_subd = frustum_needs_subdiv_to_pixel(finished_valid, finished_frust_range) 573 | 574 | if not any_needs_subd: 575 | break 576 | 577 | # Split frusta 578 | finished_valid, finished_frust_range, [finished_frust_t, finished_frust_hit_id, finished_frust_count] = \ 579 | subdivide_frusta(needs_subd, finished_start_ind, finished_valid, finished_frust_range, [finished_frust_t, finished_frust_hit_id, finished_frust_count]) 580 | finished_start_ind += jnp.sum(needs_subd) 581 | 582 | # NOTE: this will always yield exactly N frusta total (one per pixel), so there is no need to resize the 'finished' arrays 583 | 584 | ## (3) Write the result (one pixel per frustum) 585 | out_t, out_count, out_hit_id = write_frust_output(res_x, res_y, finished_frust_range, finished_frust_t, finished_frust_count, finished_frust_hit_id) 586 | 587 | return out_t, out_hit_id, out_count, N_evals 588 | 589 | 590 | 591 | -------------------------------------------------------------------------------- /src/render.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import jax.scipy 7 | import numpy as np 8 | from functools import partial 9 | import imageio 10 | 11 | import geometry 12 | import queries 13 | from utils import * 14 | import affine 15 | 16 | # theta_x/y should be 17 | def camera_ray(look_dir, up_dir, left_dir, fov_deg_x, fov_deg_y, theta_x, theta_y): 18 | ray_image_plane_pos = look_dir \ 19 | + left_dir * (theta_x * jnp.tan(jnp.deg2rad(fov_deg_x)/2)) \ 20 | + up_dir * (theta_y * jnp.tan(jnp.deg2rad(fov_deg_y)/2)) 21 | 22 | ray_dir = geometry.normalize(ray_image_plane_pos) 23 | 24 | return ray_dir 25 | 26 | @partial(jax.jit, static_argnames=("res")) 27 | def generate_camera_rays(eye_pos, look_dir, up_dir, res=1024, fov_deg=30.): 28 | 29 | D = res # image dimension 30 | R = res*res # number of rays 31 | 32 | ## Generate rays according to a pinhole camera 33 | 34 | # Image coords on [-1,1] for each output pixel 35 | cam_ax_x = jnp.linspace(-1., 1., res) 36 | cam_ax_y = jnp.linspace(-1., 1., res) 37 | cam_x, cam_y = jnp.meshgrid(cam_ax_x, cam_ax_y) 38 | cam_x = cam_x.flatten() # [R] 39 | cam_y = cam_y.flatten() # [R] 40 | 41 | # Orthornormal camera frame 42 | up_dir = up_dir - jnp.dot(look_dir, up_dir) * look_dir 43 | up_dir = geometry.normalize(up_dir) 44 | left_dir = jnp.cross(look_dir, up_dir) 45 | 46 | 47 | # ray_roots, ray_dirs = jax.jit(jax.vmap(camera_ray))(cam_x, cam_y) 48 | ray_dirs = jax.vmap(partial(camera_ray, look_dir, up_dir, left_dir, fov_deg, fov_deg))(cam_x, cam_y) 49 | ray_roots = jnp.tile(eye_pos, (ray_dirs.shape[0],1)) 50 | return ray_roots, ray_dirs 51 | 52 | 53 | @partial(jax.jit, static_argnames=("funcs_tuple", "method")) 54 | def outward_normal(funcs_tuple, params_tuple, hit_pos, hit_id, eps, method='finite_differences'): 55 | 56 | grad_out = jnp.zeros(3) 57 | i_func = 1 58 | for func, params in zip(funcs_tuple, params_tuple): 59 | f = partial(func, params) 60 | 61 | if method == 'autodiff': 62 | grad_f = jax.jacfwd(f) 63 | grad = grad_f(hit_pos) 64 | 65 | elif method == 'finite_differences': 66 | # 'tetrahedron' central differences approximation 67 | # see e.g. https://www.iquilezles.org/www/articles/normalsSDF/normalsSDF.htm 68 | offsets = jnp.array(( 69 | (+eps, -eps, -eps), 70 | (-eps, -eps, +eps), 71 | (-eps, +eps, -eps), 72 | (+eps, +eps, +eps), 73 | )) 74 | x_pts = hit_pos[None,:] + offsets 75 | samples = jax.vmap(f)(x_pts) 76 | grad = jnp.sum(offsets * samples[:,None], axis=0) 77 | 78 | else: 79 | raise ValueError("unrecognized method") 80 | 81 | grad = geometry.normalize(grad) 82 | grad_out = jnp.where(hit_id == i_func, grad, grad_out) 83 | i_func += 1 84 | 85 | return grad_out 86 | 87 | @partial(jax.jit, static_argnames=("funcs_tuple", "method")) 88 | def outward_normals(funcs_tuple, params_tuple, hit_pos, hit_ids, eps, method='finite_differences'): 89 | this_normal_one = lambda p, id : outward_normal(funcs_tuple, params_tuple, p, id, eps, method=method) 90 | return jax.vmap(this_normal_one)(hit_pos, hit_ids) 91 | 92 | 93 | # @partial(jax.jit, static_argnames=("func","res")) 94 | def render_image(funcs_tuple, params_tuple, eye_pos, look_dir, up_dir, left_dir, res, fov_deg, frustum, opts, shading="normal", shading_color_tuple=((0.157,0.613,1.000)), matcaps=None, tonemap=False, shading_color_func=None): 95 | 96 | # make sure inputs are tuples not lists (can't has lists) 97 | if isinstance(funcs_tuple, list): funcs_tuple = tuple(funcs_tuple) 98 | if isinstance(params_tuple, list): params_tuple = tuple(params_tuple) 99 | if isinstance(shading_color_tuple, list): shading_color_tuple = tuple(shading_color_tuple) 100 | 101 | # wrap in tuples if single was passed 102 | if not isinstance(funcs_tuple, tuple): 103 | funcs_tuple = (funcs_tuple,) 104 | if not isinstance(params_tuple, tuple): 105 | params_tuple = (params_tuple,) 106 | if not isinstance(shading_color_tuple[0], tuple): 107 | shading_color_tuple = (shading_color_tuple,) 108 | 109 | L = len(funcs_tuple) 110 | if (len(params_tuple) != L) or (len(shading_color_tuple) != L): 111 | raise ValueError("render_image tuple arguments should all be same length") 112 | 113 | ray_roots, ray_dirs = generate_camera_rays(eye_pos, look_dir, up_dir, res=res, fov_deg=fov_deg) 114 | if frustum: 115 | # == Frustum raycasting 116 | 117 | cam_params = eye_pos, look_dir, up_dir, left_dir, fov_deg, fov_deg, res, res 118 | 119 | with Timer("frustum raycast"): 120 | t_raycast, hit_ids, counts, n_eval = queries.cast_rays_frustum(funcs_tuple, params_tuple, cam_params, opts) 121 | t_raycast.block_until_ready() 122 | 123 | # TODO transposes here due to image layout conventions. can we get rid of them? 124 | t_raycast = t_raycast.transpose().flatten() 125 | hit_ids = hit_ids.transpose().flatten() 126 | counts = counts.transpose().flatten() 127 | 128 | else: 129 | # == Standard raycasting 130 | with Timer("raycast"): 131 | t_raycast, hit_ids, counts, n_eval = queries.cast_rays(funcs_tuple, params_tuple, ray_roots, ray_dirs, opts) 132 | t_raycast.block_until_ready() 133 | 134 | 135 | hit_pos = ray_roots + t_raycast[:,np.newaxis] * ray_dirs 136 | hit_normals = outward_normals(funcs_tuple, params_tuple, hit_pos, hit_ids, opts['hit_eps']) 137 | hit_color = shade_image(shading, ray_dirs, hit_pos, hit_normals, hit_ids, up_dir, matcaps, shading_color_tuple, shading_color_func=shading_color_func) 138 | 139 | img = jnp.where(hit_ids[:,np.newaxis], hit_color, jnp.ones((res*res, 3))) 140 | 141 | if tonemap: 142 | # We intentionally tonemap before compositing in the shadow. Otherwise the white level clips the shadow and gives it a hard edge. 143 | img = tonemap_image(img) 144 | 145 | img = img.reshape(res,res,3) 146 | depth = t_raycast.reshape(res,res) 147 | counts = counts.reshape(res,res) 148 | hit_ids = hit_ids.reshape(res,res) 149 | 150 | return img, depth, counts, hit_ids, n_eval, -1 151 | 152 | def tonemap_image(img, gamma=2.2, white_level=.75, exposure=1.): 153 | img = img * exposure 154 | num = img * (1.0 + (img / (white_level * white_level))) 155 | den = (1.0 + img) 156 | img = num / den; 157 | img = jnp.power(img, 1.0/gamma) 158 | return img 159 | 160 | @partial(jax.jit, static_argnames=("shading", "shading_color_func")) 161 | def shade_image(shading, ray_dirs, hit_pos, hit_normals, hit_ids, up_dir, matcaps, shading_color_tuple, shading_color_func=None): 162 | 163 | # Simple shading 164 | if shading == "normal": 165 | hit_color = (hit_normals + 1.) / 2. # map normals to [0,1] 166 | 167 | elif shading == "matcap_color": 168 | 169 | # compute matcap coordinates 170 | ray_up = jax.vmap(partial(geometry.orthogonal_dir,up_dir))(ray_dirs) 171 | ray_left = jax.vmap(jnp.cross)(ray_dirs, ray_up) 172 | matcap_u = jax.vmap(jnp.dot)(-ray_left, hit_normals) 173 | matcap_v = jax.vmap(jnp.dot)(ray_up, hit_normals) 174 | 175 | # pull inward slightly to avoid indexing off the matcap image 176 | matcap_u *= .98 177 | matcap_v *= .98 178 | 179 | # remap to image indices 180 | matcap_x = (matcap_u + 1.) / 2. * matcaps[0].shape[0] 181 | matcap_y = (-matcap_v + 1.) / 2. * matcaps[0].shape[1] 182 | matcap_coords = jnp.stack((matcap_x, matcap_y), axis=0) 183 | 184 | def sample_matcap(matcap, coords): 185 | m = lambda X : jax.scipy.ndimage.map_coordinates(X, coords, order=1, mode='nearest') 186 | return jax.vmap(m, in_axes=-1, out_axes=-1)(matcap) 187 | 188 | # fetch values 189 | mat_r = sample_matcap(matcaps[0], matcap_coords) 190 | mat_g = sample_matcap(matcaps[1], matcap_coords) 191 | mat_b = sample_matcap(matcaps[2], matcap_coords) 192 | mat_k = sample_matcap(matcaps[3], matcap_coords) 193 | 194 | # find the appropriate shading color 195 | def get_shade_color(hit_pos, hit_id): 196 | shading_color = jnp.ones(3) 197 | 198 | if shading_color_func is None: 199 | # use the tuple of constant colors 200 | i_func = 1 201 | for c in shading_color_tuple: 202 | shading_color = jnp.where(hit_id == i_func, jnp.array(c), shading_color) 203 | i_func += 1 204 | else: 205 | # look up varying color 206 | shading_color = shading_color_func(hit_pos) 207 | 208 | return shading_color 209 | shading_color = jax.vmap(get_shade_color)(hit_pos, hit_ids) 210 | 211 | c_r, c_g, c_b = shading_color[:,0], shading_color[:,1], shading_color[:,2] 212 | c_k = 1. - (c_r + c_b + c_g) 213 | 214 | c_r = c_r[:,None] 215 | c_g = c_g[:,None] 216 | c_b = c_b[:,None] 217 | c_k = c_k[:,None] 218 | 219 | hit_color = c_r*mat_r + c_b*mat_b + c_g*mat_g + c_k*mat_k 220 | 221 | else: 222 | raise RuntimeError("Unrecognized shading parameter") 223 | 224 | return hit_color 225 | 226 | 227 | 228 | # create camera parameters looking in a direction 229 | def look_at(eye_pos, target=None, up_dir='y'): 230 | 231 | if target == None: 232 | target = jnp.array((0., 0., 0.,)) 233 | if up_dir == 'y': 234 | up_dir = jnp.array((0., 1., 0.,)) 235 | elif up_dir == 'z': 236 | up_dir = jnp.array((0., 0., 1.,)) 237 | 238 | look_dir = geometry.normalize(target - eye_pos) 239 | up_dir = geometry.orthogonal_dir(up_dir, look_dir) 240 | left_dir = jnp.cross(look_dir, up_dir) 241 | 242 | return look_dir, up_dir, left_dir 243 | 244 | 245 | def load_matcap(fname_pattern): 246 | 247 | imgs = [] 248 | for c in ['r', 'g', 'b', 'k']: 249 | im = imageio.imread(fname_pattern.format(c)) 250 | im = jnp.array(im) / 256. 251 | imgs.append(im) 252 | 253 | return tuple(imgs) 254 | -------------------------------------------------------------------------------- /src/sdf.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import dataclasses 3 | from dataclasses import dataclass 4 | 5 | import numpy as np 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | 10 | import utils 11 | 12 | import implicit_function 13 | from implicit_function import SIGN_UNKNOWN, SIGN_POSITIVE, SIGN_NEGATIVE 14 | 15 | # === Function wrappers 16 | 17 | class WeakSDFImplicitFunction(implicit_function.ImplicitFunction): 18 | 19 | def __init__(self, sdf_func, lipschitz_bound=1.): 20 | super().__init__("classify-only") 21 | self.sdf_func = sdf_func 22 | self.lipschitz_bound = lipschitz_bound 23 | 24 | def __call__(self, params, x): 25 | return self.sdf_func(params, x) 26 | 27 | # the parent class automatically delegates to this 28 | # def classify_box(self, params, box_lower, box_upper): 29 | # pass 30 | 31 | def classify_general_box(self, params, box_center, box_vecs, offset=0.): 32 | 33 | # compute the radius of the box 34 | rad = jnp.sqrt(jnp.sum(jnp.square(jnp.linalg.norm(box_vecs, axis=-1)), axis=0)) 35 | 36 | d = box_center.shape[-1] 37 | v = box_vecs.shape[-2] 38 | assert box_center.shape == (d,), "bad box_vecs shape" 39 | assert box_vecs.shape == (v,d), "bad box_vecs shape" 40 | 41 | # evaluate the function 42 | val = self.sdf_func(params, box_center) 43 | can_change = jnp.abs(val) - rad * self.lipschitz_bound < 0. 44 | 45 | # determine the type of the region 46 | output_type = SIGN_UNKNOWN 47 | output_type = jnp.where(jnp.logical_and(~can_change, val > offset), SIGN_POSITIVE, output_type) 48 | output_type = jnp.where(jnp.logical_and(~can_change, val < -offset), SIGN_NEGATIVE, output_type) 49 | 50 | return output_type 51 | -------------------------------------------------------------------------------- /src/slope_interval.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from dataclasses import dataclass 3 | 4 | import numpy as np 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | import utils 10 | from utils import printarr 11 | import implicit_function 12 | from implicit_function import SIGN_UNKNOWN, SIGN_POSITIVE, SIGN_NEGATIVE 13 | 14 | # === Function wrappers 15 | 16 | class SlopeIntervalImplicitFunction(implicit_function.ImplicitFunction): 17 | 18 | def __init__(self, slope_interval_func): 19 | super().__init__("classify-and-distance") 20 | self.slope_interval_func = slope_interval_func 21 | 22 | def __call__(self, params, x): 23 | return wrap_scalar(partial(self.slope_interval_func, params))(x) 24 | 25 | # the parent class automatically delegates to this 26 | # def classify_box(self, params, box_lower, box_upper): 27 | # pass 28 | 29 | def classify_general_box(self, params, box_center, box_vecs, offset=0.): 30 | 31 | d = box_center.shape[-1] 32 | v = box_vecs.shape[-2] 33 | assert box_center.shape == (d,), "bad box_vecs shape" 34 | assert box_vecs.shape == (v,d), "bad box_vecs shape" 35 | 36 | # evaluate the function 37 | input = coordinates_in_general_box(box_center, box_vecs) 38 | output = self.slope_interval_func(params, input) 39 | 40 | # compute relevant bounds 41 | slope_lower, slope_upper = slope_bounds(output) 42 | may_lower, may_upper = primal_may_contain_bounds(output, slope_lower, slope_upper) 43 | 44 | # determine the type of the region 45 | output_type = SIGN_UNKNOWN 46 | output_type = jnp.where(may_lower > offset, SIGN_POSITIVE, output_type) 47 | output_type = jnp.where(may_upper < -offset, SIGN_NEGATIVE, output_type) 48 | # output_type = jnp.where(jnp.logical_and(must_lower < 0, must_upper > 0), SIGN_STRADDLES, output_type) 49 | 50 | 51 | return output_type 52 | 53 | def min_distance_to_zero(self, params, box_center, box_axis_vec, return_source_value=False): 54 | 55 | # evaluate the function 56 | input = coordinates_in_box(box_center-box_axis_vec, box_center+box_axis_vec) 57 | output = self.slope_interval_func(params, input) 58 | raw_primal, _, _ = output 59 | 60 | # compute relevant bounds 61 | slope_lower, slope_upper = slope_bounds(output) 62 | 63 | # flip the sign, so we have only one case to handle (where the function is positive) 64 | # for the slope components, this sign won't matter once we take the abs/max below 65 | mask = raw_primal >= 0 66 | primal = jnp.where(mask, raw_primal, -raw_primal) 67 | 68 | # compute the distance to the crossing 69 | decrease_mag = jnp.maximum(jnp.abs(slope_lower), jnp.abs(slope_upper)) 70 | vec_len = jnp.abs(box_axis_vec) # we're axis-aligned here, so this is just a list of components 71 | min_len = jnp.min(vec_len) 72 | decrease_mag = decrease_mag / vec_len # rescale out of the logical [-1,1] domain to wold coords 73 | axis_decrease = jnp.sum(jnp.clip(decrease_mag, a_min=0.), axis=-1) 74 | distance_to_zero = jnp.clip(primal / axis_decrease, a_max=min_len) 75 | distance_to_zero = jnp.where(distance_to_zero == 0, 0., distance_to_zero) # avoid NaN 76 | 77 | if return_source_value: 78 | return raw_primal, distance_to_zero 79 | else: 80 | return distance_to_zero 81 | 82 | 83 | def min_distance_to_zero_in_direction(self, params, source_point, bound_vec, source_range=None, return_source_value=False): 84 | 85 | # construct bounds 86 | # the "forward" direction always comes first 87 | # get the center point of the interval 88 | bound_forward = bound_vec * 0.5 89 | center = source_point + bound_forward 90 | bound_forward = bound_forward[None,:] # append a dimension to make it (1,d) 91 | if source_range is None: 92 | bound_vecs = bound_forward 93 | else: 94 | bound_vecs = jnp.concatenate((bound_forward, source_range)) 95 | 96 | # evaluate the function for bounds 97 | input = coordinates_in_general_box(center, bound_vecs) 98 | output = self.slope_interval_func(params, input) 99 | 100 | slope_lower, slope_upper = slope_bounds(output) 101 | 102 | if source_range is not None: 103 | 104 | # do another interval evaluation to bound the value at the source along the other vecs 105 | # alternative: could use the function value from a single evaluation then re-use the slope bounds we already have. This requires one less interval function evaluation, but will give looser bounds. 106 | input_source = coordinates_in_general_box(source_point, source_range) 107 | output_source = self.slope_interval_func(params, input_source) 108 | source_slope_lower, source_slope_upper = slope_bounds(output_source) 109 | source_lower, source_upper = primal_may_contain_bounds(output_source, source_slope_lower, source_slope_upper) 110 | 111 | # unify the cases based on whether the function is positive or negative, so we only need to handle one case 112 | # (note that if the interval contains zero this does nonsense, we handle that below) 113 | is_pos = source_lower >= 0. 114 | bound_vec_len = jnp.linalg.norm(bound_vec) 115 | val = jnp.where(is_pos, source_lower, -source_upper) 116 | slope = jnp.where(is_pos, slope_lower[0], -slope_upper[0]) 117 | 118 | # compute the distance 119 | slope = 2. * slope / bound_vec_len # remap slope from abstract [-1,1] domain to world 120 | biggest_decrease = jnp.clip(-slope, a_min=0.) 121 | distance_to_zero = val / biggest_decrease 122 | distance_to_zero = jnp.clip(distance_to_zero, a_max=bound_vec_len) 123 | 124 | # if the source could possibly be zero, we can't make any progress 125 | source_contains_zero = jnp.logical_and(source_lower <= 0, source_upper >= 0) 126 | distance_to_zero = jnp.where(source_contains_zero, 0., distance_to_zero) 127 | 128 | 129 | if return_source_value: 130 | return source_lower, source_upper, distance_to_zero 131 | else: 132 | return distance_to_zero 133 | 134 | else: # source range is None, this is just a ray 135 | 136 | # evaluate the function at the source 137 | source_val = self(params, source_point) 138 | 139 | # unify the cases based on whether the function is positive or negative, so we only need to handle one case 140 | is_pos = source_val >= 0. 141 | bound_vec_len = jnp.linalg.norm(bound_vec) 142 | val = jnp.abs(source_val) 143 | slope = jnp.where(is_pos, slope_lower[0], -slope_upper[0]) 144 | 145 | # compute the distance 146 | slope = 2. * slope / bound_vec_len # remap slope from abstract [-1,1] domain to world 147 | biggest_decrease = jnp.clip(-slope, a_min=0.) 148 | distance_to_zero = val / biggest_decrease 149 | distance_to_zero = jnp.clip(distance_to_zero, a_max=bound_vec_len) 150 | 151 | # avoid a NaN 152 | distance_to_zero = jnp.where(source_val==0., 0., distance_to_zero) 153 | 154 | # simple ray case 155 | if return_source_value: 156 | return source_val, distance_to_zero 157 | else: 158 | return distance_to_zero 159 | 160 | 161 | 162 | # === Slope interval utilities 163 | 164 | def is_const(input): 165 | primal, slope_center, slope_width = input 166 | return slope_center is None 167 | 168 | # Compute the 'radius' (width of the approximation) 169 | # NOTE: this is still in logical [-1,+1] coords 170 | def slope_radius(input): 171 | if is_const(input): return 0. 172 | primal, slope_center, slope_width = input 173 | return slope_width 174 | 175 | # Constuct affine inputs for the coordinates in k-dimensional box 176 | # lower,upper should be vectors of length-k 177 | def coordinates_in_box(lower, upper): 178 | center = 0.5 * (lower+upper) 179 | vec = upper - center 180 | axis_vecs = jnp.diag(vec) 181 | return coordinates_in_general_box(center, axis_vecs) 182 | 183 | 184 | # Constuct affine inputs for the coordinates in k-dimensional box, 185 | # which is not necessarily axis-aligned 186 | # - center is the center of the box 187 | # - vecs is a length-D list of vectors which point from the center of the box to its 188 | # edges. 189 | # Note that this is where we remap derivatives to the logical domain [-1,+1] 190 | # (this function is nearly a no-op, but giving it this name makes it easier to 191 | # reason about) 192 | def coordinates_in_general_box(center, vecs): 193 | primal = center 194 | assert center.shape[-1] == vecs.shape[-1], "vecs last dim should be same as center" 195 | slope_center = jnp.stack(vecs, axis=0) 196 | slope_width = jnp.zeros_like(slope_center) 197 | return primal, slope_center, slope_width 198 | 199 | 200 | def slope_bounds(input): 201 | primal, slope_center, slope_width = input 202 | rad = slope_radius(input) 203 | return slope_center-rad, slope_center+rad 204 | 205 | def primal_may_contain_bounds(input, slope_lower, slope_upper): 206 | primal, _, _ = input 207 | slope_max_mag = utils.biggest_magnitude(slope_lower, slope_upper) 208 | primal_rad = jnp.sum(slope_max_mag, axis=0) 209 | primal_lower, primal_upper = primal-primal_rad, primal+primal_rad 210 | return primal_lower, primal_upper 211 | 212 | 213 | # Convert to/from the slope interval representation from an ordinary value representing a scalar 214 | def from_scalar(x): 215 | return x, None, None 216 | def to_scalar(input): 217 | if not is_const(input): 218 | raise ValueError("non const input") 219 | return input[0] 220 | def wrap_scalar(func): 221 | return lambda x : to_scalar(func(from_scalar(x))) 222 | -------------------------------------------------------------------------------- /src/slope_interval_layers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from jax import lax 6 | 7 | import slope_interval 8 | import mlp 9 | import utils 10 | 11 | def dense(input, A, b): 12 | if(slope_interval.is_const(input)): 13 | out = jnp.dot(input[0], A) 14 | if b is not None: 15 | out += b 16 | return out, None, None 17 | 18 | primal, dual_center, dual_width = input 19 | assert (len(primal.shape) == 1 and dual_center.shape == dual_width.shape and dual_center.shape[-1] == primal.shape[-1]), "shape error" 20 | 21 | def dot(x): 22 | return jnp.dot(x, A) 23 | def dot_abs(x): 24 | return jnp.dot(x, jnp.abs(A)) 25 | 26 | primal = dot(primal) 27 | if b is not None: 28 | primal = primal + b 29 | 30 | if dual_center is not None: dual_center = jax.vmap(dot)(dual_center) 31 | if dual_width is not None: dual_width = jax.vmap(dot_abs)(dual_width) 32 | 33 | return primal, dual_center, dual_width 34 | mlp.apply_func['slope_interval']['dense'] = dense 35 | 36 | def relu(input): 37 | primal, dual_center, dual_width = input 38 | 39 | new_primal = jax.nn.relu(primal) 40 | 41 | if slope_interval.is_const(input): 42 | return new_primal, dual_center, dual_width 43 | 44 | slope_lower, slope_upper = slope_interval.slope_bounds(input) 45 | primal_lower, primal_upper = slope_interval.primal_may_contain_bounds(input, slope_lower, slope_upper) 46 | 47 | # compute bounds on the derivative of the nonlineariy 48 | df_lower = jnp.where(primal_lower > 0, 1., 0.) 49 | df_upper = jnp.where(primal_upper < 0, 0., 1.) 50 | 51 | # simpler here because df is nonegative 52 | new_slope_lower = jnp.minimum(slope_lower * df_lower[None,:], slope_lower * df_upper[None,:]) 53 | new_slope_upper = jnp.maximum(slope_upper * df_lower[None,:], slope_upper * df_upper[None,:]) 54 | new_slope_center = 0.5 * (new_slope_lower + new_slope_upper) 55 | new_slope_width = new_slope_upper - new_slope_center 56 | 57 | return new_primal, new_slope_center, new_slope_width 58 | mlp.apply_func['slope_interval']['relu'] = relu 59 | 60 | def elu(input): 61 | primal, dual_center, dual_width = input 62 | 63 | new_primal = jax.nn.elu(primal) 64 | 65 | if slope_interval.is_const(input): 66 | return new_primal, dual_center, dual_width 67 | 68 | slope_lower, slope_upper = slope_interval.slope_bounds(input) 69 | primal_lower, primal_upper = slope_interval.primal_may_contain_bounds(input, slope_lower, slope_upper) 70 | 71 | # compute bounds on the derivative of the nonlineariy 72 | df_lower = jnp.clip(jnp.exp(primal_lower), a_max=1.) 73 | df_upper = jnp.clip(jnp.exp(primal_upper), a_max=1.) 74 | 75 | # simpler here because df is nonegative 76 | new_slope_lower = jnp.minimum(slope_lower * df_lower[None,:], slope_lower * df_upper[None,:]) 77 | new_slope_upper = jnp.maximum(slope_upper * df_lower[None,:], slope_upper * df_upper[None,:]) 78 | new_slope_center = 0.5 * (new_slope_lower + new_slope_upper) 79 | new_slope_width = new_slope_upper - new_slope_center 80 | 81 | return new_primal, new_slope_center, new_slope_width 82 | mlp.apply_func['slope_interval']['elu'] = elu 83 | 84 | def sin(input): 85 | primal, dual_center, dual_width = input 86 | 87 | new_primal = jnp.sin(primal) 88 | 89 | if slope_interval.is_const(input): 90 | return new_primal, dual_center, dual_width 91 | 92 | slope_lower, slope_upper = slope_interval.slope_bounds(input) 93 | primal_lower, primal_upper = slope_interval.primal_may_contain_bounds(input, slope_lower, slope_upper) 94 | 95 | # compute bounds on the derivative of the nonlineariy 96 | df_lower, df_upper = utils.cos_bound(primal_lower, primal_upper) 97 | # utils.printarr(primal_lower, primal_upper, df_lower, df_upper, short=False) 98 | 99 | # df can be positive or negative; need full expression 100 | # (this is just an interval multiplication) 101 | vals = [slope_lower * df_lower[None,:], slope_lower * df_upper[None,:], 102 | slope_upper * df_lower[None,:], slope_upper * df_upper[None,:]] 103 | new_slope_lower = utils.minimum_all(vals) 104 | new_slope_upper = utils.maximum_all(vals) 105 | new_slope_center = 0.5 * (new_slope_lower + new_slope_upper) 106 | new_slope_width = new_slope_upper - new_slope_center 107 | 108 | return new_primal, new_slope_center, new_slope_width 109 | mlp.apply_func['slope_interval']['sin'] = sin 110 | 111 | def pow2_frequency_encode(input, coefs, shift=None): 112 | primal, dual_center, dual_width = input 113 | 114 | # expand the length-d inputs to a lenght-d*c vector 115 | def s(with_shift, x): 116 | out = (x[:,None] * coefs[None,:]) 117 | if with_shift and shift is not None: 118 | out += shift 119 | return out.flatten() 120 | 121 | primal = s(True, primal) 122 | if dual_center is not None: dual_center = jax.vmap(partial(s, False))(dual_center) 123 | if dual_width is not None: dual_width = jax.vmap(partial(s,False))(dual_width) 124 | 125 | return primal, dual_center, dual_width 126 | mlp.apply_func['slope_interval']['pow2_frequency_encode'] = pow2_frequency_encode 127 | 128 | def squeeze_last(input): 129 | primal, dual_center, dual_width = input 130 | s = lambda x : jnp.squeeze(x, axis=0) 131 | primal = s(primal) 132 | if dual_center is not None: 133 | dual_center = jax.vmap(s)(dual_center) 134 | dual_width = jax.vmap(s)(dual_width) 135 | return primal, dual_center, dual_width 136 | mlp.apply_func['slope_interval']['squeeze_last'] = squeeze_last 137 | 138 | def spatial_transformation(input, R, t): 139 | # if the shape transforms by R,t, input points need the opposite transform 140 | R_inv = jnp.linalg.inv(R) 141 | t_inv = jnp.dot(R_inv, -t) 142 | return dense(input, A=R_inv, b=t_inv) 143 | mlp.apply_func['slope_interval']['spatial_transformation'] = spatial_transformation 144 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import inspect 4 | 5 | from functools import partial 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | 10 | import polyscope as ps 11 | import polyscope.imgui as psim 12 | 13 | class Timer(object): 14 | def __init__(self, name=None, filename=None): 15 | self.name = name 16 | self.filename = filename 17 | 18 | def __enter__(self): 19 | self.tstart = time.time() 20 | 21 | def __exit__(self, type, value, traceback): 22 | message = 'Elapsed: %.3f seconds' % (time.time() - self.tstart) 23 | if self.name: 24 | message = '[%s] ' % self.name + message 25 | print(message) 26 | if self.filename: 27 | with open(self.filename, 'a') as file: 28 | print(str(datetime.datetime.now()) + ": ", message, file=file) 29 | 30 | # Extends dict{} to allow access via dot member like d.elem instead of d['elem'] 31 | class DotDict(dict): 32 | __getattr__ = dict.get 33 | __setattr__ = dict.__setitem__ 34 | __delattr__ = dict.__delitem__ 35 | 36 | 37 | def ensure_dir_exists(d): 38 | if not os.path.exists(d): 39 | os.makedirs(d) 40 | 41 | 42 | # === Polyscope 43 | 44 | def combo_string_picker(name, curr_val, vals_list): 45 | 46 | changed = psim.BeginCombo(name, curr_val) 47 | clicked = False 48 | if changed: 49 | for val in vals_list: 50 | _, selected = psim.Selectable(val, curr_val==val) 51 | if selected: 52 | curr_val=val 53 | clicked = True 54 | psim.EndCombo() 55 | 56 | return clicked, curr_val 57 | 58 | # === JAX helpers 59 | 60 | # quick printing 61 | def printarr(*arrs, data=True, short=True, max_width=200): 62 | 63 | # helper for below 64 | def compress_str(s): 65 | return s.replace('\n', '') 66 | name_align = ">" if short else "<" 67 | 68 | # get the name of the tensor as a string 69 | frame = inspect.currentframe().f_back 70 | try: 71 | # first compute some length stats 72 | name_len = -1 73 | dtype_len = -1 74 | shape_len = -1 75 | default_name = "[unnamed]" 76 | for a in arrs: 77 | name = default_name 78 | for k,v in frame.f_locals.items(): 79 | if v is a: 80 | name = k 81 | break 82 | name_len = max(name_len, len(name)) 83 | dtype_len = max(dtype_len, len(str(a.dtype))) 84 | shape_len = max(shape_len, len(str(a.shape))) 85 | len_left = max_width - name_len - dtype_len - shape_len - 5 86 | 87 | # now print the acual arrays 88 | for a in arrs: 89 | name = default_name 90 | for k,v in frame.f_locals.items(): 91 | if v is a: 92 | name = k 93 | break 94 | print(f"{name:{name_align}{name_len}} {str(a.dtype):<{dtype_len}} {str(a.shape):>{shape_len}}", end='') 95 | if data: 96 | # print the contents of the array 97 | print(": ", end='') 98 | flat_str = compress_str(str(a)) 99 | if len(flat_str) < len_left: 100 | # short arrays are easy to print 101 | print(flat_str) 102 | else: 103 | # long arrays 104 | if short: 105 | # print a shortented version that fits on one line 106 | if len(flat_str) > len_left - 4: 107 | flat_str = flat_str[:(len_left-4)] + " ..." 108 | print(flat_str) 109 | else: 110 | # print the full array on a new line 111 | print("") 112 | print(a) 113 | else: 114 | print("") # newline 115 | finally: 116 | del frame 117 | 118 | 119 | 120 | def logical_and_all(vals): 121 | out = vals[0] 122 | for i in range(1,len(vals)): 123 | out = jnp.logical_and(out, vals[i]) 124 | return out 125 | 126 | def logical_or_all(vals): 127 | out = vals[0] 128 | for i in range(1,len(vals)): 129 | out = jnp.logical_or(out, vals[i]) 130 | return out 131 | 132 | def minimum_all(vals): 133 | ''' 134 | Take elementwise minimum of a list of arrays 135 | ''' 136 | combined = jnp.stack(vals, axis=0) 137 | return jnp.min(combined, axis=0) 138 | 139 | def maximum_all(vals): 140 | ''' 141 | Take elementwise maximum of a list of arrays 142 | ''' 143 | combined = jnp.stack(vals, axis=0) 144 | return jnp.max(combined, axis=0) 145 | 146 | def all_same_sign(vals): 147 | ''' 148 | Test if all values in an array have (strictly) the same sign 149 | ''' 150 | return jnp.logical_or(jnp.all(vals < 0), jnp.all(vals > 0)) 151 | 152 | # Given a 1d array mask, enumerate the nonero entries 153 | # example: 154 | # in: [0 1 1 0 1 0] 155 | # out: [X 0 1 X 2 X] 156 | # where X = fill_val 157 | # if fill_val is None, the array lenght + 1 is used 158 | def enumerate_mask(mask, fill_value=None): 159 | if fill_value is None: 160 | fill_value = mask.shape[-1]+1 161 | out = jnp.cumsum(mask, axis=-1)-1 162 | out = jnp.where(mask, out, fill_value) 163 | return out 164 | 165 | 166 | # Returns the first index past the last True value in a mask 167 | def empty_start_ind(mask): 168 | return jnp.max(jnp.arange(mask.shape[-1]) * mask)+1 169 | 170 | # Given a list of arrays all of the same shape, interleaves 171 | # them along the first dimension and returns an array such that 172 | # out.shape[0] = len(arrs) * arrs[0].shape[0] 173 | def interleave_arrays(arrs): 174 | s = list(arrs[0].shape) 175 | s[0] *= len(arrs) 176 | return jnp.stack(arrs, axis=1).reshape(s) 177 | 178 | @partial(jax.jit, static_argnames=("new_size","axis")) 179 | def resize_array_axis(A, new_size, axis=0): 180 | first_N = min(new_size, A.shape[0]) 181 | shape = list(A.shape) 182 | shape[axis] = new_size 183 | new_A = jnp.zeros(shape, dtype=A.dtype) 184 | new_A = new_A.at[:first_N,...].set(A.at[:first_N,...].get()) 185 | return new_A 186 | 187 | def smoothstep(x): 188 | out = 3.*x*x - 2.*x*x*x 189 | out = jnp.where(x < 0, 0., out) 190 | out = jnp.where(x > 1, 1., out) 191 | return out 192 | 193 | def binary_cross_entropy_loss(logit_in, target): 194 | # same as the pytorch impl, allegedly numerically stable 195 | neg_abs = -jnp.abs(logit_in) 196 | loss = jnp.clip(logit_in, a_min=0) - logit_in * target + jnp.log(1 + jnp.exp(neg_abs)) 197 | return loss 198 | 199 | # interval routines 200 | def smallest_magnitude(interval_lower, interval_upper): 201 | min_mag = jnp.maximum(jnp.abs(interval_lower), jnp.abs(interval_upper)) 202 | min_mag = jnp.where(jnp.logical_and(interval_upper > 0, interval_lower < 0), 0., min_mag) 203 | return min_mag 204 | 205 | def biggest_magnitude(interval_lower, interval_upper): 206 | return jnp.maximum(interval_upper, -interval_lower) 207 | 208 | 209 | def sin_bound(lower, upper): 210 | ''' 211 | Bound sin([lower,upper]) 212 | ''' 213 | f_lower = jnp.sin(lower) 214 | f_upper = jnp.sin(upper) 215 | 216 | # test if there is an interior peak in the range 217 | lower /= 2. * jnp.pi 218 | upper /= 2. * jnp.pi 219 | contains_min = jnp.ceil(lower - .75) < (upper - .75) 220 | contains_max = jnp.ceil(lower - .25) < (upper - .25) 221 | 222 | # result is either at enpoints or maybe an interior peak 223 | out_lower = jnp.minimum(f_lower, f_upper) 224 | out_lower = jnp.where(contains_min, -1., out_lower) 225 | out_upper = jnp.maximum(f_lower, f_upper) 226 | out_upper = jnp.where(contains_max, 1., out_upper) 227 | 228 | return out_lower, out_upper 229 | 230 | def cos_bound(lower, upper): 231 | return sin_bound(lower + jnp.pi/2, upper + jnp.pi/2) 232 | --------------------------------------------------------------------------------