├── docs ├── _templates │ ├── .gitkeep │ └── apidoc │ │ └── package.rst.jinja ├── inference │ ├── configs │ │ ├── yaml │ │ │ ├── inputs_1.yaml │ │ │ ├── inputs_6.yaml │ │ │ ├── introduction_3.yaml │ │ │ ├── outputs_1.yaml │ │ │ ├── forcings_2.yaml │ │ │ ├── forcings_4.yaml │ │ │ ├── forcings_5.yaml │ │ │ ├── forcings_1.yaml │ │ │ ├── inputs_7.yaml │ │ │ ├── inputs_4.yaml │ │ │ ├── inputs_8.yaml │ │ │ ├── outputs_5.yaml │ │ │ ├── outputs_10.yaml │ │ │ ├── outputs_2.yaml │ │ │ ├── outputs_3.yaml │ │ │ ├── inputs_2.yaml │ │ │ ├── introduction_4.yaml │ │ │ ├── forcings_3.yaml │ │ │ ├── introduction_2.yaml │ │ │ ├── introduction_1.yaml │ │ │ ├── introduction_5.yaml │ │ │ ├── outputs_7.yaml │ │ │ ├── outputs_8a.yaml │ │ │ ├── outputs_6.yaml │ │ │ ├── inputs_11.yaml │ │ │ ├── outputs_assign.yaml │ │ │ ├── outputs_4.yaml │ │ │ ├── inputs_3.yaml │ │ │ ├── outputs_9.yaml │ │ │ ├── inputs_10.yaml │ │ │ ├── grib-output_1.yaml │ │ │ ├── outputs_8b.yaml │ │ │ ├── grib-input_1.yaml │ │ │ └── inputs_9.yaml │ │ ├── grib-input.rst │ │ ├── forcings.rst │ │ └── introduction.rst │ ├── apis │ │ ├── code │ │ │ ├── level3_1.sh │ │ │ ├── level3_1.yaml │ │ │ ├── level3_2.sh │ │ │ ├── level3_3.sh │ │ │ ├── level1_2_.py │ │ │ ├── level3_2.yaml │ │ │ ├── level3_4.sh │ │ │ ├── level3_3.yaml │ │ │ ├── level3_4.yaml │ │ │ ├── level2.py │ │ │ └── level1_1.py │ │ ├── introduction.rst │ │ ├── level2.rst │ │ ├── level1.rst │ │ └── level3.rst │ └── input-types.rst ├── _static │ ├── logo.png │ ├── schemas │ │ ├── overview.png │ │ └── run-config.rst │ └── style.css ├── pptx │ └── images.pptx ├── usage │ ├── yaml │ │ ├── external-graph1.yaml │ │ ├── external-graph5.yaml │ │ ├── external-graph3.yaml │ │ ├── external-graph2.yaml │ │ └── external-graph4.yaml │ └── getting-started.rst ├── modules │ ├── forcings.rst │ ├── metadata.rst │ ├── checkpoint.rst │ ├── inputs.rst │ ├── runner.rst │ ├── outputs.rst │ └── processor.rst ├── cli │ ├── run.rst │ ├── patch.rst │ ├── inspect.rst │ ├── metadata.rst │ ├── requests.rst │ ├── validate.rst │ └── introduction.rst ├── scripts │ └── api_build.sh ├── dev │ ├── integration-example.yaml │ └── contributing.rst ├── installing.rst ├── overview.rst └── Makefile ├── .gitattributes ├── tests ├── unit │ ├── .gitignore │ ├── configs │ │ ├── .gitignore │ │ ├── simple.yaml │ │ ├── interpolation.yaml │ │ ├── atmos.yaml │ │ ├── ocean.yaml │ │ ├── mwd.yaml │ │ └── coupled.yaml │ ├── checkpoints │ │ ├── .gitignore │ │ └── simple.yaml │ ├── __init__.py │ ├── test_decorators.py │ ├── fake_metadata.py │ ├── test_checkpoint.py │ ├── inputs │ │ ├── test_inference.py │ │ └── test_request_patching.py │ ├── test_metadata.py │ ├── test_coupling.py │ └── test_config.py ├── integration │ ├── conftest.py │ ├── rmi-lam │ │ └── config.yaml │ └── meteoswiss-sgm-cosmo │ │ └── config.yaml └── conftest.py ├── .release-please-manifest.json ├── .github ├── CODEOWNERS ├── dependabot.yml ├── workflows │ ├── pr-label-public.yml │ ├── python-publish.yml │ ├── pr-label-file-based.yml │ ├── python-integration.yml │ ├── pr-conventional-commit.yml │ ├── readthedocs-pr-update.yml │ ├── python-pull-request.yml │ ├── release-please.yml │ ├── pr-label-conventional-commits.yml │ ├── downstream-ci-hpc.yml │ └── pr-label-ats.yml ├── ci-hpc-config.yml ├── pull_request_template.md └── labeler.yml ├── .readthedocs.yaml ├── src └── anemoi │ └── inference │ ├── grib │ ├── __init__.py │ └── templates │ │ ├── samples.py │ │ ├── builtin.py │ │ ├── input.py │ │ └── file.py │ ├── commands │ ├── metadata.py │ ├── __init__.py │ ├── sanitise.py │ ├── run.py │ ├── validate.py │ ├── requests.py │ └── couple.py │ ├── protocol.py │ ├── __init__.py │ ├── device.py │ ├── outputs │ ├── none.py │ ├── __init__.py │ ├── truth.py │ ├── raw.py │ └── gribmemory.py │ ├── testing │ └── variables.py │ ├── __main__.py │ ├── tasks │ └── __init__.py │ ├── runners │ ├── __init__.py │ ├── testing.py │ └── plugin.py │ ├── pre_processors │ ├── __init__.py │ ├── no_missing_values.py │ └── forward_transform_filter.py │ ├── post_processors │ ├── __init__.py │ ├── backward_transform_filter.py │ ├── accumulate.py │ └── assign.py │ ├── task.py │ ├── inputs │ ├── __init__.py │ ├── grib.py │ ├── empty.py │ ├── fdb.py │ └── repeated_dates.py │ ├── transports │ ├── __init__.py │ └── mpi.py │ ├── clusters │ ├── distributed.py │ ├── __init__.py │ ├── spawner.py │ └── mpi.py │ ├── lazy.py │ ├── types.py │ ├── config │ └── couple.py │ ├── utils │ └── templating.py │ ├── processor.py │ ├── precisions.py │ ├── patch.py │ ├── profiler.py │ └── checks.py ├── CONTRIBUTORS.md ├── .release-please-config.json ├── README.md ├── .gitignore ├── .pre-commit-config.yaml └── pyproject.toml /docs/_templates/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | CHANGELOG.md merge=union 2 | -------------------------------------------------------------------------------- /tests/unit/.gitignore: -------------------------------------------------------------------------------- 1 | !checkpoints/ 2 | -------------------------------------------------------------------------------- /tests/unit/configs/.gitignore: -------------------------------------------------------------------------------- 1 | !*.yaml 2 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/inputs_1.yaml: -------------------------------------------------------------------------------- 1 | input: test 2 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/inputs_6.yaml: -------------------------------------------------------------------------------- 1 | input: mars 2 | -------------------------------------------------------------------------------- /tests/unit/checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | !*.yaml 2 | !*.json 3 | -------------------------------------------------------------------------------- /.release-please-manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | ".": "0.8.3" 3 | } 4 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/introduction_3.yaml: -------------------------------------------------------------------------------- 1 | input: mars 2 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_1.yaml: -------------------------------------------------------------------------------- 1 | output: printer 2 | -------------------------------------------------------------------------------- /docs/inference/apis/code/level3_1.sh: -------------------------------------------------------------------------------- 1 | anemoi-inference run aifs.yaml 2 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/forcings_2.yaml: -------------------------------------------------------------------------------- 1 | forcing: 2 | input: mars 3 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/forcings_4.yaml: -------------------------------------------------------------------------------- 1 | forcing: 2 | dynamic: mars 3 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/forcings_5.yaml: -------------------------------------------------------------------------------- 1 | forcing: 2 | boundary: mars 3 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/forcings_1.yaml: -------------------------------------------------------------------------------- 1 | input: test 2 | 3 | forcings: mars 4 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/inputs_7.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | mars: 3 | class: ea 4 | -------------------------------------------------------------------------------- /docs/inference/apis/code/level3_1.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: inference-aifs-0.2.1-anemoi.ckpt 2 | -------------------------------------------------------------------------------- /docs/inference/apis/code/level3_2.sh: -------------------------------------------------------------------------------- 1 | anemoi-inference run aifs.yaml date=2020-01-01 2 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/inputs_4.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | grib: /path/to/grib/file.grib 3 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/inputs_8.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | cds: 3 | dataset: ??? 4 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_5.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | raw: /path/to/directory 3 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_10.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | zarr: /path/to/zarr/store.zarr 3 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_2.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | grib: /path/to/grib/file.grib 3 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_3.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | netcdf: /path/to/netcdf/file.nc 3 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/inputs_2.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | test: 3 | use_original_paths: false 4 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/introduction_4.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | grib: /path/to/grib/file.grib 3 | -------------------------------------------------------------------------------- /docs/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecmwf/anemoi-inference/HEAD/docs/_static/logo.png -------------------------------------------------------------------------------- /docs/inference/configs/yaml/forcings_3.yaml: -------------------------------------------------------------------------------- 1 | forcing: 2 | constant: 3 | grib: constant.grib 4 | -------------------------------------------------------------------------------- /docs/pptx/images.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecmwf/anemoi-inference/HEAD/docs/pptx/images.pptx -------------------------------------------------------------------------------- /docs/usage/yaml/external-graph1.yaml: -------------------------------------------------------------------------------- 1 | runner: 2 | external_graph: 3 | graph: path/to/graph.pt 4 | -------------------------------------------------------------------------------- /docs/inference/apis/code/level3_3.sh: -------------------------------------------------------------------------------- 1 | anemoi-inference run checkpoint=mycheckpoint.ckpt date=2020-01-01 2 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/introduction_2.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | grib: 3 | path: /path/to/grib/file.grib 4 | -------------------------------------------------------------------------------- /tests/unit/configs/simple.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | checkpoint: unit/checkpoints/simple.ckpt 3 | date: 2020-01-01 4 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/introduction_1.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | mars: 3 | class: ea 4 | stream: oper 5 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/introduction_5.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | grib: /path/to/grib/file.grib 3 | 4 | output: printer 5 | -------------------------------------------------------------------------------- /docs/_static/schemas/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecmwf/anemoi-inference/HEAD/docs/_static/schemas/overview.png -------------------------------------------------------------------------------- /docs/inference/apis/code/level1_2_.py: -------------------------------------------------------------------------------- 1 | latitudes = runner.checkpoint.latitudes 2 | longitudes = runner.checkpoint.longitudes 3 | -------------------------------------------------------------------------------- /docs/inference/apis/code/level3_2.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: inference-aifs-0.2.1-anemoi.ckpt 2 | dataset: true 3 | date: 2021-09-01 4 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_7.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | apply_mask: 3 | mask: thinning 4 | output: /path/to/output.nc 5 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_8a.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | extract_lam: 3 | output: 4 | netcdf: 5 | path: lam.nc 6 | -------------------------------------------------------------------------------- /docs/usage/yaml/external-graph5.yaml: -------------------------------------------------------------------------------- 1 | runner: 2 | external_graph: 3 | graph: path/to/graph.pt 4 | check_state_dict: False 5 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_6.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | tee: 3 | - netcdf: /path/to/netcdf/file.nc 4 | - grib: /path/to/grib/file.grib 5 | -------------------------------------------------------------------------------- /tests/integration/conftest.py: -------------------------------------------------------------------------------- 1 | def pytest_configure(config): 2 | config.option.log_cli = False 3 | config.option.log_cli_level = "INFO" 4 | -------------------------------------------------------------------------------- /tests/unit/configs/interpolation.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | checkpoint: unit/checkpoints/simple.ckpt 3 | date: 2020-01-01 4 | 5 | 6 | name: ${oc.env:TEST} 7 | description: ${name} 8 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/inputs_11.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | cutout: 3 | - lam_0: 4 | grib: 5 | - global: 6 | grib: 7 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_assign.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | extract_lam: 3 | output: 4 | assign_mask: 5 | mask: "source0/trimedge_mask" 6 | output: printer 7 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_4.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | plot: 3 | path: /path/to/directory 4 | format: png 5 | domain: Europe 6 | variables: 7 | - 2t 8 | - msl 9 | -------------------------------------------------------------------------------- /docs/modules/forcings.rst: -------------------------------------------------------------------------------- 1 | ########## 2 | forcings 3 | ########## 4 | 5 | .. automodule:: anemoi.inference.forcings 6 | :members: 7 | :no-undoc-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/modules/metadata.rst: -------------------------------------------------------------------------------- 1 | ########## 2 | metadata 3 | ########## 4 | 5 | .. automodule:: anemoi.inference.metadata 6 | :members: 7 | :no-undoc-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/inputs_3.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | dataset: 3 | join: 4 | - dataset: dataset-1 5 | select: [ 2t, tp ] 6 | - dataset: dataset-2 7 | drop: [ 2t, tp ] 8 | -------------------------------------------------------------------------------- /docs/modules/checkpoint.rst: -------------------------------------------------------------------------------- 1 | ############ 2 | checkpoint 3 | ############ 4 | 5 | .. automodule:: anemoi.inference.checkpoint 6 | :members: 7 | :no-undoc-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_9.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | tee: 3 | - truth: 4 | output: 5 | grib: 6 | path: truth.grib 7 | - grib: 8 | path: prediction.grib 9 | -------------------------------------------------------------------------------- /docs/inference/apis/code/level3_4.sh: -------------------------------------------------------------------------------- 1 | anememoi-inference run lam.yaml \ 2 | "input.dataset.cutout.0.dataset=./analysis_20240131_00.zarr" \ 3 | "input.dataset.cutout.1.dataset=./lbc_20240131_00.zarr" 4 | -------------------------------------------------------------------------------- /docs/inference/apis/code/level3_3.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: inference-aifs-0.2.1-anemoi.ckpt 2 | 3 | date: 2024-01-31T00:00:00 4 | 5 | input: 6 | dataset: 7 | cutout: 8 | - dataset: ./analysis_20240101_00.zarr 9 | -------------------------------------------------------------------------------- /docs/inference/apis/code/level3_4.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: icon.ckpt 2 | icon_grid: icon_grid_0026_R03B07_G.nc 3 | 4 | env: 5 | ECCODES_PYTHON_USE_FINDLIBS: 1 6 | ECCODES_DEFINITION_PATH: definitions.edzw-2.31.0-2 7 | -------------------------------------------------------------------------------- /tests/unit/configs/atmos.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: unit/checkpoints/atmos.ckpt 2 | runner: testing 3 | device: cpu 4 | 5 | lead_time: 480 6 | date: 2022-01-02 7 | 8 | input: dummy 9 | 10 | write_initial_state: false 11 | -------------------------------------------------------------------------------- /docs/cli/run.rst: -------------------------------------------------------------------------------- 1 | .. _run-command: 2 | 3 | Run Command 4 | =========== 5 | 6 | .. argparse:: 7 | :module: anemoi.inference.__main__ 8 | :func: create_parser 9 | :prog: anemoi-inference 10 | :path: run 11 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/inputs_10.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | cds: 3 | dataset: 4 | levtype: 5 | pl: reanalysis-era5-pressure-levels 6 | sfc: reanalysis-era5-single-levels 7 | product_type: 'reanalysis' 8 | -------------------------------------------------------------------------------- /docs/cli/patch.rst: -------------------------------------------------------------------------------- 1 | .. _patch-command: 2 | 3 | Patch Command 4 | ============= 5 | 6 | .. argparse:: 7 | :module: anemoi.inference.__main__ 8 | :func: create_parser 9 | :prog: anemoi-inference 10 | :path: patch 11 | -------------------------------------------------------------------------------- /docs/modules/inputs.rst: -------------------------------------------------------------------------------- 1 | ######## 2 | inputs 3 | ######## 4 | 5 | .. automodule:: anemoi.inference.input 6 | :members: 7 | :no-undoc-members: 8 | :show-inheritance: 9 | 10 | .. include:: ../_api/inference.inputs.rst 11 | -------------------------------------------------------------------------------- /docs/modules/runner.rst: -------------------------------------------------------------------------------- 1 | ######## 2 | runner 3 | ######## 4 | 5 | .. automodule:: anemoi.inference.runner 6 | :members: 7 | :no-undoc-members: 8 | :show-inheritance: 9 | 10 | .. include:: ../_api/inference.runners.rst 11 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Workflows 2 | /.github/ @ecmwf/AnemoiSecurity 3 | 4 | # Project configs 5 | /pyproject.toml @ecmwf/AnemoiSecurity 6 | /.pre-commit-config.yaml @ecmwf/AnemoiSecurity 7 | /.release-please-config.json @ecmwf/AnemoiSecurity 8 | -------------------------------------------------------------------------------- /docs/modules/outputs.rst: -------------------------------------------------------------------------------- 1 | ######### 2 | outputs 3 | ######### 4 | 5 | .. automodule:: anemoi.inference.output 6 | :members: 7 | :no-undoc-members: 8 | :show-inheritance: 9 | 10 | .. include:: ../_api/inference.outputs.rst 11 | -------------------------------------------------------------------------------- /docs/cli/inspect.rst: -------------------------------------------------------------------------------- 1 | .. _inspect-command: 2 | 3 | Inspect Command 4 | =============== 5 | 6 | .. argparse:: 7 | :module: anemoi.inference.__main__ 8 | :func: create_parser 9 | :prog: anemoi-inference 10 | :path: inspect 11 | -------------------------------------------------------------------------------- /docs/scripts/api_build.sh: -------------------------------------------------------------------------------- 1 | 2 | script_dir=$(dirname "${BASH_SOURCE[0]}") 3 | docs_dir="$script_dir/.." 4 | source_dir="$script_dir/../../src/" 5 | 6 | sphinx-apidoc -M -f -o "$docs_dir/_api" "$source_dir/anemoi" -t "$docs_dir/_templates/apidoc" 7 | -------------------------------------------------------------------------------- /docs/cli/metadata.rst: -------------------------------------------------------------------------------- 1 | .. _metadata-command: 2 | 3 | Metadata Command 4 | ================ 5 | 6 | .. argparse:: 7 | :module: anemoi.inference.__main__ 8 | :func: create_parser 9 | :prog: anemoi-inference 10 | :path: metadata 11 | -------------------------------------------------------------------------------- /docs/cli/requests.rst: -------------------------------------------------------------------------------- 1 | .. _requests-command: 2 | 3 | Requests Command 4 | ================ 5 | 6 | .. argparse:: 7 | :module: anemoi.inference.__main__ 8 | :func: create_parser 9 | :prog: anemoi-inference 10 | :path: requests 11 | -------------------------------------------------------------------------------- /docs/_static/schemas/run-config.rst: -------------------------------------------------------------------------------- 1 | .. https://autodoc-pydantic.readthedocs.io 2 | 3 | .. _run-config: 4 | 5 | ################## 6 | Run configuration 7 | ################## 8 | 9 | .. autopydantic_model:: anemoi.inference.config.Configuration 10 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/grib-output_1.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | grib: 3 | path: /path/to/grib/file.grib 4 | encoding: 5 | expver: xxxx 6 | class: rd 7 | check_encoding: true 8 | templates: 9 | - input 10 | - builtin 11 | -------------------------------------------------------------------------------- /docs/usage/yaml/external-graph3.yaml: -------------------------------------------------------------------------------- 1 | runner: 2 | external_graph: 3 | graph: path/to/graph.pt 4 | graph_dataset: path/to/graph_dataset.zarr 5 | # the above can be an anemoi-datasets.open_dataset argument as well, 6 | # rather than simply a path 7 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | # - package-ecosystem: "github-actions" 8 | # directory: "/" 9 | # schedule: 10 | # interval: "monthly" 11 | -------------------------------------------------------------------------------- /tests/unit/configs/ocean.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: unit/checkpoints/ocean.ckpt 2 | runner: testing 3 | device: cpu 4 | 5 | allow_nans: true 6 | use_grib_paramid: true 7 | 8 | lead_time: 480 9 | date: 2022-01-02 10 | 11 | input: dummy 12 | 13 | write_initial_state: false 14 | -------------------------------------------------------------------------------- /docs/usage/yaml/external-graph2.yaml: -------------------------------------------------------------------------------- 1 | runner: 2 | external_graph: 3 | graph: path/to/graph.pt 4 | output_mask: 5 | nodes_name: data # name of the output nodes of the graph 6 | attribute_name: cutout_mask # mask specifying the limited area among the output nodes 7 | -------------------------------------------------------------------------------- /.github/workflows/pr-label-public.yml: -------------------------------------------------------------------------------- 1 | # Manage labels of pull requests that originate from forks 2 | name: "[PR] Label Forks" 3 | 4 | on: 5 | pull_request_target: 6 | types: [opened, synchronize] 7 | 8 | jobs: 9 | label: 10 | uses: ecmwf/reusable-workflows/.github/workflows/label-pr.yml@v2 11 | -------------------------------------------------------------------------------- /docs/modules/processor.rst: -------------------------------------------------------------------------------- 1 | ########### 2 | processor 3 | ########### 4 | 5 | .. automodule:: anemoi.inference.processor 6 | :members: 7 | :no-undoc-members: 8 | :show-inheritance: 9 | 10 | .. include:: ../_api/inference.pre_processors.rst 11 | 12 | .. include:: ../_api/inference.post_processors.rst 13 | -------------------------------------------------------------------------------- /tests/unit/configs/mwd.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: unit/checkpoints/mwd.ckpt 2 | 3 | date: 2022-09-01 4 | lead_time: 48 5 | 6 | pre_processors: 7 | - forward_transform_filter: cos_sin_mean_wave_direction 8 | - no_missing_values 9 | 10 | post_processors: 11 | - backward_transform_filter: cos_sin_mean_wave_direction 12 | 13 | 14 | use_grib_paramid: true 15 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.11" 7 | jobs: 8 | pre_build: 9 | - bash docs/scripts/api_build.sh 10 | 11 | sphinx: 12 | configuration: docs/conf.py 13 | 14 | python: 15 | install: 16 | - method: pip 17 | path: . 18 | extra_requirements: 19 | - docs 20 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/outputs_8b.yaml: -------------------------------------------------------------------------------- 1 | output: 2 | extract_lam: 3 | lam: '0/lam_0' #selects the cutout mask of the first dataset used in join 4 | output: 5 | tee: 6 | - netcdf: 7 | path: lam.nc 8 | - plot: 9 | path: lam-plots 10 | variables: 11 | - 2t 12 | - z_500 13 | - t_850 14 | -------------------------------------------------------------------------------- /docs/usage/yaml/external-graph4.yaml: -------------------------------------------------------------------------------- 1 | runner: 2 | external_graph: 3 | graph: path/to/graph.pt 4 | update_supporting_arrays: 5 | graph: # Get data from the graph 6 | 'global/cutout_mask': 'cutout_mask' # Change `global/cutout_mask` in the arrays to `cutout_mask` 7 | file: 8 | 'global/cutout_mask': 'cutout_mask.npy' # Change `global/cutout_mask` in the arrays to `cutout_mask.npy` 9 | -------------------------------------------------------------------------------- /docs/usage/getting-started.rst: -------------------------------------------------------------------------------- 1 | .. _usage-getting-started: 2 | 3 | ################################ 4 | Generating your first forecast 5 | ################################ 6 | 7 | For an example of how to quickly use anemoi to make predictions with a 8 | trained model, please visit the `hugging-face example for AIFS Single 1 9 | `_. 10 | -------------------------------------------------------------------------------- /docs/inference/configs/grib-input.rst: -------------------------------------------------------------------------------- 1 | .. _grib-input: 2 | 3 | ############ 4 | GRIB input 5 | ############ 6 | 7 | .. note:: 8 | 9 | This is placeholder documentation. This will explain how to use the 10 | GRIB input's behaviour can be configured. 11 | 12 | Use the ``namer`` parameter to specify how to rename fields to match the 13 | names in the checkpoints. 14 | 15 | .. literalinclude:: yaml/grib-input_1.yaml 16 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025- Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | -------------------------------------------------------------------------------- /tests/unit/configs/coupled.yaml: -------------------------------------------------------------------------------- 1 | transport: processes 2 | 3 | couplings: 4 | - atmos -> ocean: 5 | - lsm 6 | - 10u 7 | - 10v 8 | - 2t 9 | - 2d 10 | - ssrd 11 | - strd 12 | - tp 13 | - msl 14 | 15 | - ocean -> atmos: 16 | - avg_tos 17 | - avg_siconc 18 | 19 | tasks: 20 | atmos: 21 | runner: 22 | config: configs/atmos.yaml 23 | 24 | ocean: 25 | runner: 26 | config: configs/ocean.yaml 27 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/grib-input_1.yaml: -------------------------------------------------------------------------------- 1 | grib: 2 | path: /path/to/grib/file.grib 3 | namer: 4 | rules: 5 | - - shortName: T_SO 6 | - '{shortName}_{level}' 7 | - - shortName: SMI 8 | - '{shortName}_{level}' 9 | - - { shortName: PS, dataType: fc } 10 | - ignore 11 | - - { shortName: Z0, dataType: fc } 12 | - ignore 13 | - - { shortName: FR_LAND, dataType: fc } 14 | - ignore 15 | -------------------------------------------------------------------------------- /src/anemoi/inference/grib/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | -------------------------------------------------------------------------------- /.github/ci-hpc-config.yml: -------------------------------------------------------------------------------- 1 | build: 2 | modules: 3 | - ninja 4 | dependencies: 5 | - ecmwf/ecbuild@develop 6 | - ecmwf/eccodes@develop 7 | - ecmwf/eckit@develop 8 | - ecmwf/odc@develop 9 | python_dependencies: 10 | - ecmwf/anemoi-utils@develop 11 | - ecmwf/earthkit-data@develop 12 | parallel: 64 13 | 14 | pytest_cmd: | 15 | python -m pytest -vv -m 'not notebook and not no_cache_init' --cov=. --cov-report=xml 16 | -------------------------------------------------------------------------------- /docs/inference/apis/introduction.rst: -------------------------------------------------------------------------------- 1 | .. _api_introduction: 2 | 3 | ###### 4 | APIs 5 | ###### 6 | 7 | anemoi-inference support theree differents APIs to generate a forecast: 8 | 9 | - NumPy to NumPy API (see :ref:`api_level1`) 10 | - Object oriented API (see :ref:`api_level2`) 11 | - Command line API (see :ref:`api_level3`) 12 | 13 | .. toctree:: 14 | :maxdepth: 1 15 | :hidden: 16 | :caption: APIs 17 | 18 | level1 19 | level2 20 | level3 21 | -------------------------------------------------------------------------------- /docs/dev/integration-example.yaml: -------------------------------------------------------------------------------- 1 | - name: grib-in-netcdf-out 2 | input: input.grib # this file will be downloaded from S3 3 | output: output.nc 4 | checks: 5 | - check_with_xarray: 6 | check_accum: tp 7 | check_nans: true 8 | inference_config: 9 | post_processors: 10 | - accumulate_from_start_of_forecast 11 | write_initial_state: false 12 | checkpoint: ${checkpoint:} 13 | input: 14 | grib: ${input:} 15 | output: 16 | netcdf: ${output:} 17 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package to PyPI 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | uses: ecmwf/reusable-workflows/.github/workflows/cd-pypi.yml@v2 13 | secrets: inherit 14 | -------------------------------------------------------------------------------- /.github/workflows/pr-label-file-based.yml: -------------------------------------------------------------------------------- 1 | # This workflow assigns labels to a pull request based on the files changed in the PR. 2 | # The labels are defined in the `.github/labels.yml` file. 3 | name: "[PR] Label File-based" 4 | on: 5 | pull_request_target: 6 | types: [opened, synchronize] 7 | 8 | permissions: 9 | contents: read 10 | pull-requests: write 11 | 12 | jobs: 13 | labeler: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Assign labels from file changes 17 | uses: actions/labeler@v5 18 | -------------------------------------------------------------------------------- /src/anemoi/inference/commands/metadata.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | from anemoi.utils.commands.metadata import Metadata 11 | 12 | command = Metadata 13 | -------------------------------------------------------------------------------- /docs/installing.rst: -------------------------------------------------------------------------------- 1 | .. _installing: 2 | 3 | ############ 4 | Installing 5 | ############ 6 | 7 | To install the package, you can use the following command: 8 | 9 | .. code:: bash 10 | 11 | pip install anemoi-inference 12 | 13 | ************** 14 | Contributing 15 | ************** 16 | 17 | .. code:: bash 18 | 19 | git clone ... 20 | cd anemoi-inference 21 | pip install .[dev] 22 | pip install -r docs/requirements.txt 23 | 24 | You may also have to install pandoc on MacOS: 25 | 26 | .. code:: bash 27 | 28 | brew install pandoc 29 | -------------------------------------------------------------------------------- /.github/workflows/python-integration.yml: -------------------------------------------------------------------------------- 1 | name: Integration test isolation 2 | 3 | on: 4 | pull_request: 5 | types: [opened, synchronize, reopened] 6 | push: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | cosmo: 12 | strategy: 13 | matrix: 14 | python-version: ["3.10", "3.11", "3.12"] 15 | uses: ecmwf/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2 16 | with: 17 | optional-dependencies: "all,tests,cosmo" 18 | custom-pytest: pytest tests/integration --cosmo 19 | python-version: ${{ matrix.python-version }} 20 | -------------------------------------------------------------------------------- /docs/cli/validate.rst: -------------------------------------------------------------------------------- 1 | .. _validate-command: 2 | 3 | Validate Command 4 | ================ 5 | 6 | It is possible to investigate a checkpoint file and determine if the environment matches, 7 | or if anemoi packages differ in version between the inference and the training environment. 8 | 9 | This can be very useful to resolve issues when running an older or shared checkpoint. 10 | 11 | ********* 12 | Usage 13 | ********* 14 | 15 | 16 | .. argparse:: 17 | :module: anemoi.inference.__main__ 18 | :func: create_parser 19 | :prog: anemoi-inference 20 | :path: validate 21 | -------------------------------------------------------------------------------- /.github/workflows/pr-conventional-commit.yml: -------------------------------------------------------------------------------- 1 | # This workflow ensures that the PR title follows the Conventional Commit format. 2 | name: "[PR] Ensure Conventional Commit Title" 3 | 4 | on: 5 | pull_request_target: 6 | types: 7 | - opened 8 | - edited 9 | - synchronize 10 | - reopened 11 | 12 | permissions: 13 | pull-requests: read 14 | 15 | jobs: 16 | main: 17 | name: Validate PR title 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: amannn/action-semantic-pull-request@v5 21 | env: 22 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 23 | -------------------------------------------------------------------------------- /docs/inference/apis/level2.rst: -------------------------------------------------------------------------------- 1 | .. _api_level2: 2 | 3 | ##################### 4 | Object oriented API 5 | ##################### 6 | 7 | The object oriented API is more flexible than the NumPy API, as it 8 | allows users to use classes that can create the initial state 9 | (:class:`Input `), and classes that can 10 | process the output state (:class:`Output 11 | `). Several classes are provided as 12 | part of the package, and users can create their own classes to support 13 | various sources of data and output formats. 14 | 15 | .. literalinclude:: code/level2.py 16 | :language: python 17 | -------------------------------------------------------------------------------- /docs/cli/introduction.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | When you install the `anemoi-inference` package, this will also install command line tool 5 | called ``anemoi-inference`` this can be used to manage the checkpoints. 6 | 7 | The tools can provide help with the ``--help`` options: 8 | 9 | .. code-block:: bash 10 | 11 | % anemoi-inference --help 12 | 13 | 14 | The commands are: 15 | 16 | - :ref:`Run Command ` 17 | - :ref:`Metadata Command ` 18 | - :ref:`Inspect Command ` 19 | - :ref:`Validate Command ` 20 | - :ref:`Patch Command ` 21 | - :ref:`Requests Command ` 22 | -------------------------------------------------------------------------------- /tests/unit/test_decorators.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from anemoi.inference.decorators import main_argument 4 | 5 | 6 | def test_main_argument(): 7 | @main_argument("path") 8 | class _Cls: 9 | def __init__(self, context, flag=True, path=None): 10 | self.context = context 11 | self.flag = flag 12 | self.path = path 13 | 14 | cls = _Cls("context", "path") 15 | 16 | assert cls.context == "context" 17 | assert cls.flag is True 18 | assert cls.path == "path" 19 | 20 | assert isinstance(main_argument("path")(_Cls), type) 21 | 22 | with pytest.raises(TypeError): 23 | main_argument("path")(lambda x: x) 24 | -------------------------------------------------------------------------------- /docs/inference/configs/yaml/inputs_9.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | cds: 3 | # Dataset examples 4 | ## As a string 5 | dataset: 6 | 'reanalysis-era5-pressure-levels' 7 | 8 | ## As a simple dictionary 9 | dataset: 10 | levtype: 11 | pl: reanalysis-era5-pressure-levels 12 | sfc: reanalysis-era5-single-levels 13 | 14 | ## As a complex dictionary 15 | dataset: 16 | stream: 17 | oper: 18 | levtype: 19 | pl: reanalysis-era5-pressure-levels 20 | sfc: reanalysis-era5-single-levels 21 | an: 22 | # ... Other datasets 23 | '*': # Any other stream 24 | # ... Other datasets 25 | -------------------------------------------------------------------------------- /docs/overview.rst: -------------------------------------------------------------------------------- 1 | .. _overview: 2 | 3 | ########## 4 | Overview 5 | ########## 6 | 7 | .. _data-flow: 8 | 9 | .. figure:: _static/schemas/overview.png 10 | :alt: Overview 11 | :align: center 12 | :target: javascript:void(window.open('_images/overview.png')) 13 | 14 | Flow of data in inference. 15 | 16 | The schema above is a high-level overview of the data flow in the 17 | inference process (click on the image for a larger version). 18 | 19 | Several key concepts are introduced: 20 | 21 | - Input: 22 | - Prognostic variables: 23 | - Diagnostic variables: 24 | - Output: 25 | - Constant forcings: 26 | - Dynamic forcings: 27 | - Computed constants: 28 | - Computed forcings: 29 | -------------------------------------------------------------------------------- /.github/workflows/readthedocs-pr-update.yml: -------------------------------------------------------------------------------- 1 | # This workflow adds a link to the experimental documentation build to the PR. 2 | # This does NOT trigger a build of the documentation, this is handled through webhooks. 3 | name: "[PR] Read the Docs Preview" 4 | on: 5 | pull_request_target: 6 | types: 7 | - opened 8 | - synchronize 9 | - reopened 10 | # Execute this action only on PRs that touch 11 | # documentation files. 12 | paths: 13 | - "docs/**" 14 | 15 | permissions: 16 | pull-requests: write 17 | 18 | jobs: 19 | documentation-links: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: readthedocs/actions/preview@v1 23 | with: 24 | project-slug: "anemoi-inference" 25 | -------------------------------------------------------------------------------- /src/anemoi/inference/protocol.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | from typing import Any 12 | from typing import Protocol 13 | 14 | 15 | class MetadataProtocol(Protocol): 16 | """Protocol for metadata objects. This will keep `mypy` happy.""" 17 | 18 | _metadata: dict[str, Any] 19 | _supporting_arrays: dict[str, Any] 20 | grid: str 21 | variables: dict[str, Any] 22 | -------------------------------------------------------------------------------- /docs/dev/contributing.rst: -------------------------------------------------------------------------------- 1 | .. _dev-contributing: 2 | 3 | #################### 4 | General guidelines 5 | #################### 6 | 7 | Thank you for your interest in Anemoi Inference! Please follow the 8 | :ref:`general Anemoi contributing guidelines 9 | `. 10 | 11 | These include general guidelines for contributions to Anemoi, 12 | instructions on setting up a development environment, and guidelines on 13 | collaboration on GitHub, writing documentation, testing, and code style. 14 | 15 | ************ 16 | Unit tests 17 | ************ 18 | 19 | Anemoi-inference includes unit tests that can be executed locally using 20 | pytest. For more information on testing, please refer to the 21 | :ref:`general Anemoi testing guidelines 22 | `. 23 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env make -f 2 | 3 | # Minimal makefile for Sphinx documentation 4 | # 5 | 6 | # You can set these variables from the command line, and also 7 | # from the environment for the first two. 8 | SPHINXOPTS ?= 9 | SPHINXBUILD ?= sphinx-build 10 | SOURCEDIR = . 11 | BUILDDIR = _build 12 | 13 | # Put it first so that "make" without argument is like "make help". 14 | help: 15 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 16 | 17 | .PHONY: help Makefile 18 | 19 | # Catch-all target: route all unknown targets to Sphinx using the new 20 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 21 | %: Makefile 22 | bash $(SOURCEDIR)/scripts/api_build.sh 23 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 24 | -------------------------------------------------------------------------------- /.github/workflows/python-pull-request.yml: -------------------------------------------------------------------------------- 1 | # This workflow runs pre-commit checks and pytest tests against multiple platforms and Python versions. 2 | name: Code Quality and Testing 3 | 4 | on: 5 | pull_request: 6 | types: [opened, synchronize, reopened] 7 | push: 8 | branches: 9 | - main 10 | schedule: 11 | - cron: "9 2 * * 0" # at 9:02 on sunday 12 | 13 | jobs: 14 | quality: 15 | uses: ecmwf/reusable-workflows/.github/workflows/qa-precommit-run.yml@v2 16 | with: 17 | skip-hooks: "no-commit-to-branch" 18 | 19 | checks: 20 | strategy: 21 | matrix: 22 | python-version: ["3.10", "3.11", "3.12"] 23 | uses: ecmwf/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | -------------------------------------------------------------------------------- /src/anemoi/inference/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | try: 12 | # NOTE: the `_version.py` file must not be present in the git repository 13 | # as it is generated by setuptools at install time 14 | from ._version import __version__ 15 | except ImportError: # pragma: no cover 16 | # Local copy or not installed with setuptools 17 | __version__ = "999" 18 | 19 | __all__ = ["__version__"] 20 | -------------------------------------------------------------------------------- /src/anemoi/inference/commands/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import os 11 | 12 | from anemoi.utils.cli import Command 13 | from anemoi.utils.cli import Failed 14 | from anemoi.utils.cli import register_commands 15 | 16 | __all__ = ["Command"] 17 | 18 | COMMANDS = register_commands( 19 | os.path.dirname(__file__), 20 | __name__, 21 | lambda x: x.command(), 22 | lambda name, error: Failed(name, error), 23 | ) 24 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | ## How to Contribute 2 | 3 | Please see the [read the docs](https://anemoi.readthedocs.io/en/latest/contributing/contributing.html). 4 | 5 | 6 | ## Contributors 7 | 8 | Thank you to all the wonderful people who have contributed to Anemoi. Contributions can come in many forms, including code, documentation, bug reports, feature suggestions, design, and more. A list of code-based contributors can be found [here](https://github.com/ecmwf/anemoi-inference/graphs/contributors). 9 | 10 | 11 | ## Contributing Organisations 12 | 13 | Significant contributions have been made by the following organisations: [DMI](https://www.dmi.dk/), [DWD](https://www.dwd.de/), [FMI](https://www.ilmatieteenlaitos.fi/), [KNMI](https://www.knmi.nl), [MET Norway](https://www.met.no/), [MeteoSwiss](https://www.meteoswiss.admin.ch/), [RMI](https://www.meteo.be/), [Met Office](https://www.metoffice.gov.uk/), [Météo-France](https://meteofrance.com/) & [ECMWF](https://www.ecmwf.int/) 14 | -------------------------------------------------------------------------------- /docs/inference/apis/code/level2.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from anemoi.inference.config.run import RunConfiguration 4 | from anemoi.inference.inputs.gribfile import GribFileInput 5 | from anemoi.inference.outputs.gribfile import GribFileOutput 6 | from anemoi.inference.runners.default import DefaultRunner 7 | 8 | # Create a runner with the checkpoint file 9 | runner = DefaultRunner(RunConfiguration(checkpoint="checkpoint.ckpt")) 10 | 11 | # Select a starting date 12 | date = datetime.datetime(2024, 10, 25) 13 | 14 | input = GribFileInput(runner, "input.grib") 15 | output = GribFileOutput(runner, "output.grib") 16 | 17 | input_state = input.create_input_state(date=date) 18 | 19 | # Write the initial state to the output file 20 | output.write_initial_state(input_state) 21 | 22 | # Run the model and write the output to the file 23 | 24 | for state in runner.run(input_state=input_state, lead_time=240): 25 | output.write_state(state) 26 | 27 | # Close the output file 28 | output.close() 29 | -------------------------------------------------------------------------------- /src/anemoi/inference/device.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | from typing import TYPE_CHECKING 11 | 12 | if TYPE_CHECKING: 13 | import torch 14 | 15 | 16 | def get_available_device() -> "torch.device": 17 | """Get the available device for PyTorch. 18 | 19 | Returns 20 | ------- 21 | torch.device 22 | The available device, either 'cuda', 'mps', or 'cpu'. 23 | """ 24 | import torch 25 | 26 | if torch.cuda.is_available(): 27 | return torch.device("cuda") 28 | elif torch.backends.mps.is_available(): 29 | return torch.device("mps") 30 | return torch.device("cpu") 31 | -------------------------------------------------------------------------------- /src/anemoi/inference/outputs/none.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import logging 11 | 12 | from anemoi.inference.types import State 13 | 14 | from ..output import Output 15 | from . import output_registry 16 | 17 | LOG = logging.getLogger(__name__) 18 | 19 | 20 | @output_registry.register("none") 21 | class NoneOutput(Output): 22 | """None output class.""" 23 | 24 | def write_step(self, state: State) -> None: 25 | """Write a step of the state. 26 | 27 | Parameters 28 | ---------- 29 | state : State 30 | The state dictionary. 31 | """ 32 | pass 33 | -------------------------------------------------------------------------------- /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | .wy-side-nav-search { 2 | background-color: #f7f7f7; 3 | } 4 | 5 | /*There is a clash between xarray notebook styles and readthedoc*/ 6 | 7 | .rst-content dl.xr-attrs dt { 8 | all: revert; 9 | font-size: 95%; 10 | white-space: nowrap; 11 | } 12 | 13 | .rst-content dl.xr-attrs dd { 14 | font-size: 95%; 15 | } 16 | 17 | .xr-wrap { 18 | font-size: 85%; 19 | } 20 | 21 | .wy-table-responsive table td, .wy-table-responsive table th { 22 | white-space: inherit; 23 | } 24 | 25 | /* 26 | .wy-table-responsive table td, 27 | .wy-table-responsive table th { 28 | white-space: normal !important; 29 | vertical-align: top !important; 30 | } 31 | 32 | .wy-table-responsive { 33 | margin-bottom: 24px; 34 | max-width: 100%; 35 | overflow: visible; 36 | } */ 37 | 38 | /* Hide notebooks warnings */ 39 | .nboutput .stderr { 40 | display: none; 41 | } 42 | 43 | /* 44 | Set logo size 45 | */ 46 | .wy-side-nav-search .wy-dropdown > a img.logo, .wy-side-nav-search > a img.logo { 47 | width: 200px; 48 | } 49 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | ## What problem does this change solve? 5 | 6 | 7 | ## What issue or task does this change relate to? 8 | 9 | 10 | ## Additional notes ## 11 | 12 | 13 | ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** 14 | 15 | By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) 16 | -------------------------------------------------------------------------------- /.github/workflows/release-please.yml: -------------------------------------------------------------------------------- 1 | # This workflow uses an action to run Release Please to create a release PR. 2 | # It is governed by the config and manifest in the root of the repo. 3 | # For more information see: https://github.com/googleapis/release-please 4 | name: Run Release Please 5 | on: 6 | push: 7 | branches: 8 | - main 9 | - hotfix/* 10 | 11 | permissions: 12 | contents: write 13 | pull-requests: write 14 | 15 | jobs: 16 | release-please: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: googleapis/release-please-action@v4 20 | with: 21 | # this assumes that you have created a personal access token 22 | # (PAT) and configured it as a GitHub action secret named 23 | # `MY_RELEASE_PLEASE_TOKEN` (this secret name is not important). 24 | token: ${{ secrets.RELEASE_PLEASE_TOKEN }} 25 | # optional. customize path to .release-please-config.json 26 | config-file: .release-please-config.json 27 | # Currently releases are done from main 28 | target-branch: ${{ github.ref_name }} 29 | -------------------------------------------------------------------------------- /tests/unit/fake_metadata.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | from typing import Any 12 | 13 | 14 | class FakeMetadata: 15 | """A class to simulate metadata for testing purposes. 16 | 17 | This class returns None for any attribute accessed, simulating the absence 18 | of metadata. 19 | """ 20 | 21 | def __getattr__(self, name: str) -> Any: 22 | """Simulate the absence of metadata by returning None for any attribute. 23 | 24 | Parameters 25 | ---------- 26 | name : str 27 | The name of the attribute being accessed. 28 | 29 | Returns 30 | ------- 31 | Any 32 | Always returns None. 33 | """ 34 | return None 35 | -------------------------------------------------------------------------------- /.github/labeler.yml: -------------------------------------------------------------------------------- 1 | # This is the configuration file for the labeler action. 2 | # It assigns labels to pull requests based on the files changed in the PR. 3 | # See more here: https://github.com/actions/labeler 4 | dependencies: 5 | - changed-files: 6 | - any-glob-to-any-file: 7 | - "**/requirements.txt" 8 | - "**/setup.py" 9 | - "**/pyproject.toml" 10 | - "**/Pipfile" 11 | - "**/Pipfile.lock" 12 | - "**/requirements/*.txt" 13 | - "**/requirements/*.in" 14 | 15 | documentation: 16 | - changed-files: 17 | - any-glob-to-any-file: 18 | - "**/docs/**/*" 19 | - "*.md" 20 | - "*.rst" 21 | 22 | config: 23 | - changed-files: 24 | - any-glob-to-any-file: 25 | - "**/src/**/config/**/*" 26 | - "**/src/anemoi/inference/config.py" 27 | 28 | CI/CD: 29 | - changed-files: 30 | - any-glob-to-any-file: 31 | - "**/.pre-commit-config.yaml" 32 | - ".github/**/*" 33 | - "tox.ini" 34 | - ".coveragerc" 35 | 36 | tests: 37 | - changed-files: 38 | - any-glob-to-any-file: 39 | - "**/tests/**/*" 40 | - "**/test/**/*" 41 | - "**/test_*.py" 42 | - "**/test.py" 43 | - "**/conftest.py" 44 | -------------------------------------------------------------------------------- /tests/unit/test_checkpoint.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | from anemoi.inference.testing import fake_checkpoints 12 | 13 | 14 | @fake_checkpoints 15 | def test_checkpoint() -> None: 16 | """Test the Checkpoint class. 17 | 18 | Returns 19 | ------- 20 | None 21 | """ 22 | from anemoi.inference.checkpoint import Checkpoint 23 | 24 | c = Checkpoint("unit/checkpoints/simple.ckpt") 25 | c.select_variables( 26 | include=["prognostic"], 27 | exclude=["forcing", "computed", "diagnostic"], 28 | ) 29 | 30 | 31 | if __name__ == "__main__": 32 | for name, obj in list(globals().items()): 33 | if name.startswith("test_") and callable(obj): 34 | print(f"Running {name}...") 35 | obj() 36 | -------------------------------------------------------------------------------- /src/anemoi/inference/outputs/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | from anemoi.utils.registry import Registry 11 | 12 | from anemoi.inference.config import Configuration 13 | from anemoi.inference.context import Context 14 | from anemoi.inference.output import Output 15 | 16 | output_registry = Registry(__name__) 17 | 18 | 19 | def create_output(context: Context, config: Configuration) -> Output: 20 | """Create an output. 21 | 22 | Parameters 23 | ---------- 24 | context : Context 25 | The context for the output. 26 | config : Configuration 27 | The configuration for the output. 28 | 29 | Returns 30 | ------- 31 | object 32 | The created output. 33 | """ 34 | return output_registry.from_config(config, context) 35 | -------------------------------------------------------------------------------- /src/anemoi/inference/testing/variables.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | from anemoi.transform.variables import Variable 11 | 12 | tp = Variable.from_dict( 13 | "tp", 14 | { 15 | "mars": { 16 | "param": "tp", 17 | "levtype": "sfc", 18 | }, 19 | "process": "accumulation", 20 | "period": [0, 6], 21 | }, 22 | ) 23 | 24 | z = Variable.from_dict( 25 | "z", 26 | { 27 | "mars": { 28 | "param": "z", 29 | "levtype": "sfc", 30 | } 31 | }, 32 | ) 33 | 34 | w_100 = Variable.from_dict( 35 | "w_100", 36 | { 37 | "mars": { 38 | "param": "w", 39 | "levtype": "pl", 40 | "levelist": 100, 41 | } 42 | }, 43 | ) 44 | -------------------------------------------------------------------------------- /src/anemoi/inference/__main__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | from typing import Any 11 | 12 | from anemoi.utils.cli import cli_main 13 | from anemoi.utils.cli import make_parser 14 | 15 | from . import __version__ 16 | from .commands import COMMANDS 17 | 18 | 19 | # For read-the-docs 20 | def create_parser() -> Any: 21 | """Create a command-line argument parser. 22 | 23 | Returns 24 | ------- 25 | Any 26 | The command-line argument parser. 27 | """ 28 | return make_parser(__doc__, COMMANDS) 29 | 30 | 31 | def main() -> None: 32 | """Execute the main command-line interface. 33 | 34 | Returns 35 | ------- 36 | None 37 | """ 38 | cli_main(__version__, __doc__, COMMANDS) 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /src/anemoi/inference/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | from typing import Any 11 | 12 | from anemoi.utils.registry import Registry 13 | 14 | task_registry = Registry(__name__) 15 | 16 | 17 | def create_task(name: str, config: dict[str, Any], global_config: dict[str, Any]) -> Any: 18 | """Create a task instance based on the given configuration. 19 | 20 | Parameters 21 | ---------- 22 | name : str 23 | The name of the task. 24 | config : Dict[str, Any] 25 | The configuration for the task. 26 | global_config : Dict[str, Any] 27 | The global configuration. 28 | 29 | Returns 30 | ------- 31 | Any 32 | The created task instance. 33 | """ 34 | return task_registry.from_config(config, name, global_config=global_config) 35 | -------------------------------------------------------------------------------- /src/anemoi/inference/runners/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. 2 | # This software is licensed under the terms of the Apache Licence Version 2.0 3 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 4 | # In applying this licence, ECMWF does not waive the privileges and immunities 5 | # granted to it by virtue of its status as an intergovernmental organisation 6 | # nor does it submit to any jurisdiction. 7 | 8 | 9 | from typing import Any 10 | 11 | from anemoi.utils.registry import Registry 12 | 13 | from anemoi.inference.config.run import RunConfiguration 14 | 15 | runner_registry = Registry(__name__) 16 | 17 | 18 | def create_runner(config: RunConfiguration, **kwargs: Any) -> Any: 19 | """Create a runner instance based on the given configuration. 20 | 21 | Parameters 22 | ---------- 23 | config : Configuration 24 | The configuration for the runner. 25 | kwargs : dict 26 | Additional arguments for the runner. 27 | 28 | Returns 29 | ------- 30 | Any 31 | The created runner instance. 32 | """ 33 | return runner_registry.from_config(config.runner, config, **kwargs) 34 | -------------------------------------------------------------------------------- /src/anemoi/inference/pre_processors/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | from anemoi.utils.registry import Registry 11 | 12 | from anemoi.inference.context import Context 13 | from anemoi.inference.processor import Processor 14 | from anemoi.inference.types import ProcessorConfig 15 | 16 | pre_processor_registry = Registry(__name__) 17 | 18 | 19 | def create_pre_processor(context: Context, config: ProcessorConfig) -> Processor: 20 | """Create a pre-processor. 21 | 22 | Parameters 23 | ---------- 24 | context : Context 25 | The context for the pre-processor. 26 | config : Configuration 27 | The configuration for the pre-processor. 28 | 29 | Returns 30 | ------- 31 | Processor 32 | The created pre-processor. 33 | """ 34 | return pre_processor_registry.from_config(config, context) 35 | -------------------------------------------------------------------------------- /src/anemoi/inference/post_processors/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | 11 | from anemoi.utils.registry import Registry 12 | 13 | from anemoi.inference.context import Context 14 | from anemoi.inference.processor import Processor 15 | from anemoi.inference.types import ProcessorConfig 16 | 17 | post_processor_registry = Registry(__name__) 18 | 19 | 20 | def create_post_processor(context: Context, config: ProcessorConfig) -> Processor: 21 | """Create a post-processor. 22 | 23 | Parameters 24 | ---------- 25 | context : Context 26 | The context for the post-processor. 27 | config : Configuration 28 | The configuration for the post-processor. 29 | 30 | Returns 31 | ------- 32 | Processor 33 | The created post-processor. 34 | """ 35 | return post_processor_registry.from_config(config, context) 36 | -------------------------------------------------------------------------------- /src/anemoi/inference/task.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | import logging 10 | from abc import ABC 11 | 12 | LOG = logging.getLogger(__name__) 13 | 14 | 15 | class Task(ABC): 16 | """Abstract base class for tasks. 17 | 18 | Parameters 19 | ---------- 20 | name : str 21 | The name of the task. 22 | """ 23 | 24 | def __init__(self, name: str) -> None: 25 | """Initialize the Task. 26 | 27 | Parameters 28 | ---------- 29 | name : str 30 | The name of the task. 31 | """ 32 | self.name = name 33 | 34 | def __repr__(self) -> str: 35 | """Return a string representation of the Task. 36 | 37 | Returns 38 | ------- 39 | str 40 | String representation of the Task. 41 | """ 42 | return f"{self.__class__.__name__}({self.name})" 43 | -------------------------------------------------------------------------------- /src/anemoi/inference/inputs/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | from typing import Any 11 | 12 | from anemoi.utils.registry import Registry 13 | 14 | from anemoi.inference.config import Configuration 15 | from anemoi.inference.context import Context 16 | 17 | input_registry = Registry(__name__) 18 | 19 | 20 | def create_input(context: Context, config: Configuration, **kwargs) -> Any: 21 | """Create an input instance from the given context and configuration. 22 | 23 | Parameters 24 | ---------- 25 | context : Context 26 | The context in which the input is created. 27 | config : Configuration 28 | The configuration for the input. 29 | **kwargs : Any 30 | Additional keyword arguments to pass to the input constructor. 31 | 32 | Returns 33 | ------- 34 | Any 35 | The created input instance. 36 | """ 37 | return input_registry.from_config(config, context, **kwargs) 38 | -------------------------------------------------------------------------------- /src/anemoi/inference/transports/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | 11 | from anemoi.utils.registry import Registry 12 | 13 | from anemoi.inference.config import Configuration 14 | from anemoi.inference.task import Task 15 | from anemoi.inference.transport import Transport 16 | 17 | transport_registry = Registry(__name__) 18 | 19 | 20 | def create_transport(config: Configuration, couplings: Configuration, tasks: dict[str, Task]) -> Transport: 21 | """Create a transport instance based on the given configuration. 22 | 23 | Parameters 24 | ---------- 25 | config : Configuration 26 | The configuration for the transport. 27 | couplings : Any 28 | The couplings for the transport. 29 | tasks : Any 30 | The tasks to be executed. 31 | 32 | Returns 33 | ------- 34 | Any 35 | The created transport instance. 36 | """ 37 | return transport_registry.from_config(config, couplings, tasks) 38 | -------------------------------------------------------------------------------- /src/anemoi/inference/clusters/distributed.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025- ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | 11 | from anemoi.inference.clusters import cluster_registry 12 | from anemoi.inference.clusters.mapping import EnvMapping 13 | from anemoi.inference.clusters.mapping import MappingCluster 14 | 15 | DISTRIBUTED_MAPPING = EnvMapping( 16 | local_rank="LOCAL_RANK", 17 | global_rank="RANK", 18 | world_size="WORLD_SIZE", 19 | master_addr="MASTER_ADDR", 20 | master_port="MASTER_PORT", 21 | init_method="env://", 22 | ) 23 | 24 | 25 | @cluster_registry.register("distributed") 26 | class DistributedCluster(MappingCluster): 27 | """Distributed cluster that uses environment variables for distributed setup.""" 28 | 29 | def __init__(self) -> None: 30 | super().__init__(mapping=DISTRIBUTED_MAPPING) 31 | 32 | @classmethod 33 | def used(cls) -> bool: 34 | return bool(DISTRIBUTED_MAPPING.get_env("world_size")) and bool(DISTRIBUTED_MAPPING.get_env("global_rank")) 35 | -------------------------------------------------------------------------------- /docs/inference/configs/forcings.rst: -------------------------------------------------------------------------------- 1 | .. _forcings: 2 | 3 | ########## 4 | Forcings 5 | ########## 6 | 7 | :ref:`inputs` refers to the input methods used to fetch the initial 8 | conditions. If the model need data during the run (dynamic forcings), 9 | the forcings will be fetched by default from the same source as the 10 | input. However, you can specify a different source. 11 | 12 | The example below shows an example where the forcings are fetched from 13 | ``mars`` while the initial conditions are fetched from ``test``. 14 | 15 | .. literalinclude:: yaml/forcings_1.yaml 16 | :language: yaml 17 | 18 | This example above is a shortcut for: 19 | 20 | .. literalinclude:: yaml/forcings_2.yaml 21 | 22 | You can also specify different sources for each forcing. 23 | 24 | This is to get the initial constant forcings from a file: 25 | 26 | .. literalinclude:: yaml/forcings_3.yaml 27 | :language: yaml 28 | 29 | Get the dynamic forcings from mars: 30 | 31 | .. literalinclude:: yaml/forcings_4.yaml 32 | :language: yaml 33 | 34 | And the LAM boundary conditions from a mars as well: 35 | 36 | .. literalinclude:: yaml/forcings_5.yaml 37 | :language: yaml 38 | 39 | In the case of the last three type of forcings, it they are not 40 | specified, the value of `forcings.input` will be used. It this value is 41 | not specified, the value of `input` will then be used. 42 | -------------------------------------------------------------------------------- /docs/inference/apis/code/level1_1.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import numpy as np 4 | 5 | from anemoi.inference.runners.simple import SimpleRunner 6 | 7 | # Create a runner with the checkpoint file 8 | runner = SimpleRunner("checkpoint.ckpt") 9 | 10 | # Select a starting date 11 | date = datetime.datetime(2024, 10, 25) 12 | 13 | # Assuming that the initial conditions requires two 14 | # dates, e.g. T0 and T-6 15 | 16 | multi_step_input = 2 17 | 18 | # Define the grid 19 | 20 | latitudes = np.linspace(90, -90, 181) # 1 degree resolution 21 | longitudes = np.linspace(0, 359, 360) 22 | 23 | number_of_points = len(latitudes) * len(longitudes) 24 | latitudes, longitudes = np.meshgrid(latitudes, longitudes) 25 | 26 | # Create the initial state 27 | 28 | input_state = { 29 | "date": date, 30 | "latitudes": latitudes, 31 | "longitudes": longitudes, 32 | "fields": { 33 | "2t": np.random.rand(multi_step_input, number_of_points), 34 | "msl": np.random.rand(multi_step_input, number_of_points), 35 | "z_500": np.random.rand(multi_step_input, number_of_points), 36 | ...: ..., 37 | }, 38 | } 39 | 40 | # Run the model 41 | 42 | for state in runner.run(input_state=input_state, lead_time=240): 43 | # This is the date of the new state 44 | print("New state:", state["date"]) 45 | 46 | # This is value of a field for that date 47 | print("Forecasted 2t:", state["fields"]["2t"]) 48 | -------------------------------------------------------------------------------- /docs/_templates/apidoc/package.rst.jinja: -------------------------------------------------------------------------------- 1 | {%- macro automodule(modname, options) -%} 2 | .. automodule:: {{ modname }} 3 | {%- for option in options %} 4 | :{{ option }}: 5 | {%- endfor %} 6 | {%- endmacro %} 7 | 8 | {%- macro toctree(docnames) -%} 9 | .. toctree:: 10 | :maxdepth: {{ maxdepth }} 11 | {% for docname in docnames %} 12 | {{ docname }} 13 | {%- endfor %} 14 | {%- endmacro %} 15 | 16 | {%- if is_namespace %} 17 | {{- pkgname.split(".")[1:] | join(".") | e | heading }} 18 | {% else %} 19 | {{- pkgname.split(".")[1:] | join(" ") | e | heading }} 20 | {% endif %} 21 | 22 | {%- if is_namespace %} 23 | .. py:module:: {{ pkgname }} 24 | {% endif %} 25 | 26 | {%- if modulefirst and not is_namespace %} 27 | {{ automodule(["anemoi", pkgname] | join("."), [""]) }} 28 | {% endif %} 29 | 30 | {%- if subpackages %} 31 | Subpackages 32 | ----------- 33 | 34 | {{ toctree(subpackages) }} 35 | {% endif %} 36 | 37 | {%- if submodules %} 38 | {% if separatemodules %} 39 | {{ toctree(submodules) }} 40 | {% else %} 41 | {%- for submodule in submodules %} 42 | {% if show_headings %} 43 | {{- submodule.split(".")[2:] | join(".") | e | heading(2) }} 44 | {% endif %} 45 | {{ automodule(["anemoi", submodule] | join("."), automodule_options) }} 46 | {% endfor %} 47 | {%- endif %} 48 | {%- endif %} 49 | 50 | {%- if not modulefirst and not is_namespace %} 51 | Module contents 52 | --------------- 53 | 54 | {{ automodule(pkgname, automodule_options) }} 55 | {% endif %} 56 | -------------------------------------------------------------------------------- /src/anemoi/inference/lazy.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import logging 11 | from typing import TYPE_CHECKING 12 | 13 | LOG = logging.getLogger(__name__) 14 | 15 | 16 | class LazyModule: 17 | """Defer loading of a module until attribute access.""" 18 | 19 | def __init__(self, module_name: str): 20 | self.__module_name = module_name 21 | self.__module = None 22 | 23 | def __getattr__(self, name: str): 24 | if self.__module is None: 25 | import importlib 26 | 27 | LOG.debug("Importing '%s'", self.__module_name) 28 | 29 | self.__module = importlib.import_module(self.__module_name) 30 | return getattr(self.__module, name) 31 | 32 | 33 | # add heavy imports here, then they can be used in the rest of the codebase as regular imports 34 | # with type checking and autocompletion: `from anemoi.inference.lazy import torch` 35 | # note: when used in a type hint, use quotes, e.g. "torch.Tensor" instead of torch.Tensor to avoid triggering the import 36 | if TYPE_CHECKING: 37 | import torch 38 | else: 39 | torch = LazyModule("torch") 40 | -------------------------------------------------------------------------------- /src/anemoi/inference/types.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | import datetime 11 | from typing import Any 12 | from typing import Union 13 | 14 | from numpy.typing import NDArray 15 | 16 | """A collection of types used in the inference module. 17 | Some of these type could be moved to anemoi.utils.types or anemoi.transform.types. 18 | """ 19 | 20 | State = dict[str, Any] 21 | """A dictionary that represents the state of a model.""" 22 | 23 | DataRequest = dict[str, Any] 24 | """A dictionary that represent a data request, like MARS, CDS, OpenData, ...""" 25 | 26 | Date = Union[str, datetime.datetime, int] 27 | """A date can be a string, a datetime object or an integer. It will always be converted to a datetime object.""" 28 | 29 | IntArray = NDArray[Any] 30 | """A numpy array of integers.""" 31 | 32 | FloatArray = NDArray[Any] 33 | """A numpy array of floats.""" 34 | 35 | BoolArray = NDArray[Any] 36 | """A numpy array of booleans.""" 37 | 38 | Shape = tuple[int, ...] 39 | """A tuple of integers representing the shape of an array.""" 40 | 41 | ProcessorConfig = Union[str, dict[str, Any]] 42 | """A str or dict of str representing a pre- or post-processor configuration.""" 43 | -------------------------------------------------------------------------------- /.release-please-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "release-type": "python", 3 | "bump-minor-pre-major": true, 4 | "bump-patch-for-minor-pre-major": true, 5 | "separate-pull-requests": true, 6 | "always-update": true, 7 | "changelog-type": "default", 8 | "include-component-in-tag": false, 9 | "include-v-in-tag": false, 10 | "draft-pull-request": true, 11 | "pull-request-title-pattern": "chore${scope}: Release${component} ${version}", 12 | "pull-request-header": ":robot: Automated Release PR\n\nThis PR was created by `release-please` to prepare the next release. Once merged:\n\n1. A new version tag will be created\n2. A GitHub release will be published\n3. The changelog will be updated\n\nChanges to be included in the next release:", 13 | "pull-request-footer": "> [!IMPORTANT]\n> Please do not change the PR title, manifest file, or any other automatically generated content in this PR unless you understand the implications. Changes here can break the release process.\n> :warning: Merging this PR will:\n> - Create a new release\n> - Trigger deployment pipelines\n> - Update package versions\n\n **Before merging:**\n - Ensure all tests pass\n - Review the changelog carefully\n - Get required approvals\n\n [Release-please documentation](https://github.com/googleapis/release-please)", 14 | "packages": { 15 | ".": { 16 | "package-name": "anemoi-inference" 17 | } 18 | }, 19 | "plugins": [ 20 | { 21 | "type": "sentence-case" 22 | } 23 | ], 24 | "$schema": "https://raw.githubusercontent.com/googleapis/release-please/main/schemas/config.json" 25 | } 26 | -------------------------------------------------------------------------------- /src/anemoi/inference/inputs/grib.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import logging 12 | from typing import Any 13 | 14 | import earthkit.data as ekd 15 | 16 | from .ekd import EkdInput 17 | 18 | LOG = logging.getLogger(__name__) 19 | 20 | 21 | class GribInput(EkdInput): 22 | """Handles GRIB input fields.""" 23 | 24 | def set_private_attributes(self, state: Any, fields: ekd.FieldList) -> None: 25 | """Set private attributes for the state. 26 | 27 | Parameters 28 | ---------- 29 | state : Any 30 | The state to set private attributes for. 31 | fields : ekd.FieldList 32 | The input fields. 33 | """ 34 | # For now we just pass all the fields 35 | # Later, we can select a relevant subset (e.g. only one 36 | # level), to save memory 37 | 38 | # By sorting, we will have the most recent field last 39 | # no we can also use that list to write step 0 40 | super().set_private_attributes(state, fields) 41 | input_fields = fields.order_by("valid_datetime") 42 | 43 | state["_grib_templates_for_output"] = {field.metadata("name"): field for field in input_fields} 44 | -------------------------------------------------------------------------------- /src/anemoi/inference/config/couple.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | from __future__ import annotations 11 | 12 | import datetime 13 | import logging 14 | from typing import Any 15 | 16 | from anemoi.inference.config import Configuration 17 | 18 | LOG = logging.getLogger(__name__) 19 | 20 | 21 | class CoupleConfiguration(Configuration): 22 | """Configuration class for the couple runner.""" 23 | 24 | description: str | None = None 25 | 26 | lead_time: str | int | datetime.timedelta | None = None 27 | """The lead time for the forecast. This can be a string, an integer or a timedelta object. 28 | If an integer, it represents a number of hours. Otherwise, it is parsed by :func:`anemoi.utils.dates.as_timedelta`. 29 | """ 30 | 31 | name: str | None = None 32 | """Used by prepml.""" 33 | 34 | transport: str 35 | couplings: list[dict[str, list[str]]] 36 | tasks: dict[str, Any] 37 | 38 | env: dict[str, Any] = {} 39 | """Environment variables to set before running the model. This may be useful to control some packages 40 | such as `eccodes`. In certain cases, the variables mey be set too late, if the package for which they are intended 41 | is already loaded when the runner is configured. 42 | """ 43 | -------------------------------------------------------------------------------- /src/anemoi/inference/grib/templates/samples.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import logging 11 | import os 12 | from typing import Any 13 | 14 | import earthkit.data as ekd 15 | 16 | from . import IndexTemplateProvider 17 | from . import template_provider_registry 18 | from .manager import TemplateManager 19 | 20 | LOG = logging.getLogger(__name__) 21 | 22 | 23 | @template_provider_registry.register("samples") 24 | class SamplesTemplates(IndexTemplateProvider): 25 | """Class to provide GRIB templates from sample files.""" 26 | 27 | def __init__(self, manager: TemplateManager, *args, index_path: str | None = None) -> None: 28 | if index_path is not None: 29 | return super().__init__(manager, index_path) 30 | 31 | if isinstance(args[0], str): 32 | return super().__init__(manager, args[0]) 33 | 34 | return super().__init__(manager, [*args]) 35 | 36 | def load_template(self, grib: str, lookup: dict[str, Any]) -> ekd.Field | None: 37 | template = grib.format(**lookup) 38 | if not os.path.exists(template): 39 | LOG.warning(f"Template not found: {template}") 40 | return None 41 | 42 | LOG.debug(f"Loading sample file: {template}") 43 | return ekd.from_source("file", template)[0] 44 | -------------------------------------------------------------------------------- /tests/integration/rmi-lam/config.yaml: -------------------------------------------------------------------------------- 1 | - name: dataset-in-netcdf-out 2 | input: null # anemoi-datasets will download the zarr.zip at runtime 3 | output: 4 | - lam_output.nc 5 | - full_output.nc 6 | checks: 7 | - check_cutout_with_xarray: 8 | file: ${output:0} 9 | mask: 'lam_0' 10 | reference_date: 2020-01-02T00:00:00 11 | reference_dataset: 12 | dataset: ${s3:}/aifs-ea-an-oper-0001-mars-o48-2020-6h.zarr.zip 13 | area: [70, -55, 10, 70] 14 | - check_boundary_forcings_with_xarray: 15 | file: ${output:1} 16 | reference_dataset: 17 | dataset: 18 | cutout: 19 | - dataset: ${s3:}/aifs-ea-an-oper-0001-mars-o48-2020-6h.zarr.zip 20 | area: [70, -55, 10, 70] 21 | - dataset: ${s3:}/aifs-ea-an-oper-0001-mars-o32-2020-6h.zarr.zip 22 | - check_with_xarray: 23 | file: ${output:0} 24 | check_accum: tp 25 | check_nans: true 26 | inference_config: 27 | checkpoint: ${checkpoint:} 28 | date: 2020-01-02T00:00:00 29 | write_initial_state: false 30 | input: 31 | dataset: 32 | cutout: 33 | - dataset: ${s3:}/aifs-ea-an-oper-0001-mars-o48-2020-6h.zarr.zip 34 | area: [70, -55, 10, 70] 35 | - dataset: ${s3:}/aifs-ea-an-oper-0001-mars-o32-2020-6h.zarr.zip 36 | output: 37 | tee: 38 | - netcdf: 39 | post_processors: 40 | - extract_mask: 41 | mask: lam_0/cutout_mask 42 | as_slice: true 43 | path: ${output:0} 44 | - netcdf: ${output:1} 45 | post_processors: 46 | - accumulate_from_start_of_forecast 47 | -------------------------------------------------------------------------------- /src/anemoi/inference/utils/templating.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025- Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import re 12 | 13 | _TEMPLATE_EXPRESSION_PATTERN = re.compile(r"\{(.*?)\}") 14 | 15 | 16 | def render_template(template: str, handle: dict) -> str: 17 | """Render a template string with the given keyword arguments. 18 | 19 | Given a template string such as '{dateTime}_{step:03}.grib' and 20 | the GRIB handle, this function will replace the expressions in the 21 | template with the corresponding values from the handle, formatted 22 | according to the optional format specifier. 23 | 24 | For example, the template '{dateTime}_{step:03}.grib' with a handle 25 | containing 'dateTime' as '202501011200' and 'step' as 6 will 26 | produce '202501011200_006.grib'. 27 | 28 | Parameters 29 | ---------- 30 | template : str 31 | The template string to render. 32 | handle : dict 33 | The dictionary to use for rendering the template. 34 | 35 | Returns 36 | ------- 37 | str 38 | The rendered template string. 39 | """ 40 | expressions = _TEMPLATE_EXPRESSION_PATTERN.findall(str(template)) 41 | expr_format = [el.split(":") if ":" in el else [el, ""] for el in expressions] 42 | keys = {k[0]: format(handle.get(k[0]), k[1]) for k in expr_format} 43 | path = str(template).format(**keys) 44 | return path 45 | -------------------------------------------------------------------------------- /src/anemoi/inference/grib/templates/builtin.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import base64 11 | import logging 12 | import os 13 | import zlib 14 | from typing import Any 15 | 16 | import earthkit.data as ekd 17 | 18 | from . import IndexTemplateProvider 19 | from . import template_provider_registry 20 | from .manager import TemplateManager 21 | 22 | LOG = logging.getLogger(__name__) 23 | 24 | # 1 - Get a GRIB with mars: retrieve,param=tp,levtype=sfc,type=fc,step=6,target=data.grib,grid=0.25/0.25 25 | # 2 - grib_set -s edition=2,packingType=grid_ccsds data.grib data.grib2 26 | # 3 - grib_set -d 0 data.grib2 out.grib 27 | # 4 - python -c 'import base64, sys, zlib;print(base64.b64encode(zlib.compress(open(sys.argv[1], "rb").read())))' out.grib 28 | 29 | # tp in grib2, 0.25/0.25 grid 30 | 31 | 32 | @template_provider_registry.register("builtin") 33 | class BuiltinTemplates(IndexTemplateProvider): 34 | """Builtin templates provider.""" 35 | 36 | def __init__(self, manager: TemplateManager, index_path: str | None = None) -> None: 37 | if index_path is None: 38 | index_path = os.path.join(os.path.dirname(__file__), "builtin.yaml") 39 | 40 | super().__init__(manager, index_path) 41 | 42 | def load_template(self, grib: str, lookup: dict[str, Any]) -> ekd.Field | None: 43 | template = zlib.decompress(base64.b64decode(grib)) 44 | return ekd.from_source("memory", template)[0] 45 | -------------------------------------------------------------------------------- /src/anemoi/inference/clusters/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025- ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | from typing import Any 11 | 12 | from anemoi.utils.registry import Registry 13 | 14 | from .client import ComputeClientFactory 15 | from .spawner import ComputeSpawner 16 | 17 | cluster_registry: Registry[ComputeClientFactory | ComputeSpawner] = Registry(__name__) 18 | 19 | 20 | def create_cluster(config: dict[str, Any] | str, *args, **kwargs) -> ComputeClientFactory | ComputeSpawner: 21 | """Find and return the appropriate cluster for the current environment. 22 | 23 | Parameters 24 | ---------- 25 | config : dict 26 | Configuration for the cluster. 27 | Can be string or dict. 28 | args : Any 29 | Additional positional arguments. 30 | kwargs : Any 31 | Additional keyword arguments. 32 | 33 | Returns 34 | ------- 35 | Cluster 36 | The created cluster instance. 37 | """ 38 | if config: 39 | return cluster_registry.from_config(config, *args, **kwargs) 40 | 41 | for cluster in cluster_registry.factories: 42 | cluster_cls = cluster_registry.lookup(cluster) 43 | assert cluster_cls is not None 44 | 45 | if cluster_cls.used(): 46 | return cluster_cls(*args, **kwargs) 47 | 48 | raise RuntimeError( 49 | f"No suitable cluster found for the current environment,\nDiscovered implementations were {cluster_registry.registered}." 50 | ) 51 | -------------------------------------------------------------------------------- /docs/inference/configs/introduction.rst: -------------------------------------------------------------------------------- 1 | .. _config_introduction: 2 | 3 | ######### 4 | Configs 5 | ######### 6 | 7 | This document provides an overview of the configuration to provide to 8 | the :ref:`anemoi-inference run ` command line tool. 9 | 10 | The configuration file is a YAML file that specifies various options. It 11 | is extended by `OmegaConf `_ such 12 | that `interpolations 13 | `_ 14 | can be used. It is composed of :ref:`top level ` options 15 | which are usually simple values such as strings, number or booleans. The 16 | configuration also provide ways to specify which internal classes to use 17 | for the :ref:`inputs ` and :ref:`outputs `, and how to 18 | configure them. 19 | 20 | In that case, the general format is shown below. The first entry 21 | (``mars`` or ``grib`` in the examples below) corresponds to the 22 | underlying Python class that will be used to process the input, output 23 | or any other polymorphic behaviour, followed by arguments specific to 24 | that class. 25 | 26 | .. literalinclude:: yaml/introduction_1.yaml 27 | :language: yaml 28 | 29 | or: 30 | 31 | .. literalinclude:: yaml/introduction_2.yaml 32 | :language: yaml 33 | 34 | If the underlying class does not require any arguments, or you wish to 35 | use the default parameters, then configuration can be simplified as: 36 | 37 | .. literalinclude:: yaml/introduction_3.yaml 38 | :language: yaml 39 | 40 | or if it expects a single argument, it can be simplified as: 41 | 42 | .. literalinclude:: yaml/introduction_4.yaml 43 | :language: yaml 44 | 45 | .. toctree:: 46 | :maxdepth: 1 47 | :caption: Configurations 48 | 49 | top-level 50 | inputs 51 | outputs 52 | processors 53 | forcings 54 | grib-input 55 | grib-output 56 | -------------------------------------------------------------------------------- /.github/workflows/pr-label-conventional-commits.yml: -------------------------------------------------------------------------------- 1 | # This workflow assigns labels to a pull request based on the Conventional Commits format. 2 | # This is necessary for release-please to work properly. 3 | name: "[PR] Label Conventional Commits" 4 | 5 | on: 6 | pull_request: 7 | branches: [main] 8 | types: 9 | [opened, reopened, labeled, unlabeled] 10 | 11 | permissions: 12 | pull-requests: write 13 | 14 | jobs: 15 | assign-labels: 16 | runs-on: ubuntu-latest 17 | name: Assign labels in pull request 18 | if: github.event.pull_request.merged == false 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Assign labels from Conventional Commits 22 | id: action-assign-labels 23 | uses: mauroalderete/action-assign-labels@v1 24 | with: 25 | pull-request-number: ${{ github.event.pull_request.number }} 26 | github-token: ${{ secrets.GITHUB_TOKEN }} 27 | conventional-commits: | 28 | conventional-commits: 29 | - type: 'fix' 30 | nouns: ['FIX', 'Fix', 'fix', 'FIXED', 'Fixed', 'fixed'] 31 | labels: ['bug'] 32 | - type: 'feature' 33 | nouns: ['FEATURE', 'Feature', 'feature', 'FEAT', 'Feat', 'feat'] 34 | labels: ['enhancement'] 35 | - type: 'breaking_change' 36 | nouns: ['BREAKING CHANGE', 'BREAKING', 'MAJOR'] 37 | labels: ['breaking change'] 38 | - type: 'documentation' 39 | nouns: ['doc','docs','docu','document','documentation'] 40 | labels: ['documentation'] 41 | - type: 'build' 42 | nouns: ['build','rebuild','ci'] 43 | labels: ['CI/CD'] 44 | - type: 'config' 45 | nouns: ['config', 'conf', 'configuration'] 46 | labels: ['config'] 47 | maintain-labels-not-matched: true 48 | apply-changes: true 49 | -------------------------------------------------------------------------------- /src/anemoi/inference/clusters/spawner.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025- ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | from abc import ABC 11 | from abc import abstractmethod 12 | from typing import TYPE_CHECKING 13 | from typing import Callable 14 | 15 | if TYPE_CHECKING: 16 | from anemoi.inference.clusters.client import ComputeClientFactory 17 | from anemoi.inference.config import Configuration 18 | 19 | SPAWN_FUNCTION = Callable[["Configuration", "ComputeClientFactory"], None] 20 | 21 | 22 | class ComputeSpawner(ABC): 23 | """Abstract base class for cluster operations for parallel execution.""" 24 | 25 | @classmethod 26 | @abstractmethod 27 | def used(cls) -> bool: 28 | """Check if this client is valid in the current environment.""" 29 | raise NotImplementedError 30 | 31 | @abstractmethod 32 | def spawn(self, fn: SPAWN_FUNCTION, config: "Configuration") -> None: 33 | """Spawn processes for parallel execution. 34 | 35 | Parameters 36 | ---------- 37 | fn : SPAWN_FUNCTION 38 | The function to run in each process. 39 | Expects to receive the configuration and compute client factory as arguments. 40 | config : Configuration 41 | The configuration object for the runner. 42 | """ 43 | raise NotImplementedError 44 | 45 | @abstractmethod 46 | def teardown(self) -> None: 47 | """Tear down the cluster environment.""" 48 | raise NotImplementedError 49 | 50 | def __enter__(self): 51 | return self 52 | 53 | def __exit__(self, exc_type, exc_value, traceback): 54 | self.teardown() 55 | -------------------------------------------------------------------------------- /src/anemoi/inference/commands/sanitise.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import logging 12 | from argparse import ArgumentParser 13 | from argparse import Namespace 14 | from copy import deepcopy 15 | 16 | from . import Command 17 | 18 | LOG = logging.getLogger(__name__) 19 | 20 | 21 | class SanitiseCmd(Command): 22 | """Sanitise a checkpoint file.""" 23 | 24 | def add_arguments(self, command_parser: ArgumentParser) -> None: 25 | """Add arguments to the command parser. 26 | 27 | Parameters 28 | ---------- 29 | command_parser : ArgumentParser 30 | The argument parser to which the arguments will be added. 31 | """ 32 | command_parser.add_argument("path", help="Path to the checkpoint.") 33 | 34 | def run(self, args: Namespace) -> None: 35 | """Run the sanitise command. 36 | 37 | Parameters 38 | ---------- 39 | args : Namespace 40 | The arguments passed to the command. 41 | """ 42 | from anemoi.utils.checkpoints import load_metadata 43 | from anemoi.utils.checkpoints import replace_metadata 44 | from anemoi.utils.sanitise import sanitise 45 | 46 | original_metadata, supporting_arrays = load_metadata(args.path, supporting_arrays=True) 47 | metadata = deepcopy(original_metadata) 48 | metadata = sanitise(metadata) 49 | 50 | if metadata != original_metadata: 51 | LOG.info("Patching metadata") 52 | assert "sources" in metadata["dataset"] 53 | replace_metadata(args.path, metadata, supporting_arrays) 54 | else: 55 | LOG.info("Metadata is already sanitised") 56 | 57 | 58 | command = SanitiseCmd 59 | -------------------------------------------------------------------------------- /.github/workflows/downstream-ci-hpc.yml: -------------------------------------------------------------------------------- 1 | # This workflow triggers tests on dependent packages. 2 | # The dependency tree itself is defined in ecmwf/downstream-ci/ 3 | name: Test downstream dependent packages 4 | 5 | on: 6 | # Trigger the workflow on push to main or develop, except tag creation 7 | push: 8 | branches: 9 | - 'main' 10 | - 'develop' 11 | tags-ignore: 12 | - '**' 13 | paths-ignore: 14 | - "docs/**" 15 | - "CHANGELOG.md" 16 | - "README.md" 17 | 18 | # Trigger the workflow on pull request 19 | pull_request: 20 | paths-ignore: 21 | - "docs/**" 22 | - "CHANGELOG.md" 23 | - "README.md" 24 | 25 | # Trigger the workflow manually 26 | workflow_dispatch: ~ 27 | 28 | # Trigger after public PR approved for CI 29 | pull_request_target: 30 | types: [labeled] 31 | paths-ignore: 32 | - "docs/**" 33 | - "CHANGELOG.md" 34 | - "README.md" 35 | 36 | jobs: 37 | # Run CI including downstream packages on self-hosted runners 38 | downstream-ci: 39 | name: downstream-ci 40 | if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }} 41 | uses: ecmwf/downstream-ci/.github/workflows/downstream-ci.yml@main 42 | with: 43 | anemoi-inference: ecmwf/anemoi-inference@${{ github.event.pull_request.head.sha || github.sha }} 44 | codecov_upload: true 45 | # Only run on fedora 46 | skip_matrix_jobs: | 47 | gnu@debian-11 48 | gnu@rocky-8.6 49 | clang@rocky-8.6 50 | gnu@ubuntu-22.04 51 | secrets: inherit 52 | 53 | # # Build downstream packages on HPC 54 | # downstream-ci-hpc: 55 | # name: downstream-ci-hpc 56 | # if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }} 57 | # uses: ecmwf/downstream-ci/.github/workflows/downstream-ci-hpc.yml@main 58 | # with: 59 | # anemoi-inference: ecmwf/anemoi-inference@${{ github.event.pull_request.head.sha || github.sha }} 60 | # secrets: inherit 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # anemoi-inference 2 | 3 |

4 | 5 | Maturity Level 6 | 7 | 8 | Licence 9 | 10 | 11 | Latest Release 12 | 13 |

14 | 15 | > \[!IMPORTANT\] 16 | > This software is **Incubating** and subject to ECMWF's guidelines on [Software Maturity](https://github.com/ecmwf/codex/raw/refs/heads/main/Project%20Maturity). 17 | 18 | 19 | 20 | ## Documentation 21 | 22 | The documentation can be found at https://anemoi-inference.readthedocs.io/. 23 | 24 | ## Contributing 25 | 26 | You can find information about contributing to Anemoi at our [Contribution page](https://anemoi.readthedocs.io/en/latest/contributing/contributing.html). 27 | 28 | ## Install 29 | 30 | Install via `pip` with: 31 | 32 | ``` 33 | $ pip install anemoi-inference 34 | ``` 35 | 36 | ## License 37 | 38 | ``` 39 | Copyright 2024-2025, Anemoi Contributors. 40 | 41 | Licensed under the Apache License, Version 2.0 (the "License"); 42 | you may not use this file except in compliance with the License. 43 | You may obtain a copy of the License at 44 | 45 | http://www.apache.org/licenses/LICENSE-2.0 46 | 47 | Unless required by applicable law or agreed to in writing, software 48 | distributed under the License is distributed on an "AS IS" BASIS, 49 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 50 | See the License for the specific language governing permissions and 51 | limitations under the License. 52 | 53 | In applying this licence, ECMWF does not waive the privileges and immunities 54 | granted to it by virtue of its status as an intergovernmental organisation 55 | nor does it submit to any jurisdiction. 56 | ``` 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | __pypackages__/ 8 | 9 | # Distribution / packaging 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | share/python-wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # Testing and coverage 29 | htmlcov/ 30 | .tox/ 31 | .nox/ 32 | .coverage 33 | .coverage.* 34 | .cache 35 | nosetests.xml 36 | coverage.xml 37 | *.cover 38 | *.py,cover 39 | .hypothesis/ 40 | .pytest_cache/ 41 | cover/ 42 | 43 | # Documentation 44 | docs/_build/ 45 | docs/_api/ 46 | /site 47 | *.mo 48 | *.pot 49 | 50 | # Environments 51 | .env 52 | .envrc 53 | .venv 54 | env/ 55 | venv/ 56 | ENV/ 57 | env.bak/ 58 | venv.bak/ 59 | 60 | # IDEs 61 | .idea/ 62 | .spyderproject 63 | .spyproject 64 | .ropeproject 65 | .vscode/ 66 | *.code-workspace 67 | *.sublime-project 68 | *.sublime-workspace 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # Type checking 74 | .mypy_cache/ 75 | .dmypy.json 76 | dmypy.json 77 | .pyre/ 78 | .pytype/ 79 | 80 | # Data files 81 | *.grib 82 | *.grib1 83 | *.grib2 84 | *.onnx 85 | *.ckpt 86 | *.npy 87 | *.npz 88 | *.zarr/ 89 | *.nc 90 | *.h5 91 | *.hdf5 92 | *.pkl 93 | *.parquet 94 | *.csv 95 | *.xlsx 96 | *.xls 97 | *.json 98 | *.txt 99 | *.zip 100 | *.db 101 | *.tgz 102 | 103 | # ML artifacts 104 | wandb/ 105 | mlruns/ 106 | lightning_logs/ 107 | *.ckpt 108 | *.pt 109 | *.pth 110 | runs/ 111 | checkpoints/ 112 | 113 | # Temporary and system files 114 | *.swp 115 | *.download 116 | *.out 117 | *.sync 118 | *.dot 119 | *.tmp 120 | *.log 121 | *.log.* 122 | .DS_Store 123 | ~* 124 | tmp/ 125 | temp/ 126 | logs/ 127 | _dev/ 128 | _api/ 129 | ./outputs 130 | *tmp_data/ 131 | 132 | # Project specific 133 | ? 134 | ?.* 135 | foo 136 | bar 137 | ~$images.pptx 138 | test.py 139 | test.ipynb 140 | _version.py 141 | *.to_upload 142 | tempCodeRunnerFile.python 143 | Untitled-*.py 144 | trace.txt 145 | ?/ 146 | *.prof 147 | prof/ 148 | *.gz 149 | -------------------------------------------------------------------------------- /tests/unit/inputs/test_inference.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import logging 11 | 12 | from anemoi.inference.config.run import RunConfiguration 13 | from anemoi.inference.runners import create_runner 14 | from anemoi.inference.testing import fake_checkpoints 15 | from anemoi.inference.testing import files_for_tests 16 | 17 | 18 | @fake_checkpoints 19 | def test_inference_simple() -> None: 20 | """Test the inference process using a fake checkpoint. 21 | 22 | This function loads a configuration, creates a runner, and runs the inference 23 | process to ensure that the system works as expected with the provided configuration. 24 | """ 25 | config = RunConfiguration.load( 26 | files_for_tests("unit/configs/simple.yaml"), 27 | overrides=dict(runner="testing", device="cpu", input="dummy", trace_path="trace.log"), 28 | ) 29 | runner = create_runner(config) 30 | runner.execute() 31 | 32 | 33 | @fake_checkpoints 34 | def test_inference_mwd() -> None: 35 | """Test the inference process using a fake checkpoint. 36 | 37 | This function loads a configuration, creates a runner, and runs the inference 38 | process to ensure that the system works as expected with the provided configuration. 39 | """ 40 | config = RunConfiguration.load( 41 | files_for_tests("unit/configs/mwd.yaml"), 42 | overrides=dict(runner="testing", device="cpu", input="dummy"), 43 | ) 44 | runner = create_runner(config) 45 | runner.execute() 46 | 47 | 48 | if __name__ == "__main__": 49 | logging.basicConfig(level=logging.INFO) 50 | for name, obj in list(globals().items()): 51 | if name.startswith("test_") and callable(obj): 52 | print(f"Running {name}...") 53 | obj() 54 | -------------------------------------------------------------------------------- /tests/unit/inputs/test_request_patching.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import pytest 11 | 12 | from anemoi.inference.config.run import RunConfiguration 13 | from anemoi.inference.inputs.dummy import DummyInput 14 | from anemoi.inference.processor import Processor 15 | from anemoi.inference.runners import create_runner 16 | from anemoi.inference.testing import fake_checkpoints 17 | from anemoi.inference.testing import files_for_tests 18 | 19 | 20 | class DummyProcessor(Processor): 21 | def __init__(self, context, mark: str): 22 | super().__init__(context) 23 | self.mark = mark 24 | 25 | def process(self, data: dict) -> dict: # type: ignore 26 | """A simple processor that returns the input data unchanged.""" 27 | return data 28 | 29 | def patch_data_request(self, data: dict) -> dict: # type: ignore 30 | """A simple patch method that returns the input data unchanged.""" 31 | data[self.mark] = True 32 | return data 33 | 34 | 35 | @pytest.fixture 36 | @fake_checkpoints 37 | def runner() -> None: 38 | config = RunConfiguration.load( 39 | files_for_tests("unit/configs/simple.yaml"), 40 | overrides=dict(runner="testing", device="cpu", input="dummy", trace_path="trace.log"), 41 | ) 42 | return create_runner(config) 43 | 44 | 45 | @fake_checkpoints 46 | def test_patched_by_input_and_context(runner): 47 | runner.pre_processors.append(DummyProcessor(runner, "context")) 48 | 49 | input = DummyInput(runner) 50 | input.pre_processors.append(DummyProcessor(runner, "input")) 51 | 52 | empty_request = {} 53 | patched_request = input.patch_data_request(empty_request) 54 | 55 | assert patched_request["context"] is True 56 | assert patched_request["input"] is True 57 | -------------------------------------------------------------------------------- /tests/unit/test_metadata.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import pytest 11 | 12 | from anemoi.inference.metadata import Metadata 13 | from anemoi.inference.testing.mock_checkpoint import mock_load_metadata 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "initial, patch, expected", 18 | [ 19 | ( 20 | {"config": {"dataloader": {"dataset": "abc"}}}, 21 | {"config": {"dataloader": {"something": {"else": "123"}}}}, 22 | {"dataset": "abc", "something": {"else": "123"}}, 23 | ), 24 | ( 25 | {"config": {"dataloader": [{"dataset": "abc"}, {"dataset": "xyz"}]}}, 26 | {"config": {"dataloader": {"cutout": [{"dataset": "123"}, {"dataset": "456"}]}}}, 27 | {"cutout": [{"dataset": "123"}, {"dataset": "456"}]}, 28 | ), 29 | ( 30 | {"config": {"dataloader": "abc"}}, 31 | {"config": {"dataloader": "xyz"}}, 32 | "xyz", 33 | ), 34 | ], 35 | ) 36 | def test_patch(initial, patch, expected): 37 | metadata = Metadata(initial) 38 | assert metadata._metadata == initial 39 | 40 | metadata.patch(patch) 41 | assert metadata._metadata["config"]["dataloader"] == expected 42 | 43 | 44 | def test_constant_fields_patch(): 45 | model_metadata = mock_load_metadata("unit/checkpoints/atmos.json", supporting_arrays=False) 46 | metadata = Metadata(model_metadata) 47 | 48 | fields = ["z", "sdor", "slor", "lsm"] 49 | metadata.patch({"dataset": {"constant_fields": fields}}) 50 | assert metadata._metadata["dataset"]["constant_fields"] == fields 51 | 52 | # check that the rest of the metadata is still the same after patching 53 | metadata._metadata["dataset"].pop("constant_fields") 54 | assert model_metadata == metadata._metadata 55 | -------------------------------------------------------------------------------- /src/anemoi/inference/commands/run.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | from __future__ import annotations 11 | 12 | import logging 13 | from argparse import ArgumentParser 14 | from argparse import Namespace 15 | 16 | from ..config.run import RunConfiguration 17 | from ..runners import create_runner 18 | from . import Command 19 | 20 | LOG = logging.getLogger(__name__) 21 | 22 | 23 | class RunCmd(Command): 24 | """Run inference from a config yaml file.""" 25 | 26 | def add_arguments(self, command_parser: ArgumentParser) -> None: 27 | """Add arguments to the command parser. 28 | 29 | Parameters 30 | ---------- 31 | command_parser : ArgumentParser 32 | The argument parser to which the arguments will be added. 33 | """ 34 | command_parser.add_argument("--defaults", action="append", help="Sources of default values.") 35 | command_parser.add_argument( 36 | "config", 37 | help="Path to config file. Can be omitted to pass config with overrides and defaults.", 38 | ) 39 | command_parser.add_argument("overrides", nargs="*", help="Overrides as key=value") 40 | 41 | def run(self, args: Namespace) -> None: 42 | """Run the inference command. 43 | 44 | Parameters 45 | ---------- 46 | args : Namespace 47 | The arguments passed to the command. 48 | """ 49 | if "=" in args.config: 50 | args.overrides.append(args.config) 51 | args.config = {} 52 | 53 | config = RunConfiguration.load( 54 | args.config, 55 | args.overrides, 56 | defaults=args.defaults, 57 | ) 58 | 59 | runner = create_runner(config) 60 | runner.execute() 61 | 62 | 63 | command = RunCmd 64 | -------------------------------------------------------------------------------- /src/anemoi/inference/processor.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | from abc import ABC 10 | from abc import abstractmethod 11 | from typing import TYPE_CHECKING 12 | 13 | from anemoi.inference.types import DataRequest 14 | from anemoi.inference.types import State 15 | 16 | if TYPE_CHECKING: 17 | from anemoi.inference.context import Context 18 | 19 | 20 | class Processor(ABC): 21 | """Abstract base class for processors. 22 | 23 | Parameters 24 | ---------- 25 | context : Context 26 | The context in which the processor operates. 27 | """ 28 | 29 | def __init__(self, context: "Context") -> None: 30 | self.context = context 31 | self.checkpoint = context.checkpoint 32 | 33 | def __repr__(self) -> str: 34 | """Return a string representation of the processor. 35 | 36 | Returns 37 | ------- 38 | str 39 | The class name of the processor. 40 | """ 41 | return f"{self.__class__.__name__}()" 42 | 43 | @abstractmethod 44 | def process(self, state: State) -> State: 45 | """Process the given state. 46 | 47 | Parameters 48 | ---------- 49 | state : State 50 | The state to be processed. 51 | 52 | Returns 53 | ------- 54 | State 55 | The processed state. 56 | """ 57 | pass 58 | 59 | def patch_data_request(self, data_request: DataRequest) -> DataRequest: 60 | """Override if a processor needs to patch the data request (e.g. mars or cds). 61 | 62 | Parameters 63 | ---------- 64 | data_request : DataRequest 65 | The data request to be patched. 66 | 67 | Returns 68 | ------- 69 | DataRequest 70 | The patched data request. 71 | """ 72 | return data_request 73 | -------------------------------------------------------------------------------- /src/anemoi/inference/pre_processors/no_missing_values.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024-2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import logging 12 | from typing import Any 13 | 14 | import tqdm 15 | from anemoi.transform.fields import new_field_from_numpy 16 | from anemoi.transform.fields import new_fieldlist_from_list 17 | 18 | from anemoi.inference.context import Context 19 | from anemoi.inference.types import State 20 | 21 | from ..processor import Processor 22 | from . import pre_processor_registry 23 | 24 | LOG = logging.getLogger(__name__) 25 | 26 | 27 | @pre_processor_registry.register("no_missing_values") 28 | class NoMissingValues(Processor): 29 | """Replace NaNs with mean.""" 30 | 31 | def __init__(self, context: Context, **kwargs: Any) -> None: 32 | """Initialize the NoMissingValues processor. 33 | 34 | Parameters 35 | ---------- 36 | context : Context 37 | The context in which the processor operates. 38 | kwargs : Any 39 | Additional keyword arguments. 40 | """ 41 | super().__init__(context) 42 | 43 | def process(self, state: State) -> State: 44 | """Process the fields to replace NaNs with the mean value. 45 | 46 | Parameters 47 | ---------- 48 | state : State 49 | The state to process. 50 | 51 | Returns 52 | ------- 53 | list 54 | List of processed state with NaNs replaced by the mean value. 55 | """ 56 | result = [] 57 | fields = state["fields"] 58 | for field in tqdm.tqdm(fields): 59 | import numpy as np 60 | 61 | data = field.to_numpy() 62 | 63 | mean_value = np.nanmean(data) 64 | 65 | data = np.where(np.isnan(data), mean_value, data) 66 | result.append(new_field_from_numpy(data, template=field)) 67 | 68 | state["fields"] = new_fieldlist_from_list(result) 69 | return state 70 | -------------------------------------------------------------------------------- /src/anemoi/inference/precisions.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | """List of precisions supported by the inference runner.""" 12 | 13 | 14 | from functools import cached_property 15 | 16 | 17 | class LazyDict: 18 | """A dictionary that lazily loads its values. 19 | So we don't import torch at the top level, which can be slow. 20 | """ 21 | 22 | def get(self, key, default=None): 23 | return self._mapping.get(key, default) 24 | 25 | def keys(self): 26 | return self._mapping.keys() 27 | 28 | def values(self): 29 | return self._mapping.values() 30 | 31 | def items(self): 32 | return self._mapping.items() 33 | 34 | def __getitem__(self, key): 35 | return self._mapping[key] 36 | 37 | @cached_property 38 | def _mapping(self): 39 | import torch 40 | 41 | return { 42 | "16-mixed": torch.float16, 43 | "16": torch.float16, 44 | "32": torch.float32, 45 | "b16-mixed": torch.bfloat16, 46 | "b16": torch.bfloat16, 47 | "bf16-mixed": torch.bfloat16, 48 | "bf16": torch.bfloat16, 49 | "bfloat16": torch.bfloat16, 50 | "f16": torch.float16, 51 | "f32": torch.float32, 52 | "float16": torch.float16, 53 | "float32": torch.float32, 54 | } 55 | 56 | 57 | PRECISIONS = LazyDict() 58 | 59 | if __name__ == "__main__": 60 | # This is just to make sure that the module can be imported without errors. 61 | # It will not be executed when the module is imported, only when run as a script. 62 | 63 | print("Available precisions:", list(PRECISIONS.keys())) 64 | print("Available precisions:", list(PRECISIONS.values())) 65 | print("Available precisions:", list(PRECISIONS.items())) 66 | 67 | print("Torch float16:", PRECISIONS["16"]) 68 | print("Torch bfloat16:", PRECISIONS["b16"]) 69 | print("Torch float32:", PRECISIONS["32"]) 70 | -------------------------------------------------------------------------------- /tests/integration/meteoswiss-sgm-cosmo/config.yaml: -------------------------------------------------------------------------------- 1 | - name: grib-in-grib-out 2 | markers: cosmo 3 | input: 4 | - input-local-8km.grib 5 | - input-global-o48.grib 6 | - templates.yaml 7 | - templates/co2-coarse-typeOfLevel-isobaricInhPa.grib 8 | - templates/co2-coarse-typeOfLevel-heightAboveGround.grib 9 | - templates/co2-coarse-typeOfLevel-surface.grib 10 | - templates/co2-coarse-shortName-TOT_PREC.grib 11 | output: output.grib 12 | checks: 13 | - check_grib: 14 | check_nans: true 15 | reference_date: 20251014T0000 16 | - check_grib_cutout: 17 | reference_grib: ${input:0} 18 | inference_config: 19 | checkpoint: ${checkpoint:} 20 | env: 21 | ECCODES_DEFINITION_PATH: ${sys.prefix:}/share/eccodes-cosmo-resources/definitions 22 | post_processors: 23 | - extract_from_state: lam_0 24 | write_initial_state: false 25 | input: 26 | cutout: 27 | - lam_0: 28 | grib: 29 | path: ${input:0} 30 | namer: 31 | rules: 32 | - - shortName: T 33 | - t_{level} 34 | - - shortName: U 35 | - u_{level} 36 | - - shortName: V 37 | - v_{level} 38 | - - shortName: W 39 | - w_{level} 40 | - - shortName: QV 41 | - q_{level} 42 | - - shortName: FI 43 | - z_{level} 44 | - - shortName: PMSL 45 | - msl 46 | - - shortName: FIS 47 | - z 48 | - - shortName: PS 49 | - sp 50 | - - shortName: T_2M 51 | - 2t 52 | - - shortName: TD_2M 53 | - 2d 54 | - - shortName: T_G 55 | - skt 56 | - - shortName: U_10M 57 | - 10u 58 | - - shortName: V_10M 59 | - 10v 60 | - - shortName: FR_LAND 61 | - lsm 62 | - - shortName: TOT_PREC 63 | - tp 64 | - global: 65 | grib: ${input:1} 66 | output: 67 | grib: 68 | path: ${output:} 69 | templates: 70 | - samples: 71 | index_path: ${input:2} 72 | -------------------------------------------------------------------------------- /src/anemoi/inference/commands/validate.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | from __future__ import annotations 11 | 12 | import logging 13 | from argparse import ArgumentParser 14 | from argparse import Namespace 15 | 16 | from ..checkpoint import Checkpoint 17 | from . import Command 18 | 19 | LOG = logging.getLogger(__name__) 20 | 21 | 22 | class ValidateCmd(Command): 23 | """Validate the virtual environment against a checkpoint file.""" 24 | 25 | def add_arguments(self, command_parser: ArgumentParser) -> None: 26 | """Add arguments to the command parser. 27 | 28 | Parameters 29 | ---------- 30 | command_parser : ArgumentParser 31 | The argument parser to which the arguments will be added. 32 | """ 33 | command_parser.add_argument( 34 | "--all-packages", action="store_true", help="Check all packages in the environment." 35 | ) 36 | 37 | command_parser.add_argument( 38 | "--on-difference", choices=["warn", "error", "ignore"], default="warn", help="What to do on difference." 39 | ) 40 | 41 | command_parser.add_argument("--exempt-packages", nargs="*", help="List of packages to exempt from the check.") 42 | 43 | command_parser.add_argument("checkpoint", help="Path to checkpoint file.") 44 | 45 | def run(self, args: Namespace) -> bool: 46 | """Run the validation command. 47 | 48 | Parameters 49 | ---------- 50 | args : Namespace 51 | The arguments passed to the command. 52 | 53 | Returns 54 | ------- 55 | bool 56 | True if the environment is valid, False otherwise. 57 | """ 58 | checkpoint = Checkpoint(args.checkpoint) 59 | return checkpoint.validate_environment( 60 | all_packages=args.all_packages, 61 | on_difference=args.on_difference, 62 | exempt_packages=args.exempt_packages, 63 | ) 64 | 65 | 66 | command = ValidateCmd 67 | -------------------------------------------------------------------------------- /.github/workflows/pr-label-ats.yml: -------------------------------------------------------------------------------- 1 | # This workflow checks that the appropriate ATS labels are applied to a pull request: 2 | # - "ATS Approval Not Needed": pass 3 | # - "ATS Approved": pass 4 | # - "ATS Approval Needed": fail 5 | # - Missing ATS label: fail 6 | 7 | name: "[PR] ATS labels" 8 | 9 | on: 10 | pull_request: 11 | types: 12 | - opened 13 | - edited 14 | - reopened 15 | - labeled 16 | - unlabeled 17 | - synchronize 18 | 19 | permissions: 20 | pull-requests: read 21 | 22 | jobs: 23 | check: 24 | runs-on: ubuntu-latest 25 | steps: 26 | - name: Get PR details 27 | uses: actions/github-script@v7 28 | id: check-pr 29 | with: 30 | script: | 31 | const pr = await github.rest.pulls.get({ 32 | owner: context.repo.owner, 33 | repo: context.repo.repo, 34 | pull_number: context.payload.pull_request.number, 35 | }); 36 | const labels = pr.data.labels.map(label => label.name); 37 | core.setOutput('labels', JSON.stringify(labels)); 38 | core.setOutput('author', context.payload.pull_request.user.login); 39 | 40 | - name: Evaluate ATS labels 41 | run: | 42 | AUTHOR='${{ steps.check-pr.outputs.author }}' 43 | if [ "$AUTHOR" == "DeployDuck" ] || [ "$AUTHOR" == "pre-commit-ci[bot]" ]; then 44 | echo "Bot PR, skipping." 45 | exit 0 46 | fi 47 | LABELS=$(echo '${{ steps.check-pr.outputs.labels }}' | jq -r '.[]') 48 | echo "Labels found:" 49 | echo -e "$LABELS\n" 50 | echo "Result:" 51 | if echo "$LABELS" | grep -qi "ATS approval not needed"; then 52 | echo "ATS approval not needed. Passing." 53 | EXIT_CODE=0 54 | elif echo "$LABELS" | grep -qi "ATS approved"; then 55 | echo "ATS Approved. Passing." 56 | EXIT_CODE=0 57 | elif echo "$LABELS" | grep -qi "ATS approval needed"; then 58 | echo "ATS Approval Needed. Failing." 59 | EXIT_CODE=1 60 | else 61 | echo "No ATS approval labels found. Please assign the appropriate ATS label. Failing." 62 | EXIT_CODE=1 63 | fi 64 | echo -e "\nFor more information on ATS labels, see:" 65 | echo "https://anemoi.readthedocs.io/en/latest/contributing/guidelines.html#labelling-guidelines" 66 | exit $EXIT_CODE 67 | -------------------------------------------------------------------------------- /src/anemoi/inference/grib/templates/input.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import logging 11 | from typing import Any 12 | 13 | import earthkit.data as ekd 14 | 15 | from anemoi.inference.types import State 16 | 17 | from . import TemplateProvider 18 | from . import template_provider_registry 19 | from .manager import TemplateManager 20 | 21 | LOG = logging.getLogger(__name__) 22 | 23 | 24 | @template_provider_registry.register("input") 25 | class InputTemplates(TemplateProvider): 26 | """Use input fields as the output GRIB template.""" 27 | 28 | def __init__(self, manager: TemplateManager, **fallback: dict[str, str]) -> None: 29 | """Initialize the template provider. 30 | 31 | Parameters 32 | ---------- 33 | manager : TemplateManager 34 | The manager for the template provider. 35 | **fallback : dict[str, str] 36 | A mapping of output to input variable names to use as templates from the input, 37 | used as fallback when the output variable is not present in the input state (e.g., for diagnostic variables). 38 | """ 39 | super().__init__(manager) 40 | 41 | self.fallback = fallback 42 | 43 | def __repr__(self): 44 | info = f"{self.__class__.__name__}{{fallback}}" 45 | if fallback := ", ".join(f"{k}:{v}" for k, v in self.fallback.items()): 46 | fallback = f"(fallback {fallback})" 47 | return info.format(fallback=fallback) 48 | 49 | def template( 50 | self, 51 | variable: str, 52 | lookup: dict[str, Any], 53 | *, 54 | state: State, 55 | **kwargs, 56 | ) -> ekd.Field | None: 57 | if template := state.get("_grib_templates_for_output", {}).get(variable): 58 | return template 59 | 60 | if fallback_variable := self.fallback.get(variable): 61 | if template := state.get("_grib_templates_for_output", {}).get(fallback_variable): 62 | return template 63 | LOG.warning(f"Fallback variable '{fallback_variable}' for output '{variable}' not found in input state.") 64 | 65 | return None 66 | -------------------------------------------------------------------------------- /docs/inference/apis/level1.rst: -------------------------------------------------------------------------------- 1 | .. _api_level1: 2 | 3 | #################### 4 | NumPy to NumPy API 5 | #################### 6 | 7 | The simplest way to run a inference from a checkpoint is to provide the 8 | initial state as a dictionary containing NumPy arrays for each input 9 | variable. 10 | 11 | You then create a Runner object and call the `run` method, which will 12 | yield the state at each time step. Below is a simple code example to 13 | illustrate this: 14 | 15 | .. literalinclude:: code/level1_1.py 16 | :language: python 17 | 18 | The field names are the one that where provided when running the 19 | :ref:`training `, which were the name given 20 | to fields when creating the training :ref:`dataset 21 | `. 22 | 23 | ******** 24 | States 25 | ******** 26 | 27 | A `state` is a Python :py:class:`dictionary` with the following keys: 28 | 29 | - ``date``: :py:class:`datetime.datetime` object that represent the 30 | date at which the state is valid. 31 | - ``latitudes``: a NumPy array with the list of latitudes that matches 32 | the data values of fields 33 | - ``longitudes``: a NumPy array with the corresponding list of 34 | longitudes. It must have the same size as the latitudes array. 35 | - ``fields``: a :py:class:`dictionary` that maps fields names with 36 | their data. 37 | 38 | Each field is given as a NumPy array. If the model is 39 | :py:attr:`multi-step 40 | `, it will 41 | needs to be initialised with fields from two or more dates, the values 42 | must be two dimensions arrays, with the shape ``(number-of-dates, 43 | number-of-grid-points)``, otherwise the values can be a one dimension 44 | array. The first dimension is expected to represent each date in 45 | ascending order, and the ``date`` entry of the state must be the last 46 | one. 47 | 48 | As it iterates, the model will produce new states with the same format. 49 | The ``date`` will represent the forecasted date, and the fields would 50 | have the forecasted values as NumPy array. These arrays will be of one 51 | dimensions (the number of grid points), even if the model is multi-step. 52 | 53 | ************* 54 | Checkpoints 55 | ************* 56 | 57 | Some newer version of :ref:`anemoi-training 58 | ` will store the `latitudes` and 59 | `longitudes` used during training into the checkpoint. The example code 60 | above can be simplified as follows: 61 | 62 | .. literalinclude:: code/level1_2_.py 63 | :language: python 64 | -------------------------------------------------------------------------------- /docs/inference/apis/level3.rst: -------------------------------------------------------------------------------- 1 | .. _api_level3: 2 | 3 | ################## 4 | Command line API 5 | ################## 6 | 7 | You can run the inference from the command line using the 8 | :ref:`anemoi-inference run ` command. 9 | 10 | You must first create a configuration file in YAML format. The simplest 11 | configuration must contain the path to the checkpoint: 12 | 13 | .. literalinclude:: code/level3_1.yaml 14 | :language: yaml 15 | 16 | Then you can run the inference with the following command: 17 | 18 | .. literalinclude:: code/level3_1.sh 19 | :language: bash 20 | 21 | The other entries in the configuration file are optional, and will be 22 | substituted by the default values if not provided. 23 | 24 | You can also override values by providing them on the command line: 25 | 26 | .. literalinclude:: code/level3_2.sh 27 | :language: bash 28 | 29 | Overrides are parsed as an `OmegaConf 30 | `_ 31 | dotlist, so list items can be accessed with ``list.index`` or 32 | ``list[index]``. 33 | 34 | You can also run entirely from the command line without a config file, 35 | by passing all required options as an override: 36 | 37 | .. literalinclude:: code/level3_3.sh 38 | :language: bash 39 | 40 | The configuration below shows how to run the inference from the data 41 | that was used to train the model, by setting ``dataset`` entry to 42 | ``true``: 43 | 44 | .. literalinclude:: code/level3_2.yaml 45 | :language: yaml 46 | 47 | Below is an example of how to override list entries and append to lists 48 | on the command line by using the dotlist notation. Running inference 49 | with following command: 50 | 51 | .. literalinclude:: code/level3_4.sh 52 | :language: bash 53 | 54 | together with configuration file: 55 | 56 | .. literalinclude:: code/level3_3.yaml 57 | :language: yaml 58 | 59 | will overide the first entry in the ``input.dataset.cutout`` list with 60 | the dictionary ``{"dataset": "./analysis_20240131_00.zarr"}`` and will 61 | append the dictionary ``{"dataset": "./lbc_20240131_00.zarr"}`` to it. 62 | 63 | The configuration below shows how to provide run the inference for a 64 | checkpoint that was trained with one the ICON grid: 65 | 66 | .. literalinclude:: code/level3_4.yaml 67 | :language: yaml 68 | 69 | See :ref:`run-command` for more details on the configuration file. 70 | 71 | .. warning:: 72 | 73 | This is still work in progress, and content of the YAML configuration 74 | files will change and the examples above may not work in the future. 75 | -------------------------------------------------------------------------------- /src/anemoi/inference/post_processors/backward_transform_filter.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024-2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import logging 12 | from typing import Any 13 | 14 | from anemoi.transform.filters import filter_registry 15 | 16 | from anemoi.inference.context import Context 17 | from anemoi.inference.decorators import main_argument 18 | from anemoi.inference.types import State 19 | 20 | from ..processor import Processor 21 | from . import post_processor_registry 22 | from .earthkit_state import unwrap_state 23 | from .earthkit_state import wrap_state 24 | 25 | LOG = logging.getLogger(__name__) 26 | 27 | 28 | @post_processor_registry.register("backward_transform_filter") 29 | @main_argument("filter") 30 | class BackwardTransformFilter(Processor): 31 | """A processor that applies a backward transform filter to a given state. 32 | 33 | This class uses a specified filter from the filter registry to process 34 | the state by applying a backward transformation. 35 | 36 | Attributes 37 | ---------- 38 | filter : Any 39 | The filter instance used for processing the state. 40 | """ 41 | 42 | def __init__(self, context: Context, filter: str, **kwargs: Any) -> None: 43 | """Initialize the BackwardTransformFilter. 44 | 45 | Parameters 46 | ---------- 47 | context : Context 48 | The context for the filter. 49 | filter : str 50 | The name of the filter to be used. 51 | **kwargs : Any 52 | Additional keyword arguments for the filter. 53 | """ 54 | super().__init__(context) 55 | self.filter: Any = filter_registry.create(filter, **kwargs) 56 | 57 | def process(self, state: State) -> State: 58 | """Process the given state using the backward transform filter. 59 | 60 | Parameters 61 | ---------- 62 | state : State 63 | The state to be processed. 64 | 65 | Returns 66 | ------- 67 | State 68 | The processed state. 69 | """ 70 | 71 | fields = self.filter.backward(wrap_state(state)) 72 | 73 | return unwrap_state(fields, state, namer=self.context.checkpoint.default_namer()) 74 | -------------------------------------------------------------------------------- /tests/unit/checkpoints/simple.yaml: -------------------------------------------------------------------------------- 1 | config: 2 | data: 3 | diagnostic: 4 | - tcc 5 | - tp 6 | forcing: 7 | - lsm 8 | - z 9 | timestep: 6h 10 | training: 11 | multistep_input: 2 12 | precision: 16-mixed 13 | data_indices: 14 | data: 15 | input: 16 | prognostic: 17 | - 0 18 | - 1 19 | - 2 20 | - 3 21 | diagnostic: 22 | - 6 23 | - 7 24 | forcing: 25 | - 4 26 | - 5 27 | - 8 28 | - 9 29 | full: 30 | - 0 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | - 5 36 | - 8 37 | - 9 38 | output: 39 | full: 40 | - 0 41 | - 1 42 | - 2 43 | - 3 44 | - 6 45 | - 7 46 | model: 47 | input: 48 | forcing: 49 | - 4 50 | - 5 51 | - 6 52 | - 7 53 | full: 54 | - 0 55 | - 1 56 | - 2 57 | - 3 58 | - 4 59 | - 5 60 | - 6 61 | - 7 62 | prognostic: 63 | - 0 64 | - 1 65 | - 2 66 | - 3 67 | output: 68 | full: 69 | - 0 70 | - 1 71 | - 2 72 | - 3 73 | - 4 74 | - 5 75 | prognostic: 76 | - 0 77 | - 1 78 | - 2 79 | - 3 80 | dataset: 81 | data_request: 82 | grid: O96 83 | shape: 84 | - 365 85 | - 10 86 | - 1 87 | - 40320 88 | variables: 89 | - 2t 90 | - 10u 91 | - 10v 92 | - msl 93 | - lsm 94 | - z 95 | - tcc 96 | - tp 97 | - cos_latitude 98 | - insolation 99 | variables_metadata: 100 | 10u: 101 | mars: 102 | levtype: sfc 103 | param: 10u 104 | 10v: 105 | mars: 106 | levtype: sfc 107 | param: 10v 108 | 2t: 109 | mars: 110 | levtype: sfc 111 | param: 2t 112 | cos_latitude: 113 | computed_forcing: true 114 | constant_in_time: true 115 | insolation: 116 | computed_forcing: true 117 | constant_in_time: false 118 | lsm: 119 | constant_in_time: true 120 | mars: 121 | levtype: sfc 122 | param: lsm 123 | msl: 124 | mars: 125 | levtype: sfc 126 | param: msl 127 | tcc: 128 | mars: 129 | levtype: sfc 130 | param: tcc 131 | tp: 132 | accumulated: true 133 | mars: 134 | levtype: sfc 135 | param: tp 136 | z: 137 | constant_in_time: true 138 | mars: 139 | levtype: sfc 140 | param: z 141 | -------------------------------------------------------------------------------- /src/anemoi/inference/outputs/truth.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025- Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import logging 11 | from typing import Any 12 | 13 | from anemoi.inference.runners.default import DefaultRunner 14 | from anemoi.inference.state import reduce_state 15 | from anemoi.inference.types import State 16 | 17 | from ..decorators import main_argument 18 | from ..output import ForwardOutput 19 | from . import output_registry 20 | 21 | LOG = logging.getLogger(__name__) 22 | 23 | 24 | @output_registry.register("truth") 25 | @main_argument("output") 26 | class TruthOutput(ForwardOutput): 27 | """Write the input state at the same time for each output state. 28 | 29 | Can only be used for inputs with that have access to the time of 30 | the forecasts, effectively only for times in the past. 31 | """ 32 | 33 | def __init__(self, context: DefaultRunner, output, **kwargs: Any) -> None: 34 | """Initialise the TruthOutput. 35 | 36 | Parameters 37 | ---------- 38 | context : Context 39 | The context for the output. 40 | output : 41 | The output configuration. 42 | kwargs : dict 43 | Additional keyword arguments. 44 | """ 45 | if not isinstance(context, DefaultRunner): 46 | raise ValueError("TruthOutput can only be used with `DefaultRunner`") 47 | 48 | super().__init__(context, output, None, **kwargs) 49 | self._input = context.create_prognostics_input() 50 | 51 | def write_step(self, state: State) -> None: 52 | """Write a step of the state. 53 | 54 | Parameters 55 | ---------- 56 | state : State 57 | The state to write. 58 | """ 59 | truth_state = self._input.create_input_state(date=state["date"]) 60 | truth_state["step"] = state["step"] 61 | 62 | reduced_state = reduce_state(truth_state) 63 | self.output.write_state(reduced_state) 64 | 65 | def __repr__(self) -> str: 66 | """Return a string representation of the TruthOutput. 67 | 68 | Returns 69 | ------- 70 | str 71 | String representation of the TruthOutput. 72 | """ 73 | return f"TruthOutput({self.output})" 74 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | from datetime import datetime 12 | from datetime import timedelta 13 | 14 | import numpy as np 15 | import pytest 16 | 17 | pytest_plugins = "anemoi.utils.testing" 18 | 19 | 20 | STATE_NPOINTS = 50 21 | 22 | 23 | def pytest_addoption(parser): 24 | parser.addoption("--cosmo", action="store_true", default=False, help="only run cosmo tests") 25 | 26 | 27 | def pytest_configure(config): 28 | config.addinivalue_line("markers", "cosmo: mark test as requiring cosmo eccodes definitions and isolation") 29 | 30 | 31 | def pytest_collection_modifyitems(config, items): 32 | skip_cosmo = pytest.mark.skip(reason="skipping cosmo tests, use --cosmo to run") 33 | skip_non_cosmo = pytest.mark.skip(reason="skipping non-cosmo tests") 34 | 35 | for item in items: 36 | if config.getoption("--cosmo"): 37 | if "cosmo" not in item.keywords: 38 | item.add_marker(skip_non_cosmo) 39 | else: 40 | if "cosmo" in item.keywords: 41 | item.add_marker(skip_cosmo) 42 | 43 | 44 | @pytest.fixture 45 | def state(): 46 | """Fixture to create a mock state for testing.""" 47 | 48 | return { 49 | "latitudes": np.random.uniform(-90, 90, size=STATE_NPOINTS), 50 | "longitudes": np.random.uniform(-180, 180, size=STATE_NPOINTS), 51 | "fields": { 52 | "2t": np.random.uniform(250, 310, size=STATE_NPOINTS), 53 | "z_850": np.random.uniform(500, 1500, size=STATE_NPOINTS), 54 | }, 55 | "date": datetime(2020, 1, 1, 0, 0), 56 | "step": timedelta(hours=6), 57 | } 58 | 59 | 60 | @pytest.fixture 61 | def extract_mask_npy(tmp_path): 62 | """Fixture to create a mock mask in .npy format.""" 63 | mask = np.random.choice([True, False], size=STATE_NPOINTS) 64 | mask_path = tmp_path / "extract_mask.npy" 65 | np.save(mask_path, mask) 66 | return str(mask_path) 67 | 68 | 69 | @pytest.fixture 70 | def assign_mask_npy(tmp_path): 71 | """Fixture to create a mock mask in .npy format.""" 72 | mask = np.zeros(STATE_NPOINTS + 10, dtype=bool) 73 | mask[:STATE_NPOINTS] = True 74 | mask_path = tmp_path / "assign_mask.npy" 75 | np.save(mask_path, mask) 76 | return str(mask_path) 77 | -------------------------------------------------------------------------------- /src/anemoi/inference/inputs/empty.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | """Dummy input used for testing. 11 | 12 | It will generate fields with constant values for each variable and date. 13 | These values are then tested in the mock model. 14 | """ 15 | 16 | import logging 17 | from typing import Any 18 | 19 | from anemoi.inference.context import Context 20 | from anemoi.inference.types import Date 21 | from anemoi.inference.types import State 22 | 23 | from ..input import Input 24 | from . import input_registry 25 | 26 | LOG = logging.getLogger(__name__) 27 | SKIP_KEYS = ["date", "time", "step"] 28 | 29 | 30 | @input_registry.register("empty") 31 | class EmptyInput(Input): 32 | """An Input that is always empty.""" 33 | 34 | trace_name = "empty" 35 | 36 | def __init__(self, context: Context, **kwargs: Any) -> None: 37 | """Initialise the EmptyInput. 38 | 39 | Parameters 40 | ---------- 41 | context : Context 42 | The context object for the input. 43 | **kwargs : object 44 | Additional keyword arguments. 45 | """ 46 | super().__init__(context, **kwargs) 47 | assert self.variables in (None, []), "EmptyInput should not have variables" 48 | 49 | def create_input_state(self, *, date: Date | None, **kwargs) -> State: 50 | """Create an empty input state. 51 | 52 | Parameters 53 | ---------- 54 | date : Date or None 55 | The date for the input state. 56 | **kwargs : Any 57 | Additional keyword arguments. 58 | 59 | Returns 60 | ------- 61 | State 62 | The created empty input state. 63 | """ 64 | return dict(fields=dict(), _input=self) 65 | 66 | def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State: 67 | """Load an empty forcings state. 68 | 69 | Parameters 70 | ---------- 71 | dates : list of Date 72 | The list of dates for the forcings state. 73 | current_state : State 74 | The current state (unused). 75 | 76 | Returns 77 | ------- 78 | State 79 | The loaded empty forcings state. 80 | """ 81 | return dict(date=dates[-1], fields=dict(), _input=self) 82 | -------------------------------------------------------------------------------- /tests/unit/test_coupling.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import logging 11 | 12 | from anemoi.inference.config.run import RunConfiguration 13 | from anemoi.inference.runners import create_runner 14 | from anemoi.inference.testing import fake_checkpoints 15 | from anemoi.inference.testing import files_for_tests 16 | 17 | 18 | @fake_checkpoints 19 | def test_atmos() -> None: 20 | """Test the inference process using a fake checkpoint. 21 | 22 | This function loads a configuration, creates a runner, and runs the inference 23 | process to ensure that the system works as expected with the provided configuration. 24 | """ 25 | config = RunConfiguration.load( 26 | files_for_tests("unit/configs/atmos.yaml"), 27 | overrides=dict(device="cpu", input="dummy"), 28 | ) 29 | runner = create_runner(config) 30 | runner.execute() 31 | 32 | 33 | @fake_checkpoints 34 | def test_ocean() -> None: 35 | """Test the inference process using a fake checkpoint. 36 | 37 | This function loads a configuration, creates a runner, and runs the inference 38 | process to ensure that the system works as expected with the provided configuration. 39 | """ 40 | config = RunConfiguration.load( 41 | files_for_tests("unit/configs/ocean.yaml"), 42 | overrides=dict(device="cpu", input="dummy"), 43 | ) 44 | runner = create_runner(config) 45 | runner.execute() 46 | 47 | 48 | # @fake_checkpoints 49 | # def test_coupled() -> None: 50 | # """Test the inference process using a fake checkpoint. 51 | 52 | # This function loads a configuration, creates a runner, and runs the inference 53 | # process to ensure that the system works as expected with the provided configuration. 54 | # """ 55 | # config = CoupleConfiguration.load(files_for_tests("configs/coupled.yaml")) 56 | 57 | # global_config: Dict[str, Any] = {} 58 | 59 | # tasks = {name: create_task(name, action, global_config=global_config) for name, action in config.tasks.items()} 60 | 61 | # transport = create_transport(config.transport, config.couplings, tasks) 62 | 63 | # transport.start() 64 | # transport.wait() 65 | 66 | 67 | if __name__ == "__main__": 68 | logging.basicConfig(level=logging.INFO) 69 | for name, obj in list(globals().items()): 70 | if name.startswith("test_") and callable(obj): 71 | print(f"Running {name}...") 72 | obj() 73 | -------------------------------------------------------------------------------- /src/anemoi/inference/clusters/mpi.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025- ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | # 9 | 10 | import logging 11 | 12 | from anemoi.inference.clusters import cluster_registry 13 | from anemoi.inference.clusters.mapping import EnvMapping 14 | from anemoi.inference.clusters.mapping import MappingCluster 15 | from anemoi.inference.lazy import torch 16 | 17 | LOG = logging.getLogger(__name__) 18 | 19 | MPI_MAPPING = EnvMapping( 20 | local_rank=["OMPI_COMM_WORLD_LOCAL_RANK", "PMI_RANK"], 21 | global_rank=["OMPI_COMM_WORLD_RANK", "PMI_RANK"], 22 | world_size=["OMPI_COMM_WORLD_SIZE", "PMI_SIZE"], 23 | master_addr="MASTER_ADDR", 24 | master_port="MASTER_PORT", 25 | init_method="tcp://{master_addr}:{master_port}", 26 | ) 27 | 28 | 29 | @cluster_registry.register("mpi") 30 | class MPICluster(MappingCluster): 31 | """MPI cluster that uses MPI environment variables for distributed setup.""" 32 | 33 | def __init__(self, use_mpi_backend: bool = False, **kwargs) -> None: 34 | """Initialise the MPICluster. 35 | 36 | Parameters 37 | ---------- 38 | use_mpi_backend : bool, optional 39 | Use the `mpi` backend in torch, by default False 40 | """ 41 | super().__init__(mapping=MPI_MAPPING, **kwargs) 42 | self._use_mpi_backend = use_mpi_backend 43 | 44 | @classmethod 45 | def used(cls) -> bool: 46 | return bool(MPI_MAPPING.get_env("world_size")) 47 | 48 | @property 49 | def backend(self) -> str: 50 | """Return the backend string for distributed computing.""" 51 | if self._use_mpi_backend: 52 | return "mpi" 53 | return super().backend 54 | 55 | def create_model_comm_group(self) -> "torch.distributed.ProcessGroup | None": 56 | """Create the communication group for model parallelism.""" 57 | if not self._use_mpi_backend: 58 | return super().create_model_comm_group() 59 | 60 | if self.world_size <= 1: 61 | return None 62 | 63 | LOG.debug("Creating model communication group for parallel inference") 64 | group = torch.distributed.init_process_group( 65 | backend=self.backend, 66 | ) 67 | 68 | # Create a new process group for model communication 69 | group = torch.distributed.new_group( 70 | ranks=list(range(self.world_size)), 71 | ) 72 | LOG.info("Model communication group created") 73 | 74 | return group 75 | -------------------------------------------------------------------------------- /src/anemoi/inference/commands/requests.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | from argparse import ArgumentParser 12 | from argparse import Namespace 13 | 14 | from anemoi.utils.grib import shortname_to_paramid 15 | 16 | from ..checkpoint import Checkpoint 17 | from . import Command 18 | from .retrieve import checkpoint_to_requests 19 | 20 | 21 | class RequestCmd(Command): 22 | """MARS request utility.""" 23 | 24 | def add_arguments(self, command_parser: ArgumentParser) -> None: 25 | """Add arguments to the command parser. 26 | 27 | Parameters 28 | ---------- 29 | command_parser : ArgumentParser 30 | The argument parser to which the arguments will be added. 31 | """ 32 | command_parser.description = self.__doc__ 33 | command_parser.add_argument("--mars", action="store_true", help="Print the MARS request.") 34 | command_parser.add_argument("--use-grib-paramid", action="store_true", help="Use paramId instead of param.") 35 | command_parser.add_argument( 36 | "--dont-fail-for-missing-paramid", 37 | action="store_true", 38 | help="Do not fail if a parameter ID is missing.", 39 | ) 40 | command_parser.add_argument("path", help="Path to the checkpoint.") 41 | 42 | def run(self, args: Namespace) -> None: 43 | """Run the request command. 44 | 45 | Parameters 46 | ---------- 47 | args : Namespace 48 | The arguments passed to the command. 49 | """ 50 | c = Checkpoint(args.path) 51 | for r in checkpoint_to_requests( 52 | c, 53 | date=-1, 54 | use_grib_paramid=args.use_grib_paramid, 55 | dont_fail_for_missing_paramid=args.dont_fail_for_missing_paramid, 56 | ): 57 | if args.mars: 58 | req = ["retrieve,target=data"] 59 | for k, v in r.items(): 60 | 61 | if args.use_grib_paramid and k == "param": 62 | if not isinstance(v, (list, tuple)): 63 | v = [v] 64 | v = [shortname_to_paramid(x) for x in v] 65 | 66 | if isinstance(v, (list, tuple)): 67 | v = "/".join([str(x) for x in v]) 68 | req.append(f"{k}={v}") 69 | r = ",".join(req) 70 | print(r) 71 | 72 | 73 | command = RequestCmd 74 | -------------------------------------------------------------------------------- /src/anemoi/inference/patch.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import logging 12 | from collections.abc import Generator 13 | from contextlib import contextmanager 14 | from functools import cached_property 15 | from typing import Any 16 | 17 | from .protocol import MetadataProtocol 18 | 19 | LOG = logging.getLogger(__name__) 20 | 21 | 22 | @contextmanager 23 | def patch_function(target: Any, attribute: str, replacement: Any) -> Generator[None, None, None]: 24 | """Context manager to temporarily replace an attribute of a target object. 25 | 26 | Parameters 27 | ---------- 28 | target : object 29 | The target object whose attribute will be replaced. 30 | attribute : str 31 | The name of the attribute to replace. 32 | replacement : any 33 | The replacement value for the attribute. 34 | 35 | Returns 36 | ------- 37 | Generator[None, None, None] 38 | The context manager. 39 | """ 40 | original = getattr(target, attribute) 41 | setattr(target, attribute, replacement) 42 | try: 43 | yield 44 | finally: 45 | setattr(target, attribute, original) 46 | 47 | 48 | class PatchMixin(MetadataProtocol): 49 | 50 | # `self` is a `Metadata` object 51 | 52 | def patch_metadata( 53 | self, 54 | supporting_arrays: dict[str, Any], 55 | root: str, 56 | ) -> tuple[dict[str, Any], dict[str, Any]]: 57 | """Patch the metadata with supporting arrays and root. 58 | 59 | Parameters 60 | ---------- 61 | supporting_arrays : dict 62 | The supporting arrays to patch. 63 | root : str 64 | The root path for the supporting arrays. 65 | 66 | Returns 67 | ------- 68 | tuple 69 | The patched metadata and supporting arrays. 70 | """ 71 | 72 | metadata, supporting_arrays = self._from_zarr 73 | self._metadata["dataset"] = metadata 74 | self._supporting_arrays = supporting_arrays 75 | return self._metadata, self._supporting_arrays 76 | 77 | @cached_property 78 | def _from_zarr(self) -> tuple[dict[str, Any], dict[str, Any]]: 79 | """Open the dataset and fetch metadata and supporting arrays.""" 80 | # We assume that the datasets are reachable via the content of 81 | # ~/.config/anemoi/settings.toml. 82 | 83 | ds = self.open_dataset() 84 | return ds.metadata(), ds.supporting_arrays() 85 | -------------------------------------------------------------------------------- /src/anemoi/inference/pre_processors/forward_transform_filter.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024-2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import logging 12 | from typing import Any 13 | 14 | from anemoi.transform.filters import filter_registry 15 | 16 | from anemoi.inference.decorators import main_argument 17 | from anemoi.inference.types import State 18 | 19 | from ..processor import Processor 20 | from . import pre_processor_registry 21 | 22 | LOG = logging.getLogger(__name__) 23 | 24 | 25 | @pre_processor_registry.register("forward_transform_filter") 26 | @main_argument("filter") 27 | class ForwardTransformFilter(Processor): 28 | """A processor that applies a forward transform filter to the given fields. 29 | 30 | This class uses a specified filter from the filter registry to process 31 | fields and patch data requests. 32 | 33 | Attributes 34 | ---------- 35 | filter : object 36 | The filter instance used for processing fields and patching data requests. 37 | """ 38 | 39 | def __init__(self, context: Any, filter: str, **kwargs: Any) -> None: 40 | """Initialize the ForwardTransformFilter. 41 | 42 | Parameters 43 | ---------- 44 | context : object 45 | The context in which the filter is being used. 46 | filter : str 47 | The name of the filter to be used. 48 | **kwargs : dict 49 | Additional keyword arguments to pass to the filter. 50 | """ 51 | super().__init__(context) 52 | self.filter = filter_registry.create(filter, **kwargs) 53 | 54 | def process(self, state: State) -> State: 55 | """Process the given fields using the forward filter. 56 | 57 | Parameters 58 | ---------- 59 | state : State 60 | The state containing the fields to be processed. 61 | 62 | Returns 63 | ------- 64 | State 65 | The processed state. 66 | """ 67 | state["fields"] = self.filter.forward(state["fields"]) 68 | return state 69 | 70 | def patch_data_request(self, data_request: Any) -> Any: 71 | """Patch the data request using the filter. 72 | 73 | Parameters 74 | ---------- 75 | data_request : object 76 | The data request to be patched. 77 | 78 | Returns 79 | ------- 80 | object 81 | The patched data request. 82 | """ 83 | return self.filter.patch_data_request(data_request) 84 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Empty notebookds 3 | - repo: local 4 | hooks: 5 | - id: clear-notebooks-output 6 | name: clear-notebooks-output 7 | files: tools/.*\.ipynb$ 8 | stages: [pre-commit] 9 | language: python 10 | entry: jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace 11 | additional_dependencies: [jupyter] 12 | - repo: https://github.com/pre-commit/pre-commit-hooks 13 | rev: v6.0.0 14 | hooks: 15 | - id: check-yaml # Check YAML files for syntax errors only 16 | args: [--unsafe, --allow-multiple-documents] 17 | - id: debug-statements # Check for debugger imports and py37+ breakpoint() 18 | - id: end-of-file-fixer # Ensure files end in a newline 19 | - id: trailing-whitespace # Trailing whitespace checker 20 | - id: no-commit-to-branch # Prevent committing to main / master 21 | - id: check-added-large-files # Check for large files added to git 22 | - id: check-merge-conflict # Check for files that contain merge conflict 23 | - repo: https://github.com/pre-commit/pygrep-hooks 24 | rev: v1.10.0 # Use the ref you want to point at 25 | hooks: 26 | - id: python-use-type-annotations # Check for missing type annotations 27 | - id: python-check-blanket-noqa # Check for # noqa: all 28 | - id: python-no-log-warn # Check for log.warn 29 | - repo: https://github.com/psf/black-pre-commit-mirror 30 | rev: 25.11.0 31 | hooks: 32 | - id: black 33 | args: [--line-length=120] 34 | - repo: https://github.com/pycqa/isort 35 | rev: 7.0.0 36 | hooks: 37 | - id: isort 38 | args: 39 | - -l 120 40 | - --force-single-line-imports 41 | - --profile black 42 | - --project anemoi 43 | - repo: https://github.com/astral-sh/ruff-pre-commit 44 | rev: v0.14.7 45 | hooks: 46 | - id: ruff 47 | args: 48 | - --line-length=120 49 | - --fix 50 | - --exit-non-zero-on-fix 51 | - --exclude=docs/**/*_.py 52 | - repo: https://github.com/sphinx-contrib/sphinx-lint 53 | rev: v1.0.2 54 | hooks: 55 | - id: sphinx-lint 56 | - repo: https://github.com/b8raoult/pre-commit-docconvert 57 | rev: "0.1.5" 58 | hooks: 59 | - id: docconvert 60 | args: ["numpy"] 61 | - repo: https://github.com/tox-dev/pyproject-fmt 62 | rev: "v2.11.1" 63 | hooks: 64 | - id: pyproject-fmt 65 | args: ["--max-supported-python", "3.12"] 66 | - repo: https://github.com/jshwi/docsig # Check docstrings against function sig 67 | rev: v0.71.0 68 | hooks: 69 | - id: docsig 70 | args: 71 | - --ignore-no-params # Allow docstrings without parameters 72 | - --check-dunders # Check dunder methods 73 | - --check-overridden # Check overridden methods 74 | - --check-protected # Check protected methods 75 | - --check-class # Check class docstrings 76 | - --disable=SIG101,SIG102 # Disable empty docstrings 77 | ci: 78 | autoupdate_schedule: monthly 79 | autoupdate_commit_msg: "chore(deps): pre-commit.ci autoupdate" 80 | -------------------------------------------------------------------------------- /src/anemoi/inference/post_processors/accumulate.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024-2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import logging 12 | from datetime import timedelta 13 | 14 | import numpy as np 15 | 16 | from anemoi.inference.context import Context 17 | from anemoi.inference.types import FloatArray 18 | from anemoi.inference.types import State 19 | 20 | from ..processor import Processor 21 | from . import post_processor_registry 22 | 23 | LOG = logging.getLogger(__name__) 24 | 25 | 26 | @post_processor_registry.register("accumulate_from_start_of_forecast") 27 | class Accumulate(Processor): 28 | """Accumulate fields from zero and return the accumulated fields. 29 | 30 | Parameters 31 | ---------- 32 | context : Any 33 | The context in which the processor is running. 34 | accumulations : Optional[List[str]], optional 35 | List of fields to accumulate, by default None. 36 | If None, the fields are taken from the context's checkpoint. 37 | """ 38 | 39 | def __init__(self, context: Context, accumulations: list[str] | None = None) -> None: 40 | super().__init__(context) 41 | if accumulations is None: 42 | accumulations = context.checkpoint.accumulations 43 | 44 | self.accumulations = accumulations 45 | LOG.info("Accumulating fields %s", self.accumulations) 46 | 47 | self.accumulators: dict[str, FloatArray] = {} 48 | self.step_zero = timedelta(0) 49 | 50 | def process(self, state: State) -> State: 51 | """Process the state to accumulate specified fields. 52 | 53 | Parameters 54 | ---------- 55 | state : State 56 | The state containing fields to be accumulated. 57 | 58 | Returns 59 | ------- 60 | State 61 | The updated state with accumulated fields. 62 | """ 63 | state = state.copy() 64 | state.setdefault("start_steps", {}) 65 | for accumulation in self.accumulations: 66 | if accumulation in state["fields"]: 67 | if accumulation not in self.accumulators: 68 | self.accumulators[accumulation] = np.zeros_like(state["fields"][accumulation]) 69 | self.accumulators[accumulation] += np.maximum(0, state["fields"][accumulation]) 70 | state["fields"][accumulation] = self.accumulators[accumulation] 71 | state["start_steps"][accumulation] = self.step_zero 72 | 73 | return state 74 | 75 | def __repr__(self) -> str: 76 | """Return a string representation of the Accumulate object. 77 | 78 | Returns 79 | ------- 80 | str 81 | String representation of the object. 82 | """ 83 | return f"Accumulate({self.accumulations})" 84 | -------------------------------------------------------------------------------- /src/anemoi/inference/commands/couple.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | from __future__ import annotations 11 | 12 | import datetime 13 | import logging 14 | from argparse import ArgumentParser 15 | from argparse import Namespace 16 | 17 | from ..config.couple import CoupleConfiguration 18 | from ..tasks import create_task 19 | from ..transports import create_transport 20 | from . import Command 21 | 22 | LOG = logging.getLogger(__name__) 23 | 24 | COPY_ATTRIBUTES = ( 25 | "date", 26 | "lead_time", 27 | "verbosity", 28 | ) 29 | 30 | 31 | class CoupleCmd(Command): 32 | """Couple tasks based on a configuration file.""" 33 | 34 | def add_arguments(self, command_parser: ArgumentParser) -> None: 35 | """Add arguments to the command parser. 36 | 37 | Parameters 38 | ---------- 39 | command_parser : ArgumentParser 40 | The argument parser to which the arguments will be added. 41 | """ 42 | command_parser.add_argument("--defaults", action="append", help="Sources of default values.") 43 | command_parser.add_argument( 44 | "config", 45 | help="Path to config file. Can be omitted to pass config with overrides and defaults.", 46 | ) 47 | command_parser.add_argument("overrides", nargs="*", help="Overrides as key=value") 48 | 49 | def run(self, args: Namespace) -> None: 50 | """Run the couple command. 51 | 52 | Parameters 53 | ---------- 54 | args : Namespace 55 | The arguments passed to the command. 56 | """ 57 | if "=" in args.config: 58 | args.overrides.append(args.config) 59 | args.config = {} 60 | 61 | config = CoupleConfiguration.load( 62 | args.config, 63 | args.overrides, 64 | defaults=args.defaults, 65 | ) 66 | 67 | if config.description is not None: 68 | LOG.info("%s", config.description) 69 | 70 | global_config = {} 71 | for copy in COPY_ATTRIBUTES: 72 | value = getattr(config, copy, None) 73 | if value is not None: 74 | LOG.info("Copy setting to all tasks: %s=%s", copy, value) 75 | if isinstance(value, datetime.datetime): 76 | value = value.isoformat() 77 | global_config[copy] = value 78 | 79 | tasks = {name: create_task(name, action, global_config=global_config) for name, action in config.tasks.items()} 80 | for task in tasks.values(): 81 | LOG.info("Task: %s", task) 82 | 83 | transport = create_transport(config.transport, config.couplings, tasks) 84 | LOG.info("Transport: %s", transport) 85 | 86 | transport.start() 87 | transport.wait() 88 | LOG.info("Run complete") 89 | 90 | 91 | command = CoupleCmd 92 | -------------------------------------------------------------------------------- /src/anemoi/inference/inputs/fdb.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import datetime 11 | import logging 12 | from typing import Any 13 | 14 | import earthkit.data as ekd 15 | import numpy as np 16 | 17 | from ..types import Date 18 | from ..types import State 19 | from . import input_registry 20 | from .grib import GribInput 21 | 22 | LOG = logging.getLogger(__name__) 23 | 24 | 25 | @input_registry.register("fdb") 26 | class FDBInput(GribInput): 27 | """Get input fields from FDB.""" 28 | 29 | trace_name = "fdb" 30 | 31 | def __init__( 32 | self, 33 | context, 34 | *, 35 | fdb_config: dict | None = None, 36 | fdb_userconfig: dict | None = None, 37 | **kwargs: dict[str, Any], 38 | ): 39 | """Initialise the FDB input. 40 | 41 | Parameters 42 | ---------- 43 | context : dict 44 | The context runner.pytest 45 | fdb_config : dict, optional 46 | The FDB config to use. 47 | fdb_userconfig : dict, optional 48 | The FDB userconfig to use. 49 | kwargs : dict, optional 50 | Additional keyword arguments for the request to FDB. 51 | """ 52 | super().__init__(context, **kwargs) 53 | self.kwargs = kwargs 54 | self.configs = {"config": fdb_config, "userconfig": fdb_userconfig} 55 | # NOTE: this is a temporary workaround for #191 thus not documented 56 | self.param_id_map = kwargs.pop("param_id_map", {}) 57 | 58 | def create_input_state(self, *, date: Date | None, **kwargs) -> State: 59 | date = np.datetime64(date).astype(datetime.datetime) 60 | dates = [date + h for h in self.checkpoint.lagged] 61 | ds = self.retrieve(variables=self.variables, dates=dates) 62 | res = self._create_input_state(ds, variables=None, date=date, **kwargs) 63 | return res 64 | 65 | def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State: 66 | ds = self.retrieve(variables=self.variables, dates=dates) 67 | return self._load_forcings_state(ds, dates=dates, current_state=current_state) 68 | 69 | def retrieve(self, variables: list[str], dates: list[Date]) -> Any: 70 | requests = self.checkpoint.mars_requests( 71 | variables=variables, 72 | dates=dates, 73 | use_grib_paramid=self.context.use_grib_paramid, 74 | patch_request=self.patch_data_request, 75 | ) 76 | requests = [self.kwargs | r for r in requests] 77 | # NOTE: this is a temporary workaround for #191 78 | for request in requests: 79 | request["param"] = [self.param_id_map.get(p, p) for p in request["param"]] 80 | sources = [ekd.from_source("fdb", request, stream=False, **self.configs) for request in requests] 81 | ds = ekd.from_source("multi", sources) 82 | return ds 83 | -------------------------------------------------------------------------------- /src/anemoi/inference/outputs/raw.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import logging 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | 15 | from anemoi.inference.context import Context 16 | from anemoi.inference.types import State 17 | from anemoi.inference.utils.templating import render_template 18 | 19 | from ..decorators import ensure_dir 20 | from ..decorators import main_argument 21 | from ..output import Output 22 | from . import output_registry 23 | 24 | LOG = logging.getLogger(__name__) 25 | 26 | 27 | @output_registry.register("raw") 28 | @main_argument("path") 29 | @ensure_dir("dir") 30 | class RawOutput(Output): 31 | """Raw output class.""" 32 | 33 | def __init__( 34 | self, 35 | context: Context, 36 | dir: Path, 37 | template: str = "{date}.npz", 38 | strftime: str = "%Y%m%d%H%M%S", 39 | variables: list[str] | None = None, 40 | **kwargs, 41 | ) -> None: 42 | """Initialise the RawOutput class. 43 | 44 | Parameters 45 | ---------- 46 | context : dict 47 | The context. 48 | dir : Path 49 | The directory to save the raw output. 50 | If the parent directory does not exist, it will be created. 51 | template : str, optional 52 | The template for filenames, by default "{date}.npz". 53 | Variables available are `date`, `basetime` `step`. 54 | strftime : str, optional 55 | The date format string, by default "%Y%m%d%H%M%S". 56 | """ 57 | super().__init__(context, variables=variables, **kwargs) 58 | self.dir = dir 59 | self.template = template 60 | self.strftime = strftime 61 | 62 | def __repr__(self) -> str: 63 | """Return a string representation of the RawOutput object. 64 | 65 | Returns 66 | ------- 67 | str 68 | String representation of the RawOutput object. 69 | """ 70 | return f"RawOutput({self.dir})" 71 | 72 | def write_step(self, state: State) -> None: 73 | """Write the state to a compressed .npz file. 74 | 75 | Parameters 76 | ---------- 77 | state : State 78 | The state to be written. 79 | """ 80 | date = state["date"] 81 | basetime = date - state["step"] 82 | 83 | format_info = { 84 | "date": date.strftime(self.strftime), 85 | "step": state["step"], 86 | "basetime": basetime.strftime(self.strftime), 87 | } 88 | 89 | fn_state = f"{self.dir}/{render_template(self.template, format_info)}" 90 | restate = {f"field_{key}": val for key, val in state["fields"].items() if not self.skip_variable(key)} 91 | 92 | for key in ["date"]: 93 | restate[key] = np.array(state[key], dtype=str) 94 | 95 | for key in ["latitudes", "longitudes"]: 96 | restate[key] = np.array(state[key]) 97 | 98 | np.savez_compressed(fn_state, **restate) 99 | -------------------------------------------------------------------------------- /src/anemoi/inference/transports/mpi.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import logging 12 | from typing import Any 13 | 14 | from anemoi.utils.logs import set_logging_name 15 | 16 | from anemoi.inference.config import Configuration 17 | from anemoi.inference.task import Task 18 | from anemoi.inference.types import State 19 | 20 | from ..transport import Transport 21 | from . import transport_registry 22 | 23 | LOG = logging.getLogger(__name__) 24 | 25 | 26 | @transport_registry.register("mpi") 27 | class MPITransport(Transport): 28 | """Transport implementation using MPI.""" 29 | 30 | def __init__(self, couplings: Configuration, tasks: dict[str, Task], *args: Any, **kwargs: Any) -> None: 31 | """Initialize the MPITransport. 32 | 33 | Parameters 34 | ---------- 35 | couplings : Configuration 36 | The couplings for the transport. 37 | tasks : Dict[str, Any] 38 | The tasks to be executed. 39 | """ 40 | from mpi4py import MPI 41 | 42 | super().__init__(couplings, tasks) 43 | self.comm: MPI.Comm = MPI.COMM_WORLD 44 | self.rank: int = self.comm.Get_rank() 45 | self.size: int = self.comm.Get_size() 46 | 47 | assert ( 48 | len(tasks) == self.size 49 | ), f"Number of tasks ({len(tasks)}) must match number of MPI processes ({self.size})" 50 | 51 | def start(self) -> None: 52 | """Start the transport by initializing MPI tasks.""" 53 | 54 | tasks = list(self.tasks.values()) 55 | self.ranks = {task.name: i for i, task in enumerate(tasks)} 56 | 57 | # Pick only one task per rank 58 | task = tasks[self.rank] 59 | set_logging_name(task.name) 60 | 61 | task.run(self) 62 | 63 | def wait(self) -> None: 64 | """Wait for all MPI tasks to complete.""" 65 | self.comm.barrier() 66 | 67 | def send(self, sender: Task, target: Task, state: State, tag: int) -> None: 68 | """Send a state from the sender to the target. 69 | 70 | Parameters 71 | ---------- 72 | sender : Task 73 | The task sending the state. 74 | target : Task 75 | The task receiving the state. 76 | state : State 77 | The state to be sent. 78 | tag : int 79 | The tag associated with the state. 80 | """ 81 | self.comm.send(state, dest=self.ranks[target.name], tag=tag) 82 | 83 | def receive(self, receiver: Task, source: Task, tag: int) -> State: 84 | """Receive a state from the source to the receiver. 85 | 86 | Parameters 87 | ---------- 88 | receiver : Any 89 | The task receiving the state. 90 | source : Any 91 | The task sending the state. 92 | tag : int 93 | The tag associated with the state. 94 | 95 | Returns 96 | ------- 97 | Any 98 | The received state. 99 | """ 100 | return self.comm.recv(source=self.ranks[source.name], tag=tag) 101 | -------------------------------------------------------------------------------- /src/anemoi/inference/outputs/gribmemory.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import logging 12 | from io import IOBase 13 | from typing import Any 14 | 15 | from anemoi.inference.context import Context 16 | from anemoi.inference.types import ProcessorConfig 17 | 18 | from .gribfile import GribIoOutput 19 | 20 | LOG = logging.getLogger(__name__) 21 | 22 | 23 | class GribMemoryOutput(GribIoOutput): 24 | """Handles grib files in memory.""" 25 | 26 | def __init__( 27 | self, 28 | context: Context, 29 | *, 30 | out: IOBase, 31 | post_processors: list[ProcessorConfig] | None = None, 32 | encoding: dict[str, Any] | None = None, 33 | archive_requests: dict[str, Any] | None = None, 34 | check_encoding: bool = True, 35 | templates: list[str] | str | None = None, 36 | grib1_keys: dict[str, Any] | None = None, 37 | grib2_keys: dict[str, Any] | None = None, 38 | modifiers: list[str] | None = None, 39 | variables: list[str] | None = None, 40 | output_frequency: int | None = None, 41 | write_initial_state: bool | None = None, 42 | ) -> None: 43 | """Initialize the GribFileOutput. 44 | 45 | Parameters 46 | ---------- 47 | context : Context 48 | The context. 49 | out : IOBase 50 | Output stream or file-like object for writing GRIB data. 51 | post_processors : Optional[List[ProcessorConfig]], default None 52 | Post-processors to apply to the input 53 | encoding : dict, optional 54 | The encoding dictionary, by default None. 55 | archive_requests : dict, optional 56 | The archive requests dictionary, by default None. 57 | check_encoding : bool, optional 58 | Whether to check encoding, by default True. 59 | templates : list or str, optional 60 | The templates list or string, by default None. 61 | grib1_keys : dict, optional 62 | The grib1 keys dictionary, by default None. 63 | grib2_keys : dict, optional 64 | The grib2 keys dictionary, by default None. 65 | modifiers : list, optional 66 | The list of modifiers, by default None. 67 | output_frequency : int, optional 68 | The frequency of output, by default None. 69 | write_initial_state : bool, optional 70 | Whether to write the initial state, by default None. 71 | variables : list, optional 72 | The list of variables, by default None. 73 | """ 74 | super().__init__( 75 | context, 76 | out=out, 77 | post_processors=post_processors, 78 | encoding=encoding, 79 | archive_requests=archive_requests, 80 | check_encoding=check_encoding, 81 | templates=templates, 82 | grib1_keys=grib1_keys, 83 | grib2_keys=grib2_keys, 84 | modifiers=modifiers, 85 | output_frequency=output_frequency, 86 | write_initial_state=write_initial_state, 87 | variables=variables, 88 | split_output=False, 89 | ) 90 | -------------------------------------------------------------------------------- /src/anemoi/inference/runners/testing.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import logging 11 | from functools import cached_property 12 | from typing import Any 13 | 14 | from anemoi.inference.lazy import torch 15 | 16 | from . import runner_registry 17 | from .default import DefaultRunner 18 | 19 | LOG = logging.getLogger(__name__) 20 | 21 | 22 | class TestingMixing: 23 | # Used with dummy checkpoint (see pytests) 24 | def predict_step(self, model: Any, input_tensor_torch: Any, **kwargs: Any) -> Any: 25 | """Perform a prediction step using the model. 26 | 27 | Parameters 28 | ---------- 29 | model : Any 30 | The model to use for prediction. 31 | input_tensor_torch : torch.Tensor 32 | The input tensor for the model. 33 | **kwargs : Any 34 | Additional keyword arguments. 35 | 36 | Returns 37 | ------- 38 | Any 39 | The prediction result. 40 | """ 41 | return model.predict_step(input_tensor_torch, **kwargs) 42 | 43 | 44 | class NoModelMixing: 45 | # Use with a real checkpoint (see prepml/regression) 46 | def predict_step(self, model: Any, input_tensor_torch: Any, **kwargs: Any) -> Any: 47 | """Perform a prediction step using the model. 48 | 49 | Parameters 50 | ---------- 51 | model : Any 52 | The model to use for prediction. 53 | input_tensor_torch : torch.Tensor 54 | The input tensor for the model. 55 | **kwargs : Any 56 | Additional keyword arguments. 57 | 58 | Returns 59 | ------- 60 | Any 61 | The prediction result. 62 | """ 63 | return model.predict_step(input_tensor_torch, **kwargs) 64 | 65 | @cached_property 66 | def model(self) -> "torch.nn.Module": 67 | 68 | checkpoint = self.checkpoint 69 | number_of_output_variables = len(checkpoint.output_tensor_index_to_variable) 70 | 71 | class NoModel(torch.nn.Module): 72 | """Dummy model class for testing purposes.""" 73 | 74 | def __init__(self): 75 | super().__init__() 76 | 77 | def predict_step(self, input_tensor: Any, **kwargs: Any) -> Any: 78 | input_shape = input_tensor.shape 79 | output_shape = ( 80 | input_shape[0], # batch 81 | 1, # time 82 | input_shape[2], # gridpoints 83 | number_of_output_variables, # variables 84 | ) 85 | 86 | return torch.ones(*output_shape, dtype=input_tensor.dtype, device=input_tensor.device) 87 | 88 | return NoModel() 89 | 90 | 91 | @runner_registry.register("testing") 92 | class TestingRunner(TestingMixing, DefaultRunner): 93 | """Runner for running tests. 94 | 95 | Inherits from DefaultRunner. 96 | """ 97 | 98 | 99 | @runner_registry.register("no-model") 100 | class NoModelRunner(NoModelMixing, DefaultRunner): 101 | """Runner for running tests. 102 | 103 | Inherits from DefaultRunner. 104 | """ 105 | -------------------------------------------------------------------------------- /src/anemoi/inference/grib/templates/file.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import logging 11 | from functools import cached_property 12 | from pathlib import Path 13 | from typing import Any 14 | from typing import Literal 15 | 16 | import earthkit.data as ekd 17 | 18 | from anemoi.inference.decorators import main_argument 19 | from anemoi.inference.inputs.ekd import find_variable 20 | from anemoi.inference.types import State 21 | 22 | from . import TemplateProvider 23 | from . import template_provider_registry 24 | from .manager import TemplateManager 25 | 26 | LOG = logging.getLogger(__name__) 27 | 28 | 29 | @template_provider_registry.register("file") 30 | @main_argument("path") 31 | class FileTemplates(TemplateProvider): 32 | """Template provider using a single GRIB file.""" 33 | 34 | def __init__( 35 | self, 36 | manager: TemplateManager, 37 | *, 38 | path: str, 39 | mode: Literal["auto", "first", "last"] = "first", 40 | variables: str | list | None = None, 41 | ) -> None: 42 | """Initialize the FileTemplates instance. 43 | 44 | Parameters 45 | ---------- 46 | manager : TemplateManager 47 | The manager instance. 48 | path : str 49 | The path to the GRIB file. 50 | mode : Literal["auto", "first", "last"], optional 51 | The method with which to select a message from the grib file to use as template, by default "first": 52 | - "first": use the first message in the grib file 53 | - "last": use the last message in the grib file 54 | - "auto": select variable from the grib file matching the output variable name 55 | variables : str | list, optional 56 | The output variable name(s) for which to use this template file. If empty, applies to all variables. 57 | """ 58 | self.manager = manager 59 | self.path = Path(path) 60 | if not self.path.exists(): 61 | raise FileNotFoundError(f"GRIB template file not found: {self.path}") 62 | self.mode = mode 63 | self.variables = variables if isinstance(variables, list) else [variables] if variables else None 64 | 65 | def __repr__(self): 66 | info = f"{self.__class__.__name__}({self.path.name},mode={self.mode}{{variables}})" 67 | return info.format(variables=f",variables={self.variables}" if self.variables else "") 68 | 69 | @cached_property 70 | def _data(self): 71 | return ekd.from_source("file", self.path) 72 | 73 | def template(self, variable: str, lookup: dict[str, Any], state: State, **kwargs) -> ekd.Field | None: 74 | if self.variables and variable not in self.variables: 75 | return None 76 | 77 | match self.mode: 78 | case "first": 79 | return self._data[0] 80 | case "last": 81 | return self._data[-1] 82 | case "auto": 83 | namer = getattr(state.get("_input"), "_namer", self.manager.owner.context.checkpoint.default_namer()) 84 | field = find_variable(self._data, variable, namer) 85 | if len(field) > 0: 86 | return field[0] 87 | 88 | return None 89 | -------------------------------------------------------------------------------- /src/anemoi/inference/runners/plugin.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import datetime 12 | import logging 13 | 14 | from anemoi.inference.types import IntArray 15 | 16 | from ..forcings import ComputedForcings 17 | from ..forcings import Forcings 18 | from ..runner import Runner 19 | from . import runner_registry 20 | 21 | LOG = logging.getLogger(__name__) 22 | 23 | 24 | @runner_registry.register("plugin") 25 | class PluginRunner(Runner): 26 | """A runner implementing the ai-models plugin API.""" 27 | 28 | def __init__(self, checkpoint: str, *, device: str): 29 | """Initialize the PluginRunner. 30 | 31 | Parameters 32 | ---------- 33 | checkpoint : str 34 | The checkpoint for the runner. 35 | device : str 36 | The device to run the model on. 37 | """ 38 | super().__init__(checkpoint, device=device) 39 | 40 | # Compatibility with the ai_models API 41 | 42 | @property 43 | def param_sfc(self) -> list[str]: 44 | """Get surface parameters.""" 45 | params, _ = self.checkpoint.mars_by_levtype("sfc") 46 | return sorted(params) 47 | 48 | @property 49 | def param_level_pl(self) -> tuple[list[str], list[int]]: 50 | """Get pressure level parameters and levels.""" 51 | params, levels = self.checkpoint.mars_by_levtype("pl") 52 | return sorted(params), sorted(levels) 53 | 54 | @property 55 | def param_level_ml(self) -> tuple[list[str], list[int]]: 56 | """Get model level parameters and levels.""" 57 | params, levels = self.checkpoint.mars_by_levtype("ml") 58 | return sorted(params), sorted(levels) 59 | 60 | @property 61 | def lagged(self) -> list[datetime.timedelta]: 62 | """Get lagged times in hours.""" 63 | return self.checkpoint.lagged 64 | 65 | def create_constant_computed_forcings(self, variables: list[str], mask: IntArray) -> list[Forcings]: 66 | """Create constant computed forcings. 67 | 68 | Parameters 69 | ---------- 70 | variables : List[str] 71 | The variables for the computed forcings. 72 | mask : IntArray 73 | The mask for the computed forcings. 74 | 75 | Returns 76 | ------- 77 | List[Forcings] 78 | The constant computed forcings. 79 | """ 80 | result = ComputedForcings(self, variables, mask) 81 | LOG.info("Constant computed forcing: %s", result) 82 | return [result] 83 | 84 | def create_dynamic_computed_forcings(self, variables: list[str], mask: IntArray) -> list[Forcings]: 85 | """Create dynamic computed forcings. 86 | 87 | Parameters 88 | ---------- 89 | variables : list 90 | The variables for the computed forcings. 91 | mask : IntArray 92 | The mask for the computed forcings. 93 | 94 | Returns 95 | ------- 96 | List[Forcings] 97 | The dynamic computed forcings. 98 | """ 99 | result = ComputedForcings(self, variables, mask) 100 | LOG.info("Dynamic computed forcing: %s", result) 101 | return [result] 102 | -------------------------------------------------------------------------------- /src/anemoi/inference/profiler.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | import logging 11 | import socket 12 | import time 13 | from collections.abc import Generator 14 | from contextlib import contextmanager 15 | 16 | from anemoi.inference.lazy import torch 17 | 18 | LOG = logging.getLogger(__name__) 19 | 20 | 21 | @contextmanager 22 | def ProfilingLabel(label: str, use_profiler: bool) -> Generator[None, None, None]: 23 | """Add label to function so that the profiler can recognize it, only if the use_profiler option is True. 24 | 25 | Parameters 26 | ---------- 27 | label : str 28 | Name or description to identify the function. 29 | use_profiler : bool 30 | Wrap the function with the label if True, otherwise just execute the function as it is. 31 | 32 | Returns 33 | ------- 34 | Generator[None, None, None] 35 | Yields to the caller. 36 | """ 37 | 38 | if use_profiler: 39 | with torch.autograd.profiler.record_function(label): 40 | torch.cuda.nvtx.range_push(label) 41 | yield 42 | torch.cuda.nvtx.range_pop() 43 | else: 44 | yield 45 | 46 | 47 | @contextmanager 48 | def ProfilingRunner(use_profiler: bool) -> Generator[None, None, None]: 49 | """Perform time and memory usage profiles of the wrapped code. 50 | 51 | Parameters 52 | ---------- 53 | use_profiler : bool 54 | Whether to profile the wrapped code (True) or not (False). 55 | 56 | Returns 57 | ------- 58 | Generator[None, None, None] 59 | Yields to the caller. 60 | """ 61 | 62 | dirname = f"profiling-output/{socket.gethostname()}-{int(time.time())}" 63 | if use_profiler: 64 | torch.cuda.memory._record_memory_history(max_entries=100000) 65 | activities = [torch.profiler.ProfilerActivity.CPU] 66 | if torch.cuda.is_available(): 67 | activities.append(torch.profiler.ProfilerActivity.CUDA) 68 | with torch.profiler.profile( 69 | profile_memory=True, 70 | record_shapes=True, 71 | activities=activities, 72 | with_flops=True, 73 | on_trace_ready=torch.profiler.tensorboard_trace_handler(dirname), 74 | ) as prof: 75 | yield 76 | try: 77 | torch.cuda.memory._dump_snapshot(f"{dirname}/memory_snapshot.pickle") 78 | except Exception as e: 79 | LOG.error(f"Failed to capture memory snapshot {e}") 80 | torch.cuda.memory._record_memory_history(enabled=None) 81 | row_limit = 10 82 | LOG.info( 83 | f"Top {row_limit} kernels by runtime on CPU:\n {prof.key_averages().table(sort_by='self_cpu_time_total', row_limit=row_limit)}" 84 | ) 85 | LOG.info( 86 | f"Top {row_limit} kernels by runtime on CUDA:\n {prof.key_averages().table(sort_by='self_cuda_time_total', row_limit=row_limit)}" 87 | ) 88 | LOG.info("Memory summary \n%s", torch.cuda.memory_summary()) 89 | LOG.info( 90 | f"Memory snapshot and trace file stored to '{dirname}'. To view the memory snapshot, upload the pickle file to 'https://pytorch.org/memory_viz'. To view the trace file, see 'https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html#use-tensorboard-to-view-results-and-analyze-model-performance'" 91 | ) 92 | else: 93 | yield 94 | -------------------------------------------------------------------------------- /src/anemoi/inference/inputs/repeated_dates.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import logging 12 | from typing import Any 13 | 14 | from anemoi.utils.dates import as_datetime 15 | 16 | from anemoi.inference.context import Context 17 | from anemoi.inference.types import Date 18 | from anemoi.inference.types import State 19 | 20 | from ..input import Input 21 | from . import create_input 22 | from . import input_registry 23 | 24 | LOG = logging.getLogger(__name__) 25 | 26 | 27 | @input_registry.register("repeated-dates") 28 | class RepeatedDatesInput(Input): 29 | """This class is identical to the one used to in anemoi-datasets/create 30 | It uses a source of constants (e.g. a source containing the bathymetry) 31 | available only for a given date and returns its content whever date 32 | is requested by the runner 33 | """ 34 | 35 | trace_name = "repeated dates" 36 | 37 | def __init__(self, context: Context, *, source: str, mode: str = "constant", **kwargs: Any) -> None: 38 | 39 | self.date = kwargs.pop("date", None) 40 | assert self.date is not None, "date must be provided for repeated-dates input" 41 | 42 | self.date = as_datetime(self.date) 43 | 44 | self.mode = mode 45 | 46 | assert self.mode in ["constant"], f"Unknown mode {self.mode}" 47 | 48 | super().__init__(context, **kwargs) 49 | self.source = create_input(context, source, variables=self.variables, purpose=self.purpose) 50 | 51 | def create_input_state(self, *, date: Date | None, **kwargs) -> State: 52 | """Create the input state for the repeated-dates input. 53 | 54 | Parameters 55 | ---------- 56 | date : Date or None 57 | The date for the input state. 58 | **kwargs : Any 59 | Additional keyword arguments. 60 | 61 | Returns 62 | ------- 63 | State 64 | The created input state. 65 | """ 66 | 67 | # TODO: Consider caching the result 68 | state = self.source.create_input_state(date=self.date, **kwargs) 69 | state["_input"] = self 70 | state["date"] = date 71 | return state 72 | 73 | def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State: 74 | """Load the forcings state for repeated dates input. 75 | 76 | Parameters 77 | ---------- 78 | dates : list of Date 79 | The list of dates for which to repeat the fields. 80 | current_state : State 81 | The current state to use for loading. 82 | 83 | Returns 84 | ------- 85 | State 86 | The loaded and repeated forcings state. 87 | """ 88 | assert len(dates) > 0, "dates must not be empty for repeated dates input" 89 | 90 | state = self.source.load_forcings_state( 91 | dates=[self.date], 92 | current_state=current_state, 93 | ) 94 | 95 | fields = state["fields"] 96 | 97 | for name, data in fields.items(): 98 | assert len(data.shape) == 2, data.shape 99 | assert data.shape[0] == 1, data.shape 100 | fields[name] = data.repeat(len(dates), axis=0) 101 | 102 | state["date"] = dates[-1] 103 | 104 | state["_input"] = self 105 | 106 | return state 107 | -------------------------------------------------------------------------------- /docs/inference/input-types.rst: -------------------------------------------------------------------------------- 1 | .. _input-types: 2 | 3 | ############# 4 | Input types 5 | ############# 6 | 7 | Anemoi-inference allows you to specify different input for different 8 | type of variables. Variables are classified using these categories: 9 | 10 | - ``computed``: Variables that are calculated during the model run. 11 | - ``forcing``: Variables that are imposed on the model from external 12 | sources. 13 | - ``prognostic``: Variables that are both input (initial conditions) 14 | and output. 15 | - ``diagnostic``: Variables that are only output, derived from other 16 | variables. 17 | - ``constant``: Variables that remain unchanged throughout the 18 | simulation, such as static fields or parameters. 19 | - ``accumulation``: Variables that represent accumulated quantities 20 | over time, such as total precipitation. 21 | 22 | To find out which category a variable belongs to, you can use the 23 | :ref:`anemoi-inference inspect ` command: 24 | 25 | .. code:: console 26 | 27 | % anemoi-inference inspect --variables checkpoint.ckpt 28 | 29 | The output will show something like: 30 | 31 | .. code:: console 32 | 33 | 100u => diagnostic 34 | 100v => diagnostic 35 | 10u => prognostic 36 | 10v => prognostic 37 | 2d => prognostic 38 | 2t => prognostic 39 | cl => constant, forcing 40 | cos_julian_day => computed, forcing 41 | cos_latitude => computed, constant, forcing 42 | cos_local_time => computed, forcing 43 | cos_longitude => computed, constant, forcing 44 | cp => accumulation, diagnostic 45 | cvh => constant, forcing 46 | cvl => constant, forcing 47 | hcc => diagnostic 48 | ... 49 | ro => accumulation, diagnostic 50 | sdor => constant, forcing 51 | sf => accumulation, diagnostic 52 | sin_julian_day => computed, forcing 53 | ... 54 | tcw => prognostic 55 | tp => accumulation, diagnostic 56 | tvh => constant, forcing 57 | tvl => constant, forcing 58 | ... 59 | w_925 => prognostic 60 | z => constant, forcing 61 | z_100 => prognostic 62 | z_1000 => prognostic 63 | ... 64 | 65 | As shown above, some variables can belong to multiple categories. 66 | 67 | The runner has now three :ref:`inputs `, managing different 68 | categories of variables 69 | 70 | - ``input``: used to fetch the ``prognostics`` for the initial 71 | conditions (e.g. 2t in an atmospheric model). 72 | 73 | - ``constant_forcings``: used to fetch the constants for the initial 74 | conditions (e.g. lsm or orography). These are the variables that have 75 | ``constant`` **and** ``forcing`` in their category, and are not 76 | ``computed``, ``prognostic`` or ``diagnostic`` and are not 77 | ``computed``, ``prognostic`` or ``diagnostic``. 78 | 79 | - ``dynamic_forcings``: used to fetch the forcings needed by some 80 | models throughout the length of the forecast (e.g. atmospheric fields 81 | used as forcing to an ocean model). These are the variables that have 82 | ``forcing`` in their category, and are not ``computed`` or 83 | ``constant``. 84 | 85 | To ensure backward compatibility, unless given explicitly in the config, 86 | ``constant_forcings`` and ``dynamic_forcings`` both fallback to the 87 | ``input`` entry. 88 | 89 | A ``initial_state_categories`` configuration option lets the user select 90 | a list of categories of variables to be written to the output if 91 | ``write_initial_conditions`` is ``true``. For backward compatibility, it 92 | defaults to ``prognostics`` and ``constant_forcings``. In that case, the 93 | ``input`` will also fetch the constants and the forcing fields 94 | -------------------------------------------------------------------------------- /tests/unit/test_config.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest import MonkeyPatch 3 | 4 | from anemoi.inference.config.run import RunConfiguration 5 | from anemoi.inference.testing import files_for_tests 6 | 7 | 8 | def test_load_config() -> None: 9 | """Test loading the configuration""" 10 | RunConfiguration.load(files_for_tests("unit/configs/simple.yaml")) 11 | 12 | 13 | def test_interpolation(monkeypatch: MonkeyPatch) -> None: 14 | """Test loading the configuration with some OmegaConf interpolations""" 15 | monkeypatch.setenv("TEST", "foo") 16 | config = RunConfiguration.load(files_for_tests("unit/configs/interpolation.yaml")) 17 | assert config.name == "foo" 18 | assert config.name == config.description 19 | 20 | 21 | def test_config_dotlist_override() -> None: 22 | """Test overriding the configuration with some dotlist parameters""" 23 | config = RunConfiguration.load( 24 | files_for_tests("unit/configs/mwd.yaml"), 25 | overrides=[ 26 | "runner=testing", 27 | "device=cpu", 28 | "input=dummy", 29 | "post_processors.0.backward_transform_filter=test", 30 | ], 31 | ) 32 | assert config.post_processors is not None and len(config.post_processors) == 1 33 | assert isinstance(config.post_processors[0], dict) 34 | assert config.post_processors[0]["backward_transform_filter"] == "test" 35 | 36 | 37 | def test_config_dotlist_override_append() -> None: 38 | """Test overriding with an additional parameter""" 39 | config = RunConfiguration.load( 40 | files_for_tests("unit/configs/mwd.yaml"), 41 | overrides=[ 42 | "runner=testing", 43 | "device=cpu", 44 | "input=dummy", 45 | "post_processors.1.backward_transform_filter=test", 46 | ], 47 | ) 48 | assert config.post_processors is not None and len(config.post_processors) == 2 49 | assert isinstance(config.post_processors[1], dict) 50 | assert config.post_processors[1]["backward_transform_filter"] == "test" 51 | 52 | 53 | def test_config_dotlist_override_append_end() -> None: 54 | """Test overriding with an additional parameter""" 55 | config = RunConfiguration.load( 56 | files_for_tests("unit/configs/simple.yaml"), 57 | overrides=[ 58 | "pre_processors=[]", 59 | "pre_processors.0=test", 60 | ], 61 | ) 62 | assert isinstance(config.pre_processors, list) and len(config.pre_processors) == 1 63 | assert config.pre_processors[0] == "test" 64 | 65 | 66 | def test_config_dotlist_override_add_new_non_empty_dict() -> None: 67 | """Test overriding with an additional parameter""" 68 | config = RunConfiguration.load( 69 | files_for_tests("unit/configs/simple.yaml"), 70 | overrides=["input={grib: test}"], 71 | ) 72 | assert isinstance(config.input, dict) and "grib" in config.input 73 | assert config.input["grib"] == "test" 74 | 75 | 76 | def test_config_dotlist_override_add_new_non_empty_list() -> None: 77 | """Test overriding with an additional parameter""" 78 | config = RunConfiguration.load( 79 | files_for_tests("unit/configs/simple.yaml"), 80 | overrides=["pre_processors=[test]"], 81 | ) 82 | assert isinstance(config.pre_processors, list) and len(config.pre_processors) == 1 83 | assert config.pre_processors[0] == "test" 84 | 85 | 86 | def test_config_dotlist_override_index_error() -> None: 87 | """Test failing with index error in dotlist argument""" 88 | with pytest.raises(IndexError): 89 | RunConfiguration.load( 90 | files_for_tests("unit/configs/mwd.yaml"), 91 | overrides=[ 92 | "runner=testing", 93 | "device=cpu", 94 | "input=dummy", 95 | "post_processors.2.backward_transform_filter=test", 96 | ], 97 | ) 98 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | [build-system] 11 | build-backend = "setuptools.build_meta" 12 | 13 | requires = [ "setuptools>=60", "setuptools-scm>=8" ] 14 | 15 | [project] 16 | name = "anemoi-inference" 17 | 18 | description = "A package to run inference from data-driven forecasts weather models." 19 | readme = "README.md" 20 | 21 | keywords = [ "ai", "inference", "tools" ] 22 | 23 | license = { file = "LICENSE" } 24 | authors = [ 25 | { name = "European Centre for Medium-Range Weather Forecasts (ECMWF)", email = "software.support@ecmwf.int" }, 26 | ] 27 | 28 | requires-python = ">=3.10" 29 | 30 | classifiers = [ 31 | "Development Status :: 4 - Beta", 32 | "Intended Audience :: Developers", 33 | "License :: OSI Approved :: Apache Software License", 34 | "Operating System :: OS Independent", 35 | "Programming Language :: Python :: 3 :: Only", 36 | "Programming Language :: Python :: 3.10", 37 | "Programming Language :: Python :: 3.11", 38 | "Programming Language :: Python :: 3.12", 39 | "Programming Language :: Python :: Implementation :: CPython", 40 | "Programming Language :: Python :: Implementation :: PyPy", 41 | ] 42 | 43 | dynamic = [ "version" ] 44 | dependencies = [ 45 | "anemoi-transform>0.1.13", 46 | "anemoi-utils[provenance,text]>=0.4.32", 47 | "aniso8601", 48 | "anytree", 49 | "deprecation", 50 | "earthkit-data>=0.12.4", 51 | "eccodes>=2.38.3", 52 | "numpy", 53 | "omegaconf>=2.2,<2.4", 54 | "packaging", 55 | "pydantic", 56 | "pyyaml", 57 | "rich", 58 | "semantic-version", 59 | "torch", 60 | ] 61 | 62 | optional-dependencies.all = [ 63 | "anemoi-datasets", 64 | "anemoi-inference[huggingface,plot,tests,zarr]", 65 | "anemoi-utils[all]>=0.4.32", 66 | ] 67 | optional-dependencies.cosmo = [ 68 | "eccodes<=2.39.1", 69 | "eccodes-cosmo-resources-python", 70 | ] 71 | 72 | optional-dependencies.dev = [ "anemoi-inference[all,docs,plugin,tests]" ] 73 | 74 | optional-dependencies.docs = [ 75 | "autodoc-pydantic", 76 | "nbsphinx", 77 | "pandoc", 78 | "rstfmt", 79 | "sphinx<8.2", 80 | "sphinx-argparse<0.5", 81 | "sphinx-rtd-theme", 82 | ] 83 | 84 | optional-dependencies.huggingface = [ "huggingface-hub" ] 85 | optional-dependencies.plot = [ "earthkit-plots" ] 86 | 87 | optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ] 88 | optional-dependencies.tests = [ 89 | "anemoi-datasets[all]", 90 | "anemoi-inference[all]", 91 | "hypothesis", 92 | "pytest", 93 | "pytest-mock", 94 | ] 95 | optional-dependencies.zarr = [ "zarr" ] 96 | 97 | urls.Documentation = "https://anemoi-inference.readthedocs.io/" 98 | urls.Homepage = "https://github.com/ecmwf/anemoi-inference/" 99 | urls.Issues = "https://github.com/ecmwf/anemoi-inference/issues" 100 | urls.Repository = "https://github.com/ecmwf/anemoi-inference/" 101 | scripts.anemoi-inference = "anemoi.inference.__main__:main" 102 | 103 | entry-points."ai_models.model".anemoi = "anemoi.inference.plugin:AIModelPlugin" 104 | 105 | [tool.setuptools.package-data] 106 | "anemoi.inference.grib.templates" = [ "*.yaml" ] 107 | 108 | [tool.setuptools.packages.find] 109 | where = [ "src" ] 110 | 111 | [tool.setuptools_scm] 112 | version_file = "src/anemoi/inference/_version.py" 113 | 114 | [tool.pytest.ini_options] 115 | markers = [ 116 | "skip_on_hpc: mark a test that should not be run on HPC", 117 | ] 118 | testpaths = "tests" 119 | 120 | [tool.mypy] 121 | exclude = [ "docs/" ] 122 | strict = false 123 | ignore_missing_imports = true 124 | allow_redefinition = true 125 | -------------------------------------------------------------------------------- /src/anemoi/inference/checks.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2024 Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | 10 | 11 | import datetime 12 | import logging 13 | import sys 14 | from collections import defaultdict 15 | from typing import Any 16 | 17 | from anemoi.utils.humanize import plural 18 | from earthkit.data.utils.dates import to_datetime 19 | 20 | from anemoi.inference.checkpoint import Checkpoint 21 | 22 | LOG = logging.getLogger(__name__) 23 | 24 | 25 | def check_data( 26 | title: str, data: Any, variables: list[str], dates: list[datetime.datetime], checkpoint: Checkpoint 27 | ) -> None: 28 | """Check if the data matches the expected number of fields based on variables and dates. 29 | 30 | Parameters 31 | ---------- 32 | title : str 33 | The title for the data check. 34 | data : Any 35 | The data to be checked. 36 | variables : List[str] 37 | The list of variable names. 38 | dates : List[datetime.datetime] 39 | The list of dates. 40 | checkpoint : Checkpoint 41 | The checkpoint 42 | 43 | Raises 44 | ------ 45 | ValueError 46 | If the data does not match the expected number of fields. 47 | """ 48 | expected = len(variables) * len(dates) 49 | 50 | if len(data) != expected: 51 | 52 | from rich.console import Console 53 | from rich.table import Table 54 | 55 | table = Table(title=title) 56 | console = Console(file=sys.stderr) 57 | 58 | LOG.error("Data check failed for %s", title) 59 | 60 | nvars = plural(len(variables), "variable") 61 | ndates = plural(len(dates), "date") 62 | nfields = plural(expected, "field") 63 | msg = f"Expected ({nvars}) x ({ndates}) = {nfields}, got {len(data)}" 64 | LOG.error("%s", msg) 65 | 66 | table.add_column("Name", justify="left") 67 | 68 | dates = sorted(dates) 69 | 70 | for d in dates: 71 | table.add_column(d.isoformat(), justify="center") 72 | 73 | table.add_column("Categories") 74 | 75 | avail = defaultdict(set) 76 | duplicates = defaultdict(set) 77 | for field in data: 78 | name, date = field.metadata("name"), to_datetime(field.metadata("valid_datetime")) 79 | if date in avail[name]: 80 | duplicates[name].add(date) 81 | LOG.warning( 82 | "Duplicate field for variable '%s' at date %s in %s", 83 | name, 84 | date.isoformat(), 85 | title, 86 | ) 87 | avail[name].add(date) 88 | 89 | variable_categories = checkpoint.variable_categories() 90 | for name in variables: 91 | row = [name] 92 | for d in dates: 93 | if d not in avail[name]: 94 | row.append("❌") 95 | else: 96 | if d in duplicates[name]: 97 | row.append("⚠️") 98 | else: 99 | row.append("✅") 100 | 101 | if name in variable_categories: 102 | cats = ", ".join(sorted(variable_categories[name])) 103 | row.append(cats) 104 | else: 105 | row.append("N/A") 106 | 107 | table.add_row(*row) 108 | 109 | console.print() 110 | console.print(table) 111 | console.print() 112 | 113 | raise ValueError(msg) 114 | 115 | assert len(data) == len(variables) * len(dates) 116 | -------------------------------------------------------------------------------- /src/anemoi/inference/post_processors/assign.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025- Anemoi contributors. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # 6 | # In applying this licence, ECMWF does not waive the privileges and immunities 7 | # granted to it by virtue of its status as an intergovernmental organisation 8 | # nor does it submit to any jurisdiction. 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | 13 | from anemoi.inference.context import Context 14 | from anemoi.inference.decorators import main_argument 15 | from anemoi.inference.types import State 16 | 17 | from ..processor import Processor 18 | from . import post_processor_registry 19 | 20 | 21 | @post_processor_registry.register("assign_mask") 22 | @main_argument("mask") 23 | class AssignMask(Processor): 24 | """Assign a mask to the state. 25 | 26 | This processor assigns the state to a larger array using a mask to 27 | determine the region of assignment. The mask can be provided as a 28 | boolean numpy array or as a path to a file containing the mask. 29 | This processor can be seen as the opposite of "extract_mask". 30 | Instead of extracting a smaller area from a larger one, 31 | it assigns a state to a larger area using a mask. 32 | 33 | Parameters 34 | ---------- 35 | context : Context 36 | The context containing the checkpoint and supporting arrays. 37 | mask : str 38 | The name of the mask supporting array or a path to a file containing the mask. 39 | fill_value : float, optional 40 | The value to fill the non-assigned area, by default NaN. 41 | """ 42 | 43 | def __init__(self, context: Context, mask: str, fill_value: float = float("NaN")) -> None: 44 | super().__init__(context) 45 | 46 | self._maskname = mask 47 | 48 | if Path(mask).is_file(): 49 | mask = np.load(mask) 50 | else: 51 | mask = context.checkpoint.load_supporting_array(mask) 52 | 53 | if not isinstance(mask, np.ndarray) or mask.dtype != bool: 54 | raise ValueError( 55 | "Expected the mask to be a boolean numpy array. " f"Got {type(mask)} with dtype {mask.dtype}." 56 | ) 57 | 58 | self.indexer = mask 59 | self.npoints = np.sum(mask) 60 | self.fill_value = fill_value 61 | 62 | def process(self, state: State) -> State: 63 | """Assign the state to the mask. 64 | 65 | Parameters 66 | ---------- 67 | state : State 68 | The state dictionary. 69 | 70 | Returns 71 | ------- 72 | State 73 | The masked state dictionary. 74 | """ 75 | state = state.copy() 76 | state["fields"] = state["fields"].copy() 77 | 78 | state["latitudes"] = self._assign_mask(state["latitudes"]) 79 | state["longitudes"] = self._assign_mask(state["longitudes"]) 80 | for field in state["fields"]: 81 | state["fields"][field] = self._assign_mask(state["fields"][field]) 82 | 83 | return state 84 | 85 | def _assign_mask(self, array: np.ndarray): 86 | """Logic to assign the array to the mask.""" 87 | shape = array.shape[:-1] + self.indexer.shape 88 | res = np.full(shape, self.fill_value, dtype=array.dtype) 89 | if array.ndim == 1: 90 | res[self.indexer] = array 91 | else: 92 | res[..., self.indexer] = array 93 | return res 94 | 95 | def __repr__(self) -> str: 96 | """Return a string representation of the AssignMask object. 97 | 98 | Returns 99 | ------- 100 | str 101 | A string representation of the AssignMask object. 102 | """ 103 | return f"AssignMask(mask={self._maskname}, points={self.npoints}/{self.indexer.size}, fill_value={self.fill_value})" 104 | --------------------------------------------------------------------------------