├── .env ├── .gitignore ├── README.md ├── ai-models-modal ├── __init__.py ├── ai_models_shim.py ├── app.py ├── config.py ├── gcs.py ├── gfs.py └── main.py └── requirements.txt /.env: -------------------------------------------------------------------------------- 1 | # Over-ride these secrets! 2 | 3 | # Copernicus Data Store API url / keys 4 | CDSAPI_URL=https://cds.climate.copernicus.eu/api/v2 5 | CDSAPI_KEY=YOUR_KEY_HERE 6 | 7 | # Google Cloud Storage credentials and output bucket 8 | GCS_SERVICE_ACCOUNT_INFO=YOUR_SERVICE_ACCOUNT_INFO 9 | # Set the bucket on GCS that you own where you want to upload finished model 10 | # forecasts. We recommend that you create a new bucket for this purpose, with 11 | # a name like, "-ai-models-for-all/" 12 | GCS_BUCKET_NAME=YOUR_BUCKET_NAME 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,macos 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,macos 3 | 4 | ### macOS ### 5 | # General 6 | .DS_Store 7 | .AppleDouble 8 | .LSOverride 9 | 10 | # Icon must end with two \r 11 | Icon 12 | 13 | 14 | # Thumbnails 15 | ._* 16 | 17 | # Files that might appear in the root of a volume 18 | .DocumentRevisions-V100 19 | .fseventsd 20 | .Spotlight-V100 21 | .TemporaryItems 22 | .Trashes 23 | .VolumeIcon.icns 24 | .com.apple.timemachine.donotpresent 25 | 26 | # Directories potentially created on remote AFP share 27 | .AppleDB 28 | .AppleDesktop 29 | Network Trash Folder 30 | Temporary Items 31 | .apdisk 32 | 33 | ### macOS Patch ### 34 | # iCloud generated files 35 | *.icloud 36 | 37 | ### Python ### 38 | # Byte-compiled / optimized / DLL files 39 | __pycache__/ 40 | *.py[cod] 41 | *$py.class 42 | 43 | # C extensions 44 | *.so 45 | 46 | # Distribution / packaging 47 | .Python 48 | build/ 49 | develop-eggs/ 50 | dist/ 51 | downloads/ 52 | eggs/ 53 | .eggs/ 54 | lib/ 55 | lib64/ 56 | parts/ 57 | sdist/ 58 | var/ 59 | wheels/ 60 | share/python-wheels/ 61 | *.egg-info/ 62 | .installed.cfg 63 | *.egg 64 | MANIFEST 65 | 66 | # PyInstaller 67 | # Usually these files are written by a python script from a template 68 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 69 | *.manifest 70 | *.spec 71 | 72 | # Installer logs 73 | pip-log.txt 74 | pip-delete-this-directory.txt 75 | 76 | # Unit test / coverage reports 77 | htmlcov/ 78 | .tox/ 79 | .nox/ 80 | .coverage 81 | .coverage.* 82 | .cache 83 | nosetests.xml 84 | coverage.xml 85 | *.cover 86 | *.py,cover 87 | .hypothesis/ 88 | .pytest_cache/ 89 | cover/ 90 | 91 | # Translations 92 | *.mo 93 | *.pot 94 | 95 | # Django stuff: 96 | *.log 97 | local_settings.py 98 | db.sqlite3 99 | db.sqlite3-journal 100 | 101 | # Flask stuff: 102 | instance/ 103 | .webassets-cache 104 | 105 | # Scrapy stuff: 106 | .scrapy 107 | 108 | # Sphinx documentation 109 | docs/_build/ 110 | 111 | # PyBuilder 112 | .pybuilder/ 113 | target/ 114 | 115 | # Jupyter Notebook 116 | .ipynb_checkpoints 117 | 118 | # IPython 119 | profile_default/ 120 | ipython_config.py 121 | 122 | # pyenv 123 | # For a library or package, you might want to ignore these files since the code is 124 | # intended to run in multiple environments; otherwise, check them in: 125 | # .python-version 126 | 127 | # pipenv 128 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 129 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 130 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 131 | # install all needed dependencies. 132 | #Pipfile.lock 133 | 134 | # poetry 135 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 136 | # This is especially recommended for binary packages to ensure reproducibility, and is more 137 | # commonly ignored for libraries. 138 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 139 | #poetry.lock 140 | 141 | # pdm 142 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 143 | #pdm.lock 144 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 145 | # in version control. 146 | # https://pdm.fming.dev/#use-with-ide 147 | .pdm.toml 148 | 149 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 150 | __pypackages__/ 151 | 152 | # Celery stuff 153 | celerybeat-schedule 154 | celerybeat.pid 155 | 156 | # SageMath parsed files 157 | *.sage.py 158 | 159 | # Environments 160 | .env 161 | .venv 162 | env/ 163 | venv/ 164 | ENV/ 165 | env.bak/ 166 | venv.bak/ 167 | 168 | # Spyder project settings 169 | .spyderproject 170 | .spyproject 171 | 172 | # Rope project settings 173 | .ropeproject 174 | 175 | # mkdocs documentation 176 | /site 177 | 178 | # mypy 179 | .mypy_cache/ 180 | .dmypy.json 181 | dmypy.json 182 | 183 | # Pyre type checker 184 | .pyre/ 185 | 186 | # pytype static type analyzer 187 | .pytype/ 188 | 189 | # Cython debug symbols 190 | cython_debug/ 191 | 192 | # PyCharm 193 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 194 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 195 | # and can be added to the global gitignore or merged into this file. For a more nuclear 196 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 197 | #.idea/ 198 | 199 | ### Python Patch ### 200 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 201 | poetry.toml 202 | 203 | # ruff 204 | .ruff_cache/ 205 | 206 | # LSP config files 207 | pyrightconfig.json 208 | 209 | ### VisualStudioCode ### 210 | .vscode/* 211 | !.vscode/settings.json 212 | !.vscode/tasks.json 213 | !.vscode/launch.json 214 | !.vscode/extensions.json 215 | !.vscode/*.code-snippets 216 | 217 | # Local History for Visual Studio Code 218 | .history/ 219 | 220 | # Built Visual Studio Code Extensions 221 | *.vsix 222 | 223 | ### VisualStudioCode Patch ### 224 | # Ignore all local history of files 225 | .history 226 | .ionide 227 | 228 | # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,macos -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `ai-models` For All 2 | 3 | This package boot-straps on top of the fantastic [`ai-models`](https://github.com/ecmwf-lab/ai-models) library to build a serverless application to generate "pure AI 4 | NWP" weather forecasts on [Modal](https://www.modal.com). Users can run their own 5 | historical re-forecasts using either [PanguWeather](https://www.nature.com/articles/s41586-023-06185-3), 6 | [FourCastNet](https://arxiv.org/abs/2202.11214), or [GraphCast](https://www.science.org/doi/10.1126/science.adi2336), 7 | and save the outputs to their own cloud storage provider for further use. 8 | 9 | The initial release of this application is fully-featured, with some limitations: 10 | 11 | - We only provide one storage adapter, for Google Cloud Storage. This can be generalized 12 | to support S3, Azure, or any other provider in the future. 13 | - By default, users may initialize a forecast from the CDS-based ERA-5 archive; we also 14 | the option to initialize from a GFS forecast, retrieved from NOAA's archive of these 15 | products on Google Cloud Storage. We do not provide a mechanism to initialize with IFS 16 | operational forecasts from MARS. 17 | - The current application only runs on [Modal](https://www.modal.com); in the future, it 18 | would be great to port this to other serverless platforms, re-using as much of the 19 | core implementation as possible. 20 | 21 | This application relies on the fantastic [`ecmwf-labs/ai-models`](https://github.com/ecmwf-lab/ai-models) 22 | package to automate a lot of the traditional MLOps that are necessary to run this type 23 | of product in a semi-production context. `ai-models` handles acquiring data to use as 24 | inputs for model inference (read: generate a forecast from initial conditions) by 25 | providing an as-needed interface with the Copernicus Data Store and MARS API, it 26 | provides pre-trained model weights, it implements a simple ONNX-based interface 27 | for performing model inference, and it outputs a well-formed (albeit in GRIB) output 28 | file that can be fed into downstream workflows (e.g. model visualization). We don't 29 | anticipate replacing this package, but we may contribute improvements and features 30 | upstream (e.g. a high priority is writing a NetCDF output adapter that writes timesliced 31 | files per model step, with metadata following CF conventions) as they mature here. 32 | 33 | **Your feedback to or [@danrothenberg](https://twitter.com/danrothenberg) would be greatly appreciated!** 34 | 35 | ## Usage / Restrictions 36 | 37 | If you use this application, please give credit to [Daniel Rothenberg](https://github.com/darothen) 38 | ( or [@danrothenberg](https://twitter.com/danrothenberg)), 39 | as well as the incredible team at [ECMWF Lab](https://github.com/ecmwf-lab) and the 40 | publishers of any forecast model you use. 41 | 42 | **NOTE THAT EACH FORECAST MODEL PROVIDED BY AI-MODELS HAS ITS OWN LICENSE AND RESTRICTIONS**. 43 | 44 | This package may *only* be used in a manner compliant with the licenses and terms of all 45 | the libraries, model weights, and application platforms/services upon which it is built. 46 | The forecasts generated by the AI models and the software which power them are *experimental in nature* 47 | and may break or fail unexpectedly during normal use. 48 | 49 | ## Quick Start 50 | 51 | 1. Set up accounts (if you don't already have them) for: 52 | 1. [Google Cloud](https://cloud.google.com) 53 | 2. [Modal](https://www.modal.com) 54 | 3. [Copernicus Data Store](https://cds.climate.copernicus.eu/) 55 | 2. Complete the `.env` file with your CDS API credentials, GCS service account keys, and 56 | a bucket name where model outputs will be uploaded. **You should create this bucket 57 | before running the application!** 58 | 3. From a terminal, login with the `modal-client` 59 | 4. Navigate to the repository on-disk and execute the command, 60 | 61 | ```shell 62 | $ modal run ai-models-modal.main \ 63 | --model-name {panguweather,fourcastnetv2-small,graphcast} \ 64 | --model-init 2023-07-01T00:00:00 \ 65 | --lead-time 12 \ 66 | [--use-gfs] 67 | ``` 68 | 69 | The first time you run this, it will take a few minutes to build an image and set up 70 | assets on Modal. Then, the model will run remotely on Modal infrastructure, and you 71 | can monitor its progress via the logs streamed to your terminal. The bracketed CLI 72 | args are the defaults that will be used if you don't provide any. 73 | 5. Download the model output from Google Cloud Storage at **gs://{GCS_BUCKET_NAME}** as 74 | provided via the `.env` file. 75 | 6. Install required dependencies onto your machine using the requirements.txt file (pip install -r requirements.txt) 76 | 77 | ## Using GFS/GDAS Initial Conditions 78 | 79 | We've implemented the ability for users to fetch initial conditions from an 80 | archived GFS forecast cycle. In the current implementation, we make some assumptions 81 | about how to process and map the GFS data to the ERA-5 data that the `ai-models` 82 | package typically tries to fetch: 83 | 84 | 1. Some models require surface geopotential or orography fields as an input; we use the 85 | GFS/GDAS version of this data instead of copying over from ERA-5. Similarly, when 86 | needed we use the land-sea mask from GFS/GDAS instead of copying over ERA-5's. 87 | 2. GraphCast is initialized with accumulated precipitation data that is not readily 88 | available in the GFS/GDAS outputs; we currently approximate this very crudely by 89 | looking at the 6-hr lagged precipitation rate fields from subsequent GFS/GDAS 90 | analyses. 91 | 3. The AI models are not fine-tuned (yet) on GFS data, so underlying differences in the 92 | core distribution of atmospheric data between ERA-5 and GFS/GDAS could degrade 93 | forecast quality in unexpected ways. Additionally, we apply the ERA-5 derived 94 | Z-score or uniform distribution scaling from the parent AI models instead of providing 95 | new ones for GFS/GDAS data. 96 | 97 | We use the `gfs.tHHz.pgrb2.0p25.f000` output files to pull the initial conditions. These 98 | are available in near-real-time (unlike the final GDAS analyses, which are lagged by 99 | about one model cycle). We may provide the option to use the `.anl` analysis files, too 100 | or hot-start initial conditions, based on feedback from the community/users. Converting 101 | to these files simply requires building a new set of mappers from the corresponding 102 | ERA-5 fields. 103 | 104 | ### Running a Forecast from GFS/GDAS 105 | 106 | To tell `ai-models-for-all` to use GFS initial conditions, simply pass the command line 107 | flag "`--use-gfs`" and initialize a model run as usual. 108 | 109 | ```shell 110 | $ modal run ai-models-modal.main \ 111 | --model-name {panguweather,fourcastnetv2-small,graphcast} \ 112 | --model-init 2023-07-01T00:00:00 \ 113 | --lead-time 12 \ 114 | --use-gfs \ 115 | ``` 116 | 117 | The package will automatically download and process the GFS data to use for you, as well 118 | as archive it for future reference. Please note that the model setup process (where 119 | assets are downloaded and cached for future runs) may take much longer than usual as we 120 | also take the liberty of generating independent copies of the ERA-5 template files used 121 | to process the GFS data. Given the current quota restrictions on the CDS-API, this may 122 | take a very long time (luckily, the stub functions which perform this process are super 123 | cheap to run and will cost pennies even if they get stuck for several hours). 124 | 125 | For your convenience, we've saved pre-computed data templates for you to use; for a 126 | typical `.env` setup described below, you can locally run the following Google Cloud 127 | SDK command to copy over the input templates so that `ai-models-for-all` will 128 | automatically discover them: 129 | 130 | ```shell 131 | $ gcloud storage cp \ 132 | gs://ai-models-for-all-input-templates/era5-to-gfs-f000 \ 133 | gs://${GCS_BUCKET_NAME} 134 | ``` 135 | 136 | ## More Detailed Setup Instructions 137 | 138 | To use this demo, you'll need accounts set up on [Google Cloud](https://cloud.google.com), 139 | [Modal](https://www.modal.com), and the [Copernicus Data Store](). 140 | Don't worry - even though you do need to supply them ith credit card information, this 141 | demo should cost virtually nothing to run; we'll use very limited storage on Google 142 | Cloud Storage for forecast model outputs that we generate (a few cents per month if you 143 | become a prolific user), and Modal offers new individual users a [startup package](https://modal.com/signup) 144 | which includes $30/month of free compute - so you could run this application for about 8 145 | hours straight before you'd incur any fees (A100's cost about $3.73 on Modal at the time 146 | of writing). 147 | 148 | If you're very new to cloud computing, the following sections will help walk you through 149 | the handful of steps necessary to get started with this demo. Experienced users can 150 | quickly skim through to see how they need to modify the provided `.env` to set up the 151 | necessary credentials for the application to work. 152 | 153 | ### Setting up Google Cloud 154 | 155 | The current version of this library ships with a single storage handler - a tool 156 | to upload to Google Cloud Storage. Without an external storage mechanism, there 157 | isn't a simple way to access the forecasts you generate using this tool. 158 | Thankfully, a basic cloud storage setup is easy to create, extremely cheap, and 159 | will serve you many purposes! 160 | 161 | There are two main steps that you need to take here. We assume you already have 162 | an account on Google Cloud Platform (it should be trivial to setup from 163 | http://console.cloud.google.com). 164 | 165 | #### 1) Create a bucket on Google Cloud Storage 166 | 167 | Navigate to your project's [Cloud Storage](https://console.cloud.google.com/storage/browser) 168 | control panel. From here, you should see a listing of buckets that you own. 169 | 170 | Find the button labeled **Create** near the top of the page. On the form that 171 | loads, you should only need to provide a name for your bucket. Your name needs 172 | to be globally unique - buckets across different projects must have different 173 | names. We recommend the simple naming scheme `-ai-models-for-all`. 174 | 175 | Keep all the default settings after inputting your bucket name and submit the 176 | form with the blue **Create** button at the bottom of the page. 177 | 178 | Finally, navigate to the `.env` file in this repo; set the **GCS_BUCKET_NAME** 179 | variable to the bucket name you chose previously. You do not need quotes around 180 | the bucket name. 181 | 182 | #### 2) Create a Service Account 183 | 184 | We need a mechanism so that your app running on Modal can authenticate with 185 | Google Cloud in order to use its resources and APIs. To do this, we're going to 186 | create a *service account* and set the limited permissions needed to run this 187 | application. 188 | 189 | From your [Google Cloud Console](http://console.cloud.google.com), navigate to 190 | the **IAM & Admin** panel and then select **Service Accounts** from the menu. On 191 | the resulting page, click the button near the top that says 192 | **Create Service Account**. 193 | 194 | On the form that pops up, use the following information: 195 | 196 | - *Service account name*: modal-ai-models-for-all 197 | - *Service account description*: Access from Modal for ai-models-for-all application 198 | 199 | The *Service account ID* field should automatically fill out for you. Click 200 | **Create and Continue** and you should move to Step 2, "Grant this service 201 | account access to project (optional)". Here, we will give permissions for the 202 | service account to access Cloud Storage resources (so that it can be used to 203 | upload and list objects in the bucket you previously created). 204 | 205 | From the drop-down labeled "Select a role", search for "Storage Object Admin" 206 | (you may want to use the filter). Add this role then click **Continue**. You 207 | shouldn't need to grant any specific user permissions, assuming you're the owner 208 | of the Google Cloud account in which you're setting this up. Click **Done**. 209 | 210 | Finally, we need to access the credentials for this new account. Navigate back 211 | to the page **IAM & Admin** > **Service Accounts**, and click the name in the 212 | table with the "model-ai-models-for-all" service account you just created. At 213 | the top of the page, navigate to the **Keys** tab and click the **ADD KEY** 214 | button on the middle of hte page. This will generate and download a new private 215 | key that you'll use. Select the "JSON" option from the pop-up and download the 216 | file by clicking **Create**. 217 | 218 | The credentials you created will be downloaded to disk automatically. Open that 219 | JSON file in your favorite text editor; you'll see a mapping of many different 220 | keys to values, including "private_key". We need to pass this entire JSON object 221 | whenever we want to authenticate with Google Cloud Storage. To do this, we'll 222 | save it in the same `.env` file under the name **GCS_SERVICE_ACCOUNT_INFO**. 223 | 224 | Unfortunately we can't just copy/paste - we need to "stringify" the data. You 225 | should probably do this in Python or your preferred programming language by 226 | reading in the JSON file you saved, serializing to a string, and outputting. 227 | 228 | ### Configuring `cdsapi` 229 | 230 | We need access to the [Copernicus Data Store](https://cds.climate.copernicus.eu/) 231 | to retrieve historical ERA-5 data to use when initializing our forecasts. The 232 | easiest way to set this up would be to have the user retrieve their credentials 233 | from [here](https://cds.climate.copernicus.eu/api-how-to) and save them to a 234 | local file, `~/.cdsapirc`. But that's a tad inconvenient to build into our 235 | application image. Instead, we can just set the environment variables 236 | **CDSAPI_URL** and **CDSAPI_KEY**. Note that we still create a stub RC file during 237 | image generation, but this is a shortcut so that users only need to modify a single 238 | file with their credentials. 239 | 240 | ## Other Notes 241 | 242 | - The code here has liberal comments detailing development notes, caveats, gotchas, 243 | and opportunities. 244 | - You may see some diagnostic outputs indicating that libraries including libexpat, 245 | libglib, and libtpu are missing. These should not impact the functionality of the 246 | current application (we've tested that all three AI models do in fact run 247 | and produce expected outputs). 248 | - You still need to install some required libraries locally. These are provided for 249 | you in the requirements.txt file. Use pip install -r requirements.txt to install. 250 | - It should be *very* cheap to run this application; even accounting for the time it 251 | takes to download model assets the first time a given AI model is run, most of the 252 | models can produce a 10-day forecast in about 10-15 minutes. So end-to-end, for a 253 | long forecast, the GPU container should really only be running for < 20 minutes, 254 | which means that at today's (11-25-2023) market rates of $3.73/hr per A100 GPU, it 255 | should cost about a bit more than a dollar to generate a forecast, all-in. 256 | 257 | ## Roadmap 258 | 259 | The following major projects are slated for Q1'24: 260 | 261 | **Operational AI Model Forecasts** - We will begin running pseudo-operational 262 | forecasts for all included models in early Q1 using GFS initial conditions in 263 | near-real-time, and disseminating the outputs in a publicly available Google 264 | Cloud Storage bucket (per model licensing restrictions). 265 | 266 | **Post-processing / Visualization** - We will implement some simple (optional) 267 | routines to post-process the default GRIB outputs into more standard ARCO formats, 268 | and generate a complete set of visualizations that users can review as a stand-alone 269 | gallery. Pending collaborations, we will try to make these available on popular 270 | model visualization websites (contact @darothen if you're interested in hosting). 271 | 272 | **Porting to `earth2mip`** - Although we've used ecmwf-labs/ai-models for the initial 273 | development, this package's extremely tight coupling with ECMWF infrastructure and 274 | the climetlab library pose considerable development challenges. Therefore, we aim 275 | to re-write this library using the NVIDIA/earth2mip framework. This is a far more 276 | comprehensive and extensible framework for preparing a variety of modeling and learning 277 | tasks related to AI-NWP, and provides access to the very large library of AI models 278 | NVIDIA is collecting for their model zoo. This will likely be built in a stand-alone 279 | package, e.g. `earth2mip-for-all`, but the goal is to provide the same accessibility 280 | and ease-of-use for users who simply want to run these models to create their own 281 | forecasts with limited engineering/infrastructure investment. 282 | -------------------------------------------------------------------------------- /ai-models-modal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/darothen/ai-models-for-all/bec4a5bcacf100625d823128806f2911c524e988/ai-models-modal/__init__.py -------------------------------------------------------------------------------- /ai-models-modal/ai_models_shim.py: -------------------------------------------------------------------------------- 1 | """Shim for interfacing with ai-models package and related plugins.""" 2 | 3 | import dataclasses 4 | from importlib.metadata import EntryPoint 5 | from typing import Type 6 | 7 | import ai_models 8 | from ai_models import model # noqa: F401 - needed for type annotations 9 | 10 | AIModelType = Type[ai_models.model.Model] 11 | 12 | 13 | @dataclasses.dataclass(frozen=True) 14 | class AIModelPluginConfig: 15 | """Configuration information for ai-models plugins. 16 | 17 | Although the ai-models package provides a simple interface (ai_models.model.Model) 18 | for preparing and running models via plugins, the exact mechanics can differ 19 | a bit (e.g. the FourCastNet plugin inserts an auxiliary method to bypass directly 20 | exposing the Model sub-class), we directly map the models to package name and 21 | entrypoints to the implementation classes. 22 | 23 | Attributes: 24 | model_name: the name of an AI NWP model as it is exposed through the ai-models 25 | plugin architecture (e.g. "fourcastnetv2-small" for FourCastNet). 26 | plugin_package_name: the name of an AI NWP model as it is encoded in the name 27 | of the plugin which provides it (e.g. "fourcastnetv2" for FourCastNet). 28 | entrypoint: an EntryPoint which maps directly to the class implementing the 29 | ai-models.model.Model interface for the given model. 30 | """ 31 | 32 | model_name: str 33 | plugin_package_name: str 34 | entry_point: EntryPoint 35 | 36 | 37 | AI_MODELS_CONFIGS: dict[str, AIModelPluginConfig] = { 38 | # PanguWeather 39 | "panguweather": AIModelPluginConfig( 40 | "panguweather", 41 | "panguweather", 42 | EntryPoint( 43 | name="panguweather", 44 | group="ai_models.model", 45 | value="ai_models_panguweather.model:PanguWeather", 46 | ), 47 | ), 48 | # FourCastNet 49 | "fourcastnetv2-small": AIModelPluginConfig( 50 | "fourcastnetv2-small", 51 | "fourcastnetv2", 52 | EntryPoint( 53 | name="fourcastnetv2", 54 | group="ai_models.model", 55 | value="ai_models_fourcastnetv2.model:FourCastNetv2", 56 | ), 57 | ), 58 | # GraphCast 59 | "graphcast": AIModelPluginConfig( 60 | "graphcast", 61 | "graphcast", 62 | EntryPoint( 63 | name="graphcast", 64 | group="ai_models.model", 65 | value="ai_models_graphcast.model:GraphcastModel", 66 | ), 67 | ), 68 | } 69 | 70 | SUPPORTED_AI_MODELS = [ 71 | plugin_config.model_name for plugin_config in AI_MODELS_CONFIGS.values() 72 | ] 73 | 74 | 75 | def get_model_class(model_name: str) -> AIModelType: 76 | """Get the class initializer for an ai-models plugin.""" 77 | return AI_MODELS_CONFIGS[model_name].entry_point.load() 78 | -------------------------------------------------------------------------------- /ai-models-modal/app.py: -------------------------------------------------------------------------------- 1 | """Modal object definitions for reference by other application components.""" 2 | import os 3 | 4 | import modal 5 | 6 | from . import ai_models_shim, config 7 | 8 | logger = config.get_logger(__name__) 9 | 10 | 11 | def download_model_assets(): 12 | """Download and cache the model weights necessary to run the model.""" 13 | raise Exception( 14 | "This function is deprecated; assets will be downloaded on the first run of a model" 15 | " and saved to an NFS running within the application." 16 | ) 17 | 18 | from ai_models import model 19 | from multiurl import download 20 | 21 | # For each model, retrieve the pretrained model weights and cache them to 22 | # our volume. We are generally replicating the code from 23 | # ai_models.model.Model.download_assets(), but with some hard-coded options; 24 | # that method is also originally written as an instance method, and we don't 25 | # want to run the actual initializer for a model type to access it since 26 | # that would require us to provide input/output options and otherwise 27 | # prepare more generally for a model inference run - something we're not 28 | # ready to do at this stage of setup. 29 | n_models = len(config.SUPPORTED_AI_MODELS) 30 | for i, model_name in enumerate(config.SUPPORTED_AI_MODELS, 1): 31 | logger.info(f"({i}/{n_models}) downloading assets for model {model_name}...") 32 | model_initializer = model.available_models()[model_name].load() 33 | for file in model_initializer.download_files: 34 | asset = os.path.realpath(os.path.join(config.AI_MODEL_ASSETS_DIR, file)) 35 | if not os.path.exists(asset): 36 | os.makedirs(os.path.dirname(asset), exist_ok=True) 37 | logger.info("downloading %s", asset) 38 | download( 39 | model_initializer.download_url.format(file=file), 40 | asset + ".download", 41 | ) 42 | os.rename(asset + ".download", asset) 43 | 44 | 45 | # Set up the image that we'll use for performing model inference. 46 | # NOTE: We use a somewhat convoluted build procedure here, but after much trial 47 | # and error, this seems to reliably build a working application. The biggest 48 | # issue we ran into was getting onnx to detect our GPU and NVIDIA libraries 49 | # correctly. To achieve this, we manually install via mamba a known, working 50 | # combination of CUDA and cuDNN. We also have to be careful when we install the 51 | # library for model-specific plugins to ai-models; these tended to re-install 52 | # the CPU-only onnxruntime library, so we manually uninstall that and purposely 53 | # install the onnxrtuntime-gpu library instead. 54 | # TODO: Explore whether we can consolidate the outputs from all of these pip 55 | # installation steps into a single, master requirements.txt. Several packages seem 56 | # to produce redundant requirements that lead to some version ping-ponging during 57 | # setup. A deterministic process to produce a single requirements set that we could 58 | # keep up-to-date would improve maintainability. 59 | inference_image = ( 60 | modal.Image 61 | # Micromamba will be much faster than conda, but we need to pin to 62 | # Python=3.10 to ensure ai-models' dependencies work correctly. 63 | .micromamba(python_version="3.10") 64 | .apt_install( 65 | [ 66 | "git", 67 | ] 68 | ) 69 | .micromamba_install( 70 | "cudatoolkit=11.8", 71 | "cudnn<=8.7.0", 72 | "eccodes", 73 | "pygrib", 74 | channels=[ 75 | "conda-forge", 76 | ], 77 | ) 78 | # Run several successive pip installs; this makes it a little bit easier to 79 | # handle the dependencies and final tweaks across different plugins. 80 | # (1) Install ai-models and its dependencies. 81 | .pip_install( 82 | [ 83 | "ai-models", 84 | "google-cloud-storage", 85 | "onnx==1.15.0", 86 | "ujson", 87 | ] 88 | ) 89 | # (2) GraphCast has some additional requirements - mostly around building a 90 | # properly configured version of JAX that can run on GPUs - so we take care 91 | # of those here. 92 | .pip_install( 93 | ["jax[cuda11_pip]==0.4.20", "git+https://github.com/deepmind/graphcast.git"], 94 | find_links="https://storage.googleapis.com/jax-releases/jax_cuda_releases.html", 95 | ) 96 | # (3) Install the ai-models plugins enabled for this package. 97 | .pip_install( 98 | [ 99 | "ai-models-" + plugin_config.plugin_package_name 100 | for plugin_config in ai_models_shim.AI_MODELS_CONFIGS.values() 101 | ] 102 | ) 103 | .run_commands("pip uninstall -y onnxruntime") 104 | # (4) Ensure that we're using the ONNX GPU-enabled runtime. 105 | .pip_install("onnxruntime-gpu==1.16.3") 106 | # Generate a blank .cdsapirc file so that we can override credentials with 107 | # environment variables later on. This is necessary because the ai-models 108 | # package input handler ultimately uses climetlab.sources.CDSAPIKeyPrompt to 109 | # create a client to the CDS API, and it has a hard-coded prompt check 110 | # which requires user interaction if this file doesn't exist. 111 | # TODO: Patch climetlab to allow env var overrides for CDS API credentials. 112 | .run_commands("touch /root/.cdsapirc") 113 | ) 114 | 115 | # Set up a storage volume for sharing model outputs between processes. 116 | # TODO: Explore adding a modal.Volume to cache model weights since it should be 117 | # much faster for loading them at runtime. 118 | volume = modal.NetworkFileSystem.persisted("ai-models-cache") 119 | 120 | stub = modal.Stub(name="ai-models-for-all", image=inference_image) 121 | -------------------------------------------------------------------------------- /ai-models-modal/config.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import pathlib 5 | 6 | import modal 7 | 8 | MAX_FCST_LEAD_TIME = 24 * 10 # 10 days 9 | 10 | 11 | # Set up a cache for assets leveraged during model runtime. 12 | CACHE_DIR = pathlib.Path("/cache") 13 | # Root dir in cache for writing completed model outputs. 14 | OUTPUT_ROOT_DIR = CACHE_DIR / "output" 15 | 16 | # Set up paths that can be mapped to our Volume in order to persist model 17 | # assets after they've been downloaded once. 18 | # TODO: Should we have a separate Volume instance for the model assets? 19 | AI_MODEL_ASSETS_DIR = CACHE_DIR / "assets" 20 | 21 | # Set up paths to archive initial conditions that are prepared for our model runs; 22 | # for now, this is just the processed GFS/GDAS initial conditions that we produce. 23 | INIT_CONDITIONS_DIR = CACHE_DIR / "initial_conditions" 24 | 25 | # Set a default GPU that's large enough to work with any of the published models 26 | # available to the ai-models package. 27 | DEFAULT_GPU_CONFIG = modal.gpu.A100(memory=40) 28 | 29 | # Set a default date to use when fetching sample data from ERA-5 to create templates 30 | # for processing GFS/GDAS data; we need this because we have to sort GRIB messages by 31 | # time when we prepare GraphCast inputs. 32 | DEFAULT_GFS_TEMPLATE_MODEL_EPOCH = datetime.datetime(2024, 1, 1, 0, 0) 33 | 34 | # Read secrets locally from a ".env" file; this avoids the need to have users 35 | # manually set them up in Modal, with the one downside that we do have to put 36 | # all secrets into the same file (but we don't plan to have many). 37 | # NOTE: Modal will try to read ".env" from the working directory, not from our 38 | # module directory. So keep the ".env" in the repo root. 39 | ENV_SECRETS = modal.Secret.from_dotenv() 40 | 41 | # Manually set all "forced" actions to run (e.g. re-processing initial conditions) 42 | FORCE_OVERRIDE = True 43 | 44 | 45 | def validate_env(): 46 | """Validate that expected env vars from .env are imported correctly.""" 47 | assert os.environ.get("CDS_API_KEY", "") != "YOUR_KEY_HERE" 48 | assert os.environ.get("GCS_SERVICE_ACCOUNT_INFO", "") != "YOUR_SERVICE_ACCOUNT_INFO" 49 | assert os.environ.get("GCS_BUCKET_NAME", "") != "YOUR_BUCKET_NAME" 50 | 51 | 52 | def make_output_path( 53 | model_name: str, init_datetime: datetime.datetime, use_gfs: bool 54 | ) -> pathlib.Path: 55 | """Create a full path for writing a model output GRIB file.""" 56 | src = "gfs" if use_gfs else "era5" 57 | filename = f"{model_name}.{src}.{init_datetime:%Y%m%d%H%M}.grib" 58 | return OUTPUT_ROOT_DIR / filename 59 | 60 | 61 | def make_gfs_template_path(model_name: str) -> pathlib.Path: 62 | """Create a expected path where GFS/GDAS -> ERA-5 template should exist.""" 63 | return AI_MODEL_ASSETS_DIR / f"{model_name}.input-template.grib2" 64 | 65 | 66 | def get_logger( 67 | name: str, level: int = logging.INFO, add_handler=False 68 | ) -> logging.Logger: 69 | """Set up a default logger with configs for working within a modal app.""" 70 | logger = logging.getLogger(name) 71 | logger.setLevel(level) 72 | 73 | if add_handler: 74 | handler = logging.StreamHandler() 75 | handler.setFormatter( 76 | logging.Formatter("%(levelname)s: %(asctime)s: %(name)s %(message)s") 77 | ) 78 | logger.addHandler(handler) 79 | 80 | # logger.propagate = False 81 | return logger 82 | 83 | 84 | def set_logger_basic_config(level: int = logging.INFO): 85 | handler = logging.StreamHandler() 86 | handler.setFormatter( 87 | logging.Formatter("%(levelname)s: %(asctime)s: %(name)s %(message)s") 88 | ) 89 | logging.basicConfig(level=level, handlers=[handler]) 90 | -------------------------------------------------------------------------------- /ai-models-modal/gcs.py: -------------------------------------------------------------------------------- 1 | """Utilities for working with Google Cloud Storage. 2 | 3 | This library is lifted as-is from darothen@'s `plotflow` library, which is 4 | originally part of a serverless application for processing NWP outputs on the 5 | cloud and generating visualizations from them. See 6 | Rothenberg, Daniel: "Enabling Scalable, Serverless Weather Model Analyses by 7 | "Kerchunking" Data in the Cloud." AMS Annual Meeting, Baltimore, MD. 2024 8 | for more details. 9 | """ 10 | 11 | import os 12 | from pathlib import Path 13 | from typing import Any 14 | 15 | import ujson 16 | from google.cloud import storage 17 | 18 | from . import config 19 | 20 | logger = config.get_logger(__name__) 21 | 22 | 23 | def get_service_account_json(env_var: str = "GCS_SERVICE_ACCOUNT_INFO") -> dict: 24 | """Try to generate service account JSON from an env var. 25 | 26 | Parameters: 27 | ----------- 28 | env_var: str 29 | Name of an environment variable containing stringified JSON service account credentials. 30 | """ 31 | service_account_info = os.environ.get(env_var, "") 32 | if not service_account_info: 33 | return {} 34 | return ujson.loads(service_account_info) 35 | 36 | 37 | class GoogleCloudStorageHandler(object): 38 | def __init__(self, client: storage.Client = None): 39 | if client is None: 40 | self._client = storage.Client() 41 | else: 42 | self._client = client 43 | 44 | @property 45 | def client(self): 46 | return self._client 47 | 48 | @staticmethod 49 | def with_anonymous_client() -> "GoogleCloudStorageHandler": 50 | return GoogleCloudStorageHandler( 51 | client=storage.Client.create_anonymous_client() 52 | ) 53 | 54 | @staticmethod 55 | def with_service_account_info( 56 | service_account_info: Any, 57 | ) -> "GoogleCloudStorageHandler": 58 | return GoogleCloudStorageHandler( 59 | client=storage.Client.from_service_account_info(service_account_info) 60 | ) 61 | 62 | # TODO: Add retry and timeout logic to all of these functions, following the docs at 63 | # https://cloud.google.com/storage/docs/retry-strategy#customize-retries 64 | def download_blob( 65 | self, 66 | bucket_name: str, 67 | source_blob_name: str, 68 | destination_pth: Path, 69 | ) -> None: 70 | """Download a blob from GCS to a local path. 71 | 72 | Parameters 73 | ---------- 74 | bucket_name : str 75 | Bucket on GCS containing the blob to download. 76 | source_blob_name : str 77 | Name of the blob to download. 78 | destination_pth : Path 79 | Local path to download the blob to. 80 | """ 81 | 82 | bucket = self.client.bucket(bucket_name) 83 | 84 | # NOTE: `Bucket.blob` differs from `Bucket.get_blob` as it doesn't retrieve 85 | # any content from Google Cloud Storage. As we don't need additional data, 86 | # using `Bucket.blob` is preferred here. 87 | blob = bucket.blob(source_blob_name) 88 | logger.info( 89 | f"Downloading gs://{bucket_name}/{source_blob_name} to {destination_pth}" 90 | ) 91 | blob.download_to_filename(destination_pth) 92 | 93 | def upload_blob( 94 | self, bucket_name: str, source_file_pth: Path, destination_blob_name: str 95 | ): 96 | """Uploads a blob to GCS from a local path. 97 | 98 | Parameters 99 | ---------- 100 | bucket_name : str 101 | Bucket on GCS where the blob should be uploaded. 102 | source_file_pth : Path 103 | Local path to the file to upload. 104 | destination_blob_name : str 105 | Blob name to use when writing to `bucket_name` on GCS. 106 | """ 107 | 108 | bucket = self.client.bucket(bucket_name) 109 | blob = bucket.blob(destination_blob_name) 110 | logger.info( 111 | f"Uploading {source_file_pth} to gs://{bucket_name}/{destination_blob_name}." 112 | ) 113 | blob.upload_from_filename(source_file_pth) 114 | 115 | def upload_json_to_blob( 116 | self, bucket_name: str, json_str: str, destination_blob_name: str 117 | ): 118 | """Uploads JSON string to a GCS blob. 119 | 120 | Parameters 121 | ---------- 122 | bucket_name : str 123 | Bucket on GCS where the blob should be uploaded. 124 | json_str : str 125 | Encoded JSON data string to upload. 126 | destination_blob_name : str 127 | Blob name to use when writing to `bucket_name` on GCS. 128 | """ 129 | 130 | bucket = self.client.bucket(bucket_name) 131 | blob = bucket.blob(destination_blob_name) 132 | logger.info(f"Uploading JSON to gs://{bucket_name}/{destination_blob_name}.") 133 | blob.upload_from_string(data=json_str, content_type="application/json") 134 | -------------------------------------------------------------------------------- /ai-models-modal/gfs.py: -------------------------------------------------------------------------------- 1 | """Utilities for acquiring, fetching, and working with GFS/GDAS data for 2 | use in the AI models application.""" 3 | 4 | import datetime 5 | import pathlib 6 | from collections import namedtuple 7 | from typing import Any, Sequence, Type 8 | 9 | import pygrib 10 | from tqdm import tqdm 11 | from tqdm.contrib.logging import logging_redirect_tqdm 12 | 13 | from . import config 14 | 15 | logger = config.get_logger(__name__) 16 | 17 | 18 | def identity(x: Any) -> Any: 19 | """Identity pass-through function.""" 20 | return x 21 | 22 | 23 | # Density of water 24 | RHO_WATER = 1000.0 # kg m^-3 25 | 26 | # A `grib_mapper` is a simple wrapper for information we use to succintly identify 27 | # and coerce GRIB messages from one source to another. 28 | grib_mapper = namedtuple( 29 | "grib_mapper", ["source_field", "target_field", "fn", "source_matcher_override"] 30 | ) 31 | 32 | # ERA5 field name -> Mapper from GDAS to ERA5 33 | # We break these down hierarchically by the type_of_level in order to help disambiguate 34 | # certain fields (such as "z") which be present in both single- and multi-level field 35 | # sets. Note that these are the canonical level types, not the "pl"(->isobaricInhPa) and 36 | # "sfc"(->surface) level types that we use for querying the CDS API. 37 | mappers_by_type_of_level = { 38 | "isobaricInhPa": { 39 | "z": grib_mapper("gh", "z", lambda x: x * 9.81, {}), # Geopotential height 40 | }, 41 | "surface": { 42 | # NOTE: In GraphCast, we also consume surface geopotential height, which according 43 | # to the param_db (https://codes.ecmwf.int/grib/param-db/129) should just be the 44 | # surface orography. 45 | "z": grib_mapper("orog", "z", lambda x: x * 9.81, {}), # Geopotential height 46 | # NOTE: We might want to copy the _original_ ERA-5 lsm field instead of using 47 | # the GDAS one. 48 | "lsm": grib_mapper("lsm", "lsm", identity, {}), # Land-sea binary mask, 49 | # NOTE: This is a gross approximation to estimating 1-hr precip accumulation from 50 | # the available instantaneous precip rate. We should develop a more complex 51 | # way involving reading the hourly precip accumulations from the GFS forecasts. 52 | "tp": grib_mapper( 53 | "prate", "tp", lambda x: (x / RHO_WATER) * 3600 * 1, {} 54 | ), # Total precipitation 55 | "msl": grib_mapper( 56 | "prmsl", "msl", identity, {"typeOfLevel": "meanSea"} 57 | ), # Mean sea level pressure 58 | "10u": grib_mapper( 59 | "10u", "10u", identity, {"typeOfLevel": "heightAboveGround", "level": 10} 60 | ), # 10 meter U wind component 61 | "10v": grib_mapper( 62 | "10v", "10v", identity, {"typeOfLevel": "heightAboveGround", "level": 10} 63 | ), # 10 meter V wind component 64 | "100u": grib_mapper( 65 | "100u", "100u", identity, {"typeOfLevel": "heightAboveGround", "level": 100} 66 | ), # 100 meter U wind component 67 | "100v": grib_mapper( 68 | "100v", "100v", identity, {"typeOfLevel": "heightAboveGround", "level": 100} 69 | ), # 100 meter V wind component 70 | "2t": grib_mapper( 71 | "2t", "2t", identity, {"typeOfLevel": "heightAboveGround", "level": 2} 72 | ), # 2 meter temperature 73 | "tcwv": grib_mapper( 74 | "pwat", 75 | "tcwv", 76 | identity, 77 | {"typeOfLevel": "atmosphereSingleLayer", "level": 0}, 78 | ), # Total column water vapor, taken from GFS precipitable water 79 | }, 80 | } 81 | 82 | 83 | # NOTE: Would prefer this to be a TypeAlias (https://peps.python.org/pep-0613/) 84 | # but it's not available until Python 3.12. 85 | PyGribHandle = Type[pygrib._pygrib.open] 86 | PyGribMessage = Type[pygrib._pygrib.gribmessage] 87 | 88 | 89 | GFS_BUCKET = "global-forecast-system" 90 | 91 | 92 | def make_gfs_ics_blob_name(model_epoch: datetime.datetime) -> str: 93 | """Generate the blob name for a GFS initial conditions file. 94 | 95 | We specifically target the GFS 0-hour forecast; in practice this shouldn't 96 | very much from the GDAS analysis (or GFS ANL file), but it has field names 97 | highly consistent with the ERA-5 metadata, for the most part. 98 | 99 | Parameters 100 | ---------- 101 | model_epoch : datetime.datetime 102 | The model initialization time. 103 | """ 104 | return "/".join( 105 | [ 106 | f"gfs.{model_epoch:%Y%m%d}", 107 | f"{model_epoch:%H}", 108 | "atmos", 109 | f"gfs.t{model_epoch:%H}z.pgrb2.0p25.f000", 110 | ] 111 | ) 112 | 113 | 114 | def make_gfs_base_pth(model_epoch: datetime.datetime) -> pathlib.Path: 115 | """Generate the local path for a GFS initial conditions file. 116 | 117 | Parameters 118 | ---------- 119 | model_epoch : datetime.datetime 120 | The model initialization time. 121 | """ 122 | return config.INIT_CONDITIONS_DIR / f"{model_epoch:%Y%m%d%H%M}" 123 | 124 | 125 | def select_grb(grbs: PyGribHandle, **matchers) -> PyGribMessage: 126 | """ 127 | Select a single GRIB message from a PyGribHandle using the supplied matchers. 128 | """ 129 | matching_grbs = grbs.select(**matchers) 130 | if not matching_grbs: 131 | raise ValueError(f"Could not match GRIB message with {matchers}") 132 | elif len(matching_grbs) > 1: 133 | raise ValueError(f"Multiple matches for {matchers}") 134 | return matching_grbs[0] 135 | 136 | 137 | def select_grb_from_list(grbs: Sequence[PyGribMessage], **matchers) -> PyGribMessage: 138 | """ 139 | Select a single GRIB message from a list of PyGribMessages using the supplied matchers. 140 | """ 141 | matching_grbs = [grb for grb in grbs if grb_matches(grb, **matchers)] 142 | if not matching_grbs: 143 | for i, grb in enumerate(grbs): 144 | # if grb.shortName == matchers["shortName"]: 145 | print(i, *[(k, v, grb[k], grb[k] == v) for k, v in matchers.items()]) 146 | raise ValueError(f"Could not match GRIB message with {matchers}") 147 | elif len(matching_grbs) > 1: 148 | raise ValueError(f"Multiple matches for {matchers}") 149 | return matching_grbs[0] 150 | 151 | 152 | def grb_matches(grb: PyGribMessage, **matchers) -> PyGribMessage: 153 | """ 154 | Return "true" if a GRIB message matches all the specified key-value attributes. 155 | """ 156 | return all(grb[k] == v for k, v in matchers.items()) 157 | 158 | 159 | def process_gdas_grib( 160 | template_pth: pathlib.Path, 161 | gdas_pth: pathlib.Path, 162 | model_init: datetime.datetime = config.DEFAULT_GFS_TEMPLATE_MODEL_EPOCH, 163 | extra_template_matchers: dict = {}, 164 | ) -> Sequence[PyGribMessage]: 165 | """Process a GDAS GRIB file to prepare an input for an AI NWP forecast. 166 | 167 | Parameters 168 | ---------- 169 | template_pth : pathlib.Path 170 | The local path to the ERA-5 template GRIB file for a given model. 171 | gdas_pth : pathlib.Path 172 | The local path to the GDAS GRIB file, most likely downloaded from GCS. 173 | model_init : datetime.datetime 174 | The model analysis / initialization time, to be used in overwriting the 175 | timestamp data borrowed from the ERA-5 template GRIB. 176 | extra_template_matchers : dict, optional 177 | Additional key-value pairs to hard-code when selecting GRB messages from 178 | the template; this is useful when we need to downselect some of the 179 | template messages. 180 | 181 | Returns 182 | ------- 183 | Sequence[GrbMessage] 184 | A sequence of GRIB messages which can be written to a binary output file. 185 | """ 186 | logger.info("Reading template GRIB file %s...", template_pth) 187 | template_grbs = [] 188 | with pygrib.open(str(template_pth)) as grbs: 189 | if extra_template_matchers: 190 | grbs = grbs.select(**extra_template_matchers) 191 | for grb in grbs: 192 | template_grbs.append(grb) 193 | logger.info("... found %d GRIB messages", len(template_grbs)) 194 | 195 | time_kwargs = dict( 196 | dataDate=int(model_init.strftime("%Y%m%d")), 197 | dataTime=int(model_init.strftime("%H%M")), 198 | ) 199 | 200 | logger.info("Copying and processing GRIB messages from %s...", gdas_pth) 201 | with pygrib.open(str(gdas_pth)) as source_grbs, logging_redirect_tqdm( 202 | loggers=[ 203 | logger, 204 | ] 205 | ): 206 | # Pre-emptively subset all the source_grbs by matching against short names in 207 | # the template collection we previously opened. This greatly reduces the time it 208 | # takes to seek through the source GRIB file, which involves repeatedly reading 209 | # through the entire file from start to finish (~30x improvement when reading from 210 | # an SSD, so much faster on a cloud VM). 211 | all_short_names = [grb.shortName for grb in template_grbs] 212 | for mappers in mappers_by_type_of_level.values(): 213 | all_short_names.extend(m.source_field for m in mappers.values()) 214 | all_short_names = set(all_short_names) 215 | source_grb_list = source_grbs.select(shortName=all_short_names) 216 | 217 | for grb in tqdm( 218 | template_grbs, 219 | unit="msg", 220 | total=len(template_grbs), 221 | desc="GRIB messages", 222 | ): 223 | # Get the type of level so that we can match to the right mapper set. 224 | mappers = mappers_by_type_of_level[grb.typeOfLevel] 225 | if grb.shortName in mappers: 226 | mapper = mappers[grb.shortName] 227 | source_matchers = mapper.source_matcher_override 228 | source_grb = select_grb_from_list( 229 | source_grb_list, 230 | shortName=mapper.source_field, 231 | typeOfLevel=source_matchers.get("typeOfLevel", grb.typeOfLevel), 232 | level=source_matchers.get("level", grb.level), 233 | ) 234 | old_mean = grb.values.mean() 235 | grb.values = mapper.fn(source_grb.values) 236 | new_mean = grb.values.mean() 237 | grb.shortName = mapper.target_field 238 | logger.debug( 239 | "mapped: [x] | %10s | Old: %g | New: %g | Copied: %g", 240 | grb.shortName, 241 | old_mean, 242 | mapper.fn(source_grb.values).mean(), 243 | new_mean, 244 | ) 245 | else: 246 | source_grb = select_grb_from_list( 247 | source_grb_list, 248 | shortName=grb.shortName, 249 | typeOfLevel=grb.typeOfLevel, 250 | level=grb.level, 251 | ) 252 | old_mean = grb.values.mean() 253 | grb.values = source_grb.values 254 | new_mean = grb.values.mean() 255 | logger.debug( 256 | "mapped: [ ] | %10s | Old: %g | Copied: %g", 257 | grb.shortName, 258 | old_mean, 259 | new_mean, 260 | ) 261 | 262 | # Overwrite the GRIB metadata with the model initialization time. 263 | for key, val in time_kwargs.items(): 264 | grb[key] = val 265 | 266 | return template_grbs 267 | -------------------------------------------------------------------------------- /ai-models-modal/main.py: -------------------------------------------------------------------------------- 1 | """A Modal application for running `ai-models` weather forecasts.""" 2 | 3 | import datetime 4 | import os 5 | import pathlib 6 | import shutil 7 | 8 | import modal 9 | from ai_models import model 10 | from tqdm import tqdm 11 | from tqdm.contrib.logging import logging_redirect_tqdm 12 | 13 | from . import ai_models_shim, config, gcs 14 | from .app import stub, volume 15 | 16 | config.set_logger_basic_config() 17 | logger = config.get_logger(__name__, add_handler=False) 18 | 19 | 20 | @stub.function( 21 | image=stub.image, 22 | secrets=[config.ENV_SECRETS], 23 | network_file_systems={str(config.CACHE_DIR): volume}, 24 | timeout=300, 25 | ) 26 | def prepare_gfs_analysis( 27 | model_name: str = "panguweather", 28 | model_init: datetime.datetime = datetime.datetime(2023, 7, 1, 0, 0), 29 | force: bool = config.FORCE_OVERRIDE, 30 | ): 31 | """Retrieve and prepare initial conditions from the GFS/GDAS to run with an AI model. 32 | 33 | Parameters 34 | ---------- 35 | model_name : str 36 | Short name for the model to run; must be one of ['panguweather', 'fourcastnet_v2', 37 | 'graphcast']. Defaults to 'panguweather'. 38 | model_init : datetime.datetime 39 | Target initialization time or model epoch to fetch. 40 | force : bool 41 | Force re-download and processing, even if the target file already exists. 42 | 43 | """ 44 | from . import gfs 45 | 46 | logger.info(f"Preparing GFS/GDAS initial conditions for {model_name} model run...") 47 | 48 | template_pth = config.make_gfs_template_path(model_name) 49 | if not template_pth.exists(): 50 | raise ValueError( 51 | f"Expected to find GFS/GDAS -> ERA-5 template at {template_pth}, but file does not exist." 52 | ) 53 | 54 | gdas_base_pth = gfs.make_gfs_base_pth(model_init) 55 | gdas_base_pth.mkdir(parents=True, exist_ok=True) 56 | 57 | proc_gdas_fn = f"gdas.proc-{model_name}.grib" 58 | final_proc_gdas_pth = gdas_base_pth / proc_gdas_fn 59 | 60 | # Short-circuit - don't waste our time if file already exists. 61 | if final_proc_gdas_pth.exists() and not force: 62 | logger.info( 63 | f"Found existing processed GFS/GDAS file {gdas_base_pth / proc_gdas_fn};" 64 | " skipping download and processing." 65 | ) 66 | return 67 | 68 | service_account_info = gcs.get_service_account_json("GCS_SERVICE_ACCOUNT_INFO") 69 | gcs_handler = gcs.GoogleCloudStorageHandler.with_service_account_info( 70 | service_account_info 71 | ) 72 | 73 | # Set up the files to download with useful metadata (e.g. time lags) 74 | match model_name: 75 | case "panguweather" | "fourcastnetv2-small": 76 | model_init_tds = [ 77 | datetime.timedelta(hours=0), 78 | ] 79 | source_blob_names = [ 80 | gfs.make_gfs_ics_blob_name(model_init + td) for td in model_init_tds 81 | ] 82 | case "graphcast": 83 | # By convention, the first element is the init time, and the second element 84 | # is the time-lagged input. 85 | model_init_tds = [ 86 | datetime.timedelta(hours=0), 87 | datetime.timedelta(hours=-6), 88 | ] 89 | source_blob_names = [ 90 | gfs.make_gfs_ics_blob_name(model_init + td) for td in model_init_tds 91 | ] 92 | case _: 93 | raise ValueError(f"Encountered unknown model {model_name}") 94 | 95 | source_fns = [blob_name.split("/")[-1] for blob_name in source_blob_names] 96 | for source_blob_name, source_fn in zip(source_blob_names, source_fns): 97 | logger.info( 98 | f"Attempting to download GFS/GDAS blob gs://{gfs.GFS_BUCKET}/{source_blob_name}..." 99 | ) 100 | gcs_handler.download_blob(gfs.GFS_BUCKET, source_blob_name, source_fn) 101 | 102 | # Sanity check to make sure we were able to download the GDAS file. 103 | if not pathlib.Path(source_fn).exists(): 104 | raise RuntimeError("Failed to download GFS/GDAS blob.") 105 | 106 | # Run subsetting 107 | logger.info("Subsetting GFS/GDAS data...") 108 | match model_name: 109 | case "panguweather" | "fourcastnetv2-small": 110 | # There should only be one file that we downloaded, so we can just directly 111 | # use it. 112 | source_fn = source_fns[0] 113 | logger.info("Processing Set 1 -> %s", source_fn) 114 | subset_grbs = gfs.process_gdas_grib(template_pth, source_fn, model_init) 115 | case "graphcast": 116 | # Use our slightly custom logic. 117 | # TODO: Re-factor this to its own stand-alone function for cleanliness. 118 | 119 | # Timedeltas for Set 1 - the 0- and 6-hr lagged messages 120 | # NOTE: these should match the deltas in model_init_tds above; ideally we should 121 | # just re-use those directly. 122 | template_tds = [datetime.timedelta(hours=0), datetime.timedelta(hours=-6)] 123 | # Timedeltas for Set 2 (precipitation) - due to some quirkiness in the ai-models package, 124 | # we use 6- and 18-hr offsets for the 0- and 6-hr lagged messages, respectively. 125 | tp_template_tds = [ 126 | datetime.timedelta(hours=-6), 127 | datetime.timedelta(hours=-18), 128 | ] 129 | subset_grbs = [] 130 | # Set 1 - Core fields (everything but precipitation) 131 | for source_fn, template_td in zip(source_fns, template_tds): 132 | logger.info("Processing Set 1 (core fields) -> %s", source_fn) 133 | template_dt = config.DEFAULT_GFS_TEMPLATE_MODEL_EPOCH + template_td 134 | extra_template_matchers = { 135 | "dataDate": int(template_dt.strftime("%Y%m%d")), 136 | "dataTime": int(template_dt.strftime("%H%M")), 137 | "shortName": lambda x: x != "tp", 138 | } 139 | output_msgs = gfs.process_gdas_grib( 140 | template_pth, 141 | pathlib.Path(source_fn), 142 | # Offset the model_init time by the expected timedelta so that we 143 | # appropriately encode the GRIB message timestamps. 144 | model_init + template_td, 145 | extra_template_matchers=extra_template_matchers, 146 | ) 147 | subset_grbs.extend(output_msgs) 148 | # Set 2) - Precipitation; use the alternate time deltas and hardcode the precipitation 149 | # field. 150 | for source_fn, template_td in zip(source_fns, tp_template_tds): 151 | logger.info("Processing Set 2 (precipitation) -> %s", source_fn) 152 | template_dt = config.DEFAULT_GFS_TEMPLATE_MODEL_EPOCH + template_td 153 | extra_template_matchers = { 154 | "dataDate": int(template_dt.strftime("%Y%m%d")), 155 | "dataTime": int(template_dt.strftime("%H%M")), 156 | "shortName": "tp", 157 | } 158 | output_msgs = gfs.process_gdas_grib( 159 | template_pth, 160 | pathlib.Path(source_fn), 161 | model_init + template_td, 162 | extra_template_matchers=extra_template_matchers, 163 | ) 164 | subset_grbs.extend(output_msgs) 165 | case _: 166 | raise ValueError(f"Encountered unknown model {model_name}") 167 | 168 | with ( 169 | open(proc_gdas_fn, "wb") as f, 170 | logging_redirect_tqdm( 171 | loggers=[ 172 | logger, 173 | ] 174 | ), 175 | ): 176 | for grb in tqdm( 177 | subset_grbs, 178 | unit="msg", 179 | total=len(subset_grbs), 180 | desc="GRIB messages", 181 | ): 182 | msg = grb.tostring() 183 | f.write(msg) 184 | logger.info( 185 | "Copying processed GFS/GDAS file to cache at %s...", 186 | final_proc_gdas_pth, 187 | ) 188 | shutil.copy(proc_gdas_fn, final_proc_gdas_pth) 189 | logger.info("... done.") 190 | 191 | # Sanity check to make sure that we wrote out the processed GDAS file. 192 | if not (gdas_base_pth / proc_gdas_fn).exists(): 193 | raise RuntimeError("Failed to produce subset GFS/GDAS GRIB.") 194 | 195 | 196 | @stub.function( 197 | image=stub.image, 198 | secrets=[config.ENV_SECRETS], 199 | network_file_systems={str(config.CACHE_DIR): volume}, 200 | # gpu="T4", 201 | timeout=60, 202 | allow_cross_region_volumes=True, 203 | ) 204 | def check_assets(skip_validate_env: bool = False): 205 | """This is a placeholder function for testing that the application and credentials 206 | are all set up correctly and working as expected.""" 207 | import cdsapi 208 | 209 | if not skip_validate_env: 210 | config.validate_env() 211 | 212 | logger.info(f"Running locally -> {modal.is_local()}") 213 | 214 | assets = list(config.AI_MODEL_ASSETS_DIR.glob("**/*")) 215 | logger.info(f"Found {len(assets)} assets:") 216 | for i, asset in enumerate(assets, 1): 217 | logger.info(f"({i}) {asset}") 218 | logger.info(f"CDS API URL: {os.environ['CDSAPI_URL']}") 219 | logger.info(f"CDS API Key: {os.environ['CDSAPI_KEY']}") 220 | 221 | client = cdsapi.Client() 222 | logger.info(client) 223 | 224 | test_cdsapirc = pathlib.Path("~/.cdsapirc").expanduser() 225 | logger.info(f"Test .cdsapirc: {test_cdsapirc} exists = {test_cdsapirc.exists()}") 226 | 227 | logger.info("Trying to import eccodes...") 228 | # NOTE: Right now, this will throw a UserWarning: "libexpat.so.1: cannot 229 | # open shared object file: No such file or directory." This is likely due to 230 | # something not being built correctly by mamba in the application image, but 231 | # it doesn't impact functionality at the moment. 232 | import eccodes 233 | 234 | logger.info("Getting GPU information...") 235 | import onnxruntime as ort 236 | 237 | logger.info( 238 | f"ort avail providers: {ort.get_available_providers()}" 239 | ) # output: ['CUDAExecutionProvider', 'CPUExecutionProvider'] 240 | logger.info(f"onnxruntime device: {ort.get_device()}") # output: GPU 241 | 242 | logger.info(f"Checking contents on network file system at {config.CACHE_DIR}...") 243 | for i, asset in enumerate(config.CACHE_DIR.glob("**/*"), 1): 244 | logger.info(f"({i}) {asset}") 245 | 246 | logger.info("Checking for access to GCS...") 247 | 248 | service_account_info = gcs.get_service_account_json("GCS_SERVICE_ACCOUNT_INFO") 249 | gcs_handler = gcs.GoogleCloudStorageHandler.with_service_account_info( 250 | service_account_info 251 | ) 252 | bucket_name = os.environ["GCS_BUCKET_NAME"] 253 | logger.info(f"Listing blobs in GCS bucket gs://{bucket_name}") 254 | blobs = list(gcs_handler.client.list_blobs(bucket_name)) 255 | logger.info(f"Found {len(blobs)} blobs:") 256 | for i, blob in enumerate(blobs, 1): 257 | logger.info(f"({i}) {blob.name}") 258 | 259 | 260 | @stub.cls( 261 | secrets=[config.ENV_SECRETS], 262 | gpu=config.DEFAULT_GPU_CONFIG, 263 | network_file_systems={str(config.CACHE_DIR): volume}, 264 | concurrency_limit=1, 265 | timeout=1_800, 266 | ) 267 | class AIModel: 268 | def __init__( 269 | self, 270 | # TODO: Re-factor arguments into a well-structured dataclass. 271 | model_name: str = "panguweather", 272 | model_init: datetime.datetime = datetime.datetime(2023, 7, 1, 0, 0), 273 | lead_time: int = 12, 274 | use_gfs: bool = False, 275 | ) -> None: 276 | self.model_name = model_name 277 | self.model_init = model_init 278 | 279 | # Cap forecast lead time to 10 days; the models may or may not work longer than 280 | # this, but this is an unnecessary foot-gun. A savvy user can disable this check 281 | # in-code. 282 | if lead_time > config.MAX_FCST_LEAD_TIME: 283 | logger.warning( 284 | f"Requested forecast lead time ({lead_time}) exceeds max; setting" 285 | f" to {config.MAX_FCST_LEAD_TIME}. You can manually set a higher limit in" 286 | "ai-models-modal/config.py::MAX_FCST_LEAD_TIME." 287 | ) 288 | self.lead_time = config.MAX_FCST_LEAD_TIME 289 | else: 290 | self.lead_time = lead_time 291 | 292 | self.out_pth = config.make_output_path(model_name, model_init, use_gfs) 293 | self.out_pth.parent.mkdir(parents=True, exist_ok=True) 294 | 295 | self.use_gfs = use_gfs 296 | 297 | def __enter__(self): 298 | logger.info(f" Model: {self.model_name}") 299 | logger.info(f" Run initialization datetime: {self.model_init}") 300 | logger.info(f" Forecast lead time: {self.lead_time}") 301 | logger.info(f" Model output path: {str(self.out_pth)}") 302 | logger.info( 303 | f" Initial conditions source: {'gfs' if self.use_gfs else 'era5'}" 304 | ) 305 | logger.info("Running model initialization / staging...") 306 | if self.use_gfs: 307 | self.init_model = self._init_model_for_gfs() 308 | else: 309 | self.init_model = self._init_model_for_era5() 310 | logger.info("... done! Model is initialized and ready to run.") 311 | 312 | def _init_model_for_era5(self): 313 | """Set up the model for running with ERA-5 initial conditions.""" 314 | model_class = ai_models_shim.get_model_class(self.model_name) 315 | return model_class( 316 | # Necessary arguments to instantiate a Model object 317 | input="cds", 318 | output="file", 319 | download_assets=False, 320 | # Additional arguments. These are generally set as object attributes 321 | # which are then referred to by various Model methods; unfortunately, 322 | # they're not clearly declared in the class documentation so there is 323 | # a bit of trial and error involved in figuring out what's needed. 324 | assets=config.AI_MODEL_ASSETS_DIR, 325 | date=int(self.model_init.strftime("%Y%m%d")), 326 | time=self.model_init.hour, 327 | lead_time=self.lead_time, 328 | path=str(self.out_pth), 329 | metadata={}, # Read by the output data handler 330 | # Unused arguments that are required by Model class methods to work. 331 | model_args={}, 332 | assets_sub_directory=None, 333 | staging_dates=None, 334 | # TODO: Figure out if we can set up caching of model initial conditions 335 | # using the default interface. 336 | archive_requests=False, 337 | only_gpu=True, 338 | # Assumed set by GraphcastModel; produces additional auxiliary 339 | # output NetCDF files. 340 | debug=False, 341 | ) 342 | 343 | def _init_model_for_gfs(self): 344 | """Set up the model for running with GFS/GDAS initial conditions.""" 345 | from . import gfs 346 | 347 | model_class = ai_models_shim.get_model_class(self.model_name) 348 | 349 | # Create expected path for processed initial conditions, and check that it's 350 | # available for us to consume. 351 | gdas_base_pth = gfs.make_gfs_base_pth(self.model_init) 352 | gdas_proc_fn = f"gdas.proc-{self.model_name}.grib" 353 | gdas_proc_pth = gdas_base_pth / gdas_proc_fn 354 | if not gdas_proc_pth.exists(): 355 | raise RuntimeError( 356 | f"Expected processed GFS/GDAS initial conditions file not found at" 357 | f" {gdas_proc_fn}." 358 | ) 359 | logger.info("Copying processed GFS/GDAS file from cache to local...") 360 | shutil.copy(gdas_proc_pth, gdas_proc_fn) 361 | logger.info("... done.") 362 | logger.info(f"Reading GFS/GDAS initial conditions from {gdas_proc_fn}.") 363 | 364 | return model_class( 365 | output="file", 366 | download_assets=False, 367 | assets=config.AI_MODEL_ASSETS_DIR, 368 | date=int(self.model_init.strftime("%Y%m%d")), 369 | time=self.model_init.hour, 370 | lead_time=self.lead_time, 371 | path=str(self.out_pth), 372 | metadata={}, 373 | model_args={}, 374 | assets_sub_directory=None, 375 | staging_dates=None, 376 | archive_requests=False, 377 | only_gpu=True, 378 | debug=False, 379 | # The only changes we need to make are how we specify the model input 380 | # data. We'll use the GFS/GDAS data that we've already prepared - although 381 | # here we assume the data is available. We can add a sanity check above. 382 | input="file", 383 | file=str(gdas_proc_fn), 384 | ) 385 | 386 | @modal.method() 387 | def run_model(self) -> None: 388 | logger.info("Invoking AIModel.run_model()...") 389 | self.init_model.run() 390 | 391 | 392 | # This routine is made available as a stand-alone function, and it's up to the user 393 | # to ensure that the path config.AI_MODEL_ASSETS_DIR exists and is mapped to the storage 394 | # volume where assets should be cached. We provide this as a stand-alone function so 395 | # that it can be called a cheaper, non-GPU instance and avoid wasting cycles outside 396 | # of model inference on such a more expensive machine. 397 | def _maybe_download_assets(model_name: str) -> None: 398 | from multiurl import download 399 | 400 | logger.info(f"Maybe retrieving assets for model {model_name}...") 401 | 402 | # For the requested model, retrieve the pretrained model weights and cache them to 403 | # our storage volume. We are generally replicating the code from 404 | # ai_models.model.Model.download_assets(), but with some hard-coded options; 405 | # that method is also originally written as an instance method, and we don't 406 | # want to run the actual initializer for a model type to access it since 407 | # that would require us to provide input/output options and otherwise 408 | # prepare more generally for a model inference run - something we're not 409 | # ready to do at this stage of setup. 410 | model_class = ai_models_shim.get_model_class(model_name) 411 | n_files = len(model_class.download_files) 412 | n_downloaded = 0 413 | for i, file in enumerate(model_class.download_files): 414 | asset = os.path.realpath(os.path.join(config.AI_MODEL_ASSETS_DIR, file)) 415 | if not os.path.exists(asset): 416 | os.makedirs(os.path.dirname(asset), exist_ok=True) 417 | logger.info(f"({i}/{n_files}) downloading {asset}") 418 | download( 419 | model_class.download_url.format(file=file), 420 | asset + ".download", 421 | ) 422 | os.rename(asset + ".download", asset) 423 | n_downloaded += 1 424 | if not n_downloaded: 425 | logger.info(" No assets need to be downloaded.") 426 | logger.info("... done retrieving assets.") 427 | 428 | template_pth = config.make_gfs_template_path(model_name) 429 | logger.info("Checking for GFS/GDAS -> ERA-5 template at %s", template_pth) 430 | if not template_pth.exists(): 431 | logger.info("%s did not exist.", template_pth) 432 | # Two options: we've saved it to a bucket (so just download it), or we need 433 | # to generate it from scratch. 434 | bucket_name = os.environ.get("GCS_BUCKET_NAME", "") 435 | service_account_info = gcs.get_service_account_json("GCS_SERVICE_ACCOUNT_INFO") 436 | gcs_handler = gcs.GoogleCloudStorageHandler.with_service_account_info( 437 | service_account_info 438 | ) 439 | template_fn = template_pth.name 440 | target_blob = gcs_handler.client.bucket(bucket_name).blob(template_fn) 441 | 442 | # If the template doesn't exist, call our helper routine that forcibly 443 | # generates one, for us. Regardless, download from GCS to our local cache 444 | # afterwards. 445 | logger.info( 446 | "Checking for template in GCS bucket gs://%s/%s", bucket_name, template_fn 447 | ) 448 | if not target_blob.exists(): 449 | logger.info(" Template not found; generating from scratch.") 450 | make_model_era5_template.local(model_name) 451 | 452 | logger.info( 453 | "Downloading pre-computed template from gs://%s/%s", 454 | bucket_name, 455 | template_fn, 456 | ) 457 | gcs_handler.download_blob(bucket_name, template_fn, template_pth) 458 | 459 | 460 | @stub.function( 461 | image=stub.image, 462 | secrets=[config.ENV_SECRETS], 463 | network_file_systems={str(config.CACHE_DIR): volume}, 464 | allow_cross_region_volumes=True, 465 | timeout=1_800, 466 | ) 467 | def generate_forecast( 468 | model_name: str = "panguweather", 469 | model_init: datetime.datetime = datetime.datetime(2023, 7, 1, 0, 0), 470 | lead_time: int = 12, 471 | use_gfs: bool = False, 472 | skip_validate_env: bool = False, 473 | ): 474 | """Generate a forecast using the specified model.""" 475 | 476 | if not skip_validate_env: 477 | config.validate_env() 478 | 479 | logger.info(f"Setting up model {model_name} conditions...") 480 | # Pre-emptively try to download assets from our cheaper CPU-only function, so that 481 | # we don't waste time on the GPU machine. 482 | _maybe_download_assets(model_name) 483 | # If necessary, download and prepare GFS initial conditions. Again, don't waste time 484 | # with a GPU process for this. 485 | if use_gfs: 486 | prepare_gfs_analysis.remote(model_name, model_init) 487 | ai_model = AIModel(model_name, model_init, lead_time, use_gfs) 488 | 489 | logger.info("Generating forecast...") 490 | ai_model.run_model.remote() 491 | logger.info("... forecast complete!") 492 | 493 | # Double check that we successfully produced a model output file. 494 | logger.info(f"Checking output file {str(ai_model.out_pth)}...") 495 | if ai_model.out_pth.exists(): 496 | logger.info(" Success!") 497 | else: 498 | logger.info(" Did not find expected output file.") 499 | 500 | # Try to upload to Google Cloud Storage 501 | bucket_name = os.environ.get("GCS_BUCKET_NAME", "") 502 | service_account_info = gcs.get_service_account_json("GCS_SERVICE_ACCOUNT_INFO") 503 | 504 | if (bucket_name is None) or (not service_account_info): 505 | logger.warning("Not able to access to Google Cloud Storage; skipping upload.") 506 | return 507 | 508 | logger.info(f"Attempting to upload to GCS bucket gs://{bucket_name}...") 509 | gcs_handler = gcs.GoogleCloudStorageHandler.with_service_account_info( 510 | service_account_info 511 | ) 512 | dest_blob_name = ai_model.out_pth.name 513 | logger.info(f"Uploading to gs://{bucket_name}/{dest_blob_name}") 514 | gcs_handler.upload_blob( 515 | bucket_name, 516 | ai_model.out_pth, 517 | dest_blob_name, 518 | ) 519 | logger.info("Checking that upload was successful...") 520 | target_blob = gcs_handler.client.bucket(bucket_name).blob(dest_blob_name) 521 | if target_blob.exists(): 522 | logger.info(" Success!") 523 | else: 524 | logger.info( 525 | f" Did not find expected blob ({dest_blob_name}) in GCS bucket" 526 | f" ({bucket_name})." 527 | ) 528 | 529 | 530 | @stub.function( 531 | image=stub.image, 532 | secrets=[config.ENV_SECRETS], 533 | network_file_systems={str(config.CACHE_DIR): volume}, 534 | timeout=7_200, 535 | allow_cross_region_volumes=True, 536 | ) 537 | def make_model_era5_template(model_name: str): 538 | """Generate a template GRIB file corresponding to the ERA-5 inputs for a given 539 | AI model.""" 540 | import climetlab as cml 541 | import numpy as np 542 | 543 | bucket_name = os.environ.get("GCS_BUCKET_NAME", "") 544 | service_account_info = gcs.get_service_account_json("GCS_SERVICE_ACCOUNT_INFO") 545 | gcs_handler = gcs.GoogleCloudStorageHandler.with_service_account_info( 546 | service_account_info 547 | ) 548 | 549 | model_class = ai_models_shim.get_model_class(model_name) 550 | model = model_class( # noqa: F811 551 | # Necessary arguments to instantiate a Model object 552 | input="cds", 553 | output="file", 554 | download_assets=False, 555 | assets=config.AI_MODEL_ASSETS_DIR, 556 | date=int(config.DEFAULT_GFS_TEMPLATE_MODEL_EPOCH.strftime("%Y%m%d")), 557 | time=int(config.DEFAULT_GFS_TEMPLATE_MODEL_EPOCH.strftime("%H")), 558 | lead_time=6, 559 | path="_stub.grib2", 560 | metadata={}, 561 | model_args={}, 562 | assets_sub_directory=None, 563 | staging_dates=None, 564 | archive_requests=False, 565 | only_gpu=False, 566 | debug=True, 567 | ) 568 | 569 | out_fn = f"{model_name}.input-template.grib2" 570 | with cml.new_grib_output(out_fn) as f: 571 | for template in model.input.all_fields: 572 | f.write(np.zeros_like(template.shape), template=template) 573 | 574 | logger.info("Uploading to gs://%s/%s", bucket_name, out_fn) 575 | gcs_handler.upload_blob( 576 | bucket_name, 577 | out_fn, 578 | out_fn, 579 | ) 580 | logger.info("Checking that upload was successful...") 581 | target_blob = gcs_handler.client.bucket(bucket_name).blob(out_fn) 582 | if target_blob.exists(): 583 | logger.info(" Success!") 584 | else: 585 | logger.info( 586 | " Did not find expected blob %s in GCS bucket gs://%s.", 587 | out_fn, 588 | bucket_name, 589 | ) 590 | 591 | 592 | @stub.local_entrypoint() 593 | def main( 594 | model_name: str = "panguweather", 595 | lead_time: int = 12, 596 | model_init: datetime.datetime = datetime.datetime(2023, 7, 1, 0, 0), 597 | use_gfs: bool = False, 598 | make_template: bool = False, 599 | run_checks: bool = False, 600 | run_forecast: bool = False, 601 | ): 602 | """Entrypoint for triggering a remote ai-models weather forecast run. 603 | 604 | Parameters: 605 | model: short name for the model to run; must be one of ['panguweather', 606 | 'fourcastnetv2-small', 'graphcast']. Defaults to 'panguweather'. 607 | lead_time: number of hours to forecast into the future. Defaults to 12. 608 | model_init: datetime to use when initializing the model. Defaults to 609 | 2023-07-01T00:00. 610 | use_gfs: use GFS/GDAS initial conditions instead of the default ERA-5 611 | make_template: generate a template GRIB file corresponding to the ERA-5 inputs 612 | for a given model. 613 | run_checks: enable call to remote check_assets() for triaging the application 614 | runtime environment. 615 | run_forecast: enable call to remote generate_forecast() for running the actual 616 | forecast model. 617 | """ 618 | # Quick sanity checks on model arguments; if we don't need to call out to our 619 | # remote apps, then we shouldn't! 620 | if model_name not in ai_models_shim.SUPPORTED_AI_MODELS: 621 | raise ValueError( 622 | f"User provided model_name '{model_name}' is not supported; must be one of" 623 | f" {ai_models_shim.SUPPORTED_AI_MODELS}." 624 | ) 625 | 626 | if make_template: 627 | make_model_era5_template.remote(model_name) 628 | if run_checks: 629 | check_assets.remote() 630 | if run_forecast: 631 | generate_forecast.remote( 632 | model_name=model_name, 633 | model_init=model_init, 634 | lead_time=lead_time, 635 | use_gfs=use_gfs, 636 | ) 637 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ai-models 2 | google-cloud-storage 3 | modal 4 | python-dotenv 5 | ujson --------------------------------------------------------------------------------