├── .codecov.yml ├── .github └── workflows │ ├── publish.yml │ ├── ruff.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── configs ├── WaterDrop_2d │ └── gns.yaml ├── dam_2d │ ├── base.yaml │ ├── gns.yaml │ └── segnn.yaml ├── ldc_2d │ ├── base.yaml │ ├── gns.yaml │ └── segnn.yaml ├── ldc_3d │ ├── base.yaml │ ├── gns.yaml │ └── segnn.yaml ├── rpf_2d │ ├── base.yaml │ ├── egnn.yaml │ ├── gns.yaml │ ├── painn.yaml │ └── segnn.yaml ├── rpf_3d │ ├── base.yaml │ ├── egnn.yaml │ ├── gns.yaml │ ├── painn.yaml │ └── segnn.yaml ├── tgv_2d │ ├── base.yaml │ ├── gns.yaml │ └── segnn.yaml └── tgv_3d │ ├── base.yaml │ ├── gns.yaml │ └── segnn.yaml ├── data_gen ├── gns_data │ ├── README.md │ ├── __init__.py │ ├── download_dataset.sh │ ├── reading_utils.py │ └── tfrecord_to_h5.py └── lagrangebench_data │ ├── README.md │ ├── dataset_db.sh │ ├── dataset_ldc.sh │ ├── dataset_rpf.sh │ ├── dataset_tgv.sh │ ├── gen_dataset.py │ └── plot_frame.py ├── docs ├── Makefile ├── conf.py ├── index.rst ├── lagrangebench_logo.svg ├── make.bat ├── pages │ ├── baselines.rst │ ├── case_setup.rst │ ├── data.rst │ ├── defaults.rst │ ├── evaluate.rst │ ├── models.rst │ ├── train.rst │ ├── tutorial.rst │ └── utils.rst └── requirements.txt ├── download_data.sh ├── lagrangebench ├── __init__.py ├── case_setup │ ├── __init__.py │ ├── case.py │ └── features.py ├── data │ ├── __init__.py │ ├── data.py │ └── utils.py ├── defaults.py ├── evaluate │ ├── __init__.py │ ├── metrics.py │ ├── rollout.py │ └── utils.py ├── models │ ├── __init__.py │ ├── base.py │ ├── egnn.py │ ├── gns.py │ ├── linear.py │ ├── painn.py │ ├── segnn.py │ └── utils.py ├── runner.py ├── train │ ├── __init__.py │ ├── strats.py │ └── trainer.py └── utils.py ├── main.py ├── notebooks ├── data_gen.ipynb ├── datasets.ipynb ├── gns_data.ipynb ├── media │ └── scatter.gif └── tutorial.ipynb ├── poetry.lock ├── pyproject.toml ├── requirements_cuda.txt └── tests ├── 3D_LJ_3_1214every1 ├── metadata.json ├── test.h5 ├── train.h5 └── valid.h5 ├── case_test.py ├── models_test.py ├── pushforward_test.py ├── rollout_test.py └── runner_test.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | range: 50..70 # red color under 50%, yellow at 50%..70%, green over 70% 3 | precision: 1 4 | status: 5 | project: 6 | default: 7 | target: 60% # coverage success only above X% 8 | threshold: 5% # allow the coverage to drop by X% and being a success 9 | patch: 10 | default: 11 | target: 50% 12 | threshold: 5% -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | build-n-publish: 9 | name: Build and publish to PyPI 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up Python 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: '3.10' 18 | - name: Install Poetry 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install poetry 22 | poetry config virtualenvs.in-project true 23 | - name: Install dependencies 24 | run: | 25 | poetry install 26 | - name: Build and publish 27 | env: 28 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.POETRY_PYPI_TOKEN_PYPI }} 29 | run: | 30 | poetry version $(git describe --tags --abbrev=0) 31 | poetry build 32 | poetry publish 33 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | on: [pull_request] 3 | jobs: 4 | ruff: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v4 8 | - name: Install Ruff from pyproject.toml 9 | uses: astral-sh/ruff-action@v3 10 | with: 11 | version-file: "pyproject.toml" -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | tests: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ["3.9", "3.10", "3.11"] 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install Poetry 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install poetry 30 | poetry config virtualenvs.in-project true 31 | - name: Install dependencies 32 | run: | 33 | poetry install 34 | - name: Run pytest and generate coverage report 35 | run: | 36 | .venv/bin/pytest --cov-report=xml 37 | - name: Upload coverage report to Codecov 38 | uses: codecov/codecov-action@v3 39 | with: 40 | token: ${{ secrets.CODECOV_TOKEN }} 41 | file: ./coverage.xml 42 | flags: unittests 43 | verbose: true 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # experiments 2 | ckp/ 3 | rollout/ 4 | rollouts/ 5 | wandb/ 6 | *.out 7 | datasets 8 | baselines 9 | partition_* 10 | *.pkl 11 | datasets/ 12 | # dev 13 | .vscode 14 | __pycache__ 15 | *.pyc 16 | venv*/ 17 | *.egg-info 18 | .ruff-cache 19 | rollouts 20 | profile 21 | dist 22 | .coverage 23 | 24 | # Sphinx documentation 25 | docs/_build/ 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | exclude: | 4 | (?x)^( 5 | venv/| 6 | ckp/| 7 | rollout/| 8 | docs/| 9 | )$ 10 | repos: 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v4.4.0 13 | hooks: 14 | - id: check-merge-conflict 15 | - id: check-added-large-files 16 | - id: check-docstring-first 17 | - id: check-json 18 | - id: check-toml 19 | - id: check-yaml 20 | - id: requirements-txt-fixer 21 | - repo: https://github.com/astral-sh/ruff-pre-commit 22 | rev: 'v0.2.2' 23 | hooks: 24 | - id: ruff 25 | args: [ --fix ] 26 | - id: ruff-format 27 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.9" 7 | 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Chair of Aerodynamics and Fluid Mechanics @ TUM 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/WaterDrop_2d/gns.yaml: -------------------------------------------------------------------------------- 1 | extends: LAGRANGEBENCH_DEFAULTS 2 | 3 | dataset: 4 | src: /tmp/datasets/WaterDrop 5 | 6 | model: 7 | name: gns 8 | num_mp_steps: 10 9 | latent_dim: 128 10 | 11 | train: 12 | optimizer: 13 | lr_start: 5.e-4 14 | 15 | logging: 16 | wandb_project: waterdrop_2d 17 | 18 | neighbors: 19 | backend: matscipy 20 | -------------------------------------------------------------------------------- /configs/dam_2d/base.yaml: -------------------------------------------------------------------------------- 1 | extends: LAGRANGEBENCH_DEFAULTS 2 | 3 | dataset: 4 | src: datasets/2D_DAM_5740_20kevery100 5 | 6 | logging: 7 | wandb_project: dam_2d 8 | 9 | neighbors: 10 | multiplier: 2.0 11 | -------------------------------------------------------------------------------- /configs/dam_2d/gns.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/dam_2d/base.yaml 2 | 3 | model: 4 | name: gns 5 | num_mp_steps: 10 6 | latent_dim: 128 7 | 8 | train: 9 | noise_std: 0.001 10 | optimizer: 11 | lr_start: 5.e-4 12 | -------------------------------------------------------------------------------- /configs/dam_2d/segnn.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/dam_2d/base.yaml 2 | 3 | model: 4 | name: segnn 5 | num_mp_steps: 10 6 | latent_dim: 64 7 | isotropic_norm: True 8 | 9 | train: 10 | noise_std: 0.001 11 | optimizer: 12 | lr_start: 5.e-4 13 | -------------------------------------------------------------------------------- /configs/ldc_2d/base.yaml: -------------------------------------------------------------------------------- 1 | extends: LAGRANGEBENCH_DEFAULTS 2 | 3 | dataset: 4 | src: datasets/2D_LDC_2708_10kevery100 5 | 6 | logging: 7 | wandb_project: ldc_2d 8 | 9 | neighbors: 10 | multiplier: 2.0 -------------------------------------------------------------------------------- /configs/ldc_2d/gns.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/ldc_2d/base.yaml 2 | 3 | model: 4 | name: gns 5 | num_mp_steps: 10 6 | latent_dim: 128 7 | 8 | train: 9 | noise_std: 0.001 10 | optimizer: 11 | lr_start: 5.e-4 12 | -------------------------------------------------------------------------------- /configs/ldc_2d/segnn.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/ldc_2d/base.yaml 2 | 3 | model: 4 | name: segnn 5 | num_mp_steps: 10 6 | latent_dim: 64 7 | isotropic_norm: True 8 | 9 | train: 10 | noise_std: 0.001 11 | optimizer: 12 | lr_start: 5.e-4 13 | -------------------------------------------------------------------------------- /configs/ldc_3d/base.yaml: -------------------------------------------------------------------------------- 1 | extends: LAGRANGEBENCH_DEFAULTS 2 | 3 | dataset: 4 | src: datasets/3D_LDC_8160_10kevery100 5 | 6 | logging: 7 | wandb_project: ldc_3d 8 | 9 | neighbors: 10 | multiplier: 2.0 -------------------------------------------------------------------------------- /configs/ldc_3d/gns.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/ldc_3d/base.yaml 2 | 3 | model: 4 | name: gns 5 | num_mp_steps: 10 6 | latent_dim: 128 7 | 8 | train: 9 | optimizer: 10 | lr_start: 5.e-4 11 | -------------------------------------------------------------------------------- /configs/ldc_3d/segnn.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/ldc_3d/base.yaml 2 | 3 | model: 4 | name: segnn 5 | num_mp_steps: 10 6 | latent_dim: 64 7 | isotropic_norm: True 8 | 9 | train: 10 | optimizer: 11 | lr_start: 5.e-4 12 | -------------------------------------------------------------------------------- /configs/rpf_2d/base.yaml: -------------------------------------------------------------------------------- 1 | extends: LAGRANGEBENCH_DEFAULTS 2 | 3 | dataset: 4 | src: datasets/2D_RPF_3200_20kevery100 5 | 6 | logging: 7 | wandb_project: rpf_2d -------------------------------------------------------------------------------- /configs/rpf_2d/egnn.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/rpf_2d/base.yaml 2 | 3 | model: 4 | name: egnn 5 | num_mp_steps: 5 6 | latent_dim: 128 7 | isotropic_norm: True 8 | magnitude_features: True 9 | 10 | train: 11 | optimizer: 12 | lr_start: 5.e-4 13 | loss_weight: 14 | pos: 1.0 15 | vel: 0.0 16 | acc: 0.0 17 | -------------------------------------------------------------------------------- /configs/rpf_2d/gns.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/rpf_2d/base.yaml 2 | 3 | model: 4 | name: gns 5 | num_mp_steps: 10 6 | latent_dim: 128 7 | 8 | train: 9 | optimizer: 10 | lr_start: 5.e-4 11 | -------------------------------------------------------------------------------- /configs/rpf_2d/painn.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/rpf_2d/base.yaml 2 | 3 | model: 4 | name: painn 5 | num_mp_steps: 5 6 | latent_dim: 128 7 | isotropic_norm: True 8 | magnitude_features: True 9 | 10 | train: 11 | optimizer: 12 | lr_start: 1.e-4 13 | -------------------------------------------------------------------------------- /configs/rpf_2d/segnn.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/rpf_2d/base.yaml 2 | 3 | model: 4 | name: segnn 5 | num_mp_steps: 10 6 | latent_dim: 64 7 | isotropic_norm: True 8 | 9 | train: 10 | optimizer: 11 | lr_start: 1.e-3 12 | -------------------------------------------------------------------------------- /configs/rpf_3d/base.yaml: -------------------------------------------------------------------------------- 1 | extends: LAGRANGEBENCH_DEFAULTS 2 | 3 | dataset: 4 | src: datasets/3D_RPF_8000_10kevery100 5 | 6 | logging: 7 | wandb_project: rpf_3d -------------------------------------------------------------------------------- /configs/rpf_3d/egnn.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/rpf_3d/base.yaml 2 | 3 | model: 4 | name: egnn 5 | num_mp_steps: 5 6 | latent_dim: 128 7 | isotropic_norm: True 8 | magnitude_features: True 9 | 10 | train: 11 | optimizer: 12 | lr_start: 1.e-4 13 | loss_weight: 14 | pos: 1.0 15 | vel: 0.0 16 | acc: 0.0 17 | -------------------------------------------------------------------------------- /configs/rpf_3d/gns.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/rpf_3d/base.yaml 2 | 3 | model: 4 | name: gns 5 | num_mp_steps: 10 6 | latent_dim: 128 7 | 8 | train: 9 | optimizer: 10 | lr_start: 5.e-4 11 | -------------------------------------------------------------------------------- /configs/rpf_3d/painn.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/rpf_3d/base.yaml 2 | 3 | model: 4 | name: painn 5 | num_mp_steps: 5 6 | latent_dim: 128 7 | isotropic_norm: True 8 | magnitude_features: True 9 | 10 | train: 11 | optimizer: 12 | lr_start: 5.e-4 13 | -------------------------------------------------------------------------------- /configs/rpf_3d/segnn.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/rpf_3d/base.yaml 2 | 3 | model: 4 | name: segnn 5 | num_mp_steps: 10 6 | latent_dim: 64 7 | isotropic_norm: True 8 | 9 | train: 10 | optimizer: 11 | lr_start: 1.e-3 12 | -------------------------------------------------------------------------------- /configs/tgv_2d/base.yaml: -------------------------------------------------------------------------------- 1 | extends: LAGRANGEBENCH_DEFAULTS 2 | 3 | dataset: 4 | src: datasets/2D_TGV_2500_10kevery100 5 | 6 | logging: 7 | wandb_project: tgv_2d 8 | -------------------------------------------------------------------------------- /configs/tgv_2d/gns.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/tgv_2d/base.yaml 2 | 3 | model: 4 | name: gns 5 | num_mp_steps: 10 6 | latent_dim: 128 7 | 8 | train: 9 | optimizer: 10 | lr_start: 5.e-4 11 | -------------------------------------------------------------------------------- /configs/tgv_2d/segnn.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/tgv_2d/base.yaml 2 | 3 | model: 4 | name: segnn 5 | num_mp_steps: 10 6 | latent_dim: 64 7 | isotropic_norm: True 8 | 9 | train: 10 | optimizer: 11 | lr_start: 5.e-4 12 | -------------------------------------------------------------------------------- /configs/tgv_3d/base.yaml: -------------------------------------------------------------------------------- 1 | extends: LAGRANGEBENCH_DEFAULTS 2 | 3 | dataset: 4 | src: datasets/3D_TGV_8000_10kevery100 5 | 6 | logging: 7 | wandb_project: tgv_3d 8 | -------------------------------------------------------------------------------- /configs/tgv_3d/gns.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/tgv_3d/base.yaml 2 | 3 | model: 4 | name: gns 5 | num_mp_steps: 10 6 | latent_dim: 128 7 | 8 | train: 9 | optimizer: 10 | lr_start: 5.e-4 11 | -------------------------------------------------------------------------------- /configs/tgv_3d/segnn.yaml: -------------------------------------------------------------------------------- 1 | extends: configs/tgv_3d/base.yaml 2 | 3 | model: 4 | name: segnn 5 | num_mp_steps: 10 6 | latent_dim: 64 7 | isotropic_norm: True 8 | 9 | train: 10 | optimizer: 11 | lr_start: 5.e-4 12 | -------------------------------------------------------------------------------- /data_gen/gns_data/README.md: -------------------------------------------------------------------------------- 1 | # Demonstration on how to train the GNS model on one of its original 2D datasets 2 | 3 | > Check out the full notebook under [`notebooks/gns_data.ipynb`](../notebooks/gns_data.ipynb). 4 | 5 | ## Download data 6 | 7 | ```bash 8 | mkdir -p /tmp/datasets 9 | bash data_gen/gns_data/download_dataset.sh WaterDrop /tmp/datasets 10 | ``` 11 | 12 | ## Transform data from .tfrecord to .h5 13 | 14 | First, you need the `tensorflow` and `tensorflow-datasets` libraries. We recommend installing these in a separate virtual environment to avoid CUDA version conflicts. 15 | 16 | ```bash 17 | python3 -m venv venv_tf 18 | venv_tf/bin/pip install tensorflow tensorflow-datasets 19 | ``` 20 | 21 | Then, transform the data via 22 | 23 | ```bash 24 | ./venv_tf/bin/python data_gen/gns_data/tfrecord_to_h5.py --dataset-path=/tmp/datasets/WaterDrop 25 | ``` 26 | 27 | and train the usual way. 28 | -------------------------------------------------------------------------------- /data_gen/gns_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tumaer/lagrangebench/b880a6c84a93792d2499d2a9b8ba3a077ddf44e2/data_gen/gns_data/__init__.py -------------------------------------------------------------------------------- /data_gen/gns_data/download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2020 Deepmind Technologies Limited. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Usage: 17 | # bash download_dataset.sh ${DATASET_NAME} ${OUTPUT_DIR} 18 | # Example: 19 | # bash download_dataset.sh WaterDrop /tmp/ 20 | 21 | # Source: https://github.com/deepmind/deepmind-research/tree/master/learning_to_simulate 22 | 23 | set -e 24 | 25 | DATASET_NAME="${1}" 26 | OUTPUT_DIR="${2}/${DATASET_NAME}" 27 | 28 | BASE_URL="https://storage.googleapis.com/learning-to-simulate-complex-physics/Datasets/${DATASET_NAME}/" 29 | 30 | mkdir -p ${OUTPUT_DIR} 31 | for file in metadata.json train.tfrecord valid.tfrecord test.tfrecord 32 | do 33 | wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}" 34 | done 35 | -------------------------------------------------------------------------------- /data_gen/gns_data/reading_utils.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Utilities for reading open sourced Learning Complex Physics data.""" 17 | 18 | # Source: https://github.com/deepmind/deepmind-research/tree/master/learning_to_simulate 19 | 20 | 21 | import functools 22 | 23 | import numpy as np 24 | import tensorflow.compat.v1 as tf 25 | 26 | # Create a description of the features. 27 | _FEATURE_DESCRIPTION = { 28 | "position": tf.io.VarLenFeature(tf.string), 29 | } 30 | 31 | _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT = _FEATURE_DESCRIPTION.copy() 32 | _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT["step_context"] = tf.io.VarLenFeature( 33 | tf.string 34 | ) 35 | 36 | _FEATURE_DTYPES = { 37 | "position": {"in": np.float32, "out": tf.float32}, 38 | "step_context": {"in": np.float32, "out": tf.float32}, 39 | } 40 | 41 | _CONTEXT_FEATURES = { 42 | "key": tf.io.FixedLenFeature([], tf.int64, default_value=0), 43 | "particle_type": tf.io.VarLenFeature(tf.string), 44 | } 45 | 46 | 47 | def convert_to_tensor(x, encoded_dtype): 48 | if len(x) == 1: 49 | out = np.frombuffer(x[0].numpy(), dtype=encoded_dtype) 50 | else: 51 | out = [] 52 | for el in x: 53 | out.append(np.frombuffer(el.numpy(), dtype=encoded_dtype)) 54 | out = tf.convert_to_tensor(np.array(out)) 55 | return out 56 | 57 | 58 | def parse_serialized_simulation_example(example_proto, metadata): 59 | """Parses a serialized simulation tf.SequenceExample. 60 | 61 | Args: 62 | example_proto: A string encoding of the tf.SequenceExample proto. 63 | metadata: A dict of metadata for the dataset. 64 | 65 | Returns: 66 | context: A dict, with features that do not vary over the trajectory. 67 | parsed_features: A dict of tf.Tensors representing the parsed examples 68 | across time, where axis zero is the time axis. 69 | 70 | """ 71 | if "context_mean" in metadata: 72 | feature_description = _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT 73 | else: 74 | feature_description = _FEATURE_DESCRIPTION 75 | context, parsed_features = tf.io.parse_single_sequence_example( 76 | example_proto, 77 | context_features=_CONTEXT_FEATURES, 78 | sequence_features=feature_description, 79 | ) 80 | for feature_key, item in parsed_features.items(): 81 | convert_fn = functools.partial( 82 | convert_to_tensor, encoded_dtype=_FEATURE_DTYPES[feature_key]["in"] 83 | ) 84 | parsed_features[feature_key] = tf.py_function( 85 | convert_fn, inp=[item.values], Tout=_FEATURE_DTYPES[feature_key]["out"] 86 | ) 87 | 88 | # There is an extra frame at the beginning so we can calculate pos change 89 | # for all frames used in the paper. 90 | position_shape = [metadata["sequence_length"] + 1, -1, metadata["dim"]] 91 | 92 | # Reshape positions to correct dim: 93 | parsed_features["position"] = tf.reshape( 94 | parsed_features["position"], position_shape 95 | ) 96 | # Set correct shapes of the remaining tensors. 97 | sequence_length = metadata["sequence_length"] + 1 98 | if "context_mean" in metadata: 99 | context_feat_len = len(metadata["context_mean"]) 100 | parsed_features["step_context"] = tf.reshape( 101 | parsed_features["step_context"], [sequence_length, context_feat_len] 102 | ) 103 | # Decode particle type explicitly 104 | context["particle_type"] = tf.py_function( 105 | functools.partial(convert_fn, encoded_dtype=np.int64), 106 | inp=[context["particle_type"].values], 107 | Tout=[tf.int64], 108 | ) 109 | context["particle_type"] = tf.reshape(context["particle_type"], [-1]) 110 | return context, parsed_features 111 | 112 | 113 | def split_trajectory(context, features, window_length=7): 114 | """Splits trajectory into sliding windows.""" 115 | # Our strategy is to make sure all the leading dimensions are the same size, 116 | # then we can use from_tensor_slices. 117 | 118 | trajectory_length = features["position"].get_shape().as_list()[0] 119 | 120 | # We then stack window_length position changes so the final 121 | # trajectory length will be - window_length +1 (the 1 to make sure we get 122 | # the last split). 123 | input_trajectory_length = trajectory_length - window_length + 1 124 | 125 | model_input_features = {} 126 | # Prepare the context features per step. 127 | model_input_features["particle_type"] = tf.tile( 128 | tf.expand_dims(context["particle_type"], axis=0), [input_trajectory_length, 1] 129 | ) 130 | 131 | if "step_context" in features: 132 | global_stack = [] 133 | for idx in range(input_trajectory_length): 134 | global_stack.append(features["step_context"][idx : idx + window_length]) 135 | model_input_features["step_context"] = tf.stack(global_stack) 136 | 137 | pos_stack = [] 138 | for idx in range(input_trajectory_length): 139 | pos_stack.append(features["position"][idx : idx + window_length]) 140 | # Get the corresponding positions 141 | model_input_features["position"] = tf.stack(pos_stack) 142 | 143 | return tf.data.Dataset.from_tensor_slices(model_input_features) 144 | -------------------------------------------------------------------------------- /data_gen/gns_data/tfrecord_to_h5.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import json 4 | import os 5 | 6 | import h5py 7 | import numpy as np 8 | import reading_utils 9 | import tensorflow.compat.v1 as tf 10 | import tensorflow_datasets as tfds 11 | 12 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 13 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 15 | 16 | 17 | def convert_tfrecord_to_h5(args): 18 | """Read .tfrecord file and convert it to its closest .h5 equivalent""" 19 | 20 | file_path = os.path.join(args.dataset_path, args.file_name) 21 | print(f"Start conversion of {file_path} to .h5") 22 | 23 | with open(os.path.join(args.dataset_path, "metadata.json"), "r") as fp: 24 | metadata = json.loads(fp.read()) 25 | 26 | # get the TFRecordDataset with its proper preprocessing 27 | ds = tf.data.TFRecordDataset([file_path]) 28 | ds = ds.map( 29 | functools.partial( 30 | reading_utils.parse_serialized_simulation_example, metadata=metadata 31 | ) 32 | ) 33 | ds = tfds.as_numpy(ds) 34 | 35 | h5_file_path = f"{file_path[:-9]}.h5" 36 | hf = h5py.File(h5_file_path, "w") 37 | 38 | for i, elem in enumerate(ds): 39 | traj_str = str(i).zfill(5) 40 | 41 | particle_type = elem[0]["particle_type"] 42 | key = elem[0]["key"] 43 | position = elem[1]["position"] 44 | 45 | hf.create_dataset(f"{traj_str}/particle_type", data=particle_type) 46 | assert key == i, "Something went wrong here" 47 | hf.create_dataset( 48 | f"{traj_str}/position", 49 | data=position, 50 | dtype=np.float32, 51 | compression="gzip", 52 | ) 53 | 54 | hf.close() 55 | print(f"Finish conversion to {h5_file_path}") 56 | 57 | 58 | def main(args): 59 | files = os.listdir(args.dataset_path) 60 | files = [f for f in files if f.endswith(".tfrecord")] 61 | for file_name in files: 62 | args.file_name = file_name 63 | convert_tfrecord_to_h5(args) 64 | 65 | # add the maximum number of particles to the metadata 66 | # Crucial for the matscipy neighbors search 67 | 68 | # first find the maximum number of particles 69 | files = os.listdir(args.dataset_path) 70 | files = [f for f in files if f.endswith(".h5")] 71 | max_particles = 0 72 | for file_name in files: 73 | h5_file_path = os.path.join(args.dataset_path, file_name) 74 | hf = h5py.File(h5_file_path, "r") 75 | for k, v in hf.items(): 76 | max_particles = max(v["particle_type"].shape[0], max_particles) 77 | print(f"Max number of particles in {file_name}: {max_particles}") 78 | hf.close() 79 | 80 | metadata_path = os.path.join(args.dataset_path, "metadata.json") 81 | with open(metadata_path, "r") as fp: 82 | metadata = json.loads(fp.read()) 83 | 84 | metadata["num_particles_max"] = max_particles 85 | # all DeepMind datasets are non-periodic 86 | metadata["periodic_boundary_conditions"] = [False, False, False] 87 | 88 | with open(metadata_path, "w") as f: 89 | json.dump(metadata, f) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--dataset-path", type=str) 95 | args = parser.parse_args() 96 | 97 | main(args) 98 | -------------------------------------------------------------------------------- /data_gen/lagrangebench_data/README.md: -------------------------------------------------------------------------------- 1 | # LagrangeBench dataset generation 2 | 3 | To generate the [LagrangeBench datasets](https://zenodo.org/doi/10.5281/zenodo.10021925), we extend the case files provided at https://github.com/tumaer/jax-sph. We first copy the case files and the `main.py` file from JAX-SPH. 4 | 5 | ```bash 6 | cd data_gen/lagrangebench_data 7 | git clone https://github.com/tumaer/jax-sph.git 8 | cd jax-sph 9 | # We use this specific tag 10 | git checkout v0.0.2 11 | 12 | cd .. 13 | cp -r jax-sph/cases/ . 14 | cp jax-sph/main.py . 15 | ``` 16 | 17 | Then we install JAX-SPH 18 | ```bash 19 | pip install jax-sph/ 20 | # or 21 | # pip install jax-sph==0.0.2 22 | 23 | # make sure to have the dasired JAX version, e.g. 24 | pip install jax[cuda12]==0.4.29 25 | ``` 26 | 27 | And the only thing left is running the bash scripts. E.g. for Taylor-Green it is: 28 | ```bash 29 | bash dataset_tgv.sh 30 | 31 | # cleanup 32 | rm -rf jax-sph/ cases/ main.py 33 | ``` 34 | 35 | To inspect whether the simulated trajectories are of the desired length, we can use 36 | ```bash 37 | count_and_check() { find "$1" -type f | sed 's%/[^/]*$%%' | uniq -c | awk -v target="$2" '{if ($1 != target) $1=$1" - Failed!"; n=split($2, a, /[/]/); print a[n]" - "$1}'; } 38 | 39 | # and apply it e.g. to TGV with the arguments: 1. dataset directory and 2. target count 40 | count_and_check "/tmp/lagrangebench_data/raw/2D_TGV_2500_10kevery100/" 127 41 | ``` 42 | 43 | ## Errata 44 | 1. For dam break, one would need to replace lines 182-184 in `jax-sph/case_setup.py` with: 45 | ```python 46 | mask, _mask = state["tag"]==Tag.FLUID, _state["tag"]==Tag.FLUID 47 | assert state[k][mask].shape == _state[k][_mask].shape, ValueError( 48 | f"Shape mismatch for key {k} in state0 file." 49 | ) 50 | state[k][mask] = _state[k][_mask] 51 | ``` -------------------------------------------------------------------------------- /data_gen/lagrangebench_data/dataset_db.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # CUDA_VISIBLE_DEVICES=0 nohup ./scripts/dataset_db.sh >> db_dataset.out 2>&1 & 3 | 4 | DATA_ROOT=/tmp/lagrangebench_data 5 | 6 | ##### 2D dataset 7 | for seed in {0..114} # 15 trajectories blew up and were discarded 8 | do 9 | echo "Run with seed = $seed" 10 | python main.py config=cases/db.yaml seed=$seed case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/ 11 | python main.py config=cases/db.yaml seed=$seed case.state0_path=$DATA_ROOT/relaxed/db_2_0.02_$seed.h5 io.data_path=$DATA_ROOT/raw/2D_DB_5740_20kevery100/ 12 | done 13 | # 15 blowing up runs were removed from the dataset, and 100 were kept 14 | # use `count_files_db.py` to detect defect runs 15 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/2D_DB_5740_20kevery100/ --dst_dir=$DATA_ROOT/datasets/2D_DB_5740_20kevery100/ --split=2_1_1 16 | 17 | 18 | ### Number of particles 19 | # 100x50=5000 water particles 20 | # 106x274 outer box, i.e. 106x274 - 100x268 = 2244 wall particles 21 | # with only one wall layer, wall particles are 2*(100+270) = 740 22 | # => 7244 particles with SPH and 5740 for dataset 23 | 24 | ### Number of seeds for a given number of training samples 25 | # t_end = 12 26 | # dt_coarse = 0.0003*100 = 0.03 27 | # num_samples = 12/0.03 = 400 (+1 for initial state) 28 | # => for 40k train+valid+test samples, we need 40k/400 = 100 seeds 29 | -------------------------------------------------------------------------------- /data_gen/lagrangebench_data/dataset_ldc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # CUDA_VISIBLE_DEVICES=0 nohup ./scripts/dataset_ldc.sh 2>&1 & 3 | 4 | DATA_ROOT=/tmp/lagrangebench_data 5 | 6 | ##### 2D dataset 7 | python main.py config=cases/ldc.yaml case.dim=2 case.dx=0.02 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/ io.p_bg_factor=0.0 8 | python main.py config=cases/ldc.yaml case.dim=2 case.dx=0.02 solver.dt=0.0004 solver.t_end=85 case.state0_path=$DATA_ROOT/relaxed/ldc_2_0.02_123.h5 io.data_path=$DATA_ROOT/raw/2D_LDC_2500_10kevery100/ 9 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/2D_LDC_2500_10kevery100/ --dst_dir=$DATA_ROOT/datasets/2D_LDC_2500_10kevery100/ --split=2_1_1 --skip_first_n_frames=1248 10 | 11 | # dt_coarse = 0.0004 * 100 = 0.04 12 | # to get 20k samples, we simulate for t = 20k * dt_coarse = 800 (+50 for equilibriation) 13 | 14 | ##### 3D dataset 15 | python main.py config=cases/ldc.yaml case.dim=3 case.dx=0.041666667 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/ io.p_bg_factor=0.0 16 | python main.py config=cases/ldc.yaml case.dim=3 case.dx=0.041666667 solver.dt=0.0009 solver.t_end=1850 case.state0_path=$DATA_ROOT/relaxed/ldc_3_0.041666667_123.h5 io.data_path=$DATA_ROOT/raw/3D_LDC_8160_10kevery100/ 17 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/3D_LDC_8160_10kevery100/ --dst_dir=$DATA_ROOT/datasets/3D_LDC_8160_10kevery100/ --split=2_1_1 --skip_first_n_frames=555 18 | 19 | # dt_coarse = 0.0009 * 100 = 0.09 20 | # to get 20k samples, we simulate for t = 20k * dt_coarse = 1800 (+50 for equilibriation) 21 | -------------------------------------------------------------------------------- /data_gen/lagrangebench_data/dataset_rpf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # run this script with: 3 | # CUDA_VISIBLE_DEVICES=0 nohup ./scripts/dataset_rpf.sh 2>&1 & 4 | 5 | DATA_ROOT=/tmp/lagrangebench_data 6 | 7 | ##### 2D RPF 8 | python main.py config=cases/rpf.yaml case.dim=2 case.dx=0.025 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/ 9 | python main.py config=cases/rpf.yaml case.dim=2 case.dx=0.025 solver.dt=0.0005 solver.t_end=2050 case.state0_path=$DATA_ROOT/relaxed/rpf_2_0.025_123.h5 io.data_path=$DATA_ROOT/raw/2D_RPF_3200_20kevery100/ 10 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/2D_RPF_3200_20kevery100/ --dst_dir=$DATA_ROOT/datasets/2D_RPF_3200_20kevery100/ --split=2_1_1 --skip_first_n_frames=998 11 | 12 | ###### 3D RPF 13 | python main.py config=cases/rpf.yaml case.dim=3 case.dx=0.05 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/ 14 | python main.py config=cases/rpf.yaml case.dim=3 case.dx=0.05 solver.dt=0.001 solver.t_end=2050 case.state0_path=$DATA_ROOT/relaxed/rpf_3_0.05_123.h5 io.data_path=$DATA_ROOT/raw/3D_RPF_8000_10kevery100/ eos.p_bg_factor=0.02 15 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/3D_RPF_8000_10kevery100/ --dst_dir=$DATA_ROOT/datasets/3D_RPF_8000_10kevery100/ --split=2_1_1 --skip_first_n_frames=498 16 | 17 | # dt_coarse = 0.001 * 100 = 0.1 18 | # to get 20k samples, we simulate for t = 20k * dt_coarse = 2000 (+50 for equilibriation) 19 | 20 | # Also, for a fast particle (vel=1) to cross the box once (length=1) it takes t=1. 21 | # => it takes 10 steps to cross the box with dt_coarse = 0.1 22 | # with 20k samples for train+valid+test, the box is crossed 2000 times -------------------------------------------------------------------------------- /data_gen/lagrangebench_data/dataset_tgv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # CUDA_VISIBLE_DEVICES=0 nohup ./scripts/dataset_tgv.sh 2>&1 & 3 | 4 | DATA_ROOT=/tmp/lagrangebench_data 5 | 6 | ###### 2D TGV 7 | for seed in {0..199} 8 | do 9 | echo "Run with seed = $seed" 10 | python main.py config=cases/tgv.yaml seed=$seed case.dim=2 case.dx=0.02 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/ 11 | python main.py config=cases/tgv.yaml seed=$seed case.dim=2 case.dx=0.02 solver.dt=0.0004 solver.t_end=5 case.state0_path=$DATA_ROOT/relaxed/tgv_2_0.02_$seed.h5 io.data_path=$DATA_ROOT/raw/2D_TGV_2500_10kevery100/ 12 | done 13 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/2D_TGV_2500_10kevery100/ --dst_dir=$DATA_ROOT/datasets/2D_TGV_2500_10kevery100/ --split=2_1_1 14 | 15 | ###### 3D TGV 16 | for seed in {0..399} 17 | do 18 | echo "Run with seed = $seed" 19 | python main.py config=cases/tgv.yaml seed=$seed case.dim=3 case.dx=0.314159265 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/ eos.p_bg_factor=0.01 20 | python main.py config=cases/tgv.yaml seed=$seed case.dim=3 case.dx=0.314159265 solver.dt=0.005 solver.t_end=30 case.state0_path=$DATA_ROOT/relaxed/tgv_3_0.314159265_$seed.h5 io.data_path=$DATA_ROOT/raw/3D_TGV_8000_10kevery100/ case.viscosity=0.02 21 | done 22 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/3D_TGV_8000_10kevery100/ --dst_dir=$DATA_ROOT/datasets/3D_TGV_8000_10kevery100/ --split=2_1_1 23 | -------------------------------------------------------------------------------- /data_gen/lagrangebench_data/gen_dataset.py: -------------------------------------------------------------------------------- 1 | """Script for generating ML datasets from h5 simulation frames""" 2 | 3 | import argparse 4 | import json 5 | import os 6 | 7 | import h5py 8 | import numpy as np 9 | from jax import vmap 10 | from jax_sph.io_state import read_h5, write_h5 11 | from jax_sph.jax_md import space 12 | from omegaconf import OmegaConf 13 | 14 | 15 | def write_h5_frame_for_visualization(state_dict, file_path_h5): 16 | path_file_vis = os.path.join(file_path_h5[:-3] + "_vis.h5") 17 | print("writing to", path_file_vis) 18 | write_h5(state_dict, path_file_vis) 19 | print("done") 20 | 21 | 22 | def single_h5_files_to_h5_dataset(args): 23 | """Transform a set of .h5 files to a single .h5 dataset file 24 | 25 | Args: 26 | src_dir: source directory containing other directories, each with .h5 files 27 | corresponding to a trajectory 28 | dst_dir: destination directory where three files will be written: train.h5, 29 | valid.h5, and test.h5 30 | split: string of three integers separated by underscores, e.g. "80_10_10" 31 | """ 32 | 33 | os.makedirs(args.dst_dir, exist_ok=True) 34 | 35 | # list only directories in a root with files and directories 36 | dirs = os.listdir(args.src_dir) 37 | dirs = [d for d in dirs if os.path.isdir(os.path.join(args.src_dir, d))] 38 | # order by seed value 39 | dirs = sorted(dirs, key=lambda x: int(x.split("_")[3])) 40 | 41 | splits_array = np.array([int(s) for s in args.split.split("_")]) 42 | splits_sum = splits_array.sum() 43 | 44 | if len(dirs) == 1: # split one long trajectory into train, valid, and test 45 | files = os.listdir(os.path.join(args.src_dir, dirs[0])) 46 | files = [f for f in files if (".h5" in f)] 47 | files = sorted(files, key=lambda x: int(x.split("_")[1][:-3])) 48 | files = files[args.skip_first_n_frames :: args.slice_every_nth_frame] 49 | 50 | num_eval = np.ceil(splits_array[1] / splits_sum * len(files)).astype(int) 51 | # at least one validation and one testing trajectory 52 | splits_trajs = np.cumsum([0, len(files) - 2 * num_eval, num_eval, num_eval]) 53 | 54 | num_trajs_train = num_trajs_test = 1 55 | 56 | sequence_length_train, sequence_length_test = splits_trajs[1] - 1, num_eval - 1 57 | else: # multiple trajectories 58 | num_eval = np.ceil(splits_array[1] / splits_sum * len(dirs)).astype(int) 59 | # at least one validation and one testing trajectory 60 | splits_trajs = np.cumsum([0, len(dirs) - 2 * num_eval, num_eval, num_eval]) 61 | 62 | num_trajs_train, num_trajs_test = len(dirs) - 2 * num_eval, num_eval 63 | 64 | # seqience_length should be after subsampling every nth trajectory 65 | # and "-1" because of the last target position (see GNS dataset format) 66 | files_per_traj = len(os.listdir(os.path.join(args.src_dir, dirs[0]))) 67 | sequence_length_train = sequence_length_test = files_per_traj - 1 68 | 69 | for i, split in enumerate(["train", "valid", "test"]): 70 | hf = h5py.File(os.path.join(args.dst_dir, f"{split}.h5"), "w") 71 | 72 | if len(dirs) == 1: # one long trajectory 73 | position = [] 74 | traj_path = os.path.join(args.src_dir, dirs[0]) 75 | 76 | for j, filename in enumerate(files[splits_trajs[i] : splits_trajs[i + 1]]): 77 | file_path_h5 = os.path.join(traj_path, filename) 78 | state = read_h5(file_path_h5, array_type="numpy") 79 | r = state["r"] 80 | tag = state["tag"] 81 | 82 | if "ldc" in args.src_dir.lower(): # remove outer walls in lid-driven 83 | L, H = 1.0, 1.0 84 | cfg = OmegaConf.load(os.path.join(traj_path, "config.yaml")) 85 | mask_bottom = np.where(r[:, 1] < 2 * cfg.case.dx, False, True) 86 | mask_lid = np.where(r[:, 1] > H + 4 * cfg.case.dx, False, True) 87 | mask_left = np.where( 88 | ((r[:, 0] < 2 * cfg.case.dx) * (tag == 1)), False, True 89 | ) 90 | mask_right = np.where( 91 | (r[:, 0] > L + 4 * cfg.case.dx) * (tag == 1), False, True 92 | ) 93 | mask = mask_bottom * mask_lid * mask_left * mask_right 94 | 95 | r = r[mask] 96 | tag = tag[mask] 97 | 98 | if args.is_visualize: 99 | write_h5_frame_for_visualization({"r": r, "tag": tag}, file_path_h5) 100 | position.append(r) 101 | 102 | position = np.stack(position) # (time steps, particles, dim) 103 | particle_type = tag # (particles,) 104 | 105 | traj_str = "00000" 106 | hf.create_dataset(f"{traj_str}/particle_type", data=particle_type) 107 | hf.create_dataset( 108 | f"{traj_str}/position", 109 | data=position, 110 | dtype=np.float32, 111 | compression="gzip", 112 | ) 113 | 114 | else: # multiple trajectories 115 | for j, dir in enumerate(dirs[splits_trajs[i] : splits_trajs[i + 1]]): 116 | traj_path = os.path.join(args.src_dir, dir) 117 | files = os.listdir(traj_path) 118 | files = [f for f in files if (".h5" in f)] 119 | files = sorted(files, key=lambda x: int(x.split("_")[1][:-3])) 120 | files = files[args.skip_first_n_frames :: args.slice_every_nth_frame] 121 | 122 | position = [] 123 | for k, filename in enumerate(files): 124 | file_path_h5 = os.path.join(traj_path, filename) 125 | state = read_h5(file_path_h5, array_type="numpy") 126 | r = state["r"] 127 | tag = state["tag"] 128 | 129 | if "db" in args.src_dir.lower(): # remove outer walls in dam break 130 | L, H = 5.366, 2.0 131 | cfg = OmegaConf.load(os.path.join(traj_path, "config.yaml")) 132 | mask_bottom = np.where(r[:, 1] < 2 * cfg.sase.dx, False, True) 133 | mask_lid = np.where(r[:, 1] > H + 4 * cfg.case.dx, False, True) 134 | mask_left = np.where( 135 | ((r[:, 0] < 2 * cfg.case.dx) * (tag == 1)), False, True 136 | ) 137 | mask_right = np.where( 138 | (r[:, 0] > L + 4 * cfg.case.dx) * (tag == 1), False, True 139 | ) 140 | mask = mask_bottom * mask_lid * mask_left * mask_right 141 | 142 | r = r[mask] 143 | tag = tag[mask] 144 | 145 | if args.is_visualize: 146 | write_h5_frame_for_visualization( 147 | {"r": r, "tag": tag}, file_path_h5 148 | ) 149 | position.append(r) 150 | position = np.stack(position) # (time steps, particles, dim) 151 | particle_type = tag # (particles,) 152 | 153 | traj_str = str(j).zfill(5) 154 | hf.create_dataset(f"{traj_str}/particle_type", data=particle_type) 155 | hf.create_dataset( 156 | f"{traj_str}/position", 157 | data=position, 158 | dtype=np.float32, 159 | compression="gzip", 160 | ) 161 | 162 | hf.close() 163 | print(f"Finished {args.src_dir} {split} with {j+1} entries!") 164 | print(f"Sample positions shape {position.shape}") 165 | 166 | # metadata 167 | # Compatible with the lagrangebench metadata.json files 168 | cfg = OmegaConf.load(os.path.join(traj_path, "config.yaml")) 169 | 170 | metadata = { 171 | "case": cfg.case.name.upper(), 172 | "solver": cfg.solver.name, 173 | "density_evolution": cfg.solver.density_evolution, 174 | "dim": cfg.case.dim, 175 | "dx": cfg.case.dx, 176 | "dt": cfg.solver.dt, 177 | "t_end": cfg.solver.t_end, 178 | "viscosity": cfg.case.viscosity, 179 | "p_bg_factor": cfg.eos.p_bg_factor, 180 | "g_ext_magnitude": cfg.case.g_ext_magnitude, 181 | "artificial_alpha": cfg.solver.artificial_alpha, 182 | "free_slip": cfg.solver.free_slip, 183 | "write_every": cfg.io.write_every, 184 | "is_bc_trick": cfg.solver.is_bc_trick, 185 | "sequence_length_train": int(sequence_length_train), 186 | "num_trajs_train": int(num_trajs_train), 187 | "sequence_length_test": int(sequence_length_test), 188 | "num_trajs_test": int(num_trajs_test), 189 | "num_particles_max": cfg.case.num_particles_max, 190 | "periodic_boundary_conditions": list(cfg.case.pbc), 191 | "bounds": np.array(cfg.case.bounds).tolist(), 192 | } 193 | x = 1.45 * cfg.case["dx"] # around 1.5 dx 194 | x = np.format_float_positional( 195 | x, precision=2, unique=False, fractional=False, trim="k" 196 | ) 197 | metadata["default_connectivity_radius"] = float(x) 198 | 199 | with open(os.path.join(args.dst_dir, "metadata.json"), "w") as f: 200 | json.dump(metadata, f) 201 | 202 | 203 | def compute_statistics_h5(args): 204 | """Compute the mean and std of a h5 dataset files""" 205 | 206 | # metadata 207 | with open(os.path.join(args.dst_dir, "metadata.json"), "r") as f: 208 | metadata = json.load(f) 209 | 210 | # apply PBC in all directions or not at all 211 | if np.array(metadata["periodic_boundary_conditions"]).any(): 212 | box = np.array(metadata["bounds"]) 213 | box = box[:, 1] - box[:, 0] 214 | displacement_fn, _ = space.periodic(side=box) 215 | else: 216 | displacement_fn, _ = space.free() 217 | 218 | displacement_fn_sets = vmap(vmap(displacement_fn, in_axes=(0, 0))) 219 | 220 | vels, accs = [], [] 221 | vels_sq, accs_sq = [], [] 222 | vel_mean = acc_mean = 0.0 # to fix "F821 Undefined name ..." ruff error 223 | for loop in ["mean", "std"]: 224 | for split in ["train", "valid", "test"]: 225 | hf = h5py.File(os.path.join(args.dst_dir, f"{split}.h5"), "r") 226 | 227 | for _, v in hf.items(): 228 | tag = v.get("particle_type")[:] 229 | r = v.get("position")[:][:, tag == 0] # only fluid ("0") particles 230 | 231 | # The velocity and acceleration computation is based on an 232 | # inversion of Semi-Implicit Euler 233 | vel = displacement_fn_sets(r[1:], r[:-1]) 234 | if loop == "mean": 235 | vels.append(vel.mean((0, 1))) 236 | accs.append((vel[1:] - vel[:-1]).mean((0, 1))) 237 | elif loop == "std": 238 | centered_vel = vel - vel_mean 239 | vels_sq.append(np.square(centered_vel).mean((0, 1))) 240 | centered_acc = vel[1:] - vel[:-1] - acc_mean 241 | accs_sq.append(np.square(centered_acc).mean((0, 1))) 242 | 243 | hf.close() 244 | 245 | if loop == "mean": 246 | vel_mean = np.stack(vels).mean(0) 247 | acc_mean = np.stack(accs).mean(0) 248 | print(f"vel_mean={vel_mean}, acc_mean={acc_mean}") 249 | elif loop == "std": 250 | vel_std = np.stack(vels_sq).mean(0) ** 0.5 251 | acc_std = np.stack(accs_sq).mean(0) ** 0.5 252 | print(f"vel_std={vel_std}, acc_std={acc_std}") 253 | 254 | # stds should not be 0. If they are, set them to 1. 255 | vel_std = np.where(vel_std < 1e-7, 1, vel_std) 256 | acc_std = np.where(acc_std < 1e-7, 1, acc_std) 257 | 258 | metadata["vel_mean"] = vel_mean.tolist() 259 | metadata["vel_std"] = vel_std.tolist() 260 | 261 | metadata["acc_mean"] = acc_mean.tolist() 262 | metadata["acc_std"] = acc_std.tolist() 263 | 264 | with open(os.path.join(args.dst_dir, "metadata.json"), "w") as f: 265 | json.dump(metadata, f) 266 | 267 | 268 | if __name__ == "__main__": 269 | parser = argparse.ArgumentParser() 270 | parser.add_argument("--src_dir", type=str) 271 | parser.add_argument("--dst_dir", type=str) 272 | parser.add_argument("--split", type=str, help="E.g. 3_1_1") 273 | parser.add_argument("--skip_first_n_frames", type=int, default=0) 274 | parser.add_argument("--slice_every_nth_frame", type=int, default=1) 275 | parser.add_argument("--is_visualize", action="store_true") 276 | args = parser.parse_args() 277 | 278 | single_h5_files_to_h5_dataset(args) 279 | compute_statistics_h5(args) 280 | -------------------------------------------------------------------------------- /data_gen/lagrangebench_data/plot_frame.py: -------------------------------------------------------------------------------- 1 | """Print a frame for visual inspection of the data.""" 2 | 3 | import argparse 4 | 5 | import h5py 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def plot_frame(src_dir, frame): 10 | with h5py.File(src_dir, "r") as f: 11 | tag = f["00000/particle_type"][:] 12 | r = f["00000/position"][frame] 13 | 14 | plt.scatter(r[:, 0], r[:, 1], c=tag) 15 | plt.savefig(f"frame_{frame}.png") 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser( 20 | description="Print a frame for visual inspection of the data." 21 | ) 22 | parser.add_argument("--src_dir", type=str, help="Source directory.") 23 | parser.add_argument("--frame", type=int, help="Which frame to plot.") 24 | args = parser.parse_args() 25 | 26 | plot_frame(args.src_dir, args.frame) 27 | 28 | # Example: 29 | # python plot_frame.py --src_dir=datasets/2D_TGV_2500_10kevery100/train.h5 --frame=0 30 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = "LagrangeBench" 10 | copyright = "2023, Chair of Aerodynamics and Fluid Mechanics, TUM" 11 | author = "Artur Toshev, Gianluca Galletti" 12 | 13 | # read the version from pyproject.toml 14 | import toml 15 | 16 | pyproject = toml.load("../pyproject.toml") 17 | version = pyproject["tool"]["poetry"]["version"] 18 | 19 | # -- Path setup -------------------------------------------------------------- 20 | 21 | # If extensions (or modules to document with autodoc) are in another directory, 22 | # add these directories to sys.path here. If the directory is relative to the 23 | # documentation root, use os.path.abspath to make it absolute, like shown here. 24 | # 25 | import os 26 | import sys 27 | 28 | sys.path.insert(0, os.path.abspath("..")) 29 | 30 | import collections 31 | 32 | # -- General configuration --------------------------------------------------- 33 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 34 | 35 | extensions = [ 36 | "sphinx.ext.autodoc", 37 | "sphinx.ext.viewcode", 38 | "sphinx.ext.napoleon", 39 | "sphinx.ext.intersphinx", 40 | "sphinx.ext.mathjax", 41 | # to get defaults.py in the documentation 42 | "sphinx_exec_code", 43 | ] 44 | 45 | numfig = True 46 | 47 | templates_path = ["_templates"] 48 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 49 | 50 | 51 | # -- Options for HTML output ------------------------------------------------- 52 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 53 | 54 | html_theme = "sphinx_rtd_theme" 55 | html_static_path = ["_static"] 56 | 57 | 58 | # -- Options for autodoc ----------------------------------------------------- 59 | 60 | autodoc_default_options = { 61 | "member-order": "bysource", 62 | "special-members": True, 63 | "exclude-members": "__repr__, __str__, __weakref__", 64 | } 65 | 66 | 67 | # -- Options for sphinx-exec-code --------------------------------------------- 68 | 69 | exec_code_working_dir = ".." 70 | 71 | 72 | # drop the docstrings of undocumented the namedtuple attributes 73 | def remove_namedtuple_attrib_docstring(app, what, name, obj, skip, options): 74 | if type(obj) is collections._tuplegetter: 75 | return True 76 | return skip 77 | 78 | 79 | def setup(app): 80 | app.connect("autodoc-skip-member", remove_namedtuple_attrib_docstring) 81 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. LagrangeBench documentation master file, created by 2 | sphinx-quickstart on Fri Aug 18 19:44:58 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | LagrangeBench 7 | ============= 8 | 9 | .. image:: https://drive.google.com/thumbnail?id=1rP0pf1KL8iGbly0tA0qthUE_tMDv_9Jp&sz=w1000 10 | :alt: rpf2d.gif 11 | 12 | .. image:: https://drive.google.com/thumbnail?id=1BMGkHj9EYMGUOdsE5QwiJWCTvDNqveHc&sz=w1000 13 | :alt: rpf3d.gif 14 | 15 | 16 | What is ``LagrangeBench``? 17 | -------------------------- 18 | 19 | LagrangeBench is a machine learning benchmarking suite for **Lagrangian particle 20 | problems** based on the `JAX `_ library. It provides: 21 | 22 | - **Data loading and preprocessing** utilities for particle data. 23 | - Three different **neighbors search routines**: (a) original JAX-MD implementation, (b) 24 | memory efficient version of the JAX-MD implementation, and (c) a wrapper around the 25 | matscipy implementation allowing to handle variable number of particles. 26 | - JAX reimplementation of established **graph neural networks**: GNS, SEGNN, EGNN, PaiNN. 27 | - **Training strategies** including random-walk additive noise and the pushforward trick. 28 | - Evaluation tools consisting of **rollout generation** and different **error metrics**: 29 | position MSE, kinetic energy MSE, and Sinkhorn distance for the particle distribution. 30 | 31 | 32 | .. note:: 33 | 34 | For more details on LagrangeBench usage check out our `tutorials `_. 35 | 36 | 37 | 38 | Data loading and preprocessing 39 | ------------------------------ 40 | 41 | First, we create a dataset class based on ``torch.utils.data.Dataset``. 42 | We then initialize a ``CaseSetupFn`` object taking care of the neighbors search, 43 | preprocessing, and time integration. 44 | 45 | .. code-block:: python 46 | 47 | import lagrangebench 48 | 49 | # Load data 50 | data_train = lagrangebench.RPF2D("train") 51 | data_valid = lagrangebench.RPF2D("valid", extra_seq_length=20) 52 | data_test = lagrangebench.RPF2D("test", extra_seq_length=20) 53 | 54 | # Case setup (preprocessing and graph building) 55 | bounds = np.array(data_train.metadata["bounds"]) 56 | box = bounds[:, 1] - bounds[:, 0] 57 | case = lagrangebench.case_builder( 58 | box=box, 59 | metadata=data_train.metadata, 60 | input_seq_length=6, 61 | ) 62 | 63 | 64 | Models 65 | ------ 66 | 67 | Initialize a GNS model. 68 | 69 | .. code-block:: python 70 | 71 | import haiku as hk 72 | 73 | def gns(x): 74 | return lagrangebench.models.GNS( 75 | particle_dimension=data_train.metadata["dim"], 76 | latent_size=16, 77 | blocks_per_step=2, 78 | num_mp_steps=4, 79 | particle_type_embedding_size=8, 80 | )(x) 81 | 82 | gns = hk.without_apply_rng(hk.transform_with_state(gns)) 83 | 84 | 85 | Training 86 | -------- 87 | 88 | The ``Trainer`` provides a convenient way to train a model. 89 | 90 | .. code-block:: python 91 | 92 | trainer = lagrangebench.Trainer( 93 | model=gns, 94 | case=case, 95 | data_train=data_train, 96 | data_valid=data_valid, 97 | cfg_eval={"n_rollout_steps": 20, "train": {"metrics": ["mse"]}}, 98 | input_seq_length=6 99 | ) 100 | 101 | # Train for 25000 steps 102 | params, state, _ = trainer.train(step_max=25000) 103 | 104 | 105 | Evaluation 106 | ---------- 107 | 108 | When training is done, we can evaluate the model on the test set. 109 | 110 | .. code-block:: python 111 | 112 | metrics = lagrangebench.infer( 113 | gns, 114 | case, 115 | data_test, 116 | params, 117 | state, 118 | cfg_eval_infer={"metrics": ["mse", "sinkhorn", "e_kin"]}, 119 | n_rollout_steps=20, 120 | ) 121 | 122 | 123 | .. toctree:: 124 | :maxdepth: 2 125 | :caption: Getting Started 126 | 127 | pages/tutorial 128 | pages/defaults 129 | pages/baselines 130 | 131 | .. toctree:: 132 | :maxdepth: 2 133 | :caption: API 134 | 135 | pages/data 136 | pages/case_setup 137 | pages/models 138 | pages/train 139 | pages/evaluate 140 | pages/utils 141 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/pages/baselines.rst: -------------------------------------------------------------------------------- 1 | Baselines 2 | =================================== 3 | 4 | The table below provides inference performance and baseline results on all LagrangeBench datasets. 5 | Runtimes are evaluated Nvidia A6000 48GB GPU. 6 | 7 | .. note:: 8 | 9 | Result discussion and hyperparams can be found in the full paper `"LagrangeBench: A Lagrangian Fluid Mechanics Benchmarking Suite" `_. 10 | 11 | 12 | .. raw:: html 13 | 14 | 38 |
39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 |
Runtime performance and baseline results.
Model#ParamsForward [ms]MSE5MSE20
TGV 2D (2.5K)GNS-5-64161K1.46.4e-79.6e-6
GNS-10-1281.2M5.33.9e-76.6e-6
SEGNN-5-64183K9.83.8e-76.5e-6
SEGNN-10-64360K20.22.4e-74.4e-6
RPF 2D (3.2K)GNS-5-64161K2.14.0e-79.8e-6
GNS-10-1281.2M6.71.1e-73.3e-6
SEGNN-5-64183K15.11.3e-74.0e-6
SEGNN-10-64360K29.71.3e-74.0e-6
EGNN-5-128663K60.8unstableunstable
PaiNN-5-1281.0M9.13.0e-67.2e-5
LDC 2D (2.7K)GNS-5-64161K1.52.0e-61.7e-5
GNS-10-1281.2M5.76.4e-71.4e-5
SEGNN-5-64183K10.09.9e-71.7e-5
SEGNN-10-64360K21.11.4e-62.5e-5
DAM 2D (5.7K)GNS-5-64161K3.82.1e-66.3e-5
GNS-10-1281.2M11.91.3e-63.3e-5
SEGNN-5-64183K28.8
2.6e-61.4e-4
SEGNN-10-64360K59.21.9e-61.1e-4
TGV 3D (8.0K)GNS-5-64161K8.43.8e-48.3e-3
GNS-10-1281.2M30.52.1e-45.8e-3
SEGNN-5-64183K79.43.1e-47.7e-3
SEGNN-10-64360K154.31.7e-45.2e-3
RPF 3D (8.0K)GNS-5-64161K8.41.3e-65.2e-5
GNS-10-1281.2M30.53.3e-71.9e-5
SEGNN-5-64183K79.46.6e-73.1e-5
SEGNN-10-64360K154.33.0e-71.8e-5
EGNN-5-128663K250.7unstableunstable
PaiNN-5-1281.0M43.01.8e-53.6e-4
LDC 3D (8.2K)GNS-5-64161K8.61.7e-65.7e-5
GNS-10-1281.2M32.07.4e-74.0e-5
SEGNN-5-64183K81.21.2e-64.8e-5
SEGNN-10-64360K161.29.4e-74.4e-5
285 |
-------------------------------------------------------------------------------- /docs/pages/case_setup.rst: -------------------------------------------------------------------------------- 1 | Case Setup 2 | =================================== 3 | 4 | Case 5 | ---- 6 | .. automodule:: lagrangebench.case_setup.case 7 | :members: 8 | 9 | Featurizer 10 | ---------- 11 | .. automodule:: lagrangebench.case_setup.features 12 | :members: 13 | -------------------------------------------------------------------------------- /docs/pages/data.rst: -------------------------------------------------------------------------------- 1 | Data 2 | =================================== 3 | 4 | Data 5 | ---- 6 | .. automodule:: lagrangebench.data.data 7 | :members: 8 | 9 | Utils 10 | ----- 11 | .. automodule:: lagrangebench.data.utils 12 | :members: 13 | -------------------------------------------------------------------------------- /docs/pages/defaults.rst: -------------------------------------------------------------------------------- 1 | Defaults 2 | =================================== 3 | 4 | 5 | 6 | .. exec_code:: 7 | :hide_code: 8 | :linenos_output: 9 | :language_output: python 10 | :caption: LagrangeBench default values 11 | 12 | 13 | with open("lagrangebench/defaults.py", "r") as file: 14 | defaults_full = file.read() 15 | 16 | # parse defaults: remove imports, only keep the set_defaults function 17 | 18 | defaults_full = defaults_full.split("\n") 19 | 20 | # remove imports 21 | defaults_full = [line for line in defaults_full if not line.startswith("import")] 22 | defaults_full = [line for line in defaults_full if len(line.replace(" ", "")) > 0] 23 | 24 | # remove other functions 25 | keep = False 26 | defaults = [] 27 | for i, line in enumerate(defaults_full): 28 | if line.startswith("def"): 29 | if "set_defaults" in line: 30 | keep = True 31 | else: 32 | keep = False 33 | 34 | if keep: 35 | defaults.append(line) 36 | 37 | # remove function declaration and return 38 | defaults = defaults[2:-2] 39 | 40 | # remove indent 41 | defaults = [line[4:] for line in defaults] 42 | 43 | 44 | print("\n".join(defaults)) 45 | -------------------------------------------------------------------------------- /docs/pages/evaluate.rst: -------------------------------------------------------------------------------- 1 | Evaluate 2 | =================================== 3 | 4 | Rollout 5 | -------- 6 | .. automodule:: lagrangebench.evaluate.rollout 7 | :members: 8 | 9 | Metrics 10 | ------- 11 | .. automodule:: lagrangebench.evaluate.metrics 12 | :members: 13 | 14 | Utils 15 | ----- 16 | .. automodule:: lagrangebench.evaluate.utils 17 | :members: 18 | -------------------------------------------------------------------------------- /docs/pages/models.rst: -------------------------------------------------------------------------------- 1 | Models 2 | =================================== 3 | 4 | Base Class 5 | ---------- 6 | .. automodule:: lagrangebench.models.base 7 | :members: 8 | 9 | GNS 10 | --- 11 | .. automodule:: lagrangebench.models.gns 12 | :members: 13 | 14 | SEGNN 15 | ----- 16 | .. automodule:: lagrangebench.models.segnn 17 | :members: 18 | 19 | EGNN 20 | ---- 21 | .. automodule:: lagrangebench.models.egnn 22 | :members: 23 | 24 | PaiNN 25 | ----- 26 | .. automodule:: lagrangebench.models.painn 27 | :members: 28 | 29 | Linear 30 | ------ 31 | .. automodule:: lagrangebench.models.linear 32 | :members: 33 | 34 | Utils 35 | ----- 36 | .. automodule:: lagrangebench.models.utils 37 | :members: 38 | :exclude-members: __getnewargs__, __new__, __repr__ 39 | -------------------------------------------------------------------------------- /docs/pages/train.rst: -------------------------------------------------------------------------------- 1 | Train 2 | =================================== 3 | 4 | Trainer 5 | ------- 6 | .. automodule:: lagrangebench.train.trainer 7 | :members: 8 | :exclude-members: __init__, __delattr__, __setattr__, __hash__, __eq__, __repr__, __weakref__ 9 | 10 | Strategies 11 | ---------- 12 | .. automodule:: lagrangebench.train.strats 13 | :members: 14 | -------------------------------------------------------------------------------- /docs/pages/tutorial.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | =================================== 3 | 4 | - Basics of LagrangeBench: `Training GNS on the 2D Taylor Green vortex `_ |tutorial| 5 | - Stats and specifics of our data: `Datasets overview `_ |datasets| 6 | - How to include datasets from other works: `Working with other datasets `_ |gns| 7 | 8 | .. |tutorial| raw:: html 9 | 10 | 11 | 12 | .. |datasets| raw:: html 13 | 14 | 15 | 16 | .. |gns| raw:: html 17 | 18 | -------------------------------------------------------------------------------- /docs/pages/utils.rst: -------------------------------------------------------------------------------- 1 | Utils and Defaults 2 | =================================== 3 | 4 | Utils 5 | ----- 6 | .. automodule:: lagrangebench.utils 7 | :members: 8 | :exclude-members: __init__, __delattr__, __setattr__, __hash__, __eq__, __repr__, __weakref__ 9 | 10 | 11 | Defaults 12 | -------- 13 | .. automodule:: lagrangebench.defaults 14 | :members: 15 | :exclude-members: __init__, __delattr__, __setattr__, __hash__, __eq__, __repr__, __weakref__ 16 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | 3 | cloudpickle 4 | dm_haiku>=0.0.10 5 | e3nn_jax==0.20.3 6 | h5py 7 | jax-sph>=0.0.3 8 | jax[cpu]==0.4.29 9 | jmp>=0.0.4 10 | jraph>=0.0.6.dev0 11 | matscipy>=0.8.0 12 | omegaconf>=2.3.0 13 | optax>=0.1.7 14 | ott-jax>=0.4.2 15 | pyvista 16 | PyYAML 17 | sphinx==7.2.6 18 | sphinx-exec-code 19 | sphinx-rtd-theme==1.3.0 20 | toml>=0.10.2 21 | torch==2.3.1+cpu 22 | wandb 23 | wget 24 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Download datasets from Zenodo 3 | # Usage: 4 | # bash download_data.sh 5 | # Usage: 6 | # bash download_data.sh all datasets/ 7 | 8 | declare -A datasets 9 | datasets["tgv_2d"]="2D_TGV_2500_10kevery100.zip" 10 | datasets["rpf_2d"]="2D_RPF_3200_20kevery100.zip" 11 | datasets["ldc_2d"]="2D_LDC_2708_10kevery100.zip" 12 | datasets["dam_2d"]="2D_DAM_5740_20kevery100.zip" 13 | datasets["tgv_3d"]="3D_TGV_8000_10kevery100.zip" 14 | datasets["rpf_3d"]="3D_RPF_8000_10kevery100.zip" 15 | datasets["ldc_3d"]="3D_LDC_8160_10kevery100.zip" 16 | 17 | if [ $# -ne 2 ]; then 18 | echo "Usage: bash download_data.sh " 19 | exit 1 20 | fi 21 | 22 | DATASET_NAME="$1" 23 | OUTPUT_DIR="$2" 24 | ZENODO_PREFIX="https://zenodo.org/records/10491868/files/" 25 | 26 | # Check if there is a trailing slash in $OUTPUT_DIR and remove it 27 | if [[ $OUTPUT_DIR == */ ]]; then 28 | OUTPUT_DIR="${OUTPUT_DIR%/}" 29 | echo "Output directory: ${OUTPUT_DIR}" 30 | fi 31 | 32 | # Create output directory if it doesn't exist 33 | if [ ! -d "${OUTPUT_DIR}" ]; then 34 | mkdir -p "${OUTPUT_DIR}" 35 | fi 36 | 37 | # Download the data 38 | if [ "${DATASET_NAME}" == "all" ]; then 39 | echo "Downloading all datasets" 40 | for key in ${!datasets[@]}; do 41 | echo "Downloading ${key}" 42 | wget ${ZENODO_PREFIX}${datasets[${key}]} -P "${OUTPUT_DIR}/" 43 | ZIP_PATH=${OUTPUT_DIR}/${datasets[${key}]} 44 | python3 -c "import zipfile; zipfile.ZipFile('$ZIP_PATH', 'r').extractall('$OUTPUT_DIR')" 45 | rm ${ZIP_PATH} 46 | done 47 | else 48 | echo "Downloading ${DATASET_NAME}" 49 | wget ${ZENODO_PREFIX}${datasets[${DATASET_NAME}]} -P "${OUTPUT_DIR}/" 50 | ZIP_PATH=${OUTPUT_DIR}/${datasets[${DATASET_NAME}]} 51 | python3 -c "import zipfile; zipfile.ZipFile('$ZIP_PATH', 'r').extractall('$OUTPUT_DIR')" 52 | rm ${ZIP_PATH} 53 | fi 54 | -------------------------------------------------------------------------------- /lagrangebench/__init__.py: -------------------------------------------------------------------------------- 1 | from .case_setup.case import case_builder 2 | from .data import DAM2D, LDC2D, LDC3D, RPF2D, RPF3D, TGV2D, TGV3D, H5Dataset 3 | from .evaluate import infer 4 | from .models import EGNN, GNS, SEGNN, PaiNN 5 | from .train.trainer import Trainer 6 | 7 | __all__ = [ 8 | "Trainer", 9 | "infer", 10 | "case_builder", 11 | "models", 12 | "GNS", 13 | "EGNN", 14 | "SEGNN", 15 | "PaiNN", 16 | "data", 17 | "H5Dataset", 18 | "TGV2D", 19 | "TGV3D", 20 | "RPF2D", 21 | "RPF3D", 22 | "LDC2D", 23 | "LDC3D", 24 | "DAM2D", 25 | ] 26 | 27 | __version__ = "0.2.0" 28 | -------------------------------------------------------------------------------- /lagrangebench/case_setup/__init__.py: -------------------------------------------------------------------------------- 1 | """Case setup manager.""" 2 | 3 | from .case import CaseSetupFn, case_builder 4 | 5 | __all__ = [ 6 | "CaseSetupFn", 7 | "case_builder", 8 | ] 9 | -------------------------------------------------------------------------------- /lagrangebench/case_setup/case.py: -------------------------------------------------------------------------------- 1 | """Case setup functions.""" 2 | 3 | import warnings 4 | from typing import Callable, Dict, Optional, Tuple, Union 5 | 6 | import jax.numpy as jnp 7 | from jax import Array, jit, lax, vmap 8 | from jax_sph.jax_md import space 9 | from jax_sph.jax_md.dataclasses import dataclass, static_field 10 | from jax_sph.jax_md.partition import NeighborList, NeighborListFormat, neighbor_list 11 | from omegaconf import DictConfig, OmegaConf 12 | 13 | from lagrangebench.data.utils import get_dataset_stats 14 | from lagrangebench.defaults import defaults 15 | from lagrangebench.train.strats import add_gns_noise 16 | 17 | from .features import FeatureDict, TargetDict, physical_feature_builder 18 | 19 | TrainCaseOut = Tuple[Array, FeatureDict, TargetDict, NeighborList] 20 | EvalCaseOut = Tuple[FeatureDict, NeighborList] 21 | SampleIn = Tuple[jnp.ndarray, jnp.ndarray] 22 | 23 | AllocateFn = Callable[[Array, SampleIn, float, int], TrainCaseOut] 24 | AllocateEvalFn = Callable[[SampleIn], EvalCaseOut] 25 | 26 | PreprocessFn = Callable[[Array, SampleIn, float, NeighborList, int], TrainCaseOut] 27 | PreprocessEvalFn = Callable[[SampleIn, NeighborList], EvalCaseOut] 28 | 29 | IntegrateFn = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray] 30 | 31 | 32 | @dataclass 33 | class CaseSetupFn: 34 | """Dataclass that contains all functions required to setup the case and simulate. 35 | 36 | Attributes: 37 | allocate: AllocateFn, runs the preprocessing without having a NeighborList as 38 | input. 39 | preprocess: PreprocessFn, takes positions from the dataloader, computes 40 | velocities, adds random-walk noise if needed, then updates the neighbor 41 | list, and return the inputs to the neural network as well as the targets. 42 | allocate_eval: AllocateEvalFn, same as allocate, but without noise addition 43 | and without targets. 44 | preprocess_eval: PreprocessEvalFn, same as allocate_eval, but jit-able. 45 | integrate: IntegrateFn, semi-implicit Euler integrations step respecting 46 | all boundary conditions. 47 | displacement: space.DisplacementFn, displacement function aware of boundary 48 | conditions (periodic on non-periodic). 49 | normalization_stats: Dict, normalization statisticss for input velocities and 50 | output acceleration. 51 | """ 52 | 53 | allocate: AllocateFn = static_field() 54 | preprocess: PreprocessFn = static_field() 55 | allocate_eval: AllocateEvalFn = static_field() 56 | preprocess_eval: PreprocessEvalFn = static_field() 57 | integrate: IntegrateFn = static_field() 58 | displacement: space.DisplacementFn = static_field() 59 | normalization_stats: Dict = static_field() 60 | 61 | 62 | def case_builder( 63 | box: Tuple[float, float, float], 64 | metadata: Dict, 65 | input_seq_length: int, 66 | cfg_neighbors: Union[Dict, DictConfig] = defaults.neighbors, 67 | cfg_model: Union[Dict, DictConfig] = defaults.model, 68 | noise_std: float = defaults.train.noise_std, 69 | external_force_fn: Optional[Callable] = None, 70 | dtype: jnp.dtype = defaults.dtype, 71 | ): 72 | """Set up a CaseSetupFn that contains every required function besides the model. 73 | 74 | Inspired by the `partition.neighbor_list` function in JAX-MD. 75 | 76 | The core functions are: 77 | * allocate, allocate memory for the neighbors list. 78 | * preprocess, update the neighbors list. 79 | * integrate, semi-implicit Euler respecting periodic boundary conditions. 80 | 81 | Args: 82 | box: Box xyz sizes of the system. 83 | metadata: Dataset metadata dictionary. 84 | input_seq_length: Length of the input sequence. 85 | cfg_neighbors: Configuration dictionary for the neighbor list. 86 | cfg_model: Configuration dictionary for the model / feature builder. 87 | noise_std: Noise standard deviation. 88 | external_force_fn: External force function. 89 | dtype: Data type. 90 | """ 91 | if isinstance(cfg_neighbors, Dict): 92 | cfg_neighbors = OmegaConf.create(cfg_neighbors) 93 | if isinstance(cfg_model, Dict): 94 | cfg_model = OmegaConf.create(cfg_model) 95 | 96 | # if one of the cfg_* arguments has a subset of the default configs, merge them 97 | cfg_neighbors = OmegaConf.merge(defaults.neighbors, cfg_neighbors) 98 | cfg_model = OmegaConf.merge(defaults.model, cfg_model) 99 | 100 | normalization_stats = get_dataset_stats( 101 | metadata, cfg_model.isotropic_norm, noise_std 102 | ) 103 | 104 | # apply PBC in all directions or not at all 105 | if jnp.array(metadata["periodic_boundary_conditions"]).any(): 106 | displacement_fn, shift_fn = space.periodic(side=jnp.array(box)) 107 | else: 108 | displacement_fn, shift_fn = space.free() 109 | 110 | displacement_fn_set = vmap(displacement_fn, in_axes=(0, 0)) 111 | 112 | if cfg_neighbors.multiplier < 1.25: 113 | warnings.warn( 114 | f"cfg_neighbors.multiplier={cfg_neighbors.multiplier} < 1.25 is very low. " 115 | "Be especially cautious if you batch training and/or inference as " 116 | "reallocation might be necessary based on different overflow conditions. " 117 | "See https://github.com/tumaer/lagrangebench/pull/20#discussion_r1443811262" 118 | ) 119 | 120 | neighbor_fn = neighbor_list( 121 | displacement_fn, 122 | jnp.array(box), 123 | backend=cfg_neighbors.backend, 124 | r_cutoff=metadata["default_connectivity_radius"], 125 | capacity_multiplier=cfg_neighbors.multiplier, 126 | mask_self=False, 127 | format=NeighborListFormat.Sparse, 128 | num_particles_max=metadata["num_particles_max"], 129 | pbc=metadata["periodic_boundary_conditions"], 130 | ) 131 | 132 | feature_transform = physical_feature_builder( 133 | bounds=metadata["bounds"], 134 | normalization_stats=normalization_stats, 135 | connectivity_radius=metadata["default_connectivity_radius"], 136 | displacement_fn=displacement_fn, 137 | pbc=metadata["periodic_boundary_conditions"], 138 | magnitude_features=cfg_model.magnitude_features, 139 | external_force_fn=external_force_fn, 140 | ) 141 | 142 | def _compute_target(pos_input: jnp.ndarray) -> TargetDict: 143 | # displacement(r1, r2) = r1-r2 # without PBC 144 | 145 | current_velocity = displacement_fn_set(pos_input[:, 1], pos_input[:, 0]) 146 | next_velocity = displacement_fn_set(pos_input[:, 2], pos_input[:, 1]) 147 | current_acceleration = next_velocity - current_velocity 148 | 149 | acc_stats = normalization_stats["acceleration"] 150 | normalized_acceleration = ( 151 | current_acceleration - acc_stats["mean"] 152 | ) / acc_stats["std"] 153 | 154 | vel_stats = normalization_stats["velocity"] 155 | normalized_velocity = (next_velocity - vel_stats["mean"]) / vel_stats["std"] 156 | return { 157 | "acc": normalized_acceleration, 158 | "vel": normalized_velocity, 159 | "pos": pos_input[:, -1], 160 | } 161 | 162 | def _preprocess( 163 | sample: Tuple[jnp.ndarray, jnp.ndarray], 164 | neighbors: Optional[NeighborList] = None, 165 | is_allocate: bool = False, 166 | mode: str = "train", 167 | **kwargs, # key, noise_std, unroll_steps 168 | ) -> Union[TrainCaseOut, EvalCaseOut]: 169 | pos_input = jnp.asarray(sample[0], dtype=dtype) 170 | particle_type = jnp.asarray(sample[1]) 171 | 172 | if mode == "train": 173 | key, noise_std = kwargs["key"], kwargs["noise_std"] 174 | unroll_steps = kwargs["unroll_steps"] 175 | if pos_input.shape[1] > 1: 176 | key, pos_input = add_gns_noise( 177 | key, pos_input, particle_type, input_seq_length, noise_std, shift_fn 178 | ) 179 | 180 | # allocate the neighbor list 181 | most_recent_position = pos_input[:, input_seq_length - 1] 182 | num_particles = (particle_type != -1).sum() 183 | if is_allocate: 184 | neighbors = neighbor_fn.allocate( 185 | most_recent_position, num_particles=num_particles 186 | ) 187 | else: 188 | neighbors = neighbors.update( 189 | most_recent_position, num_particles=num_particles 190 | ) 191 | 192 | # selected features 193 | features = feature_transform(pos_input[:, :input_seq_length], neighbors) 194 | 195 | if mode == "train": 196 | # compute target acceleration. Inverse of postprocessing step. 197 | # the "-2" is needed because we need the most recent position and one before 198 | slice_begin = (0, input_seq_length - 2 + unroll_steps, 0) 199 | slice_size = (pos_input.shape[0], 3, pos_input.shape[2]) 200 | 201 | target_dict = _compute_target( 202 | lax.dynamic_slice(pos_input, slice_begin, slice_size) 203 | ) 204 | return key, features, target_dict, neighbors 205 | if mode == "eval": 206 | return features, neighbors 207 | 208 | def allocate_fn(key, sample, noise_std=0.0, unroll_steps=0): 209 | return _preprocess( 210 | sample, 211 | key=key, 212 | noise_std=noise_std, 213 | unroll_steps=unroll_steps, 214 | is_allocate=True, 215 | ) 216 | 217 | @jit 218 | def preprocess_fn(key, sample, noise_std, neighbors, unroll_steps=0): 219 | return _preprocess( 220 | sample, neighbors, key=key, noise_std=noise_std, unroll_steps=unroll_steps 221 | ) 222 | 223 | def allocate_eval_fn(sample): 224 | return _preprocess(sample, is_allocate=True, mode="eval") 225 | 226 | @jit 227 | def preprocess_eval_fn(sample, neighbors): 228 | return _preprocess(sample, neighbors, mode="eval") 229 | 230 | @jit 231 | def integrate_fn(normalized_in, position_sequence): 232 | """Euler integrator to get position shift.""" 233 | assert any([key in normalized_in for key in ["pos", "vel", "acc"]]) 234 | 235 | if "pos" in normalized_in: 236 | # Zeroth euler step 237 | return normalized_in["pos"] 238 | else: 239 | most_recent_position = position_sequence[:, -1] 240 | if "vel" in normalized_in: 241 | # invert normalization 242 | velocity_stats = normalization_stats["velocity"] 243 | new_velocity = velocity_stats["mean"] + ( 244 | normalized_in["vel"] * velocity_stats["std"] 245 | ) 246 | elif "acc" in normalized_in: 247 | # invert normalization. 248 | acceleration_stats = normalization_stats["acceleration"] 249 | acceleration = acceleration_stats["mean"] + ( 250 | normalized_in["acc"] * acceleration_stats["std"] 251 | ) 252 | # Second Euler step 253 | most_recent_velocity = displacement_fn_set( 254 | most_recent_position, position_sequence[:, -2] 255 | ) 256 | new_velocity = most_recent_velocity + acceleration # * dt = 1 257 | 258 | # First Euler step 259 | return shift_fn(most_recent_position, new_velocity) 260 | 261 | return CaseSetupFn( 262 | allocate_fn, 263 | preprocess_fn, 264 | allocate_eval_fn, 265 | preprocess_eval_fn, 266 | integrate_fn, 267 | displacement_fn, 268 | normalization_stats, 269 | ) 270 | -------------------------------------------------------------------------------- /lagrangebench/case_setup/features.py: -------------------------------------------------------------------------------- 1 | """Feature extraction utilities.""" 2 | 3 | from typing import Callable, Dict, List, Optional 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | from jax import lax, vmap 8 | from jax_sph.jax_md import partition, space 9 | 10 | FeatureDict = Dict[str, jnp.ndarray] 11 | TargetDict = Dict[str, jnp.ndarray] 12 | 13 | 14 | def physical_feature_builder( 15 | bounds: list, 16 | normalization_stats: dict, 17 | connectivity_radius: float, 18 | displacement_fn: Callable, 19 | pbc: List[bool], 20 | magnitude_features: bool = False, 21 | external_force_fn: Optional[Callable] = None, 22 | ) -> Callable: 23 | """Build a physical feature transform function. 24 | 25 | Transform raw coordinates to 26 | - Absolute positions 27 | - Historical velocity sequence 28 | - Velocity magnitudes 29 | - Distance to boundaries 30 | - External force field 31 | - Relative displacement vectors and distances 32 | 33 | Args: 34 | bounds: Each sublist contains the lower and upper bound of a dimension. 35 | normalization_stats: Dict containing mean and std of velocities and targets 36 | connectivity_radius: Radius of the connectivity graph. 37 | displacement_fn: Displacement function. 38 | pbc: Wether to use periodic boundary conditions. 39 | magnitude_features: Whether to include the magnitude of the velocity. 40 | external_force_fn: Function that returns the external force field (optional). 41 | """ 42 | displacement_fn_vmap = vmap(displacement_fn, in_axes=(0, 0)) 43 | displacement_fn_dvmap = vmap(displacement_fn_vmap, in_axes=(0, 0)) 44 | 45 | velocity_stats = normalization_stats["velocity"] 46 | 47 | def feature_transform( 48 | pos_input: jnp.ndarray, 49 | nbrs: partition.NeighborList, 50 | ) -> FeatureDict: 51 | """Feature engineering. 52 | 53 | Returns: 54 | Dict of features, with possible keys 55 | - "abs_pos", absolute positions 56 | - "vel_hist", historical velocity sequence 57 | - "vel_mag", velocity magnitudes 58 | - "bound", distance to boundaries 59 | - "force", external force field 60 | - "rel_disp", relative displacement vectors 61 | - "rel_dist", relative distance vectors 62 | """ 63 | features = {} 64 | 65 | n_total_points = pos_input.shape[0] 66 | most_recent_position = pos_input[:, -1] # (n_nodes, dim) 67 | # pos_input.shape = (n_nodes, n_timesteps, dim) 68 | velocity_sequence = displacement_fn_dvmap(pos_input[:, 1:], pos_input[:, :-1]) 69 | # Normalized velocity sequence, merging spatial an time axis. 70 | normalized_velocity_sequence = ( 71 | velocity_sequence - velocity_stats["mean"] 72 | ) / velocity_stats["std"] 73 | flat_velocity_sequence = normalized_velocity_sequence.reshape( 74 | n_total_points, -1 75 | ) 76 | 77 | features["abs_pos"] = pos_input 78 | features["vel_hist"] = flat_velocity_sequence 79 | 80 | if magnitude_features: 81 | # append the magnitude of the velocity of each particle to the node features 82 | velocity_magnitude_sequence = jnp.linalg.norm( 83 | normalized_velocity_sequence, axis=-1 84 | ) 85 | features["vel_mag"] = velocity_magnitude_sequence 86 | 87 | if not any(pbc): 88 | # Normalized clipped distances to lower and upper boundaries. 89 | # boundaries are an array of shape [num_dimensions, dim], where the 90 | # second axis, provides the lower/upper boundaries. 91 | boundaries = lax.stop_gradient(jnp.array(bounds)) 92 | 93 | distance_to_lower_boundary = most_recent_position - boundaries[:, 0][None] 94 | distance_to_upper_boundary = boundaries[:, 1][None] - most_recent_position 95 | 96 | # rewritten the code above in jax 97 | distance_to_boundaries = jnp.concatenate( 98 | [distance_to_lower_boundary, distance_to_upper_boundary], axis=1 99 | ) 100 | normalized_clipped_distance_to_boundaries = jnp.clip( 101 | distance_to_boundaries / connectivity_radius, -1.0, 1.0 102 | ) 103 | features["bound"] = normalized_clipped_distance_to_boundaries 104 | 105 | if external_force_fn is not None: 106 | external_force_field = vmap(external_force_fn)(most_recent_position) 107 | features["force"] = external_force_field 108 | 109 | # senders and receivers are integers of shape (E,) 110 | receivers, senders = nbrs.idx 111 | features["senders"] = senders 112 | features["receivers"] = receivers 113 | 114 | # Relative displacement and distances normalized to radius (E, dim) 115 | displacement = vmap(displacement_fn)( 116 | most_recent_position[receivers], most_recent_position[senders] 117 | ) 118 | normalized_relative_displacements = displacement / connectivity_radius 119 | features["rel_disp"] = normalized_relative_displacements 120 | 121 | normalized_relative_distances = space.distance( 122 | normalized_relative_displacements 123 | ) 124 | features["rel_dist"] = normalized_relative_distances[:, None] 125 | 126 | return jax.tree_map(lambda f: f, features) 127 | 128 | return feature_transform 129 | -------------------------------------------------------------------------------- /lagrangebench/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Datasets and dataloading utils.""" 2 | 3 | from .data import DAM2D, LDC2D, LDC3D, RPF2D, RPF3D, TGV2D, TGV3D, H5Dataset 4 | 5 | __all__ = ["H5Dataset", "TGV2D", "TGV3D", "RPF2D", "RPF3D", "LDC2D", "LDC3D", "DAM2D"] 6 | -------------------------------------------------------------------------------- /lagrangebench/data/utils.py: -------------------------------------------------------------------------------- 1 | """Data utils.""" 2 | 3 | from typing import Dict, List 4 | 5 | import jax.numpy as jnp 6 | import numpy as np 7 | 8 | 9 | def get_dataset_stats( 10 | metadata: Dict[str, List[float]], 11 | is_isotropic_norm: bool, 12 | noise_std: float, 13 | ) -> Dict[str, Dict[str, jnp.ndarray]]: 14 | """Return the dataset statistics based on the metadata dictionary. 15 | 16 | Args: 17 | metadata: Dataset metadata dictionary. 18 | is_isotropic_norm: 19 | Whether to shift/scale dimensions equally instead of dimension-wise. 20 | noise_std: Standard deviation of the GNS-style noise. 21 | 22 | Returns: 23 | Dictionary with the dataset statistics. 24 | """ 25 | acc_mean = jnp.array(metadata["acc_mean"]) 26 | acc_std = jnp.array(metadata["acc_std"]) 27 | vel_mean = jnp.array(metadata["vel_mean"]) 28 | vel_std = jnp.array(metadata["vel_std"]) 29 | 30 | if is_isotropic_norm: 31 | acc_mean = jnp.mean(acc_mean) * jnp.ones_like(acc_mean) 32 | acc_std = jnp.sqrt(jnp.mean(acc_std**2)) * jnp.ones_like(acc_std) 33 | vel_mean = jnp.mean(vel_mean) * jnp.ones_like(vel_mean) 34 | vel_std = jnp.sqrt(jnp.mean(vel_std**2)) * jnp.ones_like(vel_std) 35 | 36 | return { 37 | "acceleration": { 38 | "mean": acc_mean, 39 | "std": jnp.sqrt(acc_std**2 + noise_std**2), 40 | }, 41 | "velocity": { 42 | "mean": vel_mean, 43 | "std": jnp.sqrt(vel_std**2 + noise_std**2), 44 | }, 45 | } 46 | 47 | 48 | def numpy_collate(batch) -> np.ndarray: 49 | """Collate helper for torch dataloaders.""" 50 | # NOTE: to numpy to avoid copying twice (dataloader timeout). 51 | if isinstance(batch[0], np.ndarray): 52 | return np.stack(batch) 53 | if isinstance(batch[0], (tuple, list)): 54 | return type(batch[0])(numpy_collate(samples) for samples in zip(*batch)) 55 | else: 56 | return np.asarray(batch) 57 | -------------------------------------------------------------------------------- /lagrangebench/defaults.py: -------------------------------------------------------------------------------- 1 | """Default lagrangebench configs.""" 2 | 3 | 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | 7 | def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig: 8 | """Set default lagrangebench configs.""" 9 | 10 | ### global and hardware-related configs 11 | 12 | # configuration file. Either "config" or "load_ckp" must be specified. 13 | # If "config" is specified, "load_ckp" is ignored. 14 | cfg.config = None 15 | # Load checkpointed model from this directory 16 | cfg.load_ckp = None 17 | # One of "train", "infer" or "all" (= both) 18 | cfg.mode = "all" 19 | # random seed 20 | cfg.seed = 0 21 | # data type for preprocessing. One of "float32" or "float64" 22 | cfg.dtype = "float64" 23 | # gpu device. -1 for CPU. Should be specified before importing the library. 24 | cfg.gpu = None 25 | # XLA memory fraction to be preallocated. The JAX default is 0.75. 26 | # Should be specified before importing the library. 27 | cfg.xla_mem_fraction = None 28 | 29 | ### dataset 30 | cfg.dataset = OmegaConf.create({}) 31 | 32 | # path to data directory 33 | cfg.dataset.src = None 34 | # dataset name 35 | cfg.dataset.name = None 36 | 37 | ### model 38 | cfg.model = OmegaConf.create({}) 39 | 40 | # model architecture name. gns, segnn, egnn 41 | cfg.model.name = None 42 | # Length of the position input sequence 43 | cfg.model.input_seq_length = 6 44 | # Number of message passing steps 45 | cfg.model.num_mp_steps = 10 46 | # Number of MLP layers 47 | cfg.model.num_mlp_layers = 2 48 | # Hidden dimension 49 | cfg.model.latent_dim = 128 50 | # whether to include velocity magnitude features 51 | cfg.model.magnitude_features = False 52 | # whether to normalize dimensions equally 53 | cfg.model.isotropic_norm = False 54 | 55 | # SEGNN only parameters 56 | # steerable attributes level 57 | cfg.model.lmax_attributes = 1 58 | # Level of the hidden layer 59 | cfg.model.lmax_hidden = 1 60 | # SEGNN normalization. instance, batch, none 61 | cfg.model.segnn_norm = "none" 62 | # SEGNN velocity aggregation. avg or last 63 | cfg.model.velocity_aggregate = "avg" 64 | 65 | ### training 66 | cfg.train = OmegaConf.create({}) 67 | 68 | # batch size 69 | cfg.train.batch_size = 1 70 | # max number of training steps 71 | cfg.train.step_max = 500_000 72 | # number of workers for data loading 73 | cfg.train.num_workers = 4 74 | # standard deviation of the GNS-style noise 75 | cfg.train.noise_std = 3.0e-4 76 | 77 | # optimizer 78 | cfg.train.optimizer = OmegaConf.create({}) 79 | 80 | # initial learning rate 81 | cfg.train.optimizer.lr_start = 1.0e-4 82 | # final learning rate (after exponential decay) 83 | cfg.train.optimizer.lr_final = 1.0e-6 84 | # learning rate decay rate 85 | cfg.train.optimizer.lr_decay_rate = 0.1 86 | # number of steps to decay learning rate 87 | cfg.train.optimizer.lr_decay_steps = 1.0e5 88 | 89 | # pushforward 90 | cfg.train.pushforward = OmegaConf.create({}) 91 | 92 | # At which training step to introduce next unroll stage 93 | cfg.train.pushforward.steps = [-1, 20000, 300000, 400000] 94 | # For how many steps to unroll 95 | cfg.train.pushforward.unrolls = [0, 1, 2, 3] 96 | # Which probability ratio to keep between the unrolls 97 | cfg.train.pushforward.probs = [18, 2, 1, 1] 98 | 99 | # loss weights 100 | cfg.train.loss_weight = OmegaConf.create({}) 101 | 102 | # weight for acceleration error 103 | cfg.train.loss_weight.acc = 1.0 104 | # weight for velocity error 105 | cfg.train.loss_weight.vel = 0.0 106 | # weight for position error 107 | cfg.train.loss_weight.pos = 0.0 108 | 109 | ### evaluation 110 | cfg.eval = OmegaConf.create({}) 111 | 112 | # number of eval rollout steps. -1 is full rollout 113 | cfg.eval.n_rollout_steps = 20 114 | # whether to use the test or valid split 115 | cfg.eval.test = False 116 | # rollouts directory 117 | cfg.eval.rollout_dir = None 118 | 119 | # configs for validation during training 120 | cfg.eval.train = OmegaConf.create({}) 121 | 122 | # number of trajectories to evaluate 123 | cfg.eval.train.n_trajs = 50 124 | # stride for e_kin and sinkhorn 125 | cfg.eval.train.metrics_stride = 10 126 | # batch size 127 | cfg.eval.train.batch_size = 1 128 | # metrics to evaluate 129 | cfg.eval.train.metrics = ["mse"] 130 | # write validation rollouts. One of "none", "vtk", or "pkl" 131 | cfg.eval.train.out_type = "none" 132 | 133 | # configs for inference/testing 134 | cfg.eval.infer = OmegaConf.create({}) 135 | 136 | # number of trajectories to evaluate during inference 137 | cfg.eval.infer.n_trajs = -1 138 | # stride for e_kin and sinkhorn 139 | cfg.eval.infer.metrics_stride = 1 140 | # batch size 141 | cfg.eval.infer.batch_size = 2 142 | # metrics for inference 143 | cfg.eval.infer.metrics = ["mse", "e_kin", "sinkhorn"] 144 | # write inference rollouts. One of "none", "vtk", or "pkl" 145 | cfg.eval.infer.out_type = "pkl" 146 | 147 | # number of extrapolation steps during inference 148 | cfg.eval.infer.n_extrap_steps = 0 149 | 150 | ### logging 151 | cfg.logging = OmegaConf.create({}) 152 | 153 | # number of steps between loggings 154 | cfg.logging.log_steps = 1000 155 | # number of steps between evaluations and checkpoints 156 | cfg.logging.eval_steps = 10000 157 | # wandb enable 158 | cfg.logging.wandb = False 159 | # wandb project name 160 | cfg.logging.wandb_project = None 161 | # wandb entity name 162 | cfg.logging.wandb_entity = "lagrangebench" 163 | # checkpoint directory 164 | cfg.logging.ckp_dir = "ckp" 165 | # name of training run 166 | cfg.logging.run_name = None 167 | 168 | ### neighbor list 169 | cfg.neighbors = OmegaConf.create({}) 170 | 171 | # backend for neighbor list computation 172 | cfg.neighbors.backend = "jaxmd_vmap" 173 | # multiplier for neighbor list capacity 174 | cfg.neighbors.multiplier = 1.25 175 | 176 | return cfg 177 | 178 | 179 | defaults = set_defaults() 180 | 181 | 182 | def check_cfg(cfg: DictConfig): 183 | """Check if the configs are valid.""" 184 | 185 | assert cfg.mode in ["train", "infer", "all"] 186 | assert cfg.dtype in ["float32", "float64"] 187 | assert cfg.dataset.src is not None, "dataset.src must be specified." 188 | 189 | assert cfg.model.input_seq_length >= 2, "At least two positions for one past vel." 190 | 191 | pf = cfg.train.pushforward 192 | assert len(pf.steps) == len(pf.unrolls) == len(pf.probs) 193 | assert all([s >= 0 for s in pf.unrolls]), "All unrolls must be non-negative." 194 | assert all([s >= 0 for s in pf.probs]), "All probabilities must be non-negative." 195 | lwv = cfg.train.loss_weight.values() 196 | assert all([w >= 0 for w in lwv]), "All loss weights must be non-negative." 197 | assert sum(lwv) > 0, "At least one loss weight must be non-zero." 198 | 199 | assert cfg.eval.train.n_trajs >= -1 200 | assert cfg.eval.infer.n_trajs >= -1 201 | assert set(cfg.eval.train.metrics).issubset(["mse", "e_kin", "sinkhorn"]) 202 | assert set(cfg.eval.infer.metrics).issubset(["mse", "e_kin", "sinkhorn"]) 203 | assert cfg.eval.train.out_type in ["none", "vtk", "pkl"] 204 | assert cfg.eval.infer.out_type in ["none", "vtk", "pkl"] 205 | -------------------------------------------------------------------------------- /lagrangebench/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | """Evaluation and rollout generation tools.""" 2 | 3 | from .metrics import MetricsComputer, MetricsDict, averaged_metrics 4 | from .rollout import eval_rollout, infer 5 | 6 | __all__ = [ 7 | "MetricsComputer", 8 | "MetricsDict", 9 | "averaged_metrics", 10 | "infer", 11 | "eval_rollout", 12 | ] 13 | -------------------------------------------------------------------------------- /lagrangebench/evaluate/metrics.py: -------------------------------------------------------------------------------- 1 | """Metrics for evaluation end testing.""" 2 | 3 | import warnings 4 | from collections import defaultdict 5 | from functools import partial 6 | from typing import Callable, Dict, List, Optional 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | from ott.geometry.geometry import Geometry 12 | from ott.tools.sinkhorn_divergence import sinkhorn_divergence 13 | 14 | MetricsDict = Dict[str, Dict[str, jnp.ndarray]] 15 | 16 | 17 | class MetricsComputer: 18 | """ 19 | Metrics between predicted and target rollouts. 20 | 21 | Currently implemented: 22 | * MSE, mean squared error 23 | * MAE, mean absolute error 24 | * Sinkhorn distance, measures the similarity of two particle distributions 25 | * Kinetic energy, physical quantity of interest 26 | """ 27 | 28 | METRICS = ["mse", "mae", "sinkhorn", "e_kin"] 29 | 30 | def __init__( 31 | self, 32 | active_metrics: List, 33 | dist_fn: Callable, 34 | metadata: Dict, 35 | input_seq_length: int, 36 | stride: int = 10, 37 | loss_ranges: Optional[List] = None, 38 | ot_backend: str = "ott", 39 | ): 40 | """Init the metric computer. 41 | 42 | Args: 43 | active_metrics: List of metrics to compute. 44 | dist_fn: Distance function. 45 | metadata: Metadata of the dataset. 46 | loss_ranges: List of horizon lengths to compute the loss for. 47 | input_seq_length: Length of the input sequence. 48 | stride: Rollout subsample frequency for e_kin and sinkhorn. 49 | ot_backend: Backend for sinkhorn computation. "ott" or "pot". 50 | """ 51 | if active_metrics is None: 52 | active_metrics = [] 53 | assert all([hasattr(self, metric) for metric in active_metrics]) 54 | assert ot_backend in ["ott", "pot"] 55 | 56 | self._active_metrics = active_metrics 57 | self._dist_fn = dist_fn 58 | self._dist_vmap = jax.vmap(dist_fn, in_axes=(0, 0)) 59 | self._dist_dvmap = jax.vmap(self._dist_vmap, in_axes=(0, 0)) 60 | 61 | if loss_ranges is None: 62 | loss_ranges = [1, 5, 10, 20, 50, 100] 63 | self._loss_ranges = loss_ranges 64 | self._input_seq_length = input_seq_length 65 | self._stride = stride 66 | self._metadata = metadata 67 | self.ot_backend = ot_backend 68 | 69 | def __call__( 70 | self, pred_rollout: jnp.ndarray, target_rollout: jnp.ndarray 71 | ) -> MetricsDict: 72 | """Compute the metrics between two rollouts. 73 | 74 | Args: 75 | pred_rollout: Predicted rollout. 76 | target_rollout: Target rollout. 77 | 78 | Returns: 79 | Dictionary of metrics. 80 | """ 81 | # both rollouts of shape (traj_len - t_window, n_nodes, dim) 82 | target_rollout = jnp.asarray(target_rollout, dtype=pred_rollout.dtype) 83 | metrics = {} 84 | with warnings.catch_warnings(): 85 | warnings.simplefilter("ignore") 86 | for metric_name in self._active_metrics: 87 | metric_fn = getattr(self, metric_name) 88 | if metric_name in ["mse", "mae"]: 89 | # full rollout loss 90 | metrics[metric_name] = jax.vmap(metric_fn)( 91 | pred_rollout, target_rollout 92 | ) 93 | # shorter horizon losses 94 | for i in self._loss_ranges: 95 | if i < metrics[metric_name].shape[0]: 96 | metrics[f"{metric_name}{i}"] = metrics[metric_name][:i] 97 | 98 | elif metric_name in ["e_kin"]: 99 | dt = self._metadata["dt"] * self._metadata["write_every"] 100 | dx = self._metadata["dx"] 101 | dim = self._metadata["dim"] 102 | 103 | metric_dvmap = jax.vmap(jax.vmap(metric_fn)) 104 | 105 | # Ekin of predicted rollout 106 | velocity_rollout = self._dist_dvmap( 107 | pred_rollout[1 :: self._stride], 108 | pred_rollout[0 : -1 : self._stride], 109 | ) 110 | e_kin_pred = metric_dvmap(velocity_rollout / dt).sum(1) 111 | e_kin_pred = e_kin_pred * dx**dim 112 | 113 | # Ekin of target rollout 114 | velocity_rollout = self._dist_dvmap( 115 | target_rollout[1 :: self._stride], 116 | target_rollout[0 : -1 : self._stride], 117 | ) 118 | e_kin_target = metric_dvmap(velocity_rollout / dt).sum(1) 119 | e_kin_target = e_kin_target * dx**dim 120 | 121 | metrics[metric_name] = { 122 | "predicted": e_kin_pred, 123 | "target": e_kin_target, 124 | "mse": ((e_kin_pred - e_kin_target) ** 2).mean(), 125 | } 126 | 127 | elif metric_name == "sinkhorn": 128 | # vmapping over distance matrix blows up 129 | metrics[metric_name] = jax.lax.scan( 130 | lambda _, x: (None, self.sinkhorn(*x)), 131 | None, 132 | ( 133 | pred_rollout[0 :: self._stride], 134 | target_rollout[0 :: self._stride], 135 | ), 136 | )[1] 137 | return metrics 138 | 139 | @partial(jax.jit, static_argnums=(0,)) 140 | def mse(self, pred: jnp.ndarray, target: jnp.ndarray) -> float: 141 | """Compute the mean squared error between two rollouts.""" 142 | return (self._dist_vmap(pred, target) ** 2).mean() 143 | 144 | @partial(jax.jit, static_argnums=(0,)) 145 | def mae(self, pred: jnp.ndarray, target: jnp.ndarray) -> float: 146 | """Compute the mean absolute error between two rollouts.""" 147 | return (jnp.abs(self._dist_vmap(pred, target))).mean() 148 | 149 | @partial(jax.jit, static_argnums=(0,)) 150 | def sinkhorn(self, pred: jnp.ndarray, target: jnp.ndarray) -> float: 151 | """Compute the sinkhorn distance between two rollouts.""" 152 | if self.ot_backend == "ott": 153 | return self._sinkhorn_ott(pred, target) 154 | else: 155 | return self._sinkhorn_pot(pred, target) 156 | 157 | @partial(jax.jit, static_argnums=(0,)) 158 | def e_kin(self, frame: jnp.ndarray) -> float: 159 | """Compute the kinetic energy of a frame.""" 160 | return jnp.sum(frame**2) # * dx ** 3 161 | 162 | def _sinkhorn_ott(self, pred: jnp.ndarray, target: jnp.ndarray) -> float: 163 | # pairwise distances as cost 164 | loss_matrix_xy = self._distance_matrix(pred, target) 165 | loss_matrix_xx = self._distance_matrix(pred, pred) 166 | loss_matrix_yy = self._distance_matrix(target, target) 167 | return sinkhorn_divergence( 168 | Geometry, 169 | loss_matrix_xy, 170 | loss_matrix_xx, 171 | loss_matrix_yy, 172 | # uniform weights 173 | a=jnp.ones((pred.shape[0],)) / pred.shape[0], 174 | b=jnp.ones((target.shape[0],)) / target.shape[0], 175 | sinkhorn_kwargs={"threshold": 1e-4}, 176 | ).divergence 177 | 178 | def _sinkhorn_pot(self, pred: jnp.ndarray, target: jnp.ndarray): 179 | """Jax-compatible POT implementation of Sinkorn.""" 180 | # equivalent to empirical_sinkhorn_divergence with custom distance computation 181 | sinkhorn_ab = self._custom_empirical_sinkorn_pot(pred, target) 182 | sinkhorn_a = self._custom_empirical_sinkorn_pot(pred, pred) 183 | sinkhorn_b = self._custom_empirical_sinkorn_pot(target, target) 184 | return jnp.asarray( 185 | jnp.clip(sinkhorn_ab - 0.5 * (sinkhorn_a + sinkhorn_b), 0), 186 | dtype=jnp.float32, 187 | ) 188 | 189 | def _custom_empirical_sinkorn_pot(self, pred: jnp.ndarray, target: jnp.ndarray): 190 | from ot.bregman import sinkhorn2 191 | 192 | # weights are uniform 193 | a, b = ( 194 | jnp.ones((pred.shape[0],)) / pred.shape[0], 195 | jnp.ones((target.shape[0],)) / target.shape[0], 196 | ) 197 | loss_matrix = self._distance_matrix(pred, target) 198 | shape = jax.ShapeDtypeStruct((), dtype=jnp.float32) 199 | 200 | # hack to avoid CpuCallback attribute error 201 | def sinkhorn2_(a, b, loss_matrix): 202 | return jnp.array( 203 | sinkhorn2(a, b, loss_matrix, reg=0.1, numItermax=500, stopThr=1e-05), 204 | dtype=jnp.float32, 205 | ) 206 | 207 | return jax.pure_callback( 208 | sinkhorn2_, 209 | shape, 210 | a, 211 | b, 212 | loss_matrix, 213 | ) 214 | 215 | def _distance_matrix( 216 | self, x: jnp.ndarray, y: jnp.ndarray, squared=True 217 | ) -> jnp.ndarray: 218 | """Euclidean distance matrix (pairwise).""" 219 | 220 | def dist(a, b): 221 | return jnp.sum(self._dist_fn(a, b) ** 2) 222 | 223 | if not squared: 224 | 225 | def dist(a, b): 226 | return jnp.sqrt(dist(a, b)) 227 | 228 | return jnp.array( 229 | jax.vmap(lambda a: jax.vmap(lambda b: dist(a, b))(y))(x), dtype=jnp.float32 230 | ) 231 | 232 | 233 | def averaged_metrics(eval_metrics: MetricsDict) -> Dict[str, float]: 234 | """Averages the metrics over the rollouts.""" 235 | # create a dictionary with the same keys as the metrics, but empty list as values 236 | trajectory_averages = defaultdict(list) 237 | for rollout in eval_metrics.values(): 238 | for k, v in rollout.items(): 239 | if k == "e_kin": 240 | v = v["mse"] 241 | if k in ["mse", "mae"]: 242 | k = "loss" 243 | trajectory_averages[k].append(jnp.mean(v).item()) 244 | 245 | # mean and std values accross rollouts 246 | small_metrics = {} 247 | for k, v in trajectory_averages.items(): 248 | small_metrics[f"val/{k}"] = float(np.mean(v)) 249 | for k, v in trajectory_averages.items(): 250 | small_metrics[f"val/std{k}"] = float(np.std(v)) 251 | 252 | return small_metrics 253 | -------------------------------------------------------------------------------- /lagrangebench/evaluate/rollout.py: -------------------------------------------------------------------------------- 1 | """Evaluation and inference functions for generating rollouts.""" 2 | 3 | import os 4 | import pickle 5 | import time 6 | from functools import partial 7 | from typing import Callable, Dict, Iterable, Optional, Tuple, Union 8 | 9 | import haiku as hk 10 | import jax 11 | import jax.numpy as jnp 12 | import jax_sph.jax_md.partition as partition 13 | from jax import jit, vmap 14 | from omegaconf import DictConfig, OmegaConf 15 | from torch.utils.data import DataLoader 16 | 17 | from lagrangebench.data import H5Dataset 18 | from lagrangebench.data.utils import numpy_collate 19 | from lagrangebench.defaults import defaults 20 | from lagrangebench.evaluate.metrics import MetricsComputer, MetricsDict 21 | from lagrangebench.evaluate.utils import write_vtk 22 | from lagrangebench.utils import ( 23 | broadcast_from_batch, 24 | broadcast_to_batch, 25 | get_kinematic_mask, 26 | load_haiku, 27 | set_seed, 28 | ) 29 | 30 | 31 | @partial(jit, static_argnames=["model_apply", "case_integrate"]) 32 | def _forward_eval( 33 | params: hk.Params, 34 | state: hk.State, 35 | sample: Tuple[jnp.ndarray, jnp.ndarray], 36 | current_positions: jnp.ndarray, 37 | target_positions: jnp.ndarray, 38 | model_apply: Callable, 39 | case_integrate: Callable, 40 | ) -> jnp.ndarray: 41 | """Run one update of the 'current_state' using the trained model 42 | 43 | Args: 44 | params: Haiku model parameters 45 | state: Haiku model state 46 | current_positions: Set of historic positions of shape (n_nodel, t_window, dim) 47 | target_positions: used to get the next state of kinematic particles, i.e. those 48 | who are not update using the ML model, e.g. boundary particles 49 | model_apply: model function 50 | case_integrate: integration function from case.integrate 51 | 52 | Return: 53 | current_positions: after shifting the historic position sequence by one, i.e. by 54 | the newly computed most recent position 55 | """ 56 | _, particle_type = sample 57 | 58 | # predict acceleration and integrate 59 | pred, state = model_apply(params, state, sample) 60 | 61 | next_position = case_integrate(pred, current_positions) 62 | 63 | # update only the positions of non-boundary particles 64 | kinematic_mask = get_kinematic_mask(particle_type) 65 | next_position = jnp.where( 66 | kinematic_mask[:, None], 67 | target_positions, 68 | next_position, 69 | ) 70 | 71 | current_positions = jnp.concatenate( 72 | [current_positions[:, 1:], next_position[:, None, :]], axis=1 73 | ) # as next model input 74 | 75 | return current_positions, state 76 | 77 | 78 | def _eval_batched_rollout( 79 | forward_eval_vmap: Callable, 80 | preprocess_eval_vmap: Callable, 81 | case, 82 | params: hk.Params, 83 | state: hk.State, 84 | traj_batch_i: Tuple[jnp.ndarray, jnp.ndarray], 85 | neighbors: partition.NeighborList, 86 | metrics_computer_vmap: Callable, 87 | n_rollout_steps: int, 88 | t_window: int, 89 | n_extrap_steps: int = 0, 90 | ) -> Tuple[jnp.ndarray, MetricsDict, jnp.ndarray]: 91 | """Compute the rollout on a single trajectory. 92 | 93 | Args: 94 | forward_eval_vmap: Model function. 95 | case: CaseSetupFn class. 96 | params: Haiku params. 97 | state: Haiku state. 98 | traj_batch_i: Trajectory to evaluate. 99 | neighbors: Neighbor list. 100 | metrics_computer: Vectorized MetricsComputer with the desired metrics. 101 | n_rollout_steps: Number of rollout steps. 102 | t_window: Length of the input sequence. 103 | n_extrap_steps: Number of extrapolation steps (beyond the ground truth rollout). 104 | 105 | Returns: 106 | A tuple with (predicted rollout, metrics, neighbor list). 107 | """ 108 | # particle type is treated as a static property defined by state at t=0 109 | pos_input_batch, particle_type_batch = traj_batch_i 110 | # current_batch_size might be < eval_batch_size if the last batch is not full 111 | current_batch_size, n_nodes_max, _, dim = pos_input_batch.shape 112 | 113 | # if n_rollout_steps set to -1, use the whole trajectory 114 | if n_rollout_steps == -1: 115 | n_rollout_steps = pos_input_batch.shape[2] - t_window 116 | 117 | current_positions_batch = pos_input_batch[:, :, 0:t_window] 118 | # (batch, n_nodes, t_window, dim) 119 | traj_len = n_rollout_steps + n_extrap_steps 120 | target_positions_batch = pos_input_batch[:, :, t_window : t_window + traj_len] 121 | 122 | predictions_batch = jnp.zeros((current_batch_size, traj_len, n_nodes_max, dim)) 123 | neighbors_batch = broadcast_to_batch(neighbors, current_batch_size) 124 | 125 | step = 0 126 | while step < n_rollout_steps + n_extrap_steps: 127 | sample_batch = (current_positions_batch, particle_type_batch) 128 | 129 | # 1. preprocess features 130 | features_batch, neighbors_batch = preprocess_eval_vmap( 131 | sample_batch, neighbors_batch 132 | ) 133 | 134 | # 2. check whether list overflowed and fix it if so 135 | if neighbors_batch.did_buffer_overflow.sum() > 0: 136 | # check if the neighbor list is too small for any of the samples 137 | # if so, reallocate the neighbor list 138 | 139 | print(f"(eval) Reallocate neighbors list at step {step}") 140 | ind = jnp.argmax(neighbors_batch.did_buffer_overflow) 141 | sample = broadcast_from_batch(sample_batch, index=ind) 142 | 143 | _, nbrs_temp = case.allocate_eval(sample) 144 | print( 145 | f"(eval) From {neighbors_batch.idx[ind].shape} to {nbrs_temp.idx.shape}" 146 | ) 147 | neighbors_batch = broadcast_to_batch(nbrs_temp, current_batch_size) 148 | 149 | # To run the loop N times even if sometimes 150 | # did_buffer_overflow > 0 we directly return to the beginning 151 | continue 152 | 153 | # 3. run forward model 154 | current_positions_batch, state_batch = forward_eval_vmap( 155 | params, 156 | state, 157 | (features_batch, particle_type_batch), 158 | current_positions_batch, 159 | target_positions_batch[:, :, step], 160 | ) 161 | # the state is not passed out of this loop, so no not really relevant 162 | state = broadcast_from_batch(state_batch, 0) 163 | 164 | # 4. write predicted next position to output array 165 | predictions_batch = predictions_batch.at[:, step].set( 166 | current_positions_batch[:, :, -1] # most recently predicted positions 167 | ) 168 | 169 | step += 1 170 | 171 | # (batch, n_nodes, time, dim) -> (batch, time, n_nodes, dim) 172 | target_positions_batch = target_positions_batch.transpose(0, 2, 1, 3) 173 | # slice out extrapolation steps 174 | metrics_batch = metrics_computer_vmap( 175 | predictions_batch[:, :n_rollout_steps, :, :], target_positions_batch 176 | ) 177 | 178 | return (predictions_batch, metrics_batch, broadcast_from_batch(neighbors_batch, 0)) 179 | 180 | 181 | def eval_rollout( 182 | model_apply: Callable, 183 | case, 184 | params: hk.Params, 185 | state: hk.State, 186 | loader_eval: Iterable, 187 | neighbors: partition.NeighborList, 188 | metrics_computer: MetricsComputer, 189 | n_rollout_steps: int, 190 | n_trajs: int, 191 | rollout_dir: str, 192 | out_type: str = "none", 193 | n_extrap_steps: int = 0, 194 | ) -> MetricsDict: 195 | """Compute the rollout and evaluate the metrics. 196 | 197 | Args: 198 | model_apply: Model function. 199 | case: CaseSetupFn class. 200 | params: Haiku params. 201 | state: Haiku state. 202 | loader_eval: Evaluation data loader. 203 | neighbors: Neighbor list. 204 | metrics_computer: MetricsComputer with the desired metrics. 205 | n_rollout_steps: Number of rollout steps. 206 | n_trajs: Number of ground truth trajectories to evaluate. 207 | rollout_dir: Parent directory path where to store the rollout and metrics dict. 208 | out_type: Output type. Either "none", "vtk" or "pkl". 209 | n_extrap_steps: Number of extrapolation steps (beyond the ground truth rollout). 210 | 211 | Returns: 212 | Metrics per trajectory. 213 | """ 214 | batch_size = loader_eval.batch_size 215 | t_window = loader_eval.dataset.input_seq_length 216 | eval_metrics = {} 217 | 218 | if rollout_dir is not None: 219 | os.makedirs(rollout_dir, exist_ok=True) 220 | 221 | forward_eval = partial( 222 | _forward_eval, 223 | model_apply=model_apply, 224 | case_integrate=case.integrate, 225 | ) 226 | forward_eval_vmap = vmap(forward_eval, in_axes=(None, None, 0, 0, 0)) 227 | preprocess_eval_vmap = vmap(case.preprocess_eval, in_axes=(0, 0)) 228 | metrics_computer_vmap = vmap(metrics_computer, in_axes=(0, 0)) 229 | 230 | for i, traj_batch_i in enumerate(loader_eval): 231 | # if n_trajs is not a multiple of batch_size, we slice from the last batch 232 | n_traj_left = n_trajs - i * batch_size 233 | if n_traj_left < batch_size: 234 | traj_batch_i = jax.tree_map(lambda x: x[:n_traj_left], traj_batch_i) 235 | 236 | # numpy to jax 237 | traj_batch_i = jax.tree_map(lambda x: jnp.array(x), traj_batch_i) 238 | # (pos_input_batch, particle_type_batch) = traj_batch_i 239 | # pos_input_batch.shape = (batch, num_particles, seq_length, dim) 240 | 241 | example_rollout_batch, metrics_batch, neighbors = _eval_batched_rollout( 242 | forward_eval_vmap=forward_eval_vmap, 243 | preprocess_eval_vmap=preprocess_eval_vmap, 244 | case=case, 245 | params=params, 246 | state=state, 247 | traj_batch_i=traj_batch_i, # (batch, nodes, t, dim) 248 | neighbors=neighbors, 249 | metrics_computer_vmap=metrics_computer_vmap, 250 | n_rollout_steps=n_rollout_steps, 251 | t_window=t_window, 252 | n_extrap_steps=n_extrap_steps, 253 | ) 254 | 255 | current_batch_size = traj_batch_i[0].shape[0] 256 | for j in range(current_batch_size): 257 | # write metrics to output dictionary 258 | ind = i * batch_size + j 259 | eval_metrics[f"rollout_{ind}"] = broadcast_from_batch(metrics_batch, j) 260 | 261 | if rollout_dir is not None: 262 | # (batch, nodes, t, dim) -> (batch, t, nodes, dim) 263 | pos_input_batch = traj_batch_i[0].transpose(0, 2, 1, 3) 264 | 265 | for j in range(current_batch_size): # write every trajectory to file 266 | pos_input = pos_input_batch[j] 267 | example_rollout = example_rollout_batch[j] 268 | 269 | initial_positions = pos_input[:t_window] 270 | example_full = jnp.concatenate([initial_positions, example_rollout]) 271 | example_rollout = { 272 | "predicted_rollout": example_full, # (t + extrap, nodes, dim) 273 | "ground_truth_rollout": pos_input, # (t, nodes, dim), 274 | "particle_type": traj_batch_i[1][j], # (nodes,) 275 | } 276 | 277 | file_prefix = os.path.join(rollout_dir, f"rollout_{i*batch_size+j}") 278 | if out_type == "vtk": # write vtk files for each time step 279 | for k in range(example_full.shape[0]): 280 | # predictions 281 | state_vtk = { 282 | "r": example_rollout["predicted_rollout"][k], 283 | "tag": example_rollout["particle_type"], 284 | } 285 | write_vtk(state_vtk, f"{file_prefix}_{k}.vtk") 286 | for k in range(pos_input.shape[0]): 287 | # ground truth reference 288 | ref_state_vtk = { 289 | "r": example_rollout["ground_truth_rollout"][k], 290 | "tag": example_rollout["particle_type"], 291 | } 292 | write_vtk(ref_state_vtk, f"{file_prefix}_ref_{k}.vtk") 293 | elif out_type == "pkl": 294 | filename = f"{file_prefix}.pkl" 295 | 296 | with open(filename, "wb") as f: 297 | pickle.dump(example_rollout, f) 298 | 299 | if (i * batch_size + j + 1) >= n_trajs: 300 | break 301 | 302 | if rollout_dir is not None: 303 | # save metrics 304 | t = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) 305 | with open(f"{rollout_dir}/metrics{t}.pkl", "wb") as f: 306 | pickle.dump(eval_metrics, f) 307 | 308 | return eval_metrics 309 | 310 | 311 | def infer( 312 | model: hk.TransformedWithState, 313 | case, 314 | data_test: H5Dataset, 315 | params: Optional[hk.Params] = None, 316 | state: Optional[hk.State] = None, 317 | load_ckp: Optional[str] = None, 318 | cfg_eval_infer: Union[Dict, DictConfig] = defaults.eval.infer, 319 | rollout_dir: Optional[str] = defaults.eval.rollout_dir, 320 | n_rollout_steps: int = defaults.eval.n_rollout_steps, 321 | seed: int = defaults.seed, 322 | ): 323 | """ 324 | Infer on a dataset, compute metrics and optionally save rollout in out_type format. 325 | 326 | Args: 327 | model: (Transformed) Haiku model. 328 | case: Case setup class. 329 | data_test: Test dataset. 330 | params: Haiku params. 331 | state: Haiku state. 332 | load_ckp: Path to checkpoint directory. 333 | rollout_dir: Path to rollout directory. 334 | cfg_eval_infer: Evaluation configuration for inference mode. 335 | n_rollout_steps: Number of rollout steps. 336 | seed: Seed. 337 | 338 | Returns: 339 | eval_metrics: Metrics per trajectory. 340 | """ 341 | assert ( 342 | params is not None or load_ckp is not None 343 | ), "Either params or a load_ckp directory must be provided for inference." 344 | 345 | if isinstance(cfg_eval_infer, Dict): 346 | cfg_eval_infer = OmegaConf.create(cfg_eval_infer) 347 | 348 | # if one of the cfg_* arguments has a subset of the default configs, merge them 349 | cfg_eval_infer = OmegaConf.merge(defaults.eval.infer, cfg_eval_infer) 350 | 351 | n_trajs = cfg_eval_infer.n_trajs 352 | if n_trajs == -1: 353 | n_trajs = data_test.num_samples 354 | 355 | if params is not None: 356 | if state is None: 357 | state = {} 358 | else: 359 | params, state, _, _ = load_haiku(load_ckp) 360 | 361 | key, seed_worker, generator = set_seed(seed) 362 | 363 | loader_test = DataLoader( 364 | dataset=data_test, 365 | batch_size=cfg_eval_infer.batch_size, 366 | collate_fn=numpy_collate, 367 | worker_init_fn=seed_worker, 368 | generator=generator, 369 | ) 370 | metrics_computer = MetricsComputer( 371 | cfg_eval_infer.metrics, 372 | dist_fn=case.displacement, 373 | metadata=data_test.metadata, 374 | input_seq_length=data_test.input_seq_length, 375 | stride=cfg_eval_infer.metrics_stride, 376 | ) 377 | # Precompile model 378 | model_apply = jit(model.apply) 379 | 380 | # init values 381 | pos_input_and_target, particle_type = next(iter(loader_test)) 382 | sample = (pos_input_and_target[0], particle_type[0]) 383 | key, _, _, neighbors = case.allocate(key, sample) 384 | 385 | eval_metrics = eval_rollout( 386 | model_apply=model_apply, 387 | case=case, 388 | metrics_computer=metrics_computer, 389 | params=params, 390 | state=state, 391 | neighbors=neighbors, 392 | loader_eval=loader_test, 393 | n_rollout_steps=n_rollout_steps, 394 | n_trajs=n_trajs, 395 | rollout_dir=rollout_dir, 396 | out_type=cfg_eval_infer.out_type, 397 | n_extrap_steps=cfg_eval_infer.n_extrap_steps, 398 | ) 399 | return eval_metrics 400 | -------------------------------------------------------------------------------- /lagrangebench/evaluate/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for evaluation.""" 2 | 3 | import os 4 | import pickle 5 | 6 | import numpy as np 7 | 8 | 9 | def write_vtk(data_dict, path): 10 | """Store a .vtk file for ParaView.""" 11 | 12 | try: 13 | import pyvista 14 | except ImportError: 15 | raise ImportError("Please install pyvista to write VTK files.") 16 | 17 | r = np.asarray(data_dict["r"]) 18 | N, dim = r.shape 19 | 20 | # PyVista treats the position information differently than the rest 21 | if dim == 2: 22 | r = np.hstack([r, np.zeros((N, 1))]) 23 | data_pv = pyvista.PolyData(r) 24 | 25 | # copy all the other information also to pyvista, using plain numpy arrays 26 | for k, v in data_dict.items(): 27 | # skip r because we already considered it above 28 | if k == "r": 29 | continue 30 | 31 | # working in 3D or scalar features do not require special care 32 | if dim == 2 and v.ndim == 2: 33 | v = np.hstack([v, np.zeros((N, 1))]) 34 | 35 | data_pv[k] = np.asarray(v) 36 | 37 | data_pv.save(path) 38 | 39 | 40 | def pkl2vtk(src_path, dst_path=None): 41 | """Convert a rollout pickle file to a set of vtk files. 42 | 43 | Args: 44 | src_path (str): Source path to .pkl file. 45 | dst_path (str, optoinal): Destination directory path. Defaults to None. 46 | If None, then the vtk files are saved in the same directory as the pkl file. 47 | 48 | Example: 49 | pkl2vtk("rollout/test/rollout_0.pkl", "rollout/test_vtk") 50 | will create files rollout_0_0.vtk, rollout_0_1.vtk, etc. in the directory 51 | "rollout/test_vtk" 52 | """ 53 | 54 | # set up destination directory 55 | if dst_path is None: 56 | dst_path = os.path.dirname(src_path) 57 | os.makedirs(dst_path, exist_ok=True) 58 | 59 | # load rollout 60 | with open(src_path, "rb") as f: 61 | rollout = pickle.load(f) 62 | 63 | file_prefix = os.path.join(dst_path, os.path.basename(src_path).split(".")[0]) 64 | for k in range(rollout["predicted_rollout"].shape[0]): 65 | # predictions 66 | state_vtk = { 67 | "r": rollout["predicted_rollout"][k], 68 | "tag": rollout["particle_type"], 69 | } 70 | write_vtk(state_vtk, f"{file_prefix}_{k}.vtk") 71 | # ground truth reference 72 | state_vtk = { 73 | "r": rollout["ground_truth_rollout"][k], 74 | "tag": rollout["particle_type"], 75 | } 76 | write_vtk(state_vtk, f"{file_prefix}_ref_{k}.vtk") 77 | -------------------------------------------------------------------------------- /lagrangebench/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Baseline models.""" 2 | 3 | from .egnn import EGNN 4 | from .gns import GNS 5 | from .linear import Linear 6 | from .painn import PaiNN 7 | from .segnn import SEGNN 8 | 9 | __all__ = ["GNS", "SEGNN", "EGNN", "PaiNN", "Linear"] 10 | -------------------------------------------------------------------------------- /lagrangebench/models/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Tuple 3 | 4 | import haiku as hk 5 | import jax.numpy as jnp 6 | 7 | 8 | class BaseModel(hk.Module, ABC): 9 | """Base model class. All models must inherit from this class.""" 10 | 11 | @abstractmethod 12 | def __call__( 13 | self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray] 14 | ) -> Dict[str, jnp.ndarray]: 15 | """Forward pass. 16 | 17 | We specify the dimensions of the inputs and outputs using the number of nodes N, 18 | the number of edges E, number of historic velocities K (=input_seq_length - 1), 19 | and the dimensionality of the feature vectors dim. 20 | 21 | Args: 22 | sample: Tuple with feature dictionary and particle type. Possible features 23 | 24 | - "abs_pos" (N, K+1, dim), absolute positions 25 | - "vel_hist" (N, K*dim), historical velocity sequence 26 | - "vel_mag" (N,), velocity magnitudes 27 | - "bound" (N, 2*dim), distance to boundaries 28 | - "force" (N, dim), external force field 29 | - "rel_disp" (E, dim), relative displacement vectors 30 | - "rel_dist" (E, 1), relative distances, i.e. magnitude of displacements 31 | - "senders" (E), sender indices 32 | - "receivers" (E), receiver indices 33 | Returns: 34 | Dict with model output. 35 | The keys must be at least one of the following: 36 | 37 | - "acc" (N, dim), (normalized) acceleration 38 | - "vel" (N, dim), (normalized) velocity 39 | - "pos" (N, dim), (absolute) next position 40 | """ 41 | raise NotImplementedError 42 | -------------------------------------------------------------------------------- /lagrangebench/models/egnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | E(n) equivariant GNN from `Garcia Satorras et al. `_. 3 | EGNN model, layers and feature transform. 4 | 5 | Original implementation: https://github.com/vgsatorras/egnn 6 | 7 | Standalone implementation + validation: https://github.com/gerkone/egnn-jax 8 | """ 9 | 10 | from typing import Any, Callable, Dict, Optional, Tuple 11 | 12 | import haiku as hk 13 | import jax 14 | import jax.numpy as jnp 15 | import jraph 16 | from jax.tree_util import Partial 17 | from jax_sph.jax_md import space 18 | 19 | from lagrangebench.utils import NodeType 20 | 21 | from .base import BaseModel 22 | from .utils import LinearXav, MLPXav 23 | 24 | 25 | class EGNNLayer(hk.Module): 26 | r"""E(n)-equivariant EGNN layer. 27 | 28 | Applies a message passing step where the positions are corrected with the velocities 29 | and a learnable correction term :math:`\psi_x(\mathbf{h}_i^{(t+1)})`: 30 | """ 31 | 32 | def __init__( 33 | self, 34 | layer_num: int, 35 | hidden_size: int, 36 | output_size: int, 37 | displacement_fn: space.DisplacementFn, 38 | shift_fn: space.ShiftFn, 39 | blocks: int = 1, 40 | act_fn: Callable = jax.nn.silu, 41 | pos_aggregate_fn: Optional[Callable] = jraph.segment_sum, 42 | msg_aggregate_fn: Optional[Callable] = jraph.segment_sum, 43 | residual: bool = True, 44 | attention: bool = False, 45 | normalize: bool = False, 46 | tanh: bool = False, 47 | dt: float = 0.001, 48 | eps: float = 1e-8, 49 | ): 50 | """Initialize the layer. 51 | 52 | Args: 53 | layer_num: layer number 54 | hidden_size: hidden size 55 | output_size: output size 56 | displacement_fn: Displacement function for the acceleration computation. 57 | shift_fn: Shift function for updating positions 58 | blocks: number of blocks in the node and edge MLPs 59 | act_fn: activation function 60 | pos_aggregate_fn: position aggregation function 61 | msg_aggregate_fn: message aggregation function 62 | residual: whether to use residual connections 63 | attention: whether to use attention 64 | normalize: whether to normalize the coordinates 65 | tanh: whether to use tanh in the position update 66 | dt: position update step size 67 | eps: small number to avoid division by zero 68 | """ 69 | super().__init__(f"layer_{layer_num}") 70 | 71 | self._displacement_fn = displacement_fn 72 | self._shift_fn = shift_fn 73 | self.pos_aggregate_fn = pos_aggregate_fn 74 | self.msg_aggregate_fn = msg_aggregate_fn 75 | self._residual = residual 76 | self._normalize = normalize 77 | self._eps = eps 78 | 79 | # message network 80 | self._edge_mlp = MLPXav( 81 | [hidden_size] * blocks + [hidden_size], 82 | activation=act_fn, 83 | activate_final=True, 84 | ) 85 | 86 | # update network 87 | self._node_mlp = MLPXav( 88 | [hidden_size] * blocks + [output_size], 89 | activation=act_fn, 90 | activate_final=False, 91 | ) 92 | 93 | # position update network 94 | net = [LinearXav(hidden_size)] * blocks 95 | # NOTE: from https://github.com/vgsatorras/egnn/blob/main/models/gcl.py#L254 96 | net += [ 97 | act_fn, 98 | LinearXav(1, with_bias=False, w_init=hk.initializers.UniformScaling(dt)), 99 | ] 100 | if tanh: 101 | net.append(jax.nn.tanh) 102 | self._pos_correction_mlp = hk.Sequential(net) 103 | 104 | # velocity integrator network 105 | net = [LinearXav(hidden_size)] * blocks 106 | net += [ 107 | act_fn, 108 | LinearXav(1, with_bias=False, w_init=hk.initializers.UniformScaling(dt)), 109 | ] 110 | self._vel_correction_mlp = hk.Sequential(net) 111 | 112 | # attention 113 | self._attention_mlp = None 114 | if attention: 115 | self._attention_mlp = hk.Sequential( 116 | [LinearXav(hidden_size), jax.nn.sigmoid] 117 | ) 118 | 119 | def _pos_update( 120 | self, 121 | pos: jnp.ndarray, 122 | graph: jraph.GraphsTuple, 123 | coord_diff: jnp.ndarray, 124 | ) -> jnp.ndarray: 125 | trans = coord_diff * self._pos_correction_mlp(graph.edges) 126 | return self.pos_aggregate_fn(trans, graph.senders, num_segments=pos.shape[0]) 127 | 128 | def _message( 129 | self, 130 | radial: jnp.ndarray, 131 | edge_attribute: jnp.ndarray, 132 | edge_features: Any, 133 | incoming: jnp.ndarray, 134 | outgoing: jnp.ndarray, 135 | globals_: Any, 136 | ) -> jnp.ndarray: 137 | _ = edge_features 138 | _ = globals_ 139 | msg = jnp.concatenate([incoming, outgoing, radial], axis=-1) 140 | if edge_attribute is not None: 141 | msg = jnp.concatenate([msg, edge_attribute], axis=-1) 142 | msg = self._edge_mlp(msg) 143 | if self._attention_mlp: 144 | att = self._attention_mlp(msg) 145 | msg = msg * att 146 | return msg 147 | 148 | def _update( 149 | self, 150 | node_attribute: jnp.ndarray, 151 | nodes: jnp.ndarray, 152 | senders: Any, 153 | msg: jnp.ndarray, 154 | globals_: Any, 155 | ) -> jnp.ndarray: 156 | _ = senders 157 | _ = globals_ 158 | x = jnp.concatenate([nodes, msg], axis=-1) 159 | if node_attribute is not None: 160 | x = jnp.concatenate([x, node_attribute], axis=-1) 161 | x = self._node_mlp(x) 162 | if self._residual: 163 | x = nodes + x 164 | return x 165 | 166 | def _coord2radial( 167 | self, graph: jraph.GraphsTuple, coord: jnp.array 168 | ) -> Tuple[jnp.array, jnp.array]: 169 | coord_diff = self._displacement_fn(coord[graph.senders], coord[graph.receivers]) 170 | radial = jnp.sum(coord_diff**2, 1)[:, jnp.newaxis] 171 | if self._normalize: 172 | norm = jnp.sqrt(radial) 173 | coord_diff = coord_diff / (norm + self._eps) 174 | return radial, coord_diff 175 | 176 | def __call__( 177 | self, 178 | graph: jraph.GraphsTuple, 179 | pos: jnp.ndarray, 180 | vel: jnp.ndarray, 181 | edge_attribute: Optional[jnp.ndarray] = None, 182 | node_attribute: Optional[jnp.ndarray] = None, 183 | ) -> Tuple[jraph.GraphsTuple, jnp.ndarray]: 184 | """ 185 | Apply EGNN layer. 186 | 187 | Args: 188 | graph: Graph from previous step 189 | pos: Node position, updated separately 190 | vel: Node velocity 191 | edge_attribute: Edge attribute (optional) 192 | node_attribute: Node attribute (optional) 193 | Returns: 194 | Updated graph, node position 195 | """ 196 | radial, coord_diff = self._coord2radial(graph, pos) 197 | graph = jraph.GraphNetwork( 198 | update_edge_fn=Partial(self._message, radial, edge_attribute), 199 | update_node_fn=Partial(self._update, node_attribute), 200 | aggregate_edges_for_nodes_fn=self.msg_aggregate_fn, 201 | )(graph) 202 | # update position 203 | pos = self._shift_fn(pos, self._pos_update(pos, graph, coord_diff)) 204 | # integrate velocity 205 | pos = self._shift_fn(pos, self._vel_correction_mlp(graph.nodes) * vel) 206 | return graph, pos 207 | 208 | 209 | class EGNN(BaseModel): 210 | r""" 211 | E(n) Graph Neural Network by 212 | `Garcia Satorras et al. `_. 213 | 214 | EGNN doesn't require expensive higher-order representations in intermediate layers; 215 | instead it relies on separate scalar and vector channels, which are treated 216 | differently by EGNN layers. In this setup, EGNN is similar to a learnable numerical 217 | integrator: 218 | 219 | .. math:: 220 | \begin{align} 221 | \mathbf{m}_{ij}^{(t+1)} &= \phi_e \left( 222 | \mathbf{m}_{ij}^{(t)}, \mathbf{h}_i^{(t)}, 223 | \mathbf{h}_j^{(t)}, ||\mathbf{x}_i^{(t)} - \mathbf{x}_j^{(t)}||^2 224 | \right) \\ 225 | \mathbf{\hat{m}}_{ij}^{(t+1)} &= 226 | (\mathbf{x}_i^{(t)} - \mathbf{x}_j^{(t)}) \phi_x(\mathbf{m}_{ij}^{(t+1)}) 227 | \end{align} 228 | 229 | And the node update with the integrator 230 | 231 | .. math:: 232 | \begin{align} 233 | \mathbf{h}_i^{(t+1)} &= \psi_h \left( 234 | \mathbf{h}_i^{(t)}, \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}^{(t+1)} 235 | \right) \\ 236 | \mathbf{x}_i^{(t+1)} &= \mathbf{x}_i^{(t)} 237 | + \mathbf{\hat{m}}_{ij}^{(t+1)} \psi_x(\mathbf{h}_i^{(t+1)}) 238 | \end{align} 239 | 240 | where :math:`\mathbf{m}_{ij}` and :math:`\mathbf{\hat{m}}_{ij}` are the scalar and 241 | vector messages respectively, and :math:`\mathbf{x}_{i}` are the positions. 242 | 243 | This implementation differs from the original in two places: 244 | 245 | - because our datasets can have periodic boundary conditions, we use shift and 246 | displacement functions that take care of it when operations on positions are done. 247 | - we apply a simple integrator after the last layer to get the acceleration. 248 | """ 249 | 250 | def __init__( 251 | self, 252 | hidden_size: int, 253 | output_size: int, 254 | dt: float, 255 | n_vels: int, 256 | displacement_fn: space.DisplacementFn, 257 | shift_fn: space.ShiftFn, 258 | normalization_stats: Optional[Dict[str, jnp.ndarray]] = None, 259 | act_fn: Callable = jax.nn.silu, 260 | num_mp_steps: int = 4, 261 | homogeneous_particles: bool = True, 262 | residual: bool = True, 263 | attention: bool = False, 264 | normalize: bool = False, 265 | tanh: bool = False, 266 | ): 267 | r""" 268 | Initialize the network. 269 | 270 | Args: 271 | hidden_size: Number of hidden features. 272 | output_size: Number of features for 'h' at the output. 273 | dt: Time step for position and velocity integration. Used to rescale the 274 | initialization of the correction MLP. 275 | n_vels: Number of velocities in the history. 276 | displacement_fn: Displacement function for the acceleration computation. 277 | shift_fn: Shift function for updating positions. 278 | normalization_stats: Normalization statistics for the input data. 279 | act_fn: Non-linearity. 280 | num_mp_steps: Number of layer for the EGNN 281 | homogeneous_particles: If all particles are of homogeneous type. 282 | residual: Whether to use residual connections. 283 | attention: Whether to use attention or not. 284 | normalize: Normalizes the coordinates messages such that: 285 | ``x^{l+1}_i = x^{l}_i + \sum(x_i - x_j)\phi_x(m_{ij})\|x_i - x_j\|`` 286 | It may help in the stability or generalization. Not used in the paper. 287 | tanh: Sets a tanh activation function at the output of ``\phi_x(m_{ij})``. 288 | It bounds the output of ``\phi_x(m_{ij})`` which definitely improves in 289 | stability but it may decrease in accuracy. Not used in the paper. 290 | """ 291 | super().__init__() 292 | # network 293 | self._hidden_size = hidden_size 294 | self._output_size = output_size 295 | self._act_fn = act_fn 296 | self._num_mp_steps = num_mp_steps 297 | self._residual = residual 298 | self._attention = attention 299 | self._normalize = normalize 300 | self._tanh = tanh 301 | 302 | # integrator 303 | self._dt = dt / self._num_mp_steps 304 | self._displacement_fn = displacement_fn 305 | self._shift_fn = shift_fn 306 | if normalization_stats is None: 307 | normalization_stats = { 308 | "velocity": {"mean": 0.0, "std": 1.0}, 309 | "acceleration": {"mean": 0.0, "std": 1.0}, 310 | } 311 | self._vel_stats = normalization_stats["velocity"] 312 | self._acc_stats = normalization_stats["acceleration"] 313 | 314 | # transform 315 | self._n_vels = n_vels 316 | self._homogeneous_particles = homogeneous_particles 317 | 318 | def _transform( 319 | self, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray 320 | ) -> Tuple[jraph.GraphsTuple, Dict[str, jnp.ndarray]]: 321 | props = {} 322 | n_nodes = features["vel_hist"].shape[0] 323 | 324 | props["vel"] = jnp.reshape(features["vel_hist"], (n_nodes, self._n_vels, -1)) 325 | 326 | # most recent position 327 | props["pos"] = features["abs_pos"][:, -1] 328 | # relative distances between particles 329 | props["edge_attr"] = features["rel_dist"] 330 | # force magnitude as node attributes 331 | props["node_attr"] = None 332 | if "force" in features: 333 | props["node_attr"] = jnp.sqrt( 334 | jnp.sum(features["force"] ** 2, axis=-1, keepdims=True) 335 | ) 336 | 337 | # velocity magnitudes as node features 338 | node_features = jnp.concatenate( 339 | [ 340 | jnp.sqrt(jnp.sum(props["vel"][:, i, :] ** 2, axis=-1, keepdims=True)) 341 | for i in range(self._n_vels) 342 | ], 343 | axis=-1, 344 | ) 345 | if not self._homogeneous_particles: 346 | particles = jax.nn.one_hot(particle_type, NodeType.SIZE) 347 | node_features = jnp.concatenate([node_features, particles], axis=-1) 348 | 349 | graph = jraph.GraphsTuple( 350 | nodes=node_features, 351 | edges=None, 352 | senders=features["senders"], 353 | receivers=features["receivers"], 354 | n_node=jnp.array([n_nodes]), 355 | n_edge=jnp.array([len(features["senders"])]), 356 | globals=None, 357 | ) 358 | 359 | return graph, props 360 | 361 | def _postprocess( 362 | self, next_pos: jnp.ndarray, props: Dict[str, jnp.ndarray] 363 | ) -> Dict[str, jnp.ndarray]: 364 | prev_vel = props["vel"][:, -1, :] 365 | prev_pos = props["pos"] 366 | # first order finite difference 367 | next_vel = self._displacement_fn(next_pos, prev_pos) 368 | acc = next_vel - prev_vel 369 | return {"pos": next_pos, "vel": next_vel, "acc": acc} 370 | 371 | def __call__( 372 | self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray] 373 | ) -> Dict[str, jnp.ndarray]: 374 | graph, props = self._transform(*sample) 375 | # input node embedding 376 | h = LinearXav(self._hidden_size, name="scalar_emb")(graph.nodes) 377 | graph = graph._replace(nodes=h) 378 | prev_vel = props["vel"][:, -1, :] 379 | # egnn works with unnormalized velocities 380 | prev_vel = prev_vel * self._vel_stats["std"] + self._vel_stats["mean"] 381 | # message passing 382 | next_pos = props["pos"].copy() 383 | for n in range(self._num_mp_steps): 384 | graph, next_pos = EGNNLayer( 385 | layer_num=n, 386 | hidden_size=self._hidden_size, 387 | output_size=self._hidden_size, 388 | displacement_fn=self._displacement_fn, 389 | shift_fn=self._shift_fn, 390 | act_fn=self._act_fn, 391 | residual=self._residual, 392 | attention=self._attention, 393 | normalize=self._normalize, 394 | dt=self._dt, 395 | tanh=self._tanh, 396 | )(graph, next_pos, prev_vel, props["edge_attr"], props["node_attr"]) 397 | 398 | # position finite differencing to get acceleration 399 | out = self._postprocess(next_pos, props) 400 | return out 401 | -------------------------------------------------------------------------------- /lagrangebench/models/gns.py: -------------------------------------------------------------------------------- 1 | """ 2 | Graph Network-based Simulator. 3 | GNS model and feature transform. 4 | """ 5 | 6 | from typing import Dict, Tuple 7 | 8 | import haiku as hk 9 | import jax.numpy as jnp 10 | import jraph 11 | 12 | from lagrangebench.utils import NodeType 13 | 14 | from .base import BaseModel 15 | from .utils import build_mlp 16 | 17 | 18 | class GNS(BaseModel): 19 | r"""Graph Network-based Simulator by 20 | `Sanchez-Gonzalez et al. `_. 21 | 22 | GNS is the simples graph neural network applied to particle dynamics. It is built on 23 | the usual Graph Network architecture, with an encoder, a processor, and a decoder. 24 | 25 | .. math:: 26 | \begin{align} 27 | \mathbf{m}_{ij}^{(t+1)} &= \phi \left( 28 | \mathbf{m}_{ij}^{(t)}, \mathbf{h}_i^{(t)}, \mathbf{h}_j^{(t)} \right) \\ 29 | \mathbf{h}_i^{(t+1)} &= \psi \left( 30 | \mathbf{h}_i^{(t)}, \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}^{(t+1)} 31 | \right) \\ 32 | \end{align} 33 | """ 34 | 35 | def __init__( 36 | self, 37 | particle_dimension: int, 38 | latent_size: int, 39 | blocks_per_step: int, 40 | num_mp_steps: int, 41 | particle_type_embedding_size: int, 42 | num_particle_types: int = NodeType.SIZE, 43 | ): 44 | """Initialize the model. 45 | 46 | Args: 47 | particle_dimension: Space dimensionality (e.g. 2 or 3). 48 | latent_size: Size of the latent representations. 49 | blocks_per_step: Number of MLP layers per block. 50 | num_mp_steps: Number of message passing steps. 51 | particle_type_embedding_size: Size of the particle type embedding. 52 | num_particle_types: Max number of particle types. 53 | """ 54 | super().__init__() 55 | self._output_size = particle_dimension 56 | self._latent_size = latent_size 57 | self._blocks_per_step = blocks_per_step 58 | self._mp_steps = num_mp_steps 59 | self._num_particle_types = num_particle_types 60 | 61 | self._embedding = hk.Embed( 62 | num_particle_types, particle_type_embedding_size 63 | ) # (9, 16) 64 | 65 | def _encoder(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 66 | """MLP graph encoder.""" 67 | node_latents = build_mlp( 68 | self._latent_size, self._latent_size, self._blocks_per_step 69 | )(graph.nodes) 70 | edge_latents = build_mlp( 71 | self._latent_size, self._latent_size, self._blocks_per_step 72 | )(graph.edges) 73 | return jraph.GraphsTuple( 74 | nodes=node_latents, 75 | edges=edge_latents, 76 | globals=graph.globals, 77 | receivers=graph.receivers, 78 | senders=graph.senders, 79 | n_node=jnp.asarray([node_latents.shape[0]]), 80 | n_edge=jnp.asarray([edge_latents.shape[0]]), 81 | ) 82 | 83 | def _processor(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 84 | """Sequence of Graph Network blocks.""" 85 | 86 | def update_edge_features( 87 | edge_features, 88 | sender_node_features, 89 | receiver_node_features, 90 | _, # globals_ 91 | ): 92 | update_fn = build_mlp( 93 | self._latent_size, self._latent_size, self._blocks_per_step 94 | ) 95 | # Calculate sender node features from edge features 96 | return update_fn( 97 | jnp.concatenate( 98 | [sender_node_features, receiver_node_features, edge_features], 99 | axis=-1, 100 | ) 101 | ) 102 | 103 | def update_node_features( 104 | node_features, 105 | _, # aggr_sender_edge_features, 106 | aggr_receiver_edge_features, 107 | __, # globals_, 108 | ): 109 | update_fn = build_mlp( 110 | self._latent_size, self._latent_size, self._blocks_per_step 111 | ) 112 | features = [node_features, aggr_receiver_edge_features] 113 | return update_fn(jnp.concatenate(features, axis=-1)) 114 | 115 | # Perform iterative message passing by stacking Graph Network blocks 116 | for _ in range(self._mp_steps): 117 | _graph = jraph.GraphNetwork( 118 | update_edge_fn=update_edge_features, update_node_fn=update_node_features 119 | )(graph) 120 | graph = graph._replace( 121 | nodes=_graph.nodes + graph.nodes, edges=_graph.edges + graph.edges 122 | ) 123 | 124 | return graph 125 | 126 | def _decoder(self, graph: jraph.GraphsTuple): 127 | """MLP graph node decoder.""" 128 | return build_mlp( 129 | self._latent_size, 130 | self._output_size, 131 | self._blocks_per_step, 132 | is_layer_norm=False, 133 | )(graph.nodes) 134 | 135 | def _transform( 136 | self, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray 137 | ) -> jraph.GraphsTuple: 138 | """Convert physical features to jraph.GraphsTuple for gns.""" 139 | n_total_points = features["vel_hist"].shape[0] 140 | node_features = [ 141 | features[k] 142 | for k in ["vel_hist", "vel_mag", "bound", "force"] 143 | if k in features 144 | ] 145 | edge_features = [features[k] for k in ["rel_disp", "rel_dist"] if k in features] 146 | 147 | graph = jraph.GraphsTuple( 148 | nodes=jnp.concatenate(node_features, axis=-1), 149 | edges=jnp.concatenate(edge_features, axis=-1), 150 | receivers=features["receivers"], 151 | senders=features["senders"], 152 | n_node=jnp.array([n_total_points]), 153 | n_edge=jnp.array([len(features["senders"])]), 154 | globals=None, 155 | ) 156 | 157 | return graph, particle_type 158 | 159 | def __call__( 160 | self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray] 161 | ) -> Dict[str, jnp.ndarray]: 162 | graph, particle_type = self._transform(*sample) 163 | 164 | if self._num_particle_types > 1: 165 | particle_type_embeddings = self._embedding(particle_type) 166 | new_node_features = jnp.concatenate( 167 | [graph.nodes, particle_type_embeddings], axis=-1 168 | ) 169 | graph = graph._replace(nodes=new_node_features) 170 | acc = self._decoder(self._processor(self._encoder(graph))) 171 | return {"acc": acc} 172 | -------------------------------------------------------------------------------- /lagrangebench/models/linear.py: -------------------------------------------------------------------------------- 1 | """Simple baseline linear model.""" 2 | 3 | from typing import Dict, Tuple 4 | 5 | import haiku as hk 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from jax import vmap 9 | 10 | from .base import BaseModel 11 | 12 | 13 | class Linear(BaseModel): 14 | r"""Model defining linear relation between input nodes and targets. 15 | 16 | :math:`\mathbf{a}_i = \mathbf{W} \mathbf{x}_i` where :math:`\mathbf{a}_i` are the 17 | output accelerations, :math:`\mathbf{W}` is a learnable weight matrix and 18 | :math:`\mathbf{x}_i` are input features. 19 | """ 20 | 21 | def __init__(self, dim_out): 22 | """Initialize the model. 23 | 24 | Args: 25 | dim_out: Output dimensionality. 26 | """ 27 | super().__init__() 28 | self.mlp = hk.Linear(dim_out) 29 | 30 | def __call__( 31 | self, sample: Tuple[Dict[str, jnp.ndarray], np.ndarray] 32 | ) -> Dict[str, jnp.ndarray]: 33 | # transform 34 | features, particle_type = sample 35 | x = [ 36 | features[k] 37 | for k in ["vel_hist", "vel_mag", "bound", "force"] 38 | if k in features 39 | ] + [particle_type[:, None]] 40 | # call 41 | acc = vmap(self.mlp)(jnp.concatenate(x, axis=-1)) 42 | return {"acc": acc} 43 | -------------------------------------------------------------------------------- /lagrangebench/models/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Iterable, NamedTuple, Optional 2 | 3 | import e3nn_jax as e3nn 4 | import haiku as hk 5 | import jax 6 | import jax.numpy as jnp 7 | import jraph 8 | 9 | from lagrangebench.utils import NodeType 10 | 11 | 12 | class LinearXav(hk.Linear): 13 | """Linear layer with Xavier init. Avoid distracting 'w_init' everywhere.""" 14 | 15 | def __init__( 16 | self, 17 | output_size: int, 18 | with_bias: bool = True, 19 | w_init: Optional[hk.initializers.Initializer] = None, 20 | b_init: Optional[hk.initializers.Initializer] = None, 21 | name: Optional[str] = None, 22 | ): 23 | if w_init is None: 24 | w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform") 25 | super().__init__(output_size, with_bias, w_init, b_init, name) 26 | 27 | 28 | class MLPXav(hk.nets.MLP): 29 | """MLP layer with Xavier init. Avoid distracting 'w_init' everywhere.""" 30 | 31 | def __init__( 32 | self, 33 | output_sizes: Iterable[int], 34 | with_bias: bool = True, 35 | w_init: Optional[hk.initializers.Initializer] = None, 36 | b_init: Optional[hk.initializers.Initializer] = None, 37 | activation: Callable = jax.nn.silu, 38 | activate_final: bool = False, 39 | name: Optional[str] = None, 40 | ): 41 | if w_init is None: 42 | w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform") 43 | if not with_bias: 44 | b_init = None 45 | super().__init__( 46 | output_sizes, 47 | w_init, 48 | b_init, 49 | with_bias, 50 | activation, 51 | activate_final, 52 | name, 53 | ) 54 | 55 | 56 | class SteerableGraphsTuple(NamedTuple): 57 | r""" 58 | Pack (steerable) node and edge attributes with jraph.GraphsTuple. 59 | 60 | Attributes: 61 | graph: jraph.GraphsTuple, graph structure 62 | node_attributes: (N, irreps.dim), node attributes :math:`\mathbf{\hat{a}}_i` 63 | edge_attributes: (E, irreps.dim), edge attributes :math:`\mathbf{\hat{a}}_{ij}` 64 | additional_message_features: (E, edge_dim), optional message features 65 | """ 66 | 67 | graph: jraph.GraphsTuple 68 | node_attributes: Optional[e3nn.IrrepsArray] = None 69 | edge_attributes: Optional[e3nn.IrrepsArray] = None 70 | # NOTE: additional_message_features is in a separate field otherwise it would get 71 | # updated by jraph.GraphNetwork. Actual graph edges are used only for the messages. 72 | additional_message_features: Optional[e3nn.IrrepsArray] = None 73 | 74 | 75 | def node_irreps( 76 | metadata: Dict, 77 | input_seq_length: int, 78 | has_external_force: bool, 79 | has_magnitudes: bool, 80 | has_homogeneous_particles: bool, 81 | ) -> str: 82 | """Compute input node irreps based on which features are available.""" 83 | irreps = [] 84 | irreps.append(f"{input_seq_length - 1}x1o") 85 | if not any(metadata["periodic_boundary_conditions"]): 86 | irreps.append("2x1o") 87 | 88 | if has_external_force: 89 | irreps.append("1x1o") 90 | 91 | if has_magnitudes: 92 | irreps.append(f"{input_seq_length - 1}x0e") 93 | 94 | if not has_homogeneous_particles: 95 | irreps.append(f"{NodeType.SIZE}x0e") 96 | 97 | return e3nn.Irreps("+".join(irreps)) 98 | 99 | 100 | def build_mlp( 101 | latent_size, output_size, num_hidden_layers, is_layer_norm=True, **kwds: Dict 102 | ): 103 | """MLP generation helper using Haiku.""" 104 | assert num_hidden_layers >= 1 105 | network = hk.nets.MLP( 106 | [latent_size] * (num_hidden_layers - 1) + [output_size], 107 | **kwds, 108 | activate_final=False, 109 | name="MLP", 110 | ) 111 | if is_layer_norm: 112 | l_norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) 113 | return hk.Sequential([network, l_norm]) 114 | else: 115 | return network 116 | 117 | 118 | def features_2d_to_3d(features): 119 | """Add zeros in the z component of 2D features.""" 120 | n_nodes = features["vel_hist"].shape[0] 121 | n_edges = features["rel_disp"].shape[0] 122 | n_vels = features["vel_hist"].shape[1] 123 | features["vel_hist"] = jnp.concatenate( 124 | [features["vel_hist"], jnp.zeros((n_nodes, n_vels, 1))], -1 125 | ) 126 | features["rel_disp"] = jnp.concatenate( 127 | [features["rel_disp"], jnp.zeros((n_edges, 1))], -1 128 | ) 129 | if "bound" in features: 130 | features["bound"] = jnp.concatenate( 131 | [features["bound"], jnp.zeros((n_nodes, 1))], -1 132 | ) 133 | if "force" in features: 134 | features["force"] = jnp.concatenate( 135 | [features["force"], jnp.zeros((n_nodes, 1))], -1 136 | ) 137 | 138 | return features 139 | -------------------------------------------------------------------------------- /lagrangebench/runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from datetime import datetime 4 | from typing import Callable, Dict, Optional, Tuple, Type, Union 5 | 6 | import haiku as hk 7 | import jax 8 | import jax.numpy as jnp 9 | import jmp 10 | import numpy as np 11 | from e3nn_jax import Irreps 12 | from jax import config 13 | from jax_sph.jax_md import space 14 | from omegaconf import DictConfig, OmegaConf 15 | 16 | from lagrangebench import Trainer, infer, models 17 | from lagrangebench.case_setup import case_builder 18 | from lagrangebench.data import H5Dataset 19 | from lagrangebench.defaults import check_cfg 20 | from lagrangebench.evaluate import averaged_metrics 21 | from lagrangebench.models.utils import node_irreps 22 | from lagrangebench.utils import NodeType 23 | 24 | 25 | def train_or_infer(cfg: Union[Dict, DictConfig]): 26 | if isinstance(cfg, Dict): 27 | cfg = OmegaConf.create(cfg) 28 | # sanity check on the passed configs 29 | check_cfg(cfg) 30 | 31 | mode = cfg.mode 32 | load_ckp = cfg.load_ckp 33 | is_test = cfg.eval.test 34 | 35 | if cfg.dtype == "float64": 36 | config.update("jax_enable_x64", True) 37 | 38 | data_train, data_valid, data_test = setup_data(cfg) 39 | 40 | metadata = data_train.metadata 41 | # neighbors search 42 | bounds = np.array(metadata["bounds"]) 43 | box = bounds[:, 1] - bounds[:, 0] 44 | 45 | # setup core functions 46 | case = case_builder( 47 | box=box, 48 | metadata=metadata, 49 | input_seq_length=cfg.model.input_seq_length, 50 | cfg_neighbors=cfg.neighbors, 51 | cfg_model=cfg.model, 52 | noise_std=cfg.train.noise_std, 53 | external_force_fn=data_train.external_force_fn, 54 | dtype=cfg.dtype, 55 | ) 56 | 57 | _, particle_type = data_train[0] 58 | 59 | # setup model from configs 60 | model, MODEL = setup_model( 61 | cfg, 62 | metadata=metadata, 63 | homogeneous_particles=particle_type.max() == particle_type.min(), 64 | has_external_force=data_train.external_force_fn is not None, 65 | normalization_stats=case.normalization_stats, 66 | ) 67 | model = hk.without_apply_rng(hk.transform_with_state(model)) 68 | 69 | # mixed precision training based on this reference: 70 | # https://github.com/deepmind/dm-haiku/blob/main/examples/imagenet/train.py 71 | policy = jmp.get_policy("params=float32,compute=float32,output=float32") 72 | hk.mixed_precision.set_policy(MODEL, policy) 73 | 74 | if mode == "train" or mode == "all": 75 | print("Start training...") 76 | 77 | if cfg.logging.run_name is None: 78 | run_prefix = f"{cfg.model.name}_{data_train.name}" 79 | data_and_time = datetime.today().strftime("%Y%m%d-%H%M%S") 80 | cfg.logging.run_name = f"{run_prefix}_{data_and_time}" 81 | 82 | store_ckp = os.path.join(cfg.logging.ckp_dir, cfg.logging.run_name) 83 | os.makedirs(store_ckp, exist_ok=True) 84 | os.makedirs(os.path.join(store_ckp, "best"), exist_ok=True) 85 | with open(os.path.join(store_ckp, "config.yaml"), "w") as f: 86 | OmegaConf.save(config=cfg, f=f.name) 87 | with open(os.path.join(store_ckp, "best", "config.yaml"), "w") as f: 88 | OmegaConf.save(config=cfg, f=f.name) 89 | 90 | # dictionary of configs which will be stored on W&B 91 | wandb_config = OmegaConf.to_container(cfg) 92 | 93 | trainer = Trainer( 94 | model, 95 | case, 96 | data_train, 97 | data_valid, 98 | cfg.train, 99 | cfg.eval, 100 | cfg.logging, 101 | input_seq_length=cfg.model.input_seq_length, 102 | seed=cfg.seed, 103 | ) 104 | 105 | _, _, _ = trainer.train( 106 | step_max=cfg.train.step_max, 107 | load_ckp=load_ckp, 108 | store_ckp=store_ckp, 109 | wandb_config=wandb_config, 110 | ) 111 | 112 | if mode == "infer" or mode == "all": 113 | print("Start inference...") 114 | 115 | if mode == "infer": 116 | model_dir = load_ckp 117 | if mode == "all": 118 | model_dir = os.path.join(store_ckp, "best") 119 | assert osp.isfile(os.path.join(model_dir, "params_tree.pkl")) 120 | 121 | cfg.eval.rollout_dir = model_dir.replace("ckp", "rollout") 122 | os.makedirs(cfg.eval.rollout_dir, exist_ok=True) 123 | 124 | if cfg.eval.infer.n_trajs is None: 125 | cfg.eval.infer.n_trajs = cfg.eval.train.n_trajs 126 | 127 | assert model_dir, "model_dir must be specified for inference." 128 | metrics = infer( 129 | model, 130 | case, 131 | data_test if is_test else data_valid, 132 | load_ckp=model_dir, 133 | cfg_eval_infer=cfg.eval.infer, 134 | rollout_dir=cfg.eval.rollout_dir, 135 | n_rollout_steps=cfg.eval.n_rollout_steps, 136 | seed=cfg.seed, 137 | ) 138 | 139 | split = "test" if is_test else "valid" 140 | print(f"Metrics of {model_dir} on {split} split:") 141 | print(averaged_metrics(metrics)) 142 | 143 | return 0 144 | 145 | 146 | def setup_data(cfg) -> Tuple[H5Dataset, H5Dataset, H5Dataset]: 147 | dataset_path = cfg.dataset.src 148 | dataset_name = cfg.dataset.name 149 | ckp_dir = cfg.logging.ckp_dir 150 | rollout_dir = cfg.eval.rollout_dir 151 | input_seq_length = cfg.model.input_seq_length 152 | n_rollout_steps = cfg.eval.n_rollout_steps 153 | nl_backend = cfg.neighbors.backend 154 | 155 | if not osp.isabs(dataset_path): 156 | dataset_path = osp.join(os.getcwd(), dataset_path) 157 | 158 | if ckp_dir is not None: 159 | os.makedirs(ckp_dir, exist_ok=True) 160 | if rollout_dir is not None: 161 | os.makedirs(rollout_dir, exist_ok=True) 162 | 163 | # dataloader 164 | data_train = H5Dataset( 165 | "train", 166 | dataset_path=dataset_path, 167 | name=dataset_name, 168 | input_seq_length=input_seq_length, 169 | extra_seq_length=cfg.train.pushforward.unrolls[-1], 170 | nl_backend=nl_backend, 171 | ) 172 | data_valid = H5Dataset( 173 | "valid", 174 | dataset_path=dataset_path, 175 | name=dataset_name, 176 | input_seq_length=input_seq_length, 177 | extra_seq_length=n_rollout_steps, 178 | nl_backend=nl_backend, 179 | ) 180 | data_test = H5Dataset( 181 | "test", 182 | dataset_path=dataset_path, 183 | name=dataset_name, 184 | input_seq_length=input_seq_length, 185 | extra_seq_length=n_rollout_steps, 186 | nl_backend=nl_backend, 187 | ) 188 | 189 | return data_train, data_valid, data_test 190 | 191 | 192 | def setup_model( 193 | cfg, 194 | metadata: Dict, 195 | homogeneous_particles: bool = False, 196 | has_external_force: bool = False, 197 | normalization_stats: Optional[Dict] = None, 198 | ) -> Tuple[Callable, Type]: 199 | """Setup model based on cfg.""" 200 | model_name = cfg.model.name.lower() 201 | input_seq_length = cfg.model.input_seq_length 202 | magnitude_features = cfg.model.magnitude_features 203 | 204 | if model_name == "gns": 205 | 206 | def model_fn(x): 207 | return models.GNS( 208 | particle_dimension=metadata["dim"], 209 | latent_size=cfg.model.latent_dim, 210 | blocks_per_step=cfg.model.num_mlp_layers, 211 | num_mp_steps=cfg.model.num_mp_steps, 212 | num_particle_types=NodeType.SIZE, 213 | particle_type_embedding_size=16, 214 | )(x) 215 | 216 | MODEL = models.GNS 217 | elif model_name == "segnn": 218 | # Hx1o vel, Hx0e vel, 2x1o boundary, 9x0e type 219 | node_feature_irreps = node_irreps( 220 | metadata, 221 | input_seq_length, 222 | has_external_force, 223 | magnitude_features, 224 | homogeneous_particles, 225 | ) 226 | # 1o displacement, 0e distance 227 | edge_feature_irreps = Irreps("1x1o + 1x0e") 228 | 229 | def model_fn(x): 230 | return models.SEGNN( 231 | node_features_irreps=node_feature_irreps, 232 | edge_features_irreps=edge_feature_irreps, 233 | scalar_units=cfg.model.latent_dim, 234 | lmax_hidden=cfg.model.lmax_hidden, 235 | lmax_attributes=cfg.model.lmax_attributes, 236 | output_irreps=Irreps("1x1o"), 237 | num_mp_steps=cfg.model.num_mp_steps, 238 | n_vels=cfg.model.input_seq_length - 1, 239 | velocity_aggregate=cfg.model.velocity_aggregate, 240 | homogeneous_particles=homogeneous_particles, 241 | blocks_per_step=cfg.model.num_mlp_layers, 242 | norm=cfg.model.segnn_norm, 243 | )(x) 244 | 245 | MODEL = models.SEGNN 246 | elif model_name == "egnn": 247 | box = cfg.box 248 | if jnp.array(metadata["periodic_boundary_conditions"]).any(): 249 | displacement_fn, shift_fn = space.periodic(jnp.array(box)) 250 | else: 251 | displacement_fn, shift_fn = space.free() 252 | 253 | displacement_fn = jax.vmap(displacement_fn, in_axes=(0, 0)) 254 | shift_fn = jax.vmap(shift_fn, in_axes=(0, 0)) 255 | 256 | def model_fn(x): 257 | return models.EGNN( 258 | hidden_size=cfg.model.latent_dim, 259 | output_size=1, 260 | dt=metadata["dt"] * metadata["write_every"], 261 | displacement_fn=displacement_fn, 262 | shift_fn=shift_fn, 263 | normalization_stats=normalization_stats, 264 | num_mp_steps=cfg.model.num_mp_steps, 265 | n_vels=input_seq_length - 1, 266 | residual=True, 267 | )(x) 268 | 269 | MODEL = models.EGNN 270 | elif model_name == "painn": 271 | assert magnitude_features, "PaiNN requires magnitudes" 272 | radius = metadata["default_connectivity_radius"] * 1.5 273 | 274 | def model_fn(x): 275 | return models.PaiNN( 276 | hidden_size=cfg.model.latent_dim, 277 | output_size=1, 278 | n_vels=input_seq_length - 1, 279 | radial_basis_fn=models.painn.gaussian_rbf(20, radius, trainable=True), 280 | cutoff_fn=models.painn.cosine_cutoff(radius), 281 | num_mp_steps=cfg.model.num_mp_steps, 282 | )(x) 283 | 284 | MODEL = models.PaiNN 285 | elif model_name == "linear": 286 | 287 | def model_fn(x): 288 | return models.Linear(dim_out=metadata["dim"])(x) 289 | 290 | MODEL = models.Linear 291 | 292 | return model_fn, MODEL 293 | -------------------------------------------------------------------------------- /lagrangebench/train/__init__.py: -------------------------------------------------------------------------------- 1 | """Trainer method and training tricks.""" 2 | 3 | from .trainer import Trainer 4 | 5 | __all__ = ["Trainer"] 6 | -------------------------------------------------------------------------------- /lagrangebench/train/strats.py: -------------------------------------------------------------------------------- 1 | """Training tricks and strategies, currently: random-walk noise and push forward.""" 2 | 3 | from typing import Tuple 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | from jax_sph.jax_md.partition import space 8 | 9 | from lagrangebench.utils import get_kinematic_mask 10 | 11 | 12 | def add_gns_noise( 13 | key: jax.Array, 14 | pos_input: jnp.ndarray, 15 | particle_type: jnp.ndarray, 16 | input_seq_length: int, 17 | noise_std: float, 18 | shift_fn: space.ShiftFn, 19 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 20 | r"""GNS-like random walk noise injection as described by 21 | `Sanchez-Gonzalez et al. `_. 22 | 23 | Applies random-walk noise to the input positions and adjusts the targets accordingly 24 | to keep the trajectory consistent. It works by drawing independent samples from 25 | :math:`\mathcal{N^{(t)}}(0, \sigma_v^{(t)})` for each input state. Noise is 26 | accummulated as a random walk and added to the velocity seqence. 27 | Each :math:`\sigma_v^{(t)}` is set so that the last step of the random walk has 28 | :math:`\sigma_v^{(input\_seq\_length)}=noise\_std`. Based on the noised velocities, 29 | positions are adjusted such that :math:`\dot{p}^{t_k} = p^{t_k} − p^{t_{k−1}}`. 30 | 31 | Args: 32 | key: Random key. 33 | pos_input: Clean input positions. Shape: 34 | (num_particles_max, input_seq_length + pushforward["unrolls"][-1] + 1, dim) 35 | particle_type: Particle type vector. Shape: (num_particles_max,) 36 | input_seq_length: Input sequence length, as in the configs. 37 | noise_std: Noise standard deviation at the last sequence step. 38 | shift_fn: Shift function. 39 | """ 40 | isl = input_seq_length 41 | # random-walk noise in the velocity applied to the first input_seq_length positions 42 | key, pos_input_noise = _get_random_walk_noise_for_pos_sequence( 43 | key, pos_input[:, :input_seq_length], noise_std_last_step=noise_std 44 | ) 45 | 46 | kinematic_mask = get_kinematic_mask(particle_type) 47 | pos_input_noise = jnp.where(kinematic_mask[:, None, None], 0.0, pos_input_noise) 48 | # adjust targets based on the noise from the last input position 49 | n_potential_targets = pos_input[:, isl:].shape[1] 50 | pos_target_noise = pos_input_noise[:, -1][:, None, :] 51 | pos_target_noise = jnp.tile(pos_target_noise, (1, n_potential_targets, 1)) 52 | pos_input_noise = jnp.concatenate([pos_input_noise, pos_target_noise], axis=1) 53 | 54 | shift_vmap = jax.vmap(shift_fn, in_axes=(0, 0)) 55 | shift_dvmap = jax.vmap(shift_vmap, in_axes=(0, 0)) 56 | pos_input_noisy = shift_dvmap(pos_input, pos_input_noise) 57 | 58 | return key, pos_input_noisy 59 | 60 | 61 | def _get_random_walk_noise_for_pos_sequence( 62 | key, position_sequence, noise_std_last_step 63 | ): 64 | key, subkey = jax.random.split(key) 65 | velocity_sequence_shape = list(position_sequence.shape) 66 | velocity_sequence_shape[1] -= 1 67 | n_velocities = velocity_sequence_shape[1] 68 | 69 | velocity_sequence_noise = jax.random.normal( 70 | subkey, shape=tuple(velocity_sequence_shape) 71 | ) 72 | velocity_sequence_noise *= noise_std_last_step / (n_velocities**0.5) 73 | velocity_sequence_noise = jnp.cumsum(velocity_sequence_noise, axis=1) 74 | 75 | position_sequence_noise = jnp.concatenate( 76 | [ 77 | jnp.zeros_like(velocity_sequence_noise[:, 0:1]), 78 | jnp.cumsum(velocity_sequence_noise, axis=1), 79 | ], 80 | axis=1, 81 | ) 82 | 83 | return key, position_sequence_noise 84 | 85 | 86 | def push_forward_sample_steps(key, step, pushforward): 87 | """Sample the number of unroll steps based on the current training step and the 88 | specified pushforward configuration. 89 | 90 | Args: 91 | key: Random key 92 | step: Current training step 93 | pushforward: Pushforward configuration 94 | """ 95 | key, key_unroll = jax.random.split(key, 2) 96 | 97 | # steps needs to be an ordered list 98 | steps = jnp.array(pushforward.steps) 99 | assert all(steps[i] <= steps[i + 1] for i in range(len(steps) - 1)) 100 | 101 | # until which index to sample from 102 | idx = (step > steps).sum() 103 | 104 | unroll_steps = jax.random.choice( 105 | key_unroll, 106 | a=jnp.array(pushforward.unrolls[:idx]), 107 | p=jnp.array(pushforward.probs[:idx]), 108 | ) 109 | return key, unroll_steps 110 | 111 | 112 | def push_forward_build(model_apply, case): 113 | r"""Build the push forward function, introduced by 114 | `Brandstetter et al. `_. 115 | 116 | Pushforward works by adding a stability "pushforward" loss term, in the form of an 117 | adversarial style loss. 118 | 119 | .. math:: 120 | L_{pf} = \mathbb{E}_k \mathbb{E}_{u^{k+1} | u^k} 121 | \mathbb{E}_{\epsilon} \left[ \mathcal{L}(f(u^k + \epsilon), u^{k-1}) \right] 122 | 123 | where :math:`\epsilon` is :math:`u^k + \epsilon = f(u^{k−1})`, i.e. the 2-step 124 | unroll of the solver :math:`f` (from step :math:`k-1` to :math:`k`). 125 | The total loss is then :math:`L_{total}=\mathcal{L}(f(u^k), u^{k-1}) + L_{pf}`. 126 | Similarly, for :math:`S > 2` pushforward steps, :math:`L_{pf}` is extended to 127 | :math:`u^{k-S} \dots u^{k-1}` with cumulated :math:`\epsilon` perturbations. 128 | 129 | In practice, this is implemented by unrolling the solver for two steps, but only 130 | running gradients through the last unroll step. 131 | 132 | Args: 133 | model_apply: Model apply function 134 | case: Case setup function 135 | """ 136 | 137 | @jax.jit 138 | def push_forward_fn(features, current_pos, particle_type, neighbors, params, state): 139 | """Push forward function. 140 | 141 | Args: 142 | features: Input features 143 | current_pos: Current position 144 | particle_type: Particle type vector 145 | neighbors: Neighbor list 146 | params: Model parameters 147 | state: Model state 148 | """ 149 | # no buffer overflow check here, since push forward acts on later epochs 150 | pred, _ = model_apply(params, state, (features, particle_type)) 151 | next_pos = case.integrate(pred, current_pos) 152 | current_pos = jnp.concatenate( 153 | [current_pos[:, 1:], next_pos[:, None, :]], axis=1 154 | ) 155 | 156 | features, neighbors = case.preprocess_eval( 157 | (current_pos, particle_type), neighbors 158 | ) 159 | return current_pos, neighbors, features 160 | 161 | return push_forward_fn 162 | -------------------------------------------------------------------------------- /lagrangebench/utils.py: -------------------------------------------------------------------------------- 1 | """General utils and config structures.""" 2 | 3 | import enum 4 | import json 5 | import os 6 | import pickle 7 | import random 8 | from typing import Callable, Tuple 9 | 10 | import cloudpickle 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | import torch 15 | 16 | 17 | class NodeType(enum.IntEnum): 18 | """Particle types.""" 19 | 20 | PAD_VALUE = -1 21 | FLUID = 0 22 | SOLID_WALL = 1 23 | MOVING_WALL = 2 24 | RIGID_BODY = 3 25 | SIZE = 9 26 | 27 | 28 | def get_kinematic_mask(particle_type): 29 | """Return a boolean mask, set to true for all kinematic (obstacle) particles.""" 30 | res = jnp.logical_or( 31 | particle_type == NodeType.SOLID_WALL, particle_type == NodeType.MOVING_WALL 32 | ) 33 | # In datasets with variable number of particles we treat padding as kinematic nodes 34 | res = jnp.logical_or(res, particle_type == NodeType.PAD_VALUE) 35 | return res 36 | 37 | 38 | def broadcast_to_batch(sample, batch_size: int): 39 | """Broadcast a pytree to a batched one with first dimension batch_size.""" 40 | assert batch_size > 0 41 | return jax.tree_map(lambda x: jnp.repeat(x[None, ...], batch_size, axis=0), sample) 42 | 43 | 44 | def broadcast_from_batch(batch, index: int): 45 | """Broadcast a batched pytree to the sample `index` out of the batch.""" 46 | assert index >= 0 47 | return jax.tree_map(lambda x: x[index], batch) 48 | 49 | 50 | def save_pytree(ckp_dir: str, pytree_obj, name) -> None: 51 | """Save a pytree to a directory.""" 52 | with open(os.path.join(ckp_dir, f"{name}_array.npy"), "wb") as f: 53 | for x in jax.tree_leaves(pytree_obj): 54 | np.save(f, x, allow_pickle=False) 55 | 56 | tree_struct = jax.tree_map(lambda t: 0, pytree_obj) 57 | with open(os.path.join(ckp_dir, f"{name}_tree.pkl"), "wb") as f: 58 | pickle.dump(tree_struct, f) 59 | 60 | 61 | def save_haiku(ckp_dir: str, params, state, opt_state, metadata_ckp) -> None: 62 | """Save params, state and optimizer state to ckp_dir. 63 | 64 | Additionally it tracks and saves the best model to ckp_dir/best. 65 | 66 | See: https://github.com/deepmind/dm-haiku/issues/18 67 | """ 68 | save_pytree(ckp_dir, params, "params") 69 | save_pytree(ckp_dir, state, "state") 70 | 71 | with open(os.path.join(ckp_dir, "opt_state.pkl"), "wb") as f: 72 | cloudpickle.dump(opt_state, f) 73 | with open(os.path.join(ckp_dir, "metadata_ckp.json"), "w") as f: 74 | json.dump(metadata_ckp, f) 75 | 76 | # only run for the main checkpoint directory (not best) 77 | if "best" not in ckp_dir: 78 | ckp_dir_best = os.path.join(ckp_dir, "best") 79 | metadata_best_path = os.path.join(ckp_dir, "best", "metadata_ckp.json") 80 | tag = "" 81 | 82 | if os.path.exists(metadata_best_path): # all except first step 83 | with open(metadata_best_path, "r") as fp: 84 | metadata_ckp_best = json.loads(fp.read()) 85 | 86 | # if loss is better than best previous loss, save to best model directory 87 | if metadata_ckp["loss"] < metadata_ckp_best["loss"]: 88 | save_haiku(ckp_dir_best, params, state, opt_state, metadata_ckp) 89 | tag = " (best so far)" 90 | else: # first step 91 | save_haiku(ckp_dir_best, params, state, opt_state, metadata_ckp) 92 | 93 | print( 94 | f"saved model to {ckp_dir} at step {metadata_ckp['step']}" 95 | f" with loss {metadata_ckp['loss']}{tag}" 96 | ) 97 | 98 | 99 | def load_pytree(model_dir: str, name): 100 | """Load a pytree from a directory.""" 101 | with open(os.path.join(model_dir, f"{name}_tree.pkl"), "rb") as f: 102 | tree_struct = pickle.load(f) 103 | 104 | leaves, treedef = jax.tree_flatten(tree_struct) 105 | 106 | with open(os.path.join(model_dir, f"{name}_array.npy"), "rb") as f: 107 | flat_state = [np.load(f) for _ in leaves] 108 | 109 | return jax.tree_unflatten(treedef, flat_state) 110 | 111 | 112 | def load_haiku(model_dir: str): 113 | """Load params, state, optimizer state and last training step from model_dir. 114 | 115 | See: https://github.com/deepmind/dm-haiku/issues/18 116 | """ 117 | params = load_pytree(model_dir, "params") 118 | state = load_pytree(model_dir, "state") 119 | 120 | with open(os.path.join(model_dir, "opt_state.pkl"), "rb") as f: 121 | opt_state = cloudpickle.load(f) 122 | 123 | with open(os.path.join(model_dir, "metadata_ckp.json"), "r") as fp: 124 | metadata_ckp = json.loads(fp.read()) 125 | 126 | print(f"Loaded model from {model_dir} at step {metadata_ckp['step']}") 127 | 128 | return params, state, opt_state, metadata_ckp["step"] 129 | 130 | 131 | def get_num_params(params): 132 | """Get the number of parameters in a Haiku model.""" 133 | return sum(np.prod(p.shape) for p in jax.tree_leaves(params)) 134 | 135 | 136 | def print_params_shapes(params, prefix=""): 137 | if not isinstance(params, dict): 138 | print(f"{prefix: <40}, shape = {params.shape}") 139 | else: 140 | for k, v in params.items(): 141 | print_params_shapes(v, prefix=prefix + k) 142 | 143 | 144 | def set_seed(seed: int) -> Tuple[jax.Array, Callable, torch.Generator]: 145 | """Set seeds for jax, random and torch.""" 146 | # first PRNG key 147 | key = jax.random.PRNGKey(seed) 148 | np.random.seed(seed) 149 | random.seed(seed) 150 | torch.manual_seed(seed) 151 | 152 | # dataloader-related seeds 153 | def seed_worker(_): 154 | worker_seed = torch.initial_seed() % 2**32 155 | np.random.seed(worker_seed) 156 | random.seed(worker_seed) 157 | 158 | generator = torch.Generator() 159 | generator.manual_seed(seed) 160 | 161 | return key, seed_worker, generator 162 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | 6 | def check_subset(superset, subset, full_key=""): 7 | """Check that the keys of 'subset' are a subset of 'superset'.""" 8 | for k, v in subset.items(): 9 | key = full_key + k 10 | if isinstance(v, dict): 11 | check_subset(superset[k], v, key + ".") 12 | else: 13 | msg = f"cli_args must be a subset of the defaults. Wrong cli key: '{key}'" 14 | assert k in superset, msg 15 | 16 | 17 | def load_embedded_configs(config_path: str, cli_args: DictConfig) -> DictConfig: 18 | """Loads all 'extends' embedded configs and merge them with the cli overwrites.""" 19 | 20 | cfgs = [OmegaConf.load(config_path)] 21 | while "extends" in cfgs[0]: 22 | extends_path = cfgs[0]["extends"] 23 | del cfgs[0]["extends"] 24 | 25 | # go to parents configs until the defaults are reached 26 | if extends_path != "LAGRANGEBENCH_DEFAULTS": 27 | cfgs = [OmegaConf.load(extends_path)] + cfgs 28 | else: 29 | from lagrangebench.defaults import defaults 30 | 31 | cfgs = [defaults] + cfgs 32 | 33 | # assert that the cli_args are a subset of the defaults if inheritance from 34 | # defaults is used. 35 | check_subset(cfgs[0], cli_args) 36 | 37 | break 38 | 39 | # merge all embedded configs and give highest priority to cli_args 40 | cfg = OmegaConf.merge(*cfgs, cli_args) 41 | return cfg 42 | 43 | 44 | if __name__ == "__main__": 45 | cli_args = OmegaConf.from_cli() 46 | assert ("config" in cli_args) != ( 47 | "load_ckp" in cli_args 48 | ), "You must specify one of 'config' or 'load_ckp'." 49 | 50 | if "config" in cli_args: # start from config.yaml 51 | config_path = cli_args.config 52 | elif "load_ckp" in cli_args: # start from a checkpoint 53 | config_path = os.path.join(cli_args.load_ckp, "config.yaml") 54 | 55 | # values that need to be specified before importing jax 56 | cli_args.gpu = cli_args.get("gpu", -1) 57 | cli_args.xla_mem_fraction = cli_args.get("xla_mem_fraction", 0.75) 58 | 59 | # specify cuda device 60 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow 61 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cli_args.gpu) 62 | if cli_args.gpu == -1: 63 | os.environ["JAX_PLATFORMS"] = "cpu" 64 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cli_args.xla_mem_fraction) 65 | 66 | # The following line makes the code deterministic on GPUs, but also extremely slow. 67 | # os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true" 68 | 69 | cfg = load_embedded_configs(config_path, cli_args) 70 | 71 | print("#" * 79, "\nStarting a LagrangeBench run with the following configs:") 72 | print(OmegaConf.to_yaml(cfg)) 73 | print("#" * 79) 74 | 75 | from lagrangebench.runner import train_or_infer 76 | 77 | train_or_infer(cfg) 78 | -------------------------------------------------------------------------------- /notebooks/media/scatter.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tumaer/lagrangebench/b880a6c84a93792d2499d2a9b8ba3a077ddf44e2/notebooks/media/scatter.gif -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "lagrangebench" 3 | version = "0.2.0" 4 | description = "LagrangeBench: A Lagrangian Fluid Mechanics Benchmarking Suite" 5 | authors = [ 6 | "Artur Toshev, Gianluca Galletti " 7 | ] 8 | license = "MIT" 9 | readme = "README.md" 10 | homepage = "https://lagrangebench.readthedocs.io/" 11 | documentation = "https://lagrangebench.readthedocs.io/" 12 | repository = "https://github.com/tumaer/lagrangebench" 13 | keywords = [ 14 | "smoothed-particle-hydrodynamics", 15 | "benchmark-suite", 16 | "lagrangian-dynamics", 17 | "graph-neural-networks", 18 | "lagrangian-particles", 19 | ] 20 | classifiers = [ 21 | "Development Status :: 3 - Alpha", 22 | "Intended Audience :: Science/Research", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: MacOS", 25 | "Operating System :: POSIX :: Linux", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | "Topic :: Scientific/Engineering :: Physics", 32 | "Topic :: Scientific/Engineering :: Hydrology", 33 | "Topic :: Software Development :: Libraries :: Python Modules", 34 | "Typing :: Typed", 35 | ] 36 | 37 | [tool.poetry.dependencies] 38 | python = ">=3.9,<=3.11" 39 | cloudpickle = ">=2.2.1" 40 | h5py = ">=3.9.0" 41 | PyYAML = ">=6.0" 42 | numpy = ">=1.24.4" 43 | wandb = ">=0.15.11" 44 | pyvista = ">=0.42.2" 45 | jax = {version = "0.4.29", extras = ["cpu"]} 46 | jaxlib = "0.4.29" 47 | dm-haiku = ">=0.0.10" 48 | e3nn-jax = "0.20.3" 49 | jmp = ">=0.0.4" 50 | jraph = "0.0.6.dev0" 51 | optax = "0.1.7" 52 | ott-jax = ">=0.4.2" 53 | matscipy = ">=0.8.0" 54 | torch = {version = "2.3.1+cpu", source = "torchcpu"} 55 | wget = ">=3.2" 56 | omegaconf = ">=2.3.0" 57 | jax-sph = ">=0.0.3" 58 | 59 | [tool.poetry.group.dev.dependencies] 60 | # mypy = ">=1.8.0" - consider in the future 61 | pre-commit = ">=3.3.1" 62 | pytest = ">=7.3.1" 63 | pytest-cov = ">=4.1.0" 64 | ruff = "0.2.2" 65 | ipykernel = ">=6.25.1" 66 | 67 | [tool.poetry.group.docs.dependencies] 68 | sphinx = "7.2.6" 69 | sphinx-rtd-theme = "1.3.0" 70 | toml = ">=0.10.2" 71 | 72 | [[tool.poetry.source]] 73 | name = "torchcpu" 74 | url = "https://download.pytorch.org/whl/cpu" 75 | priority = "explicit" 76 | 77 | [tool.ruff] 78 | exclude = [ 79 | ".git", 80 | ".venv", 81 | "venv", 82 | "docs/_build", 83 | "dist", 84 | "notebooks", 85 | ] 86 | show-fixes = true 87 | line-length = 88 88 | 89 | [tool.ruff.lint] 90 | ignore = ["F811", "E402"] 91 | select = [ 92 | "E", # pycodestyle 93 | "F", # Pyflakes 94 | "SIM", # flake8-simplify 95 | "I", # isort 96 | # "D", # pydocstyle - consider in the future 97 | ] 98 | 99 | [tool.ruff.lint.isort] 100 | known-third-party = ["wandb"] 101 | 102 | [tool.pytest.ini_options] 103 | testpaths = "tests/" 104 | addopts = "--cov=lagrangebench --cov-fail-under=50" 105 | filterwarnings = [ 106 | # ignore all deprecation warnings except from lagrangebench 107 | "ignore::DeprecationWarning:^(?!.*lagrangebench).*" 108 | ] 109 | 110 | # Install bumpversion with: pip install -U poetry-bumpversion 111 | # Use: poetry version {major|minor|patch} 112 | [tool.poetry_bumpversion.file."lagrangebench/__init__.py"] 113 | [build-system] 114 | requires = ["poetry-core"] 115 | build-backend = "poetry.core.masonry.api" 116 | -------------------------------------------------------------------------------- /requirements_cuda.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | 3 | cloudpickle 4 | dm_haiku>=0.0.10 5 | e3nn_jax==0.20.3 6 | h5py 7 | jax-sph>=0.0.3 8 | jax[cuda12]==0.4.29 9 | jmp>=0.0.4 10 | jraph>=0.0.6.dev0 11 | matscipy>=0.8.0 12 | omegaconf>=2.3.0 13 | optax>=0.1.7 14 | ott-jax>=0.4.2 15 | pyvista 16 | PyYAML 17 | torch==2.3.1+cpu 18 | wandb 19 | wget 20 | -------------------------------------------------------------------------------- /tests/3D_LJ_3_1214every1/metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "solver": "JAXMD", 3 | "dim": 3, 4 | "dx": 1.4, 5 | "dt": 0.005, 6 | "t_end": 10.0, 7 | "sequence_length_train": 1214, 8 | "num_trajs_train": 1, 9 | "sequence_length_test": 405, 10 | "num_trajs_test": 1, 11 | "num_particles_max": 3, 12 | "periodic_boundary_conditions": [ 13 | true, 14 | true, 15 | true 16 | ], 17 | "bounds": [ 18 | [ 19 | 0.0, 20 | 5.0 21 | ], 22 | [ 23 | 0.0, 24 | 5.0 25 | ], 26 | [ 27 | 0.0, 28 | 5.0 29 | ] 30 | ], 31 | "default_connectivity_radius": 3.0, 32 | "vel_mean": [ 33 | -5.573862482677328e-10, 34 | 4.917874996124283e-10, 35 | -1.3441651125489784e-09 36 | ], 37 | "vel_std": [ 38 | 0.006350979674607515, 39 | 0.005811989773064852, 40 | 0.003586509730666876 41 | ], 42 | "acc_mean": [ 43 | -3.2785833076198756e-11, 44 | -6.557166615239751e-11, 45 | 0.0 46 | ], 47 | "acc_std": [ 48 | 0.0011505373986437917, 49 | 0.0005201193853281438, 50 | 0.00039340186049230397 51 | ], 52 | "description": "System of 3 Lennard-Jones particles in a periodic 3D box simulated with JAX-MD. Can be used to test the preprocessing and rollout utilities." 53 | } 54 | -------------------------------------------------------------------------------- /tests/3D_LJ_3_1214every1/test.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tumaer/lagrangebench/b880a6c84a93792d2499d2a9b8ba3a077ddf44e2/tests/3D_LJ_3_1214every1/test.h5 -------------------------------------------------------------------------------- /tests/3D_LJ_3_1214every1/train.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tumaer/lagrangebench/b880a6c84a93792d2499d2a9b8ba3a077ddf44e2/tests/3D_LJ_3_1214every1/train.h5 -------------------------------------------------------------------------------- /tests/3D_LJ_3_1214every1/valid.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tumaer/lagrangebench/b880a6c84a93792d2499d2a9b8ba3a077ddf44e2/tests/3D_LJ_3_1214every1/valid.h5 -------------------------------------------------------------------------------- /tests/case_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | from lagrangebench.case_setup import case_builder 8 | 9 | 10 | class TestCaseBuilder(unittest.TestCase): 11 | """Class for unit testing the case builder functions.""" 12 | 13 | def setUp(self): 14 | self.metadata = { 15 | "num_particles_max": 3, 16 | "periodic_boundary_conditions": [True, True, True], 17 | "default_connectivity_radius": 0.3, 18 | "bounds": [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]], 19 | "acc_mean": [0.0, 0.0, 0.0], 20 | "acc_std": [1.0, 1.0, 1.0], 21 | "vel_mean": [0.0, 0.0, 0.0], 22 | "vel_std": [1.0, 1.0, 1.0], 23 | } 24 | 25 | bounds = np.array(self.metadata["bounds"]) 26 | box = bounds[:, 1] - bounds[:, 0] 27 | 28 | self.case = case_builder( 29 | box, 30 | self.metadata, 31 | input_seq_length=3, # two past velocities 32 | cfg_neighbors={"backend": "jaxmd_vmap", "multiplier": 1.25}, 33 | cfg_model={"isotropic_norm": False, "magnitude_features": False}, 34 | noise_std=0.0, 35 | external_force_fn=None, 36 | ) 37 | self.key = jax.random.PRNGKey(0) 38 | 39 | # position input shape (num_particles, sequence_len, dim) = (3, 5, 3) 40 | self.position_data = np.array( 41 | [ 42 | [ 43 | [0.5, 0.5, 0.5], 44 | [0.5, 0.5, 0.5], 45 | [0.5, 0.5, 0.5], 46 | [0.5, 0.5, 0.5], 47 | [0.5, 0.5, 0.5], 48 | ], 49 | [ 50 | [0.7, 0.5, 0.5], 51 | [0.9, 0.5, 0.5], 52 | [0.1, 0.5, 0.5], 53 | [0.3, 0.5, 0.5], 54 | [0.5, 0.5, 0.5], 55 | ], 56 | [ 57 | [0.8, 0.6, 0.5], 58 | [0.8, 0.6, 0.5], 59 | [0.9, 0.6, 0.5], 60 | [0.2, 0.6, 0.5], 61 | [0.6, 0.6, 0.5], 62 | ], 63 | ] 64 | ) 65 | self.particle_types = np.array([0, 0, 0]) 66 | 67 | _, _, _, neighbors = self.case.allocate( 68 | self.key, (self.position_data, self.particle_types) 69 | ) 70 | self.neighbors = neighbors 71 | 72 | def test_allocate(self): 73 | # test PBC and velocity and acceleration computation without noise 74 | key, features, target_dict, neighbors = self.case.allocate( 75 | self.key, (self.position_data, self.particle_types) 76 | ) 77 | self.assertTrue( 78 | ( 79 | neighbors.idx == jnp.array([[0, 1, 2, 2, 1, 3], [0, 1, 1, 2, 2, 3]]) 80 | ).all(), 81 | "Wrong edge list after allocate", 82 | ) 83 | 84 | self.assertTrue((key != self.key).all(), "Key not updated at allocate") 85 | 86 | self.assertTrue( 87 | jnp.isclose( 88 | target_dict["vel"], 89 | jnp.array([[0.0, 0.0, 0.0], [0.2, 0.0, 0.0], [0.3, 0.0, 0.0]]), 90 | ).all(), 91 | "Wrong target velocity at allocate", 92 | ) 93 | 94 | self.assertTrue( 95 | jnp.isclose( 96 | target_dict["acc"], 97 | jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.2, 0.0, 0.0]]), 98 | ).all(), 99 | "Wrong target acceleration at allocate", 100 | ) 101 | 102 | self.assertTrue( 103 | jnp.isclose( 104 | features["vel_hist"], 105 | jnp.array( 106 | [ 107 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # particle 1, two past vels. 108 | [0.2, 0.0, 0.0, 0.2, 0.0, 0.0], 109 | [0.0, 0.0, 0.0, 0.1, 0.0, 0.0], 110 | ] 111 | ), 112 | ).all(), 113 | "Wrong historic velocities at allocate", 114 | ) 115 | 116 | most_recent_displacement = jnp.array( 117 | [ 118 | [0.0, 0.0, 0.0], # edge 0-0 119 | [0.0, 0.0, 0.0], # edge 1-1 120 | [-0.2, 0.1, 0.0], # edge 2-1 121 | [0.0, 0.0, 0.0], # edge 2-2 122 | [0.2, -0.1, 0.0], # edge 1-2 123 | [0.0, 0.0, 0.0], # edge 3-3 124 | ] 125 | ) 126 | r0 = self.metadata["default_connectivity_radius"] 127 | normalized_displ = most_recent_displacement / r0 128 | normalized_dist = ((normalized_displ**2).sum(-1, keepdims=True)) ** 0.5 129 | 130 | self.assertTrue( 131 | jnp.isclose(features["rel_disp"], normalized_displ).all(), 132 | "Wrong relative displacement at allocate", 133 | ) 134 | self.assertTrue( 135 | jnp.isclose(features["rel_dist"], normalized_dist).all(), 136 | "Wrong relative distance at allocate", 137 | ) 138 | 139 | def test_preprocess_base(self): 140 | # preprocess is 1-to-1 the same as allocate, up to the neighbors' computation 141 | _, _, _, neighbors_new = self.case.preprocess( 142 | self.key, (self.position_data, self.particle_types), 0.0, self.neighbors, 0 143 | ) 144 | 145 | self.assertTrue( 146 | (self.neighbors.idx == neighbors_new.idx).all(), 147 | "Wrong edge list after preprocess", 148 | ) 149 | 150 | def test_preprocess_unroll(self): 151 | # test getting the second available target acceleration 152 | _, _, target_dict, _ = self.case.preprocess( 153 | self.key, (self.position_data, self.particle_types), 0.0, self.neighbors, 1 154 | ) 155 | 156 | self.assertTrue( 157 | jnp.isclose( 158 | target_dict["acc"], 159 | jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.1, 0.0, 0.0]]), 160 | atol=1e-07, 161 | ).all(), 162 | "Wrong target acceleration at preprocess", 163 | ) 164 | 165 | def test_preprocess_noise(self): 166 | # test that both potential targets are corrected with the proper noise 167 | # we choose noise_std=0.01 to guarantee that no particle will jump periodically 168 | _, features, target_dict, _ = self.case.preprocess( 169 | self.key, (self.position_data, self.particle_types), 0.01, self.neighbors, 0 170 | ) 171 | vel_next1 = jnp.array([[0.0, 0.0, 0.0], [0.2, 0.0, 0.0], [0.3, 0.0, 0.0]]) 172 | correct_target_acc = vel_next1 - features["vel_hist"][:, 3:6] 173 | self.assertTrue( 174 | jnp.isclose(correct_target_acc, target_dict["acc"], atol=1e-7).all(), 175 | "Wrong target acceleration at preprocess", 176 | ) 177 | 178 | # with one push-forward step on top 179 | _, features, target_dict, _ = self.case.preprocess( 180 | self.key, (self.position_data, self.particle_types), 0.01, self.neighbors, 1 181 | ) 182 | vel_next2 = jnp.array([[0.0, 0.0, 0.0], [0.2, 0.0, 0.0], [0.4, 0.0, 0.0]]) 183 | correct_target_acc = vel_next2 - vel_next1 184 | self.assertTrue( 185 | jnp.isclose(correct_target_acc, target_dict["acc"], atol=1e-7).all(), 186 | "Wrong target acceleration at preprocess with 1 pushforward step", 187 | ) 188 | 189 | def test_allocate_eval(self): 190 | pass 191 | 192 | def test_preprocess_eval(self): 193 | pass 194 | 195 | def test_integrate(self): 196 | # given the reference acceleration, compute the next position 197 | correct_acceletation = { 198 | "acc": jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.2, 0.0, 0.0]]) 199 | } 200 | 201 | new_pos = self.case.integrate(correct_acceletation, self.position_data[:, :3]) 202 | 203 | self.assertTrue( 204 | jnp.isclose(new_pos, self.position_data[:, 3]).all(), 205 | "Wrong new position at integration", 206 | ) 207 | 208 | 209 | if __name__ == "__main__": 210 | unittest.main() 211 | -------------------------------------------------------------------------------- /tests/models_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import e3nn_jax as e3nn 4 | import haiku as hk 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | from lagrangebench import models 10 | from lagrangebench.utils import NodeType 11 | 12 | 13 | class ModelTest(unittest.TestCase): 14 | def dummy_sample(self, vel=None, pos=None): 15 | key = self.key() 16 | 17 | if vel is None: 18 | vel = jax.random.uniform(key, (100, 5 * 3)) 19 | if pos is None: 20 | pos = jax.random.uniform(key, (100, 1, 3)) 21 | 22 | senders = jax.random.randint(key, (200,), 0, 100) 23 | receivers = jax.random.randint(key, (200,), 0, 100) 24 | rel_disp = (pos[receivers] - pos[senders]).squeeze() 25 | x = { 26 | "vel_hist": vel, 27 | "vel_mag": jnp.sum(vel.reshape(100, -1, 3) ** 2, -1) ** 0.5, 28 | "rel_disp": rel_disp, 29 | "rel_dist": jnp.sum(rel_disp**2, -1, keepdims=True) ** 0.5, 30 | "abs_pos": pos, 31 | "senders": senders, 32 | "receivers": receivers, 33 | } 34 | particle_type = jnp.ones((100, 1), dtype=jnp.int32) * NodeType.FLUID 35 | return x, particle_type 36 | 37 | def key(self): 38 | return jax.random.PRNGKey(0) 39 | 40 | def assert_equivariant(self, f, params, state): 41 | key = self.key() 42 | 43 | vel = e3nn.normal("5x1o", key, (100,)) 44 | pos = e3nn.normal("1x1o", key, (100,)) 45 | 46 | def wrapper(v, p): 47 | sample, particle_type = self.dummy_sample() 48 | sample.update( 49 | { 50 | "vel_hist": v.array.reshape((100, 5 * 3)), 51 | "abs_pos": p.array.reshape((100, 1, 3)), 52 | } 53 | ) 54 | y, _ = f.apply(params, state, (sample, particle_type)) 55 | return e3nn.IrrepsArray("1x1o", y["acc"]) 56 | 57 | # random rotation matrix 58 | R = -e3nn.rand_matrix(key, ()) 59 | 60 | out1 = wrapper(vel.transform_by_matrix(R), pos.transform_by_matrix(R)) 61 | out2 = wrapper(vel, pos).transform_by_matrix(R) 62 | 63 | def assert_(x, y): 64 | self.assertTrue( 65 | np.isclose(x, y, atol=1e-5, rtol=1e-5).all(), "Not equivariant!" 66 | ) 67 | 68 | jax.tree_util.tree_map(assert_, out1, out2) 69 | 70 | def test_segnn(self): 71 | def segnn(x): 72 | return models.SEGNN( 73 | node_features_irreps="5x1o + 5x0e", 74 | edge_features_irreps="1x1o + 1x0e", 75 | scalar_units=8, 76 | lmax_hidden=1, 77 | lmax_attributes=1, 78 | n_vels=5, 79 | num_mp_steps=1, 80 | output_irreps="1x1o", 81 | )(x) 82 | 83 | segnn = hk.without_apply_rng(hk.transform_with_state(segnn)) 84 | x, particle_type = self.dummy_sample() 85 | params, segnn_state = segnn.init(self.key(), (x, particle_type)) 86 | 87 | self.assert_equivariant(segnn, params, segnn_state) 88 | 89 | def test_egnn(self): 90 | def egnn(x): 91 | return models.EGNN( 92 | hidden_size=8, 93 | output_size=1, 94 | num_mp_steps=1, 95 | dt=0.01, 96 | n_vels=5, 97 | displacement_fn=lambda x, y: x - y, 98 | shift_fn=lambda x, y: x + y, 99 | )(x) 100 | 101 | egnn = hk.without_apply_rng(hk.transform_with_state(egnn)) 102 | x, particle_type = self.dummy_sample() 103 | params, egnn_state = egnn.init(self.key(), (x, particle_type)) 104 | 105 | self.assert_equivariant(egnn, params, egnn_state) 106 | 107 | def test_painn(self): 108 | def painn(x): 109 | return models.PaiNN( 110 | hidden_size=8, 111 | output_size=1, 112 | num_mp_steps=1, 113 | radial_basis_fn=models.painn.gaussian_rbf(20, 10, trainable=True), 114 | cutoff_fn=models.painn.cosine_cutoff(10), 115 | n_vels=5, 116 | )(x) 117 | 118 | painn = hk.without_apply_rng(hk.transform_with_state(painn)) 119 | x, particle_type = self.dummy_sample() 120 | params, painn_state = painn.init(self.key(), (x, particle_type)) 121 | 122 | self.assert_equivariant(painn, params, painn_state) 123 | 124 | 125 | if __name__ == "__main__": 126 | unittest.main() 127 | -------------------------------------------------------------------------------- /tests/pushforward_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax 4 | import numpy as np 5 | from omegaconf import OmegaConf 6 | 7 | from lagrangebench.train.strats import push_forward_sample_steps 8 | 9 | 10 | class TestPushForward(unittest.TestCase): 11 | """Class for unit testing the push-forward functions.""" 12 | 13 | def setUp(self): 14 | self.pf = OmegaConf.create( 15 | { 16 | "steps": [-1, 20000, 50000, 100000], 17 | "unrolls": [0, 1, 3, 20], 18 | "probs": [4.05, 4.05, 1.0, 1.0], 19 | } 20 | ) 21 | 22 | self.key = jax.random.PRNGKey(42) 23 | 24 | def body_steps(self, step, unrolls, probs): 25 | dump = [] 26 | for _ in range(1000): 27 | self.key, unroll_steps = push_forward_sample_steps(self.key, step, self.pf) 28 | dump.append(unroll_steps) 29 | 30 | # Note: np.unique returns sorted array 31 | unique, counts = np.unique(dump, return_counts=True) 32 | self.assertTrue((unique == unrolls).all(), "Wrong unroll steps") 33 | self.assertTrue( 34 | np.allclose(counts / 1000, probs, atol=0.05), 35 | "Wrong probabilities of unroll steps", 36 | ) 37 | 38 | def test_pf_step_1(self): 39 | self.body_steps(1, np.array([0]), np.array([1.0])) 40 | 41 | def test_pf_step_60000(self): 42 | self.body_steps(60000, np.array([0, 1, 3]), np.array([0.45, 0.45, 0.1])) 43 | 44 | 45 | if __name__ == "__main__": 46 | unittest.main() 47 | -------------------------------------------------------------------------------- /tests/rollout_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from functools import partial 3 | 4 | import haiku as hk 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from jax import config as jax_config 9 | from jax import jit, vmap 10 | from jax_sph.jax_md import space 11 | from omegaconf import OmegaConf 12 | from torch.utils.data import DataLoader 13 | 14 | jax_config.update("jax_enable_x64", True) 15 | 16 | from lagrangebench.case_setup import case_builder 17 | from lagrangebench.data import H5Dataset 18 | from lagrangebench.data.utils import get_dataset_stats, numpy_collate 19 | from lagrangebench.evaluate import MetricsComputer 20 | from lagrangebench.evaluate.rollout import _eval_batched_rollout, _forward_eval 21 | from lagrangebench.utils import broadcast_from_batch 22 | 23 | 24 | class TestInferBuilder(unittest.TestCase): 25 | """Class for unit testing the evaluate_single_rollout function.""" 26 | 27 | def setUp(self): 28 | self.cfg = OmegaConf.create( 29 | { 30 | "dataset": { 31 | "src": "tests/3D_LJ_3_1214every1", # Lennard-Jones dataset 32 | }, 33 | "model": { 34 | "input_seq_length": 3, # two past velocities 35 | "isotropic_norm": False, 36 | }, 37 | "eval": { 38 | "train": {"metrics": ["mse"]}, 39 | "n_rollout_steps": 100, 40 | }, 41 | "train": {"noise_std": 0.0}, 42 | } 43 | ) 44 | 45 | data_valid = H5Dataset( 46 | split="valid", 47 | dataset_path=self.cfg.dataset.src, 48 | name="lj3d", 49 | input_seq_length=self.cfg.model.input_seq_length, 50 | extra_seq_length=self.cfg.eval.n_rollout_steps, 51 | ) 52 | self.loader_valid = DataLoader( 53 | dataset=data_valid, batch_size=1, collate_fn=numpy_collate 54 | ) 55 | 56 | self.metadata = data_valid.metadata 57 | self.normalization_stats = get_dataset_stats( 58 | self.metadata, self.cfg.model.isotropic_norm, self.cfg.train.noise_std 59 | ) 60 | 61 | bounds = np.array(self.metadata["bounds"]) 62 | box = bounds[:, 1] - bounds[:, 0] 63 | self.displacement_fn, self.shift_fn = space.periodic(side=box) 64 | 65 | self.case = case_builder( 66 | box, 67 | self.metadata, 68 | self.cfg.model.input_seq_length, 69 | noise_std=self.cfg.train.noise_std, 70 | ) 71 | 72 | self.key = jax.random.PRNGKey(0) 73 | 74 | def test_rollout(self): 75 | isl = self.loader_valid.dataset.input_seq_length 76 | 77 | # get one validation trajectory from the debug dataset 78 | traj_batch_i = next(iter(self.loader_valid)) 79 | traj_batch_i = jax.tree_map(lambda x: jnp.array(x), traj_batch_i) 80 | # remove batch dimension 81 | self.assertTrue(traj_batch_i[0].shape[0] == 1, "We test only batch size 1") 82 | traj_i = broadcast_from_batch(traj_batch_i, index=0) 83 | positions = traj_i[0] # (nodes, t, dim) = (3, 405, 3) 84 | 85 | displ_vmap = vmap(self.displacement_fn, (0, 0)) 86 | displ_dvmap = vmap(displ_vmap, (0, 0)) 87 | vels = displ_dvmap(positions[:, 1:], positions[:, :-1]) # (3, 404, 3) 88 | accs = vels[:, 1:] - vels[:, :-1] # (3, 403, 3) 89 | stats = self.normalization_stats["acceleration"] 90 | accs = (accs - stats["mean"]) / stats["std"] 91 | 92 | class CheatingModel(hk.Module): 93 | def __init__(self, target, start): 94 | super().__init__() 95 | self.target = target 96 | self.start = start 97 | 98 | def __call__(self, x): 99 | i = hk.get_state( 100 | "counter", 101 | shape=[], 102 | dtype=jnp.int32, 103 | init=hk.initializers.Constant(self.start), 104 | ) 105 | hk.set_state("counter", i + 1) 106 | return {"acc": self.target[:, i]} 107 | 108 | def setup_model(target, start): 109 | def model(x): 110 | return CheatingModel(target, start)(x) 111 | 112 | model = hk.without_apply_rng(hk.transform_with_state(model)) 113 | params, state = model.init(None, None) 114 | model_apply = model.apply 115 | model_apply = jit(model_apply) 116 | return params, state, model_apply 117 | 118 | params, state, model_apply = setup_model(accs, 0) 119 | 120 | # proof that the above "model" works 121 | out, state = model_apply(params, state, None) 122 | pred_acc = stats["mean"] + out["acc"] * stats["std"] 123 | pred_pos = self.shift_fn(positions[:, isl - 1], vels[:, isl - 2] + pred_acc) 124 | pred_pos = jnp.asarray(pred_pos, dtype=jnp.float32) 125 | target_pos = positions[:, isl] 126 | 127 | assert jnp.isclose(pred_pos, target_pos, atol=1e-7).all(), "Wrong setup" 128 | 129 | params, state, model_apply = setup_model(accs, isl - 2) 130 | _, neighbors = self.case.allocate_eval((positions[:, :isl], traj_i[1])) 131 | 132 | metrics_computer = MetricsComputer( 133 | ["mse"], 134 | self.case.displacement, 135 | self.metadata, 136 | isl, 137 | ) 138 | 139 | forward_eval = partial( 140 | _forward_eval, 141 | model_apply=model_apply, 142 | case_integrate=self.case.integrate, 143 | ) 144 | forward_eval_vmap = vmap(forward_eval, in_axes=(None, None, 0, 0, 0)) 145 | preprocess_eval_vmap = vmap(self.case.preprocess_eval, in_axes=(0, 0)) 146 | metrics_computer_vmap = vmap(metrics_computer, in_axes=(0, 0)) 147 | 148 | for n_extrap_steps in [0, 5, 10]: 149 | with self.subTest(n_extrap_steps): 150 | example_rollout_batch, metrics_batch, neighbors = _eval_batched_rollout( 151 | forward_eval_vmap=forward_eval_vmap, 152 | preprocess_eval_vmap=preprocess_eval_vmap, 153 | case=self.case, 154 | params=params, 155 | state=state, 156 | traj_batch_i=traj_batch_i, 157 | neighbors=neighbors, 158 | metrics_computer_vmap=metrics_computer_vmap, 159 | n_rollout_steps=self.cfg.eval.n_rollout_steps, 160 | n_extrap_steps=n_extrap_steps, 161 | t_window=isl, 162 | ) 163 | example_rollout = broadcast_from_batch(example_rollout_batch, index=0) 164 | metrics = broadcast_from_batch(metrics_batch, index=0) 165 | 166 | self.assertTrue( 167 | jnp.isclose( 168 | metrics["mse"].mean(), 169 | jnp.array(0.0), 170 | atol=1e-6, 171 | ).all(), 172 | "Wrong rollout mse", 173 | ) 174 | 175 | pos_input = traj_i[0].transpose(1, 0, 2) # (t, nodes, dim) 176 | initial_positions = pos_input[:isl] 177 | example_full = np.concatenate( 178 | [initial_positions, example_rollout], axis=0 179 | ) 180 | rollout_dict = { 181 | "predicted_rollout": example_full, # (t, nodes, dim) 182 | "ground_truth_rollout": pos_input, # (t, nodes, dim) 183 | } 184 | 185 | self.assertTrue( 186 | jnp.isclose( 187 | rollout_dict["predicted_rollout"][100, 0], 188 | rollout_dict["ground_truth_rollout"][100, 0], 189 | atol=1e-6, 190 | ).all(), 191 | "Wrong rollout prediction", 192 | ) 193 | 194 | total_steps = self.cfg.eval.n_rollout_steps + n_extrap_steps 195 | assert example_rollout_batch.shape[1] == total_steps 196 | 197 | 198 | if __name__ == "__main__": 199 | unittest.main() 200 | -------------------------------------------------------------------------------- /tests/runner_test.py: -------------------------------------------------------------------------------- 1 | """Runner test with a linear model and LJ dataset.""" 2 | 3 | import unittest 4 | 5 | from omegaconf import OmegaConf 6 | 7 | from lagrangebench.defaults import defaults 8 | from lagrangebench.runner import train_or_infer 9 | 10 | 11 | class TestRunner(unittest.TestCase): 12 | """Test whether train_or_infer runs through.""" 13 | 14 | def setUp(self): 15 | self.cfg = OmegaConf.create( 16 | { 17 | "mode": "all", 18 | "dataset": { 19 | "src": "tests/3D_LJ_3_1214every1", 20 | }, 21 | "model": { 22 | "name": "linear", 23 | "input_seq_length": 3, 24 | }, 25 | "train": { 26 | "step_max": 10, 27 | "noise_std": 0.0, 28 | }, 29 | "eval": { 30 | "n_rollout_steps": 5, 31 | "train": { 32 | "n_trajs": 2, 33 | "metrics_stride": 5, 34 | "metrics": ["mse"], 35 | "out_type": "none", 36 | }, 37 | "infer": { 38 | "n_trajs": 2, 39 | "metrics_stride": 1, 40 | "metrics": ["mse"], 41 | "out_type": "none", 42 | }, 43 | }, 44 | "logging": { 45 | "log_steps": 1, 46 | "eval_steps": 5, 47 | "wandb": False, 48 | "ckp_dir": "/tmp/ckp", 49 | }, 50 | } 51 | ) 52 | # overwrite defaults with user-defined config 53 | self.cfg = OmegaConf.merge(defaults, self.cfg) 54 | 55 | def test_runner(self): 56 | out = train_or_infer(self.cfg) 57 | self.assertEqual(out, 0) 58 | 59 | 60 | if __name__ == "__main__": 61 | unittest.main() 62 | --------------------------------------------------------------------------------