├── .gitignore
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── NCEP
├── Readme.md
├── cronjob_cloud.sh
├── cronjob_hera.sh
├── docs
│ ├── Makefile
│ ├── README.md
│ ├── make.bat
│ ├── requirements.txt
│ └── source
│ │ ├── conf.py
│ │ ├── index.rst
│ │ ├── inputs.rst
│ │ ├── installation.rst
│ │ ├── introduction.rst
│ │ ├── outputs.rst
│ │ └── run.rst
├── environment.yml
├── gc_datadissm_hera.sh
├── gc_prepdata_hera.sh
├── gc_runfcst_hera.sh
├── gcjob_13pl_cloud.sh
├── gcjob_37pl_cloud.sh
├── gdas_utility.py
├── run_graphcast.py
├── upload_to_s3bucket.py
└── utils
│ └── nc2grib.py
├── README.md
├── docs
├── GenCast_0p25deg_accelerator_scorecard.png
├── GenCast_0p25deg_attention_implementation_scorecard.png
├── GenCast_1p0deg_Mini_ENS_scorecard.png
├── cloud_vm_setup.md
├── local_runtime_popup_1.png
├── local_runtime_popup_2.png
├── local_runtime_url.png
├── project.png
├── provision_tpu.png
└── tpu_types.png
├── gencast_demo_cloud_vm.ipynb
├── gencast_mini_demo.ipynb
├── graphcast
├── autoregressive.py
├── casting.py
├── checkpoint.py
├── checkpoint_test.py
├── data_utils.py
├── data_utils_test.py
├── deep_typed_graph_net.py
├── denoiser.py
├── denoisers_base.py
├── dpm_solver_plus_plus_2s.py
├── gencast.py
├── graphcast.py
├── grid_mesh_connectivity.py
├── grid_mesh_connectivity_test.py
├── icosahedral_mesh.py
├── icosahedral_mesh_test.py
├── losses.py
├── mlp.py
├── model_utils.py
├── nan_cleaning.py
├── normalization.py
├── predictor_base.py
├── rollout.py
├── samplers_base.py
├── samplers_utils.py
├── solar_radiation.py
├── solar_radiation_test.py
├── sparse_transformer.py
├── sparse_transformer_utils.py
├── transformer.py
├── typed_graph.py
├── typed_graph_net.py
├── xarray_jax.py
├── xarray_jax_test.py
├── xarray_tree.py
└── xarray_tree_test.py
├── graphcast_demo.ipynb
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg*
2 | __pycache__
3 | *.swp
4 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 | build:
8 | os: ubuntu-20.04
9 | tools:
10 | python: "3.9"
11 |
12 | # Build documentation in the docs/ directory with Sphinx
13 | sphinx:
14 | configuration: NCEP/docs/source/conf.py
15 |
16 | # Build documentation with MkDocs
17 | #mkdocs:
18 | # configuration: mkdocs.yml
19 |
20 | # Optionally build your docs in additional formats such as PDF and ePub
21 | formats: all
22 |
23 | # Optionally set the version of Python and requirements required to build your docs
24 | python:
25 | install:
26 | - requirements: NCEP/docs/requirements.txt
27 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/NCEP/Readme.md:
--------------------------------------------------------------------------------
1 | # GraphCast model with NCEP GDAS Products as ICs
2 |
3 | This repository provides scripts to run real-time GraphCast using GDAS products as inputs. There are multiple scripts in the repository including:
4 | - `gdas_utility.py`: a Python script designed to download Global Data Assimilation System (GDAS) data from the National Centers for Environmental Prediction (NCEP) from NOAA S3 bucket (or NOMADS), and prepare the data in a format suitable for feeding into the GraphCast weather prediction system.
5 | - `run_graphcast.py`: a Python script that calls GraphCast and takes GDAS products as input and produces six-hourly forecasts with an arbitrary forecast length (e.g., 40 --> 10-days).
6 | - `graphcast_job_[machine_name].sh`: a Bash script that automates running GraphCast in real-time over the AWS cloud machine (should be submitted through CronJob).
7 |
8 | ## Table of Contents
9 | - [Overview](#overview)
10 | - [Prerequisites and Installation](#prerequisites-and-installation)
11 | - [Usage](#usage)
12 | - [GDAS Utility](#gdas-utility)
13 | - [Run GraphCast](#run-graphcast)
14 | - [Run GraphCast Through Cronjob](#run-graphcast-through-cronjob)
15 | - [Output](#output)
16 | - [Contact](#contact)
17 |
18 | ## Overview
19 |
20 | The National Centers for Environmental Prediction (NCEP) provides GDAS data that can be used for weather prediction and analysis. This repository simplifies the process of downloading GDAS data, extracting relevant variables, and converting it into a format compatible with the GraphCast weather prediction system. In addition, it automates running GraphCast with GDAS inputs on the NOAA clusters.
21 |
22 | ## Prerequisites and Installation
23 |
24 | To install the package, run the following commands:
25 |
26 | ```bash
27 | conda create --name mlwp python=3.10
28 | ```
29 |
30 | ```bash
31 | conda activate mlwp
32 | ```
33 |
34 | ```bash
35 | pip install dm-tree boto3 xarray netcdf4
36 | ```
37 |
38 | ```bash
39 | conda install --channel conda-forge cartopy
40 | ```
41 |
42 | ```bash
43 | pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip
44 | ```
45 |
46 | ```bash
47 | pip isntall pygrib requests bs4
48 | ```
49 |
50 | If your OS is MacOS, wget has to be installed:
51 |
52 | ```bash
53 | brew install wget
54 | ```
55 |
56 |
57 | If you would like to save as grib2 format, the following packages are needed:
58 |
59 | ```bash
60 | pip install ecmwflibs
61 | ````
62 | ```bash
63 | pip install iris
64 | ````
65 |
66 | ```bash
67 | pip install iris-grib
68 | ````
69 |
70 | This will install the packages and most of their dependencies.
71 |
72 |
73 | Additionally, the utility uses the `wgrib2` library for extracting specific variables from the GDAS data. You can download and install `wgrib2` from [here](http://www.cpc.ncep.noaa.gov/products/wesley/wgrib2/). Make sure it is included in your system's PATH.
74 |
75 | ## Usage
76 |
77 | To use the utility, follow these steps:
78 |
79 | Clone the NOAA-EMC GraphCast repository:
80 |
81 | ```bash
82 | git clone https://github.com/NOAA-EMC/graphcast.git
83 | ```
84 |
85 | ```bash
86 | cd graphcast/NCEP
87 | ```
88 |
89 | ## GDAS Utility
90 |
91 | To download and prepare GDAS data, use the following command:
92 |
93 | ```bash
94 | python3 gdas_utility.py yyyymmddhh yyyymmddhh --level 13 --source s3 --output /directory/to/output --download /directory/to/download --keep no
95 | ```
96 |
97 | #### Arguments (required):
98 |
99 | - `yyyymmddhh`: Start datetime
100 | - `yyyymmddhh`: End datetime
101 |
102 | #### Arguments (optional):
103 |
104 | - `-l or --level`: [13, 37], represents the number of pressure levels (default: 13)
105 | - `-s or --source`: [s3, nomads], represents the source to download GDAS data (default: "s3")
106 | - `-o or --output`: /directory/to/output, represents the directory to output netcdf file (default: "current directory")
107 | - `-d or --download`: /directory/to/download, represents the download directory for grib2 files (default: "current directory")
108 | - `-k or --keep`: [yes, no], specifies whether to keep downloaded data after processing (default: "no")
109 |
110 | Example usage with options:
111 |
112 | ```bash
113 | python3 gdas_utility.py 2023060600 2023060606 -o /path/to/output -d /path/to/download
114 | ```
115 |
116 | Note:
117 | - The 37 pressure levels option is still under development.
118 | - GraphCast only needs 2 states for initialization, however, gdas_utility can provide longer outputs for evaluation of the model (e.g., 10-days).
119 |
120 |
121 | ## Run GraphCast
122 |
123 | To run GraphCast, use the following command:
124 |
125 | ```bash
126 | python3 run_graphcast.py --input /path/to/input/file --output /path/to/output/file --weights /path/to/graphcast/weights --length forecast_length --upload yes --keep no`
127 | ```
128 |
129 | #### Arguments (required):
130 |
131 | - `-i or --input`: /path/to/input/file, represents the path to input netcdf file (including file name and extension)
132 | - `-o or --output`: /path/to/output/file, represents the path to output netcdf file (including file name and extension)
133 | - `-w or --weights`: /path/to/graphcast/weights, represents the path to the parent directory of the graphcast params (weights) and stats from the pre-trained model
134 | - `-l or --length`: An integer number in the range [1, 40], represents the number of forecasts time steps (6-hourly; e.g., 40 → 10-days)
135 |
136 | #### Arguments (optional):
137 |
138 | - `-u or --upload`: [yes, no], option for uploading the input and output files to NOAA S3 bucket [noaa-nws-graphcastgfs-pds] (default: "no")
139 | - `-k or --keep`: [yes, no], specifies whether to keep input and output files after uploading to NOAA S3 bucket (default: "no")
140 |
141 | Example usage with options (1-day forecast):
142 |
143 | ```bash
144 | python3 run_graphcast.py -i /path/to/input -o /path/to/output -w /path/to/graphcast/weights -l 4
145 | ```
146 |
147 | ## Run GraphCast Through Cronjob
148 |
149 | Submit the `cronjob_[machine_name].sh` to run GraphCast and get real-time (every 6 hours) forecasts through cronjob.
150 |
151 | ```bash
152 | # Example CronJob to run GraphCast every 6 hours
153 | 0 */6 * * * /lustre/Sadegh.Tabas/graphcast/NCEP/cronjob_cloud.sh >> /lustre/Sadegh.Tabas/graphcast/NCEP/logfile.log 2>&1
154 | ```
155 |
156 | ## Output
157 |
158 | The processed GDAS data as well as GraphCast forecasts will be saved in NetCDF format in the related directories (uploading to NOAA S3 bucket option is also provided for both input and output files). The files will be named based on the date.
159 |
160 | ## Contact
161 |
162 | For questions or issues, please contact [Sadegh.Tabas@noaa.gov](mailto:Sadegh.Tabas@noaa.gov).
163 |
--------------------------------------------------------------------------------
/NCEP/cronjob_cloud.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | sbatch /contrib/Sadegh.Tabas/operational/graphcast/NCEP/gcjob_13pl_cloud.sh
3 | # sbatch /contrib/Sadegh.Tabas/operational/graphcast/NCEP/gcjob_37pl_cloud.sh
4 |
--------------------------------------------------------------------------------
/NCEP/cronjob_hera.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash --login
2 |
3 | echo "Job 1 is running"
4 | sh /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/gc_prepdata_hera.sh
5 | sleep 60 # Simulating some work
6 | echo "Job 1 completed"
7 |
8 | echo "Job 2 is running"
9 | job2_id=$(sbatch /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/gc_runfcst_hera.sh | awk '{print $4}')
10 |
11 | # Wait for job 2 to complete
12 | while squeue -j $job2_id &>/dev/null; do
13 | sleep 5 # Adjust the polling interval as needed
14 | done
15 | sleep 5 # Simulating some work
16 | echo "Job 2 completed"
17 |
18 | echo "Job 3 is running"
19 | sh /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/gc_datadissm_hera.sh
20 | sleep 5 # Simulating some work
21 | echo "Job 3 completed"
22 |
--------------------------------------------------------------------------------
/NCEP/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | # #
3 | #
4 | # # You can set these variables from the command line, and also
5 | # # from the environment for the first two.
6 | # SPHINXOPTS ?=
7 | # SPHINXBUILD ?= sphinx-build
8 | # SOURCEDIR = source
9 | # BUILDDIR = build
10 | #
11 | # # Put it first so that "make" without argument is like "make help".
12 | # help:
13 | # @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 | #
15 | # .PHONY: help Makefile
16 | #
17 | # # Catch-all target: route all unknown targets to Sphinx using the new
18 | # # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | # %: Makefile
20 | # @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/NCEP/docs/README.md:
--------------------------------------------------------------------------------
1 | ## GraphCast Global Forecast System (GraphCastGFS)
2 |
--------------------------------------------------------------------------------
/NCEP/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/NCEP/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinxcontrib-bibtex
2 | sphinx_rtd_theme
3 | docutils==0.16
4 |
--------------------------------------------------------------------------------
/NCEP/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Configuration file for the Sphinx documentation builder.
4 | #
5 | # This file does only contain a selection of the most common options. For a
6 | # full list see the documentation:
7 | # http://www.sphinx-doc.org/en/master/config
8 |
9 | # -- Path setup --------------------------------------------------------------
10 |
11 | # If extensions (or modules to document with autodoc) are in another directory,
12 | # add these directories to sys.path here. If the directory is relative to the
13 | # documentation root, use os.path.abspath to make it absolute, like shown here.
14 | #
15 | import os
16 | import sys
17 | sys.path.insert(0, os.path.abspath('.'))
18 |
19 |
20 | # -- Project information -----------------------------------------------------
21 |
22 | project = 'GraphCast GFS'
23 | copyright = ''
24 | author = ' '
25 |
26 | # The short X.Y version
27 | version = ''
28 | # The full version, including alpha/beta/rc tags
29 | release = ''
30 |
31 | numfig = True
32 |
33 |
34 | # -- General configuration ---------------------------------------------------
35 |
36 | # If your documentation needs a minimal Sphinx version, state it here.
37 | #
38 | # needs_sphinx = '1.0'
39 |
40 | # Add any Sphinx extension module names here, as strings. They can be
41 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
42 | # ones.
43 | extensions = [
44 | 'sphinx_rtd_theme',
45 | 'sphinx.ext.autodoc',
46 | 'sphinx.ext.doctest',
47 | 'sphinx.ext.intersphinx',
48 | 'sphinx.ext.todo',
49 | 'sphinx.ext.coverage',
50 | 'sphinx.ext.mathjax',
51 | 'sphinx.ext.ifconfig',
52 | 'sphinx.ext.viewcode',
53 | 'sphinx.ext.githubpages',
54 | 'sphinx.ext.napoleon',
55 | ]
56 |
57 | # Add any paths that contain templates here, relative to this directory.
58 | #templates_path = ['_templates']
59 |
60 | # The suffix(es) of source filenames.
61 | # You can specify multiple suffix as a list of string:
62 | #
63 | # source_suffix = ['.rst', '.md']
64 | source_suffix = '.rst'
65 |
66 | # The master toctree document.
67 | master_doc = 'index'
68 |
69 | # The language for content autogenerated by Sphinx. Refer to documentation
70 | # for a list of supported languages.
71 | #
72 | # This is also used if you do content translation via gettext catalogs.
73 | # Usually you set "language" from the command line for these cases.
74 | language = 'en'
75 |
76 | # List of patterns, relative to source directory, that match files and
77 | # directories to ignore when looking for source files.
78 | # This pattern also affects html_static_path and html_extra_path.
79 | exclude_patterns = []
80 |
81 | # The name of the Pygments (syntax highlighting) style to use.
82 | pygments_style = 'sphinx'
83 |
84 |
85 | # -- Options for HTML output -------------------------------------------------
86 |
87 | # The theme to use for HTML and HTML Help pages. See the documentation for
88 | # a list of builtin themes.
89 | #
90 | #html_theme = 'classic'
91 | html_theme = 'sphinx_rtd_theme'
92 | html_theme_path = ["_themes", ]
93 |
94 | # Theme options are theme-specific and customize the look and feel of a theme
95 | # further. For a list of options available for each theme, see the
96 | # documentation.
97 | #
98 | # html_theme_options = {}
99 | html_theme_options = {"body_max_width": "none"}
100 |
101 | # Add any paths that contain custom static files (such as style sheets) here,
102 | # relative to this directory. They are copied after the builtin static files,
103 | # so a file named "default.css" will overwrite the builtin "default.css".
104 | #html_static_path = ['_static']
105 | html_context = {}
106 |
107 | def setup(app):
108 | app.add_css_file('custom.css') # may also be an URL
109 | app.add_css_file('theme_overrides.css') # may also be an URL
110 |
111 | # Custom sidebar templates, must be a dictionary that maps document names
112 | # to template names.
113 | #
114 | # The default sidebars (for documents that don't match any pattern) are
115 | # defined by theme itself. Builtin themes are using these templates by
116 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
117 | # 'searchbox.html']``.
118 | #
119 | # html_sidebars = {}
120 |
121 |
122 |
123 | # -- Options for LaTeX output ------------------------------------------------
124 |
125 | latex_engine = 'pdflatex'
126 | latex_elements = {
127 | # The paper size ('letterpaper' or 'a4paper').
128 | #
129 | # 'papersize': 'letterpaper',
130 |
131 | # The font size ('10pt', '11pt' or '12pt').
132 | #
133 | # 'pointsize': '10pt',
134 |
135 | # Additional stuff for the LaTeX preamble.
136 | #
137 | # 'preamble': '',
138 |
139 | # Latex figure (float) alignment
140 | #
141 | # 'figure_align': 'htbp',
142 | # 'maketitle': r'\newcommand\sphinxbackoftitlepage{For referencing this document please use: \newline \break Schramm, J., L. Bernardet, L. Carson, G. Firl, D. Heinzeller, L. Pan, and M. Zhang, 2020. UFS Weather Model User's Guide Release v1.0.0. Npp. Available at https://dtcenter.org.}\sphinxmaketitle'
143 | }
144 |
145 | # -- Extension configuration -------------------------------------------------
146 |
147 | # -- Options for intersphinx extension ---------------------------------------
148 |
149 | # Example configuration for intersphinx: refer to the Python standard library.
150 | intersphinx_mapping = {'landda': ('https://land-da-workflow.readthedocs.io/en/latest/', None),
151 | }
152 |
153 | # -- Options for todo extension ----------------------------------------------
154 |
155 | # If true, `todo` and `todoList` produce output, else they produce nothing.
156 | todo_include_todos = True
157 |
--------------------------------------------------------------------------------
/NCEP/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | GraphCast with GFS input
2 | =======================================================
3 | .. toctree::
4 | :numbered:
5 | :maxdepth: 3
6 |
7 | introduction
8 | installation
9 | inputs
10 | run
11 | outputs
12 |
--------------------------------------------------------------------------------
/NCEP/docs/source/inputs.rst:
--------------------------------------------------------------------------------
1 | #############################################
2 | Preparing inputs from GDAS product
3 | #############################################
4 |
5 | GraphCast takes two states of the weather (current and 6-hr earlier states) as the initial conditions. We will create a netCDF file containing these two states from GDAS 0.25 degree reanalysis data. This can be performed using the script NCEP/gdas_utility.py. The script downloads the GDAS data from either NOAA s3 bucket or NOAA NOMADS server, which are in GRIB2 format. Then it extracts required variables from GRIB2 files and saves data as netCDF files. Run the script using::
6 |
7 | python gdas_utility.py startdate enddate --level 13 --source s3 --output /path/to/output --download /path/to/download --method wgrib2 --keep no
8 |
9 | **Arguments**
10 |
11 | Requried:
12 |
13 | startdate and endate: string, yyyymmddhh
14 |
15 | Optional:
16 |
17 | *-l* or *--level*: 13 or 37, the number of pressure levels (default: 13)
18 |
19 | *-s* or *--source*: s3 or nomads, the sourece to download gdas data (default: s3)
20 |
21 | *-m* or *--method*: wgrib2 or pygrib, the method to extract required variables and create netCDF file (default: wgrib2)
22 |
23 | *-o* or *--output*: /path/to/output, where to save forecast outputs (default: current directory)
24 |
25 | *-d* or *--download*: /path/to/download, where to save downloaded grib2 files (default: current directory)
26 |
27 | *-k* or *--keep*: yes or no, whether to keep downloaded data after processed (default: no)
28 |
--------------------------------------------------------------------------------
/NCEP/docs/source/installation.rst:
--------------------------------------------------------------------------------
1 | ############################################
2 | Installation
3 | ############################################
4 |
5 | The recommended way to setup the environemnt for installing GraphCast is to use `conda `_.
6 | With conda, you can create an environment and install required libraries with the `environment.yml` file provided in NCEP folder::
7 |
8 | conda env create -f environment.yml -n your-env-name
9 |
10 | Activate the environment::
11 |
12 | conda activate your-env-name
13 |
14 | Get EMC/graphcast source code::
15 |
16 | git clone https://github.com/NOAA-EMC/graphcast.git
17 |
--------------------------------------------------------------------------------
/NCEP/docs/source/introduction.rst:
--------------------------------------------------------------------------------
1 | ######################
2 | Introduction
3 | ######################
4 |
5 | **The GraphCast Global Forecast System** is a weather forecast model built upon the pre-trained Google DeepMind's GraphCast
6 | Machine Learning Weather Prediction (MLWP) model. It is set up by the National Centers for Environmental Prediction (NCEP)
7 | to produce medium range global forecasts. The model runs in two operation modes on different vertical resolutions: 13 and
8 | 37 pressure levels. The horizontal resolution is a 0.25 degree latitude-longitude grid (about 28 km). The model runs 4
9 | times a day at 00Z, 06Z, 12Z, and 18Z cycles. Major surface and atmospheric fields including temperature, wind components,
10 | geopotential height, specific humidity, and vertical velocity are available. The products are 6-hourly forecasts up to 10 days.
11 |
12 | The Google DeepMind's GraphCast model is implemented as a message passing graph neural network (GNN) architecture with
13 | "encoder-processor-decoder" configuration. It uses an icosahedron grid with multiscale edges and has around 37 milion parameters.
14 | The model is pre-trained with ECMWF's ERA5 reanalysis data. The GraphCastGFS model takes two model states as initial conditions
15 | (current and 6-hr previous states) from NCEP 0.25 degree GDAS analysis data.
16 |
--------------------------------------------------------------------------------
/NCEP/docs/source/outputs.rst:
--------------------------------------------------------------------------------
1 | ######################
2 | Product
3 | ######################
4 |
5 | The GraphCastGFS model runs 4 times a day at 00Z, 06Z, 12Z, and 18Z cycles. The horizontal resolution is on 0.25 degree lat-lon grid.
6 | The vertical resolutions are on both 13 and 37 pressure levels.
7 |
8 | * The 13 pressure levels include:
9 |
10 | 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, and 1000 hPa.
11 |
12 | * The 37 pressure levels include:
13 |
14 | 1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, 150, 175, 200, 225, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 825, 850, 875, 900, 925, 950, 975, and 1000 hPa.
15 |
16 | The model output fields are:
17 |
18 | * 3D fields on pressure levels:
19 |
20 | * temperature
21 |
22 | * U and V component of wind
23 |
24 | * geopotential height
25 |
26 | * specific humidity
27 |
28 | * vertical velocity
29 |
30 | * 2D surface fields:
31 |
32 | * 10-m U and V components of wind
33 |
34 | * 2-m temperature
35 |
36 | * mean sea-level pressure
37 |
38 | * 6-hourly total precipitation
39 |
40 | The near real-time forecast outputs along with inputs are available on `AWS `_.
41 |
42 | For each cycle, the dataset contains input files to feed into GraphCast found in the directory:
43 |
44 | graphcastgfs.yyyymmdd/hh/input
45 |
46 | and 10-day forecast results for the current cycle found in the following directories:
47 |
48 | graphcastgfs.yyyymmdd/hh/forecasts_13_levels
49 |
50 | graphcastgfs.yyyymmdd/hh/forecasts_37_levels
51 |
--------------------------------------------------------------------------------
/NCEP/docs/source/run.rst:
--------------------------------------------------------------------------------
1 | ######################
2 | Run GraphCastGFS
3 | ######################
4 | In order to run GraphCast in inference mode you will also need to have the model weights, normalization statistics,
5 | which are avaiable on `Google Cloud Bucket `_
6 | Once you have input netCDF file, model weights, and statistics data, you can run the GraphCast model with a leading time
7 | (e.g., leading time 10 days will result in forecast_length of 40) using::
8 |
9 | python run_graphcast.py --input /input/filename/with/path --output /path/to/output --weights /path/to/weights --length forecast_length
10 |
11 | **Arguments**
12 |
13 | Required:
14 |
15 | *-i* or *--input*: /input/filename/with/path
16 |
17 | *-o* or *--output*: /path/to/output
18 |
19 | *-w* or *--weights*: /path/to/weights/and/stats
20 |
21 | *-l* or *--length*: integer, the number of forecast time steps (6-hourly)
22 |
23 | Optional:
24 |
25 | *-p* or *--pressure*: 13 or 37, number of pressure levels (default: 13)
26 |
27 | *-u* or *--upload*: yes or no, upload input and output files to NOAA s3 bucket (default: no)
28 |
29 | *-k* or *--keep*: yes or no, whether to keep input and output files after uploading
30 |
--------------------------------------------------------------------------------
/NCEP/environment.yml:
--------------------------------------------------------------------------------
1 | name: graphcast
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - python>=3.10
6 | - dm-tree
7 | - netcdf4
8 | - xarray
9 | - cartopy
10 | - pygrib
11 | - eccodes
12 | - iris
13 | - iris-grib
14 | - jupyterlab
15 | - pip
16 | - pip:
17 | - https://github.com/deepmind/graphcast/archive/master.zip
18 |
--------------------------------------------------------------------------------
/NCEP/gc_datadissm_hera.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # load necessary modules
4 | module use /scratch1/NCEPDEV/nems/role.epic/spack-stack/spack-stack-1.6.0/envs/unified-env-rocky8/install/modulefiles/Core
5 | module load stack-intel
6 | module load wgrib2
7 | module load awscli-v2
8 | module list
9 |
10 |
11 |
12 | # Get the UTC hour and calculate the time in the format yyyymmddhh
13 | current_hour=$(date -u +%H)
14 | current_hour=$((10#$current_hour))
15 |
16 | if (( $current_hour >= 0 && $current_hour < 6 )); then
17 | datetime=$(date -u -d 'today 00:00')
18 | elif (( $current_hour >= 6 && $current_hour < 12 )); then
19 | datetime=$(date -u -d 'today 06:00')
20 | elif (( $current_hour >= 12 && $current_hour < 18 )); then
21 | datetime=$(date -u -d 'today 12:00')
22 | else
23 | datetime=$(date -u -d 'today 18:00')
24 | fi
25 |
26 | # Calculate time 6 hours before
27 | #curr_datetime=$(date -u -d "$time" +'%Y%m%d%H')
28 | curr_datetime=$( date -d "$datetime 12 hour ago" "+%Y%m%d%H" )
29 | prev_datetime=$( date -d "$datetime 18 hour ago" "+%Y%m%d%H" )
30 |
31 | echo "Current state: $curr_datetime"
32 | echo "6 hours earlier state: $prev_datetime"
33 |
34 | # Activate Conda environment
35 | source /scratch1/NCEPDEV/nems/AIML/miniconda3/etc/profile.d/conda.sh
36 | conda activate mlwp
37 |
38 | cd /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/
39 |
40 | forecast_length=40
41 | num_pressure_level=13
42 |
43 | start_time=$(date +%s)
44 | echo "start uploading graphcast forecast to s3 bucket for: $curr_datetime"
45 | # Run another Python script
46 | python3 upload_to_s3bucket.py -d "$curr_datetime" -l "$num_pressure_level"
47 |
48 | end_time=$(date +%s) # Record the end time in seconds since the epoch
49 |
50 | # Calculate and print the execution time
51 | execution_time=$((end_time - start_time))
52 | echo "Execution time for uploading to s3 bucket: $execution_time seconds"
53 |
--------------------------------------------------------------------------------
/NCEP/gc_prepdata_hera.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | # load necessary modules
5 | module use /scratch1/NCEPDEV/nems/role.epic/spack-stack/spack-stack-1.6.0/envs/unified-env-rocky8/install/modulefiles/Core
6 | module load stack-intel
7 | module load wgrib2
8 | module load awscli-v2
9 | module list
10 |
11 |
12 | # Get the UTC hour and calculate the time in the format yyyymmddhh
13 | current_hour=$(date -u +%H)
14 | current_hour=$((10#$current_hour))
15 |
16 | if (( $current_hour >= 0 && $current_hour < 6 )); then
17 | datetime=$(date -u -d 'today 00:00')
18 | elif (( $current_hour >= 6 && $current_hour < 12 )); then
19 | datetime=$(date -u -d 'today 06:00')
20 | elif (( $current_hour >= 12 && $current_hour < 18 )); then
21 | datetime=$(date -u -d 'today 12:00')
22 | else
23 | datetime=$(date -u -d 'today 18:00')
24 | fi
25 |
26 | # Calculate time 6 hours before
27 | #curr_datetime=$(date -u -d "$time" +'%Y%m%d%H')
28 | curr_datetime=$( date -d "$datetime 12 hour ago" "+%Y%m%d%H" )
29 | prev_datetime=$( date -d "$datetime 18 hour ago" "+%Y%m%d%H" )
30 |
31 | echo "Current state: $curr_datetime"
32 | echo "6 hours earlier state: $prev_datetime"
33 |
34 |
35 | # Activate Conda environment
36 | source /scratch1/NCEPDEV/nems/AIML/miniconda3/etc/profile.d/conda.sh
37 | conda activate mlwp
38 |
39 | cd /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/
40 |
41 | num_pressure_levels=13
42 |
43 | start_time=$(date +%s)
44 | echo "start runing gdas utility to generate graphcast inputs for: $curr_datetime"
45 | # Run the Python script gdas.py with the calculated times
46 | python3 gdas_utility.py "$prev_datetime" "$curr_datetime" -l "$num_pressure_levels"
47 |
48 | end_time=$(date +%s) # Record the end time in seconds since the epoch
49 |
50 | # Calculate and print the execution time
51 | execution_time=$((end_time - start_time))
52 | echo "Execution time for gdas_utility.py: $execution_time seconds"
53 |
--------------------------------------------------------------------------------
/NCEP/gc_runfcst_hera.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --nodes=1
3 | #SBATCH --account=nems
4 | #SBATCH --cpus-per-task=40
5 | #SBATCH --time=4:00:00
6 | #SBATCH --job-name=graphcast
7 | #SBATCH --output=gc_output.txt
8 | #SBATCH --error=gc_error.txt
9 | #SBATCH --partition=hera
10 |
11 |
12 | # load necessary modules
13 | module use /scratch1/NCEPDEV/nems/role.epic/spack-stack/spack-stack-1.6.0/envs/unified-env-rocky8/install/modulefiles/Core
14 | module load stack-intel
15 | module load wgrib2
16 | module load awscli-v2
17 | module list
18 |
19 |
20 | # Get the UTC hour and calculate the time in the format yyyymmddhh
21 | current_hour=$(date -u +%H)
22 | current_hour=$((10#$current_hour))
23 |
24 | if (( $current_hour >= 0 && $current_hour < 6 )); then
25 | datetime=$(date -u -d 'today 00:00')
26 | elif (( $current_hour >= 6 && $current_hour < 12 )); then
27 | datetime=$(date -u -d 'today 06:00')
28 | elif (( $current_hour >= 12 && $current_hour < 18 )); then
29 | datetime=$(date -u -d 'today 12:00')
30 | else
31 | datetime=$(date -u -d 'today 18:00')
32 | fi
33 |
34 | # Calculate time 6 hours before
35 | #curr_datetime=$(date -u -d "$time" +'%Y%m%d%H')
36 | curr_datetime=$( date -d "$datetime 12 hour ago" "+%Y%m%d%H" )
37 | prev_datetime=$( date -d "$datetime 18 hour ago" "+%Y%m%d%H" )
38 |
39 | echo "Current state: $curr_datetime"
40 | echo "6 hours earlier state: $prev_datetime"
41 |
42 | forecast_length=40
43 | echo "forecast length: $forecast_length"
44 |
45 | num_pressure_levels=13
46 | echo "number of pressure levels: $num_pressure_levels"
47 |
48 | # Activate Conda environment
49 | source /scratch1/NCEPDEV/nems/AIML/miniconda3/etc/profile.d/conda.sh
50 | conda activate mlwp
51 |
52 | cd /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/
53 |
54 | start_time=$(date +%s)
55 | echo "start runing graphcast to get real time 10-days forecasts for: $curr_datetime"
56 | # Run another Python script
57 | python3 run_graphcast.py -i source-gdas_date-"$curr_datetime"_res-0.25_levels-"$num_pressure_levels"_steps-2.nc -w /scratch1/NCEPDEV/nems/AIML/gc_weights -l "$forecast_length" -p "$num_pressure_levels"
58 |
59 | end_time=$(date +%s) # Record the end time in seconds since the epoch
60 |
61 | # Calculate and print the execution time
62 | execution_time=$((end_time - start_time))
63 | echo "Execution time for graphcast: $execution_time seconds"
64 |
--------------------------------------------------------------------------------
/NCEP/gcjob_13pl_cloud.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash --login
2 | #SBATCH --nodes=1
3 | #SBATCH --cpus-per-task=30 # Use all available CPU cores
4 | #SBATCH --time=4:00:00 # Adjust this to your estimated run time
5 | #SBATCH --job-name=graphcast
6 | #SBATCH --output=gc_13pl_output.txt
7 | #SBATCH --error=gc_13pl_error.txt
8 | #SBATCH --partition=compute
9 |
10 | # load module lib
11 | # source /etc/profile.d/modules.sh
12 |
13 | # load necessary modules
14 | module use /contrib/spack-stack/envs/ufswm/install/modulefiles/Core/
15 | module load stack-intel
16 | module load wgrib2
17 | module list
18 |
19 | # Get the UTC hour and calculate the time in the format yyyymmddhh
20 | current_hour=$(date -u +%H)
21 | current_hour=$((10#$current_hour))
22 |
23 | if (( $current_hour >= 0 && $current_hour < 6 )); then
24 | datetime=$(date -u -d 'today 00:00')
25 | elif (( $current_hour >= 6 && $current_hour < 12 )); then
26 | datetime=$(date -u -d 'today 06:00')
27 | elif (( $current_hour >= 12 && $current_hour < 18 )); then
28 | datetime=$(date -u -d 'today 12:00')
29 | else
30 | datetime=$(date -u -d 'today 18:00')
31 | fi
32 |
33 | # Calculate time 6 hours before
34 | #curr_datetime=$(date -u -d "$time" +'%Y%m%d%H')
35 | curr_datetime=$( date -d "$datetime 12 hour ago" "+%Y%m%d%H" )
36 | prev_datetime=$( date -d "$datetime 18 hour ago" "+%Y%m%d%H" )
37 |
38 | echo "Current state: $curr_datetime"
39 | echo "6 hours earlier state: $prev_datetime"
40 |
41 | forecast_length=64
42 | echo "forecast length: $forecast_length"
43 |
44 | num_pressure_levels=13
45 | echo "number of pressure levels: $num_pressure_levels"
46 |
47 | # Set Miniconda path
48 | #export PATH="/contrib/Sadegh.Tabas/miniconda3/bin:$PATH"
49 |
50 | # Activate Conda environment
51 | source /contrib/Sadegh.Tabas/miniconda3/etc/profile.d/conda.sh
52 | conda activate mlwp
53 |
54 | # going to the model directory
55 | cd /contrib/Sadegh.Tabas/operational/graphcast/NCEP/
56 |
57 | start_time=$(date +%s)
58 | echo "start runing gdas utility to generate graphcast inputs for: $curr_datetime"
59 | # Run the Python script gdas.py with the calculated times
60 | python3 gdas_utility.py "$prev_datetime" "$curr_datetime" -l "$num_pressure_levels"
61 |
62 | end_time=$(date +%s) # Record the end time in seconds since the epoch
63 |
64 | # Calculate and print the execution time
65 | execution_time=$((end_time - start_time))
66 | echo "Execution time for gdas_utility.py: $execution_time seconds"
67 |
68 | start_time=$(date +%s)
69 | echo "start runing graphcast to get real time 10-days forecasts for: $curr_datetime"
70 | # Run another Python script
71 | python3 run_graphcast.py -i source-gdas_date-"$curr_datetime"_res-0.25_levels-"$num_pressure_levels"_steps-2.nc -w /contrib/graphcast/NCEP -l "$forecast_length" -p "$num_pressure_levels" -u yes -k no
72 |
73 | end_time=$(date +%s) # Record the end time in seconds since the epoch
74 |
75 | # Calculate and print the execution time
76 | execution_time=$((end_time - start_time))
77 | echo "Execution time for graphcast: $execution_time seconds"
78 |
--------------------------------------------------------------------------------
/NCEP/gcjob_37pl_cloud.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash --login
2 | #SBATCH --nodes=1
3 | #SBATCH --cpus-per-task=30 # Use all available CPU cores
4 | #SBATCH --time=4:00:00 # Adjust this to your estimated run time
5 | #SBATCH --job-name=graphcast
6 | #SBATCH --output=gc_37pl_output.txt
7 | #SBATCH --error=gc_37pl_error.txt
8 | #SBATCH --partition=compute
9 |
10 | # load module lib
11 | # source /etc/profile.d/modules.sh
12 |
13 | # load necessary modules
14 | module use /contrib/spack-stack/envs/ufswm/install/modulefiles/Core/
15 | module load stack-intel
16 | module load wgrib2
17 | module list
18 |
19 | # Get the UTC hour and calculate the time in the format yyyymmddhh
20 | current_hour=$(date -u +%H)
21 | current_hour=$((10#$current_hour))
22 |
23 | if (( $current_hour >= 0 && $current_hour < 6 )); then
24 | datetime=$(date -u -d 'today 00:00')
25 | elif (( $current_hour >= 6 && $current_hour < 12 )); then
26 | datetime=$(date -u -d 'today 06:00')
27 | elif (( $current_hour >= 12 && $current_hour < 18 )); then
28 | datetime=$(date -u -d 'today 12:00')
29 | else
30 | datetime=$(date -u -d 'today 18:00')
31 | fi
32 |
33 | # Calculate time 6 hours before
34 | #curr_datetime=$(date -u -d "$time" +'%Y%m%d%H')
35 | curr_datetime=$( date -d "$datetime 12 hour ago" "+%Y%m%d%H" )
36 | prev_datetime=$( date -d "$datetime 18 hour ago" "+%Y%m%d%H" )
37 |
38 | echo "Current state: $curr_datetime"
39 | echo "6 hours earlier state: $prev_datetime"
40 |
41 | forecast_length=40
42 | echo "forecast length: $forecast_length"
43 |
44 | num_pressure_levels=37
45 | echo "number of pressure levels: $num_pressure_levels"
46 |
47 | # Set Miniconda path
48 | #export PATH="/contrib/Sadegh.Tabas/miniconda3/bin:$PATH"
49 |
50 | # Activate Conda environment
51 | source /contrib/Sadegh.Tabas/miniconda3/etc/profile.d/conda.sh
52 | conda activate mlwp
53 |
54 | # going to the model directory
55 | cd /contrib/Sadegh.Tabas/operational/graphcast/NCEP/
56 |
57 | start_time=$(date +%s)
58 | echo "start runing gdas utility to generate graphcast inputs for: $curr_datetime"
59 | # Run the Python script gdas.py with the calculated times
60 | python3 gdas_utility.py "$prev_datetime" "$curr_datetime" -l "$num_pressure_levels"
61 |
62 | end_time=$(date +%s) # Record the end time in seconds since the epoch
63 |
64 | # Calculate and print the execution time
65 | execution_time=$((end_time - start_time))
66 | echo "Execution time for gdas_utility.py: $execution_time seconds"
67 |
68 | start_time=$(date +%s)
69 | echo "start runing graphcast to get real time 10-days forecasts for: $curr_datetime"
70 | # Run another Python script
71 | python3 run_graphcast.py -i source-gdas_date-"$curr_datetime"_res-0.25_levels-"$num_pressure_levels"_steps-2.nc -w /contrib/graphcast/NCEP -l "$forecast_length" -p "$num_pressure_levels" -u yes -k no
72 |
73 | end_time=$(date +%s) # Record the end time in seconds since the epoch
74 |
75 | # Calculate and print the execution time
76 | execution_time=$((end_time - start_time))
77 | echo "Execution time for graphcast: $execution_time seconds"
78 |
--------------------------------------------------------------------------------
/NCEP/run_graphcast.py:
--------------------------------------------------------------------------------
1 | '''
2 | Description: Script to call the graphcast model using gdas products
3 | Author: Sadegh Sadeghi Tabas (sadegh.tabas@noaa.gov)
4 | Revision history:
5 | -20231218: Sadegh Tabas, initial code
6 | -20240118: Sadegh Tabas, S3 bucket module to upload data, adding forecast length, Updating batch dataset to account for forecast length
7 | -20240125: Linlin Cui, added a capability to save output as grib2 format
8 | -20240205: Sadegh Tabas, made the code clearer, added 37 pressure level option, updated upload to s3
9 | -20240731: Sadegh Tabas, added grib2 file for F000
10 | -20240815: Sadegh Tabas, update the directory of fine tuned model parameters
11 | '''
12 | import os
13 | import argparse
14 | from datetime import timedelta
15 | import dataclasses
16 | import functools
17 | import re
18 | import haiku as hk
19 | import jax
20 | import numpy as np
21 | import xarray
22 | import boto3
23 | import pandas as pd
24 |
25 | from graphcast import autoregressive
26 | from graphcast import casting
27 | from graphcast import checkpoint
28 | from graphcast import data_utils
29 | from graphcast import graphcast
30 | from graphcast import normalization
31 | from graphcast import rollout
32 |
33 | from utils.nc2grib import Netcdf2Grib
34 |
35 | class GraphCastModel:
36 | def __init__(self, pretrained_model_path, gdas_data_path, output_dir=None, num_pressure_levels=13, forecast_length=40):
37 | self.pretrained_model_path = pretrained_model_path
38 | self.gdas_data_path = gdas_data_path
39 | self.forecast_length = forecast_length
40 | self.num_pressure_levels = num_pressure_levels
41 |
42 | if output_dir is None:
43 | self.output_dir = os.path.join(os.getcwd(), f"forecasts_{str(self.num_pressure_levels)}_levels") # Use current directory if not specified
44 | else:
45 | self.output_dir = os.path.join(output_dir, f"forecasts_{str(self.num_pressure_levels)}_levels")
46 | os.makedirs(self.output_dir, exist_ok=True)
47 |
48 | self.params = None
49 | self.state = {}
50 | self.model_config = None
51 | self.task_config = None
52 | self.diffs_stddev_by_level = None
53 | self.mean_by_level = None
54 | self.stddev_by_level = None
55 | self.current_batch = None
56 | self.inputs = None
57 | self.targets = None
58 | self.forcings = None
59 | self.s3_bucket_name = "noaa-nws-graphcastgfs-pds"
60 | self.dates = None
61 |
62 |
63 | def load_pretrained_model(self):
64 | """Load pre-trained GraphCast model."""
65 | if self.num_pressure_levels==13:
66 | model_weights_path = f"{self.pretrained_model_path}/params/GCGFSv2_finetuned - GDAS - ERA5 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz"
67 | else:
68 | model_weights_path = f"{self.pretrained_model_path}/params/GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz"
69 |
70 | with open(model_weights_path, "rb") as f:
71 | ckpt = checkpoint.load(f, graphcast.CheckPoint)
72 | self.params = ckpt.params
73 | self.state = {}
74 | self.model_config = ckpt.model_config
75 | self.task_config = ckpt.task_config
76 |
77 | def load_gdas_data(self):
78 | """Load GDAS data."""
79 | #with open(gdas_data_path, "rb") as f:
80 | # self.current_batch = xarray.load_dataset(f).compute()
81 | self.current_batch = xarray.load_dataset(self.gdas_data_path).compute()
82 | self.dates = pd.to_datetime(self.current_batch.datetime.values)
83 |
84 | if (self.forecast_length + 2) > len(self.current_batch['time']):
85 | print('Updating batch dataset to account for forecast length')
86 |
87 | diff = int(self.forecast_length + 2 - len(self.current_batch['time']))
88 | ds = self.current_batch
89 |
90 | # time and datetime update
91 | curr_time_range = ds['time'].values.astype('timedelta64[ns]')
92 | new_time_range = (np.arange(len(curr_time_range) + diff) * np.timedelta64(6, 'h')).astype('timedelta64[ns]')
93 | ds = ds.reindex(time = new_time_range)
94 | curr_datetime_range = ds['datetime'][0].values.astype('datetime64[ns]')
95 | new_datetime_range = curr_datetime_range[0] + np.arange(len(curr_time_range) + diff) * np.timedelta64(6, 'h')
96 | ds['datetime'][0]= new_datetime_range
97 |
98 | self.current_batch = ds
99 | print('batch dataset updated')
100 |
101 |
102 | def extract_inputs_targets_forcings(self):
103 | """Extract inputs, targets, and forcings from the loaded data."""
104 | self.inputs, self.targets, self.forcings = data_utils.extract_inputs_targets_forcings(
105 | self.current_batch, target_lead_times=slice("6h", f"{self.forecast_length*6}h"), **dataclasses.asdict(self.task_config)
106 | )
107 |
108 | def load_normalization_stats(self):
109 | """Load normalization stats."""
110 |
111 | diffs_stddev_path = f"{self.pretrained_model_path}/stats/diffs_stddev_by_level.nc"
112 | mean_path = f"{self.pretrained_model_path}/stats/mean_by_level.nc"
113 | stddev_path = f"{self.pretrained_model_path}/stats/stddev_by_level.nc"
114 |
115 | with open(diffs_stddev_path, "rb") as f:
116 | self.diffs_stddev_by_level = xarray.load_dataset(f).compute()
117 | with open(mean_path, "rb") as f:
118 | self.mean_by_level = xarray.load_dataset(f).compute()
119 | with open(stddev_path, "rb") as f:
120 | self.stddev_by_level = xarray.load_dataset(f).compute()
121 |
122 | # Jax doesn't seem to like passing configs as args through the jit. Passing it
123 | # in via partial (instead of capture by closure) forces jax to invalidate the
124 | # jit cache if you change configs.
125 | def _with_configs(self, fn):
126 | return functools.partial(fn, model_config=self.model_config, task_config=self.task_config,)
127 |
128 | # Always pass params and state, so the usage below are simpler
129 | def _with_params(self, fn):
130 | return functools.partial(fn, params=self.params, state=self.state)
131 |
132 | # Deepmind models aren't stateful, so the state is always empty, so just return the
133 | # predictions. This is requiredy by the rollout code, and generally simpler.
134 | @staticmethod
135 | def _drop_state(fn):
136 | return lambda **kw: fn(**kw)[0]
137 |
138 | def load_model(self):
139 | def construct_wrapped_graphcast(model_config, task_config):
140 | """Constructs and wraps the GraphCast Predictor."""
141 | # Deeper one-step predictor.
142 | predictor = graphcast.GraphCast(model_config, task_config)
143 |
144 | # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
145 | # from/to float32 to/from BFloat16.
146 | predictor = casting.Bfloat16Cast(predictor)
147 |
148 | # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
149 | # BFloat16 happens after applying normalization to the inputs/targets.
150 | predictor = normalization.InputsAndResiduals(predictor, diffs_stddev_by_level=self.diffs_stddev_by_level, mean_by_level=self.mean_by_level, stddev_by_level=self.stddev_by_level,)
151 |
152 | # Wraps everything so the one-step model can produce trajectories.
153 | predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True,)
154 | return predictor
155 |
156 | @hk.transform_with_state
157 | def run_forward(model_config, task_config, inputs, targets_template, forcings,):
158 | predictor = construct_wrapped_graphcast(model_config, task_config)
159 | return predictor(inputs, targets_template=targets_template, forcings=forcings,)
160 |
161 | jax.jit(self._with_configs(run_forward.init))
162 | self.model = self._drop_state(self._with_params(jax.jit(self._with_configs(run_forward.apply))))
163 |
164 |
165 | def get_predictions(self):
166 | """Run GraphCast and save forecasts to a NetCDF file."""
167 |
168 | print (f"start running GraphCast for {self.forecast_length} steps --> {self.forecast_length*6} hours.")
169 | self.load_model()
170 |
171 | # output = self.model(self.model ,rng=jax.random.PRNGKey(0), inputs=self.inputs, targets_template=self.targets * np.nan, forcings=self.forcings,)
172 | forecasts = rollout.chunked_prediction(self.model, rng=jax.random.PRNGKey(0), inputs=self.inputs, targets_template=self.targets * np.nan, forcings=self.forcings,)
173 |
174 | # filename = f"forecasts_levels-{self.num_pressure_levels}_steps-{self.forecast_length}.nc"
175 | # output_netcdf = os.path.join(self.output_dir, filename)
176 |
177 | # save forecasts
178 | # forecasts.to_netcdf(output_netcdf)
179 | # print (f"GraphCast run completed successfully, you can find the GraphCast forecasts in the following directory:\n {output_netcdf}")
180 |
181 | self.save_grib2(forecasts)
182 |
183 | def save_grib2(self, forecasts):
184 | converter = Netcdf2Grib()
185 |
186 | # Call and save f000 in grib2
187 | ds = self.current_batch
188 | ds = ds.drop_vars(['geopotential_at_surface','land_sea_mask', 'total_precipitation_6hr'])
189 | for var in ds.data_vars:
190 | if 'long_name' in ds[var].attrs:
191 | del ds[var].attrs['long_name']
192 | ds = ds.isel(time=slice(1, 2))
193 | ds['time'] = ds['time'] - pd.Timedelta(hours=6)
194 |
195 | converter.save_grib2(self.dates, ds, self.output_dir)
196 |
197 | # Call and save forecasts in grib2
198 | converter.save_grib2(self.dates, forecasts, self.output_dir)
199 |
200 |
201 | def upload_to_s3(self, keep_data):
202 | s3 = boto3.client('s3')
203 |
204 | # Extract date and time information from the input file name
205 | input_file_name = os.path.basename(self.gdas_data_path)
206 |
207 | date_start = input_file_name.find("date-")
208 |
209 | # Check if "date-" is found in the input_file_name
210 | if date_start != -1:
211 | date_start += len("date-") # Move to the end of "date-"
212 | date = input_file_name[date_start:date_start + 8] # Extract 8 characters as the date
213 |
214 | time_start = date_start + 8 # Move to the character after the date
215 | time = input_file_name[time_start:time_start + 2] # Extract 2 characters as the time
216 |
217 |
218 | # Define S3 key paths for input and output files
219 | input_s3_key = f'graphcastgfs.{date}/{time}/input/{self.gdas_data_path}'
220 |
221 | # Upload input file to S3
222 | s3.upload_file(self.gdas_data_path, self.s3_bucket_name, input_s3_key)
223 |
224 | # Upload output files to S3
225 | # Iterate over all files in the local directory and upload each one to S3
226 | s3_prefix = f'graphcastgfs.{date}/{time}/forecasts_{self.num_pressure_levels}_levels'
227 |
228 | for root, dirs, files in os.walk(self.output_dir):
229 |
230 | for file in files:
231 | local_path = os.path.join(root, file)
232 | relative_path = os.path.relpath(local_path, self.output_dir)
233 | s3_path = os.path.join(s3_prefix, relative_path)
234 |
235 | # Upload the file
236 | s3.upload_file(local_path, self.s3_bucket_name, s3_path)
237 |
238 | print("Upload to s3 bucket completed.")
239 |
240 | # Delete local files if keep_data is False
241 | if not keep_data:
242 | # Remove forecast data from the specified directory
243 | print("Removing input and forecast data from the specified directory...")
244 | try:
245 | os.system(f"rm -rf {self.output_dir}")
246 | os.remove(self.gdas_data_path)
247 | print("Local input and output files deleted.")
248 | except Exception as e:
249 | print(f"Error removing input and forecast data: {str(e)}")
250 |
251 |
252 |
253 | if __name__ == "__main__":
254 | parser = argparse.ArgumentParser(description="Run GraphCast model.")
255 | parser.add_argument("-i", "--input", help="input file path (including file name)", required=True)
256 | parser.add_argument("-w", "--weights", help="parent directory of the graphcast params and stats", required=True)
257 | parser.add_argument("-l", "--length", help="length of forecast (6-hourly), an integer number in range [1, 40]", required=True)
258 | parser.add_argument("-o", "--output", help="output directory", default=None)
259 | parser.add_argument("-p", "--pressure", help="number of pressure levels", default=13)
260 | parser.add_argument("-u", "--upload", help="upload input data as well as forecasts to noaa s3 bucket (yes or no)", default = "no")
261 | parser.add_argument("-k", "--keep", help="keep input and output after uploading to noaa s3 bucket (yes or no)", default = "no")
262 |
263 | args = parser.parse_args()
264 | runner = GraphCastModel(args.weights, args.input, args.output, int(args.pressure), int(args.length))
265 |
266 | runner.load_pretrained_model()
267 | runner.load_gdas_data()
268 | runner.extract_inputs_targets_forcings()
269 | runner.load_normalization_stats()
270 | runner.get_predictions()
271 |
272 | upload_data = args.upload.lower() == "yes"
273 | keep_data = args.keep.lower() == "yes"
274 |
275 | if upload_data:
276 | runner.upload_to_s3(keep_data)
277 |
--------------------------------------------------------------------------------
/NCEP/upload_to_s3bucket.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 |
5 |
6 | def upload_to_s3(forecast_datetime, output_path, level, keep):
7 |
8 | if output_path is None:
9 | output_dir = os.path.join(os.getcwd(), f"forecasts_{level}_levels") # Use current directory if not specified
10 | else:
11 | output_dir = os.path.join(output_path, f"forecasts_{level}_levels")
12 |
13 | s3 = boto3.client('s3')
14 |
15 | date = forecast_datetime[0:8]
16 | time = forecast_datetime[8:10]
17 |
18 | # Upload output files to S3
19 | # Iterate over all files in the local directory and upload each one to S3
20 | s3_prefix = f'graphcastgfs.{date}/{time}/forecasts_{level}_levels'
21 |
22 | for root, dirs, files in os.walk(self.output_dir):
23 | for file in files:
24 | local_path = os.path.join(root, file)
25 | relative_path = os.path.relpath(local_path, local_directory)
26 | s3_path = os.path.join(s3_prefix, relative_path)
27 |
28 | # Upload the file
29 | s3.upload_file(local_path, self.s3_bucket_name, s3_path)
30 |
31 | print("Upload to s3 bucket completed.")
32 |
33 | # Delete local files if keep_data is False
34 | if not keep_data:
35 | # Remove forecasts data from the specified directory
36 | print("Removing downloaded grib2 data...")
37 | try:
38 | os.system(f"rm -rf {self.output_dir}")
39 | print("Downloaded data removed.")
40 | except Exception as e:
41 | print(f"Error removing downloaded data: {str(e)}")
42 | print("Local input and output files deleted.")
43 |
44 |
45 | if __name__ == "__main__":
46 |
47 | parser = argparse.ArgumentParser(description="upload input and output to s3 bucket")
48 | parser.add_argument("-d", "--datetime", help="forecast datetime", required=True)
49 | parser.add_argument("-l", "--level", help="number of pressure levels", default=13)
50 | parser.add_argument("-o", "--output", help="output file path (including file name)", default=None)
51 |
52 | args = parser.parse_args()
53 | keep = False
54 |
55 | upload_to_s3(args.datetime, str(args.level), args.output, keep)
56 |
--------------------------------------------------------------------------------
/NCEP/utils/nc2grib.py:
--------------------------------------------------------------------------------
1 | """ Utility for converting netcdf data to grib2.
2 |
3 | History:
4 | 01/26/2024: Linlin Cui (linlin.cui@noaa.gov), added function save_grib2
5 | 02/05/2024: Sadegh Tabas update the utility to a object-oriented format
6 | 04/25/2024: Sadegh Tabas, generate grib2 index files
7 | 07/03/2024: Sadegh Tabas, sorted grib2 variables
8 | """
9 |
10 | import os
11 | from datetime import datetime, timedelta
12 | import glob
13 | import subprocess
14 | import cf_units
15 | import iris
16 | import iris_grib
17 | import eccodes
18 |
19 | class Netcdf2Grib:
20 | def __init__(self):
21 | self.ATTR_MAPS = {
22 | '10m_u_component_of_wind': [10, 'x_wind', 'm s**-1'],
23 | '10m_v_component_of_wind': [10, 'y_wind', 'm s**-1'],
24 | 'mean_sea_level_pressure': [0, 'air_pressure_at_sea_level', 'Pa'],
25 | '2m_temperature': [2, 'air_temperature', 'K'],
26 | 'total_precipitation_6hr': [0, 'precipitation_amount', 'kg m**-2'],
27 | 'total_precipitation_cumsum': [0, 'precipitation_amount', 'kg m**-2'],
28 | 'vertical_velocity': [None, 'lagrangian_tendency_of_air_pressure', 'Pa s**-1'],
29 | 'specific_humidity': [None, 'specific_humidity', 'kg kg**-1'],
30 | 'temperature': [None, 'air_temperature', 'K'],
31 | 'geopotential': [None, 'geopotential_height', 'm'],
32 | 'u_component_of_wind': [None, 'x_wind', 'm s**-1'],
33 | 'v_component_of_wind': [None, 'y_wind', 'm s**-1'],
34 | }
35 |
36 | def tweaked_messages(self, cube, time_range):
37 | """
38 | Adjust GRIB messages based on cube properties.
39 | """
40 | for cube, grib_message in iris_grib.save_pairs_from_cube(cube):
41 | if cube.standard_name == 'precipitation_amount':
42 | eccodes.codes_set(grib_message, 'stepType', 'accum')
43 | eccodes.codes_set(grib_message, 'stepRange', time_range)
44 | eccodes.codes_set(grib_message, 'discipline', 0)
45 | eccodes.codes_set(grib_message, 'parameterCategory', 1)
46 | eccodes.codes_set(grib_message, 'parameterNumber', 8)
47 | eccodes.codes_set(grib_message, 'typeOfFirstFixedSurface', 1)
48 | eccodes.codes_set(grib_message, 'typeOfStatisticalProcessing', 1)
49 | elif cube.standard_name == 'air_pressure_at_sea_level':
50 | eccodes.codes_set(grib_message, 'discipline', 0)
51 | eccodes.codes_set(grib_message, 'parameterCategory', 3)
52 | eccodes.codes_set(grib_message, 'parameterNumber', 1)
53 | eccodes.codes_set(grib_message, 'typeOfFirstFixedSurface', 101)
54 | yield grib_message
55 |
56 | def save_grib2(self, dates, forecasts, outdir):
57 | """
58 | Convert netCDF file to GRIB2 format file.
59 | Args:
60 | dates: array of datetime object, from the source file
61 | forecasts: xarray forecasts dataset
62 | outdir: output directory
63 |
64 | Returns:
65 | No return values, will save to grib2 file
66 | """
67 | forecasts = forecasts.reindex(lat=list(reversed(forecasts.lat)))
68 |
69 | for var in forecasts.variables:
70 | if 'batch' in forecasts[var].dims:
71 | forecasts[var] = forecasts[var].squeeze(dim='batch')
72 |
73 | # Update units
74 | forecasts['level'] = forecasts['level'] * 100
75 | forecasts['level'].attrs['long_name'] = 'pressure'
76 | forecasts['level'].attrs['units'] = 'Pa'
77 | forecasts['geopotential'] = forecasts['geopotential'] / 9.80665
78 | if 'total_precipitation_6hr' in forecasts:
79 | forecasts['total_precipitation_6hr'] = forecasts['total_precipitation_6hr'] * 1000
80 | forecasts['total_precipitation_cumsum'] = forecasts['total_precipitation_6hr'].cumsum(axis=0)
81 |
82 | filename = os.path.join(outdir, "forecast_to_grib2.nc")
83 | forecasts.to_netcdf(filename)
84 |
85 | # Load cubes from netCDF file
86 | cubes = iris.load(filename)
87 | times = cubes[0].coord('time').points
88 | forecast_starttime = dates[0][1]
89 | cycle = forecast_starttime.hour
90 | print(f'Forecast start time is {forecast_starttime}')
91 |
92 | datevectors = [forecast_starttime + timedelta(hours=int(t)) for t in times]
93 |
94 | time_fmt_str = '00:00:00'
95 | time_unit_str = f"Hours since {forecast_starttime.strftime('%Y-%m-%d %H:00:00')}"
96 | time_coord = cubes[0].coord('time')
97 | new_time_unit = cf_units.Unit(time_unit_str, calendar=cf_units.CALENDAR_STANDARD)
98 | new_time_points = [new_time_unit.date2num(dt) for dt in datevectors]
99 | new_time_coord = iris.coords.DimCoord(new_time_points, standard_name='time', units=new_time_unit)
100 |
101 | for date in datevectors:
102 | print(f"Processing for time {date.strftime('%Y-%m-%d %H:00:00')}")
103 | hrs = int((date - forecast_starttime).total_seconds() // 3600)
104 | outfile = os.path.join(outdir, f'graphcastgfs.t{cycle:02d}z.pgrb2.0p25.f{hrs:03d}')
105 | print(outfile)
106 |
107 | for cube in sorted(cubes, key=lambda cube: cube.name()):
108 | var_name = cube.name()
109 |
110 | # Adjust cube for different variables
111 | time_coord_dim = cube.coord_dims('time')
112 | cube.remove_coord('time')
113 | cube.add_dim_coord(new_time_coord, time_coord_dim)
114 |
115 | hour_6 = iris.Constraint(time=iris.time.PartialDateTime(month=date.month, day=date.day, hour=date.hour))
116 | cube_slice = cube.extract(hour_6)
117 | cube_slice.coord('latitude').coord_system = iris.coord_systems.GeogCS(4326)
118 | cube_slice.coord('longitude').coord_system = iris.coord_systems.GeogCS(4326)
119 |
120 | if len(cube_slice.data.shape) == 3:
121 | levels = cube_slice.coord('pressure').points
122 | for level in levels:
123 | cube_slice_level = cube_slice.extract(iris.Constraint(pressure=level))
124 | cube_slice_level.add_aux_coord(iris.coords.DimCoord(hrs, standard_name='forecast_period', units='hours'))
125 | cube_slice_level.standard_name = self.ATTR_MAPS[var_name][1]
126 | cube_slice_level.units = self.ATTR_MAPS[var_name][2]
127 | iris.save(cube_slice_level, outfile, saver='grib2', append=True)
128 | else:
129 | cube_slice.add_aux_coord(iris.coords.DimCoord(hrs, standard_name='forecast_period', units='hours'))
130 | cube_slice.standard_name = self.ATTR_MAPS[var_name][1]
131 | cube_slice.units = self.ATTR_MAPS[var_name][2]
132 |
133 | if var_name not in ['mean_sea_level_pressure', 'total_precipitation_6hr', 'total_precipitation_cumsum']:
134 | cube_slice.add_aux_coord(iris.coords.DimCoord(self.ATTR_MAPS[var_name][0], standard_name='height', units='m'))
135 | iris.save(cube_slice, outfile, saver='grib2', append=True)
136 | elif var_name == 'total_precipitation_6hr':
137 | iris_grib.save_messages(self.tweaked_messages(cube_slice, f'{hrs-6}-{hrs}'), outfile, append=True)
138 | elif var_name == 'total_precipitation_cumsum':
139 | iris_grib.save_messages(self.tweaked_messages(cube_slice, f'0-{hrs}'), outfile, append=True)
140 | elif var_name == 'mean_sea_level_pressure':
141 | cube_slice.add_aux_coord(iris.coords.DimCoord(self.ATTR_MAPS[var_name][0], standard_name='altitude', units='m'))
142 | iris_grib.save_messages(self.tweaked_messages(cube_slice, f'{hrs-6}-{hrs}'), outfile, append=True)
143 |
144 | # Use wgrib2 to generate index files
145 | output_idx_file = f"{outfile}.idx"
146 |
147 | # Construct the wgrib2 command
148 | wgrib2_command = ['wgrib2', '-s', outfile]
149 |
150 | try:
151 | # Open the output file for writing
152 | with open(output_idx_file, "w") as f_out:
153 | # Execute the wgrib2 command and redirect stdout to the output file
154 | subprocess.run(wgrib2_command, stdout=f_out, check=True)
155 |
156 | print(f"Index file created successfully: {output_idx_file}")
157 |
158 | except subprocess.CalledProcessError as e:
159 | print(f"Error running wgrib2 command: {e}")
160 |
161 |
162 | # Remove intermediate netCDF file
163 | if os.path.isfile(filename):
164 | print(f'Deleting intermediate nc file {filename}: ')
165 | os.remove(filename)
166 |
167 | # subset grib2 files
168 | def subset_grib2(indir=None):
169 | files = glob.glob(f'{indir}/graphcastgfs.*')
170 | files.sort()
171 |
172 | outdir = os.path.join(indir, 'north_america')
173 | os.makedirs(outdir, exist_ok=True)
174 |
175 | lonMin, lonMax, latMin, latMax = 61.0, 299.0, -37.0, 37.0
176 | for grbfile in files:
177 | outfile = f"{outdir}/{grbfile.split('/')[-1]}"
178 | command = ['wgrib2', grbfile, '-small_grib', f'{lonMin}:{lonMax}', f'{latMin}:{latMax}', outfile]
179 | subprocess.run(command, check=True)
180 |
181 |
182 | # subset_grib2(outdir)
183 |
184 |
--------------------------------------------------------------------------------
/docs/GenCast_0p25deg_accelerator_scorecard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/GenCast_0p25deg_accelerator_scorecard.png
--------------------------------------------------------------------------------
/docs/GenCast_0p25deg_attention_implementation_scorecard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/GenCast_0p25deg_attention_implementation_scorecard.png
--------------------------------------------------------------------------------
/docs/GenCast_1p0deg_Mini_ENS_scorecard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/GenCast_1p0deg_Mini_ENS_scorecard.png
--------------------------------------------------------------------------------
/docs/local_runtime_popup_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/local_runtime_popup_1.png
--------------------------------------------------------------------------------
/docs/local_runtime_popup_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/local_runtime_popup_2.png
--------------------------------------------------------------------------------
/docs/local_runtime_url.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/local_runtime_url.png
--------------------------------------------------------------------------------
/docs/project.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/project.png
--------------------------------------------------------------------------------
/docs/provision_tpu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/provision_tpu.png
--------------------------------------------------------------------------------
/docs/tpu_types.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/tpu_types.png
--------------------------------------------------------------------------------
/graphcast/casting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Wrappers that take care of casting."""
15 |
16 | import contextlib
17 | from typing import Any, Mapping, Tuple
18 |
19 | import chex
20 | from graphcast import predictor_base
21 | import haiku as hk
22 | import jax
23 | import jax.numpy as jnp
24 | import numpy as np
25 | import xarray
26 |
27 |
28 | PyTree = Any
29 |
30 |
31 | class Bfloat16Cast(predictor_base.Predictor):
32 | """Wrapper that casts all inputs to bfloat16 and outputs to targets dtype."""
33 |
34 | def __init__(self, predictor: predictor_base.Predictor, enabled: bool = True):
35 | """Inits the wrapper.
36 |
37 | Args:
38 | predictor: predictor being wrapped.
39 | enabled: disables the wrapper if False, for simpler hyperparameter scans.
40 |
41 | """
42 | self._enabled = enabled
43 | self._predictor = predictor
44 |
45 | def __call__(self,
46 | inputs: xarray.Dataset,
47 | targets_template: xarray.Dataset,
48 | forcings: xarray.Dataset,
49 | **kwargs
50 | ) -> xarray.Dataset:
51 | if not self._enabled:
52 | return self._predictor(inputs, targets_template, forcings, **kwargs)
53 |
54 | with bfloat16_variable_view():
55 | predictions = self._predictor(
56 | *_all_inputs_to_bfloat16(inputs, targets_template, forcings),
57 | **kwargs,)
58 |
59 | predictions_dtype = infer_floating_dtype(predictions) # pytype: disable=wrong-arg-types
60 | if predictions_dtype != jnp.bfloat16:
61 | raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
62 |
63 | targets_dtype = infer_floating_dtype(targets_template) # pytype: disable=wrong-arg-types
64 | return tree_map_cast(
65 | predictions, input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
66 |
67 | def loss(self,
68 | inputs: xarray.Dataset,
69 | targets: xarray.Dataset,
70 | forcings: xarray.Dataset,
71 | **kwargs,
72 | ) -> predictor_base.LossAndDiagnostics:
73 | if not self._enabled:
74 | return self._predictor.loss(inputs, targets, forcings, **kwargs)
75 |
76 | with bfloat16_variable_view():
77 | loss, scalars = self._predictor.loss(
78 | *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
79 |
80 | if loss.dtype != jnp.bfloat16:
81 | raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
82 |
83 | targets_dtype = infer_floating_dtype(targets) # pytype: disable=wrong-arg-types
84 |
85 | # Note that casting back the loss to e.g. float32 should not affect data
86 | # types of the backwards pass, because the first thing the backwards pass
87 | # should do is to go backwards the casting op and cast back to bfloat16
88 | # (and xprofs seem to confirm this).
89 | return tree_map_cast((loss, scalars),
90 | input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
91 |
92 | def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
93 | self,
94 | inputs: xarray.Dataset,
95 | targets: xarray.Dataset,
96 | forcings: xarray.Dataset,
97 | **kwargs,
98 | ) -> Tuple[predictor_base.LossAndDiagnostics,
99 | xarray.Dataset]:
100 | if not self._enabled:
101 | return self._predictor.loss_and_predictions(inputs, targets, forcings, # pytype: disable=bad-return-type # jax-ndarray
102 | **kwargs)
103 |
104 | with bfloat16_variable_view():
105 | (loss, scalars), predictions = self._predictor.loss_and_predictions(
106 | *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
107 |
108 | if loss.dtype != jnp.bfloat16:
109 | raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
110 |
111 | predictions_dtype = infer_floating_dtype(predictions) # pytype: disable=wrong-arg-types
112 | if predictions_dtype != jnp.bfloat16:
113 | raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
114 |
115 | targets_dtype = infer_floating_dtype(targets) # pytype: disable=wrong-arg-types
116 | return tree_map_cast(((loss, scalars), predictions),
117 | input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
118 |
119 |
120 | def infer_floating_dtype(data_vars: Mapping[str, chex.Array]) -> np.dtype:
121 | """Infers a floating dtype from an input mapping of data."""
122 | dtypes = {
123 | v.dtype
124 | for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
125 | if len(dtypes) != 1:
126 | dtypes_and_shapes = {
127 | k: (v.dtype, v.shape)
128 | for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
129 | raise ValueError(
130 | f'Did not found exactly one floating dtype {dtypes} in input variables:'
131 | f'{dtypes_and_shapes}')
132 | return list(dtypes)[0]
133 |
134 |
135 | def _all_inputs_to_bfloat16(
136 | inputs: xarray.Dataset,
137 | targets: xarray.Dataset,
138 | forcings: xarray.Dataset,
139 | ) -> Tuple[xarray.Dataset,
140 | xarray.Dataset,
141 | xarray.Dataset]:
142 | return (inputs.astype(jnp.bfloat16),
143 | jax.tree.map(lambda x: x.astype(jnp.bfloat16), targets),
144 | forcings.astype(jnp.bfloat16))
145 |
146 |
147 | def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype,
148 | ) -> PyTree:
149 | def cast_fn(x):
150 | if x.dtype == input_dtype:
151 | return x.astype(output_dtype)
152 | return jax.tree.map(cast_fn, inputs)
153 |
154 |
155 | @contextlib.contextmanager
156 | def bfloat16_variable_view(enabled: bool = True):
157 | """Context for Haiku modules with float32 params, but bfloat16 activations.
158 |
159 | It works as follows:
160 | * Every time a variable is requested to be created/set as np.bfloat16,
161 | it will create an underlying float32 variable, instead.
162 | * Every time a variable a variable is requested as bfloat16, it will check the
163 | variable is of float32 type, and cast the variable to bfloat16.
164 |
165 | Note the gradients are still computed and accumulated as float32, because
166 | the params returned by init are float32, so the gradient function with
167 | respect to the params will already include an implicit casting to float32.
168 |
169 | Args:
170 | enabled: Only enables bfloat16 behavior if True.
171 |
172 | Yields:
173 | None
174 | """
175 |
176 | if enabled:
177 | with hk.custom_creator(
178 | _bfloat16_creator, state=True), hk.custom_getter(
179 | _bfloat16_getter, state=True), hk.custom_setter(
180 | _bfloat16_setter):
181 | yield
182 | else:
183 | yield
184 |
185 |
186 | def _bfloat16_creator(next_creator, shape, dtype, init, context):
187 | """Creates float32 variables when bfloat16 is requested."""
188 | if context.original_dtype == jnp.bfloat16:
189 | dtype = jnp.float32
190 | return next_creator(shape, dtype, init)
191 |
192 |
193 | def _bfloat16_getter(next_getter, value, context):
194 | """Casts float32 to bfloat16 when bfloat16 was originally requested."""
195 | if context.original_dtype == jnp.bfloat16:
196 | assert value.dtype == jnp.float32
197 | value = value.astype(jnp.bfloat16)
198 | return next_getter(value)
199 |
200 |
201 | def _bfloat16_setter(next_setter, value, context):
202 | """Casts bfloat16 to float32 when bfloat16 was originally set."""
203 | if context.original_dtype == jnp.bfloat16:
204 | value = value.astype(jnp.float32)
205 | return next_setter(value)
206 |
--------------------------------------------------------------------------------
/graphcast/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Serialize and deserialize trees."""
15 |
16 | import dataclasses
17 | import io
18 | import types
19 | from typing import Any, BinaryIO, Optional, TypeVar
20 |
21 | import numpy as np
22 |
23 | _T = TypeVar("_T")
24 |
25 |
26 | def dump(dest: BinaryIO, value: Any) -> None:
27 | """Dump a tree of dicts/dataclasses to a file object.
28 |
29 | Args:
30 | dest: a file object to write to.
31 | value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and
32 | other basic types. Unions are not supported, other than Optional/None
33 | which is only supported in dataclasses, not in dicts, lists or tuples.
34 | All leaves must be coercible to a numpy array, and recoverable as a single
35 | arg to a type.
36 | """
37 | buffer = io.BytesIO() # In case the destination doesn't support seeking.
38 | np.savez(buffer, **_flatten(value))
39 | dest.write(buffer.getvalue())
40 |
41 |
42 | def load(source: BinaryIO, typ: type[_T]) -> _T:
43 | """Load from a file object and convert it to the specified type.
44 |
45 | Args:
46 | source: a file object to read from.
47 | typ: a type object that acts as a schema for deserialization. It must match
48 | what was serialized. If a type is Any, it will be returned however numpy
49 | serialized it, which is what you want for a tree of numpy arrays.
50 |
51 | Returns:
52 | the deserialized value as the specified type.
53 | """
54 | return _convert_types(typ, _unflatten(np.load(source)))
55 |
56 |
57 | _SEP = ":"
58 |
59 |
60 | def _flatten(tree: Any) -> dict[str, Any]:
61 | """Flatten a tree of dicts/dataclasses/lists/tuples to a single dict."""
62 | if dataclasses.is_dataclass(tree):
63 | # Don't use dataclasses.asdict as it is recursive so skips dropping None.
64 | tree = {f.name: v for f in dataclasses.fields(tree)
65 | if (v := getattr(tree, f.name)) is not None}
66 | elif isinstance(tree, (list, tuple)):
67 | tree = dict(enumerate(tree))
68 |
69 | assert isinstance(tree, dict)
70 |
71 | flat = {}
72 | for k, v in tree.items():
73 | k = str(k)
74 | assert _SEP not in k
75 | if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)):
76 | for a, b in _flatten(v).items():
77 | flat[f"{k}{_SEP}{a}"] = b
78 | else:
79 | assert v is not None
80 | flat[k] = v
81 | return flat
82 |
83 |
84 | def _unflatten(flat: dict[str, Any]) -> dict[str, Any]:
85 | """Unflatten a dict to a tree of dicts."""
86 | tree = {}
87 | for flat_key, v in flat.items():
88 | node = tree
89 | keys = flat_key.split(_SEP)
90 | for k in keys[:-1]:
91 | if k not in node:
92 | node[k] = {}
93 | node = node[k]
94 | node[keys[-1]] = v
95 | return tree
96 |
97 |
98 | def _convert_types(typ: type[_T], value: Any) -> _T:
99 | """Convert some structure into the given type. The structures must match."""
100 | if typ in (Any, ...):
101 | return value
102 |
103 | if typ in (int, float, str, bool):
104 | return typ(value)
105 |
106 | if typ is np.ndarray:
107 | assert isinstance(value, np.ndarray)
108 | return value
109 |
110 | if dataclasses.is_dataclass(typ):
111 | kwargs = {}
112 | for f in dataclasses.fields(typ):
113 | # Only support Optional for dataclasses, as numpy can't serialize it
114 | # directly (without pickle), and dataclasses are the only case where we
115 | # can know the full set of values and types and therefore know the
116 | # non-existence must mean None.
117 | if isinstance(f.type, (types.UnionType, type(Optional[int]))):
118 | constructors = [t for t in f.type.__args__ if t is not types.NoneType]
119 | if len(constructors) != 1:
120 | raise TypeError(
121 | "Optional works, Union with anything except None doesn't")
122 | if f.name not in value:
123 | kwargs[f.name] = None
124 | continue
125 | constructor = constructors[0]
126 | else:
127 | constructor = f.type
128 |
129 | if f.name in value:
130 | kwargs[f.name] = _convert_types(constructor, value[f.name])
131 | else:
132 | raise ValueError(f"Missing value: {f.name}")
133 | return typ(**kwargs)
134 |
135 | base_type = getattr(typ, "__origin__", None)
136 |
137 | if base_type is dict:
138 | assert len(typ.__args__) == 2
139 | key_type, value_type = typ.__args__
140 | return {_convert_types(key_type, k): _convert_types(value_type, v)
141 | for k, v in value.items()}
142 |
143 | if base_type is list:
144 | assert len(typ.__args__) == 1
145 | value_type = typ.__args__[0]
146 | return [_convert_types(value_type, v)
147 | for _, v in sorted(value.items(), key=lambda x: int(x[0]))]
148 |
149 | if base_type is tuple:
150 | if len(typ.__args__) == 2 and typ.__args__[1] == ...:
151 | # An arbitrary length tuple of a single type, eg: tuple[int, ...]
152 | value_type = typ.__args__[0]
153 | return tuple(_convert_types(value_type, v)
154 | for _, v in sorted(value.items(), key=lambda x: int(x[0])))
155 | else:
156 | # A fixed length tuple of arbitrary types, eg: tuple[int, str, float]
157 | assert len(typ.__args__) == len(value)
158 | return tuple(
159 | _convert_types(t, v)
160 | for t, (_, v) in zip(
161 | typ.__args__, sorted(value.items(), key=lambda x: int(x[0]))))
162 |
163 | # This is probably unreachable with reasonable serializable inputs.
164 | try:
165 | return typ(value)
166 | except TypeError as e:
167 | raise TypeError(
168 | "_convert_types expects the type argument to be a dataclass defined "
169 | "with types that are valid constructors (eg tuple is fine, Tuple "
170 | "isn't), and accept a numpy array as the sole argument.") from e
171 |
--------------------------------------------------------------------------------
/graphcast/checkpoint_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Check that the checkpoint serialization is reversable."""
15 |
16 | import dataclasses
17 | import io
18 | from typing import Any, Optional, Union
19 |
20 | from absl.testing import absltest
21 | from graphcast import checkpoint
22 | import numpy as np
23 |
24 |
25 | @dataclasses.dataclass
26 | class SubConfig:
27 | a: int
28 | b: str
29 |
30 |
31 | @dataclasses.dataclass
32 | class Config:
33 | bt: bool
34 | bf: bool
35 | i: int
36 | f: float
37 | o1: Optional[int]
38 | o2: Optional[int]
39 | o3: Union[int, None]
40 | o4: Union[int, None]
41 | o5: int | None
42 | o6: int | None
43 | li: list[int]
44 | ls: list[str]
45 | ldc: list[SubConfig]
46 | tf: tuple[float, ...]
47 | ts: tuple[str, ...]
48 | t: tuple[str, int, SubConfig]
49 | tdc: tuple[SubConfig, ...]
50 | dsi: dict[str, int]
51 | dss: dict[str, str]
52 | dis: dict[int, str]
53 | dsdis: dict[str, dict[int, str]]
54 | dc: SubConfig
55 | dco: Optional[SubConfig]
56 | ddc: dict[str, SubConfig]
57 |
58 |
59 | @dataclasses.dataclass
60 | class Checkpoint:
61 | params: dict[str, Any]
62 | config: Config
63 |
64 |
65 | class DataclassTest(absltest.TestCase):
66 |
67 | def test_serialize_dataclass(self):
68 | ckpt = Checkpoint(
69 | params={
70 | "layer1": {
71 | "w": np.arange(10).reshape(2, 5),
72 | "b": np.array([2, 6]),
73 | },
74 | "layer2": {
75 | "w": np.arange(8).reshape(2, 4),
76 | "b": np.array([2, 6]),
77 | },
78 | "blah": np.array([3, 9]),
79 | },
80 | config=Config(
81 | bt=True,
82 | bf=False,
83 | i=42,
84 | f=3.14,
85 | o1=1,
86 | o2=None,
87 | o3=2,
88 | o4=None,
89 | o5=3,
90 | o6=None,
91 | li=[12, 9, 7, 15, 16, 14, 1, 6, 11, 4, 10, 5, 13, 3, 8, 2],
92 | ls=list("qhjfdxtpzgemryoikwvblcaus"),
93 | ldc=[SubConfig(1, "hello"), SubConfig(2, "world")],
94 | tf=(1, 4, 2, 10, 5, 9, 13, 16, 15, 8, 12, 7, 11, 14, 3, 6),
95 | ts=("hello", "world"),
96 | t=("foo", 42, SubConfig(1, "bar")),
97 | tdc=(SubConfig(1, "hello"), SubConfig(2, "world")),
98 | dsi={"a": 1, "b": 2, "c": 3},
99 | dss={"d": "e", "f": "g"},
100 | dis={1: "a", 2: "b", 3: "c"},
101 | dsdis={"a": {1: "hello", 2: "world"}, "b": {1: "world"}},
102 | dc=SubConfig(1, "hello"),
103 | dco=None,
104 | ddc={"a": SubConfig(1, "hello"), "b": SubConfig(2, "world")},
105 | ))
106 |
107 | buffer = io.BytesIO()
108 | checkpoint.dump(buffer, ckpt)
109 | buffer.seek(0)
110 | ckpt2 = checkpoint.load(buffer, Checkpoint)
111 | np.testing.assert_array_equal(ckpt.params["layer1"]["w"],
112 | ckpt2.params["layer1"]["w"])
113 | np.testing.assert_array_equal(ckpt.params["layer1"]["b"],
114 | ckpt2.params["layer1"]["b"])
115 | np.testing.assert_array_equal(ckpt.params["layer2"]["w"],
116 | ckpt2.params["layer2"]["w"])
117 | np.testing.assert_array_equal(ckpt.params["layer2"]["b"],
118 | ckpt2.params["layer2"]["b"])
119 | np.testing.assert_array_equal(ckpt.params["blah"], ckpt2.params["blah"])
120 | self.assertEqual(ckpt.config, ckpt2.config)
121 |
122 |
123 | if __name__ == "__main__":
124 | absltest.main()
125 |
--------------------------------------------------------------------------------
/graphcast/data_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for `data_utils.py`."""
15 |
16 | import datetime
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from graphcast import data_utils
20 | import numpy as np
21 | import xarray as xa
22 |
23 |
24 | class DataUtilsTest(parameterized.TestCase):
25 |
26 | def setUp(self):
27 | super().setUp()
28 | # Fix the seed for reproducibility.
29 | np.random.seed(0)
30 |
31 | def test_year_progress_is_zero_at_year_start_or_end(self):
32 | year_progress = data_utils.get_year_progress(
33 | np.array([
34 | 0,
35 | data_utils.AVG_SEC_PER_YEAR,
36 | data_utils.AVG_SEC_PER_YEAR * 42, # 42 years.
37 | ])
38 | )
39 | np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape))
40 |
41 | def test_year_progress_is_almost_one_before_year_ends(self):
42 | year_progress = data_utils.get_year_progress(
43 | np.array([
44 | data_utils.AVG_SEC_PER_YEAR - 1,
45 | (data_utils.AVG_SEC_PER_YEAR - 1) * 42, # ~42 years
46 | ])
47 | )
48 | with self.subTest("Year progress values are close to 1"):
49 | self.assertTrue(np.all(year_progress > 0.999))
50 | with self.subTest("Year progress values != 1"):
51 | self.assertTrue(np.all(year_progress < 1.0))
52 |
53 | def test_day_progress_computes_for_all_times_and_longitudes(self):
54 | times = np.random.randint(low=0, high=1e10, size=10)
55 | longitudes = np.arange(0, 360.0, 1.0)
56 | day_progress = data_utils.get_day_progress(times, longitudes)
57 | with self.subTest("Day progress is computed for all times and longinutes"):
58 | self.assertSequenceEqual(
59 | day_progress.shape, (len(times), len(longitudes))
60 | )
61 |
62 | @parameterized.named_parameters(
63 | dict(
64 | testcase_name="random_date_1",
65 | year=1988,
66 | month=11,
67 | day=7,
68 | hour=2,
69 | minute=45,
70 | second=34,
71 | ),
72 | dict(
73 | testcase_name="random_date_2",
74 | year=2022,
75 | month=3,
76 | day=12,
77 | hour=7,
78 | minute=1,
79 | second=0,
80 | ),
81 | )
82 | def test_day_progress_is_in_between_zero_and_one(
83 | self, year, month, day, hour, minute, second
84 | ):
85 | # Datetime from a timestamp.
86 | dt = datetime.datetime(year, month, day, hour, minute, second)
87 | # Epoch time.
88 | epoch_time = datetime.datetime(1970, 1, 1)
89 | # Seconds since epoch.
90 | seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()])
91 |
92 | # Longitudes with 1 degree resolution.
93 | longitudes = np.arange(0, 360.0, 1.0)
94 |
95 | day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes)
96 | with self.subTest("Day progress >= 0"):
97 | self.assertTrue(np.all(day_progress >= 0.0))
98 | with self.subTest("Day progress < 1"):
99 | self.assertTrue(np.all(day_progress < 1.0))
100 |
101 | def test_day_progress_is_zero_at_day_start_or_end(self):
102 | day_progress = data_utils.get_day_progress(
103 | seconds_since_epoch=np.array([
104 | 0,
105 | data_utils.SEC_PER_DAY,
106 | data_utils.SEC_PER_DAY * 42, # 42 days.
107 | ]),
108 | longitude=np.array([0.0]),
109 | )
110 | np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape))
111 |
112 | def test_day_progress_specific_value(self):
113 | day_progress = data_utils.get_day_progress(
114 | seconds_since_epoch=np.array([123]),
115 | longitude=np.array([0.0]),
116 | )
117 | np.testing.assert_array_almost_equal(
118 | day_progress, np.array([[0.00142361]]), decimal=6
119 | )
120 |
121 | def test_featurize_progress_valid_values_and_dimensions(self):
122 | day_progress = np.array([0.0, 0.45, 0.213])
123 | feature_dimensions = ("time",)
124 | progress_features = data_utils.featurize_progress(
125 | name="day_progress", dims=feature_dimensions, progress=day_progress
126 | )
127 | for feature in progress_features.values():
128 | with self.subTest(f"Valid dimensions for {feature}"):
129 | self.assertSequenceEqual(feature.dims, feature_dimensions)
130 |
131 | with self.subTest("Valid values for day_progress"):
132 | np.testing.assert_array_equal(
133 | day_progress, progress_features["day_progress"].values
134 | )
135 |
136 | with self.subTest("Valid values for day_progress_sin"):
137 | np.testing.assert_array_almost_equal(
138 | np.array([0.0, 0.30901699, 0.97309851]),
139 | progress_features["day_progress_sin"].values,
140 | decimal=6,
141 | )
142 |
143 | with self.subTest("Valid values for day_progress_cos"):
144 | np.testing.assert_array_almost_equal(
145 | np.array([1.0, -0.95105652, 0.23038943]),
146 | progress_features["day_progress_cos"].values,
147 | decimal=6,
148 | )
149 |
150 | def test_featurize_progress_invalid_dimensions(self):
151 | year_progress = np.array([0.0, 0.45, 0.213])
152 | feature_dimensions = ("time", "longitude")
153 | with self.assertRaises(ValueError):
154 | data_utils.featurize_progress(
155 | name="year_progress", dims=feature_dimensions, progress=year_progress
156 | )
157 |
158 | def test_add_derived_vars_variables_added(self):
159 | data = xa.Dataset(
160 | data_vars={
161 | "var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3))
162 | },
163 | coords={
164 | "lon": np.array([0.0, 0.5]),
165 | "datetime": np.array([
166 | datetime.datetime(2021, 1, 1),
167 | datetime.datetime(2023, 1, 1),
168 | datetime.datetime(2023, 1, 3),
169 | ]),
170 | },
171 | )
172 | data_utils.add_derived_vars(data)
173 | all_variables = set(data.variables)
174 |
175 | with self.subTest("Original value was not removed"):
176 | self.assertIn("var1", all_variables)
177 | with self.subTest("Year progress feature was added"):
178 | self.assertIn(data_utils.YEAR_PROGRESS, all_variables)
179 | with self.subTest("Day progress feature was added"):
180 | self.assertIn(data_utils.DAY_PROGRESS, all_variables)
181 |
182 | def test_add_derived_vars_existing_vars_not_overridden(self):
183 | dims = ["x", "lon", "datetime"]
184 | data = xa.Dataset(
185 | data_vars={
186 | "var1": (dims, 8 * np.random.randn(2, 2, 3)),
187 | data_utils.YEAR_PROGRESS: (dims, np.full((2, 2, 3), 0.111)),
188 | data_utils.DAY_PROGRESS: (dims, np.full((2, 2, 3), 0.222)),
189 | },
190 | coords={
191 | "lon": np.array([0.0, 0.5]),
192 | "datetime": np.array([
193 | datetime.datetime(2021, 1, 1),
194 | datetime.datetime(2023, 1, 1),
195 | datetime.datetime(2023, 1, 3),
196 | ]),
197 | },
198 | )
199 |
200 | data_utils.add_derived_vars(data)
201 |
202 | with self.subTest("Year progress feature was not overridden"):
203 | np.testing.assert_allclose(data[data_utils.YEAR_PROGRESS], 0.111)
204 | with self.subTest("Day progress feature was not overridden"):
205 | np.testing.assert_allclose(data[data_utils.DAY_PROGRESS], 0.222)
206 |
207 | @parameterized.named_parameters(
208 | dict(testcase_name="missing_datetime", coord_name="lon"),
209 | dict(testcase_name="missing_lon", coord_name="datetime"),
210 | )
211 | def test_add_derived_vars_missing_coordinate_raises_value_error(
212 | self, coord_name
213 | ):
214 | with self.subTest(f"Missing {coord_name} coordinate"):
215 | data = xa.Dataset(
216 | data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))},
217 | coords={
218 | coord_name: np.array([0.0, 0.5]),
219 | },
220 | )
221 | with self.assertRaises(ValueError):
222 | data_utils.add_derived_vars(data)
223 |
224 | def test_add_tisr_var_variable_added(self):
225 | data = xa.Dataset(
226 | data_vars={
227 | "var1": (["time", "lat", "lon"], np.full((2, 2, 2), 8.0))
228 | },
229 | coords={
230 | "lat": np.array([2.0, 1.0]),
231 | "lon": np.array([0.0, 0.5]),
232 | "time": np.array([100, 200], dtype="timedelta64[s]"),
233 | "datetime": xa.Variable(
234 | "time", np.array([10, 20], dtype="datetime64[D]")
235 | ),
236 | },
237 | )
238 |
239 | data_utils.add_tisr_var(data)
240 |
241 | self.assertIn(data_utils.TISR, set(data.variables))
242 |
243 | def test_add_tisr_var_existing_var_not_overridden(self):
244 | dims = ["time", "lat", "lon"]
245 | data = xa.Dataset(
246 | data_vars={
247 | "var1": (dims, np.full((2, 2, 2), 8.0)),
248 | data_utils.TISR: (dims, np.full((2, 2, 2), 1200.0)),
249 | },
250 | coords={
251 | "lat": np.array([2.0, 1.0]),
252 | "lon": np.array([0.0, 0.5]),
253 | "time": np.array([100, 200], dtype="timedelta64[s]"),
254 | "datetime": xa.Variable(
255 | "time", np.array([10, 20], dtype="datetime64[D]")
256 | ),
257 | },
258 | )
259 |
260 | data_utils.add_derived_vars(data)
261 |
262 | np.testing.assert_allclose(data[data_utils.TISR], 1200.0)
263 |
264 | def test_add_tisr_var_works_with_batch_dim_size_one(self):
265 | data = xa.Dataset(
266 | data_vars={
267 | "var1": (
268 | ["batch", "time", "lat", "lon"],
269 | np.full((1, 2, 2, 2), 8.0),
270 | )
271 | },
272 | coords={
273 | "lat": np.array([2.0, 1.0]),
274 | "lon": np.array([0.0, 0.5]),
275 | "time": np.array([100, 200], dtype="timedelta64[s]"),
276 | "datetime": xa.Variable(
277 | ("batch", "time"), np.array([[10, 20]], dtype="datetime64[D]")
278 | ),
279 | },
280 | )
281 |
282 | data_utils.add_tisr_var(data)
283 |
284 | self.assertIn(data_utils.TISR, set(data.variables))
285 |
286 | def test_add_tisr_var_fails_with_batch_dim_size_greater_than_one(self):
287 | data = xa.Dataset(
288 | data_vars={
289 | "var1": (
290 | ["batch", "time", "lat", "lon"],
291 | np.full((2, 2, 2, 2), 8.0),
292 | )
293 | },
294 | coords={
295 | "lat": np.array([2.0, 1.0]),
296 | "lon": np.array([0.0, 0.5]),
297 | "time": np.array([100, 200], dtype="timedelta64[s]"),
298 | "datetime": xa.Variable(
299 | ("batch", "time"),
300 | np.array([[10, 20], [100, 200]], dtype="datetime64[D]"),
301 | ),
302 | },
303 | )
304 |
305 | with self.assertRaisesRegex(ValueError, r"cannot select a dimension"):
306 | data_utils.add_tisr_var(data)
307 |
308 |
309 | if __name__ == "__main__":
310 | absltest.main()
311 |
--------------------------------------------------------------------------------
/graphcast/denoisers_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Base class for Denoisers used in diffusion Predictors.
15 |
16 | Denoisers are a bit like deterministic Predictors, except:
17 | * Their __call__ method also conditions on noisy_targets and the noise_levels
18 | of those noisy targets
19 | * They don't have an overrideable loss function (the loss is assumed to be some
20 | form of MSE and is implemented outside the Denoiser itself)
21 | """
22 |
23 | from typing import Optional, Protocol
24 |
25 | import xarray
26 |
27 |
28 | class Denoiser(Protocol):
29 | """A denoising model that conditions on inputs as well as noise level."""
30 |
31 | def __call__(
32 | self,
33 | inputs: xarray.Dataset,
34 | noisy_targets: xarray.Dataset,
35 | noise_levels: xarray.DataArray,
36 | forcings: Optional[xarray.Dataset] = None,
37 | **kwargs) -> xarray.Dataset:
38 | """Computes denoised targets from noisy targets.
39 |
40 | Args:
41 | inputs: Inputs to condition on, as for Predictor.__call__.
42 | noisy_targets: Targets which have had i.i.d. zero-mean Gaussian noise
43 | added to them (where the noise level used may vary along the 'batch'
44 | dimension).
45 | noise_levels: A DataArray with dimensions ('batch',) specifying the noise
46 | levels that were used for each example in the batch.
47 | forcings: Optional additional per-target-timestep forcings to condition
48 | on, as for Predictor.__call__.
49 | **kwargs: Any additional custom kwargs.
50 |
51 | Returns:
52 | Denoised predictions with the same shape as noisy_targets.
53 | """
54 |
--------------------------------------------------------------------------------
/graphcast/dpm_solver_plus_plus_2s.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """DPM-Solver++ 2S sampler from https://arxiv.org/abs/2211.01095."""
15 |
16 | from typing import Optional
17 |
18 | from graphcast import casting
19 | from graphcast import denoisers_base
20 | from graphcast import samplers_base as base
21 | from graphcast import samplers_utils as utils
22 | from graphcast import xarray_jax
23 | import haiku as hk
24 | import jax.numpy as jnp
25 | import xarray
26 |
27 |
28 | class Sampler(base.Sampler):
29 | """Sampling using DPM-Solver++ 2S from [1].
30 |
31 | This is combined with optional stochastic churn as described in [2].
32 |
33 | The '2S' terminology from [1] means that this is a second-order (2),
34 | single-step (S) solver. Here 'single-step' here distinguishes it from
35 | 'multi-step' methods where the results of function evaluations from previous
36 | steps are reused in computing updates for subsequent steps. The solver still
37 | uses multiple steps though.
38 |
39 | [1] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic
40 | Models, https://arxiv.org/abs/2211.01095
41 | [2] Elucidating the Design Space of Diffusion-Based Generative Models,
42 | https://arxiv.org/abs/2206.00364
43 | """
44 |
45 | def __init__(self,
46 | denoiser: denoisers_base.Denoiser,
47 | max_noise_level: float,
48 | min_noise_level: float,
49 | num_noise_levels: int,
50 | rho: float,
51 | stochastic_churn_rate: float,
52 | churn_min_noise_level: float,
53 | churn_max_noise_level: float,
54 | noise_level_inflation_factor: float
55 | ):
56 | """Initializes the sampler.
57 |
58 | Args:
59 | denoiser: A Denoiser which predicts noise-free targets.
60 | max_noise_level: The highest noise level used at the start of the
61 | sequence of reverse diffusion steps.
62 | min_noise_level: The lowest noise level used at the end of the sequence of
63 | reverse diffusion steps.
64 | num_noise_levels: Determines the number of noise levels used and hence the
65 | number of reverse diffusion steps performed.
66 | rho: Parameter affecting the spacing of noise steps. Higher values will
67 | concentrate noise steps more around zero.
68 | stochastic_churn_rate: S_churn from the paper. This controls the rate
69 | at which noise is re-injected/'churned' during the sampling algorithm.
70 | If this is set to zero then we are performing deterministic sampling
71 | as described in Algorithm 1.
72 | churn_min_noise_level: Minimum noise level at which stochastic churn
73 | occurs. S_min from the paper. Only used if stochastic_churn_rate > 0.
74 | churn_max_noise_level: Maximum noise level at which stochastic churn
75 | occurs. S_min from the paper. Only used if stochastic_churn_rate > 0.
76 | noise_level_inflation_factor: This can be used to set the actual amount of
77 | noise injected higher than what the denoiser is told has been added.
78 | The motivation is to compensate for a tendency of L2-trained denoisers
79 | to remove slightly too much noise / blur too much. S_noise from the
80 | paper. Only used if stochastic_churn_rate > 0.
81 | """
82 | super().__init__(denoiser)
83 | self._noise_levels = utils.noise_schedule(
84 | max_noise_level, min_noise_level, num_noise_levels, rho)
85 | self._stochastic_churn = stochastic_churn_rate > 0
86 | self._per_step_churn_rates = utils.stochastic_churn_rate_schedule(
87 | self._noise_levels, stochastic_churn_rate, churn_min_noise_level,
88 | churn_max_noise_level)
89 | self._noise_level_inflation_factor = noise_level_inflation_factor
90 |
91 | def __call__(
92 | self,
93 | inputs: xarray.Dataset,
94 | targets_template: xarray.Dataset,
95 | forcings: Optional[xarray.Dataset] = None,
96 | **kwargs) -> xarray.Dataset:
97 |
98 | dtype = casting.infer_floating_dtype(targets_template) # pytype: disable=wrong-arg-types
99 | noise_levels = jnp.array(self._noise_levels).astype(dtype)
100 | per_step_churn_rates = jnp.array(self._per_step_churn_rates).astype(dtype)
101 |
102 | def denoiser(noise_level: jnp.ndarray, x: xarray.Dataset) -> xarray.Dataset:
103 | """Computes D(x, sigma, y)."""
104 | bcast_noise_level = xarray_jax.DataArray(
105 | jnp.tile(noise_level, x.sizes['batch']), dims=('batch',))
106 | # Estimate the expectation of the fully-denoised target x0, conditional on
107 | # inputs/forcings, noisy targets and their noise level:
108 | return self._denoiser(
109 | inputs=inputs,
110 | noisy_targets=x,
111 | noise_levels=bcast_noise_level,
112 | forcings=forcings)
113 |
114 | def body_fn(i: jnp.ndarray, x: xarray.Dataset) -> xarray.Dataset:
115 | """One iteration of the sampling algorithm.
116 |
117 | Args:
118 | i: Sampling iteration.
119 | x: Noisy targets at iteration i, these will have noise level
120 | self._noise_levels[i].
121 |
122 | Returns:
123 | Noisy targets at the next lowest noise level self._noise_levels[i+1].
124 | """
125 | def init_noise(template):
126 | return noise_levels[0] * utils.spherical_white_noise_like(template)
127 |
128 | # Initialise the inputs if i == 0.
129 | # This is done here to ensure both noise sampler calls can use the same
130 | # spherical harmonic basis functions. While there may be a small compute
131 | # cost the memory savings can be significant.
132 | # TODO(dominicmasters): Figure out if we can merge the two noise sampler
133 | # calls into one to avoid this hack.
134 | maybe_init_noise = (i == 0).astype(noise_levels[0].dtype)
135 | x = x + init_noise(x) * maybe_init_noise
136 |
137 | noise_level = noise_levels[i]
138 |
139 | if self._stochastic_churn:
140 | # We increase the noise level of x a bit before taking it down again:
141 | x, noise_level = utils.apply_stochastic_churn(
142 | x, noise_level,
143 | stochastic_churn_rate=per_step_churn_rates[i],
144 | noise_level_inflation_factor=self._noise_level_inflation_factor)
145 |
146 | # Apply one step of the ODE solver to take x down to the next lowest
147 | # noise level.
148 |
149 | # Note that the Elucidating paper's choice of sigma(t)=t and s(t)=1
150 | # (corresponding to alpha(t)=1 in the DPM paper) as well as the standard
151 | # choice of r=1/2 (corresponding to a geometric mean for the s_i
152 | # midpoints) greatly simplifies the update from the DPM-Solver++ paper.
153 | # You need to do a bit of algebraic fiddling to arrive at the below after
154 | # substituting these choices into DPMSolver++'s Algorithm 1. The simpler
155 | # update we arrive at helps with intuition too.
156 |
157 | next_noise_level = noise_levels[i + 1]
158 | # This is s_{i+1} from the paper. They don't explain how the s_i are
159 | # chosen, but the default choice seems to be a geometric mean, which is
160 | # equivalent to setting all the r_i = 1/2.
161 | mid_noise_level = jnp.sqrt(noise_level * next_noise_level)
162 |
163 | mid_over_current = mid_noise_level / noise_level
164 | x_denoised = denoiser(noise_level, x)
165 | # This turns out to be a convex combination of current and denoised x,
166 | # which isn't entirely apparent from the paper formulae:
167 | x_mid = mid_over_current * x + (1 - mid_over_current) * x_denoised
168 |
169 | next_over_current = next_noise_level / noise_level
170 | x_mid_denoised = denoiser(mid_noise_level, x_mid) # pytype: disable=wrong-arg-types
171 | x_next = next_over_current * x + (1 - next_over_current) * x_mid_denoised
172 |
173 | # For the final step to noise level 0, we do an Euler update which
174 | # corresponds to just returning the denoiser's prediction directly.
175 | #
176 | # In fact the behaviour above when next_noise_level == 0 is almost
177 | # equivalent, except that it runs the denoiser a second time to denoise
178 | # from noise level 0. The denoiser should just be the identity function in
179 | # this case, but it hasn't necessarily been trained at noise level 0 so
180 | # we avoid relying on this.
181 | return utils.tree_where(next_noise_level == 0, x_denoised, x_next)
182 |
183 | # Init with zeros but apply additional noise at step 0 to initialise the
184 | # state.
185 | noise_init = xarray.zeros_like(targets_template)
186 | return hk.fori_loop(
187 | 0, len(noise_levels) - 1, body_fun=body_fn, init_val=noise_init)
188 |
--------------------------------------------------------------------------------
/graphcast/gencast.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Denoising diffusion models based on the framework of [1].
15 |
16 | Throughout we will refer to notation and equations from [1].
17 |
18 | [1] Elucidating the Design Space of Diffusion-Based Generative Models
19 | Karras, Aittala, Aila and Laine, 2022
20 | https://arxiv.org/abs/2206.00364
21 | """
22 |
23 | from typing import Any, Optional, Tuple
24 |
25 | import chex
26 | from graphcast import casting
27 | from graphcast import denoiser
28 | from graphcast import dpm_solver_plus_plus_2s
29 | from graphcast import graphcast
30 | from graphcast import losses
31 | from graphcast import predictor_base
32 | from graphcast import samplers_utils
33 | from graphcast import xarray_jax
34 | import haiku as hk
35 | import jax
36 | import xarray
37 |
38 |
39 | TARGET_SURFACE_VARS = (
40 | '2m_temperature',
41 | 'mean_sea_level_pressure',
42 | '10m_v_component_of_wind',
43 | '10m_u_component_of_wind', # GenCast predicts in 12hr timesteps.
44 | 'total_precipitation_12hr',
45 | 'sea_surface_temperature',
46 | )
47 |
48 | TARGET_SURFACE_NO_PRECIP_VARS = (
49 | '2m_temperature',
50 | 'mean_sea_level_pressure',
51 | '10m_v_component_of_wind',
52 | '10m_u_component_of_wind',
53 | 'sea_surface_temperature',
54 | )
55 |
56 |
57 | TASK = graphcast.TaskConfig(
58 | input_variables=(
59 | # GenCast doesn't take precipitation as input.
60 | TARGET_SURFACE_NO_PRECIP_VARS
61 | + graphcast.TARGET_ATMOSPHERIC_VARS
62 | + graphcast.GENERATED_FORCING_VARS
63 | + graphcast.STATIC_VARS
64 | ),
65 | target_variables=TARGET_SURFACE_VARS + graphcast.TARGET_ATMOSPHERIC_VARS,
66 | # GenCast doesn't take incident solar radiation as a forcing.
67 | forcing_variables=graphcast.GENERATED_FORCING_VARS,
68 | pressure_levels=graphcast.PRESSURE_LEVELS_WEATHERBENCH_13,
69 | # GenCast takes the current frame and the frame 12 hours prior.
70 | input_duration='24h',
71 | )
72 |
73 |
74 | @chex.dataclass(frozen=True, eq=True)
75 | class SamplerConfig:
76 | """Configures the sampler used to draw samples from GenCast.
77 |
78 | max_noise_level: The highest noise level used at the start of the
79 | sequence of reverse diffusion steps.
80 | min_noise_level: The lowest noise level used at the end of the sequence of
81 | reverse diffusion steps.
82 | num_noise_levels: Determines the number of noise levels used and hence the
83 | number of reverse diffusion steps performed.
84 | rho: Parameter affecting the spacing of noise steps. Higher values will
85 | concentrate noise steps more around zero.
86 | stochastic_churn_rate: S_churn from the paper. This controls the rate
87 | at which noise is re-injected/'churned' during the sampling algorithm.
88 | If this is set to zero then we are performing deterministic sampling
89 | as described in Algorithm 1.
90 | churn_max_noise_level: Maximum noise level at which stochastic churn
91 | occurs. S_min from the paper. Only used if stochastic_churn_rate > 0.
92 | churn_min_noise_level: Minimum noise level at which stochastic churn
93 | occurs. S_min from the paper. Only used if stochastic_churn_rate > 0.
94 | noise_level_inflation_factor: This can be used to set the actual amount of
95 | noise injected higher than what the denoiser is told has been added.
96 | The motivation is to compensate for a tendency of L2-trained denoisers
97 | to remove slightly too much noise / blur too much. S_noise from the
98 | paper. Only used if stochastic_churn_rate > 0.
99 | """
100 | max_noise_level: float = 80.
101 | min_noise_level: float = 0.03
102 | num_noise_levels: int = 20
103 | rho: float = 7.
104 | # Stochastic sampler settings.
105 | stochastic_churn_rate: float = 2.5
106 | churn_min_noise_level: float = 0.75
107 | churn_max_noise_level: float = float('inf')
108 | noise_level_inflation_factor: float = 1.05
109 |
110 |
111 | @chex.dataclass(frozen=True, eq=True)
112 | class NoiseConfig:
113 | training_noise_level_rho: float = 7.0
114 | training_max_noise_level: float = 88.0
115 | training_min_noise_level: float = 0.02
116 |
117 |
118 | @chex.dataclass(frozen=True, eq=True)
119 | class CheckPoint:
120 | description: str
121 | license: str
122 | params: dict[str, Any]
123 | task_config: graphcast.TaskConfig
124 | denoiser_architecture_config: denoiser.DenoiserArchitectureConfig
125 | sampler_config: SamplerConfig
126 | noise_config: NoiseConfig
127 | noise_encoder_config: denoiser.NoiseEncoderConfig
128 |
129 |
130 | class GenCast(predictor_base.Predictor):
131 | """Predictor for a denoising diffusion model following the framework of [1].
132 |
133 | [1] Elucidating the Design Space of Diffusion-Based Generative Models
134 | Karras, Aittala, Aila and Laine, 2022
135 | https://arxiv.org/abs/2206.00364
136 |
137 | Unlike the paper, we have a conditional model and our denoising function
138 | conditions on previous timesteps.
139 |
140 | As the paper demonstrates, the sampling algorithm can be varied independently
141 | of the denoising model and its training procedure, and it is separately
142 | configurable here.
143 | """
144 |
145 | def __init__(
146 | self,
147 | task_config: graphcast.TaskConfig,
148 | denoiser_architecture_config: denoiser.DenoiserArchitectureConfig,
149 | sampler_config: Optional[SamplerConfig] = None,
150 | noise_config: Optional[NoiseConfig] = None,
151 | noise_encoder_config: Optional[denoiser.NoiseEncoderConfig] = None,
152 | ):
153 | """Constructs GenCast."""
154 | # Output size depends on number of variables being predicted.
155 | num_surface_vars = len(
156 | set(task_config.target_variables)
157 | - set(graphcast.ALL_ATMOSPHERIC_VARS)
158 | )
159 | num_atmospheric_vars = len(
160 | set(task_config.target_variables)
161 | & set(graphcast.ALL_ATMOSPHERIC_VARS)
162 | )
163 | num_outputs = (
164 | num_surface_vars
165 | + len(task_config.pressure_levels) * num_atmospheric_vars
166 | )
167 | denoiser_architecture_config.node_output_size = num_outputs
168 | self._denoiser = denoiser.Denoiser(
169 | noise_encoder_config,
170 | denoiser_architecture_config,
171 | )
172 | self._sampler_config = sampler_config
173 | # Singleton to avoid re-initializing the sampler for each inference call.
174 | self._sampler = None
175 | self._noise_config = noise_config
176 |
177 | def _c_in(self, noise_scale: xarray.DataArray) -> xarray.DataArray:
178 | """Scaling applied to the noisy targets input to the underlying network."""
179 | return (noise_scale**2 + 1)**-0.5
180 |
181 | def _c_out(self, noise_scale: xarray.DataArray) -> xarray.DataArray:
182 | """Scaling applied to the underlying network's raw outputs."""
183 | return noise_scale * (noise_scale**2 + 1)**-0.5
184 |
185 | def _c_skip(self, noise_scale: xarray.DataArray) -> xarray.DataArray:
186 | """Scaling applied to the skip connection."""
187 | return 1 / (noise_scale**2 + 1)
188 |
189 | def _loss_weighting(self, noise_scale: xarray.DataArray) -> xarray.DataArray:
190 | r"""The loss weighting \lambda(\sigma) from the paper."""
191 | return self._c_out(noise_scale) ** -2
192 |
193 | def _preconditioned_denoiser(
194 | self,
195 | inputs: xarray.Dataset,
196 | noisy_targets: xarray.Dataset,
197 | noise_levels: xarray.DataArray,
198 | forcings: Optional[xarray.Dataset] = None,
199 | **kwargs) -> xarray.Dataset:
200 | """The preconditioned denoising function D from the paper (Eqn 7)."""
201 | raw_predictions = self._denoiser(
202 | inputs=inputs,
203 | noisy_targets=noisy_targets * self._c_in(noise_levels),
204 | noise_levels=noise_levels,
205 | forcings=forcings,
206 | **kwargs)
207 | return (raw_predictions * self._c_out(noise_levels) +
208 | noisy_targets * self._c_skip(noise_levels))
209 |
210 | def loss_and_predictions(
211 | self,
212 | inputs: xarray.Dataset,
213 | targets: xarray.Dataset,
214 | forcings: Optional[xarray.Dataset] = None,
215 | ) -> Tuple[predictor_base.LossAndDiagnostics, xarray.Dataset]:
216 | return self.loss(inputs, targets, forcings), self(inputs, targets, forcings)
217 |
218 | def loss(self,
219 | inputs: xarray.Dataset,
220 | targets: xarray.Dataset,
221 | forcings: Optional[xarray.Dataset] = None,
222 | ) -> predictor_base.LossAndDiagnostics:
223 |
224 | if self._noise_config is None:
225 | raise ValueError('Noise config must be specified to train GenCast.')
226 |
227 | # Sample noise levels:
228 | dtype = casting.infer_floating_dtype(targets) # pytype: disable=wrong-arg-types
229 | key = hk.next_rng_key()
230 | batch_size = inputs.sizes['batch']
231 | noise_levels = xarray_jax.DataArray(
232 | data=samplers_utils.rho_inverse_cdf(
233 | min_value=self._noise_config.training_min_noise_level,
234 | max_value=self._noise_config.training_max_noise_level,
235 | rho=self._noise_config.training_noise_level_rho,
236 | cdf=jax.random.uniform(key, shape=(batch_size,), dtype=dtype)),
237 | dims=('batch',))
238 |
239 | # Sample noise and apply it to targets:
240 | noise = (
241 | samplers_utils.spherical_white_noise_like(targets) * noise_levels
242 | )
243 | noisy_targets = targets + noise
244 |
245 | denoised_predictions = self._preconditioned_denoiser(
246 | inputs, noisy_targets, noise_levels, forcings)
247 |
248 | loss, diagnostics = losses.weighted_mse_per_level(
249 | denoised_predictions,
250 | targets,
251 | # Weights are same as we used for GraphCast.
252 | per_variable_weights={
253 | # Any variables not specified here are weighted as 1.0.
254 | # A single-level variable, but an important headline variable
255 | # and also one which we have struggled to get good performance
256 | # on at short lead times, so leaving it weighted at 1.0, equal
257 | # to the multi-level variables:
258 | '2m_temperature': 1.0,
259 | # New single-level variables, which we don't weight too highly
260 | # to avoid hurting performance on other variables.
261 | '10m_u_component_of_wind': 0.1,
262 | '10m_v_component_of_wind': 0.1,
263 | 'mean_sea_level_pressure': 0.1,
264 | 'sea_surface_temperature': 0.1,
265 | 'total_precipitation_12hr': 0.1
266 | },
267 | )
268 | loss *= self._loss_weighting(noise_levels)
269 | return loss, diagnostics
270 |
271 | def __call__(self,
272 | inputs: xarray.Dataset,
273 | targets_template: xarray.Dataset,
274 | forcings: Optional[xarray.Dataset] = None,
275 | **kwargs) -> xarray.Dataset:
276 | if self._sampler_config is None:
277 | raise ValueError(
278 | 'Sampler config must be specified to run inference on GenCast.'
279 | )
280 | if self._sampler is None:
281 | self._sampler = dpm_solver_plus_plus_2s.Sampler(
282 | self._preconditioned_denoiser, **self._sampler_config
283 | )
284 | return self._sampler(inputs, targets_template, forcings, **kwargs)
285 |
--------------------------------------------------------------------------------
/graphcast/grid_mesh_connectivity.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tools for converting from regular grids on a sphere, to triangular meshes."""
15 |
16 | from graphcast import icosahedral_mesh
17 | import numpy as np
18 | import scipy
19 | import trimesh
20 |
21 |
22 | def _grid_lat_lon_to_coordinates(
23 | grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray:
24 | """Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3]."""
25 | # Convert to spherical coordinates phi and theta defined in the grid.
26 | # Each [num_latitude_points, num_longitude_points]
27 | phi_grid, theta_grid = np.meshgrid(
28 | np.deg2rad(grid_longitude),
29 | np.deg2rad(90 - grid_latitude))
30 |
31 | # [num_latitude_points, num_longitude_points, 3]
32 | # Note this assumes unit radius, since for now we model the earth as a
33 | # sphere of unit radius, and keep any vertical dimension as a regular grid.
34 | return np.stack(
35 | [np.cos(phi_grid)*np.sin(theta_grid),
36 | np.sin(phi_grid)*np.sin(theta_grid),
37 | np.cos(theta_grid)], axis=-1)
38 |
39 |
40 | def radius_query_indices(
41 | *,
42 | grid_latitude: np.ndarray,
43 | grid_longitude: np.ndarray,
44 | mesh: icosahedral_mesh.TriangularMesh,
45 | radius: float) -> tuple[np.ndarray, np.ndarray]:
46 | """Returns mesh-grid edge indices for radius query.
47 |
48 | Args:
49 | grid_latitude: Latitude values for the grid [num_lat_points]
50 | grid_longitude: Longitude values for the grid [num_lon_points]
51 | mesh: Mesh object.
52 | radius: Radius of connectivity in R3. for a sphere of unit radius.
53 |
54 | Returns:
55 | tuple with `grid_indices` and `mesh_indices` indicating edges between the
56 | grid and the mesh such that the distances in a straight line (not geodesic)
57 | are smaller than or equal to `radius`.
58 | * grid_indices: Indices of shape [num_edges], that index into a
59 | [num_lat_points, num_lon_points] grid, after flattening the leading axes.
60 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
61 | """
62 |
63 | # [num_grid_points=num_lat_points * num_lon_points, 3]
64 | grid_positions = _grid_lat_lon_to_coordinates(
65 | grid_latitude, grid_longitude).reshape([-1, 3])
66 |
67 | # [num_mesh_points, 3]
68 | mesh_positions = mesh.vertices
69 | kd_tree = scipy.spatial.cKDTree(mesh_positions)
70 |
71 | # [num_grid_points, num_mesh_points_per_grid_point]
72 | # Note `num_mesh_points_per_grid_point` is not constant, so this is a list
73 | # of arrays, rather than a 2d array.
74 | query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius)
75 |
76 | grid_edge_indices = []
77 | mesh_edge_indices = []
78 | for grid_index, mesh_neighbors in enumerate(query_indices):
79 | grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors)))
80 | mesh_edge_indices.append(mesh_neighbors)
81 |
82 | # [num_edges]
83 | grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int)
84 | mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int)
85 |
86 | return grid_edge_indices, mesh_edge_indices
87 |
88 |
89 | def in_mesh_triangle_indices(
90 | *,
91 | grid_latitude: np.ndarray,
92 | grid_longitude: np.ndarray,
93 | mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]:
94 | """Returns mesh-grid edge indices for grid points contained in mesh triangles.
95 |
96 | Args:
97 | grid_latitude: Latitude values for the grid [num_lat_points]
98 | grid_longitude: Longitude values for the grid [num_lon_points]
99 | mesh: Mesh object.
100 |
101 | Returns:
102 | tuple with `grid_indices` and `mesh_indices` indicating edges between the
103 | grid and the mesh vertices of the triangle that contain each grid point.
104 | The number of edges is always num_lat_points * num_lon_points * 3
105 | * grid_indices: Indices of shape [num_edges], that index into a
106 | [num_lat_points, num_lon_points] grid, after flattening the leading axes.
107 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
108 | """
109 |
110 | # [num_grid_points=num_lat_points * num_lon_points, 3]
111 | grid_positions = _grid_lat_lon_to_coordinates(
112 | grid_latitude, grid_longitude).reshape([-1, 3])
113 |
114 | mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
115 |
116 | # [num_grid_points] with mesh face indices for each grid point.
117 | _, _, query_face_indices = trimesh.proximity.closest_point(
118 | mesh_trimesh, grid_positions)
119 |
120 | # [num_grid_points, 3] with mesh node indices for each grid point.
121 | mesh_edge_indices = mesh.faces[query_face_indices]
122 |
123 | # [num_grid_points, 3] with grid node indices, where every row simply contains
124 | # the row (grid_point) index.
125 | grid_indices = np.arange(grid_positions.shape[0])
126 | grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3])
127 |
128 | # Flatten to get a regular list.
129 | # [num_edges=num_grid_points*3]
130 | mesh_edge_indices = mesh_edge_indices.reshape([-1])
131 | grid_edge_indices = grid_edge_indices.reshape([-1])
132 |
133 | return grid_edge_indices, mesh_edge_indices
134 |
--------------------------------------------------------------------------------
/graphcast/grid_mesh_connectivity_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for graphcast.grid_mesh_connectivity."""
15 |
16 | from absl.testing import absltest
17 | from graphcast import grid_mesh_connectivity
18 | from graphcast import icosahedral_mesh
19 | import numpy as np
20 |
21 |
22 | class GridMeshConnectivityTest(absltest.TestCase):
23 |
24 | def test_grid_lat_lon_to_coordinates(self):
25 |
26 | # Intervals of 30 degrees.
27 | grid_latitude = np.array([-45., 0., 45])
28 | grid_longitude = np.array([0., 90., 180., 270.])
29 |
30 | inv_sqrt2 = 1 / np.sqrt(2)
31 | expected_coordinates = np.array([
32 | [[inv_sqrt2, 0., -inv_sqrt2],
33 | [0., inv_sqrt2, -inv_sqrt2],
34 | [-inv_sqrt2, 0., -inv_sqrt2],
35 | [0., -inv_sqrt2, -inv_sqrt2]],
36 | [[1., 0., 0.],
37 | [0., 1., 0.],
38 | [-1., 0., 0.],
39 | [0., -1., 0.]],
40 | [[inv_sqrt2, 0., inv_sqrt2],
41 | [0., inv_sqrt2, inv_sqrt2],
42 | [-inv_sqrt2, 0., inv_sqrt2],
43 | [0., -inv_sqrt2, inv_sqrt2]],
44 | ])
45 |
46 | coordinates = grid_mesh_connectivity._grid_lat_lon_to_coordinates(
47 | grid_latitude, grid_longitude)
48 | np.testing.assert_allclose(expected_coordinates, coordinates, atol=1e-15)
49 |
50 | def test_radius_query_indices_smoke(self):
51 | # TODO(alvarosg): Add non-smoke test?
52 | grid_latitude = np.linspace(-75, 75, 6)
53 | grid_longitude = np.arange(12) * 30.
54 | mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
55 | splits=3)[-1]
56 | grid_mesh_connectivity.radius_query_indices(
57 | grid_latitude=grid_latitude,
58 | grid_longitude=grid_longitude,
59 | mesh=mesh, radius=0.2)
60 |
61 | def test_in_mesh_triangle_indices_smoke(self):
62 | # TODO(alvarosg): Add non-smoke test?
63 | grid_latitude = np.linspace(-75, 75, 6)
64 | grid_longitude = np.arange(12) * 30.
65 | mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
66 | splits=3)[-1]
67 | grid_mesh_connectivity.in_mesh_triangle_indices(
68 | grid_latitude=grid_latitude,
69 | grid_longitude=grid_longitude,
70 | mesh=mesh)
71 |
72 |
73 | if __name__ == "__main__":
74 | absltest.main()
75 |
--------------------------------------------------------------------------------
/graphcast/icosahedral_mesh.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils for creating icosahedral meshes."""
15 |
16 | import itertools
17 | from typing import List, NamedTuple, Sequence, Tuple
18 |
19 | import numpy as np
20 | from scipy.spatial import transform
21 |
22 |
23 | class TriangularMesh(NamedTuple):
24 | """Data structure for triangular meshes.
25 |
26 | Attributes:
27 | vertices: spatial positions of the vertices of the mesh of shape
28 | [num_vertices, num_dims].
29 | faces: triangular faces of the mesh of shape [num_faces, 3]. Contains
30 | integer indices into `vertices`.
31 |
32 | """
33 | vertices: np.ndarray
34 | faces: np.ndarray
35 |
36 |
37 | def merge_meshes(
38 | mesh_list: Sequence[TriangularMesh]) -> TriangularMesh:
39 | """Merges all meshes into one. Assumes the last mesh is the finest.
40 |
41 | Args:
42 | mesh_list: Sequence of meshes, from coarse to fine refinement levels. The
43 | vertices and faces may contain those from preceding, coarser levels.
44 |
45 | Returns:
46 | `TriangularMesh` for which the vertices correspond to the highest
47 | resolution mesh in the hierarchy, and the faces are the join set of the
48 | faces at all levels of the hierarchy.
49 | """
50 | for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list):
51 | num_nodes_mesh_i = mesh_i.vertices.shape[0]
52 | assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i])
53 |
54 | return TriangularMesh(
55 | vertices=mesh_list[-1].vertices,
56 | faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0))
57 |
58 |
59 | def get_hierarchy_of_triangular_meshes_for_sphere(
60 | splits: int) -> List[TriangularMesh]:
61 | """Returns a sequence of meshes, each with triangularization sphere.
62 |
63 | Starting with a regular icosahedron (12 vertices, 20 faces, 30 edges) with
64 | circumscribed unit sphere. Then, each triangular face is iteratively
65 | subdivided into 4 triangular faces `splits` times. The new vertices are then
66 | projected back onto the unit sphere. All resulting meshes are returned in a
67 | list, from lowest to highest resolution.
68 |
69 | The vertices in each face are specified in counter-clockwise order as
70 | observed from the outside the icosahedron.
71 |
72 | Args:
73 | splits: How many times to split each triangle.
74 | Returns:
75 | Sequence of `TriangularMesh`s of length `splits + 1` each with:
76 |
77 | vertices: [num_vertices, 3] vertex positions in 3D, all with unit norm.
78 | faces: [num_faces, 3] with triangular faces joining sets of 3 vertices.
79 | Each row contains three indices into the vertices array, indicating
80 | the vertices adjacent to the face. Always with positive orientation
81 | (counterclock-wise when looking from the outside).
82 | """
83 | current_mesh = get_icosahedron()
84 | output_meshes = [current_mesh]
85 | for _ in range(splits):
86 | current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh)
87 | output_meshes.append(current_mesh)
88 | return output_meshes
89 |
90 |
91 | def get_icosahedron() -> TriangularMesh:
92 | """Returns a regular icosahedral mesh with circumscribed unit sphere.
93 |
94 | See https://en.wikipedia.org/wiki/Regular_icosahedron#Cartesian_coordinates
95 | for details on the construction of the regular icosahedron.
96 |
97 | The vertices in each face are specified in counter-clockwise order as observed
98 | from the outside of the icosahedron.
99 |
100 | Returns:
101 | TriangularMesh with:
102 |
103 | vertices: [num_vertices=12, 3] vertex positions in 3D, all with unit norm.
104 | faces: [num_faces=20, 3] with triangular faces joining sets of 3 vertices.
105 | Each row contains three indices into the vertices array, indicating
106 | the vertices adjacent to the face. Always with positive orientation (
107 | counterclock-wise when looking from the outside).
108 |
109 | """
110 | phi = (1 + np.sqrt(5)) / 2
111 | vertices = []
112 | for c1 in [1., -1.]:
113 | for c2 in [phi, -phi]:
114 | vertices.append((c1, c2, 0.))
115 | vertices.append((0., c1, c2))
116 | vertices.append((c2, 0., c1))
117 |
118 | vertices = np.array(vertices, dtype=np.float32)
119 | vertices /= np.linalg.norm([1., phi])
120 |
121 | # I did this manually, checking the orientation one by one.
122 | faces = [(0, 1, 2),
123 | (0, 6, 1),
124 | (8, 0, 2),
125 | (8, 4, 0),
126 | (3, 8, 2),
127 | (3, 2, 7),
128 | (7, 2, 1),
129 | (0, 4, 6),
130 | (4, 11, 6),
131 | (6, 11, 5),
132 | (1, 5, 7),
133 | (4, 10, 11),
134 | (4, 8, 10),
135 | (10, 8, 3),
136 | (10, 3, 9),
137 | (11, 10, 9),
138 | (11, 9, 5),
139 | (5, 9, 7),
140 | (9, 3, 7),
141 | (1, 6, 5),
142 | ]
143 |
144 | # By default the top is an aris parallel to the Y axis.
145 | # Need to rotate around the y axis by half the supplementary to the
146 | # angle between faces divided by two to get the desired orientation.
147 | # /O\ (top arist)
148 | # / \ Z
149 | # (adjacent face)/ \ (adjacent face) ^
150 | # / angle_between_faces \ |
151 | # / \ |
152 | # / \ YO-----> X
153 | # This results in:
154 | # (adjacent faceis now top plane)
155 | # ----------------------O\ (top arist)
156 | # \
157 | # \
158 | # \ (adjacent face)
159 | # \
160 | # \
161 | # \
162 |
163 | angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3))
164 | rotation_angle = (np.pi - angle_between_faces) / 2
165 | rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle)
166 | rotation_matrix = rotation.as_matrix()
167 | vertices = np.dot(vertices, rotation_matrix)
168 |
169 | return TriangularMesh(vertices=vertices.astype(np.float32),
170 | faces=np.array(faces, dtype=np.int32))
171 |
172 |
173 | def _two_split_unit_sphere_triangle_faces(
174 | triangular_mesh: TriangularMesh) -> TriangularMesh:
175 | """Splits each triangular face into 4 triangles keeping the orientation."""
176 |
177 | # Every time we split a triangle into 4 we will be adding 3 extra vertices,
178 | # located at the edge centres.
179 | # This class handles the positioning of the new vertices, and avoids creating
180 | # duplicates.
181 | new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices)
182 |
183 | new_faces = []
184 | for ind1, ind2, ind3 in triangular_mesh.faces:
185 | # Transform each triangular face into 4 triangles,
186 | # preserving the orientation.
187 | # ind3
188 | # / \
189 | # / \
190 | # / #3 \
191 | # / \
192 | # ind31 -------------- ind23
193 | # / \ / \
194 | # / \ #4 / \
195 | # / #1 \ / #2 \
196 | # / \ / \
197 | # ind1 ------------ ind12 ------------ ind2
198 | ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2))
199 | ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3))
200 | ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1))
201 | # Note how each of the 4 triangular new faces specifies the order of the
202 | # vertices to preserve the orientation of the original face. As the input
203 | # face should always be counter-clockwise as specified in the diagram,
204 | # this means child faces should also be counter-clockwise.
205 | new_faces.extend([[ind1, ind12, ind31], # 1
206 | [ind12, ind2, ind23], # 2
207 | [ind31, ind23, ind3], # 3
208 | [ind12, ind23, ind31], # 4
209 | ])
210 | return TriangularMesh(vertices=new_vertices_builder.get_all_vertices(),
211 | faces=np.array(new_faces, dtype=np.int32))
212 |
213 |
214 | class _ChildVerticesBuilder(object):
215 | """Bookkeeping of new child vertices added to an existing set of vertices."""
216 |
217 | def __init__(self, parent_vertices):
218 |
219 | # Because the same new vertex will be required when splitting adjacent
220 | # triangles (which share an edge) we keep them in a hash table indexed by
221 | # sorted indices of the vertices adjacent to the edge, to avoid creating
222 | # duplicated child vertices.
223 | self._child_vertices_index_mapping = {}
224 | self._parent_vertices = parent_vertices
225 | # We start with all previous vertices.
226 | self._all_vertices_list = list(parent_vertices)
227 |
228 | def _get_child_vertex_key(self, parent_vertex_indices):
229 | return tuple(sorted(parent_vertex_indices))
230 |
231 | def _create_child_vertex(self, parent_vertex_indices):
232 | """Creates a new vertex."""
233 | # Position for new vertex is the middle point, between the parent points,
234 | # projected to unit sphere.
235 | child_vertex_position = self._parent_vertices[
236 | list(parent_vertex_indices)].mean(0)
237 | child_vertex_position /= np.linalg.norm(child_vertex_position)
238 |
239 | # Add the vertex to the output list. The index for this new vertex will
240 | # match the length of the list before adding it.
241 | child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
242 | self._child_vertices_index_mapping[child_vertex_key] = len(
243 | self._all_vertices_list)
244 | self._all_vertices_list.append(child_vertex_position)
245 |
246 | def get_new_child_vertex_index(self, parent_vertex_indices):
247 | """Returns index for a child vertex, creating it if necessary."""
248 | # Get the key to see if we already have a new vertex in the middle.
249 | child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
250 | if child_vertex_key not in self._child_vertices_index_mapping:
251 | self._create_child_vertex(parent_vertex_indices)
252 | return self._child_vertices_index_mapping[child_vertex_key]
253 |
254 | def get_all_vertices(self):
255 | """Returns an array with old vertices."""
256 | return np.array(self._all_vertices_list)
257 |
258 |
259 | def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
260 | """Transforms polygonal faces to sender and receiver indices.
261 |
262 | It does so by transforming every face into N_i edges. Such if the triangular
263 | face has indices [0, 1, 2], three edges are added 0->1, 1->2, and 2->0.
264 |
265 | If all faces have consistent orientation, and the surface represented by the
266 | faces is closed, then every edge in a polygon with a certain orientation
267 | is also part of another polygon with the opposite orientation. In this
268 | situation, the edges returned by the method are always bidirectional.
269 |
270 | Args:
271 | faces: Integer array of shape [num_faces, 3]. Contains node indices
272 | adjacent to each face.
273 | Returns:
274 | Tuple with sender/receiver indices, each of shape [num_edges=num_faces*3].
275 |
276 | """
277 | assert faces.ndim == 2
278 | assert faces.shape[-1] == 3
279 | senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]])
280 | receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]])
281 | return senders, receivers
282 |
283 |
284 | def get_last_triangular_mesh_for_sphere(splits: int) -> TriangularMesh:
285 | return get_hierarchy_of_triangular_meshes_for_sphere(splits=splits)[-1]
286 |
--------------------------------------------------------------------------------
/graphcast/icosahedral_mesh_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for icosahedral_mesh."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | import chex
19 | from graphcast import icosahedral_mesh
20 | import numpy as np
21 |
22 |
23 | def _get_mesh_spec(splits: int):
24 | """Returns size of the final icosahedral mesh resulting from the splitting."""
25 | num_vertices = 12
26 | num_faces = 20
27 | for _ in range(splits):
28 | # Each previous face adds three new vertices, but each vertex is shared
29 | # by two faces.
30 | num_vertices += num_faces * 3 // 2
31 | num_faces *= 4
32 | return num_vertices, num_faces
33 |
34 |
35 | class IcosahedralMeshTest(parameterized.TestCase):
36 |
37 | def test_icosahedron(self):
38 | mesh = icosahedral_mesh.get_icosahedron()
39 | _assert_valid_mesh(
40 | mesh, num_expected_vertices=12, num_expected_faces=20)
41 |
42 | @parameterized.parameters(list(range(5)))
43 | def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits):
44 | meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
45 | splits=splits)
46 | prev_vertices = None
47 | for mesh_i, mesh in enumerate(meshes):
48 | # Check that `mesh` is valid.
49 | num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i)
50 | _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces)
51 |
52 | # Check that the first N vertices from this mesh match all of the
53 | # vertices from the previous mesh.
54 | if prev_vertices is not None:
55 | leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]]
56 | np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices)
57 |
58 | # Increase the expected/previous values for the next iteration.
59 | if mesh_i < len(meshes) - 1:
60 | prev_vertices = mesh.vertices
61 |
62 | @parameterized.parameters(list(range(4)))
63 | def test_merge_meshes(self, splits):
64 | mesh_hierarchy = (
65 | icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
66 | splits=splits))
67 | mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy)
68 |
69 | expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0)
70 | np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices)
71 | np.testing.assert_array_equal(mesh.faces, expected_faces)
72 |
73 | def test_faces_to_edges(self):
74 |
75 | faces = np.array([[0, 1, 2],
76 | [3, 4, 5]])
77 |
78 | # This also documents the order of the edges returned by the method.
79 | expected_edges = np.array(
80 | [[0, 1],
81 | [3, 4],
82 | [1, 2],
83 | [4, 5],
84 | [2, 0],
85 | [5, 3]])
86 | expected_senders = expected_edges[:, 0]
87 | expected_receivers = expected_edges[:, 1]
88 |
89 | senders, receivers = icosahedral_mesh.faces_to_edges(faces)
90 |
91 | np.testing.assert_array_equal(senders, expected_senders)
92 | np.testing.assert_array_equal(receivers, expected_receivers)
93 |
94 |
95 | def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces):
96 | vertices = mesh.vertices
97 | faces = mesh.faces
98 | chex.assert_shape(vertices, [num_expected_vertices, 3])
99 | chex.assert_shape(faces, [num_expected_faces, 3])
100 |
101 | # Vertices norm should be 1.
102 | vertices_norm = np.linalg.norm(vertices, axis=-1)
103 | np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6)
104 |
105 | _assert_positive_face_orientation(vertices, faces)
106 |
107 |
108 | def _assert_positive_face_orientation(vertices, faces):
109 |
110 | # Obtain a unit vector that points, in the direction of the face.
111 | face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]],
112 | vertices[faces[:, 2]] - vertices[faces[:, 1]])
113 | face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True)
114 |
115 | # And a unit vector pointing from the origin to the center of the face.
116 | face_centers = vertices[faces].mean(1)
117 | face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True)
118 |
119 | # Positive orientation means those two vectors should be parallel
120 | # (dot product, 1), and not anti-parallel (dot product, -1).
121 | dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers)
122 |
123 | # Check that the face normal is parallel to the vector that joins the center
124 | # of the face to the center of the sphere. Note we need a small tolerance
125 | # because some discretizations are not exactly uniform, so it will not be
126 | # exactly parallel.
127 | np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4)
128 |
129 |
130 | if __name__ == "__main__":
131 | absltest.main()
132 |
--------------------------------------------------------------------------------
/graphcast/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Loss functions (and terms for use in loss functions) used for weather."""
15 |
16 | from typing import Mapping
17 |
18 | from graphcast import xarray_tree
19 | import numpy as np
20 | from typing_extensions import Protocol
21 | import xarray
22 |
23 |
24 | LossAndDiagnostics = tuple[xarray.DataArray, xarray.Dataset]
25 |
26 |
27 | class LossFunction(Protocol):
28 | """A loss function.
29 |
30 | This is a protocol so it's fine to use a plain function which 'quacks like'
31 | this. This is just to document the interface.
32 | """
33 |
34 | def __call__(self,
35 | predictions: xarray.Dataset,
36 | targets: xarray.Dataset,
37 | **optional_kwargs) -> LossAndDiagnostics:
38 | """Computes a loss function.
39 |
40 | Args:
41 | predictions: Dataset of predictions.
42 | targets: Dataset of targets.
43 | **optional_kwargs: Implementations may support extra optional kwargs.
44 |
45 | Returns:
46 | loss: A DataArray with dimensions ('batch',) containing losses for each
47 | element of the batch. These will be averaged to give the final
48 | loss, locally and across replicas.
49 | diagnostics: Mapping of additional quantities to log by name alongside the
50 | loss. These will will typically correspond to terms in the loss. They
51 | should also have dimensions ('batch',) and will be averaged over the
52 | batch before logging.
53 | """
54 |
55 |
56 | def weighted_mse_per_level(
57 | predictions: xarray.Dataset,
58 | targets: xarray.Dataset,
59 | per_variable_weights: Mapping[str, float],
60 | ) -> LossAndDiagnostics:
61 | """Latitude- and pressure-level-weighted MSE loss."""
62 | def loss(prediction, target):
63 | loss = (prediction - target)**2
64 | loss *= normalized_latitude_weights(target).astype(loss.dtype)
65 | if 'level' in target.dims:
66 | loss *= normalized_level_weights(target).astype(loss.dtype)
67 | return _mean_preserving_batch(loss)
68 |
69 | losses = xarray_tree.map_structure(loss, predictions, targets)
70 | return sum_per_variable_losses(losses, per_variable_weights)
71 |
72 |
73 | def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray:
74 | return x.mean([d for d in x.dims if d != 'batch'], skipna=False)
75 |
76 |
77 | def sum_per_variable_losses(
78 | per_variable_losses: Mapping[str, xarray.DataArray],
79 | weights: Mapping[str, float],
80 | ) -> LossAndDiagnostics:
81 | """Weighted sum of per-variable losses."""
82 | if not set(weights.keys()).issubset(set(per_variable_losses.keys())):
83 | raise ValueError(
84 | 'Passing a weight that does not correspond to any variable '
85 | f'{set(weights.keys())-set(per_variable_losses.keys())}')
86 |
87 | weighted_per_variable_losses = {
88 | name: loss * weights.get(name, 1)
89 | for name, loss in per_variable_losses.items()
90 | }
91 | total = xarray.concat(
92 | weighted_per_variable_losses.values(), dim='variable', join='exact').sum(
93 | 'variable', skipna=False)
94 | return total, per_variable_losses # pytype: disable=bad-return-type
95 |
96 |
97 | def normalized_level_weights(data: xarray.DataArray) -> xarray.DataArray:
98 | """Weights proportional to pressure at each level."""
99 | level = data.coords['level']
100 | return level / level.mean(skipna=False)
101 |
102 |
103 | def normalized_latitude_weights(data: xarray.DataArray) -> xarray.DataArray:
104 | """Weights based on latitude, roughly proportional to grid cell area.
105 |
106 | This method supports two use cases only (both for equispaced values):
107 | * Latitude values such that the closest value to the pole is at latitude
108 | (90 - d_lat/2), where d_lat is the difference between contiguous latitudes.
109 | For example: [-89, -87, -85, ..., 85, 87, 89]) (d_lat = 2)
110 | In this case each point with `lat` value represents a sphere slice between
111 | `lat - d_lat/2` and `lat + d_lat/2`, and the area of this slice would be
112 | proportional to:
113 | `sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)`, and
114 | we can simply omit the term `2 * sin(d_lat/2)` which is just a constant
115 | that cancels during normalization.
116 | * Latitude values that fall exactly at the poles.
117 | For example: [-90, -88, -86, ..., 86, 88, 90]) (d_lat = 2)
118 | In this case each point with `lat` value also represents
119 | a sphere slice between `lat - d_lat/2` and `lat + d_lat/2`,
120 | except for the points at the poles, that represent a slice between
121 | `90 - d_lat/2` and `90` or, `-90` and `-90 + d_lat/2`.
122 | The areas of the first type of point are still proportional to:
123 | * sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)
124 | but for the points at the poles now is:
125 | * sin(90) - sin(90 - d_lat/2) = 2 * sin(d_lat/4) ^ 2
126 | and we will be using these weights, depending on whether we are looking at
127 | pole cells, or non-pole cells (omitting the common factor of 2 which will be
128 | absorbed by the normalization).
129 |
130 | It can be shown via a limit, or simple geometry, that in the small angles
131 | regime, the proportion of area per pole-point is equal to 1/8th
132 | the proportion of area covered by each of the nearest non-pole point, and we
133 | test for this in the test.
134 |
135 | Args:
136 | data: `DataArray` with latitude coordinates.
137 | Returns:
138 | Unit mean latitude weights.
139 | """
140 | latitude = data.coords['lat']
141 |
142 | if np.any(np.isclose(np.abs(latitude), 90.)):
143 | weights = _weight_for_latitude_vector_with_poles(latitude)
144 | else:
145 | weights = _weight_for_latitude_vector_without_poles(latitude)
146 |
147 | return weights / weights.mean(skipna=False)
148 |
149 |
150 | def _weight_for_latitude_vector_without_poles(latitude):
151 | """Weights for uniform latitudes of the form [+-90-+d/2, ..., -+90+-d/2]."""
152 | delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
153 | if (not np.isclose(np.max(latitude), 90 - delta_latitude/2) or
154 | not np.isclose(np.min(latitude), -90 + delta_latitude/2)):
155 | raise ValueError(
156 | f'Latitude vector {latitude} does not start/end at '
157 | '+- (90 - delta_latitude/2) degrees.')
158 | return np.cos(np.deg2rad(latitude))
159 |
160 |
161 | def _weight_for_latitude_vector_with_poles(latitude):
162 | """Weights for uniform latitudes of the form [+- 90, ..., -+90]."""
163 | delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
164 | if (not np.isclose(np.max(latitude), 90.) or
165 | not np.isclose(np.min(latitude), -90.)):
166 | raise ValueError(
167 | f'Latitude vector {latitude} does not start/end at +- 90 degrees.')
168 | weights = np.cos(np.deg2rad(latitude)) * np.sin(np.deg2rad(delta_latitude/2))
169 | # The two checks above enough to guarantee that latitudes are sorted, so
170 | # the extremes are the poles
171 | weights[[0, -1]] = np.sin(np.deg2rad(delta_latitude/4)) ** 2
172 | return weights
173 |
174 |
175 | def _check_uniform_spacing_and_get_delta(vector):
176 | diff = np.diff(vector)
177 | if not np.all(np.isclose(diff[0], diff)):
178 | raise ValueError(f'Vector {diff} is not uniformly spaced.')
179 | return diff[0]
180 |
--------------------------------------------------------------------------------
/graphcast/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Constructors for MLPs."""
15 |
16 | import haiku as hk
17 | import jax
18 | import jax.numpy as jnp
19 |
20 |
21 | # TODO(aelkadi): Move the mlp factory here from `deep_typed_graph_net.py`.
22 |
23 |
24 | class LinearNormConditioning(hk.Module):
25 | """Module for norm conditioning.
26 |
27 | Conditions the normalization of "inputs" by applying a linear layer to the
28 | "norm_conditioning" which produces the scale and variance which are applied to
29 | each channel (across the last dim) of "inputs".
30 | """
31 |
32 | def __init__(self, name="norm_conditioning"):
33 | super().__init__(name=name)
34 |
35 | def __call__(self, inputs: jax.Array, norm_conditioning: jax.Array):
36 |
37 | feature_size = inputs.shape[-1]
38 | conditional_linear_layer = hk.Linear(
39 | output_size=2 * feature_size,
40 | w_init=hk.initializers.TruncatedNormal(stddev=1e-8),
41 | )
42 | conditional_scale_offset = conditional_linear_layer(norm_conditioning)
43 | scale_minus_one, offset = jnp.split(conditional_scale_offset, 2, axis=-1)
44 | scale = scale_minus_one + 1.
45 | return inputs * scale + offset
46 |
--------------------------------------------------------------------------------
/graphcast/nan_cleaning.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Wrappers for Predictors which allow them to work with data cleaned of NaNs.
15 |
16 | The Predictor which is wrapped sees inputs and targets without NaNs, and makes
17 | NaNless predictions.
18 | """
19 |
20 | from typing import Optional, Tuple
21 |
22 | from graphcast import predictor_base as base
23 | import numpy as np
24 | import xarray
25 |
26 |
27 | class NaNCleaner(base.Predictor):
28 | """A predictor wrapper than removes NaNs from ingested data.
29 |
30 | The Predictor which is wrapped sees inputs and targets without NaNs.
31 | """
32 |
33 | def __init__(
34 | self,
35 | predictor: base.Predictor,
36 | var_to_clean: str,
37 | fill_value: xarray.Dataset,
38 | reintroduce_nans: bool = False,
39 | ):
40 | """Initializes the NaNCleaner."""
41 | self._predictor = predictor
42 | self._fill_value = fill_value[var_to_clean]
43 | self._var_to_clean = var_to_clean
44 | self._reintroduce_nans = reintroduce_nans
45 |
46 | def _clean(self, dataset: xarray.Dataset) -> xarray.Dataset:
47 | """Cleans the dataset of NaNs."""
48 | data_array = dataset[self._var_to_clean]
49 | dataset = dataset.assign(
50 | {self._var_to_clean: data_array.fillna(self._fill_value)}
51 | )
52 | return dataset
53 |
54 | def _maybe_reintroduce_nans(
55 | self, stale_inputs: xarray.Dataset, predictions: xarray.Dataset
56 | ) -> xarray.Dataset:
57 | # NaN positions don't change between input frames, if they do then
58 | # we should be more careful about re-introducing them.
59 | if self._var_to_clean in predictions.keys():
60 | nan_mask = np.isnan(stale_inputs[self._var_to_clean]).any(dim='time')
61 | with_nan_values = predictions[self._var_to_clean].where(~nan_mask, np.nan)
62 | predictions = predictions.assign({self._var_to_clean: with_nan_values})
63 | return predictions
64 |
65 | def __call__(
66 | self,
67 | inputs: xarray.Dataset,
68 | targets_template: xarray.Dataset,
69 | forcings: Optional[xarray.Dataset] = None,
70 | **kwargs,
71 | ) -> xarray.Dataset:
72 | if self._reintroduce_nans:
73 | # Copy inputs before cleaning so that we can reintroduce NaNs later.
74 | original_inputs = inputs.copy()
75 | if self._var_to_clean in inputs.keys():
76 | inputs = self._clean(inputs)
77 | if forcings and self._var_to_clean in forcings.keys():
78 | forcings = self._clean(forcings)
79 | predictions = self._predictor(
80 | inputs, targets_template, forcings, **kwargs
81 | )
82 | if self._reintroduce_nans:
83 | predictions = self._maybe_reintroduce_nans(original_inputs, predictions)
84 | return predictions
85 |
86 | def loss(
87 | self,
88 | inputs: xarray.Dataset,
89 | targets: xarray.Dataset,
90 | forcings: Optional[xarray.Dataset] = None,
91 | **kwargs,
92 | ) -> base.LossAndDiagnostics:
93 | if self._var_to_clean in inputs.keys():
94 | inputs = self._clean(inputs)
95 | if self._var_to_clean in targets.keys():
96 | targets = self._clean(targets)
97 | if forcings and self._var_to_clean in forcings.keys():
98 | forcings = self._clean(forcings)
99 | return self._predictor.loss(
100 | inputs, targets, forcings, **kwargs
101 | )
102 |
103 | def loss_and_predictions(
104 | self,
105 | inputs: xarray.Dataset,
106 | targets: xarray.Dataset,
107 | forcings: Optional[xarray.Dataset] = None,
108 | **kwargs,
109 | ) -> Tuple[base.LossAndDiagnostics, xarray.Dataset]:
110 | if self._reintroduce_nans:
111 | # Copy inputs before cleaning so that we can reintroduce NaNs later.
112 | original_inputs = inputs.copy()
113 | if self._var_to_clean in inputs.keys():
114 | inputs = self._clean(inputs)
115 | if self._var_to_clean in targets.keys():
116 | targets = self._clean(targets)
117 | if forcings and self._var_to_clean in forcings.keys():
118 | forcings = self._clean(forcings)
119 |
120 | loss, predictions = self._predictor.loss_and_predictions(
121 | inputs, targets, forcings, **kwargs
122 | )
123 | if self._reintroduce_nans:
124 | predictions = self._maybe_reintroduce_nans(original_inputs, predictions)
125 | return loss, predictions
126 |
--------------------------------------------------------------------------------
/graphcast/normalization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Wrappers for Predictors which allow them to work with normalized data.
15 |
16 | The Predictor which is wrapped sees normalized inputs and targets, and makes
17 | normalized predictions. The wrapper handles translating the predictions back
18 | to the original domain.
19 | """
20 |
21 | import logging
22 | from typing import Optional, Tuple
23 |
24 | from graphcast import predictor_base
25 | from graphcast import xarray_tree
26 | import xarray
27 |
28 |
29 | def normalize(values: xarray.Dataset,
30 | scales: xarray.Dataset,
31 | locations: Optional[xarray.Dataset],
32 | ) -> xarray.Dataset:
33 | """Normalize variables using the given scales and (optionally) locations."""
34 | def normalize_array(array):
35 | if array.name is None:
36 | raise ValueError(
37 | "Can't look up normalization constants because array has no name.")
38 | if locations is not None:
39 | if array.name in locations:
40 | array = array - locations[array.name].astype(array.dtype)
41 | else:
42 | logging.warning('No normalization location found for %s', array.name)
43 | if array.name in scales:
44 | array = array / scales[array.name].astype(array.dtype)
45 | else:
46 | logging.warning('No normalization scale found for %s', array.name)
47 | return array
48 | return xarray_tree.map_structure(normalize_array, values)
49 |
50 |
51 | def unnormalize(values: xarray.Dataset,
52 | scales: xarray.Dataset,
53 | locations: Optional[xarray.Dataset],
54 | ) -> xarray.Dataset:
55 | """Unnormalize variables using the given scales and (optionally) locations."""
56 | def unnormalize_array(array):
57 | if array.name is None:
58 | raise ValueError(
59 | "Can't look up normalization constants because array has no name.")
60 | if array.name in scales:
61 | array = array * scales[array.name].astype(array.dtype)
62 | else:
63 | logging.warning('No normalization scale found for %s', array.name)
64 | if locations is not None:
65 | if array.name in locations:
66 | array = array + locations[array.name].astype(array.dtype)
67 | else:
68 | logging.warning('No normalization location found for %s', array.name)
69 | return array
70 | return xarray_tree.map_structure(unnormalize_array, values)
71 |
72 |
73 | class InputsAndResiduals(predictor_base.Predictor):
74 | """Wraps with a residual connection, normalizing inputs and target residuals.
75 |
76 | The inner predictor is given inputs that are normalized using `locations`
77 | and `scales` to roughly zero-mean unit variance.
78 |
79 | For target variables that are present in the inputs, the inner predictor is
80 | trained to predict residuals (target - last_frame_of_input) that have been
81 | normalized using `residual_scales` (and optionally `residual_locations`) to
82 | roughly unit variance / zero mean.
83 |
84 | This replaces `residual.Predictor` in the case where you want normalization
85 | that's based on the scales of the residuals.
86 |
87 | Since we return the underlying predictor's loss on the normalized residuals,
88 | if the underlying predictor is a sum of per-variable losses, the normalization
89 | will affect the relative weighting of the per-variable loss terms (hopefully
90 | in a good way).
91 |
92 | For target variables *not* present in the inputs, the inner predictor is
93 | trained to predict targets directly, that have been normalized in the same
94 | way as the inputs.
95 |
96 | The transforms applied to the targets (the residual connection and the
97 | normalization) are applied in reverse to the predictions before returning
98 | them.
99 | """
100 |
101 | def __init__(
102 | self,
103 | predictor: predictor_base.Predictor,
104 | stddev_by_level: xarray.Dataset,
105 | mean_by_level: xarray.Dataset,
106 | diffs_stddev_by_level: xarray.Dataset):
107 | self._predictor = predictor
108 | self._scales = stddev_by_level
109 | self._locations = mean_by_level
110 | self._residual_scales = diffs_stddev_by_level
111 | self._residual_locations = None
112 |
113 | def _unnormalize_prediction_and_add_input(self, inputs, norm_prediction):
114 | if norm_prediction.sizes.get('time') != 1:
115 | raise ValueError(
116 | 'normalization.InputsAndResiduals only supports predicting a '
117 | 'single timestep.')
118 | if norm_prediction.name in inputs:
119 | # Residuals are assumed to be predicted as normalized (unit variance),
120 | # but the scale and location they need mapping to is that of the residuals
121 | # not of the values themselves.
122 | prediction = unnormalize(
123 | norm_prediction, self._residual_scales, self._residual_locations)
124 | # A prediction for which we have a corresponding input -- we are
125 | # predicting the residual:
126 | last_input = inputs[norm_prediction.name].isel(time=-1)
127 | prediction = prediction + last_input
128 | return prediction
129 | else:
130 | # A predicted variable which is not an input variable. We are predicting
131 | # it directly, so unnormalize it directly to the target scale/location:
132 | return unnormalize(norm_prediction, self._scales, self._locations)
133 |
134 | def _subtract_input_and_normalize_target(self, inputs, target):
135 | if target.sizes.get('time') != 1:
136 | raise ValueError(
137 | 'normalization.InputsAndResiduals only supports wrapping predictors'
138 | 'that predict a single timestep.')
139 | if target.name in inputs:
140 | target_residual = target
141 | last_input = inputs[target.name].isel(time=-1)
142 | target_residual = target_residual - last_input
143 | return normalize(
144 | target_residual, self._residual_scales, self._residual_locations)
145 | else:
146 | return normalize(target, self._scales, self._locations)
147 |
148 | def __call__(self,
149 | inputs: xarray.Dataset,
150 | targets_template: xarray.Dataset,
151 | forcings: xarray.Dataset,
152 | **kwargs
153 | ) -> xarray.Dataset:
154 | norm_inputs = normalize(inputs, self._scales, self._locations)
155 | norm_forcings = normalize(forcings, self._scales, self._locations)
156 | norm_predictions = self._predictor(
157 | norm_inputs, targets_template, forcings=norm_forcings, **kwargs)
158 | return xarray_tree.map_structure(
159 | lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
160 | norm_predictions)
161 |
162 | def loss(self,
163 | inputs: xarray.Dataset,
164 | targets: xarray.Dataset,
165 | forcings: xarray.Dataset,
166 | **kwargs,
167 | ) -> predictor_base.LossAndDiagnostics:
168 | """Returns the loss computed on normalized inputs and targets."""
169 | norm_inputs = normalize(inputs, self._scales, self._locations)
170 | norm_forcings = normalize(forcings, self._scales, self._locations)
171 | norm_target_residuals = xarray_tree.map_structure(
172 | lambda t: self._subtract_input_and_normalize_target(inputs, t),
173 | targets)
174 | return self._predictor.loss(
175 | norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
176 |
177 | def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
178 | self,
179 | inputs: xarray.Dataset,
180 | targets: xarray.Dataset,
181 | forcings: xarray.Dataset,
182 | **kwargs,
183 | ) -> Tuple[predictor_base.LossAndDiagnostics,
184 | xarray.Dataset]:
185 | """The loss computed on normalized data, with unnormalized predictions."""
186 | norm_inputs = normalize(inputs, self._scales, self._locations)
187 | norm_forcings = normalize(forcings, self._scales, self._locations)
188 | norm_target_residuals = xarray_tree.map_structure(
189 | lambda t: self._subtract_input_and_normalize_target(inputs, t),
190 | targets)
191 | (loss, scalars), norm_predictions = self._predictor.loss_and_predictions(
192 | norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
193 | predictions = xarray_tree.map_structure(
194 | lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
195 | norm_predictions)
196 | return (loss, scalars), predictions
197 |
--------------------------------------------------------------------------------
/graphcast/predictor_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Abstract base classes for an xarray-based Predictor API."""
15 |
16 | import abc
17 |
18 | from typing import Tuple
19 |
20 | from graphcast import losses
21 | from graphcast import xarray_jax
22 | import jax.numpy as jnp
23 | import xarray
24 |
25 | LossAndDiagnostics = losses.LossAndDiagnostics
26 |
27 |
28 | class Predictor(abc.ABC):
29 | """A possibly-trainable predictor of weather, exposing an xarray-based API.
30 |
31 | Typically wraps an underlying JAX model and handles translating the xarray
32 | Dataset values to and from plain JAX arrays that are convenient for input to
33 | (and output from) the underlying model.
34 |
35 | Different subclasses may exist to wrap different kinds of underlying model,
36 | e.g. models taking stacked inputs/outputs, models taking separate 2D and 3D
37 | inputs/outputs, autoregressive models.
38 |
39 | You can also implement a specific model directly as a Predictor if you want,
40 | for example if it has quite specific/unique requirements for its input/output
41 | or loss function, or if it's convenient to implement directly using xarray.
42 | """
43 |
44 | @abc.abstractmethod
45 | def __call__(self,
46 | inputs: xarray.Dataset,
47 | targets_template: xarray.Dataset,
48 | forcings: xarray.Dataset,
49 | **optional_kwargs
50 | ) -> xarray.Dataset:
51 | """Makes predictions.
52 |
53 | This is only used by the Experiment for inference / evaluation, with
54 | training going via the .loss method. So it should default to making
55 | predictions for evaluation, although you can also support making predictions
56 | for use in the loss via an is_training argument -- see
57 | LossFunctionPredictor which helps with that.
58 |
59 | Args:
60 | inputs: An xarray.Dataset of inputs.
61 | targets_template: An xarray.Dataset or other mapping of xarray.DataArrays,
62 | with the same shape as the targets, to demonstrate what kind of
63 | predictions are required. You can use this to determine which variables,
64 | levels and lead times must be predicted.
65 | You are free to raise an error if you don't support predicting what is
66 | requested.
67 | forcings: An xarray.Dataset of forcings terms. Forcings are variables
68 | that can be fed to the model, but do not need to be predicted. This is
69 | often because this variable can be computed analytically (e.g. the toa
70 | radiation of the sun is mostly a function of geometry) or are considered
71 | to be controlled for the experiment (e.g., impose a scenario of C02
72 | emission into the atmosphere). Unlike `inputs`, the `forcings` can
73 | include information "from the future", that is, information at target
74 | times specified in the `targets_template`.
75 | **optional_kwargs: Implementations may support extra optional kwargs,
76 | provided they set appropriate defaults for them.
77 |
78 | Returns:
79 | Predictions, as an xarray.Dataset or other mapping of DataArrays which
80 | is capable of being evaluated against targets with shape given by
81 | targets_template.
82 | For probabilistic predictors which can return multiple samples from a
83 | predictive distribution, these should (by convention) be returned along
84 | an additional 'sample' dimension.
85 | """
86 |
87 | def loss(self,
88 | inputs: xarray.Dataset,
89 | targets: xarray.Dataset,
90 | forcings: xarray.Dataset,
91 | **optional_kwargs,
92 | ) -> LossAndDiagnostics:
93 | """Computes a training loss, for predictors that are trainable.
94 |
95 | Why make this the Predictor's responsibility, rather than letting callers
96 | compute their own loss function using predictions obtained from
97 | Predictor.__call__?
98 |
99 | Doing it this way gives Predictors more control over their training setup.
100 | For example, some predictors may wish to train using different targets to
101 | the ones they predict at evaluation time -- perhaps different lead times and
102 | variables, perhaps training to predict transformed versions of targets
103 | where the transform needs to be inverted at evaluation time, etc.
104 |
105 | It's also necessary for generative models (VAEs, GANs, ...) where the
106 | training loss is more complex and isn't expressible as a parameter-free
107 | function of predictions and targets.
108 |
109 | Args:
110 | inputs: An xarray.Dataset.
111 | targets: An xarray.Dataset or other mapping of xarray.DataArrays. See
112 | docs on __call__ for an explanation about the targets.
113 | forcings: xarray.Dataset of forcing terms.
114 | **optional_kwargs: Implementations may support extra optional kwargs,
115 | provided they set appropriate defaults for them.
116 |
117 | Returns:
118 | loss: A DataArray with dimensions ('batch',) containing losses for each
119 | element of the batch. These will be averaged to give the final
120 | loss, locally and across replicas.
121 | diagnostics: Mapping of additional quantities to log by name alongside the
122 | loss. These will will typically correspond to terms in the loss. They
123 | should also have dimensions ('batch',) and will be averaged over the
124 | batch before logging.
125 | You need not include the loss itself in this dict; it will be added for
126 | you.
127 | """
128 | del targets, forcings, optional_kwargs
129 | batch_size = inputs.sizes['batch']
130 | dummy_loss = xarray_jax.DataArray(jnp.zeros(batch_size), dims=('batch',))
131 | return dummy_loss, {} # pytype: disable=bad-return-type
132 |
133 | def loss_and_predictions(
134 | self,
135 | inputs: xarray.Dataset,
136 | targets: xarray.Dataset,
137 | forcings: xarray.Dataset,
138 | **optional_kwargs,
139 | ) -> Tuple[LossAndDiagnostics, xarray.Dataset]:
140 | """Like .loss but also returns corresponding predictions.
141 |
142 | Implementing this is optional as it's not used directly by the Experiment,
143 | but it is required by autoregressive.Predictor when applying an inner
144 | Predictor autoregressively at training time; we need a loss at each step but
145 | also predictions to feed back in for the next step.
146 |
147 | Note the loss itself may not be directly regressing the predictions towards
148 | targets, the loss may be computed in terms of transformed predictions and
149 | targets (or in some other way). For this reason we can't always cleanly
150 | separate this into step 1: get predictions, step 2: compute loss from them,
151 | hence the need for this combined method.
152 |
153 | Args:
154 | inputs:
155 | targets:
156 | forcings:
157 | **optional_kwargs:
158 | As for self.loss.
159 |
160 | Returns:
161 | (loss, diagnostics)
162 | As for self.loss
163 | predictions:
164 | The predictions which the loss relates to. These should be of the same
165 | shape as what you would get from
166 | `self.__call__(inputs, targets_template=targets)`, and should be in the
167 | same 'domain' as the inputs (i.e. they shouldn't be transformed
168 | differently to how the predictor expects its inputs).
169 | """
170 | raise NotImplementedError
171 |
--------------------------------------------------------------------------------
/graphcast/samplers_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Base class for diffusion samplers."""
15 |
16 | import abc
17 | from typing import Optional
18 |
19 | from graphcast import denoisers_base
20 | import xarray
21 |
22 |
23 | class Sampler(abc.ABC):
24 | """A sampling algorithm for a denoising diffusion model.
25 |
26 | This is constructed with a denoising function, and uses it to draw samples.
27 | """
28 |
29 | _denoiser: denoisers_base.Denoiser
30 |
31 | def __init__(self, denoiser: denoisers_base.Denoiser):
32 | """Constructs Sampler.
33 |
34 | Args:
35 | denoiser: A Denoiser which has been trained with an MSE loss to predict
36 | the noise-free targets.
37 | """
38 | self._denoiser = denoiser
39 |
40 | @abc.abstractmethod
41 | def __call__(
42 | self,
43 | inputs: xarray.Dataset,
44 | targets_template: xarray.Dataset,
45 | forcings: Optional[xarray.Dataset] = None,
46 | **kwargs) -> xarray.Dataset:
47 | """Draws a sample using self._denoiser. Contract like Predictor.__call__."""
48 |
--------------------------------------------------------------------------------
/graphcast/solar_radiation_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import timeit
15 | from typing import Sequence
16 |
17 | from absl import logging
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from graphcast import solar_radiation
21 | import numpy as np
22 | import pandas as pd
23 | import xarray as xa
24 |
25 |
26 | def _get_grid_lat_lon_coords(
27 | num_lat: int, num_lon: int
28 | ) -> tuple[np.ndarray, np.ndarray]:
29 | """Generates a linear latitude-longitude grid of the given size.
30 |
31 | Args:
32 | num_lat: Size of the latitude dimension of the grid.
33 | num_lon: Size of the longitude dimension of the grid.
34 |
35 | Returns:
36 | A tuple `(lat, lon)` containing 1D arrays with the latitude and longitude
37 | coordinates in degrees of the generated grid.
38 | """
39 | lat = np.linspace(-90.0, 90.0, num=num_lat, endpoint=True)
40 | lon = np.linspace(0.0, 360.0, num=num_lon, endpoint=False)
41 | return lat, lon
42 |
43 |
44 | class SolarRadiationTest(parameterized.TestCase):
45 |
46 | def setUp(self):
47 | super().setUp()
48 | np.random.seed(0)
49 |
50 | def test_missing_dim_raises_value_error(self):
51 | data = xa.DataArray(
52 | np.random.randn(2, 2),
53 | coords=[np.array([0.1, 0.2]), np.array([0.0, 0.5])],
54 | dims=["lon", "x"],
55 | )
56 | with self.assertRaisesRegex(
57 | ValueError, r".* dimensions are missing in `data_array_like`."
58 | ):
59 | solar_radiation.get_toa_incident_solar_radiation_for_xarray(
60 | data, integration_period="1h", num_integration_bins=360
61 | )
62 |
63 | def test_missing_coordinate_raises_value_error(self):
64 | data = xa.Dataset(
65 | data_vars={"var1": (["x", "lat", "lon"], np.random.randn(2, 3, 2))},
66 | coords={
67 | "lat": np.array([0.0, 0.1, 0.2]),
68 | "lon": np.array([0.0, 0.5]),
69 | },
70 | )
71 | with self.assertRaisesRegex(
72 | ValueError, r".* coordinates are missing in `data_array_like`."
73 | ):
74 | solar_radiation.get_toa_incident_solar_radiation_for_xarray(
75 | data, integration_period="1h", num_integration_bins=360
76 | )
77 |
78 | def test_shape_multiple_timestamps(self):
79 | data = xa.Dataset(
80 | data_vars={"var1": (["time", "lat", "lon"], np.random.randn(2, 4, 2))},
81 | coords={
82 | "lat": np.array([0.0, 0.1, 0.2, 0.3]),
83 | "lon": np.array([0.0, 0.5]),
84 | "time": np.array([100, 200], dtype="timedelta64[s]"),
85 | "datetime": xa.Variable(
86 | "time", np.array([10, 20], dtype="datetime64[D]")
87 | ),
88 | },
89 | )
90 |
91 | actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
92 | data, integration_period="1h", num_integration_bins=2
93 | )
94 |
95 | self.assertEqual(("time", "lat", "lon"), actual.dims)
96 | self.assertEqual((2, 4, 2), actual.shape)
97 |
98 | def test_shape_single_timestamp(self):
99 | data = xa.Dataset(
100 | data_vars={"var1": (["lat", "lon"], np.random.randn(4, 2))},
101 | coords={
102 | "lat": np.array([0.0, 0.1, 0.2, 0.3]),
103 | "lon": np.array([0.0, 0.5]),
104 | "datetime": np.datetime64(10, "D"),
105 | },
106 | )
107 |
108 | actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
109 | data, integration_period="1h", num_integration_bins=2
110 | )
111 |
112 | self.assertEqual(("lat", "lon"), actual.dims)
113 | self.assertEqual((4, 2), actual.shape)
114 |
115 | @parameterized.named_parameters(
116 | dict(
117 | testcase_name="one_timestamp_jitted",
118 | periods=1,
119 | repeats=3,
120 | use_jit=True,
121 | ),
122 | dict(
123 | testcase_name="one_timestamp_non_jitted",
124 | periods=1,
125 | repeats=3,
126 | use_jit=False,
127 | ),
128 | dict(
129 | testcase_name="ten_timestamps_non_jitted",
130 | periods=10,
131 | repeats=1,
132 | use_jit=False,
133 | ),
134 | )
135 | def test_full_spatial_resolution(
136 | self, periods: int, repeats: int, use_jit: bool
137 | ):
138 | timestamps = pd.date_range(start="2023-09-25", periods=periods, freq="6h")
139 | # Generate a linear grid with 0.25 degrees resolution similar to ERA5.
140 | lat, lon = _get_grid_lat_lon_coords(num_lat=721, num_lon=1440)
141 |
142 | def benchmark() -> None:
143 | solar_radiation.get_toa_incident_solar_radiation(
144 | timestamps,
145 | lat,
146 | lon,
147 | integration_period="1h",
148 | num_integration_bins=360,
149 | use_jit=use_jit,
150 | ).block_until_ready()
151 |
152 | results = timeit.repeat(benchmark, repeat=repeats, number=1)
153 |
154 | logging.info(
155 | "Times to compute `tisr` for input of shape `%d, %d, %d` (seconds): %s",
156 | len(timestamps),
157 | len(lat),
158 | len(lon),
159 | np.array2string(np.array(results), precision=1),
160 | )
161 |
162 |
163 | class GetTsiTest(parameterized.TestCase):
164 |
165 | @parameterized.named_parameters(
166 | dict(
167 | testcase_name="reference_tsi_data",
168 | loader=solar_radiation.reference_tsi_data,
169 | expected_tsi=np.array([1361.0]),
170 | ),
171 | dict(
172 | testcase_name="era5_tsi_data",
173 | loader=solar_radiation.era5_tsi_data,
174 | expected_tsi=np.array([1360.9440]), # 0.9965 * 1365.7240
175 | ),
176 | )
177 | def test_mid_2020_lookup(
178 | self, loader: solar_radiation.TsiDataLoader, expected_tsi: np.ndarray
179 | ):
180 | tsi_data = loader()
181 |
182 | tsi = solar_radiation.get_tsi(
183 | [np.datetime64("2020-07-02T00:00:00")], tsi_data
184 | )
185 |
186 | np.testing.assert_allclose(expected_tsi, tsi)
187 |
188 | @parameterized.named_parameters(
189 | dict(
190 | testcase_name="beginning_2020_left_boundary",
191 | timestamps=[np.datetime64("2020-01-01T00:00:00")],
192 | expected_tsi=np.array([1000.0]),
193 | ),
194 | dict(
195 | testcase_name="mid_2020_exact",
196 | timestamps=[np.datetime64("2020-07-02T00:00:00")],
197 | expected_tsi=np.array([1000.0]),
198 | ),
199 | dict(
200 | testcase_name="beginning_2021_interpolated",
201 | timestamps=[np.datetime64("2021-01-01T00:00:00")],
202 | expected_tsi=np.array([1150.0]),
203 | ),
204 | dict(
205 | testcase_name="mid_2021_lookup",
206 | timestamps=[np.datetime64("2021-07-02T12:00:00")],
207 | expected_tsi=np.array([1300.0]),
208 | ),
209 | dict(
210 | testcase_name="beginning_2022_interpolated",
211 | timestamps=[np.datetime64("2022-01-01T00:00:00")],
212 | expected_tsi=np.array([1250.0]),
213 | ),
214 | dict(
215 | testcase_name="mid_2022_lookup",
216 | timestamps=[np.datetime64("2022-07-02T12:00:00")],
217 | expected_tsi=np.array([1200.0]),
218 | ),
219 | dict(
220 | testcase_name="beginning_2023_right_boundary",
221 | timestamps=[np.datetime64("2023-01-01T00:00:00")],
222 | expected_tsi=np.array([1200.0]),
223 | ),
224 | )
225 | def test_interpolation(
226 | self, timestamps: Sequence[np.datetime64], expected_tsi: np.ndarray
227 | ):
228 | tsi_data = xa.DataArray(
229 | np.array([1000.0, 1300.0, 1200.0]),
230 | dims=["time"],
231 | coords={"time": np.array([2020.5, 2021.5, 2022.5])},
232 | )
233 |
234 | tsi = solar_radiation.get_tsi(timestamps, tsi_data)
235 |
236 | np.testing.assert_allclose(expected_tsi, tsi)
237 |
238 |
239 | if __name__ == "__main__":
240 | absltest.main()
241 |
--------------------------------------------------------------------------------
/graphcast/sparse_transformer_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils for training models in low precision."""
15 |
16 | import functools
17 | from typing import Callable, Tuple, Union
18 |
19 | import jax
20 | import jax.numpy as jnp
21 |
22 |
23 | # Wrappers for jax.lax.reduce_precision which is non-differentiable.
24 | @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2))
25 | def reduce_precision(x, exponent_bits, mantissa_bits):
26 | return jax.tree_util.tree_map(
27 | lambda y: jax.lax.reduce_precision(y, exponent_bits, mantissa_bits), x)
28 |
29 |
30 | def reduce_precision_fwd(x, exponent_bits, mantissa_bits):
31 | return reduce_precision(x, exponent_bits, mantissa_bits), None
32 |
33 |
34 | def reduce_precision_bwd(exponent_bits, mantissa_bits, res, dout):
35 | del res # Unused.
36 | return reduce_precision(dout, exponent_bits, mantissa_bits),
37 |
38 |
39 | reduce_precision.defvjp(reduce_precision_fwd, reduce_precision_bwd)
40 |
41 |
42 | def wrap_fn_for_upcast_downcast(inputs: Union[jnp.ndarray,
43 | Tuple[jnp.ndarray, ...]],
44 | fn: Callable[[Union[jnp.ndarray,
45 | Tuple[jnp.ndarray, ...]]],
46 | Union[jnp.ndarray,
47 | Tuple[jnp.ndarray, ...]]],
48 | f32_upcast: bool = True,
49 | guard_against_excess_precision: bool = True
50 | ) -> Union[jnp.ndarray,
51 | Tuple[jnp.ndarray, ...]]:
52 | """Wraps `fn` to upcast to float32 and then downcast, for use with BF16."""
53 | # Do not upcast if the inputs are already in float32.
54 | # This removes a no-op `jax.lax.reduce_precision` which is unsupported
55 | # in jax2tf at the moment.
56 | if isinstance(inputs, Tuple):
57 | f32_upcast = f32_upcast and inputs[0].dtype != jnp.float32
58 | orig_dtype = inputs[0].dtype
59 | else:
60 | f32_upcast = f32_upcast and inputs.dtype != jnp.float32
61 | orig_dtype = inputs.dtype
62 |
63 | if f32_upcast:
64 | inputs = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), inputs)
65 |
66 | if guard_against_excess_precision:
67 | # This is evil magic to guard against differences in precision in the QK
68 | # calculation between the forward pass and backwards pass. This is like
69 | # --xla_allow_excess_precision=false but scoped here.
70 | finfo = jnp.finfo(orig_dtype) # jnp important!
71 | inputs = reduce_precision(inputs, finfo.nexp, finfo.nmant)
72 |
73 | output = fn(inputs)
74 | if f32_upcast:
75 | output = jax.tree_util.tree_map(lambda x: x.astype(orig_dtype), output)
76 | return output
77 |
--------------------------------------------------------------------------------
/graphcast/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """A Transformer model for weather predictions.
15 |
16 | This model wraps the a transformer model and swaps the leading two axes of the
17 | nodes in the input graph prior to evaluating the model to make it compatible
18 | with a [nodes, batch, ...] ordering of the inputs.
19 | """
20 |
21 | from typing import Any, Mapping, Optional
22 |
23 | from graphcast import typed_graph
24 | import haiku as hk
25 | import jax
26 | import jax.numpy as jnp
27 | import numpy as np
28 | from scipy import sparse
29 |
30 |
31 | Kwargs = Mapping[str, Any]
32 |
33 |
34 | def _get_adj_matrix_for_edge_set(
35 | graph: typed_graph.TypedGraph,
36 | edge_set_name: str,
37 | add_self_edges: bool,
38 | ):
39 | """Returns the adjacency matrix for the given graph and edge set."""
40 | # Get nodes and edges of the graph.
41 | edge_set_key = graph.edge_key_by_name(edge_set_name)
42 | sender_node_set, receiver_node_set = edge_set_key.node_sets
43 |
44 | # Compute number of sender and receiver nodes.
45 | sender_n_node = graph.nodes[sender_node_set].n_node[0]
46 | receiver_n_node = graph.nodes[receiver_node_set].n_node[0]
47 |
48 | # Build adjacency matrix.
49 | adj_mat = sparse.csr_matrix((sender_n_node, receiver_n_node), dtype=np.bool_)
50 | edge_set = graph.edges[edge_set_key]
51 | s, r = edge_set.indices
52 | adj_mat[s, r] = True
53 | if add_self_edges:
54 | # Should only do this if we are certain the adjacency matrix is square.
55 | assert sender_node_set == receiver_node_set
56 | adj_mat[np.arange(sender_n_node), np.arange(receiver_n_node)] = True
57 | return adj_mat
58 |
59 |
60 | class MeshTransformer(hk.Module):
61 | """A Transformer for inputs with ordering [nodes, batch, ...]."""
62 |
63 | def __init__(self,
64 | transformer_ctor,
65 | transformer_kwargs: Kwargs,
66 | name: Optional[str] = None):
67 | """Initialises the Transformer model.
68 |
69 | Args:
70 | transformer_ctor: Constructor for transformer.
71 | transformer_kwargs: Kwargs to pass to the transformer module.
72 | name: Optional name for haiku module.
73 | """
74 | super().__init__(name=name)
75 | # We defer the transformer initialisation to the first call to __call__,
76 | # where we can build the mask senders and receivers of the TypedGraph
77 | self._batch_first_transformer = None
78 | self._transformer_ctor = transformer_ctor
79 | self._transformer_kwargs = transformer_kwargs
80 |
81 | @hk.name_like('__init__')
82 | def _maybe_init_batch_first_transformer(self, x: typed_graph.TypedGraph):
83 | if self._batch_first_transformer is not None:
84 | return
85 | self._batch_first_transformer = self._transformer_ctor(
86 | adj_mat=_get_adj_matrix_for_edge_set(
87 | graph=x,
88 | edge_set_name='mesh',
89 | add_self_edges=True,
90 | ),
91 | **self._transformer_kwargs,
92 | )
93 |
94 | def __call__(
95 | self, x: typed_graph.TypedGraph,
96 | global_norm_conditioning: jax.Array
97 | ) -> typed_graph.TypedGraph:
98 | """Applies the model to the input graph and returns graph of same shape."""
99 |
100 | if set(x.nodes.keys()) != {'mesh_nodes'}:
101 | raise ValueError(
102 | f'Expected x.nodes to have key `mesh_nodes`, got {x.nodes.keys()}.'
103 | )
104 | features = x.nodes['mesh_nodes'].features
105 | if features.ndim != 3: # pytype: disable=attribute-error # jax-ndarray
106 | raise ValueError(
107 | 'Expected `x.nodes["mesh_nodes"].features` to be 3, got'
108 | f' {features.ndim}.'
109 | ) # pytype: disable=attribute-error # jax-ndarray
110 |
111 | # Initialise transformer and mask.
112 | self._maybe_init_batch_first_transformer(x)
113 |
114 | y = jnp.transpose(features, axes=[1, 0, 2])
115 | y = self._batch_first_transformer(y, global_norm_conditioning)
116 | y = jnp.transpose(y, axes=[1, 0, 2])
117 | x = x._replace(
118 | nodes={
119 | 'mesh_nodes': x.nodes['mesh_nodes']._replace(
120 | features=y.astype(features.dtype)
121 | )
122 | }
123 | )
124 | return x
125 |
--------------------------------------------------------------------------------
/graphcast/typed_graph.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Data-structure for storing graphs with typed edges and nodes."""
15 |
16 | from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar
17 |
18 | ArrayLike = Union[Any] # np.ndarray, jnp.ndarray, tf.tensor
19 | ArrayLikeTree = Union[Any, ArrayLike] # Nest of ArrayLike
20 |
21 | _T = TypeVar('_T')
22 |
23 |
24 | # All tensors have a "flat_batch_axis", which is similar to the leading
25 | # axes of graph_tuples:
26 | # * In the case of nodes this is simply a shared node and flat batch axis, with
27 | # size corresponding to the total number of nodes in the flattened batch.
28 | # * In the case of edges this is simply a shared edge and flat batch axis, with
29 | # size corresponding to the total number of edges in the flattened batch.
30 | # * In the case of globals this is simply the number of graphs in the flattened
31 | # batch.
32 |
33 | # All shapes may also have any additional leading shape "batch_shape".
34 | # Options for building batches are:
35 | # * Use a provided "flatten" method that takes a leading `batch_shape` and
36 | # it into the flat_batch_axis (this will be useful when using `tf.Dataset`
37 | # which supports batching into RaggedTensors, with leading batch shape even
38 | # if graphs have different numbers of nodes and edges), so the RaggedBatches
39 | # can then be converted into something without ragged dimensions that jax can
40 | # use.
41 | # * Directly build a "flat batch" using a provided function for batching a list
42 | # of graphs (how it is done in `jraph`).
43 |
44 |
45 | class NodeSet(NamedTuple):
46 | """Represents a set of nodes."""
47 | n_node: ArrayLike # [num_flat_graphs]
48 | features: ArrayLikeTree # Prev. `nodes`: [num_flat_nodes] + feature_shape
49 |
50 |
51 | class EdgesIndices(NamedTuple):
52 | """Represents indices to nodes adjacent to the edges."""
53 | senders: ArrayLike # [num_flat_edges]
54 | receivers: ArrayLike # [num_flat_edges]
55 |
56 |
57 | class EdgeSet(NamedTuple):
58 | """Represents a set of edges."""
59 | n_edge: ArrayLike # [num_flat_graphs]
60 | indices: EdgesIndices
61 | features: ArrayLikeTree # Prev. `edges`: [num_flat_edges] + feature_shape
62 |
63 |
64 | class Context(NamedTuple):
65 | # `n_graph` always contains ones but it is useful to query the leading shape
66 | # in case of graphs without any nodes or edges sets.
67 | n_graph: ArrayLike # [num_flat_graphs]
68 | features: ArrayLikeTree # Prev. `globals`: [num_flat_graphs] + feature_shape
69 |
70 |
71 | class EdgeSetKey(NamedTuple):
72 | name: str # Name of the EdgeSet.
73 |
74 | # Sender node set name and receiver node set name connected by the edge set.
75 | node_sets: Tuple[str, str]
76 |
77 |
78 | class TypedGraph(NamedTuple):
79 | """A graph with typed nodes and edges.
80 |
81 | A typed graph is made of a context, multiple sets of nodes and multiple
82 | sets of edges connecting those nodes (as indicated by the EdgeSetKey).
83 | """
84 |
85 | context: Context
86 | nodes: Mapping[str, NodeSet]
87 | edges: Mapping[EdgeSetKey, EdgeSet]
88 |
89 | def edge_key_by_name(self, name: str) -> EdgeSetKey:
90 | found_key = [k for k in self.edges.keys() if k.name == name]
91 | if len(found_key) != 1:
92 | raise KeyError("invalid edge key '{}'. Available edges: [{}]".format(
93 | name, ', '.join(x.name for x in self.edges.keys())))
94 | return found_key[0]
95 |
96 | def edge_by_name(self, name: str) -> EdgeSet:
97 | return self.edges[self.edge_key_by_name(name)]
98 |
--------------------------------------------------------------------------------
/graphcast/xarray_tree.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utilities for working with trees of xarray.DataArray (including Datasets).
15 |
16 | Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library;
17 | it won't work as a leaf node since it implements Mapping, but also won't work
18 | as an internal node since tree doesn't know how to re-create it properly.
19 |
20 | To fix this, we reimplement a subset of `map_structure`, exposing its
21 | constituent DataArrays as leaf nodes. This means it can be mapped over as a
22 | generic container of DataArrays, while still preserving the result as a Dataset
23 | where possible.
24 |
25 | This is useful because in a few places we need to handle a general
26 | Mapping[str, DataArray] (where the coordinates might not be compatible across
27 | the constituent DataArrays) but also the special case of a Dataset nicely.
28 |
29 | For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for
30 | some of the child DataArrays, they will be omitted from the returned dataset. If
31 | any values other than DataArrays or None are returned, then we don't attempt to
32 | return a Dataset and just return a plain dict of the results. Similarly if
33 | DataArrays are returned but with non-matching coordinates, it will just return a
34 | plain dict of DataArrays.
35 |
36 | Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py,
37 | but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`.
38 | as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the
39 | latter exposes DataArrays as leaf nodes.
40 | """
41 |
42 | from typing import Any, Callable
43 |
44 | import xarray
45 |
46 |
47 | def map_structure(func: Callable[..., Any], *structures: Any) -> Any:
48 | """Maps func through given structures with xarrays. See tree.map_structure."""
49 | if not callable(func):
50 | raise TypeError(f'func must be callable, got: {func}')
51 | if not structures:
52 | raise ValueError('Must provide at least one structure')
53 |
54 | first = structures[0]
55 | if isinstance(first, xarray.Dataset):
56 | data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
57 | if all(isinstance(a, (type(None), xarray.DataArray))
58 | for a in data.values()):
59 | data_arrays = [v.rename(k) for k, v in data.items() if v is not None]
60 | try:
61 | return xarray.merge(data_arrays, join='exact')
62 | except ValueError: # Exact join not possible.
63 | pass
64 | return data
65 | if isinstance(first, dict):
66 | return {k: map_structure(func, *[s[k] for s in structures])
67 | for k in first.keys()}
68 | if isinstance(first, (list, tuple, set)):
69 | return type(first)(map_structure(func, *s) for s in zip(*structures))
70 | return func(*structures)
71 |
--------------------------------------------------------------------------------
/graphcast/xarray_tree_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for xarray_tree."""
15 |
16 | from absl.testing import absltest
17 | from graphcast import xarray_tree
18 | import numpy as np
19 | import xarray
20 |
21 |
22 | TEST_DATASET = xarray.Dataset(
23 | data_vars={
24 | "foo": (("x", "y"), np.zeros((2, 3))),
25 | "bar": (("x",), np.zeros((2,))),
26 | },
27 | coords={
28 | "x": [1, 2],
29 | "y": [10, 20, 30],
30 | }
31 | )
32 |
33 |
34 | class XarrayTreeTest(absltest.TestCase):
35 |
36 | def test_map_structure_maps_over_leaves_but_preserves_dataset_type(self):
37 | def fn(leaf):
38 | self.assertIsInstance(leaf, xarray.DataArray)
39 | result = leaf + 1
40 | # Removing the name from the returned DataArray to test that we don't rely
41 | # on it being present to restore the correct names in the result:
42 | result = result.rename(None)
43 | return result
44 |
45 | result = xarray_tree.map_structure(fn, TEST_DATASET)
46 | self.assertIsInstance(result, xarray.Dataset)
47 | self.assertSameElements({"foo", "bar"}, result.keys())
48 |
49 | def test_map_structure_on_data_arrays(self):
50 | data_arrays = dict(TEST_DATASET)
51 | result = xarray_tree.map_structure(lambda x: x+1, data_arrays)
52 | self.assertIsInstance(result, dict)
53 | self.assertSameElements({"foo", "bar"}, result.keys())
54 |
55 | def test_map_structure_on_dataset_plain_dict_when_coords_incompatible(self):
56 | def fn(leaf):
57 | # Returns DataArrays that can't be exactly merged back into a Dataset
58 | # due to the coordinates not matching:
59 | if leaf.name == "foo":
60 | return xarray.DataArray(
61 | data=np.zeros(2), dims=("x",), coords={"x": [1, 2]})
62 | else:
63 | return xarray.DataArray(
64 | data=np.zeros(2), dims=("x",), coords={"x": [3, 4]})
65 |
66 | result = xarray_tree.map_structure(fn, TEST_DATASET)
67 | self.assertIsInstance(result, dict)
68 | self.assertSameElements({"foo", "bar"}, result.keys())
69 |
70 | def test_map_structure_on_dataset_drops_vars_with_none_return_values(self):
71 | def fn(leaf):
72 | return leaf if leaf.name == "foo" else None
73 |
74 | result = xarray_tree.map_structure(fn, TEST_DATASET)
75 | self.assertIsInstance(result, xarray.Dataset)
76 | self.assertSameElements({"foo"}, result.keys())
77 |
78 | def test_map_structure_on_dataset_returns_plain_dict_other_return_types(self):
79 | def fn(leaf):
80 | self.assertIsInstance(leaf, xarray.DataArray)
81 | return "not a DataArray"
82 |
83 | result = xarray_tree.map_structure(fn, TEST_DATASET)
84 | self.assertEqual({"foo": "not a DataArray",
85 | "bar": "not a DataArray"}, result)
86 |
87 | def test_map_structure_two_args_different_variable_orders(self):
88 | dataset_different_order = TEST_DATASET[["bar", "foo"]]
89 | def fn(arg1, arg2):
90 | self.assertEqual(arg1.name, arg2.name)
91 | xarray_tree.map_structure(fn, TEST_DATASET, dataset_different_order)
92 |
93 |
94 | if __name__ == "__main__":
95 | absltest.main()
96 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Module setuptools script."""
15 |
16 | from setuptools import setup
17 |
18 | description = (
19 | "GraphCast: Learning skillful medium-range global weather forecasting"
20 | )
21 |
22 | setup(
23 | name="graphcast",
24 | version="0.2.0.dev",
25 | description=description,
26 | long_description=description,
27 | author="DeepMind",
28 | license="Apache License, Version 2.0",
29 | keywords="GraphCast Weather Prediction",
30 | url="https://github.com/deepmind/graphcast",
31 | packages=["graphcast"],
32 | install_requires=[
33 | "cartopy",
34 | "chex",
35 | "colabtools",
36 | "dask",
37 | "dinosaur-dycore",
38 | "dm-haiku",
39 | "dm-tree",
40 | "jax",
41 | "jraph",
42 | "matplotlib",
43 | "numpy",
44 | "pandas",
45 | "rtree",
46 | "scipy",
47 | "trimesh",
48 | "typing_extensions",
49 | "xarray",
50 | "xarray_tensorstore"
51 | ],
52 | classifiers=[
53 | "Development Status :: 3 - Alpha",
54 | "Intended Audience :: Science/Research",
55 | "License :: OSI Approved :: Apache Software License",
56 | "Operating System :: POSIX :: Linux",
57 | "Programming Language :: Python :: 3.10",
58 | "Programming Language :: Python :: 3.11",
59 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
60 | "Topic :: Scientific/Engineering :: Atmospheric Science",
61 | "Topic :: Scientific/Engineering :: Physics",
62 | ],
63 | )
64 |
--------------------------------------------------------------------------------