├── .all-contributorsrc
├── .bumpversion.cfg
├── .github
└── workflows
│ ├── release.yaml
│ └── workflows.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── Dockerfile
├── LICENSE
├── MANIFEST.in
├── README.md
├── environment_cpu.yml
├── environment_cuda.yml
├── graph_weather
├── __init__.py
├── data
│ ├── IFSAnalysis_dataloader.py
│ ├── __init__.py
│ ├── const.py
│ ├── dataloader.py
│ ├── gencast_dataloader.py
│ ├── nnja_ai.py
│ └── weather_station_reader.py
└── models
│ ├── __init__.py
│ ├── analysis.py
│ ├── aurora
│ ├── __init__.py
│ ├── decoder.py
│ ├── encoder.py
│ ├── model.py
│ └── processor.py
│ ├── fengwu_ghr
│ ├── __init__.py
│ └── layers.py
│ ├── forecast.py
│ ├── gencast
│ ├── README.md
│ ├── __init__.py
│ ├── denoiser.py
│ ├── graph
│ │ ├── __init__.py
│ │ ├── graph_builder.py
│ │ ├── grid_mesh_connectivity.py
│ │ ├── icosahedral_mesh.py
│ │ └── model_utils.py
│ ├── images
│ │ ├── animated.gif
│ │ ├── autoregressive.gif
│ │ ├── fullmodel.png
│ │ └── readme.md
│ ├── layers
│ │ ├── __init__.py
│ │ ├── decoder.py
│ │ ├── encoder.py
│ │ ├── experimental
│ │ │ ├── __init__.py
│ │ │ └── sparse_transformer.py
│ │ ├── modules.py
│ │ └── processor.py
│ ├── sampler.py
│ ├── train.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── batching.py
│ │ ├── noise.py
│ │ └── statistics.py
│ └── weighted_mse_loss.py
│ ├── layers
│ ├── __init__.py
│ ├── assimilator_decoder.py
│ ├── assimilator_encoder.py
│ ├── constraint_layer.py
│ ├── decoder.py
│ ├── encoder.py
│ ├── graph_net_block.py
│ ├── grid_to_points.py
│ ├── points_to_grid.py
│ └── processor.py
│ ├── losses.py
│ └── weathermesh
│ ├── __init__.py
│ ├── decoder.py
│ ├── encoder.py
│ ├── layers.py
│ ├── processor.py
│ └── weathermesh2.py
├── pyproject.toml
├── setup.py
├── tests
├── test_aurora.py
├── test_gencast.py
├── test_model.py
├── test_nnjai.py
├── test_weather_station_reader.py
└── test_weathermesh.py
└── train
├── deepspeed_graph.py
├── era5.py
├── hf_forecasts.json
├── lora.py
├── pl_graph_weather.py
├── run.py
└── run_fulll.py
/.all-contributorsrc:
--------------------------------------------------------------------------------
1 | {
2 | "files": [
3 | "README.md"
4 | ],
5 | "imageSize": 100,
6 | "commit": false,
7 | "commitConvention": "angular",
8 | "contributors": [
9 | {
10 | "login": "jacobbieker",
11 | "name": "Jacob Bieker",
12 | "avatar_url": "https://avatars.githubusercontent.com/u/7170359?v=4",
13 | "profile": "https://www.jacobbieker.com",
14 | "contributions": [
15 | "code"
16 | ]
17 | },
18 | {
19 | "login": "JackKelly",
20 | "name": "Jack Kelly",
21 | "avatar_url": "https://avatars.githubusercontent.com/u/460756?v=4",
22 | "profile": "http://jack-kelly.com",
23 | "contributions": [
24 | "ideas"
25 | ]
26 | },
27 | {
28 | "login": "byphilipp",
29 | "name": "byphilipp",
30 | "avatar_url": "https://avatars.githubusercontent.com/u/59995258?v=4",
31 | "profile": "https://github.com/byphilipp",
32 | "contributions": [
33 | "ideas"
34 | ]
35 | },
36 | {
37 | "login": "paapu88",
38 | "name": "Markus Kaukonen",
39 | "avatar_url": "https://avatars.githubusercontent.com/u/6195764?v=4",
40 | "profile": "http://iki.fi/markus.kaukonen",
41 | "contributions": [
42 | "question"
43 | ]
44 | },
45 | {
46 | "login": "MoHawastaken",
47 | "name": "MoHawastaken",
48 | "avatar_url": "https://avatars.githubusercontent.com/u/55447473?v=4",
49 | "profile": "https://github.com/MoHawastaken",
50 | "contributions": [
51 | "bug"
52 | ]
53 | },
54 | {
55 | "login": "mishooax",
56 | "name": "Mihai",
57 | "avatar_url": "https://avatars.githubusercontent.com/u/47196359?v=4",
58 | "profile": "http://www.ecmwf.int",
59 | "contributions": [
60 | "question"
61 | ]
62 | },
63 | {
64 | "login": "vitusbenson",
65 | "name": "Vitus Benson",
66 | "avatar_url": "https://avatars.githubusercontent.com/u/33334860?v=4",
67 | "profile": "https://github.com/vitusbenson",
68 | "contributions": [
69 | "bug"
70 | ]
71 | },
72 | {
73 | "login": "dongZheX",
74 | "name": "dongZheX",
75 | "avatar_url": "https://avatars.githubusercontent.com/u/36361726?v=4",
76 | "profile": "https://github.com/dongZheX",
77 | "contributions": [
78 | "question"
79 | ]
80 | },
81 | {
82 | "login": "sabbir2331",
83 | "name": "sabbir2331",
84 | "avatar_url": "https://avatars.githubusercontent.com/u/25061297?v=4",
85 | "profile": "https://github.com/sabbir2331",
86 | "contributions": [
87 | "question"
88 | ]
89 | },
90 | {
91 | "login": "rnwzd",
92 | "name": "Lorenzo Breschi",
93 | "avatar_url": "https://avatars.githubusercontent.com/u/58804597?v=4",
94 | "profile": "https://github.com/rnwzd",
95 | "contributions": [
96 | "code"
97 | ]
98 | },
99 | {
100 | "login": "gbruno16",
101 | "name": "gbruno16",
102 | "avatar_url": "https://avatars.githubusercontent.com/u/72879691?v=4",
103 | "profile": "https://github.com/gbruno16",
104 | "contributions": [
105 | "code"
106 | ]
107 | }
108 | ],
109 | "contributorsPerLine": 7,
110 | "projectName": "graph_weather",
111 | "projectOwner": "openclimatefix",
112 | "repoType": "github",
113 | "repoHost": "https://github.com",
114 | "skipCi": true,
115 | "commitType": "docs"
116 | }
117 |
--------------------------------------------------------------------------------
/.bumpversion.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | commit = True
3 | tag = False
4 | current_version = 1.0.108
5 | message = Bump version: {current_version} → {new_version} [skip ci]
6 |
7 | [bumpversion:file:setup.py]
8 | search = version="{current_version}"
9 | replace = version="{new_version}"
10 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Bump version and auto-release
2 | on:
3 | push:
4 | branches:
5 | - main
6 |
7 | jobs:
8 | bump-version-python-docker-release:
9 | uses: openclimatefix/.github/.github/workflows/python-docker-release.yml@main
10 | secrets:
11 | PAT_TOKEN: ${{ secrets.PAT_TOKEN }}
12 | token: ${{ secrets.PYPI_API_TOKEN }}
13 | DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
14 | DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
15 | with:
16 | image_base_name: graph_weather
17 | docker_file: Dockerfile
18 |
--------------------------------------------------------------------------------
/.github/workflows/workflows.yaml:
--------------------------------------------------------------------------------
1 | name: Python package
2 |
3 | on:
4 | push:
5 | pull_request:
6 | types: [opened, reopened]
7 | jobs:
8 | pytest:
9 | runs-on: ${{ matrix.os }}
10 |
11 | strategy:
12 | fail-fast: true
13 | matrix:
14 | os: [ubuntu-latest, macos-latest]
15 | python-version: ["3.11", "3.12"]
16 | torch-version: [2.4.0]
17 | include:
18 | - torch-version: 2.4.0
19 | torchvision-version: 0.19.0
20 | steps:
21 | - uses: prefix-dev/setup-pixi@v0.8.3
22 | with:
23 | pixi-version: v0.41.4
24 | cache: true
25 | - run: pixi run test
26 | - uses: actions/checkout@v2
27 | - name: Set up Python ${{ matrix.python-version }}
28 | uses: actions/setup-python@v2
29 | with:
30 | python-version: ${{ matrix.python-version }}
31 |
32 | - name: Install PyTorch ${{ matrix.torch-version }}+cpu
33 | run: |
34 | pip install torch==${{ matrix.torch-version}} torchvision==${{ matrix.torchvision-version}} --index-url https://download.pytorch.org/whl/cpu
35 | - name: Install internal dependencies
36 | run: |
37 | pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${{ matrix.torch-version}}+cpu.html
38 | if [ ${{ matrix.torch-version}} == 2.4.0 ]; then pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/repo.html; fi
39 | - name: Install main package
40 | run: |
41 | pip install -e .
42 | pip install pytest-xdist
43 | - name: Setup with pytest-xdist
44 | run: |
45 | # lets get the string for how many cpus to use with pytest
46 | echo "Will be using ${{ inputs.pytest_numcpus }} cpus for pytest testing"
47 | #
48 | # make PYTESTXDIST
49 | export PYTESTXDIST="-n 2"
50 | if [ 2 -gt 0 ]; then export PYTESTXDIST="$PYTESTXDIST --dist=loadfile"; fi
51 | #
52 | # echo results and save env var for other jobs
53 | echo "pytest-xdist options that will be used are: $PYTESTXDIST"
54 | echo "PYTESTXDIST=$PYTESTXDIST" >> $GITHUB_ENV
55 | - name: Setup with pytest-cov
56 | run: |
57 | # let make pytest run with coverage
58 | echo "Will be looking at coverage of dir graph_weather"
59 | #
60 | # install pytest-cov
61 | pip install coverage==7.4.3
62 | pip install pytest-cov
63 | #
64 | # make PYTESTCOV
65 | export PYTESTCOV="--cov=graph_weather tests/ --cov-report=xml"
66 | # echo results and save env var for other jobs
67 | echo "pytest-cov options that will be used are: $PYTESTCOV"
68 | echo "PYTESTCOV=$PYTESTCOV" >> $GITHUB_ENV
69 | - name: Run pytest
70 | run: |
71 | # import dgl to initialize backend
72 | if [ ${{ matrix.torch-version}} == 2.4.0 ]; then python3 -c "import dgl"; fi
73 | export PYTEST_COMMAND="pytest $PYTESTCOV $PYTESTXDIST -s"
74 | echo "Will be running this command: $PYTEST_COMMAND"
75 | eval $PYTEST_COMMAND
76 | - name: "Upload coverage to Codecov"
77 | uses: codecov/codecov-action@v2
78 | with:
79 | fail_ci_if_error: false
80 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | *.nc
3 | *.pt
4 | *.pyc
5 | .DS_Store
6 | *.txt
7 | # pixi environments
8 | .pixi
9 | .vscode/
10 | checkpoints/
11 | lightning_logs/
12 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | repos:
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: v5.0.0
7 | hooks:
8 | # list of supported hooks: https://pre-commit.com/hooks.html
9 | - id: trailing-whitespace
10 | - id: end-of-file-fixer
11 | - id: debug-statements
12 | - id: detect-private-key
13 |
14 | # python code formatting/linting
15 | - repo: https://github.com/astral-sh/ruff-pre-commit
16 | # Ruff version.
17 | rev: "v0.9.2"
18 | hooks:
19 | - id: ruff
20 | args: [--fix]
21 | - repo: https://github.com/psf/black
22 | rev: 24.10.0
23 | hooks:
24 | - id: black
25 | args: [--line-length, "100"]
26 | # yaml formatting
27 | - repo: https://github.com/pre-commit/mirrors-prettier
28 | rev: v4.0.0-alpha.8
29 | hooks:
30 | - id: prettier
31 | types: [yaml]
32 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:latest
2 |
3 | ENV CONDA_ENV_NAME=graph
4 | ENV PYTHON_VERSION=3.12
5 |
6 | # Basic setup
7 | RUN apt update && apt install -y bash \
8 | build-essential \
9 | git \
10 | curl \
11 | ca-certificates \
12 | wget \
13 | libaio-dev \
14 | && rm -rf /var/lib/apt/lists
15 |
16 | # Install Miniconda and create main env
17 | ADD https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh miniconda3.sh
18 | RUN /bin/bash miniconda3.sh -b -p /conda \
19 | && echo export PATH=/conda/bin:$PATH >> .bashrc \
20 | && rm miniconda3.sh
21 | ENV PATH="/conda/bin:${PATH}"
22 |
23 | RUN git clone https://github.com/openclimatefix/graph_weather.git && mv graph_weather/ gw/ && cd gw/ && mv * .. && rm -rf gw/
24 |
25 | # Copy the appropriate environment file based on CUDA availability
26 | COPY environment_cpu.yml /tmp/environment_cpu.yml
27 | COPY environment_cuda.yml /tmp/environment_cuda.yml
28 |
29 | RUN conda update -n base -c defaults conda
30 |
31 | # Check if CUDA is available and accordingly choose env
32 | RUN cuda=$(command -v nvcc > /dev/null && echo "true" || echo "false") \
33 | && if [ "$cuda" == "true" ]; then conda env create -f /tmp/environment_cuda.yml; else conda env create -f /tmp/environment_cpu.yml; fi
34 |
35 | # Switch to bash shell
36 | SHELL ["/bin/bash", "-c"]
37 |
38 | # Set ${CONDA_ENV_NAME} to default virutal environment
39 | RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc
40 |
41 | # Cp in the development directory and install
42 | RUN source activate ${CONDA_ENV_NAME} && pip install -e .
43 |
44 |
45 | # Make RUN commands use the new environment:
46 | SHELL ["conda", "run", "-n", "graph", "/bin/bash", "-c"]
47 |
48 | # Example command that can be used, need to set API_KEY, API_SECRET and SAVE_DIR
49 | CMD ["conda", "run", "-n", "graph", "python", "-u", "train/pl_graph_weather.py", "--gpus", "16", "--hidden", "64", "--num-blocks", "3", "--batch", "16"]
50 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Open Climate Fix
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include *.txt
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Graph Weather
2 |
3 | [](#contributors-)
4 |
5 | Implementation of the Graph Weather paper (https://arxiv.org/pdf/2202.07575.pdf) in PyTorch. Additionally, an implementation
6 | of a modified model that assimilates raw or processed observations into analysis files.
7 |
8 |
9 | ## Installation
10 |
11 | This library can be installed through
12 |
13 | ```bash
14 | pip install graph-weather
15 | ```
16 |
17 | Alternatively, you can install the latest version from the repository easily with `pixi`:
18 |
19 | ```bash
20 | pixi install # `-e cuda` for GPU support, `-e cpu` for CPU-only
21 | ```
22 |
23 | ## Example Usage
24 |
25 | The models generate the graphs internally, so the only thing that needs to be passed to the model is the node features
26 | in the same order as the ```lat_lons```.
27 |
28 | ```python
29 | import torch
30 | from graph_weather import GraphWeatherForecaster
31 | from graph_weather.models.losses import NormalizedMSELoss
32 |
33 | lat_lons = []
34 | for lat in range(-90, 90, 1):
35 | for lon in range(0, 360, 1):
36 | lat_lons.append((lat, lon))
37 | model = GraphWeatherForecaster(lat_lons)
38 |
39 | # Generate 78 random features + 24 non-NWP features (i.e. landsea mask)
40 | features = torch.randn((2, len(lat_lons), 102))
41 |
42 | target = torch.randn((2, len(lat_lons), 78))
43 | out = model(features)
44 |
45 | criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,)))
46 | loss = criterion(out, target)
47 | loss.backward()
48 | ```
49 |
50 | And for the assimilation model, which assumes each lat/lon point also has a height above ground, and each observation
51 | is a single value + the relative time. The assimlation model also assumes the desired output grid is given to it as
52 | well.
53 |
54 | ```python
55 | import torch
56 | import numpy as np
57 | from graph_weather import GraphWeatherAssimilator
58 | from graph_weather.models.losses import NormalizedMSELoss
59 |
60 | obs_lat_lons = []
61 | for lat in range(-90, 90, 7):
62 | for lon in range(0, 180, 6):
63 | obs_lat_lons.append((lat, lon, np.random.random(1)))
64 | for lon in 360 * np.random.random(100):
65 | obs_lat_lons.append((lat, lon, np.random.random(1)))
66 |
67 | output_lat_lons = []
68 | for lat in range(-90, 90, 5):
69 | for lon in range(0, 360, 5):
70 | output_lat_lons.append((lat, lon))
71 | model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24)
72 |
73 | features = torch.randn((1, len(obs_lat_lons), 2))
74 | lat_lon_heights = torch.tensor(obs_lat_lons)
75 | out = model(features, lat_lon_heights)
76 | assert not torch.isnan(out).all()
77 | assert out.size() == (1, len(output_lat_lons), 24)
78 |
79 | criterion = torch.nn.MSELoss()
80 | loss = criterion(out, torch.randn((1, len(output_lat_lons), 24)))
81 | loss.backward()
82 | ```
83 |
84 | ## Pretrained Weights
85 | Coming soon! We plan to train a model on GFS 0.25 degree operational forecasts, as well as MetOffice NWP forecasts.
86 | We also plan trying out adaptive meshes, and predicting future satellite imagery as well.
87 |
88 | ## Training Data
89 | Training data will be available through HuggingFace Datasets for the GFS forecasts. The initial set of data is available for [GFSv16 forecasts, raw observations, and FNL Analysis files from 2016 to 2022](https://huggingface.co/datasets/openclimatefix/gfs-reforecast), and for [ERA5 Reanlaysis](https://huggingface.co/datasets/openclimatefix/era5-reanalysis). MetOffice NWP forecasts we cannot
90 | redistribute, but can be accessed through [CEDA](https://data.ceda.ac.uk/).
91 |
92 | ## Contributors ✨
93 |
94 | Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)):
95 |
96 |
97 |
98 |
99 |
118 |
119 |
120 |
121 |
122 |
123 |
124 | This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome!
125 |
--------------------------------------------------------------------------------
/environment_cpu.yml:
--------------------------------------------------------------------------------
1 | name: graph
2 | channels:
3 | - pytorch
4 | - pyg
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - numcodecs
9 | - pandas
10 | - pip
11 | - pyg
12 | - python
13 | - pytorch
14 | - cpuonly
15 | - pytorch-cluster
16 | - pytorch-scatter
17 | - pytorch-sparse
18 | - pytorch-spline-conv
19 | - scikit-learn
20 | - scipy
21 | - torchvision
22 | - tqdm
23 | - xarray
24 | - zarr
25 | - h3-py
26 | - numpy
27 | - pyshtools
28 | - gcsfs
29 | - pytest
30 | - pip:
31 | - setuptools
32 | - datasets
33 | - einops
34 | - fsspec
35 | - torch-geometric-temporal
36 | - huggingface-hub
37 | - pysolar
38 | - pytorch-lightning
39 | - click
40 | - trimesh
41 | - rtree
42 | - torch-harmonics
43 |
--------------------------------------------------------------------------------
/environment_cuda.yml:
--------------------------------------------------------------------------------
1 | name: graph
2 | channels:
3 | - pytorch
4 | - pyg
5 | - nvidia
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - pytorch-cuda
10 | - numcodecs
11 | - pandas
12 | - pip
13 | - pyg
14 | - python=3.12
15 | - pytorch
16 | - pytorch-cluster
17 | - pytorch-scatter
18 | - pytorch-sparse
19 | - pytorch-spline-conv
20 | - scikit-learn
21 | - scipy
22 | - torchvision
23 | - tqdm
24 | - xarray
25 | - zarr
26 | - h3-py
27 | - numpy
28 | - pyshtools
29 | - gcsfs
30 | - pytest
31 | - pip:
32 | - setuptools
33 | - datasets
34 | - einops
35 | - fsspec
36 | - torch-geometric-temporal
37 | - huggingface-hub
38 | - pysolar
39 | - pytorch-lightning
40 | - click
41 | - trimesh
42 | - rtree
43 | - torch-harmonics
44 |
--------------------------------------------------------------------------------
/graph_weather/__init__.py:
--------------------------------------------------------------------------------
1 | """Main import for the complete models"""
2 |
3 | from .data.nnja_ai import SensorDataset, collate_fn
4 | from .data.weather_station_reader import WeatherStationReader
5 | from .models.analysis import GraphWeatherAssimilator
6 | from .models.forecast import GraphWeatherForecaster
7 |
--------------------------------------------------------------------------------
/graph_weather/data/IFSAnalysis_dataloader.py:
--------------------------------------------------------------------------------
1 | """
2 | The dataloader for IFS analysis.
3 | """
4 |
5 | import numpy as np
6 | import torchvision.transforms as transforms
7 | import xarray as xr
8 | from torch.utils.data import Dataset
9 |
10 | IFS_MEAN = {
11 | "geopotential": 78054.78,
12 | "specific_humidity": 0.0018220816,
13 | "temperature": 243.41727,
14 | "u_component_of_wind": 7.3073797,
15 | "v_component_of_wind": 0.032221083,
16 | "vertical_velocity": 0.0058287205,
17 | }
18 |
19 | IFS_STD = {
20 | "geopotential": 59538.875,
21 | "specific_humidity": 0.0035489395,
22 | "temperature": 29.211119,
23 | "u_component_of_wind": 13.777036,
24 | "v_component_of_wind": 8.867598,
25 | "vertical_velocity": 0.08577341,
26 | }
27 |
28 |
29 | class IFSAnalisysDataset(Dataset):
30 | """
31 | Dataset for IFSAnalysis.
32 |
33 | Args:
34 | filepath: path of the dataset.
35 | features: list of features.
36 | start_year: initial year. Defaults to 2016.
37 | end_year: ending year. Defaults to 2022.
38 | """
39 |
40 | def __init__(self, filepath: str, features: list, start_year: int = 2016, end_year: int = 2022):
41 | """
42 | Initialize the dataset object.
43 | """
44 |
45 | super().__init__()
46 | assert (
47 | start_year <= end_year
48 | ), f"start_year ({start_year}) cannot be greater than end_year ({end_year})."
49 | assert start_year >= 2016 and start_year <= 2022, "Time data range from 2016 to 2022"
50 | assert end_year >= 2016 and end_year <= 2022, "Time data range from 2016 to 2022"
51 | self.data = xr.open_zarr(filepath)
52 | self.data = self.data.sel(
53 | time=slice(str(start_year), str(end_year))
54 | ) # Filter data by start and end years
55 |
56 | self.NWP_features = features
57 |
58 | def __len__(self):
59 | return len(self.data["time"])
60 |
61 | def __getitem__(self, idx):
62 | start = self.data.isel(time=idx)
63 | end = self.data.isel(time=idx + 1)
64 |
65 | # Extract NWP features
66 | input_data = self._nwp_features_extraction(start)
67 | output_data = self._nwp_features_extraction(end)
68 |
69 | return (
70 | (transforms).ToTensor()(input_data).view(-1, input_data.shape[-1]),
71 | (transforms).ToTensor()(output_data).view(-1, output_data.shape[-1]),
72 | )
73 |
74 | def _nwp_features_extraction(self, data):
75 | data_cube = np.stack(
76 | [
77 | (data[f"{var}"].values - IFS_MEAN[f"{var}"]) / (IFS_STD[f"{var}"] + 1e-6)
78 | for var in self.NWP_features
79 | ],
80 | axis=-1,
81 | ).astype(np.float32)
82 |
83 | num_layers, num_lat, num_lon, num_vars = data_cube.shape
84 | data_cube = data_cube.reshape(num_lat, num_lon, num_vars * num_layers)
85 |
86 | assert not np.isnan(data_cube).any()
87 | return data_cube
88 |
--------------------------------------------------------------------------------
/graph_weather/data/__init__.py:
--------------------------------------------------------------------------------
1 | """Dataloaders and data processing utilities"""
2 |
3 | from .nnja_ai import SensorDataset, collate_fn
4 | from .weather_station_reader import WeatherStationReader
5 |
--------------------------------------------------------------------------------
/graph_weather/data/dataloader.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | The dataloader has to do a few things for the model to work correctly
4 |
5 | 1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the
6 | correct resolution
7 | 2. Calculate the top-of-atmosphere solar radiation for each location at
8 | fcurrent time and 10 other
9 | times +- 12 hours
10 | 3. Add day-of-year, sin(lat), cos(lat), sin(lon), cos(lon) as well
11 | 3. Batch data as either in geometric batches, or more normally
12 | 4. Rescale between 0 and 1, but don't normalize
13 |
14 | """
15 |
16 | import const
17 | import numpy as np
18 | import pandas as pd
19 | import xarray as xr
20 | from pysolar.util import extraterrestrial_irrad
21 | from torch.utils.data import Dataset
22 |
23 |
24 | class AnalysisDataset(Dataset):
25 | """
26 | Dataset class for analysis data.
27 |
28 | Args:
29 | filepaths: List of file paths.
30 | invariant_path: Path to the invariant file.
31 | mean: Mean value.
32 | std Standard deviation value.
33 | coarsen : Coarsening factor. Defaults to 8.
34 |
35 | Methods:
36 | __init__: Initialize the AnalysisDataset object.
37 | __len__: Get the length of the dataset.
38 | __getitem__: Get an item from the dataset.
39 | """
40 |
41 | def __init__(self, filepaths, invariant_path, mean, std, coarsen: int = 8):
42 | """
43 | Initialize the AnalysisDataset object.
44 | """
45 | super().__init__()
46 | self.filepaths = sorted(filepaths)
47 | self.invariant_path = invariant_path
48 | self.coarsen = coarsen
49 | self.mean = mean
50 | self.std = std
51 |
52 | def __len__(self):
53 | return len(self.filepaths) - 1
54 |
55 | def __getitem__(self, item):
56 | if self.coarsen <= 1: # Don't coarsen, so don't even call it
57 | start = xr.open_zarr(self.filepaths[item], consolidated=True)
58 | end = xr.open_zarr(self.filepaths[item + 1], consolidated=True)
59 | else:
60 | start = (
61 | xr.open_zarr(self.filepaths[item], consolidated=True)
62 | .coarsen(latitude=self.coarsen, boundary="pad")
63 | .mean()
64 | .coarsen(longitude=self.coarsen)
65 | .mean()
66 | )
67 | end = (
68 | xr.open_zarr(self.filepaths[item + 1], consolidated=True)
69 | .coarsen(latitude=self.coarsen, boundary="pad")
70 | .mean()
71 | .coarsen(longitude=self.coarsen)
72 | .mean()
73 | )
74 |
75 | # Land-sea mask data, resampled to the same as the physical variables
76 | landsea = (
77 | xr.open_zarr(self.invariant_path, consolidated=True)
78 | .interp(latitude=start.latitude.values)
79 | .interp(longitude=start.longitude.values)
80 | )
81 | # Calculate sin,cos, day of year, solar irradiance here before stacking
82 | landsea = np.stack(
83 | [
84 | (landsea[f"{var}"].values - const.LANDSEA_MEAN[var]) / const.LANDSEA_STD[var]
85 | for var in landsea.data_vars
86 | if not np.isnan(landsea[f"{var}"].values).any()
87 | ],
88 | axis=-1,
89 | )
90 | landsea = landsea.T.reshape((-1, landsea.shape[-1]))
91 | lat_lons = np.array(np.meshgrid(start.latitude.values, start.longitude.values)).T.reshape(
92 | (-1, 2)
93 | )
94 | sin_lat_lons = np.sin(lat_lons)
95 | cos_lat_lons = np.cos(lat_lons)
96 | date = start.time.dt.values
97 | day_of_year = start.time.dayofyear.values / 365.0
98 | np.sin(day_of_year)
99 | np.cos(day_of_year)
100 | solar_times = [np.array([extraterrestrial_irrad(date, lat, lon) for lat, lon in lat_lons])]
101 | for when in pd.date_range(
102 | date - pd.Timedelta("12 hours"), date + pd.Timedelta("12 hours"), freq="1H"
103 | ):
104 | solar_times.append(
105 | np.array([extraterrestrial_irrad(when, lat, lon) for lat, lon in lat_lons])
106 | )
107 | solar_times = np.array(solar_times)
108 |
109 | # End time solar radiation too
110 | end_date = end.time.dt.values
111 | end_solar_times = [
112 | np.array([extraterrestrial_irrad(end_date, lat, lon) for lat, lon in lat_lons])
113 | ]
114 | for when in pd.date_range(
115 | end_date - pd.Timedelta("12 hours"), end_date + pd.Timedelta("12 hours"), freq="1H"
116 | ):
117 | end_solar_times.append(
118 | np.array([extraterrestrial_irrad(when, lat, lon) for lat, lon in lat_lons])
119 | )
120 | end_solar_times = np.array(solar_times)
121 |
122 | # Normalize to between -1 and 1
123 | solar_times -= const.SOLAR_MEAN
124 | solar_times /= const.SOLAR_STD
125 | end_solar_times -= const.SOLAR_MEAN
126 | end_solar_times /= const.SOLAR_STD
127 |
128 | # Stack the data into a large data cube
129 | input_data = np.stack(
130 | [
131 | start[f"{var}"].values
132 | for var in start.data_vars
133 | if not np.isnan(start[f"{var}"].values).any()
134 | ],
135 | axis=-1,
136 | )
137 | # TODO Combine with above? And include sin/cos of day of year
138 | input_data = np.concatenate(
139 | [
140 | input_data.T.reshape((-1, input_data.shape[-1])),
141 | sin_lat_lons,
142 | cos_lat_lons,
143 | solar_times,
144 | landsea,
145 | ],
146 | axis=-1,
147 | )
148 | # Not want to predict non-physics variables -> Output only the data variables?
149 | # Would be simpler, and just add in the new ones each time
150 |
151 | output_data = np.stack(
152 | [
153 | end[f"{var}"].values
154 | for var in end.data_vars
155 | if not np.isnan(end[f"{var}"].values).any()
156 | ],
157 | axis=-1,
158 | )
159 |
160 | output_data = np.concatenate(
161 | [
162 | output_data.T.reshape((-1, output_data.shape[-1])),
163 | sin_lat_lons,
164 | cos_lat_lons,
165 | end_solar_times,
166 | landsea,
167 | ],
168 | axis=-1,
169 | )
170 | # Stick with Numpy, don't tensor it, as just going from 0 to 1
171 |
172 | # Normalize now
173 | return input_data, output_data
174 |
175 |
176 | obs_data = xr.open_zarr(
177 | "/home/jacob/Development/prepbufr.gdas.20160101.t00z.nr.48h.raw.zarr", consolidated=True
178 | )
179 |
180 | # TODO Embedding? These should stay consistent across all of the inputs, so can just load the values
181 | # not the strings?
182 | # Should only take in the quality markers, observations, reported observation time relative to start
183 | # point
184 | # Observation errors, and background values, lat/lon/height/speed of observing thing
185 |
186 | print(obs_data)
187 | print(obs_data.hdr_inst_typ.values)
188 | print(obs_data.hdr_irpt_typ.values)
189 | print(obs_data.obs_qty_table.values)
190 | print(obs_data.hdr_prpt_typ.values)
191 | print(obs_data.hdr_sid_table.values)
192 | print(obs_data.hdr_typ_table.values)
193 | print(obs_data.obs_desc.values)
194 | print(obs_data.data_vars.keys())
195 | exit()
196 | analysis_data = xr.open_zarr(
197 | "/home/jacob/Development/gdas1.fnl0p25.2016010100.f00.zarr", consolidated=True
198 | )
199 | print(analysis_data)
200 |
--------------------------------------------------------------------------------
/graph_weather/data/nnja_ai.py:
--------------------------------------------------------------------------------
1 | """
2 | A custom PyTorch Dataset implementation for various sensors like AMSU, ATMS, MHS, IASI, CrIS
3 |
4 | The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and
5 | variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata.
6 | """
7 |
8 | import numpy as np
9 | import torch
10 | from torch.utils.data import Dataset
11 |
12 | try:
13 | from nnja import DataCatalog
14 | except ImportError:
15 | print(
16 | "NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`"
17 | )
18 |
19 |
20 | class SensorDataset(Dataset):
21 | """A custom PyTorch Dataset for handling various sensor data."""
22 |
23 | def __init__(
24 | self, dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU"
25 | ):
26 | """Initialize the dataset loader for various sensors.
27 |
28 | Args:
29 | dataset_name: Name of the dataset to load.
30 | time: Specific timestamp to filter the data.
31 | primary_descriptors: List of primary descriptor variables to include (e.g., OBS_TIMESTAMP, LAT, LON).
32 | additional_variables: List of additional variables to include in metadata.
33 | sensor_type: Type of sensor (AMSU, ATMS, MHS, IASI, CrIS)
34 | """
35 | self.dataset_name = dataset_name
36 | self.time = time
37 | self.primary_descriptors = primary_descriptors
38 | self.additional_variables = additional_variables
39 | self.sensor_type = sensor_type # New argument for selecting sensor type
40 |
41 | # Load data catalog and dataset
42 | self.catalog = DataCatalog(skip_manifest=True)
43 | self.dataset = self.catalog[self.dataset_name]
44 | self.dataset.load_manifest()
45 |
46 | if self.sensor_type == "AMSU":
47 | self.dataset = self.dataset.sel(
48 | time=self.time,
49 | variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 16)],
50 | )
51 | elif self.sensor_type == "ATMS":
52 | self.dataset = self.dataset.sel(
53 | time=self.time,
54 | variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 23)],
55 | )
56 | elif self.sensor_type == "MHS":
57 | self.dataset = self.dataset.sel(
58 | time=self.time,
59 | variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 6)],
60 | )
61 | elif self.sensor_type == "IASI":
62 | self.dataset = self.dataset.sel(
63 | time=self.time,
64 | variables=self.primary_descriptors
65 | + ["SCRA_" + str(i).zfill(5) for i in range(1, 617)],
66 | )
67 | elif self.sensor_type == "CrIS":
68 | self.dataset = self.dataset.sel(
69 | time=self.time,
70 | variables=self.primary_descriptors
71 | + [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)],
72 | )
73 | else:
74 | raise ValueError(f"Unsupported sensor type: {self.sensor_type}")
75 |
76 | self.dataframe = self.dataset.load_dataset(engine="pandas")
77 |
78 | for col in primary_descriptors:
79 | if col not in self.dataframe.columns:
80 | raise ValueError(f"The dataset must include a '{col}' column.")
81 |
82 | self.metadata_columns = [
83 | col for col in self.dataframe.columns if col not in self.primary_descriptors
84 | ]
85 |
86 | def __len__(self):
87 | """Return the total number of samples in the dataset."""
88 | return len(self.dataframe)
89 |
90 | def __getitem__(self, index):
91 | """Return the observation and metadata for a given index."""
92 | row = self.dataframe.iloc[index]
93 | time = row["OBS_TIMESTAMP"].timestamp()
94 | latitude = row["LAT"]
95 | longitude = row["LON"]
96 | metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32)
97 |
98 | return {
99 | "timestamp": torch.tensor(time, dtype=torch.float32),
100 | "latitude": torch.tensor(latitude, dtype=torch.float32),
101 | "longitude": torch.tensor(longitude, dtype=torch.float32),
102 | "metadata": torch.from_numpy(metadata),
103 | }
104 |
105 |
106 | def collate_fn(batch):
107 | """Custom collate function to handle batching of dictionary data.
108 |
109 | Args:
110 | batch: List of dictionaries from __getitem__
111 |
112 | Returns:
113 | Single dictionary with batched tensors
114 | """
115 | return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()}
116 |
--------------------------------------------------------------------------------
/graph_weather/models/__init__.py:
--------------------------------------------------------------------------------
1 | """Models"""
2 |
3 | from .fengwu_ghr.layers import (
4 | ImageMetaModel,
5 | LoRAModule,
6 | MetaModel,
7 | WrapperImageModel,
8 | WrapperMetaModel,
9 | )
10 | from .layers.assimilator_decoder import AssimilatorDecoder
11 | from .layers.assimilator_encoder import AssimilatorEncoder
12 | from .layers.decoder import Decoder
13 | from .layers.encoder import Encoder
14 | from .layers.processor import Processor
15 |
--------------------------------------------------------------------------------
/graph_weather/models/analysis.py:
--------------------------------------------------------------------------------
1 | """Model for forecasting weather from NWP states"""
2 |
3 | import torch
4 | from huggingface_hub import PyTorchModelHubMixin
5 |
6 | from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Processor
7 |
8 |
9 | class GraphWeatherAssimilator(torch.nn.Module, PyTorchModelHubMixin):
10 | """Model to generate analysis file from raw observations"""
11 |
12 | def __init__(
13 | self,
14 | output_lat_lons: list,
15 | resolution: int = 2,
16 | observation_dim: int = 2,
17 | analysis_dim: int = 78,
18 | node_dim: int = 256,
19 | edge_dim: int = 256,
20 | num_blocks: int = 9,
21 | hidden_dim_processor_node: int = 256,
22 | hidden_dim_processor_edge: int = 256,
23 | hidden_layers_processor_node: int = 2,
24 | hidden_layers_processor_edge: int = 2,
25 | hidden_dim_decoder: int = 128,
26 | hidden_layers_decoder: int = 2,
27 | norm_type: str = "LayerNorm",
28 | use_checkpointing: bool = False,
29 | ):
30 | """
31 | Graph Weather Data Assimilation model
32 |
33 | Args:
34 | observation_lat_lons: Lat/lon points of the observations
35 | output_lat_lons: List of latitude and longitudes for the output analysis
36 | resolution: Resolution of the H3 grid, prefer even resolutions, as
37 | odd ones have octogons and heptagons as well
38 | observation_dim: Input feature size
39 | analysis_dim: Output Analysis feature dim
40 | node_dim: Node hidden dimension
41 | edge_dim: Edge hidden dimension
42 | num_blocks: Number of message passing blocks in the Processor
43 | hidden_dim_processor_node: Hidden dimension of the node processors
44 | hidden_dim_processor_edge: Hidden dimension of the edge processors
45 | hidden_layers_processor_node: Number of hidden layers in the node processors
46 | hidden_layers_processor_edge: Number of hidden layers in the edge processors
47 | hidden_dim_decoder:Number of hidden dimensions in the decoder
48 | hidden_layers_decoder: Number of layers in the decoder
49 | norm_type: Type of norm for the MLPs
50 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
51 | use_checkpointing: Whether to use gradient checkpointing or not
52 | """
53 | super().__init__()
54 |
55 | self.encoder = AssimilatorEncoder(
56 | resolution=resolution,
57 | input_dim=observation_dim,
58 | output_dim=node_dim,
59 | output_edge_dim=edge_dim,
60 | hidden_dim_processor_edge=hidden_dim_processor_edge,
61 | hidden_layers_processor_node=hidden_layers_processor_node,
62 | hidden_dim_processor_node=hidden_dim_processor_node,
63 | hidden_layers_processor_edge=hidden_layers_processor_edge,
64 | mlp_norm_type=norm_type,
65 | use_checkpointing=use_checkpointing,
66 | )
67 | self.processor = Processor(
68 | input_dim=node_dim,
69 | edge_dim=edge_dim,
70 | num_blocks=num_blocks,
71 | hidden_dim_processor_edge=hidden_dim_processor_edge,
72 | hidden_layers_processor_node=hidden_layers_processor_node,
73 | hidden_dim_processor_node=hidden_dim_processor_node,
74 | hidden_layers_processor_edge=hidden_layers_processor_edge,
75 | mlp_norm_type=norm_type,
76 | )
77 | self.decoder = AssimilatorDecoder(
78 | lat_lons=output_lat_lons,
79 | resolution=resolution,
80 | input_dim=node_dim,
81 | output_dim=analysis_dim,
82 | output_edge_dim=edge_dim,
83 | hidden_dim_processor_edge=hidden_dim_processor_edge,
84 | hidden_layers_processor_node=hidden_layers_processor_node,
85 | hidden_dim_processor_node=hidden_dim_processor_node,
86 | hidden_layers_processor_edge=hidden_layers_processor_edge,
87 | mlp_norm_type=norm_type,
88 | hidden_dim_decoder=hidden_dim_decoder,
89 | hidden_layers_decoder=hidden_layers_decoder,
90 | use_checkpointing=use_checkpointing,
91 | )
92 |
93 | def forward(self, features: torch.Tensor, obs_lat_lon_heights: torch.Tensor) -> torch.Tensor:
94 | """
95 | Compute the analysis output
96 |
97 | Args:
98 | features: The input features, aligned with the order of lat_lons_heights
99 | obs_lat_lon_heights: Observation lat/lon/heights in same order as features
100 |
101 | Returns:
102 | The next state in the forecast
103 | """
104 | x, edge_idx, edge_attr = self.encoder(features, obs_lat_lon_heights)
105 | x = self.processor(x, edge_idx, edge_attr)
106 | x = self.decoder(x, features.shape[0])
107 | return x
108 |
--------------------------------------------------------------------------------
/graph_weather/models/aurora/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Aurora: A Foundation Model for Earth System Science
3 | - Combines 3D Swin Transformer encoding
4 | - Perceiver processing for efficient computation
5 | - 3D decoding for spatial-temporal predictions
6 | """
7 |
8 | from .decoder import Decoder3D
9 | from .encoder import Swin3DEncoder
10 | from .model import AuroraModel, EarthSystemLoss
11 | from .processor import PerceiverProcessor
12 |
13 | __version__ = "0.1.0"
14 |
15 | __all__ = [
16 | "AuroraModel",
17 | "EarthSystemLoss",
18 | "Swin3DEncoder",
19 | "Decoder3D",
20 | "PerceiverProcessor",
21 | ]
22 |
23 | # Default configurations for different model sizes
24 | MODEL_CONFIGS = {
25 | "tiny": {
26 | "in_channels": 1,
27 | "out_channels": 1,
28 | "embed_dim": 48,
29 | "latent_dim": 256,
30 | "spatial_shape": (16, 16, 16),
31 | "max_seq_len": 2048,
32 | },
33 | "base": {
34 | "in_channels": 1,
35 | "out_channels": 1,
36 | "embed_dim": 96,
37 | "latent_dim": 512,
38 | "spatial_shape": (32, 32, 32),
39 | "max_seq_len": 4096,
40 | },
41 | "large": {
42 | "in_channels": 1,
43 | "out_channels": 1,
44 | "embed_dim": 192,
45 | "latent_dim": 1024,
46 | "spatial_shape": (64, 64, 64),
47 | "max_seq_len": 8192,
48 | },
49 | }
50 |
51 |
52 | def create_model(config="base", **kwargs):
53 | """
54 | Create an Aurora model with specified configuration.
55 |
56 | Args:
57 | config (str): Model size configuration ('tiny', 'base', or 'large')
58 | **kwargs: Override default configuration parameters
59 |
60 | Returns:
61 | AuroraModel: Initialized model with specified configuration
62 | """
63 | if config not in MODEL_CONFIGS:
64 | raise ValueError(
65 | f"Unknown configuration: {config}. Choose from {list(MODEL_CONFIGS.keys())}"
66 | )
67 |
68 | # Start with default config and update with any provided kwargs
69 | model_config = MODEL_CONFIGS[config].copy()
70 | model_config.update(kwargs)
71 |
72 | return AuroraModel(**model_config)
73 |
74 |
75 | def create_loss(alpha=0.5, beta=0.3, gamma=0.2):
76 | """
77 | Create an EarthSystemLoss instance with specified weights.
78 |
79 | Args:
80 | alpha (float): Weight for MSE loss
81 | beta (float): Weight for gradient loss
82 | gamma (float): Weight for physical consistency loss
83 |
84 | Returns:
85 | EarthSystemLoss: Initialized loss function
86 | """
87 | return EarthSystemLoss(alpha=alpha, beta=beta, gamma=gamma)
88 |
--------------------------------------------------------------------------------
/graph_weather/models/aurora/decoder.py:
--------------------------------------------------------------------------------
1 | """
2 | 3D Decoder:
3 | - Takes processed latent representations and reconstructs output.
4 | - Uses transposed convolution to upscale back to spatial-temporal format.
5 | """
6 |
7 | import torch.nn as nn
8 |
9 |
10 | class Decoder3D(nn.Module):
11 | """
12 | 3D Decoder:
13 | - Takes processed latent representations and reconstructs the spatial-temporal output.
14 | - Uses transposed convolutions to upscale latent features to the original format.
15 | """
16 |
17 | def __init__(self, output_channels=1, embed_dim=96, target_shape=(32, 32, 32)):
18 | """
19 | Args:
20 | output_channels (int): Number of channels in the output tensor (e.g., 1 for grayscale).
21 | embed_dim (int): Dimension of the latent features (matches the encoder's output).
22 | target_shape (tuple): The desired shape of the reconstructed 3D tensor (D, H, W).
23 | """
24 | super().__init__()
25 | self.embed_dim = embed_dim
26 | self.target_shape = target_shape
27 | self.deconv1 = nn.ConvTranspose3d(
28 | embed_dim, output_channels, kernel_size=3, padding=1, stride=1
29 | )
30 |
31 | def forward(self, x):
32 | """
33 | Forward pass for the decoder.
34 |
35 | Args:
36 | x (torch.Tensor): Input latent representation, shape (batch, seq_len, embed_dim).
37 |
38 | Returns:
39 | torch.Tensor: Reconstructed 3D tensor, shape (batch, output_channels, *target_shape).
40 | """
41 | batch_size = x.shape[0]
42 | depth, height, width = self.target_shape
43 | # Reshape latent features into 3D tensor
44 | x = x.view(batch_size, self.embed_dim, depth, height, width)
45 | # Transposed convolution to upscale to the final shape
46 | x = self.deconv1(x)
47 | return x
48 |
--------------------------------------------------------------------------------
/graph_weather/models/aurora/encoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Swin 3D Transformer Encoder:
3 | - Uses a 3D convolution for initial feature extraction.
4 | - Applies layer normalization and reshapes data.
5 | - Uses a transformer-based encoder to learn spatial-temporal features.
6 | """
7 |
8 | import torch.nn as nn
9 | from einops import rearrange
10 | from einops.layers.torch import Rearrange
11 |
12 |
13 | class Swin3DEncoder(nn.Module):
14 | def __init__(self, in_channels=1, embed_dim=96):
15 | super().__init__()
16 | self.conv1 = nn.Conv3d(in_channels, embed_dim, kernel_size=3, padding=1, stride=1)
17 | self.norm = nn.LayerNorm(embed_dim)
18 | self.swin_transformer = nn.Transformer(
19 | d_model=embed_dim,
20 | nhead=8,
21 | num_encoder_layers=4,
22 | num_decoder_layers=4,
23 | dim_feedforward=embed_dim * 4,
24 | )
25 | self.embed_dim = embed_dim
26 |
27 | # Define rearrangement patterns using einops
28 | self.to_transformer_format = Rearrange("b d h w c -> (d h w) b c")
29 | self.from_transformer_format = Rearrange("(d h w) b c -> b d h w c", d=None, h=None, w=None)
30 |
31 | # To use rearrange function directly instead of the Rearrange layer
32 | def forward(self, x):
33 | # 3D convolution with einops rearrangement
34 | x = self.conv1(x)
35 |
36 | # Rearrange for normalization using einops
37 | x = rearrange(x, "b c d h w -> b d h w c")
38 | x = self.norm(x)
39 |
40 | # Store spatial dimensions for later reconstruction
41 | d, h, w = x.shape[1:4]
42 |
43 | # Transform to sequence format for transformer
44 | x = rearrange(x, "b d h w c -> (d h w) b c")
45 | x = self.swin_transformer.encoder(x)
46 |
47 | # Restore original spatial structure
48 | x = rearrange(x, "(d h w) b c -> b (d h w) c", d=d, h=h, w=w)
49 |
50 | # Reshape to the expected output format (batch, seq_len, embed_dim)
51 | x = rearrange(x, "b (d h w) c -> b (d h w) c", d=d, h=h, w=w)
52 |
53 | return x
54 |
55 | def convolution(self, x):
56 | """Apply 3D convolution with clear shape transformation."""
57 | return self.conv1(x) # b c d h w -> b embed_dim d h w
58 |
59 | def normalization_layer(self, x):
60 | """Apply layer normalization with einops rearrangement."""
61 | x = rearrange(x, "b c d h w -> b d h w c")
62 | return self.norm(x)
63 |
64 | def transformer_encoder(self, x, spatial_dims):
65 | """
66 | Apply transformer encoding with proper shape handling.
67 |
68 | Args:
69 | x (torch.Tensor): Input tensor
70 | spatial_dims (tuple): Original (depth, height, width) dimensions
71 | """
72 | d, h, w = spatial_dims
73 | x = self.to_transformer_format(x)
74 | x = self.swin_transformer.encoder(x)
75 | x = self.from_transformer_format(x, d=d, h=h, w=w)
76 | return x
77 |
--------------------------------------------------------------------------------
/graph_weather/models/aurora/model.py:
--------------------------------------------------------------------------------
1 | """
2 | aurora/model.py - Core implementation of Aurora model for unstructured point data
3 | """
4 |
5 | from typing import Optional
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class PointEncoder(nn.Module):
12 | def __init__(self, input_features: int, embed_dim: int, max_seq_len: int = 1024):
13 | super().__init__()
14 | self.input_dim = input_features + 2 # Account for lat/lon coordinates
15 | self.max_seq_len = max_seq_len
16 |
17 | # Remove positional embeddings as they break point ordering invariance
18 |
19 | # Enhanced coordinate embedding
20 | self.coord_encoder = nn.Sequential(
21 | nn.Linear(2, embed_dim // 2),
22 | nn.LayerNorm(embed_dim // 2),
23 | nn.ReLU(),
24 | nn.Linear(embed_dim // 2, embed_dim),
25 | )
26 |
27 | # Feature embedding
28 | self.feature_encoder = nn.Sequential(
29 | nn.Linear(input_features, embed_dim),
30 | nn.LayerNorm(embed_dim),
31 | nn.ReLU(),
32 | nn.Linear(embed_dim, embed_dim),
33 | )
34 |
35 | # Final normalization
36 | self.norm = nn.LayerNorm(embed_dim)
37 |
38 | def forward(self, points: torch.Tensor, features: torch.Tensor) -> torch.Tensor:
39 | num_points = points.shape[1]
40 | if num_points > self.max_seq_len:
41 | points = points[:, : self.max_seq_len, :]
42 | features = features[:, : self.max_seq_len, :]
43 |
44 | # Normalize coordinates to [-1, 1] range
45 | normalized_points = torch.stack(
46 | [points[..., 0] / 180.0, points[..., 1] / 90.0], dim=-1 # longitude # latitude
47 | )
48 |
49 | # Separately encode coordinates and features
50 | coord_embedding = self.coord_encoder(normalized_points)
51 | feature_embedding = self.feature_encoder(features)
52 |
53 | # Combine embeddings through addition (order-invariant operation)
54 | x = coord_embedding + feature_embedding
55 |
56 | # Final normalization
57 | x = self.norm(x)
58 |
59 | return x
60 |
61 |
62 | class PointDecoder(nn.Module):
63 | """Decodes latent representations back to point features."""
64 |
65 | def __init__(self, embed_dim: int, output_features: int):
66 | super().__init__()
67 | self.decoder = nn.Sequential(
68 | nn.Linear(embed_dim, embed_dim), nn.ReLU(), nn.Linear(embed_dim, output_features)
69 | )
70 |
71 | def forward(self, x: torch.Tensor) -> torch.Tensor:
72 | """
73 | Args:
74 | x: (batch_size, num_points, embed_dim) tensor
75 | Returns:
76 | (batch_size, num_points, output_features) tensor
77 | """
78 | return self.decoder(x)
79 |
80 |
81 | class PointCloudProcessor(nn.Module):
82 | """Processes point cloud data using self-attention layers."""
83 |
84 | def __init__(self, embed_dim: int, num_layers: int = 4):
85 | super().__init__()
86 | self.layers = nn.ModuleList([SelfAttentionLayer(embed_dim) for _ in range(num_layers)])
87 |
88 | def forward(self, x: torch.Tensor) -> torch.Tensor:
89 | """
90 | Args:
91 | x: (batch_size, num_points, embed_dim) tensor
92 | Returns:
93 | (batch_size, num_points, embed_dim) tensor after processing
94 | """
95 | for layer in self.layers:
96 | x = layer(x)
97 | return x
98 |
99 |
100 | class SelfAttentionLayer(nn.Module):
101 | def __init__(self, embed_dim: int):
102 | super().__init__()
103 | self.attention = nn.MultiheadAttention(embed_dim, num_heads=8)
104 | self.norm1 = nn.LayerNorm(embed_dim)
105 | self.norm2 = nn.LayerNorm(embed_dim)
106 | self.ffn = nn.Sequential(
107 | nn.Linear(embed_dim, 4 * embed_dim), nn.ReLU(), nn.Linear(4 * embed_dim, embed_dim)
108 | )
109 |
110 | def forward(self, x: torch.Tensor) -> torch.Tensor:
111 | # First attention block with residual
112 | x_t = x.transpose(0, 1)
113 | attended, _ = self.attention(x_t, x_t, x_t)
114 | attended = attended.transpose(0, 1)
115 | x = self.norm1(x + attended)
116 |
117 | # FFN block with residual
118 | x = self.norm2(x + self.ffn(x))
119 | return x
120 |
121 |
122 | class EarthSystemLoss(nn.Module):
123 | def __init__(self, alpha: float = 0.5, beta: float = 0.3, gamma: float = 0.2):
124 | super().__init__()
125 | self.alpha = alpha
126 | self.beta = beta
127 | self.gamma = gamma
128 |
129 | def spatial_correlation_loss(
130 | self, pred: torch.Tensor, target: torch.Tensor, points: torch.Tensor
131 | ) -> torch.Tensor:
132 | batch_size, num_points, _ = points.shape
133 | points_flat = points.view(-1, 2)
134 |
135 | # Compute pairwise distances
136 | dists = torch.cdist(points_flat, points_flat)
137 | dists = dists.view(batch_size, num_points, num_points)
138 |
139 | # Create mask for nearby points (5 degrees threshold)
140 | nearby_mask = (dists < 5.0).float().unsqueeze(-1)
141 |
142 | # Compute differences
143 | pred_diff = pred.unsqueeze(2) - pred.unsqueeze(1)
144 | target_diff = target.unsqueeze(2) - target.unsqueeze(1)
145 |
146 | # Calculate loss with proper broadcasting
147 | correlation_loss = torch.mean(nearby_mask * (pred_diff - target_diff).pow(2))
148 |
149 | return correlation_loss
150 |
151 | def physical_loss(self, pred: torch.Tensor, points: torch.Tensor) -> torch.Tensor:
152 | """Calculate physical consistency loss - ensures predictions follow basic physical laws"""
153 | # Ensure non-negative values for physical quantities (e.g., temperature in Kelvin)
154 | min_value_loss = torch.nn.functional.relu(-pred).mean()
155 |
156 | # Ensure reasonable maximum values (e.g., max temperature)
157 | max_value_loss = torch.nn.functional.relu(pred - 500).mean() # Assuming max value of 500
158 |
159 | # Add latitude-based consistency (e.g., colder at poles)
160 | latitude = points[..., 1] # Second coordinate is latitude
161 | abs_latitude = torch.abs(latitude)
162 | latitude_consistency = torch.mean(
163 | torch.nn.functional.relu(pred[..., 0] - (1.0 - abs_latitude / 90.0) * pred.mean())
164 | )
165 |
166 | # Combine physical constraints
167 | physical_loss = min_value_loss + max_value_loss + 0.1 * latitude_consistency
168 | return physical_loss
169 |
170 | def forward(self, pred: torch.Tensor, target: torch.Tensor, points: torch.Tensor) -> dict:
171 | mse_loss = torch.nn.functional.mse_loss(pred, target)
172 | spatial_loss = self.spatial_correlation_loss(pred, target, points)
173 | physical_loss = self.physical_loss(pred, points)
174 |
175 | # Combine losses with the specified weights
176 | total_loss = self.alpha * mse_loss + self.beta * spatial_loss + self.gamma * physical_loss
177 |
178 | return {
179 | "total_loss": total_loss,
180 | "mse_loss": mse_loss,
181 | "spatial_correlation_loss": spatial_loss,
182 | "physical_loss": physical_loss,
183 | }
184 |
185 |
186 | class AuroraModel(nn.Module):
187 | def __init__(
188 | self,
189 | input_features: int,
190 | output_features: int,
191 | latent_dim: int = 256,
192 | num_layers: int = 4,
193 | max_points: int = 10000,
194 | max_seq_len: int = 1024,
195 | use_checkpointing: bool = False,
196 | ):
197 | super().__init__()
198 |
199 | self.max_points = max_points
200 | self.max_seq_len = max_seq_len
201 | self.input_features = input_features
202 | self.output_features = output_features
203 |
204 | # Model components
205 | self.encoder = PointEncoder(input_features, latent_dim, max_seq_len)
206 | self.processor = PointCloudProcessor(latent_dim, num_layers)
207 | self.decoder = PointDecoder(latent_dim, output_features)
208 |
209 | # Add gradient checkpointing
210 | self.use_checkpointing = use_checkpointing
211 |
212 | # Initialize weights properly
213 | self._init_weights()
214 |
215 | def _init_weights(self):
216 | for m in self.modules():
217 | if isinstance(m, nn.Linear):
218 | nn.init.xavier_uniform_(m.weight)
219 | if m.bias is not None:
220 | nn.init.zeros_(m.bias)
221 |
222 | def forward(
223 | self, points: torch.Tensor, features: torch.Tensor, mask: Optional[torch.Tensor] = None
224 | ) -> torch.Tensor:
225 | if points.shape[1] > self.max_points:
226 | raise ValueError(
227 | f"Number of points ({points.shape[1]}) exceeds maximum ({self.max_points})"
228 | )
229 |
230 | # Handle mask properly
231 | if mask is not None:
232 | mask = mask.float().unsqueeze(-1)
233 | points = points * mask
234 | features = features * mask
235 |
236 | # Forward pass with gradient checkpointing
237 | x = self.encoder(points, features)
238 |
239 | if self.use_checkpointing and self.training:
240 | x = torch.utils.checkpoint.checkpoint(self.processor, x)
241 | else:
242 | x = self.processor(x)
243 |
244 | output = self.decoder(x)
245 |
246 | # Apply mask to output if provided
247 | if mask is not None:
248 | output = output * mask
249 |
250 | return output
251 |
--------------------------------------------------------------------------------
/graph_weather/models/aurora/processor.py:
--------------------------------------------------------------------------------
1 | """
2 | Perceiver Transformer Processor:
3 | - Takes encoded features and processes them using latent space mapping.
4 | - Uses a latent-space bottleneck to compress input dimensions.
5 | - Provides an efficient way to extract long-range dependencies.
6 | - All architectural parameters are configurable.
7 | """
8 |
9 | from dataclasses import dataclass
10 | from typing import Optional
11 |
12 | import einops
13 | import torch.nn as nn
14 |
15 |
16 | @dataclass
17 | class ProcessorConfig:
18 | input_dim: int = 256 # Match Swin3D output
19 | latent_dim: int = 512
20 | d_model: int = 256 # Match input_dim for consistency
21 | max_seq_len: int = 4096
22 | num_self_attention_layers: int = 6
23 | num_cross_attention_layers: int = 2
24 | num_attention_heads: int = 8
25 | hidden_dropout: float = 0.1
26 | attention_dropout: float = 0.1
27 | qk_head_dim: Optional[int] = 32
28 | activation_fn: str = "gelu"
29 | layer_norm_eps: float = 1e-12
30 |
31 | def __post_init__(self):
32 | # Validate parameters
33 | if self.input_dim <= 0:
34 | raise ValueError("input_dim must be positive")
35 | if self.max_seq_len <= 0:
36 | raise ValueError("max_seq_len must be positive")
37 | if self.num_attention_heads <= 0:
38 | raise ValueError("num_attention_heads must be positive")
39 | if not 0 <= self.hidden_dropout <= 1:
40 | raise ValueError("hidden_dropout must be between 0 and 1")
41 | if not 0 <= self.attention_dropout <= 1:
42 | raise ValueError("attention_dropout must be between 0 and 1")
43 |
44 |
45 | class PerceiverProcessor(nn.Module):
46 | def __init__(self, config: Optional[ProcessorConfig] = None):
47 | super().__init__()
48 | self.config = config or ProcessorConfig()
49 |
50 | # Input projection to match d_model
51 | self.input_projection = nn.Linear(self.config.input_dim, self.config.d_model)
52 |
53 | # Simplified architecture using transformer encoder
54 | self.encoder = nn.TransformerEncoder(
55 | nn.TransformerEncoderLayer(
56 | d_model=self.config.d_model,
57 | nhead=self.config.num_attention_heads,
58 | dim_feedforward=self.config.d_model * 4,
59 | dropout=self.config.hidden_dropout,
60 | activation=self.config.activation_fn,
61 | ),
62 | num_layers=self.config.num_self_attention_layers,
63 | )
64 |
65 | # Output projection
66 | self.output_projection = nn.Linear(self.config.d_model, self.config.latent_dim)
67 |
68 | def forward(self, x, attention_mask=None):
69 | # Handle 4D input using einops for clearer reshaping
70 | if len(x.shape) == 4:
71 | # Rearrange from (batch, seq, height, width) to (batch, seq*height*width, features)
72 | x = einops.rearrange(x, "b s h w -> b (s h w) c")
73 |
74 | # Project input
75 | x = self.input_projection(x)
76 |
77 | # Apply transformer encoder with einops for transpose operations
78 | if attention_mask is not None:
79 | # Convert boolean mask to float mask where True -> 0, False -> -inf
80 | mask = ~attention_mask
81 | mask = mask.float().masked_fill(mask, float("-inf"))
82 | x = einops.rearrange(
83 | x, "b s c -> s b c"
84 | ) # (batch, seq, channels) -> (seq, batch, channels)
85 | x = self.encoder(x, src_key_padding_mask=mask)
86 | x = einops.rearrange(
87 | x, "s b c -> b s c"
88 | ) # (seq, batch, channels) -> (batch, seq, channels)
89 | else:
90 | x = einops.rearrange(x, "b s c -> s b c")
91 | x = self.encoder(x)
92 | x = einops.rearrange(x, "s b c -> b s c")
93 |
94 | # Project to latent dimension and pool
95 | x = self.output_projection(x)
96 | x = x.mean(dim=1) # Global average pooling
97 |
98 | return x
99 |
--------------------------------------------------------------------------------
/graph_weather/models/fengwu_ghr/__init__.py:
--------------------------------------------------------------------------------
1 | """Main import for FengWu-GHR"""
2 |
3 | from .layers import ImageMetaModel, LoRAModule, MetaModel, WrapperImageModel, WrapperMetaModel
4 |
--------------------------------------------------------------------------------
/graph_weather/models/forecast.py:
--------------------------------------------------------------------------------
1 | """Model for forecasting weather from NWP states"""
2 |
3 | from typing import Optional
4 |
5 | import torch
6 | from einops import rearrange, repeat
7 | from huggingface_hub import PyTorchModelHubMixin
8 |
9 | from graph_weather.models import Decoder, Encoder, Processor
10 | from graph_weather.models.layers.constraint_layer import PhysicalConstraintLayer
11 |
12 |
13 | class GraphWeatherForecaster(torch.nn.Module, PyTorchModelHubMixin):
14 | """Main weather prediction model from the paper with physical constraints"""
15 |
16 | def __init__(
17 | self,
18 | lat_lons: list,
19 | resolution: int = 2,
20 | feature_dim: int = 78,
21 | aux_dim: int = 24,
22 | output_dim: Optional[int] = None,
23 | node_dim: int = 256,
24 | edge_dim: int = 256,
25 | num_blocks: int = 9,
26 | hidden_dim_processor_node: int = 256,
27 | hidden_dim_processor_edge: int = 256,
28 | hidden_layers_processor_node: int = 2,
29 | hidden_layers_processor_edge: int = 2,
30 | hidden_dim_decoder: int = 128,
31 | hidden_layers_decoder: int = 2,
32 | norm_type: str = "LayerNorm",
33 | use_checkpointing: bool = False,
34 | constraint_type: str = "none",
35 | ):
36 | """
37 | Graph Weather Model based off https://arxiv.org/pdf/2202.07575.pdf
38 |
39 | Args:
40 | lat_lons: List of latitude and longitudes for the grid
41 | resolution: Resolution of the H3 grid, prefer even resolutions, as
42 | odd ones have octogons and heptagons as well
43 | feature_dim: Input feature size
44 | aux_dim: Number of non-NWP features (i.e. landsea mask, lat/lon, etc)
45 | output_dim: Optional, output feature size, useful if want only subset of variables in
46 | output
47 | node_dim: Node hidden dimension
48 | edge_dim: Edge hidden dimension
49 | num_blocks: Number of message passing blocks in the Processor
50 | hidden_dim_processor_node: Hidden dimension of the node processors
51 | hidden_dim_processor_edge: Hidden dimension of the edge processors
52 | hidden_layers_processor_node: Number of hidden layers in the node processors
53 | hidden_layers_processor_edge: Number of hidden layers in the edge processors
54 | hidden_dim_decoder:Number of hidden dimensions in the decoder
55 | hidden_layers_decoder: Number of layers in the decoder
56 | norm_type: Type of norm for the MLPs
57 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
58 | use_checkpointing: Use gradient checkpointing to reduce model memory
59 | constraint_type: Type of constraint to apply for physical constraints
60 | one of 'additive', 'multiplicative', 'softmax', or 'none'
61 | """
62 | super().__init__()
63 | self.feature_dim = feature_dim
64 | self.constraint_type = constraint_type
65 | if output_dim is None:
66 | output_dim = self.feature_dim
67 | self.output_dim = output_dim
68 |
69 | # Compute the geographical grid shape from lat_lons.
70 | unique_lats = sorted(set(lat for lat, _ in lat_lons))
71 | unique_lons = sorted(set(lon for _, lon in lat_lons))
72 | self.grid_shape = (len(unique_lats), len(unique_lons)) # (H, W)
73 |
74 | # Store original node order and create grid mapping
75 | self.original_lat_lons = lat_lons.copy()
76 | self._create_grid_mapping(unique_lats, unique_lons)
77 |
78 | self.encoder = Encoder(
79 | lat_lons=lat_lons,
80 | resolution=resolution,
81 | input_dim=feature_dim + aux_dim,
82 | output_dim=node_dim,
83 | output_edge_dim=edge_dim,
84 | hidden_dim_processor_edge=hidden_dim_processor_edge,
85 | hidden_layers_processor_node=hidden_layers_processor_node,
86 | hidden_dim_processor_node=hidden_dim_processor_node,
87 | hidden_layers_processor_edge=hidden_layers_processor_edge,
88 | mlp_norm_type=norm_type,
89 | use_checkpointing=use_checkpointing,
90 | )
91 | self.processor = Processor(
92 | input_dim=node_dim,
93 | edge_dim=edge_dim,
94 | num_blocks=num_blocks,
95 | hidden_dim_processor_edge=hidden_dim_processor_edge,
96 | hidden_layers_processor_node=hidden_layers_processor_node,
97 | hidden_dim_processor_node=hidden_dim_processor_node,
98 | hidden_layers_processor_edge=hidden_layers_processor_edge,
99 | mlp_norm_type=norm_type,
100 | )
101 | self.decoder = Decoder(
102 | lat_lons=lat_lons,
103 | resolution=resolution,
104 | input_dim=node_dim,
105 | output_dim=output_dim,
106 | output_edge_dim=edge_dim,
107 | hidden_dim_processor_edge=hidden_dim_processor_edge,
108 | hidden_layers_processor_node=hidden_layers_processor_node,
109 | hidden_dim_processor_node=hidden_dim_processor_node,
110 | hidden_layers_processor_edge=hidden_layers_processor_edge,
111 | mlp_norm_type=norm_type,
112 | hidden_dim_decoder=hidden_dim_decoder,
113 | hidden_layers_decoder=hidden_layers_decoder,
114 | use_checkpointing=use_checkpointing,
115 | )
116 |
117 | # Add physical constraint layer if constraint_type is not "none"
118 | if self.constraint_type != "none":
119 | self.constraint = PhysicalConstraintLayer(
120 | model=self,
121 | grid_shape=self.grid_shape,
122 | constraint_type=constraint_type,
123 | upsampling_factor=1,
124 | )
125 |
126 | def _create_grid_mapping(self, unique_lats, unique_lons):
127 | """Create (row,col) mapping for original node order"""
128 | self.node_to_grid = []
129 | for lat, lon in self.original_lat_lons:
130 | row = int(
131 | (lat - min(unique_lats))
132 | / (max(unique_lats) - min(unique_lats))
133 | * (len(unique_lats) - 1)
134 | )
135 | col = int(
136 | (lon - min(unique_lons))
137 | / (max(unique_lons) - min(unique_lons))
138 | * (len(unique_lons) - 1)
139 | )
140 | self.node_to_grid.append((row, col))
141 |
142 | def graph_to_grid(self, graph_tensor):
143 | """
144 |
145 | Convert graph tensor to grid using spatial mapping:
146 | [B, N, C] -> [B, C, H, W]
147 | """
148 | batch_size, num_nodes, features = graph_tensor.shape
149 | grid = torch.zeros(batch_size, features, *self.grid_shape)
150 | for node_idx, (row, col) in enumerate(self.node_to_grid):
151 | grid[..., row, col] = graph_tensor[..., node_idx, :]
152 | return grid
153 |
154 | def grid_to_graph(self, grid_tensor):
155 | """Convert grid to graph tensor: [B, C, H, W] -> [B, N, C]"""
156 | batch_size, features, H, W = grid_tensor.shape
157 | graph = torch.zeros(batch_size, H * W, features)
158 | for node_idx, (row, col) in enumerate(self.node_to_grid):
159 | graph[..., node_idx, :] = grid_tensor[..., row, col]
160 | return graph
161 |
162 | def forward(self, features: torch.Tensor) -> torch.Tensor:
163 | """
164 | Compute the new state of the forecast
165 |
166 | Args:
167 | features: The input features, aligned with the order of lat_lons_heights
168 |
169 | Returns:
170 | The next state in the forecast
171 | """
172 | x, edge_idx, edge_attr = self.encoder(features)
173 | x = self.processor(x, edge_idx, edge_attr)
174 | x = self.decoder(x, features[..., : self.feature_dim])
175 |
176 | # Here, assume decoder output x is a 4D tensor,
177 | # e.g. [B, output_dim, H, W] where H and W are grid dimensions.
178 | # Convert graph output to grid format
179 |
180 | # Apply physical constraints to decoder output
181 | if self.constraint_type != "none":
182 | x = rearrange(x, "b (h w) c -> b c h w", h=self.grid_shape[0], w=self.grid_shape[1])
183 | # Extract the low-res reference from the input.
184 | # (Original features has shape [B, num_nodes, feature_dim])
185 | lr = features[..., : self.feature_dim] # shape: [B, num_nodes, feature_dim]
186 | # Convert from node format to grid format using the grid_shape computed in __init__
187 | # From [B, num_nodes, feature_dim] to [B, feature_dim, H, W]
188 | lr = rearrange(lr, "b (h w) c -> b c h w", h=self.grid_shape[0], w=self.grid_shape[1])
189 | if lr.size(1) != x.size(1):
190 | repeat_factor = x.size(1) // lr.size(1)
191 | lr = repeat(lr, "b c h w -> b (r c) h w", r=repeat_factor)
192 | x = self.constraint(x, lr)
193 | return x
194 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/__init__.py:
--------------------------------------------------------------------------------
1 | """Main import for GenCast"""
2 |
3 | from .denoiser import Denoiser
4 | from .graph.graph_builder import GraphBuilder
5 | from .sampler import Sampler
6 | from .utils.noise import generate_isotropic_noise
7 | from .weighted_mse_loss import WeightedMSELoss
8 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/graph/__init__.py:
--------------------------------------------------------------------------------
1 | """Utils for graph generation."""
2 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/graph/grid_mesh_connectivity.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Source: https://github.com/google-deepmind/graphcast.
15 | """Tools for converting from regular grids on a sphere, to triangular meshes."""
16 |
17 | import numpy as np
18 | import scipy
19 | import trimesh
20 |
21 | from graph_weather.models.gencast.graph import icosahedral_mesh
22 |
23 |
24 | def _grid_lat_lon_to_coordinates(
25 | grid_latitude: np.ndarray, grid_longitude: np.ndarray
26 | ) -> np.ndarray:
27 | """Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3]."""
28 | # Convert to spherical coordinates phi and theta defined in the grid.
29 | # Each [num_latitude_points, num_longitude_points]
30 | phi_grid, theta_grid = np.meshgrid(np.deg2rad(grid_longitude), np.deg2rad(90 - grid_latitude))
31 |
32 | # [num_latitude_points, num_longitude_points, 3]
33 | # Note this assumes unit radius, since for now we model the earth as a
34 | # sphere of unit radius, and keep any vertical dimension as a regular grid.
35 | return np.stack(
36 | [
37 | np.cos(phi_grid) * np.sin(theta_grid),
38 | np.sin(phi_grid) * np.sin(theta_grid),
39 | np.cos(theta_grid),
40 | ],
41 | axis=-1,
42 | )
43 |
44 |
45 | def radius_query_indices(
46 | *,
47 | grid_latitude: np.ndarray,
48 | grid_longitude: np.ndarray,
49 | mesh: icosahedral_mesh.TriangularMesh,
50 | radius: float,
51 | ) -> tuple[np.ndarray, np.ndarray]:
52 | """Returns mesh-grid edge indices for radius query.
53 |
54 | Args:
55 | grid_latitude: Latitude values for the grid [num_lat_points]
56 | grid_longitude: Longitude values for the grid [num_lon_points]
57 | mesh: Mesh object.
58 | radius: Radius of connectivity in R3. for a sphere of unit radius.
59 |
60 | Returns:
61 | tuple with `grid_indices` and `mesh_indices` indicating edges between the
62 | grid and the mesh such that the distances in a straight line (not geodesic)
63 | are smaller than or equal to `radius`.
64 | * grid_indices: Indices of shape [num_edges], that index into a
65 | [num_lat_points, num_lon_points] grid, after flattening the leading axes.
66 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
67 | """
68 |
69 | # [num_grid_points=num_lat_points * num_lon_points, 3]
70 | grid_positions = _grid_lat_lon_to_coordinates(grid_latitude, grid_longitude).reshape([-1, 3])
71 |
72 | # [num_mesh_points, 3]
73 | mesh_positions = mesh.vertices
74 | kd_tree = scipy.spatial.cKDTree(mesh_positions)
75 |
76 | # [num_grid_points, num_mesh_points_per_grid_point]
77 | # Note `num_mesh_points_per_grid_point` is not constant, so this is a list
78 | # of arrays, rather than a 2d array.
79 | query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius)
80 |
81 | grid_edge_indices = []
82 | mesh_edge_indices = []
83 | for grid_index, mesh_neighbors in enumerate(query_indices):
84 | grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors)))
85 | mesh_edge_indices.append(mesh_neighbors)
86 |
87 | # [num_edges]
88 | grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int)
89 | mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int)
90 |
91 | return grid_edge_indices, mesh_edge_indices
92 |
93 |
94 | def in_mesh_triangle_indices(
95 | *, grid_latitude: np.ndarray, grid_longitude: np.ndarray, mesh: icosahedral_mesh.TriangularMesh
96 | ) -> tuple[np.ndarray, np.ndarray]:
97 | """Returns mesh-grid edge indices for grid points contained in mesh triangles.
98 |
99 | Args:
100 | grid_latitude: Latitude values for the grid [num_lat_points]
101 | grid_longitude: Longitude values for the grid [num_lon_points]
102 | mesh: Mesh object.
103 |
104 | Returns:
105 | tuple with `grid_indices` and `mesh_indices` indicating edges between the
106 | grid and the mesh vertices of the triangle that contain each grid point.
107 | The number of edges is always num_lat_points * num_lon_points * 3
108 | * grid_indices: Indices of shape [num_edges], that index into a
109 | [num_lat_points, num_lon_points] grid, after flattening the leading axes.
110 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
111 | """
112 |
113 | # [num_grid_points=num_lat_points * num_lon_points, 3]
114 | grid_positions = _grid_lat_lon_to_coordinates(grid_latitude, grid_longitude).reshape([-1, 3])
115 |
116 | mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
117 |
118 | # [num_grid_points] with mesh face indices for each grid point.
119 | _, _, query_face_indices = trimesh.proximity.closest_point(mesh_trimesh, grid_positions)
120 |
121 | # [num_grid_points, 3] with mesh node indices for each grid point.
122 | mesh_edge_indices = mesh.faces[query_face_indices]
123 |
124 | # [num_grid_points, 3] with grid node indices, where every row simply contains
125 | # the row (grid_point) index.
126 | grid_indices = np.arange(grid_positions.shape[0])
127 | grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3])
128 |
129 | # Flatten to get a regular list.
130 | # [num_edges=num_grid_points*3]
131 | mesh_edge_indices = mesh_edge_indices.reshape([-1])
132 | grid_edge_indices = grid_edge_indices.reshape([-1])
133 |
134 | return grid_edge_indices, mesh_edge_indices
135 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/images/animated.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/gencast/images/animated.gif
--------------------------------------------------------------------------------
/graph_weather/models/gencast/images/autoregressive.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/gencast/images/autoregressive.gif
--------------------------------------------------------------------------------
/graph_weather/models/gencast/images/fullmodel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/gencast/images/fullmodel.png
--------------------------------------------------------------------------------
/graph_weather/models/gencast/images/readme.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/gencast/images/readme.md
--------------------------------------------------------------------------------
/graph_weather/models/gencast/layers/__init__.py:
--------------------------------------------------------------------------------
1 | """GenCast layers."""
2 |
3 | from .decoder import Decoder
4 | from .encoder import Encoder
5 | from .modules import MLP, InteractionNetwork
6 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/layers/decoder.py:
--------------------------------------------------------------------------------
1 | """Decoder layer.
2 |
3 | The decoder:
4 | - perform a single message-passing step on mesh2grid using a classical interaction network.
5 | - add a residual connection to the grid nodes.
6 | """
7 |
8 | import torch
9 |
10 | from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork
11 |
12 |
13 | class Decoder(torch.nn.Module):
14 | """GenCast's decoder."""
15 |
16 | def __init__(
17 | self,
18 | edges_dim: int,
19 | output_dim: int,
20 | hidden_dims: list[int],
21 | activation_layer: torch.nn.Module = torch.nn.ReLU,
22 | use_layer_norm: bool = True,
23 | ):
24 | """Initialize the Decoder.
25 |
26 | Args:
27 | edges_dim (int): dimension of edges' features.
28 | output_dim (int): dimension of final output.
29 | hidden_dims (list[int]): hidden dimensions of internal MLPs.
30 | activation_layer (torch.nn.Module, optional): activation function of internal MLPs.
31 | Defaults to torch.nn.ReLU.
32 | use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP.
33 | Defaults to True.
34 | """
35 | super().__init__()
36 |
37 | # All the MLPs in GenCast have same hidden and output dims. Hence, the embedding latent
38 | # dimension and the MLPs' output dimension are the same. Moreover, for simplicity, we will
39 | # ask the hidden dims just once for each MLP in a module: we don't need to specify them
40 | # individually as arguments, even if the MLPs could have different roles.
41 | self.latent_dim = hidden_dims[-1]
42 |
43 | # Embedders
44 | self.edges_mlp = MLP(
45 | input_dim=edges_dim,
46 | hidden_dims=hidden_dims,
47 | activation_layer=activation_layer,
48 | use_layer_norm=use_layer_norm,
49 | bias=True,
50 | activate_final=False,
51 | )
52 |
53 | # Message Passing
54 | self.gnn = InteractionNetwork(
55 | sender_dim=self.latent_dim,
56 | receiver_dim=self.latent_dim,
57 | edge_attr_dim=self.latent_dim,
58 | hidden_dims=hidden_dims,
59 | use_layer_norm=use_layer_norm,
60 | activation_layer=activation_layer,
61 | )
62 |
63 | # Final grid nodes update
64 | self.grid_mlp_final = MLP(
65 | input_dim=self.latent_dim,
66 | hidden_dims=hidden_dims[:-1] + [output_dim],
67 | activation_layer=activation_layer,
68 | use_layer_norm=use_layer_norm,
69 | bias=True,
70 | activate_final=False,
71 | )
72 |
73 | def forward(
74 | self,
75 | input_mesh_nodes: torch.Tensor,
76 | input_grid_nodes: torch.Tensor,
77 | input_edge_attr: torch.Tensor,
78 | edge_index: torch.Tensor,
79 | ) -> torch.Tensor:
80 | """Forward pass.
81 |
82 | Args:
83 | input_mesh_nodes (torch.Tensor): mesh nodes' features.
84 | input_grid_nodes (torch.Tensor): grid nodes' features.
85 | input_edge_attr (torch.Tensor): grid2mesh edges' features.
86 | edge_index (torch.Tensor): edge index tensor.
87 |
88 | Returns:
89 | torch.Tensor: output grid nodes.
90 | """
91 | if not (
92 | input_grid_nodes.shape[-1] == self.latent_dim
93 | and input_mesh_nodes.shape[-1] == self.latent_dim
94 | ):
95 | raise ValueError(
96 | "The dimension of grid nodes and mesh nodes' features must be "
97 | "equal to the last hidden dimension."
98 | )
99 |
100 | # Embedding
101 | edges_emb = self.edges_mlp(input_edge_attr)
102 |
103 | # Message-passing + residual connection
104 | latent_grid_nodes = input_grid_nodes + self.gnn(
105 | x=(input_mesh_nodes, input_grid_nodes),
106 | edge_index=edge_index,
107 | edge_attr=edges_emb,
108 | )
109 |
110 | # Update grid nodes
111 | latent_grid_nodes = self.grid_mlp_final(latent_grid_nodes)
112 |
113 | return latent_grid_nodes
114 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/layers/encoder.py:
--------------------------------------------------------------------------------
1 | """Encoder layer.
2 |
3 | The encoder:
4 | - embeds grid nodes, mesh nodes and g2m edges' features to the latent space.
5 | - perform a single message-passing step using a classical interaction network.
6 | - add a residual connection to the mesh and grid nodes.
7 | """
8 |
9 | import torch
10 |
11 | from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork
12 |
13 |
14 | class Encoder(torch.nn.Module):
15 | """GenCast's encoder."""
16 |
17 | def __init__(
18 | self,
19 | grid_dim: int,
20 | mesh_dim: int,
21 | edge_dim: int,
22 | hidden_dims: list[int],
23 | activation_layer: torch.nn.Module = torch.nn.ReLU,
24 | use_layer_norm: bool = True,
25 | scale_factor: float = 1.0,
26 | ):
27 | """Initialize the Encoder.
28 |
29 | Args:
30 | grid_dim (int): dimension of grid nodes' features.
31 | mesh_dim (int): dimension of mesh nodes' features
32 | edge_dim (int): dimension of g2m edges' features
33 | hidden_dims (list[int]): hidden dimensions of internal MLPs.
34 | activation_layer (torch.nn.Module, optional): activation function of internal MLPs.
35 | Defaults to torch.nn.ReLU.
36 | use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP.
37 | Defaults to True.
38 | scale_factor (float): the message of the interaction network between the grid and the
39 | the mesh is multiplied by the scale factor. Useful when fine-tuning a pretrained
40 | model to a higher resolution. Defaults to 1.
41 | """
42 | super().__init__()
43 |
44 | # All the MLPs in GenCast have same hidden and output dims. Hence, the embedding latent
45 | # dimension and the MLPs' output dimension are the same. Moreover, for simplicity, we will
46 | # ask the hidden dims just once for each MLP in a module: we don't need to specify them
47 | # individually as arguments, even if the MLPs could have different roles.
48 | self.latent_dim = hidden_dims[-1]
49 |
50 | # Embedders
51 | self.grid_mlp = MLP(
52 | input_dim=grid_dim,
53 | hidden_dims=hidden_dims,
54 | activation_layer=activation_layer,
55 | use_layer_norm=use_layer_norm,
56 | bias=True,
57 | activate_final=False,
58 | )
59 |
60 | self.mesh_mlp = MLP(
61 | input_dim=mesh_dim,
62 | hidden_dims=hidden_dims,
63 | activation_layer=activation_layer,
64 | use_layer_norm=use_layer_norm,
65 | bias=True,
66 | activate_final=False,
67 | )
68 |
69 | self.edges_mlp = MLP(
70 | input_dim=edge_dim,
71 | hidden_dims=hidden_dims,
72 | activation_layer=activation_layer,
73 | use_layer_norm=use_layer_norm,
74 | bias=True,
75 | activate_final=False,
76 | )
77 |
78 | # Message Passing
79 | self.gnn = InteractionNetwork(
80 | sender_dim=self.latent_dim,
81 | receiver_dim=self.latent_dim,
82 | edge_attr_dim=self.latent_dim,
83 | hidden_dims=hidden_dims,
84 | use_layer_norm=use_layer_norm,
85 | activation_layer=activation_layer,
86 | scale_factor=scale_factor,
87 | )
88 |
89 | # Final grid nodes update
90 | self.grid_mlp_final = MLP(
91 | input_dim=self.latent_dim,
92 | hidden_dims=hidden_dims,
93 | activation_layer=activation_layer,
94 | use_layer_norm=use_layer_norm,
95 | bias=True,
96 | activate_final=False,
97 | )
98 |
99 | def forward(
100 | self,
101 | input_grid_nodes: torch.Tensor,
102 | input_mesh_nodes: torch.Tensor,
103 | input_edge_attr: torch.Tensor,
104 | edge_index: torch.Tensor,
105 | ) -> tuple[torch.Tensor, torch.Tensor]:
106 | """Forward pass.
107 |
108 | Args:
109 | input_grid_nodes (torch.Tensor): grid nodes' features.
110 | input_mesh_nodes (torch.Tensor): mesh nodes' features.
111 | input_edge_attr (torch.Tensor): grid2mesh edges' features.
112 | edge_index (torch.Tensor): edge index tensor.
113 |
114 | Returns:
115 | tuple[torch.Tensor, torch.Tensor]: output grid nodes, output mesh nodes.
116 | """
117 |
118 | # Embedding
119 | grid_emb = self.grid_mlp(input_grid_nodes)
120 | mesh_emb = self.mesh_mlp(input_mesh_nodes)
121 | edges_emb = self.edges_mlp(input_edge_attr)
122 |
123 | # Message-passing + residual connection
124 | latent_mesh_nodes = mesh_emb + self.gnn(
125 | x=(grid_emb, mesh_emb),
126 | edge_index=edge_index,
127 | edge_attr=edges_emb,
128 | )
129 |
130 | # Update grid nodes + residual connection
131 | latent_grid_nodes = grid_emb + self.grid_mlp_final(grid_emb)
132 |
133 | return latent_grid_nodes, latent_mesh_nodes
134 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/layers/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | """Experimental features."""
2 |
3 | from .sparse_transformer import SparseTransformer
4 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/layers/experimental/sparse_transformer.py:
--------------------------------------------------------------------------------
1 | """Experimental sparse transformer using DGL sparse."""
2 |
3 | import dgl.sparse as dglsp
4 | import torch
5 | import torch.nn as nn
6 |
7 | from graph_weather.models.gencast.layers.modules import ConditionalLayerNorm
8 |
9 |
10 | class SparseAttention(nn.Module):
11 | """Sparse Multi-head Attention Module"""
12 |
13 | def __init__(self, input_dim=512, output_dim=512, num_heads=4):
14 | """Initialize Sparse MultiHead attention module.
15 |
16 | Args:
17 | input_dim (int): input dimension. Defaults to 512.
18 | output_dim (int): output dimension. Defaults to 512.
19 | num_heads (int): number of heads. Output dimension should be divisible by num_heads.
20 | Defaults to 4.
21 | """
22 | super().__init__()
23 | if output_dim % num_heads:
24 | raise ValueError("Output dimension should be divisible by the number of heads.")
25 |
26 | self.hidden_size = output_dim
27 | self.num_heads = num_heads
28 | self.head_dim = output_dim // num_heads
29 | self.scaling = self.head_dim**-0.5
30 |
31 | self.q_proj = nn.Linear(input_dim, output_dim)
32 | self.k_proj = nn.Linear(input_dim, output_dim)
33 | self.v_proj = nn.Linear(input_dim, output_dim)
34 | self.out_proj = nn.Linear(output_dim, output_dim)
35 |
36 | def forward(self, x: torch.Tensor, adj: dglsp.SparseMatrix):
37 | """Forward pass of SparseMHA.
38 |
39 | Args:
40 | x (torch.Tensor): input tensor.
41 | adj (SparseMatrix): adjacency matrix in DGL SparseMatrix format.
42 |
43 | Returns:
44 | y (tensor): output of MultiHead attention.
45 | """
46 | N = len(x)
47 | # computing query,key and values.
48 | q = self.q_proj(x).reshape(N, self.head_dim, self.num_heads) # (dense) [N, dh, nh]
49 | k = self.k_proj(x).reshape(N, self.head_dim, self.num_heads) # (dense) [N, dh, nh]
50 | v = self.v_proj(x).reshape(N, self.head_dim, self.num_heads) # (dense) [N, dh, nh]
51 | # scaling query
52 | q *= self.scaling
53 |
54 | # sparse-dense-dense product
55 | attn = dglsp.bsddmm(adj, q, k.transpose(1, 0)) # (sparse) [N, N, nh]
56 |
57 | # sparse softmax (by default applies on the last sparse dimension).
58 | attn = attn.softmax() # (sparse) [N, N, nh]
59 |
60 | # sparse-dense multiplication
61 | out = dglsp.bspmm(attn, v) # (dense) [N, dh, nh]
62 | return self.out_proj(out.reshape(N, -1))
63 |
64 |
65 | class SparseTransformer(nn.Module):
66 | """A single transformer block for graph neural networks.
67 |
68 | This module implements a single transformer block with a sparse attention mechanism.
69 | """
70 |
71 | def __init__(
72 | self,
73 | conditioning_dim: int,
74 | input_dim: int,
75 | output_dim: int,
76 | num_heads: int,
77 | activation_layer: torch.nn.Module = nn.ReLU,
78 | norm_first: bool = True,
79 | ):
80 | """Initialize SparseTransformer module.
81 |
82 | Args:
83 | conditioning_dim (int, optional): dimension of the conditioning parameter. If None the
84 | layer normalization will not be applied.
85 | input_dim (int): dimension of the input features.
86 | output_dim (int): dimension of the output features.
87 | edges_dim (int): dimension of the edge features.
88 | num_heads (int): number of heads for multi-head attention.
89 | activation_layer (torch.nn.Module): activation function applied before
90 | returning the output.
91 | norm_first (bool): if True apply layer normalization before attention. Defaults to True.
92 | """
93 | super().__init__()
94 |
95 | # initialize multihead sparse attention.
96 | self.sparse_attention = SparseAttention(
97 | input_dim=input_dim, output_dim=output_dim, num_heads=num_heads
98 | )
99 |
100 | # initialize mlp
101 | self.activation = activation_layer()
102 | self.mlp = nn.Sequential(
103 | nn.Linear(output_dim, output_dim), self.activation, nn.Linear(output_dim, output_dim)
104 | )
105 |
106 | # initialize conditional layer normalization
107 | self.cond_norm_1 = ConditionalLayerNorm(
108 | conditioning_dim=conditioning_dim, features_dim=output_dim
109 | )
110 | self.cond_norm_2 = ConditionalLayerNorm(
111 | conditioning_dim=conditioning_dim, features_dim=output_dim
112 | )
113 |
114 | self.norm_first = norm_first
115 |
116 | def forward(
117 | self,
118 | x: torch.Tensor,
119 | edge_index: torch.Tensor,
120 | cond_param: torch.Tensor,
121 | *args,
122 | **kwargs,
123 | ) -> torch.Tensor:
124 | """Apply SparseTransformer to input.
125 |
126 | Input and conditioning parameter must have same batch size.
127 |
128 | Args:
129 | x (torch.Tensor): tensor containing nodes features.
130 | edge_index (torch.Tensor): edge index tensor.
131 | cond_param (torch.Tensor): conditioning parameter.
132 | *args: ignored by the module.
133 | **kwargs: ignored by the module.
134 |
135 | """
136 | if self.norm_first:
137 | x1 = self.cond_norm_1(x, cond_param)
138 | x = x + self.sparse_attention(
139 | x=x1, adj=dglsp.spmatrix(indices=edge_index, shape=(x.shape[0], x.shape[0]))
140 | )
141 | else:
142 | x = x + self.sparse_attention(
143 | x=x, adj=dglsp.spmatrix(indices=edge_index, shape=(x.shape[0], x.shape[0]))
144 | )
145 | x = self.cond_norm_1(x, cond_param)
146 |
147 | if self.norm_first:
148 | x2 = self.cond_norm_2(x, cond_param)
149 | x = x + self.mlp(x2)
150 | else:
151 | x = x + self.mlp(x)
152 | x = self.cond_norm_2(x, cond_param)
153 | return x
154 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/layers/processor.py:
--------------------------------------------------------------------------------
1 | """Processor layer.
2 |
3 | The processor:
4 | - compute a sequence of transformer blocks applied to the mesh.
5 | - condition on the noise level.
6 | """
7 |
8 | import torch
9 |
10 | from graph_weather.models.gencast.layers.modules import MLP, CondTransformerBlock, FourierEmbedding
11 |
12 | try:
13 | from graph_weather.models.gencast.layers.experimental import SparseTransformer
14 |
15 | has_dgl = True
16 | except ImportError:
17 | has_dgl = False
18 |
19 |
20 | class Processor(torch.nn.Module):
21 | """GenCast's Processor
22 |
23 | The Processor is a sequence of transformer blocks conditioned on noise level. If the graph has
24 | many edges, setting sparse=True may perform better in terms of memory and speed. Note that
25 | sparse=False uses PyG as the backend, while sparse=True uses DGL. The two implementations are
26 | not exactly equivalent: the former is described in the paper "Masked Label Prediction: Unified
27 | Message Passing Model for Semi-Supervised Classification" and can also handle edge features,
28 | while the latter is a classical transformer that performs multi-head attention utilizing the
29 | mask's sparsity and does not include edge features in the computations.
30 |
31 | Note: The GenCast paper does not provide specific details regarding the implementation of the
32 | transformer architecture for graphs.
33 | """
34 |
35 | def __init__(
36 | self,
37 | latent_dim: int,
38 | hidden_dims: list[int],
39 | num_blocks: int,
40 | num_heads: int,
41 | num_frequencies: int,
42 | base_period: int,
43 | noise_emb_dim: int,
44 | edges_dim: int | None = None,
45 | activation_layer: torch.nn.Module = torch.nn.ReLU,
46 | use_layer_norm: bool = True,
47 | sparse: bool = False,
48 | ):
49 | """Initialize the Processor.
50 |
51 | Args:
52 | latent_dim (int): dimension of nodes' features.
53 | hidden_dims (list[int]): hidden dimensions of internal MLPs.
54 | num_blocks (int): number of transformer blocks.
55 | num_heads (int): number of heads for multi-head attention.
56 | num_frequencies (int): number of frequencies for the noise Fourier embedding.
57 | base_period (int): base period for the noise Fourier embedding.
58 | noise_emb_dim (int): dimension of output of noise embedding.
59 | edges_dim (int, optional): dimension of edges' features. If None does not uses edges
60 | features in TransformerConv. Defaults to None.
61 | activation_layer (torch.nn.Module): activation function of internal MLPs.
62 | Defaults to torch.nn.ReLU.
63 | use_layer_norm (bool): if true add a LayerNorm at the end of the embedding MLP.
64 | Defaults to True.
65 | sparse (bool): if true use DGL as backend (experimental). Defaults to False.
66 | """
67 | super().__init__()
68 | self.latent_dim = latent_dim
69 | if latent_dim % num_heads != 0:
70 | raise ValueError("The latent dimension should be divisible by the number of heads.")
71 |
72 | # Embedders
73 | self.fourier_embedder = FourierEmbedding(
74 | output_dim=noise_emb_dim, num_frequencies=num_frequencies, base_period=base_period
75 | )
76 |
77 | self.edges_dim = edges_dim
78 | if edges_dim is not None:
79 | self.edges_mlp = MLP(
80 | input_dim=edges_dim,
81 | hidden_dims=hidden_dims,
82 | activation_layer=activation_layer,
83 | use_layer_norm=use_layer_norm,
84 | bias=True,
85 | activate_final=False,
86 | )
87 |
88 | # Tranformers Blocks
89 | self.cond_transformers = torch.nn.ModuleList()
90 | if not sparse:
91 | for _ in range(num_blocks - 1):
92 | # concatenating multi-head attention
93 | self.cond_transformers.append(
94 | CondTransformerBlock(
95 | conditioning_dim=noise_emb_dim,
96 | input_dim=latent_dim,
97 | output_dim=latent_dim // num_heads,
98 | edges_dim=hidden_dims[-1] if (edges_dim is not None) else None,
99 | num_heads=num_heads,
100 | concat=True,
101 | beta=True,
102 | activation_layer=activation_layer,
103 | )
104 | )
105 |
106 | # averaging multi-head attention
107 | self.cond_transformers.append(
108 | CondTransformerBlock(
109 | conditioning_dim=noise_emb_dim,
110 | input_dim=latent_dim,
111 | output_dim=latent_dim,
112 | edges_dim=hidden_dims[-1] if (edges_dim is not None) else None,
113 | num_heads=num_heads,
114 | concat=False,
115 | beta=True,
116 | activation_layer=None,
117 | )
118 | )
119 | else:
120 | if not has_dgl:
121 | raise ValueError("Please install DGL to use sparsity.")
122 |
123 | for _ in range(num_blocks):
124 | # concatenating multi-head attention
125 | self.cond_transformers.append(
126 | SparseTransformer(
127 | conditioning_dim=noise_emb_dim,
128 | input_dim=latent_dim,
129 | output_dim=latent_dim,
130 | num_heads=num_heads,
131 | activation_layer=activation_layer,
132 | )
133 | )
134 | # do we really need averaging for last block?
135 |
136 | def _check_args(self, latent_mesh_nodes, noise_levels, input_edge_attr):
137 | if not latent_mesh_nodes.shape[-1] == self.latent_dim:
138 | raise ValueError(
139 | "The dimension of the mesh nodes is different from the latent dimension provided at"
140 | " initialization."
141 | )
142 |
143 | if not latent_mesh_nodes.shape[0] == noise_levels.shape[0]:
144 | raise ValueError(
145 | "The number of noise levels and mesh nodes should be the same, but got "
146 | f"{latent_mesh_nodes.shape[0]} and {noise_levels.shape[0]}. Eventually repeat the "
147 | " noise level for each node in the same batch."
148 | )
149 |
150 | if (input_edge_attr is not None) and (self.edges_dim is None):
151 | raise ValueError("To use input_edge_attr initialize the processor with edges_dim.")
152 |
153 | def forward(
154 | self,
155 | latent_mesh_nodes: torch.Tensor,
156 | edge_index: torch.Tensor,
157 | noise_levels: torch.Tensor,
158 | input_edge_attr: torch.Tensor | None = None,
159 | ) -> torch.Tensor:
160 | """Forward pass.
161 |
162 | Args:
163 | latent_mesh_nodes (torch.Tensor): mesh nodes' features.
164 | edge_index (torch.Tensor): edge index tensor.
165 | noise_levels (torch.Tensor): log-noise levels.
166 | input_edge_attr (torch.Tensor, optional): mesh edges' features.
167 |
168 | Returns:
169 | torch.Tensor: latent mesh nodes.
170 | """
171 | self._check_args(latent_mesh_nodes, noise_levels, input_edge_attr)
172 |
173 | # embedding
174 | noise_emb = self.fourier_embedder(noise_levels)
175 |
176 | if self.edges_dim is not None:
177 | edges_emb = self.edges_mlp(input_edge_attr)
178 | else:
179 | edges_emb = None
180 |
181 | # apply transformer blocks
182 | for cond_transformer in self.cond_transformers:
183 | latent_mesh_nodes = cond_transformer(
184 | x=latent_mesh_nodes,
185 | edge_index=edge_index,
186 | cond_param=noise_emb,
187 | edge_attr=edges_emb,
188 | )
189 |
190 | return latent_mesh_nodes
191 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/sampler.py:
--------------------------------------------------------------------------------
1 | """Diffusion sampler"""
2 |
3 | import math
4 |
5 | import torch
6 |
7 | from graph_weather.models.gencast import Denoiser
8 | from graph_weather.models.gencast.utils.noise import generate_isotropic_noise
9 |
10 |
11 | class Sampler:
12 | """Sampler for the denoiser.
13 |
14 | The sampler consists in the second-order DPMSolver++2S solver (Lu et al., 2022), augmented with
15 | the stochastic churn (again making use of the isotropic noise) and noise inflation techniques
16 | used in Karras et al. (2022) to inject further stochasticity into the sampling process. In
17 | conditioning on previous timesteps it follows the Conditional Denoising Estimator approach
18 | outlined and motivated by Batzolis et al. (2021).
19 | """
20 |
21 | def __init__(
22 | self,
23 | S_noise: float = 1.05,
24 | S_tmin: float = 0.75,
25 | S_tmax: float = 80.0,
26 | S_churn: float = 2.5,
27 | r: float = 0.5,
28 | sigma_max: float = 80.0,
29 | sigma_min: float = 0.03,
30 | rho: float = 7,
31 | num_steps: int = 20,
32 | ):
33 | """Initialize the sampler.
34 |
35 | Args:
36 | S_noise (float): noise inflation parameter. Defaults to 1.05.
37 | S_tmin (float): minimum noise for sampling. Defaults to 0.75.
38 | S_tmax (float): maximum noise for sampling. Defaults to 80.
39 | S_churn (float): stochastic churn rate. Defaults to 2.5.
40 | r (float): _description_. Defaults to 0.5.
41 | sigma_max (float): maximum value of sigma for sigma's distribution. Defaults to 80.
42 | sigma_min (float): minimum value of sigma for sigma's distribution. Defaults to 0.03.
43 | rho (float): exponent of the sigma's distribution. Defaults to 7.
44 | num_steps (int): number of timesteps during sampling. Defaults to 20.
45 | """
46 | self.S_noise = S_noise
47 | self.S_tmin = S_tmin
48 | self.S_tmax = S_tmax
49 | self.S_churn = S_churn
50 | self.r = r
51 | self.num_steps = num_steps
52 |
53 | self.sigma_max = sigma_max
54 | self.sigma_min = sigma_min
55 | self.rho = rho
56 |
57 | def _sigmas_fn(self, u):
58 | return (
59 | self.sigma_max ** (1 / self.rho)
60 | + u * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))
61 | ) ** self.rho
62 |
63 | @torch.no_grad()
64 | def sample(self, denoiser: Denoiser, prev_inputs: torch.Tensor):
65 | """Generate a sample from random noise for the given inputs.
66 |
67 | Args:
68 | denoiser (Denoiser): the denoiser model.
69 | prev_inputs (torch.Tensor): previous two timesteps.
70 |
71 | Returns:
72 | torch.Tensor: normalized residuals predicted.
73 | """
74 | device = prev_inputs.device
75 |
76 | time_steps = torch.arange(0, self.num_steps).to(device) / (self.num_steps - 1)
77 | sigmas = self._sigmas_fn(time_steps)
78 |
79 | batch_ones = torch.ones(1, 1).to(device)
80 |
81 | # initialize noise
82 | x = sigmas[0] * torch.tensor(
83 | generate_isotropic_noise(
84 | num_lon=denoiser.num_lon,
85 | num_lat=denoiser.num_lat,
86 | num_samples=denoiser.output_features_dim,
87 | )
88 | ).unsqueeze(0).to(device)
89 |
90 | for i in range(len(sigmas) - 1):
91 | # stochastic churn from Karras et al. (Alg. 2)
92 | gamma = (
93 | min(self.S_churn / self.num_steps, math.sqrt(2) - 1)
94 | if self.S_tmin <= sigmas[i] <= self.S_tmax
95 | else 0.0
96 | )
97 | # noise inflation from Karras et al. (Alg. 2)
98 | noise = self.S_noise * torch.tensor(
99 | generate_isotropic_noise(
100 | num_lon=denoiser.num_lon,
101 | num_lat=denoiser.num_lat,
102 | num_samples=denoiser.output_features_dim,
103 | )
104 | )
105 | noise = noise.unsqueeze(0).to(device)
106 |
107 | sigma_hat = sigmas[i] * (gamma + 1)
108 | if gamma > 0:
109 | x = x + (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 * noise
110 | denoised = denoiser(x, prev_inputs, sigma_hat * batch_ones)
111 |
112 | if i == len(sigmas) - 2:
113 | # final Euler step
114 | d = (x - denoised) / sigma_hat
115 | x = x + d * (sigmas[i + 1] - sigma_hat)
116 | else:
117 | # DPMSolver++2S step (Alg. 1 in Lu et al.) with alpha_t=1.
118 | # t_{i-1} is t_hat because of stochastic churn!
119 | lambda_hat = -torch.log(sigma_hat)
120 | lambda_next = -torch.log(sigmas[i + 1])
121 | h = lambda_next - lambda_hat
122 | lambda_mid = lambda_hat + self.r * h
123 | sigma_mid = torch.exp(-lambda_mid)
124 |
125 | u = sigma_mid / sigma_hat * x - (torch.exp(-self.r * h) - 1) * denoised
126 | denoised_2 = denoiser(u, prev_inputs, sigma_mid * batch_ones)
127 | D = (1 - 1 / (2 * self.r)) * denoised + 1 / (2 * self.r) * denoised_2
128 | x = sigmas[i + 1] / sigma_hat * x - (torch.exp(-h) - 1) * D
129 |
130 | return x
131 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Utils for gencast."""
2 |
3 | from .noise import generate_isotropic_noise
4 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/utils/batching.py:
--------------------------------------------------------------------------------
1 | """Utils for batching graphs."""
2 |
3 | import torch
4 |
5 |
6 | def batch(senders, edge_index, edge_attr=None, batch_size=1):
7 | """Build big batched graph.
8 |
9 | Returns nodes and edges of a big graph with batch_size disconnected copies of the original
10 | graph, with features shape [(b n) f].
11 |
12 | Args:
13 | senders (torch.Tensor): nodes' features.
14 | edge_index (torch.Tensor): edge index tensor.
15 | edge_attr (torch.Tensor, optional): edge attributes tensor, if None returns None.
16 | Defaults to None.
17 | batch_size (int): batch size. Defaults to 1.
18 |
19 | Returns:
20 | batched_senders, batched_edge_index, batched_edge_attr
21 | """
22 | ns = senders.shape[0]
23 | batched_senders = senders
24 | batched_edge_attr = edge_attr
25 | batched_edge_index = edge_index
26 |
27 | for i in range(1, batch_size):
28 | batched_senders = torch.cat([batched_senders, senders], dim=0)
29 | batched_edge_index = torch.cat([batched_edge_index, edge_index + i * ns], dim=1)
30 |
31 | if edge_attr is not None:
32 | batched_edge_attr = torch.cat([batched_edge_attr, edge_attr], dim=0)
33 |
34 | return batched_senders, batched_edge_index, batched_edge_attr
35 |
36 |
37 | def hetero_batch(senders, receivers, edge_index, edge_attr=None, batch_size=1):
38 | """Build big batched heterogenous graph.
39 |
40 | Returns nodes and edges of a big graph with batch_size disconnected copies of the original
41 | graph, with features shape [(b n) f].
42 |
43 | Args:
44 | senders (torch.Tensor): senders' features.
45 | receivers (torch.Tensor): receivers' features.
46 | edge_index (torch.Tensor): edge index tensor.
47 | edge_attr (torch.Tensor, optional): edge attributes tensor, if None returns None.
48 | Defaults to None.
49 | batch_size (int): batch size. Defaults to 1.
50 |
51 | Returns:
52 | batched_senders, batched_edge_index, batched_edge_attr
53 | """
54 | ns = senders.shape[0]
55 | nr = receivers.shape[0]
56 | nodes_shape = torch.tensor([[ns], [nr]]).to(edge_index)
57 | batched_senders = senders
58 | batched_receivers = receivers
59 | batched_edge_attr = edge_attr
60 | batched_edge_index = edge_index
61 |
62 | for i in range(1, batch_size):
63 | batched_senders = torch.cat([batched_senders, senders], dim=0)
64 | batched_receivers = torch.cat([batched_receivers, receivers], dim=0)
65 | batched_edge_index = torch.cat([batched_edge_index, edge_index + i * nodes_shape], dim=1)
66 | if edge_attr is not None:
67 | batched_edge_attr = torch.cat([batched_edge_attr, edge_attr], dim=0)
68 |
69 | return batched_senders, batched_receivers, batched_edge_index, batched_edge_attr
70 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/utils/noise.py:
--------------------------------------------------------------------------------
1 | """Noise generation utils."""
2 |
3 | import einops
4 | import numpy as np
5 | import torch
6 | import torch_harmonics as th
7 |
8 |
9 | def generate_isotropic_noise(num_lon: int, num_lat: int, num_samples=1, isotropic=True):
10 | """Generate noise on the grid.
11 |
12 | When isotropic is True it samples the equivalent of white noise on a sphere and project it onto
13 | a grid using Driscoll and Healy, 1994, algorithm. The power spectrum is normalized to have
14 | variance 1. We need to assume lons = 2 * lats or lons = 2 * (lats -1). If isotropic is false, it
15 | samples flat normal random noise.
16 |
17 | Args:
18 | num_lon (int): number of longitudes in the grid.
19 | num_lat (int): number of latitudes in the grid.
20 | num_samples (int): number of indipendent samples. Defaults to 1.
21 | isotropic (bool): if true generates isotropic noise, else flat noise. Defaults to True.
22 |
23 | Returns:
24 | grid: Numpy array with shape shape(grid) x num_samples.
25 | """
26 | if isotropic:
27 | if 2 * num_lat == num_lon:
28 | extend = False
29 | elif 2 * (num_lat - 1) == num_lon:
30 | extend = True
31 | else:
32 | raise ValueError(
33 | "Isotropic noise requires grid's shape to be 2N x N or 2N x (N+1): "
34 | f"got {num_lon} x {num_lat}. If the shape is correct, please specify"
35 | "isotropic=False in the constructor.",
36 | )
37 |
38 | if isotropic:
39 | lmax = num_lat - 1 if extend else num_lat
40 | mmax = lmax + 1
41 | coeffs = torch.randn(num_samples, lmax, mmax, dtype=torch.complex64) / np.sqrt(
42 | (num_lat**2) // 2
43 | )
44 | isht = th.InverseRealSHT(
45 | nlat=num_lat, nlon=num_lon, lmax=lmax, mmax=mmax, grid="equiangular"
46 | )
47 | noise = isht(coeffs) * np.sqrt(2 * np.pi)
48 | noise = einops.rearrange(noise, "b lat lon -> lon lat b").numpy()
49 | else:
50 | noise = np.random.randn(num_lon, num_lat, num_samples)
51 | return noise
52 |
53 |
54 | def sample_noise_level(sigma_min=0.02, sigma_max=88, rho=7):
55 | """Generate random sample of noise level.
56 |
57 | Sample a noise level according to the distribution described in the paper.
58 | Notice that the default values are valid only for training and need to be
59 | modified for sampling.
60 |
61 | Args:
62 | sigma_min (float, optional): Defaults to 0.02.
63 | sigma_max (int, optional): Defaults to 88.
64 | rho (int, optional): Defaults to 7.
65 |
66 | Returns:
67 | noise_level: single sample of noise level.
68 | """
69 | u = np.random.random()
70 | noise_level = (
71 | sigma_max ** (1 / rho) + u * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
72 | ) ** rho
73 | return noise_level
74 |
75 |
76 | class Preconditioner(torch.nn.Module):
77 | """Collection of preconditioning functions.
78 |
79 | These functions are described in Karras (2022), table 1.
80 | """
81 |
82 | def __init__(self, sigma_data: float = 1):
83 | """Initialize the preconditioning functions.
84 |
85 | Args:
86 | sigma_data (float): Karras suggests 0.5, GenCast 1. Defaults to 1.
87 | """
88 | super().__init__()
89 | self.sigma_data = sigma_data
90 |
91 | def c_skip(self, sigma):
92 | """Scaling factor for skip connection."""
93 | return self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
94 |
95 | def c_out(self, sigma):
96 | """Scaling factor for output."""
97 | return sigma * self.sigma_data / torch.sqrt(sigma**2 + self.sigma_data**2)
98 |
99 | def c_in(self, sigma):
100 | """Scaling factor for input."""
101 | return 1 / torch.sqrt(sigma**2 + self.sigma_data**2)
102 |
103 | def c_noise(self, sigma):
104 | """Scaling factor for noise level."""
105 | return 1 / 4 * torch.log(sigma)
106 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/utils/statistics.py:
--------------------------------------------------------------------------------
1 | """Statistics computation utils."""
2 |
3 | import apache_beam # noqa: F401
4 | import numpy as np
5 | import weatherbench2 # noqa: F401
6 | import xarray as xr
7 |
8 |
9 | def compute_statistics(dataset, vars, num_samples=100, single=False):
10 | """Compute statistics for single timestep.
11 |
12 | Args:
13 | dataset: xarray dataset.
14 | vars: list of features.
15 | num_samples (int, optional): _description_. Defaults to 100.
16 | single (bool, optional): if the features have multiple pressure levels. Defaults to False.
17 |
18 | Returns:
19 | means: dict with the means.
20 | stds: dict with the stds.
21 | """
22 | means = {}
23 | stds = {}
24 | for var in vars:
25 | print(f"Computing statistics for {var}")
26 | random_indexes = np.random.randint(0, len(dataset.time), num_samples)
27 | samples = data.isel(time=random_indexes)[var].values
28 | samples = np.nan_to_num(samples)
29 | axis_tuple = (0, 1, 2) if single else (0, 2, 3)
30 | means[var] = samples.mean(axis=axis_tuple)
31 | stds[var] = samples.std(axis=axis_tuple)
32 | return means, stds
33 |
34 |
35 | def compute_statistics_diff(dataset, vars, num_samples=100, single=False, timestep=2):
36 | """Compute statistics for difference of two timesteps.
37 |
38 | Args:
39 | dataset: xarray dataset.
40 | vars: list of features.
41 | num_samples (int, optional): _description_. Defaults to 100.
42 | single (bool, optional): if the features have multiple pressure levels. Defaults to False.
43 | timestep (int, optional): number of steps to consider between start and end. Defaults to 2.
44 |
45 | Returns:
46 | means: dict with the means.
47 | stds: dict with the stds.
48 | """
49 | means = {}
50 | stds = {}
51 | for var in vars:
52 | print(f"Computing statistics for {var}")
53 | random_indexes = np.random.randint(0, len(dataset.time), num_samples)
54 | samples_start = data.isel(time=random_indexes)[var].values
55 | samples_start = np.nan_to_num(samples_start)
56 | samples_end = data.isel(time=random_indexes + timestep)[var].values
57 | samples_end = np.nan_to_num(samples_end)
58 | axis_tuple = (0, 1, 2) if single else (0, 2, 3)
59 | means[var] = (samples_end - samples_start).mean(axis=axis_tuple)
60 | stds[var] = (samples_end - samples_start).std(axis=axis_tuple)
61 | return means, stds
62 |
63 |
64 | atmospheric_features = [
65 | "geopotential",
66 | "specific_humidity",
67 | "temperature",
68 | "u_component_of_wind",
69 | "v_component_of_wind",
70 | "vertical_velocity",
71 | ]
72 |
73 | single_features = [
74 | "2m_temperature",
75 | "10m_u_component_of_wind",
76 | "10m_v_component_of_wind",
77 | "mean_sea_level_pressure",
78 | # "sea_surface_temperature",
79 | "total_precipitation_12hr",
80 | ]
81 |
82 | static_features = ["geopotential_at_surface", "land_sea_mask"]
83 |
84 | obs_path = "gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr"
85 | # obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-1440x721.zarr'
86 | # obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-512x256_equiangular_conservative.zarr'
87 | data = xr.open_zarr(obs_path)
88 | num_samples = 100
89 | means, stds = compute_statistics_diff(data, single_features, num_samples=num_samples, single=True)
90 | print("Means: ", means)
91 | print("Stds: ", stds)
92 |
--------------------------------------------------------------------------------
/graph_weather/models/gencast/weighted_mse_loss.py:
--------------------------------------------------------------------------------
1 | """The weighted loss function for GenCast training."""
2 |
3 | from typing import Optional
4 |
5 | import numpy as np
6 | import torch
7 |
8 |
9 | class WeightedMSELoss(torch.nn.Module):
10 | """Module WeightedMSELoss.
11 |
12 | This module implement the loss described in GenCast's paper.
13 | """
14 |
15 | def __init__(
16 | self,
17 | grid_lat: Optional[torch.Tensor] = None,
18 | pressure_levels: Optional[torch.Tensor] = None,
19 | num_atmospheric_features: Optional[int] = None,
20 | single_features_weights: Optional[torch.Tensor] = None,
21 | ):
22 | """Initialize the WeightedMSELoss Module.
23 |
24 | More details about the features weights are reported in GraphCast's paper. In short, if the
25 | single features are "2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind",
26 | "mean_sea_level_pressure" and "total_precipitation_12hr", then it's suggested to set
27 | corresponding weights as 1, 0.1, 0.1, 0.1 and 0.1.
28 |
29 | Args:
30 | grid_lat (torch.Tensor, optional): 1D tensor containing all the latitudes.
31 | pressure_levels (torch.Tensor, optional): 1D tensor containing all the pressure levels
32 | per variable.
33 | num_atmospheric_features (int, optional): number of atmospheric features.
34 | single_features_weights (torch.Tensor, optional): 1D tensor containing single features
35 | weights.
36 | """
37 | super().__init__()
38 |
39 | area_weights = None
40 | features_weights = None
41 |
42 | if grid_lat is not None:
43 | area_weights = torch.abs(torch.cos(grid_lat * np.pi / 180.0))
44 | area_weights = area_weights / torch.mean(area_weights)
45 | if (
46 | pressure_levels is not None
47 | and num_atmospheric_features is not None
48 | and single_features_weights is not None
49 | ):
50 | pressure_weights = pressure_levels / torch.sum(pressure_levels)
51 | features_weights = torch.cat(
52 | (pressure_weights.repeat(num_atmospheric_features), single_features_weights), dim=-1
53 | )
54 | elif (
55 | pressure_levels is not None
56 | or num_atmospheric_features is not None
57 | or single_features_weights is not None
58 | ):
59 | raise ValueError(
60 | "Please to use features weights provide all three: pressure_levels,"
61 | "num_atmospheric_features and single_features_weights."
62 | )
63 |
64 | self.sigma_data = 1 # assuming normalized data!
65 |
66 | self.register_buffer("area_weights", area_weights, persistent=False)
67 | self.register_buffer("features_weights", features_weights, persistent=False)
68 |
69 | def _lambda_sigma(self, noise_level):
70 | noise_weights = (noise_level**2 + self.sigma_data**2) / (noise_level * self.sigma_data) ** 2
71 | return noise_weights # [batch, 1]
72 |
73 | def forward(
74 | self,
75 | pred: torch.Tensor,
76 | noise_level: torch.Tensor,
77 | target: torch.Tensor,
78 | ) -> torch.Tensor:
79 | """Compute the loss.
80 |
81 | Args:
82 | pred (torch.Tensor): prediction of the model [batch, lon, lat, var].
83 | noise_level (torch.Tensor): noise levels fed to the model for the corresponding
84 | predictions [batch, 1].
85 | target (torch.Tensor): target tensor [batch, lon, lat, var].
86 |
87 | Returns:
88 | torch.Tensor: weighted MSE loss.
89 | """
90 | # check shapes
91 | if not (pred.shape == target.shape):
92 | raise ValueError(
93 | "Predictions and targets must have same shape. The actual shapes "
94 | f"are {pred.shape} and {target.shape}."
95 | )
96 | if not (len(pred.shape) == 4):
97 | raise ValueError(
98 | "The expected shape for predictions and targets is "
99 | f"[batch, lon, lat, var], but got {pred.shape}."
100 | )
101 | if not (noise_level.shape == (pred.shape[0], 1)):
102 | raise ValueError(
103 | f"The expected shape for noise levels is [batch, 1], but got {noise_level.shape}."
104 | )
105 |
106 | # compute square residuals
107 | loss = (pred - target) ** 2 # [batch, lon, lat, var]
108 | if torch.isnan(loss).any():
109 | raise ValueError("NaN values encountered in loss calculation.")
110 |
111 | # apply area and features weights to residuals
112 | if self.area_weights is not None:
113 | if not (len(self.area_weights) == pred.shape[2]):
114 | raise ValueError(
115 | f"The size of grid_lat at initialization ({len(self.area_weights)}) "
116 | f"and the number of latitudes in predictions ({pred.shape[2]}) "
117 | "don't match."
118 | )
119 | loss *= self.area_weights[None, None, :, None]
120 |
121 | if self.features_weights is not None:
122 | if not (len(self.features_weights) == pred.shape[-1]):
123 | raise ValueError(
124 | f"The size of features weights at initialization ({len(self.features_weights)})"
125 | f" and the number of features in predictions ({pred.shape[-1]}) "
126 | "don't match."
127 | )
128 | loss *= self.features_weights[None, None, None, :]
129 |
130 | # compute means across lon, lat, var for each sample in the batch
131 | loss = loss.flatten(1).mean(-1) # [batch]
132 |
133 | # weight each sample using the corresponding noise level, then return the mean.
134 | loss *= self._lambda_sigma(noise_level).flatten()
135 | return loss.mean()
136 |
--------------------------------------------------------------------------------
/graph_weather/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 | """Layers for use in models"""
2 |
--------------------------------------------------------------------------------
/graph_weather/models/layers/assimilator_decoder.py:
--------------------------------------------------------------------------------
1 | """Decoders to decode from the Processor graph to the original graph with updated values
2 |
3 | In the original paper the decoder is described as
4 |
5 | The Decoder maps back to physical data defined on a latitude/longitude grid. The underlying graph is
6 | again bipartite, this time mapping icosahedron→lat/lon.
7 | The inputs to the Decoder come from the Processor, plus a skip connection back to the original
8 | state of the 78 atmospheric variables onthe latitude/longitude grid.
9 | The output of the Decoder is the predicted 6-hour change in the 78 atmospheric variables,
10 | which is then added to the initial state to produce the new state. We found 6 hours to be a good
11 | balance between shorter time steps (simpler dynamics to model but more iterations required during
12 | rollout) and longer time steps (fewer iterations required during rollout but modeling
13 | more complex dynamics)
14 |
15 | """
16 |
17 | import einops
18 | import h3
19 | import numpy as np
20 | import torch
21 | from torch_geometric.data import Data
22 |
23 | from graph_weather.models.layers.graph_net_block import MLP, GraphProcessor
24 |
25 |
26 | class AssimilatorDecoder(torch.nn.Module):
27 | """Assimilator graph module"""
28 |
29 | def __init__(
30 | self,
31 | lat_lons: list,
32 | resolution: int = 2,
33 | input_dim: int = 256,
34 | output_dim: int = 78,
35 | output_edge_dim: int = 256,
36 | hidden_dim_processor_node: int = 256,
37 | hidden_dim_processor_edge: int = 256,
38 | hidden_layers_processor_node: int = 2,
39 | hidden_layers_processor_edge: int = 2,
40 | mlp_norm_type: str = "LayerNorm",
41 | hidden_dim_decoder: int = 128,
42 | hidden_layers_decoder: int = 2,
43 | use_checkpointing: bool = False,
44 | ):
45 | """
46 | Decoder from latent graph to lat/lon graph for assimilation of observation
47 |
48 | Args:
49 | lat_lons: List of (lat,lon) points
50 | resolution: H3 resolution level
51 | input_dim: Input node dimension
52 | output_dim: Output node dimension
53 | output_edge_dim: Edge dimension
54 | hidden_dim_processor_node: Hidden dimension of the node processors
55 | hidden_dim_processor_edge: Hidden dimension of the edge processors
56 | hidden_layers_processor_node: Number of hidden layers in the node processors
57 | hidden_layers_processor_edge: Number of hidden layers in the edge processors
58 | hidden_dim_decoder:Number of hidden dimensions in the decoder
59 | hidden_layers_decoder: Number of layers in the decoder
60 | mlp_norm_type: Type of norm for the MLPs
61 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
62 | use_checkpointing: Whether to use gradient checkpointing to reduce model size
63 | """
64 | super().__init__()
65 | self.use_checkpointing = use_checkpointing
66 | self.num_latlons = len(lat_lons)
67 | self.base_h3_grid = sorted(list(h3.uncompact(h3.get_res0_indexes(), resolution)))
68 | self.num_h3 = len(self.base_h3_grid)
69 | self.h3_grid = [h3.geo_to_h3(lat, lon, resolution) for lat, lon in lat_lons]
70 | self.h3_to_index = {}
71 | h_index = len(self.base_h3_grid)
72 | for h in self.base_h3_grid:
73 | if h not in self.h3_to_index:
74 | h_index -= 1
75 | self.h3_to_index[h] = h_index
76 | self.h3_mapping = {}
77 | for h, value in enumerate(self.h3_grid):
78 | self.h3_mapping[h + self.num_h3] = value
79 |
80 | # Build the default graph
81 | # Extra starting ones for appending to inputs, could 'learn' good starting points
82 | self.latlon_nodes = torch.zeros((len(lat_lons), input_dim), dtype=torch.float)
83 | # Get connections between lat nodes and h3 nodes TODO Paper makes it seem like the 3
84 | # closest iso points map to the lat/lon point Do kring 1 around current h3 cell,
85 | # and calculate distance between all those points and the lat/lon one, choosing the
86 | # nearest N (3) For a bit simpler, just include them all with their distances
87 | edge_sources = []
88 | edge_targets = []
89 | self.h3_to_lat_distances = []
90 | for node_index, h_node in enumerate(self.h3_grid):
91 | # Get h3 index
92 | h_points = h3.k_ring(self.h3_mapping[node_index + self.num_h3], 1)
93 | for h in h_points:
94 | distance = h3.point_dist(lat_lons[node_index], h3.h3_to_geo(h), unit="rads")
95 | self.h3_to_lat_distances.append([np.sin(distance), np.cos(distance)])
96 | edge_sources.append(self.h3_to_index[h])
97 | edge_targets.append(node_index + self.num_h3)
98 | edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long)
99 | self.h3_to_lat_distances = torch.tensor(self.h3_to_lat_distances, dtype=torch.float)
100 |
101 | # Use normal graph as its a bit simpler
102 | self.graph = Data(edge_index=edge_index, edge_attr=self.h3_to_lat_distances)
103 |
104 | self.edge_encoder = MLP(
105 | 2, output_edge_dim, hidden_dim_processor_edge, 2, mlp_norm_type, self.use_checkpointing
106 | )
107 | self.graph_processor = GraphProcessor(
108 | mp_iterations=1,
109 | in_dim_node=input_dim,
110 | in_dim_edge=output_edge_dim,
111 | hidden_dim_node=hidden_dim_processor_node,
112 | hidden_dim_edge=hidden_dim_processor_edge,
113 | hidden_layers_node=hidden_layers_processor_node,
114 | hidden_layers_edge=hidden_layers_processor_edge,
115 | norm_type=mlp_norm_type,
116 | )
117 | self.node_decoder = MLP(
118 | input_dim,
119 | output_dim,
120 | hidden_dim_decoder,
121 | hidden_layers_decoder,
122 | None,
123 | self.use_checkpointing,
124 | )
125 |
126 | def forward(self, processor_features: torch.Tensor, batch_size: int) -> torch.Tensor:
127 | """
128 | Adds features to the encoding graph
129 |
130 | Args:
131 | processor_features: Processed features in shape [B*Nodes, Features]
132 | batch_size: Batch size
133 |
134 | Returns:
135 | Updated features for model
136 | """
137 | self.graph = self.graph.to(processor_features.device)
138 | edge_attr = self.edge_encoder(self.graph.edge_attr) # Update attributes based on distance
139 | edge_attr = einops.repeat(edge_attr, "e f -> (repeat e) f", repeat=batch_size)
140 |
141 | edge_index = torch.cat(
142 | [
143 | self.graph.edge_index + i * torch.max(self.graph.edge_index) + i
144 | for i in range(batch_size)
145 | ],
146 | dim=1,
147 | )
148 |
149 | # Readd nodes to match graph node number
150 | self.latlon_nodes = self.latlon_nodes.to(processor_features.device)
151 | features = einops.rearrange(processor_features, "(b n) f -> b n f", b=batch_size)
152 | features = torch.cat(
153 | [features, einops.repeat(self.latlon_nodes, "n f -> b n f", b=batch_size)], dim=1
154 | )
155 | features = einops.rearrange(features, "b n f -> (b n) f")
156 |
157 | out, _ = self.graph_processor(features, edge_index, edge_attr) # Message Passing
158 | # Remove the h3 nodes now, only want the latlon ones
159 | out = self.node_decoder(out) # Decode to 78 from 256
160 | out = einops.rearrange(out, "(b n) f -> b n f", b=batch_size)
161 | test, out = torch.split(out, [self.num_h3, self.num_latlons], dim=1)
162 | return out
163 |
--------------------------------------------------------------------------------
/graph_weather/models/layers/constraint_layer.py:
--------------------------------------------------------------------------------
1 | """Module for physical constraint layers used in graph weather models.
2 |
3 | This module implements several constraints on a network’s intermediate outputs,
4 | ensuring physical consistency with an input at a lower resolution.
5 |
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class PhysicalConstraintLayer(nn.Module):
13 | """
14 |
15 | This module implements several constraint types on the network’s intermediate outputs ỹ,
16 | given the corresponding low-resolution input x. The following equations are implemented
17 | (with all operations acting per patch – here, a patch is the full grid of H×W pixels):
18 |
19 | Additive constraint:
20 | y = ỹ + x - avg(ỹ)
21 |
22 | Multiplicative constraint:
23 | y = ỹ * ( x / avg(ỹ) )
24 |
25 | Softmax constraint:
26 | y = exp(ỹ) * ( x / sum(exp(ỹ)) )
27 |
28 | We assume that both the intermediate outputs and the low-resolution reference are 4D
29 | tensors in grid format, with shape [B, C, H, W], where n = H*W is the number of pixels
30 | (or nodes) in a patch.
31 | """
32 |
33 | def __init__(
34 | self, model, grid_shape, upsampling_factor, constraint_type="none", exp_factor=1.0
35 | ):
36 | """Initialize the PhysicalConstraintLayer.
37 |
38 | Args:
39 | model (nn.Module): The model containing the helper methods
40 | 'graph_to_grid' and 'grid_to_graph'.
41 | grid_shape (tuple): Expected spatial dimensions (H, W) of the
42 | high-resolution grid.
43 | upsampling_factor (int): Factor by which the low-resolution grid is upsampled.
44 | constraint_type (str, optional): The constraint to apply. Options are
45 | 'additive', 'multiplicative', or 'softmax'. Defaults to "none".
46 | exp_factor (float, optional): Exponent factor for the softmax constraint.
47 | Defaults to 1.0.
48 | """
49 | super().__init__()
50 | self.model = model
51 | self.constraint_type = constraint_type
52 | self.grid_shape = grid_shape
53 | self.exp_factor = exp_factor
54 | self.upsampling_factor = upsampling_factor
55 | self.pool = nn.AvgPool2d(kernel_size=upsampling_factor)
56 |
57 | def forward(self, hr_graph, lr_graph):
58 | """Apply the selected physical constraint.
59 |
60 | Processes the high-resolution output and low-resolution input by converting
61 | between graph and grid formats as needed, and then applying the specified constraint.
62 |
63 | Args:
64 | hr_graph (torch.Tensor): High-resolution model output in either graph (3D)
65 | or grid (4D) format.
66 | lr_graph (torch.Tensor): Low-resolution input in the corresponding
67 | graph or grid format.
68 |
69 | Returns:
70 | torch.Tensor: The adjusted output in graph format.
71 | """
72 | # Check if inputs are in graph (3D) or grid (4D) formats.
73 | if hr_graph.dim() == 3:
74 | # Convert graph format to grid format
75 | hr_grid = self.model.graph_to_grid(hr_graph)
76 | lr_grid = self.model.graph_to_grid(lr_graph)
77 | elif hr_graph.dim() == 4:
78 | # Already in grid format: [B, C, H, W]
79 | _, _, H, W = hr_graph.shape
80 | if (H, W) != self.grid_shape:
81 | raise ValueError(f"Expected spatial dimensions {self.grid_shape}, got {(H, W)}")
82 | hr_grid = hr_graph
83 | lr_grid = lr_graph
84 | else:
85 | raise ValueError("Input tensor must be either 3D (graph) or 4D (grid).")
86 |
87 | # Apply constraint based on type in grid format
88 | if self.constraint_type == "additive":
89 | result = self.additive_constraint(hr_grid, lr_grid)
90 | elif self.constraint_type == "multiplicative":
91 | result = self.multiplicative_constraint(hr_grid, lr_grid)
92 | elif self.constraint_type == "softmax":
93 | result = self.softmax_constraint(hr_grid, lr_grid)
94 | else:
95 | raise ValueError(f"Unknown constraint type: {self.constraint_type}")
96 |
97 | # Convert grid back to graph format
98 | return self.model.grid_to_graph(result)
99 |
100 | def additive_constraint(self, hr, lr):
101 | """Enforces local conservation using an additive correction:
102 | y = ỹ + ( x - avg(ỹ) )
103 | where avg(ỹ) is computed per patch (via an average-pooling layer).
104 |
105 | For the additive constraint we follow the paper’s formulation using a Kronecker
106 | product to expand the discrepancy between the low-resolution field and the
107 | average of the high-resolution output.
108 |
109 | hr: high-resolution tensor [B, C, H_hr, W_hr]
110 | lr: low-resolution tensor [B, C, h_lr, w_lr]
111 | (with H_hr = upsampling_factor * h_lr & W_hr = upsampling_factor * w_lr)
112 | """
113 | # Convert grids to graph format using model's mapping
114 | hr_graph = self.model.grid_to_graph(hr)
115 | lr_graph = self.model.grid_to_graph(lr)
116 |
117 | # Apply constraint logic
118 | # Compute average over NODES
119 | avg_hr = hr_graph.mean(dim=1, keepdim=True)
120 | diff = lr_graph - avg_hr
121 |
122 | # Expand difference using spatial mapping
123 | diff_expanded = diff.repeat(1, self.upsampling_factor**2, 1)
124 |
125 | # Apply correction and convert back to GRID format
126 | adjusted_graph = hr_graph + diff_expanded
127 | return self.model.graph_to_grid(adjusted_graph)
128 |
129 | def multiplicative_constraint(self, hr, lr):
130 | """Enforce conservation using a multiplicative correction in graph space.
131 |
132 | The correction is applied by scaling the high-resolution output by a ratio computed
133 | from the low-resolution input and the average of the high-resolution output.
134 |
135 | Args:
136 | hr (torch.Tensor): High-resolution tensor in grid format [B, C, H_hr, W_hr].
137 | lr (torch.Tensor): Low-resolution tensor in grid format [B, C, h_lr, w_lr].
138 |
139 | Returns:
140 | torch.Tensor: Adjusted high-resolution tensor in grid format.
141 | """
142 | # Convert grids to graph format using model's mapping
143 | hr_graph = self.model.grid_to_graph(hr)
144 | lr_graph = self.model.grid_to_graph(lr)
145 |
146 | # Apply constraint logic
147 | # Compute average over NODES
148 | avg_hr = hr_graph.mean(dim=1, keepdim=True)
149 | lr_patch_avg = lr_graph.mean(dim=1, keepdim=True)
150 |
151 | # Compute ratio and expand to match HR graph structure
152 | ratio = lr_patch_avg / (avg_hr + 1e-8)
153 |
154 | # Apply multiplicative correction and convert back to GRID format
155 | adjusted_graph = hr_graph * ratio
156 | return self.model.graph_to_grid(adjusted_graph)
157 |
158 | def softmax_constraint(self, y, lr):
159 | """Apply a softmax-based constraint correction.
160 |
161 | The softmax correction scales the exponentiated high-resolution output so that the
162 | sum over spatial blocks matches the low-resolution reference.
163 |
164 | Args:
165 | y (torch.Tensor): High-resolution tensor in grid format [B, C, H, W].
166 | lr (torch.Tensor): Low-resolution tensor in grid format [B, C, h, w].
167 |
168 | Returns:
169 | torch.Tensor: Adjusted high-resolution tensor in grid format after applying
170 | the softmax constraint.
171 | """
172 | # Apply the exponential function
173 | y = torch.exp(self.exp_factor * y)
174 |
175 | # Pool over spatial blocks
176 | kernel_area = self.upsampling_factor**2
177 | sum_y = self.pool(y) * kernel_area
178 |
179 | # Ensure that lr * (1/sum_y) is contiguous
180 | ratio = (lr * (1 / sum_y)).contiguous()
181 |
182 | # Use device of lr for kron expansion:
183 | device = lr.device
184 | expansion = torch.ones((self.upsampling_factor, self.upsampling_factor), device=device)
185 |
186 | # Expand the low-resolution ratio and correct the y values so that the block sum matches lr.
187 | out = y * torch.kron(ratio, expansion)
188 | return out
189 |
--------------------------------------------------------------------------------
/graph_weather/models/layers/decoder.py:
--------------------------------------------------------------------------------
1 | """Decoders to decode from the Processor graph to the original graph with updated values
2 |
3 | In the original paper the decoder is described as
4 |
5 | The Decoder maps back to physical data defined on a latitude/longitude grid. The underlying graph is
6 | again bipartite, this time mapping icosahedron→lat/lon.
7 | The inputs to the Decoder come from the Processor, plus a skip connection back to the original
8 | state of the 78 atmospheric variables onthe latitude/longitude grid.
9 | The output of the Decoder is the predicted 6-hour change in the 78 atmospheric variables,
10 | which is then added to the initial state to produce the new state. We found 6 hours to be a good
11 | balance between shorter time steps (simpler dynamics to model but more iterations required during
12 | rollout) and longer time steps (fewer iterations required during rollout but modeling
13 | more complex dynamics)
14 |
15 | """
16 |
17 | import torch
18 |
19 | from graph_weather.models.layers.assimilator_decoder import AssimilatorDecoder
20 |
21 |
22 | class Decoder(AssimilatorDecoder):
23 | """Decoder graph module"""
24 |
25 | def __init__(
26 | self,
27 | lat_lons,
28 | resolution: int = 2,
29 | input_dim: int = 256,
30 | output_dim: int = 78,
31 | output_edge_dim: int = 256,
32 | hidden_dim_processor_node: int = 256,
33 | hidden_dim_processor_edge: int = 256,
34 | hidden_layers_processor_node: int = 2,
35 | hidden_layers_processor_edge: int = 2,
36 | mlp_norm_type: str = "LayerNorm",
37 | hidden_dim_decoder: int = 128,
38 | hidden_layers_decoder: int = 2,
39 | use_checkpointing: bool = False,
40 | ):
41 | """
42 | Decoder from latent graph to lat/lon graph
43 |
44 | Args:
45 | lat_lons: List of (lat,lon) points
46 | resolution: H3 resolution level
47 | input_dim: Input node dimension
48 | output_dim: Output node dimension
49 | output_edge_dim: Edge dimension
50 | hidden_dim_processor_node: Hidden dimension of the node processors
51 | hidden_dim_processor_edge: Hidden dimension of the edge processors
52 | hidden_layers_processor_node: Number of hidden layers in the node processors
53 | hidden_layers_processor_edge: Number of hidden layers in the edge processors
54 | hidden_dim_decoder:Number of hidden dimensions in the decoder
55 | hidden_layers_decoder: Number of layers in the decoder
56 | mlp_norm_type: Type of norm for the MLPs
57 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
58 | use_checkpointing: Whether to use gradient checkpointing or not
59 | """
60 | super().__init__(
61 | lat_lons,
62 | resolution,
63 | input_dim,
64 | output_dim,
65 | output_edge_dim,
66 | hidden_dim_processor_node,
67 | hidden_dim_processor_edge,
68 | hidden_layers_processor_node,
69 | hidden_layers_processor_edge,
70 | mlp_norm_type,
71 | hidden_dim_decoder,
72 | hidden_layers_decoder,
73 | use_checkpointing,
74 | )
75 |
76 | def forward(
77 | self, processor_features: torch.Tensor, start_features: torch.Tensor
78 | ) -> torch.Tensor:
79 | """
80 | Adds features to the encoding graph
81 |
82 | Args:
83 | processor_features: Processed features in shape [B*Nodes, Features]
84 | start_features: Original input features to the encoder, with shape [B, Nodes, Features]
85 |
86 | Returns:
87 | Updated features for model
88 | """
89 | out = super().forward(processor_features, start_features.shape[0])
90 | out = out + start_features # residual connection
91 | return out
92 |
--------------------------------------------------------------------------------
/graph_weather/models/layers/grid_to_points.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/layers/grid_to_points.py
--------------------------------------------------------------------------------
/graph_weather/models/layers/points_to_grid.py:
--------------------------------------------------------------------------------
1 | """
2 | This is designed to abstract away the lat/lon points and as a layer that can be put in front of models that more expecta regular lat/lon grid as input.
3 | """
4 |
5 |
--------------------------------------------------------------------------------
/graph_weather/models/layers/processor.py:
--------------------------------------------------------------------------------
1 | """Processor for the latent graph
2 |
3 | In the original paper the processor is described as
4 |
5 | The Processor iteratively processes the 256-channel latent feature data on the icosahedron grid
6 | using 9 rounds of message-passing GNNs. During each round, a node exchanges information with itself
7 | and its immediate neighbors. There are residual connections between each round of processing.
8 |
9 | """
10 |
11 | import torch
12 |
13 | from graph_weather.models.layers.graph_net_block import GraphProcessor
14 |
15 |
16 | class Processor(torch.nn.Module):
17 | """Processor for latent graphD"""
18 |
19 | def __init__(
20 | self,
21 | input_dim: int = 256,
22 | edge_dim: int = 256,
23 | num_blocks: int = 9,
24 | hidden_dim_processor_node: int = 256,
25 | hidden_dim_processor_edge: int = 256,
26 | hidden_layers_processor_node: int = 2,
27 | hidden_layers_processor_edge: int = 2,
28 | mlp_norm_type: str = "LayerNorm",
29 | ):
30 | """
31 | Latent graph processor
32 |
33 | Args:
34 | input_dim: Input dimension for the node
35 | edge_dim: Edge input dimension
36 | num_blocks: Number of message passing blocks
37 | hidden_dim_processor_node: Hidden dimension of the node processors
38 | hidden_dim_processor_edge: Hidden dimension of the edge processors
39 | hidden_layers_processor_node: Number of hidden layers in the node processors
40 | hidden_layers_processor_edge: Number of hidden layers in the edge processors
41 | mlp_norm_type: Type of norm for the MLPs
42 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
43 | """
44 | super().__init__()
45 | # Build the default graph
46 | # Take features from encoder and put into processor graph
47 | self.input_dim = input_dim
48 |
49 | self.graph_processor = GraphProcessor(
50 | num_blocks,
51 | input_dim,
52 | edge_dim,
53 | hidden_dim_processor_node,
54 | hidden_dim_processor_edge,
55 | hidden_layers_processor_node,
56 | hidden_layers_processor_edge,
57 | mlp_norm_type,
58 | )
59 |
60 | def forward(self, x: torch.Tensor, edge_index, edge_attr) -> torch.Tensor:
61 | """
62 | Adds features to the encoding graph
63 |
64 | Args:
65 | x: Torch tensor containing node features
66 | edge_index: Connectivity of graph, of shape [2, Num edges] in COO format
67 | edge_attr: Edge attribues in [Num edges, Features] shape
68 |
69 | Returns:
70 | torch Tensor containing the values of the nodes of the graph
71 | """
72 | out, _ = self.graph_processor(x, edge_index, edge_attr)
73 | return out
74 |
--------------------------------------------------------------------------------
/graph_weather/models/losses.py:
--------------------------------------------------------------------------------
1 | """Weather loss functions"""
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | class NormalizedMSELoss(torch.nn.Module):
8 | """Loss function described in the paper"""
9 |
10 | def __init__(
11 | self, feature_variance: list, lat_lons: list, device="cpu", normalize: bool = False
12 | ):
13 | """
14 | Normalized MSE Loss as described in the paper
15 |
16 | This re-scales each physical variable such that it has unit-variance in the 3 hour temporal
17 | difference. E.g. for temperature data, divide every one at all pressure levels by
18 | sigma_t_3hr, where sigma^2_T,3hr is the variance of the 3 hour change in temperature,
19 | averaged across space (lat/lon + pressure levels) and time (100 random temporal frames).
20 |
21 | Additionally weights by the cos(lat) of the feature
22 |
23 | cos and sin should be in radians
24 |
25 | Args:
26 | feature_variance: Variance for each of the physical features
27 | lat_lons: List of lat/lon pairs, used to generate weighting
28 | device: checks for device whether it supports gpu or not
29 | normalize: option for normalize
30 | """
31 | # TODO Rescale by nominal static air density at each pressure level, could be 1/pressure level or something similar
32 | super().__init__()
33 | self.feature_variance = torch.tensor(feature_variance)
34 | assert not torch.isnan(self.feature_variance).any()
35 | # Compute unique latitudes from the provided lat/lon pairs.
36 | unique_lats = sorted(set(lat for lat, _ in lat_lons))
37 | # Use the cosine of each unique latitude (converted to radians) as its weight.
38 | self.weights = torch.tensor(
39 | [np.cos(lat * np.pi / 180.0) for lat in unique_lats], dtype=torch.float
40 | )
41 | self.normalize = normalize
42 | assert not torch.isnan(self.weights).any()
43 |
44 | def forward(self, pred: torch.Tensor, target: torch.Tensor):
45 | """
46 | Calculate the loss
47 |
48 | Rescales both predictions and target, so assumes neither are already normalized
49 | Additionally weights by the cos(lat) of the set of features
50 |
51 | Args:
52 | pred: Prediction tensor
53 | target: Target tensor
54 |
55 | Returns:
56 | MSE loss on the variance-normalized values
57 | """
58 | self.feature_variance = self.feature_variance.to(pred.device)
59 | self.weights = self.weights.to(pred.device)
60 | print(pred.shape)
61 | print(target.shape)
62 | print(self.weights.shape)
63 |
64 | out = (pred - target) ** 2
65 | print(out.shape)
66 | if self.normalize:
67 | out = out / self.feature_variance
68 |
69 | assert not torch.isnan(out).any()
70 | # Mean of the physical variables
71 | out = out.mean(-1)
72 |
73 | # Flatten all dimensions except the batch dimension.
74 | B, *dims = out.shape
75 | num_nodes = np.prod(
76 | dims
77 | ) # Total number of grid nodes (e.g., if grid is HxW, then num_nodes = H*W)
78 | out = out.view(B, num_nodes)
79 |
80 | # Determine the number of unique latitude weights and infer the number of grid columns.
81 | num_unique = self.weights.shape[0] # e.g., number of unique latitudes (rows)
82 | num_lon = num_nodes // num_unique # e.g. if 2592 nodes and 36 unique lat, then num_lon=72
83 |
84 | # Tile the unique latitude weights into a full weight grid
85 | weight_grid = self.weights.unsqueeze(1).expand(num_unique, num_lon).reshape(1, num_nodes)
86 | weight_grid = weight_grid.expand(B, num_nodes) # Now weight_grid is [B, num_nodes]
87 |
88 | # Multiply the per-node error by the corresponding weight.
89 | out = out * weight_grid
90 |
91 | assert not torch.isnan(out).any()
92 | return out.mean()
93 |
--------------------------------------------------------------------------------
/graph_weather/models/weathermesh/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/weathermesh/__init__.py
--------------------------------------------------------------------------------
/graph_weather/models/weathermesh/decoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh
3 | """
4 |
5 | from dataclasses import dataclass
6 |
7 | import dacite
8 | import einops
9 | import torch
10 | import torch.nn as nn
11 | from natten import NeighborhoodAttention3D
12 |
13 | from graph_weather.models.weathermesh.layers import ConvUpBlock
14 |
15 |
16 | @dataclass
17 | class WeatherMeshDecoderConfig:
18 | latent_dim: int
19 | output_channels_2d: int
20 | output_channels_3d: int
21 | n_conv_blocks: int
22 | hidden_dim: int
23 | kernel_size: tuple
24 | num_heads: int
25 | num_transformer_layers: int
26 |
27 | @staticmethod
28 | def from_json(json: dict) -> "WeatherMeshDecoder":
29 | return dacite.from_dict(data_class=WeatherMeshDecoderConfig, data=json)
30 |
31 | def to_json(self) -> dict:
32 | return dacite.asdict(self)
33 |
34 |
35 | class WeatherMeshDecoder(nn.Module):
36 | def __init__(
37 | self,
38 | latent_dim,
39 | output_channels_2d,
40 | output_channels_3d,
41 | n_conv_blocks=3,
42 | hidden_dim=256,
43 | kernel_size: tuple = (5, 7, 7),
44 | num_heads: int = 8,
45 | num_transformer_layers: int = 3,
46 | ):
47 | super().__init__()
48 |
49 | # Transformer layers for initial decoding
50 | self.transformer_layers = nn.ModuleList(
51 | [
52 | NeighborhoodAttention3D(
53 | dim=latent_dim, num_heads=num_heads, kernel_size=kernel_size
54 | )
55 | for _ in range(num_transformer_layers)
56 | ]
57 | )
58 |
59 | # Split into pressure levels and surface paths
60 | self.split = nn.Conv3d(latent_dim, hidden_dim * (2**n_conv_blocks), kernel_size=1)
61 |
62 | # Pressure levels (3D) path
63 | self.pressure_path = nn.ModuleList(
64 | [
65 | ConvUpBlock(
66 | hidden_dim * (2 ** (i + 1)),
67 | hidden_dim * (2**i) if i > 0 else output_channels_3d,
68 | is_3d=True,
69 | )
70 | for i in reversed(range(n_conv_blocks))
71 | ]
72 | )
73 |
74 | # Surface (2D) path
75 | self.surface_path = nn.ModuleList(
76 | [
77 | ConvUpBlock(
78 | hidden_dim * (2 ** (i + 1)),
79 | hidden_dim * (2**i) if i > 0 else output_channels_2d,
80 | )
81 | for i in reversed(range(n_conv_blocks))
82 | ]
83 | )
84 |
85 | def forward(self, latent: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
86 | # Needs to be (B,D,H,W,C) with Batch, Depth (vertical levels), Height, Width, Channels
87 | # Apply transformer layers
88 | for transformer in self.transformer_layers:
89 | latent = transformer(latent)
90 |
91 | latent = einops.rearrange(latent, "B D H W C -> B C D H W")
92 | # Split features
93 | features = self.split(latent)
94 | pressure_features = features[:, :, :-1]
95 | surface_features = features[:, :, -1:]
96 | # Decode pressure levels
97 | for block in self.pressure_path:
98 | pressure_features = block(pressure_features)
99 | # Decode surface features
100 | surface_features = surface_features.squeeze(2)
101 | for block in self.surface_path:
102 | surface_features = block(surface_features)
103 |
104 | return surface_features, pressure_features
105 |
--------------------------------------------------------------------------------
/graph_weather/models/weathermesh/encoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh
3 | """
4 |
5 | from dataclasses import dataclass
6 |
7 | import dacite
8 | import einops
9 | import torch
10 | import torch.nn as nn
11 | from natten import NeighborhoodAttention3D
12 |
13 | from graph_weather.models.weathermesh.layers import ConvDownBlock
14 |
15 |
16 | @dataclass
17 | class WeatherMeshEncoderConfig:
18 | input_channels_2d: int
19 | input_channels_3d: int
20 | latent_dim: int
21 | n_pressure_levels: int
22 | num_conv_blocks: int
23 | hidden_dim: int
24 | kernel_size: tuple
25 | num_heads: int
26 | num_transformer_layers: int
27 |
28 | @staticmethod
29 | def from_json(json: dict) -> "WeatherMeshEncoder":
30 | return dacite.from_dict(data_class=WeatherMeshEncoderConfig, data=json)
31 |
32 | def to_json(self) -> dict:
33 | return dacite.asdict(self)
34 |
35 |
36 | class WeatherMeshEncoder(nn.Module):
37 | def __init__(
38 | self,
39 | input_channels_2d: int,
40 | input_channels_3d: int,
41 | latent_dim: int,
42 | n_pressure_levels: int,
43 | num_conv_blocks: int = 3,
44 | hidden_dim: int = 256,
45 | kernel_size: tuple = (5, 7, 7),
46 | num_heads: int = 8,
47 | num_transformer_layers: int = 3,
48 | ):
49 | super().__init__()
50 |
51 | # Surface (2D) path
52 | self.surface_path = nn.ModuleList(
53 | [
54 | ConvDownBlock(
55 | input_channels_2d if i == 0 else hidden_dim * (2**i),
56 | hidden_dim * (2 ** (i + 1)),
57 | )
58 | for i in range(num_conv_blocks)
59 | ]
60 | )
61 |
62 | # Pressure levels (3D) path
63 | self.pressure_path = nn.ModuleList(
64 | [
65 | ConvDownBlock(
66 | input_channels_3d if i == 0 else hidden_dim * (2**i),
67 | hidden_dim * (2 ** (i + 1)),
68 | stride=(1, 2, 2), # Want to keep depth the same size
69 | is_3d=True,
70 | )
71 | for i in range(num_conv_blocks)
72 | ]
73 | )
74 |
75 | # Transformer layers for final encoding
76 | self.transformer_layers = nn.ModuleList(
77 | [
78 | NeighborhoodAttention3D(
79 | dim=latent_dim, kernel_size=kernel_size, num_heads=num_heads
80 | )
81 | for _ in range(num_transformer_layers)
82 | ]
83 | )
84 |
85 | # Final projection to latent space
86 | self.to_latent = nn.Conv3d(hidden_dim * (2**num_conv_blocks), latent_dim, kernel_size=1)
87 |
88 | def forward(self, surface: torch.Tensor, pressure: torch.Tensor) -> torch.Tensor:
89 | # Process surface data
90 | for block in self.surface_path:
91 | surface = block(surface)
92 |
93 | # Process pressure level data
94 | for block in self.pressure_path:
95 | pressure = block(pressure)
96 | # Combine features
97 | features = torch.cat(
98 | [pressure, surface.unsqueeze(2)], dim=2
99 | ) # B C D H W currently, want it to be B D H W C
100 |
101 | # Transform to latent space
102 | latent = self.to_latent(features)
103 |
104 | # Reshape to get the shapes
105 | latent = einops.rearrange(latent, "B C D H W -> B D H W C")
106 | # Apply transformer layers
107 | for transformer in self.transformer_layers:
108 | latent = transformer(latent)
109 | return latent
110 |
--------------------------------------------------------------------------------
/graph_weather/models/weathermesh/layers.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class ConvDownBlock(nn.Module):
11 | """
12 | Downsampling convolutional block with residual connection.
13 | Can handle both 2D and 3D inputs.
14 | """
15 |
16 | def __init__(
17 | self,
18 | in_channels: int,
19 | out_channels: int,
20 | is_3d: bool = False,
21 | kernel_size: int = 3,
22 | stride: int = 2,
23 | padding: int = 1,
24 | groups: int = 1,
25 | activation: nn.Module = nn.GELU(),
26 | ):
27 | super().__init__()
28 |
29 | Conv = nn.Conv3d if is_3d else nn.Conv2d
30 | Norm = nn.BatchNorm3d if is_3d else nn.BatchNorm2d
31 |
32 | self.conv1 = Conv(
33 | in_channels,
34 | out_channels,
35 | kernel_size=kernel_size,
36 | stride=1,
37 | padding=padding,
38 | groups=groups,
39 | bias=False,
40 | )
41 | self.bn1 = Norm(out_channels)
42 | self.activation1 = activation
43 |
44 | self.conv2 = Conv(
45 | out_channels,
46 | out_channels,
47 | kernel_size=kernel_size,
48 | stride=stride,
49 | padding=padding,
50 | groups=groups,
51 | bias=False,
52 | )
53 | self.bn2 = Norm(out_channels)
54 | self.activation2 = activation
55 |
56 | # Residual connection with 1x1 conv to match dimensions
57 | self.downsample = Conv(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
58 | self.bn_down = Norm(out_channels)
59 |
60 | def forward(self, x: torch.Tensor) -> torch.Tensor:
61 | identity = self.bn_down(self.downsample(x))
62 |
63 | out = self.conv1(x)
64 | out = self.bn1(out)
65 | out = self.activation1(out)
66 |
67 | out = self.conv2(out)
68 | out = self.bn2(out)
69 |
70 | out += identity
71 | out = self.activation2(out)
72 |
73 | return out
74 |
75 |
76 | class ConvUpBlock(nn.Module):
77 | """
78 | Upsampling convolutional block with residual connection. same as downBlock but reversed.
79 | """
80 |
81 | def __init__(
82 | self,
83 | in_channels: int,
84 | out_channels: int,
85 | is_3d: bool = False,
86 | kernel_size: int = 3,
87 | scale_factor: int = 2,
88 | padding: int = 1,
89 | groups: int = 1,
90 | activation: nn.Module = nn.GELU(),
91 | ):
92 | super().__init__()
93 |
94 | Conv = nn.Conv3d if is_3d else nn.Conv2d
95 | Norm = nn.BatchNorm3d if is_3d else nn.BatchNorm2d
96 | self.is_3d = is_3d
97 | self.scale_factor = scale_factor
98 |
99 | self.conv1 = Conv(
100 | in_channels,
101 | in_channels,
102 | kernel_size=kernel_size,
103 | stride=1,
104 | padding=padding,
105 | groups=groups,
106 | bias=False,
107 | )
108 | self.bn1 = Norm(in_channels)
109 | self.activation1 = activation
110 |
111 | self.conv2 = Conv(
112 | in_channels,
113 | out_channels,
114 | kernel_size=kernel_size,
115 | stride=1,
116 | padding=padding,
117 | groups=groups,
118 | bias=False,
119 | )
120 | self.bn2 = Norm(out_channels)
121 | self.activation2 = activation
122 |
123 | # Residual connection with 1x1 conv to match dimensions
124 | self.upsample = Conv(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
125 | self.bn_up = Norm(out_channels)
126 |
127 | def forward(self, x: torch.Tensor) -> torch.Tensor:
128 | # Upsample input
129 | if self.is_3d:
130 | x = F.interpolate(
131 | x,
132 | scale_factor=(1, self.scale_factor, self.scale_factor),
133 | mode="trilinear",
134 | align_corners=False,
135 | )
136 | else:
137 | x = F.interpolate(
138 | x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
139 | )
140 |
141 | identity = self.bn_up(self.upsample(x))
142 |
143 | out = self.conv1(x)
144 | out = self.bn1(out)
145 | out = self.activation1(out)
146 |
147 | out = self.conv2(out)
148 | out = self.bn2(out)
149 |
150 | out += identity
151 | out = self.activation2(out)
152 |
153 | return out
154 |
--------------------------------------------------------------------------------
/graph_weather/models/weathermesh/processor.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh
3 | """
4 |
5 | from dataclasses import dataclass
6 |
7 | import dacite
8 | import torch.nn as nn
9 | from natten import NeighborhoodAttention3D
10 |
11 |
12 | @dataclass
13 | class WeatherMeshProcessorConfig:
14 | latent_dim: int
15 | n_layers: int
16 | kernel: tuple
17 | num_heads: int
18 |
19 | @staticmethod
20 | def from_json(json: dict) -> "WeatherMeshProcessor":
21 | return dacite.from_dict(data_class=WeatherMeshProcessorConfig, data=json)
22 |
23 | def to_json(self) -> dict:
24 | return dacite.asdict(self)
25 |
26 |
27 | class WeatherMeshProcessor(nn.Module):
28 | def __init__(self, latent_dim, n_layers=10, kernel=(5, 7, 7), num_heads=8):
29 | super().__init__()
30 |
31 | self.layers = nn.ModuleList(
32 | [
33 | NeighborhoodAttention3D(
34 | dim=latent_dim,
35 | num_heads=num_heads,
36 | kernel_size=kernel,
37 | )
38 | for _ in range(n_layers)
39 | ]
40 | )
41 |
42 | def forward(self, x):
43 | for layer in self.layers:
44 | x = layer(x)
45 | return x
46 |
--------------------------------------------------------------------------------
/graph_weather/models/weathermesh/weathermesh2.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh
3 | """
4 |
5 | from dataclasses import dataclass
6 | from typing import List
7 |
8 | import dacite
9 | import torch
10 | import torch.nn as nn
11 |
12 | from graph_weather.models.weathermesh.decoder import WeatherMeshDecoder, WeatherMeshDecoderConfig
13 | from graph_weather.models.weathermesh.encoder import WeatherMeshEncoder, WeatherMeshEncoderConfig
14 | from graph_weather.models.weathermesh.processor import (
15 | WeatherMeshProcessor,
16 | WeatherMeshProcessorConfig,
17 | )
18 |
19 | """
20 | Notes on implementation
21 |
22 | To make NATTEN work on a sphere, we implement our own circular padding. At the poles, we use the bump attention behavior from NATTEN. For position encoding of tokens, we use Rotary Embeddings.
23 |
24 | In the default configuration of WeatherMesh 2, the NATTEN window is 5,7,7 in depth, width, height, corresponding to a physical size of 14 degrees longitude and latitude. WeatherMesh 2 contains two processors: a 6hr and a 1hr processor. Each is 10 NATTEN layers deep.
25 |
26 | Training: distributed shampoo: https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/README.md
27 |
28 | Fork version of pytorch checkpoint library called matepoint to implement offloading to RAM
29 |
30 | TODO: Add bump attention and rotary embeddings for the circular padding and position encoding
31 |
32 | """
33 |
34 |
35 | @dataclass
36 | class WeatherMeshConfig:
37 | encoder: WeatherMeshEncoderConfig
38 | processors: List[WeatherMeshProcessorConfig]
39 | decoder: WeatherMeshDecoderConfig
40 | timesteps: List[int]
41 | surface_channels: int
42 | pressure_channels: int
43 | pressure_levels: int
44 | latent_dim: int
45 | encoder_num_conv_blocks: int
46 | encoder_num_transformer_layers: int
47 | encoder_hidden_dim: int
48 | decoder_num_conv_blocks: int
49 | decoder_num_transformer_layers: int
50 | decoder_hidden_dim: int
51 | processor_num_layers: int
52 | kernel: tuple
53 | num_heads: int
54 |
55 | @staticmethod
56 | def from_json(json: dict) -> "WeatherMesh":
57 | return dacite.from_dict(data_class=WeatherMeshConfig, data=json)
58 |
59 | def to_json(self) -> dict:
60 | return dacite.asdict(self)
61 |
62 |
63 | @dataclass
64 | class WeatherMeshOutput:
65 | surface: torch.Tensor
66 | pressure: torch.Tensor
67 |
68 |
69 | class WeatherMesh(nn.Module):
70 | def __init__(
71 | self,
72 | encoder: nn.Module | None,
73 | processors: List[nn.Module] | None,
74 | decoder: nn.Module | None,
75 | timesteps: List[int],
76 | surface_channels: int | None,
77 | pressure_channels: int | None,
78 | pressure_levels: int | None,
79 | latent_dim: int | None,
80 | encoder_num_conv_blocks: int | None,
81 | encoder_num_transformer_layers: int | None,
82 | encoder_hidden_dim: int | None,
83 | decoder_num_conv_blocks: int | None,
84 | decoder_num_transformer_layers: int | None,
85 | decoder_hidden_dim: int | None,
86 | processor_num_layers: int | None,
87 | kernel: tuple | None,
88 | num_heads: int | None,
89 | ):
90 | super().__init__()
91 | if encoder is not None:
92 | self.encoder = encoder
93 | else:
94 | self.encoder = WeatherMeshEncoder(
95 | input_channels_2d=surface_channels,
96 | input_channels_3d=pressure_channels,
97 | latent_dim=latent_dim,
98 | n_pressure_levels=pressure_levels,
99 | num_conv_blocks=encoder_num_conv_blocks,
100 | hidden_dim=encoder_hidden_dim,
101 | kernel_size=kernel,
102 | num_heads=num_heads,
103 | num_transformer_layers=encoder_num_transformer_layers,
104 | )
105 | if processors is not None:
106 | assert len(processors) == len(
107 | timesteps
108 | ), "Number of processors must match number of timesteps"
109 | self.processors = processors
110 | else:
111 | self.processors = [
112 | WeatherMeshProcessor(
113 | latent_dim=latent_dim,
114 | n_layers=processor_num_layers,
115 | kernel=kernel,
116 | num_heads=num_heads,
117 | )
118 | for _ in range(len(timesteps))
119 | ]
120 | if decoder is not None:
121 | self.decoder = decoder
122 | else:
123 | self.decoder = WeatherMeshDecoder(
124 | latent_dim=latent_dim,
125 | output_channels_2d=surface_channels,
126 | output_channels_3d=pressure_channels,
127 | n_conv_blocks=decoder_num_conv_blocks,
128 | hidden_dim=decoder_hidden_dim,
129 | kernel_size=kernel,
130 | num_heads=num_heads,
131 | num_transformer_layers=decoder_num_transformer_layers,
132 | )
133 | self.timesteps = timesteps
134 |
135 | def forward(
136 | self, surface: torch.Tensor, pressure: torch.Tensor, forecast_steps: int
137 | ) -> WeatherMeshOutput:
138 | # Encode input
139 | latent = self.encoder(surface, pressure)
140 |
141 | # Apply processors for each forecast step
142 | for _ in range(forecast_steps):
143 | for processor in self.processors:
144 | latent = processor(latent)
145 |
146 | # Decode output
147 | surface_out, pressure_out = self.decoder(latent)
148 |
149 | return WeatherMeshOutput(surface=surface_out, pressure=pressure_out)
150 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "graph_weather"
3 | requires-python = ">=3.11"
4 | version = "1.0.89"
5 | description = "Graph-based AI Weather models"
6 | authors = [
7 | {name = "Jacob Prince-Bieker", email = "jacob@bieker.tech"},
8 | ]
9 | dependencies = ["torch-harmonics"]
10 |
11 | [build-system]
12 | build-backend = "hatchling.build"
13 | requires = ["hatchling"]
14 |
15 | [tool.pixi.project]
16 | channels = ["pyg", "conda-forge", "pytorch"]
17 | platforms = ["linux-64", "osx-arm64"]
18 |
19 | [tool.pixi.feature.cuda]
20 | channels = ["nvidia", {channel = "pytorch", priority = -1}]
21 |
22 | [tool.pixi.feature.cuda.system-requirements]
23 | cuda = "12"
24 |
25 | [tool.pixi.feature.cuda.target.linux-64.dependencies]
26 | cuda-version = "12.4"
27 | pytorch-gpu = {version = "2.4.1", channel = "conda-forge"}
28 |
29 | #[tool.pixi.feature.cuda.target.linux-64.pypi-dependencies]
30 | #natten = {url = "https://shi-labs.com/natten/wheels/cu124/torch2.4.0/natten-0.17.4%2Btorch240cu124-cp312-cp312-linux_x86_64.whl"}
31 |
32 | [tool.pixi.feature.mlx]
33 | # MLX is only available on macOS >=13.5 (>14.0 is recommended)
34 | system-requirements = {macos = "13.5"}
35 |
36 | [tool.pixi.feature.mlx.target.osx-arm64.dependencies]
37 | mlx = {version = "*", channel = "conda-forge"}
38 | pytorch-cpu = {version = "2.4.1", channel = "conda-forge"}
39 |
40 | #[tool.pixi.feature.mlx.target.osx-arm64.pypi-dependencies]
41 | #natten = "*"
42 |
43 | [tool.pixi.feature.cpu]
44 | platforms = ["linux-64", "osx-arm64"]
45 |
46 | [tool.pixi.feature.cpu.dependencies]
47 | pytorch-cpu = {version = "2.4.1", channel = "conda-forge"}
48 |
49 | #[tool.pixi.feature.cpu.pypi-dependencies]
50 | #natten = "*"
51 |
52 | [tool.pixi.dependencies]
53 | python = "3.12.*"
54 | torchvision = {version = "*", channel = "conda-forge"}
55 | pip = "*"
56 | pytest = "*"
57 | pre-commit = "*"
58 | ruff = "*"
59 | xarray = "*"
60 | pandas = "*"
61 | h3-py = "3.*"
62 | numcodecs = "*"
63 | scipy = "*"
64 | zarr = ">=3.0.0"
65 | pyg = "*"
66 | pytorch-cluster = "*"
67 | pytorch-scatter = "*"
68 | pytorch-spline-conv = "*"
69 | pytorch-sparse = "*"
70 | tqdm = "*"
71 | lightning = "*"
72 | einops = "*"
73 | fsspec = "*"
74 | datasets = "*"
75 | trimesh = "*"
76 | pysolar = "*"
77 | rtree = "*"
78 | pixi-pycharm = ">=0.0.8,<0.0.9"
79 | uv = ">=0.6.2,<0.7"
80 | healpy = ">=1.18.1,<2"
81 |
82 |
83 | [tool.pixi.environments]
84 | default = ["cpu"]
85 | cuda = ["cuda"]
86 | mlx = ["mlx"]
87 |
88 | [tool.pixi.tasks]
89 | install = "pip install --editable ."
90 | installnat = "pip install natten"
91 | installnatcuda = "pip install natten==0.17.4+torch240cu124 -f https://shi-labs.com/natten/wheels/"
92 | test = "pytest"
93 | format = "ruff format"
94 |
95 |
96 | [tool.ruff]
97 | # Exclude a variety of commonly ignored directories.
98 | exclude = [
99 | ".bzr",
100 | ".direnv",
101 | ".eggs",
102 | ".git",
103 | ".hg",
104 | ".mypy_cache",
105 | ".nox",
106 | ".nox",
107 | ".pants.d",
108 | ".pytype",
109 | ".ruff_cache",
110 | ".svn",
111 | ".tox",
112 | ".venv",
113 | "__pypackages__",
114 | "_build",
115 | "buck-out",
116 | "build",
117 | "dist",
118 | "node_modules",
119 | "venv",
120 | "tests",
121 | ]
122 | # Same as Black.
123 | line-length = 100
124 |
125 | # Assume Python 3.10.
126 | target-version = "py311"
127 | fix=false
128 | # Group violations by containing file.
129 | output-format = "github"
130 |
131 | [tool.ruff.lint]
132 | # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
133 | select = ["E", "F", "D", "I"]
134 | ignore = ["D200","D202","D210","D212","D415","D105"]
135 |
136 | # Allow autofix for all enabled rules (when `--fix`) is provided.
137 | fixable = ["A", "B", "C", "D", "E", "F", "I"]
138 | unfixable = []
139 | # Allow unused variables when underscore-prefixed.
140 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
141 | mccabe.max-complexity = 10
142 | pydocstyle.convention = "google"
143 |
144 | [tool.ruff.lint.per-file-ignores]
145 | "__init__.py" = ["F401", "E402"]
146 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Setup"""
2 |
3 | from pathlib import Path
4 |
5 | from setuptools import find_packages, setup
6 |
7 | this_directory = Path(__file__).parent
8 | long_description = (this_directory / "README.md").read_text()
9 |
10 | setup(
11 | name="graph_weather",
12 | version="1.0.108",
13 | packages=find_packages(),
14 | url="https://github.com/openclimatefix/graph_weather",
15 | license="MIT License",
16 | company="Open Climate Fix Ltd",
17 | author="Jacob Bieker",
18 | long_description=long_description,
19 | long_description_content_type="text/markdown",
20 | author_email="jacob@bieker.tech",
21 | description="Weather Forecasting with Graph Neural Networks",
22 | classifiers=[
23 | "Development Status :: 4 - Beta",
24 | "Intended Audience :: Developers",
25 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
26 | "License :: OSI Approved :: MIT License",
27 | "Programming Language :: Python :: 3.9",
28 | ],
29 | )
30 |
--------------------------------------------------------------------------------
/tests/test_gencast.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from packaging.version import Version
5 | from torch_geometric.transforms import TwoHop
6 |
7 | from graph_weather.models.gencast import Denoiser, GraphBuilder, Sampler, WeightedMSELoss
8 | from graph_weather.models.gencast.layers.modules import FourierEmbedding
9 | from graph_weather.models.gencast.utils.noise import generate_isotropic_noise, sample_noise_level
10 |
11 |
12 | def test_gencast_noise():
13 | num_lon = 360
14 | num_lat = 180
15 | num_samples = 5
16 | target_residuals = np.zeros((num_lon, num_lat, num_samples))
17 | noise_level = sample_noise_level()
18 | noise = generate_isotropic_noise(
19 | num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1]
20 | )
21 | corrupted_residuals = target_residuals + noise_level * noise
22 | assert corrupted_residuals.shape == target_residuals.shape
23 | assert not np.isnan(corrupted_residuals).any()
24 |
25 | num_lon = 360
26 | num_lat = 181
27 | num_samples = 5
28 | target_residuals = np.zeros((num_lon, num_lat, num_samples))
29 | noise_level = sample_noise_level()
30 | noise = generate_isotropic_noise(
31 | num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1]
32 | )
33 | corrupted_residuals = target_residuals + noise_level * noise
34 | assert corrupted_residuals.shape == target_residuals.shape
35 | assert not np.isnan(corrupted_residuals).any()
36 |
37 | num_lon = 100
38 | num_lat = 100
39 | num_samples = 5
40 | target_residuals = np.zeros((num_lon, num_lat, num_samples))
41 | noise_level = sample_noise_level()
42 | noise = generate_isotropic_noise(
43 | num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1], isotropic=False
44 | )
45 | corrupted_residuals = target_residuals + noise_level * noise
46 | assert corrupted_residuals.shape == target_residuals.shape
47 | assert not np.isnan(corrupted_residuals).any()
48 |
49 |
50 | def test_gencast_graph():
51 | grid_lat = np.arange(-90, 90, 1)
52 | grid_lon = np.arange(0, 360, 1)
53 | graphs = GraphBuilder(grid_lon=grid_lon, grid_lat=grid_lat, splits=4, num_hops=8)
54 |
55 | # compare khop sparse implementation with pyg.
56 | transform = TwoHop()
57 | khop_mesh_graph_pyg = graphs.mesh_graph
58 | for i in range(3): # 8-hop mesh
59 | khop_mesh_graph_pyg = transform(khop_mesh_graph_pyg)
60 |
61 | assert graphs.mesh_graph.x.shape[0] == 2562
62 | assert graphs.g2m_graph["grid_nodes"].x.shape[0] == 360 * 180
63 | assert graphs.m2g_graph["mesh_nodes"].x.shape[0] == 2562
64 | assert not torch.isnan(graphs.mesh_graph.edge_attr).any()
65 | assert graphs.khop_mesh_graph.x.shape[0] == 2562
66 | assert torch.allclose(graphs.khop_mesh_graph.x, khop_mesh_graph_pyg.x)
67 | assert torch.allclose(graphs.khop_mesh_graph.edge_index, khop_mesh_graph_pyg.edge_index)
68 |
69 |
70 | def test_gencast_loss():
71 | grid_lat = torch.arange(-90, 90, 1)
72 | grid_lon = torch.arange(0, 360, 1)
73 | pressure_levels = torch.tensor(
74 | [50.0, 100.0, 150.0, 200.0, 250, 300, 400, 500, 600, 700, 850, 925, 1000.0]
75 | )
76 | single_features_weights = torch.tensor([1, 0.1, 0.1, 0.1, 0.1])
77 | num_atmospheric_features = 6
78 | batch_size = 3
79 | features_dim = len(pressure_levels) * num_atmospheric_features + len(single_features_weights)
80 |
81 | loss = WeightedMSELoss(
82 | grid_lat=grid_lat,
83 | pressure_levels=pressure_levels,
84 | num_atmospheric_features=num_atmospheric_features,
85 | single_features_weights=single_features_weights,
86 | )
87 |
88 | preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim))
89 | noise_levels = torch.rand((batch_size, 1))
90 | targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim))
91 | assert loss.forward(preds, noise_levels, targets) is not None
92 |
93 |
94 | def test_gencast_denoiser():
95 | grid_lat = np.arange(-90, 90, 1)
96 | grid_lon = np.arange(0, 360, 1)
97 | input_features_dim = 10
98 | output_features_dim = 5
99 | batch_size = 3
100 |
101 | denoiser = Denoiser(
102 | grid_lon=grid_lon,
103 | grid_lat=grid_lat,
104 | input_features_dim=input_features_dim,
105 | output_features_dim=output_features_dim,
106 | hidden_dims=[16, 32],
107 | num_blocks=3,
108 | num_heads=4,
109 | splits=0,
110 | num_hops=1,
111 | device=torch.device("cpu"),
112 | ).eval()
113 |
114 | corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim))
115 | prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim))
116 | noise_levels = torch.rand((batch_size, 1))
117 |
118 | with torch.no_grad():
119 | preds = denoiser(
120 | corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels
121 | )
122 |
123 | assert not torch.isnan(preds).any()
124 |
125 |
126 | def test_gencast_fourier():
127 | batch_size = 10
128 | output_dim = 20
129 | fourier_embedder = FourierEmbedding(output_dim=output_dim, num_frequencies=32, base_period=16)
130 | t = torch.rand((batch_size, 1))
131 | assert fourier_embedder(t).shape == (batch_size, output_dim)
132 |
133 |
134 | def test_gencast_sampler():
135 | grid_lat = np.arange(-90, 90, 1)
136 | grid_lon = np.arange(0, 360, 1)
137 | input_features_dim = 10
138 | output_features_dim = 5
139 |
140 | denoiser = Denoiser(
141 | grid_lon=grid_lon,
142 | grid_lat=grid_lat,
143 | input_features_dim=input_features_dim,
144 | output_features_dim=output_features_dim,
145 | hidden_dims=[16, 32],
146 | num_blocks=3,
147 | num_heads=4,
148 | splits=0,
149 | num_hops=1,
150 | device=torch.device("cpu"),
151 | ).eval()
152 |
153 | prev_inputs = torch.randn((1, len(grid_lon), len(grid_lat), 2 * input_features_dim))
154 |
155 | sampler = Sampler()
156 | preds = sampler.sample(denoiser, prev_inputs)
157 | assert not torch.isnan(preds).any()
158 | assert preds.shape == (1, len(grid_lon), len(grid_lat), output_features_dim)
159 |
160 |
161 | @pytest.mark.skipif(
162 | Version(torch.__version__).release != Version("2.3.0").release,
163 | reason="dgl tests for experimental features only runs with torch 2.3.0",
164 | )
165 | def test_gencast_full():
166 | # download weights from HF
167 | denoiser = Denoiser.from_pretrained(
168 | "openclimatefix/gencast-128x64",
169 | grid_lon=np.arange(0, 360, 360 / 128),
170 | grid_lat=np.arange(-90, 90, 180 / 64) + 1 / 2 * 180 / 64,
171 | )
172 |
173 | # load inputs and targets
174 | prev_inputs = torch.randn([1, 128, 64, 178])
175 | target_residuals = torch.randn([1, 128, 64, 83])
176 |
177 | # predict
178 | sampler = Sampler()
179 | preds = sampler.sample(denoiser, prev_inputs)
180 |
181 | assert not torch.isnan(preds).any()
182 | assert preds.shape == target_residuals.shape
183 |
--------------------------------------------------------------------------------
/tests/test_nnjai.py:
--------------------------------------------------------------------------------
1 | """
2 | Unit tests for the `SensorDataset` class, mocking the `DataCatalog` to simulate sensor data loading and validate dataset behavior.
3 | The tests ensure correct handling of data types, shapes, and batch processing for various sensor types.
4 | """
5 |
6 | from datetime import datetime
7 | from unittest.mock import MagicMock, patch
8 | import numpy as np
9 | import pytest
10 | import torch
11 | import pandas as pd
12 |
13 | from graph_weather.data.nnja_ai import SensorDataset, collate_fn
14 |
15 |
16 | def get_sensor_variables(sensor_type):
17 | """Helper function to get the correct variables for each sensor type."""
18 | if sensor_type == "AMSU":
19 | return [f"TMBR_000{i:02d}" for i in range(1, 16)] # 15 channels
20 | elif sensor_type == "ATMS":
21 | return [f"TMBR_000{i:02d}" for i in range(1, 23)] # 22 channels
22 | elif sensor_type == "MHS":
23 | return [f"TMBR_000{i:02d}" for i in range(1, 6)] # 5 channels
24 | elif sensor_type == "IASI":
25 | return [f"SCRA_{str(i).zfill(5)}" for i in range(1, 617)] # 616 channels
26 | elif sensor_type == "CrIS":
27 | return [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)] # 431 channels
28 | return []
29 |
30 |
31 | @pytest.fixture
32 | def mock_datacatalog():
33 | """
34 | Fixture to mock the DataCatalog for unit tests to avoid actual data loading.
35 | """
36 | with patch("graph_weather.data.nnja_ai.DataCatalog") as mock:
37 | # Create a mock catalog
38 | mock_catalog = MagicMock()
39 |
40 | # Create a mock dataset with direct DataFrame return
41 | mock_dataset = MagicMock()
42 | mock_dataset.load_manifest = MagicMock()
43 | mock_dataset.sel = MagicMock(return_value=mock_dataset) # Return self to chain calls
44 |
45 | def create_mock_df(engine="pandas"):
46 | # Get the sensor type from the mock dataset
47 | sensor_vars = get_sensor_variables(mock_dataset.sensor_type)
48 |
49 | # Create DataFrame with required columns
50 | df = pd.DataFrame(
51 | {
52 | "OBS_TIMESTAMP": pd.date_range(
53 | start=datetime(2021, 1, 1), periods=100, freq="H"
54 | ),
55 | "LAT": np.full(100, 45.0),
56 | "LON": np.full(100, -120.0),
57 | }
58 | )
59 |
60 | # Add sensor-specific variables
61 | for var in sensor_vars:
62 | df[var] = np.full(100, 250.0)
63 |
64 | return df
65 |
66 | # Set up the mock to return our DataFrame
67 | mock_dataset.load_dataset = create_mock_df
68 |
69 | # Configure the catalog to return our mock dataset
70 | def get_mock_dataset(self, name):
71 | # Set the sensor type based on the requested dataset name
72 | mock_dataset.sensor_type = next(
73 | config["sensor_type"] for config in SENSOR_CONFIGS if config["name"] == name
74 | )
75 | return mock_dataset
76 |
77 | mock_catalog.__getitem__ = get_mock_dataset # Fix: Explicitly define the method with `self`
78 | mock.return_value = mock_catalog
79 |
80 | yield mock
81 |
82 |
83 | # Test configurations
84 | SENSOR_CONFIGS = [
85 | {
86 | "name": "amsu-1bamua-NC021023",
87 | "sensor_type": "AMSU",
88 | "expected_metadata_size": 15, # 15 TMBR channels
89 | },
90 | {
91 | "name": "atms-atms-NC021203",
92 | "sensor_type": "ATMS",
93 | "expected_metadata_size": 22, # 22 TMBR channels
94 | },
95 | {
96 | "name": "mhs-1bmhs-NC021027",
97 | "sensor_type": "MHS",
98 | "expected_metadata_size": 5, # 5 TMBR channels
99 | },
100 | {
101 | "name": "iasi-mtiasi-NC021241",
102 | "sensor_type": "IASI",
103 | "expected_metadata_size": 616, # 616 SCRA channels
104 | },
105 | {
106 | "name": "cris-crisf4-NC021206",
107 | "sensor_type": "CrIS",
108 | "expected_metadata_size": 431, # 431 SRAD channels
109 | },
110 | ]
111 |
112 |
113 | @pytest.mark.parametrize("sensor_config", SENSOR_CONFIGS)
114 | def test_sensor_dataset(mock_datacatalog, sensor_config):
115 | """Test the SensorDataset class for different sensor types."""
116 | time = datetime(2021, 1, 1, 0, 0)
117 | primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"]
118 |
119 | dataset = SensorDataset(
120 | dataset_name=sensor_config["name"],
121 | time=time,
122 | primary_descriptors=primary_descriptors,
123 | additional_variables=get_sensor_variables(sensor_config["sensor_type"]),
124 | sensor_type=sensor_config["sensor_type"],
125 | )
126 |
127 | # Test dataset length
128 | assert len(dataset) > 0, f"Dataset should not be empty for {sensor_config['sensor_type']}"
129 |
130 | # Test single item structure
131 | item = dataset[0]
132 | expected_keys = {"timestamp", "latitude", "longitude", "metadata"}
133 | assert (
134 | set(item.keys()) == expected_keys
135 | ), f"Dataset item keys are not as expected for {sensor_config['sensor_type']}"
136 |
137 | # Validate tensor properties
138 | assert isinstance(
139 | item["timestamp"], torch.Tensor
140 | ), f"Timestamp should be a tensor for {sensor_config['sensor_type']}"
141 | assert (
142 | item["timestamp"].dtype == torch.float32
143 | ), f"Timestamp should have dtype float32 for {sensor_config['sensor_type']}"
144 | assert (
145 | item["timestamp"].ndim == 0
146 | ), f"Timestamp should be a scalar tensor for {sensor_config['sensor_type']}"
147 |
148 | assert isinstance(
149 | item["latitude"], torch.Tensor
150 | ), f"Latitude should be a tensor for {sensor_config['sensor_type']}"
151 | assert (
152 | item["latitude"].dtype == torch.float32
153 | ), f"Latitude should have dtype float32 for {sensor_config['sensor_type']}"
154 | assert (
155 | item["latitude"].ndim == 0
156 | ), f"Latitude should be a scalar tensor for {sensor_config['sensor_type']}"
157 |
158 | assert isinstance(
159 | item["longitude"], torch.Tensor
160 | ), f"Longitude should be a tensor for {sensor_config['sensor_type']}"
161 | assert (
162 | item["longitude"].dtype == torch.float32
163 | ), f"Longitude should have dtype float32 for {sensor_config['sensor_type']}"
164 | assert (
165 | item["longitude"].ndim == 0
166 | ), f"Longitude should be a scalar tensor for {sensor_config['sensor_type']}"
167 |
168 | assert isinstance(
169 | item["metadata"], torch.Tensor
170 | ), f"Metadata should be a tensor for {sensor_config['sensor_type']}"
171 | assert item["metadata"].shape == (
172 | sensor_config["expected_metadata_size"],
173 | ), f"Metadata shape mismatch for {sensor_config['sensor_type']}. Expected ({sensor_config['expected_metadata_size']},)"
174 | assert (
175 | item["metadata"].dtype == torch.float32
176 | ), f"Metadata should have dtype float32 for {sensor_config['sensor_type']}"
177 |
178 |
179 | def test_collate_function():
180 | """Test the collate_fn function to ensure proper batching of dataset items."""
181 | batch_size = 4
182 | metadata_size = 15 # Using AMSU size for this test
183 | mock_batch = [
184 | {
185 | "timestamp": torch.tensor(datetime.now().timestamp(), dtype=torch.float32),
186 | "latitude": torch.tensor(45.0, dtype=torch.float32),
187 | "longitude": torch.tensor(-120.0, dtype=torch.float32),
188 | "metadata": torch.randn(metadata_size, dtype=torch.float32),
189 | }
190 | for _ in range(batch_size)
191 | ]
192 |
193 | batched = collate_fn(mock_batch)
194 |
195 | assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch"
196 | assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch"
197 | assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch"
198 | assert batched["metadata"].shape == (batch_size, metadata_size), "Metadata batch shape mismatch"
199 | assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch"
200 | assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch"
201 | assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch"
202 | assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch"
203 |
--------------------------------------------------------------------------------
/tests/test_weathermesh.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from graph_weather.models.weathermesh.decoder import WeatherMeshDecoder, WeatherMeshDecoderConfig
4 | from graph_weather.models.weathermesh.encoder import WeatherMeshEncoder, WeatherMeshEncoderConfig
5 | from graph_weather.models.weathermesh.processor import (
6 | WeatherMeshProcessor,
7 | WeatherMeshProcessorConfig,
8 | )
9 | from graph_weather.models.weathermesh.weathermesh2 import WeatherMesh, WeatherMeshConfig
10 |
11 |
12 | def test_weathermesh_encoder():
13 | encoder = WeatherMeshEncoder(
14 | input_channels_2d=2,
15 | input_channels_3d=1,
16 | latent_dim=8,
17 | n_pressure_levels=25,
18 | kernel_size=(3, 3, 3),
19 | num_heads=2,
20 | hidden_dim=16,
21 | num_conv_blocks=3,
22 | num_transformer_layers=3,
23 | )
24 | x_2d = torch.randn(1, 2, 32, 64)
25 | x_3d = torch.randn(1, 1, 25, 32, 64)
26 | out = encoder(x_2d, x_3d)
27 | assert out.shape == (1, 5, 4, 8, 8)
28 |
29 |
30 | def test_weathermesh_processor():
31 | processor = WeatherMeshProcessor(latent_dim=8, n_layers=2)
32 | x = torch.randn(1, 26, 32, 64, 8)
33 | out = processor(x)
34 | assert out.shape == (1, 26, 32, 64, 8)
35 |
36 |
37 | def test_weathermesh_decoder():
38 | decoder = WeatherMeshDecoder(
39 | latent_dim=8,
40 | output_channels_2d=8,
41 | output_channels_3d=4,
42 | kernel_size=(3, 3, 3),
43 | num_heads=2,
44 | hidden_dim=8,
45 | num_transformer_layers=1,
46 | )
47 | x = torch.randn(1, 6, 32, 64, 8)
48 | out = decoder(x)
49 | assert out[0].shape == (1, 8, 256, 512)
50 | assert out[1].shape == (1, 4, 5, 256, 512)
51 |
52 |
53 | def test_weathermesh():
54 | model = WeatherMesh(
55 | encoder=None,
56 | processors=None,
57 | decoder=None,
58 | timesteps=[1, 6],
59 | surface_channels=8,
60 | pressure_channels=4,
61 | pressure_levels=5,
62 | latent_dim=4,
63 | encoder_num_conv_blocks=1,
64 | encoder_num_transformer_layers=1,
65 | encoder_hidden_dim=4,
66 | decoder_num_conv_blocks=1,
67 | decoder_num_transformer_layers=1,
68 | decoder_hidden_dim=4,
69 | processor_num_layers=2,
70 | kernel=(3, 5, 5),
71 | num_heads=2,
72 | )
73 |
74 | x_2d = torch.randn(1, 8, 32, 64)
75 | x_3d = torch.randn(1, 4, 5, 32, 64)
76 | out = model(x_2d, x_3d, forecast_steps=1)
77 | assert out.surface.shape == (1, 8, 32, 64)
78 | assert out.pressure.shape == (1, 4, 5, 32, 64)
79 |
--------------------------------------------------------------------------------
/train/deepspeed_graph.py:
--------------------------------------------------------------------------------
1 | """Module for training the Graph Weather forecaster model using PyTorch Lightning."""
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | from pytorch_lightning import Trainer
6 | from torch.utils.data import DataLoader, Dataset
7 |
8 | from graph_weather import GraphWeatherForecaster
9 |
10 | lat_lons = []
11 | for lat in range(-90, 90, 1):
12 | for lon in range(0, 360, 1):
13 | lat_lons.append((lat, lon))
14 |
15 |
16 | class LitModel(pl.LightningModule):
17 | """
18 | LightningModule for the weather forecasting model.
19 |
20 | Args:
21 | lat_lons: List of latitude and longitude coordinates.
22 | feature_dim: Dimension of the input features.
23 | aux_dim : Dimension of the auxiliary features.
24 |
25 | Methods:
26 | __init__: Initialize the LitModel object.
27 | """
28 |
29 | def __init__(self, lat_lons, feature_dim, aux_dim):
30 | """
31 | Initialize the LitModel object.
32 |
33 | Args:
34 | lat_lons: List of latitude and longitude coordinates.
35 | feature_dim : Dimension of the input features.
36 | aux_dim : Dimension of the auxiliary features.
37 | """
38 | super().__init__()
39 | self.model = GraphWeatherForecaster(
40 | lat_lons=lat_lons, feature_dim=feature_dim, aux_dim=aux_dim
41 | )
42 |
43 | def training_step(self, batch):
44 | """
45 | Performs a training step.
46 |
47 | Args:
48 | batch: A batch of training data.
49 |
50 | Returns:
51 | The computed loss.
52 | """
53 | x, y = batch
54 | x = x.half()
55 | y = y.half()
56 | out = self.forward(x)
57 | criterion = torch.nn.MSELoss()
58 | loss = criterion(out, y)
59 | return loss
60 |
61 | def configure_optimizers(self):
62 | """
63 | Configures the optimizer used during training.
64 |
65 | Returns:
66 | The optimizer.
67 | """
68 | return torch.optim.AdamW(self.parameters())
69 |
70 | def forward(self, x):
71 | """
72 | Forward pass.
73 |
74 | Args:
75 | x (torch.Tensor): Input data.
76 |
77 | Returns:
78 | torch.Tensor: Output of the model.
79 | """
80 | return self.model(x)
81 |
82 |
83 | class FakeDataset(Dataset):
84 | """
85 | Dataset class for generating fake data.
86 |
87 | Methods:
88 | __init__: Initialize the FakeDataset object.
89 | __len__: Return the length of the dataset.
90 | __getitem__: Get an item from the dataset.
91 | """
92 |
93 | def __init__(self):
94 | """
95 | Initialize the FakeDataset object.
96 | """
97 | super(FakeDataset, self).__init__()
98 |
99 | def __len__(self):
100 | return 64000
101 |
102 | def __getitem__(self, item):
103 | return torch.randn((64800, 605 + 32)), torch.randn((64800, 605))
104 |
105 |
106 | model = LitModel(lat_lons=lat_lons, feature_dim=605, aux_dim=32)
107 | trainer = Trainer(
108 | accelerator="gpu",
109 | devices=1,
110 | strategy="deepspeed_stage_3_offload",
111 | precision=16,
112 | max_epochs=10,
113 | limit_train_batches=2000,
114 | )
115 | dataset = FakeDataset()
116 | train_dataloader = DataLoader(
117 | dataset, batch_size=1, num_workers=1, pin_memory=True, prefetch_factor=1
118 | )
119 | trainer.fit(model=model, train_dataloaders=train_dataloader)
120 |
--------------------------------------------------------------------------------
/train/era5.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | import pytorch_lightning as pl
5 | import torch
6 | import xarray
7 | from einops import rearrange
8 | from pytorch_lightning.callbacks import ModelCheckpoint
9 | from torch.utils.data import DataLoader, Dataset
10 |
11 | from graph_weather.models import MetaModel
12 | from graph_weather.models.losses import NormalizedMSELoss
13 |
14 |
15 | class LitFengWuGHR(pl.LightningModule):
16 | """
17 | LightningModule for graph-based weather forecasting.
18 |
19 | Attributes:
20 | model (GraphWeatherForecaster): Graph weather forecaster model.
21 | criterion (NormalizedMSELoss): Loss criterion for training.
22 | lr : Learning rate for optimizer.
23 |
24 | Methods:
25 | __init__: Initialize the LitFengWuGHR object.
26 | forward: Forward pass of the model.
27 | training_step: Training step.
28 | configure_optimizers: Configure the optimizer for training.
29 | """
30 |
31 | def __init__(
32 | self,
33 | lat_lons: list,
34 | *,
35 | channels: int,
36 | image_size,
37 | patch_size=4,
38 | depth=5,
39 | heads=4,
40 | mlp_dim=5,
41 | feature_dim: int = 605, # TODO where does this come from?
42 | lr: float = 3e-4,
43 | ):
44 | """
45 | Initialize the LitFengWuGHR object with the required args.
46 |
47 | Args:
48 | lat_lons : List of latitude and longitude values.
49 | feature_dim : Dimensionality of the input features.
50 | aux_dim : Dimensionality of auxiliary features.
51 | hidden_dim : Dimensionality of hidden layers in the model.
52 | num_blocks : Number of graph convolutional blocks in the model.
53 | lr (float): Learning rate for optimizer.
54 | """
55 | super().__init__()
56 | self.model = MetaModel(
57 | lat_lons,
58 | image_size=image_size,
59 | patch_size=patch_size,
60 | depth=depth,
61 | heads=heads,
62 | mlp_dim=mlp_dim,
63 | channels=channels,
64 | )
65 | self.criterion = NormalizedMSELoss(
66 | lat_lons=lat_lons, feature_variance=np.ones((feature_dim,))
67 | )
68 | self.lr = lr
69 | self.save_hyperparameters()
70 |
71 | def forward(self, x):
72 | """
73 | Forward pass .
74 |
75 | Args:
76 | x (torch.Tensor): Input tensor.
77 |
78 | Returns:
79 | torch.Tensor: Output tensor.
80 | """
81 | return self.model(x)
82 |
83 | def training_step(self, batch, batch_idx):
84 | """
85 | Training step.
86 |
87 | Args:
88 | batch (array): Batch of data containing input and output tensors.
89 | batch_idx (int): Index of the current batch.
90 |
91 | Returns:
92 | torch.Tensor: Loss tensor.
93 | """
94 | x, y = batch[:, 0], batch[:, 1]
95 | if torch.isnan(x).any() or torch.isnan(y).any():
96 | return None
97 | y_hat = self.forward(x)
98 | loss = self.criterion(y_hat, y)
99 | self.log("loss", loss, prog_bar=True)
100 | return loss
101 |
102 | def configure_optimizers(self):
103 | """
104 | Configure the optimizer.
105 |
106 | Returns:
107 | torch.optim.Optimizer: Optimizer instance.
108 | """
109 | return torch.optim.AdamW(self.parameters(), lr=self.lr)
110 |
111 |
112 | class Era5Dataset(Dataset):
113 | """Era5 dataset."""
114 |
115 | def __init__(self, xarr, transform=None):
116 | """
117 | Arguments:
118 | #TODO
119 | """
120 | ds = np.asarray(xarr.to_array())
121 | ds = torch.from_numpy(ds)
122 | ds -= ds.min(0, keepdim=True)[0]
123 | ds /= ds.max(0, keepdim=True)[0]
124 | ds = rearrange(ds, "C T H W -> T (H W) C")
125 | self.ds = ds
126 |
127 | def __len__(self):
128 | return len(self.ds) - 1
129 |
130 | def __getitem__(self, index):
131 | return self.ds[index : index + 2]
132 |
133 |
134 | if __name__ == "__main__":
135 | ckpt_path = Path("./checkpoints")
136 | patch_size = 4
137 | grid_step = 20
138 | variables = [
139 | "2m_temperature",
140 | "surface_pressure",
141 | "10m_u_component_of_wind",
142 | "10m_v_component_of_wind",
143 | ]
144 |
145 | channels = len(variables)
146 | ckpt_path.mkdir(parents=True, exist_ok=True)
147 |
148 | reanalysis = xarray.open_zarr(
149 | "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
150 | storage_options=dict(token="anon"),
151 | )
152 |
153 | reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01"))
154 | reanalysis = reanalysis.isel(
155 | time=slice(100, 107), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step)
156 | )
157 |
158 | reanalysis = reanalysis[variables]
159 | print(f"size: {reanalysis.nbytes / (1024**3)} GiB")
160 |
161 | lat_lons = np.array(
162 | np.meshgrid(
163 | np.asarray(reanalysis["latitude"]).flatten(),
164 | np.asarray(reanalysis["longitude"]).flatten(),
165 | )
166 | ).T.reshape((-1, 2))
167 |
168 | checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss")
169 |
170 | dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8)
171 | model = LitFengWuGHR(
172 | lat_lons=lat_lons,
173 | channels=channels,
174 | image_size=(721 // grid_step, 1440 // grid_step),
175 | patch_size=patch_size,
176 | depth=5,
177 | heads=4,
178 | mlp_dim=5,
179 | )
180 | trainer = pl.Trainer(
181 | accelerator="gpu",
182 | devices=-1,
183 | max_epochs=100,
184 | precision="16-mixed",
185 | callbacks=[checkpoint_callback],
186 | log_every_n_steps=3,
187 | )
188 |
189 | trainer.fit(model, dset)
190 |
191 | torch.save(model.model.state_dict(), ckpt_path / "best.pt")
192 |
--------------------------------------------------------------------------------
/train/lora.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | import pytorch_lightning as pl
5 | import torch
6 | import torch.nn as nn
7 | import xarray
8 | from einops import rearrange
9 | from pytorch_lightning.callbacks import ModelCheckpoint
10 | from torch.utils.data import DataLoader, Dataset
11 |
12 | from graph_weather.models import LoRAModule, MetaModel
13 | from graph_weather.models.losses import NormalizedMSELoss
14 |
15 |
16 | class LitLoRAFengWuGHR(pl.LightningModule):
17 | def __init__(
18 | self,
19 | lat_lons: list,
20 | single_step_model_state_dict: dict,
21 | *,
22 | time_step: int,
23 | rank: int,
24 | channels: int,
25 | image_size,
26 | patch_size=4,
27 | depth=5,
28 | heads=4,
29 | mlp_dim=5,
30 | feature_dim: int = 605, # TODO where does this come from?
31 | lr: float = 3e-4,
32 | ):
33 | super().__init__()
34 | assert (
35 | time_step > 1
36 | ), "Time step must be greater than 1. Remember that 1 is the simple model time step."
37 | ssmodel = MetaModel(
38 | lat_lons,
39 | image_size=image_size,
40 | patch_size=patch_size,
41 | depth=depth,
42 | heads=heads,
43 | mlp_dim=mlp_dim,
44 | channels=channels,
45 | )
46 | ssmodel.load_state_dict(single_step_model_state_dict)
47 | self.models = nn.ModuleList(
48 | [ssmodel] + [LoRAModule(ssmodel, r=rank) for _ in range(2, time_step + 1)]
49 | )
50 | self.criterion = NormalizedMSELoss(
51 | lat_lons=lat_lons, feature_variance=np.ones((feature_dim,))
52 | )
53 | self.lr = lr
54 | self.save_hyperparameters()
55 |
56 | def forward(self, x):
57 | ys = []
58 | for t, model in enumerate(self.models):
59 | x = model(x)
60 | ys.append(x)
61 | return torch.stack(ys, dim=1)
62 |
63 | def training_step(self, batch, batch_idx):
64 | if torch.isnan(batch).any():
65 | return None
66 | x, ys = batch[:, 0, ...], batch[:, 1:, ...]
67 |
68 | y_hat = self.forward(x)
69 | loss = self.criterion(y_hat, ys)
70 | self.log("loss", loss, prog_bar=True)
71 | return loss
72 |
73 | def configure_optimizers(self):
74 | return torch.optim.AdamW(self.parameters(), lr=self.lr)
75 |
76 |
77 | class Era5Dataset(Dataset):
78 | def __init__(self, xarr, time_step=1, transform=None):
79 | assert time_step > 0, "Time step must be greater than 0."
80 | ds = np.asarray(xarr.to_array())
81 | ds = torch.from_numpy(ds)
82 | ds -= ds.min(0, keepdim=True)[0]
83 | ds /= ds.max(0, keepdim=True)[0]
84 | ds = rearrange(ds, "C T H W -> T (H W) C")
85 | self.ds = ds
86 | self.time_step = time_step
87 |
88 | def __len__(self):
89 | return len(self.ds) - self.time_step
90 |
91 | def __getitem__(self, index):
92 | return self.ds[index : index + time_step + 1]
93 |
94 |
95 | if __name__ == "__main__":
96 | ckpt_path = Path("./checkpoints")
97 | ckpt_name = "best.pt"
98 | patch_size = 4
99 | grid_step = 20
100 | time_step = 2
101 | rank = 4
102 | variables = [
103 | "2m_temperature",
104 | "surface_pressure",
105 | "10m_u_component_of_wind",
106 | "10m_v_component_of_wind",
107 | ]
108 |
109 | ###############################################################
110 |
111 | channels = len(variables)
112 | ckpt_path.mkdir(parents=True, exist_ok=True)
113 |
114 | reanalysis = xarray.open_zarr(
115 | "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
116 | storage_options=dict(token="anon"),
117 | )
118 |
119 | reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01"))
120 | reanalysis = reanalysis.isel(
121 | time=slice(100, 111), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step)
122 | )
123 |
124 | reanalysis = reanalysis[variables]
125 | print(f"size: {reanalysis.nbytes / (1024**3)} GiB")
126 |
127 | lat_lons = np.array(
128 | np.meshgrid(
129 | np.asarray(reanalysis["latitude"]).flatten(),
130 | np.asarray(reanalysis["longitude"]).flatten(),
131 | )
132 | ).T.reshape((-1, 2))
133 |
134 | checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss")
135 |
136 | dset = DataLoader(Era5Dataset(reanalysis, time_step=time_step), batch_size=10, num_workers=8)
137 |
138 | single_step_model_state_dict = torch.load(ckpt_path / ckpt_name)
139 |
140 | model = LitLoRAFengWuGHR(
141 | lat_lons=lat_lons,
142 | single_step_model_state_dict=single_step_model_state_dict,
143 | time_step=time_step,
144 | rank=rank,
145 | ##########
146 | channels=channels,
147 | image_size=(721 // grid_step, 1440 // grid_step),
148 | patch_size=patch_size,
149 | depth=5,
150 | heads=4,
151 | mlp_dim=5,
152 | )
153 | trainer = pl.Trainer(
154 | accelerator="gpu",
155 | devices=-1,
156 | max_epochs=100,
157 | precision="16-mixed",
158 | callbacks=[checkpoint_callback],
159 | log_every_n_steps=3,
160 | strategy="ddp_find_unused_parameters_true",
161 | )
162 |
163 | trainer.fit(model, dset)
164 |
--------------------------------------------------------------------------------
/train/run_fulll.py:
--------------------------------------------------------------------------------
1 | """Training script for training the weather forecasting model"""
2 |
3 | import json
4 | import os
5 | import sys
6 | import time
7 |
8 | import numpy as np
9 | import torch
10 | import torch.optim as optim
11 | import torchvision.transforms as transforms
12 | import xarray as xr
13 | from torch.utils.data import DataLoader, Dataset
14 |
15 | from graph_weather import GraphWeatherForecaster
16 | from graph_weather.data import const
17 | from graph_weather.models.losses import NormalizedMSELoss
18 |
19 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
20 | sys.path.append(BASE_DIR)
21 |
22 |
23 | class XrDataset(Dataset):
24 | """
25 | Dataset class for loading data from Hugging Face datasets.
26 |
27 | Attributes:
28 | filepaths : List of file paths to the data.
29 | data : Dataset containing the loaded data.
30 |
31 | Methods:
32 | __init__: Initialize the XrDataset object by loading data from Hugging Face datasets.
33 | __len__: Get the length of the dataset.
34 | __getitem__: Get an item from the dataset by index.
35 | """
36 |
37 | def __init__(self):
38 | """
39 | Initialize the XrDataset object by loading data from Hugging Face datasets.
40 | """
41 | super().__init__()
42 | with open("hf_forecasts.json", "r") as f:
43 | files = json.load(f)
44 | self.filepaths = [
45 | "zip:///::https://huggingface.co/datasets/openclimatefix/gfs-reforecast/resolve/main/"
46 | + f
47 | for f in files
48 | ]
49 | self.data = xr.open_mfdataset(
50 | self.filepaths, engine="zarr", concat_dim="reftime", combine="nested"
51 | ).sortby("reftime")
52 |
53 | def __len__(self):
54 | return len(self.filepaths)
55 |
56 | def __getitem__(self, item):
57 | start_idx = np.random.randint(0, 14)
58 | data = self.data.isel(reftime=item, time=slice(start_idx, start_idx + 2))
59 |
60 | start = data.isel(time=0)
61 | end = data.isel(time=1)
62 | # Stack the data into a large data cube
63 | input_data = np.stack(
64 | [
65 | (start[f"{var}"].values - const.FORECAST_MEANS[f"{var}"])
66 | / (const.FORECAST_STD[f"{var}"] + 0.0001)
67 | for var in start.data_vars
68 | if "mb" in var or "surface" in var
69 | ],
70 | axis=-1,
71 | )
72 | input_data = np.nan_to_num(input_data)
73 | assert not np.isnan(input_data).any()
74 | output_data = np.stack(
75 | [
76 | (end[f"{var}"].values - const.FORECAST_MEANS[f"{var}"])
77 | / (const.FORECAST_STD[f"{var}"] + 0.0001)
78 | for var in end.data_vars
79 | if "mb" in var or "surface" in var
80 | ],
81 | axis=-1,
82 | )
83 | output_data = np.nan_to_num(output_data)
84 | assert not np.isnan(output_data).any()
85 | transform = transforms.Compose([transforms.ToTensor()])
86 | # Normalize now
87 | return (
88 | transform(input_data).transpose(0, 1).reshape(-1, input_data.shape[-1]),
89 | transform(output_data).transpose(0, 1).reshape(-1, input_data.shape[-1]),
90 | )
91 |
92 |
93 | with open("hf_forecasts.json", "r") as f:
94 | files = json.load(f)
95 | files = [
96 | "zip:///::https://huggingface.co/datasets/openclimatefix/gfs-reforecast/resolve/main/" + f
97 | for f in files
98 | ]
99 | data = (
100 | xr.open_zarr(files[0], consolidated=True).isel(time=0)
101 | # .coarsen(latitude=8, boundary="pad")
102 | # .mean()
103 | # .coarsen(longitude=8)
104 | # .mean()
105 | )
106 | print(data)
107 | # print("Done coarsening")
108 | lat_lons = np.array(np.meshgrid(data.latitude.values, data.longitude.values)).T.reshape(-1, 2)
109 |
110 | if torch.cuda.is_available():
111 | device = "cuda"
112 | elif torch.backends.mps.is_available():
113 | device = "mps"
114 | else:
115 | device = "cpu"
116 |
117 | # Get the variance of the variables
118 | feature_variances = []
119 | for var in data.data_vars:
120 | if "mb" in var or "surface" in var:
121 | feature_variances.append(const.FORECAST_DIFF_STD[var] ** 2)
122 | criterion = NormalizedMSELoss(
123 | lat_lons=lat_lons, feature_variance=feature_variances, device=device
124 | ).to(device)
125 | means = []
126 | dataset = DataLoader(XrDataset(), batch_size=1)
127 | model = GraphWeatherForecaster(lat_lons, feature_dim=597, num_blocks=6).to(device)
128 | optimizer = optim.AdamW(model.parameters(), lr=0.001)
129 | print("Done Setup")
130 |
131 |
132 | for epoch in range(100): # loop over the dataset multiple times
133 | running_loss = 0.0
134 | start = time.time()
135 | print(f"Start Epoch: {epoch}")
136 | for i, data in enumerate(dataset):
137 | # get the inputs; data is a list of [inputs, labels]
138 | inputs, labels = data[0].to(device), data[1].to(device)
139 | # zero the parameter gradients
140 | optimizer.zero_grad()
141 |
142 | # forward + backward + optimize
143 | outputs = model(inputs)
144 |
145 | loss = criterion(outputs, labels)
146 | loss.backward()
147 | optimizer.step()
148 |
149 | # print statistics
150 | running_loss += loss.item()
151 | end = time.time()
152 | print(
153 | f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / (i + 1):.3f} Time: {end - start} sec"
154 | )
155 | if epoch % 5 == 0:
156 | assert not np.isnan(running_loss)
157 | model.push_to_hub(
158 | "graph-weather-forecaster-2.0deg",
159 | organization="openclimatefix",
160 | commit_message=f"Add model Epoch={epoch}",
161 | )
162 |
163 | print("Finished Training")
164 |
--------------------------------------------------------------------------------