├── .github └── workflows │ ├── ci.yaml │ └── lint.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── Jenkinsfile ├── LICENSE ├── README.md ├── assets └── worldcereal_logo.jpg ├── environment.yml ├── jenkins_pre_install_script.sh ├── mypy.ini ├── notebooks ├── notebook_utils │ ├── classifier.py │ ├── croptypepicker.py │ ├── dateslider.py │ ├── extractions.py │ ├── seasons.py │ └── visualization.py ├── patch_to_point │ ├── job_df.parquet │ ├── patch_to_point_development.ipynb │ └── rdm_query_example ├── resources │ ├── Cropland_inference_choose_end_date.png │ ├── Custom_cropland_map.png │ ├── Custom_croptype_map.png │ ├── Default_cropland_map.png │ ├── Landcover_training_data_density_PhI.png │ ├── MOOC_refdata_RDM_exploration.png │ ├── WorldCereal_private_extractions.png │ ├── cropland_data_Ph1.png │ ├── croptype_classes.json │ ├── croptype_data_Ph1.png │ ├── eurocrops_map_wcr_edition.csv │ └── wc2eurocrops_map.csv ├── worldcereal_RDM_demo.ipynb ├── worldcereal_custom_croptype.ipynb ├── worldcereal_default_cropland.ipynb └── worldcereal_private_extractions.ipynb ├── pyproject.toml ├── pytest.ini ├── scripts ├── extractions │ └── extract.py ├── inference │ ├── collect_inputs.py │ ├── cropland_mapping.py │ ├── cropland_mapping_local.py │ ├── cropland_mapping_udp.py │ └── croptype_mapping_udp.py ├── misc │ └── legend.py ├── spark │ ├── compute_presto_features.py │ └── compute_presto_features.sh └── stac │ ├── build_paths.py │ ├── catalogue_builder.py │ └── split_catalogue.py ├── src └── worldcereal │ ├── __init__.py │ ├── _version.py │ ├── data │ ├── cropcalendars │ │ ├── ANNUAL_EOS_WGS84.tif │ │ ├── ANNUAL_SOS_WGS84.tif │ │ ├── S1_EOS_WGS84.tif │ │ ├── S1_SOS_WGS84.tif │ │ ├── S2_EOS_WGS84.tif │ │ ├── S2_SOS_WGS84.tif │ │ └── __init__.py │ └── croptype_mappings │ │ ├── __init__.py │ │ ├── croptype_classes.json │ │ ├── eurocrops_map_wcr_edition.csv │ │ └── wc2eurocrops_map.csv │ ├── extract │ ├── __init__.py │ ├── common.py │ ├── patch_meteo.py │ ├── patch_s1.py │ ├── patch_s2.py │ ├── patch_to_point_worldcereal.py │ ├── patch_worldcereal.py │ ├── point_worldcereal.py │ └── utils.py │ ├── job.py │ ├── openeo │ ├── __init__.py │ ├── feature_extractor.py │ ├── inference.py │ ├── mapping.py │ ├── masking.py │ ├── postprocess.py │ ├── preprocessing.py │ └── udf_distance_to_cloud.py │ ├── parameters.py │ ├── rdm_api │ ├── __init__.py │ ├── rdm_collection.py │ └── rdm_interaction.py │ ├── seasons.py │ ├── stac │ ├── __init__.py │ ├── constants.py │ └── stac_api_interaction.py │ ├── train │ └── data.py │ ├── udp │ ├── worldcereal_crop_extent.json │ └── worldcereal_crop_type.json │ └── utils │ ├── geoloader.py │ ├── legend.py │ ├── map.py │ ├── models.py │ ├── refdata.py │ ├── retry.py │ ├── spark.py │ ├── timeseries.py │ └── upload.py └── tests ├── pre_test_script.sh └── worldcerealtests ├── __init__.py ├── conftest.py ├── test_feature_extractor.py ├── test_inference.py ├── test_pipelines.py ├── test_postprocessing.py ├── test_preprocessing.py ├── test_rdm_interaction.py ├── test_refdata.py ├── test_seasons.py ├── test_stac_api_interaction.py ├── test_timeseries.py ├── test_train.py ├── test_utils.py └── testresources ├── preprocess_from_patches_graph.json ├── preprocess_graph.json ├── preprocess_graphwithslope.json ├── spatial_extent.json ├── test_public_extractions.parquet ├── worldcereal_cropland_classification.nc ├── worldcereal_croptype_classification.nc ├── worldcereal_preprocessed_inputs.nc └── worldcereal_private_extractions_dummy.parquet └── ref_id=2021_BEL_LPIS-Flanders_POLY_110 └── data_0.parquet /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | tests: 11 | name: tests 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Clone repo 15 | uses: actions/checkout@v2 16 | - name: Set up python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: 3.9 20 | cache: 'pip' 21 | # - name: Install openEO-GFMAP from source 22 | # run: | 23 | # git clone https://github.com/Open-EO/openeo-gfmap.git 24 | # cd openeo-gfmap 25 | # pip install . 26 | # cd .. 27 | - name: Install Presto from source 28 | run: | 29 | git clone -b croptype https://github.com/WorldCereal/presto-worldcereal.git 30 | cd presto-worldcereal 31 | pip install . 32 | cd .. 33 | - name: Install WorldCereal dependencies 34 | run: pip install ".[dev,train,notebooks]" 35 | - name: Run WorldCereal tests 36 | run: python -m pytest -s --log-cli-level=INFO tests 37 | env: 38 | OPENEO_AUTH_METHOD: client_credentials 39 | OPENEO_OIDC_DEVICE_CODE_MAX_POLL_TIME: 5 40 | OPENEO_AUTH_PROVIDER_ID_CDSE: CDSE 41 | OPENEO_AUTH_PROVIDER_ID_VITO: terrascope 42 | OPENEO_AUTH_CLIENT_ID_CDSE: openeo-worldcereal-service-account 43 | OPENEO_AUTH_CLIENT_ID_VITO: openeo-worldcereal-service-account 44 | OPENEO_AUTH_CLIENT_SECRET_CDSE: ${{ secrets.OPENEO_AUTH_CLIENT_SECRET_CDSE }} 45 | OPENEO_AUTH_CLIENT_SECRET_VITO: ${{ secrets.OPENEO_AUTH_CLIENT_SECRET_TERRASCOPE }} -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | lint: 11 | name: "Lint: code quality and formatting checks" 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Clone repo 15 | uses: actions/checkout@v2 16 | - name: Set up python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: 3.9 20 | cache: 'pip' 21 | - name: Install Python dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | python -m pip install isort black==24.3.0 ruff==0.5.0 mypy==1.11.2 types-requests~=2.32.0 25 | 26 | - name: isort 27 | run: python -m isort . --check --diff 28 | - name: black 29 | run: python -m black --check --diff . 30 | - name: ruff 31 | run: ruff check . 32 | - name: mypy 33 | run: python -m mypy . -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | .vscode 3 | __pycache__ 4 | .ftpconfig 5 | *.pyc 6 | *.py[cod] 7 | *$py.class 8 | tmp 9 | *.vrt 10 | *.pkl 11 | *.mep2 12 | *.tig 13 | *.tif 14 | *.jp2 15 | *.npy 16 | *.kml 17 | *.csv 18 | *.tsv 19 | *.tfevents 20 | *.nc 21 | *.gif 22 | config.json 23 | # C extensions 24 | *.so 25 | *.gz 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | .mypy_cache 45 | .vscode 46 | .history 47 | .noseids 48 | experimental/** 49 | 50 | # don't track catboost training info 51 | src/worldcereal/train/catboost_info 52 | 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | *.spec 58 | 59 | # Installer logs 60 | pip-log.txt 61 | pip-delete-this-directory.txt 62 | 63 | # Allow the test resources to be tracked 64 | !/tests/worldcerealtests/testresources/** 65 | *.aux.xml 66 | 67 | # Allow crop calendar resources to be tracked 68 | !/src/worldcereal/data/cropcalendars/*.tif 69 | 70 | # Allow resource CSV files to be tracked 71 | # !/src/worldcereal/resources/*.csv 72 | 73 | # Allow resource JSON files to be tracked 74 | # !/src/worldcereal/resources/**/*.json 75 | 76 | # Allow biomes/realms to be tracked 77 | # !/src/worldcereal/resources/biomes/*.tif 78 | # !/src/worldcereal/resources/realms/*.tif 79 | 80 | # Dont track zarr data 81 | features.zarr/ 82 | 83 | # Unit test / coverage reports 84 | htmlcov/ 85 | .tox/ 86 | .nox/ 87 | .coverage 88 | .coverage.* 89 | .cache 90 | nosetests.xml 91 | coverage.xml 92 | *.cover 93 | .hypothesis/ 94 | .pytest_cache/ 95 | 96 | # Translations 97 | *.mo 98 | *.pot 99 | 100 | # Django stuff: 101 | *.log 102 | local_settings.py 103 | db.sqlite3 104 | 105 | # Flask stuff: 106 | instance/ 107 | .webassets-cache 108 | 109 | # Scrapy stuff: 110 | .scrapy 111 | 112 | # Sphinx documentation 113 | docs/_build/ 114 | 115 | # PyBuilder 116 | target/ 117 | 118 | # Jupyter Notebook 119 | .ipynb_checkpoints 120 | 121 | # IPython 122 | profile_default/ 123 | ipython_config.py 124 | 125 | # pyenv 126 | .python-version 127 | 128 | # celery beat schedule file 129 | celerybeat-schedule 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # .idea stuff 159 | .idea 160 | 161 | !src/worldcereal/data/**/*.csv 162 | 163 | # Pyre type checker 164 | .pyre/ 165 | download.zip 166 | catboost_info/* 167 | notebooks/catboost_info/* 168 | 169 | *.cbm 170 | *.pt 171 | *.onnx 172 | *.nc 173 | *.7z 174 | *.dmg 175 | *.gz 176 | *.iso 177 | *.jar 178 | *.rar 179 | *.tar 180 | *.zip 181 | 182 | .notebook-tests/ 183 | .local-presto-test/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # https://pre-commit.com 2 | default_language_version: 3 | python: python3 4 | default_stages: [pre-commit, manual] 5 | fail_fast: true 6 | exclude: "(received/|.*_depr)" 7 | repos: 8 | - repo: https://github.com/pre-commit/pre-commit-hooks 9 | rev: v4.5.0 10 | hooks: 11 | # - id: check-added-large-files 12 | # args: ['--maxkb=65536'] 13 | - id: check-ast 14 | - id: check-builtin-literals 15 | - id: check-byte-order-marker 16 | - id: check-case-conflict 17 | - id: check-docstring-first 18 | - id: check-json 19 | - id: check-merge-conflict 20 | - id: check-symlinks 21 | - id: check-toml 22 | - id: check-vcs-permalinks 23 | - id: check-xml 24 | - id: check-yaml 25 | args: [--allow-multiple-documents] 26 | - id: debug-statements 27 | - id: detect-private-key 28 | - id: mixed-line-ending 29 | - id: trailing-whitespace 30 | types: [python] 31 | - id: end-of-file-fixer 32 | types: [python] 33 | - repo: local 34 | hooks: 35 | - id: shellcheck 36 | name: shellcheck 37 | entry: shellcheck --check-sourced --shell=bash --exclude=SC1087 38 | language: system 39 | types: [shell] 40 | # - id: pydocstyle 41 | # name: pydocstyle 42 | # entry: pydocstyle 43 | # language: system 44 | # types: [python] 45 | # exclude: "(^experiments/|.*_depr)" 46 | # - id: flake8 47 | # name: flake8 48 | # entry: flake8 49 | # language: system 50 | # types: [python] 51 | # exclude: "(^tasks/|.*_depr)" 52 | - repo: https://github.com/pre-commit/pre-commit-hooks 53 | rev: v4.5.0 54 | hooks: 55 | - id: no-commit-to-branch 56 | args: [ '--branch', 'main' ] 57 | - repo: https://github.com/radix-ai/auto-smart-commit 58 | rev: v1.0.3 59 | hooks: 60 | - id: auto-smart-commit 61 | - repo: https://github.com/psf/black 62 | rev: 24.3.0 63 | hooks: 64 | - id: black 65 | language_version: python3 66 | - repo: https://github.com/pycqa/isort 67 | rev: 5.13.2 68 | hooks: 69 | - id: isort 70 | name: isort (python) 71 | args: ["--profile", "black"] 72 | - repo: https://github.com/astral-sh/ruff-pre-commit 73 | rev: v0.1.15 74 | hooks: 75 | - id: ruff 76 | args: ["--fix"] 77 | - repo: https://github.com/pre-commit/mirrors-mypy 78 | rev: v1.11.2 79 | hooks: 80 | - id: mypy 81 | additional_dependencies: ["types-requests"] 82 | -------------------------------------------------------------------------------- /Jenkinsfile: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env groovy 2 | 3 | /* Jenkinsfile for snapshot building with VITO CI system. */ 4 | 5 | @Library('lib')_ 6 | 7 | pythonPipeline { 8 | package_name = 'worldcereal-classification' 9 | test_module_name = 'worldcereal' 10 | wipeout_workspace = true 11 | python_version = ["3.10"] 12 | extras_require = "dev,train,notebooks" 13 | upload_dev_wheels = false 14 | pipeline_triggers = [cron('H H(0-6) * * *')] 15 | pep440 = true 16 | pre_test_script = 'pre_test_script.sh' 17 | pre_install_script = 'jenkins_pre_install_script.sh' 18 | extra_env_variables = [ 19 | "OPENEO_AUTH_METHOD=client_credentials", 20 | "OPENEO_OIDC_DEVICE_CODE_MAX_POLL_TIME=5", 21 | "OPENEO_AUTH_PROVIDER_ID_VITO=terrascope", 22 | "OPENEO_AUTH_CLIENT_ID_VITO=openeo-worldcereal-service-account", 23 | "OPENEO_AUTH_PROVIDER_ID_CDSE=CDSE", 24 | "OPENEO_AUTH_CLIENT_ID_CDSE=openeo-worldcereal-service-account", 25 | ] 26 | extra_env_secrets = [ 27 | 'OPENEO_AUTH_CLIENT_SECRET_VITO': 'TAP/big_data_services/devops/terraform/keycloak_mgmt/oidc_clients_prod openeo-worldcereal-service-account', 28 | 'OPENEO_AUTH_CLIENT_SECRET_CDSE': 'TAP/big_data_services/openeo/cdse-service-accounts/openeo-worldcereal-service-account client_secret', 29 | ] 30 | } 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ESA WorldCereal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESA WorldCereal classification module 2 | [![Tests](https://github.com/WorldCereal/worldcereal-classification/actions/workflows/ci.yaml/badge.svg)](https://github.com/WorldCereal/worldcereal-classification/actions/workflows/ci.yaml) [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit) [![License](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![DOI](https://img.shields.io/badge/DOI-10.5194/essd--15--5491--2023-blue)](https://doi.org/10.5194/essd-15-5491-2023) [![Documentation](https://img.shields.io/badge/docs-WorldCereal%20Documentation-blue)](https://worldcereal.github.io/worldcereal-documentation/) [![Discuss Forum](https://img.shields.io/badge/discuss-forum-ED1965?logo=discourse&logoColor=white)](https://forum.esa-worldcereal.org/) 3 | 4 | 5 |

6 | 7 | logo 8 | 9 |

10 | 11 | ## Overview 12 | 13 | **WorldCereal** is a Python package designed for generating cropland and crop type maps at a wide range of spatial scales, leveraging satellite and auxiliary data, and state-of-the-art classification workflows. It uses [**openEO**](https://openeo.org/) to run classification tasks in the cloud, by default the [**Copernicus Data Space Ecosystem (CDSE)**](https://dataspace.copernicus.eu/). 14 | 15 | Users can leverage the system in a notebook environment through [Terrascope](https://terrascope.be/en) or set up the environment locally using the provided installation options. 16 | 17 | In order to run classification jobs on CDSE, users can get started with [monthly free openEO processing credits](https://documentation.dataspace.copernicus.eu/Quotas.html) by registering on the CDSE platform. Additional credits can be purchased, or users may soon be able to request them through the **ESA Network of Resources**. 18 | 19 | ## Features 20 | 21 | - **Scalable**: Generate maps at a wide range of spatial scales. 22 | - **Cloud-based Processing**: Leverages openEO to run classifications in the cloud. 23 | - **Powerful classification pipeline**: WorldCereal builds upon [Presto](https://arxiv.org/abs/2304.14065), a pretrained transformer-based model, leveraging global self-supervised learning of multimodal input timeseries, leading to better accuracies and higher generalizability in space and time of downstream crop classification models. The Presto backbone of WorldCereal classification pipelines is developed [here](https://github.com/WorldCereal/presto-worldcereal). 24 | - **Customizable**: Users can pick any region or temporal range and apply either default models or train their own and produce custom maps, interacting with publicly available training data. 25 | - **Easy to Use**: Integrates into Jupyter notebooks and other Python environments. 26 | 27 | ## Quick Start 28 | 29 | One of our [demo notebooks](notebooks) can introduce to you quickly how to map crops with the Worldcereal system. There's two options to run these notebooks: 30 | 31 | #### Option 1: Run on Terrascope 32 | 33 | You can use a preconfigured environment on [**Terrascope**](https://terrascope.be/en) to run the workflows in a Jupyter notebook environment. Just register as a new user on Terrascope or use one of the supported EGI eduGAIN login methods to get started. 34 | 35 | | :point_up: | Once you are prompted with "Server Options", make sure to select the "Worldcereal" image. Did you choose "Terrascope" by accident? Then go to File > Hub Control Panel > Stop my server, and click the link below once again. | 36 | |---------------|:------------------------| 37 | 38 | - For a cropland map generation demo without any model training: Run cropland demo 39 | 40 | - For a crop type map generation demo with model training: Run croptype demo 41 | 42 | - For a demo on how to interact with the WorldCereal Reference Data Module (RDM): Run RDM demo 43 | 44 | #### Option 2: Install Locally 45 | 46 | If you prefer to install the package locally, you can create the environment using **Conda** or **pip**. 47 | 48 | First clone the repository: 49 | ```bash 50 | git clone https://github.com/WorldCereal/worldcereal-classification.git 51 | cd worldcereal-classification 52 | ``` 53 | Next, install the package locally: 54 | - for Conda: `conda env create -f environment.yml` 55 | - for Pip: `pip install .[train,notebooks]` 56 | 57 | ## Usage Example 58 | In its most simple form, a cropland mask can be generated with just few lines of code, triggering an openEO job on CDSE and downloading the result locally: 59 | 60 | ```python 61 | from openeo_gfmap import BoundingBoxExtent, TemporalContext 62 | from worldcereal.job import generate_map 63 | 64 | # Specify the spatial extent 65 | spatial_extent = BoundingBoxExtent( 66 | west=44.432274, 67 | south=51.317362, 68 | east=44.698802, 69 | north=51.428224, 70 | epsg=4326 71 | ) 72 | 73 | # Specify the temporal extent (this has to be one year) 74 | temporal_extent = TemporalContext('2022-11-01', '2023-10-31') 75 | 76 | # Launch processing job (result will automatically be downloaded) 77 | results = generate_map(spatial_extent, temporal_extent, output_dir='.') 78 | ``` 79 | 80 | ## Documentation 81 | 82 | Comprehensive documentation is available at the following link: https://worldcereal.github.io/worldcereal-documentation/ 83 | 84 | ## Support 85 | Questions, suggestions, feedback? Use [our forum](https://forum.esa-worldcereal.org/) to get in touch with us and the rest of the community! 86 | 87 | ## License 88 | 89 | This project is licensed under the terms of the MIT License. See the [LICENSE](LICENSE) file for details. 90 | 91 | ## Acknowledgments 92 | 93 | The WorldCereal project is funded by the [European Space Agency (ESA)](https://www.esa.int/) under grant no. 4000130569/20/I-NB. 94 | 95 | WorldCereal's classification backbone makes use of the Presto model, originally implemented [here](https://github.com/nasaharvest/presto/). Without the groundbreaking work being done by Gabriel Tseng and the rest of the [NASA Harvest](https://www.nasaharvest.org/) team, both in the original Presto implementation as well as its adaptation for WorldCereal, this package would simply not exist in its present form 🙏. 96 | 97 | The pre-configured Jupyter notebook environment in which users can train custom models and launch WorldCereal jobs is provided by [Terrascope](https://terrascope.be/en), the Belgian Earth observation data space, managed by [VITO Remote Sensing](https://remotesensing.vito.be/) on behalf of the [Belgian Science Policy Office](https://www.belspo.be/belspo/index_en.stm) 98 | 99 | ## How to cite 100 | 101 | If you use WorldCereal resources in your work, please cite it as follows: 102 | 103 | ```bibtex 104 | 105 | @article{van_tricht_worldcereal_2023, 106 | title = {{WorldCereal}: a dynamic open-source system for global-scale, seasonal, and reproducible crop and irrigation mapping}, 107 | volume = {15}, 108 | issn = {1866-3516}, 109 | shorttitle = {{WorldCereal}}, 110 | url = {https://essd.copernicus.org/articles/15/5491/2023/}, 111 | doi = {10.5194/essd-15-5491-2023}, 112 | number = {12}, 113 | urldate = {2024-03-01}, 114 | journal = {Earth System Science Data}, 115 | author = {Van Tricht, Kristof and Degerickx, Jeroen and Gilliams, Sven and Zanaga, Daniele and Battude, Marjorie and Grosu, Alex and Brombacher, Joost and Lesiv, Myroslava and Bayas, Juan Carlos Laso and Karanam, Santosh and Fritz, Steffen and Becker-Reshef, Inbal and Franch, Belén and Mollà-Bononad, Bertran and Boogaard, Hendrik and Pratihast, Arun Kumar and Koetz, Benjamin and Szantoi, Zoltan}, 116 | month = dec, 117 | year = {2023}, 118 | pages = {5491--5515}, 119 | } 120 | ``` 121 | -------------------------------------------------------------------------------- /assets/worldcereal_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/assets/worldcereal_logo.jpg -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: worldcereal 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - boto3=1.35.30 8 | - cftime=1.6.4 9 | - catboost=1.2.5 10 | - cpuonly 11 | - einops=0.8.0 12 | - fastparquet=2024.2.0 13 | - geojson=3.1.0 14 | - h5netcdf=1.3.0 15 | - geopandas=0.14.4 16 | - h5py=3.11.0 17 | - leafmap=0.35.1 18 | - loguru=0.7.2 19 | - netcdf4 20 | - numpy<2.0.0 21 | - openeo=0.35.0 22 | - pip 23 | - pyarrow=16.1.0 24 | - pydantic=2.8.0 25 | - python=3.10.0 26 | - pytorch=2.3.1 27 | - rasterio=1.3.10 28 | - rioxarray=0.15.5 29 | - scikit-image=0.22.0 30 | - scikit-learn=1.5.0 31 | - scipy 32 | - shapely=2.0.4 33 | - tqdm 34 | - pip: 35 | - duckdb==1.1.3 36 | - h3==4.1.0 37 | - openeo-gfmap==0.4.6 38 | - git+https://github.com/worldcereal/worldcereal-classification 39 | - git+https://github.com/WorldCereal/presto-worldcereal.git@croptype 40 | 41 | -------------------------------------------------------------------------------- /jenkins_pre_install_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install git 4 | dnf install git -y 5 | 6 | # # Install openeo-gfmap and presto-worldcereal 7 | # dir=$(pwd) 8 | # GFMAP_URL="https://github.com/Open-EO/openeo-gfmap.git" 9 | # PRESTO_URL="https://github.com/WorldCereal/presto-worldcereal.git" 10 | 11 | # su - jenkins -c "cd $dir && \ 12 | # source venv310/bin/activate && \ 13 | # git clone $GFMAP_URL && \ 14 | # cd openeo-gfmap || { echo 'Directory not found! Exiting...'; exit 1; } && \ 15 | # pip install . && \ 16 | # cd .. 17 | # git clone -b croptype $PRESTO_URL && \ 18 | # cd presto-worldcereal || { echo 'Directory not found! Exiting...'; exit 1; } && \ 19 | # pip install . 20 | # " 21 | 22 | # For now only presto-worldcereal as gfmap is up to date on pypi 23 | dir=$(pwd) 24 | 25 | su - jenkins -c "cd $dir && \ 26 | source venv310/bin/activate && \ 27 | pip install git+https://github.com/WorldCereal/presto-worldcereal.git@croptype && \ 28 | pip install git+https://github.com/WorldCereal/prometheo.git 29 | " -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | exclude = venv|notebooks 3 | ignore_missing_imports = True 4 | 5 | [mypy-yaml.*] 6 | ignore_missing_imports = True 7 | -------------------------------------------------------------------------------- /notebooks/notebook_utils/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from catboost import CatBoostClassifier, Pool 6 | from loguru import logger 7 | from presto.utils import DEFAULT_SEED 8 | from sklearn.metrics import classification_report, confusion_matrix 9 | from sklearn.model_selection import train_test_split 10 | from sklearn.utils.class_weight import compute_class_weight 11 | 12 | from worldcereal.parameters import CropLandParameters, CropTypeParameters 13 | 14 | 15 | def get_input(label): 16 | while True: 17 | modelname = input(f"Enter a short name for your {label} (don't use spaces): ") 18 | if " " not in modelname: 19 | return modelname 20 | print("Invalid input. Please enter a name without spaces.") 21 | 22 | 23 | def prepare_training_dataframe( 24 | df: pd.DataFrame, 25 | batch_size: int = 256, 26 | task_type: str = "croptype", 27 | augment: bool = True, 28 | mask_ratio: float = 0.30, 29 | repeats: int = 1, 30 | ) -> pd.DataFrame: 31 | """Method to generate a training dataframe with Presto embeddings for downstream Catboost training. 32 | 33 | Parameters 34 | ---------- 35 | df : pd.DataFrame 36 | input dataframe with required input features for Presto 37 | batch_size : int, optional 38 | by default 256 39 | task_type : str, optional 40 | cropland or croptype task, by default "croptype" 41 | augment : bool, optional 42 | if True, temporal jittering is enabled, by default True 43 | mask_ratio : float, optional 44 | if > 0, inputs are randomly masked before computing Presto embeddings, by default 0.30 45 | repeats: int, optional 46 | number of times to repeat each, by default 1 47 | 48 | Returns 49 | ------- 50 | pd.DataFrame 51 | output training dataframe for downstream training 52 | 53 | Raises 54 | ------ 55 | ValueError 56 | if an unknown tasktype is specified 57 | ValueError 58 | if repeats > 1 and augment=False and mask_ratio=0 59 | """ 60 | from presto.presto import Presto 61 | 62 | from worldcereal.train.data import WorldCerealTrainingDataset, get_training_df 63 | 64 | if task_type == "croptype": 65 | presto_model_url = CropTypeParameters().feature_parameters.presto_model_url 66 | use_valid_date_token = ( 67 | CropTypeParameters().feature_parameters.use_valid_date_token 68 | ) 69 | elif task_type == "cropland": 70 | presto_model_url = CropLandParameters().feature_parameters.presto_model_url 71 | use_valid_date_token = ( 72 | CropLandParameters().feature_parameters.use_valid_date_token 73 | ) 74 | else: 75 | raise ValueError(f"Unknown task type: {task_type}") 76 | 77 | if repeats > 1 and not augment and mask_ratio == 0: 78 | raise ValueError("Repeats > 1 requires augment=True or mask_ratio > 0.") 79 | 80 | # Load pretrained Presto model 81 | logger.info(f"Presto URL: {presto_model_url}") 82 | presto_model = Presto.load_pretrained( 83 | presto_model_url, 84 | from_url=True, 85 | strict=False, 86 | valid_month_as_token=use_valid_date_token, 87 | ) 88 | 89 | # Initialize dataset 90 | df = df.reset_index() 91 | ds = WorldCerealTrainingDataset( 92 | df, 93 | task_type=task_type, 94 | augment=True, 95 | mask_ratio=mask_ratio, 96 | repeats=repeats, 97 | ) 98 | logger.info("Computing Presto embeddings ...") 99 | df = get_training_df( 100 | ds, 101 | presto_model, 102 | batch_size=batch_size, 103 | valid_date_as_token=use_valid_date_token, 104 | ) 105 | 106 | logger.info("Done.") 107 | 108 | return df 109 | 110 | 111 | def train_classifier( 112 | training_dataframe: pd.DataFrame, 113 | class_names: Optional[List[str]] = None, 114 | balance_classes: bool = False, 115 | ) -> Tuple[CatBoostClassifier, Union[str | dict], np.ndarray]: 116 | """Method to train a custom CatBoostClassifier on a training dataframe. 117 | 118 | Parameters 119 | ---------- 120 | training_dataframe : pd.DataFrame 121 | training dataframe containing inputs and targets 122 | class_names : Optional[List[str]], optional 123 | class names to use, by default None 124 | balance_classes : bool, optional 125 | if True, class weights are used during training to balance the classes, by default False 126 | 127 | Returns 128 | ------- 129 | Tuple[CatBoostClassifier, Union[str | dict], np.ndarray] 130 | The trained CatBoost model, the classification report, and the confusion matrix 131 | 132 | Raises 133 | ------ 134 | ValueError 135 | When not enough classes are present in the training dataframe to train a model 136 | """ 137 | 138 | logger.info("Split train/test ...") 139 | samples_train, samples_test = train_test_split( 140 | training_dataframe, 141 | test_size=0.2, 142 | random_state=DEFAULT_SEED, 143 | stratify=training_dataframe["downstream_class"], 144 | ) 145 | 146 | # Define loss function and eval metric 147 | if np.unique(samples_train["downstream_class"]).shape[0] < 2: 148 | raise ValueError("Not enough classes to train a classifier.") 149 | elif np.unique(samples_train["downstream_class"]).shape[0] > 2: 150 | eval_metric = "MultiClass" 151 | loss_function = "MultiClass" 152 | else: 153 | eval_metric = "Logloss" 154 | loss_function = "Logloss" 155 | 156 | # Compute sample weights 157 | if balance_classes: 158 | logger.info("Computing class weights ...") 159 | class_weights = np.round( 160 | compute_class_weight( 161 | class_weight="balanced", 162 | classes=np.unique(samples_train["downstream_class"]), 163 | y=samples_train["downstream_class"], 164 | ), 165 | 3, 166 | ) 167 | class_weights = { 168 | k: v 169 | for k, v in zip(np.unique(samples_train["downstream_class"]), class_weights) 170 | } 171 | logger.info(f"Class weights: {class_weights}") 172 | 173 | sample_weights = np.ones((len(samples_train["downstream_class"]),)) 174 | sample_weights_val = np.ones((len(samples_test["downstream_class"]),)) 175 | for k, v in class_weights.items(): 176 | sample_weights[samples_train["downstream_class"] == k] = v 177 | sample_weights_val[samples_test["downstream_class"] == k] = v 178 | samples_train["weight"] = sample_weights 179 | samples_test["weight"] = sample_weights_val 180 | else: 181 | samples_train["weight"] = 1 182 | samples_test["weight"] = 1 183 | 184 | # Define classifier 185 | custom_downstream_model = CatBoostClassifier( 186 | iterations=2000, # Not too high to avoid too large model size 187 | depth=8, 188 | early_stopping_rounds=20, 189 | loss_function=loss_function, 190 | eval_metric=eval_metric, 191 | random_state=DEFAULT_SEED, 192 | verbose=25, 193 | class_names=( 194 | class_names 195 | if class_names is not None 196 | else np.unique(samples_train["downstream_class"]) 197 | ), 198 | ) 199 | 200 | # Setup dataset Pool 201 | bands = [f"presto_ft_{i}" for i in range(128)] 202 | calibration_data = Pool( 203 | data=samples_train[bands], 204 | label=samples_train["downstream_class"], 205 | weight=samples_train["weight"], 206 | ) 207 | eval_data = Pool( 208 | data=samples_test[bands], 209 | label=samples_test["downstream_class"], 210 | weight=samples_test["weight"], 211 | ) 212 | 213 | # Train classifier 214 | logger.info("Training CatBoost classifier ...") 215 | custom_downstream_model.fit( 216 | calibration_data, 217 | eval_set=eval_data, 218 | ) 219 | 220 | # Make predictions 221 | pred = custom_downstream_model.predict(samples_test[bands]).flatten() 222 | 223 | report = classification_report(samples_test["downstream_class"], pred) 224 | confuson_matrix = confusion_matrix(samples_test["downstream_class"], pred) 225 | 226 | return custom_downstream_model, report, confuson_matrix 227 | -------------------------------------------------------------------------------- /notebooks/notebook_utils/dateslider.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | import ipywidgets as widgets 4 | import pandas as pd 5 | from IPython.core.display import HTML as core_HTML 6 | from IPython.display import display 7 | from loguru import logger 8 | from openeo_gfmap import TemporalContext 9 | 10 | 11 | class date_slider: 12 | """Class that provides a slider for selecting a processing period. 13 | The processing period is fixed in length, amounting to one year. 14 | The processing period will always start the first day of a month and end the last day of a month. 15 | """ 16 | 17 | def __init__(self, start_date=datetime(2018, 1, 1), end_date=datetime(2024, 12, 1)): 18 | 19 | # Define the slider 20 | dates = pd.date_range(start_date, end_date, freq="MS") 21 | options = [(date.strftime("%b %Y"), date) for date in dates] 22 | self.interval_slider = widgets.SelectionRangeSlider( 23 | options=options, 24 | index=(0, 11), # Default to a 11-month interval 25 | orientation="horizontal", 26 | description="", 27 | continuous_update=False, 28 | behaviour="drag", 29 | style={ 30 | "handle_color": "dodgerblue", 31 | }, 32 | layout=widgets.Layout( 33 | width="600px", 34 | margin="0 0 0 10px", 35 | ), 36 | readout=False, 37 | ) 38 | 39 | # Define the HTML text widget for the selected range and focus time 40 | initial_range = [ 41 | (pd.to_datetime(start_date)).strftime("%d %b %Y"), 42 | ( 43 | pd.to_datetime(start_date) 44 | + pd.DateOffset(months=12) 45 | - timedelta(days=1) 46 | ).strftime("%d %b %Y"), 47 | ] 48 | initial_focus_time = ( 49 | pd.to_datetime(start_date) + pd.DateOffset(months=6) 50 | ).strftime("%b %Y") 51 | self.html_text = widgets.HTML( 52 | value=f"Selected range: {initial_range[0]} - {initial_range[1]}
Season center: {initial_focus_time}", 53 | placeholder="HTML placeholder", 54 | description="", 55 | layout=widgets.Layout(justify_content="center", display="flex"), 56 | ) 57 | 58 | # Attach slider observer 59 | self.interval_slider.observe(self.on_slider_change, names="value") 60 | 61 | # Add custom CSS for the ticks 62 | custom_css = """ 63 | 97 | """ 98 | 99 | # # Generate ticks 100 | tick_dates = pd.date_range( 101 | start_date, pd.to_datetime(end_date) + pd.DateOffset(months=1), freq="4MS" 102 | ) 103 | tick_labels = [date.strftime("%b %Y") for date in tick_dates] 104 | n_labels = len(tick_labels) 105 | ticks_html = "" 106 | for i, label in enumerate(tick_labels): 107 | position = (i / (n_labels - 1)) * 100 # Position as a percentage 108 | ticks_html += f""" 109 |
|
110 |
{label.split()[0]}
{label.split()[1]}
111 | """ 112 | 113 | # HTML container for tick marks and labels 114 | tick_marks_and_labels = widgets.HTML( 115 | value=f""" 116 |
117 |
118 |
119 | {ticks_html} 120 |
121 |
122 |
123 | """ 124 | ) 125 | 126 | # Combine slider and ticks using VBox 127 | slider_with_ticks = widgets.VBox( 128 | [self.interval_slider, tick_marks_and_labels], 129 | layout=widgets.Layout( 130 | width="640px", align_items="center", justify_content="center" 131 | ), 132 | ) 133 | 134 | # Add description widget 135 | descr_widget = widgets.HTML( 136 | value=""" 137 |
138 |
139 | Position the slider to select your processing period: 140 |
141 |
142 | """ 143 | ) 144 | 145 | # Arrange the description widget, interval slider, ticks and text widget in a VBox 146 | vbox = widgets.VBox( 147 | [ 148 | descr_widget, 149 | slider_with_ticks, 150 | self.html_text, 151 | ], 152 | layout=widgets.Layout( 153 | align_items="center", justify_content="center", width="650px" 154 | ), 155 | ) 156 | 157 | display(core_HTML(custom_css)) 158 | display(vbox) 159 | 160 | def on_slider_change(self, change): 161 | 162 | start, end = change["new"] 163 | 164 | # keep the interval fixed 165 | expected_end = start + pd.DateOffset(months=11) 166 | if end != expected_end: 167 | end = start + pd.DateOffset(months=11) 168 | self.interval_slider.value = (start, end) 169 | 170 | # update the HTML text underneath the slider 171 | range = [ 172 | (pd.to_datetime(start)).strftime("%d %b %Y"), 173 | ( 174 | pd.to_datetime(start) + pd.DateOffset(months=12) - timedelta(days=1) 175 | ).strftime("%d %b %Y"), 176 | ] 177 | focus_time = (start + pd.DateOffset(months=6)).strftime("%b %Y") 178 | self.html_text.value = f"Selected range: {range[0]} - {range[1]}
Season center: {focus_time}" 179 | 180 | def get_processing_period(self): 181 | 182 | start = pd.to_datetime(self.interval_slider.value[0]) 183 | end = start + pd.DateOffset(months=12) - timedelta(days=1) 184 | 185 | start = start.strftime("%Y-%m-%d") 186 | end = end.strftime("%Y-%m-%d") 187 | logger.info(f"Selected processing period: {start} to {end}") 188 | 189 | return TemporalContext(start, end) 190 | -------------------------------------------------------------------------------- /notebooks/notebook_utils/seasons.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from calendar import monthrange 3 | from typing import List 4 | 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | from matplotlib.patches import Rectangle 8 | from openeo_gfmap import BoundingBoxExtent 9 | from pyproj import Transformer 10 | 11 | from worldcereal.seasons import get_season_dates_for_extent 12 | 13 | logging.getLogger("rasterio").setLevel(logging.ERROR) 14 | 15 | 16 | def get_month_decimal(date): 17 | 18 | return date.timetuple().tm_mon + ( 19 | date.timetuple().tm_mday / monthrange(2021, date.timetuple().tm_mon)[1] 20 | ) 21 | 22 | 23 | def plot_worldcereal_seasons( 24 | seasons: dict, 25 | extent: BoundingBoxExtent, 26 | tag: str = "", 27 | ): 28 | """Method to plot WorldCereal seasons in a matplotlib plot. 29 | Parameters 30 | ---------- 31 | seasons : dict 32 | dictionary with season names as keys and start and end dates as values 33 | dates need to be in datetime format 34 | extent : BoundingBoxExtent 35 | extent for which to plot the seasons 36 | tag : str, optional 37 | tag to add to the title of the plot, by default empty string 38 | 39 | Returns 40 | ------- 41 | None 42 | """ 43 | 44 | # get lat, lon centroid of extent 45 | transformer = Transformer.from_crs( 46 | f"EPSG:{extent.epsg}", "EPSG:4326", always_xy=True 47 | ) 48 | minx, miny = transformer.transform(extent.west, extent.south) 49 | maxx, maxy = transformer.transform(extent.east, extent.north) 50 | lat = (maxy + miny) / 2 51 | lon = (maxx + minx) / 2 52 | location = f"lat={lat:.2f}, lon={lon:.2f}" 53 | 54 | # prepare figure 55 | fig, ax = plt.subplots() 56 | plt.title(f"WorldCereal seasons {tag} ({location})") 57 | ax.set_ylim((0.4, len(seasons) + 0.5)) 58 | ax.set_xlim((0, 13)) 59 | ax.set_yticks(range(1, len(seasons) + 1)) 60 | ax.set_yticklabels(list(seasons.keys())) 61 | ax.set_xticks(range(1, 13)) 62 | ax.set_xticklabels( 63 | [ 64 | "Jan", 65 | "Feb", 66 | "Mar", 67 | "Apr", 68 | "May", 69 | "Jun", 70 | "Jul", 71 | "Aug", 72 | "Sep", 73 | "Oct", 74 | "Nov", 75 | "Dec", 76 | ] 77 | ) 78 | facecolor = "darkgoldenrod" 79 | 80 | # Get the start and end date for each season 81 | idx = 0 82 | for name, dates in seasons.items(): 83 | sos, eos = dates 84 | 85 | # get start and end month (decimals) for plotting 86 | start = get_month_decimal(sos) 87 | end = get_month_decimal(eos) 88 | 89 | # add rectangle to plot 90 | if start < end: 91 | ax.add_patch( 92 | Rectangle((start, idx + 0.75), end - start, 0.5, color=facecolor) 93 | ) 94 | else: 95 | ax.add_patch( 96 | Rectangle((start, idx + 0.75), 12 - start, 0.5, color=facecolor) 97 | ) 98 | ax.add_patch(Rectangle((1, idx + 0.75), end - 1, 0.5, color=facecolor)) 99 | 100 | # add labels to plot 101 | label_start = sos.strftime("%B %d") 102 | label_end = eos.strftime("%B %d") 103 | plt.text( 104 | start - 0.2, 105 | idx + 0.65, 106 | label_start, 107 | fontsize=8, 108 | color="darkgreen", 109 | ha="left", 110 | va="center", 111 | ) 112 | plt.text( 113 | end + 0.2, 114 | idx + 0.65, 115 | label_end, 116 | fontsize=8, 117 | color="darkred", 118 | ha="right", 119 | va="center", 120 | ) 121 | 122 | idx += 1 123 | 124 | # display plot 125 | plt.show() 126 | 127 | 128 | def retrieve_worldcereal_seasons( 129 | extent: BoundingBoxExtent, 130 | seasons: List[str] = ["s1", "s2"], 131 | plot: bool = True, 132 | ): 133 | """Method to retrieve default WorldCereal seasons from global crop calendars. 134 | These will be logged to the screen for informative purposes. 135 | 136 | Parameters 137 | ---------- 138 | extent : BoundingBoxExtent 139 | extent for which to load seasonality 140 | seasons : List[str], optional 141 | seasons to load, by default s1 and s2 142 | plot : bool, optional 143 | whether to plot the seasons, by default True 144 | 145 | Returns 146 | ------- 147 | dict 148 | dictionary with season names as keys and start and end dates as values 149 | """ 150 | results = {} 151 | 152 | # Get the start and end date for each season 153 | for idx, season in enumerate(seasons): 154 | seasonal_extent = get_season_dates_for_extent(extent, 2021, f"tc-{season}") 155 | sos = pd.to_datetime(seasonal_extent.start_date) 156 | eos = pd.to_datetime(seasonal_extent.end_date) 157 | results[season] = (sos, eos) 158 | 159 | # Plot the seasons if requested 160 | if plot: 161 | plot_worldcereal_seasons(results, extent) 162 | 163 | return results 164 | -------------------------------------------------------------------------------- /notebooks/patch_to_point/job_df.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/patch_to_point/job_df.parquet -------------------------------------------------------------------------------- /notebooks/patch_to_point/rdm_query_example: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/patch_to_point/rdm_query_example -------------------------------------------------------------------------------- /notebooks/resources/Cropland_inference_choose_end_date.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/Cropland_inference_choose_end_date.png -------------------------------------------------------------------------------- /notebooks/resources/Custom_cropland_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/Custom_cropland_map.png -------------------------------------------------------------------------------- /notebooks/resources/Custom_croptype_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/Custom_croptype_map.png -------------------------------------------------------------------------------- /notebooks/resources/Default_cropland_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/Default_cropland_map.png -------------------------------------------------------------------------------- /notebooks/resources/Landcover_training_data_density_PhI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/Landcover_training_data_density_PhI.png -------------------------------------------------------------------------------- /notebooks/resources/MOOC_refdata_RDM_exploration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/MOOC_refdata_RDM_exploration.png -------------------------------------------------------------------------------- /notebooks/resources/WorldCereal_private_extractions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/WorldCereal_private_extractions.png -------------------------------------------------------------------------------- /notebooks/resources/cropland_data_Ph1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/cropland_data_Ph1.png -------------------------------------------------------------------------------- /notebooks/resources/croptype_data_Ph1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/croptype_data_Ph1.png -------------------------------------------------------------------------------- /notebooks/resources/wc2eurocrops_map.csv: -------------------------------------------------------------------------------- 1 | name,croptype,ewoc_code,landcover 2 | Unknown,0,00-00-00-000-0,0 3 | Cereals,1000,11-01-00-000-0,11 4 | Wheat,1100,11-01-01-000-0,11 5 | Winter wheat,1110,11-01-01-000-1,11 6 | Spring wheat,1120,11-01-01-000-2,11 7 | Maize,1200,11-01-06-000-0,11 8 | Rice,1300,11-01-08-000-0,11 9 | Sorghum,1400,11-01-07-003-0,11 10 | Barley,1500,11-01-02-000-0,11 11 | Winter barley,1510,11-01-02-000-1,11 12 | Spring barley,1520,11-01-02-000-2,11 13 | Rye,1600,11-01-03-000-0,11 14 | Winter rye,1610,11-01-03-000-1,11 15 | Spring rye,1620,11-01-03-000-2,11 16 | Oats,1700,11-01-04-000-0,11 17 | Millets,1800,11-01-07-001-0,11 18 | Other cereals,1900,11-01-00-000-0,11 19 | Winter cereal,1910,11-01-00-000-1,11 20 | Spring cereal,1920,11-01-00-000-2,11 21 | Vegetables and melons,2000,11-03-00-000-0,11 22 | Leafy or stem vegetables,2100,11-03-04-000-0,11 23 | Artichokes,2110,11-03-04-002-0,11 24 | Asparagus,2120,11-03-04-003-0,11 25 | Cabages,2130,11-03-06-000-0,11 26 | Cauliflowers & brocoli,2140,11-03-06-013-0,11 27 | Lettuce,2150,11-03-08-011-0,11 28 | Spinach,2160,11-03-08-008-0,11 29 | Chicory,2170,11-03-08-002-0,11 30 | Other leafy/stem vegetables,2190,11-03-04-000-0,11 31 | Fruit-bearing vegetables,2200,11-03-01-000-0,11 32 | Cucumbers,2210,11-03-02-001-0,11 33 | Eggplants,2220,11-03-01-002-0,11 34 | Tomatoes,2230,11-03-01-001-0,11 35 | Watermelons,2240,11-03-02-005-0,11 36 | Cantaloupes and other melons,2250,11-03-02-003-0,11 37 | "Pumpkin, squash and gourds",2260,11-03-02-004-0,11 38 | Other fruit-bearing vegetables,2290,11-03-01-000-0,11 39 | "Root, bulb or tuberous vegetables",2300,11-03-09-000-0,11 40 | Carrots,2310,11-03-09-004-0,11 41 | Turnips,2320,11-03-09-007-0,11 42 | Garlic,2330,11-03-11-002-0,11 43 | Onions and shallots,2340,11-03-11-007-0,11 44 | Leeks & other alliaceous vegetables,2350,11-03-11-000-0,11 45 | "Other root, bulb or tuberous vegetables",2390,11-03-09-000-0,11 46 | Mushrooms and truffles,2400,11-04-00-000-0,11 47 | Other vegetables,2900,11-03-00-000-0,11 48 | Fruit and nuts,3000,12-01-00-000-0,12 49 | Tropical and subtropical fruits,3100,12-01-02-000-0,12 50 | Avocados,3110,12-01-02-001-0,12 51 | Bananas & plantains,3120,12-01-02-002-0,12 52 | Dates,3130,12-01-02-003-0,12 53 | Figs,3140,12-01-01-006-0,12 54 | Mangoes,3150,12-01-02-004-0,12 55 | Papayas,3160,12-01-02-005-0,12 56 | Pineapples,3170,12-01-02-006-0,12 57 | Other tropical and subtropical fruits,3190,12-01-02-000-0,12 58 | Citrus fruits,3200,12-01-03-000-0,12 59 | Grapefruit & pomelo,3210,12-01-03-000-0,12 60 | Lemons & Limes,3220,12-01-03-003-0,12 61 | Oranges,3230,12-01-03-004-0,12 62 | "Tangerines, mandarins, clementines",3240,12-01-03-005-0,12 63 | Other citrus fruit,3290,12-01-03-000-0,12 64 | Grapes,3300,12-01-00-001-0,12 65 | Berries,3400,12-01-05-000-0,12 66 | Currants,3410,12-01-05-006-0,12 67 | Gooseberries,3420,12-01-05-007-0,12 68 | Kiwi fruit,3430,12-01-05-015-0,12 69 | Raspberries,3440,12-01-05-010-0,12 70 | Strawberries,3450,11-03-12-001-0,11 71 | Blueberries,3460,12-01-05-004-0,12 72 | Other berries,3490,12-01-05-000-0,12 73 | Pome fruits and stone fruits,3500,12-01-01-000-0,12 74 | Apples,3510,12-01-01-002-0,12 75 | Apricots,3520,12-01-01-003-0,12 76 | Cherries & sour cherries,3530,12-01-01-004-0,12 77 | Peaches & nectarines,3540,12-01-01-017-0,12 78 | Pears & quinces,3550,12-01-01-019-0,12 79 | Plums and sloes,3560,12-01-01-012-0,12 80 | Other pome fruits and stone fruits,3590,12-01-01-000-0,12 81 | Nuts,3600,12-01-04-000-0,12 82 | Almonds,3610,12-01-04-001-0,12 83 | Cashew nuts,3620,12-01-04-007-0,12 84 | Chestnuts,3630,12-01-04-005-0,12 85 | Hazelnuts,3640,12-01-04-002-0,12 86 | Pistachios,3650,12-01-04-004-0,12 87 | Walnuts,3660,12-01-04-006-0,12 88 | Other nuts,3690,12-01-04-000-0,12 89 | Other fruit,3900,12-01-00-000-0,12 90 | Oilseed crops,4000,11-06-00-000-0,11 91 | Soya beans,4100,11-06-00-002-0,11 92 | Groundnuts,4200,11-06-00-005-0,11 93 | Temporary oilseed crops,4300,11-06-00-000-0,11 94 | Castor bean,4310,11-06-00-006-0,11 95 | Linseed,4320,11-06-00-007-0,11 96 | Mustard,4330,11-06-00-008-0,11 97 | Niger seed,4340,11-06-00-004-0,11 98 | Rapeseed,4350,11-06-00-003-0,11 99 | Winter rapeseed,4351,11-06-00-003-1,11 100 | Spring rapeseed,4352,11-06-00-003-2,11 101 | Safflower,4360,11-06-00-009-0,11 102 | Sesame,4370,11-06-00-010-0,11 103 | Sunflower,4380,11-06-00-001-0,11 104 | Other temporary oilseed crops,4390,11-06-00-000-0,11 105 | Permanent oilseed crops,4400,12-03-00-000-0,12 106 | Coconuts,4410,12-03-00-002-0,12 107 | Olives,4420,12-03-00-001-0,12 108 | Oil palms,4430,12-03-00-003-0,12 109 | Other oleaginous fruits,4490,12-03-00-000-0,12 110 | Root/tuber crops,5000,11-07-00-000-0,11 111 | Potatoes,5100,11-07-00-001-0,11 112 | Sweet potatoes,5200,11-07-00-002-0,11 113 | Cassava,5300,11-07-00-004-0,11 114 | Yams,5400,11-07-00-005-0,11 115 | Other roots and tubers,5900,11-07-00-000-0,11 116 | Beverage and spice crops,6000,10-00-00-000-0,10 117 | Beverage crops,6100,12-02-00-000-0,12 118 | Coffee,6110,12-02-00-001-0,12 119 | Tea,6120,12-02-00-002-0,12 120 | Maté,6130,12-02-00-003-0,12 121 | Cocoa,6140,12-02-00-004-0,12 122 | Other beverage crops,6190,10-00-00-000-0,10 123 | Spice crops,6200,12-02-00-000-0,12 124 | Chilies & peppers,6211,11-03-03-000-0,11 125 | "Anise, badian, fennel",6212,11-00-00-000-0,11 126 | Other temporary spice crops,6219,11-09-00-000-0,11 127 | Pepper,6221,11-09-00-029-0,11 128 | "Nutmeg, mace, cardamoms",6222,12-02-00-007-0,12 129 | Cinnamon,6223,12-02-00-008-0,12 130 | Cloves,6224,12-02-00-009-0,12 131 | Ginger,6225,11-09-00-048-0,11 132 | Vanilla,6226,11-09-00-049-0,10 133 | Other permanent spice crops,6229,12-02-00-000-0,12 134 | Leguminous crops,7000,11-05-00-000-0,11 135 | Beans,7100,11-05-01-001-0,11 136 | Broad beans,7200,11-05-01-003-0,11 137 | Chick peas,7300,11-05-01-004-0,11 138 | Cow peas,7400,11-05-01-005-0,11 139 | Lentils,7500,11-05-00-003-0,11 140 | Lupins,7600,11-05-00-004-0,11 141 | Peas,7700,11-05-01-002-0,11 142 | Pigeon peas,7800,11-05-01-006-0,11 143 | Other Leguminous crops,7900,10-00-00-000-0,10 144 | Other Leguminous crops - Temporary,7910,11-05-00-000-0,11 145 | Other Leguminous crops - Permanent,7920,12-04-00-000-0,12 146 | Sugar crops,8000,10-00-00-000-0,10 147 | Sugar beet,8100,11-07-00-003-1,11 148 | Sugar cane,8200,11-11-01-010-0,11 149 | Sweet sorghum,8300,11-01-07-004-0,11 150 | Other sugar crops,8900,10-00-00-000-0,10 151 | Other crops,9000,10-00-00-000-0,10 152 | Grasses and other fodder crops,9100,11-11-00-000-0,11 153 | Temporary grass crops,9110,11-11-00-001-0,11 154 | Permanent grass crops,9120,13-00-00-000-0,13 155 | Fibre crops,9200,10-00-00-000-0,10 156 | Temporary fibre crops,9210,11-08-00-000-0,11 157 | Cotton,9211,11-08-00-001-0,11 158 | "Jute, kenaf and similar",9212,11-08-00-002-0,11 159 | "Flax, hemp and similar",9213,11-08-02-000-0,11 160 | Other temporary fibre crops,9219,11-08-00-000-0,11 161 | Permanent fibre crops,9220,12-05-00-000-0,12 162 | "Medicinal, aromatic, pesticidal crops",9300,10-00-00-000-0,10 163 | Temporary medicinal etc crops,9310,11-09-00-000-0,11 164 | Permanent medicinal etc crops,9320,12-02-00-000-0,12 165 | Rubber,9400,12-06-00-008-0,12 166 | Flower crops,9500,11-10-00-000-0,11 167 | Temporary flower crops,9510,11-10-00-000-0,11 168 | Permanent flower crops,9520,12-06-00-000-0,12 169 | Tobacco,9600,11-09-00-050-0,11 170 | Other other crops,9900,10-00-00-000-0,10 171 | Other crops - temporary,9910,11-00-00-000-0,11 172 | Other crops - permanent,9920,12-00-00-000-0,12 173 | mixed cropping,9998,10-00-00-000-0,10 174 | Cropland,10,10-00-00-000-0,10 175 | Annual cropland,11,11-00-00-000-0,11 176 | Perennial cropland,12,12-00-00-000-0,12 177 | Grassland *,13,13-00-00-000-0,13 178 | Herbaceous vegetation,20,20-00-00-000-0,20 179 | Shrubland,30,30-00-00-000-0,30 180 | Deciduous forest,40,43-02-00-000-0,40 181 | Evergreen forest,41,43-01-00-000-0,41 182 | Mixed/unknown forest,42,43-00-00-000-0,42 183 | Bare / sparse vegetation,50,50-00-00-000-0,50 184 | Built up / urban,60,60-00-00-000-0,60 185 | Water,70,70-00-00-000-0,70 186 | Snow / ice,80,50-05-00-000-0,80 187 | No cropland (including perennials),98,17-00-00-000-0,98 188 | No cropland,99,16-00-00-000-0,99 189 | Unknown,991,00-00-00-000-0,0 190 | Unknown,9700,00-00-00-000-0,0 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-vcs"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.build] 6 | exclude = [ 7 | "/dist", 8 | "/notebooks", 9 | "/scripts", 10 | "/bin", 11 | "/tests", 12 | ] 13 | 14 | [tool.hatch.version] 15 | path = "src/worldcereal/_version.py" 16 | pattern = "^__version__ = ['\"](?P[^'\"]+)['\"]$" 17 | 18 | [tool.hatch.metadata] 19 | allow-direct-references = true 20 | 21 | [project] 22 | name = "worldcereal" 23 | authors = [ 24 | { name="Kristof Van Tricht" }, 25 | { name="Jeroen Degerickx" }, 26 | { name="Darius Couchard" }, 27 | { name="Christina Butsko" }, 28 | ] 29 | description = "WorldCereal classification module" 30 | readme = "README.md" 31 | requires-python = ">=3.8" 32 | dynamic = ["version"] 33 | classifiers = [ 34 | "Programming Language :: Python :: 3", 35 | "Operating System :: OS Independent", 36 | ] 37 | dependencies = [ 38 | "boto3==1.35.30", 39 | "cftime", 40 | "geojson", 41 | "geopandas", 42 | "h3==4.1.0", 43 | "h5netcdf>=1.1.0", 44 | "loguru>=0.7.2", 45 | "netcdf4<=1.6.4", 46 | "numpy<2.0.0", 47 | "openeo==0.35.0", 48 | "openeo-gfmap==0.4.6", 49 | "pyarrow", 50 | "pydantic==2.8.0", 51 | "rioxarray>=0.13.0", 52 | "scipy", 53 | "duckdb==1.1.3", 54 | "tqdm", 55 | "xarray>=2022.3.0" 56 | ] 57 | 58 | [project.urls] 59 | "Homepage" = "https://github.com/WorldCereal/worldcereal-classification" 60 | "Bug Tracker" = "https://github.com/WorldCereal/worldcereal-classification/issues" 61 | 62 | [project.optional-dependencies] 63 | dev = [ 64 | "pytest>=7.4.0", 65 | "pytest-depends", 66 | "matplotlib>=3.3.0" 67 | ] 68 | train = [ 69 | "catboost==1.2.5", 70 | "presto-worldcereal==0.1.6", 71 | "scikit-learn==1.5.0", 72 | "torch==2.3.1", 73 | "pystac==1.10.1", 74 | "pystac-client==0.8.3" 75 | ] 76 | notebooks = [ 77 | "ipywidgets==8.1.3", 78 | "leafmap==0.35.1" 79 | ] 80 | 81 | [tool.pytest.ini_options] 82 | testpaths = [ 83 | "tests", 84 | ] 85 | addopts = [ 86 | "--import-mode=prepend", 87 | ] 88 | 89 | [tool.isort] 90 | profile = "black" 91 | 92 | [tool.ruff] 93 | # line-length = 88 94 | 95 | [tool.ruff.lint] 96 | select = ["E", "F"] 97 | ignore = [ 98 | "E501", # Ignore "line-too-long" issues, let black handle that. 99 | ] -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests 3 | addopts = --strict-markers 4 | markers = 5 | slow: mark for slow tests (skip these tests by adding option '-m "not slow"') 6 | -------------------------------------------------------------------------------- /scripts/extractions/extract.py: -------------------------------------------------------------------------------- 1 | """Main script to perform extractions. Each collection has it's specifities and 2 | own functions, but the setup and main thread execution is done here.""" 3 | 4 | import argparse 5 | from pathlib import Path 6 | from typing import Dict, Optional, Union 7 | 8 | from openeo_gfmap import Backend 9 | 10 | from worldcereal.extract.common import run_extractions 11 | from worldcereal.stac.constants import ExtractionCollection 12 | 13 | 14 | def main( 15 | collection: ExtractionCollection, 16 | output_folder: Path, 17 | samples_df_path: Path, 18 | ref_id: str, 19 | max_locations_per_job: int = 500, 20 | memory: Optional[str] = None, 21 | python_memory: Optional[str] = None, 22 | max_executors: Optional[int] = None, 23 | parallel_jobs: int = 2, 24 | restart_failed: bool = False, 25 | extract_value: int = 1, 26 | backend=Backend.CDSE, 27 | write_stac_api: bool = False, 28 | image_name: Optional[str] = None, 29 | organization_id: Optional[int] = None, 30 | ) -> None: 31 | """Main function responsible for launching point and patch extractions. 32 | 33 | Parameters 34 | ---------- 35 | collection : ExtractionCollection 36 | The collection to extract. Most popular: PATCH_WORLDCEREAL, POINT_WORLDCEREAL 37 | output_folder : Path 38 | The folder where to store the extracted data 39 | samples_df_path : Path 40 | Path to the input dataframe containing the geometries 41 | for which extractions need to be done 42 | ref_id : str 43 | Official ref_id of the source dataset 44 | max_locations_per_job : int, optional 45 | The maximum number of locations to extract per job, by default 500 46 | memory : str, optional 47 | Memory to allocate for the executor. 48 | If not specified, the default value is used, depending on type of collection. 49 | python_memory : str, optional 50 | Memory to allocate for the python processes as well as OrfeoToolbox in the executors, 51 | If not specified, the default value is used, depending on type of collection. 52 | max_executors : int, optional 53 | Number of executors to run. 54 | If not specified, the default value is used, depending on type of collection. 55 | parallel_jobs : int, optional 56 | The maximum number of parallel jobs to run at the same time, by default 10 57 | restart_failed : bool, optional 58 | Restart the jobs that previously failed, by default False 59 | extract_value : int, optional 60 | All samples with an "extract" value equal or larger than this one, will be extracted, by default 1 61 | backend : openeo_gfmap.Backend, optional 62 | cloud backend where to run the extractions, by default Backend.CDSE 63 | write_stac_api : bool, optional 64 | Save metadata of extractions to STAC API (requires authentication), by default False 65 | image_name : str, optional 66 | Specific openEO image name to use for the jobs, by default None 67 | organization_id : int, optional 68 | ID of the organization to use for the job, by default None in which case 69 | the active organization for the user is used 70 | 71 | Returns 72 | ------- 73 | None 74 | """ 75 | 76 | # Compile custom job options 77 | job_options: Optional[Dict[str, Union[str, int]]] = { 78 | key: value 79 | for key, value in { 80 | "executor-memory": memory, 81 | "python-memory": python_memory, 82 | "max-executors": max_executors, 83 | "image-name": image_name, 84 | "etl_organization_id": organization_id, 85 | }.items() 86 | if value is not None 87 | } or None 88 | 89 | # We need to be sure that the 90 | # output_folder points to the correct ref_id 91 | if output_folder.name != ref_id: 92 | raise ValueError( 93 | f"`root_folder` should point to ref_id `{ref_id}`, instead got: {output_folder}" 94 | ) 95 | 96 | # Fire up extractions 97 | run_extractions( 98 | collection, 99 | output_folder, 100 | samples_df_path, 101 | ref_id, 102 | max_locations_per_job=max_locations_per_job, 103 | job_options=job_options, 104 | parallel_jobs=parallel_jobs, 105 | restart_failed=restart_failed, 106 | extract_value=extract_value, 107 | backend=backend, 108 | write_stac_api=write_stac_api, 109 | ) 110 | 111 | return 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser(description="Extract data from a collection") 116 | parser.add_argument( 117 | "collection", 118 | type=ExtractionCollection, 119 | choices=list(ExtractionCollection), 120 | help="The collection to extract", 121 | ) 122 | parser.add_argument( 123 | "output_folder", type=Path, help="The folder where to store the extracted data" 124 | ) 125 | parser.add_argument( 126 | "samples_df_path", 127 | type=Path, 128 | help="Path to the samples dataframe with the data to extract", 129 | ) 130 | parser.add_argument( 131 | "--ref_id", 132 | type=str, 133 | required=True, 134 | help="Official `ref_id` of the source dataset", 135 | ) 136 | parser.add_argument( 137 | "--max_locations", 138 | type=int, 139 | default=500, 140 | help="The maximum number of locations to extract per job", 141 | ) 142 | parser.add_argument( 143 | "--memory", 144 | type=str, 145 | default="1800m", 146 | help="Memory to allocate for the executor.", 147 | ) 148 | parser.add_argument( 149 | "--python_memory", 150 | type=str, 151 | default="1900m", 152 | help="Memory to allocate for the python processes as well as OrfeoToolbox in the executors.", 153 | ) 154 | parser.add_argument( 155 | "--max_executors", type=int, default=22, help="Number of executors to run." 156 | ) 157 | parser.add_argument( 158 | "--parallel_jobs", 159 | type=int, 160 | default=2, 161 | help="The maximum number of parallel jobs to run at the same time.", 162 | ) 163 | parser.add_argument( 164 | "--restart_failed", 165 | action="store_true", 166 | help="Restart the jobs that previously failed.", 167 | ) 168 | parser.add_argument( 169 | "--extract_value", 170 | type=int, 171 | default=1, 172 | help="The value of the `extract` flag to use in the dataframe.", 173 | ) 174 | parser.add_argument( 175 | "--write_stac_api", 176 | type=bool, 177 | default=False, 178 | help="Flag to write S1 and S2 patch extraction results to STAC API or not.", 179 | ) 180 | parser.add_argument( 181 | "--image_name", 182 | type=str, 183 | default=None, 184 | help="Specific openEO image name to use for the jobs.", 185 | ) 186 | parser.add_argument( 187 | "--organization_id", 188 | type=int, 189 | default=None, 190 | help="ID of the organization to use for the job.", 191 | ) 192 | 193 | args = parser.parse_args() 194 | 195 | main( 196 | collection=args.collection, 197 | output_folder=args.output_folder, 198 | samples_df_path=args.samples_df_path, 199 | ref_id=args.ref_id, 200 | max_locations_per_job=args.max_locations, 201 | memory=args.memory, 202 | python_memory=args.python_memory, 203 | max_executors=args.max_executors, 204 | parallel_jobs=args.parallel_jobs, 205 | restart_failed=args.restart_failed, 206 | extract_value=args.extract_value, 207 | backend=Backend.CDSE, 208 | write_stac_api=args.write_stac_api, 209 | image_name=args.image_name, 210 | organization_id=args.organization_id, 211 | ) 212 | -------------------------------------------------------------------------------- /scripts/inference/collect_inputs.py: -------------------------------------------------------------------------------- 1 | from openeo_gfmap import BoundingBoxExtent, TemporalContext 2 | 3 | from worldcereal.job import collect_inputs 4 | 5 | 6 | def main(): 7 | 8 | # Set the spatial extent 9 | # bbox_utm = (664000.0, 5611120.0, 665000.0, 5612120.0) 10 | bbox_utm = (664000, 5611134, 684000, 5631134) # Large test 11 | epsg = 32631 12 | spatial_extent = BoundingBoxExtent(*bbox_utm, epsg) 13 | 14 | # Set temporal range 15 | temporal_extent = TemporalContext( 16 | start_date="2020-11-01", 17 | end_date="2021-10-31", 18 | ) 19 | 20 | outfile = "local_presto_inputs_large.nc" 21 | collect_inputs(spatial_extent, temporal_extent, output_path=outfile) 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /scripts/inference/cropland_mapping.py: -------------------------------------------------------------------------------- 1 | """Cropland mapping inference script, demonstrating the use of the GFMAP, Presto and WorldCereal classifiers in a first inference pipeline.""" 2 | 3 | import argparse 4 | from pathlib import Path 5 | 6 | from loguru import logger 7 | from openeo_gfmap import BoundingBoxExtent, TemporalContext 8 | from openeo_gfmap.backend import Backend, BackendContext 9 | 10 | from worldcereal.job import WorldCerealProductType, generate_map 11 | from worldcereal.parameters import PostprocessParameters 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser( 15 | prog="WC - Crop Mapping Inference", 16 | description="Crop Mapping inference using GFMAP, Presto and WorldCereal classifiers", 17 | ) 18 | 19 | parser.add_argument("minx", type=float, help="Minimum X coordinate (west)") 20 | parser.add_argument("miny", type=float, help="Minimum Y coordinate (south)") 21 | parser.add_argument("maxx", type=float, help="Maximum X coordinate (east)") 22 | parser.add_argument("maxy", type=float, help="Maximum Y coordinate (north)") 23 | parser.add_argument( 24 | "start_date", type=str, help="Starting date for data extraction." 25 | ) 26 | parser.add_argument("end_date", type=str, help="Ending date for data extraction.") 27 | parser.add_argument( 28 | "product", 29 | type=str, 30 | help="Product to generate. One of ['cropland', 'croptype']", 31 | ) 32 | parser.add_argument( 33 | "output_path", 34 | type=Path, 35 | help="Path to folder where to save the resulting GeoTiff.", 36 | ) 37 | parser.add_argument( 38 | "--epsg", 39 | type=int, 40 | default=4326, 41 | help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.", 42 | ) 43 | parser.add_argument( 44 | "--postprocess", 45 | action="store_true", 46 | help="Run postprocessing on the croptype and/or the cropland product after inference.", 47 | ) 48 | parser.add_argument( 49 | "--class-probabilities", 50 | action="store_true", 51 | help="Output per-class probabilities in the resulting product", 52 | ) 53 | 54 | args = parser.parse_args() 55 | 56 | minx = args.minx 57 | miny = args.miny 58 | maxx = args.maxx 59 | maxy = args.maxy 60 | epsg = args.epsg 61 | 62 | start_date = args.start_date 63 | end_date = args.end_date 64 | 65 | product = args.product 66 | 67 | # minx, miny, maxx, maxy = (664000, 5611134, 665000, 5612134) # Small test 68 | # minx, miny, maxx, maxy = (664000, 5611134, 684000, 5631134) # Large test 69 | # epsg = 32631 70 | # start_date = "2020-11-01" 71 | # end_date = "2021-10-31" 72 | # product = "croptype" 73 | 74 | spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy, epsg) 75 | temporal_extent = TemporalContext(start_date, end_date) 76 | 77 | backend_context = BackendContext(Backend.CDSE) 78 | 79 | job_results = generate_map( 80 | spatial_extent, 81 | temporal_extent, 82 | args.output_path, 83 | product_type=WorldCerealProductType(product), 84 | postprocess_parameters=PostprocessParameters( 85 | enable=args.postprocess, keep_class_probs=args.class_probabilities 86 | ), 87 | out_format="GTiff", 88 | backend_context=backend_context, 89 | ) 90 | logger.success(f"Job finished:\n\t{job_results}") 91 | -------------------------------------------------------------------------------- /scripts/inference/cropland_mapping_local.py: -------------------------------------------------------------------------------- 1 | """Perform cropland mapping inference using a local execution of presto. 2 | 3 | Make sure you test this script on the Python version 3.9+, and have worldcereal 4 | dependencies installed with the presto wheel file installed with it's dependencies. 5 | """ 6 | 7 | from pathlib import Path 8 | 9 | import requests 10 | import xarray as xr 11 | from openeo_gfmap.features.feature_extractor import ( 12 | EPSG_HARMONIZED_NAME, 13 | apply_feature_extractor_local, 14 | ) 15 | from openeo_gfmap.inference.model_inference import apply_model_inference_local 16 | 17 | from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor 18 | from worldcereal.openeo.inference import CropClassifier 19 | 20 | TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc" 21 | TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc" 22 | PRESTO_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" 23 | CATBOOST_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" 24 | 25 | if __name__ == "__main__": 26 | if not TEST_FILE_PATH.exists(): 27 | print("Downloading test input data...") 28 | # Download the test input data 29 | with requests.get(TEST_FILE_URL, stream=True, timeout=180) as response: 30 | response.raise_for_status() 31 | with open(TEST_FILE_PATH, "wb") as f: 32 | for chunk in response.iter_content(chunk_size=8192): 33 | f.write(chunk) 34 | 35 | print("Loading array in-memory...") 36 | arr = ( 37 | xr.open_dataset(TEST_FILE_PATH) 38 | .to_array(dim="bands") 39 | .drop_sel(bands="crs") 40 | .astype("uint16") 41 | ) 42 | 43 | print("Running presto UDF locally") 44 | features = apply_feature_extractor_local( 45 | PrestoFeatureExtractor, 46 | arr, 47 | parameters={ 48 | EPSG_HARMONIZED_NAME: 32631, 49 | "ignore_dependencies": True, 50 | "presto_model_url": PRESTO_URL, 51 | }, 52 | ) 53 | 54 | features.to_netcdf(Path.cwd() / "presto_test_features_cropland.nc") 55 | 56 | print("Running classification inference UDF locally") 57 | 58 | classification = apply_model_inference_local( 59 | CropClassifier, 60 | features, 61 | parameters={ 62 | EPSG_HARMONIZED_NAME: 32631, 63 | "ignore_dependencies": True, 64 | "classifier_url": CATBOOST_URL, 65 | }, 66 | ) 67 | 68 | classification.to_netcdf(Path.cwd() / "test_classification_cropland.nc") 69 | -------------------------------------------------------------------------------- /scripts/inference/cropland_mapping_udp.py: -------------------------------------------------------------------------------- 1 | """This short script demonstrates how to run WorldCereal cropland extent inference through 2 | an OpenEO User-Defined Process (UDP) on CDSE. 3 | 4 | The WorldCereal default cropland model is used and default post-processing is applied 5 | (using smooth_probabilities method). 6 | 7 | The user needs to manually download the resulting map through the OpenEO Web UI: 8 | https://openeo.dataspace.copernicus.eu/ 9 | Or use the openeo API to fetch the resulting map. 10 | """ 11 | 12 | import openeo 13 | 14 | ###### USER INPUTS ###### 15 | # Define the spatial and temporal extent 16 | spatial_extent = { 17 | "west": 3.809252, 18 | "south": 51.232365, 19 | "east": 3.833542, 20 | "north": 51.245477, 21 | "crs": "EPSG:4326", 22 | "srs": "EPSG:4326", 23 | } 24 | 25 | # Temporal extent needs to consist of exactly twelve months, 26 | # always starting first day of the month and ending last day of the month. 27 | temporal_extent = ["2020-01-01", "2020-12-31"] 28 | 29 | # This argument is optional and determines the projection of the output products 30 | target_projection = "EPSG:4326" 31 | 32 | # Here you can set custom job options, such as driver-memory, executor-memory, etc. 33 | # In case you would like to use the default options, set job_options to None. 34 | 35 | # job_options = { 36 | # "driver-memory": "4g", 37 | # "executor-memory": "2g", 38 | # "python-memory": "3g", 39 | # } 40 | job_options = None 41 | 42 | ###### END OF USER INPUTS ###### 43 | 44 | # Connect to openeo backend 45 | c = openeo.connect("openeo.dataspace.copernicus.eu").authenticate_oidc() 46 | 47 | # Define the operations 48 | classes = c.datacube_from_process( 49 | process_id="worldcereal_crop_extent", 50 | namespace="https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/refs/tags/worldcereal_crop_extent_v1.0.2/src/worldcereal/udp/worldcereal_crop_extent.json", 51 | temporal_extent=temporal_extent, 52 | spatial_extent=spatial_extent, 53 | projection=target_projection, 54 | ) 55 | 56 | # Run the job 57 | job = classes.execute_batch( 58 | title="WorldCereal Crop Extent UDP test", 59 | out_format="GTiff", 60 | job_options=job_options, 61 | ) 62 | 63 | ###### EXPLANATION OF RESULTS ###### 64 | # Two GeoTiff images will be generated: 65 | # 1. "worldcereal-cropland-extent.tif": The original results of the WorldCereal cropland model 66 | # 2. "worldcereal-cropland-extent-postprocessed.tif": The results after applying the default post-processing, 67 | # which includes a smoothing of the probabilities and relabelling of pixels according to the smoothed probabilities. 68 | 69 | # Each tif file contains 4 bands, all at 10m resolution: 70 | # - Band 1: Predicted cropland labels (0: non-cropland, 1: cropland) 71 | # - Band 2: Probability of the winning class (50 - 100) 72 | # - Band 3: Probability of the non-cropland class (0 - 100) 73 | # - Band 4: Probability of the cropland class (0 - 100) 74 | -------------------------------------------------------------------------------- /scripts/inference/croptype_mapping_udp.py: -------------------------------------------------------------------------------- 1 | """This short script demonstrates how to run WorldCereal crop type inference through 2 | an OpenEO User-Defined Process (UDP) on CDSE. 3 | 4 | No default model is currently available, meaning that a user first needs to 5 | train a custom model before using this script. Training a custom model can be done 6 | through the WorldCereal Custom Crop Type Training Notebook available here: 7 | https://github.com/WorldCereal/worldcereal-classification/blob/main/notebooks/worldcereal_custom_croptype.ipynb 8 | 9 | A user is required to specify a spatial and temporal extent, as well as the link to 10 | the crop type model to be used. The model should be hosted on a publicly accessible server. 11 | 12 | The user needs to manually download the resulting map through the OpenEO Web UI: 13 | https://openeo.dataspace.copernicus.eu/ 14 | Or use the openeo API to fetch the resulting maps. 15 | """ 16 | 17 | import openeo 18 | 19 | ###### USER INPUTS ###### 20 | # Define the spatial and temporal extent 21 | spatial_extent = { 22 | "west": 622694.5968575787, 23 | "south": 5672232.857114074, 24 | "east": 623079.000934101, 25 | "north": 5672519.995940826, 26 | "crs": "EPSG:32631", 27 | "srs": "EPSG:32631", 28 | } 29 | 30 | # Temporal extent needs to consist of exactly twelve months, 31 | # always starting first day of the month and ending last day of the month. 32 | # Ideally, the season of interest should be nicely centered within the selected twelve months period. 33 | temporal_extent = ["2018-11-01", "2019-10-30"] 34 | 35 | # Provide the link to your custom model 36 | # The model should be in ONNX format and publicly accessible. 37 | model_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/downstream/tests/be_multiclass-test_custommodel.onnx" 38 | 39 | # Sentinel-1 orbit state, either "ASCENDING" or "DESCENDING". 40 | # In the future, this setting will be automatically determined by the system based on the regional extent. 41 | # For now, the user needs to explicitly specify the orbit state based on the most dominant orbit state in the region of interest. 42 | # The user can be guided to the global map available here: https://docs.sentinel-hub.com/api/latest/data/sentinel-1-grd/ 43 | # to make this decision. 44 | orbit_state = "ASCENDING" # system's default is "DESCENDING" 45 | 46 | ## OPTIONAL PARAMETERS 47 | # Post-processing method, available choices are "majority_vote" and "smooth_probabilities". 48 | # System's default is "smooth_probabilities", but we prefer "majority_vote" for this example. 49 | postprocess_method = "majority_vote" 50 | # If the postprocess_method is set to "majority_vote", the user can specify the kernel size. 51 | # The system's default is 5 52 | postprocess_kernel_size = 5 53 | 54 | ###### END OF USER INPUTS ###### 55 | 56 | # Connect to openeo backend 57 | c = openeo.connect("openeo.dataspace.copernicus.eu").authenticate_oidc() 58 | 59 | # Define the operations 60 | cube = c.datacube_from_process( 61 | process_id="worldcereal_crop_type", 62 | namespace="https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/refs/heads/327-lut-crop-type/src/worldcereal/udp/worldcereal_crop_type.json", 63 | spatial_extent=spatial_extent, 64 | temporal_extent=temporal_extent, 65 | model_url=model_url, 66 | orbit_state=orbit_state, 67 | postprocess_method=postprocess_method, 68 | postprocess_kernel_size=postprocess_kernel_size, 69 | ) 70 | 71 | # Run the job 72 | job = cube.execute_batch( 73 | title="Test worldcereal_crop_type UDP", 74 | ) 75 | 76 | ###### EXPLANATION OF RESULTS ###### 77 | # Four GeoTiff images will be generated. 78 | 79 | # Each raster contains at least three bands: 80 | # - Band 1: "classification": The classification label of the pixel. 81 | # - Band 2: "confidence": The class-specific probablity of the winning class. 82 | # - Band 3 and beyond: "probability_xxx": Class-specific probablities. The "xxx" indicates the associated class. 83 | 84 | # The following files will be generated: 85 | # 1. "cropland-raw.tif": The original temporary crops mask, generated by the default global WorldCereal temporary crops model. 86 | # This model differentiates temporary crops from all other land cover types. 87 | # (0: other land cover, 1: temporary crops, 255: no data) 88 | 89 | # 2. "cropland-postprocessed.tif": The final temporary crops mask after spatial cleaning of the raw result. 90 | # (0: other land cover, 1: temporary crops, 255: no data) 91 | 92 | # 3. "croptype-raw.tif": The original crop type map, generated by the user's custom crop type model. 93 | # The crop type model is only applied to the temporary crops pixels. 94 | # (labels according to the user's custom model, 254: no temporary crops, 65535: no data) 95 | 96 | # 4. "croptype-postprocessed.tif": The final crop type map after spatial cleaning of the raw result. 97 | # (labels according to the user's custom model, 254: no temporary crops, 65535: no data) 98 | -------------------------------------------------------------------------------- /scripts/misc/legend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example script showing how to upload, download and delete the WorldCereal crop type legend from Artifactory. 3 | """ 4 | 5 | from pathlib import Path 6 | 7 | from worldcereal.utils.legend import ( 8 | delete_legend_file, 9 | download_legend, 10 | get_legend, 11 | upload_legend, 12 | ) 13 | 14 | if __name__ == "__main__": 15 | 16 | # Example usage 17 | srcpath = Path("./WorldCereal_LC_CT_legend_20241231.csv") 18 | date = "20241231" 19 | 20 | # Upload the legend to Artifactory 21 | link = upload_legend(srcpath, date) 22 | 23 | # Get the latest legend from Artifactory (as pandas DataFrame) 24 | legend = get_legend() 25 | 26 | legend_irr = get_legend(topic="irrigation") 27 | 28 | # Download the latest legend from Artifactory 29 | legend_path = download_legend(Path(".")) 30 | 31 | irr_legend_path = download_legend(Path("."), topic="irrigation") 32 | 33 | # Delete the uploaded legend from Artifactory 34 | delete_legend_file(link) 35 | -------------------------------------------------------------------------------- /scripts/spark/compute_presto_features.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import pandas as pd 6 | from loguru import logger 7 | from presto.presto import Presto 8 | from presto.utils import process_parquet 9 | 10 | from worldcereal.train.data import WorldCerealTrainingDataset, get_training_df 11 | 12 | 13 | def embeddings_from_parquet_file( 14 | parquet_file: Union[str, Path], 15 | pretrained_model: Presto, 16 | sample_repeats: int = 1, 17 | mask_ratio: float = 0.0, 18 | valid_date_as_token: bool = False, 19 | exclude_meteo: bool = False, 20 | ) -> pd.DataFrame: 21 | """Method to compute Presto embeddings from a parquet file of preprocessed inputs 22 | 23 | Parameters 24 | ---------- 25 | parquet_file : Union[str, Path] 26 | parquet file to read data from 27 | pretrained_model : Presto 28 | Presto model to use for computing embeddings 29 | sample_repeats : int, optional 30 | number of augmented sample repeats, by default 1 31 | mask_ratio : float, optional 32 | mask ratio to apply, by default 0.0 33 | valid_date_as_token : bool, optional 34 | feed valid date as a token to Presto, by default False 35 | exclude_meteo : bool, optional 36 | if True, meteo will be masked during embedding computation, by default False 37 | 38 | Returns 39 | ------- 40 | pd.DataFrame 41 | DataFrame containing Presto embeddings and other attributes 42 | """ 43 | 44 | # Load parquet file 45 | logger.info(f"Processing {parquet_file}") 46 | df = pd.read_parquet(parquet_file) 47 | 48 | # Original file is partitioned by ref_id so we have to add 49 | # ref_id as a column manually 50 | df["ref_id"] = Path(parquet_file).parent.stem.split("=")[1] 51 | 52 | # Put meteo to nodata if needed 53 | if exclude_meteo: 54 | logger.warning("Excluding meteo data ...") 55 | df.loc[:, df.columns.str.contains("AGERA5")] = 65535 56 | 57 | # Convert to wide format 58 | logger.info("From long to wide format ...") 59 | df = process_parquet(df).reset_index() 60 | 61 | # Check if samples remain 62 | if df.empty: 63 | logger.warning("Empty dataframe: returning None") 64 | return None 65 | 66 | # Create dataset and dataloader 67 | logger.info("Making data loader ...") 68 | ds = WorldCerealTrainingDataset( 69 | df, 70 | task_type="cropland", 71 | augment=True, 72 | mask_ratio=mask_ratio, 73 | repeats=sample_repeats, 74 | ) 75 | 76 | if len(ds) == 0: 77 | logger.warning("No valid samples in dataset: returning None") 78 | return None 79 | 80 | return get_training_df( 81 | ds, 82 | pretrained_model, 83 | batch_size=2048, 84 | valid_date_as_token=valid_date_as_token, 85 | num_workers=4, 86 | ) 87 | 88 | 89 | def main( 90 | infile, 91 | outfile, 92 | presto_model, 93 | sc=None, 94 | debug=False, 95 | sample_repeats: int = 1, 96 | mask_ratio: float = 0.0, 97 | valid_date_as_token: bool = False, 98 | exclude_meteo: bool = False, 99 | ): 100 | 101 | logger.info( 102 | f"Starting embedding computation with augmentation (sample_repeats: {sample_repeats}, mask_ratio: {mask_ratio})" 103 | ) 104 | logger.info(f"Valid date as token: {valid_date_as_token}") 105 | 106 | # List parquet files 107 | parquet_files = glob.glob(str(infile) + "/**/*.parquet", recursive=True) 108 | 109 | if debug: 110 | parquet_files = parquet_files[:2] 111 | 112 | # Load model 113 | logger.info("Loading model ...") 114 | pretrained_model = Presto.load_pretrained( 115 | model_path=presto_model, strict=False, valid_month_as_token=valid_date_as_token 116 | ) 117 | 118 | if sc is not None: 119 | logger.info(f"Parallelizing {len(parquet_files)} files ...") 120 | dfs = ( 121 | sc.parallelize(parquet_files, len(parquet_files)) 122 | .map( 123 | lambda x: embeddings_from_parquet_file( 124 | x, 125 | pretrained_model, 126 | sample_repeats, 127 | mask_ratio, 128 | valid_date_as_token, 129 | exclude_meteo, 130 | ) 131 | ) 132 | .filter(lambda x: x is not None) 133 | .collect() 134 | ) 135 | else: 136 | dfs = [] 137 | for parquet_file in parquet_files: 138 | dfs.append( 139 | embeddings_from_parquet_file( 140 | parquet_file, 141 | pretrained_model, 142 | sample_repeats, 143 | mask_ratio, 144 | valid_date_as_token, 145 | exclude_meteo, 146 | ) 147 | ) 148 | 149 | if isinstance(dfs, list): 150 | logger.info(f"Done processing: concatenating {len(dfs)} results") 151 | dfs = pd.concat(dfs, ignore_index=True) 152 | 153 | logger.info(f"Final dataframe shape: {dfs.shape}") 154 | 155 | # Write to parquet 156 | logger.info(f"Writing to parquet: {outfile}") 157 | dfs.to_parquet(outfile, index=False) 158 | 159 | logger.success("All done!") 160 | 161 | 162 | if __name__ == "__main__": 163 | # Output feature basedir 164 | baseoutdir = Path( 165 | "/vitodata/worldcereal/features/preprocessedinputs-monthly-nointerp" 166 | ) 167 | 168 | spark = True 169 | localspark = False 170 | debug = False 171 | exclude_meteo = False 172 | sample_repeats = 1 173 | valid_date_as_token = False 174 | presto_dir = Path("/vitodata/worldcereal/presto/finetuning") 175 | presto_model = ( 176 | presto_dir 177 | / "presto-ss-wc-ft-ct_cropland_CROPLAND2_30D_random_time-token=none_balance=True_augment=True.pt" 178 | ) 179 | identifier = "" 180 | 181 | if spark: 182 | from worldcereal.utils.spark import get_spark_context 183 | 184 | logger.info("Setting up spark ...") 185 | sc = get_spark_context(localspark=localspark) 186 | else: 187 | sc = None 188 | 189 | infile = baseoutdir / "worldcereal_training_data.parquet" 190 | 191 | if debug: 192 | outfile = baseoutdir / ( 193 | f"training_df_{presto_model.stem}_presto-worldcereal{identifier}_DEBUG.parquet" 194 | ) 195 | else: 196 | outfile = baseoutdir / ( 197 | f"training_df_{presto_model.stem}_presto-worldcereal{identifier}.parquet" 198 | ) 199 | 200 | if not infile.exists(): 201 | raise FileNotFoundError(infile) 202 | 203 | main( 204 | infile, 205 | outfile, 206 | presto_model, 207 | sc=sc, 208 | debug=debug, 209 | sample_repeats=sample_repeats, 210 | valid_date_as_token=valid_date_as_token, 211 | exclude_meteo=exclude_meteo, 212 | ) 213 | -------------------------------------------------------------------------------- /scripts/spark/compute_presto_features.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # shellcheck disable=SC2140 3 | 4 | export SPARK_HOME=/opt/spark3_2_0/ 5 | export PATH="$SPARK_HOME/bin:$PATH" 6 | export PYTHONPATH=wczip 7 | 8 | cd src || exit 9 | zip -r ../dist/worldcereal.zip worldcereal 10 | cd .. 11 | 12 | EX_JAVAMEM='2g' 13 | EX_PYTHONMEM='14g' 14 | DR_JAVAMEM='8g' 15 | DR_PYTHONMEM='16g' 16 | 17 | PYSPARK_PYTHON=./ewocenv/bin/python \ 18 | ${SPARK_HOME}/bin/spark-submit \ 19 | --conf spark.yarn.appMasterEnv.PYSPARK_PYTHON="./ewocenv/bin/python" \ 20 | --conf spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON="./ewocenv/bin/python" \ 21 | --conf spark.executorEnv.LD_LIBRARY_PATH="./ewocenv/lib" \ 22 | --conf spark.yarn.appMasterEnv.LD_LIBRARY_PATH="./ewocenv/lib" \ 23 | --conf spark.executorEnv.PYSPARK_PYTHON="./ewocenv/bin/python" \ 24 | --conf spark.yarn.appMasterEnv.PYTHONPATH=$PYTHONPATH \ 25 | --conf spark.executorEnv.PYTHONPATH=$PYTHONPATH \ 26 | --executor-memory ${EX_JAVAMEM} --driver-memory ${DR_JAVAMEM} \ 27 | --conf spark.yarn.appMasterEnv.PYTHON_EGG_CACHE=./ \ 28 | --conf spark.executorEnv.GDAL_CACHEMAX=512 \ 29 | --conf spark.speculation=true \ 30 | --conf spark.sql.broadcastTimeout=36000 \ 31 | --conf spark.shuffle.registration.timeout=36000 \ 32 | --conf spark.sql.execution.arrow.pyspark.enabled=true \ 33 | --conf spark.shuffle.memoryFraction=0.6 \ 34 | --conf spark.hadoop.mapreduce.output.fileoutputformat.compress=true \ 35 | --conf spark.hadoop.mapreduce.output.fileoutputformat.compress.codec=org.apache.hadoop.io.compress.GzipCodec \ 36 | --conf spark.driver.memoryOverhead=${DR_PYTHONMEM} --conf spark.executor.memoryOverhead=${EX_PYTHONMEM} \ 37 | --conf spark.memory.fraction=0.2 \ 38 | --conf spark.executor.cores=4 \ 39 | --conf spark.task.cpus=4 \ 40 | --conf spark.driver.cores=4 \ 41 | --conf spark.dynamicAllocation.maxExecutors=1000 \ 42 | --conf spark.shuffle.service.enabled=true --conf spark.dynamicAllocation.enabled=true \ 43 | --conf spark.driver.maxResultSize=0 \ 44 | --master yarn --deploy-mode cluster --queue default \ 45 | --py-files "/vitodata/worldcereal/software/wheels/presto_worldcereal-0.1.6-py3-none-any.whl" \ 46 | --archives "dist/worldcereal.zip#wczip","hdfs:///tapdata/worldcereal/worldcereal.tar.gz#ewocenv" \ 47 | --conf spark.app.name="worldcereal-presto_features" \ 48 | scripts/spark/compute_presto_features.py \ 49 | -------------------------------------------------------------------------------- /scripts/stac/build_paths.py: -------------------------------------------------------------------------------- 1 | """From a folder of extracted patches, generates a list of paths to the patches 2 | to be later parsed by the Spark cluster. 3 | """ 4 | 5 | import argparse 6 | import os 7 | import pickle 8 | import re 9 | from concurrent.futures import ThreadPoolExecutor 10 | from pathlib import Path 11 | 12 | from worldcereal.stac.constants import COLLECTION_REGEXES, ExtractionCollection 13 | 14 | 15 | def iglob_files(path, notification_folders, pattern=None): 16 | """ 17 | Generator that finds all subfolders of path containing the regex `pattern` 18 | """ 19 | root_dir, folders, files = next(os.walk(path)) 20 | 21 | for f in files: 22 | 23 | if (pattern is None) or len(re.findall(pattern, f)): 24 | file_path = os.path.join(root_dir, f) 25 | yield file_path 26 | 27 | for d in folders: 28 | # If the folder is in the notification folders list, print a message 29 | if os.path.join(path, d) in notification_folders: 30 | print(f"Searching in {d}") 31 | new_path = os.path.join(root_dir, d) 32 | yield from iglob_files(new_path, notification_folders, pattern) 33 | 34 | 35 | def glob_files(path, notification_folders, pattern=None, threads=50): 36 | """ 37 | Return all files within path and subdirs containing the regex `pattern` 38 | """ 39 | with ThreadPoolExecutor(max_workers=threads) as ex: 40 | files = list( 41 | ex.map(lambda x: x, iglob_files(path, notification_folders, pattern)) 42 | ) 43 | 44 | return files 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser( 49 | description="Generates a list of paths of patches to parse." 50 | ) 51 | 52 | parser.add_argument( 53 | "collection", 54 | type=ExtractionCollection, 55 | choices=list(ExtractionCollection), 56 | help="The collection to extract", 57 | ) 58 | parser.add_argument( 59 | "input_folder", 60 | type=Path, 61 | help="The path to the folder containing the extracted patches.", 62 | ) 63 | parser.add_argument( 64 | "output_path", 65 | type=Path, 66 | help="The path to the pickle file where the paths will be saved.", 67 | ) 68 | 69 | args = parser.parse_args() 70 | 71 | # Pattern to filter the files that are sentinel-1 patches 72 | pattern = COLLECTION_REGEXES[args.collection] 73 | 74 | root_subfolders = [ 75 | str(path) for path in args.input_folder.iterdir() if path.is_dir() 76 | ] 77 | 78 | files = glob_files( 79 | path=args.input_folder, 80 | notification_folders=root_subfolders, 81 | pattern=pattern, 82 | ) 83 | 84 | with open(args.output_path, "wb") as f: 85 | pickle.dump(files, f) 86 | -------------------------------------------------------------------------------- /scripts/stac/split_catalogue.py: -------------------------------------------------------------------------------- 1 | """Split catalogue by the local UTM projection of the products to be utilisable 2 | by OpenEO processes.""" 3 | 4 | import argparse 5 | import logging 6 | import pickle 7 | from pathlib import Path 8 | 9 | from openeo_gfmap.utils.split_stac import split_collection_by_epsg 10 | from tqdm import tqdm 11 | 12 | # Logger used for the pipeline 13 | builder_log = logging.getLogger("catalogue_splitter") 14 | 15 | builder_log.setLevel(level=logging.INFO) 16 | 17 | stream_handler = logging.StreamHandler() 18 | builder_log.addHandler(stream_handler) 19 | 20 | formatter = logging.Formatter("%(asctime)s|%(name)s|%(levelname)s: %(message)s") 21 | stream_handler.setFormatter(formatter) 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser( 26 | description="Splits the catalogue by the local UTM projection of the products." 27 | ) 28 | parser.add_argument( 29 | "input_folder", 30 | type=Path, 31 | help="The path to the folder containing the collection files.", 32 | ) 33 | parser.add_argument( 34 | "output_folder", 35 | type=Path, 36 | help="The path where to save the splitted STAC collections to.", 37 | ) 38 | 39 | args = parser.parse_args() 40 | 41 | if not args.input_folder.exists(): 42 | raise FileNotFoundError(f"The input folder {args.input_folder} does not exist.") 43 | 44 | if not args.output_folder.exists(): 45 | raise FileNotFoundError( 46 | f"The output folder {args.output_folder} does not exist." 47 | ) 48 | 49 | builder_log.info("Loading the catalogues from the directory %s", args.input_folder) 50 | # List the catalogues in the input folder 51 | catalogues = [] 52 | for catalogue_path in tqdm(args.input_folder.glob("*.pkl")): 53 | with open(catalogue_path, "rb") as file: 54 | catalogue = pickle.load(file) 55 | try: 56 | catalogue.strategy 57 | except AttributeError: 58 | setattr(catalogue, "strategy", None) 59 | catalogues.append(catalogue) 60 | 61 | builder_log.info("Loaded %s catalogues. Merging them...", len(catalogues)) 62 | 63 | merged_catalogue = None 64 | for catalogue_path in tqdm(catalogues): 65 | if merged_catalogue is None: 66 | merged_catalogue = catalogue_path 67 | else: 68 | merged_catalogue.add_items(catalogue_path.get_all_items()) 69 | 70 | if merged_catalogue is None: 71 | raise ValueError("No catalogues found in the input folder.") 72 | 73 | builder_log.info("Merged catalogues into one. Updating the extent...") 74 | merged_catalogue.update_extent_from_items() 75 | 76 | with open("temp_merged_catalogue.pkl", "wb") as file: 77 | pickle.dump(merged_catalogue, file) 78 | 79 | builder_log.info("Splitting the catalogue by the local UTM projection...") 80 | split_collection_by_epsg(merged_catalogue, args.output_folder) 81 | -------------------------------------------------------------------------------- /src/worldcereal/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from ._version import __version__ 4 | 5 | __all__ = ["__version__"] 6 | 7 | SUPPORTED_SEASONS = [ 8 | "tc-s1", 9 | "tc-s2", 10 | "tc-annual", 11 | "custom", 12 | ] 13 | 14 | SEASONAL_MAPPING = { 15 | "tc-s1": "S1", 16 | "tc-s2": "S2", 17 | "tc-annual": "ANNUAL", 18 | "custom": "custom", 19 | } 20 | 21 | 22 | # Default buffer (days) prior to 23 | # season start 24 | SEASON_PRIOR_BUFFER = { 25 | "tc-s1": 0, 26 | "tc-s2": 0, 27 | "tc-annual": 0, 28 | "custom": 0, 29 | } 30 | 31 | 32 | # Default buffer (days) after 33 | # season end 34 | SEASON_POST_BUFFER = { 35 | "tc-s1": 0, 36 | "tc-s2": 0, 37 | "tc-annual": 0, 38 | "custom": 0, 39 | } 40 | -------------------------------------------------------------------------------- /src/worldcereal/_version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | __version__ = "2.2.3" 4 | -------------------------------------------------------------------------------- /src/worldcereal/data/cropcalendars/ANNUAL_EOS_WGS84.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/ANNUAL_EOS_WGS84.tif -------------------------------------------------------------------------------- /src/worldcereal/data/cropcalendars/ANNUAL_SOS_WGS84.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/ANNUAL_SOS_WGS84.tif -------------------------------------------------------------------------------- /src/worldcereal/data/cropcalendars/S1_EOS_WGS84.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/S1_EOS_WGS84.tif -------------------------------------------------------------------------------- /src/worldcereal/data/cropcalendars/S1_SOS_WGS84.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/S1_SOS_WGS84.tif -------------------------------------------------------------------------------- /src/worldcereal/data/cropcalendars/S2_EOS_WGS84.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/S2_EOS_WGS84.tif -------------------------------------------------------------------------------- /src/worldcereal/data/cropcalendars/S2_SOS_WGS84.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/S2_SOS_WGS84.tif -------------------------------------------------------------------------------- /src/worldcereal/data/cropcalendars/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/__init__.py -------------------------------------------------------------------------------- /src/worldcereal/data/croptype_mappings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/croptype_mappings/__init__.py -------------------------------------------------------------------------------- /src/worldcereal/data/croptype_mappings/wc2eurocrops_map.csv: -------------------------------------------------------------------------------- 1 | name,croptype,ewoc_code,landcover 2 | Unknown,0,00-00-00-000-0,0 3 | Cereals,1000,11-01-00-000-0,11 4 | Wheat,1100,11-01-01-000-0,11 5 | Winter wheat,1110,11-01-01-000-1,11 6 | Spring wheat,1120,11-01-01-000-2,11 7 | Maize,1200,11-01-06-000-0,11 8 | Rice,1300,11-01-08-000-0,11 9 | Sorghum,1400,11-01-07-003-0,11 10 | Barley,1500,11-01-02-000-0,11 11 | Winter barley,1510,11-01-02-000-1,11 12 | Spring barley,1520,11-01-02-000-2,11 13 | Rye,1600,11-01-03-000-0,11 14 | Winter rye,1610,11-01-03-000-1,11 15 | Spring rye,1620,11-01-03-000-2,11 16 | Oats,1700,11-01-04-000-0,11 17 | Millets,1800,11-01-07-001-0,11 18 | Other cereals,1900,11-01-00-000-0,11 19 | Winter cereal,1910,11-01-00-000-1,11 20 | Spring cereal,1920,11-01-00-000-2,11 21 | Vegetables and melons,2000,11-03-00-000-0,11 22 | Leafy or stem vegetables,2100,11-03-04-000-0,11 23 | Artichokes,2110,11-03-04-002-0,11 24 | Asparagus,2120,11-03-04-003-0,11 25 | Cabages,2130,11-03-06-000-0,11 26 | Cauliflowers & brocoli,2140,11-03-06-013-0,11 27 | Lettuce,2150,11-03-08-011-0,11 28 | Spinach,2160,11-03-08-008-0,11 29 | Chicory,2170,11-03-08-002-0,11 30 | Other leafy/stem vegetables,2190,11-03-04-000-0,11 31 | Fruit-bearing vegetables,2200,11-03-01-000-0,11 32 | Cucumbers,2210,11-03-02-001-0,11 33 | Eggplants,2220,11-03-01-002-0,11 34 | Tomatoes,2230,11-03-01-001-0,11 35 | Watermelons,2240,11-03-02-005-0,11 36 | Cantaloupes and other melons,2250,11-03-02-003-0,11 37 | "Pumpkin, squash and gourds",2260,11-03-02-004-0,11 38 | Other fruit-bearing vegetables,2290,11-03-01-000-0,11 39 | "Root, bulb or tuberous vegetables",2300,11-03-09-000-0,11 40 | Carrots,2310,11-03-09-004-0,11 41 | Turnips,2320,11-03-09-007-0,11 42 | Garlic,2330,11-03-11-002-0,11 43 | Onions and shallots,2340,11-03-11-007-0,11 44 | Leeks & other alliaceous vegetables,2350,11-03-11-000-0,11 45 | "Other root, bulb or tuberous vegetables",2390,11-03-09-000-0,11 46 | Mushrooms and truffles,2400,11-04-00-000-0,11 47 | Other vegetables,2900,11-03-00-000-0,11 48 | Fruit and nuts,3000,12-01-00-000-0,12 49 | Tropical and subtropical fruits,3100,12-01-02-000-0,12 50 | Avocados,3110,12-01-02-001-0,12 51 | Bananas & plantains,3120,12-01-02-002-0,12 52 | Dates,3130,12-01-02-003-0,12 53 | Figs,3140,12-01-01-006-0,12 54 | Mangoes,3150,12-01-02-004-0,12 55 | Papayas,3160,12-01-02-005-0,12 56 | Pineapples,3170,12-01-02-006-0,12 57 | Other tropical and subtropical fruits,3190,12-01-02-000-0,12 58 | Citrus fruits,3200,12-01-03-000-0,12 59 | Grapefruit & pomelo,3210,12-01-03-000-0,12 60 | Lemons & Limes,3220,12-01-03-003-0,12 61 | Oranges,3230,12-01-03-004-0,12 62 | "Tangerines, mandarins, clementines",3240,12-01-03-005-0,12 63 | Other citrus fruit,3290,12-01-03-000-0,12 64 | Grapes,3300,12-01-00-001-0,12 65 | Berries,3400,12-01-05-000-0,12 66 | Currants,3410,12-01-05-006-0,12 67 | Gooseberries,3420,12-01-05-007-0,12 68 | Kiwi fruit,3430,12-01-05-015-0,12 69 | Raspberries,3440,12-01-05-010-0,12 70 | Strawberries,3450,11-03-12-001-0,11 71 | Blueberries,3460,12-01-05-004-0,12 72 | Other berries,3490,12-01-05-000-0,12 73 | Pome fruits and stone fruits,3500,12-01-01-000-0,12 74 | Apples,3510,12-01-01-002-0,12 75 | Apricots,3520,12-01-01-003-0,12 76 | Cherries & sour cherries,3530,12-01-01-004-0,12 77 | Peaches & nectarines,3540,12-01-01-017-0,12 78 | Pears & quinces,3550,12-01-01-019-0,12 79 | Plums and sloes,3560,12-01-01-012-0,12 80 | Other pome fruits and stone fruits,3590,12-01-01-000-0,12 81 | Nuts,3600,12-01-04-000-0,12 82 | Almonds,3610,12-01-04-001-0,12 83 | Cashew nuts,3620,12-01-04-007-0,12 84 | Chestnuts,3630,12-01-04-005-0,12 85 | Hazelnuts,3640,12-01-04-002-0,12 86 | Pistachios,3650,12-01-04-004-0,12 87 | Walnuts,3660,12-01-04-006-0,12 88 | Other nuts,3690,12-01-04-000-0,12 89 | Other fruit,3900,12-01-00-000-0,12 90 | Oilseed crops,4000,11-06-00-000-0,11 91 | Soya beans,4100,11-06-00-002-0,11 92 | Groundnuts,4200,11-06-00-005-0,11 93 | Temporary oilseed crops,4300,11-06-00-000-0,11 94 | Castor bean,4310,11-06-00-006-0,11 95 | Linseed,4320,11-06-00-007-0,11 96 | Mustard,4330,11-06-00-008-0,11 97 | Niger seed,4340,11-06-00-004-0,11 98 | Rapeseed,4350,11-06-00-003-0,11 99 | Winter rapeseed,4351,11-06-00-003-1,11 100 | Spring rapeseed,4352,11-06-00-003-2,11 101 | Safflower,4360,11-06-00-009-0,11 102 | Sesame,4370,11-06-00-010-0,11 103 | Sunflower,4380,11-06-00-001-0,11 104 | Other temporary oilseed crops,4390,11-06-00-000-0,11 105 | Permanent oilseed crops,4400,12-03-00-000-0,12 106 | Coconuts,4410,12-03-00-002-0,12 107 | Olives,4420,12-03-00-001-0,12 108 | Oil palms,4430,12-03-00-003-0,12 109 | Other oleaginous fruits,4490,12-03-00-000-0,12 110 | Root/tuber crops,5000,11-07-00-000-0,11 111 | Potatoes,5100,11-07-00-001-0,11 112 | Sweet potatoes,5200,11-07-00-002-0,11 113 | Cassava,5300,11-07-00-004-0,11 114 | Yams,5400,11-07-00-005-0,11 115 | Other roots and tubers,5900,11-07-00-000-0,11 116 | Beverage and spice crops,6000,10-00-00-000-0,10 117 | Beverage crops,6100,12-02-00-000-0,12 118 | Coffee,6110,12-02-00-001-0,12 119 | Tea,6120,12-02-00-002-0,12 120 | Maté,6130,12-02-00-003-0,12 121 | Cocoa,6140,12-02-00-004-0,12 122 | Other beverage crops,6190,10-00-00-000-0,10 123 | Spice crops,6200,12-02-00-000-0,12 124 | Chilies & peppers,6211,11-03-03-000-0,11 125 | "Anise, badian, fennel",6212,11-00-00-000-0,11 126 | Other temporary spice crops,6219,11-09-00-000-0,11 127 | Pepper,6221,11-09-00-029-0,11 128 | "Nutmeg, mace, cardamoms",6222,12-02-00-007-0,12 129 | Cinnamon,6223,12-02-00-008-0,12 130 | Cloves,6224,12-02-00-009-0,12 131 | Ginger,6225,11-09-00-048-0,11 132 | Vanilla,6226,11-09-00-049-0,10 133 | Other permanent spice crops,6229,12-02-00-000-0,12 134 | Leguminous crops,7000,11-05-00-000-0,11 135 | Beans,7100,11-05-01-001-0,11 136 | Broad beans,7200,11-05-01-003-0,11 137 | Chick peas,7300,11-05-01-004-0,11 138 | Cow peas,7400,11-05-01-005-0,11 139 | Lentils,7500,11-05-00-003-0,11 140 | Lupins,7600,11-05-00-004-0,11 141 | Peas,7700,11-05-01-002-0,11 142 | Pigeon peas,7800,11-05-01-006-0,11 143 | Other Leguminous crops,7900,10-00-00-000-0,10 144 | Other Leguminous crops - Temporary,7910,11-05-00-000-0,11 145 | Other Leguminous crops - Permanent,7920,12-04-00-000-0,12 146 | Sugar crops,8000,10-00-00-000-0,10 147 | Sugar beet,8100,11-07-00-003-1,11 148 | Sugar cane,8200,11-11-01-010-0,11 149 | Sweet sorghum,8300,11-01-07-004-0,11 150 | Other sugar crops,8900,10-00-00-000-0,10 151 | Other crops,9000,10-00-00-000-0,10 152 | Grasses and other fodder crops,9100,11-11-00-000-0,11 153 | Temporary grass crops,9110,11-11-00-001-0,11 154 | Permanent grass crops,9120,13-00-00-000-0,13 155 | Fibre crops,9200,10-00-00-000-0,10 156 | Temporary fibre crops,9210,11-08-00-000-0,11 157 | Cotton,9211,11-08-00-001-0,11 158 | "Jute, kenaf and similar",9212,11-08-00-002-0,11 159 | "Flax, hemp and similar",9213,11-08-02-000-0,11 160 | Other temporary fibre crops,9219,11-08-00-000-0,11 161 | Permanent fibre crops,9220,12-05-00-000-0,12 162 | "Medicinal, aromatic, pesticidal crops",9300,10-00-00-000-0,10 163 | Temporary medicinal etc crops,9310,11-09-00-000-0,11 164 | Permanent medicinal etc crops,9320,12-02-00-000-0,12 165 | Rubber,9400,12-06-00-008-0,12 166 | Flower crops,9500,11-10-00-000-0,11 167 | Temporary flower crops,9510,11-10-00-000-0,11 168 | Permanent flower crops,9520,12-06-00-000-0,12 169 | Tobacco,9600,11-09-00-050-0,11 170 | Other other crops,9900,10-00-00-000-0,10 171 | Other crops - temporary,9910,11-00-00-000-0,11 172 | Other crops - permanent,9920,12-00-00-000-0,12 173 | mixed cropping,9998,10-00-00-000-0,10 174 | Cropland,10,10-00-00-000-0,10 175 | Annual cropland,11,11-00-00-000-0,11 176 | Perennial cropland,12,12-00-00-000-0,12 177 | Grassland *,13,13-00-00-000-0,13 178 | Herbaceous vegetation,20,20-00-00-000-0,20 179 | Shrubland,30,30-00-00-000-0,30 180 | Deciduous forest,40,43-02-00-000-0,40 181 | Evergreen forest,41,43-01-00-000-0,41 182 | Mixed/unknown forest,42,43-00-00-000-0,42 183 | Bare / sparse vegetation,50,50-00-00-000-0,50 184 | Built up / urban,60,60-00-00-000-0,60 185 | Water,70,70-00-00-000-0,70 186 | Snow / ice,80,50-05-00-000-0,80 187 | No cropland (including perennials),98,17-00-00-000-0,98 188 | No cropland,99,16-00-00-000-0,99 189 | Unknown,991,00-00-00-000-0,0 190 | Unknown,9700,00-00-00-000-0,0 -------------------------------------------------------------------------------- /src/worldcereal/extract/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/extract/__init__.py -------------------------------------------------------------------------------- /src/worldcereal/extract/patch_meteo.py: -------------------------------------------------------------------------------- 1 | """Extract AGERA5 (Meteo) data using OpenEO-GFMAP package.""" 2 | 3 | import copy 4 | from typing import Dict, List, Optional, Union 5 | 6 | import geojson 7 | import geopandas as gpd 8 | import openeo 9 | import pandas as pd 10 | from openeo_gfmap import Backend, TemporalContext 11 | 12 | from worldcereal.extract.utils import ( # isort: skip 13 | buffer_geometry, # isort: skip 14 | filter_extract_true, # isort: skip 15 | upload_geoparquet_s3, # isort: skip 16 | ) # isort: skip 17 | 18 | 19 | DEFAULT_JOB_OPTIONS_PATCH_METEO = { 20 | "executor-memory": "1800m", 21 | "python-memory": "1000m", 22 | "max-executors": 22, 23 | } 24 | 25 | 26 | def create_job_dataframe_patch_meteo( 27 | backend: Backend, split_jobs: List[gpd.GeoDataFrame] 28 | ) -> pd.DataFrame: 29 | raise NotImplementedError("This function is not implemented yet.") 30 | 31 | 32 | def create_job_patch_meteo( 33 | row: pd.Series, 34 | connection: openeo.DataCube, 35 | provider, 36 | connection_provider, 37 | job_options: Optional[Dict[str, Union[str, int]]] = None, 38 | ) -> gpd.GeoDataFrame: 39 | start_date = row.start_date 40 | end_date = row.end_date 41 | temporal_context = TemporalContext(start_date, end_date) 42 | 43 | # Get the feature collection containing the geometry to the job 44 | geometry = geojson.loads(row.geometry) 45 | assert isinstance(geometry, geojson.FeatureCollection) 46 | 47 | # Filter the geometry to the rows with the extract only flag 48 | geometry = filter_extract_true(geometry) 49 | assert len(geometry.features) > 0, "No geometries with the extract flag found" 50 | 51 | # Performs a buffer of 64 px around the geometry 52 | geometry_df = buffer_geometry(geometry, distance_m=5) 53 | spatial_extent_url = upload_geoparquet_s3(provider, geometry_df, row.name, "METEO") 54 | 55 | bands_to_download = ["temperature-mean", "precipitation-flux"] 56 | 57 | cube = connection.load_collection( 58 | "AGERA5", 59 | temporal_extent=[temporal_context.start_date, temporal_context.end_date], 60 | bands=bands_to_download, 61 | ) 62 | filter_geometry = connection.load_url(spatial_extent_url, format="parquet") 63 | cube = cube.filter_spatial(filter_geometry) 64 | cube.rename_labels( 65 | dimension="bands", 66 | target=["AGERA5-temperature-mean", "AGERA5-precipitation-flux"], 67 | source=["temperature-mean", "precipitation-flux"], 68 | ) 69 | 70 | # Rescale to uint16, multiplying by 100 first 71 | cube = cube * 100 72 | cube = cube.linear_scale_range(0, 65534, 0, 65534) 73 | 74 | h3index = geometry.features[0].properties["h3index"] 75 | valid_time = geometry.features[0].properties["valid_time"] 76 | 77 | # Set job options 78 | final_job_options = copy.deepcopy(DEFAULT_JOB_OPTIONS_PATCH_METEO) 79 | if job_options: 80 | final_job_options.update(job_options) 81 | 82 | return cube.create_job( 83 | out_format="NetCDF", 84 | title=f"WorldCereal_Patch-AGERA5_Extraction_{h3index}_{valid_time}", 85 | sample_by_feature=True, 86 | job_options=final_job_options, 87 | ) 88 | -------------------------------------------------------------------------------- /src/worldcereal/extract/patch_s1.py: -------------------------------------------------------------------------------- 1 | """Extract S1 data using OpenEO-GFMAP package.""" 2 | 3 | import copy 4 | from datetime import datetime 5 | from typing import Dict, List, Optional, Union 6 | 7 | import geojson 8 | import geopandas as gpd 9 | import openeo 10 | import pandas as pd 11 | from openeo_gfmap import ( 12 | Backend, 13 | BackendContext, 14 | BoundingBoxExtent, 15 | FetchType, 16 | TemporalContext, 17 | ) 18 | from openeo_gfmap.preprocessing.sar import compress_backscatter_uint16 19 | from openeo_gfmap.utils.catalogue import s1_area_per_orbitstate_vvvh 20 | from tqdm import tqdm 21 | 22 | from worldcereal.openeo.preprocessing import raw_datacube_S1, spatially_filter_cube 23 | from worldcereal.rdm_api.rdm_interaction import RDM_DEFAULT_COLUMNS 24 | 25 | from worldcereal.extract.utils import ( # isort: skip 26 | buffer_geometry, # isort: skip 27 | get_job_nb_polygons, # isort: skip 28 | pipeline_log, # isort: skip 29 | # upload_geoparquet_s3, # isort: skip 30 | upload_geoparquet_artifactory, # isort: skip 31 | ) 32 | 33 | S1_GRD_CATALOGUE_BEGIN_DATE = datetime(2014, 10, 1) 34 | 35 | DEFAULT_JOB_OPTIONS_PATCH_S1 = { 36 | "executor-memory": "1800m", 37 | "python-memory": "1900m", 38 | "max-executors": 22, 39 | "soft-errors": 0.1, 40 | } 41 | 42 | 43 | def create_job_dataframe_patch_s1( 44 | backend: Backend, 45 | split_jobs: List[gpd.GeoDataFrame], 46 | ) -> pd.DataFrame: 47 | """Create a dataframe from the split jobs, containing all the necessary information to run the job.""" 48 | rows = [] 49 | for job in tqdm(split_jobs): 50 | # Compute the average in the valid date and make a buffer of 1.5 year around 51 | min_time = job.valid_time.min() 52 | max_time = job.valid_time.max() 53 | 54 | # Compute the average in the valid date and make a buffer of 1.5 year around 55 | # 9 months before and after the valid time 56 | start_date = (min_time - pd.Timedelta(days=275)).to_pydatetime() 57 | end_date = (max_time + pd.Timedelta(days=275)).to_pydatetime() 58 | 59 | # Impose limits due to the data availability 60 | start_date = max(start_date, S1_GRD_CATALOGUE_BEGIN_DATE) 61 | end_date = min(end_date, datetime.now()) 62 | 63 | # Convert dates to string format 64 | start_date, end_date = ( 65 | start_date.strftime("%Y-%m-%d"), 66 | end_date.strftime("%Y-%m-%d"), 67 | ) 68 | 69 | s2_tile = job.tile.iloc[0] # Job dataframes are split depending on the 70 | h3_l3_cell = job.h3_l3_cell.iloc[0] 71 | 72 | # Check wherever the s2_tile is in the grid 73 | geometry_bbox = job.to_crs(epsg=4326).total_bounds 74 | # Buffer if the geometry is a point 75 | if geometry_bbox[0] == geometry_bbox[2]: 76 | geometry_bbox = ( 77 | geometry_bbox[0] - 0.0001, 78 | geometry_bbox[1], 79 | geometry_bbox[2] + 0.0001, 80 | geometry_bbox[3], 81 | ) 82 | if geometry_bbox[1] == geometry_bbox[3]: 83 | geometry_bbox = ( 84 | geometry_bbox[0], 85 | geometry_bbox[1] - 0.0001, 86 | geometry_bbox[2], 87 | geometry_bbox[3] + 0.0001, 88 | ) 89 | 90 | area_per_orbit = s1_area_per_orbitstate_vvvh( 91 | backend=BackendContext(backend), 92 | spatial_extent=BoundingBoxExtent(*geometry_bbox), 93 | temporal_extent=TemporalContext(start_date, end_date), 94 | ) 95 | descending_area = area_per_orbit["DESCENDING"]["area"] 96 | ascending_area = area_per_orbit["ASCENDING"]["area"] 97 | 98 | # Set back the valid_time in the geometry as string 99 | job["valid_time"] = job.valid_time.dt.strftime("%Y-%m-%d") 100 | 101 | # Subset on required attributes 102 | job = job[RDM_DEFAULT_COLUMNS] 103 | 104 | variables = { 105 | "backend_name": backend.value, 106 | "out_prefix": "S1-SIGMA0-10m", 107 | "out_extension": ".nc", 108 | "start_date": start_date, 109 | "end_date": end_date, 110 | "s2_tile": s2_tile, 111 | "h3_l3_cell": h3_l3_cell, 112 | "geometry": job.to_json(), 113 | } 114 | 115 | if descending_area > 0: 116 | variables.update({"orbit_state": "DESCENDING"}) 117 | rows.append(pd.Series(variables)) 118 | 119 | if ascending_area > 0: 120 | variables.update({"orbit_state": "ASCENDING"}) 121 | rows.append(pd.Series(variables)) 122 | 123 | if descending_area + ascending_area == 0: 124 | pipeline_log.warning( 125 | "No S1 data available for the tile %s in the period %s - %s.", 126 | s2_tile, 127 | start_date, 128 | end_date, 129 | ) 130 | 131 | return pd.DataFrame(rows) 132 | 133 | 134 | def create_job_patch_s1( 135 | row: pd.Series, 136 | connection: openeo.DataCube, 137 | provider, 138 | connection_provider, 139 | job_options: Optional[Dict[str, Union[str, int]]] = None, 140 | ) -> openeo.BatchJob: 141 | """Creates an OpenEO BatchJob from the given row information. This job is a 142 | S1 patch of 32x32 pixels at 20m spatial resolution.""" 143 | 144 | # Load the temporal extent 145 | start_date = row.start_date 146 | end_date = row.end_date 147 | temporal_context = TemporalContext(start_date, end_date) 148 | 149 | # Get the feature collection containing the geometry to the job 150 | geometry = geojson.loads(row.geometry) 151 | assert isinstance(geometry, geojson.FeatureCollection) 152 | 153 | # Jobs will be run for two orbit direction 154 | orbit_state = row.orbit_state 155 | 156 | # Performs a buffer of 64 px around the geometry 157 | geometry_df = buffer_geometry(geometry, distance_m=320) 158 | # spatial_extent_url = upload_geoparquet_s3( 159 | # provider, geometry_df, f"{row.s2_tile}_{row.name}", "SENTINEL1" 160 | # ) 161 | spatial_extent_url = upload_geoparquet_artifactory( 162 | geometry_df, f"{row.s2_tile}_{row.name}", collection="SENTINEL1" 163 | ) 164 | 165 | # Backend name and fetching type 166 | backend = Backend(row.backend_name) 167 | backend_context = BackendContext(backend) 168 | 169 | # Create the job to extract S1 170 | cube = raw_datacube_S1( 171 | connection=connection, 172 | backend_context=backend_context, 173 | spatial_extent=spatial_extent_url, 174 | temporal_extent=temporal_context, 175 | bands=["S1-SIGMA0-VV", "S1-SIGMA0-VH"], 176 | fetch_type=FetchType.POLYGON, 177 | target_resolution=20, 178 | orbit_direction=orbit_state, 179 | ) 180 | cube = compress_backscatter_uint16(backend_context, cube) 181 | 182 | # Apply spatial filtering 183 | cube = spatially_filter_cube(connection, cube, spatial_extent_url) 184 | 185 | # Additional values to generate the BatcJob name 186 | s2_tile = row.s2_tile 187 | valid_time = geometry.features[0].properties["valid_time"] 188 | 189 | # Increase the memory of the jobs depending on the number of polygons to extract 190 | number_polygons = get_job_nb_polygons(row) 191 | pipeline_log.debug("Number of polygons to extract %s", number_polygons) 192 | 193 | # Set job options 194 | final_job_options = copy.deepcopy(DEFAULT_JOB_OPTIONS_PATCH_S1) 195 | if job_options: 196 | final_job_options.update(job_options) 197 | 198 | return cube.create_job( 199 | out_format="NetCDF", 200 | title=f"Worldcereal_Patch-S1_Extraction_{s2_tile}_{orbit_state}_{valid_time}", 201 | sample_by_feature=True, 202 | job_options=final_job_options, 203 | feature_id_property="sample_id", 204 | ) 205 | -------------------------------------------------------------------------------- /src/worldcereal/extract/patch_s2.py: -------------------------------------------------------------------------------- 1 | """Extract S2 data using OpenEO-GFMAP package.""" 2 | 3 | import copy 4 | from datetime import datetime 5 | from typing import Dict, List, Optional, Union 6 | 7 | import geojson 8 | import geopandas as gpd 9 | import openeo 10 | import pandas as pd 11 | from openeo_gfmap import Backend, BackendContext, FetchType, TemporalContext 12 | from openeo_gfmap.manager import _log 13 | from tqdm import tqdm 14 | 15 | from worldcereal.openeo.preprocessing import raw_datacube_S2, spatially_filter_cube 16 | from worldcereal.rdm_api.rdm_interaction import RDM_DEFAULT_COLUMNS 17 | 18 | from worldcereal.extract.utils import ( # isort: skip 19 | buffer_geometry, # isort: skip 20 | get_job_nb_polygons, # isort: skip 21 | upload_geoparquet_artifactory, # isort: skip 22 | # upload_geoparquet_s3, # isort: skip 23 | ) # isort: skip 24 | 25 | 26 | S2_L2A_CATALOGUE_BEGIN_DATE = datetime(2017, 1, 1) 27 | 28 | 29 | DEFAULT_JOB_OPTIONS_PATCH_S2 = { 30 | "driver-memory": "2G", 31 | "driver-memoryOverhead": "2G", 32 | "driver-cores": "1", 33 | "executor-memory": "1800m", 34 | "python-memory": "1900m", 35 | "executor-cores": "1", 36 | "max-executors": 22, 37 | "soft-errors": 0.1, 38 | "gdal-dataset-cache-size": 2, 39 | "gdal-cachemax": 120, 40 | "executor-threads-jvm": 1, 41 | } 42 | 43 | 44 | def create_job_dataframe_patch_s2( 45 | backend: Backend, 46 | split_jobs: List[gpd.GeoDataFrame], 47 | ) -> pd.DataFrame: 48 | """Create a dataframe from the split jobs, containing all the necessary information to run the job.""" 49 | rows = [] 50 | for job in tqdm(split_jobs): 51 | # Compute the average in the valid date and make a buffer of 1.5 year around 52 | min_time = job.valid_time.min() 53 | max_time = job.valid_time.max() 54 | # 9 months before and after the valid time 55 | start_date = (min_time - pd.Timedelta(days=275)).to_pydatetime() 56 | end_date = (max_time + pd.Timedelta(days=275)).to_pydatetime() 57 | 58 | # Impose limits due to the data availability 59 | # start_date = max(start_date, S2_L2A_CATALOGUE_BEGIN_DATE) 60 | # end_date = min(end_date, datetime.now()) 61 | 62 | s2_tile = job.tile.iloc[0] 63 | h3_l3_cell = job.h3_l3_cell.iloc[0] 64 | 65 | # Convert dates to string format 66 | start_date, end_date = ( 67 | start_date.strftime("%Y-%m-%d"), 68 | end_date.strftime("%Y-%m-%d"), 69 | ) 70 | 71 | # Set back the valid_time in the geometry as string 72 | job["valid_time"] = job.valid_time.dt.strftime("%Y-%m-%d") 73 | 74 | # Subset on required attributes 75 | job = job[RDM_DEFAULT_COLUMNS] 76 | 77 | variables = { 78 | "backend_name": backend.value, 79 | "out_prefix": "S2-L2A-10m", 80 | "out_extension": ".nc", 81 | "start_date": start_date, 82 | "end_date": end_date, 83 | "s2_tile": s2_tile, 84 | "h3_l3_cell": h3_l3_cell, 85 | "geometry": job.to_json(), 86 | } 87 | rows.append(pd.Series(variables)) 88 | 89 | return pd.DataFrame(rows) 90 | 91 | 92 | def create_job_patch_s2( 93 | row: pd.Series, 94 | connection: openeo.DataCube, 95 | provider, 96 | connection_provider, 97 | job_options: Optional[Dict[str, Union[str, int]]] = None, 98 | ) -> gpd.GeoDataFrame: 99 | start_date = row.start_date 100 | end_date = row.end_date 101 | temporal_context = TemporalContext(start_date, end_date) 102 | 103 | # Get the feature collection containing the geometry to the job 104 | geometry = geojson.loads(row.geometry) 105 | assert isinstance(geometry, geojson.FeatureCollection) 106 | 107 | # # Filter the geometry to the rows with the extract only flag 108 | # geometry = filter_extract_true(geometry) 109 | # assert len(geometry.features) > 0, "No geometries with the extract flag found" 110 | 111 | # Performs a buffer of 64 px around the geometry 112 | geometry_df = buffer_geometry(geometry, distance_m=320) 113 | # spatial_extent_url = upload_geoparquet_s3( 114 | # provider, geometry_df, f"{row.s2_tile}_{row.name}", "SENTINEL2" 115 | # ) 116 | spatial_extent_url = upload_geoparquet_artifactory( 117 | geometry_df, f"{row.s2_tile}_{row.name}", collection="SENTINEL2" 118 | ) 119 | 120 | # Backend name and fetching type 121 | backend = Backend(row.backend_name) 122 | backend_context = BackendContext(backend) 123 | 124 | # Get the h3index to use in the tile 125 | s2_tile = row.s2_tile 126 | valid_time = geometry.features[0].properties["valid_time"] 127 | 128 | bands_to_download = [ 129 | "S2-L2A-B01", 130 | "S2-L2A-B02", 131 | "S2-L2A-B03", 132 | "S2-L2A-B04", 133 | "S2-L2A-B05", 134 | "S2-L2A-B06", 135 | "S2-L2A-B07", 136 | "S2-L2A-B08", 137 | "S2-L2A-B8A", 138 | "S2-L2A-B09", 139 | "S2-L2A-B11", 140 | "S2-L2A-B12", 141 | "S2-L2A-SCL", 142 | ] 143 | 144 | cube = raw_datacube_S2( 145 | connection, 146 | backend_context, 147 | temporal_context, 148 | bands_to_download, 149 | FetchType.POLYGON, 150 | spatial_extent=spatial_extent_url, 151 | filter_tile=s2_tile, 152 | apply_mask_flag=False, 153 | additional_masks_flag=True, 154 | ) 155 | 156 | # Apply spatial filtering 157 | cube = spatially_filter_cube(connection, cube, spatial_extent_url) 158 | 159 | # Increase the memory of the jobs depending on the number of polygons to extract 160 | number_polygons = get_job_nb_polygons(row) 161 | _log.debug("Number of polygons to extract %s", number_polygons) 162 | 163 | # Set job options 164 | final_job_options = copy.deepcopy(DEFAULT_JOB_OPTIONS_PATCH_S2) 165 | if job_options: 166 | final_job_options.update(job_options) 167 | 168 | return cube.create_job( 169 | out_format="NetCDF", 170 | title=f"Worldcereal_Patch-S2_Extraction_{s2_tile}_{valid_time}", 171 | sample_by_feature=True, 172 | job_options=final_job_options, 173 | feature_id_property="sample_id", 174 | ) 175 | -------------------------------------------------------------------------------- /src/worldcereal/extract/utils.py: -------------------------------------------------------------------------------- 1 | """Common utilities used by extraction scripts.""" 2 | 3 | import logging 4 | import os 5 | from tempfile import NamedTemporaryFile 6 | 7 | import geojson 8 | import geopandas as gpd 9 | import pandas as pd 10 | import requests 11 | from openeo_gfmap.manager.job_splitters import load_s2_grid 12 | from shapely import Point 13 | 14 | from worldcereal.utils.upload import OpenEOArtifactHelper 15 | 16 | # Logger used for the pipeline 17 | pipeline_log = logging.getLogger("extraction_pipeline") 18 | 19 | pipeline_log.setLevel(level=logging.INFO) 20 | 21 | stream_handler = logging.StreamHandler() 22 | pipeline_log.addHandler(stream_handler) 23 | 24 | formatter = logging.Formatter("%(asctime)s|%(name)s|%(levelname)s: %(message)s") 25 | stream_handler.setFormatter(formatter) 26 | 27 | 28 | # Exclude the other loggers from other libraries 29 | class ManagerLoggerFilter(logging.Filter): 30 | """Filter to only accept the OpenEO-GFMAP manager logs.""" 31 | 32 | def filter(self, record): 33 | return record.name in [pipeline_log.name] 34 | 35 | 36 | stream_handler.addFilter(ManagerLoggerFilter()) 37 | 38 | 39 | S2_GRID = load_s2_grid() 40 | 41 | 42 | def buffer_geometry( 43 | geometries: geojson.FeatureCollection, distance_m: int = 320 44 | ) -> gpd.GeoDataFrame: 45 | """For each geometry of the collection, perform a square buffer of 320 46 | meters on the centroid and return the GeoDataFrame. Before buffering, 47 | the centroid is clipped to the closest 20m multiplier in order to stay 48 | aligned with the Sentinel-1 pixel grid. 49 | """ 50 | gdf = gpd.GeoDataFrame.from_features(geometries).set_crs(epsg=4326) 51 | utm = gdf.estimate_utm_crs() 52 | gdf = gdf.to_crs(utm) 53 | 54 | # Perform the buffering operation 55 | gdf["geometry"] = gdf.centroid.apply( 56 | lambda point: Point(round(point.x / 20.0) * 20.0, round(point.y / 20.0) * 20.0) 57 | ).buffer( 58 | distance=distance_m, cap_style=3 59 | ) # Square buffer 60 | 61 | return gdf 62 | 63 | 64 | def filter_extract_true( 65 | geometries: geojson.FeatureCollection, extract_value: int = 1 66 | ) -> gpd.GeoDataFrame: 67 | """Remove all the geometries from the Feature Collection that have the property field `extract` set to `False`""" 68 | return geojson.FeatureCollection( 69 | [ 70 | f 71 | for f in geometries.features 72 | if f.properties.get("extract", 0) == extract_value 73 | ] 74 | ) 75 | 76 | 77 | def upload_geoparquet_s3( 78 | backend: str, gdf: gpd.GeoDataFrame, name: str, collection: str = "" 79 | ) -> str: 80 | """Upload the given GeoDataFrame to s3 and return the URL of the 81 | uploaded file. Necessary as a workaround for Polygon sampling in OpenEO 82 | using custom CRS. 83 | """ 84 | # Save the dataframe as geoparquet to upload it to artifactory 85 | temporary_file = NamedTemporaryFile() 86 | gdf.to_parquet(temporary_file.name) 87 | 88 | targetpath = f"openeogfmap_dataframe_{collection}_{name}.parquet" 89 | 90 | artifact_helper = OpenEOArtifactHelper.from_openeo_backend(backend) 91 | normal_s3_uri = artifact_helper.upload_file(targetpath, temporary_file.name) 92 | presigned_uri = artifact_helper.get_presigned_url(normal_s3_uri) 93 | 94 | return presigned_uri 95 | 96 | 97 | def upload_geoparquet_artifactory( 98 | gdf: gpd.GeoDataFrame, name: str, collection: str = "" 99 | ) -> str: 100 | """Upload the given GeoDataFrame to artifactory and return the URL of the 101 | uploaded file. Necessary as a workaround for Polygon sampling in OpenEO 102 | using custom CRS. 103 | """ 104 | # Save the dataframe as geoparquet to upload it to artifactory 105 | temporary_file = NamedTemporaryFile() 106 | gdf.to_parquet(temporary_file.name) 107 | 108 | artifactory_username = os.getenv("ARTIFACTORY_USERNAME") 109 | artifactory_password = os.getenv("ARTIFACTORY_PASSWORD") 110 | 111 | if not artifactory_username or not artifactory_password: 112 | raise ValueError( 113 | "Artifactory credentials not found. Please set ARTIFACTORY_USERNAME and ARTIFACTORY_PASSWORD." 114 | ) 115 | 116 | headers = {"Content-Type": "application/octet-stream"} 117 | 118 | upload_url = f"https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap-temp/openeogfmap_dataframe_{collection}_{name}.parquet" 119 | 120 | with open(temporary_file.name, "rb") as f: 121 | response = requests.put( 122 | upload_url, 123 | headers=headers, 124 | data=f, 125 | auth=(artifactory_username, artifactory_password), 126 | timeout=180, 127 | ) 128 | 129 | response.raise_for_status() 130 | 131 | return upload_url 132 | 133 | 134 | def get_job_nb_polygons(row: pd.Series) -> int: 135 | """Get the number of polygons in the geometry.""" 136 | return len( 137 | list( 138 | filter( 139 | lambda feat: feat.properties.get("extract"), 140 | geojson.loads(row.geometry)["features"], 141 | ) 142 | ) 143 | ) 144 | -------------------------------------------------------------------------------- /src/worldcereal/openeo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/openeo/__init__.py -------------------------------------------------------------------------------- /src/worldcereal/openeo/inference.py: -------------------------------------------------------------------------------- 1 | """Model inference on Presto feature for binary classication""" 2 | 3 | import functools 4 | 5 | import onnxruntime as ort 6 | import requests 7 | import xarray as xr 8 | from openeo_gfmap.inference.model_inference import ModelInference 9 | 10 | 11 | class CropClassifier(ModelInference): 12 | """Binary or multi-class crop classifier using ONNX to load a catboost model. 13 | 14 | The classifier use the embeddings computed from the Presto Feature 15 | Extractor. 16 | 17 | Interesting UDF parameters: 18 | - classifier_url: A public URL to the ONNX classification model. Default is 19 | the public Presto model. 20 | - lookup_table: A dictionary mapping class names to class labels, ordered by 21 | model probability output. This is required for the model to map the output 22 | probabilities to class names. 23 | """ 24 | 25 | import numpy as np 26 | 27 | def __init__(self): 28 | super().__init__() 29 | 30 | self.onnx_session = None 31 | 32 | def dependencies(self) -> list: 33 | return [] 34 | 35 | @classmethod 36 | @functools.lru_cache(maxsize=6) 37 | def load_and_prepare_model(cls, model_url: str): 38 | """Method to be used instead the default GFMap load_ort_model function. 39 | Loads the model, validates it and extracts LUT from the model metadata. 40 | 41 | 42 | Parameters 43 | ---------- 44 | model_url : str 45 | Public URL to the ONNX classification model. 46 | """ 47 | # Load the model 48 | response = requests.get(model_url, timeout=120) 49 | model = ort.InferenceSession(response.content) 50 | 51 | # Validate the model 52 | metadata = model.get_modelmeta().custom_metadata_map 53 | 54 | if "class_params" not in metadata: 55 | raise ValueError("Could not find class names in the model metadata.") 56 | 57 | class_params = eval(metadata["class_params"], {"__builtins__": None}, {}) 58 | 59 | if "class_names" not in class_params: 60 | raise ValueError("Could not find class names in the model metadata.") 61 | 62 | if "class_to_label" not in class_params: 63 | raise ValueError("Could not find class to labels in the model metadata.") 64 | 65 | # Load model LUT 66 | lut = dict(zip(class_params["class_names"], class_params["class_to_label"])) 67 | sorted_lut = {k: v for k, v in sorted(lut.items(), key=lambda item: item[1])} 68 | 69 | return model, sorted_lut 70 | 71 | def output_labels(self) -> list: 72 | """ 73 | Returns the output labels for the classification. 74 | 75 | LUT needs to be explicitly sorted here as openEO does 76 | not guarantee the order of a json object being preserved when decoding 77 | a process graph in the backend. 78 | """ 79 | _, lut = self.load_and_prepare_model(self._parameters["classifier_url"]) 80 | lut_sorted = {k: v for k, v in sorted(lut.items(), key=lambda item: item[1])} 81 | class_names = lut_sorted.keys() 82 | 83 | return ["classification", "probability"] + [ 84 | f"probability_{name}" for name in class_names 85 | ] 86 | 87 | def predict(self, features: np.ndarray) -> np.ndarray: 88 | """ 89 | Predicts labels using the provided features array. 90 | 91 | LUT needs to be explicitly sorted here as openEO does 92 | not guarantee the order of a json object being preserved when decoding 93 | a process graph in the backend. 94 | """ 95 | import numpy as np 96 | 97 | # Classes names to codes 98 | _, lut_sorted = self.load_and_prepare_model(self._parameters["classifier_url"]) 99 | 100 | if lut_sorted is None: 101 | raise ValueError( 102 | "Lookup table is not defined. Please provide lookup_table in the model metadata." 103 | ) 104 | 105 | if self.onnx_session is None: 106 | raise ValueError("Model has not been loaded. Please load a model first.") 107 | 108 | # Prepare input data for ONNX model 109 | outputs = self.onnx_session.run(None, {"features": features}) 110 | 111 | # Extract classes as INTs and probability of winning class values 112 | labels = np.zeros((len(outputs[0]),), dtype=np.uint16) 113 | probabilities = np.zeros((len(outputs[0]),), dtype=np.uint8) 114 | for i, (label, prob) in enumerate(zip(outputs[0], outputs[1])): 115 | labels[i] = lut_sorted[label] 116 | probabilities[i] = int(round(prob[label] * 100)) 117 | 118 | # Extract per class probabilities 119 | output_probabilities = [] 120 | for output_px in outputs[1]: 121 | output_probabilities.append( 122 | [output_px[label] for label in lut_sorted.keys()] 123 | ) 124 | 125 | output_probabilities = ( 126 | (np.array(output_probabilities) * 100).round().astype(np.uint8) 127 | ) 128 | 129 | return np.hstack( 130 | [labels[:, np.newaxis], probabilities[:, np.newaxis], output_probabilities] 131 | ).transpose() 132 | 133 | def execute(self, inarr: xr.DataArray) -> xr.DataArray: 134 | 135 | if "classifier_url" not in self._parameters: 136 | raise ValueError('Missing required parameter "classifier_url"') 137 | classifier_url = self._parameters.get("classifier_url") 138 | self.logger.info(f'Loading classifier model from "{classifier_url}"') 139 | 140 | # shape and indices for output ("xy", "bands") 141 | x_coords, y_coords = inarr.x.values, inarr.y.values 142 | inarr = inarr.transpose("bands", "x", "y").stack(xy=["x", "y"]).transpose() 143 | 144 | self.onnx_session, self._parameters["lookup_table"] = ( 145 | self.load_and_prepare_model(classifier_url) 146 | ) 147 | 148 | # Run catboost classification 149 | self.logger.info("Catboost classification with input shape: %s", inarr.shape) 150 | classification = self.predict(inarr.values) 151 | self.logger.info("Classification done with shape: %s", inarr.shape) 152 | 153 | output_labels = self.output_labels() 154 | 155 | classification_da = xr.DataArray( 156 | classification.reshape((len(output_labels), len(x_coords), len(y_coords))), 157 | dims=["bands", "x", "y"], 158 | coords={ 159 | "bands": output_labels, 160 | "x": x_coords, 161 | "y": y_coords, 162 | }, 163 | ) 164 | 165 | return classification_da 166 | -------------------------------------------------------------------------------- /src/worldcereal/openeo/mapping.py: -------------------------------------------------------------------------------- 1 | """Private methods for cropland/croptype mapping. The public functions that 2 | are interracting with the methods here are defined in the `worldcereal.job` 3 | sub-module. 4 | """ 5 | 6 | from typing import Optional 7 | 8 | from openeo import DataCube 9 | from openeo_gfmap import TemporalContext 10 | from openeo_gfmap.features.feature_extractor import apply_feature_extractor 11 | from openeo_gfmap.inference.model_inference import apply_model_inference 12 | from openeo_gfmap.preprocessing.scaling import compress_uint8, compress_uint16 13 | 14 | from worldcereal.parameters import ( 15 | CropLandParameters, 16 | CropTypeParameters, 17 | PostprocessParameters, 18 | WorldCerealProductType, 19 | ) 20 | 21 | 22 | def _cropland_map( 23 | inputs: DataCube, 24 | temporal_extent: TemporalContext, 25 | cropland_parameters: CropLandParameters, 26 | postprocess_parameters: PostprocessParameters, 27 | ) -> DataCube: 28 | """Method to produce cropland map from preprocessed inputs, using 29 | a Presto feature extractor and a CatBoost classifier. 30 | 31 | Parameters 32 | ---------- 33 | inputs : DataCube 34 | preprocessed input cube 35 | temporal_extent : TemporalContext 36 | temporal extent of the input cube 37 | cropland_parameters: CropLandParameters 38 | Parameters for the cropland product inference pipeline 39 | postprocess_parameters: PostprocessParameters 40 | Parameters for the postprocessing 41 | Returns 42 | ------- 43 | DataCube 44 | binary labels and probability 45 | """ 46 | 47 | # Run feature computer 48 | features = apply_feature_extractor( 49 | feature_extractor_class=cropland_parameters.feature_extractor, 50 | cube=inputs, 51 | parameters=cropland_parameters.feature_parameters.model_dump(), 52 | size=[ 53 | {"dimension": "x", "unit": "px", "value": 128}, 54 | {"dimension": "y", "unit": "px", "value": 128}, 55 | ], 56 | overlap=[ 57 | {"dimension": "x", "unit": "px", "value": 0}, 58 | {"dimension": "y", "unit": "px", "value": 0}, 59 | ], 60 | ) 61 | 62 | # Run model inference on features 63 | parameters = cropland_parameters.classifier_parameters.model_dump( 64 | exclude=["classifier"] 65 | ) 66 | 67 | classes = apply_model_inference( 68 | model_inference_class=cropland_parameters.classifier, 69 | cube=features, 70 | parameters=parameters, 71 | size=[ 72 | {"dimension": "x", "unit": "px", "value": 128}, 73 | {"dimension": "y", "unit": "px", "value": 128}, 74 | {"dimension": "t", "value": "P1D"}, 75 | ], 76 | overlap=[ 77 | {"dimension": "x", "unit": "px", "value": 0}, 78 | {"dimension": "y", "unit": "px", "value": 0}, 79 | ], 80 | ) 81 | 82 | # Get rid of temporal dimension 83 | classes = classes.reduce_dimension(dimension="t", reducer="mean") 84 | 85 | # Postprocess 86 | if postprocess_parameters.enable: 87 | if postprocess_parameters.save_intermediate: 88 | classes = classes.save_result( 89 | format="GTiff", 90 | options=dict( 91 | filename_prefix=f"{WorldCerealProductType.CROPLAND.value}-raw_{temporal_extent.start_date}_{temporal_extent.end_date}" 92 | ), 93 | ) 94 | classes = _postprocess( 95 | classes, 96 | postprocess_parameters, 97 | cropland_parameters.classifier_parameters.classifier_url, 98 | ) 99 | 100 | # Cast to uint8 101 | classes = compress_uint8(classes) 102 | 103 | return classes 104 | 105 | 106 | def _croptype_map( 107 | inputs: DataCube, 108 | temporal_extent: TemporalContext, 109 | croptype_parameters: "CropTypeParameters", 110 | postprocess_parameters: "PostprocessParameters", 111 | cropland_mask: DataCube = None, 112 | ) -> DataCube: 113 | """Method to produce croptype map from preprocessed inputs, using 114 | a Presto feature extractor and a CatBoost classifier. 115 | 116 | Parameters 117 | ---------- 118 | inputs : DataCube 119 | preprocessed input cube 120 | temporal_extent : TemporalContext 121 | temporal extent of the input cube 122 | cropland_mask : DataCube, optional 123 | optional cropland mask, by default None 124 | lookup_table: dict, 125 | Mapping of class names to class labels, ordered by model output. 126 | Returns 127 | ------- 128 | DataCube 129 | croptype labels and probability 130 | """ 131 | 132 | # Run feature computer 133 | features = apply_feature_extractor( 134 | feature_extractor_class=croptype_parameters.feature_extractor, 135 | cube=inputs, 136 | parameters=croptype_parameters.feature_parameters.model_dump(), 137 | size=[ 138 | {"dimension": "x", "unit": "px", "value": 128}, 139 | {"dimension": "y", "unit": "px", "value": 128}, 140 | ], 141 | overlap=[ 142 | {"dimension": "x", "unit": "px", "value": 0}, 143 | {"dimension": "y", "unit": "px", "value": 0}, 144 | ], 145 | ) 146 | 147 | # Run model inference on features 148 | parameters = croptype_parameters.classifier_parameters.model_dump( 149 | exclude=["classifier"] 150 | ) 151 | 152 | classes = apply_model_inference( 153 | model_inference_class=croptype_parameters.classifier, 154 | cube=features, 155 | parameters=parameters, 156 | size=[ 157 | {"dimension": "x", "unit": "px", "value": 128}, 158 | {"dimension": "y", "unit": "px", "value": 128}, 159 | {"dimension": "t", "value": "P1D"}, 160 | ], 161 | overlap=[ 162 | {"dimension": "x", "unit": "px", "value": 0}, 163 | {"dimension": "y", "unit": "px", "value": 0}, 164 | ], 165 | ) 166 | 167 | # Get rid of temporal dimension 168 | classes = classes.reduce_dimension(dimension="t", reducer="mean") 169 | 170 | # Mask cropland 171 | if cropland_mask is not None: 172 | classes = classes.mask(cropland_mask == 0, replacement=254) 173 | 174 | # Postprocess 175 | if postprocess_parameters.enable: 176 | if postprocess_parameters.save_intermediate: 177 | classes = classes.save_result( 178 | format="GTiff", 179 | options=dict( 180 | filename_prefix=f"{WorldCerealProductType.CROPTYPE.value}-raw_{temporal_extent.start_date}_{temporal_extent.end_date}" 181 | ), 182 | ) 183 | classes = _postprocess( 184 | classes, 185 | postprocess_parameters, 186 | classifier_url=croptype_parameters.classifier_parameters.classifier_url, 187 | ) 188 | 189 | # Cast to uint16 190 | classes = compress_uint16(classes) 191 | 192 | return classes 193 | 194 | 195 | def _postprocess( 196 | classes: DataCube, 197 | postprocess_parameters: "PostprocessParameters", 198 | classifier_url: Optional[str] = None, 199 | ) -> DataCube: 200 | """Method to postprocess the classes. 201 | 202 | Parameters 203 | ---------- 204 | classes : DataCube 205 | classes to postprocess 206 | postprocess_parameters : PostprocessParameters 207 | parameter class for postprocessing 208 | lookup_table: dict 209 | Mapping of class names to class labels, ordered by model output. 210 | Returns 211 | ------- 212 | DataCube 213 | postprocessed classes 214 | """ 215 | 216 | # Run postprocessing on the raw classification output 217 | # Note that this uses the `apply_model_inference` method even though 218 | # it is not truly model inference 219 | parameters = postprocess_parameters.model_dump(exclude=["postprocessor"]) 220 | parameters.update({"classifier_url": classifier_url}) 221 | 222 | postprocessed_classes = apply_model_inference( 223 | model_inference_class=postprocess_parameters.postprocessor, 224 | cube=classes, 225 | parameters=parameters, 226 | size=[ 227 | {"dimension": "x", "unit": "px", "value": 128}, 228 | {"dimension": "y", "unit": "px", "value": 128}, 229 | ], 230 | overlap=[ 231 | {"dimension": "x", "unit": "px", "value": 0}, 232 | {"dimension": "y", "unit": "px", "value": 0}, 233 | ], 234 | ) 235 | 236 | return postprocessed_classes 237 | -------------------------------------------------------------------------------- /src/worldcereal/openeo/masking.py: -------------------------------------------------------------------------------- 1 | from skimage.morphology import footprints 2 | 3 | 4 | def convolve(img, radius): 5 | """OpenEO method to apply convolution 6 | with a circular kernel of `radius` pixels. 7 | NOTE: make sure the resolution of the image 8 | matches the expected radius in pixels! 9 | """ 10 | kernel = footprints.disk(radius) 11 | img = img.apply_kernel(kernel) 12 | return img 13 | 14 | 15 | def scl_mask_erode_dilate( 16 | session, 17 | bbox, 18 | scl_layer_band="TERRASCOPE_S2_TOC_V2:SCL", 19 | erode_r=3, 20 | dilate_r=21, 21 | target_crs=None, 22 | ): 23 | """OpenEO method to construct a Sentinel-2 mask based on SCL. 24 | It involves an erosion step followed by a dilation step. 25 | 26 | Args: 27 | session (openeo.Session): the connection openeo session 28 | scl_layer_band (str, optional): Which SCL band to use. 29 | Defaults to "TERRASCOPE_S2_TOC_V2:SCL". 30 | erode_r (int, optional): Erosion radius (pixels). Defaults to 3. 31 | dilate_r (int, optional): Dilation radius (pixels). Defaults to 13. 32 | 33 | Returns: 34 | DataCube: DataCube containing the resulting mask 35 | """ 36 | 37 | layer_band = scl_layer_band.split(":") 38 | s2_sceneclassification = session.load_collection( 39 | layer_band[0], bands=[layer_band[1]], spatial_extent=bbox 40 | ) 41 | 42 | classification = s2_sceneclassification.band(layer_band[1]) 43 | 44 | # Force to go to 10m resolution for controlled erosion/dilation 45 | classification = classification.resample_spatial( 46 | projection=target_crs, resolution=10.0 47 | ) 48 | 49 | first_mask = classification == 0 50 | for mask_value in [1, 3, 8, 9, 10, 11]: 51 | first_mask = (first_mask == 1) | (classification == mask_value) 52 | 53 | # Invert mask for erosion 54 | first_mask = first_mask.apply(lambda x: (x == 1).not_()) 55 | 56 | # Apply erosion by dilation the inverted mask 57 | erode_cube = convolve(first_mask, erode_r) 58 | 59 | # Invert again 60 | erode_cube = erode_cube > 0.1 61 | erode_cube = erode_cube.apply(lambda x: (x == 1).not_()) 62 | 63 | # Now dilate the mask 64 | dilate_cube = convolve(erode_cube, dilate_r) 65 | 66 | # Get binary mask. NOTE: >0.1 is a fix to avoid being triggered 67 | # by small non-zero oscillations after applying convolution 68 | dilate_cube = dilate_cube > 0.1 69 | 70 | return dilate_cube 71 | -------------------------------------------------------------------------------- /src/worldcereal/openeo/udf_distance_to_cloud.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | from openeo.udf import XarrayDataCube 4 | from scipy.ndimage import distance_transform_cdt 5 | from skimage.morphology import binary_erosion, footprints 6 | 7 | 8 | def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube: 9 | cube_array: xr.DataArray = cube.get_array() 10 | cube_array = cube_array.transpose("bands", "y", "x") 11 | 12 | clouds: xr.DataArray = np.logical_or( 13 | np.logical_and(cube_array < 11, cube_array >= 8), cube_array == 3 14 | ).isel( 15 | bands=0 16 | ) # type: ignore 17 | 18 | # Calculate the Distance To Cloud score 19 | # Erode 20 | er = footprints.disk(3) 21 | 22 | # Define a function to apply binary erosion 23 | def erode(image, selem): 24 | return ~binary_erosion(image, selem) 25 | 26 | # Use apply_ufunc to apply the erosion operation 27 | eroded = xr.apply_ufunc( 28 | erode, # function to apply 29 | clouds, # input DataArray 30 | input_core_dims=[["y", "x"]], # dimensions over which to apply function 31 | output_core_dims=[["y", "x"]], # dimensions of the output 32 | vectorize=True, # vectorize the function over non-core dimensions 33 | dask="parallelized", # enable dask parallelization 34 | output_dtypes=[np.int32], # data type of the output 35 | kwargs={"selem": er}, # additional keyword arguments to pass to erode 36 | ) 37 | 38 | # Distance to cloud in manhattan distance measure 39 | distance = xr.apply_ufunc( 40 | distance_transform_cdt, 41 | eroded, 42 | input_core_dims=[["y", "x"]], 43 | output_core_dims=[["y", "x"]], 44 | vectorize=True, 45 | dask="parallelized", 46 | output_dtypes=[np.int32], 47 | ) 48 | 49 | distance_da = xr.DataArray( 50 | distance, 51 | coords={ 52 | "y": cube_array.coords["y"], 53 | "x": cube_array.coords["x"], 54 | }, 55 | dims=["y", "x"], 56 | ) 57 | 58 | distance_da = distance_da.expand_dims( 59 | dim={ 60 | "bands": cube_array.coords["bands"], 61 | }, 62 | ) 63 | 64 | distance_da = distance_da.transpose("bands", "y", "x") 65 | 66 | return XarrayDataCube(distance_da) 67 | -------------------------------------------------------------------------------- /src/worldcereal/rdm_api/__init__.py: -------------------------------------------------------------------------------- 1 | """This sub-module contains utilitary function and tools for worldcereal-classification""" 2 | 3 | from .rdm_collection import RdmCollection 4 | from .rdm_interaction import RdmInteraction 5 | 6 | __all__ = [ 7 | "RdmInteraction", 8 | "RdmCollection", 9 | ] 10 | -------------------------------------------------------------------------------- /src/worldcereal/rdm_api/rdm_collection.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class RdmCollection: 5 | """Data class to host collections queried from the RDM API.""" 6 | 7 | def __init__(self, **metadata): 8 | """Initializes the RdmCollection object with metadata from the RDM API. 9 | Collection metadata is passed as keyword arguments and stored as attributes. 10 | All default collection metadata items are expected. 11 | 12 | RdmCollection can be initialized directly from the RDM API response resulting from 13 | a query to the collections endpoint (see get_collections function in rdmi_interaction.py). 14 | """ 15 | # check presence of mandatory metadata items 16 | if not metadata.get("collectionId"): 17 | raise ValueError( 18 | "Collection ID is missing, cannot create RdmCollection object." 19 | ) 20 | if not metadata.get("accessType"): 21 | raise ValueError( 22 | "Access type is missing, cannot create RdmCollection object." 23 | ) 24 | 25 | # Get all metadata items 26 | self.id = metadata.get("collectionId") 27 | self.title = metadata.get("title") 28 | self.feature_count = metadata.get("featureCount") 29 | self.data_type = metadata.get("type") 30 | self.access_type = metadata.get("accessType") 31 | self.observation_method = metadata.get("typeOfObservationMethod") 32 | self.confidence_lc = metadata.get("confidenceLandCover") 33 | self.confidence_ct = metadata.get("confidenceCropType") 34 | self.confidence_irr = metadata.get("confidenceIrrigationType") 35 | self.ewoc_codes = metadata.get("ewocCodes") 36 | self.irr_codes = metadata.get("irrTypes") 37 | self.extent = metadata.get("extent") 38 | if self.extent: 39 | self.spatial_extent = self.extent["spatial"] 40 | self.temporal_extent = self.extent["temporal"]["interval"][0] 41 | else: 42 | self.spatial_extent = None 43 | self.temporal_extent = None 44 | self.additional_data = metadata.get("additionalData") 45 | self.crs = metadata.get("crs") 46 | self.last_modified = metadata.get("lastModificationTime") 47 | self.last_modified_by = metadata.get("lastModifierId") 48 | self.creation_time = metadata.get("creationTime") 49 | self.created_by = metadata.get("creatorId") 50 | self.fid = metadata.get("id") 51 | 52 | def print_metadata(self): 53 | 54 | print("#######################") 55 | print("Collection Metadata:") 56 | print(f"ID: {self.id}") 57 | print(f"Title: {self.title}") 58 | print(f"Number of samples: {self.feature_count}") 59 | print(f"Data type: {self.data_type}") 60 | print(f"Access type: {self.access_type}") 61 | print(f"Observation method: {self.observation_method}") 62 | print(f"Confidence score for land cover: {self.confidence_lc}") 63 | print(f"Confidence score for crop type: {self.confidence_ct}") 64 | print(f"Confidence score for irrigation label: {self.confidence_irr}") 65 | print(f"List of available crop types: {self.ewoc_codes}") 66 | print(f"List of available irrigation labels: {self.irr_codes}") 67 | print(f"Spatial extent: {self.spatial_extent}") 68 | print(f"Coordinate reference system (CRS): {self.crs}") 69 | print(f"Temporal extent: {self.temporal_extent}") 70 | print(f"Additional data: {self.additional_data}") 71 | print(f"Last modified: {self.last_modified}") 72 | print(f"Last modified by: {self.last_modified_by}") 73 | print(f"Creation time: {self.creation_time}") 74 | print(f"Created by: {self.created_by}") 75 | print(f"fid: {self.fid}") 76 | 77 | 78 | def visualize_spatial_extents(collections: List[RdmCollection]): 79 | """Visualizes the spatial extent of multiple collections on a map.""" 80 | 81 | from ipyleaflet import Map, Rectangle, basemaps 82 | 83 | if len(collections) == 1: 84 | zoom = 5 85 | colbbox = collections[0].spatial_extent.get("bbox", None) 86 | if colbbox is None: 87 | raise ValueError( 88 | f"No bounding box found for collection {collections[0].id}." 89 | ) 90 | colbbox = colbbox[0] 91 | # compute the center of the bounding box 92 | center = [(colbbox[1] + colbbox[3]) / 2, (colbbox[0] + colbbox[2]) / 2] 93 | else: 94 | zoom = 1 95 | center = [0, 0] 96 | 97 | # Create the basemap 98 | m = Map( 99 | basemap=basemaps.CartoDB.Positron, 100 | zoom=zoom, 101 | center=center, 102 | scroll_wheel_zoom=True, 103 | ) 104 | 105 | # Get the extent of each collection 106 | for col in collections: 107 | colbbox = col.spatial_extent.get("bbox", None) 108 | if colbbox is None: 109 | raise ValueError(f"No bounding box found for collection {col.id}.") 110 | colbbox = colbbox[0] 111 | bbox = [[colbbox[1], colbbox[0]], [colbbox[3], colbbox[2]]] 112 | 113 | # create a rectangle from the bounding box 114 | rectangle = Rectangle(bounds=bbox, color="green", weight=2, fill_opacity=0.1) 115 | 116 | # Add the rectangle to the map 117 | m.add_layer(rectangle) 118 | 119 | return m 120 | -------------------------------------------------------------------------------- /src/worldcereal/stac/__init__.py: -------------------------------------------------------------------------------- 1 | """STAC constants and utilities relative to WorldCereal intermediary and deliverable products""" 2 | -------------------------------------------------------------------------------- /src/worldcereal/stac/stac_api_interaction.py: -------------------------------------------------------------------------------- 1 | import concurrent 2 | from concurrent.futures import ThreadPoolExecutor 3 | from typing import Iterable 4 | 5 | import pystac 6 | import pystac_client 7 | import requests 8 | from openeo.rest.auth.oidc import ( 9 | OidcClientInfo, 10 | OidcProviderInfo, 11 | OidcResourceOwnerPasswordAuthenticator, 12 | ) 13 | from requests.auth import AuthBase 14 | 15 | 16 | class VitoStacApiAuthentication(AuthBase): 17 | """Class that handles authentication for the VITO STAC API. https://stac.openeo.vito.be/""" 18 | 19 | def __init__(self, **kwargs): 20 | self.username = kwargs.get("username") 21 | self.password = kwargs.get("password") 22 | 23 | def __call__(self, request): 24 | request.headers["Authorization"] = self.get_access_token() 25 | return request 26 | 27 | def get_access_token(self) -> str: 28 | """Get API bearer access token via password flow. 29 | 30 | Returns 31 | ------- 32 | str 33 | A string containing the bearer access token. 34 | """ 35 | provider_info = OidcProviderInfo( 36 | issuer="https://sso.terrascope.be/auth/realms/terrascope" 37 | ) 38 | 39 | client_info = OidcClientInfo( 40 | client_id="terracatalogueclient", 41 | provider=provider_info, 42 | ) 43 | 44 | if self.username and self.password: 45 | authenticator = OidcResourceOwnerPasswordAuthenticator( 46 | client_info=client_info, username=self.username, password=self.password 47 | ) 48 | else: 49 | raise ValueError( 50 | "Credentials are required to obtain an access token. Please set STAC_API_USERNAME and STAC_API_PASSWORD environment variables." 51 | ) 52 | 53 | tokens = authenticator.get_tokens() 54 | 55 | return f"Bearer {tokens.access_token}" 56 | 57 | 58 | class StacApiInteraction: 59 | """Class that handles the interaction with a STAC API.""" 60 | 61 | _SENSOR_COLLECTION_CATALOG = { 62 | "Sentinel1": "worldcereal_sentinel_1_patch_extractions", 63 | "Sentinel2": "worldcereal_sentinel_2_patch_extractions", 64 | } 65 | 66 | def __init__( 67 | self, sensor: str, base_url: str, auth: AuthBase, bulk_size: int = 500 68 | ): 69 | if sensor not in self.catalog.keys(): 70 | raise ValueError( 71 | f"Invalid sensor '{sensor}'. Allowed values are: {', '.join(self.catalog.keys())}." 72 | ) 73 | self.sensor = sensor 74 | self.base_url = base_url 75 | self.collection_id = self.catalog[self.sensor] 76 | 77 | self.auth = auth 78 | 79 | self.bulk_size = bulk_size 80 | 81 | @property 82 | def catalog(self): 83 | return self._SENSOR_COLLECTION_CATALOG.copy() 84 | 85 | def exists(self) -> bool: 86 | client = pystac_client.Client.open(self.base_url) 87 | return ( 88 | len([c.id for c in client.get_collections() if c.id == self.collection_id]) 89 | > 0 90 | ) 91 | 92 | def _join_url(self, url_path: str) -> str: 93 | return str(self.base_url + "/" + url_path) 94 | 95 | def create_collection(self): 96 | spatial_extent = pystac.SpatialExtent([[-180, -90, 180, 90]]) 97 | temporal_extent = pystac.TemporalExtent([[None, None]]) 98 | extent = pystac.Extent(spatial=spatial_extent, temporal=temporal_extent) 99 | 100 | collection = pystac.Collection( 101 | id=self.collection_id, 102 | description=f"WorldCereal Patch Extractions for {self.sensor}", 103 | extent=extent, 104 | ) 105 | 106 | collection.validate() 107 | coll_dict = collection.to_dict() 108 | 109 | default_auth = { 110 | "_auth": { 111 | "read": ["anonymous"], 112 | "write": ["stac-openeo-admin", "stac-openeo-editor"], 113 | } 114 | } 115 | 116 | coll_dict.update(default_auth) 117 | 118 | response = requests.post( 119 | self._join_url("collections"), auth=self.auth, json=coll_dict 120 | ) 121 | 122 | expected_status = [ 123 | requests.status_codes.codes.ok, 124 | requests.status_codes.codes.created, 125 | requests.status_codes.codes.accepted, 126 | ] 127 | 128 | self._check_response_status(response, expected_status) 129 | 130 | return response 131 | 132 | def add_item(self, item: pystac.Item): 133 | if not self.exists(): 134 | self.create_collection() 135 | 136 | self._prepare_item(item) 137 | 138 | url_path = f"collections/{self.collection_id}/items" 139 | response = requests.post( 140 | self._join_url(url_path), auth=self.auth, json=item.to_dict() 141 | ) 142 | 143 | expected_status = [ 144 | requests.status_codes.codes.ok, 145 | requests.status_codes.codes.created, 146 | requests.status_codes.codes.accepted, 147 | ] 148 | 149 | self._check_response_status(response, expected_status) 150 | 151 | return response 152 | 153 | def _prepare_item(self, item: pystac.Item): 154 | item.collection_id = self.collection_id 155 | if not item.get_links(pystac.RelType.COLLECTION): 156 | item.add_link( 157 | pystac.Link(rel=pystac.RelType.COLLECTION, target=item.collection_id) 158 | ) 159 | 160 | def _ingest_bulk(self, items: Iterable[pystac.Item]) -> dict: 161 | if not all(i.collection_id == self.collection_id for i in items): 162 | raise Exception("All collection IDs should be identical for bulk ingests") 163 | 164 | url_path = f"collections/{self.collection_id}/bulk_items" 165 | data = { 166 | "method": "upsert", 167 | "items": {item.id: item.to_dict() for item in items}, 168 | } 169 | response = requests.post( 170 | url=self._join_url(url_path), auth=self.auth, json=data 171 | ) 172 | 173 | expected_status = [ 174 | requests.status_codes.codes.ok, 175 | requests.status_codes.codes.created, 176 | requests.status_codes.codes.accepted, 177 | ] 178 | 179 | self._check_response_status(response, expected_status) 180 | return response.json() 181 | 182 | def upload_items_bulk(self, items: Iterable[pystac.Item]) -> None: 183 | if not self.exists(): 184 | self.create_collection() 185 | 186 | chunk = [] 187 | futures = [] 188 | 189 | with ThreadPoolExecutor(max_workers=4) as executor: 190 | for item in items: 191 | self._prepare_item(item) 192 | chunk.append(item) 193 | 194 | if len(chunk) == self.bulk_size: 195 | futures.append(executor.submit(self._ingest_bulk, chunk.copy())) 196 | chunk = [] 197 | 198 | if chunk: 199 | self._ingest_bulk(chunk) 200 | 201 | for _ in concurrent.futures.as_completed(futures): 202 | continue 203 | 204 | def _check_response_status( 205 | self, response: requests.Response, expected_status_codes: list[int] 206 | ): 207 | if response.status_code not in expected_status_codes: 208 | message = ( 209 | f"Expecting HTTP status to be any of {expected_status_codes} " 210 | + f"but received {response.status_code} - {response.reason}, request method={response.request.method}\n" 211 | + f"response body:\n{response.text}" 212 | ) 213 | 214 | raise Exception(message) 215 | 216 | def get_collection_id(self) -> str: 217 | return self.collection_id 218 | -------------------------------------------------------------------------------- /src/worldcereal/utils/geoloader.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import geopandas as gpd 4 | import numpy as np 5 | import rasterio 6 | from loguru import logger 7 | from rasterio.crs import CRS 8 | from rasterio.enums import Resampling 9 | from rasterio.warp import reproject 10 | from shapely.geometry import Polygon 11 | 12 | 13 | def _load_array_bounds_latlon( 14 | fname, 15 | bounds, 16 | rio_gdal_options=None, 17 | boundless=True, 18 | fill_value=np.nan, 19 | nodata_value=None, 20 | ): 21 | rio_gdal_options = rio_gdal_options or {} 22 | 23 | with rasterio.Env(**rio_gdal_options): 24 | with rasterio.open(fname) as src: 25 | window = rasterio.windows.from_bounds(*bounds, src.transform) 26 | 27 | vals = np.array(window.flatten()) 28 | if (vals % 1 > 0).any(): 29 | col_off, row_off, width, height = ( 30 | window.col_off, 31 | window.row_off, 32 | window.width, 33 | window.height, 34 | ) 35 | 36 | width, height = int(math.ceil(width)), int(math.ceil(height)) 37 | col_off, row_off = int(math.floor(col_off + 0.5)), int( 38 | math.floor(row_off + 0.5) 39 | ) 40 | 41 | window = rasterio.windows.Window(col_off, row_off, width, height) 42 | 43 | if nodata_value is None: 44 | nodata_value = src.nodata if src.nodata is not None else np.nan 45 | 46 | if nodata_value is None and boundless is True: 47 | logger.warning( 48 | "Raster has no data value, defaulting boundless" 49 | " to False. Specify a nodata_value to read " 50 | "boundless." 51 | ) 52 | boundless = False 53 | 54 | arr = src.read(window=window, boundless=boundless, fill_value=nodata_value) 55 | arr = arr.astype( 56 | np.float32 57 | ) # needed reprojecting with bilinear resampling # noqa:e501 58 | 59 | if nodata_value is not None: 60 | arr[arr == nodata_value] = fill_value 61 | 62 | arr = arr[0] 63 | return arr 64 | 65 | 66 | def load_reproject( 67 | filename, 68 | bounds, 69 | epsg, 70 | resolution=10, 71 | border_buff=0, 72 | fill_value=0, 73 | nodata_value=None, 74 | rio_gdal_options=None, 75 | resampling=Resampling.nearest, 76 | ): 77 | """ 78 | Read from latlon layer and reproject to UTM 79 | """ 80 | bbox = gpd.GeoSeries(Polygon.from_bounds(*bounds), crs=CRS.from_epsg(epsg)) 81 | 82 | bounds = bbox.buffer(border_buff * resolution).to_crs(epsg=4326).bounds.values[0] 83 | utm_bounds = bbox.buffer(border_buff * resolution).bounds.values[0].tolist() 84 | 85 | width = max(1, int((utm_bounds[2] - utm_bounds[0]) / resolution)) 86 | height = max(1, int((utm_bounds[3] - utm_bounds[1]) / resolution)) 87 | 88 | gim = _load_array_bounds_latlon( 89 | filename, 90 | bounds, 91 | rio_gdal_options=rio_gdal_options, 92 | fill_value=fill_value, 93 | nodata_value=nodata_value, 94 | ) 95 | 96 | src_crs = CRS.from_epsg(4326) 97 | dst_crs = CRS.from_epsg(bbox.crs.to_epsg()) 98 | 99 | src_transform = rasterio.transform.from_bounds(*bounds, gim.shape[1], gim.shape[0]) 100 | dst_transform = rasterio.transform.from_bounds(*utm_bounds, width, height) 101 | 102 | dst = np.zeros((height, width), dtype=np.float32) 103 | 104 | reproject( 105 | gim.astype(np.float32), 106 | dst, 107 | src_transform=src_transform, 108 | dst_transform=dst_transform, 109 | src_crs=src_crs, 110 | dst_crs=dst_crs, 111 | resampling=resampling, 112 | ) 113 | 114 | if border_buff > 0: 115 | dst = dst[border_buff:-border_buff, border_buff:-border_buff] 116 | 117 | return dst 118 | -------------------------------------------------------------------------------- /src/worldcereal/utils/models.py: -------------------------------------------------------------------------------- 1 | """Utilities around models for the WorldCereal package.""" 2 | 3 | import json 4 | from functools import lru_cache 5 | 6 | import onnxruntime as ort 7 | import requests 8 | 9 | 10 | @lru_cache(maxsize=2) 11 | def load_model_onnx(model_url) -> ort.InferenceSession: 12 | """Load an ONNX model from a URL. 13 | 14 | Parameters 15 | ---------- 16 | model_url: str 17 | URL to the ONNX model. 18 | 19 | Returns 20 | ------- 21 | ort.InferenceSession 22 | ONNX model loaded with ONNX runtime. 23 | """ 24 | # Two minutes timeout to download the model 25 | response = requests.get(model_url, timeout=120) 26 | model = response.content 27 | 28 | return ort.InferenceSession(model) 29 | 30 | 31 | def validate_cb_model(model_url: str) -> ort.InferenceSession: 32 | """Validate a catboost model by loading it and checking if the required 33 | metadata is present. Checks for the `class_names` and `class_to_labels` 34 | fields are present in the `class_params` field of the custom metadata of 35 | the model. By default, the CatBoost module should include those fields 36 | when exporting a model to ONNX. 37 | 38 | Raises an exception if the model is not valid. 39 | 40 | Parameters 41 | ---------- 42 | model_url : str 43 | URL to the ONNX model. 44 | 45 | Returns 46 | ------- 47 | ort.InferenceSession 48 | ONNX model loaded with ONNX runtime. 49 | """ 50 | model = load_model_onnx(model_url=model_url) 51 | 52 | metadata = model.get_modelmeta().custom_metadata_map 53 | 54 | if "class_params" not in metadata: 55 | raise ValueError("Could not find class names in the model metadata.") 56 | 57 | class_params = json.loads(metadata["class_params"]) 58 | 59 | if "class_names" not in class_params: 60 | raise ValueError("Could not find class names in the model metadata.") 61 | 62 | if "class_to_label" not in class_params: 63 | raise ValueError("Could not find class to labels in the model metadata.") 64 | 65 | return model 66 | 67 | 68 | def load_model_lut(model_url: str) -> dict: 69 | """Load the class names to labels mapping from a CatBoost model. 70 | 71 | Parameters 72 | ---------- 73 | model_url : str 74 | URL to the ONNX model. 75 | 76 | Returns 77 | ------- 78 | dict 79 | Look-up table with class names and labels. 80 | """ 81 | model = validate_cb_model(model_url=model_url) 82 | metadata = model.get_modelmeta().custom_metadata_map 83 | class_params = json.loads(metadata["class_params"]) 84 | 85 | lut = dict(zip(class_params["class_names"], class_params["class_to_label"])) 86 | sorted_lut = {k: v for k, v in sorted(lut.items(), key=lambda item: item[1])} 87 | return sorted_lut 88 | -------------------------------------------------------------------------------- /src/worldcereal/utils/retry.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import random 3 | import time 4 | from functools import partial 5 | 6 | from loguru import logger 7 | 8 | 9 | def decorator(caller): 10 | """Turns caller into a decorator. 11 | Unlike decorator module, function signature is not preserved. 12 | :param caller: caller(f, *args, **kwargs) 13 | """ 14 | 15 | def decor(f): 16 | @functools.wraps(f) 17 | def wrapper(*args, **kwargs): 18 | return caller(f, *args, **kwargs) 19 | 20 | return wrapper 21 | 22 | return decor 23 | 24 | 25 | def __retry_internal( 26 | f, 27 | exceptions=Exception, 28 | tries=-1, 29 | delay=0, 30 | max_delay=None, 31 | backoff=1, 32 | jitter=0, 33 | logger=logger, 34 | ): 35 | """ 36 | Executes a function and retries it if it failed. 37 | :param f: the function to execute. 38 | :param exceptions: an exception or a tuple of exceptions to catch. default: Exception. 39 | :param tries: the maximum number of attempts. default: -1 (infinite). 40 | :param delay: initial delay between attempts. default: 0. 41 | :param max_delay: the maximum value of delay. default: None (no limit). 42 | :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff). 43 | :param jitter: extra seconds added to delay between attempts. default: 0. 44 | fixed if a number, random if a range tuple (min, max) 45 | :param logger: logger.warning(fmt, error, delay) will be called on failed attempts. 46 | default: retry.logger. if None, logging is disabled. 47 | :returns: the result of the f function. 48 | """ 49 | _tries, _delay = tries, delay 50 | while _tries: 51 | try: 52 | return f() 53 | except exceptions as e: 54 | _tries -= 1 55 | if not _tries: 56 | raise 57 | 58 | if logger is not None: 59 | logger.warning(f'Error "{e}", retrying in {_delay} seconds...') 60 | 61 | time.sleep(_delay) 62 | _delay *= backoff 63 | 64 | if isinstance(jitter, tuple): 65 | _delay += random.uniform(*jitter) 66 | else: 67 | _delay += jitter 68 | 69 | if max_delay is not None: 70 | _delay = min(_delay, max_delay) 71 | 72 | 73 | def retry( 74 | exceptions=Exception, 75 | tries=-1, 76 | delay=0, 77 | max_delay=None, 78 | backoff=1, 79 | jitter=0, 80 | logger=logger, 81 | ): 82 | """Returns a retry decorator. 83 | :param exceptions: an exception or a tuple of exceptions to catch. default: Exception. 84 | :param tries: the maximum number of attempts. default: -1 (infinite). 85 | :param delay: initial delay between attempts. default: 0. 86 | :param max_delay: the maximum value of delay. default: None (no limit). 87 | :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff). 88 | :param jitter: extra seconds added to delay between attempts. default: 0. 89 | fixed if a number, random if a range tuple (min, max) 90 | :param logger: logger.warning(fmt, error, delay) will be called on failed attempts. 91 | default: retry.logger. if None, logging is disabled. 92 | :returns: a retry decorator. 93 | """ 94 | 95 | @decorator 96 | def retry_decorator(f, *fargs, **fkwargs): 97 | args = fargs if fargs else [] 98 | kwargs = fkwargs if fkwargs else {} 99 | return __retry_internal( 100 | partial(f, *args, **kwargs), 101 | exceptions, 102 | tries, 103 | delay, 104 | max_delay, 105 | backoff, 106 | jitter, 107 | logger, 108 | ) 109 | 110 | return retry_decorator 111 | 112 | 113 | def retry_call( 114 | f, 115 | fargs=None, 116 | fkwargs=None, 117 | exceptions=Exception, 118 | tries=-1, 119 | delay=0, 120 | max_delay=None, 121 | backoff=1, 122 | jitter=0, 123 | logger=logger, 124 | ): 125 | """ 126 | Calls a function and re-executes it if it failed. 127 | :param f: the function to execute. 128 | :param fargs: the positional arguments of the function to execute. 129 | :param fkwargs: the named arguments of the function to execute. 130 | :param exceptions: an exception or a tuple of exceptions to catch. default: Exception. 131 | :param tries: the maximum number of attempts. default: -1 (infinite). 132 | :param delay: initial delay between attempts. default: 0. 133 | :param max_delay: the maximum value of delay. default: None (no limit). 134 | :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff). 135 | :param jitter: extra seconds added to delay between attempts. default: 0. 136 | fixed if a number, random if a range tuple (min, max) 137 | :param logger: logger.warning(fmt, error, delay) will be called on failed attempts. 138 | default: retry.logger. if None, logging is disabled. 139 | :returns: the result of the f function. 140 | """ 141 | args = fargs if fargs else [] 142 | kwargs = fkwargs if fkwargs else {} 143 | return __retry_internal( 144 | partial(f, *args, **kwargs), 145 | exceptions, 146 | tries, 147 | delay, 148 | max_delay, 149 | backoff, 150 | jitter, 151 | logger, 152 | ) 153 | -------------------------------------------------------------------------------- /src/worldcereal/utils/spark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from loguru import logger 5 | 6 | 7 | def get_spark_context( 8 | name="WORLDCEREAL", localspark=False, threads="*", spark_version="3_2_0" 9 | ): 10 | """ 11 | Returns SparkContext for local run. 12 | if local is True, conf is ignored. 13 | 14 | Customized for VITO MEP 15 | """ 16 | if localspark: 17 | SPARK_HOME_2_0_0 = "/usr/hdp/current/spark2-client" 18 | SPARK_HOME_3_0_0 = "/opt/spark3_0_0" 19 | SPARK_HOME_3_2_0 = "/opt/spark3_2_0" 20 | 21 | SPARK_HOME = { 22 | "2_0_0": SPARK_HOME_2_0_0, 23 | "3_0_0": SPARK_HOME_3_0_0, 24 | "3_2_0": SPARK_HOME_3_2_0, 25 | } 26 | 27 | PY4J = { 28 | "2_0_0": "py4j-0.10.7", 29 | "3_0_0": "py4j-0.10.8.1", 30 | "3_2_0": "py4j-0.10.9.2", 31 | } 32 | 33 | SPARK_MAJOR_VERSION = {"2_0_0": "2", "3_0_0": "3", "3_2_0": "3"} 34 | 35 | spark_home = SPARK_HOME[spark_version] 36 | py4j_version = PY4J[spark_version] 37 | spark_major_version = SPARK_MAJOR_VERSION[spark_version] 38 | 39 | spark_py_path = [ 40 | f"{spark_home}/python", 41 | f"{spark_home}/python/lib/{py4j_version}-src.zip", 42 | ] 43 | 44 | env_vars = { 45 | "SPARK_MAJOR_VERSION": spark_major_version, 46 | "SPARK_HOME": spark_home, 47 | } 48 | for k, v in env_vars.items(): 49 | logger.info(f"Setting env var: {k}={v}") 50 | os.environ[k] = v 51 | 52 | logger.info(f"Prepending {spark_py_path} to PYTHONPATH") 53 | sys.path = spark_py_path + sys.path 54 | 55 | import py4j 56 | 57 | logger.info(f"py4j: {py4j.__file__}") 58 | 59 | import pyspark 60 | 61 | logger.info(f"pyspark: {pyspark.__file__}") 62 | 63 | import cloudpickle 64 | import pyspark.serializers 65 | from pyspark import SparkConf, SparkContext 66 | 67 | pyspark.serializers.cloudpickle = cloudpickle 68 | 69 | logger.info(f"Setting env var: PYSPARK_PYTHON={sys.executable}") 70 | os.environ["PYSPARK_PYTHON"] = sys.executable 71 | 72 | conf = SparkConf() 73 | conf.setMaster(f"local[{threads}]") 74 | conf.set("spark.driver.bindAddress", "127.0.0.1") 75 | 76 | sc = SparkContext(conf=conf) 77 | 78 | else: 79 | import cloudpickle 80 | import pyspark.serializers 81 | from pyspark.sql import SparkSession 82 | 83 | pyspark.serializers.cloudpickle = cloudpickle 84 | 85 | spark = SparkSession.builder.appName(name).getOrCreate() 86 | sc = spark.sparkContext 87 | 88 | return sc 89 | -------------------------------------------------------------------------------- /tests/pre_test_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install git 4 | dnf install git -y 5 | 6 | # # Install openeo-gfmap and presto-worldcereal 7 | # dir=$(pwd) 8 | # GFMAP_URL="https://github.com/Open-EO/openeo-gfmap.git" 9 | # PRESTO_URL="https://github.com/WorldCereal/presto-worldcereal.git" 10 | 11 | # su - jenkins -c "cd $dir && \ 12 | # source venv310/bin/activate && \ 13 | # git clone $GFMAP_URL && \ 14 | # cd openeo-gfmap || { echo 'Directory not found! Exiting...'; exit 1; } && \ 15 | # pip install . && \ 16 | # cd .. 17 | # git clone -b croptype $PRESTO_URL && \ 18 | # cd presto-worldcereal || { echo 'Directory not found! Exiting...'; exit 1; } && \ 19 | # pip install . 20 | # " 21 | 22 | # For now only presto-worldcereal as gfmap is up to date on pypi 23 | dir=$(pwd) 24 | PRESTO_URL="https://github.com/WorldCereal/presto-worldcereal.git" 25 | 26 | su - jenkins -c "cd $dir && \ 27 | source venv310/bin/activate && \ 28 | git clone -b croptype $PRESTO_URL && \ 29 | cd presto-worldcereal || { echo 'Directory not found! Exiting...'; exit 1; } && \ 30 | pip install . 31 | " -------------------------------------------------------------------------------- /tests/worldcerealtests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/__init__.py -------------------------------------------------------------------------------- /tests/worldcerealtests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import geojson 5 | import pandas as pd 6 | import pytest 7 | import xarray as xr 8 | 9 | 10 | def get_test_resource(relative_path): 11 | dir = Path(os.path.dirname(os.path.realpath(__file__))) 12 | return dir / "testresources" / relative_path 13 | 14 | 15 | @pytest.fixture 16 | def WorldCerealPreprocessedInputs(): 17 | filepath = get_test_resource("worldcereal_preprocessed_inputs.nc") 18 | arr = ( 19 | xr.open_dataset(filepath) 20 | .to_array(dim="bands") 21 | .drop_sel(bands="crs") 22 | .astype("uint16") 23 | ) 24 | return arr 25 | 26 | 27 | @pytest.fixture 28 | def SpatialExtent(): 29 | filepath = get_test_resource("spatial_extent.json") 30 | with open(filepath, "r") as f: 31 | return geojson.load(f) 32 | 33 | 34 | @pytest.fixture 35 | def WorldCerealCroplandClassification(): 36 | filepath = get_test_resource("worldcereal_cropland_classification.nc") 37 | arr = xr.open_dataarray(filepath).astype("uint16") 38 | return arr 39 | 40 | 41 | @pytest.fixture 42 | def WorldCerealCroptypeClassification(): 43 | filepath = get_test_resource("worldcereal_croptype_classification.nc") 44 | arr = xr.open_dataarray(filepath).astype("uint16") 45 | return arr 46 | 47 | 48 | @pytest.fixture 49 | def WorldCerealExtractionsDF(): 50 | filepath = get_test_resource("test_public_extractions.parquet") 51 | return pd.read_parquet(filepath) 52 | 53 | 54 | @pytest.fixture 55 | def WorldCerealPrivateExtractionsPath(): 56 | filepath = get_test_resource("worldcereal_private_extractions_dummy.parquet") 57 | return filepath 58 | -------------------------------------------------------------------------------- /tests/worldcerealtests/test_feature_extractor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | 4 | from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor 5 | 6 | 7 | def test_slope_computation(): 8 | test_elevation = np.array( 9 | [[10, 20, 30, 30], [10, 20, 20, 20], [65535, 20, 20, 20]], dtype=np.uint16 10 | ) 11 | 12 | array = xr.DataArray( 13 | test_elevation[None, :, :], 14 | dims=["bands", "y", "x"], 15 | coords={ 16 | "bands": ["elevation"], 17 | "x": [ 18 | 71.2302216215233, 19 | 71.23031145305171, 20 | 71.23040128458014, 21 | 71.23040128458014, 22 | ], 23 | "y": [25.084450211061935, 25.08436885206669, 25.084287493017356], 24 | }, 25 | ) 26 | 27 | extractor = PrestoFeatureExtractor() 28 | extractor._epsg = 4326 # pylint: disable=protected-access 29 | 30 | # In the UDF no_data is set to 65535 31 | resolution = extractor.evaluate_resolution(array) 32 | slope = extractor.compute_slope( 33 | array, resolution 34 | ).values # pylint: disable=protected-access 35 | 36 | assert slope[0, -1, 0] == 65535 37 | assert resolution == 10 38 | -------------------------------------------------------------------------------- /tests/worldcerealtests/test_inference.py: -------------------------------------------------------------------------------- 1 | from openeo_gfmap.features.feature_extractor import ( 2 | EPSG_HARMONIZED_NAME, 3 | apply_feature_extractor_local, 4 | ) 5 | from openeo_gfmap.inference.model_inference import apply_model_inference_local 6 | 7 | from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor 8 | from worldcereal.openeo.inference import CropClassifier 9 | from worldcereal.parameters import CropLandParameters, CropTypeParameters 10 | from worldcereal.utils.models import load_model_lut 11 | 12 | 13 | def test_cropland_inference(WorldCerealPreprocessedInputs): 14 | """Test the local generation of a cropland product""" 15 | 16 | print("Get Presto cropland features") 17 | cropland_features = apply_feature_extractor_local( 18 | PrestoFeatureExtractor, 19 | WorldCerealPreprocessedInputs, 20 | parameters={ 21 | EPSG_HARMONIZED_NAME: 32631, 22 | "ignore_dependencies": True, 23 | "compile_presto": False, 24 | "use_valid_date_token": False, 25 | "presto_model_url": CropLandParameters().feature_parameters.presto_model_url, 26 | }, 27 | ) 28 | 29 | print("Running cropland classification inference UDF locally") 30 | 31 | lookup_table = load_model_lut( 32 | CropLandParameters().classifier_parameters.classifier_url 33 | ) 34 | 35 | cropland_classification = apply_model_inference_local( 36 | CropClassifier, 37 | cropland_features, 38 | parameters={ 39 | EPSG_HARMONIZED_NAME: 32631, 40 | "ignore_dependencies": True, 41 | "lookup_table": lookup_table, 42 | "classifier_url": CropLandParameters().classifier_parameters.classifier_url, 43 | }, 44 | ) 45 | 46 | assert list(cropland_classification.bands.values) == [ 47 | "classification", 48 | "probability", 49 | "probability_other", 50 | "probability_cropland", 51 | ] 52 | assert cropland_classification.sel(bands="classification").values.max() <= 1 53 | assert cropland_classification.sel(bands="classification").values.min() >= 0 54 | assert cropland_classification.sel(bands="probability").values.max() <= 100 55 | assert cropland_classification.sel(bands="probability").values.min() >= 0 56 | assert cropland_classification.shape == (4, 100, 100) 57 | 58 | 59 | def test_croptype_inference(WorldCerealPreprocessedInputs): 60 | """Test the local generation of a croptype product""" 61 | 62 | print("Get Presto croptype features") 63 | croptype_features = apply_feature_extractor_local( 64 | PrestoFeatureExtractor, 65 | WorldCerealPreprocessedInputs, 66 | parameters={ 67 | EPSG_HARMONIZED_NAME: 32631, 68 | "ignore_dependencies": True, 69 | "compile_presto": False, 70 | "use_valid_date_token": True, 71 | "presto_model_url": CropTypeParameters().feature_parameters.presto_model_url, 72 | }, 73 | ) 74 | 75 | print("Running croptype classification inference UDF locally") 76 | 77 | lookup_table = load_model_lut( 78 | CropTypeParameters().classifier_parameters.classifier_url 79 | ) 80 | 81 | croptype_classification = apply_model_inference_local( 82 | CropClassifier, 83 | croptype_features, 84 | parameters={ 85 | EPSG_HARMONIZED_NAME: 32631, 86 | "ignore_dependencies": True, 87 | "lookup_table": lookup_table, 88 | "classifier_url": CropTypeParameters().classifier_parameters.classifier_url, 89 | }, 90 | ) 91 | 92 | assert list(croptype_classification.bands.values) == [ 93 | "classification", 94 | "probability", 95 | "probability_barley", 96 | "probability_maize", 97 | "probability_millet_sorghum", 98 | "probability_other_crop", 99 | "probability_rapeseed_rape", 100 | "probability_soy_soybeans", 101 | "probability_sunflower", 102 | "probability_wheat", 103 | ] 104 | 105 | # First assert below depends on the amount of classes in the model 106 | assert croptype_classification.sel(bands="classification").values.max() <= 7 107 | assert croptype_classification.sel(bands="classification").values.min() >= 0 108 | assert croptype_classification.sel(bands="probability").values.max() <= 100 109 | assert croptype_classification.sel(bands="probability").values.min() >= 0 110 | assert croptype_classification.shape == (10, 100, 100) 111 | -------------------------------------------------------------------------------- /tests/worldcerealtests/test_pipelines.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from catboost import CatBoostClassifier, Pool 4 | from loguru import logger 5 | from openeo_gfmap import BoundingBoxExtent, TemporalContext 6 | from presto.presto import Presto 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.utils.class_weight import compute_class_weight 9 | 10 | from worldcereal.parameters import CropTypeParameters 11 | from worldcereal.train.data import WorldCerealTrainingDataset, get_training_df 12 | from worldcereal.utils.refdata import ( 13 | process_extractions_df, 14 | query_private_extractions, 15 | query_public_extractions, 16 | ) 17 | 18 | SPATIAL_EXTENT = BoundingBoxExtent( 19 | west=4.63761, south=51.11649, east=4.73761, north=51.21649, epsg=4326 20 | ).to_geometry() 21 | 22 | 23 | def test_custom_croptype_demo(WorldCerealPrivateExtractionsPath): 24 | """Test for a full custom croptype pipeline up and till the point of 25 | training a custom catboost classifier. This test uses both public and private 26 | extractions. 27 | """ 28 | 29 | # Query private and public extractions 30 | 31 | private_df = query_private_extractions( 32 | WorldCerealPrivateExtractionsPath, 33 | bbox_poly=SPATIAL_EXTENT, 34 | filter_cropland=True, 35 | buffer=250000, # Meters 36 | ) 37 | 38 | assert not private_df.empty, "Should have found private extractions" 39 | logger.info( 40 | ( 41 | f"Found {private_df['sample_id'].nunique()} unique samples in the " 42 | f"private data, spread across {private_df['ref_id'].nunique()} " 43 | "unique reference datasets." 44 | ) 45 | ) 46 | 47 | public_df = query_public_extractions( 48 | SPATIAL_EXTENT, 49 | buffer=1000, # Meters 50 | filter_cropland=True, 51 | ) 52 | 53 | assert not public_df.empty, "Should have found public extractions" 54 | logger.info( 55 | ( 56 | f"Found {public_df['sample_id'].nunique()} unique samples in the " 57 | f"public data, spread across {public_df['ref_id'].nunique()} " 58 | "unique reference datasets." 59 | ) 60 | ) 61 | 62 | # Concatenate extractions 63 | extractions_df = pd.concat([private_df, public_df]) 64 | 65 | assert len(extractions_df) == len(public_df) + len(private_df) 66 | 67 | # Process the merged data 68 | processing_period = TemporalContext("2020-01-01", "2020-12-31") 69 | print(f"Shape of extractions_df: {extractions_df.shape}") 70 | training_df = process_extractions_df(extractions_df, processing_period) 71 | logger.info(f"training_df shape: {training_df.shape}") 72 | 73 | # Drop labels that occur infrequently for this test 74 | value_counts = training_df["ewoc_code"].value_counts() 75 | single_labels = value_counts[value_counts < 3].index.to_list() 76 | training_df = training_df[~training_df["ewoc_code"].isin(single_labels)] 77 | 78 | print("*" * 40) 79 | for c in training_df.columns: 80 | print(c) 81 | print("*" * 40) 82 | 83 | # Direct shape assert: if process_extractions_df changes, this may have to be updated 84 | assert training_df.shape == (238, 246) 85 | 86 | # We keep original ewoc_code for this test 87 | training_df["downstream_class"] = training_df["ewoc_code"] 88 | 89 | # Compute presto embeddings 90 | presto_model_url = CropTypeParameters().feature_parameters.presto_model_url 91 | 92 | # Load pretrained Presto model 93 | logger.info(f"Presto URL: {presto_model_url}") 94 | presto_model = Presto.load_pretrained( 95 | presto_model_url, 96 | from_url=True, 97 | strict=False, 98 | valid_month_as_token=True, 99 | ) 100 | 101 | # Initialize dataset 102 | df = training_df.reset_index() 103 | ds = WorldCerealTrainingDataset( 104 | df, 105 | task_type="croptype", 106 | augment=True, 107 | mask_ratio=0.30, 108 | repeats=1, 109 | ) 110 | logger.info("Computing Presto embeddings ...") 111 | df = get_training_df( 112 | ds, 113 | presto_model, 114 | batch_size=256, 115 | valid_date_as_token=True, 116 | ) 117 | logger.info("Presto embeddings computed.") 118 | 119 | # Train classifier 120 | logger.info("Split train/test ...") 121 | samples_train, samples_test = train_test_split( 122 | df, 123 | test_size=0.2, 124 | random_state=3, 125 | stratify=df["downstream_class"], 126 | ) 127 | 128 | eval_metric = "MultiClass" 129 | loss_function = "MultiClass" 130 | 131 | logger.info("Computing class weights ...") 132 | class_weights = np.round( 133 | compute_class_weight( 134 | class_weight="balanced", 135 | classes=np.unique(samples_train["downstream_class"]), 136 | y=samples_train["downstream_class"], 137 | ), 138 | 3, 139 | ) 140 | class_weights = { 141 | k: v 142 | for k, v in zip(np.unique(samples_train["downstream_class"]), class_weights) 143 | } 144 | logger.info(f"Class weights: {class_weights}") 145 | 146 | sample_weights = np.ones((len(samples_train["downstream_class"]),)) 147 | sample_weights_val = np.ones((len(samples_test["downstream_class"]),)) 148 | for k, v in class_weights.items(): 149 | sample_weights[samples_train["downstream_class"] == k] = v 150 | sample_weights_val[samples_test["downstream_class"] == k] = v 151 | samples_train["weight"] = sample_weights 152 | samples_test["weight"] = sample_weights_val 153 | 154 | # Define classifier 155 | custom_downstream_model = CatBoostClassifier( 156 | iterations=2000, # Not too high to avoid too large model size 157 | depth=8, 158 | early_stopping_rounds=20, 159 | loss_function=loss_function, 160 | eval_metric=eval_metric, 161 | random_state=3, 162 | verbose=25, 163 | class_names=np.unique(samples_train["downstream_class"]), 164 | ) 165 | 166 | # Setup dataset Pool 167 | bands = [f"presto_ft_{i}" for i in range(128)] 168 | calibration_data = Pool( 169 | data=samples_train[bands], 170 | label=samples_train["downstream_class"], 171 | weight=samples_train["weight"], 172 | ) 173 | eval_data = Pool( 174 | data=samples_test[bands], 175 | label=samples_test["downstream_class"], 176 | weight=samples_test["weight"], 177 | ) 178 | 179 | # Train classifier 180 | logger.info("Training CatBoost classifier ...") 181 | custom_downstream_model.fit( 182 | calibration_data, 183 | eval_set=eval_data, 184 | ) 185 | 186 | # Make predictions 187 | _ = custom_downstream_model.predict(samples_test[bands]).flatten() 188 | -------------------------------------------------------------------------------- /tests/worldcerealtests/test_postprocessing.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from openeo_gfmap.inference.model_inference import ( 3 | EPSG_HARMONIZED_NAME, 4 | apply_model_inference_local, 5 | ) 6 | 7 | from worldcereal.openeo.postprocess import PostProcessor 8 | from worldcereal.parameters import CropLandParameters, PostprocessParameters 9 | 10 | 11 | def test_cropland_postprocessing(WorldCerealCroplandClassification): 12 | """Test the local postprocessing of a cropland product""" 13 | 14 | print("Postprocessing cropland product ...") 15 | _ = apply_model_inference_local( 16 | PostProcessor, 17 | WorldCerealCroplandClassification, 18 | parameters={ 19 | "ignore_dependencies": True, 20 | EPSG_HARMONIZED_NAME: None, 21 | "classifier_url": CropLandParameters().classifier_parameters.classifier_url, 22 | "method": "smooth_probabilities", 23 | }, 24 | ) 25 | 26 | 27 | def test_cropland_postprocessing_majority_vote(WorldCerealCroplandClassification): 28 | """Test the local postprocessing of a cropland product""" 29 | 30 | print("Postprocessing cropland product ...") 31 | _ = apply_model_inference_local( 32 | PostProcessor, 33 | WorldCerealCroplandClassification, 34 | parameters={ 35 | "ignore_dependencies": True, 36 | EPSG_HARMONIZED_NAME: None, 37 | "classifier_url": CropLandParameters().classifier_parameters.classifier_url, 38 | "method": "majority_vote", 39 | "kernel_size": 7, 40 | }, 41 | ) 42 | 43 | 44 | def test_croptype_postprocessing(WorldCerealCroptypeClassification): 45 | """Test the local postprocessing of a croptype product""" 46 | 47 | print("Postprocessing croptype product ...") 48 | _ = apply_model_inference_local( 49 | PostProcessor, 50 | WorldCerealCroptypeClassification, 51 | parameters={ 52 | "ignore_dependencies": True, 53 | EPSG_HARMONIZED_NAME: None, 54 | "classifier_url": CropLandParameters().classifier_parameters.classifier_url, 55 | "method": "smooth_probabilities", 56 | }, 57 | ) 58 | 59 | 60 | def test_croptype_postprocessing_majority_vote(WorldCerealCroptypeClassification): 61 | """Test the local postprocessing of a croptype product""" 62 | 63 | print("Postprocessing croptype product ...") 64 | _ = apply_model_inference_local( 65 | PostProcessor, 66 | WorldCerealCroptypeClassification, 67 | parameters={ 68 | "ignore_dependencies": True, 69 | EPSG_HARMONIZED_NAME: None, 70 | "classifier_url": CropLandParameters().classifier_parameters.classifier_url, 71 | "method": "majority_vote", 72 | "kernel_size": 7, 73 | }, 74 | ) 75 | 76 | 77 | def test_postprocessing_parameters(): 78 | """Test the postprocessing parameters.""" 79 | 80 | # This set should work 81 | params = { 82 | "enable": True, 83 | "method": "smooth_probabilities", 84 | "kernel_size": 5, 85 | "save_intermediate": False, 86 | "keep_class_probs": False, 87 | } 88 | PostprocessParameters(**params) 89 | 90 | # This one as well 91 | params["method"] = "majority_vote" 92 | PostprocessParameters(**params) 93 | 94 | # This one should fail with invalid kernel size 95 | params["kernel_size"] = 30 96 | with pytest.raises(ValueError): 97 | PostprocessParameters(**params) 98 | 99 | # This one should fail with invalid method 100 | params["method"] = "test" 101 | with pytest.raises(ValueError): 102 | PostprocessParameters(**params) 103 | 104 | # This one should fail with invalid save_intermediate 105 | params["enable"] = False 106 | params["save_intermediate"] = True 107 | params["method"] = "smooth_probabilities" 108 | with pytest.raises(ValueError): 109 | PostprocessParameters(**params) 110 | -------------------------------------------------------------------------------- /tests/worldcerealtests/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | import pytest 6 | from openeo_gfmap import BoundingBoxExtent, FetchType 7 | from openeo_gfmap.backend import ( 8 | Backend, 9 | BackendContext, 10 | cdse_connection, 11 | vito_connection, 12 | ) 13 | from openeo_gfmap.temporal import TemporalContext 14 | 15 | from worldcereal.extract.patch_to_point_worldcereal import ( 16 | worldcereal_preprocessed_inputs_from_patches, 17 | ) 18 | from worldcereal.openeo.preprocessing import ( 19 | InvalidTemporalContextError, 20 | _validate_temporal_context, 21 | correct_temporal_context, 22 | worldcereal_preprocessed_inputs, 23 | ) 24 | 25 | basedir = Path(os.path.dirname(os.path.realpath(__file__))) 26 | 27 | 28 | def test_temporal_context_validation(): 29 | """Test the validation of temporal context.""" 30 | 31 | temporal_context = TemporalContext("2020-01-01", "2020-12-31") 32 | _validate_temporal_context(temporal_context) 33 | 34 | incorrect_temporal_context = TemporalContext("2020-01-05", "2020-03-15") 35 | 36 | with pytest.raises(InvalidTemporalContextError): 37 | _validate_temporal_context(incorrect_temporal_context) 38 | 39 | more_than_one_year = TemporalContext("2019-01-05", "2021-03-15") 40 | 41 | with pytest.raises(InvalidTemporalContextError): 42 | _validate_temporal_context(more_than_one_year) 43 | 44 | 45 | def test_temporal_context_correction(): 46 | """Test the automatic correction of invalid temporal context.""" 47 | 48 | incorrect_temporal_context = TemporalContext("2022-01-05", "2020-03-15") 49 | corrected_temporal_context = correct_temporal_context(incorrect_temporal_context) 50 | 51 | # Should no longer raise an exception 52 | _validate_temporal_context(corrected_temporal_context) 53 | 54 | 55 | def test_worldcereal_preprocessed_inputs_graph(SpatialExtent): 56 | """Test the worldcereal_preprocessed_inputs function. 57 | This is based on constructing the openEO graph for the job 58 | that would run, without actually running it.""" 59 | 60 | temporal_extent = TemporalContext("2020-06-01", "2021-05-31") 61 | 62 | cube = worldcereal_preprocessed_inputs( 63 | connection=cdse_connection(), 64 | backend_context=BackendContext(Backend.CDSE), 65 | spatial_extent=SpatialExtent, 66 | temporal_extent=temporal_extent, 67 | fetch_type=FetchType.POLYGON, 68 | ) 69 | 70 | # Ref file with processing graph 71 | ref_graph = basedir / "testresources" / "preprocess_graph.json" 72 | 73 | # # uncomment to save current graph to the ref file 74 | # with open(ref_graph, "w") as f: 75 | # f.write(json.dumps(cube.flat_graph(), indent=4)) 76 | 77 | with open(ref_graph, "r") as f: 78 | expected = json.load(f) 79 | assert expected == cube.flat_graph() 80 | 81 | 82 | def test_worldcereal_preprocessed_inputs_graph_withslope(): 83 | """This version has fetchtype.TILE and should include slope.""" 84 | 85 | temporal_extent = TemporalContext("2018-03-01", "2019-02-28") 86 | 87 | cube = worldcereal_preprocessed_inputs( 88 | connection=cdse_connection(), 89 | backend_context=BackendContext(Backend.CDSE), 90 | spatial_extent=BoundingBoxExtent( 91 | west=44.432274, south=51.317362, east=44.698802, north=51.428224, epsg=4326 92 | ), 93 | temporal_extent=temporal_extent, 94 | ) 95 | 96 | # Ref file with processing graph 97 | ref_graph = basedir / "testresources" / "preprocess_graphwithslope.json" 98 | 99 | # # uncomment to save current graph to the ref file 100 | # with open(ref_graph, "w") as f: 101 | # f.write(json.dumps(cube.flat_graph(), indent=4)) 102 | 103 | with open(ref_graph, "r") as f: 104 | expected = json.load(f) 105 | assert expected == cube.flat_graph() 106 | 107 | 108 | def test_worldcereal_preprocessed_inputs_from_patches_graph(): 109 | """This version gets a preprocessed cube from extracted patches.""" 110 | 111 | temporal_extent = TemporalContext("2020-01-01", "2020-12-31") 112 | 113 | cube = worldcereal_preprocessed_inputs_from_patches( 114 | connection=vito_connection(), 115 | temporal_extent=temporal_extent, 116 | ref_id="test_ref_id", 117 | epsg=32631, 118 | ) 119 | 120 | # Ref file with processing graph 121 | ref_graph = basedir / "testresources" / "preprocess_from_patches_graph.json" 122 | 123 | # # uncomment to save current graph to the ref file 124 | # with open(ref_graph, "w") as f: 125 | # f.write(json.dumps(cube.flat_graph(), indent=4)) 126 | 127 | with open(ref_graph, "r") as f: 128 | expected = json.load(f) 129 | assert expected == cube.flat_graph() 130 | -------------------------------------------------------------------------------- /tests/worldcerealtests/test_rdm_interaction.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import geopandas as gpd 4 | import pytest 5 | from openeo_gfmap import BoundingBoxExtent, TemporalContext 6 | from shapely import Point 7 | 8 | from worldcereal.rdm_api.rdm_interaction import RdmCollection, RdmInteraction 9 | 10 | 11 | @pytest.fixture 12 | def sample_bbox(): 13 | return BoundingBoxExtent(west=0, east=1, north=1, south=0) 14 | 15 | 16 | @pytest.fixture 17 | def sample_temporal_extent(): 18 | return TemporalContext(start_date="2021-01-01", end_date="2021-12-31") 19 | 20 | 21 | class TestRdmInteraction: 22 | @patch("requests.Session.get") 23 | def test_collections_from_rdm( 24 | self, mock_requests_get, sample_bbox, sample_temporal_extent 25 | ): 26 | 27 | mock_requests_get.return_value.status_code = 200 28 | mock_requests_get.return_value.json.return_value = [ 29 | { 30 | "collectionId": "Foo", 31 | "title": "Foo_title", 32 | "accessType": "Public", 33 | }, 34 | { 35 | "collectionId": "Bar", 36 | "title": "Bar_title", 37 | "accessType": "Public", 38 | }, 39 | ] 40 | interaction = RdmInteraction() 41 | collections = interaction.get_collections( 42 | spatial_extent=sample_bbox, temporal_extent=sample_temporal_extent 43 | ) 44 | ref_ids = [collection.id for collection in collections] 45 | 46 | assert ref_ids == ["Foo", "Bar"] 47 | 48 | bbox_str = f"Bbox={sample_bbox.west}&Bbox={sample_bbox.south}&Bbox={sample_bbox.east}&Bbox={sample_bbox.north}" 49 | temporal = f"&ValidityTime.Start={sample_temporal_extent.start_date}T00%3A00%3A00Z&ValidityTime.End={sample_temporal_extent.end_date}T00%3A00%3A00Z" 50 | expected_url = ( 51 | f"{interaction.RDM_ENDPOINT}/collections/search?{bbox_str}{temporal}" 52 | ) 53 | mock_requests_get.assert_called_with( 54 | url=expected_url, headers={"accept": "*/*"}, timeout=10 55 | ) 56 | 57 | @patch("worldcereal.rdm_api.rdm_interaction.RdmInteraction.get_collections") 58 | @patch("worldcereal.rdm_api.rdm_interaction.RdmInteraction._get_download_urls") 59 | def test_download_samples( 60 | self, 61 | mock_get_download_urls, 62 | mock_collections_from_rdm, 63 | sample_bbox, 64 | sample_temporal_extent, 65 | tmp_path, 66 | ): 67 | 68 | data = { 69 | "col1": ["must", "include", "this", "column", "definitely", "check"], 70 | "col2": ["and", "this", "One", "Too", "please", "check"], 71 | "col3": ["but", "not", "This", "One", "please", "check"], 72 | "valid_time": [ 73 | "2021-01-01", 74 | "2021-12-31", 75 | "2021-06-01", 76 | "2025-05-22", 77 | "2021-06-01", 78 | "2021-06-01", 79 | ], # Fourth date not within sample_temporal_extent 80 | "ewoc_code": ["1", "2", "3", "4", "5", "1"], 81 | # Fifth crop code not within list of ewoc_codes 82 | "extract": [1, 1, 1, 1, 2, 0], 83 | "geometry": [ 84 | Point(0.5, 0.5), 85 | Point(0.25, 0.25), 86 | Point(2, 3), 87 | Point(0.75, 0.75), 88 | Point(0.75, 0.78), 89 | Point(0.78, 0.75), 90 | ], # Third point not within sample_polygon 91 | } 92 | gdf = gpd.GeoDataFrame(data, crs="EPSG:4326") 93 | file_path = tmp_path / "sample.parquet" 94 | gdf.to_parquet(file_path) 95 | 96 | mock_collections_from_rdm.return_value = [ 97 | RdmCollection( 98 | **{ 99 | "collectionId": "Foo", 100 | "title": "Foo_title", 101 | "accessType": "Public", 102 | } 103 | ), 104 | ] 105 | mock_get_download_urls.return_value = [str(file_path)] 106 | 107 | interaction = RdmInteraction() 108 | result_gdf = interaction.get_samples( 109 | spatial_extent=sample_bbox, 110 | temporal_extent=sample_temporal_extent, 111 | columns=["col1", "col2", "ref_id", "geometry"], 112 | ewoc_codes=["1", "2", "3", "4"], 113 | subset=True, 114 | ) 115 | 116 | # Check that col3 and valid_time indeed not included 117 | assert result_gdf.columns.tolist() == [ 118 | "col1", 119 | "col2", 120 | "ref_id", 121 | "geometry", 122 | ] 123 | 124 | # Check that the third up till last geometry are not included 125 | # third and fourth are outside the spatiotemporal extent 126 | # fifth has a crop type not in the list of ewoc_codes 127 | # last sample is not to be extracted 128 | assert len(result_gdf) == 2 129 | -------------------------------------------------------------------------------- /tests/worldcerealtests/test_refdata.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from shapely.geometry import Polygon 3 | 4 | from worldcereal.utils.refdata import ( 5 | get_best_valid_time, 6 | month_diff, 7 | query_public_extractions, 8 | ) 9 | 10 | 11 | def test_query_public_extractions(): 12 | """Unittest for querying public extractions.""" 13 | 14 | # Define small polygon 15 | poly = Polygon.from_bounds(*(4.535, 51.050719, 4.600936, 51.098176)) 16 | 17 | # Query extractions 18 | df = query_public_extractions(poly, buffer=100) 19 | 20 | # Check if dataframe has samples 21 | assert not df.empty 22 | 23 | 24 | def test_get_best_valid_time(): 25 | def process_test_case(test_case: pd.Series) -> pd.DataFrame: 26 | test_case_res = [] 27 | for processing_period_middle_month in range(1, 13): 28 | test_case["true_valid_time_month"] = test_case["valid_time"].month 29 | test_case["proposed_valid_time_month"] = processing_period_middle_month 30 | test_case["valid_month_shift_backward"] = month_diff( 31 | test_case["proposed_valid_time_month"], 32 | test_case["true_valid_time_month"], 33 | ) 34 | test_case["valid_month_shift_forward"] = month_diff( 35 | test_case["true_valid_time_month"], 36 | test_case["proposed_valid_time_month"], 37 | ) 38 | proposed_valid_time = get_best_valid_time(test_case) 39 | test_case_res.append([processing_period_middle_month, proposed_valid_time]) 40 | return pd.DataFrame( 41 | test_case_res, columns=["proposed_valid_month", "resulting_valid_time"] 42 | ) 43 | 44 | test_case1 = pd.Series( 45 | { 46 | "start_date": pd.to_datetime("2019-01-01"), 47 | "end_date": pd.to_datetime("2019-12-01"), 48 | "valid_time": pd.to_datetime("2019-06-01"), 49 | } 50 | ) 51 | test_case2 = pd.Series( 52 | { 53 | "start_date": pd.to_datetime("2019-01-01"), 54 | "end_date": pd.to_datetime("2019-12-01"), 55 | "valid_time": pd.to_datetime("2019-10-01"), 56 | } 57 | ) 58 | test_case3 = pd.Series( 59 | { 60 | "start_date": pd.to_datetime("2019-01-01"), 61 | "end_date": pd.to_datetime("2019-12-01"), 62 | "valid_time": pd.to_datetime("2019-03-01"), 63 | } 64 | ) 65 | 66 | # Process test cases 67 | test_case1_res = process_test_case(test_case1) 68 | test_case2_res = process_test_case(test_case2) 69 | test_case3_res = process_test_case(test_case3) 70 | 71 | # Asserts are valid for default MIN_EDGE_BUFFER and NUM_TIMESTEPS values 72 | # Assertions for test case 1 73 | assert ( 74 | test_case1_res[test_case1_res["proposed_valid_month"].isin([1, 2, 11, 12])][ 75 | "resulting_valid_time" 76 | ] 77 | .isna() 78 | .all() 79 | ) 80 | assert ( 81 | test_case1_res[test_case1_res["proposed_valid_month"].isin(range(3, 11))][ 82 | "resulting_valid_time" 83 | ] 84 | .notna() 85 | .all() 86 | ) 87 | 88 | # Assertions for test case 2 89 | assert ( 90 | test_case2_res[test_case2_res["proposed_valid_month"].isin([1, 2, 3, 11, 12])][ 91 | "resulting_valid_time" 92 | ] 93 | .isna() 94 | .all() 95 | ) 96 | assert ( 97 | test_case2_res[test_case2_res["proposed_valid_month"].isin(range(4, 11))][ 98 | "resulting_valid_time" 99 | ] 100 | .notna() 101 | .all() 102 | ) 103 | 104 | # Assertions for test case 3 105 | assert ( 106 | test_case3_res[ 107 | test_case3_res["proposed_valid_month"].isin([1, 2, 9, 10, 11, 12]) 108 | ]["resulting_valid_time"] 109 | .isna() 110 | .all() 111 | ) 112 | assert ( 113 | test_case3_res[test_case3_res["proposed_valid_month"].isin(range(3, 9))][ 114 | "resulting_valid_time" 115 | ] 116 | .notna() 117 | .all() 118 | ) 119 | -------------------------------------------------------------------------------- /tests/worldcerealtests/test_seasons.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from openeo_gfmap import BoundingBoxExtent 6 | 7 | from worldcereal.seasons import ( 8 | circular_median_day_of_year, 9 | doy_from_tiff, 10 | doy_to_date_after, 11 | get_processing_dates_for_extent, 12 | ) 13 | 14 | 15 | def test_doy_from_tiff(): 16 | bounds = (574680, 5621800, 575320, 5622440) 17 | epsg = 32631 18 | 19 | doy_data = doy_from_tiff("tc-s1", "SOS", bounds, epsg, resolution=10000) 20 | 21 | assert doy_data.size == 1 22 | 23 | doy_data = int(doy_data) 24 | 25 | assert doy_data != 0 26 | 27 | 28 | def test_doy_to_date_after(): 29 | bounds = (574680, 5621800, 575320, 5622440) 30 | epsg = 32631 31 | 32 | doy_data = doy_from_tiff("tc-s2", "SOS", bounds, epsg, resolution=10000) 33 | 34 | after_date = datetime.datetime(2019, 1, 1) 35 | doy_date = doy_to_date_after(int(doy_data), after_date) 36 | 37 | assert pd.to_datetime(doy_date) >= after_date 38 | 39 | after_date = datetime.datetime(2019, 8, 1) 40 | doy_date = doy_to_date_after(int(doy_data), after_date) 41 | 42 | assert pd.to_datetime(doy_date) >= after_date 43 | 44 | 45 | def test_get_processing_dates_for_extent(): 46 | # Test to check if we can infer processing dates for default season 47 | # tc-annual 48 | bounds = (574680, 5621800, 575320, 5622440) 49 | epsg = 32631 50 | year = 2021 51 | extent = BoundingBoxExtent(*bounds, epsg) 52 | 53 | temporal_context = get_processing_dates_for_extent(extent, year) 54 | start_date = temporal_context.start_date 55 | end_date = temporal_context.end_date 56 | 57 | assert pd.to_datetime(end_date).year == year 58 | assert pd.to_datetime(end_date) - pd.to_datetime(start_date) == pd.Timedelta( 59 | days=364 60 | ) 61 | 62 | 63 | def test_compute_median_doy(): 64 | # Test to check if we can compute median day of year 65 | # if array crosses the calendar year 66 | doy_array = np.array([360, 362, 365, 1, 3, 5, 7]) 67 | assert circular_median_day_of_year(doy_array) == 1 68 | 69 | # Other tests 70 | assert circular_median_day_of_year([1, 2, 3, 4, 5]) == 3 71 | assert circular_median_day_of_year([320, 330, 340, 360]) == 330 72 | assert circular_median_day_of_year([320, 330, 340, 360, 10]) == 340 73 | -------------------------------------------------------------------------------- /tests/worldcerealtests/test_stac_api_interaction.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pystac 4 | import pytest 5 | from requests.auth import AuthBase 6 | 7 | from worldcereal.stac.stac_api_interaction import StacApiInteraction 8 | 9 | 10 | @pytest.fixture 11 | def mock_auth(): 12 | return MagicMock(spec=AuthBase) 13 | 14 | 15 | def mock_stac_item(item_id): 16 | item = MagicMock(spec=pystac.Item) 17 | item.id = item_id 18 | item.to_dict.return_value = {"id": item_id, "some_property": "value"} 19 | return item 20 | 21 | 22 | class TestStacApiInteraction: 23 | @patch("requests.post") 24 | @patch("worldcereal.stac.stac_api_interaction.StacApiInteraction.exists") 25 | def test_upload_items_single_chunk( 26 | self, mock_exists, mock_requests_post, mock_auth 27 | ): 28 | """Test bulk upload of STAC items in one single chunk.""" 29 | 30 | mock_requests_post.return_value.status_code = 200 31 | mock_requests_post.return_value.json.return_value = {"status": "success"} 32 | mock_requests_post.reason = "OK" 33 | 34 | mock_exists.return_value = True 35 | 36 | items = [mock_stac_item(f"item-{i}") for i in range(10)] 37 | 38 | interaction = StacApiInteraction( 39 | sensor="Sentinel1", 40 | base_url="http://fake-stac-api", 41 | auth=mock_auth, 42 | bulk_size=10, # To ensure all 10 items are uploaded in one bulk 43 | ) 44 | interaction.upload_items_bulk(items) 45 | 46 | mock_requests_post.assert_called_with( 47 | url=f"http://fake-stac-api/collections/{interaction.collection_id}/bulk_items", 48 | auth=mock_auth, 49 | json={ 50 | "method": "upsert", 51 | "items": {item.id: item.to_dict() for item in items}, 52 | }, 53 | ) 54 | assert mock_requests_post.call_count == 1 55 | 56 | @patch("requests.post") 57 | @patch("worldcereal.stac.stac_api_interaction.StacApiInteraction.exists") 58 | def test_upload_items_multiple_chunk( 59 | self, mock_exists, mock_requests_post, mock_auth 60 | ): 61 | """Test bulk upload of STAC items in mulitiple chunks.""" 62 | 63 | mock_requests_post.return_value.status_code = 200 64 | mock_requests_post.return_value.json.return_value = {"status": "success"} 65 | mock_requests_post.reason = "OK" 66 | 67 | mock_exists.return_value = True 68 | 69 | items = [mock_stac_item(f"item-{i}") for i in range(10)] 70 | 71 | interaction = StacApiInteraction( 72 | sensor="Sentinel1", 73 | base_url="http://fake-stac-api", 74 | auth=mock_auth, 75 | bulk_size=3, # This would require 4 chunk for 10 items 76 | ) 77 | interaction.upload_items_bulk(items) 78 | 79 | assert mock_requests_post.call_count == 4 80 | 81 | expected_calls = [ 82 | { 83 | "url": f"http://fake-stac-api/collections/{interaction.collection_id}/bulk_items", 84 | "auth": mock_auth, 85 | "json": { 86 | "method": "upsert", 87 | "items": { 88 | "item-0": {"id": "item-0", "some_property": "value"}, 89 | "item-1": {"id": "item-1", "some_property": "value"}, 90 | "item-2": {"id": "item-2", "some_property": "value"}, 91 | }, 92 | }, 93 | }, 94 | { 95 | "url": f"http://fake-stac-api/collections/{interaction.collection_id}/bulk_items", 96 | "auth": mock_auth, 97 | "json": { 98 | "method": "upsert", 99 | "items": { 100 | "item-3": {"id": "item-3", "some_property": "value"}, 101 | "item-4": {"id": "item-4", "some_property": "value"}, 102 | "item-5": {"id": "item-5", "some_property": "value"}, 103 | }, 104 | }, 105 | }, 106 | { 107 | "url": f"http://fake-stac-api/collections/{interaction.collection_id}/bulk_items", 108 | "auth": mock_auth, 109 | "json": { 110 | "method": "upsert", 111 | "items": { 112 | "item-6": {"id": "item-6", "some_property": "value"}, 113 | "item-7": {"id": "item-7", "some_property": "value"}, 114 | "item-8": {"id": "item-8", "some_property": "value"}, 115 | }, 116 | }, 117 | }, 118 | { 119 | "url": f"http://fake-stac-api/collections/{interaction.collection_id}/bulk_items", 120 | "auth": mock_auth, 121 | "json": { 122 | "method": "upsert", 123 | "items": { 124 | "item-9": {"id": "item-9", "some_property": "value"}, 125 | }, 126 | }, 127 | }, 128 | ] 129 | 130 | for i, call in enumerate(mock_requests_post.call_args_list): 131 | assert call[1] == expected_calls[i] 132 | -------------------------------------------------------------------------------- /tests/worldcerealtests/test_train.py: -------------------------------------------------------------------------------- 1 | from presto.presto import Presto 2 | from torch.utils.data import DataLoader 3 | 4 | from worldcereal.parameters import CropLandParameters 5 | from worldcereal.train.data import WorldCerealTrainingDataset, get_training_df 6 | 7 | 8 | def test_worldcerealtraindataset(WorldCerealExtractionsDF): 9 | """Test creation of WorldCerealTrainingDataset and data loading""" 10 | 11 | df = WorldCerealExtractionsDF.reset_index() 12 | 13 | ds = WorldCerealTrainingDataset( 14 | df, 15 | task_type="cropland", 16 | augment=True, 17 | mask_ratio=0.25, 18 | repeats=2, 19 | ) 20 | 21 | # Check if number of samples matches repeats 22 | assert len(ds) == 2 * len(df) 23 | 24 | # Check if data loading works 25 | dl = DataLoader(ds, batch_size=2, shuffle=True) 26 | 27 | for x, y, dw, latlons, month, valid_month, variable_mask, attrs in dl: 28 | assert x.shape == (2, 12, 17) 29 | assert y.shape == (2,) 30 | assert dw.shape == (2, 12) 31 | assert dw.unique().numpy()[0] == 9 32 | assert latlons.shape == (2, 2) 33 | assert month.shape == (2,) 34 | assert valid_month.shape == (2,) 35 | assert variable_mask.shape == x.shape 36 | assert isinstance(attrs, dict) 37 | break 38 | 39 | 40 | def test_get_trainingdf(WorldCerealExtractionsDF): 41 | """Test the function that computes embeddings and targets into 42 | a training dataframe using a presto model 43 | """ 44 | 45 | df = WorldCerealExtractionsDF.reset_index() 46 | ds = WorldCerealTrainingDataset(df) 47 | 48 | presto_url = CropLandParameters().feature_parameters.presto_model_url 49 | presto_model = Presto.load_pretrained(presto_url, from_url=True, strict=False) 50 | 51 | training_df = get_training_df(ds, presto_model, batch_size=256) 52 | 53 | for ft in range(128): 54 | assert f"presto_ft_{ft}" in training_df.columns 55 | -------------------------------------------------------------------------------- /tests/worldcerealtests/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from catboost import CatBoostClassifier 3 | from openeo_gfmap.backend import cdse_connection 4 | 5 | from worldcereal.utils.models import load_model_onnx 6 | from worldcereal.utils.upload import deploy_model 7 | 8 | 9 | def test_deploy_model(): 10 | """Simple test to deploy a CatBoost model and load it back.""" 11 | model = CatBoostClassifier(iterations=10).fit(X=[[1, 2], [3, 4]], y=[0, 1]) 12 | presigned_uri = deploy_model(cdse_connection(), model) 13 | model = load_model_onnx(presigned_uri) 14 | 15 | # Compare model predictions with the original targets 16 | np.testing.assert_array_equal( 17 | model.run(None, {"features": [[1, 2], [3, 4]]})[0], [0, 1] 18 | ) 19 | -------------------------------------------------------------------------------- /tests/worldcerealtests/testresources/spatial_extent.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "FeatureCollection", 3 | "features": [ 4 | { 5 | "type": "Feature", 6 | "geometry": { 7 | "type": "Polygon", 8 | "coordinates": [ 9 | [ 10 | [ 11 | 44.433631, 12 | 51.317362 13 | ], 14 | [ 15 | 44.432274, 16 | 51.427238 17 | ], 18 | [ 19 | 44.69808, 20 | 51.428224 21 | ], 22 | [ 23 | 44.698802, 24 | 51.318344 25 | ], 26 | [ 27 | 44.433631, 28 | 51.317362 29 | ] 30 | ] 31 | ] 32 | }, 33 | "properties": {} 34 | } 35 | ] 36 | } -------------------------------------------------------------------------------- /tests/worldcerealtests/testresources/test_public_extractions.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/testresources/test_public_extractions.parquet -------------------------------------------------------------------------------- /tests/worldcerealtests/testresources/worldcereal_cropland_classification.nc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/testresources/worldcereal_cropland_classification.nc -------------------------------------------------------------------------------- /tests/worldcerealtests/testresources/worldcereal_croptype_classification.nc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/testresources/worldcereal_croptype_classification.nc -------------------------------------------------------------------------------- /tests/worldcerealtests/testresources/worldcereal_preprocessed_inputs.nc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/testresources/worldcereal_preprocessed_inputs.nc -------------------------------------------------------------------------------- /tests/worldcerealtests/testresources/worldcereal_private_extractions_dummy.parquet/ref_id=2021_BEL_LPIS-Flanders_POLY_110/data_0.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/testresources/worldcereal_private_extractions_dummy.parquet/ref_id=2021_BEL_LPIS-Flanders_POLY_110/data_0.parquet --------------------------------------------------------------------------------