├── DiffDA ├── checkpoint.py ├── dataloader.py ├── diffusion_common.py ├── inference_data_assimilation.py ├── inference_data_assimilation_gc.py ├── obs_coords.csv ├── train_conditional_graphcast.py └── train_state.py ├── Dockerfile ├── GraphCast_src ├── .DS_Store ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── buckets.py ├── graphcast │ ├── __init__.py │ ├── autoregressive.py │ ├── casting.py │ ├── checkpoint.py │ ├── checkpoint_test.py │ ├── data_utils.py │ ├── data_utils_test.py │ ├── deep_typed_graph_net.py │ ├── graphcast.py │ ├── grid_mesh_connectivity.py │ ├── grid_mesh_connectivity_test.py │ ├── icosahedral_mesh.py │ ├── icosahedral_mesh_test.py │ ├── losses.py │ ├── model_utils.py │ ├── normalization.py │ ├── predictor_base.py │ ├── rollout.py │ ├── typed_graph.py │ ├── typed_graph_net.py │ ├── xarray_jax.py │ ├── xarray_jax_test.py │ ├── xarray_tree.py │ └── xarray_tree_test.py ├── graphcast_runner.py ├── graphcast_wrapper.py ├── plotting.py ├── pretrained_graphcast.py ├── setup.py └── weatherbench2_dataloader.py ├── README.md └── figs ├── ablation_experiment1 ├── score_board_0_blur0.5.pdf ├── score_board_0_blur0.5_geopotential.pdf ├── score_board_0_blur0.5_sfc.pdf ├── score_board_0_blur0.5_specific_humidity.pdf ├── score_board_0_blur0.5_temperature.pdf ├── score_board_0_blur0.5_u_component_of_wind.pdf ├── score_board_0_blur0.5_v_component_of_wind.pdf ├── score_board_0_blur0.5_vertical_velocity.pdf ├── score_board_0_blur1.0.pdf ├── score_board_0_blur1.0_geopotential.pdf ├── score_board_0_blur1.0_sfc.pdf ├── score_board_0_blur1.0_specific_humidity.pdf ├── score_board_0_blur1.0_temperature.pdf ├── score_board_0_blur1.0_u_component_of_wind.pdf ├── score_board_0_blur1.0_v_component_of_wind.pdf ├── score_board_0_blur1.0_vertical_velocity.pdf ├── score_board_0_blur1.5.pdf ├── score_board_0_blur1.5_geopotential.pdf ├── score_board_0_blur1.5_sfc.pdf ├── score_board_0_blur1.5_specific_humidity.pdf ├── score_board_0_blur1.5_temperature.pdf ├── score_board_0_blur1.5_u_component_of_wind.pdf ├── score_board_0_blur1.5_v_component_of_wind.pdf ├── score_board_0_blur1.5_vertical_velocity.pdf ├── score_board_0_blur2.0.pdf ├── score_board_0_blur2.0_geopotential.pdf ├── score_board_0_blur2.0_sfc.pdf ├── score_board_0_blur2.0_specific_humidity.pdf ├── score_board_0_blur2.0_temperature.pdf ├── score_board_0_blur2.0_u_component_of_wind.pdf ├── score_board_0_blur2.0_v_component_of_wind.pdf ├── score_board_0_blur2.0_vertical_velocity.pdf ├── score_board_0_blur2.5.pdf ├── score_board_0_blur2.5_geopotential.pdf ├── score_board_0_blur2.5_sfc.pdf ├── score_board_0_blur2.5_specific_humidity.pdf ├── score_board_0_blur2.5_temperature.pdf ├── score_board_0_blur2.5_u_component_of_wind.pdf ├── score_board_0_blur2.5_v_component_of_wind.pdf └── score_board_0_blur2.5_vertical_velocity.pdf ├── ablation_experiment3 ├── score_board_48_blur0.5.pdf ├── score_board_48_blur0.5_geopotential.pdf ├── score_board_48_blur0.5_sfc.pdf ├── score_board_48_blur0.5_specific_humidity.pdf ├── score_board_48_blur0.5_temperature.pdf ├── score_board_48_blur0.5_u_component_of_wind.pdf ├── score_board_48_blur0.5_v_component_of_wind.pdf ├── score_board_48_blur0.5_vertical_velocity.pdf ├── score_board_48_blur1.0.pdf ├── score_board_48_blur1.0_geopotential.pdf ├── score_board_48_blur1.0_sfc.pdf ├── score_board_48_blur1.0_specific_humidity.pdf ├── score_board_48_blur1.0_temperature.pdf ├── score_board_48_blur1.0_u_component_of_wind.pdf ├── score_board_48_blur1.0_v_component_of_wind.pdf ├── score_board_48_blur1.0_vertical_velocity.pdf ├── score_board_48_blur1.5.pdf ├── score_board_48_blur1.5_geopotential.pdf ├── score_board_48_blur1.5_sfc.pdf ├── score_board_48_blur1.5_specific_humidity.pdf ├── score_board_48_blur1.5_temperature.pdf ├── score_board_48_blur1.5_u_component_of_wind.pdf ├── score_board_48_blur1.5_v_component_of_wind.pdf ├── score_board_48_blur1.5_vertical_velocity.pdf ├── score_board_48_blur2.0.pdf ├── score_board_48_blur2.0_geopotential.pdf ├── score_board_48_blur2.0_sfc.pdf ├── score_board_48_blur2.0_specific_humidity.pdf ├── score_board_48_blur2.0_temperature.pdf ├── score_board_48_blur2.0_u_component_of_wind.pdf ├── score_board_48_blur2.0_v_component_of_wind.pdf ├── score_board_48_blur2.0_vertical_velocity.pdf ├── score_board_48_blur2.5.pdf ├── score_board_48_blur2.5_geopotential.pdf ├── score_board_48_blur2.5_sfc.pdf ├── score_board_48_blur2.5_specific_humidity.pdf ├── score_board_48_blur2.5_temperature.pdf ├── score_board_48_blur2.5_u_component_of_wind.pdf ├── score_board_48_blur2.5_v_component_of_wind.pdf └── score_board_48_blur2.5_vertical_velocity.pdf └── error_visualization ├── 2m_temperature_1000.pdf ├── 2m_temperature_10000.pdf ├── 2m_temperature_2000.pdf ├── 2m_temperature_20000.pdf ├── 2m_temperature_4000.pdf ├── 2m_temperature_40000.pdf ├── 2m_temperature_8000.pdf ├── geopotential_500_1000.pdf ├── geopotential_500_10000.pdf ├── geopotential_500_2000.pdf ├── geopotential_500_20000.pdf ├── geopotential_500_4000.pdf ├── geopotential_500_40000.pdf ├── geopotential_500_8000.pdf ├── temperature_850_1000.pdf ├── temperature_850_10000.pdf ├── temperature_850_2000.pdf ├── temperature_850_20000.pdf ├── temperature_850_4000.pdf ├── temperature_850_40000.pdf └── temperature_850_8000.pdf /DiffDA/checkpoint.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import pickle 3 | from typing import Union 4 | 5 | import haiku as hk 6 | import optax 7 | from graphcast import graphcast, checkpoint 8 | import jax 9 | from pathlib import Path 10 | import os 11 | import re 12 | import diffusers 13 | 14 | @dataclasses.dataclass(frozen=True) 15 | class TrainingCheckpoint: 16 | params: hk.Params 17 | opt_state: optax.OptState 18 | scheduler_state: diffusers.schedulers.scheduling_ddpm_flax.DDPMSchedulerState 19 | task_config: graphcast.TaskConfig 20 | model_config: graphcast.ModelConfig 21 | epoch: int 22 | rng: jax.Array 23 | num_train_timesteps: int = 1000 24 | # TODO: add num_train_timesteps 25 | 26 | def save_checkpoint(directory: Path, ckpt: TrainingCheckpoint) -> None: 27 | """ 28 | Stores a checkpoint at the given directory with the name 29 | directory / diff_gc_{epoch}.npz 30 | If the given directory does not exist, it will be created. 31 | If there already exists a checkpoint in the same directory with the same epoch, 32 | it will be overwritten. 33 | """ 34 | directory.mkdir(parents=True, exist_ok=True) 35 | save_file = directory / f"diff_gc_{ckpt.epoch}.npz" 36 | with open(save_file, mode="wb") as file: 37 | pickle.dump(ckpt, file) 38 | #checkpoint.dump(file, ckpt) 39 | 40 | def load_checkpoint(directory: Path, epoch: int = -1) -> Union[TrainingCheckpoint, None]: 41 | """ 42 | Loads a checkpoint from the given directory. If a non-negative epoch is given, 43 | that checkpoint is loaded. Otherwise, the latest epoch is loaded. 44 | If no checkpoint is found, None is returned. 45 | """ 46 | if not directory.exists(): 47 | return None 48 | # Define the pattern using a regular expression 49 | pattern = r"diff_gc_(\d+)" 50 | # Initialize an empty list to store the tuples (i, path) 51 | file_tuples = [] 52 | # Iterate through all files in the directory 53 | for filename in os.listdir(directory): 54 | # Check if the filename matches the pattern 55 | match = re.match(pattern, filename) 56 | if match: 57 | # Extract the epoch i from the match 58 | i = int(match.group(1)) 59 | # Construct the full path to the file 60 | filepath = os.path.join(directory, filename) 61 | # Append the tuple (i, path) to the list 62 | file_tuples.append((i, filepath)) 63 | 64 | if len(file_tuples) == 0: 65 | return None 66 | 67 | # Sort the list of tuples based on the integer i 68 | file_tuples.sort() 69 | ckpt_path = None 70 | if epoch < 0: 71 | ckpt_path = file_tuples[-1][1] 72 | else: 73 | for (i, f) in file_tuples: 74 | if i == epoch: 75 | ckpt_path = f 76 | break 77 | 78 | with open(ckpt_path, "rb") as file: 79 | return pickle.load(file) 80 | #return checkpoint.load(file, TrainingCheckpoint) -------------------------------------------------------------------------------- /DiffDA/diffusion_common.py: -------------------------------------------------------------------------------- 1 | import xarray 2 | import numpy as np 3 | import jax 4 | import jax.numpy as jnp 5 | import haiku as hk 6 | from tqdm import tqdm 7 | import wandb 8 | 9 | from graphcast import graphcast, normalization, casting, xarray_jax, xarray_tree 10 | import graphcast.autoregressive as autoregressive 11 | from graphcast.data_utils import get_day_progress, get_year_progress, featurize_progress 12 | 13 | def get_forcing(time: xarray.DataArray, lon: xarray.DataArray, timesteps: np.ndarray, num_timesteps: int, batch_size: int = 0, forcing_type: str = "diffusion") -> None: 14 | 15 | DAY_PROGRESS = "day_progress" 16 | YEAR_PROGRESS = "year_progress" 17 | 18 | # Compute seconds since epoch. 19 | # Note `data.coords["datetime"].astype("datetime64[s]").astype(np.int64)` 20 | # does not work as xarrays always cast dates into nanoseconds! 21 | batch_dim = ("batch",) 22 | seconds_since_epoch = time.data.astype("datetime64[s]").astype(np.int64) 23 | seconds_since_epoch = seconds_since_epoch.reshape((batch_size, 1)) 24 | 25 | # Add year progress features. 26 | year_progress = get_year_progress(seconds_since_epoch) 27 | forcing_dict = {} 28 | forcing_dict.update(featurize_progress( 29 | name=YEAR_PROGRESS, dims=batch_dim + ("time",), progress=year_progress)) 30 | # Add day progress features. 31 | day_progress = get_day_progress(seconds_since_epoch, lon.data) 32 | forcing_dict.update(featurize_progress( 33 | name=DAY_PROGRESS, 34 | dims=batch_dim + ("time",) + lon.dims, 35 | progress=day_progress)) 36 | 37 | # hijack year_progress_sin for timesteps 38 | if forcing_type == "diffusion": 39 | forcing_dict["year_progress_sin"].data = (timesteps / num_timesteps * 2 - 1).astype(np.float32).reshape(forcing_dict["year_progress_sin"].shape) 40 | 41 | ds_forcing = xarray.Dataset(forcing_dict).drop_vars(["day_progress", "year_progress"]) 42 | 43 | return ds_forcing 44 | 45 | 46 | def wrap_graphcast(model_config: graphcast.ModelConfig, 47 | task_config: graphcast.TaskConfig, 48 | stddev_by_level: xarray.Dataset, 49 | mean_by_level: xarray.Dataset, 50 | diffs_stddev_by_level: xarray.Dataset,): 51 | """ 52 | Constructs and wraps the GraphCast Predictor. 53 | Note that this MUST be called within a haiku transform function. 54 | """ 55 | # Deeper one-step predictor. 56 | predictor = graphcast.GraphCast(model_config, task_config) 57 | 58 | # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to 59 | # from/to float32 to/from BFloat16. 60 | predictor = casting.Bfloat16Cast(predictor) 61 | 62 | # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from 63 | # BFloat16 happens after applying normalization to the inputs/targets. 64 | predictor = normalization.InputsAndResidualsForDiffusion(predictor, 65 | stddev_by_level=stddev_by_level, 66 | mean_by_level=mean_by_level, 67 | diffs_stddev_by_level=diffs_stddev_by_level) 68 | return predictor 69 | 70 | def wrap_graphcast_prediction(model_config: graphcast.ModelConfig, 71 | task_config: graphcast.TaskConfig, 72 | diffs_stddev_by_level = None, 73 | mean_by_level = None, 74 | stddev_by_level = None, 75 | wrap_autoregressive: bool = False, 76 | normalize: bool = False): 77 | """ 78 | Constructs and wraps the GraphCast Predictor. 79 | Note that this MUST be called within a haiku transform function. 80 | """ 81 | # Deeper one-step predictor. 82 | predictor = graphcast.GraphCast(model_config, task_config) 83 | 84 | # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to 85 | # from/to float32 to/from BFloat16. 86 | predictor = casting.Bfloat16Cast(predictor) 87 | 88 | if normalize: 89 | assert diffs_stddev_by_level is not None 90 | assert mean_by_level is not None 91 | assert stddev_by_level is not None 92 | # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from 93 | # BFloat16 happens after applying normalization to the inputs/targets. 94 | predictor = normalization.InputsAndResiduals( 95 | predictor, 96 | diffs_stddev_by_level=diffs_stddev_by_level, 97 | mean_by_level=mean_by_level, 98 | stddev_by_level=stddev_by_level) 99 | 100 | if wrap_autoregressive: 101 | # Wraps everything so the one-step model can produce trajectories. 102 | predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True) 103 | 104 | return predictor 105 | 106 | def _to_numpy_xarray(array) -> xarray.Dataset: 107 | # Unwrap 108 | vars = xarray_jax.unwrap_vars(array) 109 | coords = xarray_jax.unwrap_coords(array) 110 | # Ensure it's numpy 111 | vars = {n: (array[n].dims, np.asarray(v)) if len(array[n].dims) > 0 else np.asarray(v)for n,v in vars.items()} 112 | coords = {n: (array[n].dims, np.asarray(v)) if len(array[n].dims) > 0 else np.asarray(v) for n, v in coords.items()} 113 | # Create new dataset 114 | copied_dataset = xarray.Dataset(vars, coords) 115 | return copied_dataset 116 | 117 | def _to_jax_xarray(dataset: xarray.Dataset, device) -> xarray_jax.Dataset: 118 | # Unwrap 119 | vars = dataset.variables 120 | coords = dataset.coords 121 | # Ensure it's numpy 122 | vars = {n: (dataset[n].dims, jax.device_put(jnp.array(v.data), device)) if len(dataset[n].dims) > 0 else jax.device_put(jnp.array(v.data), device) for n,v in vars.items()} 123 | coords = {n: (dataset[n].dims, jax.device_put(jnp.array(v.data), device)) if len(dataset[n].dims) > 0 else jax.device_put(jnp.array(v.data), device) for n,v in coords.items()} 124 | # Create new dataset 125 | copied_dataset = xarray_jax.Dataset(vars, coords) 126 | return copied_dataset 127 | 128 | def validation_step(forward_fn: hk.TransformedWithState, norm_original_fn, norm_diff_fn, params, state, validate_dataset, args, device, rng, mode="ddpm"): 129 | pbar = enumerate(validate_dataset) 130 | progress_bar = False 131 | if args.rank == 0: 132 | pbar = tqdm(pbar, desc="Validation", total=len(validate_dataset)) 133 | progress_bar = True 134 | for batch_idx, batch in pbar: 135 | rng_batch = jax.random.fold_in(rng, batch_idx) 136 | timesteps = np.ones((args.validation_batch_size,), dtype=np.int32) 137 | inputs_pred = batch['graphcast'] 138 | inputs_ground_truth = batch['weatherbench'] 139 | inputs_static = batch['static'] 140 | datetime = inputs_ground_truth.datetime 141 | lon = inputs_ground_truth.lon 142 | norm_forcings = get_forcing(datetime, lon, timesteps, args.num_train_timesteps, batch_size=args.validation_batch_size) 143 | norm_forcings = _to_jax_xarray(norm_forcings, device) 144 | inputs_pred = _to_jax_xarray(inputs_pred.drop_vars("datetime"), device) 145 | inputs_ground_truth = _to_jax_xarray(inputs_ground_truth.drop_vars("datetime"), device) 146 | inputs_static = _to_jax_xarray(inputs_static, device) 147 | norm_soa = norm_original_fn(inputs_ground_truth["toa_incident_solar_radiation"]) 148 | norm_forcings = xarray.merge([norm_forcings, norm_soa]) 149 | norm_inputs_pred = norm_original_fn(inputs_pred) 150 | if 'total_precipitation_6hr' in norm_inputs_pred.data_vars and args.resolution == "0.25deg": 151 | norm_inputs_pred = norm_inputs_pred.drop_vars('total_precipitation_6hr') 152 | norm_static = norm_original_fn(inputs_static) 153 | if mode == "ddpm": 154 | corrected_pred, _ = forward_fn.apply(params, state, None, norm_inputs_pred, norm_forcings, norm_static, rng_batch, progress_bar=progress_bar) 155 | elif mode == "repaint" or mode == "dps": 156 | mask = batch["mask"] 157 | measurements_interp = batch["weatherbench_interp"].drop_vars(["datetime", "toa_incident_solar_radiation"]) 158 | mask = _to_jax_xarray(mask, device)["mask"] # convert from dataset to dataarray 159 | measurements_interp = _to_jax_xarray(measurements_interp, device) 160 | measurements_diff_interp = measurements_interp - inputs_pred 161 | norm_measurements_diff_interp = norm_diff_fn(measurements_diff_interp) 162 | corrected_pred, _ = forward_fn.apply(params, state, None, mask = mask, 163 | norm_measurements_diff_interp = norm_measurements_diff_interp, 164 | norm_inputs_pred = norm_inputs_pred, 165 | norm_forcings = norm_forcings, 166 | norm_static = norm_static, 167 | rng_batch = rng_batch, 168 | progress_bar = True,) 169 | else: 170 | raise ValueError(f"Unknown mode {mode}") 171 | 172 | diff = corrected_pred - inputs_ground_truth 173 | diff_gc = inputs_pred - inputs_ground_truth 174 | 175 | diff_500hPa = diff.sel(level=500) 176 | diff_gc_500hPa = diff_gc.sel(level=500) 177 | 178 | val_loss = {f"val/{mode}/diffusion_rmse_500hPa/{k}": jnp.sqrt(xarray_jax.unwrap_data((v*v).mean())).item() for k,v in diff_500hPa.data_vars.items()} 179 | val_loss_gc = {f"val/{mode}/graphcast_rmse500hPa/{k}": jnp.sqrt(xarray_jax.unwrap_data((v*v).mean())).item() for k,v in diff_gc_500hPa.data_vars.items()} 180 | #val_loss["rank"] = args.rank 181 | 182 | if args.rank == 0: 183 | pbar.set_postfix(val_loss_z500=val_loss[f"val/{mode}/diffusion_rmse_500hPa/geopotential"], loss_gc_z500=val_loss_gc[f"val/{mode}/graphcast_rmse500hPa/geopotential"]) 184 | if args.use_wandb: 185 | log_dict = {**val_loss, **val_loss_gc} 186 | wandb.log(log_dict) -------------------------------------------------------------------------------- /DiffDA/train_state.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The VDM Authors, Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Adapted from 17 | https://github.com/google-research/vdm/blob/main/train_state.py 18 | https://flax.readthedocs.io/en/latest/_modules/flax/training/train_state.html#TrainState. 19 | 20 | But with added EMA of the parameters. 21 | """ 22 | 23 | import copy 24 | from typing import Any, Callable, Optional 25 | 26 | from flax import core 27 | from flax import struct 28 | import jax 29 | import optax 30 | 31 | from diffusers.schedulers.scheduling_ddpm_flax import DDPMSchedulerState 32 | 33 | class TrainState(struct.PyTreeNode): 34 | """Simple train state for the common case with a single Optax optimizer. 35 | 36 | Synopsis: 37 | 38 | state = TrainState.create( 39 | apply_fn=model.apply, 40 | params=variables['params'], 41 | tx=tx) 42 | grad_fn = jax.grad(make_loss_fn(state.apply_fn)) 43 | for batch in data: 44 | grads = grad_fn(state.params, batch) 45 | state = state.apply_gradients(grads=grads) 46 | 47 | Note that you can easily extend this dataclass by subclassing it for storing 48 | additional data (e.g. additional variable collections). 49 | 50 | For more exotic usecases (e.g. multiple optimizers) it's probably best to 51 | fork the class and modify it. 52 | 53 | Attributes: 54 | step: Counter starts at 0 and is incremented by every call to 55 | `.apply_gradients()`. 56 | apply_fn: Usually set to `model.apply()`. Kept in this dataclass for 57 | convenience to have a shorter params list for the `train_step()` function 58 | in your training loop. 59 | tx: An Optax gradient transformation. 60 | opt_state: The state for `tx`. 61 | """ 62 | step: int 63 | params: core.FrozenDict[str, Any] 64 | opt_state: optax.OptState 65 | scheduler_state: DDPMSchedulerState 66 | tx_fn: Callable[[float], optax.GradientTransformation] = struct.field( 67 | pytree_node=False) 68 | apply_fn: Callable = struct.field(pytree_node=False) 69 | 70 | def apply_gradients(self, *, grads, lr, **kwargs): 71 | """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. 72 | 73 | Note that internally this function calls `.tx.update()` followed by a call 74 | to `optax.apply_updates()` to update `params` and `opt_state`. 75 | 76 | Args: 77 | grads: Gradients that have the same pytree structure as `.params`. 78 | **kwargs: Additional dataclass attributes that should be `.replace()`-ed. 79 | 80 | Returns: 81 | An updated instance of `self` with `step` incremented by one, `params` 82 | and `opt_state` updated by applying `grads`, and additional attributes 83 | replaced as specified by `kwargs`. 84 | """ 85 | tx = self.tx_fn(lr) 86 | updates, new_opt_state = tx.update( 87 | grads, self.opt_state, self.params) 88 | new_params = optax.apply_updates(self.params, updates) 89 | 90 | return self.replace( 91 | step=self.step + 1, 92 | params=new_params, 93 | opt_state=new_opt_state, 94 | **kwargs, 95 | ) 96 | 97 | @classmethod 98 | def create(_class, *, apply_fn, params, optax_optimizer, scheduler_state, **kwargs): 99 | """Creates a new instance with `step=0` and initialized `opt_state`.""" 100 | # _class is the TrainState class 101 | params = params 102 | opt_state = optax_optimizer(1.).init(params) 103 | return _class( 104 | step=0, 105 | apply_fn=apply_fn, 106 | params=params, 107 | tx_fn=optax_optimizer, 108 | opt_state=opt_state, 109 | scheduler_state=scheduler_state, 110 | **kwargs, 111 | ) -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:11.8.0-devel-ubuntu22.04 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | 4 | RUN apt-get update && apt-get install -y python3 python3-pip python3-venv git curl wget libeccodes0 tensorrt libnvinfer-dev libnvinfer-plugin-dev 5 | RUN pip install --upgrade pip && pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 6 | RUN pip install cartopy chex colabtools dask dm-haiku jraph matplotlib numpy \ 7 | pandas rtree scipy trimesh typing_extensions xarray \ 8 | ipython ipykernel jupyterlab notebook ipywidgets google-cloud-storage dm-tree \ 9 | cfgrib zarr gcsfs dask jax-dataloader tensorflow tensorboard-plugin-profile nvtx \ 10 | wandb 11 | 12 | RUN pip install --upgrade diffusers[flax] 13 | 14 | 15 | ENV TF_CPP_MIN_LOG_LEVEL=2 16 | 17 | COPY ./ /workspace/ 18 | RUN pip install -e /workspace 19 | 20 | CMD [ "/bin/bash" ] 21 | RUN mkdir -p /workspace 22 | WORKDIR /workspace -------------------------------------------------------------------------------- /GraphCast_src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/GraphCast_src/.DS_Store -------------------------------------------------------------------------------- /GraphCast_src/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /GraphCast_src/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /GraphCast_src/README.md: -------------------------------------------------------------------------------- 1 | # GraphCast: Learning skillful medium-range global weather forecasting 2 | 3 | This package contains example code to run and train [GraphCast](https://arxiv.org/abs/2212.12794). 4 | It also provides three pretrained models: 5 | 6 | 1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree 7 | resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017, 8 | 9 | 2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree 10 | resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from 11 | 1979 to 2015, useful to run a model with lower memory and compute constraints, 12 | 13 | 3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13 14 | pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on 15 | HRES data from 2016 to 2021. This model can be initialized from HRES data (does 16 | not require precipitation inputs). 17 | 18 | The model weights, normalization statistics, and example inputs are available on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast). 19 | 20 | Full model training requires downloading the 21 | [ERA5](https://www.ecmwf.int/en/forecasts/datasets/reanalysis-datasets/era5) 22 | dataset, available from [ECMWF](https://www.ecmwf.int/). 23 | 24 | ## Overview of files 25 | 26 | The best starting point is to open `graphcast_demo.ipynb` in [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/graphcast_demo.ipynb), which gives an 27 | example of loading data, generating random weights or load a pre-trained 28 | snapshot, generating predictions, computing the loss and computing gradients. 29 | The one-step implementation of GraphCast architecture, is provided in 30 | `graphcast.py`. 31 | 32 | ### Brief description of library files: 33 | 34 | * `autoregressive.py`: Wrapper used to run (and train) the one-step GraphCast 35 | to produce a sequence of predictions by auto-regressively feeding the 36 | outputs back as inputs at each step, in JAX a differentiable way. 37 | * `casting.py`: Wrapper used around GraphCast to make it work using 38 | BFloat16 precision. 39 | * `checkpoint.py`: Utils to serialize and deserialize trees. 40 | * `data_utils.py`: Utils for data preprocessing. 41 | * `deep_typed_graph_net.py`: General purpose deep graph neural network (GNN) 42 | that operates on `TypedGraph`'s where both inputs and outputs are flat 43 | vectors of features for each of the nodes and edges. `graphcast.py` uses 44 | three of these for the Grid2Mesh GNN, the Multi-mesh GNN and the Mesh2Grid 45 | GNN, respectively. 46 | * `graphcast.py`: The main GraphCast model architecture for one-step of 47 | predictions. 48 | * `grid_mesh_connectivity.py`: Tools for converting between regular grids on a 49 | sphere and triangular meshes. 50 | * `icosahedral_mesh.py`: Definition of an icosahedral multi-mesh. 51 | * `losses.py`: Loss computations, including latitude-weighting. 52 | * `model_utils.py`: Utilities to produce flat node and edge vector features 53 | from input grid data, and to manipulate the node output vectors back 54 | into a multilevel grid data. 55 | * `normalization.py`: Wrapper for the one-step GraphCast used to normalize 56 | inputs according to historical values, and targets according to historical 57 | time differences. 58 | * `predictor_base.py`: Defines the interface of the predictor, which GraphCast 59 | and all of the wrappers implement. 60 | * `rollout.py`: Similar to `autoregressive.py` but used only at inference time 61 | using a python loop to produce longer, but non-differentiable trajectories. 62 | * `typed_graph.py`: Definition of `TypedGraph`'s. 63 | * `typed_graph_net.py`: Implementation of simple graph neural network 64 | building blocks defined over `TypedGraph`'s that can be combined to build 65 | deeper models. 66 | * `xarray_jax.py`: A wrapper to let JAX work with `xarray`s. 67 | * `xarray_tree.py`: An implementation of tree.map_structure that works with 68 | `xarray`s. 69 | 70 | 71 | ### Dependencies. 72 | 73 | [Chex](https://github.com/deepmind/chex), 74 | [Dask](https://github.com/dask/dask), 75 | [Haiku](https://github.com/deepmind/dm-haiku), 76 | [JAX](https://github.com/google/jax), 77 | [JAXline](https://github.com/deepmind/jaxline), 78 | [Jraph](https://github.com/deepmind/jraph), 79 | [Numpy](https://numpy.org/), 80 | [Pandas](https://pandas.pydata.org/), 81 | [Python](https://www.python.org/), 82 | [SciPy](https://scipy.org/), 83 | [Tree](https://github.com/deepmind/tree), 84 | [Trimesh](https://github.com/mikedh/trimesh) and 85 | [XArray](https://github.com/pydata/xarray). 86 | 87 | 88 | ### License and attribution 89 | 90 | The Colab notebook and the associated code are licensed under the Apache 91 | License, Version 2.0. You may obtain a copy of the License at: 92 | https://www.apache.org/licenses/LICENSE-2.0. 93 | 94 | The model weights are made available for use under the terms of the Creative 95 | Commons Attribution-NonCommercial-ShareAlike 4.0 International 96 | (CC BY-NC-SA 4.0). You may obtain a copy of the License at: 97 | https://creativecommons.org/licenses/by-nc-sa/4.0/. 98 | 99 | The weights were trained on ECMWF's ERA5 and HRES data. The colab includes a few 100 | examples of ERA5 and HRES data that can be used as inputs to the models. 101 | ECMWF data product are subject to the following terms: 102 | 103 | 1. Copyright statement: Copyright "© 2023 European Centre for Medium-Range Weather Forecasts (ECMWF)". 104 | 2. Source www.ecmwf.int 105 | 3. Licence Statement: ECMWF data is published under a Creative Commons Attribution 4.0 International (CC BY 4.0). https://creativecommons.org/licenses/by/4.0/ 106 | 4. Disclaimer: ECMWF does not accept any liability whatsoever for any error or omission in the data, their availability, or for any loss or damage arising from their use. 107 | 108 | ### Disclaimer 109 | 110 | This is not an officially supported Google product. 111 | 112 | Copyright 2023 DeepMind Technologies Limited. 113 | 114 | ### Citation 115 | 116 | If you use this work, consider citing our [paper](https://arxiv.org/abs/2212.12794): 117 | 118 | ```latex 119 | @article{lam2022graphcast, 120 | title={GraphCast: Learning skillful medium-range global weather forecasting}, 121 | author={Remi Lam and Alvaro Sanchez-Gonzalez and Matthew Willson and Peter Wirnsberger and Meire Fortunato and Alexander Pritzel and Suman Ravuri and Timo Ewalds and Ferran Alet and Zach Eaton-Rosen and Weihua Hu and Alexander Merose and Stephan Hoyer and George Holland and Jacklynn Stott and Oriol Vinyals and Shakir Mohamed and Peter Battaglia}, 122 | year={2022}, 123 | eprint={2212.12794}, 124 | archivePrefix={arXiv}, 125 | primaryClass={cs.LG} 126 | } 127 | ``` 128 | -------------------------------------------------------------------------------- /GraphCast_src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/GraphCast_src/__init__.py -------------------------------------------------------------------------------- /GraphCast_src/buckets.py: -------------------------------------------------------------------------------- 1 | from google.cloud import storage 2 | from google.cloud.storage import Bucket 3 | from pathlib import Path 4 | import os 5 | 6 | """ 7 | Provides utilities to load and cache gcs buckets. 8 | """ 9 | 10 | def parse_file_parts(file_name) -> dict: 11 | return dict(part.split("-", 1) for part in file_name.split("_")) 12 | 13 | def authenticate_bucket() -> Bucket: 14 | # @title Authenticate with Google Cloud Storage 15 | # TODO: Figure out how to access a public cloud bucket without authentication. 16 | os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./gcs_key.json" 17 | gcs_client = storage.Client() 18 | gcs_bucket = gcs_client.get_bucket("dm_graphcast") 19 | return gcs_bucket 20 | 21 | def save_to_dir(blob, directory: Path, name: str) -> None: 22 | if not directory.exists(): 23 | directory.mkdir(parents=True, exist_ok=True) 24 | with open(directory / name, "wb") as f: 25 | f.write(blob.read()) -------------------------------------------------------------------------------- /GraphCast_src/graphcast/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/GraphCast_src/graphcast/__init__.py -------------------------------------------------------------------------------- /GraphCast_src/graphcast/casting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Wrappers that take care of casting.""" 15 | 16 | import contextlib 17 | from typing import Any, Mapping, Tuple 18 | 19 | import chex 20 | from graphcast import predictor_base 21 | import haiku as hk 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | import xarray 26 | 27 | 28 | PyTree = Any 29 | 30 | 31 | class Bfloat16Cast(predictor_base.Predictor): 32 | """Wrapper that casts all inputs to bfloat16 and outputs to targets dtype.""" 33 | 34 | def __init__(self, predictor: predictor_base.Predictor, enabled: bool = True): 35 | """Inits the wrapper. 36 | 37 | Args: 38 | predictor: predictor being wrapped. 39 | enabled: disables the wrapper if False, for simpler hyperparameter scans. 40 | 41 | """ 42 | self._enabled = enabled 43 | self._predictor = predictor 44 | 45 | def __call__(self, 46 | inputs: xarray.Dataset, 47 | targets_template: xarray.Dataset, 48 | forcings: xarray.Dataset, 49 | **kwargs 50 | ) -> xarray.Dataset: 51 | if not self._enabled: 52 | return self._predictor(inputs, targets_template, forcings, **kwargs) 53 | 54 | with bfloat16_variable_view(): 55 | predictions = self._predictor( 56 | *_all_inputs_to_bfloat16(inputs, targets_template, forcings), 57 | **kwargs,) 58 | 59 | predictions_dtype = infer_floating_dtype(predictions) 60 | if predictions_dtype != jnp.bfloat16: 61 | raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}') 62 | 63 | targets_dtype = infer_floating_dtype(targets_template) 64 | return tree_map_cast( 65 | predictions, input_dtype=jnp.bfloat16, output_dtype=targets_dtype) 66 | 67 | def loss(self, 68 | inputs: xarray.Dataset, 69 | targets: xarray.Dataset, 70 | forcings: xarray.Dataset, 71 | **kwargs, 72 | ) -> predictor_base.LossAndDiagnostics: 73 | if not self._enabled: 74 | return self._predictor.loss(inputs, targets, forcings, **kwargs) 75 | 76 | with bfloat16_variable_view(): 77 | loss, scalars = self._predictor.loss( 78 | *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs) 79 | 80 | if loss.dtype != jnp.bfloat16: 81 | raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}') 82 | 83 | targets_dtype = infer_floating_dtype(targets) 84 | 85 | # Note that casting back the loss to e.g. float32 should not affect data 86 | # types of the backwards pass, because the first thing the backwards pass 87 | # should do is to go backwards the casting op and cast back to bfloat16 88 | # (and xprofs seem to confirm this). 89 | return tree_map_cast((loss, scalars), 90 | input_dtype=jnp.bfloat16, output_dtype=targets_dtype) 91 | 92 | def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray 93 | self, 94 | inputs: xarray.Dataset, 95 | targets: xarray.Dataset, 96 | forcings: xarray.Dataset, 97 | **kwargs, 98 | ) -> Tuple[predictor_base.LossAndDiagnostics, 99 | xarray.Dataset]: 100 | if not self._enabled: 101 | return self._predictor.loss_and_predictions(inputs, targets, forcings, # pytype: disable=bad-return-type # jax-ndarray 102 | **kwargs) 103 | 104 | with bfloat16_variable_view(): 105 | (loss, scalars), predictions = self._predictor.loss_and_predictions( 106 | *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs) 107 | 108 | if loss.dtype != jnp.bfloat16: 109 | raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}') 110 | 111 | predictions_dtype = infer_floating_dtype(predictions) 112 | if predictions_dtype != jnp.bfloat16: 113 | raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}') 114 | 115 | targets_dtype = infer_floating_dtype(targets) 116 | return tree_map_cast(((loss, scalars), predictions), 117 | input_dtype=jnp.bfloat16, output_dtype=targets_dtype) 118 | 119 | 120 | def infer_floating_dtype(data_vars: Mapping[str, chex.Array]) -> np.dtype: 121 | """Infers a floating dtype from an input mapping of data.""" 122 | dtypes = { 123 | v.dtype 124 | for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)} 125 | if len(dtypes) != 1: 126 | dtypes_and_shapes = { 127 | k: (v.dtype, v.shape) 128 | for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)} 129 | raise ValueError( 130 | f'Did not found exactly one floating dtype {dtypes} in input variables:' 131 | f'{dtypes_and_shapes}') 132 | return list(dtypes)[0] 133 | 134 | 135 | def _all_inputs_to_bfloat16( 136 | inputs: xarray.Dataset, 137 | targets: xarray.Dataset, 138 | forcings: xarray.Dataset, 139 | ) -> Tuple[xarray.Dataset, 140 | xarray.Dataset, 141 | xarray.Dataset]: 142 | return (inputs.astype(jnp.bfloat16), 143 | jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets), 144 | forcings.astype(jnp.bfloat16)) 145 | 146 | 147 | def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype, 148 | ) -> PyTree: 149 | def cast_fn(x): 150 | if x.dtype == input_dtype: 151 | return x.astype(output_dtype) 152 | return jax.tree_map(cast_fn, inputs) 153 | 154 | 155 | @contextlib.contextmanager 156 | def bfloat16_variable_view(enabled: bool = True): 157 | """Context for Haiku modules with float32 params, but bfloat16 activations. 158 | 159 | It works as follows: 160 | * Every time a variable is requested to be created/set as np.bfloat16, 161 | it will create an underlying float32 variable, instead. 162 | * Every time a variable a variable is requested as bfloat16, it will check the 163 | variable is of float32 type, and cast the variable to bfloat16. 164 | 165 | Note the gradients are still computed and accumulated as float32, because 166 | the params returned by init are float32, so the gradient function with 167 | respect to the params will already include an implicit casting to float32. 168 | 169 | Args: 170 | enabled: Only enables bfloat16 behavior if True. 171 | 172 | Yields: 173 | None 174 | """ 175 | 176 | if enabled: 177 | with hk.custom_creator( 178 | _bfloat16_creator, state=True), hk.custom_getter( 179 | _bfloat16_getter, state=True), hk.custom_setter( 180 | _bfloat16_setter): 181 | yield 182 | else: 183 | yield 184 | 185 | 186 | def _bfloat16_creator(next_creator, shape, dtype, init, context): 187 | """Creates float32 variables when bfloat16 is requested.""" 188 | if context.original_dtype == jnp.bfloat16: 189 | dtype = jnp.float32 190 | return next_creator(shape, dtype, init) 191 | 192 | 193 | def _bfloat16_getter(next_getter, value, context): 194 | """Casts float32 to bfloat16 when bfloat16 was originally requested.""" 195 | if context.original_dtype == jnp.bfloat16: 196 | assert value.dtype == jnp.float32 197 | value = value.astype(jnp.bfloat16) 198 | return next_getter(value) 199 | 200 | 201 | def _bfloat16_setter(next_setter, value, context): 202 | """Casts bfloat16 to float32 when bfloat16 was originally set.""" 203 | if context.original_dtype == jnp.bfloat16: 204 | value = value.astype(jnp.float32) 205 | return next_setter(value) 206 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Serialize and deserialize trees.""" 15 | 16 | import dataclasses 17 | import io 18 | import types 19 | from typing import Any, BinaryIO, Optional, TypeVar 20 | 21 | import numpy as np 22 | 23 | _T = TypeVar("_T") 24 | 25 | 26 | def dump(dest: BinaryIO, value: Any) -> None: 27 | """Dump a tree of dicts/dataclasses to a file object. 28 | 29 | Args: 30 | dest: a file object to write to. 31 | value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and 32 | other basic types. Unions are not supported, other than Optional/None 33 | which is only supported in dataclasses, not in dicts, lists or tuples. 34 | All leaves must be coercible to a numpy array, and recoverable as a single 35 | arg to a type. 36 | """ 37 | buffer = io.BytesIO() # In case the destination doesn't support seeking. 38 | np.savez(buffer, **_flatten(value)) 39 | dest.write(buffer.getvalue()) 40 | 41 | 42 | def load(source: BinaryIO, typ: type[_T]) -> _T: 43 | """Load from a file object and convert it to the specified type. 44 | 45 | Args: 46 | source: a file object to read from. 47 | typ: a type object that acts as a schema for deserialization. It must match 48 | what was serialized. If a type is Any, it will be returned however numpy 49 | serialized it, which is what you want for a tree of numpy arrays. 50 | 51 | Returns: 52 | the deserialized value as the specified type. 53 | """ 54 | return _convert_types(typ, _unflatten(np.load(source))) 55 | 56 | 57 | _SEP = ":" 58 | 59 | 60 | def _flatten(tree: Any) -> dict[str, Any]: 61 | """Flatten a tree of dicts/dataclasses/lists/tuples to a single dict.""" 62 | if dataclasses.is_dataclass(tree): 63 | # Don't use dataclasses.asdict as it is recursive so skips dropping None. 64 | tree = {f.name: v for f in dataclasses.fields(tree) 65 | if (v := getattr(tree, f.name)) is not None} 66 | elif isinstance(tree, (list, tuple)): 67 | tree = dict(enumerate(tree)) 68 | 69 | assert isinstance(tree, dict) 70 | 71 | flat = {} 72 | for k, v in tree.items(): 73 | k = str(k) 74 | assert _SEP not in k 75 | if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)): 76 | for a, b in _flatten(v).items(): 77 | flat[f"{k}{_SEP}{a}"] = b 78 | else: 79 | assert v is not None 80 | flat[k] = v 81 | return flat 82 | 83 | 84 | def _unflatten(flat: dict[str, Any]) -> dict[str, Any]: 85 | """Unflatten a dict to a tree of dicts.""" 86 | tree = {} 87 | for flat_key, v in flat.items(): 88 | node = tree 89 | keys = flat_key.split(_SEP) 90 | for k in keys[:-1]: 91 | if k not in node: 92 | node[k] = {} 93 | node = node[k] 94 | node[keys[-1]] = v 95 | return tree 96 | 97 | 98 | def _convert_types(typ: type[_T], value: Any) -> _T: 99 | """Convert some structure into the given type. The structures must match.""" 100 | if typ in (Any, ...): 101 | return value 102 | 103 | if typ in (int, float, str, bool): 104 | return typ(value) 105 | 106 | if typ is np.ndarray: 107 | assert isinstance(value, np.ndarray) 108 | return value 109 | 110 | if dataclasses.is_dataclass(typ): 111 | kwargs = {} 112 | for f in dataclasses.fields(typ): 113 | # Only support Optional for dataclasses, as numpy can't serialize it 114 | # directly (without pickle), and dataclasses are the only case where we 115 | # can know the full set of values and types and therefore know the 116 | # non-existence must mean None. 117 | if isinstance(f.type, (types.UnionType, type(Optional[int]))): 118 | constructors = [t for t in f.type.__args__ if t is not types.NoneType] 119 | if len(constructors) != 1: 120 | raise TypeError( 121 | "Optional works, Union with anything except None doesn't") 122 | if f.name not in value: 123 | kwargs[f.name] = None 124 | continue 125 | constructor = constructors[0] 126 | else: 127 | constructor = f.type 128 | 129 | if f.name in value: 130 | kwargs[f.name] = _convert_types(constructor, value[f.name]) 131 | else: 132 | raise ValueError(f"Missing value: {f.name}") 133 | return typ(**kwargs) 134 | 135 | base_type = getattr(typ, "__origin__", None) 136 | 137 | if base_type is dict: 138 | assert len(typ.__args__) == 2 139 | key_type, value_type = typ.__args__ 140 | return {_convert_types(key_type, k): _convert_types(value_type, v) 141 | for k, v in value.items()} 142 | 143 | if base_type is list: 144 | assert len(typ.__args__) == 1 145 | value_type = typ.__args__[0] 146 | return [_convert_types(value_type, v) 147 | for _, v in sorted(value.items(), key=lambda x: int(x[0]))] 148 | 149 | if base_type is tuple: 150 | if len(typ.__args__) == 2 and typ.__args__[1] == ...: 151 | # An arbitrary length tuple of a single type, eg: tuple[int, ...] 152 | value_type = typ.__args__[0] 153 | return tuple(_convert_types(value_type, v) 154 | for _, v in sorted(value.items(), key=lambda x: int(x[0]))) 155 | else: 156 | # A fixed length tuple of arbitrary types, eg: tuple[int, str, float] 157 | assert len(typ.__args__) == len(value) 158 | return tuple( 159 | _convert_types(t, v) 160 | for t, (_, v) in zip( 161 | typ.__args__, sorted(value.items(), key=lambda x: int(x[0])))) 162 | 163 | # This is probably unreachable with reasonable serializable inputs. 164 | try: 165 | return typ(value) 166 | except TypeError as e: 167 | raise TypeError( 168 | "_convert_types expects the type argument to be a dataclass defined " 169 | "with types that are valid constructors (eg tuple is fine, Tuple " 170 | "isn't), and accept a numpy array as the sole argument.") from e 171 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/checkpoint_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Check that the checkpoint serialization is reversable.""" 15 | 16 | import dataclasses 17 | import io 18 | from typing import Any, Optional, Union 19 | 20 | from absl.testing import absltest 21 | from graphcast import checkpoint 22 | import numpy as np 23 | 24 | 25 | @dataclasses.dataclass 26 | class SubConfig: 27 | a: int 28 | b: str 29 | 30 | 31 | @dataclasses.dataclass 32 | class Config: 33 | bt: bool 34 | bf: bool 35 | i: int 36 | f: float 37 | o1: Optional[int] 38 | o2: Optional[int] 39 | o3: Union[int, None] 40 | o4: Union[int, None] 41 | o5: int | None 42 | o6: int | None 43 | li: list[int] 44 | ls: list[str] 45 | ldc: list[SubConfig] 46 | tf: tuple[float, ...] 47 | ts: tuple[str, ...] 48 | t: tuple[str, int, SubConfig] 49 | tdc: tuple[SubConfig, ...] 50 | dsi: dict[str, int] 51 | dss: dict[str, str] 52 | dis: dict[int, str] 53 | dsdis: dict[str, dict[int, str]] 54 | dc: SubConfig 55 | dco: Optional[SubConfig] 56 | ddc: dict[str, SubConfig] 57 | 58 | 59 | @dataclasses.dataclass 60 | class Checkpoint: 61 | params: dict[str, Any] 62 | config: Config 63 | 64 | 65 | class DataclassTest(absltest.TestCase): 66 | 67 | def test_serialize_dataclass(self): 68 | ckpt = Checkpoint( 69 | params={ 70 | "layer1": { 71 | "w": np.arange(10).reshape(2, 5), 72 | "b": np.array([2, 6]), 73 | }, 74 | "layer2": { 75 | "w": np.arange(8).reshape(2, 4), 76 | "b": np.array([2, 6]), 77 | }, 78 | "blah": np.array([3, 9]), 79 | }, 80 | config=Config( 81 | bt=True, 82 | bf=False, 83 | i=42, 84 | f=3.14, 85 | o1=1, 86 | o2=None, 87 | o3=2, 88 | o4=None, 89 | o5=3, 90 | o6=None, 91 | li=[12, 9, 7, 15, 16, 14, 1, 6, 11, 4, 10, 5, 13, 3, 8, 2], 92 | ls=list("qhjfdxtpzgemryoikwvblcaus"), 93 | ldc=[SubConfig(1, "hello"), SubConfig(2, "world")], 94 | tf=(1, 4, 2, 10, 5, 9, 13, 16, 15, 8, 12, 7, 11, 14, 3, 6), 95 | ts=("hello", "world"), 96 | t=("foo", 42, SubConfig(1, "bar")), 97 | tdc=(SubConfig(1, "hello"), SubConfig(2, "world")), 98 | dsi={"a": 1, "b": 2, "c": 3}, 99 | dss={"d": "e", "f": "g"}, 100 | dis={1: "a", 2: "b", 3: "c"}, 101 | dsdis={"a": {1: "hello", 2: "world"}, "b": {1: "world"}}, 102 | dc=SubConfig(1, "hello"), 103 | dco=None, 104 | ddc={"a": SubConfig(1, "hello"), "b": SubConfig(2, "world")}, 105 | )) 106 | 107 | buffer = io.BytesIO() 108 | checkpoint.dump(buffer, ckpt) 109 | buffer.seek(0) 110 | ckpt2 = checkpoint.load(buffer, Checkpoint) 111 | np.testing.assert_array_equal(ckpt.params["layer1"]["w"], 112 | ckpt2.params["layer1"]["w"]) 113 | np.testing.assert_array_equal(ckpt.params["layer1"]["b"], 114 | ckpt2.params["layer1"]["b"]) 115 | np.testing.assert_array_equal(ckpt.params["layer2"]["w"], 116 | ckpt2.params["layer2"]["w"]) 117 | np.testing.assert_array_equal(ckpt.params["layer2"]["b"], 118 | ckpt2.params["layer2"]["b"]) 119 | np.testing.assert_array_equal(ckpt.params["blah"], ckpt2.params["blah"]) 120 | self.assertEqual(ckpt.config, ckpt2.config) 121 | 122 | 123 | if __name__ == "__main__": 124 | absltest.main() 125 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataset utilities.""" 15 | 16 | from typing import Any, Mapping, Sequence, Tuple, Union 17 | 18 | import numpy as np 19 | import pandas as pd 20 | import xarray 21 | 22 | TimedeltaLike = Any # Something convertible to pd.Timedelta. 23 | TimedeltaStr = str # A string convertible to pd.Timedelta. 24 | 25 | TargetLeadTimes = Union[ 26 | TimedeltaLike, 27 | Sequence[TimedeltaLike], 28 | slice # with TimedeltaLike as its start and stop. 29 | ] 30 | 31 | _SEC_PER_HOUR = 3600 32 | _HOUR_PER_DAY = 24 33 | SEC_PER_DAY = _SEC_PER_HOUR * _HOUR_PER_DAY 34 | _AVG_DAY_PER_YEAR = 365.24219 35 | AVG_SEC_PER_YEAR = SEC_PER_DAY * _AVG_DAY_PER_YEAR 36 | 37 | DAY_PROGRESS = "day_progress" 38 | YEAR_PROGRESS = "year_progress" 39 | 40 | 41 | def get_year_progress(seconds_since_epoch: np.ndarray) -> np.ndarray: 42 | """Computes year progress for times in seconds. 43 | 44 | Args: 45 | seconds_since_epoch: Times in seconds since the "epoch" (the point at which 46 | UNIX time starts). 47 | 48 | Returns: 49 | Year progress normalized to be in the [0, 1) interval for each time point. 50 | """ 51 | 52 | # Start with the pure integer division, and then float at the very end. 53 | # We will try to keep as much precision as possible. 54 | years_since_epoch = ( 55 | seconds_since_epoch / SEC_PER_DAY / np.float64(_AVG_DAY_PER_YEAR) 56 | ) 57 | # Note depending on how these ops are down, we may end up with a "weak_type" 58 | # which can cause issues in subtle ways, and hard to track here. 59 | # In any case, casting to float32 should get rid of the weak type. 60 | # [0, 1.) Interval. 61 | return np.mod(years_since_epoch, 1.0).astype(np.float32) 62 | 63 | 64 | def get_day_progress( 65 | seconds_since_epoch: np.ndarray, 66 | longitude: np.ndarray, 67 | ) -> np.ndarray: 68 | """Computes day progress for times in seconds at each longitude. 69 | 70 | Args: 71 | seconds_since_epoch: 1D array of times in seconds since the 'epoch' (the 72 | point at which UNIX time starts). 73 | longitude: 1D array of longitudes at which day progress is computed. 74 | 75 | Returns: 76 | 2D array of day progress values normalized to be in the [0, 1) inverval 77 | for each time point at each longitude. 78 | """ 79 | 80 | # [0.0, 1.0) Interval. 81 | day_progress_greenwich = ( 82 | np.mod(seconds_since_epoch, SEC_PER_DAY) / SEC_PER_DAY 83 | ) 84 | 85 | # Offset the day progress to the longitude of each point on Earth. 86 | longitude_offsets = np.deg2rad(longitude) / (2 * np.pi) 87 | day_progress = np.mod( 88 | day_progress_greenwich[..., np.newaxis] + longitude_offsets, 1.0 89 | ) 90 | return day_progress.astype(np.float32) 91 | 92 | 93 | def featurize_progress( 94 | name: str, dims: Sequence[str], progress: np.ndarray 95 | ) -> Mapping[str, xarray.Variable]: 96 | """Derives features used by ML models from the `progress` variable. 97 | 98 | Args: 99 | name: Base variable name from which features are derived. 100 | dims: List of the output feature dimensions, e.g. ("day", "lon"). 101 | progress: Progress variable values. 102 | 103 | Returns: 104 | Dictionary of xarray variables derived from the `progress` values. It 105 | includes the original `progress` variable along with its sin and cos 106 | transformations. 107 | 108 | Raises: 109 | ValueError if the number of feature dimensions is not equal to the number 110 | of data dimensions. 111 | """ 112 | if len(dims) != progress.ndim: 113 | raise ValueError( 114 | f"Number of feature dimensions ({len(dims)}) must be equal to the" 115 | f" number of data dimensions: {progress.ndim}." 116 | ) 117 | progress_phase = progress * (2 * np.pi) 118 | return { 119 | name: xarray.Variable(dims, progress), 120 | name + "_sin": xarray.Variable(dims, np.sin(progress_phase)), 121 | name + "_cos": xarray.Variable(dims, np.cos(progress_phase)), 122 | } 123 | 124 | 125 | def add_derived_vars(data: xarray.Dataset) -> None: 126 | """Adds year and day progress features to `data` in place. 127 | 128 | NOTE: `toa_incident_solar_radiation` needs to be computed in this function 129 | as well. 130 | 131 | Args: 132 | data: Xarray dataset to which derived features will be added. 133 | 134 | Raises: 135 | ValueError if `datetime` or `lon` are not in `data` coordinates. 136 | """ 137 | 138 | for coord in ("datetime", "lon"): 139 | if coord not in data.coords: 140 | raise ValueError(f"'{coord}' must be in `data` coordinates.") 141 | 142 | # Compute seconds since epoch. 143 | # Note `data.coords["datetime"].astype("datetime64[s]").astype(np.int64)` 144 | # does not work as xarrays always cast dates into nanoseconds! 145 | seconds_since_epoch = ( 146 | data.coords["datetime"].data.astype("datetime64[s]").astype(np.int64) 147 | ) 148 | batch_dim = ("batch",) if "batch" in data.dims else () 149 | 150 | # Add year progress features. 151 | year_progress = get_year_progress(seconds_since_epoch) 152 | data.update( 153 | featurize_progress( 154 | name=YEAR_PROGRESS, dims=batch_dim + ("time",), progress=year_progress 155 | ) 156 | ) 157 | 158 | # Add day progress features. 159 | longitude_coord = data.coords["lon"] 160 | day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data) 161 | data.update( 162 | featurize_progress( 163 | name=DAY_PROGRESS, 164 | dims=batch_dim + ("time",) + longitude_coord.dims, 165 | progress=day_progress, 166 | ) 167 | ) 168 | 169 | 170 | def extract_input_target_times( 171 | dataset: xarray.Dataset, 172 | input_duration: TimedeltaLike, 173 | target_lead_times: TargetLeadTimes, 174 | ) -> Tuple[xarray.Dataset, xarray.Dataset]: 175 | """Extracts inputs and targets for prediction, from a Dataset with a time dim. 176 | 177 | The input period is assumed to be contiguous (specified by a duration), but 178 | the targets can be a list of arbitrary lead times. 179 | 180 | Examples: 181 | 182 | # Use 18 hours of data as inputs, and two specific lead times as targets: 183 | # 3 days and 5 days after the final input. 184 | extract_inputs_targets( 185 | dataset, 186 | input_duration='18h', 187 | target_lead_times=('3d', '5d') 188 | ) 189 | 190 | # Use 1 day of data as input, and all lead times between 6 hours and 191 | # 24 hours inclusive as targets. Demonstrates a friendlier supported string 192 | # syntax. 193 | extract_inputs_targets( 194 | dataset, 195 | input_duration='1 day', 196 | target_lead_times=slice('6 hours', '24 hours') 197 | ) 198 | 199 | # Just use a single target lead time of 3 days: 200 | extract_inputs_targets( 201 | dataset, 202 | input_duration='24h', 203 | target_lead_times='3d' 204 | ) 205 | 206 | Args: 207 | dataset: An xarray.Dataset with a 'time' dimension whose coordinates are 208 | timedeltas. It's assumed that the time coordinates have a fixed offset / 209 | time resolution, and that the input_duration and target_lead_times are 210 | multiples of this. 211 | input_duration: pandas.Timedelta or something convertible to it (e.g. a 212 | shorthand string like '6h' or '5d12h'). 213 | target_lead_times: Either a single lead time, a slice with start and stop 214 | (inclusive) lead times, or a sequence of lead times. Lead times should be 215 | Timedeltas (or something convertible to). They are given relative to the 216 | final input timestep, and should be positive. 217 | 218 | Returns: 219 | inputs: 220 | targets: 221 | Two datasets with the same shape as the input dataset except that a 222 | selection has been made from the time axis, and the origin of the 223 | time coordinate will be shifted to refer to lead times relative to the 224 | final input timestep. So for inputs the times will end at lead time 0, 225 | for targets the time coordinates will refer to the lead times requested. 226 | """ 227 | 228 | (target_lead_times, target_duration 229 | ) = _process_target_lead_times_and_get_duration(target_lead_times) 230 | 231 | # Shift the coordinates for the time axis so that a timedelta of zero 232 | # corresponds to the forecast reference time. That is, the final timestep 233 | # that's available as input to the forecast, with all following timesteps 234 | # forming the target period which needs to be predicted. 235 | # This means the time coordinates are now forecast lead times. 236 | time = dataset.coords["time"] 237 | dataset = dataset.assign_coords(time=time + target_duration - time[-1]) 238 | 239 | # Slice out targets: 240 | targets = dataset.sel({"time": target_lead_times}) 241 | 242 | input_duration = pd.Timedelta(input_duration) 243 | # Both endpoints are inclusive with label-based slicing, so we offset by a 244 | # small epsilon to make one of the endpoints non-inclusive: 245 | zero = pd.Timedelta(0) 246 | epsilon = pd.Timedelta(1, "ns") 247 | inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)}) 248 | return inputs, targets 249 | 250 | 251 | def _process_target_lead_times_and_get_duration( 252 | target_lead_times: TargetLeadTimes) -> TimedeltaLike: 253 | """Returns the minimum duration for the target lead times.""" 254 | if isinstance(target_lead_times, slice): 255 | # A slice of lead times. xarray already accepts timedelta-like values for 256 | # the begin/end/step of the slice. 257 | if target_lead_times.start is None: 258 | # If the start isn't specified, we assume it starts at the next timestep 259 | # after lead time 0 (lead time 0 is the final input timestep): 260 | target_lead_times = slice( 261 | pd.Timedelta(1, "ns"), target_lead_times.stop, target_lead_times.step 262 | ) 263 | target_duration = pd.Timedelta(target_lead_times.stop) 264 | else: 265 | if not isinstance(target_lead_times, (list, tuple, set)): 266 | # A single lead time, which we wrap as a length-1 array to ensure there 267 | # still remains a time dimension (here of length 1) for consistency. 268 | target_lead_times = [target_lead_times] 269 | 270 | # A list of multiple (not necessarily contiguous) lead times: 271 | target_lead_times = [pd.Timedelta(x) for x in target_lead_times] 272 | target_lead_times.sort() 273 | target_duration = target_lead_times[-1] 274 | return target_lead_times, target_duration 275 | 276 | 277 | def extract_inputs_targets_forcings( 278 | dataset: xarray.Dataset, 279 | *, 280 | input_variables: Tuple[str, ...], 281 | target_variables: Tuple[str, ...], 282 | forcing_variables: Tuple[str, ...], 283 | pressure_levels: Tuple[int, ...], 284 | input_duration: TimedeltaLike, 285 | target_lead_times: TargetLeadTimes, 286 | ) -> Tuple[xarray.Dataset, xarray.Dataset, xarray.Dataset]: 287 | """Extracts inputs, targets and forcings according to requirements.""" 288 | dataset = dataset.sel(level=list(pressure_levels)) 289 | 290 | # "Forcings" are derived variables and do not exist in the original ERA5 or 291 | # HRES datasets. Compute them if they are not in `dataset`. 292 | if not set(forcing_variables).issubset(set(dataset.data_vars)): 293 | add_derived_vars(dataset) 294 | 295 | # `datetime` is needed by add_derived_vars but breaks autoregressive rollouts. 296 | dataset = dataset.drop_vars("datetime") 297 | 298 | inputs, targets = extract_input_target_times( 299 | dataset, 300 | input_duration=input_duration, 301 | target_lead_times=target_lead_times) 302 | 303 | if set(forcing_variables) & set(target_variables): 304 | raise ValueError( 305 | f"Forcing variables {forcing_variables} should not " 306 | f"overlap with target variables {target_variables}." 307 | ) 308 | 309 | inputs = inputs[list(input_variables)] 310 | # The forcing uses the same time coordinates as the target. 311 | forcings = targets[list(forcing_variables)] 312 | targets = targets[list(target_variables)] 313 | 314 | return inputs, targets, forcings 315 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/data_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for `data_utils.py`.""" 15 | 16 | import datetime 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from graphcast import data_utils 20 | import numpy as np 21 | import xarray 22 | 23 | 24 | class DataUtilsTest(parameterized.TestCase): 25 | 26 | def setUp(self): 27 | super().setUp() 28 | # Fix the seed for reproducibility. 29 | np.random.seed(0) 30 | 31 | def test_year_progress_is_zero_at_year_start_or_end(self): 32 | year_progress = data_utils.get_year_progress( 33 | np.array([ 34 | 0, 35 | data_utils.AVG_SEC_PER_YEAR, 36 | data_utils.AVG_SEC_PER_YEAR * 42, # 42 years. 37 | ]) 38 | ) 39 | np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape)) 40 | 41 | def test_year_progress_is_almost_one_before_year_ends(self): 42 | year_progress = data_utils.get_year_progress( 43 | np.array([ 44 | data_utils.AVG_SEC_PER_YEAR - 1, 45 | (data_utils.AVG_SEC_PER_YEAR - 1) * 42, # ~42 years 46 | ]) 47 | ) 48 | with self.subTest("Year progress values are close to 1"): 49 | self.assertTrue(np.all(year_progress > 0.999)) 50 | with self.subTest("Year progress values != 1"): 51 | self.assertTrue(np.all(year_progress < 1.0)) 52 | 53 | def test_day_progress_computes_for_all_times_and_longitudes(self): 54 | times = np.random.randint(low=0, high=1e10, size=10) 55 | longitudes = np.arange(0, 360.0, 1.0) 56 | day_progress = data_utils.get_day_progress(times, longitudes) 57 | with self.subTest("Day progress is computed for all times and longinutes"): 58 | self.assertSequenceEqual( 59 | day_progress.shape, (len(times), len(longitudes)) 60 | ) 61 | 62 | @parameterized.named_parameters( 63 | dict( 64 | testcase_name="random_date_1", 65 | year=1988, 66 | month=11, 67 | day=7, 68 | hour=2, 69 | minute=45, 70 | second=34, 71 | ), 72 | dict( 73 | testcase_name="random_date_2", 74 | year=2022, 75 | month=3, 76 | day=12, 77 | hour=7, 78 | minute=1, 79 | second=0, 80 | ), 81 | ) 82 | def test_day_progress_is_in_between_zero_and_one( 83 | self, year, month, day, hour, minute, second 84 | ): 85 | # Datetime from a timestamp. 86 | dt = datetime.datetime(year, month, day, hour, minute, second) 87 | # Epoch time. 88 | epoch_time = datetime.datetime(1970, 1, 1) 89 | # Seconds since epoch. 90 | seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()]) 91 | 92 | # Longitudes with 1 degree resolution. 93 | longitudes = np.arange(0, 360.0, 1.0) 94 | 95 | day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes) 96 | with self.subTest("Day progress >= 0"): 97 | self.assertTrue(np.all(day_progress >= 0.0)) 98 | with self.subTest("Day progress < 1"): 99 | self.assertTrue(np.all(day_progress < 1.0)) 100 | 101 | def test_day_progress_is_zero_at_day_start_or_end(self): 102 | day_progress = data_utils.get_day_progress( 103 | seconds_since_epoch=np.array([ 104 | 0, 105 | data_utils.SEC_PER_DAY, 106 | data_utils.SEC_PER_DAY * 42, # 42 days. 107 | ]), 108 | longitude=np.array([0.0]), 109 | ) 110 | np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape)) 111 | 112 | def test_day_progress_specific_value(self): 113 | day_progress = data_utils.get_day_progress( 114 | seconds_since_epoch=np.array([123]), 115 | longitude=np.array([0.0]), 116 | ) 117 | np.testing.assert_array_almost_equal( 118 | day_progress, np.array([[0.00142361]]), decimal=6 119 | ) 120 | 121 | def test_featurize_progress_valid_values_and_dimensions(self): 122 | day_progress = np.array([0.0, 0.45, 0.213]) 123 | feature_dimensions = ("time",) 124 | progress_features = data_utils.featurize_progress( 125 | name="day_progress", dims=feature_dimensions, progress=day_progress 126 | ) 127 | for feature in progress_features.values(): 128 | with self.subTest(f"Valid dimensions for {feature}"): 129 | self.assertSequenceEqual(feature.dims, feature_dimensions) 130 | 131 | with self.subTest("Valid values for day_progress"): 132 | np.testing.assert_array_equal( 133 | day_progress, progress_features["day_progress"].values 134 | ) 135 | 136 | with self.subTest("Valid values for day_progress_sin"): 137 | np.testing.assert_array_almost_equal( 138 | np.array([0.0, 0.30901699, 0.97309851]), 139 | progress_features["day_progress_sin"].values, 140 | decimal=6, 141 | ) 142 | 143 | with self.subTest("Valid values for day_progress_cos"): 144 | np.testing.assert_array_almost_equal( 145 | np.array([1.0, -0.95105652, 0.23038943]), 146 | progress_features["day_progress_cos"].values, 147 | decimal=6, 148 | ) 149 | 150 | def test_featurize_progress_invalid_dimensions(self): 151 | year_progress = np.array([0.0, 0.45, 0.213]) 152 | feature_dimensions = ("time", "longitude") 153 | with self.assertRaises(ValueError): 154 | data_utils.featurize_progress( 155 | name="year_progress", dims=feature_dimensions, progress=year_progress 156 | ) 157 | 158 | def test_add_derived_vars_variables_added(self): 159 | data = xarray.Dataset( 160 | data_vars={ 161 | "var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3)) 162 | }, 163 | coords={ 164 | "lon": np.array([0.0, 0.5]), 165 | "datetime": np.array([ 166 | datetime.datetime(2021, 1, 1), 167 | datetime.datetime(2023, 1, 1), 168 | datetime.datetime(2023, 1, 3), 169 | ]), 170 | }, 171 | ) 172 | data_utils.add_derived_vars(data) 173 | all_variables = set(data.variables) 174 | 175 | with self.subTest("Original value was not removed"): 176 | self.assertIn("var1", all_variables) 177 | with self.subTest("Year progress feature was added"): 178 | self.assertIn(data_utils.YEAR_PROGRESS, all_variables) 179 | with self.subTest("Day progress feature was added"): 180 | self.assertIn(data_utils.DAY_PROGRESS, all_variables) 181 | 182 | @parameterized.named_parameters( 183 | dict(testcase_name="missing_datetime", coord_name="lon"), 184 | dict(testcase_name="missing_lon", coord_name="datetime"), 185 | ) 186 | def test_add_derived_vars_missing_coordinate_raises_value_error( 187 | self, coord_name 188 | ): 189 | with self.subTest(f"Missing {coord_name} coordinate"): 190 | data = xarray.Dataset( 191 | data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))}, 192 | coords={ 193 | coord_name: np.array([0.0, 0.5]), 194 | }, 195 | ) 196 | with self.assertRaises(ValueError): 197 | data_utils.add_derived_vars(data) 198 | 199 | 200 | if __name__ == "__main__": 201 | absltest.main() 202 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/grid_mesh_connectivity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tools for converting from regular grids on a sphere, to triangular meshes.""" 15 | 16 | from graphcast import icosahedral_mesh 17 | import numpy as np 18 | import scipy 19 | import trimesh 20 | 21 | 22 | def _grid_lat_lon_to_coordinates( 23 | grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray: 24 | """Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3].""" 25 | # Convert to spherical coordinates phi and theta defined in the grid. 26 | # Each [num_latitude_points, num_longitude_points] 27 | phi_grid, theta_grid = np.meshgrid( 28 | np.deg2rad(grid_longitude), 29 | np.deg2rad(90 - grid_latitude)) 30 | 31 | # [num_latitude_points, num_longitude_points, 3] 32 | # Note this assumes unit radius, since for now we model the earth as a 33 | # sphere of unit radius, and keep any vertical dimension as a regular grid. 34 | return np.stack( 35 | [np.cos(phi_grid)*np.sin(theta_grid), 36 | np.sin(phi_grid)*np.sin(theta_grid), 37 | np.cos(theta_grid)], axis=-1) 38 | 39 | 40 | def radius_query_indices( 41 | *, 42 | grid_latitude: np.ndarray, 43 | grid_longitude: np.ndarray, 44 | mesh: icosahedral_mesh.TriangularMesh, 45 | radius: float) -> tuple[np.ndarray, np.ndarray]: 46 | """Returns mesh-grid edge indices for radius query. 47 | 48 | Args: 49 | grid_latitude: Latitude values for the grid [num_lat_points] 50 | grid_longitude: Longitude values for the grid [num_lon_points] 51 | mesh: Mesh object. 52 | radius: Radius of connectivity in R3. for a sphere of unit radius. 53 | 54 | Returns: 55 | tuple with `grid_indices` and `mesh_indices` indicating edges between the 56 | grid and the mesh such that the distances in a straight line (not geodesic) 57 | are smaller than or equal to `radius`. 58 | * grid_indices: Indices of shape [num_edges], that index into a 59 | [num_lat_points, num_lon_points] grid, after flattening the leading axes. 60 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. 61 | """ 62 | 63 | # [num_grid_points=num_lat_points * num_lon_points, 3] 64 | grid_positions = _grid_lat_lon_to_coordinates( 65 | grid_latitude, grid_longitude).reshape([-1, 3]) 66 | 67 | # [num_mesh_points, 3] 68 | mesh_positions = mesh.vertices 69 | kd_tree = scipy.spatial.cKDTree(mesh_positions) 70 | 71 | # [num_grid_points, num_mesh_points_per_grid_point] 72 | # Note `num_mesh_points_per_grid_point` is not constant, so this is a list 73 | # of arrays, rather than a 2d array. 74 | query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius) 75 | 76 | grid_edge_indices = [] 77 | mesh_edge_indices = [] 78 | for grid_index, mesh_neighbors in enumerate(query_indices): 79 | grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors))) 80 | mesh_edge_indices.append(mesh_neighbors) 81 | 82 | # [num_edges] 83 | grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int) 84 | mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int) 85 | 86 | return grid_edge_indices, mesh_edge_indices 87 | 88 | 89 | def in_mesh_triangle_indices( 90 | *, 91 | grid_latitude: np.ndarray, 92 | grid_longitude: np.ndarray, 93 | mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]: 94 | """Returns mesh-grid edge indices for grid points contained in mesh triangles. 95 | 96 | Args: 97 | grid_latitude: Latitude values for the grid [num_lat_points] 98 | grid_longitude: Longitude values for the grid [num_lon_points] 99 | mesh: Mesh object. 100 | 101 | Returns: 102 | tuple with `grid_indices` and `mesh_indices` indicating edges between the 103 | grid and the mesh vertices of the triangle that contain each grid point. 104 | The number of edges is always num_lat_points * num_lon_points * 3 105 | * grid_indices: Indices of shape [num_edges], that index into a 106 | [num_lat_points, num_lon_points] grid, after flattening the leading axes. 107 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. 108 | """ 109 | 110 | # [num_grid_points=num_lat_points * num_lon_points, 3] 111 | grid_positions = _grid_lat_lon_to_coordinates( 112 | grid_latitude, grid_longitude).reshape([-1, 3]) 113 | 114 | mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) 115 | 116 | # [num_grid_points] with mesh face indices for each grid point. 117 | _, _, query_face_indices = trimesh.proximity.closest_point( 118 | mesh_trimesh, grid_positions) 119 | 120 | # [num_grid_points, 3] with mesh node indices for each grid point. 121 | mesh_edge_indices = mesh.faces[query_face_indices] 122 | 123 | # [num_grid_points, 3] with grid node indices, where every row simply contains 124 | # the row (grid_point) index. 125 | grid_indices = np.arange(grid_positions.shape[0]) 126 | grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3]) 127 | 128 | # Flatten to get a regular list. 129 | # [num_edges=num_grid_points*3] 130 | mesh_edge_indices = mesh_edge_indices.reshape([-1]) 131 | grid_edge_indices = grid_edge_indices.reshape([-1]) 132 | 133 | return grid_edge_indices, mesh_edge_indices 134 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/grid_mesh_connectivity_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for graphcast.grid_mesh_connectivity.""" 15 | 16 | from absl.testing import absltest 17 | from graphcast import grid_mesh_connectivity 18 | from graphcast import icosahedral_mesh 19 | import numpy as np 20 | 21 | 22 | class GridMeshConnectivityTest(absltest.TestCase): 23 | 24 | def test_grid_lat_lon_to_coordinates(self): 25 | 26 | # Intervals of 30 degrees. 27 | grid_latitude = np.array([-45., 0., 45]) 28 | grid_longitude = np.array([0., 90., 180., 270.]) 29 | 30 | inv_sqrt2 = 1 / np.sqrt(2) 31 | expected_coordinates = np.array([ 32 | [[inv_sqrt2, 0., -inv_sqrt2], 33 | [0., inv_sqrt2, -inv_sqrt2], 34 | [-inv_sqrt2, 0., -inv_sqrt2], 35 | [0., -inv_sqrt2, -inv_sqrt2]], 36 | [[1., 0., 0.], 37 | [0., 1., 0.], 38 | [-1., 0., 0.], 39 | [0., -1., 0.]], 40 | [[inv_sqrt2, 0., inv_sqrt2], 41 | [0., inv_sqrt2, inv_sqrt2], 42 | [-inv_sqrt2, 0., inv_sqrt2], 43 | [0., -inv_sqrt2, inv_sqrt2]], 44 | ]) 45 | 46 | coordinates = grid_mesh_connectivity._grid_lat_lon_to_coordinates( 47 | grid_latitude, grid_longitude) 48 | np.testing.assert_allclose(expected_coordinates, coordinates, atol=1e-15) 49 | 50 | def test_radius_query_indices_smoke(self): 51 | # TODO(alvarosg): Add non-smoke test? 52 | grid_latitude = np.linspace(-75, 75, 6) 53 | grid_longitude = np.arange(12) * 30. 54 | mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( 55 | splits=3)[-1] 56 | grid_mesh_connectivity.radius_query_indices( 57 | grid_latitude=grid_latitude, 58 | grid_longitude=grid_longitude, 59 | mesh=mesh, radius=0.2) 60 | 61 | def test_in_mesh_triangle_indices_smoke(self): 62 | # TODO(alvarosg): Add non-smoke test? 63 | grid_latitude = np.linspace(-75, 75, 6) 64 | grid_longitude = np.arange(12) * 30. 65 | mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( 66 | splits=3)[-1] 67 | grid_mesh_connectivity.in_mesh_triangle_indices( 68 | grid_latitude=grid_latitude, 69 | grid_longitude=grid_longitude, 70 | mesh=mesh) 71 | 72 | 73 | if __name__ == "__main__": 74 | absltest.main() 75 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/icosahedral_mesh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for creating icosahedral meshes.""" 15 | 16 | import itertools 17 | from typing import List, NamedTuple, Sequence, Tuple 18 | 19 | import numpy as np 20 | from scipy.spatial import transform 21 | 22 | 23 | class TriangularMesh(NamedTuple): 24 | """Data structure for triangular meshes. 25 | 26 | Attributes: 27 | vertices: spatial positions of the vertices of the mesh of shape 28 | [num_vertices, num_dims]. 29 | faces: triangular faces of the mesh of shape [num_faces, 3]. Contains 30 | integer indices into `vertices`. 31 | 32 | """ 33 | vertices: np.ndarray 34 | faces: np.ndarray 35 | 36 | 37 | def merge_meshes( 38 | mesh_list: Sequence[TriangularMesh]) -> TriangularMesh: 39 | """Merges all meshes into one. Assumes the last mesh is the finest. 40 | 41 | Args: 42 | mesh_list: Sequence of meshes, from coarse to fine refinement levels. The 43 | vertices and faces may contain those from preceding, coarser levels. 44 | 45 | Returns: 46 | `TriangularMesh` for which the vertices correspond to the highest 47 | resolution mesh in the hierarchy, and the faces are the join set of the 48 | faces at all levels of the hierarchy. 49 | """ 50 | for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list): 51 | num_nodes_mesh_i = mesh_i.vertices.shape[0] 52 | assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i]) 53 | 54 | return TriangularMesh( 55 | vertices=mesh_list[-1].vertices, 56 | faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0)) 57 | 58 | 59 | def get_hierarchy_of_triangular_meshes_for_sphere( 60 | splits: int) -> List[TriangularMesh]: 61 | """Returns a sequence of meshes, each with triangularization sphere. 62 | 63 | Starting with a regular icosahedron (12 vertices, 20 faces, 30 edges) with 64 | circumscribed unit sphere. Then, each triangular face is iteratively 65 | subdivided into 4 triangular faces `splits` times. The new vertices are then 66 | projected back onto the unit sphere. All resulting meshes are returned in a 67 | list, from lowest to highest resolution. 68 | 69 | The vertices in each face are specified in counter-clockwise order as 70 | observed from the outside the icosahedron. 71 | 72 | Args: 73 | splits: How many times to split each triangle. 74 | Returns: 75 | Sequence of `TriangularMesh`s of length `splits + 1` each with: 76 | 77 | vertices: [num_vertices, 3] vertex positions in 3D, all with unit norm. 78 | faces: [num_faces, 3] with triangular faces joining sets of 3 vertices. 79 | Each row contains three indices into the vertices array, indicating 80 | the vertices adjacent to the face. Always with positive orientation 81 | (counterclock-wise when looking from the outside). 82 | """ 83 | current_mesh = get_icosahedron() 84 | output_meshes = [current_mesh] 85 | for _ in range(splits): 86 | current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh) 87 | output_meshes.append(current_mesh) 88 | return output_meshes 89 | 90 | 91 | def get_icosahedron() -> TriangularMesh: 92 | """Returns a regular icosahedral mesh with circumscribed unit sphere. 93 | 94 | See https://en.wikipedia.org/wiki/Regular_icosahedron#Cartesian_coordinates 95 | for details on the construction of the regular icosahedron. 96 | 97 | The vertices in each face are specified in counter-clockwise order as observed 98 | from the outside of the icosahedron. 99 | 100 | Returns: 101 | TriangularMesh with: 102 | 103 | vertices: [num_vertices=12, 3] vertex positions in 3D, all with unit norm. 104 | faces: [num_faces=20, 3] with triangular faces joining sets of 3 vertices. 105 | Each row contains three indices into the vertices array, indicating 106 | the vertices adjacent to the face. Always with positive orientation ( 107 | counterclock-wise when looking from the outside). 108 | 109 | """ 110 | phi = (1 + np.sqrt(5)) / 2 111 | vertices = [] 112 | for c1 in [1., -1.]: 113 | for c2 in [phi, -phi]: 114 | vertices.append((c1, c2, 0.)) 115 | vertices.append((0., c1, c2)) 116 | vertices.append((c2, 0., c1)) 117 | 118 | vertices = np.array(vertices, dtype=np.float32) 119 | vertices /= np.linalg.norm([1., phi]) 120 | 121 | # I did this manually, checking the orientation one by one. 122 | faces = [(0, 1, 2), 123 | (0, 6, 1), 124 | (8, 0, 2), 125 | (8, 4, 0), 126 | (3, 8, 2), 127 | (3, 2, 7), 128 | (7, 2, 1), 129 | (0, 4, 6), 130 | (4, 11, 6), 131 | (6, 11, 5), 132 | (1, 5, 7), 133 | (4, 10, 11), 134 | (4, 8, 10), 135 | (10, 8, 3), 136 | (10, 3, 9), 137 | (11, 10, 9), 138 | (11, 9, 5), 139 | (5, 9, 7), 140 | (9, 3, 7), 141 | (1, 6, 5), 142 | ] 143 | 144 | # By default the top is an aris parallel to the Y axis. 145 | # Need to rotate around the y axis by half the supplementary to the 146 | # angle between faces divided by two to get the desired orientation. 147 | # /O\ (top arist) 148 | # / \ Z 149 | # (adjacent face)/ \ (adjacent face) ^ 150 | # / angle_between_faces \ | 151 | # / \ | 152 | # / \ YO-----> X 153 | # This results in: 154 | # (adjacent faceis now top plane) 155 | # ----------------------O\ (top arist) 156 | # \ 157 | # \ 158 | # \ (adjacent face) 159 | # \ 160 | # \ 161 | # \ 162 | 163 | angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3)) 164 | rotation_angle = (np.pi - angle_between_faces) / 2 165 | rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle) 166 | rotation_matrix = rotation.as_matrix() 167 | vertices = np.dot(vertices, rotation_matrix) 168 | 169 | return TriangularMesh(vertices=vertices.astype(np.float32), 170 | faces=np.array(faces, dtype=np.int32)) 171 | 172 | 173 | def _two_split_unit_sphere_triangle_faces( 174 | triangular_mesh: TriangularMesh) -> TriangularMesh: 175 | """Splits each triangular face into 4 triangles keeping the orientation.""" 176 | 177 | # Every time we split a triangle into 4 we will be adding 3 extra vertices, 178 | # located at the edge centres. 179 | # This class handles the positioning of the new vertices, and avoids creating 180 | # duplicates. 181 | new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices) 182 | 183 | new_faces = [] 184 | for ind1, ind2, ind3 in triangular_mesh.faces: 185 | # Transform each triangular face into 4 triangles, 186 | # preserving the orientation. 187 | # ind3 188 | # / \ 189 | # / \ 190 | # / #3 \ 191 | # / \ 192 | # ind31 -------------- ind23 193 | # / \ / \ 194 | # / \ #4 / \ 195 | # / #1 \ / #2 \ 196 | # / \ / \ 197 | # ind1 ------------ ind12 ------------ ind2 198 | ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2)) 199 | ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3)) 200 | ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1)) 201 | # Note how each of the 4 triangular new faces specifies the order of the 202 | # vertices to preserve the orientation of the original face. As the input 203 | # face should always be counter-clockwise as specified in the diagram, 204 | # this means child faces should also be counter-clockwise. 205 | new_faces.extend([[ind1, ind12, ind31], # 1 206 | [ind12, ind2, ind23], # 2 207 | [ind31, ind23, ind3], # 3 208 | [ind12, ind23, ind31], # 4 209 | ]) 210 | return TriangularMesh(vertices=new_vertices_builder.get_all_vertices(), 211 | faces=np.array(new_faces, dtype=np.int32)) 212 | 213 | 214 | class _ChildVerticesBuilder(object): 215 | """Bookkeeping of new child vertices added to an existing set of vertices.""" 216 | 217 | def __init__(self, parent_vertices): 218 | 219 | # Because the same new vertex will be required when splitting adjacent 220 | # triangles (which share an edge) we keep them in a hash table indexed by 221 | # sorted indices of the vertices adjacent to the edge, to avoid creating 222 | # duplicated child vertices. 223 | self._child_vertices_index_mapping = {} 224 | self._parent_vertices = parent_vertices 225 | # We start with all previous vertices. 226 | self._all_vertices_list = list(parent_vertices) 227 | 228 | def _get_child_vertex_key(self, parent_vertex_indices): 229 | return tuple(sorted(parent_vertex_indices)) 230 | 231 | def _create_child_vertex(self, parent_vertex_indices): 232 | """Creates a new vertex.""" 233 | # Position for new vertex is the middle point, between the parent points, 234 | # projected to unit sphere. 235 | child_vertex_position = self._parent_vertices[ 236 | list(parent_vertex_indices)].mean(0) 237 | child_vertex_position /= np.linalg.norm(child_vertex_position) 238 | 239 | # Add the vertex to the output list. The index for this new vertex will 240 | # match the length of the list before adding it. 241 | child_vertex_key = self._get_child_vertex_key(parent_vertex_indices) 242 | self._child_vertices_index_mapping[child_vertex_key] = len( 243 | self._all_vertices_list) 244 | self._all_vertices_list.append(child_vertex_position) 245 | 246 | def get_new_child_vertex_index(self, parent_vertex_indices): 247 | """Returns index for a child vertex, creating it if necessary.""" 248 | # Get the key to see if we already have a new vertex in the middle. 249 | child_vertex_key = self._get_child_vertex_key(parent_vertex_indices) 250 | if child_vertex_key not in self._child_vertices_index_mapping: 251 | self._create_child_vertex(parent_vertex_indices) 252 | return self._child_vertices_index_mapping[child_vertex_key] 253 | 254 | def get_all_vertices(self): 255 | """Returns an array with old vertices.""" 256 | return np.array(self._all_vertices_list) 257 | 258 | 259 | def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 260 | """Transforms polygonal faces to sender and receiver indices. 261 | 262 | It does so by transforming every face into N_i edges. Such if the triangular 263 | face has indices [0, 1, 2], three edges are added 0->1, 1->2, and 2->0. 264 | 265 | If all faces have consistent orientation, and the surface represented by the 266 | faces is closed, then every edge in a polygon with a certain orientation 267 | is also part of another polygon with the opposite orientation. In this 268 | situation, the edges returned by the method are always bidirectional. 269 | 270 | Args: 271 | faces: Integer array of shape [num_faces, 3]. Contains node indices 272 | adjacent to each face. 273 | Returns: 274 | Tuple with sender/receiver indices, each of shape [num_edges=num_faces*3]. 275 | 276 | """ 277 | assert faces.ndim == 2 278 | assert faces.shape[-1] == 3 279 | senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]]) 280 | receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]]) 281 | return senders, receivers 282 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/icosahedral_mesh_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for icosahedral_mesh.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | import chex 19 | from graphcast import icosahedral_mesh 20 | import numpy as np 21 | 22 | 23 | def _get_mesh_spec(splits: int): 24 | """Returns size of the final icosahedral mesh resulting from the splitting.""" 25 | num_vertices = 12 26 | num_faces = 20 27 | for _ in range(splits): 28 | # Each previous face adds three new vertices, but each vertex is shared 29 | # by two faces. 30 | num_vertices += num_faces * 3 // 2 31 | num_faces *= 4 32 | return num_vertices, num_faces 33 | 34 | 35 | class IcosahedralMeshTest(parameterized.TestCase): 36 | 37 | def test_icosahedron(self): 38 | mesh = icosahedral_mesh.get_icosahedron() 39 | _assert_valid_mesh( 40 | mesh, num_expected_vertices=12, num_expected_faces=20) 41 | 42 | @parameterized.parameters(list(range(5))) 43 | def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits): 44 | meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( 45 | splits=splits) 46 | prev_vertices = None 47 | for mesh_i, mesh in enumerate(meshes): 48 | # Check that `mesh` is valid. 49 | num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i) 50 | _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces) 51 | 52 | # Check that the first N vertices from this mesh match all of the 53 | # vertices from the previous mesh. 54 | if prev_vertices is not None: 55 | leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]] 56 | np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices) 57 | 58 | # Increase the expected/previous values for the next iteration. 59 | if mesh_i < len(meshes) - 1: 60 | prev_vertices = mesh.vertices 61 | 62 | @parameterized.parameters(list(range(4))) 63 | def test_merge_meshes(self, splits): 64 | mesh_hierarchy = ( 65 | icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( 66 | splits=splits)) 67 | mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy) 68 | 69 | expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0) 70 | np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices) 71 | np.testing.assert_array_equal(mesh.faces, expected_faces) 72 | 73 | def test_faces_to_edges(self): 74 | 75 | faces = np.array([[0, 1, 2], 76 | [3, 4, 5]]) 77 | 78 | # This also documents the order of the edges returned by the method. 79 | expected_edges = np.array( 80 | [[0, 1], 81 | [3, 4], 82 | [1, 2], 83 | [4, 5], 84 | [2, 0], 85 | [5, 3]]) 86 | expected_senders = expected_edges[:, 0] 87 | expected_receivers = expected_edges[:, 1] 88 | 89 | senders, receivers = icosahedral_mesh.faces_to_edges(faces) 90 | 91 | np.testing.assert_array_equal(senders, expected_senders) 92 | np.testing.assert_array_equal(receivers, expected_receivers) 93 | 94 | 95 | def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces): 96 | vertices = mesh.vertices 97 | faces = mesh.faces 98 | chex.assert_shape(vertices, [num_expected_vertices, 3]) 99 | chex.assert_shape(faces, [num_expected_faces, 3]) 100 | 101 | # Vertices norm should be 1. 102 | vertices_norm = np.linalg.norm(vertices, axis=-1) 103 | np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6) 104 | 105 | _assert_positive_face_orientation(vertices, faces) 106 | 107 | 108 | def _assert_positive_face_orientation(vertices, faces): 109 | 110 | # Obtain a unit vector that points, in the direction of the face. 111 | face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]], 112 | vertices[faces[:, 2]] - vertices[faces[:, 1]]) 113 | face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True) 114 | 115 | # And a unit vector pointing from the origin to the center of the face. 116 | face_centers = vertices[faces].mean(1) 117 | face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True) 118 | 119 | # Positive orientation means those two vectors should be parallel 120 | # (dot product, 1), and not anti-parallel (dot product, -1). 121 | dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers) 122 | 123 | # Check that the face normal is parallel to the vector that joins the center 124 | # of the face to the center of the sphere. Note we need a small tolerance 125 | # because some discretizations are not exactly uniform, so it will not be 126 | # exactly parallel. 127 | np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4) 128 | 129 | 130 | if __name__ == "__main__": 131 | absltest.main() 132 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Loss functions (and terms for use in loss functions) used for weather.""" 15 | 16 | from typing import Mapping, Callable 17 | 18 | from graphcast import xarray_tree 19 | import numpy as np 20 | from typing_extensions import Protocol 21 | import xarray 22 | 23 | 24 | LossAndDiagnostics = tuple[xarray.DataArray, xarray.Dataset] 25 | 26 | 27 | class LossFunction(Protocol): 28 | """A loss function. 29 | 30 | This is a protocol so it's fine to use a plain function which 'quacks like' 31 | this. This is just to document the interface. 32 | """ 33 | 34 | def __call__(self, 35 | predictions: xarray.Dataset, 36 | targets: xarray.Dataset, 37 | **optional_kwargs) -> LossAndDiagnostics: 38 | """Computes a loss function. 39 | 40 | Args: 41 | predictions: Dataset of predictions. 42 | targets: Dataset of targets. 43 | **optional_kwargs: Implementations may support extra optional kwargs. 44 | 45 | Returns: 46 | loss: A DataArray with dimensions ('batch',) containing losses for each 47 | element of the batch. These will be averaged to give the final 48 | loss, locally and across replicas. 49 | diagnostics: Mapping of additional quantities to log by name alongside the 50 | loss. These will will typically correspond to terms in the loss. They 51 | should also have dimensions ('batch',) and will be averaged over the 52 | batch before logging. 53 | """ 54 | 55 | 56 | def weighted_mse_per_level(predictions: xarray.Dataset, 57 | targets: xarray.Dataset, 58 | per_variable_weights: Mapping[str, float] 59 | ) -> LossAndDiagnostics: 60 | """Latitude- and pressure-level-weighted MSE loss.""" 61 | def mse_loss(prediction, target): 62 | return (prediction - target)**2 63 | 64 | return weighted_error_per_level(predictions, 65 | targets, 66 | per_variable_weights, 67 | mse_loss, 68 | normalized_level_weights) 69 | 70 | def weighted_error_per_level(predictions: xarray.Dataset, 71 | targets: xarray.Dataset, 72 | per_variable_weights: Mapping[str, float], 73 | loss_fn, 74 | per_level_weights_fn = None 75 | ) -> LossAndDiagnostics: 76 | """ 77 | Compute the loss function weighted per variable and per level. 78 | Moreover, weights the latitudes to account for unequal distribution of grid points on the sphere. 79 | """ 80 | def weighted_loss(prediction, target): 81 | loss = loss_fn(prediction, target) 82 | loss *= normalized_latitude_weights(target).astype(loss.dtype) 83 | if 'level' in target.dims: 84 | loss *= per_level_weights_fn(target).astype(loss.dtype) 85 | return _mean_preserving_batch(loss) 86 | 87 | losses = xarray_tree.map_structure(weighted_loss, predictions, targets) 88 | return sum_per_variable_losses(losses, per_variable_weights) 89 | 90 | 91 | def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray: 92 | return x.mean([d for d in x.dims if d != 'batch'], skipna=False) 93 | 94 | 95 | def sum_per_variable_losses( 96 | per_variable_losses: Mapping[str, xarray.DataArray], 97 | weights: Mapping[str, float], 98 | ) -> LossAndDiagnostics: 99 | """Weighted sum of per-variable losses.""" 100 | if not set(weights.keys()).issubset(set(per_variable_losses.keys())): 101 | raise ValueError( 102 | 'Passing a weight that does not correspond to any variable ' 103 | f'{set(weights.keys())-set(per_variable_losses.keys())}') 104 | 105 | weighted_per_variable_losses = { 106 | name: loss * weights.get(name, 1) 107 | for name, loss in per_variable_losses.items() 108 | } 109 | total = xarray.concat( 110 | weighted_per_variable_losses.values(), dim='variable', join='exact').sum( 111 | 'variable', skipna=False) 112 | return total, per_variable_losses 113 | 114 | 115 | def normalized_level_weights(data: xarray.DataArray) -> xarray.DataArray: 116 | """Weights proportional to pressure at each level.""" 117 | level = data.coords['level'] 118 | return level / level.mean(skipna=False) 119 | 120 | 121 | def single_level_weights(data: xarray.DataArray, level) -> xarray.DataArray: 122 | """Weights that select a single pressure level.""" 123 | l = data.coords['level'] 124 | weights = xarray.zeros_like(l) 125 | weights.loc[{'level': level}] = 1 # Replace '1' with the desired non-zero value 126 | return weights 127 | 128 | 129 | def normalized_latitude_weights(data: xarray.DataArray) -> xarray.DataArray: 130 | """Weights based on latitude, roughly proportional to grid cell area. 131 | 132 | This method supports two use cases only (both for equispaced values): 133 | * Latitude values such that the closest value to the pole is at latitude 134 | (90 - d_lat/2), where d_lat is the difference between contiguous latitudes. 135 | For example: [-89, -87, -85, ..., 85, 87, 89]) (d_lat = 2) 136 | In this case each point with `lat` value represents a sphere slice between 137 | `lat - d_lat/2` and `lat + d_lat/2`, and the area of this slice would be 138 | proportional to: 139 | `sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)`, and 140 | we can simply omit the term `2 * sin(d_lat/2)` which is just a constant 141 | that cancels during normalization. 142 | * Latitude values that fall exactly at the poles. 143 | For example: [-90, -88, -86, ..., 86, 88, 90]) (d_lat = 2) 144 | In this case each point with `lat` value also represents 145 | a sphere slice between `lat - d_lat/2` and `lat + d_lat/2`, 146 | except for the points at the poles, that represent a slice between 147 | `90 - d_lat/2` and `90` or, `-90` and `-90 + d_lat/2`. 148 | The areas of the first type of point are still proportional to: 149 | * sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat) 150 | but for the points at the poles now is: 151 | * sin(90) - sin(90 - d_lat/2) = 2 * sin(d_lat/4) ^ 2 152 | and we will be using these weights, depending on whether we are looking at 153 | pole cells, or non-pole cells (omitting the common factor of 2 which will be 154 | absorbed by the normalization). 155 | 156 | It can be shown via a limit, or simple geometry, that in the small angles 157 | regime, the proportion of area per pole-point is equal to 1/8th 158 | the proportion of area covered by each of the nearest non-pole point, and we 159 | test for this in the test. 160 | 161 | Args: 162 | data: `DataArray` with latitude coordinates. 163 | Returns: 164 | Unit mean latitude weights. 165 | """ 166 | latitude = data.coords['lat'] 167 | 168 | if np.any(np.isclose(np.abs(latitude), 90.)): 169 | weights = _weight_for_latitude_vector_with_poles(latitude) 170 | else: 171 | weights = _weight_for_latitude_vector_without_poles(latitude) 172 | 173 | return weights / weights.mean(skipna=False) 174 | 175 | 176 | def _weight_for_latitude_vector_without_poles(latitude): 177 | """Weights for uniform latitudes of the form [+-90-+d/2, ..., -+90+-d/2].""" 178 | delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude)) 179 | if (not np.isclose(np.max(latitude), 90 - delta_latitude/2) or 180 | not np.isclose(np.min(latitude), -90 + delta_latitude/2)): 181 | raise ValueError( 182 | f'Latitude vector {latitude} does not start/end at ' 183 | '+- (90 - delta_latitude/2) degrees.') 184 | return np.cos(np.deg2rad(latitude)) 185 | 186 | 187 | def _weight_for_latitude_vector_with_poles(latitude): 188 | """Weights for uniform latitudes of the form [+- 90, ..., -+90].""" 189 | delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude)) 190 | if (not np.isclose(np.max(latitude), 90.) or 191 | not np.isclose(np.min(latitude), -90.)): 192 | raise ValueError( 193 | f'Latitude vector {latitude} does not start/end at +- 90 degrees.') 194 | weights = np.cos(np.deg2rad(latitude)) * np.sin(np.deg2rad(delta_latitude/2)) 195 | # The two checks above enough to guarantee that latitudes are sorted, so 196 | # the extremes are the poles 197 | weights[[0, -1]] = np.sin(np.deg2rad(delta_latitude/4)) ** 2 198 | return weights 199 | 200 | 201 | def _check_uniform_spacing_and_get_delta(vector): 202 | diff = np.diff(vector) 203 | if not np.all(np.isclose(diff[0], diff)): 204 | raise ValueError(f'Vector {diff} is not uniformly spaced.') 205 | return diff[0] 206 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/predictor_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Abstract base classes for an xarray-based Predictor API.""" 15 | 16 | import abc 17 | 18 | from typing import Tuple 19 | 20 | from graphcast import losses 21 | from graphcast import xarray_jax 22 | import jax.numpy as jnp 23 | import xarray 24 | 25 | LossAndDiagnostics = losses.LossAndDiagnostics 26 | 27 | 28 | class Predictor(abc.ABC): 29 | """A possibly-trainable predictor of weather, exposing an xarray-based API. 30 | 31 | Typically wraps an underlying JAX model and handles translating the xarray 32 | Dataset values to and from plain JAX arrays that are convenient for input to 33 | (and output from) the underlying model. 34 | 35 | Different subclasses may exist to wrap different kinds of underlying model, 36 | e.g. models taking stacked inputs/outputs, models taking separate 2D and 3D 37 | inputs/outputs, autoregressive models. 38 | 39 | You can also implement a specific model directly as a Predictor if you want, 40 | for example if it has quite specific/unique requirements for its input/output 41 | or loss function, or if it's convenient to implement directly using xarray. 42 | """ 43 | 44 | @abc.abstractmethod 45 | def __call__(self, 46 | inputs: xarray.Dataset, 47 | targets_template: xarray.Dataset, 48 | forcings: xarray.Dataset, 49 | **optional_kwargs 50 | ) -> xarray.Dataset: 51 | """Makes predictions. 52 | 53 | This is only used by the Experiment for inference / evaluation, with 54 | training going via the .loss method. So it should default to making 55 | predictions for evaluation, although you can also support making predictions 56 | for use in the loss via an is_training argument -- see 57 | LossFunctionPredictor which helps with that. 58 | 59 | Args: 60 | inputs: An xarray.Dataset of inputs. 61 | targets_template: An xarray.Dataset or other mapping of xarray.DataArrays, 62 | with the same shape as the targets, to demonstrate what kind of 63 | predictions are required. You can use this to determine which variables, 64 | levels and lead times must be predicted. 65 | You are free to raise an error if you don't support predicting what is 66 | requested. 67 | forcings: An xarray.Dataset of forcings terms. Forcings are variables 68 | that can be fed to the model, but do not need to be predicted. This is 69 | often because this variable can be computed analytically (e.g. the toa 70 | radiation of the sun is mostly a function of geometry) or are considered 71 | to be controlled for the experiment (e.g., impose a scenario of C02 72 | emission into the atmosphere). Unlike `inputs`, the `forcings` can 73 | include information "from the future", that is, information at target 74 | times specified in the `targets_template`. 75 | **optional_kwargs: Implementations may support extra optional kwargs, 76 | provided they set appropriate defaults for them. 77 | 78 | Returns: 79 | Predictions, as an xarray.Dataset or other mapping of DataArrays which 80 | is capable of being evaluated against targets with shape given by 81 | targets_template. 82 | For probabilistic predictors which can return multiple samples from a 83 | predictive distribution, these should (by convention) be returned along 84 | an additional 'sample' dimension. 85 | """ 86 | 87 | def loss(self, 88 | inputs: xarray.Dataset, 89 | targets: xarray.Dataset, 90 | forcings: xarray.Dataset, 91 | **optional_kwargs, 92 | ) -> LossAndDiagnostics: 93 | """Computes a training loss, for predictors that are trainable. 94 | 95 | Why make this the Predictor's responsibility, rather than letting callers 96 | compute their own loss function using predictions obtained from 97 | Predictor.__call__? 98 | 99 | Doing it this way gives Predictors more control over their training setup. 100 | For example, some predictors may wish to train using different targets to 101 | the ones they predict at evaluation time -- perhaps different lead times and 102 | variables, perhaps training to predict transformed versions of targets 103 | where the transform needs to be inverted at evaluation time, etc. 104 | 105 | It's also necessary for generative models (VAEs, GANs, ...) where the 106 | training loss is more complex and isn't expressible as a parameter-free 107 | function of predictions and targets. 108 | 109 | Args: 110 | inputs: An xarray.Dataset. 111 | targets: An xarray.Dataset or other mapping of xarray.DataArrays. See 112 | docs on __call__ for an explanation about the targets. 113 | forcings: xarray.Dataset of forcing terms. 114 | **optional_kwargs: Implementations may support extra optional kwargs, 115 | provided they set appropriate defaults for them. 116 | 117 | Returns: 118 | loss: A DataArray with dimensions ('batch',) containing losses for each 119 | element of the batch. These will be averaged to give the final 120 | loss, locally and across replicas. 121 | diagnostics: Mapping of additional quantities to log by name alongside the 122 | loss. These will will typically correspond to terms in the loss. They 123 | should also have dimensions ('batch',) and will be averaged over the 124 | batch before logging. 125 | You need not include the loss itself in this dict; it will be added for 126 | you. 127 | """ 128 | del targets, forcings, optional_kwargs 129 | batch_size = inputs.sizes['batch'] 130 | dummy_loss = xarray_jax.DataArray(jnp.zeros(batch_size), dims=('batch',)) 131 | return dummy_loss, {} 132 | 133 | def loss_and_predictions( 134 | self, 135 | inputs: xarray.Dataset, 136 | targets: xarray.Dataset, 137 | forcings: xarray.Dataset, 138 | **optional_kwargs, 139 | ) -> Tuple[LossAndDiagnostics, xarray.Dataset]: 140 | """Like .loss but also returns corresponding predictions. 141 | 142 | Implementing this is optional as it's not used directly by the Experiment, 143 | but it is required by autoregressive.Predictor when applying an inner 144 | Predictor autoregressively at training time; we need a loss at each step but 145 | also predictions to feed back in for the next step. 146 | 147 | Note the loss itself may not be directly regressing the predictions towards 148 | targets, the loss may be computed in terms of transformed predictions and 149 | targets (or in some other way). For this reason we can't always cleanly 150 | separate this into step 1: get predictions, step 2: compute loss from them, 151 | hence the need for this combined method. 152 | 153 | Args: 154 | inputs: 155 | targets: 156 | forcings: 157 | **optional_kwargs: 158 | As for self.loss. 159 | 160 | Returns: 161 | (loss, diagnostics) 162 | As for self.loss 163 | predictions: 164 | The predictions which the loss relates to. These should be of the same 165 | shape as what you would get from 166 | `self.__call__(inputs, targets_template=targets)`, and should be in the 167 | same 'domain' as the inputs (i.e. they shouldn't be transformed 168 | differently to how the predictor expects its inputs). 169 | """ 170 | raise NotImplementedError 171 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/rollout.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for rolling out models.""" 15 | 16 | from typing import Iterator 17 | 18 | from absl import logging 19 | import chex 20 | import dask 21 | from graphcast import xarray_tree 22 | import jax 23 | import numpy as np 24 | import typing_extensions 25 | import xarray 26 | 27 | 28 | class PredictorFn(typing_extensions.Protocol): 29 | """Functional version of base.Predictor.__call__ with explicit rng.""" 30 | 31 | def __call__( 32 | self, rng: chex.PRNGKey, inputs: xarray.Dataset, 33 | targets_template: xarray.Dataset, 34 | forcings: xarray.Dataset, 35 | **optional_kwargs, 36 | ) -> xarray.Dataset: 37 | ... 38 | 39 | 40 | def chunked_prediction( 41 | predictor_fn: PredictorFn, 42 | rng: chex.PRNGKey, 43 | inputs: xarray.Dataset, 44 | targets_template: xarray.Dataset, 45 | forcings: xarray.Dataset, 46 | num_steps_per_chunk: int = 1, 47 | verbose: bool = False, 48 | ) -> xarray.Dataset: 49 | """Outputs a long trajectory by iteratively concatenating chunked predictions. 50 | 51 | Args: 52 | predictor_fn: Function to use to make predictions for each chunk. 53 | rng: Random key. 54 | inputs: Inputs for the model. 55 | targets_template: Template for the target prediction, requires targets 56 | equispaced in time. 57 | forcings: Optional forcing for the model. 58 | num_steps_per_chunk: How many of the steps in `targets_template` to predict 59 | at each call of `predictor_fn`. It must evenly divide the number of 60 | steps in `targets_template`. 61 | verbose: Whether to log the current chunk being predicted. 62 | 63 | Returns: 64 | Predictions for the targets template. 65 | 66 | """ 67 | chunks_list = [] 68 | for prediction_chunk in chunked_prediction_generator( 69 | predictor_fn=predictor_fn, 70 | rng=rng, 71 | inputs=inputs, 72 | targets_template=targets_template, 73 | forcings=forcings, 74 | num_steps_per_chunk=num_steps_per_chunk, 75 | verbose=verbose): 76 | chunks_list.append(jax.device_get(prediction_chunk)) 77 | return xarray.concat(chunks_list, dim="time") 78 | 79 | 80 | def chunked_prediction_generator( 81 | predictor_fn: PredictorFn, 82 | rng: chex.PRNGKey, 83 | inputs: xarray.Dataset, 84 | targets_template: xarray.Dataset, 85 | forcings: xarray.Dataset, 86 | num_steps_per_chunk: int = 1, 87 | verbose: bool = False, 88 | ) -> Iterator[xarray.Dataset]: 89 | """Outputs a long trajectory by yielding chunked predictions. 90 | 91 | Args: 92 | predictor_fn: Function to use to make predictions for each chunk. 93 | rng: Random key. 94 | inputs: Inputs for the model. 95 | targets_template: Template for the target prediction, requires targets 96 | equispaced in time. 97 | forcings: Optional forcing for the model. 98 | num_steps_per_chunk: How many of the steps in `targets_template` to predict 99 | at each call of `predictor_fn`. It must evenly divide the number of 100 | steps in `targets_template`. 101 | verbose: Whether to log the current chunk being predicted. 102 | 103 | Yields: 104 | The predictions for each chunked step of the chunked rollout, such as 105 | if all predictions are concatenated in time this would match the targets 106 | template in structure. 107 | 108 | """ 109 | 110 | # Create copies to avoid mutating inputs. 111 | inputs = xarray.Dataset(inputs) 112 | targets_template = xarray.Dataset(targets_template) 113 | forcings = xarray.Dataset(forcings) 114 | 115 | if "datetime" in inputs.coords: 116 | del inputs.coords["datetime"] 117 | 118 | if "datetime" in targets_template.coords: 119 | output_datetime = targets_template.coords["datetime"] 120 | del targets_template.coords["datetime"] 121 | else: 122 | output_datetime = None 123 | 124 | if "datetime" in forcings.coords: 125 | del forcings.coords["datetime"] 126 | 127 | num_target_steps = targets_template.dims["time"] 128 | num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk) 129 | if remainder != 0: 130 | raise ValueError( 131 | f"The number of steps per chunk {num_steps_per_chunk} must " 132 | f"evenly divide the number of target steps {num_target_steps} ") 133 | 134 | if len(np.unique(np.diff(targets_template.coords["time"].data))) > 1: 135 | raise ValueError("The targets time coordinates must be evenly spaced") 136 | 137 | # Our template targets will always have a time axis corresponding for the 138 | # timedeltas for the first chunk. 139 | targets_chunk_time = targets_template.time.isel( 140 | time=slice(0, num_steps_per_chunk)) 141 | 142 | current_inputs = inputs 143 | for chunk_index in range(num_chunks): 144 | if verbose: 145 | logging.info("Chunk %d/%d", chunk_index, num_chunks) 146 | logging.flush() 147 | 148 | # Select targets for the time period that we are predicting for this chunk. 149 | target_offset = num_steps_per_chunk * chunk_index 150 | target_slice = slice(target_offset, target_offset + num_steps_per_chunk) 151 | current_targets_template = targets_template.isel(time=target_slice) 152 | 153 | # Replace the timedelta, by the one corresponding to the first chunk, so we 154 | # don't recompile at every iteration, keeping the 155 | actual_target_time = current_targets_template.coords["time"] 156 | current_targets_template = current_targets_template.assign_coords( 157 | time=targets_chunk_time).compute() 158 | 159 | current_forcings = forcings.isel(time=target_slice) 160 | current_forcings = current_forcings.assign_coords(time=targets_chunk_time) 161 | current_forcings = current_forcings.compute() 162 | # Make predictions for the chunk. 163 | rng, this_rng = jax.random.split(rng) 164 | predictions = predictor_fn( 165 | rng=this_rng, 166 | inputs=current_inputs, 167 | targets_template=current_targets_template, 168 | forcings=current_forcings) 169 | 170 | next_frame = xarray.merge([predictions, current_forcings]) 171 | 172 | current_inputs = _get_next_inputs(current_inputs, next_frame) 173 | 174 | # At this point we can assign the actual targets time coordinates. 175 | predictions = predictions.assign_coords(time=actual_target_time) 176 | if output_datetime is not None: 177 | predictions.coords["datetime"] = output_datetime.isel( 178 | time=target_slice) 179 | yield predictions 180 | del predictions 181 | 182 | 183 | def _get_next_inputs( 184 | prev_inputs: xarray.Dataset, next_frame: xarray.Dataset, 185 | ) -> xarray.Dataset: 186 | """Computes next inputs, from previous inputs and predictions.""" 187 | 188 | # Make sure are are predicting all inputs with a time axis. 189 | non_predicted_or_forced_inputs = list( 190 | set(prev_inputs.keys()) - set(next_frame.keys())) 191 | if "time" in prev_inputs[non_predicted_or_forced_inputs].dims: 192 | raise ValueError( 193 | "Found an input with a time index that is not predicted or forced.") 194 | 195 | # Keys we need to copy from predictions to inputs. 196 | next_inputs_keys = list( 197 | set(next_frame.keys()).intersection(set(prev_inputs.keys()))) 198 | next_inputs = next_frame[next_inputs_keys] 199 | 200 | # Apply concatenate next frame with inputs, crop what we don't need and 201 | # shift timedelta coordinates, so we don't recompile at every iteration. 202 | num_inputs = prev_inputs.dims["time"] 203 | return ( 204 | xarray.concat( 205 | [prev_inputs, next_inputs], dim="time", data_vars="different") 206 | .tail(time=num_inputs) 207 | .assign_coords(time=prev_inputs.coords["time"])) 208 | 209 | 210 | def extend_targets_template( 211 | targets_template: xarray.Dataset, 212 | required_num_steps: int) -> xarray.Dataset: 213 | """Extends `targets_template` to `required_num_steps` with lazy arrays. 214 | 215 | It uses lazy dask arrays of zeros, so it does not require instantiating the 216 | array in memory. 217 | 218 | Args: 219 | targets_template: Input template to extend. 220 | required_num_steps: Number of steps required in the returned template. 221 | 222 | Returns: 223 | `xarray.Dataset` identical in variables and timestep to `targets_template` 224 | full of `dask.array.zeros` such that the time axis has `required_num_steps`. 225 | 226 | """ 227 | 228 | # Extend the "time" and "datetime" coordinates 229 | time = targets_template.coords["time"] 230 | 231 | # Assert the first target time corresponds to the timestep. 232 | timestep = time[0].data 233 | if time.shape[0] > 1: 234 | assert np.all(timestep == time[1:] - time[:-1]) 235 | 236 | extended_time = (np.arange(required_num_steps) + 1) * timestep 237 | 238 | if "datetime" in targets_template.coords: 239 | datetime = targets_template.coords["datetime"] 240 | extended_datetime = (datetime[0].data - timestep) + extended_time 241 | else: 242 | extended_datetime = None 243 | 244 | # Replace the values with empty dask arrays extending the time coordinates. 245 | datetime = targets_template.coords["time"] 246 | 247 | def extend_time(data_array: xarray.DataArray) -> xarray.DataArray: 248 | dims = data_array.dims 249 | shape = list(data_array.shape) 250 | shape[dims.index("time")] = required_num_steps 251 | dask_data = dask.array.zeros( 252 | shape=tuple(shape), 253 | chunks=-1, # Will give chunk info directly to `ChunksToZarr``. 254 | dtype=data_array.dtype) 255 | 256 | coords = dict(data_array.coords) 257 | coords["time"] = extended_time 258 | 259 | if extended_datetime is not None: 260 | coords["datetime"] = ("time", extended_datetime) 261 | 262 | return xarray.DataArray( 263 | dims=dims, 264 | data=dask_data, 265 | coords=coords) 266 | 267 | return xarray_tree.map_structure(extend_time, targets_template) 268 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/typed_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Data-structure for storing graphs with typed edges and nodes.""" 15 | 16 | from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar 17 | 18 | ArrayLike = Union[Any] # np.ndarray, jnp.ndarray, tf.tensor 19 | ArrayLikeTree = Union[Any, ArrayLike] # Nest of ArrayLike 20 | 21 | _T = TypeVar('_T') 22 | 23 | 24 | # All tensors have a "flat_batch_axis", which is similar to the leading 25 | # axes of graph_tuples: 26 | # * In the case of nodes this is simply a shared node and flat batch axis, with 27 | # size corresponding to the total number of nodes in the flattened batch. 28 | # * In the case of edges this is simply a shared edge and flat batch axis, with 29 | # size corresponding to the total number of edges in the flattened batch. 30 | # * In the case of globals this is simply the number of graphs in the flattened 31 | # batch. 32 | 33 | # All shapes may also have any additional leading shape "batch_shape". 34 | # Options for building batches are: 35 | # * Use a provided "flatten" method that takes a leading `batch_shape` and 36 | # it into the flat_batch_axis (this will be useful when using `tf.Dataset` 37 | # which supports batching into RaggedTensors, with leading batch shape even 38 | # if graphs have different numbers of nodes and edges), so the RaggedBatches 39 | # can then be converted into something without ragged dimensions that jax can 40 | # use. 41 | # * Directly build a "flat batch" using a provided function for batching a list 42 | # of graphs (how it is done in `jraph`). 43 | 44 | 45 | class NodeSet(NamedTuple): 46 | """Represents a set of nodes.""" 47 | n_node: ArrayLike # [num_flat_graphs] 48 | features: ArrayLikeTree # Prev. `nodes`: [num_flat_nodes] + feature_shape 49 | 50 | 51 | class EdgesIndices(NamedTuple): 52 | """Represents indices to nodes adjacent to the edges.""" 53 | senders: ArrayLike # [num_flat_edges] 54 | receivers: ArrayLike # [num_flat_edges] 55 | 56 | 57 | class EdgeSet(NamedTuple): 58 | """Represents a set of edges.""" 59 | n_edge: ArrayLike # [num_flat_graphs] 60 | indices: EdgesIndices 61 | features: ArrayLikeTree # Prev. `edges`: [num_flat_edges] + feature_shape 62 | 63 | 64 | class Context(NamedTuple): 65 | # `n_graph` always contains ones but it is useful to query the leading shape 66 | # in case of graphs without any nodes or edges sets. 67 | n_graph: ArrayLike # [num_flat_graphs] 68 | features: ArrayLikeTree # Prev. `globals`: [num_flat_graphs] + feature_shape 69 | 70 | 71 | class EdgeSetKey(NamedTuple): 72 | name: str # Name of the EdgeSet. 73 | 74 | # Sender node set name and receiver node set name connected by the edge set. 75 | node_sets: Tuple[str, str] 76 | 77 | 78 | class TypedGraph(NamedTuple): 79 | """A graph with typed nodes and edges. 80 | 81 | A typed graph is made of a context, multiple sets of nodes and multiple 82 | sets of edges connecting those nodes (as indicated by the EdgeSetKey). 83 | """ 84 | 85 | context: Context 86 | nodes: Mapping[str, NodeSet] 87 | edges: Mapping[EdgeSetKey, EdgeSet] 88 | 89 | def edge_key_by_name(self, name: str) -> EdgeSetKey: 90 | found_key = [k for k in self.edges.keys() if k.name == name] 91 | if len(found_key) != 1: 92 | raise KeyError("invalid edge key '{}'. Available edges: [{}]".format( 93 | name, ', '.join(x.name for x in self.edges.keys()))) 94 | return found_key[0] 95 | 96 | def edge_by_name(self, name: str) -> EdgeSet: 97 | return self.edges[self.edge_key_by_name(name)] 98 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/typed_graph_net.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A library of typed Graph Neural Networks.""" 15 | 16 | from typing import Callable, Mapping, Optional, Union 17 | 18 | from graphcast import typed_graph 19 | import jax.numpy as jnp 20 | import jax.tree_util as tree 21 | import jraph 22 | 23 | 24 | # All features will be an ArrayTree. 25 | NodeFeatures = EdgeFeatures = SenderFeatures = ReceiverFeatures = Globals = ( 26 | jraph.ArrayTree) 27 | 28 | # Signature: 29 | # (node features, outgoing edge features, incoming edge features, 30 | # globals) -> updated node features 31 | GNUpdateNodeFn = Callable[ 32 | [NodeFeatures, Mapping[str, SenderFeatures], Mapping[str, ReceiverFeatures], 33 | Globals], 34 | NodeFeatures] 35 | 36 | GNUpdateGlobalFn = Callable[ 37 | [Mapping[str, NodeFeatures], Mapping[str, EdgeFeatures], Globals], 38 | Globals] 39 | 40 | 41 | def GraphNetwork( # pylint: disable=invalid-name 42 | update_edge_fn: Mapping[str, jraph.GNUpdateEdgeFn], 43 | update_node_fn: Mapping[str, GNUpdateNodeFn], 44 | update_global_fn: Optional[GNUpdateGlobalFn] = None, 45 | aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph 46 | .segment_sum, 47 | aggregate_nodes_for_globals_fn: jraph.AggregateNodesToGlobalsFn = jraph 48 | .segment_sum, 49 | aggregate_edges_for_globals_fn: jraph.AggregateEdgesToGlobalsFn = jraph 50 | .segment_sum, 51 | ): 52 | """Returns a method that applies a configured GraphNetwork. 53 | 54 | This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 55 | extended to Typed Graphs with multiple edge sets and node sets and extended to 56 | allow aggregating not only edges received by the nodes, but also edges sent by 57 | the nodes. 58 | 59 | Example usage:: 60 | 61 | gn = GraphNetwork(update_edge_function, 62 | update_node_function, **kwargs) 63 | # Conduct multiple rounds of message passing with the same parameters: 64 | for _ in range(num_message_passing_steps): 65 | graph = gn(graph) 66 | 67 | Args: 68 | update_edge_fn: mapping of functions used to update a subset of the edge 69 | types, indexed by edge type name. 70 | update_node_fn: mapping of functions used to update a subset of the node 71 | types, indexed by node type name. 72 | update_global_fn: function used to update the globals or None to deactivate 73 | globals updates. 74 | aggregate_edges_for_nodes_fn: function used to aggregate messages to each 75 | node. 76 | aggregate_nodes_for_globals_fn: function used to aggregate the nodes for the 77 | globals. 78 | aggregate_edges_for_globals_fn: function used to aggregate the edges for the 79 | globals. 80 | 81 | Returns: 82 | A method that applies the configured GraphNetwork. 83 | """ 84 | 85 | def _apply_graph_net(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: 86 | """Applies a configured GraphNetwork to a graph. 87 | 88 | This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261 89 | extended to Typed Graphs with multiple edge sets and node sets and extended 90 | to allow aggregating not only edges received by the nodes, but also edges 91 | sent by the nodes. 92 | 93 | Args: 94 | graph: a `TypedGraph` containing the graph. 95 | 96 | Returns: 97 | Updated `TypedGraph`. 98 | """ 99 | 100 | updated_graph = graph 101 | 102 | # Edge update. 103 | updated_edges = dict(updated_graph.edges) 104 | for edge_set_name, edge_fn in update_edge_fn.items(): 105 | edge_set_key = graph.edge_key_by_name(edge_set_name) 106 | updated_edges[edge_set_key] = _edge_update( 107 | updated_graph, edge_fn, edge_set_key) 108 | updated_graph = updated_graph._replace(edges=updated_edges) 109 | 110 | # Node update. 111 | updated_nodes = dict(updated_graph.nodes) 112 | for node_set_key, node_fn in update_node_fn.items(): 113 | updated_nodes[node_set_key] = _node_update( 114 | updated_graph, node_fn, node_set_key, aggregate_edges_for_nodes_fn) 115 | updated_graph = updated_graph._replace(nodes=updated_nodes) 116 | 117 | # Global update. 118 | if update_global_fn: 119 | updated_context = _global_update( 120 | updated_graph, update_global_fn, 121 | aggregate_edges_for_globals_fn, 122 | aggregate_nodes_for_globals_fn) 123 | updated_graph = updated_graph._replace(context=updated_context) 124 | 125 | return updated_graph 126 | 127 | return _apply_graph_net 128 | 129 | 130 | def _edge_update(graph, edge_fn, edge_set_key): # pylint: disable=invalid-name 131 | """Updates an edge set of a given key.""" 132 | 133 | sender_nodes = graph.nodes[edge_set_key.node_sets[0]] 134 | receiver_nodes = graph.nodes[edge_set_key.node_sets[1]] 135 | edge_set = graph.edges[edge_set_key] 136 | senders = edge_set.indices.senders # pytype: disable=attribute-error 137 | receivers = edge_set.indices.receivers # pytype: disable=attribute-error 138 | 139 | sent_attributes = tree.tree_map( 140 | lambda n: n[senders], sender_nodes.features) 141 | received_attributes = tree.tree_map( 142 | lambda n: n[receivers], receiver_nodes.features) 143 | 144 | n_edge = edge_set.n_edge 145 | sum_n_edge = senders.shape[0] 146 | global_features = tree.tree_map( 147 | lambda g: jnp.repeat(g, n_edge, axis=0, total_repeat_length=sum_n_edge), 148 | graph.context.features) 149 | new_features = edge_fn( 150 | edge_set.features, sent_attributes, received_attributes, 151 | global_features) 152 | return edge_set._replace(features=new_features) 153 | 154 | 155 | def _node_update(graph, node_fn, node_set_key, aggregation_fn): # pylint: disable=invalid-name 156 | """Updates an edge set of a given key.""" 157 | node_set = graph.nodes[node_set_key] 158 | sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0] 159 | 160 | sent_features = {} 161 | for edge_set_key, edge_set in graph.edges.items(): 162 | sender_node_set_key = edge_set_key.node_sets[0] 163 | if sender_node_set_key == node_set_key: 164 | assert isinstance(edge_set.indices, typed_graph.EdgesIndices) 165 | senders = edge_set.indices.senders 166 | sent_features[edge_set_key.name] = tree.tree_map( 167 | lambda e: aggregation_fn(e, senders, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop 168 | 169 | received_features = {} 170 | for edge_set_key, edge_set in graph.edges.items(): 171 | receiver_node_set_key = edge_set_key.node_sets[1] 172 | if receiver_node_set_key == node_set_key: 173 | assert isinstance(edge_set.indices, typed_graph.EdgesIndices) 174 | receivers = edge_set.indices.receivers 175 | received_features[edge_set_key.name] = tree.tree_map( 176 | lambda e: aggregation_fn(e, receivers, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop 177 | 178 | n_node = node_set.n_node 179 | global_features = tree.tree_map( 180 | lambda g: jnp.repeat(g, n_node, axis=0, total_repeat_length=sum_n_node), 181 | graph.context.features) 182 | new_features = node_fn( 183 | node_set.features, sent_features, received_features, global_features) 184 | return node_set._replace(features=new_features) 185 | 186 | 187 | def _global_update(graph, global_fn, edge_aggregation_fn, node_aggregation_fn): # pylint: disable=invalid-name 188 | """Updates an edge set of a given key.""" 189 | n_graph = graph.context.n_graph.shape[0] 190 | graph_idx = jnp.arange(n_graph) 191 | 192 | edge_features = {} 193 | for edge_set_key, edge_set in graph.edges.items(): 194 | assert isinstance(edge_set.indices, typed_graph.EdgesIndices) 195 | sum_n_edge = edge_set.indices.senders.shape[0] 196 | edge_gr_idx = jnp.repeat( 197 | graph_idx, edge_set.n_edge, axis=0, total_repeat_length=sum_n_edge) 198 | edge_features[edge_set_key.name] = tree.tree_map( 199 | lambda e: edge_aggregation_fn(e, edge_gr_idx, n_graph), # pylint: disable=cell-var-from-loop 200 | edge_set.features) 201 | 202 | node_features = {} 203 | for node_set_key, node_set in graph.nodes.items(): 204 | sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0] 205 | node_gr_idx = jnp.repeat( 206 | graph_idx, node_set.n_node, axis=0, total_repeat_length=sum_n_node) 207 | node_features[node_set_key] = tree.tree_map( 208 | lambda n: node_aggregation_fn(n, node_gr_idx, n_graph), # pylint: disable=cell-var-from-loop 209 | node_set.features) 210 | 211 | new_features = global_fn(node_features, edge_features, graph.context.features) 212 | return graph.context._replace(features=new_features) 213 | 214 | 215 | InteractionUpdateNodeFn = Callable[ 216 | [jraph.NodeFeatures, 217 | Mapping[str, SenderFeatures], 218 | Mapping[str, ReceiverFeatures]], 219 | jraph.NodeFeatures] 220 | 221 | 222 | InteractionUpdateNodeFnNoSentEdges = Callable[ 223 | [jraph.NodeFeatures, 224 | Mapping[str, ReceiverFeatures]], 225 | jraph.NodeFeatures] 226 | 227 | 228 | def InteractionNetwork( # pylint: disable=invalid-name 229 | update_edge_fn: Mapping[str, jraph.InteractionUpdateEdgeFn], 230 | update_node_fn: Mapping[str, Union[InteractionUpdateNodeFn, 231 | InteractionUpdateNodeFnNoSentEdges]], 232 | aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph 233 | .segment_sum, 234 | include_sent_messages_in_node_update: bool = False): 235 | """Returns a method that applies a configured InteractionNetwork. 236 | 237 | An interaction network computes interactions on the edges based on the 238 | previous edges features, and on the features of the nodes sending into those 239 | edges. It then updates the nodes based on the incoming updated edges. 240 | See https://arxiv.org/abs/1612.00222 for more details. 241 | 242 | This implementation extends the behavior to `TypedGraphs` adding an option 243 | to include edge features for which a node is a sender in the arguments to 244 | the node update function. 245 | 246 | Args: 247 | update_edge_fn: mapping of functions used to update a subset of the edge 248 | types, indexed by edge type name. 249 | update_node_fn: mapping of functions used to update a subset of the node 250 | types, indexed by node type name. 251 | aggregate_edges_for_nodes_fn: function used to aggregate messages to each 252 | node. 253 | include_sent_messages_in_node_update: pass edge features for which a node is 254 | a sender to the node update function. 255 | """ 256 | # An InteractionNetwork is a GraphNetwork without globals features, 257 | # so we implement the InteractionNetwork as a configured GraphNetwork. 258 | 259 | # An InteractionNetwork edge function does not have global feature inputs, 260 | # so we filter the passed global argument in the GraphNetwork. 261 | wrapped_update_edge_fn = tree.tree_map( 262 | lambda fn: lambda e, s, r, g: fn(e, s, r), update_edge_fn) 263 | 264 | # Similarly, we wrap the update_node_fn to ensure only the expected 265 | # arguments are passed to the Interaction net. 266 | if include_sent_messages_in_node_update: 267 | wrapped_update_node_fn = tree.tree_map( 268 | lambda fn: lambda n, s, r, g: fn(n, s, r), update_node_fn) 269 | else: 270 | wrapped_update_node_fn = tree.tree_map( 271 | lambda fn: lambda n, s, r, g: fn(n, r), update_node_fn) 272 | return GraphNetwork( 273 | update_edge_fn=wrapped_update_edge_fn, 274 | update_node_fn=wrapped_update_node_fn, 275 | aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn) 276 | 277 | 278 | def GraphMapFeatures( # pylint: disable=invalid-name 279 | embed_edge_fn: Optional[Mapping[str, jraph.EmbedEdgeFn]] = None, 280 | embed_node_fn: Optional[Mapping[str, jraph.EmbedNodeFn]] = None, 281 | embed_global_fn: Optional[jraph.EmbedGlobalFn] = None): 282 | """Returns function which embeds the components of a graph independently. 283 | 284 | Args: 285 | embed_edge_fn: mapping of functions used to embed each edge type, 286 | indexed by edge type name. 287 | embed_node_fn: mapping of functions used to embed each node type, 288 | indexed by node type name. 289 | embed_global_fn: function used to embed the globals. 290 | """ 291 | 292 | def _embed(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph: 293 | 294 | updated_edges = dict(graph.edges) 295 | if embed_edge_fn: 296 | for edge_set_name, embed_fn in embed_edge_fn.items(): 297 | edge_set_key = graph.edge_key_by_name(edge_set_name) 298 | edge_set = graph.edges[edge_set_key] 299 | updated_edges[edge_set_key] = edge_set._replace( 300 | features=embed_fn(edge_set.features)) 301 | 302 | updated_nodes = dict(graph.nodes) 303 | if embed_node_fn: 304 | for node_set_key, embed_fn in embed_node_fn.items(): 305 | node_set = graph.nodes[node_set_key] 306 | updated_nodes[node_set_key] = node_set._replace( 307 | features=embed_fn(node_set.features)) 308 | 309 | updated_context = graph.context 310 | if embed_global_fn: 311 | updated_context = updated_context._replace( 312 | features=embed_global_fn(updated_context.features)) 313 | 314 | return graph._replace(edges=updated_edges, nodes=updated_nodes, 315 | context=updated_context) 316 | 317 | return _embed 318 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/xarray_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for working with trees of xarray.DataArray (including Datasets). 15 | 16 | Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library; 17 | it won't work as a leaf node since it implements Mapping, but also won't work 18 | as an internal node since tree doesn't know how to re-create it properly. 19 | 20 | To fix this, we reimplement a subset of `map_structure`, exposing its 21 | constituent DataArrays as leaf nodes. This means it can be mapped over as a 22 | generic container of DataArrays, while still preserving the result as a Dataset 23 | where possible. 24 | 25 | This is useful because in a few places we need to handle a general 26 | Mapping[str, DataArray] (where the coordinates might not be compatible across 27 | the constituent DataArrays) but also the special case of a Dataset nicely. 28 | 29 | For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for 30 | some of the child DataArrays, they will be omitted from the returned dataset. If 31 | any values other than DataArrays or None are returned, then we don't attempt to 32 | return a Dataset and just return a plain dict of the results. Similarly if 33 | DataArrays are returned but with non-matching coordinates, it will just return a 34 | plain dict of DataArrays. 35 | 36 | Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py, 37 | but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`. 38 | as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the 39 | latter exposes DataArrays as leaf nodes. 40 | """ 41 | 42 | from typing import Any, Callable 43 | 44 | import xarray 45 | 46 | 47 | def map_structure(func: Callable[..., Any], *structures: Any) -> Any: 48 | """Maps func through given structures with xarrays. See tree.map_structure.""" 49 | if not callable(func): 50 | raise TypeError(f'func must be callable, got: {func}') 51 | if not structures: 52 | raise ValueError('Must provide at least one structure') 53 | 54 | first = structures[0] 55 | if isinstance(first, xarray.Dataset): 56 | data = {k: func(*[s[k] for s in structures]) for k in first.keys()} 57 | if all(isinstance(a, (type(None), xarray.DataArray)) 58 | for a in data.values()): 59 | data_arrays = [v.rename(k) for k, v in data.items() if v is not None] 60 | try: 61 | return xarray.merge(data_arrays, join='exact') 62 | except ValueError: # Exact join not possible. 63 | pass 64 | return data 65 | if isinstance(first, dict): 66 | return {k: map_structure(func, *[s[k] for s in structures]) 67 | for k in first.keys()} 68 | if isinstance(first, (list, tuple, set)): 69 | return type(first)(map_structure(func, *s) for s in zip(*structures)) 70 | return func(*structures) 71 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast/xarray_tree_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for xarray_tree.""" 15 | 16 | from absl.testing import absltest 17 | from graphcast import xarray_tree 18 | import numpy as np 19 | import xarray 20 | 21 | 22 | TEST_DATASET = xarray.Dataset( 23 | data_vars={ 24 | "foo": (("x", "y"), np.zeros((2, 3))), 25 | "bar": (("x",), np.zeros((2,))), 26 | }, 27 | coords={ 28 | "x": [1, 2], 29 | "y": [10, 20, 30], 30 | } 31 | ) 32 | 33 | 34 | class XarrayTreeTest(absltest.TestCase): 35 | 36 | def test_map_structure_maps_over_leaves_but_preserves_dataset_type(self): 37 | def fn(leaf): 38 | self.assertIsInstance(leaf, xarray.DataArray) 39 | result = leaf + 1 40 | # Removing the name from the returned DataArray to test that we don't rely 41 | # on it being present to restore the correct names in the result: 42 | result = result.rename(None) 43 | return result 44 | 45 | result = xarray_tree.map_structure(fn, TEST_DATASET) 46 | self.assertIsInstance(result, xarray.Dataset) 47 | self.assertSameElements({"foo", "bar"}, result.keys()) 48 | 49 | def test_map_structure_on_data_arrays(self): 50 | data_arrays = dict(TEST_DATASET) 51 | result = xarray_tree.map_structure(lambda x: x+1, data_arrays) 52 | self.assertIsInstance(result, dict) 53 | self.assertSameElements({"foo", "bar"}, result.keys()) 54 | 55 | def test_map_structure_on_dataset_plain_dict_when_coords_incompatible(self): 56 | def fn(leaf): 57 | # Returns DataArrays that can't be exactly merged back into a Dataset 58 | # due to the coordinates not matching: 59 | if leaf.name == "foo": 60 | return xarray.DataArray( 61 | data=np.zeros(2), dims=("x",), coords={"x": [1, 2]}) 62 | else: 63 | return xarray.DataArray( 64 | data=np.zeros(2), dims=("x",), coords={"x": [3, 4]}) 65 | 66 | result = xarray_tree.map_structure(fn, TEST_DATASET) 67 | self.assertIsInstance(result, dict) 68 | self.assertSameElements({"foo", "bar"}, result.keys()) 69 | 70 | def test_map_structure_on_dataset_drops_vars_with_none_return_values(self): 71 | def fn(leaf): 72 | return leaf if leaf.name == "foo" else None 73 | 74 | result = xarray_tree.map_structure(fn, TEST_DATASET) 75 | self.assertIsInstance(result, xarray.Dataset) 76 | self.assertSameElements({"foo"}, result.keys()) 77 | 78 | def test_map_structure_on_dataset_returns_plain_dict_other_return_types(self): 79 | def fn(leaf): 80 | self.assertIsInstance(leaf, xarray.DataArray) 81 | return "not a DataArray" 82 | 83 | result = xarray_tree.map_structure(fn, TEST_DATASET) 84 | self.assertEqual({"foo": "not a DataArray", 85 | "bar": "not a DataArray"}, result) 86 | 87 | def test_map_structure_two_args_different_variable_orders(self): 88 | dataset_different_order = TEST_DATASET[["bar", "foo"]] 89 | def fn(arg1, arg2): 90 | self.assertEqual(arg1.name, arg2.name) 91 | xarray_tree.map_structure(fn, TEST_DATASET, dataset_different_order) 92 | 93 | 94 | if __name__ == "__main__": 95 | absltest.main() 96 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast_runner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # GraphCast 5 | # 6 | # This colab lets you run several versions of GraphCast. 7 | # 8 | # The model weights, normalization statistics, and example inputs are available on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast). 9 | # 10 | # A Colab runtime with TPU/GPU acceleration will substantially speed up generating predictions and computing the loss/gradients. If you're using a CPU-only runtime, you can switch using the menu "Runtime > Change runtime type". 11 | 12 | # >

Copyright 2023 DeepMind Technologies Limited.

13 | # >

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.

14 | # >

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

15 | 16 | # Use most memory-conservative allocation scheme 17 | # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html 18 | # Either set this in your environment or set this before any import of jax code 19 | import os 20 | from pathlib import Path 21 | 22 | import graphcast_wrapper 23 | from weatherbench2_dataloader import WeatherBench2Dataset 24 | 25 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 26 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" 27 | 28 | import argparse 29 | import dataclasses 30 | import functools 31 | import time 32 | 33 | from graphcast import losses 34 | from graphcast import data_utils 35 | from graphcast import graphcast 36 | from graphcast import rollout 37 | from graphcast import xarray_jax 38 | from graphcast import xarray_tree 39 | import haiku as hk 40 | import jax 41 | import numpy as np 42 | 43 | from buckets import authenticate_bucket 44 | from demo_data import load_example_batch, save_example_batch, load_normalization 45 | 46 | 47 | def train_model(loss_fn: hk.TransformedWithState, params, train_inputs, train_targets, train_forcings, epochs=5): 48 | """ 49 | @param loss_fn: a hk_transform wrapped loss function (encapsulates the model as well) 50 | @param params: initial model parameters (from a checkpoint) 51 | @param train_inputs 52 | @param train_targets 53 | @param train_forcings 54 | @param epochs 55 | """ 56 | assert epochs is not None 57 | 58 | def grads_fn(params, state, inputs, targets, forcings): 59 | def _aux(params, state, i, t, f): 60 | (loss, diagnostics), next_state = loss_fn.apply(params, 61 | state, 62 | jax.random.PRNGKey(0), 63 | i, 64 | t, 65 | f) 66 | return loss, (diagnostics, next_state) 67 | 68 | # TODO add reduce_axes=('batch',) 69 | (loss, (diagnostics, next_state)), grads = jax.value_and_grad(_aux, has_aux=True)(params, 70 | state, 71 | inputs, 72 | targets, 73 | forcings) 74 | return loss, diagnostics, next_state, grads 75 | 76 | grads_fn_jitted = jax.jit(grads_fn) 77 | 78 | runtimes = [] 79 | for i in range(epochs): 80 | tic = time.perf_counter() 81 | # Gradient computation (backprop through time) 82 | loss, diagnostics, next_state, grads = grads_fn_jitted( 83 | params=params, 84 | state={}, 85 | inputs=train_inputs, 86 | targets=train_targets, 87 | forcings=train_forcings) 88 | jax.block_until_ready(grads) 89 | jax.block_until_ready(loss) 90 | mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0]) 91 | print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}") 92 | toc = time.perf_counter() 93 | print(f"Step {i} took {toc-tic}s") 94 | if i > 0: 95 | runtimes.append(toc-tic) 96 | print("Training step time: ", np.mean(np.asarray(runtimes)), " +-", np.std(np.asarray(runtimes))) 97 | 98 | 99 | def evaluate_model(fwd_cost_fn, task_config: graphcast.TaskConfig, dataloader, autoregressive_steps=1): 100 | """ 101 | Perform inference using the given forward cost function. 102 | Assumes each autoregressive step has a lead time of 6 hours. 103 | @param fwd_cost_fn Cost function that takes inputs, targets, and forcings to return a scalar cost. 104 | @param task_config Corresponding task configuration (must match model underlying the fwd_cost_fn) 105 | @param dataloader must support __len__ and __getitem__. Each batch should be an xarray.Dataarray 106 | @param autoregressive_steps how many times to unroll the model. Must match with the batch provided from the dataloader. 107 | """ 108 | costs = [] 109 | for i in range(len(dataloader)): 110 | batch = dataloader[i] 111 | 112 | inputs, targets, forcings = data_utils.extract_inputs_targets_forcings(batch, 113 | target_lead_times=f"{autoregressive_steps * 6}h", 114 | # Note that this sets input duration to 12h 115 | **dataclasses.asdict(task_config)) 116 | 117 | cost = fwd_cost_fn(inputs, targets, forcings) 118 | print(f"Batch {i}/{len(dataloader)} cost: {cost}") 119 | costs.append(cost) 120 | 121 | mean_cost = np.asarray(costs).mean() 122 | print("Mean cost:", mean_cost, " +-", np.asarray(costs).std()) 123 | return mean_cost 124 | 125 | def rmse_forward(inputs, targets, forcings, forward_fn, level: int, variable_weights: dict[str, float]) -> float: 126 | """ 127 | Compute the RMSE given a forward function 128 | @param inputs 129 | @param targets 130 | @param forcings 131 | @param foward_fn Function that computes the prediction. Will be rolled out autoregressively. 132 | @param level Pressure level to evaluate the RMSE on 133 | @param variable_weights weight to assign to each target variable 134 | """ 135 | predictions = rollout.chunked_prediction( 136 | forward_fn, 137 | rng=jax.random.PRNGKey(353), 138 | inputs=inputs, 139 | targets_template=targets * np.nan, 140 | forcings=forcings) 141 | def mse(x, y): 142 | return (x-y) ** 2 143 | 144 | mse_error, _ = losses.weighted_error_per_level(predictions, 145 | targets, 146 | variable_weights, 147 | mse, 148 | functools.partial(losses.single_level_weights, level=level)) 149 | return np.sqrt(mse_error.mean().item()) 150 | 151 | def main(resolution: float = 0.25, 152 | pressure_levels: int = 13, 153 | autoregressive_steps: int = 1, 154 | test_years=None, 155 | test_variable: str = 'geopotential', 156 | test_pressure_level: int = 500, 157 | repetitions: int = 5) -> None: 158 | """ 159 | resolution: resolution of the model in degrees 160 | pressure_levels: number of pressure levels 161 | train: If true, computes gradients of the model (currently does not actually train anything) 162 | autoregressive_steps: How many rollout steps to perform. If 1, a single-step prediction is done. 163 | repetitions: For time measurement purposes, how many repetition are done. 164 | """ 165 | # 166 | # - **Source**: era5, hres 167 | # - **Resolution**: 0.25deg, 1deg 168 | # - **Levels**: 13, 37 169 | # 170 | # Not all combinations are available. 171 | # - HRES is only available in 0.25 deg, with 13 pressure levels. 172 | 173 | if test_years is None: 174 | test_years = [2016] 175 | data_path = os.environ.get('DATA_PATH') 176 | 177 | run_forward, checkpoint = graphcast_wrapper.retrieve_model(resolution, pressure_levels, Path(data_path)) 178 | 179 | # Always pass params so the usage below are simpler 180 | def with_params(fn): 181 | return functools.partial(fn, params=checkpoint.params) 182 | 183 | # # Run the model (Inference) 184 | dataset = WeatherBench2Dataset( 185 | year=test_years[0], 186 | steps=autoregressive_steps, 187 | steps_per_input=3) 188 | 189 | # Compile the forward function and add the configuration and params as partials 190 | run_forward_jitted = with_params(jax.jit(run_forward.apply)) 191 | 192 | # Pick the test_variable as the only one 193 | variable_weights = {var: 1 if var == test_variable else 0 for var in checkpoint.task_config.target_variables} 194 | # We create the loss function by passing the forward function and parameters as a partial to rmse_forward 195 | loss_function = functools.partial(rmse_forward, 196 | forward_fn = run_forward_jitted, 197 | level=test_pressure_level, 198 | variable_weights=variable_weights) 199 | evaluate_model(loss_function, checkpoint.task_config, dataset, autoregressive_steps=autoregressive_steps) 200 | 201 | 202 | if __name__ == "__main__": 203 | parser = argparse.ArgumentParser(description='Inference and training showcase for Graphcast.') 204 | 205 | # Add the arguments with default values 206 | parser.add_argument('--resolution', type=float, default=0.25, help='Resolution of the graph in the model.') 207 | parser.add_argument('--pressure_levels', type=int, default=13, help='Number of pressure levels in the model.') 208 | parser.add_argument('--autoregressive_steps', type=int, default=1, help='Number of time steps to predict into the future.') 209 | parser.add_argument('--test_year_start', type=int, default=2016, help='First year to use for testing (inference).') 210 | parser.add_argument('--test_year_end', type=int, default=2016, help='Last year to use for testing (inference).') 211 | parser.add_argument('--test_pressure_level', type=int, default=500, help='Pressure level to use for testing (inference).') 212 | parser.add_argument('--test_variable', type=str, default='geopotential', help='Variable to use for testing (inference).') 213 | parser.add_argument('--prediction_store_path', type=str, default=None, help='If not none, evaluate predictions and store them here.') 214 | 215 | 216 | # Parse the arguments 217 | args = parser.parse_args() 218 | 219 | # Access the arguments & call main 220 | main(args.resolution, 221 | args.pressure_levels, 222 | args.autoregressive_steps, 223 | list(range(args.test_year_start, args.test_year_end+1)), 224 | args.test_variable, 225 | args.test_pressure_level) 226 | -------------------------------------------------------------------------------- /GraphCast_src/graphcast_wrapper.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from pathlib import Path 3 | 4 | import haiku 5 | import xarray 6 | 7 | import graphcast.casting as casting 8 | import graphcast.normalization as normalization 9 | import graphcast.autoregressive as autoregressive 10 | from graphcast import graphcast 11 | import haiku as hk 12 | from graphcast.checkpoint import load 13 | 14 | from buckets import authenticate_bucket 15 | from demo_data import load_normalization 16 | from pretrained_graphcast import load_model_checkpoint, find_model_name, save_checkpoint 17 | 18 | 19 | def wrap_graphcast(model_config: graphcast.ModelConfig, 20 | task_config: graphcast.TaskConfig, 21 | diffs_stddev_by_level = None, 22 | mean_by_level = None, 23 | stddev_by_level = None, 24 | wrap_autoregressive: bool = False, 25 | normalize: bool = False): 26 | """ 27 | Constructs and wraps the GraphCast Predictor. 28 | Note that this MUST be called within a haiku transform function. 29 | """ 30 | # Deeper one-step predictor. 31 | predictor = graphcast.GraphCast(model_config, task_config) 32 | 33 | # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to 34 | # from/to float32 to/from BFloat16. 35 | predictor = casting.Bfloat16Cast(predictor) 36 | 37 | if normalize: 38 | assert diffs_stddev_by_level is not None 39 | assert mean_by_level is not None 40 | assert stddev_by_level is not None 41 | # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from 42 | # BFloat16 happens after applying normalization to the inputs/targets. 43 | predictor = normalization.InputsAndResiduals( 44 | predictor, 45 | diffs_stddev_by_level=diffs_stddev_by_level, 46 | mean_by_level=mean_by_level, 47 | stddev_by_level=stddev_by_level) 48 | 49 | if wrap_autoregressive: 50 | # Wraps everything so the one-step model can produce trajectories. 51 | predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True) 52 | return predictor 53 | else: 54 | return predictor 55 | 56 | 57 | def retrieve_model(resolution: float, 58 | pressure_levels: int, 59 | model_cache_path: Path = None, 60 | normalize: bool = True, 61 | autoregressive: bool = True, 62 | normalize_cache_path: Path = None) -> tuple[haiku.Transformed, graphcast.CheckPoint]: 63 | """ 64 | Returns the haiku transformed forward function and model checkpoint for the given settings. 65 | The signature of the apply function is 66 | (params, 67 | rng, 68 | inputs: xarray.Dataset, 69 | targets_template: xarray.Dataset, 70 | forcings: xarray.Dataset 71 | ) 72 | Note that there is not rng because the model is deterministic. 73 | 74 | Possible model values: 75 | - **Resolution**: 0.25deg, 1deg 76 | - **Pressure Levels**: 13, 37 77 | 78 | Not all combinations are available. 79 | - HRES is only available in 0.25 deg, with 13 pressure levels. 80 | 81 | """ 82 | gcs_bucket = authenticate_bucket() 83 | 84 | # Choose the model 85 | params_file_options = [ 86 | name for blob in gcs_bucket.list_blobs(prefix="params/") 87 | if (name := blob.name.removeprefix("params/"))] # Drop empty string. 88 | 89 | params_file_name = find_model_name(params_file_options, resolution, pressure_levels) 90 | if params_file_name is None: 91 | raise FileNotFoundError( 92 | f"No model with given resolution ({resolution} deg) and pressure levels ({pressure_levels}) found.") 93 | 94 | # Load the model 95 | # TODO Use cached model if available 96 | if (model_cache_path / params_file_name).exists(): 97 | with open((model_cache_path / params_file_name), "rb") as f: 98 | checkpoint = load(f, graphcast.CheckPoint) 99 | else: 100 | checkpoint = load_model_checkpoint(gcs_bucket, params_file_name) 101 | save_checkpoint(gcs_bucket, model_cache_path, params_file_name) 102 | print(checkpoint.model_config) 103 | 104 | diffs_stddev_by_level, mean_by_level, stddev_by_level = None, None, None 105 | if normalize: 106 | # Load normalization data 107 | diffs_stddev_by_level, mean_by_level, stddev_by_level = load_normalization(gcs_bucket, normalize_cache_path) 108 | 109 | # Build haiku transformed function 110 | def run_forward(inputs: xarray.Dataset, 111 | targets_template: xarray.Dataset, 112 | forcings: xarray.Dataset): 113 | predictor = wrap_graphcast(checkpoint.model_config, 114 | checkpoint.task_config, 115 | diffs_stddev_by_level, 116 | mean_by_level, 117 | stddev_by_level, 118 | wrap_autoregressive=autoregressive, 119 | normalize=normalize) 120 | return predictor(inputs, targets_template=targets_template, forcings=forcings) 121 | 122 | forward = haiku.transform(run_forward) 123 | 124 | return forward, checkpoint -------------------------------------------------------------------------------- /GraphCast_src/plotting.py: -------------------------------------------------------------------------------- 1 | 2 | import datetime 3 | import math 4 | from typing import Optional 5 | from IPython.display import HTML 6 | import ipywidgets as widgets 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | from matplotlib import animation 10 | import numpy as np 11 | import xarray 12 | 13 | def select( 14 | data: xarray.Dataset, 15 | variable: str, 16 | level: Optional[int] = None, 17 | max_steps: Optional[int] = None 18 | ) -> xarray.Dataset: 19 | data = data[variable] 20 | if "batch" in data.dims: 21 | data = data.isel(batch=0) 22 | if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]: 23 | data = data.isel(time=range(0, max_steps)) 24 | if level is not None and "level" in data.coords: 25 | data = data.sel(level=level) 26 | return data 27 | 28 | def scale( 29 | data: xarray.Dataset, 30 | center: Optional[float] = None, 31 | robust: bool = False, 32 | ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]: 33 | vmin = np.nanpercentile(data, (2 if robust else 0)) 34 | vmax = np.nanpercentile(data, (98 if robust else 100)) 35 | if center is not None: 36 | diff = max(vmax - center, center - vmin) 37 | vmin = center - diff 38 | vmax = center + diff 39 | return (data, matplotlib.colors.Normalize(vmin, vmax), 40 | ("RdBu_r" if center is not None else "viridis")) 41 | 42 | def plot_data( 43 | data: dict[str, xarray.Dataset], 44 | fig_title: str, 45 | plot_size: float = 5, 46 | robust: bool = False, 47 | cols: int = 4 48 | ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]: 49 | 50 | first_data = next(iter(data.values()))[0] 51 | max_steps = first_data.sizes.get("time", 1) 52 | assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values()) 53 | 54 | cols = min(cols, len(data)) 55 | rows = math.ceil(len(data) / cols) 56 | figure = plt.figure(figsize=(plot_size * 2 * cols, 57 | plot_size * rows)) 58 | figure.suptitle(fig_title, fontsize=16) 59 | figure.subplots_adjust(wspace=0, hspace=0) 60 | figure.tight_layout() 61 | 62 | images = [] 63 | for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()): 64 | ax = figure.add_subplot(rows, cols, i+1) 65 | ax.set_xticks([]) 66 | ax.set_yticks([]) 67 | ax.set_title(title) 68 | im = ax.imshow( 69 | plot_data.isel(time=0, missing_dims="ignore"), norm=norm, 70 | origin="lower", cmap=cmap) 71 | plt.colorbar( 72 | mappable=im, 73 | ax=ax, 74 | orientation="vertical", 75 | pad=0.02, 76 | aspect=16, 77 | shrink=0.75, 78 | cmap=cmap, 79 | extend=("both" if robust else "neither")) 80 | images.append(im) 81 | 82 | def update(frame): 83 | if "time" in first_data.dims: 84 | td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000) 85 | figure.suptitle(f"{fig_title}, {td}", fontsize=16) 86 | else: 87 | figure.suptitle(fig_title, fontsize=16) 88 | for im, (plot_data, norm, cmap) in zip(images, data.values()): 89 | im.set_data(plot_data.isel(time=frame, missing_dims="ignore")) 90 | 91 | ani = animation.FuncAnimation( 92 | fig=figure, func=update, frames=max_steps, interval=250) 93 | plt.close(figure.number) 94 | return HTML(ani.to_jshtml()) 95 | 96 | 97 | def plot_predictions(predictions, eval_targets): 98 | # @title Choose predictions to plot 99 | 100 | plot_pred_variable = widgets.Dropdown( 101 | options=predictions.data_vars.keys(), 102 | value="2m_temperature", 103 | description="Variable") 104 | plot_pred_level = widgets.Dropdown( 105 | options=predictions.coords["level"].values, 106 | value=500, 107 | description="Level") 108 | plot_pred_robust = widgets.Checkbox(value=True, description="Robust") 109 | plot_pred_max_steps = widgets.IntSlider( 110 | min=1, 111 | max=predictions.dims["time"], 112 | value=predictions.dims["time"], 113 | description="Max steps") 114 | 115 | widgets.VBox([ 116 | plot_pred_variable, 117 | plot_pred_level, 118 | plot_pred_robust, 119 | plot_pred_max_steps, 120 | widgets.Label(value="Run the next cell to plot the predictions. Rerunning this cell clears your selection.") 121 | ]) 122 | 123 | # @title Plot predictions 124 | 125 | plot_size = 5 126 | plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value) 127 | 128 | data = { 129 | "Targets": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), 130 | robust=plot_pred_robust.value), 131 | "Predictions": scale(select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), 132 | robust=plot_pred_robust.value), 133 | "Diff": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) - 134 | select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)), 135 | robust=plot_pred_robust.value, center=0), 136 | } 137 | fig_title = plot_pred_variable.value 138 | if "level" in predictions[plot_pred_variable.value].coords: 139 | fig_title += f" at {plot_pred_level.value} hPa" 140 | 141 | plot_data(data, fig_title, plot_size, plot_pred_robust.value) 142 | -------------------------------------------------------------------------------- /GraphCast_src/pretrained_graphcast.py: -------------------------------------------------------------------------------- 1 | from google.cloud.storage import Bucket 2 | from haiku._src.data_structures import frozendict 3 | 4 | from graphcast import checkpoint 5 | from graphcast import graphcast 6 | import buckets 7 | from pathlib import Path 8 | from typing import Sequence, Union 9 | 10 | 11 | # # Load the Data and initialize the model 12 | 13 | # ## Load the model params 14 | # 15 | # Choose one of the two ways of getting model params: 16 | # - **random**: You'll get random predictions, but you can change the model architecture, which may run faster or fit on your device. 17 | # - **checkpoint**: You'll get sensible predictions, but are limited to the model architecture that it was trained with, which may not fit on your device. In particular generating gradients uses a lot of memory, so you'll need at least 25GB of ram (TPUv4 or A100). 18 | # 19 | # Checkpoints vary across a few axes: 20 | # - The mesh size specifies the internal graph representation of the earth. Smaller meshes will run faster but will have worse outputs. The mesh size does not affect the number of parameters of the model. 21 | # - The resolution and number of pressure levels must match the data. Lower resolution and fewer levels will run a bit faster. Data resolution only affects the encoder/decoder. 22 | # - All our models predict precipitation. However, ERA5 includes precipitation, while HRES does not. Our models marked as "ERA5" take precipitation as input and expect ERA5 data as input, while model marked "ERA5-HRES" do not take precipitation as input and are specifically trained to take HRES-fc0 as input (see the data section below). 23 | # 24 | # We provide three pre-trained models. 25 | # 1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017, 26 | # 27 | # 2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from 1979 to 2015, useful to run a model with lower memory and compute constraints, 28 | # 29 | # 3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13 pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on HRES data from 2016 to 2021. This model can be initialized from HRES data (does not require precipitation inputs). 30 | # 31 | 32 | 33 | def load_model_checkpoint(gcs_bucket: Bucket, name: str) -> graphcast.CheckPoint: 34 | with gcs_bucket.blob(f"params/{name}").open("rb") as f: 35 | ckpt = checkpoint.load(f, graphcast.CheckPoint) 36 | 37 | print("Model description:\n", ckpt.description, "\n") 38 | print("Model license:\n", ckpt.license, "\n") 39 | 40 | return ckpt 41 | 42 | def find_model_name(options: Sequence[str], resolution: float, pressure_level: int) -> Union[str, None]: 43 | for name in options: 44 | if f"resolution {resolution}" in name and f"levels {pressure_level}" in name: 45 | return name 46 | return None 47 | 48 | def save_checkpoint(gcs_bucket: Bucket, directory: Path, name: str) -> None: 49 | # copy checkpoint to local disk 50 | with gcs_bucket.blob(f"params/{name}").open("rb") as param: 51 | buckets.save_to_dir(param, directory / 'params', name) 52 | -------------------------------------------------------------------------------- /GraphCast_src/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module setuptools script.""" 15 | 16 | from setuptools import setup 17 | 18 | description = ( 19 | "GraphCast: Learning skillful medium-range global weather forecasting" 20 | ) 21 | 22 | setup( 23 | name="graphcast", 24 | version="0.1", 25 | description=description, 26 | long_description=description, 27 | author="DeepMind", 28 | license="Apache License, Version 2.0", 29 | keywords="GraphCast Weather Prediction", 30 | url="https://github.com/deepmind/graphcast", 31 | packages=["graphcast"], 32 | install_requires=[ 33 | "cartopy", 34 | "chex", 35 | "colabtools", 36 | "dask", 37 | "dm-haiku", 38 | "dm-tree", 39 | "jax", 40 | "jraph", 41 | "matplotlib", 42 | "numpy", 43 | "pandas", 44 | "rtree", 45 | "scipy", 46 | #"tree", 47 | "trimesh", 48 | "typing_extensions", 49 | "xarray", 50 | "google-cloud-storage", 51 | "zarr", 52 | "gcsfs", 53 | "dask", 54 | "jax-dataloader", 55 | #"tensorflow", 56 | #"tensorboard-plugin-profile", 57 | "nvtx", 58 | "wandb" 59 | ], 60 | classifiers=[ 61 | "Development Status :: 3 - Alpha", 62 | "Intended Audience :: Science/Research", 63 | "License :: OSI Approved :: Apache Software License", 64 | "Operating System :: POSIX :: Linux", 65 | "Programming Language :: Python :: 3", 66 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 67 | "Topic :: Scientific/Engineering :: Atmospheric Science", 68 | "Topic :: Scientific/Engineering :: Physics", 69 | ], 70 | ) 71 | -------------------------------------------------------------------------------- /GraphCast_src/weatherbench2_dataloader.py: -------------------------------------------------------------------------------- 1 | import xarray as xr 2 | from copy import deepcopy 3 | 4 | class WeatherBench2Dataset(): 5 | def __init__(self, year: int, steps: int, steps_per_input:int = 3): 6 | self.vars = ["geopotential", "specific_humidity", "temperature", 7 | "u_component_of_wind", "v_component_of_wind", 8 | "vertical_velocity", "toa_incident_solar_radiation", 9 | "10m_u_component_of_wind", "10m_v_component_of_wind", 10 | "2m_temperature", "mean_sea_level_pressure", 11 | "total_precipitation_6hr", "geopotential_at_surface", "land_sea_mask"] 12 | self.static_vars = ["geopotential_at_surface", "land_sea_mask"] 13 | self.ds = xr.open_zarr('gs://weatherbench2/datasets/era5/1959-2022-6h-1440x721.zarr') 14 | self.ds = self.ds[self.vars] 15 | self.ds = self.ds.sel(time=slice(f'{year}-01-01', f'{year}-12-31')) 16 | self.length = len(self.ds.time) 17 | self.steps = steps 18 | self.steps_per_input = steps_per_input 19 | self.coord_name_dict = dict(latitude="lat", longitude="lon") 20 | 21 | def __len__(self): 22 | num_batches = self.length // self.steps 23 | if self.length % self.steps < self.steps_per_input - 1: 24 | num_batches -= 1 25 | return num_batches 26 | 27 | def __getitem__(self, item): 28 | return self.get_data(item) 29 | 30 | def get_data(self, batch_idx): 31 | batch_idx = batch_idx % len(self) 32 | it_range = slice(batch_idx*self.steps, (batch_idx+1)*self.steps + self.steps_per_input - 1) 33 | static_data = self.ds[self.static_vars].rename(**self.coord_name_dict) 34 | data = self.ds.drop_vars(self.static_vars).isel(time=it_range) 35 | data = data.rename(**self.coord_name_dict) 36 | data = data.isel(lat =slice(None, None, -1)) 37 | data = xr.merge([static_data, data.expand_dims({'batch':1})]) 38 | data = data.assign_coords(datetime=(["batch", "time"], data.time.data.reshape(1, -1))) 39 | data = data.assign_coords(time=("time", data.time.data - data.time.data[0])) 40 | return data.compute() 41 | 42 | if __name__ == "__main__": 43 | dataset = WeatherBench2Dataset(2016, steps=4, steps_per_input=3) 44 | print(len(dataset)) 45 | data = dataset.get_data(-1) 46 | print(data) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Setup environment 2 | Follow `Dockerfile` to setup a container environment for NVIDIA GPUs. 3 | 4 | # Prepare forecast data from GraphCast 5 | Setup a Google Cloud Storage key named `gcs_key.json` and put it at `GraphCast_src/` 6 | 7 | Execute `python3 GraphCast_src/graphcast_runner.py --resolution .25 --pressure_levels 13 --autoregressive_steps 8 --test_year_start 1977 --test_year_end 2016` and `python3 GraphCast_src/graphcast_runner.py --resolution .25 --pressure_levels 13 --autoregressive_steps 1 --test_year_start 1977 --test_year_end 2016` to generate 48h and 6h GraphCast forecast data need for training the diffusion model. 8 | 9 | # Prepare GraphCast weights and parameters 10 | Visit: https://console.cloud.google.com/storage/browser/dm_graphcast and download `GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz` under `params` folder and the whole `stats` folder. 11 | 12 | # Train diffusion model 13 | Execute `python3 DiffDA/train_conditional_graphcast.py` and pass in required arguments (hyperparameters, ERA5, forecast data path, etc.) 14 | 15 | # Run data assimilation 16 | - Execute `python3 DiffDA/inference_data_assimilation.py --num_autoregressive_steps=1 ...` to run single step data assimilation 17 | - Execute `python3 DiffDA/inference_data_assimilation.py --num_autoregressive_steps=n ...` (n > 2) to run autoregressive data assimilation 18 | - Execute `python3 DiffDA/inference_data_assimilation_gc.py --num_autoregressive_steps=n ...` to run (autoregressive) GraphCast forecast on single step assimilated data 19 | 20 | # Implementation detail 21 | - `GraphCast_src/graphcast/normalization.py`: ddpm and repaint algorithm for inference 22 | - `DiffDA/train_conditional_graphcast.py`: training diffusion model with GraphCast as backbone 23 | - `DiffDA/inference_data_assimilation.py`: run single step & autoregressive data assimilation 24 | - `DiffDA/inference_data_assimilation_gc.py`: run GraphCast forecast on single step assimilated data -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur0.5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur0.5_geopotential.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_geopotential.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur0.5_sfc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_sfc.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur0.5_specific_humidity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_specific_humidity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur0.5_temperature.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_temperature.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur0.5_u_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_u_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur0.5_v_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_v_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur0.5_vertical_velocity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_vertical_velocity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.0_geopotential.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_geopotential.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.0_sfc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_sfc.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.0_specific_humidity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_specific_humidity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.0_temperature.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_temperature.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.0_u_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_u_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.0_v_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_v_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.0_vertical_velocity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_vertical_velocity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.5_geopotential.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_geopotential.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.5_sfc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_sfc.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.5_specific_humidity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_specific_humidity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.5_temperature.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_temperature.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.5_u_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_u_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.5_v_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_v_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur1.5_vertical_velocity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_vertical_velocity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.0_geopotential.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_geopotential.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.0_sfc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_sfc.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.0_specific_humidity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_specific_humidity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.0_temperature.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_temperature.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.0_u_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_u_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.0_v_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_v_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.0_vertical_velocity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_vertical_velocity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.5_geopotential.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_geopotential.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.5_sfc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_sfc.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.5_specific_humidity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_specific_humidity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.5_temperature.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_temperature.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.5_u_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_u_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.5_v_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_v_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment1/score_board_0_blur2.5_vertical_velocity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_vertical_velocity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur0.5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur0.5_geopotential.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_geopotential.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur0.5_sfc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_sfc.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur0.5_specific_humidity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_specific_humidity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur0.5_temperature.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_temperature.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur0.5_u_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_u_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur0.5_v_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_v_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur0.5_vertical_velocity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_vertical_velocity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.0_geopotential.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_geopotential.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.0_sfc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_sfc.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.0_specific_humidity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_specific_humidity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.0_temperature.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_temperature.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.0_u_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_u_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.0_v_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_v_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.0_vertical_velocity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_vertical_velocity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.5_geopotential.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_geopotential.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.5_sfc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_sfc.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.5_specific_humidity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_specific_humidity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.5_temperature.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_temperature.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.5_u_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_u_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.5_v_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_v_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur1.5_vertical_velocity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_vertical_velocity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.0_geopotential.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_geopotential.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.0_sfc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_sfc.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.0_specific_humidity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_specific_humidity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.0_temperature.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_temperature.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.0_u_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_u_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.0_v_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_v_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.0_vertical_velocity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_vertical_velocity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.5_geopotential.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_geopotential.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.5_sfc.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_sfc.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.5_specific_humidity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_specific_humidity.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.5_temperature.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_temperature.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.5_u_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_u_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.5_v_component_of_wind.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_v_component_of_wind.pdf -------------------------------------------------------------------------------- /figs/ablation_experiment3/score_board_48_blur2.5_vertical_velocity.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_vertical_velocity.pdf -------------------------------------------------------------------------------- /figs/error_visualization/2m_temperature_1000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_1000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/2m_temperature_10000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_10000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/2m_temperature_2000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_2000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/2m_temperature_20000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_20000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/2m_temperature_4000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_4000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/2m_temperature_40000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_40000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/2m_temperature_8000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_8000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/geopotential_500_1000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_1000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/geopotential_500_10000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_10000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/geopotential_500_2000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_2000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/geopotential_500_20000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_20000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/geopotential_500_4000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_4000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/geopotential_500_40000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_40000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/geopotential_500_8000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_8000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/temperature_850_1000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_1000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/temperature_850_10000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_10000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/temperature_850_2000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_2000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/temperature_850_20000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_20000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/temperature_850_4000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_4000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/temperature_850_40000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_40000.pdf -------------------------------------------------------------------------------- /figs/error_visualization/temperature_850_8000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_8000.pdf --------------------------------------------------------------------------------