├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── demo_design_optimization.ipynb
├── requirements.txt
└── src
├── connectivity_utils.py
├── graph_network.py
├── learned_simulator.py
├── model_utils.py
├── normalizers.py
└── watercourse_env.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Inverse Design for fluid-structure interactions using graph network simulators
2 |
3 | Code and parameters to accompany the NeurIPS 2022 paper
4 | **Inverse Design for Fluid-Structure Interactions using Graph Network
5 | Simulators** ([arXiv](https://arxiv.org/abs/2202.00728))
6 | _Kelsey R. Allen*, Tatiana Lopez-Guevara*, Kimberly Stachenfeld*,
7 | Alvaro Sanchez-Gonzalez, Peter Battaglia, Jessica Hamrick, Tobias Pfaff_
8 |
9 | The code here provides an implementation of the Encode-Process-Decode
10 | graph network architecture in jax, model weights for this architecture trained
11 | on the 3D WaterCourse environment, and an example of performing gradient-based
12 | optimization in order to optimize a landscape to reroute water.
13 |
14 | ## Usage
15 |
16 | ### in a google colab
17 | Open the [google colab](https://colab.research.google.com/github/deepmind/inverse_design/blob/master/demo_design_optimization.ipynb) and run all cells.
18 |
19 | ### with jupyter notebook / locally
20 | To install the necessary requirements (run these commands from the directory
21 | that you wish to clone `inverse_design` into):
22 |
23 | ```shell
24 | git clone https://github.com/deepmind/inverse_design.git
25 | python3 -m venv id_venv
26 | source id_venv/bin/activate
27 | pip install --upgrade pip
28 | pip install -r ./inverse_design/requirements.txt
29 | ```
30 |
31 | Additionally install jupyter notebook if not already installed with
32 | `pip install notebook`
33 |
34 | Finally, make a new directory within the `inverse_design` repository and move
35 | files there:
36 | ```shell
37 | cd inverse_design
38 | mkdir inverse_design
39 | mv src/ inverse_design/
40 | ```
41 |
42 | Download the dataset and model weights from google cloud:
43 | ```shell
44 | wget -O ./gns_params.pickle https://storage.googleapis.com/dm_inverse_design_watercourse/gns_params.pickle
45 | wget -O ./init_sequence.pickle https://storage.googleapis.com/dm_inverse_design_watercourse/init_sequence.pickle
46 | ```
47 |
48 | Now you should be ready to go! Open `demo_design_optimization.ipynb` inside
49 | a jupyter notebook and run *from third cell* onwards.
50 |
51 | ## Citing this work
52 |
53 | If you use this work, please cite the following paper
54 | ```
55 | @misc{inversedesign_2022,
56 | title = {Inverse Design for Fluid-Structure Interactions using Graph Network Simulators},
57 | author = {Kelsey R. Allen and
58 | Tatiana Lopez{-}Guevara and
59 | Kimberly L. Stachenfeld and
60 | Alvaro Sanchez{-}Gonzalez and
61 | Peter W. Battaglia and
62 | Jessica B. Hamrick and
63 | Tobias Pfaff},
64 | journal = {Neural Information Processing Systems},
65 | year = {2022},
66 | }
67 | ```
68 | ## License and disclaimer
69 |
70 | Copyright 2022 DeepMind Technologies Limited
71 |
72 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
73 | you may not use this file except in compliance with the Apache 2.0 license.
74 | You may obtain a copy of the Apache 2.0 license at:
75 | https://www.apache.org/licenses/LICENSE-2.0
76 |
77 | All other materials are licensed under the Creative Commons Attribution 4.0
78 | International License (CC-BY). You may obtain a copy of the CC-BY license at:
79 | https://creativecommons.org/licenses/by/4.0/legalcode
80 |
81 | Unless required by applicable law or agreed to in writing, all software and
82 | materials distributed here under the Apache 2.0 or CC-BY licenses are
83 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
84 | either express or implied. See the licenses for the specific language governing
85 | permissions and limitations under those licenses.
86 |
87 | This is not an official Google product.
88 |
--------------------------------------------------------------------------------
/demo_design_optimization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "SyNZcVgQPsKs"
7 | },
8 | "source": [
9 | "Copyright 2022 DeepMind Technologies Limited\n",
10 | "\n",
11 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
12 | "you may not use this file except in compliance with the License.\n",
13 | "You may obtain a copy of the License at\n",
14 | "\n",
15 | " https://www.apache.org/licenses/LICENSE-2.0\n",
16 | "\n",
17 | "Unless required by applicable law or agreed to in writing, software\n",
18 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
19 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
20 | "See the License for the specific language governing permissions and\n",
21 | "limitations under the License.\n"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {
27 | "id": "I8K_e58H8S9d"
28 | },
29 | "source": [
30 | "# Demo design optimization for 3D WaterCourse environment\n"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {
37 | "id": "1jNaCf3sM1Xm"
38 | },
39 | "outputs": [],
40 | "source": [
41 | "#@title Installation (if not running locally)\n",
42 | "# Note, this should be skipped if running locally.\n",
43 | "!mkdir /content/inverse_design\n",
44 | "!mkdir /content/inverse_design/src\n",
45 | "!touch /content/inverse_design/__init__.py\n",
46 | "!touch /content/inverse_design/src/__init__.py\n",
47 | "\n",
48 | "!wget -O /content/inverse_design/src/connectivity_utils.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/connectivity_utils.py\n",
49 | "!wget -O /content/inverse_design/src/graph_network.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/graph_network.py\n",
50 | "!wget -O /content/inverse_design/src/learned_simulator.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/learned_simulator.py\n",
51 | "!wget -O /content/inverse_design/src/model_utils.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/model_utils.py\n",
52 | "!wget -O /content/inverse_design/src/normalizers.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/normalizers.py\n",
53 | "!wget -O /content/inverse_design/src/watercourse_env.py https://raw.githubusercontent.com/deepmind/master/inverse_design/src/watercourse_env.py\n",
54 | "\n",
55 | "!wget -O /content/requirements.txt https://raw.githubusercontent.com/deepmind/master/inverse_design/requirements.txt\n",
56 | "!pip install -r requirements.txt"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {
63 | "id": "Q08NCU1pds4j"
64 | },
65 | "outputs": [],
66 | "source": [
67 | "#@title Download Pickled Dataset \u0026 Params (if running in colab)\n",
68 | "# Note this can be skipped if following instructions for jupyter notebook\n",
69 | "from google.colab import auth\n",
70 | "auth.authenticate_user()\n",
71 | "\n",
72 | "!gsutil cp gs://dm_inverse_design_watercourse/init_sequence.pickle .\n",
73 | "!gsutil cp gs://dm_inverse_design_watercourse/gns_params.pickle .\n"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {
80 | "id": "3gfBGlIrCkBf"
81 | },
82 | "outputs": [],
83 | "source": [
84 | "#@title Imports\n",
85 | "from inverse_design.src import learned_simulator\n",
86 | "from inverse_design.src import model_utils\n",
87 | "from inverse_design.src import watercourse_env"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {
94 | "id": "SJVrIfyHnkOK"
95 | },
96 | "outputs": [],
97 | "source": [
98 | "#@title Open pickled parameters + dataset\n",
99 | "import pickle\n",
100 | "\n",
101 | "with open('init_sequence.pickle', \"rb\") as f:\n",
102 | " pickled_data = pickle.loads(f.read())\n",
103 | " gt_sequence = pickled_data['gt_sequence']\n",
104 | " meta = pickled_data['meta']\n",
105 | "\n",
106 | "with open('gns_params.pickle', \"rb\") as f:\n",
107 | " pickled_params = pickle.loads(f.read())\n",
108 | " network = pickled_params['network']\n",
109 | " plan = pickled_params['plan']"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {
116 | "id": "kdxb0hSLPRwH"
117 | },
118 | "outputs": [],
119 | "source": [
120 | "#@title make ID control/loss functions\n",
121 | "import jax\n",
122 | "import functools\n",
123 | "\n",
124 | "# maximum number of edges for single step of rollout to pad to\n",
125 | "MAX_EDGES = 2**16\n",
126 | "\n",
127 | "# define haiku model\n",
128 | "connectivity_radius = meta[\"connectivity_radius\"]\n",
129 | "flatten_fn = functools.partial(model_utils.flatten_features, **plan['flatten_kwargs'])\n",
130 | "haiku_model = functools.partial(learned_simulator.LearnedSimulator, connectivity_radius=connectivity_radius, flatten_features_fn=flatten_fn, **plan['model_kwargs'])\n",
131 | "\n",
132 | "# create initial landscape (obstacle) in the scene\n",
133 | "obstacle_pos = watercourse_env.make_plain_obstacles()\n",
134 | "for frame in gt_sequence:\n",
135 | " pos = frame.nodes['world_position'].copy()\n",
136 | " pos[:obstacle_pos.shape[0]] = obstacle_pos[:, None]\n",
137 | " frame.nodes['world_position'] = pos\n",
138 | "\n",
139 | "\n",
140 | "# get initial sequence of particles from dataset for initial graph\n",
141 | "obstacle_edges, inflow_stack, initial_graph = watercourse_env.build_initial_graph(gt_sequence[15:], max_edges=MAX_EDGES)\n",
142 | "\n",
143 | "# infer the landscape size from the dataset (25 x 25)\n",
144 | "# note that this is not required, it is also possible to create a smaller\n",
145 | "# or larger landscape (obstacle) as the design space\n",
146 | "num_side = int(jax.numpy.sqrt(initial_graph.nodes['obstacle_mask'].sum()))\n",
147 | "n_obs = num_side**2\n",
148 | "\n",
149 | "# rollout length definition (final state taken for reward computation)\n",
150 | "length = 50\n",
151 | "# radius within which to connect particles\n",
152 | "radius = 0.1\n",
153 | "# smoothing factor for loss\n",
154 | "smoothing_factor = 1e2\n",
155 | "\n",
156 | "@jax.jit\n",
157 | "def run(vars):\n",
158 | " # create landscape as graph from vars parameters\n",
159 | " graph, raw_obs = watercourse_env.design_fn(vars, initial_graph)\n",
160 | "\n",
161 | " # rollout\n",
162 | " final_graph, traj = watercourse_env.rollout(\n",
163 | " graph, inflow_stack[:length], network, haiku_model,\n",
164 | " obstacle_edges, radius=radius)\n",
165 | " \n",
166 | " # losses\n",
167 | " losses = {\n",
168 | " 'objective': watercourse_env.max_x_loss_fn(final_graph),\n",
169 | " 'smooth': smoothing_factor * watercourse_env.smooth_loss_fn(raw_obs, num_side=num_side),\n",
170 | " }\n",
171 | "\n",
172 | " # auxiliaries to keep track of for plotting\n",
173 | " aux = {\n",
174 | " 'design': vars,\n",
175 | " 'losses': losses,\n",
176 | " 'traj': traj\n",
177 | " }\n",
178 | " return sum(losses.values()), aux\n"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": null,
184 | "metadata": {
185 | "id": "TuuHYP3dOW45"
186 | },
187 | "outputs": [],
188 | "source": [
189 | "from IPython.display import clear_output\n",
190 | "import jax.numpy as jnp\n",
191 | "import matplotlib.pyplot as plt\n",
192 | "import optax\n",
193 | "\n",
194 | "# set learning rate and number of optimization steps\n",
195 | "LEARNING_RATE = 0.05\n",
196 | "num_opt_steps = 100\n",
197 | "\n",
198 | "# define optimizer as adam with learning rate\n",
199 | "optimizer = optax.adam(learning_rate=LEARNING_RATE)\n",
200 | "\n",
201 | "# initialize design parameters to be zeros (flat landscape)\n",
202 | "params = jnp.zeros(n_obs, dtype=jnp.float32)\n",
203 | "opt_state = optimizer.init(params)\n",
204 | "\n",
205 | "# initialize empty optimization trajectory (for tracking improvements to losses and design)\n",
206 | "opt_traj = []\n",
207 | "\n",
208 | "# optimization step with current design parameters\n",
209 | "@jax.jit\n",
210 | "def opt_step(params, opt_state):\n",
211 | " grads, aux = jax.grad(run, has_aux=True)(params)\n",
212 | " updates, opt_state = optimizer.update(grads, opt_state, params)\n",
213 | " params = optax.apply_updates(params, updates)\n",
214 | " return params, opt_state, aux\n",
215 | "\n",
216 | "# run optimization loop and track progress\n",
217 | "for i in range(num_opt_steps):\n",
218 | " params, opt_state, aux = opt_step(params, opt_state)\n",
219 | " opt_traj.append(aux)\n",
220 | " clear_output(wait=True)\n",
221 | " fig, ax = plt.subplots(1,1,figsize=(10,5))\n",
222 | " for key in aux['losses'].keys():\n",
223 | " ax.plot([t['losses'][key] for t in opt_traj])\n",
224 | " ax.plot([sum(t['losses'].values()) for t in opt_traj])\n",
225 | " ax.legend(list(aux['losses'].keys())+['total'])\n",
226 | " plt.show()"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": null,
232 | "metadata": {
233 | "id": "88voGdYuAPus"
234 | },
235 | "outputs": [],
236 | "source": [
237 | "import numpy as np\n",
238 | "\n",
239 | "# plot design iterations (every 10 steps)\n",
240 | "n_sam = range(0, len(opt_traj), 10)\n",
241 | "fig, ax = plt.subplots(1,len(n_sam),figsize=(len(n_sam)*10, 10), squeeze=False)\n",
242 | "\n",
243 | "for fi, idx in enumerate(n_sam):\n",
244 | " design = opt_traj[idx]['design']\n",
245 | "\n",
246 | " # control function uses tanh as transformation, so mimic here to see heightfield\n",
247 | " fld = np.tanh(design.reshape((num_side, num_side)))\n",
248 | " ax[0, fi].imshow(fld, vmin=-1, vmax=1)\n",
249 | " ax[0, fi].set_axis_off()"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": null,
255 | "metadata": {
256 | "id": "yPaWTUM0FL36"
257 | },
258 | "outputs": [],
259 | "source": [
260 | "from IPython.display import clear_output\n",
261 | "# plot video of how particles move for optimized design and initial design\n",
262 | "\n",
263 | "def _plt(ax, frame, i):\n",
264 | " pos = frame['pos'][i] \n",
265 | " p = pos[frame['mask'][i]]\n",
266 | " ax.scatter(p[:, 0], p[:, 2], p[:, 1], c='b',s=10)\n",
267 | " obs = pos[:num_side**2]\n",
268 | " ax.scatter(obs[:, 0], obs[:, 2], obs[:, 1], c='k',s=3)\n",
269 | " ax.scatter([1.5],[1.5],[0], c='g',s=20)\n",
270 | " ax.set_xlim([-0.6, 1.6])\n",
271 | " ax.set_ylim([-0.1, 1.6])\n",
272 | " ax.set_zlim([-0.1, 1.2])\n",
273 | "\n",
274 | "roll_fin0 = run(opt_traj[0]['design'])[1]['traj']\n",
275 | "roll_fin1 = run(opt_traj[-1]['design'])[1]['traj']\n",
276 | "\n",
277 | "for i in range(roll_fin0['pos'].shape[0]):\n",
278 | " clear_output(wait=True)\n",
279 | " fig = plt.figure(figsize=(20,10))\n",
280 | " ax1 = fig.add_subplot(1, 2, 1, projection='3d')\n",
281 | " ax1.set_title('Initial design, frame %d' % i)\n",
282 | " _plt(ax1, roll_fin0, i)\n",
283 | " ax2 = fig.add_subplot(1, 2, 2, projection='3d')\n",
284 | " ax2.set_title('Design at final step')\n",
285 | " _plt(ax2, roll_fin1, i)\n",
286 | " plt.show()"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": null,
292 | "metadata": {
293 | "id": "csTLqGTUPuAn"
294 | },
295 | "outputs": [],
296 | "source": []
297 | }
298 | ],
299 | "metadata": {
300 | "colab": {
301 | "private_outputs": true,
302 | "provenance": [
303 | {
304 | "file_id": "1ZYL6nDmJCvzc70qi5rwIQ7BSj9a2sdoL",
305 | "timestamp": 1667473449566
306 | },
307 | {
308 | "file_id": "1rOmaBHyQAVa6NY3nfsM1yba8GC0a8duM",
309 | "timestamp": 1667402678487
310 | },
311 | {
312 | "file_id": "1bHb2szUEMLP2gKErlnxrigutZyhZ4reh",
313 | "timestamp": 1667386305836
314 | },
315 | {
316 | "file_id": "1jrPkT2OlRGUpImN2cVetBtQ9PgFelNgy",
317 | "timestamp": 1665137721073
318 | },
319 | {
320 | "file_id": "144LaPYCJpaCSUX6jMoampo0dTMICcXWV",
321 | "timestamp": 1640626024159
322 | },
323 | {
324 | "file_id": "1B31Dl0yzZlW22l-X7qSR5S4w1nOHPAg0",
325 | "timestamp": 1640104963946
326 | },
327 | {
328 | "file_id": "1MXdt8a59t0hLm7Q1Prem3kdHpnixMOeJ",
329 | "timestamp": 1639481266768
330 | },
331 | {
332 | "file_id": "14OG1REDMq3i9phUABzuSWg3nQhUD_ozK",
333 | "timestamp": 1639156624846
334 | }
335 | ]
336 | },
337 | "kernelspec": {
338 | "display_name": "Python 3",
339 | "name": "python3"
340 | },
341 | "language_info": {
342 | "name": "python"
343 | }
344 | },
345 | "nbformat": 4,
346 | "nbformat_minor": 0
347 | }
348 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.2.0
2 | dm-haiku==0.0.9
3 | ml_collections==0.1.1
4 | numpy>=1.21.0
5 | optax==0.1.3
6 | jraph>=0.0.5.dev0
7 | scikit-learn>=0.24.0
8 | matplotlib>=3.6.0
9 |
--------------------------------------------------------------------------------
/src/connectivity_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tools to compute the connectivity of the graph."""
17 | import functools
18 |
19 | import jax
20 | from jax.experimental import host_callback as hcb
21 | import jax.numpy as jnp
22 | import numpy as np
23 | from sklearn import neighbors
24 |
25 |
26 | def _cb_radius_query(args):
27 | """Host callback function to compute connectivity."""
28 | padded_pos, n_node, radius, max_edges, query_mask, node_mask = args
29 | edges = []
30 | offset = 0
31 |
32 | for num_nodes in n_node:
33 | pos_nodes = padded_pos[offset:offset+num_nodes]
34 | pos_query = padded_pos[offset:offset+num_nodes]
35 | pos_nodes = pos_nodes[node_mask[offset:offset+num_nodes]]
36 | pos_query = pos_query[query_mask[offset:offset+num_nodes]]
37 |
38 | # indices: [num_edges, 2] array of receivers ([:, 0]) and senders ([:, 1])
39 | indices = compute_fixed_radius_connectivity_np(pos_nodes, radius, pos_query)
40 | mask = query_mask[offset:offset+num_nodes]
41 | renumber = np.arange(num_nodes, dtype=np.int32)[mask]
42 | indices[:, 0] = renumber[indices[:, 0]]
43 |
44 | mask = node_mask[offset:offset+num_nodes]
45 | renumber = np.arange(num_nodes, dtype=np.int32)[mask]
46 | indices[:, 1] = renumber[indices[:, 1]]
47 |
48 | # remove self-edges
49 | mask = indices[:, 0] != indices[:, 1]
50 | indices = indices[mask]
51 |
52 | # create unique two way edges (only necessary in the masked case)
53 | indices = np.stack([np.min(indices, axis=1),
54 | np.max(indices, axis=1)],
55 | axis=1)
56 | indices = np.unique(indices, axis=0)
57 | indices = np.concatenate([indices, indices[:, [1, 0]]], axis=0)
58 |
59 | edges.append(indices + offset)
60 | offset += num_nodes
61 |
62 | n_edge = [x.shape[0] for x in edges]
63 | total_edges = np.sum(n_edge)
64 |
65 | # padding
66 | if total_edges >= max_edges:
67 | raise ValueError("%d edges found, max_edges: %d" % (total_edges, max_edges))
68 |
69 | # create a [n_p, 2] padding array, which connects the first dummy padding node
70 | # (with index `num_nodes`) to itself.
71 | padding_size = max_edges - total_edges
72 | padding = np.ones((padding_size, 2), dtype=np.int32) * offset
73 | edges = np.concatenate(edges + [padding], axis=0)
74 | n_edge = np.array(n_edge + [padding_size], dtype=np.int32)
75 | return n_edge, edges
76 |
77 |
78 | @functools.partial(jax.custom_jvp, nondiff_argnums=(4, 5))
79 | def compute_fixed_radius_connectivity_jax(positions, n_node, query_mask,
80 | node_mask, radius, max_edges):
81 | """Computes connectivity for batched graphs using a jax host callback.
82 |
83 | Args:
84 | positions: concatenated vector (N, 2) of node positions for all graphs
85 | n_node: array of num_nodes for each graph
86 | query_mask: defines the subset of nodes to query from (None=all)
87 | node_mask: defines the subset of nodes to query to (None=all)
88 | radius: connectivity radius
89 | max_edges: maximum total number of edges
90 |
91 | Returns:
92 | array of num_edges, senders, receivers
93 | """
94 | callback_arg = (positions, n_node, radius, max_edges, query_mask, node_mask)
95 | out_shape = (jax.ShapeDtypeStruct((len(n_node) + 1,), jnp.int32),
96 | jax.ShapeDtypeStruct((max_edges, 2), jnp.int32))
97 | n_edge, indices = hcb.call(_cb_radius_query, callback_arg,
98 | result_shape=out_shape)
99 |
100 | senders = indices[:, 1]
101 | receivers = indices[:, 0]
102 | return n_edge, senders, receivers
103 |
104 |
105 | @compute_fixed_radius_connectivity_jax.defjvp
106 | def _compute_fixed_radius_connectivity_jax_jvp(radius, max_edges, primals,
107 | tangents):
108 | """Custom zero-jvp function for compute_fixed_radius_connectivity_jax."""
109 | del tangents
110 | primal_out = compute_fixed_radius_connectivity_jax(
111 | *primals, radius=radius, max_edges=max_edges)
112 | grad_out = tuple(jnp.zeros_like(x) for x in primal_out)
113 | return primal_out, grad_out
114 |
115 |
116 | def compute_fixed_radius_connectivity_np(
117 | positions, radius, receiver_positions=None, remove_self_edges=False):
118 | """Computes connectivity between positions and receiver_positions."""
119 |
120 | # if removing self edges, receiver positions must be none
121 | assert not (remove_self_edges and receiver_positions is not None)
122 |
123 | if receiver_positions is None:
124 | receiver_positions = positions
125 |
126 | # use kdtree for efficient calculation of pairs within radius distance
127 | kd_tree = neighbors.KDTree(positions)
128 | receivers_list = kd_tree.query_radius(receiver_positions, r=radius)
129 | num_nodes = len(receiver_positions)
130 | senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list])
131 | receivers = np.concatenate(receivers_list, axis=0)
132 |
133 | if remove_self_edges:
134 | # Remove self edges.
135 | mask = senders != receivers
136 | senders = senders[mask]
137 | receivers = receivers[mask]
138 |
139 | return np.stack([senders.astype(np.int32),
140 | receivers.astype(np.int32)],
141 | axis=-1)
142 |
--------------------------------------------------------------------------------
/src/graph_network.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """JAX implementation of Encode Process Decode."""
17 |
18 | from typing import Optional
19 | import haiku as hk
20 | import jax
21 | import jax.numpy as jnp
22 | import jraph
23 |
24 |
25 | class EncodeProcessDecode(hk.Module):
26 | """Encode-Process-Decode function approximator for learnable simulator."""
27 |
28 | def __init__(
29 | self,
30 | *,
31 | latent_size: int,
32 | mlp_hidden_size: int,
33 | mlp_num_hidden_layers: int,
34 | num_message_passing_steps: int,
35 | num_processor_repetitions: int = 1,
36 | encode_nodes: bool = True,
37 | encode_edges: bool = True,
38 | node_output_size: Optional[int] = None,
39 | edge_output_size: Optional[int] = None,
40 | include_sent_messages_in_node_update: bool = False,
41 | use_layer_norm: bool = True,
42 | name: str = "EncodeProcessDecode"):
43 | """Inits the model.
44 |
45 | Args:
46 | latent_size: Size of the node and edge latent representations.
47 | mlp_hidden_size: Hidden layer size for all MLPs.
48 | mlp_num_hidden_layers: Number of hidden layers in all MLPs.
49 | num_message_passing_steps: Number of unshared message passing steps
50 | in the processor steps.
51 | num_processor_repetitions: Number of times that the same processor is
52 | applied sequencially.
53 | encode_nodes: If False, the node encoder will be omitted.
54 | encode_edges: If False, the edge encoder will be omitted.
55 | node_output_size: Output size of the decoded node representations.
56 | edge_output_size: Output size of the decoded edge representations.
57 | include_sent_messages_in_node_update: Whether to include pooled sent
58 | messages from each node in the node update.
59 | use_layer_norm: Whether it uses layer norm or not.
60 | name: Name of the model.
61 | """
62 |
63 | super().__init__(name=name)
64 |
65 | self._latent_size = latent_size
66 | self._mlp_hidden_size = mlp_hidden_size
67 | self._mlp_num_hidden_layers = mlp_num_hidden_layers
68 | self._num_message_passing_steps = num_message_passing_steps
69 | self._num_processor_repetitions = num_processor_repetitions
70 | self._encode_nodes = encode_nodes
71 | self._encode_edges = encode_edges
72 | self._node_output_size = node_output_size
73 | self._edge_output_size = edge_output_size
74 | self._include_sent_messages_in_node_update = (
75 | include_sent_messages_in_node_update)
76 | self._use_layer_norm = use_layer_norm
77 | self._networks_builder()
78 |
79 | def __call__(self, input_graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
80 | """Forward pass of the learnable dynamics model."""
81 |
82 | # Encode the input_graph.
83 | latent_graph_0 = self._encode(input_graph)
84 |
85 | # Do `m` message passing steps in the latent graphs.
86 | latent_graph_m = self._process(latent_graph_0)
87 |
88 | # Decode from the last latent graph.
89 | return self._decode(latent_graph_m)
90 |
91 | def _networks_builder(self):
92 |
93 | def build_mlp(name, output_size=None):
94 | if output_size is None:
95 | output_size = self._latent_size
96 | mlp = hk.nets.MLP(
97 | output_sizes=[self._mlp_hidden_size] * self._mlp_num_hidden_layers + [
98 | output_size], name=name + "_mlp", activation=jax.nn.relu)
99 | return jraph.concatenated_args(mlp)
100 |
101 | def build_mlp_with_maybe_layer_norm(name, output_size=None):
102 | network = build_mlp(name, output_size)
103 | if self._use_layer_norm:
104 | layer_norm = hk.LayerNorm(
105 | axis=-1, create_scale=True, create_offset=True,
106 | name=name + "_layer_norm")
107 | network = hk.Sequential([network, layer_norm])
108 | return jraph.concatenated_args(network)
109 |
110 | # The encoder graph network independently encodes edge and node features.
111 | encoder_kwargs = dict(
112 | embed_edge_fn=build_mlp_with_maybe_layer_norm("encoder_edges")
113 | if self._encode_edges else None,
114 | embed_node_fn=build_mlp_with_maybe_layer_norm("encoder_nodes")
115 | if self._encode_nodes else None,)
116 | self._encoder_network = jraph.GraphMapFeatures(**encoder_kwargs)
117 |
118 | # Create `num_message_passing_steps` graph networks with unshared parameters
119 | # that update the node and edge latent features.
120 | # Note that we can use `modules.InteractionNetwork` because
121 | # it also outputs the messages as updated edge latent features.
122 | self._processor_networks = []
123 | for step_i in range(self._num_message_passing_steps):
124 | self._processor_networks.append(
125 | jraph.InteractionNetwork(
126 | update_edge_fn=build_mlp_with_maybe_layer_norm(
127 | f"processor_edges_{step_i}"),
128 | update_node_fn=build_mlp_with_maybe_layer_norm(
129 | f"processor_nodes_{step_i}"),
130 | include_sent_messages_in_node_update=(
131 | self._include_sent_messages_in_node_update)))
132 |
133 | # The decoder MLP decodes edge/node latent features into the output sizes.
134 | decoder_kwargs = dict(
135 | embed_edge_fn=build_mlp("decoder_edges", self._edge_output_size)
136 | if self._edge_output_size else None,
137 | embed_node_fn=build_mlp("decoder_nodes", self._node_output_size)
138 | if self._node_output_size else None,
139 | )
140 | self._decoder_network = jraph.GraphMapFeatures(**decoder_kwargs)
141 |
142 | def _encode(
143 | self, input_graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
144 | """Encodes the input graph features into a latent graph."""
145 |
146 | # Copy the globals to all of the nodes, if applicable.
147 | if input_graph.globals is not None:
148 | broadcasted_globals = jnp.repeat(
149 | input_graph.globals, input_graph.n_node, axis=0,
150 | total_repeat_length=input_graph.nodes.shape[0])
151 | input_graph = input_graph._replace(
152 | nodes=jnp.concatenate(
153 | [input_graph.nodes, broadcasted_globals], axis=-1),
154 | globals=None)
155 |
156 | # Encode the node and edge features.
157 | latent_graph_0 = self._encoder_network(input_graph)
158 | return latent_graph_0
159 |
160 | def _process(
161 | self, latent_graph_0: jraph.GraphsTuple) -> jraph.GraphsTuple:
162 | """Processes the latent graph with several steps of message passing."""
163 |
164 | # Do `num_message_passing_steps` with each of the `self._processor_networks`
165 | # with unshared weights, and repeat that `self._num_processor_repetitions`
166 | # times.
167 | latent_graph = latent_graph_0
168 | for unused_repetition_i in range(self._num_processor_repetitions):
169 | for processor_network in self._processor_networks:
170 | latent_graph = self._process_step(processor_network, latent_graph,
171 | latent_graph_0)
172 |
173 | return latent_graph
174 |
175 | def _process_step(
176 | self, processor_network_k,
177 | latent_graph_prev_k: jraph.GraphsTuple,
178 | latent_graph_0: jraph.GraphsTuple) -> jraph.GraphsTuple:
179 | """Single step of message passing with node/edge residual connections."""
180 |
181 | input_graph_k = latent_graph_prev_k
182 |
183 | # One step of message passing.
184 | latent_graph_k = processor_network_k(input_graph_k)
185 |
186 | # Add residuals.
187 | latent_graph_k = latent_graph_k._replace(
188 | nodes=latent_graph_k.nodes+latent_graph_prev_k.nodes,
189 | edges=latent_graph_k.edges+latent_graph_prev_k.edges)
190 | return latent_graph_k
191 |
192 | def _decode(self, latent_graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
193 | """Decodes from the latent graph."""
194 | return self._decoder_network(latent_graph)
195 |
--------------------------------------------------------------------------------
/src/learned_simulator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Graph Network Simulator implementation used in NeurIPS 2022 submission.
17 |
18 | Inverse Design for Fluid-Structure Interactions using Graph Network Simulators
19 |
20 | Kelsey R. Allen*, Tatiana Lopez-Guevera*, Kimberly Stachenfeld*,
21 | Alvaro Sanchez-Gonzalez, Peter Battaglia, Jessica Hamrick, Tobias Pfaff
22 | """
23 |
24 | from typing import Any, Dict
25 |
26 | import haiku as hk
27 | import jraph
28 |
29 | from inverse_design.src import graph_network
30 | from inverse_design.src import normalizers
31 |
32 |
33 | class LearnedSimulator(hk.Module):
34 | """Graph Network Simulator."""
35 |
36 | def __init__(self,
37 | connectivity_radius,
38 | *,
39 | graph_network_kwargs: Dict[str, Any],
40 | flatten_features_fn=None,
41 | name="LearnedSimulator"):
42 | """Initialize the model.
43 |
44 | Args:
45 | connectivity_radius: Radius of connectivity within which to connect
46 | particles with edges.
47 | graph_network_kwargs: Keyword arguments to pass to the learned part of the
48 | graph network `model.EncodeProcessDecode`.
49 | flatten_features_fn: Function that takes the input graph and dataset
50 | metadata, and returns a graph where node and edge features are a single
51 | array of rank 2, and without global features. The function will be
52 | wrapped in a haiku module, which allows the flattening fn to instantiate
53 | its own variable normalizers.
54 | name: Name of the Haiku module.
55 | """
56 | super().__init__(name=name)
57 | self._connectivity_radius = connectivity_radius
58 | self._graph_network_kwargs = graph_network_kwargs
59 | self._graph_network = None
60 |
61 | # Wrap flatten function in a Haiku module, so any haiku modules created
62 | # by the function are reused in case of multiple calls.
63 | self._flatten_features_fn = hk.to_module(flatten_features_fn)(
64 | name="flatten_features_fn")
65 |
66 | def _maybe_build_modules(self, input_graph):
67 | if self._graph_network is None:
68 | num_dimensions = input_graph.nodes["world_position"].shape[-1]
69 | self._graph_network = graph_network.EncodeProcessDecode(
70 | name="encode_process_decode",
71 | node_output_size=num_dimensions,
72 | **self._graph_network_kwargs)
73 |
74 | self._target_normalizer = normalizers.get_accumulated_normalizer(
75 | name="target_normalizer")
76 |
77 | def __call__(self, input_graph: jraph.GraphsTuple, padded_graph=True):
78 | self._maybe_build_modules(input_graph)
79 |
80 | flat_graphs_tuple = self._encoder_preprocessor(
81 | input_graph, padded_graph=padded_graph)
82 | normalized_prediction = self._graph_network(flat_graphs_tuple).nodes
83 | next_position = self._decoder_postprocessor(normalized_prediction,
84 | input_graph)
85 | return input_graph._replace(
86 | nodes={"p:world_position": next_position},
87 | edges={},
88 | globals={},
89 | senders=input_graph.senders[:0],
90 | receivers=input_graph.receivers[:0],
91 | n_edge=input_graph.n_edge * 0), {}
92 |
93 | def _encoder_preprocessor(self, input_graph, padded_graph):
94 | # Flattens the input graph
95 | graph_with_flat_features = self._flatten_features_fn(
96 | input_graph,
97 | connectivity_radius=self._connectivity_radius,
98 | is_padded_graph=padded_graph)
99 | return graph_with_flat_features
100 |
101 | def _decoder_postprocessor(self, normalized_prediction, input_graph):
102 | # Un-normalize and integrate
103 | position_sequence = input_graph.nodes["world_position"]
104 |
105 | # The model produces the output in normalized space so we apply inverse
106 | # normalization.
107 | prediction = self._target_normalizer.inverse(normalized_prediction)
108 |
109 | new_position = euler_integrate_position(position_sequence, prediction)
110 | return new_position
111 |
112 |
113 | def euler_integrate_position(position_sequence, finite_diff_estimate):
114 | """Integrates finite difference estimate to position (assuming dt=1)."""
115 | # Uses an Euler integrator to go from acceleration to position,
116 | # assuming dt=1 corresponding to the size of the finite difference.
117 | previous_position = position_sequence[:, -1]
118 | previous_velocity = previous_position - position_sequence[:, -2]
119 | next_acceleration = finite_diff_estimate
120 | next_velocity = previous_velocity + next_acceleration
121 | next_position = previous_position + next_velocity
122 | return next_position
123 |
124 |
125 | def euler_integrate_position_inverse(position_sequence, next_position):
126 | """Computes a finite difference estimate from current position and history."""
127 | previous_position = position_sequence[:, -1]
128 | previous_velocity = previous_position - position_sequence[:, -2]
129 | next_velocity = next_position - previous_position
130 | acceleration = next_velocity - previous_velocity
131 | return acceleration
132 |
--------------------------------------------------------------------------------
/src/model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utility functions for the LearnedSimulator model."""
17 |
18 | import jax
19 | import jax.numpy as jnp
20 | import jraph
21 | import tree
22 |
23 | from inverse_design.src import normalizers
24 |
25 |
26 | def flatten_features(input_graph,
27 | connectivity_radius,
28 | is_padded_graph,
29 | apply_normalization=False):
30 | """Returns a graph with a single array of node and edge features."""
31 |
32 | # Normalize the eleements of the graph.
33 | if apply_normalization:
34 | graph_elements_normalizer = normalizers.GraphElementsNormalizer(
35 | template_graph=input_graph,
36 | is_padded_graph=is_padded_graph)
37 |
38 | # Computing relative distances in the model.
39 | if "relative_world_position" not in input_graph.edges:
40 | input_graph = _add_relative_distances(
41 | input_graph)
42 |
43 | # Extract important features from the position_sequence.
44 | position_sequence = input_graph.nodes["world_position"]
45 | velocity_sequence = time_diff(position_sequence) # Finite-difference.
46 |
47 | # Collect node features.
48 | node_features = []
49 |
50 | # Normalized velocity sequence, flattening spatial axis.
51 | flat_velocity_sequence = jnp.reshape(velocity_sequence,
52 | [velocity_sequence.shape[0], -1])
53 |
54 | if apply_normalization:
55 | flat_velocity_sequence = graph_elements_normalizer.normalize_node_array(
56 | "velocity_sequence", flat_velocity_sequence)
57 |
58 | node_features.append(flat_velocity_sequence)
59 |
60 | # Material types (one-hot, does not need normalization).
61 | node_features.append(jax.nn.one_hot(input_graph.nodes["material_type(9)"], 9))
62 |
63 | # Collect edge features.
64 | edge_features = []
65 |
66 | # Relative distances and norms.
67 | relative_world_position = input_graph.edges["relative_world_position"]
68 | relative_world_distance = safe_edge_norm(
69 | input_graph.edges["relative_world_position"],
70 | input_graph,
71 | is_padded_graph,
72 | keepdims=True)
73 |
74 | if apply_normalization:
75 | # Scaled determined by connectivity radius.
76 | relative_world_position = relative_world_position / connectivity_radius
77 | relative_world_distance = relative_world_distance / connectivity_radius
78 |
79 | edge_features.append(relative_world_position)
80 | edge_features.append(relative_world_distance)
81 |
82 | # Handle normalization.
83 | node_features = jnp.concatenate(node_features, axis=-1)
84 | edge_features = jnp.concatenate(edge_features, axis=-1)
85 |
86 | return input_graph._replace(
87 | nodes=node_features,
88 | edges=edge_features,
89 | globals=None,
90 | )
91 |
92 |
93 | def time_diff(input_sequence):
94 | """Compute finnite time difference."""
95 | return input_sequence[:, 1:] - input_sequence[:, :-1]
96 |
97 |
98 | def safe_edge_norm(array, graph, is_padded_graph, keepdims=False):
99 | """Compute vector norm, preventing nans in padding elements."""
100 | if is_padded_graph:
101 | padding_mask = jraph.get_edge_padding_mask(graph)
102 | epsilon = 1e-8
103 | perturb = jnp.logical_not(padding_mask) * epsilon
104 | array += jnp.expand_dims(perturb, range(1, len(array.shape)))
105 | return jnp.linalg.norm(array, axis=-1, keepdims=keepdims)
106 |
107 |
108 | def _add_relative_distances(input_graph,
109 | use_last_position_only=True):
110 | """Computes relative distances between particles and with walls."""
111 |
112 | # If these exist, there is probably something wrong.
113 | assert "relative_world_position" not in input_graph.edges
114 | assert "clipped_distance_to_walls" not in input_graph.nodes
115 |
116 | input_graph = tree.map_structure(lambda x: x, input_graph) # Avoid mutating.
117 | particle_pos = input_graph.nodes["world_position"]
118 |
119 | if use_last_position_only:
120 | particle_pos = particle_pos[:, -1]
121 |
122 | input_graph.edges["relative_world_position"] = (
123 | particle_pos[input_graph.receivers] - particle_pos[input_graph.senders])
124 |
125 | return input_graph
126 |
--------------------------------------------------------------------------------
/src/normalizers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """JAX module for normalization with accumulated statistics."""
17 |
18 | import haiku as hk
19 | import jax.numpy as jnp
20 | import jraph
21 |
22 |
23 | def get_accumulated_normalizer(name):
24 | return AccumulatedNormalizer(name=name)
25 |
26 |
27 | class AccumulatedNormalizer(hk.Module):
28 | """Feature normalizer that accumulates statistics for normalization.
29 |
30 | It will accumulate statistics using float32 variables, and will return
31 | the mean and std. It accumulates statistics until the accumulate method is
32 | called `max_num_accumulations` times or the total number of batch elements
33 | processed is below `max_example_count`.
34 |
35 | To enable full GPU compatibility the number of accumulations is stored as a
36 | float32. As this number is incremented one by one, we require
37 | `max_num_accumulations` to be smaller than the highest float32 number that
38 | maintains integer precision (16777216).
39 |
40 | """
41 |
42 | def __init__(
43 | self,
44 | *,
45 | std_epsilon: float = 1e-5,
46 | name: str = 'accumulated_normalizer',
47 | ):
48 | """Inits the module.
49 |
50 | Args:
51 | std_epsilon: minimum value of the standard deviation to use.
52 | name: Name of the module.
53 | """
54 | super(AccumulatedNormalizer, self).__init__(name=name)
55 | self._accumulator_shape = None
56 | self._std_epsilon = std_epsilon
57 |
58 | def __call__(self, batched_data):
59 | """Direct transformation of the normalizer."""
60 | self._set_accumulator_shape(batched_data)
61 | return (batched_data - self.mean) / self.std_with_epsilon
62 |
63 | def inverse(self, normalized_batch_data):
64 | """Inverse transformation of the normalizer."""
65 | self._set_accumulator_shape(normalized_batch_data)
66 | return normalized_batch_data * self.std_with_epsilon + self.mean
67 |
68 | def _set_accumulator_shape(self, batched_sample_data):
69 | self._accumulator_shape = batched_sample_data.shape[-1]
70 |
71 | def _verify_module_connected(self):
72 | if self._accumulator_shape is None:
73 | raise RuntimeError(
74 | 'Trying to read the mean before connecting the module.')
75 |
76 | @property
77 | def _acc_sum(self):
78 | return hk.get_state(
79 | 'acc_sum', self._accumulator_shape, dtype=jnp.float32, init=jnp.zeros)
80 |
81 | @property
82 | def _acc_count(self):
83 | return hk.get_state('acc_count', (), dtype=jnp.float32, init=jnp.zeros)
84 |
85 | @property
86 | def _acc_sum_squared(self):
87 | return hk.get_state(
88 | 'acc_sum_squared',
89 | self._accumulator_shape,
90 | dtype=jnp.float32,
91 | init=jnp.zeros)
92 |
93 | @property
94 | def _safe_count(self):
95 | # To ensure count is at least one and avoid nan's.
96 | return jnp.maximum(self._acc_count, 1.)
97 |
98 | @property
99 | def mean(self):
100 | self._verify_module_connected()
101 | return self._acc_sum / self._safe_count
102 |
103 | @property
104 | def std(self):
105 | self._verify_module_connected()
106 | var = self._acc_sum_squared / self._safe_count - self.mean**2
107 | var = jnp.maximum(var, 0.) # Prevent negatives due to numerical precision.
108 | return jnp.sqrt(var)
109 |
110 | @property
111 | def std_with_epsilon(self):
112 | # To use in case the std is too small.
113 | return jnp.maximum(self.std, self._std_epsilon)
114 |
115 |
116 | class GraphElementsNormalizer(hk.Module):
117 | """Online normalization of individual graph components of a GraphsTuple.
118 |
119 |
120 | Can be used to normalize individual node, edge, and global arrays.
121 |
122 | """
123 |
124 | def __init__(self,
125 | template_graph: jraph.GraphsTuple,
126 | is_padded_graph: bool,
127 | name: str = 'graph_elements_normalizer'):
128 | """Inits the module.
129 |
130 | Args:
131 | template_graph: Input template graph to compute edge/node/global padding
132 | masks.
133 | is_padded_graph: Whether the graph has padding.
134 | name: Name of the Haiku module.
135 | """
136 |
137 | super().__init__(name=name)
138 | self._node_mask = None
139 | self._edge_mask = None
140 | self._graph_mask = None
141 | if is_padded_graph:
142 | self._node_mask = jraph.get_node_padding_mask(template_graph)
143 | self._edge_mask = jraph.get_edge_padding_mask(template_graph)
144 | self._graph_mask = jraph.get_graph_padding_mask(template_graph)
145 |
146 | self._names_used = []
147 |
148 | def _run_normalizer(self, name, array, mask):
149 | if name in self._names_used:
150 | raise ValueError(
151 | f'Attempt to reuse name {name}. Used names: {self._names_used}')
152 | self._names_used.append(name)
153 |
154 | normalizer = get_accumulated_normalizer(name)
155 | return normalizer(array)
156 |
157 | def normalize_node_array(self, name, array):
158 | return self._run_normalizer(name, array, self._node_mask)
159 |
--------------------------------------------------------------------------------
/src/watercourse_env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Watercourse 3D environment utils."""
17 | import haiku as hk
18 | import jax
19 | import jax.numpy as jnp
20 | import jraph
21 | import numpy as np
22 | import tree
23 |
24 | from inverse_design.src import connectivity_utils
25 |
26 |
27 | NORMAL = 0
28 | OBSTACLE = 1
29 | INFLOW = 4
30 |
31 | # for eliminating stray particles from pipe
32 | OOB_AREA = 1.5
33 |
34 |
35 | def _update_edges(input_graph, obstacle_edges, radius):
36 | """Recomputes particle edges, adds obstacle edges."""
37 | # get input graph nodes corresponding to fluid
38 | query_mask = ~input_graph.nodes["external_mask"]
39 |
40 | # get input graph ndoes that are either fluid or obstacle
41 | valid_mask = query_mask | input_graph.nodes["obstacle_mask"]
42 | max_edges = input_graph.senders.shape[0]
43 | num_obstacle_edges = obstacle_edges.shape[0]
44 |
45 | # compute the sender and receiver edges for fluid-fluid and fluid-obstacle
46 | # interactions.
47 | n_edge, senders, receivers = connectivity_utils.compute_fixed_radius_connectivity_jax(
48 | input_graph.nodes["world_position"][:, -1],
49 | n_node=input_graph.n_node[:-1], max_edges=max_edges - num_obstacle_edges,
50 | radius=radius, query_mask=query_mask, node_mask=valid_mask)
51 |
52 | # update edges to include obstacle edges and new fluid-fluid edges
53 | return input_graph._replace(
54 | senders=jnp.concatenate([obstacle_edges[:, 0], senders], axis=0),
55 | receivers=jnp.concatenate([obstacle_edges[:, 1], receivers], axis=0),
56 | n_edge=n_edge.at[0].set(n_edge[0] + num_obstacle_edges))
57 |
58 |
59 | def forward(input_graph, new_particles, network, haiku_model, obstacle_edges,
60 | radius):
61 | """Runs model and post-processing steps in jax, returns position sequence."""
62 | @hk.transform_with_state
63 | def model(inputs):
64 | return haiku_model()(inputs)
65 | rnd_key = jax.random.PRNGKey(42) # use a fixed random key
66 |
67 | # only run for a single graph (plus one padding graph), update graph with
68 | # obstacle edges
69 | assert len(input_graph.n_node) == 2, "Not a single padded graph."
70 | graph = tree.map_structure(lambda x: x, input_graph)
71 | graph = _update_edges(graph, obstacle_edges, radius)
72 |
73 | # build material type
74 | pattern = jnp.ones_like(graph.nodes["external_mask"], dtype=jnp.int32)
75 | inflow_mask = jnp.any(~graph.nodes["mask_stack"], axis=-1)
76 | graph.nodes["material_type(9)"] = jnp.where(
77 | graph.nodes["external_mask"], pattern * OBSTACLE,
78 | jnp.where(inflow_mask, pattern * INFLOW,
79 | pattern * NORMAL))
80 | graph.nodes["type/particles"] = None
81 |
82 | # run model
83 | prev_pos = input_graph.nodes["world_position"]
84 | model_out = model.apply(network["params"], network["state"], rnd_key, graph)
85 | pred_pos = model_out[0][0].nodes["p:world_position"]
86 | total_nodes = jnp.sum(input_graph.n_node[:-1])
87 | node_padding_mask = jnp.arange(prev_pos.shape[0]) < total_nodes
88 |
89 | # update history, reset external particles
90 | next_pos_seq = jnp.concatenate([prev_pos[:, 1:], pred_pos[:, None]], axis=1)
91 | mask = (~input_graph.nodes["external_mask"]) & node_padding_mask
92 | next_pos_seq = jnp.where(mask[:, None, None], next_pos_seq, prev_pos)
93 |
94 | # add new particles, remove old particles that go below the floor surface
95 | delete_particles = next_pos_seq[:, -1, 1] <= 0
96 | delete_particles &= graph.nodes["mask_stack"][:, -1]
97 | particle_mask = graph.nodes["mask_stack"][:, -1] & ~delete_particles
98 | particle_mask |= new_particles
99 | mask_stack = jnp.concatenate(
100 | [graph.nodes["mask_stack"][:, 1:], particle_mask[:, None]], axis=1)
101 |
102 | # create new node features and update graph
103 | new_node_features = {
104 | **input_graph.nodes,
105 | "world_position": next_pos_seq,
106 | "mask_stack": mask_stack,
107 | "external_mask": ~particle_mask,
108 | "deleted": graph.nodes["deleted"] | delete_particles,
109 | }
110 | return input_graph._replace(nodes=new_node_features)
111 |
112 |
113 | def build_initial_graph(input_graphs, max_edges):
114 | """Builds initial padded graphs tuple from typed graph."""
115 | obstacle_edges = np.stack(
116 | [input_graphs[0].senders, input_graphs[0].receivers], axis=1)
117 | graph = tree.map_structure(lambda x: x.copy(), input_graphs[0])
118 |
119 | # clear graph edges
120 | dummy_edge = np.zeros((0,), dtype=np.int32)
121 | graph = graph._replace(
122 | senders=dummy_edge,
123 | receivers=dummy_edge,
124 | n_edge=np.array([0], dtype=np.int32))
125 |
126 | # build inflow stack
127 | inflow_stack = []
128 | init_pos = graph.nodes["world_position"]
129 | for cur_graph in input_graphs:
130 | mask_stack = cur_graph.nodes["mask_stack"]
131 | cur_pos = cur_graph.nodes["world_position"]
132 | new_particles = mask_stack[:, -1] & (~mask_stack[:, -2])
133 | init_pos[new_particles] = cur_pos[new_particles]
134 | new_particles = np.concatenate([new_particles, [False]])
135 | inflow_stack.append(new_particles)
136 | inflow_stack = np.stack(inflow_stack[1:], axis=0)
137 | graph.nodes["world_position"] = init_pos
138 | graph.nodes["deleted"] = np.zeros(init_pos.shape[0], dtype=bool)
139 |
140 | # fix stray particles
141 | stray_particles = init_pos[:, -1, 1] > OOB_AREA
142 | graph.nodes["mask_stack"][stray_particles] = False
143 | graph.nodes["external_mask"][stray_particles] = True
144 |
145 | # pad to maximum node, edge values and add padding graph
146 | max_n_node = graph.n_node.sum() + 1
147 | graphs_tuple = jraph.pad_with_graphs(graph, n_node=max_n_node,
148 | n_edge=max_edges, n_graph=2)
149 | return obstacle_edges, inflow_stack, graphs_tuple
150 |
151 |
152 | def rollout(initial_graph, inflow_stack, network, haiku_model, obstacle_edges,
153 | radius):
154 | """Runs a jittable model rollout."""
155 | @jax.checkpoint
156 | def _step(graph, inflow_mask):
157 | out_graph = forward(graph, inflow_mask, network, haiku_model,
158 | obstacle_edges, radius)
159 | out_data = dict(
160 | pos=out_graph.nodes["world_position"][:, -1],
161 | mask=out_graph.nodes["mask_stack"][:, -1])
162 | return out_graph, out_data
163 | final_graph, trajectory = jax.lax.scan(_step, init=initial_graph,
164 | xs=inflow_stack)
165 | return final_graph, trajectory
166 |
167 |
168 | def make_plain_obstacles(num_side=25):
169 | """Create a mesh obstacle (landscape) with num_side squared control points."""
170 | px, pz = np.meshgrid(
171 | np.linspace(-0.5, 0.5, num_side), np.linspace(-0.5, 0.5, num_side))
172 | trans = np.array([0.5, 0.5, 0.5])
173 |
174 | # generate height map
175 | py = np.zeros_like(px)
176 | pos = np.stack([px, py, pz], axis=-1).reshape((-1, 3))
177 | pos += trans[None]
178 | return pos
179 |
180 |
181 | def max_x_loss_fn(graph):
182 | """Example loss function for maximizing x position of particles when they hit the ground."""
183 | z_pos = graph.nodes["world_position"][:, -1, 2]
184 | z_var = jnp.std(z_pos, where=graph.nodes["deleted"])
185 | x_pos = graph.nodes["world_position"][:, -1, 0]
186 | x_max = jnp.mean(-x_pos, where=graph.nodes["deleted"])
187 | return x_max + z_var
188 |
189 |
190 | def smooth_loss_fn(obs_pos, num_side=25):
191 | """Smoothing loss function for minimizing sharp changes across obstacle."""
192 | obs_grid = jnp.reshape(obs_pos, (num_side, num_side))
193 | obs_dx = jnp.diff(obs_grid, axis=0) ** 2
194 | obs_dy = jnp.diff(obs_grid, axis=1) ** 2
195 | return 0.5 * (jnp.mean(obs_dx) + jnp.mean(obs_dy))
196 |
197 |
198 | def design_fn(params, graph, height_scale=0.15):
199 | """Convert parameters in params into landscape heightfield to be represented in graph."""
200 | graph = tree.map_structure(lambda x: x, graph)
201 | init_pos = jnp.array(graph.nodes["world_position"])
202 | # use tanh transformation to limit height to be within [-1, 1]
203 | raw_obs_pos = jnp.tanh(params) * height_scale
204 |
205 | # tile graph to have the same time history as the fluid particles
206 | obs_pos = jnp.tile(raw_obs_pos[:, None], [1, init_pos.shape[1]])
207 |
208 | # only controlling the height, so set other dimensions to 0
209 | obs_pos = jnp.stack(
210 | [jnp.zeros_like(obs_pos), obs_pos,
211 | jnp.zeros_like(obs_pos)], axis=-1)
212 |
213 | # add controlled height to initial heightfield and update graph nodes
214 | pos = jnp.concatenate(
215 | [init_pos[:obs_pos.shape[0]] + obs_pos, init_pos[obs_pos.shape[0]:]],
216 | axis=0)
217 | graph.nodes["world_position"] = pos
218 | return graph, raw_obs_pos
219 |
--------------------------------------------------------------------------------