├── .copier-answers.yml
├── .github
└── workflows
│ ├── cd.yml
│ └── ci.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── Makefile
├── make.bat
└── source
│ ├── conf.py
│ ├── dataset.rst
│ ├── download.rst
│ ├── index.rst
│ ├── metrics.rst
│ ├── models.rst
│ ├── usage.md
│ ├── utils.rst
│ └── validation.rst
├── examples
└── find_test_2022_t0_times.py
├── notebooks
├── 01-plotting.ipynb
├── 02-data_loader_demo.ipynb
├── 03-score_model_demo.ipynb
└── requirements.txt
├── pyproject.toml
├── src
└── cloudcasting
│ ├── __init__.py
│ ├── __main__.py
│ ├── _version.pyi
│ ├── cli.py
│ ├── constants.py
│ ├── data
│ └── test_2022_t0_times.csv.zip
│ ├── dataset.py
│ ├── download.py
│ ├── metrics.py
│ ├── models.py
│ ├── py.typed
│ ├── types.py
│ ├── utils.py
│ └── validation.py
├── tests
├── conftest.py
├── legacy_metrics.py
├── test_cli.py
├── test_data
│ └── non_hrv_shell.netcdf
├── test_dataset.py
├── test_download.py
├── test_metrics.py
├── test_models.py
└── test_validation.py
└── uv.lock
/.copier-answers.yml:
--------------------------------------------------------------------------------
1 | # Changes here will be overwritten by Copier
2 | _commit: v0.1.0-8-ga0c8676
3 | _src_path: gh:alan-turing-institute/python-project-template
4 | coc: our_coc
5 | email: nsimpson@turing.ac.uk
6 | full_name: cloudcasting Maintainers
7 | license: MIT
8 | min_python_version: '3.10'
9 | org: climetrend
10 | project_name: cloudcasting
11 | project_short_description: Tooling and infrastructure to enable cloud nowcasting.
12 | python_name: cloudcasting
13 | typing: strict
14 | url: https://github.com/climetrend/cloudcasting
15 |
--------------------------------------------------------------------------------
/.github/workflows/cd.yml:
--------------------------------------------------------------------------------
1 | name: CD
2 |
3 | on:
4 | workflow_dispatch:
5 | pull_request:
6 | push:
7 | branches:
8 | - main
9 | release:
10 | types:
11 | - published
12 |
13 | jobs:
14 | dist:
15 | needs: [pre-commit]
16 | name: Distribution build
17 | runs-on: ubuntu-latest
18 |
19 | steps:
20 | - uses: actions/checkout@v4
21 | with:
22 | fetch-depth: 0
23 |
24 | - name: Build sdist and wheel
25 | run: pipx run build
26 |
27 | - uses: actions/upload-artifact@v3
28 | with:
29 | path: dist
30 |
31 | - name: Check products
32 | run: pipx run twine check dist/*
33 |
34 | publish:
35 | needs: [dist]
36 | name: Publish to PyPI
37 | environment: pypi
38 | permissions:
39 | id-token: write
40 | runs-on: ubuntu-latest
41 | if: github.event_name == 'release' && github.event.action == 'published'
42 |
43 | steps:
44 | - uses: actions/download-artifact@v3
45 | with:
46 | name: artifact
47 | path: dist
48 |
49 | - uses: pypa/gh-action-pypi-publish@release/v1
50 | if: github.event_name == 'release' && github.event.action == 'published'
51 | with:
52 | # Remove this line to publish to PyPI
53 | repository-url: https://test.pypi.org/legacy/
54 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | workflow_dispatch:
5 | pull_request:
6 | push:
7 | branches:
8 | - main
9 |
10 | jobs:
11 | pre-commit:
12 | name: Format + lint code
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v4
16 | with:
17 | fetch-depth: 0
18 | - uses: actions/setup-python@v4
19 | with:
20 | python-version: "3.x"
21 | - uses: pre-commit/action@v3.0.0
22 | with:
23 | extra_args: --hook-stage manual --all-files
24 |
25 | checks:
26 | name: Run tests for Python ${{ matrix.python-version }} on ${{ matrix.runs-on }}
27 | runs-on: ${{ matrix.runs-on }}
28 | needs: [pre-commit]
29 | strategy:
30 | fail-fast: false
31 | matrix:
32 | python-version: ["3.10", "3.12"] # test oldest and latest supported versions
33 | runs-on: [ubuntu-latest, macos-latest, windows-latest] # can be extended to other OSes, e.g. [ubuntu-latest, macos-latest]
34 |
35 | steps:
36 | - uses: actions/checkout@v4
37 | with:
38 | fetch-depth: 0
39 |
40 | - uses: actions/setup-python@v4
41 | with:
42 | python-version: ${{ matrix.python-version }}
43 | allow-prereleases: true
44 |
45 | - name: Install ffmpeg on macOS
46 | if: runner.os == 'macOS'
47 | run: |
48 | brew install ffmpeg
49 |
50 | - name: Install package
51 | run: python -m pip install .[dev]
52 |
53 | - name: Install nightly build of torch
54 | if: matrix.runs-on == 'windows-latest'
55 | run: pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall --upgrade
56 |
57 | - name: Test package
58 | run: >-
59 | python -m pytest -ra --cov --cov-report=xml --cov-report=term
60 | --durations=20
61 |
62 | - name: Upload coverage report
63 | uses: codecov/codecov-action@v3.1.4
64 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
140 | # setuptools_scm
141 | src/*/_version.py
142 |
143 |
144 | # ruff
145 | .ruff_cache/
146 |
147 | # OS specific stuff
148 | .DS_Store
149 | .DS_Store?
150 | ._*
151 | .Spotlight-V100
152 | .Trashes
153 | ehthumbs.db
154 | Thumbs.db
155 |
156 | # Common editor files
157 | *~
158 | *.swp
159 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ci:
2 | autoupdate_commit_msg: "chore: update pre-commit hooks"
3 | autofix_commit_msg: "style: pre-commit fixes"
4 |
5 | repos:
6 | - repo: https://github.com/pre-commit/pre-commit-hooks
7 | rev: "v5.0.0"
8 | hooks:
9 | - id: check-added-large-files
10 | - id: check-case-conflict
11 | - id: check-merge-conflict
12 | - id: check-symlinks
13 | - id: check-yaml
14 | - id: debug-statements
15 | - id: end-of-file-fixer
16 | - id: mixed-line-ending
17 | - id: name-tests-test
18 | args: ["--pytest-test-first"]
19 | exclude: ^tests/legacy_metrics.py
20 | - id: requirements-txt-fixer
21 | - id: trailing-whitespace
22 |
23 | - repo: https://github.com/astral-sh/ruff-pre-commit
24 | rev: "v0.9.3"
25 | hooks:
26 | # first, lint + autofix
27 | - id: ruff
28 | types_or: [python, pyi, jupyter]
29 | args: ["--fix", "--show-fixes"]
30 | # then, format
31 | - id: ruff-format
32 |
33 | - repo: https://github.com/pre-commit/mirrors-mypy
34 | rev: "v1.14.1"
35 | hooks:
36 | - id: mypy
37 | args: []
38 | additional_dependencies:
39 | - pytest
40 | - xarray
41 | - pandas-stubs
42 | - typer
43 | - dask
44 | - pyproj
45 | - pyresample
46 | - lightning
47 | - torch
48 | - jaxtyping
49 | - types-tqdm
50 | - chex
51 | - types-PyYAML
52 | - wandb
53 | - matplotlib
54 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: "2"
2 |
3 | build:
4 | os: "ubuntu-22.04"
5 | tools:
6 | python: "3.11"
7 |
8 | python:
9 | install:
10 | - method: pip
11 | path: .
12 | extra_requirements:
13 | - doc
14 |
15 | sphinx:
16 | configuration: docs/source/conf.py
17 | fail_on_warning: true
18 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | We value the participation of every member of our community and want to ensure
4 | that every contributor has an enjoyable and fulfilling experience. Accordingly,
5 | everyone who participates in the cloudcasting project is expected to show respect and courtesy to other community members at all time.
6 |
7 | In the interest of fostering an open and welcoming environment, we as
8 | contributors and maintainers are dedicated to making participation in our project
9 | a harassment-free experience for everyone, regardless of age, body
10 | size, disability, ethnicity, sex characteristics, gender identity and expression,
11 | level of experience, education, socio-economic status, nationality, personal
12 | appearance, race, religion, or sexual identity and orientation.
13 |
14 | ## Our Standards
15 |
16 | Examples of behaviour that contributes to creating a positive environment
17 | include:
18 |
19 | - Using welcoming and inclusive language
20 | - Being respectful of differing viewpoints and experiences
21 | - Gracefully accepting constructive criticism
22 | - Focusing on what is best for the community
23 | - Showing empathy towards other community members
24 |
25 | Examples of unacceptable behaviour by participants include:
26 |
27 | - The use of sexualized language or imagery and unwelcome sexual attention or
28 | advances
29 | - Trolling, insulting/derogatory comments, and personal or political attacks
30 | - Public or private harassment
31 | - Publishing others' private information, such as a physical or electronic
32 | address, without explicit permission
33 | - Other conduct which could reasonably be considered inappropriate in a
34 | professional setting
35 |
36 |
43 |
44 | ## Attribution
45 |
46 | This Code of Conduct is adapted from the [Turing Data Stories Code of Conduct](https://github.com/alan-turing-institute/TuringDataStories/blob/main/CODE_OF_CONDUCT.md) which is based on the [scona project Code of Conduct](https://github.com/WhitakerLab/scona/blob/master/CODE_OF_CONDUCT.md)
47 | and the [Contributor Covenant](https://www.contributor-covenant.org), version [1.4](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
48 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | See the [Scientific Python Developer Guide][spc-dev-intro] for a detailed
2 | description of best practices for developing scientific packages.
3 |
4 | [spc-dev-intro]: https://learn.scientific-python.org/development/
5 |
6 | # Setting up a development environment manually
7 |
8 | You can set up a development environment by running:
9 |
10 | ```zsh
11 | python3 -m venv venv # create a virtualenv called venv
12 | source ./venv/bin/activate # now `python` points to the virtualenv python
13 | pip install -v -e ".[dev]" # -v for verbose, -e for editable, [dev] for dev dependencies
14 | ```
15 |
16 | # Post setup
17 |
18 | You should prepare pre-commit, which will help you by checking that commits pass
19 | required checks:
20 |
21 | ```bash
22 | pip install pre-commit # or brew install pre-commit on macOS
23 | pre-commit install # this will install a pre-commit hook into the git repo
24 | ```
25 |
26 | You can also/alternatively run `pre-commit run` (changes only) or
27 | `pre-commit run --all-files` to check even without installing the hook.
28 |
29 | # Testing
30 |
31 | Use pytest to run the unit checks:
32 |
33 | ```bash
34 | pytest
35 | ```
36 |
37 | # Coverage
38 |
39 | Use pytest-cov to generate coverage reports:
40 |
41 | ```bash
42 | pytest --cov=cloudcasting
43 | ```
44 |
45 | # Pre-commit
46 |
47 | This project uses pre-commit for all style checking. Install pre-commit and run:
48 |
49 | ```bash
50 | pre-commit run -a
51 | ```
52 |
53 | to check all files.
54 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2024 cloudcasting Maintainers
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | this software and associated documentation files (the "Software"), to deal in
5 | the Software without restriction, including without limitation the rights to
6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
7 | of the Software, and to permit persons to whom the Software is furnished to do
8 | so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # cloudcasting
2 |
3 | [![Actions Status][actions-badge]][actions-link]
4 | [](https://cloudcasting.readthedocs.io/en/latest/?badge=latest)
5 |
6 | Tooling and infrastructure to enable cloud nowcasting. Full documentation can be found at https://cloudcasting.readthedocs.io/.
7 |
8 | ## Linked model repos
9 | - [Optical Flow (Farneback)](https://github.com/alan-turing-institute/ocf-optical-flow)
10 | - [Optical Flow (TVL1)](https://github.com/alan-turing-institute/ocf-optical-flow-tvl1)
11 | - [Diffusion model](https://github.com/alan-turing-institute/ocf-diffusion)
12 | - [ConvLSTM](https://github.com/alan-turing-institute/ocf-convLSTM)
13 | - [IAM4VP](https://github.com/alan-turing-institute/ocf-iam4vp)
14 |
15 | The model template repo on which these are based is found [here](https://github.com/alan-turing-institute/ocf-model-template). These repositories contain the implementations of each model, as well as validation infrastructure to replicate metric scores on weights and biases.
16 |
17 | ## Installation
18 |
19 | ### For users:
20 |
21 | ```zsh
22 | git clone https://github.com/alan-turing-institute/cloudcasting
23 | cd cloudcasting
24 | python -m pip install .
25 | ```
26 |
27 | To run metrics on GPU:
28 |
29 | ```zsh
30 | python -m pip install --upgrade "jax[cuda12]"
31 | ```
32 | ### For making changes to the library:
33 |
34 | On macOS you first need to install `ffmpeg` with the following command. On other platforms this is
35 | not necessary.
36 |
37 | ```bash
38 | brew install ffmpeg
39 | ```
40 |
41 | Clone and install the repo.
42 |
43 | ```bash
44 | git clone https://github.com/alan-turing-institute/cloudcasting
45 | cd cloudcasting
46 | python -m pip install ".[dev]"
47 | ```
48 |
49 | Install pre-commit before making development changes:
50 |
51 | ```bash
52 | pre-commit install
53 | ```
54 |
55 | For making changes, see the [guidance on development](https://github.com/alan-turing-institute/python-project-template?tab=readme-ov-file#setting-up-a-new-project) from the template that generated this project.
56 |
57 | ## Usage
58 |
59 | ### Validating a model
60 | ```bash
61 | cloudcasting validate "path/to/config/file.yml" "path/to/model/file.py"
62 | ```
63 |
64 | ### Downloading data
65 | ```bash
66 | cloudcasting download "2020-06-01 00:00" "2020-06-30 23:55" "path/to/data/save/dir"
67 | ```
68 |
69 | Full options:
70 |
71 | ```bash
72 | > cloudcasting download --help
73 |
74 | Usage: cloudcasting download [OPTIONS] START_DATE
75 | END_DATE OUTPUT_DIRECTORY
76 |
77 | ╭─ Arguments ──────────────────────────────────────────╮
78 | │ * start_date TEXT Start date in │
79 | │ 'YYYY-MM-DD HH:MM' │
80 | │ format │
81 | │ [default: None] │
82 | │ [required] │
83 | │ * end_date TEXT End date in │
84 | │ 'YYYY-MM-DD HH:MM' │
85 | │ format │
86 | │ [default: None] │
87 | │ [required] │
88 | │ * output_directory TEXT Directory to save │
89 | │ the satellite data │
90 | │ [default: None] │
91 | │ [required] │
92 | ╰──────────────────────────────────────────────────────╯
93 | ╭─ Options ────────────────────────────────────────────╮
94 | │ --download-f… TEXT Frequency to │
95 | │ download data │
96 | │ in pandas │
97 | │ datetime │
98 | │ format │
99 | │ [default: │
100 | │ 15min] │
101 | │ --get-hrv --no-get-h… Whether to │
102 | │ download HRV │
103 | │ data │
104 | │ [default: │
105 | │ no-get-hrv] │
106 | │ --override-d… --no-overr… Whether to │
107 | │ override date │
108 | │ range limits │
109 | │ [default: │
110 | │ no-override-… │
111 | │ --lon-min FLOAT Minimum │
112 | │ longitude │
113 | │ [default: │
114 | │ -16] │
115 | │ --lon-max FLOAT Maximum │
116 | │ longitude │
117 | │ [default: 10] │
118 | │ --lat-min FLOAT Minimum │
119 | │ latitude │
120 | │ [default: 45] │
121 | │ --lat-max FLOAT Maximum │
122 | │ latitude │
123 | │ [default: 70] │
124 | │ --test-2022-… --no-test-… Whether to │
125 | │ filter data │
126 | │ from 2022 to │
127 | │ download the │
128 | │ test set │
129 | │ (every 2 │
130 | │ weeks). │
131 | │ [default: │
132 | │ no-test-2022… │
133 | │ --verify-202… --no-verif… Whether to │
134 | │ download the │
135 | │ verification │
136 | │ data from │
137 | │ 2023. Only │
138 | │ used at the │
139 | │ end of the │
140 | │ project │
141 | │ [default: │
142 | │ no-verify-20… |
143 | │ --help Show this │
144 | │ message and │
145 | │ exit. │
146 | ╰──────────────────────────────────────────────────────╯
147 | ```
148 |
149 | ## Contributing
150 |
151 | See [CONTRIBUTING.md](CONTRIBUTING.md) for instructions on how to contribute.
152 |
153 | ## License
154 |
155 | Distributed under the terms of the [MIT license](LICENSE).
156 |
157 |
158 |
159 | [actions-badge]: https://github.com/alan-turing-institute/cloudcasting/actions/workflows/ci.yml/badge.svg?branch=main
160 | [actions-link]: https://github.com/alan-turing-institute/cloudcasting/actions
161 | [pypi-link]: https://pypi.org/project/cloudcasting/
162 | [pypi-platforms]: https://img.shields.io/pypi/pyversions/cloudcasting
163 | [pypi-version]: https://img.shields.io/pypi/v/cloudcasting
164 |
165 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Path setup --------------------------------------------------------------
7 |
8 | # If extensions (or modules to document with autodoc) are in another directory,
9 | # add these directories to sys.path here. If the directory is relative to the
10 | # documentation root, use os.path.abspath to make it absolute, like shown here.
11 | #
12 | import os
13 | import sys
14 |
15 | sys.path.insert(0, os.path.abspath(".."))
16 |
17 | # -- Project information -----------------------------------------------------
18 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
19 |
20 | project = "cloudcasting"
21 | copyright = "2025, cloudcasting Maintainers"
22 | author = "cloudcasting Maintainers"
23 | release = "0.6"
24 | version = "0.6.0"
25 |
26 | # -- General configuration ---------------------------------------------------
27 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
28 |
29 | extensions = [
30 | "sphinx.ext.duration",
31 | "sphinx.ext.doctest",
32 | "sphinx.ext.autodoc",
33 | "sphinx.ext.autosummary",
34 | "sphinx.ext.intersphinx",
35 | "sphinx.ext.coverage",
36 | "sphinx.ext.napoleon",
37 | "m2r2",
38 | ]
39 |
40 | intersphinx_mapping = {
41 | "python": ("https://docs.python.org/3/", None),
42 | "sphinx": ("https://www.sphinx-doc.org/en/master/", None),
43 | }
44 | intersphinx_disabled_domains = ["std"]
45 |
46 | # -- Options for HTML output -------------------------------------------------
47 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
48 |
49 | html_theme = "sphinx_rtd_theme"
50 |
--------------------------------------------------------------------------------
/docs/source/dataset.rst:
--------------------------------------------------------------------------------
1 | Dataset
2 | =======
3 |
4 | .. automodule:: cloudcasting.dataset
5 | :members:
6 | :special-members: __init__
7 |
--------------------------------------------------------------------------------
/docs/source/download.rst:
--------------------------------------------------------------------------------
1 | Download
2 | ========
3 |
4 | .. automodule:: cloudcasting.download
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 |
2 | Documentation for cloudcasting
3 | ==============================
4 |
5 | Tooling and infrastructure to enable cloud nowcasting.
6 | Check out the :doc:`usage` section for further information on how to install and run this package.
7 |
8 | This tool was developed by `Open Climate Fix `_ and
9 | `The Alan Turing Institute `_ as part of the
10 | `Manchester Prize `_.
11 |
12 | Contents
13 | --------
14 |
15 | .. toctree::
16 | :maxdepth: 2
17 |
18 | usage
19 | dataset
20 | download
21 | metrics
22 | models
23 | utils
24 | validation
25 |
26 | License
27 | -------
28 |
29 | The cloudcasting software is released under an `MIT License `_.
30 |
31 | Index
32 | -----
33 |
34 | * :ref:`genindex`
35 |
--------------------------------------------------------------------------------
/docs/source/metrics.rst:
--------------------------------------------------------------------------------
1 | Metrics
2 | =======
3 |
4 | .. automodule:: cloudcasting.metrics
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/source/models.rst:
--------------------------------------------------------------------------------
1 | Models
2 | ======
3 |
4 | .. automodule:: cloudcasting.models
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/source/usage.md:
--------------------------------------------------------------------------------
1 | .. _usage:
2 |
3 | User guide
4 | ==========
5 |
6 | **Contents:**
7 |
8 | - :ref:`install`
9 | - :ref:`optional`
10 | - :ref:`getting_started`
11 |
12 | .. _install:
13 |
14 | Installation
15 | ------------
16 |
17 | To use cloudcasting, first install it using pip:
18 |
19 | ```bash
20 | git clone https://github.com/alan-turing-institute/cloudcasting
21 | cd cloudcasting
22 | python -m pip install .
23 | ```
24 |
25 | .. _optional:
26 |
27 | Optional dependencies
28 | ---------------------
29 |
30 | cloudcasting supports optional dependencies, which are not installed by default. These dependencies are required for certain functionality.
31 |
32 | To run the metrics on GPU:
33 |
34 | ```bash
35 | python -m pip install --upgrade "jax[cuda12]"
36 | ```
37 |
38 | To make changes to the library, it is necessary to install the extra `dev` dependencies, and install pre-commit:
39 |
40 | ```bash
41 | python -m pip install ".[dev]"
42 | pre-commit install
43 | ```
44 |
45 | To create the documentation, it is necessary to install the extra `doc` dependencies:
46 |
47 | ```bash
48 | python -m pip install ".[doc]"
49 | ```
50 |
51 | .. _getting_started:
52 |
53 | Getting started
54 | ---------------
55 |
56 | Use the command line interface to download data:
57 |
58 | ```bash
59 | cloudcasting download "2020-06-01 00:00" "2020-06-30 23:55" "path/to/data/save/dir"
60 | ```
61 |
62 | Once you have developed a model, you can also validate the model, calculating a set of metrics with a standard dataset.
63 | To make use of the cli tool, use the [model github repo template](https://github.com/alan-turing-institute/ocf-model-template) to structure it correctly for validation.
64 |
65 | ```bash
66 | cloudcasting validate "path/to/config/file.yml" "path/to/model/file.py"
67 | ```
68 |
--------------------------------------------------------------------------------
/docs/source/utils.rst:
--------------------------------------------------------------------------------
1 | Utils
2 | =====
3 |
4 | .. automodule:: cloudcasting.utils
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/source/validation.rst:
--------------------------------------------------------------------------------
1 | Validation
2 | ==========
3 |
4 | .. automodule:: cloudcasting.validation
5 | :members:
6 |
--------------------------------------------------------------------------------
/examples/find_test_2022_t0_times.py:
--------------------------------------------------------------------------------
1 | """This script finds the 2022 test set t0 times and saves them to the cloudcasting package."""
2 |
3 | import importlib.util
4 | import os
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import xarray as xr
9 |
10 | from cloudcasting.constants import DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES
11 | from cloudcasting.dataset import find_valid_t0_times
12 | from cloudcasting.download import _get_sat_public_dataset_path
13 |
14 | # Set a max history length which we will support in the validation process
15 | # We will not be able to fairly score models which require a longer history than this
16 | # But by setting this too long, we will reduce the samples we have to score on
17 |
18 | # The current FORECAST_HORIZON_MINUTES is 3 hours so we'll set this conservatively to 6 hours
19 | MAX_HISTORY_MINUTES = 6 * 60
20 |
21 | # We filter t0 times so they have to have a gap of at least this long between consecutive times
22 | MIN_GAP_SIZE = pd.Timedelta("1hour")
23 |
24 | # Open the 2022 dataset
25 | ds = xr.open_zarr(_get_sat_public_dataset_path(2022, is_hrv=False))
26 |
27 | # Filter to defined time frequency
28 | mask = np.mod(ds.time.dt.minute, DATA_INTERVAL_SPACING_MINUTES) == 0
29 | ds = ds.sel(time=mask)
30 |
31 | # Mask to the odd fortnights - i.e. the 2022 test set
32 | mask = np.mod(ds.time.dt.dayofyear // 14, 2) == 1
33 | ds = ds.sel(time=mask)
34 |
35 |
36 | # Find the t0 times that we have satellite data for
37 | available_t0_times = find_valid_t0_times(
38 | datetimes=pd.DatetimeIndex(ds.time),
39 | history_mins=MAX_HISTORY_MINUTES,
40 | forecast_mins=FORECAST_HORIZON_MINUTES,
41 | sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES,
42 | )
43 |
44 | # Filter the t0 times so they have gaps of at least 1 hour
45 | _filtered_t0_times = [available_t0_times[0]]
46 |
47 | for t in available_t0_times[1:]:
48 | if (t - _filtered_t0_times[-1]) >= MIN_GAP_SIZE:
49 | _filtered_t0_times.append(t)
50 |
51 | filtered_t0_times = pd.DatetimeIndex(_filtered_t0_times)
52 |
53 |
54 | # Print the valid t0 times to sanity check
55 | print(f"Number of available t0 times: {len(filtered_t0_times)}")
56 | print(f"Actual available t0 times: {filtered_t0_times}")
57 |
58 |
59 | # Find the path of the cloudcasting package so we can save the valid times into it
60 | spec = importlib.util.find_spec("cloudcasting")
61 | if spec and spec.origin:
62 | package_path = os.path.dirname(spec.origin)
63 | else:
64 | msg = "Path of package `cloudcasting` can not be found"
65 | raise ModuleNotFoundError(msg)
66 |
67 | # Save the valid t0 times
68 | filename = "test_2022_t0_times.csv"
69 | df = pd.DataFrame(filtered_t0_times, columns=["t0_time"]).set_index("t0_time")
70 | df.to_csv(
71 | f"{package_path}/data/{filename}.zip",
72 | compression={
73 | "method": "zip",
74 | "archive_name": filename,
75 | },
76 | )
77 |
--------------------------------------------------------------------------------
/notebooks/requirements.txt:
--------------------------------------------------------------------------------
1 | ipykernel
2 | matplotlib
3 | seaborn
4 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61", "setuptools_scm[toml]>=7"]
3 | build-backend = "setuptools.build_meta"
4 |
5 |
6 | [project]
7 | name = "cloudcasting"
8 | dynamic = ["version"]
9 | authors = [
10 | { name = "cloudcasting Maintainers", email = "clouds@turing.ac.uk" },
11 | ]
12 | description = "Tooling and infrastructure to enable cloud nowcasting."
13 | readme = "README.md"
14 | requires-python = ">=3.10"
15 | classifiers = [
16 | "Development Status :: 1 - Planning",
17 | "Intended Audience :: Science/Research",
18 | "Intended Audience :: Developers",
19 | "License :: OSI Approved :: MIT License",
20 | "Operating System :: OS Independent",
21 | "Programming Language :: Python",
22 | "Programming Language :: Python :: 3",
23 | "Programming Language :: Python :: 3 :: Only",
24 | "Programming Language :: Python :: 3.10",
25 | "Programming Language :: Python :: 3.11",
26 | "Programming Language :: Python :: 3.12",
27 | "Topic :: Scientific/Engineering",
28 | "Typing :: Typed",
29 | ]
30 | dependencies = [
31 | "gcsfs",
32 | "zarr<3.0.0", # ocf_blosc2 compatibility
33 | "xarray",
34 | "dask",
35 | "pyresample",
36 | "pyproj",
37 | "pykdtree<=1.3.12", # for macOS
38 | "ocf-blosc2>=0.0.10", # for no-import codec register
39 | "typer",
40 | "lightning",
41 | "torch>=2.3.0", # needed for numpy 2.0
42 | "jaxtyping<=0.2.34", # currently >0.2.34 causing errors
43 | "wandb",
44 | "tqdm",
45 | "moviepy==1.0.3", # currently >1.0.3 not working with wandb
46 | "imageio>=2.35.1",
47 | "numpy<2.1.0", # https://github.com/wandb/wandb/issues/8166
48 | "chex",
49 | "matplotlib"
50 | ]
51 | [project.optional-dependencies]
52 | dev = [
53 | "pytest >=6",
54 | "pytest-cov >=3",
55 | "pre-commit",
56 | "scipy",
57 | "pytest-mock",
58 | "scikit-image",
59 | "typeguard",
60 | ]
61 | doc = [
62 | "sphinx",
63 | "sphinx-rtd-theme",
64 | "m2r2"
65 | ]
66 |
67 | [tool.setuptools.package-data]
68 | "cloudcasting" = ["data/*.zip"]
69 |
70 | [tool.setuptools_scm]
71 | write_to = "src/cloudcasting/_version.py"
72 |
73 | [project.scripts]
74 | cloudcasting = "cloudcasting.cli:app"
75 |
76 | [project.urls]
77 | Homepage = "https://github.com/alan-turing-institute/cloudcasting"
78 | "Bug Tracker" = "https://github.com/alan-turing-institute/cloudcasting/issues"
79 | Discussions = "https://github.com/alan-turing-institute/cloudcasting/discussions"
80 | Changelog = "https://github.com/alan-turing-institute/cloudcasting/releases"
81 |
82 | [tool.pytest.ini_options]
83 | minversion = "6.0"
84 | addopts = [
85 | "-ra",
86 | "--showlocals",
87 | "--strict-markers",
88 | "--strict-config"
89 | ]
90 | xfail_strict = true
91 | filterwarnings = [
92 | "error",
93 | "ignore:pkg_resources:DeprecationWarning", # lightning
94 | "ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning", # lightning
95 | "ignore:ast.Str is deprecated:DeprecationWarning", # jaxtyping
96 | "ignore:`newshape` keyword argument is deprecated:DeprecationWarning", # wandb using numpy 2.1.0
97 | "ignore:The keyword `fps` is no longer supported:DeprecationWarning", # wandb.Video
98 | "ignore:torch.onnx.dynamo_export is deprecated since 2.6.0:DeprecationWarning", # lighning fabric torch 2.6+
99 | ]
100 | log_cli_level = "INFO"
101 | testpaths = [
102 | "tests",
103 | ]
104 |
105 | [tool.coverage]
106 | run.source = ["cloudcasting"]
107 | port.exclude_lines = [
108 | 'pragma: no cover',
109 | '\.\.\.',
110 | 'if typing.TYPE_CHECKING:',
111 | ]
112 |
113 | [tool.mypy]
114 | files = ["src", "tests"]
115 | python_version = "3.10"
116 | show_error_codes = true
117 | warn_unreachable = true
118 | disallow_untyped_defs = false
119 | disallow_incomplete_defs = false
120 | check_untyped_defs = true
121 | strict = true
122 | enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
123 |
124 | [[tool.mypy.overrides]]
125 | module = "cloudcasting.*"
126 | disallow_untyped_defs = true
127 | disallow_incomplete_defs = true
128 |
129 | [[tool.mypy.overrides]]
130 | module = [
131 | "ocf_blosc2",
132 | ]
133 | ignore_missing_imports = true
134 |
135 | [[tool.mypy.overrides]]
136 | module = [
137 | "cloudcasting.download",
138 | "cloudcasting.cli",
139 | "cloudcasting.validation", # use of wandb.update/Table
140 | ]
141 | disallow_untyped_calls = false
142 |
143 | [tool.ruff]
144 | src = ["src"]
145 | exclude = ["notebooks/*.ipynb"]
146 | line-length = 100 # how long you want lines to be
147 |
148 | [tool.ruff.format]
149 | docstring-code-format = true # code snippets in docstrings will be formatted
150 |
151 | [tool.ruff.lint]
152 | exclude = ["notebooks/*.ipynb"]
153 | select = [
154 | "E", "F", "W", # flake8
155 | "B", # flake8-bugbear
156 | "I", # isort
157 | "ARG", # flake8-unused-arguments
158 | "C4", # flake8-comprehensions
159 | "EM", # flake8-errmsg
160 | "ICN", # flake8-import-conventions
161 | "ISC", # flake8-implicit-str-concat
162 | "G", # flake8-logging-format
163 | "PGH", # pygrep-hooks
164 | "PIE", # flake8-pie
165 | "PL", # pylint
166 | "PT", # flake8-pytest-style
167 | "RET", # flake8-return
168 | "RUF", # Ruff-specific
169 | "SIM", # flake8-simplify
170 | "UP", # pyupgrade
171 | "YTT", # flake8-2020
172 | "EXE", # flake8-executable
173 | ]
174 | ignore = [
175 | "PLR", # Design related pylint codes
176 | "ISC001", # Conflicts with formatter
177 | "F722" # Marks jaxtyping syntax annotations as incorrect
178 | ]
179 | unfixable = [
180 | "F401", # Would remove unused imports
181 | "F841", # Would remove unused variables
182 | ]
183 | flake8-unused-arguments.ignore-variadic-names = true # allow unused *args/**kwargs
184 |
--------------------------------------------------------------------------------
/src/cloudcasting/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | cloudcasting: Tooling and infrastructure to enable cloud nowcasting.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | from jaxtyping import install_import_hook
8 |
9 | # Any module imported inside this `with` block, whose
10 | # name begins with the specified string, will
11 | # automatically have both `@jaxtyped` and the
12 | # typechecker applied to all of their functions and
13 | # dataclasses, meaning that they will be type-checked
14 | # (and therefore shape-checked via jaxtyping) at runtime.
15 | with install_import_hook("cloudcasting", "typeguard.typechecked"):
16 | from cloudcasting import models, validation
17 |
18 | from cloudcasting import cli, dataset, download, metrics
19 |
20 | from ._version import version as __version__
21 |
22 | __all__ = (
23 | "__version__",
24 | "cli",
25 | "dataset",
26 | "download",
27 | "metrics",
28 | "models",
29 | "validation",
30 | )
31 |
--------------------------------------------------------------------------------
/src/cloudcasting/__main__.py:
--------------------------------------------------------------------------------
1 | from cloudcasting.cli import app
2 |
3 | app()
4 |
--------------------------------------------------------------------------------
/src/cloudcasting/_version.pyi:
--------------------------------------------------------------------------------
1 | version: str
2 |
--------------------------------------------------------------------------------
/src/cloudcasting/cli.py:
--------------------------------------------------------------------------------
1 | import typer
2 |
3 | from cloudcasting.download import download_satellite_data
4 | from cloudcasting.validation import validate_from_config
5 |
6 | # typer app code
7 | app = typer.Typer()
8 | app.command("download")(download_satellite_data)
9 | app.command("validate")(validate_from_config)
10 |
--------------------------------------------------------------------------------
/src/cloudcasting/constants.py:
--------------------------------------------------------------------------------
1 | __all__ = (
2 | "CUTOUT_MASK",
3 | "DATA_INTERVAL_SPACING_MINUTES",
4 | "FORECAST_HORIZON_MINUTES",
5 | "NUM_CHANNELS",
6 | "NUM_FORECAST_STEPS",
7 | )
8 |
9 | from cloudcasting.utils import create_cutout_mask
10 |
11 | # These constants were locked as part of the project specification
12 | # 3 hour horecast horizon
13 | FORECAST_HORIZON_MINUTES = 180
14 | # at 15 minute intervals
15 | DATA_INTERVAL_SPACING_MINUTES = 15
16 | # gives 12 forecast steps
17 | NUM_FORECAST_STEPS = FORECAST_HORIZON_MINUTES // DATA_INTERVAL_SPACING_MINUTES
18 | # for all 11 low resolution channels
19 | NUM_CHANNELS = 11
20 |
21 | # Constants for the larger (original) image
22 | # Image size (height, width)
23 | IMAGE_SIZE_TUPLE = (372, 614)
24 | # Cutout mask (min x, max x, min y, max y)
25 | CUTOUT_MASK_BOUNDARY = (166, 336, 107, 289)
26 | # Create cutout mask
27 | CUTOUT_MASK = create_cutout_mask(CUTOUT_MASK_BOUNDARY, IMAGE_SIZE_TUPLE)
28 |
29 | # Constants for the smaller (cropped) image
30 | # Cropped image size (height, width)
31 | CROPPED_IMAGE_SIZE_TUPLE = (278, 385)
32 | # Cropped cutout mask (min x, max x, min y, max y)
33 | CROPPED_CUTOUT_MASK_BOUNDARY = (109, 279, 62, 244)
34 | # Create cropped cutout mask
35 | CROPPED_CUTOUT_MASK = create_cutout_mask(CROPPED_CUTOUT_MASK_BOUNDARY, CROPPED_IMAGE_SIZE_TUPLE)
36 |
--------------------------------------------------------------------------------
/src/cloudcasting/data/test_2022_t0_times.csv.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alan-turing-institute/cloudcasting/0da4087bba01823d2477dba61ea3e6cfd557212c/src/cloudcasting/data/test_2022_t0_times.csv.zip
--------------------------------------------------------------------------------
/src/cloudcasting/dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset and DataModule for past and future satellite data"""
2 |
3 | __all__ = (
4 | "SatelliteDataModule",
5 | "SatelliteDataset",
6 | "ValidationSatelliteDataset",
7 | )
8 |
9 | import io
10 | import pkgutil
11 | from datetime import datetime, timedelta
12 | from typing import TypedDict
13 |
14 | import numpy as np
15 | import pandas as pd
16 | import xarray as xr
17 | from lightning import LightningDataModule
18 | from numpy.typing import NDArray
19 | from torch.utils.data import DataLoader, Dataset
20 |
21 | from cloudcasting.constants import DATA_INTERVAL_SPACING_MINUTES, FORECAST_HORIZON_MINUTES
22 | from cloudcasting.utils import find_contiguous_t0_time_periods, find_contiguous_time_periods
23 |
24 |
25 | class DataloaderArgs(TypedDict):
26 | batch_size: int
27 | sampler: None
28 | batch_sampler: None
29 | num_workers: int
30 | pin_memory: bool
31 | drop_last: bool
32 | timeout: int
33 | worker_init_fn: None
34 | prefetch_factor: int | None
35 | persistent_workers: bool
36 |
37 |
38 | def load_satellite_zarrs(zarr_path: list[str] | tuple[str] | str) -> xr.Dataset:
39 | """Load the satellite data
40 |
41 | Args:
42 | zarr_path: The path to the satellite zarr(s)
43 | """
44 |
45 | if isinstance(zarr_path, list | tuple):
46 | ds = xr.concat(
47 | [xr.open_dataset(path, engine="zarr", chunks="auto") for path in zarr_path],
48 | dim="time",
49 | coords="minimal",
50 | compat="identical",
51 | combine_attrs="override",
52 | ).sortby("time")
53 | else:
54 | ds = xr.open_dataset(zarr_path, engine="zarr", chunks="auto").sortby("time")
55 |
56 | return ds
57 |
58 |
59 | def find_valid_t0_times(
60 | datetimes: pd.DatetimeIndex,
61 | history_mins: int,
62 | forecast_mins: int,
63 | sample_freq_mins: int,
64 | ) -> pd.DatetimeIndex:
65 | """Constuct an array of all t0 times which are valid considering the gaps in the sat data"""
66 |
67 | # Find periods where we have contiguous time steps
68 | contiguous_time_periods = find_contiguous_time_periods(
69 | datetimes=datetimes,
70 | min_seq_length=int((history_mins + forecast_mins) / sample_freq_mins) + 1,
71 | max_gap_duration=timedelta(minutes=sample_freq_mins),
72 | )
73 |
74 | # Find periods of valid init-times
75 | contiguous_t0_periods = find_contiguous_t0_time_periods(
76 | contiguous_time_periods=contiguous_time_periods,
77 | history_duration=timedelta(minutes=history_mins),
78 | forecast_duration=timedelta(minutes=forecast_mins),
79 | )
80 |
81 | valid_t0_times = []
82 | for _, row in contiguous_t0_periods.iterrows():
83 | valid_t0_times.append(
84 | pd.date_range(row["start_dt"], row["end_dt"], freq=f"{sample_freq_mins}min")
85 | )
86 |
87 | return pd.to_datetime(np.concatenate(valid_t0_times))
88 |
89 |
90 | DataIndex = str | datetime | pd.Timestamp | int
91 |
92 |
93 | class SatelliteDataset(Dataset[tuple[NDArray[np.float32], NDArray[np.float32]]]):
94 | def __init__(
95 | self,
96 | zarr_path: list[str] | str,
97 | start_time: str | None,
98 | end_time: str | None,
99 | history_mins: int,
100 | forecast_mins: int,
101 | sample_freq_mins: int,
102 | variables: list[str] | str | None = None,
103 | preshuffle: bool = False,
104 | nan_to_num: bool = False,
105 | ):
106 | """A torch Dataset for loading past and future satellite data
107 |
108 | Args:
109 | zarr_path (list[str] | str): Path to the satellite data. Can be a string or list
110 | start_time (str): The satellite data is filtered to exclude timestamps before this
111 | end_time (str): The satellite data is filtered to exclude timestamps after this
112 | history_mins (int): How many minutes of history will be used as input features
113 | forecast_mins (int): How many minutes of future will be used as target features
114 | sample_freq_mins (int): The sample frequency to use for the satellite data
115 | variables (list[str] | str): The variables to load from the satellite data
116 | (defaults to all)
117 | preshuffle (bool): Whether to shuffle the data - useful for validation.
118 | Defaults to False.
119 | nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False.
120 | """
121 |
122 | # Load the sat zarr file or list of files and slice the data to the given period
123 | ds = load_satellite_zarrs(zarr_path).sel(time=slice(start_time, end_time))
124 |
125 | if variables is not None:
126 | if isinstance(variables, str):
127 | variables = [variables]
128 | self.ds = ds.sel(variable=variables)
129 | else:
130 | self.ds = ds
131 |
132 | # Convert the satellite data to the given time frequency by selection
133 | mask = np.mod(self.ds.time.dt.minute, sample_freq_mins) == 0
134 | self.ds = self.ds.sel(time=mask)
135 |
136 | # Find the valid t0 times for the available data. This avoids trying to take samples where
137 | # there would be a missing timestamp in the sat data required for the sample
138 | self.t0_times = self._find_t0_times(
139 | pd.DatetimeIndex(self.ds.time), history_mins, forecast_mins, sample_freq_mins
140 | )
141 |
142 | if preshuffle:
143 | self.t0_times = pd.to_datetime(np.random.permutation(self.t0_times))
144 |
145 | self.history_mins = history_mins
146 | self.forecast_mins = forecast_mins
147 | self.sample_freq_mins = sample_freq_mins
148 | self.nan_to_num = nan_to_num
149 |
150 | @staticmethod
151 | def _find_t0_times(
152 | date_range: pd.DatetimeIndex, history_mins: int, forecast_mins: int, sample_freq_mins: int
153 | ) -> pd.DatetimeIndex:
154 | return find_valid_t0_times(date_range, history_mins, forecast_mins, sample_freq_mins)
155 |
156 | def __len__(self) -> int:
157 | return len(self.t0_times)
158 |
159 | def _get_datetime(self, t0: datetime) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
160 | ds_sel = self.ds.sel(
161 | time=slice(
162 | t0 - timedelta(minutes=self.history_mins),
163 | t0 + timedelta(minutes=self.forecast_mins),
164 | )
165 | )
166 |
167 | # Load the data eagerly so that the same chunks aren't loaded multiple times after we split
168 | # further
169 | ds_sel = ds_sel.compute(scheduler="single-threaded")
170 |
171 | # Reshape to (channel, time, height, width)
172 | ds_sel = ds_sel.transpose("variable", "time", "y_geostationary", "x_geostationary")
173 |
174 | ds_input = ds_sel.sel(time=slice(None, t0))
175 | ds_target = ds_sel.sel(time=slice(t0 + timedelta(minutes=self.sample_freq_mins), None))
176 |
177 | # Convert to arrays
178 | X = ds_input.data.values
179 | y = ds_target.data.values
180 |
181 | if self.nan_to_num:
182 | X = np.nan_to_num(X, nan=-1)
183 | y = np.nan_to_num(y, nan=-1)
184 |
185 | return X.astype(np.float32), y.astype(np.float32)
186 |
187 | def __getitem__(self, key: DataIndex) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
188 | if isinstance(key, int):
189 | t0 = self.t0_times[key]
190 |
191 | else:
192 | assert isinstance(key, str | datetime | pd.Timestamp)
193 | t0 = pd.Timestamp(key)
194 | assert t0 in self.t0_times
195 |
196 | return self._get_datetime(t0)
197 |
198 |
199 | class ValidationSatelliteDataset(SatelliteDataset):
200 | def __init__(
201 | self,
202 | zarr_path: list[str] | str,
203 | history_mins: int,
204 | forecast_mins: int = FORECAST_HORIZON_MINUTES,
205 | sample_freq_mins: int = DATA_INTERVAL_SPACING_MINUTES,
206 | nan_to_num: bool = False,
207 | ):
208 | """A torch Dataset used only in the validation proceedure.
209 |
210 | Args:
211 | zarr_path (list[str] | str): Path to the satellite data for validation.
212 | Can be a string or list
213 | history_mins (int): How many minutes of history will be used as input features
214 | forecast_mins (int): How many minutes of future will be used as target features
215 | sample_freq_mins (int): The sample frequency to use for the satellite data
216 | nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False.
217 | """
218 |
219 | super().__init__(
220 | zarr_path=zarr_path,
221 | start_time=None,
222 | end_time=None,
223 | history_mins=history_mins,
224 | forecast_mins=forecast_mins,
225 | sample_freq_mins=sample_freq_mins,
226 | preshuffle=True,
227 | nan_to_num=nan_to_num,
228 | )
229 |
230 | @staticmethod
231 | def _find_t0_times(
232 | date_range: pd.DatetimeIndex, history_mins: int, forecast_mins: int, sample_freq_mins: int
233 | ) -> pd.DatetimeIndex:
234 | # Find the valid t0 times for the available data. This avoids trying to take samples where
235 | # there would be a missing timestamp in the sat data required for the sample
236 | available_t0_times = find_valid_t0_times(
237 | date_range, history_mins, forecast_mins, sample_freq_mins
238 | )
239 |
240 | # Get the required 2022 test dataset t0 times
241 | val_t0_times_from_csv = ValidationSatelliteDataset._get_test_2022_t0_times()
242 |
243 | # Find the intersection of the available t0 times and the required validation t0 times
244 | val_time_available = val_t0_times_from_csv.isin(available_t0_times)
245 |
246 | # Make sure all of the required validation times are available in the data
247 | if not val_time_available.all():
248 | msg = (
249 | "The following validation t0 times are not available in the satellite data: \n"
250 | f"{val_t0_times_from_csv[~val_time_available]}\n"
251 | "The validation proceedure requires these t0 times to be available."
252 | )
253 | raise ValueError(msg)
254 |
255 | return val_t0_times_from_csv
256 |
257 | @staticmethod
258 | def _get_t0_times(path: str) -> pd.DatetimeIndex:
259 | """Load the required validation t0 times from library path"""
260 |
261 | # Load the zipped csv file as a byte stream
262 | data = pkgutil.get_data("cloudcasting", path)
263 | if data is not None:
264 | byte_stream = io.BytesIO(data)
265 | else:
266 | # Handle the case where data is None
267 | msg = f"No data found for path: {path}"
268 | raise ValueError(msg)
269 |
270 | # Load the times into pandas
271 | df = pd.read_csv(byte_stream, encoding="utf8", compression="zip")
272 |
273 | return pd.DatetimeIndex(df.t0_time)
274 |
275 | @staticmethod
276 | def _get_test_2022_t0_times() -> pd.DatetimeIndex:
277 | """Load the required 2022 test dataset t0 times from their location in the library"""
278 | return ValidationSatelliteDataset._get_t0_times("data/test_2022_t0_times.csv.zip")
279 |
280 | @staticmethod
281 | def _get_verify_2023_t0_times() -> pd.DatetimeIndex:
282 | msg = "The required 2023 verification dataset t0 times are not available"
283 | raise NotImplementedError(msg)
284 |
285 |
286 | class SatelliteDataModule(LightningDataModule):
287 | def __init__(
288 | self,
289 | zarr_path: list[str] | str,
290 | history_mins: int,
291 | forecast_mins: int,
292 | sample_freq_mins: int,
293 | batch_size: int = 16,
294 | num_workers: int = 0,
295 | variables: list[str] | str | None = None,
296 | prefetch_factor: int | None = None,
297 | train_period: list[str | None] | tuple[str | None] | None = None,
298 | val_period: list[str | None] | tuple[str | None] | None = None,
299 | test_period: list[str | None] | tuple[str | None] | None = None,
300 | nan_to_num: bool = False,
301 | pin_memory: bool = False,
302 | persistent_workers: bool = False,
303 | ):
304 | """A lightning DataModule for loading past and future satellite data
305 |
306 | Args:
307 | zarr_path (list[str] | str): Path to the satellite data. Can be a string or list
308 | history_mins (int): How many minutes of history will be used as input features
309 | forecast_mins (int): How many minutes of future will be used as target features
310 | sample_freq_mins (int): The sample frequency to use for the satellite data
311 | batch_size (int): Batch size. Defaults to 16.
312 | num_workers (int): Number of workers to use in multiprocess batch loading.
313 | Defaults to 0.
314 | variables (list[str] | str): The variables to load from the satellite data
315 | (defaults to all)
316 | prefetch_factor (int): Number of data to be prefetched at the end of each worker process
317 | train_period (list[str] | tuple[str] | None): Date range filter for train dataloader
318 | val_period (list[str] | tuple[str] | None): Date range filter for validation dataloader
319 | test_period (list[str] | tuple[str] | None): Date range filter for test dataloader
320 | nan_to_num (bool): Whether to convert NaNs to -1. Defaults to False.
321 | pin_memory (bool): If True, the data loader will copy Tensors into device/CUDA
322 | pinned memory before returning them. Defaults to False.
323 | persistent_workers (bool): If True, the data loader will not shut down the worker
324 | processes after a dataset has been consumed once. This allows you to keep the
325 | workers Dataset instances alive. Defaults to False.
326 | """
327 | super().__init__()
328 |
329 | if train_period is None:
330 | train_period = [None, None]
331 | if val_period is None:
332 | val_period = [None, None]
333 | if test_period is None:
334 | test_period = [None, None]
335 |
336 | assert len(train_period) == 2
337 | assert len(val_period) == 2
338 | assert len(test_period) == 2
339 |
340 | self.train_period = train_period
341 | self.val_period = val_period
342 | self.test_period = test_period
343 |
344 | self.zarr_path = zarr_path
345 | self.history_mins = history_mins
346 | self.forecast_mins = forecast_mins
347 | self.sample_freq_mins = sample_freq_mins
348 |
349 | self._common_dataloader_kwargs = DataloaderArgs(
350 | batch_size=batch_size,
351 | sampler=None,
352 | batch_sampler=None,
353 | num_workers=num_workers,
354 | pin_memory=pin_memory,
355 | drop_last=False,
356 | timeout=0,
357 | worker_init_fn=None,
358 | prefetch_factor=prefetch_factor,
359 | persistent_workers=persistent_workers,
360 | )
361 |
362 | self.nan_to_num = nan_to_num
363 | self.variables = variables
364 |
365 | def _make_dataset(
366 | self, start_date: str | None, end_date: str | None, preshuffle: bool = False
367 | ) -> SatelliteDataset:
368 | return SatelliteDataset(
369 | zarr_path=self.zarr_path,
370 | start_time=start_date,
371 | end_time=end_date,
372 | history_mins=self.history_mins,
373 | forecast_mins=self.forecast_mins,
374 | sample_freq_mins=self.sample_freq_mins,
375 | preshuffle=preshuffle,
376 | nan_to_num=self.nan_to_num,
377 | variables=self.variables,
378 | )
379 |
380 | def train_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]:
381 | """Construct train dataloader"""
382 | dataset = self._make_dataset(self.train_period[0], self.train_period[1])
383 | return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
384 |
385 | def val_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]:
386 | """Construct validation dataloader"""
387 | dataset = self._make_dataset(self.val_period[0], self.val_period[1], preshuffle=True)
388 | return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
389 |
390 | def test_dataloader(self) -> DataLoader[tuple[NDArray[np.float32], NDArray[np.float32]]]:
391 | """Construct test dataloader"""
392 | dataset = self._make_dataset(self.test_period[0], self.test_period[1])
393 | return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
394 |
--------------------------------------------------------------------------------
/src/cloudcasting/download.py:
--------------------------------------------------------------------------------
1 | __all__ = ("download_satellite_data",)
2 |
3 | import logging
4 | import os
5 | from typing import Annotated
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import typer
10 | import xarray as xr
11 | from dask.diagnostics import ProgressBar # type: ignore[attr-defined]
12 |
13 | from cloudcasting.utils import lon_lat_to_geostationary_area_coords
14 |
15 | xr.set_options(keep_attrs=True)
16 |
17 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def _get_sat_public_dataset_path(year: int, is_hrv: bool = False) -> str:
22 | """
23 | Get the path to the Google Public Dataset of EUMETSAT satellite data.
24 |
25 | Args:
26 | year: The year of the dataset.
27 | is_hrv: Whether to get the HRV dataset or not.
28 |
29 | Returns:
30 | The path to the dataset.
31 | """
32 | file_end = "hrv.zarr" if is_hrv else "nonhrv.zarr"
33 | return f"gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v4/{year}_{file_end}"
34 |
35 |
36 | def download_satellite_data(
37 | start_date: Annotated[str, typer.Argument(help="Start date in 'YYYY-MM-DD HH:MM' format")],
38 | end_date: Annotated[str, typer.Argument(help="End date in 'YYYY-MM-DD HH:MM' format")],
39 | output_directory: Annotated[str, typer.Argument(help="Directory to save the satellite data")],
40 | download_frequency: Annotated[
41 | str, typer.Option(help="Frequency to download data in pandas datetime format")
42 | ] = "15min",
43 | get_hrv: Annotated[bool, typer.Option(help="Whether to download HRV data")] = False,
44 | override_date_bounds: Annotated[
45 | bool, typer.Option(help="Whether to override date range limits")
46 | ] = False,
47 | lon_min: Annotated[float, typer.Option(help="Minimum longitude")] = -16,
48 | lon_max: Annotated[float, typer.Option(help="Maximum longitude")] = 10,
49 | lat_min: Annotated[float, typer.Option(help="Minimum latitude")] = 45,
50 | lat_max: Annotated[float, typer.Option(help="Maximum latitude")] = 70,
51 | test_2022_set: Annotated[
52 | bool,
53 | typer.Option(
54 | help="Whether to filter data from 2022 to download the test set (every 2 weeks)."
55 | ),
56 | ] = False,
57 | verify_2023_set: Annotated[
58 | bool,
59 | typer.Option(
60 | help="Whether to download the verification data from 2023. Only used at project end"
61 | ),
62 | ] = False,
63 | ) -> None:
64 | """
65 | Download a selection of the available EUMETSAT data.
66 |
67 | Each calendar year of data within the supplied date range will be saved to a separate file in
68 | the output directory.
69 |
70 | Args:
71 | start_date (str): First datetime (inclusive) to download in 'YYYY-MM-DD HH:MM' format
72 | end_date (str): Last datetime (inclusive) to download in 'YYYY-MM-DD HH:MM' format
73 | output_directory (str): Directory to which the satellite data should be saved
74 | download_frequency (str): Frequency to download data in pandas datetime format.
75 | Defaults to "15min".
76 | get_hrv (bool): Whether to download the HRV data, otherwise only non-HRV is downloaded.
77 | Defaults to False.
78 | override_date_bounds (bool): Whether to override the date range limits
79 | lon_min (float): The west-most longitude (in degrees) of the bounding box to download.
80 | Defaults to -16.
81 | lon_max (float): The east-most longitude (in degrees) of the bounding box to download.
82 | Defaults to 10.
83 | lat_min (float): The south-most latitude (in degrees) of the bounding box to download.
84 | Defaults to 45.
85 | lat_max (float): The north-most latitude (in degrees) of the bounding box to download.
86 | Defaults to 70.
87 | test_2022_set (bool): Whether to filter data from 2022 to download the test set
88 | (every 2 weeks)
89 | verify_2023_set (bool): Whether to download verification data from 2023. Only
90 | used at project end.
91 |
92 | Raises:
93 | FileNotFoundError: If the output directory doesn't exist.
94 | ValueError: If there are issues with the date range or if output files already exist.
95 | """
96 |
97 | # Check output directory exists
98 | if not os.path.isdir(output_directory):
99 | msg = (
100 | f"Output directory {output_directory} does not exist. "
101 | "Please create it before attempting to download satellite data."
102 | )
103 | raise FileNotFoundError(msg)
104 |
105 | # Build the formatable string for the output file path.
106 | # We can insert year later using `output_file_root.format(year=year)``
107 | output_file_root = output_directory + "/{year}_"
108 |
109 | # Add training split label
110 | if test_2022_set:
111 | output_file_root += "test_"
112 | elif verify_2023_set:
113 | output_file_root += "verification_"
114 | else:
115 | output_file_root += "training_"
116 |
117 | # Add HRV or non-HRV label and file extension
118 | if get_hrv:
119 | output_file_root += "hrv.zarr"
120 | else:
121 | output_file_root += "nonhrv.zarr"
122 |
123 | # Check download frequency is valid (i.e. is a pandas frequency + multiple of 5 minutes)
124 | if pd.Timedelta(download_frequency) % pd.Timedelta("5min") != pd.Timedelta(0):
125 | msg = (
126 | f"Download frequency {download_frequency} is not a multiple of 5 minutes. "
127 | "Please choose a valid frequency."
128 | )
129 | raise ValueError(msg)
130 |
131 | start_date_stamp = pd.Timestamp(start_date)
132 | end_date_stamp = pd.Timestamp(end_date)
133 |
134 | # Check start date is before end date
135 | if start_date_stamp > end_date_stamp:
136 | msg = "Start date ({start_date_stamp}) must be before end date ({end_date_stamp})."
137 | raise ValueError(msg)
138 |
139 | # Check date range for known limitations
140 | if not override_date_bounds and start_date_stamp.year < 2019:
141 | msg = (
142 | "There are currently some issues with the EUMETSAT data before 2019/01/01. "
143 | "We recommend only using data from this date forward. "
144 | "To override this error set `override_date_bounds=True`"
145 | )
146 | raise ValueError(msg)
147 |
148 | # Check the year is 2022 if test_2022 data is being downloaded
149 | if test_2022_set and (start_date_stamp.year != 2022 or end_date_stamp.year != 2022):
150 | msg = "Test data is only defined for 2022"
151 | raise ValueError(msg)
152 |
153 | # Check the start / end dates are correct if verification data is being downloaded
154 | if verify_2023_set and (
155 | start_date_stamp != pd.Timestamp("2023-01-01 00:00")
156 | or end_date_stamp != pd.Timestamp("2023-12-31 23:55")
157 | ):
158 | msg = (
159 | "Verification data requires a start date of '2023-01-01 00:00'"
160 | "and an end date of '2023-12-31 23:55'"
161 | )
162 | raise ValueError(msg)
163 |
164 | # Check the year 2023 is not included unless verification data is being downloaded
165 | if (not verify_2023_set) and (end_date_stamp.year >= 2023):
166 | msg = "2023 data is reserved for the verification process"
167 | raise ValueError(msg)
168 |
169 | years = range(start_date_stamp.year, end_date_stamp.year + 1)
170 |
171 | # Ceiling the start date to nearest multiple of the download frequency
172 | # Breaks down over multiple days due to starting at the Unix epoch (1970-01-01 Thursday),
173 | # e.g. 2022-01-01 ceiled to 1 week will be 2022-01-06 (the closest Thursday to 2022-01-01).
174 | range_start = (
175 | start_date_stamp.ceil(download_frequency)
176 | if pd.Timedelta(download_frequency) <= pd.Timedelta("1day")
177 | else start_date_stamp
178 | )
179 | # Create a list of dates to download
180 | dates_to_download = pd.date_range(range_start, end_date_stamp, freq=download_frequency)
181 |
182 | # Check that none of the filenames we will save to already exist
183 | for year in years:
184 | output_zarr_file = output_file_root.format(year=year)
185 | if os.path.exists(output_zarr_file):
186 | msg = (
187 | f"The zarr file {output_zarr_file} already exists. "
188 | "This function will not overwrite data."
189 | )
190 | raise ValueError(msg)
191 |
192 | # Begin download loop
193 | for year in years:
194 | logger.info("Downloading data from %s", year)
195 | path = _get_sat_public_dataset_path(year, is_hrv=get_hrv)
196 |
197 | # Slice the data from this year which are between the start and end dates.
198 | ds = xr.open_zarr(path, chunks={}).sortby("time")
199 |
200 | ds = ds.sel(time=dates_to_download[dates_to_download.isin(ds.time.values)])
201 |
202 | if year == 2022:
203 | set_str = "Test_2022" if test_2022_set else "Training"
204 | day_str = "14" if test_2022_set else "1"
205 | logger.info("Data in 2022 will be downloaded every 2 weeks due to train/test split.")
206 | logger.info("%s set selected: Starting day will be %s.", set_str, day_str)
207 | # Integer division by 14 will tell us the fortnight we're on.
208 | # checking the mod wrt 2 will let us select every 2 weeks
209 | # Test set is defined as from week 2-3, 6-7 etc.
210 | # Weeks 0-1, 4-5 etc. are included in training set
211 | if test_2022_set:
212 | mask = np.mod(ds.time.dt.dayofyear // 14, 2) == 1
213 | else:
214 | mask = np.mod(ds.time.dt.dayofyear // 14, 2) == 0
215 | ds = ds.sel(time=mask)
216 |
217 | # Convert lon-lat bounds to geostationary-coords
218 | (x_min, x_max), (y_min, y_max) = lon_lat_to_geostationary_area_coords(
219 | [lon_min, lon_max],
220 | [lat_min, lat_max],
221 | ds.data,
222 | )
223 |
224 | # Define the spatial area to slice from
225 | ds = ds.sel(
226 | x_geostationary=slice(x_max, x_min), # x-axis is in decreasing order
227 | y_geostationary=slice(y_min, y_max),
228 | )
229 |
230 | # Re-chunking
231 | for v in ds.variables:
232 | if "chunks" in ds[v].encoding:
233 | del ds[v].encoding["chunks"]
234 |
235 | target_chunks_dict = {
236 | "time": 2,
237 | "x_geostationary": -1,
238 | "y_geostationary": -1,
239 | "variable": -1,
240 | }
241 |
242 | ds = ds.chunk(target_chunks_dict)
243 |
244 | # Save data to zarr
245 | output_zarr_file = output_file_root.format(year=year)
246 | with ProgressBar(dt=1):
247 | ds.to_zarr(output_zarr_file)
248 | logger.info("Data for %s saved to %s", year, output_zarr_file)
249 |
--------------------------------------------------------------------------------
/src/cloudcasting/metrics.py:
--------------------------------------------------------------------------------
1 | # mypy: ignore-errors
2 | ### Adapted by The Alan Turing Institute and Open Climate Fix, 2024.
3 | ### Original source: https://github.com/google-deepmind/dm_pix
4 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | """Functions to compare image pairs.
18 |
19 | All functions expect float-encoded images, with values in [0, 1], with NHWC
20 | shapes. Each image metric function returns a scalar for each image pair.
21 | """
22 |
23 | from collections.abc import Callable
24 |
25 | import chex
26 | import jax
27 | import jax.numpy as jnp
28 |
29 | # DO NOT REMOVE - Logging lib.
30 |
31 |
32 | def mae(a: chex.Array, b: chex.Array, ignore_nans: bool = False) -> chex.Numeric:
33 | """Returns the Mean Absolute Error between `a` and `b`.
34 |
35 | Args:
36 | a (chex.Array): First image (or set of images)
37 | b (chex.Array): Second image (or set of images)
38 | ignore_nans (bool): Defaults to False
39 |
40 | Returns:
41 | chex.Numeric: MAE between `a` and `b`
42 | """
43 | # DO NOT REMOVE - Logging usage.
44 |
45 | chex.assert_rank([a, b], {3, 4})
46 | chex.assert_type([a, b], float)
47 | chex.assert_equal_shape([a, b])
48 | if ignore_nans:
49 | return jnp.nanmean(jnp.abs(a - b), axis=(-3, -2, -1))
50 | return jnp.mean(jnp.abs(a - b), axis=(-3, -2, -1))
51 |
52 |
53 | def mse(a: chex.Array, b: chex.Array, ignore_nans: bool = False) -> chex.Numeric:
54 | """Returns the Mean Squared Error between `a` and `b`.
55 |
56 | Args:
57 | a (chex.Array): First image (or set of images)
58 | b (chex.Array): Second image (or set of images)
59 | ignore_nans (bool): Defaults to False
60 |
61 | Returns:
62 | chex.Numeric: MSE between `a` and `b`
63 | """
64 | # DO NOT REMOVE - Logging usage.
65 |
66 | chex.assert_rank([a, b], {3, 4})
67 | chex.assert_type([a, b], float)
68 | chex.assert_equal_shape([a, b])
69 | if ignore_nans:
70 | return jnp.nanmean(jnp.square(a - b), axis=(-3, -2, -1))
71 | return jnp.mean(jnp.square(a - b), axis=(-3, -2, -1))
72 |
73 |
74 | def psnr(a: chex.Array, b: chex.Array) -> chex.Numeric:
75 | """Returns the Peak Signal-to-Noise Ratio between `a` and `b`.
76 |
77 | Assumes that the dynamic range of the images (the difference between the
78 | maximum and the minimum allowed values) is 1.0.
79 |
80 | Args:
81 | a (chex.Array): First image (or set of images)
82 | b (chex.Array): Second image (or set of images)
83 |
84 | Returns:
85 | chex.Numeric: PSNR in decibels between `a` and `b`
86 | """
87 | # DO NOT REMOVE - Logging usage.
88 |
89 | chex.assert_rank([a, b], {3, 4})
90 | chex.assert_type([a, b], float)
91 | chex.assert_equal_shape([a, b])
92 | return -10.0 * jnp.log(mse(a, b)) / jnp.log(10.0)
93 |
94 |
95 | def rmse(a: chex.Array, b: chex.Array) -> chex.Numeric:
96 | """Returns the Root Mean Squared Error between `a` and `b`.
97 |
98 | Args:
99 | a (chex.Array): First image (or set of images)
100 | b (chex.Array): Second image (or set of images)
101 |
102 | Returns:
103 | chex.Numeric: RMSE between `a` and `b`
104 | """
105 | # DO NOT REMOVE - Logging usage.
106 |
107 | chex.assert_rank([a, b], {3, 4})
108 | chex.assert_type([a, b], float)
109 | chex.assert_equal_shape([a, b])
110 | return jnp.sqrt(mse(a, b))
111 |
112 |
113 | def simse(a: chex.Array, b: chex.Array) -> chex.Numeric:
114 | """Returns the Scale-Invariant Mean Squared Error between `a` and `b`.
115 |
116 | For each image pair, a scaling factor for `b` is computed as the solution to
117 | the following problem:
118 |
119 | min_alpha || vec(a) - alpha * vec(b) ||_2^2
120 |
121 | where `a` and `b` are flattened, i.e., vec(x) = np.flatten(x). The MSE between
122 | the optimally scaled `b` and `a` is returned: mse(a, alpha*b).
123 |
124 | This is a scale-invariant metric, so for example: simse(x, y) == sims(x, y*5).
125 |
126 | This metric was used in "Shape, Illumination, and Reflectance from Shading" by
127 | Barron and Malik, TPAMI, '15.
128 |
129 | Args:
130 | a (chex.Array): First image (or set of images)
131 | b (chex.Array): Second image (or set of images)
132 |
133 | Returns:
134 | chex.Numeric: SIMSE between `a` and `b`
135 | """
136 | # DO NOT REMOVE - Logging usage.
137 |
138 | chex.assert_rank([a, b], {3, 4})
139 | chex.assert_type([a, b], float)
140 | chex.assert_equal_shape([a, b])
141 |
142 | a_dot_b = (a * b).sum(axis=(-3, -2, -1), keepdims=True)
143 | b_dot_b = (b * b).sum(axis=(-3, -2, -1), keepdims=True)
144 | alpha = a_dot_b / b_dot_b
145 | return mse(a, alpha * b)
146 |
147 |
148 | def ssim(
149 | a: chex.Array,
150 | b: chex.Array,
151 | *,
152 | max_val: float = 1.0,
153 | filter_size: int = 11,
154 | filter_sigma: float = 1.5,
155 | k1: float = 0.01,
156 | k2: float = 0.03,
157 | return_map: bool = False,
158 | precision=jax.lax.Precision.HIGHEST,
159 | filter_fn: Callable[[chex.Array], chex.Array] | None = None,
160 | ignore_nans: bool = False,
161 | ) -> chex.Numeric:
162 | """Computes the structural similarity index (SSIM) between image pairs.
163 |
164 | This function is based on the standard SSIM implementation from:
165 | Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli,
166 | "Image quality assessment: from error visibility to structural similarity",
167 | in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, 2004.
168 |
169 | This function was modeled after tf.image.ssim, and should produce comparable
170 | output.
171 |
172 | Note: the true SSIM is only defined on grayscale. This function does not
173 | perform any colorspace transform. If the input is in a color space, then it
174 | will compute the average SSIM.
175 |
176 | Args:
177 | a (chex.Array): First image (or set of images)
178 | b (chex.Array): Second image (or set of images)
179 | max_val (float): The maximum magnitude that `a` or `b` can have. Defaults to 1.
180 | filter_size (int): Window size (>= 1). Image dims must be at least this small.
181 | Defaults to 11
182 | filter_sigma (float): The bandwidth of the Gaussian used for filtering (> 0.).
183 | Defaults to 1.5
184 | k1 (float): One of the SSIM dampening parameters (> 0.). Defaults to 0.01.
185 | k2 (float): One of the SSIM dampening parameters (> 0.). Defaults to 0.03.
186 | return_map (bool): If True, will cause the per-pixel SSIM "map" to be returned.
187 | Defaults to False.
188 | precision: The numerical precision to use when performing convolution
189 | filter_fn (Callable[[chex.Array], chex.Array] | None): An optional argument for
190 | overriding the filter function used by SSIM, which would otherwise be a 2D
191 | Gaussian blur specified by filter_size and filter_sigma
192 | ignore_nans (bool): Defaults to False
193 |
194 | Returns:
195 | chex.Numeric: Each image's mean SSIM, or a tensor of individual values if `return_map` is True
196 | """
197 | # DO NOT REMOVE - Logging usage.
198 |
199 | chex.assert_rank([a, b], {3, 4})
200 | chex.assert_type([a, b], float)
201 | chex.assert_equal_shape([a, b])
202 |
203 | if filter_fn is None:
204 | # Construct a 1D Gaussian blur filter.
205 | hw = filter_size // 2
206 | shift = (2 * hw - filter_size + 1) / 2
207 | f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma) ** 2
208 | filt = jnp.exp(-0.5 * f_i)
209 | filt /= jnp.sum(filt)
210 |
211 | # Construct a 1D convolution.
212 | def filter_fn_1(z):
213 | return jnp.convolve(z, filt, mode="valid", precision=precision)
214 |
215 | filter_fn_vmap = jax.vmap(filter_fn_1)
216 |
217 | # Apply the vectorized filter along the y axis.
218 | def filter_fn_y(z):
219 | z_flat = jnp.moveaxis(z, -3, -1).reshape((-1, z.shape[-3]))
220 | z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + (
221 | z.shape[-2],
222 | z.shape[-1],
223 | -1,
224 | )
225 | return jnp.moveaxis(filter_fn_vmap(z_flat).reshape(z_filtered_shape), -1, -3)
226 |
227 | # Apply the vectorized filter along the x axis.
228 | def filter_fn_x(z):
229 | z_flat = jnp.moveaxis(z, -2, -1).reshape((-1, z.shape[-2]))
230 | z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + (
231 | z.shape[-3],
232 | z.shape[-1],
233 | -1,
234 | )
235 | return jnp.moveaxis(filter_fn_vmap(z_flat).reshape(z_filtered_shape), -1, -2)
236 |
237 | # Apply the blur in both x and y.
238 | def filter_fn(z):
239 | return filter_fn_y(filter_fn_x(z))
240 |
241 | mu0 = filter_fn(a)
242 | mu1 = filter_fn(b)
243 | mu00 = mu0 * mu0
244 | mu11 = mu1 * mu1
245 | mu01 = mu0 * mu1
246 | sigma00 = filter_fn(a**2) - mu00
247 | sigma11 = filter_fn(b**2) - mu11
248 | sigma01 = filter_fn(a * b) - mu01
249 |
250 | # Clip the variances and covariances to valid values.
251 | # Variance must be non-negative:
252 | epsilon = jnp.finfo(jnp.float32).eps ** 2
253 | sigma00 = jnp.maximum(epsilon, sigma00)
254 | sigma11 = jnp.maximum(epsilon, sigma11)
255 | sigma01 = jnp.sign(sigma01) * jnp.minimum(jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01))
256 |
257 | c1 = (k1 * max_val) ** 2
258 | c2 = (k2 * max_val) ** 2
259 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
260 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
261 | ssim_map = numer / denom
262 |
263 | if ignore_nans:
264 | ssim_value = jnp.nanmean(ssim_map, axis=tuple(range(-3, 0)))
265 | else:
266 | ssim_value = jnp.mean(ssim_map, axis=tuple(range(-3, 0)))
267 | return ssim_map if return_map else ssim_value
268 |
--------------------------------------------------------------------------------
/src/cloudcasting/models.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any
3 |
4 | import numpy as np
5 |
6 | from cloudcasting.constants import (
7 | DATA_INTERVAL_SPACING_MINUTES,
8 | FORECAST_HORIZON_MINUTES,
9 | NUM_FORECAST_STEPS,
10 | )
11 | from cloudcasting.types import BatchInputArray, BatchOutputArray
12 |
13 |
14 | # model interface
15 | class AbstractModel(ABC):
16 | """An abstract class for validating a generic satellite prediction model"""
17 |
18 | history_steps: int
19 |
20 | def __init__(self, history_steps: int) -> None:
21 | self.history_steps: int = history_steps
22 |
23 | @abstractmethod
24 | def forward(self, X: BatchInputArray) -> BatchOutputArray:
25 | """Abstract method for the forward pass of the model.
26 |
27 | Args:
28 | X (BatchInputArray): Either a batch or a sample of the most recent satellite data.
29 | X will be 5 dimensional and has shape [batch, channels, time, height, width] where
30 | time = {t_{-n}, ..., t_{0}}
31 | (all n values needed to predict {t'_{1}, ..., t'_{horizon}})
32 |
33 | Returns:
34 | ForecastArray: The models prediction of the future satellite
35 | data of shape [batch, channels, rollout_steps, height, width] where
36 | rollout_steps = {t'_{1}, ..., t'_{horizon}}.
37 | """
38 |
39 | def check_predictions(self, y_hat: BatchOutputArray) -> None:
40 | """Checks the predictions conform to expectations"""
41 | # Check no NaNs in the predictions
42 | if np.isnan(y_hat).any():
43 | msg = f"Predictions contain NaNs: {np.isnan(y_hat).mean()=:.4g}."
44 | raise ValueError(msg)
45 |
46 | # Check the range of the predictions. If outside the expected range this can interfere
47 | # with computing metrics like structural similarity
48 | if ((y_hat < 0) | (y_hat > 1)).any():
49 | msg = (
50 | "The predictions must be in the range [0, 1]. "
51 | f"Found range [{y_hat.min(), y_hat.max()}]."
52 | )
53 | raise ValueError(msg)
54 |
55 | if y_hat.shape[-3] != NUM_FORECAST_STEPS:
56 | msg = (
57 | f"The number of forecast steps in the model ({y_hat.shape[2]}) does not match "
58 | f"{NUM_FORECAST_STEPS=}, defined from "
59 | f"{FORECAST_HORIZON_MINUTES=} // {DATA_INTERVAL_SPACING_MINUTES=}."
60 | f"Found shape {y_hat.shape}."
61 | )
62 | raise ValueError(msg)
63 |
64 | def __call__(self, X: BatchInputArray) -> BatchOutputArray:
65 | # check the shape of the input
66 | if X.shape[-3] != self.history_steps:
67 | msg = (
68 | f"The number of history steps in the input ({X.shape[-3]}) does not match "
69 | f"{self.history_steps=}."
70 | )
71 | raise ValueError(msg)
72 |
73 | # run the forward pass
74 | y_hat = self.forward(X)
75 |
76 | # carry out a set of checks on the predictions to make sure they conform to the
77 | # expectations of the validation script
78 | self.check_predictions(y_hat)
79 |
80 | return y_hat
81 |
82 | @abstractmethod
83 | def hyperparameters_dict(self) -> dict[str, Any]:
84 | """Return a dictionary of the hyperparameters used to train the model"""
85 |
86 |
87 | class VariableHorizonModel(AbstractModel):
88 | def __init__(self, rollout_steps: int, history_steps: int) -> None:
89 | self.rollout_steps: int = rollout_steps
90 | super().__init__(history_steps)
91 |
--------------------------------------------------------------------------------
/src/cloudcasting/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alan-turing-institute/cloudcasting/0da4087bba01823d2477dba61ea3e6cfd557212c/src/cloudcasting/py.typed
--------------------------------------------------------------------------------
/src/cloudcasting/types.py:
--------------------------------------------------------------------------------
1 | __all__ = (
2 | "BatchInputArray",
3 | "BatchOutputArray",
4 | "BatchOutputArrayJAX",
5 | "ChannelArray",
6 | "InputArray",
7 | "MetricArray",
8 | "OutputArray",
9 | "SampleInputArray",
10 | "SampleOutputArray",
11 | "TimeArray",
12 | )
13 |
14 | import jaxtyping
15 | import numpy as np
16 | import numpy.typing as npt
17 | from jaxtyping import Float as Float32
18 |
19 | # Type aliases for clarity + reuse
20 | Array = npt.NDArray[np.float32] # the type arg is ignored by jaxtyping, but is here for clarity
21 | TimeArray = Float32[Array, "time"]
22 | MetricArray = Float32[Array, "channels time"]
23 | ChannelArray = Float32[Array, "channels"]
24 |
25 | SampleInputArray = Float32[Array, "channels time height width"]
26 | BatchInputArray = Float32[Array, "batch channels time height width"]
27 | InputArray = SampleInputArray | BatchInputArray
28 |
29 |
30 | SampleOutputArray = Float32[Array, "channels rollout_steps height width"]
31 | BatchOutputArray = Float32[Array, "batch channels rollout_steps height width"]
32 | BatchOutputArrayJAX = Float32[jaxtyping.Array, "batch channels rollout_steps height width"]
33 |
34 | OutputArray = SampleOutputArray | BatchOutputArray
35 |
--------------------------------------------------------------------------------
/src/cloudcasting/utils.py:
--------------------------------------------------------------------------------
1 | __all__ = (
2 | "find_contiguous_t0_time_periods",
3 | "find_contiguous_time_periods",
4 | "lon_lat_to_geostationary_area_coords",
5 | "numpy_validation_collate_fn",
6 | )
7 |
8 | from collections.abc import Sequence
9 | from datetime import timedelta
10 |
11 | import numpy as np
12 | import pandas as pd
13 | import pyproj
14 | import pyresample
15 | import xarray as xr
16 | from numpy.typing import NDArray
17 |
18 | from cloudcasting.types import (
19 | BatchInputArray,
20 | BatchOutputArray,
21 | SampleInputArray,
22 | SampleOutputArray,
23 | )
24 |
25 |
26 | # taken from ocf_datapipes
27 | def lon_lat_to_geostationary_area_coords(
28 | x: Sequence[float],
29 | y: Sequence[float],
30 | xr_data: xr.Dataset | xr.DataArray,
31 | ) -> tuple[Sequence[float], Sequence[float]]:
32 | """Loads geostationary area and change from lon-lat to geostationary coords
33 |
34 | Args:
35 | x (Sequence[float]): Longitude east-west
36 | y Sequence[float]: Latitude north-south
37 | xr_data (xr.Dataset | xr.DataArray): xarray object with geostationary area
38 |
39 | Returns:
40 | tuple[Sequence[float], Sequence[float]]: x, y in geostationary coordinates
41 | """
42 | # WGS84 is short for "World Geodetic System 1984", used in GPS. Uses
43 | # latitude and longitude.
44 | WGS84 = 4326
45 |
46 | try:
47 | area_definition_yaml = xr_data.attrs["area"]
48 | except KeyError:
49 | area_definition_yaml = xr_data.data.attrs["area"]
50 | geostationary_crs = pyresample.area_config.load_area_from_string(area_definition_yaml).crs # type: ignore[no-untyped-call]
51 | lonlat_to_geostationary = pyproj.Transformer.from_crs(
52 | crs_from=WGS84,
53 | crs_to=geostationary_crs,
54 | always_xy=True,
55 | ).transform
56 | return lonlat_to_geostationary(xx=x, yy=y)
57 |
58 |
59 | def find_contiguous_time_periods(
60 | datetimes: pd.DatetimeIndex,
61 | min_seq_length: int,
62 | max_gap_duration: timedelta,
63 | ) -> pd.DataFrame:
64 | """Return a pd.DataFrame where each row records the boundary of a contiguous time period.
65 |
66 | Args:
67 | datetimes (pd.DatetimeIndex): Must be sorted.
68 | min_seq_length (int): Sequences of min_seq_length or shorter will be discarded. Typically,
69 | this would be set to the `total_seq_length` of each machine learning example.
70 | max_gap_duration (timedelta): If any pair of consecutive `datetimes` is more than
71 | `max_gap_duration` apart, then this pair of `datetimes` will be considered a "gap" between
72 | two contiguous sequences. Typically, `max_gap_duration` would be set to the sample period of
73 | the timeseries.
74 |
75 | Returns:
76 | pd.DataFrame: The DataFrame has two columns `start_dt` and `end_dt`
77 | (where 'dt' is short for 'datetime'). Each row represents a single time period.
78 | """
79 | # Sanity checks.
80 | assert len(datetimes) > 0
81 | assert min_seq_length > 1
82 | assert datetimes.is_monotonic_increasing
83 | assert datetimes.is_unique
84 |
85 | # Find indices of gaps larger than max_gap:
86 | gap_mask = pd.TimedeltaIndex(np.diff(datetimes)) > np.timedelta64(max_gap_duration)
87 | gap_indices = np.argwhere(gap_mask)[:, 0]
88 |
89 | # gap_indicies are the indices into dt_index for the timestep immediately before the gap.
90 | # e.g. if the datetimes at 12:00, 12:05, 18:00, 18:05 then gap_indicies will be [1].
91 | # So we add 1 to gap_indices to get segment_boundaries, an index into dt_index
92 | # which identifies the _start_ of each segment.
93 | segment_boundaries = gap_indices + 1
94 |
95 | # Capture the last segment of dt_index.
96 | segment_boundaries = np.append(segment_boundaries, len(datetimes))
97 |
98 | periods: list[dict[str, pd.Timestamp]] = []
99 | start_i = 0
100 | for next_start_i in segment_boundaries:
101 | n_timesteps = next_start_i - start_i
102 | if n_timesteps > min_seq_length:
103 | end_i = next_start_i - 1
104 | period = {"start_dt": datetimes[start_i], "end_dt": datetimes[end_i]}
105 | periods.append(period)
106 | start_i = next_start_i
107 |
108 | assert len(periods) > 0, (
109 | f"Did not find an periods from {datetimes}. {min_seq_length=} {max_gap_duration=}"
110 | )
111 |
112 | return pd.DataFrame(periods)
113 |
114 |
115 | def find_contiguous_t0_time_periods(
116 | contiguous_time_periods: pd.DataFrame, history_duration: timedelta, forecast_duration: timedelta
117 | ) -> pd.DataFrame:
118 | """Get all time periods which contain valid t0 datetimes.
119 | `t0` is the datetime of the most recent observation.
120 |
121 | Args:
122 | contiguous_time_periods (pd.DataFrame): Dataframe of continguous time periods
123 | history_duration (timedelta): Duration of the history
124 | forecast_duration (timedelta): Duration of the forecast
125 |
126 | Returns:
127 | pd.DataFrame: A DataFrame with two columns `start_dt` and `end_dt`
128 | (where 'dt' is short for 'datetime'). Each row represents a single time period.
129 | """
130 | contiguous_time_periods["start_dt"] += np.timedelta64(history_duration)
131 | contiguous_time_periods["end_dt"] -= np.timedelta64(forecast_duration)
132 | assert (contiguous_time_periods["start_dt"] < contiguous_time_periods["end_dt"]).all()
133 | return contiguous_time_periods
134 |
135 |
136 | def numpy_validation_collate_fn(
137 | samples: list[tuple[SampleInputArray, SampleOutputArray]],
138 | ) -> tuple[BatchInputArray, BatchOutputArray]:
139 | """Collate a list of data + targets into a batch.
140 |
141 | Args:
142 | samples (list[tuple[SampleInputArray, SampleOutputArray]]): List of (X, y) samples,
143 | with sizes of X (batch, channels, time, height, width) and
144 | y (batch, channels, rollout_steps, height, width)
145 |
146 | Returns:
147 | tuple(np.ndarray, np.ndarray): The collated batch of X samples in the form
148 | (batch, channels, time, height, width) and the collated batch of y samples
149 | in the form (batch, channels, rollout_steps, height, width)
150 | """
151 |
152 | # Create empty stores for the compiled batch
153 | X_all = np.empty((len(samples), *samples[0][0].shape), dtype=np.float32)
154 | y_all = np.empty((len(samples), *samples[0][1].shape), dtype=np.float32)
155 |
156 | # Fill the stores with the samples
157 | for i, (X, y) in enumerate(samples):
158 | X_all[i] = X
159 | y_all[i] = y
160 | return X_all, y_all
161 |
162 |
163 | def create_cutout_mask(
164 | mask_size: tuple[int, int, int, int],
165 | image_size: tuple[int, int],
166 | ) -> NDArray[np.float32]:
167 | """Create a mask with a cutout in the center.
168 |
169 | Args:
170 | mask_size: Size of the cutout
171 | image_size: Size of the image
172 |
173 | Returns:
174 | np.ndarray: The mask
175 | """
176 | height, width = image_size
177 | min_x, max_x, min_y, max_y = mask_size
178 |
179 | mask = np.empty((height, width), dtype=np.float32)
180 | mask[:] = np.nan
181 | mask[min_y:max_y, min_x:max_x] = 1
182 | return mask
183 |
--------------------------------------------------------------------------------
/src/cloudcasting/validation.py:
--------------------------------------------------------------------------------
1 | __all__ = ("validate", "validate_from_config")
2 |
3 | import importlib.util
4 | import inspect
5 | import logging
6 | import os
7 | import sys
8 | from collections.abc import Callable
9 | from functools import partial
10 | from typing import Annotated, Any, cast
11 |
12 | import jax.numpy as jnp
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import numpy.typing as npt
16 | import typer
17 | import wandb
18 | import yaml
19 | from jax import tree
20 | from jaxtyping import Array, Float32
21 | from matplotlib.colors import Normalize
22 | from torch.utils.data import DataLoader
23 | from tqdm import tqdm
24 |
25 | try:
26 | import torch.multiprocessing as mp
27 |
28 | mp.set_start_method("spawn", force=True)
29 | except RuntimeError:
30 | pass
31 |
32 | import cloudcasting
33 | from cloudcasting import metrics as dm_pix # for compatibility if our changes are upstreamed
34 | from cloudcasting.constants import (
35 | CROPPED_CUTOUT_MASK,
36 | CROPPED_CUTOUT_MASK_BOUNDARY,
37 | CROPPED_IMAGE_SIZE_TUPLE,
38 | CUTOUT_MASK,
39 | CUTOUT_MASK_BOUNDARY,
40 | DATA_INTERVAL_SPACING_MINUTES,
41 | FORECAST_HORIZON_MINUTES,
42 | IMAGE_SIZE_TUPLE,
43 | NUM_CHANNELS,
44 | )
45 | from cloudcasting.dataset import ValidationSatelliteDataset
46 | from cloudcasting.models import AbstractModel
47 | from cloudcasting.types import (
48 | BatchOutputArrayJAX,
49 | ChannelArray,
50 | MetricArray,
51 | SampleOutputArray,
52 | TimeArray,
53 | )
54 | from cloudcasting.utils import numpy_validation_collate_fn
55 |
56 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
57 | logger = logging.getLogger(__name__)
58 |
59 | # defined in manchester prize technical document
60 | WANDB_ENTITY = "manchester_prize"
61 | VIDEO_SAMPLE_DATES = [
62 | "2022-01-17 11:00",
63 | "2022-04-11 06:00",
64 | "2022-06-10 11:00",
65 | "2022-09-30 18:15",
66 | ]
67 | VIDEO_SAMPLE_CHANNELS = ["VIS008", "IR_087"]
68 |
69 |
70 | def log_mean_metrics_to_wandb(
71 | metric_value: float,
72 | plot_name: str,
73 | metric_name: str,
74 | ) -> None:
75 | """Upload a bar chart of mean metric value to wandb
76 |
77 | Args:
78 | metric_value: The mean metric value to upload
79 | plot_name: The name under which to save the plot to wandb
80 | metric_name: The name of the metric used to label the y-axis in the uploaded plot
81 | """
82 | data = [[metric_name, metric_value]]
83 | table = wandb.Table(data=data, columns=["metric name", "value"])
84 | wandb.log({plot_name: wandb.plot.bar(table, "metric name", "value", title=plot_name)})
85 |
86 |
87 | def log_per_horizon_metrics_to_wandb(
88 | horizon_mins: TimeArray,
89 | metric_values: TimeArray,
90 | plot_name: str,
91 | metric_name: str,
92 | ) -> None:
93 | """Upload a plot of metric value vs forecast horizon to wandb
94 |
95 | Args:
96 | horizon_mins: Array of the number of minutes after the init time for each predicted frame
97 | of satellite data
98 | metric_values: Array of the mean metric value at each forecast horizon
99 | plot_name: The name under which to save the plot to wandb
100 | metric_name: The name of the metric used to label the y-axis in the uploaded plot
101 | """
102 | data = list(zip(horizon_mins, metric_values, strict=True))
103 | table = wandb.Table(data=data, columns=["horizon_mins", metric_name])
104 | wandb.log({plot_name: wandb.plot.line(table, "horizon_mins", metric_name, title=plot_name)})
105 |
106 |
107 | def log_per_channel_metrics_to_wandb(
108 | channel_names: list[str],
109 | metric_values: ChannelArray,
110 | plot_name: str,
111 | metric_name: str,
112 | ) -> None:
113 | """Upload a bar chart of metric value for each channel to wandb
114 |
115 | Args:
116 | channel_names: List of channel names for ordering purposes
117 | metric_values: Array of the mean metric value for each channel
118 | plot_name: The name under which to save the plot to wandb
119 | metric_name: The name of the metric used to label the y-axis in the uploaded plot
120 | """
121 | data = list(zip(channel_names, metric_values, strict=True))
122 | table = wandb.Table(data=data, columns=["channel name", metric_name])
123 | wandb.log({plot_name: wandb.plot.bar(table, "channel name", metric_name, title=plot_name)})
124 |
125 |
126 | def log_prediction_video_to_wandb(
127 | y_hat: SampleOutputArray,
128 | y: SampleOutputArray,
129 | video_name: str,
130 | channel_ind: int = 8,
131 | fps: int = 1,
132 | ) -> None:
133 | """Upload a video comparing the true and predicted future satellite data to wandb
134 |
135 | Args:
136 | y_hat: The predicted sequence of satellite data
137 | y: The true sequence of satellite data
138 | video_name: The name under which to save the video to wandb
139 | channel_ind: The channel number to show in the video
140 | fps: Frames per second of the resulting video
141 | """
142 |
143 | # Copy the arrays so we don't modify the original
144 | y_hat = y_hat.copy()
145 | y = y.copy()
146 |
147 | # Find NaNs (or infilled NaNs) in ground truth
148 | mask = np.isnan(y) | (y == -1)
149 |
150 | # Set pixels which are NaN in the ground truth to 0 in both arrays
151 | y[mask] = 0
152 | y_hat[mask] = 0
153 |
154 | create_box = True
155 |
156 | # create a boundary box for the crop
157 | if y.shape[-2:] == IMAGE_SIZE_TUPLE:
158 | boxl, boxr, boxb, boxt = CUTOUT_MASK_BOUNDARY
159 | bsize = IMAGE_SIZE_TUPLE
160 | elif y.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE:
161 | boxl, boxr, boxb, boxt = CROPPED_CUTOUT_MASK_BOUNDARY
162 | bsize = CROPPED_IMAGE_SIZE_TUPLE
163 | else:
164 | create_box = False
165 |
166 | if create_box:
167 | # box mask
168 | _maskb = np.ones(bsize, dtype=np.float32)
169 | _maskb[boxb : boxb + 2, boxl:boxr] = np.nan # Top edge
170 | _maskb[boxt - 2 : boxt, boxl:boxr] = np.nan # Bottom edge
171 | _maskb[boxb:boxt, boxl : boxl + 2] = np.nan # Left edge
172 | _maskb[boxb:boxt, boxr - 2 : boxr] = np.nan # Right edge
173 | maskb: Float32[npt.NDArray[np.float32], "1 1 a b"] = _maskb[np.newaxis, np.newaxis, :, :]
174 |
175 | y = y * maskb
176 | y_hat = y_hat * maskb
177 |
178 | # Tranpose the arrays so time is the first dimension and select the channel
179 | # Then flip the frames so they are in the correct orientation for the video
180 | y_frames = y.transpose(1, 2, 3, 0)[:, ::-1, ::-1, channel_ind : channel_ind + 1]
181 | y_hat_frames = y_hat.transpose(1, 2, 3, 0)[:, ::-1, ::-1, channel_ind : channel_ind + 1]
182 |
183 | # Concatenate the predicted and true frames so they are displayed side by side
184 | video_array = np.concatenate([y_hat_frames, y_frames], axis=2)
185 |
186 | # Clip the values and rescale to be between 0 and 255 and repeat for RGB channels
187 | video_array = video_array.clip(0, 1)
188 | video_array = np.repeat(video_array, 3, axis=3) * 255
189 | # add Alpha channel
190 | video_array = np.concatenate(
191 | [video_array, np.full((*video_array[:, :, :, 0].shape, 1), 255)], axis=3
192 | )
193 |
194 | # calculate the difference between the prediction and the ground truth and add colour
195 | y_diff_frames = y_hat_frames - y_frames
196 | diff_ccmap = plt.get_cmap("bwr")(Normalize(vmin=-1, vmax=1)(y_diff_frames[:, :, :, 0]))
197 | diff_ccmap = diff_ccmap * 255
198 |
199 | # combine add difference to the video array
200 | video_array = np.concatenate([video_array, diff_ccmap], axis=2)
201 |
202 | # Set bounding box to a colour so it is visible
203 | if create_box:
204 | video_array[:, :, :, 0][np.isnan(video_array[:, :, :, 0])] = 250
205 | video_array[:, :, :, 1][np.isnan(video_array[:, :, :, 1])] = 40
206 | video_array[:, :, :, 2][np.isnan(video_array[:, :, :, 2])] = 10
207 | video_array[:, :, :, 3][np.isnan(video_array[:, :, :, 3])] = 255
208 |
209 | video_array = video_array.transpose(0, 3, 1, 2)
210 | video_array = video_array.astype(np.uint8)
211 |
212 | wandb.log(
213 | {
214 | video_name: wandb.Video(
215 | video_array,
216 | caption="Sample prediction (left), ground truth (middle), difference (right)",
217 | fps=fps,
218 | )
219 | }
220 | )
221 |
222 |
223 | def score_model_on_all_metrics(
224 | model: AbstractModel,
225 | valid_dataset: ValidationSatelliteDataset,
226 | batch_size: int = 1,
227 | num_workers: int = 0,
228 | batch_limit: int | None = None,
229 | metric_names: tuple[str, ...] | list[str] = ("mae", "mse", "ssim"),
230 | metric_kwargs: dict[str, dict[str, Any]] | None = None,
231 | ) -> tuple[dict[str, MetricArray], list[str]]:
232 | """Calculate the scoreboard metrics for the given model on the validation dataset.
233 |
234 | Args:
235 | model (AbstractModel): The model to score.
236 | valid_dataset (ValidationSatelliteDataset): The validation dataset to score the model on.
237 | batch_size (int, optional): Defaults to 1.
238 | num_workers (int, optional): Defaults to 0.
239 | batch_limit (int | None, optional): Defaults to None. Stop after this many batches.
240 | For testing purposes only.
241 | metric_names (tuple[str, ...] | list[str]: Names of metrics to calculate. Need to be defined
242 | in cloudcasting.metrics. Defaults to ("mae", "mse", "ssim").
243 | metric_kwargs (dict[str, dict[str, Any]] | None, optional): kwargs to pass to functions in
244 | cloudcasting.metrics. Defaults to None.
245 |
246 | Returns:
247 | tuple[dict[str, MetricArray], list[str]]:
248 | Element 0: keys are metric names, values are arrays of results
249 | averaged over all dims except horizon and channel.
250 | Element 1: list of channel names.
251 | """
252 |
253 | # check the the data spacing perfectly divides the forecast horizon
254 | assert FORECAST_HORIZON_MINUTES % DATA_INTERVAL_SPACING_MINUTES == 0, (
255 | "forecast horizon must be a multiple of the data interval "
256 | "(please make an issue on github if you see this!!!!)"
257 | )
258 |
259 | valid_dataloader = DataLoader(
260 | valid_dataset,
261 | batch_size=batch_size,
262 | num_workers=num_workers,
263 | shuffle=False,
264 | collate_fn=numpy_validation_collate_fn,
265 | drop_last=False,
266 | )
267 |
268 | if metric_kwargs is None:
269 | metric_kwargs_dict: dict[str, dict[str, Any]] = {}
270 | else:
271 | metric_kwargs_dict = metric_kwargs
272 |
273 | def get_pix_function(
274 | name: str,
275 | pix_kwargs: dict[str, dict[str, Any]],
276 | ) -> Callable[
277 | [BatchOutputArrayJAX, BatchOutputArrayJAX], Float32[Array, "batch channels time"]
278 | ]:
279 | func = getattr(dm_pix, name)
280 | sig = inspect.signature(func)
281 | if "ignore_nans" in sig.parameters:
282 | func = partial(func, ignore_nans=True)
283 | if name in pix_kwargs:
284 | func = partial(func, **pix_kwargs[name])
285 | return cast(
286 | Callable[
287 | [BatchOutputArrayJAX, BatchOutputArrayJAX], Float32[Array, "batch channels time"]
288 | ],
289 | func,
290 | )
291 |
292 | metric_funcs: dict[
293 | str,
294 | Callable[[BatchOutputArrayJAX, BatchOutputArrayJAX], Float32[Array, "batch channels time"]],
295 | ] = {name: get_pix_function(name, metric_kwargs_dict) for name in metric_names}
296 |
297 | metrics: dict[str, list[Float32[Array, "batch channels time"]]] = {
298 | metric: [] for metric in metric_funcs
299 | }
300 |
301 | # we probably want to accumulate metrics here instead of taking the mean of means!
302 | loop_steps = len(valid_dataloader) if batch_limit is None else batch_limit
303 |
304 | info_str = f"Validating model on {loop_steps} batches..."
305 | logger.info(info_str)
306 |
307 | for i, (X, y) in tqdm(enumerate(valid_dataloader), total=loop_steps):
308 | y_hat = model(X)
309 |
310 | # identify the correct mask / create a mask if necessary
311 | if X.shape[-2:] == IMAGE_SIZE_TUPLE:
312 | mask = CUTOUT_MASK
313 | elif X.shape[-2:] == CROPPED_IMAGE_SIZE_TUPLE:
314 | mask = CROPPED_CUTOUT_MASK
315 | else:
316 | mask = np.ones(X.shape[-2:], dtype=np.float32)
317 |
318 | # cutout the GB area
319 | mask_full = mask[np.newaxis, np.newaxis, np.newaxis, :, :]
320 | y_cutout = y * mask_full
321 | y_hat = y_hat * mask_full
322 |
323 | # assert shapes are the same
324 | assert y_cutout.shape == y_hat.shape, f"{y_cutout.shape=} != {y_hat.shape=}"
325 |
326 | # If nan_to_num is used in the dataset, the model will output -1 for NaNs. We need to
327 | # convert these back to NaNs for the metrics
328 | y_cutout[y_cutout == -1] = np.nan
329 |
330 | # pix accepts arrays of shape [batch, height, width, channels].
331 | # our arrays are of shape [batch, channels, time, height, width].
332 | # channel dim would be reduced; we add a new axis to satisfy the shape reqs.
333 | # we then reshape to squash batch, channels, and time into the leading axis,
334 | # where the vmap in metrics.py will broadcast over the leading dim.
335 | y_jax = jnp.array(y_cutout).reshape(-1, *y_cutout.shape[-2:])[..., np.newaxis]
336 | y_hat_jax = jnp.array(y_hat).reshape(-1, *y_hat.shape[-2:])[..., np.newaxis]
337 |
338 | for metric_name, metric_func in metric_funcs.items():
339 | # we reshape the result back into [batch, channels, time],
340 | # then take the mean over the batch
341 | metric_res = metric_func(y_hat_jax, y_jax).reshape(*y_cutout.shape[:3])
342 | batch_reduced_metric = jnp.nanmean(metric_res, axis=0)
343 | metrics[metric_name].append(batch_reduced_metric)
344 |
345 | if batch_limit is not None and i >= batch_limit:
346 | break
347 | # convert back to numpy and reduce over all batches
348 | res = tree.map(
349 | lambda x: np.mean(np.array(x), axis=0), metrics, is_leaf=lambda x: isinstance(x, list)
350 | )
351 |
352 | num_timesteps = FORECAST_HORIZON_MINUTES // DATA_INTERVAL_SPACING_MINUTES
353 |
354 | channel_names = valid_dataset.ds.variable.values.tolist()
355 |
356 | # technically, if we've made a mistake in the shapes/reduction dim, this would silently fail
357 | # if the number of batches equals the number of timesteps
358 | for v in res.values():
359 | msg = (
360 | f"metric {v.shape} is not the correct shape "
361 | f"(should be {(len(channel_names), num_timesteps)})"
362 | )
363 | assert v.shape == (len(channel_names), num_timesteps), msg
364 |
365 | return res, channel_names
366 |
367 |
368 | def calc_mean_metrics(metrics_dict: dict[str, MetricArray]) -> dict[str, float]:
369 | """Calculate the mean metric reduced over all dimensions.
370 |
371 | Args:
372 | metrics_dict: dict mapping metric names to arrays of metric values
373 |
374 | Returns:
375 | dict: dict mapping metric names to mean metric values
376 | """
377 | return {k: float(np.mean(v)) for k, v in metrics_dict.items()}
378 |
379 |
380 | def calc_mean_metrics_per_horizon(metrics_dict: dict[str, MetricArray]) -> dict[str, TimeArray]:
381 | """Calculate the mean of each metric for each forecast horizon.
382 |
383 | Args:
384 | metrics_dict: dict mapping metric names to arrays of metric values
385 |
386 | Returns:
387 | dict: dict mapping metric names to array of mean metric values for each horizon
388 | """
389 | return {k: np.mean(v, axis=0) for k, v in metrics_dict.items()}
390 |
391 |
392 | def calc_mean_metrics_per_channel(metrics_dict: dict[str, MetricArray]) -> dict[str, ChannelArray]:
393 | """Calculate the mean of each metric for each channel.
394 |
395 | Args:
396 | metrics_dict: dict mapping metric names to arrays of metric values
397 |
398 | Returns:
399 | dict: dict mapping metric names to array of mean metric values for each channel
400 | """
401 | return {k: np.mean(v, axis=1) for k, v in metrics_dict.items()}
402 |
403 |
404 | def validate(
405 | model: AbstractModel,
406 | data_path: list[str] | str,
407 | wandb_project_name: str,
408 | wandb_run_name: str,
409 | nan_to_num: bool = False,
410 | batch_size: int = 1,
411 | num_workers: int = 0,
412 | batch_limit: int | None = None,
413 | ) -> None:
414 | """Run the full validation procedure on the model and log the results to wandb.
415 |
416 | Args:
417 | model (AbstractModel): the model to be validated
418 | data_path (str): path to the validation data set
419 | nan_to_num (bool, optional): Whether to convert NaNs to -1. Defaults to False.
420 | batch_size (int, optional): Defaults to 1.
421 | num_workers (int, optional): Defaults to 0.
422 | batch_limit (int | None, optional): Defaults to None. For testing purposes only.
423 | """
424 |
425 | # Verify we can run the model forward
426 | try:
427 | model(np.zeros((1, NUM_CHANNELS, model.history_steps, *IMAGE_SIZE_TUPLE), dtype=np.float32))
428 | except Exception as err:
429 | msg = f"Failed to run the model forward due to the following error: {err}"
430 | raise ValueError(msg) from err
431 |
432 | # grab api key from environment variable
433 | wandb_api_key = os.environ.get("WANDB_API_KEY")
434 |
435 | if not wandb_api_key:
436 | msg = "WANDB_API_KEY environment variable not set. Attempting interactive login..."
437 | logger.warning(msg)
438 | wandb.login()
439 | else:
440 | logger.info("API key found. Logging in to WandB...")
441 | wandb.login(key=wandb_api_key)
442 |
443 | # Set up the validation dataset
444 | valid_dataset = ValidationSatelliteDataset(
445 | zarr_path=data_path,
446 | history_mins=(model.history_steps - 1) * DATA_INTERVAL_SPACING_MINUTES,
447 | forecast_mins=FORECAST_HORIZON_MINUTES,
448 | sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES,
449 | nan_to_num=nan_to_num,
450 | )
451 |
452 | # Calculate the metrics before logging to wandb
453 | channel_horizon_metrics_dict, channel_names = score_model_on_all_metrics(
454 | model,
455 | valid_dataset,
456 | batch_size=batch_size,
457 | num_workers=num_workers,
458 | batch_limit=batch_limit,
459 | )
460 |
461 | # Calculate the mean of each metric reduced over forecast horizon and channels
462 | mean_metrics_dict = calc_mean_metrics(channel_horizon_metrics_dict)
463 |
464 | # Calculate the mean of each metric for each forecast horizon
465 | horizon_metrics_dict = calc_mean_metrics_per_horizon(channel_horizon_metrics_dict)
466 |
467 | # Calculate the mean of each metric for each channel
468 | channel_metrics_dict = calc_mean_metrics_per_channel(channel_horizon_metrics_dict)
469 |
470 | # Append to the wandb run name if we are limiting the number of batches
471 | if batch_limit is not None:
472 | wandb_run_name = wandb_run_name + f"-limited-to-{batch_limit}batches"
473 |
474 | # Start a wandb run
475 | wandb.init(
476 | project=wandb_project_name,
477 | name=wandb_run_name,
478 | entity=WANDB_ENTITY,
479 | )
480 |
481 | # Add the cloudcasting version to the wandb config
482 | wandb.config.update(
483 | {
484 | "cloudcast_version": cloudcasting.__version__,
485 | "batch_limit": batch_limit,
486 | }
487 | )
488 |
489 | # Log the model hyperparameters to wandb
490 | wandb.config.update(model.hyperparameters_dict())
491 |
492 | # Log plot of the horizon metrics to wandb
493 | horizon_mins = np.arange(
494 | start=DATA_INTERVAL_SPACING_MINUTES,
495 | stop=FORECAST_HORIZON_MINUTES + DATA_INTERVAL_SPACING_MINUTES,
496 | step=DATA_INTERVAL_SPACING_MINUTES,
497 | dtype=np.float32,
498 | )
499 |
500 | # Log the mean metrics to wandb
501 | for metric_name, value in mean_metrics_dict.items():
502 | log_mean_metrics_to_wandb(
503 | metric_value=value,
504 | plot_name=f"{metric_name}-mean",
505 | metric_name=metric_name,
506 | )
507 |
508 | for metric_name, horizon_array in horizon_metrics_dict.items():
509 | log_per_horizon_metrics_to_wandb(
510 | horizon_mins=horizon_mins,
511 | metric_values=horizon_array,
512 | plot_name=f"{metric_name}-horizon",
513 | metric_name=metric_name,
514 | )
515 |
516 | # Log mean metrics per-channel
517 | for metric_name, channel_array in channel_metrics_dict.items():
518 | log_per_channel_metrics_to_wandb(
519 | channel_names=channel_names,
520 | metric_values=channel_array,
521 | plot_name=f"{metric_name}-channel",
522 | metric_name=metric_name,
523 | )
524 |
525 | # Log selected video samples to wandb
526 | channel_inds = valid_dataset.ds.get_index("variable").get_indexer(VIDEO_SAMPLE_CHANNELS)
527 |
528 | for date in VIDEO_SAMPLE_DATES:
529 | X, y = valid_dataset[date]
530 |
531 | # Expand dimensions to batch size of 1 for model then contract to sample
532 | y_hat = model(X[None, ...])[0]
533 |
534 | for channel_ind, channel_name in zip(channel_inds, VIDEO_SAMPLE_CHANNELS, strict=False):
535 | log_prediction_video_to_wandb(
536 | y_hat=y_hat,
537 | y=y,
538 | video_name=f"sample_videos/{date} - {channel_name}",
539 | channel_ind=int(channel_ind),
540 | fps=1,
541 | )
542 |
543 |
544 | def validate_from_config(
545 | config_file: Annotated[
546 | str, typer.Option(help="Path to config file. Defaults to 'validate_config.yml'.")
547 | ] = "validate_config.yml",
548 | model_file: Annotated[
549 | str, typer.Option(help="Path to Python file with model definition. Defaults to 'model.py'.")
550 | ] = "model.py",
551 | ) -> None:
552 | """CLI function to validate a model from a config file. Example templates of these files can
553 | be found at https://github.com/alan-turing-institute/ocf-model-template.
554 |
555 | Args:
556 | config_file (str): Path to config file. Defaults to "validate_config.yml".
557 | model_file (str): Path to Python file with model definition. Defaults to "model.py".
558 | """
559 | with open(config_file) as f:
560 | config: dict[str, Any] = yaml.safe_load(f)
561 |
562 | # import the model definition from file
563 | spec = importlib.util.spec_from_file_location("usermodel", model_file)
564 | # type narrowing
565 | if spec is None or spec.loader is None:
566 | msg = f"Error importing {model_file}"
567 | raise ImportError(msg)
568 | module = importlib.util.module_from_spec(spec)
569 | sys.modules["usermodel"] = module
570 | spec.loader.exec_module(module)
571 |
572 | ModelClass = getattr(module, config["model"]["name"])
573 | model = ModelClass(**(config["model"]["params"] or {}))
574 |
575 | validate(model, **config["validation"])
576 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 | import xarray as xr
7 |
8 | from cloudcasting.models import VariableHorizonModel
9 |
10 | xr.set_options(keep_attrs=True) # type: ignore[no-untyped-call]
11 |
12 |
13 | @pytest.fixture
14 | def temp_output_dir(tmp_path):
15 | return str(tmp_path)
16 |
17 |
18 | @pytest.fixture
19 | def sat_zarr_path(temp_output_dir):
20 | # Load dataset which only contains coordinates, but no data
21 | ds = xr.load_dataset(
22 | f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.netcdf"
23 | )
24 |
25 | # Add time coord
26 | ds = ds.assign_coords(time=pd.date_range("2023-01-01 00:00", "2023-01-02 23:55", freq="5min"))
27 |
28 | # Add data to dataset
29 | ds["data"] = xr.DataArray(
30 | np.zeros([len(ds[c]) for c in ds.coords], dtype=np.float32),
31 | coords=ds.coords,
32 | )
33 |
34 | # Transpose to variables, time, y, x (just in case)
35 | ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary")
36 |
37 | # Add some NaNs
38 | ds["data"].values[:, :, 0, 0] = np.nan
39 |
40 | # Specifiy chunking
41 | ds = ds.chunk({"time": 10, "variable": -1, "y_geostationary": -1, "x_geostationary": -1})
42 |
43 | # Save temporarily as a zarr
44 | zarr_path = f"{temp_output_dir}/test_sat.zarr"
45 | ds.to_zarr(zarr_path)
46 |
47 | return zarr_path
48 |
49 |
50 | @pytest.fixture
51 | def val_dataset_hyperparams():
52 | return {
53 | "x_geostationary_size": 8,
54 | "y_geostationary_size": 9,
55 | }
56 |
57 |
58 | @pytest.fixture
59 | def val_sat_zarr_path(temp_output_dir, val_dataset_hyperparams):
60 | # The validation set requires a much larger set of times so we create it separately
61 | # Load dataset which only contains coordinates, but no data
62 | ds = xr.load_dataset(
63 | f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.netcdf"
64 | )
65 |
66 | # Make the dataset spatially small
67 | ds = ds.isel(
68 | x_geostationary=slice(0, val_dataset_hyperparams["x_geostationary_size"]),
69 | y_geostationary=slice(0, val_dataset_hyperparams["y_geostationary_size"]),
70 | )
71 |
72 | # Add time coord
73 | ds = ds.assign_coords(time=pd.date_range("2022-01-01 00:00", "2022-12-31 23:45", freq="15min"))
74 |
75 | # Add data to dataset
76 | ds["data"] = xr.DataArray(
77 | np.zeros([len(ds[c]) for c in ds.coords], dtype=np.float32),
78 | coords=ds.coords,
79 | )
80 |
81 | # Transpose to variables, time, y, x (just in case)
82 | ds = ds.transpose("variable", "time", "y_geostationary", "x_geostationary")
83 |
84 | # Add some NaNs
85 | ds["data"].values[:, :, 0, 0] = np.nan
86 |
87 | # Specifiy chunking
88 | ds = ds.chunk({"time": 10, "variable": -1, "y_geostationary": -1, "x_geostationary": -1})
89 |
90 | # Save temporarily as a zarr
91 | zarr_path = f"{temp_output_dir}/val_test_sat.zarr"
92 | ds.to_zarr(zarr_path)
93 |
94 | return zarr_path
95 |
96 |
97 | class PersistenceModel(VariableHorizonModel):
98 | """A persistence model used solely for testing the validation procedure"""
99 |
100 | def forward(self, X):
101 | # Grab the most recent frame from the input data
102 | # There may be NaNs in the input data, so we need to handle these
103 | latest_frame = np.nan_to_num(X[..., -1:, :, :], nan=0.0, copy=True)
104 |
105 | # The NaN values in the input data could be filled with -1. Clip these to zero
106 | latest_frame = latest_frame.clip(0, 1)
107 |
108 | return np.repeat(latest_frame, self.rollout_steps, axis=-3)
109 |
110 | def hyperparameters_dict(self):
111 | return {"history_steps": self.history_steps}
112 |
--------------------------------------------------------------------------------
/tests/legacy_metrics.py:
--------------------------------------------------------------------------------
1 | """Metrics for model output evaluation"""
2 |
3 | import numpy as np
4 | import numpy.typing as npt
5 | from jaxtyping import Float as Float32
6 | from skimage.metrics import structural_similarity # type: ignore[import-not-found]
7 |
8 | # Type aliases for clarity + reuse
9 | Array = npt.NDArray[np.float32] # the type arg is ignored by jaxtyping, but is here for clarity
10 | TimeArray = Float32[Array, "time"]
11 | MetricArray = Float32[Array, "channels time"]
12 | ChannelArray = Float32[Array, "channels"]
13 |
14 | SampleInputArray = Float32[Array, "channels time height width"]
15 | BatchInputArray = Float32[Array, "batch channels time height width"]
16 | InputArray = SampleInputArray | BatchInputArray
17 |
18 |
19 | SampleOutputArray = Float32[Array, "channels rollout_steps height width"]
20 | BatchOutputArray = Float32[Array, "batch channels rollout_steps height width"]
21 |
22 | OutputArray = SampleOutputArray | BatchOutputArray
23 |
24 |
25 | def mae_single(input: SampleOutputArray, target: SampleOutputArray) -> MetricArray:
26 | """Mean absolute error for single (non-batched) image sequences.
27 |
28 | Args:
29 | input: Array of shape [channels, time, height, width]
30 | target: Array of shape [channels, time, height, width]
31 |
32 | Returns:
33 | Array of MAE values of shape [channel, time]
34 | """
35 | absolute_error = np.abs(input - target)
36 | arr: MetricArray = np.nanmean(absolute_error, axis=(2, 3))
37 | return arr
38 |
39 |
40 | def mae_batch(input: BatchOutputArray, target: BatchOutputArray) -> MetricArray:
41 | """Mean absolute error for batched image sequences.
42 |
43 | Args:
44 | input: Array of shape [batch, channels, time, height, width]
45 | target: Array of shape [batch, channels, time, height, width]
46 |
47 | Returns:
48 | Array of MAE values of shape [channel, time]
49 | """
50 | absolute_error = np.abs(input - target)
51 | arr: MetricArray = np.nanmean(absolute_error, axis=(0, 3, 4))
52 | return arr
53 |
54 |
55 | def mse_single(input: SampleOutputArray, target: SampleOutputArray) -> MetricArray:
56 | """Mean squared error for single (non-batched) image sequences.
57 |
58 | Args:
59 | input: Array of shape [channels, time, height, width]
60 | target: Array of shape [channels, time, height, width]
61 |
62 | Returns:
63 | Array of MSE values of shape [channel, time]
64 | """
65 | square_error = (input - target) ** 2
66 | arr: MetricArray = np.nanmean(square_error, axis=(2, 3))
67 | return arr
68 |
69 |
70 | def mse_batch(input: BatchOutputArray, target: BatchOutputArray) -> MetricArray:
71 | """Mean squared error for batched image sequences.
72 |
73 | Args:
74 | input: Array of shape [batch, channels, time, height, width]
75 | target: Array of shape [batch, channels, time, height, width]
76 |
77 | Returns:
78 | Array of MSE values of shape [channel, time]
79 | """
80 | square_error = (input - target) ** 2
81 | arr: MetricArray = np.nanmean(square_error, axis=(0, 3, 4))
82 | return arr
83 |
84 |
85 | def ssim_single(input: SampleOutputArray, target: SampleOutputArray) -> MetricArray:
86 | """Computes the Structural Similarity (SSIM) index for single (non-batched) image sequences.
87 |
88 | Args:
89 | input: Array of shape [channels, time, height, width]
90 | target: Array of shape [channels, time, height, width]
91 |
92 | Returns:
93 | Array of SSIM values of shape [channel, time]
94 |
95 | References:
96 | Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
97 | Image quality assessment: From error visibility to structural similarity.
98 | IEEE Transactions on Image Processing, 13, 600-612.
99 | https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf,
100 | DOI: 10.1109/TIP.2003.819861
101 | """
102 |
103 | # This function assumes the data will be in the range 0-1 and will give invalid results if not
104 | _check_input_target_ranges(input, target)
105 |
106 | # The following param setting match Wang et. al. 2004
107 | gaussian_weights = True
108 | use_sample_covariance = False
109 | sigma = 1.5
110 | win_size = 11
111 |
112 | ssim_seq = []
113 | for i_t in range(input.shape[1]):
114 | _, ssim_array = structural_similarity(
115 | input[:, i_t],
116 | target[:, i_t],
117 | data_range=1,
118 | channel_axis=0,
119 | full=True,
120 | gaussian_weights=gaussian_weights,
121 | use_sample_covariance=use_sample_covariance,
122 | sigma=sigma,
123 | win_size=win_size,
124 | )
125 |
126 | # To avoid edge effects from the Gaussian filter we trim off the border
127 | trim_width = (win_size - 1) // 2
128 | ssim_array = ssim_array[:, trim_width:-trim_width, trim_width:-trim_width]
129 | # Take the mean of the SSIM array over channels, height, and width
130 | ssim_seq.append(np.nanmean(ssim_array, axis=(1, 2)))
131 | # stack along channel dimension
132 | arr: MetricArray = np.stack(ssim_seq, axis=1)
133 | return arr
134 |
135 |
136 | def ssim_batch(input: BatchOutputArray, target: BatchOutputArray) -> MetricArray:
137 | """Structural similarity for batched image sequences.
138 |
139 | Args:
140 | input: Array of shape [batch, channels, time, height, width]
141 | target: Array of shape [batch, channels, time, height, width]
142 | win_size: Side-length of the sliding window for comparison (must be odd)
143 |
144 | Returns:
145 | Array of SSIM values of shape [channel, time]
146 | """
147 | # This function assumes the data will be in the range 0-1 and will give invalid results if not
148 | _check_input_target_ranges(input, target)
149 |
150 | ssim_samples = []
151 | for i_b in range(input.shape[0]):
152 | ssim_samples.append(ssim_single(input[i_b], target[i_b]))
153 | arr: MetricArray = np.stack(ssim_samples, axis=0).mean(axis=0)
154 | return arr
155 |
156 |
157 | def _check_input_target_ranges(input: OutputArray, target: OutputArray) -> None:
158 | """Validate input and target arrays are within the 0-1 range.
159 |
160 | Args:
161 | input: Input array
162 | target: Target array
163 |
164 | Raises:
165 | ValueError: If input or target values are outside the 0-1 range.
166 | """
167 | input_max, input_min = input.max(), input.min()
168 | target_max, target_min = target.max(), target.min()
169 |
170 | if (input_min < 0) | (input_max > 1) | (target_min < 0) | (target_max > 1):
171 | msg = (
172 | f"Input and target must be in 0-1 range. "
173 | f"Input range: {input_min}-{input_max}. "
174 | f"Target range: {target_min}-{target_max}"
175 | )
176 | raise ValueError(msg)
177 |
--------------------------------------------------------------------------------
/tests/test_cli.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | from typer.testing import CliRunner
5 |
6 | from cloudcasting.cli import app
7 |
8 |
9 | @pytest.fixture
10 | def runner():
11 | return CliRunner()
12 |
13 |
14 | @pytest.fixture
15 | def temp_output_dir(tmp_path):
16 | return str(tmp_path)
17 |
18 |
19 | def test_download_satellite_data(runner, temp_output_dir):
20 | # Define test parameters
21 | start_date = "2021-01-01 00:00"
22 | end_date = "2021-01-01 00:30"
23 |
24 | # Run the CLI command to download the file
25 | result = runner.invoke(
26 | app,
27 | [
28 | "download",
29 | start_date,
30 | end_date,
31 | temp_output_dir,
32 | "--download-frequency=15min",
33 | "--lon-min=-1",
34 | "--lon-max=1",
35 | "--lat-min=50",
36 | "--lat-max=51",
37 | ],
38 | )
39 |
40 | # Check if the command executed successfully
41 | assert result.exit_code == 0
42 |
43 | # Check if the output file was created
44 | expected_file = os.path.join(temp_output_dir, "2021_training_nonhrv.zarr")
45 | assert os.path.exists(expected_file)
46 |
--------------------------------------------------------------------------------
/tests/test_data/non_hrv_shell.netcdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alan-turing-institute/cloudcasting/0da4087bba01823d2477dba61ea3e6cfd557212c/tests/test_data/non_hrv_shell.netcdf
--------------------------------------------------------------------------------
/tests/test_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pytest
4 |
5 | from cloudcasting.constants import (
6 | DATA_INTERVAL_SPACING_MINUTES,
7 | FORECAST_HORIZON_MINUTES,
8 | NUM_CHANNELS,
9 | NUM_FORECAST_STEPS,
10 | )
11 | from cloudcasting.dataset import (
12 | SatelliteDataModule,
13 | SatelliteDataset,
14 | ValidationSatelliteDataset,
15 | find_valid_t0_times,
16 | load_satellite_zarrs,
17 | )
18 |
19 |
20 | def test_load_satellite_zarrs(sat_zarr_path):
21 | # Check can load with string and list of string(s)
22 | ds = load_satellite_zarrs(sat_zarr_path)
23 | ds = load_satellite_zarrs([sat_zarr_path])
24 |
25 | # Dataset is a full 48 hours of 5 minutely data -> 48hours * (60/5) = 576
26 | assert len(ds.time) == 576
27 |
28 |
29 | def test_find_valid_t0_times(sat_zarr_path):
30 | ds = load_satellite_zarrs(sat_zarr_path)
31 |
32 | t0_times = find_valid_t0_times(
33 | pd.DatetimeIndex(ds.time),
34 | history_mins=60,
35 | forecast_mins=120,
36 | sample_freq_mins=5,
37 | )
38 |
39 | # original timesteps 576
40 | # forecast length buffer - (120 / 5)
41 | # history length buffer - (60 / 5)
42 | # ------------
43 | # Total 540
44 |
45 | assert len(t0_times) == 540
46 |
47 | t0_times = find_valid_t0_times(
48 | pd.DatetimeIndex(ds.time),
49 | history_mins=60,
50 | forecast_mins=120,
51 | sample_freq_mins=15,
52 | )
53 |
54 | # original 15 minute timesteps 576 / 3
55 | # forecast length buffer - (120 / 15)
56 | # history length buffer - (60 / 15)
57 | # ------------
58 | # Total 180
59 |
60 | assert len(t0_times) == 180
61 |
62 |
63 | def test_satellite_dataset(sat_zarr_path):
64 | dataset = SatelliteDataset(
65 | zarr_path=sat_zarr_path,
66 | start_time=None,
67 | end_time=None,
68 | history_mins=60,
69 | forecast_mins=120,
70 | sample_freq_mins=5,
71 | )
72 |
73 | assert len(dataset) == 540
74 |
75 | X, y = dataset[0]
76 |
77 | # 11 channels
78 | # 20 y-dim steps
79 | # 49 x-dim steps
80 | # (60 / 5) + 1 = 13 history steps
81 | # (120 / 5) = 24 forecast steps
82 | assert X.shape == (11, 13, 20, 49)
83 | assert y.shape == (11, 24, 20, 49)
84 |
85 | assert np.sum(np.isnan(X)) == 11 * 13
86 | assert np.sum(np.isnan(y)) == 11 * 24
87 |
88 |
89 | def test_satellite_datamodule(sat_zarr_path):
90 | datamodule = SatelliteDataModule(
91 | zarr_path=sat_zarr_path,
92 | history_mins=60,
93 | forecast_mins=120,
94 | sample_freq_mins=5,
95 | batch_size=2,
96 | num_workers=2,
97 | prefetch_factor=None,
98 | )
99 |
100 | dl = datamodule.train_dataloader()
101 |
102 | X, y = next(iter(dl))
103 |
104 | assert X.shape == (2, 11, 13, 20, 49)
105 | assert y.shape == (2, 11, 24, 20, 49)
106 |
107 |
108 | def test_satellite_datamodule_variables(sat_zarr_path):
109 | variables = ["VIS006", "VIS008"]
110 |
111 | datamodule = SatelliteDataModule(
112 | zarr_path=sat_zarr_path,
113 | history_mins=60,
114 | forecast_mins=120,
115 | sample_freq_mins=5,
116 | batch_size=2,
117 | num_workers=2,
118 | prefetch_factor=None,
119 | variables=variables,
120 | )
121 |
122 | dl = datamodule.train_dataloader()
123 |
124 | X, y = next(iter(dl))
125 |
126 | assert X.shape == (2, 2, 13, 20, 49)
127 | assert y.shape == (2, 2, 24, 20, 49)
128 |
129 |
130 | def test_satellite_dataset_nan_to_num(sat_zarr_path):
131 | dataset = SatelliteDataset(
132 | zarr_path=sat_zarr_path,
133 | start_time=None,
134 | end_time=None,
135 | history_mins=60,
136 | forecast_mins=120,
137 | sample_freq_mins=5,
138 | nan_to_num=True,
139 | )
140 | assert len(dataset) == 540
141 |
142 | X, y = dataset[0]
143 |
144 | # 11 channels
145 | # 20 y-dim steps
146 | # 49 x-dim steps
147 | # (60 / 5) + 1 = 13 history steps
148 | # (120 / 5) = 24 forecast steps
149 | assert X.shape == (11, 13, 20, 49)
150 | assert y.shape == (11, 24, 20, 49)
151 |
152 | assert np.sum(np.isnan(X)) == 0
153 | assert np.sum(np.isnan(y)) == 0
154 |
155 | assert np.sum(X[:, :, 0, 0]) == -11 * 13
156 | assert np.sum(y[:, :, 0, 0]) == -11 * 24
157 |
158 |
159 | def test_validation_dataset(val_sat_zarr_path, val_dataset_hyperparams):
160 | dataset = ValidationSatelliteDataset(
161 | zarr_path=val_sat_zarr_path,
162 | history_mins=60,
163 | forecast_mins=FORECAST_HORIZON_MINUTES,
164 | sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES,
165 | )
166 |
167 | # There are 3744 init times which all models must make predictions for
168 | assert len(dataset) == 3744
169 |
170 | X, y = dataset[0]
171 |
172 | # 11 channels
173 | # 2 y-dim steps
174 | # 1 x-dim steps
175 | # (60 / 15) + 1 = 5 history steps
176 | # (180 / 15) = 12 forecast steps
177 | assert X.shape == (
178 | NUM_CHANNELS,
179 | 5,
180 | val_dataset_hyperparams["y_geostationary_size"],
181 | val_dataset_hyperparams["x_geostationary_size"],
182 | )
183 | assert y.shape == (
184 | NUM_CHANNELS,
185 | NUM_FORECAST_STEPS,
186 | val_dataset_hyperparams["y_geostationary_size"],
187 | val_dataset_hyperparams["x_geostationary_size"],
188 | )
189 |
190 |
191 | def test_validation_dataset_raises_error(sat_zarr_path):
192 | with pytest.raises(ValueError, match="The following validation t0 times are not available"):
193 | ValidationSatelliteDataset(
194 | zarr_path=sat_zarr_path,
195 | history_mins=60,
196 | forecast_mins=FORECAST_HORIZON_MINUTES,
197 | sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES,
198 | )
199 |
--------------------------------------------------------------------------------
/tests/test_download.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pytest
5 | import xarray as xr
6 |
7 | from cloudcasting.download import download_satellite_data
8 |
9 |
10 | @pytest.fixture
11 | def temp_output_dir(tmp_path):
12 | return str(tmp_path)
13 |
14 |
15 | def test_download_satellite_data(temp_output_dir):
16 | # Define test parameters
17 | start_date = "2021-01-01 00:00"
18 | end_date = "2021-01-01 00:30"
19 |
20 | # Run the function to download the file
21 | download_satellite_data(
22 | start_date,
23 | end_date,
24 | temp_output_dir,
25 | download_frequency="15min",
26 | lon_min=-1,
27 | lon_max=1,
28 | lat_min=50,
29 | lat_max=51,
30 | )
31 |
32 | # Check if the output file was created
33 | expected_file = os.path.join(temp_output_dir, "2021_training_nonhrv.zarr")
34 | assert os.path.exists(expected_file)
35 |
36 |
37 | def test_download_satellite_data_test_2022_set(temp_output_dir):
38 | # Only run this test on 2022 as it's the only year with a test_2022 set.
39 | # Want to make sure that the --test-2022-set flag works as expected.
40 | start_date = "2022-01-01 00:00"
41 | end_date = "2022-03-01 00:00"
42 |
43 | # Run the function with the --test_2022-set flag
44 | download_satellite_data(
45 | start_date,
46 | end_date,
47 | temp_output_dir,
48 | download_frequency="168h",
49 | lon_min=-1,
50 | lon_max=1,
51 | lat_min=50,
52 | lat_max=51,
53 | test_2022_set=True,
54 | )
55 |
56 | # Check if the output file was created and contains the expected data
57 | expected_file = os.path.join(temp_output_dir, "2022_test_nonhrv.zarr")
58 | assert os.path.exists(expected_file)
59 |
60 | ds = xr.open_zarr(expected_file)
61 | # Check that the data is only from the expected days of the year: every other 14 days,
62 | # starting from day 15 of the year.
63 | for day in [15, 22, 43, 50]:
64 | assert day in ds.time.dt.dayofyear.values
65 |
66 |
67 | def test_download_satellite_data_2022_nontest_set(temp_output_dir):
68 | # Only run this test on 2022 as it's the only year with a test set.
69 | # Want to make sure that the --test-2022-set flag works as expected.
70 | # We need to make jumps of at least 2 weeks to ensure that the test set is used.
71 | start_date = "2022-01-01 00:00"
72 | end_date = "2022-03-01 00:00"
73 |
74 | # Run the function with the --test-set flag turned off
75 | download_satellite_data(
76 | start_date,
77 | end_date,
78 | temp_output_dir,
79 | download_frequency="168h",
80 | lon_min=-1,
81 | lon_max=1,
82 | lat_min=50,
83 | lat_max=51,
84 | )
85 |
86 | # Check if the output file was created and contains the expected data
87 | expected_file = os.path.join(temp_output_dir, "2022_training_nonhrv.zarr")
88 | assert os.path.exists(expected_file)
89 |
90 | # Now, we're in the training set
91 | ds = xr.open_zarr(expected_file)
92 |
93 | # Check that the data is only from the expected days of the year: every other 14 days,
94 | # starting from day 1 of the year.
95 | for day in [1, 8, 29, 36, 57]:
96 | assert day in ds.time.dt.dayofyear.values
97 |
98 |
99 | def test_download_satellite_data_test_2021_set(temp_output_dir):
100 | # Want to make sure that the --test-2022-set flag works as expected.
101 | start_date = "2021-01-01 00:00"
102 | end_date = "2021-01-01 00:30"
103 |
104 | # Run the function with the --test-2022-set flag
105 | # Check if the expected error was raised
106 | with pytest.raises(ValueError, match=r"Test data is only defined for 2022"):
107 | download_satellite_data(
108 | start_date,
109 | end_date,
110 | temp_output_dir,
111 | download_frequency="15min",
112 | lon_min=-1,
113 | lon_max=1,
114 | lat_min=50,
115 | lat_max=51,
116 | test_2022_set=True,
117 | )
118 |
119 |
120 | def test_download_satellite_data_verify_set(temp_output_dir):
121 | # Want to make sure that the --verify-2023-set flag works as expected.
122 | start_date = "2023-01-01 00:00"
123 | end_date = "2023-01-01 00:30"
124 |
125 | # Run the function with the --verify-2023-set flag
126 | # Check if the expected error was raised
127 | with pytest.raises(
128 | ValueError,
129 | match=r"Verification data requires a start date of '2023-01-01 00:00'",
130 | ):
131 | download_satellite_data(
132 | start_date,
133 | end_date,
134 | temp_output_dir,
135 | download_frequency="15min",
136 | lon_min=-1,
137 | lon_max=1,
138 | lat_min=50,
139 | lat_max=51,
140 | verify_2023_set=True,
141 | )
142 |
143 |
144 | def test_download_satellite_data_2023_not_verify(temp_output_dir):
145 | # Want to make sure that the --verify-2023-set flag works as expected.
146 | start_date = "2023-01-01 00:00"
147 | end_date = "2023-01-01 00:30"
148 |
149 | # Run the function with the --verify-2023-set flag
150 | # Check if the expected error was raised
151 | with pytest.raises(ValueError, match=r"2023 data is reserved for the verification process"):
152 | download_satellite_data(
153 | start_date,
154 | end_date,
155 | temp_output_dir,
156 | download_frequency="15min",
157 | lon_min=-1,
158 | lon_max=1,
159 | lat_min=50,
160 | lat_max=51,
161 | )
162 |
163 |
164 | def test_irregular_start_date(temp_output_dir):
165 | # Define test parameters
166 | start_date = "2021-01-01 00:02"
167 | end_date = "2021-01-01 00:30"
168 |
169 | # Run the function to download the file
170 | download_satellite_data(
171 | start_date,
172 | end_date,
173 | temp_output_dir,
174 | download_frequency="15min",
175 | lon_min=-1,
176 | lon_max=1,
177 | lat_min=50,
178 | lat_max=51,
179 | )
180 |
181 | # Check if the output file was created
182 | expected_file = os.path.join(temp_output_dir, "2021_training_nonhrv.zarr")
183 | assert os.path.exists(expected_file)
184 |
185 | ds = xr.open_zarr(expected_file)
186 | # Check that the data ignored the 00:02 entry; only the 00:15 and 00:30 entries should exist.
187 | assert np.all(ds.time.dt.minute.values == [np.int64(15), np.int64(30)])
188 |
189 |
190 | def test_download_satellite_data_mock_to_zarr(temp_output_dir, monkeypatch):
191 | # make a tiny dataset to mock the to_zarr function,
192 | # but use netcdf instead of zarr (as to not recurse)
193 | mock_file_name = f"{temp_output_dir}/mock.nc"
194 |
195 | def mock_to_zarr(*args, **kwargs):
196 | xr.Dataset({"data": xr.DataArray(np.zeros([1, 1, 1, 1]))}).to_netcdf(mock_file_name)
197 |
198 | monkeypatch.setattr("xarray.Dataset.to_zarr", mock_to_zarr)
199 |
200 | # Define test parameters (known missing data here somewhere)
201 | start_date = "2020-06-01 00:00"
202 | end_date = "2020-06-30 23:55"
203 |
204 | # Run the function to download the file
205 | download_satellite_data(
206 | start_date,
207 | end_date,
208 | temp_output_dir,
209 | download_frequency="15min",
210 | lon_min=-1,
211 | lon_max=1,
212 | lat_min=50,
213 | lat_max=51,
214 | )
215 |
216 | # Check if the output file was created
217 | expected_file = os.path.join(temp_output_dir, mock_file_name)
218 | assert os.path.exists(expected_file)
219 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | """Test if metrics match the legacy metrics"""
2 |
3 | import inspect
4 | from functools import partial
5 | from typing import cast
6 |
7 | import jax.numpy as jnp
8 | import jax.random as jr
9 | import numpy as np
10 | import pytest
11 | from jaxtyping import Array, Float32
12 | from legacy_metrics import mae_batch, mse_batch, ssim_batch
13 |
14 | from cloudcasting.metrics import mae, mse, ssim
15 |
16 |
17 | def apply_pix_metric(metric_func, y_hat, y) -> Float32[Array, "batch channels time"]:
18 | """Apply a pix metric to a sample of satellite data
19 | Args:
20 | metric_func: The pix metric function to apply
21 | y_hat: The predicted sequence of satellite data
22 | y: The true sequence of satellite data
23 |
24 | Returns:
25 | The pix metric value for the sample
26 | """
27 | # pix accepts arrays of shape [batch, height, width, channels].
28 | # our arrays are of shape [batch, channels, time, height, width].
29 | # channel dim would be reduced; we add a new axis to satisfy the shape reqs.
30 | # we then reshape to squash batch, channels, and time into the leading axis,
31 | # where the vmap in metrics.py will broadcast over the leading dim.
32 | y_jax = jnp.array(y).reshape(-1, *y.shape[-2:])[..., np.newaxis]
33 | y_hat_jax = jnp.array(y_hat).reshape(-1, *y_hat.shape[-2:])[..., np.newaxis]
34 |
35 | sig = inspect.signature(metric_func)
36 | if "ignore_nans" in sig.parameters:
37 | metric_func = partial(metric_func, ignore_nans=True)
38 |
39 | # we reshape the result back into [batch, channels, time],
40 | # then take the mean over the batch
41 | return cast(Float32[Array, "batch channels time"], metric_func(y_hat_jax, y_jax)).reshape(
42 | *y.shape[:3]
43 | )
44 |
45 |
46 | @pytest.mark.parametrize(
47 | ("metric_func", "legacy_func"),
48 | [
49 | (mae, mae_batch),
50 | (mse, mse_batch),
51 | (ssim, ssim_batch),
52 | ],
53 | )
54 | def test_metrics(metric_func, legacy_func):
55 | """Test if metrics match the legacy metrics"""
56 | # Create a sample input batch
57 | shape = (1, 3, 10, 100, 100)
58 | key = jr.key(321)
59 | key, k1, k2 = jr.split(key, 3)
60 | y_hat = jr.uniform(k1, shape, minval=0, maxval=1)
61 | y = jr.uniform(k2, shape, minval=0, maxval=1)
62 |
63 | # Add NaNs to the input
64 | y = y.at[:, :, :, 0, 0].set(np.nan)
65 |
66 | # Call the metric function
67 | metric = apply_pix_metric(metric_func, y_hat, y).mean(axis=0)
68 |
69 | # Check the shape of the output
70 | assert metric.shape == (3, 10)
71 |
72 | # Check the values of the output
73 | legacy_res = legacy_func(y_hat, y)
74 |
75 | # Lower tolerance for ssim (differences in implementation)
76 | rtol = 0.001 if metric_func == ssim else 1e-5
77 |
78 | assert np.allclose(metric, legacy_res, rtol=rtol), (
79 | f"Metric {metric_func} does not match legacy metric {legacy_func}"
80 | )
81 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from conftest import PersistenceModel
4 | from jaxtyping import TypeCheckError
5 |
6 | from cloudcasting.constants import NUM_FORECAST_STEPS
7 | from cloudcasting.models import AbstractModel
8 |
9 |
10 | @pytest.fixture
11 | def model():
12 | return PersistenceModel(history_steps=1, rollout_steps=NUM_FORECAST_STEPS)
13 |
14 |
15 | def test_forward(model):
16 | # Create a sample input batch
17 | X = np.random.rand(1, 3, 10, 100, 100)
18 |
19 | # Call the forward method
20 | y_hat = model.forward(X)
21 |
22 | # Check the shape of the output
23 | assert y_hat.shape == (1, 3, model.rollout_steps, 100, 100)
24 |
25 |
26 | def test_check_predictions_no_nans(model):
27 | # Create a sample prediction array without NaNs
28 | y_hat = np.random.rand(1, 3, model.rollout_steps, 100, 100)
29 |
30 | # Call the check_predictions method
31 | model.check_predictions(y_hat)
32 |
33 |
34 | def test_check_predictions_with_nans(model):
35 | # Create a sample prediction array with NaNs
36 | y_hat = np.random.rand(1, 3, model.rollout_steps, 100, 100)
37 | y_hat[0, 0, 0, 0, 0] = np.nan
38 |
39 | # Call the check_predictions method and expect a ValueError
40 | with pytest.raises(ValueError, match="Predictions contain NaNs"):
41 | model.check_predictions(y_hat)
42 |
43 |
44 | def test_check_predictions_within_range(model):
45 | # Create a sample prediction array within the expected range
46 | y_hat = np.random.rand(1, 3, model.rollout_steps, 100, 100)
47 |
48 | # Call the check_predictions method
49 | model.check_predictions(y_hat)
50 |
51 |
52 | def test_check_predictions_outside_range(model):
53 | # Create a sample prediction array outside the expected range
54 | y_hat = np.random.rand(1, 3, model.rollout_steps, 100, 100) * 2
55 |
56 | # Call the check_predictions method and expect a ValueError
57 | with pytest.raises(ValueError, match="The predictions must be in the range "):
58 | model.check_predictions(y_hat)
59 |
60 |
61 | def test_call(model):
62 | # Create a sample input batch
63 | X = np.random.rand(1, 3, model.history_steps, 100, 100)
64 |
65 | # Call the __call__ method
66 | y_hat = model(X)
67 |
68 | # Check the shape of the output
69 | assert y_hat.shape == (1, 3, model.rollout_steps, 100, 100)
70 |
71 |
72 | def test_incorrect_shapes(model):
73 | # Create a sample input batch with incorrect shapes
74 | X = np.random.rand(1, 3, 10, 100)
75 |
76 | # Call the __call__ method and expect a TypeCheckError
77 | with pytest.raises(TypeCheckError):
78 | model(X)
79 |
80 |
81 | def test_incorrect_horizon():
82 | # define a model with a different forecast horizon
83 | class Model(AbstractModel):
84 | def __init__(self, history_steps: int) -> None:
85 | super().__init__(history_steps)
86 |
87 | def forward(self, X):
88 | # Grab the most recent frame from the input data
89 | latest_frame = X[..., -1:, :, :]
90 |
91 | # The NaN values in the input data could be filled with -1. Clip these to zero
92 | latest_frame = latest_frame.clip(0, 1)
93 |
94 | return np.repeat(latest_frame, NUM_FORECAST_STEPS + 1, axis=-3)
95 |
96 | def hyperparameters_dict(self):
97 | return {"history_steps": self.history_steps}
98 |
99 | model = Model(history_steps=1)
100 | X = np.random.rand(1, 3, 1, 100, 100).astype(np.float32)
101 |
102 | # Call the __call__ method and expect a ValueError
103 | with pytest.raises(ValueError, match="The number of forecast steps in the model"):
104 | model(X)
105 |
--------------------------------------------------------------------------------
/tests/test_validation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from conftest import PersistenceModel
4 |
5 | from cloudcasting.constants import (
6 | DATA_INTERVAL_SPACING_MINUTES,
7 | FORECAST_HORIZON_MINUTES,
8 | NUM_FORECAST_STEPS,
9 | )
10 | from cloudcasting.dataset import ValidationSatelliteDataset
11 | from cloudcasting.validation import (
12 | calc_mean_metrics,
13 | score_model_on_all_metrics,
14 | validate,
15 | validate_from_config,
16 | )
17 |
18 |
19 | @pytest.fixture
20 | def model():
21 | return PersistenceModel(history_steps=1, rollout_steps=NUM_FORECAST_STEPS)
22 |
23 |
24 | @pytest.mark.parametrize("nan_to_num", [True, False])
25 | def test_score_model_on_all_metrics(model, val_sat_zarr_path, nan_to_num):
26 | # Create valid dataset
27 | valid_dataset = ValidationSatelliteDataset(
28 | zarr_path=val_sat_zarr_path,
29 | history_mins=(model.history_steps - 1) * DATA_INTERVAL_SPACING_MINUTES,
30 | forecast_mins=FORECAST_HORIZON_MINUTES,
31 | sample_freq_mins=DATA_INTERVAL_SPACING_MINUTES,
32 | nan_to_num=nan_to_num,
33 | )
34 |
35 | metric_names = ("mae", "mse", "ssim")
36 | # use small filter size to not propagate nan to the whole image
37 | # (this is only because our test images are very small (8x9) --
38 | # the filter window of size 11 would be bigger than the image!)
39 | metric_kwargs = {"ssim": {"filter_size": 2}}
40 |
41 | # Call the score_model_on_all_metrics function
42 | metrics_dict, channels = score_model_on_all_metrics(
43 | model=model,
44 | valid_dataset=valid_dataset,
45 | batch_size=2,
46 | num_workers=0,
47 | batch_limit=3,
48 | metric_names=metric_names,
49 | metric_kwargs=metric_kwargs,
50 | )
51 |
52 | # Check all the expected keys are there
53 | assert tuple(metrics_dict.keys()) == metric_names
54 |
55 | for metric_name, metric_array in metrics_dict.items():
56 | # check all the items have the expected shape
57 | assert metric_array.shape == (
58 | len(channels),
59 | NUM_FORECAST_STEPS,
60 | ), f"Metric {metric_name} has the wrong shape"
61 |
62 | assert not np.any(np.isnan(metric_array)), f"Metric '{metric_name}' is predicting NaNs!"
63 |
64 |
65 | def test_calc_mean_metrics():
66 | # Create a test dictionary of metrics (channels, time)
67 | test_metrics_dict = {
68 | "mae": np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]),
69 | "mse": np.array([[4.0, 5.0, 6.0], [4.0, 5.0, 6.0]]),
70 | }
71 |
72 | # Call the calc_mean_metrics function
73 | mean_metrics_dict = calc_mean_metrics(test_metrics_dict)
74 |
75 | # Check the expected keys are present
76 | assert mean_metrics_dict.keys() == {"mae", "mse"}
77 |
78 | # Check the expected values are present
79 | assert mean_metrics_dict["mae"] == 2
80 | assert mean_metrics_dict["mse"] == 5
81 |
82 |
83 | def test_validate(model, val_sat_zarr_path, mocker):
84 | # Mock the wandb functions so they aren't run in testing
85 | mocker.patch("wandb.login")
86 | mocker.patch("wandb.init")
87 | mocker.patch("wandb.config")
88 | mocker.patch("wandb.log")
89 | mocker.patch("wandb.plot.line")
90 | mocker.patch("wandb.plot.bar")
91 |
92 | validate(
93 | model=model,
94 | data_path=val_sat_zarr_path,
95 | wandb_project_name="cloudcasting-pytest",
96 | wandb_run_name="test_validate",
97 | nan_to_num=False,
98 | batch_size=2,
99 | num_workers=0,
100 | batch_limit=4,
101 | )
102 |
103 |
104 | def test_validate_cli(val_sat_zarr_path, mocker):
105 | # write out an example model.py file
106 | with open("model.py", "w") as f:
107 | f.write(
108 | """
109 | from cloudcasting.models import AbstractModel
110 | from cloudcasting.constants import NUM_FORECAST_STEPS
111 | import numpy as np
112 |
113 | class Model(AbstractModel):
114 | def __init__(self, history_steps: int, sigma: float) -> None:
115 | super().__init__(history_steps)
116 | self.sigma = sigma
117 |
118 | def forward(self, X):
119 | shape = X.shape
120 | return np.ones((shape[0], shape[1], NUM_FORECAST_STEPS, shape[3], shape[4]))
121 |
122 | def hyperparameters_dict(self):
123 | return {"sigma": self.sigma}
124 | """
125 | )
126 |
127 | # write out an example validate_config.yml file
128 | with open("validate_config.yml", "w") as f:
129 | f.write(
130 | f"""
131 | model:
132 | name: Model
133 | params:
134 | history_steps: 1
135 | sigma: 0.1
136 | validation:
137 | data_path: {val_sat_zarr_path}
138 | wandb_project_name: cloudcasting-pytest
139 | wandb_run_name: test_validate
140 | nan_to_num: False
141 | batch_size: 2
142 | num_workers: 0
143 | batch_limit: 4
144 | """
145 | )
146 |
147 | # Mock the wandb functions so they aren't run in testing
148 | mocker.patch("wandb.login")
149 | mocker.patch("wandb.init")
150 | mocker.patch("wandb.config")
151 | mocker.patch("wandb.log")
152 | mocker.patch("wandb.plot.line")
153 | mocker.patch("wandb.plot.bar")
154 |
155 | # run the validate_from_config function
156 | validate_from_config()
157 |
--------------------------------------------------------------------------------