├── DiffDA
├── checkpoint.py
├── dataloader.py
├── diffusion_common.py
├── inference_data_assimilation.py
├── inference_data_assimilation_gc.py
├── obs_coords.csv
├── train_conditional_graphcast.py
└── train_state.py
├── Dockerfile
├── GraphCast_src
├── .DS_Store
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── __init__.py
├── buckets.py
├── graphcast
│ ├── __init__.py
│ ├── autoregressive.py
│ ├── casting.py
│ ├── checkpoint.py
│ ├── checkpoint_test.py
│ ├── data_utils.py
│ ├── data_utils_test.py
│ ├── deep_typed_graph_net.py
│ ├── graphcast.py
│ ├── grid_mesh_connectivity.py
│ ├── grid_mesh_connectivity_test.py
│ ├── icosahedral_mesh.py
│ ├── icosahedral_mesh_test.py
│ ├── losses.py
│ ├── model_utils.py
│ ├── normalization.py
│ ├── predictor_base.py
│ ├── rollout.py
│ ├── typed_graph.py
│ ├── typed_graph_net.py
│ ├── xarray_jax.py
│ ├── xarray_jax_test.py
│ ├── xarray_tree.py
│ └── xarray_tree_test.py
├── graphcast_runner.py
├── graphcast_wrapper.py
├── plotting.py
├── pretrained_graphcast.py
├── setup.py
└── weatherbench2_dataloader.py
├── README.md
└── figs
├── ablation_experiment1
├── score_board_0_blur0.5.pdf
├── score_board_0_blur0.5_geopotential.pdf
├── score_board_0_blur0.5_sfc.pdf
├── score_board_0_blur0.5_specific_humidity.pdf
├── score_board_0_blur0.5_temperature.pdf
├── score_board_0_blur0.5_u_component_of_wind.pdf
├── score_board_0_blur0.5_v_component_of_wind.pdf
├── score_board_0_blur0.5_vertical_velocity.pdf
├── score_board_0_blur1.0.pdf
├── score_board_0_blur1.0_geopotential.pdf
├── score_board_0_blur1.0_sfc.pdf
├── score_board_0_blur1.0_specific_humidity.pdf
├── score_board_0_blur1.0_temperature.pdf
├── score_board_0_blur1.0_u_component_of_wind.pdf
├── score_board_0_blur1.0_v_component_of_wind.pdf
├── score_board_0_blur1.0_vertical_velocity.pdf
├── score_board_0_blur1.5.pdf
├── score_board_0_blur1.5_geopotential.pdf
├── score_board_0_blur1.5_sfc.pdf
├── score_board_0_blur1.5_specific_humidity.pdf
├── score_board_0_blur1.5_temperature.pdf
├── score_board_0_blur1.5_u_component_of_wind.pdf
├── score_board_0_blur1.5_v_component_of_wind.pdf
├── score_board_0_blur1.5_vertical_velocity.pdf
├── score_board_0_blur2.0.pdf
├── score_board_0_blur2.0_geopotential.pdf
├── score_board_0_blur2.0_sfc.pdf
├── score_board_0_blur2.0_specific_humidity.pdf
├── score_board_0_blur2.0_temperature.pdf
├── score_board_0_blur2.0_u_component_of_wind.pdf
├── score_board_0_blur2.0_v_component_of_wind.pdf
├── score_board_0_blur2.0_vertical_velocity.pdf
├── score_board_0_blur2.5.pdf
├── score_board_0_blur2.5_geopotential.pdf
├── score_board_0_blur2.5_sfc.pdf
├── score_board_0_blur2.5_specific_humidity.pdf
├── score_board_0_blur2.5_temperature.pdf
├── score_board_0_blur2.5_u_component_of_wind.pdf
├── score_board_0_blur2.5_v_component_of_wind.pdf
└── score_board_0_blur2.5_vertical_velocity.pdf
├── ablation_experiment3
├── score_board_48_blur0.5.pdf
├── score_board_48_blur0.5_geopotential.pdf
├── score_board_48_blur0.5_sfc.pdf
├── score_board_48_blur0.5_specific_humidity.pdf
├── score_board_48_blur0.5_temperature.pdf
├── score_board_48_blur0.5_u_component_of_wind.pdf
├── score_board_48_blur0.5_v_component_of_wind.pdf
├── score_board_48_blur0.5_vertical_velocity.pdf
├── score_board_48_blur1.0.pdf
├── score_board_48_blur1.0_geopotential.pdf
├── score_board_48_blur1.0_sfc.pdf
├── score_board_48_blur1.0_specific_humidity.pdf
├── score_board_48_blur1.0_temperature.pdf
├── score_board_48_blur1.0_u_component_of_wind.pdf
├── score_board_48_blur1.0_v_component_of_wind.pdf
├── score_board_48_blur1.0_vertical_velocity.pdf
├── score_board_48_blur1.5.pdf
├── score_board_48_blur1.5_geopotential.pdf
├── score_board_48_blur1.5_sfc.pdf
├── score_board_48_blur1.5_specific_humidity.pdf
├── score_board_48_blur1.5_temperature.pdf
├── score_board_48_blur1.5_u_component_of_wind.pdf
├── score_board_48_blur1.5_v_component_of_wind.pdf
├── score_board_48_blur1.5_vertical_velocity.pdf
├── score_board_48_blur2.0.pdf
├── score_board_48_blur2.0_geopotential.pdf
├── score_board_48_blur2.0_sfc.pdf
├── score_board_48_blur2.0_specific_humidity.pdf
├── score_board_48_blur2.0_temperature.pdf
├── score_board_48_blur2.0_u_component_of_wind.pdf
├── score_board_48_blur2.0_v_component_of_wind.pdf
├── score_board_48_blur2.0_vertical_velocity.pdf
├── score_board_48_blur2.5.pdf
├── score_board_48_blur2.5_geopotential.pdf
├── score_board_48_blur2.5_sfc.pdf
├── score_board_48_blur2.5_specific_humidity.pdf
├── score_board_48_blur2.5_temperature.pdf
├── score_board_48_blur2.5_u_component_of_wind.pdf
├── score_board_48_blur2.5_v_component_of_wind.pdf
└── score_board_48_blur2.5_vertical_velocity.pdf
└── error_visualization
├── 2m_temperature_1000.pdf
├── 2m_temperature_10000.pdf
├── 2m_temperature_2000.pdf
├── 2m_temperature_20000.pdf
├── 2m_temperature_4000.pdf
├── 2m_temperature_40000.pdf
├── 2m_temperature_8000.pdf
├── geopotential_500_1000.pdf
├── geopotential_500_10000.pdf
├── geopotential_500_2000.pdf
├── geopotential_500_20000.pdf
├── geopotential_500_4000.pdf
├── geopotential_500_40000.pdf
├── geopotential_500_8000.pdf
├── temperature_850_1000.pdf
├── temperature_850_10000.pdf
├── temperature_850_2000.pdf
├── temperature_850_20000.pdf
├── temperature_850_4000.pdf
├── temperature_850_40000.pdf
└── temperature_850_8000.pdf
/DiffDA/checkpoint.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import pickle
3 | from typing import Union
4 |
5 | import haiku as hk
6 | import optax
7 | from graphcast import graphcast, checkpoint
8 | import jax
9 | from pathlib import Path
10 | import os
11 | import re
12 | import diffusers
13 |
14 | @dataclasses.dataclass(frozen=True)
15 | class TrainingCheckpoint:
16 | params: hk.Params
17 | opt_state: optax.OptState
18 | scheduler_state: diffusers.schedulers.scheduling_ddpm_flax.DDPMSchedulerState
19 | task_config: graphcast.TaskConfig
20 | model_config: graphcast.ModelConfig
21 | epoch: int
22 | rng: jax.Array
23 | num_train_timesteps: int = 1000
24 | # TODO: add num_train_timesteps
25 |
26 | def save_checkpoint(directory: Path, ckpt: TrainingCheckpoint) -> None:
27 | """
28 | Stores a checkpoint at the given directory with the name
29 | directory / diff_gc_{epoch}.npz
30 | If the given directory does not exist, it will be created.
31 | If there already exists a checkpoint in the same directory with the same epoch,
32 | it will be overwritten.
33 | """
34 | directory.mkdir(parents=True, exist_ok=True)
35 | save_file = directory / f"diff_gc_{ckpt.epoch}.npz"
36 | with open(save_file, mode="wb") as file:
37 | pickle.dump(ckpt, file)
38 | #checkpoint.dump(file, ckpt)
39 |
40 | def load_checkpoint(directory: Path, epoch: int = -1) -> Union[TrainingCheckpoint, None]:
41 | """
42 | Loads a checkpoint from the given directory. If a non-negative epoch is given,
43 | that checkpoint is loaded. Otherwise, the latest epoch is loaded.
44 | If no checkpoint is found, None is returned.
45 | """
46 | if not directory.exists():
47 | return None
48 | # Define the pattern using a regular expression
49 | pattern = r"diff_gc_(\d+)"
50 | # Initialize an empty list to store the tuples (i, path)
51 | file_tuples = []
52 | # Iterate through all files in the directory
53 | for filename in os.listdir(directory):
54 | # Check if the filename matches the pattern
55 | match = re.match(pattern, filename)
56 | if match:
57 | # Extract the epoch i from the match
58 | i = int(match.group(1))
59 | # Construct the full path to the file
60 | filepath = os.path.join(directory, filename)
61 | # Append the tuple (i, path) to the list
62 | file_tuples.append((i, filepath))
63 |
64 | if len(file_tuples) == 0:
65 | return None
66 |
67 | # Sort the list of tuples based on the integer i
68 | file_tuples.sort()
69 | ckpt_path = None
70 | if epoch < 0:
71 | ckpt_path = file_tuples[-1][1]
72 | else:
73 | for (i, f) in file_tuples:
74 | if i == epoch:
75 | ckpt_path = f
76 | break
77 |
78 | with open(ckpt_path, "rb") as file:
79 | return pickle.load(file)
80 | #return checkpoint.load(file, TrainingCheckpoint)
--------------------------------------------------------------------------------
/DiffDA/diffusion_common.py:
--------------------------------------------------------------------------------
1 | import xarray
2 | import numpy as np
3 | import jax
4 | import jax.numpy as jnp
5 | import haiku as hk
6 | from tqdm import tqdm
7 | import wandb
8 |
9 | from graphcast import graphcast, normalization, casting, xarray_jax, xarray_tree
10 | import graphcast.autoregressive as autoregressive
11 | from graphcast.data_utils import get_day_progress, get_year_progress, featurize_progress
12 |
13 | def get_forcing(time: xarray.DataArray, lon: xarray.DataArray, timesteps: np.ndarray, num_timesteps: int, batch_size: int = 0, forcing_type: str = "diffusion") -> None:
14 |
15 | DAY_PROGRESS = "day_progress"
16 | YEAR_PROGRESS = "year_progress"
17 |
18 | # Compute seconds since epoch.
19 | # Note `data.coords["datetime"].astype("datetime64[s]").astype(np.int64)`
20 | # does not work as xarrays always cast dates into nanoseconds!
21 | batch_dim = ("batch",)
22 | seconds_since_epoch = time.data.astype("datetime64[s]").astype(np.int64)
23 | seconds_since_epoch = seconds_since_epoch.reshape((batch_size, 1))
24 |
25 | # Add year progress features.
26 | year_progress = get_year_progress(seconds_since_epoch)
27 | forcing_dict = {}
28 | forcing_dict.update(featurize_progress(
29 | name=YEAR_PROGRESS, dims=batch_dim + ("time",), progress=year_progress))
30 | # Add day progress features.
31 | day_progress = get_day_progress(seconds_since_epoch, lon.data)
32 | forcing_dict.update(featurize_progress(
33 | name=DAY_PROGRESS,
34 | dims=batch_dim + ("time",) + lon.dims,
35 | progress=day_progress))
36 |
37 | # hijack year_progress_sin for timesteps
38 | if forcing_type == "diffusion":
39 | forcing_dict["year_progress_sin"].data = (timesteps / num_timesteps * 2 - 1).astype(np.float32).reshape(forcing_dict["year_progress_sin"].shape)
40 |
41 | ds_forcing = xarray.Dataset(forcing_dict).drop_vars(["day_progress", "year_progress"])
42 |
43 | return ds_forcing
44 |
45 |
46 | def wrap_graphcast(model_config: graphcast.ModelConfig,
47 | task_config: graphcast.TaskConfig,
48 | stddev_by_level: xarray.Dataset,
49 | mean_by_level: xarray.Dataset,
50 | diffs_stddev_by_level: xarray.Dataset,):
51 | """
52 | Constructs and wraps the GraphCast Predictor.
53 | Note that this MUST be called within a haiku transform function.
54 | """
55 | # Deeper one-step predictor.
56 | predictor = graphcast.GraphCast(model_config, task_config)
57 |
58 | # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
59 | # from/to float32 to/from BFloat16.
60 | predictor = casting.Bfloat16Cast(predictor)
61 |
62 | # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
63 | # BFloat16 happens after applying normalization to the inputs/targets.
64 | predictor = normalization.InputsAndResidualsForDiffusion(predictor,
65 | stddev_by_level=stddev_by_level,
66 | mean_by_level=mean_by_level,
67 | diffs_stddev_by_level=diffs_stddev_by_level)
68 | return predictor
69 |
70 | def wrap_graphcast_prediction(model_config: graphcast.ModelConfig,
71 | task_config: graphcast.TaskConfig,
72 | diffs_stddev_by_level = None,
73 | mean_by_level = None,
74 | stddev_by_level = None,
75 | wrap_autoregressive: bool = False,
76 | normalize: bool = False):
77 | """
78 | Constructs and wraps the GraphCast Predictor.
79 | Note that this MUST be called within a haiku transform function.
80 | """
81 | # Deeper one-step predictor.
82 | predictor = graphcast.GraphCast(model_config, task_config)
83 |
84 | # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
85 | # from/to float32 to/from BFloat16.
86 | predictor = casting.Bfloat16Cast(predictor)
87 |
88 | if normalize:
89 | assert diffs_stddev_by_level is not None
90 | assert mean_by_level is not None
91 | assert stddev_by_level is not None
92 | # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
93 | # BFloat16 happens after applying normalization to the inputs/targets.
94 | predictor = normalization.InputsAndResiduals(
95 | predictor,
96 | diffs_stddev_by_level=diffs_stddev_by_level,
97 | mean_by_level=mean_by_level,
98 | stddev_by_level=stddev_by_level)
99 |
100 | if wrap_autoregressive:
101 | # Wraps everything so the one-step model can produce trajectories.
102 | predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
103 |
104 | return predictor
105 |
106 | def _to_numpy_xarray(array) -> xarray.Dataset:
107 | # Unwrap
108 | vars = xarray_jax.unwrap_vars(array)
109 | coords = xarray_jax.unwrap_coords(array)
110 | # Ensure it's numpy
111 | vars = {n: (array[n].dims, np.asarray(v)) if len(array[n].dims) > 0 else np.asarray(v)for n,v in vars.items()}
112 | coords = {n: (array[n].dims, np.asarray(v)) if len(array[n].dims) > 0 else np.asarray(v) for n, v in coords.items()}
113 | # Create new dataset
114 | copied_dataset = xarray.Dataset(vars, coords)
115 | return copied_dataset
116 |
117 | def _to_jax_xarray(dataset: xarray.Dataset, device) -> xarray_jax.Dataset:
118 | # Unwrap
119 | vars = dataset.variables
120 | coords = dataset.coords
121 | # Ensure it's numpy
122 | vars = {n: (dataset[n].dims, jax.device_put(jnp.array(v.data), device)) if len(dataset[n].dims) > 0 else jax.device_put(jnp.array(v.data), device) for n,v in vars.items()}
123 | coords = {n: (dataset[n].dims, jax.device_put(jnp.array(v.data), device)) if len(dataset[n].dims) > 0 else jax.device_put(jnp.array(v.data), device) for n,v in coords.items()}
124 | # Create new dataset
125 | copied_dataset = xarray_jax.Dataset(vars, coords)
126 | return copied_dataset
127 |
128 | def validation_step(forward_fn: hk.TransformedWithState, norm_original_fn, norm_diff_fn, params, state, validate_dataset, args, device, rng, mode="ddpm"):
129 | pbar = enumerate(validate_dataset)
130 | progress_bar = False
131 | if args.rank == 0:
132 | pbar = tqdm(pbar, desc="Validation", total=len(validate_dataset))
133 | progress_bar = True
134 | for batch_idx, batch in pbar:
135 | rng_batch = jax.random.fold_in(rng, batch_idx)
136 | timesteps = np.ones((args.validation_batch_size,), dtype=np.int32)
137 | inputs_pred = batch['graphcast']
138 | inputs_ground_truth = batch['weatherbench']
139 | inputs_static = batch['static']
140 | datetime = inputs_ground_truth.datetime
141 | lon = inputs_ground_truth.lon
142 | norm_forcings = get_forcing(datetime, lon, timesteps, args.num_train_timesteps, batch_size=args.validation_batch_size)
143 | norm_forcings = _to_jax_xarray(norm_forcings, device)
144 | inputs_pred = _to_jax_xarray(inputs_pred.drop_vars("datetime"), device)
145 | inputs_ground_truth = _to_jax_xarray(inputs_ground_truth.drop_vars("datetime"), device)
146 | inputs_static = _to_jax_xarray(inputs_static, device)
147 | norm_soa = norm_original_fn(inputs_ground_truth["toa_incident_solar_radiation"])
148 | norm_forcings = xarray.merge([norm_forcings, norm_soa])
149 | norm_inputs_pred = norm_original_fn(inputs_pred)
150 | if 'total_precipitation_6hr' in norm_inputs_pred.data_vars and args.resolution == "0.25deg":
151 | norm_inputs_pred = norm_inputs_pred.drop_vars('total_precipitation_6hr')
152 | norm_static = norm_original_fn(inputs_static)
153 | if mode == "ddpm":
154 | corrected_pred, _ = forward_fn.apply(params, state, None, norm_inputs_pred, norm_forcings, norm_static, rng_batch, progress_bar=progress_bar)
155 | elif mode == "repaint" or mode == "dps":
156 | mask = batch["mask"]
157 | measurements_interp = batch["weatherbench_interp"].drop_vars(["datetime", "toa_incident_solar_radiation"])
158 | mask = _to_jax_xarray(mask, device)["mask"] # convert from dataset to dataarray
159 | measurements_interp = _to_jax_xarray(measurements_interp, device)
160 | measurements_diff_interp = measurements_interp - inputs_pred
161 | norm_measurements_diff_interp = norm_diff_fn(measurements_diff_interp)
162 | corrected_pred, _ = forward_fn.apply(params, state, None, mask = mask,
163 | norm_measurements_diff_interp = norm_measurements_diff_interp,
164 | norm_inputs_pred = norm_inputs_pred,
165 | norm_forcings = norm_forcings,
166 | norm_static = norm_static,
167 | rng_batch = rng_batch,
168 | progress_bar = True,)
169 | else:
170 | raise ValueError(f"Unknown mode {mode}")
171 |
172 | diff = corrected_pred - inputs_ground_truth
173 | diff_gc = inputs_pred - inputs_ground_truth
174 |
175 | diff_500hPa = diff.sel(level=500)
176 | diff_gc_500hPa = diff_gc.sel(level=500)
177 |
178 | val_loss = {f"val/{mode}/diffusion_rmse_500hPa/{k}": jnp.sqrt(xarray_jax.unwrap_data((v*v).mean())).item() for k,v in diff_500hPa.data_vars.items()}
179 | val_loss_gc = {f"val/{mode}/graphcast_rmse500hPa/{k}": jnp.sqrt(xarray_jax.unwrap_data((v*v).mean())).item() for k,v in diff_gc_500hPa.data_vars.items()}
180 | #val_loss["rank"] = args.rank
181 |
182 | if args.rank == 0:
183 | pbar.set_postfix(val_loss_z500=val_loss[f"val/{mode}/diffusion_rmse_500hPa/geopotential"], loss_gc_z500=val_loss_gc[f"val/{mode}/graphcast_rmse500hPa/geopotential"])
184 | if args.use_wandb:
185 | log_dict = {**val_loss, **val_loss_gc}
186 | wandb.log(log_dict)
--------------------------------------------------------------------------------
/DiffDA/train_state.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The VDM Authors, Flax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Adapted from
17 | https://github.com/google-research/vdm/blob/main/train_state.py
18 | https://flax.readthedocs.io/en/latest/_modules/flax/training/train_state.html#TrainState.
19 |
20 | But with added EMA of the parameters.
21 | """
22 |
23 | import copy
24 | from typing import Any, Callable, Optional
25 |
26 | from flax import core
27 | from flax import struct
28 | import jax
29 | import optax
30 |
31 | from diffusers.schedulers.scheduling_ddpm_flax import DDPMSchedulerState
32 |
33 | class TrainState(struct.PyTreeNode):
34 | """Simple train state for the common case with a single Optax optimizer.
35 |
36 | Synopsis:
37 |
38 | state = TrainState.create(
39 | apply_fn=model.apply,
40 | params=variables['params'],
41 | tx=tx)
42 | grad_fn = jax.grad(make_loss_fn(state.apply_fn))
43 | for batch in data:
44 | grads = grad_fn(state.params, batch)
45 | state = state.apply_gradients(grads=grads)
46 |
47 | Note that you can easily extend this dataclass by subclassing it for storing
48 | additional data (e.g. additional variable collections).
49 |
50 | For more exotic usecases (e.g. multiple optimizers) it's probably best to
51 | fork the class and modify it.
52 |
53 | Attributes:
54 | step: Counter starts at 0 and is incremented by every call to
55 | `.apply_gradients()`.
56 | apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
57 | convenience to have a shorter params list for the `train_step()` function
58 | in your training loop.
59 | tx: An Optax gradient transformation.
60 | opt_state: The state for `tx`.
61 | """
62 | step: int
63 | params: core.FrozenDict[str, Any]
64 | opt_state: optax.OptState
65 | scheduler_state: DDPMSchedulerState
66 | tx_fn: Callable[[float], optax.GradientTransformation] = struct.field(
67 | pytree_node=False)
68 | apply_fn: Callable = struct.field(pytree_node=False)
69 |
70 | def apply_gradients(self, *, grads, lr, **kwargs):
71 | """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
72 |
73 | Note that internally this function calls `.tx.update()` followed by a call
74 | to `optax.apply_updates()` to update `params` and `opt_state`.
75 |
76 | Args:
77 | grads: Gradients that have the same pytree structure as `.params`.
78 | **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
79 |
80 | Returns:
81 | An updated instance of `self` with `step` incremented by one, `params`
82 | and `opt_state` updated by applying `grads`, and additional attributes
83 | replaced as specified by `kwargs`.
84 | """
85 | tx = self.tx_fn(lr)
86 | updates, new_opt_state = tx.update(
87 | grads, self.opt_state, self.params)
88 | new_params = optax.apply_updates(self.params, updates)
89 |
90 | return self.replace(
91 | step=self.step + 1,
92 | params=new_params,
93 | opt_state=new_opt_state,
94 | **kwargs,
95 | )
96 |
97 | @classmethod
98 | def create(_class, *, apply_fn, params, optax_optimizer, scheduler_state, **kwargs):
99 | """Creates a new instance with `step=0` and initialized `opt_state`."""
100 | # _class is the TrainState class
101 | params = params
102 | opt_state = optax_optimizer(1.).init(params)
103 | return _class(
104 | step=0,
105 | apply_fn=apply_fn,
106 | params=params,
107 | tx_fn=optax_optimizer,
108 | opt_state=opt_state,
109 | scheduler_state=scheduler_state,
110 | **kwargs,
111 | )
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/cuda:11.8.0-devel-ubuntu22.04
2 | ENV DEBIAN_FRONTEND=noninteractive
3 |
4 | RUN apt-get update && apt-get install -y python3 python3-pip python3-venv git curl wget libeccodes0 tensorrt libnvinfer-dev libnvinfer-plugin-dev
5 | RUN pip install --upgrade pip && pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
6 | RUN pip install cartopy chex colabtools dask dm-haiku jraph matplotlib numpy \
7 | pandas rtree scipy trimesh typing_extensions xarray \
8 | ipython ipykernel jupyterlab notebook ipywidgets google-cloud-storage dm-tree \
9 | cfgrib zarr gcsfs dask jax-dataloader tensorflow tensorboard-plugin-profile nvtx \
10 | wandb
11 |
12 | RUN pip install --upgrade diffusers[flax]
13 |
14 |
15 | ENV TF_CPP_MIN_LOG_LEVEL=2
16 |
17 | COPY ./ /workspace/
18 | RUN pip install -e /workspace
19 |
20 | CMD [ "/bin/bash" ]
21 | RUN mkdir -p /workspace
22 | WORKDIR /workspace
--------------------------------------------------------------------------------
/GraphCast_src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/GraphCast_src/.DS_Store
--------------------------------------------------------------------------------
/GraphCast_src/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/GraphCast_src/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/GraphCast_src/README.md:
--------------------------------------------------------------------------------
1 | # GraphCast: Learning skillful medium-range global weather forecasting
2 |
3 | This package contains example code to run and train [GraphCast](https://arxiv.org/abs/2212.12794).
4 | It also provides three pretrained models:
5 |
6 | 1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree
7 | resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017,
8 |
9 | 2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree
10 | resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from
11 | 1979 to 2015, useful to run a model with lower memory and compute constraints,
12 |
13 | 3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13
14 | pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on
15 | HRES data from 2016 to 2021. This model can be initialized from HRES data (does
16 | not require precipitation inputs).
17 |
18 | The model weights, normalization statistics, and example inputs are available on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast).
19 |
20 | Full model training requires downloading the
21 | [ERA5](https://www.ecmwf.int/en/forecasts/datasets/reanalysis-datasets/era5)
22 | dataset, available from [ECMWF](https://www.ecmwf.int/).
23 |
24 | ## Overview of files
25 |
26 | The best starting point is to open `graphcast_demo.ipynb` in [Colaboratory](https://colab.research.google.com/github/deepmind/graphcast/blob/master/graphcast_demo.ipynb), which gives an
27 | example of loading data, generating random weights or load a pre-trained
28 | snapshot, generating predictions, computing the loss and computing gradients.
29 | The one-step implementation of GraphCast architecture, is provided in
30 | `graphcast.py`.
31 |
32 | ### Brief description of library files:
33 |
34 | * `autoregressive.py`: Wrapper used to run (and train) the one-step GraphCast
35 | to produce a sequence of predictions by auto-regressively feeding the
36 | outputs back as inputs at each step, in JAX a differentiable way.
37 | * `casting.py`: Wrapper used around GraphCast to make it work using
38 | BFloat16 precision.
39 | * `checkpoint.py`: Utils to serialize and deserialize trees.
40 | * `data_utils.py`: Utils for data preprocessing.
41 | * `deep_typed_graph_net.py`: General purpose deep graph neural network (GNN)
42 | that operates on `TypedGraph`'s where both inputs and outputs are flat
43 | vectors of features for each of the nodes and edges. `graphcast.py` uses
44 | three of these for the Grid2Mesh GNN, the Multi-mesh GNN and the Mesh2Grid
45 | GNN, respectively.
46 | * `graphcast.py`: The main GraphCast model architecture for one-step of
47 | predictions.
48 | * `grid_mesh_connectivity.py`: Tools for converting between regular grids on a
49 | sphere and triangular meshes.
50 | * `icosahedral_mesh.py`: Definition of an icosahedral multi-mesh.
51 | * `losses.py`: Loss computations, including latitude-weighting.
52 | * `model_utils.py`: Utilities to produce flat node and edge vector features
53 | from input grid data, and to manipulate the node output vectors back
54 | into a multilevel grid data.
55 | * `normalization.py`: Wrapper for the one-step GraphCast used to normalize
56 | inputs according to historical values, and targets according to historical
57 | time differences.
58 | * `predictor_base.py`: Defines the interface of the predictor, which GraphCast
59 | and all of the wrappers implement.
60 | * `rollout.py`: Similar to `autoregressive.py` but used only at inference time
61 | using a python loop to produce longer, but non-differentiable trajectories.
62 | * `typed_graph.py`: Definition of `TypedGraph`'s.
63 | * `typed_graph_net.py`: Implementation of simple graph neural network
64 | building blocks defined over `TypedGraph`'s that can be combined to build
65 | deeper models.
66 | * `xarray_jax.py`: A wrapper to let JAX work with `xarray`s.
67 | * `xarray_tree.py`: An implementation of tree.map_structure that works with
68 | `xarray`s.
69 |
70 |
71 | ### Dependencies.
72 |
73 | [Chex](https://github.com/deepmind/chex),
74 | [Dask](https://github.com/dask/dask),
75 | [Haiku](https://github.com/deepmind/dm-haiku),
76 | [JAX](https://github.com/google/jax),
77 | [JAXline](https://github.com/deepmind/jaxline),
78 | [Jraph](https://github.com/deepmind/jraph),
79 | [Numpy](https://numpy.org/),
80 | [Pandas](https://pandas.pydata.org/),
81 | [Python](https://www.python.org/),
82 | [SciPy](https://scipy.org/),
83 | [Tree](https://github.com/deepmind/tree),
84 | [Trimesh](https://github.com/mikedh/trimesh) and
85 | [XArray](https://github.com/pydata/xarray).
86 |
87 |
88 | ### License and attribution
89 |
90 | The Colab notebook and the associated code are licensed under the Apache
91 | License, Version 2.0. You may obtain a copy of the License at:
92 | https://www.apache.org/licenses/LICENSE-2.0.
93 |
94 | The model weights are made available for use under the terms of the Creative
95 | Commons Attribution-NonCommercial-ShareAlike 4.0 International
96 | (CC BY-NC-SA 4.0). You may obtain a copy of the License at:
97 | https://creativecommons.org/licenses/by-nc-sa/4.0/.
98 |
99 | The weights were trained on ECMWF's ERA5 and HRES data. The colab includes a few
100 | examples of ERA5 and HRES data that can be used as inputs to the models.
101 | ECMWF data product are subject to the following terms:
102 |
103 | 1. Copyright statement: Copyright "© 2023 European Centre for Medium-Range Weather Forecasts (ECMWF)".
104 | 2. Source www.ecmwf.int
105 | 3. Licence Statement: ECMWF data is published under a Creative Commons Attribution 4.0 International (CC BY 4.0). https://creativecommons.org/licenses/by/4.0/
106 | 4. Disclaimer: ECMWF does not accept any liability whatsoever for any error or omission in the data, their availability, or for any loss or damage arising from their use.
107 |
108 | ### Disclaimer
109 |
110 | This is not an officially supported Google product.
111 |
112 | Copyright 2023 DeepMind Technologies Limited.
113 |
114 | ### Citation
115 |
116 | If you use this work, consider citing our [paper](https://arxiv.org/abs/2212.12794):
117 |
118 | ```latex
119 | @article{lam2022graphcast,
120 | title={GraphCast: Learning skillful medium-range global weather forecasting},
121 | author={Remi Lam and Alvaro Sanchez-Gonzalez and Matthew Willson and Peter Wirnsberger and Meire Fortunato and Alexander Pritzel and Suman Ravuri and Timo Ewalds and Ferran Alet and Zach Eaton-Rosen and Weihua Hu and Alexander Merose and Stephan Hoyer and George Holland and Jacklynn Stott and Oriol Vinyals and Shakir Mohamed and Peter Battaglia},
122 | year={2022},
123 | eprint={2212.12794},
124 | archivePrefix={arXiv},
125 | primaryClass={cs.LG}
126 | }
127 | ```
128 |
--------------------------------------------------------------------------------
/GraphCast_src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/GraphCast_src/__init__.py
--------------------------------------------------------------------------------
/GraphCast_src/buckets.py:
--------------------------------------------------------------------------------
1 | from google.cloud import storage
2 | from google.cloud.storage import Bucket
3 | from pathlib import Path
4 | import os
5 |
6 | """
7 | Provides utilities to load and cache gcs buckets.
8 | """
9 |
10 | def parse_file_parts(file_name) -> dict:
11 | return dict(part.split("-", 1) for part in file_name.split("_"))
12 |
13 | def authenticate_bucket() -> Bucket:
14 | # @title Authenticate with Google Cloud Storage
15 | # TODO: Figure out how to access a public cloud bucket without authentication.
16 | os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./gcs_key.json"
17 | gcs_client = storage.Client()
18 | gcs_bucket = gcs_client.get_bucket("dm_graphcast")
19 | return gcs_bucket
20 |
21 | def save_to_dir(blob, directory: Path, name: str) -> None:
22 | if not directory.exists():
23 | directory.mkdir(parents=True, exist_ok=True)
24 | with open(directory / name, "wb") as f:
25 | f.write(blob.read())
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/GraphCast_src/graphcast/__init__.py
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/casting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Wrappers that take care of casting."""
15 |
16 | import contextlib
17 | from typing import Any, Mapping, Tuple
18 |
19 | import chex
20 | from graphcast import predictor_base
21 | import haiku as hk
22 | import jax
23 | import jax.numpy as jnp
24 | import numpy as np
25 | import xarray
26 |
27 |
28 | PyTree = Any
29 |
30 |
31 | class Bfloat16Cast(predictor_base.Predictor):
32 | """Wrapper that casts all inputs to bfloat16 and outputs to targets dtype."""
33 |
34 | def __init__(self, predictor: predictor_base.Predictor, enabled: bool = True):
35 | """Inits the wrapper.
36 |
37 | Args:
38 | predictor: predictor being wrapped.
39 | enabled: disables the wrapper if False, for simpler hyperparameter scans.
40 |
41 | """
42 | self._enabled = enabled
43 | self._predictor = predictor
44 |
45 | def __call__(self,
46 | inputs: xarray.Dataset,
47 | targets_template: xarray.Dataset,
48 | forcings: xarray.Dataset,
49 | **kwargs
50 | ) -> xarray.Dataset:
51 | if not self._enabled:
52 | return self._predictor(inputs, targets_template, forcings, **kwargs)
53 |
54 | with bfloat16_variable_view():
55 | predictions = self._predictor(
56 | *_all_inputs_to_bfloat16(inputs, targets_template, forcings),
57 | **kwargs,)
58 |
59 | predictions_dtype = infer_floating_dtype(predictions)
60 | if predictions_dtype != jnp.bfloat16:
61 | raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
62 |
63 | targets_dtype = infer_floating_dtype(targets_template)
64 | return tree_map_cast(
65 | predictions, input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
66 |
67 | def loss(self,
68 | inputs: xarray.Dataset,
69 | targets: xarray.Dataset,
70 | forcings: xarray.Dataset,
71 | **kwargs,
72 | ) -> predictor_base.LossAndDiagnostics:
73 | if not self._enabled:
74 | return self._predictor.loss(inputs, targets, forcings, **kwargs)
75 |
76 | with bfloat16_variable_view():
77 | loss, scalars = self._predictor.loss(
78 | *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
79 |
80 | if loss.dtype != jnp.bfloat16:
81 | raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
82 |
83 | targets_dtype = infer_floating_dtype(targets)
84 |
85 | # Note that casting back the loss to e.g. float32 should not affect data
86 | # types of the backwards pass, because the first thing the backwards pass
87 | # should do is to go backwards the casting op and cast back to bfloat16
88 | # (and xprofs seem to confirm this).
89 | return tree_map_cast((loss, scalars),
90 | input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
91 |
92 | def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
93 | self,
94 | inputs: xarray.Dataset,
95 | targets: xarray.Dataset,
96 | forcings: xarray.Dataset,
97 | **kwargs,
98 | ) -> Tuple[predictor_base.LossAndDiagnostics,
99 | xarray.Dataset]:
100 | if not self._enabled:
101 | return self._predictor.loss_and_predictions(inputs, targets, forcings, # pytype: disable=bad-return-type # jax-ndarray
102 | **kwargs)
103 |
104 | with bfloat16_variable_view():
105 | (loss, scalars), predictions = self._predictor.loss_and_predictions(
106 | *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
107 |
108 | if loss.dtype != jnp.bfloat16:
109 | raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
110 |
111 | predictions_dtype = infer_floating_dtype(predictions)
112 | if predictions_dtype != jnp.bfloat16:
113 | raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
114 |
115 | targets_dtype = infer_floating_dtype(targets)
116 | return tree_map_cast(((loss, scalars), predictions),
117 | input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
118 |
119 |
120 | def infer_floating_dtype(data_vars: Mapping[str, chex.Array]) -> np.dtype:
121 | """Infers a floating dtype from an input mapping of data."""
122 | dtypes = {
123 | v.dtype
124 | for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
125 | if len(dtypes) != 1:
126 | dtypes_and_shapes = {
127 | k: (v.dtype, v.shape)
128 | for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
129 | raise ValueError(
130 | f'Did not found exactly one floating dtype {dtypes} in input variables:'
131 | f'{dtypes_and_shapes}')
132 | return list(dtypes)[0]
133 |
134 |
135 | def _all_inputs_to_bfloat16(
136 | inputs: xarray.Dataset,
137 | targets: xarray.Dataset,
138 | forcings: xarray.Dataset,
139 | ) -> Tuple[xarray.Dataset,
140 | xarray.Dataset,
141 | xarray.Dataset]:
142 | return (inputs.astype(jnp.bfloat16),
143 | jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets),
144 | forcings.astype(jnp.bfloat16))
145 |
146 |
147 | def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype,
148 | ) -> PyTree:
149 | def cast_fn(x):
150 | if x.dtype == input_dtype:
151 | return x.astype(output_dtype)
152 | return jax.tree_map(cast_fn, inputs)
153 |
154 |
155 | @contextlib.contextmanager
156 | def bfloat16_variable_view(enabled: bool = True):
157 | """Context for Haiku modules with float32 params, but bfloat16 activations.
158 |
159 | It works as follows:
160 | * Every time a variable is requested to be created/set as np.bfloat16,
161 | it will create an underlying float32 variable, instead.
162 | * Every time a variable a variable is requested as bfloat16, it will check the
163 | variable is of float32 type, and cast the variable to bfloat16.
164 |
165 | Note the gradients are still computed and accumulated as float32, because
166 | the params returned by init are float32, so the gradient function with
167 | respect to the params will already include an implicit casting to float32.
168 |
169 | Args:
170 | enabled: Only enables bfloat16 behavior if True.
171 |
172 | Yields:
173 | None
174 | """
175 |
176 | if enabled:
177 | with hk.custom_creator(
178 | _bfloat16_creator, state=True), hk.custom_getter(
179 | _bfloat16_getter, state=True), hk.custom_setter(
180 | _bfloat16_setter):
181 | yield
182 | else:
183 | yield
184 |
185 |
186 | def _bfloat16_creator(next_creator, shape, dtype, init, context):
187 | """Creates float32 variables when bfloat16 is requested."""
188 | if context.original_dtype == jnp.bfloat16:
189 | dtype = jnp.float32
190 | return next_creator(shape, dtype, init)
191 |
192 |
193 | def _bfloat16_getter(next_getter, value, context):
194 | """Casts float32 to bfloat16 when bfloat16 was originally requested."""
195 | if context.original_dtype == jnp.bfloat16:
196 | assert value.dtype == jnp.float32
197 | value = value.astype(jnp.bfloat16)
198 | return next_getter(value)
199 |
200 |
201 | def _bfloat16_setter(next_setter, value, context):
202 | """Casts bfloat16 to float32 when bfloat16 was originally set."""
203 | if context.original_dtype == jnp.bfloat16:
204 | value = value.astype(jnp.float32)
205 | return next_setter(value)
206 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Serialize and deserialize trees."""
15 |
16 | import dataclasses
17 | import io
18 | import types
19 | from typing import Any, BinaryIO, Optional, TypeVar
20 |
21 | import numpy as np
22 |
23 | _T = TypeVar("_T")
24 |
25 |
26 | def dump(dest: BinaryIO, value: Any) -> None:
27 | """Dump a tree of dicts/dataclasses to a file object.
28 |
29 | Args:
30 | dest: a file object to write to.
31 | value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and
32 | other basic types. Unions are not supported, other than Optional/None
33 | which is only supported in dataclasses, not in dicts, lists or tuples.
34 | All leaves must be coercible to a numpy array, and recoverable as a single
35 | arg to a type.
36 | """
37 | buffer = io.BytesIO() # In case the destination doesn't support seeking.
38 | np.savez(buffer, **_flatten(value))
39 | dest.write(buffer.getvalue())
40 |
41 |
42 | def load(source: BinaryIO, typ: type[_T]) -> _T:
43 | """Load from a file object and convert it to the specified type.
44 |
45 | Args:
46 | source: a file object to read from.
47 | typ: a type object that acts as a schema for deserialization. It must match
48 | what was serialized. If a type is Any, it will be returned however numpy
49 | serialized it, which is what you want for a tree of numpy arrays.
50 |
51 | Returns:
52 | the deserialized value as the specified type.
53 | """
54 | return _convert_types(typ, _unflatten(np.load(source)))
55 |
56 |
57 | _SEP = ":"
58 |
59 |
60 | def _flatten(tree: Any) -> dict[str, Any]:
61 | """Flatten a tree of dicts/dataclasses/lists/tuples to a single dict."""
62 | if dataclasses.is_dataclass(tree):
63 | # Don't use dataclasses.asdict as it is recursive so skips dropping None.
64 | tree = {f.name: v for f in dataclasses.fields(tree)
65 | if (v := getattr(tree, f.name)) is not None}
66 | elif isinstance(tree, (list, tuple)):
67 | tree = dict(enumerate(tree))
68 |
69 | assert isinstance(tree, dict)
70 |
71 | flat = {}
72 | for k, v in tree.items():
73 | k = str(k)
74 | assert _SEP not in k
75 | if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)):
76 | for a, b in _flatten(v).items():
77 | flat[f"{k}{_SEP}{a}"] = b
78 | else:
79 | assert v is not None
80 | flat[k] = v
81 | return flat
82 |
83 |
84 | def _unflatten(flat: dict[str, Any]) -> dict[str, Any]:
85 | """Unflatten a dict to a tree of dicts."""
86 | tree = {}
87 | for flat_key, v in flat.items():
88 | node = tree
89 | keys = flat_key.split(_SEP)
90 | for k in keys[:-1]:
91 | if k not in node:
92 | node[k] = {}
93 | node = node[k]
94 | node[keys[-1]] = v
95 | return tree
96 |
97 |
98 | def _convert_types(typ: type[_T], value: Any) -> _T:
99 | """Convert some structure into the given type. The structures must match."""
100 | if typ in (Any, ...):
101 | return value
102 |
103 | if typ in (int, float, str, bool):
104 | return typ(value)
105 |
106 | if typ is np.ndarray:
107 | assert isinstance(value, np.ndarray)
108 | return value
109 |
110 | if dataclasses.is_dataclass(typ):
111 | kwargs = {}
112 | for f in dataclasses.fields(typ):
113 | # Only support Optional for dataclasses, as numpy can't serialize it
114 | # directly (without pickle), and dataclasses are the only case where we
115 | # can know the full set of values and types and therefore know the
116 | # non-existence must mean None.
117 | if isinstance(f.type, (types.UnionType, type(Optional[int]))):
118 | constructors = [t for t in f.type.__args__ if t is not types.NoneType]
119 | if len(constructors) != 1:
120 | raise TypeError(
121 | "Optional works, Union with anything except None doesn't")
122 | if f.name not in value:
123 | kwargs[f.name] = None
124 | continue
125 | constructor = constructors[0]
126 | else:
127 | constructor = f.type
128 |
129 | if f.name in value:
130 | kwargs[f.name] = _convert_types(constructor, value[f.name])
131 | else:
132 | raise ValueError(f"Missing value: {f.name}")
133 | return typ(**kwargs)
134 |
135 | base_type = getattr(typ, "__origin__", None)
136 |
137 | if base_type is dict:
138 | assert len(typ.__args__) == 2
139 | key_type, value_type = typ.__args__
140 | return {_convert_types(key_type, k): _convert_types(value_type, v)
141 | for k, v in value.items()}
142 |
143 | if base_type is list:
144 | assert len(typ.__args__) == 1
145 | value_type = typ.__args__[0]
146 | return [_convert_types(value_type, v)
147 | for _, v in sorted(value.items(), key=lambda x: int(x[0]))]
148 |
149 | if base_type is tuple:
150 | if len(typ.__args__) == 2 and typ.__args__[1] == ...:
151 | # An arbitrary length tuple of a single type, eg: tuple[int, ...]
152 | value_type = typ.__args__[0]
153 | return tuple(_convert_types(value_type, v)
154 | for _, v in sorted(value.items(), key=lambda x: int(x[0])))
155 | else:
156 | # A fixed length tuple of arbitrary types, eg: tuple[int, str, float]
157 | assert len(typ.__args__) == len(value)
158 | return tuple(
159 | _convert_types(t, v)
160 | for t, (_, v) in zip(
161 | typ.__args__, sorted(value.items(), key=lambda x: int(x[0]))))
162 |
163 | # This is probably unreachable with reasonable serializable inputs.
164 | try:
165 | return typ(value)
166 | except TypeError as e:
167 | raise TypeError(
168 | "_convert_types expects the type argument to be a dataclass defined "
169 | "with types that are valid constructors (eg tuple is fine, Tuple "
170 | "isn't), and accept a numpy array as the sole argument.") from e
171 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/checkpoint_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Check that the checkpoint serialization is reversable."""
15 |
16 | import dataclasses
17 | import io
18 | from typing import Any, Optional, Union
19 |
20 | from absl.testing import absltest
21 | from graphcast import checkpoint
22 | import numpy as np
23 |
24 |
25 | @dataclasses.dataclass
26 | class SubConfig:
27 | a: int
28 | b: str
29 |
30 |
31 | @dataclasses.dataclass
32 | class Config:
33 | bt: bool
34 | bf: bool
35 | i: int
36 | f: float
37 | o1: Optional[int]
38 | o2: Optional[int]
39 | o3: Union[int, None]
40 | o4: Union[int, None]
41 | o5: int | None
42 | o6: int | None
43 | li: list[int]
44 | ls: list[str]
45 | ldc: list[SubConfig]
46 | tf: tuple[float, ...]
47 | ts: tuple[str, ...]
48 | t: tuple[str, int, SubConfig]
49 | tdc: tuple[SubConfig, ...]
50 | dsi: dict[str, int]
51 | dss: dict[str, str]
52 | dis: dict[int, str]
53 | dsdis: dict[str, dict[int, str]]
54 | dc: SubConfig
55 | dco: Optional[SubConfig]
56 | ddc: dict[str, SubConfig]
57 |
58 |
59 | @dataclasses.dataclass
60 | class Checkpoint:
61 | params: dict[str, Any]
62 | config: Config
63 |
64 |
65 | class DataclassTest(absltest.TestCase):
66 |
67 | def test_serialize_dataclass(self):
68 | ckpt = Checkpoint(
69 | params={
70 | "layer1": {
71 | "w": np.arange(10).reshape(2, 5),
72 | "b": np.array([2, 6]),
73 | },
74 | "layer2": {
75 | "w": np.arange(8).reshape(2, 4),
76 | "b": np.array([2, 6]),
77 | },
78 | "blah": np.array([3, 9]),
79 | },
80 | config=Config(
81 | bt=True,
82 | bf=False,
83 | i=42,
84 | f=3.14,
85 | o1=1,
86 | o2=None,
87 | o3=2,
88 | o4=None,
89 | o5=3,
90 | o6=None,
91 | li=[12, 9, 7, 15, 16, 14, 1, 6, 11, 4, 10, 5, 13, 3, 8, 2],
92 | ls=list("qhjfdxtpzgemryoikwvblcaus"),
93 | ldc=[SubConfig(1, "hello"), SubConfig(2, "world")],
94 | tf=(1, 4, 2, 10, 5, 9, 13, 16, 15, 8, 12, 7, 11, 14, 3, 6),
95 | ts=("hello", "world"),
96 | t=("foo", 42, SubConfig(1, "bar")),
97 | tdc=(SubConfig(1, "hello"), SubConfig(2, "world")),
98 | dsi={"a": 1, "b": 2, "c": 3},
99 | dss={"d": "e", "f": "g"},
100 | dis={1: "a", 2: "b", 3: "c"},
101 | dsdis={"a": {1: "hello", 2: "world"}, "b": {1: "world"}},
102 | dc=SubConfig(1, "hello"),
103 | dco=None,
104 | ddc={"a": SubConfig(1, "hello"), "b": SubConfig(2, "world")},
105 | ))
106 |
107 | buffer = io.BytesIO()
108 | checkpoint.dump(buffer, ckpt)
109 | buffer.seek(0)
110 | ckpt2 = checkpoint.load(buffer, Checkpoint)
111 | np.testing.assert_array_equal(ckpt.params["layer1"]["w"],
112 | ckpt2.params["layer1"]["w"])
113 | np.testing.assert_array_equal(ckpt.params["layer1"]["b"],
114 | ckpt2.params["layer1"]["b"])
115 | np.testing.assert_array_equal(ckpt.params["layer2"]["w"],
116 | ckpt2.params["layer2"]["w"])
117 | np.testing.assert_array_equal(ckpt.params["layer2"]["b"],
118 | ckpt2.params["layer2"]["b"])
119 | np.testing.assert_array_equal(ckpt.params["blah"], ckpt2.params["blah"])
120 | self.assertEqual(ckpt.config, ckpt2.config)
121 |
122 |
123 | if __name__ == "__main__":
124 | absltest.main()
125 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataset utilities."""
15 |
16 | from typing import Any, Mapping, Sequence, Tuple, Union
17 |
18 | import numpy as np
19 | import pandas as pd
20 | import xarray
21 |
22 | TimedeltaLike = Any # Something convertible to pd.Timedelta.
23 | TimedeltaStr = str # A string convertible to pd.Timedelta.
24 |
25 | TargetLeadTimes = Union[
26 | TimedeltaLike,
27 | Sequence[TimedeltaLike],
28 | slice # with TimedeltaLike as its start and stop.
29 | ]
30 |
31 | _SEC_PER_HOUR = 3600
32 | _HOUR_PER_DAY = 24
33 | SEC_PER_DAY = _SEC_PER_HOUR * _HOUR_PER_DAY
34 | _AVG_DAY_PER_YEAR = 365.24219
35 | AVG_SEC_PER_YEAR = SEC_PER_DAY * _AVG_DAY_PER_YEAR
36 |
37 | DAY_PROGRESS = "day_progress"
38 | YEAR_PROGRESS = "year_progress"
39 |
40 |
41 | def get_year_progress(seconds_since_epoch: np.ndarray) -> np.ndarray:
42 | """Computes year progress for times in seconds.
43 |
44 | Args:
45 | seconds_since_epoch: Times in seconds since the "epoch" (the point at which
46 | UNIX time starts).
47 |
48 | Returns:
49 | Year progress normalized to be in the [0, 1) interval for each time point.
50 | """
51 |
52 | # Start with the pure integer division, and then float at the very end.
53 | # We will try to keep as much precision as possible.
54 | years_since_epoch = (
55 | seconds_since_epoch / SEC_PER_DAY / np.float64(_AVG_DAY_PER_YEAR)
56 | )
57 | # Note depending on how these ops are down, we may end up with a "weak_type"
58 | # which can cause issues in subtle ways, and hard to track here.
59 | # In any case, casting to float32 should get rid of the weak type.
60 | # [0, 1.) Interval.
61 | return np.mod(years_since_epoch, 1.0).astype(np.float32)
62 |
63 |
64 | def get_day_progress(
65 | seconds_since_epoch: np.ndarray,
66 | longitude: np.ndarray,
67 | ) -> np.ndarray:
68 | """Computes day progress for times in seconds at each longitude.
69 |
70 | Args:
71 | seconds_since_epoch: 1D array of times in seconds since the 'epoch' (the
72 | point at which UNIX time starts).
73 | longitude: 1D array of longitudes at which day progress is computed.
74 |
75 | Returns:
76 | 2D array of day progress values normalized to be in the [0, 1) inverval
77 | for each time point at each longitude.
78 | """
79 |
80 | # [0.0, 1.0) Interval.
81 | day_progress_greenwich = (
82 | np.mod(seconds_since_epoch, SEC_PER_DAY) / SEC_PER_DAY
83 | )
84 |
85 | # Offset the day progress to the longitude of each point on Earth.
86 | longitude_offsets = np.deg2rad(longitude) / (2 * np.pi)
87 | day_progress = np.mod(
88 | day_progress_greenwich[..., np.newaxis] + longitude_offsets, 1.0
89 | )
90 | return day_progress.astype(np.float32)
91 |
92 |
93 | def featurize_progress(
94 | name: str, dims: Sequence[str], progress: np.ndarray
95 | ) -> Mapping[str, xarray.Variable]:
96 | """Derives features used by ML models from the `progress` variable.
97 |
98 | Args:
99 | name: Base variable name from which features are derived.
100 | dims: List of the output feature dimensions, e.g. ("day", "lon").
101 | progress: Progress variable values.
102 |
103 | Returns:
104 | Dictionary of xarray variables derived from the `progress` values. It
105 | includes the original `progress` variable along with its sin and cos
106 | transformations.
107 |
108 | Raises:
109 | ValueError if the number of feature dimensions is not equal to the number
110 | of data dimensions.
111 | """
112 | if len(dims) != progress.ndim:
113 | raise ValueError(
114 | f"Number of feature dimensions ({len(dims)}) must be equal to the"
115 | f" number of data dimensions: {progress.ndim}."
116 | )
117 | progress_phase = progress * (2 * np.pi)
118 | return {
119 | name: xarray.Variable(dims, progress),
120 | name + "_sin": xarray.Variable(dims, np.sin(progress_phase)),
121 | name + "_cos": xarray.Variable(dims, np.cos(progress_phase)),
122 | }
123 |
124 |
125 | def add_derived_vars(data: xarray.Dataset) -> None:
126 | """Adds year and day progress features to `data` in place.
127 |
128 | NOTE: `toa_incident_solar_radiation` needs to be computed in this function
129 | as well.
130 |
131 | Args:
132 | data: Xarray dataset to which derived features will be added.
133 |
134 | Raises:
135 | ValueError if `datetime` or `lon` are not in `data` coordinates.
136 | """
137 |
138 | for coord in ("datetime", "lon"):
139 | if coord not in data.coords:
140 | raise ValueError(f"'{coord}' must be in `data` coordinates.")
141 |
142 | # Compute seconds since epoch.
143 | # Note `data.coords["datetime"].astype("datetime64[s]").astype(np.int64)`
144 | # does not work as xarrays always cast dates into nanoseconds!
145 | seconds_since_epoch = (
146 | data.coords["datetime"].data.astype("datetime64[s]").astype(np.int64)
147 | )
148 | batch_dim = ("batch",) if "batch" in data.dims else ()
149 |
150 | # Add year progress features.
151 | year_progress = get_year_progress(seconds_since_epoch)
152 | data.update(
153 | featurize_progress(
154 | name=YEAR_PROGRESS, dims=batch_dim + ("time",), progress=year_progress
155 | )
156 | )
157 |
158 | # Add day progress features.
159 | longitude_coord = data.coords["lon"]
160 | day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data)
161 | data.update(
162 | featurize_progress(
163 | name=DAY_PROGRESS,
164 | dims=batch_dim + ("time",) + longitude_coord.dims,
165 | progress=day_progress,
166 | )
167 | )
168 |
169 |
170 | def extract_input_target_times(
171 | dataset: xarray.Dataset,
172 | input_duration: TimedeltaLike,
173 | target_lead_times: TargetLeadTimes,
174 | ) -> Tuple[xarray.Dataset, xarray.Dataset]:
175 | """Extracts inputs and targets for prediction, from a Dataset with a time dim.
176 |
177 | The input period is assumed to be contiguous (specified by a duration), but
178 | the targets can be a list of arbitrary lead times.
179 |
180 | Examples:
181 |
182 | # Use 18 hours of data as inputs, and two specific lead times as targets:
183 | # 3 days and 5 days after the final input.
184 | extract_inputs_targets(
185 | dataset,
186 | input_duration='18h',
187 | target_lead_times=('3d', '5d')
188 | )
189 |
190 | # Use 1 day of data as input, and all lead times between 6 hours and
191 | # 24 hours inclusive as targets. Demonstrates a friendlier supported string
192 | # syntax.
193 | extract_inputs_targets(
194 | dataset,
195 | input_duration='1 day',
196 | target_lead_times=slice('6 hours', '24 hours')
197 | )
198 |
199 | # Just use a single target lead time of 3 days:
200 | extract_inputs_targets(
201 | dataset,
202 | input_duration='24h',
203 | target_lead_times='3d'
204 | )
205 |
206 | Args:
207 | dataset: An xarray.Dataset with a 'time' dimension whose coordinates are
208 | timedeltas. It's assumed that the time coordinates have a fixed offset /
209 | time resolution, and that the input_duration and target_lead_times are
210 | multiples of this.
211 | input_duration: pandas.Timedelta or something convertible to it (e.g. a
212 | shorthand string like '6h' or '5d12h').
213 | target_lead_times: Either a single lead time, a slice with start and stop
214 | (inclusive) lead times, or a sequence of lead times. Lead times should be
215 | Timedeltas (or something convertible to). They are given relative to the
216 | final input timestep, and should be positive.
217 |
218 | Returns:
219 | inputs:
220 | targets:
221 | Two datasets with the same shape as the input dataset except that a
222 | selection has been made from the time axis, and the origin of the
223 | time coordinate will be shifted to refer to lead times relative to the
224 | final input timestep. So for inputs the times will end at lead time 0,
225 | for targets the time coordinates will refer to the lead times requested.
226 | """
227 |
228 | (target_lead_times, target_duration
229 | ) = _process_target_lead_times_and_get_duration(target_lead_times)
230 |
231 | # Shift the coordinates for the time axis so that a timedelta of zero
232 | # corresponds to the forecast reference time. That is, the final timestep
233 | # that's available as input to the forecast, with all following timesteps
234 | # forming the target period which needs to be predicted.
235 | # This means the time coordinates are now forecast lead times.
236 | time = dataset.coords["time"]
237 | dataset = dataset.assign_coords(time=time + target_duration - time[-1])
238 |
239 | # Slice out targets:
240 | targets = dataset.sel({"time": target_lead_times})
241 |
242 | input_duration = pd.Timedelta(input_duration)
243 | # Both endpoints are inclusive with label-based slicing, so we offset by a
244 | # small epsilon to make one of the endpoints non-inclusive:
245 | zero = pd.Timedelta(0)
246 | epsilon = pd.Timedelta(1, "ns")
247 | inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)})
248 | return inputs, targets
249 |
250 |
251 | def _process_target_lead_times_and_get_duration(
252 | target_lead_times: TargetLeadTimes) -> TimedeltaLike:
253 | """Returns the minimum duration for the target lead times."""
254 | if isinstance(target_lead_times, slice):
255 | # A slice of lead times. xarray already accepts timedelta-like values for
256 | # the begin/end/step of the slice.
257 | if target_lead_times.start is None:
258 | # If the start isn't specified, we assume it starts at the next timestep
259 | # after lead time 0 (lead time 0 is the final input timestep):
260 | target_lead_times = slice(
261 | pd.Timedelta(1, "ns"), target_lead_times.stop, target_lead_times.step
262 | )
263 | target_duration = pd.Timedelta(target_lead_times.stop)
264 | else:
265 | if not isinstance(target_lead_times, (list, tuple, set)):
266 | # A single lead time, which we wrap as a length-1 array to ensure there
267 | # still remains a time dimension (here of length 1) for consistency.
268 | target_lead_times = [target_lead_times]
269 |
270 | # A list of multiple (not necessarily contiguous) lead times:
271 | target_lead_times = [pd.Timedelta(x) for x in target_lead_times]
272 | target_lead_times.sort()
273 | target_duration = target_lead_times[-1]
274 | return target_lead_times, target_duration
275 |
276 |
277 | def extract_inputs_targets_forcings(
278 | dataset: xarray.Dataset,
279 | *,
280 | input_variables: Tuple[str, ...],
281 | target_variables: Tuple[str, ...],
282 | forcing_variables: Tuple[str, ...],
283 | pressure_levels: Tuple[int, ...],
284 | input_duration: TimedeltaLike,
285 | target_lead_times: TargetLeadTimes,
286 | ) -> Tuple[xarray.Dataset, xarray.Dataset, xarray.Dataset]:
287 | """Extracts inputs, targets and forcings according to requirements."""
288 | dataset = dataset.sel(level=list(pressure_levels))
289 |
290 | # "Forcings" are derived variables and do not exist in the original ERA5 or
291 | # HRES datasets. Compute them if they are not in `dataset`.
292 | if not set(forcing_variables).issubset(set(dataset.data_vars)):
293 | add_derived_vars(dataset)
294 |
295 | # `datetime` is needed by add_derived_vars but breaks autoregressive rollouts.
296 | dataset = dataset.drop_vars("datetime")
297 |
298 | inputs, targets = extract_input_target_times(
299 | dataset,
300 | input_duration=input_duration,
301 | target_lead_times=target_lead_times)
302 |
303 | if set(forcing_variables) & set(target_variables):
304 | raise ValueError(
305 | f"Forcing variables {forcing_variables} should not "
306 | f"overlap with target variables {target_variables}."
307 | )
308 |
309 | inputs = inputs[list(input_variables)]
310 | # The forcing uses the same time coordinates as the target.
311 | forcings = targets[list(forcing_variables)]
312 | targets = targets[list(target_variables)]
313 |
314 | return inputs, targets, forcings
315 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/data_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for `data_utils.py`."""
15 |
16 | import datetime
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from graphcast import data_utils
20 | import numpy as np
21 | import xarray
22 |
23 |
24 | class DataUtilsTest(parameterized.TestCase):
25 |
26 | def setUp(self):
27 | super().setUp()
28 | # Fix the seed for reproducibility.
29 | np.random.seed(0)
30 |
31 | def test_year_progress_is_zero_at_year_start_or_end(self):
32 | year_progress = data_utils.get_year_progress(
33 | np.array([
34 | 0,
35 | data_utils.AVG_SEC_PER_YEAR,
36 | data_utils.AVG_SEC_PER_YEAR * 42, # 42 years.
37 | ])
38 | )
39 | np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape))
40 |
41 | def test_year_progress_is_almost_one_before_year_ends(self):
42 | year_progress = data_utils.get_year_progress(
43 | np.array([
44 | data_utils.AVG_SEC_PER_YEAR - 1,
45 | (data_utils.AVG_SEC_PER_YEAR - 1) * 42, # ~42 years
46 | ])
47 | )
48 | with self.subTest("Year progress values are close to 1"):
49 | self.assertTrue(np.all(year_progress > 0.999))
50 | with self.subTest("Year progress values != 1"):
51 | self.assertTrue(np.all(year_progress < 1.0))
52 |
53 | def test_day_progress_computes_for_all_times_and_longitudes(self):
54 | times = np.random.randint(low=0, high=1e10, size=10)
55 | longitudes = np.arange(0, 360.0, 1.0)
56 | day_progress = data_utils.get_day_progress(times, longitudes)
57 | with self.subTest("Day progress is computed for all times and longinutes"):
58 | self.assertSequenceEqual(
59 | day_progress.shape, (len(times), len(longitudes))
60 | )
61 |
62 | @parameterized.named_parameters(
63 | dict(
64 | testcase_name="random_date_1",
65 | year=1988,
66 | month=11,
67 | day=7,
68 | hour=2,
69 | minute=45,
70 | second=34,
71 | ),
72 | dict(
73 | testcase_name="random_date_2",
74 | year=2022,
75 | month=3,
76 | day=12,
77 | hour=7,
78 | minute=1,
79 | second=0,
80 | ),
81 | )
82 | def test_day_progress_is_in_between_zero_and_one(
83 | self, year, month, day, hour, minute, second
84 | ):
85 | # Datetime from a timestamp.
86 | dt = datetime.datetime(year, month, day, hour, minute, second)
87 | # Epoch time.
88 | epoch_time = datetime.datetime(1970, 1, 1)
89 | # Seconds since epoch.
90 | seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()])
91 |
92 | # Longitudes with 1 degree resolution.
93 | longitudes = np.arange(0, 360.0, 1.0)
94 |
95 | day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes)
96 | with self.subTest("Day progress >= 0"):
97 | self.assertTrue(np.all(day_progress >= 0.0))
98 | with self.subTest("Day progress < 1"):
99 | self.assertTrue(np.all(day_progress < 1.0))
100 |
101 | def test_day_progress_is_zero_at_day_start_or_end(self):
102 | day_progress = data_utils.get_day_progress(
103 | seconds_since_epoch=np.array([
104 | 0,
105 | data_utils.SEC_PER_DAY,
106 | data_utils.SEC_PER_DAY * 42, # 42 days.
107 | ]),
108 | longitude=np.array([0.0]),
109 | )
110 | np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape))
111 |
112 | def test_day_progress_specific_value(self):
113 | day_progress = data_utils.get_day_progress(
114 | seconds_since_epoch=np.array([123]),
115 | longitude=np.array([0.0]),
116 | )
117 | np.testing.assert_array_almost_equal(
118 | day_progress, np.array([[0.00142361]]), decimal=6
119 | )
120 |
121 | def test_featurize_progress_valid_values_and_dimensions(self):
122 | day_progress = np.array([0.0, 0.45, 0.213])
123 | feature_dimensions = ("time",)
124 | progress_features = data_utils.featurize_progress(
125 | name="day_progress", dims=feature_dimensions, progress=day_progress
126 | )
127 | for feature in progress_features.values():
128 | with self.subTest(f"Valid dimensions for {feature}"):
129 | self.assertSequenceEqual(feature.dims, feature_dimensions)
130 |
131 | with self.subTest("Valid values for day_progress"):
132 | np.testing.assert_array_equal(
133 | day_progress, progress_features["day_progress"].values
134 | )
135 |
136 | with self.subTest("Valid values for day_progress_sin"):
137 | np.testing.assert_array_almost_equal(
138 | np.array([0.0, 0.30901699, 0.97309851]),
139 | progress_features["day_progress_sin"].values,
140 | decimal=6,
141 | )
142 |
143 | with self.subTest("Valid values for day_progress_cos"):
144 | np.testing.assert_array_almost_equal(
145 | np.array([1.0, -0.95105652, 0.23038943]),
146 | progress_features["day_progress_cos"].values,
147 | decimal=6,
148 | )
149 |
150 | def test_featurize_progress_invalid_dimensions(self):
151 | year_progress = np.array([0.0, 0.45, 0.213])
152 | feature_dimensions = ("time", "longitude")
153 | with self.assertRaises(ValueError):
154 | data_utils.featurize_progress(
155 | name="year_progress", dims=feature_dimensions, progress=year_progress
156 | )
157 |
158 | def test_add_derived_vars_variables_added(self):
159 | data = xarray.Dataset(
160 | data_vars={
161 | "var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3))
162 | },
163 | coords={
164 | "lon": np.array([0.0, 0.5]),
165 | "datetime": np.array([
166 | datetime.datetime(2021, 1, 1),
167 | datetime.datetime(2023, 1, 1),
168 | datetime.datetime(2023, 1, 3),
169 | ]),
170 | },
171 | )
172 | data_utils.add_derived_vars(data)
173 | all_variables = set(data.variables)
174 |
175 | with self.subTest("Original value was not removed"):
176 | self.assertIn("var1", all_variables)
177 | with self.subTest("Year progress feature was added"):
178 | self.assertIn(data_utils.YEAR_PROGRESS, all_variables)
179 | with self.subTest("Day progress feature was added"):
180 | self.assertIn(data_utils.DAY_PROGRESS, all_variables)
181 |
182 | @parameterized.named_parameters(
183 | dict(testcase_name="missing_datetime", coord_name="lon"),
184 | dict(testcase_name="missing_lon", coord_name="datetime"),
185 | )
186 | def test_add_derived_vars_missing_coordinate_raises_value_error(
187 | self, coord_name
188 | ):
189 | with self.subTest(f"Missing {coord_name} coordinate"):
190 | data = xarray.Dataset(
191 | data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))},
192 | coords={
193 | coord_name: np.array([0.0, 0.5]),
194 | },
195 | )
196 | with self.assertRaises(ValueError):
197 | data_utils.add_derived_vars(data)
198 |
199 |
200 | if __name__ == "__main__":
201 | absltest.main()
202 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/grid_mesh_connectivity.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tools for converting from regular grids on a sphere, to triangular meshes."""
15 |
16 | from graphcast import icosahedral_mesh
17 | import numpy as np
18 | import scipy
19 | import trimesh
20 |
21 |
22 | def _grid_lat_lon_to_coordinates(
23 | grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray:
24 | """Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3]."""
25 | # Convert to spherical coordinates phi and theta defined in the grid.
26 | # Each [num_latitude_points, num_longitude_points]
27 | phi_grid, theta_grid = np.meshgrid(
28 | np.deg2rad(grid_longitude),
29 | np.deg2rad(90 - grid_latitude))
30 |
31 | # [num_latitude_points, num_longitude_points, 3]
32 | # Note this assumes unit radius, since for now we model the earth as a
33 | # sphere of unit radius, and keep any vertical dimension as a regular grid.
34 | return np.stack(
35 | [np.cos(phi_grid)*np.sin(theta_grid),
36 | np.sin(phi_grid)*np.sin(theta_grid),
37 | np.cos(theta_grid)], axis=-1)
38 |
39 |
40 | def radius_query_indices(
41 | *,
42 | grid_latitude: np.ndarray,
43 | grid_longitude: np.ndarray,
44 | mesh: icosahedral_mesh.TriangularMesh,
45 | radius: float) -> tuple[np.ndarray, np.ndarray]:
46 | """Returns mesh-grid edge indices for radius query.
47 |
48 | Args:
49 | grid_latitude: Latitude values for the grid [num_lat_points]
50 | grid_longitude: Longitude values for the grid [num_lon_points]
51 | mesh: Mesh object.
52 | radius: Radius of connectivity in R3. for a sphere of unit radius.
53 |
54 | Returns:
55 | tuple with `grid_indices` and `mesh_indices` indicating edges between the
56 | grid and the mesh such that the distances in a straight line (not geodesic)
57 | are smaller than or equal to `radius`.
58 | * grid_indices: Indices of shape [num_edges], that index into a
59 | [num_lat_points, num_lon_points] grid, after flattening the leading axes.
60 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
61 | """
62 |
63 | # [num_grid_points=num_lat_points * num_lon_points, 3]
64 | grid_positions = _grid_lat_lon_to_coordinates(
65 | grid_latitude, grid_longitude).reshape([-1, 3])
66 |
67 | # [num_mesh_points, 3]
68 | mesh_positions = mesh.vertices
69 | kd_tree = scipy.spatial.cKDTree(mesh_positions)
70 |
71 | # [num_grid_points, num_mesh_points_per_grid_point]
72 | # Note `num_mesh_points_per_grid_point` is not constant, so this is a list
73 | # of arrays, rather than a 2d array.
74 | query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius)
75 |
76 | grid_edge_indices = []
77 | mesh_edge_indices = []
78 | for grid_index, mesh_neighbors in enumerate(query_indices):
79 | grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors)))
80 | mesh_edge_indices.append(mesh_neighbors)
81 |
82 | # [num_edges]
83 | grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int)
84 | mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int)
85 |
86 | return grid_edge_indices, mesh_edge_indices
87 |
88 |
89 | def in_mesh_triangle_indices(
90 | *,
91 | grid_latitude: np.ndarray,
92 | grid_longitude: np.ndarray,
93 | mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]:
94 | """Returns mesh-grid edge indices for grid points contained in mesh triangles.
95 |
96 | Args:
97 | grid_latitude: Latitude values for the grid [num_lat_points]
98 | grid_longitude: Longitude values for the grid [num_lon_points]
99 | mesh: Mesh object.
100 |
101 | Returns:
102 | tuple with `grid_indices` and `mesh_indices` indicating edges between the
103 | grid and the mesh vertices of the triangle that contain each grid point.
104 | The number of edges is always num_lat_points * num_lon_points * 3
105 | * grid_indices: Indices of shape [num_edges], that index into a
106 | [num_lat_points, num_lon_points] grid, after flattening the leading axes.
107 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
108 | """
109 |
110 | # [num_grid_points=num_lat_points * num_lon_points, 3]
111 | grid_positions = _grid_lat_lon_to_coordinates(
112 | grid_latitude, grid_longitude).reshape([-1, 3])
113 |
114 | mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
115 |
116 | # [num_grid_points] with mesh face indices for each grid point.
117 | _, _, query_face_indices = trimesh.proximity.closest_point(
118 | mesh_trimesh, grid_positions)
119 |
120 | # [num_grid_points, 3] with mesh node indices for each grid point.
121 | mesh_edge_indices = mesh.faces[query_face_indices]
122 |
123 | # [num_grid_points, 3] with grid node indices, where every row simply contains
124 | # the row (grid_point) index.
125 | grid_indices = np.arange(grid_positions.shape[0])
126 | grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3])
127 |
128 | # Flatten to get a regular list.
129 | # [num_edges=num_grid_points*3]
130 | mesh_edge_indices = mesh_edge_indices.reshape([-1])
131 | grid_edge_indices = grid_edge_indices.reshape([-1])
132 |
133 | return grid_edge_indices, mesh_edge_indices
134 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/grid_mesh_connectivity_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for graphcast.grid_mesh_connectivity."""
15 |
16 | from absl.testing import absltest
17 | from graphcast import grid_mesh_connectivity
18 | from graphcast import icosahedral_mesh
19 | import numpy as np
20 |
21 |
22 | class GridMeshConnectivityTest(absltest.TestCase):
23 |
24 | def test_grid_lat_lon_to_coordinates(self):
25 |
26 | # Intervals of 30 degrees.
27 | grid_latitude = np.array([-45., 0., 45])
28 | grid_longitude = np.array([0., 90., 180., 270.])
29 |
30 | inv_sqrt2 = 1 / np.sqrt(2)
31 | expected_coordinates = np.array([
32 | [[inv_sqrt2, 0., -inv_sqrt2],
33 | [0., inv_sqrt2, -inv_sqrt2],
34 | [-inv_sqrt2, 0., -inv_sqrt2],
35 | [0., -inv_sqrt2, -inv_sqrt2]],
36 | [[1., 0., 0.],
37 | [0., 1., 0.],
38 | [-1., 0., 0.],
39 | [0., -1., 0.]],
40 | [[inv_sqrt2, 0., inv_sqrt2],
41 | [0., inv_sqrt2, inv_sqrt2],
42 | [-inv_sqrt2, 0., inv_sqrt2],
43 | [0., -inv_sqrt2, inv_sqrt2]],
44 | ])
45 |
46 | coordinates = grid_mesh_connectivity._grid_lat_lon_to_coordinates(
47 | grid_latitude, grid_longitude)
48 | np.testing.assert_allclose(expected_coordinates, coordinates, atol=1e-15)
49 |
50 | def test_radius_query_indices_smoke(self):
51 | # TODO(alvarosg): Add non-smoke test?
52 | grid_latitude = np.linspace(-75, 75, 6)
53 | grid_longitude = np.arange(12) * 30.
54 | mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
55 | splits=3)[-1]
56 | grid_mesh_connectivity.radius_query_indices(
57 | grid_latitude=grid_latitude,
58 | grid_longitude=grid_longitude,
59 | mesh=mesh, radius=0.2)
60 |
61 | def test_in_mesh_triangle_indices_smoke(self):
62 | # TODO(alvarosg): Add non-smoke test?
63 | grid_latitude = np.linspace(-75, 75, 6)
64 | grid_longitude = np.arange(12) * 30.
65 | mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
66 | splits=3)[-1]
67 | grid_mesh_connectivity.in_mesh_triangle_indices(
68 | grid_latitude=grid_latitude,
69 | grid_longitude=grid_longitude,
70 | mesh=mesh)
71 |
72 |
73 | if __name__ == "__main__":
74 | absltest.main()
75 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/icosahedral_mesh.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils for creating icosahedral meshes."""
15 |
16 | import itertools
17 | from typing import List, NamedTuple, Sequence, Tuple
18 |
19 | import numpy as np
20 | from scipy.spatial import transform
21 |
22 |
23 | class TriangularMesh(NamedTuple):
24 | """Data structure for triangular meshes.
25 |
26 | Attributes:
27 | vertices: spatial positions of the vertices of the mesh of shape
28 | [num_vertices, num_dims].
29 | faces: triangular faces of the mesh of shape [num_faces, 3]. Contains
30 | integer indices into `vertices`.
31 |
32 | """
33 | vertices: np.ndarray
34 | faces: np.ndarray
35 |
36 |
37 | def merge_meshes(
38 | mesh_list: Sequence[TriangularMesh]) -> TriangularMesh:
39 | """Merges all meshes into one. Assumes the last mesh is the finest.
40 |
41 | Args:
42 | mesh_list: Sequence of meshes, from coarse to fine refinement levels. The
43 | vertices and faces may contain those from preceding, coarser levels.
44 |
45 | Returns:
46 | `TriangularMesh` for which the vertices correspond to the highest
47 | resolution mesh in the hierarchy, and the faces are the join set of the
48 | faces at all levels of the hierarchy.
49 | """
50 | for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list):
51 | num_nodes_mesh_i = mesh_i.vertices.shape[0]
52 | assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i])
53 |
54 | return TriangularMesh(
55 | vertices=mesh_list[-1].vertices,
56 | faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0))
57 |
58 |
59 | def get_hierarchy_of_triangular_meshes_for_sphere(
60 | splits: int) -> List[TriangularMesh]:
61 | """Returns a sequence of meshes, each with triangularization sphere.
62 |
63 | Starting with a regular icosahedron (12 vertices, 20 faces, 30 edges) with
64 | circumscribed unit sphere. Then, each triangular face is iteratively
65 | subdivided into 4 triangular faces `splits` times. The new vertices are then
66 | projected back onto the unit sphere. All resulting meshes are returned in a
67 | list, from lowest to highest resolution.
68 |
69 | The vertices in each face are specified in counter-clockwise order as
70 | observed from the outside the icosahedron.
71 |
72 | Args:
73 | splits: How many times to split each triangle.
74 | Returns:
75 | Sequence of `TriangularMesh`s of length `splits + 1` each with:
76 |
77 | vertices: [num_vertices, 3] vertex positions in 3D, all with unit norm.
78 | faces: [num_faces, 3] with triangular faces joining sets of 3 vertices.
79 | Each row contains three indices into the vertices array, indicating
80 | the vertices adjacent to the face. Always with positive orientation
81 | (counterclock-wise when looking from the outside).
82 | """
83 | current_mesh = get_icosahedron()
84 | output_meshes = [current_mesh]
85 | for _ in range(splits):
86 | current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh)
87 | output_meshes.append(current_mesh)
88 | return output_meshes
89 |
90 |
91 | def get_icosahedron() -> TriangularMesh:
92 | """Returns a regular icosahedral mesh with circumscribed unit sphere.
93 |
94 | See https://en.wikipedia.org/wiki/Regular_icosahedron#Cartesian_coordinates
95 | for details on the construction of the regular icosahedron.
96 |
97 | The vertices in each face are specified in counter-clockwise order as observed
98 | from the outside of the icosahedron.
99 |
100 | Returns:
101 | TriangularMesh with:
102 |
103 | vertices: [num_vertices=12, 3] vertex positions in 3D, all with unit norm.
104 | faces: [num_faces=20, 3] with triangular faces joining sets of 3 vertices.
105 | Each row contains three indices into the vertices array, indicating
106 | the vertices adjacent to the face. Always with positive orientation (
107 | counterclock-wise when looking from the outside).
108 |
109 | """
110 | phi = (1 + np.sqrt(5)) / 2
111 | vertices = []
112 | for c1 in [1., -1.]:
113 | for c2 in [phi, -phi]:
114 | vertices.append((c1, c2, 0.))
115 | vertices.append((0., c1, c2))
116 | vertices.append((c2, 0., c1))
117 |
118 | vertices = np.array(vertices, dtype=np.float32)
119 | vertices /= np.linalg.norm([1., phi])
120 |
121 | # I did this manually, checking the orientation one by one.
122 | faces = [(0, 1, 2),
123 | (0, 6, 1),
124 | (8, 0, 2),
125 | (8, 4, 0),
126 | (3, 8, 2),
127 | (3, 2, 7),
128 | (7, 2, 1),
129 | (0, 4, 6),
130 | (4, 11, 6),
131 | (6, 11, 5),
132 | (1, 5, 7),
133 | (4, 10, 11),
134 | (4, 8, 10),
135 | (10, 8, 3),
136 | (10, 3, 9),
137 | (11, 10, 9),
138 | (11, 9, 5),
139 | (5, 9, 7),
140 | (9, 3, 7),
141 | (1, 6, 5),
142 | ]
143 |
144 | # By default the top is an aris parallel to the Y axis.
145 | # Need to rotate around the y axis by half the supplementary to the
146 | # angle between faces divided by two to get the desired orientation.
147 | # /O\ (top arist)
148 | # / \ Z
149 | # (adjacent face)/ \ (adjacent face) ^
150 | # / angle_between_faces \ |
151 | # / \ |
152 | # / \ YO-----> X
153 | # This results in:
154 | # (adjacent faceis now top plane)
155 | # ----------------------O\ (top arist)
156 | # \
157 | # \
158 | # \ (adjacent face)
159 | # \
160 | # \
161 | # \
162 |
163 | angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3))
164 | rotation_angle = (np.pi - angle_between_faces) / 2
165 | rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle)
166 | rotation_matrix = rotation.as_matrix()
167 | vertices = np.dot(vertices, rotation_matrix)
168 |
169 | return TriangularMesh(vertices=vertices.astype(np.float32),
170 | faces=np.array(faces, dtype=np.int32))
171 |
172 |
173 | def _two_split_unit_sphere_triangle_faces(
174 | triangular_mesh: TriangularMesh) -> TriangularMesh:
175 | """Splits each triangular face into 4 triangles keeping the orientation."""
176 |
177 | # Every time we split a triangle into 4 we will be adding 3 extra vertices,
178 | # located at the edge centres.
179 | # This class handles the positioning of the new vertices, and avoids creating
180 | # duplicates.
181 | new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices)
182 |
183 | new_faces = []
184 | for ind1, ind2, ind3 in triangular_mesh.faces:
185 | # Transform each triangular face into 4 triangles,
186 | # preserving the orientation.
187 | # ind3
188 | # / \
189 | # / \
190 | # / #3 \
191 | # / \
192 | # ind31 -------------- ind23
193 | # / \ / \
194 | # / \ #4 / \
195 | # / #1 \ / #2 \
196 | # / \ / \
197 | # ind1 ------------ ind12 ------------ ind2
198 | ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2))
199 | ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3))
200 | ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1))
201 | # Note how each of the 4 triangular new faces specifies the order of the
202 | # vertices to preserve the orientation of the original face. As the input
203 | # face should always be counter-clockwise as specified in the diagram,
204 | # this means child faces should also be counter-clockwise.
205 | new_faces.extend([[ind1, ind12, ind31], # 1
206 | [ind12, ind2, ind23], # 2
207 | [ind31, ind23, ind3], # 3
208 | [ind12, ind23, ind31], # 4
209 | ])
210 | return TriangularMesh(vertices=new_vertices_builder.get_all_vertices(),
211 | faces=np.array(new_faces, dtype=np.int32))
212 |
213 |
214 | class _ChildVerticesBuilder(object):
215 | """Bookkeeping of new child vertices added to an existing set of vertices."""
216 |
217 | def __init__(self, parent_vertices):
218 |
219 | # Because the same new vertex will be required when splitting adjacent
220 | # triangles (which share an edge) we keep them in a hash table indexed by
221 | # sorted indices of the vertices adjacent to the edge, to avoid creating
222 | # duplicated child vertices.
223 | self._child_vertices_index_mapping = {}
224 | self._parent_vertices = parent_vertices
225 | # We start with all previous vertices.
226 | self._all_vertices_list = list(parent_vertices)
227 |
228 | def _get_child_vertex_key(self, parent_vertex_indices):
229 | return tuple(sorted(parent_vertex_indices))
230 |
231 | def _create_child_vertex(self, parent_vertex_indices):
232 | """Creates a new vertex."""
233 | # Position for new vertex is the middle point, between the parent points,
234 | # projected to unit sphere.
235 | child_vertex_position = self._parent_vertices[
236 | list(parent_vertex_indices)].mean(0)
237 | child_vertex_position /= np.linalg.norm(child_vertex_position)
238 |
239 | # Add the vertex to the output list. The index for this new vertex will
240 | # match the length of the list before adding it.
241 | child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
242 | self._child_vertices_index_mapping[child_vertex_key] = len(
243 | self._all_vertices_list)
244 | self._all_vertices_list.append(child_vertex_position)
245 |
246 | def get_new_child_vertex_index(self, parent_vertex_indices):
247 | """Returns index for a child vertex, creating it if necessary."""
248 | # Get the key to see if we already have a new vertex in the middle.
249 | child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
250 | if child_vertex_key not in self._child_vertices_index_mapping:
251 | self._create_child_vertex(parent_vertex_indices)
252 | return self._child_vertices_index_mapping[child_vertex_key]
253 |
254 | def get_all_vertices(self):
255 | """Returns an array with old vertices."""
256 | return np.array(self._all_vertices_list)
257 |
258 |
259 | def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
260 | """Transforms polygonal faces to sender and receiver indices.
261 |
262 | It does so by transforming every face into N_i edges. Such if the triangular
263 | face has indices [0, 1, 2], three edges are added 0->1, 1->2, and 2->0.
264 |
265 | If all faces have consistent orientation, and the surface represented by the
266 | faces is closed, then every edge in a polygon with a certain orientation
267 | is also part of another polygon with the opposite orientation. In this
268 | situation, the edges returned by the method are always bidirectional.
269 |
270 | Args:
271 | faces: Integer array of shape [num_faces, 3]. Contains node indices
272 | adjacent to each face.
273 | Returns:
274 | Tuple with sender/receiver indices, each of shape [num_edges=num_faces*3].
275 |
276 | """
277 | assert faces.ndim == 2
278 | assert faces.shape[-1] == 3
279 | senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]])
280 | receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]])
281 | return senders, receivers
282 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/icosahedral_mesh_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for icosahedral_mesh."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | import chex
19 | from graphcast import icosahedral_mesh
20 | import numpy as np
21 |
22 |
23 | def _get_mesh_spec(splits: int):
24 | """Returns size of the final icosahedral mesh resulting from the splitting."""
25 | num_vertices = 12
26 | num_faces = 20
27 | for _ in range(splits):
28 | # Each previous face adds three new vertices, but each vertex is shared
29 | # by two faces.
30 | num_vertices += num_faces * 3 // 2
31 | num_faces *= 4
32 | return num_vertices, num_faces
33 |
34 |
35 | class IcosahedralMeshTest(parameterized.TestCase):
36 |
37 | def test_icosahedron(self):
38 | mesh = icosahedral_mesh.get_icosahedron()
39 | _assert_valid_mesh(
40 | mesh, num_expected_vertices=12, num_expected_faces=20)
41 |
42 | @parameterized.parameters(list(range(5)))
43 | def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits):
44 | meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
45 | splits=splits)
46 | prev_vertices = None
47 | for mesh_i, mesh in enumerate(meshes):
48 | # Check that `mesh` is valid.
49 | num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i)
50 | _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces)
51 |
52 | # Check that the first N vertices from this mesh match all of the
53 | # vertices from the previous mesh.
54 | if prev_vertices is not None:
55 | leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]]
56 | np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices)
57 |
58 | # Increase the expected/previous values for the next iteration.
59 | if mesh_i < len(meshes) - 1:
60 | prev_vertices = mesh.vertices
61 |
62 | @parameterized.parameters(list(range(4)))
63 | def test_merge_meshes(self, splits):
64 | mesh_hierarchy = (
65 | icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
66 | splits=splits))
67 | mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy)
68 |
69 | expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0)
70 | np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices)
71 | np.testing.assert_array_equal(mesh.faces, expected_faces)
72 |
73 | def test_faces_to_edges(self):
74 |
75 | faces = np.array([[0, 1, 2],
76 | [3, 4, 5]])
77 |
78 | # This also documents the order of the edges returned by the method.
79 | expected_edges = np.array(
80 | [[0, 1],
81 | [3, 4],
82 | [1, 2],
83 | [4, 5],
84 | [2, 0],
85 | [5, 3]])
86 | expected_senders = expected_edges[:, 0]
87 | expected_receivers = expected_edges[:, 1]
88 |
89 | senders, receivers = icosahedral_mesh.faces_to_edges(faces)
90 |
91 | np.testing.assert_array_equal(senders, expected_senders)
92 | np.testing.assert_array_equal(receivers, expected_receivers)
93 |
94 |
95 | def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces):
96 | vertices = mesh.vertices
97 | faces = mesh.faces
98 | chex.assert_shape(vertices, [num_expected_vertices, 3])
99 | chex.assert_shape(faces, [num_expected_faces, 3])
100 |
101 | # Vertices norm should be 1.
102 | vertices_norm = np.linalg.norm(vertices, axis=-1)
103 | np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6)
104 |
105 | _assert_positive_face_orientation(vertices, faces)
106 |
107 |
108 | def _assert_positive_face_orientation(vertices, faces):
109 |
110 | # Obtain a unit vector that points, in the direction of the face.
111 | face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]],
112 | vertices[faces[:, 2]] - vertices[faces[:, 1]])
113 | face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True)
114 |
115 | # And a unit vector pointing from the origin to the center of the face.
116 | face_centers = vertices[faces].mean(1)
117 | face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True)
118 |
119 | # Positive orientation means those two vectors should be parallel
120 | # (dot product, 1), and not anti-parallel (dot product, -1).
121 | dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers)
122 |
123 | # Check that the face normal is parallel to the vector that joins the center
124 | # of the face to the center of the sphere. Note we need a small tolerance
125 | # because some discretizations are not exactly uniform, so it will not be
126 | # exactly parallel.
127 | np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4)
128 |
129 |
130 | if __name__ == "__main__":
131 | absltest.main()
132 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Loss functions (and terms for use in loss functions) used for weather."""
15 |
16 | from typing import Mapping, Callable
17 |
18 | from graphcast import xarray_tree
19 | import numpy as np
20 | from typing_extensions import Protocol
21 | import xarray
22 |
23 |
24 | LossAndDiagnostics = tuple[xarray.DataArray, xarray.Dataset]
25 |
26 |
27 | class LossFunction(Protocol):
28 | """A loss function.
29 |
30 | This is a protocol so it's fine to use a plain function which 'quacks like'
31 | this. This is just to document the interface.
32 | """
33 |
34 | def __call__(self,
35 | predictions: xarray.Dataset,
36 | targets: xarray.Dataset,
37 | **optional_kwargs) -> LossAndDiagnostics:
38 | """Computes a loss function.
39 |
40 | Args:
41 | predictions: Dataset of predictions.
42 | targets: Dataset of targets.
43 | **optional_kwargs: Implementations may support extra optional kwargs.
44 |
45 | Returns:
46 | loss: A DataArray with dimensions ('batch',) containing losses for each
47 | element of the batch. These will be averaged to give the final
48 | loss, locally and across replicas.
49 | diagnostics: Mapping of additional quantities to log by name alongside the
50 | loss. These will will typically correspond to terms in the loss. They
51 | should also have dimensions ('batch',) and will be averaged over the
52 | batch before logging.
53 | """
54 |
55 |
56 | def weighted_mse_per_level(predictions: xarray.Dataset,
57 | targets: xarray.Dataset,
58 | per_variable_weights: Mapping[str, float]
59 | ) -> LossAndDiagnostics:
60 | """Latitude- and pressure-level-weighted MSE loss."""
61 | def mse_loss(prediction, target):
62 | return (prediction - target)**2
63 |
64 | return weighted_error_per_level(predictions,
65 | targets,
66 | per_variable_weights,
67 | mse_loss,
68 | normalized_level_weights)
69 |
70 | def weighted_error_per_level(predictions: xarray.Dataset,
71 | targets: xarray.Dataset,
72 | per_variable_weights: Mapping[str, float],
73 | loss_fn,
74 | per_level_weights_fn = None
75 | ) -> LossAndDiagnostics:
76 | """
77 | Compute the loss function weighted per variable and per level.
78 | Moreover, weights the latitudes to account for unequal distribution of grid points on the sphere.
79 | """
80 | def weighted_loss(prediction, target):
81 | loss = loss_fn(prediction, target)
82 | loss *= normalized_latitude_weights(target).astype(loss.dtype)
83 | if 'level' in target.dims:
84 | loss *= per_level_weights_fn(target).astype(loss.dtype)
85 | return _mean_preserving_batch(loss)
86 |
87 | losses = xarray_tree.map_structure(weighted_loss, predictions, targets)
88 | return sum_per_variable_losses(losses, per_variable_weights)
89 |
90 |
91 | def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray:
92 | return x.mean([d for d in x.dims if d != 'batch'], skipna=False)
93 |
94 |
95 | def sum_per_variable_losses(
96 | per_variable_losses: Mapping[str, xarray.DataArray],
97 | weights: Mapping[str, float],
98 | ) -> LossAndDiagnostics:
99 | """Weighted sum of per-variable losses."""
100 | if not set(weights.keys()).issubset(set(per_variable_losses.keys())):
101 | raise ValueError(
102 | 'Passing a weight that does not correspond to any variable '
103 | f'{set(weights.keys())-set(per_variable_losses.keys())}')
104 |
105 | weighted_per_variable_losses = {
106 | name: loss * weights.get(name, 1)
107 | for name, loss in per_variable_losses.items()
108 | }
109 | total = xarray.concat(
110 | weighted_per_variable_losses.values(), dim='variable', join='exact').sum(
111 | 'variable', skipna=False)
112 | return total, per_variable_losses
113 |
114 |
115 | def normalized_level_weights(data: xarray.DataArray) -> xarray.DataArray:
116 | """Weights proportional to pressure at each level."""
117 | level = data.coords['level']
118 | return level / level.mean(skipna=False)
119 |
120 |
121 | def single_level_weights(data: xarray.DataArray, level) -> xarray.DataArray:
122 | """Weights that select a single pressure level."""
123 | l = data.coords['level']
124 | weights = xarray.zeros_like(l)
125 | weights.loc[{'level': level}] = 1 # Replace '1' with the desired non-zero value
126 | return weights
127 |
128 |
129 | def normalized_latitude_weights(data: xarray.DataArray) -> xarray.DataArray:
130 | """Weights based on latitude, roughly proportional to grid cell area.
131 |
132 | This method supports two use cases only (both for equispaced values):
133 | * Latitude values such that the closest value to the pole is at latitude
134 | (90 - d_lat/2), where d_lat is the difference between contiguous latitudes.
135 | For example: [-89, -87, -85, ..., 85, 87, 89]) (d_lat = 2)
136 | In this case each point with `lat` value represents a sphere slice between
137 | `lat - d_lat/2` and `lat + d_lat/2`, and the area of this slice would be
138 | proportional to:
139 | `sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)`, and
140 | we can simply omit the term `2 * sin(d_lat/2)` which is just a constant
141 | that cancels during normalization.
142 | * Latitude values that fall exactly at the poles.
143 | For example: [-90, -88, -86, ..., 86, 88, 90]) (d_lat = 2)
144 | In this case each point with `lat` value also represents
145 | a sphere slice between `lat - d_lat/2` and `lat + d_lat/2`,
146 | except for the points at the poles, that represent a slice between
147 | `90 - d_lat/2` and `90` or, `-90` and `-90 + d_lat/2`.
148 | The areas of the first type of point are still proportional to:
149 | * sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)
150 | but for the points at the poles now is:
151 | * sin(90) - sin(90 - d_lat/2) = 2 * sin(d_lat/4) ^ 2
152 | and we will be using these weights, depending on whether we are looking at
153 | pole cells, or non-pole cells (omitting the common factor of 2 which will be
154 | absorbed by the normalization).
155 |
156 | It can be shown via a limit, or simple geometry, that in the small angles
157 | regime, the proportion of area per pole-point is equal to 1/8th
158 | the proportion of area covered by each of the nearest non-pole point, and we
159 | test for this in the test.
160 |
161 | Args:
162 | data: `DataArray` with latitude coordinates.
163 | Returns:
164 | Unit mean latitude weights.
165 | """
166 | latitude = data.coords['lat']
167 |
168 | if np.any(np.isclose(np.abs(latitude), 90.)):
169 | weights = _weight_for_latitude_vector_with_poles(latitude)
170 | else:
171 | weights = _weight_for_latitude_vector_without_poles(latitude)
172 |
173 | return weights / weights.mean(skipna=False)
174 |
175 |
176 | def _weight_for_latitude_vector_without_poles(latitude):
177 | """Weights for uniform latitudes of the form [+-90-+d/2, ..., -+90+-d/2]."""
178 | delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
179 | if (not np.isclose(np.max(latitude), 90 - delta_latitude/2) or
180 | not np.isclose(np.min(latitude), -90 + delta_latitude/2)):
181 | raise ValueError(
182 | f'Latitude vector {latitude} does not start/end at '
183 | '+- (90 - delta_latitude/2) degrees.')
184 | return np.cos(np.deg2rad(latitude))
185 |
186 |
187 | def _weight_for_latitude_vector_with_poles(latitude):
188 | """Weights for uniform latitudes of the form [+- 90, ..., -+90]."""
189 | delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
190 | if (not np.isclose(np.max(latitude), 90.) or
191 | not np.isclose(np.min(latitude), -90.)):
192 | raise ValueError(
193 | f'Latitude vector {latitude} does not start/end at +- 90 degrees.')
194 | weights = np.cos(np.deg2rad(latitude)) * np.sin(np.deg2rad(delta_latitude/2))
195 | # The two checks above enough to guarantee that latitudes are sorted, so
196 | # the extremes are the poles
197 | weights[[0, -1]] = np.sin(np.deg2rad(delta_latitude/4)) ** 2
198 | return weights
199 |
200 |
201 | def _check_uniform_spacing_and_get_delta(vector):
202 | diff = np.diff(vector)
203 | if not np.all(np.isclose(diff[0], diff)):
204 | raise ValueError(f'Vector {diff} is not uniformly spaced.')
205 | return diff[0]
206 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/predictor_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Abstract base classes for an xarray-based Predictor API."""
15 |
16 | import abc
17 |
18 | from typing import Tuple
19 |
20 | from graphcast import losses
21 | from graphcast import xarray_jax
22 | import jax.numpy as jnp
23 | import xarray
24 |
25 | LossAndDiagnostics = losses.LossAndDiagnostics
26 |
27 |
28 | class Predictor(abc.ABC):
29 | """A possibly-trainable predictor of weather, exposing an xarray-based API.
30 |
31 | Typically wraps an underlying JAX model and handles translating the xarray
32 | Dataset values to and from plain JAX arrays that are convenient for input to
33 | (and output from) the underlying model.
34 |
35 | Different subclasses may exist to wrap different kinds of underlying model,
36 | e.g. models taking stacked inputs/outputs, models taking separate 2D and 3D
37 | inputs/outputs, autoregressive models.
38 |
39 | You can also implement a specific model directly as a Predictor if you want,
40 | for example if it has quite specific/unique requirements for its input/output
41 | or loss function, or if it's convenient to implement directly using xarray.
42 | """
43 |
44 | @abc.abstractmethod
45 | def __call__(self,
46 | inputs: xarray.Dataset,
47 | targets_template: xarray.Dataset,
48 | forcings: xarray.Dataset,
49 | **optional_kwargs
50 | ) -> xarray.Dataset:
51 | """Makes predictions.
52 |
53 | This is only used by the Experiment for inference / evaluation, with
54 | training going via the .loss method. So it should default to making
55 | predictions for evaluation, although you can also support making predictions
56 | for use in the loss via an is_training argument -- see
57 | LossFunctionPredictor which helps with that.
58 |
59 | Args:
60 | inputs: An xarray.Dataset of inputs.
61 | targets_template: An xarray.Dataset or other mapping of xarray.DataArrays,
62 | with the same shape as the targets, to demonstrate what kind of
63 | predictions are required. You can use this to determine which variables,
64 | levels and lead times must be predicted.
65 | You are free to raise an error if you don't support predicting what is
66 | requested.
67 | forcings: An xarray.Dataset of forcings terms. Forcings are variables
68 | that can be fed to the model, but do not need to be predicted. This is
69 | often because this variable can be computed analytically (e.g. the toa
70 | radiation of the sun is mostly a function of geometry) or are considered
71 | to be controlled for the experiment (e.g., impose a scenario of C02
72 | emission into the atmosphere). Unlike `inputs`, the `forcings` can
73 | include information "from the future", that is, information at target
74 | times specified in the `targets_template`.
75 | **optional_kwargs: Implementations may support extra optional kwargs,
76 | provided they set appropriate defaults for them.
77 |
78 | Returns:
79 | Predictions, as an xarray.Dataset or other mapping of DataArrays which
80 | is capable of being evaluated against targets with shape given by
81 | targets_template.
82 | For probabilistic predictors which can return multiple samples from a
83 | predictive distribution, these should (by convention) be returned along
84 | an additional 'sample' dimension.
85 | """
86 |
87 | def loss(self,
88 | inputs: xarray.Dataset,
89 | targets: xarray.Dataset,
90 | forcings: xarray.Dataset,
91 | **optional_kwargs,
92 | ) -> LossAndDiagnostics:
93 | """Computes a training loss, for predictors that are trainable.
94 |
95 | Why make this the Predictor's responsibility, rather than letting callers
96 | compute their own loss function using predictions obtained from
97 | Predictor.__call__?
98 |
99 | Doing it this way gives Predictors more control over their training setup.
100 | For example, some predictors may wish to train using different targets to
101 | the ones they predict at evaluation time -- perhaps different lead times and
102 | variables, perhaps training to predict transformed versions of targets
103 | where the transform needs to be inverted at evaluation time, etc.
104 |
105 | It's also necessary for generative models (VAEs, GANs, ...) where the
106 | training loss is more complex and isn't expressible as a parameter-free
107 | function of predictions and targets.
108 |
109 | Args:
110 | inputs: An xarray.Dataset.
111 | targets: An xarray.Dataset or other mapping of xarray.DataArrays. See
112 | docs on __call__ for an explanation about the targets.
113 | forcings: xarray.Dataset of forcing terms.
114 | **optional_kwargs: Implementations may support extra optional kwargs,
115 | provided they set appropriate defaults for them.
116 |
117 | Returns:
118 | loss: A DataArray with dimensions ('batch',) containing losses for each
119 | element of the batch. These will be averaged to give the final
120 | loss, locally and across replicas.
121 | diagnostics: Mapping of additional quantities to log by name alongside the
122 | loss. These will will typically correspond to terms in the loss. They
123 | should also have dimensions ('batch',) and will be averaged over the
124 | batch before logging.
125 | You need not include the loss itself in this dict; it will be added for
126 | you.
127 | """
128 | del targets, forcings, optional_kwargs
129 | batch_size = inputs.sizes['batch']
130 | dummy_loss = xarray_jax.DataArray(jnp.zeros(batch_size), dims=('batch',))
131 | return dummy_loss, {}
132 |
133 | def loss_and_predictions(
134 | self,
135 | inputs: xarray.Dataset,
136 | targets: xarray.Dataset,
137 | forcings: xarray.Dataset,
138 | **optional_kwargs,
139 | ) -> Tuple[LossAndDiagnostics, xarray.Dataset]:
140 | """Like .loss but also returns corresponding predictions.
141 |
142 | Implementing this is optional as it's not used directly by the Experiment,
143 | but it is required by autoregressive.Predictor when applying an inner
144 | Predictor autoregressively at training time; we need a loss at each step but
145 | also predictions to feed back in for the next step.
146 |
147 | Note the loss itself may not be directly regressing the predictions towards
148 | targets, the loss may be computed in terms of transformed predictions and
149 | targets (or in some other way). For this reason we can't always cleanly
150 | separate this into step 1: get predictions, step 2: compute loss from them,
151 | hence the need for this combined method.
152 |
153 | Args:
154 | inputs:
155 | targets:
156 | forcings:
157 | **optional_kwargs:
158 | As for self.loss.
159 |
160 | Returns:
161 | (loss, diagnostics)
162 | As for self.loss
163 | predictions:
164 | The predictions which the loss relates to. These should be of the same
165 | shape as what you would get from
166 | `self.__call__(inputs, targets_template=targets)`, and should be in the
167 | same 'domain' as the inputs (i.e. they shouldn't be transformed
168 | differently to how the predictor expects its inputs).
169 | """
170 | raise NotImplementedError
171 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/rollout.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils for rolling out models."""
15 |
16 | from typing import Iterator
17 |
18 | from absl import logging
19 | import chex
20 | import dask
21 | from graphcast import xarray_tree
22 | import jax
23 | import numpy as np
24 | import typing_extensions
25 | import xarray
26 |
27 |
28 | class PredictorFn(typing_extensions.Protocol):
29 | """Functional version of base.Predictor.__call__ with explicit rng."""
30 |
31 | def __call__(
32 | self, rng: chex.PRNGKey, inputs: xarray.Dataset,
33 | targets_template: xarray.Dataset,
34 | forcings: xarray.Dataset,
35 | **optional_kwargs,
36 | ) -> xarray.Dataset:
37 | ...
38 |
39 |
40 | def chunked_prediction(
41 | predictor_fn: PredictorFn,
42 | rng: chex.PRNGKey,
43 | inputs: xarray.Dataset,
44 | targets_template: xarray.Dataset,
45 | forcings: xarray.Dataset,
46 | num_steps_per_chunk: int = 1,
47 | verbose: bool = False,
48 | ) -> xarray.Dataset:
49 | """Outputs a long trajectory by iteratively concatenating chunked predictions.
50 |
51 | Args:
52 | predictor_fn: Function to use to make predictions for each chunk.
53 | rng: Random key.
54 | inputs: Inputs for the model.
55 | targets_template: Template for the target prediction, requires targets
56 | equispaced in time.
57 | forcings: Optional forcing for the model.
58 | num_steps_per_chunk: How many of the steps in `targets_template` to predict
59 | at each call of `predictor_fn`. It must evenly divide the number of
60 | steps in `targets_template`.
61 | verbose: Whether to log the current chunk being predicted.
62 |
63 | Returns:
64 | Predictions for the targets template.
65 |
66 | """
67 | chunks_list = []
68 | for prediction_chunk in chunked_prediction_generator(
69 | predictor_fn=predictor_fn,
70 | rng=rng,
71 | inputs=inputs,
72 | targets_template=targets_template,
73 | forcings=forcings,
74 | num_steps_per_chunk=num_steps_per_chunk,
75 | verbose=verbose):
76 | chunks_list.append(jax.device_get(prediction_chunk))
77 | return xarray.concat(chunks_list, dim="time")
78 |
79 |
80 | def chunked_prediction_generator(
81 | predictor_fn: PredictorFn,
82 | rng: chex.PRNGKey,
83 | inputs: xarray.Dataset,
84 | targets_template: xarray.Dataset,
85 | forcings: xarray.Dataset,
86 | num_steps_per_chunk: int = 1,
87 | verbose: bool = False,
88 | ) -> Iterator[xarray.Dataset]:
89 | """Outputs a long trajectory by yielding chunked predictions.
90 |
91 | Args:
92 | predictor_fn: Function to use to make predictions for each chunk.
93 | rng: Random key.
94 | inputs: Inputs for the model.
95 | targets_template: Template for the target prediction, requires targets
96 | equispaced in time.
97 | forcings: Optional forcing for the model.
98 | num_steps_per_chunk: How many of the steps in `targets_template` to predict
99 | at each call of `predictor_fn`. It must evenly divide the number of
100 | steps in `targets_template`.
101 | verbose: Whether to log the current chunk being predicted.
102 |
103 | Yields:
104 | The predictions for each chunked step of the chunked rollout, such as
105 | if all predictions are concatenated in time this would match the targets
106 | template in structure.
107 |
108 | """
109 |
110 | # Create copies to avoid mutating inputs.
111 | inputs = xarray.Dataset(inputs)
112 | targets_template = xarray.Dataset(targets_template)
113 | forcings = xarray.Dataset(forcings)
114 |
115 | if "datetime" in inputs.coords:
116 | del inputs.coords["datetime"]
117 |
118 | if "datetime" in targets_template.coords:
119 | output_datetime = targets_template.coords["datetime"]
120 | del targets_template.coords["datetime"]
121 | else:
122 | output_datetime = None
123 |
124 | if "datetime" in forcings.coords:
125 | del forcings.coords["datetime"]
126 |
127 | num_target_steps = targets_template.dims["time"]
128 | num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk)
129 | if remainder != 0:
130 | raise ValueError(
131 | f"The number of steps per chunk {num_steps_per_chunk} must "
132 | f"evenly divide the number of target steps {num_target_steps} ")
133 |
134 | if len(np.unique(np.diff(targets_template.coords["time"].data))) > 1:
135 | raise ValueError("The targets time coordinates must be evenly spaced")
136 |
137 | # Our template targets will always have a time axis corresponding for the
138 | # timedeltas for the first chunk.
139 | targets_chunk_time = targets_template.time.isel(
140 | time=slice(0, num_steps_per_chunk))
141 |
142 | current_inputs = inputs
143 | for chunk_index in range(num_chunks):
144 | if verbose:
145 | logging.info("Chunk %d/%d", chunk_index, num_chunks)
146 | logging.flush()
147 |
148 | # Select targets for the time period that we are predicting for this chunk.
149 | target_offset = num_steps_per_chunk * chunk_index
150 | target_slice = slice(target_offset, target_offset + num_steps_per_chunk)
151 | current_targets_template = targets_template.isel(time=target_slice)
152 |
153 | # Replace the timedelta, by the one corresponding to the first chunk, so we
154 | # don't recompile at every iteration, keeping the
155 | actual_target_time = current_targets_template.coords["time"]
156 | current_targets_template = current_targets_template.assign_coords(
157 | time=targets_chunk_time).compute()
158 |
159 | current_forcings = forcings.isel(time=target_slice)
160 | current_forcings = current_forcings.assign_coords(time=targets_chunk_time)
161 | current_forcings = current_forcings.compute()
162 | # Make predictions for the chunk.
163 | rng, this_rng = jax.random.split(rng)
164 | predictions = predictor_fn(
165 | rng=this_rng,
166 | inputs=current_inputs,
167 | targets_template=current_targets_template,
168 | forcings=current_forcings)
169 |
170 | next_frame = xarray.merge([predictions, current_forcings])
171 |
172 | current_inputs = _get_next_inputs(current_inputs, next_frame)
173 |
174 | # At this point we can assign the actual targets time coordinates.
175 | predictions = predictions.assign_coords(time=actual_target_time)
176 | if output_datetime is not None:
177 | predictions.coords["datetime"] = output_datetime.isel(
178 | time=target_slice)
179 | yield predictions
180 | del predictions
181 |
182 |
183 | def _get_next_inputs(
184 | prev_inputs: xarray.Dataset, next_frame: xarray.Dataset,
185 | ) -> xarray.Dataset:
186 | """Computes next inputs, from previous inputs and predictions."""
187 |
188 | # Make sure are are predicting all inputs with a time axis.
189 | non_predicted_or_forced_inputs = list(
190 | set(prev_inputs.keys()) - set(next_frame.keys()))
191 | if "time" in prev_inputs[non_predicted_or_forced_inputs].dims:
192 | raise ValueError(
193 | "Found an input with a time index that is not predicted or forced.")
194 |
195 | # Keys we need to copy from predictions to inputs.
196 | next_inputs_keys = list(
197 | set(next_frame.keys()).intersection(set(prev_inputs.keys())))
198 | next_inputs = next_frame[next_inputs_keys]
199 |
200 | # Apply concatenate next frame with inputs, crop what we don't need and
201 | # shift timedelta coordinates, so we don't recompile at every iteration.
202 | num_inputs = prev_inputs.dims["time"]
203 | return (
204 | xarray.concat(
205 | [prev_inputs, next_inputs], dim="time", data_vars="different")
206 | .tail(time=num_inputs)
207 | .assign_coords(time=prev_inputs.coords["time"]))
208 |
209 |
210 | def extend_targets_template(
211 | targets_template: xarray.Dataset,
212 | required_num_steps: int) -> xarray.Dataset:
213 | """Extends `targets_template` to `required_num_steps` with lazy arrays.
214 |
215 | It uses lazy dask arrays of zeros, so it does not require instantiating the
216 | array in memory.
217 |
218 | Args:
219 | targets_template: Input template to extend.
220 | required_num_steps: Number of steps required in the returned template.
221 |
222 | Returns:
223 | `xarray.Dataset` identical in variables and timestep to `targets_template`
224 | full of `dask.array.zeros` such that the time axis has `required_num_steps`.
225 |
226 | """
227 |
228 | # Extend the "time" and "datetime" coordinates
229 | time = targets_template.coords["time"]
230 |
231 | # Assert the first target time corresponds to the timestep.
232 | timestep = time[0].data
233 | if time.shape[0] > 1:
234 | assert np.all(timestep == time[1:] - time[:-1])
235 |
236 | extended_time = (np.arange(required_num_steps) + 1) * timestep
237 |
238 | if "datetime" in targets_template.coords:
239 | datetime = targets_template.coords["datetime"]
240 | extended_datetime = (datetime[0].data - timestep) + extended_time
241 | else:
242 | extended_datetime = None
243 |
244 | # Replace the values with empty dask arrays extending the time coordinates.
245 | datetime = targets_template.coords["time"]
246 |
247 | def extend_time(data_array: xarray.DataArray) -> xarray.DataArray:
248 | dims = data_array.dims
249 | shape = list(data_array.shape)
250 | shape[dims.index("time")] = required_num_steps
251 | dask_data = dask.array.zeros(
252 | shape=tuple(shape),
253 | chunks=-1, # Will give chunk info directly to `ChunksToZarr``.
254 | dtype=data_array.dtype)
255 |
256 | coords = dict(data_array.coords)
257 | coords["time"] = extended_time
258 |
259 | if extended_datetime is not None:
260 | coords["datetime"] = ("time", extended_datetime)
261 |
262 | return xarray.DataArray(
263 | dims=dims,
264 | data=dask_data,
265 | coords=coords)
266 |
267 | return xarray_tree.map_structure(extend_time, targets_template)
268 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/typed_graph.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Data-structure for storing graphs with typed edges and nodes."""
15 |
16 | from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar
17 |
18 | ArrayLike = Union[Any] # np.ndarray, jnp.ndarray, tf.tensor
19 | ArrayLikeTree = Union[Any, ArrayLike] # Nest of ArrayLike
20 |
21 | _T = TypeVar('_T')
22 |
23 |
24 | # All tensors have a "flat_batch_axis", which is similar to the leading
25 | # axes of graph_tuples:
26 | # * In the case of nodes this is simply a shared node and flat batch axis, with
27 | # size corresponding to the total number of nodes in the flattened batch.
28 | # * In the case of edges this is simply a shared edge and flat batch axis, with
29 | # size corresponding to the total number of edges in the flattened batch.
30 | # * In the case of globals this is simply the number of graphs in the flattened
31 | # batch.
32 |
33 | # All shapes may also have any additional leading shape "batch_shape".
34 | # Options for building batches are:
35 | # * Use a provided "flatten" method that takes a leading `batch_shape` and
36 | # it into the flat_batch_axis (this will be useful when using `tf.Dataset`
37 | # which supports batching into RaggedTensors, with leading batch shape even
38 | # if graphs have different numbers of nodes and edges), so the RaggedBatches
39 | # can then be converted into something without ragged dimensions that jax can
40 | # use.
41 | # * Directly build a "flat batch" using a provided function for batching a list
42 | # of graphs (how it is done in `jraph`).
43 |
44 |
45 | class NodeSet(NamedTuple):
46 | """Represents a set of nodes."""
47 | n_node: ArrayLike # [num_flat_graphs]
48 | features: ArrayLikeTree # Prev. `nodes`: [num_flat_nodes] + feature_shape
49 |
50 |
51 | class EdgesIndices(NamedTuple):
52 | """Represents indices to nodes adjacent to the edges."""
53 | senders: ArrayLike # [num_flat_edges]
54 | receivers: ArrayLike # [num_flat_edges]
55 |
56 |
57 | class EdgeSet(NamedTuple):
58 | """Represents a set of edges."""
59 | n_edge: ArrayLike # [num_flat_graphs]
60 | indices: EdgesIndices
61 | features: ArrayLikeTree # Prev. `edges`: [num_flat_edges] + feature_shape
62 |
63 |
64 | class Context(NamedTuple):
65 | # `n_graph` always contains ones but it is useful to query the leading shape
66 | # in case of graphs without any nodes or edges sets.
67 | n_graph: ArrayLike # [num_flat_graphs]
68 | features: ArrayLikeTree # Prev. `globals`: [num_flat_graphs] + feature_shape
69 |
70 |
71 | class EdgeSetKey(NamedTuple):
72 | name: str # Name of the EdgeSet.
73 |
74 | # Sender node set name and receiver node set name connected by the edge set.
75 | node_sets: Tuple[str, str]
76 |
77 |
78 | class TypedGraph(NamedTuple):
79 | """A graph with typed nodes and edges.
80 |
81 | A typed graph is made of a context, multiple sets of nodes and multiple
82 | sets of edges connecting those nodes (as indicated by the EdgeSetKey).
83 | """
84 |
85 | context: Context
86 | nodes: Mapping[str, NodeSet]
87 | edges: Mapping[EdgeSetKey, EdgeSet]
88 |
89 | def edge_key_by_name(self, name: str) -> EdgeSetKey:
90 | found_key = [k for k in self.edges.keys() if k.name == name]
91 | if len(found_key) != 1:
92 | raise KeyError("invalid edge key '{}'. Available edges: [{}]".format(
93 | name, ', '.join(x.name for x in self.edges.keys())))
94 | return found_key[0]
95 |
96 | def edge_by_name(self, name: str) -> EdgeSet:
97 | return self.edges[self.edge_key_by_name(name)]
98 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/typed_graph_net.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """A library of typed Graph Neural Networks."""
15 |
16 | from typing import Callable, Mapping, Optional, Union
17 |
18 | from graphcast import typed_graph
19 | import jax.numpy as jnp
20 | import jax.tree_util as tree
21 | import jraph
22 |
23 |
24 | # All features will be an ArrayTree.
25 | NodeFeatures = EdgeFeatures = SenderFeatures = ReceiverFeatures = Globals = (
26 | jraph.ArrayTree)
27 |
28 | # Signature:
29 | # (node features, outgoing edge features, incoming edge features,
30 | # globals) -> updated node features
31 | GNUpdateNodeFn = Callable[
32 | [NodeFeatures, Mapping[str, SenderFeatures], Mapping[str, ReceiverFeatures],
33 | Globals],
34 | NodeFeatures]
35 |
36 | GNUpdateGlobalFn = Callable[
37 | [Mapping[str, NodeFeatures], Mapping[str, EdgeFeatures], Globals],
38 | Globals]
39 |
40 |
41 | def GraphNetwork( # pylint: disable=invalid-name
42 | update_edge_fn: Mapping[str, jraph.GNUpdateEdgeFn],
43 | update_node_fn: Mapping[str, GNUpdateNodeFn],
44 | update_global_fn: Optional[GNUpdateGlobalFn] = None,
45 | aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph
46 | .segment_sum,
47 | aggregate_nodes_for_globals_fn: jraph.AggregateNodesToGlobalsFn = jraph
48 | .segment_sum,
49 | aggregate_edges_for_globals_fn: jraph.AggregateEdgesToGlobalsFn = jraph
50 | .segment_sum,
51 | ):
52 | """Returns a method that applies a configured GraphNetwork.
53 |
54 | This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
55 | extended to Typed Graphs with multiple edge sets and node sets and extended to
56 | allow aggregating not only edges received by the nodes, but also edges sent by
57 | the nodes.
58 |
59 | Example usage::
60 |
61 | gn = GraphNetwork(update_edge_function,
62 | update_node_function, **kwargs)
63 | # Conduct multiple rounds of message passing with the same parameters:
64 | for _ in range(num_message_passing_steps):
65 | graph = gn(graph)
66 |
67 | Args:
68 | update_edge_fn: mapping of functions used to update a subset of the edge
69 | types, indexed by edge type name.
70 | update_node_fn: mapping of functions used to update a subset of the node
71 | types, indexed by node type name.
72 | update_global_fn: function used to update the globals or None to deactivate
73 | globals updates.
74 | aggregate_edges_for_nodes_fn: function used to aggregate messages to each
75 | node.
76 | aggregate_nodes_for_globals_fn: function used to aggregate the nodes for the
77 | globals.
78 | aggregate_edges_for_globals_fn: function used to aggregate the edges for the
79 | globals.
80 |
81 | Returns:
82 | A method that applies the configured GraphNetwork.
83 | """
84 |
85 | def _apply_graph_net(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
86 | """Applies a configured GraphNetwork to a graph.
87 |
88 | This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
89 | extended to Typed Graphs with multiple edge sets and node sets and extended
90 | to allow aggregating not only edges received by the nodes, but also edges
91 | sent by the nodes.
92 |
93 | Args:
94 | graph: a `TypedGraph` containing the graph.
95 |
96 | Returns:
97 | Updated `TypedGraph`.
98 | """
99 |
100 | updated_graph = graph
101 |
102 | # Edge update.
103 | updated_edges = dict(updated_graph.edges)
104 | for edge_set_name, edge_fn in update_edge_fn.items():
105 | edge_set_key = graph.edge_key_by_name(edge_set_name)
106 | updated_edges[edge_set_key] = _edge_update(
107 | updated_graph, edge_fn, edge_set_key)
108 | updated_graph = updated_graph._replace(edges=updated_edges)
109 |
110 | # Node update.
111 | updated_nodes = dict(updated_graph.nodes)
112 | for node_set_key, node_fn in update_node_fn.items():
113 | updated_nodes[node_set_key] = _node_update(
114 | updated_graph, node_fn, node_set_key, aggregate_edges_for_nodes_fn)
115 | updated_graph = updated_graph._replace(nodes=updated_nodes)
116 |
117 | # Global update.
118 | if update_global_fn:
119 | updated_context = _global_update(
120 | updated_graph, update_global_fn,
121 | aggregate_edges_for_globals_fn,
122 | aggregate_nodes_for_globals_fn)
123 | updated_graph = updated_graph._replace(context=updated_context)
124 |
125 | return updated_graph
126 |
127 | return _apply_graph_net
128 |
129 |
130 | def _edge_update(graph, edge_fn, edge_set_key): # pylint: disable=invalid-name
131 | """Updates an edge set of a given key."""
132 |
133 | sender_nodes = graph.nodes[edge_set_key.node_sets[0]]
134 | receiver_nodes = graph.nodes[edge_set_key.node_sets[1]]
135 | edge_set = graph.edges[edge_set_key]
136 | senders = edge_set.indices.senders # pytype: disable=attribute-error
137 | receivers = edge_set.indices.receivers # pytype: disable=attribute-error
138 |
139 | sent_attributes = tree.tree_map(
140 | lambda n: n[senders], sender_nodes.features)
141 | received_attributes = tree.tree_map(
142 | lambda n: n[receivers], receiver_nodes.features)
143 |
144 | n_edge = edge_set.n_edge
145 | sum_n_edge = senders.shape[0]
146 | global_features = tree.tree_map(
147 | lambda g: jnp.repeat(g, n_edge, axis=0, total_repeat_length=sum_n_edge),
148 | graph.context.features)
149 | new_features = edge_fn(
150 | edge_set.features, sent_attributes, received_attributes,
151 | global_features)
152 | return edge_set._replace(features=new_features)
153 |
154 |
155 | def _node_update(graph, node_fn, node_set_key, aggregation_fn): # pylint: disable=invalid-name
156 | """Updates an edge set of a given key."""
157 | node_set = graph.nodes[node_set_key]
158 | sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0]
159 |
160 | sent_features = {}
161 | for edge_set_key, edge_set in graph.edges.items():
162 | sender_node_set_key = edge_set_key.node_sets[0]
163 | if sender_node_set_key == node_set_key:
164 | assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
165 | senders = edge_set.indices.senders
166 | sent_features[edge_set_key.name] = tree.tree_map(
167 | lambda e: aggregation_fn(e, senders, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop
168 |
169 | received_features = {}
170 | for edge_set_key, edge_set in graph.edges.items():
171 | receiver_node_set_key = edge_set_key.node_sets[1]
172 | if receiver_node_set_key == node_set_key:
173 | assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
174 | receivers = edge_set.indices.receivers
175 | received_features[edge_set_key.name] = tree.tree_map(
176 | lambda e: aggregation_fn(e, receivers, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop
177 |
178 | n_node = node_set.n_node
179 | global_features = tree.tree_map(
180 | lambda g: jnp.repeat(g, n_node, axis=0, total_repeat_length=sum_n_node),
181 | graph.context.features)
182 | new_features = node_fn(
183 | node_set.features, sent_features, received_features, global_features)
184 | return node_set._replace(features=new_features)
185 |
186 |
187 | def _global_update(graph, global_fn, edge_aggregation_fn, node_aggregation_fn): # pylint: disable=invalid-name
188 | """Updates an edge set of a given key."""
189 | n_graph = graph.context.n_graph.shape[0]
190 | graph_idx = jnp.arange(n_graph)
191 |
192 | edge_features = {}
193 | for edge_set_key, edge_set in graph.edges.items():
194 | assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
195 | sum_n_edge = edge_set.indices.senders.shape[0]
196 | edge_gr_idx = jnp.repeat(
197 | graph_idx, edge_set.n_edge, axis=0, total_repeat_length=sum_n_edge)
198 | edge_features[edge_set_key.name] = tree.tree_map(
199 | lambda e: edge_aggregation_fn(e, edge_gr_idx, n_graph), # pylint: disable=cell-var-from-loop
200 | edge_set.features)
201 |
202 | node_features = {}
203 | for node_set_key, node_set in graph.nodes.items():
204 | sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0]
205 | node_gr_idx = jnp.repeat(
206 | graph_idx, node_set.n_node, axis=0, total_repeat_length=sum_n_node)
207 | node_features[node_set_key] = tree.tree_map(
208 | lambda n: node_aggregation_fn(n, node_gr_idx, n_graph), # pylint: disable=cell-var-from-loop
209 | node_set.features)
210 |
211 | new_features = global_fn(node_features, edge_features, graph.context.features)
212 | return graph.context._replace(features=new_features)
213 |
214 |
215 | InteractionUpdateNodeFn = Callable[
216 | [jraph.NodeFeatures,
217 | Mapping[str, SenderFeatures],
218 | Mapping[str, ReceiverFeatures]],
219 | jraph.NodeFeatures]
220 |
221 |
222 | InteractionUpdateNodeFnNoSentEdges = Callable[
223 | [jraph.NodeFeatures,
224 | Mapping[str, ReceiverFeatures]],
225 | jraph.NodeFeatures]
226 |
227 |
228 | def InteractionNetwork( # pylint: disable=invalid-name
229 | update_edge_fn: Mapping[str, jraph.InteractionUpdateEdgeFn],
230 | update_node_fn: Mapping[str, Union[InteractionUpdateNodeFn,
231 | InteractionUpdateNodeFnNoSentEdges]],
232 | aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph
233 | .segment_sum,
234 | include_sent_messages_in_node_update: bool = False):
235 | """Returns a method that applies a configured InteractionNetwork.
236 |
237 | An interaction network computes interactions on the edges based on the
238 | previous edges features, and on the features of the nodes sending into those
239 | edges. It then updates the nodes based on the incoming updated edges.
240 | See https://arxiv.org/abs/1612.00222 for more details.
241 |
242 | This implementation extends the behavior to `TypedGraphs` adding an option
243 | to include edge features for which a node is a sender in the arguments to
244 | the node update function.
245 |
246 | Args:
247 | update_edge_fn: mapping of functions used to update a subset of the edge
248 | types, indexed by edge type name.
249 | update_node_fn: mapping of functions used to update a subset of the node
250 | types, indexed by node type name.
251 | aggregate_edges_for_nodes_fn: function used to aggregate messages to each
252 | node.
253 | include_sent_messages_in_node_update: pass edge features for which a node is
254 | a sender to the node update function.
255 | """
256 | # An InteractionNetwork is a GraphNetwork without globals features,
257 | # so we implement the InteractionNetwork as a configured GraphNetwork.
258 |
259 | # An InteractionNetwork edge function does not have global feature inputs,
260 | # so we filter the passed global argument in the GraphNetwork.
261 | wrapped_update_edge_fn = tree.tree_map(
262 | lambda fn: lambda e, s, r, g: fn(e, s, r), update_edge_fn)
263 |
264 | # Similarly, we wrap the update_node_fn to ensure only the expected
265 | # arguments are passed to the Interaction net.
266 | if include_sent_messages_in_node_update:
267 | wrapped_update_node_fn = tree.tree_map(
268 | lambda fn: lambda n, s, r, g: fn(n, s, r), update_node_fn)
269 | else:
270 | wrapped_update_node_fn = tree.tree_map(
271 | lambda fn: lambda n, s, r, g: fn(n, r), update_node_fn)
272 | return GraphNetwork(
273 | update_edge_fn=wrapped_update_edge_fn,
274 | update_node_fn=wrapped_update_node_fn,
275 | aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn)
276 |
277 |
278 | def GraphMapFeatures( # pylint: disable=invalid-name
279 | embed_edge_fn: Optional[Mapping[str, jraph.EmbedEdgeFn]] = None,
280 | embed_node_fn: Optional[Mapping[str, jraph.EmbedNodeFn]] = None,
281 | embed_global_fn: Optional[jraph.EmbedGlobalFn] = None):
282 | """Returns function which embeds the components of a graph independently.
283 |
284 | Args:
285 | embed_edge_fn: mapping of functions used to embed each edge type,
286 | indexed by edge type name.
287 | embed_node_fn: mapping of functions used to embed each node type,
288 | indexed by node type name.
289 | embed_global_fn: function used to embed the globals.
290 | """
291 |
292 | def _embed(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
293 |
294 | updated_edges = dict(graph.edges)
295 | if embed_edge_fn:
296 | for edge_set_name, embed_fn in embed_edge_fn.items():
297 | edge_set_key = graph.edge_key_by_name(edge_set_name)
298 | edge_set = graph.edges[edge_set_key]
299 | updated_edges[edge_set_key] = edge_set._replace(
300 | features=embed_fn(edge_set.features))
301 |
302 | updated_nodes = dict(graph.nodes)
303 | if embed_node_fn:
304 | for node_set_key, embed_fn in embed_node_fn.items():
305 | node_set = graph.nodes[node_set_key]
306 | updated_nodes[node_set_key] = node_set._replace(
307 | features=embed_fn(node_set.features))
308 |
309 | updated_context = graph.context
310 | if embed_global_fn:
311 | updated_context = updated_context._replace(
312 | features=embed_global_fn(updated_context.features))
313 |
314 | return graph._replace(edges=updated_edges, nodes=updated_nodes,
315 | context=updated_context)
316 |
317 | return _embed
318 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/xarray_tree.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utilities for working with trees of xarray.DataArray (including Datasets).
15 |
16 | Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library;
17 | it won't work as a leaf node since it implements Mapping, but also won't work
18 | as an internal node since tree doesn't know how to re-create it properly.
19 |
20 | To fix this, we reimplement a subset of `map_structure`, exposing its
21 | constituent DataArrays as leaf nodes. This means it can be mapped over as a
22 | generic container of DataArrays, while still preserving the result as a Dataset
23 | where possible.
24 |
25 | This is useful because in a few places we need to handle a general
26 | Mapping[str, DataArray] (where the coordinates might not be compatible across
27 | the constituent DataArrays) but also the special case of a Dataset nicely.
28 |
29 | For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for
30 | some of the child DataArrays, they will be omitted from the returned dataset. If
31 | any values other than DataArrays or None are returned, then we don't attempt to
32 | return a Dataset and just return a plain dict of the results. Similarly if
33 | DataArrays are returned but with non-matching coordinates, it will just return a
34 | plain dict of DataArrays.
35 |
36 | Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py,
37 | but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`.
38 | as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the
39 | latter exposes DataArrays as leaf nodes.
40 | """
41 |
42 | from typing import Any, Callable
43 |
44 | import xarray
45 |
46 |
47 | def map_structure(func: Callable[..., Any], *structures: Any) -> Any:
48 | """Maps func through given structures with xarrays. See tree.map_structure."""
49 | if not callable(func):
50 | raise TypeError(f'func must be callable, got: {func}')
51 | if not structures:
52 | raise ValueError('Must provide at least one structure')
53 |
54 | first = structures[0]
55 | if isinstance(first, xarray.Dataset):
56 | data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
57 | if all(isinstance(a, (type(None), xarray.DataArray))
58 | for a in data.values()):
59 | data_arrays = [v.rename(k) for k, v in data.items() if v is not None]
60 | try:
61 | return xarray.merge(data_arrays, join='exact')
62 | except ValueError: # Exact join not possible.
63 | pass
64 | return data
65 | if isinstance(first, dict):
66 | return {k: map_structure(func, *[s[k] for s in structures])
67 | for k in first.keys()}
68 | if isinstance(first, (list, tuple, set)):
69 | return type(first)(map_structure(func, *s) for s in zip(*structures))
70 | return func(*structures)
71 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast/xarray_tree_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for xarray_tree."""
15 |
16 | from absl.testing import absltest
17 | from graphcast import xarray_tree
18 | import numpy as np
19 | import xarray
20 |
21 |
22 | TEST_DATASET = xarray.Dataset(
23 | data_vars={
24 | "foo": (("x", "y"), np.zeros((2, 3))),
25 | "bar": (("x",), np.zeros((2,))),
26 | },
27 | coords={
28 | "x": [1, 2],
29 | "y": [10, 20, 30],
30 | }
31 | )
32 |
33 |
34 | class XarrayTreeTest(absltest.TestCase):
35 |
36 | def test_map_structure_maps_over_leaves_but_preserves_dataset_type(self):
37 | def fn(leaf):
38 | self.assertIsInstance(leaf, xarray.DataArray)
39 | result = leaf + 1
40 | # Removing the name from the returned DataArray to test that we don't rely
41 | # on it being present to restore the correct names in the result:
42 | result = result.rename(None)
43 | return result
44 |
45 | result = xarray_tree.map_structure(fn, TEST_DATASET)
46 | self.assertIsInstance(result, xarray.Dataset)
47 | self.assertSameElements({"foo", "bar"}, result.keys())
48 |
49 | def test_map_structure_on_data_arrays(self):
50 | data_arrays = dict(TEST_DATASET)
51 | result = xarray_tree.map_structure(lambda x: x+1, data_arrays)
52 | self.assertIsInstance(result, dict)
53 | self.assertSameElements({"foo", "bar"}, result.keys())
54 |
55 | def test_map_structure_on_dataset_plain_dict_when_coords_incompatible(self):
56 | def fn(leaf):
57 | # Returns DataArrays that can't be exactly merged back into a Dataset
58 | # due to the coordinates not matching:
59 | if leaf.name == "foo":
60 | return xarray.DataArray(
61 | data=np.zeros(2), dims=("x",), coords={"x": [1, 2]})
62 | else:
63 | return xarray.DataArray(
64 | data=np.zeros(2), dims=("x",), coords={"x": [3, 4]})
65 |
66 | result = xarray_tree.map_structure(fn, TEST_DATASET)
67 | self.assertIsInstance(result, dict)
68 | self.assertSameElements({"foo", "bar"}, result.keys())
69 |
70 | def test_map_structure_on_dataset_drops_vars_with_none_return_values(self):
71 | def fn(leaf):
72 | return leaf if leaf.name == "foo" else None
73 |
74 | result = xarray_tree.map_structure(fn, TEST_DATASET)
75 | self.assertIsInstance(result, xarray.Dataset)
76 | self.assertSameElements({"foo"}, result.keys())
77 |
78 | def test_map_structure_on_dataset_returns_plain_dict_other_return_types(self):
79 | def fn(leaf):
80 | self.assertIsInstance(leaf, xarray.DataArray)
81 | return "not a DataArray"
82 |
83 | result = xarray_tree.map_structure(fn, TEST_DATASET)
84 | self.assertEqual({"foo": "not a DataArray",
85 | "bar": "not a DataArray"}, result)
86 |
87 | def test_map_structure_two_args_different_variable_orders(self):
88 | dataset_different_order = TEST_DATASET[["bar", "foo"]]
89 | def fn(arg1, arg2):
90 | self.assertEqual(arg1.name, arg2.name)
91 | xarray_tree.map_structure(fn, TEST_DATASET, dataset_different_order)
92 |
93 |
94 | if __name__ == "__main__":
95 | absltest.main()
96 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast_runner.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # # GraphCast
5 | #
6 | # This colab lets you run several versions of GraphCast.
7 | #
8 | # The model weights, normalization statistics, and example inputs are available on [Google Cloud Bucket](https://console.cloud.google.com/storage/browser/dm_graphcast).
9 | #
10 | # A Colab runtime with TPU/GPU acceleration will substantially speed up generating predictions and computing the loss/gradients. If you're using a CPU-only runtime, you can switch using the menu "Runtime > Change runtime type".
11 |
12 | # >
Copyright 2023 DeepMind Technologies Limited.
13 | # > Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.
14 | # > Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
15 |
16 | # Use most memory-conservative allocation scheme
17 | # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
18 | # Either set this in your environment or set this before any import of jax code
19 | import os
20 | from pathlib import Path
21 |
22 | import graphcast_wrapper
23 | from weatherbench2_dataloader import WeatherBench2Dataset
24 |
25 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
26 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
27 |
28 | import argparse
29 | import dataclasses
30 | import functools
31 | import time
32 |
33 | from graphcast import losses
34 | from graphcast import data_utils
35 | from graphcast import graphcast
36 | from graphcast import rollout
37 | from graphcast import xarray_jax
38 | from graphcast import xarray_tree
39 | import haiku as hk
40 | import jax
41 | import numpy as np
42 |
43 | from buckets import authenticate_bucket
44 | from demo_data import load_example_batch, save_example_batch, load_normalization
45 |
46 |
47 | def train_model(loss_fn: hk.TransformedWithState, params, train_inputs, train_targets, train_forcings, epochs=5):
48 | """
49 | @param loss_fn: a hk_transform wrapped loss function (encapsulates the model as well)
50 | @param params: initial model parameters (from a checkpoint)
51 | @param train_inputs
52 | @param train_targets
53 | @param train_forcings
54 | @param epochs
55 | """
56 | assert epochs is not None
57 |
58 | def grads_fn(params, state, inputs, targets, forcings):
59 | def _aux(params, state, i, t, f):
60 | (loss, diagnostics), next_state = loss_fn.apply(params,
61 | state,
62 | jax.random.PRNGKey(0),
63 | i,
64 | t,
65 | f)
66 | return loss, (diagnostics, next_state)
67 |
68 | # TODO add reduce_axes=('batch',)
69 | (loss, (diagnostics, next_state)), grads = jax.value_and_grad(_aux, has_aux=True)(params,
70 | state,
71 | inputs,
72 | targets,
73 | forcings)
74 | return loss, diagnostics, next_state, grads
75 |
76 | grads_fn_jitted = jax.jit(grads_fn)
77 |
78 | runtimes = []
79 | for i in range(epochs):
80 | tic = time.perf_counter()
81 | # Gradient computation (backprop through time)
82 | loss, diagnostics, next_state, grads = grads_fn_jitted(
83 | params=params,
84 | state={},
85 | inputs=train_inputs,
86 | targets=train_targets,
87 | forcings=train_forcings)
88 | jax.block_until_ready(grads)
89 | jax.block_until_ready(loss)
90 | mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
91 | print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")
92 | toc = time.perf_counter()
93 | print(f"Step {i} took {toc-tic}s")
94 | if i > 0:
95 | runtimes.append(toc-tic)
96 | print("Training step time: ", np.mean(np.asarray(runtimes)), " +-", np.std(np.asarray(runtimes)))
97 |
98 |
99 | def evaluate_model(fwd_cost_fn, task_config: graphcast.TaskConfig, dataloader, autoregressive_steps=1):
100 | """
101 | Perform inference using the given forward cost function.
102 | Assumes each autoregressive step has a lead time of 6 hours.
103 | @param fwd_cost_fn Cost function that takes inputs, targets, and forcings to return a scalar cost.
104 | @param task_config Corresponding task configuration (must match model underlying the fwd_cost_fn)
105 | @param dataloader must support __len__ and __getitem__. Each batch should be an xarray.Dataarray
106 | @param autoregressive_steps how many times to unroll the model. Must match with the batch provided from the dataloader.
107 | """
108 | costs = []
109 | for i in range(len(dataloader)):
110 | batch = dataloader[i]
111 |
112 | inputs, targets, forcings = data_utils.extract_inputs_targets_forcings(batch,
113 | target_lead_times=f"{autoregressive_steps * 6}h",
114 | # Note that this sets input duration to 12h
115 | **dataclasses.asdict(task_config))
116 |
117 | cost = fwd_cost_fn(inputs, targets, forcings)
118 | print(f"Batch {i}/{len(dataloader)} cost: {cost}")
119 | costs.append(cost)
120 |
121 | mean_cost = np.asarray(costs).mean()
122 | print("Mean cost:", mean_cost, " +-", np.asarray(costs).std())
123 | return mean_cost
124 |
125 | def rmse_forward(inputs, targets, forcings, forward_fn, level: int, variable_weights: dict[str, float]) -> float:
126 | """
127 | Compute the RMSE given a forward function
128 | @param inputs
129 | @param targets
130 | @param forcings
131 | @param foward_fn Function that computes the prediction. Will be rolled out autoregressively.
132 | @param level Pressure level to evaluate the RMSE on
133 | @param variable_weights weight to assign to each target variable
134 | """
135 | predictions = rollout.chunked_prediction(
136 | forward_fn,
137 | rng=jax.random.PRNGKey(353),
138 | inputs=inputs,
139 | targets_template=targets * np.nan,
140 | forcings=forcings)
141 | def mse(x, y):
142 | return (x-y) ** 2
143 |
144 | mse_error, _ = losses.weighted_error_per_level(predictions,
145 | targets,
146 | variable_weights,
147 | mse,
148 | functools.partial(losses.single_level_weights, level=level))
149 | return np.sqrt(mse_error.mean().item())
150 |
151 | def main(resolution: float = 0.25,
152 | pressure_levels: int = 13,
153 | autoregressive_steps: int = 1,
154 | test_years=None,
155 | test_variable: str = 'geopotential',
156 | test_pressure_level: int = 500,
157 | repetitions: int = 5) -> None:
158 | """
159 | resolution: resolution of the model in degrees
160 | pressure_levels: number of pressure levels
161 | train: If true, computes gradients of the model (currently does not actually train anything)
162 | autoregressive_steps: How many rollout steps to perform. If 1, a single-step prediction is done.
163 | repetitions: For time measurement purposes, how many repetition are done.
164 | """
165 | #
166 | # - **Source**: era5, hres
167 | # - **Resolution**: 0.25deg, 1deg
168 | # - **Levels**: 13, 37
169 | #
170 | # Not all combinations are available.
171 | # - HRES is only available in 0.25 deg, with 13 pressure levels.
172 |
173 | if test_years is None:
174 | test_years = [2016]
175 | data_path = os.environ.get('DATA_PATH')
176 |
177 | run_forward, checkpoint = graphcast_wrapper.retrieve_model(resolution, pressure_levels, Path(data_path))
178 |
179 | # Always pass params so the usage below are simpler
180 | def with_params(fn):
181 | return functools.partial(fn, params=checkpoint.params)
182 |
183 | # # Run the model (Inference)
184 | dataset = WeatherBench2Dataset(
185 | year=test_years[0],
186 | steps=autoregressive_steps,
187 | steps_per_input=3)
188 |
189 | # Compile the forward function and add the configuration and params as partials
190 | run_forward_jitted = with_params(jax.jit(run_forward.apply))
191 |
192 | # Pick the test_variable as the only one
193 | variable_weights = {var: 1 if var == test_variable else 0 for var in checkpoint.task_config.target_variables}
194 | # We create the loss function by passing the forward function and parameters as a partial to rmse_forward
195 | loss_function = functools.partial(rmse_forward,
196 | forward_fn = run_forward_jitted,
197 | level=test_pressure_level,
198 | variable_weights=variable_weights)
199 | evaluate_model(loss_function, checkpoint.task_config, dataset, autoregressive_steps=autoregressive_steps)
200 |
201 |
202 | if __name__ == "__main__":
203 | parser = argparse.ArgumentParser(description='Inference and training showcase for Graphcast.')
204 |
205 | # Add the arguments with default values
206 | parser.add_argument('--resolution', type=float, default=0.25, help='Resolution of the graph in the model.')
207 | parser.add_argument('--pressure_levels', type=int, default=13, help='Number of pressure levels in the model.')
208 | parser.add_argument('--autoregressive_steps', type=int, default=1, help='Number of time steps to predict into the future.')
209 | parser.add_argument('--test_year_start', type=int, default=2016, help='First year to use for testing (inference).')
210 | parser.add_argument('--test_year_end', type=int, default=2016, help='Last year to use for testing (inference).')
211 | parser.add_argument('--test_pressure_level', type=int, default=500, help='Pressure level to use for testing (inference).')
212 | parser.add_argument('--test_variable', type=str, default='geopotential', help='Variable to use for testing (inference).')
213 | parser.add_argument('--prediction_store_path', type=str, default=None, help='If not none, evaluate predictions and store them here.')
214 |
215 |
216 | # Parse the arguments
217 | args = parser.parse_args()
218 |
219 | # Access the arguments & call main
220 | main(args.resolution,
221 | args.pressure_levels,
222 | args.autoregressive_steps,
223 | list(range(args.test_year_start, args.test_year_end+1)),
224 | args.test_variable,
225 | args.test_pressure_level)
226 |
--------------------------------------------------------------------------------
/GraphCast_src/graphcast_wrapper.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from pathlib import Path
3 |
4 | import haiku
5 | import xarray
6 |
7 | import graphcast.casting as casting
8 | import graphcast.normalization as normalization
9 | import graphcast.autoregressive as autoregressive
10 | from graphcast import graphcast
11 | import haiku as hk
12 | from graphcast.checkpoint import load
13 |
14 | from buckets import authenticate_bucket
15 | from demo_data import load_normalization
16 | from pretrained_graphcast import load_model_checkpoint, find_model_name, save_checkpoint
17 |
18 |
19 | def wrap_graphcast(model_config: graphcast.ModelConfig,
20 | task_config: graphcast.TaskConfig,
21 | diffs_stddev_by_level = None,
22 | mean_by_level = None,
23 | stddev_by_level = None,
24 | wrap_autoregressive: bool = False,
25 | normalize: bool = False):
26 | """
27 | Constructs and wraps the GraphCast Predictor.
28 | Note that this MUST be called within a haiku transform function.
29 | """
30 | # Deeper one-step predictor.
31 | predictor = graphcast.GraphCast(model_config, task_config)
32 |
33 | # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
34 | # from/to float32 to/from BFloat16.
35 | predictor = casting.Bfloat16Cast(predictor)
36 |
37 | if normalize:
38 | assert diffs_stddev_by_level is not None
39 | assert mean_by_level is not None
40 | assert stddev_by_level is not None
41 | # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
42 | # BFloat16 happens after applying normalization to the inputs/targets.
43 | predictor = normalization.InputsAndResiduals(
44 | predictor,
45 | diffs_stddev_by_level=diffs_stddev_by_level,
46 | mean_by_level=mean_by_level,
47 | stddev_by_level=stddev_by_level)
48 |
49 | if wrap_autoregressive:
50 | # Wraps everything so the one-step model can produce trajectories.
51 | predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
52 | return predictor
53 | else:
54 | return predictor
55 |
56 |
57 | def retrieve_model(resolution: float,
58 | pressure_levels: int,
59 | model_cache_path: Path = None,
60 | normalize: bool = True,
61 | autoregressive: bool = True,
62 | normalize_cache_path: Path = None) -> tuple[haiku.Transformed, graphcast.CheckPoint]:
63 | """
64 | Returns the haiku transformed forward function and model checkpoint for the given settings.
65 | The signature of the apply function is
66 | (params,
67 | rng,
68 | inputs: xarray.Dataset,
69 | targets_template: xarray.Dataset,
70 | forcings: xarray.Dataset
71 | )
72 | Note that there is not rng because the model is deterministic.
73 |
74 | Possible model values:
75 | - **Resolution**: 0.25deg, 1deg
76 | - **Pressure Levels**: 13, 37
77 |
78 | Not all combinations are available.
79 | - HRES is only available in 0.25 deg, with 13 pressure levels.
80 |
81 | """
82 | gcs_bucket = authenticate_bucket()
83 |
84 | # Choose the model
85 | params_file_options = [
86 | name for blob in gcs_bucket.list_blobs(prefix="params/")
87 | if (name := blob.name.removeprefix("params/"))] # Drop empty string.
88 |
89 | params_file_name = find_model_name(params_file_options, resolution, pressure_levels)
90 | if params_file_name is None:
91 | raise FileNotFoundError(
92 | f"No model with given resolution ({resolution} deg) and pressure levels ({pressure_levels}) found.")
93 |
94 | # Load the model
95 | # TODO Use cached model if available
96 | if (model_cache_path / params_file_name).exists():
97 | with open((model_cache_path / params_file_name), "rb") as f:
98 | checkpoint = load(f, graphcast.CheckPoint)
99 | else:
100 | checkpoint = load_model_checkpoint(gcs_bucket, params_file_name)
101 | save_checkpoint(gcs_bucket, model_cache_path, params_file_name)
102 | print(checkpoint.model_config)
103 |
104 | diffs_stddev_by_level, mean_by_level, stddev_by_level = None, None, None
105 | if normalize:
106 | # Load normalization data
107 | diffs_stddev_by_level, mean_by_level, stddev_by_level = load_normalization(gcs_bucket, normalize_cache_path)
108 |
109 | # Build haiku transformed function
110 | def run_forward(inputs: xarray.Dataset,
111 | targets_template: xarray.Dataset,
112 | forcings: xarray.Dataset):
113 | predictor = wrap_graphcast(checkpoint.model_config,
114 | checkpoint.task_config,
115 | diffs_stddev_by_level,
116 | mean_by_level,
117 | stddev_by_level,
118 | wrap_autoregressive=autoregressive,
119 | normalize=normalize)
120 | return predictor(inputs, targets_template=targets_template, forcings=forcings)
121 |
122 | forward = haiku.transform(run_forward)
123 |
124 | return forward, checkpoint
--------------------------------------------------------------------------------
/GraphCast_src/plotting.py:
--------------------------------------------------------------------------------
1 |
2 | import datetime
3 | import math
4 | from typing import Optional
5 | from IPython.display import HTML
6 | import ipywidgets as widgets
7 | import matplotlib
8 | import matplotlib.pyplot as plt
9 | from matplotlib import animation
10 | import numpy as np
11 | import xarray
12 |
13 | def select(
14 | data: xarray.Dataset,
15 | variable: str,
16 | level: Optional[int] = None,
17 | max_steps: Optional[int] = None
18 | ) -> xarray.Dataset:
19 | data = data[variable]
20 | if "batch" in data.dims:
21 | data = data.isel(batch=0)
22 | if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
23 | data = data.isel(time=range(0, max_steps))
24 | if level is not None and "level" in data.coords:
25 | data = data.sel(level=level)
26 | return data
27 |
28 | def scale(
29 | data: xarray.Dataset,
30 | center: Optional[float] = None,
31 | robust: bool = False,
32 | ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
33 | vmin = np.nanpercentile(data, (2 if robust else 0))
34 | vmax = np.nanpercentile(data, (98 if robust else 100))
35 | if center is not None:
36 | diff = max(vmax - center, center - vmin)
37 | vmin = center - diff
38 | vmax = center + diff
39 | return (data, matplotlib.colors.Normalize(vmin, vmax),
40 | ("RdBu_r" if center is not None else "viridis"))
41 |
42 | def plot_data(
43 | data: dict[str, xarray.Dataset],
44 | fig_title: str,
45 | plot_size: float = 5,
46 | robust: bool = False,
47 | cols: int = 4
48 | ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
49 |
50 | first_data = next(iter(data.values()))[0]
51 | max_steps = first_data.sizes.get("time", 1)
52 | assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())
53 |
54 | cols = min(cols, len(data))
55 | rows = math.ceil(len(data) / cols)
56 | figure = plt.figure(figsize=(plot_size * 2 * cols,
57 | plot_size * rows))
58 | figure.suptitle(fig_title, fontsize=16)
59 | figure.subplots_adjust(wspace=0, hspace=0)
60 | figure.tight_layout()
61 |
62 | images = []
63 | for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
64 | ax = figure.add_subplot(rows, cols, i+1)
65 | ax.set_xticks([])
66 | ax.set_yticks([])
67 | ax.set_title(title)
68 | im = ax.imshow(
69 | plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
70 | origin="lower", cmap=cmap)
71 | plt.colorbar(
72 | mappable=im,
73 | ax=ax,
74 | orientation="vertical",
75 | pad=0.02,
76 | aspect=16,
77 | shrink=0.75,
78 | cmap=cmap,
79 | extend=("both" if robust else "neither"))
80 | images.append(im)
81 |
82 | def update(frame):
83 | if "time" in first_data.dims:
84 | td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
85 | figure.suptitle(f"{fig_title}, {td}", fontsize=16)
86 | else:
87 | figure.suptitle(fig_title, fontsize=16)
88 | for im, (plot_data, norm, cmap) in zip(images, data.values()):
89 | im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))
90 |
91 | ani = animation.FuncAnimation(
92 | fig=figure, func=update, frames=max_steps, interval=250)
93 | plt.close(figure.number)
94 | return HTML(ani.to_jshtml())
95 |
96 |
97 | def plot_predictions(predictions, eval_targets):
98 | # @title Choose predictions to plot
99 |
100 | plot_pred_variable = widgets.Dropdown(
101 | options=predictions.data_vars.keys(),
102 | value="2m_temperature",
103 | description="Variable")
104 | plot_pred_level = widgets.Dropdown(
105 | options=predictions.coords["level"].values,
106 | value=500,
107 | description="Level")
108 | plot_pred_robust = widgets.Checkbox(value=True, description="Robust")
109 | plot_pred_max_steps = widgets.IntSlider(
110 | min=1,
111 | max=predictions.dims["time"],
112 | value=predictions.dims["time"],
113 | description="Max steps")
114 |
115 | widgets.VBox([
116 | plot_pred_variable,
117 | plot_pred_level,
118 | plot_pred_robust,
119 | plot_pred_max_steps,
120 | widgets.Label(value="Run the next cell to plot the predictions. Rerunning this cell clears your selection.")
121 | ])
122 |
123 | # @title Plot predictions
124 |
125 | plot_size = 5
126 | plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value)
127 |
128 | data = {
129 | "Targets": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps),
130 | robust=plot_pred_robust.value),
131 | "Predictions": scale(select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps),
132 | robust=plot_pred_robust.value),
133 | "Diff": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
134 | select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
135 | robust=plot_pred_robust.value, center=0),
136 | }
137 | fig_title = plot_pred_variable.value
138 | if "level" in predictions[plot_pred_variable.value].coords:
139 | fig_title += f" at {plot_pred_level.value} hPa"
140 |
141 | plot_data(data, fig_title, plot_size, plot_pred_robust.value)
142 |
--------------------------------------------------------------------------------
/GraphCast_src/pretrained_graphcast.py:
--------------------------------------------------------------------------------
1 | from google.cloud.storage import Bucket
2 | from haiku._src.data_structures import frozendict
3 |
4 | from graphcast import checkpoint
5 | from graphcast import graphcast
6 | import buckets
7 | from pathlib import Path
8 | from typing import Sequence, Union
9 |
10 |
11 | # # Load the Data and initialize the model
12 |
13 | # ## Load the model params
14 | #
15 | # Choose one of the two ways of getting model params:
16 | # - **random**: You'll get random predictions, but you can change the model architecture, which may run faster or fit on your device.
17 | # - **checkpoint**: You'll get sensible predictions, but are limited to the model architecture that it was trained with, which may not fit on your device. In particular generating gradients uses a lot of memory, so you'll need at least 25GB of ram (TPUv4 or A100).
18 | #
19 | # Checkpoints vary across a few axes:
20 | # - The mesh size specifies the internal graph representation of the earth. Smaller meshes will run faster but will have worse outputs. The mesh size does not affect the number of parameters of the model.
21 | # - The resolution and number of pressure levels must match the data. Lower resolution and fewer levels will run a bit faster. Data resolution only affects the encoder/decoder.
22 | # - All our models predict precipitation. However, ERA5 includes precipitation, while HRES does not. Our models marked as "ERA5" take precipitation as input and expect ERA5 data as input, while model marked "ERA5-HRES" do not take precipitation as input and are specifically trained to take HRES-fc0 as input (see the data section below).
23 | #
24 | # We provide three pre-trained models.
25 | # 1. `GraphCast`, the high-resolution model used in the GraphCast paper (0.25 degree resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017,
26 | #
27 | # 2. `GraphCast_small`, a smaller, low-resolution version of GraphCast (1 degree resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from 1979 to 2015, useful to run a model with lower memory and compute constraints,
28 | #
29 | # 3. `GraphCast_operational`, a high-resolution model (0.25 degree resolution, 13 pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on HRES data from 2016 to 2021. This model can be initialized from HRES data (does not require precipitation inputs).
30 | #
31 |
32 |
33 | def load_model_checkpoint(gcs_bucket: Bucket, name: str) -> graphcast.CheckPoint:
34 | with gcs_bucket.blob(f"params/{name}").open("rb") as f:
35 | ckpt = checkpoint.load(f, graphcast.CheckPoint)
36 |
37 | print("Model description:\n", ckpt.description, "\n")
38 | print("Model license:\n", ckpt.license, "\n")
39 |
40 | return ckpt
41 |
42 | def find_model_name(options: Sequence[str], resolution: float, pressure_level: int) -> Union[str, None]:
43 | for name in options:
44 | if f"resolution {resolution}" in name and f"levels {pressure_level}" in name:
45 | return name
46 | return None
47 |
48 | def save_checkpoint(gcs_bucket: Bucket, directory: Path, name: str) -> None:
49 | # copy checkpoint to local disk
50 | with gcs_bucket.blob(f"params/{name}").open("rb") as param:
51 | buckets.save_to_dir(param, directory / 'params', name)
52 |
--------------------------------------------------------------------------------
/GraphCast_src/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Module setuptools script."""
15 |
16 | from setuptools import setup
17 |
18 | description = (
19 | "GraphCast: Learning skillful medium-range global weather forecasting"
20 | )
21 |
22 | setup(
23 | name="graphcast",
24 | version="0.1",
25 | description=description,
26 | long_description=description,
27 | author="DeepMind",
28 | license="Apache License, Version 2.0",
29 | keywords="GraphCast Weather Prediction",
30 | url="https://github.com/deepmind/graphcast",
31 | packages=["graphcast"],
32 | install_requires=[
33 | "cartopy",
34 | "chex",
35 | "colabtools",
36 | "dask",
37 | "dm-haiku",
38 | "dm-tree",
39 | "jax",
40 | "jraph",
41 | "matplotlib",
42 | "numpy",
43 | "pandas",
44 | "rtree",
45 | "scipy",
46 | #"tree",
47 | "trimesh",
48 | "typing_extensions",
49 | "xarray",
50 | "google-cloud-storage",
51 | "zarr",
52 | "gcsfs",
53 | "dask",
54 | "jax-dataloader",
55 | #"tensorflow",
56 | #"tensorboard-plugin-profile",
57 | "nvtx",
58 | "wandb"
59 | ],
60 | classifiers=[
61 | "Development Status :: 3 - Alpha",
62 | "Intended Audience :: Science/Research",
63 | "License :: OSI Approved :: Apache Software License",
64 | "Operating System :: POSIX :: Linux",
65 | "Programming Language :: Python :: 3",
66 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
67 | "Topic :: Scientific/Engineering :: Atmospheric Science",
68 | "Topic :: Scientific/Engineering :: Physics",
69 | ],
70 | )
71 |
--------------------------------------------------------------------------------
/GraphCast_src/weatherbench2_dataloader.py:
--------------------------------------------------------------------------------
1 | import xarray as xr
2 | from copy import deepcopy
3 |
4 | class WeatherBench2Dataset():
5 | def __init__(self, year: int, steps: int, steps_per_input:int = 3):
6 | self.vars = ["geopotential", "specific_humidity", "temperature",
7 | "u_component_of_wind", "v_component_of_wind",
8 | "vertical_velocity", "toa_incident_solar_radiation",
9 | "10m_u_component_of_wind", "10m_v_component_of_wind",
10 | "2m_temperature", "mean_sea_level_pressure",
11 | "total_precipitation_6hr", "geopotential_at_surface", "land_sea_mask"]
12 | self.static_vars = ["geopotential_at_surface", "land_sea_mask"]
13 | self.ds = xr.open_zarr('gs://weatherbench2/datasets/era5/1959-2022-6h-1440x721.zarr')
14 | self.ds = self.ds[self.vars]
15 | self.ds = self.ds.sel(time=slice(f'{year}-01-01', f'{year}-12-31'))
16 | self.length = len(self.ds.time)
17 | self.steps = steps
18 | self.steps_per_input = steps_per_input
19 | self.coord_name_dict = dict(latitude="lat", longitude="lon")
20 |
21 | def __len__(self):
22 | num_batches = self.length // self.steps
23 | if self.length % self.steps < self.steps_per_input - 1:
24 | num_batches -= 1
25 | return num_batches
26 |
27 | def __getitem__(self, item):
28 | return self.get_data(item)
29 |
30 | def get_data(self, batch_idx):
31 | batch_idx = batch_idx % len(self)
32 | it_range = slice(batch_idx*self.steps, (batch_idx+1)*self.steps + self.steps_per_input - 1)
33 | static_data = self.ds[self.static_vars].rename(**self.coord_name_dict)
34 | data = self.ds.drop_vars(self.static_vars).isel(time=it_range)
35 | data = data.rename(**self.coord_name_dict)
36 | data = data.isel(lat =slice(None, None, -1))
37 | data = xr.merge([static_data, data.expand_dims({'batch':1})])
38 | data = data.assign_coords(datetime=(["batch", "time"], data.time.data.reshape(1, -1)))
39 | data = data.assign_coords(time=("time", data.time.data - data.time.data[0]))
40 | return data.compute()
41 |
42 | if __name__ == "__main__":
43 | dataset = WeatherBench2Dataset(2016, steps=4, steps_per_input=3)
44 | print(len(dataset))
45 | data = dataset.get_data(-1)
46 | print(data)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Setup environment
2 | Follow `Dockerfile` to setup a container environment for NVIDIA GPUs.
3 |
4 | # Prepare forecast data from GraphCast
5 | Setup a Google Cloud Storage key named `gcs_key.json` and put it at `GraphCast_src/`
6 |
7 | Execute `python3 GraphCast_src/graphcast_runner.py --resolution .25 --pressure_levels 13 --autoregressive_steps 8 --test_year_start 1977 --test_year_end 2016` and `python3 GraphCast_src/graphcast_runner.py --resolution .25 --pressure_levels 13 --autoregressive_steps 1 --test_year_start 1977 --test_year_end 2016` to generate 48h and 6h GraphCast forecast data need for training the diffusion model.
8 |
9 | # Prepare GraphCast weights and parameters
10 | Visit: https://console.cloud.google.com/storage/browser/dm_graphcast and download `GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz` under `params` folder and the whole `stats` folder.
11 |
12 | # Train diffusion model
13 | Execute `python3 DiffDA/train_conditional_graphcast.py` and pass in required arguments (hyperparameters, ERA5, forecast data path, etc.)
14 |
15 | # Run data assimilation
16 | - Execute `python3 DiffDA/inference_data_assimilation.py --num_autoregressive_steps=1 ...` to run single step data assimilation
17 | - Execute `python3 DiffDA/inference_data_assimilation.py --num_autoregressive_steps=n ...` (n > 2) to run autoregressive data assimilation
18 | - Execute `python3 DiffDA/inference_data_assimilation_gc.py --num_autoregressive_steps=n ...` to run (autoregressive) GraphCast forecast on single step assimilated data
19 |
20 | # Implementation detail
21 | - `GraphCast_src/graphcast/normalization.py`: ddpm and repaint algorithm for inference
22 | - `DiffDA/train_conditional_graphcast.py`: training diffusion model with GraphCast as backbone
23 | - `DiffDA/inference_data_assimilation.py`: run single step & autoregressive data assimilation
24 | - `DiffDA/inference_data_assimilation_gc.py`: run GraphCast forecast on single step assimilated data
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur0.5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur0.5_geopotential.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_geopotential.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur0.5_sfc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_sfc.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur0.5_specific_humidity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_specific_humidity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur0.5_temperature.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_temperature.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur0.5_u_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_u_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur0.5_v_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_v_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur0.5_vertical_velocity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur0.5_vertical_velocity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.0.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.0_geopotential.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_geopotential.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.0_sfc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_sfc.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.0_specific_humidity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_specific_humidity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.0_temperature.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_temperature.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.0_u_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_u_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.0_v_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_v_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.0_vertical_velocity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.0_vertical_velocity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.5_geopotential.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_geopotential.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.5_sfc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_sfc.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.5_specific_humidity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_specific_humidity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.5_temperature.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_temperature.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.5_u_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_u_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.5_v_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_v_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur1.5_vertical_velocity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur1.5_vertical_velocity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.0.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.0_geopotential.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_geopotential.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.0_sfc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_sfc.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.0_specific_humidity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_specific_humidity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.0_temperature.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_temperature.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.0_u_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_u_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.0_v_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_v_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.0_vertical_velocity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.0_vertical_velocity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.5_geopotential.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_geopotential.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.5_sfc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_sfc.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.5_specific_humidity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_specific_humidity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.5_temperature.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_temperature.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.5_u_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_u_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.5_v_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_v_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment1/score_board_0_blur2.5_vertical_velocity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment1/score_board_0_blur2.5_vertical_velocity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur0.5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur0.5_geopotential.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_geopotential.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur0.5_sfc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_sfc.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur0.5_specific_humidity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_specific_humidity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur0.5_temperature.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_temperature.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur0.5_u_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_u_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur0.5_v_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_v_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur0.5_vertical_velocity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur0.5_vertical_velocity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.0.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.0_geopotential.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_geopotential.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.0_sfc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_sfc.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.0_specific_humidity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_specific_humidity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.0_temperature.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_temperature.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.0_u_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_u_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.0_v_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_v_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.0_vertical_velocity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.0_vertical_velocity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.5_geopotential.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_geopotential.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.5_sfc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_sfc.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.5_specific_humidity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_specific_humidity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.5_temperature.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_temperature.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.5_u_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_u_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.5_v_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_v_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur1.5_vertical_velocity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur1.5_vertical_velocity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.0.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.0_geopotential.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_geopotential.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.0_sfc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_sfc.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.0_specific_humidity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_specific_humidity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.0_temperature.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_temperature.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.0_u_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_u_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.0_v_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_v_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.0_vertical_velocity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.0_vertical_velocity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.5_geopotential.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_geopotential.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.5_sfc.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_sfc.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.5_specific_humidity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_specific_humidity.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.5_temperature.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_temperature.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.5_u_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_u_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.5_v_component_of_wind.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_v_component_of_wind.pdf
--------------------------------------------------------------------------------
/figs/ablation_experiment3/score_board_48_blur2.5_vertical_velocity.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/ablation_experiment3/score_board_48_blur2.5_vertical_velocity.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/2m_temperature_1000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_1000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/2m_temperature_10000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_10000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/2m_temperature_2000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_2000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/2m_temperature_20000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_20000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/2m_temperature_4000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_4000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/2m_temperature_40000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_40000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/2m_temperature_8000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/2m_temperature_8000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/geopotential_500_1000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_1000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/geopotential_500_10000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_10000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/geopotential_500_2000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_2000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/geopotential_500_20000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_20000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/geopotential_500_4000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_4000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/geopotential_500_40000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_40000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/geopotential_500_8000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/geopotential_500_8000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/temperature_850_1000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_1000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/temperature_850_10000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_10000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/temperature_850_2000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_2000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/temperature_850_20000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_20000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/temperature_850_4000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_4000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/temperature_850_40000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_40000.pdf
--------------------------------------------------------------------------------
/figs/error_visualization/temperature_850_8000.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spcl/DiffDA/367c3e887f947f139baef410a7f514c3bca90c30/figs/error_visualization/temperature_850_8000.pdf
--------------------------------------------------------------------------------