├── .gitignore ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── NCEP ├── Readme.md ├── cronjob_cloud.sh ├── cronjob_hera.sh ├── docs │ ├── Makefile │ ├── README.md │ ├── make.bat │ ├── requirements.txt │ └── source │ │ ├── conf.py │ │ ├── index.rst │ │ ├── inputs.rst │ │ ├── installation.rst │ │ ├── introduction.rst │ │ ├── outputs.rst │ │ └── run.rst ├── environment.yml ├── gc_datadissm_hera.sh ├── gc_prepdata_hera.sh ├── gc_runfcst_hera.sh ├── gcjob_13pl_cloud.sh ├── gcjob_37pl_cloud.sh ├── gdas_utility.py ├── run_graphcast.py ├── upload_to_s3bucket.py └── utils │ └── nc2grib.py ├── README.md ├── docs ├── GenCast_0p25deg_accelerator_scorecard.png ├── GenCast_0p25deg_attention_implementation_scorecard.png ├── GenCast_1p0deg_Mini_ENS_scorecard.png ├── cloud_vm_setup.md ├── local_runtime_popup_1.png ├── local_runtime_popup_2.png ├── local_runtime_url.png ├── project.png ├── provision_tpu.png └── tpu_types.png ├── gencast_demo_cloud_vm.ipynb ├── gencast_mini_demo.ipynb ├── graphcast ├── autoregressive.py ├── casting.py ├── checkpoint.py ├── checkpoint_test.py ├── data_utils.py ├── data_utils_test.py ├── deep_typed_graph_net.py ├── denoiser.py ├── denoisers_base.py ├── dpm_solver_plus_plus_2s.py ├── gencast.py ├── graphcast.py ├── grid_mesh_connectivity.py ├── grid_mesh_connectivity_test.py ├── icosahedral_mesh.py ├── icosahedral_mesh_test.py ├── losses.py ├── mlp.py ├── model_utils.py ├── nan_cleaning.py ├── normalization.py ├── predictor_base.py ├── rollout.py ├── samplers_base.py ├── samplers_utils.py ├── solar_radiation.py ├── solar_radiation_test.py ├── sparse_transformer.py ├── sparse_transformer_utils.py ├── transformer.py ├── typed_graph.py ├── typed_graph_net.py ├── xarray_jax.py ├── xarray_jax_test.py ├── xarray_tree.py └── xarray_tree_test.py ├── graphcast_demo.ipynb └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg* 2 | __pycache__ 3 | *.swp 4 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | build: 8 | os: ubuntu-20.04 9 | tools: 10 | python: "3.9" 11 | 12 | # Build documentation in the docs/ directory with Sphinx 13 | sphinx: 14 | configuration: NCEP/docs/source/conf.py 15 | 16 | # Build documentation with MkDocs 17 | #mkdocs: 18 | # configuration: mkdocs.yml 19 | 20 | # Optionally build your docs in additional formats such as PDF and ePub 21 | formats: all 22 | 23 | # Optionally set the version of Python and requirements required to build your docs 24 | python: 25 | install: 26 | - requirements: NCEP/docs/requirements.txt 27 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /NCEP/Readme.md: -------------------------------------------------------------------------------- 1 | # GraphCast model with NCEP GDAS Products as ICs 2 | 3 | This repository provides scripts to run real-time GraphCast using GDAS products as inputs. There are multiple scripts in the repository including: 4 | - `gdas_utility.py`: a Python script designed to download Global Data Assimilation System (GDAS) data from the National Centers for Environmental Prediction (NCEP) from NOAA S3 bucket (or NOMADS), and prepare the data in a format suitable for feeding into the GraphCast weather prediction system. 5 | - `run_graphcast.py`: a Python script that calls GraphCast and takes GDAS products as input and produces six-hourly forecasts with an arbitrary forecast length (e.g., 40 --> 10-days). 6 | - `graphcast_job_[machine_name].sh`: a Bash script that automates running GraphCast in real-time over the AWS cloud machine (should be submitted through CronJob). 7 | 8 | ## Table of Contents 9 | - [Overview](#overview) 10 | - [Prerequisites and Installation](#prerequisites-and-installation) 11 | - [Usage](#usage) 12 | - [GDAS Utility](#gdas-utility) 13 | - [Run GraphCast](#run-graphcast) 14 | - [Run GraphCast Through Cronjob](#run-graphcast-through-cronjob) 15 | - [Output](#output) 16 | - [Contact](#contact) 17 | 18 | ## Overview 19 | 20 | The National Centers for Environmental Prediction (NCEP) provides GDAS data that can be used for weather prediction and analysis. This repository simplifies the process of downloading GDAS data, extracting relevant variables, and converting it into a format compatible with the GraphCast weather prediction system. In addition, it automates running GraphCast with GDAS inputs on the NOAA clusters. 21 | 22 | ## Prerequisites and Installation 23 | 24 | To install the package, run the following commands: 25 | 26 | ```bash 27 | conda create --name mlwp python=3.10 28 | ``` 29 | 30 | ```bash 31 | conda activate mlwp 32 | ``` 33 | 34 | ```bash 35 | pip install dm-tree boto3 xarray netcdf4 36 | ``` 37 | 38 | ```bash 39 | conda install --channel conda-forge cartopy 40 | ``` 41 | 42 | ```bash 43 | pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip 44 | ``` 45 | 46 | ```bash 47 | pip isntall pygrib requests bs4 48 | ``` 49 | 50 | If your OS is MacOS, wget has to be installed: 51 | 52 | ```bash 53 | brew install wget 54 | ``` 55 | 56 | 57 | If you would like to save as grib2 format, the following packages are needed: 58 | 59 | ```bash 60 | pip install ecmwflibs 61 | ```` 62 | ```bash 63 | pip install iris 64 | ```` 65 | 66 | ```bash 67 | pip install iris-grib 68 | ```` 69 | 70 | This will install the packages and most of their dependencies. 71 | 72 | 73 | Additionally, the utility uses the `wgrib2` library for extracting specific variables from the GDAS data. You can download and install `wgrib2` from [here](http://www.cpc.ncep.noaa.gov/products/wesley/wgrib2/). Make sure it is included in your system's PATH. 74 | 75 | ## Usage 76 | 77 | To use the utility, follow these steps: 78 | 79 | Clone the NOAA-EMC GraphCast repository: 80 | 81 | ```bash 82 | git clone https://github.com/NOAA-EMC/graphcast.git 83 | ``` 84 | 85 | ```bash 86 | cd graphcast/NCEP 87 | ``` 88 | 89 | ## GDAS Utility 90 | 91 | To download and prepare GDAS data, use the following command: 92 | 93 | ```bash 94 | python3 gdas_utility.py yyyymmddhh yyyymmddhh --level 13 --source s3 --output /directory/to/output --download /directory/to/download --keep no 95 | ``` 96 | 97 | #### Arguments (required): 98 | 99 | - `yyyymmddhh`: Start datetime 100 | - `yyyymmddhh`: End datetime 101 | 102 | #### Arguments (optional): 103 | 104 | - `-l or --level`: [13, 37], represents the number of pressure levels (default: 13) 105 | - `-s or --source`: [s3, nomads], represents the source to download GDAS data (default: "s3") 106 | - `-o or --output`: /directory/to/output, represents the directory to output netcdf file (default: "current directory") 107 | - `-d or --download`: /directory/to/download, represents the download directory for grib2 files (default: "current directory") 108 | - `-k or --keep`: [yes, no], specifies whether to keep downloaded data after processing (default: "no") 109 | 110 | Example usage with options: 111 | 112 | ```bash 113 | python3 gdas_utility.py 2023060600 2023060606 -o /path/to/output -d /path/to/download 114 | ``` 115 | 116 | Note: 117 | - The 37 pressure levels option is still under development. 118 | - GraphCast only needs 2 states for initialization, however, gdas_utility can provide longer outputs for evaluation of the model (e.g., 10-days). 119 | 120 | 121 | ## Run GraphCast 122 | 123 | To run GraphCast, use the following command: 124 | 125 | ```bash 126 | python3 run_graphcast.py --input /path/to/input/file --output /path/to/output/file --weights /path/to/graphcast/weights --length forecast_length --upload yes --keep no` 127 | ``` 128 | 129 | #### Arguments (required): 130 | 131 | - `-i or --input`: /path/to/input/file, represents the path to input netcdf file (including file name and extension) 132 | - `-o or --output`: /path/to/output/file, represents the path to output netcdf file (including file name and extension) 133 | - `-w or --weights`: /path/to/graphcast/weights, represents the path to the parent directory of the graphcast params (weights) and stats from the pre-trained model 134 | - `-l or --length`: An integer number in the range [1, 40], represents the number of forecasts time steps (6-hourly; e.g., 40 → 10-days) 135 | 136 | #### Arguments (optional): 137 | 138 | - `-u or --upload`: [yes, no], option for uploading the input and output files to NOAA S3 bucket [noaa-nws-graphcastgfs-pds] (default: "no") 139 | - `-k or --keep`: [yes, no], specifies whether to keep input and output files after uploading to NOAA S3 bucket (default: "no") 140 | 141 | Example usage with options (1-day forecast): 142 | 143 | ```bash 144 | python3 run_graphcast.py -i /path/to/input -o /path/to/output -w /path/to/graphcast/weights -l 4 145 | ``` 146 | 147 | ## Run GraphCast Through Cronjob 148 | 149 | Submit the `cronjob_[machine_name].sh` to run GraphCast and get real-time (every 6 hours) forecasts through cronjob. 150 | 151 | ```bash 152 | # Example CronJob to run GraphCast every 6 hours 153 | 0 */6 * * * /lustre/Sadegh.Tabas/graphcast/NCEP/cronjob_cloud.sh >> /lustre/Sadegh.Tabas/graphcast/NCEP/logfile.log 2>&1 154 | ``` 155 | 156 | ## Output 157 | 158 | The processed GDAS data as well as GraphCast forecasts will be saved in NetCDF format in the related directories (uploading to NOAA S3 bucket option is also provided for both input and output files). The files will be named based on the date. 159 | 160 | ## Contact 161 | 162 | For questions or issues, please contact [Sadegh.Tabas@noaa.gov](mailto:Sadegh.Tabas@noaa.gov). 163 | -------------------------------------------------------------------------------- /NCEP/cronjob_cloud.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | sbatch /contrib/Sadegh.Tabas/operational/graphcast/NCEP/gcjob_13pl_cloud.sh 3 | # sbatch /contrib/Sadegh.Tabas/operational/graphcast/NCEP/gcjob_37pl_cloud.sh 4 | -------------------------------------------------------------------------------- /NCEP/cronjob_hera.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | 3 | echo "Job 1 is running" 4 | sh /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/gc_prepdata_hera.sh 5 | sleep 60 # Simulating some work 6 | echo "Job 1 completed" 7 | 8 | echo "Job 2 is running" 9 | job2_id=$(sbatch /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/gc_runfcst_hera.sh | awk '{print $4}') 10 | 11 | # Wait for job 2 to complete 12 | while squeue -j $job2_id &>/dev/null; do 13 | sleep 5 # Adjust the polling interval as needed 14 | done 15 | sleep 5 # Simulating some work 16 | echo "Job 2 completed" 17 | 18 | echo "Job 3 is running" 19 | sh /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/gc_datadissm_hera.sh 20 | sleep 5 # Simulating some work 21 | echo "Job 3 completed" 22 | -------------------------------------------------------------------------------- /NCEP/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # # 3 | # 4 | # # You can set these variables from the command line, and also 5 | # # from the environment for the first two. 6 | # SPHINXOPTS ?= 7 | # SPHINXBUILD ?= sphinx-build 8 | # SOURCEDIR = source 9 | # BUILDDIR = build 10 | # 11 | # # Put it first so that "make" without argument is like "make help". 12 | # help: 13 | # @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | # 15 | # .PHONY: help Makefile 16 | # 17 | # # Catch-all target: route all unknown targets to Sphinx using the new 18 | # # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | # %: Makefile 20 | # @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /NCEP/docs/README.md: -------------------------------------------------------------------------------- 1 | ## GraphCast Global Forecast System (GraphCastGFS) 2 | -------------------------------------------------------------------------------- /NCEP/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /NCEP/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinxcontrib-bibtex 2 | sphinx_rtd_theme 3 | docutils==0.16 4 | -------------------------------------------------------------------------------- /NCEP/docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('.')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'GraphCast GFS' 23 | copyright = '' 24 | author = ' ' 25 | 26 | # The short X.Y version 27 | version = '' 28 | # The full version, including alpha/beta/rc tags 29 | release = '' 30 | 31 | numfig = True 32 | 33 | 34 | # -- General configuration --------------------------------------------------- 35 | 36 | # If your documentation needs a minimal Sphinx version, state it here. 37 | # 38 | # needs_sphinx = '1.0' 39 | 40 | # Add any Sphinx extension module names here, as strings. They can be 41 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 42 | # ones. 43 | extensions = [ 44 | 'sphinx_rtd_theme', 45 | 'sphinx.ext.autodoc', 46 | 'sphinx.ext.doctest', 47 | 'sphinx.ext.intersphinx', 48 | 'sphinx.ext.todo', 49 | 'sphinx.ext.coverage', 50 | 'sphinx.ext.mathjax', 51 | 'sphinx.ext.ifconfig', 52 | 'sphinx.ext.viewcode', 53 | 'sphinx.ext.githubpages', 54 | 'sphinx.ext.napoleon', 55 | ] 56 | 57 | # Add any paths that contain templates here, relative to this directory. 58 | #templates_path = ['_templates'] 59 | 60 | # The suffix(es) of source filenames. 61 | # You can specify multiple suffix as a list of string: 62 | # 63 | # source_suffix = ['.rst', '.md'] 64 | source_suffix = '.rst' 65 | 66 | # The master toctree document. 67 | master_doc = 'index' 68 | 69 | # The language for content autogenerated by Sphinx. Refer to documentation 70 | # for a list of supported languages. 71 | # 72 | # This is also used if you do content translation via gettext catalogs. 73 | # Usually you set "language" from the command line for these cases. 74 | language = 'en' 75 | 76 | # List of patterns, relative to source directory, that match files and 77 | # directories to ignore when looking for source files. 78 | # This pattern also affects html_static_path and html_extra_path. 79 | exclude_patterns = [] 80 | 81 | # The name of the Pygments (syntax highlighting) style to use. 82 | pygments_style = 'sphinx' 83 | 84 | 85 | # -- Options for HTML output ------------------------------------------------- 86 | 87 | # The theme to use for HTML and HTML Help pages. See the documentation for 88 | # a list of builtin themes. 89 | # 90 | #html_theme = 'classic' 91 | html_theme = 'sphinx_rtd_theme' 92 | html_theme_path = ["_themes", ] 93 | 94 | # Theme options are theme-specific and customize the look and feel of a theme 95 | # further. For a list of options available for each theme, see the 96 | # documentation. 97 | # 98 | # html_theme_options = {} 99 | html_theme_options = {"body_max_width": "none"} 100 | 101 | # Add any paths that contain custom static files (such as style sheets) here, 102 | # relative to this directory. They are copied after the builtin static files, 103 | # so a file named "default.css" will overwrite the builtin "default.css". 104 | #html_static_path = ['_static'] 105 | html_context = {} 106 | 107 | def setup(app): 108 | app.add_css_file('custom.css') # may also be an URL 109 | app.add_css_file('theme_overrides.css') # may also be an URL 110 | 111 | # Custom sidebar templates, must be a dictionary that maps document names 112 | # to template names. 113 | # 114 | # The default sidebars (for documents that don't match any pattern) are 115 | # defined by theme itself. Builtin themes are using these templates by 116 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 117 | # 'searchbox.html']``. 118 | # 119 | # html_sidebars = {} 120 | 121 | 122 | 123 | # -- Options for LaTeX output ------------------------------------------------ 124 | 125 | latex_engine = 'pdflatex' 126 | latex_elements = { 127 | # The paper size ('letterpaper' or 'a4paper'). 128 | # 129 | # 'papersize': 'letterpaper', 130 | 131 | # The font size ('10pt', '11pt' or '12pt'). 132 | # 133 | # 'pointsize': '10pt', 134 | 135 | # Additional stuff for the LaTeX preamble. 136 | # 137 | # 'preamble': '', 138 | 139 | # Latex figure (float) alignment 140 | # 141 | # 'figure_align': 'htbp', 142 | # 'maketitle': r'\newcommand\sphinxbackoftitlepage{For referencing this document please use: \newline \break Schramm, J., L. Bernardet, L. Carson, G. Firl, D. Heinzeller, L. Pan, and M. Zhang, 2020. UFS Weather Model User's Guide Release v1.0.0. Npp. Available at https://dtcenter.org.}\sphinxmaketitle' 143 | } 144 | 145 | # -- Extension configuration ------------------------------------------------- 146 | 147 | # -- Options for intersphinx extension --------------------------------------- 148 | 149 | # Example configuration for intersphinx: refer to the Python standard library. 150 | intersphinx_mapping = {'landda': ('https://land-da-workflow.readthedocs.io/en/latest/', None), 151 | } 152 | 153 | # -- Options for todo extension ---------------------------------------------- 154 | 155 | # If true, `todo` and `todoList` produce output, else they produce nothing. 156 | todo_include_todos = True 157 | -------------------------------------------------------------------------------- /NCEP/docs/source/index.rst: -------------------------------------------------------------------------------- 1 | GraphCast with GFS input 2 | ======================================================= 3 | .. toctree:: 4 | :numbered: 5 | :maxdepth: 3 6 | 7 | introduction 8 | installation 9 | inputs 10 | run 11 | outputs 12 | -------------------------------------------------------------------------------- /NCEP/docs/source/inputs.rst: -------------------------------------------------------------------------------- 1 | ############################################# 2 | Preparing inputs from GDAS product 3 | ############################################# 4 | 5 | GraphCast takes two states of the weather (current and 6-hr earlier states) as the initial conditions. We will create a netCDF file containing these two states from GDAS 0.25 degree reanalysis data. This can be performed using the script NCEP/gdas_utility.py. The script downloads the GDAS data from either NOAA s3 bucket or NOAA NOMADS server, which are in GRIB2 format. Then it extracts required variables from GRIB2 files and saves data as netCDF files. Run the script using:: 6 | 7 | python gdas_utility.py startdate enddate --level 13 --source s3 --output /path/to/output --download /path/to/download --method wgrib2 --keep no 8 | 9 | **Arguments** 10 | 11 | Requried: 12 | 13 | startdate and endate: string, yyyymmddhh 14 | 15 | Optional: 16 | 17 | *-l* or *--level*: 13 or 37, the number of pressure levels (default: 13) 18 | 19 | *-s* or *--source*: s3 or nomads, the sourece to download gdas data (default: s3) 20 | 21 | *-m* or *--method*: wgrib2 or pygrib, the method to extract required variables and create netCDF file (default: wgrib2) 22 | 23 | *-o* or *--output*: /path/to/output, where to save forecast outputs (default: current directory) 24 | 25 | *-d* or *--download*: /path/to/download, where to save downloaded grib2 files (default: current directory) 26 | 27 | *-k* or *--keep*: yes or no, whether to keep downloaded data after processed (default: no) 28 | -------------------------------------------------------------------------------- /NCEP/docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | ############################################ 2 | Installation 3 | ############################################ 4 | 5 | The recommended way to setup the environemnt for installing GraphCast is to use `conda `_. 6 | With conda, you can create an environment and install required libraries with the `environment.yml` file provided in NCEP folder:: 7 | 8 | conda env create -f environment.yml -n your-env-name 9 | 10 | Activate the environment:: 11 | 12 | conda activate your-env-name 13 | 14 | Get EMC/graphcast source code:: 15 | 16 | git clone https://github.com/NOAA-EMC/graphcast.git 17 | -------------------------------------------------------------------------------- /NCEP/docs/source/introduction.rst: -------------------------------------------------------------------------------- 1 | ###################### 2 | Introduction 3 | ###################### 4 | 5 | **The GraphCast Global Forecast System** is a weather forecast model built upon the pre-trained Google DeepMind's GraphCast 6 | Machine Learning Weather Prediction (MLWP) model. It is set up by the National Centers for Environmental Prediction (NCEP) 7 | to produce medium range global forecasts. The model runs in two operation modes on different vertical resolutions: 13 and 8 | 37 pressure levels. The horizontal resolution is a 0.25 degree latitude-longitude grid (about 28 km). The model runs 4 9 | times a day at 00Z, 06Z, 12Z, and 18Z cycles. Major surface and atmospheric fields including temperature, wind components, 10 | geopotential height, specific humidity, and vertical velocity are available. The products are 6-hourly forecasts up to 10 days. 11 | 12 | The Google DeepMind's GraphCast model is implemented as a message passing graph neural network (GNN) architecture with 13 | "encoder-processor-decoder" configuration. It uses an icosahedron grid with multiscale edges and has around 37 milion parameters. 14 | The model is pre-trained with ECMWF's ERA5 reanalysis data. The GraphCastGFS model takes two model states as initial conditions 15 | (current and 6-hr previous states) from NCEP 0.25 degree GDAS analysis data. 16 | -------------------------------------------------------------------------------- /NCEP/docs/source/outputs.rst: -------------------------------------------------------------------------------- 1 | ###################### 2 | Product 3 | ###################### 4 | 5 | The GraphCastGFS model runs 4 times a day at 00Z, 06Z, 12Z, and 18Z cycles. The horizontal resolution is on 0.25 degree lat-lon grid. 6 | The vertical resolutions are on both 13 and 37 pressure levels. 7 | 8 | * The 13 pressure levels include: 9 | 10 | 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, and 1000 hPa. 11 | 12 | * The 37 pressure levels include: 13 | 14 | 1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, 150, 175, 200, 225, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 825, 850, 875, 900, 925, 950, 975, and 1000 hPa. 15 | 16 | The model output fields are: 17 | 18 | * 3D fields on pressure levels: 19 | 20 | * temperature 21 | 22 | * U and V component of wind 23 | 24 | * geopotential height 25 | 26 | * specific humidity 27 | 28 | * vertical velocity 29 | 30 | * 2D surface fields: 31 | 32 | * 10-m U and V components of wind 33 | 34 | * 2-m temperature 35 | 36 | * mean sea-level pressure 37 | 38 | * 6-hourly total precipitation 39 | 40 | The near real-time forecast outputs along with inputs are available on `AWS `_. 41 | 42 | For each cycle, the dataset contains input files to feed into GraphCast found in the directory: 43 | 44 | graphcastgfs.yyyymmdd/hh/input 45 | 46 | and 10-day forecast results for the current cycle found in the following directories: 47 | 48 | graphcastgfs.yyyymmdd/hh/forecasts_13_levels 49 | 50 | graphcastgfs.yyyymmdd/hh/forecasts_37_levels 51 | -------------------------------------------------------------------------------- /NCEP/docs/source/run.rst: -------------------------------------------------------------------------------- 1 | ###################### 2 | Run GraphCastGFS 3 | ###################### 4 | In order to run GraphCast in inference mode you will also need to have the model weights, normalization statistics, 5 | which are avaiable on `Google Cloud Bucket `_ 6 | Once you have input netCDF file, model weights, and statistics data, you can run the GraphCast model with a leading time 7 | (e.g., leading time 10 days will result in forecast_length of 40) using:: 8 | 9 | python run_graphcast.py --input /input/filename/with/path --output /path/to/output --weights /path/to/weights --length forecast_length 10 | 11 | **Arguments** 12 | 13 | Required: 14 | 15 | *-i* or *--input*: /input/filename/with/path 16 | 17 | *-o* or *--output*: /path/to/output 18 | 19 | *-w* or *--weights*: /path/to/weights/and/stats 20 | 21 | *-l* or *--length*: integer, the number of forecast time steps (6-hourly) 22 | 23 | Optional: 24 | 25 | *-p* or *--pressure*: 13 or 37, number of pressure levels (default: 13) 26 | 27 | *-u* or *--upload*: yes or no, upload input and output files to NOAA s3 bucket (default: no) 28 | 29 | *-k* or *--keep*: yes or no, whether to keep input and output files after uploading 30 | -------------------------------------------------------------------------------- /NCEP/environment.yml: -------------------------------------------------------------------------------- 1 | name: graphcast 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python>=3.10 6 | - dm-tree 7 | - netcdf4 8 | - xarray 9 | - cartopy 10 | - pygrib 11 | - eccodes 12 | - iris 13 | - iris-grib 14 | - jupyterlab 15 | - pip 16 | - pip: 17 | - https://github.com/deepmind/graphcast/archive/master.zip 18 | -------------------------------------------------------------------------------- /NCEP/gc_datadissm_hera.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # load necessary modules 4 | module use /scratch1/NCEPDEV/nems/role.epic/spack-stack/spack-stack-1.6.0/envs/unified-env-rocky8/install/modulefiles/Core 5 | module load stack-intel 6 | module load wgrib2 7 | module load awscli-v2 8 | module list 9 | 10 | 11 | 12 | # Get the UTC hour and calculate the time in the format yyyymmddhh 13 | current_hour=$(date -u +%H) 14 | current_hour=$((10#$current_hour)) 15 | 16 | if (( $current_hour >= 0 && $current_hour < 6 )); then 17 | datetime=$(date -u -d 'today 00:00') 18 | elif (( $current_hour >= 6 && $current_hour < 12 )); then 19 | datetime=$(date -u -d 'today 06:00') 20 | elif (( $current_hour >= 12 && $current_hour < 18 )); then 21 | datetime=$(date -u -d 'today 12:00') 22 | else 23 | datetime=$(date -u -d 'today 18:00') 24 | fi 25 | 26 | # Calculate time 6 hours before 27 | #curr_datetime=$(date -u -d "$time" +'%Y%m%d%H') 28 | curr_datetime=$( date -d "$datetime 12 hour ago" "+%Y%m%d%H" ) 29 | prev_datetime=$( date -d "$datetime 18 hour ago" "+%Y%m%d%H" ) 30 | 31 | echo "Current state: $curr_datetime" 32 | echo "6 hours earlier state: $prev_datetime" 33 | 34 | # Activate Conda environment 35 | source /scratch1/NCEPDEV/nems/AIML/miniconda3/etc/profile.d/conda.sh 36 | conda activate mlwp 37 | 38 | cd /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/ 39 | 40 | forecast_length=40 41 | num_pressure_level=13 42 | 43 | start_time=$(date +%s) 44 | echo "start uploading graphcast forecast to s3 bucket for: $curr_datetime" 45 | # Run another Python script 46 | python3 upload_to_s3bucket.py -d "$curr_datetime" -l "$num_pressure_level" 47 | 48 | end_time=$(date +%s) # Record the end time in seconds since the epoch 49 | 50 | # Calculate and print the execution time 51 | execution_time=$((end_time - start_time)) 52 | echo "Execution time for uploading to s3 bucket: $execution_time seconds" 53 | -------------------------------------------------------------------------------- /NCEP/gc_prepdata_hera.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # load necessary modules 5 | module use /scratch1/NCEPDEV/nems/role.epic/spack-stack/spack-stack-1.6.0/envs/unified-env-rocky8/install/modulefiles/Core 6 | module load stack-intel 7 | module load wgrib2 8 | module load awscli-v2 9 | module list 10 | 11 | 12 | # Get the UTC hour and calculate the time in the format yyyymmddhh 13 | current_hour=$(date -u +%H) 14 | current_hour=$((10#$current_hour)) 15 | 16 | if (( $current_hour >= 0 && $current_hour < 6 )); then 17 | datetime=$(date -u -d 'today 00:00') 18 | elif (( $current_hour >= 6 && $current_hour < 12 )); then 19 | datetime=$(date -u -d 'today 06:00') 20 | elif (( $current_hour >= 12 && $current_hour < 18 )); then 21 | datetime=$(date -u -d 'today 12:00') 22 | else 23 | datetime=$(date -u -d 'today 18:00') 24 | fi 25 | 26 | # Calculate time 6 hours before 27 | #curr_datetime=$(date -u -d "$time" +'%Y%m%d%H') 28 | curr_datetime=$( date -d "$datetime 12 hour ago" "+%Y%m%d%H" ) 29 | prev_datetime=$( date -d "$datetime 18 hour ago" "+%Y%m%d%H" ) 30 | 31 | echo "Current state: $curr_datetime" 32 | echo "6 hours earlier state: $prev_datetime" 33 | 34 | 35 | # Activate Conda environment 36 | source /scratch1/NCEPDEV/nems/AIML/miniconda3/etc/profile.d/conda.sh 37 | conda activate mlwp 38 | 39 | cd /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/ 40 | 41 | num_pressure_levels=13 42 | 43 | start_time=$(date +%s) 44 | echo "start runing gdas utility to generate graphcast inputs for: $curr_datetime" 45 | # Run the Python script gdas.py with the calculated times 46 | python3 gdas_utility.py "$prev_datetime" "$curr_datetime" -l "$num_pressure_levels" 47 | 48 | end_time=$(date +%s) # Record the end time in seconds since the epoch 49 | 50 | # Calculate and print the execution time 51 | execution_time=$((end_time - start_time)) 52 | echo "Execution time for gdas_utility.py: $execution_time seconds" 53 | -------------------------------------------------------------------------------- /NCEP/gc_runfcst_hera.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH --account=nems 4 | #SBATCH --cpus-per-task=40 5 | #SBATCH --time=4:00:00 6 | #SBATCH --job-name=graphcast 7 | #SBATCH --output=gc_output.txt 8 | #SBATCH --error=gc_error.txt 9 | #SBATCH --partition=hera 10 | 11 | 12 | # load necessary modules 13 | module use /scratch1/NCEPDEV/nems/role.epic/spack-stack/spack-stack-1.6.0/envs/unified-env-rocky8/install/modulefiles/Core 14 | module load stack-intel 15 | module load wgrib2 16 | module load awscli-v2 17 | module list 18 | 19 | 20 | # Get the UTC hour and calculate the time in the format yyyymmddhh 21 | current_hour=$(date -u +%H) 22 | current_hour=$((10#$current_hour)) 23 | 24 | if (( $current_hour >= 0 && $current_hour < 6 )); then 25 | datetime=$(date -u -d 'today 00:00') 26 | elif (( $current_hour >= 6 && $current_hour < 12 )); then 27 | datetime=$(date -u -d 'today 06:00') 28 | elif (( $current_hour >= 12 && $current_hour < 18 )); then 29 | datetime=$(date -u -d 'today 12:00') 30 | else 31 | datetime=$(date -u -d 'today 18:00') 32 | fi 33 | 34 | # Calculate time 6 hours before 35 | #curr_datetime=$(date -u -d "$time" +'%Y%m%d%H') 36 | curr_datetime=$( date -d "$datetime 12 hour ago" "+%Y%m%d%H" ) 37 | prev_datetime=$( date -d "$datetime 18 hour ago" "+%Y%m%d%H" ) 38 | 39 | echo "Current state: $curr_datetime" 40 | echo "6 hours earlier state: $prev_datetime" 41 | 42 | forecast_length=40 43 | echo "forecast length: $forecast_length" 44 | 45 | num_pressure_levels=13 46 | echo "number of pressure levels: $num_pressure_levels" 47 | 48 | # Activate Conda environment 49 | source /scratch1/NCEPDEV/nems/AIML/miniconda3/etc/profile.d/conda.sh 50 | conda activate mlwp 51 | 52 | cd /scratch1/NCEPDEV/nems/AIML/graphcast/NCEP/ 53 | 54 | start_time=$(date +%s) 55 | echo "start runing graphcast to get real time 10-days forecasts for: $curr_datetime" 56 | # Run another Python script 57 | python3 run_graphcast.py -i source-gdas_date-"$curr_datetime"_res-0.25_levels-"$num_pressure_levels"_steps-2.nc -w /scratch1/NCEPDEV/nems/AIML/gc_weights -l "$forecast_length" -p "$num_pressure_levels" 58 | 59 | end_time=$(date +%s) # Record the end time in seconds since the epoch 60 | 61 | # Calculate and print the execution time 62 | execution_time=$((end_time - start_time)) 63 | echo "Execution time for graphcast: $execution_time seconds" 64 | -------------------------------------------------------------------------------- /NCEP/gcjob_13pl_cloud.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH --nodes=1 3 | #SBATCH --cpus-per-task=30 # Use all available CPU cores 4 | #SBATCH --time=4:00:00 # Adjust this to your estimated run time 5 | #SBATCH --job-name=graphcast 6 | #SBATCH --output=gc_13pl_output.txt 7 | #SBATCH --error=gc_13pl_error.txt 8 | #SBATCH --partition=compute 9 | 10 | # load module lib 11 | # source /etc/profile.d/modules.sh 12 | 13 | # load necessary modules 14 | module use /contrib/spack-stack/envs/ufswm/install/modulefiles/Core/ 15 | module load stack-intel 16 | module load wgrib2 17 | module list 18 | 19 | # Get the UTC hour and calculate the time in the format yyyymmddhh 20 | current_hour=$(date -u +%H) 21 | current_hour=$((10#$current_hour)) 22 | 23 | if (( $current_hour >= 0 && $current_hour < 6 )); then 24 | datetime=$(date -u -d 'today 00:00') 25 | elif (( $current_hour >= 6 && $current_hour < 12 )); then 26 | datetime=$(date -u -d 'today 06:00') 27 | elif (( $current_hour >= 12 && $current_hour < 18 )); then 28 | datetime=$(date -u -d 'today 12:00') 29 | else 30 | datetime=$(date -u -d 'today 18:00') 31 | fi 32 | 33 | # Calculate time 6 hours before 34 | #curr_datetime=$(date -u -d "$time" +'%Y%m%d%H') 35 | curr_datetime=$( date -d "$datetime 12 hour ago" "+%Y%m%d%H" ) 36 | prev_datetime=$( date -d "$datetime 18 hour ago" "+%Y%m%d%H" ) 37 | 38 | echo "Current state: $curr_datetime" 39 | echo "6 hours earlier state: $prev_datetime" 40 | 41 | forecast_length=64 42 | echo "forecast length: $forecast_length" 43 | 44 | num_pressure_levels=13 45 | echo "number of pressure levels: $num_pressure_levels" 46 | 47 | # Set Miniconda path 48 | #export PATH="/contrib/Sadegh.Tabas/miniconda3/bin:$PATH" 49 | 50 | # Activate Conda environment 51 | source /contrib/Sadegh.Tabas/miniconda3/etc/profile.d/conda.sh 52 | conda activate mlwp 53 | 54 | # going to the model directory 55 | cd /contrib/Sadegh.Tabas/operational/graphcast/NCEP/ 56 | 57 | start_time=$(date +%s) 58 | echo "start runing gdas utility to generate graphcast inputs for: $curr_datetime" 59 | # Run the Python script gdas.py with the calculated times 60 | python3 gdas_utility.py "$prev_datetime" "$curr_datetime" -l "$num_pressure_levels" 61 | 62 | end_time=$(date +%s) # Record the end time in seconds since the epoch 63 | 64 | # Calculate and print the execution time 65 | execution_time=$((end_time - start_time)) 66 | echo "Execution time for gdas_utility.py: $execution_time seconds" 67 | 68 | start_time=$(date +%s) 69 | echo "start runing graphcast to get real time 10-days forecasts for: $curr_datetime" 70 | # Run another Python script 71 | python3 run_graphcast.py -i source-gdas_date-"$curr_datetime"_res-0.25_levels-"$num_pressure_levels"_steps-2.nc -w /contrib/graphcast/NCEP -l "$forecast_length" -p "$num_pressure_levels" -u yes -k no 72 | 73 | end_time=$(date +%s) # Record the end time in seconds since the epoch 74 | 75 | # Calculate and print the execution time 76 | execution_time=$((end_time - start_time)) 77 | echo "Execution time for graphcast: $execution_time seconds" 78 | -------------------------------------------------------------------------------- /NCEP/gcjob_37pl_cloud.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH --nodes=1 3 | #SBATCH --cpus-per-task=30 # Use all available CPU cores 4 | #SBATCH --time=4:00:00 # Adjust this to your estimated run time 5 | #SBATCH --job-name=graphcast 6 | #SBATCH --output=gc_37pl_output.txt 7 | #SBATCH --error=gc_37pl_error.txt 8 | #SBATCH --partition=compute 9 | 10 | # load module lib 11 | # source /etc/profile.d/modules.sh 12 | 13 | # load necessary modules 14 | module use /contrib/spack-stack/envs/ufswm/install/modulefiles/Core/ 15 | module load stack-intel 16 | module load wgrib2 17 | module list 18 | 19 | # Get the UTC hour and calculate the time in the format yyyymmddhh 20 | current_hour=$(date -u +%H) 21 | current_hour=$((10#$current_hour)) 22 | 23 | if (( $current_hour >= 0 && $current_hour < 6 )); then 24 | datetime=$(date -u -d 'today 00:00') 25 | elif (( $current_hour >= 6 && $current_hour < 12 )); then 26 | datetime=$(date -u -d 'today 06:00') 27 | elif (( $current_hour >= 12 && $current_hour < 18 )); then 28 | datetime=$(date -u -d 'today 12:00') 29 | else 30 | datetime=$(date -u -d 'today 18:00') 31 | fi 32 | 33 | # Calculate time 6 hours before 34 | #curr_datetime=$(date -u -d "$time" +'%Y%m%d%H') 35 | curr_datetime=$( date -d "$datetime 12 hour ago" "+%Y%m%d%H" ) 36 | prev_datetime=$( date -d "$datetime 18 hour ago" "+%Y%m%d%H" ) 37 | 38 | echo "Current state: $curr_datetime" 39 | echo "6 hours earlier state: $prev_datetime" 40 | 41 | forecast_length=40 42 | echo "forecast length: $forecast_length" 43 | 44 | num_pressure_levels=37 45 | echo "number of pressure levels: $num_pressure_levels" 46 | 47 | # Set Miniconda path 48 | #export PATH="/contrib/Sadegh.Tabas/miniconda3/bin:$PATH" 49 | 50 | # Activate Conda environment 51 | source /contrib/Sadegh.Tabas/miniconda3/etc/profile.d/conda.sh 52 | conda activate mlwp 53 | 54 | # going to the model directory 55 | cd /contrib/Sadegh.Tabas/operational/graphcast/NCEP/ 56 | 57 | start_time=$(date +%s) 58 | echo "start runing gdas utility to generate graphcast inputs for: $curr_datetime" 59 | # Run the Python script gdas.py with the calculated times 60 | python3 gdas_utility.py "$prev_datetime" "$curr_datetime" -l "$num_pressure_levels" 61 | 62 | end_time=$(date +%s) # Record the end time in seconds since the epoch 63 | 64 | # Calculate and print the execution time 65 | execution_time=$((end_time - start_time)) 66 | echo "Execution time for gdas_utility.py: $execution_time seconds" 67 | 68 | start_time=$(date +%s) 69 | echo "start runing graphcast to get real time 10-days forecasts for: $curr_datetime" 70 | # Run another Python script 71 | python3 run_graphcast.py -i source-gdas_date-"$curr_datetime"_res-0.25_levels-"$num_pressure_levels"_steps-2.nc -w /contrib/graphcast/NCEP -l "$forecast_length" -p "$num_pressure_levels" -u yes -k no 72 | 73 | end_time=$(date +%s) # Record the end time in seconds since the epoch 74 | 75 | # Calculate and print the execution time 76 | execution_time=$((end_time - start_time)) 77 | echo "Execution time for graphcast: $execution_time seconds" 78 | -------------------------------------------------------------------------------- /NCEP/run_graphcast.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Description: Script to call the graphcast model using gdas products 3 | Author: Sadegh Sadeghi Tabas (sadegh.tabas@noaa.gov) 4 | Revision history: 5 | -20231218: Sadegh Tabas, initial code 6 | -20240118: Sadegh Tabas, S3 bucket module to upload data, adding forecast length, Updating batch dataset to account for forecast length 7 | -20240125: Linlin Cui, added a capability to save output as grib2 format 8 | -20240205: Sadegh Tabas, made the code clearer, added 37 pressure level option, updated upload to s3 9 | -20240731: Sadegh Tabas, added grib2 file for F000 10 | -20240815: Sadegh Tabas, update the directory of fine tuned model parameters 11 | ''' 12 | import os 13 | import argparse 14 | from datetime import timedelta 15 | import dataclasses 16 | import functools 17 | import re 18 | import haiku as hk 19 | import jax 20 | import numpy as np 21 | import xarray 22 | import boto3 23 | import pandas as pd 24 | 25 | from graphcast import autoregressive 26 | from graphcast import casting 27 | from graphcast import checkpoint 28 | from graphcast import data_utils 29 | from graphcast import graphcast 30 | from graphcast import normalization 31 | from graphcast import rollout 32 | 33 | from utils.nc2grib import Netcdf2Grib 34 | 35 | class GraphCastModel: 36 | def __init__(self, pretrained_model_path, gdas_data_path, output_dir=None, num_pressure_levels=13, forecast_length=40): 37 | self.pretrained_model_path = pretrained_model_path 38 | self.gdas_data_path = gdas_data_path 39 | self.forecast_length = forecast_length 40 | self.num_pressure_levels = num_pressure_levels 41 | 42 | if output_dir is None: 43 | self.output_dir = os.path.join(os.getcwd(), f"forecasts_{str(self.num_pressure_levels)}_levels") # Use current directory if not specified 44 | else: 45 | self.output_dir = os.path.join(output_dir, f"forecasts_{str(self.num_pressure_levels)}_levels") 46 | os.makedirs(self.output_dir, exist_ok=True) 47 | 48 | self.params = None 49 | self.state = {} 50 | self.model_config = None 51 | self.task_config = None 52 | self.diffs_stddev_by_level = None 53 | self.mean_by_level = None 54 | self.stddev_by_level = None 55 | self.current_batch = None 56 | self.inputs = None 57 | self.targets = None 58 | self.forcings = None 59 | self.s3_bucket_name = "noaa-nws-graphcastgfs-pds" 60 | self.dates = None 61 | 62 | 63 | def load_pretrained_model(self): 64 | """Load pre-trained GraphCast model.""" 65 | if self.num_pressure_levels==13: 66 | model_weights_path = f"{self.pretrained_model_path}/params/GCGFSv2_finetuned - GDAS - ERA5 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz" 67 | else: 68 | model_weights_path = f"{self.pretrained_model_path}/params/GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz" 69 | 70 | with open(model_weights_path, "rb") as f: 71 | ckpt = checkpoint.load(f, graphcast.CheckPoint) 72 | self.params = ckpt.params 73 | self.state = {} 74 | self.model_config = ckpt.model_config 75 | self.task_config = ckpt.task_config 76 | 77 | def load_gdas_data(self): 78 | """Load GDAS data.""" 79 | #with open(gdas_data_path, "rb") as f: 80 | # self.current_batch = xarray.load_dataset(f).compute() 81 | self.current_batch = xarray.load_dataset(self.gdas_data_path).compute() 82 | self.dates = pd.to_datetime(self.current_batch.datetime.values) 83 | 84 | if (self.forecast_length + 2) > len(self.current_batch['time']): 85 | print('Updating batch dataset to account for forecast length') 86 | 87 | diff = int(self.forecast_length + 2 - len(self.current_batch['time'])) 88 | ds = self.current_batch 89 | 90 | # time and datetime update 91 | curr_time_range = ds['time'].values.astype('timedelta64[ns]') 92 | new_time_range = (np.arange(len(curr_time_range) + diff) * np.timedelta64(6, 'h')).astype('timedelta64[ns]') 93 | ds = ds.reindex(time = new_time_range) 94 | curr_datetime_range = ds['datetime'][0].values.astype('datetime64[ns]') 95 | new_datetime_range = curr_datetime_range[0] + np.arange(len(curr_time_range) + diff) * np.timedelta64(6, 'h') 96 | ds['datetime'][0]= new_datetime_range 97 | 98 | self.current_batch = ds 99 | print('batch dataset updated') 100 | 101 | 102 | def extract_inputs_targets_forcings(self): 103 | """Extract inputs, targets, and forcings from the loaded data.""" 104 | self.inputs, self.targets, self.forcings = data_utils.extract_inputs_targets_forcings( 105 | self.current_batch, target_lead_times=slice("6h", f"{self.forecast_length*6}h"), **dataclasses.asdict(self.task_config) 106 | ) 107 | 108 | def load_normalization_stats(self): 109 | """Load normalization stats.""" 110 | 111 | diffs_stddev_path = f"{self.pretrained_model_path}/stats/diffs_stddev_by_level.nc" 112 | mean_path = f"{self.pretrained_model_path}/stats/mean_by_level.nc" 113 | stddev_path = f"{self.pretrained_model_path}/stats/stddev_by_level.nc" 114 | 115 | with open(diffs_stddev_path, "rb") as f: 116 | self.diffs_stddev_by_level = xarray.load_dataset(f).compute() 117 | with open(mean_path, "rb") as f: 118 | self.mean_by_level = xarray.load_dataset(f).compute() 119 | with open(stddev_path, "rb") as f: 120 | self.stddev_by_level = xarray.load_dataset(f).compute() 121 | 122 | # Jax doesn't seem to like passing configs as args through the jit. Passing it 123 | # in via partial (instead of capture by closure) forces jax to invalidate the 124 | # jit cache if you change configs. 125 | def _with_configs(self, fn): 126 | return functools.partial(fn, model_config=self.model_config, task_config=self.task_config,) 127 | 128 | # Always pass params and state, so the usage below are simpler 129 | def _with_params(self, fn): 130 | return functools.partial(fn, params=self.params, state=self.state) 131 | 132 | # Deepmind models aren't stateful, so the state is always empty, so just return the 133 | # predictions. This is requiredy by the rollout code, and generally simpler. 134 | @staticmethod 135 | def _drop_state(fn): 136 | return lambda **kw: fn(**kw)[0] 137 | 138 | def load_model(self): 139 | def construct_wrapped_graphcast(model_config, task_config): 140 | """Constructs and wraps the GraphCast Predictor.""" 141 | # Deeper one-step predictor. 142 | predictor = graphcast.GraphCast(model_config, task_config) 143 | 144 | # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to 145 | # from/to float32 to/from BFloat16. 146 | predictor = casting.Bfloat16Cast(predictor) 147 | 148 | # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from 149 | # BFloat16 happens after applying normalization to the inputs/targets. 150 | predictor = normalization.InputsAndResiduals(predictor, diffs_stddev_by_level=self.diffs_stddev_by_level, mean_by_level=self.mean_by_level, stddev_by_level=self.stddev_by_level,) 151 | 152 | # Wraps everything so the one-step model can produce trajectories. 153 | predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True,) 154 | return predictor 155 | 156 | @hk.transform_with_state 157 | def run_forward(model_config, task_config, inputs, targets_template, forcings,): 158 | predictor = construct_wrapped_graphcast(model_config, task_config) 159 | return predictor(inputs, targets_template=targets_template, forcings=forcings,) 160 | 161 | jax.jit(self._with_configs(run_forward.init)) 162 | self.model = self._drop_state(self._with_params(jax.jit(self._with_configs(run_forward.apply)))) 163 | 164 | 165 | def get_predictions(self): 166 | """Run GraphCast and save forecasts to a NetCDF file.""" 167 | 168 | print (f"start running GraphCast for {self.forecast_length} steps --> {self.forecast_length*6} hours.") 169 | self.load_model() 170 | 171 | # output = self.model(self.model ,rng=jax.random.PRNGKey(0), inputs=self.inputs, targets_template=self.targets * np.nan, forcings=self.forcings,) 172 | forecasts = rollout.chunked_prediction(self.model, rng=jax.random.PRNGKey(0), inputs=self.inputs, targets_template=self.targets * np.nan, forcings=self.forcings,) 173 | 174 | # filename = f"forecasts_levels-{self.num_pressure_levels}_steps-{self.forecast_length}.nc" 175 | # output_netcdf = os.path.join(self.output_dir, filename) 176 | 177 | # save forecasts 178 | # forecasts.to_netcdf(output_netcdf) 179 | # print (f"GraphCast run completed successfully, you can find the GraphCast forecasts in the following directory:\n {output_netcdf}") 180 | 181 | self.save_grib2(forecasts) 182 | 183 | def save_grib2(self, forecasts): 184 | converter = Netcdf2Grib() 185 | 186 | # Call and save f000 in grib2 187 | ds = self.current_batch 188 | ds = ds.drop_vars(['geopotential_at_surface','land_sea_mask', 'total_precipitation_6hr']) 189 | for var in ds.data_vars: 190 | if 'long_name' in ds[var].attrs: 191 | del ds[var].attrs['long_name'] 192 | ds = ds.isel(time=slice(1, 2)) 193 | ds['time'] = ds['time'] - pd.Timedelta(hours=6) 194 | 195 | converter.save_grib2(self.dates, ds, self.output_dir) 196 | 197 | # Call and save forecasts in grib2 198 | converter.save_grib2(self.dates, forecasts, self.output_dir) 199 | 200 | 201 | def upload_to_s3(self, keep_data): 202 | s3 = boto3.client('s3') 203 | 204 | # Extract date and time information from the input file name 205 | input_file_name = os.path.basename(self.gdas_data_path) 206 | 207 | date_start = input_file_name.find("date-") 208 | 209 | # Check if "date-" is found in the input_file_name 210 | if date_start != -1: 211 | date_start += len("date-") # Move to the end of "date-" 212 | date = input_file_name[date_start:date_start + 8] # Extract 8 characters as the date 213 | 214 | time_start = date_start + 8 # Move to the character after the date 215 | time = input_file_name[time_start:time_start + 2] # Extract 2 characters as the time 216 | 217 | 218 | # Define S3 key paths for input and output files 219 | input_s3_key = f'graphcastgfs.{date}/{time}/input/{self.gdas_data_path}' 220 | 221 | # Upload input file to S3 222 | s3.upload_file(self.gdas_data_path, self.s3_bucket_name, input_s3_key) 223 | 224 | # Upload output files to S3 225 | # Iterate over all files in the local directory and upload each one to S3 226 | s3_prefix = f'graphcastgfs.{date}/{time}/forecasts_{self.num_pressure_levels}_levels' 227 | 228 | for root, dirs, files in os.walk(self.output_dir): 229 | 230 | for file in files: 231 | local_path = os.path.join(root, file) 232 | relative_path = os.path.relpath(local_path, self.output_dir) 233 | s3_path = os.path.join(s3_prefix, relative_path) 234 | 235 | # Upload the file 236 | s3.upload_file(local_path, self.s3_bucket_name, s3_path) 237 | 238 | print("Upload to s3 bucket completed.") 239 | 240 | # Delete local files if keep_data is False 241 | if not keep_data: 242 | # Remove forecast data from the specified directory 243 | print("Removing input and forecast data from the specified directory...") 244 | try: 245 | os.system(f"rm -rf {self.output_dir}") 246 | os.remove(self.gdas_data_path) 247 | print("Local input and output files deleted.") 248 | except Exception as e: 249 | print(f"Error removing input and forecast data: {str(e)}") 250 | 251 | 252 | 253 | if __name__ == "__main__": 254 | parser = argparse.ArgumentParser(description="Run GraphCast model.") 255 | parser.add_argument("-i", "--input", help="input file path (including file name)", required=True) 256 | parser.add_argument("-w", "--weights", help="parent directory of the graphcast params and stats", required=True) 257 | parser.add_argument("-l", "--length", help="length of forecast (6-hourly), an integer number in range [1, 40]", required=True) 258 | parser.add_argument("-o", "--output", help="output directory", default=None) 259 | parser.add_argument("-p", "--pressure", help="number of pressure levels", default=13) 260 | parser.add_argument("-u", "--upload", help="upload input data as well as forecasts to noaa s3 bucket (yes or no)", default = "no") 261 | parser.add_argument("-k", "--keep", help="keep input and output after uploading to noaa s3 bucket (yes or no)", default = "no") 262 | 263 | args = parser.parse_args() 264 | runner = GraphCastModel(args.weights, args.input, args.output, int(args.pressure), int(args.length)) 265 | 266 | runner.load_pretrained_model() 267 | runner.load_gdas_data() 268 | runner.extract_inputs_targets_forcings() 269 | runner.load_normalization_stats() 270 | runner.get_predictions() 271 | 272 | upload_data = args.upload.lower() == "yes" 273 | keep_data = args.keep.lower() == "yes" 274 | 275 | if upload_data: 276 | runner.upload_to_s3(keep_data) 277 | -------------------------------------------------------------------------------- /NCEP/upload_to_s3bucket.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | 6 | def upload_to_s3(forecast_datetime, output_path, level, keep): 7 | 8 | if output_path is None: 9 | output_dir = os.path.join(os.getcwd(), f"forecasts_{level}_levels") # Use current directory if not specified 10 | else: 11 | output_dir = os.path.join(output_path, f"forecasts_{level}_levels") 12 | 13 | s3 = boto3.client('s3') 14 | 15 | date = forecast_datetime[0:8] 16 | time = forecast_datetime[8:10] 17 | 18 | # Upload output files to S3 19 | # Iterate over all files in the local directory and upload each one to S3 20 | s3_prefix = f'graphcastgfs.{date}/{time}/forecasts_{level}_levels' 21 | 22 | for root, dirs, files in os.walk(self.output_dir): 23 | for file in files: 24 | local_path = os.path.join(root, file) 25 | relative_path = os.path.relpath(local_path, local_directory) 26 | s3_path = os.path.join(s3_prefix, relative_path) 27 | 28 | # Upload the file 29 | s3.upload_file(local_path, self.s3_bucket_name, s3_path) 30 | 31 | print("Upload to s3 bucket completed.") 32 | 33 | # Delete local files if keep_data is False 34 | if not keep_data: 35 | # Remove forecasts data from the specified directory 36 | print("Removing downloaded grib2 data...") 37 | try: 38 | os.system(f"rm -rf {self.output_dir}") 39 | print("Downloaded data removed.") 40 | except Exception as e: 41 | print(f"Error removing downloaded data: {str(e)}") 42 | print("Local input and output files deleted.") 43 | 44 | 45 | if __name__ == "__main__": 46 | 47 | parser = argparse.ArgumentParser(description="upload input and output to s3 bucket") 48 | parser.add_argument("-d", "--datetime", help="forecast datetime", required=True) 49 | parser.add_argument("-l", "--level", help="number of pressure levels", default=13) 50 | parser.add_argument("-o", "--output", help="output file path (including file name)", default=None) 51 | 52 | args = parser.parse_args() 53 | keep = False 54 | 55 | upload_to_s3(args.datetime, str(args.level), args.output, keep) 56 | -------------------------------------------------------------------------------- /NCEP/utils/nc2grib.py: -------------------------------------------------------------------------------- 1 | """ Utility for converting netcdf data to grib2. 2 | 3 | History: 4 | 01/26/2024: Linlin Cui (linlin.cui@noaa.gov), added function save_grib2 5 | 02/05/2024: Sadegh Tabas update the utility to a object-oriented format 6 | 04/25/2024: Sadegh Tabas, generate grib2 index files 7 | 07/03/2024: Sadegh Tabas, sorted grib2 variables 8 | """ 9 | 10 | import os 11 | from datetime import datetime, timedelta 12 | import glob 13 | import subprocess 14 | import cf_units 15 | import iris 16 | import iris_grib 17 | import eccodes 18 | 19 | class Netcdf2Grib: 20 | def __init__(self): 21 | self.ATTR_MAPS = { 22 | '10m_u_component_of_wind': [10, 'x_wind', 'm s**-1'], 23 | '10m_v_component_of_wind': [10, 'y_wind', 'm s**-1'], 24 | 'mean_sea_level_pressure': [0, 'air_pressure_at_sea_level', 'Pa'], 25 | '2m_temperature': [2, 'air_temperature', 'K'], 26 | 'total_precipitation_6hr': [0, 'precipitation_amount', 'kg m**-2'], 27 | 'total_precipitation_cumsum': [0, 'precipitation_amount', 'kg m**-2'], 28 | 'vertical_velocity': [None, 'lagrangian_tendency_of_air_pressure', 'Pa s**-1'], 29 | 'specific_humidity': [None, 'specific_humidity', 'kg kg**-1'], 30 | 'temperature': [None, 'air_temperature', 'K'], 31 | 'geopotential': [None, 'geopotential_height', 'm'], 32 | 'u_component_of_wind': [None, 'x_wind', 'm s**-1'], 33 | 'v_component_of_wind': [None, 'y_wind', 'm s**-1'], 34 | } 35 | 36 | def tweaked_messages(self, cube, time_range): 37 | """ 38 | Adjust GRIB messages based on cube properties. 39 | """ 40 | for cube, grib_message in iris_grib.save_pairs_from_cube(cube): 41 | if cube.standard_name == 'precipitation_amount': 42 | eccodes.codes_set(grib_message, 'stepType', 'accum') 43 | eccodes.codes_set(grib_message, 'stepRange', time_range) 44 | eccodes.codes_set(grib_message, 'discipline', 0) 45 | eccodes.codes_set(grib_message, 'parameterCategory', 1) 46 | eccodes.codes_set(grib_message, 'parameterNumber', 8) 47 | eccodes.codes_set(grib_message, 'typeOfFirstFixedSurface', 1) 48 | eccodes.codes_set(grib_message, 'typeOfStatisticalProcessing', 1) 49 | elif cube.standard_name == 'air_pressure_at_sea_level': 50 | eccodes.codes_set(grib_message, 'discipline', 0) 51 | eccodes.codes_set(grib_message, 'parameterCategory', 3) 52 | eccodes.codes_set(grib_message, 'parameterNumber', 1) 53 | eccodes.codes_set(grib_message, 'typeOfFirstFixedSurface', 101) 54 | yield grib_message 55 | 56 | def save_grib2(self, dates, forecasts, outdir): 57 | """ 58 | Convert netCDF file to GRIB2 format file. 59 | Args: 60 | dates: array of datetime object, from the source file 61 | forecasts: xarray forecasts dataset 62 | outdir: output directory 63 | 64 | Returns: 65 | No return values, will save to grib2 file 66 | """ 67 | forecasts = forecasts.reindex(lat=list(reversed(forecasts.lat))) 68 | 69 | for var in forecasts.variables: 70 | if 'batch' in forecasts[var].dims: 71 | forecasts[var] = forecasts[var].squeeze(dim='batch') 72 | 73 | # Update units 74 | forecasts['level'] = forecasts['level'] * 100 75 | forecasts['level'].attrs['long_name'] = 'pressure' 76 | forecasts['level'].attrs['units'] = 'Pa' 77 | forecasts['geopotential'] = forecasts['geopotential'] / 9.80665 78 | if 'total_precipitation_6hr' in forecasts: 79 | forecasts['total_precipitation_6hr'] = forecasts['total_precipitation_6hr'] * 1000 80 | forecasts['total_precipitation_cumsum'] = forecasts['total_precipitation_6hr'].cumsum(axis=0) 81 | 82 | filename = os.path.join(outdir, "forecast_to_grib2.nc") 83 | forecasts.to_netcdf(filename) 84 | 85 | # Load cubes from netCDF file 86 | cubes = iris.load(filename) 87 | times = cubes[0].coord('time').points 88 | forecast_starttime = dates[0][1] 89 | cycle = forecast_starttime.hour 90 | print(f'Forecast start time is {forecast_starttime}') 91 | 92 | datevectors = [forecast_starttime + timedelta(hours=int(t)) for t in times] 93 | 94 | time_fmt_str = '00:00:00' 95 | time_unit_str = f"Hours since {forecast_starttime.strftime('%Y-%m-%d %H:00:00')}" 96 | time_coord = cubes[0].coord('time') 97 | new_time_unit = cf_units.Unit(time_unit_str, calendar=cf_units.CALENDAR_STANDARD) 98 | new_time_points = [new_time_unit.date2num(dt) for dt in datevectors] 99 | new_time_coord = iris.coords.DimCoord(new_time_points, standard_name='time', units=new_time_unit) 100 | 101 | for date in datevectors: 102 | print(f"Processing for time {date.strftime('%Y-%m-%d %H:00:00')}") 103 | hrs = int((date - forecast_starttime).total_seconds() // 3600) 104 | outfile = os.path.join(outdir, f'graphcastgfs.t{cycle:02d}z.pgrb2.0p25.f{hrs:03d}') 105 | print(outfile) 106 | 107 | for cube in sorted(cubes, key=lambda cube: cube.name()): 108 | var_name = cube.name() 109 | 110 | # Adjust cube for different variables 111 | time_coord_dim = cube.coord_dims('time') 112 | cube.remove_coord('time') 113 | cube.add_dim_coord(new_time_coord, time_coord_dim) 114 | 115 | hour_6 = iris.Constraint(time=iris.time.PartialDateTime(month=date.month, day=date.day, hour=date.hour)) 116 | cube_slice = cube.extract(hour_6) 117 | cube_slice.coord('latitude').coord_system = iris.coord_systems.GeogCS(4326) 118 | cube_slice.coord('longitude').coord_system = iris.coord_systems.GeogCS(4326) 119 | 120 | if len(cube_slice.data.shape) == 3: 121 | levels = cube_slice.coord('pressure').points 122 | for level in levels: 123 | cube_slice_level = cube_slice.extract(iris.Constraint(pressure=level)) 124 | cube_slice_level.add_aux_coord(iris.coords.DimCoord(hrs, standard_name='forecast_period', units='hours')) 125 | cube_slice_level.standard_name = self.ATTR_MAPS[var_name][1] 126 | cube_slice_level.units = self.ATTR_MAPS[var_name][2] 127 | iris.save(cube_slice_level, outfile, saver='grib2', append=True) 128 | else: 129 | cube_slice.add_aux_coord(iris.coords.DimCoord(hrs, standard_name='forecast_period', units='hours')) 130 | cube_slice.standard_name = self.ATTR_MAPS[var_name][1] 131 | cube_slice.units = self.ATTR_MAPS[var_name][2] 132 | 133 | if var_name not in ['mean_sea_level_pressure', 'total_precipitation_6hr', 'total_precipitation_cumsum']: 134 | cube_slice.add_aux_coord(iris.coords.DimCoord(self.ATTR_MAPS[var_name][0], standard_name='height', units='m')) 135 | iris.save(cube_slice, outfile, saver='grib2', append=True) 136 | elif var_name == 'total_precipitation_6hr': 137 | iris_grib.save_messages(self.tweaked_messages(cube_slice, f'{hrs-6}-{hrs}'), outfile, append=True) 138 | elif var_name == 'total_precipitation_cumsum': 139 | iris_grib.save_messages(self.tweaked_messages(cube_slice, f'0-{hrs}'), outfile, append=True) 140 | elif var_name == 'mean_sea_level_pressure': 141 | cube_slice.add_aux_coord(iris.coords.DimCoord(self.ATTR_MAPS[var_name][0], standard_name='altitude', units='m')) 142 | iris_grib.save_messages(self.tweaked_messages(cube_slice, f'{hrs-6}-{hrs}'), outfile, append=True) 143 | 144 | # Use wgrib2 to generate index files 145 | output_idx_file = f"{outfile}.idx" 146 | 147 | # Construct the wgrib2 command 148 | wgrib2_command = ['wgrib2', '-s', outfile] 149 | 150 | try: 151 | # Open the output file for writing 152 | with open(output_idx_file, "w") as f_out: 153 | # Execute the wgrib2 command and redirect stdout to the output file 154 | subprocess.run(wgrib2_command, stdout=f_out, check=True) 155 | 156 | print(f"Index file created successfully: {output_idx_file}") 157 | 158 | except subprocess.CalledProcessError as e: 159 | print(f"Error running wgrib2 command: {e}") 160 | 161 | 162 | # Remove intermediate netCDF file 163 | if os.path.isfile(filename): 164 | print(f'Deleting intermediate nc file {filename}: ') 165 | os.remove(filename) 166 | 167 | # subset grib2 files 168 | def subset_grib2(indir=None): 169 | files = glob.glob(f'{indir}/graphcastgfs.*') 170 | files.sort() 171 | 172 | outdir = os.path.join(indir, 'north_america') 173 | os.makedirs(outdir, exist_ok=True) 174 | 175 | lonMin, lonMax, latMin, latMax = 61.0, 299.0, -37.0, 37.0 176 | for grbfile in files: 177 | outfile = f"{outdir}/{grbfile.split('/')[-1]}" 178 | command = ['wgrib2', grbfile, '-small_grib', f'{lonMin}:{lonMax}', f'{latMin}:{latMax}', outfile] 179 | subprocess.run(command, check=True) 180 | 181 | 182 | # subset_grib2(outdir) 183 | 184 | -------------------------------------------------------------------------------- /docs/GenCast_0p25deg_accelerator_scorecard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/GenCast_0p25deg_accelerator_scorecard.png -------------------------------------------------------------------------------- /docs/GenCast_0p25deg_attention_implementation_scorecard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/GenCast_0p25deg_attention_implementation_scorecard.png -------------------------------------------------------------------------------- /docs/GenCast_1p0deg_Mini_ENS_scorecard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/GenCast_1p0deg_Mini_ENS_scorecard.png -------------------------------------------------------------------------------- /docs/local_runtime_popup_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/local_runtime_popup_1.png -------------------------------------------------------------------------------- /docs/local_runtime_popup_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/local_runtime_popup_2.png -------------------------------------------------------------------------------- /docs/local_runtime_url.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/local_runtime_url.png -------------------------------------------------------------------------------- /docs/project.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/project.png -------------------------------------------------------------------------------- /docs/provision_tpu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/provision_tpu.png -------------------------------------------------------------------------------- /docs/tpu_types.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NOAA-EMC/graphcast/da5477d2dfe9d6af4befe701d7c6873eb4e24eb0/docs/tpu_types.png -------------------------------------------------------------------------------- /graphcast/casting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Wrappers that take care of casting.""" 15 | 16 | import contextlib 17 | from typing import Any, Mapping, Tuple 18 | 19 | import chex 20 | from graphcast import predictor_base 21 | import haiku as hk 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | import xarray 26 | 27 | 28 | PyTree = Any 29 | 30 | 31 | class Bfloat16Cast(predictor_base.Predictor): 32 | """Wrapper that casts all inputs to bfloat16 and outputs to targets dtype.""" 33 | 34 | def __init__(self, predictor: predictor_base.Predictor, enabled: bool = True): 35 | """Inits the wrapper. 36 | 37 | Args: 38 | predictor: predictor being wrapped. 39 | enabled: disables the wrapper if False, for simpler hyperparameter scans. 40 | 41 | """ 42 | self._enabled = enabled 43 | self._predictor = predictor 44 | 45 | def __call__(self, 46 | inputs: xarray.Dataset, 47 | targets_template: xarray.Dataset, 48 | forcings: xarray.Dataset, 49 | **kwargs 50 | ) -> xarray.Dataset: 51 | if not self._enabled: 52 | return self._predictor(inputs, targets_template, forcings, **kwargs) 53 | 54 | with bfloat16_variable_view(): 55 | predictions = self._predictor( 56 | *_all_inputs_to_bfloat16(inputs, targets_template, forcings), 57 | **kwargs,) 58 | 59 | predictions_dtype = infer_floating_dtype(predictions) # pytype: disable=wrong-arg-types 60 | if predictions_dtype != jnp.bfloat16: 61 | raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}') 62 | 63 | targets_dtype = infer_floating_dtype(targets_template) # pytype: disable=wrong-arg-types 64 | return tree_map_cast( 65 | predictions, input_dtype=jnp.bfloat16, output_dtype=targets_dtype) 66 | 67 | def loss(self, 68 | inputs: xarray.Dataset, 69 | targets: xarray.Dataset, 70 | forcings: xarray.Dataset, 71 | **kwargs, 72 | ) -> predictor_base.LossAndDiagnostics: 73 | if not self._enabled: 74 | return self._predictor.loss(inputs, targets, forcings, **kwargs) 75 | 76 | with bfloat16_variable_view(): 77 | loss, scalars = self._predictor.loss( 78 | *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs) 79 | 80 | if loss.dtype != jnp.bfloat16: 81 | raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}') 82 | 83 | targets_dtype = infer_floating_dtype(targets) # pytype: disable=wrong-arg-types 84 | 85 | # Note that casting back the loss to e.g. float32 should not affect data 86 | # types of the backwards pass, because the first thing the backwards pass 87 | # should do is to go backwards the casting op and cast back to bfloat16 88 | # (and xprofs seem to confirm this). 89 | return tree_map_cast((loss, scalars), 90 | input_dtype=jnp.bfloat16, output_dtype=targets_dtype) 91 | 92 | def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray 93 | self, 94 | inputs: xarray.Dataset, 95 | targets: xarray.Dataset, 96 | forcings: xarray.Dataset, 97 | **kwargs, 98 | ) -> Tuple[predictor_base.LossAndDiagnostics, 99 | xarray.Dataset]: 100 | if not self._enabled: 101 | return self._predictor.loss_and_predictions(inputs, targets, forcings, # pytype: disable=bad-return-type # jax-ndarray 102 | **kwargs) 103 | 104 | with bfloat16_variable_view(): 105 | (loss, scalars), predictions = self._predictor.loss_and_predictions( 106 | *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs) 107 | 108 | if loss.dtype != jnp.bfloat16: 109 | raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}') 110 | 111 | predictions_dtype = infer_floating_dtype(predictions) # pytype: disable=wrong-arg-types 112 | if predictions_dtype != jnp.bfloat16: 113 | raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}') 114 | 115 | targets_dtype = infer_floating_dtype(targets) # pytype: disable=wrong-arg-types 116 | return tree_map_cast(((loss, scalars), predictions), 117 | input_dtype=jnp.bfloat16, output_dtype=targets_dtype) 118 | 119 | 120 | def infer_floating_dtype(data_vars: Mapping[str, chex.Array]) -> np.dtype: 121 | """Infers a floating dtype from an input mapping of data.""" 122 | dtypes = { 123 | v.dtype 124 | for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)} 125 | if len(dtypes) != 1: 126 | dtypes_and_shapes = { 127 | k: (v.dtype, v.shape) 128 | for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)} 129 | raise ValueError( 130 | f'Did not found exactly one floating dtype {dtypes} in input variables:' 131 | f'{dtypes_and_shapes}') 132 | return list(dtypes)[0] 133 | 134 | 135 | def _all_inputs_to_bfloat16( 136 | inputs: xarray.Dataset, 137 | targets: xarray.Dataset, 138 | forcings: xarray.Dataset, 139 | ) -> Tuple[xarray.Dataset, 140 | xarray.Dataset, 141 | xarray.Dataset]: 142 | return (inputs.astype(jnp.bfloat16), 143 | jax.tree.map(lambda x: x.astype(jnp.bfloat16), targets), 144 | forcings.astype(jnp.bfloat16)) 145 | 146 | 147 | def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype, 148 | ) -> PyTree: 149 | def cast_fn(x): 150 | if x.dtype == input_dtype: 151 | return x.astype(output_dtype) 152 | return jax.tree.map(cast_fn, inputs) 153 | 154 | 155 | @contextlib.contextmanager 156 | def bfloat16_variable_view(enabled: bool = True): 157 | """Context for Haiku modules with float32 params, but bfloat16 activations. 158 | 159 | It works as follows: 160 | * Every time a variable is requested to be created/set as np.bfloat16, 161 | it will create an underlying float32 variable, instead. 162 | * Every time a variable a variable is requested as bfloat16, it will check the 163 | variable is of float32 type, and cast the variable to bfloat16. 164 | 165 | Note the gradients are still computed and accumulated as float32, because 166 | the params returned by init are float32, so the gradient function with 167 | respect to the params will already include an implicit casting to float32. 168 | 169 | Args: 170 | enabled: Only enables bfloat16 behavior if True. 171 | 172 | Yields: 173 | None 174 | """ 175 | 176 | if enabled: 177 | with hk.custom_creator( 178 | _bfloat16_creator, state=True), hk.custom_getter( 179 | _bfloat16_getter, state=True), hk.custom_setter( 180 | _bfloat16_setter): 181 | yield 182 | else: 183 | yield 184 | 185 | 186 | def _bfloat16_creator(next_creator, shape, dtype, init, context): 187 | """Creates float32 variables when bfloat16 is requested.""" 188 | if context.original_dtype == jnp.bfloat16: 189 | dtype = jnp.float32 190 | return next_creator(shape, dtype, init) 191 | 192 | 193 | def _bfloat16_getter(next_getter, value, context): 194 | """Casts float32 to bfloat16 when bfloat16 was originally requested.""" 195 | if context.original_dtype == jnp.bfloat16: 196 | assert value.dtype == jnp.float32 197 | value = value.astype(jnp.bfloat16) 198 | return next_getter(value) 199 | 200 | 201 | def _bfloat16_setter(next_setter, value, context): 202 | """Casts bfloat16 to float32 when bfloat16 was originally set.""" 203 | if context.original_dtype == jnp.bfloat16: 204 | value = value.astype(jnp.float32) 205 | return next_setter(value) 206 | -------------------------------------------------------------------------------- /graphcast/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Serialize and deserialize trees.""" 15 | 16 | import dataclasses 17 | import io 18 | import types 19 | from typing import Any, BinaryIO, Optional, TypeVar 20 | 21 | import numpy as np 22 | 23 | _T = TypeVar("_T") 24 | 25 | 26 | def dump(dest: BinaryIO, value: Any) -> None: 27 | """Dump a tree of dicts/dataclasses to a file object. 28 | 29 | Args: 30 | dest: a file object to write to. 31 | value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and 32 | other basic types. Unions are not supported, other than Optional/None 33 | which is only supported in dataclasses, not in dicts, lists or tuples. 34 | All leaves must be coercible to a numpy array, and recoverable as a single 35 | arg to a type. 36 | """ 37 | buffer = io.BytesIO() # In case the destination doesn't support seeking. 38 | np.savez(buffer, **_flatten(value)) 39 | dest.write(buffer.getvalue()) 40 | 41 | 42 | def load(source: BinaryIO, typ: type[_T]) -> _T: 43 | """Load from a file object and convert it to the specified type. 44 | 45 | Args: 46 | source: a file object to read from. 47 | typ: a type object that acts as a schema for deserialization. It must match 48 | what was serialized. If a type is Any, it will be returned however numpy 49 | serialized it, which is what you want for a tree of numpy arrays. 50 | 51 | Returns: 52 | the deserialized value as the specified type. 53 | """ 54 | return _convert_types(typ, _unflatten(np.load(source))) 55 | 56 | 57 | _SEP = ":" 58 | 59 | 60 | def _flatten(tree: Any) -> dict[str, Any]: 61 | """Flatten a tree of dicts/dataclasses/lists/tuples to a single dict.""" 62 | if dataclasses.is_dataclass(tree): 63 | # Don't use dataclasses.asdict as it is recursive so skips dropping None. 64 | tree = {f.name: v for f in dataclasses.fields(tree) 65 | if (v := getattr(tree, f.name)) is not None} 66 | elif isinstance(tree, (list, tuple)): 67 | tree = dict(enumerate(tree)) 68 | 69 | assert isinstance(tree, dict) 70 | 71 | flat = {} 72 | for k, v in tree.items(): 73 | k = str(k) 74 | assert _SEP not in k 75 | if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)): 76 | for a, b in _flatten(v).items(): 77 | flat[f"{k}{_SEP}{a}"] = b 78 | else: 79 | assert v is not None 80 | flat[k] = v 81 | return flat 82 | 83 | 84 | def _unflatten(flat: dict[str, Any]) -> dict[str, Any]: 85 | """Unflatten a dict to a tree of dicts.""" 86 | tree = {} 87 | for flat_key, v in flat.items(): 88 | node = tree 89 | keys = flat_key.split(_SEP) 90 | for k in keys[:-1]: 91 | if k not in node: 92 | node[k] = {} 93 | node = node[k] 94 | node[keys[-1]] = v 95 | return tree 96 | 97 | 98 | def _convert_types(typ: type[_T], value: Any) -> _T: 99 | """Convert some structure into the given type. The structures must match.""" 100 | if typ in (Any, ...): 101 | return value 102 | 103 | if typ in (int, float, str, bool): 104 | return typ(value) 105 | 106 | if typ is np.ndarray: 107 | assert isinstance(value, np.ndarray) 108 | return value 109 | 110 | if dataclasses.is_dataclass(typ): 111 | kwargs = {} 112 | for f in dataclasses.fields(typ): 113 | # Only support Optional for dataclasses, as numpy can't serialize it 114 | # directly (without pickle), and dataclasses are the only case where we 115 | # can know the full set of values and types and therefore know the 116 | # non-existence must mean None. 117 | if isinstance(f.type, (types.UnionType, type(Optional[int]))): 118 | constructors = [t for t in f.type.__args__ if t is not types.NoneType] 119 | if len(constructors) != 1: 120 | raise TypeError( 121 | "Optional works, Union with anything except None doesn't") 122 | if f.name not in value: 123 | kwargs[f.name] = None 124 | continue 125 | constructor = constructors[0] 126 | else: 127 | constructor = f.type 128 | 129 | if f.name in value: 130 | kwargs[f.name] = _convert_types(constructor, value[f.name]) 131 | else: 132 | raise ValueError(f"Missing value: {f.name}") 133 | return typ(**kwargs) 134 | 135 | base_type = getattr(typ, "__origin__", None) 136 | 137 | if base_type is dict: 138 | assert len(typ.__args__) == 2 139 | key_type, value_type = typ.__args__ 140 | return {_convert_types(key_type, k): _convert_types(value_type, v) 141 | for k, v in value.items()} 142 | 143 | if base_type is list: 144 | assert len(typ.__args__) == 1 145 | value_type = typ.__args__[0] 146 | return [_convert_types(value_type, v) 147 | for _, v in sorted(value.items(), key=lambda x: int(x[0]))] 148 | 149 | if base_type is tuple: 150 | if len(typ.__args__) == 2 and typ.__args__[1] == ...: 151 | # An arbitrary length tuple of a single type, eg: tuple[int, ...] 152 | value_type = typ.__args__[0] 153 | return tuple(_convert_types(value_type, v) 154 | for _, v in sorted(value.items(), key=lambda x: int(x[0]))) 155 | else: 156 | # A fixed length tuple of arbitrary types, eg: tuple[int, str, float] 157 | assert len(typ.__args__) == len(value) 158 | return tuple( 159 | _convert_types(t, v) 160 | for t, (_, v) in zip( 161 | typ.__args__, sorted(value.items(), key=lambda x: int(x[0])))) 162 | 163 | # This is probably unreachable with reasonable serializable inputs. 164 | try: 165 | return typ(value) 166 | except TypeError as e: 167 | raise TypeError( 168 | "_convert_types expects the type argument to be a dataclass defined " 169 | "with types that are valid constructors (eg tuple is fine, Tuple " 170 | "isn't), and accept a numpy array as the sole argument.") from e 171 | -------------------------------------------------------------------------------- /graphcast/checkpoint_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Check that the checkpoint serialization is reversable.""" 15 | 16 | import dataclasses 17 | import io 18 | from typing import Any, Optional, Union 19 | 20 | from absl.testing import absltest 21 | from graphcast import checkpoint 22 | import numpy as np 23 | 24 | 25 | @dataclasses.dataclass 26 | class SubConfig: 27 | a: int 28 | b: str 29 | 30 | 31 | @dataclasses.dataclass 32 | class Config: 33 | bt: bool 34 | bf: bool 35 | i: int 36 | f: float 37 | o1: Optional[int] 38 | o2: Optional[int] 39 | o3: Union[int, None] 40 | o4: Union[int, None] 41 | o5: int | None 42 | o6: int | None 43 | li: list[int] 44 | ls: list[str] 45 | ldc: list[SubConfig] 46 | tf: tuple[float, ...] 47 | ts: tuple[str, ...] 48 | t: tuple[str, int, SubConfig] 49 | tdc: tuple[SubConfig, ...] 50 | dsi: dict[str, int] 51 | dss: dict[str, str] 52 | dis: dict[int, str] 53 | dsdis: dict[str, dict[int, str]] 54 | dc: SubConfig 55 | dco: Optional[SubConfig] 56 | ddc: dict[str, SubConfig] 57 | 58 | 59 | @dataclasses.dataclass 60 | class Checkpoint: 61 | params: dict[str, Any] 62 | config: Config 63 | 64 | 65 | class DataclassTest(absltest.TestCase): 66 | 67 | def test_serialize_dataclass(self): 68 | ckpt = Checkpoint( 69 | params={ 70 | "layer1": { 71 | "w": np.arange(10).reshape(2, 5), 72 | "b": np.array([2, 6]), 73 | }, 74 | "layer2": { 75 | "w": np.arange(8).reshape(2, 4), 76 | "b": np.array([2, 6]), 77 | }, 78 | "blah": np.array([3, 9]), 79 | }, 80 | config=Config( 81 | bt=True, 82 | bf=False, 83 | i=42, 84 | f=3.14, 85 | o1=1, 86 | o2=None, 87 | o3=2, 88 | o4=None, 89 | o5=3, 90 | o6=None, 91 | li=[12, 9, 7, 15, 16, 14, 1, 6, 11, 4, 10, 5, 13, 3, 8, 2], 92 | ls=list("qhjfdxtpzgemryoikwvblcaus"), 93 | ldc=[SubConfig(1, "hello"), SubConfig(2, "world")], 94 | tf=(1, 4, 2, 10, 5, 9, 13, 16, 15, 8, 12, 7, 11, 14, 3, 6), 95 | ts=("hello", "world"), 96 | t=("foo", 42, SubConfig(1, "bar")), 97 | tdc=(SubConfig(1, "hello"), SubConfig(2, "world")), 98 | dsi={"a": 1, "b": 2, "c": 3}, 99 | dss={"d": "e", "f": "g"}, 100 | dis={1: "a", 2: "b", 3: "c"}, 101 | dsdis={"a": {1: "hello", 2: "world"}, "b": {1: "world"}}, 102 | dc=SubConfig(1, "hello"), 103 | dco=None, 104 | ddc={"a": SubConfig(1, "hello"), "b": SubConfig(2, "world")}, 105 | )) 106 | 107 | buffer = io.BytesIO() 108 | checkpoint.dump(buffer, ckpt) 109 | buffer.seek(0) 110 | ckpt2 = checkpoint.load(buffer, Checkpoint) 111 | np.testing.assert_array_equal(ckpt.params["layer1"]["w"], 112 | ckpt2.params["layer1"]["w"]) 113 | np.testing.assert_array_equal(ckpt.params["layer1"]["b"], 114 | ckpt2.params["layer1"]["b"]) 115 | np.testing.assert_array_equal(ckpt.params["layer2"]["w"], 116 | ckpt2.params["layer2"]["w"]) 117 | np.testing.assert_array_equal(ckpt.params["layer2"]["b"], 118 | ckpt2.params["layer2"]["b"]) 119 | np.testing.assert_array_equal(ckpt.params["blah"], ckpt2.params["blah"]) 120 | self.assertEqual(ckpt.config, ckpt2.config) 121 | 122 | 123 | if __name__ == "__main__": 124 | absltest.main() 125 | -------------------------------------------------------------------------------- /graphcast/data_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for `data_utils.py`.""" 15 | 16 | import datetime 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from graphcast import data_utils 20 | import numpy as np 21 | import xarray as xa 22 | 23 | 24 | class DataUtilsTest(parameterized.TestCase): 25 | 26 | def setUp(self): 27 | super().setUp() 28 | # Fix the seed for reproducibility. 29 | np.random.seed(0) 30 | 31 | def test_year_progress_is_zero_at_year_start_or_end(self): 32 | year_progress = data_utils.get_year_progress( 33 | np.array([ 34 | 0, 35 | data_utils.AVG_SEC_PER_YEAR, 36 | data_utils.AVG_SEC_PER_YEAR * 42, # 42 years. 37 | ]) 38 | ) 39 | np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape)) 40 | 41 | def test_year_progress_is_almost_one_before_year_ends(self): 42 | year_progress = data_utils.get_year_progress( 43 | np.array([ 44 | data_utils.AVG_SEC_PER_YEAR - 1, 45 | (data_utils.AVG_SEC_PER_YEAR - 1) * 42, # ~42 years 46 | ]) 47 | ) 48 | with self.subTest("Year progress values are close to 1"): 49 | self.assertTrue(np.all(year_progress > 0.999)) 50 | with self.subTest("Year progress values != 1"): 51 | self.assertTrue(np.all(year_progress < 1.0)) 52 | 53 | def test_day_progress_computes_for_all_times_and_longitudes(self): 54 | times = np.random.randint(low=0, high=1e10, size=10) 55 | longitudes = np.arange(0, 360.0, 1.0) 56 | day_progress = data_utils.get_day_progress(times, longitudes) 57 | with self.subTest("Day progress is computed for all times and longinutes"): 58 | self.assertSequenceEqual( 59 | day_progress.shape, (len(times), len(longitudes)) 60 | ) 61 | 62 | @parameterized.named_parameters( 63 | dict( 64 | testcase_name="random_date_1", 65 | year=1988, 66 | month=11, 67 | day=7, 68 | hour=2, 69 | minute=45, 70 | second=34, 71 | ), 72 | dict( 73 | testcase_name="random_date_2", 74 | year=2022, 75 | month=3, 76 | day=12, 77 | hour=7, 78 | minute=1, 79 | second=0, 80 | ), 81 | ) 82 | def test_day_progress_is_in_between_zero_and_one( 83 | self, year, month, day, hour, minute, second 84 | ): 85 | # Datetime from a timestamp. 86 | dt = datetime.datetime(year, month, day, hour, minute, second) 87 | # Epoch time. 88 | epoch_time = datetime.datetime(1970, 1, 1) 89 | # Seconds since epoch. 90 | seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()]) 91 | 92 | # Longitudes with 1 degree resolution. 93 | longitudes = np.arange(0, 360.0, 1.0) 94 | 95 | day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes) 96 | with self.subTest("Day progress >= 0"): 97 | self.assertTrue(np.all(day_progress >= 0.0)) 98 | with self.subTest("Day progress < 1"): 99 | self.assertTrue(np.all(day_progress < 1.0)) 100 | 101 | def test_day_progress_is_zero_at_day_start_or_end(self): 102 | day_progress = data_utils.get_day_progress( 103 | seconds_since_epoch=np.array([ 104 | 0, 105 | data_utils.SEC_PER_DAY, 106 | data_utils.SEC_PER_DAY * 42, # 42 days. 107 | ]), 108 | longitude=np.array([0.0]), 109 | ) 110 | np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape)) 111 | 112 | def test_day_progress_specific_value(self): 113 | day_progress = data_utils.get_day_progress( 114 | seconds_since_epoch=np.array([123]), 115 | longitude=np.array([0.0]), 116 | ) 117 | np.testing.assert_array_almost_equal( 118 | day_progress, np.array([[0.00142361]]), decimal=6 119 | ) 120 | 121 | def test_featurize_progress_valid_values_and_dimensions(self): 122 | day_progress = np.array([0.0, 0.45, 0.213]) 123 | feature_dimensions = ("time",) 124 | progress_features = data_utils.featurize_progress( 125 | name="day_progress", dims=feature_dimensions, progress=day_progress 126 | ) 127 | for feature in progress_features.values(): 128 | with self.subTest(f"Valid dimensions for {feature}"): 129 | self.assertSequenceEqual(feature.dims, feature_dimensions) 130 | 131 | with self.subTest("Valid values for day_progress"): 132 | np.testing.assert_array_equal( 133 | day_progress, progress_features["day_progress"].values 134 | ) 135 | 136 | with self.subTest("Valid values for day_progress_sin"): 137 | np.testing.assert_array_almost_equal( 138 | np.array([0.0, 0.30901699, 0.97309851]), 139 | progress_features["day_progress_sin"].values, 140 | decimal=6, 141 | ) 142 | 143 | with self.subTest("Valid values for day_progress_cos"): 144 | np.testing.assert_array_almost_equal( 145 | np.array([1.0, -0.95105652, 0.23038943]), 146 | progress_features["day_progress_cos"].values, 147 | decimal=6, 148 | ) 149 | 150 | def test_featurize_progress_invalid_dimensions(self): 151 | year_progress = np.array([0.0, 0.45, 0.213]) 152 | feature_dimensions = ("time", "longitude") 153 | with self.assertRaises(ValueError): 154 | data_utils.featurize_progress( 155 | name="year_progress", dims=feature_dimensions, progress=year_progress 156 | ) 157 | 158 | def test_add_derived_vars_variables_added(self): 159 | data = xa.Dataset( 160 | data_vars={ 161 | "var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3)) 162 | }, 163 | coords={ 164 | "lon": np.array([0.0, 0.5]), 165 | "datetime": np.array([ 166 | datetime.datetime(2021, 1, 1), 167 | datetime.datetime(2023, 1, 1), 168 | datetime.datetime(2023, 1, 3), 169 | ]), 170 | }, 171 | ) 172 | data_utils.add_derived_vars(data) 173 | all_variables = set(data.variables) 174 | 175 | with self.subTest("Original value was not removed"): 176 | self.assertIn("var1", all_variables) 177 | with self.subTest("Year progress feature was added"): 178 | self.assertIn(data_utils.YEAR_PROGRESS, all_variables) 179 | with self.subTest("Day progress feature was added"): 180 | self.assertIn(data_utils.DAY_PROGRESS, all_variables) 181 | 182 | def test_add_derived_vars_existing_vars_not_overridden(self): 183 | dims = ["x", "lon", "datetime"] 184 | data = xa.Dataset( 185 | data_vars={ 186 | "var1": (dims, 8 * np.random.randn(2, 2, 3)), 187 | data_utils.YEAR_PROGRESS: (dims, np.full((2, 2, 3), 0.111)), 188 | data_utils.DAY_PROGRESS: (dims, np.full((2, 2, 3), 0.222)), 189 | }, 190 | coords={ 191 | "lon": np.array([0.0, 0.5]), 192 | "datetime": np.array([ 193 | datetime.datetime(2021, 1, 1), 194 | datetime.datetime(2023, 1, 1), 195 | datetime.datetime(2023, 1, 3), 196 | ]), 197 | }, 198 | ) 199 | 200 | data_utils.add_derived_vars(data) 201 | 202 | with self.subTest("Year progress feature was not overridden"): 203 | np.testing.assert_allclose(data[data_utils.YEAR_PROGRESS], 0.111) 204 | with self.subTest("Day progress feature was not overridden"): 205 | np.testing.assert_allclose(data[data_utils.DAY_PROGRESS], 0.222) 206 | 207 | @parameterized.named_parameters( 208 | dict(testcase_name="missing_datetime", coord_name="lon"), 209 | dict(testcase_name="missing_lon", coord_name="datetime"), 210 | ) 211 | def test_add_derived_vars_missing_coordinate_raises_value_error( 212 | self, coord_name 213 | ): 214 | with self.subTest(f"Missing {coord_name} coordinate"): 215 | data = xa.Dataset( 216 | data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))}, 217 | coords={ 218 | coord_name: np.array([0.0, 0.5]), 219 | }, 220 | ) 221 | with self.assertRaises(ValueError): 222 | data_utils.add_derived_vars(data) 223 | 224 | def test_add_tisr_var_variable_added(self): 225 | data = xa.Dataset( 226 | data_vars={ 227 | "var1": (["time", "lat", "lon"], np.full((2, 2, 2), 8.0)) 228 | }, 229 | coords={ 230 | "lat": np.array([2.0, 1.0]), 231 | "lon": np.array([0.0, 0.5]), 232 | "time": np.array([100, 200], dtype="timedelta64[s]"), 233 | "datetime": xa.Variable( 234 | "time", np.array([10, 20], dtype="datetime64[D]") 235 | ), 236 | }, 237 | ) 238 | 239 | data_utils.add_tisr_var(data) 240 | 241 | self.assertIn(data_utils.TISR, set(data.variables)) 242 | 243 | def test_add_tisr_var_existing_var_not_overridden(self): 244 | dims = ["time", "lat", "lon"] 245 | data = xa.Dataset( 246 | data_vars={ 247 | "var1": (dims, np.full((2, 2, 2), 8.0)), 248 | data_utils.TISR: (dims, np.full((2, 2, 2), 1200.0)), 249 | }, 250 | coords={ 251 | "lat": np.array([2.0, 1.0]), 252 | "lon": np.array([0.0, 0.5]), 253 | "time": np.array([100, 200], dtype="timedelta64[s]"), 254 | "datetime": xa.Variable( 255 | "time", np.array([10, 20], dtype="datetime64[D]") 256 | ), 257 | }, 258 | ) 259 | 260 | data_utils.add_derived_vars(data) 261 | 262 | np.testing.assert_allclose(data[data_utils.TISR], 1200.0) 263 | 264 | def test_add_tisr_var_works_with_batch_dim_size_one(self): 265 | data = xa.Dataset( 266 | data_vars={ 267 | "var1": ( 268 | ["batch", "time", "lat", "lon"], 269 | np.full((1, 2, 2, 2), 8.0), 270 | ) 271 | }, 272 | coords={ 273 | "lat": np.array([2.0, 1.0]), 274 | "lon": np.array([0.0, 0.5]), 275 | "time": np.array([100, 200], dtype="timedelta64[s]"), 276 | "datetime": xa.Variable( 277 | ("batch", "time"), np.array([[10, 20]], dtype="datetime64[D]") 278 | ), 279 | }, 280 | ) 281 | 282 | data_utils.add_tisr_var(data) 283 | 284 | self.assertIn(data_utils.TISR, set(data.variables)) 285 | 286 | def test_add_tisr_var_fails_with_batch_dim_size_greater_than_one(self): 287 | data = xa.Dataset( 288 | data_vars={ 289 | "var1": ( 290 | ["batch", "time", "lat", "lon"], 291 | np.full((2, 2, 2, 2), 8.0), 292 | ) 293 | }, 294 | coords={ 295 | "lat": np.array([2.0, 1.0]), 296 | "lon": np.array([0.0, 0.5]), 297 | "time": np.array([100, 200], dtype="timedelta64[s]"), 298 | "datetime": xa.Variable( 299 | ("batch", "time"), 300 | np.array([[10, 20], [100, 200]], dtype="datetime64[D]"), 301 | ), 302 | }, 303 | ) 304 | 305 | with self.assertRaisesRegex(ValueError, r"cannot select a dimension"): 306 | data_utils.add_tisr_var(data) 307 | 308 | 309 | if __name__ == "__main__": 310 | absltest.main() 311 | -------------------------------------------------------------------------------- /graphcast/denoisers_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Base class for Denoisers used in diffusion Predictors. 15 | 16 | Denoisers are a bit like deterministic Predictors, except: 17 | * Their __call__ method also conditions on noisy_targets and the noise_levels 18 | of those noisy targets 19 | * They don't have an overrideable loss function (the loss is assumed to be some 20 | form of MSE and is implemented outside the Denoiser itself) 21 | """ 22 | 23 | from typing import Optional, Protocol 24 | 25 | import xarray 26 | 27 | 28 | class Denoiser(Protocol): 29 | """A denoising model that conditions on inputs as well as noise level.""" 30 | 31 | def __call__( 32 | self, 33 | inputs: xarray.Dataset, 34 | noisy_targets: xarray.Dataset, 35 | noise_levels: xarray.DataArray, 36 | forcings: Optional[xarray.Dataset] = None, 37 | **kwargs) -> xarray.Dataset: 38 | """Computes denoised targets from noisy targets. 39 | 40 | Args: 41 | inputs: Inputs to condition on, as for Predictor.__call__. 42 | noisy_targets: Targets which have had i.i.d. zero-mean Gaussian noise 43 | added to them (where the noise level used may vary along the 'batch' 44 | dimension). 45 | noise_levels: A DataArray with dimensions ('batch',) specifying the noise 46 | levels that were used for each example in the batch. 47 | forcings: Optional additional per-target-timestep forcings to condition 48 | on, as for Predictor.__call__. 49 | **kwargs: Any additional custom kwargs. 50 | 51 | Returns: 52 | Denoised predictions with the same shape as noisy_targets. 53 | """ 54 | -------------------------------------------------------------------------------- /graphcast/dpm_solver_plus_plus_2s.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """DPM-Solver++ 2S sampler from https://arxiv.org/abs/2211.01095.""" 15 | 16 | from typing import Optional 17 | 18 | from graphcast import casting 19 | from graphcast import denoisers_base 20 | from graphcast import samplers_base as base 21 | from graphcast import samplers_utils as utils 22 | from graphcast import xarray_jax 23 | import haiku as hk 24 | import jax.numpy as jnp 25 | import xarray 26 | 27 | 28 | class Sampler(base.Sampler): 29 | """Sampling using DPM-Solver++ 2S from [1]. 30 | 31 | This is combined with optional stochastic churn as described in [2]. 32 | 33 | The '2S' terminology from [1] means that this is a second-order (2), 34 | single-step (S) solver. Here 'single-step' here distinguishes it from 35 | 'multi-step' methods where the results of function evaluations from previous 36 | steps are reused in computing updates for subsequent steps. The solver still 37 | uses multiple steps though. 38 | 39 | [1] DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic 40 | Models, https://arxiv.org/abs/2211.01095 41 | [2] Elucidating the Design Space of Diffusion-Based Generative Models, 42 | https://arxiv.org/abs/2206.00364 43 | """ 44 | 45 | def __init__(self, 46 | denoiser: denoisers_base.Denoiser, 47 | max_noise_level: float, 48 | min_noise_level: float, 49 | num_noise_levels: int, 50 | rho: float, 51 | stochastic_churn_rate: float, 52 | churn_min_noise_level: float, 53 | churn_max_noise_level: float, 54 | noise_level_inflation_factor: float 55 | ): 56 | """Initializes the sampler. 57 | 58 | Args: 59 | denoiser: A Denoiser which predicts noise-free targets. 60 | max_noise_level: The highest noise level used at the start of the 61 | sequence of reverse diffusion steps. 62 | min_noise_level: The lowest noise level used at the end of the sequence of 63 | reverse diffusion steps. 64 | num_noise_levels: Determines the number of noise levels used and hence the 65 | number of reverse diffusion steps performed. 66 | rho: Parameter affecting the spacing of noise steps. Higher values will 67 | concentrate noise steps more around zero. 68 | stochastic_churn_rate: S_churn from the paper. This controls the rate 69 | at which noise is re-injected/'churned' during the sampling algorithm. 70 | If this is set to zero then we are performing deterministic sampling 71 | as described in Algorithm 1. 72 | churn_min_noise_level: Minimum noise level at which stochastic churn 73 | occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. 74 | churn_max_noise_level: Maximum noise level at which stochastic churn 75 | occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. 76 | noise_level_inflation_factor: This can be used to set the actual amount of 77 | noise injected higher than what the denoiser is told has been added. 78 | The motivation is to compensate for a tendency of L2-trained denoisers 79 | to remove slightly too much noise / blur too much. S_noise from the 80 | paper. Only used if stochastic_churn_rate > 0. 81 | """ 82 | super().__init__(denoiser) 83 | self._noise_levels = utils.noise_schedule( 84 | max_noise_level, min_noise_level, num_noise_levels, rho) 85 | self._stochastic_churn = stochastic_churn_rate > 0 86 | self._per_step_churn_rates = utils.stochastic_churn_rate_schedule( 87 | self._noise_levels, stochastic_churn_rate, churn_min_noise_level, 88 | churn_max_noise_level) 89 | self._noise_level_inflation_factor = noise_level_inflation_factor 90 | 91 | def __call__( 92 | self, 93 | inputs: xarray.Dataset, 94 | targets_template: xarray.Dataset, 95 | forcings: Optional[xarray.Dataset] = None, 96 | **kwargs) -> xarray.Dataset: 97 | 98 | dtype = casting.infer_floating_dtype(targets_template) # pytype: disable=wrong-arg-types 99 | noise_levels = jnp.array(self._noise_levels).astype(dtype) 100 | per_step_churn_rates = jnp.array(self._per_step_churn_rates).astype(dtype) 101 | 102 | def denoiser(noise_level: jnp.ndarray, x: xarray.Dataset) -> xarray.Dataset: 103 | """Computes D(x, sigma, y).""" 104 | bcast_noise_level = xarray_jax.DataArray( 105 | jnp.tile(noise_level, x.sizes['batch']), dims=('batch',)) 106 | # Estimate the expectation of the fully-denoised target x0, conditional on 107 | # inputs/forcings, noisy targets and their noise level: 108 | return self._denoiser( 109 | inputs=inputs, 110 | noisy_targets=x, 111 | noise_levels=bcast_noise_level, 112 | forcings=forcings) 113 | 114 | def body_fn(i: jnp.ndarray, x: xarray.Dataset) -> xarray.Dataset: 115 | """One iteration of the sampling algorithm. 116 | 117 | Args: 118 | i: Sampling iteration. 119 | x: Noisy targets at iteration i, these will have noise level 120 | self._noise_levels[i]. 121 | 122 | Returns: 123 | Noisy targets at the next lowest noise level self._noise_levels[i+1]. 124 | """ 125 | def init_noise(template): 126 | return noise_levels[0] * utils.spherical_white_noise_like(template) 127 | 128 | # Initialise the inputs if i == 0. 129 | # This is done here to ensure both noise sampler calls can use the same 130 | # spherical harmonic basis functions. While there may be a small compute 131 | # cost the memory savings can be significant. 132 | # TODO(dominicmasters): Figure out if we can merge the two noise sampler 133 | # calls into one to avoid this hack. 134 | maybe_init_noise = (i == 0).astype(noise_levels[0].dtype) 135 | x = x + init_noise(x) * maybe_init_noise 136 | 137 | noise_level = noise_levels[i] 138 | 139 | if self._stochastic_churn: 140 | # We increase the noise level of x a bit before taking it down again: 141 | x, noise_level = utils.apply_stochastic_churn( 142 | x, noise_level, 143 | stochastic_churn_rate=per_step_churn_rates[i], 144 | noise_level_inflation_factor=self._noise_level_inflation_factor) 145 | 146 | # Apply one step of the ODE solver to take x down to the next lowest 147 | # noise level. 148 | 149 | # Note that the Elucidating paper's choice of sigma(t)=t and s(t)=1 150 | # (corresponding to alpha(t)=1 in the DPM paper) as well as the standard 151 | # choice of r=1/2 (corresponding to a geometric mean for the s_i 152 | # midpoints) greatly simplifies the update from the DPM-Solver++ paper. 153 | # You need to do a bit of algebraic fiddling to arrive at the below after 154 | # substituting these choices into DPMSolver++'s Algorithm 1. The simpler 155 | # update we arrive at helps with intuition too. 156 | 157 | next_noise_level = noise_levels[i + 1] 158 | # This is s_{i+1} from the paper. They don't explain how the s_i are 159 | # chosen, but the default choice seems to be a geometric mean, which is 160 | # equivalent to setting all the r_i = 1/2. 161 | mid_noise_level = jnp.sqrt(noise_level * next_noise_level) 162 | 163 | mid_over_current = mid_noise_level / noise_level 164 | x_denoised = denoiser(noise_level, x) 165 | # This turns out to be a convex combination of current and denoised x, 166 | # which isn't entirely apparent from the paper formulae: 167 | x_mid = mid_over_current * x + (1 - mid_over_current) * x_denoised 168 | 169 | next_over_current = next_noise_level / noise_level 170 | x_mid_denoised = denoiser(mid_noise_level, x_mid) # pytype: disable=wrong-arg-types 171 | x_next = next_over_current * x + (1 - next_over_current) * x_mid_denoised 172 | 173 | # For the final step to noise level 0, we do an Euler update which 174 | # corresponds to just returning the denoiser's prediction directly. 175 | # 176 | # In fact the behaviour above when next_noise_level == 0 is almost 177 | # equivalent, except that it runs the denoiser a second time to denoise 178 | # from noise level 0. The denoiser should just be the identity function in 179 | # this case, but it hasn't necessarily been trained at noise level 0 so 180 | # we avoid relying on this. 181 | return utils.tree_where(next_noise_level == 0, x_denoised, x_next) 182 | 183 | # Init with zeros but apply additional noise at step 0 to initialise the 184 | # state. 185 | noise_init = xarray.zeros_like(targets_template) 186 | return hk.fori_loop( 187 | 0, len(noise_levels) - 1, body_fun=body_fn, init_val=noise_init) 188 | -------------------------------------------------------------------------------- /graphcast/gencast.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Denoising diffusion models based on the framework of [1]. 15 | 16 | Throughout we will refer to notation and equations from [1]. 17 | 18 | [1] Elucidating the Design Space of Diffusion-Based Generative Models 19 | Karras, Aittala, Aila and Laine, 2022 20 | https://arxiv.org/abs/2206.00364 21 | """ 22 | 23 | from typing import Any, Optional, Tuple 24 | 25 | import chex 26 | from graphcast import casting 27 | from graphcast import denoiser 28 | from graphcast import dpm_solver_plus_plus_2s 29 | from graphcast import graphcast 30 | from graphcast import losses 31 | from graphcast import predictor_base 32 | from graphcast import samplers_utils 33 | from graphcast import xarray_jax 34 | import haiku as hk 35 | import jax 36 | import xarray 37 | 38 | 39 | TARGET_SURFACE_VARS = ( 40 | '2m_temperature', 41 | 'mean_sea_level_pressure', 42 | '10m_v_component_of_wind', 43 | '10m_u_component_of_wind', # GenCast predicts in 12hr timesteps. 44 | 'total_precipitation_12hr', 45 | 'sea_surface_temperature', 46 | ) 47 | 48 | TARGET_SURFACE_NO_PRECIP_VARS = ( 49 | '2m_temperature', 50 | 'mean_sea_level_pressure', 51 | '10m_v_component_of_wind', 52 | '10m_u_component_of_wind', 53 | 'sea_surface_temperature', 54 | ) 55 | 56 | 57 | TASK = graphcast.TaskConfig( 58 | input_variables=( 59 | # GenCast doesn't take precipitation as input. 60 | TARGET_SURFACE_NO_PRECIP_VARS 61 | + graphcast.TARGET_ATMOSPHERIC_VARS 62 | + graphcast.GENERATED_FORCING_VARS 63 | + graphcast.STATIC_VARS 64 | ), 65 | target_variables=TARGET_SURFACE_VARS + graphcast.TARGET_ATMOSPHERIC_VARS, 66 | # GenCast doesn't take incident solar radiation as a forcing. 67 | forcing_variables=graphcast.GENERATED_FORCING_VARS, 68 | pressure_levels=graphcast.PRESSURE_LEVELS_WEATHERBENCH_13, 69 | # GenCast takes the current frame and the frame 12 hours prior. 70 | input_duration='24h', 71 | ) 72 | 73 | 74 | @chex.dataclass(frozen=True, eq=True) 75 | class SamplerConfig: 76 | """Configures the sampler used to draw samples from GenCast. 77 | 78 | max_noise_level: The highest noise level used at the start of the 79 | sequence of reverse diffusion steps. 80 | min_noise_level: The lowest noise level used at the end of the sequence of 81 | reverse diffusion steps. 82 | num_noise_levels: Determines the number of noise levels used and hence the 83 | number of reverse diffusion steps performed. 84 | rho: Parameter affecting the spacing of noise steps. Higher values will 85 | concentrate noise steps more around zero. 86 | stochastic_churn_rate: S_churn from the paper. This controls the rate 87 | at which noise is re-injected/'churned' during the sampling algorithm. 88 | If this is set to zero then we are performing deterministic sampling 89 | as described in Algorithm 1. 90 | churn_max_noise_level: Maximum noise level at which stochastic churn 91 | occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. 92 | churn_min_noise_level: Minimum noise level at which stochastic churn 93 | occurs. S_min from the paper. Only used if stochastic_churn_rate > 0. 94 | noise_level_inflation_factor: This can be used to set the actual amount of 95 | noise injected higher than what the denoiser is told has been added. 96 | The motivation is to compensate for a tendency of L2-trained denoisers 97 | to remove slightly too much noise / blur too much. S_noise from the 98 | paper. Only used if stochastic_churn_rate > 0. 99 | """ 100 | max_noise_level: float = 80. 101 | min_noise_level: float = 0.03 102 | num_noise_levels: int = 20 103 | rho: float = 7. 104 | # Stochastic sampler settings. 105 | stochastic_churn_rate: float = 2.5 106 | churn_min_noise_level: float = 0.75 107 | churn_max_noise_level: float = float('inf') 108 | noise_level_inflation_factor: float = 1.05 109 | 110 | 111 | @chex.dataclass(frozen=True, eq=True) 112 | class NoiseConfig: 113 | training_noise_level_rho: float = 7.0 114 | training_max_noise_level: float = 88.0 115 | training_min_noise_level: float = 0.02 116 | 117 | 118 | @chex.dataclass(frozen=True, eq=True) 119 | class CheckPoint: 120 | description: str 121 | license: str 122 | params: dict[str, Any] 123 | task_config: graphcast.TaskConfig 124 | denoiser_architecture_config: denoiser.DenoiserArchitectureConfig 125 | sampler_config: SamplerConfig 126 | noise_config: NoiseConfig 127 | noise_encoder_config: denoiser.NoiseEncoderConfig 128 | 129 | 130 | class GenCast(predictor_base.Predictor): 131 | """Predictor for a denoising diffusion model following the framework of [1]. 132 | 133 | [1] Elucidating the Design Space of Diffusion-Based Generative Models 134 | Karras, Aittala, Aila and Laine, 2022 135 | https://arxiv.org/abs/2206.00364 136 | 137 | Unlike the paper, we have a conditional model and our denoising function 138 | conditions on previous timesteps. 139 | 140 | As the paper demonstrates, the sampling algorithm can be varied independently 141 | of the denoising model and its training procedure, and it is separately 142 | configurable here. 143 | """ 144 | 145 | def __init__( 146 | self, 147 | task_config: graphcast.TaskConfig, 148 | denoiser_architecture_config: denoiser.DenoiserArchitectureConfig, 149 | sampler_config: Optional[SamplerConfig] = None, 150 | noise_config: Optional[NoiseConfig] = None, 151 | noise_encoder_config: Optional[denoiser.NoiseEncoderConfig] = None, 152 | ): 153 | """Constructs GenCast.""" 154 | # Output size depends on number of variables being predicted. 155 | num_surface_vars = len( 156 | set(task_config.target_variables) 157 | - set(graphcast.ALL_ATMOSPHERIC_VARS) 158 | ) 159 | num_atmospheric_vars = len( 160 | set(task_config.target_variables) 161 | & set(graphcast.ALL_ATMOSPHERIC_VARS) 162 | ) 163 | num_outputs = ( 164 | num_surface_vars 165 | + len(task_config.pressure_levels) * num_atmospheric_vars 166 | ) 167 | denoiser_architecture_config.node_output_size = num_outputs 168 | self._denoiser = denoiser.Denoiser( 169 | noise_encoder_config, 170 | denoiser_architecture_config, 171 | ) 172 | self._sampler_config = sampler_config 173 | # Singleton to avoid re-initializing the sampler for each inference call. 174 | self._sampler = None 175 | self._noise_config = noise_config 176 | 177 | def _c_in(self, noise_scale: xarray.DataArray) -> xarray.DataArray: 178 | """Scaling applied to the noisy targets input to the underlying network.""" 179 | return (noise_scale**2 + 1)**-0.5 180 | 181 | def _c_out(self, noise_scale: xarray.DataArray) -> xarray.DataArray: 182 | """Scaling applied to the underlying network's raw outputs.""" 183 | return noise_scale * (noise_scale**2 + 1)**-0.5 184 | 185 | def _c_skip(self, noise_scale: xarray.DataArray) -> xarray.DataArray: 186 | """Scaling applied to the skip connection.""" 187 | return 1 / (noise_scale**2 + 1) 188 | 189 | def _loss_weighting(self, noise_scale: xarray.DataArray) -> xarray.DataArray: 190 | r"""The loss weighting \lambda(\sigma) from the paper.""" 191 | return self._c_out(noise_scale) ** -2 192 | 193 | def _preconditioned_denoiser( 194 | self, 195 | inputs: xarray.Dataset, 196 | noisy_targets: xarray.Dataset, 197 | noise_levels: xarray.DataArray, 198 | forcings: Optional[xarray.Dataset] = None, 199 | **kwargs) -> xarray.Dataset: 200 | """The preconditioned denoising function D from the paper (Eqn 7).""" 201 | raw_predictions = self._denoiser( 202 | inputs=inputs, 203 | noisy_targets=noisy_targets * self._c_in(noise_levels), 204 | noise_levels=noise_levels, 205 | forcings=forcings, 206 | **kwargs) 207 | return (raw_predictions * self._c_out(noise_levels) + 208 | noisy_targets * self._c_skip(noise_levels)) 209 | 210 | def loss_and_predictions( 211 | self, 212 | inputs: xarray.Dataset, 213 | targets: xarray.Dataset, 214 | forcings: Optional[xarray.Dataset] = None, 215 | ) -> Tuple[predictor_base.LossAndDiagnostics, xarray.Dataset]: 216 | return self.loss(inputs, targets, forcings), self(inputs, targets, forcings) 217 | 218 | def loss(self, 219 | inputs: xarray.Dataset, 220 | targets: xarray.Dataset, 221 | forcings: Optional[xarray.Dataset] = None, 222 | ) -> predictor_base.LossAndDiagnostics: 223 | 224 | if self._noise_config is None: 225 | raise ValueError('Noise config must be specified to train GenCast.') 226 | 227 | # Sample noise levels: 228 | dtype = casting.infer_floating_dtype(targets) # pytype: disable=wrong-arg-types 229 | key = hk.next_rng_key() 230 | batch_size = inputs.sizes['batch'] 231 | noise_levels = xarray_jax.DataArray( 232 | data=samplers_utils.rho_inverse_cdf( 233 | min_value=self._noise_config.training_min_noise_level, 234 | max_value=self._noise_config.training_max_noise_level, 235 | rho=self._noise_config.training_noise_level_rho, 236 | cdf=jax.random.uniform(key, shape=(batch_size,), dtype=dtype)), 237 | dims=('batch',)) 238 | 239 | # Sample noise and apply it to targets: 240 | noise = ( 241 | samplers_utils.spherical_white_noise_like(targets) * noise_levels 242 | ) 243 | noisy_targets = targets + noise 244 | 245 | denoised_predictions = self._preconditioned_denoiser( 246 | inputs, noisy_targets, noise_levels, forcings) 247 | 248 | loss, diagnostics = losses.weighted_mse_per_level( 249 | denoised_predictions, 250 | targets, 251 | # Weights are same as we used for GraphCast. 252 | per_variable_weights={ 253 | # Any variables not specified here are weighted as 1.0. 254 | # A single-level variable, but an important headline variable 255 | # and also one which we have struggled to get good performance 256 | # on at short lead times, so leaving it weighted at 1.0, equal 257 | # to the multi-level variables: 258 | '2m_temperature': 1.0, 259 | # New single-level variables, which we don't weight too highly 260 | # to avoid hurting performance on other variables. 261 | '10m_u_component_of_wind': 0.1, 262 | '10m_v_component_of_wind': 0.1, 263 | 'mean_sea_level_pressure': 0.1, 264 | 'sea_surface_temperature': 0.1, 265 | 'total_precipitation_12hr': 0.1 266 | }, 267 | ) 268 | loss *= self._loss_weighting(noise_levels) 269 | return loss, diagnostics 270 | 271 | def __call__(self, 272 | inputs: xarray.Dataset, 273 | targets_template: xarray.Dataset, 274 | forcings: Optional[xarray.Dataset] = None, 275 | **kwargs) -> xarray.Dataset: 276 | if self._sampler_config is None: 277 | raise ValueError( 278 | 'Sampler config must be specified to run inference on GenCast.' 279 | ) 280 | if self._sampler is None: 281 | self._sampler = dpm_solver_plus_plus_2s.Sampler( 282 | self._preconditioned_denoiser, **self._sampler_config 283 | ) 284 | return self._sampler(inputs, targets_template, forcings, **kwargs) 285 | -------------------------------------------------------------------------------- /graphcast/grid_mesh_connectivity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tools for converting from regular grids on a sphere, to triangular meshes.""" 15 | 16 | from graphcast import icosahedral_mesh 17 | import numpy as np 18 | import scipy 19 | import trimesh 20 | 21 | 22 | def _grid_lat_lon_to_coordinates( 23 | grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray: 24 | """Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3].""" 25 | # Convert to spherical coordinates phi and theta defined in the grid. 26 | # Each [num_latitude_points, num_longitude_points] 27 | phi_grid, theta_grid = np.meshgrid( 28 | np.deg2rad(grid_longitude), 29 | np.deg2rad(90 - grid_latitude)) 30 | 31 | # [num_latitude_points, num_longitude_points, 3] 32 | # Note this assumes unit radius, since for now we model the earth as a 33 | # sphere of unit radius, and keep any vertical dimension as a regular grid. 34 | return np.stack( 35 | [np.cos(phi_grid)*np.sin(theta_grid), 36 | np.sin(phi_grid)*np.sin(theta_grid), 37 | np.cos(theta_grid)], axis=-1) 38 | 39 | 40 | def radius_query_indices( 41 | *, 42 | grid_latitude: np.ndarray, 43 | grid_longitude: np.ndarray, 44 | mesh: icosahedral_mesh.TriangularMesh, 45 | radius: float) -> tuple[np.ndarray, np.ndarray]: 46 | """Returns mesh-grid edge indices for radius query. 47 | 48 | Args: 49 | grid_latitude: Latitude values for the grid [num_lat_points] 50 | grid_longitude: Longitude values for the grid [num_lon_points] 51 | mesh: Mesh object. 52 | radius: Radius of connectivity in R3. for a sphere of unit radius. 53 | 54 | Returns: 55 | tuple with `grid_indices` and `mesh_indices` indicating edges between the 56 | grid and the mesh such that the distances in a straight line (not geodesic) 57 | are smaller than or equal to `radius`. 58 | * grid_indices: Indices of shape [num_edges], that index into a 59 | [num_lat_points, num_lon_points] grid, after flattening the leading axes. 60 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. 61 | """ 62 | 63 | # [num_grid_points=num_lat_points * num_lon_points, 3] 64 | grid_positions = _grid_lat_lon_to_coordinates( 65 | grid_latitude, grid_longitude).reshape([-1, 3]) 66 | 67 | # [num_mesh_points, 3] 68 | mesh_positions = mesh.vertices 69 | kd_tree = scipy.spatial.cKDTree(mesh_positions) 70 | 71 | # [num_grid_points, num_mesh_points_per_grid_point] 72 | # Note `num_mesh_points_per_grid_point` is not constant, so this is a list 73 | # of arrays, rather than a 2d array. 74 | query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius) 75 | 76 | grid_edge_indices = [] 77 | mesh_edge_indices = [] 78 | for grid_index, mesh_neighbors in enumerate(query_indices): 79 | grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors))) 80 | mesh_edge_indices.append(mesh_neighbors) 81 | 82 | # [num_edges] 83 | grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int) 84 | mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int) 85 | 86 | return grid_edge_indices, mesh_edge_indices 87 | 88 | 89 | def in_mesh_triangle_indices( 90 | *, 91 | grid_latitude: np.ndarray, 92 | grid_longitude: np.ndarray, 93 | mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]: 94 | """Returns mesh-grid edge indices for grid points contained in mesh triangles. 95 | 96 | Args: 97 | grid_latitude: Latitude values for the grid [num_lat_points] 98 | grid_longitude: Longitude values for the grid [num_lon_points] 99 | mesh: Mesh object. 100 | 101 | Returns: 102 | tuple with `grid_indices` and `mesh_indices` indicating edges between the 103 | grid and the mesh vertices of the triangle that contain each grid point. 104 | The number of edges is always num_lat_points * num_lon_points * 3 105 | * grid_indices: Indices of shape [num_edges], that index into a 106 | [num_lat_points, num_lon_points] grid, after flattening the leading axes. 107 | * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices. 108 | """ 109 | 110 | # [num_grid_points=num_lat_points * num_lon_points, 3] 111 | grid_positions = _grid_lat_lon_to_coordinates( 112 | grid_latitude, grid_longitude).reshape([-1, 3]) 113 | 114 | mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces) 115 | 116 | # [num_grid_points] with mesh face indices for each grid point. 117 | _, _, query_face_indices = trimesh.proximity.closest_point( 118 | mesh_trimesh, grid_positions) 119 | 120 | # [num_grid_points, 3] with mesh node indices for each grid point. 121 | mesh_edge_indices = mesh.faces[query_face_indices] 122 | 123 | # [num_grid_points, 3] with grid node indices, where every row simply contains 124 | # the row (grid_point) index. 125 | grid_indices = np.arange(grid_positions.shape[0]) 126 | grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3]) 127 | 128 | # Flatten to get a regular list. 129 | # [num_edges=num_grid_points*3] 130 | mesh_edge_indices = mesh_edge_indices.reshape([-1]) 131 | grid_edge_indices = grid_edge_indices.reshape([-1]) 132 | 133 | return grid_edge_indices, mesh_edge_indices 134 | -------------------------------------------------------------------------------- /graphcast/grid_mesh_connectivity_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for graphcast.grid_mesh_connectivity.""" 15 | 16 | from absl.testing import absltest 17 | from graphcast import grid_mesh_connectivity 18 | from graphcast import icosahedral_mesh 19 | import numpy as np 20 | 21 | 22 | class GridMeshConnectivityTest(absltest.TestCase): 23 | 24 | def test_grid_lat_lon_to_coordinates(self): 25 | 26 | # Intervals of 30 degrees. 27 | grid_latitude = np.array([-45., 0., 45]) 28 | grid_longitude = np.array([0., 90., 180., 270.]) 29 | 30 | inv_sqrt2 = 1 / np.sqrt(2) 31 | expected_coordinates = np.array([ 32 | [[inv_sqrt2, 0., -inv_sqrt2], 33 | [0., inv_sqrt2, -inv_sqrt2], 34 | [-inv_sqrt2, 0., -inv_sqrt2], 35 | [0., -inv_sqrt2, -inv_sqrt2]], 36 | [[1., 0., 0.], 37 | [0., 1., 0.], 38 | [-1., 0., 0.], 39 | [0., -1., 0.]], 40 | [[inv_sqrt2, 0., inv_sqrt2], 41 | [0., inv_sqrt2, inv_sqrt2], 42 | [-inv_sqrt2, 0., inv_sqrt2], 43 | [0., -inv_sqrt2, inv_sqrt2]], 44 | ]) 45 | 46 | coordinates = grid_mesh_connectivity._grid_lat_lon_to_coordinates( 47 | grid_latitude, grid_longitude) 48 | np.testing.assert_allclose(expected_coordinates, coordinates, atol=1e-15) 49 | 50 | def test_radius_query_indices_smoke(self): 51 | # TODO(alvarosg): Add non-smoke test? 52 | grid_latitude = np.linspace(-75, 75, 6) 53 | grid_longitude = np.arange(12) * 30. 54 | mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( 55 | splits=3)[-1] 56 | grid_mesh_connectivity.radius_query_indices( 57 | grid_latitude=grid_latitude, 58 | grid_longitude=grid_longitude, 59 | mesh=mesh, radius=0.2) 60 | 61 | def test_in_mesh_triangle_indices_smoke(self): 62 | # TODO(alvarosg): Add non-smoke test? 63 | grid_latitude = np.linspace(-75, 75, 6) 64 | grid_longitude = np.arange(12) * 30. 65 | mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( 66 | splits=3)[-1] 67 | grid_mesh_connectivity.in_mesh_triangle_indices( 68 | grid_latitude=grid_latitude, 69 | grid_longitude=grid_longitude, 70 | mesh=mesh) 71 | 72 | 73 | if __name__ == "__main__": 74 | absltest.main() 75 | -------------------------------------------------------------------------------- /graphcast/icosahedral_mesh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for creating icosahedral meshes.""" 15 | 16 | import itertools 17 | from typing import List, NamedTuple, Sequence, Tuple 18 | 19 | import numpy as np 20 | from scipy.spatial import transform 21 | 22 | 23 | class TriangularMesh(NamedTuple): 24 | """Data structure for triangular meshes. 25 | 26 | Attributes: 27 | vertices: spatial positions of the vertices of the mesh of shape 28 | [num_vertices, num_dims]. 29 | faces: triangular faces of the mesh of shape [num_faces, 3]. Contains 30 | integer indices into `vertices`. 31 | 32 | """ 33 | vertices: np.ndarray 34 | faces: np.ndarray 35 | 36 | 37 | def merge_meshes( 38 | mesh_list: Sequence[TriangularMesh]) -> TriangularMesh: 39 | """Merges all meshes into one. Assumes the last mesh is the finest. 40 | 41 | Args: 42 | mesh_list: Sequence of meshes, from coarse to fine refinement levels. The 43 | vertices and faces may contain those from preceding, coarser levels. 44 | 45 | Returns: 46 | `TriangularMesh` for which the vertices correspond to the highest 47 | resolution mesh in the hierarchy, and the faces are the join set of the 48 | faces at all levels of the hierarchy. 49 | """ 50 | for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list): 51 | num_nodes_mesh_i = mesh_i.vertices.shape[0] 52 | assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i]) 53 | 54 | return TriangularMesh( 55 | vertices=mesh_list[-1].vertices, 56 | faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0)) 57 | 58 | 59 | def get_hierarchy_of_triangular_meshes_for_sphere( 60 | splits: int) -> List[TriangularMesh]: 61 | """Returns a sequence of meshes, each with triangularization sphere. 62 | 63 | Starting with a regular icosahedron (12 vertices, 20 faces, 30 edges) with 64 | circumscribed unit sphere. Then, each triangular face is iteratively 65 | subdivided into 4 triangular faces `splits` times. The new vertices are then 66 | projected back onto the unit sphere. All resulting meshes are returned in a 67 | list, from lowest to highest resolution. 68 | 69 | The vertices in each face are specified in counter-clockwise order as 70 | observed from the outside the icosahedron. 71 | 72 | Args: 73 | splits: How many times to split each triangle. 74 | Returns: 75 | Sequence of `TriangularMesh`s of length `splits + 1` each with: 76 | 77 | vertices: [num_vertices, 3] vertex positions in 3D, all with unit norm. 78 | faces: [num_faces, 3] with triangular faces joining sets of 3 vertices. 79 | Each row contains three indices into the vertices array, indicating 80 | the vertices adjacent to the face. Always with positive orientation 81 | (counterclock-wise when looking from the outside). 82 | """ 83 | current_mesh = get_icosahedron() 84 | output_meshes = [current_mesh] 85 | for _ in range(splits): 86 | current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh) 87 | output_meshes.append(current_mesh) 88 | return output_meshes 89 | 90 | 91 | def get_icosahedron() -> TriangularMesh: 92 | """Returns a regular icosahedral mesh with circumscribed unit sphere. 93 | 94 | See https://en.wikipedia.org/wiki/Regular_icosahedron#Cartesian_coordinates 95 | for details on the construction of the regular icosahedron. 96 | 97 | The vertices in each face are specified in counter-clockwise order as observed 98 | from the outside of the icosahedron. 99 | 100 | Returns: 101 | TriangularMesh with: 102 | 103 | vertices: [num_vertices=12, 3] vertex positions in 3D, all with unit norm. 104 | faces: [num_faces=20, 3] with triangular faces joining sets of 3 vertices. 105 | Each row contains three indices into the vertices array, indicating 106 | the vertices adjacent to the face. Always with positive orientation ( 107 | counterclock-wise when looking from the outside). 108 | 109 | """ 110 | phi = (1 + np.sqrt(5)) / 2 111 | vertices = [] 112 | for c1 in [1., -1.]: 113 | for c2 in [phi, -phi]: 114 | vertices.append((c1, c2, 0.)) 115 | vertices.append((0., c1, c2)) 116 | vertices.append((c2, 0., c1)) 117 | 118 | vertices = np.array(vertices, dtype=np.float32) 119 | vertices /= np.linalg.norm([1., phi]) 120 | 121 | # I did this manually, checking the orientation one by one. 122 | faces = [(0, 1, 2), 123 | (0, 6, 1), 124 | (8, 0, 2), 125 | (8, 4, 0), 126 | (3, 8, 2), 127 | (3, 2, 7), 128 | (7, 2, 1), 129 | (0, 4, 6), 130 | (4, 11, 6), 131 | (6, 11, 5), 132 | (1, 5, 7), 133 | (4, 10, 11), 134 | (4, 8, 10), 135 | (10, 8, 3), 136 | (10, 3, 9), 137 | (11, 10, 9), 138 | (11, 9, 5), 139 | (5, 9, 7), 140 | (9, 3, 7), 141 | (1, 6, 5), 142 | ] 143 | 144 | # By default the top is an aris parallel to the Y axis. 145 | # Need to rotate around the y axis by half the supplementary to the 146 | # angle between faces divided by two to get the desired orientation. 147 | # /O\ (top arist) 148 | # / \ Z 149 | # (adjacent face)/ \ (adjacent face) ^ 150 | # / angle_between_faces \ | 151 | # / \ | 152 | # / \ YO-----> X 153 | # This results in: 154 | # (adjacent faceis now top plane) 155 | # ----------------------O\ (top arist) 156 | # \ 157 | # \ 158 | # \ (adjacent face) 159 | # \ 160 | # \ 161 | # \ 162 | 163 | angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3)) 164 | rotation_angle = (np.pi - angle_between_faces) / 2 165 | rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle) 166 | rotation_matrix = rotation.as_matrix() 167 | vertices = np.dot(vertices, rotation_matrix) 168 | 169 | return TriangularMesh(vertices=vertices.astype(np.float32), 170 | faces=np.array(faces, dtype=np.int32)) 171 | 172 | 173 | def _two_split_unit_sphere_triangle_faces( 174 | triangular_mesh: TriangularMesh) -> TriangularMesh: 175 | """Splits each triangular face into 4 triangles keeping the orientation.""" 176 | 177 | # Every time we split a triangle into 4 we will be adding 3 extra vertices, 178 | # located at the edge centres. 179 | # This class handles the positioning of the new vertices, and avoids creating 180 | # duplicates. 181 | new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices) 182 | 183 | new_faces = [] 184 | for ind1, ind2, ind3 in triangular_mesh.faces: 185 | # Transform each triangular face into 4 triangles, 186 | # preserving the orientation. 187 | # ind3 188 | # / \ 189 | # / \ 190 | # / #3 \ 191 | # / \ 192 | # ind31 -------------- ind23 193 | # / \ / \ 194 | # / \ #4 / \ 195 | # / #1 \ / #2 \ 196 | # / \ / \ 197 | # ind1 ------------ ind12 ------------ ind2 198 | ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2)) 199 | ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3)) 200 | ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1)) 201 | # Note how each of the 4 triangular new faces specifies the order of the 202 | # vertices to preserve the orientation of the original face. As the input 203 | # face should always be counter-clockwise as specified in the diagram, 204 | # this means child faces should also be counter-clockwise. 205 | new_faces.extend([[ind1, ind12, ind31], # 1 206 | [ind12, ind2, ind23], # 2 207 | [ind31, ind23, ind3], # 3 208 | [ind12, ind23, ind31], # 4 209 | ]) 210 | return TriangularMesh(vertices=new_vertices_builder.get_all_vertices(), 211 | faces=np.array(new_faces, dtype=np.int32)) 212 | 213 | 214 | class _ChildVerticesBuilder(object): 215 | """Bookkeeping of new child vertices added to an existing set of vertices.""" 216 | 217 | def __init__(self, parent_vertices): 218 | 219 | # Because the same new vertex will be required when splitting adjacent 220 | # triangles (which share an edge) we keep them in a hash table indexed by 221 | # sorted indices of the vertices adjacent to the edge, to avoid creating 222 | # duplicated child vertices. 223 | self._child_vertices_index_mapping = {} 224 | self._parent_vertices = parent_vertices 225 | # We start with all previous vertices. 226 | self._all_vertices_list = list(parent_vertices) 227 | 228 | def _get_child_vertex_key(self, parent_vertex_indices): 229 | return tuple(sorted(parent_vertex_indices)) 230 | 231 | def _create_child_vertex(self, parent_vertex_indices): 232 | """Creates a new vertex.""" 233 | # Position for new vertex is the middle point, between the parent points, 234 | # projected to unit sphere. 235 | child_vertex_position = self._parent_vertices[ 236 | list(parent_vertex_indices)].mean(0) 237 | child_vertex_position /= np.linalg.norm(child_vertex_position) 238 | 239 | # Add the vertex to the output list. The index for this new vertex will 240 | # match the length of the list before adding it. 241 | child_vertex_key = self._get_child_vertex_key(parent_vertex_indices) 242 | self._child_vertices_index_mapping[child_vertex_key] = len( 243 | self._all_vertices_list) 244 | self._all_vertices_list.append(child_vertex_position) 245 | 246 | def get_new_child_vertex_index(self, parent_vertex_indices): 247 | """Returns index for a child vertex, creating it if necessary.""" 248 | # Get the key to see if we already have a new vertex in the middle. 249 | child_vertex_key = self._get_child_vertex_key(parent_vertex_indices) 250 | if child_vertex_key not in self._child_vertices_index_mapping: 251 | self._create_child_vertex(parent_vertex_indices) 252 | return self._child_vertices_index_mapping[child_vertex_key] 253 | 254 | def get_all_vertices(self): 255 | """Returns an array with old vertices.""" 256 | return np.array(self._all_vertices_list) 257 | 258 | 259 | def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 260 | """Transforms polygonal faces to sender and receiver indices. 261 | 262 | It does so by transforming every face into N_i edges. Such if the triangular 263 | face has indices [0, 1, 2], three edges are added 0->1, 1->2, and 2->0. 264 | 265 | If all faces have consistent orientation, and the surface represented by the 266 | faces is closed, then every edge in a polygon with a certain orientation 267 | is also part of another polygon with the opposite orientation. In this 268 | situation, the edges returned by the method are always bidirectional. 269 | 270 | Args: 271 | faces: Integer array of shape [num_faces, 3]. Contains node indices 272 | adjacent to each face. 273 | Returns: 274 | Tuple with sender/receiver indices, each of shape [num_edges=num_faces*3]. 275 | 276 | """ 277 | assert faces.ndim == 2 278 | assert faces.shape[-1] == 3 279 | senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]]) 280 | receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]]) 281 | return senders, receivers 282 | 283 | 284 | def get_last_triangular_mesh_for_sphere(splits: int) -> TriangularMesh: 285 | return get_hierarchy_of_triangular_meshes_for_sphere(splits=splits)[-1] 286 | -------------------------------------------------------------------------------- /graphcast/icosahedral_mesh_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for icosahedral_mesh.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | import chex 19 | from graphcast import icosahedral_mesh 20 | import numpy as np 21 | 22 | 23 | def _get_mesh_spec(splits: int): 24 | """Returns size of the final icosahedral mesh resulting from the splitting.""" 25 | num_vertices = 12 26 | num_faces = 20 27 | for _ in range(splits): 28 | # Each previous face adds three new vertices, but each vertex is shared 29 | # by two faces. 30 | num_vertices += num_faces * 3 // 2 31 | num_faces *= 4 32 | return num_vertices, num_faces 33 | 34 | 35 | class IcosahedralMeshTest(parameterized.TestCase): 36 | 37 | def test_icosahedron(self): 38 | mesh = icosahedral_mesh.get_icosahedron() 39 | _assert_valid_mesh( 40 | mesh, num_expected_vertices=12, num_expected_faces=20) 41 | 42 | @parameterized.parameters(list(range(5))) 43 | def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits): 44 | meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( 45 | splits=splits) 46 | prev_vertices = None 47 | for mesh_i, mesh in enumerate(meshes): 48 | # Check that `mesh` is valid. 49 | num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i) 50 | _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces) 51 | 52 | # Check that the first N vertices from this mesh match all of the 53 | # vertices from the previous mesh. 54 | if prev_vertices is not None: 55 | leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]] 56 | np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices) 57 | 58 | # Increase the expected/previous values for the next iteration. 59 | if mesh_i < len(meshes) - 1: 60 | prev_vertices = mesh.vertices 61 | 62 | @parameterized.parameters(list(range(4))) 63 | def test_merge_meshes(self, splits): 64 | mesh_hierarchy = ( 65 | icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( 66 | splits=splits)) 67 | mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy) 68 | 69 | expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0) 70 | np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices) 71 | np.testing.assert_array_equal(mesh.faces, expected_faces) 72 | 73 | def test_faces_to_edges(self): 74 | 75 | faces = np.array([[0, 1, 2], 76 | [3, 4, 5]]) 77 | 78 | # This also documents the order of the edges returned by the method. 79 | expected_edges = np.array( 80 | [[0, 1], 81 | [3, 4], 82 | [1, 2], 83 | [4, 5], 84 | [2, 0], 85 | [5, 3]]) 86 | expected_senders = expected_edges[:, 0] 87 | expected_receivers = expected_edges[:, 1] 88 | 89 | senders, receivers = icosahedral_mesh.faces_to_edges(faces) 90 | 91 | np.testing.assert_array_equal(senders, expected_senders) 92 | np.testing.assert_array_equal(receivers, expected_receivers) 93 | 94 | 95 | def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces): 96 | vertices = mesh.vertices 97 | faces = mesh.faces 98 | chex.assert_shape(vertices, [num_expected_vertices, 3]) 99 | chex.assert_shape(faces, [num_expected_faces, 3]) 100 | 101 | # Vertices norm should be 1. 102 | vertices_norm = np.linalg.norm(vertices, axis=-1) 103 | np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6) 104 | 105 | _assert_positive_face_orientation(vertices, faces) 106 | 107 | 108 | def _assert_positive_face_orientation(vertices, faces): 109 | 110 | # Obtain a unit vector that points, in the direction of the face. 111 | face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]], 112 | vertices[faces[:, 2]] - vertices[faces[:, 1]]) 113 | face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True) 114 | 115 | # And a unit vector pointing from the origin to the center of the face. 116 | face_centers = vertices[faces].mean(1) 117 | face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True) 118 | 119 | # Positive orientation means those two vectors should be parallel 120 | # (dot product, 1), and not anti-parallel (dot product, -1). 121 | dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers) 122 | 123 | # Check that the face normal is parallel to the vector that joins the center 124 | # of the face to the center of the sphere. Note we need a small tolerance 125 | # because some discretizations are not exactly uniform, so it will not be 126 | # exactly parallel. 127 | np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4) 128 | 129 | 130 | if __name__ == "__main__": 131 | absltest.main() 132 | -------------------------------------------------------------------------------- /graphcast/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Loss functions (and terms for use in loss functions) used for weather.""" 15 | 16 | from typing import Mapping 17 | 18 | from graphcast import xarray_tree 19 | import numpy as np 20 | from typing_extensions import Protocol 21 | import xarray 22 | 23 | 24 | LossAndDiagnostics = tuple[xarray.DataArray, xarray.Dataset] 25 | 26 | 27 | class LossFunction(Protocol): 28 | """A loss function. 29 | 30 | This is a protocol so it's fine to use a plain function which 'quacks like' 31 | this. This is just to document the interface. 32 | """ 33 | 34 | def __call__(self, 35 | predictions: xarray.Dataset, 36 | targets: xarray.Dataset, 37 | **optional_kwargs) -> LossAndDiagnostics: 38 | """Computes a loss function. 39 | 40 | Args: 41 | predictions: Dataset of predictions. 42 | targets: Dataset of targets. 43 | **optional_kwargs: Implementations may support extra optional kwargs. 44 | 45 | Returns: 46 | loss: A DataArray with dimensions ('batch',) containing losses for each 47 | element of the batch. These will be averaged to give the final 48 | loss, locally and across replicas. 49 | diagnostics: Mapping of additional quantities to log by name alongside the 50 | loss. These will will typically correspond to terms in the loss. They 51 | should also have dimensions ('batch',) and will be averaged over the 52 | batch before logging. 53 | """ 54 | 55 | 56 | def weighted_mse_per_level( 57 | predictions: xarray.Dataset, 58 | targets: xarray.Dataset, 59 | per_variable_weights: Mapping[str, float], 60 | ) -> LossAndDiagnostics: 61 | """Latitude- and pressure-level-weighted MSE loss.""" 62 | def loss(prediction, target): 63 | loss = (prediction - target)**2 64 | loss *= normalized_latitude_weights(target).astype(loss.dtype) 65 | if 'level' in target.dims: 66 | loss *= normalized_level_weights(target).astype(loss.dtype) 67 | return _mean_preserving_batch(loss) 68 | 69 | losses = xarray_tree.map_structure(loss, predictions, targets) 70 | return sum_per_variable_losses(losses, per_variable_weights) 71 | 72 | 73 | def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray: 74 | return x.mean([d for d in x.dims if d != 'batch'], skipna=False) 75 | 76 | 77 | def sum_per_variable_losses( 78 | per_variable_losses: Mapping[str, xarray.DataArray], 79 | weights: Mapping[str, float], 80 | ) -> LossAndDiagnostics: 81 | """Weighted sum of per-variable losses.""" 82 | if not set(weights.keys()).issubset(set(per_variable_losses.keys())): 83 | raise ValueError( 84 | 'Passing a weight that does not correspond to any variable ' 85 | f'{set(weights.keys())-set(per_variable_losses.keys())}') 86 | 87 | weighted_per_variable_losses = { 88 | name: loss * weights.get(name, 1) 89 | for name, loss in per_variable_losses.items() 90 | } 91 | total = xarray.concat( 92 | weighted_per_variable_losses.values(), dim='variable', join='exact').sum( 93 | 'variable', skipna=False) 94 | return total, per_variable_losses # pytype: disable=bad-return-type 95 | 96 | 97 | def normalized_level_weights(data: xarray.DataArray) -> xarray.DataArray: 98 | """Weights proportional to pressure at each level.""" 99 | level = data.coords['level'] 100 | return level / level.mean(skipna=False) 101 | 102 | 103 | def normalized_latitude_weights(data: xarray.DataArray) -> xarray.DataArray: 104 | """Weights based on latitude, roughly proportional to grid cell area. 105 | 106 | This method supports two use cases only (both for equispaced values): 107 | * Latitude values such that the closest value to the pole is at latitude 108 | (90 - d_lat/2), where d_lat is the difference between contiguous latitudes. 109 | For example: [-89, -87, -85, ..., 85, 87, 89]) (d_lat = 2) 110 | In this case each point with `lat` value represents a sphere slice between 111 | `lat - d_lat/2` and `lat + d_lat/2`, and the area of this slice would be 112 | proportional to: 113 | `sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)`, and 114 | we can simply omit the term `2 * sin(d_lat/2)` which is just a constant 115 | that cancels during normalization. 116 | * Latitude values that fall exactly at the poles. 117 | For example: [-90, -88, -86, ..., 86, 88, 90]) (d_lat = 2) 118 | In this case each point with `lat` value also represents 119 | a sphere slice between `lat - d_lat/2` and `lat + d_lat/2`, 120 | except for the points at the poles, that represent a slice between 121 | `90 - d_lat/2` and `90` or, `-90` and `-90 + d_lat/2`. 122 | The areas of the first type of point are still proportional to: 123 | * sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat) 124 | but for the points at the poles now is: 125 | * sin(90) - sin(90 - d_lat/2) = 2 * sin(d_lat/4) ^ 2 126 | and we will be using these weights, depending on whether we are looking at 127 | pole cells, or non-pole cells (omitting the common factor of 2 which will be 128 | absorbed by the normalization). 129 | 130 | It can be shown via a limit, or simple geometry, that in the small angles 131 | regime, the proportion of area per pole-point is equal to 1/8th 132 | the proportion of area covered by each of the nearest non-pole point, and we 133 | test for this in the test. 134 | 135 | Args: 136 | data: `DataArray` with latitude coordinates. 137 | Returns: 138 | Unit mean latitude weights. 139 | """ 140 | latitude = data.coords['lat'] 141 | 142 | if np.any(np.isclose(np.abs(latitude), 90.)): 143 | weights = _weight_for_latitude_vector_with_poles(latitude) 144 | else: 145 | weights = _weight_for_latitude_vector_without_poles(latitude) 146 | 147 | return weights / weights.mean(skipna=False) 148 | 149 | 150 | def _weight_for_latitude_vector_without_poles(latitude): 151 | """Weights for uniform latitudes of the form [+-90-+d/2, ..., -+90+-d/2].""" 152 | delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude)) 153 | if (not np.isclose(np.max(latitude), 90 - delta_latitude/2) or 154 | not np.isclose(np.min(latitude), -90 + delta_latitude/2)): 155 | raise ValueError( 156 | f'Latitude vector {latitude} does not start/end at ' 157 | '+- (90 - delta_latitude/2) degrees.') 158 | return np.cos(np.deg2rad(latitude)) 159 | 160 | 161 | def _weight_for_latitude_vector_with_poles(latitude): 162 | """Weights for uniform latitudes of the form [+- 90, ..., -+90].""" 163 | delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude)) 164 | if (not np.isclose(np.max(latitude), 90.) or 165 | not np.isclose(np.min(latitude), -90.)): 166 | raise ValueError( 167 | f'Latitude vector {latitude} does not start/end at +- 90 degrees.') 168 | weights = np.cos(np.deg2rad(latitude)) * np.sin(np.deg2rad(delta_latitude/2)) 169 | # The two checks above enough to guarantee that latitudes are sorted, so 170 | # the extremes are the poles 171 | weights[[0, -1]] = np.sin(np.deg2rad(delta_latitude/4)) ** 2 172 | return weights 173 | 174 | 175 | def _check_uniform_spacing_and_get_delta(vector): 176 | diff = np.diff(vector) 177 | if not np.all(np.isclose(diff[0], diff)): 178 | raise ValueError(f'Vector {diff} is not uniformly spaced.') 179 | return diff[0] 180 | -------------------------------------------------------------------------------- /graphcast/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Constructors for MLPs.""" 15 | 16 | import haiku as hk 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | 21 | # TODO(aelkadi): Move the mlp factory here from `deep_typed_graph_net.py`. 22 | 23 | 24 | class LinearNormConditioning(hk.Module): 25 | """Module for norm conditioning. 26 | 27 | Conditions the normalization of "inputs" by applying a linear layer to the 28 | "norm_conditioning" which produces the scale and variance which are applied to 29 | each channel (across the last dim) of "inputs". 30 | """ 31 | 32 | def __init__(self, name="norm_conditioning"): 33 | super().__init__(name=name) 34 | 35 | def __call__(self, inputs: jax.Array, norm_conditioning: jax.Array): 36 | 37 | feature_size = inputs.shape[-1] 38 | conditional_linear_layer = hk.Linear( 39 | output_size=2 * feature_size, 40 | w_init=hk.initializers.TruncatedNormal(stddev=1e-8), 41 | ) 42 | conditional_scale_offset = conditional_linear_layer(norm_conditioning) 43 | scale_minus_one, offset = jnp.split(conditional_scale_offset, 2, axis=-1) 44 | scale = scale_minus_one + 1. 45 | return inputs * scale + offset 46 | -------------------------------------------------------------------------------- /graphcast/nan_cleaning.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Wrappers for Predictors which allow them to work with data cleaned of NaNs. 15 | 16 | The Predictor which is wrapped sees inputs and targets without NaNs, and makes 17 | NaNless predictions. 18 | """ 19 | 20 | from typing import Optional, Tuple 21 | 22 | from graphcast import predictor_base as base 23 | import numpy as np 24 | import xarray 25 | 26 | 27 | class NaNCleaner(base.Predictor): 28 | """A predictor wrapper than removes NaNs from ingested data. 29 | 30 | The Predictor which is wrapped sees inputs and targets without NaNs. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | predictor: base.Predictor, 36 | var_to_clean: str, 37 | fill_value: xarray.Dataset, 38 | reintroduce_nans: bool = False, 39 | ): 40 | """Initializes the NaNCleaner.""" 41 | self._predictor = predictor 42 | self._fill_value = fill_value[var_to_clean] 43 | self._var_to_clean = var_to_clean 44 | self._reintroduce_nans = reintroduce_nans 45 | 46 | def _clean(self, dataset: xarray.Dataset) -> xarray.Dataset: 47 | """Cleans the dataset of NaNs.""" 48 | data_array = dataset[self._var_to_clean] 49 | dataset = dataset.assign( 50 | {self._var_to_clean: data_array.fillna(self._fill_value)} 51 | ) 52 | return dataset 53 | 54 | def _maybe_reintroduce_nans( 55 | self, stale_inputs: xarray.Dataset, predictions: xarray.Dataset 56 | ) -> xarray.Dataset: 57 | # NaN positions don't change between input frames, if they do then 58 | # we should be more careful about re-introducing them. 59 | if self._var_to_clean in predictions.keys(): 60 | nan_mask = np.isnan(stale_inputs[self._var_to_clean]).any(dim='time') 61 | with_nan_values = predictions[self._var_to_clean].where(~nan_mask, np.nan) 62 | predictions = predictions.assign({self._var_to_clean: with_nan_values}) 63 | return predictions 64 | 65 | def __call__( 66 | self, 67 | inputs: xarray.Dataset, 68 | targets_template: xarray.Dataset, 69 | forcings: Optional[xarray.Dataset] = None, 70 | **kwargs, 71 | ) -> xarray.Dataset: 72 | if self._reintroduce_nans: 73 | # Copy inputs before cleaning so that we can reintroduce NaNs later. 74 | original_inputs = inputs.copy() 75 | if self._var_to_clean in inputs.keys(): 76 | inputs = self._clean(inputs) 77 | if forcings and self._var_to_clean in forcings.keys(): 78 | forcings = self._clean(forcings) 79 | predictions = self._predictor( 80 | inputs, targets_template, forcings, **kwargs 81 | ) 82 | if self._reintroduce_nans: 83 | predictions = self._maybe_reintroduce_nans(original_inputs, predictions) 84 | return predictions 85 | 86 | def loss( 87 | self, 88 | inputs: xarray.Dataset, 89 | targets: xarray.Dataset, 90 | forcings: Optional[xarray.Dataset] = None, 91 | **kwargs, 92 | ) -> base.LossAndDiagnostics: 93 | if self._var_to_clean in inputs.keys(): 94 | inputs = self._clean(inputs) 95 | if self._var_to_clean in targets.keys(): 96 | targets = self._clean(targets) 97 | if forcings and self._var_to_clean in forcings.keys(): 98 | forcings = self._clean(forcings) 99 | return self._predictor.loss( 100 | inputs, targets, forcings, **kwargs 101 | ) 102 | 103 | def loss_and_predictions( 104 | self, 105 | inputs: xarray.Dataset, 106 | targets: xarray.Dataset, 107 | forcings: Optional[xarray.Dataset] = None, 108 | **kwargs, 109 | ) -> Tuple[base.LossAndDiagnostics, xarray.Dataset]: 110 | if self._reintroduce_nans: 111 | # Copy inputs before cleaning so that we can reintroduce NaNs later. 112 | original_inputs = inputs.copy() 113 | if self._var_to_clean in inputs.keys(): 114 | inputs = self._clean(inputs) 115 | if self._var_to_clean in targets.keys(): 116 | targets = self._clean(targets) 117 | if forcings and self._var_to_clean in forcings.keys(): 118 | forcings = self._clean(forcings) 119 | 120 | loss, predictions = self._predictor.loss_and_predictions( 121 | inputs, targets, forcings, **kwargs 122 | ) 123 | if self._reintroduce_nans: 124 | predictions = self._maybe_reintroduce_nans(original_inputs, predictions) 125 | return loss, predictions 126 | -------------------------------------------------------------------------------- /graphcast/normalization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Wrappers for Predictors which allow them to work with normalized data. 15 | 16 | The Predictor which is wrapped sees normalized inputs and targets, and makes 17 | normalized predictions. The wrapper handles translating the predictions back 18 | to the original domain. 19 | """ 20 | 21 | import logging 22 | from typing import Optional, Tuple 23 | 24 | from graphcast import predictor_base 25 | from graphcast import xarray_tree 26 | import xarray 27 | 28 | 29 | def normalize(values: xarray.Dataset, 30 | scales: xarray.Dataset, 31 | locations: Optional[xarray.Dataset], 32 | ) -> xarray.Dataset: 33 | """Normalize variables using the given scales and (optionally) locations.""" 34 | def normalize_array(array): 35 | if array.name is None: 36 | raise ValueError( 37 | "Can't look up normalization constants because array has no name.") 38 | if locations is not None: 39 | if array.name in locations: 40 | array = array - locations[array.name].astype(array.dtype) 41 | else: 42 | logging.warning('No normalization location found for %s', array.name) 43 | if array.name in scales: 44 | array = array / scales[array.name].astype(array.dtype) 45 | else: 46 | logging.warning('No normalization scale found for %s', array.name) 47 | return array 48 | return xarray_tree.map_structure(normalize_array, values) 49 | 50 | 51 | def unnormalize(values: xarray.Dataset, 52 | scales: xarray.Dataset, 53 | locations: Optional[xarray.Dataset], 54 | ) -> xarray.Dataset: 55 | """Unnormalize variables using the given scales and (optionally) locations.""" 56 | def unnormalize_array(array): 57 | if array.name is None: 58 | raise ValueError( 59 | "Can't look up normalization constants because array has no name.") 60 | if array.name in scales: 61 | array = array * scales[array.name].astype(array.dtype) 62 | else: 63 | logging.warning('No normalization scale found for %s', array.name) 64 | if locations is not None: 65 | if array.name in locations: 66 | array = array + locations[array.name].astype(array.dtype) 67 | else: 68 | logging.warning('No normalization location found for %s', array.name) 69 | return array 70 | return xarray_tree.map_structure(unnormalize_array, values) 71 | 72 | 73 | class InputsAndResiduals(predictor_base.Predictor): 74 | """Wraps with a residual connection, normalizing inputs and target residuals. 75 | 76 | The inner predictor is given inputs that are normalized using `locations` 77 | and `scales` to roughly zero-mean unit variance. 78 | 79 | For target variables that are present in the inputs, the inner predictor is 80 | trained to predict residuals (target - last_frame_of_input) that have been 81 | normalized using `residual_scales` (and optionally `residual_locations`) to 82 | roughly unit variance / zero mean. 83 | 84 | This replaces `residual.Predictor` in the case where you want normalization 85 | that's based on the scales of the residuals. 86 | 87 | Since we return the underlying predictor's loss on the normalized residuals, 88 | if the underlying predictor is a sum of per-variable losses, the normalization 89 | will affect the relative weighting of the per-variable loss terms (hopefully 90 | in a good way). 91 | 92 | For target variables *not* present in the inputs, the inner predictor is 93 | trained to predict targets directly, that have been normalized in the same 94 | way as the inputs. 95 | 96 | The transforms applied to the targets (the residual connection and the 97 | normalization) are applied in reverse to the predictions before returning 98 | them. 99 | """ 100 | 101 | def __init__( 102 | self, 103 | predictor: predictor_base.Predictor, 104 | stddev_by_level: xarray.Dataset, 105 | mean_by_level: xarray.Dataset, 106 | diffs_stddev_by_level: xarray.Dataset): 107 | self._predictor = predictor 108 | self._scales = stddev_by_level 109 | self._locations = mean_by_level 110 | self._residual_scales = diffs_stddev_by_level 111 | self._residual_locations = None 112 | 113 | def _unnormalize_prediction_and_add_input(self, inputs, norm_prediction): 114 | if norm_prediction.sizes.get('time') != 1: 115 | raise ValueError( 116 | 'normalization.InputsAndResiduals only supports predicting a ' 117 | 'single timestep.') 118 | if norm_prediction.name in inputs: 119 | # Residuals are assumed to be predicted as normalized (unit variance), 120 | # but the scale and location they need mapping to is that of the residuals 121 | # not of the values themselves. 122 | prediction = unnormalize( 123 | norm_prediction, self._residual_scales, self._residual_locations) 124 | # A prediction for which we have a corresponding input -- we are 125 | # predicting the residual: 126 | last_input = inputs[norm_prediction.name].isel(time=-1) 127 | prediction = prediction + last_input 128 | return prediction 129 | else: 130 | # A predicted variable which is not an input variable. We are predicting 131 | # it directly, so unnormalize it directly to the target scale/location: 132 | return unnormalize(norm_prediction, self._scales, self._locations) 133 | 134 | def _subtract_input_and_normalize_target(self, inputs, target): 135 | if target.sizes.get('time') != 1: 136 | raise ValueError( 137 | 'normalization.InputsAndResiduals only supports wrapping predictors' 138 | 'that predict a single timestep.') 139 | if target.name in inputs: 140 | target_residual = target 141 | last_input = inputs[target.name].isel(time=-1) 142 | target_residual = target_residual - last_input 143 | return normalize( 144 | target_residual, self._residual_scales, self._residual_locations) 145 | else: 146 | return normalize(target, self._scales, self._locations) 147 | 148 | def __call__(self, 149 | inputs: xarray.Dataset, 150 | targets_template: xarray.Dataset, 151 | forcings: xarray.Dataset, 152 | **kwargs 153 | ) -> xarray.Dataset: 154 | norm_inputs = normalize(inputs, self._scales, self._locations) 155 | norm_forcings = normalize(forcings, self._scales, self._locations) 156 | norm_predictions = self._predictor( 157 | norm_inputs, targets_template, forcings=norm_forcings, **kwargs) 158 | return xarray_tree.map_structure( 159 | lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred), 160 | norm_predictions) 161 | 162 | def loss(self, 163 | inputs: xarray.Dataset, 164 | targets: xarray.Dataset, 165 | forcings: xarray.Dataset, 166 | **kwargs, 167 | ) -> predictor_base.LossAndDiagnostics: 168 | """Returns the loss computed on normalized inputs and targets.""" 169 | norm_inputs = normalize(inputs, self._scales, self._locations) 170 | norm_forcings = normalize(forcings, self._scales, self._locations) 171 | norm_target_residuals = xarray_tree.map_structure( 172 | lambda t: self._subtract_input_and_normalize_target(inputs, t), 173 | targets) 174 | return self._predictor.loss( 175 | norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs) 176 | 177 | def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray 178 | self, 179 | inputs: xarray.Dataset, 180 | targets: xarray.Dataset, 181 | forcings: xarray.Dataset, 182 | **kwargs, 183 | ) -> Tuple[predictor_base.LossAndDiagnostics, 184 | xarray.Dataset]: 185 | """The loss computed on normalized data, with unnormalized predictions.""" 186 | norm_inputs = normalize(inputs, self._scales, self._locations) 187 | norm_forcings = normalize(forcings, self._scales, self._locations) 188 | norm_target_residuals = xarray_tree.map_structure( 189 | lambda t: self._subtract_input_and_normalize_target(inputs, t), 190 | targets) 191 | (loss, scalars), norm_predictions = self._predictor.loss_and_predictions( 192 | norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs) 193 | predictions = xarray_tree.map_structure( 194 | lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred), 195 | norm_predictions) 196 | return (loss, scalars), predictions 197 | -------------------------------------------------------------------------------- /graphcast/predictor_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Abstract base classes for an xarray-based Predictor API.""" 15 | 16 | import abc 17 | 18 | from typing import Tuple 19 | 20 | from graphcast import losses 21 | from graphcast import xarray_jax 22 | import jax.numpy as jnp 23 | import xarray 24 | 25 | LossAndDiagnostics = losses.LossAndDiagnostics 26 | 27 | 28 | class Predictor(abc.ABC): 29 | """A possibly-trainable predictor of weather, exposing an xarray-based API. 30 | 31 | Typically wraps an underlying JAX model and handles translating the xarray 32 | Dataset values to and from plain JAX arrays that are convenient for input to 33 | (and output from) the underlying model. 34 | 35 | Different subclasses may exist to wrap different kinds of underlying model, 36 | e.g. models taking stacked inputs/outputs, models taking separate 2D and 3D 37 | inputs/outputs, autoregressive models. 38 | 39 | You can also implement a specific model directly as a Predictor if you want, 40 | for example if it has quite specific/unique requirements for its input/output 41 | or loss function, or if it's convenient to implement directly using xarray. 42 | """ 43 | 44 | @abc.abstractmethod 45 | def __call__(self, 46 | inputs: xarray.Dataset, 47 | targets_template: xarray.Dataset, 48 | forcings: xarray.Dataset, 49 | **optional_kwargs 50 | ) -> xarray.Dataset: 51 | """Makes predictions. 52 | 53 | This is only used by the Experiment for inference / evaluation, with 54 | training going via the .loss method. So it should default to making 55 | predictions for evaluation, although you can also support making predictions 56 | for use in the loss via an is_training argument -- see 57 | LossFunctionPredictor which helps with that. 58 | 59 | Args: 60 | inputs: An xarray.Dataset of inputs. 61 | targets_template: An xarray.Dataset or other mapping of xarray.DataArrays, 62 | with the same shape as the targets, to demonstrate what kind of 63 | predictions are required. You can use this to determine which variables, 64 | levels and lead times must be predicted. 65 | You are free to raise an error if you don't support predicting what is 66 | requested. 67 | forcings: An xarray.Dataset of forcings terms. Forcings are variables 68 | that can be fed to the model, but do not need to be predicted. This is 69 | often because this variable can be computed analytically (e.g. the toa 70 | radiation of the sun is mostly a function of geometry) or are considered 71 | to be controlled for the experiment (e.g., impose a scenario of C02 72 | emission into the atmosphere). Unlike `inputs`, the `forcings` can 73 | include information "from the future", that is, information at target 74 | times specified in the `targets_template`. 75 | **optional_kwargs: Implementations may support extra optional kwargs, 76 | provided they set appropriate defaults for them. 77 | 78 | Returns: 79 | Predictions, as an xarray.Dataset or other mapping of DataArrays which 80 | is capable of being evaluated against targets with shape given by 81 | targets_template. 82 | For probabilistic predictors which can return multiple samples from a 83 | predictive distribution, these should (by convention) be returned along 84 | an additional 'sample' dimension. 85 | """ 86 | 87 | def loss(self, 88 | inputs: xarray.Dataset, 89 | targets: xarray.Dataset, 90 | forcings: xarray.Dataset, 91 | **optional_kwargs, 92 | ) -> LossAndDiagnostics: 93 | """Computes a training loss, for predictors that are trainable. 94 | 95 | Why make this the Predictor's responsibility, rather than letting callers 96 | compute their own loss function using predictions obtained from 97 | Predictor.__call__? 98 | 99 | Doing it this way gives Predictors more control over their training setup. 100 | For example, some predictors may wish to train using different targets to 101 | the ones they predict at evaluation time -- perhaps different lead times and 102 | variables, perhaps training to predict transformed versions of targets 103 | where the transform needs to be inverted at evaluation time, etc. 104 | 105 | It's also necessary for generative models (VAEs, GANs, ...) where the 106 | training loss is more complex and isn't expressible as a parameter-free 107 | function of predictions and targets. 108 | 109 | Args: 110 | inputs: An xarray.Dataset. 111 | targets: An xarray.Dataset or other mapping of xarray.DataArrays. See 112 | docs on __call__ for an explanation about the targets. 113 | forcings: xarray.Dataset of forcing terms. 114 | **optional_kwargs: Implementations may support extra optional kwargs, 115 | provided they set appropriate defaults for them. 116 | 117 | Returns: 118 | loss: A DataArray with dimensions ('batch',) containing losses for each 119 | element of the batch. These will be averaged to give the final 120 | loss, locally and across replicas. 121 | diagnostics: Mapping of additional quantities to log by name alongside the 122 | loss. These will will typically correspond to terms in the loss. They 123 | should also have dimensions ('batch',) and will be averaged over the 124 | batch before logging. 125 | You need not include the loss itself in this dict; it will be added for 126 | you. 127 | """ 128 | del targets, forcings, optional_kwargs 129 | batch_size = inputs.sizes['batch'] 130 | dummy_loss = xarray_jax.DataArray(jnp.zeros(batch_size), dims=('batch',)) 131 | return dummy_loss, {} # pytype: disable=bad-return-type 132 | 133 | def loss_and_predictions( 134 | self, 135 | inputs: xarray.Dataset, 136 | targets: xarray.Dataset, 137 | forcings: xarray.Dataset, 138 | **optional_kwargs, 139 | ) -> Tuple[LossAndDiagnostics, xarray.Dataset]: 140 | """Like .loss but also returns corresponding predictions. 141 | 142 | Implementing this is optional as it's not used directly by the Experiment, 143 | but it is required by autoregressive.Predictor when applying an inner 144 | Predictor autoregressively at training time; we need a loss at each step but 145 | also predictions to feed back in for the next step. 146 | 147 | Note the loss itself may not be directly regressing the predictions towards 148 | targets, the loss may be computed in terms of transformed predictions and 149 | targets (or in some other way). For this reason we can't always cleanly 150 | separate this into step 1: get predictions, step 2: compute loss from them, 151 | hence the need for this combined method. 152 | 153 | Args: 154 | inputs: 155 | targets: 156 | forcings: 157 | **optional_kwargs: 158 | As for self.loss. 159 | 160 | Returns: 161 | (loss, diagnostics) 162 | As for self.loss 163 | predictions: 164 | The predictions which the loss relates to. These should be of the same 165 | shape as what you would get from 166 | `self.__call__(inputs, targets_template=targets)`, and should be in the 167 | same 'domain' as the inputs (i.e. they shouldn't be transformed 168 | differently to how the predictor expects its inputs). 169 | """ 170 | raise NotImplementedError 171 | -------------------------------------------------------------------------------- /graphcast/samplers_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Base class for diffusion samplers.""" 15 | 16 | import abc 17 | from typing import Optional 18 | 19 | from graphcast import denoisers_base 20 | import xarray 21 | 22 | 23 | class Sampler(abc.ABC): 24 | """A sampling algorithm for a denoising diffusion model. 25 | 26 | This is constructed with a denoising function, and uses it to draw samples. 27 | """ 28 | 29 | _denoiser: denoisers_base.Denoiser 30 | 31 | def __init__(self, denoiser: denoisers_base.Denoiser): 32 | """Constructs Sampler. 33 | 34 | Args: 35 | denoiser: A Denoiser which has been trained with an MSE loss to predict 36 | the noise-free targets. 37 | """ 38 | self._denoiser = denoiser 39 | 40 | @abc.abstractmethod 41 | def __call__( 42 | self, 43 | inputs: xarray.Dataset, 44 | targets_template: xarray.Dataset, 45 | forcings: Optional[xarray.Dataset] = None, 46 | **kwargs) -> xarray.Dataset: 47 | """Draws a sample using self._denoiser. Contract like Predictor.__call__.""" 48 | -------------------------------------------------------------------------------- /graphcast/solar_radiation_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import timeit 15 | from typing import Sequence 16 | 17 | from absl import logging 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from graphcast import solar_radiation 21 | import numpy as np 22 | import pandas as pd 23 | import xarray as xa 24 | 25 | 26 | def _get_grid_lat_lon_coords( 27 | num_lat: int, num_lon: int 28 | ) -> tuple[np.ndarray, np.ndarray]: 29 | """Generates a linear latitude-longitude grid of the given size. 30 | 31 | Args: 32 | num_lat: Size of the latitude dimension of the grid. 33 | num_lon: Size of the longitude dimension of the grid. 34 | 35 | Returns: 36 | A tuple `(lat, lon)` containing 1D arrays with the latitude and longitude 37 | coordinates in degrees of the generated grid. 38 | """ 39 | lat = np.linspace(-90.0, 90.0, num=num_lat, endpoint=True) 40 | lon = np.linspace(0.0, 360.0, num=num_lon, endpoint=False) 41 | return lat, lon 42 | 43 | 44 | class SolarRadiationTest(parameterized.TestCase): 45 | 46 | def setUp(self): 47 | super().setUp() 48 | np.random.seed(0) 49 | 50 | def test_missing_dim_raises_value_error(self): 51 | data = xa.DataArray( 52 | np.random.randn(2, 2), 53 | coords=[np.array([0.1, 0.2]), np.array([0.0, 0.5])], 54 | dims=["lon", "x"], 55 | ) 56 | with self.assertRaisesRegex( 57 | ValueError, r".* dimensions are missing in `data_array_like`." 58 | ): 59 | solar_radiation.get_toa_incident_solar_radiation_for_xarray( 60 | data, integration_period="1h", num_integration_bins=360 61 | ) 62 | 63 | def test_missing_coordinate_raises_value_error(self): 64 | data = xa.Dataset( 65 | data_vars={"var1": (["x", "lat", "lon"], np.random.randn(2, 3, 2))}, 66 | coords={ 67 | "lat": np.array([0.0, 0.1, 0.2]), 68 | "lon": np.array([0.0, 0.5]), 69 | }, 70 | ) 71 | with self.assertRaisesRegex( 72 | ValueError, r".* coordinates are missing in `data_array_like`." 73 | ): 74 | solar_radiation.get_toa_incident_solar_radiation_for_xarray( 75 | data, integration_period="1h", num_integration_bins=360 76 | ) 77 | 78 | def test_shape_multiple_timestamps(self): 79 | data = xa.Dataset( 80 | data_vars={"var1": (["time", "lat", "lon"], np.random.randn(2, 4, 2))}, 81 | coords={ 82 | "lat": np.array([0.0, 0.1, 0.2, 0.3]), 83 | "lon": np.array([0.0, 0.5]), 84 | "time": np.array([100, 200], dtype="timedelta64[s]"), 85 | "datetime": xa.Variable( 86 | "time", np.array([10, 20], dtype="datetime64[D]") 87 | ), 88 | }, 89 | ) 90 | 91 | actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray( 92 | data, integration_period="1h", num_integration_bins=2 93 | ) 94 | 95 | self.assertEqual(("time", "lat", "lon"), actual.dims) 96 | self.assertEqual((2, 4, 2), actual.shape) 97 | 98 | def test_shape_single_timestamp(self): 99 | data = xa.Dataset( 100 | data_vars={"var1": (["lat", "lon"], np.random.randn(4, 2))}, 101 | coords={ 102 | "lat": np.array([0.0, 0.1, 0.2, 0.3]), 103 | "lon": np.array([0.0, 0.5]), 104 | "datetime": np.datetime64(10, "D"), 105 | }, 106 | ) 107 | 108 | actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray( 109 | data, integration_period="1h", num_integration_bins=2 110 | ) 111 | 112 | self.assertEqual(("lat", "lon"), actual.dims) 113 | self.assertEqual((4, 2), actual.shape) 114 | 115 | @parameterized.named_parameters( 116 | dict( 117 | testcase_name="one_timestamp_jitted", 118 | periods=1, 119 | repeats=3, 120 | use_jit=True, 121 | ), 122 | dict( 123 | testcase_name="one_timestamp_non_jitted", 124 | periods=1, 125 | repeats=3, 126 | use_jit=False, 127 | ), 128 | dict( 129 | testcase_name="ten_timestamps_non_jitted", 130 | periods=10, 131 | repeats=1, 132 | use_jit=False, 133 | ), 134 | ) 135 | def test_full_spatial_resolution( 136 | self, periods: int, repeats: int, use_jit: bool 137 | ): 138 | timestamps = pd.date_range(start="2023-09-25", periods=periods, freq="6h") 139 | # Generate a linear grid with 0.25 degrees resolution similar to ERA5. 140 | lat, lon = _get_grid_lat_lon_coords(num_lat=721, num_lon=1440) 141 | 142 | def benchmark() -> None: 143 | solar_radiation.get_toa_incident_solar_radiation( 144 | timestamps, 145 | lat, 146 | lon, 147 | integration_period="1h", 148 | num_integration_bins=360, 149 | use_jit=use_jit, 150 | ).block_until_ready() 151 | 152 | results = timeit.repeat(benchmark, repeat=repeats, number=1) 153 | 154 | logging.info( 155 | "Times to compute `tisr` for input of shape `%d, %d, %d` (seconds): %s", 156 | len(timestamps), 157 | len(lat), 158 | len(lon), 159 | np.array2string(np.array(results), precision=1), 160 | ) 161 | 162 | 163 | class GetTsiTest(parameterized.TestCase): 164 | 165 | @parameterized.named_parameters( 166 | dict( 167 | testcase_name="reference_tsi_data", 168 | loader=solar_radiation.reference_tsi_data, 169 | expected_tsi=np.array([1361.0]), 170 | ), 171 | dict( 172 | testcase_name="era5_tsi_data", 173 | loader=solar_radiation.era5_tsi_data, 174 | expected_tsi=np.array([1360.9440]), # 0.9965 * 1365.7240 175 | ), 176 | ) 177 | def test_mid_2020_lookup( 178 | self, loader: solar_radiation.TsiDataLoader, expected_tsi: np.ndarray 179 | ): 180 | tsi_data = loader() 181 | 182 | tsi = solar_radiation.get_tsi( 183 | [np.datetime64("2020-07-02T00:00:00")], tsi_data 184 | ) 185 | 186 | np.testing.assert_allclose(expected_tsi, tsi) 187 | 188 | @parameterized.named_parameters( 189 | dict( 190 | testcase_name="beginning_2020_left_boundary", 191 | timestamps=[np.datetime64("2020-01-01T00:00:00")], 192 | expected_tsi=np.array([1000.0]), 193 | ), 194 | dict( 195 | testcase_name="mid_2020_exact", 196 | timestamps=[np.datetime64("2020-07-02T00:00:00")], 197 | expected_tsi=np.array([1000.0]), 198 | ), 199 | dict( 200 | testcase_name="beginning_2021_interpolated", 201 | timestamps=[np.datetime64("2021-01-01T00:00:00")], 202 | expected_tsi=np.array([1150.0]), 203 | ), 204 | dict( 205 | testcase_name="mid_2021_lookup", 206 | timestamps=[np.datetime64("2021-07-02T12:00:00")], 207 | expected_tsi=np.array([1300.0]), 208 | ), 209 | dict( 210 | testcase_name="beginning_2022_interpolated", 211 | timestamps=[np.datetime64("2022-01-01T00:00:00")], 212 | expected_tsi=np.array([1250.0]), 213 | ), 214 | dict( 215 | testcase_name="mid_2022_lookup", 216 | timestamps=[np.datetime64("2022-07-02T12:00:00")], 217 | expected_tsi=np.array([1200.0]), 218 | ), 219 | dict( 220 | testcase_name="beginning_2023_right_boundary", 221 | timestamps=[np.datetime64("2023-01-01T00:00:00")], 222 | expected_tsi=np.array([1200.0]), 223 | ), 224 | ) 225 | def test_interpolation( 226 | self, timestamps: Sequence[np.datetime64], expected_tsi: np.ndarray 227 | ): 228 | tsi_data = xa.DataArray( 229 | np.array([1000.0, 1300.0, 1200.0]), 230 | dims=["time"], 231 | coords={"time": np.array([2020.5, 2021.5, 2022.5])}, 232 | ) 233 | 234 | tsi = solar_radiation.get_tsi(timestamps, tsi_data) 235 | 236 | np.testing.assert_allclose(expected_tsi, tsi) 237 | 238 | 239 | if __name__ == "__main__": 240 | absltest.main() 241 | -------------------------------------------------------------------------------- /graphcast/sparse_transformer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for training models in low precision.""" 15 | 16 | import functools 17 | from typing import Callable, Tuple, Union 18 | 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | 23 | # Wrappers for jax.lax.reduce_precision which is non-differentiable. 24 | @functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2)) 25 | def reduce_precision(x, exponent_bits, mantissa_bits): 26 | return jax.tree_util.tree_map( 27 | lambda y: jax.lax.reduce_precision(y, exponent_bits, mantissa_bits), x) 28 | 29 | 30 | def reduce_precision_fwd(x, exponent_bits, mantissa_bits): 31 | return reduce_precision(x, exponent_bits, mantissa_bits), None 32 | 33 | 34 | def reduce_precision_bwd(exponent_bits, mantissa_bits, res, dout): 35 | del res # Unused. 36 | return reduce_precision(dout, exponent_bits, mantissa_bits), 37 | 38 | 39 | reduce_precision.defvjp(reduce_precision_fwd, reduce_precision_bwd) 40 | 41 | 42 | def wrap_fn_for_upcast_downcast(inputs: Union[jnp.ndarray, 43 | Tuple[jnp.ndarray, ...]], 44 | fn: Callable[[Union[jnp.ndarray, 45 | Tuple[jnp.ndarray, ...]]], 46 | Union[jnp.ndarray, 47 | Tuple[jnp.ndarray, ...]]], 48 | f32_upcast: bool = True, 49 | guard_against_excess_precision: bool = True 50 | ) -> Union[jnp.ndarray, 51 | Tuple[jnp.ndarray, ...]]: 52 | """Wraps `fn` to upcast to float32 and then downcast, for use with BF16.""" 53 | # Do not upcast if the inputs are already in float32. 54 | # This removes a no-op `jax.lax.reduce_precision` which is unsupported 55 | # in jax2tf at the moment. 56 | if isinstance(inputs, Tuple): 57 | f32_upcast = f32_upcast and inputs[0].dtype != jnp.float32 58 | orig_dtype = inputs[0].dtype 59 | else: 60 | f32_upcast = f32_upcast and inputs.dtype != jnp.float32 61 | orig_dtype = inputs.dtype 62 | 63 | if f32_upcast: 64 | inputs = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), inputs) 65 | 66 | if guard_against_excess_precision: 67 | # This is evil magic to guard against differences in precision in the QK 68 | # calculation between the forward pass and backwards pass. This is like 69 | # --xla_allow_excess_precision=false but scoped here. 70 | finfo = jnp.finfo(orig_dtype) # jnp important! 71 | inputs = reduce_precision(inputs, finfo.nexp, finfo.nmant) 72 | 73 | output = fn(inputs) 74 | if f32_upcast: 75 | output = jax.tree_util.tree_map(lambda x: x.astype(orig_dtype), output) 76 | return output 77 | -------------------------------------------------------------------------------- /graphcast/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A Transformer model for weather predictions. 15 | 16 | This model wraps the a transformer model and swaps the leading two axes of the 17 | nodes in the input graph prior to evaluating the model to make it compatible 18 | with a [nodes, batch, ...] ordering of the inputs. 19 | """ 20 | 21 | from typing import Any, Mapping, Optional 22 | 23 | from graphcast import typed_graph 24 | import haiku as hk 25 | import jax 26 | import jax.numpy as jnp 27 | import numpy as np 28 | from scipy import sparse 29 | 30 | 31 | Kwargs = Mapping[str, Any] 32 | 33 | 34 | def _get_adj_matrix_for_edge_set( 35 | graph: typed_graph.TypedGraph, 36 | edge_set_name: str, 37 | add_self_edges: bool, 38 | ): 39 | """Returns the adjacency matrix for the given graph and edge set.""" 40 | # Get nodes and edges of the graph. 41 | edge_set_key = graph.edge_key_by_name(edge_set_name) 42 | sender_node_set, receiver_node_set = edge_set_key.node_sets 43 | 44 | # Compute number of sender and receiver nodes. 45 | sender_n_node = graph.nodes[sender_node_set].n_node[0] 46 | receiver_n_node = graph.nodes[receiver_node_set].n_node[0] 47 | 48 | # Build adjacency matrix. 49 | adj_mat = sparse.csr_matrix((sender_n_node, receiver_n_node), dtype=np.bool_) 50 | edge_set = graph.edges[edge_set_key] 51 | s, r = edge_set.indices 52 | adj_mat[s, r] = True 53 | if add_self_edges: 54 | # Should only do this if we are certain the adjacency matrix is square. 55 | assert sender_node_set == receiver_node_set 56 | adj_mat[np.arange(sender_n_node), np.arange(receiver_n_node)] = True 57 | return adj_mat 58 | 59 | 60 | class MeshTransformer(hk.Module): 61 | """A Transformer for inputs with ordering [nodes, batch, ...].""" 62 | 63 | def __init__(self, 64 | transformer_ctor, 65 | transformer_kwargs: Kwargs, 66 | name: Optional[str] = None): 67 | """Initialises the Transformer model. 68 | 69 | Args: 70 | transformer_ctor: Constructor for transformer. 71 | transformer_kwargs: Kwargs to pass to the transformer module. 72 | name: Optional name for haiku module. 73 | """ 74 | super().__init__(name=name) 75 | # We defer the transformer initialisation to the first call to __call__, 76 | # where we can build the mask senders and receivers of the TypedGraph 77 | self._batch_first_transformer = None 78 | self._transformer_ctor = transformer_ctor 79 | self._transformer_kwargs = transformer_kwargs 80 | 81 | @hk.name_like('__init__') 82 | def _maybe_init_batch_first_transformer(self, x: typed_graph.TypedGraph): 83 | if self._batch_first_transformer is not None: 84 | return 85 | self._batch_first_transformer = self._transformer_ctor( 86 | adj_mat=_get_adj_matrix_for_edge_set( 87 | graph=x, 88 | edge_set_name='mesh', 89 | add_self_edges=True, 90 | ), 91 | **self._transformer_kwargs, 92 | ) 93 | 94 | def __call__( 95 | self, x: typed_graph.TypedGraph, 96 | global_norm_conditioning: jax.Array 97 | ) -> typed_graph.TypedGraph: 98 | """Applies the model to the input graph and returns graph of same shape.""" 99 | 100 | if set(x.nodes.keys()) != {'mesh_nodes'}: 101 | raise ValueError( 102 | f'Expected x.nodes to have key `mesh_nodes`, got {x.nodes.keys()}.' 103 | ) 104 | features = x.nodes['mesh_nodes'].features 105 | if features.ndim != 3: # pytype: disable=attribute-error # jax-ndarray 106 | raise ValueError( 107 | 'Expected `x.nodes["mesh_nodes"].features` to be 3, got' 108 | f' {features.ndim}.' 109 | ) # pytype: disable=attribute-error # jax-ndarray 110 | 111 | # Initialise transformer and mask. 112 | self._maybe_init_batch_first_transformer(x) 113 | 114 | y = jnp.transpose(features, axes=[1, 0, 2]) 115 | y = self._batch_first_transformer(y, global_norm_conditioning) 116 | y = jnp.transpose(y, axes=[1, 0, 2]) 117 | x = x._replace( 118 | nodes={ 119 | 'mesh_nodes': x.nodes['mesh_nodes']._replace( 120 | features=y.astype(features.dtype) 121 | ) 122 | } 123 | ) 124 | return x 125 | -------------------------------------------------------------------------------- /graphcast/typed_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Data-structure for storing graphs with typed edges and nodes.""" 15 | 16 | from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar 17 | 18 | ArrayLike = Union[Any] # np.ndarray, jnp.ndarray, tf.tensor 19 | ArrayLikeTree = Union[Any, ArrayLike] # Nest of ArrayLike 20 | 21 | _T = TypeVar('_T') 22 | 23 | 24 | # All tensors have a "flat_batch_axis", which is similar to the leading 25 | # axes of graph_tuples: 26 | # * In the case of nodes this is simply a shared node and flat batch axis, with 27 | # size corresponding to the total number of nodes in the flattened batch. 28 | # * In the case of edges this is simply a shared edge and flat batch axis, with 29 | # size corresponding to the total number of edges in the flattened batch. 30 | # * In the case of globals this is simply the number of graphs in the flattened 31 | # batch. 32 | 33 | # All shapes may also have any additional leading shape "batch_shape". 34 | # Options for building batches are: 35 | # * Use a provided "flatten" method that takes a leading `batch_shape` and 36 | # it into the flat_batch_axis (this will be useful when using `tf.Dataset` 37 | # which supports batching into RaggedTensors, with leading batch shape even 38 | # if graphs have different numbers of nodes and edges), so the RaggedBatches 39 | # can then be converted into something without ragged dimensions that jax can 40 | # use. 41 | # * Directly build a "flat batch" using a provided function for batching a list 42 | # of graphs (how it is done in `jraph`). 43 | 44 | 45 | class NodeSet(NamedTuple): 46 | """Represents a set of nodes.""" 47 | n_node: ArrayLike # [num_flat_graphs] 48 | features: ArrayLikeTree # Prev. `nodes`: [num_flat_nodes] + feature_shape 49 | 50 | 51 | class EdgesIndices(NamedTuple): 52 | """Represents indices to nodes adjacent to the edges.""" 53 | senders: ArrayLike # [num_flat_edges] 54 | receivers: ArrayLike # [num_flat_edges] 55 | 56 | 57 | class EdgeSet(NamedTuple): 58 | """Represents a set of edges.""" 59 | n_edge: ArrayLike # [num_flat_graphs] 60 | indices: EdgesIndices 61 | features: ArrayLikeTree # Prev. `edges`: [num_flat_edges] + feature_shape 62 | 63 | 64 | class Context(NamedTuple): 65 | # `n_graph` always contains ones but it is useful to query the leading shape 66 | # in case of graphs without any nodes or edges sets. 67 | n_graph: ArrayLike # [num_flat_graphs] 68 | features: ArrayLikeTree # Prev. `globals`: [num_flat_graphs] + feature_shape 69 | 70 | 71 | class EdgeSetKey(NamedTuple): 72 | name: str # Name of the EdgeSet. 73 | 74 | # Sender node set name and receiver node set name connected by the edge set. 75 | node_sets: Tuple[str, str] 76 | 77 | 78 | class TypedGraph(NamedTuple): 79 | """A graph with typed nodes and edges. 80 | 81 | A typed graph is made of a context, multiple sets of nodes and multiple 82 | sets of edges connecting those nodes (as indicated by the EdgeSetKey). 83 | """ 84 | 85 | context: Context 86 | nodes: Mapping[str, NodeSet] 87 | edges: Mapping[EdgeSetKey, EdgeSet] 88 | 89 | def edge_key_by_name(self, name: str) -> EdgeSetKey: 90 | found_key = [k for k in self.edges.keys() if k.name == name] 91 | if len(found_key) != 1: 92 | raise KeyError("invalid edge key '{}'. Available edges: [{}]".format( 93 | name, ', '.join(x.name for x in self.edges.keys()))) 94 | return found_key[0] 95 | 96 | def edge_by_name(self, name: str) -> EdgeSet: 97 | return self.edges[self.edge_key_by_name(name)] 98 | -------------------------------------------------------------------------------- /graphcast/xarray_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities for working with trees of xarray.DataArray (including Datasets). 15 | 16 | Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library; 17 | it won't work as a leaf node since it implements Mapping, but also won't work 18 | as an internal node since tree doesn't know how to re-create it properly. 19 | 20 | To fix this, we reimplement a subset of `map_structure`, exposing its 21 | constituent DataArrays as leaf nodes. This means it can be mapped over as a 22 | generic container of DataArrays, while still preserving the result as a Dataset 23 | where possible. 24 | 25 | This is useful because in a few places we need to handle a general 26 | Mapping[str, DataArray] (where the coordinates might not be compatible across 27 | the constituent DataArrays) but also the special case of a Dataset nicely. 28 | 29 | For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for 30 | some of the child DataArrays, they will be omitted from the returned dataset. If 31 | any values other than DataArrays or None are returned, then we don't attempt to 32 | return a Dataset and just return a plain dict of the results. Similarly if 33 | DataArrays are returned but with non-matching coordinates, it will just return a 34 | plain dict of DataArrays. 35 | 36 | Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py, 37 | but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`. 38 | as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the 39 | latter exposes DataArrays as leaf nodes. 40 | """ 41 | 42 | from typing import Any, Callable 43 | 44 | import xarray 45 | 46 | 47 | def map_structure(func: Callable[..., Any], *structures: Any) -> Any: 48 | """Maps func through given structures with xarrays. See tree.map_structure.""" 49 | if not callable(func): 50 | raise TypeError(f'func must be callable, got: {func}') 51 | if not structures: 52 | raise ValueError('Must provide at least one structure') 53 | 54 | first = structures[0] 55 | if isinstance(first, xarray.Dataset): 56 | data = {k: func(*[s[k] for s in structures]) for k in first.keys()} 57 | if all(isinstance(a, (type(None), xarray.DataArray)) 58 | for a in data.values()): 59 | data_arrays = [v.rename(k) for k, v in data.items() if v is not None] 60 | try: 61 | return xarray.merge(data_arrays, join='exact') 62 | except ValueError: # Exact join not possible. 63 | pass 64 | return data 65 | if isinstance(first, dict): 66 | return {k: map_structure(func, *[s[k] for s in structures]) 67 | for k in first.keys()} 68 | if isinstance(first, (list, tuple, set)): 69 | return type(first)(map_structure(func, *s) for s in zip(*structures)) 70 | return func(*structures) 71 | -------------------------------------------------------------------------------- /graphcast/xarray_tree_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for xarray_tree.""" 15 | 16 | from absl.testing import absltest 17 | from graphcast import xarray_tree 18 | import numpy as np 19 | import xarray 20 | 21 | 22 | TEST_DATASET = xarray.Dataset( 23 | data_vars={ 24 | "foo": (("x", "y"), np.zeros((2, 3))), 25 | "bar": (("x",), np.zeros((2,))), 26 | }, 27 | coords={ 28 | "x": [1, 2], 29 | "y": [10, 20, 30], 30 | } 31 | ) 32 | 33 | 34 | class XarrayTreeTest(absltest.TestCase): 35 | 36 | def test_map_structure_maps_over_leaves_but_preserves_dataset_type(self): 37 | def fn(leaf): 38 | self.assertIsInstance(leaf, xarray.DataArray) 39 | result = leaf + 1 40 | # Removing the name from the returned DataArray to test that we don't rely 41 | # on it being present to restore the correct names in the result: 42 | result = result.rename(None) 43 | return result 44 | 45 | result = xarray_tree.map_structure(fn, TEST_DATASET) 46 | self.assertIsInstance(result, xarray.Dataset) 47 | self.assertSameElements({"foo", "bar"}, result.keys()) 48 | 49 | def test_map_structure_on_data_arrays(self): 50 | data_arrays = dict(TEST_DATASET) 51 | result = xarray_tree.map_structure(lambda x: x+1, data_arrays) 52 | self.assertIsInstance(result, dict) 53 | self.assertSameElements({"foo", "bar"}, result.keys()) 54 | 55 | def test_map_structure_on_dataset_plain_dict_when_coords_incompatible(self): 56 | def fn(leaf): 57 | # Returns DataArrays that can't be exactly merged back into a Dataset 58 | # due to the coordinates not matching: 59 | if leaf.name == "foo": 60 | return xarray.DataArray( 61 | data=np.zeros(2), dims=("x",), coords={"x": [1, 2]}) 62 | else: 63 | return xarray.DataArray( 64 | data=np.zeros(2), dims=("x",), coords={"x": [3, 4]}) 65 | 66 | result = xarray_tree.map_structure(fn, TEST_DATASET) 67 | self.assertIsInstance(result, dict) 68 | self.assertSameElements({"foo", "bar"}, result.keys()) 69 | 70 | def test_map_structure_on_dataset_drops_vars_with_none_return_values(self): 71 | def fn(leaf): 72 | return leaf if leaf.name == "foo" else None 73 | 74 | result = xarray_tree.map_structure(fn, TEST_DATASET) 75 | self.assertIsInstance(result, xarray.Dataset) 76 | self.assertSameElements({"foo"}, result.keys()) 77 | 78 | def test_map_structure_on_dataset_returns_plain_dict_other_return_types(self): 79 | def fn(leaf): 80 | self.assertIsInstance(leaf, xarray.DataArray) 81 | return "not a DataArray" 82 | 83 | result = xarray_tree.map_structure(fn, TEST_DATASET) 84 | self.assertEqual({"foo": "not a DataArray", 85 | "bar": "not a DataArray"}, result) 86 | 87 | def test_map_structure_two_args_different_variable_orders(self): 88 | dataset_different_order = TEST_DATASET[["bar", "foo"]] 89 | def fn(arg1, arg2): 90 | self.assertEqual(arg1.name, arg2.name) 91 | xarray_tree.map_structure(fn, TEST_DATASET, dataset_different_order) 92 | 93 | 94 | if __name__ == "__main__": 95 | absltest.main() 96 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module setuptools script.""" 15 | 16 | from setuptools import setup 17 | 18 | description = ( 19 | "GraphCast: Learning skillful medium-range global weather forecasting" 20 | ) 21 | 22 | setup( 23 | name="graphcast", 24 | version="0.2.0.dev", 25 | description=description, 26 | long_description=description, 27 | author="DeepMind", 28 | license="Apache License, Version 2.0", 29 | keywords="GraphCast Weather Prediction", 30 | url="https://github.com/deepmind/graphcast", 31 | packages=["graphcast"], 32 | install_requires=[ 33 | "cartopy", 34 | "chex", 35 | "colabtools", 36 | "dask", 37 | "dinosaur-dycore", 38 | "dm-haiku", 39 | "dm-tree", 40 | "jax", 41 | "jraph", 42 | "matplotlib", 43 | "numpy", 44 | "pandas", 45 | "rtree", 46 | "scipy", 47 | "trimesh", 48 | "typing_extensions", 49 | "xarray", 50 | "xarray_tensorstore" 51 | ], 52 | classifiers=[ 53 | "Development Status :: 3 - Alpha", 54 | "Intended Audience :: Science/Research", 55 | "License :: OSI Approved :: Apache Software License", 56 | "Operating System :: POSIX :: Linux", 57 | "Programming Language :: Python :: 3.10", 58 | "Programming Language :: Python :: 3.11", 59 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 60 | "Topic :: Scientific/Engineering :: Atmospheric Science", 61 | "Topic :: Scientific/Engineering :: Physics", 62 | ], 63 | ) 64 | --------------------------------------------------------------------------------