├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dm_hamiltonian_dynamics_suite ├── .pylintrc ├── __init__.py ├── datasets.py ├── generate_dataset.py ├── generate_locally.sh ├── hamiltonian_systems │ ├── __init__.py │ ├── hamiltonian.py │ ├── ideal_double_pendulum.py │ ├── ideal_mass_spring.py │ ├── ideal_pendulum.py │ ├── n_body.py │ ├── phase_space.py │ ├── simple_analytic.py │ └── utils.py ├── load_datasets.py ├── molecular_dynamics │ ├── __init__.py │ ├── generate_dataset.py │ └── lj_16.lmp ├── multiagent_dynamics │ ├── __init__.py │ └── game_dynamics.py └── tests │ └── test_datasets.py ├── requirements.txt ├── setup.py └── visualize_datasets.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | *.csv 4 | *.npz 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepMind Hamiltonian Dynamics Suite 2 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/dm_hamiltonian_dynamics_suite/blob/master/visualize_datasets.ipynb) 3 | 4 | The repository contains the code used for generating the 5 | [DM Hamiltonian Dynamics Suite]. 6 | 7 | The code for the models and experiments in our paper can be found 8 | [here](https://github.com/deepmind/deepmind-research/tree/master/physics_inspired_models), 9 | together with the code used for our concurrent publication on how to measure the 10 | quality of the learnt dynamics in models using the Hamiltonian prior when 11 | learning from pixels. 12 | 13 | 14 | ## Datasets 15 | 16 | The suite contains 17 datasets ranging from simple physics problems 17 | (Toy Physics) datasets, to more realistic datasets of molecular dynamics 18 | (Molecular Dynamics), learning dynamics in non-transitive zero-sum games 19 | (Multi Agent), and motion in 3D simulated environments (Mujoco Room). The 20 | datasets vary in terms of the complexity of simulated dynamics and visual 21 | richness. 22 | 23 | For each dataset we created 50k training trajectories, and 20k test trajectories 24 | , with each trajectory including image observations, ground truth phase state 25 | used to generate the data, the first time derivative of the ground truth state, 26 | and any hyper-parameters of individual trajectories. For a few of the datasets 27 | we generate a small number of long trajectories which are used purely for 28 | evaluation. 29 | 30 | 31 | ### Toy Physics 32 | 33 | For all simulated systems we take the trajectories samples at every 34 | `Δt = 0.05` intervals. For any of the non-conservative variants of each 35 | dataset we set the friction coefficient to `0.05`. All hyper-parameters which 36 | can be randomized are always sampled and kept fixed throughout each trajectory. 37 | The colors of the particles, when being randomized, are always sampled uniformly 38 | from some fixed number of options. The exact configurations used for generating 39 | the datasets in the suite can be found in the [datasets.py] file. All systems 40 | were simulated using `scipy.integrate.solve_ivp` . 41 | 42 | #### Mass Spring 43 | 44 | This dataset describes a simple harmonic motion of a particle attached to a 45 | spring. The system has two hyper-parameters - the spring force coefficient `k` 46 | and the mass of the particle `m`. The initial positions and momenta are sampled 47 | jointly from an annulus, where the radius is in the interval 48 | [rlow, rhigh]. One can choose whether the distribution to 49 | sample from is uniform in the annulus, or otherwise to sample uniformly the 50 | length of the radius. To render the system on an image we visualize just the 51 | particle as a circle, with a radius proportional to the square root of its mass. 52 | When rendering, there are also a three additional hyper-parameters - whether to 53 | randomize the horizontal position of the particle, since its motion is only 54 | vertical, whether to also shift around in any direction the anchor point of the 55 | spring, and how many possible colors can the circle representing the particle 56 | can take. Finally, we one can in addition simulate a non-conservative system by 57 | setting the friction coefficient to non-zero. 58 | 59 | #### Pendulum 60 | 61 | This dataset describes the evolution of a particle attached to a pivot, such 62 | that it can move freely. The system is simulated in angle space, such that it is 63 | one dimensional. It has three hyper-parameters - the mass of the particle `m`, 64 | the gravitational constant `g` and the pivot length `l`. The initial positions 65 | and momenta are sampled jointly from an annulus, where the radius is in the 66 | interval [rlow, rhigh]. One can choose whether the 67 | distribution to sample from is uniform in the annulus, or otherwise to sample 68 | uniformly the length of the radius. To render the system on an image we 69 | visualize just the particle as a circle, with radius proportioned to the square 70 | root of its mass. When rendering, there are also a two additional 71 | hyper-parameters - whether to also shift around in any direction the anchor 72 | point of the pivot and how many possible colors can the circle representing the 73 | particle can take. Finally, we one can in addition simulate a non-conservative 74 | system by setting the friction coefficient to non-zero. 75 | 76 | #### Double Pendulum 77 | 78 | This dataset describes the evolution of two coupled pendulums, where the second 79 | one's anchor point of its pivot is the center of the particle of the first one. 80 | This leads to significantly more complicated dynamics [2]. All the 81 | hyper-parameters are equivalent to those in the Pendulum dataset and follow the 82 | exact same protocol. 83 | 84 | #### Two Body 85 | 86 | This dataset describes the gravitational motion of two particles in the plane. 87 | The system has three hyper-parameters - the masses of the two particles `m_1` 88 | and `m_2` and the gravitational constant `g`. The positions and momenta of each 89 | particle are sampled jointly from an annulus, where the radius is in the 90 | interval [rlow, rhigh]. To render the system on an image 91 | we visualize just each particle as a circle, with radius proportioned to the 92 | square root of its mass. When rendering, there are also a two additional 93 | hyper-parameters - whether to also shift around in any direction the center of 94 | mass of the system and how many possible colors can the circles representing the 95 | particles can take. 96 | 97 | 98 | ### Multi Agent 99 | 100 | These datasets describe the dynamics of non-transitive zero-sum games. Here we 101 | consider two prominent examples of such games: **matching pennies** and 102 | **rock-paper-scissors**. We use the well-known continuous-time multi-population 103 | replicator dynamics to drive the learning process. The ground-truth trajectories 104 | are generated by integrating the coupled set of ODEs using an improved Euler 105 | scheme or RK45. In both cases the ground-truth state, i.e., joint strategy 106 | profile (joint policy), and its first order time derivative, is recorded at 107 | regular time intervals `Δt = 0.1`. Trajectories start from uniformly sampled 108 | points on the product of the policy simplexes. No noise is added to the 109 | trajectories. 110 | 111 | As all other datasets use images as inputs, we define the observation as the 112 | outer product of the strategy profiles of the two players. The resulting matrix 113 | captures the probability mass that falls on each pure joint strategy profile 114 | (joint action). In this dataset, the observations are a loss-less representation 115 | of the ground-truth state and are upsampled to `32 x 32 x 3` images through 116 | tiling. 117 | 118 | 119 | ### Mujoco Room 120 | 121 | These datasets are composed of multiple scenes each consisting of a camera 122 | moving around a room with 5 randomly placed objects. The objects were sampled 123 | from four shape types: a sphere, a capsule, a cylinder and a box. Each room was 124 | different due to the randomly sampled colors of the wall, floor and objects. The 125 | dynamics were created by motion and rotation of the camera. The **cirlce** 126 | dataset is generated by rotating the camera around a single randomly sampled 127 | parallel of the unit hemisphere centered around the middle of the room. The 128 | **spiral** dataset is generated by rotating the camera on a spiral moving down 129 | the unit hemisphere. For each trajectory an initial radius and angle are sampled 130 | and then converted into the Cartesian coordinates of the camera. The dynamics 131 | are discretised by moving the camera using step size of `0.1` degrees in a way 132 | that keeps the camera on the unit hemisphere while facing the center of the 133 | room. For the **spiral** dataset, the camera path traces out a golden spiral 134 | starting at the height corresponding to the originally sampled radius on the 135 | unit hemisphere. The rendered scenes are used as observations, and the Cartesian 136 | coordinates of the camera and its velocities estimated through finite 137 | differences as the state. Each trajectory was generated using [MuJoCo]. 138 | 139 | ### Molecular Dynamics 140 | 141 | These datasets comprise a type of interaction potential commonly studied 142 | using computer simulation techniques, such as molecular dynamics or Monte Carlo 143 | simulations. In particular, we generated two datasets employing a Lennard-Jones 144 | potential of increasing complexity: one comprising only 4 particles at a very 145 | low density and another one for a 16-particle liquid at a higher density. For 146 | rendering these datasets we used the same scheme as for the Toy Physics datasets. 147 | All masses are set to unity and we represent particles by circles of equal size 148 | with a radius value adjusted to fit the canvas well. The illustrations are 149 | therefore not representative of the density of the system. In addition, we 150 | assigned different colors to the particles to facilitate tracking their 151 | trajectories. 152 | 153 | We created the datasets in two steps: we first generated the raw molecular 154 | dynamics data using the simulation software [LAMMPS], and then converted the 155 | resulting trajectories into a trainable format. For the final datasets available 156 | for download, we combined simulation data from 100 different molecular dynamics 157 | trajectories, each corresponding to a different random initialization 158 | (see Appendix 1.3 of the paper for details). Here we provide a LAMMPS input 159 | script [lj_16.lmp] to generate data for a single seed and a script 160 | [generate_dataset.py] to turn the text-based simulation output into 161 | a trainable format. By default, the simulation is set up for the 16-particle 162 | system, but we provide inline comments on which lines need changing for the 163 | 4-particle dataset. 164 | 165 | 166 | 167 | ## Installation 168 | 169 | All package requirements are listed in `requirements.txt`. To install the code 170 | run in your shell the following commands: 171 | 172 | ```shell 173 | git clone https://github.com/deepmind/dm_hamiltonian_dynamics_suite 174 | pip install -r ./dm_hamiltonian_dynamics_suite/requirements.txt 175 | pip install ./dm_hamiltonian_dynamics_suite 176 | pip install --upgrade "jax[XXX]" 177 | ``` 178 | 179 | where `XXX` is the correct type of accelerator that you have on your machine. 180 | Note that if you are using a GPU you might need `XXX` to also include the 181 | correct version of CUDA and cuDNN installed on your machine. 182 | For more details please read [here](https://github.com/google/jax#installation). 183 | 184 | 185 | ## Usage 186 | 187 | You can find an example of how to generate a dataset and the load and visualize 188 | them in the [Colab notebook provided]. 189 | 190 | 191 | ## References 192 | 193 | **Which priors matter? Benchmarking models for learning latent dynamics** 194 | 195 | Aleksandar Botev, Drew Jaegle, Peter Wirnsberger, Daniel Hennes and Irina 196 | Higgins 197 | 198 | URL: https://openreview.net/forum?id=qBl8hnwR0px 199 | 200 | **SyMetric: Measuring the Quality of Learnt Hamiltonian Dynamics Inferred from Vision** 201 | 202 | Irina Higgins, Peter Wirnsberger, Andrew Jaegle, Aleksandar Botev 203 | 204 | URL: https://openreview.net/forum?id=9Qu0U9Fj7IP 205 | 206 | ## Disclaimer 207 | 208 | This is not an official Google product. 209 | 210 | [DM Hamiltonian Dynamics Suite]: https://console.cloud.google.com/storage/browser/dm-hamiltonian-dynamics-suite 211 | [Colab notebook provided]: https://colab.research.google.com/github/deepmind/dm_hamiltonian_dynamics_suite/blob/master/visualize_datasets.ipynb 212 | [datasets.py]: https://github.com/deepmind/dm_hamiltonian_dynamics_suite/blob/master/dm_hamiltonian_dynamics_suite/datasets.py 213 | [lj_16.lmp]: https://github.com/deepmind/dm_hamiltonian_dynamics_suite/blob/master/dm_hamiltonian_dynamics_suite/molecular_dynamics/lj_16.lmp 214 | [generate_dataset.py]: https://github.com/deepmind/dm_hamiltonian_dynamics_suite/blob/master/dm_hamiltonian_dynamics_suite/generate_dataset.py 215 | [LAMMPS]: https://lammps.sandia.gov/ 216 | [MuJoCo]: http://www.mujoco.org/ 217 | 218 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/.pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MASTER] 9 | 10 | # Files or directories to be skipped. They should be base names, not paths. 11 | ignore=third_party 12 | 13 | # Files or directories matching the regex patterns are skipped. The regex 14 | # matches against base names, not paths. 15 | ignore-patterns= 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=no 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | # Use multiple processes to speed up Pylint. 25 | jobs=4 26 | 27 | # Allow loading of arbitrary C extensions. Extensions are imported into the 28 | # active Python interpreter and may run arbitrary code. 29 | unsafe-load-any-extension=no 30 | 31 | 32 | [MESSAGES CONTROL] 33 | 34 | # Only show warnings with the listed confidence levels. Leave empty to show 35 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 36 | confidence= 37 | 38 | # Enable the message, report, category or checker with the given id(s). You can 39 | # either give multiple identifier separated by comma (,) or put this option 40 | # multiple time (only on the command line, not in the configuration file where 41 | # it should appear only once). See also the "--disable" option for examples. 42 | #enable= 43 | 44 | # Disable the message, report, category or checker with the given id(s). You 45 | # can either give multiple identifiers separated by comma (,) or put this 46 | # option multiple times (only on the command line, not in the configuration 47 | # file where it should appear only once).You can also use "--disable=all" to 48 | # disable everything first and then reenable specific checks. For example, if 49 | # you want to run only the similarities checker, you can use "--disable=all 50 | # --enable=similarities". If you want to run only the classes checker, but have 51 | # no Warning level messages displayed, use"--disable=all --enable=classes 52 | # --disable=W" 53 | disable=abstract-method, 54 | apply-builtin, 55 | arguments-differ, 56 | attribute-defined-outside-init, 57 | backtick, 58 | bad-option-value, 59 | basestring-builtin, 60 | buffer-builtin, 61 | c-extension-no-member, 62 | consider-using-enumerate, 63 | cmp-builtin, 64 | cmp-method, 65 | coerce-builtin, 66 | coerce-method, 67 | delslice-method, 68 | div-method, 69 | duplicate-code, 70 | eq-without-hash, 71 | execfile-builtin, 72 | file-builtin, 73 | filter-builtin-not-iterating, 74 | fixme, 75 | getslice-method, 76 | global-statement, 77 | hex-method, 78 | idiv-method, 79 | implicit-str-concat-in-sequence, 80 | import-error, 81 | import-self, 82 | import-star-module-level, 83 | inconsistent-return-statements, 84 | input-builtin, 85 | intern-builtin, 86 | invalid-str-codec, 87 | locally-disabled, 88 | long-builtin, 89 | long-suffix, 90 | map-builtin-not-iterating, 91 | misplaced-comparison-constant, 92 | missing-function-docstring, 93 | metaclass-assignment, 94 | next-method-called, 95 | next-method-defined, 96 | no-absolute-import, 97 | no-else-break, 98 | no-else-continue, 99 | no-else-raise, 100 | no-else-return, 101 | no-init, # added 102 | no-member, 103 | no-name-in-module, 104 | no-self-use, 105 | nonzero-method, 106 | oct-method, 107 | old-division, 108 | old-ne-operator, 109 | old-octal-literal, 110 | old-raise-syntax, 111 | parameter-unpacking, 112 | print-statement, 113 | raising-string, 114 | range-builtin-not-iterating, 115 | raw_input-builtin, 116 | rdiv-method, 117 | reduce-builtin, 118 | relative-import, 119 | reload-builtin, 120 | round-builtin, 121 | setslice-method, 122 | signature-differs, 123 | standarderror-builtin, 124 | suppressed-message, 125 | sys-max-int, 126 | too-few-public-methods, 127 | too-many-ancestors, 128 | too-many-arguments, 129 | too-many-boolean-expressions, 130 | too-many-branches, 131 | too-many-instance-attributes, 132 | too-many-locals, 133 | too-many-nested-blocks, 134 | too-many-public-methods, 135 | too-many-return-statements, 136 | too-many-statements, 137 | trailing-newlines, 138 | unichr-builtin, 139 | unicode-builtin, 140 | unnecessary-pass, 141 | unpacking-in-except, 142 | useless-else-on-loop, 143 | useless-object-inheritance, 144 | useless-suppression, 145 | using-cmp-argument, 146 | wrong-import-order, 147 | xrange-builtin, 148 | zip-builtin-not-iterating, 149 | 150 | 151 | [REPORTS] 152 | 153 | # Set the output format. Available formats are text, parseable, colorized, msvs 154 | # (visual studio) and html. You can also give a reporter class, eg 155 | # mypackage.mymodule.MyReporterClass. 156 | output-format=text 157 | 158 | # Put messages in a separate file for each module / package specified on the 159 | # command line instead of printing them on stdout. Reports (if any) will be 160 | # written in a file name "pylint_global.[txt|html]". This option is deprecated 161 | # and it will be removed in Pylint 2.0. 162 | files-output=no 163 | 164 | # Tells whether to display a full report or only the messages 165 | reports=no 166 | 167 | # Python expression which should return a note less than 10 (10 is the highest 168 | # note). You have access to the variables errors warning, statement which 169 | # respectively contain the number of errors / warnings messages and the total 170 | # number of statements analyzed. This is used by the global evaluation report 171 | # (RP0004). 172 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 173 | 174 | # Template used to display messages. This is a python new-style format string 175 | # used to format the message information. See doc for all details 176 | #msg-template= 177 | 178 | 179 | [BASIC] 180 | 181 | # Good variable names which should always be accepted, separated by a comma 182 | good-names=main,_ 183 | 184 | # Bad variable names which should always be refused, separated by a comma 185 | bad-names= 186 | 187 | # Colon-delimited sets of names that determine each other's naming style when 188 | # the name regexes allow several styles. 189 | name-group= 190 | 191 | # Include a hint for the correct naming format with invalid-name 192 | include-naming-hint=no 193 | 194 | # List of decorators that produce properties, such as abc.abstractproperty. Add 195 | # to this list to register other decorators that produce valid properties. 196 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 197 | 198 | # Regular expression matching correct function names 199 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 200 | 201 | # Regular expression matching correct variable names 202 | variable-rgx=^[a-z][a-z0-9_]*$ 203 | 204 | # Regular expression matching correct constant names 205 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 206 | 207 | # Regular expression matching correct attribute names 208 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 209 | 210 | # Regular expression matching correct argument names 211 | argument-rgx=^[a-z][a-z0-9_]*$ 212 | 213 | # Regular expression matching correct class attribute names 214 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 215 | 216 | # Regular expression matching correct inline iteration names 217 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 218 | 219 | # Regular expression matching correct class names 220 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 221 | 222 | # Regular expression matching correct module names 223 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 224 | 225 | # Regular expression matching correct method names 226 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 227 | 228 | # Regular expression which should only match function or class names that do 229 | # not require a docstring. 230 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 231 | 232 | # Minimum line length for functions/classes that require docstrings, shorter 233 | # ones are exempt. 234 | docstring-min-length=10 235 | 236 | 237 | [TYPECHECK] 238 | 239 | # List of decorators that produce context managers, such as 240 | # contextlib.contextmanager. Add to this list to register other decorators that 241 | # produce valid context managers. 242 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 243 | 244 | # Tells whether missing members accessed in mixin class should be ignored. A 245 | # mixin class is detected if its name ends with "mixin" (case insensitive). 246 | ignore-mixin-members=yes 247 | 248 | # List of module names for which member attributes should not be checked 249 | # (useful for modules/projects where namespaces are manipulated during runtime 250 | # and thus existing member attributes cannot be deduced by static analysis. It 251 | # supports qualified module names, as well as Unix pattern matching. 252 | ignored-modules= 253 | 254 | # List of class names for which member attributes should not be checked (useful 255 | # for classes with dynamically set attributes). This supports the use of 256 | # qualified names. 257 | ignored-classes=optparse.Values,thread._local,_thread._local 258 | 259 | # List of members which are set dynamically and missed by pylint inference 260 | # system, and so shouldn't trigger E1101 when accessed. Python regular 261 | # expressions are accepted. 262 | generated-members= 263 | 264 | 265 | [FORMAT] 266 | 267 | # Maximum number of characters on a single line. 268 | max-line-length=80 269 | 270 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt 271 | # lines made too long by directives to pytype. 272 | 273 | # Regexp for a line that is allowed to be longer than the limit. 274 | ignore-long-lines=(?x)( 275 | ^\s*(\#\ )??$| 276 | ^\s*(from\s+\S+\s+)?import\s+.+$) 277 | 278 | # Allow the body of an if to be on the same line as the test if there is no 279 | # else. 280 | single-line-if-stmt=yes 281 | 282 | # List of optional constructs for which whitespace checking is disabled. `dict- 283 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 284 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 285 | # `empty-line` allows space-only lines. 286 | no-space-check= 287 | 288 | # Maximum number of lines in a module 289 | max-module-lines=99999 290 | 291 | # String used as indentation unit. The internal Google style guide mandates 2 292 | # spaces. Google's externaly-published style guide says 4, consistent with 293 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 294 | # projects (like TensorFlow). 295 | indent-string=' ' 296 | 297 | # Number of spaces of indent required inside a hanging or continued line. 298 | indent-after-paren=4 299 | 300 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 301 | expected-line-ending-format= 302 | 303 | 304 | [MISCELLANEOUS] 305 | 306 | # List of note tags to take in consideration, separated by a comma. 307 | notes=TODO 308 | 309 | 310 | [STRING] 311 | 312 | # This flag controls whether inconsistent-quotes generates a warning when the 313 | # character used as a quote delimiter is used inconsistently within a module. 314 | check-quote-consistency=yes 315 | 316 | 317 | [VARIABLES] 318 | 319 | # Tells whether we should check for unused import in __init__ files. 320 | init-import=no 321 | 322 | # A regular expression matching the name of dummy variables (i.e. expectedly 323 | # not used). 324 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 325 | 326 | # List of additional names supposed to be defined in builtins. Remember that 327 | # you should avoid to define new builtins when possible. 328 | additional-builtins= 329 | 330 | # List of strings which can identify a callback function by name. A callback 331 | # name must start or end with one of those strings. 332 | callbacks=cb_,_cb 333 | 334 | # List of qualified module names which can have objects that can redefine 335 | # builtins. 336 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 337 | 338 | 339 | [LOGGING] 340 | 341 | # Logging modules to check that the string format arguments are in logging 342 | # function parameter format 343 | logging-modules=logging,absl.logging,tensorflow.io.logging 344 | 345 | 346 | [SIMILARITIES] 347 | 348 | # Minimum lines number of a similarity. 349 | min-similarity-lines=4 350 | 351 | # Ignore comments when computing similarities. 352 | ignore-comments=yes 353 | 354 | # Ignore docstrings when computing similarities. 355 | ignore-docstrings=yes 356 | 357 | # Ignore imports when computing similarities. 358 | ignore-imports=no 359 | 360 | 361 | [SPELLING] 362 | 363 | # Spelling dictionary name. Available dictionaries: none. To make it working 364 | # install python-enchant package. 365 | spelling-dict= 366 | 367 | # List of comma separated words that should not be checked. 368 | spelling-ignore-words= 369 | 370 | # A path to a file that contains private dictionary; one word per line. 371 | spelling-private-dict-file= 372 | 373 | # Tells whether to store unknown words to indicated private dictionary in 374 | # --spelling-private-dict-file option instead of raising a message. 375 | spelling-store-unknown-words=no 376 | 377 | 378 | [IMPORTS] 379 | 380 | # Deprecated modules which should not be used, separated by a comma 381 | deprecated-modules=regsub, 382 | TERMIOS, 383 | Bastion, 384 | rexec, 385 | sets 386 | 387 | # Create a graph of every (i.e. internal and external) dependencies in the 388 | # given file (report RP0402 must not be disabled) 389 | import-graph= 390 | 391 | # Create a graph of external dependencies in the given file (report RP0402 must 392 | # not be disabled) 393 | ext-import-graph= 394 | 395 | # Create a graph of internal dependencies in the given file (report RP0402 must 396 | # not be disabled) 397 | int-import-graph= 398 | 399 | # Force import order to recognize a module as part of the standard 400 | # compatibility libraries. 401 | known-standard-library= 402 | 403 | # Force import order to recognize a module as part of a third party library. 404 | known-third-party=enchant, absl 405 | 406 | # Analyse import fallback blocks. This can be used to support both Python 2 and 407 | # 3 compatible code, which means that the block might have code that exists 408 | # only in one or another interpreter, leading to false positives when analysed. 409 | analyse-fallback-blocks=no 410 | 411 | 412 | [CLASSES] 413 | 414 | # List of method names used to declare (i.e. assign) instance attributes. 415 | defining-attr-methods=__init__, 416 | __new__, 417 | setUp 418 | 419 | # List of member names, which should be excluded from the protected access 420 | # warning. 421 | exclude-protected=_asdict, 422 | _fields, 423 | _replace, 424 | _source, 425 | _make 426 | 427 | # List of valid names for the first argument in a class method. 428 | valid-classmethod-first-arg=cls, 429 | class_ 430 | 431 | # List of valid names for the first argument in a metaclass class method. 432 | valid-metaclass-classmethod-first-arg=mcs 433 | 434 | 435 | [EXCEPTIONS] 436 | 437 | # Exceptions that will emit a warning when being caught. Defaults to 438 | # "Exception" 439 | overgeneral-exceptions=StandardError, 440 | Exception, 441 | BaseException 442 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Module containing dataset configurations used for generation and some utility functions.""" 16 | import functools 17 | import os 18 | import shutil 19 | from typing import Callable, Mapping, Any, TextIO, Generator, Tuple, Optional 20 | 21 | from absl import logging 22 | 23 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import ideal_double_pendulum 24 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import ideal_mass_spring 25 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import ideal_pendulum 26 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import n_body 27 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils 28 | import jax 29 | import jax.numpy as jnp 30 | import numpy as np 31 | import tensorflow as tf 32 | 33 | PipelineOutput = Optional[Tuple[ 34 | Tuple[Mapping[str, jnp.ndarray], ...], 35 | Tuple[Mapping[str, jnp.ndarray], ...], 36 | ]] 37 | 38 | try: 39 | from dm_hamiltonian_dynamics_suite.multiagent_dynamics import game_dynamics # pylint: disable=g-import-not-at-top 40 | _OPEN_SPIEL_INSTALLED = True 41 | except ModuleNotFoundError: 42 | _OPEN_SPIEL_INSTALLED = False 43 | 44 | 45 | def open_spiel_available() -> bool: 46 | return _OPEN_SPIEL_INSTALLED 47 | 48 | 49 | def set_up_folder(folder: str, overwrite: bool) -> None: 50 | """Sets up the folder needed for the dataset (optionally clearing it).""" 51 | if os.path.exists(folder): 52 | if overwrite: 53 | shutil.rmtree(folder) 54 | os.makedirs(folder) 55 | else: 56 | os.makedirs(folder) 57 | 58 | 59 | def save_features( 60 | file: TextIO, 61 | example_dict: Mapping[str, Any], 62 | prefix: str = "" 63 | ) -> None: 64 | """Saves the features file used for loading.""" 65 | for k, v in example_dict.items(): 66 | if isinstance(v, dict): 67 | save_features(file, v, prefix=f"{prefix}{k}/") 68 | else: 69 | if isinstance(v, tf.Tensor): 70 | v = v.numpy() 71 | if isinstance(v, (np.ndarray, jnp.ndarray)): 72 | # int32 are promoted to int64 73 | if v.dtype == np.int32: 74 | file.write(f"{prefix}{k}, {v.shape}, {np.int64}\n") 75 | else: 76 | file.write(f"{prefix}{k}, {v.shape}, {v.dtype}\n") 77 | else: 78 | raise NotImplementedError(f"Currently the only supported feature types " 79 | f"are tf.Tensor, np.ndarray and jnp.ndarray. " 80 | f"Encountered value of type {type(v)}.") 81 | 82 | 83 | def encode_example(example_dict: Mapping[str, Any]) -> Mapping[str, Any]: 84 | """Encodes a single trajectory into a TFRecord example.""" 85 | result_dict = dict() 86 | for k, v in example_dict.items(): 87 | if isinstance(v, tf.Tensor): 88 | v = v.numpy() 89 | if isinstance(v, dict): 90 | for ki, vi in encode_example(v).items(): 91 | result_dict[f"{k}/{ki}"] = vi 92 | elif isinstance(v, (np.ndarray, jnp.ndarray)): 93 | if v.dtype == np.uint8: 94 | # We encode images to png 95 | if v.ndim == 4: 96 | # Since encode_png accepts only a single image for a batch of images 97 | # we just stack them over their first axis. 98 | v = v.reshape((-1,) + v.shape[-2:]) 99 | image_string = tf.image.encode_png(v).numpy() 100 | result_dict[k] = tf.train.Feature( 101 | bytes_list=tf.train.BytesList(value=[image_string])) 102 | elif v.dtype == np.int32: 103 | # int32 are promoted to int64 104 | value = v.reshape([-1]).astype(np.int64) 105 | result_dict[k] = tf.train.Feature( 106 | int64_list=tf.train.Int64List(value=value)) 107 | else: 108 | # Since tf.Records do not support reading float64, here for any values 109 | # we interpret them as int64 and store them in this format, in order 110 | # when reading to be able to recover the float64 values. 111 | value = v.reshape([-1]).view(np.int64) 112 | result_dict[k] = tf.train.Feature( 113 | int64_list=tf.train.Int64List(value=value)) 114 | else: 115 | raise NotImplementedError(f"Currently the only supported feature types " 116 | f"are tf.Tensor, np.ndarray and jnp.ndarray. " 117 | f"Encountered value of type {type(v)}.") 118 | return result_dict 119 | 120 | 121 | def transform_dataset( 122 | generator: Generator[Mapping[str, Any], None, None], 123 | destination_folder: str, 124 | prefix: str, 125 | overwrite: bool 126 | ) -> None: 127 | """Copies the dataset from the source folder to the destination as a TFRecord dataset.""" 128 | set_up_folder(destination_folder, overwrite) 129 | features_path = os.path.join(destination_folder, "features.txt") 130 | features_saved = False 131 | 132 | file_path = os.path.join(destination_folder, f"{prefix}.tfrecord") 133 | if os.path.exists(file_path): 134 | if not overwrite: 135 | logging.info("The file with prefix %s already exist. Skipping.", prefix) 136 | # We assume that the features file must be present in this case. 137 | return 138 | else: 139 | logging.info("The file with prefix %s already exist and overwrite=True." 140 | " Deleting.", prefix) 141 | os.remove(file_path) 142 | with tf.io.TFRecordWriter(file_path) as writer: 143 | for element in generator: 144 | if not features_saved: 145 | with open(features_path, "w") as f: 146 | save_features(f, element) 147 | features_saved = True 148 | example = tf.train.Example(features=tf.train.Features( 149 | feature=encode_example(element))) 150 | writer.write(example.SerializeToString()) 151 | 152 | 153 | def generate_sample( 154 | index: int, 155 | system: n_body.hamiltonian.HamiltonianSystem, 156 | dt: float, 157 | num_steps: int, 158 | steps_per_dt: int 159 | ) -> Mapping[str, jnp.ndarray]: 160 | """Simulates a single trajectory of the system.""" 161 | seed = np.random.randint(0, 2 * 32 -1) 162 | prng_key = jax.random.fold_in(jax.random.PRNGKey(seed), index) 163 | total_steps = num_steps * steps_per_dt 164 | total_dt = dt / steps_per_dt 165 | result = system.generate_and_render_dt( 166 | num_trajectories=1, 167 | rng_key=prng_key, 168 | t0=0.0, 169 | dt=total_dt, 170 | num_steps=total_steps) 171 | sub_sample_index = np.linspace(0.0, total_steps, num_steps + 1) 172 | sub_sample_index = sub_sample_index.astype("int64") 173 | def sub_sample(x): 174 | if x.ndim > 1 and x.shape[1] == total_steps + 1: 175 | return x[0, sub_sample_index] 176 | else: 177 | return x 178 | result = jax.tree_map(sub_sample, result) 179 | for k in result.keys(): 180 | if "image" in k: 181 | result[k] = (result[k] * 255.0).astype("uint8") 182 | return result 183 | 184 | 185 | def create_pipeline( 186 | generate: Callable[[int], Mapping[str, jnp.ndarray]], 187 | output_path: str, 188 | num_train: int, 189 | num_test: int, 190 | return_generated_examples: bool = False, 191 | ) -> Callable[[], PipelineOutput]: 192 | """Runs the generation pipeline for the HML datasets.""" 193 | def pipeline() -> PipelineOutput: 194 | train_examples = list() 195 | test_examples = list() 196 | with open(f"{output_path}/features.txt", "w") as f: 197 | save_features(f, generate(0)) 198 | with tf.io.TFRecordWriter(f"{output_path}/train.tfrecord") as writer: 199 | for i in range(num_train): 200 | example = generate(i) 201 | if return_generated_examples: 202 | train_examples.append(example) 203 | example = tf.train.Example(features=tf.train.Features( 204 | feature=encode_example(example))) 205 | writer.write(example.SerializeToString()) 206 | with tf.io.TFRecordWriter(f"{output_path}/test.tfrecord") as writer: 207 | for i in range(num_test): 208 | example = generate(num_train + i) 209 | if return_generated_examples: 210 | test_examples.append(example) 211 | example = tf.train.Example(features=tf.train.Features( 212 | feature=encode_example(example))) 213 | writer.write(example.SerializeToString()) 214 | if return_generated_examples: 215 | return tuple(train_examples), tuple(test_examples) 216 | return pipeline 217 | 218 | 219 | def generate_full_dataset( 220 | folder: str, 221 | dataset: str, 222 | dt: float, 223 | num_steps: int, 224 | steps_per_dt: int, 225 | num_train: int, 226 | num_test: int, 227 | overwrite: bool, 228 | return_generated_examples: bool = False, 229 | ) -> PipelineOutput: 230 | """Runs the data generation.""" 231 | dt_str = str(dt).replace(".", "_") 232 | folder = os.path.join(folder, dataset.lower() + f"_dt_{dt_str}") 233 | set_up_folder(folder, overwrite) 234 | 235 | cls, config = globals().get(dataset.upper()) 236 | system = cls(**config()) 237 | generate = functools.partial( 238 | generate_sample, 239 | system=system, 240 | dt=dt, 241 | num_steps=num_steps, 242 | steps_per_dt=steps_per_dt) 243 | pipeline = create_pipeline( 244 | generate, folder, num_train, num_test, return_generated_examples) 245 | return pipeline() 246 | 247 | 248 | MASS_SPRING = ( 249 | ideal_mass_spring.IdealMassSpring, 250 | lambda: dict( # pylint:disable=g-long-lambda 251 | k_range=utils.BoxRegion(2.0, 2.0), 252 | m_range=utils.BoxRegion(0.5, 0.5), 253 | radius_range=utils.BoxRegion(0.1, 1.0), 254 | uniform_annulus=False, 255 | randomize_canvas_location=False, 256 | randomize_x=False, 257 | num_colors=1, 258 | ) 259 | ) 260 | 261 | 262 | MASS_SPRING_COLORS = ( 263 | ideal_mass_spring.IdealMassSpring, 264 | lambda: dict( # pylint:disable=g-long-lambda 265 | k_range=utils.BoxRegion(2.0, 2.0), 266 | m_range=utils.BoxRegion(0.2, 1.0), 267 | radius_range=utils.BoxRegion(0.1, 1.0), 268 | num_colors=6, 269 | ) 270 | ) 271 | 272 | 273 | MASS_SPRING_COLORS_FRICTION = ( 274 | ideal_mass_spring.IdealMassSpring, 275 | lambda: dict( # pylint:disable=g-long-lambda 276 | k_range=utils.BoxRegion(2.0, 2.0), 277 | m_range=utils.BoxRegion(0.2, 1.0), 278 | radius_range=utils.BoxRegion(0.1, 1.0), 279 | num_colors=6, 280 | friction=0.05, 281 | ), 282 | ) 283 | 284 | 285 | PENDULUM = ( 286 | ideal_pendulum.IdealPendulum, 287 | lambda: dict( # pylint:disable=g-long-lambda 288 | m_range=utils.BoxRegion(0.5, 0.5), 289 | g_range=utils.BoxRegion(3.0, 3.0), 290 | l_range=utils.BoxRegion(1.0, 1.0), 291 | radius_range=utils.BoxRegion(1.3, 2.3), 292 | uniform_annulus=False, 293 | randomize_canvas_location=False, 294 | num_colors=1, 295 | ) 296 | ) 297 | 298 | 299 | PENDULUM_COLORS = ( 300 | ideal_pendulum.IdealPendulum, 301 | lambda: dict( # pylint:disable=g-long-lambda 302 | m_range=utils.BoxRegion(0.5, 1.5), 303 | g_range=utils.BoxRegion(3.0, 4.0), 304 | l_range=utils.BoxRegion(.5, 1.0), 305 | radius_range=utils.BoxRegion(1.3, 2.3), 306 | num_colors=6, 307 | ) 308 | ) 309 | 310 | 311 | PENDULUM_COLORS_FRICTION = ( 312 | ideal_pendulum.IdealPendulum, 313 | lambda: dict( # pylint:disable=g-long-lambda 314 | m_range=utils.BoxRegion(0.5, 1.5), 315 | g_range=utils.BoxRegion(3.0, 4.0), 316 | l_range=utils.BoxRegion(.5, 1.0), 317 | radius_range=utils.BoxRegion(1.3, 2.3), 318 | num_colors=6, 319 | friction=0.05, 320 | ) 321 | ) 322 | 323 | 324 | DOUBLE_PENDULUM = ( 325 | ideal_double_pendulum.IdealDoublePendulum, 326 | lambda: dict( # pylint:disable=g-long-lambda 327 | m_range=utils.BoxRegion(0.5, 0.5), 328 | g_range=utils.BoxRegion(3.0, 3.0), 329 | l_range=utils.BoxRegion(1.0, 1.0), 330 | radius_range=utils.BoxRegion(1.3, 2.3), 331 | uniform_annulus=False, 332 | randomize_canvas_location=False, 333 | num_colors=2, 334 | ) 335 | ) 336 | 337 | 338 | DOUBLE_PENDULUM_COLORS = ( 339 | ideal_double_pendulum.IdealDoublePendulum, 340 | lambda: dict( # pylint:disable=g-long-lambda 341 | m_range=utils.BoxRegion(0.4, 0.6), 342 | g_range=utils.BoxRegion(2.5, 4.0), 343 | l_range=utils.BoxRegion(0.75, 1.0), 344 | radius_range=utils.BoxRegion(1.0, 2.5), 345 | num_colors=6, 346 | ) 347 | ) 348 | 349 | 350 | DOUBLE_PENDULUM_COLORS_FRICTION = ( 351 | ideal_double_pendulum.IdealDoublePendulum, 352 | lambda: dict( # pylint:disable=g-long-lambda 353 | m_range=utils.BoxRegion(0.4, 0.6), 354 | g_range=utils.BoxRegion(2.5, 4.0), 355 | l_range=utils.BoxRegion(0.75, 1.0), 356 | radius_range=utils.BoxRegion(1.0, 2.5), 357 | num_colors=6, 358 | friction=0.05 359 | ), 360 | ) 361 | 362 | 363 | TWO_BODY = ( 364 | n_body.TwoBodySystem, 365 | lambda: dict( # pylint:disable=g-long-lambda 366 | m_range=utils.BoxRegion(1.0, 1.0), 367 | g_range=utils.BoxRegion(1.0, 1.0), 368 | radius_range=utils.BoxRegion(0.5, 1.5), 369 | provided_canvas_bounds=utils.BoxRegion(-2.75, 2.75), 370 | randomize_canvas_location=False, 371 | num_colors=2, 372 | ) 373 | ) 374 | 375 | 376 | TWO_BODY_COLORS = ( 377 | n_body.TwoBodySystem, 378 | lambda: dict( # pylint:disable=g-long-lambda 379 | m_range=utils.BoxRegion(0.5, 1.5), 380 | g_range=utils.BoxRegion(0.5, 1.5), 381 | radius_range=utils.BoxRegion(0.5, 1.5), 382 | provided_canvas_bounds=utils.BoxRegion(-5.0, 5.0), 383 | randomize_canvas_location=False, 384 | num_colors=6, 385 | ) 386 | ) 387 | 388 | 389 | def no_open_spiel_func(*_, **__): 390 | raise ValueError("You must download and install `open_spiel` first in " 391 | "order to use the game_dynamics datasets. See " 392 | "https://github.com/deepmind/open_spiel for instructions" 393 | " how to do this.") 394 | 395 | if not open_spiel_available(): 396 | MATCHING_PENNIES = (no_open_spiel_func, dict) 397 | ROCK_PAPER_SCISSORS = (no_open_spiel_func, dict) 398 | else: 399 | MATCHING_PENNIES = (game_dynamics.ZeroSumGame, 400 | lambda: dict(game_name="matrix_mp")) 401 | ROCK_PAPER_SCISSORS = (game_dynamics.ZeroSumGame, 402 | lambda: dict(game_name="matrix_rps")) 403 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/generate_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Script to generate datasets from the configs in datasets.py.""" 16 | from absl import app 17 | from absl import flags 18 | 19 | from dm_hamiltonian_dynamics_suite import datasets 20 | from dm_hamiltonian_dynamics_suite.molecular_dynamics import generate_dataset 21 | 22 | import jax 23 | 24 | flags.DEFINE_string("folder", None, 25 | "The folder where to store the datasets.") 26 | flags.DEFINE_string("dataset", None, 27 | "The dataset from datasets.py to use or " 28 | "'molecular_dynamics' respectively.") 29 | flags.DEFINE_string("lammps_file", None, 30 | "For dataset='molecular_dynamics' this should be the " 31 | "LAMMPS trajectory file containing a sequence of timesteps " 32 | "obtained from a MD simulation.") 33 | flags.DEFINE_float("dt", None, "The delta time between two observations.") 34 | flags.DEFINE_integer("num_steps", None, "The number of steps to simulate.") 35 | flags.DEFINE_integer("steps_per_dt", 10, 36 | "How many internal steps to do per a single observation " 37 | "step.") 38 | flags.DEFINE_integer("num_train", None, 39 | "The number of training examples to generate.") 40 | flags.DEFINE_integer("num_test", None, 41 | "The number of test examples to generate.") 42 | flags.DEFINE_boolean("overwrite", False, "Overwrites previous data.") 43 | 44 | flags.mark_flag_as_required("folder") 45 | flags.mark_flag_as_required("dataset") 46 | flags.mark_flag_as_required("dt") 47 | flags.mark_flag_as_required("num_steps") 48 | flags.mark_flag_as_required("num_train") 49 | flags.mark_flag_as_required("num_test") 50 | FLAGS = flags.FLAGS 51 | 52 | 53 | def main(argv): 54 | if len(argv) > 1: 55 | raise ValueError(f"Unexpected args: {argv[1:]}") 56 | if FLAGS.dataset == "molecular_dynamics": 57 | generate_dataset.generate_lammps_dataset( 58 | lammps_file=FLAGS.lammps_file, 59 | folder=FLAGS.output_path, 60 | dt=FLAGS.dt, 61 | num_steps=FLAGS.num_steps, 62 | num_train=FLAGS.num_train, 63 | num_test=FLAGS.num_test, 64 | shuffle=FLAGS.shuffle, 65 | seed=FLAGS.seed, 66 | overwrite=FLAGS.overwrite, 67 | ) 68 | else: 69 | datasets.generate_full_dataset( 70 | folder=FLAGS.folder, 71 | dataset=FLAGS.dataset, 72 | dt=FLAGS.dt, 73 | num_steps=FLAGS.num_steps, 74 | steps_per_dt=FLAGS.steps_per_dt, 75 | num_train=FLAGS.num_train, 76 | num_test=FLAGS.num_test, 77 | overwrite=FLAGS.overwrite 78 | ) 79 | 80 | 81 | if __name__ == "__main__": 82 | jax.config.update("jax_enable_x64", True) 83 | app.run(main) 84 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/generate_locally.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2020 DeepMind Technologies Limited. 3 | # 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | if [[ "$#" -eq 1 ]]; then 18 | readonly DATASET="${1}" 19 | else 20 | echo "Illegal number of parameters. Expected only the name of the dataset." 21 | exit 2 22 | fi 23 | 24 | readonly FOLDER="/tmp/hamiltonian_ml/datasets" 25 | readonly DTS=(0.05 0.1) 26 | readonly NUM_STEPS=255 27 | readonly NUM_TRAIN=500 28 | readonly NUM_TEST=200 29 | 30 | for DT in "${DTS[@]}"; do 31 | python3 generate_dataset.py \ 32 | --folder=${FOLDER} \ 33 | --dataset="${DATASET}" \ 34 | --dt="${DT}" \ 35 | --num_steps=${NUM_STEPS} \ 36 | --num_train=${NUM_TRAIN} \ 37 | --num_test=${NUM_TEST} \ 38 | --overwrite=true 39 | done 40 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/hamiltonian_systems/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/hamiltonian_systems/hamiltonian.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A module with the abstract class for Hamiltonian systems.""" 16 | import abc 17 | from typing import Any, Callable, Mapping, Optional, Tuple, Union 18 | 19 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space 20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils 21 | 22 | import jax 23 | import jax.numpy as jnp 24 | import jax.random as jnr 25 | from scipy import integrate 26 | 27 | Integrator = Callable[ 28 | [ 29 | Union[phase_space.HamiltonianFunction, 30 | phase_space.SymplecticTangentFunction], # dy_dt 31 | Union[jnp.ndarray, phase_space.PhaseSpace], # y0 32 | Union[float, jnp.ndarray], # t0 33 | Union[float, jnp.ndarray], # dt 34 | int, # num_steps 35 | int # steps_per_dt 36 | ], 37 | Tuple[jnp.ndarray, phase_space.PhaseSpace] 38 | ] 39 | 40 | 41 | class HamiltonianSystem(abc.ABC): 42 | """General class to represent Hamiltonian Systems and simulate them.""" 43 | 44 | def __init__( 45 | self, 46 | system_dims: int, 47 | randomize_canvas_location: bool = True, 48 | random_canvas_extra_ratio: float = 0.4, 49 | try_analytic_solution: bool = True, 50 | friction: float = 0.0, 51 | method: Union[str, Integrator] = "scipy", 52 | num_colors: int = 6, 53 | image_resolution: int = 32, 54 | dtype: str = "float32", 55 | steps_per_dt: int = 1, 56 | stiff: bool = False, 57 | extra_ivp_kwargs: Optional[Mapping[str, Union[str, int, float]]] = None): 58 | """Initializes some global properties. 59 | 60 | Args: 61 | system_dims: Dimensionality of the positions (not joint state). 62 | randomize_canvas_location: Whether to add random offset to images. 63 | random_canvas_extra_ratio: How much to be the extra random ofsset. 64 | try_analytic_solution: If True, first tries to solve the system 65 | analytically possible. If False, alway integrates systems numerically. 66 | friction: This changes the dynamics to non-conservative. The new 67 | dynamics are formulated as follows: 68 | dq/dt = dH/dp 69 | dp/dt = - dH/dq - friction * dH/dp 70 | This implies that the Hamiltonian energy decreases over time: 71 | dH/dt = dH/dq^T dq/dt + dH/dp^T dp/dt = - friction * ||dH/dp||_2^2 72 | method: "scipy" or a callable of type `Integrator`. 73 | num_colors: The number of possible colors to use for rendering. 74 | image_resolution: For generated images their resolution. 75 | dtype: What dtype to use for the generated data. 76 | steps_per_dt: Number of inner steps to use per a single observation dt. 77 | stiff: Whether the problem represents a stiff system. 78 | extra_ivp_kwargs: Extra arguments to the scipy solver. 79 | Raises: 80 | ValueError: if `dtype` is not 'float32' or 'float64'. 81 | """ 82 | self._system_dims = system_dims 83 | self._randomize_canvas_location = randomize_canvas_location 84 | self._random_canvas_extra_ratio = random_canvas_extra_ratio 85 | self._try_analytic_solution = try_analytic_solution 86 | self._friction = friction 87 | self._method = method 88 | self._num_colors = num_colors 89 | self._resolution = image_resolution 90 | self._dtype = dtype 91 | self._stiff = stiff 92 | self._steps_per_dt = steps_per_dt 93 | if dtype == "float64": 94 | self._scipy_ivp_kwargs = dict(rtol=1e-12, atol=1e-12) 95 | elif dtype == "float32": 96 | self._scipy_ivp_kwargs = dict(rtol=1e-9, atol=1e-9) 97 | else: 98 | raise ValueError("Currently we only support float64 and float32 dtypes.") 99 | if stiff: 100 | self._scipy_ivp_kwargs["method"] = "Radau" 101 | if extra_ivp_kwargs is not None: 102 | self._scipy_ivp_kwargs.update(extra_ivp_kwargs) 103 | 104 | @property 105 | def system_dims(self) -> int: 106 | return self._system_dims 107 | 108 | @property 109 | def randomize_canvas_location(self) -> bool: 110 | return self._randomize_canvas_location 111 | 112 | @property 113 | def random_canvas_extra_ratio(self) -> float: 114 | return self._random_canvas_extra_ratio 115 | 116 | @property 117 | def try_analytic_solution(self) -> bool: 118 | return self._try_analytic_solution 119 | 120 | @property 121 | def friction(self) -> float: 122 | return self._friction 123 | 124 | @property 125 | def method(self): 126 | return self._method 127 | 128 | @property 129 | def num_colors(self) -> int: 130 | return self._num_colors 131 | 132 | @property 133 | def resolution(self) -> int: 134 | return self._resolution 135 | 136 | @property 137 | def dtype(self) -> str: 138 | return self._dtype 139 | 140 | @property 141 | def stiff(self) -> bool: 142 | return self._stiff 143 | 144 | @property 145 | def steps_per_dt(self) -> int: 146 | return self._steps_per_dt 147 | 148 | @property 149 | def scipy_ivp_kwargs(self) -> Mapping[str, Union[str, int, float]]: 150 | return self._scipy_ivp_kwargs 151 | 152 | @abc.abstractmethod 153 | def parametrized_hamiltonian( 154 | self, 155 | t: jnp.ndarray, 156 | y: phase_space.PhaseSpace, 157 | params: utils.Params, 158 | **kwargs: Any 159 | ) -> jnp.ndarray: 160 | """Calculates the Hamiltonian.""" 161 | 162 | def hamiltonian_from_params( 163 | self, 164 | params: utils.Params, 165 | **kwargs: Any 166 | ) -> phase_space.HamiltonianFunction: 167 | def hamiltonian(t: jnp.ndarray, y: phase_space.PhaseSpace) -> jnp.ndarray: 168 | return self.parametrized_hamiltonian(t, y, params, **kwargs) 169 | return hamiltonian 170 | 171 | @abc.abstractmethod 172 | def sample_y( 173 | self, 174 | num_samples: int, 175 | params: utils.Params, 176 | rng_key: jnp.ndarray, 177 | **kwargs: Any 178 | ) -> phase_space.PhaseSpace: 179 | """Samples randomly initial states.""" 180 | 181 | @abc.abstractmethod 182 | def sample_params( 183 | self, 184 | num_samples: int, 185 | rng_key: jnp.ndarray, 186 | **kwargs: Any 187 | ) -> utils.Params: 188 | """Samples randomly parameters.""" 189 | 190 | @abc.abstractmethod 191 | def simulate_analytically( 192 | self, 193 | y0: phase_space.PhaseSpace, 194 | t0: utils.FloatArray, 195 | t_eval: jnp.ndarray, 196 | params: utils.Params, 197 | **kwargs: Any 198 | ) -> Optional[phase_space.PhaseSpace]: 199 | """If analytic solution exist returns it, else returns None.""" 200 | 201 | def simulate_analytically_dt( 202 | self, 203 | y0: phase_space.PhaseSpace, 204 | t0: utils.FloatArray, 205 | dt: utils.FloatArray, 206 | num_steps: int, 207 | params: utils.Params, 208 | **kwargs: Any 209 | ) -> Optional[phase_space.PhaseSpace]: 210 | """Same as `simulate_analytically` but uses `dt` and `num_steps`.""" 211 | t_eval = utils.dt_to_t_eval(t0, dt, num_steps) 212 | return self.simulate_analytically(y0, t0, t_eval, params, **kwargs) 213 | 214 | @abc.abstractmethod 215 | def canvas_bounds(self) -> utils.BoxRegion: 216 | """Returns the limits of the canvas for rendering.""" 217 | 218 | def random_offset_bounds(self) -> utils.BoxRegion: 219 | """Returns any extra randomized offset that can be given to the canvas.""" 220 | extra_size = self.random_canvas_extra_ratio * self.canvas_bounds().size / 2 221 | return utils.BoxRegion( 222 | minimum=-extra_size, 223 | maximum=extra_size 224 | ) 225 | 226 | def full_canvas_bounds(self) -> utils.BoxRegion: 227 | if self.randomize_canvas_location: 228 | return utils.BoxRegion( 229 | self.canvas_bounds().min + self.random_offset_bounds().min, 230 | self.canvas_bounds().max + self.random_offset_bounds().max, 231 | ) 232 | else: 233 | return self.canvas_bounds() 234 | 235 | @abc.abstractmethod 236 | def canvas_position( 237 | self, 238 | position: jnp.ndarray, 239 | params: utils.Params 240 | ) -> jnp.ndarray: 241 | """Returns the canvas position given the position vectors and the parameters.""" 242 | 243 | @abc.abstractmethod 244 | def render_trajectories( 245 | self, 246 | position: jnp.ndarray, 247 | params: utils.Params, 248 | rng_key: jnp.ndarray, 249 | **kwargs: Any 250 | ) -> Tuple[jnp.ndarray, utils.Params]: 251 | """Renders the positions q into an image.""" 252 | 253 | def simulate_scipy( 254 | self, 255 | y0: phase_space.PhaseSpace, 256 | t0: utils.FloatArray, 257 | t_eval: jnp.ndarray, 258 | params: utils.Params, 259 | ivp_kwargs=None, 260 | **kwargs: Any 261 | ) -> phase_space.PhaseSpace: 262 | """Simulates the system using scipy.integrate.solve_ivp.""" 263 | t_span = (t0, float(t_eval[-1])) 264 | y0 = jnp.concatenate([y0.q, y0.p], axis=-1) 265 | y_shape = y0.shape 266 | y0 = y0.reshape([-1]) 267 | hamiltonian = self.hamiltonian_from_params(params, **kwargs) 268 | @jax.jit 269 | def fun(t, y): 270 | f = phase_space.poisson_bracket_with_q_and_p(hamiltonian) 271 | dy = f(t, phase_space.PhaseSpace.from_state(y.reshape(y_shape))) 272 | if self.friction != 0.0: 273 | friction_term = phase_space.TangentPhaseSpace( 274 | position=jnp.zeros_like(dy.position), 275 | momentum=-self.friction * dy.position) 276 | dy = dy + friction_term 277 | return dy.single_state.reshape([-1]) 278 | kwargs = dict(**self.scipy_ivp_kwargs) 279 | kwargs.update(ivp_kwargs or dict()) 280 | solution = integrate.solve_ivp( 281 | fun=fun, 282 | t_span=t_span, 283 | y0=y0, 284 | t_eval=t_eval, 285 | **kwargs) 286 | y_final = solution.y.reshape(y_shape + (t_eval.size,)) 287 | return phase_space.PhaseSpace.from_state(jnp.moveaxis(y_final, -1, 0)) 288 | 289 | def simulate_scipy_dt( 290 | self, 291 | y0: phase_space.PhaseSpace, 292 | t0: utils.FloatArray, 293 | dt: utils.FloatArray, 294 | num_steps: int, 295 | params: utils.Params, 296 | ivp_kwargs=None, 297 | **kwargs: Any 298 | ) -> phase_space.PhaseSpace: 299 | """Same as `simulate_scipy` but uses `dt` and `num_steps`.""" 300 | t_eval = utils.dt_to_t_eval(t0, dt, num_steps) 301 | return self.simulate_scipy(y0, t0, t_eval, params, ivp_kwargs, **kwargs) 302 | 303 | def simulate_integrator( 304 | self, 305 | y0: phase_space.PhaseSpace, 306 | t0: utils.FloatArray, 307 | t_eval: jnp.ndarray, 308 | params: utils.Params, 309 | method: Union[str, Integrator], 310 | **kwargs: Any 311 | ) -> phase_space.PhaseSpace: 312 | """Simulates the system using an integrator from integrators.py module.""" 313 | return self.simulate_integrator_dt( 314 | y0=y0, 315 | t0=t0, 316 | dt=utils.t_eval_to_dt(t0, t_eval), 317 | params=params, 318 | method=method, 319 | **kwargs 320 | ) 321 | 322 | def simulate_integrator_dt( 323 | self, 324 | y0: phase_space.PhaseSpace, 325 | t0: Union[float, jnp.ndarray], 326 | dt: Union[float, jnp.ndarray], 327 | params: utils.Params, 328 | method: Integrator, 329 | num_steps: Optional[int] = None, 330 | **kwargs: Any 331 | ) -> phase_space.PhaseSpace: 332 | """Same as `simulate_integrator` but uses `dt` and `num_steps`.""" 333 | hamiltonian = self.hamiltonian_from_params(params, **kwargs) 334 | if self.friction == 0.0: 335 | return method( 336 | hamiltonian, 337 | y0, 338 | t0, 339 | dt, 340 | num_steps, 341 | self.steps_per_dt, 342 | )[1] 343 | else: 344 | def dy_dt(t: jnp.ndarray, y: phase_space.PhaseSpace): 345 | f = phase_space.poisson_bracket_with_q_and_p(hamiltonian) 346 | dy = f(t, y) 347 | friction_term = phase_space.TangentPhaseSpace( 348 | position=jnp.zeros_like(dy.position), 349 | momentum=-self.friction * dy.position) 350 | return dy + friction_term 351 | 352 | return method( 353 | dy_dt, 354 | y0, 355 | t0, 356 | dt, 357 | num_steps, 358 | self.steps_per_dt, 359 | )[1] 360 | 361 | def generate_trajectories( 362 | self, 363 | y0: phase_space.PhaseSpace, 364 | t0: utils.FloatArray, 365 | t_eval: jnp.ndarray, 366 | params: utils.Params, 367 | **kwargs: Any 368 | ) -> phase_space.PhaseSpace: 369 | """Generates trajectories of the system in phase space. 370 | 371 | Args: 372 | y0: Initial state. 373 | t0: The time instance of the initial state y0. 374 | t_eval: Times at which to return the computed solution. 375 | params: Any parameters of the Hamiltonian. 376 | **kwargs: Any extra things that go into the hamiltonian. 377 | 378 | Returns: 379 | A phase_space.PhaseSpace instance of size NxTxD. 380 | """ 381 | return self.generate_trajectories_dt( 382 | y0=y0, 383 | t0=t0, 384 | dt=utils.t_eval_to_dt(t0, t_eval), 385 | params=params, 386 | **kwargs 387 | ) 388 | 389 | def generate_trajectories_dt( 390 | self, 391 | y0: phase_space.PhaseSpace, 392 | t0: utils.FloatArray, 393 | dt: utils.FloatArray, 394 | params: utils.Params, 395 | num_steps_forward: int, 396 | include_t0: bool = False, 397 | num_steps_backward: int = 0, 398 | **kwargs: Any 399 | ) -> phase_space.PhaseSpace: 400 | """Same as `generate_trajectories` but uses `dt` and `num_steps`.""" 401 | if num_steps_forward < 0 or num_steps_backward < 0: 402 | raise ValueError("num_steps_forward and num_steps_backward can not be " 403 | "negative.") 404 | if num_steps_forward == 0 and num_steps_backward == 0: 405 | raise ValueError("You need one of num_steps_forward or " 406 | "num_of_steps_backward to be positive.") 407 | if num_steps_forward > 0 and num_steps_backward > 0 and not include_t0: 408 | raise ValueError("When both num_steps_forward and num_steps_backward are " 409 | "positive include_t0 should be True.") 410 | 411 | if self.try_analytic_solution and num_steps_backward == 0: 412 | # Try to use analytical solution 413 | y = self.simulate_analytically_dt(y0, t0, dt, num_steps_forward, params, 414 | **kwargs) 415 | if y is not None: 416 | return y 417 | if self.method == "scipy": 418 | if num_steps_backward > 0: 419 | raise NotImplementedError() 420 | return self.simulate_scipy_dt(y0, t0, dt, num_steps_forward, params, 421 | **kwargs) 422 | yts = [] 423 | if num_steps_backward > 0: 424 | yt = self.simulate_integrator_dt( 425 | y0=y0, 426 | t0=t0, 427 | dt=-dt, 428 | params=params, 429 | method=self.method, 430 | num_steps=num_steps_backward, 431 | **kwargs) 432 | yt = jax.tree_map(lambda x: jnp.flip(x, axis=0), yt) 433 | yts.append(yt) 434 | if include_t0: 435 | yts.append(jax.tree_map(lambda x: x[None], y0)) 436 | if num_steps_forward > 0: 437 | yt = self.simulate_integrator_dt( 438 | y0=y0, 439 | t0=t0, 440 | dt=dt, 441 | params=params, 442 | method=self.method, 443 | num_steps=num_steps_forward, 444 | **kwargs) 445 | yts.append(yt) 446 | if len(yts) > 1: 447 | return jax.tree_multimap(lambda *a: jnp.concatenate(a, axis=0), *yts) 448 | else: 449 | return yts[0] 450 | 451 | def generate_and_render( 452 | self, 453 | num_trajectories: int, 454 | rng_key: jnp.ndarray, 455 | t0: utils.FloatArray, 456 | t_eval: utils.FloatArray, 457 | y0: Optional[phase_space.PhaseSpace] = None, 458 | params: Optional[utils.Params] = None, 459 | within_canvas_bounds: bool = True, 460 | **kwargs: Any 461 | ) -> Mapping[str, Any]: 462 | """Generates trajectories and renders them. 463 | 464 | Args: 465 | num_trajectories: The number of trajectories to generate. 466 | rng_key: PRNG key for sampling any random numbers. 467 | t0: The time instance of the initial state y0. 468 | t_eval: Times at which to return the computed solution. 469 | y0: Initial state. If None will be sampled with `self.sample_y` 470 | params: Parameters of the Hamiltonian. If None will be sampled with 471 | `self.sample_params` 472 | within_canvas_bounds: Re-samples y0 until the trajectories is within 473 | the canvas bounds. 474 | **kwargs: Any extra things that go into the hamiltonian. 475 | 476 | Returns: 477 | A dictionary containing the following elements: 478 | "x": A numpy array representation of the PhaseSpace vector. 479 | "dx_dt": The time derivative of "x". 480 | "image": An image representation of the state. 481 | "other": A dict of other parameters of the system that are not part of 482 | the state. 483 | """ 484 | return self.generate_and_render_dt( 485 | num_trajectories=num_trajectories, 486 | rng_key=rng_key, 487 | t0=t0, 488 | dt=utils.t_eval_to_dt(t0, t_eval), 489 | y0=y0, 490 | params=params, 491 | within_canvas_bounds=within_canvas_bounds, 492 | **kwargs 493 | ) 494 | 495 | def generate_and_render_dt( 496 | self, 497 | num_trajectories: int, 498 | rng_key: jnp.ndarray, 499 | t0: utils.FloatArray, 500 | dt: utils.FloatArray, 501 | num_steps: Optional[int] = None, 502 | y0: Optional[phase_space.PhaseSpace] = None, 503 | params: Optional[utils.Params] = None, 504 | within_canvas_bounds: bool = True, 505 | **kwargs: Any 506 | ) -> Mapping[str, Any]: 507 | """Same as `generate_and_render` but uses `dt` and `num_steps`.""" 508 | if within_canvas_bounds and (y0 is not None or params is not None): 509 | raise ValueError("Within canvas bounds is valid only when y0 and params " 510 | "are None.") 511 | if params is None: 512 | rng_key, key = jnr.split(rng_key) 513 | params = self.sample_params(num_trajectories, rng_key, **kwargs) 514 | if y0 is None: 515 | rng_key, key = jnr.split(rng_key) 516 | y0 = self.sample_y(num_trajectories, params, key, **kwargs) 517 | 518 | # Generate the phase-space trajectories 519 | x = self.generate_trajectories_dt(y0, t0, dt, params, num_steps, **kwargs) 520 | # Make batch leading dimension 521 | x = jax.tree_map(lambda x_: jnp.swapaxes(x_, 0, 1), x) 522 | x = jax.tree_multimap(lambda i, j: jnp.concatenate([i[:, None], j], axis=1), 523 | y0, x) 524 | if within_canvas_bounds: 525 | # Check for valid trajectories 526 | valid = [] 527 | while len(valid) < num_trajectories: 528 | for idx in range(x.q.shape[0]): 529 | x_idx, params_idx = jax.tree_map(lambda a, i=idx: a[i], (x, params)) 530 | position = self.canvas_position(x_idx.q, params_idx) 531 | if (jnp.all(position >= self.canvas_bounds().min) and 532 | jnp.all(position <= self.canvas_bounds().max)): 533 | valid.append((x_idx, params_idx)) 534 | if len(valid) == num_trajectories: 535 | break 536 | new_trajectories = num_trajectories - len(valid) 537 | print(f"Generating {new_trajectories} new trajectories.") 538 | rng_key, key = jnr.split(rng_key) 539 | params = self.sample_params(new_trajectories, rng_key, **kwargs) 540 | rng_key, key = jnr.split(rng_key) 541 | y0 = self.sample_y(new_trajectories, params, key, **kwargs) 542 | x = self.generate_trajectories_dt(y0, t0, dt, params, num_steps, 543 | **kwargs) 544 | x = jax.tree_map(lambda x_: jnp.swapaxes(x_, 0, 1), x) 545 | x = jax.tree_multimap(lambda i, j: # pylint:disable=g-long-lambda 546 | jnp.concatenate([i[:, None], j], axis=1), y0, x) 547 | x, params = jax.tree_multimap(lambda *args: jnp.stack(args, axis=0), 548 | *valid) 549 | 550 | hamiltonian = self.hamiltonian_from_params(params, **kwargs) 551 | df_dt = jax.vmap(phase_space.poisson_bracket_with_q_and_p(hamiltonian), 552 | in_axes=[0, 1], out_axes=1) 553 | if isinstance(dt, float): 554 | dt = jnp.asarray([dt] * num_steps, dtype=x.q.dtype) 555 | t0 = jnp.asarray(t0).astype(dt.dtype) 556 | t = jnp.cumsum(jnp.concatenate([t0[None], dt], axis=0), axis=0) 557 | dx_dt = df_dt(t, x) 558 | rng_key, key = jnr.split(rng_key) 559 | image, extra = self.render_trajectories(x.q, params, rng_key, **kwargs) 560 | params.update(extra) 561 | return dict(x=x.single_state, dx_dt=dx_dt.single_state, 562 | image=image, other=params) 563 | 564 | 565 | class TimeIndependentHamiltonianSystem(HamiltonianSystem): 566 | """A Hamiltonian system where the energy does not depend on time.""" 567 | 568 | @abc.abstractmethod 569 | def _hamiltonian( 570 | self, 571 | y: phase_space.PhaseSpace, 572 | params: utils.Params, 573 | **kwargs: Any 574 | ) -> jnp.ndarray: 575 | """Computes the time independent Hamiltonian.""" 576 | 577 | def parametrized_hamiltonian( 578 | self, 579 | t: jnp.ndarray, 580 | y: phase_space.PhaseSpace, 581 | params: utils.Params, 582 | **kwargs: Any 583 | ) -> jnp.ndarray: 584 | return self._hamiltonian(y, params, **kwargs) 585 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/hamiltonian_systems/ideal_double_pendulum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Ideal double pendulum.""" 16 | import functools 17 | from typing import Any, Optional, Tuple 18 | 19 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import hamiltonian 20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space 21 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils 22 | 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | 27 | class IdealDoublePendulum(hamiltonian.TimeIndependentHamiltonianSystem): 28 | """An idealized double pendulum system. 29 | 30 | Parameters: 31 | m_range - possible range of the particle mass 32 | g_range - possible range of the gravitational force 33 | l_range - possible range of the length of the pendulum 34 | 35 | The Hamiltonian is: 36 | H = [m_2 * l_2^2 * p_1^2 + (m_1 + m_2) * l_1^2 * p_2^2 - 2 * m_2 * l_1 * 37 | l_2 * p_1 * p_2 * cos(q_1 - q_2)] / 38 | [2 * m_2 * l_1^2 * l_2^2 * (m_1 + m_2 * sin(q_1 - q_2)^2] 39 | - (m_1 + m_2) * g * l_1 * cos(q_1) - m_2 * g * l_2 * cos(q_2) 40 | 41 | See https://iopscience.iop.org/article/10.1088/1742-6596/739/1/012066/meta 42 | 43 | Initial state parameters: 44 | radius_range - The initial state is sampled from a disk in phase space with 45 | radius in this range. 46 | uniform_annulus - Whether to sample uniformly on the disk or uniformly the 47 | radius. 48 | randomize_canvas_location - Whether to randomize th vertical position of the 49 | particle when rendering. 50 | """ 51 | 52 | def __init__( 53 | self, 54 | m_range: utils.BoxRegion, 55 | g_range: utils.BoxRegion, 56 | l_range: utils.BoxRegion, 57 | radius_range: utils.BoxRegion, 58 | uniform_annulus: bool = True, 59 | **kwargs): 60 | super().__init__(system_dims=2, **kwargs) 61 | self.m_range = m_range 62 | self.g_range = g_range 63 | self.l_range = l_range 64 | self.radius_range = radius_range 65 | self.uniform_annulus = uniform_annulus 66 | render = functools.partial( 67 | utils.render_particles_trajectory, 68 | canvas_limits=self.full_canvas_bounds(), 69 | resolution=self.resolution, 70 | num_colors=self.num_colors) 71 | self._batch_render = jax.vmap(render) 72 | 73 | def _hamiltonian( 74 | self, 75 | y: phase_space.PhaseSpace, 76 | params: utils.Params, 77 | **kwargs: Any 78 | ) -> jnp.ndarray: 79 | assert len(params) == 5 80 | m_1 = params["m_1"] 81 | l_1 = params["l_1"] 82 | m_2 = params["m_2"] 83 | l_2 = params["l_2"] 84 | g = params["g"] 85 | 86 | q_1, q_2 = y.q[..., 0], y.q[..., 1] 87 | p_1, p_2 = y.p[..., 0], y.p[..., 1] 88 | 89 | a_1 = m_2 * l_2 ** 2 * p_1 ** 2 90 | a_2 = (m_1 + m_2) * l_1 ** 2 * p_2 ** 2 91 | a_3 = 2 * m_2 * l_1 * l_2 * p_1 * p_2 * jnp.cos(q_1 - q_2) 92 | b_1 = 2 * m_2 * l_1 ** 2 * l_2 ** 2 93 | b_2 = (m_1 + m_2 * jnp.sin(q_1 - q_2) ** 2) 94 | c_1 = (m_1 + m_2) * g * l_1 * jnp.cos(q_1) 95 | c_2 = m_2 * g * l_2 * jnp.cos(q_2) 96 | return (a_1 + a_2 - a_3) / (b_1 * b_2) - c_1 - c_2 97 | 98 | def sample_y( 99 | self, 100 | num_samples: int, 101 | params: utils.Params, 102 | rng_key: jnp.ndarray, 103 | **kwargs: Any 104 | ) -> phase_space.PhaseSpace: 105 | key1, key2 = jax.random.split(rng_key) 106 | state_1 = utils.uniform_annulus( 107 | key1, num_samples, 2, self.radius_range, self.uniform_annulus) 108 | state_2 = utils.uniform_annulus( 109 | key2, num_samples, 2, self.radius_range, self.uniform_annulus) 110 | state = jnp.stack([state_1[..., 0], state_2[..., 0], 111 | state_1[..., 1], state_2[..., 1]], axis=-1) 112 | return phase_space.PhaseSpace.from_state(state.astype(self.dtype)) 113 | 114 | def sample_params( 115 | self, 116 | num_samples: int, 117 | rng_key: jnp.ndarray, 118 | **kwargs: Any 119 | ) -> utils.Params: 120 | keys = jax.random.split(rng_key, 5) 121 | m_1 = jax.random.uniform(keys[0], [num_samples], minval=self.m_range.min, 122 | maxval=self.m_range.max) 123 | m_2 = jax.random.uniform(keys[1], [num_samples], minval=self.m_range.min, 124 | maxval=self.m_range.max) 125 | l_1 = jax.random.uniform(keys[2], [num_samples], minval=self.l_range.min, 126 | maxval=self.l_range.max) 127 | l_2 = jax.random.uniform(keys[3], [num_samples], minval=self.l_range.min, 128 | maxval=self.l_range.max) 129 | g = jax.random.uniform(keys[4], [num_samples], minval=self.g_range.min, 130 | maxval=self.g_range.max) 131 | return dict(m_1=m_1, m_2=m_2, l_1=l_1, l_2=l_2, g=g) 132 | 133 | def simulate_analytically( 134 | self, 135 | y0: phase_space.PhaseSpace, 136 | t0: utils.FloatArray, 137 | t_eval: jnp.ndarray, 138 | params: utils.Params, 139 | **kwargs: Any 140 | ) -> Optional[phase_space.PhaseSpace]: 141 | return None 142 | 143 | def canvas_bounds(self) -> utils.BoxRegion: 144 | max_d = 2 * self.l_range.max + jnp.sqrt(self.m_range.max / jnp.pi) 145 | return utils.BoxRegion(-max_d, max_d) 146 | 147 | def canvas_position( 148 | self, 149 | position: jnp.ndarray, 150 | params: utils.Params 151 | ) -> jnp.ndarray: 152 | l_1 = utils.expand_to_rank_right(params["l_1"], 2) 153 | l_2 = utils.expand_to_rank_right(params["l_2"], 2) 154 | y_1 = jnp.sin(position[..., 0] - jnp.pi / 2.0) * l_1 155 | x_1 = jnp.cos(position[..., 0] - jnp.pi / 2.0) * l_1 156 | position_1 = jnp.stack([x_1, y_1], axis=-1) 157 | y_2 = jnp.sin(position[..., 1] - jnp.pi / 2.0) * l_2 158 | x_2 = jnp.cos(position[..., 1] - jnp.pi / 2.0) * l_2 159 | position_2 = jnp.stack([x_2, y_2], axis=-1) 160 | return jnp.stack([position_1, position_2], axis=-2) 161 | 162 | def render_trajectories( 163 | self, 164 | position: jnp.ndarray, 165 | params: utils.Params, 166 | rng_key: jnp.ndarray, 167 | **kwargs: Any 168 | ) -> Tuple[jnp.ndarray, utils.Params]: 169 | n, _, d = position.shape 170 | assert d == self.system_dims 171 | assert len(params) == 5 172 | key1, key2 = jax.random.split(rng_key, 2) 173 | m_1 = params["m_1"] 174 | m_2 = params["m_2"] 175 | position = self.canvas_position(position, params) 176 | position_1, position_2 = position[..., 0, :], position[..., 1, :] 177 | if self.randomize_canvas_location: 178 | offset = jax.random.uniform(key1, shape=[n, 2]) 179 | offset = self.random_offset_bounds().convert_from_unit_interval(offset) 180 | else: 181 | offset = jnp.zeros([n, 2]) 182 | position_1 = position_1 + offset[:, None, :] 183 | position_2 = position_1 + position_2 184 | particles = jnp.stack([position_1, position_2], axis=-2) 185 | radius_1 = jnp.sqrt(m_1 / jnp.pi) 186 | radius_2 = jnp.sqrt(m_2 / jnp.pi) 187 | particles_radius = jnp.stack([radius_1, radius_2], axis=-1) 188 | if self.num_colors == 1: 189 | color_index = jnp.zeros([n, 2]).astype("int64") 190 | else: 191 | color_index = utils.random_int_k_from_n( 192 | key2, 193 | num_samples=n, 194 | n=self.num_colors, 195 | k=2 196 | ) 197 | images = self._batch_render(particles, particles_radius, color_index) 198 | return images, dict(offset=offset, color_index=color_index) 199 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/hamiltonian_systems/ideal_mass_spring.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Ideal mass spring.""" 16 | import functools 17 | from typing import Any, Optional, Tuple 18 | 19 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import hamiltonian 20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space 21 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils 22 | 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | 27 | class IdealMassSpring(hamiltonian.TimeIndependentHamiltonianSystem): 28 | """An idealized mass-spring system (also known as a harmonica oscillator). 29 | 30 | The system is represented in 2 dimensions, but the spring moves only in the 31 | vertical orientation. 32 | 33 | Parameters: 34 | k_range - possible range of the spring's force coefficient 35 | m_range - possible range of the particle mass 36 | 37 | The Hamiltonian is: 38 | k * q^2 / 2.0 + p^2 / (2 * m) 39 | 40 | Initial state parameters: 41 | radius_range - The initial state is sampled from a disk in phase space with 42 | radius in this range. 43 | uniform_annulus - Whether to sample uniformly on the disk or uniformly the 44 | radius. 45 | randomize_x - Whether to randomize the horizontal position of the particle 46 | when rendering. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | k_range: utils.BoxRegion, 52 | m_range: utils.BoxRegion, 53 | radius_range: utils.BoxRegion, 54 | uniform_annulus: bool = True, 55 | randomize_x: bool = True, 56 | **kwargs): 57 | super().__init__(system_dims=1, **kwargs) 58 | self.k_range = k_range 59 | self.m_range = m_range 60 | self.radius_range = radius_range 61 | self.uniform_annulus = uniform_annulus 62 | self.randomize_x = randomize_x 63 | render = functools.partial(utils.render_particles_trajectory, 64 | canvas_limits=self.full_canvas_bounds(), 65 | resolution=self.resolution, 66 | num_colors=self.num_colors) 67 | self._batch_render = jax.vmap(render) 68 | 69 | def _hamiltonian( 70 | self, 71 | y: phase_space.PhaseSpace, 72 | params: utils.Params, 73 | **kwargs: Any 74 | ) -> jnp.ndarray: 75 | assert len(params) == 2 76 | k = params["k"] 77 | m = params["m"] 78 | potential = k * y.q[..., 0] ** 2 / 2 79 | kinetic = y.p[..., 0] ** 2 / (2 * m) 80 | return potential + kinetic 81 | 82 | def sample_y( 83 | self, 84 | num_samples: int, 85 | params: utils.Params, 86 | rng_key: jnp.ndarray, 87 | **kwargs: Any 88 | ) -> phase_space.PhaseSpace: 89 | assert len(params) == 2 90 | k = params["k"] 91 | m = params["m"] 92 | state = utils.uniform_annulus( 93 | rng_key, num_samples, 2, self.radius_range, self.uniform_annulus) 94 | q = state[..., :1] 95 | p = state[..., 1:] * jnp.sqrt(k * m) 96 | return phase_space.PhaseSpace(position=q, momentum=p) 97 | 98 | def sample_params( 99 | self, 100 | num_samples: int, 101 | rng_key: jnp.ndarray, 102 | **kwargs: Any 103 | ) -> utils.Params: 104 | key1, key2 = jax.random.split(rng_key) 105 | k = jax.random.uniform(key1, [num_samples], minval=self.k_range.min, 106 | maxval=self.k_range.max) 107 | m = jax.random.uniform(key2, [num_samples], minval=self.m_range.min, 108 | maxval=self.m_range.max) 109 | return dict(k=k, m=m) 110 | 111 | def simulate_analytically( 112 | self, 113 | y0: phase_space.PhaseSpace, 114 | t0: utils.FloatArray, 115 | t_eval: jnp.ndarray, 116 | params: utils.Params, 117 | **kwargs: Any 118 | ) -> Optional[phase_space.PhaseSpace]: 119 | if self.friction != 0.0: 120 | return None 121 | assert len(params) == 2 122 | k = params["k"] 123 | m = params["m"] 124 | t = t_eval - t0 125 | w = jnp.sqrt(k / m).astype(self.dtype) 126 | a = jnp.sqrt(y0.q[..., 0] ** 2 + y0.p[..., 0] ** 2 / (k * m)) 127 | b = jnp.arctan2(- y0.p[..., 0], y0.q[..., 0] * m * w) 128 | w, a, b, m = w[..., None], a[..., None], b[..., None], m[..., None] 129 | t = utils.expand_to_rank_right(t, y0.q.ndim + 1) 130 | 131 | q = a * jnp.cos(w * t + b) 132 | p = - a * m * w * jnp.sin(w * t + b) 133 | return phase_space.PhaseSpace(position=q, momentum=p) 134 | 135 | def canvas_bounds(self) -> utils.BoxRegion: 136 | max_x = self.radius_range.max 137 | max_r = jnp.sqrt(self.m_range.max / jnp.pi) 138 | return utils.BoxRegion(- max_x - max_r, max_x + max_r) 139 | 140 | def canvas_position( 141 | self, 142 | position: jnp.ndarray, 143 | params: utils.Params 144 | ) -> jnp.ndarray: 145 | return jnp.stack([jnp.zeros_like(position), position], axis=-1) 146 | 147 | def render_trajectories( 148 | self, 149 | position: jnp.ndarray, 150 | params: utils.Params, 151 | rng_key: jnp.ndarray, 152 | **kwargs: Any 153 | ) -> Tuple[jnp.ndarray, utils.Params]: 154 | n, _, d = position.shape 155 | assert d == self.system_dims 156 | assert len(params) == 2 157 | key1, key2 = jax.random.split(rng_key) 158 | m = utils.expand_to_rank_right(params["m"], 2) 159 | particles = self.canvas_position(position, params) 160 | if self.randomize_x: 161 | x_offset = jax.random.uniform(key1, shape=[n]) 162 | y_offset = jnp.zeros_like(x_offset) 163 | offset = jnp.stack([x_offset, y_offset], axis=-1) 164 | else: 165 | offset = jnp.zeros([n, d]) 166 | if self.randomize_canvas_location: 167 | offset_ = jax.random.uniform(key2, shape=[n, d]) 168 | offset_ = self.random_offset_bounds().convert_from_unit_interval(offset_) 169 | offset = offset + offset_ 170 | particles = particles + offset[:, None, None, :] 171 | particles_radius = jnp.sqrt(m / jnp.pi) 172 | if self.num_colors == 1: 173 | color_index = jnp.zeros([n, 1]).astype("int64") 174 | else: 175 | color_index = jax.random.randint( 176 | key=rng_key, shape=[n, 1], minval=0, maxval=self.num_colors) 177 | images = self._batch_render(particles, particles_radius, color_index) 178 | return images, dict(offset=offset, color_index=color_index) 179 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/hamiltonian_systems/ideal_pendulum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Ideal pendulum.""" 16 | import functools 17 | from typing import Any, Optional, Tuple 18 | 19 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import hamiltonian 20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space 21 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils 22 | 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | 27 | class IdealPendulum(hamiltonian.TimeIndependentHamiltonianSystem): 28 | """An idealized pendulum system. 29 | 30 | Parameters: 31 | m_range - possible range of the particle mass 32 | g_range - possible range of the gravitational force 33 | l_range - possible range of the length of the pendulum 34 | 35 | The Hamiltonian is: 36 | m * l * g * (1 - cos(q)) + p^2 / (2 * m * l^2) 37 | 38 | Initial state parameters: 39 | radius_range - The initial state is sampled from a disk in phase space with 40 | radius in this range. 41 | uniform_annulus - Whether to sample uniformly on the disk or uniformly the 42 | radius. 43 | randomize_canvas_location - Whether to randomize th vertical position of the 44 | particle when rendering. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | m_range: utils.BoxRegion, 50 | g_range: utils.BoxRegion, 51 | l_range: utils.BoxRegion, 52 | radius_range: utils.BoxRegion, 53 | uniform_annulus: bool = True, 54 | **kwargs): 55 | super().__init__(system_dims=1, **kwargs) 56 | self.m_range = m_range 57 | self.g_range = g_range 58 | self.l_range = l_range 59 | self.radius_range = radius_range 60 | self.uniform_annulus = uniform_annulus 61 | render = functools.partial(utils.render_particles_trajectory, 62 | canvas_limits=self.full_canvas_bounds(), 63 | resolution=self.resolution, 64 | num_colors=self.num_colors) 65 | self._batch_render = jax.vmap(render) 66 | 67 | def _hamiltonian( 68 | self, 69 | y: phase_space.PhaseSpace, 70 | params: utils.Params, 71 | **kwargs: Any 72 | ) -> jnp.ndarray: 73 | assert len(params) == 3 74 | m = params["m"] 75 | l = params["l"] 76 | g = params["g"] 77 | potential = m * g * l * (1 - jnp.cos(y.q[..., 0])) 78 | kinetic = y.p[..., 0] ** 2 / (2 * m * l ** 2) 79 | return potential + kinetic 80 | 81 | def sample_y( 82 | self, 83 | num_samples: int, 84 | params: utils.Params, 85 | rng_key: jnp.ndarray, 86 | **kwargs: Any 87 | ) -> phase_space.PhaseSpace: 88 | state = utils.uniform_annulus( 89 | rng_key, num_samples, 2, self.radius_range, self.uniform_annulus) 90 | return phase_space.PhaseSpace.from_state(state.astype(self.dtype)) 91 | 92 | def sample_params( 93 | self, 94 | num_samples: int, 95 | rng_key: jnp.ndarray, 96 | **kwargs: Any 97 | ) -> utils.Params: 98 | key1, key2, key3 = jax.random.split(rng_key, 3) 99 | m = jax.random.uniform(key1, [num_samples], minval=self.m_range.min, 100 | maxval=self.m_range.max) 101 | l = jax.random.uniform(key2, [num_samples], minval=self.l_range.min, 102 | maxval=self.l_range.max) 103 | g = jax.random.uniform(key3, [num_samples], minval=self.g_range.min, 104 | maxval=self.g_range.max) 105 | return dict(m=m, l=l, g=g) 106 | 107 | def simulate_analytically( 108 | self, 109 | y0: phase_space.PhaseSpace, 110 | t0: utils.FloatArray, 111 | t_eval: jnp.ndarray, 112 | params: utils.Params, 113 | **kwargs: Any 114 | ) -> Optional[phase_space.PhaseSpace]: 115 | return None 116 | 117 | def canvas_bounds(self) -> utils.BoxRegion: 118 | max_d = self.l_range.max + jnp.sqrt(self.m_range.max / jnp.pi) 119 | return utils.BoxRegion(-max_d, max_d) 120 | 121 | def canvas_position( 122 | self, 123 | position: jnp.ndarray, 124 | params: utils.Params 125 | ) -> jnp.ndarray: 126 | l = utils.expand_to_rank_right(params["l"], 2) 127 | y = jnp.sin(position[..., 0] - jnp.pi / 2.0) * l 128 | x = jnp.cos(position[..., 1] - jnp.pi / 2.0) * l 129 | return jnp.stack([x, y], axis=-1) 130 | 131 | def render_trajectories( 132 | self, 133 | position: jnp.ndarray, 134 | params: utils.Params, 135 | rng_key: jnp.ndarray, 136 | **kwargs: Any 137 | ) -> Tuple[jnp.ndarray, utils.Params]: 138 | n, _, d = position.shape 139 | assert d == self.system_dims 140 | assert len(params) == 3 141 | m = utils.expand_to_rank_right(params["m"], 2) 142 | key1, key2 = jax.random.split(rng_key, 2) 143 | particles = self.canvas_position(position, params) 144 | if self.randomize_canvas_location: 145 | offset = jax.random.uniform(key1, shape=[n, 2]) 146 | offset = self.random_offset_bounds().convert_from_unit_interval(offset) 147 | else: 148 | offset = jnp.zeros(shape=[n, 2]) 149 | particles = particles + offset[:, None, :] 150 | particles = particles[..., None, :] 151 | particles_radius = jnp.sqrt(m / jnp.pi) 152 | if self.num_colors == 1: 153 | color_index = jnp.zeros([n, 1]).astype("int64") 154 | else: 155 | color_index = jax.random.randint( 156 | key=key2, shape=[n, 1], minval=0, maxval=self.num_colors) 157 | images = self._batch_render(particles, particles_radius, color_index) 158 | return images, dict(offset=offset, color_index=color_index) 159 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/hamiltonian_systems/n_body.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """N body.""" 16 | import abc 17 | import functools 18 | from typing import Any, Optional, Tuple 19 | 20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import hamiltonian 21 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space 22 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils 23 | 24 | import jax 25 | import jax.numpy as jnp 26 | 27 | 28 | class NBodySystem(hamiltonian.TimeIndependentHamiltonianSystem): 29 | """An N-body system abstract class. 30 | 31 | Parameters: 32 | m_range - possible range of the particle mass 33 | g_range - possible range of the gravitational force 34 | provided_canvas_bounds - The canvas bounds for the given ranges 35 | 36 | The Hamiltonian is: 37 | - sum_i jnp.ndarray: 77 | assert len(params) == 2 78 | m = params["m"] 79 | g = params["g"] 80 | q = y.q.reshape([-1, self.n, self.space_dims]) 81 | p = y.p.reshape([-1, self.n, self.space_dims]) 82 | 83 | q_ij = jnp.matmul(q, jnp.swapaxes(q, axis1=-1, axis2=-2)) 84 | q_ii = jnp.diagonal(q_ij, axis1=-1, axis2=-2) 85 | q_ij_norms_2 = q_ii[:, None, :] + q_ii[:, :, None] - 2.0 * q_ij 86 | # Adding identity so that on the diagonal the norms are not 0 87 | q_ij_norms = jnp.sqrt(q_ij_norms_2 + jnp.identity(self.n)) 88 | masses_ij = m[:, None, :] * m[:, :, None] 89 | # Remove masses in the diagonal so that those potentials are 0 90 | masses_ij = masses_ij - masses_ij * jnp.identity(self.n)[None] 91 | # Compute pairwise interactions 92 | products = g[:, None, None] * masses_ij / q_ij_norms 93 | # Note that here we are summing both i->j and j->i hence the division by 2 94 | potential = - products.sum(axis=(-2, -1)) / 2 95 | kinetic = jnp.sum(p ** 2, axis=-1) / (2.0 * m) 96 | kinetic = kinetic.sum(axis=-1) 97 | return potential + kinetic 98 | 99 | @abc.abstractmethod 100 | def sample_y( 101 | self, 102 | num_samples: int, 103 | params: utils.Params, 104 | rng_key: jnp.ndarray, 105 | **kwargs: Any 106 | ) -> phase_space.PhaseSpace: 107 | pass 108 | 109 | def sample_params( 110 | self, 111 | num_samples: int, 112 | rng_key: jnp.ndarray, 113 | **kwargs: Any 114 | ) -> utils.Params: 115 | key1, key2 = jax.random.split(rng_key) 116 | m = jax.random.uniform(key1, [num_samples, self.n], minval=self.m_range.min, 117 | maxval=self.m_range.max) 118 | g = jax.random.uniform(key2, [num_samples], minval=self.g_range.min, 119 | maxval=self.g_range.max) 120 | return dict(m=m, g=g) 121 | 122 | def simulate_analytically( 123 | self, 124 | y0: phase_space.PhaseSpace, 125 | t0: utils.FloatArray, 126 | t_eval: jnp.ndarray, 127 | params: utils.Params, 128 | **kwargs: Any 129 | ) -> Optional[phase_space.PhaseSpace]: 130 | return None 131 | 132 | def canvas_bounds(self) -> utils.BoxRegion: 133 | return self.provided_canvas_bounds 134 | 135 | def canvas_position( 136 | self, 137 | position: jnp.ndarray, 138 | params: utils.Params 139 | ) -> jnp.ndarray: 140 | return position 141 | 142 | def render_trajectories( 143 | self, 144 | position: jnp.ndarray, 145 | params: utils.Params, 146 | rng_key: jnp.ndarray, 147 | **kwargs: Any 148 | ) -> Tuple[jnp.ndarray, utils.Params]: 149 | n, _, d = position.shape 150 | assert d == self.system_dims 151 | assert len(params) == 2 152 | key1, key2 = jax.random.split(rng_key, 2) 153 | m = utils.expand_to_rank_right(params["m"], 2) 154 | particles = position.reshape([n, -1, self.n, self.space_dims]) 155 | if self.randomize_canvas_location: 156 | offset = jax.random.uniform(key1, shape=[n, self.space_dims]) 157 | offset = self.random_offset_bounds().convert_from_unit_interval(offset) 158 | else: 159 | offset = jnp.zeros(shape=[n, self.space_dims]) 160 | particles = particles + offset[:, None, None, :] 161 | particles_radius = jnp.sqrt(m / jnp.pi) 162 | if self.num_colors == 1: 163 | color_index = jnp.zeros([n, self.n]).astype("int64") 164 | else: 165 | if self.num_colors < self.n: 166 | raise ValueError("The number of colors must be at least the number of " 167 | "objects or 1.") 168 | color_index = utils.random_int_k_from_n( 169 | key2, 170 | num_samples=n, 171 | n=self.num_colors, 172 | k=self.n 173 | ) 174 | images = self._batch_render(particles, particles_radius, color_index) 175 | return images, dict(offset=offset, color_index=color_index) 176 | 177 | 178 | class TwoBodySystem(NBodySystem): 179 | """N-body system with N = 2.""" 180 | 181 | def __init__( 182 | self, 183 | m_range: utils.BoxRegion, 184 | g_range: utils.BoxRegion, 185 | radius_range: utils.BoxRegion, 186 | provided_canvas_bounds: utils.BoxRegion, 187 | **kwargs): 188 | self.radius_range = radius_range 189 | super().__init__(n=2, space_dims=2, 190 | m_range=m_range, 191 | g_range=g_range, 192 | provided_canvas_bounds=provided_canvas_bounds, 193 | **kwargs) 194 | 195 | def sample_y( 196 | self, 197 | num_samples: int, 198 | params: utils.Params, 199 | rng_key: jnp.ndarray, 200 | **kwargs: Any 201 | ) -> phase_space.PhaseSpace: 202 | pos = jax.random.uniform(rng_key, [num_samples, self.n]) 203 | pos = self.radius_range.convert_from_unit_interval(pos) 204 | r = jnp.sqrt(jnp.sum(pos ** 2, axis=-1)) 205 | 206 | vel = jnp.flip(pos, axis=-1) / (2 * r[..., None] ** 1.5) 207 | vel = vel * jnp.asarray([1.0, -1.0]).reshape([1, 2]) 208 | 209 | pos = jnp.repeat(pos.reshape([num_samples, 1, -1]), repeats=self.n, axis=1) 210 | vel = jnp.repeat(vel.reshape([num_samples, 1, -1]), repeats=self.n, axis=1) 211 | 212 | pos = pos * jnp.asarray([1.0, -1.0]).reshape([1, 2, 1]) 213 | vel = vel * jnp.asarray([1.0, -1.0]).reshape([1, 2, 1]) 214 | pos = pos.reshape([num_samples, -1]) 215 | vel = vel.reshape([num_samples, -1]) 216 | return phase_space.PhaseSpace(position=pos, momentum=vel) 217 | 218 | 219 | class ThreeBody2DSystem(NBodySystem): 220 | """N-body system with N = 3 in two dimensions.""" 221 | 222 | def __init__( 223 | self, 224 | m_range: utils.BoxRegion, 225 | g_range: utils.BoxRegion, 226 | radius_range: utils.BoxRegion, 227 | provided_canvas_bounds: utils.BoxRegion, 228 | **kwargs): 229 | self.radius_range = radius_range 230 | super().__init__(n=3, space_dims=2, 231 | m_range=m_range, 232 | g_range=g_range, 233 | provided_canvas_bounds=provided_canvas_bounds, 234 | **kwargs) 235 | 236 | def sample_y( 237 | self, 238 | num_samples: int, 239 | params: utils.Params, 240 | rng_key: jnp.ndarray, 241 | **kwargs: Any 242 | ) -> phase_space.PhaseSpace: 243 | theta = 2 * jnp.pi / 3 244 | rot = jnp.asarray([[jnp.cos(theta), - jnp.sin(theta)], 245 | [jnp.sin(theta), jnp.cos(theta)]]) 246 | p1 = 2 * jax.random.uniform(rng_key, [num_samples, 2]) - 1.0 247 | r = jax.random.uniform(rng_key, [num_samples]) 248 | r = self.radius_range.convert_from_unit_interval(r) 249 | 250 | p1 *= (r / jnp.linalg.norm(p1, axis=-1))[:, None] 251 | p2 = jnp.matmul(p1, rot.T) 252 | p3 = jnp.matmul(p2, rot.T) 253 | p = jnp.concatenate([p1, p2, p3], axis=-1) 254 | 255 | # scale factor to get circular trajectories 256 | factor = jnp.sqrt(jnp.sin(jnp.pi / 3)/(2 * jnp.cos(jnp.pi / 6) **2)) 257 | # velocity that yields a circular orbit 258 | v1 = jnp.flip(p1, axis=-1) * factor / r[:, None]**1.5 259 | v2 = jnp.matmul(v1, rot.T) 260 | v3 = jnp.matmul(v2, rot.T) 261 | v = jnp.concatenate([v1, v2, v3], axis=-1) 262 | return phase_space.PhaseSpace(position=p, momentum=v) 263 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/hamiltonian_systems/phase_space.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Module for the PhaseSpace class.""" 16 | import functools 17 | from typing import Callable, Type, Union 18 | 19 | import jax 20 | from jax import numpy as jnp 21 | from jax import tree_util 22 | 23 | 24 | class PhaseSpace(object): 25 | """Holds a pair of position and momentum for a Hamiltonian System.""" 26 | 27 | def __init__(self, position: jnp.ndarray, momentum: jnp.ndarray): 28 | self._position = position 29 | self._momentum = momentum 30 | 31 | @property 32 | def position(self) -> jnp.ndarray: 33 | """The position element of the phase space.""" 34 | return self._position 35 | 36 | @property 37 | def momentum(self) -> jnp.ndarray: 38 | """The momentum element of the phase space.""" 39 | return self._momentum 40 | 41 | @property 42 | def q(self) -> jnp.ndarray: 43 | """A shorthand for the position element of the phase space.""" 44 | return self._position 45 | 46 | @property 47 | def p(self) -> jnp.ndarray: 48 | """A shorthand for the momentum element of the phase space.""" 49 | return self._momentum 50 | 51 | @property 52 | def single_state(self) -> jnp.ndarray: 53 | """Returns the concatenation of position and momentum.""" 54 | return jnp.concatenate([self.q, self.p], axis=-1) 55 | 56 | @property 57 | def ndim(self) -> int: 58 | """Returns the number of dimensions of the position array.""" 59 | return self.q.ndim 60 | 61 | @classmethod 62 | def from_state(cls: Type["PhaseSpace"], state: jnp.ndarray) -> "PhaseSpace": 63 | q, p = jnp.split(state, 2, axis=-1) 64 | return cls(position=q, momentum=p) 65 | 66 | def __str__(self) -> str: 67 | return f"{type(self).__name__}(q={self.position}, p={self.momentum})" 68 | 69 | def __repr__(self) -> str: 70 | return self.__str__() 71 | 72 | 73 | class TangentPhaseSpace(PhaseSpace): 74 | """Represents the tangent space to PhaseSpace.""" 75 | 76 | def __add__( 77 | self, 78 | other: Union[PhaseSpace, "TangentPhaseSpace"], 79 | ) -> Union[PhaseSpace, "TangentPhaseSpace"]: 80 | if isinstance(other, TangentPhaseSpace): 81 | return TangentPhaseSpace(position=self.q + other.q, 82 | momentum=self.p + other.p) 83 | elif isinstance(other, PhaseSpace): 84 | return PhaseSpace(position=self.q + other.q, 85 | momentum=self.p + other.p) 86 | else: 87 | raise ValueError(f"Can not add TangentPhaseSpace and {type(other)}.") 88 | 89 | def __radd__( 90 | self, 91 | other: Union[PhaseSpace, "TangentPhaseSpace"] 92 | ) -> Union[PhaseSpace, "TangentPhaseSpace"]: 93 | return self.__add__(other) 94 | 95 | def __mul__(self, other: jnp.ndarray) -> "TangentPhaseSpace": 96 | return TangentPhaseSpace(position=self.q * other, 97 | momentum=self.p * other) 98 | 99 | def __rmul__(self, other): 100 | return self.__mul__(other) 101 | 102 | @classmethod 103 | def zero(cls: Type["TangentPhaseSpace"]) -> "TangentPhaseSpace": 104 | return cls(position=jnp.asarray(0.0), momentum=jnp.asarray(0.0)) 105 | 106 | 107 | HamiltonianFunction = Callable[ 108 | [ 109 | jnp.ndarray, # t 110 | PhaseSpace, # y 111 | ], 112 | jnp.ndarray # H(t, y) 113 | ] 114 | 115 | SymplecticTangentFunction = Callable[ 116 | [ 117 | jnp.ndarray, # t 118 | PhaseSpace # (q, p) 119 | ], 120 | TangentPhaseSpace # (dH_dp, - dH_dq) 121 | ] 122 | 123 | SymplecticTangentFunctionArray = Callable[ 124 | [ 125 | jnp.ndarray, # t 126 | jnp.ndarray # (q, p) 127 | ], 128 | jnp.ndarray # (dH_dp, - dH_dq) 129 | ] 130 | 131 | 132 | def poisson_bracket_with_q_and_p( 133 | f: HamiltonianFunction 134 | ) -> SymplecticTangentFunction: 135 | """Returns a function that computes the Poisson brackets {q,f} and {p,f}.""" 136 | def bracket(t: jnp.ndarray, y: PhaseSpace) -> TangentPhaseSpace: 137 | # Use the summation trick for getting gradient 138 | # Note that the first argument to the hamiltonian is t 139 | grad = jax.grad(lambda *args: jnp.sum(f(*args)), argnums=1)(t, y) 140 | return TangentPhaseSpace(position=grad.p, momentum=-grad.q) 141 | return bracket 142 | 143 | 144 | def transform_symplectic_tangent_function_using_array( 145 | func: SymplecticTangentFunction 146 | ) -> SymplecticTangentFunctionArray: 147 | @functools.wraps(func) 148 | def wrapped(t: jnp.ndarray, state: jnp.ndarray) -> jnp.ndarray: 149 | return func(t, PhaseSpace.from_state(state)).single_state 150 | return wrapped 151 | 152 | 153 | tree_util.register_pytree_node( 154 | nodetype=PhaseSpace, 155 | flatten_func=lambda y: ((y.q, y.p), None), 156 | unflatten_func=lambda _, q_and_p: PhaseSpace(*q_and_p) 157 | ) 158 | 159 | tree_util.register_pytree_node( 160 | nodetype=TangentPhaseSpace, 161 | flatten_func=lambda y: ((y.q, y.p), None), 162 | unflatten_func=lambda _, q_and_p: TangentPhaseSpace(*q_and_p) 163 | ) 164 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/hamiltonian_systems/simple_analytic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A module with all Hamiltonian systems that have analytic solutions.""" 16 | from typing import Any, Optional, Tuple 17 | 18 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import hamiltonian 19 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import phase_space 20 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils 21 | 22 | import jax.numpy as jnp 23 | import jax.random as jnr 24 | 25 | 26 | class PotentialFreeSystem(hamiltonian.TimeIndependentHamiltonianSystem): 27 | """A system where the potential energy is 0 and the kinetic is quadratic. 28 | 29 | Parameters: 30 | matrix - a positive semi-definite matrix used for the kinetic quadratic. 31 | 32 | The Hamiltonian is: 33 | p^T M p / 2 34 | 35 | Initial state parameters: 36 | min_radius - the minimum radius to sample from 37 | max_radius - the maximum radius to sample from 38 | """ 39 | 40 | def __init__( 41 | self, 42 | system_dims: int, 43 | eigen_values_range: utils.BoxRegion, 44 | init_vector_range: utils.BoxRegion, 45 | **kwargs): 46 | super().__init__(system_dims=system_dims, **kwargs) 47 | if eigen_values_range.dims != 0 and eigen_values_range.dims != system_dims: 48 | raise ValueError(f"The eigen_values_range must be of the same dimensions " 49 | f"as the system dimensions, but is " 50 | f"{eigen_values_range.dims}.") 51 | if init_vector_range.dims != 0 and init_vector_range.dims != system_dims: 52 | raise ValueError(f"The init_vector_range must be of the same dimensions " 53 | f"as the system dimensions, but is " 54 | f"{init_vector_range.dims}.") 55 | self.eigen_values_range = eigen_values_range 56 | self.init_vector_range = init_vector_range 57 | 58 | def _hamiltonian( 59 | self, 60 | y: phase_space.PhaseSpace, 61 | params: utils.Params, 62 | **kwargs: Any 63 | ) -> jnp.ndarray: 64 | assert len(params) == 1 65 | matrix = params["matrix"] 66 | potential = 0 67 | kinetic = jnp.sum(jnp.matmul(y.p, matrix) * y.p, axis=-1) / 2 68 | return potential + kinetic 69 | 70 | def sample_y( 71 | self, 72 | num_samples: int, 73 | params: utils.Params, 74 | rng_key: jnp.ndarray, 75 | **kwargs: Any 76 | ) -> phase_space.PhaseSpace: 77 | # Sample random state 78 | y = jnr.uniform(rng_key, [num_samples, 2 * self.system_dims], 79 | dtype=self.dtype) 80 | y = self.init_vector_range.convert_from_unit_interval(y) 81 | return phase_space.PhaseSpace.from_state(y) 82 | 83 | def sample_params( 84 | self, 85 | num_samples: int, 86 | rng_key: jnp.ndarray, 87 | **kwargs: Any 88 | ) -> utils.Params: 89 | key1, key2 = jnr.split(rng_key) 90 | matrix_shape = [num_samples, self.system_dims, self.system_dims] 91 | gaussian = jnr.normal(key1, matrix_shape) 92 | q, _ = jnp.linalg.qr(gaussian) 93 | eigs = jnr.uniform(key2, [num_samples, self.system_dims]) 94 | eigs = self.eigen_values_range.convert_from_unit_interval(eigs) 95 | q_eigs = q * eigs[..., None] 96 | matrix = jnp.matmul(q_eigs, jnp.swapaxes(q_eigs, -2, -1)) 97 | return dict(matrix=matrix) 98 | 99 | def simulate_analytically( 100 | self, 101 | y0: phase_space.PhaseSpace, 102 | t0: utils.FloatArray, 103 | t_eval: jnp.ndarray, 104 | params: utils.Params, 105 | **kwargs: Any 106 | ) -> Optional[phase_space.PhaseSpace]: 107 | if self.friction != 0.0: 108 | return None 109 | assert len(params) == 1 110 | matrix = params["matrix"] 111 | t = utils.expand_to_rank_right(t_eval - t0, y0.q.ndim + 1) 112 | q = y0.q[None] + utils.vecmul(matrix, y0.p)[None] * t 113 | p = y0.p[None] * jnp.ones_like(t) 114 | return phase_space.PhaseSpace(position=q, momentum=p) 115 | 116 | def canvas_bounds(self) -> utils.BoxRegion: 117 | raise NotImplementedError() 118 | 119 | def canvas_position( 120 | self, 121 | position: jnp.ndarray, 122 | params: utils.Params 123 | ) -> jnp.ndarray: 124 | raise NotImplementedError() 125 | 126 | def render_trajectories( 127 | self, 128 | position: jnp.ndarray, 129 | params: utils.Params, 130 | rng_key: jnp.ndarray, 131 | **kwargs: Any 132 | ) -> Tuple[jnp.ndarray, utils.Params]: 133 | raise NotImplementedError() 134 | 135 | 136 | class KineticFreeSystem(PotentialFreeSystem): 137 | """A system where the kinetic energy is 0 and the potential is quadratic. 138 | 139 | Parameters: 140 | matrix - a positive semi-definite matrix used for the potential quadratic. 141 | 142 | The Hamiltonian is: 143 | q^T M q / 2 144 | 145 | Initial state parameters: 146 | min_radius - the minimum radius to sample from 147 | max_radius - the maximum radius to sample from 148 | """ 149 | 150 | def _hamiltonian( 151 | self, 152 | y: phase_space.PhaseSpace, 153 | params: utils.Params, 154 | **kwargs: Any 155 | ) -> jnp.ndarray: 156 | assert len(params) == 1 157 | matrix = params["matrix"] 158 | potential = jnp.sum(jnp.matmul(y.q, matrix) * y.q, axis=-1) / 2 159 | kinetic = 0 160 | return potential + kinetic 161 | 162 | def simulate_analytically( 163 | self, 164 | y0: phase_space.PhaseSpace, 165 | t0: utils.FloatArray, 166 | t_eval: jnp.ndarray, 167 | params: utils.Params, 168 | **kwargs: Any 169 | ) -> Optional[phase_space.PhaseSpace]: 170 | if self.friction != 0.0: 171 | return None 172 | assert len(params) == 1 173 | matrix = params["matrix"] 174 | t = utils.expand_to_rank_right(t_eval - t0, y0.q.ndim + 1) 175 | q = y0.q[None] * jnp.ones_like(t) 176 | p = y0.p[None] - utils.vecmul(matrix, y0.q)[None] * t 177 | return phase_space.PhaseSpace(position=q, momentum=p) 178 | 179 | def canvas_bounds(self) -> utils.BoxRegion: 180 | raise NotImplementedError() 181 | 182 | def canvas_position( 183 | self, 184 | position: jnp.ndarray, 185 | params: utils.Params 186 | ) -> jnp.ndarray: 187 | raise NotImplementedError() 188 | 189 | def render_trajectories( 190 | self, 191 | position: jnp.ndarray, 192 | params: utils.Params, 193 | rng_key: jnp.ndarray, 194 | **kwargs: Any 195 | ) -> Tuple[jnp.ndarray, utils.Params]: 196 | raise NotImplementedError() 197 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/hamiltonian_systems/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Module with various utilities not central to the code.""" 16 | from typing import Dict, Optional, Tuple, Union 17 | 18 | import jax 19 | from jax import lax 20 | import jax.numpy as jnp 21 | import jax.random as jnr 22 | import numpy as np 23 | 24 | 25 | FloatArray = Union[float, jnp.ndarray] 26 | Params = Dict[str, jnp.ndarray] 27 | 28 | 29 | class BoxRegion: 30 | """A class for bounds, especially used for sampling.""" 31 | 32 | def __init__(self, minimum: FloatArray, maximum: FloatArray): 33 | minimum = jnp.asarray(minimum) 34 | maximum = jnp.asarray(maximum) 35 | if minimum.shape != maximum.shape: 36 | raise ValueError(f"The passed values for minimum and maximum should have " 37 | f"the same shape, but had shapes: {minimum.shape} " 38 | f"and {maximum.shape}.") 39 | self._minimum = minimum 40 | self._maximum = maximum 41 | 42 | @property 43 | def min(self) -> jnp.ndarray: 44 | return self._minimum 45 | 46 | @property 47 | def max(self) -> jnp.ndarray: 48 | return self._maximum 49 | 50 | @property 51 | def size(self) -> jnp.ndarray: 52 | return self.max - self.min 53 | 54 | @property 55 | def dims(self) -> int: 56 | if self.min.ndim != 0: 57 | return self.min.shape[-1] 58 | return 0 59 | 60 | def convert_to_unit_interval(self, value: jnp.ndarray) -> jnp.ndarray: 61 | return (value - self.min) / self.size 62 | 63 | def convert_from_unit_interval(self, value: jnp.ndarray) -> jnp.ndarray: 64 | return value * self.size + self.min 65 | 66 | def __str__(self) -> str: 67 | return f"{type(self).__name__}(min={self.min}, max={self.max})" 68 | 69 | def __repr__(self) -> str: 70 | return self.__str__() 71 | 72 | 73 | def expand_to_rank_right(x: jnp.ndarray, rank: int) -> jnp.ndarray: 74 | if x.ndim == rank: 75 | return x 76 | assert x.ndim < rank 77 | new_shape = x.shape + (1,) * (rank - x.ndim) 78 | return x.reshape(new_shape) 79 | 80 | 81 | def expand_to_rank_left(x: jnp.ndarray, rank: int) -> int: 82 | if x.ndim == rank: 83 | return x 84 | assert x.ndim < rank 85 | new_shape = (1,) * (rank - x.ndim) + x.shape 86 | return x.reshape(new_shape) 87 | 88 | 89 | def vecmul(matrix: jnp.ndarray, vector: jnp.ndarray) -> jnp.ndarray: 90 | return jnp.matmul(matrix, vector[..., None])[..., 0] 91 | 92 | 93 | def dt_to_t_eval(t0: FloatArray, dt: FloatArray, num_steps: int) -> jnp.ndarray: 94 | if (isinstance(t0, (float, np.ndarray)) and 95 | isinstance(dt, (float, np.ndarray))): 96 | dt = np.asarray(dt)[None] 97 | shape = [num_steps] + [1] * (dt.ndim - 1) 98 | return t0 + dt * np.arange(1, num_steps + 1).reshape(shape) 99 | else: 100 | return t0 + dt * jnp.arange(1, num_steps + 1) 101 | 102 | 103 | def t_eval_to_dt(t0: FloatArray, t_eval: FloatArray) -> jnp.ndarray: 104 | t = jnp.ones_like(t_eval[:1]) * t0 105 | t = jnp.concatenate([t, t_eval], axis=0) 106 | return t[1:] - t[:-1] 107 | 108 | 109 | def simple_loop( 110 | f, 111 | x0: jnp.ndarray, 112 | t_args: Optional[jnp.ndarray] = None, 113 | num_steps: Optional[int] = None, 114 | use_scan: bool = True 115 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 116 | """Runs a simple loop that outputs the evolved variable at every time step.""" 117 | if t_args is None and num_steps is None: 118 | raise ValueError("Exactly one of `t_args` and `num_steps` should be " 119 | "provided.") 120 | if t_args is not None and num_steps is not None: 121 | raise ValueError("Exactly one of `t_args` and `num_steps` should be " 122 | "provided.") 123 | 124 | def step(x_t, t_arg): 125 | x_next = f(x_t) if t_arg is None else f(x_t, t_arg) 126 | return x_next, x_next 127 | if use_scan: 128 | return lax.scan(step, init=x0, xs=t_args, length=num_steps)[1] 129 | 130 | y = [] 131 | x = x0 132 | num_steps = t_args.shape[0] if t_args is not None else num_steps 133 | t_args = [None] * num_steps if t_args is None else t_args 134 | for i in range(num_steps): 135 | x, _ = step(x, t_args[i]) 136 | y.append(x) 137 | return jax.tree_multimap(lambda *args: jnp.stack(args, axis=0), *y) 138 | 139 | 140 | def hsv2rgb(array: jnp.ndarray) -> jnp.ndarray: 141 | """Converts an HSV float array image to RGB.""" 142 | hi = jnp.floor(array[..., 0] * 6) 143 | f = array[..., 0] * 6 - hi 144 | p = array[..., 2] * (1 - array[..., 1]) 145 | q = array[..., 2] * (1 - f * array[..., 1]) 146 | t = array[..., 2] * (1 - (1 - f) * array[..., 1]) 147 | v = array[..., 2] 148 | 149 | hi = jnp.stack([hi, hi, hi], axis=-1).astype(jnp.uint8) % 6 150 | hi_is_0 = hi == 0 151 | hi_is_1 = hi == 1 152 | hi_is_2 = hi == 2 153 | hi_is_3 = hi == 3 154 | hi_is_4 = hi == 4 155 | hi_is_5 = hi == 5 156 | out = (hi_is_0 * jnp.stack((v, t, p), axis=-1) + 157 | hi_is_1 * jnp.stack((q, v, p), axis=-1) + 158 | hi_is_2 * jnp.stack((p, v, t), axis=-1) + 159 | hi_is_3 * jnp.stack((p, q, v), axis=-1) + 160 | hi_is_4 * jnp.stack((t, p, v), axis=-1) + 161 | hi_is_5 * jnp.stack((v, p, q), axis=-1)) 162 | return out 163 | 164 | 165 | def render_particles_trajectory( 166 | particles: jnp.ndarray, 167 | particles_radius: FloatArray, 168 | color_indices: FloatArray, 169 | canvas_limits: BoxRegion, 170 | resolution: int, 171 | num_colors: int, 172 | background_color: Tuple[float, float, float] = (0.321, 0.349, 0.368), 173 | temperature: FloatArray = 80.0): 174 | """Renders n particles in different colors for a full trajectory. 175 | 176 | NB: The default background color is not black as we have experienced issues 177 | when training models with black background. 178 | 179 | Args: 180 | particles: Array of size (t, n, 2) 181 | The last 2 dimensions define the x, y coordinates of each particle. 182 | particles_radius: Array of size (n,) or a single value. 183 | Defines the radius of each particle. 184 | color_indices: Array of size (n,) or a single value. 185 | Defines the color of each particle. 186 | canvas_limits: List of 2 lists or Array of size (2, 2) 187 | First row defines the limit over x and second over y. 188 | resolution: int 189 | The resolution of the produced images. 190 | num_colors: int 191 | The number of possible colors to use. 192 | background_color: List or Array of size (3) 193 | The color for the background. Default to black. 194 | temperature: float 195 | The temperature of the sigmoid distance metric used to the center of the 196 | particles. 197 | 198 | Returns: 199 | An array of size (t, resolution, resolution, 3) with the produced images. 200 | """ 201 | particles = jnp.asarray(particles) 202 | assert particles.ndim == 3 203 | assert particles.shape[-1] == 2 204 | t, n = particles.shape[:2] 205 | particles_radius = jnp.asarray(particles_radius) 206 | if particles_radius.ndim == 0: 207 | particles_radius = jnp.full([n], particles_radius) 208 | assert particles_radius.shape == (n,) 209 | color_indices = jnp.asarray(color_indices) 210 | if color_indices.ndim == 0: 211 | color_indices = jnp.full([n], color_indices) 212 | assert color_indices.shape == (n,), f"Colors shape: {color_indices.shape}" 213 | background_color = jnp.asarray(background_color) 214 | assert background_color.shape == (3,) 215 | 216 | particles = canvas_limits.convert_to_unit_interval(particles) 217 | canvas_size = canvas_limits.max - canvas_limits.min 218 | canvas_size = canvas_size[0] if canvas_size.ndim == 1 else canvas_size 219 | particles_radius = particles_radius / canvas_size 220 | images = jnp.ones([t, resolution, resolution, 3]) * background_color 221 | 222 | hues = jnp.linspace(0, 1, num=num_colors, endpoint=False) 223 | colors = hues[color_indices][None, :, None, None] 224 | s_channel = jnp.ones((t, n, resolution, resolution)) 225 | v_channel = jnp.ones((t, n, resolution, resolution)) 226 | h_channel = jnp.ones((t, n, resolution, resolution)) * colors 227 | hsv_imgs = jnp.stack((h_channel, s_channel, v_channel), axis=-1) 228 | rgb_imgs = hsv2rgb(hsv_imgs) 229 | images = [img[:, 0] for img in jnp.split(rgb_imgs, n, axis=1)] + [images] 230 | 231 | grid = jnp.linspace(0.0, 1.0, resolution) 232 | dx, dy = jnp.meshgrid(grid, grid) 233 | dx, dy = dx[None, None], dy[None, None] 234 | x, y = particles[..., 0][..., None, None], particles[..., 1][..., None, None] 235 | d = jnp.sqrt((x - dx) ** 2 + (y - dy) ** 2) 236 | particles_radius = particles_radius[..., None, None] 237 | mask = 1.0 / (1.0 + jnp.exp((d - particles_radius) * temperature)) 238 | masks = ([m[:, 0, ..., None] for m in jnp.split(mask, n, axis=1)] + 239 | [jnp.ones_like(images[0])]) 240 | 241 | final_image = jnp.zeros([t, resolution, resolution, 3]) 242 | c = jnp.ones_like(images[0]) 243 | for img, m in zip(images, masks): 244 | final_image = final_image + c * m * img 245 | c = c * (1 - m) 246 | return final_image 247 | 248 | 249 | def uniform_annulus( 250 | key: jnp.ndarray, 251 | num_samples: int, 252 | dim_samples: int, 253 | radius_range: BoxRegion, 254 | uniform: bool 255 | ) -> jnp.ndarray: 256 | """Samples points uniformly in the annulus defined by radius range.""" 257 | key1, key2 = jnr.split(key) 258 | direction = jnr.normal(key1, [num_samples, dim_samples]) 259 | norms = jnp.linalg.norm(direction, axis=-1, keepdims=True) 260 | direction = direction / norms 261 | # Sample a radius uniformly between [min_radius, max_radius] 262 | r = jnr.uniform(key2, [num_samples]) 263 | if uniform: 264 | radius_range = BoxRegion(radius_range.min ** 2, radius_range.max ** 2) 265 | r = jnp.sqrt(radius_range.convert_from_unit_interval(r)) 266 | else: 267 | r = radius_range.convert_from_unit_interval(r) 268 | return direction * r[:, None] 269 | 270 | 271 | multi_shuffle = jax.vmap(lambda x, key, k: jnr.permutation(key, x)[:k], 272 | in_axes=(0, 0, None), out_axes=0) 273 | 274 | 275 | def random_int_k_from_n( 276 | rng: jnp.ndarray, 277 | num_samples: int, 278 | n: int, 279 | k: int 280 | ) -> jnp.ndarray: 281 | """Samples randomly k integers from 1 to n.""" 282 | if k > n: 283 | raise ValueError(f"k should be less than or equal to n, but got k={k} and " 284 | f"n={n}.") 285 | x = jnp.repeat(jnp.arange(n).reshape([1, n]), num_samples, axis=0) 286 | return multi_shuffle(x, jnr.split(rng, num_samples), k) 287 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/load_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A module for loading the Hamiltonian datasets.""" 16 | import functools 17 | import os 18 | from typing import Optional, Mapping, Any, Tuple, Sequence, Callable, Union, TypeVar 19 | 20 | import jax 21 | import tensorflow as tf 22 | import tensorflow_datasets as tfds 23 | 24 | 25 | T = TypeVar("T") 26 | 27 | 28 | def filter_based_on_keys( 29 | example: Mapping[str, T], 30 | keys_to_preserve: Sequence[str], 31 | single_key_return_array: bool = False 32 | ) -> Union[T, Mapping[str, T]]: 33 | """Filters the contents of the mapping, to return only the keys given in `keys_to_preserve`.""" 34 | if not keys_to_preserve: 35 | raise ValueError("You must provide at least one key to preserve.") 36 | if len(keys_to_preserve) == 1 and single_key_return_array: 37 | return example[keys_to_preserve[0]] 38 | elif single_key_return_array: 39 | raise ValueError(f"You have provided {len(keys_to_preserve)}>1 keys to " 40 | f"preserve and have also set " 41 | f"single_key_return_array=True.") 42 | return {k: example[k] for k in keys_to_preserve} 43 | 44 | 45 | def preprocess_batch( 46 | batch: Mapping[str, Any], 47 | num_local_devices: int, 48 | multi_device: bool, 49 | sub_sample_length: Optional[int], 50 | dtype: str = "float32" 51 | ) -> Mapping[str, Any]: 52 | """Function to preprocess the data for a batch. 53 | 54 | This performs two functions: 55 | 1.If 'sub_sample_length' is not None, it randomly subsamples every example 56 | along the second (time) axis to return an array of the requested length. 57 | Note that this assumes that all arrays have the same length. 58 | 2. Converts all arrays to the provided data type using 59 | `tf.image.convert_image_dtype`. 60 | 3. Reshapes the array based on the number of local devices. 61 | 62 | Args: 63 | batch: Dictionary with the full batch. 64 | num_local_devices: The number of local devices. 65 | multi_device: Whether to prepare the batch for multi device training. 66 | sub_sample_length: The sub-sampling length requested. 67 | dtype: String for the dtype, must be a floating-point type. 68 | 69 | Returns: 70 | The preprocessed batch. 71 | """ 72 | if dtype not in ("float32", "float64"): 73 | raise ValueError("The provided dtype must be a floating point dtype.") 74 | tensor = batch.get("image", batch.get("x", None)) 75 | if tensor is None: 76 | raise ValueError("We need either the key 'image' or 'x' to be present in " 77 | "the batch provided.") 78 | if not isinstance(tensor, tf.Tensor): 79 | raise ValueError(f"Expecting the value for key 'image' or 'x' to be a " 80 | f"tf.Tensor, instead got {type(tensor)}.") 81 | # `n` here represents the batch size 82 | n = tensor.shape[0] or tf.shape(tensor)[0] 83 | # `t` here represents number of time steps in the batch 84 | t = tensor.shape[1] 85 | if sub_sample_length is not None: 86 | # Sample random index for each example in the batch 87 | limit = t - sub_sample_length + 1 88 | start = tf.random.uniform(shape=[n], maxval=limit, dtype="int32") 89 | indices = tf.range(sub_sample_length)[None, :, None] + start[:, None, None] 90 | def index(x): 91 | """Indexes every array in the batch according to the sampled indices and length, if its second dimensions is equal to `t`.""" 92 | if x.shape.rank > 1 and x.shape[1] == t: 93 | if isinstance(n, tf.Tensor): 94 | shape = [None, sub_sample_length] + list(x.shape[2:]) 95 | else: 96 | shape = [n, sub_sample_length] + list(x.shape[2:]) 97 | x = tf.gather_nd(x, indices, batch_dims=1) 98 | x.set_shape(shape) 99 | return x 100 | 101 | batch = jax.tree_map(index, batch) 102 | 103 | def convert_fn(x): 104 | """Converts the value of `x` to the provided precision dtype. 105 | 106 | Integer valued arrays, with data type different from int32 and int64 are 107 | assumed to represent compressed images and are converted via 108 | `tf.image.convert_image_dtype`. For any other data types (float or int) 109 | their type is preserved, but their precision is changed based on the 110 | target `dtype`. For instance, if `dtype=float32` the float64 variables are 111 | converted to float32 and int64 values are converted to int32. 112 | 113 | Args: 114 | x: The input array. 115 | 116 | Returns: 117 | The converted output array. 118 | """ 119 | if x.dtype == tf.int64: 120 | return tf.cast(x, "int32") if dtype == "float32" else x 121 | elif x.dtype == tf.int32: 122 | return tf.cast(x, "int64") if dtype == "float64" else x 123 | elif x.dtype == tf.float64 or x.dtype == tf.float32: 124 | return tf.cast(x, dtype=dtype) 125 | else: 126 | return tf.image.convert_image_dtype(x, dtype=dtype) 127 | 128 | batch = jax.tree_map(convert_fn, batch) 129 | if not multi_device: 130 | return batch 131 | def reshape_for_jax_pmap(x): 132 | """Reshapes values such that their leading dimension is the number of local devices.""" 133 | return tf.reshape(x, [num_local_devices, -1] + x.shape[1:].as_list()) 134 | return jax.tree_map(reshape_for_jax_pmap, batch) 135 | 136 | 137 | def load_filenames_and_parse_fn( 138 | path: str, 139 | tfrecord_prefix: str 140 | ) -> Tuple[Tuple[str], Callable[[str], Mapping[str, Any]]]: 141 | """Returns the file names and read_fn based on the number of shards.""" 142 | file_name = os.path.join(path, f"{tfrecord_prefix}.tfrecord") 143 | if not os.path.exists(file_name): 144 | raise ValueError(f"The dataset file {file_name} does not exist.") 145 | features_file = os.path.join(path, "features.txt") 146 | if not os.path.exists(features_file): 147 | raise ValueError(f"The dataset features file {features_file} does not " 148 | f"exist.") 149 | with open(features_file, "r") as f: 150 | dtype_dict = dict() 151 | shapes_dict = dict() 152 | parsing_description = dict() 153 | for line in f: 154 | key = line.split(", ")[0] 155 | shape_string = line.split("(")[1].split(")")[0] 156 | shapes_dict[key] = tuple(int(s) for s in shape_string.split(",") if s) 157 | dtype_dict[key] = line.split(", ")[-1][:-1] 158 | if dtype_dict[key] == "uint8": 159 | parsing_description[key] = tf.io.FixedLenFeature([], tf.string) 160 | elif dtype_dict[key] in ("float32", "float64"): 161 | parsing_description[key] = tf.io.VarLenFeature(tf.int64) 162 | else: 163 | parsing_description[key] = tf.io.VarLenFeature(dtype_dict[key]) 164 | 165 | def parse_fn(example_proto: str) -> Mapping[str, Any]: 166 | raw = tf.io.parse_single_example(example_proto, parsing_description) 167 | parsed = dict() 168 | for name, dtype in dtype_dict.items(): 169 | value = raw[name] 170 | if dtype == "uint8": 171 | value = tf.image.decode_png(value) 172 | else: 173 | value = tf.sparse.to_dense(value) 174 | if dtype in ("float32", "float64"): 175 | value = tf.bitcast(value, type=dtype) 176 | value = tf.reshape(value, shapes_dict[name]) 177 | if "/" in name: 178 | k1, k2 = name.split("/") 179 | if k1 not in parsed: 180 | parsed[k1] = dict() 181 | parsed[k1][k2] = value 182 | else: 183 | parsed[name] = value 184 | return parsed 185 | 186 | return (file_name,), parse_fn 187 | 188 | 189 | def load_parsed_dataset( 190 | path: str, 191 | tfrecord_prefix: str, 192 | num_shards: int, 193 | shard_index: Optional[int] = None, 194 | keys_to_preserve: Optional[Sequence[str]] = None 195 | ) -> tf.data.Dataset: 196 | """Loads a dataset and shards it based on jax devices.""" 197 | shard_index = shard_index or jax.process_index() 198 | file_names, parse_fn = load_filenames_and_parse_fn( 199 | path=path, 200 | tfrecord_prefix=tfrecord_prefix, 201 | ) 202 | 203 | ds = tf.data.TFRecordDataset(file_names) 204 | 205 | threads = max(1, os.cpu_count() - 4) 206 | options = tf.data.Options() 207 | options.threading.private_threadpool_size = threads 208 | options.threading.max_intra_op_parallelism = 1 209 | ds = ds.with_options(options) 210 | 211 | # Shard if we don't shard by files 212 | if num_shards != 1: 213 | ds = ds.shard(num_shards, shard_index) 214 | 215 | # Parse the examples one by one 216 | if keys_to_preserve is not None: 217 | # Optionally also filter them based on the keys provided 218 | def parse_filter(example_proto): 219 | example = parse_fn(example_proto) 220 | return filter_based_on_keys(example, keys_to_preserve=keys_to_preserve) 221 | ds = ds.map(parse_filter, num_parallel_calls=tf.data.experimental.AUTOTUNE) 222 | else: 223 | ds = ds.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) 224 | 225 | return ds 226 | 227 | 228 | def load_dataset( 229 | path: str, 230 | tfrecord_prefix: str, 231 | sub_sample_length: Optional[int], 232 | per_device_batch_size: int, 233 | num_epochs: Optional[int], 234 | drop_remainder: bool, 235 | multi_device: bool = False, 236 | num_shards: int = 1, 237 | shard_index: Optional[int] = None, 238 | keys_to_preserve: Optional[Sequence[str]] = None, 239 | shuffle: bool = False, 240 | cache: bool = True, 241 | shuffle_buffer: Optional[int] = 10000, 242 | dtype: str = "float32", 243 | seed: Optional[int] = None 244 | ) -> tf.data.Dataset: 245 | """Creates a tensorflow.Dataset pipeline from an TFRecord dataset. 246 | 247 | Args: 248 | path: The path to the dataset. 249 | tfrecord_prefix: The dataset prefix. 250 | sub_sample_length: The length of the sequences that will be returned. 251 | If this is `None` the dataset will return full length sequences. 252 | If this is an `int` it will subsample each sequence uniformly at random 253 | for a sequence of the provided size. Note that all examples in the dataset 254 | must be at least this long, otherwise the tensorflow code might crash. 255 | per_device_batch_size: The batch size to use on a single device. The actual 256 | batch size is this multiplied by the number of devices. 257 | num_epochs: The number of times to repeat the full dataset. 258 | drop_remainder: If the number of examples in the dataset are not divisible 259 | evenly by the batch size, whether each epoch to drop the remaining 260 | examples, or to construct a batch with batch size smaller than usual. 261 | multi_device: Whether to load the dataset prepared for multi-device use 262 | (e.g. pmap) with leading dimension equal to the number of local devices. 263 | num_shards: If you want to shard the dataset, you must specify how many 264 | shards you want to use. 265 | shard_index: The shard index for this host. If `None` will use 266 | `jax.process_index()`. 267 | keys_to_preserve: Explicit specification which keys to keep from the dataset 268 | shuffle: Whether to shuffle examples in the dataset. 269 | cache: Whether to use cache in the tf.Dataset. 270 | shuffle_buffer: Size of the shuffling buffer. 271 | dtype: What data type to convert the data to. 272 | seed: Seed to pass to the loader. 273 | Returns: 274 | A tensorflow dataset object. 275 | """ 276 | per_host_batch_size = per_device_batch_size * jax.local_device_count() 277 | # Preprocessing function 278 | batch_fn = functools.partial( 279 | preprocess_batch, 280 | num_local_devices=jax.local_device_count(), 281 | multi_device=multi_device, 282 | sub_sample_length=sub_sample_length, 283 | dtype=dtype) 284 | 285 | with tf.name_scope("dataset"): 286 | ds = load_parsed_dataset( 287 | path=path, 288 | tfrecord_prefix=tfrecord_prefix, 289 | num_shards=num_shards, 290 | shard_index=shard_index, 291 | keys_to_preserve=keys_to_preserve, 292 | ) 293 | if cache: 294 | ds = ds.cache() 295 | if shuffle: 296 | ds = ds.shuffle(shuffle_buffer, seed=seed) 297 | ds = ds.repeat(num_epochs) 298 | ds = ds.batch(per_host_batch_size, drop_remainder=drop_remainder) 299 | ds = ds.map(batch_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) 300 | return ds.prefetch(tf.data.experimental.AUTOTUNE) 301 | 302 | 303 | def dataset_as_iter(dataset_func, *args, **kwargs): 304 | def iterable_func(): 305 | yield from tfds.as_numpy(dataset_func(*args, **kwargs)) 306 | return iterable_func 307 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/molecular_dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/molecular_dynamics/generate_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Creates a small molecular dynamics (MD) dataset. 16 | 17 | This binary creates a small MD dataset given a text-based trajectory file 18 | generated with the simulation package LAMMPS (https://lammps.sandia.gov/). 19 | The trajectory file can be generated by running LAMMPS with the input script 20 | provided. 21 | 22 | We note that this binary is intended as a demonstration only and is therefore 23 | not optimised for memory efficiency or performance. 24 | """ 25 | 26 | from typing import Mapping, Tuple 27 | 28 | from dm_hamiltonian_dynamics_suite import datasets 29 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils 30 | 31 | import numpy as np 32 | 33 | Array = np.ndarray 34 | 35 | 36 | def read_trajectory(filename: str) -> Tuple[Array, float]: 37 | """Reads the trajectory data from file and returns it as an array. 38 | 39 | Each timestep contains a header and the atom data. The header is 9 lines long 40 | and contains the timestep, the number of atoms, the box dimensions and the 41 | list of atom properties. The header is assumed to be structured as in the 42 | example below: 43 | 44 | ITEM: TIMESTEP 45 | <> 46 | ITEM: NUMBER OF ATOMS 47 | <> 48 | ITEM: BOX BOUNDS pp pp pp 49 | <> <> 50 | <> <> 51 | <> <> 52 | ITEM: ATOMS id type x y vx vy fx fy 53 | .... <> lines with properties of <> atoms.... 54 | 55 | Args: 56 | filename: name of the input file. 57 | 58 | Returns: 59 | A pair where the first element corresponds to an array of shape 60 | [num_timesteps, num_atoms, 6] containing the atom data and the second 61 | element corresponds to the edge length of the simulation box. 62 | """ 63 | with open(filename, 'r') as f: 64 | dat = f.read() 65 | lines = dat.split('\n') 66 | 67 | # Extract the number of particles and the edge length of the simulation box. 68 | num_particles = int(lines[3]) 69 | box = np.fromstring(lines[5], dtype=np.float32, sep=' ') 70 | box_length = box[1] - box[0] 71 | 72 | # Iterate over all timesteps and extract the relevant data columns. 73 | header_size = 9 74 | record_size = header_size + num_particles 75 | num_records = len(lines) // record_size 76 | records = [] 77 | for i in range(num_records): 78 | record = lines[header_size + i * record_size:(i + 1) * record_size] 79 | record = np.array([l.split(' ')[:-1] for l in record], dtype=np.float32) 80 | records.append(record) 81 | records = np.array(records)[..., 2:] 82 | return records, box_length 83 | 84 | 85 | def flatten_record(x: Array) -> Array: 86 | """Reshapes input from [num_particles, 2*dim] to [2*num_particles*dim].""" 87 | if x.shape[-1] % 2 != 0: 88 | raise ValueError(f'Expected last dimension to be even, got {x.shape[-1]}.') 89 | dim = x.shape[-1] // 2 90 | q = x[..., 0:dim] 91 | p = x[..., dim:] 92 | q_flat = q.reshape(list(q.shape[:-2]) + [-1]) 93 | p_flat = p.reshape(list(p.shape[:-2]) + [-1]) 94 | x_flat = np.concatenate((q_flat, p_flat), axis=-1) 95 | return x_flat 96 | 97 | 98 | def render_images( 99 | x: Array, 100 | box_length: float, 101 | resolution: int = 32, 102 | particle_radius: float = 0.3 103 | ) -> Array: 104 | """Renders a sequence with shape [num_steps, num_particles*dim] as images.""" 105 | dim = 2 106 | sequence_length, num_coordinates = x.shape 107 | if num_coordinates % (2 * dim) != 0: 108 | raise ValueError('Expected the number of coordinates to be divisible by 4, ' 109 | f'got {num_coordinates}.') 110 | # The 4 coordinates are positions and velocities in 2d. 111 | num_particles = num_coordinates // (2 * dim) 112 | # `x` is formatted as [x_1, y_1,... x_N, y_N, dx_1, dy_1,..., dx_N, dy_N], 113 | # where `N=num_particles`. For the image generation, we only require x and y 114 | # coordinates. 115 | particles = x[..., :num_particles * dim] 116 | particles = particles.reshape((sequence_length, num_particles, dim)) 117 | colors = np.arange(num_particles, dtype=np.int32) 118 | box_region = utils.BoxRegion(-box_length / 2., box_length / 2.) 119 | images = utils.render_particles_trajectory( 120 | particles=particles, 121 | particles_radius=particle_radius, 122 | color_indices=colors, 123 | canvas_limits=box_region, 124 | resolution=resolution, 125 | num_colors=num_particles) 126 | return images 127 | 128 | 129 | def convert_sequence(sequence: Array, box_length: float) -> Mapping[str, Array]: 130 | """Converts a sequence of timesteps to a data point.""" 131 | num_steps, num_particles, num_fields = sequence.shape 132 | # A LAMMPS record should contain positions, velocities and forces. 133 | if num_fields != 6: 134 | raise ValueError('Expected input sequence to be of shape ' 135 | f'[num_steps, num_particles, 6], got {sequence.shape}.') 136 | x = np.empty((num_steps, num_particles * 4)) 137 | dx_dt = np.empty((num_steps, num_particles * 4)) 138 | for step in range(num_steps): 139 | # Assign positions and momenta to `x` and momenta and forces to `dx_dt`. 140 | x[step] = flatten_record(sequence[step, :, (0, 1, 2, 3)]) 141 | dx_dt[step] = flatten_record(sequence[step, :, (2, 3, 4, 5)]) 142 | 143 | image = render_images(x, box_length) 144 | image = np.array(image * 255.0, dtype=np.uint8) 145 | return dict(x=x, dx_dt=dx_dt, image=image) 146 | 147 | 148 | def write_to_file( 149 | data: Array, 150 | box_length: float, 151 | output_path: str, 152 | split: str, 153 | overwrite: bool, 154 | ) -> None: 155 | """Writes the data to file.""" 156 | 157 | def generator(): 158 | for sequence in data: 159 | yield convert_sequence(sequence, box_length) 160 | 161 | datasets.transform_dataset(generator(), output_path, split, overwrite) 162 | 163 | 164 | def generate_lammps_dataset( 165 | lammps_file: str, 166 | folder: str, 167 | num_steps: int, 168 | num_train: int, 169 | num_test: int, 170 | dt: int, 171 | shuffle: bool, 172 | seed: int, 173 | overwrite: bool, 174 | ) -> None: 175 | """Creates the train and test datasets.""" 176 | if num_steps < 1: 177 | raise ValueError(f'Expected `num_steps` to be >= 1, got {num_steps}.') 178 | if dt < 1: 179 | raise ValueError(f'Expected `dt` to be >= 1, got {dt}.') 180 | 181 | records, box_length = read_trajectory(lammps_file) 182 | # Consider only every dt-th timestep in the input file. 183 | records = records[::dt] 184 | num_records, num_particles, num_fields = records.shape 185 | if num_records < (num_test + num_train) * num_steps: 186 | raise ValueError( 187 | f'Trajectory contains only {num_records} records which is insufficient' 188 | f'for the requested train/test split of {num_train}/{num_test} with ' 189 | f'sequence length {num_steps}.') 190 | 191 | # Reshape and shuffle the data. 192 | num_points = num_records // num_steps 193 | records = records[:num_points * num_steps] 194 | records = records.reshape((num_points, num_steps, num_particles, num_fields)) 195 | if shuffle: 196 | np.random.RandomState(seed).shuffle(records) 197 | 198 | # Create train/test splits and write them to file. 199 | train_records = records[:num_train] 200 | test_records = records[num_train:num_train + num_test] 201 | print('Writing the train dataset to file.') 202 | write_to_file(train_records, box_length, folder, 'train', overwrite) 203 | print('Writing the test dataset to file.') 204 | write_to_file(test_records, box_length, folder, 'test', overwrite) 205 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/molecular_dynamics/lj_16.lmp: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # 17 | # 18 | # Simulate a two-dimensional Lennard-Jones fluid at the thermodynamic state 19 | # T*=0.5 and rho*=0.77. See Smit and Frenkel (1991), J. Chem. Phys. for the 20 | # vapour-liquid phase diagram of that system (http://doi.org/10.1063/1.460477). 21 | # Further information on the individual LAMMPS commands in this script can be 22 | # found on the LAMMPS homepage https://lammps.sandia.gov/doc/Manual.html. 23 | # 24 | # Simulation protocol: 25 | # 1. Equilibrate the system in an NVT run using a Langevin thermostat and 26 | # compute the mean energy. 27 | # 2. Perform a subsequent NVE simulation and measure the mean energy. 28 | # 3. Scale all particle velocities such that the energy agrees with the mean 29 | # value of the NVT run. 30 | # 4. Perform an NVE production run. The mean temperature sampled during that 31 | # run should now be close to the target value. 32 | # 33 | # For the 4-particle dataset, please make changes to lines 20, 23 and 26, as 34 | # suggested in the comments below. 35 | 36 | # number of particles to simulate (change to 4 for the 4-particle dataset) 37 | variable num_particles equal 16 38 | 39 | # reference temperature (change to 1.0 for the 4-particle dataset) 40 | variable temperature_ref equal 0.5 41 | 42 | # reference density (change to 0.04 for the 4-particle dataset) 43 | variable density_ref equal 0.77 44 | 45 | # damping coefficient for the Langevin thermostat 46 | variable temperature_damp equal 0.2 47 | 48 | # edge length of the square simulation box 49 | variable box_length equal exp(ln(${num_particles}/${density_ref})/2.) 50 | 51 | # Log to file every timestep. 52 | variable log_every equal 1 53 | 54 | # simulation timestep 55 | variable timestep equal 0.002 56 | 57 | # length of equilibration run 58 | variable runtime_equilibration equal 10000 59 | 60 | # length of production run 61 | variable runtime_production equal 1000 62 | 63 | # number of timesteps for both runs 64 | variable steps_equilibration equal floor(${runtime_equilibration}/${timestep}) 65 | variable steps_production equal floor(${runtime_production}/${timestep}) 66 | 67 | # random seed for coordinates, velocities and Langevin thermostat 68 | variable seed equal 1234 69 | 70 | # Initialise coordinates and velocities randomly. 71 | units lj 72 | atom_style atomic 73 | dimension 2 74 | boundary p p p 75 | variable box_length_half equal ${box_length}/2.0 76 | region domain block -${box_length_half} ${box_length_half} & 77 | -${box_length_half} ${box_length_half} & 78 | -0.1 0.1 units box 79 | create_box 1 domain 80 | create_atoms 1 random ${num_particles} ${seed} domain 81 | set type 1 z 0.0 82 | mass 1 1.0 83 | velocity all create ${temperature_ref} ${seed} 84 | 85 | # Zero out forces, velocities and z-coordinates of particles at every step. 86 | set type 1 z 0.0 87 | fix f2d all enforce2d 88 | 89 | # Define the interaction potential (pair-style). 90 | pair_style lj/cut ${box_length_half} 91 | pair_coeff 1 1 1.0 1.0 92 | neighbor 0.3 bin 93 | neigh_modify delay 0 94 | 95 | # Perform energy minimisation. 96 | minimize 1.0e-6 1.0e-8 100 1000 97 | 98 | # Perform the equilibration run. 99 | reset_timestep 0 100 | compute temperature all temp 101 | compute kinetic_energy all ke 102 | compute potential_energy all pe 103 | variable total_energy equal c_kinetic_energy+c_potential_energy 104 | 105 | # Compute the centre-of-mass velocity of the entire simulation box. 106 | variable vcmx equal "vcm(all,x)" 107 | variable vcmy equal "vcm(all,y)" 108 | variable vcmz equal "vcm(all,z)" 109 | variable vcm2 equal v_vcmx*v_vcmx+v_vcmy*v_vcmy+v_vcmz*v_vcmz 110 | 111 | # Apply a thermostat to all particles. 112 | compute langevin_temp all temp/partial 1 1 0 113 | fix flangevin all langevin ${temperature_ref} ${temperature_ref} & 114 | ${temperature_damp} ${seed} zero yes 115 | fix_modify flangevin temp langevin_temp 116 | 117 | # Specify the terminal output and frequency. 118 | thermo_style custom step c_temperature c_langevin_temp c_potential_energy & 119 | c_kinetic_energy v_total_energy press v_vcm2 120 | thermo_modify norm no 121 | 122 | thermo 100000 123 | timestep ${timestep} 124 | 125 | # Compute averages for temperature and energy. 126 | fix fnve all nve 127 | 128 | variable sample_size_equilibration equal ${steps_equilibration}/10 129 | fix fSampleEquilibration all ave/time 10 ${sample_size_equilibration} & 130 | ${steps_equilibration} c_temperature v_total_energy 131 | run ${steps_equilibration} 132 | 133 | variable temperature_nvt_equi equal $(f_fSampleEquilibration[1]) 134 | variable energy_nvt_equi equal $(f_fSampleEquilibration[2]) 135 | unfix fSampleEquilibration 136 | print "Averages for NVT equilibration run:" 137 | print "Reference temperature = ${temperature_ref}" 138 | print "_NVT = ${temperature_nvt_equi}" 139 | print "_NVT = ${energy_nvt_equi}" 140 | 141 | # Perform a short NVE simulation to estimate the mean energy. 142 | unfix flangevin 143 | reset_timestep 0 144 | fix fNVEPreAverage all ave/time 10 ${sample_size_equilibration} & 145 | ${steps_equilibration} c_temperature v_total_energy 146 | run ${steps_equilibration} 147 | 148 | variable temperature_nve_pre equal $(f_fNVEPreAverage[1]) 149 | variable energy_nve_pre equal $(f_fNVEPreAverage[2]) 150 | unfix fNVEPreAverage 151 | print "Averages for short NVE equilibration run:" 152 | print "INFO: _NVE = ${temperature_nve_pre}" 153 | print "INFO: _NVE = ${energy_nve_pre}" 154 | 155 | # Rescale velocities so that the energy is adjusted to the NVT average. This 156 | # makes sure that the system will be close to the reference temperature during 157 | # the production run even in the absence of a thermostat. 158 | reset_timestep 0 159 | variable dQ equal ${energy_nvt_equi}-${energy_nve_pre} 160 | variable dQdt1 equal ${dQ}/100/${timestep} 161 | fix fHeat all heat 1 ${dQdt1} region domain 162 | run 100 163 | unfix fHeat 164 | 165 | # Perform a second NVE equilibration run. 166 | reset_timestep 0 167 | run ${steps_equilibration} 168 | 169 | # Perform the production run. 170 | reset_timestep 0 171 | 172 | variable sample_size_production equal ${steps_production}/10 173 | fix fSampleProduction all ave/time 10 ${sample_size_production} & 174 | ${steps_production} c_temperature v_total_energy 175 | 176 | dump 1 all custom ${log_every} /tmp/trajectory_file.dat id type x y vx vy fx fy 177 | dump_modify 1 sort id 178 | fix 2 all ave/time ${log_every} 1 ${log_every} c_temperature c_kinetic_energy & 179 | c_potential_energy v_total_energy file /tmp/observable_file.dat 180 | run ${steps_production} 181 | 182 | variable temperature_prod equal $(f_fSampleProduction[1]) 183 | variable energy_prod equal $(f_fSampleProduction[2]) 184 | unfix fSampleProduction 185 | 186 | print "Averages for NVE production run:" 187 | print "_NVE = = ${temperature_prod}" 188 | print "_NVE = = ${energy_prod}" 189 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/multiagent_dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/multiagent_dynamics/game_dynamics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Continuous-time two player zero-sum game dynamics.""" 16 | from typing import Any, Mapping 17 | 18 | from dm_hamiltonian_dynamics_suite.hamiltonian_systems import utils 19 | import jax.numpy as jnp 20 | import jax.random as jnr 21 | import numpy as np 22 | from numpy.lib.stride_tricks import as_strided 23 | import scipy 24 | from open_spiel.python.egt import dynamics as egt_dynamics 25 | from open_spiel.python.egt import utils as egt_utils 26 | from open_spiel.python.pybind11 import pyspiel 27 | 28 | 29 | def sample_from_simplex( 30 | num_samples: int, 31 | rng: jnp.ndarray, 32 | dim: int = 3, 33 | vmin: float = 0. 34 | ) -> jnp.ndarray: 35 | """Samples random points from a k-simplex. See D. B. Rubin (1981), p131.""" 36 | # This is a jax version of open_spiel.python.egt.utils.sample_from_simplex. 37 | assert vmin >= 0. 38 | p = jnr.uniform(rng, shape=(num_samples, dim - 1)) 39 | p = jnp.sort(p, axis=1) 40 | p = jnp.hstack((jnp.zeros((num_samples, 1)), p, jnp.ones((num_samples, 1)))) 41 | return (p[:, 1:] - p[:, 0:-1]) * (1 - 2 * vmin) + vmin 42 | 43 | 44 | def get_payoff_tensor(game_name: str) -> jnp.ndarray: 45 | """Returns the payoff tensor of a game.""" 46 | game = pyspiel.extensive_to_tensor_game(pyspiel.load_game(game_name)) 47 | assert game.get_type().utility == pyspiel.GameType.Utility.ZERO_SUM 48 | payoff_tensor = egt_utils.game_payoffs_array(game) 49 | return payoff_tensor 50 | 51 | 52 | def tile_array(a: jnp.ndarray, b0: int, b1: int) -> jnp.ndarray: 53 | r, c = a.shape # number of rows/columns 54 | rs, cs = a.strides # row/column strides 55 | x = as_strided(a, (r, b0, c, b1), (rs, 0, cs, 0)) # view a as larger 4D array 56 | return x.reshape(r * b0, c * b1) # create new 2D array 57 | 58 | 59 | class ZeroSumGame: 60 | """Generate trajectories from zero-sum game dynamics.""" 61 | 62 | def __init__( 63 | self, 64 | game_name: str, 65 | dynamics: str = 'replicator', 66 | method: str = 'scipy'): 67 | self.payoff_tensor = get_payoff_tensor(game_name) 68 | assert self.payoff_tensor.shape[0] == 2, 'Only supports two-player games.' 69 | dyn_fun = getattr(egt_dynamics, dynamics) 70 | self.dynamics = egt_dynamics.MultiPopulationDynamics( 71 | self.payoff_tensor, dyn_fun) 72 | self.method = method 73 | self.scipy_ivp_kwargs = dict(rtol=1e-12, atol=1e-12) 74 | 75 | def sample_x0(self, num_samples: int, rng_key: jnp.ndarray) -> jnp.ndarray: 76 | """Samples initial states.""" 77 | nrows, ncols = self.payoff_tensor.shape[1:] 78 | key1, key2 = jnr.split(rng_key) 79 | x0_1 = sample_from_simplex(num_samples, key1, dim=nrows) 80 | x0_2 = sample_from_simplex(num_samples, key2, dim=ncols) 81 | x0 = jnp.hstack((x0_1, x0_2)) 82 | return x0 83 | 84 | def generate_trajectories( 85 | self, 86 | x0: jnp.ndarray, 87 | t0: utils.FloatArray, 88 | t_eval: jnp.ndarray 89 | ) -> jnp.ndarray: 90 | """Generates trajectories of the system in phase space. 91 | 92 | Args: 93 | x0: Initial state. 94 | t0: The time instance of the initial state y0. 95 | t_eval: Times at which to return the computed solution. 96 | 97 | Returns: 98 | Trajectories of size BxTxD (batch, time, phase-space-dim). 99 | """ 100 | if self.method == 'scipy': 101 | x0_shape = x0.shape 102 | 103 | def fun(_, y): 104 | y = y.reshape(x0_shape) 105 | y_next = np.apply_along_axis(self.dynamics, -1, y) 106 | return y_next.reshape([-1]) 107 | 108 | t_span = (t0, float(t_eval[-1])) 109 | solution = scipy.integrate.solve_ivp( 110 | fun=fun, 111 | t_span=t_span, 112 | y0=x0.reshape([-1]), # Scipy requires flat input. 113 | t_eval=t_eval, 114 | **self.scipy_ivp_kwargs) 115 | x = solution.y.reshape(x0_shape + (t_eval.size,)) 116 | x = np.moveaxis(x, -1, 1) # Make time 2nd dimension. 117 | else: 118 | raise ValueError(f'Method={self.method} not supported.') 119 | return x 120 | 121 | def render_trajectories(self, x: jnp.ndarray) -> jnp.ndarray: 122 | """Maps from policies to joint-policy space.""" 123 | nrows, ncols = self.payoff_tensor.shape[1:] 124 | x_1 = x[..., :nrows] 125 | x_2 = x[..., nrows:] 126 | x_1 = x_1.repeat(ncols, axis=-1).reshape(x.shape[:-1] + (nrows, ncols,)) 127 | x_2 = x_2.repeat(nrows, axis=-1).reshape(x.shape[:-1] + (nrows, ncols,)) 128 | x_2 = x_2.swapaxes(-2, -1) 129 | image = x_1 * x_2 130 | 131 | # Rescale to 32 x 32 from the original 2x2 or 3x3 data by expanding the 132 | # matrix to the nearest to 32 multiple of 2 or 3, evenly tiling it with the 133 | # original values, and then taking a 32x32 top left slice of it 134 | temp_image = [ 135 | tile_array(x, np.ceil(32 / x.shape[0]).astype('int'), 136 | np.ceil(32 / x.shape[1]).astype('int'))[:32, :32] 137 | for x in np.squeeze(image) 138 | ] 139 | image = np.stack(temp_image) 140 | image = np.repeat(np.expand_dims(image, -1), 3, axis=-1) 141 | 142 | return image[None, ...] 143 | 144 | def generate_and_render( 145 | self, 146 | num_trajectories: int, 147 | rng_key: jnp.ndarray, 148 | t0: utils.FloatArray, 149 | t_eval: utils.FloatArray 150 | ) -> Mapping[str, Any]: 151 | """Generates trajectories and renders them. 152 | 153 | Args: 154 | num_trajectories: The number of trajectories to generate. 155 | rng_key: PRNG key for sampling any random numbers. 156 | t0: The time instance of the initial state y0. 157 | t_eval: Times at which to return the computed solution. 158 | 159 | Returns: 160 | A dictionary containing the following elements: 161 | 'x': A numpy array representation of the phase space vector. 162 | 'dx_dt': The time derivative of 'x'. 163 | 'image': An image representation of the state. 164 | """ 165 | rng_key, key = jnr.split(rng_key) 166 | x0 = self.sample_x0(num_trajectories, key) 167 | x = self.generate_trajectories(x0, t0, t_eval) 168 | x = np.concatenate([x0[:, None], x], axis=1) # Add initial state. 169 | dx_dt = np.apply_along_axis(self.dynamics, -1, x) 170 | image = self.render_trajectories(x) 171 | 172 | return dict(x=x, dx_dt=dx_dt, image=image) 173 | 174 | def generate_and_render_dt( 175 | self, 176 | num_trajectories: int, 177 | rng_key: jnp.ndarray, 178 | t0: utils.FloatArray, 179 | dt: utils.FloatArray, 180 | num_steps: int 181 | ) -> Mapping[str, Any]: 182 | """Same as `generate_and_render` but uses `dt` and `num_steps`.""" 183 | t_eval = utils.dt_to_t_eval(t0, dt, num_steps) 184 | return self.generate_and_render(num_trajectories, rng_key, t0, t_eval) 185 | -------------------------------------------------------------------------------- /dm_hamiltonian_dynamics_suite/tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Module for testing the generation and loading of datasets.""" 16 | import os 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from dm_hamiltonian_dynamics_suite import datasets 22 | from dm_hamiltonian_dynamics_suite import load_datasets 23 | 24 | import jax 25 | from jax import numpy as jnp 26 | import tensorflow as tf 27 | 28 | 29 | DATASETS_TO_TEST = [ 30 | "mass_spring", 31 | "pendulum_colors", 32 | "double_pendulum_colors", 33 | "two_body_colors", 34 | ] 35 | 36 | if datasets.open_spiel_available(): 37 | DATASETS_TO_TEST += [ 38 | "matching_pennies", 39 | "rock_paper_scissors", 40 | ] 41 | 42 | 43 | class TestToyDataset(parameterized.TestCase): 44 | """Test class for the functions in `tag_graph_matcher.py`.""" 45 | 46 | def compare_structures_all_the_same(self, example, batched_example): 47 | """Compares that the two examples are identical in structure and value.""" 48 | self.assertEqual( 49 | jax.tree_structure(example), 50 | jax.tree_structure(batched_example), 51 | "Structures should be the same." 52 | ) 53 | # The real example image is not converted however 54 | example["image"] = tf.image.convert_image_dtype( 55 | example["image"], dtype=batched_example["image"].dtype).numpy() 56 | for v1, v2 in zip(jax.tree_leaves(example), 57 | jax.tree_leaves(batched_example)): 58 | self.assertEqual(v1.dtype, v2.dtype, "Dtypes should be the same.") 59 | self.assertEqual((1,) + v1.shape, v2.shape, "Shapes should be the same.") 60 | self.assertTrue(jnp.allclose(v1, v2[0]), "Values should be the same.") 61 | 62 | @parameterized.parameters(DATASETS_TO_TEST) 63 | def test_dataset( 64 | self, 65 | dataset, 66 | folder: str = "/tmp/dm_hamiltonian_dynamics_suite/tests/", 67 | dt: float = 0.1, 68 | num_steps: int = 100, 69 | steps_per_dt: int = 10, 70 | num_train: int = 10, 71 | num_test: int = 10, 72 | ): 73 | """Checks that the dataset generation and loading are working correctly.""" 74 | 75 | # Generate the dataset 76 | train_examples, test_examples = datasets.generate_full_dataset( 77 | folder=folder, 78 | dataset=dataset, 79 | dt=dt, 80 | num_steps=num_steps, 81 | steps_per_dt=steps_per_dt, 82 | num_train=num_train, 83 | num_test=num_test, 84 | overwrite=True, 85 | return_generated_examples=True, 86 | ) 87 | 88 | # Load train dataset 89 | dataset_path = dataset.lower() + "_dt_" + str(dt).replace(".", "_") 90 | ds = load_datasets.dataset_as_iter( 91 | load_datasets.load_dataset, 92 | path=os.path.join(folder, dataset_path), 93 | tfrecord_prefix="train", 94 | sub_sample_length=None, 95 | per_device_batch_size=1, 96 | num_epochs=1, 97 | drop_remainder=False, 98 | dtype="float64" 99 | ) 100 | examples = tuple(x for x in ds()) 101 | self.assertEqual( 102 | len(train_examples), len(examples), 103 | "Number of training examples not the same." 104 | ) 105 | # Compare individual examples 106 | for example_1, example_2 in zip(train_examples, examples): 107 | self.compare_structures_all_the_same(example_1, example_2) 108 | 109 | # Load test dataset 110 | ds = load_datasets.dataset_as_iter( 111 | load_datasets.load_dataset, 112 | path=os.path.join(folder, dataset_path), 113 | tfrecord_prefix="test", 114 | sub_sample_length=None, 115 | per_device_batch_size=1, 116 | num_epochs=1, 117 | drop_remainder=False, 118 | dtype="float64" 119 | ) 120 | 121 | examples = tuple(x for x in ds()) 122 | self.assertEqual( 123 | len(test_examples), len(examples), 124 | "Number of test examples not the same." 125 | ) 126 | # Compare individual examples 127 | for example_1, example_2 in zip(test_examples, examples): 128 | self.compare_structures_all_the_same(example_1, example_2) 129 | 130 | 131 | if __name__ == "__main__": 132 | jax.config.update("jax_enable_x64", True) 133 | absltest.main() 134 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.12.0 2 | numpy>=1.16.4 3 | typing>=3.7.4.3 4 | scipy>=1.7.1 5 | open-spiel>=1.0.1 6 | tensorflow>=2.6.0 7 | tensorflow-datasets>=4.4.0 8 | jax==0.2.20 9 | jaxlib==0.1.71 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. 2 | # 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Setup for pip package.""" 16 | from setuptools import setup 17 | 18 | REQUIRED_PACKAGES = ( 19 | "absl-py>=0.12.0", 20 | "numpy>=1.16.4", 21 | "typing>=3.7.4.3", 22 | "scipy>=1.7.1", 23 | "open-spiel>=1.0.1", 24 | "tensorflow>=2.6.0", 25 | "tensorflow-datasets>=4.4.0", 26 | "jax==0.2.20", 27 | "jaxlib==0.1.71" 28 | ) 29 | 30 | LONG_DESCRIPTION = "\n".join([( 31 | "A suite of 17 datasets with phase space, high dimensional (visual) " 32 | "observations and other measurement where appropriate that are based on " 33 | "physical systems, exhibiting a Hamiltonian dynamics" 34 | )]) 35 | 36 | 37 | setup( 38 | name="dm_hamiltonian_dynamics_suite", 39 | version="0.0.1", 40 | description="A collection of 17 datasets based on Hamiltonian physical " 41 | "systems.", 42 | long_description=LONG_DESCRIPTION, 43 | url="https://github.com/deepmind/dm_hamiltonian_dynamics_suite", 44 | author="DeepMind", 45 | packages=[ 46 | "dm_hamiltonian_dynamics_suite", 47 | "dm_hamiltonian_dynamics_suite.hamiltonian_systems", 48 | "dm_hamiltonian_dynamics_suite.molecular_dynamics", 49 | "dm_hamiltonian_dynamics_suite.multiagent_dynamics", 50 | ], 51 | install_requires=REQUIRED_PACKAGES, 52 | platforms=["any"], 53 | license="Apache License, Version 2.0", 54 | ) 55 | -------------------------------------------------------------------------------- /visualize_datasets.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "d68zSuu0TX2_" 7 | }, 8 | "source": [ 9 | "**Copyright 2020 DeepMind Technologies Limited.**\n", 10 | "\n", 11 | "\n", 12 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n", 13 | "you may not use this file except in compliance with the License.\n", 14 | "You may obtain a copy of the License at\n", 15 | "\n", 16 | "https://www.apache.org/licenses/LICENSE-2.0\n", 17 | "\n", 18 | "Unless required by applicable law or agreed to in writing, software\n", 19 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n", 20 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 21 | "See the License for the specific language governing permissions and\n", 22 | "limitations under the License." 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "cellView": "form", 30 | "id": "d5gKoACsMxPq" 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "#@title Install the package for loading the datasets\n", 35 | "!git clone https://github.com/deepmind/dm_hamiltonian_dynamics_suite.git\n", 36 | "!pip install ./dm_hamiltonian_dynamics_suite/" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": { 43 | "cellView": "form", 44 | "id": "tpOOQRg_mJcU" 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "#@title Imports\n", 49 | "import functools\n", 50 | "import os\n", 51 | "import requests\n", 52 | "from subprocess import getstatusoutput\n", 53 | "from matplotlib import pyplot as plt\n", 54 | "from matplotlib import animation as plt_animation\n", 55 | "from matplotlib import rc\n", 56 | "import numpy as np\n", 57 | "from jax import config as jax_config\n", 58 | "import tensorflow as tf\n", 59 | "\n", 60 | "rc('animation', html='jshtml')\n", 61 | "jax_config.update(\"jax_enable_x64\", True)\n", 62 | "\n", 63 | "from dm_hamiltonian_dynamics_suite import load_datasets\n", 64 | "from dm_hamiltonian_dynamics_suite import datasets" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "cellView": "form", 72 | "id": "tQjT27Ymspqr" 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "#@title Helper functions\n", 77 | "DATASETS_URL = \"gs://dm-hamiltonian-dynamics-suite\"\n", 78 | "DATASETS_FOLDER = \"./datasets\" #@param {type: \"string\"}\n", 79 | "os.makedirs(DATASETS_FOLDER, exist_ok=True)\n", 80 | "\n", 81 | "def download_file(file_url, destination_file):\n", 82 | " print(\"Downloading\", file_url, \"to\", destination_file)\n", 83 | " command = f\"gsutil cp {file_url} {destination_file}\"\n", 84 | " status_code, output = getstatusoutput(command)\n", 85 | " if status_code != 0:\n", 86 | " raise ValueError(output)\n", 87 | "\n", 88 | "def download_dataset(dataset_name: str):\n", 89 | " \"\"\"Downloads the provided dataset from the DM Hamiltonian Dataset Suite\"\"\"\n", 90 | " destination_folder = os.path.join(DATASETS_FOLDER, dataset_name)\n", 91 | " dataset_url = os.path.join(DATASETS_URL, dataset_name)\n", 92 | " os.makedirs(destination_folder, exist_ok=True)\n", 93 | " if \"long_trajectory\" in dataset_name:\n", 94 | " files = (\"features.txt\", \"test.tfrecord\")\n", 95 | " else:\n", 96 | " files = (\"features.txt\", \"train.tfrecord\", \"test.tfrecord\")\n", 97 | " for file_name in files:\n", 98 | " file_url = os.path.join(dataset_url, file_name)\n", 99 | " destination_file = os.path.join(destination_folder, file_name)\n", 100 | " if os.path.exists(destination_file):\n", 101 | " print(\"File\", file_url, \"already present.\")\n", 102 | " continue\n", 103 | " download_file(file_url, destination_file)\n", 104 | "\n", 105 | "\n", 106 | "def unstack(value: np.ndarray, axis: int = 0):\n", 107 | " \"\"\"Unstacks an array along an axis into a list\"\"\"\n", 108 | " split = np.split(value, value.shape[axis], axis=axis)\n", 109 | " return [np.squeeze(v, axis=axis) for v in split]\n", 110 | "\n", 111 | "\n", 112 | "def make_batch_grid(\n", 113 | " batch: np.ndarray, \n", 114 | " grid_height: int,\n", 115 | " grid_width: int, \n", 116 | " with_padding: bool = True):\n", 117 | " \"\"\"Makes a single grid image from a batch of multiple images.\"\"\"\n", 118 | " assert batch.ndim == 5\n", 119 | " assert grid_height * grid_width \u003e= batch.shape[0]\n", 120 | " batch = batch[:grid_height * grid_width]\n", 121 | " batch = batch.reshape((grid_height, grid_width) + batch.shape[1:])\n", 122 | " if with_padding:\n", 123 | " batch = np.pad(batch, pad_width=[[0, 0], [0, 0], [0, 0],\n", 124 | " [1, 0], [1, 0], [0, 0]],\n", 125 | " mode=\"constant\", constant_values=1.0)\n", 126 | " batch = np.concatenate(unstack(batch), axis=-3)\n", 127 | " batch = np.concatenate(unstack(batch), axis=-2)\n", 128 | " if with_padding:\n", 129 | " batch = batch[:, 1:, 1:]\n", 130 | " return batch\n", 131 | "\n", 132 | "\n", 133 | "def plot_animattion_from_batch(\n", 134 | " batch: np.ndarray, \n", 135 | " grid_height, \n", 136 | " grid_width, \n", 137 | " with_padding=True, \n", 138 | " figsize=None):\n", 139 | " \"\"\"Plots an animation of the batch of sequences.\"\"\"\n", 140 | " if figsize is None:\n", 141 | " figsize = (grid_width, grid_height)\n", 142 | " batch = make_batch_grid(batch, grid_height, grid_width, with_padding)\n", 143 | " batch = batch[:, ::-1]\n", 144 | " fig = plt.figure(figsize=figsize)\n", 145 | " plt.close()\n", 146 | " ax = fig.add_subplot(1, 1, 1)\n", 147 | " ax.axis('off')\n", 148 | " img = ax.imshow(batch[0]) \n", 149 | " def frame_update(i):\n", 150 | " i = int(np.floor(i).astype(\"int64\"))\n", 151 | " img.set_data(batch[i])\n", 152 | " return [img]\n", 153 | " anim = plt_animation.FuncAnimation(\n", 154 | " fig=fig, \n", 155 | " func=frame_update,\n", 156 | " frames=np.linspace(0.0, len(batch), len(batch) * 5 + 1)[:-1],\n", 157 | " save_count=len(batch),\n", 158 | " interval=10, \n", 159 | " blit=True\n", 160 | " )\n", 161 | " return anim\n", 162 | "\n", 163 | "\n", 164 | "def plot_sequence_from_batch(\n", 165 | " batch: np.ndarray,\n", 166 | " t_start: int = 0,\n", 167 | " with_padding: bool = True, \n", 168 | " fontsize: int = 20):\n", 169 | " \"\"\"Plots all of the sequences in the batch.\"\"\"\n", 170 | " n, t, dx, dy = batch.shape[:-1]\n", 171 | " xticks = np.linspace(dx // 2, t * (dx + 1) - 1 - dx // 2, t)\n", 172 | " xtick_labels = np.arange(t) + t_start\n", 173 | " yticks = np.linspace(dy // 2, n * (dy + 1) - 1 - dy // 2, n)\n", 174 | " ytick_labels = np.arange(n)\n", 175 | " batch = batch.reshape((n * t, 1) + batch.shape[2:])\n", 176 | " batch = make_batch_grid(batch, n, t, with_padding)[0]\n", 177 | " plt.imshow(batch.squeeze())\n", 178 | " plt.xticks(ticks=xticks, labels=xtick_labels, fontsize=fontsize)\n", 179 | " plt.yticks(ticks=yticks, labels=ytick_labels, fontsize=fontsize)\n", 180 | "\n", 181 | "\n", 182 | "def visalize_dataset(\n", 183 | " dataset_path: str,\n", 184 | " sequence_lengths: int = 60,\n", 185 | " grid_height: int = 2, \n", 186 | " grid_width: int = 5):\n", 187 | " \"\"\"Visualizes a dataset loaded from the path provided.\"\"\"\n", 188 | " split = \"test\"\n", 189 | " batch_size = grid_height * grid_width\n", 190 | " dataset = load_datasets.load_dataset(\n", 191 | " path=dataset_path,\n", 192 | " tfrecord_prefix=split,\n", 193 | " sub_sample_length=sequence_lengths,\n", 194 | " per_device_batch_size=batch_size,\n", 195 | " num_epochs=None,\n", 196 | " drop_remainder=True,\n", 197 | " shuffle=False,\n", 198 | " shuffle_buffer=100\n", 199 | " )\n", 200 | " sample = next(iter(dataset))\n", 201 | " batch_x = sample['x'].numpy()\n", 202 | " batch_image = sample['image'].numpy()\n", 203 | " # Plot real system dimensions\n", 204 | " plt.figure(figsize=(24, 8))\n", 205 | " for i in range(batch_x.shape[-1]):\n", 206 | " plt.subplot(1, batch_x.shape[-1], i + 1)\n", 207 | " plt.title(f\"Samples from dimension {i+1}\")\n", 208 | " plt.plot(batch_x[:, :, i].T)\n", 209 | " plt.show()\n", 210 | " # Plot a sequence of 50 images\n", 211 | " plt.figure(figsize=(30, 10))\n", 212 | " plt.title(\"Samples from 50 steps sub sequences.\")\n", 213 | " plot_sequence_from_batch(batch_image[:, :50])\n", 214 | " plt.show()\n", 215 | " # Plot animation\n", 216 | " return plot_animattion_from_batch(batch_image, grid_height, grid_width)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": { 223 | "cellView": "form", 224 | "id": "Wr0jMkPmLzdA" 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "#@title Generate a small dataset and visualize it\n", 229 | "folder_to_store = \"./generated_datasets\" #@param {type:\"string\"}\n", 230 | "dataset = \"pendulum_colors\" #@param {type: \"string\"}\n", 231 | "dt = 0.1 #@param {type: \"number\"}\n", 232 | "num_steps = 100 #@param {type: \"integer\"}\n", 233 | "steps_per_dt = 1 #@param {type: \"integer\"}\n", 234 | "num_train = 100 #@param {type: \"integer\"}\n", 235 | "num_test = 10 #@param {type: \"integer\"}\n", 236 | "overwrite = True #@param {type: \"boolean\"}\n", 237 | "datasets.generate_full_dataset(\n", 238 | " folder=folder_to_store,\n", 239 | " dataset=dataset,\n", 240 | " dt=dt,\n", 241 | " num_steps=num_steps,\n", 242 | " steps_per_dt=steps_per_dt,\n", 243 | " num_train=num_train,\n", 244 | " num_test=num_test,\n", 245 | " overwrite=overwrite,\n", 246 | ")\n", 247 | "dataset_full_name = dataset + \"_dt_\" + str(dt).replace(\".\", \"_\")\n", 248 | "dataset_path = os.path.join(folder_to_store, dataset_full_name)\n", 249 | "visalize_dataset(dataset_path)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "metadata": { 256 | "cellView": "form", 257 | "id": "d6Fn8eaBentd" 258 | }, 259 | "outputs": [], 260 | "source": [ 261 | "#@title Download and visualise a dataset\n", 262 | "dataset_name = \"toy_physics/mass_spring\" #@param [\"toy_physics/mass_spring\", \"toy_physics/mass_spring_colors\", \"toy_physics/mass_spring_colors_friction\", \"toy_physics/mass_spring_long_trajectory\", \"toy_physics/mass_spring_colors_long_trajectory\", \"toy_physics/pendulum\", \"toy_physics/pendulum_colors\", \"toy_physics/pendulum_colors_friction\", \"toy_physics/pendulum_long_trajectory\", \"toy_physics/pendulum_colors_long_trajectory\", \"toy_physics/double_pendulum\", \"toy_physics/double_pendulum_colors\", \"toy_physics/double_pendulum_colors_friction\", \"toy_physics/two_body\", \"toy_physics/two_body_colors\", \"molecular_dynamics/lj_4\", \"molecular_dynamics/lj_16\", \"multi_agent/matching_pennies\", \"multi_agent/matching_pennies_long_trajectory\", \"multi_agent/rock_paper_scissors\", \"multi_agent/rock_paper_scissors_long_trajectory\", \"mujoco_room/circle\", \"mujoco_room/spiral\"]\n", 263 | "download_dataset(dataset_name)\n", 264 | "visalize_dataset(os.path.join(DATASETS_FOLDER, dataset_name))" 265 | ] 266 | } 267 | ], 268 | "metadata": { 269 | "colab": { 270 | "collapsed_sections": [], 271 | "last_runtime": {}, 272 | "name": "visualize_datasets.ipynb", 273 | "provenance": [] 274 | }, 275 | "kernelspec": { 276 | "display_name": "Python 3", 277 | "name": "python3" 278 | } 279 | }, 280 | "nbformat": 4, 281 | "nbformat_minor": 0 282 | } 283 | --------------------------------------------------------------------------------