├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── dm_hamiltonian_dynamics_suite
├── .pylintrc
├── __init__.py
├── datasets.py
├── generate_dataset.py
├── generate_locally.sh
├── hamiltonian_systems
│ ├── __init__.py
│ ├── hamiltonian.py
│ ├── ideal_double_pendulum.py
│ ├── ideal_mass_spring.py
│ ├── ideal_pendulum.py
│ ├── n_body.py
│ ├── phase_space.py
│ ├── simple_analytic.py
│ └── utils.py
├── load_datasets.py
├── molecular_dynamics
│ ├── __init__.py
│ ├── generate_dataset.py
│ └── lj_16.lmp
├── multiagent_dynamics
│ ├── __init__.py
│ └── game_dynamics.py
└── tests
│ └── test_datasets.py
├── requirements.txt
├── setup.py
└── visualize_datasets.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.egg-info
3 | *.csv
4 | *.npz
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
204 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepMind Hamiltonian Dynamics Suite
2 | [](https://colab.research.google.com/github/deepmind/dm_hamiltonian_dynamics_suite/blob/master/visualize_datasets.ipynb)
3 |
4 | The repository contains the code used for generating the
5 | [DM Hamiltonian Dynamics Suite].
6 |
7 | The code for the models and experiments in our paper can be found
8 | [here](https://github.com/deepmind/deepmind-research/tree/master/physics_inspired_models),
9 | together with the code used for our concurrent publication on how to measure the
10 | quality of the learnt dynamics in models using the Hamiltonian prior when
11 | learning from pixels.
12 |
13 |
14 | ## Datasets
15 |
16 | The suite contains 17 datasets ranging from simple physics problems
17 | (Toy Physics) datasets, to more realistic datasets of molecular dynamics
18 | (Molecular Dynamics), learning dynamics in non-transitive zero-sum games
19 | (Multi Agent), and motion in 3D simulated environments (Mujoco Room). The
20 | datasets vary in terms of the complexity of simulated dynamics and visual
21 | richness.
22 |
23 | For each dataset we created 50k training trajectories, and 20k test trajectories
24 | , with each trajectory including image observations, ground truth phase state
25 | used to generate the data, the first time derivative of the ground truth state,
26 | and any hyper-parameters of individual trajectories. For a few of the datasets
27 | we generate a small number of long trajectories which are used purely for
28 | evaluation.
29 |
30 |
31 | ### Toy Physics
32 |
33 | For all simulated systems we take the trajectories samples at every
34 | `Δt = 0.05` intervals. For any of the non-conservative variants of each
35 | dataset we set the friction coefficient to `0.05`. All hyper-parameters which
36 | can be randomized are always sampled and kept fixed throughout each trajectory.
37 | The colors of the particles, when being randomized, are always sampled uniformly
38 | from some fixed number of options. The exact configurations used for generating
39 | the datasets in the suite can be found in the [datasets.py] file. All systems
40 | were simulated using `scipy.integrate.solve_ivp` .
41 |
42 | #### Mass Spring
43 |
44 | This dataset describes a simple harmonic motion of a particle attached to a
45 | spring. The system has two hyper-parameters - the spring force coefficient `k`
46 | and the mass of the particle `m`. The initial positions and momenta are sampled
47 | jointly from an annulus, where the radius is in the interval
48 | [rlow, rhigh]. One can choose whether the distribution to
49 | sample from is uniform in the annulus, or otherwise to sample uniformly the
50 | length of the radius. To render the system on an image we visualize just the
51 | particle as a circle, with a radius proportional to the square root of its mass.
52 | When rendering, there are also a three additional hyper-parameters - whether to
53 | randomize the horizontal position of the particle, since its motion is only
54 | vertical, whether to also shift around in any direction the anchor point of the
55 | spring, and how many possible colors can the circle representing the particle
56 | can take. Finally, we one can in addition simulate a non-conservative system by
57 | setting the friction coefficient to non-zero.
58 |
59 | #### Pendulum
60 |
61 | This dataset describes the evolution of a particle attached to a pivot, such
62 | that it can move freely. The system is simulated in angle space, such that it is
63 | one dimensional. It has three hyper-parameters - the mass of the particle `m`,
64 | the gravitational constant `g` and the pivot length `l`. The initial positions
65 | and momenta are sampled jointly from an annulus, where the radius is in the
66 | interval [rlow, rhigh]. One can choose whether the
67 | distribution to sample from is uniform in the annulus, or otherwise to sample
68 | uniformly the length of the radius. To render the system on an image we
69 | visualize just the particle as a circle, with radius proportioned to the square
70 | root of its mass. When rendering, there are also a two additional
71 | hyper-parameters - whether to also shift around in any direction the anchor
72 | point of the pivot and how many possible colors can the circle representing the
73 | particle can take. Finally, we one can in addition simulate a non-conservative
74 | system by setting the friction coefficient to non-zero.
75 |
76 | #### Double Pendulum
77 |
78 | This dataset describes the evolution of two coupled pendulums, where the second
79 | one's anchor point of its pivot is the center of the particle of the first one.
80 | This leads to significantly more complicated dynamics [2]. All the
81 | hyper-parameters are equivalent to those in the Pendulum dataset and follow the
82 | exact same protocol.
83 |
84 | #### Two Body
85 |
86 | This dataset describes the gravitational motion of two particles in the plane.
87 | The system has three hyper-parameters - the masses of the two particles `m_1`
88 | and `m_2` and the gravitational constant `g`. The positions and momenta of each
89 | particle are sampled jointly from an annulus, where the radius is in the
90 | interval [rlow, rhigh]. To render the system on an image
91 | we visualize just each particle as a circle, with radius proportioned to the
92 | square root of its mass. When rendering, there are also a two additional
93 | hyper-parameters - whether to also shift around in any direction the center of
94 | mass of the system and how many possible colors can the circles representing the
95 | particles can take.
96 |
97 |
98 | ### Multi Agent
99 |
100 | These datasets describe the dynamics of non-transitive zero-sum games. Here we
101 | consider two prominent examples of such games: **matching pennies** and
102 | **rock-paper-scissors**. We use the well-known continuous-time multi-population
103 | replicator dynamics to drive the learning process. The ground-truth trajectories
104 | are generated by integrating the coupled set of ODEs using an improved Euler
105 | scheme or RK45. In both cases the ground-truth state, i.e., joint strategy
106 | profile (joint policy), and its first order time derivative, is recorded at
107 | regular time intervals `Δt = 0.1`. Trajectories start from uniformly sampled
108 | points on the product of the policy simplexes. No noise is added to the
109 | trajectories.
110 |
111 | As all other datasets use images as inputs, we define the observation as the
112 | outer product of the strategy profiles of the two players. The resulting matrix
113 | captures the probability mass that falls on each pure joint strategy profile
114 | (joint action). In this dataset, the observations are a loss-less representation
115 | of the ground-truth state and are upsampled to `32 x 32 x 3` images through
116 | tiling.
117 |
118 |
119 | ### Mujoco Room
120 |
121 | These datasets are composed of multiple scenes each consisting of a camera
122 | moving around a room with 5 randomly placed objects. The objects were sampled
123 | from four shape types: a sphere, a capsule, a cylinder and a box. Each room was
124 | different due to the randomly sampled colors of the wall, floor and objects. The
125 | dynamics were created by motion and rotation of the camera. The **cirlce**
126 | dataset is generated by rotating the camera around a single randomly sampled
127 | parallel of the unit hemisphere centered around the middle of the room. The
128 | **spiral** dataset is generated by rotating the camera on a spiral moving down
129 | the unit hemisphere. For each trajectory an initial radius and angle are sampled
130 | and then converted into the Cartesian coordinates of the camera. The dynamics
131 | are discretised by moving the camera using step size of `0.1` degrees in a way
132 | that keeps the camera on the unit hemisphere while facing the center of the
133 | room. For the **spiral** dataset, the camera path traces out a golden spiral
134 | starting at the height corresponding to the originally sampled radius on the
135 | unit hemisphere. The rendered scenes are used as observations, and the Cartesian
136 | coordinates of the camera and its velocities estimated through finite
137 | differences as the state. Each trajectory was generated using [MuJoCo].
138 |
139 | ### Molecular Dynamics
140 |
141 | These datasets comprise a type of interaction potential commonly studied
142 | using computer simulation techniques, such as molecular dynamics or Monte Carlo
143 | simulations. In particular, we generated two datasets employing a Lennard-Jones
144 | potential of increasing complexity: one comprising only 4 particles at a very
145 | low density and another one for a 16-particle liquid at a higher density. For
146 | rendering these datasets we used the same scheme as for the Toy Physics datasets.
147 | All masses are set to unity and we represent particles by circles of equal size
148 | with a radius value adjusted to fit the canvas well. The illustrations are
149 | therefore not representative of the density of the system. In addition, we
150 | assigned different colors to the particles to facilitate tracking their
151 | trajectories.
152 |
153 | We created the datasets in two steps: we first generated the raw molecular
154 | dynamics data using the simulation software [LAMMPS], and then converted the
155 | resulting trajectories into a trainable format. For the final datasets available
156 | for download, we combined simulation data from 100 different molecular dynamics
157 | trajectories, each corresponding to a different random initialization
158 | (see Appendix 1.3 of the paper for details). Here we provide a LAMMPS input
159 | script [lj_16.lmp] to generate data for a single seed and a script
160 | [generate_dataset.py] to turn the text-based simulation output into
161 | a trainable format. By default, the simulation is set up for the 16-particle
162 | system, but we provide inline comments on which lines need changing for the
163 | 4-particle dataset.
164 |
165 |
166 |
167 | ## Installation
168 |
169 | All package requirements are listed in `requirements.txt`. To install the code
170 | run in your shell the following commands:
171 |
172 | ```shell
173 | git clone https://github.com/deepmind/dm_hamiltonian_dynamics_suite
174 | pip install -r ./dm_hamiltonian_dynamics_suite/requirements.txt
175 | pip install ./dm_hamiltonian_dynamics_suite
176 | pip install --upgrade "jax[XXX]"
177 | ```
178 |
179 | where `XXX` is the correct type of accelerator that you have on your machine.
180 | Note that if you are using a GPU you might need `XXX` to also include the
181 | correct version of CUDA and cuDNN installed on your machine.
182 | For more details please read [here](https://github.com/google/jax#installation).
183 |
184 |
185 | ## Usage
186 |
187 | You can find an example of how to generate a dataset and the load and visualize
188 | them in the [Colab notebook provided].
189 |
190 |
191 | ## References
192 |
193 | **Which priors matter? Benchmarking models for learning latent dynamics**
194 |
195 | Aleksandar Botev, Drew Jaegle, Peter Wirnsberger, Daniel Hennes and Irina
196 | Higgins
197 |
198 | URL: https://openreview.net/forum?id=qBl8hnwR0px
199 |
200 | **SyMetric: Measuring the Quality of Learnt Hamiltonian Dynamics Inferred from Vision**
201 |
202 | Irina Higgins, Peter Wirnsberger, Andrew Jaegle, Aleksandar Botev
203 |
204 | URL: https://openreview.net/forum?id=9Qu0U9Fj7IP
205 |
206 | ## Disclaimer
207 |
208 | This is not an official Google product.
209 |
210 | [DM Hamiltonian Dynamics Suite]: https://console.cloud.google.com/storage/browser/dm-hamiltonian-dynamics-suite
211 | [Colab notebook provided]: https://colab.research.google.com/github/deepmind/dm_hamiltonian_dynamics_suite/blob/master/visualize_datasets.ipynb
212 | [datasets.py]: https://github.com/deepmind/dm_hamiltonian_dynamics_suite/blob/master/dm_hamiltonian_dynamics_suite/datasets.py
213 | [lj_16.lmp]: https://github.com/deepmind/dm_hamiltonian_dynamics_suite/blob/master/dm_hamiltonian_dynamics_suite/molecular_dynamics/lj_16.lmp
214 | [generate_dataset.py]: https://github.com/deepmind/dm_hamiltonian_dynamics_suite/blob/master/dm_hamiltonian_dynamics_suite/generate_dataset.py
215 | [LAMMPS]: https://lammps.sandia.gov/
216 | [MuJoCo]: http://www.mujoco.org/
217 |
218 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/.pylintrc:
--------------------------------------------------------------------------------
1 | # This Pylint rcfile contains a best-effort configuration to uphold the
2 | # best-practices and style described in the Google Python style guide:
3 | # https://google.github.io/styleguide/pyguide.html
4 | #
5 | # Its canonical open-source location is:
6 | # https://google.github.io/styleguide/pylintrc
7 |
8 | [MASTER]
9 |
10 | # Files or directories to be skipped. They should be base names, not paths.
11 | ignore=third_party
12 |
13 | # Files or directories matching the regex patterns are skipped. The regex
14 | # matches against base names, not paths.
15 | ignore-patterns=
16 |
17 | # Pickle collected data for later comparisons.
18 | persistent=no
19 |
20 | # List of plugins (as comma separated values of python modules names) to load,
21 | # usually to register additional checkers.
22 | load-plugins=
23 |
24 | # Use multiple processes to speed up Pylint.
25 | jobs=4
26 |
27 | # Allow loading of arbitrary C extensions. Extensions are imported into the
28 | # active Python interpreter and may run arbitrary code.
29 | unsafe-load-any-extension=no
30 |
31 |
32 | [MESSAGES CONTROL]
33 |
34 | # Only show warnings with the listed confidence levels. Leave empty to show
35 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
36 | confidence=
37 |
38 | # Enable the message, report, category or checker with the given id(s). You can
39 | # either give multiple identifier separated by comma (,) or put this option
40 | # multiple time (only on the command line, not in the configuration file where
41 | # it should appear only once). See also the "--disable" option for examples.
42 | #enable=
43 |
44 | # Disable the message, report, category or checker with the given id(s). You
45 | # can either give multiple identifiers separated by comma (,) or put this
46 | # option multiple times (only on the command line, not in the configuration
47 | # file where it should appear only once).You can also use "--disable=all" to
48 | # disable everything first and then reenable specific checks. For example, if
49 | # you want to run only the similarities checker, you can use "--disable=all
50 | # --enable=similarities". If you want to run only the classes checker, but have
51 | # no Warning level messages displayed, use"--disable=all --enable=classes
52 | # --disable=W"
53 | disable=abstract-method,
54 | apply-builtin,
55 | arguments-differ,
56 | attribute-defined-outside-init,
57 | backtick,
58 | bad-option-value,
59 | basestring-builtin,
60 | buffer-builtin,
61 | c-extension-no-member,
62 | consider-using-enumerate,
63 | cmp-builtin,
64 | cmp-method,
65 | coerce-builtin,
66 | coerce-method,
67 | delslice-method,
68 | div-method,
69 | duplicate-code,
70 | eq-without-hash,
71 | execfile-builtin,
72 | file-builtin,
73 | filter-builtin-not-iterating,
74 | fixme,
75 | getslice-method,
76 | global-statement,
77 | hex-method,
78 | idiv-method,
79 | implicit-str-concat-in-sequence,
80 | import-error,
81 | import-self,
82 | import-star-module-level,
83 | inconsistent-return-statements,
84 | input-builtin,
85 | intern-builtin,
86 | invalid-str-codec,
87 | locally-disabled,
88 | long-builtin,
89 | long-suffix,
90 | map-builtin-not-iterating,
91 | misplaced-comparison-constant,
92 | missing-function-docstring,
93 | metaclass-assignment,
94 | next-method-called,
95 | next-method-defined,
96 | no-absolute-import,
97 | no-else-break,
98 | no-else-continue,
99 | no-else-raise,
100 | no-else-return,
101 | no-init, # added
102 | no-member,
103 | no-name-in-module,
104 | no-self-use,
105 | nonzero-method,
106 | oct-method,
107 | old-division,
108 | old-ne-operator,
109 | old-octal-literal,
110 | old-raise-syntax,
111 | parameter-unpacking,
112 | print-statement,
113 | raising-string,
114 | range-builtin-not-iterating,
115 | raw_input-builtin,
116 | rdiv-method,
117 | reduce-builtin,
118 | relative-import,
119 | reload-builtin,
120 | round-builtin,
121 | setslice-method,
122 | signature-differs,
123 | standarderror-builtin,
124 | suppressed-message,
125 | sys-max-int,
126 | too-few-public-methods,
127 | too-many-ancestors,
128 | too-many-arguments,
129 | too-many-boolean-expressions,
130 | too-many-branches,
131 | too-many-instance-attributes,
132 | too-many-locals,
133 | too-many-nested-blocks,
134 | too-many-public-methods,
135 | too-many-return-statements,
136 | too-many-statements,
137 | trailing-newlines,
138 | unichr-builtin,
139 | unicode-builtin,
140 | unnecessary-pass,
141 | unpacking-in-except,
142 | useless-else-on-loop,
143 | useless-object-inheritance,
144 | useless-suppression,
145 | using-cmp-argument,
146 | wrong-import-order,
147 | xrange-builtin,
148 | zip-builtin-not-iterating,
149 |
150 |
151 | [REPORTS]
152 |
153 | # Set the output format. Available formats are text, parseable, colorized, msvs
154 | # (visual studio) and html. You can also give a reporter class, eg
155 | # mypackage.mymodule.MyReporterClass.
156 | output-format=text
157 |
158 | # Put messages in a separate file for each module / package specified on the
159 | # command line instead of printing them on stdout. Reports (if any) will be
160 | # written in a file name "pylint_global.[txt|html]". This option is deprecated
161 | # and it will be removed in Pylint 2.0.
162 | files-output=no
163 |
164 | # Tells whether to display a full report or only the messages
165 | reports=no
166 |
167 | # Python expression which should return a note less than 10 (10 is the highest
168 | # note). You have access to the variables errors warning, statement which
169 | # respectively contain the number of errors / warnings messages and the total
170 | # number of statements analyzed. This is used by the global evaluation report
171 | # (RP0004).
172 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
173 |
174 | # Template used to display messages. This is a python new-style format string
175 | # used to format the message information. See doc for all details
176 | #msg-template=
177 |
178 |
179 | [BASIC]
180 |
181 | # Good variable names which should always be accepted, separated by a comma
182 | good-names=main,_
183 |
184 | # Bad variable names which should always be refused, separated by a comma
185 | bad-names=
186 |
187 | # Colon-delimited sets of names that determine each other's naming style when
188 | # the name regexes allow several styles.
189 | name-group=
190 |
191 | # Include a hint for the correct naming format with invalid-name
192 | include-naming-hint=no
193 |
194 | # List of decorators that produce properties, such as abc.abstractproperty. Add
195 | # to this list to register other decorators that produce valid properties.
196 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
197 |
198 | # Regular expression matching correct function names
199 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$
200 |
201 | # Regular expression matching correct variable names
202 | variable-rgx=^[a-z][a-z0-9_]*$
203 |
204 | # Regular expression matching correct constant names
205 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
206 |
207 | # Regular expression matching correct attribute names
208 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
209 |
210 | # Regular expression matching correct argument names
211 | argument-rgx=^[a-z][a-z0-9_]*$
212 |
213 | # Regular expression matching correct class attribute names
214 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
215 |
216 | # Regular expression matching correct inline iteration names
217 | inlinevar-rgx=^[a-z][a-z0-9_]*$
218 |
219 | # Regular expression matching correct class names
220 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$
221 |
222 | # Regular expression matching correct module names
223 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
224 |
225 | # Regular expression matching correct method names
226 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$
227 |
228 | # Regular expression which should only match function or class names that do
229 | # not require a docstring.
230 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
231 |
232 | # Minimum line length for functions/classes that require docstrings, shorter
233 | # ones are exempt.
234 | docstring-min-length=10
235 |
236 |
237 | [TYPECHECK]
238 |
239 | # List of decorators that produce context managers, such as
240 | # contextlib.contextmanager. Add to this list to register other decorators that
241 | # produce valid context managers.
242 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
243 |
244 | # Tells whether missing members accessed in mixin class should be ignored. A
245 | # mixin class is detected if its name ends with "mixin" (case insensitive).
246 | ignore-mixin-members=yes
247 |
248 | # List of module names for which member attributes should not be checked
249 | # (useful for modules/projects where namespaces are manipulated during runtime
250 | # and thus existing member attributes cannot be deduced by static analysis. It
251 | # supports qualified module names, as well as Unix pattern matching.
252 | ignored-modules=
253 |
254 | # List of class names for which member attributes should not be checked (useful
255 | # for classes with dynamically set attributes). This supports the use of
256 | # qualified names.
257 | ignored-classes=optparse.Values,thread._local,_thread._local
258 |
259 | # List of members which are set dynamically and missed by pylint inference
260 | # system, and so shouldn't trigger E1101 when accessed. Python regular
261 | # expressions are accepted.
262 | generated-members=
263 |
264 |
265 | [FORMAT]
266 |
267 | # Maximum number of characters on a single line.
268 | max-line-length=80
269 |
270 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
271 | # lines made too long by directives to pytype.
272 |
273 | # Regexp for a line that is allowed to be longer than the limit.
274 | ignore-long-lines=(?x)(
275 | ^\s*(\#\ )??$|
276 | ^\s*(from\s+\S+\s+)?import\s+.+$)
277 |
278 | # Allow the body of an if to be on the same line as the test if there is no
279 | # else.
280 | single-line-if-stmt=yes
281 |
282 | # List of optional constructs for which whitespace checking is disabled. `dict-
283 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
284 | # `trailing-comma` allows a space between comma and closing bracket: (a, ).
285 | # `empty-line` allows space-only lines.
286 | no-space-check=
287 |
288 | # Maximum number of lines in a module
289 | max-module-lines=99999
290 |
291 | # String used as indentation unit. The internal Google style guide mandates 2
292 | # spaces. Google's externaly-published style guide says 4, consistent with
293 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
294 | # projects (like TensorFlow).
295 | indent-string=' '
296 |
297 | # Number of spaces of indent required inside a hanging or continued line.
298 | indent-after-paren=4
299 |
300 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
301 | expected-line-ending-format=
302 |
303 |
304 | [MISCELLANEOUS]
305 |
306 | # List of note tags to take in consideration, separated by a comma.
307 | notes=TODO
308 |
309 |
310 | [STRING]
311 |
312 | # This flag controls whether inconsistent-quotes generates a warning when the
313 | # character used as a quote delimiter is used inconsistently within a module.
314 | check-quote-consistency=yes
315 |
316 |
317 | [VARIABLES]
318 |
319 | # Tells whether we should check for unused import in __init__ files.
320 | init-import=no
321 |
322 | # A regular expression matching the name of dummy variables (i.e. expectedly
323 | # not used).
324 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
325 |
326 | # List of additional names supposed to be defined in builtins. Remember that
327 | # you should avoid to define new builtins when possible.
328 | additional-builtins=
329 |
330 | # List of strings which can identify a callback function by name. A callback
331 | # name must start or end with one of those strings.
332 | callbacks=cb_,_cb
333 |
334 | # List of qualified module names which can have objects that can redefine
335 | # builtins.
336 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
337 |
338 |
339 | [LOGGING]
340 |
341 | # Logging modules to check that the string format arguments are in logging
342 | # function parameter format
343 | logging-modules=logging,absl.logging,tensorflow.io.logging
344 |
345 |
346 | [SIMILARITIES]
347 |
348 | # Minimum lines number of a similarity.
349 | min-similarity-lines=4
350 |
351 | # Ignore comments when computing similarities.
352 | ignore-comments=yes
353 |
354 | # Ignore docstrings when computing similarities.
355 | ignore-docstrings=yes
356 |
357 | # Ignore imports when computing similarities.
358 | ignore-imports=no
359 |
360 |
361 | [SPELLING]
362 |
363 | # Spelling dictionary name. Available dictionaries: none. To make it working
364 | # install python-enchant package.
365 | spelling-dict=
366 |
367 | # List of comma separated words that should not be checked.
368 | spelling-ignore-words=
369 |
370 | # A path to a file that contains private dictionary; one word per line.
371 | spelling-private-dict-file=
372 |
373 | # Tells whether to store unknown words to indicated private dictionary in
374 | # --spelling-private-dict-file option instead of raising a message.
375 | spelling-store-unknown-words=no
376 |
377 |
378 | [IMPORTS]
379 |
380 | # Deprecated modules which should not be used, separated by a comma
381 | deprecated-modules=regsub,
382 | TERMIOS,
383 | Bastion,
384 | rexec,
385 | sets
386 |
387 | # Create a graph of every (i.e. internal and external) dependencies in the
388 | # given file (report RP0402 must not be disabled)
389 | import-graph=
390 |
391 | # Create a graph of external dependencies in the given file (report RP0402 must
392 | # not be disabled)
393 | ext-import-graph=
394 |
395 | # Create a graph of internal dependencies in the given file (report RP0402 must
396 | # not be disabled)
397 | int-import-graph=
398 |
399 | # Force import order to recognize a module as part of the standard
400 | # compatibility libraries.
401 | known-standard-library=
402 |
403 | # Force import order to recognize a module as part of a third party library.
404 | known-third-party=enchant, absl
405 |
406 | # Analyse import fallback blocks. This can be used to support both Python 2 and
407 | # 3 compatible code, which means that the block might have code that exists
408 | # only in one or another interpreter, leading to false positives when analysed.
409 | analyse-fallback-blocks=no
410 |
411 |
412 | [CLASSES]
413 |
414 | # List of method names used to declare (i.e. assign) instance attributes.
415 | defining-attr-methods=__init__,
416 | __new__,
417 | setUp
418 |
419 | # List of member names, which should be excluded from the protected access
420 | # warning.
421 | exclude-protected=_asdict,
422 | _fields,
423 | _replace,
424 | _source,
425 | _make
426 |
427 | # List of valid names for the first argument in a class method.
428 | valid-classmethod-first-arg=cls,
429 | class_
430 |
431 | # List of valid names for the first argument in a metaclass class method.
432 | valid-metaclass-classmethod-first-arg=mcs
433 |
434 |
435 | [EXCEPTIONS]
436 |
437 | # Exceptions that will emit a warning when being caught. Defaults to
438 | # "Exception"
439 | overgeneral-exceptions=StandardError,
440 | Exception,
441 | BaseException
442 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Module containing dataset configurations used for generation and some utility functions."""
16 | import functools
17 | import os
18 | import shutil
19 | from typing import Callable, Mapping, Any, TextIO, Generator, Tuple, Optional
20 |
21 | from absl import logging
22 |
23 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import ideal_double_pendulum
24 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import ideal_mass_spring
25 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import ideal_pendulum
26 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import n_body
27 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils
28 | import jax
29 | import jax.numpy as jnp
30 | import numpy as np
31 | import tensorflow as tf
32 |
33 | PipelineOutput = Optional[Tuple[
34 | Tuple[Mapping[str, jnp.ndarray], ...],
35 | Tuple[Mapping[str, jnp.ndarray], ...],
36 | ]]
37 |
38 | try:
39 | from dm_hamiltonian_dynamics_suite.multiagent_dynamics import game_dynamics # pylint: disable=g-import-not-at-top
40 | _OPEN_SPIEL_INSTALLED = True
41 | except ModuleNotFoundError:
42 | _OPEN_SPIEL_INSTALLED = False
43 |
44 |
45 | def open_spiel_available() -> bool:
46 | return _OPEN_SPIEL_INSTALLED
47 |
48 |
49 | def set_up_folder(folder: str, overwrite: bool) -> None:
50 | """Sets up the folder needed for the dataset (optionally clearing it)."""
51 | if os.path.exists(folder):
52 | if overwrite:
53 | shutil.rmtree(folder)
54 | os.makedirs(folder)
55 | else:
56 | os.makedirs(folder)
57 |
58 |
59 | def save_features(
60 | file: TextIO,
61 | example_dict: Mapping[str, Any],
62 | prefix: str = ""
63 | ) -> None:
64 | """Saves the features file used for loading."""
65 | for k, v in example_dict.items():
66 | if isinstance(v, dict):
67 | save_features(file, v, prefix=f"{prefix}{k}/")
68 | else:
69 | if isinstance(v, tf.Tensor):
70 | v = v.numpy()
71 | if isinstance(v, (np.ndarray, jnp.ndarray)):
72 | # int32 are promoted to int64
73 | if v.dtype == np.int32:
74 | file.write(f"{prefix}{k}, {v.shape}, {np.int64}\n")
75 | else:
76 | file.write(f"{prefix}{k}, {v.shape}, {v.dtype}\n")
77 | else:
78 | raise NotImplementedError(f"Currently the only supported feature types "
79 | f"are tf.Tensor, np.ndarray and jnp.ndarray. "
80 | f"Encountered value of type {type(v)}.")
81 |
82 |
83 | def encode_example(example_dict: Mapping[str, Any]) -> Mapping[str, Any]:
84 | """Encodes a single trajectory into a TFRecord example."""
85 | result_dict = dict()
86 | for k, v in example_dict.items():
87 | if isinstance(v, tf.Tensor):
88 | v = v.numpy()
89 | if isinstance(v, dict):
90 | for ki, vi in encode_example(v).items():
91 | result_dict[f"{k}/{ki}"] = vi
92 | elif isinstance(v, (np.ndarray, jnp.ndarray)):
93 | if v.dtype == np.uint8:
94 | # We encode images to png
95 | if v.ndim == 4:
96 | # Since encode_png accepts only a single image for a batch of images
97 | # we just stack them over their first axis.
98 | v = v.reshape((-1,) + v.shape[-2:])
99 | image_string = tf.image.encode_png(v).numpy()
100 | result_dict[k] = tf.train.Feature(
101 | bytes_list=tf.train.BytesList(value=[image_string]))
102 | elif v.dtype == np.int32:
103 | # int32 are promoted to int64
104 | value = v.reshape([-1]).astype(np.int64)
105 | result_dict[k] = tf.train.Feature(
106 | int64_list=tf.train.Int64List(value=value))
107 | else:
108 | # Since tf.Records do not support reading float64, here for any values
109 | # we interpret them as int64 and store them in this format, in order
110 | # when reading to be able to recover the float64 values.
111 | value = v.reshape([-1]).view(np.int64)
112 | result_dict[k] = tf.train.Feature(
113 | int64_list=tf.train.Int64List(value=value))
114 | else:
115 | raise NotImplementedError(f"Currently the only supported feature types "
116 | f"are tf.Tensor, np.ndarray and jnp.ndarray. "
117 | f"Encountered value of type {type(v)}.")
118 | return result_dict
119 |
120 |
121 | def transform_dataset(
122 | generator: Generator[Mapping[str, Any], None, None],
123 | destination_folder: str,
124 | prefix: str,
125 | overwrite: bool
126 | ) -> None:
127 | """Copies the dataset from the source folder to the destination as a TFRecord dataset."""
128 | set_up_folder(destination_folder, overwrite)
129 | features_path = os.path.join(destination_folder, "features.txt")
130 | features_saved = False
131 |
132 | file_path = os.path.join(destination_folder, f"{prefix}.tfrecord")
133 | if os.path.exists(file_path):
134 | if not overwrite:
135 | logging.info("The file with prefix %s already exist. Skipping.", prefix)
136 | # We assume that the features file must be present in this case.
137 | return
138 | else:
139 | logging.info("The file with prefix %s already exist and overwrite=True."
140 | " Deleting.", prefix)
141 | os.remove(file_path)
142 | with tf.io.TFRecordWriter(file_path) as writer:
143 | for element in generator:
144 | if not features_saved:
145 | with open(features_path, "w") as f:
146 | save_features(f, element)
147 | features_saved = True
148 | example = tf.train.Example(features=tf.train.Features(
149 | feature=encode_example(element)))
150 | writer.write(example.SerializeToString())
151 |
152 |
153 | def generate_sample(
154 | index: int,
155 | system: n_body.hamiltonian.HamiltonianSystem,
156 | dt: float,
157 | num_steps: int,
158 | steps_per_dt: int
159 | ) -> Mapping[str, jnp.ndarray]:
160 | """Simulates a single trajectory of the system."""
161 | seed = np.random.randint(0, 2 * 32 -1)
162 | prng_key = jax.random.fold_in(jax.random.PRNGKey(seed), index)
163 | total_steps = num_steps * steps_per_dt
164 | total_dt = dt / steps_per_dt
165 | result = system.generate_and_render_dt(
166 | num_trajectories=1,
167 | rng_key=prng_key,
168 | t0=0.0,
169 | dt=total_dt,
170 | num_steps=total_steps)
171 | sub_sample_index = np.linspace(0.0, total_steps, num_steps + 1)
172 | sub_sample_index = sub_sample_index.astype("int64")
173 | def sub_sample(x):
174 | if x.ndim > 1 and x.shape[1] == total_steps + 1:
175 | return x[0, sub_sample_index]
176 | else:
177 | return x
178 | result = jax.tree_map(sub_sample, result)
179 | for k in result.keys():
180 | if "image" in k:
181 | result[k] = (result[k] * 255.0).astype("uint8")
182 | return result
183 |
184 |
185 | def create_pipeline(
186 | generate: Callable[[int], Mapping[str, jnp.ndarray]],
187 | output_path: str,
188 | num_train: int,
189 | num_test: int,
190 | return_generated_examples: bool = False,
191 | ) -> Callable[[], PipelineOutput]:
192 | """Runs the generation pipeline for the HML datasets."""
193 | def pipeline() -> PipelineOutput:
194 | train_examples = list()
195 | test_examples = list()
196 | with open(f"{output_path}/features.txt", "w") as f:
197 | save_features(f, generate(0))
198 | with tf.io.TFRecordWriter(f"{output_path}/train.tfrecord") as writer:
199 | for i in range(num_train):
200 | example = generate(i)
201 | if return_generated_examples:
202 | train_examples.append(example)
203 | example = tf.train.Example(features=tf.train.Features(
204 | feature=encode_example(example)))
205 | writer.write(example.SerializeToString())
206 | with tf.io.TFRecordWriter(f"{output_path}/test.tfrecord") as writer:
207 | for i in range(num_test):
208 | example = generate(num_train + i)
209 | if return_generated_examples:
210 | test_examples.append(example)
211 | example = tf.train.Example(features=tf.train.Features(
212 | feature=encode_example(example)))
213 | writer.write(example.SerializeToString())
214 | if return_generated_examples:
215 | return tuple(train_examples), tuple(test_examples)
216 | return pipeline
217 |
218 |
219 | def generate_full_dataset(
220 | folder: str,
221 | dataset: str,
222 | dt: float,
223 | num_steps: int,
224 | steps_per_dt: int,
225 | num_train: int,
226 | num_test: int,
227 | overwrite: bool,
228 | return_generated_examples: bool = False,
229 | ) -> PipelineOutput:
230 | """Runs the data generation."""
231 | dt_str = str(dt).replace(".", "_")
232 | folder = os.path.join(folder, dataset.lower() + f"_dt_{dt_str}")
233 | set_up_folder(folder, overwrite)
234 |
235 | cls, config = globals().get(dataset.upper())
236 | system = cls(**config())
237 | generate = functools.partial(
238 | generate_sample,
239 | system=system,
240 | dt=dt,
241 | num_steps=num_steps,
242 | steps_per_dt=steps_per_dt)
243 | pipeline = create_pipeline(
244 | generate, folder, num_train, num_test, return_generated_examples)
245 | return pipeline()
246 |
247 |
248 | MASS_SPRING = (
249 | ideal_mass_spring.IdealMassSpring,
250 | lambda: dict( # pylint:disable=g-long-lambda
251 | k_range=utils.BoxRegion(2.0, 2.0),
252 | m_range=utils.BoxRegion(0.5, 0.5),
253 | radius_range=utils.BoxRegion(0.1, 1.0),
254 | uniform_annulus=False,
255 | randomize_canvas_location=False,
256 | randomize_x=False,
257 | num_colors=1,
258 | )
259 | )
260 |
261 |
262 | MASS_SPRING_COLORS = (
263 | ideal_mass_spring.IdealMassSpring,
264 | lambda: dict( # pylint:disable=g-long-lambda
265 | k_range=utils.BoxRegion(2.0, 2.0),
266 | m_range=utils.BoxRegion(0.2, 1.0),
267 | radius_range=utils.BoxRegion(0.1, 1.0),
268 | num_colors=6,
269 | )
270 | )
271 |
272 |
273 | MASS_SPRING_COLORS_FRICTION = (
274 | ideal_mass_spring.IdealMassSpring,
275 | lambda: dict( # pylint:disable=g-long-lambda
276 | k_range=utils.BoxRegion(2.0, 2.0),
277 | m_range=utils.BoxRegion(0.2, 1.0),
278 | radius_range=utils.BoxRegion(0.1, 1.0),
279 | num_colors=6,
280 | friction=0.05,
281 | ),
282 | )
283 |
284 |
285 | PENDULUM = (
286 | ideal_pendulum.IdealPendulum,
287 | lambda: dict( # pylint:disable=g-long-lambda
288 | m_range=utils.BoxRegion(0.5, 0.5),
289 | g_range=utils.BoxRegion(3.0, 3.0),
290 | l_range=utils.BoxRegion(1.0, 1.0),
291 | radius_range=utils.BoxRegion(1.3, 2.3),
292 | uniform_annulus=False,
293 | randomize_canvas_location=False,
294 | num_colors=1,
295 | )
296 | )
297 |
298 |
299 | PENDULUM_COLORS = (
300 | ideal_pendulum.IdealPendulum,
301 | lambda: dict( # pylint:disable=g-long-lambda
302 | m_range=utils.BoxRegion(0.5, 1.5),
303 | g_range=utils.BoxRegion(3.0, 4.0),
304 | l_range=utils.BoxRegion(.5, 1.0),
305 | radius_range=utils.BoxRegion(1.3, 2.3),
306 | num_colors=6,
307 | )
308 | )
309 |
310 |
311 | PENDULUM_COLORS_FRICTION = (
312 | ideal_pendulum.IdealPendulum,
313 | lambda: dict( # pylint:disable=g-long-lambda
314 | m_range=utils.BoxRegion(0.5, 1.5),
315 | g_range=utils.BoxRegion(3.0, 4.0),
316 | l_range=utils.BoxRegion(.5, 1.0),
317 | radius_range=utils.BoxRegion(1.3, 2.3),
318 | num_colors=6,
319 | friction=0.05,
320 | )
321 | )
322 |
323 |
324 | DOUBLE_PENDULUM = (
325 | ideal_double_pendulum.IdealDoublePendulum,
326 | lambda: dict( # pylint:disable=g-long-lambda
327 | m_range=utils.BoxRegion(0.5, 0.5),
328 | g_range=utils.BoxRegion(3.0, 3.0),
329 | l_range=utils.BoxRegion(1.0, 1.0),
330 | radius_range=utils.BoxRegion(1.3, 2.3),
331 | uniform_annulus=False,
332 | randomize_canvas_location=False,
333 | num_colors=2,
334 | )
335 | )
336 |
337 |
338 | DOUBLE_PENDULUM_COLORS = (
339 | ideal_double_pendulum.IdealDoublePendulum,
340 | lambda: dict( # pylint:disable=g-long-lambda
341 | m_range=utils.BoxRegion(0.4, 0.6),
342 | g_range=utils.BoxRegion(2.5, 4.0),
343 | l_range=utils.BoxRegion(0.75, 1.0),
344 | radius_range=utils.BoxRegion(1.0, 2.5),
345 | num_colors=6,
346 | )
347 | )
348 |
349 |
350 | DOUBLE_PENDULUM_COLORS_FRICTION = (
351 | ideal_double_pendulum.IdealDoublePendulum,
352 | lambda: dict( # pylint:disable=g-long-lambda
353 | m_range=utils.BoxRegion(0.4, 0.6),
354 | g_range=utils.BoxRegion(2.5, 4.0),
355 | l_range=utils.BoxRegion(0.75, 1.0),
356 | radius_range=utils.BoxRegion(1.0, 2.5),
357 | num_colors=6,
358 | friction=0.05
359 | ),
360 | )
361 |
362 |
363 | TWO_BODY = (
364 | n_body.TwoBodySystem,
365 | lambda: dict( # pylint:disable=g-long-lambda
366 | m_range=utils.BoxRegion(1.0, 1.0),
367 | g_range=utils.BoxRegion(1.0, 1.0),
368 | radius_range=utils.BoxRegion(0.5, 1.5),
369 | provided_canvas_bounds=utils.BoxRegion(-2.75, 2.75),
370 | randomize_canvas_location=False,
371 | num_colors=2,
372 | )
373 | )
374 |
375 |
376 | TWO_BODY_COLORS = (
377 | n_body.TwoBodySystem,
378 | lambda: dict( # pylint:disable=g-long-lambda
379 | m_range=utils.BoxRegion(0.5, 1.5),
380 | g_range=utils.BoxRegion(0.5, 1.5),
381 | radius_range=utils.BoxRegion(0.5, 1.5),
382 | provided_canvas_bounds=utils.BoxRegion(-5.0, 5.0),
383 | randomize_canvas_location=False,
384 | num_colors=6,
385 | )
386 | )
387 |
388 |
389 | def no_open_spiel_func(*_, **__):
390 | raise ValueError("You must download and install `open_spiel` first in "
391 | "order to use the game_dynamics datasets. See "
392 | "https://github.com/deepmind/open_spiel for instructions"
393 | " how to do this.")
394 |
395 | if not open_spiel_available():
396 | MATCHING_PENNIES = (no_open_spiel_func, dict)
397 | ROCK_PAPER_SCISSORS = (no_open_spiel_func, dict)
398 | else:
399 | MATCHING_PENNIES = (game_dynamics.ZeroSumGame,
400 | lambda: dict(game_name="matrix_mp"))
401 | ROCK_PAPER_SCISSORS = (game_dynamics.ZeroSumGame,
402 | lambda: dict(game_name="matrix_rps"))
403 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/generate_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Script to generate datasets from the configs in datasets.py."""
16 | from absl import app
17 | from absl import flags
18 |
19 | from dm_hamiltonian_dynamics_suite import datasets
20 | from dm_hamiltonian_dynamics_suite.molecular_dynamics import generate_dataset
21 |
22 | import jax
23 |
24 | flags.DEFINE_string("folder", None,
25 | "The folder where to store the datasets.")
26 | flags.DEFINE_string("dataset", None,
27 | "The dataset from datasets.py to use or "
28 | "'molecular_dynamics' respectively.")
29 | flags.DEFINE_string("lammps_file", None,
30 | "For dataset='molecular_dynamics' this should be the "
31 | "LAMMPS trajectory file containing a sequence of timesteps "
32 | "obtained from a MD simulation.")
33 | flags.DEFINE_float("dt", None, "The delta time between two observations.")
34 | flags.DEFINE_integer("num_steps", None, "The number of steps to simulate.")
35 | flags.DEFINE_integer("steps_per_dt", 10,
36 | "How many internal steps to do per a single observation "
37 | "step.")
38 | flags.DEFINE_integer("num_train", None,
39 | "The number of training examples to generate.")
40 | flags.DEFINE_integer("num_test", None,
41 | "The number of test examples to generate.")
42 | flags.DEFINE_boolean("overwrite", False, "Overwrites previous data.")
43 |
44 | flags.mark_flag_as_required("folder")
45 | flags.mark_flag_as_required("dataset")
46 | flags.mark_flag_as_required("dt")
47 | flags.mark_flag_as_required("num_steps")
48 | flags.mark_flag_as_required("num_train")
49 | flags.mark_flag_as_required("num_test")
50 | FLAGS = flags.FLAGS
51 |
52 |
53 | def main(argv):
54 | if len(argv) > 1:
55 | raise ValueError(f"Unexpected args: {argv[1:]}")
56 | if FLAGS.dataset == "molecular_dynamics":
57 | generate_dataset.generate_lammps_dataset(
58 | lammps_file=FLAGS.lammps_file,
59 | folder=FLAGS.output_path,
60 | dt=FLAGS.dt,
61 | num_steps=FLAGS.num_steps,
62 | num_train=FLAGS.num_train,
63 | num_test=FLAGS.num_test,
64 | shuffle=FLAGS.shuffle,
65 | seed=FLAGS.seed,
66 | overwrite=FLAGS.overwrite,
67 | )
68 | else:
69 | datasets.generate_full_dataset(
70 | folder=FLAGS.folder,
71 | dataset=FLAGS.dataset,
72 | dt=FLAGS.dt,
73 | num_steps=FLAGS.num_steps,
74 | steps_per_dt=FLAGS.steps_per_dt,
75 | num_train=FLAGS.num_train,
76 | num_test=FLAGS.num_test,
77 | overwrite=FLAGS.overwrite
78 | )
79 |
80 |
81 | if __name__ == "__main__":
82 | jax.config.update("jax_enable_x64", True)
83 | app.run(main)
84 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/generate_locally.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2020 DeepMind Technologies Limited.
3 | #
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | if [[ "$#" -eq 1 ]]; then
18 | readonly DATASET="${1}"
19 | else
20 | echo "Illegal number of parameters. Expected only the name of the dataset."
21 | exit 2
22 | fi
23 |
24 | readonly FOLDER="/tmp/hamiltonian_ml/datasets"
25 | readonly DTS=(0.05 0.1)
26 | readonly NUM_STEPS=255
27 | readonly NUM_TRAIN=500
28 | readonly NUM_TEST=200
29 |
30 | for DT in "${DTS[@]}"; do
31 | python3 generate_dataset.py \
32 | --folder=${FOLDER} \
33 | --dataset="${DATASET}" \
34 | --dt="${DT}" \
35 | --num_steps=${NUM_STEPS} \
36 | --num_train=${NUM_TRAIN} \
37 | --num_test=${NUM_TEST} \
38 | --overwrite=true
39 | done
40 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/hamiltonian_systems/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/hamiltonian_systems/hamiltonian.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """A module with the abstract class for Hamiltonian systems."""
16 | import abc
17 | from typing import Any, Callable, Mapping, Optional, Tuple, Union
18 |
19 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space
20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils
21 |
22 | import jax
23 | import jax.numpy as jnp
24 | import jax.random as jnr
25 | from scipy import integrate
26 |
27 | Integrator = Callable[
28 | [
29 | Union[phase_space.HamiltonianFunction,
30 | phase_space.SymplecticTangentFunction], # dy_dt
31 | Union[jnp.ndarray, phase_space.PhaseSpace], # y0
32 | Union[float, jnp.ndarray], # t0
33 | Union[float, jnp.ndarray], # dt
34 | int, # num_steps
35 | int # steps_per_dt
36 | ],
37 | Tuple[jnp.ndarray, phase_space.PhaseSpace]
38 | ]
39 |
40 |
41 | class HamiltonianSystem(abc.ABC):
42 | """General class to represent Hamiltonian Systems and simulate them."""
43 |
44 | def __init__(
45 | self,
46 | system_dims: int,
47 | randomize_canvas_location: bool = True,
48 | random_canvas_extra_ratio: float = 0.4,
49 | try_analytic_solution: bool = True,
50 | friction: float = 0.0,
51 | method: Union[str, Integrator] = "scipy",
52 | num_colors: int = 6,
53 | image_resolution: int = 32,
54 | dtype: str = "float32",
55 | steps_per_dt: int = 1,
56 | stiff: bool = False,
57 | extra_ivp_kwargs: Optional[Mapping[str, Union[str, int, float]]] = None):
58 | """Initializes some global properties.
59 |
60 | Args:
61 | system_dims: Dimensionality of the positions (not joint state).
62 | randomize_canvas_location: Whether to add random offset to images.
63 | random_canvas_extra_ratio: How much to be the extra random ofsset.
64 | try_analytic_solution: If True, first tries to solve the system
65 | analytically possible. If False, alway integrates systems numerically.
66 | friction: This changes the dynamics to non-conservative. The new
67 | dynamics are formulated as follows:
68 | dq/dt = dH/dp
69 | dp/dt = - dH/dq - friction * dH/dp
70 | This implies that the Hamiltonian energy decreases over time:
71 | dH/dt = dH/dq^T dq/dt + dH/dp^T dp/dt = - friction * ||dH/dp||_2^2
72 | method: "scipy" or a callable of type `Integrator`.
73 | num_colors: The number of possible colors to use for rendering.
74 | image_resolution: For generated images their resolution.
75 | dtype: What dtype to use for the generated data.
76 | steps_per_dt: Number of inner steps to use per a single observation dt.
77 | stiff: Whether the problem represents a stiff system.
78 | extra_ivp_kwargs: Extra arguments to the scipy solver.
79 | Raises:
80 | ValueError: if `dtype` is not 'float32' or 'float64'.
81 | """
82 | self._system_dims = system_dims
83 | self._randomize_canvas_location = randomize_canvas_location
84 | self._random_canvas_extra_ratio = random_canvas_extra_ratio
85 | self._try_analytic_solution = try_analytic_solution
86 | self._friction = friction
87 | self._method = method
88 | self._num_colors = num_colors
89 | self._resolution = image_resolution
90 | self._dtype = dtype
91 | self._stiff = stiff
92 | self._steps_per_dt = steps_per_dt
93 | if dtype == "float64":
94 | self._scipy_ivp_kwargs = dict(rtol=1e-12, atol=1e-12)
95 | elif dtype == "float32":
96 | self._scipy_ivp_kwargs = dict(rtol=1e-9, atol=1e-9)
97 | else:
98 | raise ValueError("Currently we only support float64 and float32 dtypes.")
99 | if stiff:
100 | self._scipy_ivp_kwargs["method"] = "Radau"
101 | if extra_ivp_kwargs is not None:
102 | self._scipy_ivp_kwargs.update(extra_ivp_kwargs)
103 |
104 | @property
105 | def system_dims(self) -> int:
106 | return self._system_dims
107 |
108 | @property
109 | def randomize_canvas_location(self) -> bool:
110 | return self._randomize_canvas_location
111 |
112 | @property
113 | def random_canvas_extra_ratio(self) -> float:
114 | return self._random_canvas_extra_ratio
115 |
116 | @property
117 | def try_analytic_solution(self) -> bool:
118 | return self._try_analytic_solution
119 |
120 | @property
121 | def friction(self) -> float:
122 | return self._friction
123 |
124 | @property
125 | def method(self):
126 | return self._method
127 |
128 | @property
129 | def num_colors(self) -> int:
130 | return self._num_colors
131 |
132 | @property
133 | def resolution(self) -> int:
134 | return self._resolution
135 |
136 | @property
137 | def dtype(self) -> str:
138 | return self._dtype
139 |
140 | @property
141 | def stiff(self) -> bool:
142 | return self._stiff
143 |
144 | @property
145 | def steps_per_dt(self) -> int:
146 | return self._steps_per_dt
147 |
148 | @property
149 | def scipy_ivp_kwargs(self) -> Mapping[str, Union[str, int, float]]:
150 | return self._scipy_ivp_kwargs
151 |
152 | @abc.abstractmethod
153 | def parametrized_hamiltonian(
154 | self,
155 | t: jnp.ndarray,
156 | y: phase_space.PhaseSpace,
157 | params: utils.Params,
158 | **kwargs: Any
159 | ) -> jnp.ndarray:
160 | """Calculates the Hamiltonian."""
161 |
162 | def hamiltonian_from_params(
163 | self,
164 | params: utils.Params,
165 | **kwargs: Any
166 | ) -> phase_space.HamiltonianFunction:
167 | def hamiltonian(t: jnp.ndarray, y: phase_space.PhaseSpace) -> jnp.ndarray:
168 | return self.parametrized_hamiltonian(t, y, params, **kwargs)
169 | return hamiltonian
170 |
171 | @abc.abstractmethod
172 | def sample_y(
173 | self,
174 | num_samples: int,
175 | params: utils.Params,
176 | rng_key: jnp.ndarray,
177 | **kwargs: Any
178 | ) -> phase_space.PhaseSpace:
179 | """Samples randomly initial states."""
180 |
181 | @abc.abstractmethod
182 | def sample_params(
183 | self,
184 | num_samples: int,
185 | rng_key: jnp.ndarray,
186 | **kwargs: Any
187 | ) -> utils.Params:
188 | """Samples randomly parameters."""
189 |
190 | @abc.abstractmethod
191 | def simulate_analytically(
192 | self,
193 | y0: phase_space.PhaseSpace,
194 | t0: utils.FloatArray,
195 | t_eval: jnp.ndarray,
196 | params: utils.Params,
197 | **kwargs: Any
198 | ) -> Optional[phase_space.PhaseSpace]:
199 | """If analytic solution exist returns it, else returns None."""
200 |
201 | def simulate_analytically_dt(
202 | self,
203 | y0: phase_space.PhaseSpace,
204 | t0: utils.FloatArray,
205 | dt: utils.FloatArray,
206 | num_steps: int,
207 | params: utils.Params,
208 | **kwargs: Any
209 | ) -> Optional[phase_space.PhaseSpace]:
210 | """Same as `simulate_analytically` but uses `dt` and `num_steps`."""
211 | t_eval = utils.dt_to_t_eval(t0, dt, num_steps)
212 | return self.simulate_analytically(y0, t0, t_eval, params, **kwargs)
213 |
214 | @abc.abstractmethod
215 | def canvas_bounds(self) -> utils.BoxRegion:
216 | """Returns the limits of the canvas for rendering."""
217 |
218 | def random_offset_bounds(self) -> utils.BoxRegion:
219 | """Returns any extra randomized offset that can be given to the canvas."""
220 | extra_size = self.random_canvas_extra_ratio * self.canvas_bounds().size / 2
221 | return utils.BoxRegion(
222 | minimum=-extra_size,
223 | maximum=extra_size
224 | )
225 |
226 | def full_canvas_bounds(self) -> utils.BoxRegion:
227 | if self.randomize_canvas_location:
228 | return utils.BoxRegion(
229 | self.canvas_bounds().min + self.random_offset_bounds().min,
230 | self.canvas_bounds().max + self.random_offset_bounds().max,
231 | )
232 | else:
233 | return self.canvas_bounds()
234 |
235 | @abc.abstractmethod
236 | def canvas_position(
237 | self,
238 | position: jnp.ndarray,
239 | params: utils.Params
240 | ) -> jnp.ndarray:
241 | """Returns the canvas position given the position vectors and the parameters."""
242 |
243 | @abc.abstractmethod
244 | def render_trajectories(
245 | self,
246 | position: jnp.ndarray,
247 | params: utils.Params,
248 | rng_key: jnp.ndarray,
249 | **kwargs: Any
250 | ) -> Tuple[jnp.ndarray, utils.Params]:
251 | """Renders the positions q into an image."""
252 |
253 | def simulate_scipy(
254 | self,
255 | y0: phase_space.PhaseSpace,
256 | t0: utils.FloatArray,
257 | t_eval: jnp.ndarray,
258 | params: utils.Params,
259 | ivp_kwargs=None,
260 | **kwargs: Any
261 | ) -> phase_space.PhaseSpace:
262 | """Simulates the system using scipy.integrate.solve_ivp."""
263 | t_span = (t0, float(t_eval[-1]))
264 | y0 = jnp.concatenate([y0.q, y0.p], axis=-1)
265 | y_shape = y0.shape
266 | y0 = y0.reshape([-1])
267 | hamiltonian = self.hamiltonian_from_params(params, **kwargs)
268 | @jax.jit
269 | def fun(t, y):
270 | f = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
271 | dy = f(t, phase_space.PhaseSpace.from_state(y.reshape(y_shape)))
272 | if self.friction != 0.0:
273 | friction_term = phase_space.TangentPhaseSpace(
274 | position=jnp.zeros_like(dy.position),
275 | momentum=-self.friction * dy.position)
276 | dy = dy + friction_term
277 | return dy.single_state.reshape([-1])
278 | kwargs = dict(**self.scipy_ivp_kwargs)
279 | kwargs.update(ivp_kwargs or dict())
280 | solution = integrate.solve_ivp(
281 | fun=fun,
282 | t_span=t_span,
283 | y0=y0,
284 | t_eval=t_eval,
285 | **kwargs)
286 | y_final = solution.y.reshape(y_shape + (t_eval.size,))
287 | return phase_space.PhaseSpace.from_state(jnp.moveaxis(y_final, -1, 0))
288 |
289 | def simulate_scipy_dt(
290 | self,
291 | y0: phase_space.PhaseSpace,
292 | t0: utils.FloatArray,
293 | dt: utils.FloatArray,
294 | num_steps: int,
295 | params: utils.Params,
296 | ivp_kwargs=None,
297 | **kwargs: Any
298 | ) -> phase_space.PhaseSpace:
299 | """Same as `simulate_scipy` but uses `dt` and `num_steps`."""
300 | t_eval = utils.dt_to_t_eval(t0, dt, num_steps)
301 | return self.simulate_scipy(y0, t0, t_eval, params, ivp_kwargs, **kwargs)
302 |
303 | def simulate_integrator(
304 | self,
305 | y0: phase_space.PhaseSpace,
306 | t0: utils.FloatArray,
307 | t_eval: jnp.ndarray,
308 | params: utils.Params,
309 | method: Union[str, Integrator],
310 | **kwargs: Any
311 | ) -> phase_space.PhaseSpace:
312 | """Simulates the system using an integrator from integrators.py module."""
313 | return self.simulate_integrator_dt(
314 | y0=y0,
315 | t0=t0,
316 | dt=utils.t_eval_to_dt(t0, t_eval),
317 | params=params,
318 | method=method,
319 | **kwargs
320 | )
321 |
322 | def simulate_integrator_dt(
323 | self,
324 | y0: phase_space.PhaseSpace,
325 | t0: Union[float, jnp.ndarray],
326 | dt: Union[float, jnp.ndarray],
327 | params: utils.Params,
328 | method: Integrator,
329 | num_steps: Optional[int] = None,
330 | **kwargs: Any
331 | ) -> phase_space.PhaseSpace:
332 | """Same as `simulate_integrator` but uses `dt` and `num_steps`."""
333 | hamiltonian = self.hamiltonian_from_params(params, **kwargs)
334 | if self.friction == 0.0:
335 | return method(
336 | hamiltonian,
337 | y0,
338 | t0,
339 | dt,
340 | num_steps,
341 | self.steps_per_dt,
342 | )[1]
343 | else:
344 | def dy_dt(t: jnp.ndarray, y: phase_space.PhaseSpace):
345 | f = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
346 | dy = f(t, y)
347 | friction_term = phase_space.TangentPhaseSpace(
348 | position=jnp.zeros_like(dy.position),
349 | momentum=-self.friction * dy.position)
350 | return dy + friction_term
351 |
352 | return method(
353 | dy_dt,
354 | y0,
355 | t0,
356 | dt,
357 | num_steps,
358 | self.steps_per_dt,
359 | )[1]
360 |
361 | def generate_trajectories(
362 | self,
363 | y0: phase_space.PhaseSpace,
364 | t0: utils.FloatArray,
365 | t_eval: jnp.ndarray,
366 | params: utils.Params,
367 | **kwargs: Any
368 | ) -> phase_space.PhaseSpace:
369 | """Generates trajectories of the system in phase space.
370 |
371 | Args:
372 | y0: Initial state.
373 | t0: The time instance of the initial state y0.
374 | t_eval: Times at which to return the computed solution.
375 | params: Any parameters of the Hamiltonian.
376 | **kwargs: Any extra things that go into the hamiltonian.
377 |
378 | Returns:
379 | A phase_space.PhaseSpace instance of size NxTxD.
380 | """
381 | return self.generate_trajectories_dt(
382 | y0=y0,
383 | t0=t0,
384 | dt=utils.t_eval_to_dt(t0, t_eval),
385 | params=params,
386 | **kwargs
387 | )
388 |
389 | def generate_trajectories_dt(
390 | self,
391 | y0: phase_space.PhaseSpace,
392 | t0: utils.FloatArray,
393 | dt: utils.FloatArray,
394 | params: utils.Params,
395 | num_steps_forward: int,
396 | include_t0: bool = False,
397 | num_steps_backward: int = 0,
398 | **kwargs: Any
399 | ) -> phase_space.PhaseSpace:
400 | """Same as `generate_trajectories` but uses `dt` and `num_steps`."""
401 | if num_steps_forward < 0 or num_steps_backward < 0:
402 | raise ValueError("num_steps_forward and num_steps_backward can not be "
403 | "negative.")
404 | if num_steps_forward == 0 and num_steps_backward == 0:
405 | raise ValueError("You need one of num_steps_forward or "
406 | "num_of_steps_backward to be positive.")
407 | if num_steps_forward > 0 and num_steps_backward > 0 and not include_t0:
408 | raise ValueError("When both num_steps_forward and num_steps_backward are "
409 | "positive include_t0 should be True.")
410 |
411 | if self.try_analytic_solution and num_steps_backward == 0:
412 | # Try to use analytical solution
413 | y = self.simulate_analytically_dt(y0, t0, dt, num_steps_forward, params,
414 | **kwargs)
415 | if y is not None:
416 | return y
417 | if self.method == "scipy":
418 | if num_steps_backward > 0:
419 | raise NotImplementedError()
420 | return self.simulate_scipy_dt(y0, t0, dt, num_steps_forward, params,
421 | **kwargs)
422 | yts = []
423 | if num_steps_backward > 0:
424 | yt = self.simulate_integrator_dt(
425 | y0=y0,
426 | t0=t0,
427 | dt=-dt,
428 | params=params,
429 | method=self.method,
430 | num_steps=num_steps_backward,
431 | **kwargs)
432 | yt = jax.tree_map(lambda x: jnp.flip(x, axis=0), yt)
433 | yts.append(yt)
434 | if include_t0:
435 | yts.append(jax.tree_map(lambda x: x[None], y0))
436 | if num_steps_forward > 0:
437 | yt = self.simulate_integrator_dt(
438 | y0=y0,
439 | t0=t0,
440 | dt=dt,
441 | params=params,
442 | method=self.method,
443 | num_steps=num_steps_forward,
444 | **kwargs)
445 | yts.append(yt)
446 | if len(yts) > 1:
447 | return jax.tree_multimap(lambda *a: jnp.concatenate(a, axis=0), *yts)
448 | else:
449 | return yts[0]
450 |
451 | def generate_and_render(
452 | self,
453 | num_trajectories: int,
454 | rng_key: jnp.ndarray,
455 | t0: utils.FloatArray,
456 | t_eval: utils.FloatArray,
457 | y0: Optional[phase_space.PhaseSpace] = None,
458 | params: Optional[utils.Params] = None,
459 | within_canvas_bounds: bool = True,
460 | **kwargs: Any
461 | ) -> Mapping[str, Any]:
462 | """Generates trajectories and renders them.
463 |
464 | Args:
465 | num_trajectories: The number of trajectories to generate.
466 | rng_key: PRNG key for sampling any random numbers.
467 | t0: The time instance of the initial state y0.
468 | t_eval: Times at which to return the computed solution.
469 | y0: Initial state. If None will be sampled with `self.sample_y`
470 | params: Parameters of the Hamiltonian. If None will be sampled with
471 | `self.sample_params`
472 | within_canvas_bounds: Re-samples y0 until the trajectories is within
473 | the canvas bounds.
474 | **kwargs: Any extra things that go into the hamiltonian.
475 |
476 | Returns:
477 | A dictionary containing the following elements:
478 | "x": A numpy array representation of the PhaseSpace vector.
479 | "dx_dt": The time derivative of "x".
480 | "image": An image representation of the state.
481 | "other": A dict of other parameters of the system that are not part of
482 | the state.
483 | """
484 | return self.generate_and_render_dt(
485 | num_trajectories=num_trajectories,
486 | rng_key=rng_key,
487 | t0=t0,
488 | dt=utils.t_eval_to_dt(t0, t_eval),
489 | y0=y0,
490 | params=params,
491 | within_canvas_bounds=within_canvas_bounds,
492 | **kwargs
493 | )
494 |
495 | def generate_and_render_dt(
496 | self,
497 | num_trajectories: int,
498 | rng_key: jnp.ndarray,
499 | t0: utils.FloatArray,
500 | dt: utils.FloatArray,
501 | num_steps: Optional[int] = None,
502 | y0: Optional[phase_space.PhaseSpace] = None,
503 | params: Optional[utils.Params] = None,
504 | within_canvas_bounds: bool = True,
505 | **kwargs: Any
506 | ) -> Mapping[str, Any]:
507 | """Same as `generate_and_render` but uses `dt` and `num_steps`."""
508 | if within_canvas_bounds and (y0 is not None or params is not None):
509 | raise ValueError("Within canvas bounds is valid only when y0 and params "
510 | "are None.")
511 | if params is None:
512 | rng_key, key = jnr.split(rng_key)
513 | params = self.sample_params(num_trajectories, rng_key, **kwargs)
514 | if y0 is None:
515 | rng_key, key = jnr.split(rng_key)
516 | y0 = self.sample_y(num_trajectories, params, key, **kwargs)
517 |
518 | # Generate the phase-space trajectories
519 | x = self.generate_trajectories_dt(y0, t0, dt, params, num_steps, **kwargs)
520 | # Make batch leading dimension
521 | x = jax.tree_map(lambda x_: jnp.swapaxes(x_, 0, 1), x)
522 | x = jax.tree_multimap(lambda i, j: jnp.concatenate([i[:, None], j], axis=1),
523 | y0, x)
524 | if within_canvas_bounds:
525 | # Check for valid trajectories
526 | valid = []
527 | while len(valid) < num_trajectories:
528 | for idx in range(x.q.shape[0]):
529 | x_idx, params_idx = jax.tree_map(lambda a, i=idx: a[i], (x, params))
530 | position = self.canvas_position(x_idx.q, params_idx)
531 | if (jnp.all(position >= self.canvas_bounds().min) and
532 | jnp.all(position <= self.canvas_bounds().max)):
533 | valid.append((x_idx, params_idx))
534 | if len(valid) == num_trajectories:
535 | break
536 | new_trajectories = num_trajectories - len(valid)
537 | print(f"Generating {new_trajectories} new trajectories.")
538 | rng_key, key = jnr.split(rng_key)
539 | params = self.sample_params(new_trajectories, rng_key, **kwargs)
540 | rng_key, key = jnr.split(rng_key)
541 | y0 = self.sample_y(new_trajectories, params, key, **kwargs)
542 | x = self.generate_trajectories_dt(y0, t0, dt, params, num_steps,
543 | **kwargs)
544 | x = jax.tree_map(lambda x_: jnp.swapaxes(x_, 0, 1), x)
545 | x = jax.tree_multimap(lambda i, j: # pylint:disable=g-long-lambda
546 | jnp.concatenate([i[:, None], j], axis=1), y0, x)
547 | x, params = jax.tree_multimap(lambda *args: jnp.stack(args, axis=0),
548 | *valid)
549 |
550 | hamiltonian = self.hamiltonian_from_params(params, **kwargs)
551 | df_dt = jax.vmap(phase_space.poisson_bracket_with_q_and_p(hamiltonian),
552 | in_axes=[0, 1], out_axes=1)
553 | if isinstance(dt, float):
554 | dt = jnp.asarray([dt] * num_steps, dtype=x.q.dtype)
555 | t0 = jnp.asarray(t0).astype(dt.dtype)
556 | t = jnp.cumsum(jnp.concatenate([t0[None], dt], axis=0), axis=0)
557 | dx_dt = df_dt(t, x)
558 | rng_key, key = jnr.split(rng_key)
559 | image, extra = self.render_trajectories(x.q, params, rng_key, **kwargs)
560 | params.update(extra)
561 | return dict(x=x.single_state, dx_dt=dx_dt.single_state,
562 | image=image, other=params)
563 |
564 |
565 | class TimeIndependentHamiltonianSystem(HamiltonianSystem):
566 | """A Hamiltonian system where the energy does not depend on time."""
567 |
568 | @abc.abstractmethod
569 | def _hamiltonian(
570 | self,
571 | y: phase_space.PhaseSpace,
572 | params: utils.Params,
573 | **kwargs: Any
574 | ) -> jnp.ndarray:
575 | """Computes the time independent Hamiltonian."""
576 |
577 | def parametrized_hamiltonian(
578 | self,
579 | t: jnp.ndarray,
580 | y: phase_space.PhaseSpace,
581 | params: utils.Params,
582 | **kwargs: Any
583 | ) -> jnp.ndarray:
584 | return self._hamiltonian(y, params, **kwargs)
585 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/hamiltonian_systems/ideal_double_pendulum.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Ideal double pendulum."""
16 | import functools
17 | from typing import Any, Optional, Tuple
18 |
19 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import hamiltonian
20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space
21 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils
22 |
23 | import jax
24 | import jax.numpy as jnp
25 |
26 |
27 | class IdealDoublePendulum(hamiltonian.TimeIndependentHamiltonianSystem):
28 | """An idealized double pendulum system.
29 |
30 | Parameters:
31 | m_range - possible range of the particle mass
32 | g_range - possible range of the gravitational force
33 | l_range - possible range of the length of the pendulum
34 |
35 | The Hamiltonian is:
36 | H = [m_2 * l_2^2 * p_1^2 + (m_1 + m_2) * l_1^2 * p_2^2 - 2 * m_2 * l_1 *
37 | l_2 * p_1 * p_2 * cos(q_1 - q_2)] /
38 | [2 * m_2 * l_1^2 * l_2^2 * (m_1 + m_2 * sin(q_1 - q_2)^2]
39 | - (m_1 + m_2) * g * l_1 * cos(q_1) - m_2 * g * l_2 * cos(q_2)
40 |
41 | See https://iopscience.iop.org/article/10.1088/1742-6596/739/1/012066/meta
42 |
43 | Initial state parameters:
44 | radius_range - The initial state is sampled from a disk in phase space with
45 | radius in this range.
46 | uniform_annulus - Whether to sample uniformly on the disk or uniformly the
47 | radius.
48 | randomize_canvas_location - Whether to randomize th vertical position of the
49 | particle when rendering.
50 | """
51 |
52 | def __init__(
53 | self,
54 | m_range: utils.BoxRegion,
55 | g_range: utils.BoxRegion,
56 | l_range: utils.BoxRegion,
57 | radius_range: utils.BoxRegion,
58 | uniform_annulus: bool = True,
59 | **kwargs):
60 | super().__init__(system_dims=2, **kwargs)
61 | self.m_range = m_range
62 | self.g_range = g_range
63 | self.l_range = l_range
64 | self.radius_range = radius_range
65 | self.uniform_annulus = uniform_annulus
66 | render = functools.partial(
67 | utils.render_particles_trajectory,
68 | canvas_limits=self.full_canvas_bounds(),
69 | resolution=self.resolution,
70 | num_colors=self.num_colors)
71 | self._batch_render = jax.vmap(render)
72 |
73 | def _hamiltonian(
74 | self,
75 | y: phase_space.PhaseSpace,
76 | params: utils.Params,
77 | **kwargs: Any
78 | ) -> jnp.ndarray:
79 | assert len(params) == 5
80 | m_1 = params["m_1"]
81 | l_1 = params["l_1"]
82 | m_2 = params["m_2"]
83 | l_2 = params["l_2"]
84 | g = params["g"]
85 |
86 | q_1, q_2 = y.q[..., 0], y.q[..., 1]
87 | p_1, p_2 = y.p[..., 0], y.p[..., 1]
88 |
89 | a_1 = m_2 * l_2 ** 2 * p_1 ** 2
90 | a_2 = (m_1 + m_2) * l_1 ** 2 * p_2 ** 2
91 | a_3 = 2 * m_2 * l_1 * l_2 * p_1 * p_2 * jnp.cos(q_1 - q_2)
92 | b_1 = 2 * m_2 * l_1 ** 2 * l_2 ** 2
93 | b_2 = (m_1 + m_2 * jnp.sin(q_1 - q_2) ** 2)
94 | c_1 = (m_1 + m_2) * g * l_1 * jnp.cos(q_1)
95 | c_2 = m_2 * g * l_2 * jnp.cos(q_2)
96 | return (a_1 + a_2 - a_3) / (b_1 * b_2) - c_1 - c_2
97 |
98 | def sample_y(
99 | self,
100 | num_samples: int,
101 | params: utils.Params,
102 | rng_key: jnp.ndarray,
103 | **kwargs: Any
104 | ) -> phase_space.PhaseSpace:
105 | key1, key2 = jax.random.split(rng_key)
106 | state_1 = utils.uniform_annulus(
107 | key1, num_samples, 2, self.radius_range, self.uniform_annulus)
108 | state_2 = utils.uniform_annulus(
109 | key2, num_samples, 2, self.radius_range, self.uniform_annulus)
110 | state = jnp.stack([state_1[..., 0], state_2[..., 0],
111 | state_1[..., 1], state_2[..., 1]], axis=-1)
112 | return phase_space.PhaseSpace.from_state(state.astype(self.dtype))
113 |
114 | def sample_params(
115 | self,
116 | num_samples: int,
117 | rng_key: jnp.ndarray,
118 | **kwargs: Any
119 | ) -> utils.Params:
120 | keys = jax.random.split(rng_key, 5)
121 | m_1 = jax.random.uniform(keys[0], [num_samples], minval=self.m_range.min,
122 | maxval=self.m_range.max)
123 | m_2 = jax.random.uniform(keys[1], [num_samples], minval=self.m_range.min,
124 | maxval=self.m_range.max)
125 | l_1 = jax.random.uniform(keys[2], [num_samples], minval=self.l_range.min,
126 | maxval=self.l_range.max)
127 | l_2 = jax.random.uniform(keys[3], [num_samples], minval=self.l_range.min,
128 | maxval=self.l_range.max)
129 | g = jax.random.uniform(keys[4], [num_samples], minval=self.g_range.min,
130 | maxval=self.g_range.max)
131 | return dict(m_1=m_1, m_2=m_2, l_1=l_1, l_2=l_2, g=g)
132 |
133 | def simulate_analytically(
134 | self,
135 | y0: phase_space.PhaseSpace,
136 | t0: utils.FloatArray,
137 | t_eval: jnp.ndarray,
138 | params: utils.Params,
139 | **kwargs: Any
140 | ) -> Optional[phase_space.PhaseSpace]:
141 | return None
142 |
143 | def canvas_bounds(self) -> utils.BoxRegion:
144 | max_d = 2 * self.l_range.max + jnp.sqrt(self.m_range.max / jnp.pi)
145 | return utils.BoxRegion(-max_d, max_d)
146 |
147 | def canvas_position(
148 | self,
149 | position: jnp.ndarray,
150 | params: utils.Params
151 | ) -> jnp.ndarray:
152 | l_1 = utils.expand_to_rank_right(params["l_1"], 2)
153 | l_2 = utils.expand_to_rank_right(params["l_2"], 2)
154 | y_1 = jnp.sin(position[..., 0] - jnp.pi / 2.0) * l_1
155 | x_1 = jnp.cos(position[..., 0] - jnp.pi / 2.0) * l_1
156 | position_1 = jnp.stack([x_1, y_1], axis=-1)
157 | y_2 = jnp.sin(position[..., 1] - jnp.pi / 2.0) * l_2
158 | x_2 = jnp.cos(position[..., 1] - jnp.pi / 2.0) * l_2
159 | position_2 = jnp.stack([x_2, y_2], axis=-1)
160 | return jnp.stack([position_1, position_2], axis=-2)
161 |
162 | def render_trajectories(
163 | self,
164 | position: jnp.ndarray,
165 | params: utils.Params,
166 | rng_key: jnp.ndarray,
167 | **kwargs: Any
168 | ) -> Tuple[jnp.ndarray, utils.Params]:
169 | n, _, d = position.shape
170 | assert d == self.system_dims
171 | assert len(params) == 5
172 | key1, key2 = jax.random.split(rng_key, 2)
173 | m_1 = params["m_1"]
174 | m_2 = params["m_2"]
175 | position = self.canvas_position(position, params)
176 | position_1, position_2 = position[..., 0, :], position[..., 1, :]
177 | if self.randomize_canvas_location:
178 | offset = jax.random.uniform(key1, shape=[n, 2])
179 | offset = self.random_offset_bounds().convert_from_unit_interval(offset)
180 | else:
181 | offset = jnp.zeros([n, 2])
182 | position_1 = position_1 + offset[:, None, :]
183 | position_2 = position_1 + position_2
184 | particles = jnp.stack([position_1, position_2], axis=-2)
185 | radius_1 = jnp.sqrt(m_1 / jnp.pi)
186 | radius_2 = jnp.sqrt(m_2 / jnp.pi)
187 | particles_radius = jnp.stack([radius_1, radius_2], axis=-1)
188 | if self.num_colors == 1:
189 | color_index = jnp.zeros([n, 2]).astype("int64")
190 | else:
191 | color_index = utils.random_int_k_from_n(
192 | key2,
193 | num_samples=n,
194 | n=self.num_colors,
195 | k=2
196 | )
197 | images = self._batch_render(particles, particles_radius, color_index)
198 | return images, dict(offset=offset, color_index=color_index)
199 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/hamiltonian_systems/ideal_mass_spring.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Ideal mass spring."""
16 | import functools
17 | from typing import Any, Optional, Tuple
18 |
19 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import hamiltonian
20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space
21 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils
22 |
23 | import jax
24 | import jax.numpy as jnp
25 |
26 |
27 | class IdealMassSpring(hamiltonian.TimeIndependentHamiltonianSystem):
28 | """An idealized mass-spring system (also known as a harmonica oscillator).
29 |
30 | The system is represented in 2 dimensions, but the spring moves only in the
31 | vertical orientation.
32 |
33 | Parameters:
34 | k_range - possible range of the spring's force coefficient
35 | m_range - possible range of the particle mass
36 |
37 | The Hamiltonian is:
38 | k * q^2 / 2.0 + p^2 / (2 * m)
39 |
40 | Initial state parameters:
41 | radius_range - The initial state is sampled from a disk in phase space with
42 | radius in this range.
43 | uniform_annulus - Whether to sample uniformly on the disk or uniformly the
44 | radius.
45 | randomize_x - Whether to randomize the horizontal position of the particle
46 | when rendering.
47 | """
48 |
49 | def __init__(
50 | self,
51 | k_range: utils.BoxRegion,
52 | m_range: utils.BoxRegion,
53 | radius_range: utils.BoxRegion,
54 | uniform_annulus: bool = True,
55 | randomize_x: bool = True,
56 | **kwargs):
57 | super().__init__(system_dims=1, **kwargs)
58 | self.k_range = k_range
59 | self.m_range = m_range
60 | self.radius_range = radius_range
61 | self.uniform_annulus = uniform_annulus
62 | self.randomize_x = randomize_x
63 | render = functools.partial(utils.render_particles_trajectory,
64 | canvas_limits=self.full_canvas_bounds(),
65 | resolution=self.resolution,
66 | num_colors=self.num_colors)
67 | self._batch_render = jax.vmap(render)
68 |
69 | def _hamiltonian(
70 | self,
71 | y: phase_space.PhaseSpace,
72 | params: utils.Params,
73 | **kwargs: Any
74 | ) -> jnp.ndarray:
75 | assert len(params) == 2
76 | k = params["k"]
77 | m = params["m"]
78 | potential = k * y.q[..., 0] ** 2 / 2
79 | kinetic = y.p[..., 0] ** 2 / (2 * m)
80 | return potential + kinetic
81 |
82 | def sample_y(
83 | self,
84 | num_samples: int,
85 | params: utils.Params,
86 | rng_key: jnp.ndarray,
87 | **kwargs: Any
88 | ) -> phase_space.PhaseSpace:
89 | assert len(params) == 2
90 | k = params["k"]
91 | m = params["m"]
92 | state = utils.uniform_annulus(
93 | rng_key, num_samples, 2, self.radius_range, self.uniform_annulus)
94 | q = state[..., :1]
95 | p = state[..., 1:] * jnp.sqrt(k * m)
96 | return phase_space.PhaseSpace(position=q, momentum=p)
97 |
98 | def sample_params(
99 | self,
100 | num_samples: int,
101 | rng_key: jnp.ndarray,
102 | **kwargs: Any
103 | ) -> utils.Params:
104 | key1, key2 = jax.random.split(rng_key)
105 | k = jax.random.uniform(key1, [num_samples], minval=self.k_range.min,
106 | maxval=self.k_range.max)
107 | m = jax.random.uniform(key2, [num_samples], minval=self.m_range.min,
108 | maxval=self.m_range.max)
109 | return dict(k=k, m=m)
110 |
111 | def simulate_analytically(
112 | self,
113 | y0: phase_space.PhaseSpace,
114 | t0: utils.FloatArray,
115 | t_eval: jnp.ndarray,
116 | params: utils.Params,
117 | **kwargs: Any
118 | ) -> Optional[phase_space.PhaseSpace]:
119 | if self.friction != 0.0:
120 | return None
121 | assert len(params) == 2
122 | k = params["k"]
123 | m = params["m"]
124 | t = t_eval - t0
125 | w = jnp.sqrt(k / m).astype(self.dtype)
126 | a = jnp.sqrt(y0.q[..., 0] ** 2 + y0.p[..., 0] ** 2 / (k * m))
127 | b = jnp.arctan2(- y0.p[..., 0], y0.q[..., 0] * m * w)
128 | w, a, b, m = w[..., None], a[..., None], b[..., None], m[..., None]
129 | t = utils.expand_to_rank_right(t, y0.q.ndim + 1)
130 |
131 | q = a * jnp.cos(w * t + b)
132 | p = - a * m * w * jnp.sin(w * t + b)
133 | return phase_space.PhaseSpace(position=q, momentum=p)
134 |
135 | def canvas_bounds(self) -> utils.BoxRegion:
136 | max_x = self.radius_range.max
137 | max_r = jnp.sqrt(self.m_range.max / jnp.pi)
138 | return utils.BoxRegion(- max_x - max_r, max_x + max_r)
139 |
140 | def canvas_position(
141 | self,
142 | position: jnp.ndarray,
143 | params: utils.Params
144 | ) -> jnp.ndarray:
145 | return jnp.stack([jnp.zeros_like(position), position], axis=-1)
146 |
147 | def render_trajectories(
148 | self,
149 | position: jnp.ndarray,
150 | params: utils.Params,
151 | rng_key: jnp.ndarray,
152 | **kwargs: Any
153 | ) -> Tuple[jnp.ndarray, utils.Params]:
154 | n, _, d = position.shape
155 | assert d == self.system_dims
156 | assert len(params) == 2
157 | key1, key2 = jax.random.split(rng_key)
158 | m = utils.expand_to_rank_right(params["m"], 2)
159 | particles = self.canvas_position(position, params)
160 | if self.randomize_x:
161 | x_offset = jax.random.uniform(key1, shape=[n])
162 | y_offset = jnp.zeros_like(x_offset)
163 | offset = jnp.stack([x_offset, y_offset], axis=-1)
164 | else:
165 | offset = jnp.zeros([n, d])
166 | if self.randomize_canvas_location:
167 | offset_ = jax.random.uniform(key2, shape=[n, d])
168 | offset_ = self.random_offset_bounds().convert_from_unit_interval(offset_)
169 | offset = offset + offset_
170 | particles = particles + offset[:, None, None, :]
171 | particles_radius = jnp.sqrt(m / jnp.pi)
172 | if self.num_colors == 1:
173 | color_index = jnp.zeros([n, 1]).astype("int64")
174 | else:
175 | color_index = jax.random.randint(
176 | key=rng_key, shape=[n, 1], minval=0, maxval=self.num_colors)
177 | images = self._batch_render(particles, particles_radius, color_index)
178 | return images, dict(offset=offset, color_index=color_index)
179 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/hamiltonian_systems/ideal_pendulum.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Ideal pendulum."""
16 | import functools
17 | from typing import Any, Optional, Tuple
18 |
19 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import hamiltonian
20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space
21 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils
22 |
23 | import jax
24 | import jax.numpy as jnp
25 |
26 |
27 | class IdealPendulum(hamiltonian.TimeIndependentHamiltonianSystem):
28 | """An idealized pendulum system.
29 |
30 | Parameters:
31 | m_range - possible range of the particle mass
32 | g_range - possible range of the gravitational force
33 | l_range - possible range of the length of the pendulum
34 |
35 | The Hamiltonian is:
36 | m * l * g * (1 - cos(q)) + p^2 / (2 * m * l^2)
37 |
38 | Initial state parameters:
39 | radius_range - The initial state is sampled from a disk in phase space with
40 | radius in this range.
41 | uniform_annulus - Whether to sample uniformly on the disk or uniformly the
42 | radius.
43 | randomize_canvas_location - Whether to randomize th vertical position of the
44 | particle when rendering.
45 | """
46 |
47 | def __init__(
48 | self,
49 | m_range: utils.BoxRegion,
50 | g_range: utils.BoxRegion,
51 | l_range: utils.BoxRegion,
52 | radius_range: utils.BoxRegion,
53 | uniform_annulus: bool = True,
54 | **kwargs):
55 | super().__init__(system_dims=1, **kwargs)
56 | self.m_range = m_range
57 | self.g_range = g_range
58 | self.l_range = l_range
59 | self.radius_range = radius_range
60 | self.uniform_annulus = uniform_annulus
61 | render = functools.partial(utils.render_particles_trajectory,
62 | canvas_limits=self.full_canvas_bounds(),
63 | resolution=self.resolution,
64 | num_colors=self.num_colors)
65 | self._batch_render = jax.vmap(render)
66 |
67 | def _hamiltonian(
68 | self,
69 | y: phase_space.PhaseSpace,
70 | params: utils.Params,
71 | **kwargs: Any
72 | ) -> jnp.ndarray:
73 | assert len(params) == 3
74 | m = params["m"]
75 | l = params["l"]
76 | g = params["g"]
77 | potential = m * g * l * (1 - jnp.cos(y.q[..., 0]))
78 | kinetic = y.p[..., 0] ** 2 / (2 * m * l ** 2)
79 | return potential + kinetic
80 |
81 | def sample_y(
82 | self,
83 | num_samples: int,
84 | params: utils.Params,
85 | rng_key: jnp.ndarray,
86 | **kwargs: Any
87 | ) -> phase_space.PhaseSpace:
88 | state = utils.uniform_annulus(
89 | rng_key, num_samples, 2, self.radius_range, self.uniform_annulus)
90 | return phase_space.PhaseSpace.from_state(state.astype(self.dtype))
91 |
92 | def sample_params(
93 | self,
94 | num_samples: int,
95 | rng_key: jnp.ndarray,
96 | **kwargs: Any
97 | ) -> utils.Params:
98 | key1, key2, key3 = jax.random.split(rng_key, 3)
99 | m = jax.random.uniform(key1, [num_samples], minval=self.m_range.min,
100 | maxval=self.m_range.max)
101 | l = jax.random.uniform(key2, [num_samples], minval=self.l_range.min,
102 | maxval=self.l_range.max)
103 | g = jax.random.uniform(key3, [num_samples], minval=self.g_range.min,
104 | maxval=self.g_range.max)
105 | return dict(m=m, l=l, g=g)
106 |
107 | def simulate_analytically(
108 | self,
109 | y0: phase_space.PhaseSpace,
110 | t0: utils.FloatArray,
111 | t_eval: jnp.ndarray,
112 | params: utils.Params,
113 | **kwargs: Any
114 | ) -> Optional[phase_space.PhaseSpace]:
115 | return None
116 |
117 | def canvas_bounds(self) -> utils.BoxRegion:
118 | max_d = self.l_range.max + jnp.sqrt(self.m_range.max / jnp.pi)
119 | return utils.BoxRegion(-max_d, max_d)
120 |
121 | def canvas_position(
122 | self,
123 | position: jnp.ndarray,
124 | params: utils.Params
125 | ) -> jnp.ndarray:
126 | l = utils.expand_to_rank_right(params["l"], 2)
127 | y = jnp.sin(position[..., 0] - jnp.pi / 2.0) * l
128 | x = jnp.cos(position[..., 1] - jnp.pi / 2.0) * l
129 | return jnp.stack([x, y], axis=-1)
130 |
131 | def render_trajectories(
132 | self,
133 | position: jnp.ndarray,
134 | params: utils.Params,
135 | rng_key: jnp.ndarray,
136 | **kwargs: Any
137 | ) -> Tuple[jnp.ndarray, utils.Params]:
138 | n, _, d = position.shape
139 | assert d == self.system_dims
140 | assert len(params) == 3
141 | m = utils.expand_to_rank_right(params["m"], 2)
142 | key1, key2 = jax.random.split(rng_key, 2)
143 | particles = self.canvas_position(position, params)
144 | if self.randomize_canvas_location:
145 | offset = jax.random.uniform(key1, shape=[n, 2])
146 | offset = self.random_offset_bounds().convert_from_unit_interval(offset)
147 | else:
148 | offset = jnp.zeros(shape=[n, 2])
149 | particles = particles + offset[:, None, :]
150 | particles = particles[..., None, :]
151 | particles_radius = jnp.sqrt(m / jnp.pi)
152 | if self.num_colors == 1:
153 | color_index = jnp.zeros([n, 1]).astype("int64")
154 | else:
155 | color_index = jax.random.randint(
156 | key=key2, shape=[n, 1], minval=0, maxval=self.num_colors)
157 | images = self._batch_render(particles, particles_radius, color_index)
158 | return images, dict(offset=offset, color_index=color_index)
159 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/hamiltonian_systems/n_body.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """N body."""
16 | import abc
17 | import functools
18 | from typing import Any, Optional, Tuple
19 |
20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import hamiltonian
21 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space
22 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils
23 |
24 | import jax
25 | import jax.numpy as jnp
26 |
27 |
28 | class NBodySystem(hamiltonian.TimeIndependentHamiltonianSystem):
29 | """An N-body system abstract class.
30 |
31 | Parameters:
32 | m_range - possible range of the particle mass
33 | g_range - possible range of the gravitational force
34 | provided_canvas_bounds - The canvas bounds for the given ranges
35 |
36 | The Hamiltonian is:
37 | - sum_i jnp.ndarray:
77 | assert len(params) == 2
78 | m = params["m"]
79 | g = params["g"]
80 | q = y.q.reshape([-1, self.n, self.space_dims])
81 | p = y.p.reshape([-1, self.n, self.space_dims])
82 |
83 | q_ij = jnp.matmul(q, jnp.swapaxes(q, axis1=-1, axis2=-2))
84 | q_ii = jnp.diagonal(q_ij, axis1=-1, axis2=-2)
85 | q_ij_norms_2 = q_ii[:, None, :] + q_ii[:, :, None] - 2.0 * q_ij
86 | # Adding identity so that on the diagonal the norms are not 0
87 | q_ij_norms = jnp.sqrt(q_ij_norms_2 + jnp.identity(self.n))
88 | masses_ij = m[:, None, :] * m[:, :, None]
89 | # Remove masses in the diagonal so that those potentials are 0
90 | masses_ij = masses_ij - masses_ij * jnp.identity(self.n)[None]
91 | # Compute pairwise interactions
92 | products = g[:, None, None] * masses_ij / q_ij_norms
93 | # Note that here we are summing both i->j and j->i hence the division by 2
94 | potential = - products.sum(axis=(-2, -1)) / 2
95 | kinetic = jnp.sum(p ** 2, axis=-1) / (2.0 * m)
96 | kinetic = kinetic.sum(axis=-1)
97 | return potential + kinetic
98 |
99 | @abc.abstractmethod
100 | def sample_y(
101 | self,
102 | num_samples: int,
103 | params: utils.Params,
104 | rng_key: jnp.ndarray,
105 | **kwargs: Any
106 | ) -> phase_space.PhaseSpace:
107 | pass
108 |
109 | def sample_params(
110 | self,
111 | num_samples: int,
112 | rng_key: jnp.ndarray,
113 | **kwargs: Any
114 | ) -> utils.Params:
115 | key1, key2 = jax.random.split(rng_key)
116 | m = jax.random.uniform(key1, [num_samples, self.n], minval=self.m_range.min,
117 | maxval=self.m_range.max)
118 | g = jax.random.uniform(key2, [num_samples], minval=self.g_range.min,
119 | maxval=self.g_range.max)
120 | return dict(m=m, g=g)
121 |
122 | def simulate_analytically(
123 | self,
124 | y0: phase_space.PhaseSpace,
125 | t0: utils.FloatArray,
126 | t_eval: jnp.ndarray,
127 | params: utils.Params,
128 | **kwargs: Any
129 | ) -> Optional[phase_space.PhaseSpace]:
130 | return None
131 |
132 | def canvas_bounds(self) -> utils.BoxRegion:
133 | return self.provided_canvas_bounds
134 |
135 | def canvas_position(
136 | self,
137 | position: jnp.ndarray,
138 | params: utils.Params
139 | ) -> jnp.ndarray:
140 | return position
141 |
142 | def render_trajectories(
143 | self,
144 | position: jnp.ndarray,
145 | params: utils.Params,
146 | rng_key: jnp.ndarray,
147 | **kwargs: Any
148 | ) -> Tuple[jnp.ndarray, utils.Params]:
149 | n, _, d = position.shape
150 | assert d == self.system_dims
151 | assert len(params) == 2
152 | key1, key2 = jax.random.split(rng_key, 2)
153 | m = utils.expand_to_rank_right(params["m"], 2)
154 | particles = position.reshape([n, -1, self.n, self.space_dims])
155 | if self.randomize_canvas_location:
156 | offset = jax.random.uniform(key1, shape=[n, self.space_dims])
157 | offset = self.random_offset_bounds().convert_from_unit_interval(offset)
158 | else:
159 | offset = jnp.zeros(shape=[n, self.space_dims])
160 | particles = particles + offset[:, None, None, :]
161 | particles_radius = jnp.sqrt(m / jnp.pi)
162 | if self.num_colors == 1:
163 | color_index = jnp.zeros([n, self.n]).astype("int64")
164 | else:
165 | if self.num_colors < self.n:
166 | raise ValueError("The number of colors must be at least the number of "
167 | "objects or 1.")
168 | color_index = utils.random_int_k_from_n(
169 | key2,
170 | num_samples=n,
171 | n=self.num_colors,
172 | k=self.n
173 | )
174 | images = self._batch_render(particles, particles_radius, color_index)
175 | return images, dict(offset=offset, color_index=color_index)
176 |
177 |
178 | class TwoBodySystem(NBodySystem):
179 | """N-body system with N = 2."""
180 |
181 | def __init__(
182 | self,
183 | m_range: utils.BoxRegion,
184 | g_range: utils.BoxRegion,
185 | radius_range: utils.BoxRegion,
186 | provided_canvas_bounds: utils.BoxRegion,
187 | **kwargs):
188 | self.radius_range = radius_range
189 | super().__init__(n=2, space_dims=2,
190 | m_range=m_range,
191 | g_range=g_range,
192 | provided_canvas_bounds=provided_canvas_bounds,
193 | **kwargs)
194 |
195 | def sample_y(
196 | self,
197 | num_samples: int,
198 | params: utils.Params,
199 | rng_key: jnp.ndarray,
200 | **kwargs: Any
201 | ) -> phase_space.PhaseSpace:
202 | pos = jax.random.uniform(rng_key, [num_samples, self.n])
203 | pos = self.radius_range.convert_from_unit_interval(pos)
204 | r = jnp.sqrt(jnp.sum(pos ** 2, axis=-1))
205 |
206 | vel = jnp.flip(pos, axis=-1) / (2 * r[..., None] ** 1.5)
207 | vel = vel * jnp.asarray([1.0, -1.0]).reshape([1, 2])
208 |
209 | pos = jnp.repeat(pos.reshape([num_samples, 1, -1]), repeats=self.n, axis=1)
210 | vel = jnp.repeat(vel.reshape([num_samples, 1, -1]), repeats=self.n, axis=1)
211 |
212 | pos = pos * jnp.asarray([1.0, -1.0]).reshape([1, 2, 1])
213 | vel = vel * jnp.asarray([1.0, -1.0]).reshape([1, 2, 1])
214 | pos = pos.reshape([num_samples, -1])
215 | vel = vel.reshape([num_samples, -1])
216 | return phase_space.PhaseSpace(position=pos, momentum=vel)
217 |
218 |
219 | class ThreeBody2DSystem(NBodySystem):
220 | """N-body system with N = 3 in two dimensions."""
221 |
222 | def __init__(
223 | self,
224 | m_range: utils.BoxRegion,
225 | g_range: utils.BoxRegion,
226 | radius_range: utils.BoxRegion,
227 | provided_canvas_bounds: utils.BoxRegion,
228 | **kwargs):
229 | self.radius_range = radius_range
230 | super().__init__(n=3, space_dims=2,
231 | m_range=m_range,
232 | g_range=g_range,
233 | provided_canvas_bounds=provided_canvas_bounds,
234 | **kwargs)
235 |
236 | def sample_y(
237 | self,
238 | num_samples: int,
239 | params: utils.Params,
240 | rng_key: jnp.ndarray,
241 | **kwargs: Any
242 | ) -> phase_space.PhaseSpace:
243 | theta = 2 * jnp.pi / 3
244 | rot = jnp.asarray([[jnp.cos(theta), - jnp.sin(theta)],
245 | [jnp.sin(theta), jnp.cos(theta)]])
246 | p1 = 2 * jax.random.uniform(rng_key, [num_samples, 2]) - 1.0
247 | r = jax.random.uniform(rng_key, [num_samples])
248 | r = self.radius_range.convert_from_unit_interval(r)
249 |
250 | p1 *= (r / jnp.linalg.norm(p1, axis=-1))[:, None]
251 | p2 = jnp.matmul(p1, rot.T)
252 | p3 = jnp.matmul(p2, rot.T)
253 | p = jnp.concatenate([p1, p2, p3], axis=-1)
254 |
255 | # scale factor to get circular trajectories
256 | factor = jnp.sqrt(jnp.sin(jnp.pi / 3)/(2 * jnp.cos(jnp.pi / 6) **2))
257 | # velocity that yields a circular orbit
258 | v1 = jnp.flip(p1, axis=-1) * factor / r[:, None]**1.5
259 | v2 = jnp.matmul(v1, rot.T)
260 | v3 = jnp.matmul(v2, rot.T)
261 | v = jnp.concatenate([v1, v2, v3], axis=-1)
262 | return phase_space.PhaseSpace(position=p, momentum=v)
263 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/hamiltonian_systems/phase_space.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Module for the PhaseSpace class."""
16 | import functools
17 | from typing import Callable, Type, Union
18 |
19 | import jax
20 | from jax import numpy as jnp
21 | from jax import tree_util
22 |
23 |
24 | class PhaseSpace(object):
25 | """Holds a pair of position and momentum for a Hamiltonian System."""
26 |
27 | def __init__(self, position: jnp.ndarray, momentum: jnp.ndarray):
28 | self._position = position
29 | self._momentum = momentum
30 |
31 | @property
32 | def position(self) -> jnp.ndarray:
33 | """The position element of the phase space."""
34 | return self._position
35 |
36 | @property
37 | def momentum(self) -> jnp.ndarray:
38 | """The momentum element of the phase space."""
39 | return self._momentum
40 |
41 | @property
42 | def q(self) -> jnp.ndarray:
43 | """A shorthand for the position element of the phase space."""
44 | return self._position
45 |
46 | @property
47 | def p(self) -> jnp.ndarray:
48 | """A shorthand for the momentum element of the phase space."""
49 | return self._momentum
50 |
51 | @property
52 | def single_state(self) -> jnp.ndarray:
53 | """Returns the concatenation of position and momentum."""
54 | return jnp.concatenate([self.q, self.p], axis=-1)
55 |
56 | @property
57 | def ndim(self) -> int:
58 | """Returns the number of dimensions of the position array."""
59 | return self.q.ndim
60 |
61 | @classmethod
62 | def from_state(cls: Type["PhaseSpace"], state: jnp.ndarray) -> "PhaseSpace":
63 | q, p = jnp.split(state, 2, axis=-1)
64 | return cls(position=q, momentum=p)
65 |
66 | def __str__(self) -> str:
67 | return f"{type(self).__name__}(q={self.position}, p={self.momentum})"
68 |
69 | def __repr__(self) -> str:
70 | return self.__str__()
71 |
72 |
73 | class TangentPhaseSpace(PhaseSpace):
74 | """Represents the tangent space to PhaseSpace."""
75 |
76 | def __add__(
77 | self,
78 | other: Union[PhaseSpace, "TangentPhaseSpace"],
79 | ) -> Union[PhaseSpace, "TangentPhaseSpace"]:
80 | if isinstance(other, TangentPhaseSpace):
81 | return TangentPhaseSpace(position=self.q + other.q,
82 | momentum=self.p + other.p)
83 | elif isinstance(other, PhaseSpace):
84 | return PhaseSpace(position=self.q + other.q,
85 | momentum=self.p + other.p)
86 | else:
87 | raise ValueError(f"Can not add TangentPhaseSpace and {type(other)}.")
88 |
89 | def __radd__(
90 | self,
91 | other: Union[PhaseSpace, "TangentPhaseSpace"]
92 | ) -> Union[PhaseSpace, "TangentPhaseSpace"]:
93 | return self.__add__(other)
94 |
95 | def __mul__(self, other: jnp.ndarray) -> "TangentPhaseSpace":
96 | return TangentPhaseSpace(position=self.q * other,
97 | momentum=self.p * other)
98 |
99 | def __rmul__(self, other):
100 | return self.__mul__(other)
101 |
102 | @classmethod
103 | def zero(cls: Type["TangentPhaseSpace"]) -> "TangentPhaseSpace":
104 | return cls(position=jnp.asarray(0.0), momentum=jnp.asarray(0.0))
105 |
106 |
107 | HamiltonianFunction = Callable[
108 | [
109 | jnp.ndarray, # t
110 | PhaseSpace, # y
111 | ],
112 | jnp.ndarray # H(t, y)
113 | ]
114 |
115 | SymplecticTangentFunction = Callable[
116 | [
117 | jnp.ndarray, # t
118 | PhaseSpace # (q, p)
119 | ],
120 | TangentPhaseSpace # (dH_dp, - dH_dq)
121 | ]
122 |
123 | SymplecticTangentFunctionArray = Callable[
124 | [
125 | jnp.ndarray, # t
126 | jnp.ndarray # (q, p)
127 | ],
128 | jnp.ndarray # (dH_dp, - dH_dq)
129 | ]
130 |
131 |
132 | def poisson_bracket_with_q_and_p(
133 | f: HamiltonianFunction
134 | ) -> SymplecticTangentFunction:
135 | """Returns a function that computes the Poisson brackets {q,f} and {p,f}."""
136 | def bracket(t: jnp.ndarray, y: PhaseSpace) -> TangentPhaseSpace:
137 | # Use the summation trick for getting gradient
138 | # Note that the first argument to the hamiltonian is t
139 | grad = jax.grad(lambda *args: jnp.sum(f(*args)), argnums=1)(t, y)
140 | return TangentPhaseSpace(position=grad.p, momentum=-grad.q)
141 | return bracket
142 |
143 |
144 | def transform_symplectic_tangent_function_using_array(
145 | func: SymplecticTangentFunction
146 | ) -> SymplecticTangentFunctionArray:
147 | @functools.wraps(func)
148 | def wrapped(t: jnp.ndarray, state: jnp.ndarray) -> jnp.ndarray:
149 | return func(t, PhaseSpace.from_state(state)).single_state
150 | return wrapped
151 |
152 |
153 | tree_util.register_pytree_node(
154 | nodetype=PhaseSpace,
155 | flatten_func=lambda y: ((y.q, y.p), None),
156 | unflatten_func=lambda _, q_and_p: PhaseSpace(*q_and_p)
157 | )
158 |
159 | tree_util.register_pytree_node(
160 | nodetype=TangentPhaseSpace,
161 | flatten_func=lambda y: ((y.q, y.p), None),
162 | unflatten_func=lambda _, q_and_p: TangentPhaseSpace(*q_and_p)
163 | )
164 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/hamiltonian_systems/simple_analytic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """A module with all Hamiltonian systems that have analytic solutions."""
16 | from typing import Any, Optional, Tuple
17 |
18 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import hamiltonian
19 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space
20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils
21 |
22 | import jax.numpy as jnp
23 | import jax.random as jnr
24 |
25 |
26 | class PotentialFreeSystem(hamiltonian.TimeIndependentHamiltonianSystem):
27 | """A system where the potential energy is 0 and the kinetic is quadratic.
28 |
29 | Parameters:
30 | matrix - a positive semi-definite matrix used for the kinetic quadratic.
31 |
32 | The Hamiltonian is:
33 | p^T M p / 2
34 |
35 | Initial state parameters:
36 | min_radius - the minimum radius to sample from
37 | max_radius - the maximum radius to sample from
38 | """
39 |
40 | def __init__(
41 | self,
42 | system_dims: int,
43 | eigen_values_range: utils.BoxRegion,
44 | init_vector_range: utils.BoxRegion,
45 | **kwargs):
46 | super().__init__(system_dims=system_dims, **kwargs)
47 | if eigen_values_range.dims != 0 and eigen_values_range.dims != system_dims:
48 | raise ValueError(f"The eigen_values_range must be of the same dimensions "
49 | f"as the system dimensions, but is "
50 | f"{eigen_values_range.dims}.")
51 | if init_vector_range.dims != 0 and init_vector_range.dims != system_dims:
52 | raise ValueError(f"The init_vector_range must be of the same dimensions "
53 | f"as the system dimensions, but is "
54 | f"{init_vector_range.dims}.")
55 | self.eigen_values_range = eigen_values_range
56 | self.init_vector_range = init_vector_range
57 |
58 | def _hamiltonian(
59 | self,
60 | y: phase_space.PhaseSpace,
61 | params: utils.Params,
62 | **kwargs: Any
63 | ) -> jnp.ndarray:
64 | assert len(params) == 1
65 | matrix = params["matrix"]
66 | potential = 0
67 | kinetic = jnp.sum(jnp.matmul(y.p, matrix) * y.p, axis=-1) / 2
68 | return potential + kinetic
69 |
70 | def sample_y(
71 | self,
72 | num_samples: int,
73 | params: utils.Params,
74 | rng_key: jnp.ndarray,
75 | **kwargs: Any
76 | ) -> phase_space.PhaseSpace:
77 | # Sample random state
78 | y = jnr.uniform(rng_key, [num_samples, 2 * self.system_dims],
79 | dtype=self.dtype)
80 | y = self.init_vector_range.convert_from_unit_interval(y)
81 | return phase_space.PhaseSpace.from_state(y)
82 |
83 | def sample_params(
84 | self,
85 | num_samples: int,
86 | rng_key: jnp.ndarray,
87 | **kwargs: Any
88 | ) -> utils.Params:
89 | key1, key2 = jnr.split(rng_key)
90 | matrix_shape = [num_samples, self.system_dims, self.system_dims]
91 | gaussian = jnr.normal(key1, matrix_shape)
92 | q, _ = jnp.linalg.qr(gaussian)
93 | eigs = jnr.uniform(key2, [num_samples, self.system_dims])
94 | eigs = self.eigen_values_range.convert_from_unit_interval(eigs)
95 | q_eigs = q * eigs[..., None]
96 | matrix = jnp.matmul(q_eigs, jnp.swapaxes(q_eigs, -2, -1))
97 | return dict(matrix=matrix)
98 |
99 | def simulate_analytically(
100 | self,
101 | y0: phase_space.PhaseSpace,
102 | t0: utils.FloatArray,
103 | t_eval: jnp.ndarray,
104 | params: utils.Params,
105 | **kwargs: Any
106 | ) -> Optional[phase_space.PhaseSpace]:
107 | if self.friction != 0.0:
108 | return None
109 | assert len(params) == 1
110 | matrix = params["matrix"]
111 | t = utils.expand_to_rank_right(t_eval - t0, y0.q.ndim + 1)
112 | q = y0.q[None] + utils.vecmul(matrix, y0.p)[None] * t
113 | p = y0.p[None] * jnp.ones_like(t)
114 | return phase_space.PhaseSpace(position=q, momentum=p)
115 |
116 | def canvas_bounds(self) -> utils.BoxRegion:
117 | raise NotImplementedError()
118 |
119 | def canvas_position(
120 | self,
121 | position: jnp.ndarray,
122 | params: utils.Params
123 | ) -> jnp.ndarray:
124 | raise NotImplementedError()
125 |
126 | def render_trajectories(
127 | self,
128 | position: jnp.ndarray,
129 | params: utils.Params,
130 | rng_key: jnp.ndarray,
131 | **kwargs: Any
132 | ) -> Tuple[jnp.ndarray, utils.Params]:
133 | raise NotImplementedError()
134 |
135 |
136 | class KineticFreeSystem(PotentialFreeSystem):
137 | """A system where the kinetic energy is 0 and the potential is quadratic.
138 |
139 | Parameters:
140 | matrix - a positive semi-definite matrix used for the potential quadratic.
141 |
142 | The Hamiltonian is:
143 | q^T M q / 2
144 |
145 | Initial state parameters:
146 | min_radius - the minimum radius to sample from
147 | max_radius - the maximum radius to sample from
148 | """
149 |
150 | def _hamiltonian(
151 | self,
152 | y: phase_space.PhaseSpace,
153 | params: utils.Params,
154 | **kwargs: Any
155 | ) -> jnp.ndarray:
156 | assert len(params) == 1
157 | matrix = params["matrix"]
158 | potential = jnp.sum(jnp.matmul(y.q, matrix) * y.q, axis=-1) / 2
159 | kinetic = 0
160 | return potential + kinetic
161 |
162 | def simulate_analytically(
163 | self,
164 | y0: phase_space.PhaseSpace,
165 | t0: utils.FloatArray,
166 | t_eval: jnp.ndarray,
167 | params: utils.Params,
168 | **kwargs: Any
169 | ) -> Optional[phase_space.PhaseSpace]:
170 | if self.friction != 0.0:
171 | return None
172 | assert len(params) == 1
173 | matrix = params["matrix"]
174 | t = utils.expand_to_rank_right(t_eval - t0, y0.q.ndim + 1)
175 | q = y0.q[None] * jnp.ones_like(t)
176 | p = y0.p[None] - utils.vecmul(matrix, y0.q)[None] * t
177 | return phase_space.PhaseSpace(position=q, momentum=p)
178 |
179 | def canvas_bounds(self) -> utils.BoxRegion:
180 | raise NotImplementedError()
181 |
182 | def canvas_position(
183 | self,
184 | position: jnp.ndarray,
185 | params: utils.Params
186 | ) -> jnp.ndarray:
187 | raise NotImplementedError()
188 |
189 | def render_trajectories(
190 | self,
191 | position: jnp.ndarray,
192 | params: utils.Params,
193 | rng_key: jnp.ndarray,
194 | **kwargs: Any
195 | ) -> Tuple[jnp.ndarray, utils.Params]:
196 | raise NotImplementedError()
197 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/hamiltonian_systems/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Module with various utilities not central to the code."""
16 | from typing import Dict, Optional, Tuple, Union
17 |
18 | import jax
19 | from jax import lax
20 | import jax.numpy as jnp
21 | import jax.random as jnr
22 | import numpy as np
23 |
24 |
25 | FloatArray = Union[float, jnp.ndarray]
26 | Params = Dict[str, jnp.ndarray]
27 |
28 |
29 | class BoxRegion:
30 | """A class for bounds, especially used for sampling."""
31 |
32 | def __init__(self, minimum: FloatArray, maximum: FloatArray):
33 | minimum = jnp.asarray(minimum)
34 | maximum = jnp.asarray(maximum)
35 | if minimum.shape != maximum.shape:
36 | raise ValueError(f"The passed values for minimum and maximum should have "
37 | f"the same shape, but had shapes: {minimum.shape} "
38 | f"and {maximum.shape}.")
39 | self._minimum = minimum
40 | self._maximum = maximum
41 |
42 | @property
43 | def min(self) -> jnp.ndarray:
44 | return self._minimum
45 |
46 | @property
47 | def max(self) -> jnp.ndarray:
48 | return self._maximum
49 |
50 | @property
51 | def size(self) -> jnp.ndarray:
52 | return self.max - self.min
53 |
54 | @property
55 | def dims(self) -> int:
56 | if self.min.ndim != 0:
57 | return self.min.shape[-1]
58 | return 0
59 |
60 | def convert_to_unit_interval(self, value: jnp.ndarray) -> jnp.ndarray:
61 | return (value - self.min) / self.size
62 |
63 | def convert_from_unit_interval(self, value: jnp.ndarray) -> jnp.ndarray:
64 | return value * self.size + self.min
65 |
66 | def __str__(self) -> str:
67 | return f"{type(self).__name__}(min={self.min}, max={self.max})"
68 |
69 | def __repr__(self) -> str:
70 | return self.__str__()
71 |
72 |
73 | def expand_to_rank_right(x: jnp.ndarray, rank: int) -> jnp.ndarray:
74 | if x.ndim == rank:
75 | return x
76 | assert x.ndim < rank
77 | new_shape = x.shape + (1,) * (rank - x.ndim)
78 | return x.reshape(new_shape)
79 |
80 |
81 | def expand_to_rank_left(x: jnp.ndarray, rank: int) -> int:
82 | if x.ndim == rank:
83 | return x
84 | assert x.ndim < rank
85 | new_shape = (1,) * (rank - x.ndim) + x.shape
86 | return x.reshape(new_shape)
87 |
88 |
89 | def vecmul(matrix: jnp.ndarray, vector: jnp.ndarray) -> jnp.ndarray:
90 | return jnp.matmul(matrix, vector[..., None])[..., 0]
91 |
92 |
93 | def dt_to_t_eval(t0: FloatArray, dt: FloatArray, num_steps: int) -> jnp.ndarray:
94 | if (isinstance(t0, (float, np.ndarray)) and
95 | isinstance(dt, (float, np.ndarray))):
96 | dt = np.asarray(dt)[None]
97 | shape = [num_steps] + [1] * (dt.ndim - 1)
98 | return t0 + dt * np.arange(1, num_steps + 1).reshape(shape)
99 | else:
100 | return t0 + dt * jnp.arange(1, num_steps + 1)
101 |
102 |
103 | def t_eval_to_dt(t0: FloatArray, t_eval: FloatArray) -> jnp.ndarray:
104 | t = jnp.ones_like(t_eval[:1]) * t0
105 | t = jnp.concatenate([t, t_eval], axis=0)
106 | return t[1:] - t[:-1]
107 |
108 |
109 | def simple_loop(
110 | f,
111 | x0: jnp.ndarray,
112 | t_args: Optional[jnp.ndarray] = None,
113 | num_steps: Optional[int] = None,
114 | use_scan: bool = True
115 | ) -> Tuple[jnp.ndarray, jnp.ndarray]:
116 | """Runs a simple loop that outputs the evolved variable at every time step."""
117 | if t_args is None and num_steps is None:
118 | raise ValueError("Exactly one of `t_args` and `num_steps` should be "
119 | "provided.")
120 | if t_args is not None and num_steps is not None:
121 | raise ValueError("Exactly one of `t_args` and `num_steps` should be "
122 | "provided.")
123 |
124 | def step(x_t, t_arg):
125 | x_next = f(x_t) if t_arg is None else f(x_t, t_arg)
126 | return x_next, x_next
127 | if use_scan:
128 | return lax.scan(step, init=x0, xs=t_args, length=num_steps)[1]
129 |
130 | y = []
131 | x = x0
132 | num_steps = t_args.shape[0] if t_args is not None else num_steps
133 | t_args = [None] * num_steps if t_args is None else t_args
134 | for i in range(num_steps):
135 | x, _ = step(x, t_args[i])
136 | y.append(x)
137 | return jax.tree_multimap(lambda *args: jnp.stack(args, axis=0), *y)
138 |
139 |
140 | def hsv2rgb(array: jnp.ndarray) -> jnp.ndarray:
141 | """Converts an HSV float array image to RGB."""
142 | hi = jnp.floor(array[..., 0] * 6)
143 | f = array[..., 0] * 6 - hi
144 | p = array[..., 2] * (1 - array[..., 1])
145 | q = array[..., 2] * (1 - f * array[..., 1])
146 | t = array[..., 2] * (1 - (1 - f) * array[..., 1])
147 | v = array[..., 2]
148 |
149 | hi = jnp.stack([hi, hi, hi], axis=-1).astype(jnp.uint8) % 6
150 | hi_is_0 = hi == 0
151 | hi_is_1 = hi == 1
152 | hi_is_2 = hi == 2
153 | hi_is_3 = hi == 3
154 | hi_is_4 = hi == 4
155 | hi_is_5 = hi == 5
156 | out = (hi_is_0 * jnp.stack((v, t, p), axis=-1) +
157 | hi_is_1 * jnp.stack((q, v, p), axis=-1) +
158 | hi_is_2 * jnp.stack((p, v, t), axis=-1) +
159 | hi_is_3 * jnp.stack((p, q, v), axis=-1) +
160 | hi_is_4 * jnp.stack((t, p, v), axis=-1) +
161 | hi_is_5 * jnp.stack((v, p, q), axis=-1))
162 | return out
163 |
164 |
165 | def render_particles_trajectory(
166 | particles: jnp.ndarray,
167 | particles_radius: FloatArray,
168 | color_indices: FloatArray,
169 | canvas_limits: BoxRegion,
170 | resolution: int,
171 | num_colors: int,
172 | background_color: Tuple[float, float, float] = (0.321, 0.349, 0.368),
173 | temperature: FloatArray = 80.0):
174 | """Renders n particles in different colors for a full trajectory.
175 |
176 | NB: The default background color is not black as we have experienced issues
177 | when training models with black background.
178 |
179 | Args:
180 | particles: Array of size (t, n, 2)
181 | The last 2 dimensions define the x, y coordinates of each particle.
182 | particles_radius: Array of size (n,) or a single value.
183 | Defines the radius of each particle.
184 | color_indices: Array of size (n,) or a single value.
185 | Defines the color of each particle.
186 | canvas_limits: List of 2 lists or Array of size (2, 2)
187 | First row defines the limit over x and second over y.
188 | resolution: int
189 | The resolution of the produced images.
190 | num_colors: int
191 | The number of possible colors to use.
192 | background_color: List or Array of size (3)
193 | The color for the background. Default to black.
194 | temperature: float
195 | The temperature of the sigmoid distance metric used to the center of the
196 | particles.
197 |
198 | Returns:
199 | An array of size (t, resolution, resolution, 3) with the produced images.
200 | """
201 | particles = jnp.asarray(particles)
202 | assert particles.ndim == 3
203 | assert particles.shape[-1] == 2
204 | t, n = particles.shape[:2]
205 | particles_radius = jnp.asarray(particles_radius)
206 | if particles_radius.ndim == 0:
207 | particles_radius = jnp.full([n], particles_radius)
208 | assert particles_radius.shape == (n,)
209 | color_indices = jnp.asarray(color_indices)
210 | if color_indices.ndim == 0:
211 | color_indices = jnp.full([n], color_indices)
212 | assert color_indices.shape == (n,), f"Colors shape: {color_indices.shape}"
213 | background_color = jnp.asarray(background_color)
214 | assert background_color.shape == (3,)
215 |
216 | particles = canvas_limits.convert_to_unit_interval(particles)
217 | canvas_size = canvas_limits.max - canvas_limits.min
218 | canvas_size = canvas_size[0] if canvas_size.ndim == 1 else canvas_size
219 | particles_radius = particles_radius / canvas_size
220 | images = jnp.ones([t, resolution, resolution, 3]) * background_color
221 |
222 | hues = jnp.linspace(0, 1, num=num_colors, endpoint=False)
223 | colors = hues[color_indices][None, :, None, None]
224 | s_channel = jnp.ones((t, n, resolution, resolution))
225 | v_channel = jnp.ones((t, n, resolution, resolution))
226 | h_channel = jnp.ones((t, n, resolution, resolution)) * colors
227 | hsv_imgs = jnp.stack((h_channel, s_channel, v_channel), axis=-1)
228 | rgb_imgs = hsv2rgb(hsv_imgs)
229 | images = [img[:, 0] for img in jnp.split(rgb_imgs, n, axis=1)] + [images]
230 |
231 | grid = jnp.linspace(0.0, 1.0, resolution)
232 | dx, dy = jnp.meshgrid(grid, grid)
233 | dx, dy = dx[None, None], dy[None, None]
234 | x, y = particles[..., 0][..., None, None], particles[..., 1][..., None, None]
235 | d = jnp.sqrt((x - dx) ** 2 + (y - dy) ** 2)
236 | particles_radius = particles_radius[..., None, None]
237 | mask = 1.0 / (1.0 + jnp.exp((d - particles_radius) * temperature))
238 | masks = ([m[:, 0, ..., None] for m in jnp.split(mask, n, axis=1)] +
239 | [jnp.ones_like(images[0])])
240 |
241 | final_image = jnp.zeros([t, resolution, resolution, 3])
242 | c = jnp.ones_like(images[0])
243 | for img, m in zip(images, masks):
244 | final_image = final_image + c * m * img
245 | c = c * (1 - m)
246 | return final_image
247 |
248 |
249 | def uniform_annulus(
250 | key: jnp.ndarray,
251 | num_samples: int,
252 | dim_samples: int,
253 | radius_range: BoxRegion,
254 | uniform: bool
255 | ) -> jnp.ndarray:
256 | """Samples points uniformly in the annulus defined by radius range."""
257 | key1, key2 = jnr.split(key)
258 | direction = jnr.normal(key1, [num_samples, dim_samples])
259 | norms = jnp.linalg.norm(direction, axis=-1, keepdims=True)
260 | direction = direction / norms
261 | # Sample a radius uniformly between [min_radius, max_radius]
262 | r = jnr.uniform(key2, [num_samples])
263 | if uniform:
264 | radius_range = BoxRegion(radius_range.min ** 2, radius_range.max ** 2)
265 | r = jnp.sqrt(radius_range.convert_from_unit_interval(r))
266 | else:
267 | r = radius_range.convert_from_unit_interval(r)
268 | return direction * r[:, None]
269 |
270 |
271 | multi_shuffle = jax.vmap(lambda x, key, k: jnr.permutation(key, x)[:k],
272 | in_axes=(0, 0, None), out_axes=0)
273 |
274 |
275 | def random_int_k_from_n(
276 | rng: jnp.ndarray,
277 | num_samples: int,
278 | n: int,
279 | k: int
280 | ) -> jnp.ndarray:
281 | """Samples randomly k integers from 1 to n."""
282 | if k > n:
283 | raise ValueError(f"k should be less than or equal to n, but got k={k} and "
284 | f"n={n}.")
285 | x = jnp.repeat(jnp.arange(n).reshape([1, n]), num_samples, axis=0)
286 | return multi_shuffle(x, jnr.split(rng, num_samples), k)
287 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/load_datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """A module for loading the Hamiltonian datasets."""
16 | import functools
17 | import os
18 | from typing import Optional, Mapping, Any, Tuple, Sequence, Callable, Union, TypeVar
19 |
20 | import jax
21 | import tensorflow as tf
22 | import tensorflow_datasets as tfds
23 |
24 |
25 | T = TypeVar("T")
26 |
27 |
28 | def filter_based_on_keys(
29 | example: Mapping[str, T],
30 | keys_to_preserve: Sequence[str],
31 | single_key_return_array: bool = False
32 | ) -> Union[T, Mapping[str, T]]:
33 | """Filters the contents of the mapping, to return only the keys given in `keys_to_preserve`."""
34 | if not keys_to_preserve:
35 | raise ValueError("You must provide at least one key to preserve.")
36 | if len(keys_to_preserve) == 1 and single_key_return_array:
37 | return example[keys_to_preserve[0]]
38 | elif single_key_return_array:
39 | raise ValueError(f"You have provided {len(keys_to_preserve)}>1 keys to "
40 | f"preserve and have also set "
41 | f"single_key_return_array=True.")
42 | return {k: example[k] for k in keys_to_preserve}
43 |
44 |
45 | def preprocess_batch(
46 | batch: Mapping[str, Any],
47 | num_local_devices: int,
48 | multi_device: bool,
49 | sub_sample_length: Optional[int],
50 | dtype: str = "float32"
51 | ) -> Mapping[str, Any]:
52 | """Function to preprocess the data for a batch.
53 |
54 | This performs two functions:
55 | 1.If 'sub_sample_length' is not None, it randomly subsamples every example
56 | along the second (time) axis to return an array of the requested length.
57 | Note that this assumes that all arrays have the same length.
58 | 2. Converts all arrays to the provided data type using
59 | `tf.image.convert_image_dtype`.
60 | 3. Reshapes the array based on the number of local devices.
61 |
62 | Args:
63 | batch: Dictionary with the full batch.
64 | num_local_devices: The number of local devices.
65 | multi_device: Whether to prepare the batch for multi device training.
66 | sub_sample_length: The sub-sampling length requested.
67 | dtype: String for the dtype, must be a floating-point type.
68 |
69 | Returns:
70 | The preprocessed batch.
71 | """
72 | if dtype not in ("float32", "float64"):
73 | raise ValueError("The provided dtype must be a floating point dtype.")
74 | tensor = batch.get("image", batch.get("x", None))
75 | if tensor is None:
76 | raise ValueError("We need either the key 'image' or 'x' to be present in "
77 | "the batch provided.")
78 | if not isinstance(tensor, tf.Tensor):
79 | raise ValueError(f"Expecting the value for key 'image' or 'x' to be a "
80 | f"tf.Tensor, instead got {type(tensor)}.")
81 | # `n` here represents the batch size
82 | n = tensor.shape[0] or tf.shape(tensor)[0]
83 | # `t` here represents number of time steps in the batch
84 | t = tensor.shape[1]
85 | if sub_sample_length is not None:
86 | # Sample random index for each example in the batch
87 | limit = t - sub_sample_length + 1
88 | start = tf.random.uniform(shape=[n], maxval=limit, dtype="int32")
89 | indices = tf.range(sub_sample_length)[None, :, None] + start[:, None, None]
90 | def index(x):
91 | """Indexes every array in the batch according to the sampled indices and length, if its second dimensions is equal to `t`."""
92 | if x.shape.rank > 1 and x.shape[1] == t:
93 | if isinstance(n, tf.Tensor):
94 | shape = [None, sub_sample_length] + list(x.shape[2:])
95 | else:
96 | shape = [n, sub_sample_length] + list(x.shape[2:])
97 | x = tf.gather_nd(x, indices, batch_dims=1)
98 | x.set_shape(shape)
99 | return x
100 |
101 | batch = jax.tree_map(index, batch)
102 |
103 | def convert_fn(x):
104 | """Converts the value of `x` to the provided precision dtype.
105 |
106 | Integer valued arrays, with data type different from int32 and int64 are
107 | assumed to represent compressed images and are converted via
108 | `tf.image.convert_image_dtype`. For any other data types (float or int)
109 | their type is preserved, but their precision is changed based on the
110 | target `dtype`. For instance, if `dtype=float32` the float64 variables are
111 | converted to float32 and int64 values are converted to int32.
112 |
113 | Args:
114 | x: The input array.
115 |
116 | Returns:
117 | The converted output array.
118 | """
119 | if x.dtype == tf.int64:
120 | return tf.cast(x, "int32") if dtype == "float32" else x
121 | elif x.dtype == tf.int32:
122 | return tf.cast(x, "int64") if dtype == "float64" else x
123 | elif x.dtype == tf.float64 or x.dtype == tf.float32:
124 | return tf.cast(x, dtype=dtype)
125 | else:
126 | return tf.image.convert_image_dtype(x, dtype=dtype)
127 |
128 | batch = jax.tree_map(convert_fn, batch)
129 | if not multi_device:
130 | return batch
131 | def reshape_for_jax_pmap(x):
132 | """Reshapes values such that their leading dimension is the number of local devices."""
133 | return tf.reshape(x, [num_local_devices, -1] + x.shape[1:].as_list())
134 | return jax.tree_map(reshape_for_jax_pmap, batch)
135 |
136 |
137 | def load_filenames_and_parse_fn(
138 | path: str,
139 | tfrecord_prefix: str
140 | ) -> Tuple[Tuple[str], Callable[[str], Mapping[str, Any]]]:
141 | """Returns the file names and read_fn based on the number of shards."""
142 | file_name = os.path.join(path, f"{tfrecord_prefix}.tfrecord")
143 | if not os.path.exists(file_name):
144 | raise ValueError(f"The dataset file {file_name} does not exist.")
145 | features_file = os.path.join(path, "features.txt")
146 | if not os.path.exists(features_file):
147 | raise ValueError(f"The dataset features file {features_file} does not "
148 | f"exist.")
149 | with open(features_file, "r") as f:
150 | dtype_dict = dict()
151 | shapes_dict = dict()
152 | parsing_description = dict()
153 | for line in f:
154 | key = line.split(", ")[0]
155 | shape_string = line.split("(")[1].split(")")[0]
156 | shapes_dict[key] = tuple(int(s) for s in shape_string.split(",") if s)
157 | dtype_dict[key] = line.split(", ")[-1][:-1]
158 | if dtype_dict[key] == "uint8":
159 | parsing_description[key] = tf.io.FixedLenFeature([], tf.string)
160 | elif dtype_dict[key] in ("float32", "float64"):
161 | parsing_description[key] = tf.io.VarLenFeature(tf.int64)
162 | else:
163 | parsing_description[key] = tf.io.VarLenFeature(dtype_dict[key])
164 |
165 | def parse_fn(example_proto: str) -> Mapping[str, Any]:
166 | raw = tf.io.parse_single_example(example_proto, parsing_description)
167 | parsed = dict()
168 | for name, dtype in dtype_dict.items():
169 | value = raw[name]
170 | if dtype == "uint8":
171 | value = tf.image.decode_png(value)
172 | else:
173 | value = tf.sparse.to_dense(value)
174 | if dtype in ("float32", "float64"):
175 | value = tf.bitcast(value, type=dtype)
176 | value = tf.reshape(value, shapes_dict[name])
177 | if "/" in name:
178 | k1, k2 = name.split("/")
179 | if k1 not in parsed:
180 | parsed[k1] = dict()
181 | parsed[k1][k2] = value
182 | else:
183 | parsed[name] = value
184 | return parsed
185 |
186 | return (file_name,), parse_fn
187 |
188 |
189 | def load_parsed_dataset(
190 | path: str,
191 | tfrecord_prefix: str,
192 | num_shards: int,
193 | shard_index: Optional[int] = None,
194 | keys_to_preserve: Optional[Sequence[str]] = None
195 | ) -> tf.data.Dataset:
196 | """Loads a dataset and shards it based on jax devices."""
197 | shard_index = shard_index or jax.process_index()
198 | file_names, parse_fn = load_filenames_and_parse_fn(
199 | path=path,
200 | tfrecord_prefix=tfrecord_prefix,
201 | )
202 |
203 | ds = tf.data.TFRecordDataset(file_names)
204 |
205 | threads = max(1, os.cpu_count() - 4)
206 | options = tf.data.Options()
207 | options.threading.private_threadpool_size = threads
208 | options.threading.max_intra_op_parallelism = 1
209 | ds = ds.with_options(options)
210 |
211 | # Shard if we don't shard by files
212 | if num_shards != 1:
213 | ds = ds.shard(num_shards, shard_index)
214 |
215 | # Parse the examples one by one
216 | if keys_to_preserve is not None:
217 | # Optionally also filter them based on the keys provided
218 | def parse_filter(example_proto):
219 | example = parse_fn(example_proto)
220 | return filter_based_on_keys(example, keys_to_preserve=keys_to_preserve)
221 | ds = ds.map(parse_filter, num_parallel_calls=tf.data.experimental.AUTOTUNE)
222 | else:
223 | ds = ds.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
224 |
225 | return ds
226 |
227 |
228 | def load_dataset(
229 | path: str,
230 | tfrecord_prefix: str,
231 | sub_sample_length: Optional[int],
232 | per_device_batch_size: int,
233 | num_epochs: Optional[int],
234 | drop_remainder: bool,
235 | multi_device: bool = False,
236 | num_shards: int = 1,
237 | shard_index: Optional[int] = None,
238 | keys_to_preserve: Optional[Sequence[str]] = None,
239 | shuffle: bool = False,
240 | cache: bool = True,
241 | shuffle_buffer: Optional[int] = 10000,
242 | dtype: str = "float32",
243 | seed: Optional[int] = None
244 | ) -> tf.data.Dataset:
245 | """Creates a tensorflow.Dataset pipeline from an TFRecord dataset.
246 |
247 | Args:
248 | path: The path to the dataset.
249 | tfrecord_prefix: The dataset prefix.
250 | sub_sample_length: The length of the sequences that will be returned.
251 | If this is `None` the dataset will return full length sequences.
252 | If this is an `int` it will subsample each sequence uniformly at random
253 | for a sequence of the provided size. Note that all examples in the dataset
254 | must be at least this long, otherwise the tensorflow code might crash.
255 | per_device_batch_size: The batch size to use on a single device. The actual
256 | batch size is this multiplied by the number of devices.
257 | num_epochs: The number of times to repeat the full dataset.
258 | drop_remainder: If the number of examples in the dataset are not divisible
259 | evenly by the batch size, whether each epoch to drop the remaining
260 | examples, or to construct a batch with batch size smaller than usual.
261 | multi_device: Whether to load the dataset prepared for multi-device use
262 | (e.g. pmap) with leading dimension equal to the number of local devices.
263 | num_shards: If you want to shard the dataset, you must specify how many
264 | shards you want to use.
265 | shard_index: The shard index for this host. If `None` will use
266 | `jax.process_index()`.
267 | keys_to_preserve: Explicit specification which keys to keep from the dataset
268 | shuffle: Whether to shuffle examples in the dataset.
269 | cache: Whether to use cache in the tf.Dataset.
270 | shuffle_buffer: Size of the shuffling buffer.
271 | dtype: What data type to convert the data to.
272 | seed: Seed to pass to the loader.
273 | Returns:
274 | A tensorflow dataset object.
275 | """
276 | per_host_batch_size = per_device_batch_size * jax.local_device_count()
277 | # Preprocessing function
278 | batch_fn = functools.partial(
279 | preprocess_batch,
280 | num_local_devices=jax.local_device_count(),
281 | multi_device=multi_device,
282 | sub_sample_length=sub_sample_length,
283 | dtype=dtype)
284 |
285 | with tf.name_scope("dataset"):
286 | ds = load_parsed_dataset(
287 | path=path,
288 | tfrecord_prefix=tfrecord_prefix,
289 | num_shards=num_shards,
290 | shard_index=shard_index,
291 | keys_to_preserve=keys_to_preserve,
292 | )
293 | if cache:
294 | ds = ds.cache()
295 | if shuffle:
296 | ds = ds.shuffle(shuffle_buffer, seed=seed)
297 | ds = ds.repeat(num_epochs)
298 | ds = ds.batch(per_host_batch_size, drop_remainder=drop_remainder)
299 | ds = ds.map(batch_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
300 | return ds.prefetch(tf.data.experimental.AUTOTUNE)
301 |
302 |
303 | def dataset_as_iter(dataset_func, *args, **kwargs):
304 | def iterable_func():
305 | yield from tfds.as_numpy(dataset_func(*args, **kwargs))
306 | return iterable_func
307 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/molecular_dynamics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/molecular_dynamics/generate_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Creates a small molecular dynamics (MD) dataset.
16 |
17 | This binary creates a small MD dataset given a text-based trajectory file
18 | generated with the simulation package LAMMPS (https://lammps.sandia.gov/).
19 | The trajectory file can be generated by running LAMMPS with the input script
20 | provided.
21 |
22 | We note that this binary is intended as a demonstration only and is therefore
23 | not optimised for memory efficiency or performance.
24 | """
25 |
26 | from typing import Mapping, Tuple
27 |
28 | from dm_hamiltonian_dynamics_suite import datasets
29 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils
30 |
31 | import numpy as np
32 |
33 | Array = np.ndarray
34 |
35 |
36 | def read_trajectory(filename: str) -> Tuple[Array, float]:
37 | """Reads the trajectory data from file and returns it as an array.
38 |
39 | Each timestep contains a header and the atom data. The header is 9 lines long
40 | and contains the timestep, the number of atoms, the box dimensions and the
41 | list of atom properties. The header is assumed to be structured as in the
42 | example below:
43 |
44 | ITEM: TIMESTEP
45 | <>
46 | ITEM: NUMBER OF ATOMS
47 | <>
48 | ITEM: BOX BOUNDS pp pp pp
49 | <> <>
50 | <> <>
51 | <> <>
52 | ITEM: ATOMS id type x y vx vy fx fy
53 | .... <> lines with properties of <> atoms....
54 |
55 | Args:
56 | filename: name of the input file.
57 |
58 | Returns:
59 | A pair where the first element corresponds to an array of shape
60 | [num_timesteps, num_atoms, 6] containing the atom data and the second
61 | element corresponds to the edge length of the simulation box.
62 | """
63 | with open(filename, 'r') as f:
64 | dat = f.read()
65 | lines = dat.split('\n')
66 |
67 | # Extract the number of particles and the edge length of the simulation box.
68 | num_particles = int(lines[3])
69 | box = np.fromstring(lines[5], dtype=np.float32, sep=' ')
70 | box_length = box[1] - box[0]
71 |
72 | # Iterate over all timesteps and extract the relevant data columns.
73 | header_size = 9
74 | record_size = header_size + num_particles
75 | num_records = len(lines) // record_size
76 | records = []
77 | for i in range(num_records):
78 | record = lines[header_size + i * record_size:(i + 1) * record_size]
79 | record = np.array([l.split(' ')[:-1] for l in record], dtype=np.float32)
80 | records.append(record)
81 | records = np.array(records)[..., 2:]
82 | return records, box_length
83 |
84 |
85 | def flatten_record(x: Array) -> Array:
86 | """Reshapes input from [num_particles, 2*dim] to [2*num_particles*dim]."""
87 | if x.shape[-1] % 2 != 0:
88 | raise ValueError(f'Expected last dimension to be even, got {x.shape[-1]}.')
89 | dim = x.shape[-1] // 2
90 | q = x[..., 0:dim]
91 | p = x[..., dim:]
92 | q_flat = q.reshape(list(q.shape[:-2]) + [-1])
93 | p_flat = p.reshape(list(p.shape[:-2]) + [-1])
94 | x_flat = np.concatenate((q_flat, p_flat), axis=-1)
95 | return x_flat
96 |
97 |
98 | def render_images(
99 | x: Array,
100 | box_length: float,
101 | resolution: int = 32,
102 | particle_radius: float = 0.3
103 | ) -> Array:
104 | """Renders a sequence with shape [num_steps, num_particles*dim] as images."""
105 | dim = 2
106 | sequence_length, num_coordinates = x.shape
107 | if num_coordinates % (2 * dim) != 0:
108 | raise ValueError('Expected the number of coordinates to be divisible by 4, '
109 | f'got {num_coordinates}.')
110 | # The 4 coordinates are positions and velocities in 2d.
111 | num_particles = num_coordinates // (2 * dim)
112 | # `x` is formatted as [x_1, y_1,... x_N, y_N, dx_1, dy_1,..., dx_N, dy_N],
113 | # where `N=num_particles`. For the image generation, we only require x and y
114 | # coordinates.
115 | particles = x[..., :num_particles * dim]
116 | particles = particles.reshape((sequence_length, num_particles, dim))
117 | colors = np.arange(num_particles, dtype=np.int32)
118 | box_region = utils.BoxRegion(-box_length / 2., box_length / 2.)
119 | images = utils.render_particles_trajectory(
120 | particles=particles,
121 | particles_radius=particle_radius,
122 | color_indices=colors,
123 | canvas_limits=box_region,
124 | resolution=resolution,
125 | num_colors=num_particles)
126 | return images
127 |
128 |
129 | def convert_sequence(sequence: Array, box_length: float) -> Mapping[str, Array]:
130 | """Converts a sequence of timesteps to a data point."""
131 | num_steps, num_particles, num_fields = sequence.shape
132 | # A LAMMPS record should contain positions, velocities and forces.
133 | if num_fields != 6:
134 | raise ValueError('Expected input sequence to be of shape '
135 | f'[num_steps, num_particles, 6], got {sequence.shape}.')
136 | x = np.empty((num_steps, num_particles * 4))
137 | dx_dt = np.empty((num_steps, num_particles * 4))
138 | for step in range(num_steps):
139 | # Assign positions and momenta to `x` and momenta and forces to `dx_dt`.
140 | x[step] = flatten_record(sequence[step, :, (0, 1, 2, 3)])
141 | dx_dt[step] = flatten_record(sequence[step, :, (2, 3, 4, 5)])
142 |
143 | image = render_images(x, box_length)
144 | image = np.array(image * 255.0, dtype=np.uint8)
145 | return dict(x=x, dx_dt=dx_dt, image=image)
146 |
147 |
148 | def write_to_file(
149 | data: Array,
150 | box_length: float,
151 | output_path: str,
152 | split: str,
153 | overwrite: bool,
154 | ) -> None:
155 | """Writes the data to file."""
156 |
157 | def generator():
158 | for sequence in data:
159 | yield convert_sequence(sequence, box_length)
160 |
161 | datasets.transform_dataset(generator(), output_path, split, overwrite)
162 |
163 |
164 | def generate_lammps_dataset(
165 | lammps_file: str,
166 | folder: str,
167 | num_steps: int,
168 | num_train: int,
169 | num_test: int,
170 | dt: int,
171 | shuffle: bool,
172 | seed: int,
173 | overwrite: bool,
174 | ) -> None:
175 | """Creates the train and test datasets."""
176 | if num_steps < 1:
177 | raise ValueError(f'Expected `num_steps` to be >= 1, got {num_steps}.')
178 | if dt < 1:
179 | raise ValueError(f'Expected `dt` to be >= 1, got {dt}.')
180 |
181 | records, box_length = read_trajectory(lammps_file)
182 | # Consider only every dt-th timestep in the input file.
183 | records = records[::dt]
184 | num_records, num_particles, num_fields = records.shape
185 | if num_records < (num_test + num_train) * num_steps:
186 | raise ValueError(
187 | f'Trajectory contains only {num_records} records which is insufficient'
188 | f'for the requested train/test split of {num_train}/{num_test} with '
189 | f'sequence length {num_steps}.')
190 |
191 | # Reshape and shuffle the data.
192 | num_points = num_records // num_steps
193 | records = records[:num_points * num_steps]
194 | records = records.reshape((num_points, num_steps, num_particles, num_fields))
195 | if shuffle:
196 | np.random.RandomState(seed).shuffle(records)
197 |
198 | # Create train/test splits and write them to file.
199 | train_records = records[:num_train]
200 | test_records = records[num_train:num_train + num_test]
201 | print('Writing the train dataset to file.')
202 | write_to_file(train_records, box_length, folder, 'train', overwrite)
203 | print('Writing the test dataset to file.')
204 | write_to_file(test_records, box_length, folder, 'test', overwrite)
205 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/molecular_dynamics/lj_16.lmp:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | #
17 | #
18 | # Simulate a two-dimensional Lennard-Jones fluid at the thermodynamic state
19 | # T*=0.5 and rho*=0.77. See Smit and Frenkel (1991), J. Chem. Phys. for the
20 | # vapour-liquid phase diagram of that system (http://doi.org/10.1063/1.460477).
21 | # Further information on the individual LAMMPS commands in this script can be
22 | # found on the LAMMPS homepage https://lammps.sandia.gov/doc/Manual.html.
23 | #
24 | # Simulation protocol:
25 | # 1. Equilibrate the system in an NVT run using a Langevin thermostat and
26 | # compute the mean energy.
27 | # 2. Perform a subsequent NVE simulation and measure the mean energy.
28 | # 3. Scale all particle velocities such that the energy agrees with the mean
29 | # value of the NVT run.
30 | # 4. Perform an NVE production run. The mean temperature sampled during that
31 | # run should now be close to the target value.
32 | #
33 | # For the 4-particle dataset, please make changes to lines 20, 23 and 26, as
34 | # suggested in the comments below.
35 |
36 | # number of particles to simulate (change to 4 for the 4-particle dataset)
37 | variable num_particles equal 16
38 |
39 | # reference temperature (change to 1.0 for the 4-particle dataset)
40 | variable temperature_ref equal 0.5
41 |
42 | # reference density (change to 0.04 for the 4-particle dataset)
43 | variable density_ref equal 0.77
44 |
45 | # damping coefficient for the Langevin thermostat
46 | variable temperature_damp equal 0.2
47 |
48 | # edge length of the square simulation box
49 | variable box_length equal exp(ln(${num_particles}/${density_ref})/2.)
50 |
51 | # Log to file every timestep.
52 | variable log_every equal 1
53 |
54 | # simulation timestep
55 | variable timestep equal 0.002
56 |
57 | # length of equilibration run
58 | variable runtime_equilibration equal 10000
59 |
60 | # length of production run
61 | variable runtime_production equal 1000
62 |
63 | # number of timesteps for both runs
64 | variable steps_equilibration equal floor(${runtime_equilibration}/${timestep})
65 | variable steps_production equal floor(${runtime_production}/${timestep})
66 |
67 | # random seed for coordinates, velocities and Langevin thermostat
68 | variable seed equal 1234
69 |
70 | # Initialise coordinates and velocities randomly.
71 | units lj
72 | atom_style atomic
73 | dimension 2
74 | boundary p p p
75 | variable box_length_half equal ${box_length}/2.0
76 | region domain block -${box_length_half} ${box_length_half} &
77 | -${box_length_half} ${box_length_half} &
78 | -0.1 0.1 units box
79 | create_box 1 domain
80 | create_atoms 1 random ${num_particles} ${seed} domain
81 | set type 1 z 0.0
82 | mass 1 1.0
83 | velocity all create ${temperature_ref} ${seed}
84 |
85 | # Zero out forces, velocities and z-coordinates of particles at every step.
86 | set type 1 z 0.0
87 | fix f2d all enforce2d
88 |
89 | # Define the interaction potential (pair-style).
90 | pair_style lj/cut ${box_length_half}
91 | pair_coeff 1 1 1.0 1.0
92 | neighbor 0.3 bin
93 | neigh_modify delay 0
94 |
95 | # Perform energy minimisation.
96 | minimize 1.0e-6 1.0e-8 100 1000
97 |
98 | # Perform the equilibration run.
99 | reset_timestep 0
100 | compute temperature all temp
101 | compute kinetic_energy all ke
102 | compute potential_energy all pe
103 | variable total_energy equal c_kinetic_energy+c_potential_energy
104 |
105 | # Compute the centre-of-mass velocity of the entire simulation box.
106 | variable vcmx equal "vcm(all,x)"
107 | variable vcmy equal "vcm(all,y)"
108 | variable vcmz equal "vcm(all,z)"
109 | variable vcm2 equal v_vcmx*v_vcmx+v_vcmy*v_vcmy+v_vcmz*v_vcmz
110 |
111 | # Apply a thermostat to all particles.
112 | compute langevin_temp all temp/partial 1 1 0
113 | fix flangevin all langevin ${temperature_ref} ${temperature_ref} &
114 | ${temperature_damp} ${seed} zero yes
115 | fix_modify flangevin temp langevin_temp
116 |
117 | # Specify the terminal output and frequency.
118 | thermo_style custom step c_temperature c_langevin_temp c_potential_energy &
119 | c_kinetic_energy v_total_energy press v_vcm2
120 | thermo_modify norm no
121 |
122 | thermo 100000
123 | timestep ${timestep}
124 |
125 | # Compute averages for temperature and energy.
126 | fix fnve all nve
127 |
128 | variable sample_size_equilibration equal ${steps_equilibration}/10
129 | fix fSampleEquilibration all ave/time 10 ${sample_size_equilibration} &
130 | ${steps_equilibration} c_temperature v_total_energy
131 | run ${steps_equilibration}
132 |
133 | variable temperature_nvt_equi equal $(f_fSampleEquilibration[1])
134 | variable energy_nvt_equi equal $(f_fSampleEquilibration[2])
135 | unfix fSampleEquilibration
136 | print "Averages for NVT equilibration run:"
137 | print "Reference temperature = ${temperature_ref}"
138 | print "_NVT = ${temperature_nvt_equi}"
139 | print "_NVT = ${energy_nvt_equi}"
140 |
141 | # Perform a short NVE simulation to estimate the mean energy.
142 | unfix flangevin
143 | reset_timestep 0
144 | fix fNVEPreAverage all ave/time 10 ${sample_size_equilibration} &
145 | ${steps_equilibration} c_temperature v_total_energy
146 | run ${steps_equilibration}
147 |
148 | variable temperature_nve_pre equal $(f_fNVEPreAverage[1])
149 | variable energy_nve_pre equal $(f_fNVEPreAverage[2])
150 | unfix fNVEPreAverage
151 | print "Averages for short NVE equilibration run:"
152 | print "INFO: _NVE = ${temperature_nve_pre}"
153 | print "INFO: _NVE = ${energy_nve_pre}"
154 |
155 | # Rescale velocities so that the energy is adjusted to the NVT average. This
156 | # makes sure that the system will be close to the reference temperature during
157 | # the production run even in the absence of a thermostat.
158 | reset_timestep 0
159 | variable dQ equal ${energy_nvt_equi}-${energy_nve_pre}
160 | variable dQdt1 equal ${dQ}/100/${timestep}
161 | fix fHeat all heat 1 ${dQdt1} region domain
162 | run 100
163 | unfix fHeat
164 |
165 | # Perform a second NVE equilibration run.
166 | reset_timestep 0
167 | run ${steps_equilibration}
168 |
169 | # Perform the production run.
170 | reset_timestep 0
171 |
172 | variable sample_size_production equal ${steps_production}/10
173 | fix fSampleProduction all ave/time 10 ${sample_size_production} &
174 | ${steps_production} c_temperature v_total_energy
175 |
176 | dump 1 all custom ${log_every} /tmp/trajectory_file.dat id type x y vx vy fx fy
177 | dump_modify 1 sort id
178 | fix 2 all ave/time ${log_every} 1 ${log_every} c_temperature c_kinetic_energy &
179 | c_potential_energy v_total_energy file /tmp/observable_file.dat
180 | run ${steps_production}
181 |
182 | variable temperature_prod equal $(f_fSampleProduction[1])
183 | variable energy_prod equal $(f_fSampleProduction[2])
184 | unfix fSampleProduction
185 |
186 | print "Averages for NVE production run:"
187 | print "_NVE = = ${temperature_prod}"
188 | print "_NVE = = ${energy_prod}"
189 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/multiagent_dynamics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/multiagent_dynamics/game_dynamics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Continuous-time two player zero-sum game dynamics."""
16 | from typing import Any, Mapping
17 |
18 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils
19 | import jax.numpy as jnp
20 | import jax.random as jnr
21 | import numpy as np
22 | from numpy.lib.stride_tricks import as_strided
23 | import scipy
24 | from open_spiel.python.egt import dynamics as egt_dynamics
25 | from open_spiel.python.egt import utils as egt_utils
26 | from open_spiel.python.pybind11 import pyspiel
27 |
28 |
29 | def sample_from_simplex(
30 | num_samples: int,
31 | rng: jnp.ndarray,
32 | dim: int = 3,
33 | vmin: float = 0.
34 | ) -> jnp.ndarray:
35 | """Samples random points from a k-simplex. See D. B. Rubin (1981), p131."""
36 | # This is a jax version of open_spiel.python.egt.utils.sample_from_simplex.
37 | assert vmin >= 0.
38 | p = jnr.uniform(rng, shape=(num_samples, dim - 1))
39 | p = jnp.sort(p, axis=1)
40 | p = jnp.hstack((jnp.zeros((num_samples, 1)), p, jnp.ones((num_samples, 1))))
41 | return (p[:, 1:] - p[:, 0:-1]) * (1 - 2 * vmin) + vmin
42 |
43 |
44 | def get_payoff_tensor(game_name: str) -> jnp.ndarray:
45 | """Returns the payoff tensor of a game."""
46 | game = pyspiel.extensive_to_tensor_game(pyspiel.load_game(game_name))
47 | assert game.get_type().utility == pyspiel.GameType.Utility.ZERO_SUM
48 | payoff_tensor = egt_utils.game_payoffs_array(game)
49 | return payoff_tensor
50 |
51 |
52 | def tile_array(a: jnp.ndarray, b0: int, b1: int) -> jnp.ndarray:
53 | r, c = a.shape # number of rows/columns
54 | rs, cs = a.strides # row/column strides
55 | x = as_strided(a, (r, b0, c, b1), (rs, 0, cs, 0)) # view a as larger 4D array
56 | return x.reshape(r * b0, c * b1) # create new 2D array
57 |
58 |
59 | class ZeroSumGame:
60 | """Generate trajectories from zero-sum game dynamics."""
61 |
62 | def __init__(
63 | self,
64 | game_name: str,
65 | dynamics: str = 'replicator',
66 | method: str = 'scipy'):
67 | self.payoff_tensor = get_payoff_tensor(game_name)
68 | assert self.payoff_tensor.shape[0] == 2, 'Only supports two-player games.'
69 | dyn_fun = getattr(egt_dynamics, dynamics)
70 | self.dynamics = egt_dynamics.MultiPopulationDynamics(
71 | self.payoff_tensor, dyn_fun)
72 | self.method = method
73 | self.scipy_ivp_kwargs = dict(rtol=1e-12, atol=1e-12)
74 |
75 | def sample_x0(self, num_samples: int, rng_key: jnp.ndarray) -> jnp.ndarray:
76 | """Samples initial states."""
77 | nrows, ncols = self.payoff_tensor.shape[1:]
78 | key1, key2 = jnr.split(rng_key)
79 | x0_1 = sample_from_simplex(num_samples, key1, dim=nrows)
80 | x0_2 = sample_from_simplex(num_samples, key2, dim=ncols)
81 | x0 = jnp.hstack((x0_1, x0_2))
82 | return x0
83 |
84 | def generate_trajectories(
85 | self,
86 | x0: jnp.ndarray,
87 | t0: utils.FloatArray,
88 | t_eval: jnp.ndarray
89 | ) -> jnp.ndarray:
90 | """Generates trajectories of the system in phase space.
91 |
92 | Args:
93 | x0: Initial state.
94 | t0: The time instance of the initial state y0.
95 | t_eval: Times at which to return the computed solution.
96 |
97 | Returns:
98 | Trajectories of size BxTxD (batch, time, phase-space-dim).
99 | """
100 | if self.method == 'scipy':
101 | x0_shape = x0.shape
102 |
103 | def fun(_, y):
104 | y = y.reshape(x0_shape)
105 | y_next = np.apply_along_axis(self.dynamics, -1, y)
106 | return y_next.reshape([-1])
107 |
108 | t_span = (t0, float(t_eval[-1]))
109 | solution = scipy.integrate.solve_ivp(
110 | fun=fun,
111 | t_span=t_span,
112 | y0=x0.reshape([-1]), # Scipy requires flat input.
113 | t_eval=t_eval,
114 | **self.scipy_ivp_kwargs)
115 | x = solution.y.reshape(x0_shape + (t_eval.size,))
116 | x = np.moveaxis(x, -1, 1) # Make time 2nd dimension.
117 | else:
118 | raise ValueError(f'Method={self.method} not supported.')
119 | return x
120 |
121 | def render_trajectories(self, x: jnp.ndarray) -> jnp.ndarray:
122 | """Maps from policies to joint-policy space."""
123 | nrows, ncols = self.payoff_tensor.shape[1:]
124 | x_1 = x[..., :nrows]
125 | x_2 = x[..., nrows:]
126 | x_1 = x_1.repeat(ncols, axis=-1).reshape(x.shape[:-1] + (nrows, ncols,))
127 | x_2 = x_2.repeat(nrows, axis=-1).reshape(x.shape[:-1] + (nrows, ncols,))
128 | x_2 = x_2.swapaxes(-2, -1)
129 | image = x_1 * x_2
130 |
131 | # Rescale to 32 x 32 from the original 2x2 or 3x3 data by expanding the
132 | # matrix to the nearest to 32 multiple of 2 or 3, evenly tiling it with the
133 | # original values, and then taking a 32x32 top left slice of it
134 | temp_image = [
135 | tile_array(x, np.ceil(32 / x.shape[0]).astype('int'),
136 | np.ceil(32 / x.shape[1]).astype('int'))[:32, :32]
137 | for x in np.squeeze(image)
138 | ]
139 | image = np.stack(temp_image)
140 | image = np.repeat(np.expand_dims(image, -1), 3, axis=-1)
141 |
142 | return image[None, ...]
143 |
144 | def generate_and_render(
145 | self,
146 | num_trajectories: int,
147 | rng_key: jnp.ndarray,
148 | t0: utils.FloatArray,
149 | t_eval: utils.FloatArray
150 | ) -> Mapping[str, Any]:
151 | """Generates trajectories and renders them.
152 |
153 | Args:
154 | num_trajectories: The number of trajectories to generate.
155 | rng_key: PRNG key for sampling any random numbers.
156 | t0: The time instance of the initial state y0.
157 | t_eval: Times at which to return the computed solution.
158 |
159 | Returns:
160 | A dictionary containing the following elements:
161 | 'x': A numpy array representation of the phase space vector.
162 | 'dx_dt': The time derivative of 'x'.
163 | 'image': An image representation of the state.
164 | """
165 | rng_key, key = jnr.split(rng_key)
166 | x0 = self.sample_x0(num_trajectories, key)
167 | x = self.generate_trajectories(x0, t0, t_eval)
168 | x = np.concatenate([x0[:, None], x], axis=1) # Add initial state.
169 | dx_dt = np.apply_along_axis(self.dynamics, -1, x)
170 | image = self.render_trajectories(x)
171 |
172 | return dict(x=x, dx_dt=dx_dt, image=image)
173 |
174 | def generate_and_render_dt(
175 | self,
176 | num_trajectories: int,
177 | rng_key: jnp.ndarray,
178 | t0: utils.FloatArray,
179 | dt: utils.FloatArray,
180 | num_steps: int
181 | ) -> Mapping[str, Any]:
182 | """Same as `generate_and_render` but uses `dt` and `num_steps`."""
183 | t_eval = utils.dt_to_t_eval(t0, dt, num_steps)
184 | return self.generate_and_render(num_trajectories, rng_key, t0, t_eval)
185 |
--------------------------------------------------------------------------------
/dm_hamiltonian_dynamics_suite/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Module for testing the generation and loading of datasets."""
16 | import os
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 |
21 | from dm_hamiltonian_dynamics_suite import datasets
22 | from dm_hamiltonian_dynamics_suite import load_datasets
23 |
24 | import jax
25 | from jax import numpy as jnp
26 | import tensorflow as tf
27 |
28 |
29 | DATASETS_TO_TEST = [
30 | "mass_spring",
31 | "pendulum_colors",
32 | "double_pendulum_colors",
33 | "two_body_colors",
34 | ]
35 |
36 | if datasets.open_spiel_available():
37 | DATASETS_TO_TEST += [
38 | "matching_pennies",
39 | "rock_paper_scissors",
40 | ]
41 |
42 |
43 | class TestToyDataset(parameterized.TestCase):
44 | """Test class for the functions in `tag_graph_matcher.py`."""
45 |
46 | def compare_structures_all_the_same(self, example, batched_example):
47 | """Compares that the two examples are identical in structure and value."""
48 | self.assertEqual(
49 | jax.tree_structure(example),
50 | jax.tree_structure(batched_example),
51 | "Structures should be the same."
52 | )
53 | # The real example image is not converted however
54 | example["image"] = tf.image.convert_image_dtype(
55 | example["image"], dtype=batched_example["image"].dtype).numpy()
56 | for v1, v2 in zip(jax.tree_leaves(example),
57 | jax.tree_leaves(batched_example)):
58 | self.assertEqual(v1.dtype, v2.dtype, "Dtypes should be the same.")
59 | self.assertEqual((1,) + v1.shape, v2.shape, "Shapes should be the same.")
60 | self.assertTrue(jnp.allclose(v1, v2[0]), "Values should be the same.")
61 |
62 | @parameterized.parameters(DATASETS_TO_TEST)
63 | def test_dataset(
64 | self,
65 | dataset,
66 | folder: str = "/tmp/dm_hamiltonian_dynamics_suite/tests/",
67 | dt: float = 0.1,
68 | num_steps: int = 100,
69 | steps_per_dt: int = 10,
70 | num_train: int = 10,
71 | num_test: int = 10,
72 | ):
73 | """Checks that the dataset generation and loading are working correctly."""
74 |
75 | # Generate the dataset
76 | train_examples, test_examples = datasets.generate_full_dataset(
77 | folder=folder,
78 | dataset=dataset,
79 | dt=dt,
80 | num_steps=num_steps,
81 | steps_per_dt=steps_per_dt,
82 | num_train=num_train,
83 | num_test=num_test,
84 | overwrite=True,
85 | return_generated_examples=True,
86 | )
87 |
88 | # Load train dataset
89 | dataset_path = dataset.lower() + "_dt_" + str(dt).replace(".", "_")
90 | ds = load_datasets.dataset_as_iter(
91 | load_datasets.load_dataset,
92 | path=os.path.join(folder, dataset_path),
93 | tfrecord_prefix="train",
94 | sub_sample_length=None,
95 | per_device_batch_size=1,
96 | num_epochs=1,
97 | drop_remainder=False,
98 | dtype="float64"
99 | )
100 | examples = tuple(x for x in ds())
101 | self.assertEqual(
102 | len(train_examples), len(examples),
103 | "Number of training examples not the same."
104 | )
105 | # Compare individual examples
106 | for example_1, example_2 in zip(train_examples, examples):
107 | self.compare_structures_all_the_same(example_1, example_2)
108 |
109 | # Load test dataset
110 | ds = load_datasets.dataset_as_iter(
111 | load_datasets.load_dataset,
112 | path=os.path.join(folder, dataset_path),
113 | tfrecord_prefix="test",
114 | sub_sample_length=None,
115 | per_device_batch_size=1,
116 | num_epochs=1,
117 | drop_remainder=False,
118 | dtype="float64"
119 | )
120 |
121 | examples = tuple(x for x in ds())
122 | self.assertEqual(
123 | len(test_examples), len(examples),
124 | "Number of test examples not the same."
125 | )
126 | # Compare individual examples
127 | for example_1, example_2 in zip(test_examples, examples):
128 | self.compare_structures_all_the_same(example_1, example_2)
129 |
130 |
131 | if __name__ == "__main__":
132 | jax.config.update("jax_enable_x64", True)
133 | absltest.main()
134 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py>=0.12.0
2 | numpy>=1.16.4
3 | typing>=3.7.4.3
4 | scipy>=1.7.1
5 | open-spiel>=1.0.1
6 | tensorflow>=2.6.0
7 | tensorflow-datasets>=4.4.0
8 | jax==0.2.20
9 | jaxlib==0.1.71
10 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 DeepMind Technologies Limited.
2 | #
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # https://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Setup for pip package."""
16 | from setuptools import setup
17 |
18 | REQUIRED_PACKAGES = (
19 | "absl-py>=0.12.0",
20 | "numpy>=1.16.4",
21 | "typing>=3.7.4.3",
22 | "scipy>=1.7.1",
23 | "open-spiel>=1.0.1",
24 | "tensorflow>=2.6.0",
25 | "tensorflow-datasets>=4.4.0",
26 | "jax==0.2.20",
27 | "jaxlib==0.1.71"
28 | )
29 |
30 | LONG_DESCRIPTION = "\n".join([(
31 | "A suite of 17 datasets with phase space, high dimensional (visual) "
32 | "observations and other measurement where appropriate that are based on "
33 | "physical systems, exhibiting a Hamiltonian dynamics"
34 | )])
35 |
36 |
37 | setup(
38 | name="dm_hamiltonian_dynamics_suite",
39 | version="0.0.1",
40 | description="A collection of 17 datasets based on Hamiltonian physical "
41 | "systems.",
42 | long_description=LONG_DESCRIPTION,
43 | url="https://github.com/deepmind/dm_hamiltonian_dynamics_suite",
44 | author="DeepMind",
45 | packages=[
46 | "dm_hamiltonian_dynamics_suite",
47 | "dm_hamiltonian_dynamics_suite.hamiltonian_systems",
48 | "dm_hamiltonian_dynamics_suite.molecular_dynamics",
49 | "dm_hamiltonian_dynamics_suite.multiagent_dynamics",
50 | ],
51 | install_requires=REQUIRED_PACKAGES,
52 | platforms=["any"],
53 | license="Apache License, Version 2.0",
54 | )
55 |
--------------------------------------------------------------------------------
/visualize_datasets.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "d68zSuu0TX2_"
7 | },
8 | "source": [
9 | "**Copyright 2020 DeepMind Technologies Limited.**\n",
10 | "\n",
11 | "\n",
12 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
13 | "you may not use this file except in compliance with the License.\n",
14 | "You may obtain a copy of the License at\n",
15 | "\n",
16 | "https://www.apache.org/licenses/LICENSE-2.0\n",
17 | "\n",
18 | "Unless required by applicable law or agreed to in writing, software\n",
19 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
20 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
21 | "See the License for the specific language governing permissions and\n",
22 | "limitations under the License."
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {
29 | "cellView": "form",
30 | "id": "d5gKoACsMxPq"
31 | },
32 | "outputs": [],
33 | "source": [
34 | "#@title Install the package for loading the datasets\n",
35 | "!git clone https://github.com/deepmind/dm_hamiltonian_dynamics_suite.git\n",
36 | "!pip install ./dm_hamiltonian_dynamics_suite/"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {
43 | "cellView": "form",
44 | "id": "tpOOQRg_mJcU"
45 | },
46 | "outputs": [],
47 | "source": [
48 | "#@title Imports\n",
49 | "import functools\n",
50 | "import os\n",
51 | "import requests\n",
52 | "from subprocess import getstatusoutput\n",
53 | "from matplotlib import pyplot as plt\n",
54 | "from matplotlib import animation as plt_animation\n",
55 | "from matplotlib import rc\n",
56 | "import numpy as np\n",
57 | "from jax import config as jax_config\n",
58 | "import tensorflow as tf\n",
59 | "\n",
60 | "rc('animation', html='jshtml')\n",
61 | "jax_config.update(\"jax_enable_x64\", True)\n",
62 | "\n",
63 | "from dm_hamiltonian_dynamics_suite import load_datasets\n",
64 | "from dm_hamiltonian_dynamics_suite import datasets"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {
71 | "cellView": "form",
72 | "id": "tQjT27Ymspqr"
73 | },
74 | "outputs": [],
75 | "source": [
76 | "#@title Helper functions\n",
77 | "DATASETS_URL = \"gs://dm-hamiltonian-dynamics-suite\"\n",
78 | "DATASETS_FOLDER = \"./datasets\" #@param {type: \"string\"}\n",
79 | "os.makedirs(DATASETS_FOLDER, exist_ok=True)\n",
80 | "\n",
81 | "def download_file(file_url, destination_file):\n",
82 | " print(\"Downloading\", file_url, \"to\", destination_file)\n",
83 | " command = f\"gsutil cp {file_url} {destination_file}\"\n",
84 | " status_code, output = getstatusoutput(command)\n",
85 | " if status_code != 0:\n",
86 | " raise ValueError(output)\n",
87 | "\n",
88 | "def download_dataset(dataset_name: str):\n",
89 | " \"\"\"Downloads the provided dataset from the DM Hamiltonian Dataset Suite\"\"\"\n",
90 | " destination_folder = os.path.join(DATASETS_FOLDER, dataset_name)\n",
91 | " dataset_url = os.path.join(DATASETS_URL, dataset_name)\n",
92 | " os.makedirs(destination_folder, exist_ok=True)\n",
93 | " if \"long_trajectory\" in dataset_name:\n",
94 | " files = (\"features.txt\", \"test.tfrecord\")\n",
95 | " else:\n",
96 | " files = (\"features.txt\", \"train.tfrecord\", \"test.tfrecord\")\n",
97 | " for file_name in files:\n",
98 | " file_url = os.path.join(dataset_url, file_name)\n",
99 | " destination_file = os.path.join(destination_folder, file_name)\n",
100 | " if os.path.exists(destination_file):\n",
101 | " print(\"File\", file_url, \"already present.\")\n",
102 | " continue\n",
103 | " download_file(file_url, destination_file)\n",
104 | "\n",
105 | "\n",
106 | "def unstack(value: np.ndarray, axis: int = 0):\n",
107 | " \"\"\"Unstacks an array along an axis into a list\"\"\"\n",
108 | " split = np.split(value, value.shape[axis], axis=axis)\n",
109 | " return [np.squeeze(v, axis=axis) for v in split]\n",
110 | "\n",
111 | "\n",
112 | "def make_batch_grid(\n",
113 | " batch: np.ndarray, \n",
114 | " grid_height: int,\n",
115 | " grid_width: int, \n",
116 | " with_padding: bool = True):\n",
117 | " \"\"\"Makes a single grid image from a batch of multiple images.\"\"\"\n",
118 | " assert batch.ndim == 5\n",
119 | " assert grid_height * grid_width \u003e= batch.shape[0]\n",
120 | " batch = batch[:grid_height * grid_width]\n",
121 | " batch = batch.reshape((grid_height, grid_width) + batch.shape[1:])\n",
122 | " if with_padding:\n",
123 | " batch = np.pad(batch, pad_width=[[0, 0], [0, 0], [0, 0],\n",
124 | " [1, 0], [1, 0], [0, 0]],\n",
125 | " mode=\"constant\", constant_values=1.0)\n",
126 | " batch = np.concatenate(unstack(batch), axis=-3)\n",
127 | " batch = np.concatenate(unstack(batch), axis=-2)\n",
128 | " if with_padding:\n",
129 | " batch = batch[:, 1:, 1:]\n",
130 | " return batch\n",
131 | "\n",
132 | "\n",
133 | "def plot_animattion_from_batch(\n",
134 | " batch: np.ndarray, \n",
135 | " grid_height, \n",
136 | " grid_width, \n",
137 | " with_padding=True, \n",
138 | " figsize=None):\n",
139 | " \"\"\"Plots an animation of the batch of sequences.\"\"\"\n",
140 | " if figsize is None:\n",
141 | " figsize = (grid_width, grid_height)\n",
142 | " batch = make_batch_grid(batch, grid_height, grid_width, with_padding)\n",
143 | " batch = batch[:, ::-1]\n",
144 | " fig = plt.figure(figsize=figsize)\n",
145 | " plt.close()\n",
146 | " ax = fig.add_subplot(1, 1, 1)\n",
147 | " ax.axis('off')\n",
148 | " img = ax.imshow(batch[0]) \n",
149 | " def frame_update(i):\n",
150 | " i = int(np.floor(i).astype(\"int64\"))\n",
151 | " img.set_data(batch[i])\n",
152 | " return [img]\n",
153 | " anim = plt_animation.FuncAnimation(\n",
154 | " fig=fig, \n",
155 | " func=frame_update,\n",
156 | " frames=np.linspace(0.0, len(batch), len(batch) * 5 + 1)[:-1],\n",
157 | " save_count=len(batch),\n",
158 | " interval=10, \n",
159 | " blit=True\n",
160 | " )\n",
161 | " return anim\n",
162 | "\n",
163 | "\n",
164 | "def plot_sequence_from_batch(\n",
165 | " batch: np.ndarray,\n",
166 | " t_start: int = 0,\n",
167 | " with_padding: bool = True, \n",
168 | " fontsize: int = 20):\n",
169 | " \"\"\"Plots all of the sequences in the batch.\"\"\"\n",
170 | " n, t, dx, dy = batch.shape[:-1]\n",
171 | " xticks = np.linspace(dx // 2, t * (dx + 1) - 1 - dx // 2, t)\n",
172 | " xtick_labels = np.arange(t) + t_start\n",
173 | " yticks = np.linspace(dy // 2, n * (dy + 1) - 1 - dy // 2, n)\n",
174 | " ytick_labels = np.arange(n)\n",
175 | " batch = batch.reshape((n * t, 1) + batch.shape[2:])\n",
176 | " batch = make_batch_grid(batch, n, t, with_padding)[0]\n",
177 | " plt.imshow(batch.squeeze())\n",
178 | " plt.xticks(ticks=xticks, labels=xtick_labels, fontsize=fontsize)\n",
179 | " plt.yticks(ticks=yticks, labels=ytick_labels, fontsize=fontsize)\n",
180 | "\n",
181 | "\n",
182 | "def visalize_dataset(\n",
183 | " dataset_path: str,\n",
184 | " sequence_lengths: int = 60,\n",
185 | " grid_height: int = 2, \n",
186 | " grid_width: int = 5):\n",
187 | " \"\"\"Visualizes a dataset loaded from the path provided.\"\"\"\n",
188 | " split = \"test\"\n",
189 | " batch_size = grid_height * grid_width\n",
190 | " dataset = load_datasets.load_dataset(\n",
191 | " path=dataset_path,\n",
192 | " tfrecord_prefix=split,\n",
193 | " sub_sample_length=sequence_lengths,\n",
194 | " per_device_batch_size=batch_size,\n",
195 | " num_epochs=None,\n",
196 | " drop_remainder=True,\n",
197 | " shuffle=False,\n",
198 | " shuffle_buffer=100\n",
199 | " )\n",
200 | " sample = next(iter(dataset))\n",
201 | " batch_x = sample['x'].numpy()\n",
202 | " batch_image = sample['image'].numpy()\n",
203 | " # Plot real system dimensions\n",
204 | " plt.figure(figsize=(24, 8))\n",
205 | " for i in range(batch_x.shape[-1]):\n",
206 | " plt.subplot(1, batch_x.shape[-1], i + 1)\n",
207 | " plt.title(f\"Samples from dimension {i+1}\")\n",
208 | " plt.plot(batch_x[:, :, i].T)\n",
209 | " plt.show()\n",
210 | " # Plot a sequence of 50 images\n",
211 | " plt.figure(figsize=(30, 10))\n",
212 | " plt.title(\"Samples from 50 steps sub sequences.\")\n",
213 | " plot_sequence_from_batch(batch_image[:, :50])\n",
214 | " plt.show()\n",
215 | " # Plot animation\n",
216 | " return plot_animattion_from_batch(batch_image, grid_height, grid_width)"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": null,
222 | "metadata": {
223 | "cellView": "form",
224 | "id": "Wr0jMkPmLzdA"
225 | },
226 | "outputs": [],
227 | "source": [
228 | "#@title Generate a small dataset and visualize it\n",
229 | "folder_to_store = \"./generated_datasets\" #@param {type:\"string\"}\n",
230 | "dataset = \"pendulum_colors\" #@param {type: \"string\"}\n",
231 | "dt = 0.1 #@param {type: \"number\"}\n",
232 | "num_steps = 100 #@param {type: \"integer\"}\n",
233 | "steps_per_dt = 1 #@param {type: \"integer\"}\n",
234 | "num_train = 100 #@param {type: \"integer\"}\n",
235 | "num_test = 10 #@param {type: \"integer\"}\n",
236 | "overwrite = True #@param {type: \"boolean\"}\n",
237 | "datasets.generate_full_dataset(\n",
238 | " folder=folder_to_store,\n",
239 | " dataset=dataset,\n",
240 | " dt=dt,\n",
241 | " num_steps=num_steps,\n",
242 | " steps_per_dt=steps_per_dt,\n",
243 | " num_train=num_train,\n",
244 | " num_test=num_test,\n",
245 | " overwrite=overwrite,\n",
246 | ")\n",
247 | "dataset_full_name = dataset + \"_dt_\" + str(dt).replace(\".\", \"_\")\n",
248 | "dataset_path = os.path.join(folder_to_store, dataset_full_name)\n",
249 | "visalize_dataset(dataset_path)"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": null,
255 | "metadata": {
256 | "cellView": "form",
257 | "id": "d6Fn8eaBentd"
258 | },
259 | "outputs": [],
260 | "source": [
261 | "#@title Download and visualise a dataset\n",
262 | "dataset_name = \"toy_physics/mass_spring\" #@param [\"toy_physics/mass_spring\", \"toy_physics/mass_spring_colors\", \"toy_physics/mass_spring_colors_friction\", \"toy_physics/mass_spring_long_trajectory\", \"toy_physics/mass_spring_colors_long_trajectory\", \"toy_physics/pendulum\", \"toy_physics/pendulum_colors\", \"toy_physics/pendulum_colors_friction\", \"toy_physics/pendulum_long_trajectory\", \"toy_physics/pendulum_colors_long_trajectory\", \"toy_physics/double_pendulum\", \"toy_physics/double_pendulum_colors\", \"toy_physics/double_pendulum_colors_friction\", \"toy_physics/two_body\", \"toy_physics/two_body_colors\", \"molecular_dynamics/lj_4\", \"molecular_dynamics/lj_16\", \"multi_agent/matching_pennies\", \"multi_agent/matching_pennies_long_trajectory\", \"multi_agent/rock_paper_scissors\", \"multi_agent/rock_paper_scissors_long_trajectory\", \"mujoco_room/circle\", \"mujoco_room/spiral\"]\n",
263 | "download_dataset(dataset_name)\n",
264 | "visalize_dataset(os.path.join(DATASETS_FOLDER, dataset_name))"
265 | ]
266 | }
267 | ],
268 | "metadata": {
269 | "colab": {
270 | "collapsed_sections": [],
271 | "last_runtime": {},
272 | "name": "visualize_datasets.ipynb",
273 | "provenance": []
274 | },
275 | "kernelspec": {
276 | "display_name": "Python 3",
277 | "name": "python3"
278 | }
279 | },
280 | "nbformat": 4,
281 | "nbformat_minor": 0
282 | }
283 |
--------------------------------------------------------------------------------