├── .github ├── pull_request_template.md └── workflows │ ├── docs.yaml │ ├── pre-commit.yml │ └── tests.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── ACE-logo.png ├── LICENSE ├── Makefile ├── README.md ├── analysis-deps.txt ├── conftest.py ├── constraints.txt ├── docker └── Dockerfile ├── fme ├── LICENSE ├── README.md ├── deploy-requirements.txt ├── dev-requirements.txt ├── docs │ ├── Makefile │ ├── _static │ │ ├── Ai2_icon_pink_RGB.png │ │ ├── Ai2_icon_pink_padding_RGB.png │ │ └── custom.css │ ├── api.rst │ ├── builder.rst │ ├── conf.py │ ├── configs │ │ ├── explicit-indices.yaml │ │ ├── inference-ic-indices.yaml │ │ └── timestamp-list.yaml │ ├── evaluator-config.yaml │ ├── evaluator_config.rst │ ├── index.rst │ ├── inference-config.yaml │ ├── inference_config.rst │ ├── installation.rst │ ├── make.bat │ ├── modules.rst │ ├── quickstart.rst │ ├── requirements.txt │ ├── train-config.yaml │ └── training_config.rst ├── fme │ ├── __init__.py │ ├── ace │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── aggregator │ │ │ ├── __init__.py │ │ │ ├── inference │ │ │ │ ├── __init__.py │ │ │ │ ├── annual.py │ │ │ │ ├── enso │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── enso.py │ │ │ │ │ ├── index.py │ │ │ │ │ └── test_enso.py │ │ │ │ ├── histogram.py │ │ │ │ ├── main.py │ │ │ │ ├── reduced.py │ │ │ │ ├── seasonal.py │ │ │ │ ├── spectrum.py │ │ │ │ ├── test_annual.py │ │ │ │ ├── test_distributed.py │ │ │ │ ├── test_evaluator.py │ │ │ │ ├── test_inference.py │ │ │ │ ├── test_reduced.py │ │ │ │ ├── test_seasonal.py │ │ │ │ ├── test_spectrum.py │ │ │ │ ├── test_time_mean.py │ │ │ │ ├── test_video.py │ │ │ │ ├── test_zonal_mean.py │ │ │ │ ├── time_mean.py │ │ │ │ ├── video.py │ │ │ │ └── zonal_mean.py │ │ │ ├── null.py │ │ │ ├── one_step │ │ │ │ ├── __init__.py │ │ │ │ ├── main.py │ │ │ │ ├── map.py │ │ │ │ ├── reduced.py │ │ │ │ ├── reduced_metrics.py │ │ │ │ ├── snapshot.py │ │ │ │ ├── test_main.py │ │ │ │ └── test_reduced.py │ │ │ ├── plotting.py │ │ │ ├── test_plotting.py │ │ │ └── train.py │ │ ├── data_loading │ │ │ ├── __init__.py │ │ │ ├── batch_data.py │ │ │ ├── config.py │ │ │ ├── getters.py │ │ │ ├── inference.py │ │ │ ├── perturbation.py │ │ │ ├── test_batch_data.py │ │ │ ├── test_data_loader.py │ │ │ ├── test_data_loading_config.py │ │ │ ├── test_metadata.py │ │ │ └── test_perturbation.py │ │ ├── evaluator.py │ │ ├── inference │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ ├── data_writer │ │ │ │ ├── __init__.py │ │ │ │ ├── histograms.py │ │ │ │ ├── main.py │ │ │ │ ├── monthly.py │ │ │ │ ├── raw.py │ │ │ │ ├── restart.py │ │ │ │ ├── test_data_writer.py │ │ │ │ ├── test_main.py │ │ │ │ ├── test_monthly.py │ │ │ │ ├── test_time_coarsen.py │ │ │ │ ├── time_coarsen.py │ │ │ │ ├── utils.py │ │ │ │ └── video.py │ │ │ ├── derived_variables.py │ │ │ ├── evaluator.py │ │ │ ├── inference.py │ │ │ ├── loop.py │ │ │ ├── stepper_test_data │ │ │ ├── test_derived_variables.py │ │ │ ├── test_evaluator.py │ │ │ ├── test_inference.py │ │ │ └── test_segmented.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── healpix │ │ │ │ ├── __init__.py │ │ │ │ ├── healpix_activations.py │ │ │ │ ├── healpix_blocks.py │ │ │ │ ├── healpix_decoder.py │ │ │ │ ├── healpix_encoder.py │ │ │ │ ├── healpix_layers.py │ │ │ │ └── healpix_recunet.py │ │ │ ├── makani │ │ │ │ ├── __init__.py │ │ │ │ ├── activations.py │ │ │ │ ├── contractions.py │ │ │ │ ├── factorizations.py │ │ │ │ ├── layers.py │ │ │ │ ├── sfnonet.py │ │ │ │ └── spectral_convolution.py │ │ │ └── modulus │ │ │ │ ├── __init__.py │ │ │ │ ├── activations.py │ │ │ │ ├── contractions.py │ │ │ │ ├── factorizations.py │ │ │ │ ├── initialization.py │ │ │ │ ├── layers.py │ │ │ │ ├── s2convolutions.py │ │ │ │ └── sfnonet.py │ │ ├── registry │ │ │ ├── __init__.py │ │ │ ├── hpx.py │ │ │ ├── prebuilt.py │ │ │ ├── registry.py │ │ │ ├── sfno.py │ │ │ ├── test_hpx.py │ │ │ └── test_sfno.py │ │ ├── requirements.py │ │ ├── run-train-and-inference.sh │ │ ├── stepper.py │ │ ├── test_stepper.py │ │ ├── test_train.py │ │ ├── testing │ │ │ ├── __init__.py │ │ │ └── fv3gfs_data.py │ │ ├── train │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ ├── train.py │ │ │ └── train_config.py │ │ └── validate_config.py │ ├── core │ │ ├── __init__.py │ │ ├── climate_data.py │ │ ├── constants.py │ │ ├── coordinates.py │ │ ├── corrector │ │ │ ├── __init__.py │ │ │ ├── corrector.py │ │ │ ├── ocean.py │ │ │ ├── registry.py │ │ │ └── test_corrector.py │ │ ├── dataset │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── data_typing.py │ │ │ ├── getters.py │ │ │ ├── requirements.py │ │ │ ├── test_utils.py │ │ │ ├── test_xarray.py │ │ │ ├── utils.py │ │ │ └── xarray.py │ │ ├── device.py │ │ ├── dicts.py │ │ ├── distributed.py │ │ ├── ema.py │ │ ├── generics │ │ │ ├── aggregator.py │ │ │ ├── data.py │ │ │ ├── inference.py │ │ │ ├── optimization.py │ │ │ ├── test_looper.py │ │ │ ├── test_trainer.py │ │ │ ├── train_stepper.py │ │ │ ├── trainer.py │ │ │ └── writer.py │ │ ├── gridded_ops.py │ │ ├── histogram.py │ │ ├── logging_utils.py │ │ ├── loss.py │ │ ├── masking.py │ │ ├── metrics.py │ │ ├── normalizer.py │ │ ├── ocean.py │ │ ├── optimization.py │ │ ├── packer.py │ │ ├── parameter_init.py │ │ ├── prescriber.py │ │ ├── registry │ │ │ ├── __init__.py │ │ │ ├── corrector.py │ │ │ ├── module.py │ │ │ ├── registry.py │ │ │ └── test_module_registry.py │ │ ├── regrid.py │ │ ├── scheduler.py │ │ ├── stacker.py │ │ ├── test_climate_data.py │ │ ├── test_coordinates.py │ │ ├── test_device.py │ │ ├── test_dicts.py │ │ ├── test_distributed.py │ │ ├── test_gridded_ops.py │ │ ├── test_histogram.py │ │ ├── test_loss.py │ │ ├── test_masking.py │ │ ├── test_metrics.py │ │ ├── test_normalizer.py │ │ ├── test_ocean.py │ │ ├── test_optimization.py │ │ ├── test_packer.py │ │ ├── test_parameter_init.py │ │ ├── test_prescriber.py │ │ ├── test_scheduler.py │ │ ├── test_stacker.py │ │ ├── test_timing.py │ │ ├── test_wandb.py │ │ ├── test_weight_ops.py │ │ ├── test_wildcard.py │ │ ├── test_winds.py │ │ ├── testing │ │ │ ├── __init__.py │ │ │ ├── distributed.py │ │ │ └── wandb.py │ │ ├── timing.py │ │ ├── typing_.py │ │ ├── wandb.py │ │ ├── weight_ops.py │ │ ├── wildcard.py │ │ └── winds.py │ ├── require_gpu.py │ ├── sht_fix.py │ └── test_harmonics.py ├── pyproject.toml └── requirements.txt └── scripts ├── README.md ├── data_process ├── Makefile ├── README.md ├── beakerpy.Dockerfile ├── combine_stats.py ├── compute_dataset.py ├── compute_dataset.sh ├── compute_dataset_argo_workflow.yaml ├── compute_dataset_e3smv2.py ├── compute_hpx_dataset.py ├── compute_hpx_dataset.sh ├── compute_repeating_forcing.py ├── compute_stats.sh ├── configs │ ├── e3sm-1deg-8layer.yaml │ ├── era5-1deg-16layer-1940-2022.yaml │ ├── era5-1deg-8layer-1940-2022.yaml │ ├── fv3gfs-amip-ensemble-1deg-8layer.yaml │ ├── fv3gfs-amip-ensemble-4deg-8layer.yaml │ ├── fv3gfs-c48-ensemble-1deg-8layer.yaml │ ├── fv3gfs-ensemble-1deg-8layer.yaml │ ├── fv3gfs-ensemble-4deg-8layer.yaml │ ├── healpix-1deg-8layer-1940-2022.yaml │ ├── pre-industrial-CM4-1deg-8layer-trial-run.yaml │ ├── shield-amip-ensemble-c24-4deg-8layer.yaml │ ├── shield-amip-ensemble-c96-1deg-8layer.yaml │ ├── shield-amip-ensemble-c96-4deg-8layer.yaml │ ├── shield-c24-ensemble-4deg-8layer.yaml │ ├── shield-c96-4deg-8layer.yaml │ ├── shield-som-abrupt-co2-increase-c96-1deg-8layer.yaml │ ├── shield-som-abrupt-co2-increase-c96-4deg-8layer.yaml │ ├── shield-som-c24-4deg-8layer.yaml │ ├── shield-som-c24-tuned-cdmbgwd-4deg-8layer.yaml │ ├── shield-som-ensemble-c96-1deg-8layer.yaml │ ├── shield-som-ensemble-c96-4deg-8layer.yaml │ ├── shield-som-increasing-co2-c96-1deg-8layer.yaml │ ├── shield-som-increasing-co2-c96-4deg-8layer.yaml │ ├── shield-som-radiation-multi-call-c96-1deg-8layer.yaml │ └── shield-som-spin-up-c96-1deg-8layer.yaml ├── convert_to_monthly_netcdf.py ├── convert_to_monthly_netcdf_fv3gfs.sh ├── earth2grid.Dockerfile ├── generate_beaker_stats_dataset_fv3gfs.sh ├── generate_datasets_e3smv2.sh ├── get_stats.py ├── test_combine_stats.py ├── test_config.py └── upload_stats.py ├── era5 ├── Dockerfile ├── Makefile ├── README.md ├── dataflow-requirements.txt ├── environment.yaml ├── ingest_ncar_data │ ├── argo_workflow.yaml │ ├── ingest_single_variable.py │ ├── main.sh │ └── variables.json ├── netcdf_to_zarr │ ├── netcdf_to_zarr_pipeline.py │ └── run-netcdf-to-zarr.sh └── pipeline │ ├── run-dataflow-025deg-data.sh │ ├── run-dataflow-16levels.sh │ ├── run-dataflow.sh │ └── xr-beam-pipeline.py ├── manual_backwards_compatibility ├── .gitignore └── ace2-era5.sh └── monthly_data ├── compute_enso_index.py ├── config.yaml ├── test_write_monthly_data.py └── write_monthly_data.py /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | Short description of why the PR is needed and how it satisfies those requirements, in sentence form. 2 | 3 | Changes: 4 | - symbol (e.g. `fme.core.my_function`) or script and concise description of changes or added feature 5 | - Can group multiple related symbols on a single bullet 6 | 7 | - [ ] Tests added 8 | 9 | Resolves # (delete if none) 10 | -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - uses: actions/setup-python@v4 17 | with: 18 | python-version: "3.10" 19 | - uses: actions/cache@v4 20 | with: 21 | path: ${{ env.pythonLocation }} 22 | key: ${{ env.pythonLocation }}-${{ hashFiles('fme/requirements.txt') }}-${{ hashFiles('fme/docs/requirements.txt') }}-${{ hashFiles('fme/constraints.txt') }} 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install uv==0.2.5 26 | uv pip install --system -c constraints.txt -e fme[docs] 27 | - name: Build docs 28 | run: | 29 | cd fme/docs && make doctest html 30 | - name: Deploy to GitHub Pages 31 | uses: peaceiris/actions-gh-pages@v3 32 | if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} 33 | with: 34 | publish_branch: gh-pages 35 | github_token: ${{ secrets.GITHUB_TOKEN }} 36 | publish_dir: fme/docs/_build/ 37 | force_orphan: true 38 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: Pre-commit 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | pre-commit: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v2 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: "3.10" 22 | 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install pre-commit 27 | 28 | - name: Run pre-commit 29 | run: pre-commit run --all-files 30 | -------------------------------------------------------------------------------- /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | cpu: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Set environment variables 17 | run: | 18 | echo "CURRENT_WEEK=$(date +'%Y-%U')" >> $GITHUB_ENV 19 | - uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.10" 22 | - uses: actions/cache@v4 23 | with: 24 | path: ${{ env.pythonLocation }} 25 | key: ${{ env.CURRENT_WEEK }}-${{ env.pythonLocation }}-${{ hashFiles('fme/requirements.txt') }}-${{ hashFiles('fme/dev-requirements.txt') }}-${{ hashFiles('constraints.txt') }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install uv 29 | uv pip install --system -c constraints.txt -e ./fme[dev] 30 | - name: Run pytest 31 | run: | 32 | make test 33 | cpu-very-fast: 34 | runs-on: ubuntu-latest 35 | steps: 36 | - uses: actions/checkout@v3 37 | - name: Set environment variables 38 | run: | 39 | echo "CURRENT_WEEK=$(date +'%Y-%U')" >> $GITHUB_ENV 40 | - uses: actions/setup-python@v4 41 | with: 42 | python-version: "3.10" 43 | - uses: actions/cache@v4 44 | with: 45 | path: ${{ env.pythonLocation }} 46 | key: ${{ env.CURRENT_WEEK }}-${{ env.pythonLocation }}-${{ hashFiles('fme/requirements.txt') }}-${{ hashFiles('fme/dev-requirements.txt') }}-${{ hashFiles('constraints.txt') }} 47 | - name: Install dependencies 48 | run: | 49 | python -m pip install uv 50 | uv pip install --system -c constraints.txt -e ./fme[dev] 51 | - name: Run pytest 52 | run: | 53 | make test_very_fast 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | fme/docs/available_modules.rst 2 | fme/docs/_build 3 | 4 | .vscode 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | .DS_Store 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # scratch directory for testing 138 | scratch/ 139 | 140 | # Some in progress data pipelines get added here 141 | scripts/data_process/.nfs* 142 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.8.1 4 | hooks: 5 | - id: ruff 6 | args: ["--fix", "--config", "fme/pyproject.toml"] 7 | - id: ruff-format 8 | - repo: https://github.com/pre-commit/pre-commit-hooks 9 | rev: v5.0.0 10 | hooks: 11 | - id: check-added-large-files 12 | args: [--maxkb=250] 13 | - id: trailing-whitespace 14 | - repo: https://github.com/pre-commit/mirrors-mypy 15 | rev: v1.2.0 16 | hooks: 17 | - id: mypy 18 | additional_dependencies: ["types-PyYaml==5.4.3"] 19 | args: ["--ignore-missing-imports", "--check-untyped-defs"] 20 | exclude: | 21 | (?x)^( 22 | .+/conf.py | 23 | .+/conftest.py 24 | )$ -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.10" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the "docs/" directory with Sphinx 19 | sphinx: 20 | configuration: fme/docs/conf.py 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | # formats: 24 | # - pdf 25 | # - epub 26 | 27 | # Optional but recommended, declare the Python requirements required 28 | # to build your documentation 29 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 30 | python: 31 | install: 32 | - requirements: fme/docs/requirements.txt 33 | - requirements: fme/requirements.txt 34 | -------------------------------------------------------------------------------- /ACE-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/ACE-logo.png -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VERSION ?= $(shell git rev-parse --short HEAD) 2 | IMAGE ?= fme 3 | ENVIRONMENT_NAME ?= fme 4 | DEPLOY_TARGET ?= pypi 5 | 6 | build_docker_image: 7 | docker build --platform=linux/amd64 -f docker/Dockerfile -t $(IMAGE):$(VERSION) . 8 | 9 | enter_docker_image: build_docker_image 10 | docker run -it --rm $(IMAGE):$(VERSION) bash 11 | 12 | # recommended to deactivate current conda environment before running this 13 | create_environment: 14 | conda create -n $(ENVIRONMENT_NAME) python=3.10 pip 15 | conda run --no-capture-output -n $(ENVIRONMENT_NAME) python -m pip install uv 16 | conda run --no-capture-output -n $(ENVIRONMENT_NAME) uv pip install -c constraints.txt -e ./fme[dev,docs] 17 | conda run --no-capture-output -n $(ENVIRONMENT_NAME) uv pip install -r analysis-deps.txt 18 | 19 | test: 20 | pytest --durations 40 . 21 | 22 | test_fast: 23 | pytest --durations 40 --fast . 24 | 25 | test_very_fast: 26 | pytest --durations 40 --very-fast . 27 | 28 | # For maintainer use only 29 | # requires fme[deploy] to be installed 30 | 31 | build_pypi: 32 | rm -rf fme/dist 33 | cd fme && python -m build 34 | 35 | deploy_pypi: build_pypi 36 | cd fme && twine upload --repository $(DEPLOY_TARGET) dist/* 37 | 38 | deploy_test_pypi: DEPLOY_TARGET = testpypi 39 | deploy_test_pypi: deploy_pypi 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Docs](https://readthedocs.org/projects/ai2-climate-emulator/badge/?version=latest)](https://ai2-climate-emulator.readthedocs.io/en/latest/) 2 | [![PyPI](https://img.shields.io/pypi/v/fme.svg)](https://pypi.org/project/fme/) 3 | 4 | Logo for the ACE Project 5 | 6 | # Ai2 Climate Emulator 7 | 8 | Ai2 Climate Emulator (ACE) is a fast machine learning model that simulates global atmospheric variability in a changing climate over time scales ranging from hours to centuries. 9 | 10 | This repo contains code accompanying four papers describing ACE models: 11 | - "ACE: A fast, skillful learned global atmospheric model for climate prediction" ([link](https://arxiv.org/abs/2310.02074)) 12 | - "Application of the Ai2 Climate Emulator to E3SMv2's global atmosphere model, with a focus on precipitation fidelity" ([link](https://agupubs.onlinelibrary.wiley.com/doi/full/10.1029/2024JH000136)) 13 | - "ACE2: Accurately learning subseasonal to decadal atmospheric variability and forced responses" ([link](https://arxiv.org/abs/2411.11268)) 14 | - "ACE2-SOM: Coupling to a slab ocean and learning the sensitivity of climate to changes in CO2" ([link](https://arxiv.org/abs/2412.04418)) 15 | 16 | ## Installation 17 | 18 | ``` 19 | pip install fme 20 | ``` 21 | 22 | ## Documentation 23 | 24 | See complete documentation [here](https://ai2-climate-emulator.readthedocs.io/en/latest/) and a quickstart guide [here](https://ai2-climate-emulator.readthedocs.io/en/latest/quickstart.html). 25 | 26 | ## Model checkpoints 27 | 28 | Pretrained model checkpoints are available in the [ACE Hugging Face](https://huggingface.co/collections/allenai/ace-67327d822f0f0d8e0e5e6ca4) collection. 29 | 30 | ## Available datasets 31 | Two versions of the complete dataset described in [arxiv:2310.02074](https://arxiv.org/abs/2310.02074) 32 | are available on a [requester pays](https://cloud.google.com/storage/docs/requester-pays) Google Cloud Storage bucket: 33 | ``` 34 | gs://ai2cm-public-requester-pays/2023-11-29-ai2-climate-emulator-v1/data/repeating-climSST-1deg-zarrs 35 | gs://ai2cm-public-requester-pays/2023-11-29-ai2-climate-emulator-v1/data/repeating-climSST-1deg-netCDFs 36 | ``` 37 | The `zarr` format is convenient for ad-hoc analysis. The netCDF version contains our 38 | train/validation split which was used for training and inference. 39 | 40 | The datasets used in the [ACE2 paper](https://arxiv.org/abs/2411.11268) are available at: 41 | ``` 42 | gs://ai2cm-public-requester-pays/2024-11-13-ai2-climate-emulator-v2-amip/data/c96-1deg-shield/ 43 | gs://ai2cm-public-requester-pays/2024-11-13-ai2-climate-emulator-v2-amip/data/era5-1deg-1940-2022.zarr/ 44 | ``` 45 | 46 | The dataset used in the [ACE2-SOM paper](https://arxiv.org/abs/2412.04418) is available at: 47 | ``` 48 | gs://ai2cm-public-requester-pays/2024-12-05-ai2-climate-emulator-v2-som/SHiELD-SOM-C96 49 | ``` 50 | -------------------------------------------------------------------------------- /analysis-deps.txt: -------------------------------------------------------------------------------- 1 | # these are some packages which are convenient to have installed for ad-hoc analysis 2 | # but which are not requirements of the "fme" package. We do not relist the fme 3 | # dependencies here. 4 | beaker-py 5 | Bottleneck 6 | cartopy>=0.22.0 7 | dask[distributed] 8 | ipywidgets 9 | nc-time-axis 10 | jupyterlab 11 | pyproj<3.7 12 | seaborn 13 | bokeh>=3.1.0 -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import signal 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--fast", 9 | action="store_true", 10 | default=False, 11 | help="Skip slow tests", 12 | ) 13 | parser.addoption( 14 | "--very-fast", 15 | action="store_true", 16 | default=False, 17 | help="Run only very fast tests (< 5 seconds)", 18 | ) 19 | 20 | 21 | @pytest.fixture 22 | def skip_slow(request, very_fast_only): 23 | return very_fast_only or request.config.getoption("--fast") 24 | 25 | 26 | @pytest.fixture 27 | def very_fast_only(request): 28 | return request.config.getoption("--very-fast") 29 | 30 | 31 | class TimeoutException(Exception): 32 | pass 33 | 34 | 35 | def timeout_handler(signum, frame): 36 | raise TimeoutException("Test took too long") 37 | 38 | 39 | @pytest.fixture 40 | def pdb_enabled(request): 41 | return request.config.getoption("--pdb") 42 | 43 | 44 | @pytest.fixture(autouse=True, scope="function") 45 | def enforce_timeout(skip_slow, very_fast_only, pdb_enabled): 46 | if pdb_enabled: 47 | yield # Do not enforce timeout if we are debugging 48 | return 49 | if very_fast_only: 50 | timeout_seconds = 3 51 | elif skip_slow: 52 | timeout_seconds = 30 53 | else: 54 | timeout_seconds = 60 55 | signal.signal(signal.SIGALRM, timeout_handler) 56 | signal.alarm(timeout_seconds) # Set the timeout for the test 57 | try: 58 | yield 59 | finally: 60 | signal.alarm(0) # Disable the alarm after the test completes 61 | 62 | 63 | @pytest.hookimpl(tryfirst=True, hookwrapper=True) 64 | def pytest_runtest_call(item): 65 | try: 66 | yield 67 | except TimeoutException: 68 | pytest.fail("Test failed due to timeout") 69 | -------------------------------------------------------------------------------- /constraints.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.2 # minor version matches torch in Docker image -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.08-py3 2 | 3 | ENV FME_DIR=/full-model 4 | ENV DGLBACKEND=pytorch 5 | 6 | # Install gcloud- used for monthly netcdf data processing script 7 | # https://cloud.google.com/sdk/docs/install#deb 8 | RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt cloud-sdk main" | \ 9 | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \ 10 | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - && apt-get update -y && apt-get install google-cloud-cli -y 11 | 12 | # install python deps 13 | COPY fme/requirements.txt /tmp/requirements.txt 14 | RUN python3 -m pip install -r /tmp/requirements.txt 15 | 16 | # copy local code and install 17 | COPY fme ${FME_DIR}/fme 18 | RUN cd $FME_DIR && pip install --no-deps -e fme 19 | 20 | # copy after install so editing scripts does not trigger reinstall 21 | COPY scripts ${FME_DIR}/scripts 22 | -------------------------------------------------------------------------------- /fme/README.md: -------------------------------------------------------------------------------- 1 | [![Docs](https://readthedocs.org/projects/ai2-climate-emulator/badge/?version=latest)](https://ai2-climate-emulator.readthedocs.io/en/latest/) 2 | [![PyPI](https://img.shields.io/pypi/v/fme.svg)](https://pypi.org/project/fme/) 3 | 4 | # FME: Weather/Climate Model Emulation 5 | This package contains code to train and evaluate weather/climate model emulators as seen in 6 | "ACE: A fast, skillful learned global atmospheric model for climate prediction" ([arxiv:2310.02074](https://arxiv.org/abs/2310.02074)) 7 | and "Application of the Ai2 Climate Emulator to E3SMv2's global atmosphere model, with a focus on precipitation fidelity" 8 | ([JGR-ML](https://agupubs.onlinelibrary.wiley.com/doi/full/10.1029/2024JH000136)). 9 | 10 | 11 | ## Installation 12 | 13 | The package can be installed via PyPI using: 14 | 15 | ``` 16 | pip install fme 17 | ``` 18 | 19 | ## Quickstart 20 | 21 | A quickstart guide may be found [here](https://ai2-climate-emulator.readthedocs.io/en/latest/quickstart.html). 22 | 23 | ## Documentation 24 | 25 | See complete documentation [here](https://ai2-climate-emulator.readthedocs.io/en/latest/). 26 | -------------------------------------------------------------------------------- /fme/deploy-requirements.txt: -------------------------------------------------------------------------------- 1 | build 2 | twine -------------------------------------------------------------------------------- /fme/dev-requirements.txt: -------------------------------------------------------------------------------- 1 | pre-commit 2 | pytest -------------------------------------------------------------------------------- /fme/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = fme 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /fme/docs/_static/Ai2_icon_pink_RGB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/docs/_static/Ai2_icon_pink_RGB.png -------------------------------------------------------------------------------- /fme/docs/_static/Ai2_icon_pink_padding_RGB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/docs/_static/Ai2_icon_pink_padding_RGB.png -------------------------------------------------------------------------------- /fme/docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | body[data-theme="dark"] { 2 | --code-block-background: #202020; 3 | } 4 | 5 | body[data-theme="light"] { 6 | --code-block-background: #f8f9fb; 7 | } 8 | 9 | body[data-theme="auto"] { 10 | --code-block-background: #f8f9fb; 11 | } 12 | 13 | @media (prefers-color-scheme: dark) { 14 | body[data-theme="auto"] { 15 | --code-block-background: #202020; 16 | } 17 | } 18 | 19 | div.highlight pre { 20 | background: var(--code-block-background); 21 | } -------------------------------------------------------------------------------- /fme/docs/api.rst: -------------------------------------------------------------------------------- 1 | .. _API Reference: 2 | 3 | ============= 4 | API Reference 5 | ============= 6 | 7 | fme 8 | === 9 | 10 | .. automodule:: fme 11 | :members: 12 | 13 | fme.ace 14 | ======= 15 | 16 | .. automodule:: fme.ace 17 | :members: 18 | -------------------------------------------------------------------------------- /fme/docs/configs/explicit-indices.yaml: -------------------------------------------------------------------------------- 1 | path: initial_conditions.nc 2 | start_indices: 3 | list: [0, 3, 7] 4 | -------------------------------------------------------------------------------- /fme/docs/configs/inference-ic-indices.yaml: -------------------------------------------------------------------------------- 1 | path: initial_conditions.nc 2 | start_indices: 3 | n_initial_conditions: 3 4 | first: 1 5 | interval: 2 6 | -------------------------------------------------------------------------------- /fme/docs/configs/timestamp-list.yaml: -------------------------------------------------------------------------------- 1 | path: initial_conditions.nc 2 | start_indices: 3 | times: 4 | - "2021-01-01T00:00:00" 5 | - "2021-02-01T00:00:00" 6 | -------------------------------------------------------------------------------- /fme/docs/evaluator-config.yaml: -------------------------------------------------------------------------------- 1 | experiment_dir: evaluator_output 2 | n_forward_steps: 400 # 100 days 3 | forward_steps_in_memory: 50 4 | checkpoint_path: ckpt.tar 5 | logging: 6 | log_to_screen: true 7 | log_to_wandb: false 8 | log_to_file: true 9 | project: ace 10 | entity: your_wandb_entity 11 | loader: 12 | dataset: 13 | data_path: validation 14 | start_indices: 15 | first: 0 16 | n_initial_conditions: 1 17 | num_data_workers: 8 18 | data_writer: 19 | save_prediction_files: false 20 | save_monthly_files: false 21 | -------------------------------------------------------------------------------- /fme/docs/evaluator_config.rst: -------------------------------------------------------------------------------- 1 | .. _evaluator-config: 2 | 3 | ================ 4 | Evaluator Config 5 | ================ 6 | 7 | The following is an example configuration for running inference while evaluating against target data. 8 | While you can use absolute paths in the config yamls (we encourage it!), the example uses paths relative to the directory you run the command. 9 | The example assumes you are running in a directory structure like: 10 | 11 | :: 12 | 13 | . 14 | ├── ckpt.tar 15 | └── validation 16 | ├── data1.nc # files can have any name, but must sort into time-sequential order 17 | ├── data2.nc # can have any number of netCDF files 18 | └── ... 19 | 20 | The ``.nc`` files correspond to data files like ``2021010100.nc`` in the `Zenodo repository`_, while ``ckpt.tar`` corresponds to a file like ``ace_ckpt.tar`` in that repository. 21 | 22 | .. _Zenodo repository: https://zenodo.org/doi/10.5281/zenodo.10791086 23 | 24 | .. literalinclude:: evaluator-config.yaml 25 | :language: yaml 26 | :caption: Example YAML Configuration 27 | 28 | .. testcode:: 29 | :hide: 30 | 31 | from fme.ace import InferenceEvaluatorConfig 32 | import yaml 33 | import dacite 34 | 35 | with open('evaluator-config.yaml', 'r') as f: 36 | config_dict = yaml.safe_load(f) 37 | 38 | config = dacite.from_dict( 39 | InferenceEvaluatorConfig, 40 | data=config_dict, 41 | config=dacite.Config(strict=True) 42 | ) 43 | # these paths are used in the documentation on this page 44 | # if they change then update the docs! 45 | assert config.checkpoint_path == "ckpt.tar" 46 | assert config.loader.dataset.data_path == "validation" 47 | print("Loaded successfully") 48 | 49 | .. testoutput:: 50 | :hide: 51 | 52 | Loaded successfully 53 | 54 | We use the :ref:`Builder pattern ` to load this configuration into a multi-level dataclass structure. 55 | The configuration is divided into several sub-configurations, each with its own dataclass. 56 | The top-level configuration is the :class:`fme.ace.InferenceEvaluatorConfig` class. 57 | 58 | .. autoclass:: fme.ace.InferenceEvaluatorConfig 59 | :show-inheritance: 60 | :noindex: 61 | 62 | The sub-configurations are: 63 | 64 | .. autoclass:: fme.ace.LoggingConfig 65 | :show-inheritance: 66 | :noindex: 67 | 68 | .. autoclass:: fme.ace.InferenceDataLoaderConfig 69 | :show-inheritance: 70 | :noindex: 71 | 72 | .. autoclass:: fme.ace.InferenceInitialConditionIndices 73 | :show-inheritance: 74 | :noindex: 75 | 76 | .. autoclass:: fme.ace.ExplicitIndices 77 | :show-inheritance: 78 | :noindex: 79 | 80 | .. autoclass:: fme.ace.TimestampList 81 | :show-inheritance: 82 | :noindex: 83 | 84 | .. autoclass:: fme.ace.XarrayDataConfig 85 | :show-inheritance: 86 | :noindex: 87 | 88 | .. autoclass:: fme.ace.DataWriterConfig 89 | :show-inheritance: 90 | :noindex: 91 | 92 | .. autoclass:: fme.ace.InferenceEvaluatorAggregatorConfig 93 | :show-inheritance: 94 | :noindex: 95 | 96 | .. autoclass:: fme.ace.OceanConfig 97 | :show-inheritance: 98 | :noindex: 99 | -------------------------------------------------------------------------------- /fme/docs/index.rst: -------------------------------------------------------------------------------- 1 | fme: Full Model Emulation 2 | ====================================== 3 | 4 | **fme** ("full model emulation") is a python package for training and running 5 | climate model emulators, such as the Ai2 Climate Emulator. 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | :caption: Contents: 10 | 11 | installation 12 | quickstart 13 | training_config 14 | inference_config 15 | evaluator_config 16 | builder 17 | modules 18 | api 19 | 20 | Indices and tables 21 | ================== 22 | * :ref:`genindex` 23 | * :ref:`modindex` 24 | * :ref:`search` 25 | -------------------------------------------------------------------------------- /fme/docs/inference-config.yaml: -------------------------------------------------------------------------------- 1 | experiment_dir: inference_output 2 | n_forward_steps: 400 # 100 days 3 | forward_steps_in_memory: 50 4 | checkpoint_path: ace_ckpt.tar 5 | logging: 6 | log_to_screen: true 7 | log_to_wandb: false 8 | log_to_file: true 9 | project: ace 10 | initial_condition: 11 | path: climSST/ic_2021.zarr 12 | start_indices: 13 | n_initial_conditions: 2 14 | first: 0 15 | interval: 3 16 | engine: zarr 17 | forcing_loader: 18 | dataset: 19 | data_path: climSST 20 | file_pattern: forcing_2021.zarr 21 | engine: zarr 22 | n_repeats: 2 # use this to extend the 1-year of forcing data to desired length 23 | num_data_workers: 2 24 | data_writer: 25 | save_prediction_files: false 26 | -------------------------------------------------------------------------------- /fme/docs/installation.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Installation 5 | ============ 6 | 7 | All commands here are run from the top-level directory of the repository, unless otherwise stated. 8 | 9 | This is unsupported, pre-alpha software: use at your own risk! We are actively developing this software 10 | and will be making breaking changes to the API. 11 | 12 | PyPI 13 | ---- 14 | 15 | To install the latest release directly from PyPI, use: 16 | 17 | .. code-block:: shell 18 | 19 | pip install fme 20 | 21 | Conda 22 | ----- 23 | 24 | For convenience, we provide an easy way to create a conda environment with `fme` installed. 25 | First, clone the repository: 26 | 27 | .. code-block:: shell 28 | 29 | git clone git@github.com:ai2cm/ace.git 30 | 31 | A make target is available to build a conda environment: 32 | 33 | .. code-block:: shell 34 | 35 | make create_environment 36 | 37 | This will create an environment named ``fme``. If you would like a different name, set the ENVIRONMENT_NAME variable: 38 | 39 | .. code-block:: shell 40 | 41 | ENVIRONMENT_NAME= make create_environment 42 | 43 | Development 44 | ----------- 45 | 46 | To install directly from source for development, clone the repository: 47 | 48 | .. code-block:: shell 49 | 50 | git clone git@github.com:ai2cm/ace.git 51 | 52 | Once downloaded, you can install the sources in development mode (``-e`` flag) with the extra dependencies for development (``[dev]``) and versions pinned to the ones we use in development (``-c constraints.txt``) with the following command: 53 | 54 | .. code-block:: shell 55 | 56 | pip install -c constraints.txt -e fme[dev] 57 | 58 | Docker 59 | ------ 60 | 61 | A make target is available to build the Docker image: 62 | 63 | .. code-block:: shell 64 | 65 | make build_docker_image 66 | -------------------------------------------------------------------------------- /fme/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=fme 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /fme/docs/modules.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Modules 3 | ======= 4 | 5 | ACE's code uses a module registry system to allow different machine learning architectures to plug into the framework. 6 | This is managed by the :class:`fme.ace.ModuleSelector` configuration class, which is used to select a module type and version. 7 | 8 | .. autoclass:: fme.ace.ModuleSelector 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | :noindex: 13 | 14 | The following module types are available: 15 | 16 | .. include:: available_modules.rst 17 | 18 | .. autofunction:: fme.core.registry.ModuleSelector.get_available_types 19 | 20 | The following module builders are available: 21 | 22 | .. autoclass:: fme.ace.SphericalFourierNeuralOperatorBuilder 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | :noindex: 27 | 28 | .. autoclass:: fme.ace.SFNO_V0_1_0 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | :noindex: 33 | 34 | .. autoclass:: fme.ace.HEALPixRecUNetBuilder 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | :noindex: -------------------------------------------------------------------------------- /fme/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | furo==2024.04.27 2 | sphinx==7.0.0 3 | sphinx_autodoc_typehints -------------------------------------------------------------------------------- /fme/fme/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "2024.12.0" 2 | 3 | import torch_harmonics 4 | 5 | from . import ace 6 | from .core import Packer, StandardNormalizer, get_device, get_normalizer, using_gpu 7 | from .core.metrics import ( 8 | gradient_magnitude, 9 | gradient_magnitude_percent_diff, 10 | rmse_of_time_mean, 11 | root_mean_squared_error, 12 | spherical_area_weights, 13 | time_and_global_mean_bias, 14 | weighted_mean, 15 | weighted_mean_bias, 16 | weighted_mean_gradient_magnitude, 17 | weighted_std, 18 | ) 19 | 20 | APPLY_SHT_FIX = True 21 | 22 | if APPLY_SHT_FIX: 23 | from .sht_fix import InverseRealSHT, RealSHT 24 | 25 | __all__ = [ 26 | "spherical_area_weights", 27 | "weighted_mean", 28 | "weighted_mean_bias", 29 | "root_mean_squared_error", 30 | "gradient_magnitude", 31 | "weighted_mean_gradient_magnitude", 32 | "rmse_of_time_mean", 33 | "time_and_global_mean_bias", 34 | "gradient_magnitude_percent_diff", 35 | "get_device", 36 | "get_normalizer", 37 | "Packer", 38 | "StandardNormalizer", 39 | "using_gpu", 40 | ] 41 | 42 | 43 | if APPLY_SHT_FIX: 44 | torch_harmonics.RealSHT = RealSHT 45 | torch_harmonics.InverseRealSHT = InverseRealSHT 46 | -------------------------------------------------------------------------------- /fme/fme/ace/LICENSE: -------------------------------------------------------------------------------- 1 | #BSD 3-Clause License 2 | # 3 | #Copyright (c) 2022, FourCastNet authors 4 | #All rights reserved. 5 | # 6 | #Redistribution and use in source and binary forms, with or without 7 | #modification, are permitted provided that the following conditions are met: 8 | # 9 | #1. Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | # 12 | #2. Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | # 16 | #3. Neither the name of the copyright holder nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | # 20 | #THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | #AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | #IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | #DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | #FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | #DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | #SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | #CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | #OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | #OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | #The code was authored by the following people: 32 | # 33 | #Jaideep Pathak - NVIDIA Corporation 34 | #Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory 35 | #Peter Harrington - NERSC, Lawrence Berkeley National Laboratory 36 | #Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory 37 | #Ashesh Chattopadhyay - Rice University 38 | #Morteza Mardani - NVIDIA Corporation 39 | #Thorsten Kurth - NVIDIA Corporation 40 | #David Hall - NVIDIA Corporation 41 | #Zongyi Li - California Institute of Technology, NVIDIA Corporation 42 | #Kamyar Azizzadenesheli - Purdue University 43 | #Pedram Hassanzadeh - Rice University 44 | #Karthik Kashinath - NVIDIA Corporation 45 | #Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation 46 | 47 | -------------------------------------------------------------------------------- /fme/fme/ace/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from fme.ace.data_loading.inference import ( 4 | ExplicitIndices, 5 | InferenceInitialConditionIndices, 6 | TimestampList, 7 | ) 8 | from fme.ace.data_loading.perturbation import ( 9 | ConstantConfig, 10 | GreensFunctionConfig, 11 | PerturbationSelector, 12 | SSTPerturbation, 13 | ) 14 | from fme.ace.inference.data_writer.time_coarsen import TimeCoarsenConfig 15 | from fme.ace.inference.evaluator import ( 16 | DataWriterConfig, 17 | InferenceDataLoaderConfig, 18 | InferenceEvaluatorAggregatorConfig, 19 | InferenceEvaluatorConfig, 20 | OceanConfig, 21 | run_evaluator_from_config, 22 | ) 23 | from fme.ace.inference.inference import ( 24 | ForcingDataLoaderConfig, 25 | InferenceAggregatorConfig, 26 | InferenceConfig, 27 | InitialConditionConfig, 28 | run_inference_from_config, 29 | ) 30 | from fme.ace.models.healpix.healpix_activations import ( 31 | CappedGELUConfig, 32 | DownsamplingBlockConfig, 33 | ) 34 | from fme.ace.models.healpix.healpix_blocks import ConvBlockConfig, RecurrentBlockConfig 35 | from fme.ace.registry.hpx import ( 36 | HEALPixRecUNetBuilder, 37 | UNetDecoderConfig, 38 | UNetEncoderConfig, 39 | ) 40 | from fme.ace.registry.sfno import SFNO_V0_1_0, SphericalFourierNeuralOperatorBuilder 41 | from fme.core.corrector.corrector import CorrectorConfig 42 | from fme.core.corrector.ocean import OceanCorrectorConfig 43 | from fme.core.dataset.config import TimeSlice, XarrayDataConfig 44 | from fme.core.gridded_ops import GriddedOperations 45 | from fme.core.loss import WeightedMappingLossConfig 46 | from fme.core.normalizer import NormalizationConfig 47 | from fme.core.ocean import SlabOceanConfig 48 | from fme.core.optimization import SchedulerConfig 49 | from fme.core.parameter_init import FrozenParameterConfig, ParameterInitializationConfig 50 | from fme.core.registry.corrector import CorrectorSelector 51 | from fme.core.registry.module import ModuleSelector 52 | from fme.core.typing_ import Slice 53 | 54 | from .train.train import run_train 55 | from .train.train_config import ( 56 | CopyWeightsConfig, 57 | DataLoaderConfig, 58 | EMAConfig, 59 | ExistingStepperConfig, 60 | InlineInferenceConfig, 61 | LoggingConfig, 62 | OptimizationConfig, 63 | SingleModuleStepperConfig, 64 | TrainConfig, 65 | ) 66 | 67 | # Get all the names defined in the current module 68 | module = sys.modules[__name__] 69 | __all__ = [name for name in dir(module) if not name.startswith("_")] 70 | del sys, module 71 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import InferenceEvaluatorAggregator, InferenceEvaluatorAggregatorConfig 2 | from .null import NullAggregator 3 | from .one_step import OneStepAggregator 4 | from .train import TrainAggregator 5 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import ( 2 | InferenceAggregator, 3 | InferenceAggregatorConfig, 4 | InferenceEvaluatorAggregator, 5 | InferenceEvaluatorAggregatorConfig, 6 | ) 7 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/inference/enso/__init__.py: -------------------------------------------------------------------------------- 1 | from .enso import EnsoCoefficientEvaluatorAggregator 2 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/inference/histogram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import xarray as xr 3 | 4 | from fme.core.histogram import ComparedDynamicHistograms 5 | from fme.core.typing_ import TensorMapping 6 | 7 | 8 | class HistogramAggregator: 9 | def __init__(self): 10 | self._histograms = ComparedDynamicHistograms(n_bins=200, percentiles=[99.9999]) 11 | 12 | @torch.no_grad() 13 | def record_batch( 14 | self, 15 | target_data: TensorMapping, 16 | gen_data: TensorMapping, 17 | target_data_norm: TensorMapping, 18 | gen_data_norm: TensorMapping, 19 | i_time_start: int = 0, 20 | ): 21 | self._histograms.record_batch(target_data, gen_data) 22 | 23 | @torch.no_grad() 24 | def get_logs(self, label: str): 25 | logs = self._histograms.get_wandb() 26 | if label != "": 27 | logs = {f"{label}/{k}": v for k, v in logs.items()} 28 | return logs 29 | 30 | @torch.no_grad() 31 | def get_dataset(self) -> xr.Dataset: 32 | return self._histograms.get_dataset() 33 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/inference/test_distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from fme.ace.aggregator.inference.reduced import MeanAggregator 4 | from fme.ace.aggregator.inference.time_mean import TimeMeanEvaluatorAggregator 5 | from fme.core.device import get_device 6 | from fme.core.gridded_ops import LatLonOperations 7 | from fme.core.testing import mock_distributed 8 | 9 | 10 | def test_mean_metrics_call_distributed(): 11 | """ 12 | All mean metrics should be reduced across processes using Distributed. 13 | 14 | This tests that functionality by modifying the Distributed singleton. 15 | """ 16 | with mock_distributed(-1.0): 17 | data_a = torch.ones([2, 3, 4, 4], device=get_device()) 18 | area_weights = torch.ones(1).to(get_device()) 19 | agg = MeanAggregator( 20 | LatLonOperations(area_weights), target="denorm", n_timesteps=3 21 | ) 22 | sample_data = {"a": data_a} 23 | agg.record_batch(sample_data, sample_data, sample_data, sample_data) 24 | logs = agg.get_logs(label="metrics") 25 | table = logs["metrics/series"] 26 | # assert all data past the first column in the WandB table is -1 27 | assert all([all(item == -1 for item in row[1][1:]) for row in table.iterrows()]) 28 | 29 | 30 | def test_time_mean_metrics_call_distributed(): 31 | """ 32 | All time-mean metrics should be reduced across processes using Distributed. 33 | 34 | This tests that functionality by modifying the Distributed singleton. 35 | """ 36 | torch.manual_seed(0) 37 | with mock_distributed(0.0) as mock: 38 | area_weights = torch.ones(1).to(get_device()) 39 | agg = TimeMeanEvaluatorAggregator( 40 | LatLonOperations(area_weights), horizontal_dims=["lat", "lon"] 41 | ) 42 | target_data = {"a": torch.ones([2, 3, 4, 4], device=get_device())} 43 | gen_data = {"a": torch.randn([2, 3, 4, 4], device=get_device())} 44 | agg.record_batch( 45 | target_data=target_data, 46 | gen_data=gen_data, 47 | target_data_norm=target_data, 48 | gen_data_norm=gen_data, 49 | ) 50 | logs = agg.get_logs(label="metrics") 51 | # the reduction happens on the time-means, so the gen and target data should 52 | # be filled identically and all errors will be zero, even though we gave them 53 | # different data above. 54 | assert logs["metrics/rmse/a"] == 0.0 55 | assert logs["metrics/bias/a"] == 0.0 56 | assert mock.reduce_called 57 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/inference/test_reduced.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import fme 5 | from fme.ace.aggregator.inference.reduced import ( 6 | AreaWeightedReducedMetric, 7 | SingleTargetMeanAggregator, 8 | ) 9 | from fme.core.device import get_device 10 | from fme.core.gridded_ops import LatLonOperations 11 | 12 | 13 | def test_area_weighted_reduced_metric_includes_later_window_starts(): 14 | """ 15 | The area weighted reduced metric should assume that the start 16 | of a window is always recorded, as we clip it before calling. 17 | """ 18 | 19 | def compute_metric(truth, predicted, weights=None, dim=()): 20 | return truth.mean(dim=(2, 3)) 21 | 22 | metric = AreaWeightedReducedMetric( 23 | device=get_device(), 24 | compute_metric=compute_metric, 25 | n_timesteps=7, 26 | ) 27 | 28 | data = torch.ones([2, 3, 4, 4], device=get_device()) 29 | metric.record(data, data, 0) 30 | data[:, 0, :, :] = np.nan 31 | metric.record(data, data, 2) 32 | metric.record(data, data, 4) 33 | result = metric.get() 34 | result = result.cpu().numpy() 35 | # assert tensor is all ones 36 | assert np.sum(np.isnan(result)) == 2 37 | assert np.isnan(result[2]) 38 | assert np.isnan(result[4]) 39 | 40 | 41 | def test_single_target_mean_aggregator(): 42 | """ 43 | The area weighted reduced metric should assume that the start 44 | of a window is always recorded, as we clip it before calling. 45 | """ 46 | n_sample = 10 47 | n_time_per_window = 22 48 | n_window = 3 49 | nx = 2 50 | ny = 2 51 | area_weights = torch.ones(ny).to(fme.get_device()) 52 | torch.manual_seed(0) 53 | 54 | agg = SingleTargetMeanAggregator( 55 | gridded_operations=LatLonOperations(area_weights), 56 | n_timesteps=n_time_per_window * n_window, 57 | ) 58 | data_a = torch.randn(n_sample, n_time_per_window, nx, ny, device=get_device()) 59 | for i in range(n_window): 60 | data = {"a": data_a[:, i * n_time_per_window : (i + 1) * n_time_per_window]} 61 | agg.record_batch(data=data, i_time_start=i * n_time_per_window) 62 | 63 | logs = agg.get_logs(label="test") 64 | assert "test/series" in logs 65 | ds = agg.get_dataset() 66 | for i in range(1, data_a.shape[1]): 67 | raw_variable = data_a[:, i] 68 | raw_global_mean = raw_variable.mean().cpu().numpy() 69 | raw_global_std = ( 70 | raw_variable.std(dim=(1, 2), correction=0).mean().cpu().numpy() 71 | ) # metrics are mean over batch 72 | np.testing.assert_allclose( 73 | raw_global_std, 74 | ds["weighted_std_gen-a"].isel(forecast_step=i).values.item(), 75 | rtol=1e-5, 76 | ) 77 | np.testing.assert_allclose( 78 | raw_global_mean, 79 | ds["weighted_mean_gen-a"].isel(forecast_step=i).values.item(), 80 | rtol=1e-5, 81 | ) 82 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/inference/test_seasonal.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import cftime 4 | import numpy as np 5 | import torch 6 | import xarray as xr 7 | 8 | import fme 9 | from fme.ace.aggregator.inference.seasonal import SeasonalAggregator 10 | from fme.core.device import get_device 11 | from fme.core.gridded_ops import LatLonOperations 12 | 13 | 14 | def get_zero_time(shape, dims): 15 | return xr.DataArray(np.zeros(shape, dtype="datetime64[ns]"), dims=dims) 16 | 17 | 18 | def test_seasonal_aggregator(): 19 | n_lat = 16 20 | n_lon = 32 21 | # need to have two actual full years of data for plotting to get exercised 22 | n_sample = 2 23 | n_time_step = 8 24 | n_time = int(365 / 10 * 2 / n_time_step + 1) * n_time_step 25 | area_weights = torch.ones(n_lat, n_lon).to(fme.get_device()) 26 | agg = SeasonalAggregator( 27 | LatLonOperations(area_weights), 28 | ) 29 | target_data = { 30 | "a": torch.randn(n_sample, n_time, n_lat, n_lon, device=get_device()) 31 | } 32 | gen_data = {"a": torch.randn(n_sample, n_time, n_lat, n_lon, device=get_device())} 33 | 34 | def time_select(tensor_mapping, start, stop): 35 | return { 36 | name: value[:, start:stop, ...] for name, value in tensor_mapping.items() 37 | } 38 | 39 | time = get_zero_time(shape=[n_sample, n_time], dims=["sample", "time"]) 40 | time_1d = [ 41 | cftime.DatetimeProlepticGregorian(2000, 1, 1) + i * datetime.timedelta(days=10) 42 | for i in range(n_time) 43 | ] 44 | time = xr.DataArray([time_1d for _ in range(n_sample)], dims=["sample", "time"]) 45 | for i in range(0, n_time, n_time_step): 46 | agg.record_batch( 47 | time.isel(time=range(i, i + n_time_step)), 48 | time_select(target_data, i, i + n_time_step), 49 | time_select(gen_data, i, i + n_time_step), 50 | ) 51 | logs = agg.get_logs(label="test") 52 | for name, value in logs.items(): 53 | if isinstance(value, (float, np.ndarray)): 54 | assert not np.isnan(value), f"{name} is nan" 55 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/inference/test_spectrum.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch_harmonics 4 | 5 | import fme 6 | from fme.ace.aggregator.inference.spectrum import ( 7 | PairedSphericalPowerSpectrumAggregator, 8 | SphericalPowerSpectrumAggregator, 9 | ) 10 | from fme.core.metrics import spherical_power_spectrum 11 | 12 | 13 | def test_spherical_power_spectrum_aggregator(): 14 | nlat = 8 15 | nlon = 16 16 | grid = "legendre-gauss" 17 | agg = SphericalPowerSpectrumAggregator(nlat, nlon, grid=grid) 18 | data = {"a": torch.randn(2, 2, nlat, nlon, device=fme.get_device())} 19 | data2 = {"a": torch.randn(2, 3, nlat, nlon, device=fme.get_device())} 20 | agg.record_batch(data) 21 | agg.record_batch(data2) 22 | result = agg.get_mean() 23 | assert "a" in result 24 | assert result["a"].shape == (nlat,) 25 | 26 | sht = torch_harmonics.RealSHT(nlat, nlon, grid=grid) 27 | data_concat = torch.cat([data["a"], data2["a"]], dim=1) 28 | expected_value = torch.mean(spherical_power_spectrum(data_concat, sht), dim=(0, 1)) 29 | torch.testing.assert_close(result["a"], expected_value) 30 | 31 | 32 | def test_paired_spherical_power_spectrum_aggregator(): 33 | nlat = 8 34 | nlon = 16 35 | agg = PairedSphericalPowerSpectrumAggregator(nlat, nlon) 36 | data = {"a": torch.randn(2, 3, nlat, nlon, device=fme.get_device())} 37 | agg.record_batch(data, data, None, None) 38 | result = agg.get_logs("spectrum") 39 | assert isinstance(result["spectrum/a"], plt.Figure) 40 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/null.py: -------------------------------------------------------------------------------- 1 | from fme.core.typing_ import TensorMapping 2 | 3 | 4 | class NullAggregator: 5 | """ 6 | An aggregator that does nothing. Null object pattern. 7 | """ 8 | 9 | def __init__(self): 10 | pass 11 | 12 | def record_batch( 13 | self, 14 | loss: float, 15 | target_data: TensorMapping, 16 | gen_data: TensorMapping, 17 | target_data_norm: TensorMapping, 18 | gen_data_norm: TensorMapping, 19 | i_time_start: int = 0, 20 | ): 21 | pass 22 | 23 | def get_logs(self, label: str): 24 | """ 25 | Returns logs as can be reported to WandB. 26 | 27 | Args: 28 | label: Label to prepend to all log keys. 29 | """ 30 | return {} 31 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/one_step/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import OneStepAggregator 2 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/one_step/reduced_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains code for computing metrics of single variables on batches of data, 3 | and aggregating them into a single metric value. The functions here mainly exist 4 | to turn metric functions that may have different APIs into a common API, 5 | so that they can be iterated over and called in the same way in a loop. 6 | """ 7 | 8 | from typing import Protocol 9 | 10 | import torch 11 | 12 | 13 | class AreaWeightedFunction(Protocol): 14 | """ 15 | A function that computes a metric on the true and predicted values, 16 | weighted by area. 17 | """ 18 | 19 | def __call__( 20 | self, 21 | truth: torch.Tensor, 22 | predicted: torch.Tensor, 23 | ) -> torch.Tensor: ... 24 | 25 | 26 | class ReducedMetric(Protocol): 27 | """Used to record a metric value on batches of data (potentially out-of-memory) 28 | and then get the total metric at the end. 29 | """ 30 | 31 | def record(self, target: torch.Tensor, gen: torch.Tensor): 32 | """ 33 | Update metric for a batch of data. 34 | """ 35 | ... 36 | 37 | def get(self) -> torch.Tensor: 38 | """ 39 | Get the total metric value, not divided by number of recorded batches. 40 | """ 41 | ... 42 | 43 | 44 | class AreaWeightedReducedMetric: 45 | """ 46 | A wrapper around an area-weighted metric function. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | device: torch.device, 52 | compute_metric: AreaWeightedFunction, 53 | ): 54 | self._compute_metric = compute_metric 55 | self._total = None 56 | self._device = device 57 | 58 | def record(self, target: torch.Tensor, gen: torch.Tensor): 59 | """Add a batch of data to the metric. 60 | 61 | Args: 62 | target: Target data. Should have shape [batch, time, height, width]. 63 | gen: Generated data. Should have shape [batch, time, height, width]. 64 | """ 65 | new_value = self._compute_metric(target, gen).mean(dim=0) 66 | if self._total is None: 67 | self._total = torch.zeros_like(new_value, device=self._device) 68 | self._total += new_value 69 | 70 | def get(self) -> torch.Tensor: 71 | """Returns the metric.""" 72 | return self._total 73 | -------------------------------------------------------------------------------- /fme/fme/ace/aggregator/train.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | 5 | from fme.ace.stepper import TrainOutput 6 | from fme.core.device import get_device 7 | from fme.core.distributed import Distributed 8 | from fme.core.generics.aggregator import AggregatorABC 9 | 10 | 11 | class TrainAggregator(AggregatorABC[TrainOutput]): 12 | """ 13 | Aggregates statistics for the first timestep. 14 | 15 | To use, call `record_batch` on the results of each batch, then call 16 | `get_logs` to get a dictionary of statistics when you're done. 17 | """ 18 | 19 | def __init__(self): 20 | self._n_batches = 0 21 | self._loss = torch.tensor(0.0, device=get_device()) 22 | 23 | @torch.no_grad() 24 | def record_batch(self, batch: TrainOutput): 25 | self._loss += batch.metrics["loss"] 26 | self._n_batches += 1 27 | 28 | @torch.no_grad() 29 | def get_logs(self, label: str) -> Dict[str, torch.Tensor]: 30 | """ 31 | Returns logs as can be reported to WandB. 32 | 33 | Args: 34 | label: Label to prepend to all log keys. 35 | """ 36 | logs = {f"{label}/mean/loss": self._loss / self._n_batches} 37 | dist = Distributed.get_instance() 38 | for key in sorted(logs.keys()): 39 | logs[key] = float(dist.reduce_mean(logs[key].detach()).cpu().numpy()) 40 | return logs 41 | -------------------------------------------------------------------------------- /fme/fme/ace/data_loading/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/fme/ace/data_loading/__init__.py -------------------------------------------------------------------------------- /fme/fme/ace/data_loading/config.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Optional, Sequence 3 | 4 | from fme.core.dataset.config import XarrayDataConfig 5 | from fme.core.distributed import Distributed 6 | 7 | 8 | @dataclasses.dataclass 9 | class DataLoaderConfig: 10 | """ 11 | Parameters: 12 | dataset: A sequence of configurations each defining a dataset 13 | to be loaded. This sequence of datasets will be concatenated. 14 | batch_size: Number of samples per batch. 15 | num_data_workers: Number of parallel workers to use for data loading. 16 | prefetch_factor: how many batches a single data worker will attempt to 17 | hold in host memory at a given time. 18 | strict_ensemble: Whether to enforce that the datasets to be concatened 19 | have the same dimensions and coordinates. 20 | """ 21 | 22 | dataset: Sequence[XarrayDataConfig] 23 | batch_size: int 24 | num_data_workers: int = 0 25 | prefetch_factor: Optional[int] = None 26 | strict_ensemble: bool = True 27 | 28 | def __post_init__(self): 29 | dist = Distributed.get_instance() 30 | if self.batch_size % dist.world_size != 0: 31 | raise ValueError( 32 | "batch_size must be divisible by the number of parallel " 33 | f"workers, got {self.batch_size} and {dist.world_size}" 34 | ) 35 | -------------------------------------------------------------------------------- /fme/fme/ace/data_loading/test_data_loading_config.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | 7 | from fme.core.dataset.config import RepeatedInterval 8 | 9 | 10 | def test_repeated_interval_int(): 11 | interval = RepeatedInterval(interval_length=3, block_length=6, start=0) 12 | mask = interval.get_boolean_mask(length=18) 13 | expected_mask = np.array([True, True, True, False, False, False] * 3) 14 | np.testing.assert_array_equal(mask, expected_mask) 15 | 16 | 17 | def test_repeated_interval_str(): 18 | interval = RepeatedInterval(interval_length="1d", block_length="7d", start="2d") 19 | mask = interval.get_boolean_mask(length=21, timestep=timedelta(days=1)) 20 | expected_mask = np.array([False, False, True, False, False, False, False] * 3) 21 | np.testing.assert_array_equal(mask, expected_mask) 22 | 23 | 24 | def test_repeated_interval_mixed_types(): 25 | with pytest.raises(ValueError): 26 | RepeatedInterval(interval_length=3, block_length="6d", start=0) 27 | 28 | 29 | @pytest.mark.parametrize("interval, block, start", [(4, 6, 3), ("2d", "3d", "2d")]) 30 | def test_repeated_interval_invalid_interval_start(interval, block, start): 31 | """start + interval exceeds length of block""" 32 | interval = RepeatedInterval( 33 | interval_length=interval, block_length=block, start=start 34 | ) 35 | with pytest.raises(ValueError): 36 | interval.get_boolean_mask(length=18, timestep=timedelta(days=1)) 37 | 38 | 39 | def test_repeated_interval_zero_length(): 40 | interval = RepeatedInterval(interval_length=0, block_length=6, start=0) 41 | mask = interval.get_boolean_mask(length=18) 42 | expected_mask = np.array([False] * 18) 43 | np.testing.assert_array_equal(mask, expected_mask) 44 | 45 | 46 | def test_repeated_interval_partial_block(): 47 | interval = RepeatedInterval(interval_length=3, block_length=6, start=0) 48 | mask = interval.get_boolean_mask(length=20) 49 | expected_mask = np.array([True, True, True, False, False, False] * 3 + [True, True]) 50 | np.testing.assert_array_equal(mask, expected_mask) 51 | 52 | 53 | def test_repeated_interval_no_timestep_fails_for_timedelta_lengths(): 54 | interval = RepeatedInterval(interval_length="1d", block_length="7d", start="0d") 55 | with pytest.raises(ValueError): 56 | interval.get_boolean_mask(length=20) 57 | 58 | 59 | @pytest.mark.parametrize("timestep", ["2h", "150m", "5h", "12h"]) 60 | def test_invalid_timesteps(timestep): 61 | """ 62 | Test that timesteps that don't evenly divide into some or all 63 | arguments raise a ValueError 64 | """ 65 | timestep = pd.to_timedelta(timestep) 66 | with pytest.raises(ValueError): 67 | RepeatedInterval( 68 | interval_length="5h", start="4h", block_length="10h" 69 | ).get_boolean_mask(length=18, timestep=timestep) 70 | -------------------------------------------------------------------------------- /fme/fme/ace/data_loading/test_perturbation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import fme 4 | from fme.ace.data_loading.perturbation import ( 5 | ConstantConfig, 6 | GreensFunctionConfig, 7 | PerturbationSelector, 8 | ) 9 | 10 | 11 | def test_constant_perturbation_config(): 12 | selector = PerturbationSelector( 13 | type="constant", 14 | config={"amplitude": 1.0}, 15 | ) 16 | perturbation = selector.build() 17 | assert isinstance(perturbation, ConstantConfig) 18 | assert perturbation.amplitude == 1.0 19 | nx, ny = 5, 5 20 | lat = torch.arange(nx, device=fme.get_device()) 21 | lon = torch.arange(ny, device=fme.get_device()) 22 | lats, lons = torch.meshgrid(lat, lon, indexing="ij") 23 | ocean_fraction = torch.ones(nx, ny, device=fme.get_device()) 24 | data = torch.ones(nx, ny, device=fme.get_device()) 25 | expected = 2.0 * torch.ones(nx, ny, device=fme.get_device()) 26 | perturbation.apply_perturbation(data, lats, lons, ocean_fraction) 27 | torch.testing.assert_close(data, expected) 28 | 29 | 30 | def test_green_function_perturbation_config(): 31 | selector = PerturbationSelector( 32 | type="greens_function", 33 | config={ 34 | "amplitude": 1.0, 35 | "lat_center": 0.0, 36 | "lon_center": 0.0, 37 | "lat_width": 10.0, 38 | "lon_width": 10.0, 39 | }, 40 | ) 41 | perturbation = selector.build() 42 | assert isinstance(perturbation, GreensFunctionConfig) 43 | assert perturbation.amplitude == 1.0 44 | assert perturbation.lat_center == 0.0 45 | assert perturbation.lon_center == 0.0 46 | assert perturbation.lat_width == 10.0 47 | assert perturbation.lon_width == 10.0 48 | nx, ny = 5, 5 49 | lat = torch.arange(nx, device=fme.get_device()) 50 | lon = torch.arange(ny, device=fme.get_device()) 51 | lats, lons = torch.meshgrid(lat, lon, indexing="ij") 52 | ocean_fraction = torch.ones(nx, ny, device=fme.get_device()) 53 | data = torch.ones(nx, ny, device=fme.get_device()) 54 | perturbation.apply_perturbation(data, lats, lons, ocean_fraction) 55 | -------------------------------------------------------------------------------- /fme/fme/ace/evaluator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from fme.ace.inference.evaluator import main 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("yaml_config", type=str) 8 | 9 | args = parser.parse_args() 10 | 11 | main( 12 | yaml_config=args.yaml_config, 13 | ) 14 | -------------------------------------------------------------------------------- /fme/fme/ace/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/fme/ace/inference/__init__.py -------------------------------------------------------------------------------- /fme/fme/ace/inference/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from .inference import main 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("yaml_config", type=str) 7 | parser.add_argument( 8 | "--segments", 9 | type=int, 10 | default=None, 11 | help="If provided, number of times to repeat the inference in time, saving each " 12 | "segment in a separate folder labeled as 'segment_0000', 'segment_0001' etc. " 13 | "WARNING: this feature is experimental and its API is subject to change.", 14 | ) 15 | args = parser.parse_args() 16 | main(yaml_config=args.yaml_config, segments=args.segments) 17 | -------------------------------------------------------------------------------- /fme/fme/ace/inference/data_writer/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import DataWriter, DataWriterConfig, PairedDataWriter 2 | -------------------------------------------------------------------------------- /fme/fme/ace/inference/data_writer/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterable, Optional, Set, TypeVar 3 | 4 | T = TypeVar("T") 5 | 6 | 7 | def get_all_names( 8 | *data_varnames: Iterable[T], allowlist: Optional[Iterable[T]] = None 9 | ) -> Set[T]: 10 | """ 11 | Returns all variable names from lists of variable names, optionally 12 | filtering by an allowlist. 13 | """ 14 | variables: Set[T] = set() 15 | for varnames in data_varnames: 16 | variables = variables.union(set(varnames)) 17 | if allowlist is None: 18 | return variables 19 | else: 20 | return variables.intersection(set(allowlist)) 21 | 22 | 23 | @dataclass 24 | class DimInfo: 25 | name: str 26 | index: int 27 | 28 | 29 | DIM_INFO_LATLON = [ 30 | DimInfo(name="lat", index=-2), 31 | DimInfo(name="lon", index=-1), 32 | ] 33 | 34 | DIM_INFO_HEALPIX = [ 35 | DimInfo(name="face", index=-3), 36 | DimInfo(name="height", index=-2), 37 | DimInfo(name="width", index=-1), 38 | ] 39 | -------------------------------------------------------------------------------- /fme/fme/ace/inference/data_writer/video.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, Mapping 3 | 4 | import numpy as np 5 | import torch 6 | import xarray as xr 7 | 8 | from fme.ace.aggregator.inference.video import VideoAggregator 9 | from fme.core.dataset.data_typing import VariableMetadata 10 | 11 | 12 | class PairedVideoDataWriter: 13 | """ 14 | Write [time, lat, lon] metric data to a netCDF file. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | path: str, 20 | n_timesteps: int, 21 | variable_metadata: Mapping[str, VariableMetadata], 22 | coords: Mapping[str, np.ndarray], 23 | ): 24 | """ 25 | Args: 26 | path: Directory within which to write the file. 27 | n_samples: Number of samples to write to the file. 28 | n_timesteps: Number of timesteps to write to the file. 29 | variable_metadata: Metadata for each variable to be written to the file. 30 | coords: Coordinate data to be written to the file. 31 | """ 32 | self.path = path 33 | self._metrics_filename = str( 34 | Path(path) / "reduced_autoregressive_predictions.nc" 35 | ) 36 | self.variable_metadata = variable_metadata 37 | self.coords = coords 38 | self._video = VideoAggregator( 39 | n_timesteps=n_timesteps, 40 | enable_extended_videos=True, 41 | variable_metadata=variable_metadata, 42 | ) 43 | 44 | def append_batch( 45 | self, 46 | target: Dict[str, torch.Tensor], 47 | prediction: Dict[str, torch.Tensor], 48 | start_timestep: int, 49 | batch_time: xr.DataArray, 50 | ): 51 | """ 52 | Append a batch of data to the file. 53 | 54 | Args: 55 | target: Target data. 56 | prediction: Prediction data. 57 | start_timestep: Timestep at which to start writing. 58 | batch_time: Time coordinate for each sample in the batch. Unused. 59 | """ 60 | self._video.record_batch( 61 | target_data=target, 62 | gen_data=prediction, 63 | i_time_start=start_timestep, 64 | ) 65 | 66 | def flush(self): 67 | """ 68 | Flush the data to disk. 69 | """ 70 | metric_dataset = self._video.get_dataset() 71 | coords = {} 72 | if "lat" in self.coords: 73 | coords["lat"] = self.coords["lat"] 74 | if "lon" in self.coords: 75 | coords["lon"] = self.coords["lon"] 76 | metric_dataset = metric_dataset.assign_coords(coords) 77 | metric_dataset.to_netcdf(self._metrics_filename) 78 | -------------------------------------------------------------------------------- /fme/fme/ace/inference/stepper_test_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/fme/ace/inference/stepper_test_data -------------------------------------------------------------------------------- /fme/fme/ace/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/fme/ace/models/__init__.py -------------------------------------------------------------------------------- /fme/fme/ace/models/healpix/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/fme/ace/models/healpix/__init__.py -------------------------------------------------------------------------------- /fme/fme/ace/models/makani/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/fme/ace/models/makani/__init__.py -------------------------------------------------------------------------------- /fme/fme/ace/models/modulus/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/fme/ace/models/modulus/__init__.py -------------------------------------------------------------------------------- /fme/fme/ace/models/modulus/initialization.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # Copied from https://github.com/ai2cm/modulus/commit/22df4a9427f5f12ff6ac891083220e7f2f54d229 3 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import math 18 | import warnings 19 | 20 | import torch 21 | 22 | 23 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 24 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 25 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 26 | def norm_cdf(x): 27 | # Computes standard normal cumulative distribution function 28 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 29 | 30 | if (mean < a - 2 * std) or (mean > b + 2 * std): 31 | warnings.warn( 32 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 33 | "The distribution of values may be incorrect.", 34 | stacklevel=2, 35 | ) 36 | 37 | with torch.no_grad(): 38 | # Values are generated by using a truncated uniform distribution and 39 | # then using the inverse CDF for the normal distribution. 40 | # Get upper and lower cdf values 41 | l = norm_cdf((a - mean) / std) 42 | u = norm_cdf((b - mean) / std) 43 | 44 | # Uniformly fill tensor with values from [l, u], then translate to 45 | # [2l-1, 2u-1]. 46 | tensor.uniform_(2 * l - 1, 2 * u - 1) 47 | 48 | # Use inverse cdf transform for normal distribution to get truncated 49 | # standard normal 50 | tensor.erfinv_() 51 | 52 | # Transform to proper mean, std 53 | tensor.mul_(std * math.sqrt(2.0)) 54 | tensor.add_(mean) 55 | 56 | # Clamp to ensure it's in the proper range 57 | tensor.clamp_(min=a, max=b) 58 | return tensor 59 | 60 | 61 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 62 | r"""Fills the input Tensor with values drawn from a truncated 63 | normal distribution. The values are effectively drawn from the 64 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 65 | with values outside :math:`[a, b]` redrawn until they are within 66 | the bounds. The method used for generating the random values works 67 | best when :math:`a \leq \text{mean} \leq b`. 68 | Args: 69 | tensor: an n-dimensional `torch.Tensor` 70 | mean: the mean of the normal distribution 71 | std: the standard deviation of the normal distribution 72 | a: the minimum cutoff value 73 | b: the maximum cutoff value 74 | """ 75 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 76 | -------------------------------------------------------------------------------- /fme/fme/ace/registry/__init__.py: -------------------------------------------------------------------------------- 1 | # import modules so they are registered 2 | from . import prebuilt as _prebuilt 3 | from . import sfno as _sfno 4 | from .registry import ModuleSelector 5 | 6 | del _prebuilt, _sfno 7 | -------------------------------------------------------------------------------- /fme/fme/ace/registry/hpx.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Tuple 3 | 4 | import torch.nn as nn 5 | 6 | from fme.ace.models.healpix.healpix_decoder import UNetDecoderConfig 7 | from fme.ace.models.healpix.healpix_encoder import UNetEncoderConfig 8 | from fme.ace.models.healpix.healpix_recunet import HEALPixRecUNet 9 | from fme.ace.registry.registry import ModuleConfig, ModuleSelector 10 | 11 | 12 | @ModuleSelector.register("HEALPixRecUNet") 13 | @dataclasses.dataclass 14 | class HEALPixRecUNetBuilder(ModuleConfig): 15 | """ 16 | Configuration for the HEALPixRecUNet architecture used in DLWP. 17 | 18 | Parameters: 19 | presteps: Number of pre-steps, by default 1. 20 | input_time_size: Input time dimension, by default 0. 21 | output_time_size: Output time dimension, by default 0. 22 | delta_time: Delta time interval, by default "6h". 23 | reset_cycle: Reset cycle interval, by default "24h". 24 | input_channels: Number of input channels, by default 8. 25 | output_channels: Number of output channels, by default 8. 26 | n_constants: Number of constant input channels, by default 2. 27 | decoder_input_channels: Number of input channels for the decoder, by default 1. 28 | enable_nhwc: Flag to enable NHWC data format, by default False. 29 | enable_healpixpad: Flag to enable HEALPix padding, by default False. 30 | """ 31 | 32 | encoder: UNetEncoderConfig 33 | decoder: UNetDecoderConfig 34 | presteps: int = 1 35 | input_time_size: int = 0 36 | output_time_size: int = 0 37 | delta_time: str = "6h" 38 | reset_cycle: str = "24h" 39 | n_constants: int = 2 40 | decoder_input_channels: int = 1 41 | prognostic_variables: int = 7 42 | enable_nhwc: bool = False 43 | enable_healpixpad: bool = False 44 | 45 | def build( 46 | self, 47 | n_in_channels: int, 48 | n_out_channels: int, 49 | img_shape: Tuple[int, int], 50 | ) -> nn.Module: 51 | """ 52 | Builds the HEALPixRecUNet model. 53 | 54 | Args: 55 | n_in_channels: Number of input channels. 56 | n_out_channels: Number of output channels. 57 | img_shape: Shape of the input image. 58 | 59 | Returns: 60 | HEALPixRecUNet model. 61 | """ 62 | # Construct the HEALPixRecUNet module here using the parameters 63 | return HEALPixRecUNet( 64 | encoder=self.encoder, 65 | decoder=self.decoder, 66 | input_channels=n_in_channels, 67 | output_channels=n_out_channels, 68 | prognostic_variables=self.prognostic_variables, 69 | n_constants=self.n_constants, 70 | decoder_input_channels=self.decoder_input_channels, 71 | input_time_size=self.input_time_size, 72 | output_time_size=self.output_time_size, 73 | delta_time=self.delta_time, 74 | reset_cycle=self.reset_cycle, 75 | presteps=self.presteps, 76 | enable_nhwc=self.enable_nhwc, 77 | enable_healpixpad=self.enable_healpixpad, 78 | ) 79 | -------------------------------------------------------------------------------- /fme/fme/ace/registry/prebuilt.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Tuple 3 | 4 | from torch import nn 5 | 6 | from fme.ace.registry.registry import ModuleConfig, ModuleSelector 7 | 8 | 9 | @ModuleSelector.register("prebuilt") 10 | @dataclasses.dataclass 11 | class PreBuiltBuilder(ModuleConfig): 12 | """ 13 | A simple module configuration which returns a pre-defined module. 14 | 15 | Used mainly for testing. 16 | """ 17 | 18 | module: nn.Module 19 | 20 | def build( 21 | self, 22 | n_in_channels: int, 23 | n_out_channels: int, 24 | img_shape: Tuple[int, int], 25 | ) -> nn.Module: 26 | return self.module 27 | -------------------------------------------------------------------------------- /fme/fme/ace/registry/registry.py: -------------------------------------------------------------------------------- 1 | from fme.core.registry.module import ( # noqa: F401 2 | ModuleConfig, 3 | ModuleSelector, 4 | ) 5 | -------------------------------------------------------------------------------- /fme/fme/ace/registry/test_sfno.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import datetime 3 | 4 | import numpy as np 5 | import pytest 6 | import torch 7 | 8 | from fme.ace.stepper import SingleModuleStepperConfig 9 | from fme.core.coordinates import HybridSigmaPressureCoordinate 10 | from fme.core.device import get_device 11 | from fme.core.gridded_ops import LatLonOperations 12 | from fme.core.normalizer import NormalizationConfig 13 | 14 | TIMESTEP = datetime.timedelta(hours=6) 15 | 16 | 17 | @pytest.mark.parametrize( 18 | "shape", 19 | [ 20 | pytest.param((8, 16)), 21 | ], 22 | ) 23 | def test_sfno_init(shape): 24 | num_layers = 2 25 | sfno_config_data = { 26 | "type": "SphericalFourierNeuralOperatorNet", 27 | "config": { 28 | "num_layers": num_layers, 29 | "embed_dim": 3, 30 | "scale_factor": 1, 31 | }, 32 | } 33 | stepper_config_data = { 34 | "builder": sfno_config_data, 35 | "in_names": ["x"], 36 | "out_names": ["x"], 37 | "normalization": dataclasses.asdict( 38 | NormalizationConfig( 39 | means={"x": float(np.random.randn(1).item())}, 40 | stds={"x": float(np.random.randn(1).item())}, 41 | ) 42 | ), 43 | } 44 | area = torch.ones((1, 16, 32)).to(get_device()) 45 | vertical_coordinate = HybridSigmaPressureCoordinate( 46 | ak=torch.arange(7), bk=torch.arange(7) 47 | ).to(get_device()) 48 | stepper_config = SingleModuleStepperConfig.from_state(stepper_config_data) 49 | stepper = stepper_config.get_stepper( 50 | img_shape=shape, 51 | gridded_operations=LatLonOperations(area), 52 | vertical_coordinate=vertical_coordinate, 53 | timestep=TIMESTEP, 54 | ) 55 | assert len(stepper.module.module.blocks) == num_layers 56 | -------------------------------------------------------------------------------- /fme/fme/ace/requirements.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import List 3 | 4 | 5 | @dataclasses.dataclass 6 | class PrognosticStateDataRequirements: 7 | """ 8 | The requirements for the model's prognostic state. 9 | 10 | Parameters: 11 | names: Names of prognostic variables. 12 | n_timesteps: Number of consecutive timesteps that must be stored. 13 | """ 14 | 15 | names: List[str] 16 | n_timesteps: int 17 | -------------------------------------------------------------------------------- /fme/fme/ace/run-train-and-inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | YAML_TRAIN_CONFIG=$1 6 | YAML_INFERENCE_CONFIG=$2 7 | NPROC_PER_NODE=$3 8 | 9 | # run training 10 | export WANDB_JOB_TYPE=training 11 | torchrun --nproc_per_node $NPROC_PER_NODE -m fme.ace.train $YAML_TRAIN_CONFIG 12 | 13 | echo =============================================================================== 14 | echo ==================== FINISHED TRAINING / STARTING INFERENCE =================== 15 | echo =============================================================================== 16 | 17 | # run inference 18 | export WANDB_JOB_TYPE=inference 19 | python -m fme.ace.inference.evaluator $YAML_INFERENCE_CONFIG 20 | -------------------------------------------------------------------------------- /fme/fme/ace/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from .fv3gfs_data import ( 2 | DimSize, 3 | DimSizes, 4 | FV3GFSData, 5 | MonthlyReferenceData, 6 | StatsData, 7 | save_nd_netcdf, 8 | save_scalar_netcdf, 9 | ) 10 | -------------------------------------------------------------------------------- /fme/fme/ace/train/__init__.py: -------------------------------------------------------------------------------- 1 | from fme.ace.train.train import Trainer 2 | from fme.core.generics.trainer import count_parameters 3 | -------------------------------------------------------------------------------- /fme/fme/ace/train/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from fme.ace.train.train import main 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("yaml_config", type=str) 8 | 9 | args = parser.parse_args() 10 | 11 | main(yaml_config=args.yaml_config) 12 | -------------------------------------------------------------------------------- /fme/fme/ace/validate_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import dacite 5 | import dacite.exceptions 6 | import yaml 7 | 8 | from fme.ace.inference.evaluator import InferenceEvaluatorConfig 9 | from fme.ace.inference.inference import InferenceConfig 10 | from fme.ace.stepper import SingleModuleStepperConfig 11 | from fme.ace.train.train_config import TrainConfig 12 | 13 | CONFIG_CHOICES = ["train", "inference", "evaluator"] 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "path", type=str, help="Path to the train or inference config file." 19 | ) 20 | parser.add_argument( 21 | "--inference", 22 | action="store_true", 23 | help=( 24 | "Deprecated, use --config_type evaluator to validate an evaluator config." 25 | ), 26 | ) 27 | parser.add_argument( 28 | "--config_type", 29 | type=str, 30 | choices=CONFIG_CHOICES, 31 | default="train", 32 | help=("Indicates the kind of configuration being validated."), 33 | ) 34 | args = parser.parse_args() 35 | 36 | if args.inference: 37 | logging.warning( 38 | "The --inference flag is deprecated. " 39 | "Use --config_type evaluator to validate an evaluator config." 40 | ) 41 | config_type = "evaluator" 42 | else: 43 | config_type = args.config_type 44 | 45 | with open(args.path, "r") as f: 46 | config_data = yaml.load(f, Loader=yaml.CLoader) 47 | 48 | if config_type == "evaluator": 49 | dacite.from_dict( 50 | data_class=InferenceEvaluatorConfig, 51 | data=config_data, 52 | config=dacite.Config(strict=True), 53 | ) 54 | elif config_type == "inference": 55 | dacite.from_dict( 56 | data_class=InferenceConfig, 57 | data=config_data, 58 | config=dacite.Config(strict=True), 59 | ) 60 | elif config_type == "train": 61 | try: 62 | dacite.from_dict( 63 | data_class=TrainConfig, 64 | data=config_data, 65 | config=dacite.Config(strict=True), 66 | ) 67 | except dacite.exceptions.UnionMatchError as err: 68 | if "checkpoint_path" not in config_data["stepper"]: 69 | dacite.from_dict( 70 | data_class=SingleModuleStepperConfig, 71 | data=config_data["stepper"], 72 | config=dacite.Config(strict=True), 73 | ) 74 | # if there was no issue for SingleModuleStepperConfig, raise original error 75 | raise err 76 | else: 77 | raise ValueError( 78 | f"Invalid config type: {config_type}, expected one of {CONFIG_CHOICES}" 79 | ) 80 | -------------------------------------------------------------------------------- /fme/fme/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .climate_data import ClimateData 2 | from .device import get_device, using_gpu 3 | from .gridded_ops import GriddedOperations 4 | from .metrics import ( 5 | root_mean_squared_error, 6 | spherical_area_weights, 7 | weighted_mean, 8 | weighted_mean_bias, 9 | ) 10 | from .normalizer import StandardNormalizer, get_normalizer 11 | from .packer import Packer 12 | 13 | __all__ = [ 14 | "spherical_area_weights", 15 | "weighted_mean", 16 | "weighted_mean_bias", 17 | "root_mean_squared_error", 18 | "get_device", 19 | "using_gpu", 20 | "StandardNormalizer", 21 | "get_normalizer", 22 | "Packer", 23 | "ClimateData", 24 | "GriddedOperations", 25 | ] 26 | -------------------------------------------------------------------------------- /fme/fme/core/constants.py: -------------------------------------------------------------------------------- 1 | LATENT_HEAT_OF_VAPORIZATION = 2.5e6 # J/kg 2 | GRAVITY = 9.80665 # m/s^2 3 | # following values are used by SHiELD's slab ocean model, and so we follow suit here. 4 | SPECIFIC_HEAT_OF_WATER = 4000.0 # J/kg/K 5 | DENSITY_OF_WATER = 1000.0 # kg/m^3 6 | 7 | SPECIFIC_HEAT_OF_DRY_AIR_CONST_PRESSURE = 1004.6 # J/kg/K 8 | 9 | RVGAS = 461.5 # J/kg/K 10 | RDGAS = 287.05 # J/kg/K 11 | -------------------------------------------------------------------------------- /fme/fme/core/corrector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/fme/core/corrector/__init__.py -------------------------------------------------------------------------------- /fme/fme/core/corrector/ocean.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import datetime 3 | from types import MappingProxyType 4 | from typing import Any, List, Mapping, Optional 5 | 6 | import dacite 7 | 8 | from fme.core.coordinates import HybridSigmaPressureCoordinate 9 | from fme.core.corrector.corrector import force_positive 10 | from fme.core.corrector.registry import CorrectorABC, CorrectorConfigProtocol 11 | from fme.core.gridded_ops import GriddedOperations 12 | from fme.core.masking import MaskingConfig 13 | from fme.core.registry.corrector import CorrectorSelector 14 | from fme.core.stacker import Stacker 15 | from fme.core.typing_ import TensorMapping 16 | 17 | OCEAN_FIELD_NAME_PREFIXES = MappingProxyType( 18 | { 19 | "surface_height": ["zos"], 20 | "salinity": ["so_"], 21 | "potential_temperature": ["thetao_"], 22 | "zonal_velocity": ["uo_"], 23 | "meridional_velocity": ["vo_"], 24 | } 25 | ) 26 | 27 | 28 | @CorrectorSelector.register("ocean_corrector") 29 | @dataclasses.dataclass 30 | class OceanCorrectorConfig(CorrectorConfigProtocol): 31 | masking: Optional[MaskingConfig] = None 32 | force_positive_names: List[str] = dataclasses.field(default_factory=list) 33 | 34 | def build( 35 | self, 36 | gridded_operations: GriddedOperations, 37 | vertical_coordinate: HybridSigmaPressureCoordinate, 38 | timestep: datetime.timedelta, 39 | ): 40 | return OceanCorrector( 41 | config=self, 42 | gridded_operations=gridded_operations, 43 | vertical_coordinate=vertical_coordinate, 44 | timestep=timestep, 45 | ) 46 | 47 | @classmethod 48 | def from_state(cls, state: Mapping[str, Any]) -> "OceanCorrectorConfig": 49 | return dacite.from_dict( 50 | data_class=cls, data=state, config=dacite.Config(strict=True) 51 | ) 52 | 53 | 54 | class OceanCorrector(CorrectorABC): 55 | def __init__( 56 | self, 57 | config: OceanCorrectorConfig, 58 | gridded_operations: GriddedOperations, 59 | vertical_coordinate: HybridSigmaPressureCoordinate, 60 | timestep: datetime.timedelta, 61 | ): 62 | self._config = config 63 | self._gridded_operations = gridded_operations 64 | self._vertical_coordinates = vertical_coordinate 65 | self._timestep = timestep 66 | 67 | if config.masking is not None: 68 | self._masking = config.masking.build() 69 | else: 70 | self._masking = None 71 | self._stacker = Stacker(OCEAN_FIELD_NAME_PREFIXES) 72 | 73 | def __call__( 74 | self, 75 | input_data: TensorMapping, 76 | gen_data: TensorMapping, 77 | forcing_data: TensorMapping, 78 | ) -> TensorMapping: 79 | if self._masking is not None: 80 | gen_data = self._masking(self._stacker, gen_data, input_data) 81 | if len(self._config.force_positive_names) > 0: 82 | gen_data = force_positive(gen_data, self._config.force_positive_names) 83 | return gen_data 84 | -------------------------------------------------------------------------------- /fme/fme/core/corrector/registry.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import datetime 3 | from typing import Any, Mapping, Protocol 4 | 5 | from fme.core.coordinates import HybridSigmaPressureCoordinate 6 | from fme.core.gridded_ops import GriddedOperations 7 | from fme.core.typing_ import TensorMapping 8 | 9 | 10 | class CorrectorConfigProtocol(Protocol): 11 | def build( 12 | self, 13 | gridded_operations: GriddedOperations, 14 | vertical_coordinate: HybridSigmaPressureCoordinate, 15 | timestep: datetime.timedelta, 16 | ) -> "CorrectorABC": ... 17 | 18 | @classmethod 19 | def from_state(cls, state: Mapping[str, Any]) -> "CorrectorConfigProtocol": 20 | """ 21 | Create a ModuleSelector from a dictionary containing all the information 22 | needed to build a ModuleConfig. 23 | """ 24 | ... 25 | 26 | 27 | class CorrectorABC(abc.ABC): 28 | @abc.abstractmethod 29 | def __call__( 30 | self, 31 | input_data: TensorMapping, 32 | gen_data: TensorMapping, 33 | forcing_data: TensorMapping, 34 | ) -> TensorMapping: ... 35 | -------------------------------------------------------------------------------- /fme/fme/core/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/fme/core/dataset/__init__.py -------------------------------------------------------------------------------- /fme/fme/core/dataset/data_typing.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from collections import namedtuple 3 | from typing import Tuple 4 | 5 | import torch 6 | import xarray as xr 7 | 8 | from fme.core.typing_ import TensorDict 9 | 10 | VariableMetadata = namedtuple("VariableMetadata", ["units", "long_name"]) 11 | 12 | 13 | class Dataset(torch.utils.data.Dataset, abc.ABC): 14 | @abc.abstractmethod 15 | def get_sample_by_time_slice( 16 | self, time_slice: slice 17 | ) -> Tuple[TensorDict, xr.DataArray]: 18 | """ 19 | Returns a sample of data for the given time slice. 20 | 21 | Args: 22 | time_slice: The time slice to return data for. 23 | 24 | Returns: 25 | A tuple whose first item is a mapping from variable 26 | name to tensor of shape [n_time, n_lat, n_lon] and 27 | whose second item is a time coordinate array. 28 | """ 29 | ... 30 | -------------------------------------------------------------------------------- /fme/fme/core/dataset/getters.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import List, Optional, Sequence, Tuple 3 | 4 | import torch.utils.data 5 | 6 | from fme.core.dataset.config import XarrayDataConfig 7 | from fme.core.dataset.xarray import DatasetProperties, XarrayDataset, get_xarray_dataset 8 | 9 | from .requirements import DataRequirements 10 | 11 | 12 | def get_datasets( 13 | dataset_configs: Sequence[XarrayDataConfig], 14 | requirements: DataRequirements, 15 | strict: bool = True, 16 | ) -> Tuple[List[XarrayDataset], DatasetProperties]: 17 | datasets = [] 18 | properties: Optional[DatasetProperties] = None 19 | for config in dataset_configs: 20 | dataset, new_properties = get_xarray_dataset(config, requirements) 21 | datasets.append(dataset) 22 | if properties is None: 23 | properties = new_properties 24 | elif not strict: 25 | try: 26 | properties.update(new_properties) 27 | except ValueError as e: 28 | warnings.warn( 29 | f"Metadata for each ensemble member are not the same: {e}" 30 | ) 31 | else: 32 | properties.update(new_properties) 33 | if properties is None: 34 | raise ValueError("At least one dataset must be provided.") 35 | 36 | return datasets, properties 37 | 38 | 39 | def get_dataset( 40 | dataset_configs: Sequence[XarrayDataConfig], 41 | requirements: DataRequirements, 42 | strict: bool = True, 43 | ) -> Tuple[torch.utils.data.ConcatDataset[XarrayDataset], DatasetProperties]: 44 | datasets, properties = get_datasets(dataset_configs, requirements, strict=strict) 45 | ensemble = torch.utils.data.ConcatDataset(datasets) 46 | return ensemble, properties 47 | -------------------------------------------------------------------------------- /fme/fme/core/dataset/requirements.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import List 3 | 4 | 5 | @dataclasses.dataclass 6 | class DataRequirements: 7 | """ 8 | The requirements for batches (time windows) of loaded data. 9 | 10 | Parameters: 11 | names: Names of the variables to load. 12 | n_timesteps: Number of timesteps to load in each batch window. 13 | """ 14 | 15 | names: List[str] 16 | n_timesteps: int 17 | -------------------------------------------------------------------------------- /fme/fme/core/device.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from .typing_ import TensorDict, TensorMapping 6 | 7 | 8 | def using_gpu() -> bool: 9 | return get_device().type == "cuda" 10 | 11 | 12 | def get_device() -> torch.device: 13 | """If CUDA is available, return a CUDA device. Otherwise, return a CPU device 14 | unless FME_USE_MPS is set, in which case return an MPS device if available. 15 | """ 16 | if torch.cuda.is_available(): 17 | return torch.device("cuda", torch.cuda.current_device()) 18 | else: 19 | mps_available = torch.backends.mps.is_available() 20 | if mps_available and os.environ.get("FME_USE_MPS", "0") == "1": 21 | return torch.device("mps", 0) 22 | else: 23 | return torch.device("cpu") 24 | 25 | 26 | def move_tensordict_to_device(data: TensorMapping) -> TensorDict: 27 | device = get_device() 28 | return {name: value.to(device) for name, value in data.items()} 29 | -------------------------------------------------------------------------------- /fme/fme/core/dicts.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Mapping 2 | 3 | 4 | def to_flat_dict(d: Mapping[str, Any]) -> Dict[str, Any]: 5 | """ 6 | Converts any nested dictionaries to a flat version with 7 | the nested keys joined with a '.', e.g., {a: {b: 1}} -> 8 | {a.b: 1}. 9 | """ 10 | new_flat = {} 11 | for k, v in d.items(): 12 | if isinstance(v, dict): 13 | sub_d = to_flat_dict(v) 14 | for sk, sv in sub_d.items(): 15 | new_flat[".".join([k, sk])] = sv 16 | else: 17 | new_flat[k] = v 18 | 19 | return new_flat 20 | 21 | 22 | def to_nested_dict(d: Mapping[str, Any]) -> Dict[str, Any]: 23 | """ 24 | Converts a flat dictionary with '.' joined keys back into 25 | a nested dictionary, e.g., {a.b: 1} -> {a: {b: 1}}. 26 | """ 27 | new_config: Dict[str, Any] = {} 28 | 29 | for k, v in d.items(): 30 | if "." in k: 31 | sub_keys = k.split(".") 32 | sub_d = new_config 33 | for sk in sub_keys[:-1]: 34 | sub_d = sub_d.setdefault(sk, {}) 35 | sub_d[sub_keys[-1]] = v 36 | else: 37 | new_config[k] = v 38 | 39 | return new_config 40 | -------------------------------------------------------------------------------- /fme/fme/core/generics/aggregator.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Dict, Generic, List, TypeVar 3 | 4 | PS = TypeVar("PS", contravariant=True) # prognostic state 5 | T = TypeVar("T", contravariant=True) 6 | 7 | 8 | class AggregatorABC(abc.ABC, Generic[T]): 9 | @abc.abstractmethod 10 | def record_batch(self, batch: T) -> None: 11 | pass 12 | 13 | @abc.abstractmethod 14 | def get_logs(self, label: str) -> Dict[str, float]: 15 | pass 16 | 17 | 18 | InferenceLog = Dict[str, Any] 19 | InferenceLogs = List[InferenceLog] 20 | 21 | 22 | class InferenceAggregatorABC(abc.ABC, Generic[PS, T]): 23 | @abc.abstractmethod 24 | def record_batch( 25 | self, 26 | data: T, 27 | ) -> InferenceLogs: 28 | """ 29 | Record a batch of data. 30 | 31 | Args: 32 | data: Batch of data. 33 | 34 | Returns: 35 | Logs for the batch. 36 | """ 37 | pass 38 | 39 | @abc.abstractmethod 40 | def record_initial_condition( 41 | self, 42 | initial_condition: PS, 43 | ) -> InferenceLogs: 44 | """ 45 | Record the initial condition. 46 | 47 | May only be recorded once, before any calls to record_batch. 48 | """ 49 | pass 50 | 51 | @abc.abstractmethod 52 | def get_summary_logs(self) -> InferenceLog: 53 | pass 54 | -------------------------------------------------------------------------------- /fme/fme/core/generics/data.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Generic, Iterable, Protocol, Sized, TypeVar 3 | 4 | T = TypeVar("T", covariant=True) 5 | 6 | 7 | class DataLoader(Protocol, Generic[T], Sized, Iterable[T]): 8 | pass 9 | 10 | 11 | PS = TypeVar("PS") # prognostic state 12 | FD = TypeVar("FD", covariant=True) # forcing data 13 | 14 | 15 | class InferenceDataABC(abc.ABC, Generic[PS, FD]): 16 | @property 17 | @abc.abstractmethod 18 | def initial_condition(self) -> PS: ... 19 | 20 | @property 21 | @abc.abstractmethod 22 | def loader(self) -> DataLoader[FD]: ... 23 | 24 | 25 | class SimpleInferenceData(InferenceDataABC[PS, FD]): 26 | def __init__( 27 | self, 28 | initial_condition: PS, 29 | loader: DataLoader[FD], 30 | ): 31 | self._initial_condition = initial_condition 32 | self._loader = loader 33 | 34 | @property 35 | def initial_condition(self) -> PS: 36 | return self._initial_condition 37 | 38 | @property 39 | def loader(self) -> DataLoader[FD]: 40 | return self._loader 41 | 42 | 43 | class GriddedDataABC(abc.ABC, Generic[T]): 44 | @property 45 | @abc.abstractmethod 46 | def loader(self) -> DataLoader[T]: ... 47 | 48 | @property 49 | @abc.abstractmethod 50 | def n_samples(self) -> int: ... 51 | 52 | @property 53 | @abc.abstractmethod 54 | def n_batches(self) -> int: ... 55 | 56 | @property 57 | @abc.abstractmethod 58 | def batch_size(self) -> int: ... 59 | 60 | @abc.abstractmethod 61 | def set_epoch(self, epoch: int): ... 62 | 63 | @abc.abstractmethod 64 | def log_info(self, name: str): 65 | """ 66 | Report information about the data using logging.info. 67 | """ 68 | ... 69 | -------------------------------------------------------------------------------- /fme/fme/core/generics/optimization.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import contextlib 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class OptimizationABC(abc.ABC): 9 | @contextlib.contextmanager 10 | @abc.abstractmethod 11 | def autocast(self): ... 12 | 13 | @property 14 | @abc.abstractmethod 15 | def learning_rate(self) -> float: ... 16 | 17 | @abc.abstractmethod 18 | def set_mode(self, module: nn.Module): 19 | """ 20 | Sets the mode of the module to train. 21 | """ 22 | ... 23 | 24 | @abc.abstractmethod 25 | def step_scheduler(self, valid_loss: float): 26 | """ 27 | Step the scheduler. 28 | 29 | Args: 30 | valid_loss: The validation loss. Used in schedulers which change the 31 | learning rate based on whether the validation loss is decreasing. 32 | """ 33 | ... 34 | 35 | @abc.abstractmethod 36 | def step_weights(self, loss: torch.Tensor): ... 37 | 38 | @abc.abstractmethod 39 | def get_state(self): 40 | """ 41 | Returns state as a serializable data structure. 42 | """ 43 | ... 44 | 45 | @abc.abstractmethod 46 | def load_state(self, state): 47 | """ 48 | Loads state from a serializable data structure. 49 | """ 50 | ... 51 | -------------------------------------------------------------------------------- /fme/fme/core/generics/train_stepper.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Dict, Generic, Type, TypeVar 3 | 4 | from torch import nn 5 | 6 | from fme.core.generics.inference import PredictFunction 7 | from fme.core.generics.optimization import OptimizationABC 8 | from fme.core.typing_ import TensorDict 9 | 10 | TO = TypeVar("TO", bound="TrainOutputABC") # train output 11 | 12 | 13 | class TrainOutputABC(abc.ABC): 14 | @abc.abstractmethod 15 | def get_metrics(self) -> TensorDict: 16 | pass 17 | 18 | 19 | PS = TypeVar("PS") # prognostic state 20 | BD = TypeVar("BD") # batch data 21 | FD = TypeVar("FD") # forcing data 22 | SD = TypeVar("SD") # stepped data 23 | 24 | 25 | class TrainStepperABC(abc.ABC, Generic[PS, BD, FD, SD, TO]): 26 | SelfType = TypeVar("SelfType", bound="TrainStepperABC") 27 | 28 | @abc.abstractmethod 29 | def train_on_batch( 30 | self, 31 | data: BD, 32 | optimization: OptimizationABC, 33 | compute_derived_variables: bool = False, 34 | ) -> TO: 35 | pass 36 | 37 | @property 38 | @abc.abstractmethod 39 | def modules(self) -> nn.ModuleList: 40 | pass 41 | 42 | @abc.abstractmethod 43 | def get_state(self) -> Dict[str, Any]: 44 | pass 45 | 46 | @abc.abstractmethod 47 | def load_state(self, state: Dict[str, Any]) -> None: 48 | pass 49 | 50 | @classmethod 51 | @abc.abstractmethod 52 | def from_state(cls: Type[SelfType], state: Dict[str, Any]) -> SelfType: 53 | pass 54 | 55 | @property 56 | @abc.abstractmethod 57 | def n_ic_timesteps(self) -> int: 58 | pass 59 | 60 | @property 61 | @abc.abstractmethod 62 | def predict_paired(self) -> PredictFunction[PS, FD, SD]: 63 | pass 64 | -------------------------------------------------------------------------------- /fme/fme/core/generics/writer.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Generic, TypeVar 3 | 4 | PS = TypeVar("PS", contravariant=True) # prognostic state 5 | SD = TypeVar("SD", contravariant=True) # stepped data 6 | 7 | 8 | class WriterABC(abc.ABC, Generic[PS, SD]): 9 | @abc.abstractmethod 10 | def write(self, data: PS, filename: str): 11 | """Eagerly write data to a file at filename.""" 12 | ... 13 | 14 | @abc.abstractmethod 15 | def append_batch( 16 | self, 17 | batch: SD, 18 | ): 19 | """ 20 | Append a batch of data to the output file(s). 21 | 22 | Args: 23 | batch: Data to be written. 24 | """ 25 | ... 26 | 27 | 28 | class NullDataWriter(WriterABC[Any, Any]): 29 | """ 30 | Null pattern for DataWriter, which does nothing. 31 | """ 32 | 33 | def __init__(self): 34 | pass 35 | 36 | def append_batch( 37 | self, 38 | batch: Any, 39 | ): 40 | pass 41 | 42 | def flush(self): 43 | pass 44 | 45 | def write(self, data: Any, filename: str): 46 | pass 47 | -------------------------------------------------------------------------------- /fme/fme/core/packer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.jit 5 | 6 | from fme.core.typing_ import TensorDict 7 | 8 | 9 | class DataShapesNotUniform(ValueError): 10 | """Indicates that a set of tensors do not all have the same shape.""" 11 | 12 | pass 13 | 14 | 15 | class Packer: 16 | """ 17 | Responsible for packing tensors into a single tensor. 18 | """ 19 | 20 | def __init__(self, names: List[str]): 21 | self.names = names 22 | 23 | def pack(self, tensors: TensorDict, axis=0) -> torch.Tensor: 24 | """ 25 | Packs tensors into a single tensor, concatenated along a new axis. 26 | 27 | Args: 28 | tensors: Dict from names to tensors. 29 | axis: index for new concatenation axis. 30 | 31 | Raises: 32 | DataShapesNotUniform: when packed tensors do not all have the same shape. 33 | """ 34 | shape = next(iter(tensors.values())).shape 35 | for name in tensors: 36 | if tensors[name].shape != shape: 37 | raise DataShapesNotUniform( 38 | ( 39 | f"Cannot pack tensors of different shapes. " 40 | 'Expected "{shape}" got "{tensors[name].shape}"' 41 | ) 42 | ) 43 | return _pack(tensors, self.names, axis=axis) 44 | 45 | def unpack(self, tensor: torch.Tensor, axis=0) -> TensorDict: 46 | return _unpack(tensor, self.names, axis=axis) 47 | 48 | 49 | @torch.jit.script 50 | def _pack(tensors: TensorDict, names: List[str], axis: int = 0) -> torch.Tensor: 51 | return torch.cat([tensors[n].unsqueeze(axis) for n in names], dim=axis) 52 | 53 | 54 | @torch.jit.script 55 | def _unpack(tensor: torch.Tensor, names: List[str], axis: int = 0) -> TensorDict: 56 | return {n: tensor.select(axis, index=i) for i, n in enumerate(names)} 57 | -------------------------------------------------------------------------------- /fme/fme/core/registry/__init__.py: -------------------------------------------------------------------------------- 1 | from .corrector import CorrectorSelector 2 | from .module import ModuleSelector 3 | -------------------------------------------------------------------------------- /fme/fme/core/registry/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Generic, Mapping, Optional, Type, TypeVar 2 | 3 | import dacite 4 | 5 | T = TypeVar("T") 6 | TT = TypeVar("TT", bound=Type) 7 | 8 | 9 | class Registry(Generic[T]): 10 | """ 11 | Used to register and initialize multiple types of a dataclass. 12 | """ 13 | 14 | def __init__(self, default_type: Optional[str] = None): 15 | """ 16 | Initialize the registry. 17 | 18 | Args: 19 | default_type: if given, the "type" key in the config dict is optional 20 | and by default this type will be used. 21 | """ 22 | self._types: Dict[str, Type[T]] = {} 23 | self.default_type = default_type 24 | 25 | def register(self, type_name: str) -> Callable[[TT], TT]: 26 | """ 27 | Registers a configuration type with the registry. 28 | 29 | When registry.from_dict is called to initialize a dataclass, if the 30 | "type" key in that dictionary is equal to the type_name you give here, 31 | then the decorated class will be the one initialized from the data 32 | in the "config" key. 33 | 34 | Args: 35 | type_name: name used in configuration to indicate the decorated 36 | class as the target type to be initialized when using from_dict. 37 | """ 38 | 39 | def register_func(cls: TT) -> TT: 40 | self._types[type_name] = cls 41 | return cls 42 | 43 | return register_func 44 | 45 | def from_dict(self, config: Mapping[str, Any]) -> T: 46 | """ 47 | Creates a registered type from the given config dict. 48 | 49 | Config should have at least one key, "type", which indicates the type to 50 | initialize based on its registered type name. This can be omitted if 51 | this instance was initialized with a default type. 52 | 53 | It can also have a "config" key, which is a dict used to initialize the 54 | dataclass. By default this is an empty dict. 55 | """ 56 | config = dict(config) 57 | config.setdefault("config", {}) 58 | if self.default_type is not None: 59 | type_name = config.get("type", self.default_type) 60 | else: 61 | type_name = config["type"] 62 | if type_name not in self._types: 63 | raise ValueError( 64 | f"Received unexpected type {type_name}, " 65 | f"expected one of {self._types.keys()}" 66 | ) 67 | else: 68 | instance = dacite.from_dict( 69 | data_class=self._types[type_name], 70 | data=config["config"], 71 | config=dacite.Config(strict=True), 72 | ) 73 | return instance 74 | -------------------------------------------------------------------------------- /fme/fme/core/registry/test_module_registry.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Iterable, List, Tuple 3 | 4 | import torch 5 | 6 | from .module import ModuleConfig, ModuleSelector 7 | 8 | 9 | class MockModule(torch.nn.Module): 10 | def __init__(self, param_shapes: Iterable[Tuple[int, ...]]): 11 | super().__init__() 12 | for i, shape in enumerate(param_shapes): 13 | setattr(self, f"param{i}", torch.nn.Parameter(torch.randn(shape))) 14 | 15 | 16 | @ModuleSelector.register("mock") 17 | @dataclasses.dataclass 18 | class MockModuleBuilder(ModuleConfig): 19 | param_shapes: List[Tuple[int, ...]] 20 | 21 | def build(self, n_in_channels, n_out_channels, img_shape): 22 | return MockModule(self.param_shapes) 23 | 24 | @classmethod 25 | def from_state(cls, state): 26 | return cls(state["param_shapes"]) 27 | 28 | def get_state(self): 29 | return { 30 | "param_shapes": self.param_shapes, 31 | } 32 | 33 | 34 | def test_register(): 35 | """Make sure that the registry is working as expected.""" 36 | selector = ModuleSelector(type="mock", config={"param_shapes": [(1, 2, 3)]}) 37 | module = selector.build(1, 1, (16, 32)) 38 | assert isinstance(module, MockModule) 39 | -------------------------------------------------------------------------------- /fme/fme/core/regrid.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai2cm/ace/0c36153610a2bd8160d1639670011af9e03b19b2/fme/fme/core/regrid.py -------------------------------------------------------------------------------- /fme/fme/core/scheduler.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any, Mapping, Optional 3 | 4 | import torch.optim.lr_scheduler 5 | 6 | 7 | @dataclasses.dataclass 8 | class SchedulerConfig: 9 | """ 10 | Configuration for a scheduler to use during training. 11 | 12 | Parameters: 13 | type: Name of scheduler class from torch.optim.lr_scheduler, 14 | no scheduler is used by default. 15 | kwargs: Keyword arguments to pass to the scheduler constructor. 16 | """ 17 | 18 | type: Optional[str] = None 19 | kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict) 20 | 21 | def build( 22 | self, optimizer, max_epochs 23 | ) -> Optional[torch.optim.lr_scheduler._LRScheduler]: 24 | """ 25 | Build the scheduler. 26 | """ 27 | if self.type is None: 28 | return None 29 | 30 | build_kwargs = {**self.kwargs} 31 | # work-around so we don't need to specify T_max 32 | # in the yaml file for this scheduler 33 | if self.type == "CosineAnnealingLR" and "T_max" not in self.kwargs: 34 | build_kwargs["T_max"] = max_epochs 35 | 36 | scheduler_class = getattr(torch.optim.lr_scheduler, self.type) 37 | return scheduler_class(optimizer=optimizer, **build_kwargs) 38 | -------------------------------------------------------------------------------- /fme/fme/core/test_device.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import fme 4 | 5 | 6 | def test_device_is_defined(): 7 | assert isinstance(fme.get_device(), torch.device) 8 | -------------------------------------------------------------------------------- /fme/fme/core/test_dicts.py: -------------------------------------------------------------------------------- 1 | from fme.core.dicts import to_flat_dict, to_nested_dict 2 | 3 | 4 | def get_cfg_and_args_dicts(): 5 | config_d = { 6 | "top": 1, 7 | "seq": [dict(a=1), dict(a=2)], 8 | "nested": {"k1": 2, "k2": 3, "double_nest": {"k1": 4, "k2": 5}}, 9 | } 10 | 11 | flat_d = { 12 | "top": 1, 13 | "seq": [dict(a=1), dict(a=2)], 14 | "nested.k1": 2, 15 | "nested.k2": 3, 16 | "nested.double_nest.k1": 4, 17 | "nested.double_nest.k2": 5, 18 | } 19 | 20 | return config_d, flat_d 21 | 22 | 23 | def test_to_flat_dict(): 24 | config_d, expected = get_cfg_and_args_dicts() 25 | result = to_flat_dict(config_d) 26 | assert result == expected 27 | 28 | 29 | def test_to_nested_dict(): 30 | expected, args_d = get_cfg_and_args_dicts() 31 | result = to_nested_dict(args_d) 32 | assert result == expected 33 | 34 | 35 | def test_flat_dict_round_trip(): 36 | config_d, _ = get_cfg_and_args_dicts() 37 | 38 | args_d = to_flat_dict(config_d) 39 | result = to_nested_dict(args_d) 40 | 41 | assert result == config_d 42 | -------------------------------------------------------------------------------- /fme/fme/core/test_distributed.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from fme import get_device 5 | 6 | from .distributed import pad_tensor_at_end, unpad_tensor_at_end 7 | 8 | 9 | @pytest.mark.parametrize( 10 | ["padding", "fill_value"], 11 | [ 12 | pytest.param([0, 0, 0], None, id="no_padding"), 13 | pytest.param([1, 1, 1], 0.0, id="padding_1"), 14 | pytest.param([1, 1, 1], 1.0, id="padding_1_fill_one"), 15 | ], 16 | ) 17 | def test_pad_tensor_at_end(padding, fill_value): 18 | tensor = torch.ones(2, 3, 4) 19 | padded_tensor = pad_tensor_at_end(tensor, padding, fill_value) 20 | assert padded_tensor.size() == (2 + padding[0], 3 + padding[1], 4 + padding[2]) 21 | for dim, pad in enumerate(padding): 22 | if pad > 0: 23 | assert torch.allclose( 24 | padded_tensor.select(dim=dim, index=padded_tensor.size(dim) - 1), 25 | torch.tensor(fill_value), 26 | ) 27 | 28 | 29 | @pytest.mark.parametrize( 30 | ["padding"], 31 | [ 32 | pytest.param([0, 0, 0], id="no_padding"), 33 | pytest.param([1, 1, 1], id="padding_1"), 34 | ], 35 | ) 36 | def test_pad_unpad_rountrip(padding): 37 | tensor = torch.ones(2, 3, 4, device=get_device()) 38 | padded_tensor = pad_tensor_at_end(tensor, padding) 39 | unpadded_tensor = unpad_tensor_at_end(padded_tensor, padding) 40 | assert unpadded_tensor.size() == tensor.size() 41 | assert torch.allclose(unpadded_tensor, tensor) 42 | -------------------------------------------------------------------------------- /fme/fme/core/test_gridded_ops.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Type 2 | 3 | import pytest 4 | import torch 5 | 6 | from fme.core.gridded_ops import GriddedOperations, HEALPixOperations, LatLonOperations 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "state, expected_class", 11 | [ 12 | ( 13 | { 14 | "type": "LatLonOperations", 15 | "state": {"area_weights": torch.tensor([1.0, 2.0])}, 16 | }, 17 | LatLonOperations, 18 | ), 19 | ( 20 | { 21 | "type": "HEALPixOperations", 22 | "state": {}, 23 | }, 24 | HEALPixOperations, 25 | ), 26 | ], 27 | ) 28 | def test_gridded_operations_from_state( 29 | state: Dict[str, Any], 30 | expected_class: Type[GriddedOperations], 31 | ): 32 | ops = GriddedOperations.from_state(state) 33 | assert isinstance(ops, expected_class) 34 | 35 | recovered_state = ops.to_state() 36 | assert recovered_state == state 37 | 38 | with pytest.raises(RuntimeError): 39 | expected_class.from_state(state["state"]) 40 | -------------------------------------------------------------------------------- /fme/fme/core/test_ocean.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import torch 4 | 5 | from fme.core.ocean import ( 6 | Ocean, 7 | OceanConfig, 8 | SlabOceanConfig, 9 | mixed_layer_temperature_tendency, 10 | ) 11 | 12 | TIMESTEP = datetime.timedelta(hours=6) 13 | 14 | 15 | def test_ocean_prescribed(): 16 | config = OceanConfig(surface_temperature_name="sst", ocean_fraction_name="of") 17 | ocean = Ocean(config, timestep=TIMESTEP) 18 | target_data = {"sst": torch.tensor([22.0, 25.0]), "of": torch.tensor([0.2, 0.8])} 19 | input_data = {"sst": torch.tensor([20.0, 21.0]), "foo": torch.tensor([1, 2])} 20 | gen_data = {"sst": torch.tensor([23.0, 26.0]), "foo": torch.tensor([2, 3])} 21 | output_data = ocean(input_data, gen_data, target_data) 22 | expected_output = {"sst": torch.tensor([23.0, 25.0]), "foo": torch.tensor([2, 3])} 23 | assert set(output_data) == set(expected_output) 24 | for name in output_data: 25 | torch.testing.assert_close(output_data[name], expected_output[name]) 26 | 27 | 28 | def test_ocean_slab(): 29 | config = OceanConfig( 30 | surface_temperature_name="sst", 31 | ocean_fraction_name="of", 32 | slab=SlabOceanConfig( 33 | mixed_layer_depth_name="mld", 34 | q_flux_name="qf", 35 | ), 36 | ) 37 | names_for_net_surface_energy_flux = [ 38 | "DLWRFsfc", 39 | "ULWRFsfc", 40 | "DSWRFsfc", 41 | "USWRFsfc", 42 | "LHTFLsfc", 43 | "SHTFLsfc", 44 | ] 45 | fluxes = {k: torch.tensor([2.0]) for k in names_for_net_surface_energy_flux} 46 | expected_net_surface_energy_flux = torch.tensor([-4.0]) 47 | ocean = Ocean(config, timestep=TIMESTEP) 48 | target_data = { 49 | "mld": torch.tensor([25.0]), 50 | "of": torch.tensor([0.8]), 51 | "qf": torch.tensor([40.0]), 52 | } 53 | input_data = {"sst": torch.tensor([20.0])} 54 | gen_data = {**fluxes, "sst": torch.tensor([25.0])} 55 | output_data = ocean(input_data, gen_data, target_data) 56 | expected_sst_tendency = mixed_layer_temperature_tendency( 57 | expected_net_surface_energy_flux, target_data["qf"], target_data["mld"] 58 | ) 59 | timestep_seconds = TIMESTEP / datetime.timedelta(seconds=1) 60 | expected_sst = input_data["sst"] + timestep_seconds * expected_sst_tendency 61 | expected_output = {**fluxes, "sst": expected_sst} 62 | assert set(output_data) == set(expected_output) 63 | for name in output_data: 64 | torch.testing.assert_close(output_data[name], expected_output[name]) 65 | 66 | 67 | def test_mixed_layer_temperature_tendency(): 68 | f_net = torch.tensor([10.0]) 69 | q_flux = torch.tensor([5.0]) 70 | depth = torch.tensor([100.0]) 71 | result = mixed_layer_temperature_tendency( 72 | f_net, q_flux, depth, density=5.0, specific_heat=3.0 73 | ) 74 | expected_result = (f_net + q_flux) / (5 * 3 * depth) 75 | torch.testing.assert_close(result, expected_result) 76 | -------------------------------------------------------------------------------- /fme/fme/core/test_packer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from fme.core.packer import Packer 4 | 5 | 6 | def test_pack_singleton_channels(): 7 | data = { 8 | "u10": torch.randn(2, 4, 8), 9 | "v10": torch.randn(2, 4, 8), 10 | "w10": torch.randn(2, 4, 8), 11 | } 12 | p = Packer(names=["u10", "v10"]) 13 | packed = p.pack(data, axis=1) 14 | assert packed.shape == (2, 2, 4, 8) 15 | assert torch.allclose(packed[:, 0, :, :], data["u10"]) 16 | assert torch.allclose(packed[:, 1, :, :], data["v10"]) 17 | 18 | 19 | def test_unpack(): 20 | tensor = torch.randn(2, 2, 4, 8) 21 | p = Packer(names=["u10", "v10"]) 22 | unpacked = p.unpack(tensor, axis=1) 23 | assert len(unpacked) == 2 24 | assert torch.allclose(unpacked["u10"], tensor[:, 0, :, :]) 25 | assert torch.allclose(unpacked["v10"], tensor[:, 1, :, :]) 26 | 27 | 28 | def test_unpack_first_axis(): 29 | tensor = torch.randn(2, 2, 4, 8) 30 | p = Packer(names=["u10", "v10"]) 31 | unpacked = p.unpack(tensor, axis=0) 32 | assert len(unpacked) == 2 33 | assert torch.allclose(unpacked["u10"], tensor[0, :, :, :]) 34 | assert torch.allclose(unpacked["v10"], tensor[1, :, :, :]) 35 | -------------------------------------------------------------------------------- /fme/fme/core/test_prescriber.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from fme.core.prescriber import Prescriber, PrescriberConfig 5 | 6 | 7 | def test_prescriber_config_build_raises_value_error(): 8 | """Test that error is raised if name not in in_names and out_names.""" 9 | config = PrescriberConfig(prescribed_name="a", mask_name="b", mask_value=1) 10 | with pytest.raises(ValueError): 11 | config.build(in_names=["a"], out_names=["c"]) 12 | 13 | 14 | def test_prescriber(): 15 | """Test that the prescriber overwrites the generated data in the masked region.""" 16 | prescriber = Prescriber(prescribed_name="a", mask_name="mask", mask_value=0) 17 | data = { 18 | "a": torch.rand(2, 4, 4), 19 | "b": torch.rand(2, 4, 4), 20 | "mask": torch.ones(2, 4, 4), 21 | } 22 | target = { 23 | "a": torch.rand(2, 4, 4), 24 | "b": torch.rand(2, 4, 4), 25 | } 26 | data["mask"][:, :, 2:] = 0 27 | gen = {k: torch.rand_like(v) for k, v in target.items()} 28 | expected_gen = {k: v.clone() for k, v in gen.items()} 29 | expected_gen["a"][:, :, 2:] = target["a"][:, :, 2:] 30 | assert not torch.allclose(gen["a"], expected_gen["a"]) 31 | prescribed_gen = prescriber(data, gen, target) 32 | for name in gen: 33 | torch.testing.assert_close(prescribed_gen[name], expected_gen[name]) 34 | # non-integer valued mask 35 | prescriber = Prescriber(prescribed_name="a", mask_name="mask", mask_value=1) 36 | data["mask"] = torch.zeros(2, 4, 4, dtype=torch.float32) + 0.1 37 | data["mask"][:, :, 2:] = 0.7 38 | prescribed_gen = prescriber(data, gen, target) 39 | for name in gen: 40 | torch.testing.assert_close(prescribed_gen[name], expected_gen[name]) 41 | 42 | 43 | def test_prescriber_interpolate(): 44 | prescriber = Prescriber( 45 | prescribed_name="a", mask_name="mask", mask_value=1, interpolate=True 46 | ) 47 | data = { 48 | "a": torch.zeros(2, 4, 4), 49 | "b": torch.ones(2, 4, 4) * 4.0, 50 | "mask": torch.ones(2, 4, 4) * 0.25, 51 | } 52 | target = { 53 | "a": torch.ones(2, 4, 4) * 4.0, 54 | "b": torch.zeros(2, 4, 4), 55 | } 56 | prescribed_gen = prescriber(data, data, target) 57 | torch.testing.assert_close(prescribed_gen["a"], torch.ones(2, 4, 4)) 58 | # check that the other variable is not changed 59 | torch.testing.assert_close(prescribed_gen["b"], torch.ones(2, 4, 4) * 4.0) 60 | -------------------------------------------------------------------------------- /fme/fme/core/test_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch.optim.lr_scheduler 2 | 3 | from fme.core.scheduler import SchedulerConfig 4 | 5 | 6 | def test_default_gives_none(): 7 | optimizer = torch.optim.Adam(params=[torch.nn.Parameter()]) 8 | max_epochs = 42 9 | assert SchedulerConfig().build(optimizer, max_epochs) is None 10 | 11 | 12 | def test_build(): 13 | config = SchedulerConfig(type="StepLR", kwargs={"step_size": 1}) 14 | # define dummy parameters for optimizer 15 | optimizer = torch.optim.Adam(params=[torch.nn.Parameter()]) 16 | max_epochs = 42 17 | scheduler = config.build(optimizer, max_epochs) 18 | assert isinstance(scheduler, torch.optim.lr_scheduler.StepLR) 19 | assert scheduler.step_size == 1 20 | assert scheduler.optimizer is optimizer 21 | -------------------------------------------------------------------------------- /fme/fme/core/test_wandb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from fme.core.wandb import DirectInitializationError, Image, WandB 5 | 6 | 7 | def test_image_is_image_instance(): 8 | wandb = WandB.get_instance() 9 | img = wandb.Image(np.zeros((10, 10))) 10 | assert isinstance(img, Image) 11 | 12 | 13 | def test_wandb_direct_initialization_raises(): 14 | with pytest.raises(DirectInitializationError): 15 | Image(np.zeros((10, 10))) 16 | -------------------------------------------------------------------------------- /fme/fme/core/test_wildcard.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import List, Optional, Type 3 | 4 | import pytest 5 | import torch 6 | from torch import nn 7 | 8 | from fme.core.wildcard import apply_by_wildcard, wildcard_match 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "pattern, name, expected", 13 | [ 14 | ("*", "abc", True), 15 | ("*", "abc.def", True), 16 | ("abc*", "abc", True), 17 | ("abc*", "abc.def", True), 18 | ("abc*", "def", False), 19 | ("abc.*", "abc.def", True), 20 | ("*.abc.*", "abc.def", False), 21 | ("*.def.*", "abc.def", False), 22 | ("*.def.*", "abc.def.ghi", True), 23 | ("abc.*", "abc.def.ghi", True), 24 | ("abc.*.ghi", "abc.def.ghi", True), 25 | ("abc.*.ghi", "abc.def", False), 26 | ("abc.*.ghi", "abc.def.ghi.jkl", False), 27 | ("*.abc.ghi", "def.abc.ghi", True), 28 | ("*.abc.ghi", "abc.ghi", False), 29 | ("*.abc.ghi", "def.abc.ghi.jkl", False), 30 | ], 31 | ) 32 | def test_wildcard_match(pattern, name, expected): 33 | assert wildcard_match(pattern=pattern, name=name) == expected 34 | 35 | 36 | class NestedModule2(nn.Module): 37 | def __init__(self): 38 | super().__init__() 39 | self.weight = nn.Parameter(torch.randn(3, 3)) 40 | 41 | 42 | class NestedModule1(nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | self.weight = nn.Parameter(torch.randn(3, 3)) 46 | self.nested = NestedModule2() 47 | 48 | 49 | @pytest.mark.parametrize( 50 | "include, exclude, expected_applied, expected_error", 51 | [ 52 | pytest.param(["*"], [], ["weight", "nested.weight"], None, id="include all"), 53 | pytest.param([], ["*"], [], None, id="exclude all"), 54 | pytest.param(["weight"], ["nested.*"], ["weight"], None, id="weight included"), 55 | pytest.param(["*"], ["nested.*"], [], ValueError, id="nested param in both"), 56 | pytest.param(["*"], ["weight"], [], ValueError, id="* include with an exclude"), 57 | pytest.param([], ["weight"], [], ValueError, id="missing weight using exclude"), 58 | pytest.param(["weight"], [], [], ValueError, id="missing weight using include"), 59 | pytest.param( 60 | ["*.weight"], [], [], ValueError, id="mising weight using wildcard include" 61 | ), 62 | ], 63 | ) 64 | def test_apply_by_wildcard( 65 | include: List[str], 66 | exclude: List[str], 67 | expected_applied: List[str], 68 | expected_error: Optional[Type[Exception]], 69 | ): 70 | model = NestedModule1() 71 | 72 | def func(module: nn.Module, name: str): 73 | module.get_parameter(name).requires_grad = False 74 | 75 | if expected_error is not None: 76 | context = pytest.raises(expected_error) 77 | else: 78 | context = contextlib.nullcontext() 79 | 80 | with context: 81 | apply_by_wildcard( 82 | model, 83 | func, 84 | include, 85 | exclude, 86 | ) 87 | 88 | if expected_error is None: 89 | for name, param in model.named_parameters(): 90 | assert param.requires_grad == (name not in expected_applied) 91 | -------------------------------------------------------------------------------- /fme/fme/core/test_winds.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from fme.core.winds import lon_lat_to_xyz, u_v_to_x_y_z_wind 5 | 6 | 7 | def test_u_v_to_x_y_z_wind_energy_conservation(): 8 | """ 9 | Test that the energy is conserved when transforming winds. 10 | """ 11 | # Define a test case 12 | u = np.random.randn(10, 10) 13 | v = np.random.randn(10, 10) 14 | lat = np.random.uniform(-90, 90, size=(10, 10)) 15 | lon = np.random.uniform(-180, 180, size=(10, 10)) 16 | 17 | # Convert to x, y, z 18 | x, y, z = u_v_to_x_y_z_wind(u, v, lat, lon) 19 | 20 | # Compute the energy conservation equation 21 | final_energy = x**2 + y**2 + z**2 22 | initial_energy = u**2 + v**2 23 | 24 | # Check that the energy is conserved 25 | assert np.allclose(final_energy, initial_energy) 26 | 27 | 28 | def test_u_v_to_x_y_z_wind_is_horizontal(): 29 | """ 30 | Test that the transformed winds are perpendicular to the vertical vector. 31 | """ 32 | # Define a test case 33 | u = np.random.randn(10, 10) 34 | v = np.random.randn(10, 10) 35 | lat = np.random.uniform(-90, 90, size=(10, 10)) 36 | lon = np.random.uniform(-180, 180, size=(10, 10)) 37 | 38 | # Convert to x, y, z 39 | wx, wy, wz = u_v_to_x_y_z_wind(u, v, lat, lon) 40 | x, y, z = lon_lat_to_xyz(lon, lat) 41 | 42 | # Compute the dot product 43 | dot_product = wx * x + wy * y + wz * z 44 | 45 | # Check that the dot product is zero 46 | assert np.allclose(dot_product, 0) 47 | 48 | 49 | @pytest.mark.parametrize( 50 | "u, v, lat, lon, expected_x, expected_y, expected_z", 51 | [ 52 | pytest.param(0, 1, 0, 0, 0, 0, 1, id="north_wind_at_equator"), 53 | pytest.param(1, 0, 0, 0, 0, 1, 0, id="east_wind_at_equator_prime_meridian"), 54 | pytest.param( 55 | 1, 0, 0, 90, -1, 0, 0, id="east_wind_at_equator_90_degrees_east_meridian" 56 | ), 57 | pytest.param(0, -1, 0, 0, 0, 0, -1, id="south_wind_at_equator"), 58 | ], 59 | ) 60 | def test_u_v_to_x_y_z_wind_expected_values( 61 | u, v, lat, lon, expected_x, expected_y, expected_z 62 | ): 63 | """ 64 | Test that the expected values are returned. 65 | """ 66 | # Convert to x, y, z 67 | x, y, z = u_v_to_x_y_z_wind(u, v, lat, lon) 68 | 69 | # Check that the expected values are returned 70 | assert np.allclose(x, expected_x) 71 | assert np.allclose(y, expected_y) 72 | assert np.allclose(z, expected_z) 73 | -------------------------------------------------------------------------------- /fme/fme/core/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed import mock_distributed 2 | from .wandb import mock_wandb 3 | -------------------------------------------------------------------------------- /fme/fme/core/testing/distributed.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import List, Optional 3 | 4 | import torch 5 | 6 | from fme.core import distributed 7 | 8 | 9 | class MockDistributed: 10 | def __init__(self, fill_value: float, world_size: int): 11 | self.world_size = world_size 12 | self.fill_value = fill_value 13 | self.reduce_called = False 14 | 15 | def local_batch_size(self, batch_size: int) -> int: 16 | return batch_size 17 | 18 | def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor: 19 | tensor.fill_(self.fill_value) 20 | self.reduce_called = True 21 | return tensor 22 | 23 | def reduce_max(self, tensor: torch.Tensor) -> torch.Tensor: 24 | return tensor + 1 25 | 26 | def reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: 27 | tensor.fill_(self.fill_value) 28 | self.reduce_called = True 29 | return tensor 30 | 31 | def is_root(self) -> bool: 32 | return True 33 | 34 | def is_distributed(self) -> bool: 35 | return True 36 | 37 | def gather(self, tensor: torch.Tensor) -> List[torch.Tensor]: 38 | return [tensor for i in range(self.world_size)] 39 | 40 | def gather_irregular(self, tensor: torch.Tensor) -> Optional[List[torch.Tensor]]: 41 | """ 42 | Note this uses the actual implementation but mocks the underlying 43 | distributed calls. 44 | """ 45 | return distributed.gather_irregular( 46 | tensor, self.reduce_max, self.gather, is_distributed=self.is_distributed() 47 | ) 48 | 49 | 50 | @contextlib.contextmanager 51 | def mock_distributed(fill_value: float = 0.0, world_size: int = 1): 52 | """ 53 | Mock the distributed singleton to return a MockDistributed object. 54 | 55 | This is useful for testing that metrics are reduced across processes. 56 | 57 | It will make it so that when any tensor is reduced, it is filled with 58 | the given fill_value, which can be checked for in tests. 59 | """ 60 | original = distributed.singleton 61 | distributed.singleton = MockDistributed( 62 | fill_value=fill_value, world_size=world_size 63 | ) # type: ignore 64 | try: 65 | yield distributed.singleton 66 | finally: 67 | distributed.singleton = original 68 | -------------------------------------------------------------------------------- /fme/fme/core/testing/wandb.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import contextlib 3 | from typing import Any, Dict, List, Mapping 4 | 5 | from fme.core import wandb 6 | from fme.core.distributed import Distributed 7 | 8 | 9 | class MockWandB: 10 | def __init__(self): 11 | self._enabled = False 12 | self._configured = False 13 | self._logs: Dict[int, Dict[str, Any]] = collections.defaultdict(dict) 14 | self._last_step = 0 15 | 16 | def configure(self, log_to_wandb: bool): 17 | dist = Distributed.get_instance() 18 | self._enabled = log_to_wandb and dist.is_root() 19 | self._configured = True 20 | 21 | def init(self, **kwargs): 22 | if not self._configured: 23 | raise RuntimeError( 24 | "must call WandB.configure before WandB init can be called" 25 | ) 26 | if self._enabled: 27 | pass 28 | 29 | def watch(self, modules): 30 | if self._enabled: 31 | # wandb.watch(modules) 32 | pass 33 | 34 | def log(self, data: Mapping[str, Any], step: int, sleep=None): 35 | if step < self._last_step: 36 | raise ValueError( 37 | f"step {step} is less than last step {self._last_step}, " 38 | "steps must be logged in order" 39 | ) 40 | self._last_step = step 41 | # sleep arg is ignored since we don't want to sleep in tests 42 | if self._enabled: 43 | self._logs[step].update(data) 44 | 45 | def get_logs(self) -> List[Dict[str, Any]]: 46 | if len(self._logs) == 0: 47 | return [] 48 | n_logs = max(self._logs.keys()) 49 | return_value: List[Dict[str, Any]] = [dict() for _ in range(n_logs + 1)] 50 | for step, log in self._logs.items(): 51 | return_value[step] = log 52 | return return_value 53 | 54 | def clean_wandb_dir(self, experiment_dir: str): 55 | pass 56 | 57 | def Image(self, *args, **kwargs) -> wandb.Image: 58 | return wandb.Image(*args, direct_access=False, **kwargs) 59 | 60 | def Video(self, *args, **kwargs) -> wandb.Video: 61 | return wandb.Video(*args, direct_access=False, **kwargs) 62 | 63 | def Table(self, *args, **kwargs) -> wandb.Table: 64 | return wandb.Table(*args, direct_access=False, **kwargs) 65 | 66 | def Histogram(self, *args, **kwargs) -> wandb.Histogram: 67 | return wandb.Histogram(*args, direct_access=False, **kwargs) 68 | 69 | @property 70 | def enabled(self) -> bool: 71 | return self._enabled 72 | 73 | 74 | @contextlib.contextmanager 75 | def mock_wandb(): 76 | """ 77 | Mock the distributed singleton to return a MockDistributed object. 78 | 79 | This is useful for testing that metrics are reduced across processes. 80 | 81 | It will make it so that when any tensor is reduced, it is filled with 82 | the given fill_value, which can be checked for in tests. 83 | """ 84 | original = wandb.singleton 85 | wandb.singleton = MockWandB() # type: ignore 86 | try: 87 | yield wandb.singleton 88 | finally: 89 | wandb.singleton = original 90 | -------------------------------------------------------------------------------- /fme/fme/core/typing_.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Dict, Mapping, Optional 3 | 4 | import torch 5 | 6 | TensorMapping = Mapping[str, torch.Tensor] 7 | TensorDict = Dict[str, torch.Tensor] 8 | 9 | 10 | @dataclasses.dataclass 11 | class Slice: 12 | """ 13 | Configuration of a python `slice` built-in. 14 | 15 | Required because `slice` cannot be initialized directly by dacite. 16 | 17 | Parameters: 18 | start: Start index of the slice. 19 | stop: Stop index of the slice. 20 | step: Step of the slice. 21 | """ 22 | 23 | start: Optional[int] = None 24 | stop: Optional[int] = None 25 | step: Optional[int] = None 26 | 27 | @property 28 | def slice(self) -> slice: 29 | return slice(self.start, self.stop, self.step) 30 | -------------------------------------------------------------------------------- /fme/fme/core/wildcard.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Callable, List 3 | 4 | from torch import nn 5 | 6 | 7 | def wildcard_match(pattern: str, name: str) -> bool: 8 | """ 9 | Check if a name matches a wildcard pattern. 10 | 11 | A wildcard pattern can include "*" to match any number of characters. 12 | """ 13 | # use regex 14 | pattern = pattern.replace(".", r"\.") 15 | pattern = pattern.replace("*", ".*") 16 | pattern = f"^{pattern}$" 17 | return bool(re.match(pattern, name)) 18 | 19 | 20 | def apply_by_wildcard( 21 | model: nn.Module, 22 | func: Callable[[nn.Module, str], None], 23 | include: List[str], 24 | exclude: List[str], 25 | ): 26 | missing_parameters = [] 27 | for name in model.state_dict().keys(): 28 | if any(wildcard_match(pattern, name) for pattern in include): 29 | if any(wildcard_match(pattern, name) for pattern in exclude): 30 | raise ValueError( 31 | f"Parameter {name} is included in both include " 32 | f"{include} and exclude {exclude}" 33 | ) 34 | func(model, name) 35 | elif not any(wildcard_match(pattern, name) for pattern in exclude): 36 | missing_parameters.append(name) 37 | if len(missing_parameters) > 0: 38 | raise ValueError( 39 | f"Model has parameters {missing_parameters} which are not " 40 | f"specified in either include {include} " 41 | f"or exclude {exclude}" 42 | ) 43 | return model 44 | -------------------------------------------------------------------------------- /fme/fme/require_gpu.py: -------------------------------------------------------------------------------- 1 | import fme 2 | 3 | """ 4 | Manually triggered for CI tests on GPU so that tests do not 5 | default to CPU if driver issues prevent use of CUDA. 6 | """ 7 | device = str(fme.get_device()) 8 | print(f"Device: {device}") 9 | assert device.startswith("cuda") 10 | -------------------------------------------------------------------------------- /fme/fme/test_harmonics.py: -------------------------------------------------------------------------------- 1 | """In this test module, we test the torch spherical harmonics module.""" 2 | 3 | import pytest 4 | import torch 5 | import torch_harmonics as harmonics 6 | 7 | 8 | @pytest.mark.parametrize("nlat, nlon", [(6, 12)]) 9 | @pytest.mark.parametrize("grid", ["equiangular", "legendre-gauss"]) 10 | @pytest.mark.parametrize("constant", [1.0, 0.42]) 11 | def test_constant_field(nlat, nlon, grid, constant): 12 | """Tests that the SHT of a constant field has a single non-zero wavenumber, 13 | the first one. 14 | """ 15 | constant_field = torch.tensor(constant).repeat(nlat, nlon) 16 | 17 | sht = harmonics.RealSHT(nlat, nlon, grid=grid) 18 | 19 | coeffs = sht(constant_field).ravel() 20 | zero = torch.zeros(1, dtype=torch.complex64) 21 | 22 | assert not torch.isclose(coeffs[0], zero) 23 | assert torch.all(torch.isclose(zero, coeffs[1:], atol=1e-6)) 24 | 25 | 26 | def _roundtrip(field, grid): 27 | nlat, nlon = field.shape 28 | sht = harmonics.RealSHT(nlat, nlon, grid=grid) 29 | isht = harmonics.InverseRealSHT(nlat, nlon, grid=grid) 30 | return isht(sht(field)) 31 | 32 | 33 | @pytest.mark.parametrize("seed", [0, 1]) 34 | @pytest.mark.parametrize("nlat, nlon", [(6, 12)]) 35 | def test_roundtrip(nlat, nlon, seed, grid="legendre-gauss"): 36 | """Tests that the SHT and ISHT are inverses of each other.""" 37 | torch.manual_seed(seed) 38 | random_field = torch.randn(nlat, nlon) 39 | proj = _roundtrip(random_field, grid) 40 | assert torch.all(torch.isclose(proj, _roundtrip(proj, grid), atol=1e-6)) 41 | -------------------------------------------------------------------------------- /fme/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "fme" 7 | description = "Train and evaluate weather/climate model emulators" 8 | readme = "README.md" 9 | requires-python = ">=3.9" 10 | license = {file = "LICENSE"} 11 | authors = [ 12 | {name = "Ai2 Climate Modeling", email = "climate-public-maintainer@allenai.org"} 13 | ] 14 | keywords = ["weather", "climate", "machine learning", "emulation"] 15 | classifiers = [ 16 | "Intended Audience :: Science/Research", 17 | "Development Status :: 4 - Beta", 18 | "License :: OSI Approved :: Apache Software License", 19 | "Programming Language :: Python :: 3", 20 | "Topic :: Scientific/Engineering :: Artificial Intelligence" 21 | ] 22 | dynamic = ["dependencies", "optional-dependencies", "version"] 23 | [project.urls] 24 | Homepage = "https://github.com/ai2cm/ace" 25 | Documentation = "https://ai2-climate-emulator.readthedocs.io/" 26 | 27 | [tool.setuptools.dynamic] 28 | version = {attr = "fme.__version__"} 29 | dependencies = { file = "requirements.txt" } 30 | optional-dependencies.dev = { file = "dev-requirements.txt" } 31 | optional-dependencies.docs = { file = "docs/requirements.txt" } 32 | optional-dependencies.deploy = { file = "deploy-requirements.txt" } 33 | 34 | [tool.setuptools.packages] 35 | find = {} 36 | 37 | [tool.uv] 38 | cache-keys = [ 39 | { file = "requirements.txt" }, 40 | { file = "dev-requirements.txt" }, 41 | { file = "docs/requirements.txt" }, 42 | ] 43 | 44 | [tool.ruff.lint] 45 | select = ["D", "E", "F", "I", "W"] 46 | ignore = ["D1", "D200", "D205", "D212", "E203", "W293", "F541", "E402"] 47 | 48 | [tool.ruff.lint.per-file-ignores] 49 | "*/__init__.py" = ["F401"] 50 | "scripts/*" = ["D"] 51 | "test_*.py" = ["D"] 52 | 53 | [tool.ruff.lint.pydocstyle] 54 | convention = "google" 55 | -------------------------------------------------------------------------------- /fme/requirements.txt: -------------------------------------------------------------------------------- 1 | h5py 2 | imageio<=2.27.0 3 | moviepy<2.0.0 # should be able to relax this after wandb updates past 0.18.7 4 | netcdf4 5 | numpy<2 6 | wandb[media] 7 | tensorly 8 | tensorly-torch 9 | xarray 10 | dacite 11 | torch 12 | torch-harmonics==0.6.2 13 | zarr 14 | gcsfs 15 | s3fs 16 | plotly 17 | matplotlib 18 | dask 19 | astropy-healpix 20 | pandas -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Purpose 2 | 3 | This directory contains subdirectories for particular research subtasks. A project 4 | directory can hold configuration files, data preparation scripts, post-training 5 | evaluation scripts, and other files as needed. 6 | 7 | Low barrier for review/merge. 8 | -------------------------------------------------------------------------------- /scripts/data_process/README.md: -------------------------------------------------------------------------------- 1 | # Data processing for full model emulation training 2 | 3 | This directory contains scripts for generating various datasets needed for FME training, including the FV3GFS primary, baseline, and stats datasets. 4 | 5 | It also contains scripts for generating E3SM training data. 6 | 7 | The first step in the process to create intermediate datasets (e.g. `make fv3gfs_AMIP_dataset`) uses argo, and can be run on your Google VM. 8 | See the vcm-workflow-control repo for instructions on how to install and run argo. 9 | 10 | The second step, which produces monthly netCDF files locally (e.g. `make fv3gfs_AMIP_monthly_netcdfs`), can be run on cirrascale in an interactive session. 11 | To create an interactive session, run the following command from the `scripts/data_process` directory: 12 | 13 | ``` 14 | beaker session create --budget ai2/climate --image beaker://jeremym/fme-2bc0033e --gpus 0 --mount hostPath:///net/nfs/climate=/net/nfs/climate --mount hostpath://$(pwd)=/full-model --workdir /full-model/scripts/data_process 15 | ``` 16 | 17 | Doing so will require that your current working directory is a mountable path (e.g. something in /data). 18 | If you'd like to write to a different directory than /net/nfs/climate, you can mount that path instead. 19 | 20 | Once inside the image, you will need to authorize access to GCS by running `gcloud auth application-default login` and following the instructions, including to run `gcloud config set project vcm-ml` afterwards. 21 | 22 | You can then produce the monthly netCDFs in a target directory by modifying the `OUTPUT_DIR` or `OUTPUT_DIR_AMIP` variable in the make command below. 23 | 24 | ``` 25 | make fv3gfs_AMIP_monthly_netcdfs RESOLUTION=4deg OUTPUT_DIR_AMIP=/data/shared/2023-12-20-vertically-resolved-4deg-fme-amip-ensemble-dataset 26 | ``` 27 | 28 | The stats dataset creation step (e.g. `make fv3gfs_AMIP_stats_beaker_dataset`) must be run in the fme conda environment (created by `make create_environment` at the top level of this repo), and additionally requires the beaker client is installed ([install instructions](https://beaker-docs.apps.allenai.org/start/install.html)). 29 | -------------------------------------------------------------------------------- /scripts/data_process/beakerpy.Dockerfile: -------------------------------------------------------------------------------- 1 | from python:3.10-slim 2 | 3 | RUN pip install beaker-py==1.30.0 click dacite fsspec==2024.6.1 gcsfs==2024.6.1 4 | -------------------------------------------------------------------------------- /scripts/data_process/compute_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | COMPUTE_DATASET=true 6 | 7 | while [[ "$#" -gt 0 ]] 8 | do case $1 in 9 | --config) CONFIG="$2" 10 | shift;; 11 | --stats-only) COMPUTE_DATASET=false;; 12 | *) echo "Unknown parameter passed: $1" 13 | exit 1;; 14 | esac 15 | shift 16 | done 17 | 18 | if [[ -z "${CONFIG}" ]] 19 | then 20 | echo "Option --config missing" 21 | exit 1; 22 | fi 23 | 24 | names=($(yq -r '.runs | to_entries[].key' ${CONFIG})) 25 | run_directories=($(yq -r '.runs | to_entries[].value' ${CONFIG})) 26 | output_directory=$(yq -r '.data_output_directory' ${CONFIG}) 27 | runs_count=$(yq -r '.runs | length' ${CONFIG}) 28 | runs_count_minus_one=$(($runs_count - 1)) 29 | 30 | # Capture the output of the argo submit command 31 | output=$(argo submit compute_dataset_argo_workflow.yaml \ 32 | -p compute_dataset=${COMPUTE_DATASET} \ 33 | -p python_script="$(< compute_dataset.py)" \ 34 | -p get_stats_script="$(< get_stats.py)" \ 35 | -p combine_stats_script="$(< combine_stats.py)" \ 36 | -p upload_stats_script="$(< upload_stats.py)" \ 37 | -p config="$(< ${CONFIG})" \ 38 | -p names="${names[*]}" \ 39 | -p run_directories="${run_directories[*]}" \ 40 | -p output_directory="${output_directory}" \ 41 | -p runs_count_minus_one=${runs_count_minus_one}) 42 | 43 | # Extract the job name from the output 44 | job_name=$(echo "$output" | grep 'Name:' | awk '{print $2}') 45 | 46 | # Print the job name 47 | echo "Argo job submitted: $job_name" 48 | -------------------------------------------------------------------------------- /scripts/data_process/compute_hpx_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Function to check if directory is in Python path and add it if not 4 | function add_to_pythonpath() { 5 | local dir_to_add="$1" 6 | if [[ ":$PYTHONPATH:" != *":$dir_to_add:"* ]]; then 7 | export PYTHONPATH="$dir_to_add:$PYTHONPATH" 8 | fi 9 | } 10 | 11 | # Add your own full-model directory to Python path if not already included 12 | add_to_pythonpath "~/full-model" 13 | 14 | ARGO=false 15 | 16 | while [[ "$#" -gt 0 ]] 17 | do 18 | case $1 in 19 | --config) CONFIG="$2" 20 | shift;; 21 | --argo) ARGO="$2" 22 | shift;; 23 | *) echo "Unknown parameter passed: $1" 24 | exit 1;; 25 | esac 26 | shift 27 | done 28 | 29 | if [[ -z "${CONFIG}" ]] 30 | then 31 | echo "Option --config missing" 32 | exit 1 33 | fi 34 | 35 | run_directory=$(yq -r '.runs.run_directory' ${CONFIG}) 36 | output_directory=$(yq -r '.data_output_directory' ${CONFIG}) 37 | 38 | if [[ "$ARGO" == "true" ]] 39 | then 40 | output=$(argo submit full-model/scripts/data_process/compute_hpx_dataset_argo_workflow.yaml \ 41 | -p python_script="$(< full-model/scripts/data_process/compute_hpx_dataset.py)" \ 42 | -p config="$(< ${CONFIG})" \ 43 | -p run_directory="${run_directory}" \ 44 | -p output_directory="${output_directory}") 45 | 46 | job_name=$(echo "$output" | grep 'Name:' | awk '{print $2}') 47 | echo "Argo job submitted: $job_name" 48 | else 49 | python3 full-model/scripts/data_process/compute_hpx_dataset.py --config="${CONFIG}" \ 50 | --run-directory="${run_directory}" \ 51 | --output-store="${output_directory}" 52 | fi -------------------------------------------------------------------------------- /scripts/data_process/compute_stats.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./compute_dataset.sh --stats-only "$@" 4 | -------------------------------------------------------------------------------- /scripts/data_process/configs/e3sm-1deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | 2024-07-10-e3smv2-1deg-testing: "" 3 | data_output_directory: /global/cfs/cdirs/m4492/fme-preprocess/zarr/ 4 | stats: 5 | output_directory: /global/cfs/cdirs/m4492/fme-preprocess/2024-07-10-e3smv2-1deg-testing 6 | start_date: "1970-01-01" 7 | end_date: "1970-12-31" 8 | data_type: E3SMV2 9 | beaker_dataset: e3sm-1deg-8layers-stats-1970 # this is not used in e3sm data processing 10 | dataset_computation: 11 | chunking: 12 | time_dim: 10 13 | latitude_dim: 180 14 | longitude_dim: 360 15 | reference_vertical_coordinate_file: None 16 | time_invariant_dir: /global/cfs/cdirs/m4331/jpduncan/e3smv2/time_invariant 17 | vertical_coarsening_indices: 18 | # computed here: https://github.com/ai2cm/explore/blob/master/jamesd/2023-06-09-e3smv2-vertical-interface-indices.ipynb 19 | - [0, 19] 20 | - [19, 30] 21 | - [30, 38] 22 | - [38, 44] 23 | - [44, 48] 24 | - [48, 53] 25 | - [53, 61] 26 | - [61, 72] 27 | roundtrip_fraction_kept: 1.0 28 | n_split: 100 29 | variable_sources: 30 | time_invariant: 31 | - PHIS 32 | 6hourly_instant/1yr: 33 | - PS 34 | - TS 35 | - T 36 | - U 37 | - V 38 | - Q 39 | - CLDLIQ 40 | - CLDICE 41 | - RAINQM 42 | - SNOWQM 43 | - TMQ 44 | - TGCLDLWP 45 | - TGCLDIWP 46 | - OCNFRAC 47 | - LANDFRAC 48 | - ICEFRAC 49 | 6hourly/1yr: 50 | - PRECT 51 | - LHFLX 52 | - SHFLX 53 | - FLNS 54 | - FLDS 55 | - FSNS 56 | - FSDS 57 | - FSNTOA 58 | - SOLIN 59 | - FLUT 60 | - PRECSC 61 | - PRECSL 62 | - QFLX 63 | standard_names: 64 | longitude_dim: lon 65 | latitude_dim: lat 66 | vertical_dim: lev 67 | vertical_interface_dim: ilev 68 | time_dim: time 69 | surface_pressure: PS 70 | latent_heat_flux: LHFLX 71 | precip_rate: PRECT 72 | precipitable_water_path: precipitable_water_path 73 | pressure_thickness: pressure_thickness_of_atmospheric_layer 74 | air_temperature: T 75 | specific_humidity: Q 76 | cloud_water_mixing_ratio: CLDLIQ 77 | cloud_ice_mixing_ratio: CLDICE 78 | graupel_mixing_ratio: None 79 | rain_mixing_ratio: RAINQM 80 | snow_mixing_ratio: SNOWQM 81 | northward_wind: V 82 | eastward_wind: U 83 | hybrid_level_coeffs: 84 | - hyai 85 | - hybi -------------------------------------------------------------------------------- /scripts/data_process/configs/era5-1deg-16layer-1940-2022.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | 2024-07-11-era5-1deg-16layer-1940-2022: "" # no real data source, config only for computing stats 3 | data_output_directory: gs://vcm-ml-intermediate 4 | stats: 5 | output_directory: gs://vcm-ml-intermediate/era5-1deg-16layer-stats-1990-2019 6 | beaker_dataset: era5-1deg-16layer-stats-1990-2019 7 | start_date: "1990-01-01" 8 | end_date: "2019-12-31" 9 | data_type: ERA5 10 | -------------------------------------------------------------------------------- /scripts/data_process/configs/era5-1deg-8layer-1940-2022.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | 2024-06-20-era5-1deg-8layer-1940-2022: "" # no real data source, config only for computing stats 3 | data_output_directory: gs://vcm-ml-intermediate 4 | stats: 5 | output_directory: gs://vcm-ml-intermediate/2024-06-20-era5-1deg-8layer-stats-1990-2019 6 | beaker_dataset: era5-1deg-8layer-stats-1990-2019-v2 7 | start_date: "1990-01-01" 8 | end_date: "2019-12-31" 9 | data_type: ERA5 10 | -------------------------------------------------------------------------------- /scripts/data_process/configs/fv3gfs-amip-ensemble-1deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | ic_0001: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0001 3 | ic_0002: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0002 4 | ic_0003: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0003 5 | ic_0004: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_180_by_360/ic_0004 6 | data_output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-1deg-fme-amip-ensemble-dataset 7 | stats: 8 | output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-1deg-fme-amip-ensemble-dataset-stats 9 | beaker_dataset: 2023-10-27-vertically-resolved-1deg-fme-amip-ensemble-dataset-stats 10 | start_date: "1990-01-01" 11 | end_date: "2019-12-31" 12 | data_type: FV3GFS 13 | exclude_runs: 14 | - "ic_0004" 15 | dataset_computation: 16 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc 17 | vertical_coarsening_indices: 18 | - [0, 18] 19 | - [18, 26] 20 | - [26, 31] 21 | - [31, 36] 22 | - [36, 41] 23 | - [41, 47] 24 | - [47, 53] 25 | - [53, 63] 26 | renaming: 27 | specific_humidity_at_two_meters: Q2m 28 | variable_sources: 29 | fluxes_2d.zarr: 30 | - PRATEsfc 31 | - LHTFLsfc 32 | - SHTFLsfc 33 | - DLWRFsfc 34 | - DSWRFsfc 35 | - DSWRFtoa 36 | - ULWRFsfc 37 | - ULWRFtoa 38 | - USWRFsfc 39 | - USWRFtoa 40 | - precipitable_water_path 41 | - GRAUPELsfc 42 | - ICEsfc 43 | - SNOWsfc 44 | fourcastnet_vanilla.zarr: 45 | - PRESsfc 46 | - HGTsfc 47 | - RH500 48 | - RH850 49 | - TMP500 50 | - TMP850 51 | - UGRD500 52 | - UGRD850 53 | - UGRD1000 54 | - VGRD500 55 | - VGRD850 56 | - VGRD1000 57 | - h50 58 | - h500 59 | - h850 60 | - h1000 61 | - TMP2m 62 | - UGRD10m 63 | - VGRD10m 64 | full_state.zarr: 65 | - surface_temperature 66 | - air_temperature 67 | - specific_humidity 68 | - cloud_water_mixing_ratio 69 | - cloud_ice_mixing_ratio 70 | - graupel_mixing_ratio 71 | - rain_mixing_ratio 72 | - snow_mixing_ratio 73 | - northward_wind 74 | - eastward_wind 75 | - pressure_thickness_of_atmospheric_layer 76 | - soil_moisture 77 | - specific_humidity_at_two_meters 78 | encoded_surface_type.zarr: 79 | - land_fraction 80 | - ocean_fraction 81 | - sea_ice_fraction 82 | -------------------------------------------------------------------------------- /scripts/data_process/configs/fv3gfs-amip-ensemble-4deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | ic_0001: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0001 3 | ic_0002: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0002 4 | ic_0003: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0003 5 | ic_0004: gs://vcm-ml-raw-flexible-retention/2023-10-20-C96-FME-AMIP-ensemble-reference/regridded-zarrs/gaussian_grid_45_by_90/ic_0004 6 | data_output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-4deg-fme-amip-ensemble-dataset 7 | stats: 8 | output_directory: gs://vcm-ml-intermediate/2023-10-27-vertically-resolved-4deg-fme-amip-ensemble-dataset-stats 9 | beaker_dataset: 2023-10-27-vertically-resolved-4deg-fme-amip-ensemble-dataset-stats 10 | start_date: "1990-01-01" 11 | end_date: "2019-12-31" 12 | exclude_runs: 13 | - "ic_0004" 14 | data_type: FV3GFS 15 | dataset_computation: 16 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc 17 | vertical_coarsening_indices: 18 | - [0, 18] 19 | - [18, 26] 20 | - [26, 31] 21 | - [31, 36] 22 | - [36, 41] 23 | - [41, 47] 24 | - [47, 53] 25 | - [53, 63] 26 | renaming: 27 | specific_humidity_at_two_meters: Q2m 28 | variable_sources: 29 | fluxes_2d.zarr: 30 | - PRATEsfc 31 | - LHTFLsfc 32 | - SHTFLsfc 33 | - DLWRFsfc 34 | - DSWRFsfc 35 | - DSWRFtoa 36 | - ULWRFsfc 37 | - ULWRFtoa 38 | - USWRFsfc 39 | - USWRFtoa 40 | - precipitable_water_path 41 | - GRAUPELsfc 42 | - ICEsfc 43 | - SNOWsfc 44 | fourcastnet_vanilla.zarr: 45 | - PRESsfc 46 | - HGTsfc 47 | - RH500 48 | - RH850 49 | - TMP500 50 | - TMP850 51 | - UGRD500 52 | - UGRD850 53 | - UGRD1000 54 | - VGRD500 55 | - VGRD850 56 | - VGRD1000 57 | - h50 58 | - h500 59 | - h850 60 | - h1000 61 | - TMP2m 62 | - UGRD10m 63 | - VGRD10m 64 | full_state.zarr: 65 | - surface_temperature 66 | - air_temperature 67 | - specific_humidity 68 | - cloud_water_mixing_ratio 69 | - cloud_ice_mixing_ratio 70 | - graupel_mixing_ratio 71 | - rain_mixing_ratio 72 | - snow_mixing_ratio 73 | - northward_wind 74 | - eastward_wind 75 | - pressure_thickness_of_atmospheric_layer 76 | - soil_moisture 77 | - specific_humidity_at_two_meters 78 | encoded_surface_type.zarr: 79 | - land_fraction 80 | - ocean_fraction 81 | - sea_ice_fraction 82 | -------------------------------------------------------------------------------- /scripts/data_process/configs/fv3gfs-c48-ensemble-1deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | ic_0011: gs://vcm-ml-raw-flexible-retention/2023-08-03-C48-FME-reference-ensemble/regridded-zarrs/gaussian_grid_180_by_360/ic_0011_2021010100 3 | data_output_directory: gs://vcm-ml-intermediate/2023-09-01-vertically-resolved-1deg-fme-c48-baseline-dataset 4 | stats: 5 | output_directory: gs://vcm-ml-intermediate/2023-09-01-vertically-resolved-1deg-fme-c48-baseline-dataset-stats 6 | beaker_dataset: 2023-09-01-vertically-resolved-1deg-fme-c48-baseline-dataset-stats 7 | start_date: "2021-01-01" 8 | end_date: "2030-12-31" 9 | data_type: FV3GFS 10 | dataset_computation: 11 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc 12 | vertical_coarsening_indices: 13 | - [0, 18] 14 | - [18, 26] 15 | - [26, 31] 16 | - [31, 36] 17 | - [36, 41] 18 | - [41, 47] 19 | - [47, 53] 20 | - [53, 63] 21 | renaming: 22 | specific_humidity_at_two_meters: Q2m 23 | variable_sources: 24 | fluxes_2d.zarr: 25 | - PRATEsfc 26 | - LHTFLsfc 27 | - SHTFLsfc 28 | - DLWRFsfc 29 | - DSWRFsfc 30 | - DSWRFtoa 31 | - ULWRFsfc 32 | - ULWRFtoa 33 | - USWRFsfc 34 | - USWRFtoa 35 | - precipitable_water_path 36 | - GRAUPELsfc 37 | - ICEsfc 38 | - SNOWsfc 39 | fourcastnet_vanilla.zarr: 40 | - PRESsfc 41 | - HGTsfc 42 | - RH500 43 | - RH850 44 | - TMP500 45 | - TMP850 46 | - UGRD500 47 | - UGRD850 48 | - UGRD1000 49 | - VGRD500 50 | - VGRD850 51 | - VGRD1000 52 | - h50 53 | - h500 54 | - h850 55 | - h1000 56 | - TMP2m 57 | - UGRD10m 58 | - VGRD10m 59 | full_state.zarr: 60 | - surface_temperature 61 | - air_temperature 62 | - specific_humidity 63 | - cloud_water_mixing_ratio 64 | - cloud_ice_mixing_ratio 65 | - graupel_mixing_ratio 66 | - rain_mixing_ratio 67 | - snow_mixing_ratio 68 | - northward_wind 69 | - eastward_wind 70 | - pressure_thickness_of_atmospheric_layer 71 | - soil_moisture 72 | - specific_humidity_at_two_meters 73 | encoded_surface_type.zarr: 74 | - land_fraction 75 | - ocean_fraction 76 | - sea_ice_fraction 77 | roundtrip_fraction_kept: 0.65 78 | -------------------------------------------------------------------------------- /scripts/data_process/configs/healpix-1deg-8layer-1940-2022.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | run_directory: /mntdata 3 | stats: 4 | output_directory: /mntdata/2024-08-21-healpix-era5-dataset 5 | beaker_dataset: 2024-08-21-healpix-era5-dataset-stats 6 | start_date: "1990-01-01" 7 | end_date: "2019-12-31" 8 | data_type: ERA5 9 | data_output_directory: /mntdata/2024-08-21-healpix-era5-dataset 10 | dataset_computation: 11 | n_split: 400 12 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2023-04-13-11-year-C96-FME-reference/vertical-coordinate-file/fv_core.res.nc 13 | vertical_coarsening_indices: 14 | - [0, 18] 15 | standard_names: 16 | face_dim: "face" 17 | chunking: 18 | face_dim: -1 19 | variable_sources: 20 | 2024-06-20-era5-1deg-8layer-1940-2022.zarr: 21 | - DLWRFsfc 22 | - DPT2m 23 | - DSWRFsfc 24 | - DSWRFtoa 25 | - HGTsfc 26 | - LHTFLsfc 27 | - PRATEsfc 28 | - PRESsfc 29 | - Q200 30 | - Q2m 31 | - Q500 32 | - Q850 33 | - SHTFLsfc 34 | - TMP200 35 | - TMP2m 36 | - TMP500 37 | - TMP850 38 | - UGRD10m 39 | - UGRD200 40 | - UGRD500 41 | - UGRD850 42 | - ULWRFsfc 43 | - ULWRFtoa 44 | - USWRFsfc 45 | - USWRFtoa 46 | - VGRD10m 47 | - VGRD200 48 | - VGRD500 49 | - VGRD850 50 | - air_temperature_0 51 | - air_temperature_1 52 | - air_temperature_2 53 | - air_temperature_3 54 | - air_temperature_4 55 | - air_temperature_5 56 | - air_temperature_6 57 | - air_temperature_7 58 | - ak_0 59 | - ak_1 60 | - ak_2 61 | - ak_3 62 | - ak_4 63 | - ak_5 64 | - ak_6 65 | - ak_7 66 | - ak_8 67 | - bk_0 68 | - bk_1 69 | - bk_2 70 | - bk_3 71 | - bk_4 72 | - bk_5 73 | - bk_6 74 | - bk_7 75 | - bk_8 76 | - eastward_wind_0 77 | - eastward_wind_1 78 | - eastward_wind_2 79 | - eastward_wind_3 80 | - eastward_wind_4 81 | - eastward_wind_5 82 | - eastward_wind_6 83 | - eastward_wind_7 84 | - h1000 85 | - h200 86 | - h250 87 | - h300 88 | - h500 89 | - h700 90 | - h850 91 | - land_fraction 92 | - latitude 93 | - longitude 94 | - northward_wind_0 95 | - northward_wind_1 96 | - northward_wind_2 97 | - northward_wind_3 98 | - northward_wind_4 99 | - northward_wind_5 100 | - northward_wind_6 101 | - northward_wind_7 102 | - ocean_fraction 103 | - sea_ice_fraction 104 | - soil_moisture_0 105 | - soil_moisture_1 106 | - soil_moisture_2 107 | - soil_moisture_3 108 | - specific_total_water_0 109 | - specific_total_water_1 110 | - specific_total_water_2 111 | - specific_total_water_3 112 | - specific_total_water_4 113 | - specific_total_water_5 114 | - specific_total_water_6 115 | - specific_total_water_7 116 | - surface_temperature 117 | - tendency_of_total_water_path_due_to_advection 118 | - time 119 | - total_column_water_vapour -------------------------------------------------------------------------------- /scripts/data_process/configs/pre-industrial-CM4-1deg-8layer-trial-run.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | 2024-09-20-cm4-1deg-8layer-trial-run: gs://vcm-ml-raw-flexible-retention/2024-08-10-pre-industrial-CM4-simulation/regridded-zarrs/gaussian_grid_180_by_360/trial-run 3 | data_output_directory: gs://vcm-ml-intermediate 4 | stats: 5 | output_directory: gs://vcm-ml-intermediate/2024-09-20-cm4-1deg-8layer-trial-run-stats 6 | start_date: "0151-01-01" 7 | end_date: "0159-01-01" 8 | data_type: CM4 9 | beaker_dataset: not-used 10 | dataset_computation: 11 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-08-10-pre-industrial-CM4-simulation/vertical-coordinate-file/fv_core.res.nc 12 | vertical_coarsening_indices: 13 | # computed here: https://github.com/ai2cm/explore/blob/master/jamesd/2024-08-13-pre-industiral-CM4-eda/2024-08-28-AM4-vertical-indices.ipynb 14 | - [0, 7] 15 | - [7, 10] 16 | - [10, 13] 17 | - [13, 16] 18 | - [16, 18] 19 | - [18, 22] 20 | - [22, 25] 21 | - [25, 33] 22 | renaming: 23 | specific_humidity_at_two_meters: Q2m 24 | air_temperature_at_two_meters: TMP2m 25 | eastward_wind_at_ten_meters: UGRD10m 26 | northward_wind_at_ten_meters: VGRD10m 27 | variable_sources: 28 | fluxes_2d.zarr: 29 | - PRATEsfc 30 | - SHTFLsfc 31 | - DLWRFsfc 32 | - DSWRFsfc 33 | - DSWRFtoa 34 | - ULWRFsfc 35 | - ULWRFtoa 36 | - USWRFsfc 37 | - USWRFtoa 38 | - eastward_surface_wind_stress 39 | - northward_surface_wind_stress 40 | - surface_evaporation_rate 41 | - total_energy 42 | - total_frozen_precipitation_rate 43 | full_state.zarr: 44 | # 2D vars 45 | - HGTsfc # static 46 | - PRESsfc 47 | - surface_temperature 48 | - air_temperature_at_two_meters 49 | - specific_humidity_at_two_meters 50 | - eastward_wind_at_ten_meters 51 | - northward_wind_at_ten_meters 52 | # 3D vars: 53 | - air_temperature 54 | - specific_humidity # water species 55 | - cloud_water_mixing_ratio # water species 56 | - cloud_ice_mixing_ratio # water species 57 | - eastward_wind 58 | - northward_wind 59 | land_static.zarr: 60 | - land_fraction 61 | full_state_land.zarr: 62 | - column_soil_moisture 63 | full_state_ice.zarr: 64 | - sea_ice_fraction 65 | standard_names: 66 | longitude_dim: lon 67 | latitude_dim: lat 68 | graupel_mixing_ratio: None 69 | rain_mixing_ratio: None 70 | snow_mixing_ratio: None 71 | precipitable_water_path: None 72 | -------------------------------------------------------------------------------- /scripts/data_process/configs/shield-amip-ensemble-c24-4deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | ic_0001: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-AMIP-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/ic_0001 3 | ic_0002: gs://vcm-ml-raw-flexible-retention/2024-11-11-C24-SHiELD-AMIP-tuned-cdmbgwd/regridded-zarrs/gaussian_grid_45_by_90/ic_0002 4 | data_output_directory: gs://vcm-ml-intermediate/2024-11-11-vertically-resolved-c24-4deg-shield-amip-tuned-cdmbgwd-ensemble-dataset 5 | stats: 6 | output_directory: gs://vcm-ml-intermediate/2024-11-11-vertically-resolved-c24-4deg-shield-amip-tuned-cdmbgwd-ensemble-dataset-stats 7 | beaker_dataset: 2024-11-11-vertically-resolved-c24-4deg-shield-amip-tuned-cdmbgwd-ensemble-dataset-stats 8 | start_date: "1940-01-01" 9 | end_date: "2021-12-31" 10 | data_type: FV3GFS 11 | exclude_runs: 12 | - "ic_0002" 13 | dataset_computation: 14 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc 15 | vertical_coarsening_indices: 16 | - [0, 11] 17 | - [11, 21] 18 | - [21, 30] 19 | - [30, 39] 20 | - [39, 49] 21 | - [49, 58] 22 | - [58, 67] 23 | - [67, 79] 24 | renaming: 25 | specific_humidity_at_two_meters: Q2m 26 | air_temperature_at_two_meters: TMP2m 27 | eastward_wind_at_ten_meters: UGRD10m 28 | northward_wind_at_ten_meters: VGRD10m 29 | variable_sources: 30 | fluxes_2d.zarr: 31 | - PRATEsfc 32 | - LHTFLsfc 33 | - SHTFLsfc 34 | - DLWRFsfc 35 | - DSWRFsfc 36 | - DSWRFtoa 37 | - ULWRFsfc 38 | - ULWRFtoa 39 | - USWRFsfc 40 | - USWRFtoa 41 | - precipitable_water_path 42 | - GRAUPELsfc 43 | - ICEsfc 44 | - SNOWsfc 45 | full_state.zarr: 46 | - PRESsfc 47 | - HGTsfc 48 | - RH500 49 | - RH850 50 | - TMP500 51 | - TMP850 52 | - UGRD500 53 | - UGRD850 54 | - UGRD1000 55 | - VGRD500 56 | - VGRD850 57 | - VGRD1000 58 | - h50 59 | - h500 60 | - h850 61 | - h1000 62 | - air_temperature_at_two_meters 63 | - eastward_wind_at_ten_meters 64 | - northward_wind_at_ten_meters 65 | - surface_temperature 66 | - air_temperature 67 | - specific_humidity 68 | - cloud_water_mixing_ratio 69 | - cloud_ice_mixing_ratio 70 | - graupel_mixing_ratio 71 | - rain_mixing_ratio 72 | - snow_mixing_ratio 73 | - northward_wind 74 | - eastward_wind 75 | - pressure_thickness_of_atmospheric_layer 76 | - soil_moisture_0 77 | - soil_moisture_1 78 | - soil_moisture_2 79 | - soil_moisture_3 80 | - snow_cover_fraction 81 | - specific_humidity_at_two_meters 82 | - land_fraction 83 | - ocean_fraction 84 | - sea_ice_fraction 85 | - UGRD200 86 | - VGRD200 87 | - TMP200 88 | - RH200 89 | scalar.zarr: 90 | - global_mean_co2 91 | -------------------------------------------------------------------------------- /scripts/data_process/configs/shield-amip-ensemble-c96-1deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | ic_0001: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_180_by_360/ic_0001 3 | ic_0002: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_180_by_360/ic_0002 4 | data_output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-1deg-shield-amip-ensemble-dataset 5 | stats: 6 | output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-1deg-shield-amip-ensemble-dataset-stats 7 | beaker_dataset: 2024-07-24-vertically-resolved-c96-1deg-shield-amip-ensemble-dataset-stats 8 | start_date: "1940-01-01" 9 | end_date: "2021-12-31" 10 | data_type: FV3GFS 11 | exclude_runs: 12 | - "ic_0002" 13 | dataset_computation: 14 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc 15 | vertical_coarsening_indices: 16 | - [0, 11] 17 | - [11, 21] 18 | - [21, 30] 19 | - [30, 39] 20 | - [39, 49] 21 | - [49, 58] 22 | - [58, 67] 23 | - [67, 79] 24 | renaming: 25 | specific_humidity_at_two_meters: Q2m 26 | air_temperature_at_two_meters: TMP2m 27 | eastward_wind_at_ten_meters: UGRD10m 28 | northward_wind_at_ten_meters: VGRD10m 29 | variable_sources: 30 | fluxes_2d.zarr: 31 | - PRATEsfc 32 | - LHTFLsfc 33 | - SHTFLsfc 34 | - DLWRFsfc 35 | - DSWRFsfc 36 | - DSWRFtoa 37 | - ULWRFsfc 38 | - ULWRFtoa 39 | - USWRFsfc 40 | - USWRFtoa 41 | - precipitable_water_path 42 | - GRAUPELsfc 43 | - ICEsfc 44 | - SNOWsfc 45 | full_state.zarr: 46 | - PRESsfc 47 | - HGTsfc 48 | - RH500 49 | - RH850 50 | - TMP500 51 | - TMP850 52 | - UGRD500 53 | - UGRD850 54 | - UGRD1000 55 | - VGRD500 56 | - VGRD850 57 | - VGRD1000 58 | - h50 59 | - h500 60 | - h850 61 | - h1000 62 | - air_temperature_at_two_meters 63 | - eastward_wind_at_ten_meters 64 | - northward_wind_at_ten_meters 65 | - surface_temperature 66 | - air_temperature 67 | - specific_humidity 68 | - cloud_water_mixing_ratio 69 | - cloud_ice_mixing_ratio 70 | - graupel_mixing_ratio 71 | - rain_mixing_ratio 72 | - snow_mixing_ratio 73 | - northward_wind 74 | - eastward_wind 75 | - pressure_thickness_of_atmospheric_layer 76 | - soil_moisture_0 77 | - soil_moisture_1 78 | - soil_moisture_2 79 | - soil_moisture_3 80 | - snow_cover_fraction 81 | - specific_humidity_at_two_meters 82 | - land_fraction 83 | - ocean_fraction 84 | - sea_ice_fraction 85 | UGRD200_VGRD200_TMP200_RH200.zarr: 86 | - UGRD200 87 | - VGRD200 88 | - TMP200 89 | - RH200 90 | scalar.zarr: 91 | - global_mean_co2 92 | -------------------------------------------------------------------------------- /scripts/data_process/configs/shield-amip-ensemble-c96-4deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | ic_0001: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_45_by_90/ic_0001 3 | ic_0002: gs://vcm-ml-raw-flexible-retention/2024-06-29-C96-SHiELD-AMIP/regridded-zarrs/gaussian_grid_45_by_90/ic_0002 4 | data_output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-4deg-shield-amip-ensemble-dataset 5 | stats: 6 | output_directory: gs://vcm-ml-intermediate/2024-07-24-vertically-resolved-c96-4deg-shield-amip-ensemble-dataset-stats 7 | beaker_dataset: 2024-07-24-vertically-resolved-c96-4deg-shield-amip-ensemble-dataset-stats 8 | start_date: "1940-01-01" 9 | end_date: "2021-12-31" 10 | data_type: FV3GFS 11 | exclude_runs: 12 | - "ic_0002" 13 | dataset_computation: 14 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc 15 | vertical_coarsening_indices: 16 | - [0, 11] 17 | - [11, 21] 18 | - [21, 30] 19 | - [30, 39] 20 | - [39, 49] 21 | - [49, 58] 22 | - [58, 67] 23 | - [67, 79] 24 | renaming: 25 | specific_humidity_at_two_meters: Q2m 26 | air_temperature_at_two_meters: TMP2m 27 | eastward_wind_at_ten_meters: UGRD10m 28 | northward_wind_at_ten_meters: VGRD10m 29 | variable_sources: 30 | fluxes_2d.zarr: 31 | - PRATEsfc 32 | - LHTFLsfc 33 | - SHTFLsfc 34 | - DLWRFsfc 35 | - DSWRFsfc 36 | - DSWRFtoa 37 | - ULWRFsfc 38 | - ULWRFtoa 39 | - USWRFsfc 40 | - USWRFtoa 41 | - precipitable_water_path 42 | - GRAUPELsfc 43 | - ICEsfc 44 | - SNOWsfc 45 | full_state.zarr: 46 | - PRESsfc 47 | - HGTsfc 48 | - RH500 49 | - RH850 50 | - TMP500 51 | - TMP850 52 | - UGRD500 53 | - UGRD850 54 | - VGRD500 55 | - VGRD850 56 | - h50 57 | - h200 58 | - h500 59 | - h850 60 | - air_temperature_at_two_meters 61 | - eastward_wind_at_ten_meters 62 | - northward_wind_at_ten_meters 63 | - surface_temperature 64 | - air_temperature 65 | - specific_humidity 66 | - cloud_water_mixing_ratio 67 | - cloud_ice_mixing_ratio 68 | - graupel_mixing_ratio 69 | - rain_mixing_ratio 70 | - snow_mixing_ratio 71 | - northward_wind 72 | - eastward_wind 73 | - pressure_thickness_of_atmospheric_layer 74 | - soil_moisture_0 75 | - soil_moisture_1 76 | - soil_moisture_2 77 | - soil_moisture_3 78 | - column_soil_moisture 79 | - snow_cover_fraction 80 | - specific_humidity_at_two_meters 81 | - land_fraction 82 | - ocean_fraction 83 | - sea_ice_fraction 84 | UGRD200_VGRD200_TMP200_RH200.zarr: 85 | - UGRD200 86 | - VGRD200 87 | - TMP200 88 | - RH200 89 | scalar.zarr: 90 | - global_mean_co2 91 | -------------------------------------------------------------------------------- /scripts/data_process/configs/shield-c96-4deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | ic_0001: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/regridded-zarrs/gaussian_grid_45_by_90/repeating-sst 3 | data_output_directory: gs://vcm-ml-intermediate/2024-04-02-vertically-resolved-4deg-c96-shield-fme-dataset 4 | stats: 5 | output_directory: gs://vcm-ml-intermediate/2024-04-02-vertically-resolved-4deg-c96-shield-fme-dataset-stats 6 | beaker_dataset: 2024-04-02-vertically-resolved-4deg-c96-shield-fme-dataset-stats 7 | # start_date: "2035-01-01" # start of run 8 | start_date: "2036-01-01" # we exclude just the first year so we can use it as an initial condition 9 | end_date: "2060-12-31" # end of run 10 | data_type: FV3GFS 11 | dataset_computation: 12 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc 13 | vertical_coarsening_indices: 14 | - [0, 11] 15 | - [11, 21] 16 | - [21, 30] 17 | - [30, 39] 18 | - [39, 49] 19 | - [49, 58] 20 | - [58, 67] 21 | - [67, 79] 22 | renaming: 23 | specific_humidity_at_two_meters: Q2m 24 | air_temperature_at_two_meters: TMP2m 25 | eastward_wind_at_ten_meters: UGRD10m 26 | northward_wind_at_ten_meters: VGRD10m 27 | variable_sources: 28 | fluxes_2d.zarr: 29 | - PRATEsfc 30 | - LHTFLsfc 31 | - SHTFLsfc 32 | - DLWRFsfc 33 | - DSWRFsfc 34 | - DSWRFtoa 35 | - ULWRFsfc 36 | - ULWRFtoa 37 | - USWRFsfc 38 | - USWRFtoa 39 | - precipitable_water_path 40 | - GRAUPELsfc 41 | - ICEsfc 42 | - SNOWsfc 43 | full_state.zarr: 44 | - surface_temperature 45 | - air_temperature 46 | - specific_humidity 47 | - cloud_water_mixing_ratio 48 | - cloud_ice_mixing_ratio 49 | - graupel_mixing_ratio 50 | - rain_mixing_ratio 51 | - snow_mixing_ratio 52 | - northward_wind 53 | - eastward_wind 54 | - pressure_thickness_of_atmospheric_layer 55 | - PRESsfc 56 | - HGTsfc 57 | - column_soil_moisture 58 | - soil_moisture_0 59 | - soil_moisture_1 60 | - soil_moisture_2 61 | - soil_moisture_3 62 | - land_fraction 63 | - ocean_fraction 64 | - sea_ice_fraction 65 | - specific_humidity_at_two_meters 66 | - air_temperature_at_two_meters 67 | - northward_wind_at_ten_meters 68 | - eastward_wind_at_ten_meters -------------------------------------------------------------------------------- /scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-1deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | abrupt-2xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/abrupt-2xCO2 3 | abrupt-3xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/abrupt-3xCO2 4 | abrupt-4xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/abrupt-4xCO2 5 | data_output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-1deg-c96-shield-som-abrupt-co2-increase-fme-dataset 6 | stats: 7 | output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-1deg-c96-shield-som-abrupt-co2-increase-fme-dataset-stats 8 | beaker_dataset: 2024-07-16-vertically-resolved-1deg-fme-c96-shield-som-abrupt-co2-increase-dataset-stats 9 | start_date: null 10 | end_date: null 11 | data_type: FV3GFS 12 | exclude_runs: 13 | - abrupt-2xCO2 14 | - abrupt-3xCO2 15 | - abrupt-4xCO2 16 | dataset_computation: 17 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc 18 | vertical_coarsening_indices: 19 | - [0, 11] 20 | - [11, 21] 21 | - [21, 30] 22 | - [30, 39] 23 | - [39, 49] 24 | - [49, 58] 25 | - [58, 67] 26 | - [67, 79] 27 | renaming: 28 | specific_humidity_at_two_meters: Q2m 29 | air_temperature_at_two_meters: TMP2m 30 | eastward_wind_at_ten_meters: UGRD10m 31 | northward_wind_at_ten_meters: VGRD10m 32 | variable_sources: 33 | fluxes_2d.zarr: 34 | - PRATEsfc 35 | - LHTFLsfc 36 | - SHTFLsfc 37 | - DLWRFsfc 38 | - DSWRFsfc 39 | - DSWRFtoa 40 | - ULWRFsfc 41 | - ULWRFtoa 42 | - USWRFsfc 43 | - USWRFtoa 44 | - precipitable_water_path 45 | - GRAUPELsfc 46 | - ICEsfc 47 | - SNOWsfc 48 | full_state.zarr: 49 | - surface_temperature 50 | - air_temperature 51 | - specific_humidity 52 | - cloud_water_mixing_ratio 53 | - cloud_ice_mixing_ratio 54 | - graupel_mixing_ratio 55 | - rain_mixing_ratio 56 | - snow_mixing_ratio 57 | - northward_wind 58 | - eastward_wind 59 | - pressure_thickness_of_atmospheric_layer 60 | - PRESsfc 61 | - HGTsfc 62 | - column_soil_moisture 63 | - soil_moisture_0 64 | - soil_moisture_1 65 | - soil_moisture_2 66 | - soil_moisture_3 67 | - land_fraction 68 | - ocean_fraction 69 | - sea_ice_fraction 70 | - specific_humidity_at_two_meters 71 | - air_temperature_at_two_meters 72 | - northward_wind_at_ten_meters 73 | - eastward_wind_at_ten_meters 74 | - RH200 75 | - RH500 76 | - RH850 77 | - TMP200 78 | - TMP500 79 | - TMP850 80 | - UGRD200 81 | - UGRD500 82 | - UGRD850 83 | - VGRD200 84 | - VGRD500 85 | - VGRD850 86 | - h50 87 | - h200 88 | - h500 89 | - h850 90 | ocean_forcing.zarr: 91 | - prescribed_mixed_layer_depth 92 | - prescribed_qflux 93 | scalar.zarr: 94 | - global_mean_co2 95 | -------------------------------------------------------------------------------- /scripts/data_process/configs/shield-som-abrupt-co2-increase-c96-4deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | abrupt-2xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-2xCO2 3 | abrupt-3xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-3xCO2 4 | abrupt-4xCO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/abrupt-4xCO2 5 | data_output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-4deg-c96-shield-som-abrupt-co2-increase-fme-dataset 6 | stats: 7 | output_directory: gs://vcm-ml-intermediate/2024-08-14-vertically-resolved-4deg-c96-shield-som-abrupt-co2-increase-fme-dataset-stats 8 | beaker_dataset: 2024-08-14-vertically-resolved-4deg-fme-c96-shield-som-abrupt-co2-increase-dataset-stats 9 | start_date: null 10 | end_date: null 11 | data_type: FV3GFS 12 | exclude_runs: 13 | - abrupt-2xCO2 14 | - abrupt-3xCO2 15 | - abrupt-4xCO2 16 | dataset_computation: 17 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc 18 | vertical_coarsening_indices: 19 | - [0, 11] 20 | - [11, 21] 21 | - [21, 30] 22 | - [30, 39] 23 | - [39, 49] 24 | - [49, 58] 25 | - [58, 67] 26 | - [67, 79] 27 | renaming: 28 | specific_humidity_at_two_meters: Q2m 29 | air_temperature_at_two_meters: TMP2m 30 | eastward_wind_at_ten_meters: UGRD10m 31 | northward_wind_at_ten_meters: VGRD10m 32 | variable_sources: 33 | fluxes_2d.zarr: 34 | - PRATEsfc 35 | - LHTFLsfc 36 | - SHTFLsfc 37 | - DLWRFsfc 38 | - DSWRFsfc 39 | - DSWRFtoa 40 | - ULWRFsfc 41 | - ULWRFtoa 42 | - USWRFsfc 43 | - USWRFtoa 44 | - precipitable_water_path 45 | - GRAUPELsfc 46 | - ICEsfc 47 | - SNOWsfc 48 | full_state.zarr: 49 | - surface_temperature 50 | - air_temperature 51 | - specific_humidity 52 | - cloud_water_mixing_ratio 53 | - cloud_ice_mixing_ratio 54 | - graupel_mixing_ratio 55 | - rain_mixing_ratio 56 | - snow_mixing_ratio 57 | - northward_wind 58 | - eastward_wind 59 | - pressure_thickness_of_atmospheric_layer 60 | - PRESsfc 61 | - HGTsfc 62 | - column_soil_moisture 63 | - soil_moisture_0 64 | - soil_moisture_1 65 | - soil_moisture_2 66 | - soil_moisture_3 67 | - land_fraction 68 | - ocean_fraction 69 | - sea_ice_fraction 70 | - specific_humidity_at_two_meters 71 | - air_temperature_at_two_meters 72 | - northward_wind_at_ten_meters 73 | - eastward_wind_at_ten_meters 74 | - RH200 75 | - RH500 76 | - RH850 77 | - TMP200 78 | - TMP500 79 | - TMP850 80 | - UGRD200 81 | - UGRD500 82 | - UGRD850 83 | - VGRD200 84 | - VGRD500 85 | - VGRD850 86 | - h50 87 | - h200 88 | - h500 89 | - h850 90 | ocean_forcing.zarr: 91 | - prescribed_mixed_layer_depth 92 | - prescribed_qflux 93 | scalar.zarr: 94 | - global_mean_co2 95 | -------------------------------------------------------------------------------- /scripts/data_process/configs/shield-som-increasing-co2-c96-1deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | increasing-CO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/increasing-CO2 3 | data_output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-1deg-c96-shield-som-increasing-co2-fme-dataset 4 | stats: 5 | output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-1deg-c96-shield-som-increasing-co2-fme-dataset-stats 6 | beaker_dataset: 2024-07-16-vertically-resolved-1deg-fme-c96-shield-som-increasing-co2-dataset-stats 7 | start_date: null 8 | end_date: null 9 | data_type: FV3GFS 10 | exclude_runs: 11 | - increasing-CO2 12 | dataset_computation: 13 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc 14 | vertical_coarsening_indices: 15 | - [0, 11] 16 | - [11, 21] 17 | - [21, 30] 18 | - [30, 39] 19 | - [39, 49] 20 | - [49, 58] 21 | - [58, 67] 22 | - [67, 79] 23 | renaming: 24 | specific_humidity_at_two_meters: Q2m 25 | air_temperature_at_two_meters: TMP2m 26 | eastward_wind_at_ten_meters: UGRD10m 27 | northward_wind_at_ten_meters: VGRD10m 28 | variable_sources: 29 | fluxes_2d.zarr: 30 | - PRATEsfc 31 | - LHTFLsfc 32 | - SHTFLsfc 33 | - DLWRFsfc 34 | - DSWRFsfc 35 | - DSWRFtoa 36 | - ULWRFsfc 37 | - ULWRFtoa 38 | - USWRFsfc 39 | - USWRFtoa 40 | - precipitable_water_path 41 | - GRAUPELsfc 42 | - ICEsfc 43 | - SNOWsfc 44 | full_state.zarr: 45 | - surface_temperature 46 | - air_temperature 47 | - specific_humidity 48 | - cloud_water_mixing_ratio 49 | - cloud_ice_mixing_ratio 50 | - graupel_mixing_ratio 51 | - rain_mixing_ratio 52 | - snow_mixing_ratio 53 | - northward_wind 54 | - eastward_wind 55 | - pressure_thickness_of_atmospheric_layer 56 | - PRESsfc 57 | - HGTsfc 58 | - column_soil_moisture 59 | - soil_moisture_0 60 | - soil_moisture_1 61 | - soil_moisture_2 62 | - soil_moisture_3 63 | - land_fraction 64 | - ocean_fraction 65 | - sea_ice_fraction 66 | - specific_humidity_at_two_meters 67 | - air_temperature_at_two_meters 68 | - northward_wind_at_ten_meters 69 | - eastward_wind_at_ten_meters 70 | - RH200 71 | - RH500 72 | - RH850 73 | - TMP200 74 | - TMP500 75 | - TMP850 76 | - UGRD200 77 | - UGRD500 78 | - UGRD850 79 | - VGRD200 80 | - VGRD500 81 | - VGRD850 82 | - h50 83 | - h200 84 | - h500 85 | - h850 86 | ocean_forcing.zarr: 87 | - prescribed_mixed_layer_depth 88 | - prescribed_qflux 89 | scalar.zarr: 90 | - global_mean_co2 91 | -------------------------------------------------------------------------------- /scripts/data_process/configs/shield-som-increasing-co2-c96-4deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | increasing-CO2: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_45_by_90/increasing-CO2 3 | data_output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-4deg-c96-shield-som-increasing-co2-fme-dataset 4 | stats: 5 | output_directory: gs://vcm-ml-intermediate/2024-07-16-vertically-resolved-4deg-c96-shield-som-increasing-co2-fme-dataset-stats 6 | beaker_dataset: 2024-07-16-vertically-resolved-4deg-fme-c96-shield-som-increasing-co2-dataset-stats 7 | start_date: null 8 | end_date: null 9 | data_type: FV3GFS 10 | exclude_runs: 11 | - increasing-CO2 12 | dataset_computation: 13 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc 14 | vertical_coarsening_indices: 15 | - [0, 11] 16 | - [11, 21] 17 | - [21, 30] 18 | - [30, 39] 19 | - [39, 49] 20 | - [49, 58] 21 | - [58, 67] 22 | - [67, 79] 23 | renaming: 24 | specific_humidity_at_two_meters: Q2m 25 | air_temperature_at_two_meters: TMP2m 26 | eastward_wind_at_ten_meters: UGRD10m 27 | northward_wind_at_ten_meters: VGRD10m 28 | variable_sources: 29 | fluxes_2d.zarr: 30 | - PRATEsfc 31 | - LHTFLsfc 32 | - SHTFLsfc 33 | - DLWRFsfc 34 | - DSWRFsfc 35 | - DSWRFtoa 36 | - ULWRFsfc 37 | - ULWRFtoa 38 | - USWRFsfc 39 | - USWRFtoa 40 | - precipitable_water_path 41 | - GRAUPELsfc 42 | - ICEsfc 43 | - SNOWsfc 44 | full_state.zarr: 45 | - surface_temperature 46 | - air_temperature 47 | - specific_humidity 48 | - cloud_water_mixing_ratio 49 | - cloud_ice_mixing_ratio 50 | - graupel_mixing_ratio 51 | - rain_mixing_ratio 52 | - snow_mixing_ratio 53 | - northward_wind 54 | - eastward_wind 55 | - pressure_thickness_of_atmospheric_layer 56 | - PRESsfc 57 | - HGTsfc 58 | - column_soil_moisture 59 | - soil_moisture_0 60 | - soil_moisture_1 61 | - soil_moisture_2 62 | - soil_moisture_3 63 | - land_fraction 64 | - ocean_fraction 65 | - sea_ice_fraction 66 | - specific_humidity_at_two_meters 67 | - air_temperature_at_two_meters 68 | - northward_wind_at_ten_meters 69 | - eastward_wind_at_ten_meters 70 | - RH200 71 | - RH500 72 | - RH850 73 | - TMP200 74 | - TMP500 75 | - TMP850 76 | - UGRD200 77 | - UGRD500 78 | - UGRD850 79 | - VGRD200 80 | - VGRD500 81 | - VGRD850 82 | - h50 83 | - h200 84 | - h500 85 | - h850 86 | ocean_forcing.zarr: 87 | - prescribed_mixed_layer_depth 88 | - prescribed_qflux 89 | scalar.zarr: 90 | - global_mean_co2 91 | -------------------------------------------------------------------------------- /scripts/data_process/configs/shield-som-radiation-multi-call-c96-1deg-8layer.yaml: -------------------------------------------------------------------------------- 1 | runs: 2 | radiation-multi-call: gs://vcm-ml-raw-flexible-retention/2024-07-03-C96-SHiELD-SOM/regridded-zarrs/gaussian_grid_180_by_360/radiation-multi-call 3 | data_output_directory: gs://vcm-ml-intermediate/2024-10-22-vertically-resolved-1deg-c96-shield-som-radiation-multi-call-fme-dataset 4 | stats: 5 | output_directory: gs://vcm-ml-intermediate/2024-10-22-vertically-resolved-1deg-c96-shield-som-radiation-multi-call-fme-dataset-stats 6 | beaker_dataset: 2024-10-22-vertically-resolved-1deg-fme-c96-radiation-multi-call-co2-dataset-stats 7 | start_date: null 8 | end_date: null 9 | data_type: FV3GFS 10 | exclude_runs: 11 | - radiation-multi-call 12 | dataset_computation: 13 | reference_vertical_coordinate_file: gs://vcm-ml-raw-flexible-retention/2024-03-10-C96-SHiELD-FME-reference/vertical-coordinate-file/fv_core.res.nc 14 | vertical_coarsening_indices: 15 | - [0, 11] 16 | - [11, 21] 17 | - [21, 30] 18 | - [30, 39] 19 | - [39, 49] 20 | - [49, 58] 21 | - [58, 67] 22 | - [67, 79] 23 | renaming: 24 | specific_humidity_at_two_meters: Q2m 25 | air_temperature_at_two_meters: TMP2m 26 | eastward_wind_at_ten_meters: UGRD10m 27 | northward_wind_at_ten_meters: VGRD10m 28 | variable_sources: 29 | fluxes_2d.zarr: 30 | - PRATEsfc 31 | - LHTFLsfc 32 | - SHTFLsfc 33 | - DLWRFsfc 34 | - DSWRFsfc 35 | - DSWRFtoa 36 | - ULWRFsfc 37 | - ULWRFtoa 38 | - USWRFsfc 39 | - USWRFtoa 40 | - precipitable_water_path 41 | - GRAUPELsfc 42 | - ICEsfc 43 | - SNOWsfc 44 | full_state.zarr: 45 | - surface_temperature 46 | - air_temperature 47 | - specific_humidity 48 | - cloud_water_mixing_ratio 49 | - cloud_ice_mixing_ratio 50 | - graupel_mixing_ratio 51 | - rain_mixing_ratio 52 | - snow_mixing_ratio 53 | - northward_wind 54 | - eastward_wind 55 | - pressure_thickness_of_atmospheric_layer 56 | - PRESsfc 57 | - HGTsfc 58 | - column_soil_moisture 59 | - soil_moisture_0 60 | - soil_moisture_1 61 | - soil_moisture_2 62 | - soil_moisture_3 63 | - land_fraction 64 | - ocean_fraction 65 | - sea_ice_fraction 66 | - specific_humidity_at_two_meters 67 | - air_temperature_at_two_meters 68 | - northward_wind_at_ten_meters 69 | - eastward_wind_at_ten_meters 70 | - RH200 71 | - RH500 72 | - RH850 73 | - TMP200 74 | - TMP500 75 | - TMP850 76 | - UGRD200 77 | - UGRD500 78 | - UGRD850 79 | - VGRD200 80 | - VGRD500 81 | - VGRD850 82 | - h50 83 | - h200 84 | - h500 85 | - h850 86 | ocean_forcing.zarr: 87 | - prescribed_mixed_layer_depth 88 | - prescribed_qflux 89 | scalar.zarr: 90 | - global_mean_co2 91 | -------------------------------------------------------------------------------- /scripts/data_process/convert_to_monthly_netcdf.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import click 4 | import numpy as np 5 | import pandas as pd 6 | import xarray as xr 7 | 8 | 9 | @click.command() 10 | @click.argument("input_zarr") 11 | @click.argument("output_directory") 12 | @click.option("--start-date", help="For subsetting, e.g. '2016-01-01'") 13 | @click.option("--end-date", help="For subsetting, e.g. '2016-12-31'") 14 | @click.option("--nc-format", default="NETCDF4", help="netCDF file format") 15 | @click.option( 16 | "--prepend-nans", 17 | is_flag=True, 18 | help="Prepend NaNs to first timestep. Used for baseline " 19 | "which is missing initial condition.", 20 | ) 21 | def main(input_zarr, output_directory, start_date, end_date, nc_format, prepend_nans): 22 | """Save data at INPUT_ZARR to monthly netcdf files in OUTPUT_DIRECTORY. 23 | It is assumed that OUTPUT_DIRECTORY does not exist.""" 24 | os.makedirs(output_directory) 25 | ds = xr.open_zarr(input_zarr) 26 | if prepend_nans: 27 | # prepend NaNs to first timestep 28 | ds = prepend_nans_to_dataset(ds) 29 | ds = ds.sel(time=slice(start_date, end_date)) 30 | monthly_ds = ds.resample(time="MS") 31 | for label, data in monthly_ds: 32 | if isinstance(label, np.datetime64): 33 | # np.datetime64 times do not have a strftime method, 34 | # so convert to pd.Timestamp 35 | label = pd.Timestamp(label) 36 | print(f"Processing month {label}") 37 | filename = os.path.join(output_directory, label.strftime("%Y%m%d%H") + ".nc") 38 | # use these options to enable opening data with netCDF4.MFDataset 39 | data.to_netcdf(filename, unlimited_dims=["time"], format=nc_format) 40 | 41 | 42 | def prepend_nans_to_dataset(ds: xr.Dataset) -> xr.Dataset: 43 | """Prepend NaNs to time dimension of an xarray dataset.""" 44 | time_dependent_vars = [v for v in ds if "time" in ds[v].dims] 45 | time_dependent_ds = ds[time_dependent_vars] 46 | prepend_step = xr.full_like(time_dependent_ds.isel(time=0), np.nan) 47 | delta_t = ds["time"].values[1] - ds["time"].values[0] 48 | prepend_step["time"] = ds["time"].values[0] - delta_t 49 | return xr.concat([prepend_step, ds], dim="time").transpose(*ds.dims) 50 | 51 | 52 | def test_prepend_nans(): 53 | ds = xr.tutorial.open_dataset("air_temperature") 54 | ds_prepended = prepend_nans_to_dataset(ds) 55 | assert ds_prepended.sizes["time"] == ds.sizes["time"] + 1 56 | assert np.isnan(ds_prepended.isel(time=0)["air"].values).all() 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /scripts/data_process/convert_to_monthly_netcdf_fv3gfs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script launches N_IC jobs to convert all GCS zarr data to local monthly netCDFs 4 | 5 | while [[ "$#" -gt 0 ]] 6 | do case $1 in 7 | --input-url) BASE_INPUT_URL="$2" 8 | shift;; 9 | --n-ic) N_IC=$2 10 | shift;; 11 | --output-dir) BASE_OUTPUT_DIR="$2" 12 | shift;; 13 | --start-date) START_DATE="$2" 14 | shift;; 15 | --end-date) END_DATE="$2" 16 | shift;; 17 | *) echo "Unknown parameter passed: $1" 18 | exit 1;; 19 | esac 20 | shift 21 | done 22 | 23 | if [[ -z "${BASE_INPUT_URL}" ]] 24 | then 25 | echo "Option --input-url missing" 26 | exit 1; 27 | elif [[ -z "${N_IC}" ]] 28 | then 29 | echo "Option --n-ic missing" 30 | exit 1; 31 | elif [[ -z "${BASE_OUTPUT_DIR}" ]] 32 | then 33 | echo "Option --output-dir missing" 34 | exit 1; 35 | elif [[ -z "${START_DATE}" ]] 36 | then 37 | echo "Option --start-date missing" 38 | exit 1; 39 | elif [[ -z "${END_DATE}" ]] 40 | then 41 | echo "Option --end-date missing" 42 | exit 1; 43 | fi 44 | 45 | 46 | 47 | for IC in $(seq 1 $(( N_IC ))); do 48 | IC_STR=$(printf "%04d" ${IC}) 49 | INPUT_URL=${BASE_INPUT_URL}/ic_${IC_STR}.zarr 50 | OUTPUT_DIR=${BASE_OUTPUT_DIR}/ic_${IC_STR} 51 | python convert_to_monthly_netcdf.py \ 52 | $INPUT_URL \ 53 | $OUTPUT_DIR \ 54 | --start-date $START_DATE \ 55 | --end-date $END_DATE & 56 | done 57 | -------------------------------------------------------------------------------- /scripts/data_process/earth2grid.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.08-py3 2 | 3 | # Clone and install earth2grid if not already installed 4 | RUN PACKAGE=earth2grid && \ 5 | if ! pip show "$PACKAGE" &>/dev/null; then \ 6 | git clone https://github.com/NVlabs/earth2grid.git && \ 7 | cd earth2grid && \ 8 | pip install --no-build-isolation . && \ 9 | cd .. && \ 10 | rm -rf earth2grid; \ 11 | fi -------------------------------------------------------------------------------- /scripts/data_process/generate_beaker_stats_dataset_fv3gfs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | while [[ "$#" -gt 0 ]] 6 | do case $1 in 7 | --input-url) INPUT_URL="$2" 8 | shift;; 9 | --start-date) START_DATE="$2" 10 | shift;; 11 | --end-date) END_DATE="$2" 12 | shift;; 13 | --name) DATASET_NAME="$2" 14 | shift;; 15 | --desc) DATASET_DESC="$2" 16 | shift;; 17 | --script-flags) SCRIPT_FLAGS="$2" 18 | shift;; 19 | *) echo "Unknown parameter passed: $1" 20 | exit 1;; 21 | esac 22 | shift 23 | done 24 | 25 | if [[ -z "${INPUT_URL}" ]] 26 | then 27 | echo "Option --input-url missing" 28 | exit 1; 29 | elif [[ -z "${START_DATE}" ]] 30 | then 31 | echo "Option --start-date missing" 32 | exit 1; 33 | elif [[ -z "${END_DATE}" ]] 34 | then 35 | echo "Option --end-date missing" 36 | exit 1; 37 | elif [[ -z "${DATASET_NAME}" ]] 38 | then 39 | echo "Option --dataset-name missing" 40 | exit 1; 41 | fi 42 | 43 | OUTPUT_DIR="/tmp/$(uuidgen)" 44 | 45 | python get_stats.py \ 46 | $INPUT_URL \ 47 | ${OUTPUT_DIR} \ 48 | --start-date $START_DATE \ 49 | --end-date $END_DATE ${SCRIPT_FLAGS} 50 | 51 | beaker dataset create ${OUTPUT_DIR} \ 52 | --name ${DATASET_NAME} --desc "${DATASET_DESC}" 53 | 54 | rm -rf ${OUTPUT_DIR} 55 | -------------------------------------------------------------------------------- /scripts/data_process/generate_datasets_e3smv2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | #SBATCH -A m4331 4 | #SBATCH -q regular 5 | #SBATCH -C cpu 6 | #SBATCH --nodes=1 7 | #SBATCH --ntasks-per-node=1 8 | #SBATCH -t 01:00:00 9 | #SBATCH --output=joblogs/%x-%j.out 10 | 11 | while [[ "$#" -gt 0 ]] 12 | do case $1 in 13 | -i|--input-dir) INPUT_DIR="$2" 14 | shift;; 15 | -c|--config) CONFIG="$2" 16 | shift;; 17 | -z|--zarr) ZARR="$2" 18 | shift;; 19 | -o|--output-dir) OUTPUT_DIR="$2" 20 | shift;; 21 | *) echo "Unknown parameter passed: $1" 22 | exit 1;; 23 | esac 24 | shift 25 | done 26 | 27 | if [[ -z "${INPUT_DIR}" ]] 28 | then 29 | echo "Option -i, --input-dir missing" 30 | exit 1; 31 | elif [[ -z "${CONFIG}" ]] 32 | then 33 | echo "Option -c, --config missing" 34 | exit 1; 35 | elif [[ -z "${ZARR}" ]] 36 | then 37 | echo "Option -z, --zarr missing" 38 | exit 1; 39 | elif [[ -z "${OUTPUT_DIR}" ]] 40 | then 41 | echo "Option -o, --output-dir missing" 42 | exit 1; 43 | fi 44 | 45 | # output dir should be somewhere on $SCRATCH, even if the intention is to send 46 | # the data to CFS or some external data store 47 | mkdir -p $OUTPUT_DIR 48 | 49 | # stripe_small is recommended for files of size 1-10 GB on Perlmutter's Lustre 50 | # scratch filesystem and stripes across 8 OSTs 51 | # see https://docs.nersc.gov/performance/io/lustre/#nersc-file-striping-recommendations 52 | stripe_small $OUTPUT_DIR 53 | 54 | # NOTE: assumes you've already created the fv3net conda env. See 55 | # https://github.com/ai2cm/fv3net/blob/8ed295cf0b8ca49e24ae5d6dd00f57e8b30169ac/Makefile#L310 56 | source activate fv3net 57 | 58 | set -xe 59 | 60 | # create the zarr from E3SMv2 .nc files 61 | python -u compute_dataset_e3smv2.py --n-workers=16 --config=${CONFIG} \ 62 | -i ${INPUT_DIR} -o ${ZARR} 63 | 64 | # Train on first year (intended for training) 65 | python -u convert_to_monthly_netcdf.py \ 66 | ${ZARR} \ 67 | ${OUTPUT_DIR}/traindata \ 68 | --start-date 1970-01-01 \ 69 | --end-date 1970-12-31 \ 70 | --nc-format NETCDF4 71 | 72 | # Validation on next 6 months 73 | python -u convert_to_monthly_netcdf.py \ 74 | ${ZARR} \ 75 | ${OUTPUT_DIR}/validdata \ 76 | --start-date 1971-01-01 \ 77 | --end-date 1971-05-31 \ 78 | --nc-format NETCDF4 79 | 80 | # Final 6 months for preditiondata reference 81 | python -u convert_to_monthly_netcdf.py \ 82 | ${ZARR} \ 83 | ${OUTPUT_DIR}/predictiondata \ 84 | --start-date 1971-06-01 \ 85 | --end-date 1971-12-31 \ 86 | --nc-format NETCDF4 87 | 88 | # compute all stats on training data 89 | python -u get_stats.py ${CONFIG} 0 90 | -------------------------------------------------------------------------------- /scripts/data_process/test_combine_stats.py: -------------------------------------------------------------------------------- 1 | import xarray as xr 2 | from combine_stats import get_combined_stats 3 | 4 | 5 | def test_get_combined_stats(): 6 | import numpy as np 7 | 8 | arr1 = xr.DataArray(np.random.rand(10, 10), dims=["x", "y"]) 9 | arr2 = xr.DataArray(np.random.rand(5, 10) * 2 + 5, dims=["x", "y"]) 10 | full_field_datasets = [ 11 | xr.Dataset( 12 | {"arr": arr1.std()}, 13 | attrs={"input_samples": 100}, 14 | ), 15 | xr.Dataset( 16 | {"arr": arr2.std()}, 17 | attrs={"input_samples": 50}, 18 | ), 19 | ] 20 | centering_datasets = [ 21 | xr.Dataset( 22 | {"arr": arr1.mean()}, 23 | attrs={"input_samples": 100}, 24 | ), 25 | xr.Dataset( 26 | {"arr": arr2.mean()}, 27 | attrs={"input_samples": 50}, 28 | ), 29 | ] 30 | samples = xr.DataArray([100, 50], dims=["run"]) 31 | average = get_combined_stats(full_field_datasets, centering_datasets, samples) 32 | combined_arr = np.concatenate([arr1.values.flatten(), arr2.values.flatten()]) 33 | assert np.allclose(average["arr"], np.std(combined_arr)) 34 | -------------------------------------------------------------------------------- /scripts/data_process/test_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dacite 4 | import pytest 5 | import yaml 6 | from combine_stats import Config as CombineStatsConfig 7 | from get_stats import Config as GetStatsConfig 8 | from upload_stats import Config as UploadStatsConfig 9 | 10 | DIRNAME = os.path.abspath(os.path.dirname(__file__)) 11 | # list files in DIRNAME/config 12 | CONFIG_YAMLS = [ 13 | os.path.join(DIRNAME + "/configs", f) 14 | for f in os.listdir(DIRNAME + "/configs") 15 | if f.endswith(".yaml") 16 | ] 17 | 18 | 19 | @pytest.mark.parametrize( 20 | "filename", 21 | CONFIG_YAMLS, 22 | ) 23 | @pytest.mark.parametrize("cls", [GetStatsConfig, UploadStatsConfig, CombineStatsConfig]) 24 | def test_get_stats_valid(filename, cls): 25 | with open(filename, "r") as f: 26 | config_data = yaml.load(f, Loader=yaml.CLoader) 27 | dacite.from_dict(data_class=cls, data=config_data) 28 | -------------------------------------------------------------------------------- /scripts/data_process/upload_stats.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import shutil 3 | import tempfile 4 | from typing import Dict, List, Optional 5 | 6 | import click 7 | import dacite 8 | import fsspec 9 | import yaml 10 | 11 | 12 | def copy(source: str, destination: str): 13 | """Copy between any two 'filesystems'. Do not use for large files. 14 | 15 | Args: 16 | source: Path to source file/object. 17 | destination: Path to destination. 18 | """ 19 | with fsspec.open(source) as f_source: 20 | with fsspec.open(destination, "wb") as f_destination: 21 | shutil.copyfileobj(f_source, f_destination) 22 | 23 | 24 | @dataclasses.dataclass 25 | class StatsConfig: 26 | output_directory: str 27 | beaker_dataset: str 28 | exclude_runs: List[str] = dataclasses.field(default_factory=list) 29 | start_date: Optional[str] = None 30 | end_date: Optional[str] = None 31 | 32 | 33 | @dataclasses.dataclass 34 | class Config: 35 | runs: Dict[str, str] 36 | data_output_directory: str 37 | stats: StatsConfig 38 | 39 | 40 | @click.command() 41 | @click.argument("config_yaml", type=str) 42 | def main(config_yaml: str): 43 | """ 44 | Combine statistics for the data processing pipeline. 45 | 46 | Arguments: 47 | config_yaml -- Path to the configuration file for the data processing pipeline. 48 | """ 49 | # imported here so we don't need to install beaker for the tests 50 | from beaker import Beaker 51 | 52 | with open(config_yaml, "r") as f: 53 | config_data = yaml.load(f, Loader=yaml.CLoader) 54 | config = dacite.from_dict(data_class=Config, data=config_data) 55 | 56 | stats_combined_dir = config.stats.output_directory + "/combined/" 57 | beaker = Beaker.from_env() 58 | with tempfile.TemporaryDirectory() as tmpdir: 59 | for filename in ( 60 | "centering.nc", 61 | "scaling-full-field.nc", 62 | "scaling-residual.nc", 63 | "time-mean.nc", 64 | ): 65 | copy(stats_combined_dir + filename, tmpdir + "/" + filename) 66 | runs = [run for run in config.runs if run not in config.stats.exclude_runs] 67 | run_names = ", ".join(runs) 68 | if config.stats.start_date is None: 69 | start = "start of run" 70 | else: 71 | start = config.stats.start_date 72 | if config.stats.end_date is None: 73 | end = "end of run" 74 | else: 75 | end = config.stats.end_date 76 | beaker.dataset.create( 77 | config.stats.beaker_dataset, 78 | tmpdir, 79 | workspace="ai2/ace", 80 | description=( 81 | "Coefficients for normalization for data " 82 | f"{config.data_output_directory} runs {run_names}. " 83 | f"Computed from {start} to {end}." 84 | ), 85 | ) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /scripts/era5/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM continuumio/miniconda3:24.1.2-0 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y apt-transport-https ca-certificates gnupg curl 5 | 6 | RUN conda create -y -n env python=3.9 \ 7 | && conda install -y -c conda-forge -n env metview-batch 8 | 9 | RUN conda install -y -n env -c conda-forge metview-python && conda clean -tip 10 | 11 | COPY dataflow-requirements.txt /tmp/requirements.txt 12 | RUN pip install uv \ 13 | && uv pip install --python=/opt/conda/envs/env/bin/python -r /tmp/requirements.txt 14 | 15 | COPY --from=apache/beam_python3.9_sdk:2.54.0 /opt/apache/beam /opt/apache/beam 16 | ENTRYPOINT [ "/opt/apache/beam/boot" ] 17 | 18 | ENV PATH /opt/conda/envs/env/bin:${PATH} 19 | # This is necessary so findlibs can find the eccodes library for metview 20 | # Beam workers do not trigger a default conda environment activation 21 | ENV CONDA_PREFIX /opt/conda/envs/env 22 | 23 | # Without following line, the following error occurs on import of pandas: 24 | # ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' not found 25 | RUN cp $CONDA_PREFIX/lib/libstdc++.so.6 /usr/lib/x86_64-linux-gnu/ 26 | 27 | RUN python -m eccodes selfcheck && python -m cfgrib selfcheck && python -m metview selfcheck 28 | -------------------------------------------------------------------------------- /scripts/era5/Makefile: -------------------------------------------------------------------------------- 1 | VERSION ?= 2024-07-10-era5-xarray-beam-pipelines 2 | IMAGE_NAME ?= us.gcr.io/vcm-ml/era5-ingest-dataflow:$(VERSION) 3 | LOCAL_ENVIRONMENT ?= era5-ingestion 4 | 5 | ingest_ncar_variables: 6 | cd ingest_ncar_data && ./main.sh 7 | 8 | create_environment: 9 | conda env create -f environment.yaml 10 | conda run --no-capture-output -n era5-ingestion pip install -r dataflow-requirements.txt 11 | conda run --no-capture-output -n era5-ingestion pip install metview jupyterlab matplotlib cfgrib 12 | 13 | build_dataflow: 14 | docker build -t $(IMAGE_NAME) . 15 | 16 | push_dataflow: build_dataflow 17 | docker push $(IMAGE_NAME) 18 | 19 | enter: 20 | docker run --rm -v $$(pwd):/era5 -w /era5 --entrypoint "/bin/bash" -it $(IMAGE_NAME) 21 | 22 | enter_google: 23 | docker run --rm -v $$(pwd):/era5 -w /era5 --entrypoint "/bin/bash" -it gcr.io/weather-tools-prod/weather-tools:0.0.0 24 | 25 | netcdf_to_zarr_local: 26 | cd netcdf_to_zarr && conda run --no-capture-output -n $(LOCAL_ENVIRONMENT) ./run-netcdf-to-zarr.sh DirectRunner $(IMAGE_NAME) 27 | 28 | netcdf_to_zarr_dataflow: 29 | cd netcdf_to_zarr && conda run --no-capture-output -n $(LOCAL_ENVIRONMENT) ./run-netcdf-to-zarr.sh DataflowRunner $(IMAGE_NAME) 30 | 31 | era5_dataflow: 32 | cd pipeline && conda run --no-capture-output -n $(LOCAL_ENVIRONMENT) ./run-dataflow.sh 33 | -------------------------------------------------------------------------------- /scripts/era5/README.md: -------------------------------------------------------------------------------- 1 | # Scripts for generating a dataset for Full Model Emulation based on ERA5 2 | 3 | ## Downloading some 2D variables from NCAR mirror of ERA5 4 | 5 | Most of this dataset will be generated from Google's version of the ERA5 6 | dataset (see https://github.com/google-research/arco-era5). However, that dataset 7 | is missing some variables such as sensible and latent heat flux and some radiative 8 | fluxes. Therefore, we download these variables, as well as some auxiliary variables 9 | such as land-fraction and surface geopotential, from NCAR's hosted version of ERA5. 10 | NCAR has the 0.25° regular lat-lon data, so we download that version. Regridding and 11 | conversion to zarr is left to a future step. 12 | 13 | This download step should only have to be performed once. It can be started via 14 | ``` 15 | make ingest_ncar_variables 16 | ``` 17 | and uses an argo workflow to run on our Google cloud resources. Sometimes the download 18 | will fail for certain variables. If this happens, the workflow can be resubmitted 19 | with the same command as above, and it will pick up where it left off. 20 | 21 | ## Converting netCDF files downloaded from NCAR to zarr 22 | 23 | To facilitate further processing and alignment with the data available from 24 | Google, we use an xarray-beam pipeline to concatenate, merge, and rechunk the 25 | ERA5 data downloaded from NCAR into a set of three zarr stores: 26 | 27 | - `e5.oper.fc.sfc.meanflux` 28 | - `e5.oper.an.sfc` 29 | - `e5.oper.invariant` 30 | 31 | The scaled up version of the beam pipeline is run using Dataflow. It first 32 | requires creating a local Python environment with the needed dependencies 33 | installed: 34 | 35 | ``` 36 | make create_environment 37 | ``` 38 | 39 | To submit the full Dataflow workflow, one can use: 40 | 41 | ``` 42 | make netcdf_to_zarr_dataflow 43 | ``` 44 | 45 | This submits the jobs to create each dataset one at a time, though the process 46 | for creating each dataset is highly parallelized. 47 | 48 | If needed the Docker image required for running the workflow in the cloud can 49 | be rebuilt and pushed using: 50 | 51 | ``` 52 | make build_dataflow push_dataflow 53 | ``` 54 | 55 | ## Computing coarsened ERA5 dataset for FME 56 | 57 | Once the previous steps have been done, all the necessary data should be available 58 | in zarr format on Google Cloud Storage. Now it is possible to compute all necessary 59 | variables on the 1° horizontal resolution and with eight vertical layers. This is 60 | done using an xarray-beam pipeline similar to the previous step. 61 | 62 | First, if not already available, build a docker image using the same instructions 63 | as in previous step. Additionally, create the local "era5-ingestion" conda 64 | environment. 65 | 66 | Once these steps are done, the workflow can be submitted with 67 | 68 | ``` 69 | make era5_dataflow 70 | ``` 71 | -------------------------------------------------------------------------------- /scripts/era5/dataflow-requirements.txt: -------------------------------------------------------------------------------- 1 | apache_beam[gcp]==2.54.0 2 | h5netcdf 3 | xarray 4 | zarr 5 | dask 6 | scipy 7 | numpy 8 | xarray_beam 9 | cftime 10 | gcsfs -------------------------------------------------------------------------------- /scripts/era5/environment.yaml: -------------------------------------------------------------------------------- 1 | name: era5-ingestion 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.9 6 | - metview-batch 7 | - pip 8 | -------------------------------------------------------------------------------- /scripts/era5/ingest_ncar_data/argo_workflow.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: argoproj.io/v1alpha1 2 | kind: Workflow 3 | metadata: 4 | generateName: ingest-ncar-era5-data- 5 | spec: 6 | entrypoint: ingest-ncar-era5-data 7 | volumes: 8 | - name: gcp-key-secret 9 | secret: 10 | defaultMode: 420 11 | secretName: gcp-key 12 | arguments: 13 | parameters: 14 | - name: python_script 15 | - name: variables 16 | - name: script_flags 17 | value: "" 18 | templates: 19 | - name: ingest-ncar-era5-data 20 | steps: 21 | - - name: ingest-ncar-era5-single-variable 22 | template: ingest-ncar-era5-single-variable 23 | arguments: 24 | parameters: 25 | - name: python_script 26 | value: "{{workflow.parameters.python_script}}" 27 | - name: category 28 | value: "{{item.category}}" 29 | - name: variable_name 30 | value: "{{item.variable_name}}" 31 | - name: start_time 32 | value: "{{item.start_time}}" 33 | - name: n_files 34 | value: "{{item.n_files}}" 35 | - name: script_flags 36 | value: "{{workflow.parameters.script_flags}}" 37 | withParam: "{{workflow.parameters.variables}}" 38 | - name: ingest-ncar-era5-single-variable 39 | tolerations: 40 | - effect: NoSchedule 41 | key: dedicated 42 | value: med-sim-pool 43 | inputs: 44 | parameters: 45 | - name: python_script 46 | - name: category 47 | - name: variable_name 48 | - name: start_time 49 | - name: n_files 50 | - name: script_flags 51 | container: 52 | image: us.gcr.io/vcm-ml/fv3net:3d1589321e40cddc06bb88c22b44f597646473b2 53 | resources: 54 | limits: 55 | cpu: "8000m" 56 | memory: "27Gi" 57 | requests: 58 | cpu: "7500m" 59 | memory: "27Gi" 60 | command: ["bash", "-c", "-e"] 61 | args: 62 | - | 63 | cat << EOF > script.py 64 | {{inputs.parameters.python_script}} 65 | EOF 66 | 67 | python script.py \ 68 | {{inputs.parameters.category}} \ 69 | {{inputs.parameters.variable_name}} \ 70 | {{inputs.parameters.start_time}} \ 71 | {{inputs.parameters.n_files}} \ 72 | {{inputs.parameters.script_flags}} 73 | env: 74 | - name: GOOGLE_APPLICATION_CREDENTIALS 75 | value: /secret/gcp-credentials/key.json 76 | - name: CLOUDSDK_AUTH_CREDENTIAL_FILE_OVERRIDE 77 | value: /secret/gcp-credentials/key.json 78 | volumeMounts: 79 | - mountPath: /secret/gcp-credentials 80 | name: gcp-key-secret 81 | -------------------------------------------------------------------------------- /scripts/era5/ingest_ncar_data/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | argo submit argo_workflow.yaml \ 4 | -p python_script="$(< ingest_single_variable.py)" \ 5 | -p variables="$(< variables.json)" \ 6 | -p script_flags="--gcs-dir gs://vcm-ml-raw-flexible-retention/2024-03-11-era5-025deg-2D-variables-from-NCAR" 7 | -------------------------------------------------------------------------------- /scripts/era5/netcdf_to_zarr/run-netcdf-to-zarr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Options, 4 | # DirectRunner - local 5 | # PortableRunner - container test 6 | # DataflowRunner - cloud 7 | RUNNER=${1} 8 | IMAGE_NAME=${2} 9 | 10 | for category in e5.oper.fc.sfc.meanflux e5.oper.an.sfc e5.oper.invariant 11 | do 12 | python netcdf_to_zarr_pipeline.py \ 13 | gs://vcm-ml-intermediate/2024-05-17-era5-025deg-2D-variables-from-NCAR-as-zarr \ 14 | ${category} \ 15 | --project vcm-ml \ 16 | --region us-central1 \ 17 | --temp_location gs://vcm-ml-scratch/oliverwm/temp/ \ 18 | --experiments use_runner_v2 \ 19 | --runner ${RUNNER} \ 20 | --sdk_location container \ 21 | --sdk_container_image ${IMAGE_NAME} \ 22 | --save_main_session \ 23 | --num_workers 1 \ 24 | --disk_size_gb 35 \ 25 | --machine_type n2d-highmem-2 26 | done 27 | 28 | # save_main_session is needed so that the imported modules are available to individual functions 29 | -------------------------------------------------------------------------------- /scripts/era5/pipeline/run-dataflow-025deg-data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Options, 4 | # DirectRunner - local 5 | # DataflowRunner - cloud 6 | RUNNER=${1:-DataflowRunner} 7 | 8 | # generate ERA5 dataset at 0.25 degree resolution 9 | 10 | python3 xr-beam-pipeline.py \ 11 | gs://vcm-ml-intermediate/2024-05-29-era5-025deg-8layer-2010-2022.zarr \ 12 | 2010-01-01T00:00:00 \ 13 | 2022-12-31T18:00:00 \ 14 | --output_grid F360 \ 15 | --output_time_chunksize 2 \ 16 | --ncar_process_time_chunksize 2 \ 17 | --project vcm-ml \ 18 | --region us-central1 \ 19 | --temp_location gs://vcm-ml-scratch/oliwm/temp/ \ 20 | --experiments use_runner_v2 \ 21 | --runner $RUNNER \ 22 | --sdk_location container \ 23 | --sdk_container_image us.gcr.io/vcm-ml/era5-ingest-dataflow:2024-03-11-era5-xarray-beam-pipelines \ 24 | --save_main_session \ 25 | --num_workers 1 \ 26 | --disk_size_gb 70 \ 27 | --max_num_workers 750 \ 28 | --machine_type n2d-custom-2-24576-ext \ 29 | --worker_disk_type "compute.googleapis.com/projects/vcm-ml/zones/us-central1-c/diskTypes/pd-ssd" \ 30 | --number_of_worker_harness_threads 1 31 | -------------------------------------------------------------------------------- /scripts/era5/pipeline/run-dataflow-16levels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Options, 4 | # DirectRunner - local 5 | # DataflowRunner - cloud 6 | RUNNER=${1:-DataflowRunner} 7 | #RUNNER=${1:-DirectRunner} 8 | 9 | # Splits each of the original 8 layers into half at linear pressure midpoint 10 | OUTPUT_LAYER_INDICES_16='0 38 48 59 67 74 79 85 90 95 100 104 109 113 119 125 137' 11 | OUTPUT_PATH='gs://vcm-ml-intermediate/2024-07-11-era5-1deg-16layer-1940-2022.zarr' 12 | 13 | python3 xr-beam-pipeline.py \ 14 | $OUTPUT_PATH \ 15 | 1940-01-01T12:00:00 \ 16 | 2022-12-31T18:00:00 \ 17 | --output-layer-indices $OUTPUT_LAYER_INDICES_16 \ 18 | --output_grid F90 \ 19 | --output_time_chunksize 20 \ 20 | --ncar_process_time_chunksize 4 \ 21 | --project vcm-ml \ 22 | --region us-central1 \ 23 | --temp_location gs://vcm-ml-scratch/annak/temp \ 24 | --experiments use_runner_v2 \ 25 | --runner $RUNNER \ 26 | --sdk_location container \ 27 | --sdk_container_image us.gcr.io/vcm-ml/era5-ingest-dataflow:2024-07-10-era5-xarray-beam-pipelines \ 28 | --save_main_session \ 29 | --num_workers 1 \ 30 | --disk_size_gb 70 \ 31 | --max_num_workers 750 \ 32 | --machine_type n2d-custom-2-24576-ext \ 33 | --worker_disk_type "compute.googleapis.com/projects/vcm-ml/zones/us-central1-c/diskTypes/pd-ssd" \ 34 | --number_of_worker_harness_threads 1 -------------------------------------------------------------------------------- /scripts/era5/pipeline/run-dataflow.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Options, 4 | # DirectRunner - local 5 | # DataflowRunner - cloud 6 | RUNNER=${1:-DataflowRunner} 7 | 8 | python3 xr-beam-pipeline.py \ 9 | gs://vcm-ml-intermediate/2024-06-20-era5-1deg-8layer-1940-2022.zarr \ 10 | 1940-01-01T12:00:00 \ 11 | 2022-12-31T18:00:00 \ 12 | --output_grid F90 \ 13 | --output_time_chunksize 20 \ 14 | --ncar_process_time_chunksize 4 \ 15 | --project vcm-ml \ 16 | --region us-central1 \ 17 | --temp_location gs://vcm-ml-scratch/oliwm/temp/ \ 18 | --experiments use_runner_v2 \ 19 | --runner $RUNNER \ 20 | --sdk_location container \ 21 | --sdk_container_image us.gcr.io/vcm-ml/era5-ingest-dataflow:2024-03-11-era5-xarray-beam-pipelines \ 22 | --save_main_session \ 23 | --num_workers 1 \ 24 | --disk_size_gb 70 \ 25 | --max_num_workers 750 \ 26 | --machine_type n2d-custom-2-24576-ext \ 27 | --worker_disk_type "compute.googleapis.com/projects/vcm-ml/zones/us-central1-c/diskTypes/pd-ssd" \ 28 | --number_of_worker_harness_threads 1 29 | -------------------------------------------------------------------------------- /scripts/manual_backwards_compatibility/.gitignore: -------------------------------------------------------------------------------- 1 | test_inference_ace2_era5 -------------------------------------------------------------------------------- /scripts/manual_backwards_compatibility/ace2-era5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # This script can be used to ensure that your currently installed version of the fme 6 | # package can do inference with the published ACE2-ERA5 model. 7 | 8 | # download necessary data 9 | mkdir -p test_inference_ace2_era5 10 | cd test_inference_ace2_era5 11 | mkdir -p initial_conditions 12 | mkdir -p forcing_data 13 | wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/ace2_era5_ckpt.tar?download=true -O ace2_era5_ckpt.tar 14 | wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/inference_config.yaml?download=true -O inference_config.yaml 15 | wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/initial_conditions/ic_2020.nc?download=true -O initial_conditions/ic_2020.nc 16 | wget https://huggingface.co/allenai/ACE2-ERA5/resolve/main/forcing_data/forcing_2020.nc?download=true -O forcing_data/forcing_2020.nc 17 | 18 | # update config to use relative paths and do a short run 19 | yq e '.n_forward_steps = 50' -i inference_config.yaml 20 | yq e '.forward_steps_in_memory = 5' -i inference_config.yaml 21 | yq e '.checkpoint_path = "ace2_era5_ckpt.tar"' -i inference_config.yaml 22 | yq e '.initial_condition.path = "initial_conditions/ic_2020.nc"' -i inference_config.yaml 23 | yq e '.forcing_loader.dataset.data_path = "forcing_data/"' -i inference_config.yaml 24 | 25 | # run on CPU or CUDA if the latter is available 26 | yq e '.experiment_dir = "output_cpu"' -i inference_config.yaml 27 | python -m fme.ace.inference inference_config.yaml 28 | 29 | # run on MPS. NOTE: this requires torch==2.5 otherwise there are complaints about some of the 30 | # features used by the SFNO architecture. 31 | yq e '.experiment_dir = "output_mps"' -i inference_config.yaml 32 | export FME_USE_MPS=1 33 | python -m fme.ace.inference inference_config.yaml 34 | -------------------------------------------------------------------------------- /scripts/monthly_data/config.yaml: -------------------------------------------------------------------------------- 1 | 2 | experiment_dir: /workdir/jeremym/monthly_data 3 | data_loader: 4 | dataset: 5 | - data_path: /workdir/shared/2023-12-20-vertically-resolved-4deg-fme-amip-ensemble-dataset/train 6 | batch_size: 1 7 | num_data_workers: 1 8 | logging: 9 | log_to_screen: true 10 | log_to_wandb: true 11 | log_to_file: true 12 | project: fourcastnet 13 | entity: ai2cm 14 | forward_steps_in_memory: 73 15 | variable_names: 16 | - DSWRFtoa 17 | - land_fraction 18 | - ocean_fraction 19 | - sea_ice_fraction 20 | - PRESsfc 21 | - surface_temperature 22 | - air_temperature_0 23 | - air_temperature_1 24 | - air_temperature_2 25 | - air_temperature_3 26 | - air_temperature_4 27 | - air_temperature_5 28 | - air_temperature_6 29 | - air_temperature_7 30 | - specific_total_water_0 31 | - specific_total_water_1 32 | - specific_total_water_2 33 | - specific_total_water_3 34 | - specific_total_water_4 35 | - specific_total_water_5 36 | - specific_total_water_6 37 | - specific_total_water_7 38 | - eastward_wind_0 39 | - eastward_wind_1 40 | - eastward_wind_2 41 | - eastward_wind_3 42 | - eastward_wind_4 43 | - eastward_wind_5 44 | - eastward_wind_6 45 | - eastward_wind_7 46 | - northward_wind_0 47 | - northward_wind_1 48 | - northward_wind_2 49 | - northward_wind_3 50 | - northward_wind_4 51 | - northward_wind_5 52 | - northward_wind_6 53 | - northward_wind_7 54 | - LHTFLsfc 55 | - SHTFLsfc 56 | - PRATEsfc 57 | - ULWRFsfc 58 | - ULWRFtoa 59 | - DLWRFsfc 60 | - DSWRFsfc 61 | - USWRFsfc 62 | - USWRFtoa 63 | - tendency_of_total_water_path_due_to_advection -------------------------------------------------------------------------------- /scripts/monthly_data/test_write_monthly_data.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import List 3 | 4 | import pytest 5 | import xarray as xr 6 | from write_monthly_data import Config, run 7 | 8 | from fme.ace.data_loading.config import DataLoaderConfig 9 | from fme.ace.testing import DimSize, DimSizes 10 | from fme.ace.testing.fv3gfs_data import save_nd_netcdf 11 | from fme.core.dataset.config import XarrayDataConfig 12 | from fme.core.logging_utils import LoggingConfig 13 | 14 | 15 | def write_ensemble_dataset( 16 | path: pathlib.Path, n_members: int, names: List[str], dim_sizes: DimSizes 17 | ): 18 | if not path.exists(): 19 | path.mkdir(parents=True) 20 | for i in range(n_members): 21 | ensemble_dir = path / f"ic_{i:04d}" 22 | ensemble_dir.mkdir(exist_ok=True) 23 | save_nd_netcdf( 24 | ensemble_dir / "data.nc", 25 | dim_sizes, 26 | names, 27 | timestep_days=5, 28 | ) 29 | 30 | 31 | def test_write_monthly_data(very_fast_only: bool, tmp_path: pathlib.Path): 32 | if very_fast_only: 33 | pytest.skip("Skipping non-fast tests") 34 | all_names = ["a", "b"] 35 | horizontal = [DimSize("grid_yt", 8), DimSize("grid_xt", 4)] 36 | dim_sizes = DimSizes( 37 | n_time=4 * 60, 38 | horizontal=horizontal, 39 | nz_interface=2, 40 | ) 41 | n_members = 3 42 | write_ensemble_dataset(tmp_path / "data", n_members, all_names, dim_sizes) 43 | dataset = [ 44 | XarrayDataConfig(data_path=str(tmp_path / "data" / f"ic_{i:04}")) 45 | for i in range(n_members) 46 | ] 47 | config = Config( 48 | experiment_dir=str(tmp_path), 49 | data_loader=DataLoaderConfig( 50 | dataset=dataset, 51 | batch_size=1, 52 | num_data_workers=0, 53 | ), 54 | logging=LoggingConfig( 55 | log_to_screen=True, 56 | log_to_file=False, 57 | log_to_wandb=False, 58 | ), 59 | variable_names=all_names, 60 | ) 61 | run(config) 62 | xr.open_dataset(tmp_path / "monthly_mean_data.nc") 63 | --------------------------------------------------------------------------------