├── tests ├── move │ ├── __init__.py │ └── array │ │ ├── __init__.py │ │ ├── test_downstream.py │ │ └── test_upstream.py ├── catchments │ ├── __init__.py │ └── array │ │ ├── __init__.py │ │ └── test_find.py ├── distance │ ├── __init__.py │ └── array │ │ ├── __init__.py │ │ ├── test_max.py │ │ └── test_min.py ├── length │ ├── __init__.py │ └── array │ │ ├── __init__.py │ │ ├── test_max.py │ │ └── test_min.py ├── upstream │ ├── __init__.py │ └── array │ │ ├── __init__.py │ │ ├── test_max.py │ │ ├── test_min.py │ │ ├── test_mean.py │ │ └── test_sum.py ├── river_network │ ├── __init__.py │ └── test_load.py ├── utils.py ├── _test_inputs │ ├── subnetwork.py │ ├── movement.py │ ├── river_networks.py │ ├── catchment.py │ ├── readers.py │ └── distance.py └── conftest.py ├── src └── earthkit │ └── hydro │ ├── _backends │ ├── __init__.py │ ├── mlx_backend.py │ ├── find.py │ ├── jax_backend.py │ ├── torch_backend.py │ ├── array_backend.py │ ├── cupy_backend.py │ ├── numpy_backend.py │ └── tensorflow_backend.py │ ├── _utils │ ├── __init__.py │ ├── decorators │ │ ├── __init__.py │ │ ├── array_backend.py │ │ ├── masking.py │ │ └── xarray.py │ ├── coords.py │ ├── locations.py │ └── readers.py │ ├── data_structures │ ├── __init__.py │ ├── _network_storage.py │ └── _network.py │ ├── _core │ ├── __init__.py │ ├── metrics.py │ ├── flow.py │ ├── _find.py │ ├── _accumulate.py │ ├── online.py │ ├── _move.py │ ├── move.py │ └── accumulate.py │ ├── subnetwork │ ├── __init__.py │ └── _toplevel.py │ ├── move │ ├── array │ │ ├── __init__.py │ │ ├── __operations.py │ │ ├── _operations.py │ │ └── _toplevel.py │ ├── __init__.py │ └── _toplevel.py │ ├── river_network │ ├── __init__.py │ └── _cache.py │ ├── length │ ├── array │ │ ├── __init__.py │ │ ├── _operations.py │ │ └── __operations.py │ └── __init__.py │ ├── distance │ ├── array │ │ ├── __init__.py │ │ ├── __operations.py │ │ ├── _operations.py │ │ └── _toplevel.py │ └── __init__.py │ ├── downstream │ ├── array │ │ ├── __init__.py │ │ └── _operations.py │ └── __init__.py │ ├── upstream │ ├── array │ │ ├── __init__.py │ │ └── _operations.py │ └── __init__.py │ ├── catchments │ ├── array │ │ ├── __init__.py │ │ ├── _operations.py │ │ └── __operations.py │ ├── __init__.py │ ├── _operations.py │ └── _xarray.py │ ├── _readers │ ├── __init__.py │ ├── _grit.py │ └── group_labels.py │ └── __init__.py ├── docs ├── images │ ├── glofas.png │ ├── accuflux.gif │ ├── catchment.gif │ ├── subcatchment.gif │ ├── distance_length.png │ ├── earthkit_example.png │ ├── array_backends_with_xr.png │ └── raster_vector_networks.jpg ├── source │ ├── references.rst │ ├── tutorials │ │ ├── index.rst │ │ └── loading_river_networks.ipynb │ ├── references.bib │ ├── contributing.rst │ ├── userguide │ │ ├── index.rst │ │ ├── specifying_locations.rst │ │ ├── earthkit.rst │ │ ├── catchment_delineation.rst │ │ ├── catchment_statistics.rst │ │ ├── loading_a_river_network.rst │ │ ├── subnetwork_creation.rst │ │ ├── raster_vector_inputs.rst │ │ ├── xarray_array_backend.rst │ │ ├── flow_accumulations.rst │ │ ├── distance_length_calculations.rst │ │ └── pcraster.rst │ ├── index.rst │ └── conf.py ├── Makefile └── clean_autodocs.py ├── pytest.ini ├── .github ├── ci-hpc-config.yml ├── ci-hpc-gpu-config.yml └── workflows │ ├── label-public-pr.yml │ ├── ci.yml │ ├── nightly-hpc-gpu.yml │ ├── cd.yml │ ├── test-pypi.yml │ └── downstream-ci.yml ├── .gitignore ├── .readthedocs.yml ├── .flake8 ├── Cargo.toml ├── setup.py ├── .pre-commit-config.yaml ├── rust └── lib.rs ├── pyproject.toml └── README.md /tests/move/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/catchments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/distance/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/length/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/move/array/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/upstream/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/distance/array/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/length/array/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/river_network/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/upstream/array/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/catchments/array/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_backends/__init__.py: -------------------------------------------------------------------------------- 1 | from .find import get_array_backend 2 | -------------------------------------------------------------------------------- /docs/images/glofas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecmwf/earthkit-hydro/HEAD/docs/images/glofas.png -------------------------------------------------------------------------------- /docs/source/references.rst: -------------------------------------------------------------------------------- 1 | Bibliography 2 | ============ 3 | 4 | .. bibliography:: 5 | :cited: 6 | -------------------------------------------------------------------------------- /docs/images/accuflux.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecmwf/earthkit-hydro/HEAD/docs/images/accuflux.gif -------------------------------------------------------------------------------- /docs/images/catchment.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecmwf/earthkit-hydro/HEAD/docs/images/catchment.gif -------------------------------------------------------------------------------- /docs/images/subcatchment.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecmwf/earthkit-hydro/HEAD/docs/images/subcatchment.gif -------------------------------------------------------------------------------- /docs/images/distance_length.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecmwf/earthkit-hydro/HEAD/docs/images/distance_length.png -------------------------------------------------------------------------------- /docs/images/earthkit_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecmwf/earthkit-hydro/HEAD/docs/images/earthkit_example.png -------------------------------------------------------------------------------- /src/earthkit/hydro/_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import earthkit.hydro._utils.decorators 2 | import earthkit.hydro._utils.readers 3 | -------------------------------------------------------------------------------- /src/earthkit/hydro/data_structures/__init__.py: -------------------------------------------------------------------------------- 1 | from ._network import RiverNetwork 2 | 3 | __all__ = ["RiverNetwork"] 4 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_core/__init__.py: -------------------------------------------------------------------------------- 1 | from .accumulate import flow_downstream, flow_upstream 2 | from .flow import propagate 3 | -------------------------------------------------------------------------------- /src/earthkit/hydro/subnetwork/__init__.py: -------------------------------------------------------------------------------- 1 | from ._toplevel import crop, from_mask 2 | 3 | __all__ = ["from_mask", "crop"] 4 | -------------------------------------------------------------------------------- /docs/images/array_backends_with_xr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecmwf/earthkit-hydro/HEAD/docs/images/array_backends_with_xr.png -------------------------------------------------------------------------------- /docs/images/raster_vector_networks.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ecmwf/earthkit-hydro/HEAD/docs/images/raster_vector_networks.jpg -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | minversion = 7.0 3 | addopts = -v "--pdbcls=IPython.terminal.debugger:Pdb" 4 | testpaths = 5 | tests 6 | -------------------------------------------------------------------------------- /src/earthkit/hydro/move/array/__init__.py: -------------------------------------------------------------------------------- 1 | from ._toplevel import downstream, upstream 2 | 3 | __all__ = ["downstream", "upstream"] 4 | -------------------------------------------------------------------------------- /src/earthkit/hydro/river_network/__init__.py: -------------------------------------------------------------------------------- 1 | from ._river_network import available, create, load 2 | 3 | __all__ = ["available", "create", "load"] 4 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_utils/decorators/__init__.py: -------------------------------------------------------------------------------- 1 | from .array_backend import multi_backend 2 | from .masking import mask 3 | from .xarray import xarray 4 | -------------------------------------------------------------------------------- /src/earthkit/hydro/length/array/__init__.py: -------------------------------------------------------------------------------- 1 | from ._toplevel import max, min, to_sink, to_source 2 | 3 | __all__ = ["max", "min", "to_sink", "to_source"] 4 | -------------------------------------------------------------------------------- /src/earthkit/hydro/distance/array/__init__.py: -------------------------------------------------------------------------------- 1 | from ._toplevel import max, min, to_sink, to_source 2 | 3 | __all__ = ["max", "min", "to_sink", "to_source"] 4 | -------------------------------------------------------------------------------- /.github/ci-hpc-config.yml: -------------------------------------------------------------------------------- 1 | build: 2 | modules: 3 | - rust 4 | python_dependencies: 5 | - ecmwf/earthkit-data@develop 6 | - ecmwf/earthkit-utils@develop 7 | -------------------------------------------------------------------------------- /src/earthkit/hydro/downstream/array/__init__.py: -------------------------------------------------------------------------------- 1 | from ._toplevel import max, mean, min, std, sum, var 2 | 3 | __all__ = ["max", "mean", "min", "std", "sum", "var"] 4 | -------------------------------------------------------------------------------- /src/earthkit/hydro/upstream/array/__init__.py: -------------------------------------------------------------------------------- 1 | from ._toplevel import max, mean, min, std, sum, var 2 | 3 | __all__ = ["max", "mean", "min", "std", "sum", "var"] 4 | -------------------------------------------------------------------------------- /src/earthkit/hydro/catchments/array/__init__.py: -------------------------------------------------------------------------------- 1 | from ._toplevel import find, max, mean, min, std, sum, var 2 | 3 | __all__ = ["find", "max", "mean", "min", "std", "sum", "var"] 4 | -------------------------------------------------------------------------------- /src/earthkit/hydro/move/__init__.py: -------------------------------------------------------------------------------- 1 | import earthkit.hydro.move.array 2 | 3 | from ._toplevel import downstream, upstream 4 | 5 | __all__ = ["array", "downstream", "upstream"] 6 | -------------------------------------------------------------------------------- /src/earthkit/hydro/distance/__init__.py: -------------------------------------------------------------------------------- 1 | import earthkit.hydro.distance.array 2 | 3 | from ._toplevel import max, min, to_sink, to_source 4 | 5 | __all__ = ["array", "max", "min", "to_sink", "to_source"] 6 | -------------------------------------------------------------------------------- /src/earthkit/hydro/length/__init__.py: -------------------------------------------------------------------------------- 1 | import earthkit.hydro.length.array 2 | 3 | from ._toplevel import max, min, to_sink, to_source 4 | 5 | __all__ = ["array", "max", "min", "to_sink", "to_source"] 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | *.egg-info 4 | _version.py 5 | target/ 6 | dist/ 7 | build/ 8 | Cargo.lock 9 | .DS_Store 10 | *.so 11 | autodocs/ 12 | 13 | *.joblib 14 | *.nc 15 | -------------------------------------------------------------------------------- /src/earthkit/hydro/upstream/__init__.py: -------------------------------------------------------------------------------- 1 | import earthkit.hydro.upstream.array 2 | 3 | from ._toplevel import max, mean, min, std, sum, var 4 | 5 | __all__ = ["array", "max", "mean", "min", "std", "sum", "var"] 6 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_readers/__init__.py: -------------------------------------------------------------------------------- 1 | from .readers import ( 2 | find_main_var, 3 | from_cama_downxy, 4 | from_cama_nextxy, 5 | from_d8, 6 | from_grit, 7 | import_earthkit_or_prompt_install, 8 | ) 9 | -------------------------------------------------------------------------------- /src/earthkit/hydro/downstream/__init__.py: -------------------------------------------------------------------------------- 1 | import earthkit.hydro.downstream.array as array 2 | 3 | from ._toplevel import max, mean, min, std, sum, var 4 | 5 | __all__ = ["array", "max", "mean", "min", "std", "sum", "var"] 6 | -------------------------------------------------------------------------------- /src/earthkit/hydro/catchments/__init__.py: -------------------------------------------------------------------------------- 1 | import earthkit.hydro.catchments.array 2 | 3 | from ._toplevel import find, max, mean, min, std, sum, var 4 | 5 | __all__ = ["array", "find", "max", "mean", "min", "std", "sum", "var"] 6 | -------------------------------------------------------------------------------- /tests/river_network/test_load.py: -------------------------------------------------------------------------------- 1 | import earthkit.hydro as ekh 2 | 3 | 4 | def test_load(): 5 | net = ekh.river_network.load("efas", "5", use_cache=False) 6 | assert net.n_nodes == 7446075 7 | assert net.n_edges == 7353055 8 | -------------------------------------------------------------------------------- /.github/ci-hpc-gpu-config.yml: -------------------------------------------------------------------------------- 1 | build: 2 | modules: 3 | - rust 4 | python_dependencies: 5 | - ecmwf/earthkit-data@develop 6 | - ecmwf/earthkit-utils@develop 7 | queue: ng 8 | gpus: 1 9 | toml_opt_dep_sections: all,tests 10 | -------------------------------------------------------------------------------- /.github/workflows/label-public-pr.yml: -------------------------------------------------------------------------------- 1 | name: label-public-pr 2 | 3 | on: 4 | pull_request_target: 5 | types: [opened, synchronize] 6 | 7 | jobs: 8 | label: 9 | uses: ecmwf/reusable-workflows/.github/workflows/label-pr.yml@v2 10 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.11" 7 | jobs: 8 | pre_build: 9 | - rm -rf _build 10 | - rm -rf docs/_build 11 | - cd docs && make rst 12 | 13 | python: 14 | install: 15 | - method: pip 16 | path: . 17 | extra_requirements: 18 | - docs 19 | 20 | sphinx: 21 | configuration: docs/source/conf.py 22 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Example to set max line length 3 | max-line-length = 88 4 | ignore = E203, W503, E501 5 | 6 | # Per-file ignore settings 7 | per-file-ignores = 8 | # Ignore F401 (unused import) in __init__.py files 9 | __init__.py: F401 10 | 11 | # Ignore specific errors in the tests folder 12 | # - F405: Undefined variable 13 | # - F403: Wildcard import usage 14 | tests/*: F405,F403 15 | -------------------------------------------------------------------------------- /docs/source/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | Basics: 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | loading_river_networks 10 | xarray_array 11 | gridded_masked 12 | array_backend 13 | 14 | Operations: 15 | 16 | .. toctree:: 17 | :maxdepth: 1 18 | 19 | computing_accumulations 20 | finding_catchments 21 | catchment_statistics 22 | distance_length 23 | creating_subnetworks 24 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def extend_array(arr, extra_shape): 5 | current_shape = arr.shape 6 | new_shape = (*extra_shape, *current_shape) 7 | extended_array = np.broadcast_to(arr, new_shape) 8 | return extended_array 9 | 10 | 11 | def convert_to_2d(river_network, array, fill_value): 12 | field = np.full(river_network.mask.shape, fill_value=fill_value, dtype=array.dtype) 13 | field[river_network.mask] = array 14 | return field 15 | -------------------------------------------------------------------------------- /tests/_test_inputs/subnetwork.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | mask_2 = np.array( 4 | [ 5 | True, 6 | True, 7 | True, 8 | True, 9 | True, 10 | True, 11 | False, 12 | True, 13 | True, 14 | True, 15 | True, 16 | True, 17 | True, 18 | False, 19 | True, 20 | True, 21 | ] 22 | ) 23 | 24 | 25 | masked_unit_accuflux_2 = np.array([2, 1, 2, 1, 1, 2, 3, 1, 1, 3, 6, 1, 1, 2]) 26 | -------------------------------------------------------------------------------- /tests/_test_inputs/movement.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | upstream_1 = np.array( 4 | [0, 0, 0, 0, 0, 1, 2, 7, 5, 0, 6, 7, 31, 25, 0, 0, 70, 19, 20, 0], dtype=float 5 | ) 6 | 7 | 8 | upstream_2 = np.array( 9 | [13, 0, 2, 0, 0, 5, 12, 3, 0, 0, 13, 24, 0, 30, 0, 15], dtype=float 10 | ) 11 | 12 | 13 | downstream_1 = np.array( 14 | [6, 7, 8, 8, 9, 11, 12, 13, 13, 14, 17, 17, 17, 13, 14, 17, 0, 17, 18, 19], 15 | dtype=float, 16 | ) 17 | 18 | 19 | downstream_2 = np.array( 20 | [0, 3, 8, 0, 6, 11, 11, 12, 14, 14, 14, 7, 1, 0, 16, 12], dtype=float 21 | ) 22 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "earthkit-hydro" 3 | version = "0.0.0" # placeholder, will be overwritten 4 | edition = "2021" 5 | 6 | [dependencies] 7 | pyo3 = { version = "0.26", features = ["extension-module"] } 8 | numpy = "0.26" 9 | rayon = "1.7" 10 | fixedbitset = "0.5" 11 | 12 | [lib] 13 | # See https://github.com/PyO3/pyo3 for details 14 | name = "_rust" # private module to be nested into Python package 15 | path = "rust/lib.rs" 16 | crate-type = ["cdylib"] 17 | 18 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 19 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_core/metrics.py: -------------------------------------------------------------------------------- 1 | def metrics_func_finder(metric, xp): 2 | 3 | class SumBased: 4 | func = xp.scatter_add 5 | base_val = 0 6 | 7 | class MaxBased: 8 | func = xp.scatter_max 9 | base_val = -xp.inf 10 | 11 | class MinBased: 12 | func = xp.scatter_min 13 | base_val = xp.inf 14 | 15 | metrics_dict = { 16 | "sum": SumBased, 17 | "mean": SumBased, 18 | "std": SumBased, 19 | "var": SumBased, 20 | "max": MaxBased, 21 | "min": MinBased, 22 | } 23 | return metrics_dict[metric] 24 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_core/flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from earthkit.hydro.data_structures import RiverNetwork 4 | 5 | 6 | def propagate( 7 | river_network: RiverNetwork, 8 | groups: np.ndarray, 9 | field: np.ndarray, 10 | invert_graph: bool, 11 | operation, 12 | *args, 13 | **kwargs, 14 | ): 15 | if invert_graph: 16 | for uid, did, eid in groups[::-1]: 17 | field = operation(field, did, uid, eid, *args, **kwargs) 18 | else: 19 | for did, uid, eid in groups: 20 | field = operation(field, did, uid, eid, *args, **kwargs) 21 | 22 | return field 23 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | branches: [ "main", "develop" ] 6 | pull_request: 7 | branches: [ "main", "develop" ] 8 | 9 | jobs: 10 | qa-pre-commit: 11 | uses: ecmwf/reusable-workflows/.github/workflows/qa-precommit-run.yml@v2 12 | secrets: inherit 13 | 14 | # TODO: add back 15 | # test: 16 | # uses: ecmwf/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2 17 | # secrets: inherit 18 | # with: 19 | # optional-dependencies: tests,dev 20 | 21 | # TODO: add back 22 | # qa-python: 23 | # uses: ecmwf-actions/reusable-workflows/.github/workflows/ci-python.yml@v2 24 | # secrets: inherit 25 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from earthkit.hydro._readers import from_cama_downxy, from_cama_nextxy, from_d8 4 | from earthkit.hydro.data_structures import RiverNetwork 5 | 6 | 7 | @pytest.fixture 8 | def river_network(request): 9 | river_network_format, flow_directions = request.param 10 | if river_network_format == "d8_ldd": 11 | river_network = from_d8(flow_directions) 12 | elif river_network_format == "cama_downxy": 13 | river_network = from_cama_downxy(*flow_directions) 14 | elif river_network_format == "cama_nextxy": 15 | river_network = from_cama_nextxy(*flow_directions) 16 | # TODO: add ESRI 17 | 18 | return RiverNetwork(river_network) 19 | -------------------------------------------------------------------------------- /src/earthkit/hydro/move/array/__operations.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro._core.move import calculate_move_metric 2 | 3 | 4 | def upstream(xp, river_network, field, node_weights, edge_weights, metric): 5 | return calculate_move_metric( 6 | xp, 7 | river_network, 8 | field, 9 | metric, 10 | node_weights, 11 | edge_weights, 12 | flow_direction="up", 13 | ) 14 | 15 | 16 | def downstream(xp, river_network, field, node_weights, edge_weights, metric): 17 | return calculate_move_metric( 18 | xp, 19 | river_network, 20 | field, 21 | metric, 22 | node_weights, 23 | edge_weights, 24 | flow_direction="down", 25 | ) 26 | -------------------------------------------------------------------------------- /.github/workflows/nightly-hpc-gpu.yml: -------------------------------------------------------------------------------- 1 | name: nightly-hpc-gpu 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | # Run at 03:20 UTC every day (on default branch) 7 | schedule: 8 | - cron: "20 03 * * *" 9 | 10 | jobs: 11 | test-hpc-gpu: 12 | runs-on: [self-hosted, linux, hpc] 13 | steps: 14 | - uses: ecmwf/reusable-workflows/ci-hpc@v2 15 | with: 16 | github_user: ${{ secrets.BUILD_PACKAGE_HPC_GITHUB_USER }} 17 | github_token: ${{ secrets.GH_REPO_READ_TOKEN }} 18 | troika_user: ${{ secrets.HPC_TEST_USER }} 19 | repository: ecmwf/earthkit-hydro@${{ github.event.pull_request.head.sha || github.sha }} 20 | build_config: .github/ci-hpc-gpu-config.yml 21 | python_version: "3.10" 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import setup 4 | from setuptools_rust import RustExtension 5 | 6 | use_rust = int(os.environ.get("USE_RUST", "-1")) 7 | 8 | if use_rust == 0: # pure python 9 | # print("Building pure Python version.") 10 | rust_extensions = [] 11 | elif use_rust == 1: # rust extension 12 | # print("Building with rust bindings.") 13 | rust_extensions = [RustExtension("earthkit.hydro._rust", "Cargo.toml")] 14 | else: # (default) try rust extension, if fail fallback to python 15 | # print("Building with rust bindings, and if failing reverting to pure Python.") 16 | rust_extensions = [ 17 | RustExtension("earthkit.hydro._rust", "Cargo.toml", optional=True), 18 | ] 19 | 20 | setup(rust_extensions=rust_extensions) 21 | -------------------------------------------------------------------------------- /.github/workflows/cd.yml: -------------------------------------------------------------------------------- 1 | name: cd 2 | 3 | on: 4 | push: 5 | tags: 6 | - '**' 7 | 8 | jobs: 9 | pypi_binwheels: 10 | uses: ecmwf/reusable-workflows/.github/workflows/cd-pypi-binwheel.yml@v2 11 | secrets: inherit 12 | with: 13 | platforms: "['ubuntu-latest','macos-latest','windows-latest']" 14 | pyversions: "['39','310','311','312','313','314']" 15 | env_vars: | 16 | { 17 | "USE_RUST": "1", 18 | "SETUPTOOLS_RUST_CARGO_PROFILE": "release" 19 | } 20 | 21 | pypi_purepython: 22 | needs: pypi_binwheels 23 | uses: ecmwf/reusable-workflows/.github/workflows/cd-pypi.yml@v2 24 | secrets: inherit 25 | with: 26 | env_vars: | 27 | { 28 | "USE_RUST": "0", 29 | "SETUPTOOLS_RUST_CARGO_PROFILE": "release" 30 | } 31 | -------------------------------------------------------------------------------- /.github/workflows/test-pypi.yml: -------------------------------------------------------------------------------- 1 | name: test-cd 2 | 3 | on: 4 | pull_request: 5 | branches: [ "main" ] 6 | 7 | jobs: 8 | pypi_binwheels: 9 | uses: ecmwf/reusable-workflows/.github/workflows/cd-pypi-binwheel.yml@v2 10 | secrets: inherit 11 | with: 12 | platforms: "['ubuntu-latest','macos-latest','windows-latest']" 13 | pyversions: "['39','310','311','312','313','314']" 14 | testpypi: true 15 | env_vars: | 16 | { 17 | "USE_RUST": "1", 18 | "SETUPTOOLS_RUST_CARGO_PROFILE": "release" 19 | } 20 | 21 | pypi_purepython: 22 | needs: pypi_binwheels 23 | uses: ecmwf/reusable-workflows/.github/workflows/cd-pypi.yml@v2 24 | secrets: inherit 25 | with: 26 | testpypi: true 27 | env_vars: | 28 | { 29 | "USE_RUST": "0", 30 | "SETUPTOOLS_RUST_CARGO_PROFILE": "release" 31 | } 32 | -------------------------------------------------------------------------------- /tests/_test_inputs/river_networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | downstream_nodes_1 = np.array( 4 | [ 5 | 5, 6 | 6, 7 | 7, 8 | 7, 9 | 8, 10 | 10, 11 | 11, 12 | 12, 13 | 12, 14 | 13, 15 | 16, 16 | 16, 17 | 16, 18 | 12, 19 | 13, 20 | 16, 21 | 20, # we set sink to len of nodes 22 | 16, 23 | 17, 24 | 18, 25 | ] 26 | ) 27 | 28 | 29 | downstream_nodes_2 = np.array( 30 | [ 31 | 16, # we set sink to len of nodes 32 | 2, 33 | 7, 34 | 16, # we set sink to len of nodes 35 | 5, 36 | 10, 37 | 10, 38 | 11, 39 | 13, 40 | 13, 41 | 13, 42 | 6, 43 | 0, 44 | 16, # we set sink to len of nodes 45 | 15, 46 | 11, 47 | ] 48 | ) 49 | -------------------------------------------------------------------------------- /tests/move/array/test_downstream.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from _test_inputs.movement import * 4 | from _test_inputs.readers import * 5 | 6 | import earthkit.hydro as ekh 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "river_network, flow_downstream", 11 | [ 12 | (("cama_nextxy", cama_nextxy_1), upstream_1), 13 | (("cama_nextxy", cama_nextxy_2), upstream_2), 14 | ], 15 | indirect=["river_network"], 16 | ) 17 | def test_calculate_upstream_metric_max(river_network, flow_downstream): 18 | output_field = ekh.move.array.downstream( 19 | river_network, 20 | np.arange(1, river_network.n_nodes + 1), 21 | node_weights=None, 22 | return_type="masked", 23 | ) 24 | print(output_field) 25 | print(flow_downstream) 26 | assert output_field.dtype == flow_downstream.dtype 27 | np.testing.assert_allclose(output_field, flow_downstream) 28 | -------------------------------------------------------------------------------- /tests/move/array/test_upstream.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from _test_inputs.movement import * 4 | from _test_inputs.readers import * 5 | 6 | import earthkit.hydro as ekh 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "river_network, flow_downstream", 11 | [ 12 | (("cama_nextxy", cama_nextxy_1), downstream_1), 13 | (("cama_nextxy", cama_nextxy_2), downstream_2), 14 | ], 15 | indirect=["river_network"], 16 | ) 17 | def test_calculate_upstream_metric_max(river_network, flow_downstream): 18 | output_field = ekh.move.array.upstream( 19 | river_network, 20 | np.arange(1, river_network.n_nodes + 1), 21 | node_weights=None, 22 | return_type="masked", 23 | ) 24 | print(output_field) 25 | print(flow_downstream) 26 | assert output_field.dtype == flow_downstream.dtype 27 | np.testing.assert_allclose(output_field, flow_downstream) 28 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_backends/mlx_backend.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | 3 | from .array_backend import ArrayBackend 4 | 5 | 6 | class MLXBackend(ArrayBackend): 7 | def __init__(self): 8 | super().__init__(mx) 9 | 10 | @property 11 | def name(self): 12 | return "mlx" 13 | 14 | def copy(self, x): 15 | return x 16 | 17 | def asarray(self, x, *args, **kwargs): 18 | return mx.array(x) 19 | 20 | def full(self, *args, **kwargs): 21 | kwargs.pop("device") 22 | return mx.full(*args, **kwargs) 23 | 24 | def gather(self, arr, indices, axis=-1): 25 | assert axis == -1 26 | return arr[..., indices] 27 | 28 | def scatter_assign(self, target, indices, updates): 29 | target[..., indices] = updates 30 | return target 31 | 32 | def scatter_add(self, target, indices, updates): 33 | return target.at[..., indices].add(updates) 34 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_backends/find.py: -------------------------------------------------------------------------------- 1 | def get_array_backend(x): 2 | 3 | if type(x) is str: 4 | mod = x 5 | else: 6 | mod = type(x).__module__ 7 | 8 | if "torch" in mod: 9 | from .torch_backend import TorchBackend 10 | 11 | return TorchBackend() 12 | elif "tensorflow" in mod: 13 | from .tensorflow_backend import TFBackend 14 | 15 | return TFBackend() 16 | elif "jax" in mod: 17 | from .jax_backend import JAXBackend 18 | 19 | return JAXBackend() 20 | elif "cupy" in mod: 21 | from .cupy_backend import CuPyBackend 22 | 23 | return CuPyBackend() 24 | elif "numpy" in mod: 25 | from .numpy_backend import NumPyBackend 26 | 27 | return NumPyBackend() 28 | elif "mlx" in mod: 29 | from .mlx_backend import MLXBackend 30 | 31 | return MLXBackend() 32 | else: 33 | raise TypeError(f"Unsupported array type: {type(x)}") 34 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_readers/_grit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_edge_indices(offsets, grouping): 5 | lengths = offsets[grouping + 1] - offsets[grouping] 6 | total_len = np.sum(lengths) 7 | result = np.empty(total_len, dtype=int) 8 | pos = 0 9 | 10 | for node, length in zip(grouping, lengths): 11 | start = offsets[node] 12 | for j in range(length): 13 | result[pos + j] = start + j 14 | pos += length 15 | return result 16 | 17 | 18 | def compute_topological_labels_bifurcations(down_ids, offsets, sources, sinks): 19 | n_nodes = offsets.size - 1 20 | labels = np.zeros(n_nodes, dtype=int) 21 | inlets = sources 22 | 23 | for n in range(1, n_nodes + 1): 24 | inlets = np.unique(down_ids[get_edge_indices(offsets, inlets)]) 25 | if inlets.size == 0: 26 | labels[sinks] = n - 1 27 | break 28 | labels[inlets] = n 29 | 30 | return labels 31 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_utils/decorators/array_backend.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | from earthkit.hydro._backends.find import get_array_backend 4 | 5 | 6 | def multi_backend(allow_jax_jit=True, jax_static_args=None): 7 | def decorator(func): 8 | compiled_jax_fn = None 9 | 10 | @wraps(func) 11 | def wrapper(**kwargs): 12 | xp = get_array_backend(kwargs["river_network"].groups[0]) 13 | backend_name = xp.name 14 | kwargs["xp"] = xp 15 | if backend_name == "jax" and allow_jax_jit: 16 | 17 | nonlocal compiled_jax_fn 18 | if compiled_jax_fn is None: 19 | from jax import jit 20 | 21 | compiled_jax_fn = jit(func, static_argnames=jax_static_args) 22 | return compiled_jax_fn(**kwargs) 23 | else: 24 | return func(**kwargs) 25 | 26 | return wrapper 27 | 28 | return decorator 29 | -------------------------------------------------------------------------------- /src/earthkit/hydro/__init__.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2025- ECMWF. 2 | # 3 | # This software is licensed under the terms of the Apache Licence Version 2.0 4 | # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | # In applying this licence, ECMWF does not waive the privileges and immunities 6 | # granted to it by virtue of its status as an intergovernmental organisation 7 | # nor does it submit to any jurisdiction. 8 | 9 | import earthkit.hydro.catchments 10 | import earthkit.hydro.distance 11 | import earthkit.hydro.downstream 12 | import earthkit.hydro.length 13 | import earthkit.hydro.move 14 | import earthkit.hydro.river_network 15 | import earthkit.hydro.subnetwork 16 | import earthkit.hydro.upstream 17 | 18 | from ._version import __version__ 19 | 20 | __all__ = [ 21 | "catchments", 22 | "distance", 23 | "downstream", 24 | "length", 25 | "move", 26 | "river_network", 27 | "upstream", 28 | "subnetwork", 29 | "__version__", 30 | ] 31 | -------------------------------------------------------------------------------- /src/earthkit/hydro/data_structures/_network_storage.py: -------------------------------------------------------------------------------- 1 | class RiverNetworkStorage: 2 | def __init__( 3 | self, 4 | n_nodes, 5 | n_edges, 6 | sorted_data, # np.vstack((down_ids_upsort, up_ids_upsort, edge_ids_upsort)) 7 | sources, 8 | sinks, 9 | coords, 10 | splits, # indices of where to split sorted_data 11 | area, 12 | mask, 13 | shape, 14 | bifurcates=False, 15 | edge_weights=None, 16 | ): 17 | self.n_nodes = n_nodes 18 | self.n_edges = n_edges 19 | self.bifurcates = bifurcates 20 | self.sources = sources 21 | self.sinks = sinks 22 | self.coords = coords 23 | self.area = area 24 | self.sorted_data = sorted_data 25 | self.splits = splits 26 | self.mask = mask 27 | self.shape = shape 28 | self.edge_weights = edge_weights 29 | assert not (bifurcates and edge_weights is None) 30 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_backends/jax_backend.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | from .array_backend import ArrayBackend 5 | 6 | 7 | class JAXBackend(ArrayBackend): 8 | def __init__(self): 9 | super().__init__(jnp) 10 | 11 | @property 12 | def name(self): 13 | return "jax" 14 | 15 | def copy(self, x): 16 | return x 17 | 18 | def gather(self, arr, indices, axis=-1): 19 | assert axis == -1 20 | return arr[..., indices] 21 | 22 | def scatter_assign(self, target, indices, updates): 23 | return target.at[..., indices].set(updates) 24 | 25 | def scatter_add(self, target, indices, updates): 26 | return target.at[..., indices].add(updates) 27 | 28 | def asarray(self, arr, dtype=None, device=None, copy=None): 29 | for d in jax.devices(): 30 | if d.platform == device: 31 | device = d 32 | break 33 | return jnp.asarray(arr, dtype=dtype, order=None, copy=copy, device=device) 34 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_backends/torch_backend.py: -------------------------------------------------------------------------------- 1 | import array_api_compat.torch as torch 2 | 3 | from .array_backend import ArrayBackend 4 | 5 | 6 | class TorchBackend(ArrayBackend): 7 | def __init__(self): 8 | super().__init__(torch) 9 | 10 | @property 11 | def name(self): 12 | return "torch" 13 | 14 | def copy(self, x): 15 | return x.clone() 16 | 17 | def gather(self, arr, indices, axis=-1): 18 | return torch.index_select(arr, dim=axis, index=indices) 19 | 20 | def scatter_assign(self, target, indices, updates): 21 | target[..., indices] = updates 22 | return target 23 | 24 | def scatter_add(self, target, indices, updates): 25 | return target.index_add(-1, indices, updates) 26 | 27 | def scatter_max(self, target, indices, updates): 28 | return target.index_reduce(-1, indices, updates, "amax") 29 | 30 | def scatter_min(self, target, indices, updates): 31 | return target.index_reduce(-1, indices, updates, "amin") 32 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_backends/array_backend.py: -------------------------------------------------------------------------------- 1 | class ArrayBackend: 2 | def __init__(self, module): 3 | self._mod = module 4 | 5 | def __getattr__(self, name): 6 | return getattr(self._mod, name) # Delegate to underlying module 7 | 8 | @property 9 | def name(self): 10 | raise NotImplementedError 11 | 12 | def copy(self, x): 13 | raise NotImplementedError 14 | 15 | # extended functionality 16 | def gather(self, arr, indices, axis=-1): 17 | raise NotImplementedError 18 | 19 | def scatter_assign(self, target, indices, updates): 20 | raise NotImplementedError 21 | 22 | def scatter_add(self, target, indices, updates): 23 | raise NotImplementedError 24 | 25 | def scatter_max(self, target, indices, updates): 26 | raise NotImplementedError 27 | 28 | def scatter_min(self, target, indices, updates): 29 | raise NotImplementedError 30 | 31 | def scatter_mul(self, target, indices, updates): 32 | raise NotImplementedError 33 | -------------------------------------------------------------------------------- /tests/upstream/array/test_max.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from _test_inputs.accumulation import * 4 | from _test_inputs.readers import * 5 | 6 | import earthkit.hydro as ekh 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "river_network, input_field, flow_downstream, mv", 11 | [ 12 | ( 13 | ("cama_nextxy", cama_nextxy_1), 14 | input_field_1c, 15 | upstream_metric_max_1c, 16 | mv_1c, 17 | ), 18 | ( 19 | ("cama_nextxy", cama_nextxy_1), 20 | input_field_1e, 21 | upstream_metric_max_1e, 22 | mv_1e, 23 | ), 24 | ], 25 | indirect=["river_network"], 26 | ) 27 | def test_calculate_upstream_metric_max(river_network, input_field, flow_downstream, mv): 28 | output_field = ekh.upstream.array.max( 29 | river_network, input_field, node_weights=None, return_type="masked" 30 | ) 31 | print(output_field) 32 | print(flow_downstream) 33 | assert output_field.dtype == flow_downstream.dtype 34 | np.testing.assert_allclose(output_field, flow_downstream) 35 | -------------------------------------------------------------------------------- /tests/upstream/array/test_min.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from _test_inputs.accumulation import * 4 | from _test_inputs.readers import * 5 | 6 | import earthkit.hydro as ekh 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "river_network, input_field, flow_downstream, mv", 11 | [ 12 | ( 13 | ("cama_nextxy", cama_nextxy_1), 14 | input_field_1c, 15 | upstream_metric_min_1c, 16 | mv_1c, 17 | ), 18 | ( 19 | ("cama_nextxy", cama_nextxy_1), 20 | input_field_1e, 21 | upstream_metric_min_1e, 22 | mv_1e, 23 | ), 24 | ], 25 | indirect=["river_network"], 26 | ) 27 | def test_calculate_upstream_metric_min(river_network, input_field, flow_downstream, mv): 28 | output_field = ekh.upstream.array.min( 29 | river_network, input_field, node_weights=None, return_type="masked" 30 | ) 31 | print(output_field) 32 | print(flow_downstream) 33 | assert output_field.dtype == flow_downstream.dtype 34 | np.testing.assert_allclose(output_field, flow_downstream) 35 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | rst: 23 | @echo "Preparing autodocs folder..." 24 | @mkdir -p ./$(SOURCEDIR)/autodocs 25 | @rm -rf ./$(SOURCEDIR)/autodocs/* 26 | @echo "Generating .rst files with sphinx-apidoc..." 27 | sphinx-apidoc --implicit-namespaces --separate -o ./$(SOURCEDIR)/autodocs ../src/earthkit/ 28 | @rm ./$(SOURCEDIR)/autodocs/earthkit.rst 29 | @rm ./$(SOURCEDIR)/autodocs/modules.rst 30 | python clean_autodocs.py 31 | -------------------------------------------------------------------------------- /tests/length/array/test_max.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from _test_inputs.distance import * 4 | from _test_inputs.readers import * 5 | 6 | import earthkit.hydro as ekh 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "river_network, stations_list, upstream, downstream, weights, result", 11 | [ 12 | ( 13 | ("cama_nextxy", cama_nextxy_1), 14 | stations, 15 | True, 16 | False, 17 | weights_1, 18 | length_1_max_up, 19 | ), 20 | ( 21 | ("cama_nextxy", cama_nextxy_1), 22 | stations, 23 | False, 24 | True, 25 | weights_1, 26 | length_1_max_down, 27 | ), 28 | ], 29 | indirect=["river_network"], 30 | ) 31 | def test_length_max( 32 | river_network, stations_list, upstream, downstream, weights, result 33 | ): 34 | dist = ekh.length.array.max( 35 | river_network, 36 | stations_list, 37 | upstream=upstream, 38 | downstream=downstream, 39 | field=weights, 40 | ) 41 | np.testing.assert_allclose(dist, result) 42 | -------------------------------------------------------------------------------- /tests/distance/array/test_max.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from _test_inputs.distance import * 4 | from _test_inputs.readers import * 5 | 6 | import earthkit.hydro as ekh 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "river_network, stations_list, upstream, downstream, weights, result", 11 | [ 12 | ( 13 | ("cama_nextxy", cama_nextxy_1), 14 | stations, 15 | True, 16 | False, 17 | weights_1, 18 | distance_1_max_up, 19 | ), 20 | ( 21 | ("cama_nextxy", cama_nextxy_1), 22 | stations, 23 | False, 24 | True, 25 | weights_1, 26 | distance_1_max_down, 27 | ), 28 | ], 29 | indirect=["river_network"], 30 | ) 31 | def test_distance_max( 32 | river_network, stations_list, upstream, downstream, weights, result 33 | ): 34 | dist = ekh.distance.array.max( 35 | river_network, 36 | stations_list, 37 | upstream=upstream, 38 | downstream=downstream, 39 | field=weights, 40 | ) 41 | np.testing.assert_allclose(dist, result) 42 | -------------------------------------------------------------------------------- /docs/source/references.bib: -------------------------------------------------------------------------------- 1 | @misc{doc_figure, 2 | author = "ECMWF", 3 | title = "{Copernicus Emergency Management Service releases GloFAS v4.0 hydrological reanalysis}", 4 | year = "2022", 5 | url = "https://www.ecmwf.int/en/about/media-centre/news/2022/copernicus-emergency-management-service-releases-glofas-v40", 6 | note = "[Online; accessed 29-July-2025]" 7 | } 8 | 9 | @misc{earthkit, 10 | author = "ECMWF", 11 | title = "{earthkit: Python tools to work with weather and climate data}", 12 | year = "2022", 13 | url = "https://github.com/ecmwf/earthkit" 14 | } 15 | 16 | @article{rastervector, 17 | title={{mizuRoute} version 1: a river network routing tool for a continental domain water resources applications}, 18 | author={Mizukami, Naoki and Clark, Martyn P and Sampson, Kevin and Nijssen, Bart and Mao, Yixin and McMillan, Hilary and Viger, Roland J and Markstrom, Steve L and Hay, Lauren E and Woods, Ross and others}, 19 | journal={Geoscientific Model Development}, 20 | volume={9}, 21 | number={6}, 22 | pages={2223--2238}, 23 | year={2016}, 24 | publisher={Copernicus Publications G{\"o}ttingen, Germany} 25 | } 26 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_backends/cupy_backend.py: -------------------------------------------------------------------------------- 1 | import array_api_compat.cupy as cp 2 | 3 | from .array_backend import ArrayBackend 4 | 5 | 6 | class CuPyBackend(ArrayBackend): 7 | def __init__(self): 8 | super().__init__(cp) 9 | 10 | @property 11 | def name(self): 12 | return "cupy" 13 | 14 | def copy(self, x): 15 | return x.copy() 16 | 17 | def gather(self, arr, indices, axis=-1): 18 | assert axis == -1 19 | return arr[..., indices] 20 | 21 | def scatter_assign(self, target, indices, updates): 22 | target[..., indices] = updates 23 | return target 24 | 25 | def scatter_add(self, target, indices, updates): 26 | cp.add.at(target, (*[slice(None)] * (target.ndim - 1), indices), updates) 27 | return target 28 | 29 | def scatter_max(self, target, indices, updates): 30 | cp.maximum.at(target, (*[slice(None)] * (target.ndim - 1), indices), updates) 31 | return target 32 | 33 | def scatter_min(self, target, indices, updates): 34 | cp.minimum.at(target, (*[slice(None)] * (target.ndim - 1), indices), updates) 35 | return target 36 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_backends/numpy_backend.py: -------------------------------------------------------------------------------- 1 | import array_api_compat.numpy as np 2 | 3 | from .array_backend import ArrayBackend 4 | 5 | 6 | class NumPyBackend(ArrayBackend): 7 | def __init__(self): 8 | super().__init__(np) 9 | 10 | @property 11 | def name(self): 12 | return "numpy" 13 | 14 | def copy(self, x): 15 | return x.copy() 16 | 17 | def gather(self, arr, indices, axis=-1): 18 | assert axis == -1 19 | return arr[..., indices] 20 | 21 | def scatter_assign(self, target, indices, updates): 22 | target[..., indices] = updates 23 | return target 24 | 25 | def scatter_add(self, target, indices, updates): 26 | np.add.at(target, (*[slice(None)] * (target.ndim - 1), indices), updates) 27 | return target 28 | 29 | def scatter_max(self, target, indices, updates): 30 | np.maximum.at(target, (*[slice(None)] * (target.ndim - 1), indices), updates) 31 | return target 32 | 33 | def scatter_min(self, target, indices, updates): 34 | np.minimum.at(target, (*[slice(None)] * (target.ndim - 1), indices), updates) 35 | return target 36 | -------------------------------------------------------------------------------- /tests/catchments/array/test_find.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from _test_inputs.catchment import * 4 | from _test_inputs.readers import * 5 | 6 | import earthkit.hydro as ekh 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "river_network, query_field, find_catchments", 11 | [ 12 | ( 13 | ("cama_nextxy", cama_nextxy_1), 14 | catchment_query_field_1, 15 | catchment_1, 16 | ), 17 | (("cama_nextxy", cama_nextxy_2), catchment_query_field_2, catchment_2), 18 | ], 19 | indirect=["river_network"], 20 | ) 21 | def test_find_catchments_2d(river_network, query_field, find_catchments): 22 | # field = np.zeros(river_network.mask.shape, dtype="int") 23 | # field[river_network.mask] = query_field 24 | network_find_catchments = ekh.catchments.array.find( 25 | river_network, locations=query_field 26 | ) 27 | print(find_catchments) 28 | print(network_find_catchments) 29 | np.testing.assert_array_equal( 30 | network_find_catchments.flat[river_network.mask], find_catchments 31 | ) 32 | # np.testing.assert_array_equal(network_find_catchments[~river_network.mask], 0) 33 | -------------------------------------------------------------------------------- /docs/source/contributing.rst: -------------------------------------------------------------------------------- 1 | Contributing 2 | ============ 3 | 4 | **earthkit-hydro** is an open-source project, and contributions are highly welcomed and appreciated. 5 | 6 | The code is hosted on `GitHub `_. 7 | 8 | Development workflow 9 | -------------------- 10 | 11 | 1. Fork the repository on GitHub 12 | 2. Clone the fork to your local machine 13 | 3. Create a virtual environment and install the package in development mode 14 | 4. Create a new branch for your changes 15 | 5. Make your changes and commit them with a clear message 16 | 6. Run tests to ensure everything is working correctly 17 | 7. Push your changes to your fork on GitHub 18 | 8. Open a pull request against the develop branch of the main repository 19 | 20 | Code style 21 | ---------- 22 | This project uses ruff, black, isort and flake8 for code styling and formatting. To handle these automatically, you can use pre-commit hooks. To set them up, run: 23 | 24 | .. code-block:: bash 25 | 26 | pip install pre-commit 27 | pre-commit install 28 | 29 | Testing 30 | ------- 31 | To run the tests, you can use pytest. Make sure you have all dependencies installed, then simply run: 32 | 33 | .. code-block:: bash 34 | 35 | pytest 36 | -------------------------------------------------------------------------------- /src/earthkit/hydro/distance/array/__operations.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro._core.accumulate import flow_downstream, flow_upstream 2 | from earthkit.hydro._core.metrics import metrics_func_finder 3 | 4 | 5 | def min(xp, river_network, field, locations, upstream, downstream): 6 | 7 | func_obj = metrics_func_finder("min", xp) 8 | 9 | out = xp.full(river_network.n_nodes, func_obj.base_val) 10 | 11 | out[locations] = 0 12 | 13 | func = func_obj.func 14 | 15 | if downstream: 16 | out = flow_downstream(xp, river_network, out, func, edge_additive_weight=field) 17 | 18 | if upstream: 19 | out = flow_upstream(xp, river_network, out, func, edge_additive_weight=field) 20 | 21 | return out 22 | 23 | 24 | def max(xp, river_network, field, locations, upstream, downstream): 25 | 26 | func_obj = metrics_func_finder("max", xp) 27 | 28 | out = xp.full(river_network.n_nodes, func_obj.base_val) 29 | 30 | out[locations] = 0 31 | 32 | func = func_obj.func 33 | 34 | if downstream: 35 | out = flow_downstream(xp, river_network, out, func, edge_additive_weight=field) 36 | if upstream: 37 | out = flow_upstream(xp, river_network, out, func, edge_additive_weight=field) 38 | 39 | return out 40 | -------------------------------------------------------------------------------- /docs/source/userguide/index.rst: -------------------------------------------------------------------------------- 1 | User Guide 2 | ========== 3 | 4 | **earthkit-hydro** is designed to simplify the process of working with hydrological data, providing a range of tools for catchment delineation, river network analysis, and more. It supports various data formats and array backends, making it versatile for different applications. 5 | 6 | At its core, **earthkit-hydro** is a library for conducting operations on river networks. A typical workflow involves: 7 | 8 | 1. Loading a river network 9 | 2. Performing operations on the network, such field propagation, catchment averages, distance calculations and more. 10 | 3. Saving or plotting the results 11 | 12 | In this user guide, we provide detailed instructions for such steps. 13 | 14 | Basics: 15 | 16 | .. toctree:: 17 | :maxdepth: 200 18 | :titlesonly: 19 | 20 | loading_a_river_network 21 | xarray_array_backend 22 | raster_vector_inputs 23 | 24 | Operations: 25 | 26 | .. toctree:: 27 | :maxdepth: 200 28 | :titlesonly: 29 | 30 | flow_accumulations 31 | specifying_locations 32 | catchment_delineation 33 | catchment_statistics 34 | distance_length_calculations 35 | subnetwork_creation 36 | 37 | Misc: 38 | 39 | .. toctree:: 40 | :maxdepth: 200 41 | :titlesonly: 42 | 43 | earthkit 44 | pcraster 45 | -------------------------------------------------------------------------------- /src/earthkit/hydro/move/array/_operations.py: -------------------------------------------------------------------------------- 1 | import earthkit.hydro.move.array.__operations as array 2 | from earthkit.hydro._utils.decorators import mask, multi_backend 3 | 4 | 5 | @multi_backend(jax_static_args=["xp", "river_network", "return_type", "metric"]) 6 | def upstream(xp, river_network, field, node_weights, edge_weights, metric, return_type): 7 | return_type = river_network.return_type if return_type is None else return_type 8 | if return_type not in ["gridded", "masked"]: 9 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 10 | decorated_func = mask(return_type == "gridded")(array.upstream) 11 | return decorated_func(xp, river_network, field, node_weights, edge_weights, metric) 12 | 13 | 14 | @multi_backend(jax_static_args=["xp", "river_network", "return_type", "metric"]) 15 | def downstream( 16 | xp, river_network, field, node_weights, edge_weights, metric, return_type 17 | ): 18 | return_type = river_network.return_type if return_type is None else return_type 19 | if return_type not in ["gridded", "masked"]: 20 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 21 | decorated_func = mask(return_type == "gridded")(array.downstream) 22 | return decorated_func(xp, river_network, field, node_weights, edge_weights, metric) 23 | -------------------------------------------------------------------------------- /docs/source/userguide/specifying_locations.rst: -------------------------------------------------------------------------------- 1 | Specifying locations 2 | ==================== 3 | 4 | Many functions are concerned with operations relating a subset of the entire river network i.e. a fixed number of locations. This can range from catchment averages, to distances etc. 5 | 6 | The most convenient and common way to specify a gauge location is by its coordinates, typically latitude and longitude. The specified coordinates must match the river network coordinate system. 7 | 8 | For the EFAS network (which uses regular lat/lon) we can specify via the lat/lon of points of interest. 9 | 10 | .. code-block:: python 11 | 12 | locations = { 13 | "station1": (10, 10), 14 | "station2": (10, 10), 15 | "station3": (10, 10) 16 | } 17 | 18 | labelled_field = ekh.catchments.sum(network, field, locations) 19 | 20 | However, for more performance, it is also possible to specify directly a grid index. 21 | 22 | .. code-block:: python 23 | 24 | locations = [(10,10), (50, 30), (80, 70)] 25 | 26 | labelled_field = ekh.catchments.sum(network, field, locations) 27 | 28 | Or, for maximum performance, it is possible to also specify node labels. 29 | 30 | .. code-block:: python 31 | 32 | locations = [10, 5, 6] 33 | 34 | labelled_field = ekh.catchments.sum(network, field, locations) 35 | -------------------------------------------------------------------------------- /tests/length/array/test_min.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from _test_inputs.distance import * 4 | from _test_inputs.readers import * 5 | 6 | import earthkit.hydro as ekh 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "river_network, stations_list, upstream, downstream, weights, result", 11 | [ 12 | ( 13 | ("cama_nextxy", cama_nextxy_1), 14 | stations, 15 | True, 16 | True, 17 | weights_1, 18 | length_1_min_up_down, 19 | ), 20 | ( 21 | ("cama_nextxy", cama_nextxy_1), 22 | stations, 23 | True, 24 | False, 25 | weights_1, 26 | length_1_min_up, 27 | ), 28 | ( 29 | ("cama_nextxy", cama_nextxy_1), 30 | stations, 31 | False, 32 | True, 33 | weights_1, 34 | length_1_min_down, 35 | ), 36 | ], 37 | indirect=["river_network"], 38 | ) 39 | def test_length_min( 40 | river_network, stations_list, upstream, downstream, weights, result 41 | ): 42 | dist = ekh.length.array.min( 43 | river_network, 44 | stations_list, 45 | upstream=upstream, 46 | downstream=downstream, 47 | field=weights, 48 | ) 49 | np.testing.assert_allclose(dist, result) 50 | -------------------------------------------------------------------------------- /tests/distance/array/test_min.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from _test_inputs.distance import * 4 | from _test_inputs.readers import * 5 | 6 | import earthkit.hydro as ekh 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "river_network, stations_list, upstream, downstream, weights, result", 11 | [ 12 | ( 13 | ("cama_nextxy", cama_nextxy_1), 14 | stations, 15 | True, 16 | True, 17 | weights_1, 18 | distance_1_min_up_down, 19 | ), 20 | ( 21 | ("cama_nextxy", cama_nextxy_1), 22 | stations, 23 | True, 24 | False, 25 | weights_1, 26 | distance_1_min_up, 27 | ), 28 | ( 29 | ("cama_nextxy", cama_nextxy_1), 30 | stations, 31 | False, 32 | True, 33 | weights_1, 34 | distance_1_min_down, 35 | ), 36 | ], 37 | indirect=["river_network"], 38 | ) 39 | def test_distance_min( 40 | river_network, stations_list, upstream, downstream, weights, result 41 | ): 42 | dist = ekh.distance.array.min( 43 | river_network, 44 | stations_list, 45 | upstream=upstream, 46 | downstream=downstream, 47 | field=weights, 48 | ) 49 | np.testing.assert_allclose(dist, result) 50 | -------------------------------------------------------------------------------- /tests/_test_inputs/catchment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # catchment_query_field_1 = np.array( 4 | # [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 1, np.nan, 5, 4, 2, 3, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], dtype="int" 5 | # ) 6 | catchment_query_field_1 = [8, 12, 13, 11, 10] 7 | 8 | 9 | # catchment_query_field_2 = np.array( 10 | # [4, np.nan, 1, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 3, np.nan, 2, np.nan, np.nan], dtype="int" 11 | # ) 12 | catchment_query_field_2 = [2, 13, 11, 0] 13 | 14 | 15 | # subcatchment_1 = np.array([5, 4, 2, 2, 1, 5, 4, 2, 1, 3, 5, 4, 2, 3, 3, np.nan, np.nan, np.nan, np.nan, np.nan])-1 16 | 17 | 18 | # subcatchment_2 = np.array([4, 1, 1, np.nan, 2, 2, 2, 3, 2, 2, 2, 3, 4, 2, 3, 3])-1 19 | 20 | 21 | catchment_1 = ( 22 | np.array( 23 | [ 24 | 5, 25 | 4, 26 | 2, 27 | 2, 28 | 2, 29 | 5, 30 | 4, 31 | 2, 32 | 2, 33 | 2, 34 | 5, 35 | 4, 36 | 2, 37 | 2, 38 | 2, 39 | np.nan, 40 | np.nan, 41 | np.nan, 42 | np.nan, 43 | np.nan, 44 | ] 45 | ) 46 | - 1 47 | ) 48 | 49 | 50 | catchment_2 = np.array([4, 2, 2, np.nan, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2]) - 1 51 | -------------------------------------------------------------------------------- /docs/source/userguide/earthkit.rst: -------------------------------------------------------------------------------- 1 | Integration with the earthkit system 2 | ==================================== 3 | 4 | earthkit-hydro is the hydrological component of earthkit :cite:`earthkit`. It is designed to interplay with other earthkit components seamlessly, primarily via xarray integration. 5 | 6 | Here is a simple example of using different earthkit packages together. 7 | 8 | .. code-block:: python 9 | 10 | import earthkit.data as ekd 11 | import earthkit.hydro as ekh 12 | import earthkit.plots as ekp 13 | 14 | # specify some custom styles 15 | style = ekp.styles.Style( 16 | colors="Blues", 17 | levels=[0, 0.5, 1, 2, 5, 10, 50, 100, 500, 1000, 2000, 3000, 4000], 18 | extend="max", 19 | ) 20 | 21 | # load data and river network 22 | network = ekh.river_network.load("efas", "5") 23 | da = ekd.from_source( 24 | "sample", 25 | "R06a.nc", 26 | )[0].to_xarray() 27 | 28 | # compute upstream accumulation 29 | upstream_sum = ekh.upstream.sum(network, da) 30 | 31 | # plot result 32 | chart = ekp.Map() 33 | chart.quickplot(upstream_sum, style=style) 34 | chart.legend(label="{variable_name}") 35 | chart.title("Upstream precipitation at {time:%H:%M on %-d %B %Y}") 36 | chart.coastlines() 37 | chart.gridlines() 38 | chart.show() 39 | 40 | .. image:: ../../images/earthkit_example.png 41 | :width: 100% 42 | :align: center 43 | -------------------------------------------------------------------------------- /.github/workflows/downstream-ci.yml: -------------------------------------------------------------------------------- 1 | name: downstream 2 | 3 | on: 4 | # Trigger the workflow on push to master or develop, except tag creation 5 | push: 6 | branches: 7 | - 'main' 8 | - 'develop' 9 | tags-ignore: 10 | - '**' 11 | 12 | # Trigger the workflow on pull request 13 | pull_request: ~ 14 | 15 | # Trigger the workflow manually 16 | workflow_dispatch: ~ 17 | 18 | # Trigger after public PR approved for CI 19 | pull_request_target: 20 | types: [labeled] 21 | 22 | jobs: 23 | # Run CI including downstream packages on self-hosted runners 24 | downstream-ci: 25 | if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }} 26 | uses: ecmwf/downstream-ci/.github/workflows/downstream-ci.yml@main 27 | with: 28 | earthkit-hydro: ecmwf/earthkit-hydro@${{ github.event.pull_request.head.sha || github.sha }} 29 | codecov_upload: true 30 | python_qa: true 31 | secrets: inherit 32 | 33 | # Build downstream packages on HPC 34 | downstream-ci-hpc: 35 | if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }} 36 | uses: ecmwf/downstream-ci/.github/workflows/downstream-ci-hpc.yml@main 37 | with: 38 | earthkit-hydro: ecmwf/earthkit-hydro@${{ github.event.pull_request.head.sha || github.sha }} 39 | secrets: inherit 40 | -------------------------------------------------------------------------------- /docs/source/userguide/catchment_delineation.rst: -------------------------------------------------------------------------------- 1 | Catchment delineation 2 | ===================== 3 | 4 | A common task in hydrology is identifying the catchment area for a given point in a river network. 5 | This process, known as catchment delineation, involves determining the area that drains into a specific point. 6 | 7 | .. image:: ../../images/catchment.gif 8 | :width: 250px 9 | :align: right 10 | 11 | In earthkit-hydro, this is accomplished by specifying certain start locations, and labelling all nodes flowing towards those start locations. 12 | If start locations belong to the same catchment, the node furthest downstream takes priority and overwrites any upstream start locations. 13 | 14 | This can be done in earthkit-hydro using the `catchments.find` method. 15 | 16 | .. raw:: html 17 | 18 |
19 | 20 | .. code-block:: python 21 | 22 | network = ekh.river_network.load("efas", "5") 23 | 24 | labelled_field = ekh.catchments.find(network, locations) 25 | 26 | Subcatchments can also be found by making use of the `overwrite` keyword. 27 | 28 | .. image:: ../../images/subcatchment.gif 29 | :width: 250px 30 | :align: left 31 | 32 | .. code-block:: python 33 | 34 | labelled_field = ekh.catchments.find( 35 | network, 36 | locations, 37 | overwrite=False 38 | ) 39 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_utils/coords.py: -------------------------------------------------------------------------------- 1 | def get_core_grid_dims(ds): 2 | possible_names = [["lat", "lon"], ["latitude", "longitude"], ["y", "x"]] 3 | for names in possible_names: 4 | present = True 5 | for name in names: 6 | present &= name in ds.coords 7 | if present: 8 | return names 9 | 10 | 11 | def get_core_node_dims(ds): 12 | possible_names = [ 13 | ["index"], 14 | ["node_index"], 15 | ["node_id"], 16 | ["station_index"], 17 | ["station_id"], 18 | ["gauge_id"], 19 | ["id"], 20 | ["idx"], 21 | ] 22 | for names in possible_names: 23 | present = True 24 | for name in names: 25 | present &= name in ds.coords 26 | if present: 27 | return names 28 | 29 | 30 | def get_core_edge_dims(ds): 31 | possible_names = [["edge_id"]] 32 | for names in possible_names: 33 | present = True 34 | for name in names: 35 | present &= name in ds.coords 36 | if present: 37 | return names 38 | 39 | 40 | def get_core_dims(ds): 41 | dims = get_core_grid_dims(ds) 42 | if dims is None: 43 | dims = get_core_node_dims(ds) 44 | if dims is None: 45 | dims = get_core_edge_dims(ds) 46 | if dims is None: 47 | raise ValueError("Could not autodetect xarray core dims.") 48 | return dims 49 | 50 | 51 | node_default_coord = "node_index" 52 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_readers/group_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | 6 | def compute_topological_labels(sources, sinks, downstream_nodes, n_nodes): 7 | 8 | use_rust = int(os.environ.get("USE_RUST", "-1")) 9 | 10 | if use_rust == 0: 11 | func = compute_topological_labels_python 12 | elif use_rust == 1: 13 | from earthkit.hydro._rust import compute_topological_labels_rust as func 14 | else: 15 | try: 16 | from earthkit.hydro._rust import compute_topological_labels_rust as func 17 | except (ModuleNotFoundError, ImportError): 18 | func = compute_topological_labels_python 19 | 20 | return func(sources, sinks, downstream_nodes, n_nodes) 21 | 22 | 23 | def compute_topological_labels_python( 24 | sources: np.ndarray, sinks: np.ndarray, downstream_nodes: np.ndarray, n_nodes: int 25 | ): 26 | n_nodes = downstream_nodes.shape[0] 27 | inlets = downstream_nodes[sources] 28 | labels = np.zeros(n_nodes, dtype=np.intp) 29 | 30 | for n in range(1, n_nodes + 1): 31 | inlets = np.unique(inlets[inlets != n_nodes]) # subset to valid nodes 32 | if inlets.shape[0] == 0: 33 | break 34 | labels[inlets] = n # update furthest distance from source 35 | inlets = downstream_nodes[inlets] 36 | 37 | if inlets.shape[0] != 0: 38 | raise ValueError("River Network contains a cycle.") 39 | labels[sinks] = n - 1 # put all sinks in last group in topological ordering 40 | 41 | return labels 42 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | default_stages: 4 | - commit 5 | - push 6 | repos: 7 | - repo: https://github.com/charliermarsh/ruff-pre-commit 8 | rev: v0.5.6 9 | hooks: 10 | - id: ruff # fix linting violations 11 | types_or: [ python, pyi, jupyter ] 12 | args: [ --fix ] 13 | # - id: ruff-format # fix formatting 14 | # types_or: [ python, pyi, jupyter ] 15 | - repo: https://github.com/psf/black 16 | rev: 25.1.0 17 | hooks: 18 | - id: black 19 | - repo: https://github.com/pycqa/isort 20 | rev: 5.13.2 21 | hooks: 22 | - id: isort 23 | - repo: https://github.com/pycqa/flake8 24 | rev: 7.0.0 25 | hooks: 26 | - id: flake8 27 | - repo: https://github.com/pre-commit/pre-commit-hooks 28 | rev: v4.4.0 29 | hooks: 30 | - id: detect-private-key 31 | - id: check-ast 32 | - id: end-of-file-fixer 33 | - id: mixed-line-ending 34 | args: [--fix=lf] 35 | - id: trailing-whitespace 36 | - id: check-case-conflict 37 | - repo: local 38 | hooks: 39 | - id: forbid-to-commit 40 | name: Don't commit rej files 41 | entry: | 42 | Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. 43 | Fix the merge conflicts manually and remove the .rej files. 44 | language: fail 45 | files: '.*\.rej$' 46 | -------------------------------------------------------------------------------- /tests/upstream/array/test_mean.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from _test_inputs.accumulation import * 4 | from _test_inputs.readers import * 5 | from utils import convert_to_2d 6 | 7 | import earthkit.hydro as ekh 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "river_network, input_field, flow_downstream, mv", 12 | [ 13 | ( 14 | ("cama_nextxy", cama_nextxy_1), 15 | input_field_1c, 16 | upstream_metric_mean_1c, 17 | mv_1c, 18 | ), 19 | ( 20 | ("cama_nextxy", cama_nextxy_1), 21 | input_field_1e, 22 | upstream_metric_mean_1e, 23 | mv_1e, 24 | ), 25 | ], 26 | indirect=["river_network"], 27 | ) 28 | def test_calculate_upstream_metric_mean( 29 | river_network, input_field, flow_downstream, mv 30 | ): 31 | output_field = ekh.upstream.array.mean( 32 | river_network, input_field, node_weights=None, return_type="masked" 33 | ) 34 | assert output_field.dtype == flow_downstream.dtype 35 | np.testing.assert_allclose(output_field, flow_downstream) 36 | 37 | input_field = convert_to_2d(river_network, input_field, 0) 38 | flow_downstream = convert_to_2d(river_network, flow_downstream, 0) 39 | output_field = ekh.upstream.array.mean( 40 | river_network, 41 | input_field, 42 | node_weights=None, 43 | ).flatten() 44 | print(output_field) 45 | print(flow_downstream) 46 | assert output_field.dtype == flow_downstream.dtype 47 | np.testing.assert_allclose(output_field, flow_downstream) 48 | -------------------------------------------------------------------------------- /docs/source/userguide/catchment_statistics.rst: -------------------------------------------------------------------------------- 1 | Catchment statistics 2 | ==================== 3 | 4 | A very common hydrological task is computing statistics over river basins. This is very simple in earthkit-hydro. 5 | 6 | Calculation 7 | ----------- 8 | 9 | A catchment of a gauge location is defined as all nodes flowing to that location. 10 | Catchment statistics are calculated for each location in the same manner as for upstream statistics, with optional weights. 11 | The only difference is that for catchment statistics, one specifies directly the gauge locations of interest as opposed to computing for each node in the river network. 12 | The following methods are available: 13 | 14 | .. code-block:: python 15 | 16 | network = ekh.river_network.load("efas", "5") 17 | field = np.ones(network.n_nodes) 18 | node_weights = np.ones(network.n_nodes) # optional weights for the nodes 19 | edge_weights = np.ones(network.n_edges) # optional weights for the edges 20 | locations = { 21 | "station1": (10, 10), 22 | "station2": (20, 20) 23 | } 24 | 25 | upstream_sum = ekh.catchments.sum(network, field, locations, node_weights, edge_weights) 26 | upstream_mean = ekh.catchments.mean(network, field, locations, node_weights, edge_weights) 27 | upstream_max = ekh.catchments.max(network, field, locations, node_weights, edge_weights) 28 | upstream_min = ekh.catchments.min(network, field, locations, node_weights, edge_weights) 29 | upstream_std = ekh.catchments.std(network, field, locations, node_weights, edge_weights) 30 | upstream_var = ekh.catchments.var(network, field, locations, node_weights, edge_weights) 31 | -------------------------------------------------------------------------------- /src/earthkit/hydro/length/array/_operations.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro._utils.decorators import mask, multi_backend 2 | from earthkit.hydro._utils.locations import locations_to_1d 3 | from earthkit.hydro.length.array import __operations as _operations 4 | 5 | 6 | @multi_backend(allow_jax_jit=False) 7 | def min(xp, river_network, field, locations, upstream, downstream, return_type): 8 | if field is None: 9 | field = xp.ones(river_network.n_nodes) 10 | locations, _, _ = locations_to_1d(xp, river_network, locations) 11 | return_type = river_network.return_type if return_type is None else return_type 12 | if return_type not in ["gridded", "masked"]: 13 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 14 | decorated_func = mask(return_type == "gridded")(_operations.min) 15 | return decorated_func(xp, river_network, field, locations, upstream, downstream) 16 | 17 | 18 | @multi_backend(allow_jax_jit=False) 19 | def max(xp, river_network, field, locations, upstream, downstream, return_type): 20 | if field is None: 21 | field = xp.ones(river_network.n_nodes) 22 | locations, _, _ = locations_to_1d(xp, river_network, locations) 23 | return_type = river_network.return_type if return_type is None else return_type 24 | if return_type not in ["gridded", "masked"]: 25 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 26 | decorated_func = mask(return_type == "gridded")(_operations.max) 27 | return decorated_func(xp, river_network, field, locations, upstream, downstream) 28 | 29 | 30 | def to_source(*args, **kwargs): 31 | raise NotImplementedError 32 | 33 | 34 | def to_sink(*args, **kwargs): 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /src/earthkit/hydro/catchments/_operations.py: -------------------------------------------------------------------------------- 1 | import earthkit.hydro.catchments.array.__operations as array 2 | from earthkit.hydro._utils.decorators import multi_backend 3 | 4 | 5 | @multi_backend(allow_jax_jit=False) 6 | def var( 7 | xp, 8 | river_network, 9 | field, 10 | locations, 11 | node_weights, 12 | edge_weights, 13 | ): 14 | return array.var(xp, river_network, field, locations, node_weights, edge_weights) 15 | 16 | 17 | @multi_backend(allow_jax_jit=False) 18 | def std( 19 | xp, 20 | river_network, 21 | field, 22 | locations, 23 | node_weights, 24 | edge_weights, 25 | ): 26 | return array.std(xp, river_network, field, locations, node_weights, edge_weights) 27 | 28 | 29 | @multi_backend(allow_jax_jit=False) 30 | def mean( 31 | xp, 32 | river_network, 33 | field, 34 | locations, 35 | node_weights, 36 | edge_weights, 37 | ): 38 | return array.mean(xp, river_network, field, locations, node_weights, edge_weights) 39 | 40 | 41 | @multi_backend(allow_jax_jit=False) 42 | def sum( 43 | xp, 44 | river_network, 45 | field, 46 | locations, 47 | node_weights, 48 | edge_weights, 49 | ): 50 | return array.sum(xp, river_network, field, locations, node_weights, edge_weights) 51 | 52 | 53 | @multi_backend(allow_jax_jit=False) 54 | def min( 55 | xp, 56 | river_network, 57 | field, 58 | locations, 59 | node_weights, 60 | edge_weights, 61 | ): 62 | return array.min(xp, river_network, field, locations, node_weights, edge_weights) 63 | 64 | 65 | @multi_backend(allow_jax_jit=False) 66 | def max( 67 | xp, 68 | river_network, 69 | field, 70 | locations, 71 | node_weights, 72 | edge_weights, 73 | ): 74 | return array.max(xp, river_network, field, locations, node_weights, edge_weights) 75 | -------------------------------------------------------------------------------- /docs/source/userguide/loading_a_river_network.rst: -------------------------------------------------------------------------------- 1 | Loading a river network 2 | ======================= 3 | 4 | earthkit-hydro provides a straightforward way to load river networks from various formats. The library supports multiple river network formats, including those used by PCRaster, CaMa-Flood, HydroSHEDS, MERIT-Hydro and GRIT. 5 | 6 | Many river networks are commonly used for hydrological analysis and modelling, such as the EFAS river network. earthkit-hydro provides precomputed versions of such river networks which are available via 7 | 8 | .. code-block:: python 9 | 10 | import earthkit.hydro as ekh 11 | 12 | # Load the EFAS version 5 river network 13 | network = ekh.river_network.load("efas", "5") 14 | 15 | This is the most convenient and performant way to load a river network, and is therefore recommended for most users. For a full list of networks, view the API reference :doc:`../autodocs/earthkit.hydro.river_network`. 16 | 17 | Custom river networks 18 | --------------------- 19 | If a river network is not available via `ekh.river_network.load`, it is possible to create a custom river network from scratch. Many different formats and sources are supported, as detailed in the API reference :doc:`../autodocs/earthkit.hydro.river_network`. 20 | 21 | .. code-block:: python 22 | 23 | network = ekh.river_network.create(path, river_network_format, source) 24 | 25 | This operation involves topologically sorting the river network, which is computationally expensive for large networks. Therefore, it is recommended to export the river network for re-use. 26 | 27 | .. code-block:: python 28 | 29 | network.export("my_river_network.joblib") 30 | 31 | In subsequent analyses, the precomputed river network can now be loaded via 32 | 33 | .. code-block:: python 34 | 35 | network = ekh.river_network.create("my_river_network.joblib", "precomputed") 36 | -------------------------------------------------------------------------------- /src/earthkit/hydro/length/array/__operations.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro._core.accumulate import flow_downstream, flow_upstream 2 | from earthkit.hydro._core.metrics import metrics_func_finder 3 | 4 | 5 | def min(xp, river_network, field, locations, upstream, downstream): 6 | 7 | func_obj = metrics_func_finder("min", xp) 8 | 9 | out = xp.full(river_network.n_nodes, func_obj.base_val) 10 | 11 | out[locations] = field[locations] 12 | 13 | func = func_obj.func 14 | 15 | if downstream: 16 | out = flow_downstream( 17 | xp, 18 | river_network, 19 | out, 20 | func, 21 | node_additive_weight=field, 22 | node_modifier_use_upstream=False, 23 | ) 24 | if upstream: 25 | out = flow_upstream( 26 | xp, 27 | river_network, 28 | out, 29 | func, 30 | node_additive_weight=field, 31 | node_modifier_use_upstream=False, 32 | ) 33 | 34 | return out 35 | 36 | 37 | def max(xp, river_network, field, locations, upstream, downstream): 38 | 39 | func_obj = metrics_func_finder("max", xp) 40 | 41 | out = xp.full(river_network.n_nodes, func_obj.base_val) 42 | 43 | out[locations] = field[locations] 44 | 45 | func = func_obj.func 46 | 47 | if downstream: 48 | out = flow_downstream( 49 | xp, 50 | river_network, 51 | out, 52 | func, 53 | node_additive_weight=field, 54 | node_modifier_use_upstream=False, 55 | ) 56 | if upstream: 57 | out = flow_upstream( 58 | xp, 59 | river_network, 60 | out, 61 | func, 62 | node_additive_weight=field, 63 | node_modifier_use_upstream=False, 64 | ) 65 | 66 | return out 67 | -------------------------------------------------------------------------------- /docs/source/userguide/subnetwork_creation.rst: -------------------------------------------------------------------------------- 1 | Subnetwork creation 2 | =================== 3 | 4 | By default, earthkit-hydro conducts operations over the full river network. In many applications, one is only interested in a specific subnetwork, such as a specific catchment or area. 5 | 6 | There are two ways to create a subnetwork: masking nodes or masking edges. 7 | 8 | Masking nodes 9 | ------------- 10 | 11 | The simplest subnetwork creation mechanism is to remove nodes from a river network. This also removes any edges that are incoming or outgoing to any of the removed nodes. 12 | The mask can be specified over the grid: 13 | 14 | .. code-block:: python 15 | 16 | network = ekh.river_network.load("efas", "5") 17 | 18 | node_mask = np.ones(network.shape, dtype=bool) 19 | node_mask[10,10] = False 20 | 21 | subnetwork = ekh.subnetwork.from_mask(network, node_mask=node_mask) 22 | 23 | Or as usual it is also possible to specify directly a mask on the nodes: 24 | 25 | .. code-block:: python 26 | 27 | node_mask = np.ones(network.n_nodes, dtype=bool) 28 | node_mask[10] = False 29 | 30 | subnetwork = ekh.subnetwork.from_mask(network, node_mask=node_mask) 31 | 32 | 33 | Masking edges 34 | ------------- 35 | 36 | Masking edges is also possible. This is useful for controlling bifurcating river networks, or physically separating a subcatchment from the main catchment. 37 | 38 | .. code-block:: python 39 | 40 | edge_mask = np.ones(network.n_edges, dtype=bool) 41 | edge_mask[10] = False 42 | 43 | subnetwork = ekh.subnetwork.from_mask(network, edge_mask=edge_mask) 44 | 45 | Combining masks 46 | --------------- 47 | 48 | It also possible to mask both nodes and edges in a single call. 49 | 50 | .. code-block:: python 51 | 52 | node_mask = np.ones(network.n_nodes, dtype=bool) 53 | node_mask[10] = False 54 | 55 | edge_mask = np.ones(network.n_edges, dtype=bool) 56 | edge_mask[10] = False 57 | 58 | subnetwork = ekh.subnetwork.from_mask(network, node_mask=node_mask, edge_mask=edge_mask) 59 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_core/_find.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro._core.flow import propagate 2 | 3 | 4 | def _flow_find( 5 | xp, 6 | river_network, 7 | field, 8 | overwrite=True, 9 | invert_graph=True, 10 | ): 11 | op = _find_catchments 12 | 13 | def operation( 14 | field, 15 | did, 16 | uid, 17 | eid, 18 | ): 19 | return op( 20 | xp, 21 | field, 22 | did, 23 | uid, 24 | eid, 25 | overwrite=overwrite, 26 | ) 27 | 28 | field = propagate( 29 | river_network, 30 | river_network.groups, 31 | field, 32 | invert_graph, 33 | operation, 34 | ) 35 | 36 | return field 37 | 38 | 39 | def _find_catchments(xp, field, did, uid, eid, overwrite): 40 | """ 41 | Updates field in-place with the value of its downstream nodes, 42 | dealing with missing values for 2D fields. 43 | 44 | Parameters 45 | ---------- 46 | river_network : earthkit.hydro.network.RiverNetwork 47 | An earthkit-hydro river network object. 48 | field : numpy.ndarray 49 | The input field. 50 | grouping : numpy.ndarray 51 | The array of node indices. 52 | overwrite : bool 53 | If True, overwrite existing non-missing values in the field array. 54 | 55 | Returns 56 | ------- 57 | None 58 | """ 59 | down_not_missing = ~xp.isnan(xp.gather(field, uid, axis=-1)) 60 | did = did[ 61 | down_not_missing 62 | ] # only update nodes where the downstream belongs to a catchment 63 | if not overwrite: 64 | up_is_missing = xp.isnan(xp.gather(field, did, axis=-1)) 65 | did = did[up_is_missing] 66 | else: 67 | up_is_missing = None 68 | uid = ( 69 | uid[down_not_missing][up_is_missing] 70 | if up_is_missing is not None 71 | else uid[down_not_missing] 72 | ) 73 | updates = xp.gather(field, uid, axis=-1) 74 | return xp.scatter_assign(field, did, updates) 75 | -------------------------------------------------------------------------------- /src/earthkit/hydro/distance/array/_operations.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro._utils.decorators import mask, multi_backend 2 | from earthkit.hydro._utils.locations import locations_to_1d 3 | from earthkit.hydro.distance.array import __operations as _operations 4 | 5 | 6 | @multi_backend(allow_jax_jit=False) 7 | def min(xp, river_network, field, locations, upstream, downstream, return_type): 8 | if field is None: 9 | field = xp.ones(river_network.n_edges) 10 | else: 11 | # make xp-agnostic 12 | arr_mask = xp.full(river_network.n_nodes, False) 13 | arr_mask[river_network.sinks] = True 14 | field = field[~arr_mask] 15 | locations, _, _ = locations_to_1d(xp, river_network, locations) 16 | return_type = river_network.return_type if return_type is None else return_type 17 | if return_type not in ["gridded", "masked"]: 18 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 19 | decorated_func = mask(return_type == "gridded")(_operations.min) 20 | return decorated_func(xp, river_network, field, locations, upstream, downstream) 21 | 22 | 23 | @multi_backend(allow_jax_jit=False) 24 | def max(xp, river_network, field, locations, upstream, downstream, return_type): 25 | if field is None: 26 | field = xp.ones(river_network.n_edges) 27 | else: 28 | # make xp-agnostic 29 | arr_mask = xp.full(river_network.n_nodes, False) 30 | arr_mask[river_network.sinks] = True 31 | field = field[~arr_mask] 32 | locations, _, _ = locations_to_1d(xp, river_network, locations) 33 | return_type = river_network.return_type if return_type is None else return_type 34 | if return_type not in ["gridded", "masked"]: 35 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 36 | decorated_func = mask(return_type == "gridded")(_operations.max) 37 | return decorated_func(xp, river_network, field, locations, upstream, downstream) 38 | 39 | 40 | def to_source(*args, **kwargs): 41 | raise NotImplementedError 42 | 43 | 44 | def to_sink(*args, **kwargs): 45 | raise NotImplementedError 46 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | earthkit-hydro 2 | ============== 3 | 4 | .. important:: 5 | 6 | This software is **Incubating** and subject to ECMWF's guidelines on `Software Maturity `_. 7 | 8 | **earthkit-hydro** is a Python library for common hydrological functions. It is the hydrological component of earthkit :cite:`earthkit`. 9 | 10 | Main Features 11 | ------------- 12 | 13 | .. raw:: html 14 | 15 |
16 | 17 | .. https://agupubs.onlinelibrary.wiley.com/cms/asset/e10b31b2-7a5c-498d-bb27-49966867e6a8/wrcr70124-fig-0002-m.jpg 18 | .. figure:: ../images/glofas.png 19 | :width: 300px 20 | 21 | *Adapted from:* :cite:`doc_figure` 22 | 23 | .. raw:: html 24 | 25 |
26 | 27 | - Catchment delineation 28 | - Catchment-based statistics 29 | - Directional flow-based accumulations 30 | - River network distance calculations 31 | - Upstream/downstream field propagation 32 | - Bifurcation handling 33 | - Custom weighting and decay support 34 | 35 | .. raw:: html 36 | 37 |
38 | 39 | .. image:: ../images/array_backends_with_xr.png 40 | :width: 300px 41 | :align: right 42 | 43 | - Support for PCRaster, CaMa-Flood, HydroSHEDS, MERIT-Hydro and GRIT river network formats 44 | - Compatible with major array-backends: xarray, numpy, cupy, torch, jax, mlx and tensorflow 45 | - GPU support 46 | - Differentiable operations suitable for machine learning 47 | 48 | .. raw:: html 49 | 50 |
51 | 52 | Installation 53 | ------------ 54 | 55 | Try it out! earthkit-hydro is readily available on PyPI. 56 | 57 | .. code-block:: bash 58 | 59 | pip install earthkit-hydro 60 | 61 | Support 62 | ------- 63 | Have a feature request or found a bug? Feel free to open an 64 | `issue `_. 65 | 66 | Documentation 67 | ------------- 68 | .. toctree:: 69 | :maxdepth: 2 70 | :titlesonly: 71 | 72 | userguide/index 73 | tutorials/index 74 | autodocs/earthkit.hydro 75 | contributing 76 | references 77 | -------------------------------------------------------------------------------- /tests/upstream/array/test_sum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from _test_inputs.accumulation import * 4 | from _test_inputs.readers import * 5 | from utils import convert_to_2d 6 | 7 | import earthkit.hydro as ekh 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "river_network, input_field, flow_downstream, mv", 12 | [ 13 | ( 14 | ("cama_nextxy", cama_nextxy_1), 15 | input_field_1c, 16 | upstream_metric_sum_1c, 17 | mv_1c, 18 | ), 19 | ( 20 | ("cama_nextxy", cama_nextxy_1), 21 | input_field_1e, 22 | upstream_metric_sum_1e, 23 | mv_1e, 24 | ), 25 | ], 26 | indirect=["river_network"], 27 | ) 28 | @pytest.mark.parametrize("array_backend", ["numpy", "torch", "jax"]) 29 | def test_upstream_metric_sum( 30 | river_network, input_field, flow_downstream, mv, array_backend 31 | ): 32 | river_network = river_network.to_device("cpu", array_backend) 33 | xp = ekh._backends.find.get_array_backend(array_backend) 34 | output_field = ekh.upstream.array.sum( 35 | river_network, xp.asarray(input_field), node_weights=None, return_type="masked" 36 | ) 37 | output_field = np.asarray(output_field) 38 | flow_downstream_out = np.asarray(xp.asarray(flow_downstream)) 39 | print(output_field) 40 | print(flow_downstream_out) 41 | assert output_field.dtype == flow_downstream_out.dtype 42 | np.testing.assert_allclose(output_field, flow_downstream, rtol=1e-6) 43 | 44 | print(input_field) 45 | input_field = convert_to_2d(river_network, input_field, 0) 46 | flow_downstream = convert_to_2d(river_network, flow_downstream, 0) 47 | print(mv, input_field.dtype) 48 | print(input_field, flow_downstream) 49 | output_field = ekh.upstream.array.sum( 50 | river_network, xp.asarray(input_field), node_weights=None 51 | ) 52 | output_field = np.asarray(output_field).flatten() 53 | flow_downstream = np.asarray(xp.asarray(flow_downstream)) 54 | print(output_field) 55 | print(flow_downstream) 56 | assert output_field.dtype == flow_downstream.dtype 57 | np.testing.assert_allclose(output_field, flow_downstream, rtol=1e-6) 58 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_utils/locations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def locations_to_1d(xp, river_network, locations): 5 | 6 | orig_locations = locations 7 | dict_locations = isinstance(locations, dict) 8 | if dict_locations: 9 | 10 | coord1_network_vals, coord2_network_vals = river_network.coords.values() 11 | 12 | locations = [] 13 | if river_network.shape is None: # vector network 14 | for coord1_val, coord2_val in orig_locations.values(): 15 | indx = ( 16 | (coord1_val - coord1_network_vals) ** 2 17 | + (coord2_val - coord2_network_vals) ** 2 18 | ).argmin() 19 | locations.append(int(indx)) 20 | else: 21 | for coord1_val, coord2_val in orig_locations.values(): 22 | indx = np.argmin((coord1_val - coord1_network_vals) ** 2) 23 | indy = np.argmin((coord2_val - coord2_network_vals) ** 2) 24 | locations.append((int(indx), int(indy))) 25 | 26 | locations = xp.asarray(locations, device=river_network.device) 27 | stations = locations 28 | 29 | if stations.ndim == 2 and stations.shape[1] == 2: 30 | if xp.name not in ["numpy", "cupy", "torch"]: 31 | raise NotImplementedError 32 | # TODO: make this code actually xp agnostic 33 | rows, cols = stations[:, 0], stations[:, 1] 34 | flat_indices = rows * river_network.shape[1] + cols 35 | flat_mask = river_network.mask 36 | reverse_map = -xp.ones( 37 | river_network.shape[0] * river_network.shape[1], 38 | dtype=int, 39 | device=river_network.device, 40 | ) 41 | reverse_map[flat_mask] = xp.arange( 42 | flat_mask.shape[0], device=river_network.device 43 | ) 44 | masked_indices = reverse_map[flat_indices] 45 | if xp.any(masked_indices < 0): 46 | raise ValueError( 47 | "Some station points are not included in the masked array." 48 | ) 49 | stations = xp.asarray(masked_indices, device=river_network.device) 50 | else: 51 | assert stations.ndim == 1 52 | 53 | return stations, locations, orig_locations 54 | -------------------------------------------------------------------------------- /tests/_test_inputs/readers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # RIVER NETWORK EXAMPLE ONE 4 | 5 | d8_ldd_1 = np.array( 6 | [ 7 | [2, 2, 2, 1, 1], 8 | [2, 2, 2, 1, 1], 9 | [3, 2, 1, 4, 4], 10 | [6, 5, 4, 4, 4], 11 | ] 12 | ) 13 | 14 | cama_downxy_1 = ( 15 | np.array( 16 | [ 17 | [0, 0, 0, -1, -1], 18 | [0, 0, 0, -1, -1], 19 | [1, 0, -1, -1, -1], 20 | [1, -999, -1, -1, -1], 21 | ] 22 | ), 23 | np.array( 24 | [ 25 | [1, 1, 1, 1, 1], 26 | [1, 1, 1, 1, 1], 27 | [1, 1, 1, 0, 0], 28 | [0, -999, 0, 0, 0], 29 | ] 30 | ), 31 | ) 32 | 33 | cama_nextxy_1 = ( 34 | np.array( 35 | [ 36 | [1, 2, 3, 3, 4], 37 | [1, 2, 3, 3, 4], 38 | [2, 2, 2, 3, 4], 39 | [2, -9, 2, 3, 4], 40 | ] 41 | ), 42 | np.array( 43 | [ 44 | [2, 2, 2, 2, 2], 45 | [3, 3, 3, 3, 3], 46 | [4, 4, 4, 3, 3], 47 | [4, -9, 4, 4, 4], 48 | ] 49 | ), 50 | ) 51 | 52 | # RIVER NETWORK EXAMPLE TWO 53 | 54 | d8_ldd_2 = np.array( 55 | [ 56 | [5, 6, 3, 5], 57 | [6, 3, 2, 2], 58 | [3, 2, 1, 7], 59 | [2, 5, 6, 8], 60 | ] 61 | ) 62 | 63 | 64 | cama_downxy_2 = ( 65 | np.array( 66 | [ 67 | [-999, 1, 1, -999], 68 | [1, 1, 0, 0], 69 | [1, 0, -1, -1], 70 | [0, -999, 1, 0], 71 | ] 72 | ), 73 | np.array( 74 | [ 75 | [-999, 0, 1, -999], 76 | [0, 1, 1, 1], 77 | [1, 1, 1, -1], 78 | [1, -999, 0, -1], 79 | ] 80 | ), 81 | ) 82 | 83 | 84 | cama_nextxy_2 = ( 85 | np.array( 86 | [ 87 | [-9, 3, 4, -9], 88 | [2, 3, 3, 4], 89 | [2, 2, 2, 3], 90 | [1, -9, 4, 4], 91 | ] 92 | ), 93 | np.array( 94 | [ 95 | [-9, 1, 2, -9], 96 | [2, 3, 3, 3], 97 | [4, 4, 4, 2], 98 | [1, -9, 4, 3], 99 | ] 100 | ), 101 | ) 102 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_core/_accumulate.py: -------------------------------------------------------------------------------- 1 | def _ufunc_to_downstream( 2 | field, 3 | did, 4 | uid, 5 | eid, 6 | node_additive_weight, 7 | node_multiplicative_weight, 8 | node_modifier_use_upstream, 9 | edge_additive_weight, 10 | edge_multiplicative_weight, 11 | func, 12 | xp, 13 | ): 14 | """ 15 | Updates field in-place by applying a ufunc at the downstream nodes 16 | of the grouping. 17 | 18 | Parameters 19 | ---------- 20 | river_network : earthkit.hydro.network.RiverNetwork 21 | An earthkit-hydro river network object. 22 | field : numpy.ndarray 23 | The input field. 24 | grouping : numpy.ndarray 25 | An array of indices. 26 | mv : scalar 27 | A missing value indicator (not used in the function but kept for consistency). 28 | additive_weight : numpy.ndarray, optional 29 | A weight to be added to the field values. Default is None. 30 | multiplicative_weight : numpy.ndarray, optional 31 | A weight to be multiplied with the field values. Default is None. 32 | modifier_use_upstream : bool, optional 33 | If True, the modifiers are used on the upstream nodes instead of downstream. 34 | Default is True. 35 | ufunc : numpy.ufunc 36 | A universal function from the numpy library to be applied to the field data. 37 | Available ufuncs: https://numpy.org/doc/2.2/reference/ufuncs.html. 38 | Note: must allow two operands. 39 | 40 | Returns 41 | ------- 42 | None 43 | """ 44 | modifier_group = uid if node_modifier_use_upstream else did 45 | 46 | modifier_field = xp.gather(field, uid, axis=-1) 47 | # ADD HAPPENS BEFORE MULT 48 | # TODO: add an option to switch order 49 | if node_additive_weight is not None: 50 | modifier_field += node_additive_weight[..., modifier_group] 51 | if edge_additive_weight is not None: 52 | modifier_field += edge_additive_weight[..., eid] 53 | if node_multiplicative_weight is not None: 54 | modifier_field *= node_multiplicative_weight[..., modifier_group] 55 | if edge_multiplicative_weight is not None: 56 | modifier_field *= edge_multiplicative_weight[..., eid] 57 | return func( 58 | field, 59 | did, 60 | modifier_field, 61 | ) 62 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_core/online.py: -------------------------------------------------------------------------------- 1 | from .accumulate import flow 2 | from .metrics import metrics_func_finder 3 | 4 | 5 | def calculate_online_metric( 6 | xp, 7 | river_network, 8 | field, 9 | metric, 10 | node_weights, 11 | edge_weights, 12 | flow_direction, 13 | ): 14 | if flow_direction == "up": 15 | invert_graph = True 16 | elif flow_direction == "down": 17 | invert_graph = False 18 | else: 19 | raise ValueError( 20 | f"flow_direction must be 'up' or 'down', got {flow_direction}." 21 | ) 22 | 23 | field = xp.copy(field) 24 | 25 | if node_weights is None: 26 | if metric == "mean" or metric == "std" or metric == "var": 27 | node_weights = xp.ones(river_network.n_nodes, dtype=xp.float64) 28 | else: 29 | node_weights = xp.copy(node_weights) 30 | 31 | if edge_weights is not None: 32 | edge_weights = xp.copy(edge_weights) 33 | 34 | func = metrics_func_finder(metric, xp).func 35 | 36 | weighted_field = flow( 37 | xp, 38 | river_network, 39 | field if node_weights is None else field * node_weights, 40 | func, 41 | invert_graph, 42 | edge_multiplicative_weight=edge_weights, 43 | ) 44 | 45 | if metric == "mean" or metric == "std" or metric == "var": 46 | counts = flow( 47 | xp, 48 | river_network, 49 | xp.copy(node_weights), 50 | func, 51 | invert_graph, 52 | edge_multiplicative_weight=edge_weights, 53 | ) 54 | 55 | if metric == "mean": 56 | weighted_field /= counts 57 | return weighted_field 58 | elif metric == "var" or metric == "std": 59 | weighted_sum_of_squares = flow( 60 | xp, 61 | river_network, 62 | field**2 if node_weights is None else field**2 * node_weights, 63 | func, 64 | invert_graph, 65 | edge_multiplicative_weight=edge_weights, 66 | ) 67 | mean = weighted_field / counts 68 | weighted_sum_of_squares = weighted_sum_of_squares / counts - mean**2 69 | weighted_sum_of_squares = xp.clip(weighted_sum_of_squares, 0, xp.inf) 70 | if metric == "var": 71 | return weighted_sum_of_squares 72 | elif metric == "std": 73 | return xp.sqrt(weighted_sum_of_squares) 74 | else: 75 | return weighted_field 76 | -------------------------------------------------------------------------------- /docs/clean_autodocs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | AUTODOCS_DIR = "source/autodocs" 5 | 6 | for root, dirs, files in os.walk(AUTODOCS_DIR): 7 | for fname in files: 8 | if not fname.endswith(".rst"): 9 | continue 10 | path = os.path.join(root, fname) 11 | with open(path, "r", encoding="utf-8") as f: 12 | lines = f.readlines() 13 | 14 | new_lines = [] 15 | changed = False 16 | i = 0 17 | 18 | while i < len(lines): 19 | line = lines[i].strip() 20 | 21 | if line == "earthkit.hydro package" and i + 1 < len(lines): 22 | new_title = "API Reference" 23 | underline = "=" * len(new_title) 24 | new_lines.append(f"{new_title}\n") 25 | new_lines.append(f"{underline}\n") 26 | i += 2 # skip underline 27 | changed = True 28 | continue 29 | 30 | # === Match lines like "earthkit.hydro.data\_structures package" === 31 | match = re.match(r"^([\w\.\\]+)\s+(package|module)$", line) 32 | if match and i + 1 < len(lines): 33 | title = match.group(1) # Keep escaped underscores as-is 34 | underline = "=" * len(title) 35 | new_lines.append(f"{title}\n") 36 | new_lines.append(f"{underline}\n") 37 | i += 2 # skip underline 38 | changed = True 39 | continue 40 | 41 | # === Add :titlesonly: under .. toctree:: if missing === 42 | if line == ".. toctree::": 43 | new_lines.append(lines[i]) # original line 44 | i += 1 45 | # Look ahead for options 46 | has_titlesonly = False 47 | temp_lines = [] 48 | while i < len(lines) and lines[i].lstrip().startswith(":"): 49 | if ":titlesonly:" in lines[i]: 50 | has_titlesonly = True 51 | temp_lines.append(lines[i]) 52 | i += 1 53 | if not has_titlesonly: 54 | new_lines.append(" :titlesonly:\n") 55 | changed = True 56 | new_lines.extend(temp_lines) 57 | continue 58 | 59 | # default case 60 | new_lines.append(lines[i]) 61 | i += 1 62 | 63 | if changed: 64 | with open(path, "w", encoding="utf-8") as f: 65 | f.writelines(new_lines) 66 | -------------------------------------------------------------------------------- /src/earthkit/hydro/catchments/array/_operations.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro._utils.decorators import multi_backend 2 | from earthkit.hydro._utils.locations import locations_to_1d 3 | from earthkit.hydro.catchments.array import __operations as _operations 4 | 5 | 6 | @multi_backend(allow_jax_jit=False) 7 | def var(xp, river_network, field, locations, node_weights, edge_weights): 8 | stations_1d, _, _ = locations_to_1d(xp, river_network, locations) 9 | return _operations.var( 10 | xp, river_network, field, stations_1d, node_weights, edge_weights 11 | ) 12 | 13 | 14 | @multi_backend(allow_jax_jit=False) 15 | def std(xp, river_network, field, locations, node_weights, edge_weights): 16 | stations_1d, _, _ = locations_to_1d(xp, river_network, locations) 17 | return _operations.std( 18 | xp, river_network, field, stations_1d, node_weights, edge_weights 19 | ) 20 | 21 | 22 | @multi_backend(allow_jax_jit=False) 23 | def mean(xp, river_network, field, locations, node_weights, edge_weights): 24 | stations_1d, _, _ = locations_to_1d(xp, river_network, locations) 25 | return _operations.mean( 26 | xp, river_network, field, stations_1d, node_weights, edge_weights 27 | ) 28 | 29 | 30 | @multi_backend(allow_jax_jit=False) 31 | def sum(xp, river_network, field, locations, node_weights, edge_weights): 32 | stations_1d, _, _ = locations_to_1d(xp, river_network, locations) 33 | return _operations.sum( 34 | xp, river_network, field, stations_1d, node_weights, edge_weights 35 | ) 36 | 37 | 38 | @multi_backend(allow_jax_jit=False) 39 | def min(xp, river_network, field, locations, node_weights, edge_weights): 40 | stations_1d, _, _ = locations_to_1d(xp, river_network, locations) 41 | return _operations.min( 42 | xp, river_network, field, stations_1d, node_weights, edge_weights 43 | ) 44 | 45 | 46 | @multi_backend(allow_jax_jit=False) 47 | def max(xp, river_network, field, locations, node_weights, edge_weights): 48 | stations_1d, _, _ = locations_to_1d(xp, river_network, locations) 49 | return _operations.max( 50 | xp, river_network, field, stations_1d, node_weights, edge_weights 51 | ) 52 | 53 | 54 | @multi_backend() 55 | def find(xp, river_network, locations, overwrite, return_type): 56 | stations1d, _, _ = locations_to_1d(xp, river_network, locations) 57 | field = xp.full(river_network.n_nodes, xp.nan, device=river_network.device) 58 | field[stations1d] = xp.arange(stations1d.shape[0]) 59 | return _operations.find(xp, river_network, field, overwrite, return_type) 60 | -------------------------------------------------------------------------------- /rust/lib.rs: -------------------------------------------------------------------------------- 1 | // (C) Copyright 2025- ECMWF. 2 | // 3 | // This software is licensed under the terms of the Apache Licence Version 2.0 4 | // which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. 5 | // In applying this licence, ECMWF does not waive the privileges and immunities 6 | // granted to it by virtue of its status as an intergovernmental organisation 7 | // nor does it submit to any jurisdiction. 8 | 9 | use pyo3::prelude::*; 10 | use rayon::prelude::*; 11 | use numpy::{PyArray1, PyReadonlyArray1}; 12 | use pyo3::exceptions::PyValueError; 13 | use std::sync::atomic::{AtomicI64, Ordering}; 14 | use fixedbitset::FixedBitSet; 15 | 16 | #[pyfunction] 17 | fn compute_topological_labels_rust<'py>( 18 | py: Python<'py>, 19 | sources: PyReadonlyArray1<'py, usize>, 20 | sinks: PyReadonlyArray1<'py, usize>, 21 | downstream_nodes: PyReadonlyArray1<'py, usize>, 22 | n_nodes: usize, 23 | ) -> PyResult>> { 24 | 25 | let labels: Vec = (0..n_nodes) 26 | .map(|_| AtomicI64::new(0)) 27 | .collect(); 28 | 29 | let mut current = sources.as_slice()?.to_vec(); 30 | let sinks = sinks.as_slice()?; 31 | let downstream = downstream_nodes.as_slice()?; 32 | 33 | let mut next = Vec::with_capacity(current.len()); 34 | let mut visited = FixedBitSet::with_capacity(n_nodes); 35 | 36 | for &i in ¤t { 37 | let d = downstream[i]; 38 | if d != n_nodes { 39 | next.push(d); 40 | } 41 | } 42 | std::mem::swap(&mut current, &mut next); 43 | 44 | for n in 1..=n_nodes { 45 | if current.is_empty() { 46 | sinks.par_iter().for_each(|&i| { 47 | labels[i].store((n as i64) - 1, Ordering::Relaxed); 48 | }); 49 | break; 50 | } 51 | 52 | current.par_iter().for_each(|&i| { 53 | labels[i].store(n as i64, Ordering::Relaxed); 54 | }); 55 | 56 | next.clear(); 57 | visited.clear(); 58 | for &i in ¤t { 59 | let d = downstream[i]; 60 | if d != n_nodes && !visited.contains(d) { 61 | visited.insert(d); 62 | next.push(d); 63 | } 64 | } 65 | 66 | std::mem::swap(&mut current, &mut next); 67 | } 68 | 69 | if !current.is_empty() { 70 | return Err(PyErr::new::("River Network contains a cycle.")); 71 | } 72 | 73 | let result: Vec = labels.iter() 74 | .map(|a| a.load(Ordering::Relaxed)) 75 | .collect(); 76 | let array = PyArray1::from_vec(py, result); 77 | Ok(array.to_owned().into()) 78 | } 79 | 80 | #[pymodule] 81 | fn _rust(m: &Bound<'_, PyModule>) -> PyResult<()> { 82 | m.add_function(wrap_pyfunction!(compute_topological_labels_rust, m)?)?; 83 | Ok(()) 84 | } 85 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_utils/decorators/masking.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | 4 | def mask(unmask=True): 5 | 6 | def decorator(func): 7 | 8 | @wraps(func) 9 | def wrapper(xp, river_network, field, *args, **kwargs): 10 | 11 | if field.shape[-2:] == river_network.shape: 12 | args, kwargs = process_args_kwargs(xp, river_network, args, kwargs) 13 | field_1d = mask_last2_dims(xp, field, river_network.mask, field.shape) 14 | 15 | out_1d = func(xp, river_network, field_1d, *args, **kwargs) 16 | 17 | if unmask: 18 | out_shape = field.shape 19 | return scatter_and_reshape( 20 | xp, 21 | river_network.mask, 22 | out_1d, 23 | out_shape, 24 | device=river_network.device, 25 | ) 26 | else: 27 | return out_1d 28 | else: 29 | args, kwargs = process_args_kwargs(xp, river_network, args, kwargs) 30 | out_1d = func(xp, river_network, field, *args, **kwargs) 31 | if unmask: 32 | out_shape = field.shape[:-1] + river_network.shape 33 | return scatter_and_reshape( 34 | xp, 35 | river_network.mask, 36 | out_1d, 37 | out_shape, 38 | device=river_network.device, 39 | ) 40 | else: 41 | return out_1d 42 | 43 | return wrapper 44 | 45 | return decorator 46 | 47 | 48 | def mask_last2_dims(xp, tensor, mask, target_shape): 49 | B = target_shape[:-2] 50 | M, N = target_shape[-2], target_shape[-1] 51 | flat_shape = B + (M * N,) 52 | tensor_flat = xp.reshape(tensor, flat_shape) 53 | return xp.gather(tensor_flat, mask, axis=-1) 54 | 55 | 56 | def scatter_and_reshape(xp, mask, out_1d, target_shape, device): 57 | B = target_shape[:-2] 58 | M, N = target_shape[-2], target_shape[-1] 59 | flat_shape = B + (M * N,) 60 | out_flat = xp.full(flat_shape, xp.nan, device=device, dtype=out_1d.dtype) 61 | out_flat = xp.scatter_assign(out_flat, mask, out_1d) 62 | return xp.reshape(out_flat, target_shape) 63 | 64 | 65 | def process_args_kwargs(xp, river_network, args, kwargs): 66 | def process_arg(arg): 67 | if ( 68 | hasattr(arg, "shape") # TODO: decide if robust enough 69 | and len(arg.shape) >= 2 70 | and arg.shape[-2:] == river_network.shape 71 | ): 72 | return mask_last2_dims(xp, arg, river_network.mask, arg.shape) 73 | return arg 74 | 75 | new_args = tuple(process_arg(arg) for arg in args) 76 | new_kwargs = {k: process_arg(v) for k, v in kwargs.items()} 77 | return new_args, new_kwargs 78 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_utils/readers.py: -------------------------------------------------------------------------------- 1 | from struct import unpack 2 | 3 | import numpy as np 4 | 5 | # CSF value scales 6 | # version 2 datatypes 7 | VS_BOOLEAN = 0xE0 # boolean, always UINT1, values: 0,1 or MV_UINT1 8 | VS_NOMINAL = 0xE2 # nominal, UINT1 or INT4 9 | VS_ORDINAL = 0xF2 # ordinal, UINT1 or INT4 10 | VS_SCALAR = 0xEB # scalar, REAL4 or (maybe) REAL8 11 | VS_DIRECTION = 0xFB # directional REAL4 or (maybe) REAL8, -1 means no direction 12 | VS_LDD = 0xF0 # local drain direction, always UINT1, values: 1-9 or MV_UINT1 13 | # this one CANNOT be returned by NOR passed to a csf2 function 14 | VS_UNDEFINED = 100 # just some value different from the rest 15 | 16 | # CSF cell representations 17 | # preferred version 2 cell representations 18 | CR_UINT1 = 0x00 # boolean, ldd and small nominal and small ordinal 19 | CR_INT4 = 0x26 # large nominal and large ordinal 20 | CR_REAL4 = 0x5A # single scalar and single directional 21 | # other version 2 cell representations 22 | CR_REAL8 = 0xDB # double scalar or directional, no loss of precision 23 | 24 | 25 | def _replace_missing_u1(cur, new): 26 | out = np.copy(cur) 27 | out[cur == 255] = new 28 | return out 29 | 30 | 31 | def _replace_missing_i4(cur, new): 32 | out = np.copy(cur) 33 | out[cur == -2147483648] = new 34 | return out 35 | 36 | 37 | def _replace_missing_f4(cur, new): 38 | out = np.copy(cur) 39 | out[np.isnan(cur)] = new 40 | return out 41 | 42 | 43 | def _replace_missing_f8(cur, new): 44 | out = np.copy(cur) 45 | out[np.isnan(cur)] = new 46 | return out 47 | 48 | 49 | CELLREPR = { 50 | CR_UINT1: dict( 51 | dtype=np.dtype("uint8"), 52 | fillmv=_replace_missing_u1, 53 | ), 54 | CR_INT4: dict( 55 | dtype=np.dtype("int32"), 56 | fillmv=_replace_missing_i4, 57 | ), 58 | CR_REAL4: dict( 59 | dtype=np.dtype("float32"), 60 | fillmv=_replace_missing_f4, 61 | ), 62 | CR_REAL8: dict( 63 | dtype=np.dtype("float64"), 64 | fillmv=_replace_missing_f8, 65 | ), 66 | } 67 | 68 | 69 | def from_file(path, mask=False): 70 | """Load a .map file into a numpy array.""" 71 | 72 | with open(path, "rb") as f: 73 | bytes = f.read() 74 | 75 | nbytes_header = 64 + 2 + 2 + 8 + 8 + 8 + 8 + 4 + 4 + 8 + 8 + 8 76 | _, cellRepr, _, _, _, _, nrRows, nrCols, _, _, _ = unpack( 77 | "=hhddddIIddd", bytes[64:nbytes_header] 78 | ) 79 | 80 | try: 81 | celltype = CELLREPR[cellRepr] 82 | except KeyError: 83 | raise Exception( 84 | "{}: invalid cellRepr value ({}) in header".format(path, cellRepr) 85 | ) 86 | 87 | dtype = celltype["dtype"] 88 | 89 | size = dtype.itemsize * nrRows * nrCols 90 | data = np.frombuffer(bytes[256 : 256 + size], dtype) 91 | if mask: 92 | data = celltype["fillmv"](data.astype(np.float64), np.nan) 93 | 94 | data = data.reshape((nrRows, nrCols)) 95 | 96 | return data 97 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_core/_move.py: -------------------------------------------------------------------------------- 1 | from ._accumulate import _ufunc_to_downstream 2 | from .flow import propagate 3 | 4 | 5 | def move_downstream( 6 | xp, 7 | river_network, 8 | field, 9 | func, 10 | node_additive_weight=None, 11 | node_multiplicative_weight=None, 12 | node_modifier_use_upstream=True, 13 | edge_additive_weight=None, 14 | edge_multiplicative_weight=None, 15 | ): 16 | invert_graph = False 17 | return move_python( 18 | xp, 19 | river_network, 20 | field, 21 | func, 22 | invert_graph, 23 | node_additive_weight, 24 | node_multiplicative_weight, 25 | node_modifier_use_upstream, 26 | edge_additive_weight, 27 | edge_multiplicative_weight, 28 | ) 29 | 30 | 31 | def move_upstream( 32 | xp, 33 | river_network, 34 | field, 35 | func, 36 | node_additive_weight=None, 37 | node_multiplicative_weight=None, 38 | node_modifier_use_upstream=True, 39 | edge_additive_weight=None, 40 | edge_multiplicative_weight=None, 41 | ): 42 | invert_graph = True 43 | return move_python( 44 | xp, 45 | river_network, 46 | field, 47 | func, 48 | invert_graph, 49 | node_additive_weight, 50 | node_multiplicative_weight, 51 | node_modifier_use_upstream, 52 | edge_additive_weight, 53 | edge_multiplicative_weight, 54 | ) 55 | 56 | 57 | def move_python( 58 | xp, 59 | river_network, 60 | field, 61 | func, 62 | invert_graph=False, 63 | node_additive_weight=None, 64 | node_multiplicative_weight=None, 65 | node_modifier_use_upstream=True, 66 | edge_additive_weight=None, 67 | edge_multiplicative_weight=None, 68 | ): 69 | op = _ufunc_to_downstream 70 | 71 | def operation( 72 | field, 73 | did, 74 | uid, 75 | eid, 76 | node_additive_weight, 77 | node_multiplicative_weight, 78 | node_modifier_use_upstream, 79 | edge_additive_weight, 80 | edge_multiplicative_weight, 81 | ): 82 | return op( 83 | field, 84 | did, 85 | uid, 86 | eid, 87 | node_additive_weight, 88 | node_multiplicative_weight, 89 | node_modifier_use_upstream, 90 | edge_additive_weight, 91 | edge_multiplicative_weight, 92 | func=func, 93 | xp=xp, 94 | ) 95 | 96 | field = propagate( 97 | river_network, 98 | river_network.data, 99 | field, 100 | invert_graph, 101 | operation, 102 | node_additive_weight, 103 | node_multiplicative_weight, 104 | node_modifier_use_upstream, 105 | edge_additive_weight, 106 | edge_multiplicative_weight, 107 | ) 108 | 109 | return field 110 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_core/move.py: -------------------------------------------------------------------------------- 1 | from ._move import move_python as flow 2 | from .metrics import metrics_func_finder 3 | 4 | 5 | def calculate_move_metric( 6 | xp, 7 | river_network, 8 | field, 9 | metric, 10 | node_weights, 11 | edge_weights, 12 | flow_direction, 13 | ): 14 | if flow_direction == "up": 15 | invert_graph = True 16 | node_modifier_use_upstream = True 17 | elif flow_direction == "down": 18 | invert_graph = False 19 | node_modifier_use_upstream = True 20 | else: 21 | raise ValueError( 22 | f"flow_direction must be 'up' or 'down', got {flow_direction}." 23 | ) 24 | 25 | if node_weights is None: 26 | if metric == "mean" or metric == "std" or metric == "var": 27 | node_weights = xp.ones(river_network.n_nodes, dtype=xp.float64) 28 | else: 29 | node_weights = xp.copy(node_weights) 30 | 31 | if edge_weights is not None: 32 | edge_weights = xp.copy(edge_weights) 33 | 34 | func = metrics_func_finder(metric, xp).func 35 | 36 | weighted_field = flow( 37 | xp, 38 | river_network, 39 | xp.zeros(field.shape), 40 | func, 41 | invert_graph, 42 | node_additive_weight=field if node_weights is None else field * node_weights, 43 | node_modifier_use_upstream=node_modifier_use_upstream, 44 | edge_multiplicative_weight=edge_weights, 45 | ) 46 | 47 | if metric == "mean" or metric == "std" or metric == "var": 48 | counts = flow( 49 | xp, 50 | river_network, 51 | xp.zeros(field.shape), 52 | func, 53 | invert_graph, 54 | node_additive_weight=xp.copy(node_weights), 55 | node_modifier_use_upstream=node_modifier_use_upstream, 56 | edge_multiplicative_weight=edge_weights, 57 | ) 58 | 59 | if metric == "mean": 60 | weighted_field /= counts 61 | return weighted_field 62 | elif metric == "var" or metric == "std": 63 | weighted_sum_of_squares = flow( 64 | xp, 65 | river_network, 66 | xp.zeros(field.shape), 67 | func, 68 | invert_graph, 69 | node_additive_weight=( 70 | field**2 if node_weights is None else field**2 * node_weights 71 | ), 72 | node_modifier_use_upstream=node_modifier_use_upstream, 73 | edge_multiplicative_weight=edge_weights, 74 | ) 75 | mean = weighted_field / counts 76 | weighted_sum_of_squares = weighted_sum_of_squares / counts - mean**2 77 | weighted_sum_of_squares = xp.clip(weighted_sum_of_squares, 0, xp.inf) 78 | if metric == "var": 79 | return weighted_sum_of_squares 80 | elif metric == "std": 81 | return xp.sqrt(weighted_sum_of_squares) 82 | else: 83 | return weighted_field 84 | -------------------------------------------------------------------------------- /tests/_test_inputs/distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | stations = [(0, 0), (1, 1), (3, 3)] 4 | 5 | weights_1 = np.array([6, 1, 2, 3, 4, 7, 1, 5, 5, 0, 6, 1, 0, 9, 9, 8, 3, 0, 6, 4]) 6 | 7 | distance_1_min_up_down = np.array( 8 | [ 9 | [0.0, 1.0, 9.0, 10.0, 11.0], 10 | [6.0, 0.0, 7.0, 7.0, 11.0], 11 | [8.0, 1.0, 2.0, 11.0, 20.0], 12 | [10.0, 2.0, 2.0, 0.0, 4.0], 13 | ] 14 | ) 15 | 16 | distance_1_min_up = np.array( 17 | [ 18 | [0.0, 1.0, np.inf, np.inf, np.inf], 19 | [np.inf, 0.0, np.inf, np.inf, np.inf], 20 | [np.inf, np.inf, np.inf, np.inf, np.inf], 21 | [np.inf, np.inf, np.inf, 0.0, 4.0], 22 | ] 23 | ) 24 | 25 | distance_1_min_down = np.array( 26 | [ 27 | [0.0, np.inf, np.inf, np.inf, np.inf], 28 | [6.0, 0.0, np.inf, np.inf, np.inf], 29 | [13.0, 1.0, np.inf, np.inf, np.inf], 30 | [np.inf, 2.0, 6.0, 0.0, np.inf], 31 | ] 32 | ) 33 | 34 | length_1_min_up_down = np.array( 35 | [ 36 | [6.0, 2.0, 12.0, 13.0, 14.0], 37 | [13.0, 1.0, 10.0, 10.0, 14.0], 38 | [11.0, 2.0, 5.0, 14.0, 23.0], 39 | [13.0, 5.0, 5.0, 6.0, 10.0], 40 | ] 41 | ) 42 | 43 | length_1_min_up = np.array( 44 | [ 45 | [6.0, 2.0, np.inf, np.inf, np.inf], 46 | [np.inf, 1.0, np.inf, np.inf, np.inf], 47 | [np.inf, np.inf, np.inf, np.inf, np.inf], 48 | [np.inf, np.inf, np.inf, 6.0, 10.0], 49 | ] 50 | ) 51 | 52 | length_1_min_down = np.array( 53 | [ 54 | [6.0, np.inf, np.inf, np.inf, np.inf], 55 | [13.0, 1.0, np.inf, np.inf, np.inf], 56 | [19.0, 2.0, np.inf, np.inf, np.inf], 57 | [np.inf, 5.0, 6.0, 6.0, np.inf], 58 | ] 59 | ) 60 | 61 | # def distance_1_max_up_down(): 62 | # return 63 | 64 | distance_1_max_up = np.array( 65 | [ 66 | [0.0, 1.0, -np.inf, -np.inf, -np.inf], 67 | [-np.inf, 0.0, -np.inf, -np.inf, -np.inf], 68 | [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf], 69 | [-np.inf, -np.inf, -np.inf, 0.0, 4.0], 70 | ] 71 | ) 72 | 73 | distance_1_max_down = np.array( 74 | [ 75 | [0.0, -np.inf, -np.inf, -np.inf, -np.inf], 76 | [6.0, 0.0, -np.inf, -np.inf, -np.inf], 77 | [13.0, 1.0, -np.inf, -np.inf, -np.inf], 78 | [-np.inf, 19.0, 6.0, 0.0, -np.inf], 79 | ] 80 | ) 81 | 82 | # def length_1_max_up_down(): 83 | # return 84 | 85 | length_1_max_up = np.array( 86 | [ 87 | [6.0, 2.0, -np.inf, -np.inf, -np.inf], 88 | [-np.inf, 1.0, -np.inf, -np.inf, -np.inf], 89 | [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf], 90 | [-np.inf, -np.inf, -np.inf, 6.0, 10.0], 91 | ] 92 | ) 93 | 94 | 95 | length_1_max_down = np.array( 96 | [ 97 | [6.0, -np.inf, -np.inf, -np.inf, -np.inf], 98 | [13.0, 1.0, -np.inf, -np.inf, -np.inf], 99 | [19.0, 2.0, -np.inf, -np.inf, -np.inf], 100 | [-np.inf, 22.0, 6.0, 6.0, -np.inf], 101 | ] 102 | ) 103 | -------------------------------------------------------------------------------- /src/earthkit/hydro/catchments/array/__operations.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro._core._find import _flow_find 2 | from earthkit.hydro._utils.decorators import mask 3 | from earthkit.hydro.upstream.array._operations import calculate_upstream_metric 4 | 5 | 6 | def calculate_catchment_metric( 7 | xp, 8 | river_network, 9 | field, 10 | stations_1d, 11 | metric, 12 | node_weights, 13 | edge_weights, 14 | ): 15 | upstream_metric_field = calculate_upstream_metric( 16 | xp, 17 | river_network, 18 | field, 19 | metric, 20 | node_weights, 21 | edge_weights, 22 | ) 23 | return xp.gather(upstream_metric_field, stations_1d, axis=-1) 24 | 25 | 26 | @mask(unmask=False) 27 | def var(xp, river_network, field, locations, node_weights, edge_weights): 28 | return calculate_catchment_metric( 29 | xp, 30 | river_network, 31 | field, 32 | locations, 33 | "var", 34 | node_weights, 35 | edge_weights, 36 | ) 37 | 38 | 39 | @mask(unmask=False) 40 | def std(xp, river_network, field, locations, node_weights, edge_weights): 41 | return calculate_catchment_metric( 42 | xp, 43 | river_network, 44 | field, 45 | locations, 46 | "std", 47 | node_weights, 48 | edge_weights, 49 | ) 50 | 51 | 52 | @mask(unmask=False) 53 | def mean(xp, river_network, field, locations, node_weights, edge_weights): 54 | return calculate_catchment_metric( 55 | xp, 56 | river_network, 57 | field, 58 | locations, 59 | "mean", 60 | node_weights, 61 | edge_weights, 62 | ) 63 | 64 | 65 | @mask(unmask=False) 66 | def sum(xp, river_network, field, locations, node_weights, edge_weights): 67 | return calculate_catchment_metric( 68 | xp, 69 | river_network, 70 | field, 71 | locations, 72 | "sum", 73 | node_weights, 74 | edge_weights, 75 | ) 76 | 77 | 78 | @mask(unmask=False) 79 | def min(xp, river_network, field, locations, node_weights, edge_weights): 80 | return calculate_catchment_metric( 81 | xp, 82 | river_network, 83 | field, 84 | locations, 85 | "min", 86 | node_weights, 87 | edge_weights, 88 | ) 89 | 90 | 91 | @mask(unmask=False) 92 | def max(xp, river_network, field, locations, node_weights, edge_weights): 93 | return calculate_catchment_metric( 94 | xp, 95 | river_network, 96 | field, 97 | locations, 98 | "max", 99 | node_weights, 100 | edge_weights, 101 | ) 102 | 103 | 104 | def find(xp, river_network, field, overwrite, return_type): 105 | return_type = river_network.return_type if return_type is None else return_type 106 | if return_type not in ["gridded", "masked"]: 107 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 108 | decorated_flow_find = mask(return_type == "gridded")(_flow_find) 109 | return decorated_flow_find(xp, river_network, field, overwrite) 110 | -------------------------------------------------------------------------------- /docs/source/userguide/raster_vector_inputs.rst: -------------------------------------------------------------------------------- 1 | Raster and vector networks 2 | ========================== 3 | 4 | earthkit-hydro supports both raster (gridded) and vector river networks. These network types have slightly different characteristics which necessitate some clarification. 5 | 6 | .. raw:: html 7 | 8 |
9 | 10 | .. figure:: ../../images/raster_vector_networks.jpg 11 | :width: 400px 12 | 13 | *From:* :cite:`rastervector` 14 | 15 | .. raw:: html 16 | 17 |
18 | 19 | Raster networks 20 | --------------- 21 | 22 | Raster networks are the most common type of river network in earthkit-hydro. They are represented as a grid where each cell corresponds to a specific location in the river network. 23 | With these type of networks, it is often most natural to conduct river network operations directly on this grid. 24 | 25 | .. code-block:: python 26 | 27 | import earthkit.hydro as ekh 28 | 29 | # Load a raster river network 30 | network = ekh.river_network.load("efas", "5") 31 | 32 | field = np.ones(network.shape) # field on the river network grid 33 | 34 | output = ekh.upstream.sum(network, field) # output on same grid 35 | 36 | Vector networks 37 | --------------- 38 | 39 | In vector networks, each river segment is represented as a node, and the network is defined by the connections between these nodes. 40 | 41 | .. code-block:: python 42 | 43 | # vector field (1D) defined on the nodes of the river network 44 | field = np.ones(network.n_nodes) 45 | 46 | # output field is also 1D, defined on the nodes of the river network 47 | output = ekh.upstream.sum(network, field, return_grid=False) 48 | 49 | Switching between vector and raster 50 | ----------------------------------- 51 | 52 | Raster networks can also be used as if they were vector networks, since internally the raster network is represented as a vector network. 53 | To allow users to work with both types of networks seamlessly, earthkit-hydro has the `return_grid` function argument, which is by default True. 54 | This allows the user to specify whether they want a gridded output (only available for raster networks), or a 1d-vector output. 55 | 56 | Note that it is possible to mix gridded and raster inputs to functions. 57 | 58 | .. code-block:: python 59 | 60 | # vector field (1D) defined on the nodes of the river network 61 | field = np.ones(network.n_nodes) 62 | 63 | # output field is a grid 64 | output = ekh.upstream.sum(network, field) 65 | 66 | Multidimensional inputs 67 | ----------------------- 68 | 69 | Any leading dimensions of the data are treated as batch/vectorised dimensions, allowing for operations on multiple fields at once. 70 | This means that users can pass directly time series or other multi-dimensional data without needing to manually loop. 71 | The last dimensions are always assumed to be spatial i.e. either the grid dimensions, or the 1d vector dimension. 72 | 73 | .. code-block:: python 74 | 75 | # vector field (1D) defined on the nodes of the river network 76 | field = np.ones((3, 4, 5, network.n_nodes)) 77 | 78 | # output field is of shape (3, 4, 5, *network.shape) 79 | output = ekh.upstream.sum(network, field) 80 | -------------------------------------------------------------------------------- /src/earthkit/hydro/river_network/_cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from functools import wraps 4 | from hashlib import sha256 5 | 6 | import joblib 7 | 8 | from earthkit.hydro._version import __version__ as ekh_version 9 | from earthkit.hydro.data_structures._network import RiverNetwork 10 | 11 | # read in only up to second decimal point 12 | # i.e. 0.1.dev90 -> 0.1 13 | ekh_version = ".".join(ekh_version.split(".")[:2]) 14 | 15 | 16 | def cache(func): 17 | """ 18 | Decorator to allow automatic use of cache. 19 | 20 | Parameters 21 | ---------- 22 | func : callable 23 | The function to be wrapped and executed with masking applied. 24 | 25 | Returns 26 | ------- 27 | callable 28 | The wrapped function. 29 | """ 30 | 31 | @wraps(func) 32 | def wrapper( 33 | path, 34 | river_network_format, 35 | source="file", 36 | use_cache=True, 37 | cache_dir=tempfile.mkdtemp(suffix="_earthkit_hydro"), 38 | cache_fname="{ekh_version}_{hash}.joblib", 39 | cache_compression=1, 40 | ): 41 | """ 42 | Wrapper to load river network from cache if available, otherwise 43 | create and cache it. 44 | 45 | Parameters 46 | ---------- 47 | path : str 48 | The path to the river network. 49 | river_network_format : str 50 | The format of the river network file. 51 | Supported formats are "precomputed", "cama", "pcr_d8", and "esri_d8". 52 | source : str, optional 53 | The source of the river network. 54 | For possible sources see: 55 | https://earthkit-data.readthedocs.io/en/latest/guide/sources.html 56 | use_cache : bool, optional 57 | Whether to use caching. Default is True. 58 | cache_dir : str, optional 59 | The directory to store the cache files. Default is a temporary directory. 60 | cache_fname : str, optional 61 | The filename template for the cache files. 62 | Default is "{ekh_version}_{hash}.joblib". 63 | cache_compression : int, optional 64 | The compression level for the cache files. Default is 1. 65 | 66 | Returns 67 | ------- 68 | earthkit.hydro.network_class.RiverNetwork 69 | The loaded river network. 70 | """ 71 | if use_cache: 72 | hashed_name = sha256(path.encode("utf-8")).hexdigest() 73 | cache_dir = cache_dir.format(ekh_version=ekh_version, hash=hashed_name) 74 | cache_fname = cache_fname.format(ekh_version=ekh_version, hash=hashed_name) 75 | cache_filepath = os.path.join(cache_dir, cache_fname) 76 | 77 | if os.path.isfile(cache_filepath): 78 | print(f"Loading river network from cache ({cache_filepath}).") 79 | return RiverNetwork(joblib.load(cache_filepath)) 80 | else: 81 | print(f"River network not found in cache ({cache_filepath}).") 82 | os.makedirs(cache_dir, exist_ok=True) 83 | else: 84 | print("Cache disabled.") 85 | 86 | network = func(path, river_network_format, source) 87 | 88 | if use_cache: 89 | joblib.dump(network._storage, cache_filepath, compress=cache_compression) 90 | print(f"River network loaded, saving to cache ({cache_filepath}).") 91 | 92 | return network 93 | 94 | return wrapper 95 | -------------------------------------------------------------------------------- /docs/source/userguide/xarray_array_backend.rst: -------------------------------------------------------------------------------- 1 | Handling xarray and multiple array backends 2 | =========================================== 3 | 4 | earthkit-hydro is designed to work seamlessly with xarray and multiple array backends, including numpy, cupy, torch, jax, and tensorflow. This flexibility allows users to choose the backend that best suits their computational needs, whether for CPU or GPU operations. 5 | 6 | Changing the river network array backend 7 | ---------------------------------------- 8 | 9 | By default, a river network is loaded for the numpy backend. However, it can be easily converted to other backends via the `to_device` method. 10 | 11 | .. code-block:: python 12 | 13 | import earthkit.hydro as ekh 14 | 15 | network = ekh.river_network.load("efas", "5").to_device(array_backend="torch") 16 | 17 | The network can also be transferred to a specific device such as a GPU 18 | 19 | .. code-block:: python 20 | 21 | network = ekh.river_network.load("efas", "5").to_device("cuda", "torch") 22 | 23 | xarray and array-backend agnostic operations 24 | -------------------------------------------- 25 | 26 | earthkit-hydro is created with array-backend agnostic operations in mind. It is structured such that each operation has two versions: a top-level xarray-oriented version and an array version. 27 | 28 | .. code-block:: python 29 | 30 | # xarray-oriented operation (returns xarray) 31 | ekh.submodule.operation(...) 32 | 33 | # array-oriented operations (returns arrays) 34 | ekh.submodule.array.operation(...) 35 | 36 | This design allows users to work primarily with xarray objects, while still having access to lower-level array operations when needed. 37 | 38 | The philosophy for the xarray-oriented operations is to return the same type of object as the input where possible, ensuring consistency across operations. For example, 39 | 40 | .. code-block:: python 41 | 42 | import earthkit.data as ekd 43 | network = ekh.river_network.load("efas", "5") 44 | 45 | # xarray dataset inputted, xarray dataset returned 46 | ds = ekd.from_source("file", "data.nc").to_xarray() 47 | output = ekh.upstream.sum(network, ds) 48 | assert isinstance(output, xr.Dataset) 49 | 50 | # xarray dataarray inputted, xarray dataarray returned 51 | da = ds['main_variable'] 52 | output = ekh.upstream.sum(network, da) 53 | assert isinstance(output, xr.DataArray) 54 | 55 | # If no xarray is provided, dataarray is returned 56 | # array inputted, xarray dataarray returned 57 | arr = np.ones(network.shape) 58 | output = ekh.upstream.sum(network, da) 59 | assert isinstance(output, xr.DataArray) 60 | 61 | Array backends are automatically detected from the river network i.e. 62 | 63 | .. code-block:: python 64 | 65 | network = ekh.river_network.load("efas", "5") 66 | input_array = np.ones(network.shape) 67 | output = ekh.upstream.array.sum(network, input_array) # numpy array returned 68 | assert isinstance(output, numpy.ndarray) 69 | 70 | network = ekh.river_network.load("efas", "5").to_device(array_backend="torch") 71 | input_array = torch.ones(network.shape) 72 | output = ekh.upstream.array.sum(network, input_array) # torch tensor returned 73 | assert isinstance(output, torch.Tensor) 74 | 75 | # Note: trying to use a numpy array with a torch-backed river network will raise 76 | 77 | This means that users can switch between array backends without changing their code, as long as the input and output types are consistent. It also allows seamless support for xarray objects with a cupy backend. 78 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_backends/tensorflow_backend.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .array_backend import ArrayBackend 4 | 5 | 6 | class TFBackend(ArrayBackend): 7 | def __init__(self): 8 | super().__init__(tf) 9 | 10 | @property 11 | def name(self): 12 | return "tensorflow" 13 | 14 | def copy(self, x): 15 | return x 16 | 17 | def scatter_assign(self, target, indices, updates): 18 | target_shape = tf.shape(target) 19 | batch_dims = target_shape[:-1] 20 | num_batch = tf.reduce_prod(batch_dims) 21 | num_idx = tf.shape(indices)[0] 22 | 23 | flat_target = tf.reshape(target, (num_batch, -1)) 24 | flat_values = tf.reshape(updates, (num_batch, num_idx)) 25 | 26 | batch_range = tf.range(num_batch)[:, None] 27 | batch_ids = tf.tile(batch_range, [1, num_idx]) 28 | scatter_idx = tf.stack( 29 | [batch_ids, tf.tile(tf.expand_dims(indices, 0), [num_batch, 1])], axis=-1 30 | ) 31 | scatter_idx = tf.reshape(scatter_idx, (-1, 2)) 32 | 33 | scatter_vals = tf.reshape(flat_values, (-1,)) 34 | flat_result = tf.tensor_scatter_nd_update( 35 | flat_target, scatter_idx, scatter_vals 36 | ) 37 | 38 | return tf.reshape(flat_result, target_shape) 39 | 40 | def scatter_add(self, target, indices, updates): 41 | target_shape = tf.shape(target) 42 | batch_shape = target_shape[:-1] 43 | depth = target_shape[-1] 44 | num_indices = tf.shape(indices)[0] 45 | 46 | flat_batch_size = tf.reduce_prod(batch_shape) 47 | K = num_indices 48 | D = depth 49 | updates_flat = tf.reshape(updates, [-1]) 50 | 51 | segment_ids = tf.tile(indices, [flat_batch_size]) 52 | 53 | batch_ids = tf.repeat(tf.range(flat_batch_size, dtype=tf.int32), repeats=K) 54 | 55 | combined_segments = batch_ids * D + segment_ids 56 | 57 | scattered_flat = tf.math.unsorted_segment_sum( 58 | data=updates_flat, 59 | segment_ids=combined_segments, 60 | num_segments=flat_batch_size * D, 61 | ) 62 | 63 | scattered = tf.reshape(scattered_flat, [flat_batch_size, D]) 64 | 65 | result = tf.reshape(scattered, tf.concat([batch_shape, [depth]], axis=0)) 66 | return target + result 67 | 68 | def gather(self, arr, indices, axis=-1): 69 | return tf.gather(arr, indices, axis=axis) 70 | 71 | def full_like(self, arr, value, *args, **kwargs): 72 | return tf.fill(tf.shape(arr), value, *args, **kwargs) 73 | 74 | def full(self, shape, value, *args, **kwargs): 75 | kwargs.pop("device") 76 | dtype = kwargs.pop("dtype") 77 | out = tf.fill(shape, value, *args, **kwargs) 78 | if dtype: 79 | return tf.cast(out, dtype) 80 | else: 81 | return out 82 | 83 | @property 84 | def nan(self): 85 | return float("nan") 86 | 87 | @property 88 | def inf(self): 89 | return float("inf") 90 | 91 | def asarray(self, arr, dtype=None, device=None, copy=None): 92 | tensor = tf.convert_to_tensor(arr, dtype=dtype) 93 | 94 | if copy and device is None: 95 | tensor = tf.identity(tensor) 96 | 97 | if device is not None: 98 | with tf.device(device): 99 | tensor = tf.identity(tensor) 100 | return tensor 101 | 102 | def clip(self, *args, **kwargs): 103 | return tf.clip_by_value(*args, **kwargs) 104 | -------------------------------------------------------------------------------- /docs/source/userguide/flow_accumulations.rst: -------------------------------------------------------------------------------- 1 | Flow accumulations 2 | ================== 3 | 4 | Flow accumulations are a fundamental aspect of hydrological modeling, allowing for the analysis of how water flows through a river network. 5 | Fundamentally, there are two different types of flow accumulations: full flow accumulations (global aggregation) and one-step neighbor accumulations (local aggregation). 6 | 7 | Full flow accumulation (global aggregation) 8 | ------------------------------------------- 9 | 10 | A full flow accumulation is a global aggregation of flow across the entire river network. 11 | 12 | .. image:: ../../images/accuflux.gif 13 | :width: 250px 14 | :align: right 15 | 16 | Typically, this is computed by starting from the sources and flowing downstream until the sinks. 17 | The most common aggregation function is the sum, but it also possible to compute averages, maximums, minimums, etc. over all upstream nodes. 18 | 19 | This can be done in earthkit-hydro using the `upstream` submodule, which computes a metric of a field over all upstream nodes in the river network. 20 | 21 | .. raw:: html 22 | 23 |
24 | 25 | .. code-block:: python 26 | 27 | network = ekh.river_network.load("efas", "5") 28 | field = np.ones(network.n_nodes) 29 | node_weights = np.ones(network.n_nodes) # optional weights for the nodes 30 | edge_weights = np.ones(network.n_edges) # optional weights for the edges 31 | 32 | upstream_sum = ekh.upstream.sum(network, field, node_weights, edge_weights) 33 | upstream_mean = ekh.upstream.mean(network, field, node_weights, edge_weights) 34 | upstream_max = ekh.upstream.max(network, field, node_weights, edge_weights) 35 | upstream_min = ekh.upstream.min(network, field, node_weights, edge_weights) 36 | upstream_std = ekh.upstream.std(network, field, node_weights, edge_weights) 37 | upstream_var = ekh.upstream.var(network, field, node_weights, edge_weights) 38 | 39 | Whilst typically flow accumulations go from sources to sinks, it is also possible to compute the flow accumulation in the reverse direction, from sinks to sources. 40 | The `downstream` submodule provides this functionality, with an analagous API to the `upstream` submodule. 41 | 42 | .. code-block:: python 43 | 44 | downstream_sum = ekh.downstream.sum(network, field, node_weights, edge_weights) 45 | downstream_mean = ekh.downstream.mean(network, field, node_weights, edge_weights) 46 | downstream_max = ekh.downstream.max(network, field, node_weights, edge_weights) 47 | downstream_min = ekh.downstream.min(network, field, node_weights, edge_weights) 48 | downstream_std = ekh.downstream.std(network, field, node_weights, edge_weights) 49 | downstream_var = ekh.downstream.var(network, field, node_weights, edge_weights) 50 | 51 | One-step neighbor accumulation (local aggregation) 52 | -------------------------------------------------- 53 | 54 | Contrarily to a global accumulation, a one-step neighbor accumulation is a local aggregation of flow across the immediate neighbors of each node in the river network. 55 | This is analagous to a message passing operation in graph networks, where each node receives a message from its neighbors and aggregates it. 56 | 57 | Again, typically this is computed downstream, but it is also possible to compute it upstream. Both of these functionalities are provided by the `move` submodule. 58 | The aggregation function is specified via the `metric` argument. 59 | 60 | .. code-block:: python 61 | 62 | ekh.move.downstream(network, field, node_weights, edge_weights, metric='sum') 63 | ekh.move.upstream(network, field, node_weights, edge_weights, metric='sum') 64 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_core/accumulate.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro.data_structures._network import RiverNetwork 2 | 3 | from ._accumulate import _ufunc_to_downstream 4 | from .flow import propagate 5 | 6 | 7 | def flow_downstream( 8 | xp, 9 | river_network, 10 | field, 11 | func, 12 | node_additive_weight=None, 13 | node_multiplicative_weight=None, 14 | node_modifier_use_upstream=True, 15 | edge_additive_weight=None, 16 | edge_multiplicative_weight=None, 17 | ): 18 | invert_graph = False 19 | return flow( 20 | xp, 21 | river_network, 22 | field, 23 | func, 24 | invert_graph, 25 | node_additive_weight, 26 | node_multiplicative_weight, 27 | node_modifier_use_upstream, 28 | edge_additive_weight, 29 | edge_multiplicative_weight, 30 | ) 31 | 32 | 33 | def flow_upstream( 34 | xp, 35 | river_network, 36 | field, 37 | func, 38 | node_additive_weight=None, 39 | node_multiplicative_weight=None, 40 | node_modifier_use_upstream=True, 41 | edge_additive_weight=None, 42 | edge_multiplicative_weight=None, 43 | ): 44 | invert_graph = True 45 | return flow( 46 | xp, 47 | river_network, 48 | field, 49 | func, 50 | invert_graph, 51 | node_additive_weight, 52 | node_multiplicative_weight, 53 | node_modifier_use_upstream, 54 | edge_additive_weight, 55 | edge_multiplicative_weight, 56 | ) 57 | 58 | 59 | def flow( 60 | xp, 61 | river_network: RiverNetwork, 62 | field, 63 | func, 64 | invert_graph=False, 65 | node_additive_weight=None, 66 | node_multiplicative_weight=None, 67 | node_modifier_use_upstream=True, 68 | edge_additive_weight=None, 69 | edge_multiplicative_weight=None, 70 | ): 71 | 72 | return flow_python( 73 | xp, 74 | river_network, 75 | field, 76 | func, 77 | invert_graph, 78 | node_additive_weight, 79 | node_multiplicative_weight, 80 | node_modifier_use_upstream, 81 | edge_additive_weight, 82 | edge_multiplicative_weight, 83 | ) 84 | 85 | 86 | def flow_python( 87 | xp, 88 | river_network, 89 | field, 90 | func, 91 | invert_graph=False, 92 | node_additive_weight=None, 93 | node_multiplicative_weight=None, 94 | node_modifier_use_upstream=True, 95 | edge_additive_weight=None, 96 | edge_multiplicative_weight=None, 97 | ): 98 | op = _ufunc_to_downstream 99 | 100 | def operation( 101 | field, 102 | did, 103 | uid, 104 | eid, 105 | node_additive_weight, 106 | node_multiplicative_weight, 107 | node_modifier_use_upstream, 108 | edge_additive_weight, 109 | edge_multiplicative_weight, 110 | ): 111 | return op( 112 | field, 113 | did, 114 | uid, 115 | eid, 116 | node_additive_weight, 117 | node_multiplicative_weight, 118 | node_modifier_use_upstream, 119 | edge_additive_weight, 120 | edge_multiplicative_weight, 121 | func=func, 122 | xp=xp, 123 | ) 124 | 125 | field = propagate( 126 | river_network, 127 | river_network.groups, 128 | field, 129 | invert_graph, 130 | operation, 131 | node_additive_weight, 132 | node_multiplicative_weight, 133 | node_modifier_use_upstream, 134 | edge_additive_weight, 135 | edge_multiplicative_weight, 136 | ) 137 | 138 | return field 139 | -------------------------------------------------------------------------------- /docs/source/userguide/distance_length_calculations.rst: -------------------------------------------------------------------------------- 1 | Distance and length calculations 2 | ================================ 3 | 4 | In earthkit-hydro, a distinction is made between distance and length calculations. This is necessary because fundamentally the two operations are not equivalent, although they are often conflated. 5 | 6 | Distinction between a length and distance 7 | ----------------------------------------- 8 | 9 | In essence, lengths and distances try and capture two different quantities based on two different inputs. 10 | A length is calculated by considering the length of a river network in each gridcell or graph node. As a result, lengths are *node properties*. 11 | There is only one length per gridcell, even if a confluence or bifurcation occurs. 12 | 13 | Distances on the other hand are not measured at a node, but are rather specified in terms of the distance from one gridcell to another. 14 | As such, they are an *edge property* and distances can be different for each branch at a confluence and a bifurcation. 15 | 16 | Even in simple river networks without confluences or bifurcations, lengths and distances are still not equivalent. This distinction is clear in the following contrived example. 17 | 18 | .. image:: ../../images/distance_length.png 19 | :width: 500px 20 | :align: center 21 | 22 | The length here for the highlighted segment is 3. 23 | 24 | By contrast, the distance here for the highlighted segment is only 2. 25 | 26 | Maximum and minimum distances or lengths 27 | ---------------------------------------- 28 | 29 | In the above example, there was only a single path from the source node to the terminal node. 30 | However, in river networks there will often by many paths. It is thus important to consider whether one is interested in the shortest or longest path. 31 | 32 | In earthkit-hydro, all of these different quantities are easily computed via 33 | 34 | .. code-block:: python 35 | 36 | network = ekh.river_network.load("efas", "5") 37 | locations = { 38 | "station1": (10, 10), 39 | "station2": (10, 10), 40 | "station3": (10, 10) 41 | } 42 | 43 | # lengths take node-level information 44 | field = np.random.rand(network.n_nodes) 45 | max_length = ekh.length.max(network, locations, field) 46 | min_length = ekh.length.min(network, locations, field) 47 | 48 | # distances take edge-level information 49 | field = np.random.rand(network.n_edges) 50 | max_distance = ekh.distance.max(network, locations, field) 51 | min_distance = ekh.distance.min(network, locations, field) 52 | 53 | Directed and undirected distances or lengths 54 | -------------------------------------------- 55 | 56 | By default, distances and lengths are calculated downstream only. However, some use cases may be interested in upstream distances/lengths, or undirected lengths/distances. 57 | This is easily specified by the `upstream` and `downstream` arguments. 58 | 59 | .. code-block:: python 60 | 61 | min_length_upstream = ekh.length.min(network, locations, field, upstream=True, downstream=False) 62 | min_length_downstream = ekh.length.min(network, locations, field, upstream=False, downstream=True) 63 | min_length_undirected = ekh.length.min(network, locations, field, upstream=True, downstream=True) 64 | 65 | As shorthands, earthkit-hydro also provides the means of automatically computing starting from the sources or the sinks with 66 | 67 | .. code-block:: python 68 | 69 | ekh.length.to_sink(network, field, path="shortest") 70 | ekh.length.to_source(network, field, path="shortest") 71 | ekh.distance.to_sink(network, field, path="shortest") 72 | ekh.distance.to_source(network, field, path="shortest") 73 | 74 | Longest path versions are also available with `path="longest"`. 75 | -------------------------------------------------------------------------------- /src/earthkit/hydro/catchments/_xarray.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | import numpy as np 4 | import xarray as xr 5 | 6 | from earthkit.hydro._backends.find import get_array_backend 7 | from earthkit.hydro._utils.coords import get_core_dims, node_default_coord 8 | from earthkit.hydro._utils.decorators.xarray import ( 9 | assert_xr_compatible_backend, 10 | get_full_signature, 11 | get_reshuffled_func, 12 | sort_xr_nonxr_args, 13 | ) 14 | from earthkit.hydro._utils.locations import locations_to_1d 15 | 16 | 17 | def get_input_core_dims(input_core_dims, xr_args): 18 | if input_core_dims is None: 19 | input_core_dims = [get_core_dims(xr_arg) for xr_arg in xr_args] 20 | elif len(input_core_dims) == 1: 21 | input_core_dims *= len(xr_args) 22 | 23 | return input_core_dims 24 | 25 | 26 | def xarray(func): 27 | 28 | @wraps(func) 29 | def wrapper(*args, **kwargs): 30 | 31 | # Inspect the function signature and bind all arguments 32 | all_args = get_full_signature(func, *args, **kwargs) 33 | 34 | input_core_dims = all_args.pop("input_core_dims", None) 35 | 36 | assert_xr_compatible_backend(all_args["river_network"]) 37 | 38 | river_network = all_args["river_network"] 39 | 40 | xp = get_array_backend(river_network.array_backend) 41 | 42 | locations = all_args["locations"] 43 | 44 | stations_1d, locations, orig_locations = locations_to_1d( 45 | xp, river_network, locations 46 | ) 47 | 48 | all_args["locations"] = stations_1d 49 | 50 | # Separate xarray and non-xarray arguments 51 | xr_args, non_xr_kwargs, arg_order = sort_xr_nonxr_args(all_args) 52 | 53 | if len(xr_args) == 0: 54 | output = func(**all_args) 55 | 56 | ndim = output.ndim 57 | dim_names = [f"axis{i + 1}" for i in range(ndim - 1)] 58 | coords = { 59 | dim: np.arange(size) for dim, size in zip(dim_names, output.shape[:-1]) 60 | } 61 | 62 | coords[node_default_coord] = np.arange(river_network.n_nodes)[stations_1d] 63 | dim_names.append(node_default_coord) 64 | 65 | result = xr.DataArray(output, dims=dim_names, coords=coords, name="out") 66 | 67 | else: 68 | 69 | reshuffled_func = get_reshuffled_func(func, arg_order) 70 | 71 | input_core_dims = get_input_core_dims(input_core_dims, xr_args) 72 | 73 | result = xr.apply_ufunc( 74 | reshuffled_func, 75 | *xr_args, 76 | input_core_dims=input_core_dims, 77 | output_core_dims=[[node_default_coord]], 78 | dask_gufunc_kwargs={ 79 | "output_sizes": {node_default_coord: stations_1d.shape[0]} 80 | }, 81 | output_dtypes=[float], 82 | dask="parallelized", 83 | kwargs=non_xr_kwargs, 84 | ) 85 | assign_dict = { 86 | node_default_coord: ( 87 | node_default_coord, 88 | np.arange(river_network.n_nodes)[stations_1d], 89 | ) 90 | } 91 | result = result.assign_coords(**assign_dict) 92 | 93 | coords = list(river_network.coords.values())[::-1] 94 | coords_grid = np.meshgrid(*coords)[::-1] 95 | assign_dict = { 96 | k: (node_default_coord, v.flat[river_network.mask][stations_1d]) 97 | for k, v in zip(river_network.coords.keys(), coords_grid) 98 | } 99 | if isinstance(orig_locations, dict): 100 | assign_dict["name"] = (node_default_coord, list(orig_locations.keys())) 101 | result = result.assign_coords(**assign_dict) 102 | 103 | return result 104 | 105 | return wrapper 106 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=65", "setuptools-rust", "setuptools_scm[toml]>=8"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "earthkit-hydro" 7 | requires-python = ">=3.9" 8 | authors = [ 9 | { name = "European Centre for Medium-Range Weather Forecasts (ECMWF)", email = "software.support@ecmwf.int" } 10 | ] 11 | maintainers = [ 12 | {name = "Oisín M. Morrison", email = "oisin.morrison@ecmwf.int"}, 13 | {name = "Corentin Carton de Wiart", email = "corentin.carton@ecmwf.int"} 14 | ] 15 | description = "A Python library for common hydrological functions" 16 | license = { text = "Apache License Version 2.0" } 17 | classifiers = [ 18 | "Development Status :: 4 - Beta", 19 | "Intended Audience :: Science/Research", 20 | "Natural Language :: English", 21 | "License :: OSI Approved :: Apache Software License", 22 | "Operating System :: OS Independent", 23 | "Programming Language :: Python :: 3 :: Only", 24 | "Programming Language :: Python :: 3.9", 25 | "Programming Language :: Python :: 3.10", 26 | "Programming Language :: Python :: 3.11", 27 | "Programming Language :: Python :: 3.12", 28 | "Programming Language :: Python :: 3.13", 29 | "Programming Language :: Python :: 3.14", 30 | "Topic :: Scientific/Engineering" 31 | ] 32 | dynamic = ["version", "readme"] 33 | 34 | dependencies = [ 35 | "numpy", 36 | "joblib", 37 | "xarray", 38 | "earthkit-utils>=0.1.0" 39 | ] 40 | 41 | [project.urls] 42 | repository = "https://github.com/ecmwf/earthkit-hydro" 43 | documentation = "https://earthkit-hydro.readthedocs.io/" 44 | issues = "https://github.com/ecmwf/earthkit-hydro/issues" 45 | 46 | [project.optional-dependencies] 47 | readers = [ 48 | "earthkit-data[geotiff]>=0.13.8" 49 | ] 50 | grit = [ 51 | "geopandas", 52 | "pandas" 53 | ] 54 | tests = [ 55 | "pytest", 56 | "torch", 57 | "jax" 58 | ] 59 | dev = [ 60 | "pytest", 61 | "pre-commit" 62 | ] 63 | docs = [ 64 | "sphinx", 65 | "furo", 66 | "sphinxcontrib-bibtex", 67 | "nbsphinx", 68 | "nbconvert", 69 | "ipykernel" 70 | ] 71 | all = [ 72 | "earthkit-data[geotiff]>=0.13.8", 73 | "pytest", 74 | "pre-commit" 75 | ] 76 | 77 | [tool.black] 78 | line-length = 88 79 | skip-string-normalization = false 80 | 81 | [tool.isort] 82 | profile = "black" # Ensures compatibility with Black's formatting. 83 | line_length = 88 # Same as Black's line length for consistency. 84 | 85 | # Linting settings 86 | [tool.ruff] 87 | line-length = 88 88 | 89 | [tool.ruff.format] 90 | quote-style = "double" 91 | 92 | [tool.ruff.lint.per-file-ignores] 93 | "__init__.py" = [ 94 | "F401", # unused imports 95 | ] 96 | "tests/*" = [ 97 | "F405", # variable may be undefined, or defined from star imports 98 | "F403", # use of wildcard imports 99 | ] 100 | 101 | # Testing 102 | [tool.pytest] 103 | addopts = "--pdbcls=IPython.terminal.debugger:Pdb" 104 | testpaths = ["tests"] 105 | 106 | # Packaging/setuptools options 107 | [tool.setuptools] 108 | include-package-data = true 109 | 110 | [tool.setuptools.dynamic] 111 | readme = {file = ["README.md"], content-type = "text/markdown"} 112 | 113 | [tool.setuptools.packages.find] 114 | include = [ "earthkit.hydro" ] 115 | where = [ "src/" ] 116 | 117 | [tool.setuptools_scm] 118 | version_file = "src/earthkit/hydro/_version.py" 119 | version_file_template = ''' 120 | # Do not change! Do not track in version control! 121 | __version__ = "{version}" 122 | ''' 123 | parentdir_prefix_version='earthkit-hydro-' # get version from GitHub-like tarballs 124 | fallback_version='0.1.0' 125 | local_scheme = "no-local-version" 126 | 127 | # takes inspiration from https://github.com/pypa/cibuildwheel/discussions/1814 128 | [tool.cibuildwheel] 129 | skip = ["*musl*", "*-win_arm64"] # ignore some problematic platforms 130 | linux.before-all = "curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal" 131 | linux.environment = { PATH="$HOME/.cargo/bin:$PATH" } 132 | macos.before-all = "curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile minimal" 133 | windows.before-all = "rustup target add aarch64-pc-windows-msvc i686-pc-windows-msvc x86_64-pc-windows-msvc" 134 | windows.archs = "all" 135 | -------------------------------------------------------------------------------- /src/earthkit/hydro/subnetwork/_toplevel.py: -------------------------------------------------------------------------------- 1 | import copy as cp 2 | 3 | from earthkit.hydro._backends.numpy_backend import NumPyBackend 4 | from earthkit.hydro._utils.decorators.masking import mask_last2_dims 5 | from earthkit.hydro.data_structures import RiverNetwork 6 | 7 | np = NumPyBackend() 8 | 9 | 10 | def from_mask(river_network: RiverNetwork, node_mask=None, edge_mask=None, copy=True): 11 | """ 12 | Create a subnetwork from a river network. 13 | 14 | Parameters 15 | ---------- 16 | river_network : RiverNetwork 17 | Original river network from which to create a subnetwork. 18 | node_mask : array, optional 19 | A mask of the network nodes or gridcells. Default is None (all True). 20 | edge_mask : array, optional 21 | A mask of the network edges. Default is None (all True). 22 | copy : bool, optional 23 | Whether or not to modify the original river network or return a copy. Default is True. 24 | 25 | Returns 26 | ------- 27 | RiverNetwork 28 | The river network object created from the given data. 29 | """ 30 | if river_network.array_backend != "numpy" or copy is not True: 31 | raise NotImplementedError 32 | 33 | if node_mask is None and edge_mask is None: 34 | return cp.deepcopy(river_network) 35 | 36 | if node_mask is not None: 37 | if node_mask.shape[-2:] == river_network.shape: 38 | node_mask = mask_last2_dims( 39 | np, node_mask, river_network.mask, node_mask.shape 40 | ) 41 | 42 | node_relabel = np.empty(river_network.n_nodes, dtype=int) 43 | node_relabel[node_mask] = np.arange(node_mask.sum()) 44 | 45 | storage = cp.deepcopy(river_network._storage) 46 | if edge_mask is not None and node_mask is not None: 47 | valid_edges = edge_mask[storage.sorted_data[2]] & ( 48 | node_mask[storage.sorted_data[0]] & node_mask[storage.sorted_data[1]] 49 | ) 50 | elif edge_mask is None: 51 | valid_edges = ( 52 | node_mask[storage.sorted_data[0]] & node_mask[storage.sorted_data[1]] 53 | ) 54 | else: 55 | valid_edges = edge_mask[storage.sorted_data[2]] 56 | 57 | original_order_edge_mask = np.empty(river_network.n_edges, dtype=bool) 58 | original_order_edge_mask[storage.sorted_data[2]] = valid_edges 59 | edge_relabel = np.empty(river_network.n_edges, dtype=int) 60 | edge_relabel[original_order_edge_mask] = np.arange(original_order_edge_mask.sum()) 61 | 62 | storage.sorted_data = storage.sorted_data[..., valid_edges] 63 | storage.sorted_data[0] = node_relabel[storage.sorted_data[0]] 64 | storage.sorted_data[1] = node_relabel[storage.sorted_data[1]] 65 | storage.sorted_data[2] = edge_relabel[storage.sorted_data[2]] 66 | 67 | storage.splits = np.cumsum(valid_edges)[storage.splits - 1] 68 | storage.mask = storage.mask[node_mask] 69 | storage.n_nodes = storage.mask.shape[0] 70 | storage.n_edges = storage.sorted_data.shape[1] 71 | 72 | return RiverNetwork(storage) 73 | 74 | 75 | def crop(river_network: RiverNetwork, copy=True): 76 | """ 77 | Crop a gridded network to the minimum bounding grid. 78 | 79 | Parameters 80 | ---------- 81 | river_network : RiverNetwork 82 | Original river network from which to create a cropped network. 83 | copy : bool, optional 84 | Whether or not to modify the original river network or return a copy. Default is True. 85 | 86 | Returns 87 | ------- 88 | RiverNetwork 89 | The river network object created from the given data. 90 | """ 91 | 92 | if river_network.array_backend != "numpy" or copy is not True: 93 | raise NotImplementedError 94 | 95 | storage = cp.deepcopy(river_network._storage) 96 | 97 | rows, cols = np.unravel_index(storage.mask, shape=(storage.shape)) 98 | 99 | row_min, row_max = rows.min(), rows.max() 100 | col_min, col_max = cols.min(), cols.max() 101 | 102 | storage.shape = (int(row_max - row_min + 1), int(col_max - col_min + 1)) 103 | 104 | storage.mask = np.ravel_multi_index( 105 | (rows - row_min, cols - col_min), dims=storage.shape 106 | ) 107 | 108 | for i, key in enumerate(storage.coords.keys()): 109 | if i == 0: 110 | storage.coords[key] = storage.coords[key][row_min : row_max + 1] 111 | elif i == 1: 112 | storage.coords[key] = storage.coords[key][col_min : col_max + 1] 113 | else: 114 | raise ValueError("coords must not have more than 2 keys.") 115 | 116 | return RiverNetwork(storage) 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | 5 | 6 |

7 | 8 |

9 | 10 | ECMWF Software EnginE 11 | 12 | 13 | Maturity Level 14 | 15 | 18 | 19 | Licence 20 | 21 | 22 | Latest Release 23 | 24 |

25 | 26 |

27 | 29 | Installation 30 | • 31 | Documentation 32 |

33 | 34 | > \[!IMPORTANT\] 35 | > This software is **Incubating** and subject to ECMWF's guidelines on [Software Maturity](https://github.com/ecmwf/codex/raw/refs/heads/main/Project%20Maturity). 36 | 37 | **earthkit-hydro** is a Python library for common hydrological functions. It is the hydrological component of [earthkit](https://github.com/ecmwf/earthkit). 38 | 39 | ## Main Features 40 | 41 | 42 | 43 | 44 | 45 | 46 |
Adapted from: doc_figureArray backends with xr
47 | 48 | - Catchment delineation 49 | - Catchment-based statistics 50 | - Directional flow-based accumulations 51 | - River network distance calculations 52 | - Upstream/downstream field propagation 53 | - Bifurcation handling 54 | - Custom weighting and decay support 55 | - Support for PCRaster, CaMa-Flood, HydroSHEDS, MERIT-Hydro and GRIT river network formats 56 | - Compatible with major array-backends: xarray, numpy, cupy, torch, jax, mlx and tensorflow 57 | - GPU support 58 | - Differentiable operations suitable for machine learning 59 | 60 | 61 | ## Installation 62 | For a default installation, run 63 | 64 | ``` 65 | pip install earthkit-hydro 66 | ``` 67 | 68 | *Developer instructions:* 69 | 70 | For a developer setup (includes linting and test libraries), run 71 | 72 | ``` 73 | conda create -n hydro python=3.12 74 | conda activate hydro 75 | conda install -c conda-forge rust 76 | git clone https://github.com/ecmwf/earthkit-hydro.git 77 | cd earthkit-hydro 78 | pip install -e .[dev] 79 | pre-commit install 80 | ``` 81 | Note: this project is a mixed Rust-Python project with a pure Python fallback. To handle this, the behaviour of the install is based on an environmental variable `USE_RUST`, with the following behaviour 82 | - `Not set or any other value (default behaviour)`: 83 | Attempts to build with Rust and if failure, skips and falls back to pure Python implementation. 84 | - `USE_RUST=0`: 85 | Builds pure Python implementation. 86 | - `USE_RUST=1`: 87 | Builds with Rust and fails if something goes wrong. 88 | 89 | 90 | ## Licence 91 | 92 | ``` 93 | Copyright 2024, European Centre for Medium Range Weather Forecasts. 94 | 95 | Licensed under the Apache License, Version 2.0 (the "License"); 96 | you may not use this file except in compliance with the License. 97 | You may obtain a copy of the License at 98 | 99 | http://www.apache.org/licenses/LICENSE-2.0 100 | 101 | Unless required by applicable law or agreed to in writing, software 102 | distributed under the License is distributed on an "AS IS" BASIS, 103 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 104 | See the License for the specific language governing permissions and 105 | limitations under the License. 106 | 107 | In applying this licence, ECMWF does not waive the privileges and immunities 108 | granted to it by virtue of its status as an intergovernmental organisation 109 | nor does it submit to any jurisdiction. 110 | ``` 111 | -------------------------------------------------------------------------------- /src/earthkit/hydro/downstream/array/_operations.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro._core.online import calculate_online_metric 2 | from earthkit.hydro._utils.decorators import mask, multi_backend 3 | 4 | 5 | def calculate_downstream_metric( 6 | xp, 7 | river_network, 8 | field, 9 | metric, 10 | node_weights, 11 | edge_weights, 12 | ): 13 | return calculate_online_metric( 14 | xp, 15 | river_network, 16 | field, 17 | metric, 18 | node_weights, 19 | edge_weights, 20 | flow_direction="up", 21 | ) 22 | 23 | 24 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 25 | def var(xp, river_network, field, node_weights, edge_weights, return_type): 26 | return_type = river_network.return_type if return_type is None else return_type 27 | if return_type not in ["gridded", "masked"]: 28 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 29 | decorated_calculate_downstream_metric = mask(return_type == "gridded")( 30 | calculate_downstream_metric 31 | ) 32 | return decorated_calculate_downstream_metric( 33 | xp, 34 | river_network, 35 | field, 36 | "var", 37 | node_weights, 38 | edge_weights, 39 | ) 40 | 41 | 42 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 43 | def std(xp, river_network, field, node_weights, edge_weights, return_type): 44 | return_type = river_network.return_type if return_type is None else return_type 45 | if return_type not in ["gridded", "masked"]: 46 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 47 | decorated_calculate_downstream_metric = mask(return_type == "gridded")( 48 | calculate_downstream_metric 49 | ) 50 | return decorated_calculate_downstream_metric( 51 | xp, 52 | river_network, 53 | field, 54 | "std", 55 | node_weights, 56 | edge_weights, 57 | ) 58 | 59 | 60 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 61 | def mean(xp, river_network, field, node_weights, edge_weights, return_type): 62 | return_type = river_network.return_type if return_type is None else return_type 63 | if return_type not in ["gridded", "masked"]: 64 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 65 | decorated_calculate_downstream_metric = mask(return_type == "gridded")( 66 | calculate_downstream_metric 67 | ) 68 | return decorated_calculate_downstream_metric( 69 | xp, 70 | river_network, 71 | field, 72 | "mean", 73 | node_weights, 74 | edge_weights, 75 | ) 76 | 77 | 78 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 79 | def sum(xp, river_network, field, node_weights, edge_weights, return_type): 80 | return_type = river_network.return_type if return_type is None else return_type 81 | if return_type not in ["gridded", "masked"]: 82 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 83 | decorated_calculate_downstream_metric = mask(return_type == "gridded")( 84 | calculate_downstream_metric 85 | ) 86 | return decorated_calculate_downstream_metric( 87 | xp, 88 | river_network, 89 | field, 90 | "sum", 91 | node_weights, 92 | edge_weights, 93 | ) 94 | 95 | 96 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 97 | def min(xp, river_network, field, node_weights, edge_weights, return_type): 98 | return_type = river_network.return_type if return_type is None else return_type 99 | if return_type not in ["gridded", "masked"]: 100 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 101 | decorated_calculate_downstream_metric = mask(return_type == "gridded")( 102 | calculate_downstream_metric 103 | ) 104 | return decorated_calculate_downstream_metric( 105 | xp, 106 | river_network, 107 | field, 108 | "min", 109 | node_weights, 110 | edge_weights, 111 | ) 112 | 113 | 114 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 115 | def max(xp, river_network, field, node_weights, edge_weights, return_type): 116 | return_type = river_network.return_type if return_type is None else return_type 117 | if return_type not in ["gridded", "masked"]: 118 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 119 | decorated_calculate_downstream_metric = mask(return_type == "gridded")( 120 | calculate_downstream_metric 121 | ) 122 | return decorated_calculate_downstream_metric( 123 | xp, 124 | river_network, 125 | field, 126 | "max", 127 | node_weights, 128 | edge_weights, 129 | ) 130 | -------------------------------------------------------------------------------- /src/earthkit/hydro/move/array/_toplevel.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro.move.array import _operations 2 | 3 | 4 | def upstream( 5 | river_network, 6 | field, 7 | node_weights=None, 8 | edge_weights=None, 9 | metric="sum", 10 | return_type=None, 11 | ): 12 | r""" 13 | Moves a field upstream. 14 | 15 | Computes a one-step neighbor accumulation (local aggregation) moving upstream only. 16 | 17 | The accumulation is defined as: 18 | 19 | .. math:: 20 | :nowrap: 21 | 22 | \begin{align*} 23 | x'_i &= w'_i \cdot x_i \\ 24 | n_j &= \bigoplus_{i \in \mathrm{Down}(j)} w_{ij} \cdot x'_i 25 | \end{align*} 26 | 27 | where: 28 | 29 | - :math:`x_i` is the input value at node :math:`i` (e.g., rainfall), 30 | - :math:`w'_i` is the node weight (e.g., pixel area), 31 | - :math:`w_{ij}` is the edge weight from node :math:`i` to node :math:`j` (e.g. discharge partitioning ratio), 32 | - :math:`\mathrm{Down}(j)` is the set of downstream nodes flowing out of node :math:`j`, 33 | - :math:`\bigoplus` is the aggregation function (e.g. a summation). 34 | - :math:`n_j` is the weighted aggregated value at node :math:`j`. 35 | 36 | Sinks are given a value of 0. 37 | 38 | Parameters 39 | ---------- 40 | river_network : RiverNetwork 41 | A river network object. 42 | field : array-like 43 | An array containing field values defined on river network nodes or gridcells. 44 | node_weights : array-like, optional 45 | Array of weights for each river network node or gridcell. Default is None (unweighted). 46 | edge_weights : array-like, optional 47 | Array of weights for each edge. Default is None (unweighted). 48 | metric : str, optional 49 | Aggregation function to apply. Options are 'var', 'std', 'mean', 'sum', 'min' and 'max'. Default is `'sum'`. 50 | return_type : str, optional 51 | Either "masked", "gridded" or None. If None (default), uses `river_network.return_type`. 52 | 53 | 54 | Returns 55 | ------- 56 | array-like 57 | Array of values after movement up the river network for every river network node or gridcell, depending on `return_type`. 58 | """ 59 | return _operations.upstream( 60 | river_network=river_network, 61 | field=field, 62 | node_weights=node_weights, 63 | edge_weights=edge_weights, 64 | metric=metric, 65 | return_type=return_type, 66 | ) 67 | 68 | 69 | def downstream( 70 | river_network, 71 | field, 72 | node_weights=None, 73 | edge_weights=None, 74 | metric="sum", 75 | return_type=None, 76 | ): 77 | r""" 78 | Moves a field downstream. 79 | 80 | Computes a one-step neighbor accumulation (local aggregation) moving downstream only. 81 | 82 | The accumulation is defined as: 83 | 84 | .. math:: 85 | :nowrap: 86 | 87 | \begin{align*} 88 | x'_i &= w'_i \cdot x_i \\ 89 | n_j &= \bigoplus_{i \in \mathrm{Up}(j)} w_{ij} \cdot x'_i 90 | \end{align*} 91 | 92 | where: 93 | 94 | - :math:`x_i` is the input value at node :math:`i` (e.g., rainfall), 95 | - :math:`w'_i` is the node weight (e.g., pixel area), 96 | - :math:`w_{ij}` is the edge weight from node :math:`i` to node :math:`j` (e.g. discharge partitioning ratio), 97 | - :math:`\mathrm{Up}(j)` is the set of upstream nodes flowing into node :math:`j`, 98 | - :math:`\bigoplus` is the aggregation function (e.g. a summation). 99 | - :math:`n_j` is the weighted aggregated value at node :math:`j`. 100 | 101 | Sources are given a value of 0. 102 | 103 | Parameters 104 | ---------- 105 | river_network : RiverNetwork 106 | A river network object. 107 | field : array-like 108 | An array containing field values defined on river network nodes or gridcells. 109 | node_weights : array-like, optional 110 | Array of weights for each river network node or gridcell. Default is None (unweighted). 111 | edge_weights : array-like, optional 112 | Array of weights for each edge. Default is None (unweighted). 113 | metric : str, optional 114 | Aggregation function to apply. Options are 'var', 'std', 'mean', 'sum', 'min' and 'max'. Default is `'sum'`. 115 | return_type : str, optional 116 | Either "masked", "gridded" or None. If None (default), uses `river_network.return_type`. 117 | 118 | 119 | Returns 120 | ------- 121 | array-like 122 | Array of values after movement down the river network for every river network node or gridcell, depending on `return_type`. 123 | """ 124 | return _operations.downstream( 125 | river_network=river_network, 126 | field=field, 127 | node_weights=node_weights, 128 | edge_weights=edge_weights, 129 | metric=metric, 130 | return_type=return_type, 131 | ) 132 | -------------------------------------------------------------------------------- /src/earthkit/hydro/upstream/array/_operations.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro._core.online import calculate_online_metric 2 | from earthkit.hydro._utils.decorators import mask, multi_backend 3 | 4 | 5 | def calculate_upstream_metric( 6 | xp, 7 | river_network, 8 | field, 9 | metric, 10 | node_weights, 11 | edge_weights, 12 | ): 13 | return calculate_online_metric( 14 | xp, 15 | river_network, 16 | field, 17 | metric, 18 | node_weights, 19 | edge_weights, 20 | flow_direction="down", 21 | ) 22 | 23 | 24 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 25 | def var(xp, river_network, field, node_weights, edge_weights, return_type): 26 | return_type = river_network.return_type if return_type is None else return_type 27 | if return_type not in ["gridded", "masked"]: 28 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 29 | decorated_calculate_upstream_metric = mask(return_type == "gridded")( 30 | calculate_upstream_metric 31 | ) 32 | return decorated_calculate_upstream_metric( 33 | xp, 34 | river_network, 35 | field, 36 | "var", 37 | node_weights, 38 | edge_weights, 39 | ) 40 | 41 | 42 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 43 | def std( 44 | xp, 45 | river_network, 46 | field, 47 | node_weights, 48 | edge_weights, 49 | return_type, 50 | ): 51 | return_type = river_network.return_type if return_type is None else return_type 52 | if return_type not in ["gridded", "masked"]: 53 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 54 | decorated_calculate_upstream_metric = mask(return_type == "gridded")( 55 | calculate_upstream_metric 56 | ) 57 | return decorated_calculate_upstream_metric( 58 | xp, 59 | river_network, 60 | field, 61 | "std", 62 | node_weights, 63 | edge_weights, 64 | ) 65 | 66 | 67 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 68 | def mean(xp, river_network, field, node_weights, edge_weights, return_type): 69 | return_type = river_network.return_type if return_type is None else return_type 70 | if return_type not in ["gridded", "masked"]: 71 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 72 | decorated_calculate_upstream_metric = mask(return_type == "gridded")( 73 | calculate_upstream_metric 74 | ) 75 | return decorated_calculate_upstream_metric( 76 | xp, 77 | river_network, 78 | field, 79 | "mean", 80 | node_weights, 81 | edge_weights, 82 | ) 83 | 84 | 85 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 86 | def sum(xp, river_network, field, node_weights, edge_weights, return_type): 87 | return_type = river_network.return_type if return_type is None else return_type 88 | if return_type not in ["gridded", "masked"]: 89 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 90 | decorated_calculate_upstream_metric = mask(return_type == "gridded")( 91 | calculate_upstream_metric 92 | ) 93 | return decorated_calculate_upstream_metric( 94 | xp, 95 | river_network, 96 | field, 97 | "sum", 98 | node_weights, 99 | edge_weights, 100 | ) 101 | 102 | 103 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 104 | def min(xp, river_network, field, node_weights, edge_weights, return_type): 105 | return_type = river_network.return_type if return_type is None else return_type 106 | if return_type not in ["gridded", "masked"]: 107 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 108 | decorated_calculate_upstream_metric = mask(return_type == "gridded")( 109 | calculate_upstream_metric 110 | ) 111 | return decorated_calculate_upstream_metric( 112 | xp, 113 | river_network, 114 | field, 115 | "min", 116 | node_weights, 117 | edge_weights, 118 | ) 119 | 120 | 121 | @multi_backend(jax_static_args=["xp", "river_network", "return_type"]) 122 | def max(xp, river_network, field, node_weights, edge_weights, return_type): 123 | return_type = river_network.return_type if return_type is None else return_type 124 | if return_type not in ["gridded", "masked"]: 125 | raise ValueError("return_type must be either 'gridded' or 'masked'.") 126 | decorated_calculate_upstream_metric = mask(return_type == "gridded")( 127 | calculate_upstream_metric 128 | ) 129 | return decorated_calculate_upstream_metric( 130 | xp, 131 | river_network, 132 | field, 133 | "max", 134 | node_weights, 135 | edge_weights, 136 | ) 137 | -------------------------------------------------------------------------------- /docs/source/userguide/pcraster.rst: -------------------------------------------------------------------------------- 1 | Migrating from PCRaster 2 | ======================= 3 | 4 | earthkit-hydro can be used as a drop-in replacement for PCRaster in many cases, offering a substantial speedup. As a design decision, it does not attempt to match PCRaster's API, meaning that migrated code may look slightly different. 5 | 6 | Here is a useful summary table of some common translations between earthkit-hydro and PCRaster. 7 | 8 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 9 | | **PCRaster** | **earthkit-hydro** | **Note** | 10 | +==================+========================+=========================================================================================================================+ 11 | | accuflux | upstream.sum | | 12 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 13 | | catchmenttotal | upstream.sum | | 14 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 15 | | downstream | move.upstream | | 16 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 17 | | upstream | move.downstream | | 18 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 19 | | catchment | catchments.find | | 20 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 21 | | subcatchment | catchments.find | overwrite=False | 22 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 23 | | path | upstream.max | | 24 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 25 | | ldddist | distance.min | friction input is slightly different from weights; by default, distance between two nodes is one regardless of diagonal | 26 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 27 | | downstreamdist | distance.to_sink | Same caveats as for ldddist | 28 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 29 | | slopelength | distance.to_source | path="longest"; same caveats as for ldddist | 30 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 31 | | lddmask | subnetwork.from_mask | | 32 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 33 | | abs, sin, cos, | np.abs, np.sin, | Any array operations can be directly used (example shown for NumPy backend) | 34 | | tan, ... | np.cos, np.tan, ... | | 35 | +------------------+------------------------+-------------------------------------------------------------------------------------------------------------------------+ 36 | 37 | 38 | Points of difference 39 | 40 | - earthkit-hydro treats missing values as np.nans i.e. any arithmetic involving a missing value will return a missing value. PCRaster does not always handle missing values exactly the same. 41 | - earthkit-hydro can handle vector fields. 42 | -------------------------------------------------------------------------------- /src/earthkit/hydro/move/_toplevel.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro._utils.decorators import xarray 2 | from earthkit.hydro.move import array 3 | 4 | 5 | @xarray 6 | def upstream( 7 | river_network, 8 | field, 9 | node_weights=None, 10 | edge_weights=None, 11 | metric="sum", 12 | return_type=None, 13 | input_core_dims=None, 14 | ): 15 | r""" 16 | Moves a field upstream. 17 | 18 | Computes a one-step neighbor accumulation (local aggregation) moving upstream only. 19 | 20 | The accumulation is defined as: 21 | 22 | .. math:: 23 | :nowrap: 24 | 25 | \begin{align*} 26 | x'_i &= w'_i \cdot x_i \\ 27 | n_j &= \bigoplus_{i \in \mathrm{Down}(j)} w_{ij} \cdot x'_i 28 | \end{align*} 29 | 30 | where: 31 | 32 | - :math:`x_i` is the input value at node :math:`i` (e.g., rainfall), 33 | - :math:`w'_i` is the node weight (e.g., pixel area), 34 | - :math:`w_{ij}` is the edge weight from node :math:`i` to node :math:`j` (e.g. discharge partitioning ratio), 35 | - :math:`\mathrm{Down}(j)` is the set of downstream nodes flowing out of node :math:`j`, 36 | - :math:`\bigoplus` is the aggregation function (e.g. a summation). 37 | - :math:`n_j` is the weighted aggregated value at node :math:`j`. 38 | 39 | Parameters 40 | ---------- 41 | river_network : RiverNetwork 42 | A river network object. 43 | field : array-like or xarray object 44 | An array containing field values defined on river network nodes or gridcells. 45 | node_weights : array-like or xarray object, optional 46 | Array of weights for each river network node or gridcell. Default is None (unweighted). 47 | edge_weights : array-like or xarray object, optional 48 | Array of weights for each edge. Default is None (unweighted). 49 | metric : str, optional 50 | Aggregation function to apply. Options are 'var', 'std', 'mean', 'sum', 'min' and 'max'. Default is `'sum'`. 51 | return_type : str, optional 52 | Either "masked", "gridded" or None. If None (default), uses `river_network.return_type`. 53 | input_core_dims : sequence of sequence, optional 54 | List of core dimensions on each input xarray argument that should not be broadcast. 55 | Default is None, which attempts to autodetect input_core_dims from the xarray inputs. 56 | Ignored if no xarray inputs passed. 57 | 58 | 59 | Returns 60 | ------- 61 | xarray object 62 | Array of values after movement up the river network for every river network node or gridcell, depending on `return_type`. 63 | """ 64 | return array.upstream( 65 | river_network, field, node_weights, edge_weights, metric, return_type 66 | ) 67 | 68 | 69 | @xarray 70 | def downstream( 71 | river_network, 72 | field, 73 | node_weights=None, 74 | edge_weights=None, 75 | metric="sum", 76 | return_type=None, 77 | input_core_dims=None, 78 | ): 79 | r""" 80 | Moves a field downstream. 81 | 82 | Computes a one-step neighbor accumulation (local aggregation) moving downstream only. 83 | 84 | The accumulation is defined as: 85 | 86 | .. math:: 87 | :nowrap: 88 | 89 | \begin{align*} 90 | x'_i &= w'_i \cdot x_i \\ 91 | n_j &= \bigoplus_{i \in \mathrm{Up}(j)} w_{ij} \cdot x'_i 92 | \end{align*} 93 | 94 | where: 95 | 96 | - :math:`x_i` is the input value at node :math:`i` (e.g., rainfall), 97 | - :math:`w'_i` is the node weight (e.g., pixel area), 98 | - :math:`w_{ij}` is the edge weight from node :math:`i` to node :math:`j` (e.g. discharge partitioning ratio), 99 | - :math:`\mathrm{Up}(j)` is the set of upstream nodes flowing into node :math:`j`, 100 | - :math:`\bigoplus` is the aggregation function (e.g. a summation). 101 | - :math:`n_j` is the weighted aggregated value at node :math:`j`. 102 | 103 | Parameters 104 | ---------- 105 | river_network : RiverNetwork 106 | A river network object. 107 | field : array-like or xarray object 108 | An array containing field values defined on river network nodes or gridcells. 109 | node_weights : array-like or xarray object, optional 110 | Array of weights for each river network node or gridcell. Default is None (unweighted). 111 | edge_weights : array-like or xarray object, optional 112 | Array of weights for each edge. Default is None (unweighted). 113 | metric : str, optional 114 | Aggregation function to apply. Options are 'var', 'std', 'mean', 'sum', 'min' and 'max'. Default is `'sum'`. 115 | return_type : str, optional 116 | Either "masked", "gridded" or None. If None (default), uses `river_network.return_type`. 117 | input_core_dims : sequence of sequence, optional 118 | List of core dimensions on each input xarray argument that should not be broadcast. 119 | Default is None, which attempts to autodetect input_core_dims from the xarray inputs. 120 | Ignored if no xarray inputs passed. 121 | 122 | 123 | Returns 124 | ------- 125 | xarray object 126 | Array of values after movement down the river network for every river network node or gridcell, depending on `return_type`. 127 | """ 128 | return array.downstream( 129 | river_network, field, node_weights, edge_weights, metric, return_type 130 | ) 131 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | import os 10 | import sys 11 | 12 | on_rtd = os.environ.get("READTHEDOCS") == "True" 13 | 14 | if on_rtd: 15 | version = os.environ.get("READTHEDOCS_VERSION", "latest") 16 | release = version 17 | else: 18 | version = "dev" 19 | release = "dev" 20 | 21 | rtd_version = version if version != "latest" else "develop" 22 | rtd_version_type = os.environ.get("READTHEDOCS_VERSION_TYPE", "branch") 23 | 24 | if rtd_version_type in ("branch", "tag"): 25 | source_branch = rtd_version 26 | else: 27 | source_branch = "main" 28 | 29 | sys.path.insert(0, os.path.abspath("../../src")) 30 | 31 | project = "earthkit-hydro" 32 | copyright = "2025, European Centre for Medium-Range Weather Forecasts (ECMWF)" 33 | author = "European Centre for Medium-Range Weather Forecasts (ECMWF)" 34 | # release = '0.0.0' 35 | 36 | # -- General configuration --------------------------------------------------- 37 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 38 | 39 | extensions = [ 40 | # Automatically extracts documentation from your Python docstrings 41 | "sphinx.ext.autodoc", 42 | # Supports Google-style and NumPy-style docstrings 43 | "sphinx.ext.napoleon", 44 | # Renders LaTeX math in HTML using MathJax 45 | "sphinx.ext.mathjax", 46 | # Option to click viewcode 47 | "sphinx.ext.viewcode", 48 | # Links to the documentation of other projects via cross-references 49 | # "sphinx.ext.intersphinx", 50 | # Generates summary tables for modules/classes/functions 51 | # "sphinx.ext.autosummary", 52 | # Allows citing BibTeX bibliographic entries in reStructuredText 53 | "sphinxcontrib.bibtex", 54 | # Tests snippets in documentation by running embedded Python examples 55 | # "sphinx.ext.doctest", 56 | # Checks documentation coverage of the codebase 57 | # "sphinx.ext.coverage", 58 | # Adds .nojekyll file and helps configure docs for GitHub Pages hosting 59 | # "sphinx.ext.githubpages", 60 | # Adds "Edit on GitHub" links to documentation pages 61 | # "edit_on_github", 62 | # Adds "Edit on GitHub" links to documentation pages 63 | # "sphinx_github_style", 64 | # Option to link to code 65 | # "sphinx.ext.linkcode", 66 | # Automatically includes type hints from function signatures into the documentation 67 | # "sphinx_autodoc_typehints", 68 | # Integrates Jupyter Notebooks into Sphinx 69 | "nbsphinx", 70 | ] 71 | 72 | templates_path = ["_templates"] 73 | exclude_patterns = [] 74 | 75 | # -- Options for HTML output ------------------------------------------------- 76 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 77 | 78 | 79 | html_theme = "furo" 80 | 81 | html_static_path = ["_static"] 82 | 83 | # html_context = { 84 | # "display_github": True, 85 | # "github_user": "ecmwf", # GitHub username 86 | # "github_repo": "docsample", # GitHub repository name 87 | # "github_version": "main", # Branch (e.g., 'main', 'master') 88 | # "conf_py_path": "/docs/", # Path to your docs root in the repo 89 | # } 90 | 91 | bibtex_bibfiles = ["references.bib"] 92 | 93 | html_theme_options = { 94 | "light_css_variables": { 95 | "color-sidebar-background": "#001F3F", 96 | "color-sidebar-link-text": "#ffffff", 97 | "color-sidebar-brand-text": "#ffffff", 98 | "color-sidebar-caption-text": "#ffffff", 99 | "color-brand-primary": "#00D9FF", 100 | "color-brand-content": "#5f8dd3", 101 | }, 102 | "dark_css_variables": { 103 | "color-sidebar-background": "#001F3F", 104 | "color-sidebar-link-text": "#ffffff", 105 | "color-sidebar-brand-text": "#ffffff", 106 | "color-sidebar-caption-text": "#ffffff", 107 | "color-brand-primary": "#00D9FF", 108 | "color-brand-content": "#5f8dd3", 109 | }, 110 | "light_logo": "earthkit-hydro-dark.svg", 111 | "dark_logo": "earthkit-hydro-dark.svg", 112 | "source_repository": "https://github.com/ecmwf/earthkit-hydro/", 113 | "source_branch": source_branch, 114 | "source_directory": "docs/source", 115 | "footer_icons": [ 116 | { 117 | "name": "GitHub", 118 | "url": "https://github.com/ecmwf/earthkit-hydro", 119 | "html": """ 120 | 121 | 122 | 123 | """, 124 | "class": "", 125 | }, 126 | ], 127 | } 128 | -------------------------------------------------------------------------------- /src/earthkit/hydro/_utils/decorators/xarray.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from inspect import signature 3 | 4 | import numpy as np 5 | import xarray as xr 6 | 7 | from earthkit.hydro._utils.coords import get_core_dims, node_default_coord 8 | 9 | 10 | def get_full_signature(func, *args, **kwargs): 11 | sig = signature(func) 12 | bound_args = sig.bind(*args, **kwargs) 13 | bound_args.apply_defaults() 14 | return bound_args.arguments 15 | 16 | 17 | def assert_xr_compatible_backend(network): 18 | network_backend = network.array_backend 19 | if network_backend not in ["numpy", "cupy"]: 20 | raise NotImplementedError(f"xarray does not support {network_backend} backend") 21 | 22 | 23 | def sort_xr_nonxr_args(all_args): 24 | xr_args = [] 25 | non_xr_kwargs = {} 26 | arg_order = [] 27 | for name, value in all_args.items(): 28 | if isinstance(value, (xr.DataArray, xr.Dataset)): 29 | xr_args.append(value) 30 | arg_order.append(("xr", name)) 31 | else: 32 | non_xr_kwargs[name] = value 33 | arg_order.append(("nonxr", name)) 34 | return xr_args, non_xr_kwargs, arg_order 35 | 36 | 37 | def get_reshuffled_func(func, arg_order): 38 | def reshuffled_func(*only_xr_args, **non_xr_kwargs): 39 | full_args = {} 40 | xr_i = 0 41 | for kind, name in arg_order: 42 | if kind == "xr": 43 | full_args[name] = only_xr_args[xr_i] 44 | xr_i += 1 45 | else: 46 | full_args[name] = non_xr_kwargs[name] 47 | return func(**full_args) 48 | 49 | return reshuffled_func 50 | 51 | 52 | def get_input_output_core_dims( 53 | input_core_dims, output_core_dims, xr_args, river_network, return_grid 54 | ): 55 | if input_core_dims is None: 56 | input_core_dims = [get_core_dims(xr_arg) for xr_arg in xr_args] 57 | elif len(input_core_dims) == 1: 58 | input_core_dims *= len(xr_args) 59 | 60 | if output_core_dims is None: 61 | if return_grid: 62 | if len(input_core_dims[0]) == 2: # grid in and out 63 | output_core_dims = [input_core_dims[0]] 64 | else: 65 | output_core_dims = [ 66 | list(river_network.coords.keys()) 67 | ] # 1d in, grid out 68 | else: 69 | if len(input_core_dims[0]) == 1: # 1d in and out 70 | output_core_dims = [input_core_dims[0]] 71 | else: 72 | output_core_dims = [[node_default_coord]] 73 | 74 | return input_core_dims, output_core_dims 75 | 76 | 77 | def xarray(func): 78 | 79 | @wraps(func) 80 | def wrapper(*args, **kwargs): 81 | 82 | # Inspect the function signature and bind all arguments 83 | all_args = get_full_signature(func, *args, **kwargs) 84 | 85 | input_core_dims = all_args.pop("input_core_dims", None) 86 | output_core_dims = None 87 | 88 | assert_xr_compatible_backend(all_args["river_network"]) 89 | 90 | # Separate xarray and non-xarray arguments 91 | xr_args, non_xr_kwargs, arg_order = sort_xr_nonxr_args(all_args) 92 | 93 | river_network = all_args["river_network"] 94 | return_type = all_args["return_type"] 95 | return_type = river_network.return_type if return_type is None else return_type 96 | return_grid = return_type == "gridded" 97 | 98 | if len(xr_args) == 0: 99 | output = func(**all_args) 100 | 101 | offset = 2 if return_grid else 1 102 | ndim = output.ndim 103 | dim_names = [f"axis{i + 1}" for i in range(ndim - offset)] 104 | coords = { 105 | dim: np.arange(size) 106 | for dim, size in zip(dim_names, output.shape[:-offset]) 107 | } 108 | 109 | if return_grid: 110 | for k, v in river_network.coords.items(): 111 | coords[k] = v 112 | dim_names.append(k) 113 | else: 114 | coords[node_default_coord] = np.arange(river_network.n_nodes) 115 | dim_names.append(node_default_coord) 116 | 117 | result = xr.DataArray(output, dims=dim_names, coords=coords, name="out") 118 | 119 | if not return_grid: 120 | coords_grid = np.meshgrid(*river_network.coords.values()) 121 | assign_dict = { 122 | k: (node_default_coord, v.flat[river_network.mask]) 123 | for k, v in zip(river_network.coords.keys(), coords_grid) 124 | } 125 | result = result.assign_coords(**assign_dict) 126 | else: 127 | 128 | reshuffled_func = get_reshuffled_func(func, arg_order) 129 | 130 | input_core_dims, output_core_dims = get_input_output_core_dims( 131 | input_core_dims, output_core_dims, xr_args, river_network, return_grid 132 | ) 133 | 134 | output_sizes = ( 135 | {output_core_dims[0][0]: river_network.n_nodes} 136 | if len(output_core_dims[0]) == 1 137 | else {k: v for k, v in zip(output_core_dims[0], river_network.shape)} 138 | ) 139 | 140 | result = xr.apply_ufunc( 141 | reshuffled_func, 142 | *xr_args, 143 | input_core_dims=input_core_dims, 144 | output_core_dims=output_core_dims, 145 | dask_gufunc_kwargs={"output_sizes": output_sizes}, 146 | output_dtypes=[float], 147 | dask="parallelized", 148 | kwargs=non_xr_kwargs, 149 | ) 150 | 151 | if len(output_core_dims[0]) == 1: 152 | coords = list(river_network.coords.values())[::-1] 153 | coords_grid = np.meshgrid(*coords)[::-1] 154 | assign_dict = { 155 | k: (output_core_dims[0], v.flat[river_network.mask]) 156 | for k, v in zip(river_network.coords.keys(), coords_grid) 157 | } 158 | result = result.assign_coords(**assign_dict) 159 | 160 | return result 161 | 162 | return wrapper 163 | -------------------------------------------------------------------------------- /src/earthkit/hydro/data_structures/_network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ._network_storage import RiverNetworkStorage 4 | 5 | 6 | class RiverNetwork: 7 | """ 8 | A class representing a river network for hydrological processing. 9 | 10 | Attributes 11 | ---------- 12 | n_nodes : int 13 | The number of nodes in the river network. 14 | n_edges : int 15 | The number of nodes in the river network. 16 | sinks : array-like 17 | Nodes with no downstream connections. 18 | sources : array-like 19 | Nodes with no upstream connections. 20 | bifurcates : bool 21 | Whether the river network has bifurcations. 22 | shape : tuple 23 | The size of the river network grid. None if the network is vector-based. 24 | mask : array-like 25 | Flattened 1D indices on the raster grid corresponding to river network nodes. 26 | array_backend : str 27 | The array backend of the river network. 28 | device : str 29 | The device of the river network. 30 | return_type : str 31 | The default return type of the river network. Either "gridded" or "masked". 32 | """ 33 | 34 | def __init__(self, river_network_storage: RiverNetworkStorage): 35 | self._storage = river_network_storage 36 | self.n_nodes = self._storage.n_nodes 37 | self.n_edges = self._storage.n_edges 38 | self.sources = self._storage.sources 39 | self.sinks = self._storage.sinks 40 | 41 | self.bifurcates = self._storage.bifurcates 42 | self.edge_weights = self._storage.edge_weights 43 | 44 | self.mask = self._storage.mask 45 | self.shape = self._storage.shape 46 | self.array_backend = "numpy" 47 | self.device = "cpu" 48 | self.return_type = "gridded" 49 | 50 | self.coords = self._storage.coords 51 | 52 | self.data = [self._storage.sorted_data] 53 | self.groups = np.split(self._storage.sorted_data, self._storage.splits, axis=1) 54 | 55 | def __str__(self): 56 | return f"RiverNetwork with {self.n_nodes} nodes and {self.n_edges} edges." 57 | 58 | def __repr__(self): 59 | return self.__str__() 60 | 61 | def to_device(self, device=None, array_backend=None): 62 | """ 63 | Change the RiverNetwork's array backend and/or move it to a 64 | different device. 65 | 66 | Parameters 67 | ---------- 68 | device : str, optional 69 | The device to which to transfer. Default is None, which is `'cpu'` for all backends except cupy, which is `'gpu'`. 70 | array_backend : str, optional 71 | The array backend. 72 | One of "numpy", "np", "cupy", "cp", "pytorch", "torch", "jax", "jnp", "tensorflow", "tf", "mlx" or "mx". 73 | Default is None, which uses `self.array_backend`. 74 | 75 | Returns 76 | ------- 77 | RiverNetwork 78 | The modified RiverNetwork. 79 | """ 80 | 81 | from earthkit.utils.array.convert import convert 82 | 83 | # TODO: use xp.asarray 84 | if array_backend == "np": 85 | array_backend = "numpy" 86 | elif array_backend == "cp": 87 | array_backend = "cupy" 88 | elif array_backend == "jnp": 89 | array_backend = "jax" 90 | elif array_backend == "tf": 91 | array_backend = "tensorflow" 92 | elif array_backend == "pytorch": 93 | array_backend = "torch" 94 | elif array_backend == "mx": 95 | array_backend = "mlx" 96 | 97 | if device is None: 98 | device = "cpu" if array_backend != "cupy" else "gpu" 99 | if array_backend is None: 100 | if self.array_backend == "numpy" and device in ["gpu", "cuda"]: 101 | array_backend = "cupy" 102 | else: 103 | array_backend = self.array_backend 104 | 105 | if array_backend in ["torch", "cupy", "numpy"]: 106 | self.groups = [ 107 | convert(group, device=device, array_namespace=array_backend) 108 | for group in self.groups 109 | ] 110 | self.mask = convert(self.mask, device=device, array_namespace=array_backend) 111 | self.data = [ 112 | convert(self.data[0], device=device, array_namespace=array_backend) 113 | ] 114 | elif array_backend == "jax": 115 | assert device == "cpu" 116 | import jax.numpy as jnp 117 | 118 | self.groups = [jnp.array(x) for x in self.groups] 119 | self.mask = jnp.array(self.mask) 120 | self.data = [jnp.array(self.data[0])] 121 | elif array_backend == "tensorflow": 122 | assert device == "cpu" 123 | import tensorflow as tf 124 | 125 | self.groups = [tf.convert_to_tensor(x, dtype=tf.int32) for x in self.groups] 126 | self.mask = tf.convert_to_tensor(self.mask, dtype=tf.int32) 127 | self.data = [tf.convert_to_tensor(self.data[0], dtype=tf.int32)] 128 | elif array_backend == "mlx": 129 | import mlx.core as mx 130 | 131 | self.groups = [mx.array(x) for x in self.groups] 132 | self.mask = mx.array(self.mask) 133 | self.data = [mx.array(self.data[0])] 134 | else: 135 | raise NotImplementedError 136 | 137 | self.array_backend = array_backend 138 | if self.array_backend != "mlx": 139 | self.device = self.groups[0].device 140 | else: 141 | self.device = None 142 | return self 143 | 144 | def set_default_return_type(self, return_type): 145 | """ 146 | Set the default return type for the river network. 147 | 148 | Parameters 149 | ---------- 150 | return_type : str 151 | The default return_type to use. 152 | 153 | Returns 154 | ------- 155 | None 156 | """ 157 | if return_type not in ["gridded", "masked"]: 158 | raise ValueError( 159 | f'Invalid return_type {return_type}. Valid types are "gridded", "masked"' 160 | ) 161 | self.return_type = return_type 162 | 163 | def export(self, fpath="river_network.joblib", compression=1): 164 | """ 165 | Save the river network to a local file. 166 | 167 | Parameters 168 | ---------- 169 | fpath : str, optional 170 | The filepath specifying where to save the RiverNetwork. Default is `'river_network.joblib'`. 171 | compression : str, optional 172 | The compression factor used for saving. Default is 1. 173 | 174 | Returns 175 | ------- 176 | None 177 | """ 178 | import joblib 179 | 180 | joblib.dump(self._storage, fpath, compress=compression) 181 | -------------------------------------------------------------------------------- /docs/source/tutorials/loading_river_networks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "f1047b78", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import earthkit.hydro as ekh" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "8ea8ca7d", 16 | "metadata": {}, 17 | "source": [ 18 | "# Loading/creating river networks\n", 19 | "\n", 20 | "This notebooks goes through all of the different options for loading a river network.\n", 21 | "\n", 22 | "There are two methods\n", 23 | "- `ekh.river_network.load` (recommended)\n", 24 | "- `ekh.river_network.create` (advanced use only)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "0751f034", 30 | "metadata": {}, 31 | "source": [ 32 | "## Loading river networks\n", 33 | "\n", 34 | "A chosen network can then be loaded easily using `river_network.load`. For example, loading the EFAS v5 network is done with" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "id": "aee70ac8", 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "River network not found in cache (/var/folders/td/yqnxcqpx39dc855vwjtv5hj40000gn/T/tmp75g6_e5n_earthkit_hydro/0.2_c3b54d239de177b0e6b8abf8b506b79bc227f3d173ad1e346c9cc89471cad1e0.joblib).\n", 48 | "River network loaded, saving to cache (/var/folders/td/yqnxcqpx39dc855vwjtv5hj40000gn/T/tmp75g6_e5n_earthkit_hydro/0.2_c3b54d239de177b0e6b8abf8b506b79bc227f3d173ad1e346c9cc89471cad1e0.joblib).\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "network = ekh.river_network.load(\"efas\", \"5\")" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "ed764d73", 59 | "metadata": {}, 60 | "source": [ 61 | "## Creating river networks\n", 62 | "\n", 63 | "Obviously not all river networks will be available. In advanced cases therefore, users may need to create their own river network with `river_network.create`.\n", 64 | "\n", 65 | "Many river network formats are supported, as are many filetypes and data sources. As an example, we can load a local netCDF file using the PCRaster river network format via" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "id": "21bc9f8c", 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "River network not found in cache (/var/folders/td/yqnxcqpx39dc855vwjtv5hj40000gn/T/tmp75g6_e5n_earthkit_hydro/0.2_d49eb4193bdab4aa653eac88fc0d9dea6c7ecbe5ae64cf639230e2c2fdf612e4.joblib).\n", 79 | "River network loaded, saving to cache (/var/folders/td/yqnxcqpx39dc855vwjtv5hj40000gn/T/tmp75g6_e5n_earthkit_hydro/0.2_d49eb4193bdab4aa653eac88fc0d9dea6c7ecbe5ae64cf639230e2c2fdf612e4.joblib).\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "network = ekh.river_network.create(\"local_pcraster_efas_network_file.nc\", \"pcr_d8\", \"file\")" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "id": "aff0387d", 90 | "metadata": {}, 91 | "source": [ 92 | "This creation step can be expensive, particularly for large network since it involves finding the topological ordering of the graph. Therefore, if the network will be reused for further analyses, it is recommended to save it explicitly somewhere." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 5, 98 | "id": "c6af51d1", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "network.export(\"my_river_network.joblib\")" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "6b0e9dc7", 108 | "metadata": {}, 109 | "source": [ 110 | "The saved network can then be cheaply reloaded in later analyses via" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 6, 116 | "id": "a1c6b317", 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "River network not found in cache (/var/folders/td/yqnxcqpx39dc855vwjtv5hj40000gn/T/tmp75g6_e5n_earthkit_hydro/0.2_cd239cb18e6e841065a4fed044c6ac667853b385ef8406da2c1cf86471f003a8.joblib).\n", 124 | "River network loaded, saving to cache (/var/folders/td/yqnxcqpx39dc855vwjtv5hj40000gn/T/tmp75g6_e5n_earthkit_hydro/0.2_cd239cb18e6e841065a4fed044c6ac667853b385ef8406da2c1cf86471f003a8.joblib).\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "network = ekh.river_network.create(\"my_river_network.joblib\", \"precomputed\", \"file\")" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "id": "a20b4e76", 135 | "metadata": {}, 136 | "source": [ 137 | "## Caching\n", 138 | "\n", 139 | "Networks are automatically cached to temporary storage for faster reloading within the same script. This temporary storage is however still only temporary, which is why exporting and saving explicitly is required if one wants to ensure the availability of the precomputed network.\n", 140 | "\n", 141 | "The cache can however be modified. To disable it, the `use_cache=False` flag can be passed" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 7, 147 | "id": "72f2b722", 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "Cache disabled.\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "network = ekh.river_network.load(\"efas\", \"5\", use_cache=False)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "id": "44358e73", 165 | "metadata": {}, 166 | "source": [ 167 | "The cache can also be set to a non-temporary storage location, in which case it becomes _permanent_." 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 8, 173 | "id": "a2ee7e97", 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "River network not found in cache (./0.2_c3b54d239de177b0e6b8abf8b506b79bc227f3d173ad1e346c9cc89471cad1e0.joblib).\n", 181 | "River network loaded, saving to cache (./0.2_c3b54d239de177b0e6b8abf8b506b79bc227f3d173ad1e346c9cc89471cad1e0.joblib).\n" 182 | ] 183 | } 184 | ], 185 | "source": [ 186 | "network = ekh.river_network.load(\"efas\", \"5\", cache_dir=\".\")" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "id": "5ade93ba", 192 | "metadata": {}, 193 | "source": [ 194 | "For further customisation options, view the API reference." 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "ekh", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.11.13" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 5 219 | } 220 | -------------------------------------------------------------------------------- /src/earthkit/hydro/distance/array/_toplevel.py: -------------------------------------------------------------------------------- 1 | from earthkit.hydro.distance.array import _operations 2 | 3 | 4 | def min( 5 | river_network, 6 | locations, 7 | field=None, 8 | upstream=False, 9 | downstream=True, 10 | return_type=None, 11 | ): 12 | r""" 13 | Calculates the minimum distance to all points from a set of start 14 | locations. 15 | 16 | For each node in the network, calculates the minimum distance starting from any of the start locations. 17 | 18 | The distance is defined as: 19 | 20 | .. math:: 21 | :nowrap: 22 | 23 | \begin{align*} 24 | d_j &= 0 ~\text{for start locations}\\ 25 | d_j &= \mathrm{min}(\infty,~\mathrm{min}_{i \in \mathrm{Neighbour}(j)} (d_i + w_{ij}) ) 26 | \end{align*} 27 | 28 | where: 29 | 30 | - :math:`w_{ij}` is the edge distance (e.g., downstream distance), 31 | - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. 32 | - :math:`d_j` is the total distance at node :math:`j`. 33 | 34 | Unreachable nodes are given a distance of :math:`\infty`. 35 | 36 | Parameters 37 | ---------- 38 | river_network : RiverNetwork 39 | A river network object. 40 | locations : array-like or dict 41 | A list of source nodes. 42 | field : array-like, optional 43 | An array containing length values defined on river network edges. 44 | Default is `xp.ones(river_network.n_edges)`. 45 | upstream : bool, optional 46 | Whether or not to consider upstream distances. Default is False. 47 | downstream : bool, optional 48 | Whether or not to consider downstream distances. Default is True. 49 | return_type : str, optional 50 | Either "masked", "gridded" or None. If None (default), uses `river_network.return_type`. 51 | 52 | Returns 53 | ------- 54 | array-like 55 | Array of minimum distances for every river network node or gridcell, depending on `return_type`. 56 | """ 57 | return _operations.min( 58 | river_network=river_network, 59 | field=field, 60 | locations=locations, 61 | upstream=upstream, 62 | downstream=downstream, 63 | return_type=return_type, 64 | ) 65 | 66 | 67 | def max( 68 | river_network, 69 | locations, 70 | field=None, 71 | upstream=False, 72 | downstream=True, 73 | return_type=None, 74 | ): 75 | r""" 76 | Calculates the maximum distance to all points from a set of start 77 | locations. 78 | 79 | For each node in the network, calculates the maximum distance starting from any of the start locations. 80 | 81 | The distance is defined as: 82 | 83 | .. math:: 84 | :nowrap: 85 | 86 | \begin{align*} 87 | d_j &= 0 ~\text{for start locations}\\ 88 | d_j &= \mathrm{max}(-\infty,~\mathrm{max}_{i \in \mathrm{Neighbour}(j)} (d_i + w_{ij}) ) 89 | \end{align*} 90 | 91 | where: 92 | 93 | - :math:`w_{ij}` is the edge distance (e.g., downstream distance), 94 | - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. 95 | - :math:`d_j` is the total distance at node :math:`j`. 96 | 97 | Unreachable nodes are given a distance of :math:`-\infty`. 98 | 99 | Parameters 100 | ---------- 101 | river_network : RiverNetwork 102 | A river network object. 103 | locations : array-like or dict 104 | A list of source nodes. 105 | field : array-like, optional 106 | An array containing length values defined on river network edges. 107 | Default is `xp.ones(river_network.n_edges)`. 108 | upstream : bool, optional 109 | Whether or not to consider upstream distances. Default is False. 110 | downstream : bool, optional 111 | Whether or not to consider downstream distances. Default is True. 112 | return_type : str, optional 113 | Either "masked", "gridded" or None. If None (default), uses `river_network.return_type`. 114 | 115 | Returns 116 | ------- 117 | array-like 118 | Array of maximum distances for every river network node or gridcell, depending on `return_type`. 119 | """ 120 | return _operations.max( 121 | river_network=river_network, 122 | field=field, 123 | locations=locations, 124 | upstream=upstream, 125 | downstream=downstream, 126 | return_type=return_type, 127 | ) 128 | 129 | 130 | def to_source( 131 | river_network, 132 | field=None, 133 | path="shortest", 134 | return_type=None, 135 | ): 136 | r""" 137 | Calculates the maximum distance to all points from the river network sources. 138 | 139 | For each node in the network, calculates the maximum distance starting from any source. 140 | 141 | The distance is defined as: 142 | 143 | .. math:: 144 | :nowrap: 145 | 146 | \begin{align*} 147 | d_j &= 0 ~\text{for sources}\\ 148 | d_j &= \bigoplus \left(-\infty,~\bigoplus_{i \in \mathrm{Neighbour}(j)} (d_i + w_{ij}) \right) 149 | \end{align*} 150 | 151 | where: 152 | 153 | - :math:`w_{ij}` is the edge distance (e.g., downstream distance), 154 | - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. 155 | - :math:`\bigoplus` is the aggregation function (max for longest path or min for shortest path). 156 | - :math:`d_j` is the total distance at node :math:`j`. 157 | 158 | Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, and :math:`\infty` if :math:`\bigoplus` is a minimum. 159 | 160 | Parameters 161 | ---------- 162 | river_network : RiverNetwork 163 | A river network object. 164 | field : array-like, optional 165 | An array containing length values defined on river network edges. 166 | Default is `xp.ones(river_network.n_edges)`. 167 | path : str, optional 168 | Whether to compute the longest or shortest path. Default is "shortest". 169 | return_type : str, optional 170 | Either "masked", "gridded" or None. If None (default), uses `river_network.return_type`. 171 | 172 | Returns 173 | ------- 174 | array-like 175 | Array of maximum distances for every river network node or gridcell, depending on `return_type`. 176 | """ 177 | locations = river_network.sources 178 | if path == "longest": 179 | return max(river_network, locations, field, False, True, return_type) 180 | elif path == "shortest": 181 | return min(river_network, locations, field, False, True, return_type) 182 | else: 183 | raise ValueError("path must be 'longest' or 'shortest'") 184 | 185 | 186 | def to_sink(river_network, field=None, path="shortest", return_type=None): 187 | r""" 188 | Calculates the maximum distance to all points from the river network sinks. 189 | 190 | For each node in the network, calculates the maximum distance starting from any sink. 191 | 192 | The distance is defined as: 193 | 194 | .. math:: 195 | :nowrap: 196 | 197 | \begin{align*} 198 | d_j &= 0 ~\text{for sinks}\\ 199 | d_j &= \bigoplus \left(-\infty,~\bigoplus_{i \in \mathrm{Neighbour}(j)} (d_i + w_{ij}) \right) 200 | \end{align*} 201 | 202 | where: 203 | 204 | - :math:`w_{ij}` is the edge distance (e.g., downstream distance), 205 | - :math:`\mathrm{Neighbour}(j)` is the set of neighbouring nodes to node :math:`j`, which can include upstream and/or downstream nodes depending on passed arguments. 206 | - :math:`\bigoplus` is the aggregation function (max for longest path or min for shortest path). 207 | - :math:`d_j` is the total distance at node :math:`j`. 208 | 209 | Unreachable nodes are given a distance of :math:`-\infty` if :math:`\bigoplus` is a maximum, and :math:`\infty` if :math:`\bigoplus` is a minimum. 210 | 211 | Parameters 212 | ---------- 213 | river_network : RiverNetwork 214 | A river network object. 215 | field : array-like, optional 216 | An array containing length values defined on river network edges. 217 | Default is `xp.ones(river_network.n_edges)`. 218 | path : str, optional 219 | Whether to compute the longest or shortest path. Default is "shortest". 220 | return_type : str, optional 221 | Either "masked", "gridded" or None. If None (default), uses `river_network.return_type`. 222 | 223 | Returns 224 | ------- 225 | array-like 226 | Array of maximum distances for every river network node or gridcell, depending on `return_type`. 227 | """ 228 | locations = river_network.sinks 229 | if path == "longest": 230 | return max(river_network, locations, field, True, False, return_type) 231 | elif path == "shortest": 232 | return min(river_network, locations, field, True, False, return_type) 233 | else: 234 | raise ValueError("path must be 'longest' or 'shortest'") 235 | --------------------------------------------------------------------------------