├── .all-contributorsrc ├── .bumpversion.cfg ├── .github └── workflows │ ├── release.yaml │ └── workflows.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── README.md ├── environment_cpu.yml ├── environment_cuda.yml ├── graph_weather ├── __init__.py ├── data │ ├── IFSAnalysis_dataloader.py │ ├── __init__.py │ ├── const.py │ ├── dataloader.py │ ├── gencast_dataloader.py │ ├── nnja_ai.py │ └── weather_station_reader.py └── models │ ├── __init__.py │ ├── analysis.py │ ├── aurora │ ├── __init__.py │ ├── decoder.py │ ├── encoder.py │ ├── model.py │ └── processor.py │ ├── fengwu_ghr │ ├── __init__.py │ └── layers.py │ ├── forecast.py │ ├── gencast │ ├── README.md │ ├── __init__.py │ ├── denoiser.py │ ├── graph │ │ ├── __init__.py │ │ ├── graph_builder.py │ │ ├── grid_mesh_connectivity.py │ │ ├── icosahedral_mesh.py │ │ └── model_utils.py │ ├── images │ │ ├── animated.gif │ │ ├── autoregressive.gif │ │ ├── fullmodel.png │ │ └── readme.md │ ├── layers │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── experimental │ │ │ ├── __init__.py │ │ │ └── sparse_transformer.py │ │ ├── modules.py │ │ └── processor.py │ ├── sampler.py │ ├── train.py │ ├── utils │ │ ├── __init__.py │ │ ├── batching.py │ │ ├── noise.py │ │ └── statistics.py │ └── weighted_mse_loss.py │ ├── layers │ ├── __init__.py │ ├── assimilator_decoder.py │ ├── assimilator_encoder.py │ ├── constraint_layer.py │ ├── decoder.py │ ├── encoder.py │ ├── graph_net_block.py │ ├── grid_to_points.py │ ├── points_to_grid.py │ └── processor.py │ ├── losses.py │ └── weathermesh │ ├── __init__.py │ ├── decoder.py │ ├── encoder.py │ ├── layers.py │ ├── processor.py │ └── weathermesh2.py ├── pyproject.toml ├── setup.py ├── tests ├── test_aurora.py ├── test_gencast.py ├── test_model.py ├── test_nnjai.py ├── test_weather_station_reader.py └── test_weathermesh.py └── train ├── deepspeed_graph.py ├── era5.py ├── hf_forecasts.json ├── lora.py ├── pl_graph_weather.py ├── run.py └── run_fulll.py /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "files": [ 3 | "README.md" 4 | ], 5 | "imageSize": 100, 6 | "commit": false, 7 | "commitConvention": "angular", 8 | "contributors": [ 9 | { 10 | "login": "jacobbieker", 11 | "name": "Jacob Bieker", 12 | "avatar_url": "https://avatars.githubusercontent.com/u/7170359?v=4", 13 | "profile": "https://www.jacobbieker.com", 14 | "contributions": [ 15 | "code" 16 | ] 17 | }, 18 | { 19 | "login": "JackKelly", 20 | "name": "Jack Kelly", 21 | "avatar_url": "https://avatars.githubusercontent.com/u/460756?v=4", 22 | "profile": "http://jack-kelly.com", 23 | "contributions": [ 24 | "ideas" 25 | ] 26 | }, 27 | { 28 | "login": "byphilipp", 29 | "name": "byphilipp", 30 | "avatar_url": "https://avatars.githubusercontent.com/u/59995258?v=4", 31 | "profile": "https://github.com/byphilipp", 32 | "contributions": [ 33 | "ideas" 34 | ] 35 | }, 36 | { 37 | "login": "paapu88", 38 | "name": "Markus Kaukonen", 39 | "avatar_url": "https://avatars.githubusercontent.com/u/6195764?v=4", 40 | "profile": "http://iki.fi/markus.kaukonen", 41 | "contributions": [ 42 | "question" 43 | ] 44 | }, 45 | { 46 | "login": "MoHawastaken", 47 | "name": "MoHawastaken", 48 | "avatar_url": "https://avatars.githubusercontent.com/u/55447473?v=4", 49 | "profile": "https://github.com/MoHawastaken", 50 | "contributions": [ 51 | "bug" 52 | ] 53 | }, 54 | { 55 | "login": "mishooax", 56 | "name": "Mihai", 57 | "avatar_url": "https://avatars.githubusercontent.com/u/47196359?v=4", 58 | "profile": "http://www.ecmwf.int", 59 | "contributions": [ 60 | "question" 61 | ] 62 | }, 63 | { 64 | "login": "vitusbenson", 65 | "name": "Vitus Benson", 66 | "avatar_url": "https://avatars.githubusercontent.com/u/33334860?v=4", 67 | "profile": "https://github.com/vitusbenson", 68 | "contributions": [ 69 | "bug" 70 | ] 71 | }, 72 | { 73 | "login": "dongZheX", 74 | "name": "dongZheX", 75 | "avatar_url": "https://avatars.githubusercontent.com/u/36361726?v=4", 76 | "profile": "https://github.com/dongZheX", 77 | "contributions": [ 78 | "question" 79 | ] 80 | }, 81 | { 82 | "login": "sabbir2331", 83 | "name": "sabbir2331", 84 | "avatar_url": "https://avatars.githubusercontent.com/u/25061297?v=4", 85 | "profile": "https://github.com/sabbir2331", 86 | "contributions": [ 87 | "question" 88 | ] 89 | }, 90 | { 91 | "login": "rnwzd", 92 | "name": "Lorenzo Breschi", 93 | "avatar_url": "https://avatars.githubusercontent.com/u/58804597?v=4", 94 | "profile": "https://github.com/rnwzd", 95 | "contributions": [ 96 | "code" 97 | ] 98 | }, 99 | { 100 | "login": "gbruno16", 101 | "name": "gbruno16", 102 | "avatar_url": "https://avatars.githubusercontent.com/u/72879691?v=4", 103 | "profile": "https://github.com/gbruno16", 104 | "contributions": [ 105 | "code" 106 | ] 107 | } 108 | ], 109 | "contributorsPerLine": 7, 110 | "projectName": "graph_weather", 111 | "projectOwner": "openclimatefix", 112 | "repoType": "github", 113 | "repoHost": "https://github.com", 114 | "skipCi": true, 115 | "commitType": "docs" 116 | } 117 | -------------------------------------------------------------------------------- /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | commit = True 3 | tag = False 4 | current_version = 1.0.108 5 | message = Bump version: {current_version} → {new_version} [skip ci] 6 | 7 | [bumpversion:file:setup.py] 8 | search = version="{current_version}" 9 | replace = version="{new_version}" 10 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Bump version and auto-release 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | jobs: 8 | bump-version-python-docker-release: 9 | uses: openclimatefix/.github/.github/workflows/python-docker-release.yml@main 10 | secrets: 11 | PAT_TOKEN: ${{ secrets.PAT_TOKEN }} 12 | token: ${{ secrets.PYPI_API_TOKEN }} 13 | DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} 14 | DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} 15 | with: 16 | image_base_name: graph_weather 17 | docker_file: Dockerfile 18 | -------------------------------------------------------------------------------- /.github/workflows/workflows.yaml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | jobs: 8 | pytest: 9 | runs-on: ${{ matrix.os }} 10 | 11 | strategy: 12 | fail-fast: true 13 | matrix: 14 | os: [ubuntu-latest, macos-latest] 15 | python-version: ["3.11", "3.12"] 16 | torch-version: [2.4.0] 17 | include: 18 | - torch-version: 2.4.0 19 | torchvision-version: 0.19.0 20 | steps: 21 | - uses: prefix-dev/setup-pixi@v0.8.3 22 | with: 23 | pixi-version: v0.41.4 24 | cache: true 25 | - run: pixi run test 26 | - uses: actions/checkout@v2 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v2 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - name: Install PyTorch ${{ matrix.torch-version }}+cpu 33 | run: | 34 | pip install torch==${{ matrix.torch-version}} torchvision==${{ matrix.torchvision-version}} --index-url https://download.pytorch.org/whl/cpu 35 | - name: Install internal dependencies 36 | run: | 37 | pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${{ matrix.torch-version}}+cpu.html 38 | if [ ${{ matrix.torch-version}} == 2.4.0 ]; then pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/repo.html; fi 39 | - name: Install main package 40 | run: | 41 | pip install -e . 42 | pip install pytest-xdist 43 | - name: Setup with pytest-xdist 44 | run: | 45 | # lets get the string for how many cpus to use with pytest 46 | echo "Will be using ${{ inputs.pytest_numcpus }} cpus for pytest testing" 47 | # 48 | # make PYTESTXDIST 49 | export PYTESTXDIST="-n 2" 50 | if [ 2 -gt 0 ]; then export PYTESTXDIST="$PYTESTXDIST --dist=loadfile"; fi 51 | # 52 | # echo results and save env var for other jobs 53 | echo "pytest-xdist options that will be used are: $PYTESTXDIST" 54 | echo "PYTESTXDIST=$PYTESTXDIST" >> $GITHUB_ENV 55 | - name: Setup with pytest-cov 56 | run: | 57 | # let make pytest run with coverage 58 | echo "Will be looking at coverage of dir graph_weather" 59 | # 60 | # install pytest-cov 61 | pip install coverage==7.4.3 62 | pip install pytest-cov 63 | # 64 | # make PYTESTCOV 65 | export PYTESTCOV="--cov=graph_weather tests/ --cov-report=xml" 66 | # echo results and save env var for other jobs 67 | echo "pytest-cov options that will be used are: $PYTESTCOV" 68 | echo "PYTESTCOV=$PYTESTCOV" >> $GITHUB_ENV 69 | - name: Run pytest 70 | run: | 71 | # import dgl to initialize backend 72 | if [ ${{ matrix.torch-version}} == 2.4.0 ]; then python3 -c "import dgl"; fi 73 | export PYTEST_COMMAND="pytest $PYTESTCOV $PYTESTXDIST -s" 74 | echo "Will be running this command: $PYTEST_COMMAND" 75 | eval $PYTEST_COMMAND 76 | - name: "Upload coverage to Codecov" 77 | uses: codecov/codecov-action@v2 78 | with: 79 | fail_ci_if_error: false 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.nc 3 | *.pt 4 | *.pyc 5 | .DS_Store 6 | *.txt 7 | # pixi environments 8 | .pixi 9 | .vscode/ 10 | checkpoints/ 11 | lightning_logs/ 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v5.0.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: debug-statements 12 | - id: detect-private-key 13 | 14 | # python code formatting/linting 15 | - repo: https://github.com/astral-sh/ruff-pre-commit 16 | # Ruff version. 17 | rev: "v0.9.2" 18 | hooks: 19 | - id: ruff 20 | args: [--fix] 21 | - repo: https://github.com/psf/black 22 | rev: 24.10.0 23 | hooks: 24 | - id: black 25 | args: [--line-length, "100"] 26 | # yaml formatting 27 | - repo: https://github.com/pre-commit/mirrors-prettier 28 | rev: v4.0.0-alpha.8 29 | hooks: 30 | - id: prettier 31 | types: [yaml] 32 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:latest 2 | 3 | ENV CONDA_ENV_NAME=graph 4 | ENV PYTHON_VERSION=3.12 5 | 6 | # Basic setup 7 | RUN apt update && apt install -y bash \ 8 | build-essential \ 9 | git \ 10 | curl \ 11 | ca-certificates \ 12 | wget \ 13 | libaio-dev \ 14 | && rm -rf /var/lib/apt/lists 15 | 16 | # Install Miniconda and create main env 17 | ADD https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh miniconda3.sh 18 | RUN /bin/bash miniconda3.sh -b -p /conda \ 19 | && echo export PATH=/conda/bin:$PATH >> .bashrc \ 20 | && rm miniconda3.sh 21 | ENV PATH="/conda/bin:${PATH}" 22 | 23 | RUN git clone https://github.com/openclimatefix/graph_weather.git && mv graph_weather/ gw/ && cd gw/ && mv * .. && rm -rf gw/ 24 | 25 | # Copy the appropriate environment file based on CUDA availability 26 | COPY environment_cpu.yml /tmp/environment_cpu.yml 27 | COPY environment_cuda.yml /tmp/environment_cuda.yml 28 | 29 | RUN conda update -n base -c defaults conda 30 | 31 | # Check if CUDA is available and accordingly choose env 32 | RUN cuda=$(command -v nvcc > /dev/null && echo "true" || echo "false") \ 33 | && if [ "$cuda" == "true" ]; then conda env create -f /tmp/environment_cuda.yml; else conda env create -f /tmp/environment_cpu.yml; fi 34 | 35 | # Switch to bash shell 36 | SHELL ["/bin/bash", "-c"] 37 | 38 | # Set ${CONDA_ENV_NAME} to default virutal environment 39 | RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc 40 | 41 | # Cp in the development directory and install 42 | RUN source activate ${CONDA_ENV_NAME} && pip install -e . 43 | 44 | 45 | # Make RUN commands use the new environment: 46 | SHELL ["conda", "run", "-n", "graph", "/bin/bash", "-c"] 47 | 48 | # Example command that can be used, need to set API_KEY, API_SECRET and SAVE_DIR 49 | CMD ["conda", "run", "-n", "graph", "python", "-u", "train/pl_graph_weather.py", "--gpus", "16", "--hidden", "64", "--num-blocks", "3", "--batch", "16"] 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Open Climate Fix 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.txt 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Weather 2 | 3 | [![All Contributors](https://img.shields.io/badge/all_contributors-11-orange.svg?style=flat-square)](#contributors-) 4 | 5 | Implementation of the Graph Weather paper (https://arxiv.org/pdf/2202.07575.pdf) in PyTorch. Additionally, an implementation 6 | of a modified model that assimilates raw or processed observations into analysis files. 7 | 8 | 9 | ## Installation 10 | 11 | This library can be installed through 12 | 13 | ```bash 14 | pip install graph-weather 15 | ``` 16 | 17 | Alternatively, you can install the latest version from the repository easily with `pixi`: 18 | 19 | ```bash 20 | pixi install # `-e cuda` for GPU support, `-e cpu` for CPU-only 21 | ``` 22 | 23 | ## Example Usage 24 | 25 | The models generate the graphs internally, so the only thing that needs to be passed to the model is the node features 26 | in the same order as the ```lat_lons```. 27 | 28 | ```python 29 | import torch 30 | from graph_weather import GraphWeatherForecaster 31 | from graph_weather.models.losses import NormalizedMSELoss 32 | 33 | lat_lons = [] 34 | for lat in range(-90, 90, 1): 35 | for lon in range(0, 360, 1): 36 | lat_lons.append((lat, lon)) 37 | model = GraphWeatherForecaster(lat_lons) 38 | 39 | # Generate 78 random features + 24 non-NWP features (i.e. landsea mask) 40 | features = torch.randn((2, len(lat_lons), 102)) 41 | 42 | target = torch.randn((2, len(lat_lons), 78)) 43 | out = model(features) 44 | 45 | criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,))) 46 | loss = criterion(out, target) 47 | loss.backward() 48 | ``` 49 | 50 | And for the assimilation model, which assumes each lat/lon point also has a height above ground, and each observation 51 | is a single value + the relative time. The assimlation model also assumes the desired output grid is given to it as 52 | well. 53 | 54 | ```python 55 | import torch 56 | import numpy as np 57 | from graph_weather import GraphWeatherAssimilator 58 | from graph_weather.models.losses import NormalizedMSELoss 59 | 60 | obs_lat_lons = [] 61 | for lat in range(-90, 90, 7): 62 | for lon in range(0, 180, 6): 63 | obs_lat_lons.append((lat, lon, np.random.random(1))) 64 | for lon in 360 * np.random.random(100): 65 | obs_lat_lons.append((lat, lon, np.random.random(1))) 66 | 67 | output_lat_lons = [] 68 | for lat in range(-90, 90, 5): 69 | for lon in range(0, 360, 5): 70 | output_lat_lons.append((lat, lon)) 71 | model = GraphWeatherAssimilator(output_lat_lons=output_lat_lons, analysis_dim=24) 72 | 73 | features = torch.randn((1, len(obs_lat_lons), 2)) 74 | lat_lon_heights = torch.tensor(obs_lat_lons) 75 | out = model(features, lat_lon_heights) 76 | assert not torch.isnan(out).all() 77 | assert out.size() == (1, len(output_lat_lons), 24) 78 | 79 | criterion = torch.nn.MSELoss() 80 | loss = criterion(out, torch.randn((1, len(output_lat_lons), 24))) 81 | loss.backward() 82 | ``` 83 | 84 | ## Pretrained Weights 85 | Coming soon! We plan to train a model on GFS 0.25 degree operational forecasts, as well as MetOffice NWP forecasts. 86 | We also plan trying out adaptive meshes, and predicting future satellite imagery as well. 87 | 88 | ## Training Data 89 | Training data will be available through HuggingFace Datasets for the GFS forecasts. The initial set of data is available for [GFSv16 forecasts, raw observations, and FNL Analysis files from 2016 to 2022](https://huggingface.co/datasets/openclimatefix/gfs-reforecast), and for [ERA5 Reanlaysis](https://huggingface.co/datasets/openclimatefix/era5-reanalysis). MetOffice NWP forecasts we cannot 90 | redistribute, but can be accessed through [CEDA](https://data.ceda.ac.uk/). 91 | 92 | ## Contributors ✨ 93 | 94 | Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 |
Jacob Bieker
Jacob Bieker

💻
Jack Kelly
Jack Kelly

🤔
byphilipp
byphilipp

🤔
Markus Kaukonen
Markus Kaukonen

💬
MoHawastaken
MoHawastaken

🐛
Mihai
Mihai

💬
Vitus Benson
Vitus Benson

🐛
dongZheX
dongZheX

💬
sabbir2331
sabbir2331

💬
Lorenzo Breschi
Lorenzo Breschi

💻
gbruno16
gbruno16

💻
118 | 119 | 120 | 121 | 122 | 123 | 124 | This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! 125 | -------------------------------------------------------------------------------- /environment_cpu.yml: -------------------------------------------------------------------------------- 1 | name: graph 2 | channels: 3 | - pytorch 4 | - pyg 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - numcodecs 9 | - pandas 10 | - pip 11 | - pyg 12 | - python 13 | - pytorch 14 | - cpuonly 15 | - pytorch-cluster 16 | - pytorch-scatter 17 | - pytorch-sparse 18 | - pytorch-spline-conv 19 | - scikit-learn 20 | - scipy 21 | - torchvision 22 | - tqdm 23 | - xarray 24 | - zarr 25 | - h3-py 26 | - numpy 27 | - pyshtools 28 | - gcsfs 29 | - pytest 30 | - pip: 31 | - setuptools 32 | - datasets 33 | - einops 34 | - fsspec 35 | - torch-geometric-temporal 36 | - huggingface-hub 37 | - pysolar 38 | - pytorch-lightning 39 | - click 40 | - trimesh 41 | - rtree 42 | - torch-harmonics 43 | -------------------------------------------------------------------------------- /environment_cuda.yml: -------------------------------------------------------------------------------- 1 | name: graph 2 | channels: 3 | - pytorch 4 | - pyg 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - pytorch-cuda 10 | - numcodecs 11 | - pandas 12 | - pip 13 | - pyg 14 | - python=3.12 15 | - pytorch 16 | - pytorch-cluster 17 | - pytorch-scatter 18 | - pytorch-sparse 19 | - pytorch-spline-conv 20 | - scikit-learn 21 | - scipy 22 | - torchvision 23 | - tqdm 24 | - xarray 25 | - zarr 26 | - h3-py 27 | - numpy 28 | - pyshtools 29 | - gcsfs 30 | - pytest 31 | - pip: 32 | - setuptools 33 | - datasets 34 | - einops 35 | - fsspec 36 | - torch-geometric-temporal 37 | - huggingface-hub 38 | - pysolar 39 | - pytorch-lightning 40 | - click 41 | - trimesh 42 | - rtree 43 | - torch-harmonics 44 | -------------------------------------------------------------------------------- /graph_weather/__init__.py: -------------------------------------------------------------------------------- 1 | """Main import for the complete models""" 2 | 3 | from .data.nnja_ai import SensorDataset, collate_fn 4 | from .data.weather_station_reader import WeatherStationReader 5 | from .models.analysis import GraphWeatherAssimilator 6 | from .models.forecast import GraphWeatherForecaster 7 | -------------------------------------------------------------------------------- /graph_weather/data/IFSAnalysis_dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | The dataloader for IFS analysis. 3 | """ 4 | 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | import xarray as xr 8 | from torch.utils.data import Dataset 9 | 10 | IFS_MEAN = { 11 | "geopotential": 78054.78, 12 | "specific_humidity": 0.0018220816, 13 | "temperature": 243.41727, 14 | "u_component_of_wind": 7.3073797, 15 | "v_component_of_wind": 0.032221083, 16 | "vertical_velocity": 0.0058287205, 17 | } 18 | 19 | IFS_STD = { 20 | "geopotential": 59538.875, 21 | "specific_humidity": 0.0035489395, 22 | "temperature": 29.211119, 23 | "u_component_of_wind": 13.777036, 24 | "v_component_of_wind": 8.867598, 25 | "vertical_velocity": 0.08577341, 26 | } 27 | 28 | 29 | class IFSAnalisysDataset(Dataset): 30 | """ 31 | Dataset for IFSAnalysis. 32 | 33 | Args: 34 | filepath: path of the dataset. 35 | features: list of features. 36 | start_year: initial year. Defaults to 2016. 37 | end_year: ending year. Defaults to 2022. 38 | """ 39 | 40 | def __init__(self, filepath: str, features: list, start_year: int = 2016, end_year: int = 2022): 41 | """ 42 | Initialize the dataset object. 43 | """ 44 | 45 | super().__init__() 46 | assert ( 47 | start_year <= end_year 48 | ), f"start_year ({start_year}) cannot be greater than end_year ({end_year})." 49 | assert start_year >= 2016 and start_year <= 2022, "Time data range from 2016 to 2022" 50 | assert end_year >= 2016 and end_year <= 2022, "Time data range from 2016 to 2022" 51 | self.data = xr.open_zarr(filepath) 52 | self.data = self.data.sel( 53 | time=slice(str(start_year), str(end_year)) 54 | ) # Filter data by start and end years 55 | 56 | self.NWP_features = features 57 | 58 | def __len__(self): 59 | return len(self.data["time"]) 60 | 61 | def __getitem__(self, idx): 62 | start = self.data.isel(time=idx) 63 | end = self.data.isel(time=idx + 1) 64 | 65 | # Extract NWP features 66 | input_data = self._nwp_features_extraction(start) 67 | output_data = self._nwp_features_extraction(end) 68 | 69 | return ( 70 | (transforms).ToTensor()(input_data).view(-1, input_data.shape[-1]), 71 | (transforms).ToTensor()(output_data).view(-1, output_data.shape[-1]), 72 | ) 73 | 74 | def _nwp_features_extraction(self, data): 75 | data_cube = np.stack( 76 | [ 77 | (data[f"{var}"].values - IFS_MEAN[f"{var}"]) / (IFS_STD[f"{var}"] + 1e-6) 78 | for var in self.NWP_features 79 | ], 80 | axis=-1, 81 | ).astype(np.float32) 82 | 83 | num_layers, num_lat, num_lon, num_vars = data_cube.shape 84 | data_cube = data_cube.reshape(num_lat, num_lon, num_vars * num_layers) 85 | 86 | assert not np.isnan(data_cube).any() 87 | return data_cube 88 | -------------------------------------------------------------------------------- /graph_weather/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Dataloaders and data processing utilities""" 2 | 3 | from .nnja_ai import SensorDataset, collate_fn 4 | from .weather_station_reader import WeatherStationReader 5 | -------------------------------------------------------------------------------- /graph_weather/data/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | The dataloader has to do a few things for the model to work correctly 4 | 5 | 1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the 6 | correct resolution 7 | 2. Calculate the top-of-atmosphere solar radiation for each location at 8 | fcurrent time and 10 other 9 | times +- 12 hours 10 | 3. Add day-of-year, sin(lat), cos(lat), sin(lon), cos(lon) as well 11 | 3. Batch data as either in geometric batches, or more normally 12 | 4. Rescale between 0 and 1, but don't normalize 13 | 14 | """ 15 | 16 | import const 17 | import numpy as np 18 | import pandas as pd 19 | import xarray as xr 20 | from pysolar.util import extraterrestrial_irrad 21 | from torch.utils.data import Dataset 22 | 23 | 24 | class AnalysisDataset(Dataset): 25 | """ 26 | Dataset class for analysis data. 27 | 28 | Args: 29 | filepaths: List of file paths. 30 | invariant_path: Path to the invariant file. 31 | mean: Mean value. 32 | std Standard deviation value. 33 | coarsen : Coarsening factor. Defaults to 8. 34 | 35 | Methods: 36 | __init__: Initialize the AnalysisDataset object. 37 | __len__: Get the length of the dataset. 38 | __getitem__: Get an item from the dataset. 39 | """ 40 | 41 | def __init__(self, filepaths, invariant_path, mean, std, coarsen: int = 8): 42 | """ 43 | Initialize the AnalysisDataset object. 44 | """ 45 | super().__init__() 46 | self.filepaths = sorted(filepaths) 47 | self.invariant_path = invariant_path 48 | self.coarsen = coarsen 49 | self.mean = mean 50 | self.std = std 51 | 52 | def __len__(self): 53 | return len(self.filepaths) - 1 54 | 55 | def __getitem__(self, item): 56 | if self.coarsen <= 1: # Don't coarsen, so don't even call it 57 | start = xr.open_zarr(self.filepaths[item], consolidated=True) 58 | end = xr.open_zarr(self.filepaths[item + 1], consolidated=True) 59 | else: 60 | start = ( 61 | xr.open_zarr(self.filepaths[item], consolidated=True) 62 | .coarsen(latitude=self.coarsen, boundary="pad") 63 | .mean() 64 | .coarsen(longitude=self.coarsen) 65 | .mean() 66 | ) 67 | end = ( 68 | xr.open_zarr(self.filepaths[item + 1], consolidated=True) 69 | .coarsen(latitude=self.coarsen, boundary="pad") 70 | .mean() 71 | .coarsen(longitude=self.coarsen) 72 | .mean() 73 | ) 74 | 75 | # Land-sea mask data, resampled to the same as the physical variables 76 | landsea = ( 77 | xr.open_zarr(self.invariant_path, consolidated=True) 78 | .interp(latitude=start.latitude.values) 79 | .interp(longitude=start.longitude.values) 80 | ) 81 | # Calculate sin,cos, day of year, solar irradiance here before stacking 82 | landsea = np.stack( 83 | [ 84 | (landsea[f"{var}"].values - const.LANDSEA_MEAN[var]) / const.LANDSEA_STD[var] 85 | for var in landsea.data_vars 86 | if not np.isnan(landsea[f"{var}"].values).any() 87 | ], 88 | axis=-1, 89 | ) 90 | landsea = landsea.T.reshape((-1, landsea.shape[-1])) 91 | lat_lons = np.array(np.meshgrid(start.latitude.values, start.longitude.values)).T.reshape( 92 | (-1, 2) 93 | ) 94 | sin_lat_lons = np.sin(lat_lons) 95 | cos_lat_lons = np.cos(lat_lons) 96 | date = start.time.dt.values 97 | day_of_year = start.time.dayofyear.values / 365.0 98 | np.sin(day_of_year) 99 | np.cos(day_of_year) 100 | solar_times = [np.array([extraterrestrial_irrad(date, lat, lon) for lat, lon in lat_lons])] 101 | for when in pd.date_range( 102 | date - pd.Timedelta("12 hours"), date + pd.Timedelta("12 hours"), freq="1H" 103 | ): 104 | solar_times.append( 105 | np.array([extraterrestrial_irrad(when, lat, lon) for lat, lon in lat_lons]) 106 | ) 107 | solar_times = np.array(solar_times) 108 | 109 | # End time solar radiation too 110 | end_date = end.time.dt.values 111 | end_solar_times = [ 112 | np.array([extraterrestrial_irrad(end_date, lat, lon) for lat, lon in lat_lons]) 113 | ] 114 | for when in pd.date_range( 115 | end_date - pd.Timedelta("12 hours"), end_date + pd.Timedelta("12 hours"), freq="1H" 116 | ): 117 | end_solar_times.append( 118 | np.array([extraterrestrial_irrad(when, lat, lon) for lat, lon in lat_lons]) 119 | ) 120 | end_solar_times = np.array(solar_times) 121 | 122 | # Normalize to between -1 and 1 123 | solar_times -= const.SOLAR_MEAN 124 | solar_times /= const.SOLAR_STD 125 | end_solar_times -= const.SOLAR_MEAN 126 | end_solar_times /= const.SOLAR_STD 127 | 128 | # Stack the data into a large data cube 129 | input_data = np.stack( 130 | [ 131 | start[f"{var}"].values 132 | for var in start.data_vars 133 | if not np.isnan(start[f"{var}"].values).any() 134 | ], 135 | axis=-1, 136 | ) 137 | # TODO Combine with above? And include sin/cos of day of year 138 | input_data = np.concatenate( 139 | [ 140 | input_data.T.reshape((-1, input_data.shape[-1])), 141 | sin_lat_lons, 142 | cos_lat_lons, 143 | solar_times, 144 | landsea, 145 | ], 146 | axis=-1, 147 | ) 148 | # Not want to predict non-physics variables -> Output only the data variables? 149 | # Would be simpler, and just add in the new ones each time 150 | 151 | output_data = np.stack( 152 | [ 153 | end[f"{var}"].values 154 | for var in end.data_vars 155 | if not np.isnan(end[f"{var}"].values).any() 156 | ], 157 | axis=-1, 158 | ) 159 | 160 | output_data = np.concatenate( 161 | [ 162 | output_data.T.reshape((-1, output_data.shape[-1])), 163 | sin_lat_lons, 164 | cos_lat_lons, 165 | end_solar_times, 166 | landsea, 167 | ], 168 | axis=-1, 169 | ) 170 | # Stick with Numpy, don't tensor it, as just going from 0 to 1 171 | 172 | # Normalize now 173 | return input_data, output_data 174 | 175 | 176 | obs_data = xr.open_zarr( 177 | "/home/jacob/Development/prepbufr.gdas.20160101.t00z.nr.48h.raw.zarr", consolidated=True 178 | ) 179 | 180 | # TODO Embedding? These should stay consistent across all of the inputs, so can just load the values 181 | # not the strings? 182 | # Should only take in the quality markers, observations, reported observation time relative to start 183 | # point 184 | # Observation errors, and background values, lat/lon/height/speed of observing thing 185 | 186 | print(obs_data) 187 | print(obs_data.hdr_inst_typ.values) 188 | print(obs_data.hdr_irpt_typ.values) 189 | print(obs_data.obs_qty_table.values) 190 | print(obs_data.hdr_prpt_typ.values) 191 | print(obs_data.hdr_sid_table.values) 192 | print(obs_data.hdr_typ_table.values) 193 | print(obs_data.obs_desc.values) 194 | print(obs_data.data_vars.keys()) 195 | exit() 196 | analysis_data = xr.open_zarr( 197 | "/home/jacob/Development/gdas1.fnl0p25.2016010100.f00.zarr", consolidated=True 198 | ) 199 | print(analysis_data) 200 | -------------------------------------------------------------------------------- /graph_weather/data/nnja_ai.py: -------------------------------------------------------------------------------- 1 | """ 2 | A custom PyTorch Dataset implementation for various sensors like AMSU, ATMS, MHS, IASI, CrIS 3 | 4 | The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and 5 | variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | try: 13 | from nnja import DataCatalog 14 | except ImportError: 15 | print( 16 | "NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`" 17 | ) 18 | 19 | 20 | class SensorDataset(Dataset): 21 | """A custom PyTorch Dataset for handling various sensor data.""" 22 | 23 | def __init__( 24 | self, dataset_name, time, primary_descriptors, additional_variables, sensor_type="AMSU" 25 | ): 26 | """Initialize the dataset loader for various sensors. 27 | 28 | Args: 29 | dataset_name: Name of the dataset to load. 30 | time: Specific timestamp to filter the data. 31 | primary_descriptors: List of primary descriptor variables to include (e.g., OBS_TIMESTAMP, LAT, LON). 32 | additional_variables: List of additional variables to include in metadata. 33 | sensor_type: Type of sensor (AMSU, ATMS, MHS, IASI, CrIS) 34 | """ 35 | self.dataset_name = dataset_name 36 | self.time = time 37 | self.primary_descriptors = primary_descriptors 38 | self.additional_variables = additional_variables 39 | self.sensor_type = sensor_type # New argument for selecting sensor type 40 | 41 | # Load data catalog and dataset 42 | self.catalog = DataCatalog(skip_manifest=True) 43 | self.dataset = self.catalog[self.dataset_name] 44 | self.dataset.load_manifest() 45 | 46 | if self.sensor_type == "AMSU": 47 | self.dataset = self.dataset.sel( 48 | time=self.time, 49 | variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 16)], 50 | ) 51 | elif self.sensor_type == "ATMS": 52 | self.dataset = self.dataset.sel( 53 | time=self.time, 54 | variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 23)], 55 | ) 56 | elif self.sensor_type == "MHS": 57 | self.dataset = self.dataset.sel( 58 | time=self.time, 59 | variables=self.primary_descriptors + [f"TMBR_000{i:02d}" for i in range(1, 6)], 60 | ) 61 | elif self.sensor_type == "IASI": 62 | self.dataset = self.dataset.sel( 63 | time=self.time, 64 | variables=self.primary_descriptors 65 | + ["SCRA_" + str(i).zfill(5) for i in range(1, 617)], 66 | ) 67 | elif self.sensor_type == "CrIS": 68 | self.dataset = self.dataset.sel( 69 | time=self.time, 70 | variables=self.primary_descriptors 71 | + [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)], 72 | ) 73 | else: 74 | raise ValueError(f"Unsupported sensor type: {self.sensor_type}") 75 | 76 | self.dataframe = self.dataset.load_dataset(engine="pandas") 77 | 78 | for col in primary_descriptors: 79 | if col not in self.dataframe.columns: 80 | raise ValueError(f"The dataset must include a '{col}' column.") 81 | 82 | self.metadata_columns = [ 83 | col for col in self.dataframe.columns if col not in self.primary_descriptors 84 | ] 85 | 86 | def __len__(self): 87 | """Return the total number of samples in the dataset.""" 88 | return len(self.dataframe) 89 | 90 | def __getitem__(self, index): 91 | """Return the observation and metadata for a given index.""" 92 | row = self.dataframe.iloc[index] 93 | time = row["OBS_TIMESTAMP"].timestamp() 94 | latitude = row["LAT"] 95 | longitude = row["LON"] 96 | metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) 97 | 98 | return { 99 | "timestamp": torch.tensor(time, dtype=torch.float32), 100 | "latitude": torch.tensor(latitude, dtype=torch.float32), 101 | "longitude": torch.tensor(longitude, dtype=torch.float32), 102 | "metadata": torch.from_numpy(metadata), 103 | } 104 | 105 | 106 | def collate_fn(batch): 107 | """Custom collate function to handle batching of dictionary data. 108 | 109 | Args: 110 | batch: List of dictionaries from __getitem__ 111 | 112 | Returns: 113 | Single dictionary with batched tensors 114 | """ 115 | return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()} 116 | -------------------------------------------------------------------------------- /graph_weather/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Models""" 2 | 3 | from .fengwu_ghr.layers import ( 4 | ImageMetaModel, 5 | LoRAModule, 6 | MetaModel, 7 | WrapperImageModel, 8 | WrapperMetaModel, 9 | ) 10 | from .layers.assimilator_decoder import AssimilatorDecoder 11 | from .layers.assimilator_encoder import AssimilatorEncoder 12 | from .layers.decoder import Decoder 13 | from .layers.encoder import Encoder 14 | from .layers.processor import Processor 15 | -------------------------------------------------------------------------------- /graph_weather/models/analysis.py: -------------------------------------------------------------------------------- 1 | """Model for forecasting weather from NWP states""" 2 | 3 | import torch 4 | from huggingface_hub import PyTorchModelHubMixin 5 | 6 | from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Processor 7 | 8 | 9 | class GraphWeatherAssimilator(torch.nn.Module, PyTorchModelHubMixin): 10 | """Model to generate analysis file from raw observations""" 11 | 12 | def __init__( 13 | self, 14 | output_lat_lons: list, 15 | resolution: int = 2, 16 | observation_dim: int = 2, 17 | analysis_dim: int = 78, 18 | node_dim: int = 256, 19 | edge_dim: int = 256, 20 | num_blocks: int = 9, 21 | hidden_dim_processor_node: int = 256, 22 | hidden_dim_processor_edge: int = 256, 23 | hidden_layers_processor_node: int = 2, 24 | hidden_layers_processor_edge: int = 2, 25 | hidden_dim_decoder: int = 128, 26 | hidden_layers_decoder: int = 2, 27 | norm_type: str = "LayerNorm", 28 | use_checkpointing: bool = False, 29 | ): 30 | """ 31 | Graph Weather Data Assimilation model 32 | 33 | Args: 34 | observation_lat_lons: Lat/lon points of the observations 35 | output_lat_lons: List of latitude and longitudes for the output analysis 36 | resolution: Resolution of the H3 grid, prefer even resolutions, as 37 | odd ones have octogons and heptagons as well 38 | observation_dim: Input feature size 39 | analysis_dim: Output Analysis feature dim 40 | node_dim: Node hidden dimension 41 | edge_dim: Edge hidden dimension 42 | num_blocks: Number of message passing blocks in the Processor 43 | hidden_dim_processor_node: Hidden dimension of the node processors 44 | hidden_dim_processor_edge: Hidden dimension of the edge processors 45 | hidden_layers_processor_node: Number of hidden layers in the node processors 46 | hidden_layers_processor_edge: Number of hidden layers in the edge processors 47 | hidden_dim_decoder:Number of hidden dimensions in the decoder 48 | hidden_layers_decoder: Number of layers in the decoder 49 | norm_type: Type of norm for the MLPs 50 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None 51 | use_checkpointing: Whether to use gradient checkpointing or not 52 | """ 53 | super().__init__() 54 | 55 | self.encoder = AssimilatorEncoder( 56 | resolution=resolution, 57 | input_dim=observation_dim, 58 | output_dim=node_dim, 59 | output_edge_dim=edge_dim, 60 | hidden_dim_processor_edge=hidden_dim_processor_edge, 61 | hidden_layers_processor_node=hidden_layers_processor_node, 62 | hidden_dim_processor_node=hidden_dim_processor_node, 63 | hidden_layers_processor_edge=hidden_layers_processor_edge, 64 | mlp_norm_type=norm_type, 65 | use_checkpointing=use_checkpointing, 66 | ) 67 | self.processor = Processor( 68 | input_dim=node_dim, 69 | edge_dim=edge_dim, 70 | num_blocks=num_blocks, 71 | hidden_dim_processor_edge=hidden_dim_processor_edge, 72 | hidden_layers_processor_node=hidden_layers_processor_node, 73 | hidden_dim_processor_node=hidden_dim_processor_node, 74 | hidden_layers_processor_edge=hidden_layers_processor_edge, 75 | mlp_norm_type=norm_type, 76 | ) 77 | self.decoder = AssimilatorDecoder( 78 | lat_lons=output_lat_lons, 79 | resolution=resolution, 80 | input_dim=node_dim, 81 | output_dim=analysis_dim, 82 | output_edge_dim=edge_dim, 83 | hidden_dim_processor_edge=hidden_dim_processor_edge, 84 | hidden_layers_processor_node=hidden_layers_processor_node, 85 | hidden_dim_processor_node=hidden_dim_processor_node, 86 | hidden_layers_processor_edge=hidden_layers_processor_edge, 87 | mlp_norm_type=norm_type, 88 | hidden_dim_decoder=hidden_dim_decoder, 89 | hidden_layers_decoder=hidden_layers_decoder, 90 | use_checkpointing=use_checkpointing, 91 | ) 92 | 93 | def forward(self, features: torch.Tensor, obs_lat_lon_heights: torch.Tensor) -> torch.Tensor: 94 | """ 95 | Compute the analysis output 96 | 97 | Args: 98 | features: The input features, aligned with the order of lat_lons_heights 99 | obs_lat_lon_heights: Observation lat/lon/heights in same order as features 100 | 101 | Returns: 102 | The next state in the forecast 103 | """ 104 | x, edge_idx, edge_attr = self.encoder(features, obs_lat_lon_heights) 105 | x = self.processor(x, edge_idx, edge_attr) 106 | x = self.decoder(x, features.shape[0]) 107 | return x 108 | -------------------------------------------------------------------------------- /graph_weather/models/aurora/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Aurora: A Foundation Model for Earth System Science 3 | - Combines 3D Swin Transformer encoding 4 | - Perceiver processing for efficient computation 5 | - 3D decoding for spatial-temporal predictions 6 | """ 7 | 8 | from .decoder import Decoder3D 9 | from .encoder import Swin3DEncoder 10 | from .model import AuroraModel, EarthSystemLoss 11 | from .processor import PerceiverProcessor 12 | 13 | __version__ = "0.1.0" 14 | 15 | __all__ = [ 16 | "AuroraModel", 17 | "EarthSystemLoss", 18 | "Swin3DEncoder", 19 | "Decoder3D", 20 | "PerceiverProcessor", 21 | ] 22 | 23 | # Default configurations for different model sizes 24 | MODEL_CONFIGS = { 25 | "tiny": { 26 | "in_channels": 1, 27 | "out_channels": 1, 28 | "embed_dim": 48, 29 | "latent_dim": 256, 30 | "spatial_shape": (16, 16, 16), 31 | "max_seq_len": 2048, 32 | }, 33 | "base": { 34 | "in_channels": 1, 35 | "out_channels": 1, 36 | "embed_dim": 96, 37 | "latent_dim": 512, 38 | "spatial_shape": (32, 32, 32), 39 | "max_seq_len": 4096, 40 | }, 41 | "large": { 42 | "in_channels": 1, 43 | "out_channels": 1, 44 | "embed_dim": 192, 45 | "latent_dim": 1024, 46 | "spatial_shape": (64, 64, 64), 47 | "max_seq_len": 8192, 48 | }, 49 | } 50 | 51 | 52 | def create_model(config="base", **kwargs): 53 | """ 54 | Create an Aurora model with specified configuration. 55 | 56 | Args: 57 | config (str): Model size configuration ('tiny', 'base', or 'large') 58 | **kwargs: Override default configuration parameters 59 | 60 | Returns: 61 | AuroraModel: Initialized model with specified configuration 62 | """ 63 | if config not in MODEL_CONFIGS: 64 | raise ValueError( 65 | f"Unknown configuration: {config}. Choose from {list(MODEL_CONFIGS.keys())}" 66 | ) 67 | 68 | # Start with default config and update with any provided kwargs 69 | model_config = MODEL_CONFIGS[config].copy() 70 | model_config.update(kwargs) 71 | 72 | return AuroraModel(**model_config) 73 | 74 | 75 | def create_loss(alpha=0.5, beta=0.3, gamma=0.2): 76 | """ 77 | Create an EarthSystemLoss instance with specified weights. 78 | 79 | Args: 80 | alpha (float): Weight for MSE loss 81 | beta (float): Weight for gradient loss 82 | gamma (float): Weight for physical consistency loss 83 | 84 | Returns: 85 | EarthSystemLoss: Initialized loss function 86 | """ 87 | return EarthSystemLoss(alpha=alpha, beta=beta, gamma=gamma) 88 | -------------------------------------------------------------------------------- /graph_weather/models/aurora/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3D Decoder: 3 | - Takes processed latent representations and reconstructs output. 4 | - Uses transposed convolution to upscale back to spatial-temporal format. 5 | """ 6 | 7 | import torch.nn as nn 8 | 9 | 10 | class Decoder3D(nn.Module): 11 | """ 12 | 3D Decoder: 13 | - Takes processed latent representations and reconstructs the spatial-temporal output. 14 | - Uses transposed convolutions to upscale latent features to the original format. 15 | """ 16 | 17 | def __init__(self, output_channels=1, embed_dim=96, target_shape=(32, 32, 32)): 18 | """ 19 | Args: 20 | output_channels (int): Number of channels in the output tensor (e.g., 1 for grayscale). 21 | embed_dim (int): Dimension of the latent features (matches the encoder's output). 22 | target_shape (tuple): The desired shape of the reconstructed 3D tensor (D, H, W). 23 | """ 24 | super().__init__() 25 | self.embed_dim = embed_dim 26 | self.target_shape = target_shape 27 | self.deconv1 = nn.ConvTranspose3d( 28 | embed_dim, output_channels, kernel_size=3, padding=1, stride=1 29 | ) 30 | 31 | def forward(self, x): 32 | """ 33 | Forward pass for the decoder. 34 | 35 | Args: 36 | x (torch.Tensor): Input latent representation, shape (batch, seq_len, embed_dim). 37 | 38 | Returns: 39 | torch.Tensor: Reconstructed 3D tensor, shape (batch, output_channels, *target_shape). 40 | """ 41 | batch_size = x.shape[0] 42 | depth, height, width = self.target_shape 43 | # Reshape latent features into 3D tensor 44 | x = x.view(batch_size, self.embed_dim, depth, height, width) 45 | # Transposed convolution to upscale to the final shape 46 | x = self.deconv1(x) 47 | return x 48 | -------------------------------------------------------------------------------- /graph_weather/models/aurora/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Swin 3D Transformer Encoder: 3 | - Uses a 3D convolution for initial feature extraction. 4 | - Applies layer normalization and reshapes data. 5 | - Uses a transformer-based encoder to learn spatial-temporal features. 6 | """ 7 | 8 | import torch.nn as nn 9 | from einops import rearrange 10 | from einops.layers.torch import Rearrange 11 | 12 | 13 | class Swin3DEncoder(nn.Module): 14 | def __init__(self, in_channels=1, embed_dim=96): 15 | super().__init__() 16 | self.conv1 = nn.Conv3d(in_channels, embed_dim, kernel_size=3, padding=1, stride=1) 17 | self.norm = nn.LayerNorm(embed_dim) 18 | self.swin_transformer = nn.Transformer( 19 | d_model=embed_dim, 20 | nhead=8, 21 | num_encoder_layers=4, 22 | num_decoder_layers=4, 23 | dim_feedforward=embed_dim * 4, 24 | ) 25 | self.embed_dim = embed_dim 26 | 27 | # Define rearrangement patterns using einops 28 | self.to_transformer_format = Rearrange("b d h w c -> (d h w) b c") 29 | self.from_transformer_format = Rearrange("(d h w) b c -> b d h w c", d=None, h=None, w=None) 30 | 31 | # To use rearrange function directly instead of the Rearrange layer 32 | def forward(self, x): 33 | # 3D convolution with einops rearrangement 34 | x = self.conv1(x) 35 | 36 | # Rearrange for normalization using einops 37 | x = rearrange(x, "b c d h w -> b d h w c") 38 | x = self.norm(x) 39 | 40 | # Store spatial dimensions for later reconstruction 41 | d, h, w = x.shape[1:4] 42 | 43 | # Transform to sequence format for transformer 44 | x = rearrange(x, "b d h w c -> (d h w) b c") 45 | x = self.swin_transformer.encoder(x) 46 | 47 | # Restore original spatial structure 48 | x = rearrange(x, "(d h w) b c -> b (d h w) c", d=d, h=h, w=w) 49 | 50 | # Reshape to the expected output format (batch, seq_len, embed_dim) 51 | x = rearrange(x, "b (d h w) c -> b (d h w) c", d=d, h=h, w=w) 52 | 53 | return x 54 | 55 | def convolution(self, x): 56 | """Apply 3D convolution with clear shape transformation.""" 57 | return self.conv1(x) # b c d h w -> b embed_dim d h w 58 | 59 | def normalization_layer(self, x): 60 | """Apply layer normalization with einops rearrangement.""" 61 | x = rearrange(x, "b c d h w -> b d h w c") 62 | return self.norm(x) 63 | 64 | def transformer_encoder(self, x, spatial_dims): 65 | """ 66 | Apply transformer encoding with proper shape handling. 67 | 68 | Args: 69 | x (torch.Tensor): Input tensor 70 | spatial_dims (tuple): Original (depth, height, width) dimensions 71 | """ 72 | d, h, w = spatial_dims 73 | x = self.to_transformer_format(x) 74 | x = self.swin_transformer.encoder(x) 75 | x = self.from_transformer_format(x, d=d, h=h, w=w) 76 | return x 77 | -------------------------------------------------------------------------------- /graph_weather/models/aurora/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | aurora/model.py - Core implementation of Aurora model for unstructured point data 3 | """ 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class PointEncoder(nn.Module): 12 | def __init__(self, input_features: int, embed_dim: int, max_seq_len: int = 1024): 13 | super().__init__() 14 | self.input_dim = input_features + 2 # Account for lat/lon coordinates 15 | self.max_seq_len = max_seq_len 16 | 17 | # Remove positional embeddings as they break point ordering invariance 18 | 19 | # Enhanced coordinate embedding 20 | self.coord_encoder = nn.Sequential( 21 | nn.Linear(2, embed_dim // 2), 22 | nn.LayerNorm(embed_dim // 2), 23 | nn.ReLU(), 24 | nn.Linear(embed_dim // 2, embed_dim), 25 | ) 26 | 27 | # Feature embedding 28 | self.feature_encoder = nn.Sequential( 29 | nn.Linear(input_features, embed_dim), 30 | nn.LayerNorm(embed_dim), 31 | nn.ReLU(), 32 | nn.Linear(embed_dim, embed_dim), 33 | ) 34 | 35 | # Final normalization 36 | self.norm = nn.LayerNorm(embed_dim) 37 | 38 | def forward(self, points: torch.Tensor, features: torch.Tensor) -> torch.Tensor: 39 | num_points = points.shape[1] 40 | if num_points > self.max_seq_len: 41 | points = points[:, : self.max_seq_len, :] 42 | features = features[:, : self.max_seq_len, :] 43 | 44 | # Normalize coordinates to [-1, 1] range 45 | normalized_points = torch.stack( 46 | [points[..., 0] / 180.0, points[..., 1] / 90.0], dim=-1 # longitude # latitude 47 | ) 48 | 49 | # Separately encode coordinates and features 50 | coord_embedding = self.coord_encoder(normalized_points) 51 | feature_embedding = self.feature_encoder(features) 52 | 53 | # Combine embeddings through addition (order-invariant operation) 54 | x = coord_embedding + feature_embedding 55 | 56 | # Final normalization 57 | x = self.norm(x) 58 | 59 | return x 60 | 61 | 62 | class PointDecoder(nn.Module): 63 | """Decodes latent representations back to point features.""" 64 | 65 | def __init__(self, embed_dim: int, output_features: int): 66 | super().__init__() 67 | self.decoder = nn.Sequential( 68 | nn.Linear(embed_dim, embed_dim), nn.ReLU(), nn.Linear(embed_dim, output_features) 69 | ) 70 | 71 | def forward(self, x: torch.Tensor) -> torch.Tensor: 72 | """ 73 | Args: 74 | x: (batch_size, num_points, embed_dim) tensor 75 | Returns: 76 | (batch_size, num_points, output_features) tensor 77 | """ 78 | return self.decoder(x) 79 | 80 | 81 | class PointCloudProcessor(nn.Module): 82 | """Processes point cloud data using self-attention layers.""" 83 | 84 | def __init__(self, embed_dim: int, num_layers: int = 4): 85 | super().__init__() 86 | self.layers = nn.ModuleList([SelfAttentionLayer(embed_dim) for _ in range(num_layers)]) 87 | 88 | def forward(self, x: torch.Tensor) -> torch.Tensor: 89 | """ 90 | Args: 91 | x: (batch_size, num_points, embed_dim) tensor 92 | Returns: 93 | (batch_size, num_points, embed_dim) tensor after processing 94 | """ 95 | for layer in self.layers: 96 | x = layer(x) 97 | return x 98 | 99 | 100 | class SelfAttentionLayer(nn.Module): 101 | def __init__(self, embed_dim: int): 102 | super().__init__() 103 | self.attention = nn.MultiheadAttention(embed_dim, num_heads=8) 104 | self.norm1 = nn.LayerNorm(embed_dim) 105 | self.norm2 = nn.LayerNorm(embed_dim) 106 | self.ffn = nn.Sequential( 107 | nn.Linear(embed_dim, 4 * embed_dim), nn.ReLU(), nn.Linear(4 * embed_dim, embed_dim) 108 | ) 109 | 110 | def forward(self, x: torch.Tensor) -> torch.Tensor: 111 | # First attention block with residual 112 | x_t = x.transpose(0, 1) 113 | attended, _ = self.attention(x_t, x_t, x_t) 114 | attended = attended.transpose(0, 1) 115 | x = self.norm1(x + attended) 116 | 117 | # FFN block with residual 118 | x = self.norm2(x + self.ffn(x)) 119 | return x 120 | 121 | 122 | class EarthSystemLoss(nn.Module): 123 | def __init__(self, alpha: float = 0.5, beta: float = 0.3, gamma: float = 0.2): 124 | super().__init__() 125 | self.alpha = alpha 126 | self.beta = beta 127 | self.gamma = gamma 128 | 129 | def spatial_correlation_loss( 130 | self, pred: torch.Tensor, target: torch.Tensor, points: torch.Tensor 131 | ) -> torch.Tensor: 132 | batch_size, num_points, _ = points.shape 133 | points_flat = points.view(-1, 2) 134 | 135 | # Compute pairwise distances 136 | dists = torch.cdist(points_flat, points_flat) 137 | dists = dists.view(batch_size, num_points, num_points) 138 | 139 | # Create mask for nearby points (5 degrees threshold) 140 | nearby_mask = (dists < 5.0).float().unsqueeze(-1) 141 | 142 | # Compute differences 143 | pred_diff = pred.unsqueeze(2) - pred.unsqueeze(1) 144 | target_diff = target.unsqueeze(2) - target.unsqueeze(1) 145 | 146 | # Calculate loss with proper broadcasting 147 | correlation_loss = torch.mean(nearby_mask * (pred_diff - target_diff).pow(2)) 148 | 149 | return correlation_loss 150 | 151 | def physical_loss(self, pred: torch.Tensor, points: torch.Tensor) -> torch.Tensor: 152 | """Calculate physical consistency loss - ensures predictions follow basic physical laws""" 153 | # Ensure non-negative values for physical quantities (e.g., temperature in Kelvin) 154 | min_value_loss = torch.nn.functional.relu(-pred).mean() 155 | 156 | # Ensure reasonable maximum values (e.g., max temperature) 157 | max_value_loss = torch.nn.functional.relu(pred - 500).mean() # Assuming max value of 500 158 | 159 | # Add latitude-based consistency (e.g., colder at poles) 160 | latitude = points[..., 1] # Second coordinate is latitude 161 | abs_latitude = torch.abs(latitude) 162 | latitude_consistency = torch.mean( 163 | torch.nn.functional.relu(pred[..., 0] - (1.0 - abs_latitude / 90.0) * pred.mean()) 164 | ) 165 | 166 | # Combine physical constraints 167 | physical_loss = min_value_loss + max_value_loss + 0.1 * latitude_consistency 168 | return physical_loss 169 | 170 | def forward(self, pred: torch.Tensor, target: torch.Tensor, points: torch.Tensor) -> dict: 171 | mse_loss = torch.nn.functional.mse_loss(pred, target) 172 | spatial_loss = self.spatial_correlation_loss(pred, target, points) 173 | physical_loss = self.physical_loss(pred, points) 174 | 175 | # Combine losses with the specified weights 176 | total_loss = self.alpha * mse_loss + self.beta * spatial_loss + self.gamma * physical_loss 177 | 178 | return { 179 | "total_loss": total_loss, 180 | "mse_loss": mse_loss, 181 | "spatial_correlation_loss": spatial_loss, 182 | "physical_loss": physical_loss, 183 | } 184 | 185 | 186 | class AuroraModel(nn.Module): 187 | def __init__( 188 | self, 189 | input_features: int, 190 | output_features: int, 191 | latent_dim: int = 256, 192 | num_layers: int = 4, 193 | max_points: int = 10000, 194 | max_seq_len: int = 1024, 195 | use_checkpointing: bool = False, 196 | ): 197 | super().__init__() 198 | 199 | self.max_points = max_points 200 | self.max_seq_len = max_seq_len 201 | self.input_features = input_features 202 | self.output_features = output_features 203 | 204 | # Model components 205 | self.encoder = PointEncoder(input_features, latent_dim, max_seq_len) 206 | self.processor = PointCloudProcessor(latent_dim, num_layers) 207 | self.decoder = PointDecoder(latent_dim, output_features) 208 | 209 | # Add gradient checkpointing 210 | self.use_checkpointing = use_checkpointing 211 | 212 | # Initialize weights properly 213 | self._init_weights() 214 | 215 | def _init_weights(self): 216 | for m in self.modules(): 217 | if isinstance(m, nn.Linear): 218 | nn.init.xavier_uniform_(m.weight) 219 | if m.bias is not None: 220 | nn.init.zeros_(m.bias) 221 | 222 | def forward( 223 | self, points: torch.Tensor, features: torch.Tensor, mask: Optional[torch.Tensor] = None 224 | ) -> torch.Tensor: 225 | if points.shape[1] > self.max_points: 226 | raise ValueError( 227 | f"Number of points ({points.shape[1]}) exceeds maximum ({self.max_points})" 228 | ) 229 | 230 | # Handle mask properly 231 | if mask is not None: 232 | mask = mask.float().unsqueeze(-1) 233 | points = points * mask 234 | features = features * mask 235 | 236 | # Forward pass with gradient checkpointing 237 | x = self.encoder(points, features) 238 | 239 | if self.use_checkpointing and self.training: 240 | x = torch.utils.checkpoint.checkpoint(self.processor, x) 241 | else: 242 | x = self.processor(x) 243 | 244 | output = self.decoder(x) 245 | 246 | # Apply mask to output if provided 247 | if mask is not None: 248 | output = output * mask 249 | 250 | return output 251 | -------------------------------------------------------------------------------- /graph_weather/models/aurora/processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Perceiver Transformer Processor: 3 | - Takes encoded features and processes them using latent space mapping. 4 | - Uses a latent-space bottleneck to compress input dimensions. 5 | - Provides an efficient way to extract long-range dependencies. 6 | - All architectural parameters are configurable. 7 | """ 8 | 9 | from dataclasses import dataclass 10 | from typing import Optional 11 | 12 | import einops 13 | import torch.nn as nn 14 | 15 | 16 | @dataclass 17 | class ProcessorConfig: 18 | input_dim: int = 256 # Match Swin3D output 19 | latent_dim: int = 512 20 | d_model: int = 256 # Match input_dim for consistency 21 | max_seq_len: int = 4096 22 | num_self_attention_layers: int = 6 23 | num_cross_attention_layers: int = 2 24 | num_attention_heads: int = 8 25 | hidden_dropout: float = 0.1 26 | attention_dropout: float = 0.1 27 | qk_head_dim: Optional[int] = 32 28 | activation_fn: str = "gelu" 29 | layer_norm_eps: float = 1e-12 30 | 31 | def __post_init__(self): 32 | # Validate parameters 33 | if self.input_dim <= 0: 34 | raise ValueError("input_dim must be positive") 35 | if self.max_seq_len <= 0: 36 | raise ValueError("max_seq_len must be positive") 37 | if self.num_attention_heads <= 0: 38 | raise ValueError("num_attention_heads must be positive") 39 | if not 0 <= self.hidden_dropout <= 1: 40 | raise ValueError("hidden_dropout must be between 0 and 1") 41 | if not 0 <= self.attention_dropout <= 1: 42 | raise ValueError("attention_dropout must be between 0 and 1") 43 | 44 | 45 | class PerceiverProcessor(nn.Module): 46 | def __init__(self, config: Optional[ProcessorConfig] = None): 47 | super().__init__() 48 | self.config = config or ProcessorConfig() 49 | 50 | # Input projection to match d_model 51 | self.input_projection = nn.Linear(self.config.input_dim, self.config.d_model) 52 | 53 | # Simplified architecture using transformer encoder 54 | self.encoder = nn.TransformerEncoder( 55 | nn.TransformerEncoderLayer( 56 | d_model=self.config.d_model, 57 | nhead=self.config.num_attention_heads, 58 | dim_feedforward=self.config.d_model * 4, 59 | dropout=self.config.hidden_dropout, 60 | activation=self.config.activation_fn, 61 | ), 62 | num_layers=self.config.num_self_attention_layers, 63 | ) 64 | 65 | # Output projection 66 | self.output_projection = nn.Linear(self.config.d_model, self.config.latent_dim) 67 | 68 | def forward(self, x, attention_mask=None): 69 | # Handle 4D input using einops for clearer reshaping 70 | if len(x.shape) == 4: 71 | # Rearrange from (batch, seq, height, width) to (batch, seq*height*width, features) 72 | x = einops.rearrange(x, "b s h w -> b (s h w) c") 73 | 74 | # Project input 75 | x = self.input_projection(x) 76 | 77 | # Apply transformer encoder with einops for transpose operations 78 | if attention_mask is not None: 79 | # Convert boolean mask to float mask where True -> 0, False -> -inf 80 | mask = ~attention_mask 81 | mask = mask.float().masked_fill(mask, float("-inf")) 82 | x = einops.rearrange( 83 | x, "b s c -> s b c" 84 | ) # (batch, seq, channels) -> (seq, batch, channels) 85 | x = self.encoder(x, src_key_padding_mask=mask) 86 | x = einops.rearrange( 87 | x, "s b c -> b s c" 88 | ) # (seq, batch, channels) -> (batch, seq, channels) 89 | else: 90 | x = einops.rearrange(x, "b s c -> s b c") 91 | x = self.encoder(x) 92 | x = einops.rearrange(x, "s b c -> b s c") 93 | 94 | # Project to latent dimension and pool 95 | x = self.output_projection(x) 96 | x = x.mean(dim=1) # Global average pooling 97 | 98 | return x 99 | -------------------------------------------------------------------------------- /graph_weather/models/fengwu_ghr/__init__.py: -------------------------------------------------------------------------------- 1 | """Main import for FengWu-GHR""" 2 | 3 | from .layers import ImageMetaModel, LoRAModule, MetaModel, WrapperImageModel, WrapperMetaModel 4 | -------------------------------------------------------------------------------- /graph_weather/models/forecast.py: -------------------------------------------------------------------------------- 1 | """Model for forecasting weather from NWP states""" 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | from huggingface_hub import PyTorchModelHubMixin 8 | 9 | from graph_weather.models import Decoder, Encoder, Processor 10 | from graph_weather.models.layers.constraint_layer import PhysicalConstraintLayer 11 | 12 | 13 | class GraphWeatherForecaster(torch.nn.Module, PyTorchModelHubMixin): 14 | """Main weather prediction model from the paper with physical constraints""" 15 | 16 | def __init__( 17 | self, 18 | lat_lons: list, 19 | resolution: int = 2, 20 | feature_dim: int = 78, 21 | aux_dim: int = 24, 22 | output_dim: Optional[int] = None, 23 | node_dim: int = 256, 24 | edge_dim: int = 256, 25 | num_blocks: int = 9, 26 | hidden_dim_processor_node: int = 256, 27 | hidden_dim_processor_edge: int = 256, 28 | hidden_layers_processor_node: int = 2, 29 | hidden_layers_processor_edge: int = 2, 30 | hidden_dim_decoder: int = 128, 31 | hidden_layers_decoder: int = 2, 32 | norm_type: str = "LayerNorm", 33 | use_checkpointing: bool = False, 34 | constraint_type: str = "none", 35 | ): 36 | """ 37 | Graph Weather Model based off https://arxiv.org/pdf/2202.07575.pdf 38 | 39 | Args: 40 | lat_lons: List of latitude and longitudes for the grid 41 | resolution: Resolution of the H3 grid, prefer even resolutions, as 42 | odd ones have octogons and heptagons as well 43 | feature_dim: Input feature size 44 | aux_dim: Number of non-NWP features (i.e. landsea mask, lat/lon, etc) 45 | output_dim: Optional, output feature size, useful if want only subset of variables in 46 | output 47 | node_dim: Node hidden dimension 48 | edge_dim: Edge hidden dimension 49 | num_blocks: Number of message passing blocks in the Processor 50 | hidden_dim_processor_node: Hidden dimension of the node processors 51 | hidden_dim_processor_edge: Hidden dimension of the edge processors 52 | hidden_layers_processor_node: Number of hidden layers in the node processors 53 | hidden_layers_processor_edge: Number of hidden layers in the edge processors 54 | hidden_dim_decoder:Number of hidden dimensions in the decoder 55 | hidden_layers_decoder: Number of layers in the decoder 56 | norm_type: Type of norm for the MLPs 57 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None 58 | use_checkpointing: Use gradient checkpointing to reduce model memory 59 | constraint_type: Type of constraint to apply for physical constraints 60 | one of 'additive', 'multiplicative', 'softmax', or 'none' 61 | """ 62 | super().__init__() 63 | self.feature_dim = feature_dim 64 | self.constraint_type = constraint_type 65 | if output_dim is None: 66 | output_dim = self.feature_dim 67 | self.output_dim = output_dim 68 | 69 | # Compute the geographical grid shape from lat_lons. 70 | unique_lats = sorted(set(lat for lat, _ in lat_lons)) 71 | unique_lons = sorted(set(lon for _, lon in lat_lons)) 72 | self.grid_shape = (len(unique_lats), len(unique_lons)) # (H, W) 73 | 74 | # Store original node order and create grid mapping 75 | self.original_lat_lons = lat_lons.copy() 76 | self._create_grid_mapping(unique_lats, unique_lons) 77 | 78 | self.encoder = Encoder( 79 | lat_lons=lat_lons, 80 | resolution=resolution, 81 | input_dim=feature_dim + aux_dim, 82 | output_dim=node_dim, 83 | output_edge_dim=edge_dim, 84 | hidden_dim_processor_edge=hidden_dim_processor_edge, 85 | hidden_layers_processor_node=hidden_layers_processor_node, 86 | hidden_dim_processor_node=hidden_dim_processor_node, 87 | hidden_layers_processor_edge=hidden_layers_processor_edge, 88 | mlp_norm_type=norm_type, 89 | use_checkpointing=use_checkpointing, 90 | ) 91 | self.processor = Processor( 92 | input_dim=node_dim, 93 | edge_dim=edge_dim, 94 | num_blocks=num_blocks, 95 | hidden_dim_processor_edge=hidden_dim_processor_edge, 96 | hidden_layers_processor_node=hidden_layers_processor_node, 97 | hidden_dim_processor_node=hidden_dim_processor_node, 98 | hidden_layers_processor_edge=hidden_layers_processor_edge, 99 | mlp_norm_type=norm_type, 100 | ) 101 | self.decoder = Decoder( 102 | lat_lons=lat_lons, 103 | resolution=resolution, 104 | input_dim=node_dim, 105 | output_dim=output_dim, 106 | output_edge_dim=edge_dim, 107 | hidden_dim_processor_edge=hidden_dim_processor_edge, 108 | hidden_layers_processor_node=hidden_layers_processor_node, 109 | hidden_dim_processor_node=hidden_dim_processor_node, 110 | hidden_layers_processor_edge=hidden_layers_processor_edge, 111 | mlp_norm_type=norm_type, 112 | hidden_dim_decoder=hidden_dim_decoder, 113 | hidden_layers_decoder=hidden_layers_decoder, 114 | use_checkpointing=use_checkpointing, 115 | ) 116 | 117 | # Add physical constraint layer if constraint_type is not "none" 118 | if self.constraint_type != "none": 119 | self.constraint = PhysicalConstraintLayer( 120 | model=self, 121 | grid_shape=self.grid_shape, 122 | constraint_type=constraint_type, 123 | upsampling_factor=1, 124 | ) 125 | 126 | def _create_grid_mapping(self, unique_lats, unique_lons): 127 | """Create (row,col) mapping for original node order""" 128 | self.node_to_grid = [] 129 | for lat, lon in self.original_lat_lons: 130 | row = int( 131 | (lat - min(unique_lats)) 132 | / (max(unique_lats) - min(unique_lats)) 133 | * (len(unique_lats) - 1) 134 | ) 135 | col = int( 136 | (lon - min(unique_lons)) 137 | / (max(unique_lons) - min(unique_lons)) 138 | * (len(unique_lons) - 1) 139 | ) 140 | self.node_to_grid.append((row, col)) 141 | 142 | def graph_to_grid(self, graph_tensor): 143 | """ 144 | 145 | Convert graph tensor to grid using spatial mapping: 146 | [B, N, C] -> [B, C, H, W] 147 | """ 148 | batch_size, num_nodes, features = graph_tensor.shape 149 | grid = torch.zeros(batch_size, features, *self.grid_shape) 150 | for node_idx, (row, col) in enumerate(self.node_to_grid): 151 | grid[..., row, col] = graph_tensor[..., node_idx, :] 152 | return grid 153 | 154 | def grid_to_graph(self, grid_tensor): 155 | """Convert grid to graph tensor: [B, C, H, W] -> [B, N, C]""" 156 | batch_size, features, H, W = grid_tensor.shape 157 | graph = torch.zeros(batch_size, H * W, features) 158 | for node_idx, (row, col) in enumerate(self.node_to_grid): 159 | graph[..., node_idx, :] = grid_tensor[..., row, col] 160 | return graph 161 | 162 | def forward(self, features: torch.Tensor) -> torch.Tensor: 163 | """ 164 | Compute the new state of the forecast 165 | 166 | Args: 167 | features: The input features, aligned with the order of lat_lons_heights 168 | 169 | Returns: 170 | The next state in the forecast 171 | """ 172 | x, edge_idx, edge_attr = self.encoder(features) 173 | x = self.processor(x, edge_idx, edge_attr) 174 | x = self.decoder(x, features[..., : self.feature_dim]) 175 | 176 | # Here, assume decoder output x is a 4D tensor, 177 | # e.g. [B, output_dim, H, W] where H and W are grid dimensions. 178 | # Convert graph output to grid format 179 | 180 | # Apply physical constraints to decoder output 181 | if self.constraint_type != "none": 182 | x = rearrange(x, "b (h w) c -> b c h w", h=self.grid_shape[0], w=self.grid_shape[1]) 183 | # Extract the low-res reference from the input. 184 | # (Original features has shape [B, num_nodes, feature_dim]) 185 | lr = features[..., : self.feature_dim] # shape: [B, num_nodes, feature_dim] 186 | # Convert from node format to grid format using the grid_shape computed in __init__ 187 | # From [B, num_nodes, feature_dim] to [B, feature_dim, H, W] 188 | lr = rearrange(lr, "b (h w) c -> b c h w", h=self.grid_shape[0], w=self.grid_shape[1]) 189 | if lr.size(1) != x.size(1): 190 | repeat_factor = x.size(1) // lr.size(1) 191 | lr = repeat(lr, "b c h w -> b (r c) h w", r=repeat_factor) 192 | x = self.constraint(x, lr) 193 | return x 194 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/__init__.py: -------------------------------------------------------------------------------- 1 | """Main import for GenCast""" 2 | 3 | from .denoiser import Denoiser 4 | from .graph.graph_builder import GraphBuilder 5 | from .sampler import Sampler 6 | from .utils.noise import generate_isotropic_noise 7 | from .weighted_mse_loss import WeightedMSELoss 8 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/graph/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils for graph generation.""" 2 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/graph/grid_mesh_connectivity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Source: https://github.com/google-deepmind/graphcast. 15 | """Tools for converting from regular grids on a sphere, to triangular meshes.""" 16 | 17 | import numpy as np 18 | import scipy 19 | import trimesh 20 | 21 | from graph_weather.models.gencast.graph import icosahedral_mesh 22 | 23 | 24 | def _grid_lat_lon_to_coordinates( 25 | grid_latitude: np.ndarray, grid_longitude: np.ndarray 26 | ) -> np.ndarray: 27 | """Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3].""" 28 | # Convert to spherical coordinates phi and theta defined in the grid. 29 | # Each [num_latitude_points, num_longitude_points] 30 | phi_grid, theta_grid = np.meshgrid(np.deg2rad(grid_longitude), np.deg2rad(90 - grid_latitude)) 31 | 32 | # [num_latitude_points, num_longitude_points, 3] 33 | # Note this assumes unit radius, since for now we model the earth as a 34 | # sphere of unit radius, and keep any vertical dimension as a regular grid. 35 | return np.stack( 36 | [ 37 | np.cos(phi_grid) * np.sin(theta_grid), 38 | np.sin(phi_grid) * np.sin(theta_grid), 39 | np.cos(theta_grid), 40 | ], 41 | axis=-1, 42 | ) 43 | 44 | 45 | def radius_query_indices( 46 | *, 47 | grid_latitude: np.ndarray, 48 | grid_longitude: np.ndarray, 49 | mesh: icosahedral_mesh.TriangularMesh, 50 | radius: float, 51 | ) -> tuple[np.ndarray, np.ndarray]: 52 | """Returns mesh-grid edge indices for radius query. 53 | 54 | Args: 55 | grid_latitude: Latitude values for the grid [num_lat_points] 56 | grid_longitude: Longitude values for the grid [num_lon_points] 57 | mesh: Mesh object. 58 | radius: Radius of connectivity in R3. for a sphere of unit radius. 59 | 60 | Returns: 61 | tuple with `grid_indices` and `mesh_indices` indicating edges between the 62 | grid and the mesh such that the distances in a straight line (not geodesic) 63 | are smaller than or equal to `radius`. 64 | * grid_indices: Indices of shape [num_edges], that index into a 65 | [num_lat_points, num_lon_points] grid, after flattening the leading axes. 66 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. 67 | """ 68 | 69 | # [num_grid_points=num_lat_points * num_lon_points, 3] 70 | grid_positions = _grid_lat_lon_to_coordinates(grid_latitude, grid_longitude).reshape([-1, 3]) 71 | 72 | # [num_mesh_points, 3] 73 | mesh_positions = mesh.vertices 74 | kd_tree = scipy.spatial.cKDTree(mesh_positions) 75 | 76 | # [num_grid_points, num_mesh_points_per_grid_point] 77 | # Note `num_mesh_points_per_grid_point` is not constant, so this is a list 78 | # of arrays, rather than a 2d array. 79 | query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius) 80 | 81 | grid_edge_indices = [] 82 | mesh_edge_indices = [] 83 | for grid_index, mesh_neighbors in enumerate(query_indices): 84 | grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors))) 85 | mesh_edge_indices.append(mesh_neighbors) 86 | 87 | # [num_edges] 88 | grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int) 89 | mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int) 90 | 91 | return grid_edge_indices, mesh_edge_indices 92 | 93 | 94 | def in_mesh_triangle_indices( 95 | *, grid_latitude: np.ndarray, grid_longitude: np.ndarray, mesh: icosahedral_mesh.TriangularMesh 96 | ) -> tuple[np.ndarray, np.ndarray]: 97 | """Returns mesh-grid edge indices for grid points contained in mesh triangles. 98 | 99 | Args: 100 | grid_latitude: Latitude values for the grid [num_lat_points] 101 | grid_longitude: Longitude values for the grid [num_lon_points] 102 | mesh: Mesh object. 103 | 104 | Returns: 105 | tuple with `grid_indices` and `mesh_indices` indicating edges between the 106 | grid and the mesh vertices of the triangle that contain each grid point. 107 | The number of edges is always num_lat_points * num_lon_points * 3 108 | * grid_indices: Indices of shape [num_edges], that index into a 109 | [num_lat_points, num_lon_points] grid, after flattening the leading axes. 110 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. 111 | """ 112 | 113 | # [num_grid_points=num_lat_points * num_lon_points, 3] 114 | grid_positions = _grid_lat_lon_to_coordinates(grid_latitude, grid_longitude).reshape([-1, 3]) 115 | 116 | mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) 117 | 118 | # [num_grid_points] with mesh face indices for each grid point. 119 | _, _, query_face_indices = trimesh.proximity.closest_point(mesh_trimesh, grid_positions) 120 | 121 | # [num_grid_points, 3] with mesh node indices for each grid point. 122 | mesh_edge_indices = mesh.faces[query_face_indices] 123 | 124 | # [num_grid_points, 3] with grid node indices, where every row simply contains 125 | # the row (grid_point) index. 126 | grid_indices = np.arange(grid_positions.shape[0]) 127 | grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3]) 128 | 129 | # Flatten to get a regular list. 130 | # [num_edges=num_grid_points*3] 131 | mesh_edge_indices = mesh_edge_indices.reshape([-1]) 132 | grid_edge_indices = grid_edge_indices.reshape([-1]) 133 | 134 | return grid_edge_indices, mesh_edge_indices 135 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/images/animated.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/gencast/images/animated.gif -------------------------------------------------------------------------------- /graph_weather/models/gencast/images/autoregressive.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/gencast/images/autoregressive.gif -------------------------------------------------------------------------------- /graph_weather/models/gencast/images/fullmodel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/gencast/images/fullmodel.png -------------------------------------------------------------------------------- /graph_weather/models/gencast/images/readme.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/gencast/images/readme.md -------------------------------------------------------------------------------- /graph_weather/models/gencast/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """GenCast layers.""" 2 | 3 | from .decoder import Decoder 4 | from .encoder import Encoder 5 | from .modules import MLP, InteractionNetwork 6 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/layers/decoder.py: -------------------------------------------------------------------------------- 1 | """Decoder layer. 2 | 3 | The decoder: 4 | - perform a single message-passing step on mesh2grid using a classical interaction network. 5 | - add a residual connection to the grid nodes. 6 | """ 7 | 8 | import torch 9 | 10 | from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork 11 | 12 | 13 | class Decoder(torch.nn.Module): 14 | """GenCast's decoder.""" 15 | 16 | def __init__( 17 | self, 18 | edges_dim: int, 19 | output_dim: int, 20 | hidden_dims: list[int], 21 | activation_layer: torch.nn.Module = torch.nn.ReLU, 22 | use_layer_norm: bool = True, 23 | ): 24 | """Initialize the Decoder. 25 | 26 | Args: 27 | edges_dim (int): dimension of edges' features. 28 | output_dim (int): dimension of final output. 29 | hidden_dims (list[int]): hidden dimensions of internal MLPs. 30 | activation_layer (torch.nn.Module, optional): activation function of internal MLPs. 31 | Defaults to torch.nn.ReLU. 32 | use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP. 33 | Defaults to True. 34 | """ 35 | super().__init__() 36 | 37 | # All the MLPs in GenCast have same hidden and output dims. Hence, the embedding latent 38 | # dimension and the MLPs' output dimension are the same. Moreover, for simplicity, we will 39 | # ask the hidden dims just once for each MLP in a module: we don't need to specify them 40 | # individually as arguments, even if the MLPs could have different roles. 41 | self.latent_dim = hidden_dims[-1] 42 | 43 | # Embedders 44 | self.edges_mlp = MLP( 45 | input_dim=edges_dim, 46 | hidden_dims=hidden_dims, 47 | activation_layer=activation_layer, 48 | use_layer_norm=use_layer_norm, 49 | bias=True, 50 | activate_final=False, 51 | ) 52 | 53 | # Message Passing 54 | self.gnn = InteractionNetwork( 55 | sender_dim=self.latent_dim, 56 | receiver_dim=self.latent_dim, 57 | edge_attr_dim=self.latent_dim, 58 | hidden_dims=hidden_dims, 59 | use_layer_norm=use_layer_norm, 60 | activation_layer=activation_layer, 61 | ) 62 | 63 | # Final grid nodes update 64 | self.grid_mlp_final = MLP( 65 | input_dim=self.latent_dim, 66 | hidden_dims=hidden_dims[:-1] + [output_dim], 67 | activation_layer=activation_layer, 68 | use_layer_norm=use_layer_norm, 69 | bias=True, 70 | activate_final=False, 71 | ) 72 | 73 | def forward( 74 | self, 75 | input_mesh_nodes: torch.Tensor, 76 | input_grid_nodes: torch.Tensor, 77 | input_edge_attr: torch.Tensor, 78 | edge_index: torch.Tensor, 79 | ) -> torch.Tensor: 80 | """Forward pass. 81 | 82 | Args: 83 | input_mesh_nodes (torch.Tensor): mesh nodes' features. 84 | input_grid_nodes (torch.Tensor): grid nodes' features. 85 | input_edge_attr (torch.Tensor): grid2mesh edges' features. 86 | edge_index (torch.Tensor): edge index tensor. 87 | 88 | Returns: 89 | torch.Tensor: output grid nodes. 90 | """ 91 | if not ( 92 | input_grid_nodes.shape[-1] == self.latent_dim 93 | and input_mesh_nodes.shape[-1] == self.latent_dim 94 | ): 95 | raise ValueError( 96 | "The dimension of grid nodes and mesh nodes' features must be " 97 | "equal to the last hidden dimension." 98 | ) 99 | 100 | # Embedding 101 | edges_emb = self.edges_mlp(input_edge_attr) 102 | 103 | # Message-passing + residual connection 104 | latent_grid_nodes = input_grid_nodes + self.gnn( 105 | x=(input_mesh_nodes, input_grid_nodes), 106 | edge_index=edge_index, 107 | edge_attr=edges_emb, 108 | ) 109 | 110 | # Update grid nodes 111 | latent_grid_nodes = self.grid_mlp_final(latent_grid_nodes) 112 | 113 | return latent_grid_nodes 114 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/layers/encoder.py: -------------------------------------------------------------------------------- 1 | """Encoder layer. 2 | 3 | The encoder: 4 | - embeds grid nodes, mesh nodes and g2m edges' features to the latent space. 5 | - perform a single message-passing step using a classical interaction network. 6 | - add a residual connection to the mesh and grid nodes. 7 | """ 8 | 9 | import torch 10 | 11 | from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork 12 | 13 | 14 | class Encoder(torch.nn.Module): 15 | """GenCast's encoder.""" 16 | 17 | def __init__( 18 | self, 19 | grid_dim: int, 20 | mesh_dim: int, 21 | edge_dim: int, 22 | hidden_dims: list[int], 23 | activation_layer: torch.nn.Module = torch.nn.ReLU, 24 | use_layer_norm: bool = True, 25 | scale_factor: float = 1.0, 26 | ): 27 | """Initialize the Encoder. 28 | 29 | Args: 30 | grid_dim (int): dimension of grid nodes' features. 31 | mesh_dim (int): dimension of mesh nodes' features 32 | edge_dim (int): dimension of g2m edges' features 33 | hidden_dims (list[int]): hidden dimensions of internal MLPs. 34 | activation_layer (torch.nn.Module, optional): activation function of internal MLPs. 35 | Defaults to torch.nn.ReLU. 36 | use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP. 37 | Defaults to True. 38 | scale_factor (float): the message of the interaction network between the grid and the 39 | the mesh is multiplied by the scale factor. Useful when fine-tuning a pretrained 40 | model to a higher resolution. Defaults to 1. 41 | """ 42 | super().__init__() 43 | 44 | # All the MLPs in GenCast have same hidden and output dims. Hence, the embedding latent 45 | # dimension and the MLPs' output dimension are the same. Moreover, for simplicity, we will 46 | # ask the hidden dims just once for each MLP in a module: we don't need to specify them 47 | # individually as arguments, even if the MLPs could have different roles. 48 | self.latent_dim = hidden_dims[-1] 49 | 50 | # Embedders 51 | self.grid_mlp = MLP( 52 | input_dim=grid_dim, 53 | hidden_dims=hidden_dims, 54 | activation_layer=activation_layer, 55 | use_layer_norm=use_layer_norm, 56 | bias=True, 57 | activate_final=False, 58 | ) 59 | 60 | self.mesh_mlp = MLP( 61 | input_dim=mesh_dim, 62 | hidden_dims=hidden_dims, 63 | activation_layer=activation_layer, 64 | use_layer_norm=use_layer_norm, 65 | bias=True, 66 | activate_final=False, 67 | ) 68 | 69 | self.edges_mlp = MLP( 70 | input_dim=edge_dim, 71 | hidden_dims=hidden_dims, 72 | activation_layer=activation_layer, 73 | use_layer_norm=use_layer_norm, 74 | bias=True, 75 | activate_final=False, 76 | ) 77 | 78 | # Message Passing 79 | self.gnn = InteractionNetwork( 80 | sender_dim=self.latent_dim, 81 | receiver_dim=self.latent_dim, 82 | edge_attr_dim=self.latent_dim, 83 | hidden_dims=hidden_dims, 84 | use_layer_norm=use_layer_norm, 85 | activation_layer=activation_layer, 86 | scale_factor=scale_factor, 87 | ) 88 | 89 | # Final grid nodes update 90 | self.grid_mlp_final = MLP( 91 | input_dim=self.latent_dim, 92 | hidden_dims=hidden_dims, 93 | activation_layer=activation_layer, 94 | use_layer_norm=use_layer_norm, 95 | bias=True, 96 | activate_final=False, 97 | ) 98 | 99 | def forward( 100 | self, 101 | input_grid_nodes: torch.Tensor, 102 | input_mesh_nodes: torch.Tensor, 103 | input_edge_attr: torch.Tensor, 104 | edge_index: torch.Tensor, 105 | ) -> tuple[torch.Tensor, torch.Tensor]: 106 | """Forward pass. 107 | 108 | Args: 109 | input_grid_nodes (torch.Tensor): grid nodes' features. 110 | input_mesh_nodes (torch.Tensor): mesh nodes' features. 111 | input_edge_attr (torch.Tensor): grid2mesh edges' features. 112 | edge_index (torch.Tensor): edge index tensor. 113 | 114 | Returns: 115 | tuple[torch.Tensor, torch.Tensor]: output grid nodes, output mesh nodes. 116 | """ 117 | 118 | # Embedding 119 | grid_emb = self.grid_mlp(input_grid_nodes) 120 | mesh_emb = self.mesh_mlp(input_mesh_nodes) 121 | edges_emb = self.edges_mlp(input_edge_attr) 122 | 123 | # Message-passing + residual connection 124 | latent_mesh_nodes = mesh_emb + self.gnn( 125 | x=(grid_emb, mesh_emb), 126 | edge_index=edge_index, 127 | edge_attr=edges_emb, 128 | ) 129 | 130 | # Update grid nodes + residual connection 131 | latent_grid_nodes = grid_emb + self.grid_mlp_final(grid_emb) 132 | 133 | return latent_grid_nodes, latent_mesh_nodes 134 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/layers/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | """Experimental features.""" 2 | 3 | from .sparse_transformer import SparseTransformer 4 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/layers/experimental/sparse_transformer.py: -------------------------------------------------------------------------------- 1 | """Experimental sparse transformer using DGL sparse.""" 2 | 3 | import dgl.sparse as dglsp 4 | import torch 5 | import torch.nn as nn 6 | 7 | from graph_weather.models.gencast.layers.modules import ConditionalLayerNorm 8 | 9 | 10 | class SparseAttention(nn.Module): 11 | """Sparse Multi-head Attention Module""" 12 | 13 | def __init__(self, input_dim=512, output_dim=512, num_heads=4): 14 | """Initialize Sparse MultiHead attention module. 15 | 16 | Args: 17 | input_dim (int): input dimension. Defaults to 512. 18 | output_dim (int): output dimension. Defaults to 512. 19 | num_heads (int): number of heads. Output dimension should be divisible by num_heads. 20 | Defaults to 4. 21 | """ 22 | super().__init__() 23 | if output_dim % num_heads: 24 | raise ValueError("Output dimension should be divisible by the number of heads.") 25 | 26 | self.hidden_size = output_dim 27 | self.num_heads = num_heads 28 | self.head_dim = output_dim // num_heads 29 | self.scaling = self.head_dim**-0.5 30 | 31 | self.q_proj = nn.Linear(input_dim, output_dim) 32 | self.k_proj = nn.Linear(input_dim, output_dim) 33 | self.v_proj = nn.Linear(input_dim, output_dim) 34 | self.out_proj = nn.Linear(output_dim, output_dim) 35 | 36 | def forward(self, x: torch.Tensor, adj: dglsp.SparseMatrix): 37 | """Forward pass of SparseMHA. 38 | 39 | Args: 40 | x (torch.Tensor): input tensor. 41 | adj (SparseMatrix): adjacency matrix in DGL SparseMatrix format. 42 | 43 | Returns: 44 | y (tensor): output of MultiHead attention. 45 | """ 46 | N = len(x) 47 | # computing query,key and values. 48 | q = self.q_proj(x).reshape(N, self.head_dim, self.num_heads) # (dense) [N, dh, nh] 49 | k = self.k_proj(x).reshape(N, self.head_dim, self.num_heads) # (dense) [N, dh, nh] 50 | v = self.v_proj(x).reshape(N, self.head_dim, self.num_heads) # (dense) [N, dh, nh] 51 | # scaling query 52 | q *= self.scaling 53 | 54 | # sparse-dense-dense product 55 | attn = dglsp.bsddmm(adj, q, k.transpose(1, 0)) # (sparse) [N, N, nh] 56 | 57 | # sparse softmax (by default applies on the last sparse dimension). 58 | attn = attn.softmax() # (sparse) [N, N, nh] 59 | 60 | # sparse-dense multiplication 61 | out = dglsp.bspmm(attn, v) # (dense) [N, dh, nh] 62 | return self.out_proj(out.reshape(N, -1)) 63 | 64 | 65 | class SparseTransformer(nn.Module): 66 | """A single transformer block for graph neural networks. 67 | 68 | This module implements a single transformer block with a sparse attention mechanism. 69 | """ 70 | 71 | def __init__( 72 | self, 73 | conditioning_dim: int, 74 | input_dim: int, 75 | output_dim: int, 76 | num_heads: int, 77 | activation_layer: torch.nn.Module = nn.ReLU, 78 | norm_first: bool = True, 79 | ): 80 | """Initialize SparseTransformer module. 81 | 82 | Args: 83 | conditioning_dim (int, optional): dimension of the conditioning parameter. If None the 84 | layer normalization will not be applied. 85 | input_dim (int): dimension of the input features. 86 | output_dim (int): dimension of the output features. 87 | edges_dim (int): dimension of the edge features. 88 | num_heads (int): number of heads for multi-head attention. 89 | activation_layer (torch.nn.Module): activation function applied before 90 | returning the output. 91 | norm_first (bool): if True apply layer normalization before attention. Defaults to True. 92 | """ 93 | super().__init__() 94 | 95 | # initialize multihead sparse attention. 96 | self.sparse_attention = SparseAttention( 97 | input_dim=input_dim, output_dim=output_dim, num_heads=num_heads 98 | ) 99 | 100 | # initialize mlp 101 | self.activation = activation_layer() 102 | self.mlp = nn.Sequential( 103 | nn.Linear(output_dim, output_dim), self.activation, nn.Linear(output_dim, output_dim) 104 | ) 105 | 106 | # initialize conditional layer normalization 107 | self.cond_norm_1 = ConditionalLayerNorm( 108 | conditioning_dim=conditioning_dim, features_dim=output_dim 109 | ) 110 | self.cond_norm_2 = ConditionalLayerNorm( 111 | conditioning_dim=conditioning_dim, features_dim=output_dim 112 | ) 113 | 114 | self.norm_first = norm_first 115 | 116 | def forward( 117 | self, 118 | x: torch.Tensor, 119 | edge_index: torch.Tensor, 120 | cond_param: torch.Tensor, 121 | *args, 122 | **kwargs, 123 | ) -> torch.Tensor: 124 | """Apply SparseTransformer to input. 125 | 126 | Input and conditioning parameter must have same batch size. 127 | 128 | Args: 129 | x (torch.Tensor): tensor containing nodes features. 130 | edge_index (torch.Tensor): edge index tensor. 131 | cond_param (torch.Tensor): conditioning parameter. 132 | *args: ignored by the module. 133 | **kwargs: ignored by the module. 134 | 135 | """ 136 | if self.norm_first: 137 | x1 = self.cond_norm_1(x, cond_param) 138 | x = x + self.sparse_attention( 139 | x=x1, adj=dglsp.spmatrix(indices=edge_index, shape=(x.shape[0], x.shape[0])) 140 | ) 141 | else: 142 | x = x + self.sparse_attention( 143 | x=x, adj=dglsp.spmatrix(indices=edge_index, shape=(x.shape[0], x.shape[0])) 144 | ) 145 | x = self.cond_norm_1(x, cond_param) 146 | 147 | if self.norm_first: 148 | x2 = self.cond_norm_2(x, cond_param) 149 | x = x + self.mlp(x2) 150 | else: 151 | x = x + self.mlp(x) 152 | x = self.cond_norm_2(x, cond_param) 153 | return x 154 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/layers/processor.py: -------------------------------------------------------------------------------- 1 | """Processor layer. 2 | 3 | The processor: 4 | - compute a sequence of transformer blocks applied to the mesh. 5 | - condition on the noise level. 6 | """ 7 | 8 | import torch 9 | 10 | from graph_weather.models.gencast.layers.modules import MLP, CondTransformerBlock, FourierEmbedding 11 | 12 | try: 13 | from graph_weather.models.gencast.layers.experimental import SparseTransformer 14 | 15 | has_dgl = True 16 | except ImportError: 17 | has_dgl = False 18 | 19 | 20 | class Processor(torch.nn.Module): 21 | """GenCast's Processor 22 | 23 | The Processor is a sequence of transformer blocks conditioned on noise level. If the graph has 24 | many edges, setting sparse=True may perform better in terms of memory and speed. Note that 25 | sparse=False uses PyG as the backend, while sparse=True uses DGL. The two implementations are 26 | not exactly equivalent: the former is described in the paper "Masked Label Prediction: Unified 27 | Message Passing Model for Semi-Supervised Classification" and can also handle edge features, 28 | while the latter is a classical transformer that performs multi-head attention utilizing the 29 | mask's sparsity and does not include edge features in the computations. 30 | 31 | Note: The GenCast paper does not provide specific details regarding the implementation of the 32 | transformer architecture for graphs. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | latent_dim: int, 38 | hidden_dims: list[int], 39 | num_blocks: int, 40 | num_heads: int, 41 | num_frequencies: int, 42 | base_period: int, 43 | noise_emb_dim: int, 44 | edges_dim: int | None = None, 45 | activation_layer: torch.nn.Module = torch.nn.ReLU, 46 | use_layer_norm: bool = True, 47 | sparse: bool = False, 48 | ): 49 | """Initialize the Processor. 50 | 51 | Args: 52 | latent_dim (int): dimension of nodes' features. 53 | hidden_dims (list[int]): hidden dimensions of internal MLPs. 54 | num_blocks (int): number of transformer blocks. 55 | num_heads (int): number of heads for multi-head attention. 56 | num_frequencies (int): number of frequencies for the noise Fourier embedding. 57 | base_period (int): base period for the noise Fourier embedding. 58 | noise_emb_dim (int): dimension of output of noise embedding. 59 | edges_dim (int, optional): dimension of edges' features. If None does not uses edges 60 | features in TransformerConv. Defaults to None. 61 | activation_layer (torch.nn.Module): activation function of internal MLPs. 62 | Defaults to torch.nn.ReLU. 63 | use_layer_norm (bool): if true add a LayerNorm at the end of the embedding MLP. 64 | Defaults to True. 65 | sparse (bool): if true use DGL as backend (experimental). Defaults to False. 66 | """ 67 | super().__init__() 68 | self.latent_dim = latent_dim 69 | if latent_dim % num_heads != 0: 70 | raise ValueError("The latent dimension should be divisible by the number of heads.") 71 | 72 | # Embedders 73 | self.fourier_embedder = FourierEmbedding( 74 | output_dim=noise_emb_dim, num_frequencies=num_frequencies, base_period=base_period 75 | ) 76 | 77 | self.edges_dim = edges_dim 78 | if edges_dim is not None: 79 | self.edges_mlp = MLP( 80 | input_dim=edges_dim, 81 | hidden_dims=hidden_dims, 82 | activation_layer=activation_layer, 83 | use_layer_norm=use_layer_norm, 84 | bias=True, 85 | activate_final=False, 86 | ) 87 | 88 | # Tranformers Blocks 89 | self.cond_transformers = torch.nn.ModuleList() 90 | if not sparse: 91 | for _ in range(num_blocks - 1): 92 | # concatenating multi-head attention 93 | self.cond_transformers.append( 94 | CondTransformerBlock( 95 | conditioning_dim=noise_emb_dim, 96 | input_dim=latent_dim, 97 | output_dim=latent_dim // num_heads, 98 | edges_dim=hidden_dims[-1] if (edges_dim is not None) else None, 99 | num_heads=num_heads, 100 | concat=True, 101 | beta=True, 102 | activation_layer=activation_layer, 103 | ) 104 | ) 105 | 106 | # averaging multi-head attention 107 | self.cond_transformers.append( 108 | CondTransformerBlock( 109 | conditioning_dim=noise_emb_dim, 110 | input_dim=latent_dim, 111 | output_dim=latent_dim, 112 | edges_dim=hidden_dims[-1] if (edges_dim is not None) else None, 113 | num_heads=num_heads, 114 | concat=False, 115 | beta=True, 116 | activation_layer=None, 117 | ) 118 | ) 119 | else: 120 | if not has_dgl: 121 | raise ValueError("Please install DGL to use sparsity.") 122 | 123 | for _ in range(num_blocks): 124 | # concatenating multi-head attention 125 | self.cond_transformers.append( 126 | SparseTransformer( 127 | conditioning_dim=noise_emb_dim, 128 | input_dim=latent_dim, 129 | output_dim=latent_dim, 130 | num_heads=num_heads, 131 | activation_layer=activation_layer, 132 | ) 133 | ) 134 | # do we really need averaging for last block? 135 | 136 | def _check_args(self, latent_mesh_nodes, noise_levels, input_edge_attr): 137 | if not latent_mesh_nodes.shape[-1] == self.latent_dim: 138 | raise ValueError( 139 | "The dimension of the mesh nodes is different from the latent dimension provided at" 140 | " initialization." 141 | ) 142 | 143 | if not latent_mesh_nodes.shape[0] == noise_levels.shape[0]: 144 | raise ValueError( 145 | "The number of noise levels and mesh nodes should be the same, but got " 146 | f"{latent_mesh_nodes.shape[0]} and {noise_levels.shape[0]}. Eventually repeat the " 147 | " noise level for each node in the same batch." 148 | ) 149 | 150 | if (input_edge_attr is not None) and (self.edges_dim is None): 151 | raise ValueError("To use input_edge_attr initialize the processor with edges_dim.") 152 | 153 | def forward( 154 | self, 155 | latent_mesh_nodes: torch.Tensor, 156 | edge_index: torch.Tensor, 157 | noise_levels: torch.Tensor, 158 | input_edge_attr: torch.Tensor | None = None, 159 | ) -> torch.Tensor: 160 | """Forward pass. 161 | 162 | Args: 163 | latent_mesh_nodes (torch.Tensor): mesh nodes' features. 164 | edge_index (torch.Tensor): edge index tensor. 165 | noise_levels (torch.Tensor): log-noise levels. 166 | input_edge_attr (torch.Tensor, optional): mesh edges' features. 167 | 168 | Returns: 169 | torch.Tensor: latent mesh nodes. 170 | """ 171 | self._check_args(latent_mesh_nodes, noise_levels, input_edge_attr) 172 | 173 | # embedding 174 | noise_emb = self.fourier_embedder(noise_levels) 175 | 176 | if self.edges_dim is not None: 177 | edges_emb = self.edges_mlp(input_edge_attr) 178 | else: 179 | edges_emb = None 180 | 181 | # apply transformer blocks 182 | for cond_transformer in self.cond_transformers: 183 | latent_mesh_nodes = cond_transformer( 184 | x=latent_mesh_nodes, 185 | edge_index=edge_index, 186 | cond_param=noise_emb, 187 | edge_attr=edges_emb, 188 | ) 189 | 190 | return latent_mesh_nodes 191 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/sampler.py: -------------------------------------------------------------------------------- 1 | """Diffusion sampler""" 2 | 3 | import math 4 | 5 | import torch 6 | 7 | from graph_weather.models.gencast import Denoiser 8 | from graph_weather.models.gencast.utils.noise import generate_isotropic_noise 9 | 10 | 11 | class Sampler: 12 | """Sampler for the denoiser. 13 | 14 | The sampler consists in the second-order DPMSolver++2S solver (Lu et al., 2022), augmented with 15 | the stochastic churn (again making use of the isotropic noise) and noise inflation techniques 16 | used in Karras et al. (2022) to inject further stochasticity into the sampling process. In 17 | conditioning on previous timesteps it follows the Conditional Denoising Estimator approach 18 | outlined and motivated by Batzolis et al. (2021). 19 | """ 20 | 21 | def __init__( 22 | self, 23 | S_noise: float = 1.05, 24 | S_tmin: float = 0.75, 25 | S_tmax: float = 80.0, 26 | S_churn: float = 2.5, 27 | r: float = 0.5, 28 | sigma_max: float = 80.0, 29 | sigma_min: float = 0.03, 30 | rho: float = 7, 31 | num_steps: int = 20, 32 | ): 33 | """Initialize the sampler. 34 | 35 | Args: 36 | S_noise (float): noise inflation parameter. Defaults to 1.05. 37 | S_tmin (float): minimum noise for sampling. Defaults to 0.75. 38 | S_tmax (float): maximum noise for sampling. Defaults to 80. 39 | S_churn (float): stochastic churn rate. Defaults to 2.5. 40 | r (float): _description_. Defaults to 0.5. 41 | sigma_max (float): maximum value of sigma for sigma's distribution. Defaults to 80. 42 | sigma_min (float): minimum value of sigma for sigma's distribution. Defaults to 0.03. 43 | rho (float): exponent of the sigma's distribution. Defaults to 7. 44 | num_steps (int): number of timesteps during sampling. Defaults to 20. 45 | """ 46 | self.S_noise = S_noise 47 | self.S_tmin = S_tmin 48 | self.S_tmax = S_tmax 49 | self.S_churn = S_churn 50 | self.r = r 51 | self.num_steps = num_steps 52 | 53 | self.sigma_max = sigma_max 54 | self.sigma_min = sigma_min 55 | self.rho = rho 56 | 57 | def _sigmas_fn(self, u): 58 | return ( 59 | self.sigma_max ** (1 / self.rho) 60 | + u * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) 61 | ) ** self.rho 62 | 63 | @torch.no_grad() 64 | def sample(self, denoiser: Denoiser, prev_inputs: torch.Tensor): 65 | """Generate a sample from random noise for the given inputs. 66 | 67 | Args: 68 | denoiser (Denoiser): the denoiser model. 69 | prev_inputs (torch.Tensor): previous two timesteps. 70 | 71 | Returns: 72 | torch.Tensor: normalized residuals predicted. 73 | """ 74 | device = prev_inputs.device 75 | 76 | time_steps = torch.arange(0, self.num_steps).to(device) / (self.num_steps - 1) 77 | sigmas = self._sigmas_fn(time_steps) 78 | 79 | batch_ones = torch.ones(1, 1).to(device) 80 | 81 | # initialize noise 82 | x = sigmas[0] * torch.tensor( 83 | generate_isotropic_noise( 84 | num_lon=denoiser.num_lon, 85 | num_lat=denoiser.num_lat, 86 | num_samples=denoiser.output_features_dim, 87 | ) 88 | ).unsqueeze(0).to(device) 89 | 90 | for i in range(len(sigmas) - 1): 91 | # stochastic churn from Karras et al. (Alg. 2) 92 | gamma = ( 93 | min(self.S_churn / self.num_steps, math.sqrt(2) - 1) 94 | if self.S_tmin <= sigmas[i] <= self.S_tmax 95 | else 0.0 96 | ) 97 | # noise inflation from Karras et al. (Alg. 2) 98 | noise = self.S_noise * torch.tensor( 99 | generate_isotropic_noise( 100 | num_lon=denoiser.num_lon, 101 | num_lat=denoiser.num_lat, 102 | num_samples=denoiser.output_features_dim, 103 | ) 104 | ) 105 | noise = noise.unsqueeze(0).to(device) 106 | 107 | sigma_hat = sigmas[i] * (gamma + 1) 108 | if gamma > 0: 109 | x = x + (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 * noise 110 | denoised = denoiser(x, prev_inputs, sigma_hat * batch_ones) 111 | 112 | if i == len(sigmas) - 2: 113 | # final Euler step 114 | d = (x - denoised) / sigma_hat 115 | x = x + d * (sigmas[i + 1] - sigma_hat) 116 | else: 117 | # DPMSolver++2S step (Alg. 1 in Lu et al.) with alpha_t=1. 118 | # t_{i-1} is t_hat because of stochastic churn! 119 | lambda_hat = -torch.log(sigma_hat) 120 | lambda_next = -torch.log(sigmas[i + 1]) 121 | h = lambda_next - lambda_hat 122 | lambda_mid = lambda_hat + self.r * h 123 | sigma_mid = torch.exp(-lambda_mid) 124 | 125 | u = sigma_mid / sigma_hat * x - (torch.exp(-self.r * h) - 1) * denoised 126 | denoised_2 = denoiser(u, prev_inputs, sigma_mid * batch_ones) 127 | D = (1 - 1 / (2 * self.r)) * denoised + 1 / (2 * self.r) * denoised_2 128 | x = sigmas[i + 1] / sigma_hat * x - (torch.exp(-h) - 1) * D 129 | 130 | return x 131 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils for gencast.""" 2 | 3 | from .noise import generate_isotropic_noise 4 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/utils/batching.py: -------------------------------------------------------------------------------- 1 | """Utils for batching graphs.""" 2 | 3 | import torch 4 | 5 | 6 | def batch(senders, edge_index, edge_attr=None, batch_size=1): 7 | """Build big batched graph. 8 | 9 | Returns nodes and edges of a big graph with batch_size disconnected copies of the original 10 | graph, with features shape [(b n) f]. 11 | 12 | Args: 13 | senders (torch.Tensor): nodes' features. 14 | edge_index (torch.Tensor): edge index tensor. 15 | edge_attr (torch.Tensor, optional): edge attributes tensor, if None returns None. 16 | Defaults to None. 17 | batch_size (int): batch size. Defaults to 1. 18 | 19 | Returns: 20 | batched_senders, batched_edge_index, batched_edge_attr 21 | """ 22 | ns = senders.shape[0] 23 | batched_senders = senders 24 | batched_edge_attr = edge_attr 25 | batched_edge_index = edge_index 26 | 27 | for i in range(1, batch_size): 28 | batched_senders = torch.cat([batched_senders, senders], dim=0) 29 | batched_edge_index = torch.cat([batched_edge_index, edge_index + i * ns], dim=1) 30 | 31 | if edge_attr is not None: 32 | batched_edge_attr = torch.cat([batched_edge_attr, edge_attr], dim=0) 33 | 34 | return batched_senders, batched_edge_index, batched_edge_attr 35 | 36 | 37 | def hetero_batch(senders, receivers, edge_index, edge_attr=None, batch_size=1): 38 | """Build big batched heterogenous graph. 39 | 40 | Returns nodes and edges of a big graph with batch_size disconnected copies of the original 41 | graph, with features shape [(b n) f]. 42 | 43 | Args: 44 | senders (torch.Tensor): senders' features. 45 | receivers (torch.Tensor): receivers' features. 46 | edge_index (torch.Tensor): edge index tensor. 47 | edge_attr (torch.Tensor, optional): edge attributes tensor, if None returns None. 48 | Defaults to None. 49 | batch_size (int): batch size. Defaults to 1. 50 | 51 | Returns: 52 | batched_senders, batched_edge_index, batched_edge_attr 53 | """ 54 | ns = senders.shape[0] 55 | nr = receivers.shape[0] 56 | nodes_shape = torch.tensor([[ns], [nr]]).to(edge_index) 57 | batched_senders = senders 58 | batched_receivers = receivers 59 | batched_edge_attr = edge_attr 60 | batched_edge_index = edge_index 61 | 62 | for i in range(1, batch_size): 63 | batched_senders = torch.cat([batched_senders, senders], dim=0) 64 | batched_receivers = torch.cat([batched_receivers, receivers], dim=0) 65 | batched_edge_index = torch.cat([batched_edge_index, edge_index + i * nodes_shape], dim=1) 66 | if edge_attr is not None: 67 | batched_edge_attr = torch.cat([batched_edge_attr, edge_attr], dim=0) 68 | 69 | return batched_senders, batched_receivers, batched_edge_index, batched_edge_attr 70 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/utils/noise.py: -------------------------------------------------------------------------------- 1 | """Noise generation utils.""" 2 | 3 | import einops 4 | import numpy as np 5 | import torch 6 | import torch_harmonics as th 7 | 8 | 9 | def generate_isotropic_noise(num_lon: int, num_lat: int, num_samples=1, isotropic=True): 10 | """Generate noise on the grid. 11 | 12 | When isotropic is True it samples the equivalent of white noise on a sphere and project it onto 13 | a grid using Driscoll and Healy, 1994, algorithm. The power spectrum is normalized to have 14 | variance 1. We need to assume lons = 2 * lats or lons = 2 * (lats -1). If isotropic is false, it 15 | samples flat normal random noise. 16 | 17 | Args: 18 | num_lon (int): number of longitudes in the grid. 19 | num_lat (int): number of latitudes in the grid. 20 | num_samples (int): number of indipendent samples. Defaults to 1. 21 | isotropic (bool): if true generates isotropic noise, else flat noise. Defaults to True. 22 | 23 | Returns: 24 | grid: Numpy array with shape shape(grid) x num_samples. 25 | """ 26 | if isotropic: 27 | if 2 * num_lat == num_lon: 28 | extend = False 29 | elif 2 * (num_lat - 1) == num_lon: 30 | extend = True 31 | else: 32 | raise ValueError( 33 | "Isotropic noise requires grid's shape to be 2N x N or 2N x (N+1): " 34 | f"got {num_lon} x {num_lat}. If the shape is correct, please specify" 35 | "isotropic=False in the constructor.", 36 | ) 37 | 38 | if isotropic: 39 | lmax = num_lat - 1 if extend else num_lat 40 | mmax = lmax + 1 41 | coeffs = torch.randn(num_samples, lmax, mmax, dtype=torch.complex64) / np.sqrt( 42 | (num_lat**2) // 2 43 | ) 44 | isht = th.InverseRealSHT( 45 | nlat=num_lat, nlon=num_lon, lmax=lmax, mmax=mmax, grid="equiangular" 46 | ) 47 | noise = isht(coeffs) * np.sqrt(2 * np.pi) 48 | noise = einops.rearrange(noise, "b lat lon -> lon lat b").numpy() 49 | else: 50 | noise = np.random.randn(num_lon, num_lat, num_samples) 51 | return noise 52 | 53 | 54 | def sample_noise_level(sigma_min=0.02, sigma_max=88, rho=7): 55 | """Generate random sample of noise level. 56 | 57 | Sample a noise level according to the distribution described in the paper. 58 | Notice that the default values are valid only for training and need to be 59 | modified for sampling. 60 | 61 | Args: 62 | sigma_min (float, optional): Defaults to 0.02. 63 | sigma_max (int, optional): Defaults to 88. 64 | rho (int, optional): Defaults to 7. 65 | 66 | Returns: 67 | noise_level: single sample of noise level. 68 | """ 69 | u = np.random.random() 70 | noise_level = ( 71 | sigma_max ** (1 / rho) + u * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) 72 | ) ** rho 73 | return noise_level 74 | 75 | 76 | class Preconditioner(torch.nn.Module): 77 | """Collection of preconditioning functions. 78 | 79 | These functions are described in Karras (2022), table 1. 80 | """ 81 | 82 | def __init__(self, sigma_data: float = 1): 83 | """Initialize the preconditioning functions. 84 | 85 | Args: 86 | sigma_data (float): Karras suggests 0.5, GenCast 1. Defaults to 1. 87 | """ 88 | super().__init__() 89 | self.sigma_data = sigma_data 90 | 91 | def c_skip(self, sigma): 92 | """Scaling factor for skip connection.""" 93 | return self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 94 | 95 | def c_out(self, sigma): 96 | """Scaling factor for output.""" 97 | return sigma * self.sigma_data / torch.sqrt(sigma**2 + self.sigma_data**2) 98 | 99 | def c_in(self, sigma): 100 | """Scaling factor for input.""" 101 | return 1 / torch.sqrt(sigma**2 + self.sigma_data**2) 102 | 103 | def c_noise(self, sigma): 104 | """Scaling factor for noise level.""" 105 | return 1 / 4 * torch.log(sigma) 106 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/utils/statistics.py: -------------------------------------------------------------------------------- 1 | """Statistics computation utils.""" 2 | 3 | import apache_beam # noqa: F401 4 | import numpy as np 5 | import weatherbench2 # noqa: F401 6 | import xarray as xr 7 | 8 | 9 | def compute_statistics(dataset, vars, num_samples=100, single=False): 10 | """Compute statistics for single timestep. 11 | 12 | Args: 13 | dataset: xarray dataset. 14 | vars: list of features. 15 | num_samples (int, optional): _description_. Defaults to 100. 16 | single (bool, optional): if the features have multiple pressure levels. Defaults to False. 17 | 18 | Returns: 19 | means: dict with the means. 20 | stds: dict with the stds. 21 | """ 22 | means = {} 23 | stds = {} 24 | for var in vars: 25 | print(f"Computing statistics for {var}") 26 | random_indexes = np.random.randint(0, len(dataset.time), num_samples) 27 | samples = data.isel(time=random_indexes)[var].values 28 | samples = np.nan_to_num(samples) 29 | axis_tuple = (0, 1, 2) if single else (0, 2, 3) 30 | means[var] = samples.mean(axis=axis_tuple) 31 | stds[var] = samples.std(axis=axis_tuple) 32 | return means, stds 33 | 34 | 35 | def compute_statistics_diff(dataset, vars, num_samples=100, single=False, timestep=2): 36 | """Compute statistics for difference of two timesteps. 37 | 38 | Args: 39 | dataset: xarray dataset. 40 | vars: list of features. 41 | num_samples (int, optional): _description_. Defaults to 100. 42 | single (bool, optional): if the features have multiple pressure levels. Defaults to False. 43 | timestep (int, optional): number of steps to consider between start and end. Defaults to 2. 44 | 45 | Returns: 46 | means: dict with the means. 47 | stds: dict with the stds. 48 | """ 49 | means = {} 50 | stds = {} 51 | for var in vars: 52 | print(f"Computing statistics for {var}") 53 | random_indexes = np.random.randint(0, len(dataset.time), num_samples) 54 | samples_start = data.isel(time=random_indexes)[var].values 55 | samples_start = np.nan_to_num(samples_start) 56 | samples_end = data.isel(time=random_indexes + timestep)[var].values 57 | samples_end = np.nan_to_num(samples_end) 58 | axis_tuple = (0, 1, 2) if single else (0, 2, 3) 59 | means[var] = (samples_end - samples_start).mean(axis=axis_tuple) 60 | stds[var] = (samples_end - samples_start).std(axis=axis_tuple) 61 | return means, stds 62 | 63 | 64 | atmospheric_features = [ 65 | "geopotential", 66 | "specific_humidity", 67 | "temperature", 68 | "u_component_of_wind", 69 | "v_component_of_wind", 70 | "vertical_velocity", 71 | ] 72 | 73 | single_features = [ 74 | "2m_temperature", 75 | "10m_u_component_of_wind", 76 | "10m_v_component_of_wind", 77 | "mean_sea_level_pressure", 78 | # "sea_surface_temperature", 79 | "total_precipitation_12hr", 80 | ] 81 | 82 | static_features = ["geopotential_at_surface", "land_sea_mask"] 83 | 84 | obs_path = "gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr" 85 | # obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-1440x721.zarr' 86 | # obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-512x256_equiangular_conservative.zarr' 87 | data = xr.open_zarr(obs_path) 88 | num_samples = 100 89 | means, stds = compute_statistics_diff(data, single_features, num_samples=num_samples, single=True) 90 | print("Means: ", means) 91 | print("Stds: ", stds) 92 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/weighted_mse_loss.py: -------------------------------------------------------------------------------- 1 | """The weighted loss function for GenCast training.""" 2 | 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class WeightedMSELoss(torch.nn.Module): 10 | """Module WeightedMSELoss. 11 | 12 | This module implement the loss described in GenCast's paper. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | grid_lat: Optional[torch.Tensor] = None, 18 | pressure_levels: Optional[torch.Tensor] = None, 19 | num_atmospheric_features: Optional[int] = None, 20 | single_features_weights: Optional[torch.Tensor] = None, 21 | ): 22 | """Initialize the WeightedMSELoss Module. 23 | 24 | More details about the features weights are reported in GraphCast's paper. In short, if the 25 | single features are "2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind", 26 | "mean_sea_level_pressure" and "total_precipitation_12hr", then it's suggested to set 27 | corresponding weights as 1, 0.1, 0.1, 0.1 and 0.1. 28 | 29 | Args: 30 | grid_lat (torch.Tensor, optional): 1D tensor containing all the latitudes. 31 | pressure_levels (torch.Tensor, optional): 1D tensor containing all the pressure levels 32 | per variable. 33 | num_atmospheric_features (int, optional): number of atmospheric features. 34 | single_features_weights (torch.Tensor, optional): 1D tensor containing single features 35 | weights. 36 | """ 37 | super().__init__() 38 | 39 | area_weights = None 40 | features_weights = None 41 | 42 | if grid_lat is not None: 43 | area_weights = torch.abs(torch.cos(grid_lat * np.pi / 180.0)) 44 | area_weights = area_weights / torch.mean(area_weights) 45 | if ( 46 | pressure_levels is not None 47 | and num_atmospheric_features is not None 48 | and single_features_weights is not None 49 | ): 50 | pressure_weights = pressure_levels / torch.sum(pressure_levels) 51 | features_weights = torch.cat( 52 | (pressure_weights.repeat(num_atmospheric_features), single_features_weights), dim=-1 53 | ) 54 | elif ( 55 | pressure_levels is not None 56 | or num_atmospheric_features is not None 57 | or single_features_weights is not None 58 | ): 59 | raise ValueError( 60 | "Please to use features weights provide all three: pressure_levels," 61 | "num_atmospheric_features and single_features_weights." 62 | ) 63 | 64 | self.sigma_data = 1 # assuming normalized data! 65 | 66 | self.register_buffer("area_weights", area_weights, persistent=False) 67 | self.register_buffer("features_weights", features_weights, persistent=False) 68 | 69 | def _lambda_sigma(self, noise_level): 70 | noise_weights = (noise_level**2 + self.sigma_data**2) / (noise_level * self.sigma_data) ** 2 71 | return noise_weights # [batch, 1] 72 | 73 | def forward( 74 | self, 75 | pred: torch.Tensor, 76 | noise_level: torch.Tensor, 77 | target: torch.Tensor, 78 | ) -> torch.Tensor: 79 | """Compute the loss. 80 | 81 | Args: 82 | pred (torch.Tensor): prediction of the model [batch, lon, lat, var]. 83 | noise_level (torch.Tensor): noise levels fed to the model for the corresponding 84 | predictions [batch, 1]. 85 | target (torch.Tensor): target tensor [batch, lon, lat, var]. 86 | 87 | Returns: 88 | torch.Tensor: weighted MSE loss. 89 | """ 90 | # check shapes 91 | if not (pred.shape == target.shape): 92 | raise ValueError( 93 | "Predictions and targets must have same shape. The actual shapes " 94 | f"are {pred.shape} and {target.shape}." 95 | ) 96 | if not (len(pred.shape) == 4): 97 | raise ValueError( 98 | "The expected shape for predictions and targets is " 99 | f"[batch, lon, lat, var], but got {pred.shape}." 100 | ) 101 | if not (noise_level.shape == (pred.shape[0], 1)): 102 | raise ValueError( 103 | f"The expected shape for noise levels is [batch, 1], but got {noise_level.shape}." 104 | ) 105 | 106 | # compute square residuals 107 | loss = (pred - target) ** 2 # [batch, lon, lat, var] 108 | if torch.isnan(loss).any(): 109 | raise ValueError("NaN values encountered in loss calculation.") 110 | 111 | # apply area and features weights to residuals 112 | if self.area_weights is not None: 113 | if not (len(self.area_weights) == pred.shape[2]): 114 | raise ValueError( 115 | f"The size of grid_lat at initialization ({len(self.area_weights)}) " 116 | f"and the number of latitudes in predictions ({pred.shape[2]}) " 117 | "don't match." 118 | ) 119 | loss *= self.area_weights[None, None, :, None] 120 | 121 | if self.features_weights is not None: 122 | if not (len(self.features_weights) == pred.shape[-1]): 123 | raise ValueError( 124 | f"The size of features weights at initialization ({len(self.features_weights)})" 125 | f" and the number of features in predictions ({pred.shape[-1]}) " 126 | "don't match." 127 | ) 128 | loss *= self.features_weights[None, None, None, :] 129 | 130 | # compute means across lon, lat, var for each sample in the batch 131 | loss = loss.flatten(1).mean(-1) # [batch] 132 | 133 | # weight each sample using the corresponding noise level, then return the mean. 134 | loss *= self._lambda_sigma(noise_level).flatten() 135 | return loss.mean() 136 | -------------------------------------------------------------------------------- /graph_weather/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """Layers for use in models""" 2 | -------------------------------------------------------------------------------- /graph_weather/models/layers/assimilator_decoder.py: -------------------------------------------------------------------------------- 1 | """Decoders to decode from the Processor graph to the original graph with updated values 2 | 3 | In the original paper the decoder is described as 4 | 5 | The Decoder maps back to physical data defined on a latitude/longitude grid. The underlying graph is 6 | again bipartite, this time mapping icosahedron→lat/lon. 7 | The inputs to the Decoder come from the Processor, plus a skip connection back to the original 8 | state of the 78 atmospheric variables onthe latitude/longitude grid. 9 | The output of the Decoder is the predicted 6-hour change in the 78 atmospheric variables, 10 | which is then added to the initial state to produce the new state. We found 6 hours to be a good 11 | balance between shorter time steps (simpler dynamics to model but more iterations required during 12 | rollout) and longer time steps (fewer iterations required during rollout but modeling 13 | more complex dynamics) 14 | 15 | """ 16 | 17 | import einops 18 | import h3 19 | import numpy as np 20 | import torch 21 | from torch_geometric.data import Data 22 | 23 | from graph_weather.models.layers.graph_net_block import MLP, GraphProcessor 24 | 25 | 26 | class AssimilatorDecoder(torch.nn.Module): 27 | """Assimilator graph module""" 28 | 29 | def __init__( 30 | self, 31 | lat_lons: list, 32 | resolution: int = 2, 33 | input_dim: int = 256, 34 | output_dim: int = 78, 35 | output_edge_dim: int = 256, 36 | hidden_dim_processor_node: int = 256, 37 | hidden_dim_processor_edge: int = 256, 38 | hidden_layers_processor_node: int = 2, 39 | hidden_layers_processor_edge: int = 2, 40 | mlp_norm_type: str = "LayerNorm", 41 | hidden_dim_decoder: int = 128, 42 | hidden_layers_decoder: int = 2, 43 | use_checkpointing: bool = False, 44 | ): 45 | """ 46 | Decoder from latent graph to lat/lon graph for assimilation of observation 47 | 48 | Args: 49 | lat_lons: List of (lat,lon) points 50 | resolution: H3 resolution level 51 | input_dim: Input node dimension 52 | output_dim: Output node dimension 53 | output_edge_dim: Edge dimension 54 | hidden_dim_processor_node: Hidden dimension of the node processors 55 | hidden_dim_processor_edge: Hidden dimension of the edge processors 56 | hidden_layers_processor_node: Number of hidden layers in the node processors 57 | hidden_layers_processor_edge: Number of hidden layers in the edge processors 58 | hidden_dim_decoder:Number of hidden dimensions in the decoder 59 | hidden_layers_decoder: Number of layers in the decoder 60 | mlp_norm_type: Type of norm for the MLPs 61 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None 62 | use_checkpointing: Whether to use gradient checkpointing to reduce model size 63 | """ 64 | super().__init__() 65 | self.use_checkpointing = use_checkpointing 66 | self.num_latlons = len(lat_lons) 67 | self.base_h3_grid = sorted(list(h3.uncompact(h3.get_res0_indexes(), resolution))) 68 | self.num_h3 = len(self.base_h3_grid) 69 | self.h3_grid = [h3.geo_to_h3(lat, lon, resolution) for lat, lon in lat_lons] 70 | self.h3_to_index = {} 71 | h_index = len(self.base_h3_grid) 72 | for h in self.base_h3_grid: 73 | if h not in self.h3_to_index: 74 | h_index -= 1 75 | self.h3_to_index[h] = h_index 76 | self.h3_mapping = {} 77 | for h, value in enumerate(self.h3_grid): 78 | self.h3_mapping[h + self.num_h3] = value 79 | 80 | # Build the default graph 81 | # Extra starting ones for appending to inputs, could 'learn' good starting points 82 | self.latlon_nodes = torch.zeros((len(lat_lons), input_dim), dtype=torch.float) 83 | # Get connections between lat nodes and h3 nodes TODO Paper makes it seem like the 3 84 | # closest iso points map to the lat/lon point Do kring 1 around current h3 cell, 85 | # and calculate distance between all those points and the lat/lon one, choosing the 86 | # nearest N (3) For a bit simpler, just include them all with their distances 87 | edge_sources = [] 88 | edge_targets = [] 89 | self.h3_to_lat_distances = [] 90 | for node_index, h_node in enumerate(self.h3_grid): 91 | # Get h3 index 92 | h_points = h3.k_ring(self.h3_mapping[node_index + self.num_h3], 1) 93 | for h in h_points: 94 | distance = h3.point_dist(lat_lons[node_index], h3.h3_to_geo(h), unit="rads") 95 | self.h3_to_lat_distances.append([np.sin(distance), np.cos(distance)]) 96 | edge_sources.append(self.h3_to_index[h]) 97 | edge_targets.append(node_index + self.num_h3) 98 | edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long) 99 | self.h3_to_lat_distances = torch.tensor(self.h3_to_lat_distances, dtype=torch.float) 100 | 101 | # Use normal graph as its a bit simpler 102 | self.graph = Data(edge_index=edge_index, edge_attr=self.h3_to_lat_distances) 103 | 104 | self.edge_encoder = MLP( 105 | 2, output_edge_dim, hidden_dim_processor_edge, 2, mlp_norm_type, self.use_checkpointing 106 | ) 107 | self.graph_processor = GraphProcessor( 108 | mp_iterations=1, 109 | in_dim_node=input_dim, 110 | in_dim_edge=output_edge_dim, 111 | hidden_dim_node=hidden_dim_processor_node, 112 | hidden_dim_edge=hidden_dim_processor_edge, 113 | hidden_layers_node=hidden_layers_processor_node, 114 | hidden_layers_edge=hidden_layers_processor_edge, 115 | norm_type=mlp_norm_type, 116 | ) 117 | self.node_decoder = MLP( 118 | input_dim, 119 | output_dim, 120 | hidden_dim_decoder, 121 | hidden_layers_decoder, 122 | None, 123 | self.use_checkpointing, 124 | ) 125 | 126 | def forward(self, processor_features: torch.Tensor, batch_size: int) -> torch.Tensor: 127 | """ 128 | Adds features to the encoding graph 129 | 130 | Args: 131 | processor_features: Processed features in shape [B*Nodes, Features] 132 | batch_size: Batch size 133 | 134 | Returns: 135 | Updated features for model 136 | """ 137 | self.graph = self.graph.to(processor_features.device) 138 | edge_attr = self.edge_encoder(self.graph.edge_attr) # Update attributes based on distance 139 | edge_attr = einops.repeat(edge_attr, "e f -> (repeat e) f", repeat=batch_size) 140 | 141 | edge_index = torch.cat( 142 | [ 143 | self.graph.edge_index + i * torch.max(self.graph.edge_index) + i 144 | for i in range(batch_size) 145 | ], 146 | dim=1, 147 | ) 148 | 149 | # Readd nodes to match graph node number 150 | self.latlon_nodes = self.latlon_nodes.to(processor_features.device) 151 | features = einops.rearrange(processor_features, "(b n) f -> b n f", b=batch_size) 152 | features = torch.cat( 153 | [features, einops.repeat(self.latlon_nodes, "n f -> b n f", b=batch_size)], dim=1 154 | ) 155 | features = einops.rearrange(features, "b n f -> (b n) f") 156 | 157 | out, _ = self.graph_processor(features, edge_index, edge_attr) # Message Passing 158 | # Remove the h3 nodes now, only want the latlon ones 159 | out = self.node_decoder(out) # Decode to 78 from 256 160 | out = einops.rearrange(out, "(b n) f -> b n f", b=batch_size) 161 | test, out = torch.split(out, [self.num_h3, self.num_latlons], dim=1) 162 | return out 163 | -------------------------------------------------------------------------------- /graph_weather/models/layers/constraint_layer.py: -------------------------------------------------------------------------------- 1 | """Module for physical constraint layers used in graph weather models. 2 | 3 | This module implements several constraints on a network’s intermediate outputs, 4 | ensuring physical consistency with an input at a lower resolution. 5 | 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class PhysicalConstraintLayer(nn.Module): 13 | """ 14 | 15 | This module implements several constraint types on the network’s intermediate outputs ỹ, 16 | given the corresponding low-resolution input x. The following equations are implemented 17 | (with all operations acting per patch – here, a patch is the full grid of H×W pixels): 18 | 19 | Additive constraint: 20 | y = ỹ + x - avg(ỹ) 21 | 22 | Multiplicative constraint: 23 | y = ỹ * ( x / avg(ỹ) ) 24 | 25 | Softmax constraint: 26 | y = exp(ỹ) * ( x / sum(exp(ỹ)) ) 27 | 28 | We assume that both the intermediate outputs and the low-resolution reference are 4D 29 | tensors in grid format, with shape [B, C, H, W], where n = H*W is the number of pixels 30 | (or nodes) in a patch. 31 | """ 32 | 33 | def __init__( 34 | self, model, grid_shape, upsampling_factor, constraint_type="none", exp_factor=1.0 35 | ): 36 | """Initialize the PhysicalConstraintLayer. 37 | 38 | Args: 39 | model (nn.Module): The model containing the helper methods 40 | 'graph_to_grid' and 'grid_to_graph'. 41 | grid_shape (tuple): Expected spatial dimensions (H, W) of the 42 | high-resolution grid. 43 | upsampling_factor (int): Factor by which the low-resolution grid is upsampled. 44 | constraint_type (str, optional): The constraint to apply. Options are 45 | 'additive', 'multiplicative', or 'softmax'. Defaults to "none". 46 | exp_factor (float, optional): Exponent factor for the softmax constraint. 47 | Defaults to 1.0. 48 | """ 49 | super().__init__() 50 | self.model = model 51 | self.constraint_type = constraint_type 52 | self.grid_shape = grid_shape 53 | self.exp_factor = exp_factor 54 | self.upsampling_factor = upsampling_factor 55 | self.pool = nn.AvgPool2d(kernel_size=upsampling_factor) 56 | 57 | def forward(self, hr_graph, lr_graph): 58 | """Apply the selected physical constraint. 59 | 60 | Processes the high-resolution output and low-resolution input by converting 61 | between graph and grid formats as needed, and then applying the specified constraint. 62 | 63 | Args: 64 | hr_graph (torch.Tensor): High-resolution model output in either graph (3D) 65 | or grid (4D) format. 66 | lr_graph (torch.Tensor): Low-resolution input in the corresponding 67 | graph or grid format. 68 | 69 | Returns: 70 | torch.Tensor: The adjusted output in graph format. 71 | """ 72 | # Check if inputs are in graph (3D) or grid (4D) formats. 73 | if hr_graph.dim() == 3: 74 | # Convert graph format to grid format 75 | hr_grid = self.model.graph_to_grid(hr_graph) 76 | lr_grid = self.model.graph_to_grid(lr_graph) 77 | elif hr_graph.dim() == 4: 78 | # Already in grid format: [B, C, H, W] 79 | _, _, H, W = hr_graph.shape 80 | if (H, W) != self.grid_shape: 81 | raise ValueError(f"Expected spatial dimensions {self.grid_shape}, got {(H, W)}") 82 | hr_grid = hr_graph 83 | lr_grid = lr_graph 84 | else: 85 | raise ValueError("Input tensor must be either 3D (graph) or 4D (grid).") 86 | 87 | # Apply constraint based on type in grid format 88 | if self.constraint_type == "additive": 89 | result = self.additive_constraint(hr_grid, lr_grid) 90 | elif self.constraint_type == "multiplicative": 91 | result = self.multiplicative_constraint(hr_grid, lr_grid) 92 | elif self.constraint_type == "softmax": 93 | result = self.softmax_constraint(hr_grid, lr_grid) 94 | else: 95 | raise ValueError(f"Unknown constraint type: {self.constraint_type}") 96 | 97 | # Convert grid back to graph format 98 | return self.model.grid_to_graph(result) 99 | 100 | def additive_constraint(self, hr, lr): 101 | """Enforces local conservation using an additive correction: 102 | y = ỹ + ( x - avg(ỹ) ) 103 | where avg(ỹ) is computed per patch (via an average-pooling layer). 104 | 105 | For the additive constraint we follow the paper’s formulation using a Kronecker 106 | product to expand the discrepancy between the low-resolution field and the 107 | average of the high-resolution output. 108 | 109 | hr: high-resolution tensor [B, C, H_hr, W_hr] 110 | lr: low-resolution tensor [B, C, h_lr, w_lr] 111 | (with H_hr = upsampling_factor * h_lr & W_hr = upsampling_factor * w_lr) 112 | """ 113 | # Convert grids to graph format using model's mapping 114 | hr_graph = self.model.grid_to_graph(hr) 115 | lr_graph = self.model.grid_to_graph(lr) 116 | 117 | # Apply constraint logic 118 | # Compute average over NODES 119 | avg_hr = hr_graph.mean(dim=1, keepdim=True) 120 | diff = lr_graph - avg_hr 121 | 122 | # Expand difference using spatial mapping 123 | diff_expanded = diff.repeat(1, self.upsampling_factor**2, 1) 124 | 125 | # Apply correction and convert back to GRID format 126 | adjusted_graph = hr_graph + diff_expanded 127 | return self.model.graph_to_grid(adjusted_graph) 128 | 129 | def multiplicative_constraint(self, hr, lr): 130 | """Enforce conservation using a multiplicative correction in graph space. 131 | 132 | The correction is applied by scaling the high-resolution output by a ratio computed 133 | from the low-resolution input and the average of the high-resolution output. 134 | 135 | Args: 136 | hr (torch.Tensor): High-resolution tensor in grid format [B, C, H_hr, W_hr]. 137 | lr (torch.Tensor): Low-resolution tensor in grid format [B, C, h_lr, w_lr]. 138 | 139 | Returns: 140 | torch.Tensor: Adjusted high-resolution tensor in grid format. 141 | """ 142 | # Convert grids to graph format using model's mapping 143 | hr_graph = self.model.grid_to_graph(hr) 144 | lr_graph = self.model.grid_to_graph(lr) 145 | 146 | # Apply constraint logic 147 | # Compute average over NODES 148 | avg_hr = hr_graph.mean(dim=1, keepdim=True) 149 | lr_patch_avg = lr_graph.mean(dim=1, keepdim=True) 150 | 151 | # Compute ratio and expand to match HR graph structure 152 | ratio = lr_patch_avg / (avg_hr + 1e-8) 153 | 154 | # Apply multiplicative correction and convert back to GRID format 155 | adjusted_graph = hr_graph * ratio 156 | return self.model.graph_to_grid(adjusted_graph) 157 | 158 | def softmax_constraint(self, y, lr): 159 | """Apply a softmax-based constraint correction. 160 | 161 | The softmax correction scales the exponentiated high-resolution output so that the 162 | sum over spatial blocks matches the low-resolution reference. 163 | 164 | Args: 165 | y (torch.Tensor): High-resolution tensor in grid format [B, C, H, W]. 166 | lr (torch.Tensor): Low-resolution tensor in grid format [B, C, h, w]. 167 | 168 | Returns: 169 | torch.Tensor: Adjusted high-resolution tensor in grid format after applying 170 | the softmax constraint. 171 | """ 172 | # Apply the exponential function 173 | y = torch.exp(self.exp_factor * y) 174 | 175 | # Pool over spatial blocks 176 | kernel_area = self.upsampling_factor**2 177 | sum_y = self.pool(y) * kernel_area 178 | 179 | # Ensure that lr * (1/sum_y) is contiguous 180 | ratio = (lr * (1 / sum_y)).contiguous() 181 | 182 | # Use device of lr for kron expansion: 183 | device = lr.device 184 | expansion = torch.ones((self.upsampling_factor, self.upsampling_factor), device=device) 185 | 186 | # Expand the low-resolution ratio and correct the y values so that the block sum matches lr. 187 | out = y * torch.kron(ratio, expansion) 188 | return out 189 | -------------------------------------------------------------------------------- /graph_weather/models/layers/decoder.py: -------------------------------------------------------------------------------- 1 | """Decoders to decode from the Processor graph to the original graph with updated values 2 | 3 | In the original paper the decoder is described as 4 | 5 | The Decoder maps back to physical data defined on a latitude/longitude grid. The underlying graph is 6 | again bipartite, this time mapping icosahedron→lat/lon. 7 | The inputs to the Decoder come from the Processor, plus a skip connection back to the original 8 | state of the 78 atmospheric variables onthe latitude/longitude grid. 9 | The output of the Decoder is the predicted 6-hour change in the 78 atmospheric variables, 10 | which is then added to the initial state to produce the new state. We found 6 hours to be a good 11 | balance between shorter time steps (simpler dynamics to model but more iterations required during 12 | rollout) and longer time steps (fewer iterations required during rollout but modeling 13 | more complex dynamics) 14 | 15 | """ 16 | 17 | import torch 18 | 19 | from graph_weather.models.layers.assimilator_decoder import AssimilatorDecoder 20 | 21 | 22 | class Decoder(AssimilatorDecoder): 23 | """Decoder graph module""" 24 | 25 | def __init__( 26 | self, 27 | lat_lons, 28 | resolution: int = 2, 29 | input_dim: int = 256, 30 | output_dim: int = 78, 31 | output_edge_dim: int = 256, 32 | hidden_dim_processor_node: int = 256, 33 | hidden_dim_processor_edge: int = 256, 34 | hidden_layers_processor_node: int = 2, 35 | hidden_layers_processor_edge: int = 2, 36 | mlp_norm_type: str = "LayerNorm", 37 | hidden_dim_decoder: int = 128, 38 | hidden_layers_decoder: int = 2, 39 | use_checkpointing: bool = False, 40 | ): 41 | """ 42 | Decoder from latent graph to lat/lon graph 43 | 44 | Args: 45 | lat_lons: List of (lat,lon) points 46 | resolution: H3 resolution level 47 | input_dim: Input node dimension 48 | output_dim: Output node dimension 49 | output_edge_dim: Edge dimension 50 | hidden_dim_processor_node: Hidden dimension of the node processors 51 | hidden_dim_processor_edge: Hidden dimension of the edge processors 52 | hidden_layers_processor_node: Number of hidden layers in the node processors 53 | hidden_layers_processor_edge: Number of hidden layers in the edge processors 54 | hidden_dim_decoder:Number of hidden dimensions in the decoder 55 | hidden_layers_decoder: Number of layers in the decoder 56 | mlp_norm_type: Type of norm for the MLPs 57 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None 58 | use_checkpointing: Whether to use gradient checkpointing or not 59 | """ 60 | super().__init__( 61 | lat_lons, 62 | resolution, 63 | input_dim, 64 | output_dim, 65 | output_edge_dim, 66 | hidden_dim_processor_node, 67 | hidden_dim_processor_edge, 68 | hidden_layers_processor_node, 69 | hidden_layers_processor_edge, 70 | mlp_norm_type, 71 | hidden_dim_decoder, 72 | hidden_layers_decoder, 73 | use_checkpointing, 74 | ) 75 | 76 | def forward( 77 | self, processor_features: torch.Tensor, start_features: torch.Tensor 78 | ) -> torch.Tensor: 79 | """ 80 | Adds features to the encoding graph 81 | 82 | Args: 83 | processor_features: Processed features in shape [B*Nodes, Features] 84 | start_features: Original input features to the encoder, with shape [B, Nodes, Features] 85 | 86 | Returns: 87 | Updated features for model 88 | """ 89 | out = super().forward(processor_features, start_features.shape[0]) 90 | out = out + start_features # residual connection 91 | return out 92 | -------------------------------------------------------------------------------- /graph_weather/models/layers/grid_to_points.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/layers/grid_to_points.py -------------------------------------------------------------------------------- /graph_weather/models/layers/points_to_grid.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is designed to abstract away the lat/lon points and as a layer that can be put in front of models that more expecta regular lat/lon grid as input. 3 | """ 4 | 5 | -------------------------------------------------------------------------------- /graph_weather/models/layers/processor.py: -------------------------------------------------------------------------------- 1 | """Processor for the latent graph 2 | 3 | In the original paper the processor is described as 4 | 5 | The Processor iteratively processes the 256-channel latent feature data on the icosahedron grid 6 | using 9 rounds of message-passing GNNs. During each round, a node exchanges information with itself 7 | and its immediate neighbors. There are residual connections between each round of processing. 8 | 9 | """ 10 | 11 | import torch 12 | 13 | from graph_weather.models.layers.graph_net_block import GraphProcessor 14 | 15 | 16 | class Processor(torch.nn.Module): 17 | """Processor for latent graphD""" 18 | 19 | def __init__( 20 | self, 21 | input_dim: int = 256, 22 | edge_dim: int = 256, 23 | num_blocks: int = 9, 24 | hidden_dim_processor_node: int = 256, 25 | hidden_dim_processor_edge: int = 256, 26 | hidden_layers_processor_node: int = 2, 27 | hidden_layers_processor_edge: int = 2, 28 | mlp_norm_type: str = "LayerNorm", 29 | ): 30 | """ 31 | Latent graph processor 32 | 33 | Args: 34 | input_dim: Input dimension for the node 35 | edge_dim: Edge input dimension 36 | num_blocks: Number of message passing blocks 37 | hidden_dim_processor_node: Hidden dimension of the node processors 38 | hidden_dim_processor_edge: Hidden dimension of the edge processors 39 | hidden_layers_processor_node: Number of hidden layers in the node processors 40 | hidden_layers_processor_edge: Number of hidden layers in the edge processors 41 | mlp_norm_type: Type of norm for the MLPs 42 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None 43 | """ 44 | super().__init__() 45 | # Build the default graph 46 | # Take features from encoder and put into processor graph 47 | self.input_dim = input_dim 48 | 49 | self.graph_processor = GraphProcessor( 50 | num_blocks, 51 | input_dim, 52 | edge_dim, 53 | hidden_dim_processor_node, 54 | hidden_dim_processor_edge, 55 | hidden_layers_processor_node, 56 | hidden_layers_processor_edge, 57 | mlp_norm_type, 58 | ) 59 | 60 | def forward(self, x: torch.Tensor, edge_index, edge_attr) -> torch.Tensor: 61 | """ 62 | Adds features to the encoding graph 63 | 64 | Args: 65 | x: Torch tensor containing node features 66 | edge_index: Connectivity of graph, of shape [2, Num edges] in COO format 67 | edge_attr: Edge attribues in [Num edges, Features] shape 68 | 69 | Returns: 70 | torch Tensor containing the values of the nodes of the graph 71 | """ 72 | out, _ = self.graph_processor(x, edge_index, edge_attr) 73 | return out 74 | -------------------------------------------------------------------------------- /graph_weather/models/losses.py: -------------------------------------------------------------------------------- 1 | """Weather loss functions""" 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class NormalizedMSELoss(torch.nn.Module): 8 | """Loss function described in the paper""" 9 | 10 | def __init__( 11 | self, feature_variance: list, lat_lons: list, device="cpu", normalize: bool = False 12 | ): 13 | """ 14 | Normalized MSE Loss as described in the paper 15 | 16 | This re-scales each physical variable such that it has unit-variance in the 3 hour temporal 17 | difference. E.g. for temperature data, divide every one at all pressure levels by 18 | sigma_t_3hr, where sigma^2_T,3hr is the variance of the 3 hour change in temperature, 19 | averaged across space (lat/lon + pressure levels) and time (100 random temporal frames). 20 | 21 | Additionally weights by the cos(lat) of the feature 22 | 23 | cos and sin should be in radians 24 | 25 | Args: 26 | feature_variance: Variance for each of the physical features 27 | lat_lons: List of lat/lon pairs, used to generate weighting 28 | device: checks for device whether it supports gpu or not 29 | normalize: option for normalize 30 | """ 31 | # TODO Rescale by nominal static air density at each pressure level, could be 1/pressure level or something similar 32 | super().__init__() 33 | self.feature_variance = torch.tensor(feature_variance) 34 | assert not torch.isnan(self.feature_variance).any() 35 | # Compute unique latitudes from the provided lat/lon pairs. 36 | unique_lats = sorted(set(lat for lat, _ in lat_lons)) 37 | # Use the cosine of each unique latitude (converted to radians) as its weight. 38 | self.weights = torch.tensor( 39 | [np.cos(lat * np.pi / 180.0) for lat in unique_lats], dtype=torch.float 40 | ) 41 | self.normalize = normalize 42 | assert not torch.isnan(self.weights).any() 43 | 44 | def forward(self, pred: torch.Tensor, target: torch.Tensor): 45 | """ 46 | Calculate the loss 47 | 48 | Rescales both predictions and target, so assumes neither are already normalized 49 | Additionally weights by the cos(lat) of the set of features 50 | 51 | Args: 52 | pred: Prediction tensor 53 | target: Target tensor 54 | 55 | Returns: 56 | MSE loss on the variance-normalized values 57 | """ 58 | self.feature_variance = self.feature_variance.to(pred.device) 59 | self.weights = self.weights.to(pred.device) 60 | print(pred.shape) 61 | print(target.shape) 62 | print(self.weights.shape) 63 | 64 | out = (pred - target) ** 2 65 | print(out.shape) 66 | if self.normalize: 67 | out = out / self.feature_variance 68 | 69 | assert not torch.isnan(out).any() 70 | # Mean of the physical variables 71 | out = out.mean(-1) 72 | 73 | # Flatten all dimensions except the batch dimension. 74 | B, *dims = out.shape 75 | num_nodes = np.prod( 76 | dims 77 | ) # Total number of grid nodes (e.g., if grid is HxW, then num_nodes = H*W) 78 | out = out.view(B, num_nodes) 79 | 80 | # Determine the number of unique latitude weights and infer the number of grid columns. 81 | num_unique = self.weights.shape[0] # e.g., number of unique latitudes (rows) 82 | num_lon = num_nodes // num_unique # e.g. if 2592 nodes and 36 unique lat, then num_lon=72 83 | 84 | # Tile the unique latitude weights into a full weight grid 85 | weight_grid = self.weights.unsqueeze(1).expand(num_unique, num_lon).reshape(1, num_nodes) 86 | weight_grid = weight_grid.expand(B, num_nodes) # Now weight_grid is [B, num_nodes] 87 | 88 | # Multiply the per-node error by the corresponding weight. 89 | out = out * weight_grid 90 | 91 | assert not torch.isnan(out).any() 92 | return out.mean() 93 | -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/graph_weather/abf13a36e95ba6b699027e1d4c6760760aae280f/graph_weather/models/weathermesh/__init__.py -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh 3 | """ 4 | 5 | from dataclasses import dataclass 6 | 7 | import dacite 8 | import einops 9 | import torch 10 | import torch.nn as nn 11 | from natten import NeighborhoodAttention3D 12 | 13 | from graph_weather.models.weathermesh.layers import ConvUpBlock 14 | 15 | 16 | @dataclass 17 | class WeatherMeshDecoderConfig: 18 | latent_dim: int 19 | output_channels_2d: int 20 | output_channels_3d: int 21 | n_conv_blocks: int 22 | hidden_dim: int 23 | kernel_size: tuple 24 | num_heads: int 25 | num_transformer_layers: int 26 | 27 | @staticmethod 28 | def from_json(json: dict) -> "WeatherMeshDecoder": 29 | return dacite.from_dict(data_class=WeatherMeshDecoderConfig, data=json) 30 | 31 | def to_json(self) -> dict: 32 | return dacite.asdict(self) 33 | 34 | 35 | class WeatherMeshDecoder(nn.Module): 36 | def __init__( 37 | self, 38 | latent_dim, 39 | output_channels_2d, 40 | output_channels_3d, 41 | n_conv_blocks=3, 42 | hidden_dim=256, 43 | kernel_size: tuple = (5, 7, 7), 44 | num_heads: int = 8, 45 | num_transformer_layers: int = 3, 46 | ): 47 | super().__init__() 48 | 49 | # Transformer layers for initial decoding 50 | self.transformer_layers = nn.ModuleList( 51 | [ 52 | NeighborhoodAttention3D( 53 | dim=latent_dim, num_heads=num_heads, kernel_size=kernel_size 54 | ) 55 | for _ in range(num_transformer_layers) 56 | ] 57 | ) 58 | 59 | # Split into pressure levels and surface paths 60 | self.split = nn.Conv3d(latent_dim, hidden_dim * (2**n_conv_blocks), kernel_size=1) 61 | 62 | # Pressure levels (3D) path 63 | self.pressure_path = nn.ModuleList( 64 | [ 65 | ConvUpBlock( 66 | hidden_dim * (2 ** (i + 1)), 67 | hidden_dim * (2**i) if i > 0 else output_channels_3d, 68 | is_3d=True, 69 | ) 70 | for i in reversed(range(n_conv_blocks)) 71 | ] 72 | ) 73 | 74 | # Surface (2D) path 75 | self.surface_path = nn.ModuleList( 76 | [ 77 | ConvUpBlock( 78 | hidden_dim * (2 ** (i + 1)), 79 | hidden_dim * (2**i) if i > 0 else output_channels_2d, 80 | ) 81 | for i in reversed(range(n_conv_blocks)) 82 | ] 83 | ) 84 | 85 | def forward(self, latent: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 86 | # Needs to be (B,D,H,W,C) with Batch, Depth (vertical levels), Height, Width, Channels 87 | # Apply transformer layers 88 | for transformer in self.transformer_layers: 89 | latent = transformer(latent) 90 | 91 | latent = einops.rearrange(latent, "B D H W C -> B C D H W") 92 | # Split features 93 | features = self.split(latent) 94 | pressure_features = features[:, :, :-1] 95 | surface_features = features[:, :, -1:] 96 | # Decode pressure levels 97 | for block in self.pressure_path: 98 | pressure_features = block(pressure_features) 99 | # Decode surface features 100 | surface_features = surface_features.squeeze(2) 101 | for block in self.surface_path: 102 | surface_features = block(surface_features) 103 | 104 | return surface_features, pressure_features 105 | -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh 3 | """ 4 | 5 | from dataclasses import dataclass 6 | 7 | import dacite 8 | import einops 9 | import torch 10 | import torch.nn as nn 11 | from natten import NeighborhoodAttention3D 12 | 13 | from graph_weather.models.weathermesh.layers import ConvDownBlock 14 | 15 | 16 | @dataclass 17 | class WeatherMeshEncoderConfig: 18 | input_channels_2d: int 19 | input_channels_3d: int 20 | latent_dim: int 21 | n_pressure_levels: int 22 | num_conv_blocks: int 23 | hidden_dim: int 24 | kernel_size: tuple 25 | num_heads: int 26 | num_transformer_layers: int 27 | 28 | @staticmethod 29 | def from_json(json: dict) -> "WeatherMeshEncoder": 30 | return dacite.from_dict(data_class=WeatherMeshEncoderConfig, data=json) 31 | 32 | def to_json(self) -> dict: 33 | return dacite.asdict(self) 34 | 35 | 36 | class WeatherMeshEncoder(nn.Module): 37 | def __init__( 38 | self, 39 | input_channels_2d: int, 40 | input_channels_3d: int, 41 | latent_dim: int, 42 | n_pressure_levels: int, 43 | num_conv_blocks: int = 3, 44 | hidden_dim: int = 256, 45 | kernel_size: tuple = (5, 7, 7), 46 | num_heads: int = 8, 47 | num_transformer_layers: int = 3, 48 | ): 49 | super().__init__() 50 | 51 | # Surface (2D) path 52 | self.surface_path = nn.ModuleList( 53 | [ 54 | ConvDownBlock( 55 | input_channels_2d if i == 0 else hidden_dim * (2**i), 56 | hidden_dim * (2 ** (i + 1)), 57 | ) 58 | for i in range(num_conv_blocks) 59 | ] 60 | ) 61 | 62 | # Pressure levels (3D) path 63 | self.pressure_path = nn.ModuleList( 64 | [ 65 | ConvDownBlock( 66 | input_channels_3d if i == 0 else hidden_dim * (2**i), 67 | hidden_dim * (2 ** (i + 1)), 68 | stride=(1, 2, 2), # Want to keep depth the same size 69 | is_3d=True, 70 | ) 71 | for i in range(num_conv_blocks) 72 | ] 73 | ) 74 | 75 | # Transformer layers for final encoding 76 | self.transformer_layers = nn.ModuleList( 77 | [ 78 | NeighborhoodAttention3D( 79 | dim=latent_dim, kernel_size=kernel_size, num_heads=num_heads 80 | ) 81 | for _ in range(num_transformer_layers) 82 | ] 83 | ) 84 | 85 | # Final projection to latent space 86 | self.to_latent = nn.Conv3d(hidden_dim * (2**num_conv_blocks), latent_dim, kernel_size=1) 87 | 88 | def forward(self, surface: torch.Tensor, pressure: torch.Tensor) -> torch.Tensor: 89 | # Process surface data 90 | for block in self.surface_path: 91 | surface = block(surface) 92 | 93 | # Process pressure level data 94 | for block in self.pressure_path: 95 | pressure = block(pressure) 96 | # Combine features 97 | features = torch.cat( 98 | [pressure, surface.unsqueeze(2)], dim=2 99 | ) # B C D H W currently, want it to be B D H W C 100 | 101 | # Transform to latent space 102 | latent = self.to_latent(features) 103 | 104 | # Reshape to get the shapes 105 | latent = einops.rearrange(latent, "B C D H W -> B D H W C") 106 | # Apply transformer layers 107 | for transformer in self.transformer_layers: 108 | latent = transformer(latent) 109 | return latent 110 | -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ConvDownBlock(nn.Module): 11 | """ 12 | Downsampling convolutional block with residual connection. 13 | Can handle both 2D and 3D inputs. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | in_channels: int, 19 | out_channels: int, 20 | is_3d: bool = False, 21 | kernel_size: int = 3, 22 | stride: int = 2, 23 | padding: int = 1, 24 | groups: int = 1, 25 | activation: nn.Module = nn.GELU(), 26 | ): 27 | super().__init__() 28 | 29 | Conv = nn.Conv3d if is_3d else nn.Conv2d 30 | Norm = nn.BatchNorm3d if is_3d else nn.BatchNorm2d 31 | 32 | self.conv1 = Conv( 33 | in_channels, 34 | out_channels, 35 | kernel_size=kernel_size, 36 | stride=1, 37 | padding=padding, 38 | groups=groups, 39 | bias=False, 40 | ) 41 | self.bn1 = Norm(out_channels) 42 | self.activation1 = activation 43 | 44 | self.conv2 = Conv( 45 | out_channels, 46 | out_channels, 47 | kernel_size=kernel_size, 48 | stride=stride, 49 | padding=padding, 50 | groups=groups, 51 | bias=False, 52 | ) 53 | self.bn2 = Norm(out_channels) 54 | self.activation2 = activation 55 | 56 | # Residual connection with 1x1 conv to match dimensions 57 | self.downsample = Conv(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) 58 | self.bn_down = Norm(out_channels) 59 | 60 | def forward(self, x: torch.Tensor) -> torch.Tensor: 61 | identity = self.bn_down(self.downsample(x)) 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.activation1(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | out += identity 71 | out = self.activation2(out) 72 | 73 | return out 74 | 75 | 76 | class ConvUpBlock(nn.Module): 77 | """ 78 | Upsampling convolutional block with residual connection. same as downBlock but reversed. 79 | """ 80 | 81 | def __init__( 82 | self, 83 | in_channels: int, 84 | out_channels: int, 85 | is_3d: bool = False, 86 | kernel_size: int = 3, 87 | scale_factor: int = 2, 88 | padding: int = 1, 89 | groups: int = 1, 90 | activation: nn.Module = nn.GELU(), 91 | ): 92 | super().__init__() 93 | 94 | Conv = nn.Conv3d if is_3d else nn.Conv2d 95 | Norm = nn.BatchNorm3d if is_3d else nn.BatchNorm2d 96 | self.is_3d = is_3d 97 | self.scale_factor = scale_factor 98 | 99 | self.conv1 = Conv( 100 | in_channels, 101 | in_channels, 102 | kernel_size=kernel_size, 103 | stride=1, 104 | padding=padding, 105 | groups=groups, 106 | bias=False, 107 | ) 108 | self.bn1 = Norm(in_channels) 109 | self.activation1 = activation 110 | 111 | self.conv2 = Conv( 112 | in_channels, 113 | out_channels, 114 | kernel_size=kernel_size, 115 | stride=1, 116 | padding=padding, 117 | groups=groups, 118 | bias=False, 119 | ) 120 | self.bn2 = Norm(out_channels) 121 | self.activation2 = activation 122 | 123 | # Residual connection with 1x1 conv to match dimensions 124 | self.upsample = Conv(in_channels, out_channels, kernel_size=1, stride=1, bias=False) 125 | self.bn_up = Norm(out_channels) 126 | 127 | def forward(self, x: torch.Tensor) -> torch.Tensor: 128 | # Upsample input 129 | if self.is_3d: 130 | x = F.interpolate( 131 | x, 132 | scale_factor=(1, self.scale_factor, self.scale_factor), 133 | mode="trilinear", 134 | align_corners=False, 135 | ) 136 | else: 137 | x = F.interpolate( 138 | x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False 139 | ) 140 | 141 | identity = self.bn_up(self.upsample(x)) 142 | 143 | out = self.conv1(x) 144 | out = self.bn1(out) 145 | out = self.activation1(out) 146 | 147 | out = self.conv2(out) 148 | out = self.bn2(out) 149 | 150 | out += identity 151 | out = self.activation2(out) 152 | 153 | return out 154 | -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh 3 | """ 4 | 5 | from dataclasses import dataclass 6 | 7 | import dacite 8 | import torch.nn as nn 9 | from natten import NeighborhoodAttention3D 10 | 11 | 12 | @dataclass 13 | class WeatherMeshProcessorConfig: 14 | latent_dim: int 15 | n_layers: int 16 | kernel: tuple 17 | num_heads: int 18 | 19 | @staticmethod 20 | def from_json(json: dict) -> "WeatherMeshProcessor": 21 | return dacite.from_dict(data_class=WeatherMeshProcessorConfig, data=json) 22 | 23 | def to_json(self) -> dict: 24 | return dacite.asdict(self) 25 | 26 | 27 | class WeatherMeshProcessor(nn.Module): 28 | def __init__(self, latent_dim, n_layers=10, kernel=(5, 7, 7), num_heads=8): 29 | super().__init__() 30 | 31 | self.layers = nn.ModuleList( 32 | [ 33 | NeighborhoodAttention3D( 34 | dim=latent_dim, 35 | num_heads=num_heads, 36 | kernel_size=kernel, 37 | ) 38 | for _ in range(n_layers) 39 | ] 40 | ) 41 | 42 | def forward(self, x): 43 | for layer in self.layers: 44 | x = layer(x) 45 | return x 46 | -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/weathermesh2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh 3 | """ 4 | 5 | from dataclasses import dataclass 6 | from typing import List 7 | 8 | import dacite 9 | import torch 10 | import torch.nn as nn 11 | 12 | from graph_weather.models.weathermesh.decoder import WeatherMeshDecoder, WeatherMeshDecoderConfig 13 | from graph_weather.models.weathermesh.encoder import WeatherMeshEncoder, WeatherMeshEncoderConfig 14 | from graph_weather.models.weathermesh.processor import ( 15 | WeatherMeshProcessor, 16 | WeatherMeshProcessorConfig, 17 | ) 18 | 19 | """ 20 | Notes on implementation 21 | 22 | To make NATTEN work on a sphere, we implement our own circular padding. At the poles, we use the bump attention behavior from NATTEN. For position encoding of tokens, we use Rotary Embeddings. 23 | 24 | In the default configuration of WeatherMesh 2, the NATTEN window is 5,7,7 in depth, width, height, corresponding to a physical size of 14 degrees longitude and latitude. WeatherMesh 2 contains two processors: a 6hr and a 1hr processor. Each is 10 NATTEN layers deep. 25 | 26 | Training: distributed shampoo: https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/README.md 27 | 28 | Fork version of pytorch checkpoint library called matepoint to implement offloading to RAM 29 | 30 | TODO: Add bump attention and rotary embeddings for the circular padding and position encoding 31 | 32 | """ 33 | 34 | 35 | @dataclass 36 | class WeatherMeshConfig: 37 | encoder: WeatherMeshEncoderConfig 38 | processors: List[WeatherMeshProcessorConfig] 39 | decoder: WeatherMeshDecoderConfig 40 | timesteps: List[int] 41 | surface_channels: int 42 | pressure_channels: int 43 | pressure_levels: int 44 | latent_dim: int 45 | encoder_num_conv_blocks: int 46 | encoder_num_transformer_layers: int 47 | encoder_hidden_dim: int 48 | decoder_num_conv_blocks: int 49 | decoder_num_transformer_layers: int 50 | decoder_hidden_dim: int 51 | processor_num_layers: int 52 | kernel: tuple 53 | num_heads: int 54 | 55 | @staticmethod 56 | def from_json(json: dict) -> "WeatherMesh": 57 | return dacite.from_dict(data_class=WeatherMeshConfig, data=json) 58 | 59 | def to_json(self) -> dict: 60 | return dacite.asdict(self) 61 | 62 | 63 | @dataclass 64 | class WeatherMeshOutput: 65 | surface: torch.Tensor 66 | pressure: torch.Tensor 67 | 68 | 69 | class WeatherMesh(nn.Module): 70 | def __init__( 71 | self, 72 | encoder: nn.Module | None, 73 | processors: List[nn.Module] | None, 74 | decoder: nn.Module | None, 75 | timesteps: List[int], 76 | surface_channels: int | None, 77 | pressure_channels: int | None, 78 | pressure_levels: int | None, 79 | latent_dim: int | None, 80 | encoder_num_conv_blocks: int | None, 81 | encoder_num_transformer_layers: int | None, 82 | encoder_hidden_dim: int | None, 83 | decoder_num_conv_blocks: int | None, 84 | decoder_num_transformer_layers: int | None, 85 | decoder_hidden_dim: int | None, 86 | processor_num_layers: int | None, 87 | kernel: tuple | None, 88 | num_heads: int | None, 89 | ): 90 | super().__init__() 91 | if encoder is not None: 92 | self.encoder = encoder 93 | else: 94 | self.encoder = WeatherMeshEncoder( 95 | input_channels_2d=surface_channels, 96 | input_channels_3d=pressure_channels, 97 | latent_dim=latent_dim, 98 | n_pressure_levels=pressure_levels, 99 | num_conv_blocks=encoder_num_conv_blocks, 100 | hidden_dim=encoder_hidden_dim, 101 | kernel_size=kernel, 102 | num_heads=num_heads, 103 | num_transformer_layers=encoder_num_transformer_layers, 104 | ) 105 | if processors is not None: 106 | assert len(processors) == len( 107 | timesteps 108 | ), "Number of processors must match number of timesteps" 109 | self.processors = processors 110 | else: 111 | self.processors = [ 112 | WeatherMeshProcessor( 113 | latent_dim=latent_dim, 114 | n_layers=processor_num_layers, 115 | kernel=kernel, 116 | num_heads=num_heads, 117 | ) 118 | for _ in range(len(timesteps)) 119 | ] 120 | if decoder is not None: 121 | self.decoder = decoder 122 | else: 123 | self.decoder = WeatherMeshDecoder( 124 | latent_dim=latent_dim, 125 | output_channels_2d=surface_channels, 126 | output_channels_3d=pressure_channels, 127 | n_conv_blocks=decoder_num_conv_blocks, 128 | hidden_dim=decoder_hidden_dim, 129 | kernel_size=kernel, 130 | num_heads=num_heads, 131 | num_transformer_layers=decoder_num_transformer_layers, 132 | ) 133 | self.timesteps = timesteps 134 | 135 | def forward( 136 | self, surface: torch.Tensor, pressure: torch.Tensor, forecast_steps: int 137 | ) -> WeatherMeshOutput: 138 | # Encode input 139 | latent = self.encoder(surface, pressure) 140 | 141 | # Apply processors for each forecast step 142 | for _ in range(forecast_steps): 143 | for processor in self.processors: 144 | latent = processor(latent) 145 | 146 | # Decode output 147 | surface_out, pressure_out = self.decoder(latent) 148 | 149 | return WeatherMeshOutput(surface=surface_out, pressure=pressure_out) 150 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "graph_weather" 3 | requires-python = ">=3.11" 4 | version = "1.0.89" 5 | description = "Graph-based AI Weather models" 6 | authors = [ 7 | {name = "Jacob Prince-Bieker", email = "jacob@bieker.tech"}, 8 | ] 9 | dependencies = ["torch-harmonics"] 10 | 11 | [build-system] 12 | build-backend = "hatchling.build" 13 | requires = ["hatchling"] 14 | 15 | [tool.pixi.project] 16 | channels = ["pyg", "conda-forge", "pytorch"] 17 | platforms = ["linux-64", "osx-arm64"] 18 | 19 | [tool.pixi.feature.cuda] 20 | channels = ["nvidia", {channel = "pytorch", priority = -1}] 21 | 22 | [tool.pixi.feature.cuda.system-requirements] 23 | cuda = "12" 24 | 25 | [tool.pixi.feature.cuda.target.linux-64.dependencies] 26 | cuda-version = "12.4" 27 | pytorch-gpu = {version = "2.4.1", channel = "conda-forge"} 28 | 29 | #[tool.pixi.feature.cuda.target.linux-64.pypi-dependencies] 30 | #natten = {url = "https://shi-labs.com/natten/wheels/cu124/torch2.4.0/natten-0.17.4%2Btorch240cu124-cp312-cp312-linux_x86_64.whl"} 31 | 32 | [tool.pixi.feature.mlx] 33 | # MLX is only available on macOS >=13.5 (>14.0 is recommended) 34 | system-requirements = {macos = "13.5"} 35 | 36 | [tool.pixi.feature.mlx.target.osx-arm64.dependencies] 37 | mlx = {version = "*", channel = "conda-forge"} 38 | pytorch-cpu = {version = "2.4.1", channel = "conda-forge"} 39 | 40 | #[tool.pixi.feature.mlx.target.osx-arm64.pypi-dependencies] 41 | #natten = "*" 42 | 43 | [tool.pixi.feature.cpu] 44 | platforms = ["linux-64", "osx-arm64"] 45 | 46 | [tool.pixi.feature.cpu.dependencies] 47 | pytorch-cpu = {version = "2.4.1", channel = "conda-forge"} 48 | 49 | #[tool.pixi.feature.cpu.pypi-dependencies] 50 | #natten = "*" 51 | 52 | [tool.pixi.dependencies] 53 | python = "3.12.*" 54 | torchvision = {version = "*", channel = "conda-forge"} 55 | pip = "*" 56 | pytest = "*" 57 | pre-commit = "*" 58 | ruff = "*" 59 | xarray = "*" 60 | pandas = "*" 61 | h3-py = "3.*" 62 | numcodecs = "*" 63 | scipy = "*" 64 | zarr = ">=3.0.0" 65 | pyg = "*" 66 | pytorch-cluster = "*" 67 | pytorch-scatter = "*" 68 | pytorch-spline-conv = "*" 69 | pytorch-sparse = "*" 70 | tqdm = "*" 71 | lightning = "*" 72 | einops = "*" 73 | fsspec = "*" 74 | datasets = "*" 75 | trimesh = "*" 76 | pysolar = "*" 77 | rtree = "*" 78 | pixi-pycharm = ">=0.0.8,<0.0.9" 79 | uv = ">=0.6.2,<0.7" 80 | healpy = ">=1.18.1,<2" 81 | 82 | 83 | [tool.pixi.environments] 84 | default = ["cpu"] 85 | cuda = ["cuda"] 86 | mlx = ["mlx"] 87 | 88 | [tool.pixi.tasks] 89 | install = "pip install --editable ." 90 | installnat = "pip install natten" 91 | installnatcuda = "pip install natten==0.17.4+torch240cu124 -f https://shi-labs.com/natten/wheels/" 92 | test = "pytest" 93 | format = "ruff format" 94 | 95 | 96 | [tool.ruff] 97 | # Exclude a variety of commonly ignored directories. 98 | exclude = [ 99 | ".bzr", 100 | ".direnv", 101 | ".eggs", 102 | ".git", 103 | ".hg", 104 | ".mypy_cache", 105 | ".nox", 106 | ".nox", 107 | ".pants.d", 108 | ".pytype", 109 | ".ruff_cache", 110 | ".svn", 111 | ".tox", 112 | ".venv", 113 | "__pypackages__", 114 | "_build", 115 | "buck-out", 116 | "build", 117 | "dist", 118 | "node_modules", 119 | "venv", 120 | "tests", 121 | ] 122 | # Same as Black. 123 | line-length = 100 124 | 125 | # Assume Python 3.10. 126 | target-version = "py311" 127 | fix=false 128 | # Group violations by containing file. 129 | output-format = "github" 130 | 131 | [tool.ruff.lint] 132 | # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. 133 | select = ["E", "F", "D", "I"] 134 | ignore = ["D200","D202","D210","D212","D415","D105"] 135 | 136 | # Allow autofix for all enabled rules (when `--fix`) is provided. 137 | fixable = ["A", "B", "C", "D", "E", "F", "I"] 138 | unfixable = [] 139 | # Allow unused variables when underscore-prefixed. 140 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 141 | mccabe.max-complexity = 10 142 | pydocstyle.convention = "google" 143 | 144 | [tool.ruff.lint.per-file-ignores] 145 | "__init__.py" = ["F401", "E402"] 146 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup""" 2 | 3 | from pathlib import Path 4 | 5 | from setuptools import find_packages, setup 6 | 7 | this_directory = Path(__file__).parent 8 | long_description = (this_directory / "README.md").read_text() 9 | 10 | setup( 11 | name="graph_weather", 12 | version="1.0.108", 13 | packages=find_packages(), 14 | url="https://github.com/openclimatefix/graph_weather", 15 | license="MIT License", 16 | company="Open Climate Fix Ltd", 17 | author="Jacob Bieker", 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | author_email="jacob@bieker.tech", 21 | description="Weather Forecasting with Graph Neural Networks", 22 | classifiers=[ 23 | "Development Status :: 4 - Beta", 24 | "Intended Audience :: Developers", 25 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 26 | "License :: OSI Approved :: MIT License", 27 | "Programming Language :: Python :: 3.9", 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /tests/test_gencast.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from packaging.version import Version 5 | from torch_geometric.transforms import TwoHop 6 | 7 | from graph_weather.models.gencast import Denoiser, GraphBuilder, Sampler, WeightedMSELoss 8 | from graph_weather.models.gencast.layers.modules import FourierEmbedding 9 | from graph_weather.models.gencast.utils.noise import generate_isotropic_noise, sample_noise_level 10 | 11 | 12 | def test_gencast_noise(): 13 | num_lon = 360 14 | num_lat = 180 15 | num_samples = 5 16 | target_residuals = np.zeros((num_lon, num_lat, num_samples)) 17 | noise_level = sample_noise_level() 18 | noise = generate_isotropic_noise( 19 | num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1] 20 | ) 21 | corrupted_residuals = target_residuals + noise_level * noise 22 | assert corrupted_residuals.shape == target_residuals.shape 23 | assert not np.isnan(corrupted_residuals).any() 24 | 25 | num_lon = 360 26 | num_lat = 181 27 | num_samples = 5 28 | target_residuals = np.zeros((num_lon, num_lat, num_samples)) 29 | noise_level = sample_noise_level() 30 | noise = generate_isotropic_noise( 31 | num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1] 32 | ) 33 | corrupted_residuals = target_residuals + noise_level * noise 34 | assert corrupted_residuals.shape == target_residuals.shape 35 | assert not np.isnan(corrupted_residuals).any() 36 | 37 | num_lon = 100 38 | num_lat = 100 39 | num_samples = 5 40 | target_residuals = np.zeros((num_lon, num_lat, num_samples)) 41 | noise_level = sample_noise_level() 42 | noise = generate_isotropic_noise( 43 | num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1], isotropic=False 44 | ) 45 | corrupted_residuals = target_residuals + noise_level * noise 46 | assert corrupted_residuals.shape == target_residuals.shape 47 | assert not np.isnan(corrupted_residuals).any() 48 | 49 | 50 | def test_gencast_graph(): 51 | grid_lat = np.arange(-90, 90, 1) 52 | grid_lon = np.arange(0, 360, 1) 53 | graphs = GraphBuilder(grid_lon=grid_lon, grid_lat=grid_lat, splits=4, num_hops=8) 54 | 55 | # compare khop sparse implementation with pyg. 56 | transform = TwoHop() 57 | khop_mesh_graph_pyg = graphs.mesh_graph 58 | for i in range(3): # 8-hop mesh 59 | khop_mesh_graph_pyg = transform(khop_mesh_graph_pyg) 60 | 61 | assert graphs.mesh_graph.x.shape[0] == 2562 62 | assert graphs.g2m_graph["grid_nodes"].x.shape[0] == 360 * 180 63 | assert graphs.m2g_graph["mesh_nodes"].x.shape[0] == 2562 64 | assert not torch.isnan(graphs.mesh_graph.edge_attr).any() 65 | assert graphs.khop_mesh_graph.x.shape[0] == 2562 66 | assert torch.allclose(graphs.khop_mesh_graph.x, khop_mesh_graph_pyg.x) 67 | assert torch.allclose(graphs.khop_mesh_graph.edge_index, khop_mesh_graph_pyg.edge_index) 68 | 69 | 70 | def test_gencast_loss(): 71 | grid_lat = torch.arange(-90, 90, 1) 72 | grid_lon = torch.arange(0, 360, 1) 73 | pressure_levels = torch.tensor( 74 | [50.0, 100.0, 150.0, 200.0, 250, 300, 400, 500, 600, 700, 850, 925, 1000.0] 75 | ) 76 | single_features_weights = torch.tensor([1, 0.1, 0.1, 0.1, 0.1]) 77 | num_atmospheric_features = 6 78 | batch_size = 3 79 | features_dim = len(pressure_levels) * num_atmospheric_features + len(single_features_weights) 80 | 81 | loss = WeightedMSELoss( 82 | grid_lat=grid_lat, 83 | pressure_levels=pressure_levels, 84 | num_atmospheric_features=num_atmospheric_features, 85 | single_features_weights=single_features_weights, 86 | ) 87 | 88 | preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) 89 | noise_levels = torch.rand((batch_size, 1)) 90 | targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) 91 | assert loss.forward(preds, noise_levels, targets) is not None 92 | 93 | 94 | def test_gencast_denoiser(): 95 | grid_lat = np.arange(-90, 90, 1) 96 | grid_lon = np.arange(0, 360, 1) 97 | input_features_dim = 10 98 | output_features_dim = 5 99 | batch_size = 3 100 | 101 | denoiser = Denoiser( 102 | grid_lon=grid_lon, 103 | grid_lat=grid_lat, 104 | input_features_dim=input_features_dim, 105 | output_features_dim=output_features_dim, 106 | hidden_dims=[16, 32], 107 | num_blocks=3, 108 | num_heads=4, 109 | splits=0, 110 | num_hops=1, 111 | device=torch.device("cpu"), 112 | ).eval() 113 | 114 | corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim)) 115 | prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim)) 116 | noise_levels = torch.rand((batch_size, 1)) 117 | 118 | with torch.no_grad(): 119 | preds = denoiser( 120 | corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels 121 | ) 122 | 123 | assert not torch.isnan(preds).any() 124 | 125 | 126 | def test_gencast_fourier(): 127 | batch_size = 10 128 | output_dim = 20 129 | fourier_embedder = FourierEmbedding(output_dim=output_dim, num_frequencies=32, base_period=16) 130 | t = torch.rand((batch_size, 1)) 131 | assert fourier_embedder(t).shape == (batch_size, output_dim) 132 | 133 | 134 | def test_gencast_sampler(): 135 | grid_lat = np.arange(-90, 90, 1) 136 | grid_lon = np.arange(0, 360, 1) 137 | input_features_dim = 10 138 | output_features_dim = 5 139 | 140 | denoiser = Denoiser( 141 | grid_lon=grid_lon, 142 | grid_lat=grid_lat, 143 | input_features_dim=input_features_dim, 144 | output_features_dim=output_features_dim, 145 | hidden_dims=[16, 32], 146 | num_blocks=3, 147 | num_heads=4, 148 | splits=0, 149 | num_hops=1, 150 | device=torch.device("cpu"), 151 | ).eval() 152 | 153 | prev_inputs = torch.randn((1, len(grid_lon), len(grid_lat), 2 * input_features_dim)) 154 | 155 | sampler = Sampler() 156 | preds = sampler.sample(denoiser, prev_inputs) 157 | assert not torch.isnan(preds).any() 158 | assert preds.shape == (1, len(grid_lon), len(grid_lat), output_features_dim) 159 | 160 | 161 | @pytest.mark.skipif( 162 | Version(torch.__version__).release != Version("2.3.0").release, 163 | reason="dgl tests for experimental features only runs with torch 2.3.0", 164 | ) 165 | def test_gencast_full(): 166 | # download weights from HF 167 | denoiser = Denoiser.from_pretrained( 168 | "openclimatefix/gencast-128x64", 169 | grid_lon=np.arange(0, 360, 360 / 128), 170 | grid_lat=np.arange(-90, 90, 180 / 64) + 1 / 2 * 180 / 64, 171 | ) 172 | 173 | # load inputs and targets 174 | prev_inputs = torch.randn([1, 128, 64, 178]) 175 | target_residuals = torch.randn([1, 128, 64, 83]) 176 | 177 | # predict 178 | sampler = Sampler() 179 | preds = sampler.sample(denoiser, prev_inputs) 180 | 181 | assert not torch.isnan(preds).any() 182 | assert preds.shape == target_residuals.shape 183 | -------------------------------------------------------------------------------- /tests/test_nnjai.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the `SensorDataset` class, mocking the `DataCatalog` to simulate sensor data loading and validate dataset behavior. 3 | The tests ensure correct handling of data types, shapes, and batch processing for various sensor types. 4 | """ 5 | 6 | from datetime import datetime 7 | from unittest.mock import MagicMock, patch 8 | import numpy as np 9 | import pytest 10 | import torch 11 | import pandas as pd 12 | 13 | from graph_weather.data.nnja_ai import SensorDataset, collate_fn 14 | 15 | 16 | def get_sensor_variables(sensor_type): 17 | """Helper function to get the correct variables for each sensor type.""" 18 | if sensor_type == "AMSU": 19 | return [f"TMBR_000{i:02d}" for i in range(1, 16)] # 15 channels 20 | elif sensor_type == "ATMS": 21 | return [f"TMBR_000{i:02d}" for i in range(1, 23)] # 22 channels 22 | elif sensor_type == "MHS": 23 | return [f"TMBR_000{i:02d}" for i in range(1, 6)] # 5 channels 24 | elif sensor_type == "IASI": 25 | return [f"SCRA_{str(i).zfill(5)}" for i in range(1, 617)] # 616 channels 26 | elif sensor_type == "CrIS": 27 | return [f"SRAD01_{str(i).zfill(5)}" for i in range(1, 432)] # 431 channels 28 | return [] 29 | 30 | 31 | @pytest.fixture 32 | def mock_datacatalog(): 33 | """ 34 | Fixture to mock the DataCatalog for unit tests to avoid actual data loading. 35 | """ 36 | with patch("graph_weather.data.nnja_ai.DataCatalog") as mock: 37 | # Create a mock catalog 38 | mock_catalog = MagicMock() 39 | 40 | # Create a mock dataset with direct DataFrame return 41 | mock_dataset = MagicMock() 42 | mock_dataset.load_manifest = MagicMock() 43 | mock_dataset.sel = MagicMock(return_value=mock_dataset) # Return self to chain calls 44 | 45 | def create_mock_df(engine="pandas"): 46 | # Get the sensor type from the mock dataset 47 | sensor_vars = get_sensor_variables(mock_dataset.sensor_type) 48 | 49 | # Create DataFrame with required columns 50 | df = pd.DataFrame( 51 | { 52 | "OBS_TIMESTAMP": pd.date_range( 53 | start=datetime(2021, 1, 1), periods=100, freq="H" 54 | ), 55 | "LAT": np.full(100, 45.0), 56 | "LON": np.full(100, -120.0), 57 | } 58 | ) 59 | 60 | # Add sensor-specific variables 61 | for var in sensor_vars: 62 | df[var] = np.full(100, 250.0) 63 | 64 | return df 65 | 66 | # Set up the mock to return our DataFrame 67 | mock_dataset.load_dataset = create_mock_df 68 | 69 | # Configure the catalog to return our mock dataset 70 | def get_mock_dataset(self, name): 71 | # Set the sensor type based on the requested dataset name 72 | mock_dataset.sensor_type = next( 73 | config["sensor_type"] for config in SENSOR_CONFIGS if config["name"] == name 74 | ) 75 | return mock_dataset 76 | 77 | mock_catalog.__getitem__ = get_mock_dataset # Fix: Explicitly define the method with `self` 78 | mock.return_value = mock_catalog 79 | 80 | yield mock 81 | 82 | 83 | # Test configurations 84 | SENSOR_CONFIGS = [ 85 | { 86 | "name": "amsu-1bamua-NC021023", 87 | "sensor_type": "AMSU", 88 | "expected_metadata_size": 15, # 15 TMBR channels 89 | }, 90 | { 91 | "name": "atms-atms-NC021203", 92 | "sensor_type": "ATMS", 93 | "expected_metadata_size": 22, # 22 TMBR channels 94 | }, 95 | { 96 | "name": "mhs-1bmhs-NC021027", 97 | "sensor_type": "MHS", 98 | "expected_metadata_size": 5, # 5 TMBR channels 99 | }, 100 | { 101 | "name": "iasi-mtiasi-NC021241", 102 | "sensor_type": "IASI", 103 | "expected_metadata_size": 616, # 616 SCRA channels 104 | }, 105 | { 106 | "name": "cris-crisf4-NC021206", 107 | "sensor_type": "CrIS", 108 | "expected_metadata_size": 431, # 431 SRAD channels 109 | }, 110 | ] 111 | 112 | 113 | @pytest.mark.parametrize("sensor_config", SENSOR_CONFIGS) 114 | def test_sensor_dataset(mock_datacatalog, sensor_config): 115 | """Test the SensorDataset class for different sensor types.""" 116 | time = datetime(2021, 1, 1, 0, 0) 117 | primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"] 118 | 119 | dataset = SensorDataset( 120 | dataset_name=sensor_config["name"], 121 | time=time, 122 | primary_descriptors=primary_descriptors, 123 | additional_variables=get_sensor_variables(sensor_config["sensor_type"]), 124 | sensor_type=sensor_config["sensor_type"], 125 | ) 126 | 127 | # Test dataset length 128 | assert len(dataset) > 0, f"Dataset should not be empty for {sensor_config['sensor_type']}" 129 | 130 | # Test single item structure 131 | item = dataset[0] 132 | expected_keys = {"timestamp", "latitude", "longitude", "metadata"} 133 | assert ( 134 | set(item.keys()) == expected_keys 135 | ), f"Dataset item keys are not as expected for {sensor_config['sensor_type']}" 136 | 137 | # Validate tensor properties 138 | assert isinstance( 139 | item["timestamp"], torch.Tensor 140 | ), f"Timestamp should be a tensor for {sensor_config['sensor_type']}" 141 | assert ( 142 | item["timestamp"].dtype == torch.float32 143 | ), f"Timestamp should have dtype float32 for {sensor_config['sensor_type']}" 144 | assert ( 145 | item["timestamp"].ndim == 0 146 | ), f"Timestamp should be a scalar tensor for {sensor_config['sensor_type']}" 147 | 148 | assert isinstance( 149 | item["latitude"], torch.Tensor 150 | ), f"Latitude should be a tensor for {sensor_config['sensor_type']}" 151 | assert ( 152 | item["latitude"].dtype == torch.float32 153 | ), f"Latitude should have dtype float32 for {sensor_config['sensor_type']}" 154 | assert ( 155 | item["latitude"].ndim == 0 156 | ), f"Latitude should be a scalar tensor for {sensor_config['sensor_type']}" 157 | 158 | assert isinstance( 159 | item["longitude"], torch.Tensor 160 | ), f"Longitude should be a tensor for {sensor_config['sensor_type']}" 161 | assert ( 162 | item["longitude"].dtype == torch.float32 163 | ), f"Longitude should have dtype float32 for {sensor_config['sensor_type']}" 164 | assert ( 165 | item["longitude"].ndim == 0 166 | ), f"Longitude should be a scalar tensor for {sensor_config['sensor_type']}" 167 | 168 | assert isinstance( 169 | item["metadata"], torch.Tensor 170 | ), f"Metadata should be a tensor for {sensor_config['sensor_type']}" 171 | assert item["metadata"].shape == ( 172 | sensor_config["expected_metadata_size"], 173 | ), f"Metadata shape mismatch for {sensor_config['sensor_type']}. Expected ({sensor_config['expected_metadata_size']},)" 174 | assert ( 175 | item["metadata"].dtype == torch.float32 176 | ), f"Metadata should have dtype float32 for {sensor_config['sensor_type']}" 177 | 178 | 179 | def test_collate_function(): 180 | """Test the collate_fn function to ensure proper batching of dataset items.""" 181 | batch_size = 4 182 | metadata_size = 15 # Using AMSU size for this test 183 | mock_batch = [ 184 | { 185 | "timestamp": torch.tensor(datetime.now().timestamp(), dtype=torch.float32), 186 | "latitude": torch.tensor(45.0, dtype=torch.float32), 187 | "longitude": torch.tensor(-120.0, dtype=torch.float32), 188 | "metadata": torch.randn(metadata_size, dtype=torch.float32), 189 | } 190 | for _ in range(batch_size) 191 | ] 192 | 193 | batched = collate_fn(mock_batch) 194 | 195 | assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch" 196 | assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch" 197 | assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch" 198 | assert batched["metadata"].shape == (batch_size, metadata_size), "Metadata batch shape mismatch" 199 | assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch" 200 | assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch" 201 | assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch" 202 | assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch" 203 | -------------------------------------------------------------------------------- /tests/test_weathermesh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from graph_weather.models.weathermesh.decoder import WeatherMeshDecoder, WeatherMeshDecoderConfig 4 | from graph_weather.models.weathermesh.encoder import WeatherMeshEncoder, WeatherMeshEncoderConfig 5 | from graph_weather.models.weathermesh.processor import ( 6 | WeatherMeshProcessor, 7 | WeatherMeshProcessorConfig, 8 | ) 9 | from graph_weather.models.weathermesh.weathermesh2 import WeatherMesh, WeatherMeshConfig 10 | 11 | 12 | def test_weathermesh_encoder(): 13 | encoder = WeatherMeshEncoder( 14 | input_channels_2d=2, 15 | input_channels_3d=1, 16 | latent_dim=8, 17 | n_pressure_levels=25, 18 | kernel_size=(3, 3, 3), 19 | num_heads=2, 20 | hidden_dim=16, 21 | num_conv_blocks=3, 22 | num_transformer_layers=3, 23 | ) 24 | x_2d = torch.randn(1, 2, 32, 64) 25 | x_3d = torch.randn(1, 1, 25, 32, 64) 26 | out = encoder(x_2d, x_3d) 27 | assert out.shape == (1, 5, 4, 8, 8) 28 | 29 | 30 | def test_weathermesh_processor(): 31 | processor = WeatherMeshProcessor(latent_dim=8, n_layers=2) 32 | x = torch.randn(1, 26, 32, 64, 8) 33 | out = processor(x) 34 | assert out.shape == (1, 26, 32, 64, 8) 35 | 36 | 37 | def test_weathermesh_decoder(): 38 | decoder = WeatherMeshDecoder( 39 | latent_dim=8, 40 | output_channels_2d=8, 41 | output_channels_3d=4, 42 | kernel_size=(3, 3, 3), 43 | num_heads=2, 44 | hidden_dim=8, 45 | num_transformer_layers=1, 46 | ) 47 | x = torch.randn(1, 6, 32, 64, 8) 48 | out = decoder(x) 49 | assert out[0].shape == (1, 8, 256, 512) 50 | assert out[1].shape == (1, 4, 5, 256, 512) 51 | 52 | 53 | def test_weathermesh(): 54 | model = WeatherMesh( 55 | encoder=None, 56 | processors=None, 57 | decoder=None, 58 | timesteps=[1, 6], 59 | surface_channels=8, 60 | pressure_channels=4, 61 | pressure_levels=5, 62 | latent_dim=4, 63 | encoder_num_conv_blocks=1, 64 | encoder_num_transformer_layers=1, 65 | encoder_hidden_dim=4, 66 | decoder_num_conv_blocks=1, 67 | decoder_num_transformer_layers=1, 68 | decoder_hidden_dim=4, 69 | processor_num_layers=2, 70 | kernel=(3, 5, 5), 71 | num_heads=2, 72 | ) 73 | 74 | x_2d = torch.randn(1, 8, 32, 64) 75 | x_3d = torch.randn(1, 4, 5, 32, 64) 76 | out = model(x_2d, x_3d, forecast_steps=1) 77 | assert out.surface.shape == (1, 8, 32, 64) 78 | assert out.pressure.shape == (1, 4, 5, 32, 64) 79 | -------------------------------------------------------------------------------- /train/deepspeed_graph.py: -------------------------------------------------------------------------------- 1 | """Module for training the Graph Weather forecaster model using PyTorch Lightning.""" 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from pytorch_lightning import Trainer 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | from graph_weather import GraphWeatherForecaster 9 | 10 | lat_lons = [] 11 | for lat in range(-90, 90, 1): 12 | for lon in range(0, 360, 1): 13 | lat_lons.append((lat, lon)) 14 | 15 | 16 | class LitModel(pl.LightningModule): 17 | """ 18 | LightningModule for the weather forecasting model. 19 | 20 | Args: 21 | lat_lons: List of latitude and longitude coordinates. 22 | feature_dim: Dimension of the input features. 23 | aux_dim : Dimension of the auxiliary features. 24 | 25 | Methods: 26 | __init__: Initialize the LitModel object. 27 | """ 28 | 29 | def __init__(self, lat_lons, feature_dim, aux_dim): 30 | """ 31 | Initialize the LitModel object. 32 | 33 | Args: 34 | lat_lons: List of latitude and longitude coordinates. 35 | feature_dim : Dimension of the input features. 36 | aux_dim : Dimension of the auxiliary features. 37 | """ 38 | super().__init__() 39 | self.model = GraphWeatherForecaster( 40 | lat_lons=lat_lons, feature_dim=feature_dim, aux_dim=aux_dim 41 | ) 42 | 43 | def training_step(self, batch): 44 | """ 45 | Performs a training step. 46 | 47 | Args: 48 | batch: A batch of training data. 49 | 50 | Returns: 51 | The computed loss. 52 | """ 53 | x, y = batch 54 | x = x.half() 55 | y = y.half() 56 | out = self.forward(x) 57 | criterion = torch.nn.MSELoss() 58 | loss = criterion(out, y) 59 | return loss 60 | 61 | def configure_optimizers(self): 62 | """ 63 | Configures the optimizer used during training. 64 | 65 | Returns: 66 | The optimizer. 67 | """ 68 | return torch.optim.AdamW(self.parameters()) 69 | 70 | def forward(self, x): 71 | """ 72 | Forward pass. 73 | 74 | Args: 75 | x (torch.Tensor): Input data. 76 | 77 | Returns: 78 | torch.Tensor: Output of the model. 79 | """ 80 | return self.model(x) 81 | 82 | 83 | class FakeDataset(Dataset): 84 | """ 85 | Dataset class for generating fake data. 86 | 87 | Methods: 88 | __init__: Initialize the FakeDataset object. 89 | __len__: Return the length of the dataset. 90 | __getitem__: Get an item from the dataset. 91 | """ 92 | 93 | def __init__(self): 94 | """ 95 | Initialize the FakeDataset object. 96 | """ 97 | super(FakeDataset, self).__init__() 98 | 99 | def __len__(self): 100 | return 64000 101 | 102 | def __getitem__(self, item): 103 | return torch.randn((64800, 605 + 32)), torch.randn((64800, 605)) 104 | 105 | 106 | model = LitModel(lat_lons=lat_lons, feature_dim=605, aux_dim=32) 107 | trainer = Trainer( 108 | accelerator="gpu", 109 | devices=1, 110 | strategy="deepspeed_stage_3_offload", 111 | precision=16, 112 | max_epochs=10, 113 | limit_train_batches=2000, 114 | ) 115 | dataset = FakeDataset() 116 | train_dataloader = DataLoader( 117 | dataset, batch_size=1, num_workers=1, pin_memory=True, prefetch_factor=1 118 | ) 119 | trainer.fit(model=model, train_dataloaders=train_dataloader) 120 | -------------------------------------------------------------------------------- /train/era5.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import xarray 7 | from einops import rearrange 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from graph_weather.models import MetaModel 12 | from graph_weather.models.losses import NormalizedMSELoss 13 | 14 | 15 | class LitFengWuGHR(pl.LightningModule): 16 | """ 17 | LightningModule for graph-based weather forecasting. 18 | 19 | Attributes: 20 | model (GraphWeatherForecaster): Graph weather forecaster model. 21 | criterion (NormalizedMSELoss): Loss criterion for training. 22 | lr : Learning rate for optimizer. 23 | 24 | Methods: 25 | __init__: Initialize the LitFengWuGHR object. 26 | forward: Forward pass of the model. 27 | training_step: Training step. 28 | configure_optimizers: Configure the optimizer for training. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | lat_lons: list, 34 | *, 35 | channels: int, 36 | image_size, 37 | patch_size=4, 38 | depth=5, 39 | heads=4, 40 | mlp_dim=5, 41 | feature_dim: int = 605, # TODO where does this come from? 42 | lr: float = 3e-4, 43 | ): 44 | """ 45 | Initialize the LitFengWuGHR object with the required args. 46 | 47 | Args: 48 | lat_lons : List of latitude and longitude values. 49 | feature_dim : Dimensionality of the input features. 50 | aux_dim : Dimensionality of auxiliary features. 51 | hidden_dim : Dimensionality of hidden layers in the model. 52 | num_blocks : Number of graph convolutional blocks in the model. 53 | lr (float): Learning rate for optimizer. 54 | """ 55 | super().__init__() 56 | self.model = MetaModel( 57 | lat_lons, 58 | image_size=image_size, 59 | patch_size=patch_size, 60 | depth=depth, 61 | heads=heads, 62 | mlp_dim=mlp_dim, 63 | channels=channels, 64 | ) 65 | self.criterion = NormalizedMSELoss( 66 | lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) 67 | ) 68 | self.lr = lr 69 | self.save_hyperparameters() 70 | 71 | def forward(self, x): 72 | """ 73 | Forward pass . 74 | 75 | Args: 76 | x (torch.Tensor): Input tensor. 77 | 78 | Returns: 79 | torch.Tensor: Output tensor. 80 | """ 81 | return self.model(x) 82 | 83 | def training_step(self, batch, batch_idx): 84 | """ 85 | Training step. 86 | 87 | Args: 88 | batch (array): Batch of data containing input and output tensors. 89 | batch_idx (int): Index of the current batch. 90 | 91 | Returns: 92 | torch.Tensor: Loss tensor. 93 | """ 94 | x, y = batch[:, 0], batch[:, 1] 95 | if torch.isnan(x).any() or torch.isnan(y).any(): 96 | return None 97 | y_hat = self.forward(x) 98 | loss = self.criterion(y_hat, y) 99 | self.log("loss", loss, prog_bar=True) 100 | return loss 101 | 102 | def configure_optimizers(self): 103 | """ 104 | Configure the optimizer. 105 | 106 | Returns: 107 | torch.optim.Optimizer: Optimizer instance. 108 | """ 109 | return torch.optim.AdamW(self.parameters(), lr=self.lr) 110 | 111 | 112 | class Era5Dataset(Dataset): 113 | """Era5 dataset.""" 114 | 115 | def __init__(self, xarr, transform=None): 116 | """ 117 | Arguments: 118 | #TODO 119 | """ 120 | ds = np.asarray(xarr.to_array()) 121 | ds = torch.from_numpy(ds) 122 | ds -= ds.min(0, keepdim=True)[0] 123 | ds /= ds.max(0, keepdim=True)[0] 124 | ds = rearrange(ds, "C T H W -> T (H W) C") 125 | self.ds = ds 126 | 127 | def __len__(self): 128 | return len(self.ds) - 1 129 | 130 | def __getitem__(self, index): 131 | return self.ds[index : index + 2] 132 | 133 | 134 | if __name__ == "__main__": 135 | ckpt_path = Path("./checkpoints") 136 | patch_size = 4 137 | grid_step = 20 138 | variables = [ 139 | "2m_temperature", 140 | "surface_pressure", 141 | "10m_u_component_of_wind", 142 | "10m_v_component_of_wind", 143 | ] 144 | 145 | channels = len(variables) 146 | ckpt_path.mkdir(parents=True, exist_ok=True) 147 | 148 | reanalysis = xarray.open_zarr( 149 | "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", 150 | storage_options=dict(token="anon"), 151 | ) 152 | 153 | reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01")) 154 | reanalysis = reanalysis.isel( 155 | time=slice(100, 107), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step) 156 | ) 157 | 158 | reanalysis = reanalysis[variables] 159 | print(f"size: {reanalysis.nbytes / (1024**3)} GiB") 160 | 161 | lat_lons = np.array( 162 | np.meshgrid( 163 | np.asarray(reanalysis["latitude"]).flatten(), 164 | np.asarray(reanalysis["longitude"]).flatten(), 165 | ) 166 | ).T.reshape((-1, 2)) 167 | 168 | checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss") 169 | 170 | dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8) 171 | model = LitFengWuGHR( 172 | lat_lons=lat_lons, 173 | channels=channels, 174 | image_size=(721 // grid_step, 1440 // grid_step), 175 | patch_size=patch_size, 176 | depth=5, 177 | heads=4, 178 | mlp_dim=5, 179 | ) 180 | trainer = pl.Trainer( 181 | accelerator="gpu", 182 | devices=-1, 183 | max_epochs=100, 184 | precision="16-mixed", 185 | callbacks=[checkpoint_callback], 186 | log_every_n_steps=3, 187 | ) 188 | 189 | trainer.fit(model, dset) 190 | 191 | torch.save(model.model.state_dict(), ckpt_path / "best.pt") 192 | -------------------------------------------------------------------------------- /train/lora.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn as nn 7 | import xarray 8 | from einops import rearrange 9 | from pytorch_lightning.callbacks import ModelCheckpoint 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | from graph_weather.models import LoRAModule, MetaModel 13 | from graph_weather.models.losses import NormalizedMSELoss 14 | 15 | 16 | class LitLoRAFengWuGHR(pl.LightningModule): 17 | def __init__( 18 | self, 19 | lat_lons: list, 20 | single_step_model_state_dict: dict, 21 | *, 22 | time_step: int, 23 | rank: int, 24 | channels: int, 25 | image_size, 26 | patch_size=4, 27 | depth=5, 28 | heads=4, 29 | mlp_dim=5, 30 | feature_dim: int = 605, # TODO where does this come from? 31 | lr: float = 3e-4, 32 | ): 33 | super().__init__() 34 | assert ( 35 | time_step > 1 36 | ), "Time step must be greater than 1. Remember that 1 is the simple model time step." 37 | ssmodel = MetaModel( 38 | lat_lons, 39 | image_size=image_size, 40 | patch_size=patch_size, 41 | depth=depth, 42 | heads=heads, 43 | mlp_dim=mlp_dim, 44 | channels=channels, 45 | ) 46 | ssmodel.load_state_dict(single_step_model_state_dict) 47 | self.models = nn.ModuleList( 48 | [ssmodel] + [LoRAModule(ssmodel, r=rank) for _ in range(2, time_step + 1)] 49 | ) 50 | self.criterion = NormalizedMSELoss( 51 | lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) 52 | ) 53 | self.lr = lr 54 | self.save_hyperparameters() 55 | 56 | def forward(self, x): 57 | ys = [] 58 | for t, model in enumerate(self.models): 59 | x = model(x) 60 | ys.append(x) 61 | return torch.stack(ys, dim=1) 62 | 63 | def training_step(self, batch, batch_idx): 64 | if torch.isnan(batch).any(): 65 | return None 66 | x, ys = batch[:, 0, ...], batch[:, 1:, ...] 67 | 68 | y_hat = self.forward(x) 69 | loss = self.criterion(y_hat, ys) 70 | self.log("loss", loss, prog_bar=True) 71 | return loss 72 | 73 | def configure_optimizers(self): 74 | return torch.optim.AdamW(self.parameters(), lr=self.lr) 75 | 76 | 77 | class Era5Dataset(Dataset): 78 | def __init__(self, xarr, time_step=1, transform=None): 79 | assert time_step > 0, "Time step must be greater than 0." 80 | ds = np.asarray(xarr.to_array()) 81 | ds = torch.from_numpy(ds) 82 | ds -= ds.min(0, keepdim=True)[0] 83 | ds /= ds.max(0, keepdim=True)[0] 84 | ds = rearrange(ds, "C T H W -> T (H W) C") 85 | self.ds = ds 86 | self.time_step = time_step 87 | 88 | def __len__(self): 89 | return len(self.ds) - self.time_step 90 | 91 | def __getitem__(self, index): 92 | return self.ds[index : index + time_step + 1] 93 | 94 | 95 | if __name__ == "__main__": 96 | ckpt_path = Path("./checkpoints") 97 | ckpt_name = "best.pt" 98 | patch_size = 4 99 | grid_step = 20 100 | time_step = 2 101 | rank = 4 102 | variables = [ 103 | "2m_temperature", 104 | "surface_pressure", 105 | "10m_u_component_of_wind", 106 | "10m_v_component_of_wind", 107 | ] 108 | 109 | ############################################################### 110 | 111 | channels = len(variables) 112 | ckpt_path.mkdir(parents=True, exist_ok=True) 113 | 114 | reanalysis = xarray.open_zarr( 115 | "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", 116 | storage_options=dict(token="anon"), 117 | ) 118 | 119 | reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01")) 120 | reanalysis = reanalysis.isel( 121 | time=slice(100, 111), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step) 122 | ) 123 | 124 | reanalysis = reanalysis[variables] 125 | print(f"size: {reanalysis.nbytes / (1024**3)} GiB") 126 | 127 | lat_lons = np.array( 128 | np.meshgrid( 129 | np.asarray(reanalysis["latitude"]).flatten(), 130 | np.asarray(reanalysis["longitude"]).flatten(), 131 | ) 132 | ).T.reshape((-1, 2)) 133 | 134 | checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss") 135 | 136 | dset = DataLoader(Era5Dataset(reanalysis, time_step=time_step), batch_size=10, num_workers=8) 137 | 138 | single_step_model_state_dict = torch.load(ckpt_path / ckpt_name) 139 | 140 | model = LitLoRAFengWuGHR( 141 | lat_lons=lat_lons, 142 | single_step_model_state_dict=single_step_model_state_dict, 143 | time_step=time_step, 144 | rank=rank, 145 | ########## 146 | channels=channels, 147 | image_size=(721 // grid_step, 1440 // grid_step), 148 | patch_size=patch_size, 149 | depth=5, 150 | heads=4, 151 | mlp_dim=5, 152 | ) 153 | trainer = pl.Trainer( 154 | accelerator="gpu", 155 | devices=-1, 156 | max_epochs=100, 157 | precision="16-mixed", 158 | callbacks=[checkpoint_callback], 159 | log_every_n_steps=3, 160 | strategy="ddp_find_unused_parameters_true", 161 | ) 162 | 163 | trainer.fit(model, dset) 164 | -------------------------------------------------------------------------------- /train/run_fulll.py: -------------------------------------------------------------------------------- 1 | """Training script for training the weather forecasting model""" 2 | 3 | import json 4 | import os 5 | import sys 6 | import time 7 | 8 | import numpy as np 9 | import torch 10 | import torch.optim as optim 11 | import torchvision.transforms as transforms 12 | import xarray as xr 13 | from torch.utils.data import DataLoader, Dataset 14 | 15 | from graph_weather import GraphWeatherForecaster 16 | from graph_weather.data import const 17 | from graph_weather.models.losses import NormalizedMSELoss 18 | 19 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 20 | sys.path.append(BASE_DIR) 21 | 22 | 23 | class XrDataset(Dataset): 24 | """ 25 | Dataset class for loading data from Hugging Face datasets. 26 | 27 | Attributes: 28 | filepaths : List of file paths to the data. 29 | data : Dataset containing the loaded data. 30 | 31 | Methods: 32 | __init__: Initialize the XrDataset object by loading data from Hugging Face datasets. 33 | __len__: Get the length of the dataset. 34 | __getitem__: Get an item from the dataset by index. 35 | """ 36 | 37 | def __init__(self): 38 | """ 39 | Initialize the XrDataset object by loading data from Hugging Face datasets. 40 | """ 41 | super().__init__() 42 | with open("hf_forecasts.json", "r") as f: 43 | files = json.load(f) 44 | self.filepaths = [ 45 | "zip:///::https://huggingface.co/datasets/openclimatefix/gfs-reforecast/resolve/main/" 46 | + f 47 | for f in files 48 | ] 49 | self.data = xr.open_mfdataset( 50 | self.filepaths, engine="zarr", concat_dim="reftime", combine="nested" 51 | ).sortby("reftime") 52 | 53 | def __len__(self): 54 | return len(self.filepaths) 55 | 56 | def __getitem__(self, item): 57 | start_idx = np.random.randint(0, 14) 58 | data = self.data.isel(reftime=item, time=slice(start_idx, start_idx + 2)) 59 | 60 | start = data.isel(time=0) 61 | end = data.isel(time=1) 62 | # Stack the data into a large data cube 63 | input_data = np.stack( 64 | [ 65 | (start[f"{var}"].values - const.FORECAST_MEANS[f"{var}"]) 66 | / (const.FORECAST_STD[f"{var}"] + 0.0001) 67 | for var in start.data_vars 68 | if "mb" in var or "surface" in var 69 | ], 70 | axis=-1, 71 | ) 72 | input_data = np.nan_to_num(input_data) 73 | assert not np.isnan(input_data).any() 74 | output_data = np.stack( 75 | [ 76 | (end[f"{var}"].values - const.FORECAST_MEANS[f"{var}"]) 77 | / (const.FORECAST_STD[f"{var}"] + 0.0001) 78 | for var in end.data_vars 79 | if "mb" in var or "surface" in var 80 | ], 81 | axis=-1, 82 | ) 83 | output_data = np.nan_to_num(output_data) 84 | assert not np.isnan(output_data).any() 85 | transform = transforms.Compose([transforms.ToTensor()]) 86 | # Normalize now 87 | return ( 88 | transform(input_data).transpose(0, 1).reshape(-1, input_data.shape[-1]), 89 | transform(output_data).transpose(0, 1).reshape(-1, input_data.shape[-1]), 90 | ) 91 | 92 | 93 | with open("hf_forecasts.json", "r") as f: 94 | files = json.load(f) 95 | files = [ 96 | "zip:///::https://huggingface.co/datasets/openclimatefix/gfs-reforecast/resolve/main/" + f 97 | for f in files 98 | ] 99 | data = ( 100 | xr.open_zarr(files[0], consolidated=True).isel(time=0) 101 | # .coarsen(latitude=8, boundary="pad") 102 | # .mean() 103 | # .coarsen(longitude=8) 104 | # .mean() 105 | ) 106 | print(data) 107 | # print("Done coarsening") 108 | lat_lons = np.array(np.meshgrid(data.latitude.values, data.longitude.values)).T.reshape(-1, 2) 109 | 110 | if torch.cuda.is_available(): 111 | device = "cuda" 112 | elif torch.backends.mps.is_available(): 113 | device = "mps" 114 | else: 115 | device = "cpu" 116 | 117 | # Get the variance of the variables 118 | feature_variances = [] 119 | for var in data.data_vars: 120 | if "mb" in var or "surface" in var: 121 | feature_variances.append(const.FORECAST_DIFF_STD[var] ** 2) 122 | criterion = NormalizedMSELoss( 123 | lat_lons=lat_lons, feature_variance=feature_variances, device=device 124 | ).to(device) 125 | means = [] 126 | dataset = DataLoader(XrDataset(), batch_size=1) 127 | model = GraphWeatherForecaster(lat_lons, feature_dim=597, num_blocks=6).to(device) 128 | optimizer = optim.AdamW(model.parameters(), lr=0.001) 129 | print("Done Setup") 130 | 131 | 132 | for epoch in range(100): # loop over the dataset multiple times 133 | running_loss = 0.0 134 | start = time.time() 135 | print(f"Start Epoch: {epoch}") 136 | for i, data in enumerate(dataset): 137 | # get the inputs; data is a list of [inputs, labels] 138 | inputs, labels = data[0].to(device), data[1].to(device) 139 | # zero the parameter gradients 140 | optimizer.zero_grad() 141 | 142 | # forward + backward + optimize 143 | outputs = model(inputs) 144 | 145 | loss = criterion(outputs, labels) 146 | loss.backward() 147 | optimizer.step() 148 | 149 | # print statistics 150 | running_loss += loss.item() 151 | end = time.time() 152 | print( 153 | f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / (i + 1):.3f} Time: {end - start} sec" 154 | ) 155 | if epoch % 5 == 0: 156 | assert not np.isnan(running_loss) 157 | model.push_to_hub( 158 | "graph-weather-forecaster-2.0deg", 159 | organization="openclimatefix", 160 | commit_message=f"Add model Epoch={epoch}", 161 | ) 162 | 163 | print("Finished Training") 164 | --------------------------------------------------------------------------------