├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── demo_design_optimization.ipynb ├── requirements.txt └── src ├── connectivity_utils.py ├── graph_network.py ├── learned_simulator.py ├── model_utils.py ├── normalizers.py └── watercourse_env.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Inverse Design for fluid-structure interactions using graph network simulators 2 | 3 | Code and parameters to accompany the NeurIPS 2022 paper 4 | **Inverse Design for Fluid-Structure Interactions using Graph Network 5 | Simulators** ([arXiv](https://arxiv.org/abs/2202.00728))
6 | _Kelsey R. Allen*, Tatiana Lopez-Guevara*, Kimberly Stachenfeld*, 7 | Alvaro Sanchez-Gonzalez, Peter Battaglia, Jessica Hamrick, Tobias Pfaff_ 8 | 9 | The code here provides an implementation of the Encode-Process-Decode 10 | graph network architecture in jax, model weights for this architecture trained 11 | on the 3D WaterCourse environment, and an example of performing gradient-based 12 | optimization in order to optimize a landscape to reroute water. 13 | 14 | ## Usage 15 | 16 | ### in a google colab 17 | Open the [google colab](https://colab.research.google.com/github/deepmind/inverse_design/blob/master/demo_design_optimization.ipynb) and run all cells. 18 | 19 | ### with jupyter notebook / locally 20 | To install the necessary requirements (run these commands from the directory 21 | that you wish to clone `inverse_design` into): 22 | 23 | ```shell 24 | git clone https://github.com/deepmind/inverse_design.git 25 | python3 -m venv id_venv 26 | source id_venv/bin/activate 27 | pip install --upgrade pip 28 | pip install -r ./inverse_design/requirements.txt 29 | ``` 30 | 31 | Additionally install jupyter notebook if not already installed with 32 | `pip install notebook` 33 | 34 | Finally, make a new directory within the `inverse_design` repository and move 35 | files there: 36 | ```shell 37 | cd inverse_design 38 | mkdir inverse_design 39 | mv src/ inverse_design/ 40 | ``` 41 | 42 | Download the dataset and model weights from google cloud: 43 | ```shell 44 | wget -O ./gns_params.pickle https://storage.googleapis.com/dm_inverse_design_watercourse/gns_params.pickle 45 | wget -O ./init_sequence.pickle https://storage.googleapis.com/dm_inverse_design_watercourse/init_sequence.pickle 46 | ``` 47 | 48 | Now you should be ready to go! Open `demo_design_optimization.ipynb` inside 49 | a jupyter notebook and run *from third cell* onwards. 50 | 51 | ## Citing this work 52 | 53 | If you use this work, please cite the following paper 54 | ``` 55 | @misc{inversedesign_2022, 56 | title = {Inverse Design for Fluid-Structure Interactions using Graph Network Simulators}, 57 | author = {Kelsey R. Allen and 58 | Tatiana Lopez{-}Guevara and 59 | Kimberly L. Stachenfeld and 60 | Alvaro Sanchez{-}Gonzalez and 61 | Peter W. Battaglia and 62 | Jessica B. Hamrick and 63 | Tobias Pfaff}, 64 | journal = {Neural Information Processing Systems}, 65 | year = {2022}, 66 | } 67 | ``` 68 | ## License and disclaimer 69 | 70 | Copyright 2022 DeepMind Technologies Limited 71 | 72 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 73 | you may not use this file except in compliance with the Apache 2.0 license. 74 | You may obtain a copy of the Apache 2.0 license at: 75 | https://www.apache.org/licenses/LICENSE-2.0 76 | 77 | All other materials are licensed under the Creative Commons Attribution 4.0 78 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 79 | https://creativecommons.org/licenses/by/4.0/legalcode 80 | 81 | Unless required by applicable law or agreed to in writing, all software and 82 | materials distributed here under the Apache 2.0 or CC-BY licenses are 83 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 84 | either express or implied. See the licenses for the specific language governing 85 | permissions and limitations under those licenses. 86 | 87 | This is not an official Google product. 88 | -------------------------------------------------------------------------------- /demo_design_optimization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "SyNZcVgQPsKs" 7 | }, 8 | "source": [ 9 | "Copyright 2022 DeepMind Technologies Limited\n", 10 | "\n", 11 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n", 12 | "you may not use this file except in compliance with the License.\n", 13 | "You may obtain a copy of the License at\n", 14 | "\n", 15 | " https://www.apache.org/licenses/LICENSE-2.0\n", 16 | "\n", 17 | "Unless required by applicable law or agreed to in writing, software\n", 18 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n", 19 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 20 | "See the License for the specific language governing permissions and\n", 21 | "limitations under the License.\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": { 27 | "id": "I8K_e58H8S9d" 28 | }, 29 | "source": [ 30 | "# Demo design optimization for 3D WaterCourse environment\n" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": { 37 | "id": "1jNaCf3sM1Xm" 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "#@title Installation (if not running locally)\n", 42 | "# Note, this should be skipped if running locally.\n", 43 | "!mkdir /content/inverse_design\n", 44 | "!mkdir /content/inverse_design/src\n", 45 | "!touch /content/inverse_design/__init__.py\n", 46 | "!touch /content/inverse_design/src/__init__.py\n", 47 | "\n", 48 | "!wget -O /content/inverse_design/src/connectivity_utils.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/connectivity_utils.py\n", 49 | "!wget -O /content/inverse_design/src/graph_network.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/graph_network.py\n", 50 | "!wget -O /content/inverse_design/src/learned_simulator.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/learned_simulator.py\n", 51 | "!wget -O /content/inverse_design/src/model_utils.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/model_utils.py\n", 52 | "!wget -O /content/inverse_design/src/normalizers.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/normalizers.py\n", 53 | "!wget -O /content/inverse_design/src/watercourse_env.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/watercourse_env.py\n", 54 | "\n", 55 | "!wget -O /content/requirements.txt https://raw.githubusercontent.com/deepmind/master/inverse_design/requirements.txt\n", 56 | "!pip install -r requirements.txt" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "id": "Q08NCU1pds4j" 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "#@title Download Pickled Dataset \u0026 Params (if running in colab)\n", 68 | "# Note this can be skipped if following instructions for jupyter notebook\n", 69 | "from google.colab import auth\n", 70 | "auth.authenticate_user()\n", 71 | "\n", 72 | "!gsutil cp gs://dm_inverse_design_watercourse/init_sequence.pickle .\n", 73 | "!gsutil cp gs://dm_inverse_design_watercourse/gns_params.pickle .\n" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "id": "3gfBGlIrCkBf" 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "#@title Imports\n", 85 | "from inverse_design.src import learned_simulator\n", 86 | "from inverse_design.src import model_utils\n", 87 | "from inverse_design.src import watercourse_env" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "id": "SJVrIfyHnkOK" 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "#@title Open pickled parameters + dataset\n", 99 | "import pickle\n", 100 | "\n", 101 | "with open('init_sequence.pickle', \"rb\") as f:\n", 102 | " pickled_data = pickle.loads(f.read())\n", 103 | " gt_sequence = pickled_data['gt_sequence']\n", 104 | " meta = pickled_data['meta']\n", 105 | "\n", 106 | "with open('gns_params.pickle', \"rb\") as f:\n", 107 | " pickled_params = pickle.loads(f.read())\n", 108 | " network = pickled_params['network']\n", 109 | " plan = pickled_params['plan']" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "id": "kdxb0hSLPRwH" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "#@title make ID control/loss functions\n", 121 | "import jax\n", 122 | "import functools\n", 123 | "\n", 124 | "# maximum number of edges for single step of rollout to pad to\n", 125 | "MAX_EDGES = 2**16\n", 126 | "\n", 127 | "# define haiku model\n", 128 | "connectivity_radius = meta[\"connectivity_radius\"]\n", 129 | "flatten_fn = functools.partial(model_utils.flatten_features, **plan['flatten_kwargs'])\n", 130 | "haiku_model = functools.partial(learned_simulator.LearnedSimulator, connectivity_radius=connectivity_radius, flatten_features_fn=flatten_fn, **plan['model_kwargs'])\n", 131 | "\n", 132 | "# create initial landscape (obstacle) in the scene\n", 133 | "obstacle_pos = watercourse_env.make_plain_obstacles()\n", 134 | "for frame in gt_sequence:\n", 135 | " pos = frame.nodes['world_position'].copy()\n", 136 | " pos[:obstacle_pos.shape[0]] = obstacle_pos[:, None]\n", 137 | " frame.nodes['world_position'] = pos\n", 138 | "\n", 139 | "\n", 140 | "# get initial sequence of particles from dataset for initial graph\n", 141 | "obstacle_edges, inflow_stack, initial_graph = watercourse_env.build_initial_graph(gt_sequence[15:], max_edges=MAX_EDGES)\n", 142 | "\n", 143 | "# infer the landscape size from the dataset (25 x 25)\n", 144 | "# note that this is not required, it is also possible to create a smaller\n", 145 | "# or larger landscape (obstacle) as the design space\n", 146 | "num_side = int(jax.numpy.sqrt(initial_graph.nodes['obstacle_mask'].sum()))\n", 147 | "n_obs = num_side**2\n", 148 | "\n", 149 | "# rollout length definition (final state taken for reward computation)\n", 150 | "length = 50\n", 151 | "# radius within which to connect particles\n", 152 | "radius = 0.1\n", 153 | "# smoothing factor for loss\n", 154 | "smoothing_factor = 1e2\n", 155 | "\n", 156 | "@jax.jit\n", 157 | "def run(vars):\n", 158 | " # create landscape as graph from vars parameters\n", 159 | " graph, raw_obs = watercourse_env.design_fn(vars, initial_graph)\n", 160 | "\n", 161 | " # rollout\n", 162 | " final_graph, traj = watercourse_env.rollout(\n", 163 | " graph, inflow_stack[:length], network, haiku_model,\n", 164 | " obstacle_edges, radius=radius)\n", 165 | " \n", 166 | " # losses\n", 167 | " losses = {\n", 168 | " 'objective': watercourse_env.max_x_loss_fn(final_graph),\n", 169 | " 'smooth': smoothing_factor * watercourse_env.smooth_loss_fn(raw_obs, num_side=num_side),\n", 170 | " }\n", 171 | "\n", 172 | " # auxiliaries to keep track of for plotting\n", 173 | " aux = {\n", 174 | " 'design': vars,\n", 175 | " 'losses': losses,\n", 176 | " 'traj': traj\n", 177 | " }\n", 178 | " return sum(losses.values()), aux\n" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": { 185 | "id": "TuuHYP3dOW45" 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "from IPython.display import clear_output\n", 190 | "import jax.numpy as jnp\n", 191 | "import matplotlib.pyplot as plt\n", 192 | "import optax\n", 193 | "\n", 194 | "# set learning rate and number of optimization steps\n", 195 | "LEARNING_RATE = 0.05\n", 196 | "num_opt_steps = 100\n", 197 | "\n", 198 | "# define optimizer as adam with learning rate\n", 199 | "optimizer = optax.adam(learning_rate=LEARNING_RATE)\n", 200 | "\n", 201 | "# initialize design parameters to be zeros (flat landscape)\n", 202 | "params = jnp.zeros(n_obs, dtype=jnp.float32)\n", 203 | "opt_state = optimizer.init(params)\n", 204 | "\n", 205 | "# initialize empty optimization trajectory (for tracking improvements to losses and design)\n", 206 | "opt_traj = []\n", 207 | "\n", 208 | "# optimization step with current design parameters\n", 209 | "@jax.jit\n", 210 | "def opt_step(params, opt_state):\n", 211 | " grads, aux = jax.grad(run, has_aux=True)(params)\n", 212 | " updates, opt_state = optimizer.update(grads, opt_state, params)\n", 213 | " params = optax.apply_updates(params, updates)\n", 214 | " return params, opt_state, aux\n", 215 | "\n", 216 | "# run optimization loop and track progress\n", 217 | "for i in range(num_opt_steps):\n", 218 | " params, opt_state, aux = opt_step(params, opt_state)\n", 219 | " opt_traj.append(aux)\n", 220 | " clear_output(wait=True)\n", 221 | " fig, ax = plt.subplots(1,1,figsize=(10,5))\n", 222 | " for key in aux['losses'].keys():\n", 223 | " ax.plot([t['losses'][key] for t in opt_traj])\n", 224 | " ax.plot([sum(t['losses'].values()) for t in opt_traj])\n", 225 | " ax.legend(list(aux['losses'].keys())+['total'])\n", 226 | " plt.show()" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": { 233 | "id": "88voGdYuAPus" 234 | }, 235 | "outputs": [], 236 | "source": [ 237 | "import numpy as np\n", 238 | "\n", 239 | "# plot design iterations (every 10 steps)\n", 240 | "n_sam = range(0, len(opt_traj), 10)\n", 241 | "fig, ax = plt.subplots(1,len(n_sam),figsize=(len(n_sam)*10, 10), squeeze=False)\n", 242 | "\n", 243 | "for fi, idx in enumerate(n_sam):\n", 244 | " design = opt_traj[idx]['design']\n", 245 | "\n", 246 | " # control function uses tanh as transformation, so mimic here to see heightfield\n", 247 | " fld = np.tanh(design.reshape((num_side, num_side)))\n", 248 | " ax[0, fi].imshow(fld, vmin=-1, vmax=1)\n", 249 | " ax[0, fi].set_axis_off()" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": { 256 | "id": "yPaWTUM0FL36" 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "from IPython.display import clear_output\n", 261 | "# plot video of how particles move for optimized design and initial design\n", 262 | "\n", 263 | "def _plt(ax, frame, i):\n", 264 | " pos = frame['pos'][i] \n", 265 | " p = pos[frame['mask'][i]]\n", 266 | " ax.scatter(p[:, 0], p[:, 2], p[:, 1], c='b',s=10)\n", 267 | " obs = pos[:num_side**2]\n", 268 | " ax.scatter(obs[:, 0], obs[:, 2], obs[:, 1], c='k',s=3)\n", 269 | " ax.scatter([1.5],[1.5],[0], c='g',s=20)\n", 270 | " ax.set_xlim([-0.6, 1.6])\n", 271 | " ax.set_ylim([-0.1, 1.6])\n", 272 | " ax.set_zlim([-0.1, 1.2])\n", 273 | "\n", 274 | "roll_fin0 = run(opt_traj[0]['design'])[1]['traj']\n", 275 | "roll_fin1 = run(opt_traj[-1]['design'])[1]['traj']\n", 276 | "\n", 277 | "for i in range(roll_fin0['pos'].shape[0]):\n", 278 | " clear_output(wait=True)\n", 279 | " fig = plt.figure(figsize=(20,10))\n", 280 | " ax1 = fig.add_subplot(1, 2, 1, projection='3d')\n", 281 | " ax1.set_title('Initial design, frame %d' % i)\n", 282 | " _plt(ax1, roll_fin0, i)\n", 283 | " ax2 = fig.add_subplot(1, 2, 2, projection='3d')\n", 284 | " ax2.set_title('Design at final step')\n", 285 | " _plt(ax2, roll_fin1, i)\n", 286 | " plt.show()" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "metadata": { 293 | "id": "csTLqGTUPuAn" 294 | }, 295 | "outputs": [], 296 | "source": [] 297 | } 298 | ], 299 | "metadata": { 300 | "colab": { 301 | "private_outputs": true, 302 | "provenance": [ 303 | { 304 | "file_id": "1ZYL6nDmJCvzc70qi5rwIQ7BSj9a2sdoL", 305 | "timestamp": 1667473449566 306 | }, 307 | { 308 | "file_id": "1rOmaBHyQAVa6NY3nfsM1yba8GC0a8duM", 309 | "timestamp": 1667402678487 310 | }, 311 | { 312 | "file_id": "1bHb2szUEMLP2gKErlnxrigutZyhZ4reh", 313 | "timestamp": 1667386305836 314 | }, 315 | { 316 | "file_id": "1jrPkT2OlRGUpImN2cVetBtQ9PgFelNgy", 317 | "timestamp": 1665137721073 318 | }, 319 | { 320 | "file_id": "144LaPYCJpaCSUX6jMoampo0dTMICcXWV", 321 | "timestamp": 1640626024159 322 | }, 323 | { 324 | "file_id": "1B31Dl0yzZlW22l-X7qSR5S4w1nOHPAg0", 325 | "timestamp": 1640104963946 326 | }, 327 | { 328 | "file_id": "1MXdt8a59t0hLm7Q1Prem3kdHpnixMOeJ", 329 | "timestamp": 1639481266768 330 | }, 331 | { 332 | "file_id": "14OG1REDMq3i9phUABzuSWg3nQhUD_ozK", 333 | "timestamp": 1639156624846 334 | } 335 | ] 336 | }, 337 | "kernelspec": { 338 | "display_name": "Python 3", 339 | "name": "python3" 340 | }, 341 | "language_info": { 342 | "name": "python" 343 | } 344 | }, 345 | "nbformat": 4, 346 | "nbformat_minor": 0 347 | } 348 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.2.0 2 | dm-haiku==0.0.9 3 | ml_collections==0.1.1 4 | numpy>=1.21.0 5 | optax==0.1.3 6 | jraph>=0.0.5.dev0 7 | scikit-learn>=0.24.0 8 | matplotlib>=3.6.0 9 | -------------------------------------------------------------------------------- /src/connectivity_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tools to compute the connectivity of the graph.""" 17 | import functools 18 | 19 | import jax 20 | from jax.experimental import host_callback as hcb 21 | import jax.numpy as jnp 22 | import numpy as np 23 | from sklearn import neighbors 24 | 25 | 26 | def _cb_radius_query(args): 27 | """Host callback function to compute connectivity.""" 28 | padded_pos, n_node, radius, max_edges, query_mask, node_mask = args 29 | edges = [] 30 | offset = 0 31 | 32 | for num_nodes in n_node: 33 | pos_nodes = padded_pos[offset:offset+num_nodes] 34 | pos_query = padded_pos[offset:offset+num_nodes] 35 | pos_nodes = pos_nodes[node_mask[offset:offset+num_nodes]] 36 | pos_query = pos_query[query_mask[offset:offset+num_nodes]] 37 | 38 | # indices: [num_edges, 2] array of receivers ([:, 0]) and senders ([:, 1]) 39 | indices = compute_fixed_radius_connectivity_np(pos_nodes, radius, pos_query) 40 | mask = query_mask[offset:offset+num_nodes] 41 | renumber = np.arange(num_nodes, dtype=np.int32)[mask] 42 | indices[:, 0] = renumber[indices[:, 0]] 43 | 44 | mask = node_mask[offset:offset+num_nodes] 45 | renumber = np.arange(num_nodes, dtype=np.int32)[mask] 46 | indices[:, 1] = renumber[indices[:, 1]] 47 | 48 | # remove self-edges 49 | mask = indices[:, 0] != indices[:, 1] 50 | indices = indices[mask] 51 | 52 | # create unique two way edges (only necessary in the masked case) 53 | indices = np.stack([np.min(indices, axis=1), 54 | np.max(indices, axis=1)], 55 | axis=1) 56 | indices = np.unique(indices, axis=0) 57 | indices = np.concatenate([indices, indices[:, [1, 0]]], axis=0) 58 | 59 | edges.append(indices + offset) 60 | offset += num_nodes 61 | 62 | n_edge = [x.shape[0] for x in edges] 63 | total_edges = np.sum(n_edge) 64 | 65 | # padding 66 | if total_edges >= max_edges: 67 | raise ValueError("%d edges found, max_edges: %d" % (total_edges, max_edges)) 68 | 69 | # create a [n_p, 2] padding array, which connects the first dummy padding node 70 | # (with index `num_nodes`) to itself. 71 | padding_size = max_edges - total_edges 72 | padding = np.ones((padding_size, 2), dtype=np.int32) * offset 73 | edges = np.concatenate(edges + [padding], axis=0) 74 | n_edge = np.array(n_edge + [padding_size], dtype=np.int32) 75 | return n_edge, edges 76 | 77 | 78 | @functools.partial(jax.custom_jvp, nondiff_argnums=(4, 5)) 79 | def compute_fixed_radius_connectivity_jax(positions, n_node, query_mask, 80 | node_mask, radius, max_edges): 81 | """Computes connectivity for batched graphs using a jax host callback. 82 | 83 | Args: 84 | positions: concatenated vector (N, 2) of node positions for all graphs 85 | n_node: array of num_nodes for each graph 86 | query_mask: defines the subset of nodes to query from (None=all) 87 | node_mask: defines the subset of nodes to query to (None=all) 88 | radius: connectivity radius 89 | max_edges: maximum total number of edges 90 | 91 | Returns: 92 | array of num_edges, senders, receivers 93 | """ 94 | callback_arg = (positions, n_node, radius, max_edges, query_mask, node_mask) 95 | out_shape = (jax.ShapeDtypeStruct((len(n_node) + 1,), jnp.int32), 96 | jax.ShapeDtypeStruct((max_edges, 2), jnp.int32)) 97 | n_edge, indices = hcb.call(_cb_radius_query, callback_arg, 98 | result_shape=out_shape) 99 | 100 | senders = indices[:, 1] 101 | receivers = indices[:, 0] 102 | return n_edge, senders, receivers 103 | 104 | 105 | @compute_fixed_radius_connectivity_jax.defjvp 106 | def _compute_fixed_radius_connectivity_jax_jvp(radius, max_edges, primals, 107 | tangents): 108 | """Custom zero-jvp function for compute_fixed_radius_connectivity_jax.""" 109 | del tangents 110 | primal_out = compute_fixed_radius_connectivity_jax( 111 | *primals, radius=radius, max_edges=max_edges) 112 | grad_out = tuple(jnp.zeros_like(x) for x in primal_out) 113 | return primal_out, grad_out 114 | 115 | 116 | def compute_fixed_radius_connectivity_np( 117 | positions, radius, receiver_positions=None, remove_self_edges=False): 118 | """Computes connectivity between positions and receiver_positions.""" 119 | 120 | # if removing self edges, receiver positions must be none 121 | assert not (remove_self_edges and receiver_positions is not None) 122 | 123 | if receiver_positions is None: 124 | receiver_positions = positions 125 | 126 | # use kdtree for efficient calculation of pairs within radius distance 127 | kd_tree = neighbors.KDTree(positions) 128 | receivers_list = kd_tree.query_radius(receiver_positions, r=radius) 129 | num_nodes = len(receiver_positions) 130 | senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list]) 131 | receivers = np.concatenate(receivers_list, axis=0) 132 | 133 | if remove_self_edges: 134 | # Remove self edges. 135 | mask = senders != receivers 136 | senders = senders[mask] 137 | receivers = receivers[mask] 138 | 139 | return np.stack([senders.astype(np.int32), 140 | receivers.astype(np.int32)], 141 | axis=-1) 142 | -------------------------------------------------------------------------------- /src/graph_network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """JAX implementation of Encode Process Decode.""" 17 | 18 | from typing import Optional 19 | import haiku as hk 20 | import jax 21 | import jax.numpy as jnp 22 | import jraph 23 | 24 | 25 | class EncodeProcessDecode(hk.Module): 26 | """Encode-Process-Decode function approximator for learnable simulator.""" 27 | 28 | def __init__( 29 | self, 30 | *, 31 | latent_size: int, 32 | mlp_hidden_size: int, 33 | mlp_num_hidden_layers: int, 34 | num_message_passing_steps: int, 35 | num_processor_repetitions: int = 1, 36 | encode_nodes: bool = True, 37 | encode_edges: bool = True, 38 | node_output_size: Optional[int] = None, 39 | edge_output_size: Optional[int] = None, 40 | include_sent_messages_in_node_update: bool = False, 41 | use_layer_norm: bool = True, 42 | name: str = "EncodeProcessDecode"): 43 | """Inits the model. 44 | 45 | Args: 46 | latent_size: Size of the node and edge latent representations. 47 | mlp_hidden_size: Hidden layer size for all MLPs. 48 | mlp_num_hidden_layers: Number of hidden layers in all MLPs. 49 | num_message_passing_steps: Number of unshared message passing steps 50 | in the processor steps. 51 | num_processor_repetitions: Number of times that the same processor is 52 | applied sequencially. 53 | encode_nodes: If False, the node encoder will be omitted. 54 | encode_edges: If False, the edge encoder will be omitted. 55 | node_output_size: Output size of the decoded node representations. 56 | edge_output_size: Output size of the decoded edge representations. 57 | include_sent_messages_in_node_update: Whether to include pooled sent 58 | messages from each node in the node update. 59 | use_layer_norm: Whether it uses layer norm or not. 60 | name: Name of the model. 61 | """ 62 | 63 | super().__init__(name=name) 64 | 65 | self._latent_size = latent_size 66 | self._mlp_hidden_size = mlp_hidden_size 67 | self._mlp_num_hidden_layers = mlp_num_hidden_layers 68 | self._num_message_passing_steps = num_message_passing_steps 69 | self._num_processor_repetitions = num_processor_repetitions 70 | self._encode_nodes = encode_nodes 71 | self._encode_edges = encode_edges 72 | self._node_output_size = node_output_size 73 | self._edge_output_size = edge_output_size 74 | self._include_sent_messages_in_node_update = ( 75 | include_sent_messages_in_node_update) 76 | self._use_layer_norm = use_layer_norm 77 | self._networks_builder() 78 | 79 | def __call__(self, input_graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 80 | """Forward pass of the learnable dynamics model.""" 81 | 82 | # Encode the input_graph. 83 | latent_graph_0 = self._encode(input_graph) 84 | 85 | # Do `m` message passing steps in the latent graphs. 86 | latent_graph_m = self._process(latent_graph_0) 87 | 88 | # Decode from the last latent graph. 89 | return self._decode(latent_graph_m) 90 | 91 | def _networks_builder(self): 92 | 93 | def build_mlp(name, output_size=None): 94 | if output_size is None: 95 | output_size = self._latent_size 96 | mlp = hk.nets.MLP( 97 | output_sizes=[self._mlp_hidden_size] * self._mlp_num_hidden_layers + [ 98 | output_size], name=name + "_mlp", activation=jax.nn.relu) 99 | return jraph.concatenated_args(mlp) 100 | 101 | def build_mlp_with_maybe_layer_norm(name, output_size=None): 102 | network = build_mlp(name, output_size) 103 | if self._use_layer_norm: 104 | layer_norm = hk.LayerNorm( 105 | axis=-1, create_scale=True, create_offset=True, 106 | name=name + "_layer_norm") 107 | network = hk.Sequential([network, layer_norm]) 108 | return jraph.concatenated_args(network) 109 | 110 | # The encoder graph network independently encodes edge and node features. 111 | encoder_kwargs = dict( 112 | embed_edge_fn=build_mlp_with_maybe_layer_norm("encoder_edges") 113 | if self._encode_edges else None, 114 | embed_node_fn=build_mlp_with_maybe_layer_norm("encoder_nodes") 115 | if self._encode_nodes else None,) 116 | self._encoder_network = jraph.GraphMapFeatures(**encoder_kwargs) 117 | 118 | # Create `num_message_passing_steps` graph networks with unshared parameters 119 | # that update the node and edge latent features. 120 | # Note that we can use `modules.InteractionNetwork` because 121 | # it also outputs the messages as updated edge latent features. 122 | self._processor_networks = [] 123 | for step_i in range(self._num_message_passing_steps): 124 | self._processor_networks.append( 125 | jraph.InteractionNetwork( 126 | update_edge_fn=build_mlp_with_maybe_layer_norm( 127 | f"processor_edges_{step_i}"), 128 | update_node_fn=build_mlp_with_maybe_layer_norm( 129 | f"processor_nodes_{step_i}"), 130 | include_sent_messages_in_node_update=( 131 | self._include_sent_messages_in_node_update))) 132 | 133 | # The decoder MLP decodes edge/node latent features into the output sizes. 134 | decoder_kwargs = dict( 135 | embed_edge_fn=build_mlp("decoder_edges", self._edge_output_size) 136 | if self._edge_output_size else None, 137 | embed_node_fn=build_mlp("decoder_nodes", self._node_output_size) 138 | if self._node_output_size else None, 139 | ) 140 | self._decoder_network = jraph.GraphMapFeatures(**decoder_kwargs) 141 | 142 | def _encode( 143 | self, input_graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 144 | """Encodes the input graph features into a latent graph.""" 145 | 146 | # Copy the globals to all of the nodes, if applicable. 147 | if input_graph.globals is not None: 148 | broadcasted_globals = jnp.repeat( 149 | input_graph.globals, input_graph.n_node, axis=0, 150 | total_repeat_length=input_graph.nodes.shape[0]) 151 | input_graph = input_graph._replace( 152 | nodes=jnp.concatenate( 153 | [input_graph.nodes, broadcasted_globals], axis=-1), 154 | globals=None) 155 | 156 | # Encode the node and edge features. 157 | latent_graph_0 = self._encoder_network(input_graph) 158 | return latent_graph_0 159 | 160 | def _process( 161 | self, latent_graph_0: jraph.GraphsTuple) -> jraph.GraphsTuple: 162 | """Processes the latent graph with several steps of message passing.""" 163 | 164 | # Do `num_message_passing_steps` with each of the `self._processor_networks` 165 | # with unshared weights, and repeat that `self._num_processor_repetitions` 166 | # times. 167 | latent_graph = latent_graph_0 168 | for unused_repetition_i in range(self._num_processor_repetitions): 169 | for processor_network in self._processor_networks: 170 | latent_graph = self._process_step(processor_network, latent_graph, 171 | latent_graph_0) 172 | 173 | return latent_graph 174 | 175 | def _process_step( 176 | self, processor_network_k, 177 | latent_graph_prev_k: jraph.GraphsTuple, 178 | latent_graph_0: jraph.GraphsTuple) -> jraph.GraphsTuple: 179 | """Single step of message passing with node/edge residual connections.""" 180 | 181 | input_graph_k = latent_graph_prev_k 182 | 183 | # One step of message passing. 184 | latent_graph_k = processor_network_k(input_graph_k) 185 | 186 | # Add residuals. 187 | latent_graph_k = latent_graph_k._replace( 188 | nodes=latent_graph_k.nodes+latent_graph_prev_k.nodes, 189 | edges=latent_graph_k.edges+latent_graph_prev_k.edges) 190 | return latent_graph_k 191 | 192 | def _decode(self, latent_graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 193 | """Decodes from the latent graph.""" 194 | return self._decoder_network(latent_graph) 195 | -------------------------------------------------------------------------------- /src/learned_simulator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Graph Network Simulator implementation used in NeurIPS 2022 submission. 17 | 18 | Inverse Design for Fluid-Structure Interactions using Graph Network Simulators 19 | 20 | Kelsey R. Allen*, Tatiana Lopez-Guevera*, Kimberly Stachenfeld*, 21 | Alvaro Sanchez-Gonzalez, Peter Battaglia, Jessica Hamrick, Tobias Pfaff 22 | """ 23 | 24 | from typing import Any, Dict 25 | 26 | import haiku as hk 27 | import jraph 28 | 29 | from inverse_design.src import graph_network 30 | from inverse_design.src import normalizers 31 | 32 | 33 | class LearnedSimulator(hk.Module): 34 | """Graph Network Simulator.""" 35 | 36 | def __init__(self, 37 | connectivity_radius, 38 | *, 39 | graph_network_kwargs: Dict[str, Any], 40 | flatten_features_fn=None, 41 | name="LearnedSimulator"): 42 | """Initialize the model. 43 | 44 | Args: 45 | connectivity_radius: Radius of connectivity within which to connect 46 | particles with edges. 47 | graph_network_kwargs: Keyword arguments to pass to the learned part of the 48 | graph network `model.EncodeProcessDecode`. 49 | flatten_features_fn: Function that takes the input graph and dataset 50 | metadata, and returns a graph where node and edge features are a single 51 | array of rank 2, and without global features. The function will be 52 | wrapped in a haiku module, which allows the flattening fn to instantiate 53 | its own variable normalizers. 54 | name: Name of the Haiku module. 55 | """ 56 | super().__init__(name=name) 57 | self._connectivity_radius = connectivity_radius 58 | self._graph_network_kwargs = graph_network_kwargs 59 | self._graph_network = None 60 | 61 | # Wrap flatten function in a Haiku module, so any haiku modules created 62 | # by the function are reused in case of multiple calls. 63 | self._flatten_features_fn = hk.to_module(flatten_features_fn)( 64 | name="flatten_features_fn") 65 | 66 | def _maybe_build_modules(self, input_graph): 67 | if self._graph_network is None: 68 | num_dimensions = input_graph.nodes["world_position"].shape[-1] 69 | self._graph_network = graph_network.EncodeProcessDecode( 70 | name="encode_process_decode", 71 | node_output_size=num_dimensions, 72 | **self._graph_network_kwargs) 73 | 74 | self._target_normalizer = normalizers.get_accumulated_normalizer( 75 | name="target_normalizer") 76 | 77 | def __call__(self, input_graph: jraph.GraphsTuple, padded_graph=True): 78 | self._maybe_build_modules(input_graph) 79 | 80 | flat_graphs_tuple = self._encoder_preprocessor( 81 | input_graph, padded_graph=padded_graph) 82 | normalized_prediction = self._graph_network(flat_graphs_tuple).nodes 83 | next_position = self._decoder_postprocessor(normalized_prediction, 84 | input_graph) 85 | return input_graph._replace( 86 | nodes={"p:world_position": next_position}, 87 | edges={}, 88 | globals={}, 89 | senders=input_graph.senders[:0], 90 | receivers=input_graph.receivers[:0], 91 | n_edge=input_graph.n_edge * 0), {} 92 | 93 | def _encoder_preprocessor(self, input_graph, padded_graph): 94 | # Flattens the input graph 95 | graph_with_flat_features = self._flatten_features_fn( 96 | input_graph, 97 | connectivity_radius=self._connectivity_radius, 98 | is_padded_graph=padded_graph) 99 | return graph_with_flat_features 100 | 101 | def _decoder_postprocessor(self, normalized_prediction, input_graph): 102 | # Un-normalize and integrate 103 | position_sequence = input_graph.nodes["world_position"] 104 | 105 | # The model produces the output in normalized space so we apply inverse 106 | # normalization. 107 | prediction = self._target_normalizer.inverse(normalized_prediction) 108 | 109 | new_position = euler_integrate_position(position_sequence, prediction) 110 | return new_position 111 | 112 | 113 | def euler_integrate_position(position_sequence, finite_diff_estimate): 114 | """Integrates finite difference estimate to position (assuming dt=1).""" 115 | # Uses an Euler integrator to go from acceleration to position, 116 | # assuming dt=1 corresponding to the size of the finite difference. 117 | previous_position = position_sequence[:, -1] 118 | previous_velocity = previous_position - position_sequence[:, -2] 119 | next_acceleration = finite_diff_estimate 120 | next_velocity = previous_velocity + next_acceleration 121 | next_position = previous_position + next_velocity 122 | return next_position 123 | 124 | 125 | def euler_integrate_position_inverse(position_sequence, next_position): 126 | """Computes a finite difference estimate from current position and history.""" 127 | previous_position = position_sequence[:, -1] 128 | previous_velocity = previous_position - position_sequence[:, -2] 129 | next_velocity = next_position - previous_position 130 | acceleration = next_velocity - previous_velocity 131 | return acceleration 132 | -------------------------------------------------------------------------------- /src/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for the LearnedSimulator model.""" 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | import jraph 21 | import tree 22 | 23 | from inverse_design.src import normalizers 24 | 25 | 26 | def flatten_features(input_graph, 27 | connectivity_radius, 28 | is_padded_graph, 29 | apply_normalization=False): 30 | """Returns a graph with a single array of node and edge features.""" 31 | 32 | # Normalize the eleements of the graph. 33 | if apply_normalization: 34 | graph_elements_normalizer = normalizers.GraphElementsNormalizer( 35 | template_graph=input_graph, 36 | is_padded_graph=is_padded_graph) 37 | 38 | # Computing relative distances in the model. 39 | if "relative_world_position" not in input_graph.edges: 40 | input_graph = _add_relative_distances( 41 | input_graph) 42 | 43 | # Extract important features from the position_sequence. 44 | position_sequence = input_graph.nodes["world_position"] 45 | velocity_sequence = time_diff(position_sequence) # Finite-difference. 46 | 47 | # Collect node features. 48 | node_features = [] 49 | 50 | # Normalized velocity sequence, flattening spatial axis. 51 | flat_velocity_sequence = jnp.reshape(velocity_sequence, 52 | [velocity_sequence.shape[0], -1]) 53 | 54 | if apply_normalization: 55 | flat_velocity_sequence = graph_elements_normalizer.normalize_node_array( 56 | "velocity_sequence", flat_velocity_sequence) 57 | 58 | node_features.append(flat_velocity_sequence) 59 | 60 | # Material types (one-hot, does not need normalization). 61 | node_features.append(jax.nn.one_hot(input_graph.nodes["material_type(9)"], 9)) 62 | 63 | # Collect edge features. 64 | edge_features = [] 65 | 66 | # Relative distances and norms. 67 | relative_world_position = input_graph.edges["relative_world_position"] 68 | relative_world_distance = safe_edge_norm( 69 | input_graph.edges["relative_world_position"], 70 | input_graph, 71 | is_padded_graph, 72 | keepdims=True) 73 | 74 | if apply_normalization: 75 | # Scaled determined by connectivity radius. 76 | relative_world_position = relative_world_position / connectivity_radius 77 | relative_world_distance = relative_world_distance / connectivity_radius 78 | 79 | edge_features.append(relative_world_position) 80 | edge_features.append(relative_world_distance) 81 | 82 | # Handle normalization. 83 | node_features = jnp.concatenate(node_features, axis=-1) 84 | edge_features = jnp.concatenate(edge_features, axis=-1) 85 | 86 | return input_graph._replace( 87 | nodes=node_features, 88 | edges=edge_features, 89 | globals=None, 90 | ) 91 | 92 | 93 | def time_diff(input_sequence): 94 | """Compute finnite time difference.""" 95 | return input_sequence[:, 1:] - input_sequence[:, :-1] 96 | 97 | 98 | def safe_edge_norm(array, graph, is_padded_graph, keepdims=False): 99 | """Compute vector norm, preventing nans in padding elements.""" 100 | if is_padded_graph: 101 | padding_mask = jraph.get_edge_padding_mask(graph) 102 | epsilon = 1e-8 103 | perturb = jnp.logical_not(padding_mask) * epsilon 104 | array += jnp.expand_dims(perturb, range(1, len(array.shape))) 105 | return jnp.linalg.norm(array, axis=-1, keepdims=keepdims) 106 | 107 | 108 | def _add_relative_distances(input_graph, 109 | use_last_position_only=True): 110 | """Computes relative distances between particles and with walls.""" 111 | 112 | # If these exist, there is probably something wrong. 113 | assert "relative_world_position" not in input_graph.edges 114 | assert "clipped_distance_to_walls" not in input_graph.nodes 115 | 116 | input_graph = tree.map_structure(lambda x: x, input_graph) # Avoid mutating. 117 | particle_pos = input_graph.nodes["world_position"] 118 | 119 | if use_last_position_only: 120 | particle_pos = particle_pos[:, -1] 121 | 122 | input_graph.edges["relative_world_position"] = ( 123 | particle_pos[input_graph.receivers] - particle_pos[input_graph.senders]) 124 | 125 | return input_graph 126 | -------------------------------------------------------------------------------- /src/normalizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """JAX module for normalization with accumulated statistics.""" 17 | 18 | import haiku as hk 19 | import jax.numpy as jnp 20 | import jraph 21 | 22 | 23 | def get_accumulated_normalizer(name): 24 | return AccumulatedNormalizer(name=name) 25 | 26 | 27 | class AccumulatedNormalizer(hk.Module): 28 | """Feature normalizer that accumulates statistics for normalization. 29 | 30 | It will accumulate statistics using float32 variables, and will return 31 | the mean and std. It accumulates statistics until the accumulate method is 32 | called `max_num_accumulations` times or the total number of batch elements 33 | processed is below `max_example_count`. 34 | 35 | To enable full GPU compatibility the number of accumulations is stored as a 36 | float32. As this number is incremented one by one, we require 37 | `max_num_accumulations` to be smaller than the highest float32 number that 38 | maintains integer precision (16777216). 39 | 40 | """ 41 | 42 | def __init__( 43 | self, 44 | *, 45 | std_epsilon: float = 1e-5, 46 | name: str = 'accumulated_normalizer', 47 | ): 48 | """Inits the module. 49 | 50 | Args: 51 | std_epsilon: minimum value of the standard deviation to use. 52 | name: Name of the module. 53 | """ 54 | super(AccumulatedNormalizer, self).__init__(name=name) 55 | self._accumulator_shape = None 56 | self._std_epsilon = std_epsilon 57 | 58 | def __call__(self, batched_data): 59 | """Direct transformation of the normalizer.""" 60 | self._set_accumulator_shape(batched_data) 61 | return (batched_data - self.mean) / self.std_with_epsilon 62 | 63 | def inverse(self, normalized_batch_data): 64 | """Inverse transformation of the normalizer.""" 65 | self._set_accumulator_shape(normalized_batch_data) 66 | return normalized_batch_data * self.std_with_epsilon + self.mean 67 | 68 | def _set_accumulator_shape(self, batched_sample_data): 69 | self._accumulator_shape = batched_sample_data.shape[-1] 70 | 71 | def _verify_module_connected(self): 72 | if self._accumulator_shape is None: 73 | raise RuntimeError( 74 | 'Trying to read the mean before connecting the module.') 75 | 76 | @property 77 | def _acc_sum(self): 78 | return hk.get_state( 79 | 'acc_sum', self._accumulator_shape, dtype=jnp.float32, init=jnp.zeros) 80 | 81 | @property 82 | def _acc_count(self): 83 | return hk.get_state('acc_count', (), dtype=jnp.float32, init=jnp.zeros) 84 | 85 | @property 86 | def _acc_sum_squared(self): 87 | return hk.get_state( 88 | 'acc_sum_squared', 89 | self._accumulator_shape, 90 | dtype=jnp.float32, 91 | init=jnp.zeros) 92 | 93 | @property 94 | def _safe_count(self): 95 | # To ensure count is at least one and avoid nan's. 96 | return jnp.maximum(self._acc_count, 1.) 97 | 98 | @property 99 | def mean(self): 100 | self._verify_module_connected() 101 | return self._acc_sum / self._safe_count 102 | 103 | @property 104 | def std(self): 105 | self._verify_module_connected() 106 | var = self._acc_sum_squared / self._safe_count - self.mean**2 107 | var = jnp.maximum(var, 0.) # Prevent negatives due to numerical precision. 108 | return jnp.sqrt(var) 109 | 110 | @property 111 | def std_with_epsilon(self): 112 | # To use in case the std is too small. 113 | return jnp.maximum(self.std, self._std_epsilon) 114 | 115 | 116 | class GraphElementsNormalizer(hk.Module): 117 | """Online normalization of individual graph components of a GraphsTuple. 118 | 119 | 120 | Can be used to normalize individual node, edge, and global arrays. 121 | 122 | """ 123 | 124 | def __init__(self, 125 | template_graph: jraph.GraphsTuple, 126 | is_padded_graph: bool, 127 | name: str = 'graph_elements_normalizer'): 128 | """Inits the module. 129 | 130 | Args: 131 | template_graph: Input template graph to compute edge/node/global padding 132 | masks. 133 | is_padded_graph: Whether the graph has padding. 134 | name: Name of the Haiku module. 135 | """ 136 | 137 | super().__init__(name=name) 138 | self._node_mask = None 139 | self._edge_mask = None 140 | self._graph_mask = None 141 | if is_padded_graph: 142 | self._node_mask = jraph.get_node_padding_mask(template_graph) 143 | self._edge_mask = jraph.get_edge_padding_mask(template_graph) 144 | self._graph_mask = jraph.get_graph_padding_mask(template_graph) 145 | 146 | self._names_used = [] 147 | 148 | def _run_normalizer(self, name, array, mask): 149 | if name in self._names_used: 150 | raise ValueError( 151 | f'Attempt to reuse name {name}. Used names: {self._names_used}') 152 | self._names_used.append(name) 153 | 154 | normalizer = get_accumulated_normalizer(name) 155 | return normalizer(array) 156 | 157 | def normalize_node_array(self, name, array): 158 | return self._run_normalizer(name, array, self._node_mask) 159 | -------------------------------------------------------------------------------- /src/watercourse_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Watercourse 3D environment utils.""" 17 | import haiku as hk 18 | import jax 19 | import jax.numpy as jnp 20 | import jraph 21 | import numpy as np 22 | import tree 23 | 24 | from inverse_design.src import connectivity_utils 25 | 26 | 27 | NORMAL = 0 28 | OBSTACLE = 1 29 | INFLOW = 4 30 | 31 | # for eliminating stray particles from pipe 32 | OOB_AREA = 1.5 33 | 34 | 35 | def _update_edges(input_graph, obstacle_edges, radius): 36 | """Recomputes particle edges, adds obstacle edges.""" 37 | # get input graph nodes corresponding to fluid 38 | query_mask = ~input_graph.nodes["external_mask"] 39 | 40 | # get input graph ndoes that are either fluid or obstacle 41 | valid_mask = query_mask | input_graph.nodes["obstacle_mask"] 42 | max_edges = input_graph.senders.shape[0] 43 | num_obstacle_edges = obstacle_edges.shape[0] 44 | 45 | # compute the sender and receiver edges for fluid-fluid and fluid-obstacle 46 | # interactions. 47 | n_edge, senders, receivers = connectivity_utils.compute_fixed_radius_connectivity_jax( 48 | input_graph.nodes["world_position"][:, -1], 49 | n_node=input_graph.n_node[:-1], max_edges=max_edges - num_obstacle_edges, 50 | radius=radius, query_mask=query_mask, node_mask=valid_mask) 51 | 52 | # update edges to include obstacle edges and new fluid-fluid edges 53 | return input_graph._replace( 54 | senders=jnp.concatenate([obstacle_edges[:, 0], senders], axis=0), 55 | receivers=jnp.concatenate([obstacle_edges[:, 1], receivers], axis=0), 56 | n_edge=n_edge.at[0].set(n_edge[0] + num_obstacle_edges)) 57 | 58 | 59 | def forward(input_graph, new_particles, network, haiku_model, obstacle_edges, 60 | radius): 61 | """Runs model and post-processing steps in jax, returns position sequence.""" 62 | @hk.transform_with_state 63 | def model(inputs): 64 | return haiku_model()(inputs) 65 | rnd_key = jax.random.PRNGKey(42) # use a fixed random key 66 | 67 | # only run for a single graph (plus one padding graph), update graph with 68 | # obstacle edges 69 | assert len(input_graph.n_node) == 2, "Not a single padded graph." 70 | graph = tree.map_structure(lambda x: x, input_graph) 71 | graph = _update_edges(graph, obstacle_edges, radius) 72 | 73 | # build material type 74 | pattern = jnp.ones_like(graph.nodes["external_mask"], dtype=jnp.int32) 75 | inflow_mask = jnp.any(~graph.nodes["mask_stack"], axis=-1) 76 | graph.nodes["material_type(9)"] = jnp.where( 77 | graph.nodes["external_mask"], pattern * OBSTACLE, 78 | jnp.where(inflow_mask, pattern * INFLOW, 79 | pattern * NORMAL)) 80 | graph.nodes["type/particles"] = None 81 | 82 | # run model 83 | prev_pos = input_graph.nodes["world_position"] 84 | model_out = model.apply(network["params"], network["state"], rnd_key, graph) 85 | pred_pos = model_out[0][0].nodes["p:world_position"] 86 | total_nodes = jnp.sum(input_graph.n_node[:-1]) 87 | node_padding_mask = jnp.arange(prev_pos.shape[0]) < total_nodes 88 | 89 | # update history, reset external particles 90 | next_pos_seq = jnp.concatenate([prev_pos[:, 1:], pred_pos[:, None]], axis=1) 91 | mask = (~input_graph.nodes["external_mask"]) & node_padding_mask 92 | next_pos_seq = jnp.where(mask[:, None, None], next_pos_seq, prev_pos) 93 | 94 | # add new particles, remove old particles that go below the floor surface 95 | delete_particles = next_pos_seq[:, -1, 1] <= 0 96 | delete_particles &= graph.nodes["mask_stack"][:, -1] 97 | particle_mask = graph.nodes["mask_stack"][:, -1] & ~delete_particles 98 | particle_mask |= new_particles 99 | mask_stack = jnp.concatenate( 100 | [graph.nodes["mask_stack"][:, 1:], particle_mask[:, None]], axis=1) 101 | 102 | # create new node features and update graph 103 | new_node_features = { 104 | **input_graph.nodes, 105 | "world_position": next_pos_seq, 106 | "mask_stack": mask_stack, 107 | "external_mask": ~particle_mask, 108 | "deleted": graph.nodes["deleted"] | delete_particles, 109 | } 110 | return input_graph._replace(nodes=new_node_features) 111 | 112 | 113 | def build_initial_graph(input_graphs, max_edges): 114 | """Builds initial padded graphs tuple from typed graph.""" 115 | obstacle_edges = np.stack( 116 | [input_graphs[0].senders, input_graphs[0].receivers], axis=1) 117 | graph = tree.map_structure(lambda x: x.copy(), input_graphs[0]) 118 | 119 | # clear graph edges 120 | dummy_edge = np.zeros((0,), dtype=np.int32) 121 | graph = graph._replace( 122 | senders=dummy_edge, 123 | receivers=dummy_edge, 124 | n_edge=np.array([0], dtype=np.int32)) 125 | 126 | # build inflow stack 127 | inflow_stack = [] 128 | init_pos = graph.nodes["world_position"] 129 | for cur_graph in input_graphs: 130 | mask_stack = cur_graph.nodes["mask_stack"] 131 | cur_pos = cur_graph.nodes["world_position"] 132 | new_particles = mask_stack[:, -1] & (~mask_stack[:, -2]) 133 | init_pos[new_particles] = cur_pos[new_particles] 134 | new_particles = np.concatenate([new_particles, [False]]) 135 | inflow_stack.append(new_particles) 136 | inflow_stack = np.stack(inflow_stack[1:], axis=0) 137 | graph.nodes["world_position"] = init_pos 138 | graph.nodes["deleted"] = np.zeros(init_pos.shape[0], dtype=bool) 139 | 140 | # fix stray particles 141 | stray_particles = init_pos[:, -1, 1] > OOB_AREA 142 | graph.nodes["mask_stack"][stray_particles] = False 143 | graph.nodes["external_mask"][stray_particles] = True 144 | 145 | # pad to maximum node, edge values and add padding graph 146 | max_n_node = graph.n_node.sum() + 1 147 | graphs_tuple = jraph.pad_with_graphs(graph, n_node=max_n_node, 148 | n_edge=max_edges, n_graph=2) 149 | return obstacle_edges, inflow_stack, graphs_tuple 150 | 151 | 152 | def rollout(initial_graph, inflow_stack, network, haiku_model, obstacle_edges, 153 | radius): 154 | """Runs a jittable model rollout.""" 155 | @jax.checkpoint 156 | def _step(graph, inflow_mask): 157 | out_graph = forward(graph, inflow_mask, network, haiku_model, 158 | obstacle_edges, radius) 159 | out_data = dict( 160 | pos=out_graph.nodes["world_position"][:, -1], 161 | mask=out_graph.nodes["mask_stack"][:, -1]) 162 | return out_graph, out_data 163 | final_graph, trajectory = jax.lax.scan(_step, init=initial_graph, 164 | xs=inflow_stack) 165 | return final_graph, trajectory 166 | 167 | 168 | def make_plain_obstacles(num_side=25): 169 | """Create a mesh obstacle (landscape) with num_side squared control points.""" 170 | px, pz = np.meshgrid( 171 | np.linspace(-0.5, 0.5, num_side), np.linspace(-0.5, 0.5, num_side)) 172 | trans = np.array([0.5, 0.5, 0.5]) 173 | 174 | # generate height map 175 | py = np.zeros_like(px) 176 | pos = np.stack([px, py, pz], axis=-1).reshape((-1, 3)) 177 | pos += trans[None] 178 | return pos 179 | 180 | 181 | def max_x_loss_fn(graph): 182 | """Example loss function for maximizing x position of particles when they hit the ground.""" 183 | z_pos = graph.nodes["world_position"][:, -1, 2] 184 | z_var = jnp.std(z_pos, where=graph.nodes["deleted"]) 185 | x_pos = graph.nodes["world_position"][:, -1, 0] 186 | x_max = jnp.mean(-x_pos, where=graph.nodes["deleted"]) 187 | return x_max + z_var 188 | 189 | 190 | def smooth_loss_fn(obs_pos, num_side=25): 191 | """Smoothing loss function for minimizing sharp changes across obstacle.""" 192 | obs_grid = jnp.reshape(obs_pos, (num_side, num_side)) 193 | obs_dx = jnp.diff(obs_grid, axis=0) ** 2 194 | obs_dy = jnp.diff(obs_grid, axis=1) ** 2 195 | return 0.5 * (jnp.mean(obs_dx) + jnp.mean(obs_dy)) 196 | 197 | 198 | def design_fn(params, graph, height_scale=0.15): 199 | """Convert parameters in params into landscape heightfield to be represented in graph.""" 200 | graph = tree.map_structure(lambda x: x, graph) 201 | init_pos = jnp.array(graph.nodes["world_position"]) 202 | # use tanh transformation to limit height to be within [-1, 1] 203 | raw_obs_pos = jnp.tanh(params) * height_scale 204 | 205 | # tile graph to have the same time history as the fluid particles 206 | obs_pos = jnp.tile(raw_obs_pos[:, None], [1, init_pos.shape[1]]) 207 | 208 | # only controlling the height, so set other dimensions to 0 209 | obs_pos = jnp.stack( 210 | [jnp.zeros_like(obs_pos), obs_pos, 211 | jnp.zeros_like(obs_pos)], axis=-1) 212 | 213 | # add controlled height to initial heightfield and update graph nodes 214 | pos = jnp.concatenate( 215 | [init_pos[:obs_pos.shape[0]] + obs_pos, init_pos[obs_pos.shape[0]:]], 216 | axis=0) 217 | graph.nodes["world_position"] = pos 218 | return graph, raw_obs_pos 219 | --------------------------------------------------------------------------------