├── .all-contributorsrc ├── .github └── workflows │ ├── branch_ci.yml │ ├── merged_ci.yml │ ├── pull_ci.yml │ └── tagged_ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── configs.example ├── callbacks │ └── default.yaml ├── config.yaml ├── datamodule │ ├── configuration │ │ └── example_configuration.yaml │ └── streamed_samples.yaml ├── experiment │ └── example_simple.yaml ├── hydra │ └── default.yaml ├── logger │ └── wandb.yaml ├── model │ └── late_fusion.yaml ├── readme.md └── trainer │ ├── all_params.yaml │ └── default.yaml ├── pvnet ├── __init__.py ├── datamodule.py ├── load_model.py ├── model_cards │ └── empty_model_card_template.md ├── models │ ├── __init__.py │ ├── base_model.py │ ├── ensemble.py │ └── late_fusion │ │ ├── README.md │ │ ├── __init__.py │ │ ├── basic_blocks.py │ │ ├── encoders │ │ ├── __init__.py │ │ ├── basic_blocks.py │ │ └── encoders3d.py │ │ ├── late_fusion.py │ │ ├── linear_networks │ │ ├── __init__.py │ │ ├── basic_blocks.py │ │ └── networks.py │ │ └── site_encoders │ │ ├── __init__.py │ │ ├── basic_blocks.py │ │ └── encoders.py ├── optimizers.py ├── training │ ├── __init__.py │ ├── lightning_module.py │ ├── plots.py │ └── train.py └── utils.py ├── pyproject.toml ├── run.py ├── scripts ├── backtest_sites.py ├── backtest_uk_gsp.py ├── checkpoint_to_huggingface.py ├── mae_analysis.py └── migrate_old_model.py └── tests ├── __init__.py ├── conftest.py ├── models ├── late_fusion │ ├── encoders │ │ └── test_encoders3d.py │ ├── linear_networks │ │ └── test_networks.py │ ├── site_encoders │ │ └── test_encoders.py │ ├── test_late_fusion.py │ └── test_save_load_pretrained.py ├── test_ensemble.py └── test_validation.py ├── test_data ├── site_data_config.yaml └── uk_data_config.yaml ├── test_datamodule.py └── test_end2end.py /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "files": [ 3 | "README.md" 4 | ], 5 | "imageSize": 100, 6 | "commit": false, 7 | "commitType": "docs", 8 | "commitConvention": "angular", 9 | "contributors": [ 10 | { 11 | "login": "felix-e-h-p", 12 | "name": "Felix", 13 | "avatar_url": "https://avatars.githubusercontent.com/u/137530077?v=4", 14 | "profile": "https://github.com/felix-e-h-p", 15 | "contributions": [ 16 | "code" 17 | ] 18 | }, 19 | { 20 | "login": "Sukh-P", 21 | "name": "Sukhil Patel", 22 | "avatar_url": "https://avatars.githubusercontent.com/u/42407101?v=4", 23 | "profile": "https://github.com/Sukh-P", 24 | "contributions": [ 25 | "code" 26 | ] 27 | }, 28 | { 29 | "login": "dfulu", 30 | "name": "James Fulton", 31 | "avatar_url": "https://avatars.githubusercontent.com/u/41546094?v=4", 32 | "profile": "https://github.com/dfulu", 33 | "contributions": [ 34 | "code" 35 | ] 36 | }, 37 | { 38 | "login": "AUdaltsova", 39 | "name": "Alexandra Udaltsova", 40 | "avatar_url": "https://avatars.githubusercontent.com/u/43303448?v=4", 41 | "profile": "https://github.com/AUdaltsova", 42 | "contributions": [ 43 | "code", 44 | "review" 45 | ] 46 | }, 47 | { 48 | "login": "zakwatts", 49 | "name": "Megawattz", 50 | "avatar_url": "https://avatars.githubusercontent.com/u/47150349?v=4", 51 | "profile": "https://github.com/zakwatts", 52 | "contributions": [ 53 | "code" 54 | ] 55 | }, 56 | { 57 | "login": "peterdudfield", 58 | "name": "Peter Dudfield", 59 | "avatar_url": "https://avatars.githubusercontent.com/u/34686298?v=4", 60 | "profile": "https://github.com/peterdudfield", 61 | "contributions": [ 62 | "code" 63 | ] 64 | }, 65 | { 66 | "login": "mahdilamb", 67 | "name": "Mahdi Lamb", 68 | "avatar_url": "https://avatars.githubusercontent.com/u/4696915?v=4", 69 | "profile": "https://github.com/mahdilamb", 70 | "contributions": [ 71 | "infra" 72 | ] 73 | }, 74 | { 75 | "login": "jacobbieker", 76 | "name": "Jacob Prince-Bieker", 77 | "avatar_url": "https://avatars.githubusercontent.com/u/7170359?v=4", 78 | "profile": "https://www.jacobbieker.com", 79 | "contributions": [ 80 | "code" 81 | ] 82 | }, 83 | { 84 | "login": "codderrrrr", 85 | "name": "codderrrrr", 86 | "avatar_url": "https://avatars.githubusercontent.com/u/149995852?v=4", 87 | "profile": "https://github.com/codderrrrr", 88 | "contributions": [ 89 | "code" 90 | ] 91 | }, 92 | { 93 | "login": "confusedmatrix", 94 | "name": "Chris Briggs", 95 | "avatar_url": "https://avatars.githubusercontent.com/u/617309?v=4", 96 | "profile": "https://chrisxbriggs.com", 97 | "contributions": [ 98 | "code" 99 | ] 100 | }, 101 | { 102 | "login": "tmi", 103 | "name": "tmi", 104 | "avatar_url": "https://avatars.githubusercontent.com/u/147159?v=4", 105 | "profile": "https://github.com/tmi", 106 | "contributions": [ 107 | "code" 108 | ] 109 | }, 110 | { 111 | "login": "carderne", 112 | "name": "Chris Arderne", 113 | "avatar_url": "https://avatars.githubusercontent.com/u/19817302?v=4", 114 | "profile": "https://rdrn.me/", 115 | "contributions": [ 116 | "code" 117 | ] 118 | }, 119 | { 120 | "login": "Dakshbir", 121 | "name": "Dakshbir", 122 | "avatar_url": "https://avatars.githubusercontent.com/u/144359831?v=4", 123 | "profile": "https://github.com/Dakshbir", 124 | "contributions": [ 125 | "code" 126 | ] 127 | }, 128 | { 129 | "login": "MAYANK12SHARMA", 130 | "name": "MAYANK SHARMA", 131 | "avatar_url": "https://avatars.githubusercontent.com/u/145884197?v=4", 132 | "profile": "https://github.com/MAYANK12SHARMA", 133 | "contributions": [ 134 | "code" 135 | ] 136 | }, 137 | { 138 | "login": "lambaaryan011", 139 | "name": "aryan lamba ", 140 | "avatar_url": "https://avatars.githubusercontent.com/u/153702847?v=4", 141 | "profile": "https://github.com/lambaaryan011", 142 | "contributions": [ 143 | "code" 144 | ] 145 | }, 146 | { 147 | "login": "michael-gendy", 148 | "name": "michael-gendy", 149 | "avatar_url": "https://avatars.githubusercontent.com/u/64384201?v=4", 150 | "profile": "https://github.com/michael-gendy", 151 | "contributions": [ 152 | "code" 153 | ] 154 | }, 155 | { 156 | "login": "adityasuthar", 157 | "name": "Aditya Suthar", 158 | "avatar_url": "https://avatars.githubusercontent.com/u/95685363?v=4", 159 | "profile": "https://adityasuthar.github.io/", 160 | "contributions": [ 161 | "code" 162 | ] 163 | }, 164 | { 165 | "login": "markus-kreft", 166 | "name": "Markus Kreft", 167 | "avatar_url": "https://avatars.githubusercontent.com/u/129367085?v=4", 168 | "profile": "https://github.com/markus-kreft", 169 | "contributions": [ 170 | "code" 171 | ] 172 | }, 173 | { 174 | "login": "JackKelly", 175 | "name": "Jack Kelly", 176 | "avatar_url": "https://avatars.githubusercontent.com/u/460756?v=4", 177 | "profile": "http://jack-kelly.com", 178 | "contributions": [ 179 | "ideas" 180 | ] 181 | }, 182 | { 183 | "login": "zaryab-ali", 184 | "name": "zaryab-ali", 185 | "avatar_url": "https://avatars.githubusercontent.com/u/85732412?v=4", 186 | "profile": "https://github.com/zaryab-ali", 187 | "contributions": [ 188 | "code" 189 | ] 190 | }, 191 | { 192 | "login": "Lex-Ashu", 193 | "name": "Lex-Ashu", 194 | "avatar_url": "https://avatars.githubusercontent.com/u/181084934?v=4", 195 | "profile": "https://github.com/Lex-Ashu", 196 | "contributions": [ 197 | "code" 198 | ] 199 | } 200 | ], 201 | "contributorsPerLine": 7, 202 | "skipCi": true, 203 | "repoType": "github", 204 | "repoHost": "https://github.com", 205 | "projectName": "pvnet", 206 | "projectOwner": "openclimatefix" 207 | } 208 | -------------------------------------------------------------------------------- /.github/workflows/branch_ci.yml: -------------------------------------------------------------------------------- 1 | name: Branch CI (Python) 2 | run-name: 'Test branch commit "${{ github.event.head_commit.message }}"' 3 | 4 | on: 5 | push: 6 | branches-ignore: [ "main" ] 7 | paths-ignore: ['README.md'] 8 | 9 | jobs: 10 | branch-ci: 11 | uses: openclimatefix/.github/.github/workflows/branch_ci.yml@main 12 | secrets: inherit 13 | with: 14 | enable_linting: true 15 | enable_typechecking: false 16 | containerfile: 'None' 17 | tests_folder: 'tests' 18 | tests_matrix: true 19 | test_python_versions: '["3.11", "3.12"]' 20 | -------------------------------------------------------------------------------- /.github/workflows/merged_ci.yml: -------------------------------------------------------------------------------- 1 | name: Merged CI 2 | run-name: 'Bump tag with merge #${{ github.event.number }} "${{ github.event.pull_request.title }}"' 3 | 4 | on: 5 | pull_request_target: 6 | types: ["closed"] 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | bump-tag: 11 | uses: openclimatefix/.github/.github/workflows/bump_tag.yml@main 12 | secrets: inherit 13 | -------------------------------------------------------------------------------- /.github/workflows/pull_ci.yml: -------------------------------------------------------------------------------- 1 | name: Pull CI (Python) 2 | run-name: 'Test PR edit #${{ github.event.number }} "${{ github.event.pull_request.title }}"' 3 | 4 | on: 5 | pull_request: 6 | paths-ignore: ['README.md'] 7 | 8 | jobs: 9 | 10 | pull-ci: 11 | uses: openclimatefix/.github/.github/workflows/branch_ci.yml@main 12 | if: ${{ github.event.pull_request.head.repo.fork }} 13 | with: 14 | enable_linting: true 15 | enable_typechecking: false 16 | containerfile: 'None' 17 | tests_folder: 'tests' 18 | tests_matrix: true 19 | test_python_versions: '["3.11", "3.12"]' 20 | -------------------------------------------------------------------------------- /.github/workflows/tagged_ci.yml: -------------------------------------------------------------------------------- 1 | name: Tagged CI 2 | run-name: 'Tagged CI for ${{ github.ref_name }} by ${{ github.actor }}' 3 | 4 | on: 5 | push: 6 | tags: ["v*.*.*"] 7 | 8 | jobs: 9 | tagged-ci: 10 | uses: openclimatefix/.github/.github/workflows/tagged_ci.yml@main 11 | secrets: inherit 12 | with: 13 | containerfile: 'None' 14 | enable_pypi: true 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | config_tree.txt 3 | configs/ 4 | lightning_logs/ 5 | logs/ 6 | output/ 7 | checkpoints* 8 | csv/ 9 | notebooks/ 10 | *.html 11 | *.csv 12 | latest_logged_train_batch.png 13 | 14 | # Ignore all model cards... 15 | pvnet/model_cards/* 16 | 17 | # ...except for the empty template. 18 | !pvnet/model_cards/empty_model_card_template.md 19 | 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | pip-wheel-metadata/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | *.py,cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | 73 | # Translations 74 | *.mo 75 | *.pot 76 | 77 | # Django stuff: 78 | *.log 79 | local_settings.py 80 | db.sqlite3 81 | db.sqlite3-journal 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # pipenv 107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 109 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 110 | # install all needed dependencies. 111 | #Pipfile.lock 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | .DS_Store 150 | 151 | # vim 152 | *swp 153 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Open Climate Fix Ltd 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PVNet 2 | 3 | [![All Contributors](https://img.shields.io/badge/all_contributors-21-orange.svg?style=flat-square)](#contributors-) 4 | 5 | 6 | [![tags badge](https://img.shields.io/github/v/tag/openclimatefix/PVNet?include_prereleases&sort=semver&color=FFAC5F)](https://github.com/openclimatefix/PVNet/tags) 7 | [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories) 8 | 9 | 10 | This project is used for training PVNet and running PVNet on live data. 11 | 12 | PVNet is a multi-modal late-fusion model for predicting renewable energy generation from weather 13 | data. The NWP (Numerical Weather Prediction) and satellite data are sent through a neural network 14 | which encodes them down to 1D intermediate representations. These are concatenated together with 15 | recent generation, the calculated solar coordinates (azimuth and elevation) and the location ID 16 | which has been put through an embedding layer. This 1D concatenated feature vector is put through 17 | an output network which outputs predictions of the future energy yield. 18 | 19 | 20 | ## Experiments 21 | 22 | Our paper based on this repo was accepted into the Tackling Climate Change with Machine Learning 23 | workshop at ICLR 2024 and can be viewed [here](https://www.climatechange.ai/papers/iclr2024/46). 24 | 25 | Some more structured notes on experiments we have performed with PVNet are 26 | [here](https://docs.google.com/document/d/1VumDwWd8YAfvXbOtJEv3ZJm_FHQDzrKXR0jU9vnvGQg). 27 | 28 | 29 | ## Setup / Installation 30 | 31 | ```bash 32 | git clone git@github.com:openclimatefix/PVNet.git 33 | cd PVNet 34 | pip install . 35 | ``` 36 | 37 | The commit history is extensive. To save download time, use a depth of 1: 38 | ```bash 39 | git clone --depth 1 git@github.com:openclimatefix/PVNet.git 40 | ``` 41 | This means only the latest commit and its associated files will be downloaded. 42 | 43 | Next, in the PVNet repo, install PVNet as an editable package: 44 | 45 | ```bash 46 | pip install -e . 47 | ``` 48 | 49 | ### Additional development dependencies 50 | 51 | ```bash 52 | pip install ".[dev]" 53 | ``` 54 | 55 | 56 | 57 | ## Getting started with running PVNet 58 | 59 | Before running any code in PVNet, copy the example configuration to a 60 | configs directory: 61 | 62 | ``` 63 | cp -r configs.example configs 64 | ``` 65 | 66 | You will be making local amendments to these configs. See the README in 67 | `configs.example` for more info. 68 | 69 | ### Datasets 70 | 71 | As a minimum, in order to create samples of data/run PVNet, you will need to 72 | supply paths to NWP and GSP data. PV data can also be used. We list some 73 | suggested locations for downloading such datasets below: 74 | 75 | **GSP (Grid Supply Point)** - Regional PV generation data\ 76 | The University of Sheffield provides API access to download this data: 77 | https://www.solar.sheffield.ac.uk/api/ 78 | 79 | Documentation for querying generation data aggregated by GSP region can be found 80 | here: 81 | https://docs.google.com/document/d/e/2PACX-1vSDFb-6dJ2kIFZnsl-pBQvcH4inNQCA4lYL9cwo80bEHQeTK8fONLOgDf6Wm4ze_fxonqK3EVBVoAIz/pub#h.9d97iox3wzmd 82 | 83 | **NWP (Numerical weather predictions)**\ 84 | OCF maintains a Zarr formatted version of the German Weather Service's (DWD) 85 | ICON-EU NWP model here: 86 | https://huggingface.co/datasets/openclimatefix/dwd-icon-eu which includes the UK 87 | 88 | **PV**\ 89 | OCF maintains a dataset of PV generation from 1311 private PV installations 90 | here: https://huggingface.co/datasets/openclimatefix/uk_pv 91 | 92 | 93 | ### Connecting with ocf-data-sampler for sample creation 94 | 95 | Outside the PVNet repo, clone the ocf-data-sampler repo and exit the conda env created for PVNet: https://github.com/openclimatefix/ocf-data-sampler 96 | ```bash 97 | git clone git@github.com/openclimatefix/ocf-data-sampler.git 98 | conda create -n ocf-data-sampler python=3.11 99 | ``` 100 | 101 | Then go inside the ocf-data-sampler repo to add packages 102 | 103 | ```bash 104 | pip install . 105 | ``` 106 | 107 | Then exit this environment, and enter back into the pvnet conda environment and install ocf-data-sampler in editable mode (-e). This means the package is directly linked to the source code in the ocf-data-sampler repo. 108 | 109 | ```bash 110 | pip install -e 111 | ``` 112 | 113 | If you install the local version of `ocf-data-sampler` that is more recent than the version 114 | specified in `PVNet` it is not guarenteed to function properly with this library. 115 | 116 | 117 | ### Set up and config example for streaming 118 | 119 | We will use the following example config file to describe your data sources: `/PVNet/configs/datamodule/configuration/example_configuration.yaml`. Ensure that the file paths are set to the correct locations in `example_configuration.yaml`: search for `PLACEHOLDER` to find where to input the location of the files. Delete or comment the parts for data you are not using. 120 | 121 | At run time, the datamodule config `PVNet/configs/datamodule/streamed_samples.yaml` points to your chosen configuration file: 122 | 123 | configuration: "/FULL-PATH-TO-REPO/PVNet/configs/datamodule/configuration/example_configuration.yaml" 124 | 125 | You can also update train/val/test time ranges here to match the period you have access to. 126 | 127 | If downloading private data from a GCP bucket make sure to authenticate gcloud (the public satellite data does not need authentication): 128 | 129 | gcloud auth login 130 | 131 | You can provide multiple storage locations as a list. For example: 132 | 133 | satellite: 134 | zarr_path: 135 | - "gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v4/2020_nonhrv.zarr" 136 | - "gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v4/2021_nonhrv.zarr" 137 | 138 | `ocf-data-sampler` is currently set up to use 11 channels from the satellite data (the 12th, HRV, is not used). 139 | 140 | ⚠️ NB: Our publicly accessible satellite data is currently saved with a blosc2 compressor, which is not supported by the tensorstore backend PVNet relies on now. We are in the process of updating this; for now, the paths above cannot be used with this codebase. 141 | 142 | ### Training PVNet 143 | 144 | How PVNet is run is determined by the configuration files. The example configs in `PVNet/configs.example` work with **streamed_samples** using `datamodule/streamed_samples.yaml`. 145 | 146 | Update the following before training: 147 | 148 | 1. In `configs/model/late_fusion.yaml`: 149 | - Update the list of encoders to match the data sources you are using. For different NWP sources, keep the same structure but ensure: 150 | - `in_channels`: the number of variables your NWP source supplies 151 | - `image_size_pixels`: spatial crop matching your NWP resolution and the settings in your datamodule configuration (unless you coarsened, e.g. for ECMWF) 152 | 2. In `configs/trainer/default.yaml`: 153 | - Set `accelerator: 0` if running on a system without a supported GPU 154 | 3. In `configs/datamodule/streamed_samples.yaml`: 155 | - Point `configuration:` to your local `example_configuration.yaml` (or your custom one) 156 | - Adjust the train/val/test time ranges to your available data 157 | 158 | If you create custom config files, update the main `./configs/config.yaml` defaults: 159 | 160 | defaults: 161 | - trainer: default.yaml 162 | - model: late_fusion.yaml 163 | - datamodule: streamed_samples.yaml 164 | - callbacks: null 165 | - experiment: null 166 | - hparams_search: null 167 | - hydra: default.yaml 168 | 169 | Now train PVNet: 170 | 171 | python run.py 172 | 173 | You can override any setting with Hydra, e.g.: 174 | 175 | python run.py datamodule=streamed_samples datamodule.configuration="/FULL-PATH/PVNet/configs/datamodule/configuration/example_configuration.yaml" 176 | 177 | ## Backtest 178 | 179 | If you have successfully trained a PVNet model and have a saved model checkpoint you can create a backtest using this, e.g. forecasts on historical data to evaluate forecast accuracy/skill. This can be done by running one of the scripts in this repo such as [the UK GSP backtest script](scripts/backtest_uk_gsp.py) or the [the pv site backtest script](scripts/backtest_sites.py), further info on how to run these are in each backtest file. 180 | 181 | ## Testing 182 | 183 | You can use `python -m pytest tests` to run tests 184 | 185 | ## Contributors ✨ 186 | 187 | Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 |
Felix
Felix

💻
Sukhil Patel
Sukhil Patel

💻
James Fulton
James Fulton

💻
Alexandra Udaltsova
Alexandra Udaltsova

💻 👀
Megawattz
Megawattz

💻
Peter Dudfield
Peter Dudfield

💻
Mahdi Lamb
Mahdi Lamb

🚇
Jacob Prince-Bieker
Jacob Prince-Bieker

💻
codderrrrr
codderrrrr

💻
Chris Briggs
Chris Briggs

💻
tmi
tmi

💻
Chris Arderne
Chris Arderne

💻
Dakshbir
Dakshbir

💻
MAYANK SHARMA
MAYANK SHARMA

💻
aryan lamba
aryan lamba

💻
michael-gendy
michael-gendy

💻
Aditya Suthar
Aditya Suthar

💻
Markus Kreft
Markus Kreft

💻
Jack Kelly
Jack Kelly

🤔
zaryab-ali
zaryab-ali

💻
Lex-Ashu
Lex-Ashu

💻
223 | 224 | 225 | 226 | 227 | 228 | 229 | This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! 230 | -------------------------------------------------------------------------------- /configs.example/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | early_stopping: 2 | _target_: lightning.pytorch.callbacks.EarlyStopping 3 | # name of the logged metric which determines when model is improving 4 | monitor: "${resolve_monitor_loss:${model.model.output_quantiles}}" 5 | mode: "min" # can be "max" or "min" 6 | patience: 10 # how many epochs (or val check periods) of not improving until training stops 7 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement 8 | 9 | learning_rate_monitor: 10 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 11 | logging_interval: "epoch" 12 | 13 | model_summary: 14 | _target_: lightning.pytorch.callbacks.ModelSummary 15 | max_depth: 3 16 | 17 | model_checkpoint: 18 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 19 | # name of the logged metric which determines when model is improving 20 | monitor: "${resolve_monitor_loss:${model.model.output_quantiles}}" 21 | mode: "min" # can be "max" or "min" 22 | save_top_k: 1 # save k best models (determined by above metric) 23 | save_last: True # additionaly always save model from last epoch 24 | every_n_epochs: 1 25 | verbose: False 26 | filename: "epoch={epoch}-step={step}" 27 | # The path to where the model checkpoints will be stored 28 | dirpath: "PLACEHOLDER/${model_name}" #${..model_name} 29 | auto_insert_metric_name: False 30 | save_on_train_epoch_end: False 31 | -------------------------------------------------------------------------------- /configs.example/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - trainer: default.yaml 7 | - model: late_fusion.yaml 8 | - datamodule: streamed_samples.yaml 9 | - callbacks: default.yaml # set this to null if you don't want to use callbacks 10 | - logger: wandb.yaml # set logger here or use command line (e.g. `python run.py logger=wandb`) 11 | - experiment: null 12 | - hparams_search: null 13 | - hydra: default.yaml 14 | 15 | renewable: "pv_uk" 16 | 17 | # enable color logging 18 | # - override hydra/hydra_logging: colorlog 19 | # - override hydra/job_logging: colorlog 20 | 21 | # path to original working directory 22 | # hydra hijacks working directory by changing it to the current log directory, 23 | # so it's useful to have this path as a special variable 24 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 25 | work_dir: ${hydra:runtime.cwd} 26 | 27 | model_name: "default" 28 | 29 | seed: 2727831 30 | -------------------------------------------------------------------------------- /configs.example/datamodule/configuration/example_configuration.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | description: Example config for producing PVNet samples 3 | name: example_config 4 | 5 | input_data: 6 | 7 | # Either use Site OR GSP configuration 8 | site: 9 | # Path to Site data in NetCDF format 10 | file_path: PLACEHOLDER.nc 11 | # Path to metadata in CSV format 12 | metadata_file_path: PLACEHOLDER.csv 13 | time_resolution_minutes: 15 14 | interval_start_minutes: -60 15 | # Specified for intraday currently 16 | interval_end_minutes: 480 17 | dropout_timedeltas_minutes: [] 18 | dropout_fraction: 0 # Fraction of samples with dropout 19 | 20 | gsp: 21 | # Path to GSP data in zarr format 22 | # e.g. gs://solar-pv-nowcasting-data/PV/GSP/v7/pv_gsp.zarr 23 | zarr_path: PLACEHOLDER.zarr 24 | interval_start_minutes: -60 25 | # Specified for intraday currently 26 | interval_end_minutes: 480 27 | time_resolution_minutes: 30 28 | # Random value from the list below will be chosen as the delay when dropout is used 29 | # If set to null no dropout is applied. Only values before t0 are dropped out for GSP. 30 | # Values after t0 are assumed as targets and cannot be dropped. 31 | dropout_timedeltas_minutes: [] 32 | dropout_fraction: 0 # Fraction of samples with dropout 33 | 34 | nwp: 35 | 36 | ecmwf: 37 | provider: ecmwf 38 | # Path to ECMWF NWP data in zarr format 39 | # n.b. It is not necessary to use multiple or any NWP data. These entries can be removed 40 | zarr_path: PLACEHOLDER.zarr 41 | interval_start_minutes: -60 42 | # Specified for intraday currently 43 | interval_end_minutes: 480 44 | time_resolution_minutes: 60 45 | channels: 46 | - t2m # 2-metre temperature 47 | - dswrf # downwards short-wave radiation flux 48 | - dlwrf # downwards long-wave radiation flux 49 | - hcc # high cloud cover 50 | - mcc # medium cloud cover 51 | - lcc # low cloud cover 52 | - tcc # total cloud cover 53 | - sde # snow depth water equivalent 54 | - sr # direct solar radiation 55 | - duvrs # downwards UV radiation at surface 56 | - prate # precipitation rate 57 | - u10 # 10-metre U component of wind speed 58 | - u100 # 100-metre U component of wind speed 59 | - u200 # 200-metre U component of wind speed 60 | - v10 # 10-metre V component of wind speed 61 | - v100 # 100-metre V component of wind speed 62 | - v200 # 200-metre V component of wind speed 63 | # The following channels are accumulated and need to be diffed 64 | accum_channels: 65 | - dswrf # downwards short-wave radiation flux 66 | - dlwrf # downwards long-wave radiation flux 67 | - sr # direct solar radiation 68 | - duvrs # downwards UV radiation at surface 69 | image_size_pixels_height: 24 70 | image_size_pixels_width: 24 71 | dropout_timedeltas_minutes: [-360] 72 | dropout_fraction: 1.0 # Fraction of samples with dropout 73 | max_staleness_minutes: null 74 | normalisation_constants: 75 | t2m: 76 | mean: 283.48333740234375 77 | std: 3.692270040512085 78 | dswrf: 79 | mean: 11458988.0 80 | std: 13025427.0 81 | dlwrf: 82 | mean: 27187026.0 83 | std: 15855867.0 84 | hcc: 85 | mean: 0.3961029052734375 86 | std: 0.42244860529899597 87 | mcc: 88 | mean: 0.3288780450820923 89 | std: 0.38039860129356384 90 | lcc: 91 | mean: 0.44901806116104126 92 | std: 0.3791404366493225 93 | tcc: 94 | mean: 0.7049227356910706 95 | std: 0.37487083673477173 96 | sde: 97 | mean: 8.107526082312688e-05 98 | std: 0.000913831521756947 # Mapped from "sd" in the Python file 99 | sr: 100 | mean: 12905302.0 101 | std: 16294988.0 102 | duvrs: 103 | mean: 1305651.25 104 | std: 1445635.25 105 | prate: 106 | mean: 3.108070450252853e-05 107 | std: 9.81039775069803e-05 108 | u10: 109 | mean: 1.7677178382873535 110 | std: 5.531515598297119 111 | u100: 112 | mean: 2.393547296524048 113 | std: 7.2320556640625 114 | u200: 115 | mean: 2.7963004112243652 116 | std: 8.049470901489258 117 | v10: 118 | mean: 0.985887885093689 119 | std: 5.411230564117432 120 | v100: 121 | mean: 1.4244288206100464 122 | std: 6.944501876831055 123 | v200: 124 | mean: 1.6010299921035767 125 | std: 7.561611652374268 126 | # Added diff_ keys for the channels under accum_channels: 127 | diff_dlwrf: 128 | mean: 1136464.0 129 | std: 131942.03125 130 | diff_dswrf: 131 | mean: 420584.6875 132 | std: 715366.3125 133 | diff_duvrs: 134 | mean: 48265.4765625 135 | std: 81605.25 136 | diff_sr: 137 | mean: 469169.5 138 | std: 818950.6875 139 | 140 | ukv: 141 | provider: ukv 142 | # Path to UKV NWP data in zarr format 143 | # e.g. gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV_intermediate_version_7.zarr 144 | # n.b. It is not necessary to use multiple or any NWP data. These entries can be removed 145 | zarr_path: PLACEHOLDER.zarr 146 | interval_start_minutes: -60 147 | # Specified for intraday currently 148 | interval_end_minutes: 480 149 | time_resolution_minutes: 60 150 | channels: 151 | - t # 2-metre temperature 152 | - dswrf # downwards short-wave radiation flux 153 | - dlwrf # downwards long-wave radiation flux 154 | - hcc # high cloud cover 155 | - mcc # medium cloud cover 156 | - lcc # low cloud cover 157 | - sde # snow depth water equivalent 158 | - r # relative humidty 159 | - vis # visibility 160 | - si10 # 10-metre wind speed 161 | - wdir10 # 10-metre wind direction 162 | - prate # precipitation rate 163 | # These variables exist in CEDA training data but not in the live MetOffice live service 164 | - hcct # height of convective cloud top, meters above surface. NaN if no clouds 165 | - cdcb # height of lowest cloud base > 3 oktas 166 | - dpt # dew point temperature 167 | - prmsl # mean sea level pressure 168 | - h # geometrical? (maybe geopotential?) height 169 | image_size_pixels_height: 24 170 | image_size_pixels_width: 24 171 | dropout_timedeltas_minutes: [-360] 172 | dropout_fraction: 1.0 # Fraction of samples with dropout 173 | max_staleness_minutes: null 174 | normalisation_constants: 175 | t: 176 | mean: 283.64913206 177 | std: 4.38818501 178 | dswrf: 179 | mean: 111.28265039 180 | std: 190.47216887 181 | dlwrf: 182 | mean: 325.03130139 183 | std: 39.45988077 184 | hcc: 185 | mean: 29.11949682 186 | std: 38.07184418 187 | mcc: 188 | mean: 40.88984494 189 | std: 41.91144559 190 | lcc: 191 | mean: 50.08362643 192 | std: 39.33210726 193 | sde: 194 | mean: 0.00289545 195 | std: 0.1029753 196 | r: 197 | mean: 81.79229501 198 | std: 11.45012499 199 | vis: 200 | mean: 32262.03285118 201 | std: 21578.97975625 202 | si10: 203 | mean: 6.88348448 204 | std: 3.94718813 205 | wdir10: 206 | mean: 199.41891636 207 | std: 94.08407495 208 | prate: 209 | mean: 3.45793433e-05 210 | std: 0.00021497 211 | hcct: 212 | mean: -18345.97478167 213 | std: 18382.63958991 214 | cdcb: 215 | mean: 1412.26599062 216 | std: 2126.99350113 217 | dpt: 218 | mean: 280.54379901 219 | std: 4.57250482 220 | prmsl: 221 | mean: 101321.61574029 222 | std: 1252.71790539 223 | h: 224 | mean: 2096.51991356 225 | std: 1075.77812282 226 | 227 | satellite: 228 | # Path to Satellite data (non-HRV) in zarr format 229 | # e.g. gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/v4/2020_nonhrv.zarr 230 | zarr_path: PLACEHOLDER.zarr 231 | interval_start_minutes: -30 232 | interval_end_minutes: 0 233 | time_resolution_minutes: 5 234 | channels: 235 | - IR_016 # Surface, cloud phase 236 | - IR_039 # Surface, clouds, wind fields 237 | - IR_087 # Surface, clouds, atmospheric instability 238 | - IR_097 # Ozone 239 | - IR_108 # Surface, clouds, wind fields, atmospheric instability 240 | - IR_120 # Surface, clouds, atmospheric instability 241 | - IR_134 # Cirrus cloud height, atmospheric instability 242 | - VIS006 # Surface, clouds, wind fields 243 | - VIS008 # Surface, clouds, wind fields 244 | - WV_062 # Water vapor, high level clouds, upper air analysis 245 | - WV_073 # Water vapor, atmospheric instability, upper-level dynamics 246 | image_size_pixels_height: 24 247 | image_size_pixels_width: 24 248 | dropout_timedeltas_minutes: [] 249 | dropout_fraction: 0 # Fraction of samples with dropout 250 | normalisation_constants: 251 | IR_016: 252 | mean: 0.17594202 253 | std: 0.21462157 254 | IR_039: 255 | mean: 0.86167645 256 | std: 0.04618041 257 | IR_087: 258 | mean: 0.7719318 259 | std: 0.06687243 260 | IR_097: 261 | mean: 0.8014212 262 | std: 0.0468558 263 | IR_108: 264 | mean: 0.71254843 265 | std: 0.17482725 266 | IR_120: 267 | mean: 0.89058584 268 | std: 0.06115861 269 | IR_134: 270 | mean: 0.944365 271 | std: 0.04492306 272 | VIS006: 273 | mean: 0.09633306 274 | std: 0.12184761 275 | VIS008: 276 | mean: 0.11426069 277 | std: 0.13090034 278 | WV_062: 279 | mean: 0.7359355 280 | std: 0.16111417 281 | WV_073: 282 | mean: 0.62479186 283 | std: 0.12924142 284 | 285 | solar_position: 286 | interval_start_minutes: -60 287 | interval_end_minutes: 480 288 | time_resolution_minutes: 30 289 | -------------------------------------------------------------------------------- /configs.example/datamodule/streamed_samples.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvnet.datamodule.UKRegionalDataModule 2 | # Path to the data configuration yaml file. You can find examples in the configuration subdirectory 3 | # in configs.example/datamodule/configuration 4 | # Use the full local path such as: /FULL/PATH/PVNet/configs/datamodule/configuration/example_configuration.yaml" 5 | configuration: "PLACEHOLDER.yaml" 6 | num_workers: 20 7 | prefetch_factor: 2 8 | persistent_workers: false 9 | batch_size: 8 10 | 11 | train_period: 12 | - null 13 | - "2022-05-07" 14 | val_period: 15 | - "2022-05-08" 16 | - "2023-05-08" 17 | 18 | seed: "${seed}" 19 | 20 | # Setting the dataset pickle dir will speed up initiation of multiple workers 21 | dataset_pickle_dir: null 22 | -------------------------------------------------------------------------------- /configs.example/experiment/example_simple.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: default.yaml # choose trainer from 'configs/trainer/' 8 | - override /model: multimodal.yaml 9 | - override /datamodule: streamed_samples.yaml 10 | - override /callbacks: default.yaml 11 | - override /logger: wandb.yaml 12 | - override /hydra: default.yaml 13 | 14 | # all parameters below will be merged with parameters from default configurations set above 15 | # this allows you to overwrite only specified parameters 16 | 17 | seed: 518 18 | 19 | trainer: 20 | min_epochs: 1 21 | max_epochs: 2 22 | 23 | datamodule: 24 | batch_size: 16 25 | -------------------------------------------------------------------------------- /configs.example/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | run: 3 | # Local log directory for hydra 4 | dir: PLACEHOLDER/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | sweep: 6 | # Local log directory for hydra 7 | dir: PLACEHOLDER/multiruns/${now:%Y-%m-%d_%H-%M-%S} 8 | subdir: ${hydra.job.num} 9 | 10 | # you can set here environment variables that are universal for all users 11 | # for system specific variables (like data paths) it's better to use .env file! 12 | job: 13 | env_set: 14 | EXAMPLE_VAR: "example_value" 15 | -------------------------------------------------------------------------------- /configs.example/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # wandb project to log to 6 | project: "PLACEHOLDER" 7 | name: "${model_name}" 8 | # location to store the wandb local logs 9 | save_dir: "PLACEHOLDER" 10 | offline: False # set True to store all logs only locally 11 | id: null # pass correct id to resume experiment! 12 | # entity: "" # set to name of your wandb team or just remove it 13 | log_model: False 14 | prefix: "" 15 | job_type: "train" 16 | group: "" 17 | tags: [] 18 | -------------------------------------------------------------------------------- /configs.example/model/late_fusion.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvnet.training.lightning_module.PVNetLightningModule 2 | 3 | model: 4 | _target_: pvnet.models.LateFusionModel 5 | output_quantiles: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98] 6 | 7 | #-------------------------------------------- 8 | # NWP encoder 9 | #-------------------------------------------- 10 | 11 | nwp_encoders_dict: 12 | ukv: 13 | _target_: pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet 14 | _partial_: True 15 | in_channels: 2 16 | out_features: 256 17 | number_of_conv3d_layers: 6 18 | conv3d_channels: 32 19 | image_size_pixels: 24 20 | ecmwf: 21 | _target_: pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet 22 | _partial_: True 23 | in_channels: 12 24 | out_features: 256 25 | number_of_conv3d_layers: 4 26 | conv3d_channels: 32 27 | image_size_pixels: 12 28 | 29 | #-------------------------------------------- 30 | # Sat encoder settings 31 | #-------------------------------------------- 32 | 33 | sat_encoder: 34 | _target_: pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet 35 | _partial_: True 36 | in_channels: 11 37 | out_features: 256 38 | number_of_conv3d_layers: 6 39 | conv3d_channels: 32 40 | image_size_pixels: 24 41 | 42 | add_image_embedding_channel: False 43 | 44 | #-------------------------------------------- 45 | # PV encoder settings 46 | #-------------------------------------------- 47 | 48 | pv_encoder: 49 | _target_: pvnet.models.late_fusion.site_encoders.encoders.SingleAttentionNetwork 50 | _partial_: True 51 | num_sites: 349 52 | out_features: 40 53 | num_heads: 4 54 | kdim: 40 55 | id_embed_dim: 20 56 | 57 | #-------------------------------------------- 58 | # Tabular network settings 59 | #-------------------------------------------- 60 | 61 | output_network: 62 | _target_: pvnet.models.late_fusion.linear_networks.networks.ResFCNet 63 | _partial_: True 64 | fc_hidden_features: 128 65 | n_res_blocks: 6 66 | res_block_layers: 2 67 | dropout_frac: 0.0 68 | 69 | embedding_dim: 16 70 | include_sun: True 71 | include_gsp_yield_history: False 72 | include_site_yield_history: False 73 | 74 | # The mapping between the location IDs and their embedding indices 75 | location_id_mapping: 76 | 1: 1 77 | 5: 2 78 | 110: 3 79 | # ... 80 | 81 | #-------------------------------------------- 82 | # Times 83 | #-------------------------------------------- 84 | 85 | # Foreast and time settings 86 | forecast_minutes: 480 87 | history_minutes: 120 88 | 89 | min_sat_delay_minutes: 60 90 | 91 | # These must also be set even if identical to forecast_minutes and history_minutes 92 | sat_history_minutes: 90 93 | pv_history_minutes: 180 94 | 95 | # These must be set for each NWP encoder 96 | nwp_history_minutes: 97 | ukv: 120 98 | ecmwf: 120 99 | nwp_forecast_minutes: 100 | ukv: 480 101 | ecmwf: 480 102 | # Optional; defaults to 60, so must be set for data with different time resolution 103 | nwp_interval_minutes: 104 | ukv: 60 105 | ecmwf: 60 106 | 107 | # ---------------------------------------------- 108 | # Optimizer 109 | # ---------------------------------------------- 110 | optimizer: 111 | _target_: pvnet.optimizers.EmbAdamWReduceLROnPlateau 112 | lr: 0.0001 113 | weight_decay: 0.01 114 | amsgrad: True 115 | patience: 5 116 | factor: 0.1 117 | threshold: 0.002 118 | -------------------------------------------------------------------------------- /configs.example/readme.md: -------------------------------------------------------------------------------- 1 | This directory contains example configuration files for the PVNet project. Many paths will need to unique to each user. You can find these paths by searching for PLACEHOLDER within these logs. Not all of 2 | the values with a placeholder need to be set. For example in the logger subdirectory there are many different loggers with PLACEHOLDERS. If only one logger is used, then only that placeholder needs to be set. 3 | 4 | run experiments by: 5 | `python run.py experiment=example_simple ` 6 | -------------------------------------------------------------------------------- /configs.example/trainer/all_params.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # default values for all trainer parameters 4 | checkpoint_callback: True 5 | default_root_dir: null 6 | gradient_clip_val: 0.0 7 | process_position: 0 8 | num_nodes: 1 9 | num_processes: 1 10 | gpus: null 11 | auto_select_gpus: False 12 | tpu_cores: null 13 | log_gpu_memory: null 14 | overfit_batches: 0.0 15 | track_grad_norm: -1 16 | check_val_every_n_epoch: 1 17 | fast_dev_run: False 18 | accumulate_grad_batches: 1 19 | max_epochs: 1 20 | min_epochs: 1 21 | max_steps: null 22 | min_steps: null 23 | limit_train_batches: 1.0 24 | limit_val_batches: 1.0 25 | limit_test_batches: 1.0 26 | val_check_interval: 1.0 27 | flush_logs_every_n_steps: 100 28 | log_every_n_steps: 50 29 | accelerator: null 30 | sync_batchnorm: False 31 | precision: 32 32 | weights_save_path: null 33 | num_sanity_val_steps: 2 34 | truncated_bptt_steps: null 35 | resume_from_checkpoint: null 36 | profiler: null 37 | benchmark: False 38 | deterministic: False 39 | reload_dataloaders_every_epoch: False 40 | auto_lr_find: False 41 | replace_sampler_ddp: True 42 | terminate_on_nan: False 43 | auto_scale_batch_size: False 44 | prepare_data_per_node: True 45 | plugins: null 46 | amp_backend: "native" 47 | amp_level: "O2" 48 | move_metrics_to_cpu: False 49 | -------------------------------------------------------------------------------- /configs.example/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.Trainer 2 | 3 | # set `gpu` to train on GPU, `cpu` to train on CPU only 4 | accelerator: auto 5 | devices: auto 6 | 7 | min_epochs: null 8 | max_epochs: null 9 | reload_dataloaders_every_n_epochs: 0 10 | num_sanity_val_steps: 8 11 | fast_dev_run: false 12 | 13 | accumulate_grad_batches: 4 14 | log_every_n_steps: 50 15 | -------------------------------------------------------------------------------- /pvnet/__init__.py: -------------------------------------------------------------------------------- 1 | """PVNet source code.""" 2 | -------------------------------------------------------------------------------- /pvnet/datamodule.py: -------------------------------------------------------------------------------- 1 | """ Data module for pytorch lightning """ 2 | 3 | import os 4 | 5 | import numpy as np 6 | from lightning.pytorch import LightningDataModule 7 | from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch 8 | from ocf_data_sampler.numpy_sample.common_types import NumpySample, TensorBatch 9 | from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset 10 | from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset 11 | from ocf_data_sampler.torch_datasets.sample.base import batch_to_tensor 12 | from torch.utils.data import DataLoader, Dataset, Subset 13 | 14 | 15 | def collate_fn(samples: list[NumpySample]) -> TensorBatch: 16 | """Convert a list of NumpySample samples to a tensor batch""" 17 | return batch_to_tensor(stack_np_samples_into_batch(samples)) 18 | 19 | 20 | class BaseDataModule(LightningDataModule): 21 | """Base Datamodule which streams samples using a sampler from ocf-data-sampler.""" 22 | 23 | def __init__( 24 | self, 25 | configuration: str, 26 | batch_size: int = 16, 27 | num_workers: int = 0, 28 | prefetch_factor: int | None = None, 29 | persistent_workers: bool = False, 30 | pin_memory: bool = False, 31 | train_period: list[str | None] = [None, None], 32 | val_period: list[str | None] = [None, None], 33 | seed: int | None = None, 34 | dataset_pickle_dir: str | None = None, 35 | ): 36 | """Base Datamodule for streaming samples. 37 | 38 | Args: 39 | configuration: Path to ocf-data-sampler configuration file. 40 | batch_size: Batch size. 41 | num_workers: Number of workers to use in multiprocess batch loading. 42 | prefetch_factor: Number of batches loaded in advance by each worker. 43 | persistent_workers: If True, the data loader will not shut down the worker processes 44 | after a dataset has been consumed once. This allows to maintain the workers Dataset 45 | instances alive. 46 | pin_memory: If True, the data loader will copy Tensors into device/CUDA pinned memory 47 | before returning them. 48 | train_period: Date range filter for train dataloader. 49 | val_period: Date range filter for val dataloader. 50 | seed: Random seed used in shuffling datasets. 51 | dataset_pickle_dir: Directory in which the val and train set will be presaved as 52 | pickle objects. Setting this speeds up instantiation of multiple workers a lot. 53 | """ 54 | super().__init__() 55 | 56 | self.configuration = configuration 57 | self.train_period = train_period 58 | self.val_period = val_period 59 | self.seed = seed 60 | self.dataset_pickle_dir = dataset_pickle_dir 61 | 62 | self._common_dataloader_kwargs = dict( 63 | batch_size=batch_size, 64 | batch_sampler=None, 65 | num_workers=num_workers, 66 | collate_fn=collate_fn, 67 | pin_memory=pin_memory, 68 | drop_last=False, 69 | timeout=0, 70 | worker_init_fn=None, 71 | prefetch_factor=prefetch_factor, 72 | persistent_workers=persistent_workers, 73 | multiprocessing_context="spawn" if num_workers>0 else None, 74 | ) 75 | 76 | def setup(self, stage: str | None = None): 77 | """Called once to prepare the datasets.""" 78 | 79 | # This logic runs only once at the start of training, therefore the val dataset is only 80 | # shuffled once 81 | if stage == "fit": 82 | 83 | # Prepare the train dataset 84 | self.train_dataset = self._get_dataset(*self.train_period) 85 | 86 | # Prepare and pre-shuffle the val dataset and set seed for reproducibility 87 | val_dataset = self._get_dataset(*self.val_period) 88 | 89 | shuffled_indices = np.random.default_rng(seed=self.seed).permutation(len(val_dataset)) 90 | self.val_dataset = Subset(val_dataset, shuffled_indices) 91 | 92 | if self.dataset_pickle_dir is not None: 93 | os.makedirs(self.dataset_pickle_dir, exist_ok=True) 94 | train_dataset_path = f"{self.dataset_pickle_dir}/train_dataset.pkl" 95 | val_dataset_path = f"{self.dataset_pickle_dir}/val_dataset.pkl" 96 | 97 | # For safety, these pickled datasets cannot be overwritten. 98 | # See: https://github.com/openclimatefix/pvnet/pull/445 99 | for path in [train_dataset_path, val_dataset_path]: 100 | if os.path.exists(path): 101 | raise FileExistsError( 102 | f"The pickled dataset path '{path}' already exists. Make sure that " 103 | "this can be safely deleted (i.e. not currently being used by any " 104 | "training run) and delete it manually. Else change the " 105 | "`dataset_pickle_dir` to a different directory." 106 | ) 107 | 108 | self.train_dataset.presave_pickle(train_dataset_path) 109 | self.train_dataset.presave_pickle(val_dataset_path) 110 | 111 | def teardown(self, stage: str | None = None) -> None: 112 | """Clean up the pickled datasets""" 113 | if self.dataset_pickle_dir is not None: 114 | for filename in ["val_dataset.pkl", "train_dataset.pkl"]: 115 | filepath = f"{self.dataset_pickle_dir}/{filename}" 116 | if os.path.exists(filepath): 117 | os.remove(filepath) 118 | 119 | def _get_dataset(self, start_time: str | None, end_time: str | None) -> Dataset: 120 | raise NotImplementedError 121 | 122 | def train_dataloader(self) -> DataLoader: 123 | """Construct train dataloader""" 124 | return DataLoader(self.train_dataset, shuffle=True, **self._common_dataloader_kwargs) 125 | 126 | def val_dataloader(self) -> DataLoader: 127 | """Construct val dataloader""" 128 | return DataLoader(self.val_dataset, shuffle=False, **self._common_dataloader_kwargs) 129 | 130 | 131 | class UKRegionalDataModule(BaseDataModule): 132 | """Datamodule for streaming UK regional samples.""" 133 | 134 | def _get_dataset(self, start_time: str | None, end_time: str | None) -> PVNetUKRegionalDataset: 135 | return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time) 136 | 137 | 138 | class SitesDataModule(BaseDataModule): 139 | """Datamodule for streaming site samples.""" 140 | 141 | def _get_dataset(self, start_time: str | None, end_time: str | None) -> SitesDataset: 142 | return SitesDataset(self.configuration, start_time=start_time, end_time=end_time) -------------------------------------------------------------------------------- /pvnet/load_model.py: -------------------------------------------------------------------------------- 1 | """Load a model from its checkpoint directory""" 2 | 3 | import glob 4 | import os 5 | 6 | import hydra 7 | import torch 8 | import yaml 9 | 10 | from pvnet.models.ensemble import Ensemble 11 | from pvnet.utils import ( 12 | DATA_CONFIG_NAME, 13 | DATAMODULE_CONFIG_NAME, 14 | FULL_CONFIG_NAME, 15 | MODEL_CONFIG_NAME, 16 | ) 17 | 18 | 19 | def get_model_from_checkpoints( 20 | checkpoint_dir_paths: list[str], 21 | val_best: bool = True, 22 | ) -> tuple[torch.nn.Module, dict, str, str | None, str | None]: 23 | """Load a model from its checkpoint directory 24 | 25 | Returns: 26 | tuple: 27 | model: nn.Module of pretrained model. 28 | model_config: path to model config used to train the model. 29 | data_config: path to data config used to create samples for the model. 30 | datamodule_config: path to datamodule used to create samples e.g train/test split info. 31 | experiment_configs: path to the full experimental config. 32 | 33 | """ 34 | is_ensemble = len(checkpoint_dir_paths) > 1 35 | 36 | model_configs = [] 37 | models = [] 38 | data_configs = [] 39 | datamodule_configs = [] 40 | experiment_configs = [] 41 | 42 | for path in checkpoint_dir_paths: 43 | 44 | # Load lightning training module 45 | with open(f"{path}/{MODEL_CONFIG_NAME}") as cfg: 46 | model_config = yaml.load(cfg, Loader=yaml.FullLoader) 47 | 48 | lightning_module = hydra.utils.instantiate(model_config) 49 | 50 | if val_best: 51 | # Only one epoch (best) saved per model 52 | files = glob.glob(f"{path}/epoch*.ckpt") 53 | if len(files) != 1: 54 | raise ValueError( 55 | f"Found {len(files)} checkpoints @ {path}/epoch*.ckpt. Expected one." 56 | ) 57 | # TODO: Loading with weights_only=False is not recommended 58 | checkpoint = torch.load(files[0], map_location="cpu", weights_only=True) 59 | else: 60 | checkpoint = torch.load(f"{path}/last.ckpt", map_location="cpu", weights_only=True) 61 | 62 | lightning_module.load_state_dict(state_dict=checkpoint["state_dict"]) 63 | 64 | # Extract the model from the lightning module 65 | models.append(lightning_module.model) 66 | model_configs.append(model_config["model"]) 67 | 68 | # Store the data config used for the model 69 | data_config = f"{path}/{DATA_CONFIG_NAME}" 70 | 71 | if os.path.isfile(data_config): 72 | data_configs.append(data_config) 73 | else: 74 | raise FileNotFoundError(f"File {data_config} does not exist") 75 | 76 | # TODO: This should be removed in a future release since no new models will be trained on 77 | # presaved samples 78 | # Check for datamodule config 79 | # This only exists if the model was trained with presaved samples 80 | datamodule_config = f"{path}/{DATAMODULE_CONFIG_NAME}" 81 | if os.path.isfile(datamodule_config): 82 | datamodule_configs.append(datamodule_config) 83 | else: 84 | datamodule_configs.append(None) 85 | 86 | # Check for experiment config 87 | # For backwards compatibility - this might always exist 88 | experiment_config = f"{path}/{FULL_CONFIG_NAME}" 89 | if os.path.isfile(datamodule_config): 90 | experiment_configs.append(experiment_config) 91 | else: 92 | experiment_configs.append(None) 93 | 94 | if is_ensemble: 95 | model_config = { 96 | "_target_": "pvnet.models.ensemble.Ensemble", 97 | "model_list": model_configs, 98 | } 99 | model = Ensemble(model_list=models) 100 | 101 | else: 102 | model_config = model_configs[0] 103 | model = models[0] 104 | 105 | # Assume if using an ensemble that the members were trained on the same input data 106 | data_config = data_configs[0] 107 | datamodule_config = datamodule_configs[0] 108 | 109 | # TODO: How should we save the experimental configs if we had an ensemble? 110 | experiment_config = experiment_configs[0] 111 | 112 | return model, model_config, data_config, datamodule_config, experiment_config 113 | -------------------------------------------------------------------------------- /pvnet/model_cards/empty_model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 7 | 8 | 9 | # TEMPLATE 10 | 11 | 12 | ## Model Description 13 | 14 | 17 | 18 | - **Developed by:** openclimatefix 19 | - **Model type:** Fusion model 20 | - **Language(s) (NLP):** en 21 | - **License:** mit 22 | 23 | # Training Details 24 | 25 | ## Data 26 | 27 | 32 | 33 | 34 | ### Preprocessing 35 | 36 | 39 | 40 | ## Results 41 | 42 | 43 | The training logs for this model commit can be found here: 44 | {{ wandb_links }} 45 | 46 | 47 | ### Hardware 48 | Trained on a single NVIDIA Tesla T4 49 | 50 | 51 | ### Software 52 | 53 | This model was trained using the following Open Climate Fix packages: 54 | 55 | - [1] https://github.com/openclimatefix/PVNet 56 | - [2] https://github.com/openclimatefix/ocf-data-sampler 57 | 58 | 59 | The versions of these packages can be found below: 60 | {{ package_versions }} 61 | -------------------------------------------------------------------------------- /pvnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Models for PVNet""" 2 | from .base_model import BaseModel 3 | from .ensemble import Ensemble 4 | from .late_fusion.late_fusion import LateFusionModel 5 | -------------------------------------------------------------------------------- /pvnet/models/ensemble.py: -------------------------------------------------------------------------------- 1 | """Model which uses mutliple prediction heads""" 2 | import torch 3 | from ocf_data_sampler.numpy_sample.common_types import TensorBatch 4 | from torch import nn 5 | 6 | from pvnet.models.base_model import BaseModel 7 | 8 | 9 | class Ensemble(BaseModel): 10 | """Ensemble of PVNet models""" 11 | 12 | def __init__( 13 | self, 14 | model_list: list[BaseModel], 15 | weights: list[float] | None = None, 16 | ): 17 | """Ensemble of PVNet models 18 | 19 | Args: 20 | model_list: A list of PVNet models to ensemble 21 | weights: A list of weighting to apply to each model. If None, the models are weighted 22 | equally. 23 | """ 24 | 25 | # Surface check all the models are compatible 26 | output_quantiles = [] 27 | history_minutes = [] 28 | forecast_minutes = [] 29 | target_key = [] 30 | interval_minutes = [] 31 | 32 | # Get some model properties from each model 33 | for model in model_list: 34 | output_quantiles.append(model.output_quantiles) 35 | history_minutes.append(model.history_minutes) 36 | forecast_minutes.append(model.forecast_minutes) 37 | target_key.append(model._target_key) 38 | interval_minutes.append(model.interval_minutes) 39 | 40 | # Check these properties are all the same 41 | for param_list in [ 42 | output_quantiles, 43 | history_minutes, 44 | forecast_minutes, 45 | target_key, 46 | interval_minutes, 47 | ]: 48 | assert all([p == param_list[0] for p in param_list]), param_list 49 | 50 | super().__init__( 51 | history_minutes=history_minutes[0], 52 | forecast_minutes=forecast_minutes[0], 53 | output_quantiles=output_quantiles[0], 54 | target_key=target_key[0], 55 | interval_minutes=interval_minutes[0], 56 | ) 57 | 58 | self.model_list = nn.ModuleList(model_list) 59 | 60 | if weights is None: 61 | weights = torch.ones(len(model_list)) / len(model_list) 62 | else: 63 | assert len(weights) == len(model_list) 64 | weights = torch.Tensor(weights) / sum(weights) 65 | self.weights = nn.Parameter(weights, requires_grad=False) 66 | 67 | def forward(self, x: TensorBatch) -> torch.Tensor: 68 | """Run the model forward""" 69 | y_hat = 0 70 | for weight, model in zip(self.weights, self.model_list): 71 | y_hat = model(x) * weight + y_hat 72 | return y_hat 73 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/README.md: -------------------------------------------------------------------------------- 1 | ## Multimodal model architecture 2 | 3 | These models fusion models to predict GSP power output based on NWP, non-HRV satellite, GSP output history, solor coordinates, and GSP ID. 4 | 5 | The core model is `late_fusion.LateFusionModel`, and its architecture is shown in the diagram below. 6 | 7 | ![multimodal_model_diagram](https://github.com/openclimatefix/PVNet/assets/41546094/118393fa-52ec-4bfe-a0a3-268c94c25f1e) 8 | 9 | This model uses encoders which take 4D (time, channel, x, y) inputs of NWP and satellite and encode them into 1D feature vectors. Different encoders are contained inside `encoders`. 10 | 11 | Different choices for the fusion model are contained inside `linear_networks`. 12 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/__init__.py: -------------------------------------------------------------------------------- 1 | """Late fusion models""" 2 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic layers for composite models""" 2 | 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class ImageEmbedding(nn.Module): 9 | """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs.""" 10 | 11 | def __init__(self, num_embeddings: int, sequence_length: int, image_size_pixels: int, **kwargs): 12 | """A embedding layer which concatenates an ID embedding as a new channel onto 3D inputs. 13 | 14 | The embedding is a single 2D image and is appended at each step in the 1st dimension 15 | (assumed to be time). 16 | 17 | Args: 18 | num_embeddings: Size of the dictionary of embeddings 19 | sequence_length: The time sequence length of the data. 20 | image_size_pixels: The spatial size of the image. Assumed square. 21 | **kwargs: See `torch.nn.Embedding` for more possible arguments. 22 | """ 23 | super().__init__() 24 | self.image_size_pixels = image_size_pixels 25 | self.sequence_length = sequence_length 26 | self._embed = nn.Embedding( 27 | num_embeddings=num_embeddings, 28 | embedding_dim=image_size_pixels * image_size_pixels, 29 | **kwargs, 30 | ) 31 | 32 | def forward(self, x: torch.Tensor, id: torch.Tensor) -> torch.Tensor: 33 | """Append ID embedding to image""" 34 | emb = self._embed(id) 35 | emb = emb.reshape((-1, 1, 1, self.image_size_pixels, self.image_size_pixels)) 36 | emb = emb.repeat(1, 1, self.sequence_length, 1, 1) 37 | return torch.cat((x, emb), dim=1) 38 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Submodels to encode satellite and NWP inputs""" 2 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/encoders/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic blocks for image sequence encoders""" 2 | from abc import ABCMeta, abstractmethod 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta): 9 | """Abstract class for NWP/satellite encoder. 10 | 11 | The encoder will take an input of shape (batch_size, sequence_length, channels, height, width) 12 | and return an output of shape (batch_size, out_features). 13 | """ 14 | 15 | def __init__( 16 | self, 17 | sequence_length: int, 18 | image_size_pixels: int, 19 | in_channels: int, 20 | out_features: int, 21 | ): 22 | """Abstract class for NWP/satellite encoder. 23 | 24 | Args: 25 | sequence_length: The time sequence length of the data. 26 | image_size_pixels: The spatial size of the image. Assumed square. 27 | in_channels: Number of input channels. 28 | out_features: Number of output features. 29 | """ 30 | super().__init__() 31 | self.out_features = out_features 32 | self.image_size_pixels = image_size_pixels 33 | self.sequence_length = sequence_length 34 | self.in_channels = in_channels 35 | 36 | @abstractmethod 37 | def forward(self): 38 | """Run model forward""" 39 | pass 40 | 41 | 42 | class ResidualConv3dBlock(nn.Module): 43 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 44 | 45 | This was the best performing residual block tested in the study. This implementation differs 46 | from that block just by using LeakyReLU activation to avoid dead neurons, and by including 47 | optional dropout in the residual branch. This is also a 3D fully connected layer residual block 48 | rather than a 2D convolutional block. 49 | 50 | Sources: 51 | [1] https://arxiv.org/pdf/1603.05027.pdf 52 | """ 53 | 54 | def __init__( 55 | self, 56 | in_channels: int, 57 | n_layers: int = 2, 58 | dropout_frac: float = 0.0, 59 | batch_norm: bool = True, 60 | ): 61 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 62 | 63 | Sources: 64 | [1] https://arxiv.org/pdf/1603.05027.pdf 65 | 66 | Args: 67 | in_channels: Number of input channels. 68 | n_layers: Number of layers in residual pathway. 69 | dropout_frac: Probability of an element to be zeroed. 70 | batch_norm: Whether to use batchnorm 71 | """ 72 | super().__init__() 73 | 74 | layers = [] 75 | for i in range(n_layers): 76 | if batch_norm: 77 | layers.append(nn.BatchNorm3d(in_channels)) 78 | layers.extend( 79 | [ 80 | nn.Dropout3d(p=dropout_frac), 81 | nn.LeakyReLU(), 82 | nn.Conv3d( 83 | in_channels=in_channels, 84 | out_channels=in_channels, 85 | kernel_size=(3, 3, 3), 86 | padding=(1, 1, 1), 87 | ), 88 | ] 89 | ) 90 | 91 | self.model = nn.Sequential(*layers) 92 | 93 | def forward(self, x: torch.Tensor) -> torch.Tensor: 94 | """Run model forward""" 95 | return self.model(x) + x 96 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/encoders/encoders3d.py: -------------------------------------------------------------------------------- 1 | """Encoder modules for the satellite/NWP data based on 3D concolutions. 2 | """ 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from pvnet.models.late_fusion.encoders.basic_blocks import ( 8 | AbstractNWPSatelliteEncoder, 9 | ResidualConv3dBlock, 10 | ) 11 | 12 | 13 | class DefaultPVNet(AbstractNWPSatelliteEncoder): 14 | """This is the original encoding module used in PVNet, with a few minor tweaks.""" 15 | 16 | def __init__( 17 | self, 18 | sequence_length: int, 19 | image_size_pixels: int, 20 | in_channels: int, 21 | out_features: int, 22 | number_of_conv3d_layers: int = 4, 23 | conv3d_channels: int = 32, 24 | fc_features: int = 128, 25 | spatial_kernel_size: int = 3, 26 | temporal_kernel_size: int = 3, 27 | padding: int | tuple[int, ...] = (1, 0, 0), 28 | stride: int | tuple[int, ...] = 1, 29 | ): 30 | """This is the original encoding module used in PVNet, with a few minor tweaks. 31 | 32 | Args: 33 | sequence_length: The time sequence length of the data. 34 | image_size_pixels: The spatial size of the image. Assumed square. 35 | in_channels: Number of input channels. 36 | out_features: Number of output features. 37 | number_of_conv3d_layers: Number of convolution 3d layers that are used. 38 | conv3d_channels: Number of channels used in each conv3d layer. 39 | fc_features: number of output nodes out of the hidden fully connected layer. 40 | spatial_kernel_size: The spatial size of the kernel used in the conv3d layers. 41 | temporal_kernel_size: The temporal size of the kernel used in the conv3d layers. 42 | padding: The padding used in the conv3d layers. If an int, the same padding 43 | is used in all dimensions. The dimensions are (time, space, space) 44 | stride: The stride used in conv3d layers. If an int, the same stride is used 45 | in all dimensions 46 | """ 47 | 48 | super().__init__(sequence_length, image_size_pixels, in_channels, out_features) 49 | 50 | if isinstance(padding, int): 51 | padding = (padding, padding, padding) 52 | 53 | if isinstance(stride, int): 54 | stride = (stride, stride, stride) 55 | 56 | # Check that the output shape of the convolutional layers will be at least 1x1 57 | cnn_spatial_output_size = image_size_pixels 58 | 59 | for _ in range(number_of_conv3d_layers): 60 | cnn_spatial_output_size = ( 61 | cnn_spatial_output_size - spatial_kernel_size + 2 * padding[1] 62 | ) // stride[1] + 1 63 | 64 | if not (cnn_spatial_output_size >= 1): 65 | raise ValueError( 66 | f"cannot use this many conv3d layers ({number_of_conv3d_layers}) with this input " 67 | f"spatial size ({image_size_pixels})" 68 | ) 69 | 70 | cnn_sequence_length = ( 71 | sequence_length 72 | - ((temporal_kernel_size - 2 * padding[0]) - 1) * number_of_conv3d_layers 73 | ) 74 | 75 | conv_layers = [] 76 | 77 | conv_layers += [ 78 | nn.Conv3d( 79 | in_channels=in_channels, 80 | out_channels=conv3d_channels, 81 | kernel_size=(temporal_kernel_size, spatial_kernel_size, spatial_kernel_size), 82 | padding=padding, 83 | stride=stride, 84 | ), 85 | nn.ELU(), 86 | ] 87 | for _ in range(0, number_of_conv3d_layers - 1): 88 | conv_layers += [ 89 | nn.Conv3d( 90 | in_channels=conv3d_channels, 91 | out_channels=conv3d_channels, 92 | kernel_size=(temporal_kernel_size, spatial_kernel_size, spatial_kernel_size), 93 | padding=padding, 94 | stride=stride, 95 | ), 96 | nn.ELU(), 97 | ] 98 | 99 | self.conv_layers = nn.Sequential(*conv_layers) 100 | 101 | # Calculate the size of the output of the 3D convolutional layers 102 | cnn_output_size = conv3d_channels * cnn_spatial_output_size**2 * cnn_sequence_length 103 | 104 | self.final_block = nn.Sequential( 105 | nn.Linear(in_features=cnn_output_size, out_features=fc_features), 106 | nn.ELU(), 107 | nn.Linear(in_features=fc_features, out_features=out_features), 108 | nn.ELU(), 109 | ) 110 | 111 | def forward(self, x: torch.Tensor) -> torch.Tensor: 112 | """Run model forward""" 113 | out = self.conv_layers(x) 114 | out = out.reshape(x.shape[0], -1) 115 | return self.final_block(out) 116 | 117 | 118 | class ResConv3DNet(AbstractNWPSatelliteEncoder): 119 | """3D convolutional network based on ResNet architecture. 120 | 121 | The residual blocks are implemented based on the best performing block in [1]. 122 | 123 | Sources: 124 | [1] https://arxiv.org/pdf/1603.05027.pdf 125 | """ 126 | 127 | def __init__( 128 | self, 129 | sequence_length: int, 130 | image_size_pixels: int, 131 | in_channels: int, 132 | out_features: int, 133 | hidden_channels: int = 32, 134 | n_res_blocks: int = 4, 135 | res_block_layers: int = 2, 136 | batch_norm: bool = True, 137 | dropout_frac: float = 0.0, 138 | ): 139 | """Fully connected deep network based on ResNet architecture. 140 | 141 | Args: 142 | sequence_length: The time sequence length of the data. 143 | image_size_pixels: The spatial size of the image. Assumed square. 144 | in_channels: Number of input channels. 145 | out_features: Number of output features. 146 | hidden_channels: Number of channels in middle hidden layers. 147 | n_res_blocks: Number of residual blocks to use. 148 | res_block_layers: Number of Conv3D layers used in each residual block. 149 | batch_norm: Whether to include batch normalisation. 150 | dropout_frac: Probability of an element to be zeroed in the residual pathways. 151 | """ 152 | super().__init__(sequence_length, image_size_pixels, in_channels, out_features) 153 | 154 | model = [ 155 | nn.Conv3d( 156 | in_channels=in_channels, 157 | out_channels=hidden_channels, 158 | kernel_size=(3, 3, 3), 159 | padding=(1, 1, 1), 160 | ), 161 | ] 162 | 163 | for i in range(n_res_blocks): 164 | model.extend( 165 | [ 166 | ResidualConv3dBlock( 167 | in_channels=hidden_channels, 168 | n_layers=res_block_layers, 169 | dropout_frac=dropout_frac, 170 | batch_norm=batch_norm, 171 | ), 172 | nn.AvgPool3d((1, 2, 2), stride=(1, 2, 2)), 173 | ] 174 | ) 175 | 176 | # Calculate the size of the output of the 3D convolutional layers 177 | final_im_size = image_size_pixels // (2**n_res_blocks) 178 | cnn_output_size = hidden_channels * sequence_length * final_im_size * final_im_size 179 | 180 | model.extend( 181 | [ 182 | nn.ELU(), 183 | nn.Flatten(start_dim=1, end_dim=-1), 184 | nn.Linear(in_features=cnn_output_size, out_features=out_features), 185 | nn.ELU(), 186 | ] 187 | ) 188 | 189 | self.model = nn.Sequential(*model) 190 | 191 | def forward(self, x: torch.Tensor) -> torch.Tensor: 192 | """Run model forward""" 193 | return self.model(x) 194 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/linear_networks/__init__.py: -------------------------------------------------------------------------------- 1 | """Submodels to combine 1D feature vectors from different sources and make final predictions""" 2 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/linear_networks/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic blocks for the lienar networks""" 2 | from abc import ABCMeta, abstractmethod 3 | from collections import OrderedDict 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class AbstractLinearNetwork(nn.Module, metaclass=ABCMeta): 10 | """Abstract class for a network to combine the features from all the inputs.""" 11 | 12 | def __init__( 13 | self, 14 | in_features: int, 15 | out_features: int, 16 | ): 17 | """Abstract class for a network to combine the features from all the inputs. 18 | 19 | Args: 20 | in_features: Number of input features. 21 | out_features: Number of output features. 22 | """ 23 | super().__init__() 24 | 25 | def cat_modes(self, x: OrderedDict[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: 26 | """Concatenate modes of input data into 1D feature vector""" 27 | if isinstance(x, OrderedDict): 28 | return torch.cat([value for key, value in x.items()], dim=1) 29 | elif isinstance(x, torch.Tensor): 30 | return x 31 | else: 32 | raise ValueError(f"Input of unexpected type {type(x)}") 33 | 34 | @abstractmethod 35 | def forward(self, x: OrderedDict[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: 36 | """Run model forward""" 37 | pass 38 | 39 | 40 | class ResidualLinearBlock(nn.Module): 41 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 42 | 43 | This was the best performing residual block tested in the study. This implementation differs 44 | from that block just by using LeakyReLU activation to avoid dead neuron, and by including 45 | optional dropout in the residual branch. This is also a 1D fully connected layer residual block 46 | rather than a 2D convolutional block. 47 | 48 | Sources: 49 | [1] https://arxiv.org/pdf/1603.05027.pdf 50 | """ 51 | 52 | def __init__( 53 | self, 54 | in_features: int, 55 | n_layers: int = 2, 56 | dropout_frac: float = 0.0, 57 | ): 58 | """Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1]. 59 | 60 | Sources: 61 | [1] https://arxiv.org/pdf/1603.05027.pdf 62 | 63 | Args: 64 | in_features: Number of input features. 65 | n_layers: Number of layers in residual pathway. 66 | dropout_frac: Probability of an element to be zeroed. 67 | """ 68 | super().__init__() 69 | 70 | layers = [] 71 | for i in range(n_layers): 72 | layers += [ 73 | nn.BatchNorm1d(in_features), 74 | nn.Dropout(p=dropout_frac), 75 | nn.LeakyReLU(), 76 | nn.Linear( 77 | in_features=in_features, 78 | out_features=in_features, 79 | ), 80 | ] 81 | 82 | self.model = nn.Sequential(*layers) 83 | 84 | def forward(self, x: torch.Tensor) -> torch.Tensor: 85 | """Run model forward""" 86 | return self.model(x) + x 87 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/linear_networks/networks.py: -------------------------------------------------------------------------------- 1 | """Linear networks used for the fusion model""" 2 | from collections import OrderedDict 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from pvnet.models.late_fusion.linear_networks.basic_blocks import ( 8 | AbstractLinearNetwork, 9 | ResidualLinearBlock, 10 | ) 11 | 12 | 13 | class ResFCNet(AbstractLinearNetwork): 14 | """Fully connected deep network based on ResNet architecture. 15 | 16 | This architecture is similar to 17 | `ResFCNet`, except that it uses LeakyReLU activations internally, and batchnorm in the residual 18 | branches. The residual blocks are implemented based on the best performing block in [1]. 19 | 20 | Sources: 21 | [1] https://arxiv.org/pdf/1603.05027.pdf 22 | """ 23 | 24 | def __init__( 25 | self, 26 | in_features: int, 27 | out_features: int, 28 | fc_hidden_features: int = 128, 29 | n_res_blocks: int = 4, 30 | res_block_layers: int = 2, 31 | dropout_frac: float = 0.0, 32 | ): 33 | """Fully connected deep network based on ResNet architecture. 34 | 35 | Args: 36 | in_features: Number of input features. 37 | out_features: Number of output features. 38 | fc_hidden_features: Number of features in middle hidden layers. 39 | n_res_blocks: Number of residual blocks to use. 40 | res_block_layers: Number of fully-connected layers used in each residual block. 41 | dropout_frac: Probability of an element to be zeroed in the residual pathways. 42 | """ 43 | super().__init__(in_features, out_features) 44 | 45 | model = [nn.Linear(in_features=in_features, out_features=fc_hidden_features)] 46 | 47 | for i in range(n_res_blocks): 48 | model += [ 49 | ResidualLinearBlock( 50 | in_features=fc_hidden_features, 51 | n_layers=res_block_layers, 52 | dropout_frac=dropout_frac, 53 | ) 54 | ] 55 | 56 | model += [ 57 | nn.LeakyReLU(), 58 | nn.Linear(in_features=fc_hidden_features, out_features=out_features), 59 | nn.LeakyReLU(negative_slope=0.01), 60 | ] 61 | 62 | self.model = nn.Sequential(*model) 63 | 64 | def forward(self, x: OrderedDict[str, torch.Tensor] | torch.Tensor) -> torch.Tensor: 65 | """Run model forward""" 66 | x = self.cat_modes(x) 67 | return self.model(x) 68 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/site_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Submodels to encode site-level PV data""" 2 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/site_encoders/basic_blocks.py: -------------------------------------------------------------------------------- 1 | """Basic blocks for PV-site encoders""" 2 | from abc import ABCMeta, abstractmethod 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class AbstractSitesEncoder(nn.Module, metaclass=ABCMeta): 9 | """Abstract class for encoder for output data from multiple PV sites. 10 | 11 | The encoder will take an input of shape (batch_size, sequence_length, num_sites) 12 | and return an output of shape (batch_size, out_features). 13 | """ 14 | 15 | def __init__( 16 | self, 17 | sequence_length: int, 18 | num_sites: int, 19 | out_features: int, 20 | ): 21 | """Abstract class for PV site-level encoder. 22 | 23 | Args: 24 | sequence_length: The time sequence length of the data. 25 | num_sites: Number of PV sites in the input data. 26 | out_features: Number of output features. 27 | """ 28 | super().__init__() 29 | self.sequence_length = sequence_length 30 | self.num_sites = num_sites 31 | self.out_features = out_features 32 | 33 | @abstractmethod 34 | def forward(self) -> torch.Tensor: 35 | """Run model forward""" 36 | pass 37 | -------------------------------------------------------------------------------- /pvnet/models/late_fusion/site_encoders/encoders.py: -------------------------------------------------------------------------------- 1 | """Encoder modules for the site-level PV data. 2 | 3 | """ 4 | 5 | import einops 6 | import torch 7 | from ocf_data_sampler.numpy_sample.common_types import TensorBatch 8 | from torch import nn 9 | 10 | from pvnet.models.late_fusion.linear_networks.networks import ResFCNet 11 | from pvnet.models.late_fusion.site_encoders.basic_blocks import AbstractSitesEncoder 12 | 13 | 14 | class SimpleLearnedAggregator(AbstractSitesEncoder): 15 | """A simple model which learns a different weighted-average across all PV sites for each GSP. 16 | 17 | Each sequence from each site is independently encodeded through some dense layers wih skip- 18 | connections, then the encoded form of each sequence is aggregated through a learned weighted-sum 19 | and finally put through more dense layers. 20 | 21 | This model was written to be a simplified version of a single-headed attention layer. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | sequence_length: int, 27 | num_sites: int, 28 | out_features: int, 29 | value_dim: int = 10, 30 | value_enc_resblocks: int = 2, 31 | final_resblocks: int = 2, 32 | ): 33 | """A simple sequence encoder and weighted-average model. 34 | 35 | Args: 36 | sequence_length: The time sequence length of the data. 37 | num_sites: Number of PV sites in the input data. 38 | out_features: Number of output features. 39 | value_dim: The number of features in each encoded sequence. Similar to the value 40 | dimension in single- or multi-head attention. 41 | value_dim: The number of features in each encoded sequence. Similar to the value 42 | dimension in single- or multi-head attention. 43 | value_enc_resblocks: Number of residual blocks in the value-encoder sub-network. 44 | final_resblocks: Number of residual blocks in the final sub-network. 45 | """ 46 | 47 | super().__init__(sequence_length, num_sites, out_features) 48 | 49 | # Network used to encode each PV site sequence 50 | self._value_encoder = nn.Sequential( 51 | ResFCNet( 52 | in_features=sequence_length, 53 | out_features=value_dim, 54 | fc_hidden_features=value_dim, 55 | n_res_blocks=value_enc_resblocks, 56 | res_block_layers=2, 57 | dropout_frac=0, 58 | ), 59 | ) 60 | 61 | # The learned weighted average is stored in an embedding layer for ease of use 62 | self._attention_network = nn.Sequential( 63 | nn.Embedding(318, num_sites), 64 | nn.Softmax(dim=1), 65 | ) 66 | 67 | # Network used to process weighted average 68 | self.output_network = ResFCNet( 69 | in_features=value_dim, 70 | out_features=out_features, 71 | fc_hidden_features=value_dim, 72 | n_res_blocks=final_resblocks, 73 | res_block_layers=2, 74 | dropout_frac=0, 75 | ) 76 | 77 | def _calculate_attention(self, x: TensorBatch) -> torch.Tensor: 78 | gsp_ids = x["gsp_id"].squeeze().int() 79 | attention = self._attention_network(gsp_ids) 80 | return attention 81 | 82 | def _encode_value(self, x: TensorBatch) -> torch.Tensor: 83 | # Shape: [batch size, sequence length, PV site] 84 | pv_site_seqs = x["pv"].float() 85 | batch_size = pv_site_seqs.shape[0] 86 | 87 | pv_site_seqs = pv_site_seqs.swapaxes(1, 2).flatten(0, 1) 88 | 89 | x_seq_enc = self._value_encoder(pv_site_seqs) 90 | x_seq_out = x_seq_enc.unflatten(0, (batch_size, self.num_sites)) 91 | return x_seq_out 92 | 93 | def forward(self, x: TensorBatch) -> torch.Tensor: 94 | """Run model forward""" 95 | # Output has shape: [batch size, num_sites, value_dim] 96 | encodeded_seqs = self._encode_value(x) 97 | 98 | # Calculate learned averaging weights 99 | attn_avg_weights = self._calculate_attention(x) 100 | 101 | # Take weighted average across num_sites 102 | value_weighted_avg = (encodeded_seqs * attn_avg_weights.unsqueeze(-1)).sum(dim=1) 103 | 104 | # Put through final processing layers 105 | x_out = self.output_network(value_weighted_avg) 106 | 107 | return x_out 108 | 109 | 110 | class SingleAttentionNetwork(AbstractSitesEncoder): 111 | """A simple attention-based model with a single multihead attention layer 112 | 113 | For the attention layer the query is based on the target alone, the key is based on the 114 | input ID and the recent input data, the value is based on the recent input data. 115 | 116 | """ 117 | 118 | def __init__( 119 | self, 120 | sequence_length: int, 121 | num_sites: int, 122 | out_features: int, 123 | kdim: int = 10, 124 | id_embed_dim: int = 10, 125 | num_heads: int = 2, 126 | n_kv_res_blocks: int = 2, 127 | kv_res_block_layers: int = 2, 128 | use_id_in_value: bool = False, 129 | target_id_dim: int = 318, 130 | target_key_to_use: str = "gsp", 131 | input_key_to_use: str = "site", 132 | num_channels: int = 1, 133 | num_sites_in_inference: int = 1, 134 | ): 135 | """A simple attention-based model with a single multihead attention layer 136 | 137 | Args: 138 | sequence_length: The time sequence length of the data. 139 | num_sites: Number of sites in the input data. 140 | out_features: Number of output features. In this network this is also the embed and 141 | value dimension in the multi-head attention layer. 142 | kdim: The dimensions used the keys. 143 | id_embed_dim: Number of dimensiosn used in the site ID embedding layer(s). 144 | num_heads: Number of parallel attention heads. Note that `out_features` will be split 145 | across `num_heads` so `out_features` must be a multiple of `num_heads`. 146 | n_kv_res_blocks: Number of residual blocks to use in the key and value encoders. 147 | kv_res_block_layers: Number of fully-connected layers used in each residual block within 148 | the key and value encoders. 149 | use_id_in_value: Whether to use a site ID embedding in network used to produce the 150 | value for the attention layer. 151 | target_id_dim: The number of unique IDs. 152 | target_key_to_use: The key to use for the target in the attention layer. 153 | input_key_to_use: The key to use for the input in the attention layer. 154 | num_channels: Number of channels in the input data 155 | num_sites_in_inference: Number of sites to use in inference. 156 | This is used to determine the number of sites to use in the 157 | attention layer, for a single site, 1 works, while for multiple sites 158 | this would be higher than that 159 | 160 | """ 161 | super().__init__(sequence_length, num_sites, out_features) 162 | self.sequence_length = sequence_length 163 | self.target_id_embedding = nn.Embedding(target_id_dim, out_features) 164 | self.site_id_embedding = nn.Embedding(num_sites, id_embed_dim) 165 | self._ids = nn.parameter.Parameter(torch.arange(num_sites), requires_grad=False) 166 | self.use_id_in_value = use_id_in_value 167 | self.target_key_to_use = target_key_to_use 168 | self.input_key_to_use = input_key_to_use 169 | self.num_channels = num_channels 170 | self.num_sites_in_inference = num_sites_in_inference 171 | 172 | if use_id_in_value: 173 | self.value_id_embedding = nn.Embedding(num_sites, id_embed_dim) 174 | 175 | self._value_encoder = nn.Sequential( 176 | ResFCNet( 177 | in_features=sequence_length * self.num_channels 178 | + int(use_id_in_value) * id_embed_dim, 179 | out_features=out_features, 180 | fc_hidden_features=sequence_length * self.num_channels, 181 | n_res_blocks=n_kv_res_blocks, 182 | res_block_layers=kv_res_block_layers, 183 | dropout_frac=0, 184 | ), 185 | ) 186 | 187 | self._key_encoder = nn.Sequential( 188 | ResFCNet( 189 | in_features=id_embed_dim + sequence_length * self.num_channels, 190 | out_features=kdim, 191 | fc_hidden_features=id_embed_dim + sequence_length * self.num_channels, 192 | n_res_blocks=n_kv_res_blocks, 193 | res_block_layers=kv_res_block_layers, 194 | dropout_frac=0, 195 | ), 196 | ) 197 | 198 | self.multihead_attn = nn.MultiheadAttention( 199 | embed_dim=out_features, 200 | kdim=kdim, 201 | vdim=out_features, 202 | num_heads=num_heads, 203 | batch_first=True, 204 | ) 205 | 206 | def _encode_inputs(self, x: TensorBatch) -> tuple[torch.Tensor, int]: 207 | # Shape: [batch size, sequence length, number of sites] 208 | # Shape: [batch size, station_id, sequence length, channels] 209 | input_data = x[f"{self.input_key_to_use}"] 210 | if len(input_data.shape) == 2: # one site per sample 211 | input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D 212 | if len(input_data.shape) == 4: # Has multiple channels 213 | input_data = input_data[:, :, : self.sequence_length] 214 | input_data = einops.rearrange(input_data, "b id s c -> b (s c) id") 215 | else: 216 | input_data = input_data[:, : self.sequence_length] 217 | site_seqs = input_data.float() 218 | batch_size = site_seqs.shape[0] 219 | site_seqs = site_seqs.swapaxes(1, 2) # [batch size, Site ID, sequence length] 220 | return site_seqs, batch_size 221 | 222 | def _encode_query(self, x: TensorBatch) -> torch.Tensor: 223 | if self.target_key_to_use == "gsp": 224 | # GSP seems to have a different structure 225 | ids = x[f"{self.target_key_to_use}_id"] 226 | else: 227 | ids = x[f"{self.input_key_to_use}_id"] 228 | ids = ids.int() 229 | query = self.target_id_embedding(ids).unsqueeze(1) 230 | return query 231 | 232 | def _encode_key(self, x: TensorBatch) -> torch.Tensor: 233 | site_seqs, batch_size = self._encode_inputs(x) 234 | 235 | # site ID embeddings are the same for each sample 236 | site_id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1)) 237 | # Each concated (site sequence, site ID embedding) is processed with encoder 238 | x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1) 239 | key = self._key_encoder(x_seq_in) 240 | 241 | # Reshape to [batch size, site, kdim] 242 | key = key.unflatten(0, (batch_size, self.num_sites)) 243 | return key 244 | 245 | def _encode_value(self, x: TensorBatch) -> torch.Tensor: 246 | site_seqs, batch_size = self._encode_inputs(x) 247 | 248 | if self.use_id_in_value: 249 | # site ID embeddings are the same for each sample 250 | site_id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1)) 251 | # Each concated (site sequence, site ID embedding) is processed with encoder 252 | x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1) 253 | else: 254 | # Encode each site sequence independently 255 | x_seq_in = site_seqs.flatten(0, 1) 256 | value = self._value_encoder(x_seq_in) 257 | 258 | # Reshape to [batch size, site, vdim] 259 | value = value.unflatten(0, (batch_size, self.num_sites)) 260 | return value 261 | 262 | def _attention_forward( 263 | self, x: dict, 264 | average_attn_weights: bool = True 265 | ) -> tuple[torch.Tensor, torch.Tensor:]: 266 | query = self._encode_query(x) 267 | key = self._encode_key(x) 268 | value = self._encode_value(x) 269 | attn_output, attn_weights = self.multihead_attn( 270 | query, key, value, average_attn_weights=average_attn_weights 271 | ) 272 | 273 | return attn_output, attn_weights 274 | 275 | def forward(self, x: TensorBatch) -> torch.Tensor: 276 | """Run model forward""" 277 | 278 | attn_output, _ = self._attention_forward(x) 279 | 280 | # Reshape from [batch_size, 1, vdim] to [batch_size, vdim] 281 | x_out = attn_output.squeeze() 282 | if len(x_out.shape) == 1: 283 | x_out = x_out.unsqueeze(0) 284 | 285 | return x_out 286 | -------------------------------------------------------------------------------- /pvnet/optimizers.py: -------------------------------------------------------------------------------- 1 | """Optimizer factory-function classes. 2 | """ 3 | 4 | from abc import ABC, abstractmethod 5 | 6 | import torch 7 | from torch.nn import Module 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | def find_submodule_parameters(model: Module, search_modules: list[Module]) -> list[Parameter]: 12 | """Finds all parameters within given submodule types 13 | 14 | Args: 15 | model: torch Module to search through 16 | search_modules: List of submodule types to search for 17 | """ 18 | if isinstance(model, search_modules): 19 | return model.parameters() 20 | 21 | children = list(model.children()) 22 | if len(children) == 0: 23 | return [] 24 | else: 25 | params = [] 26 | for c in children: 27 | params += find_submodule_parameters(c, search_modules) 28 | return params 29 | 30 | 31 | def find_other_than_submodule_parameters( 32 | model: Module, 33 | ignore_modules: list[Module], 34 | ) -> list[Parameter]: 35 | """Finds all parameters not with given submodule types 36 | 37 | Args: 38 | model: torch Module to search through 39 | ignore_modules: List of submodule types to ignore 40 | """ 41 | if isinstance(model, ignore_modules): 42 | return [] 43 | 44 | children = list(model.children()) 45 | if len(children) == 0: 46 | return model.parameters() 47 | else: 48 | params = [] 49 | for c in children: 50 | params += find_other_than_submodule_parameters(c, ignore_modules) 51 | return params 52 | 53 | 54 | class AbstractOptimizer(ABC): 55 | """Abstract class for optimizer 56 | 57 | Optimizer classes will be used by model like: 58 | > OptimizerGenerator = AbstractOptimizer() 59 | > optimizer = OptimizerGenerator(model) 60 | The returned object `optimizer` must be something that may be returned by `pytorch_lightning`'s 61 | `configure_optimizers()` method. 62 | See : 63 | https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers 64 | 65 | """ 66 | 67 | @abstractmethod 68 | def __call__(self): 69 | """Abstract call""" 70 | pass 71 | 72 | 73 | class Adam(AbstractOptimizer): 74 | """Adam optimizer""" 75 | 76 | def __init__(self, lr: float = 0.0005, **kwargs): 77 | """Adam optimizer""" 78 | self.lr = lr 79 | self.kwargs = kwargs 80 | 81 | def __call__(self, model: Module): 82 | """Return optimizer""" 83 | return torch.optim.Adam(model.parameters(), lr=self.lr, **self.kwargs) 84 | 85 | 86 | class AdamW(AbstractOptimizer): 87 | """AdamW optimizer""" 88 | 89 | def __init__(self, lr: float = 0.0005, **kwargs): 90 | """AdamW optimizer""" 91 | self.lr = lr 92 | self.kwargs = kwargs 93 | 94 | def __call__(self, model: Module): 95 | """Return optimizer""" 96 | return torch.optim.AdamW(model.parameters(), lr=self.lr, **self.kwargs) 97 | 98 | 99 | 100 | class EmbAdamWReduceLROnPlateau(AbstractOptimizer): 101 | """AdamW optimizer and reduce on plateau scheduler""" 102 | 103 | def __init__( 104 | self, 105 | lr: float = 0.0005, 106 | weight_decay: float = 0.01, 107 | patience: int = 3, 108 | factor: float = 0.5, 109 | threshold: float = 2e-4, 110 | **opt_kwargs, 111 | ): 112 | """AdamW optimizer and reduce on plateau scheduler""" 113 | self.lr = lr 114 | self.weight_decay = weight_decay 115 | self.patience = patience 116 | self.factor = factor 117 | self.threshold = threshold 118 | self.opt_kwargs = opt_kwargs 119 | 120 | def __call__(self, model): 121 | """Return optimizer""" 122 | 123 | search_modules = (torch.nn.Embedding,) 124 | 125 | no_decay = find_submodule_parameters(model, search_modules) 126 | decay = find_other_than_submodule_parameters(model, search_modules) 127 | 128 | optim_groups = [ 129 | {"params": decay, "weight_decay": self.weight_decay}, 130 | {"params": no_decay, "weight_decay": 0.0}, 131 | ] 132 | opt = torch.optim.AdamW(optim_groups, lr=self.lr, **self.opt_kwargs) 133 | 134 | sch = torch.optim.lr_scheduler.ReduceLROnPlateau( 135 | opt, 136 | factor=self.factor, 137 | patience=self.patience, 138 | threshold=self.threshold, 139 | ) 140 | sch = { 141 | "scheduler": sch, 142 | "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", 143 | } 144 | return [opt], [sch] 145 | 146 | 147 | class AdamWReduceLROnPlateau(AbstractOptimizer): 148 | """AdamW optimizer and reduce on plateau scheduler""" 149 | 150 | def __init__( 151 | self, 152 | lr: float = 0.0005, 153 | patience: int = 3, 154 | factor: float = 0.5, 155 | threshold: float = 2e-4, 156 | step_freq=None, 157 | **opt_kwargs, 158 | ): 159 | """AdamW optimizer and reduce on plateau scheduler""" 160 | self._lr = lr 161 | self.patience = patience 162 | self.factor = factor 163 | self.threshold = threshold 164 | self.step_freq = step_freq 165 | self.opt_kwargs = opt_kwargs 166 | 167 | def _call_multi(self, model): 168 | remaining_params = {k: p for k, p in model.named_parameters()} 169 | 170 | group_args = [] 171 | 172 | for key in self._lr.keys(): 173 | if key == "default": 174 | continue 175 | 176 | submodule_params = [] 177 | for param_name in list(remaining_params.keys()): 178 | if param_name.startswith(key): 179 | submodule_params += [remaining_params.pop(param_name)] 180 | 181 | group_args += [{"params": submodule_params, "lr": self._lr[key]}] 182 | 183 | remaining_params = [p for k, p in remaining_params.items()] 184 | group_args += [{"params": remaining_params}] 185 | 186 | opt = torch.optim.AdamW( 187 | group_args, 188 | lr=self._lr["default"] if model.lr is None else model.lr, 189 | **self.opt_kwargs, 190 | ) 191 | sch = { 192 | "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( 193 | opt, 194 | factor=self.factor, 195 | patience=self.patience, 196 | threshold=self.threshold, 197 | ), 198 | "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", 199 | } 200 | 201 | return [opt], [sch] 202 | 203 | def __call__(self, model): 204 | """Return optimizer""" 205 | if not isinstance(self._lr, float): 206 | return self._call_multi(model) 207 | else: 208 | default_lr = self._lr if model.lr is None else model.lr 209 | opt = torch.optim.AdamW(model.parameters(), lr=default_lr, **self.opt_kwargs) 210 | sch = torch.optim.lr_scheduler.ReduceLROnPlateau( 211 | opt, 212 | factor=self.factor, 213 | patience=self.patience, 214 | threshold=self.threshold, 215 | ) 216 | sch = { 217 | "scheduler": sch, 218 | "monitor": "quantile_loss/val" if model.use_quantile_regression else "MAE/val", 219 | } 220 | return [opt], [sch] 221 | -------------------------------------------------------------------------------- /pvnet/training/__init__.py: -------------------------------------------------------------------------------- 1 | """Training submodule""" 2 | from .train import train -------------------------------------------------------------------------------- /pvnet/training/lightning_module.py: -------------------------------------------------------------------------------- 1 | """Pytorch lightning module for training PVNet models""" 2 | 3 | import lightning.pytorch as pl 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torch.nn.functional as F 9 | import wandb 10 | import xarray as xr 11 | from ocf_data_sampler.numpy_sample.common_types import TensorBatch 12 | from ocf_data_sampler.torch_datasets.sample.base import copy_batch_to_device 13 | 14 | from pvnet.datamodule import collate_fn 15 | from pvnet.models.base_model import BaseModel 16 | from pvnet.optimizers import AbstractOptimizer 17 | from pvnet.training.plots import plot_sample_forecasts, wandb_line_plot 18 | from pvnet.utils import validate_batch_against_config 19 | 20 | 21 | class PVNetLightningModule(pl.LightningModule): 22 | """Lightning module for training PVNet models""" 23 | 24 | def __init__( 25 | self, 26 | model: BaseModel, 27 | optimizer: AbstractOptimizer, 28 | save_all_validation_results: bool = False, 29 | ): 30 | """Lightning module for training PVNet models 31 | 32 | Args: 33 | model: The PVNet model 34 | optimizer: Optimizer 35 | save_all_validation_results: Whether to save all the validation predictions to wandb 36 | """ 37 | super().__init__() 38 | 39 | self.model = model 40 | self._optimizer = optimizer 41 | self.save_all_validation_results = save_all_validation_results 42 | 43 | # Model must have lr to allow tuning 44 | # This setting is only used when lr is tuned with callback 45 | self.lr = None 46 | 47 | def transfer_batch_to_device( 48 | self, 49 | batch: TensorBatch, 50 | device: torch.device, 51 | dataloader_idx: int, 52 | ) -> dict: 53 | """Method to move custom batches to a given device""" 54 | return copy_batch_to_device(batch, device) 55 | 56 | def _calculate_quantile_loss(self, y_quantiles: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 57 | """Calculate quantile loss. 58 | 59 | Note: 60 | Implementation copied from: 61 | https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting 62 | /metrics/quantile.html#QuantileLoss.loss 63 | 64 | Args: 65 | y_quantiles: Quantile prediction of network 66 | y: Target values 67 | 68 | Returns: 69 | Quantile loss 70 | """ 71 | losses = [] 72 | for i, q in enumerate(self.model.output_quantiles): 73 | errors = y - y_quantiles[..., i] 74 | losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1)) 75 | losses = 2 * torch.cat(losses, dim=2) 76 | 77 | return losses.mean() 78 | 79 | def configure_optimizers(self): 80 | """Configure the optimizers using learning rate found with LR finder if used""" 81 | if self.lr is not None: 82 | # Use learning rate found by learning rate finder callback 83 | self._optimizer.lr = self.lr 84 | return self._optimizer(self.model) 85 | 86 | def _calculate_common_losses( 87 | self, 88 | y: torch.Tensor, 89 | y_hat: torch.Tensor, 90 | ) -> dict[str, torch.Tensor]: 91 | """Calculate losses common to train, and val""" 92 | 93 | losses = {} 94 | 95 | if self.model.use_quantile_regression: 96 | losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y) 97 | y_hat = self.model._quantiles_to_prediction(y_hat) 98 | 99 | losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)}) 100 | 101 | return losses 102 | 103 | def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor: 104 | """Run training step""" 105 | y_hat = self.model(batch) 106 | 107 | y = batch[self.model._target_key][:, -self.model.forecast_len :] 108 | 109 | losses = self._calculate_common_losses(y, y_hat) 110 | losses = {f"{k}/train": v for k, v in losses.items()} 111 | 112 | self.log_dict(losses, on_step=True, on_epoch=True) 113 | 114 | if self.model.use_quantile_regression: 115 | opt_target = losses["quantile_loss/train"] 116 | else: 117 | opt_target = losses["MAE/train"] 118 | return opt_target 119 | 120 | def _calculate_val_losses( 121 | self, 122 | y: torch.Tensor, 123 | y_hat: torch.Tensor, 124 | ) -> dict[str, torch.Tensor]: 125 | """Calculate additional losses only run in validation""" 126 | 127 | losses = {} 128 | 129 | if self.model.use_quantile_regression: 130 | metric_name = "val_fraction_below/fraction_below_{:.2f}_quantile" 131 | # Add fraction below each quantile for calibration 132 | for i, quantile in enumerate(self.model.output_quantiles): 133 | below_quant = y <= y_hat[..., i] 134 | # Mask values small values, which are dominated by night 135 | mask = y >= 0.01 136 | losses[metric_name.format(quantile)] = below_quant[mask].float().mean() 137 | 138 | return losses 139 | 140 | def _calculate_step_metrics( 141 | self, 142 | y: torch.Tensor, 143 | y_hat: torch.Tensor, 144 | ) -> tuple[np.array, np.array]: 145 | """Calculate the MAE and MSE at each forecast step""" 146 | 147 | mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0).cpu().numpy() 148 | mse_each_step = torch.mean((y_hat - y) ** 2, dim=0).cpu().numpy() 149 | 150 | return mae_each_step, mse_each_step 151 | 152 | def _store_val_predictions(self, batch: TensorBatch, y_hat: torch.Tensor) -> None: 153 | """Internally store the validation predictions""" 154 | 155 | target_key = self.model._target_key 156 | 157 | y = batch[target_key][:, -self.model.forecast_len :].cpu().numpy() 158 | y_hat = y_hat.cpu().numpy() 159 | ids = batch[f"{target_key}_id"].cpu().numpy() 160 | init_times_utc = pd.to_datetime( 161 | batch[f"{target_key}_time_utc"][:, self.model.history_len+1] 162 | .cpu().numpy().astype("datetime64[ns]") 163 | ) 164 | 165 | if self.model.use_quantile_regression: 166 | p_levels = self.model.output_quantiles 167 | else: 168 | p_levels = [0.5] 169 | y_hat = y_hat[..., None] 170 | 171 | ds_preds_batch = xr.Dataset( 172 | data_vars=dict( 173 | y_hat=(["sample_num", "forecast_step", "p_level"], y_hat), 174 | y=(["sample_num", "forecast_step"], y), 175 | ), 176 | coords=dict( 177 | ids=("sample_num", ids), 178 | init_times_utc=("sample_num", init_times_utc), 179 | p_level=p_levels, 180 | ), 181 | ) 182 | self.all_val_results.append(ds_preds_batch) 183 | 184 | def on_validation_epoch_start(self): 185 | """Run at start of val period""" 186 | # Set up stores which we will fill during validation 187 | self.all_val_results: list[xr.Dataset] = [] 188 | self._val_horizon_maes: list[np.array] = [] 189 | if self.current_epoch==0: 190 | self._val_persistence_horizon_maes: list[np.array] = [] 191 | 192 | # Plot some sample forecasts 193 | val_dataset = self.trainer.val_dataloaders.dataset 194 | 195 | plots_per_figure = 16 196 | num_figures = 2 197 | 198 | for plot_num in range(num_figures): 199 | idxs = np.arange(plots_per_figure) + plot_num * plots_per_figure 200 | idxs = idxs[idxs None: 234 | """Run validation step""" 235 | 236 | y_hat = self.model(batch) 237 | 238 | # Internally store the val predictions 239 | self._store_val_predictions(batch, y_hat) 240 | 241 | y = batch[self.model._target_key][:, -self.model.forecast_len :] 242 | 243 | losses = self._calculate_common_losses(y, y_hat) 244 | losses = {f"{k}/val": v for k, v in losses.items()} 245 | 246 | losses.update(self._calculate_val_losses(y, y_hat)) 247 | 248 | # Calculate the horizon MAE/MSE metrics 249 | if self.model.use_quantile_regression: 250 | y_hat_mid = self.model._quantiles_to_prediction(y_hat) 251 | else: 252 | y_hat_mid = y_hat 253 | 254 | mae_step, mse_step = self._calculate_step_metrics(y, y_hat_mid) 255 | 256 | # Store to make horizon-MAE plot 257 | self._val_horizon_maes.append(mae_step) 258 | 259 | # Also add each step to logged metrics 260 | losses.update({f"val_step_MAE/step_{i:03}": m for i, m in enumerate(mae_step)}) 261 | losses.update({f"val_step_MSE/step_{i:03}": m for i, m in enumerate(mse_step)}) 262 | 263 | # Calculate the persistance losses - we only need to do this once per training run 264 | # not every epoch 265 | if self.current_epoch==0: 266 | y_persist = ( 267 | batch[self.model._target_key][:, -(self.model.forecast_len+1)] 268 | .unsqueeze(1).expand(-1, self.model.forecast_len) 269 | ) 270 | mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist) 271 | self._val_persistence_horizon_maes.append(mae_step_persist) 272 | losses.update( 273 | { 274 | "MAE/val_persistence": mae_step_persist.mean(), 275 | "MSE/val_persistence": mse_step_persist.mean() 276 | } 277 | ) 278 | 279 | # Log the metrics 280 | self.log_dict(losses, on_step=False, on_epoch=True) 281 | 282 | def on_validation_epoch_end(self) -> None: 283 | """Run on epoch end""" 284 | 285 | ds_val_results = xr.concat(self.all_val_results, dim="sample_num") 286 | self.all_val_results = [] 287 | 288 | val_horizon_maes = np.mean(self._val_horizon_maes, axis=0) 289 | self._val_horizon_maes = [] 290 | 291 | # We only run this on the first epoch 292 | if self.current_epoch==0: 293 | val_persistence_horizon_maes = np.mean(self._val_persistence_horizon_maes, axis=0) 294 | self._val_persistence_horizon_maes = [] 295 | 296 | if isinstance(self.logger, pl.loggers.WandbLogger): 297 | # Calculate and log extreme error metrics 298 | val_error = ds_val_results["y"] - ds_val_results["y_hat"].sel(p_level=0.5) 299 | 300 | # Factor out this part of the string for brevity below 301 | s = "error_extremes/{}_percentile_median_forecast_error" 302 | s_abs = "error_extremes/{}_percentile_median_forecast_absolute_error" 303 | 304 | extreme_error_metrics = { 305 | s.format("2nd"): val_error.quantile(0.02).item(), 306 | s.format("5th"): val_error.quantile(0.05).item(), 307 | s.format("95th"): val_error.quantile(0.95).item(), 308 | s.format("98th"): val_error.quantile(0.98).item(), 309 | s_abs.format("95th"): np.abs(val_error).quantile(0.95).item(), 310 | s_abs.format("98th"): np.abs(val_error).quantile(0.98).item(), 311 | } 312 | 313 | self.log_dict(extreme_error_metrics, on_step=False, on_epoch=True) 314 | 315 | # Optionally save all validation results - these are overridden each epoch 316 | if self.save_all_validation_results: 317 | # Add attributes 318 | ds_val_results.attrs["epoch"] = self.current_epoch 319 | 320 | # Save locally to the wandb output dir 321 | wandb_log_dir = self.logger.experiment.dir 322 | filepath = f"{wandb_log_dir}/validation_results.netcdf" 323 | ds_val_results.to_netcdf(filepath) 324 | 325 | # Uplodad to wandb 326 | self.logger.experiment.save(filepath, base_path=wandb_log_dir, policy="now") 327 | 328 | # Create the horizon accuracy curve 329 | horizon_mae_plot = wandb_line_plot( 330 | x=np.arange(self.model.forecast_len), 331 | y=val_horizon_maes, 332 | xlabel="Horizon step", 333 | ylabel="MAE", 334 | title="Val horizon loss curve", 335 | ) 336 | 337 | wandb.log({"val_horizon_mae_plot": horizon_mae_plot}) 338 | 339 | # Create persistence horizon accuracy curve but only on first epoch 340 | if self.current_epoch==0: 341 | persist_horizon_mae_plot = wandb_line_plot( 342 | x=np.arange(self.model.forecast_len), 343 | y=val_persistence_horizon_maes, 344 | xlabel="Horizon step", 345 | ylabel="MAE", 346 | title="Val persistence horizon loss curve", 347 | ) 348 | wandb.log({"persistence_val_horizon_mae_plot": persist_horizon_mae_plot}) 349 | -------------------------------------------------------------------------------- /pvnet/training/plots.py: -------------------------------------------------------------------------------- 1 | """Plots logged during training""" 2 | from collections.abc import Sequence 3 | 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import pylab 7 | import torch 8 | import wandb 9 | from ocf_data_sampler.numpy_sample.common_types import TensorBatch 10 | 11 | 12 | def wandb_line_plot( 13 | x: Sequence[float], 14 | y: Sequence[float], 15 | xlabel: str, 16 | ylabel: str, 17 | title: str | None = None 18 | ) -> wandb.plot.CustomChart: 19 | """Make a wandb line plot""" 20 | data = [[xi, yi] for (xi, yi) in zip(x, y)] 21 | table = wandb.Table(data=data, columns=[xlabel, ylabel]) 22 | return wandb.plot.line(table, xlabel, ylabel, title=title) 23 | 24 | 25 | def plot_sample_forecasts( 26 | batch: TensorBatch, 27 | y_hat: torch.Tensor, 28 | quantiles: list[float] | None, 29 | key_to_plot: str, 30 | ) -> plt.Figure: 31 | """Plot a batch of data and the forecast from that batch""" 32 | 33 | y = batch[key_to_plot].cpu().numpy() 34 | y_hat = y_hat.cpu().numpy() 35 | ids = batch[f"{key_to_plot}_id"].cpu().numpy().squeeze() 36 | times_utc = pd.to_datetime( 37 | batch[f"{key_to_plot}_time_utc"].cpu().numpy().squeeze().astype("datetime64[ns]") 38 | ) 39 | batch_size = y.shape[0] 40 | 41 | fig, axes = plt.subplots(4, 4, figsize=(16, 16)) 42 | 43 | for i, ax in enumerate(axes.ravel()[:batch_size]): 44 | 45 | ax.plot(times_utc[i], y[i], marker=".", color="k", label=r"$y$") 46 | 47 | if quantiles is None: 48 | ax.plot( 49 | times_utc[i][-len(y_hat[i]) :], 50 | y_hat[i], 51 | marker=".", 52 | color="r", 53 | label=r"$\hat{y}$", 54 | ) 55 | else: 56 | cm = pylab.get_cmap("twilight") 57 | for nq, q in enumerate(quantiles): 58 | ax.plot( 59 | times_utc[i][-len(y_hat[i]) :], 60 | y_hat[i, :, nq], 61 | color=cm(q), 62 | label=r"$\hat{y}$" + f"({q})", 63 | alpha=0.7, 64 | ) 65 | 66 | ax.set_title(f"ID: {ids[i]} | {times_utc[i][0].date()}", fontsize="small") 67 | 68 | xticks = [t for t in times_utc[i] if t.minute == 0][::2] 69 | ax.set_xticks(ticks=xticks, labels=[f"{t.hour:02}" for t in xticks], rotation=90) 70 | ax.grid() 71 | 72 | axes[0, 0].legend(loc="best") 73 | 74 | if batch_size<16: 75 | for ax in axes.ravel()[batch_size:]: 76 | ax.axis("off") 77 | 78 | for ax in axes[-1, :]: 79 | ax.set_xlabel("Time (hour of day)") 80 | 81 | title = f"Normed {key_to_plot.upper()} output" 82 | 83 | plt.suptitle(title) 84 | plt.tight_layout() 85 | 86 | return fig 87 | -------------------------------------------------------------------------------- /pvnet/training/train.py: -------------------------------------------------------------------------------- 1 | """Training""" 2 | import logging 3 | import os 4 | import shutil 5 | 6 | import hydra 7 | from lightning.pytorch import ( 8 | Callback, 9 | LightningDataModule, 10 | LightningModule, 11 | Trainer, 12 | seed_everything, 13 | ) 14 | from lightning.pytorch.callbacks import ModelCheckpoint 15 | from lightning.pytorch.loggers import Logger, WandbLogger 16 | from omegaconf import DictConfig, OmegaConf 17 | 18 | from pvnet.utils import ( 19 | DATA_CONFIG_NAME, 20 | FULL_CONFIG_NAME, 21 | MODEL_CONFIG_NAME, 22 | ) 23 | 24 | log = logging.getLogger(__name__) 25 | 26 | 27 | def resolve_monitor_loss(output_quantiles: list | None) -> str: 28 | """Return the desired metric to monitor based on whether quantile regression is being used. 29 | 30 | The adds the option to use something like: 31 | monitor: "${resolve_monitor_loss:${model.model.output_quantiles}}" 32 | 33 | in early stopping and model checkpoint callbacks so the callbacks config does not need to be 34 | modified depending on whether quantile regression is being used or not. 35 | """ 36 | if output_quantiles is None: 37 | return "MAE/val" 38 | else: 39 | return "quantile_loss/val" 40 | 41 | 42 | OmegaConf.register_new_resolver("resolve_monitor_loss", resolve_monitor_loss) 43 | 44 | 45 | def train(config: DictConfig) -> None: 46 | """Contains training pipeline. 47 | 48 | Instantiates all PyTorch Lightning objects from config. 49 | 50 | Args: 51 | config (DictConfig): Configuration composed by Hydra. 52 | """ 53 | 54 | # Set seed for random number generators in pytorch, numpy and python.random 55 | if "seed" in config: 56 | seed_everything(config.seed, workers=True) 57 | 58 | # Init lightning datamodule 59 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) 60 | 61 | # Init lightning model 62 | model: LightningModule = hydra.utils.instantiate(config.model) 63 | 64 | # Init lightning loggers 65 | loggers: list[Logger] = [] 66 | if "logger" in config: 67 | for _, lg_conf in config.logger.items(): 68 | loggers.append(hydra.utils.instantiate(lg_conf)) 69 | 70 | # Init lightning callbacks 71 | callbacks: list[Callback] = [] 72 | if "callbacks" in config: 73 | for _, cb_conf in config.callbacks.items(): 74 | callbacks.append(hydra.utils.instantiate(cb_conf)) 75 | 76 | # Align the wandb id with the checkpoint path 77 | # - only works if wandb logger and model checkpoint used 78 | # - this makes it easy to push the model to huggingface 79 | use_wandb_logger = False 80 | for logger in loggers: 81 | if isinstance(logger, WandbLogger): 82 | use_wandb_logger = True 83 | wandb_logger = logger 84 | break 85 | 86 | # Set the output directory based in the wandb-id of the run 87 | if use_wandb_logger: 88 | for callback in callbacks: 89 | if isinstance(callback, ModelCheckpoint): 90 | # Calling the .experiment property instantiates a wandb run 91 | wandb_id = wandb_logger.experiment.id 92 | 93 | # Save the run results to the expected parent folder but with the folder name 94 | # set by the wandb ID 95 | save_dir = "/".join(callback.dirpath.split("/")[:-1] + [wandb_id]) 96 | 97 | callback.dirpath = save_dir 98 | 99 | # Save the model config 100 | os.makedirs(save_dir, exist_ok=True) 101 | OmegaConf.save(config.model, f"{save_dir}/{MODEL_CONFIG_NAME}") 102 | 103 | # Save the data config to the output directory and to wandb 104 | data_config = config.datamodule.configuration 105 | shutil.copyfile(data_config, f"{save_dir}/{DATA_CONFIG_NAME}") 106 | wandb_logger.experiment.save(f"{save_dir}/{DATA_CONFIG_NAME}", base_path=save_dir) 107 | 108 | # Save the full hydra config to the output directory and to wandb 109 | OmegaConf.save(config, f"{save_dir}/{FULL_CONFIG_NAME}") 110 | wandb_logger.experiment.save(f"{save_dir}/{FULL_CONFIG_NAME}", base_path=save_dir) 111 | 112 | break 113 | 114 | trainer: Trainer = hydra.utils.instantiate( 115 | config.trainer, 116 | logger=loggers, 117 | _convert_="partial", 118 | callbacks=callbacks, 119 | ) 120 | 121 | # Train the model completely 122 | trainer.fit(model=model, datamodule=datamodule) 123 | -------------------------------------------------------------------------------- /pvnet/utils.py: -------------------------------------------------------------------------------- 1 | """Utils""" 2 | import logging 3 | from typing import TYPE_CHECKING 4 | 5 | import rich.syntax 6 | import rich.tree 7 | from lightning.pytorch.utilities import rank_zero_only 8 | from omegaconf import DictConfig, OmegaConf 9 | 10 | if TYPE_CHECKING: 11 | from pvnet.models.base_model import BaseModel 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | PYTORCH_WEIGHTS_NAME = "model_weights.safetensors" 17 | MODEL_CONFIG_NAME = "model_config.yaml" 18 | DATA_CONFIG_NAME = "data_config.yaml" 19 | DATAMODULE_CONFIG_NAME = "datamodule_config.yaml" 20 | FULL_CONFIG_NAME = "full_experiment_config.yaml" 21 | MODEL_CARD_NAME = "README.md" 22 | 23 | 24 | def run_config_utilities(config: DictConfig) -> None: 25 | """A couple of optional utilities. 26 | 27 | Controlled by main config file: 28 | - forcing debug friendly configuration 29 | 30 | Modifies DictConfig in place. 31 | 32 | Args: 33 | config (DictConfig): Configuration composed by Hydra. 34 | """ 35 | 36 | # Enable adding new keys to config 37 | OmegaConf.set_struct(config, False) 38 | 39 | # Force debugger friendly configuration if 40 | if config.trainer.get("fast_dev_run"): 41 | logger.info("Forcing debugger friendly configuration! ") 42 | # Debuggers don't like GPUs or multiprocessing 43 | if config.trainer.get("gpus"): 44 | config.trainer.gpus = 0 45 | if config.datamodule.get("pin_memory"): 46 | config.datamodule.pin_memory = False 47 | if config.datamodule.get("num_workers"): 48 | config.datamodule.num_workers = 0 49 | if config.datamodule.get("prefetch_factor"): 50 | config.datamodule.prefetch_factor = None 51 | 52 | # Disable adding new keys to config 53 | OmegaConf.set_struct(config, True) 54 | 55 | 56 | @rank_zero_only 57 | def print_config( 58 | config: DictConfig, 59 | fields: tuple[str] = ( 60 | "trainer", 61 | "model", 62 | "datamodule", 63 | "callbacks", 64 | "logger", 65 | "seed", 66 | ), 67 | resolve: bool = True, 68 | ) -> None: 69 | """Prints content of DictConfig using Rich library and its tree structure. 70 | 71 | Args: 72 | config (DictConfig): Configuration composed by Hydra. 73 | fields (Sequence[str], optional): Determines which main fields from config will 74 | be printed and in what order. 75 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 76 | """ 77 | 78 | style = "dim" 79 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 80 | 81 | for field in fields: 82 | branch = tree.add(field, style=style, guide_style=style) 83 | 84 | config_section = config.get(field) 85 | 86 | branch_content = str(config_section) 87 | if isinstance(config_section, DictConfig): 88 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 89 | 90 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 91 | 92 | rich.print(tree) 93 | 94 | 95 | def validate_batch_against_config( 96 | batch: dict, 97 | model: "BaseModel", 98 | ) -> None: 99 | """Validates tensor shapes in batch against model configuration.""" 100 | logger.info("Performing batch shape validation against model config.") 101 | 102 | # NWP validation 103 | if hasattr(model, 'nwp_encoders_dict'): 104 | if "nwp" not in batch: 105 | raise ValueError( 106 | "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch." 107 | ) 108 | 109 | for source, nwp_data in batch["nwp"].items(): 110 | if source in model.nwp_encoders_dict: 111 | 112 | enc = model.nwp_encoders_dict[source] 113 | expected_channels = enc.in_channels 114 | if model.add_image_embedding_channel: 115 | expected_channels -= 1 116 | 117 | expected = (nwp_data["nwp"].shape[0], enc.sequence_length, 118 | expected_channels, enc.image_size_pixels, enc.image_size_pixels) 119 | if tuple(nwp_data["nwp"].shape) != expected: 120 | actual_shape = tuple(nwp_data['nwp'].shape) 121 | raise ValueError( 122 | f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}" 123 | ) 124 | 125 | # Satellite validation 126 | if hasattr(model, 'sat_encoder'): 127 | if "satellite_actual" not in batch: 128 | raise ValueError( 129 | "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch." 130 | ) 131 | 132 | enc = model.sat_encoder 133 | expected_channels = enc.in_channels 134 | if model.add_image_embedding_channel: 135 | expected_channels -= 1 136 | 137 | expected = (batch["satellite_actual"].shape[0], enc.sequence_length, expected_channels, 138 | enc.image_size_pixels, enc.image_size_pixels) 139 | if tuple(batch["satellite_actual"].shape) != expected: 140 | actual_shape = tuple(batch['satellite_actual'].shape) 141 | raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}") 142 | 143 | # GSP/Site validation 144 | key = model._target_key 145 | if key in batch: 146 | total_minutes = model.history_minutes + model.forecast_minutes 147 | interval = model.interval_minutes 148 | expected_len = total_minutes // interval + 1 149 | expected = (batch[key].shape[0], expected_len) 150 | if tuple(batch[key].shape) != expected: 151 | actual_shape = tuple(batch[key].shape) 152 | raise ValueError( 153 | f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}" 154 | ) 155 | 156 | logger.info("Batch shape validation successful!") 157 | 158 | 159 | def validate_gpu_config(config: DictConfig) -> None: 160 | """Abort if multiple GPUs requested by mistake i.e. `devices: 2` instead of `[2]`.""" 161 | tr = config.get("trainer", {}) 162 | dev = tr.get("devices") 163 | 164 | if isinstance(dev, int) and dev > 1: 165 | raise ValueError( 166 | f"Detected `devices: {dev}` — this requests {dev} GPUs. " 167 | "If you meant a specific GPU (e.g. GPU 2), use `devices: [2]`. " 168 | "Parallel training not supported." 169 | ) 170 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=67", "wheel", "setuptools-git-versioning>=2.0,<3"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name="PVNet" 7 | description = "PVNet" 8 | authors = [{name="Peter Dudfield", email="info@openclimatefix.org"}] 9 | dynamic = ["version"] 10 | license={file="LICENCE"} 11 | readme = {file="README.md", content-type="text/markdown"} 12 | requires-python = ">=3.11" 13 | 14 | dependencies = [ 15 | "ocf-data-sampler>=0.5.20", 16 | "numpy", 17 | "pandas", 18 | "matplotlib", 19 | "xarray", 20 | "h5netcdf", 21 | "torch>=2.0.0", 22 | "lightning", 23 | "typer", 24 | "sqlalchemy", 25 | "fsspec[s3]", 26 | "wandb", 27 | "huggingface-hub", 28 | "tqdm", 29 | "omegaconf", 30 | "hydra-core", 31 | "rich", 32 | "einops", 33 | "safetensors", 34 | ] 35 | 36 | [dependency-groups] 37 | dev=[ 38 | "ruff", 39 | "mypy", 40 | "pytest", 41 | "pytest-cov", 42 | ] 43 | 44 | [tool.setuptools-git-versioning] 45 | enabled = true 46 | 47 | [tool.setuptools.package-dir] 48 | "pvnet" = "pvnet" 49 | 50 | [tool.mypy] 51 | exclude = [ 52 | "^tests/", 53 | ] 54 | disallow_untyped_defs = true 55 | disallow_any_unimported = true 56 | no_implicit_optional = true 57 | check_untyped_defs = true 58 | warn_return_any = true 59 | warn_unused_ignores = true 60 | show_error_codes = true 61 | warn_unreachable = true 62 | 63 | [[tool.mypy.overrides]] 64 | module = [] 65 | ignore_missing_imports = true 66 | 67 | [tool.pytest.ini_options] 68 | minversion = "6.0" 69 | addopts = "-ra -q" 70 | testpaths = [ 71 | "tests", 72 | ] 73 | 74 | [tool.ruff] 75 | line-length = 100 76 | exclude = ["tests"] 77 | target-version = "py310" 78 | 79 | [tool.ruff.lint] 80 | extend-select = ["E", "D", "I"] 81 | ignore = ["D200","D202","D210","D212","D415","D105"] 82 | 83 | [tool.ruff.lint.mccabe] 84 | # Unlike Flake8, default to a complexity level of 10. 85 | max-complexity = 10 86 | 87 | [tool.ruff.lint.pydocstyle] 88 | # Use Google-style docstrings. 89 | convention = "google" 90 | 91 | [tool.ruff.lint.per-file-ignores] 92 | "__init__.py" = ["F401", "E402"] 93 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | """Run training. 2 | 3 | This file can be run for example using 4 | >> python run.py experiment=example_simple 5 | """ 6 | 7 | import logging 8 | import sys 9 | 10 | import hydra 11 | from omegaconf import DictConfig 12 | 13 | from pvnet.training import train 14 | from pvnet.utils import print_config, run_config_utilities, validate_gpu_config 15 | 16 | logging.basicConfig(stream=sys.stdout, level=logging.ERROR) 17 | 18 | 19 | 20 | @hydra.main(config_path="configs/", config_name="config.yaml", version_base="1.2") 21 | def main(config: DictConfig) -> None: 22 | """Runs training""" 23 | 24 | # A couple of optional utilities: 25 | # - disabling python warnings 26 | # - forcing debug friendly configuration 27 | # - forcing multi-gpu friendly configuration 28 | run_config_utilities(config) 29 | validate_gpu_config(config) 30 | print_config(config, resolve=True) 31 | 32 | return train(config) 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /scripts/backtest_sites.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script to run backtest for PVNet for specific sites 3 | 4 | Use: 5 | 6 | - This script uses hydra to construct the config, just like in `run.py`. So you need to make sure 7 | that the data config is set up appropriate for the model being run in this script 8 | - The following variables are hard coded near the top of the script and should be changed prior to 9 | use: 10 | - number of workers to use; 11 | - the PVNet model checkpoint (either local or HuggingFace repo details); 12 | - the time range over which predictions are made; 13 | - the output directory where the results are stored; 14 | 15 | - Outputs netCDF files with the predictions for each t0 in seperate files, 16 | each file has forecasts for all sites. 17 | Time resolution of the forecast t0s is the same as the time resolution of the generation data. 18 | 19 | - WARNING: this script currently assumes that if you are running the backtest for multiple sites 20 | (generation data being used has multiple sites). 21 | that they will all have the same t0s available in generation data, 22 | if they have non overlapping periods may be best to run this multiple times with 23 | different generation files for each site, otherwise silent errors could occur. 24 | 25 | ``` 26 | python scripts/backtest_sites.py 27 | ``` 28 | 29 | """ 30 | import os 31 | 32 | import hydra 33 | import numpy as np 34 | import pandas as pd 35 | import torch 36 | import xarray as xr 37 | from ocf_data_sampler.config import load_yaml_configuration 38 | from ocf_data_sampler.load.load_dataset import get_dataset_dict 39 | from ocf_data_sampler.numpy_sample.common_types import NumpyBatch 40 | from ocf_data_sampler.torch_datasets.datasets.site import SitesDatasetConcurrent 41 | from ocf_data_sampler.torch_datasets.sample.base import batch_to_tensor, copy_batch_to_device 42 | from omegaconf import DictConfig 43 | from torch.utils.data import DataLoader 44 | from tqdm import tqdm 45 | 46 | from pvnet.load_model import get_model_from_checkpoints 47 | from pvnet.models.base_model import BaseModel as PVNetBaseModel 48 | 49 | # ------------------------------------------------------------------ 50 | # USER CONFIGURED VARIABLES TO RUN THE SCRIPT 51 | 52 | num_workers = 2 53 | 54 | # Directory path to save results 55 | output_dir: str = "example_repo" 56 | 57 | # Local directory to load the PVNet checkpoint from. By default this should pull the best performing 58 | # checkpoint on the val set, set to None if using HF 59 | model_checkpoint_dir: str | None = None 60 | 61 | 62 | # Location to download exported PVNet model on HF, set to None if using local 63 | hf_model_id: str | None = "openclimatefix/example_repo" 64 | hf_revision: str | None = "95b1658c2b771e567fb3a0379e9bd600e0b1d209" 65 | 66 | # Forecasts will be made for all available init times between these 67 | start_datetime = "2024-06-05 00:00" 68 | end_datetime = "2024-06-05 03:00" 69 | 70 | # ------------------------------------------------------------------ 71 | # DERIVED VARIABLES 72 | 73 | # This will run on GPU if it exists 74 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 75 | 76 | # ------------------------------------------------------------------ 77 | # GLOBAL VARIABLES 78 | 79 | # When sun as elevation below this, the forecast is set to zero 80 | MIN_DAY_ELEVATION = 0 81 | 82 | # ------------------------------------------------------------------ 83 | # FUNCTIONS 84 | 85 | def preds_to_dataarray(preds, model, valid_times, site_ids): 86 | """Put numpy array of predictions into a dataarray""" 87 | 88 | if model.use_quantile_regression: 89 | output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles] 90 | output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw" 91 | else: 92 | output_labels = ["forecast_mw"] 93 | preds = preds[..., np.newaxis] 94 | da = xr.DataArray( 95 | data=preds, 96 | dims=["site_id", "target_datetime_utc", "output_label"], 97 | coords=dict( 98 | site_id = site_ids, 99 | target_datetime_utc=valid_times, 100 | output_label=output_labels, 101 | ), 102 | ) 103 | return da 104 | 105 | def get_sites_ds(config_path: str) -> xr.Dataset: 106 | """Load site data from the path in the data config. 107 | 108 | Args: 109 | config_path: Path to the data configuration file 110 | 111 | Returns: 112 | xarray.Dataset of PV sites data 113 | """ 114 | config = load_yaml_configuration(config_path) 115 | datasets_dict = get_dataset_dict(config.input_data) 116 | return datasets_dict["site"].to_dataset(name="site") 117 | 118 | 119 | class ModelPipe: 120 | """A class to conveniently make and process predictions from batches""" 121 | 122 | def __init__(self, model, ds_site: xr.Dataset, interval_start, interval_end, time_resolution): 123 | """A class to conveniently make and process predictions from batches 124 | 125 | Args: 126 | model: PVNet site level model 127 | ds_site: xarray dataset of pv site true values and capacities 128 | interval_start: The start timestamp (inclusive) for the prediction interval. 129 | interval_end: The end timestamp (exclusive) for the prediction interval. 130 | time_resolution: The time resolution (e.g., in minutes) for the prediction intervals. 131 | 132 | """ 133 | self.model = model 134 | self.ds_site = ds_site 135 | self.interval_start = interval_start 136 | self.interval_end = interval_end 137 | self.time_resolution = time_resolution 138 | 139 | def predict_batch(self, batch: NumpyBatch) -> xr.Dataset: 140 | """Run the batch through the model and compile the predictions into an xarray DataArray 141 | 142 | Args: 143 | batch: A batch containing inputs for a site 144 | 145 | Returns: 146 | xarray.Dataset of site forecasts for the sample 147 | """ 148 | 149 | tensor_batch = batch_to_tensor(batch) 150 | # First available timestamp in the sample (this is t0 + interval_start) 151 | first_time = pd.Timestamp(tensor_batch["site_time_utc"][0][0].item()) 152 | # Compute t0 (true start of forecast) 153 | t0 = first_time - pd.Timedelta(self.interval_start) 154 | 155 | # Generate valid times for inference (only t0 to t0 + interval_end) 156 | valid_times = pd.date_range( 157 | start=t0 + pd.Timedelta(self.time_resolution.astype(int), "min"), 158 | end=t0 + pd.Timedelta(self.interval_end), 159 | freq=f"{self.time_resolution.astype(int)}min", 160 | ) 161 | # Get capacity for this site 162 | site_capacities = [float(i) for i in self.ds_site["capacity_kwp"].values] 163 | # Get solar elevation and create sundown mask 164 | elevation = (tensor_batch['solar_elevation'] - 0.5) * 180 165 | # We only need elevation mask for forecasted values, not history 166 | elevation = elevation[:, -valid_times.shape[0]:] 167 | site_ids = self.ds_site["site_id"].values 168 | 169 | da_sundown_mask = xr.DataArray( 170 | data=elevation < MIN_DAY_ELEVATION, 171 | dims=["site_id", "target_datetime_utc"], 172 | coords=dict(site_id=site_ids, 173 | target_datetime_utc=valid_times, 174 | ), 175 | ) 176 | with torch.no_grad(): 177 | # Run through model to get 0-1 predictions 178 | tensor_batch = copy_batch_to_device(tensor_batch, device) 179 | y_normed = self.model(tensor_batch).detach().cpu().numpy() 180 | 181 | da_normed = preds_to_dataarray(y_normed, self.model, valid_times, site_ids) 182 | 183 | # Multiply normalised forecasts by capacity and clip negatives 184 | # Define multipliers for each id 185 | capacity_multipliers = xr.DataArray( 186 | data=site_capacities, 187 | dims=["site_id"], 188 | coords={"site_id": site_ids} 189 | ) 190 | da_abs = da_normed.clip(0, None) * capacity_multipliers 191 | 192 | # Apply sundown mask 193 | da_abs = da_abs.where(~da_sundown_mask).fillna(0.0) 194 | da_abs = da_abs.expand_dims(dim="init_time_utc", axis=0).assign_coords( 195 | init_time_utc=np.array([t0], dtype="datetime64[ns]") 196 | ) 197 | 198 | return da_abs 199 | 200 | 201 | @hydra.main(config_path="../configs", config_name="config.yaml", version_base="1.2") 202 | def main(config: DictConfig): 203 | """Runs the backtest""" 204 | 205 | dataloader_kwargs = dict( 206 | shuffle=False, 207 | batch_size=None, 208 | num_workers=num_workers, 209 | prefetch_factor=2 if num_workers>0 else None, 210 | multiprocessing_context="spawn" if num_workers>0 else None, 211 | pin_memory=False, 212 | drop_last=False, 213 | persistent_workers=False, 214 | sampler=None, 215 | batch_sampler=None, 216 | collate_fn=None, 217 | timeout=0, 218 | worker_init_fn=None, 219 | ) 220 | 221 | # Set up output dir 222 | os.makedirs(output_dir) 223 | 224 | # load yaml file 225 | unpacked_configuration = load_yaml_configuration(config.datamodule.configuration) 226 | 227 | interval_start = np.timedelta64( 228 | unpacked_configuration.input_data.site.interval_start_minutes, "m" 229 | ) 230 | interval_end = np.timedelta64(unpacked_configuration.input_data.site.interval_end_minutes, "m") 231 | time_resolution = np.timedelta64( 232 | unpacked_configuration.input_data.site.time_resolution_minutes, "m" 233 | ) 234 | 235 | # Create dataset 236 | dataset = SitesDatasetConcurrent( 237 | config.datamodule.configuration, start_time=start_datetime, end_time=end_datetime 238 | ) 239 | 240 | # Load the site data 241 | ds_sites = get_sites_ds(config.datamodule.configuration) 242 | 243 | # Create a dataloader 244 | dataloader = DataLoader(dataset, **dataloader_kwargs) 245 | 246 | # Load the PVNet model 247 | if model_checkpoint_dir: 248 | model, *_ = get_model_from_checkpoints([model_checkpoint_dir], val_best=True) 249 | model.eval() 250 | model.to(device) 251 | elif hf_model_id: 252 | model = PVNetBaseModel.from_pretrained( 253 | model_id=hf_model_id, 254 | revision=hf_revision).to(device).eval() 255 | else: 256 | raise ValueError("Provide a model checkpoint or a HuggingFace model") 257 | 258 | # Create object to make predictions 259 | model_pipe = ModelPipe(model, ds_sites, interval_start, interval_end, time_resolution) 260 | 261 | # Loop through the batches 262 | pbar = tqdm(total=len(dataset)) 263 | for i, batch in enumerate(dataloader): 264 | try: 265 | # Make predictions 266 | ds_abs_all = model_pipe.predict_batch(batch) 267 | t0 = ds_abs_all.init_time_utc.values[0] 268 | # Save the predictions 269 | filename = f"{output_dir}/{t0}.nc" 270 | ds_abs_all.to_netcdf(filename) 271 | 272 | pbar.update() 273 | except Exception as e: 274 | print(f"Exception {e} at batch {i}") 275 | pass 276 | 277 | pbar.close() 278 | del dataloader 279 | 280 | 281 | if __name__ == "__main__": 282 | main() 283 | -------------------------------------------------------------------------------- /scripts/backtest_uk_gsp.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script to run backtest for PVNet and the summation model for UK regional and national 3 | 4 | Use: 5 | 6 | - This script uses exported PVNet and PVNet summation models stored either locally or on huggingface 7 | - The save directory, model paths, the backtest time range, the input data paths, and number of 8 | workers used are near the top of the script as hard-coded user variables. These should be changed. 9 | 10 | 11 | ``` 12 | python backtest_uk_gsp.py 13 | ``` 14 | 15 | """ 16 | 17 | import logging 18 | import os 19 | 20 | import numpy as np 21 | import pandas as pd 22 | import torch 23 | import xarray as xr 24 | import yaml 25 | from ocf_data_sampler.torch_datasets.sample.base import batch_to_tensor, copy_batch_to_device 26 | from pvnet_summation.data.datamodule import StreamedDataset 27 | from pvnet_summation.models.base_model import BaseModel as SummationBaseModel 28 | from torch.utils.data import DataLoader 29 | from tqdm import tqdm 30 | 31 | from pvnet.models.base_model import BaseModel as PVNetBaseModel 32 | 33 | # ------------------------------------------------------------------ 34 | # USER CONFIGURED VARIABLES 35 | output_dir = "/home/james/tmp/test_backtest/pvnet_v2" 36 | 37 | # Number of workers to use in the dataloader 38 | num_workers = 16 39 | 40 | # Location of the exported PVNet and summation model pair 41 | pvnet_model_name: str = "openclimatefix/pvnet_uk_region" 42 | pvnet_model_version: str | None = "ff09e4aee871fe094d3a2dabe9d9cea50e4b5485" 43 | 44 | summation_model_name: str = "openclimatefix/pvnet_v2_summation" 45 | summation_model_version: str | None = "d746683893330fe3380e57e65d40812daa343c8e" 46 | 47 | # Forecasts will be made for all available init times between these 48 | start_datetime: str | None = "2022-01-01 00:00" 49 | end_datetime: str | None = "2022-12-31 23:30" 50 | 51 | # The paths to the input data for the backtest 52 | backtest_paths = { 53 | "gsp": "/mnt/raphael/fast/crops/pv/pvlive_gsp_new_boundaries_2019-2025.zarr", 54 | "nwp": { 55 | "ukv": [ 56 | "/mnt/raphael/fast/crops/nwp/ukv/UKV_v7/UKV_intermediate_version_7.1.zarr", 57 | "/mnt/raphael/fast/crops/nwp/ukv/UKV_v7/UKV_2021_missing.zarr", 58 | "/mnt/raphael/fast/crops/nwp/ukv/UKV_v7/UKV_2022.zarr", 59 | ], 60 | "ecmwf": [ 61 | "/mnt/raphael/fast/crops/nwp/ecmwf/uk_v3/ECMWF_2019.zarr", 62 | "/mnt/raphael/fast/crops/nwp/ecmwf/uk_v3/ECMWF_2020.zarr", 63 | "/mnt/raphael/fast/crops/nwp/ecmwf/uk_v3/ECMWF_2021.zarr", 64 | "/mnt/raphael/fast/crops/nwp/ecmwf/uk_v3/ECMWF_2022.zarr", 65 | ], 66 | "cloudcasting": "/mnt/raphael/fast/cloudcasting/simvp.zarr", 67 | }, 68 | "satellite": [ 69 | "/mnt/raphael/fast/crops/sat/uk_sat_crops/v1/2019_nonhrv.zarr", 70 | "/mnt/raphael/fast/crops/sat/uk_sat_crops/v1/2020_nonhrv.zarr", 71 | "/mnt/raphael/fast/crops/sat/uk_sat_crops/v1/2021_nonhrv.zarr", 72 | "/mnt/raphael/fast/crops/sat/uk_sat_crops/v1/2022_nonhrv.zarr", 73 | ], 74 | } 75 | 76 | # When sun as elevation below this, the forecast is set to zero 77 | MIN_DAY_ELEVATION = 0 78 | 79 | # ------------------------------------------------------------------ 80 | 81 | logger = logging.getLogger(__name__) 82 | 83 | # This will run on GPU if it exists 84 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 85 | 86 | # ------------------------------------------------------------------ 87 | # FUNCTIONS 88 | 89 | _model_mismatch_msg = ( 90 | "The PVNet version running in this app is {}/{}. The summation model running in this app was " 91 | "trained on outputs from PVNet version {}/{}. Combining these models may lead to an error if " 92 | "the shape of PVNet output doesn't match the expected shape of the summation model. Combining " 93 | "may lead to unreliable results even if the shapes match." 94 | ) 95 | 96 | def populate_config_with_data_data_filepaths(config: dict) -> dict: 97 | """Populate the data source filepaths in the config 98 | 99 | Args: 100 | config: The data config 101 | """ 102 | 103 | # Replace the GSP data path 104 | config["input_data"]["gsp"]["zarr_path"] = backtest_paths["gsp"] 105 | 106 | # Replace satellite data path if using it 107 | if "satellite" in config["input_data"]: 108 | if config["input_data"]["satellite"]["zarr_path"] != "": 109 | config["input_data"]["satellite"]["zarr_path"] = backtest_paths["satellite"] 110 | 111 | # NWP is nested so much be treated separately 112 | if "nwp" in config["input_data"]: 113 | nwp_config = config["input_data"]["nwp"] 114 | for nwp_source in nwp_config.keys(): 115 | provider = nwp_config[nwp_source]["provider"] 116 | assert provider in backtest_paths["nwp"], f"Missing NWP path: {provider}" 117 | nwp_config[nwp_source]["zarr_path"] = backtest_paths["nwp"][provider] 118 | 119 | return config 120 | 121 | 122 | def overwrite_config_dropouts(config: dict) -> dict: 123 | """Overwrite the config drouput parameters for the backtest 124 | 125 | Args: 126 | config: The data config 127 | """ 128 | if "satellite" in config["input_data"]: 129 | 130 | satellite_config = config["input_data"]["satellite"] 131 | 132 | if satellite_config["zarr_path"] != "": 133 | satellite_config["dropout_timedeltas_minutes"] = [] 134 | satellite_config["dropout_fraction"] = 0 135 | 136 | # Don't modify NWP dropout since this accounts for the expected NWP delay 137 | 138 | return config 139 | 140 | 141 | class BacktestStreamedDataset(StreamedDataset): 142 | """A torch dataset object used only for backtesting""" 143 | 144 | def _get_sample(self, t0: pd.Timestamp) -> ...: 145 | """Generate a concurrent PVNet sample for given init-time + augment for backtesting. 146 | 147 | Args: 148 | t0: init-time for sample 149 | """ 150 | 151 | sample = super()._get_sample(t0) 152 | 153 | total_capacity = self.national_gsp_data.sel(time_utc=t0).effective_capacity_mwp.item() 154 | 155 | sample.update( 156 | { 157 | "backtest_t0": t0, 158 | "backtest_national_capacity": total_capacity, 159 | } 160 | ) 161 | 162 | return sample 163 | 164 | 165 | class Forecaster: 166 | """Class for making and solar forecasts for all GB GSPs and national total""" 167 | 168 | def __init__(self): 169 | """Class for making and solar forecasts for all GB GSPs and national total 170 | """ 171 | 172 | # Load the GSP-level model 173 | self.model = PVNetBaseModel.from_pretrained( 174 | model_id=pvnet_model_name, 175 | revision=pvnet_model_version, 176 | ).to(device).eval() 177 | 178 | # Load the summation model 179 | self.sum_model = SummationBaseModel.from_pretrained( 180 | model_id=summation_model_name, 181 | revision=summation_model_version, 182 | ).to(device).eval() 183 | 184 | # Compare the current GSP model with the one the summation model was trained on 185 | datamodule_path = SummationBaseModel.get_datamodule_config( 186 | model_id=summation_model_name, 187 | revision=summation_model_version, 188 | ) 189 | with open(datamodule_path) as cfg: 190 | sum_pvnet_cfg = yaml.load(cfg, Loader=yaml.FullLoader)["pvnet_model"] 191 | 192 | sum_expected_gsp_model = (sum_pvnet_cfg["model_id"], sum_pvnet_cfg["revision"]) 193 | this_gsp_model = (pvnet_model_name, pvnet_model_version) 194 | 195 | if sum_expected_gsp_model != this_gsp_model: 196 | logger.warning(_model_mismatch_msg.format(*this_gsp_model, *sum_expected_gsp_model)) 197 | 198 | # These are the steps this forecast will predict for 199 | self.steps = pd.timedelta_range( 200 | start="30min", 201 | freq="30min", 202 | periods=self.model.forecast_len, 203 | ) 204 | 205 | @torch.inference_mode() 206 | def predict(self, sample: dict) -> xr.Dataset: 207 | """Make predictions for the batch and store results internally""" 208 | 209 | x = copy_batch_to_device(batch_to_tensor(sample["pvnet_inputs"]), device) 210 | 211 | # Run batch through model 212 | normed_preds = self.model(x).detach().cpu().numpy() 213 | 214 | # Calculate sun mask 215 | # The dataloader normalises solar elevation data to the range [0, 1] 216 | elevation_degrees = (sample["pvnet_inputs"]["solar_elevation"] - 0.5) * 180 217 | # We only need elevation mask for forecasted values, not history 218 | elevation_degrees = elevation_degrees[:, -normed_preds.shape[1]:] 219 | sun_down_masks = elevation_degrees < MIN_DAY_ELEVATION 220 | 221 | # Convert GSP results to xarray DataArray 222 | t0 = sample["backtest_t0"] 223 | gsp_ids = sample["pvnet_inputs"]["gsp_id"] 224 | 225 | da_normed = self.to_dataarray( 226 | normed_preds, 227 | t0, 228 | gsp_ids, 229 | self.model.output_quantiles, 230 | ) 231 | 232 | da_sundown_mask = self.to_dataarray(sun_down_masks, t0, gsp_ids, None) 233 | 234 | # Multiply normalised forecasts by capacities and clip negatives 235 | da_abs = ( 236 | da_normed.clip(0, None) 237 | * sample["pvnet_inputs"]["gsp_effective_capacity_mwp"][None, :, None, None].numpy() 238 | ) 239 | 240 | # Apply sundown mask 241 | da_abs = da_abs.where(~da_sundown_mask).fillna(0.0) 242 | 243 | # Make national predictions using summation model 244 | # - Need to add batch dimension and convert to torch tensors on device 245 | sample["pvnet_outputs"] = torch.tensor(normed_preds[None]).to(device) 246 | sample["relative_capacity"] = sample["relative_capacity"][None].to(device) 247 | normed_national = self.sum_model(sample).detach().squeeze().cpu().numpy() 248 | 249 | # Convert national predictions to DataArray 250 | da_normed_national = self.to_dataarray( 251 | normed_national[np.newaxis], 252 | t0, 253 | gsp_ids=[0], 254 | output_quantiles=self.sum_model.output_quantiles, 255 | ) 256 | 257 | # Multiply normalised forecasts by capacity and clip negatives 258 | national_capacity = sample["backtest_national_capacity"] 259 | da_abs_national = da_normed_national.clip(0, None) * national_capacity 260 | 261 | # Apply sundown mask - All GSPs must be masked to mask national 262 | da_abs_national = da_abs_national.where(~da_sundown_mask.all(dim="gsp_id")).fillna(0.0) 263 | 264 | # Convert to Dataset and add attrs about the models used 265 | ds_result = xr.concat([da_abs_national, da_abs], dim="gsp_id").to_dataset(name="hindcast") 266 | ds_result.attrs.update( 267 | { 268 | "pvnet_model_name": pvnet_model_name, 269 | "pvnet_model_version": pvnet_model_version, 270 | "summation_model_name": summation_model_name, 271 | "summation_model_version": summation_model_version, 272 | } 273 | ) 274 | 275 | return ds_result 276 | 277 | def to_dataarray( 278 | self, 279 | preds: np.ndarray, 280 | t0: pd.Timestamp, 281 | gsp_ids: list[int], 282 | output_quantiles: list[float] | None, 283 | ) -> xr.DataArray: 284 | """Put numpy array of predictions into a dataarray""" 285 | 286 | dims = ["init_time_utc", "gsp_id", "step"] 287 | coords = dict( 288 | init_time_utc=[t0], 289 | gsp_id=gsp_ids, 290 | step=self.steps, 291 | ) 292 | 293 | if output_quantiles is not None: 294 | dims.append("quantile") 295 | coords["quantile"] = output_quantiles 296 | 297 | return xr.DataArray(data=preds[np.newaxis, ...], dims=dims, coords=coords) 298 | 299 | # ------------------------------------------------------------------ 300 | # RUN 301 | 302 | if __name__=="__main__": 303 | 304 | # Set up output dir 305 | os.makedirs(output_dir) 306 | 307 | data_config_path = PVNetBaseModel.get_data_config( 308 | model_id=pvnet_model_name, 309 | revision=pvnet_model_version, 310 | ) 311 | 312 | with open(data_config_path) as file: 313 | data_config = yaml.load(file, Loader=yaml.FullLoader) 314 | 315 | data_config = populate_config_with_data_data_filepaths(data_config) 316 | data_config = overwrite_config_dropouts(data_config) 317 | 318 | modified_data_config_filepath = f"{output_dir}/data_config.yaml" 319 | 320 | with open(modified_data_config_filepath, "w") as file: 321 | yaml.dump(data_config, file, default_flow_style=False) 322 | 323 | 324 | dataset = BacktestStreamedDataset( 325 | config_filename=modified_data_config_filepath, 326 | start_time=start_datetime, 327 | end_time=end_datetime, 328 | ) 329 | 330 | dataloader_kwargs = dict( 331 | num_workers=num_workers, 332 | prefetch_factor=2 if num_workers>0 else None, 333 | multiprocessing_context="spawn" if num_workers>0 else None, 334 | shuffle=False, 335 | batch_size=None, 336 | sampler=None, 337 | batch_sampler=None, 338 | collate_fn=None, 339 | drop_last=False, 340 | timeout=0, 341 | worker_init_fn=None, 342 | persistent_workers=False, 343 | ) 344 | 345 | if num_workers>0: 346 | dataset.presave_pickle(f"{output_dir}/dataset.pkl") 347 | 348 | dataloader = DataLoader(dataset, **dataloader_kwargs) 349 | forecaster = Forecaster() 350 | 351 | # Loop through the batches 352 | pbar = tqdm(total=len(dataloader)) 353 | for sample in dataloader: 354 | # Make predictions for the init-time 355 | ds_abs_all = forecaster.predict(sample) 356 | 357 | # Save the predictions 358 | t0 = pd.Timestamp(ds_abs_all.init_time_utc.item()) 359 | filename = f"{output_dir}/{t0}.nc" 360 | ds_abs_all.to_netcdf(filename) 361 | 362 | pbar.update() 363 | 364 | # Close down 365 | pbar.close() 366 | 367 | # Clean up 368 | if num_workers>0: 369 | os.remove(f"{output_dir}/dataset.pkl") -------------------------------------------------------------------------------- /scripts/checkpoint_to_huggingface.py: -------------------------------------------------------------------------------- 1 | """Command line tool to push locally save model checkpoints to huggingface 2 | 3 | To use this script, you will need to write a custom model card. You can copy and fill out 4 | `pvnet/model_cards/empty_model_card_template.md` to get you started. 5 | 6 | These model cards should not be added to and version controlled in the repo since they are specific 7 | to each user. 8 | 9 | Then run using: 10 | 11 | ``` 12 | python checkpoint_to_huggingface.py "path/to/model/checkpoints" \ 13 | --huggingface-repo="openclimatefix/pvnet_uk_region" \ 14 | --wandb-repo="openclimatefix/pvnet2.1" \ 15 | --card-template-path="pvnet/models/model_cards/my_custom_model_card.md" \ 16 | --local-path="~/tmp/this_model" \ 17 | --no-push-to-hub 18 | ``` 19 | """ 20 | 21 | import tempfile 22 | 23 | import typer 24 | import wandb 25 | 26 | from pvnet.load_model import get_model_from_checkpoints 27 | 28 | app = typer.Typer(pretty_exceptions_show_locals=False) 29 | 30 | 31 | @app.command() 32 | def push_to_huggingface( 33 | checkpoint_dir_paths: list[str] = typer.Argument(...,), 34 | huggingface_repo: str = typer.Option(..., "--huggingface-repo"), 35 | wandb_repo: str = typer.Option(..., "--wandb-repo"), 36 | card_template_path: str = typer.Option(..., "--card-template-path"), 37 | wandb_ids: list[str] = typer.Option([], "--wandb-id"), 38 | val_best: bool = typer.Option(True), 39 | local_path: str = typer.Option(None, "--local-path"), 40 | push_to_hub: bool = typer.Option(True), 41 | ): 42 | """Push a local model to a huggingface model repo 43 | 44 | Args: 45 | checkpoint_dir_paths: Path(s) of the checkpoint directory(ies) 46 | huggingface_repo: Name of the HuggingFace repo to push the model to 47 | wandb_repo: Name of the wandb repo which has training logs 48 | card_template_path: Path to the model card template. 49 | wandb_ids: The wandb ID code(s) - if not filled out these are taken 50 | val_best: Use best model according to val loss, else last saved model 51 | local_path: Where to save the local copy of the model 52 | push_to_hub: Whether to push the model to the hub or just create local version. 53 | """ 54 | 55 | assert push_to_hub or local_path is not None 56 | 57 | is_ensemble = len(checkpoint_dir_paths) > 1 58 | 59 | # Check that the wandb-IDs are correct 60 | all_wandb_ids = [run.id for run in wandb.Api().runs(path=wandb_repo)] 61 | 62 | # If the IDs are not supplied try and pull them from the checkpoint dir name 63 | if wandb_ids == []: 64 | for path in checkpoint_dir_paths: 65 | dirname = path.split("/")[-1] 66 | if dirname in all_wandb_ids: 67 | wandb_ids.append(dirname) 68 | else: 69 | raise Exception(f"Could not find wand run for {path} within {wandb_repo}") 70 | 71 | # Else if they are provided check that they exist 72 | else: 73 | for wandb_id in wandb_ids: 74 | if wandb_id not in all_wandb_ids: 75 | raise Exception(f"Could not find wand run for {path} within {wandb_repo}") 76 | 77 | ( 78 | model, 79 | model_config, 80 | data_config_path, 81 | datamodule_config_path, 82 | experiment_config_path, 83 | ) = get_model_from_checkpoints(checkpoint_dir_paths, val_best) 84 | 85 | if not is_ensemble: 86 | wandb_ids = wandb_ids[0] 87 | 88 | # Push to hub 89 | if local_path is None: 90 | temp_dir = tempfile.TemporaryDirectory() 91 | model_output_dir = temp_dir.name 92 | else: 93 | model_output_dir = local_path 94 | 95 | model.save_pretrained( 96 | save_directory=model_output_dir, 97 | model_config=model_config, 98 | data_config_path=data_config_path, 99 | datamodule_config_path=datamodule_config_path, 100 | experiment_config_path=experiment_config_path, 101 | wandb_repo=wandb_repo, 102 | wandb_ids=wandb_ids, 103 | card_template_path=card_template_path, 104 | push_to_hub=push_to_hub, 105 | hf_repo_id=huggingface_repo if push_to_hub else None, 106 | ) 107 | 108 | if local_path is None: 109 | temp_dir.cleanup() 110 | 111 | 112 | if __name__ == "__main__": 113 | app() 114 | -------------------------------------------------------------------------------- /scripts/mae_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to generate analysis of MAE values for multiple model forecasts 3 | 4 | Does this for 48 hour horizon forecasts with 15 minute granularity 5 | 6 | """ 7 | 8 | import argparse 9 | 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pandas as pd 14 | import wandb 15 | 16 | matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler( 17 | color=[ 18 | "FFD053", # yellow 19 | "7BCDF3", # blue 20 | "63BCAF", # teal 21 | "086788", # dark blue 22 | "FF9736", # dark orange 23 | "E4E4E4", # grey 24 | "14120E", # black 25 | "FFAC5F", # orange 26 | "4C9A8E", # dark teal 27 | ] 28 | ) 29 | 30 | 31 | def main(project: str, runs: list[str], run_names: list[str]) -> None: 32 | """ 33 | Compare MAE values for multiple model forecasts for 48 hour horizon with 15 minute granularity 34 | 35 | Args: 36 | project: name of W&B project 37 | runs: W&B ids of runs 38 | run_names: user specified names for runs 39 | 40 | """ 41 | api = wandb.Api() 42 | dfs = [] 43 | epoch_num = [] 44 | for run in runs: 45 | run = api.run(f"openclimatefix/{project}/{run}") 46 | 47 | df = run.history(samples=run.lastHistoryStep + 1) 48 | # Get the columns that are in the format 'val_step_MAE/step_' 49 | mae_cols = [col for col in df.columns if "val_step_MAE/step_" in col] 50 | # Sort them 51 | mae_cols.sort() 52 | df = df[mae_cols] 53 | # Get last non-NaN value 54 | # Drop all rows with all NaNs 55 | df = df.dropna(how="all") 56 | # Select the last row 57 | # Get average across entire row, and get the IDX for the one with the smallest values 58 | min_row_mean = np.inf 59 | for idx, (row_idx, row) in enumerate(df.iterrows()): 60 | if row.mean() < min_row_mean: 61 | min_row_mean = row.mean() 62 | min_row_idx = idx 63 | df = df.iloc[min_row_idx] 64 | # Calculate the timedelta for each group 65 | # Get the step from the column name 66 | column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols] 67 | dfs.append(df) 68 | epoch_num.append(min_row_idx) 69 | # Get the timedelta for each group 70 | groupings = [ 71 | [0, 0], 72 | [15, 15], 73 | [30, 45], 74 | [45, 60], 75 | [60, 120], 76 | [120, 240], 77 | [240, 360], 78 | [360, 480], 79 | [480, 720], 80 | [720, 1440], 81 | [1440, 2880], 82 | ] 83 | 84 | groups_df = [] 85 | grouping_starts = [grouping[0] for grouping in groupings] 86 | header = "| Timestep |" 87 | separator = "| --- |" 88 | for run_name in run_names: 89 | header += f" {run_name} MAE % |" 90 | separator += " --- |" 91 | print(header) 92 | print(separator) 93 | for grouping in groupings: 94 | group_string = f"| {grouping[0]}-{grouping[1]} minutes |" 95 | # Select indicies from column_timesteps that are within the grouping, inclusive 96 | group_idx = [ 97 | idx 98 | for idx, timestep in enumerate(column_timesteps) 99 | if timestep >= grouping[0] and timestep <= grouping[1] 100 | ] 101 | data_one_group = [] 102 | for df in dfs: 103 | mean_row = df.iloc[group_idx].mean() 104 | group_string += f" {mean_row:0.3f} |" 105 | data_one_group.append(mean_row) 106 | print(group_string) 107 | 108 | groups_df.append(data_one_group) 109 | 110 | groups_df = pd.DataFrame(groups_df, columns=run_names, index=grouping_starts) 111 | 112 | for idx, df in enumerate(dfs): 113 | print(f"{run_names[idx]}: {df.mean()*100:0.3f}") 114 | 115 | # Plot the error per timestep 116 | plt.figure() 117 | for idx, df in enumerate(dfs): 118 | plt.plot( 119 | column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}", linestyle="-" 120 | ) 121 | plt.legend() 122 | plt.xlabel("Timestep (minutes)") 123 | plt.ylabel("MAE %") 124 | plt.title("MAE % for each timestep") 125 | plt.savefig("mae_per_timestep.png") 126 | plt.show() 127 | 128 | # Plot the error per grouped timestep 129 | plt.figure() 130 | for idx, run_name in enumerate(run_names): 131 | plt.plot( 132 | groups_df[run_name], 133 | label=f"{run_name}, epoch: {epoch_num[idx]}", 134 | marker="o", 135 | linestyle="-", 136 | ) 137 | plt.legend() 138 | plt.xlabel("Timestep (minutes)") 139 | plt.ylabel("MAE %") 140 | plt.title("MAE % for each grouped timestep") 141 | plt.savefig("mae_per_grouped_timestep.png") 142 | plt.show() 143 | 144 | 145 | if __name__ == "__main__": 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument("--project", type=str, default="") 148 | # Add arguments that is a list of strings 149 | parser.add_argument("--list_of_runs", nargs="+") 150 | parser.add_argument("--run_names", nargs="+") 151 | args = parser.parse_args() 152 | main(args.project, args.list_of_runs, args.run_names) 153 | -------------------------------------------------------------------------------- /scripts/migrate_old_model.py: -------------------------------------------------------------------------------- 1 | """Script to migrate old PVNet models which are hosted on huggingface to current version 2 | 3 | This script can be used to update models from version >= v4.1 4 | """ 5 | import datetime 6 | import os 7 | import tempfile 8 | from importlib.metadata import version 9 | 10 | import torch 11 | import yaml 12 | from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi, file_exists 13 | from safetensors.torch import save_file 14 | 15 | from pvnet.models.base_model import BaseModel 16 | from pvnet.utils import DATA_CONFIG_NAME, MODEL_CARD_NAME, MODEL_CONFIG_NAME, PYTORCH_WEIGHTS_NAME 17 | 18 | # ------------------------------------------ 19 | # USER SETTINGS 20 | 21 | # The huggingface commit of the model you want to update 22 | repo_id: str = "openclimatefix/pvnet_uk_region" 23 | revision: str = "6feaa986a6bed3cc6c7961c6bf9e92fb15acca6a" 24 | 25 | # The local directory which will be downloaded to 26 | # If set to None a temporary directory will be used 27 | local_dir: str | None = None 28 | 29 | # Whether to upload the migrated model back to the huggingface - else just saved locally 30 | upload: bool = False 31 | 32 | # ------------------------------------------ 33 | # SETUP 34 | 35 | if local_dir is None: 36 | temp_dir = tempfile.TemporaryDirectory() 37 | save_dir = temp_dir.name 38 | 39 | else: 40 | os.makedirs(local_dir, exist_ok=False) 41 | save_dir = local_dir 42 | 43 | # Set up huggingface API 44 | api = HfApi() 45 | 46 | # Download the model repo 47 | _ = api.snapshot_download( 48 | repo_id=repo_id, 49 | revision=revision, 50 | local_dir=save_dir, 51 | force_download=True, 52 | ) 53 | 54 | # ------------------------------------------ 55 | # MIGRATION STEPS 56 | 57 | # Modify the model config 58 | with open(f"{save_dir}/{MODEL_CONFIG_NAME}") as cfg: 59 | model_config = yaml.load(cfg, Loader=yaml.FullLoader) 60 | 61 | # Get rid of the optimiser - we don't store this anymore 62 | if "optimizer" in model_config: 63 | del model_config["optimizer"] 64 | 65 | # This parameter has been moved out of the model to the pytorch lightning module 66 | if "save_validation_results_csv" in model_config: 67 | del model_config["save_validation_results_csv"] 68 | 69 | # This parameter has removed 70 | if "adapt_batches" in model_config: 71 | del model_config["adapt_batches"] 72 | 73 | # Rename the top level model 74 | if model_config["_target_"]=="pvnet.models.multimodal.multimodal.Model": 75 | model_config["_target_"] = "pvnet.models.LateFusionModel" 76 | elif model_config["_target_"] == "pvnet.models.LateFusionModel": 77 | pass 78 | else: 79 | raise Exception("Unknown model: " + model_config["_target_"]) 80 | 81 | # Re-find the model components in the new package structure 82 | if model_config.get("nwp_encoders_dict", None) is not None: 83 | for k, v in model_config["nwp_encoders_dict"].items(): 84 | v["_target_"] = ( 85 | v["_target_"] 86 | .replace("multimodal", "late_fusion") 87 | .replace("ResConv3DNet2", "ResConv3DNet") 88 | ) 89 | 90 | 91 | for component in ["sat_encoder", "pv_encoder", "output_network"]: 92 | if model_config.get(component, None) is not None: 93 | model_config[component]["_target_"] = ( 94 | model_config[component]["_target_"] 95 | .replace("multimodal", "late_fusion") 96 | .replace("ResConv3DNet2", "ResConv3DNet") 97 | .replace("ResFCNet2", "ResFCNet") 98 | ) 99 | 100 | with open(f"{save_dir}/{MODEL_CONFIG_NAME}", "w") as f: 101 | yaml.dump(model_config, f, sort_keys=False, default_flow_style=False) 102 | 103 | # Resave the model weights as safetensors if in old format 104 | if os.path.exists(f"{save_dir}/pytorch_model.bin"): 105 | state_dict = torch.load(f"{save_dir}/pytorch_model.bin", map_location="cpu", weights_only=True) 106 | save_file(state_dict, f"{save_dir}/{PYTORCH_WEIGHTS_NAME}") 107 | os.remove(f"{save_dir}/pytorch_model.bin") 108 | else: 109 | assert os.path.exists(f"{save_dir}/{PYTORCH_WEIGHTS_NAME}") 110 | 111 | # Add a note to the model card to say the model has been migrated 112 | with open(f"{save_dir}/{MODEL_CARD_NAME}", "a") as f: 113 | current_date = datetime.date.today().strftime("%Y-%m-%d") 114 | pvnet_version = version("pvnet") 115 | f.write( 116 | f"\n\n---\n**Migration Note**: This model was migrated on {current_date} " 117 | f"to pvnet version {pvnet_version}\n" 118 | ) 119 | 120 | # ------------------------------------------ 121 | # CHECKS 122 | 123 | # Check the model can be loaded 124 | model = BaseModel.from_pretrained(model_id=save_dir, revision=None) 125 | 126 | print("Model checkpoint successfully migrated") 127 | 128 | # ------------------------------------------ 129 | # UPLOAD TO HUGGINGFACE 130 | 131 | if upload: 132 | print("Uploading migrated model to huggingface") 133 | 134 | operations = [] 135 | for file in [MODEL_CARD_NAME, MODEL_CONFIG_NAME, PYTORCH_WEIGHTS_NAME, DATA_CONFIG_NAME]: 136 | # Stage modified files for upload 137 | operations.append( 138 | CommitOperationAdd( 139 | path_in_repo=file, # Name of the file in the repo 140 | path_or_fileobj=f"{save_dir}/{file}", # Local path to the file 141 | ), 142 | ) 143 | 144 | # Remove old pytorch weights file if it exists in the most recent commit 145 | if file_exists(repo_id, "pytorch_model.bin"): 146 | operations.append( 147 | CommitOperationDelete(path_in_repo="pytorch_model.bin") 148 | ) 149 | 150 | commit_info = api.create_commit( 151 | repo_id=repo_id, 152 | operations=operations, 153 | commit_message=f"Migrate model (HF commit {revision[:7]}) to pvnet version {pvnet_version}", 154 | ) 155 | 156 | # Print the most recent commit hash 157 | c = api.list_repo_commits(repo_id=repo_id, repo_type="model")[0] 158 | 159 | print( 160 | f"\nThe latest commit is now: \n" 161 | f" date: {c.created_at} \n" 162 | f" commit hash: {c.commit_id}\n" 163 | f" by: {c.authors}\n" 164 | f" title: {c.title}\n" 165 | ) 166 | 167 | if local_dir is None: 168 | temp_dir.cleanup() 169 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Late fusion models""" 2 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dask.array 4 | import hydra 5 | import pytest 6 | import pandas as pd 7 | import numpy as np 8 | import xarray as xr 9 | import torch 10 | 11 | from omegaconf import OmegaConf 12 | 13 | from ocf_data_sampler.torch_datasets.sample.site import SiteSample 14 | from ocf_data_sampler.torch_datasets.datasets import SitesDataset 15 | from ocf_data_sampler.numpy_sample.common_types import TensorBatch 16 | from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration 17 | 18 | from pvnet.datamodule import collate_fn 19 | from pvnet.datamodule import UKRegionalDataModule, SitesDataModule 20 | from pvnet.models import LateFusionModel 21 | 22 | 23 | 24 | _top_test_directory = os.path.dirname(os.path.realpath(__file__)) 25 | 26 | 27 | uk_sat_area_string = """msg_seviri_rss_3km: 28 | description: MSG SEVIRI Rapid Scanning Service area definition with 3 km resolution 29 | projection: 30 | proj: geos 31 | lon_0: 9.5 32 | h: 35785831 33 | x_0: 0 34 | y_0: 0 35 | a: 6378169 36 | rf: 295.488065897014 37 | no_defs: null 38 | type: crs 39 | shape: 40 | height: 298 41 | width: 615 42 | area_extent: 43 | lower_left_xy: [28503.830075263977, 5090183.970808983] 44 | upper_right_xy: [-1816744.1169023514, 4196063.827395439] 45 | units: m 46 | """ 47 | 48 | 49 | @pytest.fixture(scope="session") 50 | def session_tmp_path(tmp_path_factory): 51 | return tmp_path_factory.mktemp("data") 52 | 53 | 54 | @pytest.fixture(scope="session") 55 | def sat_zarr_path(session_tmp_path) -> str: 56 | variables = [ 57 | "IR_016", "IR_039", "IR_087", "IR_097", "IR_108", "IR_120", 58 | "IR_134", "VIS006", "VIS008", "WV_062", "WV_073", 59 | ] 60 | times = pd.date_range("2023-01-01 00:00", "2023-01-01 23:55", freq="5min") 61 | y = np.linspace(start=4191563, stop=5304712, num=100) 62 | x = np.linspace(start=15002, stop=-1824245, num=100) 63 | 64 | coords = ( 65 | ("variable", variables), 66 | ("time", times), 67 | ("y_geostationary", y), 68 | ("x_geostationary", x), 69 | ) 70 | 71 | data = dask.array.zeros( 72 | shape=tuple(len(coord_values) for _, coord_values in coords), 73 | chunks=(-1, 10, -1, -1), 74 | dtype=np.float32, 75 | ) 76 | 77 | attrs = {"area": uk_sat_area_string} 78 | 79 | ds = xr.DataArray(data=data, coords=coords, attrs=attrs).to_dataset(name="data") 80 | 81 | zarr_path = session_tmp_path / "test_sat.zarr" 82 | ds.to_zarr(zarr_path) 83 | 84 | return zarr_path 85 | 86 | 87 | @pytest.fixture(scope="session") 88 | def ukv_zarr_path(session_tmp_path) -> str: 89 | init_times = pd.date_range(start="2023-01-01 00:00", freq="180min", periods=24 * 7) 90 | variables = ["si10", "dswrf", "t", "prate"] 91 | steps = pd.timedelta_range("0h", "24h", freq="1h") 92 | x = np.linspace(-239_000, 857_000, 200) 93 | y = np.linspace(-183_000, 1425_000, 200) 94 | 95 | coords = ( 96 | ("init_time", init_times), 97 | ("variable", variables), 98 | ("step", steps), 99 | ("x", x), 100 | ("y", y), 101 | ) 102 | 103 | data = dask.array.random.uniform( 104 | low=0, 105 | high=200, 106 | size=tuple(len(coord_values) for _, coord_values in coords), 107 | chunks=(1, -1, -1, 50, 50), 108 | ).astype(np.float32) 109 | 110 | ds = xr.DataArray(data=data, coords=coords).to_dataset(name="UKV") 111 | 112 | zarr_path = session_tmp_path / "ukv_nwp.zarr" 113 | ds.to_zarr(zarr_path) 114 | return zarr_path 115 | 116 | 117 | @pytest.fixture(scope="session") 118 | def ecmwf_zarr_path(session_tmp_path) -> str: 119 | init_times = pd.date_range(start="2023-01-01 00:00", freq="6h", periods=24 * 7) 120 | variables = ["t2m", "dswrf", "mcc"] 121 | steps = pd.timedelta_range("0h", "14h", freq="1h") 122 | lons = np.arange(-12.0, 3.0, 0.1) 123 | lats = np.arange(48.0, 65.0, 0.1) 124 | 125 | coords = ( 126 | ("init_time", init_times), 127 | ("variable", variables), 128 | ("step", steps), 129 | ("longitude", lons), 130 | ("latitude", lats), 131 | ) 132 | 133 | data = dask.array.random.uniform( 134 | low=0, 135 | high=200, 136 | size=tuple(len(coord_values) for _, coord_values in coords), 137 | chunks=(1, -1, -1, 50, 50), 138 | ).astype(np.float32) 139 | 140 | ds = xr.DataArray(data=data, coords=coords).to_dataset(name="ECMWF_UK") 141 | 142 | zarr_path = session_tmp_path / "ukv_ecmwf.zarr" 143 | ds.to_zarr(zarr_path) 144 | yield zarr_path 145 | 146 | 147 | @pytest.fixture(scope="session") 148 | def gsp_zarr_path(session_tmp_path) -> str: 149 | times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="30min") 150 | gsp_ids = np.arange(0, 318) 151 | capacity = np.ones((len(times), len(gsp_ids))) 152 | generation = np.random.uniform(0, 200, size=(len(times), len(gsp_ids))).astype(np.float32) 153 | 154 | coords = ( 155 | ("datetime_gmt", times), 156 | ("gsp_id", gsp_ids), 157 | ) 158 | 159 | ds_uk_gsp = xr.Dataset({ 160 | "capacity_mwp": xr.DataArray(capacity, coords=coords), 161 | "installedcapacity_mwp": xr.DataArray(capacity, coords=coords), 162 | "generation_mw": xr.DataArray(generation, coords=coords), 163 | }) 164 | 165 | zarr_path = session_tmp_path / "uk_gsp.zarr" 166 | ds_uk_gsp.to_zarr(zarr_path) 167 | return zarr_path 168 | 169 | 170 | @pytest.fixture(scope="session") 171 | def site_data_paths(session_tmp_path) -> tuple[str, str]: 172 | 173 | times = pd.date_range("2023-01-01 00:00", "2023-01-02 00:00", freq="15min") 174 | site_ids = np.arange(0, 10) 175 | 176 | capacity = np.ones(len(site_ids)) 177 | lons = np.linspace(-4, -3, len(site_ids)) 178 | lats = np.linspace(51, 52, len(site_ids)) 179 | 180 | coords = (("time_utc", times), ("site_id", site_ids)) 181 | 182 | generation_data = np.random.uniform( 183 | low=0, 184 | high=200, 185 | size=tuple(len(coord_values) for _, coord_values in coords) 186 | ).astype(np.float32) 187 | 188 | ds_gen = xr.DataArray(generation_data, coords=coords).to_dataset(name="generation_kw") 189 | 190 | df_meta = pd.DataFrame( 191 | { 192 | "site_id": site_ids, 193 | "capacity_kwp": capacity, 194 | "longitude": lons, 195 | "latitude": lats, 196 | } 197 | ) 198 | 199 | generation_data_path = session_tmp_path / f"sites_data.netcdf" 200 | metadata_path = session_tmp_path / f"sites_metadata.csv" 201 | ds_gen.to_netcdf(generation_data_path) 202 | df_meta.to_csv(metadata_path, index=False) 203 | 204 | return generation_data_path, metadata_path 205 | 206 | 207 | @pytest.fixture(scope="session") 208 | def uk_data_config_path( 209 | session_tmp_path, 210 | sat_zarr_path, 211 | ukv_zarr_path, 212 | ecmwf_zarr_path, 213 | gsp_zarr_path 214 | ) -> str: 215 | 216 | # Populate the config with the generated zarr paths 217 | config = load_yaml_configuration(f"{_top_test_directory}/test_data/uk_data_config.yaml") 218 | config.input_data.nwp["ukv"].zarr_path = str(ukv_zarr_path) 219 | config.input_data.nwp["ecmwf"].zarr_path = str(ecmwf_zarr_path) 220 | config.input_data.satellite.zarr_path = str(sat_zarr_path) 221 | config.input_data.gsp.zarr_path = str(gsp_zarr_path) 222 | 223 | filename = f"{session_tmp_path}/uk_data_config.yaml" 224 | save_yaml_configuration(config, filename) 225 | return filename 226 | 227 | 228 | @pytest.fixture(scope="session") 229 | def site_data_config_path( 230 | session_tmp_path, 231 | sat_zarr_path, 232 | ukv_zarr_path, 233 | ecmwf_zarr_path, 234 | site_data_paths, 235 | ) -> str: 236 | 237 | # Populate the config with the generated zarr paths 238 | config = load_yaml_configuration(f"{_top_test_directory}/test_data/site_data_config.yaml") 239 | config.input_data.nwp["ukv"].zarr_path = str(ukv_zarr_path) 240 | config.input_data.nwp["ecmwf"].zarr_path = str(ecmwf_zarr_path) 241 | config.input_data.satellite.zarr_path = str(sat_zarr_path) 242 | config.input_data.site.file_path = str(site_data_paths[0]) 243 | config.input_data.site.metadata_file_path = str(site_data_paths[1]) 244 | 245 | filename = f"{session_tmp_path}/site_data_config.yaml" 246 | save_yaml_configuration(config, filename) 247 | return filename 248 | 249 | 250 | @pytest.fixture(scope="session") 251 | def uk_streamed_datamodule(uk_data_config_path) -> UKRegionalDataModule: 252 | dm = UKRegionalDataModule( 253 | configuration=uk_data_config_path, 254 | batch_size=2, 255 | num_workers=0, 256 | prefetch_factor=None, 257 | ) 258 | dm.setup(stage="fit") 259 | return dm 260 | 261 | 262 | @pytest.fixture(scope="session") 263 | def site_streamed_datamodule(site_data_config_path) -> SitesDataModule: 264 | dm = SitesDataModule( 265 | configuration=site_data_config_path, 266 | batch_size=2, 267 | num_workers=0, 268 | prefetch_factor=None, 269 | ) 270 | dm.setup(stage="fit") 271 | return dm 272 | 273 | 274 | @pytest.fixture(scope="session") 275 | def uk_batch(uk_streamed_datamodule) -> TensorBatch: 276 | return next(iter(uk_streamed_datamodule.train_dataloader())) 277 | 278 | 279 | @pytest.fixture(scope="session") 280 | def site_batch(site_data_config_path) -> TensorBatch: 281 | dataset = SitesDataset(site_data_config_path) 282 | return collate_fn([SiteSample(dataset[i]).to_numpy() for i in range(2)]) 283 | 284 | 285 | @pytest.fixture(scope="session") 286 | def satellite_batch_component(uk_batch) -> torch.Tensor: 287 | return torch.swapaxes(uk_batch["satellite_actual"], 1, 2).float() 288 | 289 | 290 | @pytest.fixture() 291 | def model_minutes_kwargs() -> dict: 292 | return dict(forecast_minutes=480, history_minutes=60) 293 | 294 | 295 | @pytest.fixture() 296 | def encoder_model_kwargs() -> dict: 297 | # Used to test encoder model on satellite data 298 | return dict( 299 | sequence_length=7, # 30 minutes of 5 minutely satellite data = 7 time steps 300 | image_size_pixels=24, 301 | in_channels=11, 302 | out_features=128, 303 | ) 304 | 305 | 306 | @pytest.fixture() 307 | def site_encoder_model_kwargs() -> dict: 308 | """Used to test site encoder model on PV data with data sampler""" 309 | return dict( 310 | sequence_length=60 // 15 + 1, 311 | num_sites=1, 312 | out_features=128, 313 | target_key_to_use="site" 314 | ) 315 | 316 | 317 | @pytest.fixture() 318 | def raw_late_fusion_model_kwargs(model_minutes_kwargs) -> dict: 319 | return dict( 320 | sat_encoder=dict( 321 | _target_="pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet", 322 | _partial_=True, 323 | in_channels=11, 324 | out_features=128, 325 | number_of_conv3d_layers=6, 326 | conv3d_channels=32, 327 | image_size_pixels=24, 328 | ), 329 | nwp_encoders_dict={ 330 | "ukv": dict( 331 | _target_="pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet", 332 | _partial_=True, 333 | in_channels=4, 334 | out_features=128, 335 | number_of_conv3d_layers=6, 336 | conv3d_channels=32, 337 | image_size_pixels=24, 338 | ), 339 | "ecmwf": dict( 340 | _target_="pvnet.models.late_fusion.encoders.encoders3d.DefaultPVNet", 341 | _partial_=True, 342 | in_channels=3, 343 | out_features=128, 344 | number_of_conv3d_layers=2, 345 | stride=[1,2,2], 346 | conv3d_channels=32, 347 | image_size_pixels=12, 348 | ), 349 | }, 350 | 351 | add_image_embedding_channel=True, 352 | output_network=dict( 353 | _target_="pvnet.models.late_fusion.linear_networks.networks.ResFCNet", 354 | _partial_=True, 355 | fc_hidden_features=128, 356 | n_res_blocks=6, 357 | res_block_layers=2, 358 | dropout_frac=0.0, 359 | ), 360 | location_id_mapping={i:i for i in range(1, 318)}, 361 | embedding_dim=16, 362 | include_sun=True, 363 | include_gsp_yield_history=True, 364 | sat_history_minutes=30, 365 | nwp_history_minutes={"ukv": 120, "ecmwf": 120}, 366 | nwp_forecast_minutes={"ukv": 480, "ecmwf": 480}, 367 | nwp_interval_minutes={"ukv": 60, "ecmwf": 60}, 368 | min_sat_delay_minutes=0, 369 | **model_minutes_kwargs, 370 | ) 371 | 372 | 373 | @pytest.fixture() 374 | def late_fusion_model_kwargs(raw_late_fusion_model_kwargs) -> dict: 375 | return hydra.utils.instantiate(raw_late_fusion_model_kwargs) 376 | 377 | 378 | @pytest.fixture() 379 | def late_fusion_model(late_fusion_model_kwargs) -> LateFusionModel: 380 | return LateFusionModel(**late_fusion_model_kwargs) 381 | 382 | 383 | @pytest.fixture() 384 | def raw_late_fusion_model_kwargs_site_history(model_minutes_kwargs) -> dict: 385 | return dict( 386 | # Set inputs to None/False apart from site history 387 | target_key="pv", 388 | sat_encoder=None, 389 | nwp_encoders_dict=None, 390 | add_image_embedding_channel=False, 391 | pv_encoder=None, 392 | output_network=dict( 393 | _target_="pvnet.models.late_fusion.linear_networks.networks.ResFCNet", 394 | _partial_=True, 395 | fc_hidden_features=128, 396 | n_res_blocks=6, 397 | res_block_layers=2, 398 | dropout_frac=0.0, 399 | ), 400 | location_id_mapping=None, 401 | embedding_dim=None, 402 | include_sun=False, 403 | include_time=True, 404 | include_gsp_yield_history=False, 405 | include_site_yield_history=True, 406 | forecast_minutes=480, 407 | history_minutes=60, 408 | interval_minutes=15, 409 | ) 410 | 411 | 412 | @pytest.fixture() 413 | def late_fusion_model_kwargs_site_history(raw_late_fusion_model_kwargs_site_history) -> dict: 414 | return hydra.utils.instantiate(raw_late_fusion_model_kwargs_site_history) 415 | 416 | 417 | @pytest.fixture() 418 | def late_fusion_model_site_history(late_fusion_model_kwargs_site_history) -> LateFusionModel: 419 | return LateFusionModel(**late_fusion_model_kwargs_site_history) 420 | 421 | 422 | @pytest.fixture() 423 | def late_fusion_quantile_model(late_fusion_model_kwargs) -> LateFusionModel: 424 | return LateFusionModel(output_quantiles=[0.1, 0.5, 0.9], **late_fusion_model_kwargs) 425 | 426 | 427 | @pytest.fixture 428 | def trainer_cfg(): 429 | def _make(trainer_dict): 430 | return OmegaConf.create({"trainer": trainer_dict}) 431 | return _make 432 | -------------------------------------------------------------------------------- /tests/models/late_fusion/encoders/test_encoders3d.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.late_fusion.encoders.encoders3d import DefaultPVNet, ResConv3DNet 2 | 3 | 4 | def _test_model_forward(batch, model_class, model_kwargs): 5 | model = model_class(**model_kwargs) 6 | y = model(batch) 7 | assert tuple(y.shape) == (2, model_kwargs["out_features"]), y.shape 8 | 9 | 10 | def _test_model_backward(batch, model_class, model_kwargs): 11 | model = model_class(**model_kwargs) 12 | y = model(batch) 13 | # Backwards on sum drives sum to zero 14 | y.sum().backward() 15 | 16 | 17 | # Test model forward on all models 18 | def test_defaultpvnet_forward(satellite_batch_component, encoder_model_kwargs): 19 | _test_model_forward(satellite_batch_component, DefaultPVNet, encoder_model_kwargs) 20 | 21 | 22 | def test_resconv3dnet_forward(satellite_batch_component, encoder_model_kwargs): 23 | _test_model_forward(satellite_batch_component, ResConv3DNet, encoder_model_kwargs) 24 | 25 | 26 | # Test model backward on all models 27 | def test_defaultpvnet_backward(satellite_batch_component, encoder_model_kwargs): 28 | _test_model_backward(satellite_batch_component, DefaultPVNet, encoder_model_kwargs) 29 | 30 | 31 | def test_resconv3dnet_backward(satellite_batch_component, encoder_model_kwargs): 32 | _test_model_backward(satellite_batch_component, ResConv3DNet, encoder_model_kwargs) 33 | -------------------------------------------------------------------------------- /tests/models/late_fusion/linear_networks/test_networks.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.late_fusion.linear_networks.networks import ResFCNet 2 | import pytest 3 | import torch 4 | from collections import OrderedDict 5 | 6 | 7 | @pytest.fixture() 8 | def simple_linear_batch(): 9 | return torch.rand(2, 100) 10 | 11 | 12 | @pytest.fixture() 13 | def late_fusion_linear_batch(): 14 | return OrderedDict(nwp=torch.rand(2, 50), sat=torch.rand(2, 40), sun=torch.rand(2, 10)) 15 | 16 | 17 | @pytest.fixture() 18 | def multiple_batch_types(simple_linear_batch, late_fusion_linear_batch): 19 | return [simple_linear_batch, late_fusion_linear_batch] 20 | 21 | 22 | @pytest.fixture() 23 | def fc_batch_batch(): 24 | return torch.rand(2, 100) 25 | 26 | 27 | @pytest.fixture() 28 | def linear_network_kwargs(): 29 | return dict(in_features=100, out_features=10) 30 | 31 | 32 | def _test_model_forward(batches, model_class, model_kwargs): 33 | for batch in batches: 34 | model = model_class(**model_kwargs) 35 | y = model(batch) 36 | assert tuple(y.shape) == (2, model_kwargs["out_features"]), y.shape 37 | 38 | 39 | def _test_model_backward(batch, model_class, model_kwargs): 40 | model = model_class(**model_kwargs) 41 | y = model(batch) 42 | # Backwards on sum drives sum to zero 43 | y.sum().backward() 44 | 45 | 46 | # Test model forward on all models 47 | def test_resfcnet_forward(multiple_batch_types, linear_network_kwargs): 48 | _test_model_forward(multiple_batch_types, ResFCNet, linear_network_kwargs) 49 | 50 | 51 | def test_resfcnet_backward(simple_linear_batch, linear_network_kwargs): 52 | _test_model_backward(simple_linear_batch, ResFCNet, linear_network_kwargs) 53 | -------------------------------------------------------------------------------- /tests/models/late_fusion/site_encoders/test_encoders.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.late_fusion.site_encoders.encoders import SingleAttentionNetwork 2 | 3 | 4 | def _test_model_forward(batch, model_class, kwargs, batch_size): 5 | model = model_class(**kwargs) 6 | y = model(batch) 7 | assert tuple(y.shape) == (batch_size, kwargs["out_features"]), y.shape 8 | 9 | 10 | def _test_model_backward(batch, model_class, kwargs): 11 | model = model_class(**kwargs) 12 | y = model(batch) 13 | # Backwards on sum drives sum to zero 14 | y.sum().backward() 15 | 16 | 17 | def test_singleattentionnetwork_forward(site_batch, site_encoder_model_kwargs): 18 | _test_model_forward( 19 | site_batch, 20 | SingleAttentionNetwork, 21 | site_encoder_model_kwargs, 22 | batch_size=2, 23 | ) 24 | 25 | 26 | def test_singleattentionnetwork_backward(site_batch, site_encoder_model_kwargs): 27 | _test_model_backward(site_batch, SingleAttentionNetwork, site_encoder_model_kwargs) 28 | -------------------------------------------------------------------------------- /tests/models/late_fusion/test_late_fusion.py: -------------------------------------------------------------------------------- 1 | def test_model_forward(late_fusion_model, uk_batch): 2 | y = late_fusion_model(uk_batch) 3 | 4 | # Check output is the correct shape: [batch size=2, forecast_len=16] 5 | assert tuple(y.shape) == (2, 16), y.shape 6 | 7 | def test_model_forward_site_history(late_fusion_model_site_history, site_batch): 8 | 9 | y = late_fusion_model_site_history(site_batch) 10 | 11 | # Check output is the correct shape: [batch size=2, forecast_len=32] 12 | assert tuple(y.shape) == (2, 32), y.shape 13 | 14 | 15 | def test_model_backward(late_fusion_model, uk_batch): 16 | y = late_fusion_model(uk_batch) 17 | 18 | # Backwards on sum drives sum to zero 19 | y.sum().backward() 20 | 21 | 22 | def test_quantile_model_forward(late_fusion_quantile_model, uk_batch): 23 | y_quantiles = late_fusion_quantile_model(uk_batch) 24 | 25 | # Check output is the correct shape: [batch size=2, forecast_len=16, num_quantiles=3] 26 | assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape 27 | 28 | 29 | def test_quantile_model_backward(late_fusion_quantile_model, uk_batch): 30 | 31 | y_quantiles = late_fusion_quantile_model(uk_batch) 32 | 33 | # Backwards on sum drives sum to zero 34 | y_quantiles.sum().backward() 35 | -------------------------------------------------------------------------------- /tests/models/late_fusion/test_save_load_pretrained.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | from pvnet.models import BaseModel 3 | import pvnet.model_cards 4 | 5 | 6 | card_path = f"{pvnet.model_cards.__path__[0]}/empty_model_card_template.md" 7 | 8 | 9 | def test_save_pretrained( 10 | tmp_path, 11 | late_fusion_model, 12 | raw_late_fusion_model_kwargs, 13 | uk_data_config_path 14 | ): 15 | 16 | # Construct the model config 17 | model_config = { 18 | "_target_": "pvnet.models.LateFusionModel", 19 | **raw_late_fusion_model_kwargs, 20 | } 21 | 22 | # Save the model 23 | model_output_dir = f"{tmp_path}/saved_model" 24 | late_fusion_model.save_pretrained( 25 | save_directory=model_output_dir, 26 | model_config=model_config, 27 | data_config_path=uk_data_config_path, 28 | wandb_repo="test", 29 | wandb_ids="abc", 30 | card_template_path=card_path, 31 | push_to_hub=False, 32 | ) 33 | 34 | # Load the model 35 | _ = BaseModel.from_pretrained(model_id=model_output_dir, revision=None) 36 | 37 | 38 | def test_create_hugging_face_model_card(): 39 | 40 | # Create Hugging Face ModelCard 41 | card = BaseModel.create_hugging_face_model_card(card_path, wandb_repo="test", wandb_ids="abc") 42 | 43 | # Extract the card markdown 44 | card_markdown = card.content 45 | 46 | # Regex to find if the pvnet and ocf-data-sampler versions are present 47 | pvnet_version = version("pvnet") 48 | has_pvnet = f"pvnet=={pvnet_version}" in card_markdown 49 | 50 | ocf_sampler_version = version("ocf-data-sampler") 51 | has_ocf_data_sampler= f"ocf-data-sampler=={ocf_sampler_version}" in card_markdown 52 | 53 | assert has_pvnet, f"The hugging face card created does not display the PVNet package version" 54 | assert has_ocf_data_sampler, f"The hugging face card created does not display the ocf-data-sampler package version" 55 | -------------------------------------------------------------------------------- /tests/models/test_ensemble.py: -------------------------------------------------------------------------------- 1 | from pvnet.models.ensemble import Ensemble 2 | 3 | 4 | def test_model_init(late_fusion_model): 5 | # Without weighting 6 | ensemble_model = Ensemble(model_list=[late_fusion_model] * 3, weights=None) 7 | 8 | # With weighting 9 | ensemble_model = Ensemble(model_list=[late_fusion_model] * 3, weights=[1, 2, 3]) 10 | 11 | 12 | def test_model_forward(late_fusion_model, uk_batch): 13 | ensemble_model = Ensemble(model_list=[late_fusion_model] * 3) 14 | 15 | y = ensemble_model(uk_batch) 16 | 17 | # Check output is the correct shape: [batch size=2, forecast_len=16] 18 | assert tuple(y.shape) == (2, 16), y.shape 19 | 20 | 21 | def test_quantile_model_forward(late_fusion_quantile_model, uk_batch): 22 | ensemble_model = Ensemble(model_list=[late_fusion_quantile_model] * 3) 23 | 24 | y_quantiles = ensemble_model(uk_batch) 25 | 26 | # Check output is the correct shape: [batch size=2, forecast_len=16, num_quantiles=3] 27 | assert tuple(y_quantiles.shape) == (2, 16, 3), y_quantiles.shape 28 | -------------------------------------------------------------------------------- /tests/models/test_validation.py: -------------------------------------------------------------------------------- 1 | """Tests for model and trainer configuration validation utilities.""" 2 | 3 | import pytest 4 | import torch 5 | 6 | from pvnet.utils import validate_batch_against_config, validate_gpu_config 7 | 8 | 9 | def test_validate_batch_against_config( 10 | uk_batch: dict, 11 | late_fusion_model, 12 | ): 13 | """Test batch validation utility function.""" 14 | # This should pass as full uk_batch is valid 15 | validate_batch_against_config(batch=uk_batch, model=late_fusion_model) 16 | 17 | 18 | def test_validate_batch_against_config_raises_error(late_fusion_model): 19 | """Test that the validation raises an error for a mismatched batch.""" 20 | # Create batch that is missing required NWP data 21 | minimal_batch = {"gsp": torch.randn(2, 17)} 22 | with pytest.raises( 23 | ValueError, 24 | match="Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch.", 25 | ): 26 | validate_batch_against_config(batch=minimal_batch, model=late_fusion_model) 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "trainer", 31 | [ 32 | {"devices": 1}, 33 | {"devices": [0]}, 34 | {"accelerator": "cpu"}, 35 | ], 36 | ids=["devices=1", "devices=[0]", "accelerator=cpu"], 37 | ) 38 | def test_validate_gpu_config_single_device(trainer_cfg, trainer): 39 | """Accept single GPU or explicit CPU configurations.""" 40 | validate_gpu_config(trainer_cfg(trainer)) 41 | 42 | 43 | @pytest.mark.parametrize( 44 | "trainer", 45 | [ 46 | {"devices": 2}, 47 | ], 48 | ids=["devices=2"], 49 | ) 50 | def test_validate_gpu_config_multiple_devices(trainer_cfg, trainer): 51 | """Reject accidental multi-GPU setups.""" 52 | with pytest.raises(ValueError, match="Parallel training not supported"): 53 | validate_gpu_config(trainer_cfg(trainer)) 54 | -------------------------------------------------------------------------------- /tests/test_data/site_data_config.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | description: Test config for PVNet 3 | name: pvnet_test 4 | 5 | input_data: 6 | 7 | site: 8 | file_path: set_in_temp_file 9 | metadata_file_path: set_in_temp_file 10 | interval_start_minutes: -60 11 | interval_end_minutes: 480 12 | time_resolution_minutes: 15 13 | dropout_timedeltas_minutes: [] 14 | dropout_fraction: 0 15 | 16 | nwp: 17 | ukv: 18 | provider: ukv 19 | zarr_path: set_in_temp_file 20 | interval_start_minutes: -120 21 | interval_end_minutes: 480 22 | time_resolution_minutes: 60 23 | channels: ["si10", "dswrf", "t", "prate"] 24 | image_size_pixels_height: 24 25 | image_size_pixels_width: 24 26 | dropout_timedeltas_minutes: [-180] 27 | dropout_fraction: 1.0 28 | max_staleness_minutes: null 29 | normalisation_constants: 30 | si10: 31 | mean: 1 32 | std: 1 33 | dswrf: 34 | mean: 1 35 | std: 1 36 | t: 37 | mean: 1 38 | std: 1 39 | prate: 40 | mean: 1 41 | std: 1 42 | 43 | ecmwf: 44 | provider: ecmwf 45 | zarr_path: set_in_temp_file 46 | interval_start_minutes: -120 47 | interval_end_minutes: 480 48 | time_resolution_minutes: 60 49 | channels: ["t2m", "dswrf", "mcc"] 50 | image_size_pixels_height: 12 51 | image_size_pixels_width: 12 52 | dropout_timedeltas_minutes: [-180] 53 | dropout_fraction: 1.0 54 | max_staleness_minutes: null 55 | normalisation_constants: 56 | t2m: 57 | mean: 1 58 | std: 1 59 | dswrf: 60 | mean: 1 61 | std: 1 62 | mcc: 63 | mean: 1 64 | std: 1 65 | 66 | satellite: 67 | zarr_path: set_in_temp_file 68 | interval_start_minutes: -30 69 | interval_end_minutes: 0 70 | time_resolution_minutes: 5 71 | 72 | image_size_pixels_height: 24 73 | image_size_pixels_width: 24 74 | dropout_timedeltas_minutes: [] 75 | dropout_fraction: 0 76 | 77 | channels: 78 | - IR_016 79 | - IR_039 80 | - IR_087 81 | - IR_097 82 | - IR_108 83 | - IR_120 84 | - IR_134 85 | - VIS006 86 | - VIS008 87 | - WV_062 88 | - WV_073 89 | 90 | normalisation_constants: 91 | IR_016: 92 | mean: 1 93 | std: 1 94 | IR_039: 95 | mean: 1 96 | std: 1 97 | IR_087: 98 | mean: 1 99 | std: 1 100 | IR_097: 101 | mean: 1 102 | std: 1 103 | IR_108: 104 | mean: 1 105 | std: 1 106 | IR_120: 107 | mean: 1 108 | std: 1 109 | IR_134: 110 | mean: 1 111 | std: 1 112 | VIS006: 113 | mean: 1 114 | std: 1 115 | VIS008: 116 | mean: 1 117 | std: 1 118 | WV_062: 119 | mean: 1 120 | std: 1 121 | WV_073: 122 | mean: 1 123 | std: 1 124 | 125 | solar_position: 126 | interval_start_minutes: -60 127 | interval_end_minutes: 480 128 | time_resolution_minutes: 15 129 | -------------------------------------------------------------------------------- /tests/test_data/uk_data_config.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | description: Test config for PVNet 3 | name: pvnet_test 4 | 5 | input_data: 6 | 7 | gsp: 8 | zarr_path: set_in_temp_file 9 | interval_start_minutes: -60 10 | interval_end_minutes: 480 11 | time_resolution_minutes: 30 12 | dropout_timedeltas_minutes: [] 13 | dropout_fraction: 0 14 | 15 | nwp: 16 | ukv: 17 | provider: ukv 18 | zarr_path: set_in_temp_file 19 | interval_start_minutes: -120 20 | interval_end_minutes: 480 21 | time_resolution_minutes: 60 22 | channels: ["si10", "dswrf", "t", "prate"] 23 | image_size_pixels_height: 24 24 | image_size_pixels_width: 24 25 | dropout_timedeltas_minutes: [-180] 26 | dropout_fraction: 1.0 27 | max_staleness_minutes: null 28 | normalisation_constants: 29 | si10: 30 | mean: 1 31 | std: 1 32 | dswrf: 33 | mean: 1 34 | std: 1 35 | t: 36 | mean: 1 37 | std: 1 38 | prate: 39 | mean: 1 40 | std: 1 41 | 42 | ecmwf: 43 | provider: ecmwf 44 | zarr_path: set_in_temp_file 45 | interval_start_minutes: -120 46 | interval_end_minutes: 480 47 | time_resolution_minutes: 60 48 | channels: ["t2m", "dswrf", "mcc"] 49 | image_size_pixels_height: 12 50 | image_size_pixels_width: 12 51 | dropout_timedeltas_minutes: [-180] 52 | dropout_fraction: 1.0 53 | max_staleness_minutes: null 54 | normalisation_constants: 55 | t2m: 56 | mean: 1 57 | std: 1 58 | dswrf: 59 | mean: 1 60 | std: 1 61 | mcc: 62 | mean: 1 63 | std: 1 64 | 65 | satellite: 66 | zarr_path: set_in_temp_file 67 | interval_start_minutes: -30 68 | interval_end_minutes: 0 69 | time_resolution_minutes: 5 70 | 71 | image_size_pixels_height: 24 72 | image_size_pixels_width: 24 73 | dropout_timedeltas_minutes: [] 74 | dropout_fraction: 0 75 | 76 | channels: 77 | - IR_016 78 | - IR_039 79 | - IR_087 80 | - IR_097 81 | - IR_108 82 | - IR_120 83 | - IR_134 84 | - VIS006 85 | - VIS008 86 | - WV_062 87 | - WV_073 88 | 89 | normalisation_constants: 90 | IR_016: 91 | mean: 1 92 | std: 1 93 | IR_039: 94 | mean: 1 95 | std: 1 96 | IR_087: 97 | mean: 1 98 | std: 1 99 | IR_097: 100 | mean: 1 101 | std: 1 102 | IR_108: 103 | mean: 1 104 | std: 1 105 | IR_120: 106 | mean: 1 107 | std: 1 108 | IR_134: 109 | mean: 1 110 | std: 1 111 | VIS006: 112 | mean: 1 113 | std: 1 114 | VIS008: 115 | mean: 1 116 | std: 1 117 | WV_062: 118 | mean: 1 119 | std: 1 120 | WV_073: 121 | mean: 1 122 | std: 1 123 | 124 | solar_position: 125 | interval_start_minutes: -60 126 | interval_end_minutes: 480 127 | time_resolution_minutes: 30 128 | -------------------------------------------------------------------------------- /tests/test_datamodule.py: -------------------------------------------------------------------------------- 1 | from pvnet.datamodule import SitesDataModule 2 | 3 | 4 | 5 | def test_sites_data_module(site_data_config_path): 6 | """Test SitesDataModule initialization""" 7 | 8 | _ = SitesDataModule( 9 | configuration=site_data_config_path, 10 | batch_size=2, 11 | num_workers=0, 12 | prefetch_factor=None, 13 | train_period=[None, None], 14 | val_period=[None, None], 15 | ) -------------------------------------------------------------------------------- /tests/test_end2end.py: -------------------------------------------------------------------------------- 1 | import lightning 2 | 3 | from pvnet.datamodule import UKRegionalDataModule 4 | from pvnet.optimizers import EmbAdamWReduceLROnPlateau 5 | from pvnet.training.lightning_module import PVNetLightningModule 6 | 7 | 8 | def test_model_trainer_fit(session_tmp_path, uk_data_config_path, late_fusion_model): 9 | """Test end-to-end training.""" 10 | 11 | datamodule = UKRegionalDataModule( 12 | configuration=uk_data_config_path, 13 | batch_size=2, 14 | num_workers=2, 15 | prefetch_factor=None, 16 | dataset_pickle_dir=f"{session_tmp_path}/dataset_pickles" 17 | ) 18 | 19 | lightning_model = PVNetLightningModule( 20 | model=late_fusion_model, 21 | optimizer=EmbAdamWReduceLROnPlateau(), 22 | ) 23 | 24 | # Train the model for two batches 25 | trainer = lightning.Trainer( 26 | max_epochs=2, 27 | limit_val_batches=2, 28 | limit_train_batches=2, 29 | accelerator="cpu", 30 | logger=False, 31 | enable_checkpointing=False, 32 | ) 33 | trainer.fit(model=lightning_model, datamodule=datamodule) 34 | --------------------------------------------------------------------------------