├── .codecov.yml
├── .github
└── workflows
│ ├── publish.yml
│ ├── ruff.yml
│ └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── configs
├── WaterDrop_2d
│ └── gns.yaml
├── dam_2d
│ ├── base.yaml
│ ├── gns.yaml
│ └── segnn.yaml
├── ldc_2d
│ ├── base.yaml
│ ├── gns.yaml
│ └── segnn.yaml
├── ldc_3d
│ ├── base.yaml
│ ├── gns.yaml
│ └── segnn.yaml
├── rpf_2d
│ ├── base.yaml
│ ├── egnn.yaml
│ ├── gns.yaml
│ ├── painn.yaml
│ └── segnn.yaml
├── rpf_3d
│ ├── base.yaml
│ ├── egnn.yaml
│ ├── gns.yaml
│ ├── painn.yaml
│ └── segnn.yaml
├── tgv_2d
│ ├── base.yaml
│ ├── gns.yaml
│ └── segnn.yaml
└── tgv_3d
│ ├── base.yaml
│ ├── gns.yaml
│ └── segnn.yaml
├── data_gen
├── gns_data
│ ├── README.md
│ ├── __init__.py
│ ├── download_dataset.sh
│ ├── reading_utils.py
│ └── tfrecord_to_h5.py
└── lagrangebench_data
│ ├── README.md
│ ├── dataset_db.sh
│ ├── dataset_ldc.sh
│ ├── dataset_rpf.sh
│ ├── dataset_tgv.sh
│ ├── gen_dataset.py
│ └── plot_frame.py
├── docs
├── Makefile
├── conf.py
├── index.rst
├── lagrangebench_logo.svg
├── make.bat
├── pages
│ ├── baselines.rst
│ ├── case_setup.rst
│ ├── data.rst
│ ├── defaults.rst
│ ├── evaluate.rst
│ ├── models.rst
│ ├── train.rst
│ ├── tutorial.rst
│ └── utils.rst
└── requirements.txt
├── download_data.sh
├── lagrangebench
├── __init__.py
├── case_setup
│ ├── __init__.py
│ ├── case.py
│ └── features.py
├── data
│ ├── __init__.py
│ ├── data.py
│ └── utils.py
├── defaults.py
├── evaluate
│ ├── __init__.py
│ ├── metrics.py
│ ├── rollout.py
│ └── utils.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── egnn.py
│ ├── gns.py
│ ├── linear.py
│ ├── painn.py
│ ├── segnn.py
│ └── utils.py
├── runner.py
├── train
│ ├── __init__.py
│ ├── strats.py
│ └── trainer.py
└── utils.py
├── main.py
├── notebooks
├── data_gen.ipynb
├── datasets.ipynb
├── gns_data.ipynb
├── media
│ └── scatter.gif
└── tutorial.ipynb
├── poetry.lock
├── pyproject.toml
├── requirements_cuda.txt
└── tests
├── 3D_LJ_3_1214every1
├── metadata.json
├── test.h5
├── train.h5
└── valid.h5
├── case_test.py
├── models_test.py
├── pushforward_test.py
├── rollout_test.py
└── runner_test.py
/.codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | range: 50..70 # red color under 50%, yellow at 50%..70%, green over 70%
3 | precision: 1
4 | status:
5 | project:
6 | default:
7 | target: 60% # coverage success only above X%
8 | threshold: 5% # allow the coverage to drop by X% and being a success
9 | patch:
10 | default:
11 | target: 50%
12 | threshold: 5%
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to PyPI
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | build-n-publish:
9 | name: Build and publish to PyPI
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v4
14 | - name: Set up Python
15 | uses: actions/setup-python@v4
16 | with:
17 | python-version: '3.10'
18 | - name: Install Poetry
19 | run: |
20 | python -m pip install --upgrade pip
21 | pip install poetry
22 | poetry config virtualenvs.in-project true
23 | - name: Install dependencies
24 | run: |
25 | poetry install
26 | - name: Build and publish
27 | env:
28 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.POETRY_PYPI_TOKEN_PYPI }}
29 | run: |
30 | poetry version $(git describe --tags --abbrev=0)
31 | poetry build
32 | poetry publish
33 |
--------------------------------------------------------------------------------
/.github/workflows/ruff.yml:
--------------------------------------------------------------------------------
1 | name: Ruff
2 | on: [pull_request]
3 | jobs:
4 | ruff:
5 | runs-on: ubuntu-latest
6 | steps:
7 | - uses: actions/checkout@v4
8 | - name: Install Ruff from pyproject.toml
9 | uses: astral-sh/ruff-action@v3
10 | with:
11 | version-file: "pyproject.toml"
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | pull_request:
7 | branches: [ "main" ]
8 |
9 | permissions:
10 | contents: read
11 |
12 | jobs:
13 | tests:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | matrix:
18 | python-version: ["3.9", "3.10", "3.11"]
19 |
20 | steps:
21 | - uses: actions/checkout@v4
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v4
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Install Poetry
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install poetry
30 | poetry config virtualenvs.in-project true
31 | - name: Install dependencies
32 | run: |
33 | poetry install
34 | - name: Run pytest and generate coverage report
35 | run: |
36 | .venv/bin/pytest --cov-report=xml
37 | - name: Upload coverage report to Codecov
38 | uses: codecov/codecov-action@v3
39 | with:
40 | token: ${{ secrets.CODECOV_TOKEN }}
41 | file: ./coverage.xml
42 | flags: unittests
43 | verbose: true
44 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # experiments
2 | ckp/
3 | rollout/
4 | rollouts/
5 | wandb/
6 | *.out
7 | datasets
8 | baselines
9 | partition_*
10 | *.pkl
11 | datasets/
12 | # dev
13 | .vscode
14 | __pycache__
15 | *.pyc
16 | venv*/
17 | *.egg-info
18 | .ruff-cache
19 | rollouts
20 | profile
21 | dist
22 | .coverage
23 |
24 | # Sphinx documentation
25 | docs/_build/
26 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 | exclude: |
4 | (?x)^(
5 | venv/|
6 | ckp/|
7 | rollout/|
8 | docs/|
9 | )$
10 | repos:
11 | - repo: https://github.com/pre-commit/pre-commit-hooks
12 | rev: v4.4.0
13 | hooks:
14 | - id: check-merge-conflict
15 | - id: check-added-large-files
16 | - id: check-docstring-first
17 | - id: check-json
18 | - id: check-toml
19 | - id: check-yaml
20 | - id: requirements-txt-fixer
21 | - repo: https://github.com/astral-sh/ruff-pre-commit
22 | rev: 'v0.2.2'
23 | hooks:
24 | - id: ruff
25 | args: [ --fix ]
26 | - id: ruff-format
27 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: "2"
2 |
3 | build:
4 | os: "ubuntu-22.04"
5 | tools:
6 | python: "3.9"
7 |
8 | sphinx:
9 | configuration: docs/conf.py
10 |
11 | python:
12 | install:
13 | - requirements: docs/requirements.txt
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Chair of Aerodynamics and Fluid Mechanics @ TUM
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/configs/WaterDrop_2d/gns.yaml:
--------------------------------------------------------------------------------
1 | extends: LAGRANGEBENCH_DEFAULTS
2 |
3 | dataset:
4 | src: /tmp/datasets/WaterDrop
5 |
6 | model:
7 | name: gns
8 | num_mp_steps: 10
9 | latent_dim: 128
10 |
11 | train:
12 | optimizer:
13 | lr_start: 5.e-4
14 |
15 | logging:
16 | wandb_project: waterdrop_2d
17 |
18 | neighbors:
19 | backend: matscipy
20 |
--------------------------------------------------------------------------------
/configs/dam_2d/base.yaml:
--------------------------------------------------------------------------------
1 | extends: LAGRANGEBENCH_DEFAULTS
2 |
3 | dataset:
4 | src: datasets/2D_DAM_5740_20kevery100
5 |
6 | logging:
7 | wandb_project: dam_2d
8 |
9 | neighbors:
10 | multiplier: 2.0
11 |
--------------------------------------------------------------------------------
/configs/dam_2d/gns.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/dam_2d/base.yaml
2 |
3 | model:
4 | name: gns
5 | num_mp_steps: 10
6 | latent_dim: 128
7 |
8 | train:
9 | noise_std: 0.001
10 | optimizer:
11 | lr_start: 5.e-4
12 |
--------------------------------------------------------------------------------
/configs/dam_2d/segnn.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/dam_2d/base.yaml
2 |
3 | model:
4 | name: segnn
5 | num_mp_steps: 10
6 | latent_dim: 64
7 | isotropic_norm: True
8 |
9 | train:
10 | noise_std: 0.001
11 | optimizer:
12 | lr_start: 5.e-4
13 |
--------------------------------------------------------------------------------
/configs/ldc_2d/base.yaml:
--------------------------------------------------------------------------------
1 | extends: LAGRANGEBENCH_DEFAULTS
2 |
3 | dataset:
4 | src: datasets/2D_LDC_2708_10kevery100
5 |
6 | logging:
7 | wandb_project: ldc_2d
8 |
9 | neighbors:
10 | multiplier: 2.0
--------------------------------------------------------------------------------
/configs/ldc_2d/gns.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/ldc_2d/base.yaml
2 |
3 | model:
4 | name: gns
5 | num_mp_steps: 10
6 | latent_dim: 128
7 |
8 | train:
9 | noise_std: 0.001
10 | optimizer:
11 | lr_start: 5.e-4
12 |
--------------------------------------------------------------------------------
/configs/ldc_2d/segnn.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/ldc_2d/base.yaml
2 |
3 | model:
4 | name: segnn
5 | num_mp_steps: 10
6 | latent_dim: 64
7 | isotropic_norm: True
8 |
9 | train:
10 | noise_std: 0.001
11 | optimizer:
12 | lr_start: 5.e-4
13 |
--------------------------------------------------------------------------------
/configs/ldc_3d/base.yaml:
--------------------------------------------------------------------------------
1 | extends: LAGRANGEBENCH_DEFAULTS
2 |
3 | dataset:
4 | src: datasets/3D_LDC_8160_10kevery100
5 |
6 | logging:
7 | wandb_project: ldc_3d
8 |
9 | neighbors:
10 | multiplier: 2.0
--------------------------------------------------------------------------------
/configs/ldc_3d/gns.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/ldc_3d/base.yaml
2 |
3 | model:
4 | name: gns
5 | num_mp_steps: 10
6 | latent_dim: 128
7 |
8 | train:
9 | optimizer:
10 | lr_start: 5.e-4
11 |
--------------------------------------------------------------------------------
/configs/ldc_3d/segnn.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/ldc_3d/base.yaml
2 |
3 | model:
4 | name: segnn
5 | num_mp_steps: 10
6 | latent_dim: 64
7 | isotropic_norm: True
8 |
9 | train:
10 | optimizer:
11 | lr_start: 5.e-4
12 |
--------------------------------------------------------------------------------
/configs/rpf_2d/base.yaml:
--------------------------------------------------------------------------------
1 | extends: LAGRANGEBENCH_DEFAULTS
2 |
3 | dataset:
4 | src: datasets/2D_RPF_3200_20kevery100
5 |
6 | logging:
7 | wandb_project: rpf_2d
--------------------------------------------------------------------------------
/configs/rpf_2d/egnn.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/rpf_2d/base.yaml
2 |
3 | model:
4 | name: egnn
5 | num_mp_steps: 5
6 | latent_dim: 128
7 | isotropic_norm: True
8 | magnitude_features: True
9 |
10 | train:
11 | optimizer:
12 | lr_start: 5.e-4
13 | loss_weight:
14 | pos: 1.0
15 | vel: 0.0
16 | acc: 0.0
17 |
--------------------------------------------------------------------------------
/configs/rpf_2d/gns.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/rpf_2d/base.yaml
2 |
3 | model:
4 | name: gns
5 | num_mp_steps: 10
6 | latent_dim: 128
7 |
8 | train:
9 | optimizer:
10 | lr_start: 5.e-4
11 |
--------------------------------------------------------------------------------
/configs/rpf_2d/painn.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/rpf_2d/base.yaml
2 |
3 | model:
4 | name: painn
5 | num_mp_steps: 5
6 | latent_dim: 128
7 | isotropic_norm: True
8 | magnitude_features: True
9 |
10 | train:
11 | optimizer:
12 | lr_start: 1.e-4
13 |
--------------------------------------------------------------------------------
/configs/rpf_2d/segnn.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/rpf_2d/base.yaml
2 |
3 | model:
4 | name: segnn
5 | num_mp_steps: 10
6 | latent_dim: 64
7 | isotropic_norm: True
8 |
9 | train:
10 | optimizer:
11 | lr_start: 1.e-3
12 |
--------------------------------------------------------------------------------
/configs/rpf_3d/base.yaml:
--------------------------------------------------------------------------------
1 | extends: LAGRANGEBENCH_DEFAULTS
2 |
3 | dataset:
4 | src: datasets/3D_RPF_8000_10kevery100
5 |
6 | logging:
7 | wandb_project: rpf_3d
--------------------------------------------------------------------------------
/configs/rpf_3d/egnn.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/rpf_3d/base.yaml
2 |
3 | model:
4 | name: egnn
5 | num_mp_steps: 5
6 | latent_dim: 128
7 | isotropic_norm: True
8 | magnitude_features: True
9 |
10 | train:
11 | optimizer:
12 | lr_start: 1.e-4
13 | loss_weight:
14 | pos: 1.0
15 | vel: 0.0
16 | acc: 0.0
17 |
--------------------------------------------------------------------------------
/configs/rpf_3d/gns.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/rpf_3d/base.yaml
2 |
3 | model:
4 | name: gns
5 | num_mp_steps: 10
6 | latent_dim: 128
7 |
8 | train:
9 | optimizer:
10 | lr_start: 5.e-4
11 |
--------------------------------------------------------------------------------
/configs/rpf_3d/painn.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/rpf_3d/base.yaml
2 |
3 | model:
4 | name: painn
5 | num_mp_steps: 5
6 | latent_dim: 128
7 | isotropic_norm: True
8 | magnitude_features: True
9 |
10 | train:
11 | optimizer:
12 | lr_start: 5.e-4
13 |
--------------------------------------------------------------------------------
/configs/rpf_3d/segnn.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/rpf_3d/base.yaml
2 |
3 | model:
4 | name: segnn
5 | num_mp_steps: 10
6 | latent_dim: 64
7 | isotropic_norm: True
8 |
9 | train:
10 | optimizer:
11 | lr_start: 1.e-3
12 |
--------------------------------------------------------------------------------
/configs/tgv_2d/base.yaml:
--------------------------------------------------------------------------------
1 | extends: LAGRANGEBENCH_DEFAULTS
2 |
3 | dataset:
4 | src: datasets/2D_TGV_2500_10kevery100
5 |
6 | logging:
7 | wandb_project: tgv_2d
8 |
--------------------------------------------------------------------------------
/configs/tgv_2d/gns.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/tgv_2d/base.yaml
2 |
3 | model:
4 | name: gns
5 | num_mp_steps: 10
6 | latent_dim: 128
7 |
8 | train:
9 | optimizer:
10 | lr_start: 5.e-4
11 |
--------------------------------------------------------------------------------
/configs/tgv_2d/segnn.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/tgv_2d/base.yaml
2 |
3 | model:
4 | name: segnn
5 | num_mp_steps: 10
6 | latent_dim: 64
7 | isotropic_norm: True
8 |
9 | train:
10 | optimizer:
11 | lr_start: 5.e-4
12 |
--------------------------------------------------------------------------------
/configs/tgv_3d/base.yaml:
--------------------------------------------------------------------------------
1 | extends: LAGRANGEBENCH_DEFAULTS
2 |
3 | dataset:
4 | src: datasets/3D_TGV_8000_10kevery100
5 |
6 | logging:
7 | wandb_project: tgv_3d
8 |
--------------------------------------------------------------------------------
/configs/tgv_3d/gns.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/tgv_3d/base.yaml
2 |
3 | model:
4 | name: gns
5 | num_mp_steps: 10
6 | latent_dim: 128
7 |
8 | train:
9 | optimizer:
10 | lr_start: 5.e-4
11 |
--------------------------------------------------------------------------------
/configs/tgv_3d/segnn.yaml:
--------------------------------------------------------------------------------
1 | extends: configs/tgv_3d/base.yaml
2 |
3 | model:
4 | name: segnn
5 | num_mp_steps: 10
6 | latent_dim: 64
7 | isotropic_norm: True
8 |
9 | train:
10 | optimizer:
11 | lr_start: 5.e-4
12 |
--------------------------------------------------------------------------------
/data_gen/gns_data/README.md:
--------------------------------------------------------------------------------
1 | # Demonstration on how to train the GNS model on one of its original 2D datasets
2 |
3 | > Check out the full notebook under [`notebooks/gns_data.ipynb`](../notebooks/gns_data.ipynb).
4 |
5 | ## Download data
6 |
7 | ```bash
8 | mkdir -p /tmp/datasets
9 | bash data_gen/gns_data/download_dataset.sh WaterDrop /tmp/datasets
10 | ```
11 |
12 | ## Transform data from .tfrecord to .h5
13 |
14 | First, you need the `tensorflow` and `tensorflow-datasets` libraries. We recommend installing these in a separate virtual environment to avoid CUDA version conflicts.
15 |
16 | ```bash
17 | python3 -m venv venv_tf
18 | venv_tf/bin/pip install tensorflow tensorflow-datasets
19 | ```
20 |
21 | Then, transform the data via
22 |
23 | ```bash
24 | ./venv_tf/bin/python data_gen/gns_data/tfrecord_to_h5.py --dataset-path=/tmp/datasets/WaterDrop
25 | ```
26 |
27 | and train the usual way.
28 |
--------------------------------------------------------------------------------
/data_gen/gns_data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tumaer/lagrangebench/b880a6c84a93792d2499d2a9b8ba3a077ddf44e2/data_gen/gns_data/__init__.py
--------------------------------------------------------------------------------
/data_gen/gns_data/download_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2020 Deepmind Technologies Limited.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | #
16 | # Usage:
17 | # bash download_dataset.sh ${DATASET_NAME} ${OUTPUT_DIR}
18 | # Example:
19 | # bash download_dataset.sh WaterDrop /tmp/
20 |
21 | # Source: https://github.com/deepmind/deepmind-research/tree/master/learning_to_simulate
22 |
23 | set -e
24 |
25 | DATASET_NAME="${1}"
26 | OUTPUT_DIR="${2}/${DATASET_NAME}"
27 |
28 | BASE_URL="https://storage.googleapis.com/learning-to-simulate-complex-physics/Datasets/${DATASET_NAME}/"
29 |
30 | mkdir -p ${OUTPUT_DIR}
31 | for file in metadata.json train.tfrecord valid.tfrecord test.tfrecord
32 | do
33 | wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}"
34 | done
35 |
--------------------------------------------------------------------------------
/data_gen/gns_data/reading_utils.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=g-bad-file-header
2 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ============================================================================
16 | """Utilities for reading open sourced Learning Complex Physics data."""
17 |
18 | # Source: https://github.com/deepmind/deepmind-research/tree/master/learning_to_simulate
19 |
20 |
21 | import functools
22 |
23 | import numpy as np
24 | import tensorflow.compat.v1 as tf
25 |
26 | # Create a description of the features.
27 | _FEATURE_DESCRIPTION = {
28 | "position": tf.io.VarLenFeature(tf.string),
29 | }
30 |
31 | _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT = _FEATURE_DESCRIPTION.copy()
32 | _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT["step_context"] = tf.io.VarLenFeature(
33 | tf.string
34 | )
35 |
36 | _FEATURE_DTYPES = {
37 | "position": {"in": np.float32, "out": tf.float32},
38 | "step_context": {"in": np.float32, "out": tf.float32},
39 | }
40 |
41 | _CONTEXT_FEATURES = {
42 | "key": tf.io.FixedLenFeature([], tf.int64, default_value=0),
43 | "particle_type": tf.io.VarLenFeature(tf.string),
44 | }
45 |
46 |
47 | def convert_to_tensor(x, encoded_dtype):
48 | if len(x) == 1:
49 | out = np.frombuffer(x[0].numpy(), dtype=encoded_dtype)
50 | else:
51 | out = []
52 | for el in x:
53 | out.append(np.frombuffer(el.numpy(), dtype=encoded_dtype))
54 | out = tf.convert_to_tensor(np.array(out))
55 | return out
56 |
57 |
58 | def parse_serialized_simulation_example(example_proto, metadata):
59 | """Parses a serialized simulation tf.SequenceExample.
60 |
61 | Args:
62 | example_proto: A string encoding of the tf.SequenceExample proto.
63 | metadata: A dict of metadata for the dataset.
64 |
65 | Returns:
66 | context: A dict, with features that do not vary over the trajectory.
67 | parsed_features: A dict of tf.Tensors representing the parsed examples
68 | across time, where axis zero is the time axis.
69 |
70 | """
71 | if "context_mean" in metadata:
72 | feature_description = _FEATURE_DESCRIPTION_WITH_GLOBAL_CONTEXT
73 | else:
74 | feature_description = _FEATURE_DESCRIPTION
75 | context, parsed_features = tf.io.parse_single_sequence_example(
76 | example_proto,
77 | context_features=_CONTEXT_FEATURES,
78 | sequence_features=feature_description,
79 | )
80 | for feature_key, item in parsed_features.items():
81 | convert_fn = functools.partial(
82 | convert_to_tensor, encoded_dtype=_FEATURE_DTYPES[feature_key]["in"]
83 | )
84 | parsed_features[feature_key] = tf.py_function(
85 | convert_fn, inp=[item.values], Tout=_FEATURE_DTYPES[feature_key]["out"]
86 | )
87 |
88 | # There is an extra frame at the beginning so we can calculate pos change
89 | # for all frames used in the paper.
90 | position_shape = [metadata["sequence_length"] + 1, -1, metadata["dim"]]
91 |
92 | # Reshape positions to correct dim:
93 | parsed_features["position"] = tf.reshape(
94 | parsed_features["position"], position_shape
95 | )
96 | # Set correct shapes of the remaining tensors.
97 | sequence_length = metadata["sequence_length"] + 1
98 | if "context_mean" in metadata:
99 | context_feat_len = len(metadata["context_mean"])
100 | parsed_features["step_context"] = tf.reshape(
101 | parsed_features["step_context"], [sequence_length, context_feat_len]
102 | )
103 | # Decode particle type explicitly
104 | context["particle_type"] = tf.py_function(
105 | functools.partial(convert_fn, encoded_dtype=np.int64),
106 | inp=[context["particle_type"].values],
107 | Tout=[tf.int64],
108 | )
109 | context["particle_type"] = tf.reshape(context["particle_type"], [-1])
110 | return context, parsed_features
111 |
112 |
113 | def split_trajectory(context, features, window_length=7):
114 | """Splits trajectory into sliding windows."""
115 | # Our strategy is to make sure all the leading dimensions are the same size,
116 | # then we can use from_tensor_slices.
117 |
118 | trajectory_length = features["position"].get_shape().as_list()[0]
119 |
120 | # We then stack window_length position changes so the final
121 | # trajectory length will be - window_length +1 (the 1 to make sure we get
122 | # the last split).
123 | input_trajectory_length = trajectory_length - window_length + 1
124 |
125 | model_input_features = {}
126 | # Prepare the context features per step.
127 | model_input_features["particle_type"] = tf.tile(
128 | tf.expand_dims(context["particle_type"], axis=0), [input_trajectory_length, 1]
129 | )
130 |
131 | if "step_context" in features:
132 | global_stack = []
133 | for idx in range(input_trajectory_length):
134 | global_stack.append(features["step_context"][idx : idx + window_length])
135 | model_input_features["step_context"] = tf.stack(global_stack)
136 |
137 | pos_stack = []
138 | for idx in range(input_trajectory_length):
139 | pos_stack.append(features["position"][idx : idx + window_length])
140 | # Get the corresponding positions
141 | model_input_features["position"] = tf.stack(pos_stack)
142 |
143 | return tf.data.Dataset.from_tensor_slices(model_input_features)
144 |
--------------------------------------------------------------------------------
/data_gen/gns_data/tfrecord_to_h5.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import functools
3 | import json
4 | import os
5 |
6 | import h5py
7 | import numpy as np
8 | import reading_utils
9 | import tensorflow.compat.v1 as tf
10 | import tensorflow_datasets as tfds
11 |
12 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
13 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
14 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
15 |
16 |
17 | def convert_tfrecord_to_h5(args):
18 | """Read .tfrecord file and convert it to its closest .h5 equivalent"""
19 |
20 | file_path = os.path.join(args.dataset_path, args.file_name)
21 | print(f"Start conversion of {file_path} to .h5")
22 |
23 | with open(os.path.join(args.dataset_path, "metadata.json"), "r") as fp:
24 | metadata = json.loads(fp.read())
25 |
26 | # get the TFRecordDataset with its proper preprocessing
27 | ds = tf.data.TFRecordDataset([file_path])
28 | ds = ds.map(
29 | functools.partial(
30 | reading_utils.parse_serialized_simulation_example, metadata=metadata
31 | )
32 | )
33 | ds = tfds.as_numpy(ds)
34 |
35 | h5_file_path = f"{file_path[:-9]}.h5"
36 | hf = h5py.File(h5_file_path, "w")
37 |
38 | for i, elem in enumerate(ds):
39 | traj_str = str(i).zfill(5)
40 |
41 | particle_type = elem[0]["particle_type"]
42 | key = elem[0]["key"]
43 | position = elem[1]["position"]
44 |
45 | hf.create_dataset(f"{traj_str}/particle_type", data=particle_type)
46 | assert key == i, "Something went wrong here"
47 | hf.create_dataset(
48 | f"{traj_str}/position",
49 | data=position,
50 | dtype=np.float32,
51 | compression="gzip",
52 | )
53 |
54 | hf.close()
55 | print(f"Finish conversion to {h5_file_path}")
56 |
57 |
58 | def main(args):
59 | files = os.listdir(args.dataset_path)
60 | files = [f for f in files if f.endswith(".tfrecord")]
61 | for file_name in files:
62 | args.file_name = file_name
63 | convert_tfrecord_to_h5(args)
64 |
65 | # add the maximum number of particles to the metadata
66 | # Crucial for the matscipy neighbors search
67 |
68 | # first find the maximum number of particles
69 | files = os.listdir(args.dataset_path)
70 | files = [f for f in files if f.endswith(".h5")]
71 | max_particles = 0
72 | for file_name in files:
73 | h5_file_path = os.path.join(args.dataset_path, file_name)
74 | hf = h5py.File(h5_file_path, "r")
75 | for k, v in hf.items():
76 | max_particles = max(v["particle_type"].shape[0], max_particles)
77 | print(f"Max number of particles in {file_name}: {max_particles}")
78 | hf.close()
79 |
80 | metadata_path = os.path.join(args.dataset_path, "metadata.json")
81 | with open(metadata_path, "r") as fp:
82 | metadata = json.loads(fp.read())
83 |
84 | metadata["num_particles_max"] = max_particles
85 | # all DeepMind datasets are non-periodic
86 | metadata["periodic_boundary_conditions"] = [False, False, False]
87 |
88 | with open(metadata_path, "w") as f:
89 | json.dump(metadata, f)
90 |
91 |
92 | if __name__ == "__main__":
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument("--dataset-path", type=str)
95 | args = parser.parse_args()
96 |
97 | main(args)
98 |
--------------------------------------------------------------------------------
/data_gen/lagrangebench_data/README.md:
--------------------------------------------------------------------------------
1 | # LagrangeBench dataset generation
2 |
3 | To generate the [LagrangeBench datasets](https://zenodo.org/doi/10.5281/zenodo.10021925), we extend the case files provided at https://github.com/tumaer/jax-sph. We first copy the case files and the `main.py` file from JAX-SPH.
4 |
5 | ```bash
6 | cd data_gen/lagrangebench_data
7 | git clone https://github.com/tumaer/jax-sph.git
8 | cd jax-sph
9 | # We use this specific tag
10 | git checkout v0.0.2
11 |
12 | cd ..
13 | cp -r jax-sph/cases/ .
14 | cp jax-sph/main.py .
15 | ```
16 |
17 | Then we install JAX-SPH
18 | ```bash
19 | pip install jax-sph/
20 | # or
21 | # pip install jax-sph==0.0.2
22 |
23 | # make sure to have the dasired JAX version, e.g.
24 | pip install jax[cuda12]==0.4.29
25 | ```
26 |
27 | And the only thing left is running the bash scripts. E.g. for Taylor-Green it is:
28 | ```bash
29 | bash dataset_tgv.sh
30 |
31 | # cleanup
32 | rm -rf jax-sph/ cases/ main.py
33 | ```
34 |
35 | To inspect whether the simulated trajectories are of the desired length, we can use
36 | ```bash
37 | count_and_check() { find "$1" -type f | sed 's%/[^/]*$%%' | uniq -c | awk -v target="$2" '{if ($1 != target) $1=$1" - Failed!"; n=split($2, a, /[/]/); print a[n]" - "$1}'; }
38 |
39 | # and apply it e.g. to TGV with the arguments: 1. dataset directory and 2. target count
40 | count_and_check "/tmp/lagrangebench_data/raw/2D_TGV_2500_10kevery100/" 127
41 | ```
42 |
43 | ## Errata
44 | 1. For dam break, one would need to replace lines 182-184 in `jax-sph/case_setup.py` with:
45 | ```python
46 | mask, _mask = state["tag"]==Tag.FLUID, _state["tag"]==Tag.FLUID
47 | assert state[k][mask].shape == _state[k][_mask].shape, ValueError(
48 | f"Shape mismatch for key {k} in state0 file."
49 | )
50 | state[k][mask] = _state[k][_mask]
51 | ```
--------------------------------------------------------------------------------
/data_gen/lagrangebench_data/dataset_db.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # CUDA_VISIBLE_DEVICES=0 nohup ./scripts/dataset_db.sh >> db_dataset.out 2>&1 &
3 |
4 | DATA_ROOT=/tmp/lagrangebench_data
5 |
6 | ##### 2D dataset
7 | for seed in {0..114} # 15 trajectories blew up and were discarded
8 | do
9 | echo "Run with seed = $seed"
10 | python main.py config=cases/db.yaml seed=$seed case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/
11 | python main.py config=cases/db.yaml seed=$seed case.state0_path=$DATA_ROOT/relaxed/db_2_0.02_$seed.h5 io.data_path=$DATA_ROOT/raw/2D_DB_5740_20kevery100/
12 | done
13 | # 15 blowing up runs were removed from the dataset, and 100 were kept
14 | # use `count_files_db.py` to detect defect runs
15 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/2D_DB_5740_20kevery100/ --dst_dir=$DATA_ROOT/datasets/2D_DB_5740_20kevery100/ --split=2_1_1
16 |
17 |
18 | ### Number of particles
19 | # 100x50=5000 water particles
20 | # 106x274 outer box, i.e. 106x274 - 100x268 = 2244 wall particles
21 | # with only one wall layer, wall particles are 2*(100+270) = 740
22 | # => 7244 particles with SPH and 5740 for dataset
23 |
24 | ### Number of seeds for a given number of training samples
25 | # t_end = 12
26 | # dt_coarse = 0.0003*100 = 0.03
27 | # num_samples = 12/0.03 = 400 (+1 for initial state)
28 | # => for 40k train+valid+test samples, we need 40k/400 = 100 seeds
29 |
--------------------------------------------------------------------------------
/data_gen/lagrangebench_data/dataset_ldc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # CUDA_VISIBLE_DEVICES=0 nohup ./scripts/dataset_ldc.sh 2>&1 &
3 |
4 | DATA_ROOT=/tmp/lagrangebench_data
5 |
6 | ##### 2D dataset
7 | python main.py config=cases/ldc.yaml case.dim=2 case.dx=0.02 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/ io.p_bg_factor=0.0
8 | python main.py config=cases/ldc.yaml case.dim=2 case.dx=0.02 solver.dt=0.0004 solver.t_end=85 case.state0_path=$DATA_ROOT/relaxed/ldc_2_0.02_123.h5 io.data_path=$DATA_ROOT/raw/2D_LDC_2500_10kevery100/
9 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/2D_LDC_2500_10kevery100/ --dst_dir=$DATA_ROOT/datasets/2D_LDC_2500_10kevery100/ --split=2_1_1 --skip_first_n_frames=1248
10 |
11 | # dt_coarse = 0.0004 * 100 = 0.04
12 | # to get 20k samples, we simulate for t = 20k * dt_coarse = 800 (+50 for equilibriation)
13 |
14 | ##### 3D dataset
15 | python main.py config=cases/ldc.yaml case.dim=3 case.dx=0.041666667 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/ io.p_bg_factor=0.0
16 | python main.py config=cases/ldc.yaml case.dim=3 case.dx=0.041666667 solver.dt=0.0009 solver.t_end=1850 case.state0_path=$DATA_ROOT/relaxed/ldc_3_0.041666667_123.h5 io.data_path=$DATA_ROOT/raw/3D_LDC_8160_10kevery100/
17 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/3D_LDC_8160_10kevery100/ --dst_dir=$DATA_ROOT/datasets/3D_LDC_8160_10kevery100/ --split=2_1_1 --skip_first_n_frames=555
18 |
19 | # dt_coarse = 0.0009 * 100 = 0.09
20 | # to get 20k samples, we simulate for t = 20k * dt_coarse = 1800 (+50 for equilibriation)
21 |
--------------------------------------------------------------------------------
/data_gen/lagrangebench_data/dataset_rpf.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # run this script with:
3 | # CUDA_VISIBLE_DEVICES=0 nohup ./scripts/dataset_rpf.sh 2>&1 &
4 |
5 | DATA_ROOT=/tmp/lagrangebench_data
6 |
7 | ##### 2D RPF
8 | python main.py config=cases/rpf.yaml case.dim=2 case.dx=0.025 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/
9 | python main.py config=cases/rpf.yaml case.dim=2 case.dx=0.025 solver.dt=0.0005 solver.t_end=2050 case.state0_path=$DATA_ROOT/relaxed/rpf_2_0.025_123.h5 io.data_path=$DATA_ROOT/raw/2D_RPF_3200_20kevery100/
10 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/2D_RPF_3200_20kevery100/ --dst_dir=$DATA_ROOT/datasets/2D_RPF_3200_20kevery100/ --split=2_1_1 --skip_first_n_frames=998
11 |
12 | ###### 3D RPF
13 | python main.py config=cases/rpf.yaml case.dim=3 case.dx=0.05 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/
14 | python main.py config=cases/rpf.yaml case.dim=3 case.dx=0.05 solver.dt=0.001 solver.t_end=2050 case.state0_path=$DATA_ROOT/relaxed/rpf_3_0.05_123.h5 io.data_path=$DATA_ROOT/raw/3D_RPF_8000_10kevery100/ eos.p_bg_factor=0.02
15 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/3D_RPF_8000_10kevery100/ --dst_dir=$DATA_ROOT/datasets/3D_RPF_8000_10kevery100/ --split=2_1_1 --skip_first_n_frames=498
16 |
17 | # dt_coarse = 0.001 * 100 = 0.1
18 | # to get 20k samples, we simulate for t = 20k * dt_coarse = 2000 (+50 for equilibriation)
19 |
20 | # Also, for a fast particle (vel=1) to cross the box once (length=1) it takes t=1.
21 | # => it takes 10 steps to cross the box with dt_coarse = 0.1
22 | # with 20k samples for train+valid+test, the box is crossed 2000 times
--------------------------------------------------------------------------------
/data_gen/lagrangebench_data/dataset_tgv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # CUDA_VISIBLE_DEVICES=0 nohup ./scripts/dataset_tgv.sh 2>&1 &
3 |
4 | DATA_ROOT=/tmp/lagrangebench_data
5 |
6 | ###### 2D TGV
7 | for seed in {0..199}
8 | do
9 | echo "Run with seed = $seed"
10 | python main.py config=cases/tgv.yaml seed=$seed case.dim=2 case.dx=0.02 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/
11 | python main.py config=cases/tgv.yaml seed=$seed case.dim=2 case.dx=0.02 solver.dt=0.0004 solver.t_end=5 case.state0_path=$DATA_ROOT/relaxed/tgv_2_0.02_$seed.h5 io.data_path=$DATA_ROOT/raw/2D_TGV_2500_10kevery100/
12 | done
13 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/2D_TGV_2500_10kevery100/ --dst_dir=$DATA_ROOT/datasets/2D_TGV_2500_10kevery100/ --split=2_1_1
14 |
15 | ###### 3D TGV
16 | for seed in {0..399}
17 | do
18 | echo "Run with seed = $seed"
19 | python main.py config=cases/tgv.yaml seed=$seed case.dim=3 case.dx=0.314159265 case.mode=rlx solver.tvf=1.0 case.r0_noise_factor=0.25 io.data_path=$DATA_ROOT/relaxed/ eos.p_bg_factor=0.01
20 | python main.py config=cases/tgv.yaml seed=$seed case.dim=3 case.dx=0.314159265 solver.dt=0.005 solver.t_end=30 case.state0_path=$DATA_ROOT/relaxed/tgv_3_0.314159265_$seed.h5 io.data_path=$DATA_ROOT/raw/3D_TGV_8000_10kevery100/ case.viscosity=0.02
21 | done
22 | python gen_dataset.py --src_dir=$DATA_ROOT/raw/3D_TGV_8000_10kevery100/ --dst_dir=$DATA_ROOT/datasets/3D_TGV_8000_10kevery100/ --split=2_1_1
23 |
--------------------------------------------------------------------------------
/data_gen/lagrangebench_data/gen_dataset.py:
--------------------------------------------------------------------------------
1 | """Script for generating ML datasets from h5 simulation frames"""
2 |
3 | import argparse
4 | import json
5 | import os
6 |
7 | import h5py
8 | import numpy as np
9 | from jax import vmap
10 | from jax_sph.io_state import read_h5, write_h5
11 | from jax_sph.jax_md import space
12 | from omegaconf import OmegaConf
13 |
14 |
15 | def write_h5_frame_for_visualization(state_dict, file_path_h5):
16 | path_file_vis = os.path.join(file_path_h5[:-3] + "_vis.h5")
17 | print("writing to", path_file_vis)
18 | write_h5(state_dict, path_file_vis)
19 | print("done")
20 |
21 |
22 | def single_h5_files_to_h5_dataset(args):
23 | """Transform a set of .h5 files to a single .h5 dataset file
24 |
25 | Args:
26 | src_dir: source directory containing other directories, each with .h5 files
27 | corresponding to a trajectory
28 | dst_dir: destination directory where three files will be written: train.h5,
29 | valid.h5, and test.h5
30 | split: string of three integers separated by underscores, e.g. "80_10_10"
31 | """
32 |
33 | os.makedirs(args.dst_dir, exist_ok=True)
34 |
35 | # list only directories in a root with files and directories
36 | dirs = os.listdir(args.src_dir)
37 | dirs = [d for d in dirs if os.path.isdir(os.path.join(args.src_dir, d))]
38 | # order by seed value
39 | dirs = sorted(dirs, key=lambda x: int(x.split("_")[3]))
40 |
41 | splits_array = np.array([int(s) for s in args.split.split("_")])
42 | splits_sum = splits_array.sum()
43 |
44 | if len(dirs) == 1: # split one long trajectory into train, valid, and test
45 | files = os.listdir(os.path.join(args.src_dir, dirs[0]))
46 | files = [f for f in files if (".h5" in f)]
47 | files = sorted(files, key=lambda x: int(x.split("_")[1][:-3]))
48 | files = files[args.skip_first_n_frames :: args.slice_every_nth_frame]
49 |
50 | num_eval = np.ceil(splits_array[1] / splits_sum * len(files)).astype(int)
51 | # at least one validation and one testing trajectory
52 | splits_trajs = np.cumsum([0, len(files) - 2 * num_eval, num_eval, num_eval])
53 |
54 | num_trajs_train = num_trajs_test = 1
55 |
56 | sequence_length_train, sequence_length_test = splits_trajs[1] - 1, num_eval - 1
57 | else: # multiple trajectories
58 | num_eval = np.ceil(splits_array[1] / splits_sum * len(dirs)).astype(int)
59 | # at least one validation and one testing trajectory
60 | splits_trajs = np.cumsum([0, len(dirs) - 2 * num_eval, num_eval, num_eval])
61 |
62 | num_trajs_train, num_trajs_test = len(dirs) - 2 * num_eval, num_eval
63 |
64 | # seqience_length should be after subsampling every nth trajectory
65 | # and "-1" because of the last target position (see GNS dataset format)
66 | files_per_traj = len(os.listdir(os.path.join(args.src_dir, dirs[0])))
67 | sequence_length_train = sequence_length_test = files_per_traj - 1
68 |
69 | for i, split in enumerate(["train", "valid", "test"]):
70 | hf = h5py.File(os.path.join(args.dst_dir, f"{split}.h5"), "w")
71 |
72 | if len(dirs) == 1: # one long trajectory
73 | position = []
74 | traj_path = os.path.join(args.src_dir, dirs[0])
75 |
76 | for j, filename in enumerate(files[splits_trajs[i] : splits_trajs[i + 1]]):
77 | file_path_h5 = os.path.join(traj_path, filename)
78 | state = read_h5(file_path_h5, array_type="numpy")
79 | r = state["r"]
80 | tag = state["tag"]
81 |
82 | if "ldc" in args.src_dir.lower(): # remove outer walls in lid-driven
83 | L, H = 1.0, 1.0
84 | cfg = OmegaConf.load(os.path.join(traj_path, "config.yaml"))
85 | mask_bottom = np.where(r[:, 1] < 2 * cfg.case.dx, False, True)
86 | mask_lid = np.where(r[:, 1] > H + 4 * cfg.case.dx, False, True)
87 | mask_left = np.where(
88 | ((r[:, 0] < 2 * cfg.case.dx) * (tag == 1)), False, True
89 | )
90 | mask_right = np.where(
91 | (r[:, 0] > L + 4 * cfg.case.dx) * (tag == 1), False, True
92 | )
93 | mask = mask_bottom * mask_lid * mask_left * mask_right
94 |
95 | r = r[mask]
96 | tag = tag[mask]
97 |
98 | if args.is_visualize:
99 | write_h5_frame_for_visualization({"r": r, "tag": tag}, file_path_h5)
100 | position.append(r)
101 |
102 | position = np.stack(position) # (time steps, particles, dim)
103 | particle_type = tag # (particles,)
104 |
105 | traj_str = "00000"
106 | hf.create_dataset(f"{traj_str}/particle_type", data=particle_type)
107 | hf.create_dataset(
108 | f"{traj_str}/position",
109 | data=position,
110 | dtype=np.float32,
111 | compression="gzip",
112 | )
113 |
114 | else: # multiple trajectories
115 | for j, dir in enumerate(dirs[splits_trajs[i] : splits_trajs[i + 1]]):
116 | traj_path = os.path.join(args.src_dir, dir)
117 | files = os.listdir(traj_path)
118 | files = [f for f in files if (".h5" in f)]
119 | files = sorted(files, key=lambda x: int(x.split("_")[1][:-3]))
120 | files = files[args.skip_first_n_frames :: args.slice_every_nth_frame]
121 |
122 | position = []
123 | for k, filename in enumerate(files):
124 | file_path_h5 = os.path.join(traj_path, filename)
125 | state = read_h5(file_path_h5, array_type="numpy")
126 | r = state["r"]
127 | tag = state["tag"]
128 |
129 | if "db" in args.src_dir.lower(): # remove outer walls in dam break
130 | L, H = 5.366, 2.0
131 | cfg = OmegaConf.load(os.path.join(traj_path, "config.yaml"))
132 | mask_bottom = np.where(r[:, 1] < 2 * cfg.sase.dx, False, True)
133 | mask_lid = np.where(r[:, 1] > H + 4 * cfg.case.dx, False, True)
134 | mask_left = np.where(
135 | ((r[:, 0] < 2 * cfg.case.dx) * (tag == 1)), False, True
136 | )
137 | mask_right = np.where(
138 | (r[:, 0] > L + 4 * cfg.case.dx) * (tag == 1), False, True
139 | )
140 | mask = mask_bottom * mask_lid * mask_left * mask_right
141 |
142 | r = r[mask]
143 | tag = tag[mask]
144 |
145 | if args.is_visualize:
146 | write_h5_frame_for_visualization(
147 | {"r": r, "tag": tag}, file_path_h5
148 | )
149 | position.append(r)
150 | position = np.stack(position) # (time steps, particles, dim)
151 | particle_type = tag # (particles,)
152 |
153 | traj_str = str(j).zfill(5)
154 | hf.create_dataset(f"{traj_str}/particle_type", data=particle_type)
155 | hf.create_dataset(
156 | f"{traj_str}/position",
157 | data=position,
158 | dtype=np.float32,
159 | compression="gzip",
160 | )
161 |
162 | hf.close()
163 | print(f"Finished {args.src_dir} {split} with {j+1} entries!")
164 | print(f"Sample positions shape {position.shape}")
165 |
166 | # metadata
167 | # Compatible with the lagrangebench metadata.json files
168 | cfg = OmegaConf.load(os.path.join(traj_path, "config.yaml"))
169 |
170 | metadata = {
171 | "case": cfg.case.name.upper(),
172 | "solver": cfg.solver.name,
173 | "density_evolution": cfg.solver.density_evolution,
174 | "dim": cfg.case.dim,
175 | "dx": cfg.case.dx,
176 | "dt": cfg.solver.dt,
177 | "t_end": cfg.solver.t_end,
178 | "viscosity": cfg.case.viscosity,
179 | "p_bg_factor": cfg.eos.p_bg_factor,
180 | "g_ext_magnitude": cfg.case.g_ext_magnitude,
181 | "artificial_alpha": cfg.solver.artificial_alpha,
182 | "free_slip": cfg.solver.free_slip,
183 | "write_every": cfg.io.write_every,
184 | "is_bc_trick": cfg.solver.is_bc_trick,
185 | "sequence_length_train": int(sequence_length_train),
186 | "num_trajs_train": int(num_trajs_train),
187 | "sequence_length_test": int(sequence_length_test),
188 | "num_trajs_test": int(num_trajs_test),
189 | "num_particles_max": cfg.case.num_particles_max,
190 | "periodic_boundary_conditions": list(cfg.case.pbc),
191 | "bounds": np.array(cfg.case.bounds).tolist(),
192 | }
193 | x = 1.45 * cfg.case["dx"] # around 1.5 dx
194 | x = np.format_float_positional(
195 | x, precision=2, unique=False, fractional=False, trim="k"
196 | )
197 | metadata["default_connectivity_radius"] = float(x)
198 |
199 | with open(os.path.join(args.dst_dir, "metadata.json"), "w") as f:
200 | json.dump(metadata, f)
201 |
202 |
203 | def compute_statistics_h5(args):
204 | """Compute the mean and std of a h5 dataset files"""
205 |
206 | # metadata
207 | with open(os.path.join(args.dst_dir, "metadata.json"), "r") as f:
208 | metadata = json.load(f)
209 |
210 | # apply PBC in all directions or not at all
211 | if np.array(metadata["periodic_boundary_conditions"]).any():
212 | box = np.array(metadata["bounds"])
213 | box = box[:, 1] - box[:, 0]
214 | displacement_fn, _ = space.periodic(side=box)
215 | else:
216 | displacement_fn, _ = space.free()
217 |
218 | displacement_fn_sets = vmap(vmap(displacement_fn, in_axes=(0, 0)))
219 |
220 | vels, accs = [], []
221 | vels_sq, accs_sq = [], []
222 | vel_mean = acc_mean = 0.0 # to fix "F821 Undefined name ..." ruff error
223 | for loop in ["mean", "std"]:
224 | for split in ["train", "valid", "test"]:
225 | hf = h5py.File(os.path.join(args.dst_dir, f"{split}.h5"), "r")
226 |
227 | for _, v in hf.items():
228 | tag = v.get("particle_type")[:]
229 | r = v.get("position")[:][:, tag == 0] # only fluid ("0") particles
230 |
231 | # The velocity and acceleration computation is based on an
232 | # inversion of Semi-Implicit Euler
233 | vel = displacement_fn_sets(r[1:], r[:-1])
234 | if loop == "mean":
235 | vels.append(vel.mean((0, 1)))
236 | accs.append((vel[1:] - vel[:-1]).mean((0, 1)))
237 | elif loop == "std":
238 | centered_vel = vel - vel_mean
239 | vels_sq.append(np.square(centered_vel).mean((0, 1)))
240 | centered_acc = vel[1:] - vel[:-1] - acc_mean
241 | accs_sq.append(np.square(centered_acc).mean((0, 1)))
242 |
243 | hf.close()
244 |
245 | if loop == "mean":
246 | vel_mean = np.stack(vels).mean(0)
247 | acc_mean = np.stack(accs).mean(0)
248 | print(f"vel_mean={vel_mean}, acc_mean={acc_mean}")
249 | elif loop == "std":
250 | vel_std = np.stack(vels_sq).mean(0) ** 0.5
251 | acc_std = np.stack(accs_sq).mean(0) ** 0.5
252 | print(f"vel_std={vel_std}, acc_std={acc_std}")
253 |
254 | # stds should not be 0. If they are, set them to 1.
255 | vel_std = np.where(vel_std < 1e-7, 1, vel_std)
256 | acc_std = np.where(acc_std < 1e-7, 1, acc_std)
257 |
258 | metadata["vel_mean"] = vel_mean.tolist()
259 | metadata["vel_std"] = vel_std.tolist()
260 |
261 | metadata["acc_mean"] = acc_mean.tolist()
262 | metadata["acc_std"] = acc_std.tolist()
263 |
264 | with open(os.path.join(args.dst_dir, "metadata.json"), "w") as f:
265 | json.dump(metadata, f)
266 |
267 |
268 | if __name__ == "__main__":
269 | parser = argparse.ArgumentParser()
270 | parser.add_argument("--src_dir", type=str)
271 | parser.add_argument("--dst_dir", type=str)
272 | parser.add_argument("--split", type=str, help="E.g. 3_1_1")
273 | parser.add_argument("--skip_first_n_frames", type=int, default=0)
274 | parser.add_argument("--slice_every_nth_frame", type=int, default=1)
275 | parser.add_argument("--is_visualize", action="store_true")
276 | args = parser.parse_args()
277 |
278 | single_h5_files_to_h5_dataset(args)
279 | compute_statistics_h5(args)
280 |
--------------------------------------------------------------------------------
/data_gen/lagrangebench_data/plot_frame.py:
--------------------------------------------------------------------------------
1 | """Print a frame for visual inspection of the data."""
2 |
3 | import argparse
4 |
5 | import h5py
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | def plot_frame(src_dir, frame):
10 | with h5py.File(src_dir, "r") as f:
11 | tag = f["00000/particle_type"][:]
12 | r = f["00000/position"][frame]
13 |
14 | plt.scatter(r[:, 0], r[:, 1], c=tag)
15 | plt.savefig(f"frame_{frame}.png")
16 |
17 |
18 | if __name__ == "__main__":
19 | parser = argparse.ArgumentParser(
20 | description="Print a frame for visual inspection of the data."
21 | )
22 | parser.add_argument("--src_dir", type=str, help="Source directory.")
23 | parser.add_argument("--frame", type=int, help="Which frame to plot.")
24 | args = parser.parse_args()
25 |
26 | plot_frame(args.src_dir, args.frame)
27 |
28 | # Example:
29 | # python plot_frame.py --src_dir=datasets/2D_TGV_2500_10kevery100/train.h5 --frame=0
30 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | project = "LagrangeBench"
10 | copyright = "2023, Chair of Aerodynamics and Fluid Mechanics, TUM"
11 | author = "Artur Toshev, Gianluca Galletti"
12 |
13 | # read the version from pyproject.toml
14 | import toml
15 |
16 | pyproject = toml.load("../pyproject.toml")
17 | version = pyproject["tool"]["poetry"]["version"]
18 |
19 | # -- Path setup --------------------------------------------------------------
20 |
21 | # If extensions (or modules to document with autodoc) are in another directory,
22 | # add these directories to sys.path here. If the directory is relative to the
23 | # documentation root, use os.path.abspath to make it absolute, like shown here.
24 | #
25 | import os
26 | import sys
27 |
28 | sys.path.insert(0, os.path.abspath(".."))
29 |
30 | import collections
31 |
32 | # -- General configuration ---------------------------------------------------
33 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
34 |
35 | extensions = [
36 | "sphinx.ext.autodoc",
37 | "sphinx.ext.viewcode",
38 | "sphinx.ext.napoleon",
39 | "sphinx.ext.intersphinx",
40 | "sphinx.ext.mathjax",
41 | # to get defaults.py in the documentation
42 | "sphinx_exec_code",
43 | ]
44 |
45 | numfig = True
46 |
47 | templates_path = ["_templates"]
48 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
49 |
50 |
51 | # -- Options for HTML output -------------------------------------------------
52 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
53 |
54 | html_theme = "sphinx_rtd_theme"
55 | html_static_path = ["_static"]
56 |
57 |
58 | # -- Options for autodoc -----------------------------------------------------
59 |
60 | autodoc_default_options = {
61 | "member-order": "bysource",
62 | "special-members": True,
63 | "exclude-members": "__repr__, __str__, __weakref__",
64 | }
65 |
66 |
67 | # -- Options for sphinx-exec-code ---------------------------------------------
68 |
69 | exec_code_working_dir = ".."
70 |
71 |
72 | # drop the docstrings of undocumented the namedtuple attributes
73 | def remove_namedtuple_attrib_docstring(app, what, name, obj, skip, options):
74 | if type(obj) is collections._tuplegetter:
75 | return True
76 | return skip
77 |
78 |
79 | def setup(app):
80 | app.connect("autodoc-skip-member", remove_namedtuple_attrib_docstring)
81 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. LagrangeBench documentation master file, created by
2 | sphinx-quickstart on Fri Aug 18 19:44:58 2023.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | LagrangeBench
7 | =============
8 |
9 | .. image:: https://drive.google.com/thumbnail?id=1rP0pf1KL8iGbly0tA0qthUE_tMDv_9Jp&sz=w1000
10 | :alt: rpf2d.gif
11 |
12 | .. image:: https://drive.google.com/thumbnail?id=1BMGkHj9EYMGUOdsE5QwiJWCTvDNqveHc&sz=w1000
13 | :alt: rpf3d.gif
14 |
15 |
16 | What is ``LagrangeBench``?
17 | --------------------------
18 |
19 | LagrangeBench is a machine learning benchmarking suite for **Lagrangian particle
20 | problems** based on the `JAX `_ library. It provides:
21 |
22 | - **Data loading and preprocessing** utilities for particle data.
23 | - Three different **neighbors search routines**: (a) original JAX-MD implementation, (b)
24 | memory efficient version of the JAX-MD implementation, and (c) a wrapper around the
25 | matscipy implementation allowing to handle variable number of particles.
26 | - JAX reimplementation of established **graph neural networks**: GNS, SEGNN, EGNN, PaiNN.
27 | - **Training strategies** including random-walk additive noise and the pushforward trick.
28 | - Evaluation tools consisting of **rollout generation** and different **error metrics**:
29 | position MSE, kinetic energy MSE, and Sinkhorn distance for the particle distribution.
30 |
31 |
32 | .. note::
33 |
34 | For more details on LagrangeBench usage check out our `tutorials `_.
35 |
36 |
37 |
38 | Data loading and preprocessing
39 | ------------------------------
40 |
41 | First, we create a dataset class based on ``torch.utils.data.Dataset``.
42 | We then initialize a ``CaseSetupFn`` object taking care of the neighbors search,
43 | preprocessing, and time integration.
44 |
45 | .. code-block:: python
46 |
47 | import lagrangebench
48 |
49 | # Load data
50 | data_train = lagrangebench.RPF2D("train")
51 | data_valid = lagrangebench.RPF2D("valid", extra_seq_length=20)
52 | data_test = lagrangebench.RPF2D("test", extra_seq_length=20)
53 |
54 | # Case setup (preprocessing and graph building)
55 | bounds = np.array(data_train.metadata["bounds"])
56 | box = bounds[:, 1] - bounds[:, 0]
57 | case = lagrangebench.case_builder(
58 | box=box,
59 | metadata=data_train.metadata,
60 | input_seq_length=6,
61 | )
62 |
63 |
64 | Models
65 | ------
66 |
67 | Initialize a GNS model.
68 |
69 | .. code-block:: python
70 |
71 | import haiku as hk
72 |
73 | def gns(x):
74 | return lagrangebench.models.GNS(
75 | particle_dimension=data_train.metadata["dim"],
76 | latent_size=16,
77 | blocks_per_step=2,
78 | num_mp_steps=4,
79 | particle_type_embedding_size=8,
80 | )(x)
81 |
82 | gns = hk.without_apply_rng(hk.transform_with_state(gns))
83 |
84 |
85 | Training
86 | --------
87 |
88 | The ``Trainer`` provides a convenient way to train a model.
89 |
90 | .. code-block:: python
91 |
92 | trainer = lagrangebench.Trainer(
93 | model=gns,
94 | case=case,
95 | data_train=data_train,
96 | data_valid=data_valid,
97 | cfg_eval={"n_rollout_steps": 20, "train": {"metrics": ["mse"]}},
98 | input_seq_length=6
99 | )
100 |
101 | # Train for 25000 steps
102 | params, state, _ = trainer.train(step_max=25000)
103 |
104 |
105 | Evaluation
106 | ----------
107 |
108 | When training is done, we can evaluate the model on the test set.
109 |
110 | .. code-block:: python
111 |
112 | metrics = lagrangebench.infer(
113 | gns,
114 | case,
115 | data_test,
116 | params,
117 | state,
118 | cfg_eval_infer={"metrics": ["mse", "sinkhorn", "e_kin"]},
119 | n_rollout_steps=20,
120 | )
121 |
122 |
123 | .. toctree::
124 | :maxdepth: 2
125 | :caption: Getting Started
126 |
127 | pages/tutorial
128 | pages/defaults
129 | pages/baselines
130 |
131 | .. toctree::
132 | :maxdepth: 2
133 | :caption: API
134 |
135 | pages/data
136 | pages/case_setup
137 | pages/models
138 | pages/train
139 | pages/evaluate
140 | pages/utils
141 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/pages/baselines.rst:
--------------------------------------------------------------------------------
1 | Baselines
2 | ===================================
3 |
4 | The table below provides inference performance and baseline results on all LagrangeBench datasets.
5 | Runtimes are evaluated Nvidia A6000 48GB GPU.
6 |
7 | .. note::
8 |
9 | Result discussion and hyperparams can be found in the full paper `"LagrangeBench: A Lagrangian Fluid Mechanics Benchmarking Suite" `_.
10 |
11 |
12 | .. raw:: html
13 |
14 |
38 |
39 |
40 | Runtime performance and baseline results.
41 |
42 |
43 | |
44 | Model |
45 | #Params |
46 | Forward [ms] |
47 | MSE5 |
48 | MSE20 |
49 |
50 |
51 |
52 |
53 | | TGV 2D (2.5K) |
54 | GNS-5-64 |
55 | 161K |
56 | 1.4 |
57 | 6.4e-7 |
58 | 9.6e-6 |
59 |
60 |
61 | | GNS-10-128 |
62 | 1.2M |
63 | 5.3 |
64 | 3.9e-7 |
65 | 6.6e-6 |
66 |
67 |
68 | | SEGNN-5-64 |
69 | 183K |
70 | 9.8 |
71 | 3.8e-7 |
72 | 6.5e-6 |
73 |
74 |
75 | | SEGNN-10-64 |
76 | 360K |
77 | 20.2 |
78 | 2.4e-7 |
79 | 4.4e-6 |
80 |
81 |
82 | | RPF 2D (3.2K) |
83 | GNS-5-64 |
84 | 161K |
85 | 2.1 |
86 | 4.0e-7 |
87 | 9.8e-6 |
88 |
89 |
90 | | GNS-10-128 |
91 | 1.2M |
92 | 6.7 |
93 | 1.1e-7 |
94 | 3.3e-6 |
95 |
96 |
97 | | SEGNN-5-64 |
98 | 183K |
99 | 15.1 |
100 | 1.3e-7 |
101 | 4.0e-6 |
102 |
103 |
104 | | SEGNN-10-64 |
105 | 360K |
106 | 29.7 |
107 | 1.3e-7 |
108 | 4.0e-6 |
109 |
110 |
111 | | EGNN-5-128 |
112 | 663K |
113 | 60.8 |
114 | unstable |
115 | unstable |
116 |
117 |
118 | | PaiNN-5-128 |
119 | 1.0M |
120 | 9.1 |
121 | 3.0e-6 |
122 | 7.2e-5 |
123 |
124 |
125 | | LDC 2D (2.7K) |
126 | GNS-5-64 |
127 | 161K |
128 | 1.5 |
129 | 2.0e-6 |
130 | 1.7e-5 |
131 |
132 |
133 | | GNS-10-128 |
134 | 1.2M |
135 | 5.7 |
136 | 6.4e-7 |
137 | 1.4e-5 |
138 |
139 |
140 | | SEGNN-5-64 |
141 | 183K |
142 | 10.0 |
143 | 9.9e-7 |
144 | 1.7e-5 |
145 |
146 |
147 | | SEGNN-10-64 |
148 | 360K |
149 | 21.1 |
150 | 1.4e-6 |
151 | 2.5e-5 |
152 |
153 |
154 | | DAM 2D (5.7K) |
155 | GNS-5-64 |
156 | 161K |
157 | 3.8 |
158 | 2.1e-6 |
159 | 6.3e-5 |
160 |
161 |
162 | | GNS-10-128 |
163 | 1.2M |
164 | 11.9 |
165 | 1.3e-6 |
166 | 3.3e-5 |
167 |
168 |
169 | | SEGNN-5-64 |
170 | 183K |
171 | 28.8
|
172 | 2.6e-6 |
173 | 1.4e-4 |
174 |
175 |
176 | | SEGNN-10-64 |
177 | 360K |
178 | 59.2 |
179 | 1.9e-6 |
180 | 1.1e-4 |
181 |
182 |
183 | | TGV 3D (8.0K) |
184 | GNS-5-64 |
185 | 161K |
186 | 8.4 |
187 | 3.8e-4 |
188 | 8.3e-3 |
189 |
190 |
191 | | GNS-10-128 |
192 | 1.2M |
193 | 30.5 |
194 | 2.1e-4 |
195 | 5.8e-3 |
196 |
197 |
198 | | SEGNN-5-64 |
199 | 183K |
200 | 79.4 |
201 | 3.1e-4 |
202 | 7.7e-3 |
203 |
204 |
205 | | SEGNN-10-64 |
206 | 360K |
207 | 154.3 |
208 | 1.7e-4 |
209 | 5.2e-3 |
210 |
211 |
212 | | RPF 3D (8.0K) |
213 | GNS-5-64 |
214 | 161K |
215 | 8.4 |
216 | 1.3e-6 |
217 | 5.2e-5 |
218 |
219 |
220 | | GNS-10-128 |
221 | 1.2M |
222 | 30.5 |
223 | 3.3e-7 |
224 | 1.9e-5 |
225 |
226 |
227 | | SEGNN-5-64 |
228 | 183K |
229 | 79.4 |
230 | 6.6e-7 |
231 | 3.1e-5 |
232 |
233 |
234 | | SEGNN-10-64 |
235 | 360K |
236 | 154.3 |
237 | 3.0e-7 |
238 | 1.8e-5 |
239 |
240 |
241 | | EGNN-5-128 |
242 | 663K |
243 | 250.7 |
244 | unstable |
245 | unstable |
246 |
247 |
248 | | PaiNN-5-128 |
249 | 1.0M |
250 | 43.0 |
251 | 1.8e-5 |
252 | 3.6e-4 |
253 |
254 |
255 | | LDC 3D (8.2K) |
256 | GNS-5-64 |
257 | 161K |
258 | 8.6 |
259 | 1.7e-6 |
260 | 5.7e-5 |
261 |
262 |
263 | | GNS-10-128 |
264 | 1.2M |
265 | 32.0 |
266 | 7.4e-7 |
267 | 4.0e-5 |
268 |
269 |
270 | | SEGNN-5-64 |
271 | 183K |
272 | 81.2 |
273 | 1.2e-6 |
274 | 4.8e-5 |
275 |
276 |
277 | | SEGNN-10-64 |
278 | 360K |
279 | 161.2 |
280 | 9.4e-7 |
281 | 4.4e-5 |
282 |
283 |
284 |
285 |
--------------------------------------------------------------------------------
/docs/pages/case_setup.rst:
--------------------------------------------------------------------------------
1 | Case Setup
2 | ===================================
3 |
4 | Case
5 | ----
6 | .. automodule:: lagrangebench.case_setup.case
7 | :members:
8 |
9 | Featurizer
10 | ----------
11 | .. automodule:: lagrangebench.case_setup.features
12 | :members:
13 |
--------------------------------------------------------------------------------
/docs/pages/data.rst:
--------------------------------------------------------------------------------
1 | Data
2 | ===================================
3 |
4 | Data
5 | ----
6 | .. automodule:: lagrangebench.data.data
7 | :members:
8 |
9 | Utils
10 | -----
11 | .. automodule:: lagrangebench.data.utils
12 | :members:
13 |
--------------------------------------------------------------------------------
/docs/pages/defaults.rst:
--------------------------------------------------------------------------------
1 | Defaults
2 | ===================================
3 |
4 |
5 |
6 | .. exec_code::
7 | :hide_code:
8 | :linenos_output:
9 | :language_output: python
10 | :caption: LagrangeBench default values
11 |
12 |
13 | with open("lagrangebench/defaults.py", "r") as file:
14 | defaults_full = file.read()
15 |
16 | # parse defaults: remove imports, only keep the set_defaults function
17 |
18 | defaults_full = defaults_full.split("\n")
19 |
20 | # remove imports
21 | defaults_full = [line for line in defaults_full if not line.startswith("import")]
22 | defaults_full = [line for line in defaults_full if len(line.replace(" ", "")) > 0]
23 |
24 | # remove other functions
25 | keep = False
26 | defaults = []
27 | for i, line in enumerate(defaults_full):
28 | if line.startswith("def"):
29 | if "set_defaults" in line:
30 | keep = True
31 | else:
32 | keep = False
33 |
34 | if keep:
35 | defaults.append(line)
36 |
37 | # remove function declaration and return
38 | defaults = defaults[2:-2]
39 |
40 | # remove indent
41 | defaults = [line[4:] for line in defaults]
42 |
43 |
44 | print("\n".join(defaults))
45 |
--------------------------------------------------------------------------------
/docs/pages/evaluate.rst:
--------------------------------------------------------------------------------
1 | Evaluate
2 | ===================================
3 |
4 | Rollout
5 | --------
6 | .. automodule:: lagrangebench.evaluate.rollout
7 | :members:
8 |
9 | Metrics
10 | -------
11 | .. automodule:: lagrangebench.evaluate.metrics
12 | :members:
13 |
14 | Utils
15 | -----
16 | .. automodule:: lagrangebench.evaluate.utils
17 | :members:
18 |
--------------------------------------------------------------------------------
/docs/pages/models.rst:
--------------------------------------------------------------------------------
1 | Models
2 | ===================================
3 |
4 | Base Class
5 | ----------
6 | .. automodule:: lagrangebench.models.base
7 | :members:
8 |
9 | GNS
10 | ---
11 | .. automodule:: lagrangebench.models.gns
12 | :members:
13 |
14 | SEGNN
15 | -----
16 | .. automodule:: lagrangebench.models.segnn
17 | :members:
18 |
19 | EGNN
20 | ----
21 | .. automodule:: lagrangebench.models.egnn
22 | :members:
23 |
24 | PaiNN
25 | -----
26 | .. automodule:: lagrangebench.models.painn
27 | :members:
28 |
29 | Linear
30 | ------
31 | .. automodule:: lagrangebench.models.linear
32 | :members:
33 |
34 | Utils
35 | -----
36 | .. automodule:: lagrangebench.models.utils
37 | :members:
38 | :exclude-members: __getnewargs__, __new__, __repr__
39 |
--------------------------------------------------------------------------------
/docs/pages/train.rst:
--------------------------------------------------------------------------------
1 | Train
2 | ===================================
3 |
4 | Trainer
5 | -------
6 | .. automodule:: lagrangebench.train.trainer
7 | :members:
8 | :exclude-members: __init__, __delattr__, __setattr__, __hash__, __eq__, __repr__, __weakref__
9 |
10 | Strategies
11 | ----------
12 | .. automodule:: lagrangebench.train.strats
13 | :members:
14 |
--------------------------------------------------------------------------------
/docs/pages/tutorial.rst:
--------------------------------------------------------------------------------
1 | Tutorials
2 | ===================================
3 |
4 | - Basics of LagrangeBench: `Training GNS on the 2D Taylor Green vortex `_ |tutorial|
5 | - Stats and specifics of our data: `Datasets overview `_ |datasets|
6 | - How to include datasets from other works: `Working with other datasets `_ |gns|
7 |
8 | .. |tutorial| raw:: html
9 |
10 |
11 |
12 | .. |datasets| raw:: html
13 |
14 |
15 |
16 | .. |gns| raw:: html
17 |
18 |
--------------------------------------------------------------------------------
/docs/pages/utils.rst:
--------------------------------------------------------------------------------
1 | Utils and Defaults
2 | ===================================
3 |
4 | Utils
5 | -----
6 | .. automodule:: lagrangebench.utils
7 | :members:
8 | :exclude-members: __init__, __delattr__, __setattr__, __hash__, __eq__, __repr__, __weakref__
9 |
10 |
11 | Defaults
12 | --------
13 | .. automodule:: lagrangebench.defaults
14 | :members:
15 | :exclude-members: __init__, __delattr__, __setattr__, __hash__, __eq__, __repr__, __weakref__
16 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cpu
2 |
3 | cloudpickle
4 | dm_haiku>=0.0.10
5 | e3nn_jax==0.20.3
6 | h5py
7 | jax-sph>=0.0.3
8 | jax[cpu]==0.4.29
9 | jmp>=0.0.4
10 | jraph>=0.0.6.dev0
11 | matscipy>=0.8.0
12 | omegaconf>=2.3.0
13 | optax>=0.1.7
14 | ott-jax>=0.4.2
15 | pyvista
16 | PyYAML
17 | sphinx==7.2.6
18 | sphinx-exec-code
19 | sphinx-rtd-theme==1.3.0
20 | toml>=0.10.2
21 | torch==2.3.1+cpu
22 | wandb
23 | wget
24 |
--------------------------------------------------------------------------------
/download_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Download datasets from Zenodo
3 | # Usage:
4 | # bash download_data.sh
5 | # Usage:
6 | # bash download_data.sh all datasets/
7 |
8 | declare -A datasets
9 | datasets["tgv_2d"]="2D_TGV_2500_10kevery100.zip"
10 | datasets["rpf_2d"]="2D_RPF_3200_20kevery100.zip"
11 | datasets["ldc_2d"]="2D_LDC_2708_10kevery100.zip"
12 | datasets["dam_2d"]="2D_DAM_5740_20kevery100.zip"
13 | datasets["tgv_3d"]="3D_TGV_8000_10kevery100.zip"
14 | datasets["rpf_3d"]="3D_RPF_8000_10kevery100.zip"
15 | datasets["ldc_3d"]="3D_LDC_8160_10kevery100.zip"
16 |
17 | if [ $# -ne 2 ]; then
18 | echo "Usage: bash download_data.sh "
19 | exit 1
20 | fi
21 |
22 | DATASET_NAME="$1"
23 | OUTPUT_DIR="$2"
24 | ZENODO_PREFIX="https://zenodo.org/records/10491868/files/"
25 |
26 | # Check if there is a trailing slash in $OUTPUT_DIR and remove it
27 | if [[ $OUTPUT_DIR == */ ]]; then
28 | OUTPUT_DIR="${OUTPUT_DIR%/}"
29 | echo "Output directory: ${OUTPUT_DIR}"
30 | fi
31 |
32 | # Create output directory if it doesn't exist
33 | if [ ! -d "${OUTPUT_DIR}" ]; then
34 | mkdir -p "${OUTPUT_DIR}"
35 | fi
36 |
37 | # Download the data
38 | if [ "${DATASET_NAME}" == "all" ]; then
39 | echo "Downloading all datasets"
40 | for key in ${!datasets[@]}; do
41 | echo "Downloading ${key}"
42 | wget ${ZENODO_PREFIX}${datasets[${key}]} -P "${OUTPUT_DIR}/"
43 | ZIP_PATH=${OUTPUT_DIR}/${datasets[${key}]}
44 | python3 -c "import zipfile; zipfile.ZipFile('$ZIP_PATH', 'r').extractall('$OUTPUT_DIR')"
45 | rm ${ZIP_PATH}
46 | done
47 | else
48 | echo "Downloading ${DATASET_NAME}"
49 | wget ${ZENODO_PREFIX}${datasets[${DATASET_NAME}]} -P "${OUTPUT_DIR}/"
50 | ZIP_PATH=${OUTPUT_DIR}/${datasets[${DATASET_NAME}]}
51 | python3 -c "import zipfile; zipfile.ZipFile('$ZIP_PATH', 'r').extractall('$OUTPUT_DIR')"
52 | rm ${ZIP_PATH}
53 | fi
54 |
--------------------------------------------------------------------------------
/lagrangebench/__init__.py:
--------------------------------------------------------------------------------
1 | from .case_setup.case import case_builder
2 | from .data import DAM2D, LDC2D, LDC3D, RPF2D, RPF3D, TGV2D, TGV3D, H5Dataset
3 | from .evaluate import infer
4 | from .models import EGNN, GNS, SEGNN, PaiNN
5 | from .train.trainer import Trainer
6 |
7 | __all__ = [
8 | "Trainer",
9 | "infer",
10 | "case_builder",
11 | "models",
12 | "GNS",
13 | "EGNN",
14 | "SEGNN",
15 | "PaiNN",
16 | "data",
17 | "H5Dataset",
18 | "TGV2D",
19 | "TGV3D",
20 | "RPF2D",
21 | "RPF3D",
22 | "LDC2D",
23 | "LDC3D",
24 | "DAM2D",
25 | ]
26 |
27 | __version__ = "0.2.0"
28 |
--------------------------------------------------------------------------------
/lagrangebench/case_setup/__init__.py:
--------------------------------------------------------------------------------
1 | """Case setup manager."""
2 |
3 | from .case import CaseSetupFn, case_builder
4 |
5 | __all__ = [
6 | "CaseSetupFn",
7 | "case_builder",
8 | ]
9 |
--------------------------------------------------------------------------------
/lagrangebench/case_setup/case.py:
--------------------------------------------------------------------------------
1 | """Case setup functions."""
2 |
3 | import warnings
4 | from typing import Callable, Dict, Optional, Tuple, Union
5 |
6 | import jax.numpy as jnp
7 | from jax import Array, jit, lax, vmap
8 | from jax_sph.jax_md import space
9 | from jax_sph.jax_md.dataclasses import dataclass, static_field
10 | from jax_sph.jax_md.partition import NeighborList, NeighborListFormat, neighbor_list
11 | from omegaconf import DictConfig, OmegaConf
12 |
13 | from lagrangebench.data.utils import get_dataset_stats
14 | from lagrangebench.defaults import defaults
15 | from lagrangebench.train.strats import add_gns_noise
16 |
17 | from .features import FeatureDict, TargetDict, physical_feature_builder
18 |
19 | TrainCaseOut = Tuple[Array, FeatureDict, TargetDict, NeighborList]
20 | EvalCaseOut = Tuple[FeatureDict, NeighborList]
21 | SampleIn = Tuple[jnp.ndarray, jnp.ndarray]
22 |
23 | AllocateFn = Callable[[Array, SampleIn, float, int], TrainCaseOut]
24 | AllocateEvalFn = Callable[[SampleIn], EvalCaseOut]
25 |
26 | PreprocessFn = Callable[[Array, SampleIn, float, NeighborList, int], TrainCaseOut]
27 | PreprocessEvalFn = Callable[[SampleIn, NeighborList], EvalCaseOut]
28 |
29 | IntegrateFn = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
30 |
31 |
32 | @dataclass
33 | class CaseSetupFn:
34 | """Dataclass that contains all functions required to setup the case and simulate.
35 |
36 | Attributes:
37 | allocate: AllocateFn, runs the preprocessing without having a NeighborList as
38 | input.
39 | preprocess: PreprocessFn, takes positions from the dataloader, computes
40 | velocities, adds random-walk noise if needed, then updates the neighbor
41 | list, and return the inputs to the neural network as well as the targets.
42 | allocate_eval: AllocateEvalFn, same as allocate, but without noise addition
43 | and without targets.
44 | preprocess_eval: PreprocessEvalFn, same as allocate_eval, but jit-able.
45 | integrate: IntegrateFn, semi-implicit Euler integrations step respecting
46 | all boundary conditions.
47 | displacement: space.DisplacementFn, displacement function aware of boundary
48 | conditions (periodic on non-periodic).
49 | normalization_stats: Dict, normalization statisticss for input velocities and
50 | output acceleration.
51 | """
52 |
53 | allocate: AllocateFn = static_field()
54 | preprocess: PreprocessFn = static_field()
55 | allocate_eval: AllocateEvalFn = static_field()
56 | preprocess_eval: PreprocessEvalFn = static_field()
57 | integrate: IntegrateFn = static_field()
58 | displacement: space.DisplacementFn = static_field()
59 | normalization_stats: Dict = static_field()
60 |
61 |
62 | def case_builder(
63 | box: Tuple[float, float, float],
64 | metadata: Dict,
65 | input_seq_length: int,
66 | cfg_neighbors: Union[Dict, DictConfig] = defaults.neighbors,
67 | cfg_model: Union[Dict, DictConfig] = defaults.model,
68 | noise_std: float = defaults.train.noise_std,
69 | external_force_fn: Optional[Callable] = None,
70 | dtype: jnp.dtype = defaults.dtype,
71 | ):
72 | """Set up a CaseSetupFn that contains every required function besides the model.
73 |
74 | Inspired by the `partition.neighbor_list` function in JAX-MD.
75 |
76 | The core functions are:
77 | * allocate, allocate memory for the neighbors list.
78 | * preprocess, update the neighbors list.
79 | * integrate, semi-implicit Euler respecting periodic boundary conditions.
80 |
81 | Args:
82 | box: Box xyz sizes of the system.
83 | metadata: Dataset metadata dictionary.
84 | input_seq_length: Length of the input sequence.
85 | cfg_neighbors: Configuration dictionary for the neighbor list.
86 | cfg_model: Configuration dictionary for the model / feature builder.
87 | noise_std: Noise standard deviation.
88 | external_force_fn: External force function.
89 | dtype: Data type.
90 | """
91 | if isinstance(cfg_neighbors, Dict):
92 | cfg_neighbors = OmegaConf.create(cfg_neighbors)
93 | if isinstance(cfg_model, Dict):
94 | cfg_model = OmegaConf.create(cfg_model)
95 |
96 | # if one of the cfg_* arguments has a subset of the default configs, merge them
97 | cfg_neighbors = OmegaConf.merge(defaults.neighbors, cfg_neighbors)
98 | cfg_model = OmegaConf.merge(defaults.model, cfg_model)
99 |
100 | normalization_stats = get_dataset_stats(
101 | metadata, cfg_model.isotropic_norm, noise_std
102 | )
103 |
104 | # apply PBC in all directions or not at all
105 | if jnp.array(metadata["periodic_boundary_conditions"]).any():
106 | displacement_fn, shift_fn = space.periodic(side=jnp.array(box))
107 | else:
108 | displacement_fn, shift_fn = space.free()
109 |
110 | displacement_fn_set = vmap(displacement_fn, in_axes=(0, 0))
111 |
112 | if cfg_neighbors.multiplier < 1.25:
113 | warnings.warn(
114 | f"cfg_neighbors.multiplier={cfg_neighbors.multiplier} < 1.25 is very low. "
115 | "Be especially cautious if you batch training and/or inference as "
116 | "reallocation might be necessary based on different overflow conditions. "
117 | "See https://github.com/tumaer/lagrangebench/pull/20#discussion_r1443811262"
118 | )
119 |
120 | neighbor_fn = neighbor_list(
121 | displacement_fn,
122 | jnp.array(box),
123 | backend=cfg_neighbors.backend,
124 | r_cutoff=metadata["default_connectivity_radius"],
125 | capacity_multiplier=cfg_neighbors.multiplier,
126 | mask_self=False,
127 | format=NeighborListFormat.Sparse,
128 | num_particles_max=metadata["num_particles_max"],
129 | pbc=metadata["periodic_boundary_conditions"],
130 | )
131 |
132 | feature_transform = physical_feature_builder(
133 | bounds=metadata["bounds"],
134 | normalization_stats=normalization_stats,
135 | connectivity_radius=metadata["default_connectivity_radius"],
136 | displacement_fn=displacement_fn,
137 | pbc=metadata["periodic_boundary_conditions"],
138 | magnitude_features=cfg_model.magnitude_features,
139 | external_force_fn=external_force_fn,
140 | )
141 |
142 | def _compute_target(pos_input: jnp.ndarray) -> TargetDict:
143 | # displacement(r1, r2) = r1-r2 # without PBC
144 |
145 | current_velocity = displacement_fn_set(pos_input[:, 1], pos_input[:, 0])
146 | next_velocity = displacement_fn_set(pos_input[:, 2], pos_input[:, 1])
147 | current_acceleration = next_velocity - current_velocity
148 |
149 | acc_stats = normalization_stats["acceleration"]
150 | normalized_acceleration = (
151 | current_acceleration - acc_stats["mean"]
152 | ) / acc_stats["std"]
153 |
154 | vel_stats = normalization_stats["velocity"]
155 | normalized_velocity = (next_velocity - vel_stats["mean"]) / vel_stats["std"]
156 | return {
157 | "acc": normalized_acceleration,
158 | "vel": normalized_velocity,
159 | "pos": pos_input[:, -1],
160 | }
161 |
162 | def _preprocess(
163 | sample: Tuple[jnp.ndarray, jnp.ndarray],
164 | neighbors: Optional[NeighborList] = None,
165 | is_allocate: bool = False,
166 | mode: str = "train",
167 | **kwargs, # key, noise_std, unroll_steps
168 | ) -> Union[TrainCaseOut, EvalCaseOut]:
169 | pos_input = jnp.asarray(sample[0], dtype=dtype)
170 | particle_type = jnp.asarray(sample[1])
171 |
172 | if mode == "train":
173 | key, noise_std = kwargs["key"], kwargs["noise_std"]
174 | unroll_steps = kwargs["unroll_steps"]
175 | if pos_input.shape[1] > 1:
176 | key, pos_input = add_gns_noise(
177 | key, pos_input, particle_type, input_seq_length, noise_std, shift_fn
178 | )
179 |
180 | # allocate the neighbor list
181 | most_recent_position = pos_input[:, input_seq_length - 1]
182 | num_particles = (particle_type != -1).sum()
183 | if is_allocate:
184 | neighbors = neighbor_fn.allocate(
185 | most_recent_position, num_particles=num_particles
186 | )
187 | else:
188 | neighbors = neighbors.update(
189 | most_recent_position, num_particles=num_particles
190 | )
191 |
192 | # selected features
193 | features = feature_transform(pos_input[:, :input_seq_length], neighbors)
194 |
195 | if mode == "train":
196 | # compute target acceleration. Inverse of postprocessing step.
197 | # the "-2" is needed because we need the most recent position and one before
198 | slice_begin = (0, input_seq_length - 2 + unroll_steps, 0)
199 | slice_size = (pos_input.shape[0], 3, pos_input.shape[2])
200 |
201 | target_dict = _compute_target(
202 | lax.dynamic_slice(pos_input, slice_begin, slice_size)
203 | )
204 | return key, features, target_dict, neighbors
205 | if mode == "eval":
206 | return features, neighbors
207 |
208 | def allocate_fn(key, sample, noise_std=0.0, unroll_steps=0):
209 | return _preprocess(
210 | sample,
211 | key=key,
212 | noise_std=noise_std,
213 | unroll_steps=unroll_steps,
214 | is_allocate=True,
215 | )
216 |
217 | @jit
218 | def preprocess_fn(key, sample, noise_std, neighbors, unroll_steps=0):
219 | return _preprocess(
220 | sample, neighbors, key=key, noise_std=noise_std, unroll_steps=unroll_steps
221 | )
222 |
223 | def allocate_eval_fn(sample):
224 | return _preprocess(sample, is_allocate=True, mode="eval")
225 |
226 | @jit
227 | def preprocess_eval_fn(sample, neighbors):
228 | return _preprocess(sample, neighbors, mode="eval")
229 |
230 | @jit
231 | def integrate_fn(normalized_in, position_sequence):
232 | """Euler integrator to get position shift."""
233 | assert any([key in normalized_in for key in ["pos", "vel", "acc"]])
234 |
235 | if "pos" in normalized_in:
236 | # Zeroth euler step
237 | return normalized_in["pos"]
238 | else:
239 | most_recent_position = position_sequence[:, -1]
240 | if "vel" in normalized_in:
241 | # invert normalization
242 | velocity_stats = normalization_stats["velocity"]
243 | new_velocity = velocity_stats["mean"] + (
244 | normalized_in["vel"] * velocity_stats["std"]
245 | )
246 | elif "acc" in normalized_in:
247 | # invert normalization.
248 | acceleration_stats = normalization_stats["acceleration"]
249 | acceleration = acceleration_stats["mean"] + (
250 | normalized_in["acc"] * acceleration_stats["std"]
251 | )
252 | # Second Euler step
253 | most_recent_velocity = displacement_fn_set(
254 | most_recent_position, position_sequence[:, -2]
255 | )
256 | new_velocity = most_recent_velocity + acceleration # * dt = 1
257 |
258 | # First Euler step
259 | return shift_fn(most_recent_position, new_velocity)
260 |
261 | return CaseSetupFn(
262 | allocate_fn,
263 | preprocess_fn,
264 | allocate_eval_fn,
265 | preprocess_eval_fn,
266 | integrate_fn,
267 | displacement_fn,
268 | normalization_stats,
269 | )
270 |
--------------------------------------------------------------------------------
/lagrangebench/case_setup/features.py:
--------------------------------------------------------------------------------
1 | """Feature extraction utilities."""
2 |
3 | from typing import Callable, Dict, List, Optional
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | from jax import lax, vmap
8 | from jax_sph.jax_md import partition, space
9 |
10 | FeatureDict = Dict[str, jnp.ndarray]
11 | TargetDict = Dict[str, jnp.ndarray]
12 |
13 |
14 | def physical_feature_builder(
15 | bounds: list,
16 | normalization_stats: dict,
17 | connectivity_radius: float,
18 | displacement_fn: Callable,
19 | pbc: List[bool],
20 | magnitude_features: bool = False,
21 | external_force_fn: Optional[Callable] = None,
22 | ) -> Callable:
23 | """Build a physical feature transform function.
24 |
25 | Transform raw coordinates to
26 | - Absolute positions
27 | - Historical velocity sequence
28 | - Velocity magnitudes
29 | - Distance to boundaries
30 | - External force field
31 | - Relative displacement vectors and distances
32 |
33 | Args:
34 | bounds: Each sublist contains the lower and upper bound of a dimension.
35 | normalization_stats: Dict containing mean and std of velocities and targets
36 | connectivity_radius: Radius of the connectivity graph.
37 | displacement_fn: Displacement function.
38 | pbc: Wether to use periodic boundary conditions.
39 | magnitude_features: Whether to include the magnitude of the velocity.
40 | external_force_fn: Function that returns the external force field (optional).
41 | """
42 | displacement_fn_vmap = vmap(displacement_fn, in_axes=(0, 0))
43 | displacement_fn_dvmap = vmap(displacement_fn_vmap, in_axes=(0, 0))
44 |
45 | velocity_stats = normalization_stats["velocity"]
46 |
47 | def feature_transform(
48 | pos_input: jnp.ndarray,
49 | nbrs: partition.NeighborList,
50 | ) -> FeatureDict:
51 | """Feature engineering.
52 |
53 | Returns:
54 | Dict of features, with possible keys
55 | - "abs_pos", absolute positions
56 | - "vel_hist", historical velocity sequence
57 | - "vel_mag", velocity magnitudes
58 | - "bound", distance to boundaries
59 | - "force", external force field
60 | - "rel_disp", relative displacement vectors
61 | - "rel_dist", relative distance vectors
62 | """
63 | features = {}
64 |
65 | n_total_points = pos_input.shape[0]
66 | most_recent_position = pos_input[:, -1] # (n_nodes, dim)
67 | # pos_input.shape = (n_nodes, n_timesteps, dim)
68 | velocity_sequence = displacement_fn_dvmap(pos_input[:, 1:], pos_input[:, :-1])
69 | # Normalized velocity sequence, merging spatial an time axis.
70 | normalized_velocity_sequence = (
71 | velocity_sequence - velocity_stats["mean"]
72 | ) / velocity_stats["std"]
73 | flat_velocity_sequence = normalized_velocity_sequence.reshape(
74 | n_total_points, -1
75 | )
76 |
77 | features["abs_pos"] = pos_input
78 | features["vel_hist"] = flat_velocity_sequence
79 |
80 | if magnitude_features:
81 | # append the magnitude of the velocity of each particle to the node features
82 | velocity_magnitude_sequence = jnp.linalg.norm(
83 | normalized_velocity_sequence, axis=-1
84 | )
85 | features["vel_mag"] = velocity_magnitude_sequence
86 |
87 | if not any(pbc):
88 | # Normalized clipped distances to lower and upper boundaries.
89 | # boundaries are an array of shape [num_dimensions, dim], where the
90 | # second axis, provides the lower/upper boundaries.
91 | boundaries = lax.stop_gradient(jnp.array(bounds))
92 |
93 | distance_to_lower_boundary = most_recent_position - boundaries[:, 0][None]
94 | distance_to_upper_boundary = boundaries[:, 1][None] - most_recent_position
95 |
96 | # rewritten the code above in jax
97 | distance_to_boundaries = jnp.concatenate(
98 | [distance_to_lower_boundary, distance_to_upper_boundary], axis=1
99 | )
100 | normalized_clipped_distance_to_boundaries = jnp.clip(
101 | distance_to_boundaries / connectivity_radius, -1.0, 1.0
102 | )
103 | features["bound"] = normalized_clipped_distance_to_boundaries
104 |
105 | if external_force_fn is not None:
106 | external_force_field = vmap(external_force_fn)(most_recent_position)
107 | features["force"] = external_force_field
108 |
109 | # senders and receivers are integers of shape (E,)
110 | receivers, senders = nbrs.idx
111 | features["senders"] = senders
112 | features["receivers"] = receivers
113 |
114 | # Relative displacement and distances normalized to radius (E, dim)
115 | displacement = vmap(displacement_fn)(
116 | most_recent_position[receivers], most_recent_position[senders]
117 | )
118 | normalized_relative_displacements = displacement / connectivity_radius
119 | features["rel_disp"] = normalized_relative_displacements
120 |
121 | normalized_relative_distances = space.distance(
122 | normalized_relative_displacements
123 | )
124 | features["rel_dist"] = normalized_relative_distances[:, None]
125 |
126 | return jax.tree_map(lambda f: f, features)
127 |
128 | return feature_transform
129 |
--------------------------------------------------------------------------------
/lagrangebench/data/__init__.py:
--------------------------------------------------------------------------------
1 | """Datasets and dataloading utils."""
2 |
3 | from .data import DAM2D, LDC2D, LDC3D, RPF2D, RPF3D, TGV2D, TGV3D, H5Dataset
4 |
5 | __all__ = ["H5Dataset", "TGV2D", "TGV3D", "RPF2D", "RPF3D", "LDC2D", "LDC3D", "DAM2D"]
6 |
--------------------------------------------------------------------------------
/lagrangebench/data/utils.py:
--------------------------------------------------------------------------------
1 | """Data utils."""
2 |
3 | from typing import Dict, List
4 |
5 | import jax.numpy as jnp
6 | import numpy as np
7 |
8 |
9 | def get_dataset_stats(
10 | metadata: Dict[str, List[float]],
11 | is_isotropic_norm: bool,
12 | noise_std: float,
13 | ) -> Dict[str, Dict[str, jnp.ndarray]]:
14 | """Return the dataset statistics based on the metadata dictionary.
15 |
16 | Args:
17 | metadata: Dataset metadata dictionary.
18 | is_isotropic_norm:
19 | Whether to shift/scale dimensions equally instead of dimension-wise.
20 | noise_std: Standard deviation of the GNS-style noise.
21 |
22 | Returns:
23 | Dictionary with the dataset statistics.
24 | """
25 | acc_mean = jnp.array(metadata["acc_mean"])
26 | acc_std = jnp.array(metadata["acc_std"])
27 | vel_mean = jnp.array(metadata["vel_mean"])
28 | vel_std = jnp.array(metadata["vel_std"])
29 |
30 | if is_isotropic_norm:
31 | acc_mean = jnp.mean(acc_mean) * jnp.ones_like(acc_mean)
32 | acc_std = jnp.sqrt(jnp.mean(acc_std**2)) * jnp.ones_like(acc_std)
33 | vel_mean = jnp.mean(vel_mean) * jnp.ones_like(vel_mean)
34 | vel_std = jnp.sqrt(jnp.mean(vel_std**2)) * jnp.ones_like(vel_std)
35 |
36 | return {
37 | "acceleration": {
38 | "mean": acc_mean,
39 | "std": jnp.sqrt(acc_std**2 + noise_std**2),
40 | },
41 | "velocity": {
42 | "mean": vel_mean,
43 | "std": jnp.sqrt(vel_std**2 + noise_std**2),
44 | },
45 | }
46 |
47 |
48 | def numpy_collate(batch) -> np.ndarray:
49 | """Collate helper for torch dataloaders."""
50 | # NOTE: to numpy to avoid copying twice (dataloader timeout).
51 | if isinstance(batch[0], np.ndarray):
52 | return np.stack(batch)
53 | if isinstance(batch[0], (tuple, list)):
54 | return type(batch[0])(numpy_collate(samples) for samples in zip(*batch))
55 | else:
56 | return np.asarray(batch)
57 |
--------------------------------------------------------------------------------
/lagrangebench/defaults.py:
--------------------------------------------------------------------------------
1 | """Default lagrangebench configs."""
2 |
3 |
4 | from omegaconf import DictConfig, OmegaConf
5 |
6 |
7 | def set_defaults(cfg: DictConfig = OmegaConf.create({})) -> DictConfig:
8 | """Set default lagrangebench configs."""
9 |
10 | ### global and hardware-related configs
11 |
12 | # configuration file. Either "config" or "load_ckp" must be specified.
13 | # If "config" is specified, "load_ckp" is ignored.
14 | cfg.config = None
15 | # Load checkpointed model from this directory
16 | cfg.load_ckp = None
17 | # One of "train", "infer" or "all" (= both)
18 | cfg.mode = "all"
19 | # random seed
20 | cfg.seed = 0
21 | # data type for preprocessing. One of "float32" or "float64"
22 | cfg.dtype = "float64"
23 | # gpu device. -1 for CPU. Should be specified before importing the library.
24 | cfg.gpu = None
25 | # XLA memory fraction to be preallocated. The JAX default is 0.75.
26 | # Should be specified before importing the library.
27 | cfg.xla_mem_fraction = None
28 |
29 | ### dataset
30 | cfg.dataset = OmegaConf.create({})
31 |
32 | # path to data directory
33 | cfg.dataset.src = None
34 | # dataset name
35 | cfg.dataset.name = None
36 |
37 | ### model
38 | cfg.model = OmegaConf.create({})
39 |
40 | # model architecture name. gns, segnn, egnn
41 | cfg.model.name = None
42 | # Length of the position input sequence
43 | cfg.model.input_seq_length = 6
44 | # Number of message passing steps
45 | cfg.model.num_mp_steps = 10
46 | # Number of MLP layers
47 | cfg.model.num_mlp_layers = 2
48 | # Hidden dimension
49 | cfg.model.latent_dim = 128
50 | # whether to include velocity magnitude features
51 | cfg.model.magnitude_features = False
52 | # whether to normalize dimensions equally
53 | cfg.model.isotropic_norm = False
54 |
55 | # SEGNN only parameters
56 | # steerable attributes level
57 | cfg.model.lmax_attributes = 1
58 | # Level of the hidden layer
59 | cfg.model.lmax_hidden = 1
60 | # SEGNN normalization. instance, batch, none
61 | cfg.model.segnn_norm = "none"
62 | # SEGNN velocity aggregation. avg or last
63 | cfg.model.velocity_aggregate = "avg"
64 |
65 | ### training
66 | cfg.train = OmegaConf.create({})
67 |
68 | # batch size
69 | cfg.train.batch_size = 1
70 | # max number of training steps
71 | cfg.train.step_max = 500_000
72 | # number of workers for data loading
73 | cfg.train.num_workers = 4
74 | # standard deviation of the GNS-style noise
75 | cfg.train.noise_std = 3.0e-4
76 |
77 | # optimizer
78 | cfg.train.optimizer = OmegaConf.create({})
79 |
80 | # initial learning rate
81 | cfg.train.optimizer.lr_start = 1.0e-4
82 | # final learning rate (after exponential decay)
83 | cfg.train.optimizer.lr_final = 1.0e-6
84 | # learning rate decay rate
85 | cfg.train.optimizer.lr_decay_rate = 0.1
86 | # number of steps to decay learning rate
87 | cfg.train.optimizer.lr_decay_steps = 1.0e5
88 |
89 | # pushforward
90 | cfg.train.pushforward = OmegaConf.create({})
91 |
92 | # At which training step to introduce next unroll stage
93 | cfg.train.pushforward.steps = [-1, 20000, 300000, 400000]
94 | # For how many steps to unroll
95 | cfg.train.pushforward.unrolls = [0, 1, 2, 3]
96 | # Which probability ratio to keep between the unrolls
97 | cfg.train.pushforward.probs = [18, 2, 1, 1]
98 |
99 | # loss weights
100 | cfg.train.loss_weight = OmegaConf.create({})
101 |
102 | # weight for acceleration error
103 | cfg.train.loss_weight.acc = 1.0
104 | # weight for velocity error
105 | cfg.train.loss_weight.vel = 0.0
106 | # weight for position error
107 | cfg.train.loss_weight.pos = 0.0
108 |
109 | ### evaluation
110 | cfg.eval = OmegaConf.create({})
111 |
112 | # number of eval rollout steps. -1 is full rollout
113 | cfg.eval.n_rollout_steps = 20
114 | # whether to use the test or valid split
115 | cfg.eval.test = False
116 | # rollouts directory
117 | cfg.eval.rollout_dir = None
118 |
119 | # configs for validation during training
120 | cfg.eval.train = OmegaConf.create({})
121 |
122 | # number of trajectories to evaluate
123 | cfg.eval.train.n_trajs = 50
124 | # stride for e_kin and sinkhorn
125 | cfg.eval.train.metrics_stride = 10
126 | # batch size
127 | cfg.eval.train.batch_size = 1
128 | # metrics to evaluate
129 | cfg.eval.train.metrics = ["mse"]
130 | # write validation rollouts. One of "none", "vtk", or "pkl"
131 | cfg.eval.train.out_type = "none"
132 |
133 | # configs for inference/testing
134 | cfg.eval.infer = OmegaConf.create({})
135 |
136 | # number of trajectories to evaluate during inference
137 | cfg.eval.infer.n_trajs = -1
138 | # stride for e_kin and sinkhorn
139 | cfg.eval.infer.metrics_stride = 1
140 | # batch size
141 | cfg.eval.infer.batch_size = 2
142 | # metrics for inference
143 | cfg.eval.infer.metrics = ["mse", "e_kin", "sinkhorn"]
144 | # write inference rollouts. One of "none", "vtk", or "pkl"
145 | cfg.eval.infer.out_type = "pkl"
146 |
147 | # number of extrapolation steps during inference
148 | cfg.eval.infer.n_extrap_steps = 0
149 |
150 | ### logging
151 | cfg.logging = OmegaConf.create({})
152 |
153 | # number of steps between loggings
154 | cfg.logging.log_steps = 1000
155 | # number of steps between evaluations and checkpoints
156 | cfg.logging.eval_steps = 10000
157 | # wandb enable
158 | cfg.logging.wandb = False
159 | # wandb project name
160 | cfg.logging.wandb_project = None
161 | # wandb entity name
162 | cfg.logging.wandb_entity = "lagrangebench"
163 | # checkpoint directory
164 | cfg.logging.ckp_dir = "ckp"
165 | # name of training run
166 | cfg.logging.run_name = None
167 |
168 | ### neighbor list
169 | cfg.neighbors = OmegaConf.create({})
170 |
171 | # backend for neighbor list computation
172 | cfg.neighbors.backend = "jaxmd_vmap"
173 | # multiplier for neighbor list capacity
174 | cfg.neighbors.multiplier = 1.25
175 |
176 | return cfg
177 |
178 |
179 | defaults = set_defaults()
180 |
181 |
182 | def check_cfg(cfg: DictConfig):
183 | """Check if the configs are valid."""
184 |
185 | assert cfg.mode in ["train", "infer", "all"]
186 | assert cfg.dtype in ["float32", "float64"]
187 | assert cfg.dataset.src is not None, "dataset.src must be specified."
188 |
189 | assert cfg.model.input_seq_length >= 2, "At least two positions for one past vel."
190 |
191 | pf = cfg.train.pushforward
192 | assert len(pf.steps) == len(pf.unrolls) == len(pf.probs)
193 | assert all([s >= 0 for s in pf.unrolls]), "All unrolls must be non-negative."
194 | assert all([s >= 0 for s in pf.probs]), "All probabilities must be non-negative."
195 | lwv = cfg.train.loss_weight.values()
196 | assert all([w >= 0 for w in lwv]), "All loss weights must be non-negative."
197 | assert sum(lwv) > 0, "At least one loss weight must be non-zero."
198 |
199 | assert cfg.eval.train.n_trajs >= -1
200 | assert cfg.eval.infer.n_trajs >= -1
201 | assert set(cfg.eval.train.metrics).issubset(["mse", "e_kin", "sinkhorn"])
202 | assert set(cfg.eval.infer.metrics).issubset(["mse", "e_kin", "sinkhorn"])
203 | assert cfg.eval.train.out_type in ["none", "vtk", "pkl"]
204 | assert cfg.eval.infer.out_type in ["none", "vtk", "pkl"]
205 |
--------------------------------------------------------------------------------
/lagrangebench/evaluate/__init__.py:
--------------------------------------------------------------------------------
1 | """Evaluation and rollout generation tools."""
2 |
3 | from .metrics import MetricsComputer, MetricsDict, averaged_metrics
4 | from .rollout import eval_rollout, infer
5 |
6 | __all__ = [
7 | "MetricsComputer",
8 | "MetricsDict",
9 | "averaged_metrics",
10 | "infer",
11 | "eval_rollout",
12 | ]
13 |
--------------------------------------------------------------------------------
/lagrangebench/evaluate/metrics.py:
--------------------------------------------------------------------------------
1 | """Metrics for evaluation end testing."""
2 |
3 | import warnings
4 | from collections import defaultdict
5 | from functools import partial
6 | from typing import Callable, Dict, List, Optional
7 |
8 | import jax
9 | import jax.numpy as jnp
10 | import numpy as np
11 | from ott.geometry.geometry import Geometry
12 | from ott.tools.sinkhorn_divergence import sinkhorn_divergence
13 |
14 | MetricsDict = Dict[str, Dict[str, jnp.ndarray]]
15 |
16 |
17 | class MetricsComputer:
18 | """
19 | Metrics between predicted and target rollouts.
20 |
21 | Currently implemented:
22 | * MSE, mean squared error
23 | * MAE, mean absolute error
24 | * Sinkhorn distance, measures the similarity of two particle distributions
25 | * Kinetic energy, physical quantity of interest
26 | """
27 |
28 | METRICS = ["mse", "mae", "sinkhorn", "e_kin"]
29 |
30 | def __init__(
31 | self,
32 | active_metrics: List,
33 | dist_fn: Callable,
34 | metadata: Dict,
35 | input_seq_length: int,
36 | stride: int = 10,
37 | loss_ranges: Optional[List] = None,
38 | ot_backend: str = "ott",
39 | ):
40 | """Init the metric computer.
41 |
42 | Args:
43 | active_metrics: List of metrics to compute.
44 | dist_fn: Distance function.
45 | metadata: Metadata of the dataset.
46 | loss_ranges: List of horizon lengths to compute the loss for.
47 | input_seq_length: Length of the input sequence.
48 | stride: Rollout subsample frequency for e_kin and sinkhorn.
49 | ot_backend: Backend for sinkhorn computation. "ott" or "pot".
50 | """
51 | if active_metrics is None:
52 | active_metrics = []
53 | assert all([hasattr(self, metric) for metric in active_metrics])
54 | assert ot_backend in ["ott", "pot"]
55 |
56 | self._active_metrics = active_metrics
57 | self._dist_fn = dist_fn
58 | self._dist_vmap = jax.vmap(dist_fn, in_axes=(0, 0))
59 | self._dist_dvmap = jax.vmap(self._dist_vmap, in_axes=(0, 0))
60 |
61 | if loss_ranges is None:
62 | loss_ranges = [1, 5, 10, 20, 50, 100]
63 | self._loss_ranges = loss_ranges
64 | self._input_seq_length = input_seq_length
65 | self._stride = stride
66 | self._metadata = metadata
67 | self.ot_backend = ot_backend
68 |
69 | def __call__(
70 | self, pred_rollout: jnp.ndarray, target_rollout: jnp.ndarray
71 | ) -> MetricsDict:
72 | """Compute the metrics between two rollouts.
73 |
74 | Args:
75 | pred_rollout: Predicted rollout.
76 | target_rollout: Target rollout.
77 |
78 | Returns:
79 | Dictionary of metrics.
80 | """
81 | # both rollouts of shape (traj_len - t_window, n_nodes, dim)
82 | target_rollout = jnp.asarray(target_rollout, dtype=pred_rollout.dtype)
83 | metrics = {}
84 | with warnings.catch_warnings():
85 | warnings.simplefilter("ignore")
86 | for metric_name in self._active_metrics:
87 | metric_fn = getattr(self, metric_name)
88 | if metric_name in ["mse", "mae"]:
89 | # full rollout loss
90 | metrics[metric_name] = jax.vmap(metric_fn)(
91 | pred_rollout, target_rollout
92 | )
93 | # shorter horizon losses
94 | for i in self._loss_ranges:
95 | if i < metrics[metric_name].shape[0]:
96 | metrics[f"{metric_name}{i}"] = metrics[metric_name][:i]
97 |
98 | elif metric_name in ["e_kin"]:
99 | dt = self._metadata["dt"] * self._metadata["write_every"]
100 | dx = self._metadata["dx"]
101 | dim = self._metadata["dim"]
102 |
103 | metric_dvmap = jax.vmap(jax.vmap(metric_fn))
104 |
105 | # Ekin of predicted rollout
106 | velocity_rollout = self._dist_dvmap(
107 | pred_rollout[1 :: self._stride],
108 | pred_rollout[0 : -1 : self._stride],
109 | )
110 | e_kin_pred = metric_dvmap(velocity_rollout / dt).sum(1)
111 | e_kin_pred = e_kin_pred * dx**dim
112 |
113 | # Ekin of target rollout
114 | velocity_rollout = self._dist_dvmap(
115 | target_rollout[1 :: self._stride],
116 | target_rollout[0 : -1 : self._stride],
117 | )
118 | e_kin_target = metric_dvmap(velocity_rollout / dt).sum(1)
119 | e_kin_target = e_kin_target * dx**dim
120 |
121 | metrics[metric_name] = {
122 | "predicted": e_kin_pred,
123 | "target": e_kin_target,
124 | "mse": ((e_kin_pred - e_kin_target) ** 2).mean(),
125 | }
126 |
127 | elif metric_name == "sinkhorn":
128 | # vmapping over distance matrix blows up
129 | metrics[metric_name] = jax.lax.scan(
130 | lambda _, x: (None, self.sinkhorn(*x)),
131 | None,
132 | (
133 | pred_rollout[0 :: self._stride],
134 | target_rollout[0 :: self._stride],
135 | ),
136 | )[1]
137 | return metrics
138 |
139 | @partial(jax.jit, static_argnums=(0,))
140 | def mse(self, pred: jnp.ndarray, target: jnp.ndarray) -> float:
141 | """Compute the mean squared error between two rollouts."""
142 | return (self._dist_vmap(pred, target) ** 2).mean()
143 |
144 | @partial(jax.jit, static_argnums=(0,))
145 | def mae(self, pred: jnp.ndarray, target: jnp.ndarray) -> float:
146 | """Compute the mean absolute error between two rollouts."""
147 | return (jnp.abs(self._dist_vmap(pred, target))).mean()
148 |
149 | @partial(jax.jit, static_argnums=(0,))
150 | def sinkhorn(self, pred: jnp.ndarray, target: jnp.ndarray) -> float:
151 | """Compute the sinkhorn distance between two rollouts."""
152 | if self.ot_backend == "ott":
153 | return self._sinkhorn_ott(pred, target)
154 | else:
155 | return self._sinkhorn_pot(pred, target)
156 |
157 | @partial(jax.jit, static_argnums=(0,))
158 | def e_kin(self, frame: jnp.ndarray) -> float:
159 | """Compute the kinetic energy of a frame."""
160 | return jnp.sum(frame**2) # * dx ** 3
161 |
162 | def _sinkhorn_ott(self, pred: jnp.ndarray, target: jnp.ndarray) -> float:
163 | # pairwise distances as cost
164 | loss_matrix_xy = self._distance_matrix(pred, target)
165 | loss_matrix_xx = self._distance_matrix(pred, pred)
166 | loss_matrix_yy = self._distance_matrix(target, target)
167 | return sinkhorn_divergence(
168 | Geometry,
169 | loss_matrix_xy,
170 | loss_matrix_xx,
171 | loss_matrix_yy,
172 | # uniform weights
173 | a=jnp.ones((pred.shape[0],)) / pred.shape[0],
174 | b=jnp.ones((target.shape[0],)) / target.shape[0],
175 | sinkhorn_kwargs={"threshold": 1e-4},
176 | ).divergence
177 |
178 | def _sinkhorn_pot(self, pred: jnp.ndarray, target: jnp.ndarray):
179 | """Jax-compatible POT implementation of Sinkorn."""
180 | # equivalent to empirical_sinkhorn_divergence with custom distance computation
181 | sinkhorn_ab = self._custom_empirical_sinkorn_pot(pred, target)
182 | sinkhorn_a = self._custom_empirical_sinkorn_pot(pred, pred)
183 | sinkhorn_b = self._custom_empirical_sinkorn_pot(target, target)
184 | return jnp.asarray(
185 | jnp.clip(sinkhorn_ab - 0.5 * (sinkhorn_a + sinkhorn_b), 0),
186 | dtype=jnp.float32,
187 | )
188 |
189 | def _custom_empirical_sinkorn_pot(self, pred: jnp.ndarray, target: jnp.ndarray):
190 | from ot.bregman import sinkhorn2
191 |
192 | # weights are uniform
193 | a, b = (
194 | jnp.ones((pred.shape[0],)) / pred.shape[0],
195 | jnp.ones((target.shape[0],)) / target.shape[0],
196 | )
197 | loss_matrix = self._distance_matrix(pred, target)
198 | shape = jax.ShapeDtypeStruct((), dtype=jnp.float32)
199 |
200 | # hack to avoid CpuCallback attribute error
201 | def sinkhorn2_(a, b, loss_matrix):
202 | return jnp.array(
203 | sinkhorn2(a, b, loss_matrix, reg=0.1, numItermax=500, stopThr=1e-05),
204 | dtype=jnp.float32,
205 | )
206 |
207 | return jax.pure_callback(
208 | sinkhorn2_,
209 | shape,
210 | a,
211 | b,
212 | loss_matrix,
213 | )
214 |
215 | def _distance_matrix(
216 | self, x: jnp.ndarray, y: jnp.ndarray, squared=True
217 | ) -> jnp.ndarray:
218 | """Euclidean distance matrix (pairwise)."""
219 |
220 | def dist(a, b):
221 | return jnp.sum(self._dist_fn(a, b) ** 2)
222 |
223 | if not squared:
224 |
225 | def dist(a, b):
226 | return jnp.sqrt(dist(a, b))
227 |
228 | return jnp.array(
229 | jax.vmap(lambda a: jax.vmap(lambda b: dist(a, b))(y))(x), dtype=jnp.float32
230 | )
231 |
232 |
233 | def averaged_metrics(eval_metrics: MetricsDict) -> Dict[str, float]:
234 | """Averages the metrics over the rollouts."""
235 | # create a dictionary with the same keys as the metrics, but empty list as values
236 | trajectory_averages = defaultdict(list)
237 | for rollout in eval_metrics.values():
238 | for k, v in rollout.items():
239 | if k == "e_kin":
240 | v = v["mse"]
241 | if k in ["mse", "mae"]:
242 | k = "loss"
243 | trajectory_averages[k].append(jnp.mean(v).item())
244 |
245 | # mean and std values accross rollouts
246 | small_metrics = {}
247 | for k, v in trajectory_averages.items():
248 | small_metrics[f"val/{k}"] = float(np.mean(v))
249 | for k, v in trajectory_averages.items():
250 | small_metrics[f"val/std{k}"] = float(np.std(v))
251 |
252 | return small_metrics
253 |
--------------------------------------------------------------------------------
/lagrangebench/evaluate/rollout.py:
--------------------------------------------------------------------------------
1 | """Evaluation and inference functions for generating rollouts."""
2 |
3 | import os
4 | import pickle
5 | import time
6 | from functools import partial
7 | from typing import Callable, Dict, Iterable, Optional, Tuple, Union
8 |
9 | import haiku as hk
10 | import jax
11 | import jax.numpy as jnp
12 | import jax_sph.jax_md.partition as partition
13 | from jax import jit, vmap
14 | from omegaconf import DictConfig, OmegaConf
15 | from torch.utils.data import DataLoader
16 |
17 | from lagrangebench.data import H5Dataset
18 | from lagrangebench.data.utils import numpy_collate
19 | from lagrangebench.defaults import defaults
20 | from lagrangebench.evaluate.metrics import MetricsComputer, MetricsDict
21 | from lagrangebench.evaluate.utils import write_vtk
22 | from lagrangebench.utils import (
23 | broadcast_from_batch,
24 | broadcast_to_batch,
25 | get_kinematic_mask,
26 | load_haiku,
27 | set_seed,
28 | )
29 |
30 |
31 | @partial(jit, static_argnames=["model_apply", "case_integrate"])
32 | def _forward_eval(
33 | params: hk.Params,
34 | state: hk.State,
35 | sample: Tuple[jnp.ndarray, jnp.ndarray],
36 | current_positions: jnp.ndarray,
37 | target_positions: jnp.ndarray,
38 | model_apply: Callable,
39 | case_integrate: Callable,
40 | ) -> jnp.ndarray:
41 | """Run one update of the 'current_state' using the trained model
42 |
43 | Args:
44 | params: Haiku model parameters
45 | state: Haiku model state
46 | current_positions: Set of historic positions of shape (n_nodel, t_window, dim)
47 | target_positions: used to get the next state of kinematic particles, i.e. those
48 | who are not update using the ML model, e.g. boundary particles
49 | model_apply: model function
50 | case_integrate: integration function from case.integrate
51 |
52 | Return:
53 | current_positions: after shifting the historic position sequence by one, i.e. by
54 | the newly computed most recent position
55 | """
56 | _, particle_type = sample
57 |
58 | # predict acceleration and integrate
59 | pred, state = model_apply(params, state, sample)
60 |
61 | next_position = case_integrate(pred, current_positions)
62 |
63 | # update only the positions of non-boundary particles
64 | kinematic_mask = get_kinematic_mask(particle_type)
65 | next_position = jnp.where(
66 | kinematic_mask[:, None],
67 | target_positions,
68 | next_position,
69 | )
70 |
71 | current_positions = jnp.concatenate(
72 | [current_positions[:, 1:], next_position[:, None, :]], axis=1
73 | ) # as next model input
74 |
75 | return current_positions, state
76 |
77 |
78 | def _eval_batched_rollout(
79 | forward_eval_vmap: Callable,
80 | preprocess_eval_vmap: Callable,
81 | case,
82 | params: hk.Params,
83 | state: hk.State,
84 | traj_batch_i: Tuple[jnp.ndarray, jnp.ndarray],
85 | neighbors: partition.NeighborList,
86 | metrics_computer_vmap: Callable,
87 | n_rollout_steps: int,
88 | t_window: int,
89 | n_extrap_steps: int = 0,
90 | ) -> Tuple[jnp.ndarray, MetricsDict, jnp.ndarray]:
91 | """Compute the rollout on a single trajectory.
92 |
93 | Args:
94 | forward_eval_vmap: Model function.
95 | case: CaseSetupFn class.
96 | params: Haiku params.
97 | state: Haiku state.
98 | traj_batch_i: Trajectory to evaluate.
99 | neighbors: Neighbor list.
100 | metrics_computer: Vectorized MetricsComputer with the desired metrics.
101 | n_rollout_steps: Number of rollout steps.
102 | t_window: Length of the input sequence.
103 | n_extrap_steps: Number of extrapolation steps (beyond the ground truth rollout).
104 |
105 | Returns:
106 | A tuple with (predicted rollout, metrics, neighbor list).
107 | """
108 | # particle type is treated as a static property defined by state at t=0
109 | pos_input_batch, particle_type_batch = traj_batch_i
110 | # current_batch_size might be < eval_batch_size if the last batch is not full
111 | current_batch_size, n_nodes_max, _, dim = pos_input_batch.shape
112 |
113 | # if n_rollout_steps set to -1, use the whole trajectory
114 | if n_rollout_steps == -1:
115 | n_rollout_steps = pos_input_batch.shape[2] - t_window
116 |
117 | current_positions_batch = pos_input_batch[:, :, 0:t_window]
118 | # (batch, n_nodes, t_window, dim)
119 | traj_len = n_rollout_steps + n_extrap_steps
120 | target_positions_batch = pos_input_batch[:, :, t_window : t_window + traj_len]
121 |
122 | predictions_batch = jnp.zeros((current_batch_size, traj_len, n_nodes_max, dim))
123 | neighbors_batch = broadcast_to_batch(neighbors, current_batch_size)
124 |
125 | step = 0
126 | while step < n_rollout_steps + n_extrap_steps:
127 | sample_batch = (current_positions_batch, particle_type_batch)
128 |
129 | # 1. preprocess features
130 | features_batch, neighbors_batch = preprocess_eval_vmap(
131 | sample_batch, neighbors_batch
132 | )
133 |
134 | # 2. check whether list overflowed and fix it if so
135 | if neighbors_batch.did_buffer_overflow.sum() > 0:
136 | # check if the neighbor list is too small for any of the samples
137 | # if so, reallocate the neighbor list
138 |
139 | print(f"(eval) Reallocate neighbors list at step {step}")
140 | ind = jnp.argmax(neighbors_batch.did_buffer_overflow)
141 | sample = broadcast_from_batch(sample_batch, index=ind)
142 |
143 | _, nbrs_temp = case.allocate_eval(sample)
144 | print(
145 | f"(eval) From {neighbors_batch.idx[ind].shape} to {nbrs_temp.idx.shape}"
146 | )
147 | neighbors_batch = broadcast_to_batch(nbrs_temp, current_batch_size)
148 |
149 | # To run the loop N times even if sometimes
150 | # did_buffer_overflow > 0 we directly return to the beginning
151 | continue
152 |
153 | # 3. run forward model
154 | current_positions_batch, state_batch = forward_eval_vmap(
155 | params,
156 | state,
157 | (features_batch, particle_type_batch),
158 | current_positions_batch,
159 | target_positions_batch[:, :, step],
160 | )
161 | # the state is not passed out of this loop, so no not really relevant
162 | state = broadcast_from_batch(state_batch, 0)
163 |
164 | # 4. write predicted next position to output array
165 | predictions_batch = predictions_batch.at[:, step].set(
166 | current_positions_batch[:, :, -1] # most recently predicted positions
167 | )
168 |
169 | step += 1
170 |
171 | # (batch, n_nodes, time, dim) -> (batch, time, n_nodes, dim)
172 | target_positions_batch = target_positions_batch.transpose(0, 2, 1, 3)
173 | # slice out extrapolation steps
174 | metrics_batch = metrics_computer_vmap(
175 | predictions_batch[:, :n_rollout_steps, :, :], target_positions_batch
176 | )
177 |
178 | return (predictions_batch, metrics_batch, broadcast_from_batch(neighbors_batch, 0))
179 |
180 |
181 | def eval_rollout(
182 | model_apply: Callable,
183 | case,
184 | params: hk.Params,
185 | state: hk.State,
186 | loader_eval: Iterable,
187 | neighbors: partition.NeighborList,
188 | metrics_computer: MetricsComputer,
189 | n_rollout_steps: int,
190 | n_trajs: int,
191 | rollout_dir: str,
192 | out_type: str = "none",
193 | n_extrap_steps: int = 0,
194 | ) -> MetricsDict:
195 | """Compute the rollout and evaluate the metrics.
196 |
197 | Args:
198 | model_apply: Model function.
199 | case: CaseSetupFn class.
200 | params: Haiku params.
201 | state: Haiku state.
202 | loader_eval: Evaluation data loader.
203 | neighbors: Neighbor list.
204 | metrics_computer: MetricsComputer with the desired metrics.
205 | n_rollout_steps: Number of rollout steps.
206 | n_trajs: Number of ground truth trajectories to evaluate.
207 | rollout_dir: Parent directory path where to store the rollout and metrics dict.
208 | out_type: Output type. Either "none", "vtk" or "pkl".
209 | n_extrap_steps: Number of extrapolation steps (beyond the ground truth rollout).
210 |
211 | Returns:
212 | Metrics per trajectory.
213 | """
214 | batch_size = loader_eval.batch_size
215 | t_window = loader_eval.dataset.input_seq_length
216 | eval_metrics = {}
217 |
218 | if rollout_dir is not None:
219 | os.makedirs(rollout_dir, exist_ok=True)
220 |
221 | forward_eval = partial(
222 | _forward_eval,
223 | model_apply=model_apply,
224 | case_integrate=case.integrate,
225 | )
226 | forward_eval_vmap = vmap(forward_eval, in_axes=(None, None, 0, 0, 0))
227 | preprocess_eval_vmap = vmap(case.preprocess_eval, in_axes=(0, 0))
228 | metrics_computer_vmap = vmap(metrics_computer, in_axes=(0, 0))
229 |
230 | for i, traj_batch_i in enumerate(loader_eval):
231 | # if n_trajs is not a multiple of batch_size, we slice from the last batch
232 | n_traj_left = n_trajs - i * batch_size
233 | if n_traj_left < batch_size:
234 | traj_batch_i = jax.tree_map(lambda x: x[:n_traj_left], traj_batch_i)
235 |
236 | # numpy to jax
237 | traj_batch_i = jax.tree_map(lambda x: jnp.array(x), traj_batch_i)
238 | # (pos_input_batch, particle_type_batch) = traj_batch_i
239 | # pos_input_batch.shape = (batch, num_particles, seq_length, dim)
240 |
241 | example_rollout_batch, metrics_batch, neighbors = _eval_batched_rollout(
242 | forward_eval_vmap=forward_eval_vmap,
243 | preprocess_eval_vmap=preprocess_eval_vmap,
244 | case=case,
245 | params=params,
246 | state=state,
247 | traj_batch_i=traj_batch_i, # (batch, nodes, t, dim)
248 | neighbors=neighbors,
249 | metrics_computer_vmap=metrics_computer_vmap,
250 | n_rollout_steps=n_rollout_steps,
251 | t_window=t_window,
252 | n_extrap_steps=n_extrap_steps,
253 | )
254 |
255 | current_batch_size = traj_batch_i[0].shape[0]
256 | for j in range(current_batch_size):
257 | # write metrics to output dictionary
258 | ind = i * batch_size + j
259 | eval_metrics[f"rollout_{ind}"] = broadcast_from_batch(metrics_batch, j)
260 |
261 | if rollout_dir is not None:
262 | # (batch, nodes, t, dim) -> (batch, t, nodes, dim)
263 | pos_input_batch = traj_batch_i[0].transpose(0, 2, 1, 3)
264 |
265 | for j in range(current_batch_size): # write every trajectory to file
266 | pos_input = pos_input_batch[j]
267 | example_rollout = example_rollout_batch[j]
268 |
269 | initial_positions = pos_input[:t_window]
270 | example_full = jnp.concatenate([initial_positions, example_rollout])
271 | example_rollout = {
272 | "predicted_rollout": example_full, # (t + extrap, nodes, dim)
273 | "ground_truth_rollout": pos_input, # (t, nodes, dim),
274 | "particle_type": traj_batch_i[1][j], # (nodes,)
275 | }
276 |
277 | file_prefix = os.path.join(rollout_dir, f"rollout_{i*batch_size+j}")
278 | if out_type == "vtk": # write vtk files for each time step
279 | for k in range(example_full.shape[0]):
280 | # predictions
281 | state_vtk = {
282 | "r": example_rollout["predicted_rollout"][k],
283 | "tag": example_rollout["particle_type"],
284 | }
285 | write_vtk(state_vtk, f"{file_prefix}_{k}.vtk")
286 | for k in range(pos_input.shape[0]):
287 | # ground truth reference
288 | ref_state_vtk = {
289 | "r": example_rollout["ground_truth_rollout"][k],
290 | "tag": example_rollout["particle_type"],
291 | }
292 | write_vtk(ref_state_vtk, f"{file_prefix}_ref_{k}.vtk")
293 | elif out_type == "pkl":
294 | filename = f"{file_prefix}.pkl"
295 |
296 | with open(filename, "wb") as f:
297 | pickle.dump(example_rollout, f)
298 |
299 | if (i * batch_size + j + 1) >= n_trajs:
300 | break
301 |
302 | if rollout_dir is not None:
303 | # save metrics
304 | t = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
305 | with open(f"{rollout_dir}/metrics{t}.pkl", "wb") as f:
306 | pickle.dump(eval_metrics, f)
307 |
308 | return eval_metrics
309 |
310 |
311 | def infer(
312 | model: hk.TransformedWithState,
313 | case,
314 | data_test: H5Dataset,
315 | params: Optional[hk.Params] = None,
316 | state: Optional[hk.State] = None,
317 | load_ckp: Optional[str] = None,
318 | cfg_eval_infer: Union[Dict, DictConfig] = defaults.eval.infer,
319 | rollout_dir: Optional[str] = defaults.eval.rollout_dir,
320 | n_rollout_steps: int = defaults.eval.n_rollout_steps,
321 | seed: int = defaults.seed,
322 | ):
323 | """
324 | Infer on a dataset, compute metrics and optionally save rollout in out_type format.
325 |
326 | Args:
327 | model: (Transformed) Haiku model.
328 | case: Case setup class.
329 | data_test: Test dataset.
330 | params: Haiku params.
331 | state: Haiku state.
332 | load_ckp: Path to checkpoint directory.
333 | rollout_dir: Path to rollout directory.
334 | cfg_eval_infer: Evaluation configuration for inference mode.
335 | n_rollout_steps: Number of rollout steps.
336 | seed: Seed.
337 |
338 | Returns:
339 | eval_metrics: Metrics per trajectory.
340 | """
341 | assert (
342 | params is not None or load_ckp is not None
343 | ), "Either params or a load_ckp directory must be provided for inference."
344 |
345 | if isinstance(cfg_eval_infer, Dict):
346 | cfg_eval_infer = OmegaConf.create(cfg_eval_infer)
347 |
348 | # if one of the cfg_* arguments has a subset of the default configs, merge them
349 | cfg_eval_infer = OmegaConf.merge(defaults.eval.infer, cfg_eval_infer)
350 |
351 | n_trajs = cfg_eval_infer.n_trajs
352 | if n_trajs == -1:
353 | n_trajs = data_test.num_samples
354 |
355 | if params is not None:
356 | if state is None:
357 | state = {}
358 | else:
359 | params, state, _, _ = load_haiku(load_ckp)
360 |
361 | key, seed_worker, generator = set_seed(seed)
362 |
363 | loader_test = DataLoader(
364 | dataset=data_test,
365 | batch_size=cfg_eval_infer.batch_size,
366 | collate_fn=numpy_collate,
367 | worker_init_fn=seed_worker,
368 | generator=generator,
369 | )
370 | metrics_computer = MetricsComputer(
371 | cfg_eval_infer.metrics,
372 | dist_fn=case.displacement,
373 | metadata=data_test.metadata,
374 | input_seq_length=data_test.input_seq_length,
375 | stride=cfg_eval_infer.metrics_stride,
376 | )
377 | # Precompile model
378 | model_apply = jit(model.apply)
379 |
380 | # init values
381 | pos_input_and_target, particle_type = next(iter(loader_test))
382 | sample = (pos_input_and_target[0], particle_type[0])
383 | key, _, _, neighbors = case.allocate(key, sample)
384 |
385 | eval_metrics = eval_rollout(
386 | model_apply=model_apply,
387 | case=case,
388 | metrics_computer=metrics_computer,
389 | params=params,
390 | state=state,
391 | neighbors=neighbors,
392 | loader_eval=loader_test,
393 | n_rollout_steps=n_rollout_steps,
394 | n_trajs=n_trajs,
395 | rollout_dir=rollout_dir,
396 | out_type=cfg_eval_infer.out_type,
397 | n_extrap_steps=cfg_eval_infer.n_extrap_steps,
398 | )
399 | return eval_metrics
400 |
--------------------------------------------------------------------------------
/lagrangebench/evaluate/utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for evaluation."""
2 |
3 | import os
4 | import pickle
5 |
6 | import numpy as np
7 |
8 |
9 | def write_vtk(data_dict, path):
10 | """Store a .vtk file for ParaView."""
11 |
12 | try:
13 | import pyvista
14 | except ImportError:
15 | raise ImportError("Please install pyvista to write VTK files.")
16 |
17 | r = np.asarray(data_dict["r"])
18 | N, dim = r.shape
19 |
20 | # PyVista treats the position information differently than the rest
21 | if dim == 2:
22 | r = np.hstack([r, np.zeros((N, 1))])
23 | data_pv = pyvista.PolyData(r)
24 |
25 | # copy all the other information also to pyvista, using plain numpy arrays
26 | for k, v in data_dict.items():
27 | # skip r because we already considered it above
28 | if k == "r":
29 | continue
30 |
31 | # working in 3D or scalar features do not require special care
32 | if dim == 2 and v.ndim == 2:
33 | v = np.hstack([v, np.zeros((N, 1))])
34 |
35 | data_pv[k] = np.asarray(v)
36 |
37 | data_pv.save(path)
38 |
39 |
40 | def pkl2vtk(src_path, dst_path=None):
41 | """Convert a rollout pickle file to a set of vtk files.
42 |
43 | Args:
44 | src_path (str): Source path to .pkl file.
45 | dst_path (str, optoinal): Destination directory path. Defaults to None.
46 | If None, then the vtk files are saved in the same directory as the pkl file.
47 |
48 | Example:
49 | pkl2vtk("rollout/test/rollout_0.pkl", "rollout/test_vtk")
50 | will create files rollout_0_0.vtk, rollout_0_1.vtk, etc. in the directory
51 | "rollout/test_vtk"
52 | """
53 |
54 | # set up destination directory
55 | if dst_path is None:
56 | dst_path = os.path.dirname(src_path)
57 | os.makedirs(dst_path, exist_ok=True)
58 |
59 | # load rollout
60 | with open(src_path, "rb") as f:
61 | rollout = pickle.load(f)
62 |
63 | file_prefix = os.path.join(dst_path, os.path.basename(src_path).split(".")[0])
64 | for k in range(rollout["predicted_rollout"].shape[0]):
65 | # predictions
66 | state_vtk = {
67 | "r": rollout["predicted_rollout"][k],
68 | "tag": rollout["particle_type"],
69 | }
70 | write_vtk(state_vtk, f"{file_prefix}_{k}.vtk")
71 | # ground truth reference
72 | state_vtk = {
73 | "r": rollout["ground_truth_rollout"][k],
74 | "tag": rollout["particle_type"],
75 | }
76 | write_vtk(state_vtk, f"{file_prefix}_ref_{k}.vtk")
77 |
--------------------------------------------------------------------------------
/lagrangebench/models/__init__.py:
--------------------------------------------------------------------------------
1 | """Baseline models."""
2 |
3 | from .egnn import EGNN
4 | from .gns import GNS
5 | from .linear import Linear
6 | from .painn import PaiNN
7 | from .segnn import SEGNN
8 |
9 | __all__ = ["GNS", "SEGNN", "EGNN", "PaiNN", "Linear"]
10 |
--------------------------------------------------------------------------------
/lagrangebench/models/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, Tuple
3 |
4 | import haiku as hk
5 | import jax.numpy as jnp
6 |
7 |
8 | class BaseModel(hk.Module, ABC):
9 | """Base model class. All models must inherit from this class."""
10 |
11 | @abstractmethod
12 | def __call__(
13 | self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray]
14 | ) -> Dict[str, jnp.ndarray]:
15 | """Forward pass.
16 |
17 | We specify the dimensions of the inputs and outputs using the number of nodes N,
18 | the number of edges E, number of historic velocities K (=input_seq_length - 1),
19 | and the dimensionality of the feature vectors dim.
20 |
21 | Args:
22 | sample: Tuple with feature dictionary and particle type. Possible features
23 |
24 | - "abs_pos" (N, K+1, dim), absolute positions
25 | - "vel_hist" (N, K*dim), historical velocity sequence
26 | - "vel_mag" (N,), velocity magnitudes
27 | - "bound" (N, 2*dim), distance to boundaries
28 | - "force" (N, dim), external force field
29 | - "rel_disp" (E, dim), relative displacement vectors
30 | - "rel_dist" (E, 1), relative distances, i.e. magnitude of displacements
31 | - "senders" (E), sender indices
32 | - "receivers" (E), receiver indices
33 | Returns:
34 | Dict with model output.
35 | The keys must be at least one of the following:
36 |
37 | - "acc" (N, dim), (normalized) acceleration
38 | - "vel" (N, dim), (normalized) velocity
39 | - "pos" (N, dim), (absolute) next position
40 | """
41 | raise NotImplementedError
42 |
--------------------------------------------------------------------------------
/lagrangebench/models/egnn.py:
--------------------------------------------------------------------------------
1 | """
2 | E(n) equivariant GNN from `Garcia Satorras et al. `_.
3 | EGNN model, layers and feature transform.
4 |
5 | Original implementation: https://github.com/vgsatorras/egnn
6 |
7 | Standalone implementation + validation: https://github.com/gerkone/egnn-jax
8 | """
9 |
10 | from typing import Any, Callable, Dict, Optional, Tuple
11 |
12 | import haiku as hk
13 | import jax
14 | import jax.numpy as jnp
15 | import jraph
16 | from jax.tree_util import Partial
17 | from jax_sph.jax_md import space
18 |
19 | from lagrangebench.utils import NodeType
20 |
21 | from .base import BaseModel
22 | from .utils import LinearXav, MLPXav
23 |
24 |
25 | class EGNNLayer(hk.Module):
26 | r"""E(n)-equivariant EGNN layer.
27 |
28 | Applies a message passing step where the positions are corrected with the velocities
29 | and a learnable correction term :math:`\psi_x(\mathbf{h}_i^{(t+1)})`:
30 | """
31 |
32 | def __init__(
33 | self,
34 | layer_num: int,
35 | hidden_size: int,
36 | output_size: int,
37 | displacement_fn: space.DisplacementFn,
38 | shift_fn: space.ShiftFn,
39 | blocks: int = 1,
40 | act_fn: Callable = jax.nn.silu,
41 | pos_aggregate_fn: Optional[Callable] = jraph.segment_sum,
42 | msg_aggregate_fn: Optional[Callable] = jraph.segment_sum,
43 | residual: bool = True,
44 | attention: bool = False,
45 | normalize: bool = False,
46 | tanh: bool = False,
47 | dt: float = 0.001,
48 | eps: float = 1e-8,
49 | ):
50 | """Initialize the layer.
51 |
52 | Args:
53 | layer_num: layer number
54 | hidden_size: hidden size
55 | output_size: output size
56 | displacement_fn: Displacement function for the acceleration computation.
57 | shift_fn: Shift function for updating positions
58 | blocks: number of blocks in the node and edge MLPs
59 | act_fn: activation function
60 | pos_aggregate_fn: position aggregation function
61 | msg_aggregate_fn: message aggregation function
62 | residual: whether to use residual connections
63 | attention: whether to use attention
64 | normalize: whether to normalize the coordinates
65 | tanh: whether to use tanh in the position update
66 | dt: position update step size
67 | eps: small number to avoid division by zero
68 | """
69 | super().__init__(f"layer_{layer_num}")
70 |
71 | self._displacement_fn = displacement_fn
72 | self._shift_fn = shift_fn
73 | self.pos_aggregate_fn = pos_aggregate_fn
74 | self.msg_aggregate_fn = msg_aggregate_fn
75 | self._residual = residual
76 | self._normalize = normalize
77 | self._eps = eps
78 |
79 | # message network
80 | self._edge_mlp = MLPXav(
81 | [hidden_size] * blocks + [hidden_size],
82 | activation=act_fn,
83 | activate_final=True,
84 | )
85 |
86 | # update network
87 | self._node_mlp = MLPXav(
88 | [hidden_size] * blocks + [output_size],
89 | activation=act_fn,
90 | activate_final=False,
91 | )
92 |
93 | # position update network
94 | net = [LinearXav(hidden_size)] * blocks
95 | # NOTE: from https://github.com/vgsatorras/egnn/blob/main/models/gcl.py#L254
96 | net += [
97 | act_fn,
98 | LinearXav(1, with_bias=False, w_init=hk.initializers.UniformScaling(dt)),
99 | ]
100 | if tanh:
101 | net.append(jax.nn.tanh)
102 | self._pos_correction_mlp = hk.Sequential(net)
103 |
104 | # velocity integrator network
105 | net = [LinearXav(hidden_size)] * blocks
106 | net += [
107 | act_fn,
108 | LinearXav(1, with_bias=False, w_init=hk.initializers.UniformScaling(dt)),
109 | ]
110 | self._vel_correction_mlp = hk.Sequential(net)
111 |
112 | # attention
113 | self._attention_mlp = None
114 | if attention:
115 | self._attention_mlp = hk.Sequential(
116 | [LinearXav(hidden_size), jax.nn.sigmoid]
117 | )
118 |
119 | def _pos_update(
120 | self,
121 | pos: jnp.ndarray,
122 | graph: jraph.GraphsTuple,
123 | coord_diff: jnp.ndarray,
124 | ) -> jnp.ndarray:
125 | trans = coord_diff * self._pos_correction_mlp(graph.edges)
126 | return self.pos_aggregate_fn(trans, graph.senders, num_segments=pos.shape[0])
127 |
128 | def _message(
129 | self,
130 | radial: jnp.ndarray,
131 | edge_attribute: jnp.ndarray,
132 | edge_features: Any,
133 | incoming: jnp.ndarray,
134 | outgoing: jnp.ndarray,
135 | globals_: Any,
136 | ) -> jnp.ndarray:
137 | _ = edge_features
138 | _ = globals_
139 | msg = jnp.concatenate([incoming, outgoing, radial], axis=-1)
140 | if edge_attribute is not None:
141 | msg = jnp.concatenate([msg, edge_attribute], axis=-1)
142 | msg = self._edge_mlp(msg)
143 | if self._attention_mlp:
144 | att = self._attention_mlp(msg)
145 | msg = msg * att
146 | return msg
147 |
148 | def _update(
149 | self,
150 | node_attribute: jnp.ndarray,
151 | nodes: jnp.ndarray,
152 | senders: Any,
153 | msg: jnp.ndarray,
154 | globals_: Any,
155 | ) -> jnp.ndarray:
156 | _ = senders
157 | _ = globals_
158 | x = jnp.concatenate([nodes, msg], axis=-1)
159 | if node_attribute is not None:
160 | x = jnp.concatenate([x, node_attribute], axis=-1)
161 | x = self._node_mlp(x)
162 | if self._residual:
163 | x = nodes + x
164 | return x
165 |
166 | def _coord2radial(
167 | self, graph: jraph.GraphsTuple, coord: jnp.array
168 | ) -> Tuple[jnp.array, jnp.array]:
169 | coord_diff = self._displacement_fn(coord[graph.senders], coord[graph.receivers])
170 | radial = jnp.sum(coord_diff**2, 1)[:, jnp.newaxis]
171 | if self._normalize:
172 | norm = jnp.sqrt(radial)
173 | coord_diff = coord_diff / (norm + self._eps)
174 | return radial, coord_diff
175 |
176 | def __call__(
177 | self,
178 | graph: jraph.GraphsTuple,
179 | pos: jnp.ndarray,
180 | vel: jnp.ndarray,
181 | edge_attribute: Optional[jnp.ndarray] = None,
182 | node_attribute: Optional[jnp.ndarray] = None,
183 | ) -> Tuple[jraph.GraphsTuple, jnp.ndarray]:
184 | """
185 | Apply EGNN layer.
186 |
187 | Args:
188 | graph: Graph from previous step
189 | pos: Node position, updated separately
190 | vel: Node velocity
191 | edge_attribute: Edge attribute (optional)
192 | node_attribute: Node attribute (optional)
193 | Returns:
194 | Updated graph, node position
195 | """
196 | radial, coord_diff = self._coord2radial(graph, pos)
197 | graph = jraph.GraphNetwork(
198 | update_edge_fn=Partial(self._message, radial, edge_attribute),
199 | update_node_fn=Partial(self._update, node_attribute),
200 | aggregate_edges_for_nodes_fn=self.msg_aggregate_fn,
201 | )(graph)
202 | # update position
203 | pos = self._shift_fn(pos, self._pos_update(pos, graph, coord_diff))
204 | # integrate velocity
205 | pos = self._shift_fn(pos, self._vel_correction_mlp(graph.nodes) * vel)
206 | return graph, pos
207 |
208 |
209 | class EGNN(BaseModel):
210 | r"""
211 | E(n) Graph Neural Network by
212 | `Garcia Satorras et al. `_.
213 |
214 | EGNN doesn't require expensive higher-order representations in intermediate layers;
215 | instead it relies on separate scalar and vector channels, which are treated
216 | differently by EGNN layers. In this setup, EGNN is similar to a learnable numerical
217 | integrator:
218 |
219 | .. math::
220 | \begin{align}
221 | \mathbf{m}_{ij}^{(t+1)} &= \phi_e \left(
222 | \mathbf{m}_{ij}^{(t)}, \mathbf{h}_i^{(t)},
223 | \mathbf{h}_j^{(t)}, ||\mathbf{x}_i^{(t)} - \mathbf{x}_j^{(t)}||^2
224 | \right) \\
225 | \mathbf{\hat{m}}_{ij}^{(t+1)} &=
226 | (\mathbf{x}_i^{(t)} - \mathbf{x}_j^{(t)}) \phi_x(\mathbf{m}_{ij}^{(t+1)})
227 | \end{align}
228 |
229 | And the node update with the integrator
230 |
231 | .. math::
232 | \begin{align}
233 | \mathbf{h}_i^{(t+1)} &= \psi_h \left(
234 | \mathbf{h}_i^{(t)}, \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}^{(t+1)}
235 | \right) \\
236 | \mathbf{x}_i^{(t+1)} &= \mathbf{x}_i^{(t)}
237 | + \mathbf{\hat{m}}_{ij}^{(t+1)} \psi_x(\mathbf{h}_i^{(t+1)})
238 | \end{align}
239 |
240 | where :math:`\mathbf{m}_{ij}` and :math:`\mathbf{\hat{m}}_{ij}` are the scalar and
241 | vector messages respectively, and :math:`\mathbf{x}_{i}` are the positions.
242 |
243 | This implementation differs from the original in two places:
244 |
245 | - because our datasets can have periodic boundary conditions, we use shift and
246 | displacement functions that take care of it when operations on positions are done.
247 | - we apply a simple integrator after the last layer to get the acceleration.
248 | """
249 |
250 | def __init__(
251 | self,
252 | hidden_size: int,
253 | output_size: int,
254 | dt: float,
255 | n_vels: int,
256 | displacement_fn: space.DisplacementFn,
257 | shift_fn: space.ShiftFn,
258 | normalization_stats: Optional[Dict[str, jnp.ndarray]] = None,
259 | act_fn: Callable = jax.nn.silu,
260 | num_mp_steps: int = 4,
261 | homogeneous_particles: bool = True,
262 | residual: bool = True,
263 | attention: bool = False,
264 | normalize: bool = False,
265 | tanh: bool = False,
266 | ):
267 | r"""
268 | Initialize the network.
269 |
270 | Args:
271 | hidden_size: Number of hidden features.
272 | output_size: Number of features for 'h' at the output.
273 | dt: Time step for position and velocity integration. Used to rescale the
274 | initialization of the correction MLP.
275 | n_vels: Number of velocities in the history.
276 | displacement_fn: Displacement function for the acceleration computation.
277 | shift_fn: Shift function for updating positions.
278 | normalization_stats: Normalization statistics for the input data.
279 | act_fn: Non-linearity.
280 | num_mp_steps: Number of layer for the EGNN
281 | homogeneous_particles: If all particles are of homogeneous type.
282 | residual: Whether to use residual connections.
283 | attention: Whether to use attention or not.
284 | normalize: Normalizes the coordinates messages such that:
285 | ``x^{l+1}_i = x^{l}_i + \sum(x_i - x_j)\phi_x(m_{ij})\|x_i - x_j\|``
286 | It may help in the stability or generalization. Not used in the paper.
287 | tanh: Sets a tanh activation function at the output of ``\phi_x(m_{ij})``.
288 | It bounds the output of ``\phi_x(m_{ij})`` which definitely improves in
289 | stability but it may decrease in accuracy. Not used in the paper.
290 | """
291 | super().__init__()
292 | # network
293 | self._hidden_size = hidden_size
294 | self._output_size = output_size
295 | self._act_fn = act_fn
296 | self._num_mp_steps = num_mp_steps
297 | self._residual = residual
298 | self._attention = attention
299 | self._normalize = normalize
300 | self._tanh = tanh
301 |
302 | # integrator
303 | self._dt = dt / self._num_mp_steps
304 | self._displacement_fn = displacement_fn
305 | self._shift_fn = shift_fn
306 | if normalization_stats is None:
307 | normalization_stats = {
308 | "velocity": {"mean": 0.0, "std": 1.0},
309 | "acceleration": {"mean": 0.0, "std": 1.0},
310 | }
311 | self._vel_stats = normalization_stats["velocity"]
312 | self._acc_stats = normalization_stats["acceleration"]
313 |
314 | # transform
315 | self._n_vels = n_vels
316 | self._homogeneous_particles = homogeneous_particles
317 |
318 | def _transform(
319 | self, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray
320 | ) -> Tuple[jraph.GraphsTuple, Dict[str, jnp.ndarray]]:
321 | props = {}
322 | n_nodes = features["vel_hist"].shape[0]
323 |
324 | props["vel"] = jnp.reshape(features["vel_hist"], (n_nodes, self._n_vels, -1))
325 |
326 | # most recent position
327 | props["pos"] = features["abs_pos"][:, -1]
328 | # relative distances between particles
329 | props["edge_attr"] = features["rel_dist"]
330 | # force magnitude as node attributes
331 | props["node_attr"] = None
332 | if "force" in features:
333 | props["node_attr"] = jnp.sqrt(
334 | jnp.sum(features["force"] ** 2, axis=-1, keepdims=True)
335 | )
336 |
337 | # velocity magnitudes as node features
338 | node_features = jnp.concatenate(
339 | [
340 | jnp.sqrt(jnp.sum(props["vel"][:, i, :] ** 2, axis=-1, keepdims=True))
341 | for i in range(self._n_vels)
342 | ],
343 | axis=-1,
344 | )
345 | if not self._homogeneous_particles:
346 | particles = jax.nn.one_hot(particle_type, NodeType.SIZE)
347 | node_features = jnp.concatenate([node_features, particles], axis=-1)
348 |
349 | graph = jraph.GraphsTuple(
350 | nodes=node_features,
351 | edges=None,
352 | senders=features["senders"],
353 | receivers=features["receivers"],
354 | n_node=jnp.array([n_nodes]),
355 | n_edge=jnp.array([len(features["senders"])]),
356 | globals=None,
357 | )
358 |
359 | return graph, props
360 |
361 | def _postprocess(
362 | self, next_pos: jnp.ndarray, props: Dict[str, jnp.ndarray]
363 | ) -> Dict[str, jnp.ndarray]:
364 | prev_vel = props["vel"][:, -1, :]
365 | prev_pos = props["pos"]
366 | # first order finite difference
367 | next_vel = self._displacement_fn(next_pos, prev_pos)
368 | acc = next_vel - prev_vel
369 | return {"pos": next_pos, "vel": next_vel, "acc": acc}
370 |
371 | def __call__(
372 | self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray]
373 | ) -> Dict[str, jnp.ndarray]:
374 | graph, props = self._transform(*sample)
375 | # input node embedding
376 | h = LinearXav(self._hidden_size, name="scalar_emb")(graph.nodes)
377 | graph = graph._replace(nodes=h)
378 | prev_vel = props["vel"][:, -1, :]
379 | # egnn works with unnormalized velocities
380 | prev_vel = prev_vel * self._vel_stats["std"] + self._vel_stats["mean"]
381 | # message passing
382 | next_pos = props["pos"].copy()
383 | for n in range(self._num_mp_steps):
384 | graph, next_pos = EGNNLayer(
385 | layer_num=n,
386 | hidden_size=self._hidden_size,
387 | output_size=self._hidden_size,
388 | displacement_fn=self._displacement_fn,
389 | shift_fn=self._shift_fn,
390 | act_fn=self._act_fn,
391 | residual=self._residual,
392 | attention=self._attention,
393 | normalize=self._normalize,
394 | dt=self._dt,
395 | tanh=self._tanh,
396 | )(graph, next_pos, prev_vel, props["edge_attr"], props["node_attr"])
397 |
398 | # position finite differencing to get acceleration
399 | out = self._postprocess(next_pos, props)
400 | return out
401 |
--------------------------------------------------------------------------------
/lagrangebench/models/gns.py:
--------------------------------------------------------------------------------
1 | """
2 | Graph Network-based Simulator.
3 | GNS model and feature transform.
4 | """
5 |
6 | from typing import Dict, Tuple
7 |
8 | import haiku as hk
9 | import jax.numpy as jnp
10 | import jraph
11 |
12 | from lagrangebench.utils import NodeType
13 |
14 | from .base import BaseModel
15 | from .utils import build_mlp
16 |
17 |
18 | class GNS(BaseModel):
19 | r"""Graph Network-based Simulator by
20 | `Sanchez-Gonzalez et al. `_.
21 |
22 | GNS is the simples graph neural network applied to particle dynamics. It is built on
23 | the usual Graph Network architecture, with an encoder, a processor, and a decoder.
24 |
25 | .. math::
26 | \begin{align}
27 | \mathbf{m}_{ij}^{(t+1)} &= \phi \left(
28 | \mathbf{m}_{ij}^{(t)}, \mathbf{h}_i^{(t)}, \mathbf{h}_j^{(t)} \right) \\
29 | \mathbf{h}_i^{(t+1)} &= \psi \left(
30 | \mathbf{h}_i^{(t)}, \sum_{j \in \mathcal{N}(i)} \mathbf{m}_{ij}^{(t+1)}
31 | \right) \\
32 | \end{align}
33 | """
34 |
35 | def __init__(
36 | self,
37 | particle_dimension: int,
38 | latent_size: int,
39 | blocks_per_step: int,
40 | num_mp_steps: int,
41 | particle_type_embedding_size: int,
42 | num_particle_types: int = NodeType.SIZE,
43 | ):
44 | """Initialize the model.
45 |
46 | Args:
47 | particle_dimension: Space dimensionality (e.g. 2 or 3).
48 | latent_size: Size of the latent representations.
49 | blocks_per_step: Number of MLP layers per block.
50 | num_mp_steps: Number of message passing steps.
51 | particle_type_embedding_size: Size of the particle type embedding.
52 | num_particle_types: Max number of particle types.
53 | """
54 | super().__init__()
55 | self._output_size = particle_dimension
56 | self._latent_size = latent_size
57 | self._blocks_per_step = blocks_per_step
58 | self._mp_steps = num_mp_steps
59 | self._num_particle_types = num_particle_types
60 |
61 | self._embedding = hk.Embed(
62 | num_particle_types, particle_type_embedding_size
63 | ) # (9, 16)
64 |
65 | def _encoder(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
66 | """MLP graph encoder."""
67 | node_latents = build_mlp(
68 | self._latent_size, self._latent_size, self._blocks_per_step
69 | )(graph.nodes)
70 | edge_latents = build_mlp(
71 | self._latent_size, self._latent_size, self._blocks_per_step
72 | )(graph.edges)
73 | return jraph.GraphsTuple(
74 | nodes=node_latents,
75 | edges=edge_latents,
76 | globals=graph.globals,
77 | receivers=graph.receivers,
78 | senders=graph.senders,
79 | n_node=jnp.asarray([node_latents.shape[0]]),
80 | n_edge=jnp.asarray([edge_latents.shape[0]]),
81 | )
82 |
83 | def _processor(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
84 | """Sequence of Graph Network blocks."""
85 |
86 | def update_edge_features(
87 | edge_features,
88 | sender_node_features,
89 | receiver_node_features,
90 | _, # globals_
91 | ):
92 | update_fn = build_mlp(
93 | self._latent_size, self._latent_size, self._blocks_per_step
94 | )
95 | # Calculate sender node features from edge features
96 | return update_fn(
97 | jnp.concatenate(
98 | [sender_node_features, receiver_node_features, edge_features],
99 | axis=-1,
100 | )
101 | )
102 |
103 | def update_node_features(
104 | node_features,
105 | _, # aggr_sender_edge_features,
106 | aggr_receiver_edge_features,
107 | __, # globals_,
108 | ):
109 | update_fn = build_mlp(
110 | self._latent_size, self._latent_size, self._blocks_per_step
111 | )
112 | features = [node_features, aggr_receiver_edge_features]
113 | return update_fn(jnp.concatenate(features, axis=-1))
114 |
115 | # Perform iterative message passing by stacking Graph Network blocks
116 | for _ in range(self._mp_steps):
117 | _graph = jraph.GraphNetwork(
118 | update_edge_fn=update_edge_features, update_node_fn=update_node_features
119 | )(graph)
120 | graph = graph._replace(
121 | nodes=_graph.nodes + graph.nodes, edges=_graph.edges + graph.edges
122 | )
123 |
124 | return graph
125 |
126 | def _decoder(self, graph: jraph.GraphsTuple):
127 | """MLP graph node decoder."""
128 | return build_mlp(
129 | self._latent_size,
130 | self._output_size,
131 | self._blocks_per_step,
132 | is_layer_norm=False,
133 | )(graph.nodes)
134 |
135 | def _transform(
136 | self, features: Dict[str, jnp.ndarray], particle_type: jnp.ndarray
137 | ) -> jraph.GraphsTuple:
138 | """Convert physical features to jraph.GraphsTuple for gns."""
139 | n_total_points = features["vel_hist"].shape[0]
140 | node_features = [
141 | features[k]
142 | for k in ["vel_hist", "vel_mag", "bound", "force"]
143 | if k in features
144 | ]
145 | edge_features = [features[k] for k in ["rel_disp", "rel_dist"] if k in features]
146 |
147 | graph = jraph.GraphsTuple(
148 | nodes=jnp.concatenate(node_features, axis=-1),
149 | edges=jnp.concatenate(edge_features, axis=-1),
150 | receivers=features["receivers"],
151 | senders=features["senders"],
152 | n_node=jnp.array([n_total_points]),
153 | n_edge=jnp.array([len(features["senders"])]),
154 | globals=None,
155 | )
156 |
157 | return graph, particle_type
158 |
159 | def __call__(
160 | self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray]
161 | ) -> Dict[str, jnp.ndarray]:
162 | graph, particle_type = self._transform(*sample)
163 |
164 | if self._num_particle_types > 1:
165 | particle_type_embeddings = self._embedding(particle_type)
166 | new_node_features = jnp.concatenate(
167 | [graph.nodes, particle_type_embeddings], axis=-1
168 | )
169 | graph = graph._replace(nodes=new_node_features)
170 | acc = self._decoder(self._processor(self._encoder(graph)))
171 | return {"acc": acc}
172 |
--------------------------------------------------------------------------------
/lagrangebench/models/linear.py:
--------------------------------------------------------------------------------
1 | """Simple baseline linear model."""
2 |
3 | from typing import Dict, Tuple
4 |
5 | import haiku as hk
6 | import jax.numpy as jnp
7 | import numpy as np
8 | from jax import vmap
9 |
10 | from .base import BaseModel
11 |
12 |
13 | class Linear(BaseModel):
14 | r"""Model defining linear relation between input nodes and targets.
15 |
16 | :math:`\mathbf{a}_i = \mathbf{W} \mathbf{x}_i` where :math:`\mathbf{a}_i` are the
17 | output accelerations, :math:`\mathbf{W}` is a learnable weight matrix and
18 | :math:`\mathbf{x}_i` are input features.
19 | """
20 |
21 | def __init__(self, dim_out):
22 | """Initialize the model.
23 |
24 | Args:
25 | dim_out: Output dimensionality.
26 | """
27 | super().__init__()
28 | self.mlp = hk.Linear(dim_out)
29 |
30 | def __call__(
31 | self, sample: Tuple[Dict[str, jnp.ndarray], np.ndarray]
32 | ) -> Dict[str, jnp.ndarray]:
33 | # transform
34 | features, particle_type = sample
35 | x = [
36 | features[k]
37 | for k in ["vel_hist", "vel_mag", "bound", "force"]
38 | if k in features
39 | ] + [particle_type[:, None]]
40 | # call
41 | acc = vmap(self.mlp)(jnp.concatenate(x, axis=-1))
42 | return {"acc": acc}
43 |
--------------------------------------------------------------------------------
/lagrangebench/models/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, Iterable, NamedTuple, Optional
2 |
3 | import e3nn_jax as e3nn
4 | import haiku as hk
5 | import jax
6 | import jax.numpy as jnp
7 | import jraph
8 |
9 | from lagrangebench.utils import NodeType
10 |
11 |
12 | class LinearXav(hk.Linear):
13 | """Linear layer with Xavier init. Avoid distracting 'w_init' everywhere."""
14 |
15 | def __init__(
16 | self,
17 | output_size: int,
18 | with_bias: bool = True,
19 | w_init: Optional[hk.initializers.Initializer] = None,
20 | b_init: Optional[hk.initializers.Initializer] = None,
21 | name: Optional[str] = None,
22 | ):
23 | if w_init is None:
24 | w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform")
25 | super().__init__(output_size, with_bias, w_init, b_init, name)
26 |
27 |
28 | class MLPXav(hk.nets.MLP):
29 | """MLP layer with Xavier init. Avoid distracting 'w_init' everywhere."""
30 |
31 | def __init__(
32 | self,
33 | output_sizes: Iterable[int],
34 | with_bias: bool = True,
35 | w_init: Optional[hk.initializers.Initializer] = None,
36 | b_init: Optional[hk.initializers.Initializer] = None,
37 | activation: Callable = jax.nn.silu,
38 | activate_final: bool = False,
39 | name: Optional[str] = None,
40 | ):
41 | if w_init is None:
42 | w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform")
43 | if not with_bias:
44 | b_init = None
45 | super().__init__(
46 | output_sizes,
47 | w_init,
48 | b_init,
49 | with_bias,
50 | activation,
51 | activate_final,
52 | name,
53 | )
54 |
55 |
56 | class SteerableGraphsTuple(NamedTuple):
57 | r"""
58 | Pack (steerable) node and edge attributes with jraph.GraphsTuple.
59 |
60 | Attributes:
61 | graph: jraph.GraphsTuple, graph structure
62 | node_attributes: (N, irreps.dim), node attributes :math:`\mathbf{\hat{a}}_i`
63 | edge_attributes: (E, irreps.dim), edge attributes :math:`\mathbf{\hat{a}}_{ij}`
64 | additional_message_features: (E, edge_dim), optional message features
65 | """
66 |
67 | graph: jraph.GraphsTuple
68 | node_attributes: Optional[e3nn.IrrepsArray] = None
69 | edge_attributes: Optional[e3nn.IrrepsArray] = None
70 | # NOTE: additional_message_features is in a separate field otherwise it would get
71 | # updated by jraph.GraphNetwork. Actual graph edges are used only for the messages.
72 | additional_message_features: Optional[e3nn.IrrepsArray] = None
73 |
74 |
75 | def node_irreps(
76 | metadata: Dict,
77 | input_seq_length: int,
78 | has_external_force: bool,
79 | has_magnitudes: bool,
80 | has_homogeneous_particles: bool,
81 | ) -> str:
82 | """Compute input node irreps based on which features are available."""
83 | irreps = []
84 | irreps.append(f"{input_seq_length - 1}x1o")
85 | if not any(metadata["periodic_boundary_conditions"]):
86 | irreps.append("2x1o")
87 |
88 | if has_external_force:
89 | irreps.append("1x1o")
90 |
91 | if has_magnitudes:
92 | irreps.append(f"{input_seq_length - 1}x0e")
93 |
94 | if not has_homogeneous_particles:
95 | irreps.append(f"{NodeType.SIZE}x0e")
96 |
97 | return e3nn.Irreps("+".join(irreps))
98 |
99 |
100 | def build_mlp(
101 | latent_size, output_size, num_hidden_layers, is_layer_norm=True, **kwds: Dict
102 | ):
103 | """MLP generation helper using Haiku."""
104 | assert num_hidden_layers >= 1
105 | network = hk.nets.MLP(
106 | [latent_size] * (num_hidden_layers - 1) + [output_size],
107 | **kwds,
108 | activate_final=False,
109 | name="MLP",
110 | )
111 | if is_layer_norm:
112 | l_norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
113 | return hk.Sequential([network, l_norm])
114 | else:
115 | return network
116 |
117 |
118 | def features_2d_to_3d(features):
119 | """Add zeros in the z component of 2D features."""
120 | n_nodes = features["vel_hist"].shape[0]
121 | n_edges = features["rel_disp"].shape[0]
122 | n_vels = features["vel_hist"].shape[1]
123 | features["vel_hist"] = jnp.concatenate(
124 | [features["vel_hist"], jnp.zeros((n_nodes, n_vels, 1))], -1
125 | )
126 | features["rel_disp"] = jnp.concatenate(
127 | [features["rel_disp"], jnp.zeros((n_edges, 1))], -1
128 | )
129 | if "bound" in features:
130 | features["bound"] = jnp.concatenate(
131 | [features["bound"], jnp.zeros((n_nodes, 1))], -1
132 | )
133 | if "force" in features:
134 | features["force"] = jnp.concatenate(
135 | [features["force"], jnp.zeros((n_nodes, 1))], -1
136 | )
137 |
138 | return features
139 |
--------------------------------------------------------------------------------
/lagrangebench/runner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from datetime import datetime
4 | from typing import Callable, Dict, Optional, Tuple, Type, Union
5 |
6 | import haiku as hk
7 | import jax
8 | import jax.numpy as jnp
9 | import jmp
10 | import numpy as np
11 | from e3nn_jax import Irreps
12 | from jax import config
13 | from jax_sph.jax_md import space
14 | from omegaconf import DictConfig, OmegaConf
15 |
16 | from lagrangebench import Trainer, infer, models
17 | from lagrangebench.case_setup import case_builder
18 | from lagrangebench.data import H5Dataset
19 | from lagrangebench.defaults import check_cfg
20 | from lagrangebench.evaluate import averaged_metrics
21 | from lagrangebench.models.utils import node_irreps
22 | from lagrangebench.utils import NodeType
23 |
24 |
25 | def train_or_infer(cfg: Union[Dict, DictConfig]):
26 | if isinstance(cfg, Dict):
27 | cfg = OmegaConf.create(cfg)
28 | # sanity check on the passed configs
29 | check_cfg(cfg)
30 |
31 | mode = cfg.mode
32 | load_ckp = cfg.load_ckp
33 | is_test = cfg.eval.test
34 |
35 | if cfg.dtype == "float64":
36 | config.update("jax_enable_x64", True)
37 |
38 | data_train, data_valid, data_test = setup_data(cfg)
39 |
40 | metadata = data_train.metadata
41 | # neighbors search
42 | bounds = np.array(metadata["bounds"])
43 | box = bounds[:, 1] - bounds[:, 0]
44 |
45 | # setup core functions
46 | case = case_builder(
47 | box=box,
48 | metadata=metadata,
49 | input_seq_length=cfg.model.input_seq_length,
50 | cfg_neighbors=cfg.neighbors,
51 | cfg_model=cfg.model,
52 | noise_std=cfg.train.noise_std,
53 | external_force_fn=data_train.external_force_fn,
54 | dtype=cfg.dtype,
55 | )
56 |
57 | _, particle_type = data_train[0]
58 |
59 | # setup model from configs
60 | model, MODEL = setup_model(
61 | cfg,
62 | metadata=metadata,
63 | homogeneous_particles=particle_type.max() == particle_type.min(),
64 | has_external_force=data_train.external_force_fn is not None,
65 | normalization_stats=case.normalization_stats,
66 | )
67 | model = hk.without_apply_rng(hk.transform_with_state(model))
68 |
69 | # mixed precision training based on this reference:
70 | # https://github.com/deepmind/dm-haiku/blob/main/examples/imagenet/train.py
71 | policy = jmp.get_policy("params=float32,compute=float32,output=float32")
72 | hk.mixed_precision.set_policy(MODEL, policy)
73 |
74 | if mode == "train" or mode == "all":
75 | print("Start training...")
76 |
77 | if cfg.logging.run_name is None:
78 | run_prefix = f"{cfg.model.name}_{data_train.name}"
79 | data_and_time = datetime.today().strftime("%Y%m%d-%H%M%S")
80 | cfg.logging.run_name = f"{run_prefix}_{data_and_time}"
81 |
82 | store_ckp = os.path.join(cfg.logging.ckp_dir, cfg.logging.run_name)
83 | os.makedirs(store_ckp, exist_ok=True)
84 | os.makedirs(os.path.join(store_ckp, "best"), exist_ok=True)
85 | with open(os.path.join(store_ckp, "config.yaml"), "w") as f:
86 | OmegaConf.save(config=cfg, f=f.name)
87 | with open(os.path.join(store_ckp, "best", "config.yaml"), "w") as f:
88 | OmegaConf.save(config=cfg, f=f.name)
89 |
90 | # dictionary of configs which will be stored on W&B
91 | wandb_config = OmegaConf.to_container(cfg)
92 |
93 | trainer = Trainer(
94 | model,
95 | case,
96 | data_train,
97 | data_valid,
98 | cfg.train,
99 | cfg.eval,
100 | cfg.logging,
101 | input_seq_length=cfg.model.input_seq_length,
102 | seed=cfg.seed,
103 | )
104 |
105 | _, _, _ = trainer.train(
106 | step_max=cfg.train.step_max,
107 | load_ckp=load_ckp,
108 | store_ckp=store_ckp,
109 | wandb_config=wandb_config,
110 | )
111 |
112 | if mode == "infer" or mode == "all":
113 | print("Start inference...")
114 |
115 | if mode == "infer":
116 | model_dir = load_ckp
117 | if mode == "all":
118 | model_dir = os.path.join(store_ckp, "best")
119 | assert osp.isfile(os.path.join(model_dir, "params_tree.pkl"))
120 |
121 | cfg.eval.rollout_dir = model_dir.replace("ckp", "rollout")
122 | os.makedirs(cfg.eval.rollout_dir, exist_ok=True)
123 |
124 | if cfg.eval.infer.n_trajs is None:
125 | cfg.eval.infer.n_trajs = cfg.eval.train.n_trajs
126 |
127 | assert model_dir, "model_dir must be specified for inference."
128 | metrics = infer(
129 | model,
130 | case,
131 | data_test if is_test else data_valid,
132 | load_ckp=model_dir,
133 | cfg_eval_infer=cfg.eval.infer,
134 | rollout_dir=cfg.eval.rollout_dir,
135 | n_rollout_steps=cfg.eval.n_rollout_steps,
136 | seed=cfg.seed,
137 | )
138 |
139 | split = "test" if is_test else "valid"
140 | print(f"Metrics of {model_dir} on {split} split:")
141 | print(averaged_metrics(metrics))
142 |
143 | return 0
144 |
145 |
146 | def setup_data(cfg) -> Tuple[H5Dataset, H5Dataset, H5Dataset]:
147 | dataset_path = cfg.dataset.src
148 | dataset_name = cfg.dataset.name
149 | ckp_dir = cfg.logging.ckp_dir
150 | rollout_dir = cfg.eval.rollout_dir
151 | input_seq_length = cfg.model.input_seq_length
152 | n_rollout_steps = cfg.eval.n_rollout_steps
153 | nl_backend = cfg.neighbors.backend
154 |
155 | if not osp.isabs(dataset_path):
156 | dataset_path = osp.join(os.getcwd(), dataset_path)
157 |
158 | if ckp_dir is not None:
159 | os.makedirs(ckp_dir, exist_ok=True)
160 | if rollout_dir is not None:
161 | os.makedirs(rollout_dir, exist_ok=True)
162 |
163 | # dataloader
164 | data_train = H5Dataset(
165 | "train",
166 | dataset_path=dataset_path,
167 | name=dataset_name,
168 | input_seq_length=input_seq_length,
169 | extra_seq_length=cfg.train.pushforward.unrolls[-1],
170 | nl_backend=nl_backend,
171 | )
172 | data_valid = H5Dataset(
173 | "valid",
174 | dataset_path=dataset_path,
175 | name=dataset_name,
176 | input_seq_length=input_seq_length,
177 | extra_seq_length=n_rollout_steps,
178 | nl_backend=nl_backend,
179 | )
180 | data_test = H5Dataset(
181 | "test",
182 | dataset_path=dataset_path,
183 | name=dataset_name,
184 | input_seq_length=input_seq_length,
185 | extra_seq_length=n_rollout_steps,
186 | nl_backend=nl_backend,
187 | )
188 |
189 | return data_train, data_valid, data_test
190 |
191 |
192 | def setup_model(
193 | cfg,
194 | metadata: Dict,
195 | homogeneous_particles: bool = False,
196 | has_external_force: bool = False,
197 | normalization_stats: Optional[Dict] = None,
198 | ) -> Tuple[Callable, Type]:
199 | """Setup model based on cfg."""
200 | model_name = cfg.model.name.lower()
201 | input_seq_length = cfg.model.input_seq_length
202 | magnitude_features = cfg.model.magnitude_features
203 |
204 | if model_name == "gns":
205 |
206 | def model_fn(x):
207 | return models.GNS(
208 | particle_dimension=metadata["dim"],
209 | latent_size=cfg.model.latent_dim,
210 | blocks_per_step=cfg.model.num_mlp_layers,
211 | num_mp_steps=cfg.model.num_mp_steps,
212 | num_particle_types=NodeType.SIZE,
213 | particle_type_embedding_size=16,
214 | )(x)
215 |
216 | MODEL = models.GNS
217 | elif model_name == "segnn":
218 | # Hx1o vel, Hx0e vel, 2x1o boundary, 9x0e type
219 | node_feature_irreps = node_irreps(
220 | metadata,
221 | input_seq_length,
222 | has_external_force,
223 | magnitude_features,
224 | homogeneous_particles,
225 | )
226 | # 1o displacement, 0e distance
227 | edge_feature_irreps = Irreps("1x1o + 1x0e")
228 |
229 | def model_fn(x):
230 | return models.SEGNN(
231 | node_features_irreps=node_feature_irreps,
232 | edge_features_irreps=edge_feature_irreps,
233 | scalar_units=cfg.model.latent_dim,
234 | lmax_hidden=cfg.model.lmax_hidden,
235 | lmax_attributes=cfg.model.lmax_attributes,
236 | output_irreps=Irreps("1x1o"),
237 | num_mp_steps=cfg.model.num_mp_steps,
238 | n_vels=cfg.model.input_seq_length - 1,
239 | velocity_aggregate=cfg.model.velocity_aggregate,
240 | homogeneous_particles=homogeneous_particles,
241 | blocks_per_step=cfg.model.num_mlp_layers,
242 | norm=cfg.model.segnn_norm,
243 | )(x)
244 |
245 | MODEL = models.SEGNN
246 | elif model_name == "egnn":
247 | box = cfg.box
248 | if jnp.array(metadata["periodic_boundary_conditions"]).any():
249 | displacement_fn, shift_fn = space.periodic(jnp.array(box))
250 | else:
251 | displacement_fn, shift_fn = space.free()
252 |
253 | displacement_fn = jax.vmap(displacement_fn, in_axes=(0, 0))
254 | shift_fn = jax.vmap(shift_fn, in_axes=(0, 0))
255 |
256 | def model_fn(x):
257 | return models.EGNN(
258 | hidden_size=cfg.model.latent_dim,
259 | output_size=1,
260 | dt=metadata["dt"] * metadata["write_every"],
261 | displacement_fn=displacement_fn,
262 | shift_fn=shift_fn,
263 | normalization_stats=normalization_stats,
264 | num_mp_steps=cfg.model.num_mp_steps,
265 | n_vels=input_seq_length - 1,
266 | residual=True,
267 | )(x)
268 |
269 | MODEL = models.EGNN
270 | elif model_name == "painn":
271 | assert magnitude_features, "PaiNN requires magnitudes"
272 | radius = metadata["default_connectivity_radius"] * 1.5
273 |
274 | def model_fn(x):
275 | return models.PaiNN(
276 | hidden_size=cfg.model.latent_dim,
277 | output_size=1,
278 | n_vels=input_seq_length - 1,
279 | radial_basis_fn=models.painn.gaussian_rbf(20, radius, trainable=True),
280 | cutoff_fn=models.painn.cosine_cutoff(radius),
281 | num_mp_steps=cfg.model.num_mp_steps,
282 | )(x)
283 |
284 | MODEL = models.PaiNN
285 | elif model_name == "linear":
286 |
287 | def model_fn(x):
288 | return models.Linear(dim_out=metadata["dim"])(x)
289 |
290 | MODEL = models.Linear
291 |
292 | return model_fn, MODEL
293 |
--------------------------------------------------------------------------------
/lagrangebench/train/__init__.py:
--------------------------------------------------------------------------------
1 | """Trainer method and training tricks."""
2 |
3 | from .trainer import Trainer
4 |
5 | __all__ = ["Trainer"]
6 |
--------------------------------------------------------------------------------
/lagrangebench/train/strats.py:
--------------------------------------------------------------------------------
1 | """Training tricks and strategies, currently: random-walk noise and push forward."""
2 |
3 | from typing import Tuple
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | from jax_sph.jax_md.partition import space
8 |
9 | from lagrangebench.utils import get_kinematic_mask
10 |
11 |
12 | def add_gns_noise(
13 | key: jax.Array,
14 | pos_input: jnp.ndarray,
15 | particle_type: jnp.ndarray,
16 | input_seq_length: int,
17 | noise_std: float,
18 | shift_fn: space.ShiftFn,
19 | ) -> Tuple[jnp.ndarray, jnp.ndarray]:
20 | r"""GNS-like random walk noise injection as described by
21 | `Sanchez-Gonzalez et al. `_.
22 |
23 | Applies random-walk noise to the input positions and adjusts the targets accordingly
24 | to keep the trajectory consistent. It works by drawing independent samples from
25 | :math:`\mathcal{N^{(t)}}(0, \sigma_v^{(t)})` for each input state. Noise is
26 | accummulated as a random walk and added to the velocity seqence.
27 | Each :math:`\sigma_v^{(t)}` is set so that the last step of the random walk has
28 | :math:`\sigma_v^{(input\_seq\_length)}=noise\_std`. Based on the noised velocities,
29 | positions are adjusted such that :math:`\dot{p}^{t_k} = p^{t_k} − p^{t_{k−1}}`.
30 |
31 | Args:
32 | key: Random key.
33 | pos_input: Clean input positions. Shape:
34 | (num_particles_max, input_seq_length + pushforward["unrolls"][-1] + 1, dim)
35 | particle_type: Particle type vector. Shape: (num_particles_max,)
36 | input_seq_length: Input sequence length, as in the configs.
37 | noise_std: Noise standard deviation at the last sequence step.
38 | shift_fn: Shift function.
39 | """
40 | isl = input_seq_length
41 | # random-walk noise in the velocity applied to the first input_seq_length positions
42 | key, pos_input_noise = _get_random_walk_noise_for_pos_sequence(
43 | key, pos_input[:, :input_seq_length], noise_std_last_step=noise_std
44 | )
45 |
46 | kinematic_mask = get_kinematic_mask(particle_type)
47 | pos_input_noise = jnp.where(kinematic_mask[:, None, None], 0.0, pos_input_noise)
48 | # adjust targets based on the noise from the last input position
49 | n_potential_targets = pos_input[:, isl:].shape[1]
50 | pos_target_noise = pos_input_noise[:, -1][:, None, :]
51 | pos_target_noise = jnp.tile(pos_target_noise, (1, n_potential_targets, 1))
52 | pos_input_noise = jnp.concatenate([pos_input_noise, pos_target_noise], axis=1)
53 |
54 | shift_vmap = jax.vmap(shift_fn, in_axes=(0, 0))
55 | shift_dvmap = jax.vmap(shift_vmap, in_axes=(0, 0))
56 | pos_input_noisy = shift_dvmap(pos_input, pos_input_noise)
57 |
58 | return key, pos_input_noisy
59 |
60 |
61 | def _get_random_walk_noise_for_pos_sequence(
62 | key, position_sequence, noise_std_last_step
63 | ):
64 | key, subkey = jax.random.split(key)
65 | velocity_sequence_shape = list(position_sequence.shape)
66 | velocity_sequence_shape[1] -= 1
67 | n_velocities = velocity_sequence_shape[1]
68 |
69 | velocity_sequence_noise = jax.random.normal(
70 | subkey, shape=tuple(velocity_sequence_shape)
71 | )
72 | velocity_sequence_noise *= noise_std_last_step / (n_velocities**0.5)
73 | velocity_sequence_noise = jnp.cumsum(velocity_sequence_noise, axis=1)
74 |
75 | position_sequence_noise = jnp.concatenate(
76 | [
77 | jnp.zeros_like(velocity_sequence_noise[:, 0:1]),
78 | jnp.cumsum(velocity_sequence_noise, axis=1),
79 | ],
80 | axis=1,
81 | )
82 |
83 | return key, position_sequence_noise
84 |
85 |
86 | def push_forward_sample_steps(key, step, pushforward):
87 | """Sample the number of unroll steps based on the current training step and the
88 | specified pushforward configuration.
89 |
90 | Args:
91 | key: Random key
92 | step: Current training step
93 | pushforward: Pushforward configuration
94 | """
95 | key, key_unroll = jax.random.split(key, 2)
96 |
97 | # steps needs to be an ordered list
98 | steps = jnp.array(pushforward.steps)
99 | assert all(steps[i] <= steps[i + 1] for i in range(len(steps) - 1))
100 |
101 | # until which index to sample from
102 | idx = (step > steps).sum()
103 |
104 | unroll_steps = jax.random.choice(
105 | key_unroll,
106 | a=jnp.array(pushforward.unrolls[:idx]),
107 | p=jnp.array(pushforward.probs[:idx]),
108 | )
109 | return key, unroll_steps
110 |
111 |
112 | def push_forward_build(model_apply, case):
113 | r"""Build the push forward function, introduced by
114 | `Brandstetter et al. `_.
115 |
116 | Pushforward works by adding a stability "pushforward" loss term, in the form of an
117 | adversarial style loss.
118 |
119 | .. math::
120 | L_{pf} = \mathbb{E}_k \mathbb{E}_{u^{k+1} | u^k}
121 | \mathbb{E}_{\epsilon} \left[ \mathcal{L}(f(u^k + \epsilon), u^{k-1}) \right]
122 |
123 | where :math:`\epsilon` is :math:`u^k + \epsilon = f(u^{k−1})`, i.e. the 2-step
124 | unroll of the solver :math:`f` (from step :math:`k-1` to :math:`k`).
125 | The total loss is then :math:`L_{total}=\mathcal{L}(f(u^k), u^{k-1}) + L_{pf}`.
126 | Similarly, for :math:`S > 2` pushforward steps, :math:`L_{pf}` is extended to
127 | :math:`u^{k-S} \dots u^{k-1}` with cumulated :math:`\epsilon` perturbations.
128 |
129 | In practice, this is implemented by unrolling the solver for two steps, but only
130 | running gradients through the last unroll step.
131 |
132 | Args:
133 | model_apply: Model apply function
134 | case: Case setup function
135 | """
136 |
137 | @jax.jit
138 | def push_forward_fn(features, current_pos, particle_type, neighbors, params, state):
139 | """Push forward function.
140 |
141 | Args:
142 | features: Input features
143 | current_pos: Current position
144 | particle_type: Particle type vector
145 | neighbors: Neighbor list
146 | params: Model parameters
147 | state: Model state
148 | """
149 | # no buffer overflow check here, since push forward acts on later epochs
150 | pred, _ = model_apply(params, state, (features, particle_type))
151 | next_pos = case.integrate(pred, current_pos)
152 | current_pos = jnp.concatenate(
153 | [current_pos[:, 1:], next_pos[:, None, :]], axis=1
154 | )
155 |
156 | features, neighbors = case.preprocess_eval(
157 | (current_pos, particle_type), neighbors
158 | )
159 | return current_pos, neighbors, features
160 |
161 | return push_forward_fn
162 |
--------------------------------------------------------------------------------
/lagrangebench/utils.py:
--------------------------------------------------------------------------------
1 | """General utils and config structures."""
2 |
3 | import enum
4 | import json
5 | import os
6 | import pickle
7 | import random
8 | from typing import Callable, Tuple
9 |
10 | import cloudpickle
11 | import jax
12 | import jax.numpy as jnp
13 | import numpy as np
14 | import torch
15 |
16 |
17 | class NodeType(enum.IntEnum):
18 | """Particle types."""
19 |
20 | PAD_VALUE = -1
21 | FLUID = 0
22 | SOLID_WALL = 1
23 | MOVING_WALL = 2
24 | RIGID_BODY = 3
25 | SIZE = 9
26 |
27 |
28 | def get_kinematic_mask(particle_type):
29 | """Return a boolean mask, set to true for all kinematic (obstacle) particles."""
30 | res = jnp.logical_or(
31 | particle_type == NodeType.SOLID_WALL, particle_type == NodeType.MOVING_WALL
32 | )
33 | # In datasets with variable number of particles we treat padding as kinematic nodes
34 | res = jnp.logical_or(res, particle_type == NodeType.PAD_VALUE)
35 | return res
36 |
37 |
38 | def broadcast_to_batch(sample, batch_size: int):
39 | """Broadcast a pytree to a batched one with first dimension batch_size."""
40 | assert batch_size > 0
41 | return jax.tree_map(lambda x: jnp.repeat(x[None, ...], batch_size, axis=0), sample)
42 |
43 |
44 | def broadcast_from_batch(batch, index: int):
45 | """Broadcast a batched pytree to the sample `index` out of the batch."""
46 | assert index >= 0
47 | return jax.tree_map(lambda x: x[index], batch)
48 |
49 |
50 | def save_pytree(ckp_dir: str, pytree_obj, name) -> None:
51 | """Save a pytree to a directory."""
52 | with open(os.path.join(ckp_dir, f"{name}_array.npy"), "wb") as f:
53 | for x in jax.tree_leaves(pytree_obj):
54 | np.save(f, x, allow_pickle=False)
55 |
56 | tree_struct = jax.tree_map(lambda t: 0, pytree_obj)
57 | with open(os.path.join(ckp_dir, f"{name}_tree.pkl"), "wb") as f:
58 | pickle.dump(tree_struct, f)
59 |
60 |
61 | def save_haiku(ckp_dir: str, params, state, opt_state, metadata_ckp) -> None:
62 | """Save params, state and optimizer state to ckp_dir.
63 |
64 | Additionally it tracks and saves the best model to ckp_dir/best.
65 |
66 | See: https://github.com/deepmind/dm-haiku/issues/18
67 | """
68 | save_pytree(ckp_dir, params, "params")
69 | save_pytree(ckp_dir, state, "state")
70 |
71 | with open(os.path.join(ckp_dir, "opt_state.pkl"), "wb") as f:
72 | cloudpickle.dump(opt_state, f)
73 | with open(os.path.join(ckp_dir, "metadata_ckp.json"), "w") as f:
74 | json.dump(metadata_ckp, f)
75 |
76 | # only run for the main checkpoint directory (not best)
77 | if "best" not in ckp_dir:
78 | ckp_dir_best = os.path.join(ckp_dir, "best")
79 | metadata_best_path = os.path.join(ckp_dir, "best", "metadata_ckp.json")
80 | tag = ""
81 |
82 | if os.path.exists(metadata_best_path): # all except first step
83 | with open(metadata_best_path, "r") as fp:
84 | metadata_ckp_best = json.loads(fp.read())
85 |
86 | # if loss is better than best previous loss, save to best model directory
87 | if metadata_ckp["loss"] < metadata_ckp_best["loss"]:
88 | save_haiku(ckp_dir_best, params, state, opt_state, metadata_ckp)
89 | tag = " (best so far)"
90 | else: # first step
91 | save_haiku(ckp_dir_best, params, state, opt_state, metadata_ckp)
92 |
93 | print(
94 | f"saved model to {ckp_dir} at step {metadata_ckp['step']}"
95 | f" with loss {metadata_ckp['loss']}{tag}"
96 | )
97 |
98 |
99 | def load_pytree(model_dir: str, name):
100 | """Load a pytree from a directory."""
101 | with open(os.path.join(model_dir, f"{name}_tree.pkl"), "rb") as f:
102 | tree_struct = pickle.load(f)
103 |
104 | leaves, treedef = jax.tree_flatten(tree_struct)
105 |
106 | with open(os.path.join(model_dir, f"{name}_array.npy"), "rb") as f:
107 | flat_state = [np.load(f) for _ in leaves]
108 |
109 | return jax.tree_unflatten(treedef, flat_state)
110 |
111 |
112 | def load_haiku(model_dir: str):
113 | """Load params, state, optimizer state and last training step from model_dir.
114 |
115 | See: https://github.com/deepmind/dm-haiku/issues/18
116 | """
117 | params = load_pytree(model_dir, "params")
118 | state = load_pytree(model_dir, "state")
119 |
120 | with open(os.path.join(model_dir, "opt_state.pkl"), "rb") as f:
121 | opt_state = cloudpickle.load(f)
122 |
123 | with open(os.path.join(model_dir, "metadata_ckp.json"), "r") as fp:
124 | metadata_ckp = json.loads(fp.read())
125 |
126 | print(f"Loaded model from {model_dir} at step {metadata_ckp['step']}")
127 |
128 | return params, state, opt_state, metadata_ckp["step"]
129 |
130 |
131 | def get_num_params(params):
132 | """Get the number of parameters in a Haiku model."""
133 | return sum(np.prod(p.shape) for p in jax.tree_leaves(params))
134 |
135 |
136 | def print_params_shapes(params, prefix=""):
137 | if not isinstance(params, dict):
138 | print(f"{prefix: <40}, shape = {params.shape}")
139 | else:
140 | for k, v in params.items():
141 | print_params_shapes(v, prefix=prefix + k)
142 |
143 |
144 | def set_seed(seed: int) -> Tuple[jax.Array, Callable, torch.Generator]:
145 | """Set seeds for jax, random and torch."""
146 | # first PRNG key
147 | key = jax.random.PRNGKey(seed)
148 | np.random.seed(seed)
149 | random.seed(seed)
150 | torch.manual_seed(seed)
151 |
152 | # dataloader-related seeds
153 | def seed_worker(_):
154 | worker_seed = torch.initial_seed() % 2**32
155 | np.random.seed(worker_seed)
156 | random.seed(worker_seed)
157 |
158 | generator = torch.Generator()
159 | generator.manual_seed(seed)
160 |
161 | return key, seed_worker, generator
162 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from omegaconf import DictConfig, OmegaConf
4 |
5 |
6 | def check_subset(superset, subset, full_key=""):
7 | """Check that the keys of 'subset' are a subset of 'superset'."""
8 | for k, v in subset.items():
9 | key = full_key + k
10 | if isinstance(v, dict):
11 | check_subset(superset[k], v, key + ".")
12 | else:
13 | msg = f"cli_args must be a subset of the defaults. Wrong cli key: '{key}'"
14 | assert k in superset, msg
15 |
16 |
17 | def load_embedded_configs(config_path: str, cli_args: DictConfig) -> DictConfig:
18 | """Loads all 'extends' embedded configs and merge them with the cli overwrites."""
19 |
20 | cfgs = [OmegaConf.load(config_path)]
21 | while "extends" in cfgs[0]:
22 | extends_path = cfgs[0]["extends"]
23 | del cfgs[0]["extends"]
24 |
25 | # go to parents configs until the defaults are reached
26 | if extends_path != "LAGRANGEBENCH_DEFAULTS":
27 | cfgs = [OmegaConf.load(extends_path)] + cfgs
28 | else:
29 | from lagrangebench.defaults import defaults
30 |
31 | cfgs = [defaults] + cfgs
32 |
33 | # assert that the cli_args are a subset of the defaults if inheritance from
34 | # defaults is used.
35 | check_subset(cfgs[0], cli_args)
36 |
37 | break
38 |
39 | # merge all embedded configs and give highest priority to cli_args
40 | cfg = OmegaConf.merge(*cfgs, cli_args)
41 | return cfg
42 |
43 |
44 | if __name__ == "__main__":
45 | cli_args = OmegaConf.from_cli()
46 | assert ("config" in cli_args) != (
47 | "load_ckp" in cli_args
48 | ), "You must specify one of 'config' or 'load_ckp'."
49 |
50 | if "config" in cli_args: # start from config.yaml
51 | config_path = cli_args.config
52 | elif "load_ckp" in cli_args: # start from a checkpoint
53 | config_path = os.path.join(cli_args.load_ckp, "config.yaml")
54 |
55 | # values that need to be specified before importing jax
56 | cli_args.gpu = cli_args.get("gpu", -1)
57 | cli_args.xla_mem_fraction = cli_args.get("xla_mem_fraction", 0.75)
58 |
59 | # specify cuda device
60 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow
61 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cli_args.gpu)
62 | if cli_args.gpu == -1:
63 | os.environ["JAX_PLATFORMS"] = "cpu"
64 | os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cli_args.xla_mem_fraction)
65 |
66 | # The following line makes the code deterministic on GPUs, but also extremely slow.
67 | # os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
68 |
69 | cfg = load_embedded_configs(config_path, cli_args)
70 |
71 | print("#" * 79, "\nStarting a LagrangeBench run with the following configs:")
72 | print(OmegaConf.to_yaml(cfg))
73 | print("#" * 79)
74 |
75 | from lagrangebench.runner import train_or_infer
76 |
77 | train_or_infer(cfg)
78 |
--------------------------------------------------------------------------------
/notebooks/media/scatter.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tumaer/lagrangebench/b880a6c84a93792d2499d2a9b8ba3a077ddf44e2/notebooks/media/scatter.gif
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "lagrangebench"
3 | version = "0.2.0"
4 | description = "LagrangeBench: A Lagrangian Fluid Mechanics Benchmarking Suite"
5 | authors = [
6 | "Artur Toshev, Gianluca Galletti "
7 | ]
8 | license = "MIT"
9 | readme = "README.md"
10 | homepage = "https://lagrangebench.readthedocs.io/"
11 | documentation = "https://lagrangebench.readthedocs.io/"
12 | repository = "https://github.com/tumaer/lagrangebench"
13 | keywords = [
14 | "smoothed-particle-hydrodynamics",
15 | "benchmark-suite",
16 | "lagrangian-dynamics",
17 | "graph-neural-networks",
18 | "lagrangian-particles",
19 | ]
20 | classifiers = [
21 | "Development Status :: 3 - Alpha",
22 | "Intended Audience :: Science/Research",
23 | "License :: OSI Approved :: MIT License",
24 | "Operating System :: MacOS",
25 | "Operating System :: POSIX :: Linux",
26 | "Programming Language :: Python :: 3",
27 | "Programming Language :: Python :: 3.9",
28 | "Programming Language :: Python :: 3.10",
29 | "Programming Language :: Python :: 3.11",
30 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
31 | "Topic :: Scientific/Engineering :: Physics",
32 | "Topic :: Scientific/Engineering :: Hydrology",
33 | "Topic :: Software Development :: Libraries :: Python Modules",
34 | "Typing :: Typed",
35 | ]
36 |
37 | [tool.poetry.dependencies]
38 | python = ">=3.9,<=3.11"
39 | cloudpickle = ">=2.2.1"
40 | h5py = ">=3.9.0"
41 | PyYAML = ">=6.0"
42 | numpy = ">=1.24.4"
43 | wandb = ">=0.15.11"
44 | pyvista = ">=0.42.2"
45 | jax = {version = "0.4.29", extras = ["cpu"]}
46 | jaxlib = "0.4.29"
47 | dm-haiku = ">=0.0.10"
48 | e3nn-jax = "0.20.3"
49 | jmp = ">=0.0.4"
50 | jraph = "0.0.6.dev0"
51 | optax = "0.1.7"
52 | ott-jax = ">=0.4.2"
53 | matscipy = ">=0.8.0"
54 | torch = {version = "2.3.1+cpu", source = "torchcpu"}
55 | wget = ">=3.2"
56 | omegaconf = ">=2.3.0"
57 | jax-sph = ">=0.0.3"
58 |
59 | [tool.poetry.group.dev.dependencies]
60 | # mypy = ">=1.8.0" - consider in the future
61 | pre-commit = ">=3.3.1"
62 | pytest = ">=7.3.1"
63 | pytest-cov = ">=4.1.0"
64 | ruff = "0.2.2"
65 | ipykernel = ">=6.25.1"
66 |
67 | [tool.poetry.group.docs.dependencies]
68 | sphinx = "7.2.6"
69 | sphinx-rtd-theme = "1.3.0"
70 | toml = ">=0.10.2"
71 |
72 | [[tool.poetry.source]]
73 | name = "torchcpu"
74 | url = "https://download.pytorch.org/whl/cpu"
75 | priority = "explicit"
76 |
77 | [tool.ruff]
78 | exclude = [
79 | ".git",
80 | ".venv",
81 | "venv",
82 | "docs/_build",
83 | "dist",
84 | "notebooks",
85 | ]
86 | show-fixes = true
87 | line-length = 88
88 |
89 | [tool.ruff.lint]
90 | ignore = ["F811", "E402"]
91 | select = [
92 | "E", # pycodestyle
93 | "F", # Pyflakes
94 | "SIM", # flake8-simplify
95 | "I", # isort
96 | # "D", # pydocstyle - consider in the future
97 | ]
98 |
99 | [tool.ruff.lint.isort]
100 | known-third-party = ["wandb"]
101 |
102 | [tool.pytest.ini_options]
103 | testpaths = "tests/"
104 | addopts = "--cov=lagrangebench --cov-fail-under=50"
105 | filterwarnings = [
106 | # ignore all deprecation warnings except from lagrangebench
107 | "ignore::DeprecationWarning:^(?!.*lagrangebench).*"
108 | ]
109 |
110 | # Install bumpversion with: pip install -U poetry-bumpversion
111 | # Use: poetry version {major|minor|patch}
112 | [tool.poetry_bumpversion.file."lagrangebench/__init__.py"]
113 | [build-system]
114 | requires = ["poetry-core"]
115 | build-backend = "poetry.core.masonry.api"
116 |
--------------------------------------------------------------------------------
/requirements_cuda.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cpu
2 |
3 | cloudpickle
4 | dm_haiku>=0.0.10
5 | e3nn_jax==0.20.3
6 | h5py
7 | jax-sph>=0.0.3
8 | jax[cuda12]==0.4.29
9 | jmp>=0.0.4
10 | jraph>=0.0.6.dev0
11 | matscipy>=0.8.0
12 | omegaconf>=2.3.0
13 | optax>=0.1.7
14 | ott-jax>=0.4.2
15 | pyvista
16 | PyYAML
17 | torch==2.3.1+cpu
18 | wandb
19 | wget
20 |
--------------------------------------------------------------------------------
/tests/3D_LJ_3_1214every1/metadata.json:
--------------------------------------------------------------------------------
1 | {
2 | "solver": "JAXMD",
3 | "dim": 3,
4 | "dx": 1.4,
5 | "dt": 0.005,
6 | "t_end": 10.0,
7 | "sequence_length_train": 1214,
8 | "num_trajs_train": 1,
9 | "sequence_length_test": 405,
10 | "num_trajs_test": 1,
11 | "num_particles_max": 3,
12 | "periodic_boundary_conditions": [
13 | true,
14 | true,
15 | true
16 | ],
17 | "bounds": [
18 | [
19 | 0.0,
20 | 5.0
21 | ],
22 | [
23 | 0.0,
24 | 5.0
25 | ],
26 | [
27 | 0.0,
28 | 5.0
29 | ]
30 | ],
31 | "default_connectivity_radius": 3.0,
32 | "vel_mean": [
33 | -5.573862482677328e-10,
34 | 4.917874996124283e-10,
35 | -1.3441651125489784e-09
36 | ],
37 | "vel_std": [
38 | 0.006350979674607515,
39 | 0.005811989773064852,
40 | 0.003586509730666876
41 | ],
42 | "acc_mean": [
43 | -3.2785833076198756e-11,
44 | -6.557166615239751e-11,
45 | 0.0
46 | ],
47 | "acc_std": [
48 | 0.0011505373986437917,
49 | 0.0005201193853281438,
50 | 0.00039340186049230397
51 | ],
52 | "description": "System of 3 Lennard-Jones particles in a periodic 3D box simulated with JAX-MD. Can be used to test the preprocessing and rollout utilities."
53 | }
54 |
--------------------------------------------------------------------------------
/tests/3D_LJ_3_1214every1/test.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tumaer/lagrangebench/b880a6c84a93792d2499d2a9b8ba3a077ddf44e2/tests/3D_LJ_3_1214every1/test.h5
--------------------------------------------------------------------------------
/tests/3D_LJ_3_1214every1/train.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tumaer/lagrangebench/b880a6c84a93792d2499d2a9b8ba3a077ddf44e2/tests/3D_LJ_3_1214every1/train.h5
--------------------------------------------------------------------------------
/tests/3D_LJ_3_1214every1/valid.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tumaer/lagrangebench/b880a6c84a93792d2499d2a9b8ba3a077ddf44e2/tests/3D_LJ_3_1214every1/valid.h5
--------------------------------------------------------------------------------
/tests/case_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import numpy as np
6 |
7 | from lagrangebench.case_setup import case_builder
8 |
9 |
10 | class TestCaseBuilder(unittest.TestCase):
11 | """Class for unit testing the case builder functions."""
12 |
13 | def setUp(self):
14 | self.metadata = {
15 | "num_particles_max": 3,
16 | "periodic_boundary_conditions": [True, True, True],
17 | "default_connectivity_radius": 0.3,
18 | "bounds": [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0]],
19 | "acc_mean": [0.0, 0.0, 0.0],
20 | "acc_std": [1.0, 1.0, 1.0],
21 | "vel_mean": [0.0, 0.0, 0.0],
22 | "vel_std": [1.0, 1.0, 1.0],
23 | }
24 |
25 | bounds = np.array(self.metadata["bounds"])
26 | box = bounds[:, 1] - bounds[:, 0]
27 |
28 | self.case = case_builder(
29 | box,
30 | self.metadata,
31 | input_seq_length=3, # two past velocities
32 | cfg_neighbors={"backend": "jaxmd_vmap", "multiplier": 1.25},
33 | cfg_model={"isotropic_norm": False, "magnitude_features": False},
34 | noise_std=0.0,
35 | external_force_fn=None,
36 | )
37 | self.key = jax.random.PRNGKey(0)
38 |
39 | # position input shape (num_particles, sequence_len, dim) = (3, 5, 3)
40 | self.position_data = np.array(
41 | [
42 | [
43 | [0.5, 0.5, 0.5],
44 | [0.5, 0.5, 0.5],
45 | [0.5, 0.5, 0.5],
46 | [0.5, 0.5, 0.5],
47 | [0.5, 0.5, 0.5],
48 | ],
49 | [
50 | [0.7, 0.5, 0.5],
51 | [0.9, 0.5, 0.5],
52 | [0.1, 0.5, 0.5],
53 | [0.3, 0.5, 0.5],
54 | [0.5, 0.5, 0.5],
55 | ],
56 | [
57 | [0.8, 0.6, 0.5],
58 | [0.8, 0.6, 0.5],
59 | [0.9, 0.6, 0.5],
60 | [0.2, 0.6, 0.5],
61 | [0.6, 0.6, 0.5],
62 | ],
63 | ]
64 | )
65 | self.particle_types = np.array([0, 0, 0])
66 |
67 | _, _, _, neighbors = self.case.allocate(
68 | self.key, (self.position_data, self.particle_types)
69 | )
70 | self.neighbors = neighbors
71 |
72 | def test_allocate(self):
73 | # test PBC and velocity and acceleration computation without noise
74 | key, features, target_dict, neighbors = self.case.allocate(
75 | self.key, (self.position_data, self.particle_types)
76 | )
77 | self.assertTrue(
78 | (
79 | neighbors.idx == jnp.array([[0, 1, 2, 2, 1, 3], [0, 1, 1, 2, 2, 3]])
80 | ).all(),
81 | "Wrong edge list after allocate",
82 | )
83 |
84 | self.assertTrue((key != self.key).all(), "Key not updated at allocate")
85 |
86 | self.assertTrue(
87 | jnp.isclose(
88 | target_dict["vel"],
89 | jnp.array([[0.0, 0.0, 0.0], [0.2, 0.0, 0.0], [0.3, 0.0, 0.0]]),
90 | ).all(),
91 | "Wrong target velocity at allocate",
92 | )
93 |
94 | self.assertTrue(
95 | jnp.isclose(
96 | target_dict["acc"],
97 | jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.2, 0.0, 0.0]]),
98 | ).all(),
99 | "Wrong target acceleration at allocate",
100 | )
101 |
102 | self.assertTrue(
103 | jnp.isclose(
104 | features["vel_hist"],
105 | jnp.array(
106 | [
107 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], # particle 1, two past vels.
108 | [0.2, 0.0, 0.0, 0.2, 0.0, 0.0],
109 | [0.0, 0.0, 0.0, 0.1, 0.0, 0.0],
110 | ]
111 | ),
112 | ).all(),
113 | "Wrong historic velocities at allocate",
114 | )
115 |
116 | most_recent_displacement = jnp.array(
117 | [
118 | [0.0, 0.0, 0.0], # edge 0-0
119 | [0.0, 0.0, 0.0], # edge 1-1
120 | [-0.2, 0.1, 0.0], # edge 2-1
121 | [0.0, 0.0, 0.0], # edge 2-2
122 | [0.2, -0.1, 0.0], # edge 1-2
123 | [0.0, 0.0, 0.0], # edge 3-3
124 | ]
125 | )
126 | r0 = self.metadata["default_connectivity_radius"]
127 | normalized_displ = most_recent_displacement / r0
128 | normalized_dist = ((normalized_displ**2).sum(-1, keepdims=True)) ** 0.5
129 |
130 | self.assertTrue(
131 | jnp.isclose(features["rel_disp"], normalized_displ).all(),
132 | "Wrong relative displacement at allocate",
133 | )
134 | self.assertTrue(
135 | jnp.isclose(features["rel_dist"], normalized_dist).all(),
136 | "Wrong relative distance at allocate",
137 | )
138 |
139 | def test_preprocess_base(self):
140 | # preprocess is 1-to-1 the same as allocate, up to the neighbors' computation
141 | _, _, _, neighbors_new = self.case.preprocess(
142 | self.key, (self.position_data, self.particle_types), 0.0, self.neighbors, 0
143 | )
144 |
145 | self.assertTrue(
146 | (self.neighbors.idx == neighbors_new.idx).all(),
147 | "Wrong edge list after preprocess",
148 | )
149 |
150 | def test_preprocess_unroll(self):
151 | # test getting the second available target acceleration
152 | _, _, target_dict, _ = self.case.preprocess(
153 | self.key, (self.position_data, self.particle_types), 0.0, self.neighbors, 1
154 | )
155 |
156 | self.assertTrue(
157 | jnp.isclose(
158 | target_dict["acc"],
159 | jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.1, 0.0, 0.0]]),
160 | atol=1e-07,
161 | ).all(),
162 | "Wrong target acceleration at preprocess",
163 | )
164 |
165 | def test_preprocess_noise(self):
166 | # test that both potential targets are corrected with the proper noise
167 | # we choose noise_std=0.01 to guarantee that no particle will jump periodically
168 | _, features, target_dict, _ = self.case.preprocess(
169 | self.key, (self.position_data, self.particle_types), 0.01, self.neighbors, 0
170 | )
171 | vel_next1 = jnp.array([[0.0, 0.0, 0.0], [0.2, 0.0, 0.0], [0.3, 0.0, 0.0]])
172 | correct_target_acc = vel_next1 - features["vel_hist"][:, 3:6]
173 | self.assertTrue(
174 | jnp.isclose(correct_target_acc, target_dict["acc"], atol=1e-7).all(),
175 | "Wrong target acceleration at preprocess",
176 | )
177 |
178 | # with one push-forward step on top
179 | _, features, target_dict, _ = self.case.preprocess(
180 | self.key, (self.position_data, self.particle_types), 0.01, self.neighbors, 1
181 | )
182 | vel_next2 = jnp.array([[0.0, 0.0, 0.0], [0.2, 0.0, 0.0], [0.4, 0.0, 0.0]])
183 | correct_target_acc = vel_next2 - vel_next1
184 | self.assertTrue(
185 | jnp.isclose(correct_target_acc, target_dict["acc"], atol=1e-7).all(),
186 | "Wrong target acceleration at preprocess with 1 pushforward step",
187 | )
188 |
189 | def test_allocate_eval(self):
190 | pass
191 |
192 | def test_preprocess_eval(self):
193 | pass
194 |
195 | def test_integrate(self):
196 | # given the reference acceleration, compute the next position
197 | correct_acceletation = {
198 | "acc": jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.2, 0.0, 0.0]])
199 | }
200 |
201 | new_pos = self.case.integrate(correct_acceletation, self.position_data[:, :3])
202 |
203 | self.assertTrue(
204 | jnp.isclose(new_pos, self.position_data[:, 3]).all(),
205 | "Wrong new position at integration",
206 | )
207 |
208 |
209 | if __name__ == "__main__":
210 | unittest.main()
211 |
--------------------------------------------------------------------------------
/tests/models_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import e3nn_jax as e3nn
4 | import haiku as hk
5 | import jax
6 | import jax.numpy as jnp
7 | import numpy as np
8 |
9 | from lagrangebench import models
10 | from lagrangebench.utils import NodeType
11 |
12 |
13 | class ModelTest(unittest.TestCase):
14 | def dummy_sample(self, vel=None, pos=None):
15 | key = self.key()
16 |
17 | if vel is None:
18 | vel = jax.random.uniform(key, (100, 5 * 3))
19 | if pos is None:
20 | pos = jax.random.uniform(key, (100, 1, 3))
21 |
22 | senders = jax.random.randint(key, (200,), 0, 100)
23 | receivers = jax.random.randint(key, (200,), 0, 100)
24 | rel_disp = (pos[receivers] - pos[senders]).squeeze()
25 | x = {
26 | "vel_hist": vel,
27 | "vel_mag": jnp.sum(vel.reshape(100, -1, 3) ** 2, -1) ** 0.5,
28 | "rel_disp": rel_disp,
29 | "rel_dist": jnp.sum(rel_disp**2, -1, keepdims=True) ** 0.5,
30 | "abs_pos": pos,
31 | "senders": senders,
32 | "receivers": receivers,
33 | }
34 | particle_type = jnp.ones((100, 1), dtype=jnp.int32) * NodeType.FLUID
35 | return x, particle_type
36 |
37 | def key(self):
38 | return jax.random.PRNGKey(0)
39 |
40 | def assert_equivariant(self, f, params, state):
41 | key = self.key()
42 |
43 | vel = e3nn.normal("5x1o", key, (100,))
44 | pos = e3nn.normal("1x1o", key, (100,))
45 |
46 | def wrapper(v, p):
47 | sample, particle_type = self.dummy_sample()
48 | sample.update(
49 | {
50 | "vel_hist": v.array.reshape((100, 5 * 3)),
51 | "abs_pos": p.array.reshape((100, 1, 3)),
52 | }
53 | )
54 | y, _ = f.apply(params, state, (sample, particle_type))
55 | return e3nn.IrrepsArray("1x1o", y["acc"])
56 |
57 | # random rotation matrix
58 | R = -e3nn.rand_matrix(key, ())
59 |
60 | out1 = wrapper(vel.transform_by_matrix(R), pos.transform_by_matrix(R))
61 | out2 = wrapper(vel, pos).transform_by_matrix(R)
62 |
63 | def assert_(x, y):
64 | self.assertTrue(
65 | np.isclose(x, y, atol=1e-5, rtol=1e-5).all(), "Not equivariant!"
66 | )
67 |
68 | jax.tree_util.tree_map(assert_, out1, out2)
69 |
70 | def test_segnn(self):
71 | def segnn(x):
72 | return models.SEGNN(
73 | node_features_irreps="5x1o + 5x0e",
74 | edge_features_irreps="1x1o + 1x0e",
75 | scalar_units=8,
76 | lmax_hidden=1,
77 | lmax_attributes=1,
78 | n_vels=5,
79 | num_mp_steps=1,
80 | output_irreps="1x1o",
81 | )(x)
82 |
83 | segnn = hk.without_apply_rng(hk.transform_with_state(segnn))
84 | x, particle_type = self.dummy_sample()
85 | params, segnn_state = segnn.init(self.key(), (x, particle_type))
86 |
87 | self.assert_equivariant(segnn, params, segnn_state)
88 |
89 | def test_egnn(self):
90 | def egnn(x):
91 | return models.EGNN(
92 | hidden_size=8,
93 | output_size=1,
94 | num_mp_steps=1,
95 | dt=0.01,
96 | n_vels=5,
97 | displacement_fn=lambda x, y: x - y,
98 | shift_fn=lambda x, y: x + y,
99 | )(x)
100 |
101 | egnn = hk.without_apply_rng(hk.transform_with_state(egnn))
102 | x, particle_type = self.dummy_sample()
103 | params, egnn_state = egnn.init(self.key(), (x, particle_type))
104 |
105 | self.assert_equivariant(egnn, params, egnn_state)
106 |
107 | def test_painn(self):
108 | def painn(x):
109 | return models.PaiNN(
110 | hidden_size=8,
111 | output_size=1,
112 | num_mp_steps=1,
113 | radial_basis_fn=models.painn.gaussian_rbf(20, 10, trainable=True),
114 | cutoff_fn=models.painn.cosine_cutoff(10),
115 | n_vels=5,
116 | )(x)
117 |
118 | painn = hk.without_apply_rng(hk.transform_with_state(painn))
119 | x, particle_type = self.dummy_sample()
120 | params, painn_state = painn.init(self.key(), (x, particle_type))
121 |
122 | self.assert_equivariant(painn, params, painn_state)
123 |
124 |
125 | if __name__ == "__main__":
126 | unittest.main()
127 |
--------------------------------------------------------------------------------
/tests/pushforward_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax
4 | import numpy as np
5 | from omegaconf import OmegaConf
6 |
7 | from lagrangebench.train.strats import push_forward_sample_steps
8 |
9 |
10 | class TestPushForward(unittest.TestCase):
11 | """Class for unit testing the push-forward functions."""
12 |
13 | def setUp(self):
14 | self.pf = OmegaConf.create(
15 | {
16 | "steps": [-1, 20000, 50000, 100000],
17 | "unrolls": [0, 1, 3, 20],
18 | "probs": [4.05, 4.05, 1.0, 1.0],
19 | }
20 | )
21 |
22 | self.key = jax.random.PRNGKey(42)
23 |
24 | def body_steps(self, step, unrolls, probs):
25 | dump = []
26 | for _ in range(1000):
27 | self.key, unroll_steps = push_forward_sample_steps(self.key, step, self.pf)
28 | dump.append(unroll_steps)
29 |
30 | # Note: np.unique returns sorted array
31 | unique, counts = np.unique(dump, return_counts=True)
32 | self.assertTrue((unique == unrolls).all(), "Wrong unroll steps")
33 | self.assertTrue(
34 | np.allclose(counts / 1000, probs, atol=0.05),
35 | "Wrong probabilities of unroll steps",
36 | )
37 |
38 | def test_pf_step_1(self):
39 | self.body_steps(1, np.array([0]), np.array([1.0]))
40 |
41 | def test_pf_step_60000(self):
42 | self.body_steps(60000, np.array([0, 1, 3]), np.array([0.45, 0.45, 0.1]))
43 |
44 |
45 | if __name__ == "__main__":
46 | unittest.main()
47 |
--------------------------------------------------------------------------------
/tests/rollout_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from functools import partial
3 |
4 | import haiku as hk
5 | import jax
6 | import jax.numpy as jnp
7 | import numpy as np
8 | from jax import config as jax_config
9 | from jax import jit, vmap
10 | from jax_sph.jax_md import space
11 | from omegaconf import OmegaConf
12 | from torch.utils.data import DataLoader
13 |
14 | jax_config.update("jax_enable_x64", True)
15 |
16 | from lagrangebench.case_setup import case_builder
17 | from lagrangebench.data import H5Dataset
18 | from lagrangebench.data.utils import get_dataset_stats, numpy_collate
19 | from lagrangebench.evaluate import MetricsComputer
20 | from lagrangebench.evaluate.rollout import _eval_batched_rollout, _forward_eval
21 | from lagrangebench.utils import broadcast_from_batch
22 |
23 |
24 | class TestInferBuilder(unittest.TestCase):
25 | """Class for unit testing the evaluate_single_rollout function."""
26 |
27 | def setUp(self):
28 | self.cfg = OmegaConf.create(
29 | {
30 | "dataset": {
31 | "src": "tests/3D_LJ_3_1214every1", # Lennard-Jones dataset
32 | },
33 | "model": {
34 | "input_seq_length": 3, # two past velocities
35 | "isotropic_norm": False,
36 | },
37 | "eval": {
38 | "train": {"metrics": ["mse"]},
39 | "n_rollout_steps": 100,
40 | },
41 | "train": {"noise_std": 0.0},
42 | }
43 | )
44 |
45 | data_valid = H5Dataset(
46 | split="valid",
47 | dataset_path=self.cfg.dataset.src,
48 | name="lj3d",
49 | input_seq_length=self.cfg.model.input_seq_length,
50 | extra_seq_length=self.cfg.eval.n_rollout_steps,
51 | )
52 | self.loader_valid = DataLoader(
53 | dataset=data_valid, batch_size=1, collate_fn=numpy_collate
54 | )
55 |
56 | self.metadata = data_valid.metadata
57 | self.normalization_stats = get_dataset_stats(
58 | self.metadata, self.cfg.model.isotropic_norm, self.cfg.train.noise_std
59 | )
60 |
61 | bounds = np.array(self.metadata["bounds"])
62 | box = bounds[:, 1] - bounds[:, 0]
63 | self.displacement_fn, self.shift_fn = space.periodic(side=box)
64 |
65 | self.case = case_builder(
66 | box,
67 | self.metadata,
68 | self.cfg.model.input_seq_length,
69 | noise_std=self.cfg.train.noise_std,
70 | )
71 |
72 | self.key = jax.random.PRNGKey(0)
73 |
74 | def test_rollout(self):
75 | isl = self.loader_valid.dataset.input_seq_length
76 |
77 | # get one validation trajectory from the debug dataset
78 | traj_batch_i = next(iter(self.loader_valid))
79 | traj_batch_i = jax.tree_map(lambda x: jnp.array(x), traj_batch_i)
80 | # remove batch dimension
81 | self.assertTrue(traj_batch_i[0].shape[0] == 1, "We test only batch size 1")
82 | traj_i = broadcast_from_batch(traj_batch_i, index=0)
83 | positions = traj_i[0] # (nodes, t, dim) = (3, 405, 3)
84 |
85 | displ_vmap = vmap(self.displacement_fn, (0, 0))
86 | displ_dvmap = vmap(displ_vmap, (0, 0))
87 | vels = displ_dvmap(positions[:, 1:], positions[:, :-1]) # (3, 404, 3)
88 | accs = vels[:, 1:] - vels[:, :-1] # (3, 403, 3)
89 | stats = self.normalization_stats["acceleration"]
90 | accs = (accs - stats["mean"]) / stats["std"]
91 |
92 | class CheatingModel(hk.Module):
93 | def __init__(self, target, start):
94 | super().__init__()
95 | self.target = target
96 | self.start = start
97 |
98 | def __call__(self, x):
99 | i = hk.get_state(
100 | "counter",
101 | shape=[],
102 | dtype=jnp.int32,
103 | init=hk.initializers.Constant(self.start),
104 | )
105 | hk.set_state("counter", i + 1)
106 | return {"acc": self.target[:, i]}
107 |
108 | def setup_model(target, start):
109 | def model(x):
110 | return CheatingModel(target, start)(x)
111 |
112 | model = hk.without_apply_rng(hk.transform_with_state(model))
113 | params, state = model.init(None, None)
114 | model_apply = model.apply
115 | model_apply = jit(model_apply)
116 | return params, state, model_apply
117 |
118 | params, state, model_apply = setup_model(accs, 0)
119 |
120 | # proof that the above "model" works
121 | out, state = model_apply(params, state, None)
122 | pred_acc = stats["mean"] + out["acc"] * stats["std"]
123 | pred_pos = self.shift_fn(positions[:, isl - 1], vels[:, isl - 2] + pred_acc)
124 | pred_pos = jnp.asarray(pred_pos, dtype=jnp.float32)
125 | target_pos = positions[:, isl]
126 |
127 | assert jnp.isclose(pred_pos, target_pos, atol=1e-7).all(), "Wrong setup"
128 |
129 | params, state, model_apply = setup_model(accs, isl - 2)
130 | _, neighbors = self.case.allocate_eval((positions[:, :isl], traj_i[1]))
131 |
132 | metrics_computer = MetricsComputer(
133 | ["mse"],
134 | self.case.displacement,
135 | self.metadata,
136 | isl,
137 | )
138 |
139 | forward_eval = partial(
140 | _forward_eval,
141 | model_apply=model_apply,
142 | case_integrate=self.case.integrate,
143 | )
144 | forward_eval_vmap = vmap(forward_eval, in_axes=(None, None, 0, 0, 0))
145 | preprocess_eval_vmap = vmap(self.case.preprocess_eval, in_axes=(0, 0))
146 | metrics_computer_vmap = vmap(metrics_computer, in_axes=(0, 0))
147 |
148 | for n_extrap_steps in [0, 5, 10]:
149 | with self.subTest(n_extrap_steps):
150 | example_rollout_batch, metrics_batch, neighbors = _eval_batched_rollout(
151 | forward_eval_vmap=forward_eval_vmap,
152 | preprocess_eval_vmap=preprocess_eval_vmap,
153 | case=self.case,
154 | params=params,
155 | state=state,
156 | traj_batch_i=traj_batch_i,
157 | neighbors=neighbors,
158 | metrics_computer_vmap=metrics_computer_vmap,
159 | n_rollout_steps=self.cfg.eval.n_rollout_steps,
160 | n_extrap_steps=n_extrap_steps,
161 | t_window=isl,
162 | )
163 | example_rollout = broadcast_from_batch(example_rollout_batch, index=0)
164 | metrics = broadcast_from_batch(metrics_batch, index=0)
165 |
166 | self.assertTrue(
167 | jnp.isclose(
168 | metrics["mse"].mean(),
169 | jnp.array(0.0),
170 | atol=1e-6,
171 | ).all(),
172 | "Wrong rollout mse",
173 | )
174 |
175 | pos_input = traj_i[0].transpose(1, 0, 2) # (t, nodes, dim)
176 | initial_positions = pos_input[:isl]
177 | example_full = np.concatenate(
178 | [initial_positions, example_rollout], axis=0
179 | )
180 | rollout_dict = {
181 | "predicted_rollout": example_full, # (t, nodes, dim)
182 | "ground_truth_rollout": pos_input, # (t, nodes, dim)
183 | }
184 |
185 | self.assertTrue(
186 | jnp.isclose(
187 | rollout_dict["predicted_rollout"][100, 0],
188 | rollout_dict["ground_truth_rollout"][100, 0],
189 | atol=1e-6,
190 | ).all(),
191 | "Wrong rollout prediction",
192 | )
193 |
194 | total_steps = self.cfg.eval.n_rollout_steps + n_extrap_steps
195 | assert example_rollout_batch.shape[1] == total_steps
196 |
197 |
198 | if __name__ == "__main__":
199 | unittest.main()
200 |
--------------------------------------------------------------------------------
/tests/runner_test.py:
--------------------------------------------------------------------------------
1 | """Runner test with a linear model and LJ dataset."""
2 |
3 | import unittest
4 |
5 | from omegaconf import OmegaConf
6 |
7 | from lagrangebench.defaults import defaults
8 | from lagrangebench.runner import train_or_infer
9 |
10 |
11 | class TestRunner(unittest.TestCase):
12 | """Test whether train_or_infer runs through."""
13 |
14 | def setUp(self):
15 | self.cfg = OmegaConf.create(
16 | {
17 | "mode": "all",
18 | "dataset": {
19 | "src": "tests/3D_LJ_3_1214every1",
20 | },
21 | "model": {
22 | "name": "linear",
23 | "input_seq_length": 3,
24 | },
25 | "train": {
26 | "step_max": 10,
27 | "noise_std": 0.0,
28 | },
29 | "eval": {
30 | "n_rollout_steps": 5,
31 | "train": {
32 | "n_trajs": 2,
33 | "metrics_stride": 5,
34 | "metrics": ["mse"],
35 | "out_type": "none",
36 | },
37 | "infer": {
38 | "n_trajs": 2,
39 | "metrics_stride": 1,
40 | "metrics": ["mse"],
41 | "out_type": "none",
42 | },
43 | },
44 | "logging": {
45 | "log_steps": 1,
46 | "eval_steps": 5,
47 | "wandb": False,
48 | "ckp_dir": "/tmp/ckp",
49 | },
50 | }
51 | )
52 | # overwrite defaults with user-defined config
53 | self.cfg = OmegaConf.merge(defaults, self.cfg)
54 |
55 | def test_runner(self):
56 | out = train_or_infer(self.cfg)
57 | self.assertEqual(out, 0)
58 |
59 |
60 | if __name__ == "__main__":
61 | unittest.main()
62 |
--------------------------------------------------------------------------------