├── .copier-answers.yml ├── .github └── workflows │ ├── cd.yml │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── dataset.rst │ ├── download.rst │ ├── index.rst │ ├── metrics.rst │ ├── models.rst │ ├── usage.md │ ├── utils.rst │ └── validation.rst ├── examples └── find_test_2022_t0_times.py ├── notebooks ├── 01-plotting.ipynb ├── 02-data_loader_demo.ipynb ├── 03-score_model_demo.ipynb └── requirements.txt ├── pyproject.toml ├── src └── cloudcasting │ ├── __init__.py │ ├── __main__.py │ ├── _version.pyi │ ├── cli.py │ ├── constants.py │ ├── data │ └── test_2022_t0_times.csv.zip │ ├── dataset.py │ ├── download.py │ ├── metrics.py │ ├── models.py │ ├── py.typed │ ├── types.py │ ├── utils.py │ └── validation.py ├── tests ├── conftest.py ├── legacy_metrics.py ├── test_cli.py ├── test_data │ └── non_hrv_shell.netcdf ├── test_dataset.py ├── test_download.py ├── test_metrics.py ├── test_models.py └── test_validation.py └── uv.lock /.copier-answers.yml: -------------------------------------------------------------------------------- 1 | # Changes here will be overwritten by Copier 2 | _commit: v0.1.0-8-ga0c8676 3 | _src_path: gh:alan-turing-institute/python-project-template 4 | coc: our_coc 5 | email: nsimpson@turing.ac.uk 6 | full_name: cloudcasting Maintainers 7 | license: MIT 8 | min_python_version: '3.10' 9 | org: climetrend 10 | project_name: cloudcasting 11 | project_short_description: Tooling and infrastructure to enable cloud nowcasting. 12 | python_name: cloudcasting 13 | typing: strict 14 | url: https://github.com/climetrend/cloudcasting 15 | -------------------------------------------------------------------------------- /.github/workflows/cd.yml: -------------------------------------------------------------------------------- 1 | name: CD 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | push: 7 | branches: 8 | - main 9 | release: 10 | types: 11 | - published 12 | 13 | jobs: 14 | dist: 15 | needs: [pre-commit] 16 | name: Distribution build 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | with: 22 | fetch-depth: 0 23 | 24 | - name: Build sdist and wheel 25 | run: pipx run build 26 | 27 | - uses: actions/upload-artifact@v3 28 | with: 29 | path: dist 30 | 31 | - name: Check products 32 | run: pipx run twine check dist/* 33 | 34 | publish: 35 | needs: [dist] 36 | name: Publish to PyPI 37 | environment: pypi 38 | permissions: 39 | id-token: write 40 | runs-on: ubuntu-latest 41 | if: github.event_name == 'release' && github.event.action == 'published' 42 | 43 | steps: 44 | - uses: actions/download-artifact@v3 45 | with: 46 | name: artifact 47 | path: dist 48 | 49 | - uses: pypa/gh-action-pypi-publish@release/v1 50 | if: github.event_name == 'release' && github.event.action == 'published' 51 | with: 52 | # Remove this line to publish to PyPI 53 | repository-url: https://test.pypi.org/legacy/ 54 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | push: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | pre-commit: 12 | name: Format + lint code 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | with: 17 | fetch-depth: 0 18 | - uses: actions/setup-python@v4 19 | with: 20 | python-version: "3.x" 21 | - uses: pre-commit/action@v3.0.0 22 | with: 23 | extra_args: --hook-stage manual --all-files 24 | 25 | checks: 26 | name: Run tests for Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} 27 | runs-on: ${{ matrix.runs-on }} 28 | needs: [pre-commit] 29 | strategy: 30 | fail-fast: false 31 | matrix: 32 | python-version: ["3.10", "3.12"] # test oldest and latest supported versions 33 | runs-on: [ubuntu-latest, macos-latest, windows-latest] # can be extended to other OSes, e.g. [ubuntu-latest, macos-latest] 34 | 35 | steps: 36 | - uses: actions/checkout@v4 37 | with: 38 | fetch-depth: 0 39 | 40 | - uses: actions/setup-python@v4 41 | with: 42 | python-version: ${{ matrix.python-version }} 43 | allow-prereleases: true 44 | 45 | - name: Install ffmpeg on macOS 46 | if: runner.os == 'macOS' 47 | run: | 48 | brew install ffmpeg 49 | 50 | - name: Install package 51 | run: python -m pip install .[dev] 52 | 53 | - name: Install nightly build of torch 54 | if: matrix.runs-on == 'windows-latest' 55 | run: pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall --upgrade 56 | 57 | - name: Test package 58 | run: >- 59 | python -m pytest -ra --cov --cov-report=xml --cov-report=term 60 | --durations=20 61 | 62 | - name: Upload coverage report 63 | uses: codecov/codecov-action@v3.1.4 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # setuptools_scm 141 | src/*/_version.py 142 | 143 | 144 | # ruff 145 | .ruff_cache/ 146 | 147 | # OS specific stuff 148 | .DS_Store 149 | .DS_Store? 150 | ._* 151 | .Spotlight-V100 152 | .Trashes 153 | ehthumbs.db 154 | Thumbs.db 155 | 156 | # Common editor files 157 | *~ 158 | *.swp 159 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autoupdate_commit_msg: "chore: update pre-commit hooks" 3 | autofix_commit_msg: "style: pre-commit fixes" 4 | 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: "v5.0.0" 8 | hooks: 9 | - id: check-added-large-files 10 | - id: check-case-conflict 11 | - id: check-merge-conflict 12 | - id: check-symlinks 13 | - id: check-yaml 14 | - id: debug-statements 15 | - id: end-of-file-fixer 16 | - id: mixed-line-ending 17 | - id: name-tests-test 18 | args: ["--pytest-test-first"] 19 | exclude: ^tests/legacy_metrics.py 20 | - id: requirements-txt-fixer 21 | - id: trailing-whitespace 22 | 23 | - repo: https://github.com/astral-sh/ruff-pre-commit 24 | rev: "v0.9.3" 25 | hooks: 26 | # first, lint + autofix 27 | - id: ruff 28 | types_or: [python, pyi, jupyter] 29 | args: ["--fix", "--show-fixes"] 30 | # then, format 31 | - id: ruff-format 32 | 33 | - repo: https://github.com/pre-commit/mirrors-mypy 34 | rev: "v1.14.1" 35 | hooks: 36 | - id: mypy 37 | args: [] 38 | additional_dependencies: 39 | - pytest 40 | - xarray 41 | - pandas-stubs 42 | - typer 43 | - dask 44 | - pyproj 45 | - pyresample 46 | - lightning 47 | - torch 48 | - jaxtyping 49 | - types-tqdm 50 | - chex 51 | - types-PyYAML 52 | - wandb 53 | - matplotlib 54 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.11" 7 | 8 | python: 9 | install: 10 | - method: pip 11 | path: . 12 | extra_requirements: 13 | - doc 14 | 15 | sphinx: 16 | configuration: docs/source/conf.py 17 | fail_on_warning: true 18 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | We value the participation of every member of our community and want to ensure 4 | that every contributor has an enjoyable and fulfilling experience. Accordingly, 5 | everyone who participates in the cloudcasting project is expected to show respect and courtesy to other community members at all time. 6 | 7 | In the interest of fostering an open and welcoming environment, we as 8 | contributors and maintainers are dedicated to making participation in our project 9 | a harassment-free experience for everyone, regardless of age, body 10 | size, disability, ethnicity, sex characteristics, gender identity and expression, 11 | level of experience, education, socio-economic status, nationality, personal 12 | appearance, race, religion, or sexual identity and orientation. 13 | 14 | ## Our Standards 15 | 16 | Examples of behaviour that contributes to creating a positive environment 17 | include: 18 | 19 | - Using welcoming and inclusive language 20 | - Being respectful of differing viewpoints and experiences 21 | - Gracefully accepting constructive criticism 22 | - Focusing on what is best for the community 23 | - Showing empathy towards other community members 24 | 25 | Examples of unacceptable behaviour by participants include: 26 | 27 | - The use of sexualized language or imagery and unwelcome sexual attention or 28 | advances 29 | - Trolling, insulting/derogatory comments, and personal or political attacks 30 | - Public or private harassment 31 | - Publishing others' private information, such as a physical or electronic 32 | address, without explicit permission 33 | - Other conduct which could reasonably be considered inappropriate in a 34 | professional setting 35 | 36 | 43 | 44 | ## Attribution 45 | 46 | This Code of Conduct is adapted from the [Turing Data Stories Code of Conduct](https://github.com/alan-turing-institute/TuringDataStories/blob/main/CODE_OF_CONDUCT.md) which is based on the [scona project Code of Conduct](https://github.com/WhitakerLab/scona/blob/master/CODE_OF_CONDUCT.md) 47 | and the [Contributor Covenant](https://www.contributor-covenant.org), version [1.4](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) 48 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | See the [Scientific Python Developer Guide][spc-dev-intro] for a detailed 2 | description of best practices for developing scientific packages. 3 | 4 | [spc-dev-intro]: https://learn.scientific-python.org/development/ 5 | 6 | # Setting up a development environment manually 7 | 8 | You can set up a development environment by running: 9 | 10 | ```zsh 11 | python3 -m venv venv # create a virtualenv called venv 12 | source ./venv/bin/activate # now `python` points to the virtualenv python 13 | pip install -v -e ".[dev]" # -v for verbose, -e for editable, [dev] for dev dependencies 14 | ``` 15 | 16 | # Post setup 17 | 18 | You should prepare pre-commit, which will help you by checking that commits pass 19 | required checks: 20 | 21 | ```bash 22 | pip install pre-commit # or brew install pre-commit on macOS 23 | pre-commit install # this will install a pre-commit hook into the git repo 24 | ``` 25 | 26 | You can also/alternatively run `pre-commit run` (changes only) or 27 | `pre-commit run --all-files` to check even without installing the hook. 28 | 29 | # Testing 30 | 31 | Use pytest to run the unit checks: 32 | 33 | ```bash 34 | pytest 35 | ``` 36 | 37 | # Coverage 38 | 39 | Use pytest-cov to generate coverage reports: 40 | 41 | ```bash 42 | pytest --cov=cloudcasting 43 | ``` 44 | 45 | # Pre-commit 46 | 47 | This project uses pre-commit for all style checking. Install pre-commit and run: 48 | 49 | ```bash 50 | pre-commit run -a 51 | ``` 52 | 53 | to check all files. 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 cloudcasting Maintainers 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cloudcasting 2 | 3 | [![Actions Status][actions-badge]][actions-link] 4 | [![Documentation status badge](https://readthedocs.org/projects/cloudcasting/badge/?version=latest)](https://cloudcasting.readthedocs.io/en/latest/?badge=latest) 5 | 6 | Tooling and infrastructure to enable cloud nowcasting. Full documentation can be found at https://cloudcasting.readthedocs.io/. 7 | 8 | ## Linked model repos 9 | - [Optical Flow (Farneback)](https://github.com/alan-turing-institute/ocf-optical-flow) 10 | - [Optical Flow (TVL1)](https://github.com/alan-turing-institute/ocf-optical-flow-tvl1) 11 | - [Diffusion model](https://github.com/alan-turing-institute/ocf-diffusion) 12 | - [ConvLSTM](https://github.com/alan-turing-institute/ocf-convLSTM) 13 | - [IAM4VP](https://github.com/alan-turing-institute/ocf-iam4vp) 14 | 15 | The model template repo on which these are based is found [here](https://github.com/alan-turing-institute/ocf-model-template). These repositories contain the implementations of each model, as well as validation infrastructure to replicate metric scores on weights and biases. 16 | 17 | ## Installation 18 | 19 | ### For users: 20 | 21 | ```zsh 22 | git clone https://github.com/alan-turing-institute/cloudcasting 23 | cd cloudcasting 24 | python -m pip install . 25 | ``` 26 | 27 | To run metrics on GPU: 28 | 29 | ```zsh 30 | python -m pip install --upgrade "jax[cuda12]" 31 | ``` 32 | ### For making changes to the library: 33 | 34 | On macOS you first need to install `ffmpeg` with the following command. On other platforms this is 35 | not necessary. 36 | 37 | ```bash 38 | brew install ffmpeg 39 | ``` 40 | 41 | Clone and install the repo. 42 | 43 | ```bash 44 | git clone https://github.com/alan-turing-institute/cloudcasting 45 | cd cloudcasting 46 | python -m pip install ".[dev]" 47 | ``` 48 | 49 | Install pre-commit before making development changes: 50 | 51 | ```bash 52 | pre-commit install 53 | ``` 54 | 55 | For making changes, see the [guidance on development](https://github.com/alan-turing-institute/python-project-template?tab=readme-ov-file#setting-up-a-new-project) from the template that generated this project. 56 | 57 | ## Usage 58 | 59 | ### Validating a model 60 | ```bash 61 | cloudcasting validate "path/to/config/file.yml" "path/to/model/file.py" 62 | ``` 63 | 64 | ### Downloading data 65 | ```bash 66 | cloudcasting download "2020-06-01 00:00" "2020-06-30 23:55" "path/to/data/save/dir" 67 | ``` 68 | 69 | Full options: 70 | 71 | ```bash 72 | > cloudcasting download --help 73 | 74 | Usage: cloudcasting download [OPTIONS] START_DATE 75 | END_DATE OUTPUT_DIRECTORY 76 | 77 | ╭─ Arguments ──────────────────────────────────────────╮ 78 | │ * start_date TEXT Start date in │ 79 | │ 'YYYY-MM-DD HH:MM' │ 80 | │ format │ 81 | │ [default: None] │ 82 | │ [required] │ 83 | │ * end_date TEXT End date in │ 84 | │ 'YYYY-MM-DD HH:MM' │ 85 | │ format │ 86 | │ [default: None] │ 87 | │ [required] │ 88 | │ * output_directory TEXT Directory to save │ 89 | │ the satellite data │ 90 | │ [default: None] │ 91 | │ [required] │ 92 | ╰──────────────────────────────────────────────────────╯ 93 | ╭─ Options ────────────────────────────────────────────╮ 94 | │ --download-f… TEXT Frequency to │ 95 | │ download data │ 96 | │ in pandas │ 97 | │ datetime │ 98 | │ format │ 99 | │ [default: │ 100 | │ 15min] │ 101 | │ --get-hrv --no-get-h… Whether to │ 102 | │ download HRV │ 103 | │ data │ 104 | │ [default: │ 105 | │ no-get-hrv] │ 106 | │ --override-d… --no-overr… Whether to │ 107 | │ override date │ 108 | │ range limits │ 109 | │ [default: │ 110 | │ no-override-… │ 111 | │ --lon-min FLOAT Minimum │ 112 | │ longitude │ 113 | │ [default: │ 114 | │ -16] │ 115 | │ --lon-max FLOAT Maximum │ 116 | │ longitude │ 117 | │ [default: 10] │ 118 | │ --lat-min FLOAT Minimum │ 119 | │ latitude │ 120 | │ [default: 45] │ 121 | │ --lat-max FLOAT Maximum │ 122 | │ latitude │ 123 | │ [default: 70] │ 124 | │ --test-2022-… --no-test-… Whether to │ 125 | │ filter data │ 126 | │ from 2022 to │ 127 | │ download the │ 128 | │ test set │ 129 | │ (every 2 │ 130 | │ weeks). │ 131 | │ [default: │ 132 | │ no-test-2022… │ 133 | │ --verify-202… --no-verif… Whether to │ 134 | │ download the │ 135 | │ verification │ 136 | │ data from │ 137 | │ 2023. Only │ 138 | │ used at the │ 139 | │ end of the │ 140 | │ project │ 141 | │ [default: │ 142 | │ no-verify-20… | 143 | │ --help Show this │ 144 | │ message and │ 145 | │ exit. │ 146 | ╰──────────────────────────────────────────────────────╯ 147 | ``` 148 | 149 | ## Contributing 150 | 151 | See [CONTRIBUTING.md](CONTRIBUTING.md) for instructions on how to contribute. 152 | 153 | ## License 154 | 155 | Distributed under the terms of the [MIT license](LICENSE). 156 | 157 | 158 | 159 | [actions-badge]: https://github.com/alan-turing-institute/cloudcasting/actions/workflows/ci.yml/badge.svg?branch=main 160 | [actions-link]: https://github.com/alan-turing-institute/cloudcasting/actions 161 | [pypi-link]: https://pypi.org/project/cloudcasting/ 162 | [pypi-platforms]: https://img.shields.io/pypi/pyversions/cloudcasting 163 | [pypi-version]: https://img.shields.io/pypi/v/cloudcasting 164 | 165 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Path setup -------------------------------------------------------------- 7 | 8 | # If extensions (or modules to document with autodoc) are in another directory, 9 | # add these directories to sys.path here. If the directory is relative to the 10 | # documentation root, use os.path.abspath to make it absolute, like shown here. 11 | # 12 | import os 13 | import sys 14 | 15 | sys.path.insert(0, os.path.abspath("..")) 16 | 17 | # -- Project information ----------------------------------------------------- 18 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 19 | 20 | project = "cloudcasting" 21 | copyright = "2025, cloudcasting Maintainers" 22 | author = "cloudcasting Maintainers" 23 | release = "0.6" 24 | version = "0.6.0" 25 | 26 | # -- General configuration --------------------------------------------------- 27 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 28 | 29 | extensions = [ 30 | "sphinx.ext.duration", 31 | "sphinx.ext.doctest", 32 | "sphinx.ext.autodoc", 33 | "sphinx.ext.autosummary", 34 | "sphinx.ext.intersphinx", 35 | "sphinx.ext.coverage", 36 | "sphinx.ext.napoleon", 37 | "m2r2", 38 | ] 39 | 40 | intersphinx_mapping = { 41 | "python": ("https://docs.python.org/3/", None), 42 | "sphinx": ("https://www.sphinx-doc.org/en/master/", None), 43 | } 44 | intersphinx_disabled_domains = ["std"] 45 | 46 | # -- Options for HTML output ------------------------------------------------- 47 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 48 | 49 | html_theme = "sphinx_rtd_theme" 50 | -------------------------------------------------------------------------------- /docs/source/dataset.rst: -------------------------------------------------------------------------------- 1 | Dataset 2 | ======= 3 | 4 | .. automodule:: cloudcasting.dataset 5 | :members: 6 | :special-members: __init__ 7 | -------------------------------------------------------------------------------- /docs/source/download.rst: -------------------------------------------------------------------------------- 1 | Download 2 | ======== 3 | 4 | .. automodule:: cloudcasting.download 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | 2 | Documentation for cloudcasting 3 | ============================== 4 | 5 | Tooling and infrastructure to enable cloud nowcasting. 6 | Check out the :doc:`usage` section for further information on how to install and run this package. 7 | 8 | This tool was developed by `Open Climate Fix `_ and 9 | `The Alan Turing Institute `_ as part of the 10 | `Manchester Prize `_. 11 | 12 | Contents 13 | -------- 14 | 15 | .. toctree:: 16 | :maxdepth: 2 17 | 18 | usage 19 | dataset 20 | download 21 | metrics 22 | models 23 | utils 24 | validation 25 | 26 | License 27 | ------- 28 | 29 | The cloudcasting software is released under an `MIT License `_. 30 | 31 | Index 32 | ----- 33 | 34 | * :ref:`genindex` 35 | -------------------------------------------------------------------------------- /docs/source/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | ======= 3 | 4 | .. automodule:: cloudcasting.metrics 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/models.rst: -------------------------------------------------------------------------------- 1 | Models 2 | ====== 3 | 4 | .. automodule:: cloudcasting.models 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/usage.md: -------------------------------------------------------------------------------- 1 | .. _usage: 2 | 3 | User guide 4 | ========== 5 | 6 | **Contents:** 7 | 8 | - :ref:`install` 9 | - :ref:`optional` 10 | - :ref:`getting_started` 11 | 12 | .. _install: 13 | 14 | Installation 15 | ------------ 16 | 17 | To use cloudcasting, first install it using pip: 18 | 19 | ```bash 20 | git clone https://github.com/alan-turing-institute/cloudcasting 21 | cd cloudcasting 22 | python -m pip install . 23 | ``` 24 | 25 | .. _optional: 26 | 27 | Optional dependencies 28 | --------------------- 29 | 30 | cloudcasting supports optional dependencies, which are not installed by default. These dependencies are required for certain functionality. 31 | 32 | To run the metrics on GPU: 33 | 34 | ```bash 35 | python -m pip install --upgrade "jax[cuda12]" 36 | ``` 37 | 38 | To make changes to the library, it is necessary to install the extra `dev` dependencies, and install pre-commit: 39 | 40 | ```bash 41 | python -m pip install ".[dev]" 42 | pre-commit install 43 | ``` 44 | 45 | To create the documentation, it is necessary to install the extra `doc` dependencies: 46 | 47 | ```bash 48 | python -m pip install ".[doc]" 49 | ``` 50 | 51 | .. _getting_started: 52 | 53 | Getting started 54 | --------------- 55 | 56 | Use the command line interface to download data: 57 | 58 | ```bash 59 | cloudcasting download "2020-06-01 00:00" "2020-06-30 23:55" "path/to/data/save/dir" 60 | ``` 61 | 62 | Once you have developed a model, you can also validate the model, calculating a set of metrics with a standard dataset. 63 | To make use of the cli tool, use the [model github repo template](https://github.com/alan-turing-institute/ocf-model-template) to structure it correctly for validation. 64 | 65 | ```bash 66 | cloudcasting validate "path/to/config/file.yml" "path/to/model/file.py" 67 | ``` 68 | -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ===== 3 | 4 | .. automodule:: cloudcasting.utils 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/validation.rst: -------------------------------------------------------------------------------- 1 | Validation 2 | ========== 3 | 4 | .. automodule:: cloudcasting.validation 5 | :members: 6 | -------------------------------------------------------------------------------- /examples/find_test_2022_t0_times.py: -------------------------------------------------------------------------------- 1 | """This script finds the 2022 test set t0 times and saves them to the cloudcasting package.""" 2 | 3 | import importlib.util 4 | import os 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import xarray as xr 9 | 10 | from cloudcasting.constants import DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES 11 | from cloudcasting.dataset import find_valid_t0_times 12 | from cloudcasting.download import _get_sat_public_dataset_path 13 | 14 | # Set a max history length which we will support in the validation process 15 | # We will not be able to fairly score models which require a longer history than this 16 | # But by setting this too long, we will reduce the samples we have to score on 17 | 18 | # The current FORECAST_HORIZON_MINUTES is 3 hours so we'll set this conservatively to 6 hours 19 | MAX_HISTORY_MINUTES = 6 * 60 20 | 21 | # We filter t0 times so they have to have a gap of at least this long between consecutive times 22 | MIN_GAP_SIZE = pd.Timedelta("1hour") 23 | 24 | # Open the 2022 dataset 25 | ds = xr.open_zarr(_get_sat_public_dataset_path(2022, is_hrv=False)) 26 | 27 | # Filter to defined time frequency 28 | mask = np.mod(ds.time.dt.minute, DATA_INTERVAL_SPACING_MINUTES) == 0 29 | ds = ds.sel(time=mask) 30 | 31 | # Mask to the odd fortnights - i.e. the 2022 test set 32 | mask = np.mod(ds.time.dt.dayofyear // 14, 2) == 1 33 | ds = ds.sel(time=mask) 34 | 35 | 36 | # Find the t0 times that we have satellite data for 37 | available_t0_times = find_valid_t0_times( 38 | datetimes=pd.DatetimeIndex(ds.time), 39 | history_mins=MAX_HISTORY_MINUTES, 40 | forecast_mins=FORECAST_HORIZON_MINUTES, 41 | sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES, 42 | ) 43 | 44 | # Filter the t0 times so they have gaps of at least 1 hour 45 | _filtered_t0_times = [available_t0_times[0]] 46 | 47 | for t in available_t0_times[1:]: 48 | if (t - _filtered_t0_times[-1]) >= MIN_GAP_SIZE: 49 | _filtered_t0_times.append(t) 50 | 51 | filtered_t0_times = pd.DatetimeIndex(_filtered_t0_times) 52 | 53 | 54 | # Print the valid t0 times to sanity check 55 | print(f"Number of available t0 times: {len(filtered_t0_times)}") 56 | print(f"Actual available t0 times: {filtered_t0_times}") 57 | 58 | 59 | # Find the path of the cloudcasting package so we can save the valid times into it 60 | spec = importlib.util.find_spec("cloudcasting") 61 | if spec and spec.origin: 62 | package_path = os.path.dirname(spec.origin) 63 | else: 64 | msg = "Path of package `cloudcasting` can not be found" 65 | raise ModuleNotFoundError(msg) 66 | 67 | # Save the valid t0 times 68 | filename = "test_2022_t0_times.csv" 69 | df = pd.DataFrame(filtered_t0_times, columns=["t0_time"]).set_index("t0_time") 70 | df.to_csv( 71 | f"{package_path}/data/{filename}.zip", 72 | compression={ 73 | "method": "zip", 74 | "archive_name": filename, 75 | }, 76 | ) 77 | -------------------------------------------------------------------------------- /notebooks/requirements.txt: -------------------------------------------------------------------------------- 1 | ipykernel 2 | matplotlib 3 | seaborn 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61", "setuptools_scm[toml]>=7"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [project] 7 | name = "cloudcasting" 8 | dynamic = ["version"] 9 | authors = [ 10 | { name = "cloudcasting Maintainers", email = "clouds@turing.ac.uk" }, 11 | ] 12 | description = "Tooling and infrastructure to enable cloud nowcasting." 13 | readme = "README.md" 14 | requires-python = ">=3.10" 15 | classifiers = [ 16 | "Development Status :: 1 - Planning", 17 | "Intended Audience :: Science/Research", 18 | "Intended Audience :: Developers", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | "Programming Language :: Python", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3 :: Only", 24 | "Programming Language :: Python :: 3.10", 25 | "Programming Language :: Python :: 3.11", 26 | "Programming Language :: Python :: 3.12", 27 | "Topic :: Scientific/Engineering", 28 | "Typing :: Typed", 29 | ] 30 | dependencies = [ 31 | "gcsfs", 32 | "zarr<3.0.0", # ocf_blosc2 compatibility 33 | "xarray", 34 | "dask", 35 | "pyresample", 36 | "pyproj", 37 | "pykdtree<=1.3.12", # for macOS 38 | "ocf-blosc2>=0.0.10", # for no-import codec register 39 | "typer", 40 | "lightning", 41 | "torch>=2.3.0", # needed for numpy 2.0 42 | "jaxtyping<=0.2.34", # currently >0.2.34 causing errors 43 | "wandb", 44 | "tqdm", 45 | "moviepy==1.0.3", # currently >1.0.3 not working with wandb 46 | "imageio>=2.35.1", 47 | "numpy<2.1.0", # https://github.com/wandb/wandb/issues/8166 48 | "chex", 49 | "matplotlib" 50 | ] 51 | [project.optional-dependencies] 52 | dev = [ 53 | "pytest >=6", 54 | "pytest-cov >=3", 55 | "pre-commit", 56 | "scipy", 57 | "pytest-mock", 58 | "scikit-image", 59 | "typeguard", 60 | ] 61 | doc = [ 62 | "sphinx", 63 | "sphinx-rtd-theme", 64 | "m2r2" 65 | ] 66 | 67 | [tool.setuptools.package-data] 68 | "cloudcasting" = ["data/*.zip"] 69 | 70 | [tool.setuptools_scm] 71 | write_to = "src/cloudcasting/_version.py" 72 | 73 | [project.scripts] 74 | cloudcasting = "cloudcasting.cli:app" 75 | 76 | [project.urls] 77 | Homepage = "https://github.com/alan-turing-institute/cloudcasting" 78 | "Bug Tracker" = "https://github.com/alan-turing-institute/cloudcasting/issues" 79 | Discussions = "https://github.com/alan-turing-institute/cloudcasting/discussions" 80 | Changelog = "https://github.com/alan-turing-institute/cloudcasting/releases" 81 | 82 | [tool.pytest.ini_options] 83 | minversion = "6.0" 84 | addopts = [ 85 | "-ra", 86 | "--showlocals", 87 | "--strict-markers", 88 | "--strict-config" 89 | ] 90 | xfail_strict = true 91 | filterwarnings = [ 92 | "error", 93 | "ignore:pkg_resources:DeprecationWarning", # lightning 94 | "ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning", # lightning 95 | "ignore:ast.Str is deprecated:DeprecationWarning", # jaxtyping 96 | "ignore:`newshape` keyword argument is deprecated:DeprecationWarning", # wandb using numpy 2.1.0 97 | "ignore:The keyword `fps` is no longer supported:DeprecationWarning", # wandb.Video 98 | "ignore:torch.onnx.dynamo_export is deprecated since 2.6.0:DeprecationWarning", # lighning fabric torch 2.6+ 99 | ] 100 | log_cli_level = "INFO" 101 | testpaths = [ 102 | "tests", 103 | ] 104 | 105 | [tool.coverage] 106 | run.source = ["cloudcasting"] 107 | port.exclude_lines = [ 108 | 'pragma: no cover', 109 | '\.\.\.', 110 | 'if typing.TYPE_CHECKING:', 111 | ] 112 | 113 | [tool.mypy] 114 | files = ["src", "tests"] 115 | python_version = "3.10" 116 | show_error_codes = true 117 | warn_unreachable = true 118 | disallow_untyped_defs = false 119 | disallow_incomplete_defs = false 120 | check_untyped_defs = true 121 | strict = true 122 | enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] 123 | 124 | [[tool.mypy.overrides]] 125 | module = "cloudcasting.*" 126 | disallow_untyped_defs = true 127 | disallow_incomplete_defs = true 128 | 129 | [[tool.mypy.overrides]] 130 | module = [ 131 | "ocf_blosc2", 132 | ] 133 | ignore_missing_imports = true 134 | 135 | [[tool.mypy.overrides]] 136 | module = [ 137 | "cloudcasting.download", 138 | "cloudcasting.cli", 139 | "cloudcasting.validation", # use of wandb.update/Table 140 | ] 141 | disallow_untyped_calls = false 142 | 143 | [tool.ruff] 144 | src = ["src"] 145 | exclude = ["notebooks/*.ipynb"] 146 | line-length = 100 # how long you want lines to be 147 | 148 | [tool.ruff.format] 149 | docstring-code-format = true # code snippets in docstrings will be formatted 150 | 151 | [tool.ruff.lint] 152 | exclude = ["notebooks/*.ipynb"] 153 | select = [ 154 | "E", "F", "W", # flake8 155 | "B", # flake8-bugbear 156 | "I", # isort 157 | "ARG", # flake8-unused-arguments 158 | "C4", # flake8-comprehensions 159 | "EM", # flake8-errmsg 160 | "ICN", # flake8-import-conventions 161 | "ISC", # flake8-implicit-str-concat 162 | "G", # flake8-logging-format 163 | "PGH", # pygrep-hooks 164 | "PIE", # flake8-pie 165 | "PL", # pylint 166 | "PT", # flake8-pytest-style 167 | "RET", # flake8-return 168 | "RUF", # Ruff-specific 169 | "SIM", # flake8-simplify 170 | "UP", # pyupgrade 171 | "YTT", # flake8-2020 172 | "EXE", # flake8-executable 173 | ] 174 | ignore = [ 175 | "PLR", # Design related pylint codes 176 | "ISC001", # Conflicts with formatter 177 | "F722" # Marks jaxtyping syntax annotations as incorrect 178 | ] 179 | unfixable = [ 180 | "F401", # Would remove unused imports 181 | "F841", # Would remove unused variables 182 | ] 183 | flake8-unused-arguments.ignore-variadic-names = true # allow unused *args/**kwargs 184 | -------------------------------------------------------------------------------- /src/cloudcasting/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | cloudcasting: Tooling and infrastructure to enable cloud nowcasting. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from jaxtyping import install_import_hook 8 | 9 | # Any module imported inside this `with` block, whose 10 | # name begins with the specified string, will 11 | # automatically have both `@jaxtyped` and the 12 | # typechecker applied to all of their functions and 13 | # dataclasses, meaning that they will be type-checked 14 | # (and therefore shape-checked via jaxtyping) at runtime. 15 | with install_import_hook("cloudcasting", "typeguard.typechecked"): 16 | from cloudcasting import models, validation 17 | 18 | from cloudcasting import cli, dataset, download, metrics 19 | 20 | from ._version import version as __version__ 21 | 22 | __all__ = ( 23 | "__version__", 24 | "cli", 25 | "dataset", 26 | "download", 27 | "metrics", 28 | "models", 29 | "validation", 30 | ) 31 | -------------------------------------------------------------------------------- /src/cloudcasting/__main__.py: -------------------------------------------------------------------------------- 1 | from cloudcasting.cli import app 2 | 3 | app() 4 | -------------------------------------------------------------------------------- /src/cloudcasting/_version.pyi: -------------------------------------------------------------------------------- 1 | version: str 2 | -------------------------------------------------------------------------------- /src/cloudcasting/cli.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | from cloudcasting.download import download_satellite_data 4 | from cloudcasting.validation import validate_from_config 5 | 6 | # typer app code 7 | app = typer.Typer() 8 | app.command("download")(download_satellite_data) 9 | app.command("validate")(validate_from_config) 10 | -------------------------------------------------------------------------------- /src/cloudcasting/constants.py: -------------------------------------------------------------------------------- 1 | __all__ = ( 2 | "CUTOUT_MASK", 3 | "DATA_INTERVAL_SPACING_MINUTES", 4 | "FORECAST_HORIZON_MINUTES", 5 | "NUM_CHANNELS", 6 | "NUM_FORECAST_STEPS", 7 | ) 8 | 9 | from cloudcasting.utils import create_cutout_mask 10 | 11 | # These constants were locked as part of the project specification 12 | # 3 hour horecast horizon 13 | FORECAST_HORIZON_MINUTES = 180 14 | # at 15 minute intervals 15 | DATA_INTERVAL_SPACING_MINUTES = 15 16 | # gives 12 forecast steps 17 | NUM_FORECAST_STEPS = FORECAST_HORIZON_MINUTES // DATA_INTERVAL_SPACING_MINUTES 18 | # for all 11 low resolution channels 19 | NUM_CHANNELS = 11 20 | 21 | # Constants for the larger (original) image 22 | # Image size (height, width) 23 | IMAGE_SIZE_TUPLE = (372, 614) 24 | # Cutout mask (min x, max x, min y, max y) 25 | CUTOUT_MASK_BOUNDARY = (166, 336, 107, 289) 26 | # Create cutout mask 27 | CUTOUT_MASK = create_cutout_mask(CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE) 28 | 29 | # Constants for the smaller (cropped) image 30 | # Cropped image size (height, width) 31 | CROPPED_IMAGE_SIZE_TUPLE = (278, 385) 32 | # Cropped cutout mask (min x, max x, min y, max y) 33 | CROPPED_CUTOUT_MASK_BOUNDARY = (109, 279, 62, 244) 34 | # Create cropped cutout mask 35 | CROPPED_CUTOUT_MASK = create_cutout_mask(CROPPED_CUTOUT_MASK_BOUNDARY, CROPPED_IMAGE_SIZE_TUPLE) 36 | -------------------------------------------------------------------------------- /src/cloudcasting/data/test_2022_t0_times.csv.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/cloudcasting/0da4087bba01823d2477dba61ea3e6cfd557212c/src/cloudcasting/data/test_2022_t0_times.csv.zip -------------------------------------------------------------------------------- /src/cloudcasting/dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset and DataModule for past and future satellite data""" 2 | 3 | __all__ = ( 4 | "SatelliteDataModule", 5 | "SatelliteDataset", 6 | "ValidationSatelliteDataset", 7 | ) 8 | 9 | import io 10 | import pkgutil 11 | from datetime import datetime, timedelta 12 | from typing import TypedDict 13 | 14 | import numpy as np 15 | import pandas as pd 16 | import xarray as xr 17 | from lightning import LightningDataModule 18 | from numpy.typing import NDArray 19 | from torch.utils.data import DataLoader, Dataset 20 | 21 | from cloudcasting.constants import DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES 22 | from cloudcasting.utils import find_contiguous_t0_time_periods, find_contiguous_time_periods 23 | 24 | 25 | class DataloaderArgs(TypedDict): 26 | batch_size: int 27 | sampler: None 28 | batch_sampler: None 29 | num_workers: int 30 | pin_memory: bool 31 | drop_last: bool 32 | timeout: int 33 | worker_init_fn: None 34 | prefetch_factor: int | None 35 | persistent_workers: bool 36 | 37 | 38 | def load_satellite_zarrs(zarr_path: list[str] | tuple[str] | str) -> xr.Dataset: 39 | """Load the satellite data 40 | 41 | Args: 42 | zarr_path: The path to the satellite zarr(s) 43 | """ 44 | 45 | if isinstance(zarr_path, list | tuple): 46 | ds = xr.concat( 47 | [xr.open_dataset(path, engine="zarr", chunks="auto") for path in zarr_path], 48 | dim="time", 49 | coords="minimal", 50 | compat="identical", 51 | combine_attrs="override", 52 | ).sortby("time") 53 | else: 54 | ds = xr.open_dataset(zarr_path, engine="zarr", chunks="auto").sortby("time") 55 | 56 | return ds 57 | 58 | 59 | def find_valid_t0_times( 60 | datetimes: pd.DatetimeIndex, 61 | history_mins: int, 62 | forecast_mins: int, 63 | sample_freq_mins: int, 64 | ) -> pd.DatetimeIndex: 65 | """Constuct an array of all t0 times which are valid considering the gaps in the sat data""" 66 | 67 | # Find periods where we have contiguous time steps 68 | contiguous_time_periods = find_contiguous_time_periods( 69 | datetimes=datetimes, 70 | min_seq_length=int((history_mins + forecast_mins) / sample_freq_mins) + 1, 71 | max_gap_duration=timedelta(minutes=sample_freq_mins), 72 | ) 73 | 74 | # Find periods of valid init-times 75 | contiguous_t0_periods = find_contiguous_t0_time_periods( 76 | contiguous_time_periods=contiguous_time_periods, 77 | history_duration=timedelta(minutes=history_mins), 78 | forecast_duration=timedelta(minutes=forecast_mins), 79 | ) 80 | 81 | valid_t0_times = [] 82 | for _, row in contiguous_t0_periods.iterrows(): 83 | valid_t0_times.append( 84 | pd.date_range(row["start_dt"], row["end_dt"], freq=f"{sample_freq_mins}min") 85 | ) 86 | 87 | return pd.to_datetime(np.concatenate(valid_t0_times)) 88 | 89 | 90 | DataIndex = str | datetime | pd.Timestamp | int 91 | 92 | 93 | class SatelliteDataset(Dataset[tuple[NDArray[np.float32], NDArray[np.float32]]]): 94 | def __init__( 95 | self, 96 | zarr_path: list[str] | str, 97 | start_time: str | None, 98 | end_time: str | None, 99 | history_mins: int, 100 | forecast_mins: int, 101 | sample_freq_mins: int, 102 | variables: list[str] | str | None = None, 103 | preshuffle: bool = False, 104 | nan_to_num: bool = False, 105 | ): 106 | """A torch Dataset for loading past and future satellite data 107 | 108 | Args: 109 | zarr_path (list[str] | str): Path to the satellite data. Can be a string or list 110 | start_time (str): The satellite data is filtered to exclude timestamps before this 111 | end_time (str): The satellite data is filtered to exclude timestamps after this 112 | history_mins (int): How many minutes of history will be used as input features 113 | forecast_mins (int): How many minutes of future will be used as target features 114 | sample_freq_mins (int): The sample frequency to use for the satellite data 115 | variables (list[str] | str): The variables to load from the satellite data 116 | (defaults to all) 117 | preshuffle (bool): Whether to shuffle the data - useful for validation. 118 | Defaults to False. 119 | nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False. 120 | """ 121 | 122 | # Load the sat zarr file or list of files and slice the data to the given period 123 | ds = load_satellite_zarrs(zarr_path).sel(time=slice(start_time, end_time)) 124 | 125 | if variables is not None: 126 | if isinstance(variables, str): 127 | variables = [variables] 128 | self.ds = ds.sel(variable=variables) 129 | else: 130 | self.ds = ds 131 | 132 | # Convert the satellite data to the given time frequency by selection 133 | mask = np.mod(self.ds.time.dt.minute, sample_freq_mins) == 0 134 | self.ds = self.ds.sel(time=mask) 135 | 136 | # Find the valid t0 times for the available data. This avoids trying to take samples where 137 | # there would be a missing timestamp in the sat data required for the sample 138 | self.t0_times = self._find_t0_times( 139 | pd.DatetimeIndex(self.ds.time), history_mins, forecast_mins, sample_freq_mins 140 | ) 141 | 142 | if preshuffle: 143 | self.t0_times = pd.to_datetime(np.random.permutation(self.t0_times)) 144 | 145 | self.history_mins = history_mins 146 | self.forecast_mins = forecast_mins 147 | self.sample_freq_mins = sample_freq_mins 148 | self.nan_to_num = nan_to_num 149 | 150 | @staticmethod 151 | def _find_t0_times( 152 | date_range: pd.DatetimeIndex, history_mins: int, forecast_mins: int, sample_freq_mins: int 153 | ) -> pd.DatetimeIndex: 154 | return find_valid_t0_times(date_range, history_mins, forecast_mins, sample_freq_mins) 155 | 156 | def __len__(self) -> int: 157 | return len(self.t0_times) 158 | 159 | def _get_datetime(self, t0: datetime) -> tuple[NDArray[np.float32], NDArray[np.float32]]: 160 | ds_sel = self.ds.sel( 161 | time=slice( 162 | t0 - timedelta(minutes=self.history_mins), 163 | t0 + timedelta(minutes=self.forecast_mins), 164 | ) 165 | ) 166 | 167 | # Load the data eagerly so that the same chunks aren't loaded multiple times after we split 168 | # further 169 | ds_sel = ds_sel.compute(scheduler="single-threaded") 170 | 171 | # Reshape to (channel, time, height, width) 172 | ds_sel = ds_sel.transpose("variable", "time", "y_geostationary", "x_geostationary") 173 | 174 | ds_input = ds_sel.sel(time=slice(None, t0)) 175 | ds_target = ds_sel.sel(time=slice(t0 + timedelta(minutes=self.sample_freq_mins), None)) 176 | 177 | # Convert to arrays 178 | X = ds_input.data.values 179 | y = ds_target.data.values 180 | 181 | if self.nan_to_num: 182 | X = np.nan_to_num(X, nan=-1) 183 | y = np.nan_to_num(y, nan=-1) 184 | 185 | return X.astype(np.float32), y.astype(np.float32) 186 | 187 | def __getitem__(self, key: DataIndex) -> tuple[NDArray[np.float32], NDArray[np.float32]]: 188 | if isinstance(key, int): 189 | t0 = self.t0_times[key] 190 | 191 | else: 192 | assert isinstance(key, str | datetime | pd.Timestamp) 193 | t0 = pd.Timestamp(key) 194 | assert t0 in self.t0_times 195 | 196 | return self._get_datetime(t0) 197 | 198 | 199 | class ValidationSatelliteDataset(SatelliteDataset): 200 | def __init__( 201 | self, 202 | zarr_path: list[str] | str, 203 | history_mins: int, 204 | forecast_mins: int = FORECAST_HORIZON_MINUTES, 205 | sample_freq_mins: int = DATA_INTERVAL_SPACING_MINUTES, 206 | nan_to_num: bool = False, 207 | ): 208 | """A torch Dataset used only in the validation proceedure. 209 | 210 | Args: 211 | zarr_path (list[str] | str): Path to the satellite data for validation. 212 | Can be a string or list 213 | history_mins (int): How many minutes of history will be used as input features 214 | forecast_mins (int): How many minutes of future will be used as target features 215 | sample_freq_mins (int): The sample frequency to use for the satellite data 216 | nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False. 217 | """ 218 | 219 | super().__init__( 220 | zarr_path=zarr_path, 221 | start_time=None, 222 | end_time=None, 223 | history_mins=history_mins, 224 | forecast_mins=forecast_mins, 225 | sample_freq_mins=sample_freq_mins, 226 | preshuffle=True, 227 | nan_to_num=nan_to_num, 228 | ) 229 | 230 | @staticmethod 231 | def _find_t0_times( 232 | date_range: pd.DatetimeIndex, history_mins: int, forecast_mins: int, sample_freq_mins: int 233 | ) -> pd.DatetimeIndex: 234 | # Find the valid t0 times for the available data. This avoids trying to take samples where 235 | # there would be a missing timestamp in the sat data required for the sample 236 | available_t0_times = find_valid_t0_times( 237 | date_range, history_mins, forecast_mins, sample_freq_mins 238 | ) 239 | 240 | # Get the required 2022 test dataset t0 times 241 | val_t0_times_from_csv = ValidationSatelliteDataset._get_test_2022_t0_times() 242 | 243 | # Find the intersection of the available t0 times and the required validation t0 times 244 | val_time_available = val_t0_times_from_csv.isin(available_t0_times) 245 | 246 | # Make sure all of the required validation times are available in the data 247 | if not val_time_available.all(): 248 | msg = ( 249 | "The following validation t0 times are not available in the satellite data: \n" 250 | f"{val_t0_times_from_csv[~val_time_available]}\n" 251 | "The validation proceedure requires these t0 times to be available." 252 | ) 253 | raise ValueError(msg) 254 | 255 | return val_t0_times_from_csv 256 | 257 | @staticmethod 258 | def _get_t0_times(path: str) -> pd.DatetimeIndex: 259 | """Load the required validation t0 times from library path""" 260 | 261 | # Load the zipped csv file as a byte stream 262 | data = pkgutil.get_data("cloudcasting", path) 263 | if data is not None: 264 | byte_stream = io.BytesIO(data) 265 | else: 266 | # Handle the case where data is None 267 | msg = f"No data found for path: {path}" 268 | raise ValueError(msg) 269 | 270 | # Load the times into pandas 271 | df = pd.read_csv(byte_stream, encoding="utf8", compression="zip") 272 | 273 | return pd.DatetimeIndex(df.t0_time) 274 | 275 | @staticmethod 276 | def _get_test_2022_t0_times() -> pd.DatetimeIndex: 277 | """Load the required 2022 test dataset t0 times from their location in the library""" 278 | return ValidationSatelliteDataset._get_t0_times("data/test_2022_t0_times.csv.zip") 279 | 280 | @staticmethod 281 | def _get_verify_2023_t0_times() -> pd.DatetimeIndex: 282 | msg = "The required 2023 verification dataset t0 times are not available" 283 | raise NotImplementedError(msg) 284 | 285 | 286 | class SatelliteDataModule(LightningDataModule): 287 | def __init__( 288 | self, 289 | zarr_path: list[str] | str, 290 | history_mins: int, 291 | forecast_mins: int, 292 | sample_freq_mins: int, 293 | batch_size: int = 16, 294 | num_workers: int = 0, 295 | variables: list[str] | str | None = None, 296 | prefetch_factor: int | None = None, 297 | train_period: list[str | None] | tuple[str | None] | None = None, 298 | val_period: list[str | None] | tuple[str | None] | None = None, 299 | test_period: list[str | None] | tuple[str | None] | None = None, 300 | nan_to_num: bool = False, 301 | pin_memory: bool = False, 302 | persistent_workers: bool = False, 303 | ): 304 | """A lightning DataModule for loading past and future satellite data 305 | 306 | Args: 307 | zarr_path (list[str] | str): Path to the satellite data. Can be a string or list 308 | history_mins (int): How many minutes of history will be used as input features 309 | forecast_mins (int): How many minutes of future will be used as target features 310 | sample_freq_mins (int): The sample frequency to use for the satellite data 311 | batch_size (int): Batch size. Defaults to 16. 312 | num_workers (int): Number of workers to use in multiprocess batch loading. 313 | Defaults to 0. 314 | variables (list[str] | str): The variables to load from the satellite data 315 | (defaults to all) 316 | prefetch_factor (int): Number of data to be prefetched at the end of each worker process 317 | train_period (list[str] | tuple[str] | None): Date range filter for train dataloader 318 | val_period (list[str] | tuple[str] | None): Date range filter for validation dataloader 319 | test_period (list[str] | tuple[str] | None): Date range filter for test dataloader 320 | nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False. 321 | pin_memory (bool): If True, the data loader will copy Tensors into device/CUDA 322 | pinned memory before returning them. Defaults to False. 323 | persistent_workers (bool): If True, the data loader will not shut down the worker 324 | processes after a dataset has been consumed once. This allows you to keep the 325 | workers Dataset instances alive. Defaults to False. 326 | """ 327 | super().__init__() 328 | 329 | if train_period is None: 330 | train_period = [None, None] 331 | if val_period is None: 332 | val_period = [None, None] 333 | if test_period is None: 334 | test_period = [None, None] 335 | 336 | assert len(train_period) == 2 337 | assert len(val_period) == 2 338 | assert len(test_period) == 2 339 | 340 | self.train_period = train_period 341 | self.val_period = val_period 342 | self.test_period = test_period 343 | 344 | self.zarr_path = zarr_path 345 | self.history_mins = history_mins 346 | self.forecast_mins = forecast_mins 347 | self.sample_freq_mins = sample_freq_mins 348 | 349 | self._common_dataloader_kwargs = DataloaderArgs( 350 | batch_size=batch_size, 351 | sampler=None, 352 | batch_sampler=None, 353 | num_workers=num_workers, 354 | pin_memory=pin_memory, 355 | drop_last=False, 356 | timeout=0, 357 | worker_init_fn=None, 358 | prefetch_factor=prefetch_factor, 359 | persistent_workers=persistent_workers, 360 | ) 361 | 362 | self.nan_to_num = nan_to_num 363 | self.variables = variables 364 | 365 | def _make_dataset( 366 | self, start_date: str | None, end_date: str | None, preshuffle: bool = False 367 | ) -> SatelliteDataset: 368 | return SatelliteDataset( 369 | zarr_path=self.zarr_path, 370 | start_time=start_date, 371 | end_time=end_date, 372 | history_mins=self.history_mins, 373 | forecast_mins=self.forecast_mins, 374 | sample_freq_mins=self.sample_freq_mins, 375 | preshuffle=preshuffle, 376 | nan_to_num=self.nan_to_num, 377 | variables=self.variables, 378 | ) 379 | 380 | def train_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]: 381 | """Construct train dataloader""" 382 | dataset = self._make_dataset(self.train_period[0], self.train_period[1]) 383 | return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs) 384 | 385 | def val_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]: 386 | """Construct validation dataloader""" 387 | dataset = self._make_dataset(self.val_period[0], self.val_period[1], preshuffle=True) 388 | return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs) 389 | 390 | def test_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]: 391 | """Construct test dataloader""" 392 | dataset = self._make_dataset(self.test_period[0], self.test_period[1]) 393 | return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs) 394 | -------------------------------------------------------------------------------- /src/cloudcasting/download.py: -------------------------------------------------------------------------------- 1 | __all__ = ("download_satellite_data",) 2 | 3 | import logging 4 | import os 5 | from typing import Annotated 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import typer 10 | import xarray as xr 11 | from dask.diagnostics import ProgressBar # type: ignore[attr-defined] 12 | 13 | from cloudcasting.utils import lon_lat_to_geostationary_area_coords 14 | 15 | xr.set_options(keep_attrs=True) 16 | 17 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def _get_sat_public_dataset_path(year: int, is_hrv: bool = False) -> str: 22 | """ 23 | Get the path to the Google Public Dataset of EUMETSAT satellite data. 24 | 25 | Args: 26 | year: The year of the dataset. 27 | is_hrv: Whether to get the HRV dataset or not. 28 | 29 | Returns: 30 | The path to the dataset. 31 | """ 32 | file_end = "hrv.zarr" if is_hrv else "nonhrv.zarr" 33 | return f"gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v4/{year}_{file_end}" 34 | 35 | 36 | def download_satellite_data( 37 | start_date: Annotated[str, typer.Argument(help="Start date in 'YYYY-MM-DD HH:MM' format")], 38 | end_date: Annotated[str, typer.Argument(help="End date in 'YYYY-MM-DD HH:MM' format")], 39 | output_directory: Annotated[str, typer.Argument(help="Directory to save the satellite data")], 40 | download_frequency: Annotated[ 41 | str, typer.Option(help="Frequency to download data in pandas datetime format") 42 | ] = "15min", 43 | get_hrv: Annotated[bool, typer.Option(help="Whether to download HRV data")] = False, 44 | override_date_bounds: Annotated[ 45 | bool, typer.Option(help="Whether to override date range limits") 46 | ] = False, 47 | lon_min: Annotated[float, typer.Option(help="Minimum longitude")] = -16, 48 | lon_max: Annotated[float, typer.Option(help="Maximum longitude")] = 10, 49 | lat_min: Annotated[float, typer.Option(help="Minimum latitude")] = 45, 50 | lat_max: Annotated[float, typer.Option(help="Maximum latitude")] = 70, 51 | test_2022_set: Annotated[ 52 | bool, 53 | typer.Option( 54 | help="Whether to filter data from 2022 to download the test set (every 2 weeks)." 55 | ), 56 | ] = False, 57 | verify_2023_set: Annotated[ 58 | bool, 59 | typer.Option( 60 | help="Whether to download the verification data from 2023. Only used at project end" 61 | ), 62 | ] = False, 63 | ) -> None: 64 | """ 65 | Download a selection of the available EUMETSAT data. 66 | 67 | Each calendar year of data within the supplied date range will be saved to a separate file in 68 | the output directory. 69 | 70 | Args: 71 | start_date (str): First datetime (inclusive) to download in 'YYYY-MM-DD HH:MM' format 72 | end_date (str): Last datetime (inclusive) to download in 'YYYY-MM-DD HH:MM' format 73 | output_directory (str): Directory to which the satellite data should be saved 74 | download_frequency (str): Frequency to download data in pandas datetime format. 75 | Defaults to "15min". 76 | get_hrv (bool): Whether to download the HRV data, otherwise only non-HRV is downloaded. 77 | Defaults to False. 78 | override_date_bounds (bool): Whether to override the date range limits 79 | lon_min (float): The west-most longitude (in degrees) of the bounding box to download. 80 | Defaults to -16. 81 | lon_max (float): The east-most longitude (in degrees) of the bounding box to download. 82 | Defaults to 10. 83 | lat_min (float): The south-most latitude (in degrees) of the bounding box to download. 84 | Defaults to 45. 85 | lat_max (float): The north-most latitude (in degrees) of the bounding box to download. 86 | Defaults to 70. 87 | test_2022_set (bool): Whether to filter data from 2022 to download the test set 88 | (every 2 weeks) 89 | verify_2023_set (bool): Whether to download verification data from 2023. Only 90 | used at project end. 91 | 92 | Raises: 93 | FileNotFoundError: If the output directory doesn't exist. 94 | ValueError: If there are issues with the date range or if output files already exist. 95 | """ 96 | 97 | # Check output directory exists 98 | if not os.path.isdir(output_directory): 99 | msg = ( 100 | f"Output directory {output_directory} does not exist. " 101 | "Please create it before attempting to download satellite data." 102 | ) 103 | raise FileNotFoundError(msg) 104 | 105 | # Build the formatable string for the output file path. 106 | # We can insert year later using `output_file_root.format(year=year)`` 107 | output_file_root = output_directory + "/{year}_" 108 | 109 | # Add training split label 110 | if test_2022_set: 111 | output_file_root += "test_" 112 | elif verify_2023_set: 113 | output_file_root += "verification_" 114 | else: 115 | output_file_root += "training_" 116 | 117 | # Add HRV or non-HRV label and file extension 118 | if get_hrv: 119 | output_file_root += "hrv.zarr" 120 | else: 121 | output_file_root += "nonhrv.zarr" 122 | 123 | # Check download frequency is valid (i.e. is a pandas frequency + multiple of 5 minutes) 124 | if pd.Timedelta(download_frequency) % pd.Timedelta("5min") != pd.Timedelta(0): 125 | msg = ( 126 | f"Download frequency {download_frequency} is not a multiple of 5 minutes. " 127 | "Please choose a valid frequency." 128 | ) 129 | raise ValueError(msg) 130 | 131 | start_date_stamp = pd.Timestamp(start_date) 132 | end_date_stamp = pd.Timestamp(end_date) 133 | 134 | # Check start date is before end date 135 | if start_date_stamp > end_date_stamp: 136 | msg = "Start date ({start_date_stamp}) must be before end date ({end_date_stamp})." 137 | raise ValueError(msg) 138 | 139 | # Check date range for known limitations 140 | if not override_date_bounds and start_date_stamp.year < 2019: 141 | msg = ( 142 | "There are currently some issues with the EUMETSAT data before 2019/01/01. " 143 | "We recommend only using data from this date forward. " 144 | "To override this error set `override_date_bounds=True`" 145 | ) 146 | raise ValueError(msg) 147 | 148 | # Check the year is 2022 if test_2022 data is being downloaded 149 | if test_2022_set and (start_date_stamp.year != 2022 or end_date_stamp.year != 2022): 150 | msg = "Test data is only defined for 2022" 151 | raise ValueError(msg) 152 | 153 | # Check the start / end dates are correct if verification data is being downloaded 154 | if verify_2023_set and ( 155 | start_date_stamp != pd.Timestamp("2023-01-01 00:00") 156 | or end_date_stamp != pd.Timestamp("2023-12-31 23:55") 157 | ): 158 | msg = ( 159 | "Verification data requires a start date of '2023-01-01 00:00'" 160 | "and an end date of '2023-12-31 23:55'" 161 | ) 162 | raise ValueError(msg) 163 | 164 | # Check the year 2023 is not included unless verification data is being downloaded 165 | if (not verify_2023_set) and (end_date_stamp.year >= 2023): 166 | msg = "2023 data is reserved for the verification process" 167 | raise ValueError(msg) 168 | 169 | years = range(start_date_stamp.year, end_date_stamp.year + 1) 170 | 171 | # Ceiling the start date to nearest multiple of the download frequency 172 | # Breaks down over multiple days due to starting at the Unix epoch (1970-01-01 Thursday), 173 | # e.g. 2022-01-01 ceiled to 1 week will be 2022-01-06 (the closest Thursday to 2022-01-01). 174 | range_start = ( 175 | start_date_stamp.ceil(download_frequency) 176 | if pd.Timedelta(download_frequency) <= pd.Timedelta("1day") 177 | else start_date_stamp 178 | ) 179 | # Create a list of dates to download 180 | dates_to_download = pd.date_range(range_start, end_date_stamp, freq=download_frequency) 181 | 182 | # Check that none of the filenames we will save to already exist 183 | for year in years: 184 | output_zarr_file = output_file_root.format(year=year) 185 | if os.path.exists(output_zarr_file): 186 | msg = ( 187 | f"The zarr file {output_zarr_file} already exists. " 188 | "This function will not overwrite data." 189 | ) 190 | raise ValueError(msg) 191 | 192 | # Begin download loop 193 | for year in years: 194 | logger.info("Downloading data from %s", year) 195 | path = _get_sat_public_dataset_path(year, is_hrv=get_hrv) 196 | 197 | # Slice the data from this year which are between the start and end dates. 198 | ds = xr.open_zarr(path, chunks={}).sortby("time") 199 | 200 | ds = ds.sel(time=dates_to_download[dates_to_download.isin(ds.time.values)]) 201 | 202 | if year == 2022: 203 | set_str = "Test_2022" if test_2022_set else "Training" 204 | day_str = "14" if test_2022_set else "1" 205 | logger.info("Data in 2022 will be downloaded every 2 weeks due to train/test split.") 206 | logger.info("%s set selected: Starting day will be %s.", set_str, day_str) 207 | # Integer division by 14 will tell us the fortnight we're on. 208 | # checking the mod wrt 2 will let us select every 2 weeks 209 | # Test set is defined as from week 2-3, 6-7 etc. 210 | # Weeks 0-1, 4-5 etc. are included in training set 211 | if test_2022_set: 212 | mask = np.mod(ds.time.dt.dayofyear // 14, 2) == 1 213 | else: 214 | mask = np.mod(ds.time.dt.dayofyear // 14, 2) == 0 215 | ds = ds.sel(time=mask) 216 | 217 | # Convert lon-lat bounds to geostationary-coords 218 | (x_min, x_max), (y_min, y_max) = lon_lat_to_geostationary_area_coords( 219 | [lon_min, lon_max], 220 | [lat_min, lat_max], 221 | ds.data, 222 | ) 223 | 224 | # Define the spatial area to slice from 225 | ds = ds.sel( 226 | x_geostationary=slice(x_max, x_min), # x-axis is in decreasing order 227 | y_geostationary=slice(y_min, y_max), 228 | ) 229 | 230 | # Re-chunking 231 | for v in ds.variables: 232 | if "chunks" in ds[v].encoding: 233 | del ds[v].encoding["chunks"] 234 | 235 | target_chunks_dict = { 236 | "time": 2, 237 | "x_geostationary": -1, 238 | "y_geostationary": -1, 239 | "variable": -1, 240 | } 241 | 242 | ds = ds.chunk(target_chunks_dict) 243 | 244 | # Save data to zarr 245 | output_zarr_file = output_file_root.format(year=year) 246 | with ProgressBar(dt=1): 247 | ds.to_zarr(output_zarr_file) 248 | logger.info("Data for %s saved to %s", year, output_zarr_file) 249 | -------------------------------------------------------------------------------- /src/cloudcasting/metrics.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | ### Adapted by The Alan Turing Institute and Open Climate Fix, 2024. 3 | ### Original source: https://github.com/google-deepmind/dm_pix 4 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Functions to compare image pairs. 18 | 19 | All functions expect float-encoded images, with values in [0, 1], with NHWC 20 | shapes. Each image metric function returns a scalar for each image pair. 21 | """ 22 | 23 | from collections.abc import Callable 24 | 25 | import chex 26 | import jax 27 | import jax.numpy as jnp 28 | 29 | # DO NOT REMOVE - Logging lib. 30 | 31 | 32 | def mae(a: chex.Array, b: chex.Array, ignore_nans: bool = False) -> chex.Numeric: 33 | """Returns the Mean Absolute Error between `a` and `b`. 34 | 35 | Args: 36 | a (chex.Array): First image (or set of images) 37 | b (chex.Array): Second image (or set of images) 38 | ignore_nans (bool): Defaults to False 39 | 40 | Returns: 41 | chex.Numeric: MAE between `a` and `b` 42 | """ 43 | # DO NOT REMOVE - Logging usage. 44 | 45 | chex.assert_rank([a, b], {3, 4}) 46 | chex.assert_type([a, b], float) 47 | chex.assert_equal_shape([a, b]) 48 | if ignore_nans: 49 | return jnp.nanmean(jnp.abs(a - b), axis=(-3, -2, -1)) 50 | return jnp.mean(jnp.abs(a - b), axis=(-3, -2, -1)) 51 | 52 | 53 | def mse(a: chex.Array, b: chex.Array, ignore_nans: bool = False) -> chex.Numeric: 54 | """Returns the Mean Squared Error between `a` and `b`. 55 | 56 | Args: 57 | a (chex.Array): First image (or set of images) 58 | b (chex.Array): Second image (or set of images) 59 | ignore_nans (bool): Defaults to False 60 | 61 | Returns: 62 | chex.Numeric: MSE between `a` and `b` 63 | """ 64 | # DO NOT REMOVE - Logging usage. 65 | 66 | chex.assert_rank([a, b], {3, 4}) 67 | chex.assert_type([a, b], float) 68 | chex.assert_equal_shape([a, b]) 69 | if ignore_nans: 70 | return jnp.nanmean(jnp.square(a - b), axis=(-3, -2, -1)) 71 | return jnp.mean(jnp.square(a - b), axis=(-3, -2, -1)) 72 | 73 | 74 | def psnr(a: chex.Array, b: chex.Array) -> chex.Numeric: 75 | """Returns the Peak Signal-to-Noise Ratio between `a` and `b`. 76 | 77 | Assumes that the dynamic range of the images (the difference between the 78 | maximum and the minimum allowed values) is 1.0. 79 | 80 | Args: 81 | a (chex.Array): First image (or set of images) 82 | b (chex.Array): Second image (or set of images) 83 | 84 | Returns: 85 | chex.Numeric: PSNR in decibels between `a` and `b` 86 | """ 87 | # DO NOT REMOVE - Logging usage. 88 | 89 | chex.assert_rank([a, b], {3, 4}) 90 | chex.assert_type([a, b], float) 91 | chex.assert_equal_shape([a, b]) 92 | return -10.0 * jnp.log(mse(a, b)) / jnp.log(10.0) 93 | 94 | 95 | def rmse(a: chex.Array, b: chex.Array) -> chex.Numeric: 96 | """Returns the Root Mean Squared Error between `a` and `b`. 97 | 98 | Args: 99 | a (chex.Array): First image (or set of images) 100 | b (chex.Array): Second image (or set of images) 101 | 102 | Returns: 103 | chex.Numeric: RMSE between `a` and `b` 104 | """ 105 | # DO NOT REMOVE - Logging usage. 106 | 107 | chex.assert_rank([a, b], {3, 4}) 108 | chex.assert_type([a, b], float) 109 | chex.assert_equal_shape([a, b]) 110 | return jnp.sqrt(mse(a, b)) 111 | 112 | 113 | def simse(a: chex.Array, b: chex.Array) -> chex.Numeric: 114 | """Returns the Scale-Invariant Mean Squared Error between `a` and `b`. 115 | 116 | For each image pair, a scaling factor for `b` is computed as the solution to 117 | the following problem: 118 | 119 | min_alpha || vec(a) - alpha * vec(b) ||_2^2 120 | 121 | where `a` and `b` are flattened, i.e., vec(x) = np.flatten(x). The MSE between 122 | the optimally scaled `b` and `a` is returned: mse(a, alpha*b). 123 | 124 | This is a scale-invariant metric, so for example: simse(x, y) == sims(x, y*5). 125 | 126 | This metric was used in "Shape, Illumination, and Reflectance from Shading" by 127 | Barron and Malik, TPAMI, '15. 128 | 129 | Args: 130 | a (chex.Array): First image (or set of images) 131 | b (chex.Array): Second image (or set of images) 132 | 133 | Returns: 134 | chex.Numeric: SIMSE between `a` and `b` 135 | """ 136 | # DO NOT REMOVE - Logging usage. 137 | 138 | chex.assert_rank([a, b], {3, 4}) 139 | chex.assert_type([a, b], float) 140 | chex.assert_equal_shape([a, b]) 141 | 142 | a_dot_b = (a * b).sum(axis=(-3, -2, -1), keepdims=True) 143 | b_dot_b = (b * b).sum(axis=(-3, -2, -1), keepdims=True) 144 | alpha = a_dot_b / b_dot_b 145 | return mse(a, alpha * b) 146 | 147 | 148 | def ssim( 149 | a: chex.Array, 150 | b: chex.Array, 151 | *, 152 | max_val: float = 1.0, 153 | filter_size: int = 11, 154 | filter_sigma: float = 1.5, 155 | k1: float = 0.01, 156 | k2: float = 0.03, 157 | return_map: bool = False, 158 | precision=jax.lax.Precision.HIGHEST, 159 | filter_fn: Callable[[chex.Array], chex.Array] | None = None, 160 | ignore_nans: bool = False, 161 | ) -> chex.Numeric: 162 | """Computes the structural similarity index (SSIM) between image pairs. 163 | 164 | This function is based on the standard SSIM implementation from: 165 | Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, 166 | "Image quality assessment: from error visibility to structural similarity", 167 | in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, 2004. 168 | 169 | This function was modeled after tf.image.ssim, and should produce comparable 170 | output. 171 | 172 | Note: the true SSIM is only defined on grayscale. This function does not 173 | perform any colorspace transform. If the input is in a color space, then it 174 | will compute the average SSIM. 175 | 176 | Args: 177 | a (chex.Array): First image (or set of images) 178 | b (chex.Array): Second image (or set of images) 179 | max_val (float): The maximum magnitude that `a` or `b` can have. Defaults to 1. 180 | filter_size (int): Window size (>= 1). Image dims must be at least this small. 181 | Defaults to 11 182 | filter_sigma (float): The bandwidth of the Gaussian used for filtering (> 0.). 183 | Defaults to 1.5 184 | k1 (float): One of the SSIM dampening parameters (> 0.). Defaults to 0.01. 185 | k2 (float): One of the SSIM dampening parameters (> 0.). Defaults to 0.03. 186 | return_map (bool): If True, will cause the per-pixel SSIM "map" to be returned. 187 | Defaults to False. 188 | precision: The numerical precision to use when performing convolution 189 | filter_fn (Callable[[chex.Array], chex.Array] | None): An optional argument for 190 | overriding the filter function used by SSIM, which would otherwise be a 2D 191 | Gaussian blur specified by filter_size and filter_sigma 192 | ignore_nans (bool): Defaults to False 193 | 194 | Returns: 195 | chex.Numeric: Each image's mean SSIM, or a tensor of individual values if `return_map` is True 196 | """ 197 | # DO NOT REMOVE - Logging usage. 198 | 199 | chex.assert_rank([a, b], {3, 4}) 200 | chex.assert_type([a, b], float) 201 | chex.assert_equal_shape([a, b]) 202 | 203 | if filter_fn is None: 204 | # Construct a 1D Gaussian blur filter. 205 | hw = filter_size // 2 206 | shift = (2 * hw - filter_size + 1) / 2 207 | f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma) ** 2 208 | filt = jnp.exp(-0.5 * f_i) 209 | filt /= jnp.sum(filt) 210 | 211 | # Construct a 1D convolution. 212 | def filter_fn_1(z): 213 | return jnp.convolve(z, filt, mode="valid", precision=precision) 214 | 215 | filter_fn_vmap = jax.vmap(filter_fn_1) 216 | 217 | # Apply the vectorized filter along the y axis. 218 | def filter_fn_y(z): 219 | z_flat = jnp.moveaxis(z, -3, -1).reshape((-1, z.shape[-3])) 220 | z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + ( 221 | z.shape[-2], 222 | z.shape[-1], 223 | -1, 224 | ) 225 | return jnp.moveaxis(filter_fn_vmap(z_flat).reshape(z_filtered_shape), -1, -3) 226 | 227 | # Apply the vectorized filter along the x axis. 228 | def filter_fn_x(z): 229 | z_flat = jnp.moveaxis(z, -2, -1).reshape((-1, z.shape[-2])) 230 | z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + ( 231 | z.shape[-3], 232 | z.shape[-1], 233 | -1, 234 | ) 235 | return jnp.moveaxis(filter_fn_vmap(z_flat).reshape(z_filtered_shape), -1, -2) 236 | 237 | # Apply the blur in both x and y. 238 | def filter_fn(z): 239 | return filter_fn_y(filter_fn_x(z)) 240 | 241 | mu0 = filter_fn(a) 242 | mu1 = filter_fn(b) 243 | mu00 = mu0 * mu0 244 | mu11 = mu1 * mu1 245 | mu01 = mu0 * mu1 246 | sigma00 = filter_fn(a**2) - mu00 247 | sigma11 = filter_fn(b**2) - mu11 248 | sigma01 = filter_fn(a * b) - mu01 249 | 250 | # Clip the variances and covariances to valid values. 251 | # Variance must be non-negative: 252 | epsilon = jnp.finfo(jnp.float32).eps ** 2 253 | sigma00 = jnp.maximum(epsilon, sigma00) 254 | sigma11 = jnp.maximum(epsilon, sigma11) 255 | sigma01 = jnp.sign(sigma01) * jnp.minimum(jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01)) 256 | 257 | c1 = (k1 * max_val) ** 2 258 | c2 = (k2 * max_val) ** 2 259 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 260 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 261 | ssim_map = numer / denom 262 | 263 | if ignore_nans: 264 | ssim_value = jnp.nanmean(ssim_map, axis=tuple(range(-3, 0))) 265 | else: 266 | ssim_value = jnp.mean(ssim_map, axis=tuple(range(-3, 0))) 267 | return ssim_map if return_map else ssim_value 268 | -------------------------------------------------------------------------------- /src/cloudcasting/models.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | import numpy as np 5 | 6 | from cloudcasting.constants import ( 7 | DATA_INTERVAL_SPACING_MINUTES, 8 | FORECAST_HORIZON_MINUTES, 9 | NUM_FORECAST_STEPS, 10 | ) 11 | from cloudcasting.types import BatchInputArray, BatchOutputArray 12 | 13 | 14 | # model interface 15 | class AbstractModel(ABC): 16 | """An abstract class for validating a generic satellite prediction model""" 17 | 18 | history_steps: int 19 | 20 | def __init__(self, history_steps: int) -> None: 21 | self.history_steps: int = history_steps 22 | 23 | @abstractmethod 24 | def forward(self, X: BatchInputArray) -> BatchOutputArray: 25 | """Abstract method for the forward pass of the model. 26 | 27 | Args: 28 | X (BatchInputArray): Either a batch or a sample of the most recent satellite data. 29 | X will be 5 dimensional and has shape [batch, channels, time, height, width] where 30 | time = {t_{-n}, ..., t_{0}} 31 | (all n values needed to predict {t'_{1}, ..., t'_{horizon}}) 32 | 33 | Returns: 34 | ForecastArray: The models prediction of the future satellite 35 | data of shape [batch, channels, rollout_steps, height, width] where 36 | rollout_steps = {t'_{1}, ..., t'_{horizon}}. 37 | """ 38 | 39 | def check_predictions(self, y_hat: BatchOutputArray) -> None: 40 | """Checks the predictions conform to expectations""" 41 | # Check no NaNs in the predictions 42 | if np.isnan(y_hat).any(): 43 | msg = f"Predictions contain NaNs: {np.isnan(y_hat).mean()=:.4g}." 44 | raise ValueError(msg) 45 | 46 | # Check the range of the predictions. If outside the expected range this can interfere 47 | # with computing metrics like structural similarity 48 | if ((y_hat < 0) | (y_hat > 1)).any(): 49 | msg = ( 50 | "The predictions must be in the range [0, 1]. " 51 | f"Found range [{y_hat.min(), y_hat.max()}]." 52 | ) 53 | raise ValueError(msg) 54 | 55 | if y_hat.shape[-3] != NUM_FORECAST_STEPS: 56 | msg = ( 57 | f"The number of forecast steps in the model ({y_hat.shape[2]}) does not match " 58 | f"{NUM_FORECAST_STEPS=}, defined from " 59 | f"{FORECAST_HORIZON_MINUTES=} // {DATA_INTERVAL_SPACING_MINUTES=}." 60 | f"Found shape {y_hat.shape}." 61 | ) 62 | raise ValueError(msg) 63 | 64 | def __call__(self, X: BatchInputArray) -> BatchOutputArray: 65 | # check the shape of the input 66 | if X.shape[-3] != self.history_steps: 67 | msg = ( 68 | f"The number of history steps in the input ({X.shape[-3]}) does not match " 69 | f"{self.history_steps=}." 70 | ) 71 | raise ValueError(msg) 72 | 73 | # run the forward pass 74 | y_hat = self.forward(X) 75 | 76 | # carry out a set of checks on the predictions to make sure they conform to the 77 | # expectations of the validation script 78 | self.check_predictions(y_hat) 79 | 80 | return y_hat 81 | 82 | @abstractmethod 83 | def hyperparameters_dict(self) -> dict[str, Any]: 84 | """Return a dictionary of the hyperparameters used to train the model""" 85 | 86 | 87 | class VariableHorizonModel(AbstractModel): 88 | def __init__(self, rollout_steps: int, history_steps: int) -> None: 89 | self.rollout_steps: int = rollout_steps 90 | super().__init__(history_steps) 91 | -------------------------------------------------------------------------------- /src/cloudcasting/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/cloudcasting/0da4087bba01823d2477dba61ea3e6cfd557212c/src/cloudcasting/py.typed -------------------------------------------------------------------------------- /src/cloudcasting/types.py: -------------------------------------------------------------------------------- 1 | __all__ = ( 2 | "BatchInputArray", 3 | "BatchOutputArray", 4 | "BatchOutputArrayJAX", 5 | "ChannelArray", 6 | "InputArray", 7 | "MetricArray", 8 | "OutputArray", 9 | "SampleInputArray", 10 | "SampleOutputArray", 11 | "TimeArray", 12 | ) 13 | 14 | import jaxtyping 15 | import numpy as np 16 | import numpy.typing as npt 17 | from jaxtyping import Float as Float32 18 | 19 | # Type aliases for clarity + reuse 20 | Array = npt.NDArray[np.float32] # the type arg is ignored by jaxtyping, but is here for clarity 21 | TimeArray = Float32[Array, "time"] 22 | MetricArray = Float32[Array, "channels time"] 23 | ChannelArray = Float32[Array, "channels"] 24 | 25 | SampleInputArray = Float32[Array, "channels time height width"] 26 | BatchInputArray = Float32[Array, "batch channels time height width"] 27 | InputArray = SampleInputArray | BatchInputArray 28 | 29 | 30 | SampleOutputArray = Float32[Array, "channels rollout_steps height width"] 31 | BatchOutputArray = Float32[Array, "batch channels rollout_steps height width"] 32 | BatchOutputArrayJAX = Float32[jaxtyping.Array, "batch channels rollout_steps height width"] 33 | 34 | OutputArray = SampleOutputArray | BatchOutputArray 35 | -------------------------------------------------------------------------------- /src/cloudcasting/utils.py: -------------------------------------------------------------------------------- 1 | __all__ = ( 2 | "find_contiguous_t0_time_periods", 3 | "find_contiguous_time_periods", 4 | "lon_lat_to_geostationary_area_coords", 5 | "numpy_validation_collate_fn", 6 | ) 7 | 8 | from collections.abc import Sequence 9 | from datetime import timedelta 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import pyproj 14 | import pyresample 15 | import xarray as xr 16 | from numpy.typing import NDArray 17 | 18 | from cloudcasting.types import ( 19 | BatchInputArray, 20 | BatchOutputArray, 21 | SampleInputArray, 22 | SampleOutputArray, 23 | ) 24 | 25 | 26 | # taken from ocf_datapipes 27 | def lon_lat_to_geostationary_area_coords( 28 | x: Sequence[float], 29 | y: Sequence[float], 30 | xr_data: xr.Dataset | xr.DataArray, 31 | ) -> tuple[Sequence[float], Sequence[float]]: 32 | """Loads geostationary area and change from lon-lat to geostationary coords 33 | 34 | Args: 35 | x (Sequence[float]): Longitude east-west 36 | y Sequence[float]: Latitude north-south 37 | xr_data (xr.Dataset | xr.DataArray): xarray object with geostationary area 38 | 39 | Returns: 40 | tuple[Sequence[float], Sequence[float]]: x, y in geostationary coordinates 41 | """ 42 | # WGS84 is short for "World Geodetic System 1984", used in GPS. Uses 43 | # latitude and longitude. 44 | WGS84 = 4326 45 | 46 | try: 47 | area_definition_yaml = xr_data.attrs["area"] 48 | except KeyError: 49 | area_definition_yaml = xr_data.data.attrs["area"] 50 | geostationary_crs = pyresample.area_config.load_area_from_string(area_definition_yaml).crs # type: ignore[no-untyped-call] 51 | lonlat_to_geostationary = pyproj.Transformer.from_crs( 52 | crs_from=WGS84, 53 | crs_to=geostationary_crs, 54 | always_xy=True, 55 | ).transform 56 | return lonlat_to_geostationary(xx=x, yy=y) 57 | 58 | 59 | def find_contiguous_time_periods( 60 | datetimes: pd.DatetimeIndex, 61 | min_seq_length: int, 62 | max_gap_duration: timedelta, 63 | ) -> pd.DataFrame: 64 | """Return a pd.DataFrame where each row records the boundary of a contiguous time period. 65 | 66 | Args: 67 | datetimes (pd.DatetimeIndex): Must be sorted. 68 | min_seq_length (int): Sequences of min_seq_length or shorter will be discarded. Typically, 69 | this would be set to the `total_seq_length` of each machine learning example. 70 | max_gap_duration (timedelta): If any pair of consecutive `datetimes` is more than 71 | `max_gap_duration` apart, then this pair of `datetimes` will be considered a "gap" between 72 | two contiguous sequences. Typically, `max_gap_duration` would be set to the sample period of 73 | the timeseries. 74 | 75 | Returns: 76 | pd.DataFrame: The DataFrame has two columns `start_dt` and `end_dt` 77 | (where 'dt' is short for 'datetime'). Each row represents a single time period. 78 | """ 79 | # Sanity checks. 80 | assert len(datetimes) > 0 81 | assert min_seq_length > 1 82 | assert datetimes.is_monotonic_increasing 83 | assert datetimes.is_unique 84 | 85 | # Find indices of gaps larger than max_gap: 86 | gap_mask = pd.TimedeltaIndex(np.diff(datetimes)) > np.timedelta64(max_gap_duration) 87 | gap_indices = np.argwhere(gap_mask)[:, 0] 88 | 89 | # gap_indicies are the indices into dt_index for the timestep immediately before the gap. 90 | # e.g. if the datetimes at 12:00, 12:05, 18:00, 18:05 then gap_indicies will be [1]. 91 | # So we add 1 to gap_indices to get segment_boundaries, an index into dt_index 92 | # which identifies the _start_ of each segment. 93 | segment_boundaries = gap_indices + 1 94 | 95 | # Capture the last segment of dt_index. 96 | segment_boundaries = np.append(segment_boundaries, len(datetimes)) 97 | 98 | periods: list[dict[str, pd.Timestamp]] = [] 99 | start_i = 0 100 | for next_start_i in segment_boundaries: 101 | n_timesteps = next_start_i - start_i 102 | if n_timesteps > min_seq_length: 103 | end_i = next_start_i - 1 104 | period = {"start_dt": datetimes[start_i], "end_dt": datetimes[end_i]} 105 | periods.append(period) 106 | start_i = next_start_i 107 | 108 | assert len(periods) > 0, ( 109 | f"Did not find an periods from {datetimes}. {min_seq_length=} {max_gap_duration=}" 110 | ) 111 | 112 | return pd.DataFrame(periods) 113 | 114 | 115 | def find_contiguous_t0_time_periods( 116 | contiguous_time_periods: pd.DataFrame, history_duration: timedelta, forecast_duration: timedelta 117 | ) -> pd.DataFrame: 118 | """Get all time periods which contain valid t0 datetimes. 119 | `t0` is the datetime of the most recent observation. 120 | 121 | Args: 122 | contiguous_time_periods (pd.DataFrame): Dataframe of continguous time periods 123 | history_duration (timedelta): Duration of the history 124 | forecast_duration (timedelta): Duration of the forecast 125 | 126 | Returns: 127 | pd.DataFrame: A DataFrame with two columns `start_dt` and `end_dt` 128 | (where 'dt' is short for 'datetime'). Each row represents a single time period. 129 | """ 130 | contiguous_time_periods["start_dt"] += np.timedelta64(history_duration) 131 | contiguous_time_periods["end_dt"] -= np.timedelta64(forecast_duration) 132 | assert (contiguous_time_periods["start_dt"] < contiguous_time_periods["end_dt"]).all() 133 | return contiguous_time_periods 134 | 135 | 136 | def numpy_validation_collate_fn( 137 | samples: list[tuple[SampleInputArray, SampleOutputArray]], 138 | ) -> tuple[BatchInputArray, BatchOutputArray]: 139 | """Collate a list of data + targets into a batch. 140 | 141 | Args: 142 | samples (list[tuple[SampleInputArray, SampleOutputArray]]): List of (X, y) samples, 143 | with sizes of X (batch, channels, time, height, width) and 144 | y (batch, channels, rollout_steps, height, width) 145 | 146 | Returns: 147 | tuple(np.ndarray, np.ndarray): The collated batch of X samples in the form 148 | (batch, channels, time, height, width) and the collated batch of y samples 149 | in the form (batch, channels, rollout_steps, height, width) 150 | """ 151 | 152 | # Create empty stores for the compiled batch 153 | X_all = np.empty((len(samples), *samples[0][0].shape), dtype=np.float32) 154 | y_all = np.empty((len(samples), *samples[0][1].shape), dtype=np.float32) 155 | 156 | # Fill the stores with the samples 157 | for i, (X, y) in enumerate(samples): 158 | X_all[i] = X 159 | y_all[i] = y 160 | return X_all, y_all 161 | 162 | 163 | def create_cutout_mask( 164 | mask_size: tuple[int, int, int, int], 165 | image_size: tuple[int, int], 166 | ) -> NDArray[np.float32]: 167 | """Create a mask with a cutout in the center. 168 | 169 | Args: 170 | mask_size: Size of the cutout 171 | image_size: Size of the image 172 | 173 | Returns: 174 | np.ndarray: The mask 175 | """ 176 | height, width = image_size 177 | min_x, max_x, min_y, max_y = mask_size 178 | 179 | mask = np.empty((height, width), dtype=np.float32) 180 | mask[:] = np.nan 181 | mask[min_y:max_y, min_x:max_x] = 1 182 | return mask 183 | -------------------------------------------------------------------------------- /src/cloudcasting/validation.py: -------------------------------------------------------------------------------- 1 | __all__ = ("validate", "validate_from_config") 2 | 3 | import importlib.util 4 | import inspect 5 | import logging 6 | import os 7 | import sys 8 | from collections.abc import Callable 9 | from functools import partial 10 | from typing import Annotated, Any, cast 11 | 12 | import jax.numpy as jnp 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import numpy.typing as npt 16 | import typer 17 | import wandb 18 | import yaml 19 | from jax import tree 20 | from jaxtyping import Array, Float32 21 | from matplotlib.colors import Normalize 22 | from torch.utils.data import DataLoader 23 | from tqdm import tqdm 24 | 25 | try: 26 | import torch.multiprocessing as mp 27 | 28 | mp.set_start_method("spawn", force=True) 29 | except RuntimeError: 30 | pass 31 | 32 | import cloudcasting 33 | from cloudcasting import metrics as dm_pix # for compatibility if our changes are upstreamed 34 | from cloudcasting.constants import ( 35 | CROPPED_CUTOUT_MASK, 36 | CROPPED_CUTOUT_MASK_BOUNDARY, 37 | CROPPED_IMAGE_SIZE_TUPLE, 38 | CUTOUT_MASK, 39 | CUTOUT_MASK_BOUNDARY, 40 | DATA_INTERVAL_SPACING_MINUTES, 41 | FORECAST_HORIZON_MINUTES, 42 | IMAGE_SIZE_TUPLE, 43 | NUM_CHANNELS, 44 | ) 45 | from cloudcasting.dataset import ValidationSatelliteDataset 46 | from cloudcasting.models import AbstractModel 47 | from cloudcasting.types import ( 48 | BatchOutputArrayJAX, 49 | ChannelArray, 50 | MetricArray, 51 | SampleOutputArray, 52 | TimeArray, 53 | ) 54 | from cloudcasting.utils import numpy_validation_collate_fn 55 | 56 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") 57 | logger = logging.getLogger(__name__) 58 | 59 | # defined in manchester prize technical document 60 | WANDB_ENTITY = "manchester_prize" 61 | VIDEO_SAMPLE_DATES = [ 62 | "2022-01-17 11:00", 63 | "2022-04-11 06:00", 64 | "2022-06-10 11:00", 65 | "2022-09-30 18:15", 66 | ] 67 | VIDEO_SAMPLE_CHANNELS = ["VIS008", "IR_087"] 68 | 69 | 70 | def log_mean_metrics_to_wandb( 71 | metric_value: float, 72 | plot_name: str, 73 | metric_name: str, 74 | ) -> None: 75 | """Upload a bar chart of mean metric value to wandb 76 | 77 | Args: 78 | metric_value: The mean metric value to upload 79 | plot_name: The name under which to save the plot to wandb 80 | metric_name: The name of the metric used to label the y-axis in the uploaded plot 81 | """ 82 | data = [[metric_name, metric_value]] 83 | table = wandb.Table(data=data, columns=["metric name", "value"]) 84 | wandb.log({plot_name: wandb.plot.bar(table, "metric name", "value", title=plot_name)}) 85 | 86 | 87 | def log_per_horizon_metrics_to_wandb( 88 | horizon_mins: TimeArray, 89 | metric_values: TimeArray, 90 | plot_name: str, 91 | metric_name: str, 92 | ) -> None: 93 | """Upload a plot of metric value vs forecast horizon to wandb 94 | 95 | Args: 96 | horizon_mins: Array of the number of minutes after the init time for each predicted frame 97 | of satellite data 98 | metric_values: Array of the mean metric value at each forecast horizon 99 | plot_name: The name under which to save the plot to wandb 100 | metric_name: The name of the metric used to label the y-axis in the uploaded plot 101 | """ 102 | data = list(zip(horizon_mins, metric_values, strict=True)) 103 | table = wandb.Table(data=data, columns=["horizon_mins", metric_name]) 104 | wandb.log({plot_name: wandb.plot.line(table, "horizon_mins", metric_name, title=plot_name)}) 105 | 106 | 107 | def log_per_channel_metrics_to_wandb( 108 | channel_names: list[str], 109 | metric_values: ChannelArray, 110 | plot_name: str, 111 | metric_name: str, 112 | ) -> None: 113 | """Upload a bar chart of metric value for each channel to wandb 114 | 115 | Args: 116 | channel_names: List of channel names for ordering purposes 117 | metric_values: Array of the mean metric value for each channel 118 | plot_name: The name under which to save the plot to wandb 119 | metric_name: The name of the metric used to label the y-axis in the uploaded plot 120 | """ 121 | data = list(zip(channel_names, metric_values, strict=True)) 122 | table = wandb.Table(data=data, columns=["channel name", metric_name]) 123 | wandb.log({plot_name: wandb.plot.bar(table, "channel name", metric_name, title=plot_name)}) 124 | 125 | 126 | def log_prediction_video_to_wandb( 127 | y_hat: SampleOutputArray, 128 | y: SampleOutputArray, 129 | video_name: str, 130 | channel_ind: int = 8, 131 | fps: int = 1, 132 | ) -> None: 133 | """Upload a video comparing the true and predicted future satellite data to wandb 134 | 135 | Args: 136 | y_hat: The predicted sequence of satellite data 137 | y: The true sequence of satellite data 138 | video_name: The name under which to save the video to wandb 139 | channel_ind: The channel number to show in the video 140 | fps: Frames per second of the resulting video 141 | """ 142 | 143 | # Copy the arrays so we don't modify the original 144 | y_hat = y_hat.copy() 145 | y = y.copy() 146 | 147 | # Find NaNs (or infilled NaNs) in ground truth 148 | mask = np.isnan(y) | (y == -1) 149 | 150 | # Set pixels which are NaN in the ground truth to 0 in both arrays 151 | y[mask] = 0 152 | y_hat[mask] = 0 153 | 154 | create_box = True 155 | 156 | # create a boundary box for the crop 157 | if y.shape[-2:] == IMAGE_SIZE_TUPLE: 158 | boxl, boxr, boxb, boxt = CUTOUT_MASK_BOUNDARY 159 | bsize = IMAGE_SIZE_TUPLE 160 | elif y.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE: 161 | boxl, boxr, boxb, boxt = CROPPED_CUTOUT_MASK_BOUNDARY 162 | bsize = CROPPED_IMAGE_SIZE_TUPLE 163 | else: 164 | create_box = False 165 | 166 | if create_box: 167 | # box mask 168 | _maskb = np.ones(bsize, dtype=np.float32) 169 | _maskb[boxb : boxb + 2, boxl:boxr] = np.nan # Top edge 170 | _maskb[boxt - 2 : boxt, boxl:boxr] = np.nan # Bottom edge 171 | _maskb[boxb:boxt, boxl : boxl + 2] = np.nan # Left edge 172 | _maskb[boxb:boxt, boxr - 2 : boxr] = np.nan # Right edge 173 | maskb: Float32[npt.NDArray[np.float32], "1 1 a b"] = _maskb[np.newaxis, np.newaxis, :, :] 174 | 175 | y = y * maskb 176 | y_hat = y_hat * maskb 177 | 178 | # Tranpose the arrays so time is the first dimension and select the channel 179 | # Then flip the frames so they are in the correct orientation for the video 180 | y_frames = y.transpose(1, 2, 3, 0)[:, ::-1, ::-1, channel_ind : channel_ind + 1] 181 | y_hat_frames = y_hat.transpose(1, 2, 3, 0)[:, ::-1, ::-1, channel_ind : channel_ind + 1] 182 | 183 | # Concatenate the predicted and true frames so they are displayed side by side 184 | video_array = np.concatenate([y_hat_frames, y_frames], axis=2) 185 | 186 | # Clip the values and rescale to be between 0 and 255 and repeat for RGB channels 187 | video_array = video_array.clip(0, 1) 188 | video_array = np.repeat(video_array, 3, axis=3) * 255 189 | # add Alpha channel 190 | video_array = np.concatenate( 191 | [video_array, np.full((*video_array[:, :, :, 0].shape, 1), 255)], axis=3 192 | ) 193 | 194 | # calculate the difference between the prediction and the ground truth and add colour 195 | y_diff_frames = y_hat_frames - y_frames 196 | diff_ccmap = plt.get_cmap("bwr")(Normalize(vmin=-1, vmax=1)(y_diff_frames[:, :, :, 0])) 197 | diff_ccmap = diff_ccmap * 255 198 | 199 | # combine add difference to the video array 200 | video_array = np.concatenate([video_array, diff_ccmap], axis=2) 201 | 202 | # Set bounding box to a colour so it is visible 203 | if create_box: 204 | video_array[:, :, :, 0][np.isnan(video_array[:, :, :, 0])] = 250 205 | video_array[:, :, :, 1][np.isnan(video_array[:, :, :, 1])] = 40 206 | video_array[:, :, :, 2][np.isnan(video_array[:, :, :, 2])] = 10 207 | video_array[:, :, :, 3][np.isnan(video_array[:, :, :, 3])] = 255 208 | 209 | video_array = video_array.transpose(0, 3, 1, 2) 210 | video_array = video_array.astype(np.uint8) 211 | 212 | wandb.log( 213 | { 214 | video_name: wandb.Video( 215 | video_array, 216 | caption="Sample prediction (left), ground truth (middle), difference (right)", 217 | fps=fps, 218 | ) 219 | } 220 | ) 221 | 222 | 223 | def score_model_on_all_metrics( 224 | model: AbstractModel, 225 | valid_dataset: ValidationSatelliteDataset, 226 | batch_size: int = 1, 227 | num_workers: int = 0, 228 | batch_limit: int | None = None, 229 | metric_names: tuple[str, ...] | list[str] = ("mae", "mse", "ssim"), 230 | metric_kwargs: dict[str, dict[str, Any]] | None = None, 231 | ) -> tuple[dict[str, MetricArray], list[str]]: 232 | """Calculate the scoreboard metrics for the given model on the validation dataset. 233 | 234 | Args: 235 | model (AbstractModel): The model to score. 236 | valid_dataset (ValidationSatelliteDataset): The validation dataset to score the model on. 237 | batch_size (int, optional): Defaults to 1. 238 | num_workers (int, optional): Defaults to 0. 239 | batch_limit (int | None, optional): Defaults to None. Stop after this many batches. 240 | For testing purposes only. 241 | metric_names (tuple[str, ...] | list[str]: Names of metrics to calculate. Need to be defined 242 | in cloudcasting.metrics. Defaults to ("mae", "mse", "ssim"). 243 | metric_kwargs (dict[str, dict[str, Any]] | None, optional): kwargs to pass to functions in 244 | cloudcasting.metrics. Defaults to None. 245 | 246 | Returns: 247 | tuple[dict[str, MetricArray], list[str]]: 248 | Element 0: keys are metric names, values are arrays of results 249 | averaged over all dims except horizon and channel. 250 | Element 1: list of channel names. 251 | """ 252 | 253 | # check the the data spacing perfectly divides the forecast horizon 254 | assert FORECAST_HORIZON_MINUTES % DATA_INTERVAL_SPACING_MINUTES == 0, ( 255 | "forecast horizon must be a multiple of the data interval " 256 | "(please make an issue on github if you see this!!!!)" 257 | ) 258 | 259 | valid_dataloader = DataLoader( 260 | valid_dataset, 261 | batch_size=batch_size, 262 | num_workers=num_workers, 263 | shuffle=False, 264 | collate_fn=numpy_validation_collate_fn, 265 | drop_last=False, 266 | ) 267 | 268 | if metric_kwargs is None: 269 | metric_kwargs_dict: dict[str, dict[str, Any]] = {} 270 | else: 271 | metric_kwargs_dict = metric_kwargs 272 | 273 | def get_pix_function( 274 | name: str, 275 | pix_kwargs: dict[str, dict[str, Any]], 276 | ) -> Callable[ 277 | [BatchOutputArrayJAX, BatchOutputArrayJAX], Float32[Array, "batch channels time"] 278 | ]: 279 | func = getattr(dm_pix, name) 280 | sig = inspect.signature(func) 281 | if "ignore_nans" in sig.parameters: 282 | func = partial(func, ignore_nans=True) 283 | if name in pix_kwargs: 284 | func = partial(func, **pix_kwargs[name]) 285 | return cast( 286 | Callable[ 287 | [BatchOutputArrayJAX, BatchOutputArrayJAX], Float32[Array, "batch channels time"] 288 | ], 289 | func, 290 | ) 291 | 292 | metric_funcs: dict[ 293 | str, 294 | Callable[[BatchOutputArrayJAX, BatchOutputArrayJAX], Float32[Array, "batch channels time"]], 295 | ] = {name: get_pix_function(name, metric_kwargs_dict) for name in metric_names} 296 | 297 | metrics: dict[str, list[Float32[Array, "batch channels time"]]] = { 298 | metric: [] for metric in metric_funcs 299 | } 300 | 301 | # we probably want to accumulate metrics here instead of taking the mean of means! 302 | loop_steps = len(valid_dataloader) if batch_limit is None else batch_limit 303 | 304 | info_str = f"Validating model on {loop_steps} batches..." 305 | logger.info(info_str) 306 | 307 | for i, (X, y) in tqdm(enumerate(valid_dataloader), total=loop_steps): 308 | y_hat = model(X) 309 | 310 | # identify the correct mask / create a mask if necessary 311 | if X.shape[-2:] == IMAGE_SIZE_TUPLE: 312 | mask = CUTOUT_MASK 313 | elif X.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE: 314 | mask = CROPPED_CUTOUT_MASK 315 | else: 316 | mask = np.ones(X.shape[-2:], dtype=np.float32) 317 | 318 | # cutout the GB area 319 | mask_full = mask[np.newaxis, np.newaxis, np.newaxis, :, :] 320 | y_cutout = y * mask_full 321 | y_hat = y_hat * mask_full 322 | 323 | # assert shapes are the same 324 | assert y_cutout.shape == y_hat.shape, f"{y_cutout.shape=} != {y_hat.shape=}" 325 | 326 | # If nan_to_num is used in the dataset, the model will output -1 for NaNs. We need to 327 | # convert these back to NaNs for the metrics 328 | y_cutout[y_cutout == -1] = np.nan 329 | 330 | # pix accepts arrays of shape [batch, height, width, channels]. 331 | # our arrays are of shape [batch, channels, time, height, width]. 332 | # channel dim would be reduced; we add a new axis to satisfy the shape reqs. 333 | # we then reshape to squash batch, channels, and time into the leading axis, 334 | # where the vmap in metrics.py will broadcast over the leading dim. 335 | y_jax = jnp.array(y_cutout).reshape(-1, *y_cutout.shape[-2:])[..., np.newaxis] 336 | y_hat_jax = jnp.array(y_hat).reshape(-1, *y_hat.shape[-2:])[..., np.newaxis] 337 | 338 | for metric_name, metric_func in metric_funcs.items(): 339 | # we reshape the result back into [batch, channels, time], 340 | # then take the mean over the batch 341 | metric_res = metric_func(y_hat_jax, y_jax).reshape(*y_cutout.shape[:3]) 342 | batch_reduced_metric = jnp.nanmean(metric_res, axis=0) 343 | metrics[metric_name].append(batch_reduced_metric) 344 | 345 | if batch_limit is not None and i >= batch_limit: 346 | break 347 | # convert back to numpy and reduce over all batches 348 | res = tree.map( 349 | lambda x: np.mean(np.array(x), axis=0), metrics, is_leaf=lambda x: isinstance(x, list) 350 | ) 351 | 352 | num_timesteps = FORECAST_HORIZON_MINUTES // DATA_INTERVAL_SPACING_MINUTES 353 | 354 | channel_names = valid_dataset.ds.variable.values.tolist() 355 | 356 | # technically, if we've made a mistake in the shapes/reduction dim, this would silently fail 357 | # if the number of batches equals the number of timesteps 358 | for v in res.values(): 359 | msg = ( 360 | f"metric {v.shape} is not the correct shape " 361 | f"(should be {(len(channel_names), num_timesteps)})" 362 | ) 363 | assert v.shape == (len(channel_names), num_timesteps), msg 364 | 365 | return res, channel_names 366 | 367 | 368 | def calc_mean_metrics(metrics_dict: dict[str, MetricArray]) -> dict[str, float]: 369 | """Calculate the mean metric reduced over all dimensions. 370 | 371 | Args: 372 | metrics_dict: dict mapping metric names to arrays of metric values 373 | 374 | Returns: 375 | dict: dict mapping metric names to mean metric values 376 | """ 377 | return {k: float(np.mean(v)) for k, v in metrics_dict.items()} 378 | 379 | 380 | def calc_mean_metrics_per_horizon(metrics_dict: dict[str, MetricArray]) -> dict[str, TimeArray]: 381 | """Calculate the mean of each metric for each forecast horizon. 382 | 383 | Args: 384 | metrics_dict: dict mapping metric names to arrays of metric values 385 | 386 | Returns: 387 | dict: dict mapping metric names to array of mean metric values for each horizon 388 | """ 389 | return {k: np.mean(v, axis=0) for k, v in metrics_dict.items()} 390 | 391 | 392 | def calc_mean_metrics_per_channel(metrics_dict: dict[str, MetricArray]) -> dict[str, ChannelArray]: 393 | """Calculate the mean of each metric for each channel. 394 | 395 | Args: 396 | metrics_dict: dict mapping metric names to arrays of metric values 397 | 398 | Returns: 399 | dict: dict mapping metric names to array of mean metric values for each channel 400 | """ 401 | return {k: np.mean(v, axis=1) for k, v in metrics_dict.items()} 402 | 403 | 404 | def validate( 405 | model: AbstractModel, 406 | data_path: list[str] | str, 407 | wandb_project_name: str, 408 | wandb_run_name: str, 409 | nan_to_num: bool = False, 410 | batch_size: int = 1, 411 | num_workers: int = 0, 412 | batch_limit: int | None = None, 413 | ) -> None: 414 | """Run the full validation procedure on the model and log the results to wandb. 415 | 416 | Args: 417 | model (AbstractModel): the model to be validated 418 | data_path (str): path to the validation data set 419 | nan_to_num (bool, optional): Whether to convert NaNs to -1. Defaults to False. 420 | batch_size (int, optional): Defaults to 1. 421 | num_workers (int, optional): Defaults to 0. 422 | batch_limit (int | None, optional): Defaults to None. For testing purposes only. 423 | """ 424 | 425 | # Verify we can run the model forward 426 | try: 427 | model(np.zeros((1, NUM_CHANNELS, model.history_steps, *IMAGE_SIZE_TUPLE), dtype=np.float32)) 428 | except Exception as err: 429 | msg = f"Failed to run the model forward due to the following error: {err}" 430 | raise ValueError(msg) from err 431 | 432 | # grab api key from environment variable 433 | wandb_api_key = os.environ.get("WANDB_API_KEY") 434 | 435 | if not wandb_api_key: 436 | msg = "WANDB_API_KEY environment variable not set. Attempting interactive login..." 437 | logger.warning(msg) 438 | wandb.login() 439 | else: 440 | logger.info("API key found. Logging in to WandB...") 441 | wandb.login(key=wandb_api_key) 442 | 443 | # Set up the validation dataset 444 | valid_dataset = ValidationSatelliteDataset( 445 | zarr_path=data_path, 446 | history_mins=(model.history_steps - 1) * DATA_INTERVAL_SPACING_MINUTES, 447 | forecast_mins=FORECAST_HORIZON_MINUTES, 448 | sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES, 449 | nan_to_num=nan_to_num, 450 | ) 451 | 452 | # Calculate the metrics before logging to wandb 453 | channel_horizon_metrics_dict, channel_names = score_model_on_all_metrics( 454 | model, 455 | valid_dataset, 456 | batch_size=batch_size, 457 | num_workers=num_workers, 458 | batch_limit=batch_limit, 459 | ) 460 | 461 | # Calculate the mean of each metric reduced over forecast horizon and channels 462 | mean_metrics_dict = calc_mean_metrics(channel_horizon_metrics_dict) 463 | 464 | # Calculate the mean of each metric for each forecast horizon 465 | horizon_metrics_dict = calc_mean_metrics_per_horizon(channel_horizon_metrics_dict) 466 | 467 | # Calculate the mean of each metric for each channel 468 | channel_metrics_dict = calc_mean_metrics_per_channel(channel_horizon_metrics_dict) 469 | 470 | # Append to the wandb run name if we are limiting the number of batches 471 | if batch_limit is not None: 472 | wandb_run_name = wandb_run_name + f"-limited-to-{batch_limit}batches" 473 | 474 | # Start a wandb run 475 | wandb.init( 476 | project=wandb_project_name, 477 | name=wandb_run_name, 478 | entity=WANDB_ENTITY, 479 | ) 480 | 481 | # Add the cloudcasting version to the wandb config 482 | wandb.config.update( 483 | { 484 | "cloudcast_version": cloudcasting.__version__, 485 | "batch_limit": batch_limit, 486 | } 487 | ) 488 | 489 | # Log the model hyperparameters to wandb 490 | wandb.config.update(model.hyperparameters_dict()) 491 | 492 | # Log plot of the horizon metrics to wandb 493 | horizon_mins = np.arange( 494 | start=DATA_INTERVAL_SPACING_MINUTES, 495 | stop=FORECAST_HORIZON_MINUTES + DATA_INTERVAL_SPACING_MINUTES, 496 | step=DATA_INTERVAL_SPACING_MINUTES, 497 | dtype=np.float32, 498 | ) 499 | 500 | # Log the mean metrics to wandb 501 | for metric_name, value in mean_metrics_dict.items(): 502 | log_mean_metrics_to_wandb( 503 | metric_value=value, 504 | plot_name=f"{metric_name}-mean", 505 | metric_name=metric_name, 506 | ) 507 | 508 | for metric_name, horizon_array in horizon_metrics_dict.items(): 509 | log_per_horizon_metrics_to_wandb( 510 | horizon_mins=horizon_mins, 511 | metric_values=horizon_array, 512 | plot_name=f"{metric_name}-horizon", 513 | metric_name=metric_name, 514 | ) 515 | 516 | # Log mean metrics per-channel 517 | for metric_name, channel_array in channel_metrics_dict.items(): 518 | log_per_channel_metrics_to_wandb( 519 | channel_names=channel_names, 520 | metric_values=channel_array, 521 | plot_name=f"{metric_name}-channel", 522 | metric_name=metric_name, 523 | ) 524 | 525 | # Log selected video samples to wandb 526 | channel_inds = valid_dataset.ds.get_index("variable").get_indexer(VIDEO_SAMPLE_CHANNELS) 527 | 528 | for date in VIDEO_SAMPLE_DATES: 529 | X, y = valid_dataset[date] 530 | 531 | # Expand dimensions to batch size of 1 for model then contract to sample 532 | y_hat = model(X[None, ...])[0] 533 | 534 | for channel_ind, channel_name in zip(channel_inds, VIDEO_SAMPLE_CHANNELS, strict=False): 535 | log_prediction_video_to_wandb( 536 | y_hat=y_hat, 537 | y=y, 538 | video_name=f"sample_videos/{date} - {channel_name}", 539 | channel_ind=int(channel_ind), 540 | fps=1, 541 | ) 542 | 543 | 544 | def validate_from_config( 545 | config_file: Annotated[ 546 | str, typer.Option(help="Path to config file. Defaults to 'validate_config.yml'.") 547 | ] = "validate_config.yml", 548 | model_file: Annotated[ 549 | str, typer.Option(help="Path to Python file with model definition. Defaults to 'model.py'.") 550 | ] = "model.py", 551 | ) -> None: 552 | """CLI function to validate a model from a config file. Example templates of these files can 553 | be found at https://github.com/alan-turing-institute/ocf-model-template. 554 | 555 | Args: 556 | config_file (str): Path to config file. Defaults to "validate_config.yml". 557 | model_file (str): Path to Python file with model definition. Defaults to "model.py". 558 | """ 559 | with open(config_file) as f: 560 | config: dict[str, Any] = yaml.safe_load(f) 561 | 562 | # import the model definition from file 563 | spec = importlib.util.spec_from_file_location("usermodel", model_file) 564 | # type narrowing 565 | if spec is None or spec.loader is None: 566 | msg = f"Error importing {model_file}" 567 | raise ImportError(msg) 568 | module = importlib.util.module_from_spec(spec) 569 | sys.modules["usermodel"] = module 570 | spec.loader.exec_module(module) 571 | 572 | ModelClass = getattr(module, config["model"]["name"]) 573 | model = ModelClass(**(config["model"]["params"] or {})) 574 | 575 | validate(model, **config["validation"]) 576 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | import xarray as xr 7 | 8 | from cloudcasting.models import VariableHorizonModel 9 | 10 | xr.set_options(keep_attrs=True) # type: ignore[no-untyped-call] 11 | 12 | 13 | @pytest.fixture 14 | def temp_output_dir(tmp_path): 15 | return str(tmp_path) 16 | 17 | 18 | @pytest.fixture 19 | def sat_zarr_path(temp_output_dir): 20 | # Load dataset which only contains coordinates, but no data 21 | ds = xr.load_dataset( 22 | f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.netcdf" 23 | ) 24 | 25 | # Add time coord 26 | ds = ds.assign_coords(time=pd.date_range("2023-01-01 00:00", "2023-01-02 23:55", freq="5min")) 27 | 28 | # Add data to dataset 29 | ds["data"] = xr.DataArray( 30 | np.zeros([len(ds[c]) for c in ds.coords], dtype=np.float32), 31 | coords=ds.coords, 32 | ) 33 | 34 | # Transpose to variables, time, y, x (just in case) 35 | ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary") 36 | 37 | # Add some NaNs 38 | ds["data"].values[:, :, 0, 0] = np.nan 39 | 40 | # Specifiy chunking 41 | ds = ds.chunk({"time": 10, "variable": -1, "y_geostationary": -1, "x_geostationary": -1}) 42 | 43 | # Save temporarily as a zarr 44 | zarr_path = f"{temp_output_dir}/test_sat.zarr" 45 | ds.to_zarr(zarr_path) 46 | 47 | return zarr_path 48 | 49 | 50 | @pytest.fixture 51 | def val_dataset_hyperparams(): 52 | return { 53 | "x_geostationary_size": 8, 54 | "y_geostationary_size": 9, 55 | } 56 | 57 | 58 | @pytest.fixture 59 | def val_sat_zarr_path(temp_output_dir, val_dataset_hyperparams): 60 | # The validation set requires a much larger set of times so we create it separately 61 | # Load dataset which only contains coordinates, but no data 62 | ds = xr.load_dataset( 63 | f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.netcdf" 64 | ) 65 | 66 | # Make the dataset spatially small 67 | ds = ds.isel( 68 | x_geostationary=slice(0, val_dataset_hyperparams["x_geostationary_size"]), 69 | y_geostationary=slice(0, val_dataset_hyperparams["y_geostationary_size"]), 70 | ) 71 | 72 | # Add time coord 73 | ds = ds.assign_coords(time=pd.date_range("2022-01-01 00:00", "2022-12-31 23:45", freq="15min")) 74 | 75 | # Add data to dataset 76 | ds["data"] = xr.DataArray( 77 | np.zeros([len(ds[c]) for c in ds.coords], dtype=np.float32), 78 | coords=ds.coords, 79 | ) 80 | 81 | # Transpose to variables, time, y, x (just in case) 82 | ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary") 83 | 84 | # Add some NaNs 85 | ds["data"].values[:, :, 0, 0] = np.nan 86 | 87 | # Specifiy chunking 88 | ds = ds.chunk({"time": 10, "variable": -1, "y_geostationary": -1, "x_geostationary": -1}) 89 | 90 | # Save temporarily as a zarr 91 | zarr_path = f"{temp_output_dir}/val_test_sat.zarr" 92 | ds.to_zarr(zarr_path) 93 | 94 | return zarr_path 95 | 96 | 97 | class PersistenceModel(VariableHorizonModel): 98 | """A persistence model used solely for testing the validation procedure""" 99 | 100 | def forward(self, X): 101 | # Grab the most recent frame from the input data 102 | # There may be NaNs in the input data, so we need to handle these 103 | latest_frame = np.nan_to_num(X[..., -1:, :, :], nan=0.0, copy=True) 104 | 105 | # The NaN values in the input data could be filled with -1. Clip these to zero 106 | latest_frame = latest_frame.clip(0, 1) 107 | 108 | return np.repeat(latest_frame, self.rollout_steps, axis=-3) 109 | 110 | def hyperparameters_dict(self): 111 | return {"history_steps": self.history_steps} 112 | -------------------------------------------------------------------------------- /tests/legacy_metrics.py: -------------------------------------------------------------------------------- 1 | """Metrics for model output evaluation""" 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | from jaxtyping import Float as Float32 6 | from skimage.metrics import structural_similarity # type: ignore[import-not-found] 7 | 8 | # Type aliases for clarity + reuse 9 | Array = npt.NDArray[np.float32] # the type arg is ignored by jaxtyping, but is here for clarity 10 | TimeArray = Float32[Array, "time"] 11 | MetricArray = Float32[Array, "channels time"] 12 | ChannelArray = Float32[Array, "channels"] 13 | 14 | SampleInputArray = Float32[Array, "channels time height width"] 15 | BatchInputArray = Float32[Array, "batch channels time height width"] 16 | InputArray = SampleInputArray | BatchInputArray 17 | 18 | 19 | SampleOutputArray = Float32[Array, "channels rollout_steps height width"] 20 | BatchOutputArray = Float32[Array, "batch channels rollout_steps height width"] 21 | 22 | OutputArray = SampleOutputArray | BatchOutputArray 23 | 24 | 25 | def mae_single(input: SampleOutputArray, target: SampleOutputArray) -> MetricArray: 26 | """Mean absolute error for single (non-batched) image sequences. 27 | 28 | Args: 29 | input: Array of shape [channels, time, height, width] 30 | target: Array of shape [channels, time, height, width] 31 | 32 | Returns: 33 | Array of MAE values of shape [channel, time] 34 | """ 35 | absolute_error = np.abs(input - target) 36 | arr: MetricArray = np.nanmean(absolute_error, axis=(2, 3)) 37 | return arr 38 | 39 | 40 | def mae_batch(input: BatchOutputArray, target: BatchOutputArray) -> MetricArray: 41 | """Mean absolute error for batched image sequences. 42 | 43 | Args: 44 | input: Array of shape [batch, channels, time, height, width] 45 | target: Array of shape [batch, channels, time, height, width] 46 | 47 | Returns: 48 | Array of MAE values of shape [channel, time] 49 | """ 50 | absolute_error = np.abs(input - target) 51 | arr: MetricArray = np.nanmean(absolute_error, axis=(0, 3, 4)) 52 | return arr 53 | 54 | 55 | def mse_single(input: SampleOutputArray, target: SampleOutputArray) -> MetricArray: 56 | """Mean squared error for single (non-batched) image sequences. 57 | 58 | Args: 59 | input: Array of shape [channels, time, height, width] 60 | target: Array of shape [channels, time, height, width] 61 | 62 | Returns: 63 | Array of MSE values of shape [channel, time] 64 | """ 65 | square_error = (input - target) ** 2 66 | arr: MetricArray = np.nanmean(square_error, axis=(2, 3)) 67 | return arr 68 | 69 | 70 | def mse_batch(input: BatchOutputArray, target: BatchOutputArray) -> MetricArray: 71 | """Mean squared error for batched image sequences. 72 | 73 | Args: 74 | input: Array of shape [batch, channels, time, height, width] 75 | target: Array of shape [batch, channels, time, height, width] 76 | 77 | Returns: 78 | Array of MSE values of shape [channel, time] 79 | """ 80 | square_error = (input - target) ** 2 81 | arr: MetricArray = np.nanmean(square_error, axis=(0, 3, 4)) 82 | return arr 83 | 84 | 85 | def ssim_single(input: SampleOutputArray, target: SampleOutputArray) -> MetricArray: 86 | """Computes the Structural Similarity (SSIM) index for single (non-batched) image sequences. 87 | 88 | Args: 89 | input: Array of shape [channels, time, height, width] 90 | target: Array of shape [channels, time, height, width] 91 | 92 | Returns: 93 | Array of SSIM values of shape [channel, time] 94 | 95 | References: 96 | Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). 97 | Image quality assessment: From error visibility to structural similarity. 98 | IEEE Transactions on Image Processing, 13, 600-612. 99 | https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, 100 | DOI: 10.1109/TIP.2003.819861 101 | """ 102 | 103 | # This function assumes the data will be in the range 0-1 and will give invalid results if not 104 | _check_input_target_ranges(input, target) 105 | 106 | # The following param setting match Wang et. al. 2004 107 | gaussian_weights = True 108 | use_sample_covariance = False 109 | sigma = 1.5 110 | win_size = 11 111 | 112 | ssim_seq = [] 113 | for i_t in range(input.shape[1]): 114 | _, ssim_array = structural_similarity( 115 | input[:, i_t], 116 | target[:, i_t], 117 | data_range=1, 118 | channel_axis=0, 119 | full=True, 120 | gaussian_weights=gaussian_weights, 121 | use_sample_covariance=use_sample_covariance, 122 | sigma=sigma, 123 | win_size=win_size, 124 | ) 125 | 126 | # To avoid edge effects from the Gaussian filter we trim off the border 127 | trim_width = (win_size - 1) // 2 128 | ssim_array = ssim_array[:, trim_width:-trim_width, trim_width:-trim_width] 129 | # Take the mean of the SSIM array over channels, height, and width 130 | ssim_seq.append(np.nanmean(ssim_array, axis=(1, 2))) 131 | # stack along channel dimension 132 | arr: MetricArray = np.stack(ssim_seq, axis=1) 133 | return arr 134 | 135 | 136 | def ssim_batch(input: BatchOutputArray, target: BatchOutputArray) -> MetricArray: 137 | """Structural similarity for batched image sequences. 138 | 139 | Args: 140 | input: Array of shape [batch, channels, time, height, width] 141 | target: Array of shape [batch, channels, time, height, width] 142 | win_size: Side-length of the sliding window for comparison (must be odd) 143 | 144 | Returns: 145 | Array of SSIM values of shape [channel, time] 146 | """ 147 | # This function assumes the data will be in the range 0-1 and will give invalid results if not 148 | _check_input_target_ranges(input, target) 149 | 150 | ssim_samples = [] 151 | for i_b in range(input.shape[0]): 152 | ssim_samples.append(ssim_single(input[i_b], target[i_b])) 153 | arr: MetricArray = np.stack(ssim_samples, axis=0).mean(axis=0) 154 | return arr 155 | 156 | 157 | def _check_input_target_ranges(input: OutputArray, target: OutputArray) -> None: 158 | """Validate input and target arrays are within the 0-1 range. 159 | 160 | Args: 161 | input: Input array 162 | target: Target array 163 | 164 | Raises: 165 | ValueError: If input or target values are outside the 0-1 range. 166 | """ 167 | input_max, input_min = input.max(), input.min() 168 | target_max, target_min = target.max(), target.min() 169 | 170 | if (input_min < 0) | (input_max > 1) | (target_min < 0) | (target_max > 1): 171 | msg = ( 172 | f"Input and target must be in 0-1 range. " 173 | f"Input range: {input_min}-{input_max}. " 174 | f"Target range: {target_min}-{target_max}" 175 | ) 176 | raise ValueError(msg) 177 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from typer.testing import CliRunner 5 | 6 | from cloudcasting.cli import app 7 | 8 | 9 | @pytest.fixture 10 | def runner(): 11 | return CliRunner() 12 | 13 | 14 | @pytest.fixture 15 | def temp_output_dir(tmp_path): 16 | return str(tmp_path) 17 | 18 | 19 | def test_download_satellite_data(runner, temp_output_dir): 20 | # Define test parameters 21 | start_date = "2021-01-01 00:00" 22 | end_date = "2021-01-01 00:30" 23 | 24 | # Run the CLI command to download the file 25 | result = runner.invoke( 26 | app, 27 | [ 28 | "download", 29 | start_date, 30 | end_date, 31 | temp_output_dir, 32 | "--download-frequency=15min", 33 | "--lon-min=-1", 34 | "--lon-max=1", 35 | "--lat-min=50", 36 | "--lat-max=51", 37 | ], 38 | ) 39 | 40 | # Check if the command executed successfully 41 | assert result.exit_code == 0 42 | 43 | # Check if the output file was created 44 | expected_file = os.path.join(temp_output_dir, "2021_training_nonhrv.zarr") 45 | assert os.path.exists(expected_file) 46 | -------------------------------------------------------------------------------- /tests/test_data/non_hrv_shell.netcdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/cloudcasting/0da4087bba01823d2477dba61ea3e6cfd557212c/tests/test_data/non_hrv_shell.netcdf -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from cloudcasting.constants import ( 6 | DATA_INTERVAL_SPACING_MINUTES, 7 | FORECAST_HORIZON_MINUTES, 8 | NUM_CHANNELS, 9 | NUM_FORECAST_STEPS, 10 | ) 11 | from cloudcasting.dataset import ( 12 | SatelliteDataModule, 13 | SatelliteDataset, 14 | ValidationSatelliteDataset, 15 | find_valid_t0_times, 16 | load_satellite_zarrs, 17 | ) 18 | 19 | 20 | def test_load_satellite_zarrs(sat_zarr_path): 21 | # Check can load with string and list of string(s) 22 | ds = load_satellite_zarrs(sat_zarr_path) 23 | ds = load_satellite_zarrs([sat_zarr_path]) 24 | 25 | # Dataset is a full 48 hours of 5 minutely data -> 48hours * (60/5) = 576 26 | assert len(ds.time) == 576 27 | 28 | 29 | def test_find_valid_t0_times(sat_zarr_path): 30 | ds = load_satellite_zarrs(sat_zarr_path) 31 | 32 | t0_times = find_valid_t0_times( 33 | pd.DatetimeIndex(ds.time), 34 | history_mins=60, 35 | forecast_mins=120, 36 | sample_freq_mins=5, 37 | ) 38 | 39 | # original timesteps 576 40 | # forecast length buffer - (120 / 5) 41 | # history length buffer - (60 / 5) 42 | # ------------ 43 | # Total 540 44 | 45 | assert len(t0_times) == 540 46 | 47 | t0_times = find_valid_t0_times( 48 | pd.DatetimeIndex(ds.time), 49 | history_mins=60, 50 | forecast_mins=120, 51 | sample_freq_mins=15, 52 | ) 53 | 54 | # original 15 minute timesteps 576 / 3 55 | # forecast length buffer - (120 / 15) 56 | # history length buffer - (60 / 15) 57 | # ------------ 58 | # Total 180 59 | 60 | assert len(t0_times) == 180 61 | 62 | 63 | def test_satellite_dataset(sat_zarr_path): 64 | dataset = SatelliteDataset( 65 | zarr_path=sat_zarr_path, 66 | start_time=None, 67 | end_time=None, 68 | history_mins=60, 69 | forecast_mins=120, 70 | sample_freq_mins=5, 71 | ) 72 | 73 | assert len(dataset) == 540 74 | 75 | X, y = dataset[0] 76 | 77 | # 11 channels 78 | # 20 y-dim steps 79 | # 49 x-dim steps 80 | # (60 / 5) + 1 = 13 history steps 81 | # (120 / 5) = 24 forecast steps 82 | assert X.shape == (11, 13, 20, 49) 83 | assert y.shape == (11, 24, 20, 49) 84 | 85 | assert np.sum(np.isnan(X)) == 11 * 13 86 | assert np.sum(np.isnan(y)) == 11 * 24 87 | 88 | 89 | def test_satellite_datamodule(sat_zarr_path): 90 | datamodule = SatelliteDataModule( 91 | zarr_path=sat_zarr_path, 92 | history_mins=60, 93 | forecast_mins=120, 94 | sample_freq_mins=5, 95 | batch_size=2, 96 | num_workers=2, 97 | prefetch_factor=None, 98 | ) 99 | 100 | dl = datamodule.train_dataloader() 101 | 102 | X, y = next(iter(dl)) 103 | 104 | assert X.shape == (2, 11, 13, 20, 49) 105 | assert y.shape == (2, 11, 24, 20, 49) 106 | 107 | 108 | def test_satellite_datamodule_variables(sat_zarr_path): 109 | variables = ["VIS006", "VIS008"] 110 | 111 | datamodule = SatelliteDataModule( 112 | zarr_path=sat_zarr_path, 113 | history_mins=60, 114 | forecast_mins=120, 115 | sample_freq_mins=5, 116 | batch_size=2, 117 | num_workers=2, 118 | prefetch_factor=None, 119 | variables=variables, 120 | ) 121 | 122 | dl = datamodule.train_dataloader() 123 | 124 | X, y = next(iter(dl)) 125 | 126 | assert X.shape == (2, 2, 13, 20, 49) 127 | assert y.shape == (2, 2, 24, 20, 49) 128 | 129 | 130 | def test_satellite_dataset_nan_to_num(sat_zarr_path): 131 | dataset = SatelliteDataset( 132 | zarr_path=sat_zarr_path, 133 | start_time=None, 134 | end_time=None, 135 | history_mins=60, 136 | forecast_mins=120, 137 | sample_freq_mins=5, 138 | nan_to_num=True, 139 | ) 140 | assert len(dataset) == 540 141 | 142 | X, y = dataset[0] 143 | 144 | # 11 channels 145 | # 20 y-dim steps 146 | # 49 x-dim steps 147 | # (60 / 5) + 1 = 13 history steps 148 | # (120 / 5) = 24 forecast steps 149 | assert X.shape == (11, 13, 20, 49) 150 | assert y.shape == (11, 24, 20, 49) 151 | 152 | assert np.sum(np.isnan(X)) == 0 153 | assert np.sum(np.isnan(y)) == 0 154 | 155 | assert np.sum(X[:, :, 0, 0]) == -11 * 13 156 | assert np.sum(y[:, :, 0, 0]) == -11 * 24 157 | 158 | 159 | def test_validation_dataset(val_sat_zarr_path, val_dataset_hyperparams): 160 | dataset = ValidationSatelliteDataset( 161 | zarr_path=val_sat_zarr_path, 162 | history_mins=60, 163 | forecast_mins=FORECAST_HORIZON_MINUTES, 164 | sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES, 165 | ) 166 | 167 | # There are 3744 init times which all models must make predictions for 168 | assert len(dataset) == 3744 169 | 170 | X, y = dataset[0] 171 | 172 | # 11 channels 173 | # 2 y-dim steps 174 | # 1 x-dim steps 175 | # (60 / 15) + 1 = 5 history steps 176 | # (180 / 15) = 12 forecast steps 177 | assert X.shape == ( 178 | NUM_CHANNELS, 179 | 5, 180 | val_dataset_hyperparams["y_geostationary_size"], 181 | val_dataset_hyperparams["x_geostationary_size"], 182 | ) 183 | assert y.shape == ( 184 | NUM_CHANNELS, 185 | NUM_FORECAST_STEPS, 186 | val_dataset_hyperparams["y_geostationary_size"], 187 | val_dataset_hyperparams["x_geostationary_size"], 188 | ) 189 | 190 | 191 | def test_validation_dataset_raises_error(sat_zarr_path): 192 | with pytest.raises(ValueError, match="The following validation t0 times are not available"): 193 | ValidationSatelliteDataset( 194 | zarr_path=sat_zarr_path, 195 | history_mins=60, 196 | forecast_mins=FORECAST_HORIZON_MINUTES, 197 | sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES, 198 | ) 199 | -------------------------------------------------------------------------------- /tests/test_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | import xarray as xr 6 | 7 | from cloudcasting.download import download_satellite_data 8 | 9 | 10 | @pytest.fixture 11 | def temp_output_dir(tmp_path): 12 | return str(tmp_path) 13 | 14 | 15 | def test_download_satellite_data(temp_output_dir): 16 | # Define test parameters 17 | start_date = "2021-01-01 00:00" 18 | end_date = "2021-01-01 00:30" 19 | 20 | # Run the function to download the file 21 | download_satellite_data( 22 | start_date, 23 | end_date, 24 | temp_output_dir, 25 | download_frequency="15min", 26 | lon_min=-1, 27 | lon_max=1, 28 | lat_min=50, 29 | lat_max=51, 30 | ) 31 | 32 | # Check if the output file was created 33 | expected_file = os.path.join(temp_output_dir, "2021_training_nonhrv.zarr") 34 | assert os.path.exists(expected_file) 35 | 36 | 37 | def test_download_satellite_data_test_2022_set(temp_output_dir): 38 | # Only run this test on 2022 as it's the only year with a test_2022 set. 39 | # Want to make sure that the --test-2022-set flag works as expected. 40 | start_date = "2022-01-01 00:00" 41 | end_date = "2022-03-01 00:00" 42 | 43 | # Run the function with the --test_2022-set flag 44 | download_satellite_data( 45 | start_date, 46 | end_date, 47 | temp_output_dir, 48 | download_frequency="168h", 49 | lon_min=-1, 50 | lon_max=1, 51 | lat_min=50, 52 | lat_max=51, 53 | test_2022_set=True, 54 | ) 55 | 56 | # Check if the output file was created and contains the expected data 57 | expected_file = os.path.join(temp_output_dir, "2022_test_nonhrv.zarr") 58 | assert os.path.exists(expected_file) 59 | 60 | ds = xr.open_zarr(expected_file) 61 | # Check that the data is only from the expected days of the year: every other 14 days, 62 | # starting from day 15 of the year. 63 | for day in [15, 22, 43, 50]: 64 | assert day in ds.time.dt.dayofyear.values 65 | 66 | 67 | def test_download_satellite_data_2022_nontest_set(temp_output_dir): 68 | # Only run this test on 2022 as it's the only year with a test set. 69 | # Want to make sure that the --test-2022-set flag works as expected. 70 | # We need to make jumps of at least 2 weeks to ensure that the test set is used. 71 | start_date = "2022-01-01 00:00" 72 | end_date = "2022-03-01 00:00" 73 | 74 | # Run the function with the --test-set flag turned off 75 | download_satellite_data( 76 | start_date, 77 | end_date, 78 | temp_output_dir, 79 | download_frequency="168h", 80 | lon_min=-1, 81 | lon_max=1, 82 | lat_min=50, 83 | lat_max=51, 84 | ) 85 | 86 | # Check if the output file was created and contains the expected data 87 | expected_file = os.path.join(temp_output_dir, "2022_training_nonhrv.zarr") 88 | assert os.path.exists(expected_file) 89 | 90 | # Now, we're in the training set 91 | ds = xr.open_zarr(expected_file) 92 | 93 | # Check that the data is only from the expected days of the year: every other 14 days, 94 | # starting from day 1 of the year. 95 | for day in [1, 8, 29, 36, 57]: 96 | assert day in ds.time.dt.dayofyear.values 97 | 98 | 99 | def test_download_satellite_data_test_2021_set(temp_output_dir): 100 | # Want to make sure that the --test-2022-set flag works as expected. 101 | start_date = "2021-01-01 00:00" 102 | end_date = "2021-01-01 00:30" 103 | 104 | # Run the function with the --test-2022-set flag 105 | # Check if the expected error was raised 106 | with pytest.raises(ValueError, match=r"Test data is only defined for 2022"): 107 | download_satellite_data( 108 | start_date, 109 | end_date, 110 | temp_output_dir, 111 | download_frequency="15min", 112 | lon_min=-1, 113 | lon_max=1, 114 | lat_min=50, 115 | lat_max=51, 116 | test_2022_set=True, 117 | ) 118 | 119 | 120 | def test_download_satellite_data_verify_set(temp_output_dir): 121 | # Want to make sure that the --verify-2023-set flag works as expected. 122 | start_date = "2023-01-01 00:00" 123 | end_date = "2023-01-01 00:30" 124 | 125 | # Run the function with the --verify-2023-set flag 126 | # Check if the expected error was raised 127 | with pytest.raises( 128 | ValueError, 129 | match=r"Verification data requires a start date of '2023-01-01 00:00'", 130 | ): 131 | download_satellite_data( 132 | start_date, 133 | end_date, 134 | temp_output_dir, 135 | download_frequency="15min", 136 | lon_min=-1, 137 | lon_max=1, 138 | lat_min=50, 139 | lat_max=51, 140 | verify_2023_set=True, 141 | ) 142 | 143 | 144 | def test_download_satellite_data_2023_not_verify(temp_output_dir): 145 | # Want to make sure that the --verify-2023-set flag works as expected. 146 | start_date = "2023-01-01 00:00" 147 | end_date = "2023-01-01 00:30" 148 | 149 | # Run the function with the --verify-2023-set flag 150 | # Check if the expected error was raised 151 | with pytest.raises(ValueError, match=r"2023 data is reserved for the verification process"): 152 | download_satellite_data( 153 | start_date, 154 | end_date, 155 | temp_output_dir, 156 | download_frequency="15min", 157 | lon_min=-1, 158 | lon_max=1, 159 | lat_min=50, 160 | lat_max=51, 161 | ) 162 | 163 | 164 | def test_irregular_start_date(temp_output_dir): 165 | # Define test parameters 166 | start_date = "2021-01-01 00:02" 167 | end_date = "2021-01-01 00:30" 168 | 169 | # Run the function to download the file 170 | download_satellite_data( 171 | start_date, 172 | end_date, 173 | temp_output_dir, 174 | download_frequency="15min", 175 | lon_min=-1, 176 | lon_max=1, 177 | lat_min=50, 178 | lat_max=51, 179 | ) 180 | 181 | # Check if the output file was created 182 | expected_file = os.path.join(temp_output_dir, "2021_training_nonhrv.zarr") 183 | assert os.path.exists(expected_file) 184 | 185 | ds = xr.open_zarr(expected_file) 186 | # Check that the data ignored the 00:02 entry; only the 00:15 and 00:30 entries should exist. 187 | assert np.all(ds.time.dt.minute.values == [np.int64(15), np.int64(30)]) 188 | 189 | 190 | def test_download_satellite_data_mock_to_zarr(temp_output_dir, monkeypatch): 191 | # make a tiny dataset to mock the to_zarr function, 192 | # but use netcdf instead of zarr (as to not recurse) 193 | mock_file_name = f"{temp_output_dir}/mock.nc" 194 | 195 | def mock_to_zarr(*args, **kwargs): 196 | xr.Dataset({"data": xr.DataArray(np.zeros([1, 1, 1, 1]))}).to_netcdf(mock_file_name) 197 | 198 | monkeypatch.setattr("xarray.Dataset.to_zarr", mock_to_zarr) 199 | 200 | # Define test parameters (known missing data here somewhere) 201 | start_date = "2020-06-01 00:00" 202 | end_date = "2020-06-30 23:55" 203 | 204 | # Run the function to download the file 205 | download_satellite_data( 206 | start_date, 207 | end_date, 208 | temp_output_dir, 209 | download_frequency="15min", 210 | lon_min=-1, 211 | lon_max=1, 212 | lat_min=50, 213 | lat_max=51, 214 | ) 215 | 216 | # Check if the output file was created 217 | expected_file = os.path.join(temp_output_dir, mock_file_name) 218 | assert os.path.exists(expected_file) 219 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | """Test if metrics match the legacy metrics""" 2 | 3 | import inspect 4 | from functools import partial 5 | from typing import cast 6 | 7 | import jax.numpy as jnp 8 | import jax.random as jr 9 | import numpy as np 10 | import pytest 11 | from jaxtyping import Array, Float32 12 | from legacy_metrics import mae_batch, mse_batch, ssim_batch 13 | 14 | from cloudcasting.metrics import mae, mse, ssim 15 | 16 | 17 | def apply_pix_metric(metric_func, y_hat, y) -> Float32[Array, "batch channels time"]: 18 | """Apply a pix metric to a sample of satellite data 19 | Args: 20 | metric_func: The pix metric function to apply 21 | y_hat: The predicted sequence of satellite data 22 | y: The true sequence of satellite data 23 | 24 | Returns: 25 | The pix metric value for the sample 26 | """ 27 | # pix accepts arrays of shape [batch, height, width, channels]. 28 | # our arrays are of shape [batch, channels, time, height, width]. 29 | # channel dim would be reduced; we add a new axis to satisfy the shape reqs. 30 | # we then reshape to squash batch, channels, and time into the leading axis, 31 | # where the vmap in metrics.py will broadcast over the leading dim. 32 | y_jax = jnp.array(y).reshape(-1, *y.shape[-2:])[..., np.newaxis] 33 | y_hat_jax = jnp.array(y_hat).reshape(-1, *y_hat.shape[-2:])[..., np.newaxis] 34 | 35 | sig = inspect.signature(metric_func) 36 | if "ignore_nans" in sig.parameters: 37 | metric_func = partial(metric_func, ignore_nans=True) 38 | 39 | # we reshape the result back into [batch, channels, time], 40 | # then take the mean over the batch 41 | return cast(Float32[Array, "batch channels time"], metric_func(y_hat_jax, y_jax)).reshape( 42 | *y.shape[:3] 43 | ) 44 | 45 | 46 | @pytest.mark.parametrize( 47 | ("metric_func", "legacy_func"), 48 | [ 49 | (mae, mae_batch), 50 | (mse, mse_batch), 51 | (ssim, ssim_batch), 52 | ], 53 | ) 54 | def test_metrics(metric_func, legacy_func): 55 | """Test if metrics match the legacy metrics""" 56 | # Create a sample input batch 57 | shape = (1, 3, 10, 100, 100) 58 | key = jr.key(321) 59 | key, k1, k2 = jr.split(key, 3) 60 | y_hat = jr.uniform(k1, shape, minval=0, maxval=1) 61 | y = jr.uniform(k2, shape, minval=0, maxval=1) 62 | 63 | # Add NaNs to the input 64 | y = y.at[:, :, :, 0, 0].set(np.nan) 65 | 66 | # Call the metric function 67 | metric = apply_pix_metric(metric_func, y_hat, y).mean(axis=0) 68 | 69 | # Check the shape of the output 70 | assert metric.shape == (3, 10) 71 | 72 | # Check the values of the output 73 | legacy_res = legacy_func(y_hat, y) 74 | 75 | # Lower tolerance for ssim (differences in implementation) 76 | rtol = 0.001 if metric_func == ssim else 1e-5 77 | 78 | assert np.allclose(metric, legacy_res, rtol=rtol), ( 79 | f"Metric {metric_func} does not match legacy metric {legacy_func}" 80 | ) 81 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from conftest import PersistenceModel 4 | from jaxtyping import TypeCheckError 5 | 6 | from cloudcasting.constants import NUM_FORECAST_STEPS 7 | from cloudcasting.models import AbstractModel 8 | 9 | 10 | @pytest.fixture 11 | def model(): 12 | return PersistenceModel(history_steps=1, rollout_steps=NUM_FORECAST_STEPS) 13 | 14 | 15 | def test_forward(model): 16 | # Create a sample input batch 17 | X = np.random.rand(1, 3, 10, 100, 100) 18 | 19 | # Call the forward method 20 | y_hat = model.forward(X) 21 | 22 | # Check the shape of the output 23 | assert y_hat.shape == (1, 3, model.rollout_steps, 100, 100) 24 | 25 | 26 | def test_check_predictions_no_nans(model): 27 | # Create a sample prediction array without NaNs 28 | y_hat = np.random.rand(1, 3, model.rollout_steps, 100, 100) 29 | 30 | # Call the check_predictions method 31 | model.check_predictions(y_hat) 32 | 33 | 34 | def test_check_predictions_with_nans(model): 35 | # Create a sample prediction array with NaNs 36 | y_hat = np.random.rand(1, 3, model.rollout_steps, 100, 100) 37 | y_hat[0, 0, 0, 0, 0] = np.nan 38 | 39 | # Call the check_predictions method and expect a ValueError 40 | with pytest.raises(ValueError, match="Predictions contain NaNs"): 41 | model.check_predictions(y_hat) 42 | 43 | 44 | def test_check_predictions_within_range(model): 45 | # Create a sample prediction array within the expected range 46 | y_hat = np.random.rand(1, 3, model.rollout_steps, 100, 100) 47 | 48 | # Call the check_predictions method 49 | model.check_predictions(y_hat) 50 | 51 | 52 | def test_check_predictions_outside_range(model): 53 | # Create a sample prediction array outside the expected range 54 | y_hat = np.random.rand(1, 3, model.rollout_steps, 100, 100) * 2 55 | 56 | # Call the check_predictions method and expect a ValueError 57 | with pytest.raises(ValueError, match="The predictions must be in the range "): 58 | model.check_predictions(y_hat) 59 | 60 | 61 | def test_call(model): 62 | # Create a sample input batch 63 | X = np.random.rand(1, 3, model.history_steps, 100, 100) 64 | 65 | # Call the __call__ method 66 | y_hat = model(X) 67 | 68 | # Check the shape of the output 69 | assert y_hat.shape == (1, 3, model.rollout_steps, 100, 100) 70 | 71 | 72 | def test_incorrect_shapes(model): 73 | # Create a sample input batch with incorrect shapes 74 | X = np.random.rand(1, 3, 10, 100) 75 | 76 | # Call the __call__ method and expect a TypeCheckError 77 | with pytest.raises(TypeCheckError): 78 | model(X) 79 | 80 | 81 | def test_incorrect_horizon(): 82 | # define a model with a different forecast horizon 83 | class Model(AbstractModel): 84 | def __init__(self, history_steps: int) -> None: 85 | super().__init__(history_steps) 86 | 87 | def forward(self, X): 88 | # Grab the most recent frame from the input data 89 | latest_frame = X[..., -1:, :, :] 90 | 91 | # The NaN values in the input data could be filled with -1. Clip these to zero 92 | latest_frame = latest_frame.clip(0, 1) 93 | 94 | return np.repeat(latest_frame, NUM_FORECAST_STEPS + 1, axis=-3) 95 | 96 | def hyperparameters_dict(self): 97 | return {"history_steps": self.history_steps} 98 | 99 | model = Model(history_steps=1) 100 | X = np.random.rand(1, 3, 1, 100, 100).astype(np.float32) 101 | 102 | # Call the __call__ method and expect a ValueError 103 | with pytest.raises(ValueError, match="The number of forecast steps in the model"): 104 | model(X) 105 | -------------------------------------------------------------------------------- /tests/test_validation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from conftest import PersistenceModel 4 | 5 | from cloudcasting.constants import ( 6 | DATA_INTERVAL_SPACING_MINUTES, 7 | FORECAST_HORIZON_MINUTES, 8 | NUM_FORECAST_STEPS, 9 | ) 10 | from cloudcasting.dataset import ValidationSatelliteDataset 11 | from cloudcasting.validation import ( 12 | calc_mean_metrics, 13 | score_model_on_all_metrics, 14 | validate, 15 | validate_from_config, 16 | ) 17 | 18 | 19 | @pytest.fixture 20 | def model(): 21 | return PersistenceModel(history_steps=1, rollout_steps=NUM_FORECAST_STEPS) 22 | 23 | 24 | @pytest.mark.parametrize("nan_to_num", [True, False]) 25 | def test_score_model_on_all_metrics(model, val_sat_zarr_path, nan_to_num): 26 | # Create valid dataset 27 | valid_dataset = ValidationSatelliteDataset( 28 | zarr_path=val_sat_zarr_path, 29 | history_mins=(model.history_steps - 1) * DATA_INTERVAL_SPACING_MINUTES, 30 | forecast_mins=FORECAST_HORIZON_MINUTES, 31 | sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES, 32 | nan_to_num=nan_to_num, 33 | ) 34 | 35 | metric_names = ("mae", "mse", "ssim") 36 | # use small filter size to not propagate nan to the whole image 37 | # (this is only because our test images are very small (8x9) -- 38 | # the filter window of size 11 would be bigger than the image!) 39 | metric_kwargs = {"ssim": {"filter_size": 2}} 40 | 41 | # Call the score_model_on_all_metrics function 42 | metrics_dict, channels = score_model_on_all_metrics( 43 | model=model, 44 | valid_dataset=valid_dataset, 45 | batch_size=2, 46 | num_workers=0, 47 | batch_limit=3, 48 | metric_names=metric_names, 49 | metric_kwargs=metric_kwargs, 50 | ) 51 | 52 | # Check all the expected keys are there 53 | assert tuple(metrics_dict.keys()) == metric_names 54 | 55 | for metric_name, metric_array in metrics_dict.items(): 56 | # check all the items have the expected shape 57 | assert metric_array.shape == ( 58 | len(channels), 59 | NUM_FORECAST_STEPS, 60 | ), f"Metric {metric_name} has the wrong shape" 61 | 62 | assert not np.any(np.isnan(metric_array)), f"Metric '{metric_name}' is predicting NaNs!" 63 | 64 | 65 | def test_calc_mean_metrics(): 66 | # Create a test dictionary of metrics (channels, time) 67 | test_metrics_dict = { 68 | "mae": np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), 69 | "mse": np.array([[4.0, 5.0, 6.0], [4.0, 5.0, 6.0]]), 70 | } 71 | 72 | # Call the calc_mean_metrics function 73 | mean_metrics_dict = calc_mean_metrics(test_metrics_dict) 74 | 75 | # Check the expected keys are present 76 | assert mean_metrics_dict.keys() == {"mae", "mse"} 77 | 78 | # Check the expected values are present 79 | assert mean_metrics_dict["mae"] == 2 80 | assert mean_metrics_dict["mse"] == 5 81 | 82 | 83 | def test_validate(model, val_sat_zarr_path, mocker): 84 | # Mock the wandb functions so they aren't run in testing 85 | mocker.patch("wandb.login") 86 | mocker.patch("wandb.init") 87 | mocker.patch("wandb.config") 88 | mocker.patch("wandb.log") 89 | mocker.patch("wandb.plot.line") 90 | mocker.patch("wandb.plot.bar") 91 | 92 | validate( 93 | model=model, 94 | data_path=val_sat_zarr_path, 95 | wandb_project_name="cloudcasting-pytest", 96 | wandb_run_name="test_validate", 97 | nan_to_num=False, 98 | batch_size=2, 99 | num_workers=0, 100 | batch_limit=4, 101 | ) 102 | 103 | 104 | def test_validate_cli(val_sat_zarr_path, mocker): 105 | # write out an example model.py file 106 | with open("model.py", "w") as f: 107 | f.write( 108 | """ 109 | from cloudcasting.models import AbstractModel 110 | from cloudcasting.constants import NUM_FORECAST_STEPS 111 | import numpy as np 112 | 113 | class Model(AbstractModel): 114 | def __init__(self, history_steps: int, sigma: float) -> None: 115 | super().__init__(history_steps) 116 | self.sigma = sigma 117 | 118 | def forward(self, X): 119 | shape = X.shape 120 | return np.ones((shape[0], shape[1], NUM_FORECAST_STEPS, shape[3], shape[4])) 121 | 122 | def hyperparameters_dict(self): 123 | return {"sigma": self.sigma} 124 | """ 125 | ) 126 | 127 | # write out an example validate_config.yml file 128 | with open("validate_config.yml", "w") as f: 129 | f.write( 130 | f""" 131 | model: 132 | name: Model 133 | params: 134 | history_steps: 1 135 | sigma: 0.1 136 | validation: 137 | data_path: {val_sat_zarr_path} 138 | wandb_project_name: cloudcasting-pytest 139 | wandb_run_name: test_validate 140 | nan_to_num: False 141 | batch_size: 2 142 | num_workers: 0 143 | batch_limit: 4 144 | """ 145 | ) 146 | 147 | # Mock the wandb functions so they aren't run in testing 148 | mocker.patch("wandb.login") 149 | mocker.patch("wandb.init") 150 | mocker.patch("wandb.config") 151 | mocker.patch("wandb.log") 152 | mocker.patch("wandb.plot.line") 153 | mocker.patch("wandb.plot.bar") 154 | 155 | # run the validate_from_config function 156 | validate_from_config() 157 | --------------------------------------------------------------------------------