├── MANIFEST.in ├── graph_weather ├── models │ ├── fgn │ │ ├── layers │ │ │ ├── __init__.py │ │ │ └── processor.py │ │ ├── __init__.py │ │ └── README.md │ ├── gencast │ │ ├── images │ │ │ ├── readme.md │ │ │ ├── animated.gif │ │ │ ├── fullmodel.png │ │ │ └── autoregressive.gif │ │ ├── graph │ │ │ ├── __init__.py │ │ │ └── grid_mesh_connectivity.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── batching.py │ │ │ ├── statistics.py │ │ │ └── noise.py │ │ ├── layers │ │ │ ├── experimental │ │ │ │ ├── __init__.py │ │ │ │ └── sparse_transformer.py │ │ │ ├── __init__.py │ │ │ ├── decoder.py │ │ │ └── encoder.py │ │ ├── __init__.py │ │ ├── sampler.py │ │ └── weighted_mse_loss.py │ ├── layers │ │ ├── grid_to_points.py │ │ ├── __init__.py │ │ ├── points_to_grid.py │ │ ├── film.py │ │ ├── processor.py │ │ ├── decoder.py │ │ └── assimilator_decoder.py │ ├── weathermesh │ │ ├── __init__.py │ │ ├── processor.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── layers.py │ │ └── weathermesh2.py │ ├── fengwu_ghr │ │ └── __init__.py │ ├── __init__.py │ ├── cafa │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── decoder.py │ │ ├── processor.py │ │ ├── model.py │ │ └── factorize.py │ ├── aurora │ │ ├── decoder.py │ │ ├── __init__.py │ │ ├── encoder.py │ │ └── processor.py │ └── analysis.py ├── data │ ├── __init__.py │ ├── IFSAnalysis_dataloader.py │ ├── nnja_ai.py │ ├── dataloader.py │ └── anemoi_dataloader.py └── __init__.py ├── .gitignore ├── .bumpversion.cfg ├── .github └── workflows │ ├── release.yaml │ └── workflows.yaml ├── tests ├── test_film.py ├── test_fgn.py ├── test_gencast_with_thermalizer.py ├── test_weathermesh.py ├── test_cafa.py ├── test_thermalizer.py ├── test_asme_loss.py ├── test_anemoi.py ├── test_nnjai.py └── test_gencast.py ├── environment_cpu.yml ├── environment_cuda.yml ├── .pre-commit-config.yaml ├── setup.py ├── LICENSE ├── Dockerfile ├── .all-contributorsrc ├── train ├── deepspeed_graph.py ├── lora.py ├── run_fulll.py └── era5.py └── pyproject.toml /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.txt 2 | -------------------------------------------------------------------------------- /graph_weather/models/fgn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/images/readme.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graph_weather/models/layers/grid_to_points.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graph_weather/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """Layers for use in models""" 2 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/graph/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils for graph generation.""" 2 | -------------------------------------------------------------------------------- /graph_weather/models/fgn/__init__.py: -------------------------------------------------------------------------------- 1 | from graph_weather.models.fgn.model import FunctionalGenerativeNetwork 2 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils for gencast.""" 2 | 3 | from .noise import generate_isotropic_noise 4 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/layers/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | """Experimental features.""" 2 | 3 | from .sparse_transformer import SparseTransformer 4 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/images/animated.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/graph_weather/HEAD/graph_weather/models/gencast/images/animated.gif -------------------------------------------------------------------------------- /graph_weather/models/gencast/images/fullmodel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/graph_weather/HEAD/graph_weather/models/gencast/images/fullmodel.png -------------------------------------------------------------------------------- /graph_weather/models/gencast/images/autoregressive.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openclimatefix/graph_weather/HEAD/graph_weather/models/gencast/images/autoregressive.gif -------------------------------------------------------------------------------- /graph_weather/models/fengwu_ghr/__init__.py: -------------------------------------------------------------------------------- 1 | """Main import for FengWu-GHR""" 2 | 3 | from .layers import ImageMetaModel, LoRAModule, MetaModel, WrapperImageModel, WrapperMetaModel 4 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """GenCast layers.""" 2 | 3 | from .decoder import Decoder 4 | from .encoder import Encoder 5 | from .modules import MLP, InteractionNetwork 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.nc 3 | *.pt 4 | *.pyc 5 | .DS_Store 6 | *.txt 7 | # pixi environments 8 | .pixi 9 | .vscode/ 10 | checkpoints/ 11 | lightning_logs/ 12 | pixi.lock 13 | .python-version 14 | -------------------------------------------------------------------------------- /graph_weather/models/layers/points_to_grid.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is designed to abstract away the lat/lon points and as a layer that can be put in front of models that more expecta regular lat/lon grid as input. 3 | """ 4 | -------------------------------------------------------------------------------- /graph_weather/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Dataloaders and data processing utilities""" 2 | 3 | from .anemoi_dataloader import AnemoiDataset 4 | from .nnja_ai import SensorDataset 5 | from .weather_station_reader import WeatherStationReader 6 | -------------------------------------------------------------------------------- /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | commit = True 3 | tag = False 4 | current_version = 1.0.124 5 | message = Bump version: {current_version} → {new_version} [skip ci] 6 | 7 | [bumpversion:file:setup.py] 8 | search = version="{current_version}" 9 | replace = version="{new_version}" 10 | -------------------------------------------------------------------------------- /graph_weather/__init__.py: -------------------------------------------------------------------------------- 1 | """Main import for the complete models""" 2 | 3 | from .data.nnja_ai import SensorDataset 4 | from .data.weather_station_reader import WeatherStationReader 5 | from .models.analysis import GraphWeatherAssimilator 6 | from .models.forecast import GraphWeatherForecaster 7 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/__init__.py: -------------------------------------------------------------------------------- 1 | """Main import for GenCast""" 2 | 3 | from .denoiser import Denoiser 4 | from .graph.graph_builder import GraphBuilder 5 | from .sampler import Sampler 6 | from .utils.noise import generate_isotropic_noise 7 | from .weighted_mse_loss import WeightedMSELoss 8 | -------------------------------------------------------------------------------- /graph_weather/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Models""" 2 | 3 | from .fengwu_ghr.layers import ( 4 | ImageMetaModel, 5 | LoRAModule, 6 | MetaModel, 7 | WrapperImageModel, 8 | WrapperMetaModel, 9 | ) 10 | from .layers.assimilator_decoder import AssimilatorDecoder 11 | from .layers.assimilator_encoder import AssimilatorEncoder 12 | from .layers.decoder import Decoder 13 | from .layers.encoder import Encoder 14 | from .layers.processor import Processor 15 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Bump version and auto-release 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | jobs: 8 | bump-version-python-docker-release: 9 | uses: openclimatefix/.github/.github/workflows/python-docker-release.yml@main 10 | secrets: 11 | PAT_TOKEN: ${{ secrets.PAT_TOKEN }} 12 | token: ${{ secrets.PYPI_API_TOKEN }} 13 | DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} 14 | DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} 15 | with: 16 | image_base_name: graph_weather 17 | docker_file: Dockerfile 18 | -------------------------------------------------------------------------------- /graph_weather/models/cafa/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | CaFA (Climate-Aware Factorized Attention)'s Architectural Design: 3 | - Transformer-based weather forecast for computational efficiency 4 | - Uses Factorized Attention to reduce the cost of the attention mechanism 5 | - A Three-Part System for Efficient Forecasting: Encoder, Factorized Transformer, Decoder 6 | """ 7 | 8 | from .decoder import CaFADecoder 9 | from .encoder import CaFAEncoder 10 | from .factorize import AxialAttention, FactorizedAttention, FactorizedTransformerBlock 11 | from .model import CaFAForecaster 12 | from .processor import CaFAProcessor 13 | -------------------------------------------------------------------------------- /graph_weather/models/fgn/README.md: -------------------------------------------------------------------------------- 1 | # Functional Generative Network (FGN) 2 | 3 | ## Overview 4 | 5 | This is an unofficial implementation of the Functional Generative Network 6 | outlined in [Skillful joint probabilistic weather forecsting from marginals](https://arxiv.org/abs/2506.10772). 7 | 8 | This model is heavily based on GenCast, and is designed to make ensemble weather forecasts through a combination of 9 | mutliple trained models, and noise injected into the model parameters during inference. 10 | 11 | As it does not use diffusion, it is significantly faster to run than GenCast, while outperforming it on nearly all metrics. 12 | -------------------------------------------------------------------------------- /tests/test_film.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from graph_weather.models.layers.film import FiLMGenerator, FiLMApplier 3 | 4 | 5 | def test_film_shapes(): 6 | batch = 4 7 | feature_dim = 16 8 | num_steps = 10 9 | hidden_dim = 8 10 | lead_time = 3 11 | 12 | gen = FiLMGenerator(num_steps, hidden_dim, feature_dim) 13 | apply = FiLMApplier() 14 | 15 | gamma, beta = gen(batch, lead_time, device="cpu") 16 | 17 | assert gamma.shape == (batch, feature_dim) 18 | assert beta.shape == (batch, feature_dim) 19 | 20 | x = torch.randn(batch, feature_dim, 8, 8) 21 | out = apply(x, gamma, beta) 22 | assert out.shape == x.shape 23 | -------------------------------------------------------------------------------- /environment_cpu.yml: -------------------------------------------------------------------------------- 1 | name: graph 2 | channels: 3 | - pytorch 4 | - pyg 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - numcodecs 9 | - pandas 10 | - pip 11 | - pyg 12 | - python 13 | - pytorch 14 | - cpuonly 15 | - pytorch-cluster 16 | - pytorch-scatter 17 | - pytorch-sparse 18 | - pytorch-spline-conv 19 | - scikit-learn 20 | - scipy 21 | - torchvision 22 | - tqdm 23 | - xarray 24 | - zarr 25 | - h3-py 26 | - numpy 27 | - pyshtools 28 | - gcsfs 29 | - pytest 30 | - pip: 31 | - setuptools 32 | - datasets 33 | - einops 34 | - fsspec 35 | - torch-geometric-temporal 36 | - huggingface-hub 37 | - pysolar 38 | - pytorch-lightning 39 | - click 40 | - trimesh 41 | - rtree 42 | - torch-harmonics 43 | -------------------------------------------------------------------------------- /environment_cuda.yml: -------------------------------------------------------------------------------- 1 | name: graph 2 | channels: 3 | - pytorch 4 | - pyg 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - pytorch-cuda 10 | - numcodecs 11 | - pandas 12 | - pip 13 | - pyg 14 | - python=3.12 15 | - pytorch 16 | - pytorch-cluster 17 | - pytorch-scatter 18 | - pytorch-sparse 19 | - pytorch-spline-conv 20 | - scikit-learn 21 | - scipy 22 | - torchvision 23 | - tqdm 24 | - xarray 25 | - zarr 26 | - h3-py 27 | - numpy 28 | - pyshtools 29 | - gcsfs 30 | - pytest 31 | - pip: 32 | - setuptools 33 | - datasets 34 | - einops 35 | - fsspec 36 | - torch-geometric-temporal 37 | - huggingface-hub 38 | - pysolar 39 | - pytorch-lightning 40 | - click 41 | - trimesh 42 | - rtree 43 | - torch-harmonics 44 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v5.0.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: debug-statements 12 | - id: detect-private-key 13 | 14 | # python code formatting/linting 15 | - repo: https://github.com/astral-sh/ruff-pre-commit 16 | # Ruff version. 17 | rev: "v0.12.7" 18 | hooks: 19 | - id: ruff 20 | args: [--fix] 21 | - repo: https://github.com/psf/black 22 | rev: 25.1.0 23 | hooks: 24 | - id: black 25 | args: [--line-length, "100"] 26 | # yaml formatting 27 | - repo: https://github.com/pre-commit/mirrors-prettier 28 | rev: v4.0.0-alpha.8 29 | hooks: 30 | - id: prettier 31 | types: [yaml] 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup""" 2 | 3 | from pathlib import Path 4 | 5 | from setuptools import find_packages, setup 6 | 7 | this_directory = Path(__file__).parent 8 | long_description = (this_directory / "README.md").read_text() 9 | 10 | setup( 11 | name="graph_weather", 12 | version="1.0.124", 13 | packages=find_packages(), 14 | url="https://github.com/openclimatefix/graph_weather", 15 | license="MIT License", 16 | company="Open Climate Fix Ltd", 17 | author="Jacob Bieker", 18 | long_description=long_description, 19 | long_description_content_type="text/markdown", 20 | author_email="jacob@bieker.tech", 21 | description="Weather Forecasting with Graph Neural Networks", 22 | classifiers=[ 23 | "Development Status :: 4 - Beta", 24 | "Intended Audience :: Developers", 25 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 26 | "License :: OSI Approved :: MIT License", 27 | "Programming Language :: Python :: 3.9", 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Open Climate Fix 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_fgn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from packaging.version import Version 5 | from torch_geometric.transforms import TwoHop 6 | 7 | from graph_weather.models.fgn import FunctionalGenerativeNetwork 8 | 9 | 10 | def test_fgn_forward(): 11 | grid_lat = np.arange(-90, 90, 1) 12 | grid_lon = np.arange(0, 360, 1) 13 | input_features_dim = 10 14 | output_features_dim = 5 15 | batch_size = 3 16 | 17 | model = FunctionalGenerativeNetwork( 18 | grid_lon=grid_lon, 19 | grid_lat=grid_lat, 20 | input_features_dim=input_features_dim, 21 | output_features_dim=output_features_dim, 22 | noise_dimension=32, 23 | hidden_dims=[14, 32], 24 | num_blocks=3, 25 | num_heads=2, 26 | splits=0, 27 | num_hops=1, 28 | device=torch.device("cpu"), 29 | ).eval() 30 | 31 | prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), input_features_dim)) 32 | 33 | with torch.no_grad(): 34 | preds = model(previous_weather_state=prev_inputs) 35 | 36 | assert not torch.isnan(preds).any() 37 | assert preds.shape == (3, 2, len(grid_lon), len(grid_lat), output_features_dim) 38 | -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh 3 | """ 4 | 5 | from dataclasses import dataclass 6 | 7 | import dacite 8 | import torch.nn as nn 9 | from natten import NeighborhoodAttention3D 10 | 11 | 12 | @dataclass 13 | class WeatherMeshProcessorConfig: 14 | latent_dim: int 15 | n_layers: int 16 | kernel: tuple 17 | num_heads: int 18 | 19 | @staticmethod 20 | def from_json(json: dict) -> "WeatherMeshProcessor": 21 | return dacite.from_dict(data_class=WeatherMeshProcessorConfig, data=json) 22 | 23 | def to_json(self) -> dict: 24 | return dacite.asdict(self) 25 | 26 | 27 | class WeatherMeshProcessor(nn.Module): 28 | def __init__(self, latent_dim, n_layers=10, kernel=(5, 7, 7), num_heads=8): 29 | super().__init__() 30 | 31 | self.layers = nn.ModuleList( 32 | [ 33 | NeighborhoodAttention3D( 34 | embed_dim=latent_dim, 35 | num_heads=num_heads, 36 | kernel_size=kernel, 37 | ) 38 | for _ in range(n_layers) 39 | ] 40 | ) 41 | 42 | def forward(self, x): 43 | for layer in self.layers: 44 | x = layer(x) 45 | return x 46 | -------------------------------------------------------------------------------- /graph_weather/models/cafa/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CaFAEncoder(nn.Module): 6 | """ 7 | Encoder for CaFA 8 | This projects complex, high-resolution input weather state 9 | and transform it into a lower-resolution, high-dimensional 10 | latent representation that the processor can work with 11 | """ 12 | 13 | def __init__(self, input_channels: int, model_dim: int, downsampling_factor: int = 1): 14 | """ 15 | Args: 16 | input_channel: No. of channels/features in raw input data 17 | model_dim: Dimensions of the model's hidden layers (output channels) 18 | downsampling_factor: Factor to downsample the spatial dimensions by (i.e., 2 means H/2, W/2) 19 | """ 20 | super().__init__() 21 | self.encoder = nn.Conv2d( 22 | in_channels=input_channels, 23 | out_channels=model_dim, 24 | kernel_size=downsampling_factor, 25 | stride=downsampling_factor, 26 | ) 27 | 28 | def forward(self, x: torch.Tensor) -> torch.Tensor: 29 | """ 30 | Args: 31 | x: Input tensor of shape (batch, channels, height, width) 32 | 33 | Returns: 34 | Encoded tensor of shape (batch, model_dim, height/downsampling_factor, width/downsampling_factor) 35 | """ 36 | x = self.encoder(x) 37 | return x 38 | -------------------------------------------------------------------------------- /graph_weather/models/cafa/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CaFADecoder(nn.Module): 6 | """ 7 | Decoder for for CaFA 8 | After the Processor and FactorizedTransformer generated a prediction 9 | in the latent space, the decoder's role is to translate this abstract 10 | representation back into a physical prediction 11 | """ 12 | 13 | def __init__(self, model_dim: int, output_channels: int, upsampling_factor: int = 1): 14 | """ 15 | Args: 16 | output_channels: No. of channels/features in output prediction 17 | model_dim: Dimensions of the model's hidden layers (output channels) 18 | upsampling_factor: Factor to upsample the spatial dimensions. 19 | Must match the downsampling factor in encoder. 20 | """ 21 | super().__init__() 22 | self.decoder = nn.ConvTranspose2d( 23 | in_channels=model_dim, 24 | out_channels=output_channels, 25 | kernel_size=upsampling_factor, 26 | stride=upsampling_factor, 27 | ) 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | """ 31 | Args: 32 | x: Input tensor of shape (batch, model_dim, height, width). 33 | 34 | Returns: 35 | Output tensor of shape (batch, output_channels, height*factor, width*factor) 36 | """ 37 | x = self.decoder(x) 38 | return x 39 | -------------------------------------------------------------------------------- /.github/workflows/workflows.yaml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | pull_request: 6 | types: [opened, reopened] 7 | jobs: 8 | pytest: 9 | runs-on: ${{ matrix.os }} 10 | 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: [ubuntu-latest, macos-latest] 15 | python-version: ["3.12"] 16 | torch-version: [2.7.0] 17 | environment: ["default"] 18 | include: 19 | - torch-version: 2.7.0 20 | torchvision-version: 0.21.0 21 | steps: 22 | - uses: actions/checkout@v4 23 | - uses: prefix-dev/setup-pixi@v0.8.10 24 | with: 25 | environments: ${{ matrix.environment }} 26 | - name: Install dependencies 27 | run: | 28 | pixi run -e ${{ matrix.environment }} installpyg 29 | pixi run -e ${{ matrix.environment }} pip install coverage==7.4.3 pytest-cov 30 | pixi run -e ${{ matrix.environment }} installnat 31 | pixi run -e ${{ matrix.environment }} install 32 | - name: Setup with pytest-cov 33 | run: | 34 | # make PYTESTCOV 35 | export PYTESTCOV="--cov=graph_weather tests/ --cov-report=xml" 36 | # echo results and save env var for other jobs 37 | echo "pytest-cov options that will be used are: $PYTESTCOV" 38 | echo "PYTESTCOV=$PYTESTCOV" >> $GITHUB_ENV 39 | - name: Run tests 40 | run: | 41 | export PYTEST_COMMAND="pytest $PYTESTCOV $PYTESTXDIST -s" 42 | echo "Will be running this command: $PYTEST_COMMAND" 43 | pixi run -e ${{ matrix.environment }} $PYTEST_COMMAND 44 | - name: "Upload coverage to Codecov" 45 | uses: codecov/codecov-action@v2 46 | with: 47 | fail_ci_if_error: false 48 | -------------------------------------------------------------------------------- /graph_weather/models/cafa/processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | from .factorize import FactorizedTransformerBlock 6 | 7 | 8 | class CaFAProcessor(nn.Module): 9 | """ 10 | Processor module for CaFA 11 | Handles latent feature map through multiple layers of self-attention, 12 | allowing information to propagate across the entire global grid. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | dim: int, 18 | depth: int, 19 | heads: int, 20 | dim_head: int = 64, 21 | feedforward_multiplier: int = 4, 22 | dropout: float = 0.0, 23 | ): 24 | """ 25 | Args: 26 | dim: No. of input channels/ features 27 | depth: No. of FactorizedTransformerBlocks to stack 28 | heads: No. of attention heads in each block 29 | dim_head: Dimension of each attention head 30 | feedforward_multiplier: Multiplier for the feedforward network dimension 31 | dropout: Dropout rate 32 | """ 33 | super().__init__() 34 | self.blocks = nn.ModuleList( 35 | [ 36 | FactorizedTransformerBlock(dim, heads, dim_head, feedforward_multiplier, dropout) 37 | for _ in range(depth) 38 | ] 39 | ) 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | """ 43 | Args: 44 | x: Input tensor of shape (batch, height, width, channels) 45 | 46 | Returns: 47 | Refined tensor of same shape 48 | """ 49 | x = rearrange(x, "b c h w -> b h w c") 50 | for block in self.blocks: 51 | x = block(x) 52 | x = rearrange(x, "b h w c -> b c h w") 53 | return x 54 | -------------------------------------------------------------------------------- /graph_weather/models/aurora/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3D Decoder: 3 | - Takes processed latent representations and reconstructs output. 4 | - Uses transposed convolution to upscale back to spatial-temporal format. 5 | """ 6 | 7 | import torch.nn as nn 8 | 9 | 10 | class Decoder3D(nn.Module): 11 | """ 12 | 3D Decoder: 13 | - Takes processed latent representations and reconstructs the spatial-temporal output. 14 | - Uses transposed convolutions to upscale latent features to the original format. 15 | """ 16 | 17 | def __init__(self, output_channels=1, embed_dim=96, target_shape=(32, 32, 32)): 18 | """ 19 | Args: 20 | output_channels (int): Number of channels in the output tensor (e.g., 1 for grayscale). 21 | embed_dim (int): Dimension of the latent features (matches the encoder's output). 22 | target_shape (tuple): The desired shape of the reconstructed 3D tensor (D, H, W). 23 | """ 24 | super().__init__() 25 | self.embed_dim = embed_dim 26 | self.target_shape = target_shape 27 | self.deconv1 = nn.ConvTranspose3d( 28 | embed_dim, output_channels, kernel_size=3, padding=1, stride=1 29 | ) 30 | 31 | def forward(self, x): 32 | """ 33 | Forward pass for the decoder. 34 | 35 | Args: 36 | x (torch.Tensor): Input latent representation, shape (batch, seq_len, embed_dim). 37 | 38 | Returns: 39 | torch.Tensor: Reconstructed 3D tensor, shape (batch, output_channels, *target_shape). 40 | """ 41 | batch_size = x.shape[0] 42 | depth, height, width = self.target_shape 43 | # Reshape latent features into 3D tensor 44 | x = x.view(batch_size, self.embed_dim, depth, height, width) 45 | # Transposed convolution to upscale to the final shape 46 | x = self.deconv1(x) 47 | return x 48 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:latest 2 | 3 | ENV CONDA_ENV_NAME=graph 4 | ENV PYTHON_VERSION=3.12 5 | 6 | # Basic setup 7 | RUN apt update && apt install -y bash \ 8 | build-essential \ 9 | git \ 10 | curl \ 11 | ca-certificates \ 12 | wget \ 13 | libaio-dev \ 14 | && rm -rf /var/lib/apt/lists 15 | 16 | # Install Miniconda and create main env 17 | ADD https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh miniconda3.sh 18 | RUN /bin/bash miniconda3.sh -b -p /conda \ 19 | && echo export PATH=/conda/bin:$PATH >> .bashrc \ 20 | && rm miniconda3.sh 21 | ENV PATH="/conda/bin:${PATH}" 22 | 23 | RUN git clone https://github.com/openclimatefix/graph_weather.git && mv graph_weather/ gw/ && cd gw/ && mv * .. && rm -rf gw/ 24 | 25 | # Copy the appropriate environment file based on CUDA availability 26 | COPY environment_cpu.yml /tmp/environment_cpu.yml 27 | COPY environment_cuda.yml /tmp/environment_cuda.yml 28 | 29 | RUN conda update -n base -c defaults conda 30 | 31 | # Check if CUDA is available and accordingly choose env 32 | RUN cuda=$(command -v nvcc > /dev/null && echo "true" || echo "false") \ 33 | && if [ "$cuda" == "true" ]; then conda env create -f /tmp/environment_cuda.yml; else conda env create -f /tmp/environment_cpu.yml; fi 34 | 35 | # Switch to bash shell 36 | SHELL ["/bin/bash", "-c"] 37 | 38 | # Set ${CONDA_ENV_NAME} to default virutal environment 39 | RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc 40 | 41 | # Cp in the development directory and install 42 | RUN source activate ${CONDA_ENV_NAME} && pip install -e . 43 | 44 | 45 | # Make RUN commands use the new environment: 46 | SHELL ["conda", "run", "-n", "graph", "/bin/bash", "-c"] 47 | 48 | # Example command that can be used, need to set API_KEY, API_SECRET and SAVE_DIR 49 | CMD ["conda", "run", "-n", "graph", "python", "-u", "train/pl_graph_weather.py", "--gpus", "16", "--hidden", "64", "--num-blocks", "3", "--batch", "16"] 50 | -------------------------------------------------------------------------------- /tests/test_gencast_with_thermalizer.py: -------------------------------------------------------------------------------- 1 | """Integration tests for GraphWeatherForecaster using the ThermalizerLayer.""" 2 | 3 | import torch 4 | 5 | from graph_weather.models.forecast import GraphWeatherForecaster 6 | from graph_weather.models.layers.thermalizer import ThermalizerLayer 7 | 8 | 9 | def test_gencast_thermal_integration(): 10 | """End-to-end test: GraphWeatherForecaster with ThermalizerLayer on a 3x3 grid.""" 11 | lat_lons = [(i // 3, i % 3) for i in range(9)] # 3x3 grid 12 | 13 | model = GraphWeatherForecaster( 14 | lat_lons, 15 | use_thermalizer=True, 16 | feature_dim=3, 17 | aux_dim=0, 18 | node_dim=256, 19 | num_blocks=1, 20 | ) 21 | 22 | features = torch.randn(1, len(lat_lons), 3) 23 | t = torch.randint(0, 1000, (1,)).item() 24 | pred = model(features, t=t) 25 | 26 | assert not torch.isnan(pred).any() 27 | assert torch.isfinite(pred).all() 28 | 29 | if pred.dim() == 4: 30 | assert pred.shape[0] == features.shape[0] 31 | assert pred.shape[2] * pred.shape[3] == features.shape[1] 32 | else: 33 | assert pred.shape == features.shape 34 | 35 | 36 | def test_thermalizer_small_grids(): 37 | """Test ThermalizerLayer on various small grid sizes.""" 38 | layer = ThermalizerLayer(input_dim=256) 39 | t = torch.randint(0, 1000, (1,)).item() 40 | 41 | for nodes in [4, 9, 64]: # 2x2, 3x3, 8x8 42 | x = torch.randn(nodes, 256) 43 | out = layer(x, t) 44 | assert out.shape == x.shape 45 | assert not torch.isnan(out).any() 46 | 47 | 48 | def test_small_grid_integration(): 49 | """Test GraphWeatherForecaster + ThermalizerLayer on a 2x2 grid.""" 50 | lat_lons = [(i // 2, i % 2) for i in range(4)] # 2x2 grid 51 | 52 | model = GraphWeatherForecaster( 53 | lat_lons, 54 | use_thermalizer=True, 55 | feature_dim=3, 56 | aux_dim=0, 57 | node_dim=256, 58 | num_blocks=1, 59 | ) 60 | 61 | features = torch.randn(1, len(lat_lons), 3) 62 | pred = model(features, t=50) 63 | 64 | assert not torch.isnan(pred).any() 65 | assert torch.isfinite(pred).all() 66 | 67 | 68 | def test_additional_thermalizer(): 69 | """Basic sanity test for ThermalizerLayer with small input.""" 70 | layer = ThermalizerLayer(input_dim=256) 71 | x = torch.randn(4, 256) 72 | t = torch.randint(0, 1000, (1,)).item() 73 | out = layer(x, t) 74 | 75 | assert out.shape == x.shape 76 | assert not torch.isnan(out).any() 77 | -------------------------------------------------------------------------------- /tests/test_weathermesh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from graph_weather.models.weathermesh.decoder import WeatherMeshDecoder, WeatherMeshDecoderConfig 4 | from graph_weather.models.weathermesh.encoder import WeatherMeshEncoder, WeatherMeshEncoderConfig 5 | from graph_weather.models.weathermesh.processor import ( 6 | WeatherMeshProcessor, 7 | WeatherMeshProcessorConfig, 8 | ) 9 | from graph_weather.models.weathermesh.weathermesh2 import WeatherMesh, WeatherMeshConfig 10 | 11 | 12 | def test_weathermesh_encoder(): 13 | encoder = WeatherMeshEncoder( 14 | input_channels_2d=2, 15 | input_channels_3d=1, 16 | latent_dim=8, 17 | n_pressure_levels=25, 18 | kernel_size=(3, 3, 3), 19 | num_heads=1, 20 | hidden_dim=16, 21 | num_conv_blocks=3, 22 | num_transformer_layers=3, 23 | ) 24 | x_2d = torch.randn(1, 2, 32, 64) 25 | x_3d = torch.randn(1, 1, 25, 32, 64) 26 | out = encoder(x_2d, x_3d) 27 | assert out.shape == (1, 5, 4, 8, 8) 28 | 29 | 30 | def test_weathermesh_processor(): 31 | processor = WeatherMeshProcessor(latent_dim=8, n_layers=2, num_heads=1) 32 | x = torch.randn(1, 26, 32, 64, 8) 33 | out = processor(x) 34 | assert out.shape == (1, 26, 32, 64, 8) 35 | 36 | 37 | def test_weathermesh_decoder(): 38 | decoder = WeatherMeshDecoder( 39 | latent_dim=8, 40 | output_channels_2d=8, 41 | output_channels_3d=4, 42 | kernel_size=(3, 3, 3), 43 | num_heads=2, 44 | hidden_dim=8, 45 | num_transformer_layers=1, 46 | ) 47 | x = torch.randn(1, 6, 32, 64, 8) 48 | out = decoder(x) 49 | assert out[0].shape == (1, 8, 256, 512) 50 | assert out[1].shape == (1, 4, 5, 256, 512) 51 | 52 | 53 | def test_weathermesh(): 54 | model = WeatherMesh( 55 | encoder=None, 56 | processors=None, 57 | decoder=None, 58 | timesteps=[1, 6], 59 | surface_channels=8, 60 | pressure_channels=4, 61 | pressure_levels=5, 62 | latent_dim=4, 63 | encoder_num_conv_blocks=1, 64 | encoder_num_transformer_layers=1, 65 | encoder_hidden_dim=4, 66 | decoder_num_conv_blocks=1, 67 | decoder_num_transformer_layers=1, 68 | decoder_hidden_dim=4, 69 | processor_num_layers=2, 70 | kernel=(3, 5, 5), 71 | num_heads=1, 72 | ) 73 | 74 | x_2d = torch.randn(1, 8, 32, 64) 75 | x_3d = torch.randn(1, 4, 5, 32, 64) 76 | out = model(x_2d, x_3d, forecast_steps=1) 77 | assert out.surface.shape == (1, 8, 32, 64) 78 | assert out.pressure.shape == (1, 4, 5, 32, 64) 79 | -------------------------------------------------------------------------------- /graph_weather/models/aurora/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Aurora: A Foundation Model for Earth System Science 3 | - Combines 3D Swin Transformer encoding 4 | - Perceiver processing for efficient computation 5 | - 3D decoding for spatial-temporal predictions 6 | """ 7 | 8 | from .decoder import Decoder3D 9 | from .encoder import Swin3DEncoder 10 | from .model import AuroraModel, EarthSystemLoss 11 | from .processor import PerceiverProcessor 12 | 13 | __version__ = "0.1.0" 14 | 15 | __all__ = [ 16 | "AuroraModel", 17 | "EarthSystemLoss", 18 | "Swin3DEncoder", 19 | "Decoder3D", 20 | "PerceiverProcessor", 21 | ] 22 | 23 | # Default configurations for different model sizes 24 | MODEL_CONFIGS = { 25 | "tiny": { 26 | "in_channels": 1, 27 | "out_channels": 1, 28 | "embed_dim": 48, 29 | "latent_dim": 256, 30 | "spatial_shape": (16, 16, 16), 31 | "max_seq_len": 2048, 32 | }, 33 | "base": { 34 | "in_channels": 1, 35 | "out_channels": 1, 36 | "embed_dim": 96, 37 | "latent_dim": 512, 38 | "spatial_shape": (32, 32, 32), 39 | "max_seq_len": 4096, 40 | }, 41 | "large": { 42 | "in_channels": 1, 43 | "out_channels": 1, 44 | "embed_dim": 192, 45 | "latent_dim": 1024, 46 | "spatial_shape": (64, 64, 64), 47 | "max_seq_len": 8192, 48 | }, 49 | } 50 | 51 | 52 | def create_model(config="base", **kwargs): 53 | """ 54 | Create an Aurora model with specified configuration. 55 | 56 | Args: 57 | config (str): Model size configuration ('tiny', 'base', or 'large') 58 | **kwargs: Override default configuration parameters 59 | 60 | Returns: 61 | AuroraModel: Initialized model with specified configuration 62 | """ 63 | if config not in MODEL_CONFIGS: 64 | raise ValueError( 65 | f"Unknown configuration: {config}. Choose from {list(MODEL_CONFIGS.keys())}" 66 | ) 67 | 68 | # Start with default config and update with any provided kwargs 69 | model_config = MODEL_CONFIGS[config].copy() 70 | model_config.update(kwargs) 71 | 72 | return AuroraModel(**model_config) 73 | 74 | 75 | def create_loss(alpha=0.5, beta=0.3, gamma=0.2): 76 | """ 77 | Create an EarthSystemLoss instance with specified weights. 78 | 79 | Args: 80 | alpha (float): Weight for MSE loss 81 | beta (float): Weight for gradient loss 82 | gamma (float): Weight for physical consistency loss 83 | 84 | Returns: 85 | EarthSystemLoss: Initialized loss function 86 | """ 87 | return EarthSystemLoss(alpha=alpha, beta=beta, gamma=gamma) 88 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/utils/batching.py: -------------------------------------------------------------------------------- 1 | """Utils for batching graphs.""" 2 | 3 | import torch 4 | 5 | 6 | def batch(senders, edge_index, edge_attr=None, batch_size=1): 7 | """Build big batched graph. 8 | 9 | Returns nodes and edges of a big graph with batch_size disconnected copies of the original 10 | graph, with features shape [(b n) f]. 11 | 12 | Args: 13 | senders (torch.Tensor): nodes' features. 14 | edge_index (torch.Tensor): edge index tensor. 15 | edge_attr (torch.Tensor, optional): edge attributes tensor, if None returns None. 16 | Defaults to None. 17 | batch_size (int): batch size. Defaults to 1. 18 | 19 | Returns: 20 | batched_senders, batched_edge_index, batched_edge_attr 21 | """ 22 | ns = senders.shape[0] 23 | batched_senders = senders 24 | batched_edge_attr = edge_attr 25 | batched_edge_index = edge_index 26 | 27 | for i in range(1, batch_size): 28 | batched_senders = torch.cat([batched_senders, senders], dim=0) 29 | batched_edge_index = torch.cat([batched_edge_index, edge_index + i * ns], dim=1) 30 | 31 | if edge_attr is not None: 32 | batched_edge_attr = torch.cat([batched_edge_attr, edge_attr], dim=0) 33 | 34 | return batched_senders, batched_edge_index, batched_edge_attr 35 | 36 | 37 | def hetero_batch(senders, receivers, edge_index, edge_attr=None, batch_size=1): 38 | """Build big batched heterogenous graph. 39 | 40 | Returns nodes and edges of a big graph with batch_size disconnected copies of the original 41 | graph, with features shape [(b n) f]. 42 | 43 | Args: 44 | senders (torch.Tensor): senders' features. 45 | receivers (torch.Tensor): receivers' features. 46 | edge_index (torch.Tensor): edge index tensor. 47 | edge_attr (torch.Tensor, optional): edge attributes tensor, if None returns None. 48 | Defaults to None. 49 | batch_size (int): batch size. Defaults to 1. 50 | 51 | Returns: 52 | batched_senders, batched_edge_index, batched_edge_attr 53 | """ 54 | ns = senders.shape[0] 55 | nr = receivers.shape[0] 56 | nodes_shape = torch.tensor([[ns], [nr]]).to(edge_index) 57 | batched_senders = senders 58 | batched_receivers = receivers 59 | batched_edge_attr = edge_attr 60 | batched_edge_index = edge_index 61 | 62 | for i in range(1, batch_size): 63 | batched_senders = torch.cat([batched_senders, senders], dim=0) 64 | batched_receivers = torch.cat([batched_receivers, receivers], dim=0) 65 | batched_edge_index = torch.cat([batched_edge_index, edge_index + i * nodes_shape], dim=1) 66 | if edge_attr is not None: 67 | batched_edge_attr = torch.cat([batched_edge_attr, edge_attr], dim=0) 68 | 69 | return batched_senders, batched_receivers, batched_edge_index, batched_edge_attr 70 | -------------------------------------------------------------------------------- /graph_weather/models/layers/film.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FiLMGenerator(nn.Module): 6 | """ 7 | Generates FiLM parameters (gamma and beta) from a lead-time index. 8 | 9 | A one-hot vector for the given lead time is expanded to the batch size 10 | and passed through a small MLP to produce FiLM modulation parameters. 11 | 12 | Args: 13 | num_lead_times (int): Number of possible lead-time categories. 14 | hidden_dim (int): Hidden size for the internal MLP. 15 | feature_dim (int): Output dimensionality of gamma and beta. 16 | """ 17 | 18 | def __init__(self, num_lead_times: int, hidden_dim: int, feature_dim: int): 19 | super().__init__() 20 | self.num_lead_times = num_lead_times 21 | self.feature_dim = feature_dim 22 | self.network = nn.Sequential( 23 | nn.Linear(num_lead_times, hidden_dim), 24 | nn.ReLU(), 25 | nn.Linear(hidden_dim, 2 * feature_dim), 26 | ) 27 | 28 | def forward(self, batch_size: int, lead_time: int, device=None): 29 | """ 30 | Compute FiLM gamma and beta parameters. 31 | 32 | Args: 33 | batch_size (int): Number of samples to generate parameters for. 34 | lead_time (int): Lead-time index used to construct the one-hot input. 35 | device (optional): Device to place tensors on. Defaults to CPU. 36 | 37 | Returns: 38 | Tuple[torch.Tensor, torch.Tensor]: 39 | gamma: Tensor of shape (batch_size, feature_dim). 40 | beta: Tensor of shape (batch_size, feature_dim). 41 | """ 42 | 43 | one_hot = torch.zeros(batch_size, self.num_lead_times, device=device) 44 | one_hot[:, lead_time] = 1.0 45 | gamma_beta = self.network(one_hot) 46 | gamma = gamma_beta[:, : self.feature_dim] 47 | beta = gamma_beta[:, self.feature_dim :] 48 | return gamma, beta 49 | 50 | 51 | class FiLMApplier(nn.Module): 52 | """ 53 | Applies FiLM modulation to an input tensor. 54 | 55 | Gamma and beta are broadcast to match the dimensionality of the input, 56 | and the FiLM operation is applied elementwise. 57 | """ 58 | 59 | def forward(self, x: torch.Tensor, gamma: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: 60 | """ 61 | Apply FiLM conditioning. 62 | 63 | Args: 64 | x (torch.Tensor): Input tensor of shape (B, C, ...). 65 | gamma (torch.Tensor): Scaling parameters of shape (B, C). 66 | beta (torch.Tensor): Bias parameters of shape (B, C). 67 | 68 | Returns: 69 | torch.Tensor: Output tensor after FiLM modulation, same shape as `x`. 70 | """ 71 | 72 | while gamma.ndim < x.ndim: 73 | gamma = gamma.unsqueeze(-1) 74 | beta = beta.unsqueeze(-1) 75 | return x * gamma + beta 76 | -------------------------------------------------------------------------------- /tests/test_cafa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from graph_weather.models.cafa.encoder import CaFAEncoder 5 | from graph_weather.models.cafa.processor import CaFAProcessor 6 | from graph_weather.models.cafa.decoder import CaFADecoder 7 | from graph_weather.models.cafa.model import CaFAForecaster 8 | 9 | # common params for test 10 | BATCH_SIZE = 2 11 | HEIGHT = 32 12 | WIDTH = 64 13 | MODEL_DIM = 128 14 | INPUT_CHANNELS = 3 15 | OUTPUT_CHANNELS = 3 16 | HEADS = 4 17 | DEPTH = 2 18 | DOWNSAMPLING = 2 19 | 20 | 21 | def test_encoder(): 22 | """Tests the CaFAEncoder for correct shape transformation.""" 23 | x = torch.randn(BATCH_SIZE, INPUT_CHANNELS, HEIGHT, WIDTH) 24 | encoder = CaFAEncoder( 25 | input_channels=INPUT_CHANNELS, model_dim=MODEL_DIM, downsampling_factor=DOWNSAMPLING 26 | ) 27 | output = encoder(x) 28 | 29 | assert output.shape == (BATCH_SIZE, MODEL_DIM, HEIGHT // DOWNSAMPLING, WIDTH // DOWNSAMPLING) 30 | 31 | 32 | def test_decoder(): 33 | """Tests the CaFADecoder for correct shape transformation.""" 34 | x = torch.randn(BATCH_SIZE, MODEL_DIM, HEIGHT // DOWNSAMPLING, WIDTH // DOWNSAMPLING) 35 | decoder = CaFADecoder( 36 | model_dim=MODEL_DIM, output_channels=OUTPUT_CHANNELS, upsampling_factor=DOWNSAMPLING 37 | ) 38 | output = decoder(x) 39 | 40 | assert output.shape == (BATCH_SIZE, OUTPUT_CHANNELS, HEIGHT, WIDTH) 41 | 42 | 43 | def test_processor(): 44 | """Tests the CaFAProcessor to ensure it preserves shape.""" 45 | x = torch.randn(BATCH_SIZE, MODEL_DIM, HEIGHT, WIDTH) 46 | processor = CaFAProcessor(dim=MODEL_DIM, depth=DEPTH, heads=HEADS) 47 | output = processor(x) 48 | 49 | assert output.shape == x.shape 50 | 51 | 52 | def test_cafa_forecaster_end_to_end(): 53 | """Tests the full CaFAForecaster model to ensure it works end-to-end.""" 54 | x = torch.randn(BATCH_SIZE, INPUT_CHANNELS, HEIGHT, WIDTH) 55 | model = CaFAForecaster( 56 | input_channels=INPUT_CHANNELS, 57 | output_channels=OUTPUT_CHANNELS, 58 | model_dim=MODEL_DIM, 59 | downsampling_factor=DOWNSAMPLING, 60 | processor_depth=DEPTH, 61 | num_heads=HEADS, 62 | ) 63 | output = model(x) 64 | 65 | assert output.shape == (BATCH_SIZE, OUTPUT_CHANNELS, HEIGHT, WIDTH) 66 | 67 | 68 | def test_cafa_forecaster_odd_dimensions(): 69 | """Tests that the model's internal padding handles odd-sized inputs correctly.""" 70 | 71 | # Use odd dimensions for height and width 72 | x = torch.randn(BATCH_SIZE, INPUT_CHANNELS, HEIGHT + 1, WIDTH + 1) 73 | model = CaFAForecaster( 74 | input_channels=INPUT_CHANNELS, 75 | output_channels=OUTPUT_CHANNELS, 76 | model_dim=MODEL_DIM, 77 | downsampling_factor=DOWNSAMPLING, 78 | processor_depth=DEPTH, 79 | num_heads=HEADS, 80 | ) 81 | output = model(x) 82 | 83 | # The model should return a tensor with the original odd-sized dimensions 84 | assert output.shape == x.shape 85 | -------------------------------------------------------------------------------- /tests/test_thermalizer.py: -------------------------------------------------------------------------------- 1 | """Unit tests for the ThermalizerLayer module.""" 2 | 3 | import torch 4 | 5 | from graph_weather.models.layers.thermalizer import ThermalizerLayer 6 | 7 | 8 | def test_thermalizer_forward_shape(): 9 | """Test forward pass shape with explicit height, width, and batch.""" 10 | batch_size = 2 11 | height, width = 12, 12 12 | nodes = height * width 13 | features = 256 14 | 15 | x = torch.randn(batch_size * nodes, features) 16 | layer = ThermalizerLayer(input_dim=features) 17 | t = torch.randint(0, 1000, (1,)).item() 18 | 19 | out = layer(x, t, height=height, width=width, batch=batch_size) 20 | 21 | assert out.shape == x.shape 22 | assert not torch.isnan(out).any() 23 | assert torch.isfinite(out).all() 24 | 25 | 26 | def test_thermalizer_auto_inference(): 27 | """Test forward pass with auto-inferred dimensions (no height, width, batch).""" 28 | batch_size = 1 29 | height, width = 12, 12 30 | nodes = height * width 31 | features = 256 32 | 33 | x = torch.randn(batch_size * nodes, features) 34 | layer = ThermalizerLayer(input_dim=features) 35 | t = torch.randint(0, 1000, (1,)).item() 36 | 37 | out = layer(x, t) 38 | 39 | assert out.shape == x.shape 40 | assert not torch.isnan(out).any() 41 | assert torch.isfinite(out).all() 42 | 43 | 44 | def test_thermalizer_different_sizes(): 45 | """Test multiple input grid sizes and batch counts.""" 46 | test_cases = [ 47 | (1, 4, 2, 2), 48 | (1, 9, 3, 3), 49 | (1, 16, 4, 4), 50 | (2, 25, 5, 5), 51 | (3, 64, 8, 8), 52 | ] 53 | 54 | features = 256 55 | layer = ThermalizerLayer(input_dim=features) 56 | 57 | for batch_size, nodes, height, width in test_cases: 58 | x = torch.randn(batch_size * nodes, features) 59 | t = torch.randint(0, 1000, (1,)).item() 60 | 61 | out = layer(x, t, height=height, width=width, batch=batch_size) 62 | assert out.shape == x.shape 63 | assert not torch.isnan(out).any() 64 | 65 | if batch_size == 1: 66 | out_auto = layer(x, t) 67 | assert out_auto.shape == x.shape 68 | assert not torch.isnan(out_auto).any() 69 | 70 | 71 | def test_grid_reconstruction(): 72 | """Test reshaping from flat to grid format after forward pass.""" 73 | batch_size = 1 74 | height, width = 6, 8 75 | nodes = height * width 76 | features = 256 77 | 78 | x_grid = torch.randn(batch_size, features, height, width) 79 | x_flat = x_grid.permute(0, 2, 3, 1).reshape(batch_size * nodes, features) 80 | 81 | layer = ThermalizerLayer(input_dim=features) 82 | t = 100 83 | 84 | out_flat = layer(x_flat, t, height=height, width=width, batch=batch_size) 85 | out_grid = out_flat.reshape(batch_size, height, width, features).permute(0, 3, 1, 2) 86 | 87 | assert out_flat.shape == x_flat.shape 88 | assert out_grid.shape == x_grid.shape 89 | assert not torch.isnan(out_grid).any() 90 | -------------------------------------------------------------------------------- /graph_weather/models/aurora/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Swin 3D Transformer Encoder: 3 | - Uses a 3D convolution for initial feature extraction. 4 | - Applies layer normalization and reshapes data. 5 | - Uses a transformer-based encoder to learn spatial-temporal features. 6 | """ 7 | 8 | import torch.nn as nn 9 | from einops import rearrange 10 | from einops.layers.torch import Rearrange 11 | 12 | 13 | class Swin3DEncoder(nn.Module): 14 | def __init__(self, in_channels=1, embed_dim=96): 15 | super().__init__() 16 | self.conv1 = nn.Conv3d(in_channels, embed_dim, kernel_size=3, padding=1, stride=1) 17 | self.norm = nn.LayerNorm(embed_dim) 18 | self.swin_transformer = nn.Transformer( 19 | d_model=embed_dim, 20 | nhead=8, 21 | num_encoder_layers=4, 22 | num_decoder_layers=4, 23 | dim_feedforward=embed_dim * 4, 24 | ) 25 | self.embed_dim = embed_dim 26 | 27 | # Define rearrangement patterns using einops 28 | self.to_transformer_format = Rearrange("b d h w c -> (d h w) b c") 29 | self.from_transformer_format = Rearrange("(d h w) b c -> b d h w c", d=None, h=None, w=None) 30 | 31 | # To use rearrange function directly instead of the Rearrange layer 32 | def forward(self, x): 33 | # 3D convolution with einops rearrangement 34 | x = self.conv1(x) 35 | 36 | # Rearrange for normalization using einops 37 | x = rearrange(x, "b c d h w -> b d h w c") 38 | x = self.norm(x) 39 | 40 | # Store spatial dimensions for later reconstruction 41 | d, h, w = x.shape[1:4] 42 | 43 | # Transform to sequence format for transformer 44 | x = rearrange(x, "b d h w c -> (d h w) b c") 45 | x = self.swin_transformer.encoder(x) 46 | 47 | # Restore original spatial structure 48 | x = rearrange(x, "(d h w) b c -> b (d h w) c", d=d, h=h, w=w) 49 | 50 | # Reshape to the expected output format (batch, seq_len, embed_dim) 51 | x = rearrange(x, "b (d h w) c -> b (d h w) c", d=d, h=h, w=w) 52 | 53 | return x 54 | 55 | def convolution(self, x): 56 | """Apply 3D convolution with clear shape transformation.""" 57 | return self.conv1(x) # b c d h w -> b embed_dim d h w 58 | 59 | def normalization_layer(self, x): 60 | """Apply layer normalization with einops rearrangement.""" 61 | x = rearrange(x, "b c d h w -> b d h w c") 62 | return self.norm(x) 63 | 64 | def transformer_encoder(self, x, spatial_dims): 65 | """ 66 | Apply transformer encoding with proper shape handling. 67 | 68 | Args: 69 | x (torch.Tensor): Input tensor 70 | spatial_dims (tuple): Original (depth, height, width) dimensions 71 | """ 72 | d, h, w = spatial_dims 73 | x = self.to_transformer_format(x) 74 | x = self.swin_transformer.encoder(x) 75 | x = self.from_transformer_format(x, d=d, h=h, w=w) 76 | return x 77 | -------------------------------------------------------------------------------- /graph_weather/data/IFSAnalysis_dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | The dataloader for IFS analysis. 3 | """ 4 | 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | import xarray as xr 8 | from torch.utils.data import Dataset 9 | 10 | IFS_MEAN = { 11 | "geopotential": 78054.78, 12 | "specific_humidity": 0.0018220816, 13 | "temperature": 243.41727, 14 | "u_component_of_wind": 7.3073797, 15 | "v_component_of_wind": 0.032221083, 16 | "vertical_velocity": 0.0058287205, 17 | } 18 | 19 | IFS_STD = { 20 | "geopotential": 59538.875, 21 | "specific_humidity": 0.0035489395, 22 | "temperature": 29.211119, 23 | "u_component_of_wind": 13.777036, 24 | "v_component_of_wind": 8.867598, 25 | "vertical_velocity": 0.08577341, 26 | } 27 | 28 | 29 | class IFSAnalisysDataset(Dataset): 30 | """ 31 | Dataset for IFSAnalysis. 32 | 33 | Args: 34 | filepath: path of the dataset. 35 | features: list of features. 36 | start_year: initial year. Defaults to 2016. 37 | end_year: ending year. Defaults to 2022. 38 | """ 39 | 40 | def __init__(self, filepath: str, features: list, start_year: int = 2016, end_year: int = 2022): 41 | """ 42 | Initialize the dataset object. 43 | """ 44 | 45 | super().__init__() 46 | assert ( 47 | start_year <= end_year 48 | ), f"start_year ({start_year}) cannot be greater than end_year ({end_year})." 49 | assert start_year >= 2016 and start_year <= 2022, "Time data range from 2016 to 2022" 50 | assert end_year >= 2016 and end_year <= 2022, "Time data range from 2016 to 2022" 51 | self.data = xr.open_zarr(filepath) 52 | self.data = self.data.sel( 53 | time=slice(str(start_year), str(end_year)) 54 | ) # Filter data by start and end years 55 | 56 | self.NWP_features = features 57 | 58 | def __len__(self): 59 | return len(self.data["time"]) 60 | 61 | def __getitem__(self, idx): 62 | start = self.data.isel(time=idx) 63 | end = self.data.isel(time=idx + 1) 64 | 65 | # Extract NWP features 66 | input_data = self._nwp_features_extraction(start) 67 | output_data = self._nwp_features_extraction(end) 68 | 69 | return ( 70 | (transforms).ToTensor()(input_data).view(-1, input_data.shape[-1]), 71 | (transforms).ToTensor()(output_data).view(-1, output_data.shape[-1]), 72 | ) 73 | 74 | def _nwp_features_extraction(self, data): 75 | data_cube = np.stack( 76 | [ 77 | (data[f"{var}"].values - IFS_MEAN[f"{var}"]) / (IFS_STD[f"{var}"] + 1e-6) 78 | for var in self.NWP_features 79 | ], 80 | axis=-1, 81 | ).astype(np.float32) 82 | 83 | num_layers, num_lat, num_lon, num_vars = data_cube.shape 84 | data_cube = data_cube.reshape(num_lat, num_lon, num_vars * num_layers) 85 | 86 | assert not np.isnan(data_cube).any() 87 | return data_cube 88 | -------------------------------------------------------------------------------- /graph_weather/models/cafa/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from .decoder import CaFADecoder 6 | from .encoder import CaFAEncoder 7 | from .processor import CaFAProcessor 8 | 9 | 10 | class CaFAForecaster(nn.Module): 11 | """ 12 | CaFA (Climate-Aware Factorized Attention) model 13 | Puts together Encoder, Processor and Decoder into an end-to-end model 14 | """ 15 | 16 | def __init__( 17 | self, 18 | input_channels: int, 19 | output_channels: int, 20 | model_dim: int = 256, 21 | downsampling_factor: int = 2, 22 | processor_depth: int = 6, 23 | num_heads: int = 8, 24 | dim_head: int = 64, 25 | feedforward_multiplier: int = 4, 26 | dropout: float = 0.0, 27 | ): 28 | """ 29 | Args: 30 | input_channels: No. of input channels/features 31 | output_channels: No. of channels to predict 32 | model_dim: Internal dimensions of the model 33 | downsampling_factor: Down/Up-sampling factor in the encoder-decoder 34 | processor_depth: No. of transformer blocks in the processor 35 | num_heads: No. of attention heads in each block 36 | dim_head: Dimension of each attention head 37 | feedforward_multiplier: Multiplier for the feedforward network's inner dimension 38 | dropout: Dropout rate 39 | """ 40 | super().__init__() 41 | 42 | self.downsampling_factor = downsampling_factor 43 | 44 | self.encoder = CaFAEncoder( 45 | input_channels=input_channels, 46 | model_dim=model_dim, 47 | downsampling_factor=downsampling_factor, 48 | ) 49 | 50 | self.processor = CaFAProcessor( 51 | dim=model_dim, 52 | depth=processor_depth, 53 | heads=num_heads, 54 | dim_head=dim_head, 55 | feedforward_multiplier=feedforward_multiplier, 56 | dropout=dropout, 57 | ) 58 | 59 | self.decoder = CaFADecoder( 60 | model_dim=model_dim, 61 | output_channels=output_channels, 62 | upsampling_factor=downsampling_factor, 63 | ) 64 | 65 | def forward(self, x: torch.Tensor) -> torch.Tensor: 66 | """ 67 | Args: 68 | x: Input tensor of shape (batch, input_channels, height, width) 69 | 70 | Returns: 71 | Output tensor of shape (batch, output_channels, height, width) 72 | """ 73 | 74 | # to handle odd-sized inputs, we pad the input to be divisible by downsampling factor 75 | _, _, h, w = x.shape 76 | pad_h = ( 77 | self.downsampling_factor - (h % self.downsampling_factor) 78 | ) % self.downsampling_factor 79 | pad_w = ( 80 | self.downsampling_factor - (w % self.downsampling_factor) 81 | ) % self.downsampling_factor 82 | if pad_h > 0 or pad_w > 0: 83 | x = F.pad(x, (0, pad_w, 0, pad_h)) 84 | 85 | x = self.encoder(x) 86 | x = self.processor(x) 87 | x = self.decoder(x) 88 | 89 | if pad_h > 0 or pad_w > 0: 90 | x = x[:, :, :h, :w] 91 | 92 | return x 93 | -------------------------------------------------------------------------------- /graph_weather/models/layers/processor.py: -------------------------------------------------------------------------------- 1 | """Processor for the latent graph 2 | 3 | In the original paper the processor is described as 4 | 5 | The Processor iteratively processes the 256-channel latent feature data on the icosahedron grid 6 | using 9 rounds of message-passing GNNs. During each round, a node exchanges information with itself 7 | and its immediate neighbors. There are residual connections between each round of processing. 8 | 9 | """ 10 | 11 | import torch 12 | 13 | from graph_weather.models.layers.graph_net_block import GraphProcessor 14 | from graph_weather.models.layers.thermalizer import ThermalizerLayer 15 | 16 | 17 | class Processor(torch.nn.Module): 18 | """Processor for latent graphD""" 19 | 20 | def __init__( 21 | self, 22 | input_dim: int = 256, 23 | edge_dim: int = 256, 24 | num_blocks: int = 9, 25 | hidden_dim_processor_node: int = 256, 26 | hidden_dim_processor_edge: int = 256, 27 | hidden_layers_processor_node: int = 2, 28 | hidden_layers_processor_edge: int = 2, 29 | mlp_norm_type: str = "LayerNorm", 30 | use_thermalizer: bool = False, 31 | ): 32 | """ 33 | Latent graph processor 34 | 35 | Args: 36 | input_dim: Input dimension for the node 37 | edge_dim: Edge input dimension 38 | num_blocks: Number of message passing blocks 39 | hidden_dim_processor_node: Hidden dimension of the node processors 40 | hidden_dim_processor_edge: Hidden dimension of the edge processors 41 | hidden_layers_processor_node: Number of hidden layers in the node processors 42 | hidden_layers_processor_edge: Number of hidden layers in the edge processors 43 | mlp_norm_type: Type of norm for the MLPs 44 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None 45 | use_thermalizer: Whether to use the thermalizer layer 46 | """ 47 | super().__init__() 48 | # Build the default graph 49 | # Take features from encoder and put into processor graph 50 | self.input_dim = input_dim 51 | self.use_thermalizer = use_thermalizer 52 | 53 | self.graph_processor = GraphProcessor( 54 | num_blocks, 55 | input_dim, 56 | edge_dim, 57 | hidden_dim_processor_node, 58 | hidden_dim_processor_edge, 59 | hidden_layers_processor_node, 60 | hidden_layers_processor_edge, 61 | mlp_norm_type, 62 | ) 63 | if self.use_thermalizer: 64 | self.thermalizer = ThermalizerLayer(input_dim) 65 | 66 | def forward(self, x: torch.Tensor, edge_index, edge_attr, t: int = 0) -> torch.Tensor: 67 | """ 68 | Adds features to the encoding graph 69 | 70 | Args: 71 | x: Torch tensor containing node features 72 | edge_index: Connectivity of graph, of shape [2, Num edges] in COO format 73 | edge_attr: Edge attribues in [Num edges, Features] shape 74 | t: Timestep for the thermalizer 75 | 76 | Returns: 77 | torch Tensor containing the values of the nodes of the graph 78 | """ 79 | out, _ = self.graph_processor(x, edge_index, edge_attr) 80 | if self.use_thermalizer: 81 | out = self.thermalizer(out, t) 82 | return out 83 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/utils/statistics.py: -------------------------------------------------------------------------------- 1 | """Statistics computation utils.""" 2 | 3 | import apache_beam # noqa: F401 4 | import numpy as np 5 | import weatherbench2 # noqa: F401 6 | import xarray as xr 7 | 8 | 9 | def compute_statistics(dataset, vars, num_samples=100, single=False): 10 | """Compute statistics for single timestep. 11 | 12 | Args: 13 | dataset: xarray dataset. 14 | vars: list of features. 15 | num_samples (int, optional): _description_. Defaults to 100. 16 | single (bool, optional): if the features have multiple pressure levels. Defaults to False. 17 | 18 | Returns: 19 | means: dict with the means. 20 | stds: dict with the stds. 21 | """ 22 | means = {} 23 | stds = {} 24 | for var in vars: 25 | print(f"Computing statistics for {var}") 26 | random_indexes = np.random.randint(0, len(dataset.time), num_samples) 27 | samples = data.isel(time=random_indexes)[var].values 28 | samples = np.nan_to_num(samples) 29 | axis_tuple = (0, 1, 2) if single else (0, 2, 3) 30 | means[var] = samples.mean(axis=axis_tuple) 31 | stds[var] = samples.std(axis=axis_tuple) 32 | return means, stds 33 | 34 | 35 | def compute_statistics_diff(dataset, vars, num_samples=100, single=False, timestep=2): 36 | """Compute statistics for difference of two timesteps. 37 | 38 | Args: 39 | dataset: xarray dataset. 40 | vars: list of features. 41 | num_samples (int, optional): _description_. Defaults to 100. 42 | single (bool, optional): if the features have multiple pressure levels. Defaults to False. 43 | timestep (int, optional): number of steps to consider between start and end. Defaults to 2. 44 | 45 | Returns: 46 | means: dict with the means. 47 | stds: dict with the stds. 48 | """ 49 | means = {} 50 | stds = {} 51 | for var in vars: 52 | print(f"Computing statistics for {var}") 53 | random_indexes = np.random.randint(0, len(dataset.time), num_samples) 54 | samples_start = data.isel(time=random_indexes)[var].values 55 | samples_start = np.nan_to_num(samples_start) 56 | samples_end = data.isel(time=random_indexes + timestep)[var].values 57 | samples_end = np.nan_to_num(samples_end) 58 | axis_tuple = (0, 1, 2) if single else (0, 2, 3) 59 | means[var] = (samples_end - samples_start).mean(axis=axis_tuple) 60 | stds[var] = (samples_end - samples_start).std(axis=axis_tuple) 61 | return means, stds 62 | 63 | 64 | atmospheric_features = [ 65 | "geopotential", 66 | "specific_humidity", 67 | "temperature", 68 | "u_component_of_wind", 69 | "v_component_of_wind", 70 | "vertical_velocity", 71 | ] 72 | 73 | single_features = [ 74 | "2m_temperature", 75 | "10m_u_component_of_wind", 76 | "10m_v_component_of_wind", 77 | "mean_sea_level_pressure", 78 | # "sea_surface_temperature", 79 | "total_precipitation_12hr", 80 | ] 81 | 82 | static_features = ["geopotential_at_surface", "land_sea_mask"] 83 | 84 | obs_path = "gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr" 85 | # obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-1440x721.zarr' 86 | # obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-512x256_equiangular_conservative.zarr' 87 | data = xr.open_zarr(obs_path) 88 | num_samples = 100 89 | means, stds = compute_statistics_diff(data, single_features, num_samples=num_samples, single=True) 90 | print("Means: ", means) 91 | print("Stds: ", stds) 92 | -------------------------------------------------------------------------------- /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "files": [ 3 | "README.md" 4 | ], 5 | "imageSize": 100, 6 | "commit": false, 7 | "commitConvention": "angular", 8 | "contributors": [ 9 | { 10 | "login": "jacobbieker", 11 | "name": "Jacob Bieker", 12 | "avatar_url": "https://avatars.githubusercontent.com/u/7170359?v=4", 13 | "profile": "https://www.jacobbieker.com", 14 | "contributions": [ 15 | "code" 16 | ] 17 | }, 18 | { 19 | "login": "JackKelly", 20 | "name": "Jack Kelly", 21 | "avatar_url": "https://avatars.githubusercontent.com/u/460756?v=4", 22 | "profile": "http://jack-kelly.com", 23 | "contributions": [ 24 | "ideas" 25 | ] 26 | }, 27 | { 28 | "login": "byphilipp", 29 | "name": "byphilipp", 30 | "avatar_url": "https://avatars.githubusercontent.com/u/59995258?v=4", 31 | "profile": "https://github.com/byphilipp", 32 | "contributions": [ 33 | "ideas" 34 | ] 35 | }, 36 | { 37 | "login": "paapu88", 38 | "name": "Markus Kaukonen", 39 | "avatar_url": "https://avatars.githubusercontent.com/u/6195764?v=4", 40 | "profile": "http://iki.fi/markus.kaukonen", 41 | "contributions": [ 42 | "question" 43 | ] 44 | }, 45 | { 46 | "login": "MoHawastaken", 47 | "name": "MoHawastaken", 48 | "avatar_url": "https://avatars.githubusercontent.com/u/55447473?v=4", 49 | "profile": "https://github.com/MoHawastaken", 50 | "contributions": [ 51 | "bug" 52 | ] 53 | }, 54 | { 55 | "login": "mishooax", 56 | "name": "Mihai", 57 | "avatar_url": "https://avatars.githubusercontent.com/u/47196359?v=4", 58 | "profile": "http://www.ecmwf.int", 59 | "contributions": [ 60 | "question" 61 | ] 62 | }, 63 | { 64 | "login": "vitusbenson", 65 | "name": "Vitus Benson", 66 | "avatar_url": "https://avatars.githubusercontent.com/u/33334860?v=4", 67 | "profile": "https://github.com/vitusbenson", 68 | "contributions": [ 69 | "bug" 70 | ] 71 | }, 72 | { 73 | "login": "dongZheX", 74 | "name": "dongZheX", 75 | "avatar_url": "https://avatars.githubusercontent.com/u/36361726?v=4", 76 | "profile": "https://github.com/dongZheX", 77 | "contributions": [ 78 | "question" 79 | ] 80 | }, 81 | { 82 | "login": "sabbir2331", 83 | "name": "sabbir2331", 84 | "avatar_url": "https://avatars.githubusercontent.com/u/25061297?v=4", 85 | "profile": "https://github.com/sabbir2331", 86 | "contributions": [ 87 | "question" 88 | ] 89 | }, 90 | { 91 | "login": "rnwzd", 92 | "name": "Lorenzo Breschi", 93 | "avatar_url": "https://avatars.githubusercontent.com/u/58804597?v=4", 94 | "profile": "https://github.com/rnwzd", 95 | "contributions": [ 96 | "code" 97 | ] 98 | }, 99 | { 100 | "login": "gbruno16", 101 | "name": "gbruno16", 102 | "avatar_url": "https://avatars.githubusercontent.com/u/72879691?v=4", 103 | "profile": "https://github.com/gbruno16", 104 | "contributions": [ 105 | "code" 106 | ] 107 | } 108 | ], 109 | "contributorsPerLine": 7, 110 | "projectName": "graph_weather", 111 | "projectOwner": "openclimatefix", 112 | "repoType": "github", 113 | "repoHost": "https://github.com", 114 | "skipCi": true, 115 | "commitType": "docs" 116 | } 117 | -------------------------------------------------------------------------------- /train/deepspeed_graph.py: -------------------------------------------------------------------------------- 1 | """Module for training the Graph Weather forecaster model using PyTorch Lightning.""" 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from pytorch_lightning import Trainer 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | from graph_weather import GraphWeatherForecaster 9 | 10 | lat_lons = [] 11 | for lat in range(-90, 90, 1): 12 | for lon in range(0, 360, 1): 13 | lat_lons.append((lat, lon)) 14 | 15 | 16 | class LitModel(pl.LightningModule): 17 | """ 18 | LightningModule for the weather forecasting model. 19 | 20 | Args: 21 | lat_lons: List of latitude and longitude coordinates. 22 | feature_dim: Dimension of the input features. 23 | aux_dim : Dimension of the auxiliary features. 24 | 25 | Methods: 26 | __init__: Initialize the LitModel object. 27 | """ 28 | 29 | def __init__(self, lat_lons, feature_dim, aux_dim): 30 | """ 31 | Initialize the LitModel object. 32 | 33 | Args: 34 | lat_lons: List of latitude and longitude coordinates. 35 | feature_dim : Dimension of the input features. 36 | aux_dim : Dimension of the auxiliary features. 37 | """ 38 | super().__init__() 39 | self.model = GraphWeatherForecaster( 40 | lat_lons=lat_lons, feature_dim=feature_dim, aux_dim=aux_dim 41 | ) 42 | 43 | def training_step(self, batch): 44 | """ 45 | Performs a training step. 46 | 47 | Args: 48 | batch: A batch of training data. 49 | 50 | Returns: 51 | The computed loss. 52 | """ 53 | x, y = batch 54 | x = x.half() 55 | y = y.half() 56 | out = self.forward(x) 57 | criterion = torch.nn.MSELoss() 58 | loss = criterion(out, y) 59 | return loss 60 | 61 | def configure_optimizers(self): 62 | """ 63 | Configures the optimizer used during training. 64 | 65 | Returns: 66 | The optimizer. 67 | """ 68 | return torch.optim.AdamW(self.parameters()) 69 | 70 | def forward(self, x): 71 | """ 72 | Forward pass. 73 | 74 | Args: 75 | x (torch.Tensor): Input data. 76 | 77 | Returns: 78 | torch.Tensor: Output of the model. 79 | """ 80 | return self.model(x) 81 | 82 | 83 | class FakeDataset(Dataset): 84 | """ 85 | Dataset class for generating fake data. 86 | 87 | Methods: 88 | __init__: Initialize the FakeDataset object. 89 | __len__: Return the length of the dataset. 90 | __getitem__: Get an item from the dataset. 91 | """ 92 | 93 | def __init__(self): 94 | """ 95 | Initialize the FakeDataset object. 96 | """ 97 | super(FakeDataset, self).__init__() 98 | 99 | def __len__(self): 100 | return 64000 101 | 102 | def __getitem__(self, item): 103 | return torch.randn((64800, 605 + 32)), torch.randn((64800, 605)) 104 | 105 | 106 | model = LitModel(lat_lons=lat_lons, feature_dim=605, aux_dim=32) 107 | trainer = Trainer( 108 | accelerator="gpu", 109 | devices=1, 110 | strategy="deepspeed_stage_3_offload", 111 | precision=16, 112 | max_epochs=10, 113 | limit_train_batches=2000, 114 | ) 115 | dataset = FakeDataset() 116 | train_dataloader = DataLoader( 117 | dataset, batch_size=1, num_workers=1, pin_memory=True, prefetch_factor=1 118 | ) 119 | trainer.fit(model=model, train_dataloaders=train_dataloader) 120 | -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh 3 | """ 4 | 5 | from dataclasses import dataclass 6 | 7 | import dacite 8 | import einops 9 | import torch 10 | import torch.nn as nn 11 | from natten import NeighborhoodAttention3D 12 | 13 | from graph_weather.models.weathermesh.layers import ConvUpBlock 14 | 15 | 16 | @dataclass 17 | class WeatherMeshDecoderConfig: 18 | latent_dim: int 19 | output_channels_2d: int 20 | output_channels_3d: int 21 | n_conv_blocks: int 22 | hidden_dim: int 23 | kernel_size: tuple 24 | num_heads: int 25 | num_transformer_layers: int 26 | 27 | @staticmethod 28 | def from_json(json: dict) -> "WeatherMeshDecoder": 29 | return dacite.from_dict(data_class=WeatherMeshDecoderConfig, data=json) 30 | 31 | def to_json(self) -> dict: 32 | return dacite.asdict(self) 33 | 34 | 35 | class WeatherMeshDecoder(nn.Module): 36 | def __init__( 37 | self, 38 | latent_dim, 39 | output_channels_2d, 40 | output_channels_3d, 41 | n_conv_blocks=3, 42 | hidden_dim=256, 43 | kernel_size: tuple = (5, 7, 7), 44 | num_heads: int = 8, 45 | num_transformer_layers: int = 3, 46 | ): 47 | super().__init__() 48 | 49 | # Transformer layers for initial decoding 50 | self.transformer_layers = nn.ModuleList( 51 | [ 52 | NeighborhoodAttention3D( 53 | embed_dim=latent_dim, num_heads=num_heads, kernel_size=kernel_size 54 | ) 55 | for _ in range(num_transformer_layers) 56 | ] 57 | ) 58 | 59 | # Split into pressure levels and surface paths 60 | self.split = nn.Conv3d(latent_dim, hidden_dim * (2**n_conv_blocks), kernel_size=1) 61 | 62 | # Pressure levels (3D) path 63 | self.pressure_path = nn.ModuleList( 64 | [ 65 | ConvUpBlock( 66 | hidden_dim * (2 ** (i + 1)), 67 | hidden_dim * (2**i) if i > 0 else output_channels_3d, 68 | is_3d=True, 69 | ) 70 | for i in reversed(range(n_conv_blocks)) 71 | ] 72 | ) 73 | 74 | # Surface (2D) path 75 | self.surface_path = nn.ModuleList( 76 | [ 77 | ConvUpBlock( 78 | hidden_dim * (2 ** (i + 1)), 79 | hidden_dim * (2**i) if i > 0 else output_channels_2d, 80 | ) 81 | for i in reversed(range(n_conv_blocks)) 82 | ] 83 | ) 84 | 85 | def forward(self, latent: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 86 | # Needs to be (B,D,H,W,C) with Batch, Depth (vertical levels), Height, Width, Channels 87 | # Apply transformer layers 88 | for transformer in self.transformer_layers: 89 | latent = transformer(latent) 90 | 91 | latent = einops.rearrange(latent, "B D H W C -> B C D H W") 92 | # Split features 93 | features = self.split(latent) 94 | pressure_features = features[:, :, :-1] 95 | surface_features = features[:, :, -1:] 96 | # Decode pressure levels 97 | for block in self.pressure_path: 98 | pressure_features = block(pressure_features) 99 | # Decode surface features 100 | surface_features = surface_features.squeeze(2) 101 | for block in self.surface_path: 102 | surface_features = block(surface_features) 103 | 104 | return surface_features, pressure_features 105 | -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh 3 | """ 4 | 5 | from dataclasses import dataclass 6 | 7 | import dacite 8 | import einops 9 | import torch 10 | import torch.nn as nn 11 | from natten import NeighborhoodAttention3D 12 | 13 | from graph_weather.models.weathermesh.layers import ConvDownBlock 14 | 15 | 16 | @dataclass 17 | class WeatherMeshEncoderConfig: 18 | input_channels_2d: int 19 | input_channels_3d: int 20 | latent_dim: int 21 | n_pressure_levels: int 22 | num_conv_blocks: int 23 | hidden_dim: int 24 | kernel_size: tuple 25 | num_heads: int 26 | num_transformer_layers: int 27 | 28 | @staticmethod 29 | def from_json(json: dict) -> "WeatherMeshEncoder": 30 | return dacite.from_dict(data_class=WeatherMeshEncoderConfig, data=json) 31 | 32 | def to_json(self) -> dict: 33 | return dacite.asdict(self) 34 | 35 | 36 | class WeatherMeshEncoder(nn.Module): 37 | def __init__( 38 | self, 39 | input_channels_2d: int, 40 | input_channels_3d: int, 41 | latent_dim: int, 42 | n_pressure_levels: int, 43 | num_conv_blocks: int = 3, 44 | hidden_dim: int = 256, 45 | kernel_size: tuple = (5, 7, 7), 46 | num_heads: int = 8, 47 | num_transformer_layers: int = 3, 48 | ): 49 | super().__init__() 50 | 51 | # Surface (2D) path 52 | self.surface_path = nn.ModuleList( 53 | [ 54 | ConvDownBlock( 55 | input_channels_2d if i == 0 else hidden_dim * (2**i), 56 | hidden_dim * (2 ** (i + 1)), 57 | ) 58 | for i in range(num_conv_blocks) 59 | ] 60 | ) 61 | 62 | # Pressure levels (3D) path 63 | self.pressure_path = nn.ModuleList( 64 | [ 65 | ConvDownBlock( 66 | input_channels_3d if i == 0 else hidden_dim * (2**i), 67 | hidden_dim * (2 ** (i + 1)), 68 | stride=(1, 2, 2), # Want to keep depth the same size 69 | is_3d=True, 70 | ) 71 | for i in range(num_conv_blocks) 72 | ] 73 | ) 74 | 75 | # Transformer layers for final encoding 76 | self.transformer_layers = nn.ModuleList( 77 | [ 78 | NeighborhoodAttention3D( 79 | embed_dim=latent_dim, kernel_size=kernel_size, num_heads=num_heads 80 | ) 81 | for _ in range(num_transformer_layers) 82 | ] 83 | ) 84 | 85 | # Final projection to latent space 86 | self.to_latent = nn.Conv3d(hidden_dim * (2**num_conv_blocks), latent_dim, kernel_size=1) 87 | 88 | def forward(self, surface: torch.Tensor, pressure: torch.Tensor) -> torch.Tensor: 89 | # Process surface data 90 | for block in self.surface_path: 91 | surface = block(surface) 92 | 93 | # Process pressure level data 94 | for block in self.pressure_path: 95 | pressure = block(pressure) 96 | # Combine features 97 | features = torch.cat( 98 | [pressure, surface.unsqueeze(2)], dim=2 99 | ) # B C D H W currently, want it to be B D H W C 100 | 101 | # Transform to latent space 102 | latent = self.to_latent(features) 103 | 104 | # Reshape to get the shapes 105 | latent = einops.rearrange(latent, "B C D H W -> B D H W C") 106 | # Apply transformer layers 107 | for transformer in self.transformer_layers: 108 | latent = transformer(latent) 109 | return latent 110 | -------------------------------------------------------------------------------- /graph_weather/models/layers/decoder.py: -------------------------------------------------------------------------------- 1 | """Decoders to decode from the Processor graph to the original graph with updated values 2 | 3 | In the original paper the decoder is described as 4 | 5 | The Decoder maps back to physical data defined on a latitude/longitude grid. The underlying graph is 6 | again bipartite, this time mapping icosahedron→lat/lon. 7 | The inputs to the Decoder come from the Processor, plus a skip connection back to the original 8 | state of the 78 atmospheric variables onthe latitude/longitude grid. 9 | The output of the Decoder is the predicted 6-hour change in the 78 atmospheric variables, 10 | which is then added to the initial state to produce the new state. We found 6 hours to be a good 11 | balance between shorter time steps (simpler dynamics to model but more iterations required during 12 | rollout) and longer time steps (fewer iterations required during rollout but modeling 13 | more complex dynamics) 14 | 15 | """ 16 | 17 | import torch 18 | 19 | from graph_weather.models.layers.assimilator_decoder import AssimilatorDecoder 20 | 21 | 22 | class Decoder(AssimilatorDecoder): 23 | """Decoder graph module""" 24 | 25 | def __init__( 26 | self, 27 | lat_lons, 28 | resolution: int = 2, 29 | input_dim: int = 256, 30 | output_dim: int = 78, 31 | output_edge_dim: int = 256, 32 | hidden_dim_processor_node: int = 256, 33 | hidden_dim_processor_edge: int = 256, 34 | hidden_layers_processor_node: int = 2, 35 | hidden_layers_processor_edge: int = 2, 36 | mlp_norm_type: str = "LayerNorm", 37 | hidden_dim_decoder: int = 128, 38 | hidden_layers_decoder: int = 2, 39 | use_checkpointing: bool = False, 40 | ): 41 | """ 42 | Decoder from latent graph to lat/lon graph 43 | 44 | Args: 45 | lat_lons: List of (lat,lon) points 46 | resolution: H3 resolution level 47 | input_dim: Input node dimension 48 | output_dim: Output node dimension 49 | output_edge_dim: Edge dimension 50 | hidden_dim_processor_node: Hidden dimension of the node processors 51 | hidden_dim_processor_edge: Hidden dimension of the edge processors 52 | hidden_layers_processor_node: Number of hidden layers in the node processors 53 | hidden_layers_processor_edge: Number of hidden layers in the edge processors 54 | hidden_dim_decoder:Number of hidden dimensions in the decoder 55 | hidden_layers_decoder: Number of layers in the decoder 56 | mlp_norm_type: Type of norm for the MLPs 57 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None 58 | use_checkpointing: Whether to use gradient checkpointing or not 59 | """ 60 | super().__init__( 61 | lat_lons, 62 | resolution, 63 | input_dim, 64 | output_dim, 65 | output_edge_dim, 66 | hidden_dim_processor_node, 67 | hidden_dim_processor_edge, 68 | hidden_layers_processor_node, 69 | hidden_layers_processor_edge, 70 | mlp_norm_type, 71 | hidden_dim_decoder, 72 | hidden_layers_decoder, 73 | use_checkpointing, 74 | ) 75 | 76 | def forward( 77 | self, processor_features: torch.Tensor, start_features: torch.Tensor, t: int = 0 78 | ) -> torch.Tensor: 79 | """ 80 | Adds features to the encoding graph 81 | 82 | Args: 83 | processor_features: Processed features in shape [B*Nodes, Features] 84 | start_features: Original input features to the encoder, with shape [B, Nodes, Features] 85 | t: Timestep for the thermalizer, not used, but there for API consistency 86 | Returns: 87 | Updated features for model 88 | """ 89 | out = super().forward(processor_features, start_features.shape[0]) 90 | out = out + start_features # residual connection 91 | return out 92 | -------------------------------------------------------------------------------- /graph_weather/models/aurora/processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Perceiver Transformer Processor: 3 | - Takes encoded features and processes them using latent space mapping. 4 | - Uses a latent-space bottleneck to compress input dimensions. 5 | - Provides an efficient way to extract long-range dependencies. 6 | - All architectural parameters are configurable. 7 | """ 8 | 9 | from dataclasses import dataclass 10 | from typing import Optional 11 | 12 | import einops 13 | import torch.nn as nn 14 | 15 | 16 | @dataclass 17 | class ProcessorConfig: 18 | input_dim: int = 256 # Match Swin3D output 19 | latent_dim: int = 512 20 | d_model: int = 256 # Match input_dim for consistency 21 | max_seq_len: int = 4096 22 | num_self_attention_layers: int = 6 23 | num_cross_attention_layers: int = 2 24 | num_attention_heads: int = 8 25 | hidden_dropout: float = 0.1 26 | attention_dropout: float = 0.1 27 | qk_head_dim: Optional[int] = 32 28 | activation_fn: str = "gelu" 29 | layer_norm_eps: float = 1e-12 30 | 31 | def __post_init__(self): 32 | # Validate parameters 33 | if self.input_dim <= 0: 34 | raise ValueError("input_dim must be positive") 35 | if self.max_seq_len <= 0: 36 | raise ValueError("max_seq_len must be positive") 37 | if self.num_attention_heads <= 0: 38 | raise ValueError("num_attention_heads must be positive") 39 | if not 0 <= self.hidden_dropout <= 1: 40 | raise ValueError("hidden_dropout must be between 0 and 1") 41 | if not 0 <= self.attention_dropout <= 1: 42 | raise ValueError("attention_dropout must be between 0 and 1") 43 | 44 | 45 | class PerceiverProcessor(nn.Module): 46 | def __init__(self, config: Optional[ProcessorConfig] = None): 47 | super().__init__() 48 | self.config = config or ProcessorConfig() 49 | 50 | # Input projection to match d_model 51 | self.input_projection = nn.Linear(self.config.input_dim, self.config.d_model) 52 | 53 | # Simplified architecture using transformer encoder 54 | self.encoder = nn.TransformerEncoder( 55 | nn.TransformerEncoderLayer( 56 | d_model=self.config.d_model, 57 | nhead=self.config.num_attention_heads, 58 | dim_feedforward=self.config.d_model * 4, 59 | dropout=self.config.hidden_dropout, 60 | activation=self.config.activation_fn, 61 | ), 62 | num_layers=self.config.num_self_attention_layers, 63 | ) 64 | 65 | # Output projection 66 | self.output_projection = nn.Linear(self.config.d_model, self.config.latent_dim) 67 | 68 | def forward(self, x, attention_mask=None): 69 | # Handle 4D input using einops for clearer reshaping 70 | if len(x.shape) == 4: 71 | # Rearrange from (batch, seq, height, width) to (batch, seq*height*width, features) 72 | x = einops.rearrange(x, "b s h w -> b (s h w) c") 73 | 74 | # Project input 75 | x = self.input_projection(x) 76 | 77 | # Apply transformer encoder with einops for transpose operations 78 | if attention_mask is not None: 79 | # Convert boolean mask to float mask where True -> 0, False -> -inf 80 | mask = ~attention_mask 81 | mask = mask.float().masked_fill(mask, float("-inf")) 82 | x = einops.rearrange( 83 | x, "b s c -> s b c" 84 | ) # (batch, seq, channels) -> (seq, batch, channels) 85 | x = self.encoder(x, src_key_padding_mask=mask) 86 | x = einops.rearrange( 87 | x, "s b c -> b s c" 88 | ) # (seq, batch, channels) -> (batch, seq, channels) 89 | else: 90 | x = einops.rearrange(x, "b s c -> s b c") 91 | x = self.encoder(x) 92 | x = einops.rearrange(x, "s b c -> b s c") 93 | 94 | # Project to latent dimension and pool 95 | x = self.output_projection(x) 96 | x = x.mean(dim=1) # Global average pooling 97 | 98 | return x 99 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/utils/noise.py: -------------------------------------------------------------------------------- 1 | """Noise generation utils.""" 2 | 3 | import einops 4 | import numpy as np 5 | import torch 6 | import torch_harmonics as th 7 | 8 | 9 | def generate_isotropic_noise(num_lon: int, num_lat: int, num_samples=1, isotropic=True): 10 | """Generate noise on the grid. 11 | 12 | When isotropic is True it samples the equivalent of white noise on a sphere and project it onto 13 | a grid using Driscoll and Healy, 1994, algorithm. The power spectrum is normalized to have 14 | variance 1. We need to assume lons = 2 * lats or lons = 2 * (lats -1). If isotropic is false, it 15 | samples flat normal random noise. 16 | 17 | Args: 18 | num_lon (int): number of longitudes in the grid. 19 | num_lat (int): number of latitudes in the grid. 20 | num_samples (int): number of indipendent samples. Defaults to 1. 21 | isotropic (bool): if true generates isotropic noise, else flat noise. Defaults to True. 22 | 23 | Returns: 24 | grid: Numpy array with shape shape(grid) x num_samples. 25 | """ 26 | if isotropic: 27 | if 2 * num_lat == num_lon: 28 | extend = False 29 | elif 2 * (num_lat - 1) == num_lon: 30 | extend = True 31 | else: 32 | raise ValueError( 33 | "Isotropic noise requires grid's shape to be 2N x N or 2N x (N+1): " 34 | f"got {num_lon} x {num_lat}. If the shape is correct, please specify" 35 | "isotropic=False in the constructor.", 36 | ) 37 | 38 | if isotropic: 39 | lmax = num_lat - 1 if extend else num_lat 40 | mmax = lmax + 1 41 | coeffs = torch.randn(num_samples, lmax, mmax, dtype=torch.complex64) / np.sqrt( 42 | (num_lat**2) // 2 43 | ) 44 | isht = th.InverseRealSHT( 45 | nlat=num_lat, nlon=num_lon, lmax=lmax, mmax=mmax, grid="equiangular" 46 | ) 47 | noise = isht(coeffs) * np.sqrt(2 * np.pi) 48 | noise = einops.rearrange(noise, "b lat lon -> lon lat b").numpy() 49 | else: 50 | noise = np.random.randn(num_lon, num_lat, num_samples) 51 | return noise 52 | 53 | 54 | def sample_noise_level(sigma_min=0.02, sigma_max=88, rho=7): 55 | """Generate random sample of noise level. 56 | 57 | Sample a noise level according to the distribution described in the paper. 58 | Notice that the default values are valid only for training and need to be 59 | modified for sampling. 60 | 61 | Args: 62 | sigma_min (float, optional): Defaults to 0.02. 63 | sigma_max (int, optional): Defaults to 88. 64 | rho (int, optional): Defaults to 7. 65 | 66 | Returns: 67 | noise_level: single sample of noise level. 68 | """ 69 | u = np.random.random() 70 | noise_level = ( 71 | sigma_max ** (1 / rho) + u * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) 72 | ) ** rho 73 | return noise_level 74 | 75 | 76 | class Preconditioner(torch.nn.Module): 77 | """Collection of preconditioning functions. 78 | 79 | These functions are described in Karras (2022), table 1. 80 | """ 81 | 82 | def __init__(self, sigma_data: float = 1): 83 | """Initialize the preconditioning functions. 84 | 85 | Args: 86 | sigma_data (float): Karras suggests 0.5, GenCast 1. Defaults to 1. 87 | """ 88 | super().__init__() 89 | self.sigma_data = sigma_data 90 | 91 | def c_skip(self, sigma): 92 | """Scaling factor for skip connection.""" 93 | return self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 94 | 95 | def c_out(self, sigma): 96 | """Scaling factor for output.""" 97 | return sigma * self.sigma_data / torch.sqrt(sigma**2 + self.sigma_data**2) 98 | 99 | def c_in(self, sigma): 100 | """Scaling factor for input.""" 101 | return 1 / torch.sqrt(sigma**2 + self.sigma_data**2) 102 | 103 | def c_noise(self, sigma): 104 | """Scaling factor for noise level.""" 105 | return 1 / 4 * torch.log(sigma) 106 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/layers/decoder.py: -------------------------------------------------------------------------------- 1 | """Decoder layer. 2 | 3 | The decoder: 4 | - perform a single message-passing step on mesh2grid using a classical interaction network. 5 | - add a residual connection to the grid nodes. 6 | """ 7 | 8 | import torch 9 | 10 | from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork 11 | 12 | 13 | class Decoder(torch.nn.Module): 14 | """GenCast's decoder.""" 15 | 16 | def __init__( 17 | self, 18 | edges_dim: int, 19 | output_dim: int, 20 | hidden_dims: list[int], 21 | activation_layer: torch.nn.Module = torch.nn.ReLU, 22 | use_layer_norm: bool = True, 23 | ): 24 | """Initialize the Decoder. 25 | 26 | Args: 27 | edges_dim (int): dimension of edges' features. 28 | output_dim (int): dimension of final output. 29 | hidden_dims (list[int]): hidden dimensions of internal MLPs. 30 | activation_layer (torch.nn.Module, optional): activation function of internal MLPs. 31 | Defaults to torch.nn.ReLU. 32 | use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP. 33 | Defaults to True. 34 | """ 35 | super().__init__() 36 | 37 | # All the MLPs in GenCast have same hidden and output dims. Hence, the embedding latent 38 | # dimension and the MLPs' output dimension are the same. Moreover, for simplicity, we will 39 | # ask the hidden dims just once for each MLP in a module: we don't need to specify them 40 | # individually as arguments, even if the MLPs could have different roles. 41 | self.latent_dim = hidden_dims[-1] 42 | 43 | # Embedders 44 | self.edges_mlp = MLP( 45 | input_dim=edges_dim, 46 | hidden_dims=hidden_dims, 47 | activation_layer=activation_layer, 48 | use_layer_norm=use_layer_norm, 49 | bias=True, 50 | activate_final=False, 51 | ) 52 | 53 | # Message Passing 54 | self.gnn = InteractionNetwork( 55 | sender_dim=self.latent_dim, 56 | receiver_dim=self.latent_dim, 57 | edge_attr_dim=self.latent_dim, 58 | hidden_dims=hidden_dims, 59 | use_layer_norm=use_layer_norm, 60 | activation_layer=activation_layer, 61 | ) 62 | 63 | # Final grid nodes update 64 | self.grid_mlp_final = MLP( 65 | input_dim=self.latent_dim, 66 | hidden_dims=hidden_dims[:-1] + [output_dim], 67 | activation_layer=activation_layer, 68 | use_layer_norm=use_layer_norm, 69 | bias=True, 70 | activate_final=False, 71 | ) 72 | 73 | def forward( 74 | self, 75 | input_mesh_nodes: torch.Tensor, 76 | input_grid_nodes: torch.Tensor, 77 | input_edge_attr: torch.Tensor, 78 | edge_index: torch.Tensor, 79 | ) -> torch.Tensor: 80 | """Forward pass. 81 | 82 | Args: 83 | input_mesh_nodes (torch.Tensor): mesh nodes' features. 84 | input_grid_nodes (torch.Tensor): grid nodes' features. 85 | input_edge_attr (torch.Tensor): grid2mesh edges' features. 86 | edge_index (torch.Tensor): edge index tensor. 87 | 88 | Returns: 89 | torch.Tensor: output grid nodes. 90 | """ 91 | if not ( 92 | input_grid_nodes.shape[-1] == self.latent_dim 93 | and input_mesh_nodes.shape[-1] == self.latent_dim 94 | ): 95 | raise ValueError( 96 | "The dimension of grid nodes and mesh nodes' features must be " 97 | "equal to the last hidden dimension." 98 | ) 99 | 100 | # Embedding 101 | edges_emb = self.edges_mlp(input_edge_attr) 102 | 103 | # Message-passing + residual connection 104 | latent_grid_nodes = input_grid_nodes + self.gnn( 105 | x=(input_mesh_nodes, input_grid_nodes), 106 | edge_index=edge_index, 107 | edge_attr=edges_emb, 108 | ) 109 | 110 | # Update grid nodes 111 | latent_grid_nodes = self.grid_mlp_final(latent_grid_nodes) 112 | 113 | return latent_grid_nodes 114 | -------------------------------------------------------------------------------- /tests/test_asme_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch_harmonics as th 4 | 5 | from graph_weather.models.losses import AMSENormalizedLoss 6 | 7 | 8 | @pytest.fixture 9 | def default_shape() -> tuple[int, int, int, int]: 10 | """Return a default tensor shape (B, C, H, W) for test inputs.""" 11 | return 2, 3, 32, 64 12 | 13 | 14 | @pytest.fixture 15 | def feature_variance(default_shape: tuple) -> torch.Tensor: 16 | """Return a synthetic feature variance tensor, one value per channel.""" 17 | _, num_channels, _, _ = default_shape 18 | return (torch.rand(num_channels) + 0.5).clone().detach() 19 | 20 | 21 | @pytest.fixture 22 | def loss_fn(feature_variance: torch.Tensor) -> AMSENormalizedLoss: 23 | """Instantiate the AMSENormalizedLoss with mock feature variance.""" 24 | return AMSENormalizedLoss(feature_variance=feature_variance) 25 | 26 | 27 | def test_zero_loss_for_identical_inputs(loss_fn: AMSENormalizedLoss, default_shape: tuple): 28 | """Loss should be zero when prediction and target tensors are identical.""" 29 | pred = torch.randn(default_shape) 30 | target = pred.clone() 31 | loss = loss_fn(pred, target) 32 | assert torch.allclose(loss, torch.tensor(0.0), atol=1e-6) 33 | 34 | 35 | def test_positive_loss_for_different_inputs(loss_fn: AMSENormalizedLoss, default_shape: tuple): 36 | """Loss should be strictly positive when inputs differ.""" 37 | pred = torch.randn(default_shape) 38 | target = torch.randn(default_shape) 39 | loss = loss_fn(pred, target) 40 | assert loss.item() > 0.0 41 | 42 | 43 | def test_gradient_flow(loss_fn: AMSENormalizedLoss, default_shape: tuple): 44 | """Check that gradients can flow through the loss for backpropagation.""" 45 | pred = torch.randn(default_shape, requires_grad=True) 46 | target = torch.randn(default_shape) 47 | loss = loss_fn(pred, target) 48 | loss.backward() 49 | assert pred.grad is not None 50 | assert torch.sum(torch.abs(pred.grad)) > 0 51 | 52 | 53 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") 54 | def test_cuda_execution(feature_variance: torch.Tensor, default_shape: tuple): 55 | """Verify that the loss runs on GPU and returns a finite CUDA tensor.""" 56 | device = torch.device("cuda") 57 | loss_fn_cuda = AMSENormalizedLoss(feature_variance=feature_variance).to(device) 58 | pred = torch.randn(default_shape, device=device) 59 | target = torch.randn(default_shape, device=device) 60 | loss = loss_fn_cuda(pred, target) 61 | assert loss.is_cuda 62 | assert torch.isfinite(loss) 63 | 64 | 65 | def test_known_value_simple_case(feature_variance: torch.Tensor): 66 | """ 67 | Validate loss against a known spectral case. 68 | 69 | This test generates synthetic spectral coefficients and applies the inverse 70 | spherical harmonic transform to ensure the AMSE loss produces expected values. 71 | """ 72 | nlat, nlon = 16, 32 73 | batch_size, num_channels = 1, feature_variance.shape[0] 74 | 75 | sht_forward_temp = th.RealSHT(nlat, nlon, grid="equiangular") 76 | lmax, mmax = sht_forward_temp.lmax, sht_forward_temp.mmax 77 | coeffs_shape = (batch_size * num_channels, lmax, mmax) 78 | 79 | # Place known energy in (l=1, m=0) band 80 | target_coeffs = torch.zeros(coeffs_shape, dtype=torch.complex64) 81 | target_coeffs[:, 1, 0] = 1.0 + 0.0j 82 | pred_coeffs = target_coeffs * 0.5 83 | 84 | # Inverse SHT to get spatial-domain data 85 | isht = th.InverseRealSHT(nlat, nlon, grid="equiangular") 86 | target = isht(target_coeffs).view(batch_size, num_channels, nlat, nlon) 87 | pred = isht(pred_coeffs).view(batch_size, num_channels, nlat, nlon) 88 | 89 | # Manually compute expected normalized spectral loss 90 | psd_target_l1 = 1.0**2 91 | psd_pred_l1 = 0.5**2 92 | amp_error_l1 = ( 93 | torch.sqrt(torch.tensor(psd_pred_l1)) - torch.sqrt(torch.tensor(psd_target_l1)) 94 | ) ** 2 95 | expected_spectral_loss_per_channel = amp_error_l1 96 | expected_normalized_loss = (expected_spectral_loss_per_channel / feature_variance).mean() 97 | 98 | # Compare to actual loss 99 | loss_fn = AMSENormalizedLoss(feature_variance=feature_variance) 100 | actual_loss = loss_fn(pred, target) 101 | 102 | assert torch.allclose(actual_loss, expected_normalized_loss, atol=1e-5) 103 | -------------------------------------------------------------------------------- /graph_weather/models/cafa/factorize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core components for the Factorized Attention mechanism, 3 | based on the principles of Axial Attention. 4 | """ 5 | 6 | from einops import rearrange 7 | from torch import einsum, nn 8 | 9 | 10 | def FeedFoward(dim, multiply=4, dropout=0.0): 11 | """ 12 | Standard feed-forward block used in transformer architecture. 13 | Consists of 2 linear layers with GELU activation and dropouts, in between. 14 | """ 15 | inner_dim = int(dim * multiply) 16 | return nn.Sequential( 17 | nn.Linear(dim, inner_dim), 18 | nn.GELU(), 19 | nn.Dropout(dropout), 20 | nn.Linear(inner_dim, dim), 21 | nn.Dropout(dropout), 22 | ) 23 | 24 | 25 | class AxialAttention(nn.Module): 26 | """ 27 | Performs multi-head self-attention on a single axis of a 2D feature map. 28 | Core building block for Factorized Attention. 29 | """ 30 | 31 | def __init__(self, dim, heads, dim_head=64, dropout=0.0): 32 | super().__init__() 33 | self.heads = heads 34 | self.scale = dim_head**-0.5 35 | inner_dim = dim_head * heads 36 | 37 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 38 | self.to_out = nn.Linear(inner_dim, dim) 39 | self.dropout = nn.Dropout(dropout) 40 | 41 | def forward(self, x, axis): 42 | """ 43 | Forward pass for axial attention 44 | Args: 45 | x: Input tensor of shape (batch, height, width, channels) 46 | axis: Axis to perform attention on (1 for height, 2 for width) 47 | """ 48 | b, h, w, d = x.shape 49 | 50 | # rearrange tensor to isolate attention axis as the sequence dim 51 | if axis == 1: 52 | x = rearrange(x, "b h w d -> (b w) h d") # attention along height 53 | elif axis == 2: 54 | x = rearrange(x, "b h w d -> (b h) w d") # attention along width 55 | else: 56 | raise ValueError("Axis must be 1 (height) or 2 (width)") 57 | 58 | # project to query, key and value tensors 59 | q, k, v = self.to_qkv(x).chunk(3, dim=-1) 60 | q, k, v = map( 61 | lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (q, k, v) 62 | ) # reshape for multi-head attn 63 | 64 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale # attention scores 65 | attn = sim.softmax(dim=-1) 66 | attn = self.dropout(attn) 67 | 68 | # attn to the value tensors 69 | out = einsum("b h i j, b h j d -> b h i d", attn, v) 70 | out = rearrange(out, "b h n d -> b n (h d)") 71 | out = self.to_out(out) 72 | 73 | # original 2D grid format 74 | if axis == 1: 75 | out = rearrange(out, "(b w) h d -> b h w d", w=w) 76 | elif axis == 2: 77 | out = rearrange(out, "(b h) w d -> b h w d", h=h) 78 | 79 | return out 80 | 81 | 82 | class FactorizedAttention(nn.Module): 83 | """ 84 | Combines 2 AxialAttention blocks to perform full factorized attention 85 | over a 2D feature map, first along height then along width. 86 | """ 87 | 88 | def __init__(self, dim, heads, dim_head=64, dropout=0.0): 89 | super().__init__() 90 | self.attn_height = AxialAttention(dim, heads, dim_head, dropout) 91 | self.attn_width = AxialAttention(dim, heads, dim_head, dropout) 92 | self.norm1 = nn.LayerNorm(dim) 93 | self.norm2 = nn.LayerNorm(dim) 94 | 95 | def forward(self, x): 96 | """ 97 | Args: 98 | x: Input tensor of shape (batch, height, width, channels) 99 | """ 100 | x = x + self.attn_height(self.norm1(x), axis=1) 101 | x = x + self.attn_width(self.norm2(x), axis=2) 102 | return x 103 | 104 | 105 | class FactorizedTransformerBlock(nn.Module): 106 | """ 107 | Standalone transformer block using Factorized attention 108 | """ 109 | 110 | def __init__(self, dim, heads, dim_head=64, feedforward_multiplier=4, dropout=0.0): 111 | super().__init__() 112 | self.attn = FactorizedAttention(dim, heads, dim_head, dropout) 113 | self.ffn = FeedFoward(dim, feedforward_multiplier, dropout) 114 | self.norm1 = nn.LayerNorm(dim) 115 | self.norm2 = nn.LayerNorm(dim) 116 | 117 | def forward(self, x): 118 | """ 119 | Args: 120 | x: Input tensor of shape (batch, height, width, channels) 121 | """ 122 | x = x + self.attn(self.norm1(x)) 123 | x = x + self.ffn(self.norm2(x)) 124 | return x 125 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "graph_weather" 3 | requires-python = ">=3.11" 4 | version = "1.0.89" 5 | description = "Graph-based AI Weather models" 6 | authors = [ 7 | {name = "Jacob Prince-Bieker", email = "jacob@bieker.tech"}, 8 | ] 9 | dependencies = ["torch-harmonics", "anemoi-datasets>=0.4.5,<0.5"] 10 | 11 | [build-system] 12 | build-backend = "hatchling.build" 13 | requires = ["hatchling"] 14 | 15 | [tool.pixi.project] 16 | channels = ["conda-forge"] 17 | platforms = ["linux-64", "osx-arm64"] 18 | 19 | [tool.pixi.feature.cuda] 20 | channels = ["nvidia",] 21 | 22 | [tool.pixi.feature.cuda.system-requirements] 23 | cuda = "12" 24 | 25 | [tool.pixi.feature.cuda.target.linux-64.dependencies] 26 | cuda-version = "==12.8" 27 | 28 | #[tool.pixi.feature.cuda.target.linux-64.pypi-dependencies] 29 | #natten = {url = "https://shi-labs.com/natten/wheels/cu124/torch2.4.0/natten-0.17.4%2Btorch240cu124-cp312-cp312-linux_x86_64.whl"} 30 | 31 | [tool.pixi.feature.mlx] 32 | # MLX is only available on macOS >=13.5 (>14.0 is recommended) 33 | system-requirements = {macos = "13.5"} 34 | 35 | [tool.pixi.feature.mlx.target.osx-arm64.dependencies] 36 | mlx = {version = "*", channel = "conda-forge"} 37 | 38 | #[tool.pixi.feature.mlx.target.osx-arm64.pypi-dependencies] 39 | #natten = "*" 40 | [tool.pixi.feature.cpu] 41 | platforms = ["linux-64", "osx-arm64"] 42 | 43 | #[tool.pixi.feature.cpu.pypi-dependencies] 44 | #natten = "*" 45 | 46 | [tool.pixi.dependencies] 47 | python = "3.12.*" 48 | pip = "*" 49 | pytest = "*" 50 | pre-commit = "*" 51 | ruff = "*" 52 | xarray = "*" 53 | pandas = "*" 54 | numcodecs = "*" 55 | scipy = "*" 56 | zarr = ">=3.0.0" 57 | tqdm = "*" 58 | lightning = "*" 59 | einops = "*" 60 | fsspec = "*" 61 | datasets = "*" 62 | trimesh = "*" 63 | pysolar = "*" 64 | rtree = "*" 65 | pixi-pycharm = "*" 66 | uv = ">=0.6.2,<0.7" 67 | healpy = "*" 68 | dacite = "*" 69 | 70 | [tool.pixi.pypi-dependencies] 71 | torch_geometric = "*" 72 | pytest = "*" 73 | pytest-xdist = "*" 74 | h3 = "==4.3.1" 75 | 76 | [tool.pixi.feature.cuda.pypi-dependencies] 77 | torch = { version = ">=2.7.0", index = "https://download.pytorch.org/whl/cu128" } 78 | torchvision = {version = "*", index = "https://download.pytorch.org/whl/cu128"} 79 | 80 | [tool.pixi.feature.mlx.pypi-dependencies] 81 | torch = { version = ">=2.7.0", index = "https://download.pytorch.org/whl/cpu" } 82 | torchvision = {version = "*", index = "https://download.pytorch.org/whl/cpu"} 83 | 84 | [tool.pixi.feature.cpu.pypi-dependencies] 85 | torch = { version = ">=2.7.0", index = "https://download.pytorch.org/whl/cpu" } 86 | torchvision = {version = "*", index = "https://download.pytorch.org/whl/cpu"} 87 | 88 | [tool.pixi.environments] 89 | default = ["cpu"] 90 | cuda = ["cuda"] 91 | mlx = ["mlx"] 92 | 93 | [tool.pixi.tasks] 94 | install = "pip install --editable ." 95 | installnnja = "pip install git+https://github.com/brightbandtech/nnja-ai.git" 96 | installnat = "pip install natten" 97 | installnatcuda = "pip install natten==0.20.1+torch270cu128 -f https://whl.natten.org" 98 | installpyg = "pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.7.0+cpu.html" 99 | installpygcuda = "pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.7.0+cuda128.html" 100 | test = "pytest" 101 | format = "ruff format" 102 | 103 | 104 | [tool.ruff] 105 | # Exclude a variety of commonly ignored directories. 106 | exclude = [ 107 | ".bzr", 108 | ".direnv", 109 | ".eggs", 110 | ".git", 111 | ".hg", 112 | ".mypy_cache", 113 | ".nox", 114 | ".nox", 115 | ".pants.d", 116 | ".pytype", 117 | ".ruff_cache", 118 | ".svn", 119 | ".tox", 120 | ".venv", 121 | "__pypackages__", 122 | "_build", 123 | "buck-out", 124 | "build", 125 | "dist", 126 | "node_modules", 127 | "venv", 128 | "tests", 129 | ] 130 | # Same as Black. 131 | line-length = 100 132 | 133 | # Assume Python 3.10. 134 | target-version = "py311" 135 | fix=false 136 | # Group violations by containing file. 137 | output-format = "github" 138 | 139 | [tool.ruff.lint] 140 | # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. 141 | select = ["E", "F", "D", "I"] 142 | ignore = ["D200","D202","D210","D212","D415","D105"] 143 | 144 | # Allow autofix for all enabled rules (when `--fix`) is provided. 145 | fixable = ["A", "B", "C", "D", "E", "F", "I"] 146 | unfixable = [] 147 | # Allow unused variables when underscore-prefixed. 148 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 149 | mccabe.max-complexity = 10 150 | pydocstyle.convention = "google" 151 | 152 | [tool.ruff.lint.per-file-ignores] 153 | "__init__.py" = ["F401", "E402"] 154 | -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ConvDownBlock(nn.Module): 11 | """ 12 | Downsampling convolutional block with residual connection. 13 | Can handle both 2D and 3D inputs. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | in_channels: int, 19 | out_channels: int, 20 | is_3d: bool = False, 21 | kernel_size: int = 3, 22 | stride: int = 2, 23 | padding: int = 1, 24 | groups: int = 1, 25 | activation: nn.Module = nn.GELU(), 26 | ): 27 | super().__init__() 28 | 29 | Conv = nn.Conv3d if is_3d else nn.Conv2d 30 | Norm = nn.BatchNorm3d if is_3d else nn.BatchNorm2d 31 | 32 | self.conv1 = Conv( 33 | in_channels, 34 | out_channels, 35 | kernel_size=kernel_size, 36 | stride=1, 37 | padding=padding, 38 | groups=groups, 39 | bias=False, 40 | ) 41 | self.bn1 = Norm(out_channels) 42 | self.activation1 = activation 43 | 44 | self.conv2 = Conv( 45 | out_channels, 46 | out_channels, 47 | kernel_size=kernel_size, 48 | stride=stride, 49 | padding=padding, 50 | groups=groups, 51 | bias=False, 52 | ) 53 | self.bn2 = Norm(out_channels) 54 | self.activation2 = activation 55 | 56 | # Residual connection with 1x1 conv to match dimensions 57 | self.downsample = Conv(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) 58 | self.bn_down = Norm(out_channels) 59 | 60 | def forward(self, x: torch.Tensor) -> torch.Tensor: 61 | identity = self.bn_down(self.downsample(x)) 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.activation1(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | out += identity 71 | out = self.activation2(out) 72 | 73 | return out 74 | 75 | 76 | class ConvUpBlock(nn.Module): 77 | """ 78 | Upsampling convolutional block with residual connection. same as downBlock but reversed. 79 | """ 80 | 81 | def __init__( 82 | self, 83 | in_channels: int, 84 | out_channels: int, 85 | is_3d: bool = False, 86 | kernel_size: int = 3, 87 | scale_factor: int = 2, 88 | padding: int = 1, 89 | groups: int = 1, 90 | activation: nn.Module = nn.GELU(), 91 | ): 92 | super().__init__() 93 | 94 | Conv = nn.Conv3d if is_3d else nn.Conv2d 95 | Norm = nn.BatchNorm3d if is_3d else nn.BatchNorm2d 96 | self.is_3d = is_3d 97 | self.scale_factor = scale_factor 98 | 99 | self.conv1 = Conv( 100 | in_channels, 101 | in_channels, 102 | kernel_size=kernel_size, 103 | stride=1, 104 | padding=padding, 105 | groups=groups, 106 | bias=False, 107 | ) 108 | self.bn1 = Norm(in_channels) 109 | self.activation1 = activation 110 | 111 | self.conv2 = Conv( 112 | in_channels, 113 | out_channels, 114 | kernel_size=kernel_size, 115 | stride=1, 116 | padding=padding, 117 | groups=groups, 118 | bias=False, 119 | ) 120 | self.bn2 = Norm(out_channels) 121 | self.activation2 = activation 122 | 123 | # Residual connection with 1x1 conv to match dimensions 124 | self.upsample = Conv(in_channels, out_channels, kernel_size=1, stride=1, bias=False) 125 | self.bn_up = Norm(out_channels) 126 | 127 | def forward(self, x: torch.Tensor) -> torch.Tensor: 128 | # Upsample input 129 | if self.is_3d: 130 | x = F.interpolate( 131 | x, 132 | scale_factor=(1, self.scale_factor, self.scale_factor), 133 | mode="trilinear", 134 | align_corners=False, 135 | ) 136 | else: 137 | x = F.interpolate( 138 | x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False 139 | ) 140 | 141 | identity = self.bn_up(self.upsample(x)) 142 | 143 | out = self.conv1(x) 144 | out = self.bn1(out) 145 | out = self.activation1(out) 146 | 147 | out = self.conv2(out) 148 | out = self.bn2(out) 149 | 150 | out += identity 151 | out = self.activation2(out) 152 | 153 | return out 154 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/layers/encoder.py: -------------------------------------------------------------------------------- 1 | """Encoder layer. 2 | 3 | The encoder: 4 | - embeds grid nodes, mesh nodes and g2m edges' features to the latent space. 5 | - perform a single message-passing step using a classical interaction network. 6 | - add a residual connection to the mesh and grid nodes. 7 | """ 8 | 9 | import torch 10 | 11 | from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork 12 | 13 | 14 | class Encoder(torch.nn.Module): 15 | """GenCast's encoder.""" 16 | 17 | def __init__( 18 | self, 19 | grid_dim: int, 20 | mesh_dim: int, 21 | edge_dim: int, 22 | hidden_dims: list[int], 23 | activation_layer: torch.nn.Module = torch.nn.ReLU, 24 | use_layer_norm: bool = True, 25 | scale_factor: float = 1.0, 26 | ): 27 | """Initialize the Encoder. 28 | 29 | Args: 30 | grid_dim (int): dimension of grid nodes' features. 31 | mesh_dim (int): dimension of mesh nodes' features 32 | edge_dim (int): dimension of g2m edges' features 33 | hidden_dims (list[int]): hidden dimensions of internal MLPs. 34 | activation_layer (torch.nn.Module, optional): activation function of internal MLPs. 35 | Defaults to torch.nn.ReLU. 36 | use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP. 37 | Defaults to True. 38 | scale_factor (float): the message of the interaction network between the grid and the 39 | the mesh is multiplied by the scale factor. Useful when fine-tuning a pretrained 40 | model to a higher resolution. Defaults to 1. 41 | """ 42 | super().__init__() 43 | 44 | # All the MLPs in GenCast have same hidden and output dims. Hence, the embedding latent 45 | # dimension and the MLPs' output dimension are the same. Moreover, for simplicity, we will 46 | # ask the hidden dims just once for each MLP in a module: we don't need to specify them 47 | # individually as arguments, even if the MLPs could have different roles. 48 | self.latent_dim = hidden_dims[-1] 49 | 50 | # Embedders 51 | self.grid_mlp = MLP( 52 | input_dim=grid_dim, 53 | hidden_dims=hidden_dims, 54 | activation_layer=activation_layer, 55 | use_layer_norm=use_layer_norm, 56 | bias=True, 57 | activate_final=False, 58 | ) 59 | 60 | self.mesh_mlp = MLP( 61 | input_dim=mesh_dim, 62 | hidden_dims=hidden_dims, 63 | activation_layer=activation_layer, 64 | use_layer_norm=use_layer_norm, 65 | bias=True, 66 | activate_final=False, 67 | ) 68 | 69 | self.edges_mlp = MLP( 70 | input_dim=edge_dim, 71 | hidden_dims=hidden_dims, 72 | activation_layer=activation_layer, 73 | use_layer_norm=use_layer_norm, 74 | bias=True, 75 | activate_final=False, 76 | ) 77 | 78 | # Message Passing 79 | self.gnn = InteractionNetwork( 80 | sender_dim=self.latent_dim, 81 | receiver_dim=self.latent_dim, 82 | edge_attr_dim=self.latent_dim, 83 | hidden_dims=hidden_dims, 84 | use_layer_norm=use_layer_norm, 85 | activation_layer=activation_layer, 86 | scale_factor=scale_factor, 87 | ) 88 | 89 | # Final grid nodes update 90 | self.grid_mlp_final = MLP( 91 | input_dim=self.latent_dim, 92 | hidden_dims=hidden_dims, 93 | activation_layer=activation_layer, 94 | use_layer_norm=use_layer_norm, 95 | bias=True, 96 | activate_final=False, 97 | ) 98 | 99 | def forward( 100 | self, 101 | input_grid_nodes: torch.Tensor, 102 | input_mesh_nodes: torch.Tensor, 103 | input_edge_attr: torch.Tensor, 104 | edge_index: torch.Tensor, 105 | ) -> tuple[torch.Tensor, torch.Tensor]: 106 | """Forward pass. 107 | 108 | Args: 109 | input_grid_nodes (torch.Tensor): grid nodes' features. 110 | input_mesh_nodes (torch.Tensor): mesh nodes' features. 111 | input_edge_attr (torch.Tensor): grid2mesh edges' features. 112 | edge_index (torch.Tensor): edge index tensor. 113 | 114 | Returns: 115 | tuple[torch.Tensor, torch.Tensor]: output grid nodes, output mesh nodes. 116 | """ 117 | 118 | # Embedding 119 | grid_emb = self.grid_mlp(input_grid_nodes) 120 | mesh_emb = self.mesh_mlp(input_mesh_nodes) 121 | edges_emb = self.edges_mlp(input_edge_attr) 122 | 123 | # Message-passing + residual connection 124 | latent_mesh_nodes = mesh_emb + self.gnn( 125 | x=(grid_emb, mesh_emb), 126 | edge_index=edge_index, 127 | edge_attr=edges_emb, 128 | ) 129 | 130 | # Update grid nodes + residual connection 131 | latent_grid_nodes = grid_emb + self.grid_mlp_final(grid_emb) 132 | 133 | return latent_grid_nodes, latent_mesh_nodes 134 | -------------------------------------------------------------------------------- /tests/test_anemoi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import xarray as xr 4 | from torch.utils.data import DataLoader 5 | from unittest.mock import patch 6 | from graph_weather.data import AnemoiDataset 7 | 8 | 9 | def fake_open_dataset(config): 10 | # Create a small, synthetic xarray.Dataset for testing 11 | data = xr.Dataset( 12 | { 13 | "temperature": (("time", "lat", "lon"), np.random.rand(3, 2, 2)), 14 | "geopotential": (("time", "lat", "lon"), np.random.rand(3, 2, 2)), 15 | "u_component_of_wind": (("time", "lat", "lon"), np.random.rand(3, 2, 2)), 16 | "v_component_of_wind": (("time", "lat", "lon"), np.random.rand(3, 2, 2)), 17 | }, 18 | coords={ 19 | "time": pd.date_range("2020-01-01", periods=3), 20 | "lat": [0.0, 1.0], 21 | "lon": [10.0, 11.0], 22 | }, 23 | ) 24 | return data 25 | 26 | 27 | def test_anemoi_dataset(): 28 | """Test the AnemoiDataset class with synthetic data""" 29 | dataset_config = { 30 | "dataset_name": "synthetic", 31 | "features": ["temperature", "geopotential", "u_component_of_wind", "v_component_of_wind"], 32 | "time_range": ("2020-01-01", "2020-01-03"), 33 | "time_step": 1, 34 | "max_samples": 2, 35 | "means": { 36 | "temperature": 0.5, 37 | "geopotential": 0.5, 38 | "u_component_of_wind": 0.5, 39 | "v_component_of_wind": 0.5, 40 | }, 41 | "stds": { 42 | "temperature": 0.2, 43 | "geopotential": 0.2, 44 | "u_component_of_wind": 0.2, 45 | "v_component_of_wind": 0.2, 46 | }, 47 | } 48 | with patch("graph_weather.data.anemoi_dataloader.open_dataset", new=fake_open_dataset): 49 | dataset = AnemoiDataset(**dataset_config) 50 | assert len(dataset) > 0 51 | assert dataset.num_lat > 0 and dataset.num_lon > 0 52 | assert len(dataset.features) > 0 53 | 54 | # Test getting a single sample 55 | input_data, target_data = dataset[0] 56 | assert input_data.shape == target_data.shape 57 | assert input_data.dtype == np.float32 58 | assert not ( 59 | np.isnan(input_data).any() or np.isnan(target_data).any() 60 | ), "Found NaN values in data!" 61 | 62 | # Test with DataLoader 63 | dataloader = DataLoader(dataset, batch_size=2, shuffle=True) 64 | batch_input, batch_target = next(iter(dataloader)) 65 | assert batch_input.shape[0] == 2 66 | assert batch_target.shape[0] == 2 67 | 68 | 69 | def test_normalization(): 70 | """Test that normalization is working correctly""" 71 | with patch("graph_weather.data.anemoi_dataloader.open_dataset", new=fake_open_dataset): 72 | dataset = AnemoiDataset( 73 | dataset_name="synthetic", 74 | features=["temperature"], 75 | max_samples=3, 76 | means={"temperature": 0.5}, 77 | stds={"temperature": 0.2}, 78 | ) 79 | samples = [] 80 | for i in range(min(3, len(dataset))): 81 | input_data, _ = dataset[i] 82 | samples.append(input_data[:, 0]) # First feature (temperature) 83 | all_values = np.concatenate(samples) 84 | mean_val = np.mean(all_values) 85 | std_val = np.std(all_values) 86 | assert ( 87 | abs(mean_val) < 0.5 and abs(std_val - 1.0) < 0.5 88 | ), "Normalization might need adjustment" 89 | 90 | 91 | def test_time_features(): 92 | """Test that time features are being added correctly""" 93 | with patch("graph_weather.data.anemoi_dataloader.open_dataset", new=fake_open_dataset): 94 | dataset = AnemoiDataset( 95 | dataset_name="synthetic", 96 | features=["temperature"], 97 | max_samples=2, 98 | means={"temperature": 0.5}, 99 | stds={"temperature": 0.2}, 100 | ) 101 | input_data, _ = dataset[0] 102 | num_features = len(dataset.features) 103 | num_time_features = 4 # sin/cos day, sin/cos hour 104 | expected_total_features = num_features + num_time_features 105 | actual_features = input_data.shape[1] 106 | assert actual_features == expected_total_features, "Feature count mismatch" 107 | 108 | 109 | def test_check_anemoi_dataset_output(): 110 | """Compare output format with GenCast dataloader expectations""" 111 | with patch("graph_weather.data.anemoi_dataloader.open_dataset", new=fake_open_dataset): 112 | dataset = AnemoiDataset( 113 | dataset_name="synthetic", 114 | features=["temperature", "geopotential"], 115 | max_samples=2, 116 | means={"temperature": 0.5, "geopotential": 0.5}, 117 | stds={"temperature": 0.2, "geopotential": 0.2}, 118 | ) 119 | input_data, target_data = dataset[0] 120 | assert input_data.shape == target_data.shape, "Input and target shapes do not match" 121 | assert input_data.dtype == np.float32, "Input data type should be float32" 122 | assert len(input_data.shape) == 2, "Input data should be 2D" 123 | expected_locations = dataset.num_lat * dataset.num_lon 124 | actual_locations = input_data.shape[0] 125 | assert actual_locations == expected_locations, "Format mismatch" 126 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/sampler.py: -------------------------------------------------------------------------------- 1 | """Diffusion sampler""" 2 | 3 | import math 4 | 5 | import torch 6 | 7 | from graph_weather.models.gencast import Denoiser 8 | from graph_weather.models.gencast.utils.noise import generate_isotropic_noise 9 | 10 | 11 | class Sampler: 12 | """Sampler for the denoiser. 13 | 14 | The sampler consists in the second-order DPMSolver++2S solver (Lu et al., 2022), augmented with 15 | the stochastic churn (again making use of the isotropic noise) and noise inflation techniques 16 | used in Karras et al. (2022) to inject further stochasticity into the sampling process. In 17 | conditioning on previous timesteps it follows the Conditional Denoising Estimator approach 18 | outlined and motivated by Batzolis et al. (2021). 19 | """ 20 | 21 | def __init__( 22 | self, 23 | S_noise: float = 1.05, 24 | S_tmin: float = 0.75, 25 | S_tmax: float = 80.0, 26 | S_churn: float = 2.5, 27 | r: float = 0.5, 28 | sigma_max: float = 80.0, 29 | sigma_min: float = 0.03, 30 | rho: float = 7, 31 | num_steps: int = 20, 32 | ): 33 | """Initialize the sampler. 34 | 35 | Args: 36 | S_noise (float): noise inflation parameter. Defaults to 1.05. 37 | S_tmin (float): minimum noise for sampling. Defaults to 0.75. 38 | S_tmax (float): maximum noise for sampling. Defaults to 80. 39 | S_churn (float): stochastic churn rate. Defaults to 2.5. 40 | r (float): _description_. Defaults to 0.5. 41 | sigma_max (float): maximum value of sigma for sigma's distribution. Defaults to 80. 42 | sigma_min (float): minimum value of sigma for sigma's distribution. Defaults to 0.03. 43 | rho (float): exponent of the sigma's distribution. Defaults to 7. 44 | num_steps (int): number of timesteps during sampling. Defaults to 20. 45 | """ 46 | self.S_noise = S_noise 47 | self.S_tmin = S_tmin 48 | self.S_tmax = S_tmax 49 | self.S_churn = S_churn 50 | self.r = r 51 | self.num_steps = num_steps 52 | 53 | self.sigma_max = sigma_max 54 | self.sigma_min = sigma_min 55 | self.rho = rho 56 | 57 | def _sigmas_fn(self, u): 58 | return ( 59 | self.sigma_max ** (1 / self.rho) 60 | + u * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) 61 | ) ** self.rho 62 | 63 | @torch.no_grad() 64 | def sample(self, denoiser: Denoiser, prev_inputs: torch.Tensor): 65 | """Generate a sample from random noise for the given inputs. 66 | 67 | Args: 68 | denoiser (Denoiser): the denoiser model. 69 | prev_inputs (torch.Tensor): previous two timesteps. 70 | 71 | Returns: 72 | torch.Tensor: normalized residuals predicted. 73 | """ 74 | device = prev_inputs.device 75 | 76 | time_steps = torch.arange(0, self.num_steps).to(device) / (self.num_steps - 1) 77 | sigmas = self._sigmas_fn(time_steps) 78 | 79 | batch_ones = torch.ones(1, 1).to(device) 80 | 81 | # initialize noise 82 | x = sigmas[0] * torch.tensor( 83 | generate_isotropic_noise( 84 | num_lon=denoiser.num_lon, 85 | num_lat=denoiser.num_lat, 86 | num_samples=denoiser.output_features_dim, 87 | ) 88 | ).unsqueeze(0).to(device) 89 | 90 | for i in range(len(sigmas) - 1): 91 | # stochastic churn from Karras et al. (Alg. 2) 92 | gamma = ( 93 | min(self.S_churn / self.num_steps, math.sqrt(2) - 1) 94 | if self.S_tmin <= sigmas[i] <= self.S_tmax 95 | else 0.0 96 | ) 97 | # noise inflation from Karras et al. (Alg. 2) 98 | noise = self.S_noise * torch.tensor( 99 | generate_isotropic_noise( 100 | num_lon=denoiser.num_lon, 101 | num_lat=denoiser.num_lat, 102 | num_samples=denoiser.output_features_dim, 103 | ) 104 | ) 105 | noise = noise.unsqueeze(0).to(device) 106 | 107 | sigma_hat = sigmas[i] * (gamma + 1) 108 | if gamma > 0: 109 | x = x + (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 * noise 110 | denoised = denoiser(x, prev_inputs, sigma_hat * batch_ones) 111 | 112 | if i == len(sigmas) - 2: 113 | # final Euler step 114 | d = (x - denoised) / sigma_hat 115 | x = x + d * (sigmas[i + 1] - sigma_hat) 116 | else: 117 | # DPMSolver++2S step (Alg. 1 in Lu et al.) with alpha_t=1. 118 | # t_{i-1} is t_hat because of stochastic churn! 119 | lambda_hat = -torch.log(sigma_hat) 120 | lambda_next = -torch.log(sigmas[i + 1]) 121 | h = lambda_next - lambda_hat 122 | lambda_mid = lambda_hat + self.r * h 123 | sigma_mid = torch.exp(-lambda_mid) 124 | 125 | u = sigma_mid / sigma_hat * x - (torch.exp(-self.r * h) - 1) * denoised 126 | denoised_2 = denoiser(u, prev_inputs, sigma_mid * batch_ones) 127 | D = (1 - 1 / (2 * self.r)) * denoised + 1 / (2 * self.r) * denoised_2 128 | x = sigmas[i + 1] / sigma_hat * x - (torch.exp(-h) - 1) * D 129 | 130 | return x 131 | -------------------------------------------------------------------------------- /train/lora.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn as nn 7 | import xarray 8 | from einops import rearrange 9 | from pytorch_lightning.callbacks import ModelCheckpoint 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | from graph_weather.models import LoRAModule, MetaModel 13 | from graph_weather.models.losses import NormalizedMSELoss 14 | 15 | 16 | class LitLoRAFengWuGHR(pl.LightningModule): 17 | def __init__( 18 | self, 19 | lat_lons: list, 20 | single_step_model_state_dict: dict, 21 | *, 22 | time_step: int, 23 | rank: int, 24 | channels: int, 25 | image_size, 26 | patch_size=4, 27 | depth=5, 28 | heads=4, 29 | mlp_dim=5, 30 | feature_dim: int = 605, # TODO where does this come from? 31 | lr: float = 3e-4, 32 | ): 33 | super().__init__() 34 | assert ( 35 | time_step > 1 36 | ), "Time step must be greater than 1. Remember that 1 is the simple model time step." 37 | ssmodel = MetaModel( 38 | lat_lons, 39 | image_size=image_size, 40 | patch_size=patch_size, 41 | depth=depth, 42 | heads=heads, 43 | mlp_dim=mlp_dim, 44 | channels=channels, 45 | ) 46 | ssmodel.load_state_dict(single_step_model_state_dict) 47 | self.models = nn.ModuleList( 48 | [ssmodel] + [LoRAModule(ssmodel, r=rank) for _ in range(2, time_step + 1)] 49 | ) 50 | self.criterion = NormalizedMSELoss( 51 | lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) 52 | ) 53 | self.lr = lr 54 | self.save_hyperparameters() 55 | 56 | def forward(self, x): 57 | ys = [] 58 | for t, model in enumerate(self.models): 59 | x = model(x) 60 | ys.append(x) 61 | return torch.stack(ys, dim=1) 62 | 63 | def training_step(self, batch, batch_idx): 64 | if torch.isnan(batch).any(): 65 | return None 66 | x, ys = batch[:, 0, ...], batch[:, 1:, ...] 67 | 68 | y_hat = self.forward(x) 69 | loss = self.criterion(y_hat, ys) 70 | self.log("loss", loss, prog_bar=True) 71 | return loss 72 | 73 | def configure_optimizers(self): 74 | return torch.optim.AdamW(self.parameters(), lr=self.lr) 75 | 76 | 77 | class Era5Dataset(Dataset): 78 | def __init__(self, xarr, time_step=1, transform=None): 79 | assert time_step > 0, "Time step must be greater than 0." 80 | ds = np.asarray(xarr.to_array()) 81 | ds = torch.from_numpy(ds) 82 | ds -= ds.min(0, keepdim=True)[0] 83 | ds /= ds.max(0, keepdim=True)[0] 84 | ds = rearrange(ds, "C T H W -> T (H W) C") 85 | self.ds = ds 86 | self.time_step = time_step 87 | 88 | def __len__(self): 89 | return len(self.ds) - self.time_step 90 | 91 | def __getitem__(self, index): 92 | return self.ds[index : index + time_step + 1] 93 | 94 | 95 | if __name__ == "__main__": 96 | ckpt_path = Path("./checkpoints") 97 | ckpt_name = "best.pt" 98 | patch_size = 4 99 | grid_step = 20 100 | time_step = 2 101 | rank = 4 102 | variables = [ 103 | "2m_temperature", 104 | "surface_pressure", 105 | "10m_u_component_of_wind", 106 | "10m_v_component_of_wind", 107 | ] 108 | 109 | ############################################################### 110 | 111 | channels = len(variables) 112 | ckpt_path.mkdir(parents=True, exist_ok=True) 113 | 114 | reanalysis = xarray.open_zarr( 115 | "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", 116 | storage_options=dict(token="anon"), 117 | ) 118 | 119 | reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01")) 120 | reanalysis = reanalysis.isel( 121 | time=slice(100, 111), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step) 122 | ) 123 | 124 | reanalysis = reanalysis[variables] 125 | print(f"size: {reanalysis.nbytes / (1024**3)} GiB") 126 | 127 | lat_lons = np.array( 128 | np.meshgrid( 129 | np.asarray(reanalysis["latitude"]).flatten(), 130 | np.asarray(reanalysis["longitude"]).flatten(), 131 | ) 132 | ).T.reshape((-1, 2)) 133 | 134 | checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss") 135 | 136 | dset = DataLoader(Era5Dataset(reanalysis, time_step=time_step), batch_size=10, num_workers=8) 137 | 138 | single_step_model_state_dict = torch.load(ckpt_path / ckpt_name) 139 | 140 | model = LitLoRAFengWuGHR( 141 | lat_lons=lat_lons, 142 | single_step_model_state_dict=single_step_model_state_dict, 143 | time_step=time_step, 144 | rank=rank, 145 | ########## 146 | channels=channels, 147 | image_size=(721 // grid_step, 1440 // grid_step), 148 | patch_size=patch_size, 149 | depth=5, 150 | heads=4, 151 | mlp_dim=5, 152 | ) 153 | trainer = pl.Trainer( 154 | accelerator="gpu", 155 | devices=-1, 156 | max_epochs=100, 157 | precision="16-mixed", 158 | callbacks=[checkpoint_callback], 159 | log_every_n_steps=3, 160 | strategy="ddp_find_unused_parameters_true", 161 | ) 162 | 163 | trainer.fit(model, dset) 164 | -------------------------------------------------------------------------------- /tests/test_nnjai.py: -------------------------------------------------------------------------------- 1 | """Unit tests for NNJA-AI data loading components. 2 | 3 | Tests cover variable classification, dataset loading, and PyTorch integration. 4 | """ 5 | 6 | from datetime import datetime 7 | from unittest.mock import MagicMock, patch 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import pytest 12 | import xarray as xr 13 | 14 | from graph_weather.data.nnja_ai import ( 15 | NNJATorchDataset, 16 | SensorDataset, 17 | _classify_variable, 18 | load_nnja_dataset, 19 | ) 20 | 21 | 22 | @pytest.fixture 23 | def mock_datacatalog(): 24 | """Fixture to mock the DataCatalog with properly configured variables.""" 25 | with patch("graph_weather.data.nnja_ai.DataCatalog") as mock: 26 | mock_catalog = MagicMock() 27 | mock_dataset = MagicMock() 28 | mock_dataset.load_manifest = MagicMock() 29 | 30 | # Setup valid variables with proper attributes 31 | def create_mock_var(var_type): 32 | var = MagicMock() 33 | var.category = var_type 34 | return var 35 | 36 | valid_variables = { 37 | "time": create_mock_var("primary_descriptor"), 38 | "latitude": create_mock_var("primary_descriptor"), 39 | "longitude": create_mock_var("primary_descriptor"), 40 | "TMBR_00001": create_mock_var("primary_data"), 41 | "TMBR_00002": create_mock_var("primary_data"), 42 | "OBS_TIMESTAMP": create_mock_var("primary_descriptor"), 43 | "LAT": create_mock_var("primary_descriptor"), 44 | "LON": create_mock_var("primary_descriptor"), 45 | } 46 | 47 | def mock_sel(time=None, variables=None): 48 | # Always include time plus requested variables 49 | vars_to_load = ["time"] 50 | if variables: 51 | for v in variables: 52 | if v == "LAT": 53 | vars_to_load.append("latitude") 54 | elif v == "LON": 55 | vars_to_load.append("longitude") 56 | else: 57 | vars_to_load.append(v) 58 | 59 | # Validate variables exist 60 | invalid_vars = [v for v in vars_to_load if v not in valid_variables] 61 | if invalid_vars: 62 | raise ValueError(f"Invalid variables requested: {invalid_vars}") 63 | return mock_dataset 64 | 65 | mock_dataset.sel = mock_sel 66 | mock_dataset.variables = valid_variables 67 | 68 | def mock_load_dataset(backend="pandas", engine="pyarrow"): 69 | time_points = pd.date_range(start=datetime(2021, 1, 1), periods=100, freq="h") 70 | 71 | data = { 72 | "time": time_points, 73 | "latitude": np.full(100, 45.0), 74 | "longitude": np.full(100, -120.0), 75 | "TMBR_00001": np.full(100, 250.0), 76 | "TMBR_00002": np.full(100, 250.0), 77 | } 78 | return pd.DataFrame(data) 79 | 80 | mock_dataset.load_dataset = mock_load_dataset 81 | mock_catalog.__getitem__.side_effect = lambda name: mock_dataset 82 | mock.return_value = mock_catalog 83 | 84 | yield mock 85 | 86 | 87 | def test_variable_classification(): 88 | """Test the variable classification logic.""" 89 | # Create a mock variable with category attribute 90 | mock_var = MagicMock() 91 | mock_var.category = "primary_data" 92 | assert _classify_variable(mock_var) == "primary_data" 93 | 94 | 95 | def test_load_nnja_dataset(mock_datacatalog): 96 | """Test the core dataset loading function.""" 97 | ds = load_nnja_dataset("test-dataset", time=datetime(2021, 1, 1)) 98 | 99 | assert isinstance(ds, xr.Dataset) 100 | assert "time" in ds.dims 101 | assert len(ds.data_vars) >= 3 102 | assert np.issubdtype(ds.time.dtype, np.datetime64) 103 | 104 | 105 | def test_sensor_dataset(mock_datacatalog): 106 | """Test the SensorDataset class.""" 107 | ds = SensorDataset( 108 | "test-dataset", 109 | time=datetime(2021, 1, 1), 110 | variables=["LAT", "LON", "TMBR_00001"], # Used original names here 111 | ) 112 | 113 | assert len(ds) == 100 114 | sample = ds[0] 115 | assert isinstance(sample, dict) 116 | assert "latitude" in sample or "LAT" in sample 117 | assert "longitude" in sample or "LON" in sample 118 | assert "TMBR_00001" in sample 119 | 120 | 121 | def test_nnja_xarray_torch_dataset(mock_datacatalog): 122 | """Test the xarray to torch Dataset adapter.""" 123 | xrds = load_nnja_dataset("test-dataset") 124 | torch_ds = NNJATorchDataset(xrds) 125 | 126 | assert len(torch_ds) == len(xrds.time) 127 | sample = torch_ds[0] 128 | assert isinstance(sample, dict) 129 | 130 | 131 | def test_custom_variable_selection(mock_datacatalog): 132 | """Verify loading specific variables works with coordinate renaming.""" 133 | # Request the original coordinate names 134 | # Request original coordinate names 135 | custom_vars = ["LAT", "LON"] 136 | ds = load_nnja_dataset("test-dataset", variables=custom_vars) 137 | 138 | # Check renamed coordinates exist 139 | assert "latitude" in ds.data_vars 140 | assert "longitude" in ds.data_vars 141 | 142 | 143 | def test_load_all_variables(mock_datacatalog): 144 | """Test loading all variables.""" 145 | ds = load_nnja_dataset("test-dataset", load_all=True) 146 | assert len(ds.data_vars) >= 4 # time + lat + lon + at least one variable 147 | -------------------------------------------------------------------------------- /graph_weather/models/weathermesh/weathermesh2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation based off the technical report and this repo: https://github.com/Brayden-Zhang/WeatherMesh 3 | """ 4 | 5 | from dataclasses import dataclass 6 | from typing import List 7 | 8 | import dacite 9 | import torch 10 | import torch.nn as nn 11 | 12 | from graph_weather.models.weathermesh.decoder import WeatherMeshDecoder, WeatherMeshDecoderConfig 13 | from graph_weather.models.weathermesh.encoder import WeatherMeshEncoder, WeatherMeshEncoderConfig 14 | from graph_weather.models.weathermesh.processor import ( 15 | WeatherMeshProcessor, 16 | WeatherMeshProcessorConfig, 17 | ) 18 | 19 | """ 20 | Notes on implementation 21 | 22 | To make NATTEN work on a sphere, we implement our own circular padding. At the poles, we use the bump attention behavior from NATTEN. For position encoding of tokens, we use Rotary Embeddings. 23 | 24 | In the default configuration of WeatherMesh 2, the NATTEN window is 5,7,7 in depth, width, height, corresponding to a physical size of 14 degrees longitude and latitude. WeatherMesh 2 contains two processors: a 6hr and a 1hr processor. Each is 10 NATTEN layers deep. 25 | 26 | Training: distributed shampoo: https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/README.md 27 | 28 | Fork version of pytorch checkpoint library called matepoint to implement offloading to RAM 29 | 30 | TODO: Add bump attention and rotary embeddings for the circular padding and position encoding 31 | 32 | """ 33 | 34 | 35 | @dataclass 36 | class WeatherMeshConfig: 37 | encoder: WeatherMeshEncoderConfig 38 | processors: List[WeatherMeshProcessorConfig] 39 | decoder: WeatherMeshDecoderConfig 40 | timesteps: List[int] 41 | surface_channels: int 42 | pressure_channels: int 43 | pressure_levels: int 44 | latent_dim: int 45 | encoder_num_conv_blocks: int 46 | encoder_num_transformer_layers: int 47 | encoder_hidden_dim: int 48 | decoder_num_conv_blocks: int 49 | decoder_num_transformer_layers: int 50 | decoder_hidden_dim: int 51 | processor_num_layers: int 52 | kernel: tuple 53 | num_heads: int 54 | 55 | @staticmethod 56 | def from_json(json: dict) -> "WeatherMesh": 57 | return dacite.from_dict(data_class=WeatherMeshConfig, data=json) 58 | 59 | def to_json(self) -> dict: 60 | return dacite.asdict(self) 61 | 62 | 63 | @dataclass 64 | class WeatherMeshOutput: 65 | surface: torch.Tensor 66 | pressure: torch.Tensor 67 | 68 | 69 | class WeatherMesh(nn.Module): 70 | def __init__( 71 | self, 72 | encoder: nn.Module | None, 73 | processors: List[nn.Module] | None, 74 | decoder: nn.Module | None, 75 | timesteps: List[int], 76 | surface_channels: int | None, 77 | pressure_channels: int | None, 78 | pressure_levels: int | None, 79 | latent_dim: int | None, 80 | encoder_num_conv_blocks: int | None, 81 | encoder_num_transformer_layers: int | None, 82 | encoder_hidden_dim: int | None, 83 | decoder_num_conv_blocks: int | None, 84 | decoder_num_transformer_layers: int | None, 85 | decoder_hidden_dim: int | None, 86 | processor_num_layers: int | None, 87 | kernel: tuple | None, 88 | num_heads: int | None, 89 | ): 90 | super().__init__() 91 | if encoder is not None: 92 | self.encoder = encoder 93 | else: 94 | self.encoder = WeatherMeshEncoder( 95 | input_channels_2d=surface_channels, 96 | input_channels_3d=pressure_channels, 97 | latent_dim=latent_dim, 98 | n_pressure_levels=pressure_levels, 99 | num_conv_blocks=encoder_num_conv_blocks, 100 | hidden_dim=encoder_hidden_dim, 101 | kernel_size=kernel, 102 | num_heads=num_heads, 103 | num_transformer_layers=encoder_num_transformer_layers, 104 | ) 105 | if processors is not None: 106 | assert len(processors) == len( 107 | timesteps 108 | ), "Number of processors must match number of timesteps" 109 | self.processors = processors 110 | else: 111 | self.processors = [ 112 | WeatherMeshProcessor( 113 | latent_dim=latent_dim, 114 | n_layers=processor_num_layers, 115 | kernel=kernel, 116 | num_heads=num_heads, 117 | ) 118 | for _ in range(len(timesteps)) 119 | ] 120 | if decoder is not None: 121 | self.decoder = decoder 122 | else: 123 | self.decoder = WeatherMeshDecoder( 124 | latent_dim=latent_dim, 125 | output_channels_2d=surface_channels, 126 | output_channels_3d=pressure_channels, 127 | n_conv_blocks=decoder_num_conv_blocks, 128 | hidden_dim=decoder_hidden_dim, 129 | kernel_size=kernel, 130 | num_heads=num_heads, 131 | num_transformer_layers=decoder_num_transformer_layers, 132 | ) 133 | self.timesteps = timesteps 134 | 135 | def forward( 136 | self, surface: torch.Tensor, pressure: torch.Tensor, forecast_steps: int 137 | ) -> WeatherMeshOutput: 138 | # Encode input 139 | latent = self.encoder(surface, pressure) 140 | 141 | # Apply processors for each forecast step 142 | for _ in range(forecast_steps): 143 | for processor in self.processors: 144 | latent = processor(latent) 145 | 146 | # Decode output 147 | surface_out, pressure_out = self.decoder(latent) 148 | 149 | return WeatherMeshOutput(surface=surface_out, pressure=pressure_out) 150 | -------------------------------------------------------------------------------- /train/run_fulll.py: -------------------------------------------------------------------------------- 1 | """Training script for training the weather forecasting model""" 2 | 3 | import json 4 | import os 5 | import sys 6 | import time 7 | 8 | import numpy as np 9 | import torch 10 | import torch.optim as optim 11 | import torchvision.transforms as transforms 12 | import xarray as xr 13 | from torch.utils.data import DataLoader, Dataset 14 | 15 | from graph_weather import GraphWeatherForecaster 16 | from graph_weather.data import const 17 | from graph_weather.models.losses import NormalizedMSELoss 18 | 19 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 20 | sys.path.append(BASE_DIR) 21 | 22 | 23 | class XrDataset(Dataset): 24 | """ 25 | Dataset class for loading data from Hugging Face datasets. 26 | 27 | Attributes: 28 | filepaths : List of file paths to the data. 29 | data : Dataset containing the loaded data. 30 | 31 | Methods: 32 | __init__: Initialize the XrDataset object by loading data from Hugging Face datasets. 33 | __len__: Get the length of the dataset. 34 | __getitem__: Get an item from the dataset by index. 35 | """ 36 | 37 | def __init__(self): 38 | """ 39 | Initialize the XrDataset object by loading data from Hugging Face datasets. 40 | """ 41 | super().__init__() 42 | with open("hf_forecasts.json", "r") as f: 43 | files = json.load(f) 44 | self.filepaths = [ 45 | "zip:///::https://huggingface.co/datasets/openclimatefix/gfs-reforecast/resolve/main/" 46 | + f 47 | for f in files 48 | ] 49 | self.data = xr.open_mfdataset( 50 | self.filepaths, engine="zarr", concat_dim="reftime", combine="nested" 51 | ).sortby("reftime") 52 | 53 | def __len__(self): 54 | return len(self.filepaths) 55 | 56 | def __getitem__(self, item): 57 | start_idx = np.random.randint(0, 14) 58 | data = self.data.isel(reftime=item, time=slice(start_idx, start_idx + 2)) 59 | 60 | start = data.isel(time=0) 61 | end = data.isel(time=1) 62 | # Stack the data into a large data cube 63 | input_data = np.stack( 64 | [ 65 | (start[f"{var}"].values - const.FORECAST_MEANS[f"{var}"]) 66 | / (const.FORECAST_STD[f"{var}"] + 0.0001) 67 | for var in start.data_vars 68 | if "mb" in var or "surface" in var 69 | ], 70 | axis=-1, 71 | ) 72 | input_data = np.nan_to_num(input_data) 73 | assert not np.isnan(input_data).any() 74 | output_data = np.stack( 75 | [ 76 | (end[f"{var}"].values - const.FORECAST_MEANS[f"{var}"]) 77 | / (const.FORECAST_STD[f"{var}"] + 0.0001) 78 | for var in end.data_vars 79 | if "mb" in var or "surface" in var 80 | ], 81 | axis=-1, 82 | ) 83 | output_data = np.nan_to_num(output_data) 84 | assert not np.isnan(output_data).any() 85 | transform = transforms.Compose([transforms.ToTensor()]) 86 | # Normalize now 87 | return ( 88 | transform(input_data).transpose(0, 1).reshape(-1, input_data.shape[-1]), 89 | transform(output_data).transpose(0, 1).reshape(-1, input_data.shape[-1]), 90 | ) 91 | 92 | 93 | with open("hf_forecasts.json", "r") as f: 94 | files = json.load(f) 95 | files = [ 96 | "zip:///::https://huggingface.co/datasets/openclimatefix/gfs-reforecast/resolve/main/" + f 97 | for f in files 98 | ] 99 | data = ( 100 | xr.open_zarr(files[0], consolidated=True).isel(time=0) 101 | # .coarsen(latitude=8, boundary="pad") 102 | # .mean() 103 | # .coarsen(longitude=8) 104 | # .mean() 105 | ) 106 | print(data) 107 | # print("Done coarsening") 108 | lat_lons = np.array(np.meshgrid(data.latitude.values, data.longitude.values)).T.reshape(-1, 2) 109 | 110 | if torch.cuda.is_available(): 111 | device = "cuda" 112 | elif torch.backends.mps.is_available(): 113 | device = "mps" 114 | else: 115 | device = "cpu" 116 | 117 | # Get the variance of the variables 118 | feature_variances = [] 119 | for var in data.data_vars: 120 | if "mb" in var or "surface" in var: 121 | feature_variances.append(const.FORECAST_DIFF_STD[var] ** 2) 122 | criterion = NormalizedMSELoss( 123 | lat_lons=lat_lons, feature_variance=feature_variances, device=device 124 | ).to(device) 125 | means = [] 126 | dataset = DataLoader(XrDataset(), batch_size=1) 127 | model = GraphWeatherForecaster(lat_lons, feature_dim=597, num_blocks=6).to(device) 128 | optimizer = optim.AdamW(model.parameters(), lr=0.001) 129 | print("Done Setup") 130 | 131 | 132 | for epoch in range(100): # loop over the dataset multiple times 133 | running_loss = 0.0 134 | start = time.time() 135 | print(f"Start Epoch: {epoch}") 136 | for i, data in enumerate(dataset): 137 | # get the inputs; data is a list of [inputs, labels] 138 | inputs, labels = data[0].to(device), data[1].to(device) 139 | # zero the parameter gradients 140 | optimizer.zero_grad() 141 | 142 | # forward + backward + optimize 143 | outputs = model(inputs) 144 | 145 | loss = criterion(outputs, labels) 146 | loss.backward() 147 | optimizer.step() 148 | 149 | # print statistics 150 | running_loss += loss.item() 151 | end = time.time() 152 | print( 153 | f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / (i + 1):.3f} Time: {end - start} sec" 154 | ) 155 | if epoch % 5 == 0: 156 | assert not np.isnan(running_loss) 157 | model.push_to_hub( 158 | "graph-weather-forecaster-2.0deg", 159 | organization="openclimatefix", 160 | commit_message=f"Add model Epoch={epoch}", 161 | ) 162 | 163 | print("Finished Training") 164 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/weighted_mse_loss.py: -------------------------------------------------------------------------------- 1 | """The weighted loss function for GenCast training.""" 2 | 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class WeightedMSELoss(torch.nn.Module): 10 | """Module WeightedMSELoss. 11 | 12 | This module implement the loss described in GenCast's paper. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | grid_lat: Optional[torch.Tensor] = None, 18 | pressure_levels: Optional[torch.Tensor] = None, 19 | num_atmospheric_features: Optional[int] = None, 20 | single_features_weights: Optional[torch.Tensor] = None, 21 | ): 22 | """Initialize the WeightedMSELoss Module. 23 | 24 | More details about the features weights are reported in GraphCast's paper. In short, if the 25 | single features are "2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind", 26 | "mean_sea_level_pressure" and "total_precipitation_12hr", then it's suggested to set 27 | corresponding weights as 1, 0.1, 0.1, 0.1 and 0.1. 28 | 29 | Args: 30 | grid_lat (torch.Tensor, optional): 1D tensor containing all the latitudes. 31 | pressure_levels (torch.Tensor, optional): 1D tensor containing all the pressure levels 32 | per variable. 33 | num_atmospheric_features (int, optional): number of atmospheric features. 34 | single_features_weights (torch.Tensor, optional): 1D tensor containing single features 35 | weights. 36 | """ 37 | super().__init__() 38 | 39 | area_weights = None 40 | features_weights = None 41 | 42 | if grid_lat is not None: 43 | area_weights = torch.abs(torch.cos(grid_lat * np.pi / 180.0)) 44 | area_weights = area_weights / torch.mean(area_weights) 45 | if ( 46 | pressure_levels is not None 47 | and num_atmospheric_features is not None 48 | and single_features_weights is not None 49 | ): 50 | pressure_weights = pressure_levels / torch.sum(pressure_levels) 51 | features_weights = torch.cat( 52 | (pressure_weights.repeat(num_atmospheric_features), single_features_weights), dim=-1 53 | ) 54 | elif ( 55 | pressure_levels is not None 56 | or num_atmospheric_features is not None 57 | or single_features_weights is not None 58 | ): 59 | raise ValueError( 60 | "Please to use features weights provide all three: pressure_levels," 61 | "num_atmospheric_features and single_features_weights." 62 | ) 63 | 64 | self.sigma_data = 1 # assuming normalized data! 65 | 66 | self.register_buffer("area_weights", area_weights, persistent=False) 67 | self.register_buffer("features_weights", features_weights, persistent=False) 68 | 69 | def _lambda_sigma(self, noise_level): 70 | noise_weights = (noise_level**2 + self.sigma_data**2) / (noise_level * self.sigma_data) ** 2 71 | return noise_weights # [batch, 1] 72 | 73 | def forward( 74 | self, 75 | pred: torch.Tensor, 76 | noise_level: torch.Tensor, 77 | target: torch.Tensor, 78 | ) -> torch.Tensor: 79 | """Compute the loss. 80 | 81 | Args: 82 | pred (torch.Tensor): prediction of the model [batch, lon, lat, var]. 83 | noise_level (torch.Tensor): noise levels fed to the model for the corresponding 84 | predictions [batch, 1]. 85 | target (torch.Tensor): target tensor [batch, lon, lat, var]. 86 | 87 | Returns: 88 | torch.Tensor: weighted MSE loss. 89 | """ 90 | # check shapes 91 | if not (pred.shape == target.shape): 92 | raise ValueError( 93 | "Predictions and targets must have same shape. The actual shapes " 94 | f"are {pred.shape} and {target.shape}." 95 | ) 96 | if not (len(pred.shape) == 4): 97 | raise ValueError( 98 | "The expected shape for predictions and targets is " 99 | f"[batch, lon, lat, var], but got {pred.shape}." 100 | ) 101 | if not (noise_level.shape == (pred.shape[0], 1)): 102 | raise ValueError( 103 | f"The expected shape for noise levels is [batch, 1], but got {noise_level.shape}." 104 | ) 105 | 106 | # compute square residuals 107 | loss = (pred - target) ** 2 # [batch, lon, lat, var] 108 | if torch.isnan(loss).any(): 109 | raise ValueError("NaN values encountered in loss calculation.") 110 | 111 | # apply area and features weights to residuals 112 | if self.area_weights is not None: 113 | if not (len(self.area_weights) == pred.shape[2]): 114 | raise ValueError( 115 | f"The size of grid_lat at initialization ({len(self.area_weights)}) " 116 | f"and the number of latitudes in predictions ({pred.shape[2]}) " 117 | "don't match." 118 | ) 119 | loss *= self.area_weights[None, None, :, None] 120 | 121 | if self.features_weights is not None: 122 | if not (len(self.features_weights) == pred.shape[-1]): 123 | raise ValueError( 124 | f"The size of features weights at initialization ({len(self.features_weights)})" 125 | f" and the number of features in predictions ({pred.shape[-1]}) " 126 | "don't match." 127 | ) 128 | loss *= self.features_weights[None, None, None, :] 129 | 130 | # compute means across lon, lat, var for each sample in the batch 131 | loss = loss.flatten(1).mean(-1) # [batch] 132 | 133 | # weight each sample using the corresponding noise level, then return the mean. 134 | loss *= self._lambda_sigma(noise_level).flatten() 135 | return loss.mean() 136 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/graph/grid_mesh_connectivity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Source: https://github.com/google-deepmind/graphcast. 15 | """Tools for converting from regular grids on a sphere, to triangular meshes.""" 16 | 17 | import numpy as np 18 | import scipy 19 | import trimesh 20 | 21 | from graph_weather.models.gencast.graph import icosahedral_mesh 22 | 23 | 24 | def _grid_lat_lon_to_coordinates( 25 | grid_latitude: np.ndarray, grid_longitude: np.ndarray 26 | ) -> np.ndarray: 27 | """Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3].""" 28 | # Convert to spherical coordinates phi and theta defined in the grid. 29 | # Each [num_latitude_points, num_longitude_points] 30 | phi_grid, theta_grid = np.meshgrid(np.deg2rad(grid_longitude), np.deg2rad(90 - grid_latitude)) 31 | 32 | # [num_latitude_points, num_longitude_points, 3] 33 | # Note this assumes unit radius, since for now we model the earth as a 34 | # sphere of unit radius, and keep any vertical dimension as a regular grid. 35 | return np.stack( 36 | [ 37 | np.cos(phi_grid) * np.sin(theta_grid), 38 | np.sin(phi_grid) * np.sin(theta_grid), 39 | np.cos(theta_grid), 40 | ], 41 | axis=-1, 42 | ) 43 | 44 | 45 | def radius_query_indices( 46 | *, 47 | grid_latitude: np.ndarray, 48 | grid_longitude: np.ndarray, 49 | mesh: icosahedral_mesh.TriangularMesh, 50 | radius: float, 51 | ) -> tuple[np.ndarray, np.ndarray]: 52 | """Returns mesh-grid edge indices for radius query. 53 | 54 | Args: 55 | grid_latitude: Latitude values for the grid [num_lat_points] 56 | grid_longitude: Longitude values for the grid [num_lon_points] 57 | mesh: Mesh object. 58 | radius: Radius of connectivity in R3. for a sphere of unit radius. 59 | 60 | Returns: 61 | tuple with `grid_indices` and `mesh_indices` indicating edges between the 62 | grid and the mesh such that the distances in a straight line (not geodesic) 63 | are smaller than or equal to `radius`. 64 | * grid_indices: Indices of shape [num_edges], that index into a 65 | [num_lat_points, num_lon_points] grid, after flattening the leading axes. 66 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. 67 | """ 68 | 69 | # [num_grid_points=num_lat_points * num_lon_points, 3] 70 | grid_positions = _grid_lat_lon_to_coordinates(grid_latitude, grid_longitude).reshape([-1, 3]) 71 | 72 | # [num_mesh_points, 3] 73 | mesh_positions = mesh.vertices 74 | kd_tree = scipy.spatial.cKDTree(mesh_positions) 75 | 76 | # [num_grid_points, num_mesh_points_per_grid_point] 77 | # Note `num_mesh_points_per_grid_point` is not constant, so this is a list 78 | # of arrays, rather than a 2d array. 79 | query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius) 80 | 81 | grid_edge_indices = [] 82 | mesh_edge_indices = [] 83 | for grid_index, mesh_neighbors in enumerate(query_indices): 84 | grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors))) 85 | mesh_edge_indices.append(mesh_neighbors) 86 | 87 | # [num_edges] 88 | grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int) 89 | mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int) 90 | 91 | return grid_edge_indices, mesh_edge_indices 92 | 93 | 94 | def in_mesh_triangle_indices( 95 | *, grid_latitude: np.ndarray, grid_longitude: np.ndarray, mesh: icosahedral_mesh.TriangularMesh 96 | ) -> tuple[np.ndarray, np.ndarray]: 97 | """Returns mesh-grid edge indices for grid points contained in mesh triangles. 98 | 99 | Args: 100 | grid_latitude: Latitude values for the grid [num_lat_points] 101 | grid_longitude: Longitude values for the grid [num_lon_points] 102 | mesh: Mesh object. 103 | 104 | Returns: 105 | tuple with `grid_indices` and `mesh_indices` indicating edges between the 106 | grid and the mesh vertices of the triangle that contain each grid point. 107 | The number of edges is always num_lat_points * num_lon_points * 3 108 | * grid_indices: Indices of shape [num_edges], that index into a 109 | [num_lat_points, num_lon_points] grid, after flattening the leading axes. 110 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. 111 | """ 112 | 113 | # [num_grid_points=num_lat_points * num_lon_points, 3] 114 | grid_positions = _grid_lat_lon_to_coordinates(grid_latitude, grid_longitude).reshape([-1, 3]) 115 | 116 | mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) 117 | 118 | # [num_grid_points] with mesh face indices for each grid point. 119 | _, _, query_face_indices = trimesh.proximity.closest_point(mesh_trimesh, grid_positions) 120 | 121 | # [num_grid_points, 3] with mesh node indices for each grid point. 122 | mesh_edge_indices = mesh.faces[query_face_indices] 123 | 124 | # [num_grid_points, 3] with grid node indices, where every row simply contains 125 | # the row (grid_point) index. 126 | grid_indices = np.arange(grid_positions.shape[0]) 127 | grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3]) 128 | 129 | # Flatten to get a regular list. 130 | # [num_edges=num_grid_points*3] 131 | mesh_edge_indices = mesh_edge_indices.reshape([-1]) 132 | grid_edge_indices = grid_edge_indices.reshape([-1]) 133 | 134 | return grid_edge_indices, mesh_edge_indices 135 | -------------------------------------------------------------------------------- /graph_weather/models/gencast/layers/experimental/sparse_transformer.py: -------------------------------------------------------------------------------- 1 | """Experimental sparse transformer using DGL sparse.""" 2 | 3 | import dgl.sparse as dglsp 4 | import torch 5 | import torch.nn as nn 6 | 7 | from graph_weather.models.gencast.layers.modules import ConditionalLayerNorm 8 | 9 | 10 | class SparseAttention(nn.Module): 11 | """Sparse Multi-head Attention Module""" 12 | 13 | def __init__(self, input_dim=512, output_dim=512, num_heads=4): 14 | """Initialize Sparse MultiHead attention module. 15 | 16 | Args: 17 | input_dim (int): input dimension. Defaults to 512. 18 | output_dim (int): output dimension. Defaults to 512. 19 | num_heads (int): number of heads. Output dimension should be divisible by num_heads. 20 | Defaults to 4. 21 | """ 22 | super().__init__() 23 | if output_dim % num_heads: 24 | raise ValueError("Output dimension should be divisible by the number of heads.") 25 | 26 | self.hidden_size = output_dim 27 | self.num_heads = num_heads 28 | self.head_dim = output_dim // num_heads 29 | self.scaling = self.head_dim**-0.5 30 | 31 | self.q_proj = nn.Linear(input_dim, output_dim) 32 | self.k_proj = nn.Linear(input_dim, output_dim) 33 | self.v_proj = nn.Linear(input_dim, output_dim) 34 | self.out_proj = nn.Linear(output_dim, output_dim) 35 | 36 | def forward(self, x: torch.Tensor, adj: dglsp.SparseMatrix): 37 | """Forward pass of SparseMHA. 38 | 39 | Args: 40 | x (torch.Tensor): input tensor. 41 | adj (SparseMatrix): adjacency matrix in DGL SparseMatrix format. 42 | 43 | Returns: 44 | y (tensor): output of MultiHead attention. 45 | """ 46 | N = len(x) 47 | # computing query,key and values. 48 | q = self.q_proj(x).reshape(N, self.head_dim, self.num_heads) # (dense) [N, dh, nh] 49 | k = self.k_proj(x).reshape(N, self.head_dim, self.num_heads) # (dense) [N, dh, nh] 50 | v = self.v_proj(x).reshape(N, self.head_dim, self.num_heads) # (dense) [N, dh, nh] 51 | # scaling query 52 | q *= self.scaling 53 | 54 | # sparse-dense-dense product 55 | attn = dglsp.bsddmm(adj, q, k.transpose(1, 0)) # (sparse) [N, N, nh] 56 | 57 | # sparse softmax (by default applies on the last sparse dimension). 58 | attn = attn.softmax() # (sparse) [N, N, nh] 59 | 60 | # sparse-dense multiplication 61 | out = dglsp.bspmm(attn, v) # (dense) [N, dh, nh] 62 | return self.out_proj(out.reshape(N, -1)) 63 | 64 | 65 | class SparseTransformer(nn.Module): 66 | """A single transformer block for graph neural networks. 67 | 68 | This module implements a single transformer block with a sparse attention mechanism. 69 | """ 70 | 71 | def __init__( 72 | self, 73 | conditioning_dim: int, 74 | input_dim: int, 75 | output_dim: int, 76 | num_heads: int, 77 | activation_layer: torch.nn.Module = nn.ReLU, 78 | norm_first: bool = True, 79 | ): 80 | """Initialize SparseTransformer module. 81 | 82 | Args: 83 | conditioning_dim (int, optional): dimension of the conditioning parameter. If None the 84 | layer normalization will not be applied. 85 | input_dim (int): dimension of the input features. 86 | output_dim (int): dimension of the output features. 87 | edges_dim (int): dimension of the edge features. 88 | num_heads (int): number of heads for multi-head attention. 89 | activation_layer (torch.nn.Module): activation function applied before 90 | returning the output. 91 | norm_first (bool): if True apply layer normalization before attention. Defaults to True. 92 | """ 93 | super().__init__() 94 | 95 | # initialize multihead sparse attention. 96 | self.sparse_attention = SparseAttention( 97 | input_dim=input_dim, output_dim=output_dim, num_heads=num_heads 98 | ) 99 | 100 | # initialize mlp 101 | self.activation = activation_layer() 102 | self.mlp = nn.Sequential( 103 | nn.Linear(output_dim, output_dim), self.activation, nn.Linear(output_dim, output_dim) 104 | ) 105 | 106 | # initialize conditional layer normalization 107 | self.cond_norm_1 = ConditionalLayerNorm( 108 | conditioning_dim=conditioning_dim, features_dim=output_dim 109 | ) 110 | self.cond_norm_2 = ConditionalLayerNorm( 111 | conditioning_dim=conditioning_dim, features_dim=output_dim 112 | ) 113 | 114 | self.norm_first = norm_first 115 | 116 | def forward( 117 | self, 118 | x: torch.Tensor, 119 | edge_index: torch.Tensor, 120 | cond_param: torch.Tensor, 121 | *args, 122 | **kwargs, 123 | ) -> torch.Tensor: 124 | """Apply SparseTransformer to input. 125 | 126 | Input and conditioning parameter must have same batch size. 127 | 128 | Args: 129 | x (torch.Tensor): tensor containing nodes features. 130 | edge_index (torch.Tensor): edge index tensor. 131 | cond_param (torch.Tensor): conditioning parameter. 132 | *args: ignored by the module. 133 | **kwargs: ignored by the module. 134 | 135 | """ 136 | if self.norm_first: 137 | x1 = self.cond_norm_1(x, cond_param) 138 | x = x + self.sparse_attention( 139 | x=x1, adj=dglsp.spmatrix(indices=edge_index, shape=(x.shape[0], x.shape[0])) 140 | ) 141 | else: 142 | x = x + self.sparse_attention( 143 | x=x, adj=dglsp.spmatrix(indices=edge_index, shape=(x.shape[0], x.shape[0])) 144 | ) 145 | x = self.cond_norm_1(x, cond_param) 146 | 147 | if self.norm_first: 148 | x2 = self.cond_norm_2(x, cond_param) 149 | x = x + self.mlp(x2) 150 | else: 151 | x = x + self.mlp(x) 152 | x = self.cond_norm_2(x, cond_param) 153 | return x 154 | -------------------------------------------------------------------------------- /train/era5.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import xarray 7 | from einops import rearrange 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from graph_weather.models import MetaModel 12 | from graph_weather.models.losses import NormalizedMSELoss 13 | 14 | 15 | class LitFengWuGHR(pl.LightningModule): 16 | """ 17 | LightningModule for graph-based weather forecasting. 18 | 19 | Attributes: 20 | model (GraphWeatherForecaster): Graph weather forecaster model. 21 | criterion (NormalizedMSELoss): Loss criterion for training. 22 | lr : Learning rate for optimizer. 23 | 24 | Methods: 25 | __init__: Initialize the LitFengWuGHR object. 26 | forward: Forward pass of the model. 27 | training_step: Training step. 28 | configure_optimizers: Configure the optimizer for training. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | lat_lons: list, 34 | *, 35 | channels: int, 36 | image_size, 37 | patch_size=4, 38 | depth=5, 39 | heads=4, 40 | mlp_dim=5, 41 | feature_dim: int = 605, # TODO where does this come from? 42 | lr: float = 3e-4, 43 | ): 44 | """ 45 | Initialize the LitFengWuGHR object with the required args. 46 | 47 | Args: 48 | lat_lons : List of latitude and longitude values. 49 | feature_dim : Dimensionality of the input features. 50 | aux_dim : Dimensionality of auxiliary features. 51 | hidden_dim : Dimensionality of hidden layers in the model. 52 | num_blocks : Number of graph convolutional blocks in the model. 53 | lr (float): Learning rate for optimizer. 54 | """ 55 | super().__init__() 56 | self.model = MetaModel( 57 | lat_lons, 58 | image_size=image_size, 59 | patch_size=patch_size, 60 | depth=depth, 61 | heads=heads, 62 | mlp_dim=mlp_dim, 63 | channels=channels, 64 | ) 65 | self.criterion = NormalizedMSELoss( 66 | lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) 67 | ) 68 | self.lr = lr 69 | self.save_hyperparameters() 70 | 71 | def forward(self, x): 72 | """ 73 | Forward pass . 74 | 75 | Args: 76 | x (torch.Tensor): Input tensor. 77 | 78 | Returns: 79 | torch.Tensor: Output tensor. 80 | """ 81 | return self.model(x) 82 | 83 | def training_step(self, batch, batch_idx): 84 | """ 85 | Training step. 86 | 87 | Args: 88 | batch (array): Batch of data containing input and output tensors. 89 | batch_idx (int): Index of the current batch. 90 | 91 | Returns: 92 | torch.Tensor: Loss tensor. 93 | """ 94 | x, y = batch[:, 0], batch[:, 1] 95 | if torch.isnan(x).any() or torch.isnan(y).any(): 96 | return None 97 | y_hat = self.forward(x) 98 | loss = self.criterion(y_hat, y) 99 | self.log("loss", loss, prog_bar=True) 100 | return loss 101 | 102 | def configure_optimizers(self): 103 | """ 104 | Configure the optimizer. 105 | 106 | Returns: 107 | torch.optim.Optimizer: Optimizer instance. 108 | """ 109 | return torch.optim.AdamW(self.parameters(), lr=self.lr) 110 | 111 | 112 | class Era5Dataset(Dataset): 113 | """Era5 dataset.""" 114 | 115 | def __init__(self, xarr, transform=None): 116 | """ 117 | Arguments: 118 | #TODO 119 | """ 120 | ds = np.asarray(xarr.to_array()) 121 | ds = torch.from_numpy(ds) 122 | ds -= ds.min(0, keepdim=True)[0] 123 | ds /= ds.max(0, keepdim=True)[0] 124 | ds = rearrange(ds, "C T H W -> T (H W) C") 125 | self.ds = ds 126 | 127 | def __len__(self): 128 | return len(self.ds) - 1 129 | 130 | def __getitem__(self, index): 131 | return self.ds[index : index + 2] 132 | 133 | 134 | if __name__ == "__main__": 135 | ckpt_path = Path("./checkpoints") 136 | patch_size = 4 137 | grid_step = 20 138 | variables = [ 139 | "2m_temperature", 140 | "surface_pressure", 141 | "10m_u_component_of_wind", 142 | "10m_v_component_of_wind", 143 | ] 144 | 145 | channels = len(variables) 146 | ckpt_path.mkdir(parents=True, exist_ok=True) 147 | 148 | reanalysis = xarray.open_zarr( 149 | "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", 150 | storage_options=dict(token="anon"), 151 | ) 152 | 153 | reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01")) 154 | reanalysis = reanalysis.isel( 155 | time=slice(100, 107), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step) 156 | ) 157 | 158 | reanalysis = reanalysis[variables] 159 | print(f"size: {reanalysis.nbytes / (1024**3)} GiB") 160 | 161 | lat_lons = np.array( 162 | np.meshgrid( 163 | np.asarray(reanalysis["latitude"]).flatten(), 164 | np.asarray(reanalysis["longitude"]).flatten(), 165 | ) 166 | ).T.reshape((-1, 2)) 167 | 168 | checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss") 169 | 170 | dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8) 171 | model = LitFengWuGHR( 172 | lat_lons=lat_lons, 173 | channels=channels, 174 | image_size=(721 // grid_step, 1440 // grid_step), 175 | patch_size=patch_size, 176 | depth=5, 177 | heads=4, 178 | mlp_dim=5, 179 | ) 180 | trainer = pl.Trainer( 181 | accelerator="gpu", 182 | devices=-1, 183 | max_epochs=100, 184 | precision="16-mixed", 185 | callbacks=[checkpoint_callback], 186 | log_every_n_steps=3, 187 | ) 188 | 189 | trainer.fit(model, dset) 190 | 191 | torch.save(model.model.state_dict(), ckpt_path / "best.pt") 192 | -------------------------------------------------------------------------------- /graph_weather/models/analysis.py: -------------------------------------------------------------------------------- 1 | """Model for forecasting weather from NWP states""" 2 | 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | from huggingface_hub import PyTorchModelHubMixin 7 | 8 | from graph_weather.models import AssimilatorDecoder, AssimilatorEncoder, Processor 9 | 10 | 11 | @dataclass 12 | class GraphWeatherAssimilatorConfig: 13 | """Configuration for GraphWeatherAssimilator model.""" 14 | 15 | output_lat_lons: list 16 | resolution: int = 2 17 | observation_dim: int = 2 18 | analysis_dim: int = 78 19 | node_dim: int = 256 20 | edge_dim: int = 256 21 | num_blocks: int = 9 22 | hidden_dim_processor_node: int = 256 23 | hidden_dim_processor_edge: int = 256 24 | hidden_layers_processor_node: int = 2 25 | hidden_layers_processor_edge: int = 2 26 | hidden_dim_decoder: int = 128 27 | hidden_layers_decoder: int = 2 28 | norm_type: str = "LayerNorm" 29 | use_checkpointing: bool = False 30 | 31 | def build(self) -> "GraphWeatherAssimilator": 32 | """Build GraphWeatherAssimilator from this configuration.""" 33 | return GraphWeatherAssimilator( 34 | output_lat_lons=self.output_lat_lons, 35 | resolution=self.resolution, 36 | observation_dim=self.observation_dim, 37 | analysis_dim=self.analysis_dim, 38 | node_dim=self.node_dim, 39 | edge_dim=self.edge_dim, 40 | num_blocks=self.num_blocks, 41 | hidden_dim_processor_node=self.hidden_dim_processor_node, 42 | hidden_dim_processor_edge=self.hidden_dim_processor_edge, 43 | hidden_layers_processor_node=self.hidden_layers_processor_node, 44 | hidden_layers_processor_edge=self.hidden_layers_processor_edge, 45 | hidden_dim_decoder=self.hidden_dim_decoder, 46 | hidden_layers_decoder=self.hidden_layers_decoder, 47 | norm_type=self.norm_type, 48 | use_checkpointing=self.use_checkpointing, 49 | ) 50 | 51 | 52 | class GraphWeatherAssimilator(torch.nn.Module, PyTorchModelHubMixin): 53 | """Model to generate analysis file from raw observations""" 54 | 55 | def __init__( 56 | self, 57 | output_lat_lons: list, 58 | resolution: int = 2, 59 | observation_dim: int = 2, 60 | analysis_dim: int = 78, 61 | node_dim: int = 256, 62 | edge_dim: int = 256, 63 | num_blocks: int = 9, 64 | hidden_dim_processor_node: int = 256, 65 | hidden_dim_processor_edge: int = 256, 66 | hidden_layers_processor_node: int = 2, 67 | hidden_layers_processor_edge: int = 2, 68 | hidden_dim_decoder: int = 128, 69 | hidden_layers_decoder: int = 2, 70 | norm_type: str = "LayerNorm", 71 | use_checkpointing: bool = False, 72 | ): 73 | """ 74 | Graph Weather Data Assimilation model 75 | 76 | Args: 77 | observation_lat_lons: Lat/lon points of the observations 78 | output_lat_lons: List of latitude and longitudes for the output analysis 79 | resolution: Resolution of the H3 grid, prefer even resolutions, as 80 | odd ones have octogons and heptagons as well 81 | observation_dim: Input feature size 82 | analysis_dim: Output Analysis feature dim 83 | node_dim: Node hidden dimension 84 | edge_dim: Edge hidden dimension 85 | num_blocks: Number of message passing blocks in the Processor 86 | hidden_dim_processor_node: Hidden dimension of the node processors 87 | hidden_dim_processor_edge: Hidden dimension of the edge processors 88 | hidden_layers_processor_node: Number of hidden layers in the node processors 89 | hidden_layers_processor_edge: Number of hidden layers in the edge processors 90 | hidden_dim_decoder:Number of hidden dimensions in the decoder 91 | hidden_layers_decoder: Number of layers in the decoder 92 | norm_type: Type of norm for the MLPs 93 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None 94 | use_checkpointing: Whether to use gradient checkpointing or not 95 | """ 96 | super().__init__() 97 | 98 | self.encoder = AssimilatorEncoder( 99 | resolution=resolution, 100 | input_dim=observation_dim, 101 | output_dim=node_dim, 102 | output_edge_dim=edge_dim, 103 | hidden_dim_processor_edge=hidden_dim_processor_edge, 104 | hidden_layers_processor_node=hidden_layers_processor_node, 105 | hidden_dim_processor_node=hidden_dim_processor_node, 106 | hidden_layers_processor_edge=hidden_layers_processor_edge, 107 | mlp_norm_type=norm_type, 108 | use_checkpointing=use_checkpointing, 109 | ) 110 | self.processor = Processor( 111 | input_dim=node_dim, 112 | edge_dim=edge_dim, 113 | num_blocks=num_blocks, 114 | hidden_dim_processor_edge=hidden_dim_processor_edge, 115 | hidden_layers_processor_node=hidden_layers_processor_node, 116 | hidden_dim_processor_node=hidden_dim_processor_node, 117 | hidden_layers_processor_edge=hidden_layers_processor_edge, 118 | mlp_norm_type=norm_type, 119 | ) 120 | self.decoder = AssimilatorDecoder( 121 | lat_lons=output_lat_lons, 122 | resolution=resolution, 123 | input_dim=node_dim, 124 | output_dim=analysis_dim, 125 | output_edge_dim=edge_dim, 126 | hidden_dim_processor_edge=hidden_dim_processor_edge, 127 | hidden_layers_processor_node=hidden_layers_processor_node, 128 | hidden_dim_processor_node=hidden_dim_processor_node, 129 | hidden_layers_processor_edge=hidden_layers_processor_edge, 130 | mlp_norm_type=norm_type, 131 | hidden_dim_decoder=hidden_dim_decoder, 132 | hidden_layers_decoder=hidden_layers_decoder, 133 | use_checkpointing=use_checkpointing, 134 | ) 135 | 136 | def forward(self, features: torch.Tensor, obs_lat_lon_heights: torch.Tensor) -> torch.Tensor: 137 | """ 138 | Compute the analysis output 139 | 140 | Args: 141 | features: The input features, aligned with the order of lat_lons_heights 142 | obs_lat_lon_heights: Observation lat/lon/heights in same order as features 143 | 144 | Returns: 145 | The next state in the forecast 146 | """ 147 | x, edge_idx, edge_attr = self.encoder(features, obs_lat_lon_heights) 148 | x = self.processor(x, edge_idx, edge_attr) 149 | x = self.decoder(x, features.shape[0]) 150 | return x 151 | -------------------------------------------------------------------------------- /tests/test_gencast.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from packaging.version import Version 5 | from torch_geometric.transforms import TwoHop 6 | 7 | from graph_weather.models.gencast import Denoiser, GraphBuilder, Sampler, WeightedMSELoss 8 | from graph_weather.models.gencast.layers.modules import FourierEmbedding 9 | from graph_weather.models.gencast.utils.noise import generate_isotropic_noise, sample_noise_level 10 | 11 | 12 | def test_gencast_noise(): 13 | num_lon = 360 14 | num_lat = 180 15 | num_samples = 5 16 | target_residuals = np.zeros((num_lon, num_lat, num_samples)) 17 | noise_level = sample_noise_level() 18 | noise = generate_isotropic_noise( 19 | num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1] 20 | ) 21 | corrupted_residuals = target_residuals + noise_level * noise 22 | assert corrupted_residuals.shape == target_residuals.shape 23 | assert not np.isnan(corrupted_residuals).any() 24 | 25 | num_lon = 360 26 | num_lat = 181 27 | num_samples = 5 28 | target_residuals = np.zeros((num_lon, num_lat, num_samples)) 29 | noise_level = sample_noise_level() 30 | noise = generate_isotropic_noise( 31 | num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1] 32 | ) 33 | corrupted_residuals = target_residuals + noise_level * noise 34 | assert corrupted_residuals.shape == target_residuals.shape 35 | assert not np.isnan(corrupted_residuals).any() 36 | 37 | num_lon = 100 38 | num_lat = 100 39 | num_samples = 5 40 | target_residuals = np.zeros((num_lon, num_lat, num_samples)) 41 | noise_level = sample_noise_level() 42 | noise = generate_isotropic_noise( 43 | num_lon=num_lon, num_lat=num_lat, num_samples=target_residuals.shape[-1], isotropic=False 44 | ) 45 | corrupted_residuals = target_residuals + noise_level * noise 46 | assert corrupted_residuals.shape == target_residuals.shape 47 | assert not np.isnan(corrupted_residuals).any() 48 | 49 | 50 | def test_gencast_graph(): 51 | grid_lat = np.arange(-90, 90, 1) 52 | grid_lon = np.arange(0, 360, 1) 53 | graphs = GraphBuilder(grid_lon=grid_lon, grid_lat=grid_lat, splits=4, num_hops=8) 54 | 55 | # compare khop sparse implementation with pyg. 56 | transform = TwoHop() 57 | khop_mesh_graph_pyg = graphs.mesh_graph 58 | for i in range(3): # 8-hop mesh 59 | khop_mesh_graph_pyg = transform(khop_mesh_graph_pyg) 60 | 61 | assert graphs.mesh_graph.x.shape[0] == 2562 62 | assert graphs.g2m_graph["grid_nodes"].x.shape[0] == 360 * 180 63 | assert graphs.m2g_graph["mesh_nodes"].x.shape[0] == 2562 64 | assert not torch.isnan(graphs.mesh_graph.edge_attr).any() 65 | assert graphs.khop_mesh_graph.x.shape[0] == 2562 66 | assert torch.allclose(graphs.khop_mesh_graph.x, khop_mesh_graph_pyg.x) 67 | assert torch.allclose(graphs.khop_mesh_graph.edge_index, khop_mesh_graph_pyg.edge_index) 68 | 69 | 70 | def test_gencast_loss(): 71 | grid_lat = torch.arange(-90, 90, 1) 72 | grid_lon = torch.arange(0, 360, 1) 73 | pressure_levels = torch.tensor( 74 | [50.0, 100.0, 150.0, 200.0, 250, 300, 400, 500, 600, 700, 850, 925, 1000.0] 75 | ) 76 | single_features_weights = torch.tensor([1, 0.1, 0.1, 0.1, 0.1]) 77 | num_atmospheric_features = 6 78 | batch_size = 3 79 | features_dim = len(pressure_levels) * num_atmospheric_features + len(single_features_weights) 80 | 81 | loss = WeightedMSELoss( 82 | grid_lat=grid_lat, 83 | pressure_levels=pressure_levels, 84 | num_atmospheric_features=num_atmospheric_features, 85 | single_features_weights=single_features_weights, 86 | ) 87 | 88 | preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) 89 | noise_levels = torch.rand((batch_size, 1)) 90 | targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) 91 | assert loss.forward(preds, noise_levels, targets) is not None 92 | 93 | 94 | def test_gencast_denoiser(): 95 | grid_lat = np.arange(-90, 90, 1) 96 | grid_lon = np.arange(0, 360, 1) 97 | input_features_dim = 10 98 | output_features_dim = 5 99 | batch_size = 3 100 | 101 | denoiser = Denoiser( 102 | grid_lon=grid_lon, 103 | grid_lat=grid_lat, 104 | input_features_dim=input_features_dim, 105 | output_features_dim=output_features_dim, 106 | hidden_dims=[16, 32], 107 | num_blocks=3, 108 | num_heads=4, 109 | splits=0, 110 | num_hops=1, 111 | device=torch.device("cpu"), 112 | ).eval() 113 | 114 | corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim)) 115 | prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim)) 116 | noise_levels = torch.rand((batch_size, 1)) 117 | 118 | with torch.no_grad(): 119 | preds = denoiser( 120 | corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels 121 | ) 122 | 123 | assert not torch.isnan(preds).any() 124 | 125 | 126 | def test_gencast_fourier(): 127 | batch_size = 10 128 | output_dim = 20 129 | fourier_embedder = FourierEmbedding(output_dim=output_dim, num_frequencies=32, base_period=16) 130 | t = torch.rand((batch_size, 1)) 131 | assert fourier_embedder(t).shape == (batch_size, output_dim) 132 | 133 | 134 | def test_gencast_sampler(): 135 | grid_lat = np.arange(-90, 90, 1) 136 | grid_lon = np.arange(0, 360, 1) 137 | input_features_dim = 10 138 | output_features_dim = 5 139 | 140 | denoiser = Denoiser( 141 | grid_lon=grid_lon, 142 | grid_lat=grid_lat, 143 | input_features_dim=input_features_dim, 144 | output_features_dim=output_features_dim, 145 | hidden_dims=[16, 32], 146 | num_blocks=3, 147 | num_heads=4, 148 | splits=0, 149 | num_hops=1, 150 | device=torch.device("cpu"), 151 | ).eval() 152 | 153 | prev_inputs = torch.randn((1, len(grid_lon), len(grid_lat), 2 * input_features_dim)) 154 | 155 | sampler = Sampler() 156 | preds = sampler.sample(denoiser, prev_inputs) 157 | assert not torch.isnan(preds).any() 158 | assert preds.shape == (1, len(grid_lon), len(grid_lat), output_features_dim) 159 | 160 | 161 | @pytest.mark.skipif( 162 | Version(torch.__version__).release != Version("2.3.0").release, 163 | reason="dgl tests for experimental features only runs with torch 2.3.0", 164 | ) 165 | def test_gencast_full(): 166 | # download weights from HF 167 | denoiser = Denoiser.from_pretrained( 168 | "openclimatefix/gencast-128x64", 169 | grid_lon=np.arange(0, 360, 360 / 128), 170 | grid_lat=np.arange(-90, 90, 180 / 64) + 1 / 2 * 180 / 64, 171 | ) 172 | 173 | # load inputs and targets 174 | prev_inputs = torch.randn([1, 128, 64, 178]) 175 | target_residuals = torch.randn([1, 128, 64, 83]) 176 | 177 | # predict 178 | sampler = Sampler() 179 | preds = sampler.sample(denoiser, prev_inputs) 180 | 181 | assert not torch.isnan(preds).any() 182 | assert preds.shape == target_residuals.shape 183 | -------------------------------------------------------------------------------- /graph_weather/data/nnja_ai.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic loader for NNJA-AI datasets with support for primary descriptors and data variables. 3 | 4 | Features: 5 | - Automatically loads primary descriptors + primary data by default 6 | - Supports custom variable selection 7 | - Can load all variables when requested 8 | - Returns xarray.Dataset with time as the only coordinate 9 | - Optimized for performance with direct xarray access 10 | 11 | """ 12 | 13 | import numpy as np 14 | import xarray as xr 15 | from torch.utils.data import Dataset 16 | 17 | try: 18 | from nnja import DataCatalog 19 | except ImportError: 20 | raise ImportError("NNJA-AI library not installed. Install with: " "`pip install nnja-ai`") 21 | 22 | 23 | def _classify_variable(nnja_var) -> str: 24 | """Return category of a variable using attributes or repr fallback.""" 25 | # First try to get explicit attributes 26 | if hasattr(nnja_var, "category"): 27 | return nnja_var.category 28 | if hasattr(nnja_var, "role"): 29 | return nnja_var.role 30 | 31 | # Fallback to string representation 32 | tag = repr(nnja_var).lower() 33 | if "primary_descriptor" in tag or "primary descriptor" in tag: 34 | return "primary_descriptor" 35 | if "primary_data" in tag or "primary data" in tag: 36 | return "primary_data" 37 | return "other" 38 | 39 | 40 | def load_nnja_dataset( 41 | dataset_name: str, 42 | time=None, 43 | variables: list[str] | None = None, 44 | load_all: bool = False, 45 | ) -> xr.Dataset: 46 | """ 47 | Load a NNJA dataset as an xarray.Dataset with time as the only coordinate. 48 | 49 | Args: 50 | dataset_name: Name of NNJA dataset to load 51 | time: Time selection (single timestamp, slice, or None) 52 | variables: Specific variables to load (overrides default) 53 | load_all: Load all available variables in the dataset 54 | 55 | Returns: 56 | xarray.Dataset with only 'time' dimension/coordinate 57 | """ 58 | try: 59 | cat = DataCatalog() 60 | ds_meta = cat[dataset_name] 61 | ds_meta.load_manifest() 62 | except KeyError as e: 63 | raise ValueError(f"Dataset '{dataset_name}' not found in catalog") from e 64 | 65 | vars_dict = ds_meta.variables 66 | if load_all: 67 | vars_to_load = list(vars_dict.keys()) 68 | elif variables: 69 | # Validate requested variables 70 | invalid_vars = [v for v in variables if v not in vars_dict] 71 | if invalid_vars: 72 | raise ValueError(f"Invalid variables requested: {invalid_vars}") 73 | vars_to_load = variables 74 | else: 75 | # Default: primary descriptors + primary data 76 | primary = [ 77 | name 78 | for name, v in vars_dict.items() 79 | if _classify_variable(v) 80 | in ( 81 | "primary_descriptor", 82 | "primary_data", 83 | "primary descriptor", 84 | "primary data", 85 | ) 86 | ] 87 | vars_to_load = primary 88 | 89 | try: 90 | df = ds_meta.sel(time=time, variables=vars_to_load).load_dataset( 91 | backend="pandas", engine="pyarrow" 92 | ) 93 | except Exception as e: 94 | raise RuntimeError(f"Error loading dataset '{dataset_name}': {str(e)}") from e 95 | 96 | xrds = df.to_xarray() 97 | 98 | # Standardize coordinate names 99 | rename_map = {"OBS_TIMESTAMP": "time", "LAT": "latitude", "LON": "longitude"} 100 | for coord_var in rename_map: 101 | if coord_var in vars_dict and coord_var not in vars_to_load: 102 | vars_to_load.append(coord_var) 103 | xrds = xrds.rename({k: v for k, v in rename_map.items() if k in xrds}) 104 | 105 | # Ensure 'time' coordinate exists 106 | if "time" not in xrds and "OBS_DATE" in xrds: 107 | xrds = xrds.rename({"OBS_DATE": "time"}) 108 | 109 | # Handle time conversion if needed 110 | if "time" in xrds and not np.issubdtype(xrds.time.dtype, np.datetime64): 111 | xrds["time"] = xrds.time.astype("datetime64[ns]") 112 | 113 | # If time is not a dimension but 'obs' is, swap 114 | if "time" in xrds and "obs" in xrds.dims and "time" not in xrds.dims: 115 | xrds = xrds.swap_dims({"obs": "time"}) 116 | if "obs" in xrds.coords: 117 | xrds = xrds.reset_coords("obs", drop=True) 118 | 119 | if "time" in xrds and "time" not in xrds.coords: 120 | xrds = xrds.set_coords("time") 121 | 122 | # Flatten extra dimensions into time as may encounter an extra "index" dimension 123 | # Ensures output is always 1D along "time" 124 | extra_dims = [d for d in xrds.dims if d != "time"] 125 | if extra_dims: 126 | time_values = xrds.time.values if "time" in xrds else None 127 | xrds = xrds.stack(sample=tuple(extra_dims)) 128 | xrds = xrds.reset_index("sample") 129 | 130 | # Rename to time and restore original time values 131 | if "sample" in xrds.dims: 132 | xrds = xrds.swap_dims({"sample": "time"}) 133 | if "sample" in xrds.coords: 134 | xrds = xrds.reset_coords("sample", drop=True) 135 | if time_values is not None: 136 | xrds["time"] = ("time", time_values) 137 | 138 | if "time" not in xrds.dims: 139 | raise RuntimeError("Failed to establish 'time' dimension in output dataset") 140 | 141 | return xrds 142 | 143 | 144 | class SensorDataset(Dataset): 145 | """PyTorch Dataset wrapper for NNJA-AI datasets with optimized access.""" 146 | 147 | def __init__(self, dataset_name, time=None, variables=None, load_all=False): 148 | """Initialize dataset loader. 149 | 150 | Args: 151 | dataset_name: Name of NNJA dataset to load 152 | time: Time selection (single timestamp or slice) 153 | variables: Specific variables to load 154 | load_all: If True, loads all available variables 155 | """ 156 | self.dataset_name = dataset_name 157 | self.time = time 158 | 159 | self.xrds = load_nnja_dataset( 160 | dataset_name, time=time, variables=variables, load_all=load_all 161 | ) 162 | 163 | # Store for efficient access 164 | self.variables = list(self.xrds.data_vars.keys()) 165 | self.time_index = self.xrds.time.values 166 | 167 | def __len__(self): 168 | return self.xrds.sizes["time"] 169 | 170 | def __getitem__(self, idx): 171 | """Direct xarray access without DataFrame conversion.""" 172 | time_point = self.time_index[idx] 173 | return {var: self.xrds[var].sel(time=time_point).item() for var in self.variables} 174 | 175 | 176 | class NNJATorchDataset(Dataset): 177 | """Adapter for torch Dataset directly from xarray.""" 178 | 179 | def __init__(self, xrds): 180 | """Initialize adapter. 181 | 182 | Args: 183 | xrds: xarray Dataset to convert 184 | """ 185 | self.ds = xrds 186 | self.vars = list(xrds.data_vars.keys()) 187 | self.time_index = xrds.time.values 188 | 189 | def __len__(self): 190 | return self.ds.sizes["time"] 191 | 192 | def __getitem__(self, idx): 193 | time_point = self.time_index[idx] 194 | return {var: self.ds[var].sel(time=time_point).item() for var in self.vars} 195 | -------------------------------------------------------------------------------- /graph_weather/data/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | The dataloader has to do a few things for the model to work correctly 4 | 5 | 1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the 6 | correct resolution 7 | 2. Calculate the top-of-atmosphere solar radiation for each location at 8 | fcurrent time and 10 other 9 | times +- 12 hours 10 | 3. Add day-of-year, sin(lat), cos(lat), sin(lon), cos(lon) as well 11 | 3. Batch data as either in geometric batches, or more normally 12 | 4. Rescale between 0 and 1, but don't normalize 13 | 14 | """ 15 | 16 | import const 17 | import numpy as np 18 | import pandas as pd 19 | import xarray as xr 20 | from pysolar.util import extraterrestrial_irrad 21 | from torch.utils.data import Dataset 22 | 23 | 24 | class AnalysisDataset(Dataset): 25 | """ 26 | Dataset class for analysis data. 27 | 28 | Args: 29 | filepaths: List of file paths. 30 | invariant_path: Path to the invariant file. 31 | mean: Mean value. 32 | std Standard deviation value. 33 | coarsen : Coarsening factor. Defaults to 8. 34 | 35 | Methods: 36 | __init__: Initialize the AnalysisDataset object. 37 | __len__: Get the length of the dataset. 38 | __getitem__: Get an item from the dataset. 39 | """ 40 | 41 | def __init__(self, filepaths, invariant_path, mean, std, coarsen: int = 8): 42 | """ 43 | Initialize the AnalysisDataset object. 44 | """ 45 | super().__init__() 46 | self.filepaths = sorted(filepaths) 47 | self.invariant_path = invariant_path 48 | self.coarsen = coarsen 49 | self.mean = mean 50 | self.std = std 51 | 52 | def __len__(self): 53 | return len(self.filepaths) - 1 54 | 55 | def __getitem__(self, item): 56 | if self.coarsen <= 1: # Don't coarsen, so don't even call it 57 | start = xr.open_zarr(self.filepaths[item], consolidated=True) 58 | end = xr.open_zarr(self.filepaths[item + 1], consolidated=True) 59 | else: 60 | start = ( 61 | xr.open_zarr(self.filepaths[item], consolidated=True) 62 | .coarsen(latitude=self.coarsen, boundary="pad") 63 | .mean() 64 | .coarsen(longitude=self.coarsen) 65 | .mean() 66 | ) 67 | end = ( 68 | xr.open_zarr(self.filepaths[item + 1], consolidated=True) 69 | .coarsen(latitude=self.coarsen, boundary="pad") 70 | .mean() 71 | .coarsen(longitude=self.coarsen) 72 | .mean() 73 | ) 74 | 75 | # Land-sea mask data, resampled to the same as the physical variables 76 | landsea = ( 77 | xr.open_zarr(self.invariant_path, consolidated=True) 78 | .interp(latitude=start.latitude.values) 79 | .interp(longitude=start.longitude.values) 80 | ) 81 | # Calculate sin,cos, day of year, solar irradiance here before stacking 82 | landsea = np.stack( 83 | [ 84 | (landsea[f"{var}"].values - const.LANDSEA_MEAN[var]) / const.LANDSEA_STD[var] 85 | for var in landsea.data_vars 86 | if not np.isnan(landsea[f"{var}"].values).any() 87 | ], 88 | axis=-1, 89 | ) 90 | landsea = landsea.T.reshape((-1, landsea.shape[-1])) 91 | lat_lons = np.array(np.meshgrid(start.latitude.values, start.longitude.values)).T.reshape( 92 | (-1, 2) 93 | ) 94 | sin_lat_lons = np.sin(lat_lons) 95 | cos_lat_lons = np.cos(lat_lons) 96 | date = start.time.dt.values 97 | day_of_year = start.time.dayofyear.values / 365.0 98 | np.sin(day_of_year) 99 | np.cos(day_of_year) 100 | solar_times = [np.array([extraterrestrial_irrad(date, lat, lon) for lat, lon in lat_lons])] 101 | for when in pd.date_range( 102 | date - pd.Timedelta("12 hours"), date + pd.Timedelta("12 hours"), freq="1H" 103 | ): 104 | solar_times.append( 105 | np.array([extraterrestrial_irrad(when, lat, lon) for lat, lon in lat_lons]) 106 | ) 107 | solar_times = np.array(solar_times) 108 | 109 | # End time solar radiation too 110 | end_date = end.time.dt.values 111 | end_solar_times = [ 112 | np.array([extraterrestrial_irrad(end_date, lat, lon) for lat, lon in lat_lons]) 113 | ] 114 | for when in pd.date_range( 115 | end_date - pd.Timedelta("12 hours"), end_date + pd.Timedelta("12 hours"), freq="1H" 116 | ): 117 | end_solar_times.append( 118 | np.array([extraterrestrial_irrad(when, lat, lon) for lat, lon in lat_lons]) 119 | ) 120 | end_solar_times = np.array(solar_times) 121 | 122 | # Normalize to between -1 and 1 123 | solar_times -= const.SOLAR_MEAN 124 | solar_times /= const.SOLAR_STD 125 | end_solar_times -= const.SOLAR_MEAN 126 | end_solar_times /= const.SOLAR_STD 127 | 128 | # Stack the data into a large data cube 129 | input_data = np.stack( 130 | [ 131 | start[f"{var}"].values 132 | for var in start.data_vars 133 | if not np.isnan(start[f"{var}"].values).any() 134 | ], 135 | axis=-1, 136 | ) 137 | # TODO Combine with above? And include sin/cos of day of year 138 | input_data = np.concatenate( 139 | [ 140 | input_data.T.reshape((-1, input_data.shape[-1])), 141 | sin_lat_lons, 142 | cos_lat_lons, 143 | solar_times, 144 | landsea, 145 | ], 146 | axis=-1, 147 | ) 148 | # Not want to predict non-physics variables -> Output only the data variables? 149 | # Would be simpler, and just add in the new ones each time 150 | 151 | output_data = np.stack( 152 | [ 153 | end[f"{var}"].values 154 | for var in end.data_vars 155 | if not np.isnan(end[f"{var}"].values).any() 156 | ], 157 | axis=-1, 158 | ) 159 | 160 | output_data = np.concatenate( 161 | [ 162 | output_data.T.reshape((-1, output_data.shape[-1])), 163 | sin_lat_lons, 164 | cos_lat_lons, 165 | end_solar_times, 166 | landsea, 167 | ], 168 | axis=-1, 169 | ) 170 | # Stick with Numpy, don't tensor it, as just going from 0 to 1 171 | 172 | # Normalize now 173 | return input_data, output_data 174 | 175 | 176 | obs_data = xr.open_zarr( 177 | "/home/jacob/Development/prepbufr.gdas.20160101.t00z.nr.48h.raw.zarr", consolidated=True 178 | ) 179 | 180 | # TODO Embedding? These should stay consistent across all of the inputs, so can just load the values 181 | # not the strings? 182 | # Should only take in the quality markers, observations, reported observation time relative to start 183 | # point 184 | # Observation errors, and background values, lat/lon/height/speed of observing thing 185 | 186 | print(obs_data) 187 | print(obs_data.hdr_inst_typ.values) 188 | print(obs_data.hdr_irpt_typ.values) 189 | print(obs_data.obs_qty_table.values) 190 | print(obs_data.hdr_prpt_typ.values) 191 | print(obs_data.hdr_sid_table.values) 192 | print(obs_data.hdr_typ_table.values) 193 | print(obs_data.obs_desc.values) 194 | print(obs_data.data_vars.keys()) 195 | exit() 196 | analysis_data = xr.open_zarr( 197 | "/home/jacob/Development/gdas1.fnl0p25.2016010100.f00.zarr", consolidated=True 198 | ) 199 | print(analysis_data) 200 | -------------------------------------------------------------------------------- /graph_weather/data/anemoi_dataloader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from anemoi.datasets import open_dataset 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class AnemoiDataset(Dataset): 10 | """ 11 | Dataset class for Anemoi datasets integration with graph_weather. 12 | 13 | Args: 14 | dataset_name: Name of the Anemoi dataset (e.g., "era5-o48-2020-2021-6h-v1") 15 | features: List of atmospheric variables to use 16 | means: Dict of means for each feature (required) 17 | stds: Dict of stddevs for each feature (required) 18 | time_range: Optional tuple of (start_date, end_date) 19 | time_step: Time step between input and target (default: 1) 20 | max_samples: Maximum number of samples to use (for testing) 21 | """ 22 | 23 | def __init__( 24 | self, 25 | dataset_name: str, 26 | features: list[str], 27 | means: dict, 28 | stds: dict, 29 | time_range: tuple = None, 30 | time_step: int = 1, 31 | max_samples: int = None, 32 | **kwargs, 33 | ): 34 | super().__init__() 35 | 36 | self.features = features 37 | self.time_step = time_step 38 | self.max_samples = max_samples 39 | self.means = means 40 | self.stds = stds 41 | 42 | # Validate that normalization stats are provided for all features 43 | missing_means = [f for f in self.features if f not in self.means] 44 | missing_stds = [f for f in self.features if f not in self.stds] 45 | if missing_means or missing_stds: 46 | raise ValueError( 47 | f"Normalization statistics missing for features: " 48 | f"means missing: {missing_means}, stds missing: {missing_stds}" 49 | ) 50 | 51 | # Build Anemoi dataset configuration 52 | config = {"dataset": dataset_name} 53 | if time_range: 54 | config["start"] = time_range[0] 55 | config["end"] = time_range[1] 56 | config.update(kwargs) 57 | 58 | # Load the dataset 59 | try: 60 | self.dataset = open_dataset(config) 61 | if hasattr(self.dataset, "to_xarray"): 62 | self.data = self.dataset.to_xarray() 63 | elif hasattr(self.dataset, "to_dataset"): 64 | self.data = self.dataset.to_dataset() 65 | else: 66 | self.data = self.dataset 67 | logging.info(f"Successfully loaded Anemoi dataset: {dataset_name}") 68 | except Exception as e: 69 | raise RuntimeError( 70 | f"Failed to load Anemoi dataset '{dataset_name}': {e}. " 71 | "Please ensure the dataset is available and properly configured." 72 | ) 73 | 74 | # Validate that we have the required features 75 | missing_features = [f for f in self.features if f not in self.data.data_vars] 76 | if missing_features: 77 | available_features = list(self.data.data_vars.keys()) 78 | raise ValueError( 79 | f"Features {missing_features} not found in dataset. " 80 | f"Available features: {available_features}" 81 | ) 82 | 83 | # Get grid information - try multiple coordinate name variations 84 | coord_names = ["latitude", "lat", "y"] 85 | self.grid_lat = None 86 | for name in coord_names: 87 | if name in self.data.coords: 88 | self.grid_lat = self.data.coords[name] 89 | break 90 | 91 | coord_names = ["longitude", "lon", "x"] 92 | self.grid_lon = None 93 | for name in coord_names: 94 | if name in self.data.coords: 95 | self.grid_lon = self.data.coords[name] 96 | break 97 | 98 | if self.grid_lat is None or self.grid_lon is None: 99 | available_coords = list(self.data.coords.keys()) 100 | raise ValueError( 101 | f"Could not find latitude/longitude coordinates in dataset. " 102 | f"Available coordinates: {available_coords}" 103 | ) 104 | 105 | self.num_lat = len(self.grid_lat) 106 | self.num_lon = len(self.grid_lon) 107 | 108 | def _normalize(self, data, feature): 109 | """Normalize data using feature-specific statistics""" 110 | if feature not in self.means or feature not in self.stds: 111 | raise ValueError(f"Normalization stats for feature '{feature}' not provided.") 112 | return (data - self.means[feature]) / (self.stds[feature] + 1e-6) 113 | 114 | def _generate_clock_features(self, data_time): 115 | """Generate time features following GenCast pattern""" 116 | if hasattr(data_time, "values"): 117 | timestamp = pd.Timestamp(data_time.values) 118 | else: 119 | timestamp = data_time 120 | 121 | # Leap year aware normalization 122 | year = timestamp.year 123 | is_leap = year % 4 == 0 and (year % 100 != 0 or year % 400 == 0) 124 | days_in_year = 366.0 if is_leap else 365.0 125 | day_of_year = timestamp.dayofyear / days_in_year 126 | 127 | sin_day_of_year = np.sin(2 * np.pi * day_of_year) 128 | cos_day_of_year = np.cos(2 * np.pi * day_of_year) 129 | 130 | hour_of_day = timestamp.hour / 24.0 131 | sin_hour_of_day = np.sin(2 * np.pi * hour_of_day) 132 | cos_hour_of_day = np.cos(2 * np.pi * hour_of_day) 133 | 134 | num_locations = self.num_lat * self.num_lon 135 | time_features = np.column_stack( 136 | [ 137 | np.full(num_locations, sin_day_of_year), 138 | np.full(num_locations, cos_day_of_year), 139 | np.full(num_locations, sin_hour_of_day), 140 | np.full(num_locations, cos_hour_of_day), 141 | ] 142 | ) 143 | 144 | return time_features 145 | 146 | def __len__(self): 147 | total_length = len(self.data.time) - self.time_step 148 | if self.max_samples: 149 | return min(total_length, self.max_samples) 150 | return total_length 151 | 152 | def __getitem__(self, idx): 153 | input_data_slice = self.data.isel(time=idx) 154 | target_data_slice = self.data.isel(time=idx + self.time_step) 155 | 156 | input_features = [] 157 | target_features = [] 158 | 159 | for feature in self.features: 160 | input_vals = input_data_slice[feature].values.reshape(-1) 161 | target_vals = target_data_slice[feature].values.reshape(-1) 162 | input_vals = self._normalize(input_vals, feature) 163 | target_vals = self._normalize(target_vals, feature) 164 | input_features.append(input_vals.reshape(-1, 1)) 165 | target_features.append(target_vals.reshape(-1, 1)) 166 | 167 | input_data = np.concatenate(input_features, axis=1) 168 | target_data = np.concatenate(target_features, axis=1) 169 | 170 | time_features = self._generate_clock_features(input_data_slice.time) 171 | input_data = np.concatenate([input_data, time_features], axis=1) 172 | target_data = np.concatenate([target_data, time_features], axis=1) 173 | 174 | return input_data.astype(np.float32), target_data.astype(np.float32) 175 | 176 | def get_dataset_info(self): 177 | """Return information about the loaded dataset""" 178 | return { 179 | "dataset_name": getattr(self, "dataset_name", "unknown"), 180 | "features": self.features, 181 | "grid_shape": (self.num_lat, self.num_lon), 182 | "time_steps": len(self.data.time), 183 | "dataset_length": len(self), 184 | "normalization_stats": {"means": self.means, "stds": self.stds}, 185 | } 186 | -------------------------------------------------------------------------------- /graph_weather/models/fgn/layers/processor.py: -------------------------------------------------------------------------------- 1 | """Processor layer. 2 | 3 | The processor: 4 | - compute a sequence of transformer blocks applied to the mesh. 5 | - condition on the noise level. 6 | """ 7 | 8 | import torch 9 | 10 | from graph_weather.models.gencast.layers.modules import MLP, CondTransformerBlock 11 | 12 | try: 13 | from graph_weather.models.gencast.layers.experimental import SparseTransformer 14 | 15 | has_dgl = True 16 | except ImportError: 17 | has_dgl = False 18 | 19 | 20 | class Processor(torch.nn.Module): 21 | """GenCast's Processor 22 | 23 | The Processor is a sequence of transformer blocks conditioned on noise level. If the graph has 24 | many edges, setting sparse=True may perform better in terms of memory and speed. Note that 25 | sparse=False uses PyG as the backend, while sparse=True uses DGL. The two implementations are 26 | not exactly equivalent: the former is described in the paper "Masked Label Prediction: Unified 27 | Message Passing Model for Semi-Supervised Classification" and can also handle edge features, 28 | while the latter is a classical transformer that performs multi-head attention utilizing the 29 | mask's sparsity and does not include edge features in the computations. 30 | 31 | Note: The GenCast paper does not provide specific details regarding the implementation of the 32 | transformer architecture for graphs. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | latent_dim: int, 38 | hidden_dims: list[int], 39 | num_blocks: int, 40 | num_heads: int, 41 | noise_emb_dim: int, 42 | edges_dim: int | None = None, 43 | activation_layer: torch.nn.Module = torch.nn.ReLU, 44 | use_layer_norm: bool = True, 45 | sparse: bool = False, 46 | ): 47 | """Initialize the Processor. 48 | 49 | Args: 50 | latent_dim (int): dimension of nodes' features. 51 | hidden_dims (list[int]): hidden dimensions of internal MLPs. 52 | num_blocks (int): number of transformer blocks. 53 | num_heads (int): number of heads for multi-head attention. 54 | num_frequencies (int): number of frequencies for the noise Fourier embedding. 55 | base_period (int): base period for the noise Fourier embedding. 56 | noise_emb_dim (int): dimension of output of noise embedding. 57 | edges_dim (int, optional): dimension of edges' features. If None does not uses edges 58 | features in TransformerConv. Defaults to None. 59 | activation_layer (torch.nn.Module): activation function of internal MLPs. 60 | Defaults to torch.nn.ReLU. 61 | use_layer_norm (bool): if true add a LayerNorm at the end of the embedding MLP. 62 | Defaults to True. 63 | sparse (bool): if true use DGL as backend (experimental). Defaults to False. 64 | """ 65 | super().__init__() 66 | self.latent_dim = latent_dim 67 | if latent_dim % num_heads != 0: 68 | raise ValueError("The latent dimension should be divisible by the number of heads.") 69 | 70 | self.edges_dim = edges_dim 71 | if edges_dim is not None: 72 | self.edges_mlp = MLP( 73 | input_dim=edges_dim, 74 | hidden_dims=hidden_dims, 75 | activation_layer=activation_layer, 76 | use_layer_norm=use_layer_norm, 77 | bias=True, 78 | activate_final=False, 79 | ) 80 | 81 | # Tranformers Blocks 82 | self.cond_transformers = torch.nn.ModuleList() 83 | if not sparse: 84 | for _ in range(num_blocks - 1): 85 | # concatenating multi-head attention 86 | self.cond_transformers.append( 87 | CondTransformerBlock( 88 | conditioning_dim=noise_emb_dim, 89 | input_dim=latent_dim, 90 | output_dim=latent_dim // num_heads, 91 | edges_dim=hidden_dims[-1] if (edges_dim is not None) else None, 92 | num_heads=num_heads, 93 | concat=True, 94 | beta=True, 95 | activation_layer=activation_layer, 96 | ) 97 | ) 98 | 99 | # averaging multi-head attention 100 | self.cond_transformers.append( 101 | CondTransformerBlock( 102 | conditioning_dim=noise_emb_dim, 103 | input_dim=latent_dim, 104 | output_dim=latent_dim, 105 | edges_dim=hidden_dims[-1] if (edges_dim is not None) else None, 106 | num_heads=num_heads, 107 | concat=False, 108 | beta=True, 109 | activation_layer=None, 110 | ) 111 | ) 112 | else: 113 | if not has_dgl: 114 | raise ValueError("Please install DGL to use sparsity.") 115 | 116 | for _ in range(num_blocks): 117 | # concatenating multi-head attention 118 | self.cond_transformers.append( 119 | SparseTransformer( 120 | conditioning_dim=noise_emb_dim, 121 | input_dim=latent_dim, 122 | output_dim=latent_dim, 123 | num_heads=num_heads, 124 | activation_layer=activation_layer, 125 | ) 126 | ) 127 | # do we really need averaging for last block? 128 | 129 | def _check_args(self, latent_mesh_nodes, noise_levels, input_edge_attr): 130 | if not latent_mesh_nodes.shape[-1] == self.latent_dim: 131 | raise ValueError( 132 | "The dimension of the mesh nodes is different from the latent dimension provided at" 133 | " initialization." 134 | ) 135 | 136 | if not latent_mesh_nodes.shape[0] == noise_levels.shape[0]: 137 | raise ValueError( 138 | "The number of noise levels and mesh nodes should be the same, but got " 139 | f"{latent_mesh_nodes.shape[0]} and {noise_levels.shape[0]}. Eventually repeat the " 140 | " noise level for each node in the same batch." 141 | ) 142 | 143 | if (input_edge_attr is not None) and (self.edges_dim is None): 144 | raise ValueError("To use input_edge_attr initialize the processor with edges_dim.") 145 | 146 | def forward( 147 | self, 148 | latent_mesh_nodes: torch.Tensor, 149 | edge_index: torch.Tensor, 150 | noise_vector: torch.Tensor, 151 | input_edge_attr: torch.Tensor | None = None, 152 | ) -> torch.Tensor: 153 | """Forward pass. 154 | 155 | Args: 156 | latent_mesh_nodes (torch.Tensor): mesh nodes' features. 157 | edge_index (torch.Tensor): edge index tensor. 158 | noise_vector (torch.Tensor): noise vector for conditioning 159 | input_edge_attr (torch.Tensor, optional): mesh edges' features. 160 | 161 | Returns: 162 | torch.Tensor: latent mesh nodes. 163 | """ 164 | self._check_args(latent_mesh_nodes, noise_vector, input_edge_attr) 165 | 166 | if self.edges_dim is not None: 167 | edges_emb = self.edges_mlp(input_edge_attr) 168 | else: 169 | edges_emb = None 170 | 171 | # apply transformer blocks 172 | for cond_transformer in self.cond_transformers: 173 | latent_mesh_nodes = cond_transformer( 174 | x=latent_mesh_nodes, 175 | edge_index=edge_index, 176 | cond_param=noise_vector, 177 | edge_attr=edges_emb, 178 | ) 179 | 180 | return latent_mesh_nodes 181 | -------------------------------------------------------------------------------- /graph_weather/models/layers/assimilator_decoder.py: -------------------------------------------------------------------------------- 1 | """Decoders to decode from the Processor graph to the original graph with updated values 2 | 3 | In the original paper the decoder is described as 4 | 5 | The Decoder maps back to physical data defined on a latitude/longitude grid. The underlying graph is 6 | again bipartite, this time mapping icosahedron→lat/lon. 7 | The inputs to the Decoder come from the Processor, plus a skip connection back to the original 8 | state of the 78 atmospheric variables onthe latitude/longitude grid. 9 | The output of the Decoder is the predicted 6-hour change in the 78 atmospheric variables, 10 | which is then added to the initial state to produce the new state. We found 6 hours to be a good 11 | balance between shorter time steps (simpler dynamics to model but more iterations required during 12 | rollout) and longer time steps (fewer iterations required during rollout but modeling 13 | more complex dynamics) 14 | 15 | """ 16 | 17 | import einops 18 | import h3 19 | import numpy as np 20 | import torch 21 | from torch_geometric.data import Data 22 | 23 | from graph_weather.models.layers.graph_net_block import MLP, GraphProcessor 24 | 25 | 26 | class AssimilatorDecoder(torch.nn.Module): 27 | """Assimilator graph module""" 28 | 29 | def __init__( 30 | self, 31 | lat_lons: list, 32 | resolution: int = 2, 33 | input_dim: int = 256, 34 | output_dim: int = 78, 35 | output_edge_dim: int = 256, 36 | hidden_dim_processor_node: int = 256, 37 | hidden_dim_processor_edge: int = 256, 38 | hidden_layers_processor_node: int = 2, 39 | hidden_layers_processor_edge: int = 2, 40 | mlp_norm_type: str = "LayerNorm", 41 | hidden_dim_decoder: int = 128, 42 | hidden_layers_decoder: int = 2, 43 | use_checkpointing: bool = False, 44 | ): 45 | """ 46 | Decoder from latent graph to lat/lon graph for assimilation of observation 47 | 48 | Args: 49 | lat_lons: List of (lat,lon) points 50 | resolution: H3 resolution level 51 | input_dim: Input node dimension 52 | output_dim: Output node dimension 53 | output_edge_dim: Edge dimension 54 | hidden_dim_processor_node: Hidden dimension of the node processors 55 | hidden_dim_processor_edge: Hidden dimension of the edge processors 56 | hidden_layers_processor_node: Number of hidden layers in the node processors 57 | hidden_layers_processor_edge: Number of hidden layers in the edge processors 58 | hidden_dim_decoder:Number of hidden dimensions in the decoder 59 | hidden_layers_decoder: Number of layers in the decoder 60 | mlp_norm_type: Type of norm for the MLPs 61 | one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None 62 | use_checkpointing: Whether to use gradient checkpointing to reduce model size 63 | """ 64 | super().__init__() 65 | self.use_checkpointing = use_checkpointing 66 | self.num_latlons = len(lat_lons) 67 | self.base_h3_grid = sorted(list(h3.uncompact_cells(h3.get_res0_cells(), resolution))) 68 | self.num_h3 = len(self.base_h3_grid) 69 | self.h3_grid = [h3.latlng_to_cell(lat, lon, resolution) for lat, lon in lat_lons] 70 | self.h3_to_index = {} 71 | h_index = len(self.base_h3_grid) 72 | for h in self.base_h3_grid: 73 | if h not in self.h3_to_index: 74 | h_index -= 1 75 | self.h3_to_index[h] = h_index 76 | self.h3_mapping = {} 77 | for h, value in enumerate(self.h3_grid): 78 | self.h3_mapping[h + self.num_h3] = value 79 | 80 | # Build the default graph 81 | # Extra starting ones for appending to inputs, could 'learn' good starting points 82 | self.latlon_nodes = torch.zeros((len(lat_lons), input_dim), dtype=torch.float) 83 | # Get connections between lat nodes and h3 nodes TODO Paper makes it seem like the 3 84 | # closest iso points map to the lat/lon point Do kring 1 around current h3 cell, 85 | # and calculate distance between all those points and the lat/lon one, choosing the 86 | # nearest N (3) For a bit simpler, just include them all with their distances 87 | edge_sources = [] 88 | edge_targets = [] 89 | self.h3_to_lat_distances = [] 90 | for node_index, h_node in enumerate(self.h3_grid): 91 | # Get h3 index 92 | h_points = h3.grid_disk(self.h3_mapping[node_index + self.num_h3], 1) 93 | for h in h_points: 94 | distance = h3.great_circle_distance( 95 | lat_lons[node_index], h3.cell_to_latlng(h), unit="rads" 96 | ) 97 | self.h3_to_lat_distances.append([np.sin(distance), np.cos(distance)]) 98 | edge_sources.append(self.h3_to_index[h]) 99 | edge_targets.append(node_index + self.num_h3) 100 | edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long) 101 | self.h3_to_lat_distances = torch.tensor(self.h3_to_lat_distances, dtype=torch.float) 102 | 103 | # Use normal graph as its a bit simpler 104 | self.graph = Data(edge_index=edge_index, edge_attr=self.h3_to_lat_distances) 105 | 106 | self.edge_encoder = MLP( 107 | 2, output_edge_dim, hidden_dim_processor_edge, 2, mlp_norm_type, self.use_checkpointing 108 | ) 109 | self.graph_processor = GraphProcessor( 110 | mp_iterations=1, 111 | in_dim_node=input_dim, 112 | in_dim_edge=output_edge_dim, 113 | hidden_dim_node=hidden_dim_processor_node, 114 | hidden_dim_edge=hidden_dim_processor_edge, 115 | hidden_layers_node=hidden_layers_processor_node, 116 | hidden_layers_edge=hidden_layers_processor_edge, 117 | norm_type=mlp_norm_type, 118 | ) 119 | self.node_decoder = MLP( 120 | input_dim, 121 | output_dim, 122 | hidden_dim_decoder, 123 | hidden_layers_decoder, 124 | None, 125 | self.use_checkpointing, 126 | ) 127 | 128 | def forward(self, processor_features: torch.Tensor, batch_size: int) -> torch.Tensor: 129 | """ 130 | Adds features to the encoding graph 131 | 132 | Args: 133 | processor_features: Processed features in shape [B*Nodes, Features] 134 | batch_size: Batch size 135 | 136 | Returns: 137 | Updated features for model 138 | """ 139 | self.graph = self.graph.to(processor_features.device) 140 | edge_attr = self.edge_encoder(self.graph.edge_attr) # Update attributes based on distance 141 | edge_attr = einops.repeat(edge_attr, "e f -> (repeat e) f", repeat=batch_size) 142 | 143 | edge_index = torch.cat( 144 | [ 145 | self.graph.edge_index + i * torch.max(self.graph.edge_index) + i 146 | for i in range(batch_size) 147 | ], 148 | dim=1, 149 | ) 150 | 151 | # Readd nodes to match graph node number 152 | self.latlon_nodes = self.latlon_nodes.to(processor_features.device) 153 | features = einops.rearrange(processor_features, "(b n) f -> b n f", b=batch_size) 154 | features = torch.cat( 155 | [features, einops.repeat(self.latlon_nodes, "n f -> b n f", b=batch_size)], dim=1 156 | ) 157 | features = einops.rearrange(features, "b n f -> (b n) f") 158 | 159 | out, _ = self.graph_processor(features, edge_index, edge_attr) # Message Passing 160 | # Remove the h3 nodes now, only want the latlon ones 161 | out = self.node_decoder(out) # Decode to 78 from 256 162 | out = einops.rearrange(out, "(b n) f -> b n f", b=batch_size) 163 | test, out = torch.split(out, [self.num_h3, self.num_latlons], dim=1) 164 | return out 165 | --------------------------------------------------------------------------------