├── .github
└── workflows
│ ├── ci.yaml
│ └── lint.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── Jenkinsfile
├── LICENSE
├── README.md
├── assets
└── worldcereal_logo.jpg
├── environment.yml
├── jenkins_pre_install_script.sh
├── mypy.ini
├── notebooks
├── notebook_utils
│ ├── classifier.py
│ ├── croptypepicker.py
│ ├── dateslider.py
│ ├── extractions.py
│ ├── seasons.py
│ └── visualization.py
├── patch_to_point
│ ├── job_df.parquet
│ ├── patch_to_point_development.ipynb
│ └── rdm_query_example
├── resources
│ ├── Cropland_inference_choose_end_date.png
│ ├── Custom_cropland_map.png
│ ├── Custom_croptype_map.png
│ ├── Default_cropland_map.png
│ ├── Landcover_training_data_density_PhI.png
│ ├── MOOC_refdata_RDM_exploration.png
│ ├── WorldCereal_private_extractions.png
│ ├── cropland_data_Ph1.png
│ ├── croptype_classes.json
│ ├── croptype_data_Ph1.png
│ ├── eurocrops_map_wcr_edition.csv
│ └── wc2eurocrops_map.csv
├── worldcereal_RDM_demo.ipynb
├── worldcereal_custom_croptype.ipynb
├── worldcereal_default_cropland.ipynb
└── worldcereal_private_extractions.ipynb
├── pyproject.toml
├── pytest.ini
├── scripts
├── extractions
│ └── extract.py
├── inference
│ ├── collect_inputs.py
│ ├── cropland_mapping.py
│ ├── cropland_mapping_local.py
│ ├── cropland_mapping_udp.py
│ └── croptype_mapping_udp.py
├── misc
│ └── legend.py
├── spark
│ ├── compute_presto_features.py
│ └── compute_presto_features.sh
└── stac
│ ├── build_paths.py
│ ├── catalogue_builder.py
│ └── split_catalogue.py
├── src
└── worldcereal
│ ├── __init__.py
│ ├── _version.py
│ ├── data
│ ├── cropcalendars
│ │ ├── ANNUAL_EOS_WGS84.tif
│ │ ├── ANNUAL_SOS_WGS84.tif
│ │ ├── S1_EOS_WGS84.tif
│ │ ├── S1_SOS_WGS84.tif
│ │ ├── S2_EOS_WGS84.tif
│ │ ├── S2_SOS_WGS84.tif
│ │ └── __init__.py
│ └── croptype_mappings
│ │ ├── __init__.py
│ │ ├── croptype_classes.json
│ │ ├── eurocrops_map_wcr_edition.csv
│ │ └── wc2eurocrops_map.csv
│ ├── extract
│ ├── __init__.py
│ ├── common.py
│ ├── patch_meteo.py
│ ├── patch_s1.py
│ ├── patch_s2.py
│ ├── patch_to_point_worldcereal.py
│ ├── patch_worldcereal.py
│ ├── point_worldcereal.py
│ └── utils.py
│ ├── job.py
│ ├── openeo
│ ├── __init__.py
│ ├── feature_extractor.py
│ ├── inference.py
│ ├── mapping.py
│ ├── masking.py
│ ├── postprocess.py
│ ├── preprocessing.py
│ └── udf_distance_to_cloud.py
│ ├── parameters.py
│ ├── rdm_api
│ ├── __init__.py
│ ├── rdm_collection.py
│ └── rdm_interaction.py
│ ├── seasons.py
│ ├── stac
│ ├── __init__.py
│ ├── constants.py
│ └── stac_api_interaction.py
│ ├── train
│ └── data.py
│ ├── udp
│ ├── worldcereal_crop_extent.json
│ └── worldcereal_crop_type.json
│ └── utils
│ ├── geoloader.py
│ ├── legend.py
│ ├── map.py
│ ├── models.py
│ ├── refdata.py
│ ├── retry.py
│ ├── spark.py
│ ├── timeseries.py
│ └── upload.py
└── tests
├── pre_test_script.sh
└── worldcerealtests
├── __init__.py
├── conftest.py
├── test_feature_extractor.py
├── test_inference.py
├── test_pipelines.py
├── test_postprocessing.py
├── test_preprocessing.py
├── test_rdm_interaction.py
├── test_refdata.py
├── test_seasons.py
├── test_stac_api_interaction.py
├── test_timeseries.py
├── test_train.py
├── test_utils.py
└── testresources
├── preprocess_from_patches_graph.json
├── preprocess_graph.json
├── preprocess_graphwithslope.json
├── spatial_extent.json
├── test_public_extractions.parquet
├── worldcereal_cropland_classification.nc
├── worldcereal_croptype_classification.nc
├── worldcereal_preprocessed_inputs.nc
└── worldcereal_private_extractions_dummy.parquet
└── ref_id=2021_BEL_LPIS-Flanders_POLY_110
└── data_0.parquet
/.github/workflows/ci.yaml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | tests:
11 | name: tests
12 | runs-on: ubuntu-latest
13 | steps:
14 | - name: Clone repo
15 | uses: actions/checkout@v2
16 | - name: Set up python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: 3.9
20 | cache: 'pip'
21 | # - name: Install openEO-GFMAP from source
22 | # run: |
23 | # git clone https://github.com/Open-EO/openeo-gfmap.git
24 | # cd openeo-gfmap
25 | # pip install .
26 | # cd ..
27 | - name: Install Presto from source
28 | run: |
29 | git clone -b croptype https://github.com/WorldCereal/presto-worldcereal.git
30 | cd presto-worldcereal
31 | pip install .
32 | cd ..
33 | - name: Install WorldCereal dependencies
34 | run: pip install ".[dev,train,notebooks]"
35 | - name: Run WorldCereal tests
36 | run: python -m pytest -s --log-cli-level=INFO tests
37 | env:
38 | OPENEO_AUTH_METHOD: client_credentials
39 | OPENEO_OIDC_DEVICE_CODE_MAX_POLL_TIME: 5
40 | OPENEO_AUTH_PROVIDER_ID_CDSE: CDSE
41 | OPENEO_AUTH_PROVIDER_ID_VITO: terrascope
42 | OPENEO_AUTH_CLIENT_ID_CDSE: openeo-worldcereal-service-account
43 | OPENEO_AUTH_CLIENT_ID_VITO: openeo-worldcereal-service-account
44 | OPENEO_AUTH_CLIENT_SECRET_CDSE: ${{ secrets.OPENEO_AUTH_CLIENT_SECRET_CDSE }}
45 | OPENEO_AUTH_CLIENT_SECRET_VITO: ${{ secrets.OPENEO_AUTH_CLIENT_SECRET_TERRASCOPE }}
--------------------------------------------------------------------------------
/.github/workflows/lint.yaml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | lint:
11 | name: "Lint: code quality and formatting checks"
12 | runs-on: ubuntu-latest
13 | steps:
14 | - name: Clone repo
15 | uses: actions/checkout@v2
16 | - name: Set up python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: 3.9
20 | cache: 'pip'
21 | - name: Install Python dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | python -m pip install isort black==24.3.0 ruff==0.5.0 mypy==1.11.2 types-requests~=2.32.0
25 |
26 | - name: isort
27 | run: python -m isort . --check --diff
28 | - name: black
29 | run: python -m black --check --diff .
30 | - name: ruff
31 | run: ruff check .
32 | - name: mypy
33 | run: python -m mypy .
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | .vscode
3 | __pycache__
4 | .ftpconfig
5 | *.pyc
6 | *.py[cod]
7 | *$py.class
8 | tmp
9 | *.vrt
10 | *.pkl
11 | *.mep2
12 | *.tig
13 | *.tif
14 | *.jp2
15 | *.npy
16 | *.kml
17 | *.csv
18 | *.tsv
19 | *.tfevents
20 | *.nc
21 | *.gif
22 | config.json
23 | # C extensions
24 | *.so
25 | *.gz
26 | # Distribution / packaging
27 | .Python
28 | build/
29 | develop-eggs/
30 | dist/
31 | downloads/
32 | eggs/
33 | .eggs/
34 | lib/
35 | lib64/
36 | parts/
37 | sdist/
38 | var/
39 | wheels/
40 | *.egg-info/
41 | .installed.cfg
42 | *.egg
43 | MANIFEST
44 | .mypy_cache
45 | .vscode
46 | .history
47 | .noseids
48 | experimental/**
49 |
50 | # don't track catboost training info
51 | src/worldcereal/train/catboost_info
52 |
53 | # PyInstaller
54 | # Usually these files are written by a python script from a template
55 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
56 | *.manifest
57 | *.spec
58 |
59 | # Installer logs
60 | pip-log.txt
61 | pip-delete-this-directory.txt
62 |
63 | # Allow the test resources to be tracked
64 | !/tests/worldcerealtests/testresources/**
65 | *.aux.xml
66 |
67 | # Allow crop calendar resources to be tracked
68 | !/src/worldcereal/data/cropcalendars/*.tif
69 |
70 | # Allow resource CSV files to be tracked
71 | # !/src/worldcereal/resources/*.csv
72 |
73 | # Allow resource JSON files to be tracked
74 | # !/src/worldcereal/resources/**/*.json
75 |
76 | # Allow biomes/realms to be tracked
77 | # !/src/worldcereal/resources/biomes/*.tif
78 | # !/src/worldcereal/resources/realms/*.tif
79 |
80 | # Dont track zarr data
81 | features.zarr/
82 |
83 | # Unit test / coverage reports
84 | htmlcov/
85 | .tox/
86 | .nox/
87 | .coverage
88 | .coverage.*
89 | .cache
90 | nosetests.xml
91 | coverage.xml
92 | *.cover
93 | .hypothesis/
94 | .pytest_cache/
95 |
96 | # Translations
97 | *.mo
98 | *.pot
99 |
100 | # Django stuff:
101 | *.log
102 | local_settings.py
103 | db.sqlite3
104 |
105 | # Flask stuff:
106 | instance/
107 | .webassets-cache
108 |
109 | # Scrapy stuff:
110 | .scrapy
111 |
112 | # Sphinx documentation
113 | docs/_build/
114 |
115 | # PyBuilder
116 | target/
117 |
118 | # Jupyter Notebook
119 | .ipynb_checkpoints
120 |
121 | # IPython
122 | profile_default/
123 | ipython_config.py
124 |
125 | # pyenv
126 | .python-version
127 |
128 | # celery beat schedule file
129 | celerybeat-schedule
130 |
131 | # SageMath parsed files
132 | *.sage.py
133 |
134 | # Environments
135 | .env
136 | .venv
137 | env/
138 | venv/
139 | ENV/
140 | env.bak/
141 | venv.bak/
142 |
143 | # Spyder project settings
144 | .spyderproject
145 | .spyproject
146 |
147 | # Rope project settings
148 | .ropeproject
149 |
150 | # mkdocs documentation
151 | /site
152 |
153 | # mypy
154 | .mypy_cache/
155 | .dmypy.json
156 | dmypy.json
157 |
158 | # .idea stuff
159 | .idea
160 |
161 | !src/worldcereal/data/**/*.csv
162 |
163 | # Pyre type checker
164 | .pyre/
165 | download.zip
166 | catboost_info/*
167 | notebooks/catboost_info/*
168 |
169 | *.cbm
170 | *.pt
171 | *.onnx
172 | *.nc
173 | *.7z
174 | *.dmg
175 | *.gz
176 | *.iso
177 | *.jar
178 | *.rar
179 | *.tar
180 | *.zip
181 |
182 | .notebook-tests/
183 | .local-presto-test/
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # https://pre-commit.com
2 | default_language_version:
3 | python: python3
4 | default_stages: [pre-commit, manual]
5 | fail_fast: true
6 | exclude: "(received/|.*_depr)"
7 | repos:
8 | - repo: https://github.com/pre-commit/pre-commit-hooks
9 | rev: v4.5.0
10 | hooks:
11 | # - id: check-added-large-files
12 | # args: ['--maxkb=65536']
13 | - id: check-ast
14 | - id: check-builtin-literals
15 | - id: check-byte-order-marker
16 | - id: check-case-conflict
17 | - id: check-docstring-first
18 | - id: check-json
19 | - id: check-merge-conflict
20 | - id: check-symlinks
21 | - id: check-toml
22 | - id: check-vcs-permalinks
23 | - id: check-xml
24 | - id: check-yaml
25 | args: [--allow-multiple-documents]
26 | - id: debug-statements
27 | - id: detect-private-key
28 | - id: mixed-line-ending
29 | - id: trailing-whitespace
30 | types: [python]
31 | - id: end-of-file-fixer
32 | types: [python]
33 | - repo: local
34 | hooks:
35 | - id: shellcheck
36 | name: shellcheck
37 | entry: shellcheck --check-sourced --shell=bash --exclude=SC1087
38 | language: system
39 | types: [shell]
40 | # - id: pydocstyle
41 | # name: pydocstyle
42 | # entry: pydocstyle
43 | # language: system
44 | # types: [python]
45 | # exclude: "(^experiments/|.*_depr)"
46 | # - id: flake8
47 | # name: flake8
48 | # entry: flake8
49 | # language: system
50 | # types: [python]
51 | # exclude: "(^tasks/|.*_depr)"
52 | - repo: https://github.com/pre-commit/pre-commit-hooks
53 | rev: v4.5.0
54 | hooks:
55 | - id: no-commit-to-branch
56 | args: [ '--branch', 'main' ]
57 | - repo: https://github.com/radix-ai/auto-smart-commit
58 | rev: v1.0.3
59 | hooks:
60 | - id: auto-smart-commit
61 | - repo: https://github.com/psf/black
62 | rev: 24.3.0
63 | hooks:
64 | - id: black
65 | language_version: python3
66 | - repo: https://github.com/pycqa/isort
67 | rev: 5.13.2
68 | hooks:
69 | - id: isort
70 | name: isort (python)
71 | args: ["--profile", "black"]
72 | - repo: https://github.com/astral-sh/ruff-pre-commit
73 | rev: v0.1.15
74 | hooks:
75 | - id: ruff
76 | args: ["--fix"]
77 | - repo: https://github.com/pre-commit/mirrors-mypy
78 | rev: v1.11.2
79 | hooks:
80 | - id: mypy
81 | additional_dependencies: ["types-requests"]
82 |
--------------------------------------------------------------------------------
/Jenkinsfile:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env groovy
2 |
3 | /* Jenkinsfile for snapshot building with VITO CI system. */
4 |
5 | @Library('lib')_
6 |
7 | pythonPipeline {
8 | package_name = 'worldcereal-classification'
9 | test_module_name = 'worldcereal'
10 | wipeout_workspace = true
11 | python_version = ["3.10"]
12 | extras_require = "dev,train,notebooks"
13 | upload_dev_wheels = false
14 | pipeline_triggers = [cron('H H(0-6) * * *')]
15 | pep440 = true
16 | pre_test_script = 'pre_test_script.sh'
17 | pre_install_script = 'jenkins_pre_install_script.sh'
18 | extra_env_variables = [
19 | "OPENEO_AUTH_METHOD=client_credentials",
20 | "OPENEO_OIDC_DEVICE_CODE_MAX_POLL_TIME=5",
21 | "OPENEO_AUTH_PROVIDER_ID_VITO=terrascope",
22 | "OPENEO_AUTH_CLIENT_ID_VITO=openeo-worldcereal-service-account",
23 | "OPENEO_AUTH_PROVIDER_ID_CDSE=CDSE",
24 | "OPENEO_AUTH_CLIENT_ID_CDSE=openeo-worldcereal-service-account",
25 | ]
26 | extra_env_secrets = [
27 | 'OPENEO_AUTH_CLIENT_SECRET_VITO': 'TAP/big_data_services/devops/terraform/keycloak_mgmt/oidc_clients_prod openeo-worldcereal-service-account',
28 | 'OPENEO_AUTH_CLIENT_SECRET_CDSE': 'TAP/big_data_services/openeo/cdse-service-accounts/openeo-worldcereal-service-account client_secret',
29 | ]
30 | }
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 ESA WorldCereal
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ESA WorldCereal classification module
2 | [](https://github.com/WorldCereal/worldcereal-classification/actions/workflows/ci.yaml) [](https://github.com/pre-commit/pre-commit) [](https://opensource.org/licenses/MIT) [](https://doi.org/10.5194/essd-15-5491-2023) [](https://worldcereal.github.io/worldcereal-documentation/) [](https://forum.esa-worldcereal.org/)
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | ## Overview
12 |
13 | **WorldCereal** is a Python package designed for generating cropland and crop type maps at a wide range of spatial scales, leveraging satellite and auxiliary data, and state-of-the-art classification workflows. It uses [**openEO**](https://openeo.org/) to run classification tasks in the cloud, by default the [**Copernicus Data Space Ecosystem (CDSE)**](https://dataspace.copernicus.eu/).
14 |
15 | Users can leverage the system in a notebook environment through [Terrascope](https://terrascope.be/en) or set up the environment locally using the provided installation options.
16 |
17 | In order to run classification jobs on CDSE, users can get started with [monthly free openEO processing credits](https://documentation.dataspace.copernicus.eu/Quotas.html) by registering on the CDSE platform. Additional credits can be purchased, or users may soon be able to request them through the **ESA Network of Resources**.
18 |
19 | ## Features
20 |
21 | - **Scalable**: Generate maps at a wide range of spatial scales.
22 | - **Cloud-based Processing**: Leverages openEO to run classifications in the cloud.
23 | - **Powerful classification pipeline**: WorldCereal builds upon [Presto](https://arxiv.org/abs/2304.14065), a pretrained transformer-based model, leveraging global self-supervised learning of multimodal input timeseries, leading to better accuracies and higher generalizability in space and time of downstream crop classification models. The Presto backbone of WorldCereal classification pipelines is developed [here](https://github.com/WorldCereal/presto-worldcereal).
24 | - **Customizable**: Users can pick any region or temporal range and apply either default models or train their own and produce custom maps, interacting with publicly available training data.
25 | - **Easy to Use**: Integrates into Jupyter notebooks and other Python environments.
26 |
27 | ## Quick Start
28 |
29 | One of our [demo notebooks](notebooks) can introduce to you quickly how to map crops with the Worldcereal system. There's two options to run these notebooks:
30 |
31 | #### Option 1: Run on Terrascope
32 |
33 | You can use a preconfigured environment on [**Terrascope**](https://terrascope.be/en) to run the workflows in a Jupyter notebook environment. Just register as a new user on Terrascope or use one of the supported EGI eduGAIN login methods to get started.
34 |
35 | | :point_up: | Once you are prompted with "Server Options", make sure to select the "Worldcereal" image. Did you choose "Terrascope" by accident? Then go to File > Hub Control Panel > Stop my server, and click the link below once again. |
36 | |---------------|:------------------------|
37 |
38 | - For a cropland map generation demo without any model training:
39 |
40 | - For a crop type map generation demo with model training:
41 |
42 | - For a demo on how to interact with the WorldCereal Reference Data Module (RDM):
43 |
44 | #### Option 2: Install Locally
45 |
46 | If you prefer to install the package locally, you can create the environment using **Conda** or **pip**.
47 |
48 | First clone the repository:
49 | ```bash
50 | git clone https://github.com/WorldCereal/worldcereal-classification.git
51 | cd worldcereal-classification
52 | ```
53 | Next, install the package locally:
54 | - for Conda: `conda env create -f environment.yml`
55 | - for Pip: `pip install .[train,notebooks]`
56 |
57 | ## Usage Example
58 | In its most simple form, a cropland mask can be generated with just few lines of code, triggering an openEO job on CDSE and downloading the result locally:
59 |
60 | ```python
61 | from openeo_gfmap import BoundingBoxExtent, TemporalContext
62 | from worldcereal.job import generate_map
63 |
64 | # Specify the spatial extent
65 | spatial_extent = BoundingBoxExtent(
66 | west=44.432274,
67 | south=51.317362,
68 | east=44.698802,
69 | north=51.428224,
70 | epsg=4326
71 | )
72 |
73 | # Specify the temporal extent (this has to be one year)
74 | temporal_extent = TemporalContext('2022-11-01', '2023-10-31')
75 |
76 | # Launch processing job (result will automatically be downloaded)
77 | results = generate_map(spatial_extent, temporal_extent, output_dir='.')
78 | ```
79 |
80 | ## Documentation
81 |
82 | Comprehensive documentation is available at the following link: https://worldcereal.github.io/worldcereal-documentation/
83 |
84 | ## Support
85 | Questions, suggestions, feedback? Use [our forum](https://forum.esa-worldcereal.org/) to get in touch with us and the rest of the community!
86 |
87 | ## License
88 |
89 | This project is licensed under the terms of the MIT License. See the [LICENSE](LICENSE) file for details.
90 |
91 | ## Acknowledgments
92 |
93 | The WorldCereal project is funded by the [European Space Agency (ESA)](https://www.esa.int/) under grant no. 4000130569/20/I-NB.
94 |
95 | WorldCereal's classification backbone makes use of the Presto model, originally implemented [here](https://github.com/nasaharvest/presto/). Without the groundbreaking work being done by Gabriel Tseng and the rest of the [NASA Harvest](https://www.nasaharvest.org/) team, both in the original Presto implementation as well as its adaptation for WorldCereal, this package would simply not exist in its present form 🙏.
96 |
97 | The pre-configured Jupyter notebook environment in which users can train custom models and launch WorldCereal jobs is provided by [Terrascope](https://terrascope.be/en), the Belgian Earth observation data space, managed by [VITO Remote Sensing](https://remotesensing.vito.be/) on behalf of the [Belgian Science Policy Office](https://www.belspo.be/belspo/index_en.stm)
98 |
99 | ## How to cite
100 |
101 | If you use WorldCereal resources in your work, please cite it as follows:
102 |
103 | ```bibtex
104 |
105 | @article{van_tricht_worldcereal_2023,
106 | title = {{WorldCereal}: a dynamic open-source system for global-scale, seasonal, and reproducible crop and irrigation mapping},
107 | volume = {15},
108 | issn = {1866-3516},
109 | shorttitle = {{WorldCereal}},
110 | url = {https://essd.copernicus.org/articles/15/5491/2023/},
111 | doi = {10.5194/essd-15-5491-2023},
112 | number = {12},
113 | urldate = {2024-03-01},
114 | journal = {Earth System Science Data},
115 | author = {Van Tricht, Kristof and Degerickx, Jeroen and Gilliams, Sven and Zanaga, Daniele and Battude, Marjorie and Grosu, Alex and Brombacher, Joost and Lesiv, Myroslava and Bayas, Juan Carlos Laso and Karanam, Santosh and Fritz, Steffen and Becker-Reshef, Inbal and Franch, Belén and Mollà-Bononad, Bertran and Boogaard, Hendrik and Pratihast, Arun Kumar and Koetz, Benjamin and Szantoi, Zoltan},
116 | month = dec,
117 | year = {2023},
118 | pages = {5491--5515},
119 | }
120 | ```
121 |
--------------------------------------------------------------------------------
/assets/worldcereal_logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/assets/worldcereal_logo.jpg
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: worldcereal
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - boto3=1.35.30
8 | - cftime=1.6.4
9 | - catboost=1.2.5
10 | - cpuonly
11 | - einops=0.8.0
12 | - fastparquet=2024.2.0
13 | - geojson=3.1.0
14 | - h5netcdf=1.3.0
15 | - geopandas=0.14.4
16 | - h5py=3.11.0
17 | - leafmap=0.35.1
18 | - loguru=0.7.2
19 | - netcdf4
20 | - numpy<2.0.0
21 | - openeo=0.35.0
22 | - pip
23 | - pyarrow=16.1.0
24 | - pydantic=2.8.0
25 | - python=3.10.0
26 | - pytorch=2.3.1
27 | - rasterio=1.3.10
28 | - rioxarray=0.15.5
29 | - scikit-image=0.22.0
30 | - scikit-learn=1.5.0
31 | - scipy
32 | - shapely=2.0.4
33 | - tqdm
34 | - pip:
35 | - duckdb==1.1.3
36 | - h3==4.1.0
37 | - openeo-gfmap==0.4.6
38 | - git+https://github.com/worldcereal/worldcereal-classification
39 | - git+https://github.com/WorldCereal/presto-worldcereal.git@croptype
40 |
41 |
--------------------------------------------------------------------------------
/jenkins_pre_install_script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Install git
4 | dnf install git -y
5 |
6 | # # Install openeo-gfmap and presto-worldcereal
7 | # dir=$(pwd)
8 | # GFMAP_URL="https://github.com/Open-EO/openeo-gfmap.git"
9 | # PRESTO_URL="https://github.com/WorldCereal/presto-worldcereal.git"
10 |
11 | # su - jenkins -c "cd $dir && \
12 | # source venv310/bin/activate && \
13 | # git clone $GFMAP_URL && \
14 | # cd openeo-gfmap || { echo 'Directory not found! Exiting...'; exit 1; } && \
15 | # pip install . && \
16 | # cd ..
17 | # git clone -b croptype $PRESTO_URL && \
18 | # cd presto-worldcereal || { echo 'Directory not found! Exiting...'; exit 1; } && \
19 | # pip install .
20 | # "
21 |
22 | # For now only presto-worldcereal as gfmap is up to date on pypi
23 | dir=$(pwd)
24 |
25 | su - jenkins -c "cd $dir && \
26 | source venv310/bin/activate && \
27 | pip install git+https://github.com/WorldCereal/presto-worldcereal.git@croptype && \
28 | pip install git+https://github.com/WorldCereal/prometheo.git
29 | "
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | exclude = venv|notebooks
3 | ignore_missing_imports = True
4 |
5 | [mypy-yaml.*]
6 | ignore_missing_imports = True
7 |
--------------------------------------------------------------------------------
/notebooks/notebook_utils/classifier.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from catboost import CatBoostClassifier, Pool
6 | from loguru import logger
7 | from presto.utils import DEFAULT_SEED
8 | from sklearn.metrics import classification_report, confusion_matrix
9 | from sklearn.model_selection import train_test_split
10 | from sklearn.utils.class_weight import compute_class_weight
11 |
12 | from worldcereal.parameters import CropLandParameters, CropTypeParameters
13 |
14 |
15 | def get_input(label):
16 | while True:
17 | modelname = input(f"Enter a short name for your {label} (don't use spaces): ")
18 | if " " not in modelname:
19 | return modelname
20 | print("Invalid input. Please enter a name without spaces.")
21 |
22 |
23 | def prepare_training_dataframe(
24 | df: pd.DataFrame,
25 | batch_size: int = 256,
26 | task_type: str = "croptype",
27 | augment: bool = True,
28 | mask_ratio: float = 0.30,
29 | repeats: int = 1,
30 | ) -> pd.DataFrame:
31 | """Method to generate a training dataframe with Presto embeddings for downstream Catboost training.
32 |
33 | Parameters
34 | ----------
35 | df : pd.DataFrame
36 | input dataframe with required input features for Presto
37 | batch_size : int, optional
38 | by default 256
39 | task_type : str, optional
40 | cropland or croptype task, by default "croptype"
41 | augment : bool, optional
42 | if True, temporal jittering is enabled, by default True
43 | mask_ratio : float, optional
44 | if > 0, inputs are randomly masked before computing Presto embeddings, by default 0.30
45 | repeats: int, optional
46 | number of times to repeat each, by default 1
47 |
48 | Returns
49 | -------
50 | pd.DataFrame
51 | output training dataframe for downstream training
52 |
53 | Raises
54 | ------
55 | ValueError
56 | if an unknown tasktype is specified
57 | ValueError
58 | if repeats > 1 and augment=False and mask_ratio=0
59 | """
60 | from presto.presto import Presto
61 |
62 | from worldcereal.train.data import WorldCerealTrainingDataset, get_training_df
63 |
64 | if task_type == "croptype":
65 | presto_model_url = CropTypeParameters().feature_parameters.presto_model_url
66 | use_valid_date_token = (
67 | CropTypeParameters().feature_parameters.use_valid_date_token
68 | )
69 | elif task_type == "cropland":
70 | presto_model_url = CropLandParameters().feature_parameters.presto_model_url
71 | use_valid_date_token = (
72 | CropLandParameters().feature_parameters.use_valid_date_token
73 | )
74 | else:
75 | raise ValueError(f"Unknown task type: {task_type}")
76 |
77 | if repeats > 1 and not augment and mask_ratio == 0:
78 | raise ValueError("Repeats > 1 requires augment=True or mask_ratio > 0.")
79 |
80 | # Load pretrained Presto model
81 | logger.info(f"Presto URL: {presto_model_url}")
82 | presto_model = Presto.load_pretrained(
83 | presto_model_url,
84 | from_url=True,
85 | strict=False,
86 | valid_month_as_token=use_valid_date_token,
87 | )
88 |
89 | # Initialize dataset
90 | df = df.reset_index()
91 | ds = WorldCerealTrainingDataset(
92 | df,
93 | task_type=task_type,
94 | augment=True,
95 | mask_ratio=mask_ratio,
96 | repeats=repeats,
97 | )
98 | logger.info("Computing Presto embeddings ...")
99 | df = get_training_df(
100 | ds,
101 | presto_model,
102 | batch_size=batch_size,
103 | valid_date_as_token=use_valid_date_token,
104 | )
105 |
106 | logger.info("Done.")
107 |
108 | return df
109 |
110 |
111 | def train_classifier(
112 | training_dataframe: pd.DataFrame,
113 | class_names: Optional[List[str]] = None,
114 | balance_classes: bool = False,
115 | ) -> Tuple[CatBoostClassifier, Union[str | dict], np.ndarray]:
116 | """Method to train a custom CatBoostClassifier on a training dataframe.
117 |
118 | Parameters
119 | ----------
120 | training_dataframe : pd.DataFrame
121 | training dataframe containing inputs and targets
122 | class_names : Optional[List[str]], optional
123 | class names to use, by default None
124 | balance_classes : bool, optional
125 | if True, class weights are used during training to balance the classes, by default False
126 |
127 | Returns
128 | -------
129 | Tuple[CatBoostClassifier, Union[str | dict], np.ndarray]
130 | The trained CatBoost model, the classification report, and the confusion matrix
131 |
132 | Raises
133 | ------
134 | ValueError
135 | When not enough classes are present in the training dataframe to train a model
136 | """
137 |
138 | logger.info("Split train/test ...")
139 | samples_train, samples_test = train_test_split(
140 | training_dataframe,
141 | test_size=0.2,
142 | random_state=DEFAULT_SEED,
143 | stratify=training_dataframe["downstream_class"],
144 | )
145 |
146 | # Define loss function and eval metric
147 | if np.unique(samples_train["downstream_class"]).shape[0] < 2:
148 | raise ValueError("Not enough classes to train a classifier.")
149 | elif np.unique(samples_train["downstream_class"]).shape[0] > 2:
150 | eval_metric = "MultiClass"
151 | loss_function = "MultiClass"
152 | else:
153 | eval_metric = "Logloss"
154 | loss_function = "Logloss"
155 |
156 | # Compute sample weights
157 | if balance_classes:
158 | logger.info("Computing class weights ...")
159 | class_weights = np.round(
160 | compute_class_weight(
161 | class_weight="balanced",
162 | classes=np.unique(samples_train["downstream_class"]),
163 | y=samples_train["downstream_class"],
164 | ),
165 | 3,
166 | )
167 | class_weights = {
168 | k: v
169 | for k, v in zip(np.unique(samples_train["downstream_class"]), class_weights)
170 | }
171 | logger.info(f"Class weights: {class_weights}")
172 |
173 | sample_weights = np.ones((len(samples_train["downstream_class"]),))
174 | sample_weights_val = np.ones((len(samples_test["downstream_class"]),))
175 | for k, v in class_weights.items():
176 | sample_weights[samples_train["downstream_class"] == k] = v
177 | sample_weights_val[samples_test["downstream_class"] == k] = v
178 | samples_train["weight"] = sample_weights
179 | samples_test["weight"] = sample_weights_val
180 | else:
181 | samples_train["weight"] = 1
182 | samples_test["weight"] = 1
183 |
184 | # Define classifier
185 | custom_downstream_model = CatBoostClassifier(
186 | iterations=2000, # Not too high to avoid too large model size
187 | depth=8,
188 | early_stopping_rounds=20,
189 | loss_function=loss_function,
190 | eval_metric=eval_metric,
191 | random_state=DEFAULT_SEED,
192 | verbose=25,
193 | class_names=(
194 | class_names
195 | if class_names is not None
196 | else np.unique(samples_train["downstream_class"])
197 | ),
198 | )
199 |
200 | # Setup dataset Pool
201 | bands = [f"presto_ft_{i}" for i in range(128)]
202 | calibration_data = Pool(
203 | data=samples_train[bands],
204 | label=samples_train["downstream_class"],
205 | weight=samples_train["weight"],
206 | )
207 | eval_data = Pool(
208 | data=samples_test[bands],
209 | label=samples_test["downstream_class"],
210 | weight=samples_test["weight"],
211 | )
212 |
213 | # Train classifier
214 | logger.info("Training CatBoost classifier ...")
215 | custom_downstream_model.fit(
216 | calibration_data,
217 | eval_set=eval_data,
218 | )
219 |
220 | # Make predictions
221 | pred = custom_downstream_model.predict(samples_test[bands]).flatten()
222 |
223 | report = classification_report(samples_test["downstream_class"], pred)
224 | confuson_matrix = confusion_matrix(samples_test["downstream_class"], pred)
225 |
226 | return custom_downstream_model, report, confuson_matrix
227 |
--------------------------------------------------------------------------------
/notebooks/notebook_utils/dateslider.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timedelta
2 |
3 | import ipywidgets as widgets
4 | import pandas as pd
5 | from IPython.core.display import HTML as core_HTML
6 | from IPython.display import display
7 | from loguru import logger
8 | from openeo_gfmap import TemporalContext
9 |
10 |
11 | class date_slider:
12 | """Class that provides a slider for selecting a processing period.
13 | The processing period is fixed in length, amounting to one year.
14 | The processing period will always start the first day of a month and end the last day of a month.
15 | """
16 |
17 | def __init__(self, start_date=datetime(2018, 1, 1), end_date=datetime(2024, 12, 1)):
18 |
19 | # Define the slider
20 | dates = pd.date_range(start_date, end_date, freq="MS")
21 | options = [(date.strftime("%b %Y"), date) for date in dates]
22 | self.interval_slider = widgets.SelectionRangeSlider(
23 | options=options,
24 | index=(0, 11), # Default to a 11-month interval
25 | orientation="horizontal",
26 | description="",
27 | continuous_update=False,
28 | behaviour="drag",
29 | style={
30 | "handle_color": "dodgerblue",
31 | },
32 | layout=widgets.Layout(
33 | width="600px",
34 | margin="0 0 0 10px",
35 | ),
36 | readout=False,
37 | )
38 |
39 | # Define the HTML text widget for the selected range and focus time
40 | initial_range = [
41 | (pd.to_datetime(start_date)).strftime("%d %b %Y"),
42 | (
43 | pd.to_datetime(start_date)
44 | + pd.DateOffset(months=12)
45 | - timedelta(days=1)
46 | ).strftime("%d %b %Y"),
47 | ]
48 | initial_focus_time = (
49 | pd.to_datetime(start_date) + pd.DateOffset(months=6)
50 | ).strftime("%b %Y")
51 | self.html_text = widgets.HTML(
52 | value=f"Selected range: {initial_range[0]} - {initial_range[1]}
Season center: {initial_focus_time}",
53 | placeholder="HTML placeholder",
54 | description="",
55 | layout=widgets.Layout(justify_content="center", display="flex"),
56 | )
57 |
58 | # Attach slider observer
59 | self.interval_slider.observe(self.on_slider_change, names="value")
60 |
61 | # Add custom CSS for the ticks
62 | custom_css = """
63 |
97 | """
98 |
99 | # # Generate ticks
100 | tick_dates = pd.date_range(
101 | start_date, pd.to_datetime(end_date) + pd.DateOffset(months=1), freq="4MS"
102 | )
103 | tick_labels = [date.strftime("%b %Y") for date in tick_dates]
104 | n_labels = len(tick_labels)
105 | ticks_html = ""
106 | for i, label in enumerate(tick_labels):
107 | position = (i / (n_labels - 1)) * 100 # Position as a percentage
108 | ticks_html += f"""
109 | |
110 | {label.split()[0]}
{label.split()[1]}
111 | """
112 |
113 | # HTML container for tick marks and labels
114 | tick_marks_and_labels = widgets.HTML(
115 | value=f"""
116 |
123 | """
124 | )
125 |
126 | # Combine slider and ticks using VBox
127 | slider_with_ticks = widgets.VBox(
128 | [self.interval_slider, tick_marks_and_labels],
129 | layout=widgets.Layout(
130 | width="640px", align_items="center", justify_content="center"
131 | ),
132 | )
133 |
134 | # Add description widget
135 | descr_widget = widgets.HTML(
136 | value="""
137 |
138 |
139 | Position the slider to select your processing period:
140 |
141 |
142 | """
143 | )
144 |
145 | # Arrange the description widget, interval slider, ticks and text widget in a VBox
146 | vbox = widgets.VBox(
147 | [
148 | descr_widget,
149 | slider_with_ticks,
150 | self.html_text,
151 | ],
152 | layout=widgets.Layout(
153 | align_items="center", justify_content="center", width="650px"
154 | ),
155 | )
156 |
157 | display(core_HTML(custom_css))
158 | display(vbox)
159 |
160 | def on_slider_change(self, change):
161 |
162 | start, end = change["new"]
163 |
164 | # keep the interval fixed
165 | expected_end = start + pd.DateOffset(months=11)
166 | if end != expected_end:
167 | end = start + pd.DateOffset(months=11)
168 | self.interval_slider.value = (start, end)
169 |
170 | # update the HTML text underneath the slider
171 | range = [
172 | (pd.to_datetime(start)).strftime("%d %b %Y"),
173 | (
174 | pd.to_datetime(start) + pd.DateOffset(months=12) - timedelta(days=1)
175 | ).strftime("%d %b %Y"),
176 | ]
177 | focus_time = (start + pd.DateOffset(months=6)).strftime("%b %Y")
178 | self.html_text.value = f"Selected range: {range[0]} - {range[1]}
Season center: {focus_time}"
179 |
180 | def get_processing_period(self):
181 |
182 | start = pd.to_datetime(self.interval_slider.value[0])
183 | end = start + pd.DateOffset(months=12) - timedelta(days=1)
184 |
185 | start = start.strftime("%Y-%m-%d")
186 | end = end.strftime("%Y-%m-%d")
187 | logger.info(f"Selected processing period: {start} to {end}")
188 |
189 | return TemporalContext(start, end)
190 |
--------------------------------------------------------------------------------
/notebooks/notebook_utils/seasons.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from calendar import monthrange
3 | from typing import List
4 |
5 | import matplotlib.pyplot as plt
6 | import pandas as pd
7 | from matplotlib.patches import Rectangle
8 | from openeo_gfmap import BoundingBoxExtent
9 | from pyproj import Transformer
10 |
11 | from worldcereal.seasons import get_season_dates_for_extent
12 |
13 | logging.getLogger("rasterio").setLevel(logging.ERROR)
14 |
15 |
16 | def get_month_decimal(date):
17 |
18 | return date.timetuple().tm_mon + (
19 | date.timetuple().tm_mday / monthrange(2021, date.timetuple().tm_mon)[1]
20 | )
21 |
22 |
23 | def plot_worldcereal_seasons(
24 | seasons: dict,
25 | extent: BoundingBoxExtent,
26 | tag: str = "",
27 | ):
28 | """Method to plot WorldCereal seasons in a matplotlib plot.
29 | Parameters
30 | ----------
31 | seasons : dict
32 | dictionary with season names as keys and start and end dates as values
33 | dates need to be in datetime format
34 | extent : BoundingBoxExtent
35 | extent for which to plot the seasons
36 | tag : str, optional
37 | tag to add to the title of the plot, by default empty string
38 |
39 | Returns
40 | -------
41 | None
42 | """
43 |
44 | # get lat, lon centroid of extent
45 | transformer = Transformer.from_crs(
46 | f"EPSG:{extent.epsg}", "EPSG:4326", always_xy=True
47 | )
48 | minx, miny = transformer.transform(extent.west, extent.south)
49 | maxx, maxy = transformer.transform(extent.east, extent.north)
50 | lat = (maxy + miny) / 2
51 | lon = (maxx + minx) / 2
52 | location = f"lat={lat:.2f}, lon={lon:.2f}"
53 |
54 | # prepare figure
55 | fig, ax = plt.subplots()
56 | plt.title(f"WorldCereal seasons {tag} ({location})")
57 | ax.set_ylim((0.4, len(seasons) + 0.5))
58 | ax.set_xlim((0, 13))
59 | ax.set_yticks(range(1, len(seasons) + 1))
60 | ax.set_yticklabels(list(seasons.keys()))
61 | ax.set_xticks(range(1, 13))
62 | ax.set_xticklabels(
63 | [
64 | "Jan",
65 | "Feb",
66 | "Mar",
67 | "Apr",
68 | "May",
69 | "Jun",
70 | "Jul",
71 | "Aug",
72 | "Sep",
73 | "Oct",
74 | "Nov",
75 | "Dec",
76 | ]
77 | )
78 | facecolor = "darkgoldenrod"
79 |
80 | # Get the start and end date for each season
81 | idx = 0
82 | for name, dates in seasons.items():
83 | sos, eos = dates
84 |
85 | # get start and end month (decimals) for plotting
86 | start = get_month_decimal(sos)
87 | end = get_month_decimal(eos)
88 |
89 | # add rectangle to plot
90 | if start < end:
91 | ax.add_patch(
92 | Rectangle((start, idx + 0.75), end - start, 0.5, color=facecolor)
93 | )
94 | else:
95 | ax.add_patch(
96 | Rectangle((start, idx + 0.75), 12 - start, 0.5, color=facecolor)
97 | )
98 | ax.add_patch(Rectangle((1, idx + 0.75), end - 1, 0.5, color=facecolor))
99 |
100 | # add labels to plot
101 | label_start = sos.strftime("%B %d")
102 | label_end = eos.strftime("%B %d")
103 | plt.text(
104 | start - 0.2,
105 | idx + 0.65,
106 | label_start,
107 | fontsize=8,
108 | color="darkgreen",
109 | ha="left",
110 | va="center",
111 | )
112 | plt.text(
113 | end + 0.2,
114 | idx + 0.65,
115 | label_end,
116 | fontsize=8,
117 | color="darkred",
118 | ha="right",
119 | va="center",
120 | )
121 |
122 | idx += 1
123 |
124 | # display plot
125 | plt.show()
126 |
127 |
128 | def retrieve_worldcereal_seasons(
129 | extent: BoundingBoxExtent,
130 | seasons: List[str] = ["s1", "s2"],
131 | plot: bool = True,
132 | ):
133 | """Method to retrieve default WorldCereal seasons from global crop calendars.
134 | These will be logged to the screen for informative purposes.
135 |
136 | Parameters
137 | ----------
138 | extent : BoundingBoxExtent
139 | extent for which to load seasonality
140 | seasons : List[str], optional
141 | seasons to load, by default s1 and s2
142 | plot : bool, optional
143 | whether to plot the seasons, by default True
144 |
145 | Returns
146 | -------
147 | dict
148 | dictionary with season names as keys and start and end dates as values
149 | """
150 | results = {}
151 |
152 | # Get the start and end date for each season
153 | for idx, season in enumerate(seasons):
154 | seasonal_extent = get_season_dates_for_extent(extent, 2021, f"tc-{season}")
155 | sos = pd.to_datetime(seasonal_extent.start_date)
156 | eos = pd.to_datetime(seasonal_extent.end_date)
157 | results[season] = (sos, eos)
158 |
159 | # Plot the seasons if requested
160 | if plot:
161 | plot_worldcereal_seasons(results, extent)
162 |
163 | return results
164 |
--------------------------------------------------------------------------------
/notebooks/patch_to_point/job_df.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/patch_to_point/job_df.parquet
--------------------------------------------------------------------------------
/notebooks/patch_to_point/rdm_query_example:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/patch_to_point/rdm_query_example
--------------------------------------------------------------------------------
/notebooks/resources/Cropland_inference_choose_end_date.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/Cropland_inference_choose_end_date.png
--------------------------------------------------------------------------------
/notebooks/resources/Custom_cropland_map.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/Custom_cropland_map.png
--------------------------------------------------------------------------------
/notebooks/resources/Custom_croptype_map.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/Custom_croptype_map.png
--------------------------------------------------------------------------------
/notebooks/resources/Default_cropland_map.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/Default_cropland_map.png
--------------------------------------------------------------------------------
/notebooks/resources/Landcover_training_data_density_PhI.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/Landcover_training_data_density_PhI.png
--------------------------------------------------------------------------------
/notebooks/resources/MOOC_refdata_RDM_exploration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/MOOC_refdata_RDM_exploration.png
--------------------------------------------------------------------------------
/notebooks/resources/WorldCereal_private_extractions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/WorldCereal_private_extractions.png
--------------------------------------------------------------------------------
/notebooks/resources/cropland_data_Ph1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/cropland_data_Ph1.png
--------------------------------------------------------------------------------
/notebooks/resources/croptype_data_Ph1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/notebooks/resources/croptype_data_Ph1.png
--------------------------------------------------------------------------------
/notebooks/resources/wc2eurocrops_map.csv:
--------------------------------------------------------------------------------
1 | name,croptype,ewoc_code,landcover
2 | Unknown,0,00-00-00-000-0,0
3 | Cereals,1000,11-01-00-000-0,11
4 | Wheat,1100,11-01-01-000-0,11
5 | Winter wheat,1110,11-01-01-000-1,11
6 | Spring wheat,1120,11-01-01-000-2,11
7 | Maize,1200,11-01-06-000-0,11
8 | Rice,1300,11-01-08-000-0,11
9 | Sorghum,1400,11-01-07-003-0,11
10 | Barley,1500,11-01-02-000-0,11
11 | Winter barley,1510,11-01-02-000-1,11
12 | Spring barley,1520,11-01-02-000-2,11
13 | Rye,1600,11-01-03-000-0,11
14 | Winter rye,1610,11-01-03-000-1,11
15 | Spring rye,1620,11-01-03-000-2,11
16 | Oats,1700,11-01-04-000-0,11
17 | Millets,1800,11-01-07-001-0,11
18 | Other cereals,1900,11-01-00-000-0,11
19 | Winter cereal,1910,11-01-00-000-1,11
20 | Spring cereal,1920,11-01-00-000-2,11
21 | Vegetables and melons,2000,11-03-00-000-0,11
22 | Leafy or stem vegetables,2100,11-03-04-000-0,11
23 | Artichokes,2110,11-03-04-002-0,11
24 | Asparagus,2120,11-03-04-003-0,11
25 | Cabages,2130,11-03-06-000-0,11
26 | Cauliflowers & brocoli,2140,11-03-06-013-0,11
27 | Lettuce,2150,11-03-08-011-0,11
28 | Spinach,2160,11-03-08-008-0,11
29 | Chicory,2170,11-03-08-002-0,11
30 | Other leafy/stem vegetables,2190,11-03-04-000-0,11
31 | Fruit-bearing vegetables,2200,11-03-01-000-0,11
32 | Cucumbers,2210,11-03-02-001-0,11
33 | Eggplants,2220,11-03-01-002-0,11
34 | Tomatoes,2230,11-03-01-001-0,11
35 | Watermelons,2240,11-03-02-005-0,11
36 | Cantaloupes and other melons,2250,11-03-02-003-0,11
37 | "Pumpkin, squash and gourds",2260,11-03-02-004-0,11
38 | Other fruit-bearing vegetables,2290,11-03-01-000-0,11
39 | "Root, bulb or tuberous vegetables",2300,11-03-09-000-0,11
40 | Carrots,2310,11-03-09-004-0,11
41 | Turnips,2320,11-03-09-007-0,11
42 | Garlic,2330,11-03-11-002-0,11
43 | Onions and shallots,2340,11-03-11-007-0,11
44 | Leeks & other alliaceous vegetables,2350,11-03-11-000-0,11
45 | "Other root, bulb or tuberous vegetables",2390,11-03-09-000-0,11
46 | Mushrooms and truffles,2400,11-04-00-000-0,11
47 | Other vegetables,2900,11-03-00-000-0,11
48 | Fruit and nuts,3000,12-01-00-000-0,12
49 | Tropical and subtropical fruits,3100,12-01-02-000-0,12
50 | Avocados,3110,12-01-02-001-0,12
51 | Bananas & plantains,3120,12-01-02-002-0,12
52 | Dates,3130,12-01-02-003-0,12
53 | Figs,3140,12-01-01-006-0,12
54 | Mangoes,3150,12-01-02-004-0,12
55 | Papayas,3160,12-01-02-005-0,12
56 | Pineapples,3170,12-01-02-006-0,12
57 | Other tropical and subtropical fruits,3190,12-01-02-000-0,12
58 | Citrus fruits,3200,12-01-03-000-0,12
59 | Grapefruit & pomelo,3210,12-01-03-000-0,12
60 | Lemons & Limes,3220,12-01-03-003-0,12
61 | Oranges,3230,12-01-03-004-0,12
62 | "Tangerines, mandarins, clementines",3240,12-01-03-005-0,12
63 | Other citrus fruit,3290,12-01-03-000-0,12
64 | Grapes,3300,12-01-00-001-0,12
65 | Berries,3400,12-01-05-000-0,12
66 | Currants,3410,12-01-05-006-0,12
67 | Gooseberries,3420,12-01-05-007-0,12
68 | Kiwi fruit,3430,12-01-05-015-0,12
69 | Raspberries,3440,12-01-05-010-0,12
70 | Strawberries,3450,11-03-12-001-0,11
71 | Blueberries,3460,12-01-05-004-0,12
72 | Other berries,3490,12-01-05-000-0,12
73 | Pome fruits and stone fruits,3500,12-01-01-000-0,12
74 | Apples,3510,12-01-01-002-0,12
75 | Apricots,3520,12-01-01-003-0,12
76 | Cherries & sour cherries,3530,12-01-01-004-0,12
77 | Peaches & nectarines,3540,12-01-01-017-0,12
78 | Pears & quinces,3550,12-01-01-019-0,12
79 | Plums and sloes,3560,12-01-01-012-0,12
80 | Other pome fruits and stone fruits,3590,12-01-01-000-0,12
81 | Nuts,3600,12-01-04-000-0,12
82 | Almonds,3610,12-01-04-001-0,12
83 | Cashew nuts,3620,12-01-04-007-0,12
84 | Chestnuts,3630,12-01-04-005-0,12
85 | Hazelnuts,3640,12-01-04-002-0,12
86 | Pistachios,3650,12-01-04-004-0,12
87 | Walnuts,3660,12-01-04-006-0,12
88 | Other nuts,3690,12-01-04-000-0,12
89 | Other fruit,3900,12-01-00-000-0,12
90 | Oilseed crops,4000,11-06-00-000-0,11
91 | Soya beans,4100,11-06-00-002-0,11
92 | Groundnuts,4200,11-06-00-005-0,11
93 | Temporary oilseed crops,4300,11-06-00-000-0,11
94 | Castor bean,4310,11-06-00-006-0,11
95 | Linseed,4320,11-06-00-007-0,11
96 | Mustard,4330,11-06-00-008-0,11
97 | Niger seed,4340,11-06-00-004-0,11
98 | Rapeseed,4350,11-06-00-003-0,11
99 | Winter rapeseed,4351,11-06-00-003-1,11
100 | Spring rapeseed,4352,11-06-00-003-2,11
101 | Safflower,4360,11-06-00-009-0,11
102 | Sesame,4370,11-06-00-010-0,11
103 | Sunflower,4380,11-06-00-001-0,11
104 | Other temporary oilseed crops,4390,11-06-00-000-0,11
105 | Permanent oilseed crops,4400,12-03-00-000-0,12
106 | Coconuts,4410,12-03-00-002-0,12
107 | Olives,4420,12-03-00-001-0,12
108 | Oil palms,4430,12-03-00-003-0,12
109 | Other oleaginous fruits,4490,12-03-00-000-0,12
110 | Root/tuber crops,5000,11-07-00-000-0,11
111 | Potatoes,5100,11-07-00-001-0,11
112 | Sweet potatoes,5200,11-07-00-002-0,11
113 | Cassava,5300,11-07-00-004-0,11
114 | Yams,5400,11-07-00-005-0,11
115 | Other roots and tubers,5900,11-07-00-000-0,11
116 | Beverage and spice crops,6000,10-00-00-000-0,10
117 | Beverage crops,6100,12-02-00-000-0,12
118 | Coffee,6110,12-02-00-001-0,12
119 | Tea,6120,12-02-00-002-0,12
120 | Maté,6130,12-02-00-003-0,12
121 | Cocoa,6140,12-02-00-004-0,12
122 | Other beverage crops,6190,10-00-00-000-0,10
123 | Spice crops,6200,12-02-00-000-0,12
124 | Chilies & peppers,6211,11-03-03-000-0,11
125 | "Anise, badian, fennel",6212,11-00-00-000-0,11
126 | Other temporary spice crops,6219,11-09-00-000-0,11
127 | Pepper,6221,11-09-00-029-0,11
128 | "Nutmeg, mace, cardamoms",6222,12-02-00-007-0,12
129 | Cinnamon,6223,12-02-00-008-0,12
130 | Cloves,6224,12-02-00-009-0,12
131 | Ginger,6225,11-09-00-048-0,11
132 | Vanilla,6226,11-09-00-049-0,10
133 | Other permanent spice crops,6229,12-02-00-000-0,12
134 | Leguminous crops,7000,11-05-00-000-0,11
135 | Beans,7100,11-05-01-001-0,11
136 | Broad beans,7200,11-05-01-003-0,11
137 | Chick peas,7300,11-05-01-004-0,11
138 | Cow peas,7400,11-05-01-005-0,11
139 | Lentils,7500,11-05-00-003-0,11
140 | Lupins,7600,11-05-00-004-0,11
141 | Peas,7700,11-05-01-002-0,11
142 | Pigeon peas,7800,11-05-01-006-0,11
143 | Other Leguminous crops,7900,10-00-00-000-0,10
144 | Other Leguminous crops - Temporary,7910,11-05-00-000-0,11
145 | Other Leguminous crops - Permanent,7920,12-04-00-000-0,12
146 | Sugar crops,8000,10-00-00-000-0,10
147 | Sugar beet,8100,11-07-00-003-1,11
148 | Sugar cane,8200,11-11-01-010-0,11
149 | Sweet sorghum,8300,11-01-07-004-0,11
150 | Other sugar crops,8900,10-00-00-000-0,10
151 | Other crops,9000,10-00-00-000-0,10
152 | Grasses and other fodder crops,9100,11-11-00-000-0,11
153 | Temporary grass crops,9110,11-11-00-001-0,11
154 | Permanent grass crops,9120,13-00-00-000-0,13
155 | Fibre crops,9200,10-00-00-000-0,10
156 | Temporary fibre crops,9210,11-08-00-000-0,11
157 | Cotton,9211,11-08-00-001-0,11
158 | "Jute, kenaf and similar",9212,11-08-00-002-0,11
159 | "Flax, hemp and similar",9213,11-08-02-000-0,11
160 | Other temporary fibre crops,9219,11-08-00-000-0,11
161 | Permanent fibre crops,9220,12-05-00-000-0,12
162 | "Medicinal, aromatic, pesticidal crops",9300,10-00-00-000-0,10
163 | Temporary medicinal etc crops,9310,11-09-00-000-0,11
164 | Permanent medicinal etc crops,9320,12-02-00-000-0,12
165 | Rubber,9400,12-06-00-008-0,12
166 | Flower crops,9500,11-10-00-000-0,11
167 | Temporary flower crops,9510,11-10-00-000-0,11
168 | Permanent flower crops,9520,12-06-00-000-0,12
169 | Tobacco,9600,11-09-00-050-0,11
170 | Other other crops,9900,10-00-00-000-0,10
171 | Other crops - temporary,9910,11-00-00-000-0,11
172 | Other crops - permanent,9920,12-00-00-000-0,12
173 | mixed cropping,9998,10-00-00-000-0,10
174 | Cropland,10,10-00-00-000-0,10
175 | Annual cropland,11,11-00-00-000-0,11
176 | Perennial cropland,12,12-00-00-000-0,12
177 | Grassland *,13,13-00-00-000-0,13
178 | Herbaceous vegetation,20,20-00-00-000-0,20
179 | Shrubland,30,30-00-00-000-0,30
180 | Deciduous forest,40,43-02-00-000-0,40
181 | Evergreen forest,41,43-01-00-000-0,41
182 | Mixed/unknown forest,42,43-00-00-000-0,42
183 | Bare / sparse vegetation,50,50-00-00-000-0,50
184 | Built up / urban,60,60-00-00-000-0,60
185 | Water,70,70-00-00-000-0,70
186 | Snow / ice,80,50-05-00-000-0,80
187 | No cropland (including perennials),98,17-00-00-000-0,98
188 | No cropland,99,16-00-00-000-0,99
189 | Unknown,991,00-00-00-000-0,0
190 | Unknown,9700,00-00-00-000-0,0
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling", "hatch-vcs"]
3 | build-backend = "hatchling.build"
4 |
5 | [tool.hatch.build]
6 | exclude = [
7 | "/dist",
8 | "/notebooks",
9 | "/scripts",
10 | "/bin",
11 | "/tests",
12 | ]
13 |
14 | [tool.hatch.version]
15 | path = "src/worldcereal/_version.py"
16 | pattern = "^__version__ = ['\"](?P[^'\"]+)['\"]$"
17 |
18 | [tool.hatch.metadata]
19 | allow-direct-references = true
20 |
21 | [project]
22 | name = "worldcereal"
23 | authors = [
24 | { name="Kristof Van Tricht" },
25 | { name="Jeroen Degerickx" },
26 | { name="Darius Couchard" },
27 | { name="Christina Butsko" },
28 | ]
29 | description = "WorldCereal classification module"
30 | readme = "README.md"
31 | requires-python = ">=3.8"
32 | dynamic = ["version"]
33 | classifiers = [
34 | "Programming Language :: Python :: 3",
35 | "Operating System :: OS Independent",
36 | ]
37 | dependencies = [
38 | "boto3==1.35.30",
39 | "cftime",
40 | "geojson",
41 | "geopandas",
42 | "h3==4.1.0",
43 | "h5netcdf>=1.1.0",
44 | "loguru>=0.7.2",
45 | "netcdf4<=1.6.4",
46 | "numpy<2.0.0",
47 | "openeo==0.35.0",
48 | "openeo-gfmap==0.4.6",
49 | "pyarrow",
50 | "pydantic==2.8.0",
51 | "rioxarray>=0.13.0",
52 | "scipy",
53 | "duckdb==1.1.3",
54 | "tqdm",
55 | "xarray>=2022.3.0"
56 | ]
57 |
58 | [project.urls]
59 | "Homepage" = "https://github.com/WorldCereal/worldcereal-classification"
60 | "Bug Tracker" = "https://github.com/WorldCereal/worldcereal-classification/issues"
61 |
62 | [project.optional-dependencies]
63 | dev = [
64 | "pytest>=7.4.0",
65 | "pytest-depends",
66 | "matplotlib>=3.3.0"
67 | ]
68 | train = [
69 | "catboost==1.2.5",
70 | "presto-worldcereal==0.1.6",
71 | "scikit-learn==1.5.0",
72 | "torch==2.3.1",
73 | "pystac==1.10.1",
74 | "pystac-client==0.8.3"
75 | ]
76 | notebooks = [
77 | "ipywidgets==8.1.3",
78 | "leafmap==0.35.1"
79 | ]
80 |
81 | [tool.pytest.ini_options]
82 | testpaths = [
83 | "tests",
84 | ]
85 | addopts = [
86 | "--import-mode=prepend",
87 | ]
88 |
89 | [tool.isort]
90 | profile = "black"
91 |
92 | [tool.ruff]
93 | # line-length = 88
94 |
95 | [tool.ruff.lint]
96 | select = ["E", "F"]
97 | ignore = [
98 | "E501", # Ignore "line-too-long" issues, let black handle that.
99 | ]
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | testpaths = tests
3 | addopts = --strict-markers
4 | markers =
5 | slow: mark for slow tests (skip these tests by adding option '-m "not slow"')
6 |
--------------------------------------------------------------------------------
/scripts/extractions/extract.py:
--------------------------------------------------------------------------------
1 | """Main script to perform extractions. Each collection has it's specifities and
2 | own functions, but the setup and main thread execution is done here."""
3 |
4 | import argparse
5 | from pathlib import Path
6 | from typing import Dict, Optional, Union
7 |
8 | from openeo_gfmap import Backend
9 |
10 | from worldcereal.extract.common import run_extractions
11 | from worldcereal.stac.constants import ExtractionCollection
12 |
13 |
14 | def main(
15 | collection: ExtractionCollection,
16 | output_folder: Path,
17 | samples_df_path: Path,
18 | ref_id: str,
19 | max_locations_per_job: int = 500,
20 | memory: Optional[str] = None,
21 | python_memory: Optional[str] = None,
22 | max_executors: Optional[int] = None,
23 | parallel_jobs: int = 2,
24 | restart_failed: bool = False,
25 | extract_value: int = 1,
26 | backend=Backend.CDSE,
27 | write_stac_api: bool = False,
28 | image_name: Optional[str] = None,
29 | organization_id: Optional[int] = None,
30 | ) -> None:
31 | """Main function responsible for launching point and patch extractions.
32 |
33 | Parameters
34 | ----------
35 | collection : ExtractionCollection
36 | The collection to extract. Most popular: PATCH_WORLDCEREAL, POINT_WORLDCEREAL
37 | output_folder : Path
38 | The folder where to store the extracted data
39 | samples_df_path : Path
40 | Path to the input dataframe containing the geometries
41 | for which extractions need to be done
42 | ref_id : str
43 | Official ref_id of the source dataset
44 | max_locations_per_job : int, optional
45 | The maximum number of locations to extract per job, by default 500
46 | memory : str, optional
47 | Memory to allocate for the executor.
48 | If not specified, the default value is used, depending on type of collection.
49 | python_memory : str, optional
50 | Memory to allocate for the python processes as well as OrfeoToolbox in the executors,
51 | If not specified, the default value is used, depending on type of collection.
52 | max_executors : int, optional
53 | Number of executors to run.
54 | If not specified, the default value is used, depending on type of collection.
55 | parallel_jobs : int, optional
56 | The maximum number of parallel jobs to run at the same time, by default 10
57 | restart_failed : bool, optional
58 | Restart the jobs that previously failed, by default False
59 | extract_value : int, optional
60 | All samples with an "extract" value equal or larger than this one, will be extracted, by default 1
61 | backend : openeo_gfmap.Backend, optional
62 | cloud backend where to run the extractions, by default Backend.CDSE
63 | write_stac_api : bool, optional
64 | Save metadata of extractions to STAC API (requires authentication), by default False
65 | image_name : str, optional
66 | Specific openEO image name to use for the jobs, by default None
67 | organization_id : int, optional
68 | ID of the organization to use for the job, by default None in which case
69 | the active organization for the user is used
70 |
71 | Returns
72 | -------
73 | None
74 | """
75 |
76 | # Compile custom job options
77 | job_options: Optional[Dict[str, Union[str, int]]] = {
78 | key: value
79 | for key, value in {
80 | "executor-memory": memory,
81 | "python-memory": python_memory,
82 | "max-executors": max_executors,
83 | "image-name": image_name,
84 | "etl_organization_id": organization_id,
85 | }.items()
86 | if value is not None
87 | } or None
88 |
89 | # We need to be sure that the
90 | # output_folder points to the correct ref_id
91 | if output_folder.name != ref_id:
92 | raise ValueError(
93 | f"`root_folder` should point to ref_id `{ref_id}`, instead got: {output_folder}"
94 | )
95 |
96 | # Fire up extractions
97 | run_extractions(
98 | collection,
99 | output_folder,
100 | samples_df_path,
101 | ref_id,
102 | max_locations_per_job=max_locations_per_job,
103 | job_options=job_options,
104 | parallel_jobs=parallel_jobs,
105 | restart_failed=restart_failed,
106 | extract_value=extract_value,
107 | backend=backend,
108 | write_stac_api=write_stac_api,
109 | )
110 |
111 | return
112 |
113 |
114 | if __name__ == "__main__":
115 | parser = argparse.ArgumentParser(description="Extract data from a collection")
116 | parser.add_argument(
117 | "collection",
118 | type=ExtractionCollection,
119 | choices=list(ExtractionCollection),
120 | help="The collection to extract",
121 | )
122 | parser.add_argument(
123 | "output_folder", type=Path, help="The folder where to store the extracted data"
124 | )
125 | parser.add_argument(
126 | "samples_df_path",
127 | type=Path,
128 | help="Path to the samples dataframe with the data to extract",
129 | )
130 | parser.add_argument(
131 | "--ref_id",
132 | type=str,
133 | required=True,
134 | help="Official `ref_id` of the source dataset",
135 | )
136 | parser.add_argument(
137 | "--max_locations",
138 | type=int,
139 | default=500,
140 | help="The maximum number of locations to extract per job",
141 | )
142 | parser.add_argument(
143 | "--memory",
144 | type=str,
145 | default="1800m",
146 | help="Memory to allocate for the executor.",
147 | )
148 | parser.add_argument(
149 | "--python_memory",
150 | type=str,
151 | default="1900m",
152 | help="Memory to allocate for the python processes as well as OrfeoToolbox in the executors.",
153 | )
154 | parser.add_argument(
155 | "--max_executors", type=int, default=22, help="Number of executors to run."
156 | )
157 | parser.add_argument(
158 | "--parallel_jobs",
159 | type=int,
160 | default=2,
161 | help="The maximum number of parallel jobs to run at the same time.",
162 | )
163 | parser.add_argument(
164 | "--restart_failed",
165 | action="store_true",
166 | help="Restart the jobs that previously failed.",
167 | )
168 | parser.add_argument(
169 | "--extract_value",
170 | type=int,
171 | default=1,
172 | help="The value of the `extract` flag to use in the dataframe.",
173 | )
174 | parser.add_argument(
175 | "--write_stac_api",
176 | type=bool,
177 | default=False,
178 | help="Flag to write S1 and S2 patch extraction results to STAC API or not.",
179 | )
180 | parser.add_argument(
181 | "--image_name",
182 | type=str,
183 | default=None,
184 | help="Specific openEO image name to use for the jobs.",
185 | )
186 | parser.add_argument(
187 | "--organization_id",
188 | type=int,
189 | default=None,
190 | help="ID of the organization to use for the job.",
191 | )
192 |
193 | args = parser.parse_args()
194 |
195 | main(
196 | collection=args.collection,
197 | output_folder=args.output_folder,
198 | samples_df_path=args.samples_df_path,
199 | ref_id=args.ref_id,
200 | max_locations_per_job=args.max_locations,
201 | memory=args.memory,
202 | python_memory=args.python_memory,
203 | max_executors=args.max_executors,
204 | parallel_jobs=args.parallel_jobs,
205 | restart_failed=args.restart_failed,
206 | extract_value=args.extract_value,
207 | backend=Backend.CDSE,
208 | write_stac_api=args.write_stac_api,
209 | image_name=args.image_name,
210 | organization_id=args.organization_id,
211 | )
212 |
--------------------------------------------------------------------------------
/scripts/inference/collect_inputs.py:
--------------------------------------------------------------------------------
1 | from openeo_gfmap import BoundingBoxExtent, TemporalContext
2 |
3 | from worldcereal.job import collect_inputs
4 |
5 |
6 | def main():
7 |
8 | # Set the spatial extent
9 | # bbox_utm = (664000.0, 5611120.0, 665000.0, 5612120.0)
10 | bbox_utm = (664000, 5611134, 684000, 5631134) # Large test
11 | epsg = 32631
12 | spatial_extent = BoundingBoxExtent(*bbox_utm, epsg)
13 |
14 | # Set temporal range
15 | temporal_extent = TemporalContext(
16 | start_date="2020-11-01",
17 | end_date="2021-10-31",
18 | )
19 |
20 | outfile = "local_presto_inputs_large.nc"
21 | collect_inputs(spatial_extent, temporal_extent, output_path=outfile)
22 |
23 |
24 | if __name__ == "__main__":
25 | main()
26 |
--------------------------------------------------------------------------------
/scripts/inference/cropland_mapping.py:
--------------------------------------------------------------------------------
1 | """Cropland mapping inference script, demonstrating the use of the GFMAP, Presto and WorldCereal classifiers in a first inference pipeline."""
2 |
3 | import argparse
4 | from pathlib import Path
5 |
6 | from loguru import logger
7 | from openeo_gfmap import BoundingBoxExtent, TemporalContext
8 | from openeo_gfmap.backend import Backend, BackendContext
9 |
10 | from worldcereal.job import WorldCerealProductType, generate_map
11 | from worldcereal.parameters import PostprocessParameters
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser(
15 | prog="WC - Crop Mapping Inference",
16 | description="Crop Mapping inference using GFMAP, Presto and WorldCereal classifiers",
17 | )
18 |
19 | parser.add_argument("minx", type=float, help="Minimum X coordinate (west)")
20 | parser.add_argument("miny", type=float, help="Minimum Y coordinate (south)")
21 | parser.add_argument("maxx", type=float, help="Maximum X coordinate (east)")
22 | parser.add_argument("maxy", type=float, help="Maximum Y coordinate (north)")
23 | parser.add_argument(
24 | "start_date", type=str, help="Starting date for data extraction."
25 | )
26 | parser.add_argument("end_date", type=str, help="Ending date for data extraction.")
27 | parser.add_argument(
28 | "product",
29 | type=str,
30 | help="Product to generate. One of ['cropland', 'croptype']",
31 | )
32 | parser.add_argument(
33 | "output_path",
34 | type=Path,
35 | help="Path to folder where to save the resulting GeoTiff.",
36 | )
37 | parser.add_argument(
38 | "--epsg",
39 | type=int,
40 | default=4326,
41 | help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.",
42 | )
43 | parser.add_argument(
44 | "--postprocess",
45 | action="store_true",
46 | help="Run postprocessing on the croptype and/or the cropland product after inference.",
47 | )
48 | parser.add_argument(
49 | "--class-probabilities",
50 | action="store_true",
51 | help="Output per-class probabilities in the resulting product",
52 | )
53 |
54 | args = parser.parse_args()
55 |
56 | minx = args.minx
57 | miny = args.miny
58 | maxx = args.maxx
59 | maxy = args.maxy
60 | epsg = args.epsg
61 |
62 | start_date = args.start_date
63 | end_date = args.end_date
64 |
65 | product = args.product
66 |
67 | # minx, miny, maxx, maxy = (664000, 5611134, 665000, 5612134) # Small test
68 | # minx, miny, maxx, maxy = (664000, 5611134, 684000, 5631134) # Large test
69 | # epsg = 32631
70 | # start_date = "2020-11-01"
71 | # end_date = "2021-10-31"
72 | # product = "croptype"
73 |
74 | spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy, epsg)
75 | temporal_extent = TemporalContext(start_date, end_date)
76 |
77 | backend_context = BackendContext(Backend.CDSE)
78 |
79 | job_results = generate_map(
80 | spatial_extent,
81 | temporal_extent,
82 | args.output_path,
83 | product_type=WorldCerealProductType(product),
84 | postprocess_parameters=PostprocessParameters(
85 | enable=args.postprocess, keep_class_probs=args.class_probabilities
86 | ),
87 | out_format="GTiff",
88 | backend_context=backend_context,
89 | )
90 | logger.success(f"Job finished:\n\t{job_results}")
91 |
--------------------------------------------------------------------------------
/scripts/inference/cropland_mapping_local.py:
--------------------------------------------------------------------------------
1 | """Perform cropland mapping inference using a local execution of presto.
2 |
3 | Make sure you test this script on the Python version 3.9+, and have worldcereal
4 | dependencies installed with the presto wheel file installed with it's dependencies.
5 | """
6 |
7 | from pathlib import Path
8 |
9 | import requests
10 | import xarray as xr
11 | from openeo_gfmap.features.feature_extractor import (
12 | EPSG_HARMONIZED_NAME,
13 | apply_feature_extractor_local,
14 | )
15 | from openeo_gfmap.inference.model_inference import apply_model_inference_local
16 |
17 | from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
18 | from worldcereal.openeo.inference import CropClassifier
19 |
20 | TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc"
21 | TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc"
22 | PRESTO_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt"
23 | CATBOOST_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx"
24 |
25 | if __name__ == "__main__":
26 | if not TEST_FILE_PATH.exists():
27 | print("Downloading test input data...")
28 | # Download the test input data
29 | with requests.get(TEST_FILE_URL, stream=True, timeout=180) as response:
30 | response.raise_for_status()
31 | with open(TEST_FILE_PATH, "wb") as f:
32 | for chunk in response.iter_content(chunk_size=8192):
33 | f.write(chunk)
34 |
35 | print("Loading array in-memory...")
36 | arr = (
37 | xr.open_dataset(TEST_FILE_PATH)
38 | .to_array(dim="bands")
39 | .drop_sel(bands="crs")
40 | .astype("uint16")
41 | )
42 |
43 | print("Running presto UDF locally")
44 | features = apply_feature_extractor_local(
45 | PrestoFeatureExtractor,
46 | arr,
47 | parameters={
48 | EPSG_HARMONIZED_NAME: 32631,
49 | "ignore_dependencies": True,
50 | "presto_model_url": PRESTO_URL,
51 | },
52 | )
53 |
54 | features.to_netcdf(Path.cwd() / "presto_test_features_cropland.nc")
55 |
56 | print("Running classification inference UDF locally")
57 |
58 | classification = apply_model_inference_local(
59 | CropClassifier,
60 | features,
61 | parameters={
62 | EPSG_HARMONIZED_NAME: 32631,
63 | "ignore_dependencies": True,
64 | "classifier_url": CATBOOST_URL,
65 | },
66 | )
67 |
68 | classification.to_netcdf(Path.cwd() / "test_classification_cropland.nc")
69 |
--------------------------------------------------------------------------------
/scripts/inference/cropland_mapping_udp.py:
--------------------------------------------------------------------------------
1 | """This short script demonstrates how to run WorldCereal cropland extent inference through
2 | an OpenEO User-Defined Process (UDP) on CDSE.
3 |
4 | The WorldCereal default cropland model is used and default post-processing is applied
5 | (using smooth_probabilities method).
6 |
7 | The user needs to manually download the resulting map through the OpenEO Web UI:
8 | https://openeo.dataspace.copernicus.eu/
9 | Or use the openeo API to fetch the resulting map.
10 | """
11 |
12 | import openeo
13 |
14 | ###### USER INPUTS ######
15 | # Define the spatial and temporal extent
16 | spatial_extent = {
17 | "west": 3.809252,
18 | "south": 51.232365,
19 | "east": 3.833542,
20 | "north": 51.245477,
21 | "crs": "EPSG:4326",
22 | "srs": "EPSG:4326",
23 | }
24 |
25 | # Temporal extent needs to consist of exactly twelve months,
26 | # always starting first day of the month and ending last day of the month.
27 | temporal_extent = ["2020-01-01", "2020-12-31"]
28 |
29 | # This argument is optional and determines the projection of the output products
30 | target_projection = "EPSG:4326"
31 |
32 | # Here you can set custom job options, such as driver-memory, executor-memory, etc.
33 | # In case you would like to use the default options, set job_options to None.
34 |
35 | # job_options = {
36 | # "driver-memory": "4g",
37 | # "executor-memory": "2g",
38 | # "python-memory": "3g",
39 | # }
40 | job_options = None
41 |
42 | ###### END OF USER INPUTS ######
43 |
44 | # Connect to openeo backend
45 | c = openeo.connect("openeo.dataspace.copernicus.eu").authenticate_oidc()
46 |
47 | # Define the operations
48 | classes = c.datacube_from_process(
49 | process_id="worldcereal_crop_extent",
50 | namespace="https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/refs/tags/worldcereal_crop_extent_v1.0.2/src/worldcereal/udp/worldcereal_crop_extent.json",
51 | temporal_extent=temporal_extent,
52 | spatial_extent=spatial_extent,
53 | projection=target_projection,
54 | )
55 |
56 | # Run the job
57 | job = classes.execute_batch(
58 | title="WorldCereal Crop Extent UDP test",
59 | out_format="GTiff",
60 | job_options=job_options,
61 | )
62 |
63 | ###### EXPLANATION OF RESULTS ######
64 | # Two GeoTiff images will be generated:
65 | # 1. "worldcereal-cropland-extent.tif": The original results of the WorldCereal cropland model
66 | # 2. "worldcereal-cropland-extent-postprocessed.tif": The results after applying the default post-processing,
67 | # which includes a smoothing of the probabilities and relabelling of pixels according to the smoothed probabilities.
68 |
69 | # Each tif file contains 4 bands, all at 10m resolution:
70 | # - Band 1: Predicted cropland labels (0: non-cropland, 1: cropland)
71 | # - Band 2: Probability of the winning class (50 - 100)
72 | # - Band 3: Probability of the non-cropland class (0 - 100)
73 | # - Band 4: Probability of the cropland class (0 - 100)
74 |
--------------------------------------------------------------------------------
/scripts/inference/croptype_mapping_udp.py:
--------------------------------------------------------------------------------
1 | """This short script demonstrates how to run WorldCereal crop type inference through
2 | an OpenEO User-Defined Process (UDP) on CDSE.
3 |
4 | No default model is currently available, meaning that a user first needs to
5 | train a custom model before using this script. Training a custom model can be done
6 | through the WorldCereal Custom Crop Type Training Notebook available here:
7 | https://github.com/WorldCereal/worldcereal-classification/blob/main/notebooks/worldcereal_custom_croptype.ipynb
8 |
9 | A user is required to specify a spatial and temporal extent, as well as the link to
10 | the crop type model to be used. The model should be hosted on a publicly accessible server.
11 |
12 | The user needs to manually download the resulting map through the OpenEO Web UI:
13 | https://openeo.dataspace.copernicus.eu/
14 | Or use the openeo API to fetch the resulting maps.
15 | """
16 |
17 | import openeo
18 |
19 | ###### USER INPUTS ######
20 | # Define the spatial and temporal extent
21 | spatial_extent = {
22 | "west": 622694.5968575787,
23 | "south": 5672232.857114074,
24 | "east": 623079.000934101,
25 | "north": 5672519.995940826,
26 | "crs": "EPSG:32631",
27 | "srs": "EPSG:32631",
28 | }
29 |
30 | # Temporal extent needs to consist of exactly twelve months,
31 | # always starting first day of the month and ending last day of the month.
32 | # Ideally, the season of interest should be nicely centered within the selected twelve months period.
33 | temporal_extent = ["2018-11-01", "2019-10-30"]
34 |
35 | # Provide the link to your custom model
36 | # The model should be in ONNX format and publicly accessible.
37 | model_url = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/downstream/tests/be_multiclass-test_custommodel.onnx"
38 |
39 | # Sentinel-1 orbit state, either "ASCENDING" or "DESCENDING".
40 | # In the future, this setting will be automatically determined by the system based on the regional extent.
41 | # For now, the user needs to explicitly specify the orbit state based on the most dominant orbit state in the region of interest.
42 | # The user can be guided to the global map available here: https://docs.sentinel-hub.com/api/latest/data/sentinel-1-grd/
43 | # to make this decision.
44 | orbit_state = "ASCENDING" # system's default is "DESCENDING"
45 |
46 | ## OPTIONAL PARAMETERS
47 | # Post-processing method, available choices are "majority_vote" and "smooth_probabilities".
48 | # System's default is "smooth_probabilities", but we prefer "majority_vote" for this example.
49 | postprocess_method = "majority_vote"
50 | # If the postprocess_method is set to "majority_vote", the user can specify the kernel size.
51 | # The system's default is 5
52 | postprocess_kernel_size = 5
53 |
54 | ###### END OF USER INPUTS ######
55 |
56 | # Connect to openeo backend
57 | c = openeo.connect("openeo.dataspace.copernicus.eu").authenticate_oidc()
58 |
59 | # Define the operations
60 | cube = c.datacube_from_process(
61 | process_id="worldcereal_crop_type",
62 | namespace="https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/refs/heads/327-lut-crop-type/src/worldcereal/udp/worldcereal_crop_type.json",
63 | spatial_extent=spatial_extent,
64 | temporal_extent=temporal_extent,
65 | model_url=model_url,
66 | orbit_state=orbit_state,
67 | postprocess_method=postprocess_method,
68 | postprocess_kernel_size=postprocess_kernel_size,
69 | )
70 |
71 | # Run the job
72 | job = cube.execute_batch(
73 | title="Test worldcereal_crop_type UDP",
74 | )
75 |
76 | ###### EXPLANATION OF RESULTS ######
77 | # Four GeoTiff images will be generated.
78 |
79 | # Each raster contains at least three bands:
80 | # - Band 1: "classification": The classification label of the pixel.
81 | # - Band 2: "confidence": The class-specific probablity of the winning class.
82 | # - Band 3 and beyond: "probability_xxx": Class-specific probablities. The "xxx" indicates the associated class.
83 |
84 | # The following files will be generated:
85 | # 1. "cropland-raw.tif": The original temporary crops mask, generated by the default global WorldCereal temporary crops model.
86 | # This model differentiates temporary crops from all other land cover types.
87 | # (0: other land cover, 1: temporary crops, 255: no data)
88 |
89 | # 2. "cropland-postprocessed.tif": The final temporary crops mask after spatial cleaning of the raw result.
90 | # (0: other land cover, 1: temporary crops, 255: no data)
91 |
92 | # 3. "croptype-raw.tif": The original crop type map, generated by the user's custom crop type model.
93 | # The crop type model is only applied to the temporary crops pixels.
94 | # (labels according to the user's custom model, 254: no temporary crops, 65535: no data)
95 |
96 | # 4. "croptype-postprocessed.tif": The final crop type map after spatial cleaning of the raw result.
97 | # (labels according to the user's custom model, 254: no temporary crops, 65535: no data)
98 |
--------------------------------------------------------------------------------
/scripts/misc/legend.py:
--------------------------------------------------------------------------------
1 | """
2 | Example script showing how to upload, download and delete the WorldCereal crop type legend from Artifactory.
3 | """
4 |
5 | from pathlib import Path
6 |
7 | from worldcereal.utils.legend import (
8 | delete_legend_file,
9 | download_legend,
10 | get_legend,
11 | upload_legend,
12 | )
13 |
14 | if __name__ == "__main__":
15 |
16 | # Example usage
17 | srcpath = Path("./WorldCereal_LC_CT_legend_20241231.csv")
18 | date = "20241231"
19 |
20 | # Upload the legend to Artifactory
21 | link = upload_legend(srcpath, date)
22 |
23 | # Get the latest legend from Artifactory (as pandas DataFrame)
24 | legend = get_legend()
25 |
26 | legend_irr = get_legend(topic="irrigation")
27 |
28 | # Download the latest legend from Artifactory
29 | legend_path = download_legend(Path("."))
30 |
31 | irr_legend_path = download_legend(Path("."), topic="irrigation")
32 |
33 | # Delete the uploaded legend from Artifactory
34 | delete_legend_file(link)
35 |
--------------------------------------------------------------------------------
/scripts/spark/compute_presto_features.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from pathlib import Path
3 | from typing import Union
4 |
5 | import pandas as pd
6 | from loguru import logger
7 | from presto.presto import Presto
8 | from presto.utils import process_parquet
9 |
10 | from worldcereal.train.data import WorldCerealTrainingDataset, get_training_df
11 |
12 |
13 | def embeddings_from_parquet_file(
14 | parquet_file: Union[str, Path],
15 | pretrained_model: Presto,
16 | sample_repeats: int = 1,
17 | mask_ratio: float = 0.0,
18 | valid_date_as_token: bool = False,
19 | exclude_meteo: bool = False,
20 | ) -> pd.DataFrame:
21 | """Method to compute Presto embeddings from a parquet file of preprocessed inputs
22 |
23 | Parameters
24 | ----------
25 | parquet_file : Union[str, Path]
26 | parquet file to read data from
27 | pretrained_model : Presto
28 | Presto model to use for computing embeddings
29 | sample_repeats : int, optional
30 | number of augmented sample repeats, by default 1
31 | mask_ratio : float, optional
32 | mask ratio to apply, by default 0.0
33 | valid_date_as_token : bool, optional
34 | feed valid date as a token to Presto, by default False
35 | exclude_meteo : bool, optional
36 | if True, meteo will be masked during embedding computation, by default False
37 |
38 | Returns
39 | -------
40 | pd.DataFrame
41 | DataFrame containing Presto embeddings and other attributes
42 | """
43 |
44 | # Load parquet file
45 | logger.info(f"Processing {parquet_file}")
46 | df = pd.read_parquet(parquet_file)
47 |
48 | # Original file is partitioned by ref_id so we have to add
49 | # ref_id as a column manually
50 | df["ref_id"] = Path(parquet_file).parent.stem.split("=")[1]
51 |
52 | # Put meteo to nodata if needed
53 | if exclude_meteo:
54 | logger.warning("Excluding meteo data ...")
55 | df.loc[:, df.columns.str.contains("AGERA5")] = 65535
56 |
57 | # Convert to wide format
58 | logger.info("From long to wide format ...")
59 | df = process_parquet(df).reset_index()
60 |
61 | # Check if samples remain
62 | if df.empty:
63 | logger.warning("Empty dataframe: returning None")
64 | return None
65 |
66 | # Create dataset and dataloader
67 | logger.info("Making data loader ...")
68 | ds = WorldCerealTrainingDataset(
69 | df,
70 | task_type="cropland",
71 | augment=True,
72 | mask_ratio=mask_ratio,
73 | repeats=sample_repeats,
74 | )
75 |
76 | if len(ds) == 0:
77 | logger.warning("No valid samples in dataset: returning None")
78 | return None
79 |
80 | return get_training_df(
81 | ds,
82 | pretrained_model,
83 | batch_size=2048,
84 | valid_date_as_token=valid_date_as_token,
85 | num_workers=4,
86 | )
87 |
88 |
89 | def main(
90 | infile,
91 | outfile,
92 | presto_model,
93 | sc=None,
94 | debug=False,
95 | sample_repeats: int = 1,
96 | mask_ratio: float = 0.0,
97 | valid_date_as_token: bool = False,
98 | exclude_meteo: bool = False,
99 | ):
100 |
101 | logger.info(
102 | f"Starting embedding computation with augmentation (sample_repeats: {sample_repeats}, mask_ratio: {mask_ratio})"
103 | )
104 | logger.info(f"Valid date as token: {valid_date_as_token}")
105 |
106 | # List parquet files
107 | parquet_files = glob.glob(str(infile) + "/**/*.parquet", recursive=True)
108 |
109 | if debug:
110 | parquet_files = parquet_files[:2]
111 |
112 | # Load model
113 | logger.info("Loading model ...")
114 | pretrained_model = Presto.load_pretrained(
115 | model_path=presto_model, strict=False, valid_month_as_token=valid_date_as_token
116 | )
117 |
118 | if sc is not None:
119 | logger.info(f"Parallelizing {len(parquet_files)} files ...")
120 | dfs = (
121 | sc.parallelize(parquet_files, len(parquet_files))
122 | .map(
123 | lambda x: embeddings_from_parquet_file(
124 | x,
125 | pretrained_model,
126 | sample_repeats,
127 | mask_ratio,
128 | valid_date_as_token,
129 | exclude_meteo,
130 | )
131 | )
132 | .filter(lambda x: x is not None)
133 | .collect()
134 | )
135 | else:
136 | dfs = []
137 | for parquet_file in parquet_files:
138 | dfs.append(
139 | embeddings_from_parquet_file(
140 | parquet_file,
141 | pretrained_model,
142 | sample_repeats,
143 | mask_ratio,
144 | valid_date_as_token,
145 | exclude_meteo,
146 | )
147 | )
148 |
149 | if isinstance(dfs, list):
150 | logger.info(f"Done processing: concatenating {len(dfs)} results")
151 | dfs = pd.concat(dfs, ignore_index=True)
152 |
153 | logger.info(f"Final dataframe shape: {dfs.shape}")
154 |
155 | # Write to parquet
156 | logger.info(f"Writing to parquet: {outfile}")
157 | dfs.to_parquet(outfile, index=False)
158 |
159 | logger.success("All done!")
160 |
161 |
162 | if __name__ == "__main__":
163 | # Output feature basedir
164 | baseoutdir = Path(
165 | "/vitodata/worldcereal/features/preprocessedinputs-monthly-nointerp"
166 | )
167 |
168 | spark = True
169 | localspark = False
170 | debug = False
171 | exclude_meteo = False
172 | sample_repeats = 1
173 | valid_date_as_token = False
174 | presto_dir = Path("/vitodata/worldcereal/presto/finetuning")
175 | presto_model = (
176 | presto_dir
177 | / "presto-ss-wc-ft-ct_cropland_CROPLAND2_30D_random_time-token=none_balance=True_augment=True.pt"
178 | )
179 | identifier = ""
180 |
181 | if spark:
182 | from worldcereal.utils.spark import get_spark_context
183 |
184 | logger.info("Setting up spark ...")
185 | sc = get_spark_context(localspark=localspark)
186 | else:
187 | sc = None
188 |
189 | infile = baseoutdir / "worldcereal_training_data.parquet"
190 |
191 | if debug:
192 | outfile = baseoutdir / (
193 | f"training_df_{presto_model.stem}_presto-worldcereal{identifier}_DEBUG.parquet"
194 | )
195 | else:
196 | outfile = baseoutdir / (
197 | f"training_df_{presto_model.stem}_presto-worldcereal{identifier}.parquet"
198 | )
199 |
200 | if not infile.exists():
201 | raise FileNotFoundError(infile)
202 |
203 | main(
204 | infile,
205 | outfile,
206 | presto_model,
207 | sc=sc,
208 | debug=debug,
209 | sample_repeats=sample_repeats,
210 | valid_date_as_token=valid_date_as_token,
211 | exclude_meteo=exclude_meteo,
212 | )
213 |
--------------------------------------------------------------------------------
/scripts/spark/compute_presto_features.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # shellcheck disable=SC2140
3 |
4 | export SPARK_HOME=/opt/spark3_2_0/
5 | export PATH="$SPARK_HOME/bin:$PATH"
6 | export PYTHONPATH=wczip
7 |
8 | cd src || exit
9 | zip -r ../dist/worldcereal.zip worldcereal
10 | cd ..
11 |
12 | EX_JAVAMEM='2g'
13 | EX_PYTHONMEM='14g'
14 | DR_JAVAMEM='8g'
15 | DR_PYTHONMEM='16g'
16 |
17 | PYSPARK_PYTHON=./ewocenv/bin/python \
18 | ${SPARK_HOME}/bin/spark-submit \
19 | --conf spark.yarn.appMasterEnv.PYSPARK_PYTHON="./ewocenv/bin/python" \
20 | --conf spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON="./ewocenv/bin/python" \
21 | --conf spark.executorEnv.LD_LIBRARY_PATH="./ewocenv/lib" \
22 | --conf spark.yarn.appMasterEnv.LD_LIBRARY_PATH="./ewocenv/lib" \
23 | --conf spark.executorEnv.PYSPARK_PYTHON="./ewocenv/bin/python" \
24 | --conf spark.yarn.appMasterEnv.PYTHONPATH=$PYTHONPATH \
25 | --conf spark.executorEnv.PYTHONPATH=$PYTHONPATH \
26 | --executor-memory ${EX_JAVAMEM} --driver-memory ${DR_JAVAMEM} \
27 | --conf spark.yarn.appMasterEnv.PYTHON_EGG_CACHE=./ \
28 | --conf spark.executorEnv.GDAL_CACHEMAX=512 \
29 | --conf spark.speculation=true \
30 | --conf spark.sql.broadcastTimeout=36000 \
31 | --conf spark.shuffle.registration.timeout=36000 \
32 | --conf spark.sql.execution.arrow.pyspark.enabled=true \
33 | --conf spark.shuffle.memoryFraction=0.6 \
34 | --conf spark.hadoop.mapreduce.output.fileoutputformat.compress=true \
35 | --conf spark.hadoop.mapreduce.output.fileoutputformat.compress.codec=org.apache.hadoop.io.compress.GzipCodec \
36 | --conf spark.driver.memoryOverhead=${DR_PYTHONMEM} --conf spark.executor.memoryOverhead=${EX_PYTHONMEM} \
37 | --conf spark.memory.fraction=0.2 \
38 | --conf spark.executor.cores=4 \
39 | --conf spark.task.cpus=4 \
40 | --conf spark.driver.cores=4 \
41 | --conf spark.dynamicAllocation.maxExecutors=1000 \
42 | --conf spark.shuffle.service.enabled=true --conf spark.dynamicAllocation.enabled=true \
43 | --conf spark.driver.maxResultSize=0 \
44 | --master yarn --deploy-mode cluster --queue default \
45 | --py-files "/vitodata/worldcereal/software/wheels/presto_worldcereal-0.1.6-py3-none-any.whl" \
46 | --archives "dist/worldcereal.zip#wczip","hdfs:///tapdata/worldcereal/worldcereal.tar.gz#ewocenv" \
47 | --conf spark.app.name="worldcereal-presto_features" \
48 | scripts/spark/compute_presto_features.py \
49 |
--------------------------------------------------------------------------------
/scripts/stac/build_paths.py:
--------------------------------------------------------------------------------
1 | """From a folder of extracted patches, generates a list of paths to the patches
2 | to be later parsed by the Spark cluster.
3 | """
4 |
5 | import argparse
6 | import os
7 | import pickle
8 | import re
9 | from concurrent.futures import ThreadPoolExecutor
10 | from pathlib import Path
11 |
12 | from worldcereal.stac.constants import COLLECTION_REGEXES, ExtractionCollection
13 |
14 |
15 | def iglob_files(path, notification_folders, pattern=None):
16 | """
17 | Generator that finds all subfolders of path containing the regex `pattern`
18 | """
19 | root_dir, folders, files = next(os.walk(path))
20 |
21 | for f in files:
22 |
23 | if (pattern is None) or len(re.findall(pattern, f)):
24 | file_path = os.path.join(root_dir, f)
25 | yield file_path
26 |
27 | for d in folders:
28 | # If the folder is in the notification folders list, print a message
29 | if os.path.join(path, d) in notification_folders:
30 | print(f"Searching in {d}")
31 | new_path = os.path.join(root_dir, d)
32 | yield from iglob_files(new_path, notification_folders, pattern)
33 |
34 |
35 | def glob_files(path, notification_folders, pattern=None, threads=50):
36 | """
37 | Return all files within path and subdirs containing the regex `pattern`
38 | """
39 | with ThreadPoolExecutor(max_workers=threads) as ex:
40 | files = list(
41 | ex.map(lambda x: x, iglob_files(path, notification_folders, pattern))
42 | )
43 |
44 | return files
45 |
46 |
47 | if __name__ == "__main__":
48 | parser = argparse.ArgumentParser(
49 | description="Generates a list of paths of patches to parse."
50 | )
51 |
52 | parser.add_argument(
53 | "collection",
54 | type=ExtractionCollection,
55 | choices=list(ExtractionCollection),
56 | help="The collection to extract",
57 | )
58 | parser.add_argument(
59 | "input_folder",
60 | type=Path,
61 | help="The path to the folder containing the extracted patches.",
62 | )
63 | parser.add_argument(
64 | "output_path",
65 | type=Path,
66 | help="The path to the pickle file where the paths will be saved.",
67 | )
68 |
69 | args = parser.parse_args()
70 |
71 | # Pattern to filter the files that are sentinel-1 patches
72 | pattern = COLLECTION_REGEXES[args.collection]
73 |
74 | root_subfolders = [
75 | str(path) for path in args.input_folder.iterdir() if path.is_dir()
76 | ]
77 |
78 | files = glob_files(
79 | path=args.input_folder,
80 | notification_folders=root_subfolders,
81 | pattern=pattern,
82 | )
83 |
84 | with open(args.output_path, "wb") as f:
85 | pickle.dump(files, f)
86 |
--------------------------------------------------------------------------------
/scripts/stac/split_catalogue.py:
--------------------------------------------------------------------------------
1 | """Split catalogue by the local UTM projection of the products to be utilisable
2 | by OpenEO processes."""
3 |
4 | import argparse
5 | import logging
6 | import pickle
7 | from pathlib import Path
8 |
9 | from openeo_gfmap.utils.split_stac import split_collection_by_epsg
10 | from tqdm import tqdm
11 |
12 | # Logger used for the pipeline
13 | builder_log = logging.getLogger("catalogue_splitter")
14 |
15 | builder_log.setLevel(level=logging.INFO)
16 |
17 | stream_handler = logging.StreamHandler()
18 | builder_log.addHandler(stream_handler)
19 |
20 | formatter = logging.Formatter("%(asctime)s|%(name)s|%(levelname)s: %(message)s")
21 | stream_handler.setFormatter(formatter)
22 |
23 |
24 | if __name__ == "__main__":
25 | parser = argparse.ArgumentParser(
26 | description="Splits the catalogue by the local UTM projection of the products."
27 | )
28 | parser.add_argument(
29 | "input_folder",
30 | type=Path,
31 | help="The path to the folder containing the collection files.",
32 | )
33 | parser.add_argument(
34 | "output_folder",
35 | type=Path,
36 | help="The path where to save the splitted STAC collections to.",
37 | )
38 |
39 | args = parser.parse_args()
40 |
41 | if not args.input_folder.exists():
42 | raise FileNotFoundError(f"The input folder {args.input_folder} does not exist.")
43 |
44 | if not args.output_folder.exists():
45 | raise FileNotFoundError(
46 | f"The output folder {args.output_folder} does not exist."
47 | )
48 |
49 | builder_log.info("Loading the catalogues from the directory %s", args.input_folder)
50 | # List the catalogues in the input folder
51 | catalogues = []
52 | for catalogue_path in tqdm(args.input_folder.glob("*.pkl")):
53 | with open(catalogue_path, "rb") as file:
54 | catalogue = pickle.load(file)
55 | try:
56 | catalogue.strategy
57 | except AttributeError:
58 | setattr(catalogue, "strategy", None)
59 | catalogues.append(catalogue)
60 |
61 | builder_log.info("Loaded %s catalogues. Merging them...", len(catalogues))
62 |
63 | merged_catalogue = None
64 | for catalogue_path in tqdm(catalogues):
65 | if merged_catalogue is None:
66 | merged_catalogue = catalogue_path
67 | else:
68 | merged_catalogue.add_items(catalogue_path.get_all_items())
69 |
70 | if merged_catalogue is None:
71 | raise ValueError("No catalogues found in the input folder.")
72 |
73 | builder_log.info("Merged catalogues into one. Updating the extent...")
74 | merged_catalogue.update_extent_from_items()
75 |
76 | with open("temp_merged_catalogue.pkl", "wb") as file:
77 | pickle.dump(merged_catalogue, file)
78 |
79 | builder_log.info("Splitting the catalogue by the local UTM projection...")
80 | split_collection_by_epsg(merged_catalogue, args.output_folder)
81 |
--------------------------------------------------------------------------------
/src/worldcereal/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from ._version import __version__
4 |
5 | __all__ = ["__version__"]
6 |
7 | SUPPORTED_SEASONS = [
8 | "tc-s1",
9 | "tc-s2",
10 | "tc-annual",
11 | "custom",
12 | ]
13 |
14 | SEASONAL_MAPPING = {
15 | "tc-s1": "S1",
16 | "tc-s2": "S2",
17 | "tc-annual": "ANNUAL",
18 | "custom": "custom",
19 | }
20 |
21 |
22 | # Default buffer (days) prior to
23 | # season start
24 | SEASON_PRIOR_BUFFER = {
25 | "tc-s1": 0,
26 | "tc-s2": 0,
27 | "tc-annual": 0,
28 | "custom": 0,
29 | }
30 |
31 |
32 | # Default buffer (days) after
33 | # season end
34 | SEASON_POST_BUFFER = {
35 | "tc-s1": 0,
36 | "tc-s2": 0,
37 | "tc-annual": 0,
38 | "custom": 0,
39 | }
40 |
--------------------------------------------------------------------------------
/src/worldcereal/_version.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | __version__ = "2.2.3"
4 |
--------------------------------------------------------------------------------
/src/worldcereal/data/cropcalendars/ANNUAL_EOS_WGS84.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/ANNUAL_EOS_WGS84.tif
--------------------------------------------------------------------------------
/src/worldcereal/data/cropcalendars/ANNUAL_SOS_WGS84.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/ANNUAL_SOS_WGS84.tif
--------------------------------------------------------------------------------
/src/worldcereal/data/cropcalendars/S1_EOS_WGS84.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/S1_EOS_WGS84.tif
--------------------------------------------------------------------------------
/src/worldcereal/data/cropcalendars/S1_SOS_WGS84.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/S1_SOS_WGS84.tif
--------------------------------------------------------------------------------
/src/worldcereal/data/cropcalendars/S2_EOS_WGS84.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/S2_EOS_WGS84.tif
--------------------------------------------------------------------------------
/src/worldcereal/data/cropcalendars/S2_SOS_WGS84.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/S2_SOS_WGS84.tif
--------------------------------------------------------------------------------
/src/worldcereal/data/cropcalendars/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/cropcalendars/__init__.py
--------------------------------------------------------------------------------
/src/worldcereal/data/croptype_mappings/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/data/croptype_mappings/__init__.py
--------------------------------------------------------------------------------
/src/worldcereal/data/croptype_mappings/wc2eurocrops_map.csv:
--------------------------------------------------------------------------------
1 | name,croptype,ewoc_code,landcover
2 | Unknown,0,00-00-00-000-0,0
3 | Cereals,1000,11-01-00-000-0,11
4 | Wheat,1100,11-01-01-000-0,11
5 | Winter wheat,1110,11-01-01-000-1,11
6 | Spring wheat,1120,11-01-01-000-2,11
7 | Maize,1200,11-01-06-000-0,11
8 | Rice,1300,11-01-08-000-0,11
9 | Sorghum,1400,11-01-07-003-0,11
10 | Barley,1500,11-01-02-000-0,11
11 | Winter barley,1510,11-01-02-000-1,11
12 | Spring barley,1520,11-01-02-000-2,11
13 | Rye,1600,11-01-03-000-0,11
14 | Winter rye,1610,11-01-03-000-1,11
15 | Spring rye,1620,11-01-03-000-2,11
16 | Oats,1700,11-01-04-000-0,11
17 | Millets,1800,11-01-07-001-0,11
18 | Other cereals,1900,11-01-00-000-0,11
19 | Winter cereal,1910,11-01-00-000-1,11
20 | Spring cereal,1920,11-01-00-000-2,11
21 | Vegetables and melons,2000,11-03-00-000-0,11
22 | Leafy or stem vegetables,2100,11-03-04-000-0,11
23 | Artichokes,2110,11-03-04-002-0,11
24 | Asparagus,2120,11-03-04-003-0,11
25 | Cabages,2130,11-03-06-000-0,11
26 | Cauliflowers & brocoli,2140,11-03-06-013-0,11
27 | Lettuce,2150,11-03-08-011-0,11
28 | Spinach,2160,11-03-08-008-0,11
29 | Chicory,2170,11-03-08-002-0,11
30 | Other leafy/stem vegetables,2190,11-03-04-000-0,11
31 | Fruit-bearing vegetables,2200,11-03-01-000-0,11
32 | Cucumbers,2210,11-03-02-001-0,11
33 | Eggplants,2220,11-03-01-002-0,11
34 | Tomatoes,2230,11-03-01-001-0,11
35 | Watermelons,2240,11-03-02-005-0,11
36 | Cantaloupes and other melons,2250,11-03-02-003-0,11
37 | "Pumpkin, squash and gourds",2260,11-03-02-004-0,11
38 | Other fruit-bearing vegetables,2290,11-03-01-000-0,11
39 | "Root, bulb or tuberous vegetables",2300,11-03-09-000-0,11
40 | Carrots,2310,11-03-09-004-0,11
41 | Turnips,2320,11-03-09-007-0,11
42 | Garlic,2330,11-03-11-002-0,11
43 | Onions and shallots,2340,11-03-11-007-0,11
44 | Leeks & other alliaceous vegetables,2350,11-03-11-000-0,11
45 | "Other root, bulb or tuberous vegetables",2390,11-03-09-000-0,11
46 | Mushrooms and truffles,2400,11-04-00-000-0,11
47 | Other vegetables,2900,11-03-00-000-0,11
48 | Fruit and nuts,3000,12-01-00-000-0,12
49 | Tropical and subtropical fruits,3100,12-01-02-000-0,12
50 | Avocados,3110,12-01-02-001-0,12
51 | Bananas & plantains,3120,12-01-02-002-0,12
52 | Dates,3130,12-01-02-003-0,12
53 | Figs,3140,12-01-01-006-0,12
54 | Mangoes,3150,12-01-02-004-0,12
55 | Papayas,3160,12-01-02-005-0,12
56 | Pineapples,3170,12-01-02-006-0,12
57 | Other tropical and subtropical fruits,3190,12-01-02-000-0,12
58 | Citrus fruits,3200,12-01-03-000-0,12
59 | Grapefruit & pomelo,3210,12-01-03-000-0,12
60 | Lemons & Limes,3220,12-01-03-003-0,12
61 | Oranges,3230,12-01-03-004-0,12
62 | "Tangerines, mandarins, clementines",3240,12-01-03-005-0,12
63 | Other citrus fruit,3290,12-01-03-000-0,12
64 | Grapes,3300,12-01-00-001-0,12
65 | Berries,3400,12-01-05-000-0,12
66 | Currants,3410,12-01-05-006-0,12
67 | Gooseberries,3420,12-01-05-007-0,12
68 | Kiwi fruit,3430,12-01-05-015-0,12
69 | Raspberries,3440,12-01-05-010-0,12
70 | Strawberries,3450,11-03-12-001-0,11
71 | Blueberries,3460,12-01-05-004-0,12
72 | Other berries,3490,12-01-05-000-0,12
73 | Pome fruits and stone fruits,3500,12-01-01-000-0,12
74 | Apples,3510,12-01-01-002-0,12
75 | Apricots,3520,12-01-01-003-0,12
76 | Cherries & sour cherries,3530,12-01-01-004-0,12
77 | Peaches & nectarines,3540,12-01-01-017-0,12
78 | Pears & quinces,3550,12-01-01-019-0,12
79 | Plums and sloes,3560,12-01-01-012-0,12
80 | Other pome fruits and stone fruits,3590,12-01-01-000-0,12
81 | Nuts,3600,12-01-04-000-0,12
82 | Almonds,3610,12-01-04-001-0,12
83 | Cashew nuts,3620,12-01-04-007-0,12
84 | Chestnuts,3630,12-01-04-005-0,12
85 | Hazelnuts,3640,12-01-04-002-0,12
86 | Pistachios,3650,12-01-04-004-0,12
87 | Walnuts,3660,12-01-04-006-0,12
88 | Other nuts,3690,12-01-04-000-0,12
89 | Other fruit,3900,12-01-00-000-0,12
90 | Oilseed crops,4000,11-06-00-000-0,11
91 | Soya beans,4100,11-06-00-002-0,11
92 | Groundnuts,4200,11-06-00-005-0,11
93 | Temporary oilseed crops,4300,11-06-00-000-0,11
94 | Castor bean,4310,11-06-00-006-0,11
95 | Linseed,4320,11-06-00-007-0,11
96 | Mustard,4330,11-06-00-008-0,11
97 | Niger seed,4340,11-06-00-004-0,11
98 | Rapeseed,4350,11-06-00-003-0,11
99 | Winter rapeseed,4351,11-06-00-003-1,11
100 | Spring rapeseed,4352,11-06-00-003-2,11
101 | Safflower,4360,11-06-00-009-0,11
102 | Sesame,4370,11-06-00-010-0,11
103 | Sunflower,4380,11-06-00-001-0,11
104 | Other temporary oilseed crops,4390,11-06-00-000-0,11
105 | Permanent oilseed crops,4400,12-03-00-000-0,12
106 | Coconuts,4410,12-03-00-002-0,12
107 | Olives,4420,12-03-00-001-0,12
108 | Oil palms,4430,12-03-00-003-0,12
109 | Other oleaginous fruits,4490,12-03-00-000-0,12
110 | Root/tuber crops,5000,11-07-00-000-0,11
111 | Potatoes,5100,11-07-00-001-0,11
112 | Sweet potatoes,5200,11-07-00-002-0,11
113 | Cassava,5300,11-07-00-004-0,11
114 | Yams,5400,11-07-00-005-0,11
115 | Other roots and tubers,5900,11-07-00-000-0,11
116 | Beverage and spice crops,6000,10-00-00-000-0,10
117 | Beverage crops,6100,12-02-00-000-0,12
118 | Coffee,6110,12-02-00-001-0,12
119 | Tea,6120,12-02-00-002-0,12
120 | Maté,6130,12-02-00-003-0,12
121 | Cocoa,6140,12-02-00-004-0,12
122 | Other beverage crops,6190,10-00-00-000-0,10
123 | Spice crops,6200,12-02-00-000-0,12
124 | Chilies & peppers,6211,11-03-03-000-0,11
125 | "Anise, badian, fennel",6212,11-00-00-000-0,11
126 | Other temporary spice crops,6219,11-09-00-000-0,11
127 | Pepper,6221,11-09-00-029-0,11
128 | "Nutmeg, mace, cardamoms",6222,12-02-00-007-0,12
129 | Cinnamon,6223,12-02-00-008-0,12
130 | Cloves,6224,12-02-00-009-0,12
131 | Ginger,6225,11-09-00-048-0,11
132 | Vanilla,6226,11-09-00-049-0,10
133 | Other permanent spice crops,6229,12-02-00-000-0,12
134 | Leguminous crops,7000,11-05-00-000-0,11
135 | Beans,7100,11-05-01-001-0,11
136 | Broad beans,7200,11-05-01-003-0,11
137 | Chick peas,7300,11-05-01-004-0,11
138 | Cow peas,7400,11-05-01-005-0,11
139 | Lentils,7500,11-05-00-003-0,11
140 | Lupins,7600,11-05-00-004-0,11
141 | Peas,7700,11-05-01-002-0,11
142 | Pigeon peas,7800,11-05-01-006-0,11
143 | Other Leguminous crops,7900,10-00-00-000-0,10
144 | Other Leguminous crops - Temporary,7910,11-05-00-000-0,11
145 | Other Leguminous crops - Permanent,7920,12-04-00-000-0,12
146 | Sugar crops,8000,10-00-00-000-0,10
147 | Sugar beet,8100,11-07-00-003-1,11
148 | Sugar cane,8200,11-11-01-010-0,11
149 | Sweet sorghum,8300,11-01-07-004-0,11
150 | Other sugar crops,8900,10-00-00-000-0,10
151 | Other crops,9000,10-00-00-000-0,10
152 | Grasses and other fodder crops,9100,11-11-00-000-0,11
153 | Temporary grass crops,9110,11-11-00-001-0,11
154 | Permanent grass crops,9120,13-00-00-000-0,13
155 | Fibre crops,9200,10-00-00-000-0,10
156 | Temporary fibre crops,9210,11-08-00-000-0,11
157 | Cotton,9211,11-08-00-001-0,11
158 | "Jute, kenaf and similar",9212,11-08-00-002-0,11
159 | "Flax, hemp and similar",9213,11-08-02-000-0,11
160 | Other temporary fibre crops,9219,11-08-00-000-0,11
161 | Permanent fibre crops,9220,12-05-00-000-0,12
162 | "Medicinal, aromatic, pesticidal crops",9300,10-00-00-000-0,10
163 | Temporary medicinal etc crops,9310,11-09-00-000-0,11
164 | Permanent medicinal etc crops,9320,12-02-00-000-0,12
165 | Rubber,9400,12-06-00-008-0,12
166 | Flower crops,9500,11-10-00-000-0,11
167 | Temporary flower crops,9510,11-10-00-000-0,11
168 | Permanent flower crops,9520,12-06-00-000-0,12
169 | Tobacco,9600,11-09-00-050-0,11
170 | Other other crops,9900,10-00-00-000-0,10
171 | Other crops - temporary,9910,11-00-00-000-0,11
172 | Other crops - permanent,9920,12-00-00-000-0,12
173 | mixed cropping,9998,10-00-00-000-0,10
174 | Cropland,10,10-00-00-000-0,10
175 | Annual cropland,11,11-00-00-000-0,11
176 | Perennial cropland,12,12-00-00-000-0,12
177 | Grassland *,13,13-00-00-000-0,13
178 | Herbaceous vegetation,20,20-00-00-000-0,20
179 | Shrubland,30,30-00-00-000-0,30
180 | Deciduous forest,40,43-02-00-000-0,40
181 | Evergreen forest,41,43-01-00-000-0,41
182 | Mixed/unknown forest,42,43-00-00-000-0,42
183 | Bare / sparse vegetation,50,50-00-00-000-0,50
184 | Built up / urban,60,60-00-00-000-0,60
185 | Water,70,70-00-00-000-0,70
186 | Snow / ice,80,50-05-00-000-0,80
187 | No cropland (including perennials),98,17-00-00-000-0,98
188 | No cropland,99,16-00-00-000-0,99
189 | Unknown,991,00-00-00-000-0,0
190 | Unknown,9700,00-00-00-000-0,0
--------------------------------------------------------------------------------
/src/worldcereal/extract/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/extract/__init__.py
--------------------------------------------------------------------------------
/src/worldcereal/extract/patch_meteo.py:
--------------------------------------------------------------------------------
1 | """Extract AGERA5 (Meteo) data using OpenEO-GFMAP package."""
2 |
3 | import copy
4 | from typing import Dict, List, Optional, Union
5 |
6 | import geojson
7 | import geopandas as gpd
8 | import openeo
9 | import pandas as pd
10 | from openeo_gfmap import Backend, TemporalContext
11 |
12 | from worldcereal.extract.utils import ( # isort: skip
13 | buffer_geometry, # isort: skip
14 | filter_extract_true, # isort: skip
15 | upload_geoparquet_s3, # isort: skip
16 | ) # isort: skip
17 |
18 |
19 | DEFAULT_JOB_OPTIONS_PATCH_METEO = {
20 | "executor-memory": "1800m",
21 | "python-memory": "1000m",
22 | "max-executors": 22,
23 | }
24 |
25 |
26 | def create_job_dataframe_patch_meteo(
27 | backend: Backend, split_jobs: List[gpd.GeoDataFrame]
28 | ) -> pd.DataFrame:
29 | raise NotImplementedError("This function is not implemented yet.")
30 |
31 |
32 | def create_job_patch_meteo(
33 | row: pd.Series,
34 | connection: openeo.DataCube,
35 | provider,
36 | connection_provider,
37 | job_options: Optional[Dict[str, Union[str, int]]] = None,
38 | ) -> gpd.GeoDataFrame:
39 | start_date = row.start_date
40 | end_date = row.end_date
41 | temporal_context = TemporalContext(start_date, end_date)
42 |
43 | # Get the feature collection containing the geometry to the job
44 | geometry = geojson.loads(row.geometry)
45 | assert isinstance(geometry, geojson.FeatureCollection)
46 |
47 | # Filter the geometry to the rows with the extract only flag
48 | geometry = filter_extract_true(geometry)
49 | assert len(geometry.features) > 0, "No geometries with the extract flag found"
50 |
51 | # Performs a buffer of 64 px around the geometry
52 | geometry_df = buffer_geometry(geometry, distance_m=5)
53 | spatial_extent_url = upload_geoparquet_s3(provider, geometry_df, row.name, "METEO")
54 |
55 | bands_to_download = ["temperature-mean", "precipitation-flux"]
56 |
57 | cube = connection.load_collection(
58 | "AGERA5",
59 | temporal_extent=[temporal_context.start_date, temporal_context.end_date],
60 | bands=bands_to_download,
61 | )
62 | filter_geometry = connection.load_url(spatial_extent_url, format="parquet")
63 | cube = cube.filter_spatial(filter_geometry)
64 | cube.rename_labels(
65 | dimension="bands",
66 | target=["AGERA5-temperature-mean", "AGERA5-precipitation-flux"],
67 | source=["temperature-mean", "precipitation-flux"],
68 | )
69 |
70 | # Rescale to uint16, multiplying by 100 first
71 | cube = cube * 100
72 | cube = cube.linear_scale_range(0, 65534, 0, 65534)
73 |
74 | h3index = geometry.features[0].properties["h3index"]
75 | valid_time = geometry.features[0].properties["valid_time"]
76 |
77 | # Set job options
78 | final_job_options = copy.deepcopy(DEFAULT_JOB_OPTIONS_PATCH_METEO)
79 | if job_options:
80 | final_job_options.update(job_options)
81 |
82 | return cube.create_job(
83 | out_format="NetCDF",
84 | title=f"WorldCereal_Patch-AGERA5_Extraction_{h3index}_{valid_time}",
85 | sample_by_feature=True,
86 | job_options=final_job_options,
87 | )
88 |
--------------------------------------------------------------------------------
/src/worldcereal/extract/patch_s1.py:
--------------------------------------------------------------------------------
1 | """Extract S1 data using OpenEO-GFMAP package."""
2 |
3 | import copy
4 | from datetime import datetime
5 | from typing import Dict, List, Optional, Union
6 |
7 | import geojson
8 | import geopandas as gpd
9 | import openeo
10 | import pandas as pd
11 | from openeo_gfmap import (
12 | Backend,
13 | BackendContext,
14 | BoundingBoxExtent,
15 | FetchType,
16 | TemporalContext,
17 | )
18 | from openeo_gfmap.preprocessing.sar import compress_backscatter_uint16
19 | from openeo_gfmap.utils.catalogue import s1_area_per_orbitstate_vvvh
20 | from tqdm import tqdm
21 |
22 | from worldcereal.openeo.preprocessing import raw_datacube_S1, spatially_filter_cube
23 | from worldcereal.rdm_api.rdm_interaction import RDM_DEFAULT_COLUMNS
24 |
25 | from worldcereal.extract.utils import ( # isort: skip
26 | buffer_geometry, # isort: skip
27 | get_job_nb_polygons, # isort: skip
28 | pipeline_log, # isort: skip
29 | # upload_geoparquet_s3, # isort: skip
30 | upload_geoparquet_artifactory, # isort: skip
31 | )
32 |
33 | S1_GRD_CATALOGUE_BEGIN_DATE = datetime(2014, 10, 1)
34 |
35 | DEFAULT_JOB_OPTIONS_PATCH_S1 = {
36 | "executor-memory": "1800m",
37 | "python-memory": "1900m",
38 | "max-executors": 22,
39 | "soft-errors": 0.1,
40 | }
41 |
42 |
43 | def create_job_dataframe_patch_s1(
44 | backend: Backend,
45 | split_jobs: List[gpd.GeoDataFrame],
46 | ) -> pd.DataFrame:
47 | """Create a dataframe from the split jobs, containing all the necessary information to run the job."""
48 | rows = []
49 | for job in tqdm(split_jobs):
50 | # Compute the average in the valid date and make a buffer of 1.5 year around
51 | min_time = job.valid_time.min()
52 | max_time = job.valid_time.max()
53 |
54 | # Compute the average in the valid date and make a buffer of 1.5 year around
55 | # 9 months before and after the valid time
56 | start_date = (min_time - pd.Timedelta(days=275)).to_pydatetime()
57 | end_date = (max_time + pd.Timedelta(days=275)).to_pydatetime()
58 |
59 | # Impose limits due to the data availability
60 | start_date = max(start_date, S1_GRD_CATALOGUE_BEGIN_DATE)
61 | end_date = min(end_date, datetime.now())
62 |
63 | # Convert dates to string format
64 | start_date, end_date = (
65 | start_date.strftime("%Y-%m-%d"),
66 | end_date.strftime("%Y-%m-%d"),
67 | )
68 |
69 | s2_tile = job.tile.iloc[0] # Job dataframes are split depending on the
70 | h3_l3_cell = job.h3_l3_cell.iloc[0]
71 |
72 | # Check wherever the s2_tile is in the grid
73 | geometry_bbox = job.to_crs(epsg=4326).total_bounds
74 | # Buffer if the geometry is a point
75 | if geometry_bbox[0] == geometry_bbox[2]:
76 | geometry_bbox = (
77 | geometry_bbox[0] - 0.0001,
78 | geometry_bbox[1],
79 | geometry_bbox[2] + 0.0001,
80 | geometry_bbox[3],
81 | )
82 | if geometry_bbox[1] == geometry_bbox[3]:
83 | geometry_bbox = (
84 | geometry_bbox[0],
85 | geometry_bbox[1] - 0.0001,
86 | geometry_bbox[2],
87 | geometry_bbox[3] + 0.0001,
88 | )
89 |
90 | area_per_orbit = s1_area_per_orbitstate_vvvh(
91 | backend=BackendContext(backend),
92 | spatial_extent=BoundingBoxExtent(*geometry_bbox),
93 | temporal_extent=TemporalContext(start_date, end_date),
94 | )
95 | descending_area = area_per_orbit["DESCENDING"]["area"]
96 | ascending_area = area_per_orbit["ASCENDING"]["area"]
97 |
98 | # Set back the valid_time in the geometry as string
99 | job["valid_time"] = job.valid_time.dt.strftime("%Y-%m-%d")
100 |
101 | # Subset on required attributes
102 | job = job[RDM_DEFAULT_COLUMNS]
103 |
104 | variables = {
105 | "backend_name": backend.value,
106 | "out_prefix": "S1-SIGMA0-10m",
107 | "out_extension": ".nc",
108 | "start_date": start_date,
109 | "end_date": end_date,
110 | "s2_tile": s2_tile,
111 | "h3_l3_cell": h3_l3_cell,
112 | "geometry": job.to_json(),
113 | }
114 |
115 | if descending_area > 0:
116 | variables.update({"orbit_state": "DESCENDING"})
117 | rows.append(pd.Series(variables))
118 |
119 | if ascending_area > 0:
120 | variables.update({"orbit_state": "ASCENDING"})
121 | rows.append(pd.Series(variables))
122 |
123 | if descending_area + ascending_area == 0:
124 | pipeline_log.warning(
125 | "No S1 data available for the tile %s in the period %s - %s.",
126 | s2_tile,
127 | start_date,
128 | end_date,
129 | )
130 |
131 | return pd.DataFrame(rows)
132 |
133 |
134 | def create_job_patch_s1(
135 | row: pd.Series,
136 | connection: openeo.DataCube,
137 | provider,
138 | connection_provider,
139 | job_options: Optional[Dict[str, Union[str, int]]] = None,
140 | ) -> openeo.BatchJob:
141 | """Creates an OpenEO BatchJob from the given row information. This job is a
142 | S1 patch of 32x32 pixels at 20m spatial resolution."""
143 |
144 | # Load the temporal extent
145 | start_date = row.start_date
146 | end_date = row.end_date
147 | temporal_context = TemporalContext(start_date, end_date)
148 |
149 | # Get the feature collection containing the geometry to the job
150 | geometry = geojson.loads(row.geometry)
151 | assert isinstance(geometry, geojson.FeatureCollection)
152 |
153 | # Jobs will be run for two orbit direction
154 | orbit_state = row.orbit_state
155 |
156 | # Performs a buffer of 64 px around the geometry
157 | geometry_df = buffer_geometry(geometry, distance_m=320)
158 | # spatial_extent_url = upload_geoparquet_s3(
159 | # provider, geometry_df, f"{row.s2_tile}_{row.name}", "SENTINEL1"
160 | # )
161 | spatial_extent_url = upload_geoparquet_artifactory(
162 | geometry_df, f"{row.s2_tile}_{row.name}", collection="SENTINEL1"
163 | )
164 |
165 | # Backend name and fetching type
166 | backend = Backend(row.backend_name)
167 | backend_context = BackendContext(backend)
168 |
169 | # Create the job to extract S1
170 | cube = raw_datacube_S1(
171 | connection=connection,
172 | backend_context=backend_context,
173 | spatial_extent=spatial_extent_url,
174 | temporal_extent=temporal_context,
175 | bands=["S1-SIGMA0-VV", "S1-SIGMA0-VH"],
176 | fetch_type=FetchType.POLYGON,
177 | target_resolution=20,
178 | orbit_direction=orbit_state,
179 | )
180 | cube = compress_backscatter_uint16(backend_context, cube)
181 |
182 | # Apply spatial filtering
183 | cube = spatially_filter_cube(connection, cube, spatial_extent_url)
184 |
185 | # Additional values to generate the BatcJob name
186 | s2_tile = row.s2_tile
187 | valid_time = geometry.features[0].properties["valid_time"]
188 |
189 | # Increase the memory of the jobs depending on the number of polygons to extract
190 | number_polygons = get_job_nb_polygons(row)
191 | pipeline_log.debug("Number of polygons to extract %s", number_polygons)
192 |
193 | # Set job options
194 | final_job_options = copy.deepcopy(DEFAULT_JOB_OPTIONS_PATCH_S1)
195 | if job_options:
196 | final_job_options.update(job_options)
197 |
198 | return cube.create_job(
199 | out_format="NetCDF",
200 | title=f"Worldcereal_Patch-S1_Extraction_{s2_tile}_{orbit_state}_{valid_time}",
201 | sample_by_feature=True,
202 | job_options=final_job_options,
203 | feature_id_property="sample_id",
204 | )
205 |
--------------------------------------------------------------------------------
/src/worldcereal/extract/patch_s2.py:
--------------------------------------------------------------------------------
1 | """Extract S2 data using OpenEO-GFMAP package."""
2 |
3 | import copy
4 | from datetime import datetime
5 | from typing import Dict, List, Optional, Union
6 |
7 | import geojson
8 | import geopandas as gpd
9 | import openeo
10 | import pandas as pd
11 | from openeo_gfmap import Backend, BackendContext, FetchType, TemporalContext
12 | from openeo_gfmap.manager import _log
13 | from tqdm import tqdm
14 |
15 | from worldcereal.openeo.preprocessing import raw_datacube_S2, spatially_filter_cube
16 | from worldcereal.rdm_api.rdm_interaction import RDM_DEFAULT_COLUMNS
17 |
18 | from worldcereal.extract.utils import ( # isort: skip
19 | buffer_geometry, # isort: skip
20 | get_job_nb_polygons, # isort: skip
21 | upload_geoparquet_artifactory, # isort: skip
22 | # upload_geoparquet_s3, # isort: skip
23 | ) # isort: skip
24 |
25 |
26 | S2_L2A_CATALOGUE_BEGIN_DATE = datetime(2017, 1, 1)
27 |
28 |
29 | DEFAULT_JOB_OPTIONS_PATCH_S2 = {
30 | "driver-memory": "2G",
31 | "driver-memoryOverhead": "2G",
32 | "driver-cores": "1",
33 | "executor-memory": "1800m",
34 | "python-memory": "1900m",
35 | "executor-cores": "1",
36 | "max-executors": 22,
37 | "soft-errors": 0.1,
38 | "gdal-dataset-cache-size": 2,
39 | "gdal-cachemax": 120,
40 | "executor-threads-jvm": 1,
41 | }
42 |
43 |
44 | def create_job_dataframe_patch_s2(
45 | backend: Backend,
46 | split_jobs: List[gpd.GeoDataFrame],
47 | ) -> pd.DataFrame:
48 | """Create a dataframe from the split jobs, containing all the necessary information to run the job."""
49 | rows = []
50 | for job in tqdm(split_jobs):
51 | # Compute the average in the valid date and make a buffer of 1.5 year around
52 | min_time = job.valid_time.min()
53 | max_time = job.valid_time.max()
54 | # 9 months before and after the valid time
55 | start_date = (min_time - pd.Timedelta(days=275)).to_pydatetime()
56 | end_date = (max_time + pd.Timedelta(days=275)).to_pydatetime()
57 |
58 | # Impose limits due to the data availability
59 | # start_date = max(start_date, S2_L2A_CATALOGUE_BEGIN_DATE)
60 | # end_date = min(end_date, datetime.now())
61 |
62 | s2_tile = job.tile.iloc[0]
63 | h3_l3_cell = job.h3_l3_cell.iloc[0]
64 |
65 | # Convert dates to string format
66 | start_date, end_date = (
67 | start_date.strftime("%Y-%m-%d"),
68 | end_date.strftime("%Y-%m-%d"),
69 | )
70 |
71 | # Set back the valid_time in the geometry as string
72 | job["valid_time"] = job.valid_time.dt.strftime("%Y-%m-%d")
73 |
74 | # Subset on required attributes
75 | job = job[RDM_DEFAULT_COLUMNS]
76 |
77 | variables = {
78 | "backend_name": backend.value,
79 | "out_prefix": "S2-L2A-10m",
80 | "out_extension": ".nc",
81 | "start_date": start_date,
82 | "end_date": end_date,
83 | "s2_tile": s2_tile,
84 | "h3_l3_cell": h3_l3_cell,
85 | "geometry": job.to_json(),
86 | }
87 | rows.append(pd.Series(variables))
88 |
89 | return pd.DataFrame(rows)
90 |
91 |
92 | def create_job_patch_s2(
93 | row: pd.Series,
94 | connection: openeo.DataCube,
95 | provider,
96 | connection_provider,
97 | job_options: Optional[Dict[str, Union[str, int]]] = None,
98 | ) -> gpd.GeoDataFrame:
99 | start_date = row.start_date
100 | end_date = row.end_date
101 | temporal_context = TemporalContext(start_date, end_date)
102 |
103 | # Get the feature collection containing the geometry to the job
104 | geometry = geojson.loads(row.geometry)
105 | assert isinstance(geometry, geojson.FeatureCollection)
106 |
107 | # # Filter the geometry to the rows with the extract only flag
108 | # geometry = filter_extract_true(geometry)
109 | # assert len(geometry.features) > 0, "No geometries with the extract flag found"
110 |
111 | # Performs a buffer of 64 px around the geometry
112 | geometry_df = buffer_geometry(geometry, distance_m=320)
113 | # spatial_extent_url = upload_geoparquet_s3(
114 | # provider, geometry_df, f"{row.s2_tile}_{row.name}", "SENTINEL2"
115 | # )
116 | spatial_extent_url = upload_geoparquet_artifactory(
117 | geometry_df, f"{row.s2_tile}_{row.name}", collection="SENTINEL2"
118 | )
119 |
120 | # Backend name and fetching type
121 | backend = Backend(row.backend_name)
122 | backend_context = BackendContext(backend)
123 |
124 | # Get the h3index to use in the tile
125 | s2_tile = row.s2_tile
126 | valid_time = geometry.features[0].properties["valid_time"]
127 |
128 | bands_to_download = [
129 | "S2-L2A-B01",
130 | "S2-L2A-B02",
131 | "S2-L2A-B03",
132 | "S2-L2A-B04",
133 | "S2-L2A-B05",
134 | "S2-L2A-B06",
135 | "S2-L2A-B07",
136 | "S2-L2A-B08",
137 | "S2-L2A-B8A",
138 | "S2-L2A-B09",
139 | "S2-L2A-B11",
140 | "S2-L2A-B12",
141 | "S2-L2A-SCL",
142 | ]
143 |
144 | cube = raw_datacube_S2(
145 | connection,
146 | backend_context,
147 | temporal_context,
148 | bands_to_download,
149 | FetchType.POLYGON,
150 | spatial_extent=spatial_extent_url,
151 | filter_tile=s2_tile,
152 | apply_mask_flag=False,
153 | additional_masks_flag=True,
154 | )
155 |
156 | # Apply spatial filtering
157 | cube = spatially_filter_cube(connection, cube, spatial_extent_url)
158 |
159 | # Increase the memory of the jobs depending on the number of polygons to extract
160 | number_polygons = get_job_nb_polygons(row)
161 | _log.debug("Number of polygons to extract %s", number_polygons)
162 |
163 | # Set job options
164 | final_job_options = copy.deepcopy(DEFAULT_JOB_OPTIONS_PATCH_S2)
165 | if job_options:
166 | final_job_options.update(job_options)
167 |
168 | return cube.create_job(
169 | out_format="NetCDF",
170 | title=f"Worldcereal_Patch-S2_Extraction_{s2_tile}_{valid_time}",
171 | sample_by_feature=True,
172 | job_options=final_job_options,
173 | feature_id_property="sample_id",
174 | )
175 |
--------------------------------------------------------------------------------
/src/worldcereal/extract/utils.py:
--------------------------------------------------------------------------------
1 | """Common utilities used by extraction scripts."""
2 |
3 | import logging
4 | import os
5 | from tempfile import NamedTemporaryFile
6 |
7 | import geojson
8 | import geopandas as gpd
9 | import pandas as pd
10 | import requests
11 | from openeo_gfmap.manager.job_splitters import load_s2_grid
12 | from shapely import Point
13 |
14 | from worldcereal.utils.upload import OpenEOArtifactHelper
15 |
16 | # Logger used for the pipeline
17 | pipeline_log = logging.getLogger("extraction_pipeline")
18 |
19 | pipeline_log.setLevel(level=logging.INFO)
20 |
21 | stream_handler = logging.StreamHandler()
22 | pipeline_log.addHandler(stream_handler)
23 |
24 | formatter = logging.Formatter("%(asctime)s|%(name)s|%(levelname)s: %(message)s")
25 | stream_handler.setFormatter(formatter)
26 |
27 |
28 | # Exclude the other loggers from other libraries
29 | class ManagerLoggerFilter(logging.Filter):
30 | """Filter to only accept the OpenEO-GFMAP manager logs."""
31 |
32 | def filter(self, record):
33 | return record.name in [pipeline_log.name]
34 |
35 |
36 | stream_handler.addFilter(ManagerLoggerFilter())
37 |
38 |
39 | S2_GRID = load_s2_grid()
40 |
41 |
42 | def buffer_geometry(
43 | geometries: geojson.FeatureCollection, distance_m: int = 320
44 | ) -> gpd.GeoDataFrame:
45 | """For each geometry of the collection, perform a square buffer of 320
46 | meters on the centroid and return the GeoDataFrame. Before buffering,
47 | the centroid is clipped to the closest 20m multiplier in order to stay
48 | aligned with the Sentinel-1 pixel grid.
49 | """
50 | gdf = gpd.GeoDataFrame.from_features(geometries).set_crs(epsg=4326)
51 | utm = gdf.estimate_utm_crs()
52 | gdf = gdf.to_crs(utm)
53 |
54 | # Perform the buffering operation
55 | gdf["geometry"] = gdf.centroid.apply(
56 | lambda point: Point(round(point.x / 20.0) * 20.0, round(point.y / 20.0) * 20.0)
57 | ).buffer(
58 | distance=distance_m, cap_style=3
59 | ) # Square buffer
60 |
61 | return gdf
62 |
63 |
64 | def filter_extract_true(
65 | geometries: geojson.FeatureCollection, extract_value: int = 1
66 | ) -> gpd.GeoDataFrame:
67 | """Remove all the geometries from the Feature Collection that have the property field `extract` set to `False`"""
68 | return geojson.FeatureCollection(
69 | [
70 | f
71 | for f in geometries.features
72 | if f.properties.get("extract", 0) == extract_value
73 | ]
74 | )
75 |
76 |
77 | def upload_geoparquet_s3(
78 | backend: str, gdf: gpd.GeoDataFrame, name: str, collection: str = ""
79 | ) -> str:
80 | """Upload the given GeoDataFrame to s3 and return the URL of the
81 | uploaded file. Necessary as a workaround for Polygon sampling in OpenEO
82 | using custom CRS.
83 | """
84 | # Save the dataframe as geoparquet to upload it to artifactory
85 | temporary_file = NamedTemporaryFile()
86 | gdf.to_parquet(temporary_file.name)
87 |
88 | targetpath = f"openeogfmap_dataframe_{collection}_{name}.parquet"
89 |
90 | artifact_helper = OpenEOArtifactHelper.from_openeo_backend(backend)
91 | normal_s3_uri = artifact_helper.upload_file(targetpath, temporary_file.name)
92 | presigned_uri = artifact_helper.get_presigned_url(normal_s3_uri)
93 |
94 | return presigned_uri
95 |
96 |
97 | def upload_geoparquet_artifactory(
98 | gdf: gpd.GeoDataFrame, name: str, collection: str = ""
99 | ) -> str:
100 | """Upload the given GeoDataFrame to artifactory and return the URL of the
101 | uploaded file. Necessary as a workaround for Polygon sampling in OpenEO
102 | using custom CRS.
103 | """
104 | # Save the dataframe as geoparquet to upload it to artifactory
105 | temporary_file = NamedTemporaryFile()
106 | gdf.to_parquet(temporary_file.name)
107 |
108 | artifactory_username = os.getenv("ARTIFACTORY_USERNAME")
109 | artifactory_password = os.getenv("ARTIFACTORY_PASSWORD")
110 |
111 | if not artifactory_username or not artifactory_password:
112 | raise ValueError(
113 | "Artifactory credentials not found. Please set ARTIFACTORY_USERNAME and ARTIFACTORY_PASSWORD."
114 | )
115 |
116 | headers = {"Content-Type": "application/octet-stream"}
117 |
118 | upload_url = f"https://artifactory.vgt.vito.be/artifactory/auxdata-public/gfmap-temp/openeogfmap_dataframe_{collection}_{name}.parquet"
119 |
120 | with open(temporary_file.name, "rb") as f:
121 | response = requests.put(
122 | upload_url,
123 | headers=headers,
124 | data=f,
125 | auth=(artifactory_username, artifactory_password),
126 | timeout=180,
127 | )
128 |
129 | response.raise_for_status()
130 |
131 | return upload_url
132 |
133 |
134 | def get_job_nb_polygons(row: pd.Series) -> int:
135 | """Get the number of polygons in the geometry."""
136 | return len(
137 | list(
138 | filter(
139 | lambda feat: feat.properties.get("extract"),
140 | geojson.loads(row.geometry)["features"],
141 | )
142 | )
143 | )
144 |
--------------------------------------------------------------------------------
/src/worldcereal/openeo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/src/worldcereal/openeo/__init__.py
--------------------------------------------------------------------------------
/src/worldcereal/openeo/inference.py:
--------------------------------------------------------------------------------
1 | """Model inference on Presto feature for binary classication"""
2 |
3 | import functools
4 |
5 | import onnxruntime as ort
6 | import requests
7 | import xarray as xr
8 | from openeo_gfmap.inference.model_inference import ModelInference
9 |
10 |
11 | class CropClassifier(ModelInference):
12 | """Binary or multi-class crop classifier using ONNX to load a catboost model.
13 |
14 | The classifier use the embeddings computed from the Presto Feature
15 | Extractor.
16 |
17 | Interesting UDF parameters:
18 | - classifier_url: A public URL to the ONNX classification model. Default is
19 | the public Presto model.
20 | - lookup_table: A dictionary mapping class names to class labels, ordered by
21 | model probability output. This is required for the model to map the output
22 | probabilities to class names.
23 | """
24 |
25 | import numpy as np
26 |
27 | def __init__(self):
28 | super().__init__()
29 |
30 | self.onnx_session = None
31 |
32 | def dependencies(self) -> list:
33 | return []
34 |
35 | @classmethod
36 | @functools.lru_cache(maxsize=6)
37 | def load_and_prepare_model(cls, model_url: str):
38 | """Method to be used instead the default GFMap load_ort_model function.
39 | Loads the model, validates it and extracts LUT from the model metadata.
40 |
41 |
42 | Parameters
43 | ----------
44 | model_url : str
45 | Public URL to the ONNX classification model.
46 | """
47 | # Load the model
48 | response = requests.get(model_url, timeout=120)
49 | model = ort.InferenceSession(response.content)
50 |
51 | # Validate the model
52 | metadata = model.get_modelmeta().custom_metadata_map
53 |
54 | if "class_params" not in metadata:
55 | raise ValueError("Could not find class names in the model metadata.")
56 |
57 | class_params = eval(metadata["class_params"], {"__builtins__": None}, {})
58 |
59 | if "class_names" not in class_params:
60 | raise ValueError("Could not find class names in the model metadata.")
61 |
62 | if "class_to_label" not in class_params:
63 | raise ValueError("Could not find class to labels in the model metadata.")
64 |
65 | # Load model LUT
66 | lut = dict(zip(class_params["class_names"], class_params["class_to_label"]))
67 | sorted_lut = {k: v for k, v in sorted(lut.items(), key=lambda item: item[1])}
68 |
69 | return model, sorted_lut
70 |
71 | def output_labels(self) -> list:
72 | """
73 | Returns the output labels for the classification.
74 |
75 | LUT needs to be explicitly sorted here as openEO does
76 | not guarantee the order of a json object being preserved when decoding
77 | a process graph in the backend.
78 | """
79 | _, lut = self.load_and_prepare_model(self._parameters["classifier_url"])
80 | lut_sorted = {k: v for k, v in sorted(lut.items(), key=lambda item: item[1])}
81 | class_names = lut_sorted.keys()
82 |
83 | return ["classification", "probability"] + [
84 | f"probability_{name}" for name in class_names
85 | ]
86 |
87 | def predict(self, features: np.ndarray) -> np.ndarray:
88 | """
89 | Predicts labels using the provided features array.
90 |
91 | LUT needs to be explicitly sorted here as openEO does
92 | not guarantee the order of a json object being preserved when decoding
93 | a process graph in the backend.
94 | """
95 | import numpy as np
96 |
97 | # Classes names to codes
98 | _, lut_sorted = self.load_and_prepare_model(self._parameters["classifier_url"])
99 |
100 | if lut_sorted is None:
101 | raise ValueError(
102 | "Lookup table is not defined. Please provide lookup_table in the model metadata."
103 | )
104 |
105 | if self.onnx_session is None:
106 | raise ValueError("Model has not been loaded. Please load a model first.")
107 |
108 | # Prepare input data for ONNX model
109 | outputs = self.onnx_session.run(None, {"features": features})
110 |
111 | # Extract classes as INTs and probability of winning class values
112 | labels = np.zeros((len(outputs[0]),), dtype=np.uint16)
113 | probabilities = np.zeros((len(outputs[0]),), dtype=np.uint8)
114 | for i, (label, prob) in enumerate(zip(outputs[0], outputs[1])):
115 | labels[i] = lut_sorted[label]
116 | probabilities[i] = int(round(prob[label] * 100))
117 |
118 | # Extract per class probabilities
119 | output_probabilities = []
120 | for output_px in outputs[1]:
121 | output_probabilities.append(
122 | [output_px[label] for label in lut_sorted.keys()]
123 | )
124 |
125 | output_probabilities = (
126 | (np.array(output_probabilities) * 100).round().astype(np.uint8)
127 | )
128 |
129 | return np.hstack(
130 | [labels[:, np.newaxis], probabilities[:, np.newaxis], output_probabilities]
131 | ).transpose()
132 |
133 | def execute(self, inarr: xr.DataArray) -> xr.DataArray:
134 |
135 | if "classifier_url" not in self._parameters:
136 | raise ValueError('Missing required parameter "classifier_url"')
137 | classifier_url = self._parameters.get("classifier_url")
138 | self.logger.info(f'Loading classifier model from "{classifier_url}"')
139 |
140 | # shape and indices for output ("xy", "bands")
141 | x_coords, y_coords = inarr.x.values, inarr.y.values
142 | inarr = inarr.transpose("bands", "x", "y").stack(xy=["x", "y"]).transpose()
143 |
144 | self.onnx_session, self._parameters["lookup_table"] = (
145 | self.load_and_prepare_model(classifier_url)
146 | )
147 |
148 | # Run catboost classification
149 | self.logger.info("Catboost classification with input shape: %s", inarr.shape)
150 | classification = self.predict(inarr.values)
151 | self.logger.info("Classification done with shape: %s", inarr.shape)
152 |
153 | output_labels = self.output_labels()
154 |
155 | classification_da = xr.DataArray(
156 | classification.reshape((len(output_labels), len(x_coords), len(y_coords))),
157 | dims=["bands", "x", "y"],
158 | coords={
159 | "bands": output_labels,
160 | "x": x_coords,
161 | "y": y_coords,
162 | },
163 | )
164 |
165 | return classification_da
166 |
--------------------------------------------------------------------------------
/src/worldcereal/openeo/mapping.py:
--------------------------------------------------------------------------------
1 | """Private methods for cropland/croptype mapping. The public functions that
2 | are interracting with the methods here are defined in the `worldcereal.job`
3 | sub-module.
4 | """
5 |
6 | from typing import Optional
7 |
8 | from openeo import DataCube
9 | from openeo_gfmap import TemporalContext
10 | from openeo_gfmap.features.feature_extractor import apply_feature_extractor
11 | from openeo_gfmap.inference.model_inference import apply_model_inference
12 | from openeo_gfmap.preprocessing.scaling import compress_uint8, compress_uint16
13 |
14 | from worldcereal.parameters import (
15 | CropLandParameters,
16 | CropTypeParameters,
17 | PostprocessParameters,
18 | WorldCerealProductType,
19 | )
20 |
21 |
22 | def _cropland_map(
23 | inputs: DataCube,
24 | temporal_extent: TemporalContext,
25 | cropland_parameters: CropLandParameters,
26 | postprocess_parameters: PostprocessParameters,
27 | ) -> DataCube:
28 | """Method to produce cropland map from preprocessed inputs, using
29 | a Presto feature extractor and a CatBoost classifier.
30 |
31 | Parameters
32 | ----------
33 | inputs : DataCube
34 | preprocessed input cube
35 | temporal_extent : TemporalContext
36 | temporal extent of the input cube
37 | cropland_parameters: CropLandParameters
38 | Parameters for the cropland product inference pipeline
39 | postprocess_parameters: PostprocessParameters
40 | Parameters for the postprocessing
41 | Returns
42 | -------
43 | DataCube
44 | binary labels and probability
45 | """
46 |
47 | # Run feature computer
48 | features = apply_feature_extractor(
49 | feature_extractor_class=cropland_parameters.feature_extractor,
50 | cube=inputs,
51 | parameters=cropland_parameters.feature_parameters.model_dump(),
52 | size=[
53 | {"dimension": "x", "unit": "px", "value": 128},
54 | {"dimension": "y", "unit": "px", "value": 128},
55 | ],
56 | overlap=[
57 | {"dimension": "x", "unit": "px", "value": 0},
58 | {"dimension": "y", "unit": "px", "value": 0},
59 | ],
60 | )
61 |
62 | # Run model inference on features
63 | parameters = cropland_parameters.classifier_parameters.model_dump(
64 | exclude=["classifier"]
65 | )
66 |
67 | classes = apply_model_inference(
68 | model_inference_class=cropland_parameters.classifier,
69 | cube=features,
70 | parameters=parameters,
71 | size=[
72 | {"dimension": "x", "unit": "px", "value": 128},
73 | {"dimension": "y", "unit": "px", "value": 128},
74 | {"dimension": "t", "value": "P1D"},
75 | ],
76 | overlap=[
77 | {"dimension": "x", "unit": "px", "value": 0},
78 | {"dimension": "y", "unit": "px", "value": 0},
79 | ],
80 | )
81 |
82 | # Get rid of temporal dimension
83 | classes = classes.reduce_dimension(dimension="t", reducer="mean")
84 |
85 | # Postprocess
86 | if postprocess_parameters.enable:
87 | if postprocess_parameters.save_intermediate:
88 | classes = classes.save_result(
89 | format="GTiff",
90 | options=dict(
91 | filename_prefix=f"{WorldCerealProductType.CROPLAND.value}-raw_{temporal_extent.start_date}_{temporal_extent.end_date}"
92 | ),
93 | )
94 | classes = _postprocess(
95 | classes,
96 | postprocess_parameters,
97 | cropland_parameters.classifier_parameters.classifier_url,
98 | )
99 |
100 | # Cast to uint8
101 | classes = compress_uint8(classes)
102 |
103 | return classes
104 |
105 |
106 | def _croptype_map(
107 | inputs: DataCube,
108 | temporal_extent: TemporalContext,
109 | croptype_parameters: "CropTypeParameters",
110 | postprocess_parameters: "PostprocessParameters",
111 | cropland_mask: DataCube = None,
112 | ) -> DataCube:
113 | """Method to produce croptype map from preprocessed inputs, using
114 | a Presto feature extractor and a CatBoost classifier.
115 |
116 | Parameters
117 | ----------
118 | inputs : DataCube
119 | preprocessed input cube
120 | temporal_extent : TemporalContext
121 | temporal extent of the input cube
122 | cropland_mask : DataCube, optional
123 | optional cropland mask, by default None
124 | lookup_table: dict,
125 | Mapping of class names to class labels, ordered by model output.
126 | Returns
127 | -------
128 | DataCube
129 | croptype labels and probability
130 | """
131 |
132 | # Run feature computer
133 | features = apply_feature_extractor(
134 | feature_extractor_class=croptype_parameters.feature_extractor,
135 | cube=inputs,
136 | parameters=croptype_parameters.feature_parameters.model_dump(),
137 | size=[
138 | {"dimension": "x", "unit": "px", "value": 128},
139 | {"dimension": "y", "unit": "px", "value": 128},
140 | ],
141 | overlap=[
142 | {"dimension": "x", "unit": "px", "value": 0},
143 | {"dimension": "y", "unit": "px", "value": 0},
144 | ],
145 | )
146 |
147 | # Run model inference on features
148 | parameters = croptype_parameters.classifier_parameters.model_dump(
149 | exclude=["classifier"]
150 | )
151 |
152 | classes = apply_model_inference(
153 | model_inference_class=croptype_parameters.classifier,
154 | cube=features,
155 | parameters=parameters,
156 | size=[
157 | {"dimension": "x", "unit": "px", "value": 128},
158 | {"dimension": "y", "unit": "px", "value": 128},
159 | {"dimension": "t", "value": "P1D"},
160 | ],
161 | overlap=[
162 | {"dimension": "x", "unit": "px", "value": 0},
163 | {"dimension": "y", "unit": "px", "value": 0},
164 | ],
165 | )
166 |
167 | # Get rid of temporal dimension
168 | classes = classes.reduce_dimension(dimension="t", reducer="mean")
169 |
170 | # Mask cropland
171 | if cropland_mask is not None:
172 | classes = classes.mask(cropland_mask == 0, replacement=254)
173 |
174 | # Postprocess
175 | if postprocess_parameters.enable:
176 | if postprocess_parameters.save_intermediate:
177 | classes = classes.save_result(
178 | format="GTiff",
179 | options=dict(
180 | filename_prefix=f"{WorldCerealProductType.CROPTYPE.value}-raw_{temporal_extent.start_date}_{temporal_extent.end_date}"
181 | ),
182 | )
183 | classes = _postprocess(
184 | classes,
185 | postprocess_parameters,
186 | classifier_url=croptype_parameters.classifier_parameters.classifier_url,
187 | )
188 |
189 | # Cast to uint16
190 | classes = compress_uint16(classes)
191 |
192 | return classes
193 |
194 |
195 | def _postprocess(
196 | classes: DataCube,
197 | postprocess_parameters: "PostprocessParameters",
198 | classifier_url: Optional[str] = None,
199 | ) -> DataCube:
200 | """Method to postprocess the classes.
201 |
202 | Parameters
203 | ----------
204 | classes : DataCube
205 | classes to postprocess
206 | postprocess_parameters : PostprocessParameters
207 | parameter class for postprocessing
208 | lookup_table: dict
209 | Mapping of class names to class labels, ordered by model output.
210 | Returns
211 | -------
212 | DataCube
213 | postprocessed classes
214 | """
215 |
216 | # Run postprocessing on the raw classification output
217 | # Note that this uses the `apply_model_inference` method even though
218 | # it is not truly model inference
219 | parameters = postprocess_parameters.model_dump(exclude=["postprocessor"])
220 | parameters.update({"classifier_url": classifier_url})
221 |
222 | postprocessed_classes = apply_model_inference(
223 | model_inference_class=postprocess_parameters.postprocessor,
224 | cube=classes,
225 | parameters=parameters,
226 | size=[
227 | {"dimension": "x", "unit": "px", "value": 128},
228 | {"dimension": "y", "unit": "px", "value": 128},
229 | ],
230 | overlap=[
231 | {"dimension": "x", "unit": "px", "value": 0},
232 | {"dimension": "y", "unit": "px", "value": 0},
233 | ],
234 | )
235 |
236 | return postprocessed_classes
237 |
--------------------------------------------------------------------------------
/src/worldcereal/openeo/masking.py:
--------------------------------------------------------------------------------
1 | from skimage.morphology import footprints
2 |
3 |
4 | def convolve(img, radius):
5 | """OpenEO method to apply convolution
6 | with a circular kernel of `radius` pixels.
7 | NOTE: make sure the resolution of the image
8 | matches the expected radius in pixels!
9 | """
10 | kernel = footprints.disk(radius)
11 | img = img.apply_kernel(kernel)
12 | return img
13 |
14 |
15 | def scl_mask_erode_dilate(
16 | session,
17 | bbox,
18 | scl_layer_band="TERRASCOPE_S2_TOC_V2:SCL",
19 | erode_r=3,
20 | dilate_r=21,
21 | target_crs=None,
22 | ):
23 | """OpenEO method to construct a Sentinel-2 mask based on SCL.
24 | It involves an erosion step followed by a dilation step.
25 |
26 | Args:
27 | session (openeo.Session): the connection openeo session
28 | scl_layer_band (str, optional): Which SCL band to use.
29 | Defaults to "TERRASCOPE_S2_TOC_V2:SCL".
30 | erode_r (int, optional): Erosion radius (pixels). Defaults to 3.
31 | dilate_r (int, optional): Dilation radius (pixels). Defaults to 13.
32 |
33 | Returns:
34 | DataCube: DataCube containing the resulting mask
35 | """
36 |
37 | layer_band = scl_layer_band.split(":")
38 | s2_sceneclassification = session.load_collection(
39 | layer_band[0], bands=[layer_band[1]], spatial_extent=bbox
40 | )
41 |
42 | classification = s2_sceneclassification.band(layer_band[1])
43 |
44 | # Force to go to 10m resolution for controlled erosion/dilation
45 | classification = classification.resample_spatial(
46 | projection=target_crs, resolution=10.0
47 | )
48 |
49 | first_mask = classification == 0
50 | for mask_value in [1, 3, 8, 9, 10, 11]:
51 | first_mask = (first_mask == 1) | (classification == mask_value)
52 |
53 | # Invert mask for erosion
54 | first_mask = first_mask.apply(lambda x: (x == 1).not_())
55 |
56 | # Apply erosion by dilation the inverted mask
57 | erode_cube = convolve(first_mask, erode_r)
58 |
59 | # Invert again
60 | erode_cube = erode_cube > 0.1
61 | erode_cube = erode_cube.apply(lambda x: (x == 1).not_())
62 |
63 | # Now dilate the mask
64 | dilate_cube = convolve(erode_cube, dilate_r)
65 |
66 | # Get binary mask. NOTE: >0.1 is a fix to avoid being triggered
67 | # by small non-zero oscillations after applying convolution
68 | dilate_cube = dilate_cube > 0.1
69 |
70 | return dilate_cube
71 |
--------------------------------------------------------------------------------
/src/worldcereal/openeo/udf_distance_to_cloud.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import xarray as xr
3 | from openeo.udf import XarrayDataCube
4 | from scipy.ndimage import distance_transform_cdt
5 | from skimage.morphology import binary_erosion, footprints
6 |
7 |
8 | def apply_datacube(cube: XarrayDataCube, context: dict) -> XarrayDataCube:
9 | cube_array: xr.DataArray = cube.get_array()
10 | cube_array = cube_array.transpose("bands", "y", "x")
11 |
12 | clouds: xr.DataArray = np.logical_or(
13 | np.logical_and(cube_array < 11, cube_array >= 8), cube_array == 3
14 | ).isel(
15 | bands=0
16 | ) # type: ignore
17 |
18 | # Calculate the Distance To Cloud score
19 | # Erode
20 | er = footprints.disk(3)
21 |
22 | # Define a function to apply binary erosion
23 | def erode(image, selem):
24 | return ~binary_erosion(image, selem)
25 |
26 | # Use apply_ufunc to apply the erosion operation
27 | eroded = xr.apply_ufunc(
28 | erode, # function to apply
29 | clouds, # input DataArray
30 | input_core_dims=[["y", "x"]], # dimensions over which to apply function
31 | output_core_dims=[["y", "x"]], # dimensions of the output
32 | vectorize=True, # vectorize the function over non-core dimensions
33 | dask="parallelized", # enable dask parallelization
34 | output_dtypes=[np.int32], # data type of the output
35 | kwargs={"selem": er}, # additional keyword arguments to pass to erode
36 | )
37 |
38 | # Distance to cloud in manhattan distance measure
39 | distance = xr.apply_ufunc(
40 | distance_transform_cdt,
41 | eroded,
42 | input_core_dims=[["y", "x"]],
43 | output_core_dims=[["y", "x"]],
44 | vectorize=True,
45 | dask="parallelized",
46 | output_dtypes=[np.int32],
47 | )
48 |
49 | distance_da = xr.DataArray(
50 | distance,
51 | coords={
52 | "y": cube_array.coords["y"],
53 | "x": cube_array.coords["x"],
54 | },
55 | dims=["y", "x"],
56 | )
57 |
58 | distance_da = distance_da.expand_dims(
59 | dim={
60 | "bands": cube_array.coords["bands"],
61 | },
62 | )
63 |
64 | distance_da = distance_da.transpose("bands", "y", "x")
65 |
66 | return XarrayDataCube(distance_da)
67 |
--------------------------------------------------------------------------------
/src/worldcereal/rdm_api/__init__.py:
--------------------------------------------------------------------------------
1 | """This sub-module contains utilitary function and tools for worldcereal-classification"""
2 |
3 | from .rdm_collection import RdmCollection
4 | from .rdm_interaction import RdmInteraction
5 |
6 | __all__ = [
7 | "RdmInteraction",
8 | "RdmCollection",
9 | ]
10 |
--------------------------------------------------------------------------------
/src/worldcereal/rdm_api/rdm_collection.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | class RdmCollection:
5 | """Data class to host collections queried from the RDM API."""
6 |
7 | def __init__(self, **metadata):
8 | """Initializes the RdmCollection object with metadata from the RDM API.
9 | Collection metadata is passed as keyword arguments and stored as attributes.
10 | All default collection metadata items are expected.
11 |
12 | RdmCollection can be initialized directly from the RDM API response resulting from
13 | a query to the collections endpoint (see get_collections function in rdmi_interaction.py).
14 | """
15 | # check presence of mandatory metadata items
16 | if not metadata.get("collectionId"):
17 | raise ValueError(
18 | "Collection ID is missing, cannot create RdmCollection object."
19 | )
20 | if not metadata.get("accessType"):
21 | raise ValueError(
22 | "Access type is missing, cannot create RdmCollection object."
23 | )
24 |
25 | # Get all metadata items
26 | self.id = metadata.get("collectionId")
27 | self.title = metadata.get("title")
28 | self.feature_count = metadata.get("featureCount")
29 | self.data_type = metadata.get("type")
30 | self.access_type = metadata.get("accessType")
31 | self.observation_method = metadata.get("typeOfObservationMethod")
32 | self.confidence_lc = metadata.get("confidenceLandCover")
33 | self.confidence_ct = metadata.get("confidenceCropType")
34 | self.confidence_irr = metadata.get("confidenceIrrigationType")
35 | self.ewoc_codes = metadata.get("ewocCodes")
36 | self.irr_codes = metadata.get("irrTypes")
37 | self.extent = metadata.get("extent")
38 | if self.extent:
39 | self.spatial_extent = self.extent["spatial"]
40 | self.temporal_extent = self.extent["temporal"]["interval"][0]
41 | else:
42 | self.spatial_extent = None
43 | self.temporal_extent = None
44 | self.additional_data = metadata.get("additionalData")
45 | self.crs = metadata.get("crs")
46 | self.last_modified = metadata.get("lastModificationTime")
47 | self.last_modified_by = metadata.get("lastModifierId")
48 | self.creation_time = metadata.get("creationTime")
49 | self.created_by = metadata.get("creatorId")
50 | self.fid = metadata.get("id")
51 |
52 | def print_metadata(self):
53 |
54 | print("#######################")
55 | print("Collection Metadata:")
56 | print(f"ID: {self.id}")
57 | print(f"Title: {self.title}")
58 | print(f"Number of samples: {self.feature_count}")
59 | print(f"Data type: {self.data_type}")
60 | print(f"Access type: {self.access_type}")
61 | print(f"Observation method: {self.observation_method}")
62 | print(f"Confidence score for land cover: {self.confidence_lc}")
63 | print(f"Confidence score for crop type: {self.confidence_ct}")
64 | print(f"Confidence score for irrigation label: {self.confidence_irr}")
65 | print(f"List of available crop types: {self.ewoc_codes}")
66 | print(f"List of available irrigation labels: {self.irr_codes}")
67 | print(f"Spatial extent: {self.spatial_extent}")
68 | print(f"Coordinate reference system (CRS): {self.crs}")
69 | print(f"Temporal extent: {self.temporal_extent}")
70 | print(f"Additional data: {self.additional_data}")
71 | print(f"Last modified: {self.last_modified}")
72 | print(f"Last modified by: {self.last_modified_by}")
73 | print(f"Creation time: {self.creation_time}")
74 | print(f"Created by: {self.created_by}")
75 | print(f"fid: {self.fid}")
76 |
77 |
78 | def visualize_spatial_extents(collections: List[RdmCollection]):
79 | """Visualizes the spatial extent of multiple collections on a map."""
80 |
81 | from ipyleaflet import Map, Rectangle, basemaps
82 |
83 | if len(collections) == 1:
84 | zoom = 5
85 | colbbox = collections[0].spatial_extent.get("bbox", None)
86 | if colbbox is None:
87 | raise ValueError(
88 | f"No bounding box found for collection {collections[0].id}."
89 | )
90 | colbbox = colbbox[0]
91 | # compute the center of the bounding box
92 | center = [(colbbox[1] + colbbox[3]) / 2, (colbbox[0] + colbbox[2]) / 2]
93 | else:
94 | zoom = 1
95 | center = [0, 0]
96 |
97 | # Create the basemap
98 | m = Map(
99 | basemap=basemaps.CartoDB.Positron,
100 | zoom=zoom,
101 | center=center,
102 | scroll_wheel_zoom=True,
103 | )
104 |
105 | # Get the extent of each collection
106 | for col in collections:
107 | colbbox = col.spatial_extent.get("bbox", None)
108 | if colbbox is None:
109 | raise ValueError(f"No bounding box found for collection {col.id}.")
110 | colbbox = colbbox[0]
111 | bbox = [[colbbox[1], colbbox[0]], [colbbox[3], colbbox[2]]]
112 |
113 | # create a rectangle from the bounding box
114 | rectangle = Rectangle(bounds=bbox, color="green", weight=2, fill_opacity=0.1)
115 |
116 | # Add the rectangle to the map
117 | m.add_layer(rectangle)
118 |
119 | return m
120 |
--------------------------------------------------------------------------------
/src/worldcereal/stac/__init__.py:
--------------------------------------------------------------------------------
1 | """STAC constants and utilities relative to WorldCereal intermediary and deliverable products"""
2 |
--------------------------------------------------------------------------------
/src/worldcereal/stac/stac_api_interaction.py:
--------------------------------------------------------------------------------
1 | import concurrent
2 | from concurrent.futures import ThreadPoolExecutor
3 | from typing import Iterable
4 |
5 | import pystac
6 | import pystac_client
7 | import requests
8 | from openeo.rest.auth.oidc import (
9 | OidcClientInfo,
10 | OidcProviderInfo,
11 | OidcResourceOwnerPasswordAuthenticator,
12 | )
13 | from requests.auth import AuthBase
14 |
15 |
16 | class VitoStacApiAuthentication(AuthBase):
17 | """Class that handles authentication for the VITO STAC API. https://stac.openeo.vito.be/"""
18 |
19 | def __init__(self, **kwargs):
20 | self.username = kwargs.get("username")
21 | self.password = kwargs.get("password")
22 |
23 | def __call__(self, request):
24 | request.headers["Authorization"] = self.get_access_token()
25 | return request
26 |
27 | def get_access_token(self) -> str:
28 | """Get API bearer access token via password flow.
29 |
30 | Returns
31 | -------
32 | str
33 | A string containing the bearer access token.
34 | """
35 | provider_info = OidcProviderInfo(
36 | issuer="https://sso.terrascope.be/auth/realms/terrascope"
37 | )
38 |
39 | client_info = OidcClientInfo(
40 | client_id="terracatalogueclient",
41 | provider=provider_info,
42 | )
43 |
44 | if self.username and self.password:
45 | authenticator = OidcResourceOwnerPasswordAuthenticator(
46 | client_info=client_info, username=self.username, password=self.password
47 | )
48 | else:
49 | raise ValueError(
50 | "Credentials are required to obtain an access token. Please set STAC_API_USERNAME and STAC_API_PASSWORD environment variables."
51 | )
52 |
53 | tokens = authenticator.get_tokens()
54 |
55 | return f"Bearer {tokens.access_token}"
56 |
57 |
58 | class StacApiInteraction:
59 | """Class that handles the interaction with a STAC API."""
60 |
61 | _SENSOR_COLLECTION_CATALOG = {
62 | "Sentinel1": "worldcereal_sentinel_1_patch_extractions",
63 | "Sentinel2": "worldcereal_sentinel_2_patch_extractions",
64 | }
65 |
66 | def __init__(
67 | self, sensor: str, base_url: str, auth: AuthBase, bulk_size: int = 500
68 | ):
69 | if sensor not in self.catalog.keys():
70 | raise ValueError(
71 | f"Invalid sensor '{sensor}'. Allowed values are: {', '.join(self.catalog.keys())}."
72 | )
73 | self.sensor = sensor
74 | self.base_url = base_url
75 | self.collection_id = self.catalog[self.sensor]
76 |
77 | self.auth = auth
78 |
79 | self.bulk_size = bulk_size
80 |
81 | @property
82 | def catalog(self):
83 | return self._SENSOR_COLLECTION_CATALOG.copy()
84 |
85 | def exists(self) -> bool:
86 | client = pystac_client.Client.open(self.base_url)
87 | return (
88 | len([c.id for c in client.get_collections() if c.id == self.collection_id])
89 | > 0
90 | )
91 |
92 | def _join_url(self, url_path: str) -> str:
93 | return str(self.base_url + "/" + url_path)
94 |
95 | def create_collection(self):
96 | spatial_extent = pystac.SpatialExtent([[-180, -90, 180, 90]])
97 | temporal_extent = pystac.TemporalExtent([[None, None]])
98 | extent = pystac.Extent(spatial=spatial_extent, temporal=temporal_extent)
99 |
100 | collection = pystac.Collection(
101 | id=self.collection_id,
102 | description=f"WorldCereal Patch Extractions for {self.sensor}",
103 | extent=extent,
104 | )
105 |
106 | collection.validate()
107 | coll_dict = collection.to_dict()
108 |
109 | default_auth = {
110 | "_auth": {
111 | "read": ["anonymous"],
112 | "write": ["stac-openeo-admin", "stac-openeo-editor"],
113 | }
114 | }
115 |
116 | coll_dict.update(default_auth)
117 |
118 | response = requests.post(
119 | self._join_url("collections"), auth=self.auth, json=coll_dict
120 | )
121 |
122 | expected_status = [
123 | requests.status_codes.codes.ok,
124 | requests.status_codes.codes.created,
125 | requests.status_codes.codes.accepted,
126 | ]
127 |
128 | self._check_response_status(response, expected_status)
129 |
130 | return response
131 |
132 | def add_item(self, item: pystac.Item):
133 | if not self.exists():
134 | self.create_collection()
135 |
136 | self._prepare_item(item)
137 |
138 | url_path = f"collections/{self.collection_id}/items"
139 | response = requests.post(
140 | self._join_url(url_path), auth=self.auth, json=item.to_dict()
141 | )
142 |
143 | expected_status = [
144 | requests.status_codes.codes.ok,
145 | requests.status_codes.codes.created,
146 | requests.status_codes.codes.accepted,
147 | ]
148 |
149 | self._check_response_status(response, expected_status)
150 |
151 | return response
152 |
153 | def _prepare_item(self, item: pystac.Item):
154 | item.collection_id = self.collection_id
155 | if not item.get_links(pystac.RelType.COLLECTION):
156 | item.add_link(
157 | pystac.Link(rel=pystac.RelType.COLLECTION, target=item.collection_id)
158 | )
159 |
160 | def _ingest_bulk(self, items: Iterable[pystac.Item]) -> dict:
161 | if not all(i.collection_id == self.collection_id for i in items):
162 | raise Exception("All collection IDs should be identical for bulk ingests")
163 |
164 | url_path = f"collections/{self.collection_id}/bulk_items"
165 | data = {
166 | "method": "upsert",
167 | "items": {item.id: item.to_dict() for item in items},
168 | }
169 | response = requests.post(
170 | url=self._join_url(url_path), auth=self.auth, json=data
171 | )
172 |
173 | expected_status = [
174 | requests.status_codes.codes.ok,
175 | requests.status_codes.codes.created,
176 | requests.status_codes.codes.accepted,
177 | ]
178 |
179 | self._check_response_status(response, expected_status)
180 | return response.json()
181 |
182 | def upload_items_bulk(self, items: Iterable[pystac.Item]) -> None:
183 | if not self.exists():
184 | self.create_collection()
185 |
186 | chunk = []
187 | futures = []
188 |
189 | with ThreadPoolExecutor(max_workers=4) as executor:
190 | for item in items:
191 | self._prepare_item(item)
192 | chunk.append(item)
193 |
194 | if len(chunk) == self.bulk_size:
195 | futures.append(executor.submit(self._ingest_bulk, chunk.copy()))
196 | chunk = []
197 |
198 | if chunk:
199 | self._ingest_bulk(chunk)
200 |
201 | for _ in concurrent.futures.as_completed(futures):
202 | continue
203 |
204 | def _check_response_status(
205 | self, response: requests.Response, expected_status_codes: list[int]
206 | ):
207 | if response.status_code not in expected_status_codes:
208 | message = (
209 | f"Expecting HTTP status to be any of {expected_status_codes} "
210 | + f"but received {response.status_code} - {response.reason}, request method={response.request.method}\n"
211 | + f"response body:\n{response.text}"
212 | )
213 |
214 | raise Exception(message)
215 |
216 | def get_collection_id(self) -> str:
217 | return self.collection_id
218 |
--------------------------------------------------------------------------------
/src/worldcereal/utils/geoloader.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import geopandas as gpd
4 | import numpy as np
5 | import rasterio
6 | from loguru import logger
7 | from rasterio.crs import CRS
8 | from rasterio.enums import Resampling
9 | from rasterio.warp import reproject
10 | from shapely.geometry import Polygon
11 |
12 |
13 | def _load_array_bounds_latlon(
14 | fname,
15 | bounds,
16 | rio_gdal_options=None,
17 | boundless=True,
18 | fill_value=np.nan,
19 | nodata_value=None,
20 | ):
21 | rio_gdal_options = rio_gdal_options or {}
22 |
23 | with rasterio.Env(**rio_gdal_options):
24 | with rasterio.open(fname) as src:
25 | window = rasterio.windows.from_bounds(*bounds, src.transform)
26 |
27 | vals = np.array(window.flatten())
28 | if (vals % 1 > 0).any():
29 | col_off, row_off, width, height = (
30 | window.col_off,
31 | window.row_off,
32 | window.width,
33 | window.height,
34 | )
35 |
36 | width, height = int(math.ceil(width)), int(math.ceil(height))
37 | col_off, row_off = int(math.floor(col_off + 0.5)), int(
38 | math.floor(row_off + 0.5)
39 | )
40 |
41 | window = rasterio.windows.Window(col_off, row_off, width, height)
42 |
43 | if nodata_value is None:
44 | nodata_value = src.nodata if src.nodata is not None else np.nan
45 |
46 | if nodata_value is None and boundless is True:
47 | logger.warning(
48 | "Raster has no data value, defaulting boundless"
49 | " to False. Specify a nodata_value to read "
50 | "boundless."
51 | )
52 | boundless = False
53 |
54 | arr = src.read(window=window, boundless=boundless, fill_value=nodata_value)
55 | arr = arr.astype(
56 | np.float32
57 | ) # needed reprojecting with bilinear resampling # noqa:e501
58 |
59 | if nodata_value is not None:
60 | arr[arr == nodata_value] = fill_value
61 |
62 | arr = arr[0]
63 | return arr
64 |
65 |
66 | def load_reproject(
67 | filename,
68 | bounds,
69 | epsg,
70 | resolution=10,
71 | border_buff=0,
72 | fill_value=0,
73 | nodata_value=None,
74 | rio_gdal_options=None,
75 | resampling=Resampling.nearest,
76 | ):
77 | """
78 | Read from latlon layer and reproject to UTM
79 | """
80 | bbox = gpd.GeoSeries(Polygon.from_bounds(*bounds), crs=CRS.from_epsg(epsg))
81 |
82 | bounds = bbox.buffer(border_buff * resolution).to_crs(epsg=4326).bounds.values[0]
83 | utm_bounds = bbox.buffer(border_buff * resolution).bounds.values[0].tolist()
84 |
85 | width = max(1, int((utm_bounds[2] - utm_bounds[0]) / resolution))
86 | height = max(1, int((utm_bounds[3] - utm_bounds[1]) / resolution))
87 |
88 | gim = _load_array_bounds_latlon(
89 | filename,
90 | bounds,
91 | rio_gdal_options=rio_gdal_options,
92 | fill_value=fill_value,
93 | nodata_value=nodata_value,
94 | )
95 |
96 | src_crs = CRS.from_epsg(4326)
97 | dst_crs = CRS.from_epsg(bbox.crs.to_epsg())
98 |
99 | src_transform = rasterio.transform.from_bounds(*bounds, gim.shape[1], gim.shape[0])
100 | dst_transform = rasterio.transform.from_bounds(*utm_bounds, width, height)
101 |
102 | dst = np.zeros((height, width), dtype=np.float32)
103 |
104 | reproject(
105 | gim.astype(np.float32),
106 | dst,
107 | src_transform=src_transform,
108 | dst_transform=dst_transform,
109 | src_crs=src_crs,
110 | dst_crs=dst_crs,
111 | resampling=resampling,
112 | )
113 |
114 | if border_buff > 0:
115 | dst = dst[border_buff:-border_buff, border_buff:-border_buff]
116 |
117 | return dst
118 |
--------------------------------------------------------------------------------
/src/worldcereal/utils/models.py:
--------------------------------------------------------------------------------
1 | """Utilities around models for the WorldCereal package."""
2 |
3 | import json
4 | from functools import lru_cache
5 |
6 | import onnxruntime as ort
7 | import requests
8 |
9 |
10 | @lru_cache(maxsize=2)
11 | def load_model_onnx(model_url) -> ort.InferenceSession:
12 | """Load an ONNX model from a URL.
13 |
14 | Parameters
15 | ----------
16 | model_url: str
17 | URL to the ONNX model.
18 |
19 | Returns
20 | -------
21 | ort.InferenceSession
22 | ONNX model loaded with ONNX runtime.
23 | """
24 | # Two minutes timeout to download the model
25 | response = requests.get(model_url, timeout=120)
26 | model = response.content
27 |
28 | return ort.InferenceSession(model)
29 |
30 |
31 | def validate_cb_model(model_url: str) -> ort.InferenceSession:
32 | """Validate a catboost model by loading it and checking if the required
33 | metadata is present. Checks for the `class_names` and `class_to_labels`
34 | fields are present in the `class_params` field of the custom metadata of
35 | the model. By default, the CatBoost module should include those fields
36 | when exporting a model to ONNX.
37 |
38 | Raises an exception if the model is not valid.
39 |
40 | Parameters
41 | ----------
42 | model_url : str
43 | URL to the ONNX model.
44 |
45 | Returns
46 | -------
47 | ort.InferenceSession
48 | ONNX model loaded with ONNX runtime.
49 | """
50 | model = load_model_onnx(model_url=model_url)
51 |
52 | metadata = model.get_modelmeta().custom_metadata_map
53 |
54 | if "class_params" not in metadata:
55 | raise ValueError("Could not find class names in the model metadata.")
56 |
57 | class_params = json.loads(metadata["class_params"])
58 |
59 | if "class_names" not in class_params:
60 | raise ValueError("Could not find class names in the model metadata.")
61 |
62 | if "class_to_label" not in class_params:
63 | raise ValueError("Could not find class to labels in the model metadata.")
64 |
65 | return model
66 |
67 |
68 | def load_model_lut(model_url: str) -> dict:
69 | """Load the class names to labels mapping from a CatBoost model.
70 |
71 | Parameters
72 | ----------
73 | model_url : str
74 | URL to the ONNX model.
75 |
76 | Returns
77 | -------
78 | dict
79 | Look-up table with class names and labels.
80 | """
81 | model = validate_cb_model(model_url=model_url)
82 | metadata = model.get_modelmeta().custom_metadata_map
83 | class_params = json.loads(metadata["class_params"])
84 |
85 | lut = dict(zip(class_params["class_names"], class_params["class_to_label"]))
86 | sorted_lut = {k: v for k, v in sorted(lut.items(), key=lambda item: item[1])}
87 | return sorted_lut
88 |
--------------------------------------------------------------------------------
/src/worldcereal/utils/retry.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import random
3 | import time
4 | from functools import partial
5 |
6 | from loguru import logger
7 |
8 |
9 | def decorator(caller):
10 | """Turns caller into a decorator.
11 | Unlike decorator module, function signature is not preserved.
12 | :param caller: caller(f, *args, **kwargs)
13 | """
14 |
15 | def decor(f):
16 | @functools.wraps(f)
17 | def wrapper(*args, **kwargs):
18 | return caller(f, *args, **kwargs)
19 |
20 | return wrapper
21 |
22 | return decor
23 |
24 |
25 | def __retry_internal(
26 | f,
27 | exceptions=Exception,
28 | tries=-1,
29 | delay=0,
30 | max_delay=None,
31 | backoff=1,
32 | jitter=0,
33 | logger=logger,
34 | ):
35 | """
36 | Executes a function and retries it if it failed.
37 | :param f: the function to execute.
38 | :param exceptions: an exception or a tuple of exceptions to catch. default: Exception.
39 | :param tries: the maximum number of attempts. default: -1 (infinite).
40 | :param delay: initial delay between attempts. default: 0.
41 | :param max_delay: the maximum value of delay. default: None (no limit).
42 | :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff).
43 | :param jitter: extra seconds added to delay between attempts. default: 0.
44 | fixed if a number, random if a range tuple (min, max)
45 | :param logger: logger.warning(fmt, error, delay) will be called on failed attempts.
46 | default: retry.logger. if None, logging is disabled.
47 | :returns: the result of the f function.
48 | """
49 | _tries, _delay = tries, delay
50 | while _tries:
51 | try:
52 | return f()
53 | except exceptions as e:
54 | _tries -= 1
55 | if not _tries:
56 | raise
57 |
58 | if logger is not None:
59 | logger.warning(f'Error "{e}", retrying in {_delay} seconds...')
60 |
61 | time.sleep(_delay)
62 | _delay *= backoff
63 |
64 | if isinstance(jitter, tuple):
65 | _delay += random.uniform(*jitter)
66 | else:
67 | _delay += jitter
68 |
69 | if max_delay is not None:
70 | _delay = min(_delay, max_delay)
71 |
72 |
73 | def retry(
74 | exceptions=Exception,
75 | tries=-1,
76 | delay=0,
77 | max_delay=None,
78 | backoff=1,
79 | jitter=0,
80 | logger=logger,
81 | ):
82 | """Returns a retry decorator.
83 | :param exceptions: an exception or a tuple of exceptions to catch. default: Exception.
84 | :param tries: the maximum number of attempts. default: -1 (infinite).
85 | :param delay: initial delay between attempts. default: 0.
86 | :param max_delay: the maximum value of delay. default: None (no limit).
87 | :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff).
88 | :param jitter: extra seconds added to delay between attempts. default: 0.
89 | fixed if a number, random if a range tuple (min, max)
90 | :param logger: logger.warning(fmt, error, delay) will be called on failed attempts.
91 | default: retry.logger. if None, logging is disabled.
92 | :returns: a retry decorator.
93 | """
94 |
95 | @decorator
96 | def retry_decorator(f, *fargs, **fkwargs):
97 | args = fargs if fargs else []
98 | kwargs = fkwargs if fkwargs else {}
99 | return __retry_internal(
100 | partial(f, *args, **kwargs),
101 | exceptions,
102 | tries,
103 | delay,
104 | max_delay,
105 | backoff,
106 | jitter,
107 | logger,
108 | )
109 |
110 | return retry_decorator
111 |
112 |
113 | def retry_call(
114 | f,
115 | fargs=None,
116 | fkwargs=None,
117 | exceptions=Exception,
118 | tries=-1,
119 | delay=0,
120 | max_delay=None,
121 | backoff=1,
122 | jitter=0,
123 | logger=logger,
124 | ):
125 | """
126 | Calls a function and re-executes it if it failed.
127 | :param f: the function to execute.
128 | :param fargs: the positional arguments of the function to execute.
129 | :param fkwargs: the named arguments of the function to execute.
130 | :param exceptions: an exception or a tuple of exceptions to catch. default: Exception.
131 | :param tries: the maximum number of attempts. default: -1 (infinite).
132 | :param delay: initial delay between attempts. default: 0.
133 | :param max_delay: the maximum value of delay. default: None (no limit).
134 | :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff).
135 | :param jitter: extra seconds added to delay between attempts. default: 0.
136 | fixed if a number, random if a range tuple (min, max)
137 | :param logger: logger.warning(fmt, error, delay) will be called on failed attempts.
138 | default: retry.logger. if None, logging is disabled.
139 | :returns: the result of the f function.
140 | """
141 | args = fargs if fargs else []
142 | kwargs = fkwargs if fkwargs else {}
143 | return __retry_internal(
144 | partial(f, *args, **kwargs),
145 | exceptions,
146 | tries,
147 | delay,
148 | max_delay,
149 | backoff,
150 | jitter,
151 | logger,
152 | )
153 |
--------------------------------------------------------------------------------
/src/worldcereal/utils/spark.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | from loguru import logger
5 |
6 |
7 | def get_spark_context(
8 | name="WORLDCEREAL", localspark=False, threads="*", spark_version="3_2_0"
9 | ):
10 | """
11 | Returns SparkContext for local run.
12 | if local is True, conf is ignored.
13 |
14 | Customized for VITO MEP
15 | """
16 | if localspark:
17 | SPARK_HOME_2_0_0 = "/usr/hdp/current/spark2-client"
18 | SPARK_HOME_3_0_0 = "/opt/spark3_0_0"
19 | SPARK_HOME_3_2_0 = "/opt/spark3_2_0"
20 |
21 | SPARK_HOME = {
22 | "2_0_0": SPARK_HOME_2_0_0,
23 | "3_0_0": SPARK_HOME_3_0_0,
24 | "3_2_0": SPARK_HOME_3_2_0,
25 | }
26 |
27 | PY4J = {
28 | "2_0_0": "py4j-0.10.7",
29 | "3_0_0": "py4j-0.10.8.1",
30 | "3_2_0": "py4j-0.10.9.2",
31 | }
32 |
33 | SPARK_MAJOR_VERSION = {"2_0_0": "2", "3_0_0": "3", "3_2_0": "3"}
34 |
35 | spark_home = SPARK_HOME[spark_version]
36 | py4j_version = PY4J[spark_version]
37 | spark_major_version = SPARK_MAJOR_VERSION[spark_version]
38 |
39 | spark_py_path = [
40 | f"{spark_home}/python",
41 | f"{spark_home}/python/lib/{py4j_version}-src.zip",
42 | ]
43 |
44 | env_vars = {
45 | "SPARK_MAJOR_VERSION": spark_major_version,
46 | "SPARK_HOME": spark_home,
47 | }
48 | for k, v in env_vars.items():
49 | logger.info(f"Setting env var: {k}={v}")
50 | os.environ[k] = v
51 |
52 | logger.info(f"Prepending {spark_py_path} to PYTHONPATH")
53 | sys.path = spark_py_path + sys.path
54 |
55 | import py4j
56 |
57 | logger.info(f"py4j: {py4j.__file__}")
58 |
59 | import pyspark
60 |
61 | logger.info(f"pyspark: {pyspark.__file__}")
62 |
63 | import cloudpickle
64 | import pyspark.serializers
65 | from pyspark import SparkConf, SparkContext
66 |
67 | pyspark.serializers.cloudpickle = cloudpickle
68 |
69 | logger.info(f"Setting env var: PYSPARK_PYTHON={sys.executable}")
70 | os.environ["PYSPARK_PYTHON"] = sys.executable
71 |
72 | conf = SparkConf()
73 | conf.setMaster(f"local[{threads}]")
74 | conf.set("spark.driver.bindAddress", "127.0.0.1")
75 |
76 | sc = SparkContext(conf=conf)
77 |
78 | else:
79 | import cloudpickle
80 | import pyspark.serializers
81 | from pyspark.sql import SparkSession
82 |
83 | pyspark.serializers.cloudpickle = cloudpickle
84 |
85 | spark = SparkSession.builder.appName(name).getOrCreate()
86 | sc = spark.sparkContext
87 |
88 | return sc
89 |
--------------------------------------------------------------------------------
/tests/pre_test_script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Install git
4 | dnf install git -y
5 |
6 | # # Install openeo-gfmap and presto-worldcereal
7 | # dir=$(pwd)
8 | # GFMAP_URL="https://github.com/Open-EO/openeo-gfmap.git"
9 | # PRESTO_URL="https://github.com/WorldCereal/presto-worldcereal.git"
10 |
11 | # su - jenkins -c "cd $dir && \
12 | # source venv310/bin/activate && \
13 | # git clone $GFMAP_URL && \
14 | # cd openeo-gfmap || { echo 'Directory not found! Exiting...'; exit 1; } && \
15 | # pip install . && \
16 | # cd ..
17 | # git clone -b croptype $PRESTO_URL && \
18 | # cd presto-worldcereal || { echo 'Directory not found! Exiting...'; exit 1; } && \
19 | # pip install .
20 | # "
21 |
22 | # For now only presto-worldcereal as gfmap is up to date on pypi
23 | dir=$(pwd)
24 | PRESTO_URL="https://github.com/WorldCereal/presto-worldcereal.git"
25 |
26 | su - jenkins -c "cd $dir && \
27 | source venv310/bin/activate && \
28 | git clone -b croptype $PRESTO_URL && \
29 | cd presto-worldcereal || { echo 'Directory not found! Exiting...'; exit 1; } && \
30 | pip install .
31 | "
--------------------------------------------------------------------------------
/tests/worldcerealtests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/__init__.py
--------------------------------------------------------------------------------
/tests/worldcerealtests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import geojson
5 | import pandas as pd
6 | import pytest
7 | import xarray as xr
8 |
9 |
10 | def get_test_resource(relative_path):
11 | dir = Path(os.path.dirname(os.path.realpath(__file__)))
12 | return dir / "testresources" / relative_path
13 |
14 |
15 | @pytest.fixture
16 | def WorldCerealPreprocessedInputs():
17 | filepath = get_test_resource("worldcereal_preprocessed_inputs.nc")
18 | arr = (
19 | xr.open_dataset(filepath)
20 | .to_array(dim="bands")
21 | .drop_sel(bands="crs")
22 | .astype("uint16")
23 | )
24 | return arr
25 |
26 |
27 | @pytest.fixture
28 | def SpatialExtent():
29 | filepath = get_test_resource("spatial_extent.json")
30 | with open(filepath, "r") as f:
31 | return geojson.load(f)
32 |
33 |
34 | @pytest.fixture
35 | def WorldCerealCroplandClassification():
36 | filepath = get_test_resource("worldcereal_cropland_classification.nc")
37 | arr = xr.open_dataarray(filepath).astype("uint16")
38 | return arr
39 |
40 |
41 | @pytest.fixture
42 | def WorldCerealCroptypeClassification():
43 | filepath = get_test_resource("worldcereal_croptype_classification.nc")
44 | arr = xr.open_dataarray(filepath).astype("uint16")
45 | return arr
46 |
47 |
48 | @pytest.fixture
49 | def WorldCerealExtractionsDF():
50 | filepath = get_test_resource("test_public_extractions.parquet")
51 | return pd.read_parquet(filepath)
52 |
53 |
54 | @pytest.fixture
55 | def WorldCerealPrivateExtractionsPath():
56 | filepath = get_test_resource("worldcereal_private_extractions_dummy.parquet")
57 | return filepath
58 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/test_feature_extractor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import xarray as xr
3 |
4 | from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
5 |
6 |
7 | def test_slope_computation():
8 | test_elevation = np.array(
9 | [[10, 20, 30, 30], [10, 20, 20, 20], [65535, 20, 20, 20]], dtype=np.uint16
10 | )
11 |
12 | array = xr.DataArray(
13 | test_elevation[None, :, :],
14 | dims=["bands", "y", "x"],
15 | coords={
16 | "bands": ["elevation"],
17 | "x": [
18 | 71.2302216215233,
19 | 71.23031145305171,
20 | 71.23040128458014,
21 | 71.23040128458014,
22 | ],
23 | "y": [25.084450211061935, 25.08436885206669, 25.084287493017356],
24 | },
25 | )
26 |
27 | extractor = PrestoFeatureExtractor()
28 | extractor._epsg = 4326 # pylint: disable=protected-access
29 |
30 | # In the UDF no_data is set to 65535
31 | resolution = extractor.evaluate_resolution(array)
32 | slope = extractor.compute_slope(
33 | array, resolution
34 | ).values # pylint: disable=protected-access
35 |
36 | assert slope[0, -1, 0] == 65535
37 | assert resolution == 10
38 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/test_inference.py:
--------------------------------------------------------------------------------
1 | from openeo_gfmap.features.feature_extractor import (
2 | EPSG_HARMONIZED_NAME,
3 | apply_feature_extractor_local,
4 | )
5 | from openeo_gfmap.inference.model_inference import apply_model_inference_local
6 |
7 | from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
8 | from worldcereal.openeo.inference import CropClassifier
9 | from worldcereal.parameters import CropLandParameters, CropTypeParameters
10 | from worldcereal.utils.models import load_model_lut
11 |
12 |
13 | def test_cropland_inference(WorldCerealPreprocessedInputs):
14 | """Test the local generation of a cropland product"""
15 |
16 | print("Get Presto cropland features")
17 | cropland_features = apply_feature_extractor_local(
18 | PrestoFeatureExtractor,
19 | WorldCerealPreprocessedInputs,
20 | parameters={
21 | EPSG_HARMONIZED_NAME: 32631,
22 | "ignore_dependencies": True,
23 | "compile_presto": False,
24 | "use_valid_date_token": False,
25 | "presto_model_url": CropLandParameters().feature_parameters.presto_model_url,
26 | },
27 | )
28 |
29 | print("Running cropland classification inference UDF locally")
30 |
31 | lookup_table = load_model_lut(
32 | CropLandParameters().classifier_parameters.classifier_url
33 | )
34 |
35 | cropland_classification = apply_model_inference_local(
36 | CropClassifier,
37 | cropland_features,
38 | parameters={
39 | EPSG_HARMONIZED_NAME: 32631,
40 | "ignore_dependencies": True,
41 | "lookup_table": lookup_table,
42 | "classifier_url": CropLandParameters().classifier_parameters.classifier_url,
43 | },
44 | )
45 |
46 | assert list(cropland_classification.bands.values) == [
47 | "classification",
48 | "probability",
49 | "probability_other",
50 | "probability_cropland",
51 | ]
52 | assert cropland_classification.sel(bands="classification").values.max() <= 1
53 | assert cropland_classification.sel(bands="classification").values.min() >= 0
54 | assert cropland_classification.sel(bands="probability").values.max() <= 100
55 | assert cropland_classification.sel(bands="probability").values.min() >= 0
56 | assert cropland_classification.shape == (4, 100, 100)
57 |
58 |
59 | def test_croptype_inference(WorldCerealPreprocessedInputs):
60 | """Test the local generation of a croptype product"""
61 |
62 | print("Get Presto croptype features")
63 | croptype_features = apply_feature_extractor_local(
64 | PrestoFeatureExtractor,
65 | WorldCerealPreprocessedInputs,
66 | parameters={
67 | EPSG_HARMONIZED_NAME: 32631,
68 | "ignore_dependencies": True,
69 | "compile_presto": False,
70 | "use_valid_date_token": True,
71 | "presto_model_url": CropTypeParameters().feature_parameters.presto_model_url,
72 | },
73 | )
74 |
75 | print("Running croptype classification inference UDF locally")
76 |
77 | lookup_table = load_model_lut(
78 | CropTypeParameters().classifier_parameters.classifier_url
79 | )
80 |
81 | croptype_classification = apply_model_inference_local(
82 | CropClassifier,
83 | croptype_features,
84 | parameters={
85 | EPSG_HARMONIZED_NAME: 32631,
86 | "ignore_dependencies": True,
87 | "lookup_table": lookup_table,
88 | "classifier_url": CropTypeParameters().classifier_parameters.classifier_url,
89 | },
90 | )
91 |
92 | assert list(croptype_classification.bands.values) == [
93 | "classification",
94 | "probability",
95 | "probability_barley",
96 | "probability_maize",
97 | "probability_millet_sorghum",
98 | "probability_other_crop",
99 | "probability_rapeseed_rape",
100 | "probability_soy_soybeans",
101 | "probability_sunflower",
102 | "probability_wheat",
103 | ]
104 |
105 | # First assert below depends on the amount of classes in the model
106 | assert croptype_classification.sel(bands="classification").values.max() <= 7
107 | assert croptype_classification.sel(bands="classification").values.min() >= 0
108 | assert croptype_classification.sel(bands="probability").values.max() <= 100
109 | assert croptype_classification.sel(bands="probability").values.min() >= 0
110 | assert croptype_classification.shape == (10, 100, 100)
111 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/test_pipelines.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from catboost import CatBoostClassifier, Pool
4 | from loguru import logger
5 | from openeo_gfmap import BoundingBoxExtent, TemporalContext
6 | from presto.presto import Presto
7 | from sklearn.model_selection import train_test_split
8 | from sklearn.utils.class_weight import compute_class_weight
9 |
10 | from worldcereal.parameters import CropTypeParameters
11 | from worldcereal.train.data import WorldCerealTrainingDataset, get_training_df
12 | from worldcereal.utils.refdata import (
13 | process_extractions_df,
14 | query_private_extractions,
15 | query_public_extractions,
16 | )
17 |
18 | SPATIAL_EXTENT = BoundingBoxExtent(
19 | west=4.63761, south=51.11649, east=4.73761, north=51.21649, epsg=4326
20 | ).to_geometry()
21 |
22 |
23 | def test_custom_croptype_demo(WorldCerealPrivateExtractionsPath):
24 | """Test for a full custom croptype pipeline up and till the point of
25 | training a custom catboost classifier. This test uses both public and private
26 | extractions.
27 | """
28 |
29 | # Query private and public extractions
30 |
31 | private_df = query_private_extractions(
32 | WorldCerealPrivateExtractionsPath,
33 | bbox_poly=SPATIAL_EXTENT,
34 | filter_cropland=True,
35 | buffer=250000, # Meters
36 | )
37 |
38 | assert not private_df.empty, "Should have found private extractions"
39 | logger.info(
40 | (
41 | f"Found {private_df['sample_id'].nunique()} unique samples in the "
42 | f"private data, spread across {private_df['ref_id'].nunique()} "
43 | "unique reference datasets."
44 | )
45 | )
46 |
47 | public_df = query_public_extractions(
48 | SPATIAL_EXTENT,
49 | buffer=1000, # Meters
50 | filter_cropland=True,
51 | )
52 |
53 | assert not public_df.empty, "Should have found public extractions"
54 | logger.info(
55 | (
56 | f"Found {public_df['sample_id'].nunique()} unique samples in the "
57 | f"public data, spread across {public_df['ref_id'].nunique()} "
58 | "unique reference datasets."
59 | )
60 | )
61 |
62 | # Concatenate extractions
63 | extractions_df = pd.concat([private_df, public_df])
64 |
65 | assert len(extractions_df) == len(public_df) + len(private_df)
66 |
67 | # Process the merged data
68 | processing_period = TemporalContext("2020-01-01", "2020-12-31")
69 | print(f"Shape of extractions_df: {extractions_df.shape}")
70 | training_df = process_extractions_df(extractions_df, processing_period)
71 | logger.info(f"training_df shape: {training_df.shape}")
72 |
73 | # Drop labels that occur infrequently for this test
74 | value_counts = training_df["ewoc_code"].value_counts()
75 | single_labels = value_counts[value_counts < 3].index.to_list()
76 | training_df = training_df[~training_df["ewoc_code"].isin(single_labels)]
77 |
78 | print("*" * 40)
79 | for c in training_df.columns:
80 | print(c)
81 | print("*" * 40)
82 |
83 | # Direct shape assert: if process_extractions_df changes, this may have to be updated
84 | assert training_df.shape == (238, 246)
85 |
86 | # We keep original ewoc_code for this test
87 | training_df["downstream_class"] = training_df["ewoc_code"]
88 |
89 | # Compute presto embeddings
90 | presto_model_url = CropTypeParameters().feature_parameters.presto_model_url
91 |
92 | # Load pretrained Presto model
93 | logger.info(f"Presto URL: {presto_model_url}")
94 | presto_model = Presto.load_pretrained(
95 | presto_model_url,
96 | from_url=True,
97 | strict=False,
98 | valid_month_as_token=True,
99 | )
100 |
101 | # Initialize dataset
102 | df = training_df.reset_index()
103 | ds = WorldCerealTrainingDataset(
104 | df,
105 | task_type="croptype",
106 | augment=True,
107 | mask_ratio=0.30,
108 | repeats=1,
109 | )
110 | logger.info("Computing Presto embeddings ...")
111 | df = get_training_df(
112 | ds,
113 | presto_model,
114 | batch_size=256,
115 | valid_date_as_token=True,
116 | )
117 | logger.info("Presto embeddings computed.")
118 |
119 | # Train classifier
120 | logger.info("Split train/test ...")
121 | samples_train, samples_test = train_test_split(
122 | df,
123 | test_size=0.2,
124 | random_state=3,
125 | stratify=df["downstream_class"],
126 | )
127 |
128 | eval_metric = "MultiClass"
129 | loss_function = "MultiClass"
130 |
131 | logger.info("Computing class weights ...")
132 | class_weights = np.round(
133 | compute_class_weight(
134 | class_weight="balanced",
135 | classes=np.unique(samples_train["downstream_class"]),
136 | y=samples_train["downstream_class"],
137 | ),
138 | 3,
139 | )
140 | class_weights = {
141 | k: v
142 | for k, v in zip(np.unique(samples_train["downstream_class"]), class_weights)
143 | }
144 | logger.info(f"Class weights: {class_weights}")
145 |
146 | sample_weights = np.ones((len(samples_train["downstream_class"]),))
147 | sample_weights_val = np.ones((len(samples_test["downstream_class"]),))
148 | for k, v in class_weights.items():
149 | sample_weights[samples_train["downstream_class"] == k] = v
150 | sample_weights_val[samples_test["downstream_class"] == k] = v
151 | samples_train["weight"] = sample_weights
152 | samples_test["weight"] = sample_weights_val
153 |
154 | # Define classifier
155 | custom_downstream_model = CatBoostClassifier(
156 | iterations=2000, # Not too high to avoid too large model size
157 | depth=8,
158 | early_stopping_rounds=20,
159 | loss_function=loss_function,
160 | eval_metric=eval_metric,
161 | random_state=3,
162 | verbose=25,
163 | class_names=np.unique(samples_train["downstream_class"]),
164 | )
165 |
166 | # Setup dataset Pool
167 | bands = [f"presto_ft_{i}" for i in range(128)]
168 | calibration_data = Pool(
169 | data=samples_train[bands],
170 | label=samples_train["downstream_class"],
171 | weight=samples_train["weight"],
172 | )
173 | eval_data = Pool(
174 | data=samples_test[bands],
175 | label=samples_test["downstream_class"],
176 | weight=samples_test["weight"],
177 | )
178 |
179 | # Train classifier
180 | logger.info("Training CatBoost classifier ...")
181 | custom_downstream_model.fit(
182 | calibration_data,
183 | eval_set=eval_data,
184 | )
185 |
186 | # Make predictions
187 | _ = custom_downstream_model.predict(samples_test[bands]).flatten()
188 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/test_postprocessing.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from openeo_gfmap.inference.model_inference import (
3 | EPSG_HARMONIZED_NAME,
4 | apply_model_inference_local,
5 | )
6 |
7 | from worldcereal.openeo.postprocess import PostProcessor
8 | from worldcereal.parameters import CropLandParameters, PostprocessParameters
9 |
10 |
11 | def test_cropland_postprocessing(WorldCerealCroplandClassification):
12 | """Test the local postprocessing of a cropland product"""
13 |
14 | print("Postprocessing cropland product ...")
15 | _ = apply_model_inference_local(
16 | PostProcessor,
17 | WorldCerealCroplandClassification,
18 | parameters={
19 | "ignore_dependencies": True,
20 | EPSG_HARMONIZED_NAME: None,
21 | "classifier_url": CropLandParameters().classifier_parameters.classifier_url,
22 | "method": "smooth_probabilities",
23 | },
24 | )
25 |
26 |
27 | def test_cropland_postprocessing_majority_vote(WorldCerealCroplandClassification):
28 | """Test the local postprocessing of a cropland product"""
29 |
30 | print("Postprocessing cropland product ...")
31 | _ = apply_model_inference_local(
32 | PostProcessor,
33 | WorldCerealCroplandClassification,
34 | parameters={
35 | "ignore_dependencies": True,
36 | EPSG_HARMONIZED_NAME: None,
37 | "classifier_url": CropLandParameters().classifier_parameters.classifier_url,
38 | "method": "majority_vote",
39 | "kernel_size": 7,
40 | },
41 | )
42 |
43 |
44 | def test_croptype_postprocessing(WorldCerealCroptypeClassification):
45 | """Test the local postprocessing of a croptype product"""
46 |
47 | print("Postprocessing croptype product ...")
48 | _ = apply_model_inference_local(
49 | PostProcessor,
50 | WorldCerealCroptypeClassification,
51 | parameters={
52 | "ignore_dependencies": True,
53 | EPSG_HARMONIZED_NAME: None,
54 | "classifier_url": CropLandParameters().classifier_parameters.classifier_url,
55 | "method": "smooth_probabilities",
56 | },
57 | )
58 |
59 |
60 | def test_croptype_postprocessing_majority_vote(WorldCerealCroptypeClassification):
61 | """Test the local postprocessing of a croptype product"""
62 |
63 | print("Postprocessing croptype product ...")
64 | _ = apply_model_inference_local(
65 | PostProcessor,
66 | WorldCerealCroptypeClassification,
67 | parameters={
68 | "ignore_dependencies": True,
69 | EPSG_HARMONIZED_NAME: None,
70 | "classifier_url": CropLandParameters().classifier_parameters.classifier_url,
71 | "method": "majority_vote",
72 | "kernel_size": 7,
73 | },
74 | )
75 |
76 |
77 | def test_postprocessing_parameters():
78 | """Test the postprocessing parameters."""
79 |
80 | # This set should work
81 | params = {
82 | "enable": True,
83 | "method": "smooth_probabilities",
84 | "kernel_size": 5,
85 | "save_intermediate": False,
86 | "keep_class_probs": False,
87 | }
88 | PostprocessParameters(**params)
89 |
90 | # This one as well
91 | params["method"] = "majority_vote"
92 | PostprocessParameters(**params)
93 |
94 | # This one should fail with invalid kernel size
95 | params["kernel_size"] = 30
96 | with pytest.raises(ValueError):
97 | PostprocessParameters(**params)
98 |
99 | # This one should fail with invalid method
100 | params["method"] = "test"
101 | with pytest.raises(ValueError):
102 | PostprocessParameters(**params)
103 |
104 | # This one should fail with invalid save_intermediate
105 | params["enable"] = False
106 | params["save_intermediate"] = True
107 | params["method"] = "smooth_probabilities"
108 | with pytest.raises(ValueError):
109 | PostprocessParameters(**params)
110 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/test_preprocessing.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 |
5 | import pytest
6 | from openeo_gfmap import BoundingBoxExtent, FetchType
7 | from openeo_gfmap.backend import (
8 | Backend,
9 | BackendContext,
10 | cdse_connection,
11 | vito_connection,
12 | )
13 | from openeo_gfmap.temporal import TemporalContext
14 |
15 | from worldcereal.extract.patch_to_point_worldcereal import (
16 | worldcereal_preprocessed_inputs_from_patches,
17 | )
18 | from worldcereal.openeo.preprocessing import (
19 | InvalidTemporalContextError,
20 | _validate_temporal_context,
21 | correct_temporal_context,
22 | worldcereal_preprocessed_inputs,
23 | )
24 |
25 | basedir = Path(os.path.dirname(os.path.realpath(__file__)))
26 |
27 |
28 | def test_temporal_context_validation():
29 | """Test the validation of temporal context."""
30 |
31 | temporal_context = TemporalContext("2020-01-01", "2020-12-31")
32 | _validate_temporal_context(temporal_context)
33 |
34 | incorrect_temporal_context = TemporalContext("2020-01-05", "2020-03-15")
35 |
36 | with pytest.raises(InvalidTemporalContextError):
37 | _validate_temporal_context(incorrect_temporal_context)
38 |
39 | more_than_one_year = TemporalContext("2019-01-05", "2021-03-15")
40 |
41 | with pytest.raises(InvalidTemporalContextError):
42 | _validate_temporal_context(more_than_one_year)
43 |
44 |
45 | def test_temporal_context_correction():
46 | """Test the automatic correction of invalid temporal context."""
47 |
48 | incorrect_temporal_context = TemporalContext("2022-01-05", "2020-03-15")
49 | corrected_temporal_context = correct_temporal_context(incorrect_temporal_context)
50 |
51 | # Should no longer raise an exception
52 | _validate_temporal_context(corrected_temporal_context)
53 |
54 |
55 | def test_worldcereal_preprocessed_inputs_graph(SpatialExtent):
56 | """Test the worldcereal_preprocessed_inputs function.
57 | This is based on constructing the openEO graph for the job
58 | that would run, without actually running it."""
59 |
60 | temporal_extent = TemporalContext("2020-06-01", "2021-05-31")
61 |
62 | cube = worldcereal_preprocessed_inputs(
63 | connection=cdse_connection(),
64 | backend_context=BackendContext(Backend.CDSE),
65 | spatial_extent=SpatialExtent,
66 | temporal_extent=temporal_extent,
67 | fetch_type=FetchType.POLYGON,
68 | )
69 |
70 | # Ref file with processing graph
71 | ref_graph = basedir / "testresources" / "preprocess_graph.json"
72 |
73 | # # uncomment to save current graph to the ref file
74 | # with open(ref_graph, "w") as f:
75 | # f.write(json.dumps(cube.flat_graph(), indent=4))
76 |
77 | with open(ref_graph, "r") as f:
78 | expected = json.load(f)
79 | assert expected == cube.flat_graph()
80 |
81 |
82 | def test_worldcereal_preprocessed_inputs_graph_withslope():
83 | """This version has fetchtype.TILE and should include slope."""
84 |
85 | temporal_extent = TemporalContext("2018-03-01", "2019-02-28")
86 |
87 | cube = worldcereal_preprocessed_inputs(
88 | connection=cdse_connection(),
89 | backend_context=BackendContext(Backend.CDSE),
90 | spatial_extent=BoundingBoxExtent(
91 | west=44.432274, south=51.317362, east=44.698802, north=51.428224, epsg=4326
92 | ),
93 | temporal_extent=temporal_extent,
94 | )
95 |
96 | # Ref file with processing graph
97 | ref_graph = basedir / "testresources" / "preprocess_graphwithslope.json"
98 |
99 | # # uncomment to save current graph to the ref file
100 | # with open(ref_graph, "w") as f:
101 | # f.write(json.dumps(cube.flat_graph(), indent=4))
102 |
103 | with open(ref_graph, "r") as f:
104 | expected = json.load(f)
105 | assert expected == cube.flat_graph()
106 |
107 |
108 | def test_worldcereal_preprocessed_inputs_from_patches_graph():
109 | """This version gets a preprocessed cube from extracted patches."""
110 |
111 | temporal_extent = TemporalContext("2020-01-01", "2020-12-31")
112 |
113 | cube = worldcereal_preprocessed_inputs_from_patches(
114 | connection=vito_connection(),
115 | temporal_extent=temporal_extent,
116 | ref_id="test_ref_id",
117 | epsg=32631,
118 | )
119 |
120 | # Ref file with processing graph
121 | ref_graph = basedir / "testresources" / "preprocess_from_patches_graph.json"
122 |
123 | # # uncomment to save current graph to the ref file
124 | # with open(ref_graph, "w") as f:
125 | # f.write(json.dumps(cube.flat_graph(), indent=4))
126 |
127 | with open(ref_graph, "r") as f:
128 | expected = json.load(f)
129 | assert expected == cube.flat_graph()
130 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/test_rdm_interaction.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import patch
2 |
3 | import geopandas as gpd
4 | import pytest
5 | from openeo_gfmap import BoundingBoxExtent, TemporalContext
6 | from shapely import Point
7 |
8 | from worldcereal.rdm_api.rdm_interaction import RdmCollection, RdmInteraction
9 |
10 |
11 | @pytest.fixture
12 | def sample_bbox():
13 | return BoundingBoxExtent(west=0, east=1, north=1, south=0)
14 |
15 |
16 | @pytest.fixture
17 | def sample_temporal_extent():
18 | return TemporalContext(start_date="2021-01-01", end_date="2021-12-31")
19 |
20 |
21 | class TestRdmInteraction:
22 | @patch("requests.Session.get")
23 | def test_collections_from_rdm(
24 | self, mock_requests_get, sample_bbox, sample_temporal_extent
25 | ):
26 |
27 | mock_requests_get.return_value.status_code = 200
28 | mock_requests_get.return_value.json.return_value = [
29 | {
30 | "collectionId": "Foo",
31 | "title": "Foo_title",
32 | "accessType": "Public",
33 | },
34 | {
35 | "collectionId": "Bar",
36 | "title": "Bar_title",
37 | "accessType": "Public",
38 | },
39 | ]
40 | interaction = RdmInteraction()
41 | collections = interaction.get_collections(
42 | spatial_extent=sample_bbox, temporal_extent=sample_temporal_extent
43 | )
44 | ref_ids = [collection.id for collection in collections]
45 |
46 | assert ref_ids == ["Foo", "Bar"]
47 |
48 | bbox_str = f"Bbox={sample_bbox.west}&Bbox={sample_bbox.south}&Bbox={sample_bbox.east}&Bbox={sample_bbox.north}"
49 | temporal = f"&ValidityTime.Start={sample_temporal_extent.start_date}T00%3A00%3A00Z&ValidityTime.End={sample_temporal_extent.end_date}T00%3A00%3A00Z"
50 | expected_url = (
51 | f"{interaction.RDM_ENDPOINT}/collections/search?{bbox_str}{temporal}"
52 | )
53 | mock_requests_get.assert_called_with(
54 | url=expected_url, headers={"accept": "*/*"}, timeout=10
55 | )
56 |
57 | @patch("worldcereal.rdm_api.rdm_interaction.RdmInteraction.get_collections")
58 | @patch("worldcereal.rdm_api.rdm_interaction.RdmInteraction._get_download_urls")
59 | def test_download_samples(
60 | self,
61 | mock_get_download_urls,
62 | mock_collections_from_rdm,
63 | sample_bbox,
64 | sample_temporal_extent,
65 | tmp_path,
66 | ):
67 |
68 | data = {
69 | "col1": ["must", "include", "this", "column", "definitely", "check"],
70 | "col2": ["and", "this", "One", "Too", "please", "check"],
71 | "col3": ["but", "not", "This", "One", "please", "check"],
72 | "valid_time": [
73 | "2021-01-01",
74 | "2021-12-31",
75 | "2021-06-01",
76 | "2025-05-22",
77 | "2021-06-01",
78 | "2021-06-01",
79 | ], # Fourth date not within sample_temporal_extent
80 | "ewoc_code": ["1", "2", "3", "4", "5", "1"],
81 | # Fifth crop code not within list of ewoc_codes
82 | "extract": [1, 1, 1, 1, 2, 0],
83 | "geometry": [
84 | Point(0.5, 0.5),
85 | Point(0.25, 0.25),
86 | Point(2, 3),
87 | Point(0.75, 0.75),
88 | Point(0.75, 0.78),
89 | Point(0.78, 0.75),
90 | ], # Third point not within sample_polygon
91 | }
92 | gdf = gpd.GeoDataFrame(data, crs="EPSG:4326")
93 | file_path = tmp_path / "sample.parquet"
94 | gdf.to_parquet(file_path)
95 |
96 | mock_collections_from_rdm.return_value = [
97 | RdmCollection(
98 | **{
99 | "collectionId": "Foo",
100 | "title": "Foo_title",
101 | "accessType": "Public",
102 | }
103 | ),
104 | ]
105 | mock_get_download_urls.return_value = [str(file_path)]
106 |
107 | interaction = RdmInteraction()
108 | result_gdf = interaction.get_samples(
109 | spatial_extent=sample_bbox,
110 | temporal_extent=sample_temporal_extent,
111 | columns=["col1", "col2", "ref_id", "geometry"],
112 | ewoc_codes=["1", "2", "3", "4"],
113 | subset=True,
114 | )
115 |
116 | # Check that col3 and valid_time indeed not included
117 | assert result_gdf.columns.tolist() == [
118 | "col1",
119 | "col2",
120 | "ref_id",
121 | "geometry",
122 | ]
123 |
124 | # Check that the third up till last geometry are not included
125 | # third and fourth are outside the spatiotemporal extent
126 | # fifth has a crop type not in the list of ewoc_codes
127 | # last sample is not to be extracted
128 | assert len(result_gdf) == 2
129 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/test_refdata.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from shapely.geometry import Polygon
3 |
4 | from worldcereal.utils.refdata import (
5 | get_best_valid_time,
6 | month_diff,
7 | query_public_extractions,
8 | )
9 |
10 |
11 | def test_query_public_extractions():
12 | """Unittest for querying public extractions."""
13 |
14 | # Define small polygon
15 | poly = Polygon.from_bounds(*(4.535, 51.050719, 4.600936, 51.098176))
16 |
17 | # Query extractions
18 | df = query_public_extractions(poly, buffer=100)
19 |
20 | # Check if dataframe has samples
21 | assert not df.empty
22 |
23 |
24 | def test_get_best_valid_time():
25 | def process_test_case(test_case: pd.Series) -> pd.DataFrame:
26 | test_case_res = []
27 | for processing_period_middle_month in range(1, 13):
28 | test_case["true_valid_time_month"] = test_case["valid_time"].month
29 | test_case["proposed_valid_time_month"] = processing_period_middle_month
30 | test_case["valid_month_shift_backward"] = month_diff(
31 | test_case["proposed_valid_time_month"],
32 | test_case["true_valid_time_month"],
33 | )
34 | test_case["valid_month_shift_forward"] = month_diff(
35 | test_case["true_valid_time_month"],
36 | test_case["proposed_valid_time_month"],
37 | )
38 | proposed_valid_time = get_best_valid_time(test_case)
39 | test_case_res.append([processing_period_middle_month, proposed_valid_time])
40 | return pd.DataFrame(
41 | test_case_res, columns=["proposed_valid_month", "resulting_valid_time"]
42 | )
43 |
44 | test_case1 = pd.Series(
45 | {
46 | "start_date": pd.to_datetime("2019-01-01"),
47 | "end_date": pd.to_datetime("2019-12-01"),
48 | "valid_time": pd.to_datetime("2019-06-01"),
49 | }
50 | )
51 | test_case2 = pd.Series(
52 | {
53 | "start_date": pd.to_datetime("2019-01-01"),
54 | "end_date": pd.to_datetime("2019-12-01"),
55 | "valid_time": pd.to_datetime("2019-10-01"),
56 | }
57 | )
58 | test_case3 = pd.Series(
59 | {
60 | "start_date": pd.to_datetime("2019-01-01"),
61 | "end_date": pd.to_datetime("2019-12-01"),
62 | "valid_time": pd.to_datetime("2019-03-01"),
63 | }
64 | )
65 |
66 | # Process test cases
67 | test_case1_res = process_test_case(test_case1)
68 | test_case2_res = process_test_case(test_case2)
69 | test_case3_res = process_test_case(test_case3)
70 |
71 | # Asserts are valid for default MIN_EDGE_BUFFER and NUM_TIMESTEPS values
72 | # Assertions for test case 1
73 | assert (
74 | test_case1_res[test_case1_res["proposed_valid_month"].isin([1, 2, 11, 12])][
75 | "resulting_valid_time"
76 | ]
77 | .isna()
78 | .all()
79 | )
80 | assert (
81 | test_case1_res[test_case1_res["proposed_valid_month"].isin(range(3, 11))][
82 | "resulting_valid_time"
83 | ]
84 | .notna()
85 | .all()
86 | )
87 |
88 | # Assertions for test case 2
89 | assert (
90 | test_case2_res[test_case2_res["proposed_valid_month"].isin([1, 2, 3, 11, 12])][
91 | "resulting_valid_time"
92 | ]
93 | .isna()
94 | .all()
95 | )
96 | assert (
97 | test_case2_res[test_case2_res["proposed_valid_month"].isin(range(4, 11))][
98 | "resulting_valid_time"
99 | ]
100 | .notna()
101 | .all()
102 | )
103 |
104 | # Assertions for test case 3
105 | assert (
106 | test_case3_res[
107 | test_case3_res["proposed_valid_month"].isin([1, 2, 9, 10, 11, 12])
108 | ]["resulting_valid_time"]
109 | .isna()
110 | .all()
111 | )
112 | assert (
113 | test_case3_res[test_case3_res["proposed_valid_month"].isin(range(3, 9))][
114 | "resulting_valid_time"
115 | ]
116 | .notna()
117 | .all()
118 | )
119 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/test_seasons.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from openeo_gfmap import BoundingBoxExtent
6 |
7 | from worldcereal.seasons import (
8 | circular_median_day_of_year,
9 | doy_from_tiff,
10 | doy_to_date_after,
11 | get_processing_dates_for_extent,
12 | )
13 |
14 |
15 | def test_doy_from_tiff():
16 | bounds = (574680, 5621800, 575320, 5622440)
17 | epsg = 32631
18 |
19 | doy_data = doy_from_tiff("tc-s1", "SOS", bounds, epsg, resolution=10000)
20 |
21 | assert doy_data.size == 1
22 |
23 | doy_data = int(doy_data)
24 |
25 | assert doy_data != 0
26 |
27 |
28 | def test_doy_to_date_after():
29 | bounds = (574680, 5621800, 575320, 5622440)
30 | epsg = 32631
31 |
32 | doy_data = doy_from_tiff("tc-s2", "SOS", bounds, epsg, resolution=10000)
33 |
34 | after_date = datetime.datetime(2019, 1, 1)
35 | doy_date = doy_to_date_after(int(doy_data), after_date)
36 |
37 | assert pd.to_datetime(doy_date) >= after_date
38 |
39 | after_date = datetime.datetime(2019, 8, 1)
40 | doy_date = doy_to_date_after(int(doy_data), after_date)
41 |
42 | assert pd.to_datetime(doy_date) >= after_date
43 |
44 |
45 | def test_get_processing_dates_for_extent():
46 | # Test to check if we can infer processing dates for default season
47 | # tc-annual
48 | bounds = (574680, 5621800, 575320, 5622440)
49 | epsg = 32631
50 | year = 2021
51 | extent = BoundingBoxExtent(*bounds, epsg)
52 |
53 | temporal_context = get_processing_dates_for_extent(extent, year)
54 | start_date = temporal_context.start_date
55 | end_date = temporal_context.end_date
56 |
57 | assert pd.to_datetime(end_date).year == year
58 | assert pd.to_datetime(end_date) - pd.to_datetime(start_date) == pd.Timedelta(
59 | days=364
60 | )
61 |
62 |
63 | def test_compute_median_doy():
64 | # Test to check if we can compute median day of year
65 | # if array crosses the calendar year
66 | doy_array = np.array([360, 362, 365, 1, 3, 5, 7])
67 | assert circular_median_day_of_year(doy_array) == 1
68 |
69 | # Other tests
70 | assert circular_median_day_of_year([1, 2, 3, 4, 5]) == 3
71 | assert circular_median_day_of_year([320, 330, 340, 360]) == 330
72 | assert circular_median_day_of_year([320, 330, 340, 360, 10]) == 340
73 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/test_stac_api_interaction.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import MagicMock, patch
2 |
3 | import pystac
4 | import pytest
5 | from requests.auth import AuthBase
6 |
7 | from worldcereal.stac.stac_api_interaction import StacApiInteraction
8 |
9 |
10 | @pytest.fixture
11 | def mock_auth():
12 | return MagicMock(spec=AuthBase)
13 |
14 |
15 | def mock_stac_item(item_id):
16 | item = MagicMock(spec=pystac.Item)
17 | item.id = item_id
18 | item.to_dict.return_value = {"id": item_id, "some_property": "value"}
19 | return item
20 |
21 |
22 | class TestStacApiInteraction:
23 | @patch("requests.post")
24 | @patch("worldcereal.stac.stac_api_interaction.StacApiInteraction.exists")
25 | def test_upload_items_single_chunk(
26 | self, mock_exists, mock_requests_post, mock_auth
27 | ):
28 | """Test bulk upload of STAC items in one single chunk."""
29 |
30 | mock_requests_post.return_value.status_code = 200
31 | mock_requests_post.return_value.json.return_value = {"status": "success"}
32 | mock_requests_post.reason = "OK"
33 |
34 | mock_exists.return_value = True
35 |
36 | items = [mock_stac_item(f"item-{i}") for i in range(10)]
37 |
38 | interaction = StacApiInteraction(
39 | sensor="Sentinel1",
40 | base_url="http://fake-stac-api",
41 | auth=mock_auth,
42 | bulk_size=10, # To ensure all 10 items are uploaded in one bulk
43 | )
44 | interaction.upload_items_bulk(items)
45 |
46 | mock_requests_post.assert_called_with(
47 | url=f"http://fake-stac-api/collections/{interaction.collection_id}/bulk_items",
48 | auth=mock_auth,
49 | json={
50 | "method": "upsert",
51 | "items": {item.id: item.to_dict() for item in items},
52 | },
53 | )
54 | assert mock_requests_post.call_count == 1
55 |
56 | @patch("requests.post")
57 | @patch("worldcereal.stac.stac_api_interaction.StacApiInteraction.exists")
58 | def test_upload_items_multiple_chunk(
59 | self, mock_exists, mock_requests_post, mock_auth
60 | ):
61 | """Test bulk upload of STAC items in mulitiple chunks."""
62 |
63 | mock_requests_post.return_value.status_code = 200
64 | mock_requests_post.return_value.json.return_value = {"status": "success"}
65 | mock_requests_post.reason = "OK"
66 |
67 | mock_exists.return_value = True
68 |
69 | items = [mock_stac_item(f"item-{i}") for i in range(10)]
70 |
71 | interaction = StacApiInteraction(
72 | sensor="Sentinel1",
73 | base_url="http://fake-stac-api",
74 | auth=mock_auth,
75 | bulk_size=3, # This would require 4 chunk for 10 items
76 | )
77 | interaction.upload_items_bulk(items)
78 |
79 | assert mock_requests_post.call_count == 4
80 |
81 | expected_calls = [
82 | {
83 | "url": f"http://fake-stac-api/collections/{interaction.collection_id}/bulk_items",
84 | "auth": mock_auth,
85 | "json": {
86 | "method": "upsert",
87 | "items": {
88 | "item-0": {"id": "item-0", "some_property": "value"},
89 | "item-1": {"id": "item-1", "some_property": "value"},
90 | "item-2": {"id": "item-2", "some_property": "value"},
91 | },
92 | },
93 | },
94 | {
95 | "url": f"http://fake-stac-api/collections/{interaction.collection_id}/bulk_items",
96 | "auth": mock_auth,
97 | "json": {
98 | "method": "upsert",
99 | "items": {
100 | "item-3": {"id": "item-3", "some_property": "value"},
101 | "item-4": {"id": "item-4", "some_property": "value"},
102 | "item-5": {"id": "item-5", "some_property": "value"},
103 | },
104 | },
105 | },
106 | {
107 | "url": f"http://fake-stac-api/collections/{interaction.collection_id}/bulk_items",
108 | "auth": mock_auth,
109 | "json": {
110 | "method": "upsert",
111 | "items": {
112 | "item-6": {"id": "item-6", "some_property": "value"},
113 | "item-7": {"id": "item-7", "some_property": "value"},
114 | "item-8": {"id": "item-8", "some_property": "value"},
115 | },
116 | },
117 | },
118 | {
119 | "url": f"http://fake-stac-api/collections/{interaction.collection_id}/bulk_items",
120 | "auth": mock_auth,
121 | "json": {
122 | "method": "upsert",
123 | "items": {
124 | "item-9": {"id": "item-9", "some_property": "value"},
125 | },
126 | },
127 | },
128 | ]
129 |
130 | for i, call in enumerate(mock_requests_post.call_args_list):
131 | assert call[1] == expected_calls[i]
132 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/test_train.py:
--------------------------------------------------------------------------------
1 | from presto.presto import Presto
2 | from torch.utils.data import DataLoader
3 |
4 | from worldcereal.parameters import CropLandParameters
5 | from worldcereal.train.data import WorldCerealTrainingDataset, get_training_df
6 |
7 |
8 | def test_worldcerealtraindataset(WorldCerealExtractionsDF):
9 | """Test creation of WorldCerealTrainingDataset and data loading"""
10 |
11 | df = WorldCerealExtractionsDF.reset_index()
12 |
13 | ds = WorldCerealTrainingDataset(
14 | df,
15 | task_type="cropland",
16 | augment=True,
17 | mask_ratio=0.25,
18 | repeats=2,
19 | )
20 |
21 | # Check if number of samples matches repeats
22 | assert len(ds) == 2 * len(df)
23 |
24 | # Check if data loading works
25 | dl = DataLoader(ds, batch_size=2, shuffle=True)
26 |
27 | for x, y, dw, latlons, month, valid_month, variable_mask, attrs in dl:
28 | assert x.shape == (2, 12, 17)
29 | assert y.shape == (2,)
30 | assert dw.shape == (2, 12)
31 | assert dw.unique().numpy()[0] == 9
32 | assert latlons.shape == (2, 2)
33 | assert month.shape == (2,)
34 | assert valid_month.shape == (2,)
35 | assert variable_mask.shape == x.shape
36 | assert isinstance(attrs, dict)
37 | break
38 |
39 |
40 | def test_get_trainingdf(WorldCerealExtractionsDF):
41 | """Test the function that computes embeddings and targets into
42 | a training dataframe using a presto model
43 | """
44 |
45 | df = WorldCerealExtractionsDF.reset_index()
46 | ds = WorldCerealTrainingDataset(df)
47 |
48 | presto_url = CropLandParameters().feature_parameters.presto_model_url
49 | presto_model = Presto.load_pretrained(presto_url, from_url=True, strict=False)
50 |
51 | training_df = get_training_df(ds, presto_model, batch_size=256)
52 |
53 | for ft in range(128):
54 | assert f"presto_ft_{ft}" in training_df.columns
55 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/test_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from catboost import CatBoostClassifier
3 | from openeo_gfmap.backend import cdse_connection
4 |
5 | from worldcereal.utils.models import load_model_onnx
6 | from worldcereal.utils.upload import deploy_model
7 |
8 |
9 | def test_deploy_model():
10 | """Simple test to deploy a CatBoost model and load it back."""
11 | model = CatBoostClassifier(iterations=10).fit(X=[[1, 2], [3, 4]], y=[0, 1])
12 | presigned_uri = deploy_model(cdse_connection(), model)
13 | model = load_model_onnx(presigned_uri)
14 |
15 | # Compare model predictions with the original targets
16 | np.testing.assert_array_equal(
17 | model.run(None, {"features": [[1, 2], [3, 4]]})[0], [0, 1]
18 | )
19 |
--------------------------------------------------------------------------------
/tests/worldcerealtests/testresources/spatial_extent.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "FeatureCollection",
3 | "features": [
4 | {
5 | "type": "Feature",
6 | "geometry": {
7 | "type": "Polygon",
8 | "coordinates": [
9 | [
10 | [
11 | 44.433631,
12 | 51.317362
13 | ],
14 | [
15 | 44.432274,
16 | 51.427238
17 | ],
18 | [
19 | 44.69808,
20 | 51.428224
21 | ],
22 | [
23 | 44.698802,
24 | 51.318344
25 | ],
26 | [
27 | 44.433631,
28 | 51.317362
29 | ]
30 | ]
31 | ]
32 | },
33 | "properties": {}
34 | }
35 | ]
36 | }
--------------------------------------------------------------------------------
/tests/worldcerealtests/testresources/test_public_extractions.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/testresources/test_public_extractions.parquet
--------------------------------------------------------------------------------
/tests/worldcerealtests/testresources/worldcereal_cropland_classification.nc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/testresources/worldcereal_cropland_classification.nc
--------------------------------------------------------------------------------
/tests/worldcerealtests/testresources/worldcereal_croptype_classification.nc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/testresources/worldcereal_croptype_classification.nc
--------------------------------------------------------------------------------
/tests/worldcerealtests/testresources/worldcereal_preprocessed_inputs.nc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/testresources/worldcereal_preprocessed_inputs.nc
--------------------------------------------------------------------------------
/tests/worldcerealtests/testresources/worldcereal_private_extractions_dummy.parquet/ref_id=2021_BEL_LPIS-Flanders_POLY_110/data_0.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WorldCereal/worldcereal-classification/66e6df14770265f5c21f87946d82311fc7fa12f3/tests/worldcerealtests/testresources/worldcereal_private_extractions_dummy.parquet/ref_id=2021_BEL_LPIS-Flanders_POLY_110/data_0.parquet
--------------------------------------------------------------------------------