├── .gitignore ├── README.md ├── masonry_vault_cad_design.gif ├── pyproject.toml ├── scripts ├── Vera.ttf ├── bezier.yml ├── brickify.py ├── camera.py ├── optimize.py ├── predict.py ├── predict_optimize.py ├── shapes.py ├── sweep.py ├── sweep_bezier.yml ├── sweep_tower.yml ├── text_2_mesh.py ├── tower.yml ├── train.py ├── visualize.py └── visualize_tower_task.py └── src └── neural_fdm ├── __init__.py ├── builders.py ├── generators ├── __init__.py ├── bezier.py ├── generator.py ├── generator_bezier.py ├── grids.py └── tubes.py ├── helpers.py ├── losses.py ├── mesh.py ├── models.py ├── plotting.py ├── serialization.py └── training.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | *.log 4 | tmp/ 5 | 6 | *.py[cod] 7 | *.egg 8 | build 9 | htmlcov 10 | 11 | src/neural_fdm.egg-info/ 12 | 13 | temp/ 14 | 15 | data/ 16 | data/backup/ 17 | data/cad/ 18 | data/figures/ 19 | data/trained/ 20 | 21 | figures/ 22 | 23 | scripts/_archive 24 | scripts/*.png 25 | scripts/*.ai 26 | scripts/*.svg 27 | scripts/*.indd 28 | scripts/*.eqx 29 | 30 | **/wandb/ 31 | 32 | **.gh 33 | **.3dm 34 | **.ipynb 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Real-time design of architectural simulations with differentiable mechanics and neural networks 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2409.02606-b31b1b.svg)](https://arxiv.org/abs/2409.02606) 4 | 5 | > Code for the [paper](https://arxiv.org/abs/2409.02606) published at ICLR 2025. 6 | 7 | ![Our trained model, deployed in Rhino3D](masonry_vault_cad_design.gif) 8 | 9 | ## Abstract 10 | Designing mechanically efficient geometry for architectural structures like shells, towers, and bridges, is an expensive iterative process. Existing techniques for solving such inverse problems rely on traditional optimization methods, which are slow and computationally expensive, limiting iteration speed and design exploration. Neural networks would seem to offer a solution via data-driven amortized optimization, but they often require extensive fine-tuning and cannot ensure that important design criteria, such as mechanical integrity, are met. 11 | 12 | In this work, we combine neural networks with a differentiable mechanics simulator to develop a model that accelerates the solution of shape approximation problems for architectural structures represented as bar systems. This model explicitly guarantees compliance with mechanical constraints while generating designs that closely match target geometries. We validate our approach in two tasks, the design of masonry shells and cable-net towers. Our model achieves better accuracy and generalization than fully neural alternatives, and comparable accuracy to direct optimization but in real time, enabling fast and reliable design exploration. We further demonstrate its advantages by integrating it into 3D modeling software and fabricating a physical prototype. Our work opens up new opportunities for accelerated mechanical design enhanced by neural networks for the built environment. 13 | 14 | ## Table of Contents 15 | 16 | - [Pretrained models](#pretrained-models) 17 | - [Repository structure](#repository-structure) 18 | - [Installation](#installation) 19 | - [Configuration](#configuration) 20 | - [Data generation](#data-generation) 21 | - [Building a model](#building-a-model) 22 | - [Loss and optimizer](#loss-and-optimizer) 23 | - [Training](#training) 24 | - [Testing](#testing) 25 | - [Visualization](#visualization) 26 | - [Direct optimization](#direct-optimization) 27 | - [Predict then optimize](#predict-then-optimize) 28 | - [Citation](#citation) 29 | - [Contact](#contact) 30 | 31 | ## Pretrained models 32 | 33 | Wan't to skip training? We got you 💪🏽 34 | Our trained model weights are publicly available at this [link](https://drive.google.com/drive/folders/1BL_g5ikNh1s0fxsNp4PzKl84fQFUpm0L?usp=share_link). 35 | Once downloaded, you can [test](#testing) the models at inference time and [display](#visualization) their predictions. 36 | 37 | [Table of contents](#table-of-contents) ⬆️ 38 | 39 | ## Repository structure 40 | 41 | This repository contains two folders with the meat of our work: `src` and `scripts`. 42 | 43 | The first folder, `src`, defines all the code infrastructure we need to build, train, serialize, and visualize our model and the baselines. 44 | The second one, `scripts`, groups a list of routines to execute the code in `src`, and more importantly, to reproduce our experiments at inference time. 45 | 46 | With the scripts, you can even tesselate and 3D print your own masonry vault from one of our model predictions if you fancy! 47 | 48 | [Table of contents](#table-of-contents) ⬆️ 49 | 50 | ## Installation 51 | 52 | >We only support installation on a CPU. Our paper does not use any GPUs. Crazy, right? 🪄 53 | 54 | Create a new [Anaconda](https://www.anaconda.com/) environment with Python 3.0 and then activate it: 55 | 56 | ```bash 57 | conda create -n neural python=3.9 58 | conda activate neural 59 | ``` 60 | 61 | ### Basic Installation 62 | 63 | 1. Install the required Conda dependencies: 64 | ```bash 65 | conda install -c conda-forge compas==1.17.10 compas_view2==0.7.0 66 | ``` 67 | 68 | 2. Install the package and its pip dependencies: 69 | ```bash 70 | pip install -e . 71 | ``` 72 | The `-e .` flag installs the package in "editable" mode, which means changes to the source code take effect immediately without reinstalling. 73 | 74 | ### Advanced Installation 75 | 76 | If you need additional development tools (testing, formatting, etc.), are interested in making data plots, or want to generate bricks for a shell, follow these steps: 77 | 78 | 1. Install the necessary and additional Conda dependencies: 79 | ```bash 80 | conda install -c conda-forge compas==1.17.10 compas_view2==0.7.0 compas_cgal==0.5.0 81 | ``` 82 | 83 | 2. Install the package with development dependencies: 84 | ```bash 85 | pip install -e ".[dev]" 86 | ``` 87 | [Table of contents](#table-of-contents) ⬆️ 88 | 89 | ## Configuration 90 | 91 | Our work focuses on two structural design tasks: compression-only shells and cablenet towers. 92 | 93 | We thus create a `.yml` file with all the configuration hyperparameters per task. 94 | The files are stored in the `scripts` folder as: 95 | - `bezier.yml`, and 96 | - `tower.yml` 97 | 98 | for the first and the second task, respectively. 99 | The hyperparameters exposed in the configuration files range from choosing a data generator, prescribing the model architecture, and the optimization scheme. 100 | We'll be mingling with them to steer the wheel while we run experiments. 101 | 102 | 103 | ### Data generation 104 | 105 | An advantage of our work is that we only need to define target shapes alone, without a vector of force densities to be paired as ground-truth labels. 106 | That would be the case in a fully supervised setting, which is not the case here. 107 | Our model figures these labels out automatically. 108 | This allows us to generate a dataset of target shapes on the fly at train time by specifying a `generator` configuration and a random `seed` to create pseudo-random keys. 109 | 110 | #### Shells 111 | The target shell shapes are parametrized by a square Bezier patch. 112 | 113 | - `name`: The name of the generator to instantiate. One of `bezier_symmetric_double`, `bezier_symmetric`, `bezier_asymmetric`, and `bezier_lerp`. The first two options constraint the shapes to be symmetric along two or one axis, respectively. The third option does not enforce symmetry. The last option, `bezier_lerp` is used to interpolate linearnly a batch of doubly-symmetric and asymmetric shapes (i.e., shapes generated by `bezier_symmetric_double` and `bezier_asymmetric`). 114 | - `num_points`: the number of control points on one side of the grid that parametrizes a Bezier patch. 115 | - `bounds`: It specifies how to wiggle the control points of the patch on a `num_points x num_points` grid. The option `pillow` only moves the internal control point up and down, while `dome` additionally jitters the two control points on the boundary in and out. `saddle` is an extension of `dome` in that it lets one of the control points on the boundary move up and down too. 116 | - `num_uv`: The number of spans to evaluate on the Bezier along the *u* and *v* directions. A value of `10`, for example, would result in a `10x10` grid of target points. These are the points to be matched during training. 117 | - `size`: The length of the sides of the patch. It defines the scale of the task. 118 | - `lerp_factor`: A scalar factor in [0, 1] to interpolate between two target surfaces. Only employed for `bezier_lerp`. 119 | 120 | #### Towers 121 | The target tower shapes are described in turn by a vertical sequence of planar circles. 122 | The tower rings are deformed and rotated depending on the generator `name` and `bounds`. 123 | 124 | - `name`: The generator name. Use `tower_ellipse` to make target shapes with elliptical rings, and `tower_circles` to keep the rings as circles. 125 | - `bounds`: Either `straight` or `twisted`. The former only scales the rings on the plane at random. The latter scales and rotates the rings at random. 126 | - `height`: The tower height. 127 | - `radius`: The start radius of the all the generated circles. 128 | - `num_sides`: The number of segments to discretize each circle with. 129 | - `num_levels`: The number of circles to create along the tower's height. Equidistantly spaced. 130 | - `num_rings`: The number of circles to be morphed during training. Must be `>2` since two of these rings are, by default, at the top and bottom of the tower. 131 | 132 | ### Building a model 133 | 134 | We specify the architecture of a model in the configuration file, which for the most part, ressembles an autoencoder. 135 | The configuration scheme is the same for any task. 136 | 137 | #### Neural networks 138 | 139 | Our experiments use multilayer perceptrons (MLP) for the encoder that maps shapes to simulation parameters, although we are by no means restricted to that. 140 | An MLP too serves as a decoder for our fully neural baselines. 141 | We employ one of the simplest possible neural networks, the MLP, to quantify the benefits of having a physics simulator in a neural network in large-scale mechanical design tasks. 142 | This sets a baseline from which we can build upon with beefier architectures like graph neural networks, transformers, and beyond. 143 | 144 | The encoder hyperparameters are: 145 | - `shift`: The lower bound shift in output of the last layer of the encoder. This is what we call `tau` in the [paper](https://arxiv.org/abs/2409.02606). 146 | - `hidden_layer_size`: The width of every fully-connected hidden layer. We restrict the size to `256` in all the experiments. 147 | - `hidden_layer_num`: The number of hidden layers, output layer included. 148 | - `activation_fn_name`: The name of the activation function after each hidden layer. We typically resort to `elu`. 149 | - `final_activation_fn_name`: The activation function name after the output layer. We use `softplus` to ensure a strictly positive output, as needed by the simulator decoder. 150 | 151 | The neural decoder's setup mirrors the encoder's, except for the `include_params_xl` flag. 152 | If set to `True`, then the decoder expects the latents and boundary conditions as inputs. 153 | Otherwise, it only decodes the latents. 154 | We fix this hyperparameter to `True` in the [paper](https://arxiv.org/abs/2409.02606). 155 | 156 | #### Simulator 157 | 158 | For the simulator, the force density method (FDM), we only have `load` as a hyperparameter, which sets the magnitude of a vertical **area** load applied to the structures in the direction of gravity (hello Isaac Newton! 🍎). 159 | 160 | If this value is nonzero, then the model will convert the area load into point loads to be compatible with our physics simulator. 161 | 162 | ### Loss and Optimizer 163 | 164 | The training setup is also defined in the configuration file of the task, including the `loss` function to optimize for, the `optimizer` that updates the model parameters, and the `training` schedule that pretty much allocates the compute budget. 165 | 166 | The `loss` function is the sum of multiple terms, that for the most part are a shape loss and a physics loss, as we explain in the [paper](https://arxiv.org/abs/2409.02606). 167 | We allow for more refined control on the scaling of each loss term in the file: 168 | - `include`: Whether or not to include the loss term during training. If set to `False`, then the value of the loss term is not calculated, saving some computation resources. By default, `include=True`. 169 | - `weight`: The scalar weight of the loss term used for callibrating model performance, called `kappa` in the [paper](https://arxiv.org/abs/2409.02606). It is particularly useful to tune the scale of the physics loss when training the PINN baseline. The `weight=1.0` by default. 170 | 171 | The `optimizer` hyperparameters are: 172 | - `name`: the name of the gradient-based optimizer. We currently support `adam` and `sgd` from the `optax` library, but only use `adam` in the [paper](https://arxiv.org/abs/2409.02606). 173 | - `learning_rate`: The constant learning rate. The rate is fixed, we ommit schedulers - it is more elegant. 174 | - `clip_norm`: The global norm for gradient clipping. If set to `0.0`, then gradient clipping is ignored. 175 | 176 | And for the `training` routine: 177 | - `steps`: The number of optimization steps for model training (i.e., the number of times the model parameters are updated). We mostly train the models for `10000` steps. 178 | - `batch_size`: The batch size of the input data. 179 | 180 | [Table of contents](#table-of-contents) ⬆️ 181 | 182 | ## Training 183 | 184 | After setting up the config files, now it's time to make that CPU go brrrrr. 185 | Execute the `train.py` script from your terminal: 186 | 187 | ```bash 188 | python train.py 189 | ``` 190 | 191 | Where `task_name` is either `bezier` for the shells task or `tower` for the towers task. 192 | Task-specific configuration details are given in the [paper](https://arxiv.org/abs/2409.02606). 193 | 194 | The `model_name` is where things get interesting. 195 | In summary: 196 | 197 | - Ours: `formfinder` 198 | - NN and PINN baseline: `autoencoder` 199 | 200 | If `autoencoder` is trained with the `residual` (i.e., the physics loss is included and active), this model will become a PINN baseline and will internally be renamed as `autoencoder_pinn` (sorry, naming is hard). 201 | 202 | We invite you to check the docstring of the `train.py` script to see all the input options. 203 | They would allow you to warmstart the training from an existing pretrained model, checkpoint every so often, as well as plot and export the loss history for your inspection. 204 | 205 | > A note on hyperparameter tuning. We utilized WandB to run hyperparameter sweeps. The sweeps are in turn handled by the `sweep.py` script in tandem with `sweep_bezier.yml` or `sweep_tower.yml` files, depending on the task. The structure of these sweep files mimics that of the configuration files described herein. We trust you'll be able to find your way around them if you really want to fiddle with them. 206 | 207 | [Table of contents](#table-of-contents) ⬆️ 208 | 209 | ## Testing 210 | 211 | To evaluate the trained models on a test batch, run: 212 | 213 | ```bash 214 | python predict.py --batch_size= --seed= 215 | ``` 216 | where we set to `--batch_size=100` during inference to match what we do in the [paper](https://arxiv.org/abs/2409.02606). 217 | The test set is created by a generator that follows the same configuration as the train set, except for the random seed. 218 | We set `test_seed` to `90` in the `bezier` task and `test_seed` to `92` in the `tower` task. 219 | Feel free to specify other seed values to test the model on different test datasets. 220 | 221 | [Table of contents](#table-of-contents) ⬆️ 222 | 223 | ## Visualization 224 | 225 | An image is worth more than a thousand words, or in this case, more than a thousand numbers in a JAX array. 226 | 227 | You can visualize the prediction a model makes, either ours or the baselines, with a dedicated script that lets you take control over the style of the rendered prediction: 228 | 229 | ```bash 230 | python visualize.py --shape_index= --seed= 231 | ``` 232 | 233 | The shape to display is selected by inputting its index relative to the batch size with the `` argument. 234 | Check out the docstring of `visualize.py` for the nitty-gritty details of how to control color palettes, linewidths, and arrow scales for making pretty pictures. 235 | 236 | [Table of contents](#table-of-contents) ⬆️ 237 | 238 | ## Direct optimization 239 | 240 | So far we've only discussed how to create neural models for shape-matching tasks. 241 | Direct gradient-based optimization is another baseline that merits its own section as it is the ground-truth in traditional design optimization in structural engineering. 242 | Take an optimizer for a ride via: 243 | 244 | ``` 245 | python optimize.py --batch_size= --seed= --blow= --bup= --param_init= --maxiter= --tol= 246 | ``` 247 | 248 | We support two constrained gradient-based algorithms as implemented in `jaxopt`. 249 | Select one of them through their `optimizer_name`: 250 | - `slsqp`: The sequential least squares quadratic programming algorithm. 251 | - `lbfgsb`: The limited-memory Broyden–Fletcher–Goldfarb–Shanno algorithm. 252 | 253 | The algorithms support box constraints on the simulation parameters. 254 | We take advantage of this feature to constrain their value to a specific sign and to a range of reasonable values, depending on the task. 255 | The lower box constraint is equivalent to the effect that `tau` has on the decoder's output in the [paper](https://arxiv.org/abs/2409.02606), by prescribing a minimum output value. 256 | Both optimizers run for `maxiter=5000` iterations at most and stop early if they hit the convergence tolerance of `tol=1e-6`. 257 | 258 | We're in the business of local optimization, so the inialization affects convergence. 259 | You can pick between two initialization schemes with `param_init` that respect the force density signs of a task (compression or tension): 260 | 261 | - If specified as a scalar, it determines the starting constant value of all the simulation parameters. 262 | - If set to `None`, the initialization samples starting parameters between `blow` and `bup` from a uniform distribution. 263 | 264 | In the shells task, we apply `slsqp`, set `blow=0.0` and `bup=20.0`, and `param_init=None`. 265 | 266 | In contrast, the towers task uses `lbfgsb` and `blow=1.0` to match the value of `tau` we used in this task in the paper. 267 | The towers task is more nuanced because we explore three different initialization schemes: 268 | - Randomized: `param_init=None` 269 | - Expert: `param_init=1.0` 270 | 271 | The third initialization type relies on the predictions of a pre-trained model and, to use it, we need to invoke a different script. 272 | See [Predict then optimize](#predict-then-optimize) below. 273 | 274 | [Table of contents](#table-of-contents) ⬆️ 275 | 276 | ## Predict then optimize 277 | 278 | There is enormous potential in combining neural networks with traditional optimization techniques to expedite mechanical design. 279 | An opportunity in this space is to leverage the prediction made by one of our models and refine that prediction with direct optimization to unlock the best-performing designs. 280 | 281 | Our key to open a tiny (very tiny) door into this space is the `predict_optimize.py` script: 282 | 283 | ``` 284 | python predict_optimize.py --batch_size= --seed= --blow= --bup= --maxiter= --tol= 285 | ``` 286 | 287 | What is different from the `optimize.py` script is that, now, you will have to specify the name of a trained model via `model_name`. 288 | The predictions will warmstart the optimization, replacing any of the `param_init` schemes described earlier. 289 | The rest of the inputs work the same way as in `optimize.py`. 290 | 291 | [Table of contents](#table-of-contents) ⬆️ 292 | 293 | ## Citation 294 | 295 | Consider citing our [paper](https://arxiv.org/abs/2409.02606) if this work was helpful to your research. 296 | Don't worry, it's free. 297 | 298 | ```bibtex 299 | @inproceedings{ 300 | pastrana_2025_diffmechanics, 301 | title={Real-time design of architectural structures with differentiable mechanics and neural networks}, 302 | author={Rafael Pastrana and Eder Medina and Isabel M. de Oliveira and Sigrid Adriaenssens and Ryan P. Adams}, 303 | booktitle={The Thirteenth International Conference on Learning Representations}, 304 | year={2025}, 305 | url={https://openreview.net/forum?id=Tpjq66xwTq} 306 | } 307 | ``` 308 | 309 | [Table of contents](#table-of-contents) ⬆️ 310 | 311 | ## Contact 312 | 313 | Reach out! If you have questions or find bugs in our code, please open an issue on Github or email the authors at arpastrana@princeton.edu. 314 | 315 | [Table of contents](#table-of-contents) ⬆️ 316 | -------------------------------------------------------------------------------- /masonry_vault_cad_design.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrincetonLIPS/neural_fdm/15672aed937aba7c53c5d59734833d92c8d5c56b/masonry_vault_cad_design.gif -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "neural_fdm" 7 | version = "0.1.0" 8 | description = "Combining differentiable mechanics simulations with neural networks" 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | license = "MIT" 12 | authors = [ 13 | { name = "Rafael Pastrana" } 14 | ] 15 | dependencies = [ 16 | "numpy<2", 17 | "scipy<1.13", 18 | "jax==0.4.23", 19 | "jaxlib==0.4.23", 20 | "equinox==0.11.3", 21 | "jax-fdm==0.8.6", 22 | "optax==0.1.5", 23 | "pyyaml==6.0.1", 24 | "tqdm==4.66.1", 25 | "fire==0.6.0", 26 | "matplotlib>=3.0" 27 | ] 28 | 29 | [project.optional-dependencies] 30 | dev = [ 31 | "freetype-py", 32 | "pytest>=7.0", 33 | "pytest-cov>=4.0", 34 | "black>=23.0", 35 | "isort>=5.0", 36 | "mypy>=1.0", 37 | "ruff>=0.1.0", 38 | "pre-commit>=3.0", 39 | "jupyter>=1.0", 40 | "ipykernel>=6.0", 41 | "seaborn==0.13.2" 42 | ] 43 | 44 | [tool.hatch.build.targets.wheel] 45 | packages = ["src/neural_fdm"] -------------------------------------------------------------------------------- /scripts/Vera.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PrincetonLIPS/neural_fdm/15672aed937aba7c53c5d59734833d92c8d5c56b/scripts/Vera.ttf -------------------------------------------------------------------------------- /scripts/bezier.yml: -------------------------------------------------------------------------------- 1 | # randomness 2 | seed: 91 # 91 3 | # dataset 4 | generator: 5 | # options: "bezier_symmetric_double", "bezier_symmetric", "bezier_asymmetric", "bezier_lerp" 6 | name: "bezier_symmetric_double" # "bezier_symmetric_double" 7 | bounds: "saddle" # options: pillow, dome, saddle 8 | num_uv: 10 # 10, 16, 23 9 | size: 10.0 10 | num_points: 4 # grid points 11 | lerp_factor: 0.5 # scalar factor in [0, 1] to interpolate between 2 surfaces, only for bezier_lerp 12 | # simulator 13 | fdm: 14 | load: -0.5 # -0.5, scale of vertical area load 15 | # neural networks 16 | encoder: 17 | shift: 0.0 18 | hidden_layer_size: 256 19 | hidden_layer_num: 3 20 | activation_fn_name: "elu" 21 | final_activation_fn_name: "softplus" # needs softplus to ensure positive output 22 | decoder: 23 | # If true, the decoder maps (z, boundary conditions) -> x. Otherwise, z -> x. 24 | include_params_xl: True 25 | hidden_layer_size: 256 26 | hidden_layer_num: 3 27 | activation_fn_name: "elu" 28 | # loss function 29 | loss: 30 | shape: 31 | include: True 32 | weight: 1.0 # weight of the shape error term in the loss function 33 | residual: # PINN term 34 | include: True 35 | weight: 1.0 # weight of the residual error term in the loss function 36 | # optimization 37 | optimizer: 38 | name: "adam" 39 | learning_rate: 3.0e-5 # 3.0e-5 (formfinder), 5.0e-5 (others). Be careful with scientific notation in YAML! 40 | clip_norm: 0.0 41 | # training 42 | training: 43 | steps: 10000 # 10000 44 | batch_size: 64 45 | -------------------------------------------------------------------------------- /scripts/brickify.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate bricks off of the faces of an optimized mesh. 3 | """ 4 | 5 | import os 6 | from random import randint 7 | 8 | from compas.colors import Color 9 | 10 | from compas.datastructures import Mesh 11 | from compas.datastructures import mesh_dual 12 | from compas.datastructures import mesh_thicken 13 | from compas.datastructures import mesh_offset 14 | from compas.datastructures import mesh_delete_duplicate_vertices 15 | 16 | from compas.geometry import Line 17 | from compas.geometry import Sphere 18 | from compas.geometry import Box 19 | from compas.geometry import centroid_points 20 | from compas.geometry import oriented_bounding_box_numpy 21 | from compas.geometry import Transformation 22 | from compas.geometry import Scale 23 | from compas.geometry import Frame 24 | from compas.geometry import Plane 25 | from compas.geometry import intersection_line_plane 26 | from compas.geometry import area_triangle 27 | from compas.geometry import normal_triangle 28 | from compas.geometry import circle_from_points 29 | from compas.geometry import normalize_vector 30 | from compas.geometry import add_vectors 31 | from compas.geometry import subtract_vectors 32 | 33 | from compas.utilities import pairwise 34 | from compas.utilities import geometric_key 35 | 36 | from compas_cgal.booleans import boolean_difference as boolean_difference_mesh_mesh 37 | from compas_cgal.meshing import remesh as mesh_remesh 38 | 39 | from jax_fdm.datastructures import FDMesh 40 | from jax_fdm.visualization import Viewer 41 | 42 | from neural_fdm import DATA 43 | 44 | from text_2_mesh import text_2_mesh 45 | 46 | 47 | # =============================================================================== 48 | # Helper functions 49 | # =============================================================================== 50 | 51 | def triangulate_face_quad(face, reverse=False): 52 | """ 53 | Triangulate a mesh quad face. 54 | 55 | Parameters 56 | ___________ 57 | face: `list` of `int` 58 | The face vertices. 59 | reverse: `bool`, optional 60 | If `True`, the face is reversed. 61 | 62 | Returns 63 | _______ 64 | new_faces: `list` of `list` of `int` 65 | The two triangulated faces. 66 | """ 67 | a, b, c, d = face 68 | if not reverse: 69 | face_a = [a, b, d] 70 | face_b = [b, c, d] 71 | else: 72 | face_a = [a, b, c] 73 | face_b = [c, d, a] 74 | 75 | return [face_a, face_b] 76 | 77 | 78 | def triangulate_face_ngon(face, vertices): 79 | """ 80 | Triangulate a mesh polygonal face. 81 | 82 | Parameters 83 | ___________ 84 | face: `list` of `int` 85 | The face vertices. 86 | vertices: `list` of `list` of `float` 87 | The xyz coordinates of the face vertices. 88 | 89 | Returns 90 | _______ 91 | new_faces: `list` of `list` of `int` 92 | The triangulated faces. 93 | """ 94 | midpoint = centroid_points([vertices[vkey] for vkey in face]) 95 | vertices.append(midpoint) 96 | ckey = len(vertices) - 1 97 | 98 | # create new faces 99 | new_faces = [] 100 | for a, b in pairwise(face + face[:1]): 101 | _face = [a, b, ckey] 102 | new_faces.append(_face) 103 | 104 | return new_faces 105 | 106 | 107 | def triangulate_face(face, vertices, reverse=False): 108 | """ 109 | Triangulate a mesh face based on its vertex count. 110 | 111 | The face is a list of indices pointing to a list with the vertices xyz coordinates. 112 | 113 | Parameters 114 | ___________ 115 | face: `list` of `int` 116 | The face vertices. 117 | vertices: `list` of `list` of `float` 118 | The xyz coordinates of the face vertices. 119 | reverse: `bool`, optional 120 | If `True`, the face is reversed. 121 | 122 | Returns 123 | _______ 124 | new_faces: `list` of `list` of `int` 125 | The triangulated faces. 126 | """ 127 | assert len(face) > 2 128 | 129 | new_faces = [] 130 | # triangle 131 | if len(face) == 3: 132 | new_faces = [face] 133 | # quad 134 | elif len(face) == 4: 135 | new_faces = triangulate_face_quad(face, reverse) 136 | # ngon 137 | else: 138 | new_faces = triangulate_face_ngon(face, vertices) 139 | 140 | return new_faces 141 | 142 | 143 | def calculate_brick_thicknesses(thickness): 144 | """ 145 | Calculate the top and bottom thicknesses of a brick. 146 | 147 | Parameters 148 | ___________ 149 | thickness: `float` 150 | The brick thickness. 151 | 152 | Returns 153 | _______ 154 | thickness_bottom: `float` 155 | The bottom thickness of the brick. 156 | thickness_top: `float` 157 | The top thickness of the brick. 158 | """ 159 | thickness_bottom = thickness / 3.0 160 | thickness_top = 2.0 * thickness / 3.0 161 | 162 | return thickness_bottom, thickness_top 163 | 164 | 165 | def generate_bricks(mesh, thickness): 166 | """ 167 | Generate a solid brick per mesh face. 168 | 169 | Parameters 170 | ___________ 171 | mesh: `compas.datastructures.Mesh` 172 | The mesh whose faces will be turned into bricks. 173 | thickness: `float` 174 | The global brick thickness. 175 | 176 | Returns 177 | _______ 178 | bricks: `dict` of `compas.datastructures.Mesh` 179 | The bricks meshes (closed, watertight). 180 | meshes: `tuple` of `compas.datastructures.Mesh` 181 | The meshes of the bottom and top faces of the bricks to create the scaffolding. 182 | """ 183 | thick_bottom, thick_top = calculate_brick_thicknesses(thickness) 184 | 185 | mesh_bottom = mesh_offset(mesh, thick_bottom) 186 | mesh_top = mesh_offset(mesh, thick_top * -1.0) 187 | 188 | bricks = {} 189 | halfedges_visited = set() 190 | 191 | for i, fkey in enumerate(mesh.faces()): 192 | 193 | # adding side faces 194 | faces = [] 195 | 196 | num_vertices = len(mesh.face_vertices(fkey)) 197 | face_bottom = list(range(num_vertices)) 198 | face_top = list(range(num_vertices, 2 * num_vertices)) 199 | 200 | xyz_bottom = mesh_bottom.face_coordinates(fkey) 201 | xyz_top = mesh_top.face_coordinates(fkey) 202 | vertices = xyz_bottom + xyz_top 203 | 204 | halfedges = mesh.face_halfedges(fkey) 205 | iterable = zip( 206 | pairwise(face_top + face_top[:1]), 207 | pairwise(face_bottom + face_bottom[:1]), 208 | halfedges 209 | ) 210 | 211 | # triangulate 212 | for edge_1, edge_2, halfedge in iterable: 213 | a, b = edge_1 214 | d, c = edge_2 215 | face = [a, b, c, d] 216 | is_visited = halfedge in halfedges_visited 217 | tri_faces = triangulate_face_quad(face, is_visited) 218 | faces.extend(tri_faces) 219 | 220 | # face half edges 221 | halfedges_visited.update(halfedges) 222 | halfedges_reversed = [(v, u) for u, v in halfedges] 223 | halfedges_visited.update(halfedges_reversed) 224 | 225 | # triangulate top and bottom faces 226 | tri_faces = triangulate_face(face_bottom, vertices) 227 | faces.extend(tri_faces) 228 | 229 | tri_faces = triangulate_face(face_top, vertices) 230 | faces.extend([face[::-1] for face in tri_faces]) 231 | 232 | # make mesh from face 233 | brick = Mesh.from_vertices_and_faces(vertices, faces) 234 | 235 | # store brick 236 | bricks[fkey] = brick 237 | 238 | return bricks, (mesh_bottom, mesh_top) 239 | 240 | 241 | def add_text_engraving(fkey, bricks, mesh_bottom, text, depth=1): 242 | """ 243 | Engrave the underside of a brick mesh with text. 244 | 245 | Parameters 246 | ___________ 247 | fkey: `int` 248 | The face key of the brick to engrave. 249 | bricks: `dict` of `compas.datastructures.Mesh` 250 | The brick meshes. 251 | mesh_bottom: `compas.datastructures.Mesh` 252 | The mesh of the bottom face of the brick. 253 | text: `str` 254 | The text to engrave. 255 | depth: `float`, optional 256 | The depth of the engraving. 257 | 258 | Returns 259 | _______ 260 | text_mesh: `compas.datastructures.Mesh` 261 | The thickened text mesh. 262 | bbox_mesh: `compas.datastructures.Mesh` 263 | The bounding box of the text. 264 | face_mesh: `compas.datastructures.Mesh` 265 | A mesh with the face of the brick where the text is engraved. 266 | brick: `compas.datastructures.Mesh` 267 | The engraved brick mesh. 268 | """ 269 | # generate text mesh 270 | text_mesh = text_2_mesh(text) 271 | vertices, _ = text_mesh.to_vertices_and_faces() 272 | 273 | # calculate bounding box and properties 274 | bbox = oriented_bounding_box_numpy(vertices) 275 | bbox = Box.from_bounding_box(bbox) 276 | bbox_largest = max([bbox.width, bbox.depth, bbox.height]) 277 | 278 | # find face in brick 279 | brick = bricks[fkey] 280 | 281 | # generate face plane 282 | vertices = mesh_bottom.face_coordinates(fkey) 283 | face = list(range(len(vertices))) 284 | tri_faces = triangulate_face(face, vertices) 285 | 286 | # take face with largest area 287 | tri_faces = sorted(tri_faces, key=lambda x: area_triangle([vertices[v] for v in x])) 288 | big_face = tri_faces[-1] 289 | big_face_vertices = [vertices[v] for v in big_face] 290 | _, radius = circle_from_points(*big_face_vertices) 291 | 292 | big_face_mesh = Mesh.from_vertices_and_faces(vertices, [big_face]) 293 | 294 | center = centroid_points(big_face_vertices) 295 | normal = normal_triangle(big_face_vertices) 296 | plane = Plane(center, normal) 297 | brick_frame = Frame.from_plane(plane) 298 | 299 | # box again 300 | ratio = (0.4 * radius) / bbox_largest 301 | S = Scale.from_factors(factors=[ratio] * 3) 302 | bbox.transform(S) 303 | 304 | # text again 305 | text_origin = [x for x in bbox.frame.point] 306 | text_frame = Frame(text_origin, [-1.0, 0.0, 0.0], [0.0, -1.0, 0.0]) 307 | text_mesh.transform(S) 308 | T = Transformation.from_frame_to_frame(text_frame, brick_frame) 309 | text_mesh.transform(T) 310 | 311 | B = text_mesh.to_vertices_and_faces(triangulated=True) 312 | B = mesh_remesh(B, radius/30.0, 10) 313 | text_mesh = Mesh.from_vertices_and_faces(*B) 314 | 315 | text_mesh = mesh_thicken(text_mesh, depth) 316 | 317 | vertices, faces = bbox.to_vertices_and_faces() 318 | bbox_mesh = Mesh.from_vertices_and_faces(vertices, faces) 319 | 320 | A = brick.to_vertices_and_faces() 321 | B = text_mesh.to_vertices_and_faces(triangulated=True) 322 | 323 | V, F = boolean_difference_mesh_mesh(A, B) 324 | brick = Mesh.from_vertices_and_faces(V, F) 325 | 326 | return text_mesh, bbox_mesh, big_face_mesh, brick 327 | 328 | 329 | def generate_scaffolding(mesh, thickness): 330 | """ 331 | Generate the scaffolding platform under the bricks. 332 | 333 | Parameters 334 | ---------- 335 | mesh: `compas.datastructures.Mesh` 336 | The shell the scaffolding will support. 337 | thickness: `float` 338 | The scaffolding thickness. 339 | """ 340 | # convert 341 | scaffold_mesh = mesh_offset(mesh, 1.0 * thickness / 2.0) 342 | vertices, faces = scaffold_mesh.to_vertices_and_faces() 343 | 344 | # triangulate existing faces 345 | faces_tri = [] 346 | for face in faces: 347 | faces_tri.extend(triangulate_face(face, vertices)) 348 | 349 | # add boundary face 350 | key_index = scaffold_mesh.key_index() 351 | face_bnd = [key_index[fkey] for fkey in scaffold_mesh.vertices_on_boundary()][::-1] 352 | faces_tri.extend(triangulate_face(face_bnd, vertices)) 353 | 354 | # create mesh 355 | return Mesh.from_vertices_and_faces(vertices, faces_tri) 356 | 357 | 358 | def calculate_vertex_nbr_line(vkey, mesh): 359 | """ 360 | Find the line connecting a boundary vertex to its 'perpendicular' neighbor on the interior. 361 | 362 | Parameters 363 | ---------- 364 | vkey: `int` 365 | The vertex key. 366 | mesh: `compas.datastructures.Mesh` 367 | The mesh. 368 | 369 | Returns 370 | ------- 371 | line: `tuple` of `list` of `float` 372 | The line connecting the vertex to its neighbor. 373 | """ 374 | # sift neighbors 375 | nbrs_boundary = [] 376 | nbrs_interior = [] 377 | for nkey in mesh.vertex_neighbors(vkey): 378 | if nkey == vkey: 379 | continue 380 | if mesh.is_vertex_on_boundary(nkey): 381 | nbrs_boundary.append(nkey) 382 | else: 383 | nbrs_interior.append(nkey) 384 | # cases 385 | num_nbrs_interior = len(nbrs_interior) 386 | if num_nbrs_interior == 1: 387 | start = mesh.vertex_coordinates(vkey) 388 | end = mesh.vertex_coordinates(nbrs_interior[0]) 389 | line = (start, end) 390 | 391 | elif num_nbrs_interior == 0: 392 | assert len(nbrs_boundary) == 2 393 | start = mesh.vertex_coordinates(vkey) 394 | lines = [] 395 | for nkey in nbrs_boundary: 396 | line = calculate_vertex_nbr_line(nkey, mesh) 397 | lines.append(line) 398 | 399 | assert len(lines) > 0 400 | 401 | end = start 402 | for line in lines: 403 | a, b = line 404 | vector = normalize_vector(subtract_vectors(b, a)) 405 | end = add_vectors(end, vector) 406 | line = (start, end) 407 | else: 408 | raise ValueError 409 | 410 | return line 411 | 412 | 413 | def generate_boundary_support(mesh, thickness): 414 | """ 415 | Generate the mesh of the support ring bearing the bricks at the boundary. 416 | 417 | Parameters 418 | ---------- 419 | mesh: `compas.datastructures.Mesh` 420 | The middle surface of the masonry shell. 421 | thickness: `float` 422 | The thickness of the ring. 423 | 424 | Returns 425 | ------- 426 | support_mesh: `compas.datastructures.Mesh` 427 | The support ring mesh. 428 | """ 429 | thick_bottom, thick_top = calculate_brick_thicknesses(thickness) 430 | 431 | mesh_bottom = mesh_offset(mesh, thick_bottom) 432 | mesh_top = mesh_offset(mesh, thick_top * -2.0) 433 | 434 | # generate xyz polygon bottom 435 | polygon_bottom = [mesh_bottom.vertex_coordinates(vkey) for vkey in mesh.vertices_on_boundary()] 436 | 437 | # generate xyz polygon top 438 | polygon_top = [mesh_top.vertex_coordinates(vkey) for vkey in mesh.vertices_on_boundary()] 439 | 440 | # generate xyz polygon that intersects with ground plane 441 | lines = [] 442 | polygon_offset = [] 443 | polygon_squashed = [] 444 | ground_level = -thick_bottom 445 | plane = Plane([0.0, 0.0, ground_level], [0.0, 0.0, 1.0]) 446 | 447 | for vkey in mesh.vertices_on_boundary(): 448 | 449 | start = mesh_top.vertex_coordinates(vkey) 450 | line = calculate_vertex_nbr_line(vkey, mesh_top) 451 | point = intersection_line_plane(line, plane) 452 | if not point: 453 | raise ValueError("No intersection found!") 454 | 455 | line = Line(start, point) 456 | point = line.point(0.75) 457 | 458 | line = Line(start, point) 459 | lines.append(line) 460 | 461 | polygon_offset.append(point) 462 | point = point[:] 463 | point[2] = ground_level 464 | polygon_squashed.append(point) 465 | 466 | # now, weave polygons to make mesh 467 | polygons = [ 468 | polygon_bottom, 469 | polygon_top, 470 | polygon_offset, 471 | polygon_squashed, 472 | ] 473 | 474 | # add polygons' points to main vertex list 475 | max_vkey = 0 476 | gkey_vkey = {} 477 | points = [] 478 | 479 | for polygon in polygons: 480 | for point in polygon: 481 | gkey = geometric_key(point) 482 | if gkey in gkey_vkey: 483 | continue 484 | points.append(point) 485 | gkey_vkey[gkey] = max_vkey 486 | max_vkey += 1 487 | 488 | # weave faces 489 | faces = [] 490 | for polyline in zip(*(pairwise(polygon) for polygon in polygons)): 491 | for line_a, line_b in pairwise(polyline + polyline[:1]): 492 | a, b = (gkey_vkey[geometric_key(pt)] for pt in line_a) 493 | d, c = (gkey_vkey[geometric_key(pt)] for pt in line_b) 494 | face = [a, b, c, d] 495 | faces.append(face) 496 | 497 | support_mesh = Mesh.from_vertices_and_faces(points, faces) 498 | 499 | return support_mesh 500 | 501 | 502 | # =============================================================================== 503 | # Script function 504 | # =============================================================================== 505 | 506 | def brickify( 507 | name, 508 | thickness, 509 | scale=1.0, 510 | dual=True, 511 | do_bricks=True, 512 | do_label=True, 513 | do_scaffold=False, 514 | do_support=False, 515 | save=False 516 | ): 517 | """ 518 | Generate bricks on the faces of a mesh. One face = one brick. 519 | 520 | Parameters 521 | ---------- 522 | name: `str` 523 | The mesh name (without extension). 524 | thickness: `float` 525 | The brick thickness. 526 | scale: `float`, optional 527 | The mesh scale, whether to scale it down or up to fit in a printer's bed. 528 | dual: `bool`, optional 529 | If `True`, the script will work on the dual of the input mesh. 530 | do_bricks: `bool`, optional 531 | If `True`, generate the bricks as closed, watertight meshes. 532 | do_label: `bool`, optional 533 | If `True`, engrave the bricks with labels via mesh boolean differences. 534 | do_scaffold: `bool`, optional 535 | If `True`, generate scaffolding platform. 536 | do_support: `bool`, optional 537 | If `True`, create the perimetral support for the bricks. 538 | save: `bool`, optional 539 | If `True`, save all generated data as both JSON and OBJ files. 540 | """ 541 | CAMERA_CONFIG = { 542 | "position": (30.34, 30.28, 42.94), 543 | "target": (0.956, 0.727, 1.287), 544 | "distance": 20.0, 545 | } 546 | 547 | # load mesh 548 | filepath = os.path.join(DATA, f"{name}.json") 549 | mesh = FDMesh.from_json(filepath) 550 | print(mesh) 551 | 552 | # calculate dual mesh 553 | if dual: 554 | mesh = mesh_dual(mesh, include_boundary=True) 555 | mesh_delete_duplicate_vertices(mesh) 556 | 557 | # scale mesh 558 | if scale != 1.0: 559 | S = Scale.from_factors(factors=[scale] * 3) 560 | mesh.transform(S) 561 | 562 | # do bricks 563 | if do_bricks: 564 | bricks, meshes = generate_bricks(mesh, thickness) 565 | mesh_bottom, mesh_top = meshes 566 | 567 | if save: 568 | filepath = os.path.join(DATA, f"{name}_top.json") 569 | mesh_top.to_json(filepath) 570 | print(f"Saved mesh top to {filepath}") 571 | filepath = os.path.join(DATA, f"{name}_bottom.json") 572 | mesh_bottom.to_json(filepath) 573 | print(f"Saved mesh bottom to {filepath}") 574 | 575 | # generate scaffolding 576 | if do_scaffold: 577 | scaffold_mesh = generate_scaffolding(mesh, thickness) 578 | 579 | # generate perimetral support base 580 | if do_support: 581 | support_mesh = generate_boundary_support(mesh, thickness) 582 | 583 | # engrave labels 584 | if do_bricks and do_label: 585 | for i, fkey in enumerate(mesh.faces()): 586 | data = add_text_engraving( 587 | fkey, 588 | bricks, 589 | mesh_bottom, 590 | f"{i}", 591 | depth=thickness/2.0 592 | ) 593 | text_mesh, bbox_mesh, face_mesh, brick = data 594 | bricks[fkey] = brick 595 | 596 | if save: 597 | if do_bricks: 598 | for i, brick in enumerate(bricks.values()): 599 | filepath = os.path.join(DATA, f"brick_{i}.json") 600 | brick.to_json(filepath) 601 | print(f"Saved brick to {filepath}") 602 | filepath = os.path.join(DATA, f"brick_{i}.obj") 603 | brick.to_obj(filepath) 604 | print(f"Saved brick to {filepath}") 605 | 606 | if do_scaffold: 607 | filepath = os.path.join(DATA, "scaffold.json") 608 | scaffold_mesh.to_json(filepath) 609 | print(f"Saved scaffold to {filepath}") 610 | filepath = os.path.join(DATA, "scaffold.obj") 611 | scaffold_mesh.to_obj(filepath) 612 | print(f"Saved scaffold to {filepath}") 613 | 614 | if do_support: 615 | filepath = os.path.join(DATA, "support.json") 616 | support_mesh.to_json(filepath) 617 | print(f"Saved scaffold to {filepath}") 618 | filepath = os.path.join(DATA, "support.obj") 619 | support_mesh.to_obj(filepath) 620 | print(f"Saved scaffold to {filepath}") 621 | 622 | # visualization 623 | viewer = Viewer( 624 | width=900, 625 | height=900, 626 | show_grid=True, 627 | viewmode="lighted" 628 | ) 629 | 630 | # modify view 631 | viewer.view.camera.position = CAMERA_CONFIG["position"] 632 | viewer.view.camera.target = CAMERA_CONFIG["target"] 633 | viewer.view.camera.distance = CAMERA_CONFIG["distance"] 634 | 635 | if do_scaffold: 636 | viewer.add(scaffold_mesh) 637 | 638 | if do_support: 639 | viewer.add(support_mesh) 640 | 641 | if do_bricks: 642 | for fkey, brick in bricks.items(): 643 | r, g, b = [randint(0, 255) for _ in range(3)] 644 | color = Color.from_rgb255(r, g, b) 645 | 646 | viewer.add( 647 | brick, 648 | color=color, 649 | show_points=False, 650 | show_edges=True 651 | ) 652 | 653 | # show le crème 654 | viewer.show() 655 | 656 | 657 | # =============================================================================== 658 | # Main 659 | # =============================================================================== 660 | 661 | if __name__ == "__main__": 662 | 663 | from fire import Fire 664 | 665 | Fire(brickify) 666 | -------------------------------------------------------------------------------- /scripts/camera.py: -------------------------------------------------------------------------------- 1 | """ 2 | Camera configurations to view the outputs of the bezier and tower tasks. 3 | """ 4 | 5 | CAMERA_CONFIG_BEZIER = { 6 | "position": (30.34, 30.28, 42.94), 7 | "target": (0.956, 0.727, 1.287), 8 | "distance": 20.0, 9 | } 10 | 11 | CAMERA_CONFIG_BEZIER_TOP = { 12 | "position": (0.3, 0.85, 7.5), 13 | "target": (0.3, 0.85, 0.000), 14 | "distance": 7.5, 15 | "rotation": (0.000, 0.000, 0.000), 16 | "use_top_view": True 17 | } 18 | 19 | CAMERA_CONFIG_TOWER = { 20 | "position": (10.718, 10.883, 14.159), 21 | "target": (-0.902, -0.873, 3.846), 22 | "distance": 19.482960680274577, 23 | "rotation": (1.013, 0.000, 2.362), 24 | } -------------------------------------------------------------------------------- /scripts/optimize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Optimize force densities to match a batch of target shapes with gradient-based optimization, one shape at a time (no vectorization). 3 | """ 4 | import os 5 | from math import fabs 6 | import yaml 7 | 8 | import warnings 9 | 10 | import matplotlib.pyplot as plt 11 | 12 | from time import perf_counter 13 | from statistics import mean 14 | from statistics import stdev 15 | 16 | import jax 17 | from jax import vmap 18 | import jax.numpy as jnp 19 | 20 | import jax.random as jrn 21 | 22 | import equinox as eqx 23 | 24 | from jaxopt import ScipyBoundedMinimize 25 | 26 | from compas.colors import Color 27 | from compas.colors import ColorMap 28 | from compas.geometry import Polyline 29 | from compas.utilities import remap_values 30 | 31 | from jax_fdm.datastructures import FDNetwork 32 | from jax_fdm.equilibrium import datastructure_updated 33 | from jax_fdm.visualization import Viewer 34 | 35 | from neural_fdm import DATA 36 | 37 | from neural_fdm.builders import build_mesh_from_generator 38 | from neural_fdm.builders import build_data_generator 39 | from neural_fdm.builders import build_connectivity_structure_from_generator 40 | from neural_fdm.builders import build_fd_decoder_parametrized 41 | from neural_fdm.builders import build_loss_function 42 | 43 | from neural_fdm.losses import print_loss_summary 44 | 45 | from camera import CAMERA_CONFIG_BEZIER 46 | from camera import CAMERA_CONFIG_TOWER 47 | 48 | from shapes import BEZIERS 49 | 50 | 51 | # =============================================================================== 52 | # Script function 53 | # =============================================================================== 54 | 55 | def optimize_batch( 56 | optimizer, 57 | task_name, 58 | shape_name=None, 59 | param_init=None, 60 | blow=0.0, 61 | bup=20.0, 62 | maxiter=5000, 63 | tol=1e-6, 64 | seed=None, 65 | batch_size=None, 66 | slice=(0, -1), 67 | save=False, 68 | view=False, 69 | edgecolor=None, 70 | show_reactions=False, 71 | edgewidth=(0.01, 0.25), 72 | fmax=None, 73 | fmax_tens=None, 74 | fmax_comp=None, 75 | qmin=None, 76 | qmax=None, 77 | verbose=True, 78 | record=False, 79 | save_metrics=False, 80 | ): 81 | """ 82 | Solve the prediction task on a batch target shapes with gradient-based optimization and box constraints. 83 | The box constraints help generating compression-only or tension-only solutions. 84 | 85 | This script optimizes and visualizes. This is probably not the best idea, but oh well. 86 | 87 | Parameters 88 | ---------- 89 | optimizer: `str` 90 | The name gradient-based optimizer used to solve this task. 91 | Supported methods are slsqp and lbfgsb. 92 | task_name: `str` 93 | The name of the YAML config file with the task hyperparameters. 94 | shape_name: `str` or `None`, optional 95 | The name of the shape to optimize. 96 | Supported shell shapes are pillow, dome, saddle, hypar, pringle, and cannon; 97 | and require of a `bezier_symmetric_double` generator. 98 | Supported tower shapes are either named by an integer or a float scalar. 99 | If the name is an integer, the generator should be `tower_ellipse`, and `tower_circle` if the name is a float. 100 | In general, if a name is provided, the optimization is performed on this shape, ignoring the batch. 101 | param_init: `float` or `None`, optional 102 | If specified, it determines the starting value of all the model parameters. 103 | If `None`, then it samples parameters between `blow` and `bup` from a uniform distribution. 104 | The sampling respects the force density signs of a task (compression or tension, currently hardcoded). 105 | blow: `float`, optional 106 | The lower bound of the box constraints on the model parameters. 107 | The bounds respect the force density signs of a task (compression or tension, currently hardcoded). 108 | bup: `float`, optional 109 | The lower bound of the box constraints on the model parameters. 110 | The bounds respect the force density signs of a task (compression or tension, currently hardcoded). 111 | maxiter: `int`, optional 112 | The maximum number of optimization iterations. 113 | tol: `float`, optional 114 | The tolerance for the optimization. 115 | seed: `int` or `None`, optional 116 | The random seed to generate a batch of target shapes. 117 | If `None`, it defaults to the task hyperparameters file. 118 | batch_size: `int` or `None`, optional 119 | The size of the batch of target shapes. 120 | If `None`, it defaults to the task hyperparameters file. 121 | slice: `tuple`, optional 122 | The start and stop indices of the slice of the batch for saving and viewing. 123 | Defaults to all the shapes in the batch. 124 | save: `bool`, optional 125 | If `True`, save the predicted shapes as JSON files. 126 | view: `bool`, optional 127 | If `True`, view the predicted shapes. 128 | show_reactions: `bool`, optional 129 | If `True`, show the reactions on the predicted shapes upon display. 130 | edgewidth: `tuple`, optional 131 | The minimum and maximum width of the edges for visualization. 132 | fmax: `float` or `None`, optional 133 | The maximum force for the visualization. 134 | fmax_tens: `float` or `None`, optional 135 | The maximum tensile force for the visualization. 136 | fmax_comp: `float` or `None`, optional 137 | The maximum compressive force for the visualization. 138 | qmin: `float` or `None`, optional 139 | The minimum force density for the visualization. 140 | qmax: `float` or `None`, optional 141 | The maximum force density for the visualization. 142 | verbose: `bool`, optional 143 | If `True`, print to stdout intermediary results. 144 | record: `bool`, optional 145 | If `True`, record the loss history. 146 | edgecolor: `str` or `None`, optional 147 | The color palette for the edges. 148 | Supported color palettes are fd to display force densities, and force to show forces. 149 | If `None`, the edges are colored by the force density in the shells tasks, and by the force in the tower tasks. 150 | save_metrics: `bool`, optional 151 | If `True`, saves the calcualted batch metrics in text files. 152 | """ 153 | if edgecolor is None: 154 | if task_name == "bezier": 155 | EDGECOLOR = "fd" 156 | elif task_name == "tower": 157 | EDGECOLOR = "force" 158 | else: 159 | EDGECOLOR = edgecolor 160 | 161 | START, STOP = slice 162 | SAVE = save 163 | QMIN = blow 164 | QMAX = bup 165 | EDGEWIDTH = edgewidth 166 | 167 | # pick camera configuration for task 168 | if task_name == "bezier": 169 | CAMERA_CONFIG = CAMERA_CONFIG_BEZIER 170 | _width = 900 171 | elif task_name == "tower": 172 | CAMERA_CONFIG = CAMERA_CONFIG_TOWER 173 | _width = 450 174 | 175 | # pick optimizer name 176 | optimizer_names = {"lbfgsb": "L-BFGS-B", "slsqp": "SLSQP"} 177 | optimizer_name = optimizer_names[optimizer] 178 | 179 | # load yaml file with hyperparameters 180 | with open(f"{task_name}.yml") as file: 181 | config = yaml.load(file, Loader=yaml.FullLoader) 182 | 183 | # unpack parameters 184 | if seed is None: 185 | seed = config["seed"] 186 | training_params = config["training"] 187 | if batch_size is None: 188 | batch_size = training_params["batch_size"] 189 | 190 | generator_name = config['generator']['name'] 191 | bounds_name = config['generator']['bounds'] 192 | fd_params = config["fdm"] 193 | 194 | # randomness 195 | key = jrn.PRNGKey(seed) 196 | _, generator_key = jax.random.split(key, 2) 197 | 198 | # create data generator 199 | generator = build_data_generator(config) 200 | compute_loss = build_loss_function(config, generator) 201 | structure = build_connectivity_structure_from_generator(config, generator) 202 | mesh = build_mesh_from_generator(config, generator) 203 | 204 | # generate initial model parameters 205 | q0 = calculate_params_init(mesh, param_init, key, QMIN, QMAX) 206 | 207 | # create model 208 | print(f"Directly optimizing with {optimizer_name} for {generator_name} dataset with {bounds_name} bounds on seed {seed}") 209 | model = build_fd_decoder_parametrized(q0, mesh, fd_params) 210 | 211 | # sample data batch 212 | xyz_batch = vmap(generator)(jrn.split(generator_key, batch_size)) 213 | 214 | # split model 215 | diff_model, static_model = eqx.partition(model, eqx.is_inexact_array) 216 | 217 | # wrap loss function to meet jax and jaxopt's ideosyncracies 218 | @eqx.filter_jit 219 | @eqx.debug.assert_max_traces(max_traces=1) # ensure this function is compiled at most once 220 | @eqx.filter_value_and_grad 221 | def compute_loss_diffable(diff_model, xyz_target): 222 | """ 223 | """ 224 | _model = eqx.combine(diff_model, static_model) 225 | return compute_loss(_model, structure, xyz_target, aux_data=False) 226 | 227 | # warmstart loss function to eliminate jit compilation time from perf measurements 228 | start_time = perf_counter() 229 | _ = compute_loss_diffable(diff_model, xyz_target=xyz_batch[None, 0]) 230 | end_time = perf_counter() - start_time 231 | print(f"JIT compilation time (loss): {end_time:.4f} s") 232 | 233 | # define callback function 234 | history = [] 235 | recorder = lambda x: history.append(x) if record else None 236 | 237 | opt = ScipyBoundedMinimize( 238 | fun=compute_loss_diffable, 239 | method=optimizer_name, 240 | jit=True, 241 | tol=tol, 242 | maxiter=maxiter, 243 | options={"disp": False}, 244 | value_and_grad=True, 245 | callback=recorder 246 | ) 247 | 248 | # disable scipy warnings about hitting the box constraints 249 | warnings.filterwarnings("ignore") 250 | 251 | # define parameter bounds 252 | bound_low, bound_up = calculate_params_bounds(mesh, q0, QMIN, QMAX) 253 | bound_low_tree = eqx.tree_at(lambda tree: tree.q, diff_model, replace=(bound_low)) 254 | bound_up_tree = eqx.tree_at(lambda tree: tree.q, diff_model, replace=(bound_up)) 255 | bounds = (bound_low_tree, bound_up_tree) 256 | 257 | # optimize 258 | print("\nOptimizing shapes in sequence") 259 | qs = [] 260 | opt_times = [] 261 | loss_terms_batch = [] 262 | 263 | were_successful = 0 264 | if STOP == -1: 265 | STOP = batch_size 266 | 267 | xyz_slice = xyz_batch[START:STOP] 268 | 269 | # sample target points from prescribed shape name 270 | if shape_name is not None and "bezier" in task_name: 271 | transform = BEZIERS[shape_name] 272 | transform = jnp.array(transform) 273 | xyz = generator.evaluate_points(transform) 274 | xyz_slice = xyz[None, :] 275 | 276 | # Warmstart optimization 277 | _xyz_ = xyz_batch[0][None, :] 278 | start_time = perf_counter() 279 | diff_model_opt, opt_res = opt.run(diff_model, bounds, _xyz_) 280 | end_time = perf_counter() - start_time 281 | print(f"\tJIT compilation time (optimizer): {end_time:.4f} s") 282 | 283 | num_opts = 0 284 | for i, xyz in enumerate(xyz_slice): 285 | 286 | num_opts += 1 287 | xyz = xyz[None, :] 288 | 289 | # report start losses 290 | _, loss_terms = compute_loss(model, structure, xyz, aux_data=True) 291 | if verbose: 292 | print(f"\nShape {i + 1}") 293 | print_loss_summary(loss_terms, prefix="\tStart") 294 | 295 | # optimize 296 | start_time = perf_counter() 297 | diff_model_opt, opt_res = opt.run(diff_model, bounds, xyz) 298 | opt_time = perf_counter() - start_time 299 | 300 | # unite optimal and static submodels 301 | model_opt = eqx.combine(diff_model_opt, static_model) 302 | 303 | # assemble datastructure for post-processing 304 | eqstate_hat, fd_params_hat = model_opt.predict_states(xyz, structure) 305 | mesh_hat = datastructure_updated(mesh, eqstate_hat, fd_params_hat) 306 | network_hat = FDNetwork.from_mesh(mesh_hat) 307 | 308 | # evaluate loss function at optimum point 309 | _, loss_terms = compute_loss(model_opt, structure, xyz, aux_data=True) 310 | 311 | if verbose: 312 | print_loss_summary(loss_terms, prefix="\tEnd") 313 | print(f"\tOpt success?: {opt_res.success}") 314 | print(f"\tOpt iters: {opt_res.iter_num}") 315 | print(f"\tOpt time: {opt_time:.4f} sec") 316 | 317 | if record: 318 | _losses = [] 319 | for xk in history: 320 | _loss, _ = compute_loss_diffable(xk, xyz) 321 | _losses.append(_loss) 322 | 323 | plt.figure() 324 | plt.plot(jnp.array(_losses)) 325 | plt.xlabel("Steps") 326 | plt.ylabel("Loss") 327 | plt.yscale("log") 328 | plt.grid() 329 | plt.show() 330 | 331 | if opt_res.success: 332 | were_successful += 1 333 | 334 | qs.extend([_q.item() for _q in fd_params_hat.q]) 335 | opt_times.append(opt_time) 336 | loss_terms_batch.append(loss_terms) 337 | 338 | # export prediction 339 | if SAVE: 340 | filename = f"mesh_{i}" 341 | filepath = os.path.join(DATA, f"{filename}.json") 342 | mesh_hat.to_json(filepath) 343 | print(f"Saved prediction to {filepath}") 344 | 345 | # visualization 346 | if view: 347 | # create target mesh 348 | mesh_target = mesh.copy() 349 | _xyz = jnp.reshape(xyz, (-1, 3)).tolist() 350 | for idx, key in mesh.index_key().items(): 351 | mesh_target.vertex_attributes(key, "xyz", _xyz[idx]) 352 | 353 | viewer = Viewer( 354 | width=_width, 355 | height=900, 356 | show_grid=False, 357 | viewmode="lighted" 358 | ) 359 | 360 | # modify view 361 | viewer.view.camera.position = CAMERA_CONFIG["position"] 362 | viewer.view.camera.target = CAMERA_CONFIG["target"] 363 | viewer.view.camera.distance = CAMERA_CONFIG["distance"] 364 | _rotation = CAMERA_CONFIG.get("rotation") 365 | if _rotation: 366 | viewer.view.camera.rotation = _rotation 367 | 368 | # edges to view 369 | # NOTE: we are not visualizing edges on boundaries since they are supported 370 | edges_2_view = [edge for edge in mesh.edges() if not mesh.is_edge_on_boundary(*edge)] 371 | 372 | # compute stats 373 | forces_all = [] 374 | forces_comp_all = [] 375 | forces_tens_all = [] 376 | 377 | for edge in network_hat.edges(): 378 | 379 | force = network_hat.edge_force(edge) 380 | force_abs = fabs(force) 381 | forces_all.append(force_abs) 382 | if force <= 0.0: 383 | forces_comp_all.append(force_abs) 384 | else: 385 | forces_tens_all.append(force_abs) 386 | 387 | fmin = 0.0 388 | fmin_comp = 0.0 389 | fmin_tens = 0.0 390 | if fmax is None: 391 | fmax = max(forces_all) 392 | if fmax_tens is None: 393 | if forces_tens_all: 394 | fmax_tens = max(forces_tens_all) 395 | else: 396 | fmax_tens = 0.0 397 | if fmax_comp is None: 398 | if forces_comp_all: 399 | fmax_comp = max(forces_comp_all) 400 | else: 401 | fmax_comp = 0.0 402 | 403 | # edge width 404 | width_min, width_max = EDGEWIDTH 405 | _forces = [fabs(mesh_hat.edge_force(edge)) for edge in mesh_hat.edges()] 406 | _forces = remap_values(_forces, original_min=fmin, original_max=fmax) 407 | _widths = remap_values(_forces, width_min, width_max) 408 | edgewidth = {edge: width for edge, width in zip(mesh_hat.edges(), _widths)} 409 | 410 | # edge colors 411 | edgecolor = EDGECOLOR 412 | if edgecolor == "force": 413 | edgecolor = {} 414 | 415 | color_start = Color.white() 416 | color_comp_end = Color.from_rgb255(12, 119, 184) 417 | cmap_comp = ColorMap.from_two_colors(color_start, color_comp_end) 418 | color_tens_end = Color.from_rgb255(227, 6, 75) 419 | cmap_tens = ColorMap.from_two_colors(color_start, color_tens_end) 420 | 421 | for edge in mesh_hat.edges(): 422 | 423 | force = mesh_hat.edge_force(edge) 424 | 425 | if force == 0.0: 426 | edgecolor[edge] = color_start 427 | else: 428 | if force < 0.0: 429 | _cmap = cmap_comp 430 | _fmin = fmin_comp 431 | _fmax = fmax_comp 432 | else: 433 | _cmap = cmap_tens 434 | _fmin = fmin_tens 435 | _fmax = fmax_tens 436 | 437 | value = (fabs(force) - _fmin) / (_fmax - _fmin) 438 | edgecolor[edge] = _cmap(value) 439 | 440 | viewer.add( 441 | network_hat, 442 | edgewidth=edgewidth, 443 | edgecolor=edgecolor, 444 | show_edges=True, 445 | edges=edges_2_view, 446 | nodes=[node for node in mesh.vertices() if len(mesh.vertex_neighbors(node)) > 2], 447 | show_loads=False, 448 | loadscale=1.0, 449 | show_reactions=show_reactions, 450 | reactionscale=1.0, 451 | reactioncolor=Color.from_rgb255(0, 150, 10), 452 | ) 453 | 454 | viewer.add( 455 | mesh_hat, 456 | show_points=False, 457 | show_edges=False, 458 | opacity=0.7, 459 | color=Color.grey().lightened(100), 460 | ) 461 | 462 | for _vertices in mesh.vertices_on_boundaries(): 463 | viewer.add( 464 | Polyline([mesh_hat.vertex_coordinates(vkey) for vkey in _vertices]), 465 | linewidth=4.0, 466 | color=Color.black().lightened() 467 | ) 468 | 469 | # show le crème 470 | viewer.show() 471 | 472 | # report statistics 473 | print(f"\nSuccessful optimizations: {were_successful}/{num_opts}") 474 | if num_opts > 1: 475 | print(f"Optimization time over {num_opts} optimizations (s): {mean(opt_times):.4f} (+-{stdev(opt_times):.4f})") 476 | 477 | labels = loss_terms_batch[0].keys() 478 | for label in labels: 479 | errors = [terms[label].item() for terms in loss_terms_batch] 480 | print(f"{label.capitalize()} over {num_opts} optimizations: {mean(errors):.4f} (+-{stdev(errors):.4f})") 481 | 482 | if save_metrics: 483 | # Export force densities 484 | filename = f"{optimizer}_{task_name}_q_eval.txt" 485 | filepath = os.path.join(DATA, filename) 486 | 487 | metrics = [f"{_q}\n" for _q in qs] 488 | with open(filepath, 'w') as output: 489 | output.writelines(metrics) 490 | 491 | print(f"Saved batch qs to {filepath}") 492 | 493 | 494 | # =============================================================================== 495 | # Helper functions 496 | # =============================================================================== 497 | 498 | def calculate_params_init(mesh, param_init, key, minval, maxval): 499 | """ 500 | Calculate the initial force densities for the optimization. 501 | 502 | Parameters 503 | ---------- 504 | mesh: `compas.datastructures.Mesh` 505 | The mesh to optimize. 506 | param_init: `float` or `None` 507 | If specified, it determines the starting value of all the model parameters. 508 | If `None`, then it samples parameters between `b_low` and `b_up` from a uniform distribution. 509 | key: `jax.random.PRNGKey` 510 | The random seed for the uniform distribution. 511 | minval: `float` 512 | The minimum value for the uniform distribution. 513 | maxval: `float` 514 | The maximum value for the uniform distribution. 515 | 516 | Returns 517 | ------- 518 | q0: `jax.numpy.ndarray` 519 | The initial force densities. 520 | """ 521 | num_edges = mesh.number_of_edges() 522 | 523 | signs = [] 524 | for edge in mesh.edges(): 525 | sign = -1.0 # compression by default 526 | # for tower task 527 | # FIXME: this method of checking is hand-wavy! 528 | if mesh.edge_attribute(edge, "tag") == "cable": 529 | sign = 1.0 530 | signs.append(sign) 531 | 532 | signs = jnp.array(signs) 533 | 534 | if param_init is not None: 535 | q0 = jnp.ones(num_edges) * param_init 536 | else: 537 | q0 = jrn.uniform(key, shape=(num_edges, ), minval=minval, maxval=maxval) 538 | 539 | return q0 * signs 540 | 541 | 542 | def calculate_params_bounds(mesh, q0, minval, maxval): 543 | """ 544 | Calculate the box constraints for the optimization. 545 | 546 | Parameters 547 | ---------- 548 | mesh: `compas.datastructures.Mesh` 549 | The mesh to optimize. 550 | q0: `jax.numpy.ndarray` 551 | The initial force densities. 552 | minval: `float` 553 | The value of the lower bound. 554 | maxval: `float` 555 | The value of the upper bound. 556 | 557 | Returns 558 | ------- 559 | bound_low: `jax.numpy.ndarray` 560 | The lower box constraint. 561 | bound_up: `jax.numpy.ndarray` 562 | The upper box constraint. 563 | """ 564 | bound_low = [] 565 | bound_up = [] 566 | for edge in mesh.edges(): 567 | # compression by default 568 | b_low = maxval * -1.0 569 | b_up = minval * -1.0 570 | # for tower task 571 | if mesh.edge_attribute(edge, "tag") == "cable": 572 | b_low = minval 573 | b_up = maxval 574 | 575 | bound_low.append(b_low) 576 | bound_up.append(b_up) 577 | 578 | bound_low = jnp.array(bound_low) 579 | bound_up = jnp.array(bound_up) 580 | 581 | return bound_low, bound_up 582 | 583 | 584 | # =============================================================================== 585 | # Main 586 | # =============================================================================== 587 | 588 | if __name__ == "__main__": 589 | 590 | from fire import Fire 591 | 592 | Fire(optimize_batch) 593 | -------------------------------------------------------------------------------- /scripts/predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Predict the force densities and shapes of a batch of target shapes with a pretrained model. 3 | """ 4 | import os 5 | from functools import partial 6 | from math import fabs 7 | import yaml 8 | 9 | from time import perf_counter 10 | from statistics import mean 11 | from statistics import stdev 12 | 13 | import jax 14 | from jax import jit 15 | from jax import vmap 16 | import jax.numpy as jnp 17 | 18 | import jax.random as jrn 19 | 20 | import equinox as eqx 21 | 22 | from compas.colors import Color 23 | from compas.colors import ColorMap 24 | from compas.geometry import Polygon 25 | from compas.geometry import Line 26 | from compas.utilities import remap_values 27 | 28 | from jax_fdm.datastructures import FDNetwork 29 | from jax_fdm.equilibrium import datastructure_updated 30 | from jax_fdm.visualization import Viewer 31 | 32 | from neural_fdm import DATA 33 | 34 | from neural_fdm.builders import build_loss_function 35 | from neural_fdm.builders import build_mesh_from_generator 36 | from neural_fdm.builders import build_data_generator 37 | from neural_fdm.builders import build_connectivity_structure_from_generator 38 | from neural_fdm.builders import build_neural_model 39 | 40 | from neural_fdm.losses import print_loss_summary 41 | 42 | from neural_fdm.serialization import load_model 43 | 44 | from camera import CAMERA_CONFIG_BEZIER 45 | from camera import CAMERA_CONFIG_TOWER 46 | 47 | from train import count_model_params 48 | 49 | 50 | # =============================================================================== 51 | # Script function 52 | # =============================================================================== 53 | 54 | def predict_batch( 55 | model_name, 56 | task_name, 57 | seed=None, 58 | batch_size=None, 59 | time_batch_inference=False, 60 | predict_in_sequence=True, 61 | slice=(0, -1), # (50, 53) for bezier 62 | view=False, 63 | save=False, 64 | save_metrics=False, 65 | edgecolor=None, 66 | ): 67 | """ 68 | Predict a batch of target shapes with a pretrained model. 69 | 70 | Parameters 71 | ---------- 72 | model_name: `str` 73 | The model name. 74 | Supported models are formfinder, autoencoder, and piggy. 75 | Append the suffix `_pinn` to load model versions that were trained with a PINN loss. 76 | task_name: `str` 77 | The name of the YAML config file with the task hyperparameters. 78 | seed: `int` or `None`, optional 79 | The random seed to generate a batch of target shapes. 80 | If `None`, it defaults to the task hyperparameters file. 81 | batch_size: `int` or `None`, optional 82 | The size of the batch of target shapes. 83 | If `None`, it defaults to the task hyperparameters file. 84 | time_batch_inference: `bool`, optional 85 | If `True`, report the inference time over a data batch. 86 | predict_in_sequence: `bool`, optional 87 | If `True`, predict every shape in the prescribed slice of the data batch, one at a time. 88 | slice: `tuple`, optional 89 | The start and stop indices of the slice of the batch for saving and viewing. 90 | Defaults to all the shapes in the batch. 91 | view: `bool`, optional 92 | If `True`, view the predicted shapes. 93 | save: `bool`, optional 94 | If `True`, saves the predicted shapes as JSON files. 95 | save_metrics: `bool`, optional 96 | If `True`, saves the calculated batch metrics in text files. 97 | edgecolor: `str` or `None`, optional 98 | The color palette for the edges. 99 | Supported color palettes are fd to display force densities, and force to show forces. 100 | If `None`, the edges are colored by the force density in the shells tasks, and by the force in the tower tasks. 101 | """ 102 | if edgecolor is None: 103 | if task_name == "bezier": 104 | EDGECOLOR = "fd" 105 | elif task_name == "tower": 106 | EDGECOLOR = "force" 107 | else: 108 | EDGECOLOR = edgecolor 109 | 110 | START, STOP = slice 111 | 112 | # load yaml file with hyperparameters 113 | with open(f"{task_name}.yml") as file: 114 | config = yaml.load(file, Loader=yaml.FullLoader) 115 | 116 | # unpack parameters 117 | if seed is None: 118 | seed = config["seed"] 119 | training_params = config["training"] 120 | if batch_size is None: 121 | batch_size = training_params["batch_size"] 122 | 123 | if STOP == -1: 124 | STOP = batch_size 125 | 126 | generator_name = config['generator']['name'] 127 | bounds_name = config['generator']['bounds'] 128 | 129 | # randomness 130 | key = jrn.PRNGKey(seed) 131 | model_key, generator_key = jax.random.split(key, 2) 132 | 133 | # create data generator 134 | generator = build_data_generator(config) 135 | structure = build_connectivity_structure_from_generator(config, generator) 136 | mesh = build_mesh_from_generator(config, generator) 137 | compute_loss = build_loss_function(config, generator) 138 | 139 | # print info 140 | print(f"Making predictions with {model_name} on {generator_name} dataset with {bounds_name} bounds\n") 141 | print(f"Structure size: {structure.num_vertices} vertices, {structure.num_edges} edges") 142 | 143 | # load model 144 | filepath = os.path.join(DATA, f"{model_name}_{task_name}.eqx") 145 | _model_name = model_name.split("_")[0] 146 | model_skeleton = build_neural_model(_model_name, config, generator, model_key) 147 | model = load_model(filepath, model_skeleton) 148 | print(f"Model parameter count: {count_model_params(model)}") 149 | 150 | # sample data batch 151 | xyz_batch = vmap(generator)(jrn.split(generator_key, batch_size)) 152 | 153 | # inference function to time 154 | timed_fn = jit(vmap(partial(model, structure=structure))) 155 | 156 | # NOTE: Using eqx.debug to ensure this function is compiled at most once 157 | timed_fn = vmap(partial(model, structure=structure)) 158 | timed_fn = eqx.debug.assert_max_traces(timed_fn, max_traces=1) 159 | timed_fn = jit(timed_fn) 160 | 161 | # time inference time on full batch 162 | if time_batch_inference: 163 | 164 | # warmstart 165 | timed_fn(xyz_batch) 166 | 167 | # time 168 | times = [] 169 | for i in range(10): 170 | start = perf_counter() 171 | timed_fn(xyz_batch).block_until_ready() 172 | duration = 1000.0 * (perf_counter() - start) # time in milliseconds 173 | times.append(duration) 174 | print(f"Inference time on batch size {batch_size}: {mean(times):.5f} (+-{stdev(times):.5f}) ms") 175 | 176 | # report batch losses 177 | _, loss_terms = compute_loss(model, structure, xyz_batch, aux_data=True) 178 | print_loss_summary(loss_terms, prefix="Batch\t") 179 | 180 | # make individual predictions 181 | if not predict_in_sequence: 182 | return 183 | 184 | print("\nPredicting shapes in sequence") 185 | qs = [] 186 | opt_times = [] 187 | loss_terms_batch = [] 188 | num_predictions = 0 189 | 190 | # warmstart again, just in case 191 | _xyz_ = xyz_batch[0][None, :] 192 | start_time = perf_counter() 193 | timed_fn(_xyz_).block_until_ready() 194 | end_time = perf_counter() - start_time 195 | print(f"JIT compilation time: {end_time * 1000.0:.2f} ms") 196 | 197 | start = perf_counter() 198 | for i in range(START, STOP): 199 | 200 | xyz = xyz_batch[i] 201 | _xyz = xyz[None, :] 202 | 203 | # do inference on one design 204 | start_time = perf_counter() 205 | timed_fn(_xyz).block_until_ready() 206 | end_time = perf_counter() - start_time # time in seconds 207 | opt_times.append(end_time) 208 | num_predictions += 1 209 | 210 | # calculate loss 211 | _, loss_terms = compute_loss( 212 | model, 213 | structure, 214 | xyz[None, :], 215 | aux_data=True 216 | ) 217 | 218 | # predict equilibrium states for viz and i/o 219 | eqstate_hat, fd_params_hat = model.predict_states(xyz, structure) 220 | mesh_hat = datastructure_updated(mesh, eqstate_hat, fd_params_hat) 221 | 222 | # print loss statistics 223 | loss_terms_batch.append(loss_terms) 224 | loss_terms["time"] = jnp.array([end_time]) 225 | print_loss_summary(loss_terms, prefix=f"Shape {i} \t") 226 | 227 | # extract additional statistics 228 | qs.extend([_q.item() for _q in fd_params_hat.q]) 229 | 230 | if view or save: 231 | # assemble datastructure for post-processing 232 | network_hat = FDNetwork.from_mesh(mesh_hat) 233 | network_hat.print_stats() 234 | print() 235 | 236 | # export prediction 237 | if save: 238 | filename = f"mesh_{model_name}_{task_name}_{i}" 239 | filepath = os.path.join(DATA, f"{filename}.json") 240 | mesh_hat.to_json(filepath) 241 | print(f"Saved prediction to {filepath}") 242 | 243 | # Create target mesh 244 | mesh_target = mesh.copy() 245 | _xyz = jnp.reshape(xyz, (-1, 3)).tolist() 246 | for idx, key in mesh.index_key().items(): 247 | mesh_target.vertex_attributes(key, "xyz", _xyz[idx]) 248 | 249 | # visualization 250 | if view: 251 | # pick camera configuration for task 252 | if task_name == "bezier": 253 | _width = 900 254 | CAMERA_CONFIG = CAMERA_CONFIG_BEZIER 255 | elif task_name == "tower": 256 | _width = 450 257 | CAMERA_CONFIG = CAMERA_CONFIG_TOWER 258 | 259 | viewer = Viewer( 260 | width=_width, 261 | height=900, 262 | show_grid=False, 263 | viewmode="lighted" 264 | ) 265 | 266 | # modify view 267 | viewer.view.camera.position = CAMERA_CONFIG["position"] 268 | viewer.view.camera.target = CAMERA_CONFIG["target"] 269 | viewer.view.camera.distance = CAMERA_CONFIG["distance"] 270 | 271 | # edge colors 272 | if EDGECOLOR == "force": 273 | 274 | color_end = Color.from_rgb255(12, 119, 184) 275 | color_start = Color.white() 276 | cmap = ColorMap.from_two_colors(color_start, color_end) 277 | 278 | edgecolor = {} 279 | forces = [fabs(network_hat.edge_force(edge)) for edge in network_hat.edges()] 280 | fmin = min(forces) 281 | fmax = max(forces) 282 | 283 | for edge in network_hat.edges(): 284 | force = network_hat.edge_force(edge) * -1.0 285 | if force < 0.0: 286 | _color = Color.from_rgb255(227, 6, 75) 287 | else: 288 | value = (force - fmin) / (fmax - fmin) 289 | _color = cmap(value) 290 | 291 | edgecolor[edge] = _color 292 | 293 | elif task_name == "tower" and EDGECOLOR == "fd": 294 | edgecolor = {} 295 | cmap = ColorMap.from_mpl("viridis") 296 | _edges = [edge for edge in network_hat.edges() if mesh.edge_attribute(edge, "tag") == "cable"] 297 | values = [fabs(mesh_hat.edge_forcedensity(edge)) for edge in _edges] 298 | ratios = remap_values(values) 299 | edgecolor = {edge: cmap(ratio) for edge, ratio in zip(_edges, ratios)} 300 | for edge in network_hat.edges(): 301 | if mesh.edge_attribute(edge, "tag") != "cable": 302 | edgecolor[edge] = Color.pink() 303 | 304 | else: 305 | edgecolor = EDGECOLOR 306 | 307 | vertices_2_view = list(mesh.vertices()) 308 | color_load = Color.from_rgb255(0, 150, 10) 309 | reactioncolor = color_load 310 | show_reactions = True 311 | if task_name == "bezier": 312 | if EDGECOLOR == "fd": 313 | show_reactions = False 314 | vertices_2_view = [] 315 | for vkey in mesh.vertices(): 316 | if len(mesh.vertex_neighbors(vkey)) < 3: 317 | continue 318 | vertices_2_view.append(vkey) 319 | 320 | reactioncolor = {} 321 | for vkey in mesh.vertices(): 322 | _color = Color.pink() 323 | if mesh.is_vertex_on_boundary(vkey): 324 | _color = color_load 325 | reactioncolor[vkey] = _color 326 | 327 | viewer.add( 328 | network_hat, 329 | edgewidth=(0.01, 0.25), 330 | edgecolor=edgecolor, 331 | show_edges=True, 332 | edges=[edge for edge in mesh.edges() if not mesh.is_edge_on_boundary(*edge)], 333 | nodes=vertices_2_view, 334 | show_loads=False, 335 | loadscale=1.0, 336 | show_reactions=show_reactions, 337 | reactionscale=1.0, 338 | reactioncolor=reactioncolor 339 | ) 340 | 341 | if task_name == "bezier": 342 | # approximated mesh 343 | viewer.add( 344 | mesh_hat, 345 | show_points=False, 346 | show_edges=False, 347 | opacity=0.1 348 | ) 349 | 350 | # target mesh 351 | viewer.add( 352 | FDNetwork.from_mesh(mesh_target), 353 | as_wireframe=True, 354 | show_points=False, 355 | linewidth=4.0, 356 | color=Color.black().lightened() 357 | ) 358 | 359 | elif task_name == "tower": 360 | rings = jnp.reshape(xyz, generator.shape_tube)[generator.levels_rings_comp, :, :] 361 | for ring in rings: 362 | ring = Polygon(ring.tolist()) 363 | viewer.add(ring, opacity=0.5) 364 | 365 | xyz_hat = model(xyz, structure) 366 | rings_hat = jnp.reshape(xyz_hat, generator.shape_tube)[generator.levels_rings_comp, :, :] 367 | for ring_a, ring_b in zip(rings, rings_hat): 368 | for pt_a, pt_b in zip(ring_a, ring_b): 369 | viewer.add(Line(pt_a, pt_b)) 370 | 371 | # show le crème 372 | viewer.show() 373 | 374 | # report statistics 375 | opt_times = [t * 1000.0 for t in opt_times] # Convert seconds to milliseconds 376 | print(f"Inference time over {num_predictions} samples (ms): {mean(opt_times):.4f} (+-{stdev(opt_times):.4f})") 377 | 378 | labels = loss_terms_batch[0].keys() 379 | for label in labels: 380 | errors = [terms[label].item() for terms in loss_terms_batch] 381 | print(f"{label.capitalize()} over {num_predictions} samples: {mean(errors):.4f} (+-{stdev(errors):.4f})") 382 | 383 | if save_metrics: 384 | # Export force densities 385 | filename = f"{model_name}_{task_name}_q_eval.txt" 386 | filepath = os.path.join(DATA, filename) 387 | 388 | metrics = [f"{_q}\n" for _q in qs] 389 | with open(filepath, 'w') as output: 390 | output.writelines(metrics) 391 | 392 | print(f"Saved batch qs to {filepath}") 393 | 394 | 395 | # =============================================================================== 396 | # Main 397 | # =============================================================================== 398 | 399 | if __name__ == "__main__": 400 | 401 | from fire import Fire 402 | 403 | Fire(predict_batch) 404 | -------------------------------------------------------------------------------- /scripts/predict_optimize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Optimize the force densities and shapes of a batch of target shapes starting from a pretrained model predictions (no vectorization). 3 | """ 4 | import os 5 | from math import fabs 6 | import yaml 7 | 8 | import warnings 9 | 10 | from time import perf_counter 11 | from statistics import mean 12 | from statistics import stdev 13 | 14 | import jax 15 | from jax import vmap 16 | import jax.numpy as jnp 17 | 18 | import jax.random as jrn 19 | 20 | import equinox as eqx 21 | 22 | from jaxopt import ScipyBoundedMinimize 23 | 24 | from compas.colors import Color 25 | from compas.colors import ColorMap 26 | from compas.geometry import Polygon 27 | from compas.geometry import Line 28 | 29 | from jax_fdm.datastructures import FDNetwork 30 | from jax_fdm.equilibrium import datastructure_updated 31 | from jax_fdm.visualization import Viewer 32 | 33 | from neural_fdm import DATA 34 | 35 | from neural_fdm.builders import build_mesh_from_generator 36 | from neural_fdm.builders import build_data_generator 37 | from neural_fdm.builders import build_connectivity_structure_from_generator 38 | from neural_fdm.builders import build_fd_decoder_parametrized 39 | from neural_fdm.builders import build_loss_function 40 | from neural_fdm.builders import build_neural_model 41 | 42 | from neural_fdm.losses import print_loss_summary 43 | 44 | from neural_fdm.serialization import load_model 45 | 46 | from camera import CAMERA_CONFIG_BEZIER 47 | from camera import CAMERA_CONFIG_TOWER 48 | 49 | from optimize import calculate_params_init 50 | from optimize import calculate_params_bounds 51 | 52 | 53 | # =============================================================================== 54 | # Script function 55 | # =============================================================================== 56 | 57 | def predict_optimize_batch( 58 | model_name, 59 | optimizer_name, 60 | task_name, 61 | blow=0.0, 62 | bup=20.0, 63 | maxiter=5000, 64 | tol=1e-6, 65 | seed=None, 66 | batch_size=None, 67 | verbose=True, 68 | save=False, 69 | view=False, 70 | slice=(0, -1), # (50, 53) for bezier 71 | edgecolor=None, 72 | ): 73 | """ 74 | Solve the prediction task on a batch target shapes with gradient-based optimization 75 | and box constraints, using a neural model to warmstart the optimization. 76 | The box constraints help generating compression-only or tension-only solutions. 77 | 78 | This script optimizes and visualizes. This is probably not the best idea, but oh well. 79 | 80 | Parameters 81 | ---------- 82 | model_name: `str` 83 | The model name. 84 | Supported models are formfinder, autoencoder, and piggy. 85 | Append the suffix `_pinn` to load model versions that were trained with a PINN loss. 86 | optimizer_name: `str` 87 | The name gradient-based optimizer used to solve this task. 88 | Supported methods are slsqp and lbfgsb. 89 | task_name: `str` 90 | The name of the YAML config file with the task hyperparameters. 91 | blow: `float`, optional 92 | The lower bound of the box constraints on the model parameters. 93 | The bounds respect the force density signs of a task (compression or tension, currently hardcoded). 94 | bup: `float`, optional 95 | The lower bound of the box constraints on the model parameters. 96 | The bounds respect the force density signs of a task (compression or tension, currently hardcoded). 97 | maxiter: `int`, optional 98 | The maximum number of optimization iterations. 99 | tol: `float`, optional 100 | The tolerance for the optimization. 101 | seed: `int` or `None`, optional 102 | The random seed to generate a batch of target shapes. 103 | If `None`, it defaults to the input hyperparameters file. 104 | batch_size: `int` or `None`, optional 105 | The size of the batch of target shapes. 106 | If `None`, it defaults to the input hyperparameters file. 107 | verbose: `bool`, optional 108 | If `True`, print to stdout intermediary results. 109 | save: `bool`, optional 110 | If `True`, save the predicted shapes as JSON files. 111 | view: `bool`, optional 112 | If `True`, view the predicted shapes. 113 | slice: `tuple`, optional 114 | The start and stop indices of the slice of the batch for saving and viewing. 115 | Defaults to all the shapes in the batch. 116 | edgecolor: `str` or `None`, optional 117 | The color palette for the edges. 118 | Supported color palettes are fd to display force densities, and force to show forces. 119 | If `None`, the edges are colored by the force density in the shells tasks, and by the force in the tower tasks. 120 | """ 121 | if edgecolor is None: 122 | if task_name == "bezier": 123 | EDGECOLOR = "fd" 124 | elif task_name == "tower": 125 | EDGECOLOR = "force" 126 | else: 127 | EDGECOLOR = edgecolor 128 | 129 | START, STOP = slice 130 | SAVE = save 131 | QMIN = blow 132 | QMAX = bup 133 | 134 | # pick camera configuration for task 135 | if task_name == "bezier": 136 | CAMERA_CONFIG = CAMERA_CONFIG_BEZIER 137 | _width = 900 138 | elif task_name == "tower": 139 | CAMERA_CONFIG = CAMERA_CONFIG_TOWER 140 | _width = 450 141 | 142 | # pick optimizer name 143 | optimizer_names = {"lbfgsb": "L-BFGS-B", "slsqp": "SLSQP"} 144 | optimizer_name = optimizer_names[optimizer_name] 145 | 146 | # load yaml file with hyperparameters 147 | with open(f"{task_name}.yml") as file: 148 | config = yaml.load(file, Loader=yaml.FullLoader) 149 | 150 | # unpack parameters 151 | if seed is None: 152 | seed = config["seed"] 153 | training_params = config["training"] 154 | if batch_size is None: 155 | batch_size = training_params["batch_size"] 156 | 157 | generator_name = config['generator']['name'] 158 | bounds_name = config['generator']['bounds'] 159 | fd_params = config["fdm"] 160 | 161 | # randomness 162 | key = jrn.PRNGKey(seed) 163 | model_key, generator_key = jax.random.split(key, 2) 164 | 165 | # create data generator 166 | generator = build_data_generator(config) 167 | compute_loss = build_loss_function(config, generator) 168 | structure = build_connectivity_structure_from_generator(config, generator) 169 | mesh = build_mesh_from_generator(config, generator) 170 | 171 | # load model 172 | filepath = os.path.join(DATA, f"{model_name}_{task_name}.eqx") 173 | _model_name = model_name.split("_")[0] 174 | model_skeleton = build_neural_model(_model_name, config, generator, model_key) 175 | model = load_model(filepath, model_skeleton) 176 | 177 | # generate initial model parameters 178 | q0 = calculate_params_init(mesh, None, key, QMIN, QMAX) 179 | 180 | # create model 181 | print(f"Directly optimizing with {optimizer_name} using {model_name} init for {generator_name} dataset with {bounds_name} bounds on seed {seed}") 182 | decoder = build_fd_decoder_parametrized(q0, mesh, fd_params) 183 | 184 | # sample data batch 185 | xyz_batch = vmap(generator)(jrn.split(generator_key, batch_size)) 186 | 187 | # split mode 188 | diff_decoder, static_decoder = eqx.partition(decoder, eqx.is_inexact_array) 189 | 190 | # wrap loss function to meet jax and jaxopt's ideosyncracies 191 | @eqx.filter_jit 192 | @eqx.debug.assert_max_traces(max_traces=1) # Ensure this function is compiled at most once 193 | @eqx.filter_value_and_grad 194 | def compute_loss_diffable(diff_decoder, xyz_target): 195 | """ 196 | """ 197 | _decoder = eqx.combine(diff_decoder, static_decoder) 198 | return compute_loss(_decoder, structure, xyz_target, aux_data=False) 199 | 200 | # warmstart loss function to eliminate jit compilation time from perf measurements 201 | _ = compute_loss_diffable(diff_decoder, xyz_target=xyz_batch[None, 0]) 202 | 203 | # define optimization function 204 | warnings.filterwarnings("ignore") 205 | 206 | opt = ScipyBoundedMinimize( 207 | fun=compute_loss_diffable, 208 | method=optimizer_name, 209 | jit=True, 210 | tol=tol, 211 | maxiter=maxiter, 212 | options={"disp": False}, 213 | value_and_grad=True, 214 | ) 215 | 216 | # define parameter bounds 217 | bound_low, bound_up = calculate_params_bounds(mesh, q0, QMIN, QMAX) 218 | bound_low_tree = eqx.tree_at(lambda tree: tree.q, diff_decoder, replace=(bound_low)) 219 | bound_up_tree = eqx.tree_at(lambda tree: tree.q, diff_decoder, replace=(bound_up)) 220 | bounds = (bound_low_tree, bound_up_tree) 221 | 222 | # optimize 223 | print("\nOptimizing shapes in sequence") 224 | opt_times = [] 225 | loss_terms_batch = [] 226 | 227 | were_successful = 0 228 | if STOP == -1: 229 | STOP = batch_size 230 | 231 | xyz_slice = xyz_batch[START:STOP] 232 | 233 | # Warmstart optimization 234 | _xyz_ = xyz_slice[0][None, :] 235 | start_time = perf_counter() 236 | diff_model_opt, opt_res = opt.run(diff_decoder, bounds, _xyz_) 237 | end_time = perf_counter() - start_time 238 | print(f"\tJIT compilation time (optimizer): {end_time:.4f} s") 239 | 240 | num_opts = xyz_slice.shape[0] 241 | for i, xyz in enumerate(xyz_slice): 242 | 243 | # get sample from batch, add extra dimension 244 | xyz = xyz[None, :] 245 | 246 | # predict with pretrained model 247 | if task_name == "bezier": 248 | q0 = model.encode(xyz.ravel()) 249 | else: 250 | q0 = model.encode(xyz) 251 | 252 | # reinitialize decoders with pretrained model predictions 253 | decoder = eqx.tree_at(lambda tree: tree.q, decoder, replace=q0) 254 | diff_decoder = eqx.tree_at(lambda tree: tree.q, diff_decoder, replace=q0) 255 | 256 | # report start losses 257 | _, loss_terms = compute_loss(decoder, structure, xyz, aux_data=True) 258 | if verbose: 259 | print(f"Shape {i + 1}") 260 | print_loss_summary(loss_terms, prefix="\tStart") 261 | 262 | # optimize 263 | start_time = perf_counter() 264 | diff_model_opt, opt_res = opt.run(diff_decoder, bounds, xyz) 265 | opt_time = perf_counter() - start_time 266 | 267 | # unite optimal and static submodels 268 | model_opt = eqx.combine(diff_model_opt, static_decoder) 269 | 270 | # evaluate loss function at optimum point 271 | _, loss_terms = compute_loss(model_opt, structure, xyz, aux_data=True) 272 | if verbose: 273 | print_loss_summary(loss_terms, prefix="\tEnd") 274 | print(f"\tOpt success?: {opt_res.success}") 275 | print(f"\tOpt iters: {opt_res.iter_num}") 276 | print(f"\tOpt time: {opt_time:.4f} sec") 277 | 278 | if opt_res.success: 279 | were_successful += 1 280 | 281 | opt_times.append(opt_time) 282 | loss_terms_batch.append(loss_terms) 283 | 284 | # assemble datastructure for post-processing 285 | eqstate_hat, fd_params_hat = model_opt.predict_states(xyz, structure) 286 | mesh_hat = datastructure_updated(mesh, eqstate_hat, fd_params_hat) 287 | network_hat = FDNetwork.from_mesh(mesh_hat) 288 | 289 | # export prediction 290 | if SAVE: 291 | filename = f"mesh_{i}" 292 | filepath = os.path.join(DATA, f"{filename}.json") 293 | mesh_hat.to_json(filepath) 294 | print(f"Saved prediction to {filepath}") 295 | 296 | # visualization 297 | if view: 298 | # create target mesh 299 | mesh_target = mesh.copy() 300 | _xyz = jnp.reshape(xyz, (-1, 3)).tolist() 301 | for idx, key in mesh.index_key().items(): 302 | mesh_target.vertex_attributes(key, "xyz", _xyz[idx]) 303 | 304 | viewer = Viewer( 305 | width=_width, 306 | height=900, 307 | show_grid=False, 308 | viewmode="lighted" 309 | ) 310 | 311 | # modify view 312 | viewer.view.camera.position = CAMERA_CONFIG["position"] 313 | viewer.view.camera.target = CAMERA_CONFIG["target"] 314 | viewer.view.camera.distance = CAMERA_CONFIG["distance"] 315 | _rotation = CAMERA_CONFIG.get("rotation") 316 | if _rotation: 317 | viewer.view.camera.rotation = _rotation 318 | 319 | # edge colors 320 | if EDGECOLOR == "force": 321 | 322 | color_end = Color.from_rgb255(12, 119, 184) 323 | color_start = Color.white() 324 | cmap = ColorMap.from_two_colors(color_start, color_end) 325 | 326 | edgecolor = {} 327 | forces = [fabs(network_hat.edge_force(edge)) for edge in network_hat.edges()] 328 | fmin = min(forces) 329 | fmax = max(forces) 330 | 331 | for edge in network_hat.edges(): 332 | force = network_hat.edge_force(edge) * -1.0 333 | if force < 0.0: 334 | _color = Color.from_rgb255(227, 6, 75) 335 | else: 336 | value = (force - fmin) / (fmax - fmin) 337 | _color = cmap(value) 338 | 339 | edgecolor[edge] = _color 340 | else: 341 | edgecolor = EDGECOLOR 342 | 343 | viewer.add( 344 | network_hat, 345 | edgewidth=(0.01, 0.3), 346 | edgecolor=edgecolor, 347 | show_edges=True, 348 | edges=[edge for edge in mesh.edges() if not mesh.is_edge_on_boundary(*edge)], 349 | nodes=[node for node in mesh.vertices() if len(mesh.vertex_neighbors(node)) > 2], 350 | show_loads=False, 351 | loadscale=1.0, 352 | show_reactions=True, 353 | reactionscale=1.0, 354 | reactioncolor=Color.from_rgb255(0, 150, 10), 355 | ) 356 | 357 | if task_name == "bezier": 358 | # target mesh 359 | viewer.add( 360 | FDNetwork.from_mesh(mesh_target), 361 | as_wireframe=True, 362 | show_points=False, 363 | linewidth=4.0, 364 | color=Color.black().lightened() 365 | ) 366 | 367 | # approximated mesh 368 | viewer.add( 369 | mesh_hat, 370 | show_points=False, 371 | show_edges=False, 372 | opacity=0.2 373 | ) 374 | 375 | elif task_name == "tower": 376 | rings = jnp.reshape(xyz, generator.shape_tube)[generator.levels_rings_comp, :, :] 377 | for ring in rings: 378 | ring = Polygon(ring.tolist()) 379 | viewer.add(ring, opacity=0.5) 380 | 381 | lengths = [] 382 | xyz_hat = model_opt(xyz, structure) 383 | rings_hat = jnp.reshape(xyz_hat, generator.shape_tube)[generator.levels_rings_comp, :, :] 384 | for ring_a, ring_b in zip(rings, rings_hat): 385 | for pt_a, pt_b in zip(ring_a, ring_b): 386 | line = Line(pt_a, pt_b) 387 | viewer.add(line) 388 | lengths.append(line.length**2) 389 | 390 | # show le crème 391 | viewer.show() 392 | 393 | # report optimization statistics 394 | print(f"\nSuccessful optimizations: {were_successful}/{num_opts}") 395 | print(f"Optimization time over {num_opts} optimizations (s): {mean(opt_times):.4f} (+-{stdev(opt_times):.4f})") 396 | 397 | labels = loss_terms_batch[0].keys() 398 | for label in labels: 399 | errors = [terms[label].item() for terms in loss_terms_batch] 400 | print(f"{label.capitalize()} over {num_opts} optimizations: {mean(errors):.4f} (+-{stdev(errors):.4f})") 401 | 402 | 403 | # =============================================================================== 404 | # Main 405 | # =============================================================================== 406 | 407 | if __name__ == "__main__": 408 | 409 | from fire import Fire 410 | 411 | Fire(predict_optimize_batch) 412 | -------------------------------------------------------------------------------- /scripts/shapes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prescribed shapes for the shell and tower tasks. 3 | """ 4 | 5 | # =============================================================================== 6 | # Shell task - These shapes need of a `bezier_symmetric_double` generator. 7 | # =============================================================================== 8 | 9 | # pillow 10 | BEZIER_PILLOW = [ 11 | [0.0, 0.0, 10.0], 12 | [0.0, 0.0, 0.0], 13 | [0.0, 0.0, 0.0], 14 | [0.0, 0.0, 0.0] 15 | ] 16 | 17 | # circular dome 18 | BEZIER_DOME = [ 19 | [0.0, 0.0, 10.0], 20 | [2.75, 0.0, 0.0], 21 | [0.0, 2.75, 0.0], 22 | [0.0, 0.0, 0.0] 23 | ] 24 | 25 | # cute saddle 26 | BEZIER_SADDLE = [ 27 | [0.0, 0.0, 1.5], 28 | [-1.25, 0.0, 5.0], 29 | [0.0, -2.5, 0.0], 30 | [0.0, 0.0, 0.0] 31 | ] 32 | 33 | # cute hypar 34 | BEZIER_HYPAR = [ 35 | [0.0, 0.0, 1.5], 36 | [-1.25, 0.0, 7.5], 37 | [0.0, 1.25, 0.0], 38 | [0.0, 0.0, 0.0] 39 | ] 40 | 41 | # cute pringle 42 | BEZIER_PRINGLE = [ 43 | [0.0, 0.0, 1.5], 44 | [1.25, 1.25, 0.0], 45 | [-1.25, 0.0, 7.5], 46 | [0.0, 0.0, 0.0] 47 | ] 48 | 49 | # cannon vault 50 | BEZIER_CANNON = [ 51 | [0.0, 0.0, 6.0], 52 | [0.0, 0.0, 6.0], 53 | [0.0, 0.0, 0.0], 54 | [0.0, 0.0, 0.0] 55 | ] 56 | 57 | BEZIERS = { 58 | "pillow": BEZIER_PILLOW, 59 | "dome": BEZIER_DOME, 60 | "saddle": BEZIER_SADDLE, 61 | "hypar": BEZIER_HYPAR, 62 | "pringle": BEZIER_PRINGLE, 63 | "cannon": BEZIER_CANNON, 64 | } 65 | 66 | # =============================================================================== 67 | # Tower task 68 | # =============================================================================== 69 | 70 | TOWER_ANGLES = [0.0, 0.0, 0.0] 71 | TOWER_RADII_FIXED = [0.75, 0.75] 72 | TOWER_RADII = [TOWER_RADII_FIXED, [0.75, 0.75], TOWER_RADII_FIXED] 73 | 74 | TOWERS = { 75 | -30: [TOWER_RADII, [0.0, -30.0, 0.0]], 76 | -22: [TOWER_RADII, [0.0, -22.0, 0.0]], 77 | -15: [TOWER_RADII, [0.0, -15.0, 0.0]], 78 | -7: [TOWER_RADII, [0.0, -7, 0.0]], 79 | 0: [TOWER_RADII, [0.0, 0.0, 0.0]], 80 | 7: [TOWER_RADII, [0.0, 7.0, 0.0]], 81 | 15: [TOWER_RADII, [0.0, 15.0, 0.0]], 82 | 22: [TOWER_RADII, [0.0, 22.0, 0.0]], 83 | 30: [TOWER_RADII, [0.0, 30.0, 0.0]], 84 | 0.5: [[TOWER_RADII_FIXED, [0.5, 0.5], TOWER_RADII_FIXED], TOWER_ANGLES], 85 | 0.75: [[TOWER_RADII_FIXED, [0.75, 0.75], TOWER_RADII_FIXED], TOWER_ANGLES], 86 | 1.0: [[TOWER_RADII_FIXED, [1.0, 1.0], TOWER_RADII_FIXED], TOWER_ANGLES], 87 | 1.25: [[TOWER_RADII_FIXED, [1.25, 1.25], TOWER_RADII_FIXED], TOWER_ANGLES], 88 | 1.5: [[TOWER_RADII_FIXED, [1.5, 1.5], TOWER_RADII_FIXED], TOWER_ANGLES], 89 | 90 | } 91 | -------------------------------------------------------------------------------- /scripts/sweep.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | from neural_fdm.serialization import save_model 4 | 5 | from train import train_model_from_config 6 | 7 | 8 | # =============================================================================== 9 | # Helper functions 10 | # =============================================================================== 11 | 12 | def log_to_wandb(model, opt_state, loss_vals, step): 13 | """ 14 | Record metrics in weights and biases. Used as a callback in the training loop. 15 | 16 | Parameters 17 | ---------- 18 | model: `eqx.Module` 19 | The model to record the metrics of. 20 | opt_state: `eqx.Module` 21 | The optimizer state. 22 | loss_vals: `dict` 23 | The loss values. 24 | step: `int` 25 | The current training step. 26 | """ 27 | metrics = {} 28 | for key, value in loss_vals.items(): 29 | metrics[key] = value.item() 30 | 31 | wandb.log(metrics) 32 | 33 | 34 | # =============================================================================== 35 | # Script function 36 | # =============================================================================== 37 | 38 | def sweep(**kwargs): 39 | """ 40 | Sweep a model to find adequate hyperparameters. 41 | We use weights and biases to execute and log the training loops. 42 | 43 | Parameters 44 | ---------- 45 | **kwargs: `dict` 46 | Optional keyword arguments. 47 | """ 48 | wandb.init() 49 | 50 | config = wandb.config 51 | MODEL_NAME = config.model 52 | TASK_NAME = config.generator["name"] 53 | FROM_PRETRAINED = config.from_pretrained 54 | 55 | # train model with wandb config 56 | train_data = train_model_from_config( 57 | MODEL_NAME, 58 | config, 59 | pretrained=FROM_PRETRAINED, 60 | callback=log_to_wandb 61 | ) 62 | trained_model, _ = train_data 63 | 64 | # save trained model to local folder 65 | filename = MODEL_NAME 66 | loss_params = config["loss"] 67 | if loss_params["residual"]["include"] > 0 and MODEL_NAME != "formfinder": 68 | filename += "_pinn" 69 | filename += f"_{TASK_NAME}" 70 | 71 | filepath = f"{filename}.eqx" 72 | save_model(filepath, trained_model) 73 | 74 | # save trained model to wandb 75 | wandb.save(filepath) 76 | 77 | 78 | # =============================================================================== 79 | # Main 80 | # =============================================================================== 81 | 82 | if __name__ == "__main__": 83 | 84 | from fire import Fire 85 | 86 | Fire(sweep) 87 | -------------------------------------------------------------------------------- /scripts/sweep_bezier.yml: -------------------------------------------------------------------------------- 1 | # wandb variables 2 | program: sweep.py 3 | project: neural_fofin_bezier 4 | method: random # grid, random 5 | metric: 6 | goal: minimize 7 | name: loss 8 | # parameters to be sweeped 9 | parameters: 10 | model: 11 | value: "autoencoder" 12 | from_pretrained: 13 | value: False 14 | # randomness 15 | seed: 16 | value: 91 17 | # data generator 18 | generator: 19 | # we need to specify "parameters" again in every nested field of hyperparameters 20 | parameters: 21 | # a wandb sweep would not work otherwise 22 | name: 23 | value: "bezier_symmetric_double" 24 | bounds: 25 | value: "saddle" # options: pillow, dome, saddle 26 | num_uv: 27 | value: 10 28 | size: 29 | value: 10.0 30 | num_points: 31 | value: 4 32 | # fd simulation 33 | fdm: 34 | parameters: 35 | load: 36 | value: -0.5 37 | # encoder 38 | encoder: 39 | parameters: 40 | shift: 41 | value: 0.0 42 | hidden_layer_size: 43 | values: [128, 256, 512, 1024] 44 | hidden_layer_num: 45 | values: [3, 4, 5] 46 | activation_fn_name: 47 | values: ["elu", "relu"] 48 | final_activation_fn_name: 49 | value: "softplus" 50 | # decoder 51 | decoder: 52 | parameters: 53 | include_params_xl: 54 | value: True 55 | hidden_layer_size: 56 | values: [128, 256, 512, 1024] 57 | hidden_layer_num: 58 | values: [3, 4, 5] 59 | activation_fn_name: 60 | values: ["elu", "relu"] 61 | # loss 62 | loss: 63 | parameters: 64 | shape: 65 | parameters: 66 | include: 67 | value: True 68 | weight: 69 | value: 1.0 70 | scale: 71 | value: 1.0 72 | residual: # PINN term 73 | parameters: 74 | include: 75 | value: True 76 | weight: 77 | value: 1.0 78 | scale: 79 | value: 1.0 80 | # optimizer 81 | optimizer: 82 | parameters: 83 | name: 84 | value: "adam" 85 | learning_rate: 86 | # NOTE: Be careful with scientific notation in YAML! 87 | values: [1.0e-3, 3.0e-3, 1.0e-4, 3.0e-4, 5.0e-4, 1.0e-5, 3.0e-5, 5.0e-5] 88 | clip_norm: 89 | value: 0.0 90 | # training 91 | training: 92 | parameters: 93 | steps: 94 | value: 10000 95 | batch_size: 96 | value: 64 97 | -------------------------------------------------------------------------------- /scripts/sweep_tower.yml: -------------------------------------------------------------------------------- 1 | # wandb variables 2 | program: sweep.py 3 | project: neural_fdm_tower 4 | method: grid # grid, random 5 | metric: 6 | goal: minimize 7 | name: loss 8 | # parameters to be sweeped 9 | parameters: 10 | model: 11 | value: "autoencoder" # Supported models are formfinder, autoencoder, and piggy 12 | from_pretrained: 13 | value: False 14 | # randomness 15 | seed: 16 | values: [89, 90, 91] 17 | # data generator 18 | generator: 19 | # we need to specify "parameters" again in every nested field of hyperparameters 20 | # a wandb sweep would not work otherwise 21 | parameters: 22 | name: 23 | value: "tower_ellipse" # options: tower_ellipse, tower_circle 24 | bounds: 25 | value: "twisted" # options: straight, twisted 26 | height: 27 | value: 10.0 28 | radius: 29 | value: 2.0 30 | num_sides: 31 | value: 16 32 | num_levels: 33 | value: 21 34 | num_rings: 35 | value: 3 36 | # fd simulation 37 | fdm: 38 | parameters: 39 | load: 40 | value: 0.0 41 | # encoder 42 | encoder: 43 | parameters: 44 | shift: 45 | value: 1.0 46 | hidden_layer_size: 47 | value: 256 48 | hidden_layer_num: 49 | value: 5 50 | activation_fn_name: 51 | value: "elu" 52 | final_activation_fn_name: 53 | value: "softplus" 54 | # decoder 55 | decoder: 56 | parameters: 57 | include_params_xl: 58 | value: True 59 | hidden_layer_size: 60 | value: 256 61 | hidden_layer_num: 62 | value: 5 63 | activation_fn_name: 64 | value: "elu" 65 | # loss 66 | loss: 67 | parameters: 68 | shape: 69 | parameters: 70 | include: 71 | value: True 72 | weight: 73 | value: 1.0 74 | scale: 75 | value: 1.0 76 | height: 77 | parameters: 78 | include: 79 | value: True 80 | weight: 81 | value: 1.0 82 | scale: 83 | value: 1.0 84 | energy: 85 | parameters: 86 | include: 87 | value: False 88 | weight: 89 | value: 1.0 90 | scale: 91 | value: 1.0 92 | residual: 93 | parameters: 94 | include: 95 | value: True 96 | weight: 97 | values: [0.01, 0.1, 1.0, 10.0, 100.0] 98 | scale: 99 | value: 1.0 100 | regularization: 101 | parameters: 102 | include: 103 | value: True 104 | weight: 105 | value: 10.0 106 | # optimizer 107 | optimizer: 108 | parameters: 109 | name: 110 | value: "adam" 111 | learning_rate: 112 | # NOTE: Be careful with scientific notation in YAML! 113 | value: 0.001 114 | clip_norm: 115 | value: 0.0 116 | # training 117 | training: 118 | parameters: 119 | steps: 120 | value: 10000 121 | batch_size: 122 | value: 16 123 | -------------------------------------------------------------------------------- /scripts/text_2_mesh.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert text to a mesh to label the bricks of a masonry shell. 3 | 4 | This script uses the FreeType library to convert text to a mesh. 5 | 6 | FreeType high-level python API - Copyright 2011 Nicolas P. Rougier. 7 | Distributed under the terms of the new BSD license. 8 | """ 9 | 10 | import numpy as np 11 | from matplotlib.path import Path 12 | 13 | from freetype import Face 14 | 15 | from compas.geometry import bounding_box 16 | from compas.geometry import Box 17 | from compas.geometry import Translation 18 | from compas.datastructures import Mesh 19 | from compas.datastructures import meshes_join 20 | 21 | from compas_cgal.triangulation import constrained_delaunay_triangulation 22 | 23 | 24 | def char_2_mesh(char, filepath="Vera.ttf"): 25 | """ 26 | Convert a single character to a mesh. 27 | 28 | Parameters 29 | ---------- 30 | char: `str` 31 | The character to convert. 32 | filepath: `str`, optional 33 | The path to the font file. 34 | 35 | Returns 36 | ------- 37 | mesh: `compas.datastructures.Mesh` 38 | The mesh representing the character. 39 | """ 40 | face = Face(filepath) 41 | face.set_char_size(48*64) 42 | face.load_char(char) 43 | slot = face.glyph 44 | 45 | outline = slot.outline 46 | points = np.array(outline.points, dtype=[('x',float), ('y',float)]) 47 | 48 | # Iterate over each contour 49 | start, end = 0, 0 50 | VERTS, CODES = [], [] 51 | for i in range(len(outline.contours)): 52 | 53 | end = outline.contours[i] 54 | points = outline.points[start:end+1] 55 | points.append(points[0]) 56 | tags = outline.tags[start:end+1] 57 | tags.append(tags[0]) 58 | 59 | segments = [ [points[0],], ] 60 | for j in range(1, len(points) ): 61 | segments[-1].append(points[j]) 62 | if tags[j] & (1 << 0) and j < (len(points)-1): 63 | segments.append( [points[j],] ) 64 | 65 | verts = [points[0], ] 66 | codes = [Path.MOVETO,] 67 | 68 | for segment in segments: 69 | 70 | if len(segment) == 2: 71 | verts.extend(segment[1:]) 72 | codes.extend([Path.LINETO]) 73 | elif len(segment) == 3: 74 | verts.extend(segment[1:]) 75 | codes.extend([Path.CURVE3, Path.CURVE3]) 76 | else: 77 | verts.append(segment[1]) 78 | codes.append(Path.CURVE3) 79 | for i in range(1,len(segment)-2): 80 | A,B = segment[i], segment[i+1] 81 | C = ((A[0]+B[0])/2.0, (A[1]+B[1])/2.0) 82 | verts.extend([ C, B ]) 83 | codes.extend([ Path.CURVE3, Path.CURVE3]) 84 | verts.append(segment[-1]) 85 | codes.append(Path.CURVE3) 86 | 87 | VERTS.extend(verts) 88 | CODES.extend(codes) 89 | 90 | start = end + 1 91 | 92 | path = Path(VERTS, CODES) 93 | polygons = path.to_polygons() 94 | points = polygons[-1] * 0.1 95 | holes = [p * 0.1 for p in polygons[:-1]] 96 | 97 | V, F = constrained_delaunay_triangulation(points, holes=holes) 98 | mesh = Mesh.from_vertices_and_faces(V, F) 99 | 100 | return mesh 101 | 102 | 103 | def text_2_mesh(text, filepath="Vera.ttf"): 104 | """ 105 | Convert a text to a mesh. 106 | 107 | Parameters 108 | ---------- 109 | text: `str` 110 | The text to convert. 111 | filepath: `str`, optional 112 | The path to the font file. 113 | 114 | Returns 115 | ------- 116 | mesh: `compas.datastructures.Mesh` 117 | The joined mesh of all characters in the text. 118 | """ 119 | meshes = [] 120 | xsize = 0.0 121 | for char in text: 122 | mesh = char_2_mesh(char, filepath="Vera.ttf") 123 | vertices, faces = mesh.to_vertices_and_faces() 124 | bbox = Box.from_bounding_box(bounding_box(vertices)) 125 | T = Translation.from_vector([xsize, 0.0, 0.0]) 126 | mesh.transform(T) 127 | xsize += bbox.xsize * 1.1 128 | meshes.append(mesh) 129 | 130 | mesh = meshes_join(meshes) 131 | 132 | return mesh 133 | 134 | 135 | if __name__ == '__main__': 136 | 137 | from jax_fdm.visualization import Viewer 138 | 139 | mesh = text_2_mesh("100") 140 | 141 | viewer = Viewer( 142 | width=900, 143 | height=900, 144 | show_grid=True, 145 | viewmode="ghosted" 146 | ) 147 | 148 | viewer.add(mesh) 149 | 150 | viewer.show() 151 | -------------------------------------------------------------------------------- /scripts/tower.yml: -------------------------------------------------------------------------------- 1 | # randomness 2 | seed: 90 # 90 3 | # dataset 4 | generator: 5 | name: "tower_ellipse" # options: tower_ellipse, tower_circle 6 | bounds: "twisted" # options: straight, twisted 7 | height: 10.0 8 | radius: 2.0 9 | num_sides: 16 10 | num_levels: 21 11 | num_rings: 3 # must be >=3, 2 of them are the top and bottom supports 12 | # simulator 13 | fdm: 14 | load: 0.0 # 0.0, scale of vertical area load 15 | # neural networks 16 | encoder: 17 | shift: 1.0 # 1.0 18 | hidden_layer_size: 256 # 256 19 | hidden_layer_num: 5 # 5 20 | activation_fn_name: "elu" 21 | final_activation_fn_name: "softplus" # needs softplus to ensure positive output 22 | decoder: 23 | # If true, the decoder maps (q, boundary conditions) -> x. Otherwise, q -> x. 24 | include_params_xl: True 25 | hidden_layer_size: 256 26 | hidden_layer_num: 5 27 | activation_fn_name: "elu" 28 | # loss function 29 | loss: 30 | shape: 31 | include: True 32 | weight: 1.0 # weight of the shape term in the loss function 33 | height: 34 | include: True 35 | weight: 1.0 # weight of the height term in the loss function 36 | residual: # physics term 37 | include: True 38 | weight: 1.0 # weight of the physics error term in the loss function 39 | regularization: 40 | include: True 41 | weight: 10.0 # weight of the regularization term in the loss function 42 | # optimization 43 | optimizer: 44 | name: "adam" 45 | learning_rate: 0.001 # 0.001 then 0.0001, be careful with scientific notation in YAML! 46 | clip_norm: 0.01 # 0.01 47 | training: 48 | steps: 10000 49 | batch_size: 16 # 16 50 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a model to approximate a family of arbitrary shapes with mechanically-feasible geometries. 3 | """ 4 | 5 | import os 6 | import time 7 | import yaml 8 | 9 | from functools import partial 10 | 11 | import jax 12 | import jax.random as jrn 13 | import jax.tree_util as jtu 14 | 15 | from jax import vmap 16 | 17 | import equinox as eqx 18 | 19 | from neural_fdm import DATA 20 | 21 | from neural_fdm.training import train_model 22 | 23 | from neural_fdm.plotting import plot_losses as plot_loss_curves 24 | 25 | from neural_fdm.builders import build_loss_function 26 | from neural_fdm.builders import build_data_generator 27 | from neural_fdm.builders import build_connectivity_structure_from_generator 28 | from neural_fdm.builders import build_neural_model 29 | from neural_fdm.builders import build_optimizer 30 | 31 | from neural_fdm.serialization import load_model 32 | from neural_fdm.serialization import save_model as save_model_fn 33 | 34 | 35 | # =============================================================================== 36 | # Script function 37 | # =============================================================================== 38 | 39 | def train( 40 | model_name, 41 | task_name, 42 | from_pretrained=False, 43 | checkpoint_every=None, 44 | plot_losses=True, 45 | save_model=True, 46 | save_losses=True, 47 | ): 48 | """ 49 | Train a model to approximate a family of arbitrary shapes with mechanically-feasible geometries. 50 | 51 | Parameters 52 | ---------- 53 | model_name: `str` 54 | The model name. 55 | Supported models are formfinder, autoencoder, and piggy. 56 | task_name: `str` 57 | The name of the YAML config file with the task hyperparameters. 58 | from_pretrained: `bool`, optional 59 | If `True`, train the model starting from a pretrained version. 60 | checkpoint_every: `int` or `None`, optional 61 | If not None, save a model every checkpoint steps. 62 | plot_losses: `bool`, optional 63 | If `True`, plot the loss curves. 64 | save_model: `bool`, optional 65 | If `True`, save the trained model. 66 | save_losses: `bool`, optional 67 | If `True`, save the loss histories as text files. 68 | """ 69 | # load yaml file with hyperparameters 70 | with open(f"{task_name}.yml") as file: 71 | config = yaml.load(file, Loader=yaml.FullLoader) 72 | 73 | # resolve model name for saving and checkpointing 74 | filename = f"{model_name}" 75 | loss_params = config["loss"] 76 | if loss_params["residual"]["include"] > 0 and model_name != "formfinder": 77 | filename += "_pinn" 78 | filename += f"_{task_name}" 79 | 80 | # pick callback 81 | callback = None 82 | if checkpoint_every: 83 | callback = partial( 84 | checkpoint_model, 85 | checkpoint_step=checkpoint_every, 86 | filename=filename 87 | ) 88 | 89 | # train model 90 | trained_model, loss_history = train_model_from_config( 91 | model_name, 92 | config, 93 | from_pretrained, 94 | callback=callback 95 | ) 96 | 97 | if plot_losses: 98 | print("\nPlotting loss curves") 99 | plot_loss_curves(loss_history, labels=["loss"]) 100 | 101 | if save_model: 102 | print("\nSaving model") 103 | 104 | # save trained model 105 | filepath = os.path.join(DATA, f"{filename}.eqx") 106 | save_model_fn(filepath, trained_model) 107 | print(f"Saved model to {filepath}") 108 | 109 | if save_losses: 110 | labels = loss_history[0].keys() 111 | for label in labels: 112 | _label = "_".join(label.split()) 113 | filename_loss = f"losses_{filename}_{_label}.txt" 114 | 115 | filepath = os.path.join(DATA, filename_loss) 116 | with open(filepath, "w") as file: 117 | for values in loss_history: 118 | _value = values[label].item() 119 | file.write(f"{_value}\n") 120 | 121 | print(f"Saved loss history to {filepath}") 122 | 123 | 124 | # =============================================================================== 125 | # Train functions 126 | # =============================================================================== 127 | 128 | def train_model_from_config(model_name, config, pretrained=False, callback=None): 129 | """ 130 | Train a model to approximate a family of arbitrary shapes with mechanically-feasible geometries. 131 | 132 | Parameters 133 | ---------- 134 | model_name: `str` 135 | The model name. 136 | Supported models are formfinder, autoencoder, and piggy. 137 | config: `dict` 138 | A dictionary with the hyperparameters configuration. 139 | task_name: `str` 140 | The name of the YAML config file with the task hyperparameters. 141 | pretrained: `bool` 142 | If `True`, train the model starting from a pretrained version of it. 143 | callback: `Callable` 144 | A callback function to call at every train step. 145 | """ 146 | # unpack parameters 147 | seed = config["seed"] 148 | training_params = config["training"] 149 | batch_size = training_params["batch_size"] 150 | steps = training_params["steps"] 151 | generator_name = config['generator']['name'] 152 | bounds_name = config['generator']['bounds'] 153 | 154 | # randomness 155 | key = jrn.PRNGKey(seed) 156 | model_key, generator_key = jax.random.split(key, 2) 157 | 158 | # create experiment 159 | print(f"\nTraining {model_name} on {generator_name} dataset with {bounds_name} bounds") 160 | generator = build_data_generator(config) 161 | structure = build_connectivity_structure_from_generator(config, generator) 162 | compute_loss = build_loss_function(config, generator) 163 | model = build_neural_model(model_name, config, generator, model_key) 164 | optimizer = build_optimizer(config) 165 | 166 | if pretrained: 167 | print("Starting from pretrained model") 168 | task_name = generator_name.split("_")[0] 169 | filepath = os.path.join(DATA, f"{model_name}_{task_name}_pretrain.eqx") 170 | model = load_model(filepath, model) 171 | 172 | # sample initial data batch 173 | xyz = vmap(generator)(jrn.split(generator_key, batch_size)) 174 | 175 | # warmstart 176 | start_loss = compute_loss(model, structure, xyz) 177 | print(f"The structure has {structure.num_vertices} vertices and {structure.num_edges} edges") 178 | print(f"Model parameter count: {count_model_params(model)}") 179 | print(f"{model_name.capitalize()} start loss: {start_loss:.6f}") 180 | 181 | # train models 182 | print("\nTraining") 183 | start = time.perf_counter() 184 | train_data = train_model( 185 | model, 186 | structure, 187 | optimizer, 188 | generator, 189 | loss_fn=compute_loss, 190 | num_steps=steps, 191 | batch_size=batch_size, 192 | key=generator_key, 193 | callback=callback 194 | ) 195 | end = time.perf_counter() 196 | 197 | print("\nTraining completed") 198 | print(f"Training time: {end - start:.4f} s") 199 | 200 | trained_model, _ = train_data 201 | 202 | end_loss = compute_loss(trained_model, structure, xyz) 203 | print(f"{model_name.capitalize()} last loss: {end_loss}") 204 | 205 | return train_data 206 | 207 | 208 | # =============================================================================== 209 | # Helper functions 210 | # =============================================================================== 211 | 212 | def checkpoint_model( 213 | model, 214 | opt_state, 215 | loss_vals, 216 | step, 217 | checkpoint_step, 218 | filename 219 | ): 220 | """ 221 | Checkpoint a model. Function to be used as a callback in the training loop. 222 | 223 | Parameters 224 | ---------- 225 | model: `eqx.Module` 226 | The model to checkpoint. 227 | opt_state: `eqx.Module` 228 | The optimizer state. 229 | loss_vals: `dict` 230 | The loss values. 231 | step: `int` 232 | The current training step. 233 | checkpoint_step: `int` 234 | The step interval at which to checkpoint the model. 235 | filename: `str` 236 | The filename to save the model to. 237 | """ 238 | if step > 0 and step % checkpoint_step == 0: 239 | filepath = os.path.join(DATA, f"{filename}_{step}.eqx") 240 | save_model_fn(filepath, model) 241 | 242 | 243 | def count_model_params(model): 244 | """ 245 | Count the number of trainable model parameters. 246 | 247 | Parameters 248 | ---------- 249 | model: `eqx.Module` 250 | The model to count the parameters of. 251 | 252 | Returns 253 | ------- 254 | count: `int` 255 | The number of trainable model parameters. 256 | """ 257 | spec = eqx.is_inexact_array 258 | 259 | return sum(x.size for x in jtu.tree_leaves(eqx.filter(model, spec))) 260 | 261 | 262 | # =============================================================================== 263 | # Main 264 | # =============================================================================== 265 | 266 | if __name__ == "__main__": 267 | 268 | from fire import Fire 269 | 270 | Fire(train) 271 | -------------------------------------------------------------------------------- /scripts/visualize_tower_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate diagrams for the cablenet tower prediction task. 3 | """ 4 | import yaml 5 | 6 | import jax 7 | from jax import vmap 8 | import jax.numpy as jnp 9 | 10 | import jax.random as jrn 11 | 12 | from compas.colors import Color 13 | from compas.geometry import Polygon 14 | from compas.geometry import Polyline 15 | from compas.geometry import Plane 16 | 17 | from jax_fdm.visualization import Viewer 18 | 19 | from neural_fdm.builders import build_data_generator 20 | 21 | from neural_fdm.generators import points_on_ellipse 22 | 23 | from camera import CAMERA_CONFIG_TOWER as CAMERA_CONFIG 24 | 25 | 26 | # =============================================================================== 27 | # Script function 28 | # =============================================================================== 29 | 30 | def view_tower_task(seed=None, batch_size=None, shape_index=0): 31 | """ 32 | View the description of the cablenet tower prediction task for a target shape. 33 | 34 | Parameters 35 | ---------- 36 | seed: `int` or `None` 37 | The random seed to generate a batch of target shapes. 38 | If `None`, it defaults to the task hyperparameters file. 39 | batch_size: `int` or `None` 40 | The size of the batch of target shapes. 41 | If `None`, it defaults to the task hyperparameters file. 42 | shape_index: `int` 43 | The index of the shape to view. 44 | """ 45 | # pick camera configuration for task 46 | task_name = "tower" 47 | _width = 450 48 | 49 | # load yaml file with hyperparameters 50 | with open(f"{task_name}.yml") as file: 51 | config = yaml.load(file, Loader=yaml.FullLoader) 52 | 53 | # unpack parameters 54 | if seed is None: 55 | seed = config["seed"] 56 | if batch_size is None: 57 | training_params = config["training"] 58 | batch_size = training_params["batch_size"] 59 | 60 | # randomness 61 | key = jrn.PRNGKey(seed) 62 | _, generator_key = jax.random.split(key, 2) 63 | 64 | # create data generator 65 | generator = build_data_generator(config) 66 | 67 | # sample data batch 68 | xyz_batch = vmap(generator)(jrn.split(generator_key, batch_size)) 69 | xyz = xyz_batch[shape_index, :] 70 | 71 | # view task 72 | 73 | # create viewer 74 | viewer = Viewer( 75 | width=_width, 76 | height=900, 77 | show_grid=False, 78 | viewmode="lighted" 79 | ) 80 | 81 | # modify view 82 | viewer.view.camera.position = CAMERA_CONFIG["position"] 83 | viewer.view.camera.target = CAMERA_CONFIG["target"] 84 | viewer.view.camera.distance = CAMERA_CONFIG["distance"] 85 | _rotation = CAMERA_CONFIG.get("rotation") 86 | if _rotation: 87 | viewer.view.camera.rotation = _rotation 88 | 89 | # draw rings 90 | rings = jnp.reshape(xyz, generator.shape_tube)[generator.levels_rings_comp, :, :] 91 | 92 | for ring in rings: 93 | ring = ring.tolist() 94 | polygon = Polygon(ring) 95 | 96 | viewer.add(polygon, opacity=0.5) 97 | viewer.add( 98 | Polyline(ring + ring[:1]), 99 | linewidth=4.0, 100 | color=Color.black().lightened() 101 | ) 102 | 103 | # draw planes, transparent, thick-ish boundary 104 | heights = jnp.linspace(0.0, generator.height, generator.num_levels) 105 | 106 | for i, height in enumerate(heights): 107 | 108 | plane = Plane([0.0, 0.0, height], [0.0, 0.0, 1.0]) 109 | 110 | circle = points_on_ellipse( 111 | generator.radius, 112 | generator.radius, 113 | height, 114 | generator.num_sides 115 | ) 116 | circle = circle.tolist() 117 | 118 | if i in generator.levels_rings_comp: 119 | viewer.add( 120 | Polyline(circle + circle[:1]), 121 | linewidth=2.0, 122 | color=Color.grey().lightened() 123 | ) 124 | # skip plane drawing for compression rings to avoid overlap 125 | continue 126 | 127 | viewer.add( 128 | plane, 129 | size=1.0, 130 | linewidth=0.1, 131 | color=Color.grey().lightened(10), 132 | opacity=0.1) 133 | 134 | # show viewer 135 | viewer.show() 136 | 137 | 138 | # =============================================================================== 139 | # Main 140 | # =============================================================================== 141 | 142 | if __name__ == "__main__": 143 | 144 | from fire import Fire 145 | 146 | Fire(view_tower_task) 147 | -------------------------------------------------------------------------------- /src/neural_fdm/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | HERE = os.path.dirname(__file__) 5 | HOME = os.path.abspath(os.path.join(HERE, "../../")) 6 | DATA = os.path.abspath(os.path.join(HOME, "data")) 7 | FIGURES = os.path.abspath(os.path.join(HOME, "figures")) 8 | SCRIPTS = os.path.abspath(os.path.join(HOME, "scripts")) 9 | 10 | # Monkey patch numpy for compas_view2==0.7.0 11 | if not hasattr(np, 'int'): 12 | np.int = np.int64 # noqa: F821 -------------------------------------------------------------------------------- /src/neural_fdm/generators/__init__.py: -------------------------------------------------------------------------------- 1 | from .grids import * 2 | from .bezier import * 3 | from .generator import * 4 | from .generator_bezier import * 5 | from .tubes import * 6 | -------------------------------------------------------------------------------- /src/neural_fdm/generators/bezier.py: -------------------------------------------------------------------------------- 1 | from jax import lax 2 | from jax import vmap 3 | 4 | import jax.numpy as jnp 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | from neural_fdm.generators.grids import PointGridAsymmetric 9 | from neural_fdm.generators.grids import PointGridSymmetric 10 | from neural_fdm.generators.grids import PointGridSymmetricDouble 11 | 12 | 13 | # =============================================================================== 14 | # Functions 15 | # =============================================================================== 16 | 17 | def factorial(n): 18 | """ 19 | Calculate the factorial of a number. 20 | 21 | Parameters 22 | ---------- 23 | n: `int` 24 | The number to calculate the factorial of. 25 | 26 | Returns 27 | ------- 28 | factorial: `float` 29 | The factorial of the number. 30 | """ 31 | return jnp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1))) 32 | 33 | 34 | def binomial_coefficient(n, i): 35 | """ 36 | Compute the binomial coefficient. 37 | 38 | Parameters 39 | ---------- 40 | n: `int` 41 | The number to calculate the binomial coefficient of. 42 | i: `int` 43 | The index of the binomial coefficient. 44 | 45 | Returns 46 | ------- 47 | coefficient: `float` 48 | The binomial coefficient of the number. 49 | """ 50 | return factorial(n) / (factorial(i) * factorial(n - i)) 51 | 52 | 53 | def bernstein_poly(n, t): 54 | """ 55 | Compute all Bernstein polynomials of degree `n` at `t` using vectorized operations. 56 | 57 | Parameters 58 | ---------- 59 | n: `int` 60 | The degree of the Bernstein polynomial. 61 | t: `jax.Array` 62 | The parameter values. 63 | 64 | Returns 65 | ------- 66 | bernstein_poly: `jax.Array` 67 | The Bernstein polynomials of the degree `n` at the parameters `t`. 68 | """ 69 | i = jnp.arange(n + 1) 70 | binomial_coeff = binomial_coefficient(n, i) 71 | 72 | return binomial_coeff * (t ** i) * ((1 - t) ** (n - i)) 73 | 74 | 75 | def degree_u(control_points): 76 | """ 77 | Get the degree along the `u` direction of the Bezier surface. 78 | 79 | Parameters 80 | ---------- 81 | control_points: `jax.Array` 82 | The control points of the Bezier surface. 83 | 84 | Returns 85 | ------- 86 | degree: `int` 87 | The degree along the `u` direction of the Bezier surface. 88 | """ 89 | n, m, _ = control_points.shape 90 | n -= 1.0 91 | 92 | return n 93 | 94 | 95 | def degree_v(control_points): 96 | """ 97 | Get the degree along the `v` direction of the Bezier surface. 98 | 99 | Parameters 100 | ---------- 101 | control_points: `jax.Array` 102 | The control points of the Bezier surface. 103 | 104 | Returns 105 | ------- 106 | degree: `int` 107 | The degree along the `v` direction of the Bezier surface. 108 | """ 109 | n, m, _ = control_points.shape 110 | m -= 1.0 111 | 112 | return m 113 | 114 | 115 | def bezier_surface_point(control_points, u, v): 116 | """ 117 | Evaluate a point on a Bezier surface using a series of dot products. 118 | 119 | Parameters 120 | ---------- 121 | control_points: `jax.Array` 122 | The control points of the Bezier surface in the shape (n+1, m+1, 3). 123 | u: `float` 124 | The parameter value along the `u` direction in the range [0, 1]. 125 | v: `float` 126 | The parameter value along the `v` direction in the range [0, 1]. 127 | 128 | Returns 129 | ------- 130 | point: `jax.Array` 131 | The point on the Bezier surface. 132 | """ 133 | n = degree_u(control_points) 134 | m = degree_v(control_points) 135 | 136 | # Compute the Bernstein polynomial values at u and v 137 | bernstein_u = bernstein_poly(n, u) 138 | bernstein_v = bernstein_poly(m, v) 139 | 140 | # Calculate the weighted sum of control points along the u direction 141 | weighted_u = jnp.dot(bernstein_u, control_points) # shape becomes (m+1, 3) 142 | 143 | # Calculate the final point on the surface by combining the results across v 144 | point = jnp.dot(weighted_u.T, bernstein_v) 145 | 146 | return point 147 | 148 | 149 | # Evaluate points 150 | def evaluate_bezier_surface(control_points, u, v): 151 | """ 152 | Sample a series of 3D points on a Bezier surface with `vmap`. 153 | 154 | Parameters 155 | ---------- 156 | control_points: `jax.Array` 157 | The control points of the Bezier surface in the shape (n+1, m+1, 3). 158 | u: `jax.Array` 159 | The parameter values along the `u` direction in the range [0, 1]. 160 | v: `jax.Array` 161 | The parameter values along the `v` direction in the range [0, 1]. 162 | 163 | Returns 164 | ------- 165 | points: `jax.Array` 166 | The points on the Bezier surface. 167 | """ 168 | fn = vmap(vmap(bezier_surface_point, 169 | in_axes=(None, 0, None)), 170 | in_axes=(None, None, 0)) 171 | 172 | return fn(control_points, u, v) 173 | 174 | 175 | def evaluate_bezier_surface_einsum(control_points, u, v): 176 | """ 177 | Vectorized computation of a point on a Bezier surface via `einsum`. 178 | 179 | Parameters 180 | ---------- 181 | control_points: `jax.Array` 182 | The control points of the Bezier surface in the shape (n+1, m+1, 3). 183 | u: `jax.Array` 184 | The parameter values along the `u` direction in the range [0, 1]. 185 | v: `jax.Array` 186 | The parameter values along the `v` direction in the range [0, 1]. 187 | 188 | Returns 189 | ------- 190 | points: `jax.Array` 191 | The points on the Bezier surface. 192 | """ 193 | n = degree_u(control_points) 194 | m = degree_v(control_points) 195 | 196 | # Compute the Bernstein polynomial values at u and v 197 | bernstein_u = bernstein_poly(n, u[:, :, None]) 198 | bernstein_v = bernstein_poly(m, v[:, :, None]) 199 | 200 | # Calculate the weighted sum of control points along the u and v directions 201 | surface_points = jnp.einsum('ijk,lmi,lmj->lmk', 202 | control_points, 203 | bernstein_u, 204 | bernstein_v) 205 | 206 | return surface_points 207 | 208 | 209 | # =============================================================================== 210 | # Surfaces 211 | # =============================================================================== 212 | 213 | 214 | class BezierSurface: 215 | """ 216 | A Bezier surface. 217 | 218 | Parameters 219 | ---------- 220 | grid: `PointGrid` 221 | The grid of points that define the Bezier surface. 222 | """ 223 | def __init__(self, grid): 224 | self.grid = grid 225 | 226 | def control_points(self, transform=None): 227 | """ 228 | """ 229 | return self.grid.points(transform) 230 | 231 | def evaluate_points(self, u, v, transform=None): 232 | """ 233 | """ 234 | control_points = self.control_points(transform) 235 | return evaluate_bezier_surface(control_points, u, v) 236 | 237 | 238 | class BezierSurfaceSymmetric(BezierSurface): 239 | """ 240 | A symmetric Bezier surface. 241 | 242 | Parameters 243 | ---------- 244 | size: `int` 245 | The size of the grid. 246 | num_pts: `int` 247 | The number of points along one side of the grid. 248 | """ 249 | def __init__(self, size, num_pts): 250 | grid = PointGridSymmetric(size, num_pts) 251 | super().__init__(grid) 252 | 253 | 254 | class BezierSurfaceSymmetricDouble(BezierSurface): 255 | """ 256 | A Bezier surface with double symmetry. 257 | 258 | Parameters 259 | ---------- 260 | size: `int` 261 | The size of the grid. 262 | num_pts: `int` 263 | The number of points along one side of the grid. 264 | """ 265 | def __init__(self, size, num_pts): 266 | grid = PointGridSymmetricDouble(size, num_pts) 267 | super().__init__(grid) 268 | 269 | 270 | class BezierSurfaceAsymmetric(BezierSurface): 271 | """ 272 | A Bezier surface without symmetry. 273 | 274 | Parameters 275 | ---------- 276 | size: `int` 277 | The size of the grid. 278 | num_pts: `int` 279 | The number of points along one side of the grid. 280 | """ 281 | def __init__(self, size, num_pts): 282 | grid = PointGridAsymmetric(size, num_pts) 283 | super().__init__(grid) 284 | 285 | 286 | # =============================================================================== 287 | # Main 288 | # =============================================================================== 289 | 290 | if __name__ == "__main__": 291 | 292 | grid_size = 4 293 | num_u = 11 # 12 for youtube 294 | num_v = 11 295 | 296 | # Control points rhino 297 | points = [ 298 | [-5, -5, 0], 299 | [-5, -1.666667, 0], 300 | [-5, 1.666667, 0], 301 | [-5, 5, 0], 302 | [-1.666667, -5, 0], 303 | [-1.666667, -1.666667, 10], 304 | [-1.666667, 1.666667, 10], 305 | [-1.666667, 5, 0], 306 | [1.666667, -5, 0], 307 | [1.666667, -1.666667, 10], 308 | [1.666667, 1.666667, 10], 309 | [1.666667, 5, 0], 310 | [5, -5, 0], 311 | [5, -1.666667, 0], 312 | [5, 1.666667, 0], 313 | [5, 5, 0] 314 | ] 315 | # Control points Youtube 316 | # cx = [[-0.5, -2.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]] 317 | # cy = [[2.0, 1.0, 0.0], [2.0, 0.0, -1.0], [2.0, 1.0, 1.0]] 318 | # cz = [[1.0, -1.0, 2.0], [0.0, -0.5, 2.0], [0.5, 1.0, 2.0]] 319 | # cx = jnp.array(cx) 320 | # cy = jnp.array(cy) 321 | # cz = jnp.array(cz) 322 | 323 | assert len(points) % grid_size == 0 324 | points = jnp.array(points) 325 | control_points = jnp.reshape(points, (grid_size, grid_size, 3)) 326 | 327 | # U, V 328 | u = jnp.linspace(0.0, 1.0, num_u) 329 | v = jnp.linspace(0.0, 1.0, num_v) 330 | u_grid, v_grid = jnp.meshgrid(u, v) 331 | 332 | # surface_points = bezier_surface(control_points, u_grid, v_grid) 333 | surface_points = evaluate_bezier_surface(control_points, u, v) 334 | print("Surface Points Shape:", surface_points.shape) # Should be (10, 10, 3) 335 | 336 | fig = plt.figure() 337 | ax = fig.add_subplot(111, projection='3d') 338 | 339 | ax.plot_surface(surface_points[:, :, 0], 340 | surface_points[:, :, 1], 341 | surface_points[:, :, 2]) 342 | 343 | ax.scatter(control_points[:, :, 0], 344 | control_points[:, :, 1], 345 | control_points[:, :, 2], 346 | edgecolors="face") 347 | 348 | plt.show() 349 | -------------------------------------------------------------------------------- /src/neural_fdm/generators/generator.py: -------------------------------------------------------------------------------- 1 | class PointGenerator: 2 | """ 3 | A generator that samples random points on a target shape. 4 | """ 5 | def __call__(self, key, wiggle=True): 6 | """ 7 | Generate points. 8 | 9 | Parameters 10 | ---------- 11 | key: `jax.random.PRNGKey` 12 | The random key. 13 | wiggle: `bool` 14 | If True, the points are wiggled. 15 | 16 | Returns 17 | ------- 18 | points: `jax.Array` 19 | The points on the target shape. 20 | """ 21 | raise NotImplementedError("Subclasses must implement this method.") -------------------------------------------------------------------------------- /src/neural_fdm/generators/generator_bezier.py: -------------------------------------------------------------------------------- 1 | import jax.random as jrn 2 | import jax.numpy as jnp 3 | 4 | from neural_fdm.generators.generator import PointGenerator 5 | 6 | from neural_fdm.generators.bezier import evaluate_bezier_surface 7 | from neural_fdm.generators.bezier import BezierSurfaceAsymmetric 8 | from neural_fdm.generators.bezier import BezierSurfaceSymmetric 9 | from neural_fdm.generators.bezier import BezierSurfaceSymmetricDouble 10 | 11 | 12 | # =============================================================================== 13 | # Generators 14 | # =============================================================================== 15 | 16 | class BezierSurfacePointGenerator(PointGenerator): 17 | """ 18 | A generator that outputs point evaluated on a wiggled bezier surface. 19 | 20 | Parameters 21 | ---------- 22 | surface: `BezierSurface` 23 | The surface to sample points from. 24 | u: `jax.Array` 25 | The parameter values along the `u` direction in the range [0, 1]. 26 | v: `jax.Array` 27 | The parameter values along the `v` direction in the range [0, 1]. 28 | minval: `jax.Array` 29 | The minimum values of the space of random translations. 30 | maxval: `jax.Array` 31 | The maximum values of the space of random translations. 32 | """ 33 | def __init__(self, surface, u, v, minval, maxval): 34 | self._check_array_shapes(surface, minval, maxval) 35 | 36 | self.surface = surface 37 | self.u = u 38 | self.v = v 39 | self.minval = minval 40 | self.maxval = maxval 41 | 42 | def _check_array_shapes(self, surface, minval, maxval): 43 | """ 44 | Verify that input shapes are consistent. 45 | 46 | Parameters 47 | ---------- 48 | surface: `BezierSurface` 49 | The surface to sample points from. 50 | minval: `jax.Array` 51 | The minimum values of the space of random translations. 52 | maxval: `jax.Array` 53 | The maximum values of the space of random translations. 54 | """ 55 | tile_shape = surface.grid.tile.shape 56 | minval_shape = minval.shape 57 | maxval_shape = maxval.shape 58 | 59 | assert minval_shape == tile_shape, f"{minval_shape} vs. {tile_shape}" 60 | assert maxval_shape == tile_shape, f"{maxval_shape} vs. {tile_shape}" 61 | 62 | def wiggle(self, key): 63 | """ 64 | Sample a translation vector from a uniform distribution. 65 | 66 | Parameters 67 | ---------- 68 | key: `jax.random.PRNGKey` 69 | The random key. 70 | 71 | Returns 72 | ------- 73 | transform: `jax.Array` 74 | The translation vector. 75 | """ 76 | shape = self.surface.grid.tile.shape 77 | return jrn.uniform(key, shape=shape, minval=self.minval, maxval=self.maxval) 78 | 79 | def evaluate_points(self, transform): 80 | """ 81 | Generate transformed points. 82 | 83 | Parameters 84 | ---------- 85 | transform: `jax.Array` 86 | The translation vector. 87 | 88 | Returns 89 | ------- 90 | points: `jax.Array` 91 | The transformed points. 92 | """ 93 | points = self.surface.evaluate_points(self.u, self.v, transform) 94 | 95 | return jnp.ravel(points) 96 | 97 | def __call__(self, key, wiggle=True): 98 | """ 99 | Generate (wiggled) points. 100 | 101 | Parameters 102 | ---------- 103 | key: `jax.random.PRNGKey` 104 | The random key. 105 | wiggle: `bool`, optional 106 | Whether to wiggle the points at random. 107 | 108 | Returns 109 | ------- 110 | points: `jax.Array` 111 | The points on the surface. 112 | """ 113 | if wiggle: 114 | transform = self.wiggle(key) 115 | 116 | return self.evaluate_points(transform) 117 | 118 | 119 | class BezierSurfaceSymmetricDoublePointGenerator(BezierSurfacePointGenerator): 120 | """ 121 | A generator that outputs point evaluated on a wiggled, doubly-symmetric bezier surface. 122 | 123 | Parameters 124 | ---------- 125 | size: `int` 126 | The size of the grid. 127 | num_pts: `int` 128 | The number of points along one side of the grid. 129 | u: `jax.Array` 130 | The parameter values along the `u` direction in the range [0, 1]. 131 | v: `jax.Array` 132 | The parameter values along the `v` direction in the range [0, 1]. 133 | minval: `jax.Array` 134 | The minimum values of the space of random translations. 135 | maxval: `jax.Array` 136 | The maximum values of the space of random translations. 137 | """ 138 | def __init__(self, size, num_pts, u, v, minval, maxval, *args, **kwargs): 139 | surface = BezierSurfaceSymmetricDouble(size, num_pts) 140 | super().__init__(surface, u, v, minval, maxval) 141 | 142 | 143 | class BezierSurfaceSymmetricPointGenerator(BezierSurfacePointGenerator): 144 | """ 145 | A generator that outputs point evaluated on a wiggled, symmetric bezier surface. 146 | 147 | Parameters 148 | ---------- 149 | size: `int` 150 | The size of the grid. 151 | num_pts: `int` 152 | The number of points along one side of the grid. 153 | u: `jax.Array` 154 | The parameter values along the `u` direction in the range [0, 1]. 155 | v: `jax.Array` 156 | The parameter values along the `v` direction in the range [0, 1]. 157 | minval: `jax.Array` 158 | The minimum values of the space of random translations. 159 | maxval: `jax.Array` 160 | The maximum values of the space of random translations. 161 | """ 162 | def __init__(self, size, num_pts, u, v, minval, maxval, *args, **kwargs): 163 | surface = BezierSurfaceSymmetric(size, num_pts) 164 | super().__init__(surface, u, v, minval, maxval) 165 | 166 | 167 | class BezierSurfaceAsymmetricPointGenerator(BezierSurfacePointGenerator): 168 | """ 169 | A generator that outputs point evaluated on a wiggled bezier surface. 170 | 171 | Parameters 172 | ---------- 173 | size: `int` 174 | The size of the grid. 175 | num_pts: `int` 176 | The number of points along one side of the grid. 177 | u: `jax.Array` 178 | The parameter values along the `u` direction in the range [0, 1]. 179 | v: `jax.Array` 180 | The parameter values along the `v` direction in the range [0, 1]. 181 | minval: `jax.Array` 182 | The minimum values of the space of random translations. 183 | maxval: `jax.Array` 184 | The maximum values of the space of random translations. 185 | """ 186 | def __init__(self, size, num_pts, u, v, minval, maxval, *args, **kwargs): 187 | surface = BezierSurfaceAsymmetric(size, num_pts) 188 | super().__init__(surface, u, v, minval, maxval) 189 | 190 | 191 | class BezierSurfaceLerpPointGenerator(BezierSurfacePointGenerator): 192 | """ 193 | A generator that outputs points interpolated between two wiggled bezier surfaces. 194 | 195 | Parameters 196 | ---------- 197 | size: `int` 198 | The size of the grid. 199 | num_pts: `int` 200 | The number of points along one side of the grid. 201 | u: `jax.Array` 202 | The parameter values along the `u` direction in the range [0, 1]. 203 | v: `jax.Array` 204 | The parameter values along the `v` direction in the range [0, 1]. 205 | minval: `jax.Array` 206 | The minimum values of the space of random translations. 207 | maxval: `jax.Array` 208 | The maximum values of the space of random translations. 209 | alpha: `float` 210 | The interpolation value. 211 | 212 | Notes 213 | ----- 214 | One surface is doubly-symmetric and the other asymmetric. 215 | """ 216 | def __init__(self, size, num_pts, u, v, minval, maxval, alpha, *args, **kwargs): 217 | minval_a, minval_b = minval 218 | maxval_a, maxval_b = maxval 219 | 220 | surface = BezierSurfaceSymmetricDouble(size, num_pts) 221 | super().__init__(surface, u, v, minval_a, maxval_a) 222 | self.generator_other = BezierSurfaceAsymmetricPointGenerator( 223 | size, 224 | num_pts, 225 | u, 226 | v, 227 | minval_b, 228 | maxval_b) 229 | self.alpha = alpha 230 | 231 | def __call__(self, key, wiggle=True): 232 | """ 233 | Generate (wiggled) points. 234 | 235 | Parameters 236 | ---------- 237 | key: `jax.random.PRNGKey` 238 | The random key. 239 | wiggle: `bool`, optional 240 | Whether to wiggle the points at random. 241 | 242 | Returns 243 | ------- 244 | points: `jax.Array` 245 | The points on the surface. 246 | """ 247 | if wiggle: 248 | transform_this = self.wiggle(key) 249 | transform_other = self.generator_other.wiggle(key) 250 | 251 | control_points_this = self.surface.control_points(transform_this) 252 | control_points_other = self.generator_other.surface.control_points(transform_other) 253 | control_points = (1.0 - self.alpha) * control_points_this + self.alpha * control_points_other 254 | 255 | points = evaluate_bezier_surface(control_points, self.u, self.v) 256 | 257 | return jnp.ravel(points) 258 | -------------------------------------------------------------------------------- /src/neural_fdm/generators/grids.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | 4 | # =============================================================================== 5 | # Transformations 6 | # =============================================================================== 7 | 8 | def get_world_mirror_matrix(plane): 9 | """ 10 | Create a mirror matrix for a given plane. 11 | 12 | Parameters 13 | ---------- 14 | plane: `str` 15 | The plane to mirror across. Must be one of 'xy', 'yz', or 'xz'. 16 | 17 | Returns 18 | ------- 19 | mirror_matrix: `jax.Array` 20 | The mirror matrix for the given plane. 21 | """ 22 | if plane.lower() == 'xy': 23 | # Mirroring across the XY plane (change Z coordinate) 24 | mirror_matrix = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) 25 | elif plane.lower() == 'yz': 26 | # Mirroring across the YZ plane (change X coordinate) 27 | mirror_matrix = jnp.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) 28 | elif plane.lower() == 'xz': 29 | # Mirroring across the XZ plane (change Y coordinate) 30 | mirror_matrix = jnp.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) 31 | else: 32 | raise ValueError("Invalid plane. Choose 'xy', 'yz', or 'xz'.") 33 | 34 | return mirror_matrix 35 | 36 | 37 | def mirror_points(points, mirror_matrix): 38 | """ 39 | Mirror points across a given plane. 40 | 41 | Parameters 42 | ---------- 43 | points: `jax.Array` 44 | The points to mirror. 45 | mirror_matrix: `jax.Array` 46 | The mirror matrix. 47 | 48 | Returns 49 | ------- 50 | mirrored_points: `jax.Array` 51 | The mirrored points. 52 | """ 53 | return points @ mirror_matrix 54 | 55 | 56 | # =============================================================================== 57 | # Grid functions 58 | # =============================================================================== 59 | 60 | def get_grid_tile_quarter(grid_size, grid_num_pts): 61 | """ 62 | Get the 3D coordinates of a quarter tile of a control point grid. 63 | 64 | Parameters 65 | ---------- 66 | grid_size: `int` 67 | The size of the grid. 68 | grid_num_pts: `int` 69 | The number of points along one side of the grid. 70 | 71 | Returns 72 | ------- 73 | tile: `jax.Array` 74 | The 3D coordinates. 75 | """ 76 | half_grid_size = grid_size / 2.0 77 | grid_step = half_grid_size / (grid_num_pts - 1.0) 78 | 79 | pt0 = [grid_step, grid_step, 0.0] 80 | pt1 = [half_grid_size, grid_step, 0.0] 81 | pt2 = [grid_step, half_grid_size, 0.0] 82 | pt3 = [half_grid_size, half_grid_size, 0.0] 83 | 84 | return jnp.array([pt0, pt1, pt2, pt3]) 85 | 86 | 87 | def calculate_grid_from_tile_quarter(tile): 88 | """ 89 | Generate an ordered grid of control points from a quarter tile. 90 | 91 | Parameters 92 | ---------- 93 | tile: `jax.Array` 94 | The 3D coordinates of a quarter tile. 95 | 96 | Returns 97 | ------- 98 | grid_points: `jax.Array` 99 | The 3D coordinates of the grid. 100 | """ 101 | grid_points = tile 102 | 103 | # mirror tile once 104 | mirrored_points = mirror_points(grid_points, get_world_mirror_matrix("yz")) 105 | grid_points = jnp.concatenate((grid_points, mirrored_points)) 106 | 107 | # mirror tile again 108 | mirrored_points = mirror_points(grid_points, get_world_mirror_matrix("xz")) 109 | grid_points = jnp.concatenate((grid_points, mirrored_points)) 110 | 111 | return grid_points 112 | 113 | 114 | def get_grid_tile_half(grid_size, grid_num_pts): 115 | """ 116 | Get the 3D coordinates of a half tile of a control point grid. 117 | 118 | Parameters 119 | ---------- 120 | grid_size: `int` 121 | The size of the grid. 122 | grid_num_pts: `int` 123 | The number of points along one side of the grid. 124 | 125 | Returns 126 | ------- 127 | tile: `jax.Array` 128 | The 3D coordinates. 129 | """ 130 | tile_quarter = get_grid_tile_quarter(grid_size, grid_num_pts) 131 | 132 | # mirror tile once 133 | mirrored_points = mirror_points(tile_quarter, get_world_mirror_matrix("yz")) 134 | 135 | return jnp.concatenate((tile_quarter, mirrored_points)) 136 | 137 | 138 | def calculate_grid_from_tile_half(tile): 139 | """ 140 | Generate an ordered grid of control points from a half tile. 141 | 142 | Parameters 143 | ---------- 144 | tile: `jax.Array` 145 | The 3D coordinates of a half tile. 146 | 147 | Returns 148 | ------- 149 | grid_points: `jax.Array` 150 | The 3D coordinates of the grid. 151 | """ 152 | grid_points = tile 153 | 154 | # mirror tile once 155 | mirrored_points = mirror_points(grid_points, get_world_mirror_matrix("xz")) 156 | grid_points = jnp.concatenate((grid_points, mirrored_points)) 157 | 158 | return grid_points 159 | 160 | 161 | def get_grid_tile_full(grid_size, grid_num_pts): 162 | """ 163 | Get the 3D coordinates of a full tile of a control point grid. 164 | 165 | Parameters 166 | ---------- 167 | grid_size: `int` 168 | The size of the grid. 169 | grid_num_pts: `int` 170 | The number of points along one side of the grid. 171 | 172 | Returns 173 | ------- 174 | tile: `jax.Array` 175 | The 3D coordinates. 176 | """ 177 | tile = get_grid_tile_quarter(grid_size, grid_num_pts) 178 | 179 | return calculate_grid_from_tile_quarter(tile) 180 | 181 | 182 | def calculate_grid_from_tile_full(tile): 183 | """ 184 | Generate an ordered grid of control points from a full tile. 185 | 186 | Parameters 187 | ---------- 188 | tile: `jax.Array` 189 | The 3D coordinates of a full tile. 190 | 191 | Returns 192 | ------- 193 | grid_points: `jax.Array` 194 | The 3D coordinates of the grid. 195 | """ 196 | grid_points = tile 197 | 198 | return grid_points 199 | 200 | 201 | # =============================================================================== 202 | # Grids 203 | # =============================================================================== 204 | 205 | class PointGrid: 206 | """ 207 | A grid of control points. 208 | 209 | Parameters 210 | ---------- 211 | tile: `jax.Array` 212 | The 3D coordinates of a tile. 213 | num_pts: `int` 214 | The number of points along one side of the grid. 215 | 216 | Notes 217 | ----- 218 | The order of the points in a 4x4 grid must be: 219 | 220 | 3 7 11 15 221 | 2 6 10 14 222 | 1 5 9 13 223 | 0 4 8 12 224 | """ 225 | def __init__(self, tile, num_pts) -> None: 226 | self.tile = tile 227 | self.num_pts = num_pts 228 | 229 | # NOTE: indices are hard-coded from Rhino to match expected grid order. 230 | self.indices = [15, 13, 5, 7, 14, 12, 4, 6, 10, 8, 0, 2, 11, 9, 1, 3] 231 | 232 | def points(self, transform=None): 233 | """ 234 | Get the reindexed and transformed control points of the grid. 235 | 236 | Parameters 237 | ---------- 238 | transform: `jax.Array` or `None`, optional 239 | The translation vector. 240 | If `None`, the control points are returned without any transformation. 241 | 242 | Returns 243 | ------- 244 | points: `jax.Array` 245 | The control points. 246 | """ 247 | tile = self.tile 248 | if transform is not None: 249 | tile = self.tile + transform 250 | 251 | points = self.points_grid(tile) 252 | grid_points = self.reindex_grid(points) 253 | 254 | return jnp.reshape(grid_points, (self.num_pts, self.num_pts, 3)) 255 | 256 | def reindex_grid(self, points): 257 | """ 258 | Reconfigure the grid using hard-coded indices. 259 | 260 | Parameters 261 | ---------- 262 | points: `jax.Array` 263 | The control points. 264 | 265 | Returns 266 | ------- 267 | reindexed_points: `jax.Array` 268 | The reindexed control points. 269 | """ 270 | return points[self.indices, :] 271 | 272 | def points_grid(self, tile): 273 | """ 274 | Generate the control points of the grid from a tile. 275 | 276 | Parameters 277 | ---------- 278 | tile: `jax.Array` 279 | The 3D coordinates of a tile. 280 | 281 | Returns 282 | ------- 283 | points: `jax.Array` 284 | The control points. 285 | """ 286 | raise NotImplementedError 287 | 288 | 289 | class PointGridSymmetricDouble(PointGrid): 290 | """ 291 | A doubly-symmetric grid of control points. 292 | 293 | Parameters 294 | ---------- 295 | size: `int` 296 | The size of the grid. 297 | num_pts: `int` 298 | The number of points along one side of the grid. 299 | """ 300 | def __init__(self, size, num_pts): 301 | tile = get_grid_tile_quarter(size, num_pts) 302 | super().__init__(tile, num_pts) 303 | 304 | def points_grid(self, tile): 305 | return calculate_grid_from_tile_quarter(tile) 306 | 307 | 308 | class PointGridSymmetric(PointGrid): 309 | """ 310 | A symmetric grid of control points. 311 | 312 | Parameters 313 | ---------- 314 | size: `int` 315 | The size of the grid. 316 | num_pts: `int` 317 | The number of points along one side of the grid. 318 | """ 319 | def __init__(self, size, num_pts): 320 | tile = get_grid_tile_half(size, num_pts) 321 | super().__init__(tile, num_pts) 322 | 323 | def points_grid(self, tile): 324 | """ 325 | Generate the control points of the grid from a tile. 326 | 327 | Parameters 328 | ---------- 329 | tile: `jax.Array` 330 | The 3D coordinates of a tile. 331 | 332 | Returns 333 | ------- 334 | points: `jax.Array` 335 | The control points. 336 | """ 337 | return calculate_grid_from_tile_half(tile) 338 | 339 | 340 | class PointGridAsymmetric(PointGrid): 341 | """ 342 | An asymmetric grid of control points. 343 | 344 | Parameters 345 | ---------- 346 | size: `int` 347 | The size of the grid. 348 | num_pts: `int` 349 | The number of points along one side of the grid. 350 | """ 351 | def __init__(self, size, num_pts): 352 | tile = get_grid_tile_full(size, num_pts) 353 | super().__init__(tile, num_pts) 354 | 355 | def points_grid(self, tile): 356 | """ 357 | Generate the control points of the grid from a tile. 358 | 359 | Parameters 360 | ---------- 361 | tile: `jax.Array` 362 | The 3D coordinates of a tile. 363 | 364 | Returns 365 | ------- 366 | points: `jax.Array` 367 | The control points. 368 | """ 369 | return calculate_grid_from_tile_full(tile) 370 | -------------------------------------------------------------------------------- /src/neural_fdm/generators/tubes.py: -------------------------------------------------------------------------------- 1 | from jax import vmap 2 | 3 | import jax.random as jrn 4 | 5 | import jax.numpy as jnp 6 | 7 | from neural_fdm.generators.generator import PointGenerator 8 | 9 | 10 | # =============================================================================== 11 | # Generators 12 | # =============================================================================== 13 | 14 | class TubePointGenerator(PointGenerator): 15 | """ 16 | A generator that outputs point evaluated on a wiggled tube. 17 | """ 18 | pass 19 | 20 | 21 | class EllipticalTubePointGenerator(TubePointGenerator): 22 | """ 23 | A generator that outputs point evaluated on a wiggled elliptical tube. 24 | 25 | Parameters 26 | ---------- 27 | height: `float` 28 | The height of the tube. 29 | radius: `float` 30 | The reference radius of the tube. 31 | num_sides: `int` 32 | The number of sides per ellipse. 33 | num_levels: `int` 34 | The number of levels along the height of the tube. 35 | num_rings: `int` 36 | The number of levels that will work as compression rings. The first and last levels are fully supported. 37 | minval: `jax.Array` 38 | The minimum values of the space of random transformations. 39 | maxval: `jax.Array` 40 | The maximum values of the space of random transformations. 41 | """ 42 | def __init__( 43 | self, 44 | height, 45 | radius, 46 | num_sides, 47 | num_levels, 48 | num_rings, 49 | minval, 50 | maxval): 51 | 52 | # sanity checks 53 | assert num_rings >= 3, "Must include at least 1 ring in the middle!" 54 | self._check_array_shapes(num_rings, minval, maxval) 55 | 56 | self.height = height 57 | self.radius = radius 58 | 59 | self.num_sides = num_sides 60 | self.num_levels = num_levels 61 | self.num_rings = num_rings 62 | 63 | self.minval = minval 64 | self.maxval = maxval 65 | 66 | self.levels_rings_comp = self._levels_rings_compression() 67 | self.indices_rings_comp_ravel = self._indices_rings_compression_ravel() 68 | self.indices_rings_comp_interior_ravel = self._indices_rings_compression_interior_ravel() 69 | 70 | self.levels_rings_tension = self._levels_rings_tension() 71 | 72 | self.shape_tube = (num_levels, num_sides, 3) 73 | self.shape_rings = (num_rings, num_sides, 3) 74 | 75 | def __call__(self, key, wiggle=True): 76 | """ 77 | Generate points. 78 | 79 | Parameters 80 | ---------- 81 | key: `jax.random.PRNGKey` 82 | The random key. 83 | wiggle: `bool`, optional 84 | Whether to wiggle the points at random. 85 | 86 | Returns 87 | ------- 88 | points: `jax.Array` 89 | The points on the tube. 90 | """ 91 | points = self.points_on_tube(key, wiggle) 92 | 93 | return jnp.ravel(points) 94 | 95 | def _levels_rings_tension(self): 96 | """ 97 | Compute the integer indices of the levels that work as tension rings. 98 | 99 | Returns 100 | ------- 101 | indices: `jax.Array` 102 | The indices. 103 | """ 104 | indices = [i for i in range(self.num_levels) if i not in self.levels_rings_comp] 105 | indices = jnp.array(indices, dtype=jnp.int64) 106 | 107 | assert indices.size == self.num_levels - self.num_rings 108 | 109 | return indices 110 | 111 | def _levels_rings_compression(self): 112 | """ 113 | Compute the integer indices of the levels that work as compression rings. 114 | 115 | Returns 116 | ------- 117 | indices: `jax.Array` 118 | The indices. 119 | """ 120 | step = int(self.num_levels / (self.num_rings - 1)) 121 | 122 | indices = [0] + list(range(step, self.num_levels - 1, step)) + [self.num_levels - 1] 123 | indices = jnp.array(indices, dtype=jnp.int64) 124 | 125 | assert indices.size == self.num_rings 126 | 127 | return indices 128 | 129 | def _indices_rings_compression_ravel(self): 130 | """ 131 | Compute the integer indices of the vertices in the compression rings. 132 | 133 | Returns 134 | ------- 135 | indices: `jax.Array` 136 | The indices. 137 | """ 138 | indices = [] 139 | for index in self.levels_rings_comp: 140 | start = index * self.num_sides 141 | end = start + self.num_sides 142 | indices.extend(range(start, end)) 143 | 144 | indices = jnp.array(indices, dtype=jnp.int64) 145 | 146 | return indices 147 | 148 | def _indices_rings_compression_interior_ravel(self): 149 | """ 150 | Compute the integer indices of the vertices in the unsupported compression rings. 151 | 152 | Returns 153 | ------- 154 | indices: `jax.Array` 155 | The indices. 156 | """ 157 | indices = [] 158 | for index in self.levels_rings_comp[1:-1]: 159 | start = index * self.num_sides 160 | end = start + self.num_sides 161 | indices.extend(range(start, end)) 162 | 163 | indices = jnp.array(indices, dtype=jnp.int64) 164 | 165 | return indices 166 | 167 | def wiggle(self, key): 168 | """ 169 | Sample random radii and angles from a uniform distribution. 170 | 171 | Parameters 172 | ---------- 173 | key: `jax.random.PRNGKey` 174 | The random key. 175 | 176 | Returns 177 | ------- 178 | transform: tuple of `jax.Array` 179 | The transformation factors for the radii and angles. 180 | """ 181 | return self.wiggle_radii(key), self.wiggle_angle(key) 182 | 183 | def wiggle_radii(self, key): 184 | """ 185 | Sample random radii from a uniform distribution. 186 | 187 | Parameters 188 | ---------- 189 | key: `jax.random.PRNGKey` 190 | The random key. 191 | 192 | Returns 193 | ------- 194 | radii: `jax.Array` 195 | The random radii. 196 | """ 197 | shape = (self.num_rings, 2) 198 | minval = self.minval[:2] 199 | maxval = self.maxval[:2] 200 | 201 | return jrn.uniform(key, shape=shape, minval=minval, maxval=maxval) 202 | 203 | def wiggle_angle(self, key): 204 | """ 205 | Sample random angles from a uniform distribution. 206 | 207 | Parameters 208 | ---------- 209 | key: `jax.random.PRNGKey` 210 | The random key. 211 | 212 | Returns 213 | ------- 214 | angles: `jax.Array` 215 | The random angles. 216 | """ 217 | shape = (self.num_rings,) 218 | minval = self.minval[2] 219 | maxval = self.maxval[2] 220 | 221 | return jrn.uniform(key, shape=shape, minval=minval, maxval=maxval) 222 | 223 | def evaluate_points(self, transform): 224 | """ 225 | Generate wiggled points. 226 | 227 | Parameters 228 | ---------- 229 | transform: tuple of `jax.Array` 230 | The random radii and angles. 231 | 232 | Returns 233 | ------- 234 | points: `jax.Array` 235 | The points. 236 | """ 237 | heights = jnp.linspace(0.0, self.height, self.num_levels) 238 | radii = jnp.ones(shape=(self.num_levels, 2)) * self.radius 239 | angles = jnp.ones(shape=(self.num_levels,)) 240 | 241 | wiggle_radii, wiggle_angle = transform 242 | wiggle_radii = wiggle_radii * self.radius 243 | radii = radii.at[self.levels_rings_comp, :].set(wiggle_radii) 244 | angles = angles.at[self.levels_rings_comp].set(wiggle_angle) 245 | 246 | points = points_on_ellipses( 247 | radii[:, 0], 248 | radii[:, 1], 249 | heights, 250 | self.num_sides, 251 | angles, 252 | ) 253 | 254 | return jnp.ravel(points) 255 | 256 | def points_on_tube(self, key=None, wiggle=False): 257 | """ 258 | Evaluate wiggled points on the tube. 259 | 260 | Parameters 261 | ---------- 262 | key: `jax.random.PRNGKey` 263 | The random key. 264 | wiggle: `bool`, optional 265 | Whether to wiggle the points at random. 266 | 267 | Returns 268 | ------- 269 | points: `jax.Array` 270 | The points on the tube. 271 | """ 272 | heights = jnp.linspace(0.0, self.height, self.num_levels) 273 | radii = jnp.ones(shape=(self.num_levels, 2)) * self.radius 274 | angles = jnp.ones(shape=(self.num_levels,)) 275 | 276 | if wiggle: 277 | wiggle_radii, wiggle_angle = self.wiggle(key) 278 | wiggle_radii = wiggle_radii * self.radius 279 | radii = radii.at[self.levels_rings_comp, :].set(wiggle_radii) 280 | angles = angles.at[self.levels_rings_comp].set(wiggle_angle) 281 | 282 | points = points_on_ellipses( 283 | radii[:, 0], 284 | radii[:, 1], 285 | heights, 286 | self.num_sides, 287 | angles, 288 | ) 289 | 290 | return points 291 | 292 | def _check_array_shapes(self, num_rings, minval, maxval): 293 | """ 294 | Verify that input shapes are consistent. 295 | 296 | Parameters 297 | ---------- 298 | num_rings: `int` 299 | The number of rings. 300 | minval: `jax.Array` 301 | The minimum values of the space of random transformations. 302 | maxval: `jax.Array` 303 | The maximum values of the space of random transformations. 304 | """ 305 | shape = (3, ) 306 | minval_shape = minval.shape 307 | maxval_shape = maxval.shape 308 | 309 | assert minval_shape == shape, f"{minval_shape} vs. {shape}" 310 | assert maxval_shape == shape, f"{maxval_shape} vs. {shape}" 311 | 312 | 313 | class CircularTubePointGenerator(EllipticalTubePointGenerator): 314 | """ 315 | A generator that outputs point evaluated on a wiggled circular tube. 316 | """ 317 | def wiggle_radii(self, key): 318 | """ 319 | Sample random radii from a uniform distribution. 320 | 321 | Parameters 322 | ---------- 323 | key: `jax.random.PRNGKey` 324 | The random key. 325 | 326 | Returns 327 | ------- 328 | radii: `jax.Array` 329 | The random radii. 330 | """ 331 | shape = (self.num_rings,) 332 | minval = self.minval[0] 333 | maxval = self.maxval[0] 334 | 335 | return jrn.uniform(key, shape=shape, minval=minval, maxval=maxval) 336 | 337 | def points_on_tube(self, key=None, wiggle=False): 338 | """ 339 | Evaluate wiggled points on the tube. 340 | 341 | Parameters 342 | ---------- 343 | key: `jax.random.PRNGKey` 344 | The random key. 345 | wiggle: `bool`, optional 346 | Whether to wiggle the points at random. 347 | 348 | Returns 349 | ------- 350 | points: `jax.Array` 351 | The points on the tube. 352 | """ 353 | heights = jnp.linspace(0.0, self.height, self.num_levels) 354 | radii = jnp.ones(shape=(self.num_levels,)) * self.radius 355 | angles = jnp.ones(shape=(self.num_levels,)) 356 | 357 | if wiggle: 358 | wiggle_radii, wiggle_angle = self.wiggle(key) 359 | wiggle_radii = wiggle_radii * self.radius 360 | radii = radii.at[self.levels_rings_comp].set(wiggle_radii) 361 | angles = angles.at[self.levels_rings_comp].set(wiggle_angle) 362 | 363 | points = points_on_ellipses( 364 | radii, 365 | radii, 366 | heights, 367 | self.num_sides, 368 | angles, 369 | ) 370 | 371 | return points 372 | 373 | 374 | # =============================================================================== 375 | # Helper functions 376 | # =============================================================================== 377 | 378 | def points_on_ellipse_xy(radius_1, radius_2, num_sides, angle=0.0): 379 | """ 380 | Sample points on an ellipse on the XY plane. 381 | 382 | Parameters 383 | ---------- 384 | radius_1: `float` 385 | The radius of the ellipse along the X axis. 386 | radius_2: `float` 387 | The radius of the ellipse along the Y axis. 388 | num_sides: `int` 389 | The number of sides of the ellipse. 390 | angle: `float`, optional 391 | The angle of the ellipse in degrees relative to the X axis. 392 | 393 | Returns 394 | ------- 395 | points: `jax.Array` 396 | The points. 397 | 398 | Notes 399 | ----- 400 | The first and last points are not equal. 401 | """ 402 | angles = 2 * jnp.pi * jnp.linspace(0.0, 1.0, num_sides + 1) 403 | angles = jnp.reshape(angles, (-1, 1)) 404 | xs = radius_1 * jnp.cos(angles) 405 | ys = radius_2 * jnp.sin(angles) 406 | 407 | points = jnp.hstack((xs, ys))[:-1] 408 | 409 | # Calculate rotation matrix 410 | theta = jnp.radians(angle) 411 | rotation_matrix = jnp.array([ 412 | [jnp.cos(theta), -jnp.sin(theta)], 413 | [jnp.sin(theta), jnp.cos(theta)] 414 | ]) 415 | 416 | # Rotate points 417 | points = points @ rotation_matrix.T 418 | 419 | return points 420 | 421 | 422 | def points_on_ellipse(radius_1, radius_2, height, num_sides, angle=0.0): 423 | """ 424 | Sample points on a planar ellipse at a given height. 425 | 426 | Parameters 427 | ---------- 428 | radius_1: `float` 429 | The radius of the ellipse along the X axis. 430 | radius_2: `float` 431 | The radius of the ellipse along the Y axis. 432 | height: `float` 433 | The height of the ellipse. 434 | num_sides: `int` 435 | The number of sides of the ellipse. 436 | angle: `float`, optional 437 | The angle of the ellipse in degrees relative to the X axis. 438 | 439 | Returns 440 | ------- 441 | points: `jax.Array` 442 | The points. 443 | 444 | Notes 445 | ----- 446 | The first and last points are not equal. 447 | """ 448 | xy = points_on_ellipse_xy(radius_1, radius_2, num_sides, angle) 449 | z = jnp.ones((num_sides, 1)) * height 450 | 451 | return jnp.hstack((xy, z)) 452 | 453 | 454 | def points_on_ellipses(radius_1, radius_2, heights, num_sides, angles): 455 | """ 456 | Sample points on an sequence of ellipses distributed over an array of heights. 457 | 458 | Parameters 459 | ---------- 460 | radius_1: `jax.Array` 461 | The radii of the ellipses along the X axis. 462 | radius_2: `jax.Array` 463 | The radii of the ellipses along the Y axis. 464 | heights: `jax.Array` 465 | The heights of the ellipses. 466 | num_sides: `int` 467 | The number of sides of the ellipses. 468 | angles: `jax.Array` 469 | The angles of the ellipses in degrees relative to the X axis. 470 | 471 | Returns 472 | ------- 473 | points: `jax.Array` 474 | The points on the ellipses. 475 | 476 | Notes 477 | ----- 478 | The first and last points per ellipse are not equal. 479 | """ 480 | polygon_fn = vmap(points_on_ellipse, in_axes=(0, 0, 0, None, 0)) 481 | 482 | return polygon_fn(radius_1, radius_2, heights, num_sides, angles) 483 | 484 | 485 | # =============================================================================== 486 | # Main 487 | # =============================================================================== 488 | 489 | if __name__ == "__main__": 490 | 491 | from compas.geometry import Polygon 492 | from jax_fdm.visualization import Viewer 493 | from neural_fdm.generators import EllipticalTubePointGenerator 494 | 495 | height = 10.0 496 | radius = 2.0 497 | radius_1 = 1.0 498 | radius_2 = 1.25 499 | 500 | num_sides = 4 501 | num_levels = 11 # Use 11 or 21 or 31 502 | num_rings = 3 # Use 3 or 4, 2 of them will be supported 503 | 504 | xy = points_on_ellipse_xy(radius_1, radius_2, num_sides) 505 | assert xy.shape == (num_sides, 2) 506 | xyz = jnp.hstack((xy, jnp.ones((num_sides, 1)) * height)) 507 | assert xyz.shape == (num_sides, 3) 508 | 509 | xyz2 = points_on_ellipse(radius_1, radius_2, height, num_sides) 510 | assert xyz2.shape == (num_sides, 3) 511 | assert jnp.allclose(xyz, xyz2) 512 | 513 | heights = jnp.linspace(0, height, num_levels) 514 | r1 = jnp.ones_like(heights) * radius_1 515 | r2 = jnp.ones_like(heights) * radius_2 516 | angles = jnp.zeros_like(heights) 517 | xyzs = points_on_ellipses(r1, r2, heights, num_sides, angles) 518 | assert xyzs.shape == (num_levels, num_sides, 3) 519 | 520 | print("\nGenerating") 521 | generator = EllipticalTubePointGenerator( 522 | height, 523 | radius, 524 | num_sides, 525 | num_levels, 526 | num_rings, 527 | minval=jnp.array([0.5, 0.5, 0.0]), 528 | maxval=jnp.array([2.0, 2.0, 0.0]), 529 | ) 530 | 531 | print(generator.indices_rings, generator.shape_tube) 532 | 533 | # randomness 534 | seed = 91 535 | key = jrn.PRNGKey(seed) 536 | _, generator_key = jrn.split(key, 2) 537 | batch_size = 3 538 | 539 | # sample data batch 540 | xyz_batch = vmap(generator)(jrn.split(generator_key, batch_size)) 541 | print(f"{xyz_batch.shape=}") 542 | xyz_ellipse_batch = vmap(generator.points_on_ellipses)(jrn.split(generator_key, batch_size)) 543 | 544 | # xyz_batch = jnp.reshape(xyz_batch, (batch_size, num_levels, num_sides, 3)) 545 | # print(f"{xyz_batch.shape=}") 546 | 547 | for xyzs, xyzs_ellipse in zip(xyz_batch, xyz_ellipse_batch): 548 | print(f"{xyzs.shape=}") 549 | xyzs = jnp.reshape(xyzs, generator.shape_rings) 550 | assert jnp.allclose(xyzs, xyzs_ellipse) 551 | raise 552 | # assert xyzs.shape == (num_rings, num_sides, 3) 553 | # print("Generated") 554 | 555 | # print("Viewing") 556 | # viewer = Viewer(width=1600, height=900, show_grid=True) 557 | 558 | # for xyz in xyzs: 559 | # polygon = Polygon(xyz.tolist()) 560 | # viewer.add(polygon, opacity=0.5) 561 | 562 | # viewer.show() 563 | 564 | print("Viewing") 565 | from neural_fdm.builders import build_mesh_from_generator 566 | from jax_fdm.equilibrium import fdm 567 | from jax_fdm.datastructures import FDNetwork 568 | 569 | mesh = build_mesh_from_generator(generator) 570 | mesh.edges_forcedensities(10.0) 571 | # mesh = fdm(mesh, sparse=False) 572 | 573 | viewer = Viewer(width=1600, height=900, show_grid=True) 574 | viewer.add(mesh, opacity=0.5) 575 | viewer.add(FDNetwork.from_mesh(mesh), show_nodes=True, nodesize=0.2) 576 | viewer.show() 577 | -------------------------------------------------------------------------------- /src/neural_fdm/helpers.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from jax_fdm.equilibrium import EquilibriumParametersState as FDParametersState 4 | from jax_fdm.equilibrium import EquilibriumState 5 | from jax_fdm.equilibrium import LoadState 6 | from jax_fdm.equilibrium import nodes_load_from_faces 7 | 8 | 9 | # =============================================================================== 10 | # Load helpers 11 | # =============================================================================== 12 | 13 | def calculate_area_loads(x, structure, load): 14 | """ 15 | Convert area loads into vertex loads. 16 | 17 | Parameters 18 | ---------- 19 | x: `jax.Array` 20 | The 3D coordinates of the vertices. 21 | structure: `jax_fdm.EquilibriumStructure` 22 | A structure with the discretization of the shape. 23 | load: `float` 24 | The vertical load per unit area in the `z` direction. 25 | 26 | Returns 27 | ------- 28 | vertices_load: `jax.Array` 29 | The 3D vertex loads. 30 | """ 31 | x = jnp.reshape(x, (-1, 3)) 32 | 33 | # need to convert loads into face loads 34 | num_faces = structure.num_faces 35 | faces_load_xy = jnp.zeros(shape=(num_faces, 2)) # (num_faces, xy) 36 | faces_load_z = jnp.ones(shape=(num_faces, 1)) * load # (num_faces, xy) 37 | faces_load = jnp.hstack((faces_load_xy, faces_load_z)) 38 | 39 | vertices_load = nodes_load_from_faces( 40 | x, 41 | faces_load, 42 | structure, 43 | is_local=False 44 | ) 45 | 46 | return vertices_load 47 | 48 | 49 | def calculate_constant_loads(x, structure, load): 50 | """ 51 | Create constant vertical vertex loads. 52 | 53 | Parameters 54 | ---------- 55 | x: `jax.Array` 56 | The 3D coordinates of the vertices. 57 | structure: `jax_fdm.EquilibriumStructure` 58 | A structure with the discretization of the shape. 59 | load: `float` 60 | The vertical load per vertex in the `z` direction. 61 | 62 | Returns 63 | ------- 64 | vertices_load: `jax.Array` 65 | The 3D vertex loads. 66 | """ 67 | num_vertices = structure.num_vertices 68 | # (num_vertices, xy) 69 | vertices_load_xy = jnp.zeros(shape=(num_vertices, 2)) 70 | # (num_vertices, xy) 71 | vertices_load_z = jnp.ones(shape=(num_vertices, 1)) * load 72 | 73 | return jnp.hstack((vertices_load_xy, vertices_load_z)) 74 | 75 | 76 | # =============================================================================== 77 | # Form-finding helpers 78 | # =============================================================================== 79 | 80 | 81 | def edges_vectors(xyz, connectivity): 82 | """ 83 | Calculate the unnormalized edge directions (nodal coordinate differences). 84 | 85 | Parameters 86 | ---------- 87 | xyz: `jax.Array` 88 | The 3D coordinates of the vertices. 89 | connectivity: `jax.Array` 90 | The connectivity matrix of the structure. 91 | 92 | Returns 93 | ------- 94 | vectors: `jax.Array` 95 | The edge vectors. 96 | """ 97 | return connectivity @ xyz 98 | 99 | 100 | def edges_lengths(vectors): 101 | """ 102 | Compute the length of the edge vectors. 103 | 104 | Parameters 105 | ---------- 106 | vectors: `jax.Array` 107 | The edge vectors. 108 | 109 | Returns 110 | ------- 111 | lengths: `jax.Array` 112 | The lengths. 113 | """ 114 | return jnp.linalg.norm(vectors, axis=1, keepdims=True) 115 | 116 | 117 | def edges_forces(q, lengths): 118 | """ 119 | Calculate the force in the edges. 120 | 121 | Parameters 122 | ---------- 123 | q: `jax.Array` 124 | The force densities. 125 | lengths: `jax.Array` 126 | The edge lengths. 127 | 128 | Returns 129 | ------- 130 | forces: `jax.Array` 131 | The forces in the edges. 132 | """ 133 | return jnp.reshape(q, (-1, 1)) * lengths 134 | 135 | 136 | def vertices_residuals(q, loads, vectors, connectivity): 137 | """ 138 | Compute the residual forces on the vertices of the structure. 139 | 140 | Parameters 141 | ---------- 142 | q: `jax.Array` 143 | The force densities. 144 | loads: `jax.Array` 145 | The loads on the vertices. 146 | vectors: `jax.Array` 147 | The edge vectors. 148 | connectivity: `jax.Array` 149 | The connectivity matrix of the structure. 150 | 151 | Returns 152 | ------- 153 | residuals: `jax.Array` 154 | The residual forces on the vertices. 155 | """ 156 | return loads - connectivity.T @ (q[:, None] * vectors) 157 | 158 | 159 | def vertices_residuals_from_xyz(q, loads, xyz, structure): 160 | """ 161 | Compute the residual forces on the vertices of the structure. 162 | 163 | Parameters 164 | ---------- 165 | q: `jax.Array` 166 | The force densities. 167 | loads: `jax.Array` 168 | The loads on the vertices. 169 | xyz: `jax.Array` 170 | The 3D coordinates of the vertices. 171 | structure: `jax_fdm.EquilibriumStructure` 172 | A structure with the discretization of the shape. 173 | 174 | Returns 175 | ------- 176 | residuals: `jax.Array` 177 | The residual forces on the vertices. 178 | """ 179 | connectivity = structure.connectivity 180 | 181 | xyz = jnp.reshape(xyz, (-1, 3)) 182 | vectors = edges_vectors(xyz, connectivity) 183 | 184 | return vertices_residuals(q, loads, vectors, connectivity) 185 | 186 | 187 | def calculate_equilibrium_state(q, xyz, loads_nodes, structure): 188 | """ 189 | Assembles an equilibrium state object. 190 | 191 | Parameters 192 | ---------- 193 | q: `jax.Array` 194 | The force densities. 195 | xyz: `jax.Array` 196 | The 3D coordinates of the vertices. 197 | loads_nodes: `jax.Array` 198 | The loads on the vertices. 199 | structure: `jax_fdm.EquilibriumStructure` 200 | A structure with the discretization of the shape. 201 | 202 | Returns 203 | ------- 204 | state: `jax_fdm.EquilibriumState` 205 | The equilibrium state. 206 | """ 207 | connectivity = structure.connectivity 208 | 209 | vectors = edges_vectors(xyz, connectivity) 210 | lengths = edges_lengths(vectors) 211 | residuals = vertices_residuals(q, loads_nodes, vectors, connectivity) 212 | forces = edges_forces(q, lengths) 213 | 214 | return EquilibriumState( 215 | xyz=xyz, 216 | residuals=residuals, 217 | lengths=lengths, 218 | forces=forces, 219 | loads=loads_nodes, 220 | vectors=vectors 221 | ) 222 | 223 | 224 | def calculate_fd_params_state(q, xyz_fixed, loads_nodes): 225 | """ 226 | Assembles an simulation parameters state. 227 | 228 | Parameters 229 | ---------- 230 | q: `jax.Array` 231 | The force densities. 232 | xyz_fixed: `jax.Array` 233 | The 3D coordinates of the fixed vertices. 234 | loads_nodes: `jax.Array` 235 | The loads on the vertices. 236 | 237 | Returns 238 | ------- 239 | state: `jax_fdm.EquilibriumParametersState` 240 | The current state of the simulation parameters. 241 | """ 242 | return FDParametersState(q, xyz_fixed, LoadState(loads_nodes, 0.0, 0.0)) 243 | -------------------------------------------------------------------------------- /src/neural_fdm/losses.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import vmap 3 | 4 | from neural_fdm.helpers import vertices_residuals_from_xyz 5 | from neural_fdm.models import AutoEncoderPiggy 6 | 7 | 8 | # =============================================================================== 9 | # Loss assemblers 10 | # =============================================================================== 11 | 12 | def compute_loss( 13 | model, 14 | structure, 15 | x, 16 | loss_fn, 17 | loss_params, 18 | aux_data=False, 19 | piggy_mode=False 20 | ): 21 | """ 22 | Compute the model loss according to the model type. 23 | 24 | Parameters 25 | ---------- 26 | model: `eqx.Module` 27 | The model. 28 | structure: `jax_fdm.EquilibriumStructure` 29 | A structure with the discretization of the shape. 30 | x: `jax.Array` 31 | The target shape. 32 | loss_fn: `Callable` 33 | The loss function. 34 | loss_params: `dict` 35 | The scaling parameters to combine the loss' error terms. 36 | aux_data: `bool` 37 | If true, returns auxiliary data. 38 | piggy_mode: `bool` 39 | If true, the model is a piggy autoencoder. 40 | 41 | Returns 42 | ------- 43 | loss: `float` or `tuple` 44 | The loss. If `aux_data` is `True`, returns a tuple of the loss and the loss terms. 45 | """ 46 | predict_fn = vmap(model, in_axes=(0, None, None, None)) 47 | x_hat, data_hat = predict_fn(x, structure, True, piggy_mode) 48 | 49 | # TODO: make _loss_fn an input, and make isinstance check before running this function 50 | _loss_fn = _compute_loss 51 | if isinstance(model, AutoEncoderPiggy): 52 | _loss_fn = _compute_loss_piggy 53 | 54 | loss = _loss_fn( 55 | loss_fn, 56 | loss_params, 57 | x, 58 | x_hat, 59 | data_hat, 60 | structure, 61 | aux_data, 62 | piggy_mode 63 | ) 64 | 65 | return loss 66 | 67 | 68 | def _compute_loss( 69 | loss_fn, 70 | loss_params, 71 | x, 72 | x_hat, 73 | params_hat, 74 | structure, 75 | aux_data, 76 | piggy_mode=False 77 | ): 78 | """ 79 | Compute the model loss of an autoencoder. 80 | 81 | Parameters 82 | ---------- 83 | loss_fn: `Callable` 84 | The loss function. 85 | loss_params: `dict` 86 | The scaling parameters to combine the loss' error terms. 87 | x: `jax.Array` 88 | The target shape. 89 | x_hat: `jax.Array` 90 | The predicted shape. 91 | params_hat: tuple of `jax.Array` 92 | The predicted force densities, loads, and fixed positions. 93 | structure: `jax_fdm.EquilibriumStructure` 94 | A structure with the discretization of the shape. 95 | aux_data: `bool` 96 | If true, returns auxiliary data. 97 | piggy_mode: `bool` 98 | If true, the model is a piggy autoencoder. 99 | 100 | Returns 101 | ------- 102 | loss: `float` or `tuple` 103 | The loss. If `aux_data` is `True`, returns a tuple of the loss and the loss terms. 104 | """ 105 | return loss_fn(x, x_hat, params_hat, structure, loss_params, aux_data) 106 | 107 | 108 | def _compute_loss_piggy( 109 | loss_fn, 110 | loss_params, 111 | x, 112 | x_data_hat, 113 | y_data_hat, 114 | structure, 115 | aux_data, 116 | piggy_mode=True, 117 | ): 118 | """ 119 | Compute the loss of a piggy autoencoder. 120 | 121 | Parameters 122 | ---------- 123 | loss_fn: `Callable` 124 | The loss function. 125 | loss_params: `dict` 126 | The scaling parameters to combine the loss' error terms. 127 | x: `jax.Array` 128 | The target shape. 129 | x_data_hat: `tuple` 130 | The predicted shape and the predicted parameters. 131 | y_data_hat: `tuple` 132 | The predicted shape and the predicted parameters. 133 | structure: `jax_fdm.EquilibriumStructure` 134 | A structure with the discretization of the shape. 135 | aux_data: `bool` 136 | If true, returns auxiliary data. 137 | piggy_mode: `bool` 138 | If true, the model is a piggy autoencoder. 139 | 140 | Returns 141 | ------- 142 | loss: `float` or `tuple` 143 | The loss. If `aux_data` is `True`, returns a tuple of the loss and the loss terms. 144 | """ 145 | x_hat, x_params_hat = x_data_hat 146 | 147 | if not piggy_mode: 148 | loss_data = loss_fn(x, x_hat, x_params_hat, structure, loss_params, aux_data) 149 | else: 150 | y_hat, y_params_hat = y_data_hat 151 | loss_data = loss_fn(x_hat, y_hat, y_params_hat, structure, loss_params, aux_data) 152 | 153 | return loss_data 154 | 155 | 156 | # =============================================================================== 157 | # Task losses 158 | # =============================================================================== 159 | 160 | def compute_loss_shell( 161 | x, 162 | x_hat, 163 | params_hat, 164 | structure, 165 | loss_params, 166 | aux_data, 167 | *args 168 | ): 169 | """ 170 | Compute the loss for the shell task. 171 | 172 | Parameters 173 | ---------- 174 | x: `jax.Array` 175 | The target shape. 176 | x_hat: `jax.Array` 177 | The predicted shape. 178 | params_hat: tuple of `jax.Array` 179 | The predicted force densities, loads, and fixed positions. 180 | structure: `jax_fdm.EquilibriumStructure` 181 | A structure with the discretization of the shape. 182 | loss_params: `dict` 183 | The scaling parameters to combine the loss' error terms. 184 | aux_data: `bool` 185 | If true, returns auxiliary data. 186 | 187 | Returns 188 | ------- 189 | loss: `float` or `tuple` 190 | The loss. If `aux_data` is `True`, returns a tuple of the loss and the loss terms. 191 | """ 192 | shape_params = loss_params["shape"] 193 | factor_shape = shape_params["weight"] 194 | loss_shape = compute_error_shape_l1(x, x_hat) 195 | loss_shape = factor_shape * loss_shape 196 | 197 | indices = structure.indices_free 198 | residual_params = loss_params["residual"] 199 | factor_residual = residual_params["weight"] 200 | loss_residual = compute_error_residual(x_hat, params_hat, structure, indices) 201 | loss_residual = factor_residual * loss_residual 202 | 203 | loss = 0.0 204 | if shape_params["include"]: 205 | loss = loss + loss_shape 206 | if residual_params["include"]: 207 | loss = loss + loss_residual 208 | 209 | loss_terms = { 210 | "loss": loss, 211 | "shape error": loss_shape, 212 | "residual error": loss_residual 213 | } 214 | 215 | if aux_data: 216 | return loss, loss_terms 217 | 218 | return loss 219 | 220 | 221 | def compute_loss_tower( 222 | x, 223 | x_hat, 224 | params_hat, 225 | structure, 226 | loss_params, 227 | aux_data, 228 | *args 229 | ): 230 | """ 231 | Compute the loss for the tower task. 232 | 233 | Parameters 234 | ---------- 235 | x: `jax.Array` 236 | The target shape. 237 | x_hat: `jax.Array` 238 | The predicted shape. 239 | params_hat: tuple of `jax.Array` 240 | The predicted force densities, loads, and fixed positions. 241 | structure: `jax_fdm.EquilibriumStructure` 242 | A structure with the discretization of the shape. 243 | loss_params: `dict` 244 | The scaling parameters to combine the loss' error terms. 245 | aux_data: `bool` 246 | If true, returns auxiliary data. 247 | 248 | Returns 249 | ------- 250 | loss: `float` or `tuple` 251 | The loss. If `aux_data` is `True`, returns a tuple of the loss and the loss terms. 252 | """ 253 | # compression ring shape 254 | shape_params = loss_params["shape"] 255 | factor_shape = shape_params["weight"] 256 | shape_dims = shape_params["dims"] 257 | levels_compression = shape_params["levels_compression"] 258 | 259 | def slice_xyz_rings(_x, levels): 260 | return jnp.reshape(_x, shape_dims)[levels, :, :].ravel() 261 | 262 | slice_xyz_vmap = vmap(slice_xyz_rings, in_axes=(0, None)) 263 | xyz_slice = slice_xyz_vmap(x, levels_compression) 264 | xyz_hat_slice = slice_xyz_vmap(x_hat, levels_compression) 265 | assert xyz_slice.shape == xyz_hat_slice.shape 266 | 267 | # NOTE: Using L2 norm here because L1 does not work well 268 | loss_shape = compute_error_shape_l2(xyz_slice, xyz_hat_slice) 269 | loss_shape = factor_shape * loss_shape 270 | 271 | # tension rings height 272 | height_params = loss_params["shape"] 273 | factor_height = height_params["weight"] 274 | height_dims = height_params["dims"] 275 | levels_tension = height_params["levels_tension"] 276 | 277 | def slice_z_rings(_x, levels): 278 | return jnp.reshape(_x, height_dims)[levels, :, 2].ravel() 279 | 280 | slice_z_vmap = vmap(slice_z_rings, in_axes=(0, None)) 281 | z_slice = slice_z_vmap(x, levels_tension) 282 | z_hat_slice = slice_z_vmap(x_hat, levels_tension) 283 | assert z_slice.shape == z_hat_slice.shape 284 | 285 | # NOTE: Using L2 norm here because L1 does not work well 286 | loss_height = compute_error_shape_l2(z_slice, z_hat_slice) 287 | loss_height = factor_height * loss_height 288 | 289 | # Add the shape and height losses 290 | loss_shape = loss_shape + loss_height 291 | 292 | # residual 293 | indices = structure.indices_free 294 | residual_params = loss_params["residual"] 295 | factor_residual = residual_params["weight"] 296 | loss_residual = compute_error_residual(x_hat, params_hat, structure, indices) 297 | loss_residual = factor_residual * loss_residual 298 | 299 | # regularization 300 | regularization_params = loss_params["regularization"] 301 | factor_regularization = regularization_params["weight"] 302 | q = params_hat[0] 303 | regularization = compute_q_regularization(q) 304 | regularization = factor_regularization * regularization 305 | 306 | loss = 0.0 307 | if shape_params["include"]: 308 | loss = loss + loss_shape 309 | if residual_params["include"]: 310 | loss = loss + loss_residual 311 | if regularization_params["include"]: 312 | loss = loss + regularization 313 | 314 | loss_terms = { 315 | "loss": loss, 316 | "shape error": loss_shape, 317 | "residual error": loss_residual, 318 | "regularization": regularization 319 | } 320 | 321 | if aux_data: 322 | return loss, loss_terms 323 | 324 | return loss 325 | 326 | 327 | # =============================================================================== 328 | # Shape approximation error 329 | # =============================================================================== 330 | 331 | def compute_error_shape_l1(x, x_hat): 332 | """ 333 | Calculate the L1 shape reconstruction error, averaged over the batch. 334 | 335 | Parameters 336 | ---------- 337 | x: `jax.Array` 338 | The target shape. 339 | x_hat: `jax.Array` 340 | The predicted shape. 341 | 342 | Returns 343 | ------- 344 | error: `float` 345 | The reconstruction error. 346 | """ 347 | error = jnp.abs(x - x_hat) 348 | batch_error = jnp.sum(error, axis=-1) 349 | 350 | return jnp.mean(batch_error, axis=-1) 351 | 352 | 353 | def compute_error_shape_l2(x, x_hat): 354 | """ 355 | Calculate the L2 shape reconstruction error, averaged over the batch. 356 | 357 | Parameters 358 | ---------- 359 | x: `jax.Array` 360 | The target shape. 361 | x_hat: `jax.Array` 362 | The predicted shape. 363 | 364 | Returns 365 | ------- 366 | error: `float` 367 | The reconstruction error. 368 | """ 369 | error = jnp.square(x - x_hat) 370 | batch_error = jnp.sum(error, axis=-1) 371 | 372 | return jnp.mean(batch_error, axis=-1) 373 | 374 | 375 | # =============================================================================== 376 | # Residual error 377 | # =============================================================================== 378 | 379 | def compute_error_residual(x_hat, params_hat, structure, indices): 380 | """ 381 | Calculate the residual error, averaged over the batch. This is the physics loss. 382 | 383 | Parameters 384 | ---------- 385 | x_hat: `jax.Array` 386 | The predicted shape. 387 | params_hat: tuple of `jax.Array` 388 | The predicted force densities, loads, and fixed positions. 389 | structure: `jax_fdm.EquilibriumStructure` 390 | A structure with the discretization of the shape. 391 | indices: `jax.Array` 392 | The indices of the free vertices to calculate the residual at. 393 | 394 | Returns 395 | ------- 396 | error: `float` 397 | The residual error. 398 | """ 399 | def calculate_residuals(_x_hat, _params_hat): 400 | # NOTE: Not using jnp.linalg.norm because we hitted NaNs. 401 | q_hat, xyz_fixed, loads = _params_hat 402 | residual_vectors = vertices_residuals_from_xyz(q_hat, loads, _x_hat, structure) 403 | residual_vectors_free = jnp.ravel(residual_vectors[indices, :]) 404 | 405 | # return jnp.linalg.norm(residual_vectors_free, axis=-1) 406 | # return jnp.sqrt(jnp.sum(jnp.square(residual_vectors_free), axis=-1)) 407 | return jnp.square(residual_vectors_free) 408 | 409 | residuals = vmap(calculate_residuals)(x_hat, params_hat) 410 | shape_residuals = jnp.sqrt(jnp.sum(residuals, axis=-1)) 411 | batch_residual = jnp.mean(shape_residuals, axis=-1) 412 | 413 | return batch_residual 414 | 415 | 416 | # =============================================================================== 417 | # Regularization 418 | # =============================================================================== 419 | 420 | def compute_q_regularization(q): 421 | """ 422 | Calculate variance of the force densities for compression and tension. 423 | 424 | Parameters 425 | ---------- 426 | q: `jax.Array` 427 | The force densities. 428 | 429 | Returns 430 | ------- 431 | result: `float` 432 | The sum of the two variances. 433 | """ 434 | sign_q = jnp.sign(q) 435 | var_q_pos = jnp.var(q, where=sign_q > 0) 436 | var_q_neg = jnp.var(q, where=sign_q < 0) 437 | 438 | # NOTE: jnp.mean is doing nothing here because the size of the variance arrays is 1 439 | result = jnp.mean(var_q_pos) + jnp.mean(var_q_neg) 440 | 441 | return result 442 | 443 | 444 | # =============================================================================== 445 | # Utilities 446 | # =============================================================================== 447 | 448 | def print_loss_summary(loss_terms, prefix=None): 449 | """ 450 | Print a summary of the loss terms. 451 | 452 | Parameters 453 | ---------- 454 | loss_terms: `dict` 455 | The loss terms. 456 | prefix: `str` or `None`, optional 457 | The prefix to add to the loss terms printed to the console. 458 | """ 459 | msg_parts = [] 460 | if prefix: 461 | msg_parts.append(prefix) 462 | 463 | for label, term in loss_terms.items(): 464 | part = f"{label.capitalize()}: {term.item():.4f}" 465 | msg_parts.append(part) 466 | 467 | msg = "\t".join(msg_parts) 468 | print(msg) 469 | -------------------------------------------------------------------------------- /src/neural_fdm/mesh.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | from compas.utilities import geometric_key 4 | from compas.utilities import pairwise 5 | 6 | import jax.numpy as jnp 7 | 8 | from jax_fdm.datastructures import FDMesh 9 | 10 | from neural_fdm.generators import evaluate_bezier_surface 11 | 12 | 13 | def create_mesh_from_tube_generator(generator, config, *args, **kwargs): 14 | """ 15 | Boundary-supported mesh on a tube. The mesh has group tags. 16 | 17 | Parameters 18 | ---------- 19 | generator: `TubePointGenerator` 20 | The tube generator. 21 | config: `dict` 22 | The configuration dictionary. 23 | 24 | Returns 25 | ------- 26 | mesh: `jax_fdm.FDMesh` 27 | The mesh. 28 | """ 29 | # shorthands 30 | tube = generator 31 | fix_rings = not config["loss"]["shape"]["include"] 32 | 33 | # generate base FD Mesh 34 | points = tube.points_on_tube() 35 | points = jnp.reshape(points, (-1, 3)) 36 | 37 | num_u = tube.num_levels 38 | num_v = tube.num_sides 39 | faces = calculate_mesh_tube_faces(num_u - 1, num_v - 1) 40 | mesh = FDMesh.from_vertices_and_faces(points, faces) 41 | 42 | # define structural system 43 | for vertices in mesh.vertices_on_boundaries(): 44 | mesh.vertices_supports(vertices) 45 | 46 | # tag edges as either rings or cables 47 | # first assume all edges ar cables 48 | mesh.edges_attribute("tag", "cable") 49 | 50 | # then, search for ring edges by geometric key 51 | points = jnp.reshape(points, tube.shape_tube) 52 | points_rings = points[tube.levels_rings_comp, :, :].tolist() 53 | gkey_key = mesh.gkey_key() 54 | 55 | num_ring_edges = 0 56 | for points_ring in points_rings: 57 | for line in pairwise(points_ring + points_ring[:1]): 58 | edge = tuple([gkey_key[geometric_key(pt)] for pt in line]) 59 | if not mesh.has_edge(edge): 60 | u, v = edge 61 | edge = v, u 62 | assert mesh.has_edge(edge) 63 | mesh.edge_attribute(edge, "tag", "ring") 64 | num_ring_edges += 1 65 | 66 | # NOTE: fix ring supports if no shape error in loss 67 | if fix_rings: 68 | mesh.vertices_supports(edge) 69 | 70 | assert num_ring_edges == tube.num_rings * tube.num_sides 71 | 72 | return mesh 73 | 74 | 75 | def create_mesh_from_bezier_generator(generator, *args, **kwargs): 76 | """ 77 | Boundary-supported mesh on bezier surface. 78 | 79 | Parameters 80 | ---------- 81 | generator: `BezierSurfacePointGenerator` 82 | The bezier surface generator. 83 | 84 | Returns 85 | ------- 86 | mesh: `jax_fdm.FDMesh` 87 | The mesh. 88 | """ 89 | # unpack parameters 90 | bezier = generator.surface 91 | u = generator.u 92 | v = generator.v 93 | 94 | # generate base FD Mesh 95 | srf_points = bezier.evaluate_points(u, v) 96 | srf_points = jnp.reshape(srf_points, (-1, 3)) 97 | 98 | num_u = u.shape[0] 99 | num_v = v.shape[0] 100 | faces = calculate_mesh_grid_faces(num_u - 1, num_v - 1) 101 | mesh = FDMesh.from_vertices_and_faces(srf_points, faces) 102 | 103 | # define structural system 104 | mesh.vertices_supports(mesh.vertices_on_boundary()) 105 | 106 | return mesh 107 | 108 | 109 | def create_mesh_from_grid(grid, u, v): 110 | """ 111 | Boundary-supported mesh on a grid of Bezier control points. 112 | 113 | Parameters 114 | ---------- 115 | grid: `PointGrid` 116 | The grid of control points. 117 | u: `jax.Array` 118 | The parameter values along the `u` direction in the range [0, 1]. 119 | v: `jax.Array` 120 | The parameter values along the `v` direction in the range [0, 1]. 121 | 122 | Returns 123 | ------- 124 | mesh: `jax_fdm.FDMesh` 125 | The mesh. 126 | """ 127 | # generate base FD Mesh 128 | srf_points = calculate_bezier_surface_points_from_grid(grid, u, v) 129 | 130 | num_u = u.shape[0] 131 | num_v = v.shape[0] 132 | faces = calculate_mesh_grid_faces(num_u - 1, num_v - 1) 133 | mesh = FDMesh.from_vertices_and_faces(srf_points, faces) 134 | 135 | # define structural system 136 | mesh.vertices_supports(mesh.vertices_on_boundary()) 137 | 138 | return mesh 139 | 140 | 141 | def calculate_mesh_grid_faces(nx, ny): 142 | """ 143 | Generate the indices of the mesh faces of the grid. 144 | 145 | Parameters 146 | ---------- 147 | nx: `int` 148 | The number of points along the `x` direction. 149 | ny: `int` 150 | The number of points along the `y` direction. 151 | 152 | Returns 153 | ------- 154 | faces: `list` of `list` of `int` 155 | The indices of the mesh faces. 156 | """ 157 | faces = [] 158 | for i, j in product(range(nx), range(ny)): 159 | face = [ 160 | i * (ny + 1) + j, 161 | (i + 1) * (ny + 1) + j, 162 | (i + 1) * (ny + 1) + j + 1, 163 | i * (ny + 1) + j + 1, 164 | ] 165 | faces.append(face) 166 | 167 | return faces 168 | 169 | 170 | def calculate_mesh_tube_faces(nx, ny): 171 | """ 172 | Generate the indices of the mesh faces of a tube. 173 | 174 | Parameters 175 | ---------- 176 | nx: `int` 177 | The number of points along the `x` direction. 178 | ny: `int` 179 | The number of points along the `y` direction. 180 | 181 | Returns 182 | ------- 183 | faces: `list` of `list` of `int` 184 | The indices of the mesh faces. 185 | """ 186 | faces = calculate_mesh_grid_faces(nx, ny) 187 | 188 | num_xy = (nx + 1) * (ny + 1) 189 | starts = range(0, num_xy, ny + 1) 190 | ends = range(ny, num_xy + ny, ny + 1) 191 | 192 | for (a, b), (d, c) in zip(pairwise(starts), pairwise(ends)): 193 | face = [d, c, b, a] 194 | faces.append(face) 195 | 196 | return faces 197 | 198 | 199 | def calculate_bezier_surface_points_from_grid(grid, u, v): 200 | """ 201 | Evaluate points on a Bezier surface from a grid of control points. 202 | 203 | Parameters 204 | ---------- 205 | grid: `PointGrid` 206 | The grid of control points. 207 | u: `jax.Array` 208 | The parameter values along the `u` direction in the range [0, 1]. 209 | v: `jax.Array` 210 | The parameter values along the `v` direction in the range [0, 1]. 211 | 212 | Returns 213 | ------- 214 | points: `jax.Array` 215 | The points on the surface. 216 | """ 217 | # generate control points grid 218 | control_points = grid.points() 219 | 220 | # sample surface points on bezier 221 | surface_points = evaluate_bezier_surface(control_points, u, v) 222 | 223 | return jnp.reshape(surface_points, (-1, 3)) 224 | -------------------------------------------------------------------------------- /src/neural_fdm/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def moving_average(data, window_size): 7 | """ 8 | Calculate the moving average of a data series. 9 | 10 | Parameters 11 | ---------- 12 | data: `numpy.ndarray` 13 | The data series to average. 14 | window_size: `int` 15 | The size of the window to average over. 16 | 17 | Returns 18 | ------- 19 | moving_average: `numpy.ndarray` 20 | The moving average of the data series. 21 | """ 22 | 23 | weights = np.ones(window_size) / window_size 24 | return np.convolve(data, weights, mode='valid') 25 | 26 | 27 | def plot_losses(loss_history, labels): 28 | """ 29 | Plot the convergence curve of a list of loss terms. 30 | 31 | Parameters 32 | ---------- 33 | loss_history: `list` of `dict` of `float` 34 | The loss histories during training. 35 | The keys are the plot labels. 36 | The values are the loss values. 37 | labels: `list` of `str` 38 | The labels of the losses to plot. 39 | """ 40 | # Plotting 41 | plt.figure(figsize=(10, 6)) 42 | 43 | for label in labels: 44 | loss_values = [float(vals[label]) for vals in loss_history] 45 | plt.plot(loss_values, label=label) 46 | 47 | plt.title('Loss') 48 | plt.xlabel('Step') 49 | plt.ylabel('Loss') 50 | plt.yscale('log') 51 | plt.grid() 52 | plt.legend() 53 | plt.show() 54 | 55 | 56 | def plot_smoothed_losses(loss_history, window_size, labels): 57 | """ 58 | Plot the convergence curve of a list of loss terms with a moving average. 59 | 60 | Parameters 61 | ---------- 62 | loss_history: `list` of `dict` of `float` 63 | The loss histories during training. 64 | The keys are the plot labels. 65 | The values are the loss values. 66 | window_size: `int` 67 | The size of the window to average over. 68 | labels: `list` of `str` 69 | The labels of the losses to plot. 70 | """ 71 | # Plotting 72 | plt.figure(figsize=(10, 6)) 73 | 74 | for label in labels: 75 | 76 | # Loss values 77 | loss_values = [float(vals[label]) for vals in loss_history] 78 | 79 | # Calculate the moving average 80 | smooth_loss = moving_average(loss_values, window_size) 81 | 82 | # Adjust the length of original loss values to match the smoothed array 83 | adjusted_loss_values = loss_values[:len(smooth_loss)] 84 | 85 | # Plot 86 | lines = plt.plot(adjusted_loss_values, alpha=0.5, label=label) 87 | color = lines[-1].get_color() 88 | plt.plot(smooth_loss, color=color) 89 | 90 | plt.title('Loss') 91 | plt.xlabel('Step') 92 | plt.ylabel('Loss') 93 | plt.yscale('log') 94 | plt.grid() 95 | plt.legend() 96 | plt.show() 97 | 98 | 99 | def plot_smoothed_loss(loss_history, window_size): 100 | """ 101 | Plot the convergence curve of a loss term with a moving average. 102 | 103 | Parameters 104 | ---------- 105 | loss_history: `list` of `float` 106 | The loss values during training. 107 | window_size: `int` 108 | The size of the window to average over. 109 | """ 110 | # Plotting 111 | plt.figure(figsize=(10, 6)) 112 | 113 | # Calculate the moving average 114 | smooth_loss = moving_average(loss_history, window_size) 115 | 116 | # Adjust the length of original loss values to match the smoothed array 117 | adjusted_loss_values = loss_history[:len(smooth_loss)] 118 | 119 | # Plot 120 | color = "tab:blue" 121 | plt.plot(adjusted_loss_values, alpha=0.5, color=color) 122 | plt.plot(smooth_loss, color=color) 123 | 124 | plt.title('Loss') 125 | plt.xlabel('Step') 126 | plt.ylabel('Loss') 127 | plt.yscale('log') 128 | plt.grid() 129 | plt.show() 130 | -------------------------------------------------------------------------------- /src/neural_fdm/serialization.py: -------------------------------------------------------------------------------- 1 | import equinox as eqx 2 | 3 | 4 | def save_model(filename, model): 5 | """ 6 | Serialize and save a model to a file. 7 | 8 | Parameters 9 | ---------- 10 | filename: `str` 11 | The name of the file to save the model to. 12 | The file extension must be `.eqx`. 13 | model: `eqx.Module` 14 | The model to save. 15 | """ 16 | with open(filename, "wb") as f: 17 | eqx.tree_serialise_leaves(f, model) 18 | 19 | 20 | def load_model(filename, model_skeleton): 21 | """ 22 | Load a serialized model from a file. 23 | 24 | Parameters 25 | ---------- 26 | filename: `str` 27 | The name of the file to load the model from. 28 | The file extension must be `.eqx`. 29 | model_skeleton: `eqx.Module` 30 | The reference skeleton of the model to load the model into. 31 | 32 | Returns 33 | ------- 34 | model: `eqx.Module` 35 | The loaded model. 36 | """ 37 | with open(filename, "rb") as f: 38 | return eqx.tree_deserialise_leaves(f, model_skeleton) 39 | -------------------------------------------------------------------------------- /src/neural_fdm/training.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from jax import vmap 4 | import jax.random as jrn 5 | import jax.tree_util as jtu 6 | 7 | import equinox as eqx 8 | 9 | from tqdm import tqdm 10 | 11 | from neural_fdm.models import AutoEncoderPiggy 12 | 13 | 14 | def train_step_piggy(model, structure, optimizer, generator, opt_state, *, loss_fn, batch_size, key): 15 | """ 16 | Update the parameters of an autoencoder piggy model on a batch of data for one step. 17 | 18 | Parameters 19 | ---------- 20 | model: `eqx.Module` 21 | The model to train. 22 | structure: `jax_fdm.EquilibriumStructure` 23 | A structure with the discretization of the shape. 24 | optimizer: `optax.GradientTransformation` 25 | The optimizer to use for training. 26 | generator: `PointGenerator` 27 | The data generator. 28 | opt_state: `optax.GradientTransformationExtraArgs` 29 | The current optimizer state. 30 | loss_fn: `Callable` 31 | The loss function. 32 | batch_size: `int` 33 | The number of samples to generate in each batch. 34 | key: `jax.random.PRNGKey` 35 | The random key. 36 | 37 | Returns 38 | ------- 39 | loss_vals: `dict` of `float` 40 | The values of the loss terms. 41 | model: `eqx.Module` 42 | The updated model. 43 | opt_state: `optax.GradientTransformationExtraArgs` 44 | The updated optimizer state. 45 | """ 46 | # sample fresh data 47 | keys = jrn.split(key, batch_size) 48 | x = vmap(generator)(keys) 49 | 50 | # calculate updates for main 51 | val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True) 52 | (loss, loss_vals), grads_main = val_grad_fn( 53 | model, 54 | structure, 55 | x, 56 | True, 57 | False 58 | ) 59 | 60 | # calculate updates for piggy 61 | val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True) 62 | (loss, loss_vals), grads_piggy = val_grad_fn( 63 | model, 64 | structure, 65 | x, 66 | True, 67 | True 68 | ) 69 | 70 | # combine gradients 71 | grads = jtu.tree_map(lambda x, y: x + y, grads_main, grads_piggy) 72 | 73 | # apply updates 74 | updates, opt_state = optimizer.update(grads, opt_state) 75 | model = eqx.apply_updates(model, updates) 76 | 77 | return loss_vals, model, opt_state 78 | 79 | 80 | def train_step(model, structure, optimizer, generator, opt_state, *, loss_fn, batch_size, key): 81 | """ 82 | Update the parameters of an autoencoder model on a batch of data for one step. 83 | 84 | Parameters 85 | ---------- 86 | model: `eqx.Module` 87 | The model to train. 88 | structure: `jax_fdm.EquilibriumStructure` 89 | A structure with the discretization of the shape. 90 | optimizer: `optax.GradientTransformation` 91 | The optimizer to use for training. 92 | generator: `PointGenerator` 93 | The data generator. 94 | opt_state: `optax.GradientTransformationExtraArgs` 95 | The current optimizer state. 96 | loss_fn: `Callable` 97 | The loss function. 98 | batch_size: `int` 99 | The number of samples to generate in each batch. 100 | key: `jax.random.PRNGKey` 101 | The random key. 102 | 103 | Returns 104 | ------- 105 | loss_vals: `dict` of `float` 106 | The values of the loss terms. 107 | model: `eqx.Module` 108 | The updated model. 109 | opt_state: `optax.GradientTransformationExtraArgs` 110 | The updated optimizer state. 111 | """ 112 | # sample fresh data 113 | keys = jrn.split(key, batch_size) 114 | x = vmap(generator)(keys) 115 | 116 | # calculate updates 117 | val_grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=True) 118 | (loss, loss_vals), grads = val_grad_fn(model, structure, x, aux_data=True) 119 | 120 | # apply updates 121 | updates, opt_state = optimizer.update(grads, opt_state) 122 | model = eqx.apply_updates(model, updates) 123 | 124 | return loss_vals, model, opt_state 125 | 126 | 127 | def train_model(model, structure, optimizer, generator, *, loss_fn, num_steps, batch_size, key, callback=None): 128 | """ 129 | Train a model over a number of steps. 130 | 131 | Parameters 132 | ---------- 133 | model: `eqx.Module` 134 | The model to train. 135 | structure: `jax_fdm.EquilibriumStructure` 136 | A structure with the discretization of the shape. 137 | optimizer: `optax.GradientTransformation` 138 | The optimizer to use for training. 139 | generator: `PointGenerator` 140 | The data generator. 141 | loss_fn: `Callable` 142 | The loss function. 143 | num_steps: `int` 144 | The number of steps to train for (number of parameter updates). 145 | batch_size: `int` 146 | The number of samples to generate per batch. 147 | key: `jax.random.PRNGKey` 148 | The random key. 149 | callback: `Callable`, optional 150 | A callback function to call after each step. 151 | The callback function should take the following arguments: 152 | - model: `eqx.Module` 153 | - opt_state: `optax.GradientTransformationExtraArgs` 154 | - loss_vals: `dict` of `float` 155 | - step: `int` 156 | """ 157 | # initial optimization step 158 | opt_state = optimizer.init(eqx.filter(model, eqx.is_array)) 159 | 160 | # assemble train step 161 | train_step_fn = train_step 162 | if isinstance(model, AutoEncoderPiggy): 163 | train_step_fn = train_step_piggy 164 | 165 | train_step_fn = partial(train_step_fn, loss_fn=loss_fn) 166 | train_step_fn = eqx.filter_jit(train_step_fn) 167 | 168 | # train 169 | loss_history = [] 170 | for step in tqdm(range(num_steps)): 171 | 172 | # randomnesss 173 | key, _ = jrn.split(key) 174 | 175 | # train step 176 | loss_vals, model, opt_state = train_step_fn( 177 | model, 178 | structure, 179 | optimizer, 180 | generator, 181 | opt_state, 182 | batch_size=batch_size, 183 | key=key, 184 | ) 185 | 186 | # store loss values 187 | loss_history.append(loss_vals) 188 | 189 | # callback 190 | if callback: 191 | callback(model, opt_state, loss_vals, step) 192 | 193 | return model, loss_history 194 | --------------------------------------------------------------------------------