├── docs ├── static │ ├── crim.png │ ├── nrcan.png │ ├── stac_mlm.png │ ├── terradue.png │ ├── wherobots.png │ └── sigspatial_2024_mlm.pdf └── legacy │ ├── dlm.md │ └── ml-model.md ├── stac_model ├── torch │ ├── __init__.py │ ├── base.py │ └── utils.py ├── __init__.py ├── __main__.py ├── runtime.py ├── input.py ├── output.py ├── base.py ├── schema.py └── examples.py ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── question.md │ ├── feature_request.md │ └── bug_report.md ├── workflows │ ├── test.yaml │ ├── release-drafter.yml │ ├── citation.yaml │ ├── publish.yaml │ └── stac-model.yml ├── .stale.yml ├── release-drafter.yml ├── dependabot.yml ├── PULL_REQUEST_TEMPLATE.md └── remark.yaml ├── .remarkignore ├── tests ├── torch │ ├── test_metadata.py │ ├── test_unet_mlm.py │ ├── metadata.yaml │ └── test_export.py ├── test_examples.py ├── conftest.py └── test_stac_model.py ├── .pre-commit-config.yaml ├── .editorconfig ├── package.json ├── examples ├── torch │ └── mlm-metadata.yaml ├── collection.json ├── item_basic.json ├── item_pytorch_geo_unet.json ├── item_bands_expression.json ├── item_multi_io.json ├── item_datacube_variables.json ├── item_raster_bands.json └── item_eo_bands.json ├── stac-model.bump.toml ├── .safety-policy.yml ├── CONTRIBUTING.md ├── Makefile ├── README_STAC_MODEL.md ├── pyproject.toml ├── LICENSE └── CITATION.cff /docs/static/crim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stac-extensions/mlm/HEAD/docs/static/crim.png -------------------------------------------------------------------------------- /docs/static/nrcan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stac-extensions/mlm/HEAD/docs/static/nrcan.png -------------------------------------------------------------------------------- /docs/static/stac_mlm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stac-extensions/mlm/HEAD/docs/static/stac_mlm.png -------------------------------------------------------------------------------- /docs/static/terradue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stac-extensions/mlm/HEAD/docs/static/terradue.png -------------------------------------------------------------------------------- /docs/static/wherobots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stac-extensions/mlm/HEAD/docs/static/wherobots.png -------------------------------------------------------------------------------- /docs/static/sigspatial_2024_mlm.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stac-extensions/mlm/HEAD/docs/static/sigspatial_2024_mlm.pdf -------------------------------------------------------------------------------- /stac_model/torch/__init__.py: -------------------------------------------------------------------------------- 1 | # noqa: F401 2 | from stac_model.torch.export import from_torch 3 | 4 | __all__ = ["from_torch"] 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | # Configuration: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository 2 | 3 | blank_issues_enabled: false 4 | -------------------------------------------------------------------------------- /stac_model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | A PydanticV2/PySTAC validation and serialization library for the STAC Machine Learning Model Extension. 3 | """ 4 | 5 | from importlib import metadata 6 | 7 | try: 8 | __version__ = metadata.version("stac-model") 9 | except metadata.PackageNotFoundError: 10 | __version__ = "unknown" 11 | -------------------------------------------------------------------------------- /.remarkignore: -------------------------------------------------------------------------------- 1 | # To save time scanning 2 | .idea/ 3 | .vscode/ 4 | .tox/ 5 | .git/ 6 | .github/**/*.yaml 7 | .github/**/*.yml 8 | *.egg-info/ 9 | build/ 10 | dist/ 11 | downloads/ 12 | env/ 13 | 14 | # actual items to ignore 15 | .pytest_cache/ 16 | node_modules/ 17 | docs/_build/ 18 | docs/build/ 19 | 20 | # potentially conflicting dev installs 21 | stac-mlm/ 22 | mlm/ 23 | -------------------------------------------------------------------------------- /tests/torch/test_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from stac_model.schema import MLModelProperties 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "mlm_example", 10 | [os.path.join("torch", "mlm-metadata.yaml")], 11 | indirect=True, 12 | ) 13 | def test_mlm_metadata_only_yaml_validation(mlm_example): 14 | MLModelProperties.model_validate(mlm_example["properties"]) 15 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Check Markdown and Examples 2 | on: [push, pull_request] 3 | jobs: 4 | check-docs: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/setup-node@v6 8 | with: 9 | node-version: 'lts/*' 10 | #cache: npm 11 | - uses: actions/checkout@v6 12 | - run: | 13 | npm install 14 | npm list 15 | npm test 16 | -------------------------------------------------------------------------------- /stac_model/torch/base.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import TypedDict 2 | 3 | import torch 4 | 5 | from ..base import Path 6 | 7 | ExtraFiles = TypedDict("ExtraFiles", {"mlm-metadata": str}, total=False) 8 | 9 | 10 | class ExportedPrograms(TypedDict, total=False): 11 | model: torch.export.ExportedProgram 12 | transforms: torch.export.ExportedProgram 13 | 14 | 15 | class AOTIFiles(TypedDict, total=False): 16 | model: list[Path] 17 | transforms: list[Path] 18 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | default_stages: [commit, push] 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v4.5.0 8 | hooks: 9 | - id: check-yaml 10 | - id: end-of-file-fixer 11 | - repo: https://github.com/astral-sh/ruff-pre-commit 12 | rev: 'v0.12.4' 13 | hooks: 14 | - id: ruff 15 | pass_filenames: false 16 | args: 17 | - --config=pyproject.toml 18 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: Release Drafter 2 | 3 | on: 4 | push: 5 | # branches to consider in the event; optional, defaults to all 6 | branches: 7 | - main 8 | 9 | jobs: 10 | update_release_draft: 11 | runs-on: ubuntu-latest 12 | steps: 13 | # Drafts your next Release notes as Pull Requests are merged into "master" 14 | - uses: release-drafter/release-drafter@v6.1.0 15 | env: 16 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 17 | -------------------------------------------------------------------------------- /.github/workflows/citation.yaml: -------------------------------------------------------------------------------- 1 | name: Check Citation Format 2 | on: 3 | push: 4 | paths: 5 | - CITATION.cff 6 | pull_request: 7 | paths: 8 | - CITATION.cff 9 | jobs: 10 | check-citation: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v6 14 | - name: Check whether the citation metadata from CITATION.cff is valid 15 | uses: citation-file-format/cffconvert-github-action@2.0.0 16 | with: 17 | args: "--validate" 18 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # Check http://editorconfig.org for more information 2 | # This is the main config file for this project: 3 | root = true 4 | 5 | [*] 6 | charset = utf-8 7 | end_of_line = lf 8 | insert_final_newline = true 9 | indent_style = space 10 | indent_size = 2 11 | trim_trailing_whitespace = true 12 | 13 | [*.{py, pyi}] 14 | indent_style = space 15 | indent_size = 4 16 | 17 | [Makefile] 18 | indent_style = tab 19 | 20 | [*.md] 21 | trim_trailing_whitespace = false 22 | 23 | [*.{diff,patch}] 24 | trim_trailing_whitespace = false 25 | -------------------------------------------------------------------------------- /docs/legacy/dlm.md: -------------------------------------------------------------------------------- 1 | # Deep Learning Model (DLM) Extension 2 | 3 | 4 | 5 | > [!WARNING] 6 | > This is legacy documentation reference of [Deep Learning Model extension](https://github.com/crim-ca/dlm-extension) 7 | > preceding the current Machine Learning Model (MLM) extension. 8 | 9 | 10 | 11 | Check the original [Technical Report](https://github.com/crim-ca/CCCOT03/raw/main/CCCOT03_Rapport%20Final_FINAL_EN.pdf). 12 | 13 | ![Image Description](https://i.imgur.com/cVAg5sA.png) 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: ❓Question 3 | about: Ask a question about this project 🎓 4 | title: '' 5 | labels: question, needs-triage 6 | assignees: 7 | --- 8 | 9 | ## Checklist 10 | 11 | 12 | 13 | - [ ] I've searched the project's [`issues`](..), looking for the following terms: 14 | - [...] 15 | 16 | ## :question: Question 17 | 18 | 19 | 20 | How can I [...]? 21 | 22 | Is it possible to [...]? 23 | 24 | ## :paperclip: Additional context 25 | 26 | 27 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | import pystac 2 | 3 | from stac_model.schema import SCHEMA_URI 4 | 5 | # ignore typing errors introduced by generic JSON manipulation errors 6 | # mypy: disable_error_code="arg-type,call-overload,index,union-attr" 7 | 8 | def test_model_metadata_to_dict(eurosat_resnet): 9 | assert eurosat_resnet.item.to_dict() 10 | 11 | 12 | def test_validate_model_metadata(eurosat_resnet): 13 | assert pystac.read_dict(eurosat_resnet.item.to_dict()) 14 | 15 | 16 | def test_validate_model_against_schema(eurosat_resnet, mlm_validator): 17 | mlm_item = pystac.read_dict(eurosat_resnet.item.to_dict()) 18 | validated = pystac.validation.validate(mlm_item, validator=mlm_validator) 19 | assert SCHEMA_URI in validated 20 | -------------------------------------------------------------------------------- /tests/torch/test_unet_mlm.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | pytest.importorskip("torchgeo") 7 | 8 | from stac_model.examples import unet_mlm 9 | 10 | 11 | def test_unet_mlm_matches_example_json(): 12 | try: 13 | item = unet_mlm().item.to_dict() 14 | except ModuleNotFoundError as e: 15 | if e.name == "torchgeo": 16 | pytest.skip("torchgeo is not installed") 17 | raise 18 | 19 | json_path = Path(__file__).resolve().parents[2] / "examples" / "item_pytorch_geo_unet.json" 20 | with open(json_path, "r", encoding="utf-8") as f: 21 | expected = json.load(f) 22 | 23 | assert item == expected, "Generated STAC Item does not match the saved example." 24 | -------------------------------------------------------------------------------- /.github/.stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 120 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 30 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: stale 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs in 30 days. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature request 3 | about: Suggest an idea for this project 🏖 4 | title: '' 5 | labels: enhancement, needs-triage 6 | assignees: 7 | --- 8 | 9 | ## :rocket: Feature Request 10 | 11 | 12 | 13 | ## :sound: Motivation 14 | 15 | 19 | 20 | ## :satellite: Alternatives 21 | 22 | 23 | 24 | ## :paperclip: Additional context 25 | 26 | 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug report 3 | about: If something isn't working 🔧 4 | title: '' 5 | labels: bug, needs-triage 6 | assignees: 7 | --- 8 | 9 | ## :bug: Bug Report 10 | 11 | 12 | 13 | ## :microscope: How To Reproduce 14 | 15 | Steps to reproduce the behavior: 16 | 17 | 1. ... 18 | 19 | ### Code sample 20 | 21 | 22 | 23 | ### Environment 24 | 25 | * OS: (e.g. Linux / Windows / macOS) 26 | * Python version 27 | * stac-model version 28 | 29 | 30 | ## :chart_with_upwards_trend: Expected behavior 31 | 32 | 33 | 34 | ## :paperclip: Additional context 35 | 36 | 37 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | # Release drafter configuration https://github.com/release-drafter/release-drafter#configuration 2 | # Emojis were chosen to match the https://gitmoji.dev/ 3 | 4 | name-template: "v$NEXT_PATCH_VERSION" 5 | tag-template: "v$NEXT_PATCH_VERSION" 6 | 7 | categories: 8 | - title: ":rocket: Features" 9 | labels: [enhancement, feature] 10 | - title: ":wrench: Fixes & Refactoring" 11 | labels: [bug, refactoring, bugfix, fix] 12 | - title: ":package: Build System & CI/CD" 13 | labels: [build, ci, testing] 14 | - title: ":boom: Breaking Changes" 15 | labels: [breaking] 16 | - title: ":memo: Documentation" 17 | labels: [documentation] 18 | - title: ":arrow_up: Dependencies updates" 19 | labels: [dependencies] 20 | 21 | template: | 22 | ## What's Changed 23 | 24 | $CHANGES 25 | 26 | ## :busts_in_silhouette: List of contributors 27 | 28 | $CONTRIBUTORS 29 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Configuration: https://dependabot.com/docs/config-file/ 2 | # Docs: https://docs.github.com/en/github/administering-a-repository/keeping-your-dependencies-updated-automatically 3 | 4 | version: 2 5 | 6 | updates: 7 | - package-ecosystem: "pip" 8 | directory: "/" 9 | schedule: 10 | interval: "monthly" 11 | allow: 12 | - dependency-type: "all" 13 | commit-message: 14 | prefix: ":arrow_up:" 15 | open-pull-requests-limit: 5 16 | 17 | - package-ecosystem: "github-actions" 18 | directory: "/" 19 | schedule: 20 | interval: "monthly" 21 | allow: 22 | - dependency-type: "all" 23 | commit-message: 24 | prefix: ":arrow_up:" 25 | open-pull-requests-limit: 5 26 | 27 | - package-ecosystem: "docker" 28 | directory: "/docker" 29 | schedule: 30 | interval: "monthly" 31 | allow: 32 | - dependency-type: "all" 33 | commit-message: 34 | prefix: ":arrow_up:" 35 | open-pull-requests-limit: 5 36 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "stac-mlm", 3 | "version": "1.5.0", 4 | "scripts": { 5 | "test": "npm run check-markdown && npm run check-examples", 6 | "check-markdown": "remark . -f -r .github/remark.yaml -i .remarkignore", 7 | "format-markdown": "remark . -f -r .github/remark.yaml -i .remarkignore -o", 8 | "check-examples": "stac-node-validator . --lint --verbose --schemaMap https://stac-extensions.github.io/mlm/v1.5.0/schema.json=./json-schema/schema.json", 9 | "format-examples": "stac-node-validator . --format --schemaMap https://stac-extensions.github.io/mlm/v1.5.0/schema.json=./json-schema/schema.json" 10 | }, 11 | "dependencies": { 12 | "remark-cli": "^8.0.0", 13 | "remark-gfm": "^4.0.0", 14 | "remark-lint": "^7.0.0", 15 | "remark-lint-no-html": "^2.0.0", 16 | "remark-math": "^6.0.0", 17 | "remark-preset-lint-consistent": "^3.0.0", 18 | "remark-preset-lint-markdown-style-guide": "^3.0.0", 19 | "remark-preset-lint-recommended": "^4.0.0", 20 | "remark-validate-links": "^10.0.0", 21 | "stac-node-validator": "^1.0.0" 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish JSON Schema or stac-model package via Github Release 2 | on: 3 | release: 4 | types: [published] 5 | jobs: 6 | deploy-schema: 7 | if: startsWith(github.ref, 'refs/tags/v') 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Inject env variables 11 | uses: rlespinasse/github-slug-action@v3.x 12 | - uses: actions/checkout@v6 13 | - name: deploy JSON Schema for version ${{ env.GITHUB_REF_SLUG }} 14 | uses: peaceiris/actions-gh-pages@v4 15 | with: 16 | github_token: ${{ secrets.GITHUB_TOKEN }} 17 | publish_dir: json-schema 18 | destination_dir: ${{ env.GITHUB_REF_SLUG }} 19 | publish-pypi: 20 | if: startsWith(github.ref, 'refs/tags/stac-model-v') 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v6 24 | - name: Set up Python 25 | uses: actions/setup-python@v6.1.0 26 | with: 27 | python-version: "3.10" 28 | - name: Install uv 29 | run: make setup 30 | - name: Publish stac-model to PyPI 31 | run: | 32 | uv build 33 | uv publish --username __token__ --password ${{ secrets.PYPI_SECRET }} 34 | -------------------------------------------------------------------------------- /.github/workflows/stac-model.yml: -------------------------------------------------------------------------------- 1 | name: Check Python Linting and Tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | stac-model: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.10", "3.11", "3.12", "3.13"] 11 | env: 12 | COVERAGE_FILE: .coverage.${{ matrix.python-version }} 13 | 14 | steps: 15 | - uses: actions/checkout@v6 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v6.1.0 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | 21 | - name: Install uv 22 | run: make setup 23 | 24 | - name: Set up cache 25 | uses: actions/cache@v4.3.0 26 | with: 27 | path: .venv 28 | key: venv-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('uv.lock') }} 29 | 30 | - name: Install dependencies 31 | run: make install-dev 32 | if: ${{ matrix.python-version == '3.10' }} 33 | 34 | - name: Install dependencies with extras 35 | run: make install-dev-extras 36 | if: ${{ matrix.python-version != '3.10' }} 37 | 38 | - name: Display Packages 39 | run: pip list 40 | 41 | - name: Run checks 42 | if: ${{ matrix.python-version != '3.10' }} 43 | run: make lint-all 44 | 45 | - name: Run tests 46 | run: | 47 | make test 48 | -------------------------------------------------------------------------------- /examples/torch/mlm-metadata.yaml: -------------------------------------------------------------------------------- 1 | $schema: "https://stac-extensions.github.io/mlm/v1.5.0/schema.json" 2 | properties: 3 | name: Unet 4 | architecture: Unet 5 | artifact_type: torch.export.save 6 | framework: torch 7 | framework_version: 2.7.0 8 | accelerator: cuda 9 | total_parameters: 1234567 10 | tasks: [semantic-segmentation] 11 | input: 12 | - name: Imagery 13 | bands: [red, blue, green] 14 | input: 15 | shape: [-1, 3, 512, 512] 16 | dim_order: [batch, channel, height, width] 17 | data_type: float32 18 | pre_processing_function: 19 | format: torch.export.load 20 | expression: 21 | id: export-transform 22 | name: transforms 23 | href: path/to/archive.pt2 24 | type: "application/octet-stream; framework=pytorch; profile=ExportedProgram" 25 | output: 26 | - name: segmentation-output 27 | tasks: [semantic-segmentation] 28 | result: 29 | shape: [-1, 6, 512, 512] 30 | dim_order: [batch, classes, height, width] 31 | data_type: float32 32 | classification:classes: 33 | - value: 0 34 | name: Class 0 35 | - value: 1 36 | name: Class 1 37 | - value: 2 38 | name: Class 2 39 | - value: 3 40 | name: Class 3 41 | - value: 4 42 | name: Class 4 43 | - value: 5 44 | name: Class 5 45 | -------------------------------------------------------------------------------- /stac_model/__main__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import typer 4 | from rich.console import Console 5 | 6 | from stac_model import __version__ 7 | from stac_model.examples import eurosat_resnet 8 | from stac_model.schema import ItemMLModelExtension 9 | 10 | app = typer.Typer( 11 | name="stac-model", 12 | help="A PydanticV2 validation and serialization library for the STAC Machine Learning Model Extension", 13 | add_completion=False, 14 | ) 15 | console = Console() 16 | 17 | 18 | def version_callback(print_version: bool) -> None: 19 | """Print the version of the package.""" 20 | if print_version: 21 | console.print(f"[yellow]stac-model[/] version: [bold blue]{__version__}[/]") 22 | raise typer.Exit() 23 | 24 | 25 | @app.command(name="") 26 | def main( 27 | print_version: bool = typer.Option( 28 | None, 29 | "-v", 30 | "--version", 31 | callback=version_callback, 32 | is_eager=True, 33 | help="Prints the version of the stac-model package.", 34 | ), 35 | ) -> ItemMLModelExtension: 36 | """Generate example spec.""" 37 | ml_model_meta = eurosat_resnet() 38 | with open("example.json", mode="w", encoding="utf-8") as json_file: 39 | json.dump(ml_model_meta.item.to_dict(), json_file, indent=4) 40 | print("Example model metadata written to ./example.json.") 41 | return ml_model_meta 42 | 43 | 44 | if __name__ == "__main__": 45 | app() 46 | -------------------------------------------------------------------------------- /tests/torch/metadata.yaml: -------------------------------------------------------------------------------- 1 | # This file is used in torch export tests to check that the metadata is 2 | # validated using the MLModelMetadata Pydantic model and the export 3 | # functions work as expected. 4 | $schema: "https://stac-extensions.github.io/mlm/v1.5.0/schema.json" 5 | properties: 6 | name: FTW 3 Class Unet 7 | architecture: torchgeo.models.unet 8 | artifact_type: torch.export.save 9 | framework: torch 10 | framework_version: 2.8.0 11 | accelerator: cuda 12 | total_parameters: 1234567 13 | tasks: [semantic-segmentation] 14 | input: 15 | - name: Imagery 16 | bands: [B04, B03, B02, B05, B04, B03, B02, B05] 17 | input: 18 | shape: [-1, 8, -1, -1] 19 | dim_order: [batch, channel, height, width] 20 | data_type: float32 21 | pre_processing_function: 22 | format: torch.export.load 23 | expression: 24 | id: export-transform 25 | name: transforms 26 | href: path/to/archive.pt2 27 | type: "application/octet-stream; framework=pytorch; profile=ExportedProgram" 28 | output: 29 | - name: segmentation-output 30 | tasks: [semantic-segmentation] 31 | result: 32 | shape: [-1, 3, -1, -1] 33 | dim_order: [batch, classes, height, width] 34 | data_type: float32 35 | classification:classes: 36 | - value: 0 37 | name: background 38 | - value: 1 39 | name: field 40 | - value: 2 41 | name: field-boundary 42 | -------------------------------------------------------------------------------- /stac-model.bump.toml: -------------------------------------------------------------------------------- 1 | 2 | [tool.bumpversion] 3 | # NOTE: 4 | # This is the bump definition for the 'stac-model' package. 5 | # For the MLM specification, refer to the main 'pyproject.toml'. 6 | # they are actually intented for versioning the MLM specification itself. 7 | # To version 'stac-model', use the 'bump-my-version bump' operation with this file. 8 | # See also https://github.com/stac-extensions/mlm/blob/main/CONTRIBUTING.md#building-and-releasing 9 | current_version = "0.4.0" 10 | parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)" 11 | serialize = ["{major}.{minor}.{patch}"] 12 | search = "{current_version}" 13 | replace = "{new_version}" 14 | regex = false 15 | ignore_missing_version = true 16 | ignore_missing_files = false 17 | tag = true 18 | sign_tags = false 19 | tag_name = "stac-model-v{new_version}" 20 | tag_message = "Bump version: stac-model {current_version} → {new_version}" 21 | allow_dirty = false 22 | commit = true 23 | commit_args = "--no-verify" 24 | message = "Bump version: stac-model {current_version} → {new_version}" 25 | 26 | [[tool.bumpversion.files]] 27 | filename = "uv.lock" 28 | search = """ 29 | name = "stac-model" 30 | version = "{current_version}" 31 | """ 32 | replace = """ 33 | name = "stac-model" 34 | version = "{new_version}" 35 | """ 36 | 37 | [[tool.bumpversion.files]] 38 | filename = "pyproject.toml" 39 | search = """ 40 | name = "stac-model" 41 | version = "{current_version}" 42 | """ 43 | replace = """ 44 | name = "stac-model" 45 | version = "{new_version}" 46 | """ 47 | -------------------------------------------------------------------------------- /stac_model/runtime.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Annotated, Literal, Union 3 | 4 | from pydantic import AliasChoices, Field 5 | 6 | from stac_model.base import MLMBaseModel, OmitIfNone 7 | 8 | 9 | class AcceleratorEnum(str, Enum): 10 | amd64 = "amd64" 11 | cpu = "cpu" 12 | cuda = "cuda" 13 | xla = "xla" 14 | amd_rocm = "amd-rocm" 15 | intel_ipex_cpu = "intel-ipex-cpu" 16 | intel_ipex_gpu = "intel-ipex-gpu" 17 | macos_arm = "macos-arm" 18 | 19 | def __str__(self): 20 | return self.value 21 | 22 | 23 | AcceleratorName = Literal[ 24 | "amd64", 25 | "cpu", 26 | "cuda", 27 | "xla", 28 | "amd-rocm", 29 | "intel-ipex-cpu", 30 | "intel-ipex-gpu", 31 | "macos-arm", 32 | ] 33 | 34 | AcceleratorType = Union[AcceleratorName, AcceleratorEnum] 35 | 36 | 37 | class Runtime(MLMBaseModel): 38 | framework: Annotated[str | None, OmitIfNone] = Field(default=None) 39 | framework_version: Annotated[str | None, OmitIfNone] = Field(default=None) 40 | file_size: Annotated[int | None, OmitIfNone] = Field( 41 | alias="file:size", 42 | validation_alias=AliasChoices("file_size", "file:size"), 43 | default=None, 44 | ) 45 | memory_size: Annotated[int | None, OmitIfNone] = Field(default=None) 46 | batch_size_suggestion: Annotated[int | None, OmitIfNone] = Field(default=None) 47 | 48 | accelerator: AcceleratorType | None = Field(default=None) 49 | accelerator_constrained: bool = Field(default=False) 50 | accelerator_summary: Annotated[str | None, OmitIfNone] = Field(default=None) 51 | accelerator_count: Annotated[int | None, OmitIfNone] = Field(default=None, ge=1) 52 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | 5 | ## Related Issue 6 | 7 | 8 | 9 | ## Type of Change 10 | 11 | 12 | 13 | - [ ] :books: Examples, docs, tutorials or dependencies update; 14 | - [ ] :wrench: Bug fix (non-breaking change which fixes an issue); 15 | - [ ] :clinking_glasses: Improvement (non-breaking change which improves an existing feature); 16 | - [ ] :rocket: New feature (non-breaking change which adds functionality); 17 | - [ ] :boom: Breaking change (fix or feature that would cause existing functionality to change); 18 | - [ ] :closed_lock_with_key: Security fix. 19 | 20 | ## Checklist 21 | 22 | 23 | 24 | - [ ] I've read the [`CONTRIBUTING.md`][contrib] guide; 25 | - [ ] I've updated the [`CHANGELOG.md`][changes] with provided changes; 26 | - [ ] I've updated the [`README.md`][readme] and/or [`best-practices.md`][best-practices] as applicable with new features; 27 | - [ ] I've updated the code style using `make check`; 28 | - [ ] I've written tests for all new methods and classes that I created; 29 | - [ ] I've written the docstring in `Google` format for all the methods and classes that I used. 30 | 31 | [contrib]: https://github.com/stac-extensions/mlm/blob/main/CONTRIBUTING.md 32 | [changes]: https://github.com/stac-extensions/mlm/blob/main/CHANGELOG.md 33 | [readme]: https://github.com/stac-extensions/mlm/blob/main/README.md 34 | [best-practices]: https://github.com/stac-extensions/mlm/blob/main/best-practices.md 35 | -------------------------------------------------------------------------------- /.github/remark.yaml: -------------------------------------------------------------------------------- 1 | settings: 2 | listItemIndent: '1' 3 | emphasis: '*' 4 | spacedTable: false 5 | paddedTable: true 6 | stringify: 7 | entities: false 8 | escape: false 9 | plugins: 10 | # Check links 11 | - validate-links 12 | # Apply some recommended defaults for consistency 13 | - remark-preset-lint-consistent 14 | - remark-preset-lint-recommended 15 | - - lint-no-html 16 | - false 17 | # General formatting 18 | - - remark-lint-emphasis-marker 19 | - '*' 20 | - remark-lint-hard-break-spaces 21 | - remark-lint-blockquote-indentation 22 | - remark-lint-no-consecutive-blank-lines 23 | - - remark-lint-maximum-line-length 24 | - 120 25 | - remark-lint-no-literal-urls 26 | # GFM - autolink literals, footnotes, strikethrough, tables, tasklist 27 | - remark-gfm 28 | # Math Expression 29 | - remark-math 30 | # Code 31 | - remark-lint-fenced-code-flag 32 | - remark-lint-fenced-code-marker 33 | - remark-lint-no-shell-dollars 34 | - - remark-lint-code-block-style 35 | - 'fenced' 36 | # Headings 37 | - remark-lint-heading-increment 38 | - remark-lint-no-multiple-toplevel-headings 39 | - remark-lint-no-heading-punctuation 40 | - - remark-lint-maximum-heading-length 41 | - 70 42 | - - remark-lint-heading-style 43 | - atx 44 | - - remark-lint-no-shortcut-reference-link 45 | - false 46 | # Lists 47 | - - remark-lint-list-item-bullet-indent 48 | - 'one' 49 | - remark-lint-ordered-list-marker-style 50 | - remark-lint-ordered-list-marker-value 51 | - remark-lint-checkbox-character-style 52 | - - remark-lint-unordered-list-marker-style 53 | - '-' 54 | - - remark-lint-list-item-content-indent 55 | - 1 56 | - - remark-lint-list-item-indent 57 | - 'space' 58 | # Tables 59 | - remark-lint-table-pipes 60 | - remark-lint-table-cell-padding 61 | -------------------------------------------------------------------------------- /examples/collection.json: -------------------------------------------------------------------------------- 1 | { 2 | "stac_version": "1.0.0", 3 | "stac_extensions": [ 4 | "https://stac-extensions.github.io/item-assets/v1.0.0/schema.json" 5 | ], 6 | "type": "Collection", 7 | "id": "ml-model-examples", 8 | "title": "Machine Learning Model examples", 9 | "description": "Collection of items contained in the Machine Learning Model examples.", 10 | "license": "Apache-2.0", 11 | "extent": { 12 | "spatial": { 13 | "bbox": [ 14 | [ 15 | -7.882190080512502, 16 | 37.13739173208318, 17 | 27.911651652899923, 18 | 58.21798141355221 19 | ] 20 | ] 21 | }, 22 | "temporal": { 23 | "interval": [ 24 | [ 25 | "1900-01-01T00:00:00Z", 26 | "9999-12-31T23:59:59Z" 27 | ] 28 | ] 29 | } 30 | }, 31 | "item_assets": { 32 | "weights": { 33 | "title": "model weights", 34 | "roles": [ 35 | "mlm:model", 36 | "mlm:weights" 37 | ] 38 | } 39 | }, 40 | "summaries": { 41 | "datetime": { 42 | "minimum": "1900-01-01T00:00:00Z", 43 | "maximum": "9999-12-31T23:59:59Z" 44 | } 45 | }, 46 | "links": [ 47 | { 48 | "href": "collection.json", 49 | "rel": "self" 50 | }, 51 | { 52 | "href": "item_basic.json", 53 | "rel": "item" 54 | }, 55 | { 56 | "href": "item_bands_expression.json", 57 | "rel": "item" 58 | }, 59 | { 60 | "href": "item_datacube_variables.json", 61 | "rel": "item" 62 | }, 63 | { 64 | "href": "item_eo_bands.json", 65 | "rel": "item" 66 | }, 67 | { 68 | "href": "item_eo_and_raster_bands.json", 69 | "rel": "item" 70 | }, 71 | { 72 | "href": "item_eo_bands_summarized.json", 73 | "rel": "item" 74 | }, 75 | { 76 | "href": "item_raster_bands.json", 77 | "rel": "item" 78 | }, 79 | { 80 | "href": "item_multi_io.json", 81 | "rel": "item" 82 | }, 83 | { 84 | "href": "item_pytorch_geo_unet.json", 85 | "rel": "item" 86 | } 87 | ] 88 | } 89 | -------------------------------------------------------------------------------- /stac_model/input.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Literal, TypeAlias, Union 2 | from typing_extensions import Self 3 | 4 | from pydantic import Field, model_validator 5 | 6 | from stac_model.base import ( 7 | DataType, 8 | MLMBaseModel, 9 | ModelBandsOrVariablesReferences, 10 | Number, 11 | OmitIfNone, 12 | ProcessingExpression, 13 | ) 14 | 15 | 16 | class InputStructure(MLMBaseModel): 17 | shape: list[Union[int, float]] = Field(min_length=1) 18 | dim_order: list[str] = Field(min_length=1) 19 | data_type: DataType 20 | 21 | @model_validator(mode="after") 22 | def validate_dimensions(self) -> Self: 23 | if len(self.shape) != len(self.dim_order): 24 | raise ValueError("Dimension order and shape must be of equal length for corresponding indices.") 25 | return self 26 | 27 | 28 | class ValueScalingClipMin(MLMBaseModel): 29 | type: Literal["clip-min"] = "clip-min" 30 | minimum: Number 31 | 32 | 33 | class ValueScalingClipMax(MLMBaseModel): 34 | type: Literal["clip-max"] = "clip-max" 35 | maximum: Number 36 | 37 | 38 | class ValueScalingClip(MLMBaseModel): 39 | type: Literal["clip"] = "clip" 40 | minimum: Number 41 | maximum: Number 42 | 43 | 44 | class ValueScalingMinMax(MLMBaseModel): 45 | type: Literal["min-max"] = "min-max" 46 | minimum: Number 47 | maximum: Number 48 | 49 | 50 | class ValueScalingZScore(MLMBaseModel): 51 | type: Literal["z-score"] = "z-score" 52 | mean: Number 53 | stddev: Number 54 | 55 | 56 | class ValueScalingOffset(MLMBaseModel): 57 | type: Literal["offset"] = "offset" 58 | value: Number 59 | 60 | 61 | class ValueScalingScale(MLMBaseModel): 62 | type: Literal["scale"] = "scale" 63 | value: Number 64 | 65 | 66 | class ValueScalingProcessingExpression(ProcessingExpression): 67 | type: Literal["processing"] = "processing" 68 | 69 | 70 | ValueScalingObject: TypeAlias = Union[ 71 | ValueScalingMinMax, 72 | ValueScalingZScore, 73 | ValueScalingClip, 74 | ValueScalingClipMin, 75 | ValueScalingClipMax, 76 | ValueScalingOffset, 77 | ValueScalingScale, 78 | ValueScalingProcessingExpression, 79 | None, 80 | ] 81 | 82 | ResizeType: TypeAlias = ( 83 | Literal[ 84 | "crop", 85 | "pad", 86 | "interpolation-nearest", 87 | "interpolation-linear", 88 | "interpolation-cubic", 89 | "interpolation-area", 90 | "interpolation-lanczos4", 91 | "interpolation-max", 92 | "wrap-fill-outliers", 93 | "wrap-inverse-map", 94 | ] 95 | | None 96 | ) 97 | 98 | 99 | class ModelInput(ModelBandsOrVariablesReferences): 100 | name: str 101 | input: InputStructure 102 | value_scaling: Annotated[list[ValueScalingObject] | None, OmitIfNone] = None 103 | resize_type: Annotated[ResizeType | None, OmitIfNone] = None 104 | pre_processing_function: ProcessingExpression | list[ProcessingExpression] | None = None 105 | -------------------------------------------------------------------------------- /.safety-policy.yml: -------------------------------------------------------------------------------- 1 | # Safety Security and License Configuration file 2 | # https://docs.safetycli.com/safety-docs/administration/safety-policy-files 3 | 4 | security: # configuration for the `safety check` command 5 | ignore-cvss-severity-below: 0 6 | ignore-cvss-unknown-severity: False 7 | ignore-vulnerabilities: 8 | 67599: 9 | reason: disputed pip feature not used by this project 10 | continue-on-vulnerability-error: False 11 | alert: # configuration for the `safety alert` command 12 | security: 13 | # Configuration specific to Safety's GitHub Issue alerting 14 | github-issue: 15 | # Same as for security - these allow controlling if this alert will fire based 16 | # on severity information. 17 | # default: not set 18 | # ignore-cvss-severity-below: 6 19 | # ignore-cvss-unknown-severity: False 20 | 21 | # Add a label to pull requests with the cvss severity, if available 22 | # label-severity: true 23 | 24 | # Add a label to pull requests, default is 'security' 25 | # requires private repo permissions, even on public repos 26 | # default: security 27 | labels: 28 | - security 29 | 30 | # Assign users to pull requests, default is not set 31 | # requires private repo permissions, even on public repos 32 | # default: empty 33 | # assignees: 34 | # - example-user 35 | 36 | # Prefix to give issues when creating them. Note that changing 37 | # this might cause duplicate issues to be created. 38 | # default: "[PyUp] " 39 | # issue-prefix: "[PyUp] " 40 | 41 | # Configuration specific to Safety's GitHub PR alerting 42 | github-pr: 43 | # Same as for security - these allow controlling if this alert will fire based 44 | # on severity information. 45 | # default: not set 46 | # ignore-cvss-severity-below: 6 47 | # ignore-cvss-unknown-severity: False 48 | 49 | # Set the default branch (ie, main, master) 50 | # default: empty, the default branch on GitHub 51 | branch: '' 52 | 53 | # Add a label to pull requests with the cvss severity, if available 54 | # default: true 55 | # label-severity: True 56 | 57 | # Add a label to pull requests, default is 'security' 58 | # requires private repo permissions, even on public repos 59 | # default: security 60 | labels: 61 | - security 62 | 63 | # Assign users to pull requests, default is not set 64 | # requires private repo permissions, even on public repos 65 | # default: empty 66 | # assignees: 67 | # - example-user 68 | 69 | # Configure the branch prefix for PRs created by this alert. 70 | # NB: Changing this will likely cause duplicate PRs. 71 | # default: pyup/ 72 | branch-prefix: pyup/ 73 | 74 | # Set a global prefix for PRs 75 | # default: "[PyUp] " 76 | pr-prefix: "[PyUp] " 77 | -------------------------------------------------------------------------------- /stac_model/output.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Any, cast 2 | 3 | from pydantic import AliasChoices, ConfigDict, Field, model_serializer 4 | from pystac.extensions.classification import Classification 5 | 6 | from stac_model.base import ( 7 | DataType, 8 | MLMBaseModel, 9 | ModelBandsOrVariablesReferences, 10 | ModelTask, 11 | OmitIfNone, 12 | ProcessingExpression, 13 | ) 14 | 15 | 16 | class ModelResult(MLMBaseModel): 17 | shape: list[int | float] = Field(..., min_length=1) 18 | dim_order: list[str] = Field(..., min_length=1) 19 | data_type: DataType 20 | 21 | 22 | # MLMClassification: TypeAlias = Annotated[ 23 | # Classification, 24 | # PlainSerializer( 25 | # lambda x: x.to_dict(), 26 | # when_used="json", 27 | # return_type=TypedDict( 28 | # "Classification", 29 | # { 30 | # "value": int, 31 | # "name": str, 32 | # "description": NotRequired[str], 33 | # "color_hint": NotRequired[str], 34 | # } 35 | # ) 36 | # ) 37 | # ] 38 | 39 | 40 | class MLMClassification(MLMBaseModel, Classification): 41 | @model_serializer() 42 | def model_dump(self, *_: Any, **__: Any) -> dict[str, Any]: 43 | return self.to_dict() # type: ignore[call-arg] 44 | 45 | def __init__( 46 | self, 47 | value: int, 48 | description: str | None = None, 49 | name: str | None = None, 50 | color_hint: str | None = None, 51 | ) -> None: 52 | Classification.__init__(self, {}) 53 | if not name and not description: 54 | raise ValueError("Class name or description is required!") 55 | self.apply( 56 | value=value, 57 | name=name or description, 58 | description=cast(str, description or name), 59 | color_hint=color_hint, 60 | ) 61 | 62 | def __hash__(self) -> int: 63 | return sum(map(hash, self.to_dict().items())) 64 | 65 | def __setattr__(self, key: str, value: Any) -> None: 66 | if key == "properties": 67 | Classification.__setattr__(self, key, value) 68 | else: 69 | MLMBaseModel.__setattr__(self, key, value) 70 | 71 | model_config = ConfigDict( 72 | populate_by_name=True, 73 | arbitrary_types_allowed=True, 74 | ) 75 | 76 | 77 | # class ClassObject(BaseModel): 78 | # value: int 79 | # name: str 80 | # description: Optional[str] = None 81 | # title: Optional[str] = None 82 | # color_hint: Optional[str] = None 83 | # nodata: Optional[bool] = False 84 | 85 | 86 | class ModelOutput(ModelBandsOrVariablesReferences): 87 | name: str 88 | tasks: set[ModelTask] 89 | result: ModelResult 90 | 91 | # NOTE: 92 | # Although it is preferable to have 'Set' to avoid duplicate, 93 | # it is more important to keep the order in this case, 94 | # which we would lose with 'Set'. 95 | # We also get some unhashable errors with 'Set', although 'MLMClassification' implements '__hash__'. 96 | classes: Annotated[list[MLMClassification] | None, OmitIfNone] = Field( 97 | alias="classification:classes", 98 | validation_alias=AliasChoices("classification:classes", "classification_classes", "classes"), 99 | default=None, 100 | ) 101 | post_processing_function: ProcessingExpression | list[ProcessingExpression] | None = None 102 | 103 | model_config = ConfigDict( 104 | populate_by_name=True, 105 | ) 106 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | from typing import TYPE_CHECKING, Any, cast 5 | 6 | import pystac 7 | import pytest 8 | import yaml 9 | 10 | from stac_model.base import JSON 11 | from stac_model.examples import eurosat_resnet as make_eurosat_resnet 12 | from stac_model.schema import SCHEMA_URI 13 | 14 | if TYPE_CHECKING: 15 | from _pytest.fixtures import SubRequest 16 | 17 | TEST_DIR = os.path.dirname(__file__) 18 | EXAMPLES_DIR = os.path.abspath(os.path.join(TEST_DIR, "../examples")) 19 | JSON_SCHEMA_DIR = os.path.abspath(os.path.join(TEST_DIR, "../json-schema")) 20 | 21 | 22 | def get_all_stac_item_examples() -> list[str]: 23 | all_json = glob.glob("**/*.json", root_dir=EXAMPLES_DIR, recursive=True) 24 | all_geojson = glob.glob("**/*.geojson", root_dir=EXAMPLES_DIR, recursive=True) 25 | all_stac_items = [ 26 | path 27 | for path in all_json + all_geojson 28 | if os.path.splitext(os.path.basename(path))[0] not in ["collection", "catalog"] 29 | ] 30 | return all_stac_items 31 | 32 | 33 | @pytest.fixture(scope="session") 34 | def mlm_schema() -> JSON: 35 | with open(os.path.join(JSON_SCHEMA_DIR, "schema.json"), mode="r", encoding="utf-8") as schema_file: 36 | data = json.load(schema_file) 37 | return cast(JSON, data) 38 | 39 | 40 | @pytest.fixture(scope="session") 41 | def mlm_validator( 42 | request: "SubRequest", 43 | mlm_schema: dict[str, Any], 44 | ) -> pystac.validation.stac_validator.JsonSchemaSTACValidator: 45 | """ 46 | Update the :class:`pystac.validation.RegisteredValidator` with the local MLM JSON schema definition. 47 | 48 | Because the schema is *not yet* uploaded to the expected STAC schema URI, 49 | any call to :func:`pystac.validation.validate` or :meth:`pystac.stac_object.STACObject.validate` results 50 | in ``GetSchemaError`` when the schema retrieval is attempted by the validator.By adding the schema to the 51 | mapping beforehand, remote resolution can be bypassed temporarily. When evaluating modifications to the 52 | current schema, this also ensures that local changes are used instead of the remote reference. 53 | """ 54 | validator = pystac.validation.RegisteredValidator.get_validator() 55 | validator = cast(pystac.validation.stac_validator.JsonSchemaSTACValidator, validator) 56 | validator.schema_cache[SCHEMA_URI] = mlm_schema 57 | pystac.validation.RegisteredValidator.set_validator(validator) # apply globally to allow 'STACObject.validate()' 58 | return validator 59 | 60 | 61 | @pytest.fixture 62 | def mlm_example(request: "SubRequest") -> dict[str, JSON]: 63 | """ 64 | Fixture that loads an example STAC Item with MLM extension from the examples directory. 65 | 66 | Usage: 67 | 68 | ```python 69 | @pytest.mark.parametrize( 70 | "mlm_example", 71 | ["path/to/example1.json", "path/to/example2.yaml"], # or just the name if in 'EXAMPLES_DIR' 72 | indirect=True, 73 | ) 74 | def test_example(mlm_example: dict[str, JSON]) -> None: ... 75 | ``` 76 | """ 77 | with open(os.path.join(EXAMPLES_DIR, request.param), mode="r", encoding="utf-8") as example_file: 78 | if request.param.endswith(".json"): 79 | data = json.load(example_file) 80 | elif request.param.endswith(".yaml"): 81 | data = yaml.safe_load(example_file) 82 | else: 83 | raise ValueError(f"Unsupported file format for example: {request.param}") 84 | return cast(dict[str, JSON], data) 85 | 86 | 87 | @pytest.fixture(name="eurosat_resnet") 88 | def eurosat_resnet(): 89 | return make_eurosat_resnet() 90 | -------------------------------------------------------------------------------- /examples/item_basic.json: -------------------------------------------------------------------------------- 1 | { 2 | "stac_version": "1.0.0", 3 | "stac_extensions": [ 4 | "https://stac-extensions.github.io/mlm/v1.5.0/schema.json" 5 | ], 6 | "type": "Feature", 7 | "id": "example-model", 8 | "collection": "ml-model-examples", 9 | "geometry": { 10 | "type": "Polygon", 11 | "coordinates": [ 12 | [ 13 | [ 14 | -7.882190080512502, 15 | 37.13739173208318 16 | ], 17 | [ 18 | -7.882190080512502, 19 | 58.21798141355221 20 | ], 21 | [ 22 | 27.911651652899923, 23 | 58.21798141355221 24 | ], 25 | [ 26 | 27.911651652899923, 27 | 37.13739173208318 28 | ], 29 | [ 30 | -7.882190080512502, 31 | 37.13739173208318 32 | ] 33 | ] 34 | ] 35 | }, 36 | "bbox": [ 37 | -7.882190080512502, 38 | 37.13739173208318, 39 | 27.911651652899923, 40 | 58.21798141355221 41 | ], 42 | "properties": { 43 | "description": "Basic STAC Item with only the MLM extension and no other extension cross-references.", 44 | "datetime": null, 45 | "start_datetime": "1900-01-01T00:00:00Z", 46 | "end_datetime": "9999-12-31T23:59:59Z", 47 | "mlm:name": "example-model", 48 | "mlm:tasks": [ 49 | "classification" 50 | ], 51 | "mlm:architecture": "ResNet", 52 | "mlm:input": [ 53 | { 54 | "name": "Model with RGB input that does not refer to any band.", 55 | "bands": [], 56 | "input": { 57 | "shape": [ 58 | -1, 59 | 3, 60 | 64, 61 | 64 62 | ], 63 | "dim_order": [ 64 | "batch", 65 | "channel", 66 | "height", 67 | "width" 68 | ], 69 | "data_type": "float32" 70 | } 71 | } 72 | ], 73 | "mlm:output": [ 74 | { 75 | "name": "classification", 76 | "tasks": [ 77 | "classification" 78 | ], 79 | "result": { 80 | "shape": [ 81 | -1, 82 | 1 83 | ], 84 | "dim_order": [ 85 | "batch", 86 | "class" 87 | ], 88 | "data_type": "uint8" 89 | }, 90 | "classification_classes": [ 91 | { 92 | "value": 0, 93 | "name": "BACKGROUND", 94 | "description": "Background non-city.", 95 | "color_hint": [ 96 | 0, 97 | 0, 98 | 0 99 | ] 100 | }, 101 | { 102 | "value": 1, 103 | "name": "CITY", 104 | "description": "A city is detected.", 105 | "color_hint": [ 106 | 0, 107 | 0, 108 | 255 109 | ] 110 | } 111 | ] 112 | } 113 | ] 114 | }, 115 | "assets": { 116 | "model": { 117 | "href": "https://huggingface.co/example/model-card", 118 | "title": "Pytorch weights checkpoint", 119 | "description": "Example model.", 120 | "type": "text/html", 121 | "roles": [ 122 | "mlm:model" 123 | ], 124 | "mlm:artifact_type": "torch.save" 125 | } 126 | }, 127 | "links": [ 128 | { 129 | "rel": "collection", 130 | "href": "./collection.json", 131 | "type": "application/json" 132 | }, 133 | { 134 | "rel": "self", 135 | "href": "./item_basic.json", 136 | "type": "application/geo+json" 137 | } 138 | ] 139 | } 140 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute to MLM specification or `stac-model` 2 | 3 | ## Project setup 4 | 5 | 1. If you don't have `uv` installed run: 6 | 7 | ```bash 8 | make setup 9 | ``` 10 | 11 | This installs `uv` as a [standalone application][uv-install].
12 | For more details, see also the [`uv` documentation][uv-docs]. 13 | 14 | 2. Initialize project dependencies with `uv` and install `pre-commit` hooks: 15 | 16 | ```bash 17 | make install-dev 18 | make pre-commit-install 19 | ``` 20 | 21 | This will install project dependencies into the currently active environment. If you would like to 22 | use `uv`'s default behavior of managing a project-scoped environment, use `uv` commands directly to 23 | install dependencies. `uv sync` will install dependencies and dev dependencies in `.venv` and update the `uv.lock`. 24 | 25 | ## PR submission 26 | 27 | Before submitting your code please do the following steps: 28 | 29 | 1. Add any changes you want 30 | 31 | 2. Add tests for the new changes 32 | 33 | 3. Edit documentation if you have changed something significant 34 | 35 | You're then ready to run and test your contributions. 36 | 37 | 4. Run linting checks: 38 | 39 | ```bash 40 | make lint-all 41 | ``` 42 | 43 | 5. Run `tests` (including your new ones) with 44 | 45 | ```bash 46 | make test 47 | ``` 48 | 49 | 6. Upload your changes to your fork, then make a PR from there to the main repo: 50 | 51 | ```bash 52 | git checkout -b your-branch 53 | git add . 54 | git commit -m ":tada: Initial commit" 55 | git remote add origin https://github.com/your-fork/mlm-extension.git 56 | git push -u origin your-branch 57 | ``` 58 | 59 | ## Building and releasing 60 | 61 | 62 | 63 | > [!WARNING] 64 | > There are multiple types of releases for this repository:
65 | > 66 | > 1. Release for MLM specification (usually, this should include one for `stac-model` as well to support it) 67 | > 2. Release for `stac-model` only 68 | 69 | 70 | 71 | ### Building a new version of MLM specification 72 | 73 | - Checkout to the `main` branch by making sure the CI passed all previous tests. 74 | - Bump the version with `bump-my-version bump --verbose `. 75 | - Consider using `--dry-run` beforehand to inspect the changes. 76 | - The `` should be one of `major`, `minor`, or `patch`.
77 | Alternatively, the version can be set explicitly with `--new-version patch`.
78 | For more details, refer to the [Semantic Versions][semver] standard; 79 | - Make a commit to `GitHub` and push the corresponding auto-generated `v{MAJOR}.{MINOR}.{PATCH}` tag. 80 | - Validate that the CI validated everything once again. 81 | - Create a `GitHub release` with the created tag. 82 | 83 | 84 | 85 | > [!WARNING] 86 | > 87 | > - Ensure the "Set as the latest release" option is selected :heavy_check_mark:. 88 | > - Ensure the diff ranges from the previous MLM version, and not an intermediate `stac-model` release. 89 | 90 | 91 | 92 | ### Building a new version of `stac-model` 93 | 94 | - Apply any relevant changes and `CHANGELOG.md` entries in a PR that modifies `stac-model`. 95 | - Bump the version with `bump-my-version bump --verbose --config-file stac-model.bump.toml`. 96 | - You can pass the new version explicitly, or a rule such as `major`, `minor`, or `patch`.
97 | For more details, refer to the [Semantic Versions][semver] standard; 98 | - Once CI validation succeeded, merge the corresponding PR branch. 99 | - Checkout to `main` branch that contains the freshly created merge commit. 100 | - Push the tag `stac-model-v{MAJOR}.{MINOR}.{PATCH}`. The CI should auto-publish it to PyPI. 101 | - Create a `GitHub release` (if not automatically drafted by the CI). 102 | 103 | 104 | 105 | > [!WARNING] 106 | > 107 | > - Ensure the "Set as the latest release" option is deselected :x:. 108 | > - Ensure the diff ranges from the previous release of `stac-model`, not an intermediate MLM release. 109 | 110 | 111 | 112 | ## Other help 113 | 114 | You can contribute by spreading a word about this library. 115 | It would also be a huge contribution to write 116 | a short article on how you are using this project. 117 | You can also share how the ML Model extension does or does 118 | not serve your needs with us in the GitHub Discussions or raise 119 | Issues for bugs. 120 | 121 | [uv-install]: https://docs.astral.sh/uv/getting-started/installation/ 122 | 123 | [uv-docs]: https://docs.astral.sh/uv/ 124 | 125 | [semver]: https://semver.org/ 126 | -------------------------------------------------------------------------------- /examples/item_pytorch_geo_unet.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "Feature", 3 | "stac_version": "1.0.0", 4 | "stac_extensions": [ 5 | "https://stac-extensions.github.io/mlm/v1.5.0/schema.json", 6 | "https://stac-extensions.github.io/eo/v1.1.0/schema.json" 7 | ], 8 | "id": "pytorch_geo_unet", 9 | "geometry": { 10 | "type": "Polygon", 11 | "coordinates": [ 12 | [ 13 | [ 14 | -7.88, 15 | 37.13 16 | ], 17 | [ 18 | -7.88, 19 | 58.21 20 | ], 21 | [ 22 | 27.91, 23 | 58.21 24 | ], 25 | [ 26 | 27.91, 27 | 37.13 28 | ], 29 | [ 30 | -7.88, 31 | 37.13 32 | ] 33 | ] 34 | ] 35 | }, 36 | "bbox": [ 37 | -7.88, 38 | 37.13, 39 | 27.91, 40 | 58.21 41 | ], 42 | "properties": { 43 | "description": "STAC item generated using unet_mlm() in stac_model/examples.py example. Specified in https://github.com/fieldsoftheworld/ftw-baselines First 4 S2 bands are for image t1 and last 4 bands are for image t2", 44 | "start_datetime": "2015-06-23T00:00:00Z", 45 | "end_datetime": "2024-08-27T23:59:59Z", 46 | "mlm:framework": "segmentation_models_pytorch.decoders.unet.model", 47 | "mlm:accelerator_constrained": false, 48 | "mlm:name": "U-Net_efficientnet-b3", 49 | "mlm:architecture": "segmentation_models_pytorch.decoders.unet.model.Unet", 50 | "mlm:tasks": [ 51 | "semantic-segmentation" 52 | ], 53 | "mlm:input": [ 54 | { 55 | "bands": [ 56 | "B4", 57 | "B3", 58 | "B2", 59 | "B8A", 60 | "B4", 61 | "B3", 62 | "B2", 63 | "B8A" 64 | ], 65 | "variables": [], 66 | "name": "model_input", 67 | "input": { 68 | "shape": [ 69 | -1, 70 | 8, 71 | 3, 72 | 3 73 | ], 74 | "dim_order": [ 75 | "batch", 76 | "bands", 77 | "height", 78 | "width" 79 | ], 80 | "data_type": "float32" 81 | }, 82 | "value_scaling": [ 83 | { 84 | "type": "z-score", 85 | "mean": 0, 86 | "stddev": 3000 87 | } 88 | ], 89 | "pre_processing_function": null 90 | } 91 | ], 92 | "mlm:output": [ 93 | { 94 | "bands": [], 95 | "variables": [], 96 | "name": "model_output", 97 | "tasks": [ 98 | "semantic-segmentation" 99 | ], 100 | "result": { 101 | "shape": [ 102 | -1, 103 | 2 104 | ], 105 | "dim_order": [ 106 | "batch", 107 | "classes" 108 | ], 109 | "data_type": "float32" 110 | }, 111 | "classification:classes": [ 112 | { 113 | "value": 0, 114 | "name": "class_0", 115 | "description": "Auto-generated class 0" 116 | }, 117 | { 118 | "value": 1, 119 | "name": "class_1", 120 | "description": "Auto-generated class 1" 121 | } 122 | ], 123 | "post_processing_function": null 124 | } 125 | ], 126 | "mlm:total_parameters": 13160978, 127 | "mlm:pretrained": true, 128 | "datetime": null 129 | }, 130 | "links": [ 131 | { 132 | "rel": "cite-as", 133 | "href": "https://arxiv.org/abs/2409.16252", 134 | "type": "text/html", 135 | "title": "Publication for the training dataset of the model" 136 | }, 137 | { 138 | "rel": "self", 139 | "href": "./item_pytorch_geo_unet.json", 140 | "type": "application/geo+json" 141 | }, 142 | { 143 | "rel": "collection", 144 | "href": "./collection.json", 145 | "type": "application/json" 146 | } 147 | ], 148 | "assets": { 149 | "model": { 150 | "href": "https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/noncommercial/2-class/sentinel2_unet_effb3-bf010a31.pth", 151 | "type": "application/octet-stream; application=pytorch", 152 | "title": "U-Net_efficientnet-b3", 153 | "description": "A U-Net segmentation model with efficientnet-b3 encoder Weights are non-commercial.", 154 | "mlm:artifact_type": "torch.save", 155 | "eo:bands": [ 156 | { 157 | "name": "B4" 158 | }, 159 | { 160 | "name": "B3" 161 | }, 162 | { 163 | "name": "B2" 164 | }, 165 | { 166 | "name": "B8A" 167 | }, 168 | { 169 | "name": "B4" 170 | }, 171 | { 172 | "name": "B3" 173 | }, 174 | { 175 | "name": "B2" 176 | }, 177 | { 178 | "name": "B8A" 179 | } 180 | ], 181 | "roles": [ 182 | "mlm:model", 183 | "mlm:weights", 184 | "data" 185 | ] 186 | }, 187 | "source_code": { 188 | "href": "https://github.com/qubvel-org/segmentation_models.pytorch", 189 | "type": "text/html", 190 | "title": "Source code for U-Net_efficientnet-b3", 191 | "description": "GitHub repo of the pytorch model", 192 | "roles": [ 193 | "mlm:source_code", 194 | "code" 195 | ] 196 | } 197 | }, 198 | "collection": "ml-model-examples", 199 | "mlm:entrypoint": "segmentation_models_pytorch.decoders.unet.model.Unet" 200 | } -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | #* Variables 2 | SHELL ?= /usr/bin/env bash 3 | 4 | # use the directory rather than the python binary to allow auto-discovery, which is more cross-platform compatible 5 | PYTHON_PATH := $(shell which python) 6 | # handle whether running on Windows or Unix-like systems 7 | ifneq ($(findstring $(PYTHON_PATH),bin/python),) 8 | PYTHON_ROOT := $(shell dirname $(dir $(PYTHON_PATH))) 9 | else 10 | PYTHON_ROOT := $(shell dirname $(PYTHON_PATH)) 11 | endif 12 | ifeq ($(patsubst %/bin,,$(lastword $(PYTHON_ROOT))),) 13 | PYTHON_ROOT := $(dir $(PYTHON_ROOT)) 14 | endif 15 | UV_PYTHON_ROOT ?= $(PYTHON_ROOT) 16 | 17 | # to actually reuse an existing virtual/conda environment, the 'UV_PROJECT_ENVIRONMENT' variable must be set to it 18 | # use this command: 19 | # UV_PROJECT_ENVIRONMENT=/path/to/env make [target] 20 | # consider exporting this variable in '/path/to/env/etc/conda/activate.d/env.sh' to enable it by default when 21 | # activating a conda environment, and reset it in '/path/to/env/etc/conda/deactivate.d/env.sh' 22 | UV_PROJECT_ENVIRONMENT ?= 23 | # make sure every uv command employs the specified environment path 24 | ifeq (${UV_PROJECT_ENVIRONMENT},) 25 | UV_COMMAND := uv 26 | else 27 | UV_COMMAND := UV_PROJECT_ENVIRONMENT="${UV_PROJECT_ENVIRONMENT}" uv 28 | endif 29 | 30 | #* UV 31 | .PHONY: setup 32 | setup: 33 | which uv >/dev/null || (curl -LsSf https://astral.sh/uv/install.sh | sh) 34 | 35 | .PHONY: publish 36 | publish: 37 | $(UV_COMMAND) publish --build 38 | 39 | #* Installation 40 | .PHONY: install 41 | install: setup 42 | $(UV_COMMAND) export --format requirements-txt -o requirements.txt --no-dev 43 | $(UV_COMMAND) pip install --python "$(UV_PYTHON_ROOT)" -r requirements.txt 44 | 45 | .PHONY: install-dev 46 | install-dev: setup 47 | $(UV_COMMAND) export --format requirements-txt -o requirements-dev.txt 48 | $(UV_COMMAND) pip install --python "$(UV_PYTHON_ROOT)" -r requirements-dev.txt 49 | 50 | .PHONY: install-dev-extras 51 | install-dev-extras: setup 52 | $(UV_COMMAND) export --format requirements-txt -o requirements-dev.txt 53 | $(UV_COMMAND) pip install --python "$(UV_PYTHON_ROOT)" -e .[torch] -r requirements-dev.txt 54 | 55 | .PHONY: pre-commit-install 56 | pre-commit-install: setup 57 | $(UV_COMMAND) run --no-sync --python "$(UV_PYTHON_ROOT)" pre-commit install 58 | 59 | #* Formatters 60 | .PHONY: codestyle 61 | codestyle: setup 62 | $(UV_COMMAND) run --no-sync --python "$(UV_PYTHON_ROOT)" ruff format --config=pyproject.toml stac_model tests 63 | 64 | .PHONY: format 65 | format: codestyle 66 | 67 | #* Testing 68 | .PHONY: test 69 | test: setup 70 | $(UV_COMMAND) run --no-sync --python "$(UV_PYTHON_ROOT)" pytest -c pyproject.toml -v --cov-report=html --cov=stac_model --cov-config pyproject.toml tests/ 71 | 72 | #* Linting 73 | .PHONY: check 74 | check: check-examples check-markdown check-lint check-mypy check-safety check-citation 75 | 76 | .PHONY: check-all 77 | check-all: check 78 | 79 | .PHONY: mypy 80 | mypy: setup 81 | $(UV_COMMAND) run --no-sync --python "$(UV_PYTHON_ROOT)" mypy --config-file pyproject.toml ./ 82 | 83 | .PHONY: check-mypy 84 | check-mypy: mypy 85 | 86 | .PHONY: check-safety 87 | check-safety: setup 88 | $(UV_COMMAND) run --no-sync --python "$(UV_PYTHON_ROOT)" safety check --full-report 89 | $(UV_COMMAND) run --no-sync --python "$(UV_PYTHON_ROOT)" bandit -ll --recursive stac_model tests 90 | 91 | .PHONY: check-citation 92 | check-citation: setup 93 | $(UV_COMMAND) run --no-sync --python "$(UV_PYTHON_ROOT)" cffconvert --validate 94 | 95 | # see https://docs.astral.sh/ruff/formatter/#sorting-imports for use of both `check` and `format` commands 96 | .PHONY: lint 97 | lint: setup 98 | $(UV_COMMAND) run --no-sync --python "$(UV_PYTHON_ROOT)" ruff check --select I --fix --config=pyproject.toml ./ 99 | $(UV_COMMAND) run --no-sync --python "$(UV_PYTHON_ROOT)" ruff format --config=pyproject.toml ./ 100 | 101 | .PHONY: check-lint 102 | check-lint: setup 103 | $(UV_COMMAND) run --python "$(UV_PYTHON_ROOT)" ruff check --config=pyproject.toml ./ 104 | 105 | .PHONY: format-lint 106 | format-lint: lint 107 | $(UV_COMMAND) run --no-sync --python "$(UV_PYTHON_ROOT)" ruff check --fix --config=pyproject.toml ./ 108 | 109 | .PHONY: install-npm 110 | install-npm: 111 | npm install 112 | 113 | .PHONY: check-markdown 114 | check-markdown: install-npm 115 | npm run check-markdown 116 | 117 | .PHONY: format-markdown 118 | format-markdown: install-npm 119 | npm run format-markdown 120 | 121 | .PHONY: check-examples 122 | check-examples: install-npm 123 | npm run check-examples 124 | 125 | .PHONY: format-examples 126 | format-examples: install-npm 127 | npm run format-examples 128 | 129 | FORMATTERS := lint markdown examples 130 | $(addprefix fix-, $(FORMATTERS)): fix-%: format-% 131 | 132 | .PHONY: lint-all 133 | lint-all: lint mypy check-safety check-markdown 134 | 135 | .PHONY: update-dev-deps 136 | update-dev-deps: setup 137 | $(UV_COMMAND) export --only-dev --format requirements-txt -o requirements-only-dev.txt 138 | $(UV_COMMAND) pip install --python "$(UV_PYTHON_ROOT)" -r requirements-only-dev.txt 139 | 140 | #* Cleaning 141 | .PHONY: pycache-remove 142 | pycache-remove: 143 | find . | grep -E "(__pycache__|\.pyc|\.pyo$$)" | xargs rm -rf 144 | 145 | .PHONY: dsstore-remove 146 | dsstore-remove: 147 | find . | grep -E ".DS_Store" | xargs rm -rf 148 | 149 | .PHONY: mypycache-remove 150 | mypycache-remove: 151 | find . | grep -E ".mypy_cache" | xargs rm -rf 152 | 153 | .PHONY: ipynbcheckpoints-remove 154 | ipynbcheckpoints-remove: 155 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 156 | 157 | .PHONY: pytestcache-remove 158 | pytestcache-remove: 159 | find . | grep -E ".pytest_cache" | xargs rm -rf 160 | 161 | .PHONY: build-remove 162 | build-remove: 163 | rm -rf build/ 164 | 165 | .PHONY: cleanup 166 | cleanup: pycache-remove dsstore-remove mypycache-remove ipynbcheckpoints-remove pytestcache-remove 167 | -------------------------------------------------------------------------------- /examples/item_bands_expression.json: -------------------------------------------------------------------------------- 1 | { 2 | "$comment": "Demonstrate the use of MLM and EO for bands description, with EO bands directly in the Model Asset.", 3 | "stac_version": "1.0.0", 4 | "stac_extensions": [ 5 | "https://stac-extensions.github.io/mlm/v1.5.0/schema.json", 6 | "https://stac-extensions.github.io/eo/v1.1.0/schema.json", 7 | "https://stac-extensions.github.io/raster/v1.1.0/schema.json", 8 | "https://stac-extensions.github.io/file/v1.0.0/schema.json", 9 | "https://stac-extensions.github.io/ml-aoi/v0.2.0/schema.json" 10 | ], 11 | "type": "Feature", 12 | "id": "resnet-18_sentinel-2_all_moco_classification", 13 | "collection": "ml-model-examples", 14 | "geometry": { 15 | "type": "Polygon", 16 | "coordinates": [ 17 | [ 18 | [ 19 | -7.882190080512502, 20 | 37.13739173208318 21 | ], 22 | [ 23 | -7.882190080512502, 24 | 58.21798141355221 25 | ], 26 | [ 27 | 27.911651652899923, 28 | 58.21798141355221 29 | ], 30 | [ 31 | 27.911651652899923, 32 | 37.13739173208318 33 | ], 34 | [ 35 | -7.882190080512502, 36 | 37.13739173208318 37 | ] 38 | ] 39 | ] 40 | }, 41 | "bbox": [ 42 | -7.882190080512502, 43 | 37.13739173208318, 44 | 27.911651652899923, 45 | 58.21798141355221 46 | ], 47 | "properties": { 48 | "description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO", 49 | "datetime": null, 50 | "start_datetime": "1900-01-01T00:00:00Z", 51 | "end_datetime": "9999-12-31T23:59:59Z", 52 | "mlm:name": "Resnet-18 Sentinel-2 ALL MOCO", 53 | "mlm:tasks": [ 54 | "classification" 55 | ], 56 | "mlm:architecture": "ResNet", 57 | "mlm:framework": "pytorch", 58 | "mlm:framework_version": "2.1.2+cu121", 59 | "file:size": 43000000, 60 | "mlm:memory_size": 1, 61 | "mlm:total_parameters": 11700000, 62 | "mlm:pretrained_source": "EuroSat Sentinel-2", 63 | "mlm:accelerator": "cuda", 64 | "mlm:accelerator_constrained": false, 65 | "mlm:accelerator_summary": "Unknown", 66 | "mlm:batch_size_suggestion": 256, 67 | "mlm:input": [ 68 | { 69 | "name": "RBG+NDVI Bands Sentinel-2 Batch", 70 | "bands": [ 71 | { 72 | "name": "B04" 73 | }, 74 | { 75 | "name": "B03" 76 | }, 77 | { 78 | "name": "B02" 79 | }, 80 | { 81 | "name": "NDVI", 82 | "format": "rio-calc", 83 | "expression": "(B08 - B04) / (B08 + B04)" 84 | } 85 | ], 86 | "input": { 87 | "shape": [ 88 | -1, 89 | 4, 90 | 64, 91 | 64 92 | ], 93 | "dim_order": [ 94 | "batch", 95 | "bands", 96 | "height", 97 | "width" 98 | ], 99 | "data_type": "float32" 100 | } 101 | } 102 | ], 103 | "mlm:output": [ 104 | { 105 | "name": "classification", 106 | "tasks": [ 107 | "segmentation", 108 | "semantic-segmentation" 109 | ], 110 | "result": { 111 | "shape": [ 112 | -1, 113 | 2 114 | ], 115 | "dim_order": [ 116 | "batch", 117 | "class" 118 | ], 119 | "data_type": "float32" 120 | }, 121 | "classification_classes": [ 122 | { 123 | "value": 1, 124 | "name": "vegetation", 125 | "title": "Vegetation", 126 | "description": "Pixels were vegetation is detected.", 127 | "color_hint": "00FF00", 128 | "nodata": false 129 | }, 130 | { 131 | "value": 0, 132 | "name": "background", 133 | "title": "Non-Vegetation", 134 | "description": "Anything that is not classified as vegetation.", 135 | "color_hint": "000000", 136 | "nodata": false 137 | } 138 | ], 139 | "post_processing_function": null 140 | } 141 | ] 142 | }, 143 | "assets": { 144 | "weights": { 145 | "href": "https://example.com/model-rgb-ndvi.pth", 146 | "title": "Pytorch weights checkpoint", 147 | "description": "A vegetation classification model trained on Sentinel-2 imagery and NDVI.", 148 | "type": "application/octet-stream; application=pytorch", 149 | "roles": [ 150 | "mlm:model", 151 | "mlm:weights" 152 | ], 153 | "mlm:artifact_type": "torch.save", 154 | "$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.", 155 | "eo:bands": [ 156 | { 157 | "name": "B02", 158 | "common_name": "blue", 159 | "description": "Blue (band 2)", 160 | "center_wavelength": 0.49, 161 | "full_width_half_max": 0.098 162 | }, 163 | { 164 | "name": "B03", 165 | "common_name": "green", 166 | "description": "Green (band 3)", 167 | "center_wavelength": 0.56, 168 | "full_width_half_max": 0.045 169 | }, 170 | { 171 | "name": "B04", 172 | "common_name": "red", 173 | "description": "Red (band 4)", 174 | "center_wavelength": 0.665, 175 | "full_width_half_max": 0.038 176 | }, 177 | { 178 | "name": "B08", 179 | "common_name": "nir", 180 | "description": "NIR 1 (band 8)", 181 | "center_wavelength": 0.842, 182 | "full_width_half_max": 0.145 183 | } 184 | ] 185 | } 186 | }, 187 | "links": [ 188 | { 189 | "rel": "collection", 190 | "href": "./collection.json", 191 | "type": "application/json" 192 | }, 193 | { 194 | "rel": "self", 195 | "href": "./item_bands_expression.json", 196 | "type": "application/geo+json" 197 | }, 198 | { 199 | "rel": "derived_from", 200 | "href": "https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a", 201 | "type": "application/json", 202 | "ml-aoi:split": "train" 203 | } 204 | ] 205 | } 206 | -------------------------------------------------------------------------------- /tests/test_stac_model.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections.abc import Sized 3 | from typing import Any 4 | 5 | import pydantic 6 | import pytest 7 | 8 | from stac_model.base import ModelBand, ModelDataVariable 9 | from stac_model.input import InputStructure, ModelInput 10 | from stac_model.output import ModelOutput, ModelResult 11 | 12 | ModelClass = type[ModelInput | ModelOutput] 13 | 14 | 15 | def make_struct(model_class: ModelClass, refs: Sized) -> dict[str, Any]: 16 | if model_class is ModelInput: 17 | struct_class = InputStructure 18 | struct_field = "input" 19 | struct_xargs = {} 20 | else: 21 | struct_class = ModelResult # type: ignore[assignment] 22 | struct_field = "result" 23 | struct_xargs = {"tasks": ["classification"]} 24 | struct = struct_class( 25 | shape=[-1, len(refs), 64, 64], 26 | dim_order=["batch", "channel", "height", "width"], 27 | data_type="float32", 28 | ) 29 | return {struct_field: struct, **struct_xargs} 30 | 31 | 32 | @pytest.mark.parametrize( 33 | ["model_class", "bands"], 34 | itertools.product( 35 | [ModelInput, ModelOutput], 36 | [ 37 | ["B04", "B03", "B02"], 38 | [{"name": "B04"}, {"name": "B03"}, {"name": "B02"}], 39 | [{"name": "NDVI", "format": "rio-calc", "expression": "(B08 - B04) / (B08 + B04)"}], 40 | [ 41 | "B04", 42 | {"name": "B03"}, 43 | "B02", 44 | {"name": "NDVI", "format": "rio-calc", "expression": "(B08 - B04) / (B08 + B04)"}, 45 | ], 46 | ], 47 | ), 48 | ) 49 | def test_model_band(model_class: ModelClass, bands: list[ModelBand]) -> None: 50 | struct = make_struct(model_class, bands) 51 | mlm_object = model_class( 52 | name="test", 53 | bands=bands, 54 | **struct, 55 | ) 56 | mlm_bands = mlm_object.model_dump()["bands"] 57 | assert mlm_bands == bands 58 | 59 | 60 | @pytest.mark.parametrize( 61 | ["model_class", "bands"], 62 | itertools.product( 63 | [ModelInput, ModelOutput], 64 | [ 65 | [{"name": "test", "expression": "missing-format"}], 66 | [{"name": "test", "format": "missing-expression"}], 67 | ], 68 | ), 69 | ) 70 | def test_model_band_format_expression_dependency(model_class: ModelClass, bands: list[ModelBand]) -> None: 71 | with pytest.raises(pydantic.ValidationError): 72 | struct = make_struct(model_class, bands) 73 | ModelInput( 74 | name="test", 75 | bands=bands, 76 | **struct, 77 | ) 78 | 79 | 80 | @pytest.mark.parametrize( 81 | "processing_expression", 82 | [ 83 | None, 84 | {"format": "test", "expression": "test"}, 85 | [ 86 | {"format": "test", "expression": "test1"}, 87 | {"format": "test", "expression": "test2"}, 88 | ], 89 | ], 90 | ) 91 | def test_model_io_processing_expression_variants(processing_expression): 92 | model_input = ModelInput( 93 | name="test", 94 | bands=[], 95 | input=InputStructure( 96 | shape=[-1, 3, 64, 64], 97 | dim_order=["batch", "channel", "height", "width"], 98 | data_type="float32", 99 | ), 100 | pre_processing_function=processing_expression, 101 | ) 102 | model_json = model_input.model_dump() 103 | assert model_json["pre_processing_function"] == processing_expression 104 | 105 | model_output = ModelOutput( 106 | name="test", 107 | classes=[], 108 | tasks={"classification"}, 109 | result=ModelResult( 110 | shape=[-1, 2, 64, 64], 111 | dim_order=["batch", "channel", "height", "width"], 112 | data_type="float32", 113 | ), 114 | post_processing_function=processing_expression, 115 | ) 116 | model_json = model_output.model_dump() 117 | assert model_json["post_processing_function"] == processing_expression 118 | 119 | 120 | @pytest.mark.parametrize( 121 | "variables", 122 | [ 123 | [ 124 | "temperature_2m", 125 | "10m_u_component_of_wind", 126 | "10m_v_component_of_wind", 127 | ], 128 | [ 129 | {"name": "temperature_2m"}, 130 | {"name": "10m_u_component_of_wind"}, 131 | {"name": "10m_v_component_of_wind"}, 132 | ], 133 | [ 134 | {"name": "temperature_2m_celsius", "format": "rio-calc", "expression": "temperature_2m + 273.15"}, 135 | ], 136 | [ 137 | "temperature_2m", 138 | {"name": "10m_u_component_of_wind"}, 139 | "10m_v_component_of_wind", 140 | {"name": "temperature_2m_celsius", "format": "rio-calc", "expression": "temperature_2m + 273.15"}, 141 | ], 142 | ], 143 | ) 144 | def test_model_variables(variables: list[ModelDataVariable]) -> None: 145 | mlm_input = ModelInput( 146 | name="test", 147 | variables=variables, 148 | input=InputStructure( 149 | shape=[-1, len(variables), 365], 150 | dim_order=["batch", "variables", "time"], 151 | data_type="float32", 152 | ), 153 | ) 154 | mlm_variables = mlm_input.model_dump()["variables"] 155 | assert mlm_variables == variables 156 | 157 | 158 | class Omitted: 159 | pass 160 | 161 | 162 | @pytest.mark.parametrize( 163 | ["model_class", "bands", "variables", "expected_bands", "expected_variables"], 164 | [ # type: ignore 165 | (model_cls, *args) 166 | for model_cls, args in itertools.product( 167 | [ModelInput, ModelOutput], 168 | [ 169 | ( 170 | # explicit empty list should be kept 171 | [], 172 | [], 173 | [], 174 | [], 175 | ), 176 | ( 177 | # explicit None should drop the definitions 178 | None, 179 | None, 180 | Omitted, 181 | Omitted, 182 | ), 183 | ( 184 | # omitting the properties should default to empty definitions 185 | Omitted, 186 | Omitted, 187 | [], 188 | [], 189 | ), 190 | ], 191 | ) 192 | ], 193 | ) 194 | def test_model_bands_or_variables_defaults( 195 | model_class: ModelClass, 196 | bands: Any, 197 | variables: Any, 198 | expected_bands: Any, 199 | expected_variables: Any, 200 | ) -> None: 201 | mlm_xargs = {} 202 | if bands is not Omitted: 203 | mlm_xargs["bands"] = bands 204 | if variables is not Omitted: 205 | mlm_xargs["variables"] = variables 206 | mlm_struct = make_struct(model_class, [1, 2, 3]) 207 | mlm_input = model_class(name="test", **mlm_struct, **mlm_xargs) 208 | mlm_input_json = mlm_input.model_dump() 209 | if expected_bands is Omitted: 210 | assert "bands" not in mlm_input_json 211 | else: 212 | mlm_bands = mlm_input_json["bands"] 213 | assert mlm_bands == expected_bands 214 | if expected_variables is Omitted: 215 | assert "variables" not in mlm_input_json 216 | else: 217 | mlm_variables = mlm_input_json["variables"] 218 | assert mlm_variables == expected_variables 219 | -------------------------------------------------------------------------------- /README_STAC_MODEL.md: -------------------------------------------------------------------------------- 1 | # stac-model 2 | 3 | 4 | 5 |
6 | 7 | [![Python support][bp1]][bp2] 8 | [![PyPI Release][bp3]][bp2] 9 | [![Repository][bscm1]][bp4] 10 | [![Releases][bscm2]][bp5] 11 | 12 | [![Contributions Welcome][bp8]][bp9] 13 | 14 | [![uv][bp11]][bp12] 15 | [![Pre-commit][bp15]][bp16] 16 | [![Semantic versions][blic3]][bp5] 17 | [![Pipelines][bscm6]][bscm7] 18 | 19 | *A PydanticV2 and PySTAC validation and serialization library for the STAC ML Model Extension* 20 | 21 |
22 | 23 | > ⚠️
24 | > FIXME: update description with ML framework connectors (pytorch, scikit-learn, etc.) 25 | 26 | ## Installation 27 | 28 | ```shell 29 | pip install -U stac-model 30 | ``` 31 | 32 | or install with uv: 33 | 34 | ```shell 35 | uv add stac-model 36 | ``` 37 | 38 | Then you can run 39 | 40 | ```shell 41 | stac-model --help 42 | ``` 43 | 44 | ## Creating example metadata JSON for a STAC Item 45 | 46 | ```shell 47 | stac-model 48 | ``` 49 | 50 | This will make [this example item](./examples/item_basic.json) for an example model. 51 | 52 | ## Validating Model Metadata 53 | 54 | An alternative use of `stac_model` is to validate config files containing model metadata using the `MLModelProperties` schema. 55 | 56 | Given a YAML or JSON file with the structure in [examples/torch/mlm-metadata.yaml](./examples/torch/mlm-metadata.yaml), the model metadata can be validated as follows: 57 | 58 | ```python 59 | import yaml 60 | from stac_model.schema import MLModelProperties 61 | 62 | with open("examples/mlm-metadata.yaml", "r", encoding="utf-8") as f: 63 | metadata = yaml.safe_load(f) 64 | 65 | MLModelProperties.model_validate(metadata["properties"]) 66 | ``` 67 | 68 | ## Exporting and Packaging PyTorch Models, Transforms, and Model Metadata 69 | 70 | As of PyTorch 2.8, and stac_model 0.4.0, you can now export and package PyTorch models, transforms, 71 | and model metadata using functions in `stac_model.torch.export`. Below is an example of exporting a 72 | U-Net model pretrained on the [Fields of The World (FTW) dataset](https://fieldsofthe.world/) for 73 | field boundary segmentation in Sentinel-2 satellite imagery using the [TorchGeo](https://github.com/microsoft/torchgeo) library. 74 | 75 | > 📝 **Note:** To customize the metadata for your model you can use this [example](./tests/torch/metadata.yaml) as a template. 76 | 77 | ```python 78 | import torch 79 | import torchvision.transforms.v2 as T 80 | from torchgeo.models import Unet_Weights, unet 81 | from stac_model.torch.export import save 82 | 83 | weights = Unet_Weights.SENTINEL2_3CLASS_FTW 84 | transforms = torch.nn.Sequential( 85 | T.Resize((256, 256)), 86 | T.Normalize(mean=[0.0], std=[3000.0]) 87 | ) 88 | model = unet(weights=weights) 89 | 90 | save( 91 | output_file="ftw.pt2", 92 | model=model, # Must be an nn.Module 93 | transforms=transforms, # Must be an nn.Module 94 | metadata="metadata.yaml", # Can be a metadata yaml or stac_model.schema.MLModelProperties object 95 | input_shape=[-1, 8, -1, -1], # -1 indicates a dynamic shaped dimension 96 | device="cpu", 97 | dtype=torch.float32, 98 | aoti_compile_and_package=False, # True for AOTInductor compile otherwise use torch.export 99 | ) 100 | ``` 101 | 102 | The model, transforms, and metadata can then be loaded into an environment with only torch and stac_model as required dependencies like below: 103 | 104 | ```python 105 | import yaml 106 | from torch.export.pt2_archive._package import load_pt2 107 | 108 | pt2 = load_pt2(archive_path) 109 | metadata = yaml.safe_load(pt2.extra_files["mlm-metadata"]) 110 | 111 | # If exported with aoti_compile_and_package=True 112 | model = pt2.aoti_runners["model"] 113 | transforms = pt2.aoti_runners["transforms"] 114 | 115 | # If exported with aoti_compile_and_package=False 116 | model = pt2.exported_programs["model"].module() 117 | transforms = pt2.exported_programs["transforms"].module() 118 | 119 | # Inference 120 | batch = ... # An input batch tensor 121 | outputs = model(transforms(batch)) 122 | ``` 123 | 124 | ### Creating a STAC Item from a PyTorch Model 125 | 126 | You can generate a valid STAC Item using the **Machine Learning Model (MLM) Extension**. 127 | 128 | The example below demonstrates creating a STAC Item from a U-Net model pretrained on the 129 | [Fields of The World (FTW) dataset](https://fieldsofthe.world/) for field boundary segmentation 130 | in Sentinel-2 satellite imagery, using the [TorchGeo](https://github.com/microsoft/torchgeo) library 131 | 132 | ```python 133 | from stac_model.examples import unet_mlm 134 | from stac_model.torch import MLModelExtension 135 | from torchgeo.models import unet, Unet_Weights 136 | 137 | # Use default TorchGeo UNet weights 138 | weights = Unet_Weights.SENTINEL2_2CLASS_NC_FTW 139 | model = unet(weights=weights) 140 | 141 | # Create an ItemMLModelExtension using the MLM extension 142 | item_ext = MLModelExtension.from_torch( 143 | model, 144 | weights=weights, 145 | item_id="pytorch_geo_unet" 146 | ) 147 | 148 | ``` 149 | 150 | For a more complete example including STAC Item properties, geometry, and datetime ranges, 151 | see `unet_mlm()` in [`stac_model/examples.py`](stac_model/examples.py). 152 | 153 | ## 📈 Releases 154 | 155 | You can see the list of available releases on the [GitHub Releases][github-releases] page. 156 | 157 | ## 📄 License 158 | 159 | [![License][blic1]][blic2] 160 | 161 | This project is licenced under the terms of the `Apache Software License 2.0` licence. 162 | See [LICENSE][blic2] for more details. 163 | 164 | ## 💗 Credits 165 | 166 | [![Python project templated from galactipy.][bp6]][bp7] 167 | 168 | 169 | 170 | [bp1]: https://img.shields.io/pypi/pyversions/stac-model?style=for-the-badge 171 | 172 | [bp2]: https://pypi.org/project/stac-model/ 173 | 174 | [bp3]: https://img.shields.io/pypi/v/stac-model?style=for-the-badge&logo=pypi&color=3775a9 175 | 176 | [bp4]: https://github.com/stac-extensions/mlm 177 | 178 | [bp5]: https://github.com/stac-extensions/mlm/releases 179 | 180 | [bp6]: https://img.shields.io/badge/made%20with-galactipy%20%F0%9F%8C%8C-179287?style=for-the-badge&labelColor=193A3E 181 | 182 | [bp7]: https://kutt.it/7fYqQl 183 | 184 | [bp8]: https://img.shields.io/static/v1.svg?label=Contributions&message=Welcome&color=0059b3&style=for-the-badge 185 | 186 | [bp9]: https://github.com/stac-extensions/mlm/blob/main/CONTRIBUTING.md 187 | 188 | [bp11]: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json&style=for-the-badge 189 | 190 | [bp12]: https://docs.astral.sh/uv/ 191 | 192 | [bp15]: https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white&style=for-the-badge 193 | 194 | [bp16]: https://github.com/stac-extensions/mlm/blob/main/.pre-commit-config.yaml 195 | 196 | [blic1]: https://img.shields.io/github/license/stac-extensions/mlm?style=for-the-badge 197 | 198 | [blic2]: https://github.com/stac-extensions/mlm/blob/main/LICENSE 199 | 200 | [blic3]: https://img.shields.io/badge/%F0%9F%93%A6-semantic%20versions-4053D6?style=for-the-badge 201 | 202 | [github-releases]: https://github.com/stac-extensions/mlm/releases 203 | 204 | [bscm1]: https://img.shields.io/badge/GitHub-100000?style=for-the-badge&logo=github&logoColor=white 205 | 206 | [bscm2]: https://img.shields.io/github/v/release/stac-extensions/mlm?filter=stac-model-v*&style=for-the-badge&logo=semantic-release&color=347d39 207 | 208 | [bscm6]: https://img.shields.io/github/actions/workflow/status/stac-extensions/mlm/publish.yaml?style=for-the-badge&logo=github 209 | 210 | [bscm7]: https://github.com/stac-extensions/mlm/blob/main/.github/workflows/publish.yaml 211 | 212 | [hub1]: https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuring-dependabot-version-updates#enabling-dependabot-version-updates 213 | 214 | [hub2]: https://github.com/marketplace/actions/close-stale-issues 215 | 216 | [hub6]: https://docs.github.com/en/code-security/dependabot 217 | 218 | [hub8]: https://github.com/stac-extensions/mlm/blob/main/.github/release-drafter.yml 219 | 220 | [hub9]: https://github.com/stac-extensions/mlm/blob/main/.github/.stale.yml 221 | -------------------------------------------------------------------------------- /examples/item_multi_io.json: -------------------------------------------------------------------------------- 1 | { 2 | "stac_version": "1.0.0", 3 | "stac_extensions": [ 4 | "https://stac-extensions.github.io/mlm/v1.5.0/schema.json", 5 | "https://stac-extensions.github.io/raster/v1.1.0/schema.json", 6 | "https://stac-extensions.github.io/file/v1.0.0/schema.json", 7 | "https://stac-extensions.github.io/ml-aoi/v0.2.0/schema.json" 8 | ], 9 | "type": "Feature", 10 | "id": "model-multi-input", 11 | "collection": "ml-model-examples", 12 | "geometry": { 13 | "type": "Polygon", 14 | "coordinates": [ 15 | [ 16 | [ 17 | -7.882190080512502, 18 | 37.13739173208318 19 | ], 20 | [ 21 | -7.882190080512502, 22 | 58.21798141355221 23 | ], 24 | [ 25 | 27.911651652899923, 26 | 58.21798141355221 27 | ], 28 | [ 29 | 27.911651652899923, 30 | 37.13739173208318 31 | ], 32 | [ 33 | -7.882190080512502, 34 | 37.13739173208318 35 | ] 36 | ] 37 | ] 38 | }, 39 | "bbox": [ 40 | -7.882190080512502, 41 | 37.13739173208318, 42 | 27.911651652899923, 43 | 58.21798141355221 44 | ], 45 | "properties": { 46 | "description": "Generic model that employs multiple input sources with different combination of bands, and some inputs without any band at all.", 47 | "datetime": null, 48 | "start_datetime": "1900-01-01T00:00:00Z", 49 | "end_datetime": "9999-12-31T23:59:59Z", 50 | "mlm:name": "Resnet-18 Sentinel-2 ALL MOCO", 51 | "mlm:tasks": [ 52 | "classification" 53 | ], 54 | "mlm:architecture": "ResNet", 55 | "mlm:framework": "pytorch", 56 | "mlm:framework_version": "2.1.2+cu121", 57 | "file:size": 43000000, 58 | "mlm:memory_size": 1, 59 | "mlm:total_parameters": 11700000, 60 | "mlm:pretrained_source": "EuroSat Sentinel-2", 61 | "mlm:accelerator": "cuda", 62 | "mlm:accelerator_constrained": false, 63 | "mlm:accelerator_summary": "Unknown", 64 | "mlm:batch_size_suggestion": 256, 65 | "mlm:input": [ 66 | { 67 | "name": "RGB", 68 | "bands": [ 69 | "B04", 70 | "B03", 71 | "B02" 72 | ], 73 | "input": { 74 | "shape": [ 75 | -1, 76 | 3, 77 | 64, 78 | 64 79 | ], 80 | "dim_order": [ 81 | "batch", 82 | "bands", 83 | "height", 84 | "width" 85 | ], 86 | "data_type": "uint16" 87 | }, 88 | "value_scaling": null, 89 | "resize_type": null 90 | }, 91 | { 92 | "description": "Compute NDVI from Sentinel-2 bands. The single 'NDVI' virtual band is then fed as 'bands' dimension to the model input.", 93 | "name": "NDVI", 94 | "bands": [ 95 | "B04", 96 | "B08" 97 | ], 98 | "pre_processing_function": { 99 | "format": "gdal-calc", 100 | "expression": "(A - B) / (A + B)" 101 | }, 102 | "input": { 103 | "shape": [ 104 | -1, 105 | 1, 106 | 64, 107 | 64 108 | ], 109 | "dim_order": [ 110 | "batch", 111 | "bands", 112 | "height", 113 | "width" 114 | ], 115 | "data_type": "uint16" 116 | } 117 | }, 118 | { 119 | "description": "Digital elevation model. Comes from another source than the Sentinel bands. Therefore, no 'bands' associated to it.", 120 | "name": "DEM", 121 | "bands": [], 122 | "input": { 123 | "shape": [ 124 | -1, 125 | 1, 126 | 64, 127 | 64 128 | ], 129 | "dim_order": [ 130 | "batch", 131 | "h", 132 | "y", 133 | "x" 134 | ], 135 | "data_type": "float32" 136 | } 137 | } 138 | ], 139 | "mlm:output": [ 140 | { 141 | "name": "vegetation-segmentation", 142 | "tasks": [ 143 | "semantic-segmentation" 144 | ], 145 | "result": { 146 | "shape": [ 147 | -1, 148 | 1 149 | ], 150 | "dim_order": [ 151 | "batch", 152 | "class" 153 | ], 154 | "data_type": "uint8" 155 | }, 156 | "classification_classes": [ 157 | { 158 | "value": 0, 159 | "name": "NON_VEGETATION", 160 | "description": "background pixels", 161 | "color_hint": null 162 | }, 163 | { 164 | "value": 1, 165 | "name": "VEGETATION", 166 | "description": "pixels where vegetation was detected", 167 | "color_hint": [ 168 | 0, 169 | 255, 170 | 0 171 | ] 172 | } 173 | ], 174 | "post_processing_function": null 175 | }, 176 | { 177 | "name": "inverse-mask", 178 | "tasks": [ 179 | "semantic-segmentation" 180 | ], 181 | "result": { 182 | "shape": [ 183 | -1, 184 | 1 185 | ], 186 | "dim_order": [ 187 | "batch", 188 | "class" 189 | ], 190 | "data_type": "uint8" 191 | }, 192 | "classification_classes": [ 193 | { 194 | "value": 0, 195 | "name": "NON_VEGETATION", 196 | "description": "background pixels", 197 | "color_hint": [ 198 | 255, 199 | 255, 200 | 255 201 | ] 202 | }, 203 | { 204 | "value": 1, 205 | "name": "VEGETATION", 206 | "description": "pixels where vegetation was detected", 207 | "color_hint": [ 208 | 0, 209 | 0, 210 | 0 211 | ] 212 | } 213 | ], 214 | "post_processing_function": { 215 | "format": "gdal-calc", 216 | "expression": "logical_not(A)" 217 | } 218 | } 219 | ] 220 | }, 221 | "assets": { 222 | "weights": { 223 | "href": "https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_moco/blob/main/resnet50_sentinel2_rgb_moco.pth", 224 | "title": "Pytorch weights checkpoint", 225 | "description": "A Resnet-50 classification model trained on Sentinel-2 RGB imagery with torchgeo.", 226 | "type": "application/octet-stream; application=pytorch", 227 | "roles": [ 228 | "mlm:model", 229 | "mlm:weights" 230 | ], 231 | "mlm:artifact_type": "torch.save", 232 | "raster:bands": [ 233 | { 234 | "name": "B02 - blue", 235 | "nodata": 0, 236 | "data_type": "uint16", 237 | "bits_per_sample": 15, 238 | "spatial_resolution": 10, 239 | "scale": 0.0001, 240 | "offset": 0, 241 | "unit": "m" 242 | }, 243 | { 244 | "name": "B03 - green", 245 | "nodata": 0, 246 | "data_type": "uint16", 247 | "bits_per_sample": 15, 248 | "spatial_resolution": 10, 249 | "scale": 0.0001, 250 | "offset": 0, 251 | "unit": "m" 252 | }, 253 | { 254 | "name": "B04 - red", 255 | "nodata": 0, 256 | "data_type": "uint16", 257 | "bits_per_sample": 15, 258 | "spatial_resolution": 10, 259 | "scale": 0.0001, 260 | "offset": 0, 261 | "unit": "m" 262 | }, 263 | { 264 | "name": "B08 - nir", 265 | "nodata": 0, 266 | "data_type": "uint16", 267 | "bits_per_sample": 15, 268 | "spatial_resolution": 10, 269 | "scale": 0.0001, 270 | "offset": 0, 271 | "unit": "m" 272 | } 273 | ] 274 | } 275 | }, 276 | "links": [ 277 | { 278 | "rel": "collection", 279 | "href": "./collection.json", 280 | "type": "application/json" 281 | }, 282 | { 283 | "rel": "self", 284 | "href": "./item_multi_io.json", 285 | "type": "application/geo+json" 286 | }, 287 | { 288 | "rel": "derived_from", 289 | "href": "https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a", 290 | "type": "application/json", 291 | "ml-aoi:split": "train" 292 | } 293 | ] 294 | } 295 | -------------------------------------------------------------------------------- /stac_model/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections.abc import Sequence 3 | from dataclasses import dataclass 4 | from enum import Enum 5 | from typing import Annotated, Any, Literal, TypeAlias, Union 6 | from typing_extensions import Self 7 | 8 | from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator 9 | 10 | Number: TypeAlias = int | float 11 | JSON: TypeAlias = dict[str, "JSON"] | list["JSON"] | Number | bool | str | None 12 | Path: TypeAlias = os.PathLike[str] 13 | 14 | 15 | @dataclass 16 | class _OmitIfNone: 17 | pass 18 | 19 | 20 | OmitIfNone = _OmitIfNone() 21 | 22 | 23 | class MLMBaseModel(BaseModel): 24 | """ 25 | Allows wrapping any field with an annotation to drop it entirely if unset. 26 | 27 | ```python 28 | field: Annotated[Optional[], OmitIfNone] = None 29 | # or 30 | field: Annotated[Optional[], OmitIfNone] = Field(default=None) 31 | ``` 32 | 33 | Since `OmitIfNone` implies that the value could be `None` (even though it would be dropped), 34 | the `Optional` annotation must be specified to corresponding typings to avoid `mypy` lint issues. 35 | 36 | It is important to use `MLMBaseModel`, otherwise the serializer will not be called and applied. 37 | 38 | Reference: https://github.com/pydantic/pydantic/discussions/5461#discussioncomment-7503283 39 | """ 40 | 41 | @model_serializer 42 | def model_serialize(self): 43 | omit_if_none_fields = { 44 | key: field 45 | for key, field in self.model_fields.items() 46 | if any(isinstance(m, _OmitIfNone) for m in field.metadata) 47 | } 48 | fields = getattr(self, "model_fields", self.__fields__) # noqa 49 | values = { 50 | fields[key].alias or key: val # use the alias if specified 51 | for key, val in self 52 | if key not in omit_if_none_fields or val is not None 53 | } 54 | return values 55 | 56 | model_config = ConfigDict( 57 | populate_by_name=True, 58 | ) 59 | 60 | 61 | DataType: TypeAlias = Literal[ 62 | "uint8", 63 | "uint16", 64 | "uint32", 65 | "uint64", 66 | "int8", 67 | "int16", 68 | "int32", 69 | "int64", 70 | "float16", 71 | "float32", 72 | "float64", 73 | "cint16", 74 | "cint32", 75 | "cfloat32", 76 | "cfloat64", 77 | "other", 78 | ] 79 | 80 | 81 | class TaskEnum(str, Enum): 82 | REGRESSION = "regression" 83 | CLASSIFICATION = "classification" 84 | SCENE_CLASSIFICATION = "scene-classification" 85 | DETECTION = "detection" 86 | OBJECT_DETECTION = "object-detection" 87 | SEGMENTATION = "segmentation" 88 | SEMANTIC_SEGMENTATION = "semantic-segmentation" 89 | INSTANCE_SEGMENTATION = "instance-segmentation" 90 | PANOPTIC_SEGMENTATION = "panoptic-segmentation" 91 | SIMILARITY_SEARCH = "similarity-search" 92 | GENERATIVE = "generative" 93 | IMAGE_CAPTIONING = "image-captioning" 94 | SUPER_RESOLUTION = "super-resolution" 95 | DOWNSCALING = "downscaling" 96 | 97 | 98 | ModelTaskNames: TypeAlias = Literal[ 99 | "regression", 100 | "classification", 101 | "scene-classification", 102 | "detection", 103 | "object-detection", 104 | "segmentation", 105 | "semantic-segmentation", 106 | "instance-segmentation", 107 | "panoptic-segmentation", 108 | "similarity-search", 109 | "generative", 110 | "image-captioning", 111 | "super-resolution", 112 | "downscaling", 113 | ] 114 | 115 | 116 | ModelTask = Union[ModelTaskNames, TaskEnum] 117 | 118 | 119 | class ProcessingExpression(MLMBaseModel): 120 | """ 121 | Expression used to perform a pre-processing or post-processing step on the input or output model data. 122 | """ 123 | 124 | # FIXME: should use 'pystac' reference, but 'processing' extension is not implemented yet! 125 | format: str = Field( 126 | description="The type of the expression that is specified in the 'expression' property.", 127 | ) 128 | expression: Any = Field( 129 | description=( 130 | "An expression compliant with the 'format' specified. " 131 | "The expression can be any data type and depends on the format given. " 132 | "This represents the processing operation to be applied on the entire data before or after the model." 133 | ) 134 | ) 135 | 136 | 137 | class ModelCrossReferenceObject(MLMBaseModel): 138 | name: str = Field( 139 | description=( 140 | "Name of the reference to use for the input or output. " 141 | "The name must refer to an entry of a relevant STAC extension providing further definition details." 142 | ) 143 | ) 144 | # similar to 'ProcessingExpression', but they can be omitted here 145 | format: Annotated[str | None, OmitIfNone] = Field( 146 | default=None, 147 | description="The type of the expression that is specified in the 'expression' property.", 148 | ) 149 | expression: Annotated[Any | None, OmitIfNone] = Field( 150 | default=None, 151 | description=( 152 | "An expression compliant with the 'format' specified. " 153 | "The expression can be any data type and depends on the format given. " 154 | "This represents the processing operation to be applied on the data before or after the model. " 155 | "Contrary to pre/post-processing expressions, this expression is applied only to the specific " 156 | "item it refers to." 157 | ), 158 | ) 159 | 160 | @model_validator(mode="after") 161 | def validate_expression(self) -> Self: 162 | if ( # mutually dependant 163 | (self.format is not None or self.expression is not None) 164 | and (self.format is None or self.expression is None) 165 | ): 166 | raise ValueError("Model band 'format' and 'expression' are mutually dependant.") 167 | return self 168 | 169 | 170 | class ModelBand(ModelCrossReferenceObject): 171 | """ 172 | Definition of a band reference in the model input or output. 173 | """ 174 | 175 | 176 | class ModelDataVariable(ModelCrossReferenceObject): 177 | """ 178 | Definition of a data variable in the model input or output. 179 | """ 180 | 181 | 182 | class ModelBandsOrVariablesReferences(MLMBaseModel): 183 | bands: Annotated[Sequence[str | ModelBand] | None, OmitIfNone] = Field( 184 | description=( 185 | "List of bands that compose the data. " 186 | "If a string is used, it is implied to correspond to a named band. " 187 | "If no band is needed for the data, use an empty array, or omit the property entirely. " 188 | "If provided, order is critical to match the stacking method as aggregated 'bands' dimension " 189 | "in 'dim_order' and 'shape' lists." 190 | ), 191 | # default omission is interpreted the same as if empty list was provided, but populate it explicitly 192 | # if the user wishes to omit the property entirely, they can use `None` explicitly 193 | default=[], 194 | examples=[ 195 | [ 196 | "B01", 197 | {"name": "B02"}, 198 | { 199 | "name": "NDVI", 200 | "format": "rio-calc", 201 | "expression": "(B08 - B04) / (B08 + B04)", 202 | }, 203 | ], 204 | ], 205 | ) 206 | variables: Annotated[Sequence[str | ModelDataVariable] | None, OmitIfNone] = Field( 207 | description=( 208 | "List of variables that compose the data. " 209 | "If a string is used, it is implied to correspond to a named variable. " 210 | "If no variable is needed for the data, use an empty array, or omit the property entirely. " 211 | "If provided, order is critical to match the stacking method as aggregated 'variables' dimension " 212 | "in 'dim_order' and 'shape' lists." 213 | ), 214 | # default omission is interpreted the same as if empty list was provided, but populate it explicitly 215 | # if the user wishes to omit the property entirely, they can use `None` explicitly 216 | default=[], 217 | examples=[ 218 | [ 219 | "10m_u_component_of_wind", 220 | {"name": "10m_v_component_of_wind"}, 221 | { 222 | "name": "temperature_2m_celsius", 223 | "format": "rio-calc", 224 | "expression": "temperature_2m + 273.15", 225 | }, 226 | ], 227 | ], 228 | ) 229 | -------------------------------------------------------------------------------- /examples/item_datacube_variables.json: -------------------------------------------------------------------------------- 1 | { 2 | "$comment": "Demonstrate the use of MLM and DataCube variables description.", 3 | "stac_version": "1.0.0", 4 | "stac_extensions": [ 5 | "https://stac-extensions.github.io/mlm/v1.5.0/schema.json", 6 | "https://stac-extensions.github.io/datacube/v2.3.0/schema.json", 7 | "https://stac-extensions.github.io/file/v2.1.0/schema.json", 8 | "https://stac-extensions.github.io/scientific/v1.0.0/schema.json" 9 | ], 10 | "type": "Feature", 11 | "id": "UNet_ClimateDiffuse_ERA5_Downscaling", 12 | "collection": "ml-model-examples", 13 | "bbox": [ 14 | 233.6, 15 | 54.2, 16 | 297.5, 17 | 22.6 18 | ], 19 | "geometry": { 20 | "type": "Polygon", 21 | "coordinates": [ 22 | [ 23 | [ 24 | 233.6, 25 | 54.2 26 | ], 27 | [ 28 | 297.5, 29 | 54.2 30 | ], 31 | [ 32 | 297.5, 33 | 22.6 34 | ], 35 | [ 36 | 233.6, 37 | 22.6 38 | ], 39 | [ 40 | 233.6, 41 | 54.2 42 | ] 43 | ] 44 | ] 45 | }, 46 | "properties": { 47 | "description": "UNet model for coarse-to-fine downscaling as regression task of climate indices of ERA5 dataset.", 48 | "datetime": null, 49 | "start_datetime": "1940-01-01T00:00:00Z", 50 | "end_datetime": "2100-12-31T23:59:59Z", 51 | "mlm:name": "UNet ClimateDiffuse ERA5 Downscaling", 52 | "mlm:tasks": [ 53 | "regression", 54 | "downscaling" 55 | ], 56 | "mlm:architecture": "U-Net", 57 | "mlm:framework": "pytorch", 58 | "mlm:framework_version": "2.1.2+cu118", 59 | "mlm:accelerator": "cuda", 60 | "mlm:accelerator_constrained": false, 61 | "mlm:input": [ 62 | { 63 | "name": "Coarse climate variables employed by the model downscaling regression. This model takes 2 'constant/spatial' variables and 3 'spatio-temporal' variables provided by ERA5 datacube.", 64 | "variables": [ 65 | "land_sea_mask", 66 | "geopotential", 67 | "temperature_2m", 68 | "10m_u_component_of_wind", 69 | "10m_v_component_of_wind" 70 | ], 71 | "input": { 72 | "shape": [ 73 | -1, 74 | 5, 75 | 128, 76 | 256 77 | ], 78 | "dim_order": [ 79 | "time", 80 | "variables", 81 | "lat", 82 | "lon" 83 | ], 84 | "data_type": "float32" 85 | }, 86 | "norm_by_channel": false, 87 | "resize_type": null, 88 | "pre_processing_function": { 89 | "description": "Script that performs the relevant normalization and concatenation of variables for model input.", 90 | "format": "uri", 91 | "expression": { 92 | "href": "https://raw.githubusercontent.com/robbiewatt1/ClimateDiffuse/refs/heads/main/src/DatasetUS.py", 93 | "type": "text/x-python" 94 | } 95 | } 96 | } 97 | ], 98 | "mlm:output": [ 99 | { 100 | "name": "Fine climate variables predicted by the model. Only the 3 'spatio-temporal' variables are predicted.", 101 | "tasks": [ 102 | "regression", 103 | "downscaling" 104 | ], 105 | "variables": [ 106 | "temperature_2m", 107 | "10m_u_component_of_wind", 108 | "10m_v_component_of_wind" 109 | ], 110 | "result": { 111 | "shape": [ 112 | -1, 113 | 3, 114 | 128, 115 | 256 116 | ], 117 | "dim_order": [ 118 | "time", 119 | "variables", 120 | "lat", 121 | "lon" 122 | ], 123 | "data_type": "float32" 124 | } 125 | } 126 | ], 127 | "cube:dimensions": { 128 | "time": { 129 | "type": "temporal", 130 | "extent": [ 131 | "1940-01-01T00:00:00Z", 132 | "2100-12-31T23:59:59Z" 133 | ], 134 | "step": "P1H" 135 | }, 136 | "lat": { 137 | "type": "spatial", 138 | "extent": [ 139 | 54.2, 140 | 22.6 141 | ], 142 | "axis": "y" 143 | }, 144 | "lon": { 145 | "type": "spatial", 146 | "extent": [ 147 | 233.6, 148 | 297.5 149 | ], 150 | "axis": "x" 151 | } 152 | }, 153 | "cube:variables": { 154 | "land_sea_mask": { 155 | "shortname": "lsm", 156 | "description": "Proportion of land-sea, where 1 indicates land and 0 indicates sea.", 157 | "type": "data", 158 | "data_type": "float32", 159 | "extent": [ 160 | 0, 161 | 1 162 | ], 163 | "dimensions": [ 164 | "lat", 165 | "lon" 166 | ], 167 | "definition": "https://codes.ecmwf.int/grib/param-db/172" 168 | }, 169 | "geopotential": { 170 | "shortname": "z", 171 | "description": "This parameter is the geopotential, which is the potential energy per unit mass at a point in the atmosphere, expressed in metres squared per second squared (m² s⁻²). It is a measure of the height of a point in the atmosphere relative to sea level.", 172 | "type": "data", 173 | "data_type": "float32", 174 | "unit": "m² s-2", 175 | "dimensions": [ 176 | "lat", 177 | "lon" 178 | ], 179 | "definition": "https://codes.ecmwf.int/grib/param-db/129" 180 | }, 181 | "2m_temperature": { 182 | "description": "This parameter is the temperature of air at 2m above the surface of land, sea or in-land waters.", 183 | "type": "data", 184 | "data_type": "float32", 185 | "unit": "K", 186 | "dimensions": [ 187 | "time", 188 | "lat", 189 | "lon" 190 | ], 191 | "definition": "https://codes.ecmwf.int/grib/param-db/167" 192 | }, 193 | "10m_u_component_of_wind": { 194 | "shortname": "10u", 195 | "description": "This parameter is the eastward component of the 10m wind. It is the horizontal speed of air moving towards the east, at a height of ten metres above the surface of the Earth, in metres per second.", 196 | "type": "data", 197 | "data_type": "float32", 198 | "unit": "m s-1", 199 | "dimensions": [ 200 | "time", 201 | "lat", 202 | "lon" 203 | ], 204 | "definition": "https://codes.ecmwf.int/grib/param-db/165" 205 | }, 206 | "10m_v_component_of_wind": { 207 | "shortname": "10v", 208 | "description": "This parameter is the northward component of the 10m wind. It is the horizontal speed of air moving towards the north, at a height of ten metres above the surface of the Earth, in metres per second.", 209 | "type": "data", 210 | "data_type": "float32", 211 | "unit": "m s-1", 212 | "dimensions": [ 213 | "time", 214 | "lat", 215 | "lon" 216 | ], 217 | "definition": "https://codes.ecmwf.int/grib/param-db/166" 218 | } 219 | }, 220 | "sci:publications": [ 221 | { 222 | "doi": "10.48550/arXiv.2404.17752", 223 | "citation": "Robbie A. Watt, Laura A. Mansfield, Generative Diffusion-based Downscaling for Climate" 224 | } 225 | ] 226 | }, 227 | "assets": { 228 | "weights": { 229 | "href": "https://github.com/robbiewatt1/ClimateDiffuse/raw/refs/heads/main/Model_chpt/unet.pt", 230 | "title": "U-Net Pytorch weights checkpoint", 231 | "type": "application/octet-stream; application=pytorch", 232 | "roles": [ 233 | "mlm:model", 234 | "mlm:weights" 235 | ], 236 | "mlm:artifact_type": "torch.save", 237 | "file:size": 389657415 238 | }, 239 | "model": { 240 | "href": "https://raw.githubusercontent.com/robbiewatt1/ClimateDiffuse/refs/heads/main/src/Network.py", 241 | "title": "Model implementation.", 242 | "description": "Source code to define the U-Net model.", 243 | "type": "text/x-python", 244 | "roles": [ 245 | "mlm:source_code", 246 | "code", 247 | "metadata" 248 | ] 249 | }, 250 | "train-script": { 251 | "href": "https://raw.githubusercontent.com/robbiewatt1/ClimateDiffuse/refs/heads/main/src/TrainUnet.py", 252 | "title": "Training script.", 253 | "description": "Script to run training of the model.", 254 | "type": "text/x-python", 255 | "roles": [ 256 | "mlm:training", 257 | "code", 258 | "metadata" 259 | ] 260 | } 261 | }, 262 | "inference-script": { 263 | "href": "https://raw.githubusercontent.com/robbiewatt1/ClimateDiffuse/refs/heads/main/src/Inference.py", 264 | "title": "Inference script.", 265 | "description": "Script to run inference with the model.", 266 | "type": "text/x-python", 267 | "roles": [ 268 | "mlm:inference", 269 | "code", 270 | "metadata" 271 | ] 272 | }, 273 | "links": [ 274 | { 275 | "rel": "collection", 276 | "href": "./collection.json", 277 | "type": "application/json" 278 | }, 279 | { 280 | "rel": "self", 281 | "href": "./item_datacube_variables.json", 282 | "type": "application/geo+json" 283 | }, 284 | { 285 | "rel": "via", 286 | "href": "https://github.com/robbiewatt1/ClimateDiffuse", 287 | "type": "text/html" 288 | }, 289 | { 290 | "rel": "cite-as", 291 | "href": "https://doi.org/10.48550/arXiv.2404.17752", 292 | "type": "text/html" 293 | }, 294 | { 295 | "rel": "code", 296 | "href": "https://github.com/robbiewatt1/ClimateDiffuse", 297 | "type": "text/html" 298 | } 299 | ] 300 | } 301 | -------------------------------------------------------------------------------- /tests/torch/test_export.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | pytest.importorskip("torchgeo") 4 | 5 | import pathlib 6 | 7 | import torch 8 | import torchvision.transforms.v2 as T 9 | import yaml 10 | from torch.export.pt2_archive._package import load_pt2 11 | from torchgeo.models import Unet_Weights, unet 12 | 13 | from stac_model.base import Path 14 | from stac_model.schema import MLModelProperties 15 | from stac_model.torch.export import save 16 | 17 | 18 | class TestPT2: 19 | in_channels = 3 20 | num_classes = 2 21 | height = width = 16 22 | in_h = in_w = 8 23 | metadata_path = pathlib.Path("tests") / "torch" / "metadata.yaml" 24 | 25 | @pytest.fixture 26 | def model(self) -> torch.nn.Module: 27 | return torch.nn.Conv2d(in_channels=self.in_channels, out_channels=self.num_classes, kernel_size=1, padding=0) 28 | 29 | @pytest.fixture 30 | def transforms(self) -> torch.nn.Module: 31 | return torch.nn.Sequential(T.Resize((self.height, self.width)), T.Normalize(mean=[0.0], std=[255.0])) 32 | 33 | def validate( 34 | self, 35 | archive_path: pathlib.Path, 36 | no_transforms: bool, 37 | input_shape: list[int], 38 | device: str | torch.device, 39 | dtype: torch.dtype, 40 | ) -> None: 41 | """Validate that pt2 is loadable and model/transform are usable.""" 42 | pt2 = load_pt2(archive_path) 43 | 44 | x = torch.randn(1, self.in_channels, self.in_h, self.in_w, device=device, dtype=dtype) 45 | 46 | # Validate AOT Inductor saving 47 | if pt2.aoti_runners != {}: 48 | model_aoti = pt2.aoti_runners["model"] 49 | preds = model_aoti(x) 50 | assert preds.shape == (1, self.num_classes, self.in_h, self.in_w) 51 | 52 | if no_transforms: 53 | assert "transforms" not in pt2.aoti_runners 54 | else: 55 | assert "transforms" in pt2.aoti_runners 56 | 57 | if "transforms" in pt2.aoti_runners: 58 | transforms_aoti = pt2.aoti_runners["transforms"] 59 | transformed = transforms_aoti(x) 60 | assert transformed.shape == (1, self.in_channels, self.height, self.width) 61 | 62 | # Validate ExportedProgram saving 63 | else: 64 | model_exported = pt2.exported_programs["model"].module() 65 | preds = model_exported(x) 66 | assert preds.shape == (1, self.num_classes, self.in_h, self.in_w) 67 | 68 | if no_transforms: 69 | assert "transforms" not in pt2.exported_programs 70 | else: 71 | assert "transforms" in pt2.exported_programs 72 | 73 | if "transforms" in pt2.exported_programs: 74 | transforms_exported = pt2.exported_programs["transforms"].module() 75 | transformed = transforms_exported(x) 76 | assert transformed.shape == (1, self.in_channels, self.height, self.width) 77 | 78 | # Validate MLM model metadata 79 | metadata = pt2.extra_files["mlm-metadata"] 80 | metadata = yaml.safe_load(metadata) 81 | assert "mlm:accelerator" not in metadata["properties"] 82 | properties = MLModelProperties(**metadata["properties"]) 83 | assert properties.input[0].input.shape == input_shape 84 | assert properties.accelerator == str(device).split(":")[0] 85 | assert properties.input[0].input.data_type == str(dtype).split(".")[-1] 86 | assert properties.output[0].result.data_type == str(dtype).split(".")[-1] 87 | 88 | @pytest.mark.parametrize("no_transforms", [True, False]) 89 | @pytest.mark.parametrize("aoti_compile_and_package", [False, True]) 90 | def test_export_model_cpu( 91 | self, 92 | tmpdir: Path, 93 | model: torch.nn.Module, 94 | transforms: torch.nn.Module, 95 | aoti_compile_and_package: bool, 96 | no_transforms: bool, 97 | ) -> None: 98 | archive_path = pathlib.Path(tmpdir) / "model.pt2" 99 | input_shape = [-1, self.in_channels, -1, -1] 100 | save( 101 | output_file=archive_path, 102 | model=model, 103 | transforms=None if no_transforms else transforms, 104 | metadata=self.metadata_path, 105 | input_shape=input_shape, 106 | device="cpu", 107 | dtype=torch.float32, 108 | aoti_compile_and_package=aoti_compile_and_package, 109 | ) 110 | self.validate( 111 | archive_path=archive_path, 112 | no_transforms=no_transforms, 113 | input_shape=input_shape, 114 | device="cpu", 115 | dtype=torch.float32, 116 | ) 117 | 118 | @pytest.mark.slow 119 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") 120 | @pytest.mark.parametrize("no_transforms", [True, False]) 121 | @pytest.mark.parametrize("aoti_compile_and_package", [False, True]) 122 | def test_export_model_cuda( 123 | self, 124 | tmpdir: Path, 125 | model: torch.nn.Module, 126 | transforms: torch.nn.Module, 127 | aoti_compile_and_package: bool, 128 | no_transforms: bool, 129 | ) -> None: 130 | archive_path = pathlib.Path(tmpdir) / "model.pt2" 131 | input_shape = [-1, self.in_channels, -1, -1] 132 | save( 133 | output_file=archive_path, 134 | model=model, 135 | transforms=None if no_transforms else transforms, 136 | metadata=self.metadata_path, 137 | input_shape=input_shape, 138 | device="cuda", 139 | dtype=torch.float32, 140 | aoti_compile_and_package=aoti_compile_and_package, 141 | ) 142 | self.validate( 143 | archive_path=archive_path, 144 | no_transforms=no_transforms, 145 | input_shape=input_shape, 146 | device="cuda", 147 | dtype=torch.float32, 148 | ) 149 | 150 | def test_export_mlmodelproperties( 151 | self, 152 | tmpdir: Path, 153 | model: torch.nn.Module, 154 | ) -> None: 155 | archive_path = pathlib.Path(tmpdir) / "model.pt2" 156 | input_shape = [-1, self.in_channels, -1, -1] 157 | 158 | with open(self.metadata_path) as f: 159 | metadata = yaml.safe_load(f) 160 | properties = MLModelProperties(**metadata["properties"]) 161 | 162 | save( 163 | output_file=archive_path, 164 | model=model, 165 | transforms=None, 166 | metadata=properties, 167 | input_shape=input_shape, 168 | device=torch.device("cpu"), 169 | dtype=torch.float32, 170 | aoti_compile_and_package=False, 171 | ) 172 | self.validate( 173 | archive_path=archive_path, 174 | no_transforms=True, 175 | input_shape=input_shape, 176 | device="cpu", 177 | dtype=torch.float32, 178 | ) 179 | 180 | 181 | class TestTorchGeoFTWPT2(TestPT2): 182 | in_channels = 8 183 | num_classes = 3 184 | height = width = 256 185 | in_h = in_w = 128 186 | metadata_path = pathlib.Path("tests") / "torch" / "metadata.yaml" 187 | 188 | @pytest.fixture 189 | def model(self) -> torch.nn.Module: 190 | model: torch.nn.Module = unet(weights=Unet_Weights.SENTINEL2_3CLASS_FTW) 191 | return model 192 | 193 | @pytest.fixture 194 | def transforms(self) -> torch.nn.Module: 195 | return torch.nn.Sequential(T.Resize((self.height, self.width)), T.Normalize(mean=[0.0], std=[3000.0])) 196 | 197 | @pytest.mark.slow 198 | @pytest.mark.parametrize("no_transforms", [True, False]) 199 | @pytest.mark.parametrize("aoti_compile_and_package", [False, True]) 200 | def test_export_model_cpu( 201 | self, 202 | tmpdir: Path, 203 | model: torch.nn.Module, 204 | transforms: torch.nn.Module, 205 | aoti_compile_and_package: bool, 206 | no_transforms: bool, 207 | ) -> None: 208 | archive_path = pathlib.Path(tmpdir) / "model.pt2" 209 | input_shape = [-1, self.in_channels, -1, -1] 210 | save( 211 | output_file=archive_path, 212 | model=model, 213 | transforms=None if no_transforms else transforms, 214 | metadata=self.metadata_path, 215 | input_shape=input_shape, 216 | device="cpu", 217 | dtype=torch.float32, 218 | aoti_compile_and_package=aoti_compile_and_package, 219 | ) 220 | self.validate( 221 | archive_path=archive_path, 222 | no_transforms=no_transforms, 223 | input_shape=input_shape, 224 | device="cpu", 225 | dtype=torch.float32, 226 | ) 227 | 228 | @pytest.mark.slow 229 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") 230 | @pytest.mark.parametrize("no_transforms", [True, False]) 231 | @pytest.mark.parametrize("aoti_compile_and_package", [False, True]) 232 | def test_export_model_cuda( 233 | self, 234 | tmpdir: Path, 235 | model: torch.nn.Module, 236 | transforms: torch.nn.Module, 237 | aoti_compile_and_package: bool, 238 | no_transforms: bool, 239 | ) -> None: 240 | archive_path = pathlib.Path(tmpdir) / "model.pt2" 241 | input_shape = [-1, self.in_channels, -1, -1] 242 | save( 243 | output_file=archive_path, 244 | model=model, 245 | transforms=None if no_transforms else transforms, 246 | metadata=self.metadata_path, 247 | input_shape=input_shape, 248 | device="cuda", 249 | dtype=torch.float32, 250 | aoti_compile_and_package=aoti_compile_and_package, 251 | ) 252 | self.validate( 253 | archive_path=archive_path, 254 | no_transforms=no_transforms, 255 | input_shape=input_shape, 256 | device="cuda", 257 | dtype=torch.float32, 258 | ) 259 | -------------------------------------------------------------------------------- /stac_model/torch/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import os 4 | import pathlib 5 | import tempfile 6 | import zipfile 7 | from typing import cast 8 | 9 | import kornia.augmentation as K 10 | import torch 11 | import yaml 12 | 13 | from ..base import DataType, Path 14 | from ..input import ValueScalingObject, ValueScalingZScore 15 | from ..runtime import AcceleratorName 16 | from ..schema import SCHEMA_URI, MLModelProperties 17 | from .base import AOTIFiles 18 | 19 | 20 | def normalize_dtype(torch_dtype: torch.dtype) -> DataType: 21 | """ 22 | Convert a PyTorch dtype (e.g., torch.float32) to a standardized DataType. 23 | """ 24 | return cast(DataType, str(torch_dtype).rsplit(".", 1)[-1]) 25 | 26 | 27 | def find_tensor_by_key(state_dict: dict[str, torch.Tensor], key_substring: str, reverse: bool = False) -> torch.Tensor: 28 | """ 29 | Find a tensor in the state_dict by a substring in its key. 30 | If `reverse` is True, search from the end of the dictionary. 31 | """ 32 | items = reversed(state_dict.items()) if reverse else state_dict.items() 33 | for key, tensor in items: 34 | if key_substring in key: 35 | return tensor 36 | raise ValueError(f"Could not find tensor with key containing '{key_substring}'") 37 | 38 | 39 | def get_input_hw(state_dict: dict[str, torch.Tensor]) -> tuple[int, int]: 40 | tensor = find_tensor_by_key(state_dict, "encoder._conv_stem.weight") 41 | return tensor.shape[2], tensor.shape[3] 42 | 43 | 44 | def get_input_dtype(state_dict: dict[str, torch.Tensor]) -> DataType: 45 | """ 46 | Get the data type (dtype) of the input from the first convolutional layer's weights. 47 | """ 48 | tensor = find_tensor_by_key(state_dict, "encoder._conv_stem.weight") 49 | return normalize_dtype(tensor.dtype) 50 | 51 | 52 | def get_output_dtype(state_dict: dict[str, torch.Tensor]) -> DataType: 53 | """ 54 | Get the data type (dtype) of the output from the segmentation head's last conv layer. 55 | """ 56 | tensor = find_tensor_by_key(state_dict, "segmentation_head.0.weight", reverse=True) 57 | return normalize_dtype(tensor.dtype) 58 | 59 | 60 | def get_input_channels(state_dict: dict[str, torch.Tensor]) -> int: 61 | """ 62 | Get number of input channels from the first convolutional layer's weights. 63 | """ 64 | tensor = find_tensor_by_key(state_dict, "encoder._conv_stem.weight") 65 | return int(tensor.shape[1]) 66 | 67 | 68 | def get_output_channels(state_dict: dict[str, torch.Tensor]) -> int: 69 | """ 70 | Get number of output channels from the segmentation head's last conv layer. 71 | """ 72 | tensor = find_tensor_by_key(state_dict, "segmentation_head.0.weight", reverse=True) 73 | return int(tensor.shape[0]) 74 | 75 | 76 | def extract_value_scaling(transforms: K.AugmentationSequential) -> list[ValueScalingObject]: 77 | """ 78 | Extracts value scaling definitions from a Kornia AugmentationSequential object. 79 | """ 80 | children = list(transforms.children()) 81 | 82 | def _tensor_to_value(tensor: torch.Tensor) -> int: 83 | """ 84 | Convert a tensor to an int. 85 | """ 86 | flat_tensor = tensor.view(-1) 87 | return int(flat_tensor[0]) 88 | 89 | scaling_defs: list[ValueScalingObject] = [] 90 | 91 | for t in children: 92 | if isinstance(t, K.Normalize): 93 | buffers = dict(t.named_buffers()) if hasattr(t, "named_buffers") else {} 94 | mean = buffers.get("mean") 95 | stddev = buffers.get("std") 96 | 97 | if mean is None or stddev is None: 98 | flags = getattr(t, "flags", {}) 99 | mean = mean or flags.get("mean") 100 | stddev = stddev or flags.get("std") 101 | 102 | if mean is None or stddev is None: 103 | raise AttributeError("Normalize transform missing mean/std info") 104 | 105 | scaling_defs.append(ValueScalingZScore(mean=_tensor_to_value(mean), stddev=_tensor_to_value(stddev))) 106 | 107 | elif isinstance(t, K.AugmentationSequential): 108 | scaling_defs.extend(extract_value_scaling(t)) 109 | 110 | return scaling_defs 111 | 112 | 113 | def extract_module_arg_names(module: torch.nn.Module) -> str: 114 | """Extracts the argument names of the forward method of a given module. 115 | 116 | Args: 117 | module: The PyTorch module from which to extract argument names. 118 | 119 | Returns: 120 | A list of argument names for the forward method of the module. 121 | 122 | Raises: 123 | ValueError: If the module does not have a forward method. 124 | """ 125 | if not hasattr(module, "forward"): 126 | raise ValueError("The provided module does not have a forward method.") 127 | 128 | return next(iter(inspect.signature(module.forward).parameters)) 129 | 130 | 131 | def aoti_compile_and_extract(program: torch.export.ExportedProgram, output_directory: Path) -> list[Path]: 132 | """Compiles an exported program using AOTI and extracts the files to the specified directory. 133 | 134 | Args: 135 | program: The exported program to compile. 136 | output_directory: The directory where the compiled files will be extracted. 137 | 138 | Returns: 139 | A list of file paths extracted from the compiled package. 140 | """ 141 | with tempfile.TemporaryDirectory() as tmpdir: 142 | path = torch._inductor.aoti_compile_and_package(program, package_path=os.path.join(tmpdir, "file.pt2")) 143 | 144 | with zipfile.ZipFile(path, "r") as zip_ref: 145 | zip_ref.extractall(output_directory) 146 | 147 | return [ 148 | cast(os.PathLike[str], f) 149 | for f in glob.glob(os.path.join(output_directory, "**"), recursive=True) 150 | if os.path.isfile(f) 151 | ] 152 | 153 | 154 | def aoti_compile( 155 | model_directory: Path, 156 | model_program: torch.export.ExportedProgram, 157 | transforms_directory: Path | None = None, 158 | transforms_program: torch.export.ExportedProgram | None = None, 159 | ) -> AOTIFiles: 160 | """Compiles a model and its transforms using AOTI. 161 | 162 | Args: 163 | model_directory: The directory to store the compiled model files. 164 | model_program: The exported model program. 165 | transforms_directory: The directory to store the compiled transforms files. 166 | transforms_program: The exported transforms program. 167 | """ 168 | model_files = aoti_compile_and_extract( 169 | program=model_program, 170 | output_directory=model_directory, 171 | ) 172 | aoti_files: AOTIFiles = {"model": model_files} 173 | 174 | if transforms_program is not None and transforms_directory is not None: 175 | transforms_files = aoti_compile_and_extract( 176 | program=transforms_program, 177 | output_directory=transforms_directory, 178 | ) 179 | aoti_files["transforms"] = transforms_files 180 | 181 | return aoti_files 182 | 183 | 184 | def create_example_input_from_shape(input_shape: list[int]) -> torch.Tensor: 185 | """Creates an example input tensor based on the provided input shape. 186 | 187 | If batch dimension is dynamic (-1), it defaults to 2. Other dynamic dimensions 188 | default to 224. Ideally all dimensions are defined by the user but this provides good defaults. 189 | 190 | Args: 191 | input_shape: The shape of the input tensor. 192 | 193 | Returns: 194 | A tensor filled with random values, shaped according to the input_shape. 195 | 196 | Raises: 197 | ValueError: If all dimensions are dynamic (-1). 198 | """ 199 | shape = [] 200 | 201 | if all(dim == -1 for dim in input_shape): 202 | raise ValueError("Input shape cannot be all dynamic (-1). At least one dimension must be fixed.") 203 | elif any(dim != -1 for dim in input_shape): 204 | batch_dim = 2 if input_shape[0] == -1 else input_shape[0] 205 | shape.append(batch_dim) 206 | shape.extend([dim if dim != -1 else 224 for dim in input_shape[1:]]) 207 | else: 208 | shape = list(input_shape) 209 | 210 | return torch.randn(*shape, requires_grad=False) 211 | 212 | 213 | def model_properties_to_metadata(properties: MLModelProperties) -> str: 214 | """Converts MLModelProperties to a metadata dictionary in YAML format. 215 | 216 | Args: 217 | properties: An instance of MLModelProperties containing model metadata. 218 | 219 | Returns: 220 | A YAML string representation of the model properties. 221 | """ 222 | properties_dict = properties.model_dump(by_alias=False, exclude_none=True) 223 | properties_dict = {k.replace("mlm:", ""): v for k, v in properties_dict.items()} 224 | return yaml.dump( 225 | { 226 | "$schema": SCHEMA_URI, 227 | "properties": properties_dict, 228 | }, 229 | default_flow_style=False, 230 | ) 231 | 232 | 233 | def update_properties( 234 | metadata: Path | MLModelProperties, input_shape: list[int], device: str | torch.device, dtype: torch.dtype 235 | ) -> MLModelProperties: 236 | """Updates the MLModelProperties with the given metadata, device, and dtype. 237 | 238 | Args: 239 | metadata: Path to the YAML file containing model metadata or an instance of MLModelProperties. 240 | input_shape: The shape of the input tensor, where -1 indicates a dynamic dimension. 241 | device: The device to export the model and transforms to. 242 | dtype: The data type to use for the model and transforms. 243 | 244 | Returns: 245 | An instance of MLModelProperties with updated properties. 246 | 247 | Raises: 248 | ValidationError: if the metadata is not valid MLModelProperties. 249 | TypeError: if metadata is not a path to a YAML file or an instance of MLModelProperties. 250 | """ 251 | if isinstance(metadata, pathlib.Path | str): 252 | with open(metadata) as f: 253 | meta = yaml.safe_load(f) 254 | properties = MLModelProperties(**meta["properties"]) 255 | elif isinstance(metadata, MLModelProperties): 256 | properties = metadata 257 | else: 258 | raise TypeError("Metadata must be a path to a YAML file or an instance of MLModelProperties.") 259 | 260 | accelerator = cast(AcceleratorName, str(device).split(":")[0]) 261 | data_type = cast(DataType, str(dtype).split(".")[-1]) 262 | 263 | properties.accelerator = accelerator 264 | properties.input[0].input.shape = input_shape # type: ignore[assignment] 265 | properties.input[0].input.data_type = data_type 266 | properties.output[0].result.data_type = data_type 267 | 268 | return properties 269 | -------------------------------------------------------------------------------- /stac_model/schema.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections.abc import Iterable 3 | from typing import ( 4 | TYPE_CHECKING, 5 | Annotated, 6 | Any, 7 | Generic, 8 | Literal, 9 | TypeVar, 10 | Union, 11 | cast, 12 | get_args, 13 | overload, 14 | ) 15 | 16 | import pystac 17 | from pydantic import ConfigDict, Field 18 | from pydantic.fields import FieldInfo 19 | from pystac.extensions.base import ( 20 | ExtensionManagementMixin, 21 | PropertiesExtension, 22 | SummariesExtension, 23 | ) 24 | 25 | from stac_model.base import ModelTask, OmitIfNone 26 | from stac_model.input import ModelInput 27 | from stac_model.output import ModelOutput 28 | from stac_model.runtime import Runtime 29 | 30 | if TYPE_CHECKING: 31 | import torch.nn as nn 32 | 33 | T = TypeVar( 34 | "T", 35 | pystac.Collection, 36 | pystac.Item, 37 | pystac.Asset, # item_assets.AssetDefinition, 38 | ) 39 | 40 | SchemaName = Literal["mlm"] 41 | SCHEMA_URI: str = "https://stac-extensions.github.io/mlm/v1.5.0/schema.json" 42 | PREFIX = f"{get_args(SchemaName)[0]}:" 43 | 44 | 45 | def mlm_prefix_adder(field_name: str) -> str: 46 | return "mlm:" + field_name 47 | 48 | 49 | class MLModelProperties(Runtime): 50 | name: str = Field(min_length=1) 51 | architecture: str = Field(min_length=1) 52 | tasks: set[ModelTask] 53 | input: list[ModelInput] 54 | output: list[ModelOutput] 55 | 56 | total_parameters: Annotated[int | None, OmitIfNone] = Field(default=None, ge=0) 57 | pretrained: Annotated[bool | None, OmitIfNone] = Field(default=True) 58 | pretrained_source: Annotated[str | None, OmitIfNone] = None 59 | 60 | model_config = ConfigDict(alias_generator=mlm_prefix_adder, populate_by_name=True, extra="ignore") 61 | 62 | 63 | class MLModelExtension( 64 | Generic[T], 65 | PropertiesExtension, 66 | # FIXME: resolve typing incompatibility? 67 | # 'pystac.Asset' does not derive from STACObject 68 | # therefore, it technically cannot be used in 'ExtensionManagementMixin[T]' 69 | # however, this makes our extension definition much easier and avoids lots of code duplication 70 | ExtensionManagementMixin[ # type: ignore[type-var] 71 | Union[ 72 | pystac.Collection, 73 | pystac.Item, 74 | pystac.Asset, 75 | ] 76 | ], 77 | ): 78 | @property 79 | def name(self) -> SchemaName: 80 | return cast(SchemaName, get_args(SchemaName)[0]) 81 | 82 | def apply( 83 | self, 84 | properties: MLModelProperties | dict[str, Any], 85 | ) -> None: 86 | """ 87 | Applies Machine Learning Model Extension properties to the extended :mod:`~pystac` object. 88 | """ 89 | if isinstance(properties, dict): 90 | properties = MLModelProperties(**properties) 91 | data_json = json.loads(properties.model_dump_json(by_alias=True)) 92 | for prop, val in data_json.items(): 93 | self._set_property(prop, val) 94 | 95 | @classmethod 96 | def get_schema_uri(cls) -> str: 97 | return SCHEMA_URI 98 | 99 | @overload 100 | @classmethod 101 | def ext(cls, obj: pystac.Asset, add_if_missing: bool = False) -> "AssetMLModelExtension": ... 102 | 103 | @overload 104 | @classmethod 105 | def ext(cls, obj: pystac.Item, add_if_missing: bool = False) -> "ItemMLModelExtension": ... 106 | 107 | @overload 108 | @classmethod 109 | def ext(cls, obj: pystac.Collection, add_if_missing: bool = False) -> "CollectionMLModelExtension": ... 110 | 111 | # @overload 112 | # @classmethod 113 | # def ext(cls, obj: item_assets.AssetDefinition, add_if_missing: bool = False) -> "ItemAssetsMLModelExtension": 114 | # ... 115 | 116 | @classmethod 117 | def ext( 118 | cls, 119 | obj: pystac.Collection | pystac.Item | pystac.Asset, # item_assets.AssetDefinition 120 | add_if_missing: bool = False, 121 | ) -> Union[ 122 | "CollectionMLModelExtension", 123 | "ItemMLModelExtension", 124 | "AssetMLModelExtension", 125 | ]: 126 | """ 127 | Extends the given STAC Object with properties from the :stac-ext:`Machine Learning Model Extension `. 128 | 129 | This extension can be applied to instances of :class:`~pystac.Item` or :class:`~pystac.Asset`. 130 | 131 | Args: 132 | obj: STAC Object to extend with the MLM extension fields. 133 | add_if_missing: Add the MLM extension schema URI to the object if not already in `stac_extensions`. 134 | 135 | Returns: 136 | Extended object. 137 | 138 | Raises: 139 | pystac.ExtensionTypeError: If an invalid object type is passed. 140 | """ 141 | if isinstance(obj, pystac.Collection): 142 | cls.ensure_has_extension(obj, add_if_missing) 143 | return CollectionMLModelExtension(obj) 144 | elif isinstance(obj, pystac.Item): 145 | cls.ensure_has_extension(obj, add_if_missing) 146 | return ItemMLModelExtension(obj) 147 | elif isinstance(obj, pystac.Asset): 148 | cls.ensure_owner_has_extension(obj, add_if_missing) 149 | return AssetMLModelExtension(obj) 150 | # elif isinstance(obj, item_assets.AssetDefinition): 151 | # cls.ensure_owner_has_extension(obj, add_if_missing) 152 | # return ItemAssetsMLModelExtension(obj) 153 | else: 154 | raise pystac.ExtensionTypeError(cls._ext_error_message(obj)) 155 | 156 | @classmethod 157 | def summaries(cls, obj: pystac.Collection, add_if_missing: bool = False) -> "SummariesMLModelExtension": 158 | """Returns the extended summaries object for the given collection.""" 159 | cls.ensure_has_extension(obj, add_if_missing) 160 | return SummariesMLModelExtension(obj) 161 | 162 | @classmethod 163 | def from_torch(cls, model: "nn.Module", **kwargs: Any) -> "ItemMLModelExtension": 164 | from stac_model.torch import from_torch 165 | 166 | return from_torch(model, **kwargs) 167 | 168 | 169 | class SummariesMLModelExtension(SummariesExtension): 170 | """ 171 | Summaries annotated with the Machine Learning Model Extension. 172 | 173 | A concrete implementation of :class:`~SummariesExtension` that extends 174 | the ``summaries`` field of a :class:`~pystac.Collection` to include properties 175 | defined in the :stac-ext:`Machine Learning Model `. 176 | """ 177 | 178 | def _check_mlm_property(self, prop: str) -> FieldInfo: 179 | try: 180 | return MLModelProperties.model_fields[prop] 181 | except KeyError as err: 182 | raise AttributeError(f"Name '{prop}' is not a valid MLM property.") from err 183 | 184 | def _validate_mlm_property(self, prop: str, summaries: list[Any]) -> None: 185 | # ignore mypy issue when combined with Annotated 186 | # - https://github.com/pydantic/pydantic/issues/6713 187 | # - https://github.com/pydantic/pydantic/issues/5190 188 | model = MLModelProperties.model_construct() # type: ignore[call-arg] 189 | validator = MLModelProperties.__pydantic_validator__ 190 | for value in summaries: 191 | validator.validate_assignment(model, prop, value) 192 | 193 | def get_mlm_property(self, prop: str) -> list[Any] | None: 194 | self._check_mlm_property(prop) 195 | return self.summaries.get_list(prop) 196 | 197 | def set_mlm_property(self, prop: str, summaries: list[Any]) -> None: 198 | self._check_mlm_property(prop) 199 | self._validate_mlm_property(prop, summaries) 200 | self._set_summary(prop, summaries) 201 | 202 | def __getattr__(self, prop): 203 | return self.get_mlm_property(prop) 204 | 205 | def __setattr__(self, prop, value): 206 | self.set_mlm_property(prop, value) 207 | 208 | 209 | class ItemMLModelExtension(MLModelExtension[pystac.Item]): 210 | """ 211 | Item annotated with the Machine Learning Model Extension. 212 | 213 | A concrete implementation of :class:`MLModelExtension` on an 214 | :class:`~pystac.Item` that extends the properties of the Item to 215 | include properties defined in the :stac-ext:`Machine Learning Model 216 | Extension `. 217 | 218 | This class should generally not be instantiated directly. Instead, call 219 | :meth:`MLModelExtension.ext` on an :class:`~pystac.Item` to extend it. 220 | """ 221 | 222 | def __init__(self, item: pystac.Item): 223 | self.item = item 224 | self.properties = item.properties 225 | 226 | def __repr__(self) -> str: 227 | return f"" 228 | 229 | 230 | # class ItemAssetsMLModelExtension(MLModelExtension[item_assets.AssetDefinition]): 231 | # properties: dict[str, Any] 232 | # asset_defn: item_assets.AssetDefinition 233 | # 234 | # def __init__(self, item_asset: item_assets.AssetDefinition): 235 | # self.asset_defn = item_asset 236 | # self.properties = item_asset.properties 237 | 238 | 239 | class AssetMLModelExtension(MLModelExtension[pystac.Asset]): 240 | """ 241 | Asset annotated with the Machine Learning Model Extension. 242 | 243 | A concrete implementation of :class:`MLModelExtension` on an 244 | :class:`~pystac.Asset` that extends the Asset fields to include 245 | properties defined in the :stac-ext:`Machine Learning Model 246 | Extension `. 247 | 248 | This class should generally not be instantiated directly. Instead, call 249 | :meth:`MLModelExtension.ext` on an :class:`~pystac.Asset` to extend it. 250 | """ 251 | 252 | asset_href: str 253 | """The ``href`` value of the :class:`~pystac.Asset` being extended.""" 254 | 255 | properties: dict[str, Any] 256 | """The :class:`~pystac.Asset` fields, including extension properties.""" 257 | 258 | additional_read_properties: Iterable[dict[str, Any]] | None = None 259 | """If present, this will be a list containing 1 dictionary representing the 260 | properties of the owning :class:`~pystac.Item`.""" 261 | 262 | def __init__(self, asset: pystac.Asset): 263 | self.asset_href = asset.href 264 | self.properties = asset.extra_fields 265 | if asset.owner and isinstance(asset.owner, pystac.Item): 266 | self.additional_read_properties = [asset.owner.properties] 267 | 268 | def __repr__(self) -> str: 269 | return f"" 270 | 271 | 272 | class CollectionMLModelExtension(MLModelExtension[pystac.Collection]): 273 | def __init__(self, collection: pystac.Collection): 274 | self.collection = collection 275 | 276 | 277 | # __all__ = [ 278 | # "MLModelExtension", 279 | # "ModelInput", 280 | # "InputArray", 281 | # "Band", 282 | # "Statistics", 283 | # "ModelOutput", 284 | # "Asset", 285 | # "Runtime", 286 | # "Container", 287 | # "Asset", 288 | # ] 289 | -------------------------------------------------------------------------------- /stac_model/examples.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import pystac 4 | import shapely 5 | from dateutil.parser import parse as parse_dt 6 | from pystac.extensions.eo import Band, EOExtension 7 | from pystac.extensions.file import FileExtension 8 | 9 | from stac_model.base import ProcessingExpression, TaskEnum 10 | from stac_model.input import InputStructure, ModelInput, ValueScalingObject 11 | from stac_model.output import MLMClassification, ModelOutput, ModelResult 12 | from stac_model.schema import ItemMLModelExtension, MLModelExtension, MLModelProperties 13 | 14 | 15 | def eurosat_resnet() -> ItemMLModelExtension: 16 | input_struct = InputStructure( 17 | shape=[-1, 13, 64, 64], 18 | dim_order=["batch", "bands", "height", "width"], 19 | data_type="float32", 20 | ) 21 | band_names = [ 22 | "B01", 23 | "B02", 24 | "B03", 25 | "B04", 26 | "B05", 27 | "B06", 28 | "B07", 29 | "B08", 30 | "B8A", 31 | "B09", 32 | "B10", 33 | "B11", 34 | "B12", 35 | ] 36 | stats_mean = [ 37 | 1354.40546513, 38 | 1118.24399958, 39 | 1042.92983953, 40 | 947.62620298, 41 | 1199.47283961, 42 | 1999.79090914, 43 | 2369.22292565, 44 | 2296.82608323, 45 | 732.08340178, 46 | 12.11327804, 47 | 1819.01027855, 48 | 1118.92391149, 49 | 2594.14080798, 50 | ] 51 | stats_stddev = [ 52 | 245.71762908, 53 | 333.00778264, 54 | 395.09249139, 55 | 593.75055589, 56 | 566.4170017, 57 | 861.18399006, 58 | 1086.63139075, 59 | 1117.98170791, 60 | 404.91978886, 61 | 4.77584468, 62 | 1002.58768311, 63 | 761.30323499, 64 | 1231.58581042, 65 | ] 66 | value_scaling = [ 67 | cast( 68 | ValueScalingObject, 69 | dict( 70 | type="z-score", 71 | mean=mean, 72 | stddev=stddev, 73 | ), 74 | ) 75 | for mean, stddev in zip(stats_mean, stats_stddev, strict=False) 76 | ] 77 | model_input = ModelInput( 78 | name="13 Band Sentinel-2 Batch", 79 | bands=band_names, 80 | input=input_struct, 81 | resize_type=None, 82 | value_scaling=value_scaling, 83 | pre_processing_function=ProcessingExpression( 84 | format="python", 85 | expression="torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn", 86 | ), # noqa: E501 87 | ) 88 | result_struct = ModelResult( 89 | shape=[-1, 10], 90 | dim_order=["batch", "class"], 91 | data_type="float32", 92 | ) 93 | class_map = { 94 | "Annual Crop": 0, 95 | "Forest": 1, 96 | "Herbaceous Vegetation": 2, 97 | "Highway": 3, 98 | "Industrial Buildings": 4, 99 | "Pasture": 5, 100 | "Permanent Crop": 6, 101 | "Residential Buildings": 7, 102 | "River": 8, 103 | "SeaLake": 9, 104 | } 105 | class_objects = [ 106 | MLMClassification( 107 | value=class_value, 108 | name=class_name, 109 | ) 110 | for class_name, class_value in class_map.items() 111 | ] 112 | model_output = ModelOutput( 113 | name="classification", 114 | tasks={"classification"}, 115 | classes=class_objects, 116 | result=result_struct, 117 | post_processing_function=None, 118 | ) 119 | assets = { 120 | "model": pystac.Asset( 121 | title="Pytorch weights checkpoint", 122 | description=( 123 | "A Resnet-18 classification model trained on normalized Sentinel-2 " 124 | "imagery with Eurosat landcover labels with torchgeo." 125 | ), 126 | href="https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth", 127 | media_type="application/octet-stream; application=pytorch", 128 | roles=[ 129 | "mlm:model", 130 | "mlm:weights", 131 | "data", 132 | ], 133 | extra_fields={"mlm:artifact_type": "torch.save"}, 134 | ), 135 | "source_code": pystac.Asset( 136 | title="Model implementation.", 137 | description="Source code to run the model.", 138 | href="https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207", 139 | media_type="text/x-python", 140 | roles=[ 141 | "mlm:source_code", 142 | "code", 143 | ], 144 | ), 145 | } 146 | 147 | ml_model_size = 43000000 148 | ml_model_meta = MLModelProperties( 149 | name="Resnet-18 Sentinel-2 ALL MOCO", 150 | architecture="ResNet-18", 151 | tasks={"classification"}, 152 | framework="pytorch", 153 | framework_version="2.1.2+cu121", 154 | accelerator="cuda", 155 | accelerator_constrained=False, 156 | accelerator_summary="Unknown", 157 | file_size=ml_model_size, 158 | memory_size=1, 159 | pretrained=True, 160 | pretrained_source="EuroSat Sentinel-2", 161 | total_parameters=11_700_000, 162 | input=[model_input], 163 | output=[model_output], 164 | ) 165 | # TODO, this can't be serialized but pystac.item calls for a datetime 166 | # in docs. start_datetime=datetime.strptime("1900-01-01", "%Y-%m-%d") 167 | # Is this a problem that we don't do date validation if we supply as str? 168 | start_datetime_str = "1900-01-01" 169 | end_datetime_str = "9999-01-01" # cannot be None, invalid against STAC Core! 170 | start_datetime = parse_dt(start_datetime_str).isoformat() + "Z" 171 | end_datetime = parse_dt(end_datetime_str).isoformat() + "Z" 172 | bbox = [ 173 | -7.882190080512502, 174 | 37.13739173208318, 175 | 27.911651652899923, 176 | 58.21798141355221, 177 | ] 178 | geometry = shapely.geometry.Polygon.from_bounds(*bbox).__geo_interface__ 179 | item_name = "eurosat-resnet-mlm-example" 180 | col_name = "ml-model-examples" 181 | item = pystac.Item( 182 | id=item_name, 183 | collection=col_name, 184 | geometry=geometry, 185 | bbox=bbox, 186 | datetime=None, 187 | properties={ 188 | "start_datetime": start_datetime, 189 | "end_datetime": end_datetime, 190 | "description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO", 191 | }, 192 | assets=assets, 193 | ) 194 | 195 | # note: cannot use 'item.add_derived_from' since it expects a 'Item' object, but we refer to a 'Collection' here 196 | # item.add_derived_from("https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a") 197 | item.add_link( 198 | pystac.Link( 199 | target="https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a", 200 | rel=pystac.RelType.DERIVED_FROM, 201 | media_type=pystac.MediaType.JSON, 202 | ) 203 | ) 204 | 205 | # define more link references 206 | col = pystac.Collection( 207 | id=col_name, 208 | title="Machine Learning Model examples", 209 | description="Collection of items contained in the Machine Learning Model examples.", 210 | extent=pystac.Extent( 211 | temporal=pystac.TemporalExtent([[parse_dt(start_datetime), parse_dt(end_datetime)]]), 212 | spatial=pystac.SpatialExtent([bbox]), 213 | ), 214 | ) 215 | col.set_self_href("./examples/collection.json") 216 | col.add_item(item) 217 | item.set_self_href(f"./examples/{item_name}.json") 218 | 219 | model_asset = cast( 220 | FileExtension[pystac.Asset], 221 | FileExtension.ext(assets["model"], add_if_missing=True), 222 | ) 223 | model_asset.apply(size=ml_model_size) 224 | 225 | eo_model_asset = cast( 226 | EOExtension[pystac.Asset], 227 | EOExtension.ext(assets["model"], add_if_missing=True), 228 | ) 229 | # NOTE: 230 | # typically, it is recommended to add as much details as possible for the band description 231 | # minimally, the names (which are well-known for sentinel-2) are sufficient 232 | eo_bands = [] 233 | for name in band_names: 234 | band = Band({}) 235 | band.apply(name=name) 236 | eo_bands.append(band) 237 | eo_model_asset.apply(bands=eo_bands) 238 | 239 | item_mlm = MLModelExtension.ext(item, add_if_missing=True) 240 | item_mlm.apply(ml_model_meta.model_dump(by_alias=True, exclude_unset=True, exclude_defaults=True)) 241 | return item_mlm 242 | 243 | 244 | def unet_mlm() -> ItemMLModelExtension: # pragma: has-torch 245 | """ 246 | Example of a UNet model using PyTorchGeo SENTINEL2_2CLASS_NC_FTW default weights. 247 | 248 | Returns an ItemMLModelExtension with Machine Learning Model Extension metadata. 249 | """ 250 | from torchgeo.models import Unet_Weights, unet 251 | # Set the STAC version to 1.0.0 for compatibility with the example using relative links 252 | pystac.set_stac_version("1.0.0") 253 | 254 | weights = Unet_Weights.SENTINEL2_2CLASS_NC_FTW 255 | model = unet(weights=weights) 256 | item_id = "pytorch_geo_unet" 257 | collection_id = "ml-model-examples" 258 | bbox = [-7.88, 37.13, 27.91, 58.21] 259 | geometry = { 260 | "type": "Polygon", 261 | "coordinates": [ 262 | [ 263 | [-7.88, 37.13], 264 | [-7.88, 58.21], 265 | [27.91, 58.21], 266 | [27.91, 37.13], 267 | [-7.88, 37.13], 268 | ] 269 | ], 270 | } 271 | datetime_range: tuple[str, str] = ( 272 | "2015-06-23T00:00:00Z", # Sentinel-2A launch date (first Sentinel-2 data available) 273 | "2024-08-27T23:59:59Z", # Dataset publication date Fields of The World (FTW) 274 | ) 275 | 276 | task = {TaskEnum.SEMANTIC_SEGMENTATION} 277 | 278 | properties = { 279 | "description": "STAC item generated using unet_mlm() in stac_model/examples.py example. " 280 | "Specified in https://github.com/fieldsoftheworld/ftw-baselines " 281 | "First 4 S2 bands are for image t1 and last 4 bands are for image t2", 282 | } 283 | 284 | item_ext = MLModelExtension.from_torch( 285 | model, 286 | task=task, 287 | weights=weights, 288 | item_id=item_id, 289 | collection=collection_id, 290 | bbox=bbox, 291 | geometry=geometry, 292 | datetime_range=datetime_range, 293 | stac_properties=properties, 294 | ) 295 | 296 | # Add additional metadata regarding the examples to have a valid STAC Item 297 | item = item_ext.item 298 | item_name = f"item_{item_id}.json" 299 | item_self_href = f"./{item_name}" 300 | 301 | link = pystac.Link(rel="self", target=item_self_href, media_type=pystac.MediaType.GEOJSON) 302 | link._target_href = item_self_href 303 | item.add_link(link) 304 | item.add_link(pystac.Link(rel="collection", target="./collection.json", media_type=pystac.MediaType.JSON)) 305 | 306 | item_ext.item = item 307 | return item_ext 308 | -------------------------------------------------------------------------------- /examples/item_raster_bands.json: -------------------------------------------------------------------------------- 1 | { 2 | "stac_version": "1.0.0", 3 | "stac_extensions": [ 4 | "https://stac-extensions.github.io/mlm/v1.5.0/schema.json", 5 | "https://stac-extensions.github.io/raster/v1.1.0/schema.json", 6 | "https://stac-extensions.github.io/file/v1.0.0/schema.json", 7 | "https://stac-extensions.github.io/ml-aoi/v0.2.0/schema.json" 8 | ], 9 | "type": "Feature", 10 | "id": "resnet-18_sentinel-2_all_moco_classification", 11 | "collection": "ml-model-examples", 12 | "geometry": { 13 | "type": "Polygon", 14 | "coordinates": [ 15 | [ 16 | [ 17 | -7.882190080512502, 18 | 37.13739173208318 19 | ], 20 | [ 21 | -7.882190080512502, 22 | 58.21798141355221 23 | ], 24 | [ 25 | 27.911651652899923, 26 | 58.21798141355221 27 | ], 28 | [ 29 | 27.911651652899923, 30 | 37.13739173208318 31 | ], 32 | [ 33 | -7.882190080512502, 34 | 37.13739173208318 35 | ] 36 | ] 37 | ] 38 | }, 39 | "bbox": [ 40 | -7.882190080512502, 41 | 37.13739173208318, 42 | 27.911651652899923, 43 | 58.21798141355221 44 | ], 45 | "properties": { 46 | "description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO", 47 | "datetime": null, 48 | "start_datetime": "1900-01-01T00:00:00Z", 49 | "end_datetime": "9999-12-31T23:59:59Z", 50 | "mlm:name": "Resnet-18 Sentinel-2 ALL MOCO", 51 | "mlm:tasks": [ 52 | "classification" 53 | ], 54 | "mlm:architecture": "ResNet", 55 | "mlm:framework": "pytorch", 56 | "mlm:framework_version": "2.1.2+cu121", 57 | "file:size": 43000000, 58 | "mlm:memory_size": 1, 59 | "mlm:total_parameters": 11700000, 60 | "mlm:pretrained_source": "EuroSat Sentinel-2", 61 | "mlm:accelerator": "cuda", 62 | "mlm:accelerator_constrained": false, 63 | "mlm:accelerator_summary": "Unknown", 64 | "mlm:batch_size_suggestion": 256, 65 | "mlm:input": [ 66 | { 67 | "name": "13 Band Sentinel-2 Batch", 68 | "bands": [ 69 | "B01", 70 | "B02", 71 | "B03", 72 | "B04", 73 | "B05", 74 | "B06", 75 | "B07", 76 | "B08", 77 | "B8A", 78 | "B09", 79 | "B10", 80 | "B11", 81 | "B12" 82 | ], 83 | "input": { 84 | "shape": [ 85 | -1, 86 | 13, 87 | 64, 88 | 64 89 | ], 90 | "dim_order": [ 91 | "batch", 92 | "bands", 93 | "height", 94 | "width" 95 | ], 96 | "data_type": "float32" 97 | }, 98 | "value_scaling": null, 99 | "resize_type": null, 100 | "pre_processing_function": { 101 | "format": "python", 102 | "expression": "torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn" 103 | } 104 | } 105 | ], 106 | "mlm:output": [ 107 | { 108 | "name": "classification", 109 | "tasks": [ 110 | "classification" 111 | ], 112 | "result": { 113 | "shape": [ 114 | -1, 115 | 10 116 | ], 117 | "dim_order": [ 118 | "batch", 119 | "class" 120 | ], 121 | "data_type": "float32" 122 | }, 123 | "classification_classes": [ 124 | { 125 | "value": 0, 126 | "name": "Annual Crop", 127 | "description": null, 128 | "title": null, 129 | "color_hint": null, 130 | "nodata": false 131 | }, 132 | { 133 | "value": 1, 134 | "name": "Forest", 135 | "description": null, 136 | "title": null, 137 | "color_hint": null, 138 | "nodata": false 139 | }, 140 | { 141 | "value": 2, 142 | "name": "Herbaceous Vegetation", 143 | "description": null, 144 | "title": null, 145 | "color_hint": null, 146 | "nodata": false 147 | }, 148 | { 149 | "value": 3, 150 | "name": "Highway", 151 | "description": null, 152 | "title": null, 153 | "color_hint": null, 154 | "nodata": false 155 | }, 156 | { 157 | "value": 4, 158 | "name": "Industrial Buildings", 159 | "description": null, 160 | "title": null, 161 | "color_hint": null, 162 | "nodata": false 163 | }, 164 | { 165 | "value": 5, 166 | "name": "Pasture", 167 | "description": null, 168 | "title": null, 169 | "color_hint": null, 170 | "nodata": false 171 | }, 172 | { 173 | "value": 6, 174 | "name": "Permanent Crop", 175 | "description": null, 176 | "title": null, 177 | "color_hint": null, 178 | "nodata": false 179 | }, 180 | { 181 | "value": 7, 182 | "name": "Residential Buildings", 183 | "description": null, 184 | "title": null, 185 | "color_hint": null, 186 | "nodata": false 187 | }, 188 | { 189 | "value": 8, 190 | "name": "River", 191 | "description": null, 192 | "title": null, 193 | "color_hint": null, 194 | "nodata": false 195 | }, 196 | { 197 | "value": 9, 198 | "name": "SeaLake", 199 | "description": null, 200 | "title": null, 201 | "color_hint": null, 202 | "nodata": false 203 | } 204 | ], 205 | "post_processing_function": null 206 | } 207 | ] 208 | }, 209 | "assets": { 210 | "weights": { 211 | "href": "https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth", 212 | "title": "Pytorch weights checkpoint", 213 | "description": "A Resnet-18 classification model trained on normalized Sentinel-2 imagery with Eurosat landcover labels with torchgeo", 214 | "type": "application/octet-stream; application=pytorch", 215 | "roles": [ 216 | "mlm:model", 217 | "mlm:weights" 218 | ], 219 | "mlm:artifact_type": "torch.save", 220 | "raster:bands": [ 221 | { 222 | "name": "B01", 223 | "nodata": 0, 224 | "data_type": "uint16", 225 | "bits_per_sample": 15, 226 | "spatial_resolution": 60, 227 | "scale": 0.0001, 228 | "offset": 0, 229 | "unit": "m" 230 | }, 231 | { 232 | "name": "B02", 233 | "nodata": 0, 234 | "data_type": "uint16", 235 | "bits_per_sample": 15, 236 | "spatial_resolution": 10, 237 | "scale": 0.0001, 238 | "offset": 0, 239 | "unit": "m" 240 | }, 241 | { 242 | "name": "B03", 243 | "nodata": 0, 244 | "data_type": "uint16", 245 | "bits_per_sample": 15, 246 | "spatial_resolution": 10, 247 | "scale": 0.0001, 248 | "offset": 0, 249 | "unit": "m" 250 | }, 251 | { 252 | "name": "B04", 253 | "nodata": 0, 254 | "data_type": "uint16", 255 | "bits_per_sample": 15, 256 | "spatial_resolution": 10, 257 | "scale": 0.0001, 258 | "offset": 0, 259 | "unit": "m" 260 | }, 261 | { 262 | "name": "B05", 263 | "nodata": 0, 264 | "data_type": "uint16", 265 | "bits_per_sample": 15, 266 | "spatial_resolution": 20, 267 | "scale": 0.0001, 268 | "offset": 0, 269 | "unit": "m" 270 | }, 271 | { 272 | "name": "B06", 273 | "nodata": 0, 274 | "data_type": "uint16", 275 | "bits_per_sample": 15, 276 | "spatial_resolution": 20, 277 | "scale": 0.0001, 278 | "offset": 0, 279 | "unit": "m" 280 | }, 281 | { 282 | "name": "B07", 283 | "nodata": 0, 284 | "data_type": "uint16", 285 | "bits_per_sample": 15, 286 | "spatial_resolution": 20, 287 | "scale": 0.0001, 288 | "offset": 0, 289 | "unit": "m" 290 | }, 291 | { 292 | "name": "B08", 293 | "nodata": 0, 294 | "data_type": "uint16", 295 | "bits_per_sample": 15, 296 | "spatial_resolution": 10, 297 | "scale": 0.0001, 298 | "offset": 0, 299 | "unit": "m" 300 | }, 301 | { 302 | "name": "B8A", 303 | "nodata": 0, 304 | "data_type": "uint16", 305 | "bits_per_sample": 15, 306 | "spatial_resolution": 20, 307 | "scale": 0.0001, 308 | "offset": 0, 309 | "unit": "m" 310 | }, 311 | { 312 | "name": "B09", 313 | "nodata": 0, 314 | "data_type": "uint16", 315 | "bits_per_sample": 15, 316 | "spatial_resolution": 60, 317 | "scale": 0.0001, 318 | "offset": 0, 319 | "unit": "m" 320 | }, 321 | { 322 | "name": "B10", 323 | "nodata": 0, 324 | "data_type": "uint16", 325 | "bits_per_sample": 15, 326 | "spatial_resolution": 60, 327 | "scale": 0.0001, 328 | "offset": 0, 329 | "unit": "m" 330 | }, 331 | { 332 | "name": "B11", 333 | "nodata": 0, 334 | "data_type": "uint16", 335 | "bits_per_sample": 15, 336 | "spatial_resolution": 20, 337 | "scale": 0.0001, 338 | "offset": 0, 339 | "unit": "m" 340 | }, 341 | { 342 | "name": "B12", 343 | "nodata": 0, 344 | "data_type": "uint16", 345 | "bits_per_sample": 15, 346 | "spatial_resolution": 20, 347 | "scale": 0.0001, 348 | "offset": 0, 349 | "unit": "m" 350 | } 351 | ] 352 | }, 353 | "source_code": { 354 | "href": "https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207", 355 | "title": "Model implementation.", 356 | "description": "Source code to run the model.", 357 | "type": "text/x-python", 358 | "roles": [ 359 | "mlm:source_code", 360 | "code", 361 | "metadata" 362 | ] 363 | } 364 | }, 365 | "links": [ 366 | { 367 | "rel": "collection", 368 | "href": "./collection.json", 369 | "type": "application/json" 370 | }, 371 | { 372 | "rel": "self", 373 | "href": "./item_raster_bands.json", 374 | "type": "application/geo+json" 375 | }, 376 | { 377 | "rel": "derived_from", 378 | "href": "https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a", 379 | "type": "application/json", 380 | "ml-aoi:split": "train" 381 | } 382 | ] 383 | } 384 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["pdm-backend"] 3 | build-backend = "pdm.backend" 4 | 5 | [tool.pdm.build] 6 | includes = ["stac_model"] 7 | 8 | [project] 9 | authors = [ 10 | { name = "Ryan Avery", email = "ryan@wherobots.com" }, 11 | { name = "Francis Charette-Migneault", email = "francis.charette-migneault@crim.ca" }, 12 | ] 13 | license = { text = "Apache Software License 2.0" } 14 | requires-python = "<4.0,>=3.10" 15 | dependencies = [ 16 | "typer<1.0.0,>=0.9.0", 17 | "rich>=13.7.0,<15.0.0", 18 | "pydantic<3.0.0,>=2.6.3", 19 | "pydantic-core<3,>=2", 20 | "pystac<2.0.0,>=1.9.0", 21 | "shapely<3,>=2", 22 | "jsonschema<5.0.0,>=4.21.1", 23 | "pip>=25.0.0", 24 | ] 25 | optional-dependencies.torch = [ 26 | "torch==2.8.0; python_version>='3.11'", 27 | "torchgeo; python_version>='3.11'", 28 | "torchvision>=0.21,<1; python_version>='3.11'" 29 | ] 30 | optional-dependencies.torch-cu126 = [ 31 | "pytorch-triton; python_version>='3.11'", 32 | "torch==2.8.0; python_version>='3.11'", 33 | "torchgeo; python_version>='3.11'", 34 | "torchvision>=0.21,<1; python_version>='3.11'", 35 | ] 36 | 37 | # important: leave the name and version together for bump resolution 38 | name = "stac-model" 39 | version = "0.4.0" 40 | description = "A PydanticV2 validation and serialization libary for the STAC ML Model Extension" 41 | readme = "README_STAC_MODEL.md" 42 | keywords = [ 43 | "STAC", 44 | "SpatioTemporal Asset Catalog", 45 | "Machine Learning Model", 46 | "Artificial Intelligence", 47 | ] 48 | classifiers = [ 49 | "Development Status :: 4 - Beta", 50 | "Operating System :: OS Independent", 51 | "Topic :: Software Development :: Libraries :: Python Modules", 52 | "License :: OSI Approved :: Apache Software License", 53 | "Programming Language :: Python :: 3", 54 | "Programming Language :: Python :: 3.10", 55 | "Programming Language :: Python :: 3.11", 56 | "Programming Language :: Python :: 3.12", 57 | "Programming Language :: Python :: 3.13", 58 | "Programming Language :: Python :: 3 :: Only", 59 | "Framework :: Pydantic", 60 | "Framework :: Pydantic :: 2", 61 | "Intended Audience :: Developers", 62 | "Intended Audience :: Information Technology", 63 | "Intended Audience :: Science/Research", 64 | "Topic :: File Formats :: JSON :: JSON Schema", 65 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 66 | "Topic :: Scientific/Engineering :: GIS", 67 | "Topic :: Scientific/Engineering :: Image Processing", 68 | "Topic :: Scientific/Engineering :: Image Recognition", 69 | ] 70 | 71 | [tool.uv] 72 | dev-dependencies = [ 73 | "setuptools>=78.1.1", 74 | "mypy>=1.10.0,<2.0.0", 75 | "mypy-extensions>=0.4.3,<1.2.0", 76 | "pre-commit<3.0.0,>=2.21.0", 77 | "bandit<2.0.0,>=1.7.5", 78 | "safety>=3.6.1,<4.0.0", 79 | "typer>=0.17.0", # fix for https://github.com/pyupio/safety/issues/784 80 | "pystac<2.0.0,>=1.10.0", 81 | "pydocstyle[toml]<7.0.0,>=6.2.0", 82 | "pydoclint<0.6,>=0.3", 83 | "pytest<8.0.0,>=7.2.1", 84 | "pytest-cov<5.0.0,>=4.1.0", 85 | "pytest-mock<4.0.0,>=3.10.0", 86 | "pytest-timeout<3.0.0,>=2.2.0", 87 | "pytest-benchmark<5.0.0,>=4.0.0", 88 | "pytest-sugar<1.0.0,>=0.9.7", 89 | "pytest-click<2.0.0,>=1.1.0", 90 | "pytest-pikachu<2.0.0,>=1.0.0", 91 | "coverage<8.0.0,>=7.3.0", 92 | "ruff<1.0.0,>=0.2.2", 93 | "bump-my-version>=1.2.1", 94 | "types-python-dateutil>=2.9.0.20241003", 95 | "types-pyyaml>=6.0.12.20250516", 96 | "requests>=2.32.4", 97 | "urllib3>=2.5.0,<3.0.0", 98 | "httpx<1", 99 | "referencing>=0.36.2", 100 | "cffconvert", 101 | "coverage-conditional-plugin>=0.9.0", 102 | "authlib>=1.6.5", 103 | ] 104 | 105 | conflicts = [[{ extra = "torch" }, { extra = "torch-cu126" }]] 106 | prerelease = "allow" 107 | 108 | [[tool.uv.index]] 109 | name = "pytorch-cpu" 110 | url = "https://download.pytorch.org/whl/cpu" 111 | explicit = true 112 | 113 | [[tool.uv.index]] 114 | name = "pytorch-cu126" 115 | url = "https://download.pytorch.org/whl/cu126" 116 | explicit = true 117 | 118 | [tool.uv.sources] 119 | torch = [ 120 | { index = "pytorch-cpu", extra = "torch" }, 121 | { index = "pytorch-cu126", extra = "torch-cu126" }, 122 | ] 123 | torchvision = [ 124 | { index = "pytorch-cpu", extra = "torch" }, 125 | { index = "pytorch-cu126", extra = "torch-cu126" }, 126 | ] 127 | pytorch-triton = [{ index = "pytorch-cu126", extra = "torch-cu126" }] 128 | torchgeo = { git = "https://github.com/microsoft/torchgeo", rev = "95dba5c43a2828c04a6b0a316d1996c5d1c6b9c5" } 129 | cffconvert = { git = "https://github.com/citation-file-format/cffconvert.git", rev = "b6045d78aac9e02b039703b030588d54d53262ac" } 130 | 131 | [project.urls] 132 | homepage = "https://github.com/stac-extensions/mlm/blob/main/README_STAC_MODEL.md" 133 | repository = "https://github.com/crim-ca/mlm-extension" 134 | 135 | [project.scripts] 136 | stac-model = "stac_model.__main__:app" 137 | 138 | [tool.bumpversion] 139 | # NOTE: 140 | # Although these definitions are provided in this 'stac-model' project file, 141 | # they are actually intented for versioning the MLM specification itself. 142 | # To version 'stac-model', use the 'bump-my-version bump' operation using the 'stac-model.bump.toml' file. 143 | # See also https://github.com/stac-extensions/mlm/blob/main/CONTRIBUTING.md#building-and-releasing 144 | current_version = "1.5.0" 145 | parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)" 146 | serialize = ["{major}.{minor}.{patch}"] 147 | search = "{current_version}" 148 | replace = "{new_version}" 149 | regex = false 150 | ignore_missing_version = true 151 | ignore_missing_files = false 152 | tag = true 153 | sign_tags = false 154 | tag_name = "v{new_version}" 155 | tag_message = "Bump version: {current_version} → {new_version}" 156 | allow_dirty = false 157 | commit = true 158 | commit_args = "--no-verify" 159 | message = "Bump version: {current_version} → {new_version}" 160 | 161 | [[tool.bumpversion.files]] 162 | glob = "**/*.@(json|yml|yaml|py|md)" 163 | glob_exclude = [ 164 | ".git/**", 165 | "**/__pycache__/**", 166 | ".mypy_cache/**", 167 | ".tox/**", 168 | ".venv/**", 169 | "_build/**", 170 | "build/**", 171 | "dist/**", 172 | "node_modules/**", 173 | "htmlcov/**", 174 | "package-lock.json", 175 | ] 176 | search = "https://stac-extensions.github.io/mlm/v{current_version}/schema.json" 177 | replace = "https://stac-extensions.github.io/mlm/v{new_version}/schema.json" 178 | 179 | 180 | [[tool.bumpversion.files]] 181 | filename = "CHANGELOG.md" 182 | search = """ 183 | ## [Unreleased](https://github.com/stac-extensions/mlm/tree/main) 184 | """ 185 | replace = """ 186 | ## [Unreleased](https://github.com/stac-extensions/mlm/tree/main) 187 | 188 | ### Added 189 | 190 | - n/a 191 | 192 | ### Changed 193 | 194 | - n/a 195 | 196 | ### Deprecated 197 | 198 | - n/a 199 | 200 | ### Removed 201 | 202 | - n/a 203 | 204 | ### Fixed 205 | 206 | - n/a 207 | 208 | ## [v{new_version}](https://github.com/stac-extensions/mlm/tree/v{new_version}) 209 | """ 210 | 211 | [[tool.bumpversion.files]] 212 | filename = "CITATION.cff" 213 | search = "https://stac-extensions.github.io/mlm/v{current_version}/schema.json" 214 | replace = "https://stac-extensions.github.io/mlm/v{new_version}/schema.json" 215 | 216 | [[tool.bumpversion.files]] 217 | filename = "package.json" 218 | search = "\"version\": \"{current_version}\"" 219 | replace = "\"version\": \"{new_version}\"" 220 | 221 | [tool.ruff] 222 | exclude = [ 223 | ".git", 224 | "__pycache__", 225 | ".mypy_cache", 226 | ".tox", 227 | ".venv", 228 | "_build", 229 | "buck-out", 230 | "build", 231 | "dist", 232 | "env", 233 | "venv", 234 | "node_modules", 235 | ] 236 | respect-gitignore = true 237 | line-length = 120 238 | show-fixes = true 239 | 240 | [tool.ruff.lint] 241 | select = [ 242 | # pycodestyle 243 | "E", 244 | # Pyflakes 245 | "F", 246 | # pyupgrade 247 | "UP", 248 | # flake8-bugbear 249 | "B", 250 | # flake8-simplify 251 | "SIM", 252 | # isort 253 | "I", 254 | ] 255 | ignore = ["UP007", "UP015", "E501"] 256 | 257 | [tool.ruff.lint.isort] 258 | known-local-folder = ["tests", "conftest"] 259 | known-first-party = ["stac_model"] 260 | known-third-party = ["torch"] 261 | extra-standard-library = ["typing_extensions"] 262 | 263 | [tool.mypy] 264 | # https://github.com/python/mypy 265 | # https://mypy.readthedocs.io/en/latest/config_file.html#using-a-pyproject-toml-file 266 | python_version = "3.10" 267 | pretty = true 268 | show_traceback = true 269 | color_output = true 270 | exclude = '(^\\.venv/|site-packages/)' 271 | files = [ 272 | "stac_model", 273 | "tests", 274 | ] 275 | allow_redefinition = false 276 | check_untyped_defs = true 277 | disallow_any_generics = true 278 | disallow_incomplete_defs = true 279 | ignore_missing_imports = true 280 | implicit_reexport = false 281 | no_implicit_optional = true 282 | show_column_numbers = true 283 | show_error_codes = true 284 | show_error_context = true 285 | strict_equality = true 286 | strict_optional = true 287 | warn_no_return = true 288 | warn_redundant_casts = true 289 | warn_return_any = true 290 | warn_unreachable = true 291 | warn_unused_configs = true 292 | warn_unused_ignores = true 293 | 294 | plugins = ["pydantic.mypy"] 295 | 296 | [tool.pydantic-mypy] 297 | init_forbid_extra = true 298 | init_typed = true 299 | warn_required_dynamic_aliases = true 300 | 301 | [tool.pydocstyle] 302 | # https://github.com/PyCQA/pydocstyle 303 | # http://www.pydocstyle.org/en/stable/usage.html#available-options 304 | convention = "google" 305 | match_dir = "^(stac_model|tests)" 306 | # ignore missing documentation, just validate provided ones 307 | add_ignore = "D100,D101,D102,D103,D104,D105,D107,D200,D202,D204,D212,D401" 308 | 309 | [tool.pydoclint] 310 | # https://github.com/jsh9/pydoclint 311 | # https://jsh9.github.io/pydoclint/how_to_config.html 312 | style = "google" 313 | exclude = '\.git|\.hg|\.mypy_cache|\.tox|.?v?env|__pycache__|_build|buck-out|dist|node_modules' 314 | # don't require type hints, since we have them in the signature instead (don't duplicate) 315 | arg-type-hints-in-docstring = false 316 | arg-type-hints-in-signature = true 317 | check-return-types = false 318 | 319 | [tool.pytest.ini_options] 320 | # https://github.com/pytest-dev/pytest 321 | # https://docs.pytest.org/en/6.2.x/customize.html#pyproject-toml 322 | # Directories that are not visited by pytest collector: 323 | norecursedirs = [ 324 | "hooks", 325 | "*.egg", 326 | ".eggs", 327 | "dist", 328 | "build", 329 | "docs", 330 | ".tox", 331 | ".git", 332 | "__pycache__", 333 | "node_modules", 334 | ] 335 | markers = ["slow: mark test as slow to run (more than a few minutes)"] 336 | doctest_optionflags = [ 337 | "NUMBER", 338 | "NORMALIZE_WHITESPACE", 339 | "IGNORE_EXCEPTION_DETAIL", 340 | ] 341 | timeout = 1000 342 | 343 | # Extra options: 344 | addopts = [ 345 | "-m not slow", 346 | "--strict-markers", 347 | "--tb=short", 348 | "--doctest-modules", 349 | "--doctest-continue-on-failure", 350 | "--pikachu", 351 | ] 352 | 353 | [tool.coverage.run] 354 | source = ["tests"] 355 | branch = true 356 | plugins = ["coverage_conditional_plugin"] 357 | 358 | [tool.coverage.coverage_conditional_plugin.omit] 359 | "not is_installed('torch')" = "stac_model/torch/*.py" 360 | 361 | [tool.coverage.coverage_conditional_plugin.rules] 362 | has-torch = "is_installed('torch')" 363 | 364 | [tool.coverage.report] 365 | exclude_also = [ 366 | "def main", 367 | "if __name__ == .__main__.:" 368 | ] 369 | fail_under = 80 370 | show_missing = true 371 | 372 | [tool.coverage.paths] 373 | source = ["stac_model"] 374 | -------------------------------------------------------------------------------- /docs/legacy/ml-model.md: -------------------------------------------------------------------------------- 1 | # ML Model Extension Specification 2 | 3 | 4 | 5 | > [!WARNING] 6 | > This is legacy documentation reference of [ML-Model][ml-model] 7 | > preceding the current Machine Learning Model ([MLM][mlm-spec]) extension. 8 | 9 | 10 | 11 | ## Notable Differences 12 | 13 | - The [MLM][mlm-spec] extension covers more details at both the [Item](#item-properties) and [Asset](#asset-objects) 14 | levels, making it easier to describe and use model metadata. 15 | 16 | - The [MLM][mlm-spec] extension covers runtime requirements using distinct [Asset Roles](#roles) 17 | ([Model][mlm-asset-model], [Container][mlm-asset-container] and [Source Code][mlm-asset-code]) which allows 18 | for more flexibility in describing how and which operations are performed by a given model. 19 | This is in contrast to the [ML-Model][ml-model] extension that records [similar information][ml-model-runtimes] 20 | in `ml-model:inference-runtime` or `ml-model:training-runtime` __*all at once*__, which leads to runtime ambiguities 21 | and limited reusability. 22 | 23 | - The [MLM][mlm-spec] extension provides additional fields to better describe the model properties, such as 24 | the [Model Inputs][mlm-inputs] to describe the input features, bands, data transforms, or any 25 | other relevant data sources and preparation steps required by the model, the [Model Outputs][mlm-outputs] to describe 26 | the output predictions, regression values, classes or other relevant information about what the model produces, and 27 | the [Model Hyperparameters][mlm-hyperparam] to better describe training configuration 28 | that lead to the model definition. All of these fields are __*undefined*__ in the [ML-Model][ml-model] extension. 29 | 30 | - The [MLM][mlm-spec] extension has a corresponding Python library [`stac-model`][mlm-stac-model], 31 | which can be used to create and validate MLM metadata using [pydantic][pydantic]. 32 | An example of the library in action is [provided in examples](./../../stac_model/examples.py). 33 | The extension also provides [pystac MLM][pystac-mlm] for easier integration with the STAC ecosystem. 34 | The [MLM Form Filler][mlm-form] is also available to help users create valid MLM metadata in a no-code fashion. 35 | In contrast, [ML-Model][ml-model] extension does not provide any support for Python integration and requires the JSON 36 | to be written manually. 37 | 38 | ## Migration Tables 39 | 40 | Following are the corresponding fields between the legacy [ML-Model][ml-model] and the current [MLM][mlm-spec] 41 | extension, which can be used to completely migrate to the newer *Machine Leaning Model* extension providing 42 | enhanced features and interconnectivity with other STAC extensions (see also [Best Practices][mlm-bp]). 43 | 44 | 45 | 46 | > [!IMPORTANT] 47 | > Only the limited set of [`ml-model`][ml-model] fields are listed below for migration guidelines. 48 | > See the full [MLM Specification](./../../README.md) for all additional fields provided to further describe models. 49 | 50 | 51 | 52 | ### Item Properties 53 | 54 | | ML-Model Field | MLM Field | Migration Details | 55 | |----------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 56 | | `ml-model:type`
(`"ml-model"` constant) | *n/a* | Including the MLM URI in `stac_extensions` is sufficient to indicate that the Item is a model. | 57 | | `ml-model:learning_approach` | *n/a* | No direct mapping. Machine Learning training approaches can be very convoluted to describe. Instead, it is recommended to employ `derived_from` collection and other STAC Extension references to describe explicitly how the model was obtained. See [Best Practices][mlm-bp] for more details. | 58 | | `ml-model:prediction_type`
(`string`) | `mlm:tasks`
(`[string]`) | ML-Model limited to a single task. MLM allows multiple. Use `[""]` to migrate directly. | 59 | | `ml-model:architecture` | `mlm:architecture` | Direct mapping. | 60 | | `ml-model:training-processor-type`
`ml-model:training-os` | `mlm:framework`
`mlm:framework_version`
`mlm:accelerator`
`mlm:accelerator_constrained`
`mlm:accelerator_summary`
`mlm:accelerator_count` | More fields are provided to describe the subtleties of compute hardware and ML frameworks that can be intricated between them. If compute hardware imposes OS dependencies, they are typically reflected through the framework version and/or the specific accelerator. Further subtleties are permitted with [complex accelerator values][mlm-acc-type]. | 61 | 62 | ### Asset Objects 63 | 64 | #### Roles 65 | 66 | All [ML-Model Asset Roles](https://github.com/stac-extensions/ml-model/blob/main/README.md#roles) 67 | are available with a prefix change with the same sematic meaning. 68 | 69 | Further roles are also proposed in [MLM Asset Roles](./../../README.md#mlm-asset-roles). 70 | 71 | | ML-Model Field | MLM Field | Migration Details | 72 | |------------------------------|-------------------------|--------------------------------------------------------------------------------------------------------------------------------------------| 73 | | `ml-model:inference-runtime` | `mlm:inference-runtime` | Prefix change. | 74 | | `ml-model:training-runtime` | `mlm:training-runtime` | Prefix change. | 75 | | `ml-model:checkpoint` | `mlm:checkpoint` | Prefix change. Recommended addition of further `mlm` properties for [Model Asset](./../../README.md#model-asset) to describe the artifact. | 76 | 77 | 78 | 79 | > [!TIP] 80 | > In the context of [ML-Model][ml-model], Assets providing [Inference/Training Runtimes][ml-model-runtimes] 81 | > are strictly provided as [Docker Compose][docker-compose-file] definitions. While this is still permitted, 82 | > the MLM extension offers alternatives using any relevant definition for the model, as long as it is properly 83 | > identified by its applicable media-type. Additional recommendations and Asset property fields are provided 84 | > under [MLM Assets Objects](./../../README.md#assets-objects) for specific cases. 85 | 86 | 87 | 88 | ### Relation Types 89 | 90 | | ML-Model Field | MLM Field | Migration Details | 91 | |-------------------------------------------------|----------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 92 | | `ml-model:inferencing-image` | *n/a* | Deemed redundant with `mlm:inference-runtime` Asset Role. | 93 | | `ml-model:training-image` | *n/a* | Deemed redundant with `mlm:training-runtime` Asset Role. | 94 | | `ml-model:train-data`
`ml-model:test-data` | `derived_from` | Use one or more `derived_from` links (as many as needed with regard to data involved during the model creation. Linked data should employ `ml-aoi` as appropriate (see [ML-AOI Best Practices][mlm-ml-aoi]). | 95 | 96 | [mlm-acc-type]: ./../../README.md#accelerator-type-enum 97 | 98 | [mlm-asset-model]: ./../../README.md#model-asset 99 | 100 | [mlm-asset-container]: ./../../README.md#container-asset 101 | 102 | [mlm-asset-code]: ./../../README.md#source-code-asset 103 | 104 | [mlm-inputs]: ./../../README.md#model-input-object 105 | 106 | [mlm-outputs]: ./../../README.md#model-output-object 107 | 108 | [mlm-hyperparam]: ./../../README.md#model-hyperparameters-object 109 | 110 | [mlm-stac-model]: https://pypi.org/project/stac-model/ 111 | 112 | [mlm-form]: https://mlm-form.vercel.app/ 113 | 114 | [mlm-spec]: ./../../README.md 115 | 116 | [mlm-bp]: ./../../best-practices.md 117 | 118 | [mlm-ml-aoi]: ./../../best-practices.md#ml-aoi-and-label-extensions 119 | 120 | [ml-model]: https://github.com/stac-extensions/ml-model 121 | 122 | [ml-model-runtimes]: https://github.com/stac-extensions/ml-model/blob/main/README.md#inferencetraining-runtimes 123 | 124 | [pydantic]: https://docs.pydantic.dev/latest/ 125 | 126 | [pystac-mlm]: https://github.com/stac-utils/pystac/blob/main/pystac/extensions/mlm.py 127 | 128 | [docker-compose-file]: https://github.com/compose-spec/compose-spec/blob/master/spec.md#compose-file 129 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: If you use this standard or software, please cite it using the metadata from this file. 3 | title: Machine Learning Model Extension Specification for SpatioTemporal Asset Catalog 4 | type: software 5 | keywords: 6 | - mlm 7 | - Machine Learning 8 | - Model 9 | - STAC 10 | url: "https://github.com/stac-extensions/mlm/blob/main/README.md" 11 | repository-code: "https://github.com/stac-extensions/mlm" 12 | license: Apache-2.0 13 | license-url: https://github.com/stac-extensions/mlm/blob/main/LICENSE 14 | identifiers: 15 | - type: doi 16 | value: "10.1145/3681769.3698586" 17 | description: "Conference paper presenting the standard." 18 | - type: url 19 | value: "https://stac-extensions.github.io/mlm/" 20 | description: "Generic URL of the MLM extension schema versions for 'stac_extensions' references." 21 | contact: 22 | - given-names: Francis 23 | family-names: Charette-Migneault 24 | email: francis.charette-migneault@crim.ca 25 | affiliation: Computer Research Institute of Montréal (CRIM) 26 | orcid: "https://orcid.org/0000-0003-4862-3349" 27 | - given-names: Ryan 28 | family-names: Avery 29 | alias: rbavery 30 | email: ryan@wherobots.com 31 | affiliation: "Wherobots, Inc." 32 | orcid: "https://orcid.org/0000-0001-7392-1474" 33 | authors: &authors 34 | - given-names: Francis 35 | family-names: Charette-Migneault 36 | alias: fmigneault 37 | email: francis.charette-migneault@crim.ca 38 | affiliation: Computer Research Institute of Montréal (CRIM) 39 | orcid: "https://orcid.org/0000-0003-4862-3349" 40 | - given-names: Ryan 41 | family-names: Avery 42 | alias: rbavery 43 | email: ryan@wherobots.com 44 | affiliation: "Wherobots, Inc." 45 | orcid: "https://orcid.org/0000-0001-7392-1474" 46 | - &crim 47 | name: Computer Research Institute of Montréal 48 | city: Montréal 49 | region: Québec 50 | alias: CRIM 51 | website: "https://www.crim.ca/" 52 | email: info@crim.ca 53 | tel: 1 (514) 840-1234 54 | country: CA 55 | post-code: H3N 1M3 56 | address: "101 – 405, avenue Ogilvy" 57 | - name: "Wherobots, Inc." 58 | address: 350 California St 59 | city: San Francisco 60 | country: US 61 | post-code: "94104" 62 | region: California 63 | website: "https://www.wherobots.ai/" 64 | location: Floor 1 - Lincoln Towne Center 65 | 66 | references: 67 | - type: software-code 68 | title: "A PydanticV2 and PySTAC validation and serialization library for the STAC ML Model Extension" 69 | keywords: 70 | - stac_model 71 | repository-code: "https://github.com/stac-extensions/mlm/tree/main/stac_model" 72 | repository-artifact: "https://pypi.org/project/stac-model/" 73 | url: "https://github.com/stac-extensions/mlm/blob/main/README_STAC_MODEL.md" 74 | authors: 75 | - given-names: Ryan 76 | family-names: Avery 77 | alias: rbavery 78 | email: ryan@wherobots.com 79 | affiliation: "Wherobots, Inc." 80 | orcid: "https://orcid.org/0000-0001-7392-1474" 81 | - given-names: Francis 82 | family-names: Charette-Migneault 83 | alias: fmigneault 84 | email: francis.charette-migneault@crim.ca 85 | affiliation: Computer Research Institute of Montréal (CRIM) 86 | orcid: "https://orcid.org/0000-0003-4862-3349" 87 | 88 | - type: standard 89 | title: STAC MLM specification 90 | authors: *authors 91 | identifiers: 92 | - type: url 93 | value: "https://stac-extensions.github.io/mlm/v1.5.0/schema.json" 94 | description: "Latest extension URL used in 'stac_extensions' references." 95 | - type: url 96 | value: "https://stac-extensions.github.io/mlm/" 97 | description: "Generic URL of the MLM extension schema versions for 'stac_extensions' references." 98 | 99 | - type: software-code 100 | title: "Archive repository of the STAC MLM specification." 101 | repository-code: "https://github.com/crim-ca/mlm-extension" 102 | authors: *authors 103 | identifiers: 104 | - type: url 105 | value: "https://crim-ca.github.io/mlm-extension/v1.3.0/schema.json" 106 | description: "Archive extension URL used in 'stac_extensions' references." 107 | - type: url 108 | value: "https://crim-ca.github.io/mlm-extension/" 109 | description: "Generic URL of the archived MLM extension schema versions for 'stac_extensions' references." 110 | 111 | - type: report 112 | title: Project CCCOT03 – Technical Report 113 | abstract: "Project CCCOT03: Proposal for a STAC Extension for Deep Learning Models" 114 | keywords: 115 | - dlm 116 | - Deep Learning 117 | - Model 118 | - STAC 119 | repository: "https://raw.githubusercontent.com/crim-ca/CCCOT03/main/CCCOT03_Rapport%20Final_FINAL_EN.pdf" 120 | repository-code: "https://github.com/crim-ca/dlm-extension" 121 | license: Apache-2.0 122 | license-url: https://github.com/crim-ca/dlm-extension/blob/main/LICENSE 123 | date-released: "2020-12-14" 124 | languages: 125 | - en 126 | doi: "10.13140/RG.2.2.27858.68804" 127 | url: "https://www.researchgate.net/publication/349003427" 128 | institution: *crim 129 | authors: 130 | - given-names: Francis 131 | family-names: Charette-Migneault 132 | alias: fmigneault 133 | email: francis.charette-migneault@crim.ca 134 | affiliation: Computer Research Institute of Montréal (CRIM) 135 | orcid: "https://orcid.org/0000-0003-4862-3349" 136 | - given-names: Samuel 137 | family-names: Foucher 138 | alias: sfoucher 139 | orcid: "https://orcid.org/0000-0001-9557-6907" 140 | - given-names: David 141 | family-names: Landry 142 | orcid: "https://orcid.org/0000-0001-5343-2235" 143 | - given-names: Yves 144 | family-names: Moisan 145 | alias: ymoisan 146 | - name: Computer Research Institute of Montréal 147 | city: Montréal 148 | region: Québec 149 | alias: CRIM 150 | website: "https://www.crim.ca/" 151 | email: info@crim.ca 152 | tel: 1 (514) 840-1234 153 | country: CA 154 | post-code: H3N 1M3 155 | address: "101 – 405, avenue Ogilvy" 156 | - name: "Natural Resources Canada" 157 | country: CA 158 | website: "https://natural-resources.canada.ca/" 159 | - name: "Canada Centre for Mapping and Earth Observation" 160 | alias: CCMEO 161 | country: CA 162 | website: "https://natural-resources.canada.ca/research-centres-and-labs/canada-centre-for-mapping-and-earth-observation/25735" 163 | 164 | - type: conference 165 | notes: Conference reference where the demo paper presenting MLM is published. 166 | title: "GeoSearch’24: Proceedings of the 3rd ACM SIGSPATIAL International Workshop on Searching and Mining Large Collections of Geospatial Data" 167 | conference: 168 | name: "SIGSPATIAL’24: The 32nd ACM International Conference on Advances in Geographic Information Systems" 169 | date-start: "2024-10-29" 170 | date-end: "2024-11-01" 171 | city: Atlanta 172 | region: Georgia 173 | country: US 174 | url: https://dl.acm.org/doi/proceedings/10.1145/3681769 175 | isbn: "979-8-4007-1148-0" 176 | date-published: "2024-10-29" 177 | publisher: 178 | name: "Association for Computing Machinery" 179 | authors: 180 | - given-names: Hao 181 | family-names: Li 182 | - given-names: Abhishek 183 | family-names: Potnis 184 | - given-names: Wenwen 185 | family-names: Li 186 | - given-names: Dalton 187 | family-names: Lunga 188 | - given-names: Martin 189 | family-names: Werner 190 | - given-names: Andreas 191 | family-names: Züfle 192 | 193 | preferred-citation: 194 | type: conference-paper 195 | doi: "10.1145/3681769.3698586" 196 | title: Machine Learning Model Specification for Cataloging Spatio-Temporal Models 197 | conference: 198 | name: 3rd ACM SIGSPATIAL International Workshop on Searching and Mining Large Collections of Geospatial Data 199 | alias: GeoSearch’24 200 | date-published: "2024-10-29" 201 | year: 2024 202 | month: 10 203 | pages: 4 204 | loc-start: 36 205 | loc-end: 39 206 | location: 207 | name: Georgia Tech Hotel and Conference Center 208 | city: Atlanta 209 | region: Georgia 210 | country: US 211 | languages: 212 | - en 213 | abstract: >- 214 | The Machine Learning Model (MLM) extension is a 215 | specification that extends the SpatioTemporal Asset 216 | Catalogs (STAC) framework to catalog machine learning 217 | models. This demo paper introduces the goals of the MLM, 218 | highlighting its role in improving 219 | searchability and reproducibility of geospatial models. 220 | The MLM is contextualized within the STAC ecosystem, 221 | demonstrating its compatibility and the advantages it 222 | brings to discovering relevant geospatial models and 223 | describing their inference requirements. 224 | 225 | A detailed overview of the MLM's structure and fields 226 | describes the tasks, hardware requirements, frameworks, 227 | and inputs/outputs associated with machine learning 228 | models. Three use cases are presented, showcasing the 229 | application of the MLM in describing models for land cover 230 | classification and image segmentation. These examples 231 | illustrate how the MLM facilitates easier search and better 232 | understanding of how to deploy models in inference pipelines. 233 | 234 | The discussion addresses future challenges in extending 235 | the MLM to account for the diversity in machine learning 236 | models, including foundational and fine-tuned models, 237 | multi-modal models, and the importance of describing the 238 | data pipeline and infrastructure models depend on. 239 | Finally, the paper demonstrates the potential of the MLM 240 | to be a unifying standard to enable benchmarking and 241 | comparing geospatial machine learning models. 242 | keywords: 243 | - STAC 244 | - Catalog 245 | - Machine Learning 246 | - Spatio-Temporal Models 247 | - Search 248 | contact: 249 | - given-names: Francis 250 | family-names: Charette-Migneault 251 | email: francis.charette-migneault@crim.ca 252 | affiliation: Computer Research Institute of Montréal (CRIM) 253 | orcid: "https://orcid.org/0000-0003-4862-3349" 254 | authors: 255 | - given-names: Francis 256 | family-names: Charette-Migneault 257 | email: francis.charette-migneault@crim.ca 258 | affiliation: Computer Research Institute of Montréal (CRIM) 259 | orcid: "https://orcid.org/0000-0003-4862-3349" 260 | - given-names: Ryan 261 | family-names: Avery 262 | email: ryan@wherobots.com 263 | affiliation: "Wherobots, Inc." 264 | orcid: "https://orcid.org/0000-0001-7392-1474" 265 | - given-names: Brian 266 | family-names: Pondi 267 | email: brian.pondi@uni-muenster.de 268 | affiliation: "Institute for Geoinformatics, University of Münster" 269 | orcid: "https://orcid.org/0009-0008-0367-1690" 270 | - given-names: Joses 271 | family-names: Omojola 272 | affiliation: University of Arizona 273 | email: jomojo1@arizona.edu 274 | orcid: "https://orcid.org/0000-0001-5807-2953" 275 | - given-names: Simone 276 | family-names: Vaccari 277 | email: simone.vaccari@terradue.com 278 | affiliation: Terradue 279 | orcid: "https://orcid.org/0000-0002-2757-4165" 280 | - given-names: Parham 281 | family-names: Membari 282 | email: parham.membari@terradue.com 283 | affiliation: Terradue 284 | orcid: "https://orcid.org/0009-0004-7594-4011" 285 | - given-names: Devis 286 | family-names: Peressutti 287 | email: devis.peressutti@planet.com 288 | affiliation: "Sinergise Solutions, a Planet Labs company" 289 | orcid: "https://orcid.org/0000-0002-4660-0576" 290 | - given-names: Jia 291 | family-names: Yu 292 | email: jiayu@wherobots.com 293 | affiliation: "Wherobots, Inc." 294 | orcid: "https://orcid.org/0000-0003-1340-6475" 295 | - given-names: Jed 296 | family-names: Sundwall 297 | email: jed@radiant.earth 298 | affiliation: Radiant Earth 299 | orcid: "https://orcid.org/0000-0001-9681-230X" 300 | -------------------------------------------------------------------------------- /examples/item_eo_bands.json: -------------------------------------------------------------------------------- 1 | { 2 | "$comment": "Demonstrate the use of MLM and EO for bands description, with EO bands directly in the Model Asset.", 3 | "stac_version": "1.0.0", 4 | "stac_extensions": [ 5 | "https://stac-extensions.github.io/mlm/v1.5.0/schema.json", 6 | "https://stac-extensions.github.io/eo/v1.1.0/schema.json", 7 | "https://stac-extensions.github.io/file/v1.0.0/schema.json", 8 | "https://stac-extensions.github.io/ml-aoi/v0.2.0/schema.json" 9 | ], 10 | "type": "Feature", 11 | "id": "resnet-18_sentinel-2_all_moco_classification", 12 | "collection": "ml-model-examples", 13 | "geometry": { 14 | "type": "Polygon", 15 | "coordinates": [ 16 | [ 17 | [ 18 | -7.882190080512502, 19 | 37.13739173208318 20 | ], 21 | [ 22 | -7.882190080512502, 23 | 58.21798141355221 24 | ], 25 | [ 26 | 27.911651652899923, 27 | 58.21798141355221 28 | ], 29 | [ 30 | 27.911651652899923, 31 | 37.13739173208318 32 | ], 33 | [ 34 | -7.882190080512502, 35 | 37.13739173208318 36 | ] 37 | ] 38 | ] 39 | }, 40 | "bbox": [ 41 | -7.882190080512502, 42 | 37.13739173208318, 43 | 27.911651652899923, 44 | 58.21798141355221 45 | ], 46 | "properties": { 47 | "description": "Sourced from torchgeo python library, identifier is ResNet18_Weights.SENTINEL2_ALL_MOCO", 48 | "datetime": null, 49 | "start_datetime": "1900-01-01T00:00:00Z", 50 | "end_datetime": "9999-12-31T23:59:59Z", 51 | "mlm:name": "Resnet-18 Sentinel-2 ALL MOCO", 52 | "mlm:tasks": [ 53 | "classification" 54 | ], 55 | "mlm:architecture": "ResNet", 56 | "mlm:framework": "pytorch", 57 | "mlm:framework_version": "2.1.2+cu121", 58 | "file:size": 43000000, 59 | "mlm:memory_size": 1, 60 | "mlm:total_parameters": 11700000, 61 | "mlm:pretrained_source": "EuroSat Sentinel-2", 62 | "mlm:accelerator": "cuda", 63 | "mlm:accelerator_constrained": false, 64 | "mlm:accelerator_summary": "Unknown", 65 | "mlm:batch_size_suggestion": 256, 66 | "mlm:input": [ 67 | { 68 | "name": "13 Band Sentinel-2 Batch", 69 | "bands": [ 70 | "B01", 71 | "B02", 72 | "B03", 73 | "B04", 74 | "B05", 75 | "B06", 76 | "B07", 77 | "B08", 78 | "B8A", 79 | "B09", 80 | "B10", 81 | "B11", 82 | "B12" 83 | ], 84 | "input": { 85 | "shape": [ 86 | -1, 87 | 13, 88 | 64, 89 | 64 90 | ], 91 | "dim_order": [ 92 | "batch", 93 | "bands", 94 | "height", 95 | "width" 96 | ], 97 | "data_type": "float32" 98 | }, 99 | "norm_by_channel": true, 100 | "resize_type": null, 101 | "value_scaling": [ 102 | { 103 | "type": "z-score", 104 | "mean": 1354.40546513, 105 | "stddev": 245.71762908 106 | }, 107 | { 108 | "type": "z-score", 109 | "mean": 1118.24399958, 110 | "stddev": 333.00778264 111 | }, 112 | { 113 | "type": "z-score", 114 | "mean": 1042.92983953, 115 | "stddev": 395.09249139 116 | }, 117 | { 118 | "type": "z-score", 119 | "mean": 947.62620298, 120 | "stddev": 593.75055589 121 | }, 122 | { 123 | "type": "z-score", 124 | "mean": 1199.47283961, 125 | "stddev": 566.4170017 126 | }, 127 | { 128 | "type": "z-score", 129 | "mean": 1999.79090914, 130 | "stddev": 861.18399006 131 | }, 132 | { 133 | "type": "z-score", 134 | "mean": 2369.22292565, 135 | "stddev": 1086.63139075 136 | }, 137 | { 138 | "type": "z-score", 139 | "mean": 2296.82608323, 140 | "stddev": 1117.98170791 141 | }, 142 | { 143 | "type": "z-score", 144 | "mean": 732.08340178, 145 | "stddev": 404.91978886 146 | }, 147 | { 148 | "type": "z-score", 149 | "mean": 12.11327804, 150 | "stddev": 4.77584468 151 | }, 152 | { 153 | "type": "z-score", 154 | "mean": 1819.01027855, 155 | "stddev": 1002.58768311 156 | }, 157 | { 158 | "type": "z-score", 159 | "mean": 1118.92391149, 160 | "stddev": 761.30323499 161 | }, 162 | { 163 | "type": "z-score", 164 | "mean": 2594.14080798, 165 | "stddev": 1231.58581042 166 | } 167 | ], 168 | "pre_processing_function": { 169 | "format": "python", 170 | "expression": "torchgeo.datamodules.eurosat.EuroSATDataModule.collate_fn" 171 | } 172 | } 173 | ], 174 | "mlm:output": [ 175 | { 176 | "name": "classification", 177 | "tasks": [ 178 | "classification" 179 | ], 180 | "result": { 181 | "shape": [ 182 | -1, 183 | 10 184 | ], 185 | "dim_order": [ 186 | "batch", 187 | "class" 188 | ], 189 | "data_type": "float32" 190 | }, 191 | "classification_classes": [ 192 | { 193 | "value": 0, 194 | "name": "Annual Crop", 195 | "description": null, 196 | "title": null, 197 | "color_hint": null, 198 | "nodata": false 199 | }, 200 | { 201 | "value": 1, 202 | "name": "Forest", 203 | "description": null, 204 | "title": null, 205 | "color_hint": null, 206 | "nodata": false 207 | }, 208 | { 209 | "value": 2, 210 | "name": "Herbaceous Vegetation", 211 | "description": null, 212 | "title": null, 213 | "color_hint": null, 214 | "nodata": false 215 | }, 216 | { 217 | "value": 3, 218 | "name": "Highway", 219 | "description": null, 220 | "title": null, 221 | "color_hint": null, 222 | "nodata": false 223 | }, 224 | { 225 | "value": 4, 226 | "name": "Industrial Buildings", 227 | "description": null, 228 | "title": null, 229 | "color_hint": null, 230 | "nodata": false 231 | }, 232 | { 233 | "value": 5, 234 | "name": "Pasture", 235 | "description": null, 236 | "title": null, 237 | "color_hint": null, 238 | "nodata": false 239 | }, 240 | { 241 | "value": 6, 242 | "name": "Permanent Crop", 243 | "description": null, 244 | "title": null, 245 | "color_hint": null, 246 | "nodata": false 247 | }, 248 | { 249 | "value": 7, 250 | "name": "Residential Buildings", 251 | "description": null, 252 | "title": null, 253 | "color_hint": null, 254 | "nodata": false 255 | }, 256 | { 257 | "value": 8, 258 | "name": "River", 259 | "description": null, 260 | "title": null, 261 | "color_hint": null, 262 | "nodata": false 263 | }, 264 | { 265 | "value": 9, 266 | "name": "SeaLake", 267 | "description": null, 268 | "title": null, 269 | "color_hint": null, 270 | "nodata": false 271 | } 272 | ], 273 | "post_processing_function": null 274 | } 275 | ] 276 | }, 277 | "assets": { 278 | "weights": { 279 | "href": "https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth", 280 | "title": "Pytorch weights checkpoint", 281 | "description": "A Resnet-18 classification model trained on normalized Sentinel-2 imagery with Eurosat landcover labels with torchgeo", 282 | "type": "application/octet-stream; application=pytorch", 283 | "roles": [ 284 | "mlm:model", 285 | "mlm:weights" 286 | ], 287 | "mlm:artifact_type": "torch.save", 288 | "$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.", 289 | "eo:bands": [ 290 | { 291 | "name": "B01", 292 | "common_name": "coastal", 293 | "description": "Coastal aerosol (band 1)", 294 | "center_wavelength": 0.443, 295 | "full_width_half_max": 0.027 296 | }, 297 | { 298 | "name": "B02", 299 | "common_name": "blue", 300 | "description": "Blue (band 2)", 301 | "center_wavelength": 0.49, 302 | "full_width_half_max": 0.098 303 | }, 304 | { 305 | "name": "B03", 306 | "common_name": "green", 307 | "description": "Green (band 3)", 308 | "center_wavelength": 0.56, 309 | "full_width_half_max": 0.045 310 | }, 311 | { 312 | "name": "B04", 313 | "common_name": "red", 314 | "description": "Red (band 4)", 315 | "center_wavelength": 0.665, 316 | "full_width_half_max": 0.038 317 | }, 318 | { 319 | "name": "B05", 320 | "common_name": "rededge", 321 | "description": "Red edge 1 (band 5)", 322 | "center_wavelength": 0.704, 323 | "full_width_half_max": 0.019 324 | }, 325 | { 326 | "name": "B06", 327 | "common_name": "rededge", 328 | "description": "Red edge 2 (band 6)", 329 | "center_wavelength": 0.74, 330 | "full_width_half_max": 0.018 331 | }, 332 | { 333 | "name": "B07", 334 | "common_name": "rededge", 335 | "description": "Red edge 3 (band 7)", 336 | "center_wavelength": 0.783, 337 | "full_width_half_max": 0.028 338 | }, 339 | { 340 | "name": "B08", 341 | "common_name": "nir", 342 | "description": "NIR 1 (band 8)", 343 | "center_wavelength": 0.842, 344 | "full_width_half_max": 0.145 345 | }, 346 | { 347 | "name": "B8A", 348 | "common_name": "nir08", 349 | "description": "NIR 2 (band 8A)", 350 | "center_wavelength": 0.865, 351 | "full_width_half_max": 0.033 352 | }, 353 | { 354 | "name": "B09", 355 | "common_name": "nir09", 356 | "description": "NIR 3 (band 9)", 357 | "center_wavelength": 0.945, 358 | "full_width_half_max": 0.026 359 | }, 360 | { 361 | "name": "B10", 362 | "common_name": "cirrus", 363 | "description": "SWIR - Cirrus (band 10)", 364 | "center_wavelength": 1.375, 365 | "full_width_half_max": 0.026 366 | }, 367 | { 368 | "name": "B11", 369 | "common_name": "swir16", 370 | "description": "SWIR 1 (band 11)", 371 | "center_wavelength": 1.61, 372 | "full_width_half_max": 0.143 373 | }, 374 | { 375 | "name": "B12", 376 | "common_name": "swir22", 377 | "description": "SWIR 2 (band 12)", 378 | "center_wavelength": 2.19, 379 | "full_width_half_max": 0.242 380 | } 381 | ] 382 | }, 383 | "source_code": { 384 | "href": "https://github.com/microsoft/torchgeo/blob/61efd2e2c4df7ebe3bd03002ebbaeaa3cfe9885a/torchgeo/models/resnet.py#L207", 385 | "title": "Model implementation.", 386 | "description": "Source code to run the model.", 387 | "type": "text/x-python", 388 | "roles": [ 389 | "mlm:source_code", 390 | "code", 391 | "metadata" 392 | ] 393 | } 394 | }, 395 | "links": [ 396 | { 397 | "rel": "collection", 398 | "href": "./collection.json", 399 | "type": "application/json" 400 | }, 401 | { 402 | "rel": "self", 403 | "href": "./item_eo_bands.json", 404 | "type": "application/geo+json" 405 | }, 406 | { 407 | "rel": "derived_from", 408 | "href": "https://earth-search.aws.element84.com/v1/collections/sentinel-2-l2a", 409 | "type": "application/json", 410 | "ml-aoi:split": "train" 411 | } 412 | ] 413 | } 414 | --------------------------------------------------------------------------------